Compare commits

...

102 Commits

Author SHA1 Message Date
Gilles Boccon-Gibod 135df0dcc0 format with Black 2022-12-10 09:40:12 -08:00
Gilles Boccon-Gibod 297246fa4c Merge pull request #92 from yuyangh/yuyangh/add_ASHA_GATT
add ASHA profile
2022-12-07 13:16:01 -08:00
Yuyang Huang 52db1cfcc1 improve code style 2022-12-06 07:38:05 -08:00
Yuyang Huang 29f9a79502 improve get service advertising data 2022-12-05 11:22:07 -08:00
Gilles Boccon-Gibod c86125de4f Merge pull request #93 from AlanRosenthal/alan/add_default_services
Add Device::add_default_services()
2022-12-01 12:36:03 -08:00
Yuyang Huang 697d5df3f8 code style update 2022-12-01 10:50:15 -08:00
Yuyang Huang 87aa4f617e add ASHA advertising factory method 2022-12-01 10:40:30 -08:00
Alan Rosenthal a8eff737e6 Add Device::add_default_services()
This will allow a test to:
a: add services to a device
b: reset services via `Server()`
c: add the default services back
2022-12-01 17:02:54 +00:00
Gilles Boccon-Gibod 4417eb636c Merge pull request #83 from AlanRosenthal/alan/pytest_fixes_2
Test all python versions in CI
2022-11-29 12:48:35 -08:00
Alan Rosenthal f4e5e61bbb Test all python versions in CI
Followed instructions here: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#using-the-python-starter-workflow
2022-11-29 20:22:56 +00:00
Lucas Abel ba7a60025f Merge pull request #89 from google/uael/misc
Typos & CI fixes
2022-11-29 12:14:03 -08:00
Yuyang Huang d92b7e9b74 add ASHA Profile 2022-11-29 10:34:59 -08:00
Yuyang Huang b0336adf1c add ASHA GATT UUID 2022-11-29 10:02:40 -08:00
Abel Lucas 691450c7de gatt: fix CharacteristicDeclaration.__str__ and associated test 2022-11-29 16:43:47 +00:00
Abel Lucas 99a0eb21c1 address: fix deprecated use of combined @classmethod and @property 2022-11-29 16:33:12 +00:00
Abel Lucas ab4859bd94 device: fix typos 2022-11-29 16:33:12 +00:00
Lucas Abel 0d70cbde64 Merge pull request #75 from google/uael/fixes
Pairing: device/host fixes & improvements
2022-11-28 21:42:43 -08:00
Gilles Boccon-Gibod f41d0682b2 Merge pull request #80 from AlanRosenthal/alan/gatt_server_getter
Added class CharacteristicDeclaration, gatt_server getters
2022-11-28 19:21:08 -08:00
Gilles Boccon-Gibod 062dc1e53d Merge pull request #85 from AlanRosenthal/alan/gatt_server_console2
Add `bumble-console --device-config` support for gatt services
2022-11-28 19:19:25 -08:00
Abel Lucas 662704e551 classic: complete authentication when being the .authenticate acceptor 2022-11-29 00:28:39 +00:00
Abel Lucas 02a474c44e smp: emit enough information on pairing complete to deduce security level 2022-11-29 00:28:38 +00:00
Abel Lucas a1c7aec492 device: fix .find_connection_by_bd_addr 2022-11-29 00:28:38 +00:00
Abel Lucas 6112f00049 device: introduce BR/EDR pending connections
This commit enable the BR/EDR pairing to run asynchronously to
the connection being established.

When in security mode 3, a controller shall start authentication as
part of the connection, which result in HCI events being sent on a BD
address without a completed connection (ie. no connection handle).
2022-11-29 00:28:38 +00:00
Alan Rosenthal f56ac14f2c Add bumble-console --device-config support for gatt services
This PR adds support for bumble-console to be preloaded with gatt services via `--device-config`.
This PR also adds some type annotations
2022-11-28 14:11:27 -05:00
Gilles Boccon-Gibod a739fc71ce Merge pull request #84 from google/gbg/78
use libusb auto-detach feature
2022-11-27 12:42:02 -08:00
Alan Rosenthal b89f9030a0 Added class CharacteristicDeclaration, gatt_server getters
* Converted CharacteristicDeclaration implementation to class
* Added ability to get a gatt_server attribute by service UUID, characteristics UUID, descriptor UUID
2022-11-27 19:22:25 +00:00
Gilles Boccon-Gibod 9e5a85bd10 use libusb auto-detach feature 2022-11-25 17:52:13 -08:00
Gilles Boccon-Gibod b437bd8619 Merge pull request #82 from AlanRosenthal/alan/pytest_fixes
Fix test failures
2022-11-25 13:19:53 -08:00
Alan Rosenthal a3e4674819 Fix test failures
a. `DeprecationWarning: The 'warn' method is deprecated, use 'warning' instead`
Updated call in `bumble/smp.py`

b. `ModuleNotFoundError: No module named 'bumble.apps'`
Updated imports in `tests/import_test.py`

c. Added `pytest-html` for easier viewing of test results
Added package in `setup.cfg`, and hook in `tasks.py`

d. Updated workflows to use `invoke test`

This is a partial fix of #81
2022-11-23 11:31:27 -05:00
Abel Lucas 5f1d57fcb0 device: simplify and fixes remote name request 2022-11-22 21:20:56 +00:00
Gilles Boccon-Gibod ae0b739e4a Merge pull request #79 from google/gbg/fix-host-reset
fix sequencing logic broken by earlier merge
2022-11-22 09:16:26 -08:00
Michael Mogenson 0570b59796 Merge pull request #77 from mogenson/l2cap-on-event
Fix for 'Host' object has no attribute 'add_listener'
2022-11-22 09:39:04 -05:00
Gilles Boccon-Gibod 22218627f6 fix sequencing logic broken by earlier merge 2022-11-21 21:07:47 -08:00
Michael Mogenson 1c72242264 Fix for 'Host' object has no attribute 'add_listener'
Pyee's add_listener() method was not added until release 9.0.0. Bumble's
setup.cfg specifies a minimum pyee version of 8.2.2.

Remove the call to add_listener() in l2cap.py. If the add_listener() API
is prefered over the on(), another solution would be to bump the pyee
version requirement.
2022-11-21 12:31:21 -05:00
Abel Lucas 9c133706e6 keys: add a way to remove all bonds from key store 2022-11-18 18:22:15 +00:00
Michael Mogenson 4988a31487 Merge pull request #76 from mogenson/connection-error-params
Swap arguments to ConnectionError in RFCOMM Multiplexer
2022-11-18 10:48:02 -05:00
Michael Mogenson e6c062117f Swap arguments to ConnectionError in RFCOMM Multiplexer
Minor fixup. Change the order of arguments to ConnectionError to set the
transport and address correctly in rfcomm.py on_dm_frame().
2022-11-18 10:02:40 -05:00
Gilles Boccon-Gibod f2133235d5 Merge pull request #73 from google/gbg/faster-l2cap-test
lower the number of test cases for l2cap in order to speed up the test
2022-11-15 10:49:55 -08:00
Gilles Boccon-Gibod 867e8c13dc lower the number of test cases for l2cap in order to speed up the test 2022-11-14 17:26:09 -08:00
Lucas Abel 25ce38c3f5 Merge pull request #72 from google/uael/public-str-address
address: add public information to the stringified value
2022-11-14 17:16:47 -08:00
Abel Lucas c0810230a6 address: add public information to the stringified value
This affect the way security keys are stored. For instance the same
key can be used both as public and random, and it need to be stored
separately one from each other.
2022-11-14 20:05:12 +00:00
Michael Mogenson 27c46eef9d Merge pull request #71 from mogenson/prefer-notify
Add prefer_notify option to gatt_client subscribe()
2022-11-13 19:53:09 -05:00
Michael Mogenson c140876157 Add prefer_notify option to gatt_client subscribe()
If characteristic supports Notify and Indicate, the prefer_notify option
will subscribe with Notify if True or Indicate if False.

If characteristic only supports one property, Notify or Indicate, that
mode will be selected, regardless of the prefer_notify setting.

Tested with a characteristic that supports both Notify and Indicate and
verified that prefer_notify sets the desired mode.
2022-11-13 19:38:12 -05:00
Lucas Abel d743656f09 Merge pull request #68 from google/uael/pairing-improvements
Pairing improvements
2022-11-11 21:03:17 -08:00
Abel Lucas b91d0e24c1 device: handle HCI passkey notification event 2022-11-11 18:43:35 +00:00
Abel Lucas eb46f60c87 le: save own_address_type on ACL connection for SMP to be able to use the right self address 2022-11-10 02:06:37 +00:00
Abel Lucas 8bbba7c84c pairing: always ask user for confirmation, even in JUST_WORKS method 2022-11-10 01:58:02 +00:00
Gilles Boccon-Gibod ee54df2d08 Merge pull request #65 from google/gbg/fix-classic-connect-await
fix classic connection event filtering
2022-11-09 14:40:29 -08:00
Gilles Boccon-Gibod 6549e53398 Merge pull request #60 from google/gbg/fix-console-logs
use a formatter object, not a string
2022-11-09 13:19:26 -08:00
Gilles Boccon-Gibod 0f219eff12 address PR comments 2022-11-09 13:18:30 -08:00
Gilles Boccon-Gibod 4a1345cf95 only force the type if the address is passed as a string 2022-11-08 19:10:13 -08:00
Gilles Boccon-Gibod 8a1cdef152 fix classic connection event filtering 2022-11-08 17:33:29 -08:00
Gilles Boccon-Gibod 6e1baf0344 use a formatter object, not a string 2022-11-08 13:19:41 -08:00
Lucas Abel cea1905ffb Merge pull request #59 from google/uael/device-cleanup
le: pass `own_address_type` to BLE `Device.connect`
2022-11-08 11:50:40 -08:00
Abel Lucas af8e0d4dc7 le: pass own_address_type to BLE Device.connect 2022-11-08 18:22:54 +00:00
Gilles Boccon-Gibod 875195aebb Merge pull request #58 from AlanRosenthal/main
Add definition of `Client Characteristic Configuration bit`
2022-11-08 09:34:22 -08:00
Gilles Boccon-Gibod 5aee37aeab Merge pull request #34 from google/gbg/l2cap-bridge
Add L2CAP CoC support
2022-11-07 16:57:17 -08:00
Gilles Boccon-Gibod edcb7d05d6 fix merge conflict 2022-11-07 16:51:40 -08:00
Gilles Boccon-Gibod ce9004f0ac Add L2CAP CoC support (squashed)
[85542e0] fix test
[3748781] add ASAH sink example
[e782e29] add app
[83daa30] wip
[7f138a0] add test
[f732108] allow different address syntax
[9d0bbf8] rename deprecated methods
[eb303d5] add LE CoC support
2022-11-07 16:45:37 -08:00
Alan Rosenthal d4228e3b5b Add definition of Client Characteristic Configuration bit 2022-11-07 19:43:22 -05:00
Lucas Abel be8f8ac68f Merge pull request #55 from google/uael/device-improvements
Device improvements
2022-11-07 15:22:41 -08:00
Abel Lucas ca16410a6d device: add option to check for the address type when using find_connection_by_bd_addr 2022-11-07 22:17:01 +00:00
Abel Lucas b95888eb39 le: permit legacy scanning even when extended is supported 2022-11-07 22:15:54 +00:00
Abel Lucas 56ed46adfa classic: add BR/EDR accept connection logic 2022-11-04 17:26:59 +00:00
Abel Lucas 7044102e05 classic: upgrade Device.cancel_connection logic to support canceling ongoing BR/EDR connections 2022-11-04 17:26:59 +00:00
Abel Lucas ca8f284888 le: add own_address_type parameter to Device.start_advertising 2022-11-04 17:26:59 +00:00
Abel Lucas e9e14f5183 le: make the device connecting state relative to LE only
We may need to add a distinct BR/EDR connecting state in the future.
2022-11-04 17:26:59 +00:00
Abel Lucas b961affd3d device: update Device.connect documentation to match BR/EDR behavior 2022-11-04 17:26:59 +00:00
Abel Lucas 51ddb36c91 device: add auto_restart mechanism to .start_discovery (default to True) 2022-11-04 17:26:59 +00:00
Abel Lucas 78534b659a device: enhance .request_remote_name to also accept an Address as argument 2022-11-04 17:26:59 +00:00
Abel Lucas ce9472bf42 core: change AdvertisingData.get default raw behavior to False 2022-11-04 17:26:59 +00:00
Abel Lucas fc331b7aea core: improve Advertisement.ad_data_to_object with support for more data types 2022-11-04 17:26:59 +00:00
Abel Lucas 8119bc210c host: pass remote_host_supported_features event to upper layer
The `HCI_Remote_Name_Request` command may trigger this HCI event.
Instead of warn for not being handled, pass it to upper layer.
2022-11-02 20:23:14 +00:00
Abel Lucas 65deefdc64 host: allow bytes return paramaters when checking command result 2022-11-02 20:23:14 +00:00
Michael Mogenson 2920c3b0d1 Merge pull request #53 from mogenson/mogenson/show-device-tab
Add a show device tab
2022-10-24 09:15:43 -04:00
Michael Mogenson f5cd825dbc Merge pull request #51 from mogenson/mogenson/console-py-rand-addr
Use random address in console.py if device config is not provided
2022-10-24 09:15:10 -04:00
Gilles Boccon-Gibod cf4c43c4ff Merge pull request #48 from google/uael/classic-parallel-connect
classic: update `Device.connect` to allow parallels connection creation
2022-10-23 20:52:08 -07:00
Gilles Boccon-Gibod da2f596e52 Merge pull request #50 from google/uael/command-timeout
device: raise a `CommandTimeoutError` error on command timeout
2022-10-23 20:49:06 -07:00
Gilles Boccon-Gibod c8aa0b4ef6 Merge pull request #54 from google/gbg/fix-regression-001
use the correct constants as previously renamed
2022-10-23 20:43:43 -07:00
Gilles Boccon-Gibod 75ac276c8b use the correct constants as previously renamed 2022-10-21 17:12:26 -07:00
Michael Mogenson dd4023ff56 Add a show device tab
Show configuration data about the Bumble device. Make this the default
tab on startup.
2022-10-21 16:00:03 -04:00
Michael Mogenson dde8c5e1c2 Use random address in console.py if device config is not provided
If a device configuration is not provided on startup, generate a random
BT address instead of using a default static value of
"F0:F1:F2:F3:F4:F5". This is helpful to avoid colisions when there are
two instances of console.py running nearby.

Testing:
Started console.py and began advertising a few times. Scanned from a
second instance of console.py and observed that the advertising address
changed with each restart.
2022-10-21 15:32:58 -04:00
Michael Mogenson 8ed1f4b50d Merge pull request #52 from mogenson/mogenson/console-py-clear-scan-results
add 'scan clear' command to console.py
2022-10-21 14:34:08 -04:00
Michael Mogenson 92de7dff4f add 'scan clear' command to console.py
Add command to clear scan results and known addresses. Useful for
determining if a peripheral has stopped advertising.

Also, check if a scan is in progress before connecting. If it is, stop
scanning. Some BT controllers will fail to connect while scanning.

Testing:
Can clear scan results before, during, and after scan. Can clear scan
results while disconnected and connected.
2022-10-21 13:58:21 -04:00
Abel Lucas 16b4f18c92 tests: add parallel device connection test 2022-10-21 15:49:03 +00:00
Gilles Boccon-Gibod 46f4b82d29 Merge pull request #46 from AlanRosenthal/main
Add runtime switch for filtering by address.
2022-10-20 19:20:28 -07:00
Abel Lucas 4e2f66f709 device: raise a CommandTimeoutError error on command timeout 2022-10-20 22:11:07 +00:00
Alan Rosenthal 3d79d7def5 Add runtime switch for filtering by address.
* scan on [filter pattern]
* filter address <filter pattern>
2022-10-20 14:47:14 -04:00
Abel Lucas 915405a9bd examples: update run_classic_connect example to take multiple addresses instead of one 2022-10-20 14:53:39 +00:00
Abel Lucas 45dd849d9f classic: update ConnectionError to take transport and peer address 2022-10-20 14:53:03 +00:00
Abel Lucas 7208fd6642 classic: update Device.connect to allow parallels connection creation
According to the specification nothing prevent the Host from creating
multiple connections at the same time. This commit add this mechanisme
by matching the `connection` and `connection_failure` events against the
peer address.
2022-10-19 17:44:44 +00:00
Gilles Boccon-Gibod eb8556ccf6 gbg/extended scanning (#47)
Squashed:
* add extended report class
* more HCI commands
* add AdvertisingType
* add phy options
* fix tests
2022-10-19 10:06:00 -07:00
Octavian Purdila 4d96b821bc Merge pull request #44 from google/tavip/fix-address-resolution
Fix address resolution handling
2022-10-12 10:09:33 -07:00
Gilles Boccon-Gibod 78b36d2049 Merge pull request #45 from google/gbg/add-missing-app
add controller-info CLI app to setup
2022-10-11 22:21:08 -07:00
Gilles Boccon-Gibod 3e0cad1456 add controller-info CLI app to setup 2022-10-11 22:15:23 -07:00
Octavian Purdila b4de38cdc3 Fix address resolution handling
In one of the refactors the command address_resolution field was
changed to address_reslution_enable but the controller code was not
updated.
2022-10-11 22:53:42 +00:00
Gilles Boccon-Gibod 68d9fbc159 Merge pull request #42 from google/gbg/improve-linux-doc
Refactor and improve the doc for Bumble on Linux
2022-10-11 14:35:14 -07:00
Gilles Boccon-Gibod a916b7a21a Merge pull request #43 from google/gbg/proxy-write-with-response
support with_response on adapters
2022-10-11 07:41:28 -07:00
Gilles Boccon-Gibod 6ff52df8bd better/safer Linux recommendations 2022-10-10 20:11:55 -07:00
Gilles Boccon-Gibod 7fa2eb7658 support with_response on adapters 2022-10-10 12:11:51 -07:00
Gilles Boccon-Gibod 86618e52ef Refactor and improve the doc for Bumble on Linux 2022-10-09 12:56:06 -07:00
Gilles Boccon-Gibod fbb46dd736 Merge pull request #41 from google/gbg/cli-scripts
use arg-less main() functions in all scripts
2022-10-07 16:16:35 -07:00
112 changed files with 12641 additions and 5369 deletions
+9 -5
View File
@@ -14,6 +14,10 @@ jobs:
build: build:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10"]
fail-fast: false
steps: steps:
- name: Check out from Git - name: Check out from Git
@@ -22,17 +26,17 @@ jobs:
run: | run: |
git fetch --prune --unshallow git fetch --prune --unshallow
git fetch --depth=1 origin +refs/tags/*:refs/tags/* git fetch --depth=1 origin +refs/tags/*:refs/tags/*
- name: Set up Python 3.10 - name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3 uses: actions/setup-python@v4
with: with:
python-version: "3.10" python-version: ${{ matrix.python-version }}
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
python -m pip install ".[build,test,development,documentation]" python -m pip install ".[build,test,development,documentation]"
- name: Test with pytest - name: Test
run: | run: |
pytest invoke test
- name: Build - name: Build
run: | run: |
inv build inv build
+1 -4
View File
@@ -3,9 +3,6 @@ build/
dist/ dist/
*.egg-info/ *.egg-info/
*~ *~
bumble/__pycache__
docs/mkdocs/site docs/mkdocs/site
tests/__pycache__
test-results.xml test-results.xml
bumble/transport/__pycache__ __pycache__
bumble/profiles/__pycache__
+405 -96
View File
@@ -20,19 +20,26 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import asyncio import asyncio
from bumble.hci import HCI_Constant
import os
import os.path
import logging import logging
import click import os
import random
import re
from collections import OrderedDict from collections import OrderedDict
import click
import colors import colors
from bumble.core import UUID, AdvertisingData from bumble.core import UUID, AdvertisingData, TimeoutError, BT_LE_TRANSPORT
from bumble.device import Device, Connection, Peer from bumble.device import ConnectionParametersPreferences, Device, Connection, Peer
from bumble.utils import AsyncRunner from bumble.utils import AsyncRunner
from bumble.transport import open_transport_or_link from bumble.transport import open_transport_or_link
from bumble.gatt import Characteristic from bumble.gatt import Characteristic
from bumble.hci import (
HCI_Constant,
HCI_LE_1M_PHY,
HCI_LE_2M_PHY,
HCI_LE_CODED_PHY,
)
from prompt_toolkit import Application from prompt_toolkit import Application
from prompt_toolkit.history import FileHistory from prompt_toolkit.history import FileHistory
@@ -43,6 +50,7 @@ from prompt_toolkit.styles import Style
from prompt_toolkit.filters import Condition from prompt_toolkit.filters import Condition
from prompt_toolkit.widgets import TextArea, Frame from prompt_toolkit.widgets import TextArea, Frame
from prompt_toolkit.widgets.toolbars import FormattedTextToolbar from prompt_toolkit.widgets.toolbars import FormattedTextToolbar
from prompt_toolkit.data_structures import Point
from prompt_toolkit.layout import ( from prompt_toolkit.layout import (
Layout, Layout,
HSplit, HSplit,
@@ -51,17 +59,20 @@ from prompt_toolkit.layout import (
Float, Float,
FormattedTextControl, FormattedTextControl,
FloatContainer, FloatContainer,
ConditionalContainer ConditionalContainer,
Dimension,
) )
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Constants # Constants
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
BUMBLE_USER_DIR = os.path.expanduser('~/.bumble') BUMBLE_USER_DIR = os.path.expanduser('~/.bumble')
DEFAULT_PROMPT_HEIGHT = 20
DEFAULT_RSSI_BAR_WIDTH = 20 DEFAULT_RSSI_BAR_WIDTH = 20
DEFAULT_CONNECTION_TIMEOUT = 30.0
DISPLAY_MIN_RSSI = -100 DISPLAY_MIN_RSSI = -100
DISPLAY_MAX_RSSI = -30 DISPLAY_MAX_RSSI = -30
RSSI_MONITOR_INTERVAL = 5.0 # Seconds
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Globals # Globals
@@ -69,6 +80,44 @@ DISPLAY_MAX_RSSI = -30
App = None App = None
# -----------------------------------------------------------------------------
# Utils
# -----------------------------------------------------------------------------
def le_phy_name(phy_id):
return {HCI_LE_1M_PHY: '1M', HCI_LE_2M_PHY: '2M', HCI_LE_CODED_PHY: 'CODED'}.get(
phy_id, HCI_Constant.le_phy_name(phy_id)
)
def rssi_bar(rssi):
blocks = ['', '', '', '', '', '', '', '']
bar_width = (rssi - DISPLAY_MIN_RSSI) / (DISPLAY_MAX_RSSI - DISPLAY_MIN_RSSI)
bar_width = min(max(bar_width, 0), 1)
bar_ticks = int(bar_width * DEFAULT_RSSI_BAR_WIDTH * 8)
bar_blocks = ('' * int(bar_ticks / 8)) + blocks[bar_ticks % 8]
return f'{rssi:4} {bar_blocks}'
def parse_phys(phys):
if phys.lower() == '*':
return None
else:
phy_list = []
elements = phys.lower().split(',')
for element in elements:
if element == '1m':
phy_list.append(HCI_LE_1M_PHY)
elif element == '2m':
phy_list.append(HCI_LE_2M_PHY)
elif element == 'coded':
phy_list.append(HCI_LE_CODED_PHY)
else:
raise ValueError('invalid PHY name')
return phy_list
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Console App # Console App
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -78,14 +127,18 @@ class ConsoleApp:
self.known_attributes = [] self.known_attributes = []
self.device = None self.device = None
self.connected_peer = None self.connected_peer = None
self.top_tab = 'scan' self.top_tab = 'device'
self.monitor_rssi = False
self.connection_rssi = None
style = Style.from_dict({ style = Style.from_dict(
{
'output-field': 'bg:#000044 #ffffff', 'output-field': 'bg:#000044 #ffffff',
'input-field': 'bg:#000000 #ffffff', 'input-field': 'bg:#000000 #ffffff',
'line': '#004400', 'line': '#004400',
'error': 'fg:ansired' 'error': 'fg:ansired',
}) }
)
class LiveCompleter(Completer): class LiveCompleter(Completer):
def __init__(self, words): def __init__(self, words):
@@ -97,36 +150,37 @@ class ConsoleApp:
yield Completion(word, start_position=-len(prefix)) yield Completion(word, start_position=-len(prefix))
def make_completer(): def make_completer():
return NestedCompleter.from_nested_dict({ return NestedCompleter.from_nested_dict(
'scan': { {
'on': None, 'scan': {'on': None, 'off': None, 'clear': None},
'off': None 'advertise': {'on': None, 'off': None},
}, 'rssi': {'on': None, 'off': None},
'advertise': {
'on': None,
'off': None
},
'show': { 'show': {
'scan': None, 'scan': None,
'services': None, 'services': None,
'attributes': None, 'attributes': None,
'log': None 'log': None,
'device': None,
},
'filter': {
'address': None,
}, },
'connect': LiveCompleter(self.known_addresses), 'connect': LiveCompleter(self.known_addresses),
'update-parameters': None, 'update-parameters': None,
'encrypt': None, 'encrypt': None,
'disconnect': None, 'disconnect': None,
'discover': { 'discover': {'services': None, 'attributes': None},
'services': None, 'request-mtu': None,
'attributes': None
},
'read': LiveCompleter(self.known_attributes), 'read': LiveCompleter(self.known_attributes),
'write': LiveCompleter(self.known_attributes), 'write': LiveCompleter(self.known_attributes),
'subscribe': LiveCompleter(self.known_attributes), 'subscribe': LiveCompleter(self.known_attributes),
'unsubscribe': LiveCompleter(self.known_attributes), 'unsubscribe': LiveCompleter(self.known_attributes),
'set-phy': {'1m': None, '2m': None, 'coded': None},
'set-default-phy': None,
'quit': None, 'quit': None,
'exit': None 'exit': None,
}) }
)
self.input_field = TextArea( self.input_field = TextArea(
height=1, height=1,
@@ -134,43 +188,55 @@ class ConsoleApp:
multiline=False, multiline=False,
wrap_lines=False, wrap_lines=False,
completer=make_completer(), completer=make_completer(),
history=FileHistory(os.path.join(BUMBLE_USER_DIR, 'history')) history=FileHistory(os.path.join(BUMBLE_USER_DIR, 'history')),
) )
self.input_field.accept_handler = self.accept_input self.input_field.accept_handler = self.accept_input
self.output_height = 7 self.output_height = Dimension(min=7, max=7, weight=1)
self.output_lines = [] self.output_lines = []
self.output = FormattedTextControl() self.output = FormattedTextControl(
get_cursor_position=lambda: Point(0, max(0, len(self.output_lines) - 1))
)
self.output_max_lines = 20
self.scan_results_text = FormattedTextControl() self.scan_results_text = FormattedTextControl()
self.services_text = FormattedTextControl() self.services_text = FormattedTextControl()
self.attributes_text = FormattedTextControl() self.attributes_text = FormattedTextControl()
self.log_text = FormattedTextControl() self.device_text = FormattedTextControl()
self.log_height = 20 self.log_text = FormattedTextControl(
get_cursor_position=lambda: Point(0, max(0, len(self.log_lines) - 1))
)
self.log_height = Dimension(min=7, weight=4)
self.log_max_lines = 100
self.log_lines = [] self.log_lines = []
container = HSplit([ container = HSplit(
[
ConditionalContainer( ConditionalContainer(
Frame(Window(self.scan_results_text), title='Scan Results'), Frame(Window(self.scan_results_text), title='Scan Results'),
filter=Condition(lambda: self.top_tab == 'scan') filter=Condition(lambda: self.top_tab == 'scan'),
), ),
ConditionalContainer( ConditionalContainer(
Frame(Window(self.services_text), title='Services'), Frame(Window(self.services_text), title='Services'),
filter=Condition(lambda: self.top_tab == 'services') filter=Condition(lambda: self.top_tab == 'services'),
), ),
ConditionalContainer( ConditionalContainer(
Frame(Window(self.attributes_text), title='Attributes'), Frame(Window(self.attributes_text), title='Attributes'),
filter=Condition(lambda: self.top_tab == 'attributes') filter=Condition(lambda: self.top_tab == 'attributes'),
), ),
ConditionalContainer( ConditionalContainer(
Frame(Window(self.log_text), title='Log'), Frame(Window(self.log_text, height=self.log_height), title='Log'),
filter=Condition(lambda: self.top_tab == 'log') filter=Condition(lambda: self.top_tab == 'log'),
), ),
Frame(Window(self.output), height=self.output_height), ConditionalContainer(
# HorizontalLine(), Frame(Window(self.device_text), title='Device'),
filter=Condition(lambda: self.top_tab == 'device'),
),
Frame(Window(self.output, height=self.output_height)),
FormattedTextToolbar(text=self.get_status_bar_text, style='reverse'), FormattedTextToolbar(text=self.get_status_bar_text, style='reverse'),
self.input_field self.input_field,
]) ]
)
container = FloatContainer( container = FloatContainer(
container, container,
@@ -186,30 +252,43 @@ class ConsoleApp:
layout = Layout(container, focused_element=self.input_field) layout = Layout(container, focused_element=self.input_field)
kb = KeyBindings() kb = KeyBindings()
@kb.add("c-c") @kb.add("c-c")
@kb.add("c-q") @kb.add("c-q")
def _(event): def _(event):
event.app.exit() event.app.exit()
self.ui = Application( self.ui = Application(
layout=layout, layout=layout, style=style, key_bindings=kb, full_screen=True
style=style,
key_bindings=kb,
full_screen=True
) )
async def run_async(self, device_config, transport): async def run_async(self, device_config, transport):
rssi_monitoring_task = asyncio.create_task(self.rssi_monitor_loop())
async with await open_transport_or_link(transport) as (hci_source, hci_sink): async with await open_transport_or_link(transport) as (hci_source, hci_sink):
if device_config: if device_config:
self.device = Device.from_config_file_with_hci(device_config, hci_source, hci_sink) self.device = Device.from_config_file_with_hci(
device_config, hci_source, hci_sink
)
else: else:
self.device = Device.with_hci('Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink) random_address = (
f"{random.randint(192,255):02X}" # address is static random
)
for c in random.sample(range(255), 5):
random_address += f":{c:02X}"
self.append_to_log(f"Setting random address: {random_address}")
self.device = Device.with_hci(
'Bumble', random_address, hci_source, hci_sink
)
self.device.listener = DeviceListener(self) self.device.listener = DeviceListener(self)
await self.device.power_on() await self.device.power_on()
self.show_device(self.device)
# Run the UI # Run the UI
await self.ui.run_async() await self.ui.run_async()
rssi_monitoring_task.cancel()
def add_known_address(self, address): def add_known_address(self, address):
self.known_addresses.add(address) self.known_addresses.add(address)
@@ -224,22 +303,35 @@ class ConsoleApp:
connection_state = 'NONE' connection_state = 'NONE'
encryption_state = '' encryption_state = ''
att_mtu = ''
rssi = '' if self.connection_rssi is None else rssi_bar(self.connection_rssi)
if self.device: if self.device:
if self.device.is_connecting: if self.device.is_le_connecting:
connection_state = 'CONNECTING' connection_state = 'CONNECTING'
elif self.connected_peer: elif self.connected_peer:
connection = self.connected_peer.connection connection = self.connected_peer.connection
connection_parameters = f'{connection.parameters.connection_interval}/{connection.parameters.connection_latency}/{connection.parameters.supervision_timeout}' connection_parameters = f'{connection.parameters.connection_interval}/{connection.parameters.peripheral_latency}/{connection.parameters.supervision_timeout}'
connection_state = f'{connection.peer_address} {connection_parameters} {connection.data_length}' if connection.transport == BT_LE_TRANSPORT:
encryption_state = 'ENCRYPTED' if connection.is_encrypted else 'NOT ENCRYPTED' phy_state = f' RX={le_phy_name(connection.phy.rx_phy)}/TX={le_phy_name(connection.phy.tx_phy)}'
else:
phy_state = ''
connection_state = f'{connection.peer_address} {connection_parameters} {connection.data_length}{phy_state}'
encryption_state = (
'ENCRYPTED' if connection.is_encrypted else 'NOT ENCRYPTED'
)
att_mtu = f'ATT_MTU: {connection.att_mtu}'
return [ return [
('ansigreen', f' SCAN: {scanning} '), ('ansigreen', f' SCAN: {scanning} '),
('', ' '), ('', ' '),
('ansiblue', f' CONNECTION: {connection_state} '), ('ansiblue', f' CONNECTION: {connection_state} '),
('', ' '), ('', ' '),
('ansimagenta', f' {encryption_state} ') ('ansimagenta', f' {encryption_state} '),
('', ' '),
('ansicyan', f' {att_mtu} '),
('', ' '),
('ansiyellow', f' {rssi} '),
] ]
def show_error(self, title, details=None): def show_error(self, title, details=None):
@@ -265,7 +357,9 @@ class ConsoleApp:
for characteristic in service.characteristics: for characteristic in service.characteristics:
lines.append(('ansimagenta', ' ' + str(characteristic) + '\n')) lines.append(('ansimagenta', ' ' + str(characteristic) + '\n'))
self.known_attributes.append(f'{service.uuid.to_hex_str()}.{characteristic.uuid.to_hex_str()}') self.known_attributes.append(
f'{service.uuid.to_hex_str()}.{characteristic.uuid.to_hex_str()}'
)
self.known_attributes.append(f'*.{characteristic.uuid.to_hex_str()}') self.known_attributes.append(f'*.{characteristic.uuid.to_hex_str()}')
self.known_attributes.append(f'#{characteristic.handle:X}') self.known_attributes.append(f'#{characteristic.handle:X}')
for descriptor in characteristic.descriptors: for descriptor in characteristic.descriptors:
@@ -274,7 +368,7 @@ class ConsoleApp:
self.services_text.text = lines self.services_text.text = lines
self.ui.invalidate() self.ui.invalidate()
async def show_attributes(self, attributes): def show_attributes(self, attributes):
lines = [] lines = []
for attribute in attributes: for attribute in attributes:
@@ -283,10 +377,48 @@ class ConsoleApp:
self.attributes_text.text = lines self.attributes_text.text = lines
self.ui.invalidate() self.ui.invalidate()
def show_device(self, device):
lines = []
lines.append(('ansicyan', 'Name: '))
lines.append(('', f'{device.name}\n'))
lines.append(('ansicyan', 'Public Address: '))
lines.append(('', f'{device.public_address}\n'))
lines.append(('ansicyan', 'Random Address: '))
lines.append(('', f'{device.random_address}\n'))
lines.append(('ansicyan', 'LE Enabled: '))
lines.append(('', f'{device.le_enabled}\n'))
lines.append(('ansicyan', 'Classic Enabled: '))
lines.append(('', f'{device.classic_enabled}\n'))
lines.append(('ansicyan', 'Classic SC Enabled: '))
lines.append(('', f'{device.classic_sc_enabled}\n'))
lines.append(('ansicyan', 'Classic SSP Enabled: '))
lines.append(('', f'{device.classic_ssp_enabled}\n'))
lines.append(('ansicyan', 'Classic Class: '))
lines.append(('', f'{device.class_of_device}\n'))
lines.append(('ansicyan', 'Discoverable: '))
lines.append(('', f'{device.discoverable}\n'))
lines.append(('ansicyan', 'Connectable: '))
lines.append(('', f'{device.connectable}\n'))
lines.append(('ansicyan', 'Advertising Data: '))
lines.append(('', f'{device.advertising_data}\n'))
lines.append(('ansicyan', 'Scan Response Data: '))
lines.append(('', f'{device.scan_response_data}\n'))
advertising_interval = (
device.advertising_interval_min
if device.advertising_interval_min == device.advertising_interval_max
else f"{device.advertising_interval_min} to {device.advertising_interval_max}"
)
lines.append(('ansicyan', 'Advertising Interval: '))
lines.append(('', f'{advertising_interval}\n'))
self.device_text.text = lines
self.ui.invalidate()
def append_to_output(self, line, invalidate=True): def append_to_output(self, line, invalidate=True):
if type(line) is str: if type(line) is str:
line = [('', line)] line = [('', line)]
self.output_lines = self.output_lines[-(self.output_height - 3):] self.output_lines = self.output_lines[-self.output_max_lines :]
self.output_lines.append(line) self.output_lines.append(line)
formatted_text = [] formatted_text = []
for line in self.output_lines: for line in self.output_lines:
@@ -298,7 +430,7 @@ class ConsoleApp:
def append_to_log(self, lines, invalidate=True): def append_to_log(self, lines, invalidate=True):
self.log_lines.extend(lines.split('\n')) self.log_lines.extend(lines.split('\n'))
self.log_lines = self.log_lines[-(self.log_height - 3):] self.log_lines = self.log_lines[-self.log_max_lines :]
self.log_text.text = ANSI('\n'.join(self.log_lines)) self.log_text.text = ANSI('\n'.join(self.log_lines))
if invalidate: if invalidate:
self.ui.invalidate() self.ui.invalidate()
@@ -311,7 +443,10 @@ class ConsoleApp:
# Discover all services, characteristics and descriptors # Discover all services, characteristics and descriptors
self.append_to_output('discovering services...') self.append_to_output('discovering services...')
await self.connected_peer.discover_services() await self.connected_peer.discover_services()
self.append_to_output(f'found {len(self.connected_peer.services)} services, discovering charateristics...') self.append_to_output(
f'found {len(self.connected_peer.services)} services,'
' discovering characteristics...'
)
await self.connected_peer.discover_characteristics() await self.connected_peer.discover_characteristics()
self.append_to_output('found characteristics, discovering descriptors...') self.append_to_output('found characteristics, discovering descriptors...')
for service in self.connected_peer.services: for service in self.connected_peer.services:
@@ -331,7 +466,7 @@ class ConsoleApp:
attributes = await self.connected_peer.discover_attributes() attributes = await self.connected_peer.discover_attributes()
self.append_to_output(f'discovered {len(attributes)} attributes...') self.append_to_output(f'discovered {len(attributes)} attributes...')
await self.show_attributes(attributes) self.show_attributes(attributes)
def find_characteristic(self, param): def find_characteristic(self, param):
parts = param.split('.') parts = param.split('.')
@@ -351,6 +486,12 @@ class ConsoleApp:
if characteristic.handle == attribute_handle: if characteristic.handle == attribute_handle:
return characteristic return characteristic
async def rssi_monitor_loop(self):
while True:
if self.monitor_rssi and self.connected_peer:
self.connection_rssi = await self.connected_peer.connection.get_rssi()
await asyncio.sleep(RSSI_MONITOR_INTERVAL)
async def command(self, command): async def command(self, command):
try: try:
(keyword, *params) = command.strip().split(' ') (keyword, *params) = command.strip().split(' ')
@@ -372,23 +513,75 @@ class ConsoleApp:
else: else:
await self.device.start_scanning() await self.device.start_scanning()
elif params[0] == 'on': elif params[0] == 'on':
if len(params) == 2:
if not params[1].startswith("filter="):
self.show_error(
'invalid syntax',
'expected address filter=key1:value1,key2:value,... available filters: address',
)
# regex: (word):(any char except ,)
matches = re.findall(r"(\w+):([^,]+)", params[1])
for match in matches:
if match[0] == "address":
self.device.listener.address_filter = match[1]
await self.device.start_scanning() await self.device.start_scanning()
self.top_tab = 'scan' self.top_tab = 'scan'
elif params[0] == 'off': elif params[0] == 'off':
await self.device.stop_scanning() await self.device.stop_scanning()
elif params[0] == 'clear':
self.device.listener.scan_results.clear()
self.known_addresses.clear()
self.show_scan_results(self.device.listener.scan_results)
else: else:
self.show_error('unsupported arguments for scan command') self.show_error('unsupported arguments for scan command')
async def do_rssi(self, params):
if len(params) == 0:
# Toggle monitoring
self.monitor_rssi = not self.monitor_rssi
elif params[0] == 'on':
self.monitor_rssi = True
elif params[0] == 'off':
self.monitor_rssi = False
else:
self.show_error('unsupported arguments for rssi command')
async def do_connect(self, params): async def do_connect(self, params):
if len(params) != 1: if len(params) != 1 and len(params) != 2:
self.show_error('invalid syntax', 'expected connect <address>') self.show_error('invalid syntax', 'expected connect <address> [phys]')
return return
if len(params) == 1:
phys = None
else:
phys = parse_phys(params[1])
if phys is None:
connection_parameters_preferences = None
else:
connection_parameters_preferences = {
phy: ConnectionParametersPreferences() for phy in phys
}
if self.device.is_scanning:
await self.device.stop_scanning()
self.append_to_output('connecting...') self.append_to_output('connecting...')
await self.device.connect(params[0])
try:
await self.device.connect(
params[0],
connection_parameters_preferences=connection_parameters_preferences,
timeout=DEFAULT_CONNECTION_TIMEOUT,
)
self.top_tab = 'services' self.top_tab = 'services'
except TimeoutError:
self.show_error('connection timed out')
async def do_disconnect(self, params): async def do_disconnect(self, params):
if self.device.is_le_connecting:
await self.device.cancel_connection()
else:
if not self.connected_peer: if not self.connected_peer:
self.show_error('not connected') self.show_error('not connected')
return return
@@ -397,22 +590,27 @@ class ConsoleApp:
async def do_update_parameters(self, params): async def do_update_parameters(self, params):
if len(params) != 1 or len(params[0].split('/')) != 3: if len(params) != 1 or len(params[0].split('/')) != 3:
self.show_error('invalid syntax', 'expected update-parameters <interval-min>-<interval-max>/<latency>/<supervision>') self.show_error(
'invalid syntax',
'expected update-parameters <interval-min>-<interval-max>/<max-latency>/<supervision>',
)
return return
if not self.connected_peer: if not self.connected_peer:
self.show_error('not connected') self.show_error('not connected')
return return
connection_intervals, connection_latency, supervision_timeout = params[0].split('/') connection_intervals, max_latency, supervision_timeout = params[0].split('/')
connection_interval_min, connection_interval_max = [int(x) for x in connection_intervals.split('-')] connection_interval_min, connection_interval_max = [
connection_latency = int(connection_latency) int(x) for x in connection_intervals.split('-')
]
max_latency = int(max_latency)
supervision_timeout = int(supervision_timeout) supervision_timeout = int(supervision_timeout)
await self.connected_peer.connection.update_parameters( await self.connected_peer.connection.update_parameters(
connection_interval_min, connection_interval_min,
connection_interval_max, connection_interval_max,
connection_latency, max_latency,
supervision_timeout supervision_timeout,
) )
async def do_encrypt(self, params): async def do_encrypt(self, params):
@@ -438,10 +636,31 @@ class ConsoleApp:
async def do_show(self, params): async def do_show(self, params):
if params: if params:
if params[0] in {'scan', 'services', 'attributes', 'log'}: if params[0] in {'scan', 'services', 'attributes', 'log', 'device'}:
self.top_tab = params[0] self.top_tab = params[0]
self.ui.invalidate() self.ui.invalidate()
async def do_get_phy(self, params):
if not self.connected_peer:
self.show_error('not connected')
return
phy = await self.connected_peer.connection.get_phy()
self.append_to_output(
f'PHY: RX={HCI_Constant.le_phy_name(phy[0])}, TX={HCI_Constant.le_phy_name(phy[1])}'
)
async def do_request_mtu(self, params):
if len(params) != 1:
self.show_error('invalid syntax', 'expected request-mtu <mtu>')
return
if not self.connected_peer:
self.show_error('not connected')
return
await self.connected_peer.request_mtu(int(params[0]))
async def do_discover(self, params): async def do_discover(self, params):
if not params: if not params:
self.show_error('invalid syntax', 'expected discover services|attributes') self.show_error('invalid syntax', 'expected discover services|attributes')
@@ -454,14 +673,14 @@ class ConsoleApp:
await self.discover_attributes() await self.discover_attributes()
async def do_read(self, params): async def do_read(self, params):
if not self.connected_peer:
self.show_error('not connected')
return
if len(params) != 1: if len(params) != 1:
self.show_error('invalid syntax', 'expected read <attribute>') self.show_error('invalid syntax', 'expected read <attribute>')
return return
if not self.connected_peer:
self.show_error('not connected')
return
characteristic = self.find_characteristic(params[0]) characteristic = self.find_characteristic(params[0])
if characteristic is None: if characteristic is None:
self.show_error('no such characteristic') self.show_error('no such characteristic')
@@ -511,7 +730,9 @@ class ConsoleApp:
return return
await characteristic.subscribe( await characteristic.subscribe(
lambda value: self.append_to_output(f"{characteristic} VALUE: 0x{value.hex()}"), lambda value: self.append_to_output(
f"{characteristic} VALUE: 0x{value.hex()}"
),
) )
async def do_unsubscribe(self, params): async def do_unsubscribe(self, params):
@@ -530,12 +751,58 @@ class ConsoleApp:
await characteristic.unsubscribe() await characteristic.unsubscribe()
async def do_set_phy(self, params):
if len(params) != 1:
self.show_error(
'invalid syntax', 'expected set-phy <tx_rx_phys>|<tx_phys>/<rx_phys>'
)
return
if not self.connected_peer:
self.show_error('not connected')
return
if '/' in params[0]:
tx_phys, rx_phys = params[0].split('/')
else:
tx_phys = params[0]
rx_phys = tx_phys
await self.connected_peer.connection.set_phy(
tx_phys=parse_phys(tx_phys), rx_phys=parse_phys(rx_phys)
)
async def do_set_default_phy(self, params):
if len(params) != 1:
self.show_error(
'invalid syntax',
'expected set-default-phy <tx_rx_phys>|<tx_phys>/<rx_phys>',
)
return
if '/' in params[0]:
tx_phys, rx_phys = params[0].split('/')
else:
tx_phys = params[0]
rx_phys = tx_phys
await self.device.set_default_phy(
tx_phys=parse_phys(tx_phys), rx_phys=parse_phys(rx_phys)
)
async def do_exit(self, params): async def do_exit(self, params):
self.ui.exit() self.ui.exit()
async def do_quit(self, params): async def do_quit(self, params):
self.ui.exit() self.ui.exit()
async def do_filter(self, params):
if params[0] == "address":
if len(params) != 2:
self.show_error('invalid syntax', 'expected filter address <pattern>')
return
self.device.listener.address_filter = params[1]
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Device and Connection Listener # Device and Connection Listener
@@ -544,42 +811,87 @@ class DeviceListener(Device.Listener, Connection.Listener):
def __init__(self, app): def __init__(self, app):
self.app = app self.app = app
self.scan_results = OrderedDict() self.scan_results = OrderedDict()
self.address_filter = None
@property
def address_filter(self):
return self._address_filter
@address_filter.setter
def address_filter(self, filter_addr):
if filter_addr is None:
self._address_filter = re.compile(r".*")
else:
self._address_filter = re.compile(filter_addr)
self.scan_results = OrderedDict(
filter(lambda x: self.filter_address_match(x), self.scan_results)
)
self.app.show_scan_results(self.scan_results)
def filter_address_match(self, address):
"""
Returns true if an address matches the filter
"""
return bool(self.address_filter.match(address))
@AsyncRunner.run_in_task() @AsyncRunner.run_in_task()
async def on_connection(self, connection): async def on_connection(self, connection):
self.app.connected_peer = Peer(connection) self.app.connected_peer = Peer(connection)
self.app.connection_rssi = None
self.app.append_to_output(f'connected to {self.app.connected_peer}') self.app.append_to_output(f'connected to {self.app.connected_peer}')
connection.listener = self connection.listener = self
def on_disconnection(self, reason): def on_disconnection(self, reason):
self.app.append_to_output(f'disconnected from {self.app.connected_peer}, reason: {HCI_Constant.error_name(reason)}') self.app.append_to_output(
f'disconnected from {self.app.connected_peer}, reason: {HCI_Constant.error_name(reason)}'
)
self.app.connected_peer = None self.app.connected_peer = None
self.app.connection_rssi = None
def on_connection_parameters_update(self): def on_connection_parameters_update(self):
self.app.append_to_output(f'connection parameters update: {self.app.connected_peer.connection.parameters}') self.app.append_to_output(
f'connection parameters update: {self.app.connected_peer.connection.parameters}'
)
def on_connection_phy_update(self): def on_connection_phy_update(self):
self.app.append_to_output(f'connection phy update: {self.app.connected_peer.connection.phy}') self.app.append_to_output(
f'connection phy update: {self.app.connected_peer.connection.phy}'
)
def on_connection_att_mtu_update(self): def on_connection_att_mtu_update(self):
self.app.append_to_output(f'connection att mtu update: {self.app.connected_peer.connection.att_mtu}') self.app.append_to_output(
f'connection att mtu update: {self.app.connected_peer.connection.att_mtu}'
)
def on_connection_encryption_change(self): def on_connection_encryption_change(self):
self.app.append_to_output(f'connection encryption change: {"encrypted" if self.app.connected_peer.connection.is_encrypted else "not encrypted"}') self.app.append_to_output(
f'connection encryption change: {"encrypted" if self.app.connected_peer.connection.is_encrypted else "not encrypted"}'
)
def on_connection_data_length_change(self): def on_connection_data_length_change(self):
self.app.append_to_output(f'connection data length change: {self.app.connected_peer.connection.data_length}') self.app.append_to_output(
f'connection data length change: {self.app.connected_peer.connection.data_length}'
)
def on_advertisement(self, address, ad_data, rssi, connectable): def on_advertisement(self, advertisement):
entry_key = f'{address}/{address.address_type}' if not self.filter_address_match(str(advertisement.address)):
return
entry_key = f'{advertisement.address}/{advertisement.address.address_type}'
entry = self.scan_results.get(entry_key) entry = self.scan_results.get(entry_key)
if entry: if entry:
entry.ad_data = ad_data entry.ad_data = advertisement.data
entry.rssi = rssi entry.rssi = advertisement.rssi
entry.connectable = connectable entry.connectable = advertisement.is_connectable
else: else:
self.app.add_known_address(str(address)) self.app.add_known_address(str(advertisement.address))
self.scan_results[entry_key] = ScanResult(address, address.address_type, ad_data, rssi, connectable) self.scan_results[entry_key] = ScanResult(
advertisement.address,
advertisement.address.address_type,
advertisement.data,
advertisement.rssi,
advertisement.is_connectable,
)
self.app.show_scan_results(self.scan_results) self.app.show_scan_results(self.scan_results)
@@ -603,9 +915,9 @@ class ScanResult:
else: else:
type_color = colors.cyan type_color = colors.cyan
name = self.ad_data.get(AdvertisingData.COMPLETE_LOCAL_NAME) name = self.ad_data.get(AdvertisingData.COMPLETE_LOCAL_NAME, raw=True)
if name is None: if name is None:
name = self.ad_data.get(AdvertisingData.SHORTENED_LOCAL_NAME) name = self.ad_data.get(AdvertisingData.SHORTENED_LOCAL_NAME, raw=True)
if name: if name:
# Convert to string # Convert to string
try: try:
@@ -616,12 +928,7 @@ class ScanResult:
name = '' name = ''
# RSSI bar # RSSI bar
blocks = ['', '', '', '', '', '', '', ''] bar_string = rssi_bar(self.rssi)
bar_width = (self.rssi - DISPLAY_MIN_RSSI) / (DISPLAY_MAX_RSSI - DISPLAY_MIN_RSSI)
bar_width = min(max(bar_width, 0), 1)
bar_ticks = int(bar_width * DEFAULT_RSSI_BAR_WIDTH * 8)
bar_blocks = ('' * int(bar_ticks / 8)) + blocks[bar_ticks % 8]
bar_string = f'{self.rssi} {bar_blocks}'
bar_padding = ' ' * (DEFAULT_RSSI_BAR_WIDTH + 5 - len(bar_string)) bar_padding = ' ' * (DEFAULT_RSSI_BAR_WIDTH + 5 - len(bar_string))
return f'{address_color(str(self.address))} [{type_color(address_type_string)}] {bar_string} {bar_padding} {name}' return f'{address_color(str(self.address))} [{type_color(address_type_string)}] {bar_string} {bar_padding} {name}'
@@ -633,6 +940,7 @@ class LogHandler(logging.Handler):
def __init__(self, app): def __init__(self, app):
super().__init__() super().__init__()
self.app = app self.app = app
self.setFormatter(logging.Formatter('[%(asctime)s][%(levelname)s] %(message)s'))
def emit(self, record): def emit(self, record):
message = self.format(record) message = self.format(record)
@@ -657,6 +965,7 @@ def main(device_config, transport):
# logging.basicConfig(level = 'FATAL') # logging.basicConfig(level = 'FATAL')
# logging.basicConfig(level = 'DEBUG') # logging.basicConfig(level = 'DEBUG')
root_logger = logging.getLogger() root_logger = logging.getLogger()
root_logger.addHandler(LogHandler(app)) root_logger.addHandler(LogHandler(app))
root_logger.setLevel(logging.DEBUG) root_logger.setLevel(logging.DEBUG)
+65 -8
View File
@@ -25,15 +25,21 @@ from bumble.company_ids import COMPANY_IDENTIFIERS
from bumble.core import name_or_number from bumble.core import name_or_number
from bumble.hci import ( from bumble.hci import (
map_null_terminated_utf8_string, map_null_terminated_utf8_string,
HCI_LE_SUPPORTED_FEATURES_NAMES,
HCI_SUCCESS, HCI_SUCCESS,
HCI_LE_SUPPORTED_FEATURES_NAMES,
HCI_VERSION_NAMES, HCI_VERSION_NAMES,
LMP_VERSION_NAMES, LMP_VERSION_NAMES,
HCI_Command, HCI_Command,
HCI_Read_BD_ADDR_Command,
HCI_READ_BD_ADDR_COMMAND, HCI_READ_BD_ADDR_COMMAND,
HCI_Read_BD_ADDR_Command,
HCI_READ_LOCAL_NAME_COMMAND,
HCI_Read_Local_Name_Command, HCI_Read_Local_Name_Command,
HCI_READ_LOCAL_NAME_COMMAND HCI_LE_READ_MAXIMUM_DATA_LENGTH_COMMAND,
HCI_LE_Read_Maximum_Data_Length_Command,
HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND,
HCI_LE_Read_Number_Of_Supported_Advertising_Sets_Command,
HCI_LE_READ_MAXIMUM_ADVERTISING_DATA_LENGTH_COMMAND,
HCI_LE_Read_Maximum_Advertising_Data_Length_Command,
) )
from bumble.host import Host from bumble.host import Host
from bumble.transport import open_transport_or_link from bumble.transport import open_transport_or_link
@@ -45,18 +51,60 @@ async def get_classic_info(host):
response = await host.send_command(HCI_Read_BD_ADDR_Command()) response = await host.send_command(HCI_Read_BD_ADDR_Command())
if response.return_parameters.status == HCI_SUCCESS: if response.return_parameters.status == HCI_SUCCESS:
print() print()
print(color('Classic Address:', 'yellow'), response.return_parameters.bd_addr) print(
color('Classic Address:', 'yellow'), response.return_parameters.bd_addr
)
if host.supports_command(HCI_READ_LOCAL_NAME_COMMAND): if host.supports_command(HCI_READ_LOCAL_NAME_COMMAND):
response = await host.send_command(HCI_Read_Local_Name_Command()) response = await host.send_command(HCI_Read_Local_Name_Command())
if response.return_parameters.status == HCI_SUCCESS: if response.return_parameters.status == HCI_SUCCESS:
print() print()
print(color('Local Name:', 'yellow'), map_null_terminated_utf8_string(response.return_parameters.local_name)) print(
color('Local Name:', 'yellow'),
map_null_terminated_utf8_string(response.return_parameters.local_name),
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def get_le_info(host): async def get_le_info(host):
print() print()
if host.supports_command(HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND):
response = await host.send_command(
HCI_LE_Read_Number_Of_Supported_Advertising_Sets_Command()
)
if response.return_parameters.status == HCI_SUCCESS:
print(
color('LE Number Of Supported Advertising Sets:', 'yellow'),
response.return_parameters.num_supported_advertising_sets,
'\n',
)
if host.supports_command(HCI_LE_READ_MAXIMUM_ADVERTISING_DATA_LENGTH_COMMAND):
response = await host.send_command(
HCI_LE_Read_Maximum_Advertising_Data_Length_Command()
)
if response.return_parameters.status == HCI_SUCCESS:
print(
color('LE Maximum Advertising Data Length:', 'yellow'),
response.return_parameters.max_advertising_data_length,
'\n',
)
if host.supports_command(HCI_LE_READ_MAXIMUM_DATA_LENGTH_COMMAND):
response = await host.send_command(HCI_LE_Read_Maximum_Data_Length_Command())
if response.return_parameters.status == HCI_SUCCESS:
print(
color('Maximum Data Length:', 'yellow'),
(
f'tx:{response.return_parameters.supported_max_tx_octets}/'
f'{response.return_parameters.supported_max_tx_time}, '
f'rx:{response.return_parameters.supported_max_rx_octets}/'
f'{response.return_parameters.supported_max_rx_time}'
),
'\n',
)
print(color('LE Features:', 'yellow')) print(color('LE Features:', 'yellow'))
for feature in host.supported_le_features: for feature in host.supported_le_features:
print(' ', name_or_number(HCI_LE_SUPPORTED_FEATURES_NAMES, feature)) print(' ', name_or_number(HCI_LE_SUPPORTED_FEATURES_NAMES, feature))
@@ -73,10 +121,19 @@ async def async_main(transport):
# Print version # Print version
print(color('Version:', 'yellow')) print(color('Version:', 'yellow'))
print(color(' Manufacturer: ', 'green'), name_or_number(COMPANY_IDENTIFIERS, host.local_version.company_identifier)) print(
print(color(' HCI Version: ', 'green'), name_or_number(HCI_VERSION_NAMES, host.local_version.hci_version)) color(' Manufacturer: ', 'green'),
name_or_number(COMPANY_IDENTIFIERS, host.local_version.company_identifier),
)
print(
color(' HCI Version: ', 'green'),
name_or_number(HCI_VERSION_NAMES, host.local_version.hci_version),
)
print(color(' HCI Subversion:', 'green'), host.local_version.hci_subversion) print(color(' HCI Subversion:', 'green'), host.local_version.hci_subversion)
print(color(' LMP Version: ', 'green'), name_or_number(LMP_VERSION_NAMES, host.local_version.lmp_version)) print(
color(' LMP Version: ', 'green'),
name_or_number(LMP_VERSION_NAMES, host.local_version.lmp_version),
)
print(color(' LMP Subversion:', 'green'), host.local_version.lmp_subversion) print(color(' LMP Subversion:', 'green'), host.local_version.lmp_subversion)
# Get the Classic info # Get the Classic info
+9 -2
View File
@@ -28,7 +28,9 @@ from bumble.transport import open_transport_or_link
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def async_main(): async def async_main():
if len(sys.argv) != 3: if len(sys.argv) != 3:
print('Usage: controllers.py <hci-transport-1> <hci-transport-2> [<hci-transport-3> ...]') print(
'Usage: controllers.py <hci-transport-1> <hci-transport-2> [<hci-transport-3> ...]'
)
print('example: python controllers.py pty:ble1 pty:ble2') print('example: python controllers.py pty:ble1 pty:ble2')
return return
@@ -41,7 +43,12 @@ async def async_main():
for index, transport_name in enumerate(sys.argv[1:]): for index, transport_name in enumerate(sys.argv[1:]):
transport = await open_transport_or_link(transport_name) transport = await open_transport_or_link(transport_name)
transports.append(transport) transports.append(transport)
controller = Controller(f'C{index}', host_source = transport.source, host_sink = transport.sink, link = link) controller = Controller(
f'C{index}',
host_source=transport.source,
host_sink=transport.sink,
link=link,
)
controllers.append(controller) controllers.append(controller)
# Wait until the user interrupts # Wait until the user interrupts
+12 -3
View File
@@ -64,9 +64,13 @@ async def async_main(device_config, encrypt, transport, address_or_name):
# Create a device # Create a device
if device_config: if device_config:
device = Device.from_config_file_with_hci(device_config, hci_source, hci_sink) device = Device.from_config_file_with_hci(
device_config, hci_source, hci_sink
)
else: else:
device = Device.with_hci('Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink) device = Device.with_hci(
'Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink
)
await device.power_on() await device.power_on()
if address_or_name: if address_or_name:
@@ -81,7 +85,12 @@ async def async_main(device_config, encrypt, transport, address_or_name):
else: else:
# Wait for a connection # Wait for a connection
done = asyncio.get_running_loop().create_future() done = asyncio.get_running_loop().create_future()
device.on('connection', lambda connection: asyncio.create_task(dump_gatt_db(Peer(connection), done))) device.on(
'connection',
lambda connection: asyncio.create_task(
dump_gatt_db(Peer(connection), done)
),
)
await device.start_advertising(auto_restart=True) await device.start_advertising(auto_restart=True)
print(color('### Waiting for connection...', 'blue')) print(color('### Waiting for connection...', 'blue'))
+249 -69
View File
@@ -17,13 +17,14 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import asyncio import asyncio
import os import os
import struct
import logging import logging
import click import click
from colors import color from colors import color
from bumble.device import Device, Peer from bumble.device import Device, Peer
from bumble.core import AdvertisingData from bumble.core import AdvertisingData
from bumble.gatt import Service, Characteristic from bumble.gatt import Service, Characteristic, CharacteristicValue
from bumble.utils import AsyncRunner from bumble.utils import AsyncRunner
from bumble.transport import open_transport_or_link from bumble.transport import open_transport_or_link
from bumble.hci import HCI_Constant from bumble.hci import HCI_Constant
@@ -35,19 +36,67 @@ from bumble.hci import HCI_Constant
GG_GATTLINK_SERVICE_UUID = 'ABBAFF00-E56A-484C-B832-8B17CF6CBFE8' GG_GATTLINK_SERVICE_UUID = 'ABBAFF00-E56A-484C-B832-8B17CF6CBFE8'
GG_GATTLINK_RX_CHARACTERISTIC_UUID = 'ABBAFF01-E56A-484C-B832-8B17CF6CBFE8' GG_GATTLINK_RX_CHARACTERISTIC_UUID = 'ABBAFF01-E56A-484C-B832-8B17CF6CBFE8'
GG_GATTLINK_TX_CHARACTERISTIC_UUID = 'ABBAFF02-E56A-484C-B832-8B17CF6CBFE8' GG_GATTLINK_TX_CHARACTERISTIC_UUID = 'ABBAFF02-E56A-484C-B832-8B17CF6CBFE8'
GG_GATTLINK_L2CAP_CHANNEL_PSM_CHARACTERISTIC_UUID = 'ABBAFF03-E56A-484C-B832-8B17CF6CBFE8' GG_GATTLINK_L2CAP_CHANNEL_PSM_CHARACTERISTIC_UUID = (
'ABBAFF03-E56A-484C-B832-8B17CF6CBFE8'
)
GG_PREFERRED_MTU = 256 GG_PREFERRED_MTU = 256
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class GattlinkHubBridge(Device.Listener): class GattlinkL2capEndpoint:
def __init__(self): def __init__(self):
self.l2cap_channel = None
self.l2cap_packet = b''
self.l2cap_packet_size = 0
# Called when an L2CAP SDU has been received
def on_coc_sdu(self, sdu):
print(color(f'<<< [L2CAP SDU]: {len(sdu)} bytes', 'cyan'))
while len(sdu):
if self.l2cap_packet_size == 0:
# Expect a new packet
self.l2cap_packet_size = sdu[0] + 1
sdu = sdu[1:]
else:
bytes_needed = self.l2cap_packet_size - len(self.l2cap_packet)
chunk = min(bytes_needed, len(sdu))
self.l2cap_packet += sdu[:chunk]
sdu = sdu[chunk:]
if len(self.l2cap_packet) == self.l2cap_packet_size:
self.on_l2cap_packet(self.l2cap_packet)
self.l2cap_packet = b''
self.l2cap_packet_size = 0
# -----------------------------------------------------------------------------
class GattlinkHubBridge(GattlinkL2capEndpoint, Device.Listener):
def __init__(self, device, peer_address):
super().__init__()
self.device = device
self.peer_address = peer_address
self.peer = None self.peer = None
self.rx_socket = None
self.tx_socket = None self.tx_socket = None
self.rx_characteristic = None self.rx_characteristic = None
self.tx_characteristic = None self.tx_characteristic = None
self.l2cap_psm_characteristic = None
device.listener = self
async def start(self):
# Connect to the peer
print(f'=== Connecting to {self.peer_address}...')
await self.device.connect(self.peer_address)
async def connect_l2cap(self, psm):
print(color(f'### Connecting with L2CAP on PSM = {psm}', 'yellow'))
try:
self.l2cap_channel = await self.peer.connection.open_l2cap_channel(psm)
print(color('*** Connected', 'yellow'), self.l2cap_channel)
self.l2cap_channel.sink = self.on_coc_sdu
except Exception as error:
print(color(f'!!! Connection failed: {error}', 'red'))
@AsyncRunner.run_in_task() @AsyncRunner.run_in_task()
async def on_connection(self, connection): async def on_connection(self, connection):
@@ -80,50 +129,60 @@ class GattlinkHubBridge(Device.Listener):
self.rx_characteristic = characteristic self.rx_characteristic = characteristic
elif characteristic.uuid == GG_GATTLINK_TX_CHARACTERISTIC_UUID: elif characteristic.uuid == GG_GATTLINK_TX_CHARACTERISTIC_UUID:
self.tx_characteristic = characteristic self.tx_characteristic = characteristic
elif (
characteristic.uuid == GG_GATTLINK_L2CAP_CHANNEL_PSM_CHARACTERISTIC_UUID
):
self.l2cap_psm_characteristic = characteristic
print('RX:', self.rx_characteristic) print('RX:', self.rx_characteristic)
print('TX:', self.tx_characteristic) print('TX:', self.tx_characteristic)
print('PSM:', self.l2cap_psm_characteristic)
if self.l2cap_psm_characteristic:
# Subscribe to and then read the PSM value
await self.peer.subscribe(
self.l2cap_psm_characteristic, self.on_l2cap_psm_received
)
psm_bytes = await self.peer.read_value(self.l2cap_psm_characteristic)
psm = struct.unpack('<H', psm_bytes)[0]
await self.connect_l2cap(psm)
elif self.tx_characteristic:
# Subscribe to TX # Subscribe to TX
if self.tx_characteristic:
await self.peer.subscribe(self.tx_characteristic, self.on_tx_received) await self.peer.subscribe(self.tx_characteristic, self.on_tx_received)
print(color('=== Subscribed to Gattlink TX', 'yellow')) print(color('=== Subscribed to Gattlink TX', 'yellow'))
else: else:
print(color('!!! Gattlink TX not found', 'red')) print(color('!!! No Gattlink TX or PSM found', 'red'))
def on_connection_failure(self, error): def on_connection_failure(self, error):
print(color(f'!!! Connection failed: {error}')) print(color(f'!!! Connection failed: {error}'))
def on_disconnection(self, reason): def on_disconnection(self, reason):
print(color(f'!!! Disconnected from {self.peer}, reason={HCI_Constant.error_name(reason)}', 'red')) print(
color(
f'!!! Disconnected from {self.peer}, reason={HCI_Constant.error_name(reason)}',
'red',
)
)
self.tx_characteristic = None self.tx_characteristic = None
self.rx_characteristic = None self.rx_characteristic = None
self.peer = None self.peer = None
# Called when an L2CAP packet has been received
def on_l2cap_packet(self, packet):
print(color(f'<<< [L2CAP PACKET]: {len(packet)} bytes', 'cyan'))
print(color('>>> [UDP]', 'magenta'))
self.tx_socket.sendto(packet)
# Called by the GATT client when a notification is received # Called by the GATT client when a notification is received
def on_tx_received(self, value): def on_tx_received(self, value):
print(color('>>> TX:', 'magenta'), value.hex()) print(color(f'<<< [GATT TX]: {len(value)} bytes', 'cyan'))
if self.tx_socket: if self.tx_socket:
print(color('>>> [UDP]', 'magenta'))
self.tx_socket.sendto(value) self.tx_socket.sendto(value)
# Called by asyncio when the UDP socket is created # Called by asyncio when the UDP socket is created
def connection_made(self, transport): def on_l2cap_psm_received(self, value):
pass psm = struct.unpack('<H', value)[0]
asyncio.create_task(self.connect_l2cap(psm))
# Called by asyncio when a UDP datagram is received
def datagram_received(self, data, address):
print(color('<<< RX:', 'magenta'), data.hex())
# TODO: use a queue instead of creating a task everytime
if self.peer and self.rx_characteristic:
asyncio.create_task(self.peer.write_value(self.rx_characteristic, data))
# -----------------------------------------------------------------------------
class GattlinkNodeBridge(Device.Listener):
def __init__(self):
self.peer = None
self.rx_socket = None
self.tx_socket = None
# Called by asyncio when the UDP socket is created # Called by asyncio when the UDP socket is created
def connection_made(self, transport): def connection_made(self, transport):
@@ -131,64 +190,156 @@ class GattlinkNodeBridge(Device.Listener):
# Called by asyncio when a UDP datagram is received # Called by asyncio when a UDP datagram is received
def datagram_received(self, data, address): def datagram_received(self, data, address):
print(color('<<< RX:', 'magenta'), data.hex()) print(color(f'<<< [UDP]: {len(data)} bytes', 'green'))
# TODO: use a queue instead of creating a task everytime if self.l2cap_channel:
if self.peer and self.rx_characteristic: print(color('>>> [L2CAP]', 'yellow'))
self.l2cap_channel.write(bytes([len(data) - 1]) + data)
elif self.peer and self.rx_characteristic:
print(color('>>> [GATT RX]', 'yellow'))
asyncio.create_task(self.peer.write_value(self.rx_characteristic, data)) asyncio.create_task(self.peer.write_value(self.rx_characteristic, data))
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def run(hci_transport, device_address, send_host, send_port, receive_host, receive_port): class GattlinkNodeBridge(GattlinkL2capEndpoint, Device.Listener):
def __init__(self, device):
super().__init__()
self.device = device
self.peer = None
self.tx_socket = None
self.tx_subscriber = None
self.rx_characteristic = None
# Register as a listener
device.listener = self
# Listen for incoming L2CAP CoC connections
psm = 0xFB
device.register_l2cap_channel_server(0xFB, self.on_coc)
print(f'### Listening for CoC connection on PSM {psm}')
# Setup the Gattlink service
self.rx_characteristic = Characteristic(
GG_GATTLINK_RX_CHARACTERISTIC_UUID,
Characteristic.WRITE_WITHOUT_RESPONSE,
Characteristic.WRITEABLE,
CharacteristicValue(write=self.on_rx_write),
)
self.tx_characteristic = Characteristic(
GG_GATTLINK_TX_CHARACTERISTIC_UUID,
Characteristic.NOTIFY,
Characteristic.READABLE,
)
self.tx_characteristic.on('subscription', self.on_tx_subscription)
self.psm_characteristic = Characteristic(
GG_GATTLINK_L2CAP_CHANNEL_PSM_CHARACTERISTIC_UUID,
Characteristic.READ | Characteristic.NOTIFY,
Characteristic.READABLE,
bytes([psm, 0]),
)
gattlink_service = Service(
GG_GATTLINK_SERVICE_UUID,
[self.rx_characteristic, self.tx_characteristic, self.psm_characteristic],
)
device.add_services([gattlink_service])
device.advertising_data = bytes(
AdvertisingData(
[
(AdvertisingData.COMPLETE_LOCAL_NAME, bytes('Bumble GG', 'utf-8')),
(
AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS,
bytes(
reversed(bytes.fromhex('ABBAFF00E56A484CB8328B17CF6CBFE8'))
),
),
]
)
)
async def start(self):
await self.device.start_advertising()
# Called by asyncio when the UDP socket is created
def connection_made(self, transport):
self.transport = transport
# Called by asyncio when a UDP datagram is received
def datagram_received(self, data, address):
print(color(f'<<< [UDP]: {len(data)} bytes', 'green'))
if self.l2cap_channel:
print(color('>>> [L2CAP]', 'yellow'))
self.l2cap_channel.write(bytes([len(data) - 1]) + data)
elif self.tx_subscriber:
print(color('>>> [GATT TX]', 'yellow'))
self.tx_characteristic.value = data
asyncio.create_task(self.device.notify_subscribers(self.tx_characteristic))
# Called when a write to the RX characteristic has been received
def on_rx_write(self, connection, data):
print(color(f'<<< [GATT RX]: {len(data)} bytes', 'cyan'))
print(color('>>> [UDP]', 'magenta'))
self.tx_socket.sendto(data)
# Called when the subscription to the TX characteristic has changed
def on_tx_subscription(self, peer, enabled):
print(
f'### [GATT TX] subscription from {peer}: {"enabled" if enabled else "disabled"}'
)
if enabled:
self.tx_subscriber = peer
else:
self.tx_subscriber = None
# Called when an L2CAP packet is received
def on_l2cap_packet(self, packet):
print(color(f'<<< [L2CAP PACKET]: {len(packet)} bytes', 'cyan'))
print(color('>>> [UDP]', 'magenta'))
self.tx_socket.sendto(packet)
# Called when a new connection is established
def on_coc(self, channel):
print('*** CoC Connection', channel)
self.l2cap_channel = channel
channel.sink = self.on_coc_sdu
# -----------------------------------------------------------------------------
async def run(
hci_transport,
device_address,
role_or_peer_address,
send_host,
send_port,
receive_host,
receive_port,
):
print('<<< connecting to HCI...') print('<<< connecting to HCI...')
async with await open_transport_or_link(hci_transport) as (hci_source, hci_sink): async with await open_transport_or_link(hci_transport) as (hci_source, hci_sink):
print('<<< connected') print('<<< connected')
# Instantiate a bridge object # Instantiate a bridge object
bridge = GattlinkNodeBridge() device = Device.with_hci('Bumble GG', device_address, hci_source, hci_sink)
# Instantiate a bridge object
if role_or_peer_address == 'node':
bridge = GattlinkNodeBridge(device)
else:
bridge = GattlinkHubBridge(device, role_or_peer_address)
# Create a UDP to RX bridge (receive from UDP, send to RX) # Create a UDP to RX bridge (receive from UDP, send to RX)
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
await loop.create_datagram_endpoint( await loop.create_datagram_endpoint(
lambda: bridge, lambda: bridge, local_addr=(receive_host, receive_port)
local_addr=(receive_host, receive_port)
) )
# Create a UDP to TX bridge (receive from TX, send to UDP) # Create a UDP to TX bridge (receive from TX, send to UDP)
bridge.tx_socket, _ = await loop.create_datagram_endpoint( bridge.tx_socket, _ = await loop.create_datagram_endpoint(
lambda: asyncio.DatagramProtocol(), lambda: asyncio.DatagramProtocol(), remote_addr=(send_host, send_port)
remote_addr=(send_host, send_port)
) )
# Create a device to manage the host, with a custom listener
device = Device.with_hci('Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink)
device.listener = bridge
await device.power_on() await device.power_on()
await bridge.start()
# Connect to the peer
# print(f'=== Connecting to {device_address}...')
# await device.connect(device_address)
# TODO move to class
gattlink_service = Service(
GG_GATTLINK_SERVICE_UUID,
[
Characteristic(
GG_GATTLINK_L2CAP_CHANNEL_PSM_CHARACTERISTIC_UUID,
Characteristic.READ,
Characteristic.READABLE,
bytes([193, 0])
)
]
)
device.add_services([gattlink_service])
device.advertising_data = bytes(
AdvertisingData([
(AdvertisingData.COMPLETE_LOCAL_NAME, bytes('Bumble GG', 'utf-8')),
(AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS, bytes(reversed(bytes.fromhex('ABBAFF00E56A484CB8328B17CF6CBFE8'))))
])
)
await device.start_advertising()
# Wait until the source terminates # Wait until the source terminates
await hci_source.wait_for_termination() await hci_source.wait_for_termination()
@@ -197,15 +348,44 @@ async def run(hci_transport, device_address, send_host, send_port, receive_host,
@click.command() @click.command()
@click.argument('hci_transport') @click.argument('hci_transport')
@click.argument('device_address') @click.argument('device_address')
@click.option('-sh', '--send-host', type=str, default='127.0.0.1', help='UDP host to send to') @click.argument('role_or_peer_address')
@click.option(
'-sh', '--send-host', type=str, default='127.0.0.1', help='UDP host to send to'
)
@click.option('-sp', '--send-port', type=int, default=9001, help='UDP port to send to') @click.option('-sp', '--send-port', type=int, default=9001, help='UDP port to send to')
@click.option('-rh', '--receive-host', type=str, default='127.0.0.1', help='UDP host to receive on') @click.option(
@click.option('-rp', '--receive-port', type=int, default=9000, help='UDP port to receive on') '-rh',
def main(hci_transport, device_address, send_host, send_port, receive_host, receive_port): '--receive-host',
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) type=str,
asyncio.run(run(hci_transport, device_address, send_host, send_port, receive_host, receive_port)) default='127.0.0.1',
help='UDP host to receive on',
)
@click.option(
'-rp', '--receive-port', type=int, default=9000, help='UDP port to receive on'
)
def main(
hci_transport,
device_address,
role_or_peer_address,
send_host,
send_port,
receive_host,
receive_port,
):
asyncio.run(
run(
hci_transport,
device_address,
role_or_peer_address,
send_host,
send_port,
receive_host,
receive_port,
)
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
if __name__ == '__main__': if __name__ == '__main__':
main() main()
+23 -8
View File
@@ -34,16 +34,26 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def async_main(): async def async_main():
if len(sys.argv) < 3: if len(sys.argv) < 3:
print('Usage: hci_bridge.py <host-transport-spec> <controller-transport-spec> [command-short-circuit-list]') print(
print('example: python hci_bridge.py udp:0.0.0.0:9000,127.0.0.1:9001 serial:/dev/tty.usbmodem0006839912171,1000000 0x3f:0x0070,0x3f:0x0074,0x3f:0x0077,0x3f:0x0078') 'Usage: hci_bridge.py <host-transport-spec> <controller-transport-spec> [command-short-circuit-list]'
)
print(
'example: python hci_bridge.py udp:0.0.0.0:9000,127.0.0.1:9001 serial:/dev/tty.usbmodem0006839912171,1000000 0x3f:0x0070,0x3f:0x0074,0x3f:0x0077,0x3f:0x0078'
)
return return
print('>>> connecting to HCI...') print('>>> connecting to HCI...')
async with await transport.open_transport_or_link(sys.argv[1]) as (hci_host_source, hci_host_sink): async with await transport.open_transport_or_link(sys.argv[1]) as (
hci_host_source,
hci_host_sink,
):
print('>>> connected') print('>>> connected')
print('>>> connecting to HCI...') print('>>> connecting to HCI...')
async with await transport.open_transport_or_link(sys.argv[2]) as (hci_controller_source, hci_controller_sink): async with await transport.open_transport_or_link(sys.argv[2]) as (
hci_controller_source,
hci_controller_sink,
):
print('>>> connected') print('>>> connected')
command_short_circuits = [] command_short_circuits = []
@@ -51,18 +61,23 @@ async def async_main():
for op_code_str in sys.argv[3].split(','): for op_code_str in sys.argv[3].split(','):
if ':' in op_code_str: if ':' in op_code_str:
ogf, ocf = op_code_str.split(':') ogf, ocf = op_code_str.split(':')
command_short_circuits.append(hci.hci_command_op_code(int(ogf, 16), int(ocf, 16))) command_short_circuits.append(
hci.hci_command_op_code(int(ogf, 16), int(ocf, 16))
)
else: else:
command_short_circuits.append(int(op_code_str, 16)) command_short_circuits.append(int(op_code_str, 16))
def host_to_controller_filter(hci_packet): def host_to_controller_filter(hci_packet):
if hci_packet.hci_packet_type == hci.HCI_COMMAND_PACKET and hci_packet.op_code in command_short_circuits: if (
hci_packet.hci_packet_type == hci.HCI_COMMAND_PACKET
and hci_packet.op_code in command_short_circuits
):
# Respond with a success response # Respond with a success response
logger.debug('short-circuiting packet') logger.debug('short-circuiting packet')
response = hci.HCI_Command_Complete_Event( response = hci.HCI_Command_Complete_Event(
num_hci_command_packets=1, num_hci_command_packets=1,
command_opcode=hci_packet.op_code, command_opcode=hci_packet.op_code,
return_parameters = bytes([hci.HCI_SUCCESS]) return_parameters=bytes([hci.HCI_SUCCESS]),
) )
# Return a packet with 'respond to sender' set to True # Return a packet with 'respond to sender' set to True
return (response.to_bytes(), True) return (response.to_bytes(), True)
@@ -73,7 +88,7 @@ async def async_main():
hci_controller_source, hci_controller_source,
hci_controller_sink, hci_controller_sink,
host_to_controller_filter, host_to_controller_filter,
None None,
) )
await asyncio.get_running_loop().create_future() await asyncio.get_running_loop().create_future()
+349
View File
@@ -0,0 +1,349 @@
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import click
import logging
import os
from colors import color
from bumble.transport import open_transport_or_link
from bumble.device import Device
from bumble.utils import FlowControlAsyncPipe
from bumble.hci import HCI_Constant
# -----------------------------------------------------------------------------
class ServerBridge:
"""
L2CAP CoC server bridge: waits for a peer to connect an L2CAP CoC channel
on a specified PSM. When the connection is made, the bridge connects a TCP
socket to a remote host and bridges the data in both directions, with flow
control.
When the L2CAP CoC channel is closed, the bridge disconnects the TCP socket
and waits for a new L2CAP CoC channel to be connected.
When the TCP connection is closed by the TCP server, XXXX
"""
def __init__(self, psm, max_credits, mtu, mps, tcp_host, tcp_port):
self.psm = psm
self.max_credits = max_credits
self.mtu = mtu
self.mps = mps
self.tcp_host = tcp_host
self.tcp_port = tcp_port
async def start(self, device):
# Listen for incoming L2CAP CoC connections
device.register_l2cap_channel_server(
psm=self.psm,
server=self.on_coc,
max_credits=self.max_credits,
mtu=self.mtu,
mps=self.mps,
)
print(color(f'### Listening for CoC connection on PSM {self.psm}', 'yellow'))
def on_ble_connection(connection):
def on_ble_disconnection(reason):
print(
color('@@@ Bluetooth disconnection:', 'red'),
HCI_Constant.error_name(reason),
)
print(color('@@@ Bluetooth connection:', 'green'), connection)
connection.on('disconnection', on_ble_disconnection)
device.on('connection', on_ble_connection)
await device.start_advertising(auto_restart=True)
# Called when a new L2CAP connection is established
def on_coc(self, l2cap_channel):
print(color('*** L2CAP channel:', 'cyan'), l2cap_channel)
class Pipe:
def __init__(self, bridge, l2cap_channel):
self.bridge = bridge
self.tcp_transport = None
self.l2cap_channel = l2cap_channel
l2cap_channel.on('close', self.on_l2cap_close)
l2cap_channel.sink = self.on_coc_sdu
async def connect_to_tcp(self):
# Connect to the TCP server
print(
color(
f'### Connecting to TCP {self.bridge.tcp_host}:{self.bridge.tcp_port}...',
'yellow',
)
)
class TcpClientProtocol(asyncio.Protocol):
def __init__(self, pipe):
self.pipe = pipe
def connection_lost(self, error):
print(color(f'!!! TCP connection lost: {error}', 'red'))
if self.pipe.l2cap_channel is not None:
asyncio.create_task(self.pipe.l2cap_channel.disconnect())
def data_received(self, data):
print(f'<<< Received on TCP: {len(data)}')
self.pipe.l2cap_channel.write(data)
try:
(
self.tcp_transport,
_,
) = await asyncio.get_running_loop().create_connection(
lambda: TcpClientProtocol(self),
host=self.bridge.tcp_host,
port=self.bridge.tcp_port,
)
print(color('### Connected', 'green'))
except Exception as error:
print(color(f'!!! Connection failed: {error}', 'red'))
await self.l2cap_channel.disconnect()
def on_l2cap_close(self):
self.l2cap_channel = None
if self.tcp_transport is not None:
self.tcp_transport.close()
def on_coc_sdu(self, sdu):
print(color(f'<<< [L2CAP SDU]: {len(sdu)} bytes', 'cyan'))
if self.tcp_transport is None:
print(color('!!! TCP socket not open, dropping', 'red'))
return
self.tcp_transport.write(sdu)
pipe = Pipe(self, l2cap_channel)
asyncio.create_task(pipe.connect_to_tcp())
# -----------------------------------------------------------------------------
class ClientBridge:
"""
L2CAP CoC client bridge: connects to a BLE device, then waits for an inbound
TCP connection on a specified port number. When a TCP client connects, an
L2CAP CoC channel connection to the BLE device is established, and the data
is bridged in both directions, with flow control.
When the TCP connection is closed by the client, the L2CAP CoC channel is
disconnected, but the connection to the BLE device remains, ready for a new
TCP client to connect.
When the L2CAP CoC channel is closed, XXXX
"""
READ_CHUNK_SIZE = 4096
def __init__(self, psm, max_credits, mtu, mps, address, tcp_host, tcp_port):
self.psm = psm
self.max_credits = max_credits
self.mtu = mtu
self.mps = mps
self.address = address
self.tcp_host = tcp_host
self.tcp_port = tcp_port
async def start(self, device):
print(color(f'### Connecting to {self.address}...', 'yellow'))
connection = await device.connect(self.address)
print(color('### Connected', 'green'))
# Called when the BLE connection is disconnected
def on_ble_disconnection(reason):
print(
color('@@@ Bluetooth disconnection:', 'red'),
HCI_Constant.error_name(reason),
)
connection.on('disconnection', on_ble_disconnection)
# Called when a TCP connection is established
async def on_tcp_connection(reader, writer):
peername = writer.get_extra_info('peername')
print(color(f'<<< TCP connection from {peername}', 'magenta'))
def on_coc_sdu(sdu):
print(color(f'<<< [L2CAP SDU]: {len(sdu)} bytes', 'cyan'))
l2cap_to_tcp_pipe.write(sdu)
def on_l2cap_close():
print(color('*** L2CAP channel closed', 'red'))
l2cap_to_tcp_pipe.stop()
writer.close()
# Connect a new L2CAP channel
print(color(f'>>> Opening L2CAP channel on PSM = {self.psm}', 'yellow'))
try:
l2cap_channel = await connection.open_l2cap_channel(
psm=self.psm,
max_credits=self.max_credits,
mtu=self.mtu,
mps=self.mps,
)
print(color('*** L2CAP channel:', 'cyan'), l2cap_channel)
except Exception as error:
print(color(f'!!! Connection failed: {error}', 'red'))
writer.close()
return
l2cap_channel.sink = on_coc_sdu
l2cap_channel.on('close', on_l2cap_close)
# Start a flow control pipe from L2CAP to TCP
l2cap_to_tcp_pipe = FlowControlAsyncPipe(
l2cap_channel.pause_reading,
l2cap_channel.resume_reading,
writer.write,
writer.drain,
)
l2cap_to_tcp_pipe.start()
# Pipe data from TCP to L2CAP
while True:
try:
data = await reader.read(self.READ_CHUNK_SIZE)
if len(data) == 0:
print(color('!!! End of stream', 'red'))
await l2cap_channel.disconnect()
return
print(color(f'<<< [TCP DATA]: {len(data)} bytes', 'blue'))
l2cap_channel.write(data)
await l2cap_channel.drain()
except Exception as error:
print(f'!!! Exception: {error}')
break
writer.close()
print(color('~~~ Bye bye', 'magenta'))
await asyncio.start_server(
on_tcp_connection,
host=self.tcp_host if self.tcp_host != '_' else None,
port=self.tcp_port,
)
print(
color(
f'### Listening for TCP connections on port {self.tcp_port}', 'magenta'
)
)
# -----------------------------------------------------------------------------
async def run(device_config, hci_transport, bridge):
print('<<< connecting to HCI...')
async with await open_transport_or_link(hci_transport) as (hci_source, hci_sink):
print('<<< connected')
device = Device.from_config_file_with_hci(device_config, hci_source, hci_sink)
# Let's go
await device.power_on()
await bridge.start(device)
# Wait until the transport terminates
await hci_source.wait_for_termination()
# -----------------------------------------------------------------------------
@click.group()
@click.pass_context
@click.option('--device-config', help='Device configuration file', required=True)
@click.option('--hci-transport', help='HCI transport', required=True)
@click.option('--psm', help='PSM for L2CAP CoC', type=int, default=1234)
@click.option(
'--l2cap-coc-max-credits',
help='Maximum L2CAP CoC Credits',
type=click.IntRange(1, 65535),
default=128,
)
@click.option(
'--l2cap-coc-mtu',
help='L2CAP CoC MTU',
type=click.IntRange(23, 65535),
default=1022,
)
@click.option(
'--l2cap-coc-mps',
help='L2CAP CoC MPS',
type=click.IntRange(23, 65533),
default=1024,
)
def cli(
context,
device_config,
hci_transport,
psm,
l2cap_coc_max_credits,
l2cap_coc_mtu,
l2cap_coc_mps,
):
context.ensure_object(dict)
context.obj['device_config'] = device_config
context.obj['hci_transport'] = hci_transport
context.obj['psm'] = psm
context.obj['max_credits'] = l2cap_coc_max_credits
context.obj['mtu'] = l2cap_coc_mtu
context.obj['mps'] = l2cap_coc_mps
# -----------------------------------------------------------------------------
@cli.command()
@click.pass_context
@click.option('--tcp-host', help='TCP host', default='localhost')
@click.option('--tcp-port', help='TCP port', default=9544)
def server(context, tcp_host, tcp_port):
bridge = ServerBridge(
context.obj['psm'],
context.obj['max_credits'],
context.obj['mtu'],
context.obj['mps'],
tcp_host,
tcp_port,
)
asyncio.run(run(context.obj['device_config'], context.obj['hci_transport'], bridge))
# -----------------------------------------------------------------------------
@cli.command()
@click.pass_context
@click.argument('bluetooth-address')
@click.option('--tcp-host', help='TCP host', default='_')
@click.option('--tcp-port', help='TCP port', default=9543)
def client(context, bluetooth_address, tcp_host, tcp_port):
bridge = ClientBridge(
context.obj['psm'],
context.obj['max_credits'],
context.obj['mtu'],
context.obj['mps'],
bluetooth_address,
tcp_host,
tcp_port,
)
asyncio.run(run(context.obj['device_config'], context.obj['hci_transport'], bridge))
# -----------------------------------------------------------------------------
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
if __name__ == '__main__':
cli(obj={})
+13 -7
View File
@@ -145,7 +145,9 @@ class Room:
# This is an RPC request # This is an RPC request
await self.on_rpc_request(connection, message) await self.on_rpc_request(connection, message)
else: else:
await connection.send_message(f'result:{error_to_json("error: invalid message")}') await connection.send_message(
f'result:{error_to_json("error: invalid message")}'
)
async def broadcast_message(self, sender, message): async def broadcast_message(self, sender, message):
''' '''
@@ -155,7 +157,9 @@ class Room:
async def on_rpc_request(self, connection, message): async def on_rpc_request(self, connection, message):
command, *params = message.split(' ', 1) command, *params = message.split(' ', 1)
if handler := getattr(self, f'on_{command[1:].lower().replace("-","_")}_command', None): if handler := getattr(
self, f'on_{command[1:].lower().replace("-","_")}_command', None
):
try: try:
result = await handler(connection, params) result = await handler(connection, params)
except Exception as error: except Exception as error:
@@ -192,7 +196,9 @@ class Room:
current_address = connection.address current_address = connection.address
new_address = params[0] new_address = params[0]
connection.set_address(new_address) connection.set_address(new_address)
await self.broadcast_message(connection, f'address-changed:from={current_address},to={new_address}') await self.broadcast_message(
connection, f'address-changed:from={current_address},to={new_address}'
)
# ---------------------------------------------------------------------------- # ----------------------------------------------------------------------------
@@ -252,15 +258,15 @@ def main():
arg_parser = argparse.ArgumentParser(description='Bumble Link Relay') arg_parser = argparse.ArgumentParser(description='Bumble Link Relay')
arg_parser.add_argument('--log-level', default='INFO', help='logger level') arg_parser.add_argument('--log-level', default='INFO', help='logger level')
arg_parser.add_argument('--log-config', help='logger config file (YAML)') arg_parser.add_argument('--log-config', help='logger config file (YAML)')
arg_parser.add_argument('--port', arg_parser.add_argument(
type = int, '--port', type=int, default=DEFAULT_RELAY_PORT, help='Port to listen on'
default = DEFAULT_RELAY_PORT, )
help = 'Port to listen on')
args = arg_parser.parse_args() args = arg_parser.parse_args()
# Setup logger # Setup logger
if args.log_config: if args.log_config:
from logging import config from logging import config
config.fileConfig(args.log_config) config.fileConfig(args.log_config)
else: else:
logging.basicConfig(level=getattr(logging, args.log_level.upper())) logging.basicConfig(level=getattr(logging, args.log_level.upper()))
+85 -24
View File
@@ -33,25 +33,27 @@ from bumble.gatt import (
GATT_GENERIC_ACCESS_SERVICE, GATT_GENERIC_ACCESS_SERVICE,
Service, Service,
Characteristic, Characteristic,
CharacteristicValue CharacteristicValue,
) )
from bumble.att import ( from bumble.att import (
ATT_Error, ATT_Error,
ATT_INSUFFICIENT_AUTHENTICATION_ERROR, ATT_INSUFFICIENT_AUTHENTICATION_ERROR,
ATT_INSUFFICIENT_ENCRYPTION_ERROR ATT_INSUFFICIENT_ENCRYPTION_ERROR,
) )
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Delegate(PairingDelegate): class Delegate(PairingDelegate):
def __init__(self, mode, connection, capability_string, prompt): def __init__(self, mode, connection, capability_string, prompt):
super().__init__({ super().__init__(
{
'keyboard': PairingDelegate.KEYBOARD_INPUT_ONLY, 'keyboard': PairingDelegate.KEYBOARD_INPUT_ONLY,
'display': PairingDelegate.DISPLAY_OUTPUT_ONLY, 'display': PairingDelegate.DISPLAY_OUTPUT_ONLY,
'display+keyboard': PairingDelegate.DISPLAY_OUTPUT_AND_KEYBOARD_INPUT, 'display+keyboard': PairingDelegate.DISPLAY_OUTPUT_AND_KEYBOARD_INPUT,
'display+yes/no': PairingDelegate.DISPLAY_OUTPUT_AND_YES_NO_INPUT, 'display+yes/no': PairingDelegate.DISPLAY_OUTPUT_AND_YES_NO_INPUT,
'none': PairingDelegate.NO_OUTPUT_NO_INPUT 'none': PairingDelegate.NO_OUTPUT_NO_INPUT,
}[capability_string.lower()]) }[capability_string.lower()]
)
self.mode = mode self.mode = mode
self.peer = Peer(connection) self.peer = Peer(connection)
@@ -103,7 +105,11 @@ class Delegate(PairingDelegate):
print(color(f'### Pairing with {self.peer_name}', 'yellow')) print(color(f'### Pairing with {self.peer_name}', 'yellow'))
print(color('###-----------------------------------', 'yellow')) print(color('###-----------------------------------', 'yellow'))
while True: while True:
response = await aioconsole.ainput(color(f'>>> Does the other device display {number:0{digits}}? ', 'yellow')) response = await aioconsole.ainput(
color(
f'>>> Does the other device display {number:0{digits}}? ', 'yellow'
)
)
response = response.lower().strip() response = response.lower().strip()
if response == 'yes': if response == 'yes':
return True return True
@@ -149,7 +155,9 @@ async def get_peer_name(peer, mode):
if not services: if not services:
return None return None
values = await peer.read_characteristics_by_uuid(GATT_DEVICE_NAME_CHARACTERISTIC, services[0]) values = await peer.read_characteristics_by_uuid(
GATT_DEVICE_NAME_CHARACTERISTIC, services[0]
)
if values: if values:
return values[0].decode('utf-8') return values[0].decode('utf-8')
@@ -190,7 +198,7 @@ def on_connection(connection, request):
# Listen for encryption changes # Listen for encryption changes
connection.on( connection.on(
'connection_encryption_change', 'connection_encryption_change',
lambda: on_connection_encryption_change(connection) lambda: on_connection_encryption_change(connection),
) )
# Request pairing if needed # Request pairing if needed
@@ -202,7 +210,12 @@ def on_connection(connection, request):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def on_connection_encryption_change(connection): def on_connection_encryption_change(connection):
print(color('@@@-----------------------------------', 'blue')) print(color('@@@-----------------------------------', 'blue'))
print(color(f'@@@ Connection is {"" if connection.is_encrypted else "not"}encrypted', 'blue')) print(
color(
f'@@@ Connection is {"" if connection.is_encrypted else "not"}encrypted',
'blue',
)
)
print(color('@@@-----------------------------------', 'blue')) print(color('@@@-----------------------------------', 'blue'))
@@ -241,7 +254,7 @@ async def pair(
keystore_file, keystore_file,
device_config, device_config,
hci_transport, hci_transport,
address_or_name address_or_name,
): ):
print('<<< connecting to HCI...') print('<<< connecting to HCI...')
async with await open_transport_or_link(hci_transport) as (hci_source, hci_sink): async with await open_transport_or_link(hci_transport) as (hci_source, hci_sink):
@@ -272,9 +285,11 @@ async def pair(
'552957FB-CF1F-4A31-9535-E78847E1A714', '552957FB-CF1F-4A31-9535-E78847E1A714',
Characteristic.READ | Characteristic.WRITE, Characteristic.READ | Characteristic.WRITE,
Characteristic.READABLE | Characteristic.WRITEABLE, Characteristic.READABLE | Characteristic.WRITEABLE,
CharacteristicValue(read=read_with_error, write=write_with_error) CharacteristicValue(
read=read_with_error, write=write_with_error
),
) )
] ],
) )
) )
@@ -288,10 +303,7 @@ async def pair(
# Set up a pairing config factory # Set up a pairing config factory
device.pairing_config_factory = lambda connection: PairingConfig( device.pairing_config_factory = lambda connection: PairingConfig(
sc, sc, mitm, bond, Delegate(mode, connection, io, prompt)
mitm,
bond,
Delegate(mode, connection, io, prompt)
) )
# Connect to a peer or wait for a connection # Connect to a peer or wait for a connection
@@ -319,21 +331,70 @@ async def pair(
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@click.command() @click.command()
@click.option('--mode', type=click.Choice(['le', 'classic']), default='le', show_default=True) @click.option(
@click.option('--sc', type=bool, default=True, help='Use the Secure Connections protocol', show_default=True) '--mode', type=click.Choice(['le', 'classic']), default='le', show_default=True
@click.option('--mitm', type=bool, default=True, help='Request MITM protection', show_default=True) )
@click.option('--bond', type=bool, default=True, help='Enable bonding', show_default=True) @click.option(
@click.option('--io', type=click.Choice(['keyboard', 'display', 'display+keyboard', 'display+yes/no', 'none']), default='display+keyboard', show_default=True) '--sc',
type=bool,
default=True,
help='Use the Secure Connections protocol',
show_default=True,
)
@click.option(
'--mitm', type=bool, default=True, help='Request MITM protection', show_default=True
)
@click.option(
'--bond', type=bool, default=True, help='Enable bonding', show_default=True
)
@click.option(
'--io',
type=click.Choice(
['keyboard', 'display', 'display+keyboard', 'display+yes/no', 'none']
),
default='display+keyboard',
show_default=True,
)
@click.option('--prompt', is_flag=True, help='Prompt to accept/reject pairing request') @click.option('--prompt', is_flag=True, help='Prompt to accept/reject pairing request')
@click.option('--request', is_flag=True, help='Request that the connecting peer initiate pairing') @click.option(
'--request', is_flag=True, help='Request that the connecting peer initiate pairing'
)
@click.option('--print-keys', is_flag=True, help='Print the bond keys before pairing') @click.option('--print-keys', is_flag=True, help='Print the bond keys before pairing')
@click.option('--keystore-file', help='File in which to store the pairing keys') @click.option('--keystore-file', help='File in which to store the pairing keys')
@click.argument('device-config') @click.argument('device-config')
@click.argument('hci_transport') @click.argument('hci_transport')
@click.argument('address-or-name', required=False) @click.argument('address-or-name', required=False)
def main(mode, sc, mitm, bond, io, prompt, request, print_keys, keystore_file, device_config, hci_transport, address_or_name): def main(
mode,
sc,
mitm,
bond,
io,
prompt,
request,
print_keys,
keystore_file,
device_config,
hci_transport,
address_or_name,
):
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(pair(mode, sc, mitm, bond, io, prompt, request, print_keys, keystore_file, device_config, hci_transport, address_or_name)) asyncio.run(
pair(
mode,
sc,
mitm,
bond,
io,
prompt,
request,
print_keys,
keystore_file,
device_config,
hci_transport,
address_or_name,
)
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
+92 -25
View File
@@ -25,8 +25,8 @@ from bumble.device import Device
from bumble.transport import open_transport_or_link from bumble.transport import open_transport_or_link
from bumble.keys import JsonKeyStore from bumble.keys import JsonKeyStore
from bumble.smp import AddressResolver from bumble.smp import AddressResolver
from bumble.hci import HCI_LE_Advertising_Report_Event from bumble.device import Advertisement
from bumble.core import AdvertisingData from bumble.hci import HCI_Constant, HCI_LE_1M_PHY, HCI_LE_CODED_PHY
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -48,19 +48,24 @@ class AdvertisementPrinter:
self.min_rssi = min_rssi self.min_rssi = min_rssi
self.resolver = resolver self.resolver = resolver
def print_advertisement(self, address, address_color, ad_data, rssi): def print_advertisement(self, advertisement):
if self.min_rssi is not None and rssi < self.min_rssi: address = advertisement.address
address_color = 'yellow' if advertisement.is_connectable else 'red'
if self.min_rssi is not None and advertisement.rssi < self.min_rssi:
return return
address_qualifier = '' address_qualifier = ''
resolution_qualifier = '' resolution_qualifier = ''
if self.resolver and address.is_resolvable: if self.resolver and advertisement.address.is_resolvable:
resolved = self.resolver.resolve(address) resolved = self.resolver.resolve(advertisement.address)
if resolved is not None: if resolved is not None:
resolution_qualifier = f'(resolved from {address})' resolution_qualifier = f'(resolved from {advertisement.address})'
address = resolved address = resolved
address_type_string = ('PUBLIC', 'RANDOM', 'PUBLIC_ID', 'RANDOM_ID')[address.address_type] address_type_string = ('PUBLIC', 'RANDOM', 'PUBLIC_ID', 'RANDOM_ID')[
address.address_type
]
if address.is_public: if address.is_public:
type_color = 'cyan' type_color = 'cyan'
else: else:
@@ -74,18 +79,31 @@ class AdvertisementPrinter:
type_color = 'blue' type_color = 'blue'
address_qualifier = '(non-resolvable)' address_qualifier = '(non-resolvable)'
rssi_bar = make_rssi_bar(rssi)
separator = '\n ' separator = '\n '
print(f'>>> {color(address, address_color)} [{color(address_type_string, type_color)}]{address_qualifier}{resolution_qualifier}:{separator}RSSI:{rssi:4} {rssi_bar}{separator}{ad_data.to_string(separator)}\n') rssi_bar = make_rssi_bar(advertisement.rssi)
if not advertisement.is_legacy:
phy_info = (
f'PHY: {HCI_Constant.le_phy_name(advertisement.primary_phy)}/'
f'{HCI_Constant.le_phy_name(advertisement.secondary_phy)} '
f'{separator}'
)
else:
phy_info = ''
def on_advertisement(self, address, ad_data, rssi, connectable): print(
address_color = 'yellow' if connectable else 'red' f'>>> {color(address, address_color)} '
self.print_advertisement(address, address_color, ad_data, rssi) f'[{color(address_type_string, type_color)}]{address_qualifier}{resolution_qualifier}:{separator}'
f'{phy_info}'
f'RSSI:{advertisement.rssi:4} {rssi_bar}{separator}'
f'{advertisement.data.to_string(separator)}\n'
)
def on_advertising_report(self, address, ad_data, rssi, event_type): def on_advertisement(self, advertisement):
print(f'{color("EVENT", "green")}: {HCI_LE_Advertising_Report_Event.event_type_name(event_type)}') self.print_advertisement(advertisement)
ad_data = AdvertisingData.from_bytes(ad_data)
self.print_advertisement(address, 'yellow', ad_data, rssi) def on_advertising_report(self, report):
print(f'{color("EVENT", "green")}: {report.event_type_string()}')
self.print_advertisement(Advertisement.from_advertising_report(report))
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -94,20 +112,25 @@ async def scan(
passive, passive,
scan_interval, scan_interval,
scan_window, scan_window,
phy,
filter_duplicates, filter_duplicates,
raw, raw,
keystore_file, keystore_file,
device_config, device_config,
transport transport,
): ):
print('<<< connecting to HCI...') print('<<< connecting to HCI...')
async with await open_transport_or_link(transport) as (hci_source, hci_sink): async with await open_transport_or_link(transport) as (hci_source, hci_sink):
print('<<< connected') print('<<< connected')
if device_config: if device_config:
device = Device.from_config_file_with_hci(device_config, hci_source, hci_sink) device = Device.from_config_file_with_hci(
device_config, hci_source, hci_sink
)
else: else:
device = Device.with_hci('Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink) device = Device.with_hci(
'Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink
)
if keystore_file: if keystore_file:
keystore = JsonKeyStore(namespace=None, filename=keystore_file) keystore = JsonKeyStore(namespace=None, filename=keystore_file)
@@ -126,11 +149,18 @@ async def scan(
device.on('advertisement', printer.on_advertisement) device.on('advertisement', printer.on_advertisement)
await device.power_on() await device.power_on()
if phy is None:
scanning_phys = [HCI_LE_1M_PHY, HCI_LE_CODED_PHY]
else:
scanning_phys = [{'1m': HCI_LE_1M_PHY, 'coded': HCI_LE_CODED_PHY}[phy]]
await device.start_scanning( await device.start_scanning(
active=(not passive), active=(not passive),
scan_interval=scan_interval, scan_interval=scan_interval,
scan_window=scan_window, scan_window=scan_window,
filter_duplicates=filter_duplicates filter_duplicates=filter_duplicates,
scanning_phys=scanning_phys,
) )
await hci_source.wait_for_termination() await hci_source.wait_for_termination()
@@ -142,14 +172,51 @@ async def scan(
@click.option('--passive', is_flag=True, default=False, help='Perform passive scanning') @click.option('--passive', is_flag=True, default=False, help='Perform passive scanning')
@click.option('--scan-interval', type=int, default=60, help='Scan interval') @click.option('--scan-interval', type=int, default=60, help='Scan interval')
@click.option('--scan-window', type=int, default=60, help='Scan window') @click.option('--scan-window', type=int, default=60, help='Scan window')
@click.option('--filter-duplicates', type=bool, default=True, help='Filter duplicates at the controller level') @click.option(
@click.option('--raw', is_flag=True, default=False, help='Listen for raw advertising reports instead of processed ones') '--phy', type=click.Choice(['1m', 'coded']), help='Only scan on the specified PHY'
)
@click.option(
'--filter-duplicates',
type=bool,
default=True,
help='Filter duplicates at the controller level',
)
@click.option(
'--raw',
is_flag=True,
default=False,
help='Listen for raw advertising reports instead of processed ones',
)
@click.option('--keystore-file', help='Keystore file to use when resolving addresses') @click.option('--keystore-file', help='Keystore file to use when resolving addresses')
@click.option('--device-config', help='Device config file for the scanning device') @click.option('--device-config', help='Device config file for the scanning device')
@click.argument('transport') @click.argument('transport')
def main(min_rssi, passive, scan_interval, scan_window, filter_duplicates, raw, keystore_file, device_config, transport): def main(
min_rssi,
passive,
scan_interval,
scan_window,
phy,
filter_duplicates,
raw,
keystore_file,
device_config,
transport,
):
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper()) logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
asyncio.run(scan(min_rssi, passive, scan_interval, scan_window, filter_duplicates, raw, keystore_file, device_config, transport)) asyncio.run(
scan(
min_rssi,
passive,
scan_interval,
scan_window,
phy,
filter_duplicates,
raw,
keystore_file,
device_config,
transport,
)
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
+22 -6
View File
@@ -41,9 +41,16 @@ class SnoopPacketReader:
# Read the header # Read the header
identification_pattern = source.read(8) identification_pattern = source.read(8)
if identification_pattern.hex().lower() != '6274736e6f6f7000': if identification_pattern.hex().lower() != '6274736e6f6f7000':
raise ValueError('not a valid snoop file, unexpected identification pattern') raise ValueError(
(self.version_number, self.data_link_type) = struct.unpack('>II', source.read(8)) 'not a valid snoop file, unexpected identification pattern'
if self.data_link_type != self.DATALINK_H4 and self.data_link_type != self.DATALINK_H1: )
(self.version_number, self.data_link_type) = struct.unpack(
'>II', source.read(8)
)
if (
self.data_link_type != self.DATALINK_H4
and self.data_link_type != self.DATALINK_H1
):
raise ValueError(f'datalink type {self.data_link_type} not supported') raise ValueError(f'datalink type {self.data_link_type} not supported')
def next_packet(self): def next_packet(self):
@@ -57,7 +64,7 @@ class SnoopPacketReader:
packet_flags, packet_flags,
cumulative_drops, cumulative_drops,
timestamp_seconds, timestamp_seconds,
timestamp_microsecond timestamp_microsecond,
) = struct.unpack('>IIIIII', header) ) = struct.unpack('>IIIIII', header)
# Abort on truncated packets # Abort on truncated packets
@@ -79,7 +86,10 @@ class SnoopPacketReader:
else: else:
packet_type = hci.HCI_ACL_DATA_PACKET packet_type = hci.HCI_ACL_DATA_PACKET
return (packet_flags & 1, bytes([packet_type]) + self.source.read(included_length)) return (
packet_flags & 1,
bytes([packet_type]) + self.source.read(included_length),
)
else: else:
return (packet_flags & 1, self.source.read(included_length)) return (packet_flags & 1, self.source.read(included_length))
@@ -88,7 +98,12 @@ class SnoopPacketReader:
# Main # Main
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@click.command() @click.command()
@click.option('--format', type=click.Choice(['h4', 'snoop']), default='h4', help='Format of the input file') @click.option(
'--format',
type=click.Choice(['h4', 'snoop']),
default='h4',
help='Format of the input file',
)
@click.argument('filename') @click.argument('filename')
def main(format, filename): def main(format, filename):
input = open(filename, 'rb') input = open(filename, 'rb')
@@ -97,6 +112,7 @@ def main(format, filename):
def read_next_packet(): def read_next_packet():
(0, packet_reader.next_packet()) (0, packet_reader.next_packet())
else: else:
packet_reader = SnoopPacketReader(input) packet_reader = SnoopPacketReader(input)
read_next_packet = packet_reader.next_packet read_next_packet = packet_reader.next_packet
+58 -22
View File
@@ -69,13 +69,13 @@ USB_DEVICE_CLASSES = {
0x01: 'Bluetooth', 0x01: 'Bluetooth',
0x02: 'UWB', 0x02: 'UWB',
0x03: 'Remote NDIS', 0x03: 'Remote NDIS',
0x04: 'Bluetooth AMP' 0x04: 'Bluetooth AMP',
}
} }
},
), ),
0xEF: 'Miscellaneous', 0xEF: 'Miscellaneous',
0xFE: 'Application Specific', 0xFE: 'Application Specific',
0xFF: 'Vendor Specific' 0xFF: 'Vendor Specific',
} }
USB_ENDPOINT_IN = 0x80 USB_ENDPOINT_IN = 0x80
@@ -84,7 +84,7 @@ USB_ENDPOINT_TYPES = ['CONTROL', 'ISOCHRONOUS', 'BULK', 'INTERRUPT']
USB_BT_HCI_CLASS_TUPLE = ( USB_BT_HCI_CLASS_TUPLE = (
USB_DEVICE_CLASS_WIRELESS_CONTROLLER, USB_DEVICE_CLASS_WIRELESS_CONTROLLER,
USB_DEVICE_SUBCLASS_RF_CONTROLLER, USB_DEVICE_SUBCLASS_RF_CONTROLLER,
USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER,
) )
@@ -95,18 +95,24 @@ def show_device_details(device):
for interface in configuration: for interface in configuration:
for setting in interface: for setting in interface:
alternateSetting = setting.getAlternateSetting() alternateSetting = setting.getAlternateSetting()
suffix = f'/{alternateSetting}' if interface.getNumSettings() > 1 else '' suffix = (
f'/{alternateSetting}' if interface.getNumSettings() > 1 else ''
)
(class_string, subclass_string) = get_class_info( (class_string, subclass_string) = get_class_info(
setting.getClass(), setting.getClass(), setting.getSubClass(), setting.getProtocol()
setting.getSubClass(),
setting.getProtocol()
) )
details = f'({class_string}, {subclass_string})' details = f'({class_string}, {subclass_string})'
print(f' Interface: {setting.getNumber()}{suffix} {details}') print(f' Interface: {setting.getNumber()}{suffix} {details}')
for endpoint in setting: for endpoint in setting:
endpoint_type = USB_ENDPOINT_TYPES[endpoint.getAttributes() & 3] endpoint_type = USB_ENDPOINT_TYPES[endpoint.getAttributes() & 3]
endpoint_direction = 'OUT' if (endpoint.getAddress() & USB_ENDPOINT_IN == 0) else 'IN' endpoint_direction = (
print(f' Endpoint 0x{endpoint.getAddress():02X}: {endpoint_type} {endpoint_direction}') 'OUT'
if (endpoint.getAddress() & USB_ENDPOINT_IN == 0)
else 'IN'
)
print(
f' Endpoint 0x{endpoint.getAddress():02X}: {endpoint_type} {endpoint_direction}'
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -135,7 +141,11 @@ def get_class_info(cls, subclass, protocol):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def is_bluetooth_hci(device): def is_bluetooth_hci(device):
# Check if the device class indicates a match # Check if the device class indicates a match
if (device.getDeviceClass(), device.getDeviceSubClass(), device.getDeviceProtocol()) == USB_BT_HCI_CLASS_TUPLE: if (
device.getDeviceClass(),
device.getDeviceSubClass(),
device.getDeviceProtocol(),
) == USB_BT_HCI_CLASS_TUPLE:
return True return True
# If the device class is 'Device', look for a matching interface # If the device class is 'Device', look for a matching interface
@@ -143,7 +153,11 @@ def is_bluetooth_hci(device):
for configuration in device: for configuration in device:
for interface in configuration: for interface in configuration:
for setting in interface: for setting in interface:
if (setting.getClass(), setting.getSubClass(), setting.getProtocol()) == USB_BT_HCI_CLASS_TUPLE: if (
setting.getClass(),
setting.getSubClass(),
setting.getProtocol(),
) == USB_BT_HCI_CLASS_TUPLE:
return True return True
return False return False
@@ -167,9 +181,7 @@ def main(verbose):
device_id = (device.getVendorID(), device.getProductID()) device_id = (device.getVendorID(), device.getProductID())
(device_class_string, device_subclass_string) = get_class_info( (device_class_string, device_subclass_string) = get_class_info(
device_class, device_class, device_subclass, device_protocol
device_subclass,
device_protocol
) )
try: try:
@@ -198,7 +210,9 @@ def main(verbose):
# Compute the different ways this can be referenced as a Bumble transport # Compute the different ways this can be referenced as a Bumble transport
bumble_transport_names = [] bumble_transport_names = []
basic_transport_name = f'usb:{device.getVendorID():04X}:{device.getProductID():04X}' basic_transport_name = (
f'usb:{device.getVendorID():04X}:{device.getProductID():04X}'
)
if device_is_bluetooth_hci: if device_is_bluetooth_hci:
bumble_transport_names.append(f'usb:{bluetooth_device_count - 1}') bumble_transport_names.append(f'usb:{bluetooth_device_count - 1}')
@@ -206,17 +220,39 @@ def main(verbose):
if device_id not in devices: if device_id not in devices:
bumble_transport_names.append(basic_transport_name) bumble_transport_names.append(basic_transport_name)
else: else:
bumble_transport_names.append(f'{basic_transport_name}#{len(devices[device_id])}') bumble_transport_names.append(
f'{basic_transport_name}#{len(devices[device_id])}'
)
if device_serial_number is not None: if device_serial_number is not None:
if device_id not in devices or device_serial_number not in devices[device_id]: if (
bumble_transport_names.append(f'{basic_transport_name}/{device_serial_number}') device_id not in devices
or device_serial_number not in devices[device_id]
):
bumble_transport_names.append(
f'{basic_transport_name}/{device_serial_number}'
)
# Print the results # Print the results
print(color(f'ID {device.getVendorID():04X}:{device.getProductID():04X}', fg=fg_color, bg=bg_color)) print(
color(
f'ID {device.getVendorID():04X}:{device.getProductID():04X}',
fg=fg_color,
bg=bg_color,
)
)
if bumble_transport_names: if bumble_transport_names:
print(color(' Bumble Transport Names:', 'blue'), ' or '.join(color(x, 'cyan' if device_is_bluetooth_hci else 'red') for x in bumble_transport_names)) print(
print(color(' Bus/Device: ', 'green'), f'{device.getBusNumber():03}/{device.getDeviceAddress():03}') color(' Bumble Transport Names:', 'blue'),
' or '.join(
color(x, 'cyan' if device_is_bluetooth_hci else 'red')
for x in bumble_transport_names
),
)
print(
color(' Bus/Device: ', 'green'),
f'{device.getBusNumber():03}/{device.getDeviceAddress():03}',
)
print(color(' Class: ', 'green'), device_class_string) print(color(' Class: ', 'green'), device_class_string)
print(color(' Subclass/Protocol: ', 'green'), device_subclass_string) print(color(' Subclass/Protocol: ', 'green'), device_subclass_string)
if device_serial_number is not None: if device_serial_number is not None:
+159 -120
View File
@@ -30,7 +30,7 @@ from .sdp import (
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
) )
from .core import ( from .core import (
BT_L2CAP_PROTOCOL_ID, BT_L2CAP_PROTOCOL_ID,
@@ -38,7 +38,7 @@ from .core import (
BT_AUDIO_SINK_SERVICE, BT_AUDIO_SINK_SERVICE,
BT_AVDTP_PROTOCOL_ID, BT_AVDTP_PROTOCOL_ID,
BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE, BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE,
name_or_number name_or_number,
) )
@@ -51,6 +51,7 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Constants # Constants
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# fmt: off
A2DP_SBC_CODEC_TYPE = 0x00 A2DP_SBC_CODEC_TYPE = 0x00
A2DP_MPEG_1_2_AUDIO_CODEC_TYPE = 0x01 A2DP_MPEG_1_2_AUDIO_CODEC_TYPE = 0x01
@@ -127,6 +128,8 @@ MPEG_2_4_OBJECT_TYPE_NAMES = {
MPEG_4_AAC_SCALABLE_OBJECT_TYPE: 'MPEG_4_AAC_SCALABLE_OBJECT_TYPE' MPEG_4_AAC_SCALABLE_OBJECT_TYPE: 'MPEG_4_AAC_SCALABLE_OBJECT_TYPE'
} }
# fmt: on
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def flags_to_list(flags, values): def flags_to_list(flags, values):
@@ -140,58 +143,98 @@ def flags_to_list(flags, values):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def make_audio_source_service_sdp_records(service_record_handle, version=(1, 3)): def make_audio_source_service_sdp_records(service_record_handle, version=(1, 3)):
from .avdtp import AVDTP_PSM from .avdtp import AVDTP_PSM
version_int = version[0] << 8 | version[1] version_int = version[0] << 8 | version[1]
return [ return [
ServiceAttribute(SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, DataElement.unsigned_integer_32(service_record_handle)), ServiceAttribute(
ServiceAttribute(SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID, DataElement.sequence([ SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT) DataElement.unsigned_integer_32(service_record_handle),
])), ),
ServiceAttribute(SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, DataElement.sequence([ ServiceAttribute(
DataElement.uuid(BT_AUDIO_SOURCE_SERVICE) SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
])), DataElement.sequence([DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)]),
ServiceAttribute(SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, DataElement.sequence([ ),
DataElement.sequence([ ServiceAttribute(
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
DataElement.sequence([DataElement.uuid(BT_AUDIO_SOURCE_SERVICE)]),
),
ServiceAttribute(
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.sequence(
[
DataElement.uuid(BT_L2CAP_PROTOCOL_ID), DataElement.uuid(BT_L2CAP_PROTOCOL_ID),
DataElement.unsigned_integer_16(AVDTP_PSM) DataElement.unsigned_integer_16(AVDTP_PSM),
]), ]
DataElement.sequence([ ),
DataElement.sequence(
[
DataElement.uuid(BT_AVDTP_PROTOCOL_ID), DataElement.uuid(BT_AVDTP_PROTOCOL_ID),
DataElement.unsigned_integer_16(version_int) DataElement.unsigned_integer_16(version_int),
]) ]
])), ),
ServiceAttribute(SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID, DataElement.sequence([ ]
),
),
ServiceAttribute(
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE), DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE),
DataElement.unsigned_integer_16(version_int) DataElement.unsigned_integer_16(version_int),
])), ]
),
),
] ]
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def make_audio_sink_service_sdp_records(service_record_handle, version=(1, 3)): def make_audio_sink_service_sdp_records(service_record_handle, version=(1, 3)):
from .avdtp import AVDTP_PSM from .avdtp import AVDTP_PSM
version_int = version[0] << 8 | version[1] version_int = version[0] << 8 | version[1]
return [ return [
ServiceAttribute(SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, DataElement.unsigned_integer_32(service_record_handle)), ServiceAttribute(
ServiceAttribute(SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID, DataElement.sequence([ SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT) DataElement.unsigned_integer_32(service_record_handle),
])), ),
ServiceAttribute(SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, DataElement.sequence([ ServiceAttribute(
DataElement.uuid(BT_AUDIO_SINK_SERVICE) SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
])), DataElement.sequence([DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)]),
ServiceAttribute(SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, DataElement.sequence([ ),
DataElement.sequence([ ServiceAttribute(
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
DataElement.sequence([DataElement.uuid(BT_AUDIO_SINK_SERVICE)]),
),
ServiceAttribute(
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.sequence(
[
DataElement.uuid(BT_L2CAP_PROTOCOL_ID), DataElement.uuid(BT_L2CAP_PROTOCOL_ID),
DataElement.unsigned_integer_16(AVDTP_PSM) DataElement.unsigned_integer_16(AVDTP_PSM),
]), ]
DataElement.sequence([ ),
DataElement.sequence(
[
DataElement.uuid(BT_AVDTP_PROTOCOL_ID), DataElement.uuid(BT_AVDTP_PROTOCOL_ID),
DataElement.unsigned_integer_16(version_int) DataElement.unsigned_integer_16(version_int),
]) ]
])), ),
ServiceAttribute(SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID, DataElement.sequence([ ]
),
),
ServiceAttribute(
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE), DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE),
DataElement.unsigned_integer_16(version_int) DataElement.unsigned_integer_16(version_int),
])), ]
),
),
] ]
@@ -206,8 +249,8 @@ class SbcMediaCodecInformation(
'subbands', 'subbands',
'allocation_method', 'allocation_method',
'minimum_bitpool_value', 'minimum_bitpool_value',
'maximum_bitpool_value' 'maximum_bitpool_value',
] ],
) )
): ):
''' '''
@@ -215,36 +258,25 @@ class SbcMediaCodecInformation(
''' '''
BIT_FIELDS = 'u4u4u4u2u2u8u8' BIT_FIELDS = 'u4u4u4u2u2u8u8'
SAMPLING_FREQUENCY_BITS = { SAMPLING_FREQUENCY_BITS = {16000: 1 << 3, 32000: 1 << 2, 44100: 1 << 1, 48000: 1}
16000: 1 << 3,
32000: 1 << 2,
44100: 1 << 1,
48000: 1
}
CHANNEL_MODE_BITS = { CHANNEL_MODE_BITS = {
SBC_MONO_CHANNEL_MODE: 1 << 3, SBC_MONO_CHANNEL_MODE: 1 << 3,
SBC_DUAL_CHANNEL_MODE: 1 << 2, SBC_DUAL_CHANNEL_MODE: 1 << 2,
SBC_STEREO_CHANNEL_MODE: 1 << 1, SBC_STEREO_CHANNEL_MODE: 1 << 1,
SBC_JOINT_STEREO_CHANNEL_MODE: 1 SBC_JOINT_STEREO_CHANNEL_MODE: 1,
}
BLOCK_LENGTH_BITS = {
4: 1 << 3,
8: 1 << 2,
12: 1 << 1,
16: 1
}
SUBBANDS_BITS = {
4: 1 << 1,
8: 1
} }
BLOCK_LENGTH_BITS = {4: 1 << 3, 8: 1 << 2, 12: 1 << 1, 16: 1}
SUBBANDS_BITS = {4: 1 << 1, 8: 1}
ALLOCATION_METHOD_BITS = { ALLOCATION_METHOD_BITS = {
SBC_SNR_ALLOCATION_METHOD: 1 << 1, SBC_SNR_ALLOCATION_METHOD: 1 << 1,
SBC_LOUDNESS_ALLOCATION_METHOD: 1 SBC_LOUDNESS_ALLOCATION_METHOD: 1,
} }
@staticmethod @staticmethod
def from_bytes(data): def from_bytes(data):
return SbcMediaCodecInformation(*bitstruct.unpack(SbcMediaCodecInformation.BIT_FIELDS, data)) return SbcMediaCodecInformation(
*bitstruct.unpack(SbcMediaCodecInformation.BIT_FIELDS, data)
)
@classmethod @classmethod
def from_discrete_values( def from_discrete_values(
@@ -255,7 +287,7 @@ class SbcMediaCodecInformation(
subbands, subbands,
allocation_method, allocation_method,
minimum_bitpool_value, minimum_bitpool_value,
maximum_bitpool_value maximum_bitpool_value,
): ):
return SbcMediaCodecInformation( return SbcMediaCodecInformation(
sampling_frequency=cls.SAMPLING_FREQUENCY_BITS[sampling_frequency], sampling_frequency=cls.SAMPLING_FREQUENCY_BITS[sampling_frequency],
@@ -264,7 +296,7 @@ class SbcMediaCodecInformation(
subbands=cls.SUBBANDS_BITS[subbands], subbands=cls.SUBBANDS_BITS[subbands],
allocation_method=cls.ALLOCATION_METHOD_BITS[allocation_method], allocation_method=cls.ALLOCATION_METHOD_BITS[allocation_method],
minimum_bitpool_value=minimum_bitpool_value, minimum_bitpool_value=minimum_bitpool_value,
maximum_bitpool_value = maximum_bitpool_value maximum_bitpool_value=maximum_bitpool_value,
) )
@classmethod @classmethod
@@ -276,16 +308,20 @@ class SbcMediaCodecInformation(
subbands, subbands,
allocation_methods, allocation_methods,
minimum_bitpool_value, minimum_bitpool_value,
maximum_bitpool_value maximum_bitpool_value,
): ):
return SbcMediaCodecInformation( return SbcMediaCodecInformation(
sampling_frequency = sum(cls.SAMPLING_FREQUENCY_BITS[x] for x in sampling_frequencies), sampling_frequency=sum(
cls.SAMPLING_FREQUENCY_BITS[x] for x in sampling_frequencies
),
channel_mode=sum(cls.CHANNEL_MODE_BITS[x] for x in channel_modes), channel_mode=sum(cls.CHANNEL_MODE_BITS[x] for x in channel_modes),
block_length=sum(cls.BLOCK_LENGTH_BITS[x] for x in block_lengths), block_length=sum(cls.BLOCK_LENGTH_BITS[x] for x in block_lengths),
subbands=sum(cls.SUBBANDS_BITS[x] for x in subbands), subbands=sum(cls.SUBBANDS_BITS[x] for x in subbands),
allocation_method = sum(cls.ALLOCATION_METHOD_BITS[x] for x in allocation_methods), allocation_method=sum(
cls.ALLOCATION_METHOD_BITS[x] for x in allocation_methods
),
minimum_bitpool_value=minimum_bitpool_value, minimum_bitpool_value=minimum_bitpool_value,
maximum_bitpool_value = maximum_bitpool_value maximum_bitpool_value=maximum_bitpool_value,
) )
def __bytes__(self): def __bytes__(self):
@@ -294,7 +330,8 @@ class SbcMediaCodecInformation(
def __str__(self): def __str__(self):
channel_modes = ['MONO', 'DUAL_CHANNEL', 'STEREO', 'JOINT_STEREO'] channel_modes = ['MONO', 'DUAL_CHANNEL', 'STEREO', 'JOINT_STEREO']
allocation_methods = ['SNR', 'Loudness'] allocation_methods = ['SNR', 'Loudness']
return '\n'.join([ return '\n'.join(
[
'SbcMediaCodecInformation(', 'SbcMediaCodecInformation(',
f' sampling_frequency: {",".join([str(x) for x in flags_to_list(self.sampling_frequency, SBC_SAMPLING_FREQUENCIES)])}', f' sampling_frequency: {",".join([str(x) for x in flags_to_list(self.sampling_frequency, SBC_SAMPLING_FREQUENCIES)])}',
f' channel_mode: {",".join([str(x) for x in flags_to_list(self.channel_mode, channel_modes)])}', f' channel_mode: {",".join([str(x) for x in flags_to_list(self.channel_mode, channel_modes)])}',
@@ -302,22 +339,16 @@ class SbcMediaCodecInformation(
f' subbands: {",".join([str(x) for x in flags_to_list(self.subbands, SBC_SUBBANDS)])}', f' subbands: {",".join([str(x) for x in flags_to_list(self.subbands, SBC_SUBBANDS)])}',
f' allocation_method: {",".join([str(x) for x in flags_to_list(self.allocation_method, allocation_methods)])}', f' allocation_method: {",".join([str(x) for x in flags_to_list(self.allocation_method, allocation_methods)])}',
f' minimum_bitpool_value: {self.minimum_bitpool_value}', f' minimum_bitpool_value: {self.minimum_bitpool_value}',
f' maximum_bitpool_value: {self.maximum_bitpool_value}' f' maximum_bitpool_value: {self.maximum_bitpool_value}' ')',
')' ]
]) )
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class AacMediaCodecInformation( class AacMediaCodecInformation(
namedtuple( namedtuple(
'AacMediaCodecInformation', 'AacMediaCodecInformation',
[ ['object_type', 'sampling_frequency', 'channels', 'vbr', 'bitrate'],
'object_type',
'sampling_frequency',
'channels',
'vbr',
'bitrate'
]
) )
): ):
''' '''
@@ -329,7 +360,7 @@ class AacMediaCodecInformation(
MPEG_2_AAC_LC_OBJECT_TYPE: 1 << 7, MPEG_2_AAC_LC_OBJECT_TYPE: 1 << 7,
MPEG_4_AAC_LC_OBJECT_TYPE: 1 << 6, MPEG_4_AAC_LC_OBJECT_TYPE: 1 << 6,
MPEG_4_AAC_LTP_OBJECT_TYPE: 1 << 5, MPEG_4_AAC_LTP_OBJECT_TYPE: 1 << 5,
MPEG_4_AAC_SCALABLE_OBJECT_TYPE: 1 << 4 MPEG_4_AAC_SCALABLE_OBJECT_TYPE: 1 << 4,
} }
SAMPLING_FREQUENCY_BITS = { SAMPLING_FREQUENCY_BITS = {
8000: 1 << 11, 8000: 1 << 11,
@@ -343,66 +374,65 @@ class AacMediaCodecInformation(
48000: 1 << 3, 48000: 1 << 3,
64000: 1 << 2, 64000: 1 << 2,
88200: 1 << 1, 88200: 1 << 1,
96000: 1 96000: 1,
}
CHANNELS_BITS = {
1: 1 << 1,
2: 1
} }
CHANNELS_BITS = {1: 1 << 1, 2: 1}
@staticmethod @staticmethod
def from_bytes(data): def from_bytes(data):
return AacMediaCodecInformation(*bitstruct.unpack(AacMediaCodecInformation.BIT_FIELDS, data)) return AacMediaCodecInformation(
*bitstruct.unpack(AacMediaCodecInformation.BIT_FIELDS, data)
)
@classmethod @classmethod
def from_discrete_values( def from_discrete_values(
cls, cls, object_type, sampling_frequency, channels, vbr, bitrate
object_type,
sampling_frequency,
channels,
vbr,
bitrate
): ):
return AacMediaCodecInformation( return AacMediaCodecInformation(
object_type=cls.OBJECT_TYPE_BITS[object_type], object_type=cls.OBJECT_TYPE_BITS[object_type],
sampling_frequency=cls.SAMPLING_FREQUENCY_BITS[sampling_frequency], sampling_frequency=cls.SAMPLING_FREQUENCY_BITS[sampling_frequency],
channels=cls.CHANNELS_BITS[channels], channels=cls.CHANNELS_BITS[channels],
vbr=vbr, vbr=vbr,
bitrate = bitrate bitrate=bitrate,
) )
@classmethod @classmethod
def from_lists( def from_lists(cls, object_types, sampling_frequencies, channels, vbr, bitrate):
cls,
object_types,
sampling_frequencies,
channels,
vbr,
bitrate
):
return AacMediaCodecInformation( return AacMediaCodecInformation(
object_type=sum(cls.OBJECT_TYPE_BITS[x] for x in object_types), object_type=sum(cls.OBJECT_TYPE_BITS[x] for x in object_types),
sampling_frequency = sum(cls.SAMPLING_FREQUENCY_BITS[x] for x in sampling_frequencies), sampling_frequency=sum(
cls.SAMPLING_FREQUENCY_BITS[x] for x in sampling_frequencies
),
channels=sum(cls.CHANNELS_BITS[x] for x in channels), channels=sum(cls.CHANNELS_BITS[x] for x in channels),
vbr=vbr, vbr=vbr,
bitrate = bitrate bitrate=bitrate,
) )
def __bytes__(self): def __bytes__(self):
return bitstruct.pack(self.BIT_FIELDS, *self) return bitstruct.pack(self.BIT_FIELDS, *self)
def __str__(self): def __str__(self):
object_types = ['MPEG_2_AAC_LC', 'MPEG_4_AAC_LC', 'MPEG_4_AAC_LTP', 'MPEG_4_AAC_SCALABLE', '[4]', '[5]', '[6]', '[7]'] object_types = [
'MPEG_2_AAC_LC',
'MPEG_4_AAC_LC',
'MPEG_4_AAC_LTP',
'MPEG_4_AAC_SCALABLE',
'[4]',
'[5]',
'[6]',
'[7]',
]
channels = [1, 2] channels = [1, 2]
return '\n'.join([ return '\n'.join(
[
'AacMediaCodecInformation(', 'AacMediaCodecInformation(',
f' object_type: {",".join([str(x) for x in flags_to_list(self.object_type, object_types)])}', f' object_type: {",".join([str(x) for x in flags_to_list(self.object_type, object_types)])}',
f' sampling_frequency: {",".join([str(x) for x in flags_to_list(self.sampling_frequency, MPEG_2_4_AAC_SAMPLING_FREQUENCIES)])}', f' sampling_frequency: {",".join([str(x) for x in flags_to_list(self.sampling_frequency, MPEG_2_4_AAC_SAMPLING_FREQUENCIES)])}',
f' channels: {",".join([str(x) for x in flags_to_list(self.channels, channels)])}', f' channels: {",".join([str(x) for x in flags_to_list(self.channels, channels)])}',
f' vbr: {self.vbr}', f' vbr: {self.vbr}',
f' bitrate: {self.bitrate}' f' bitrate: {self.bitrate}' ')',
')' ]
]) )
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -425,24 +455,20 @@ class VendorSpecificMediaCodecInformation:
return struct.pack('<IH', self.vendor_id, self.codec_id, self.value) return struct.pack('<IH', self.vendor_id, self.codec_id, self.value)
def __str__(self): def __str__(self):
return '\n'.join([ return '\n'.join(
[
'VendorSpecificMediaCodecInformation(', 'VendorSpecificMediaCodecInformation(',
f' vendor_id: {self.vendor_id:08X} ({name_or_number(COMPANY_IDENTIFIERS, self.vendor_id & 0xFFFF)})', f' vendor_id: {self.vendor_id:08X} ({name_or_number(COMPANY_IDENTIFIERS, self.vendor_id & 0xFFFF)})',
f' codec_id: {self.codec_id:04X}', f' codec_id: {self.codec_id:04X}',
f' value: {self.value.hex()}' f' value: {self.value.hex()}' ')',
')' ]
]) )
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class SbcFrame: class SbcFrame:
def __init__( def __init__(
self, self, sampling_frequency, block_count, channel_mode, subband_count, payload
sampling_frequency,
block_count,
channel_mode,
subband_count,
payload
): ):
self.sampling_frequency = sampling_frequency self.sampling_frequency = sampling_frequency
self.block_count = block_count self.block_count = block_count
@@ -498,13 +524,19 @@ class SbcParser:
if channel_mode in (SBC_MONO_CHANNEL_MODE, SBC_DUAL_CHANNEL_MODE): if channel_mode in (SBC_MONO_CHANNEL_MODE, SBC_DUAL_CHANNEL_MODE):
frame_length += (blocks * channels * bitpool) // 8 frame_length += (blocks * channels * bitpool) // 8
else: else:
frame_length += ((1 if channel_mode == SBC_JOINT_STEREO_CHANNEL_MODE else 0) * subbands + blocks * bitpool) // 8 frame_length += (
(1 if channel_mode == SBC_JOINT_STEREO_CHANNEL_MODE else 0)
* subbands
+ blocks * bitpool
) // 8
# Read the rest of the frame # Read the rest of the frame
payload = header + await self.read(frame_length - 4) payload = header + await self.read(frame_length - 4)
# Emit the next frame # Emit the next frame
yield SbcFrame(sampling_frequency, blocks, channel_mode, subbands, payload) yield SbcFrame(
sampling_frequency, blocks, channel_mode, subbands, payload
)
return generate_frames() return generate_frames()
@@ -532,12 +564,19 @@ class SbcPacketSource:
async for frame in sbc_parser.frames: async for frame in sbc_parser.frames:
print(frame) print(frame)
if frames_size + len(frame.payload) > max_rtp_payload or len(frames) == 16: if (
frames_size + len(frame.payload) > max_rtp_payload
or len(frames) == 16
):
# Need to flush what has been accumulated so far # Need to flush what has been accumulated so far
# Emit a packet # Emit a packet
sbc_payload = bytes([len(frames)]) + b''.join([frame.payload for frame in frames]) sbc_payload = bytes([len(frames)]) + b''.join(
packet = MediaPacket(2, 0, 0, 0, sequence_number, timestamp, 0, [], 96, sbc_payload) [frame.payload for frame in frames]
)
packet = MediaPacket(
2, 0, 0, 0, sequence_number, timestamp, 0, [], 96, sbc_payload
)
packet.timestamp_seconds = timestamp / frame.sampling_frequency packet.timestamp_seconds = timestamp / frame.sampling_frequency
yield packet yield packet
+144 -100
View File
@@ -31,6 +31,8 @@ from .hci import *
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Constants # Constants
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# fmt: off
ATT_CID = 0x04 ATT_CID = 0x04
ATT_ERROR_RESPONSE = 0x01 ATT_ERROR_RESPONSE = 0x01
@@ -166,6 +168,8 @@ HANDLE_FIELD_SPEC = {'size': 2, 'mapper': lambda x: f'0x{x:04X}'}
UUID_2_16_FIELD_SPEC = lambda x, y: UUID.parse_uuid(x, y) # noqa: E731 UUID_2_16_FIELD_SPEC = lambda x, y: UUID.parse_uuid(x, y) # noqa: E731
UUID_2_FIELD_SPEC = lambda x, y: UUID.parse_uuid_2(x, y) # noqa: E731 UUID_2_FIELD_SPEC = lambda x, y: UUID.parse_uuid_2(x, y) # noqa: E731
# fmt: on
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Utils # Utils
@@ -196,6 +200,7 @@ class ATT_PDU:
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.3 ATTRIBUTE PDU See Bluetooth spec @ Vol 3, Part F - 3.3 ATTRIBUTE PDU
''' '''
pdu_classes = {} pdu_classes = {}
op_code = 0 op_code = 0
@@ -274,11 +279,13 @@ class ATT_PDU:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass([ @ATT_PDU.subclass(
[
('request_opcode_in_error', {'size': 1, 'mapper': ATT_PDU.pdu_name}), ('request_opcode_in_error', {'size': 1, 'mapper': ATT_PDU.pdu_name}),
('attribute_handle_in_error', HANDLE_FIELD_SPEC), ('attribute_handle_in_error', HANDLE_FIELD_SPEC),
('error_code', {'size': 1, 'mapper': ATT_PDU.error_name}) ('error_code', {'size': 1, 'mapper': ATT_PDU.error_name}),
]) ]
)
class ATT_Error_Response(ATT_PDU): class ATT_Error_Response(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.1.1 Error Response See Bluetooth spec @ Vol 3, Part F - 3.4.1.1 Error Response
@@ -286,9 +293,7 @@ class ATT_Error_Response(ATT_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass([ @ATT_PDU.subclass([('client_rx_mtu', 2)])
('client_rx_mtu', 2)
])
class ATT_Exchange_MTU_Request(ATT_PDU): class ATT_Exchange_MTU_Request(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.2.1 Exchange MTU Request See Bluetooth spec @ Vol 3, Part F - 3.4.2.1 Exchange MTU Request
@@ -296,9 +301,7 @@ class ATT_Exchange_MTU_Request(ATT_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass([ @ATT_PDU.subclass([('server_rx_mtu', 2)])
('server_rx_mtu', 2)
])
class ATT_Exchange_MTU_Response(ATT_PDU): class ATT_Exchange_MTU_Response(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.2.2 Exchange MTU Response See Bluetooth spec @ Vol 3, Part F - 3.4.2.2 Exchange MTU Response
@@ -306,10 +309,9 @@ class ATT_Exchange_MTU_Response(ATT_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass([ @ATT_PDU.subclass(
('starting_handle', HANDLE_FIELD_SPEC), [('starting_handle', HANDLE_FIELD_SPEC), ('ending_handle', HANDLE_FIELD_SPEC)]
('ending_handle', HANDLE_FIELD_SPEC) )
])
class ATT_Find_Information_Request(ATT_PDU): class ATT_Find_Information_Request(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.3.1 Find Information Request See Bluetooth spec @ Vol 3, Part F - 3.4.3.1 Find Information Request
@@ -317,10 +319,7 @@ class ATT_Find_Information_Request(ATT_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass([ @ATT_PDU.subclass([('format', 1), ('information_data', '*')])
('format', 1),
('information_data', '*')
])
class ATT_Find_Information_Response(ATT_PDU): class ATT_Find_Information_Response(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.3.2 Find Information Response See Bluetooth spec @ Vol 3, Part F - 3.4.3.2 Find Information Response
@@ -346,20 +345,33 @@ class ATT_Find_Information_Response(ATT_PDU):
def __str__(self): def __str__(self):
result = color(self.name, 'yellow') result = color(self.name, 'yellow')
result += ':\n' + HCI_Object.format_fields(self.__dict__, [ result += ':\n' + HCI_Object.format_fields(
self.__dict__,
[
('format', 1), ('format', 1),
('information', {'mapper': lambda x: ', '.join([f'0x{handle:04X}:{uuid.hex()}' for handle, uuid in x])}) (
], ' ') 'information',
{
'mapper': lambda x: ', '.join(
[f'0x{handle:04X}:{uuid.hex()}' for handle, uuid in x]
)
},
),
],
' ',
)
return result return result
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass([ @ATT_PDU.subclass(
[
('starting_handle', HANDLE_FIELD_SPEC), ('starting_handle', HANDLE_FIELD_SPEC),
('ending_handle', HANDLE_FIELD_SPEC), ('ending_handle', HANDLE_FIELD_SPEC),
('attribute_type', UUID_2_FIELD_SPEC), ('attribute_type', UUID_2_FIELD_SPEC),
('attribute_value', '*') ('attribute_value', '*'),
]) ]
)
class ATT_Find_By_Type_Value_Request(ATT_PDU): class ATT_Find_By_Type_Value_Request(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.3.3 Find By Type Value Request See Bluetooth spec @ Vol 3, Part F - 3.4.3.3 Find By Type Value Request
@@ -367,9 +379,7 @@ class ATT_Find_By_Type_Value_Request(ATT_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass([ @ATT_PDU.subclass([('handles_information_list', '*')])
('handles_information_list', '*')
])
class ATT_Find_By_Type_Value_Response(ATT_PDU): class ATT_Find_By_Type_Value_Response(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.3.4 Find By Type Value Response See Bluetooth spec @ Vol 3, Part F - 3.4.3.4 Find By Type Value Response
@@ -379,7 +389,9 @@ class ATT_Find_By_Type_Value_Response(ATT_PDU):
self.handles_information = [] self.handles_information = []
offset = 0 offset = 0
while offset + 4 <= len(self.handles_information_list): while offset + 4 <= len(self.handles_information_list):
found_attribute_handle, group_end_handle = struct.unpack_from('<HH', self.handles_information_list, offset) found_attribute_handle, group_end_handle = struct.unpack_from(
'<HH', self.handles_information_list, offset
)
self.handles_information.append((found_attribute_handle, group_end_handle)) self.handles_information.append((found_attribute_handle, group_end_handle))
offset += 4 offset += 4
@@ -393,18 +405,34 @@ class ATT_Find_By_Type_Value_Response(ATT_PDU):
def __str__(self): def __str__(self):
result = color(self.name, 'yellow') result = color(self.name, 'yellow')
result += ':\n' + HCI_Object.format_fields(self.__dict__, [ result += ':\n' + HCI_Object.format_fields(
('handles_information', {'mapper': lambda x: ', '.join([f'0x{handle1:04X}-0x{handle2:04X}' for handle1, handle2 in x])}) self.__dict__,
], ' ') [
(
'handles_information',
{
'mapper': lambda x: ', '.join(
[
f'0x{handle1:04X}-0x{handle2:04X}'
for handle1, handle2 in x
]
)
},
)
],
' ',
)
return result return result
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass([ @ATT_PDU.subclass(
[
('starting_handle', HANDLE_FIELD_SPEC), ('starting_handle', HANDLE_FIELD_SPEC),
('ending_handle', HANDLE_FIELD_SPEC), ('ending_handle', HANDLE_FIELD_SPEC),
('attribute_type', UUID_2_16_FIELD_SPEC) ('attribute_type', UUID_2_16_FIELD_SPEC),
]) ]
)
class ATT_Read_By_Type_Request(ATT_PDU): class ATT_Read_By_Type_Request(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.4.1 Read By Type Request See Bluetooth spec @ Vol 3, Part F - 3.4.4.1 Read By Type Request
@@ -412,10 +440,7 @@ class ATT_Read_By_Type_Request(ATT_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass([ @ATT_PDU.subclass([('length', 1), ('attribute_data_list', '*')])
('length', 1),
('attribute_data_list', '*')
])
class ATT_Read_By_Type_Response(ATT_PDU): class ATT_Read_By_Type_Response(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.4.2 Read By Type Response See Bluetooth spec @ Vol 3, Part F - 3.4.4.2 Read By Type Response
@@ -424,9 +449,15 @@ class ATT_Read_By_Type_Response(ATT_PDU):
def parse_attribute_data_list(self): def parse_attribute_data_list(self):
self.attributes = [] self.attributes = []
offset = 0 offset = 0
while self.length != 0 and offset + self.length <= len(self.attribute_data_list): while self.length != 0 and offset + self.length <= len(
attribute_handle, = struct.unpack_from('<H', self.attribute_data_list, offset) self.attribute_data_list
attribute_value = self.attribute_data_list[offset + 2:offset + self.length] ):
(attribute_handle,) = struct.unpack_from(
'<H', self.attribute_data_list, offset
)
attribute_value = self.attribute_data_list[
offset + 2 : offset + self.length
]
self.attributes.append((attribute_handle, attribute_value)) self.attributes.append((attribute_handle, attribute_value))
offset += self.length offset += self.length
@@ -440,17 +471,26 @@ class ATT_Read_By_Type_Response(ATT_PDU):
def __str__(self): def __str__(self):
result = color(self.name, 'yellow') result = color(self.name, 'yellow')
result += ':\n' + HCI_Object.format_fields(self.__dict__, [ result += ':\n' + HCI_Object.format_fields(
self.__dict__,
[
('length', 1), ('length', 1),
('attributes', {'mapper': lambda x: ', '.join([f'0x{handle:04X}:{value.hex()}' for handle, value in x])}) (
], ' ') 'attributes',
{
'mapper': lambda x: ', '.join(
[f'0x{handle:04X}:{value.hex()}' for handle, value in x]
)
},
),
],
' ',
)
return result return result
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass([ @ATT_PDU.subclass([('attribute_handle', HANDLE_FIELD_SPEC)])
('attribute_handle', HANDLE_FIELD_SPEC)
])
class ATT_Read_Request(ATT_PDU): class ATT_Read_Request(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.4.3 Read Request See Bluetooth spec @ Vol 3, Part F - 3.4.4.3 Read Request
@@ -458,9 +498,7 @@ class ATT_Read_Request(ATT_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass([ @ATT_PDU.subclass([('attribute_value', '*')])
('attribute_value', '*')
])
class ATT_Read_Response(ATT_PDU): class ATT_Read_Response(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.4.4 Read Response See Bluetooth spec @ Vol 3, Part F - 3.4.4.4 Read Response
@@ -468,10 +506,7 @@ class ATT_Read_Response(ATT_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass([ @ATT_PDU.subclass([('attribute_handle', HANDLE_FIELD_SPEC), ('value_offset', 2)])
('attribute_handle', HANDLE_FIELD_SPEC),
('value_offset', 2)
])
class ATT_Read_Blob_Request(ATT_PDU): class ATT_Read_Blob_Request(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.4.5 Read Blob Request See Bluetooth spec @ Vol 3, Part F - 3.4.4.5 Read Blob Request
@@ -479,9 +514,7 @@ class ATT_Read_Blob_Request(ATT_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass([ @ATT_PDU.subclass([('part_attribute_value', '*')])
('part_attribute_value', '*')
])
class ATT_Read_Blob_Response(ATT_PDU): class ATT_Read_Blob_Response(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.4.6 Read Blob Response See Bluetooth spec @ Vol 3, Part F - 3.4.4.6 Read Blob Response
@@ -489,9 +522,7 @@ class ATT_Read_Blob_Response(ATT_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass([ @ATT_PDU.subclass([('set_of_handles', '*')])
('set_of_handles', '*')
])
class ATT_Read_Multiple_Request(ATT_PDU): class ATT_Read_Multiple_Request(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.4.7 Read Multiple Request See Bluetooth spec @ Vol 3, Part F - 3.4.4.7 Read Multiple Request
@@ -499,9 +530,7 @@ class ATT_Read_Multiple_Request(ATT_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass([ @ATT_PDU.subclass([('set_of_values', '*')])
('set_of_values', '*')
])
class ATT_Read_Multiple_Response(ATT_PDU): class ATT_Read_Multiple_Response(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.4.8 Read Multiple Response See Bluetooth spec @ Vol 3, Part F - 3.4.4.8 Read Multiple Response
@@ -509,11 +538,13 @@ class ATT_Read_Multiple_Response(ATT_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass([ @ATT_PDU.subclass(
[
('starting_handle', HANDLE_FIELD_SPEC), ('starting_handle', HANDLE_FIELD_SPEC),
('ending_handle', HANDLE_FIELD_SPEC), ('ending_handle', HANDLE_FIELD_SPEC),
('attribute_group_type', UUID_2_16_FIELD_SPEC) ('attribute_group_type', UUID_2_16_FIELD_SPEC),
]) ]
)
class ATT_Read_By_Group_Type_Request(ATT_PDU): class ATT_Read_By_Group_Type_Request(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.4.9 Read by Group Type Request See Bluetooth spec @ Vol 3, Part F - 3.4.4.9 Read by Group Type Request
@@ -521,10 +552,7 @@ class ATT_Read_By_Group_Type_Request(ATT_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass([ @ATT_PDU.subclass([('length', 1), ('attribute_data_list', '*')])
('length', 1),
('attribute_data_list', '*')
])
class ATT_Read_By_Group_Type_Response(ATT_PDU): class ATT_Read_By_Group_Type_Response(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.4.10 Read by Group Type Response See Bluetooth spec @ Vol 3, Part F - 3.4.4.10 Read by Group Type Response
@@ -533,10 +561,18 @@ class ATT_Read_By_Group_Type_Response(ATT_PDU):
def parse_attribute_data_list(self): def parse_attribute_data_list(self):
self.attributes = [] self.attributes = []
offset = 0 offset = 0
while self.length != 0 and offset + self.length <= len(self.attribute_data_list): while self.length != 0 and offset + self.length <= len(
attribute_handle, end_group_handle = struct.unpack_from('<HH', self.attribute_data_list, offset) self.attribute_data_list
attribute_value = self.attribute_data_list[offset + 4:offset + self.length] ):
self.attributes.append((attribute_handle, end_group_handle, attribute_value)) attribute_handle, end_group_handle = struct.unpack_from(
'<HH', self.attribute_data_list, offset
)
attribute_value = self.attribute_data_list[
offset + 4 : offset + self.length
]
self.attributes.append(
(attribute_handle, end_group_handle, attribute_value)
)
offset += self.length offset += self.length
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
@@ -549,18 +585,29 @@ class ATT_Read_By_Group_Type_Response(ATT_PDU):
def __str__(self): def __str__(self):
result = color(self.name, 'yellow') result = color(self.name, 'yellow')
result += ':\n' + HCI_Object.format_fields(self.__dict__, [ result += ':\n' + HCI_Object.format_fields(
self.__dict__,
[
('length', 1), ('length', 1),
('attributes', {'mapper': lambda x: ', '.join([f'0x{handle:04X}-0x{end:04X}:{value.hex()}' for handle, end, value in x])}) (
], ' ') 'attributes',
{
'mapper': lambda x: ', '.join(
[
f'0x{handle:04X}-0x{end:04X}:{value.hex()}'
for handle, end, value in x
]
)
},
),
],
' ',
)
return result return result
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass([ @ATT_PDU.subclass([('attribute_handle', HANDLE_FIELD_SPEC), ('attribute_value', '*')])
('attribute_handle', HANDLE_FIELD_SPEC),
('attribute_value', '*')
])
class ATT_Write_Request(ATT_PDU): class ATT_Write_Request(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.5.1 Write Request See Bluetooth spec @ Vol 3, Part F - 3.4.5.1 Write Request
@@ -576,10 +623,7 @@ class ATT_Write_Response(ATT_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass([ @ATT_PDU.subclass([('attribute_handle', HANDLE_FIELD_SPEC), ('attribute_value', '*')])
('attribute_handle', HANDLE_FIELD_SPEC),
('attribute_value', '*')
])
class ATT_Write_Command(ATT_PDU): class ATT_Write_Command(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.5.3 Write Command See Bluetooth spec @ Vol 3, Part F - 3.4.5.3 Write Command
@@ -587,11 +631,13 @@ class ATT_Write_Command(ATT_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass([ @ATT_PDU.subclass(
[
('attribute_handle', HANDLE_FIELD_SPEC), ('attribute_handle', HANDLE_FIELD_SPEC),
('attribute_value', '*') ('attribute_value', '*')
# ('authentication_signature', 'TODO') # ('authentication_signature', 'TODO')
]) ]
)
class ATT_Signed_Write_Command(ATT_PDU): class ATT_Signed_Write_Command(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.5.4 Signed Write Command See Bluetooth spec @ Vol 3, Part F - 3.4.5.4 Signed Write Command
@@ -599,11 +645,13 @@ class ATT_Signed_Write_Command(ATT_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass([ @ATT_PDU.subclass(
[
('attribute_handle', HANDLE_FIELD_SPEC), ('attribute_handle', HANDLE_FIELD_SPEC),
('value_offset', 2), ('value_offset', 2),
('part_attribute_value', '*') ('part_attribute_value', '*'),
]) ]
)
class ATT_Prepare_Write_Request(ATT_PDU): class ATT_Prepare_Write_Request(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.6.1 Prepare Write Request See Bluetooth spec @ Vol 3, Part F - 3.4.6.1 Prepare Write Request
@@ -611,11 +659,13 @@ class ATT_Prepare_Write_Request(ATT_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass([ @ATT_PDU.subclass(
[
('attribute_handle', HANDLE_FIELD_SPEC), ('attribute_handle', HANDLE_FIELD_SPEC),
('value_offset', 2), ('value_offset', 2),
('part_attribute_value', '*') ('part_attribute_value', '*'),
]) ]
)
class ATT_Prepare_Write_Response(ATT_PDU): class ATT_Prepare_Write_Response(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.6.2 Prepare Write Response See Bluetooth spec @ Vol 3, Part F - 3.4.6.2 Prepare Write Response
@@ -639,10 +689,7 @@ class ATT_Execute_Write_Response(ATT_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass([ @ATT_PDU.subclass([('attribute_handle', HANDLE_FIELD_SPEC), ('attribute_value', '*')])
('attribute_handle', HANDLE_FIELD_SPEC),
('attribute_value', '*')
])
class ATT_Handle_Value_Notification(ATT_PDU): class ATT_Handle_Value_Notification(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.7.1 Handle Value Notification See Bluetooth spec @ Vol 3, Part F - 3.4.7.1 Handle Value Notification
@@ -650,10 +697,7 @@ class ATT_Handle_Value_Notification(ATT_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass([ @ATT_PDU.subclass([('attribute_handle', HANDLE_FIELD_SPEC), ('attribute_value', '*')])
('attribute_handle', HANDLE_FIELD_SPEC),
('attribute_value', '*')
])
class ATT_Handle_Value_Indication(ATT_PDU): class ATT_Handle_Value_Indication(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.7.2 Handle Value Indication See Bluetooth spec @ Vol 3, Part F - 3.4.7.2 Handle Value Indication
+206 -104
View File
@@ -26,7 +26,7 @@ from .core import (
BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE, BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE,
InvalidStateError, InvalidStateError,
ProtocolError, ProtocolError,
name_or_number name_or_number,
) )
from .a2dp import ( from .a2dp import (
A2DP_CODEC_TYPE_NAMES, A2DP_CODEC_TYPE_NAMES,
@@ -35,7 +35,7 @@ from .a2dp import (
A2DP_SBC_CODEC_TYPE, A2DP_SBC_CODEC_TYPE,
AacMediaCodecInformation, AacMediaCodecInformation,
SbcMediaCodecInformation, SbcMediaCodecInformation,
VendorSpecificMediaCodecInformation VendorSpecificMediaCodecInformation,
) )
from . import sdp from . import sdp
@@ -48,6 +48,8 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Constants # Constants
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# fmt: off
AVDTP_PSM = 0x0019 AVDTP_PSM = 0x0019
AVDTP_DEFAULT_RTX_SIG_TIMER = 5 # Seconds AVDTP_DEFAULT_RTX_SIG_TIMER = 5 # Seconds
@@ -195,6 +197,8 @@ AVDTP_STATE_NAMES = {
AVDTP_ABORTING_STATE: 'AVDTP_ABORTING_STATE' AVDTP_ABORTING_STATE: 'AVDTP_ABORTING_STATE'
} }
# fmt: on
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def find_avdtp_service_with_sdp_client(sdp_client): async def find_avdtp_service_with_sdp_client(sdp_client):
@@ -206,14 +210,11 @@ async def find_avdtp_service_with_sdp_client(sdp_client):
# Search for services with an Audio Sink service class # Search for services with an Audio Sink service class
search_result = await sdp_client.search_attributes( search_result = await sdp_client.search_attributes(
[BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE], [BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE],
[ [sdp.SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID],
sdp.SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID
]
) )
for attribute_list in search_result: for attribute_list in search_result:
profile_descriptor_list = sdp.ServiceAttribute.find_attribute_in_list( profile_descriptor_list = sdp.ServiceAttribute.find_attribute_in_list(
attribute_list, attribute_list, sdp.SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID
sdp.SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID
) )
if profile_descriptor_list: if profile_descriptor_list:
for profile_descriptor in profile_descriptor_list.value: for profile_descriptor in profile_descriptor_list.value:
@@ -260,7 +261,9 @@ class MediaPacket:
sequence_number = struct.unpack_from('>H', data, 2)[0] sequence_number = struct.unpack_from('>H', data, 2)[0]
timestamp = struct.unpack_from('>I', data, 4)[0] timestamp = struct.unpack_from('>I', data, 4)[0]
ssrc = struct.unpack_from('>I', data, 8)[0] ssrc = struct.unpack_from('>I', data, 8)[0]
csrc_list = [struct.unpack_from('>I', data, 12 + i)[0] for i in range(csrc_count)] csrc_list = [
struct.unpack_from('>I', data, 12 + i)[0] for i in range(csrc_count)
]
payload = data[12 + csrc_count * 4 :] payload = data[12 + csrc_count * 4 :]
return MediaPacket( return MediaPacket(
@@ -273,7 +276,7 @@ class MediaPacket:
ssrc, ssrc,
csrc_list, csrc_list,
payload_type, payload_type,
payload payload,
) )
def __init__( def __init__(
@@ -287,7 +290,7 @@ class MediaPacket:
ssrc, ssrc,
csrc_list, csrc_list,
payload_type, payload_type,
payload payload,
): ):
self.version = version self.version = version
self.padding = padding self.padding = padding
@@ -301,13 +304,15 @@ class MediaPacket:
self.payload = payload self.payload = payload
def __bytes__(self): def __bytes__(self):
header = ( header = bytes(
bytes([ [
self.version << 6 | self.padding << 5 | self.extension << 4 | len(self.csrc_list), self.version << 6
self.marker << 7 | self.payload_type | self.padding << 5
]) + | self.extension << 4
struct.pack('>HII', self.sequence_number, self.timestamp, self.ssrc) | len(self.csrc_list),
) self.marker << 7 | self.payload_type,
]
) + struct.pack('>HII', self.sequence_number, self.timestamp, self.ssrc)
for csrc in self.csrc_list: for csrc in self.csrc_list:
header += struct.pack('>I', csrc) header += struct.pack('>I', csrc)
return header + self.payload return header + self.payload
@@ -346,12 +351,14 @@ class MediaPacketPump:
# Emit # Emit
rtp_channel.send_pdu(bytes(packet)) rtp_channel.send_pdu(bytes(packet))
logger.debug(f'{color(">>> sending RTP packet:", "green")} {packet}') logger.debug(
f'{color(">>> sending RTP packet:", "green")} {packet}'
)
except asyncio.exceptions.CancelledError: except asyncio.exceptions.CancelledError:
logger.debug('pump canceled') logger.debug('pump canceled')
# Pump packets # Pump packets
self.pump_task = asyncio.get_running_loop().create_task(pump_packets()) self.pump_task = asyncio.create_task(pump_packets())
async def stop(self): async def stop(self):
# Stop the pump # Stop the pump
@@ -382,11 +389,18 @@ class MessageAssembler:
packet_type = (pdu[0] >> 2) & 3 packet_type = (pdu[0] >> 2) & 3
message_type = pdu[0] & 3 message_type = pdu[0] & 3
logger.debug(f'transaction_label={transaction_label}, packet_type={Protocol.packet_type_name(packet_type)}, message_type={Message.message_type_name(message_type)}') logger.debug(
if packet_type == Protocol.SINGLE_PACKET or packet_type == Protocol.START_PACKET: f'transaction_label={transaction_label}, packet_type={Protocol.packet_type_name(packet_type)}, message_type={Message.message_type_name(message_type)}'
)
if (
packet_type == Protocol.SINGLE_PACKET
or packet_type == Protocol.START_PACKET
):
if self.message is not None: if self.message is not None:
# The previous message has not been terminated # The previous message has not been terminated
logger.warning('received a start or single packet when expecting an end or continuation') logger.warning(
'received a start or single packet when expecting an end or continuation'
)
self.reset() self.reset()
self.transaction_label = transaction_label self.transaction_label = transaction_label
@@ -399,36 +413,49 @@ class MessageAssembler:
else: else:
self.number_of_signal_packets = pdu[2] self.number_of_signal_packets = pdu[2]
self.message = pdu[3:] self.message = pdu[3:]
elif packet_type == Protocol.CONTINUE_PACKET or packet_type == Protocol.END_PACKET: elif (
packet_type == Protocol.CONTINUE_PACKET
or packet_type == Protocol.END_PACKET
):
if self.packet_count == 0: if self.packet_count == 0:
logger.warning('unexpected continuation') logger.warning('unexpected continuation')
return return
if transaction_label != self.transaction_label: if transaction_label != self.transaction_label:
logger.warning(f'transaction label mismatch: expected {self.transaction_label}, received {transaction_label}') logger.warning(
f'transaction label mismatch: expected {self.transaction_label}, received {transaction_label}'
)
return return
if message_type != self.message_type: if message_type != self.message_type:
logger.warning(f'message type mismatch: expected {self.message_type}, received {message_type}') logger.warning(
f'message type mismatch: expected {self.message_type}, received {message_type}'
)
return return
self.message += pdu[1:] self.message += pdu[1:]
if packet_type == Protocol.END_PACKET: if packet_type == Protocol.END_PACKET:
if self.packet_count != self.number_of_signal_packets: if self.packet_count != self.number_of_signal_packets:
logger.warning(f'incomplete fragmented message: expected {self.number_of_signal_packets} packets, received {self.packet_count}') logger.warning(
f'incomplete fragmented message: expected {self.number_of_signal_packets} packets, received {self.packet_count}'
)
self.reset() self.reset()
return return
self.on_message_complete() self.on_message_complete()
else: else:
if self.packet_count > self.number_of_signal_packets: if self.packet_count > self.number_of_signal_packets:
logger.warning(f'too many packets: expected {self.number_of_signal_packets}, received {self.packet_count}') logger.warning(
f'too many packets: expected {self.number_of_signal_packets}, received {self.packet_count}'
)
self.reset() self.reset()
return return
def on_message_complete(self): def on_message_complete(self):
message = Message.create(self.signal_identifier, self.message_type, self.message) message = Message.create(
self.signal_identifier, self.message_type, self.message
)
try: try:
self.callback(self.transaction_label, message) self.callback(self.transaction_label, message)
@@ -463,7 +490,9 @@ class ServiceCapabilities:
service_category = payload[0] service_category = payload[0]
length_of_service_capabilities = payload[1] length_of_service_capabilities = payload[1]
service_capabilities_bytes = payload[2 : 2 + length_of_service_capabilities] service_capabilities_bytes = payload[2 : 2 + length_of_service_capabilities]
capabilities.append(ServiceCapabilities.create(service_category, service_capabilities_bytes)) capabilities.append(
ServiceCapabilities.create(service_category, service_capabilities_bytes)
)
payload = payload[2 + length_of_service_capabilities :] payload = payload[2 + length_of_service_capabilities :]
@@ -473,10 +502,10 @@ class ServiceCapabilities:
def serialize_capabilities(capabilities): def serialize_capabilities(capabilities):
serialized = b'' serialized = b''
for item in capabilities: for item in capabilities:
serialized += bytes([ serialized += (
item.service_category, bytes([item.service_category, len(item.service_capabilities_bytes)])
len(item.service_capabilities_bytes) + item.service_capabilities_bytes
]) + item.service_capabilities_bytes )
return serialized return serialized
def init_from_bytes(self): def init_from_bytes(self):
@@ -487,7 +516,10 @@ class ServiceCapabilities:
self.service_capabilities_bytes = service_capabilities_bytes self.service_capabilities_bytes = service_capabilities_bytes
def to_string(self, details=[]): def to_string(self, details=[]):
attributes = ','.join([name_or_number(AVDTP_SERVICE_CATEGORY_NAMES, self.service_category)] + details) attributes = ','.join(
[name_or_number(AVDTP_SERVICE_CATEGORY_NAMES, self.service_category)]
+ details
)
return f'ServiceCapabilities({attributes})' return f'ServiceCapabilities({attributes})'
def __str__(self): def __str__(self):
@@ -506,16 +538,24 @@ class MediaCodecCapabilities(ServiceCapabilities):
self.media_codec_information = self.service_capabilities_bytes[2:] self.media_codec_information = self.service_capabilities_bytes[2:]
if self.media_codec_type == A2DP_SBC_CODEC_TYPE: if self.media_codec_type == A2DP_SBC_CODEC_TYPE:
self.media_codec_information = SbcMediaCodecInformation.from_bytes(self.media_codec_information) self.media_codec_information = SbcMediaCodecInformation.from_bytes(
self.media_codec_information
)
elif self.media_codec_type == A2DP_MPEG_2_4_AAC_CODEC_TYPE: elif self.media_codec_type == A2DP_MPEG_2_4_AAC_CODEC_TYPE:
self.media_codec_information = AacMediaCodecInformation.from_bytes(self.media_codec_information) self.media_codec_information = AacMediaCodecInformation.from_bytes(
self.media_codec_information
)
elif self.media_codec_type == A2DP_NON_A2DP_CODEC_TYPE: elif self.media_codec_type == A2DP_NON_A2DP_CODEC_TYPE:
self.media_codec_information = VendorSpecificMediaCodecInformation.from_bytes(self.media_codec_information) self.media_codec_information = (
VendorSpecificMediaCodecInformation.from_bytes(
self.media_codec_information
)
)
def __init__(self, media_type, media_codec_type, media_codec_information): def __init__(self, media_type, media_codec_type, media_codec_information):
super().__init__( super().__init__(
AVDTP_MEDIA_CODEC_SERVICE_CATEGORY, AVDTP_MEDIA_CODEC_SERVICE_CATEGORY,
bytes([media_type, media_codec_type]) + bytes(media_codec_information) bytes([media_type, media_codec_type]) + bytes(media_codec_information),
) )
self.media_type = media_type self.media_type = media_type
self.media_codec_type = media_codec_type self.media_codec_type = media_codec_type
@@ -525,7 +565,7 @@ class MediaCodecCapabilities(ServiceCapabilities):
details = [ details = [
f'media_type={name_or_number(AVDTP_MEDIA_TYPE_NAMES, self.media_type)}', f'media_type={name_or_number(AVDTP_MEDIA_TYPE_NAMES, self.media_type)}',
f'codec={name_or_number(A2DP_CODEC_TYPE_NAMES, self.media_codec_type)}', f'codec={name_or_number(A2DP_CODEC_TYPE_NAMES, self.media_codec_type)}',
f'codec_info={self.media_codec_information.hex() if type(self.media_codec_information) is bytes else str(self.media_codec_information)}' f'codec_info={self.media_codec_information.hex() if type(self.media_codec_information) is bytes else str(self.media_codec_information)}',
] ]
return self.to_string(details) return self.to_string(details)
@@ -535,17 +575,13 @@ class EndPointInfo:
@staticmethod @staticmethod
def from_bytes(payload): def from_bytes(payload):
return EndPointInfo( return EndPointInfo(
payload[0] >> 2, payload[0] >> 2, payload[0] >> 1 & 1, payload[1] >> 4, payload[1] >> 3 & 1
payload[0] >> 1 & 1,
payload[1] >> 4,
payload[1] >> 3 & 1
) )
def __bytes__(self): def __bytes__(self):
return bytes([ return bytes(
self.seid << 2 | self.in_use << 1, [self.seid << 2 | self.in_use << 1, self.media_type << 4 | self.tsep << 3]
self.media_type << 4 | self.tsep << 3 )
])
def __init__(self, seid, in_use, media_type, tsep): def __init__(self, seid, in_use, media_type, tsep):
self.seid = seid self.seid = seid
@@ -565,7 +601,7 @@ class Message:
COMMAND: 'COMMAND', COMMAND: 'COMMAND',
GENERAL_REJECT: 'GENERAL_REJECT', GENERAL_REJECT: 'GENERAL_REJECT',
RESPONSE_ACCEPT: 'RESPONSE_ACCEPT', RESPONSE_ACCEPT: 'RESPONSE_ACCEPT',
RESPONSE_REJECT: 'RESPONSE_REJECT' RESPONSE_REJECT: 'RESPONSE_REJECT',
} }
subclasses = {} # Subclasses, by signal identifier and message type subclasses = {} # Subclasses, by signal identifier and message type
@@ -603,7 +639,9 @@ class Message:
break break
# Register the subclass # Register the subclass
Message.subclasses.setdefault(cls.signal_identifier, {})[cls.message_type] = cls Message.subclasses.setdefault(cls.signal_identifier, {})[
cls.message_type
] = cls
return cls return cls
@@ -643,7 +681,11 @@ class Message:
if type(details) is str: if type(details) is str:
return f'{base}: {details}' return f'{base}: {details}'
else: else:
return base + ':\n' + '\n'.join([' ' + color(detail, 'cyan') for detail in details]) return (
base
+ ':\n'
+ '\n'.join([' ' + color(detail, 'cyan') for detail in details])
)
else: else:
return base return base
@@ -682,9 +724,7 @@ class Simple_Reject(Message):
self.payload = bytes([self.error_code]) self.payload = bytes([self.error_code])
def __str__(self): def __str__(self):
details = [ details = [f'error_code: {name_or_number(AVDTP_ERROR_NAMES, self.error_code)}']
f'error_code: {name_or_number(AVDTP_ERROR_NAMES, self.error_code)}'
]
return self.to_string(details) return self.to_string(details)
@@ -707,7 +747,9 @@ class Discover_Response(Message):
self.endpoints = [] self.endpoints = []
endpoint_count = len(self.payload) // 2 endpoint_count = len(self.payload) // 2
for i in range(endpoint_count): for i in range(endpoint_count):
self.endpoints.append(EndPointInfo.from_bytes(self.payload[i * 2:(i + 1) * 2])) self.endpoints.append(
EndPointInfo.from_bytes(self.payload[i * 2 : (i + 1) * 2])
)
def __init__(self, endpoints): def __init__(self, endpoints):
self.endpoints = endpoints self.endpoints = endpoints
@@ -721,7 +763,7 @@ class Discover_Response(Message):
f'ACP SEID: {endpoint.seid}', f'ACP SEID: {endpoint.seid}',
f' in_use: {endpoint.in_use}', f' in_use: {endpoint.in_use}',
f' media_type: {name_or_number(AVDTP_MEDIA_TYPE_NAMES, endpoint.media_type)}', f' media_type: {name_or_number(AVDTP_MEDIA_TYPE_NAMES, endpoint.media_type)}',
f' tsep: {name_or_number(AVDTP_TSEP_NAMES, endpoint.tsep)}' f' tsep: {name_or_number(AVDTP_TSEP_NAMES, endpoint.tsep)}',
] ]
) )
return self.to_string(details) return self.to_string(details)
@@ -802,13 +844,14 @@ class Set_Configuration_Command(Message):
self.acp_seid = acp_seid self.acp_seid = acp_seid
self.int_seid = int_seid self.int_seid = int_seid
self.capabilities = capabilities self.capabilities = capabilities
self.payload = bytes([acp_seid << 2, int_seid << 2]) + ServiceCapabilities.serialize_capabilities(capabilities) self.payload = bytes(
[acp_seid << 2, int_seid << 2]
) + ServiceCapabilities.serialize_capabilities(capabilities)
def __str__(self): def __str__(self):
details = [ details = [f'ACP SEID: {self.acp_seid}', f'INT SEID: {self.int_seid}'] + [
f'ACP SEID: {self.acp_seid}', str(capability) for capability in self.capabilities
f'INT SEID: {self.int_seid}' ]
] + [str(capability) for capability in self.capabilities]
return self.to_string(details) return self.to_string(details)
@@ -839,7 +882,7 @@ class Set_Configuration_Reject(Message):
def __str__(self): def __str__(self):
details = [ details = [
f'service_category: {name_or_number(AVDTP_SERVICE_CATEGORY_NAMES, self.service_category)}', f'service_category: {name_or_number(AVDTP_SERVICE_CATEGORY_NAMES, self.service_category)}',
f'error_code: {name_or_number(AVDTP_ERROR_NAMES, self.error_code)}' f'error_code: {name_or_number(AVDTP_ERROR_NAMES, self.error_code)}',
] ]
return self.to_string(details) return self.to_string(details)
@@ -982,7 +1025,7 @@ class Start_Reject(Message):
def __str__(self): def __str__(self):
details = [ details = [
f'acp_seid: {self.acp_seid}', f'acp_seid: {self.acp_seid}',
f'error_code: {name_or_number(AVDTP_ERROR_NAMES, self.error_code)}' f'error_code: {name_or_number(AVDTP_ERROR_NAMES, self.error_code)}',
] ]
return self.to_string(details) return self.to_string(details)
@@ -1098,10 +1141,7 @@ class DelayReport_Command(Message):
self.delay = (self.payload[1] << 8) | (self.payload[2]) self.delay = (self.payload[1] << 8) | (self.payload[2])
def __str__(self): def __str__(self):
return self.to_string([ return self.to_string([f'ACP_SEID: {self.acp_seid}', f'delay: {self.delay}'])
f'ACP_SEID: {self.acp_seid}',
f'delay: {self.delay}'
])
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -1131,7 +1171,7 @@ class Protocol:
SINGLE_PACKET: 'SINGLE_PACKET', SINGLE_PACKET: 'SINGLE_PACKET',
START_PACKET: 'START_PACKET', START_PACKET: 'START_PACKET',
CONTINUE_PACKET: 'CONTINUE_PACKET', CONTINUE_PACKET: 'CONTINUE_PACKET',
END_PACKET: 'END_PACKET' END_PACKET: 'END_PACKET',
} }
@staticmethod @staticmethod
@@ -1205,7 +1245,9 @@ class Protocol:
response = await self.send_command(Discover_Command()) response = await self.send_command(Discover_Command())
for endpoint_entry in response.endpoints: for endpoint_entry in response.endpoints:
logger.debug(f'getting endpoint capabilities for endpoint {endpoint_entry.seid}') logger.debug(
f'getting endpoint capabilities for endpoint {endpoint_entry.seid}'
)
get_capabilities_response = await self.get_capabilities(endpoint_entry.seid) get_capabilities_response = await self.get_capabilities(endpoint_entry.seid)
endpoint = DiscoveredStreamEndPoint( endpoint = DiscoveredStreamEndPoint(
self, self,
@@ -1213,7 +1255,7 @@ class Protocol:
endpoint_entry.media_type, endpoint_entry.media_type,
endpoint_entry.tsep, endpoint_entry.tsep,
endpoint_entry.in_use, endpoint_entry.in_use,
get_capabilities_response.capabilities get_capabilities_response.capabilities,
) )
self.remote_endpoints[endpoint_entry.seid] = endpoint self.remote_endpoints[endpoint_entry.seid] = endpoint
@@ -1221,14 +1263,27 @@ class Protocol:
def find_remote_sink_by_codec(self, media_type, codec_type): def find_remote_sink_by_codec(self, media_type, codec_type):
for endpoint in self.remote_endpoints.values(): for endpoint in self.remote_endpoints.values():
if not endpoint.in_use and endpoint.media_type == media_type and endpoint.tsep == AVDTP_TSEP_SNK: if (
not endpoint.in_use
and endpoint.media_type == media_type
and endpoint.tsep == AVDTP_TSEP_SNK
):
has_media_transport = False has_media_transport = False
has_codec = False has_codec = False
for capabilities in endpoint.capabilities: for capabilities in endpoint.capabilities:
if capabilities.service_category == AVDTP_MEDIA_TRANSPORT_SERVICE_CATEGORY: if (
capabilities.service_category
== AVDTP_MEDIA_TRANSPORT_SERVICE_CATEGORY
):
has_media_transport = True has_media_transport = True
elif capabilities.service_category == AVDTP_MEDIA_CODEC_SERVICE_CATEGORY: elif (
if capabilities.media_type == AVDTP_AUDIO_MEDIA_TYPE and capabilities.media_codec_type == codec_type: capabilities.service_category
== AVDTP_MEDIA_CODEC_SERVICE_CATEGORY
):
if (
capabilities.media_type == AVDTP_AUDIO_MEDIA_TYPE
and capabilities.media_codec_type == codec_type
):
has_codec = True has_codec = True
if has_media_transport and has_codec: if has_media_transport and has_codec:
return endpoint return endpoint
@@ -1237,7 +1292,9 @@ class Protocol:
self.message_assembler.on_pdu(pdu) self.message_assembler.on_pdu(pdu)
def on_message(self, transaction_label, message): def on_message(self, transaction_label, message):
logger.debug(f'{color("<<< Received AVDTP message", "magenta")}: [{transaction_label}] {message}') logger.debug(
f'{color("<<< Received AVDTP message", "magenta")}: [{transaction_label}] {message}'
)
# Check that the identifier is not reserved # Check that the identifier is not reserved
if message.signal_identifier == 0: if message.signal_identifier == 0:
@@ -1245,7 +1302,10 @@ class Protocol:
return return
# Check that the identifier is valid # Check that the identifier is valid
if message.signal_identifier < 0 or message.signal_identifier > AVDTP_DELAYREPORT: if (
message.signal_identifier < 0
or message.signal_identifier > AVDTP_DELAYREPORT
):
logger.warning('!!! invalid signal identifier') logger.warning('!!! invalid signal identifier')
self.send_message(transaction_label, General_Reject()) self.send_message(transaction_label, General_Reject())
@@ -1258,7 +1318,9 @@ class Protocol:
response = handler(message) response = handler(message)
self.send_message(transaction_label, response) self.send_message(transaction_label, response)
except Exception as error: except Exception as error:
logger.warning(f'{color("!!! Exception in handler:", "red")} {error}') logger.warning(
f'{color("!!! Exception in handler:", "red")} {error}'
)
else: else:
logger.warning('unhandled command') logger.warning('unhandled command')
else: else:
@@ -1281,8 +1343,12 @@ class Protocol:
logger.debug(color('<<< L2CAP channel open', 'magenta')) logger.debug(color('<<< L2CAP channel open', 'magenta'))
def send_message(self, transaction_label, message): def send_message(self, transaction_label, message):
logger.debug(f'{color(">>> Sending AVDTP message", "magenta")}: [{transaction_label}] {message}') logger.debug(
max_fragment_size = self.l2cap_channel.mtu - 3 # Enough space for a 3-byte start packet header f'{color(">>> Sending AVDTP message", "magenta")}: [{transaction_label}] {message}'
)
max_fragment_size = (
self.l2cap_channel.mtu - 3
) # Enough space for a 3-byte start packet header
payload = message.payload payload = message.payload
if len(payload) + 2 <= self.l2cap_channel.mtu: if len(payload) + 2 <= self.l2cap_channel.mtu:
# Fits in a single packet # Fits in a single packet
@@ -1292,13 +1358,19 @@ class Protocol:
done = False done = False
while not done: while not done:
first_header_byte = transaction_label << 4 | packet_type << 2 | message.message_type first_header_byte = (
transaction_label << 4 | packet_type << 2 | message.message_type
)
if packet_type == self.SINGLE_PACKET: if packet_type == self.SINGLE_PACKET:
header = bytes([first_header_byte, message.signal_identifier]) header = bytes([first_header_byte, message.signal_identifier])
elif packet_type == self.START_PACKET: elif packet_type == self.START_PACKET:
packet_count = (max_fragment_size - 1 + len(payload)) // max_fragment_size packet_count = (
header = bytes([first_header_byte, message.signal_identifier, packet_count]) max_fragment_size - 1 + len(payload)
) // max_fragment_size
header = bytes(
[first_header_byte, message.signal_identifier, packet_count]
)
else: else:
header = bytes([first_header_byte]) header = bytes([first_header_byte])
@@ -1308,7 +1380,11 @@ class Protocol:
# Prepare for the next packet # Prepare for the next packet
payload = payload[max_fragment_size:] payload = payload[max_fragment_size:]
if payload: if payload:
packet_type = self.CONTINUE_PACKET if payload > max_fragment_size else self.END_PACKET packet_type = (
self.CONTINUE_PACKET
if payload > max_fragment_size
else self.END_PACKET
)
else: else:
done = True done = True
@@ -1322,7 +1398,10 @@ class Protocol:
response = await transaction_result response = await transaction_result
# Check for errors # Check for errors
if response.message_type == Message.GENERAL_REJECT or response.message_type == Message.RESPONSE_REJECT: if (
response.message_type == Message.GENERAL_REJECT
or response.message_type == Message.RESPONSE_REJECT
):
raise ProtocolError(response.error_code, 'avdtp') raise ProtocolError(response.error_code, 'avdtp')
return response return response
@@ -1340,7 +1419,7 @@ class Protocol:
self.transaction_count += 1 self.transaction_count += 1
return (transaction_label, transaction_result) return (transaction_label, transaction_result)
assert(False) # Should never reach this assert False # Should never reach this
async def get_capabilities(self, seid): async def get_capabilities(self, seid):
if self.version > (1, 2): if self.version > (1, 2):
@@ -1349,7 +1428,9 @@ class Protocol:
return await self.send_command(Get_Capabilities_Command(seid)) return await self.send_command(Get_Capabilities_Command(seid))
async def set_configuration(self, acp_seid, int_seid, capabilities): async def set_configuration(self, acp_seid, int_seid, capabilities):
return await self.send_command(Set_Configuration_Command(acp_seid, int_seid, capabilities)) return await self.send_command(
Set_Configuration_Command(acp_seid, int_seid, capabilities)
)
async def get_configuration(self, seid): async def get_configuration(self, seid):
response = await self.send_command(Get_Configuration_Command(seid)) response = await self.send_command(Get_Configuration_Command(seid))
@@ -1537,6 +1618,7 @@ class Listener(EventEmitter):
server = Protocol(channel, self.version) server = Protocol(channel, self.version)
self.set_server(channel.connection, server) self.set_server(channel.connection, server)
self.emit('connection', server) self.emit('connection', server)
channel.on('open', on_channel_open) channel.on('open', on_channel_open)
@@ -1562,8 +1644,7 @@ class Stream:
raise InvalidStateError('current state is not IDLE') raise InvalidStateError('current state is not IDLE')
await self.remote_endpoint.set_configuration( await self.remote_endpoint.set_configuration(
self.local_endpoint.seid, self.local_endpoint.seid, self.local_endpoint.configuration
self.local_endpoint.configuration
) )
self.change_state(AVDTP_CONFIGURED_STATE) self.change_state(AVDTP_CONFIGURED_STATE)
@@ -1639,7 +1720,11 @@ class Stream:
self.change_state(AVDTP_CONFIGURED_STATE) self.change_state(AVDTP_CONFIGURED_STATE)
def on_get_configuration_command(self, configuration): def on_get_configuration_command(self, configuration):
if self.state not in {AVDTP_CONFIGURED_STATE, AVDTP_OPEN_STATE, AVDTP_STREAMING_STATE}: if self.state not in {
AVDTP_CONFIGURED_STATE,
AVDTP_OPEN_STATE,
AVDTP_STREAMING_STATE,
}:
return Get_Configuration_Reject(AVDTP_BAD_STATE_ERROR) return Get_Configuration_Reject(AVDTP_BAD_STATE_ERROR)
return self.local_endpoint.on_get_configuration_command(configuration) return self.local_endpoint.on_get_configuration_command(configuration)
@@ -1767,7 +1852,8 @@ class StreamEndPoint:
self.capabilities = capabilities self.capabilities = capabilities
def __str__(self): def __str__(self):
return '\n'.join([ return '\n'.join(
[
'SEP(', 'SEP(',
f' seid={self.seid}', f' seid={self.seid}',
f' media_type={name_or_number(AVDTP_MEDIA_TYPE_NAMES, self.media_type)}', f' media_type={name_or_number(AVDTP_MEDIA_TYPE_NAMES, self.media_type)}',
@@ -1776,8 +1862,9 @@ class StreamEndPoint:
' capabilities=[', ' capabilities=[',
'\n'.join([f' {x}' for x in self.capabilities]), '\n'.join([f' {x}' for x in self.capabilities]),
' ]', ' ]',
')' ')',
]) ]
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -1787,11 +1874,7 @@ class StreamEndPointProxy:
self.protocol = protocol self.protocol = protocol
async def set_configuration(self, int_seid, configuration): async def set_configuration(self, int_seid, configuration):
return await self.protocol.set_configuration( return await self.protocol.set_configuration(self.seid, int_seid, configuration)
self.seid,
int_seid,
configuration
)
async def open(self): async def open(self):
return await self.protocol.open(self.seid) return await self.protocol.open(self.seid)
@@ -1818,7 +1901,9 @@ class DiscoveredStreamEndPoint(StreamEndPoint, StreamEndPointProxy):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class LocalStreamEndPoint(StreamEndPoint): class LocalStreamEndPoint(StreamEndPoint):
def __init__(self, protocol, seid, media_type, tsep, capabilities, configuration=[]): def __init__(
self, protocol, seid, media_type, tsep, capabilities, configuration=[]
):
super().__init__(seid, media_type, tsep, 0, capabilities) super().__init__(seid, media_type, tsep, 0, capabilities)
self.protocol = protocol self.protocol = protocol
self.configuration = configuration self.configuration = configuration
@@ -1866,9 +1951,17 @@ class LocalSource(LocalStreamEndPoint, EventEmitter):
def __init__(self, protocol, seid, codec_capabilities, packet_pump): def __init__(self, protocol, seid, codec_capabilities, packet_pump):
capabilities = [ capabilities = [
ServiceCapabilities(AVDTP_MEDIA_TRANSPORT_SERVICE_CATEGORY), ServiceCapabilities(AVDTP_MEDIA_TRANSPORT_SERVICE_CATEGORY),
codec_capabilities codec_capabilities,
] ]
LocalStreamEndPoint.__init__(self, protocol, seid, codec_capabilities.media_type, AVDTP_TSEP_SRC, capabilities, capabilities) LocalStreamEndPoint.__init__(
self,
protocol,
seid,
codec_capabilities.media_type,
AVDTP_TSEP_SRC,
capabilities,
capabilities,
)
EventEmitter.__init__(self) EventEmitter.__init__(self)
self.packet_pump = packet_pump self.packet_pump = packet_pump
@@ -1890,10 +1983,10 @@ class LocalSource(LocalStreamEndPoint, EventEmitter):
self.configuration = configuration self.configuration = configuration
def on_start_command(self): def on_start_command(self):
asyncio.get_running_loop().create_task(self.start()) asyncio.create_task(self.start())
def on_suspend_command(self): def on_suspend_command(self):
asyncio.get_running_loop().create_task(self.stop()) asyncio.create_task(self.stop())
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -1901,9 +1994,16 @@ class LocalSink(LocalStreamEndPoint, EventEmitter):
def __init__(self, protocol, seid, codec_capabilities): def __init__(self, protocol, seid, codec_capabilities):
capabilities = [ capabilities = [
ServiceCapabilities(AVDTP_MEDIA_TRANSPORT_SERVICE_CATEGORY), ServiceCapabilities(AVDTP_MEDIA_TRANSPORT_SERVICE_CATEGORY),
codec_capabilities codec_capabilities,
] ]
LocalStreamEndPoint.__init__(self, protocol, seid, codec_capabilities.media_type, AVDTP_TSEP_SNK, capabilities) LocalStreamEndPoint.__init__(
self,
protocol,
seid,
codec_capabilities.media_type,
AVDTP_TSEP_SNK,
capabilities,
)
EventEmitter.__init__(self) EventEmitter.__init__(self)
def on_set_configuration_command(self, configuration): def on_set_configuration_command(self, configuration):
@@ -1917,5 +2017,7 @@ class LocalSink(LocalStreamEndPoint, EventEmitter):
def on_avdtp_packet(self, packet): def on_avdtp_packet(self, packet):
rtp_packet = MediaPacket.from_bytes(packet) rtp_packet = MediaPacket.from_bytes(packet)
logger.debug(f'{color("<<< RTP Packet:", "green")} {rtp_packet} {rtp_packet.payload[:16].hex()}') logger.debug(
f'{color("<<< RTP Packet:", "green")} {rtp_packet} {rtp_packet.payload[:16].hex()}'
)
self.emit('rtp_packet', rtp_packet) self.emit('rtp_packet', rtp_packet)
+3 -3
View File
@@ -62,14 +62,14 @@ class HCI_Bridge:
hci_controller_source, hci_controller_source,
hci_controller_sink, hci_controller_sink,
host_to_controller_filter=None, host_to_controller_filter=None,
controller_to_host_filter = None controller_to_host_filter=None,
): ):
tracer = PacketTracer(emit_message=logger.info) tracer = PacketTracer(emit_message=logger.info)
host_to_controller_forwarder = HCI_Bridge.Forwarder( host_to_controller_forwarder = HCI_Bridge.Forwarder(
hci_controller_sink, hci_controller_sink,
hci_host_sink, hci_host_sink,
host_to_controller_filter, host_to_controller_filter,
lambda packet: tracer.trace(packet, 0) lambda packet: tracer.trace(packet, 0),
) )
hci_host_source.set_packet_sink(host_to_controller_forwarder) hci_host_source.set_packet_sink(host_to_controller_forwarder)
@@ -77,6 +77,6 @@ class HCI_Bridge:
hci_host_sink, hci_host_sink,
hci_controller_sink, hci_controller_sink,
controller_to_host_filter, controller_to_host_filter,
lambda packet: tracer.trace(packet, 1) lambda packet: tracer.trace(packet, 1),
) )
hci_controller_source.set_packet_sink(controller_to_host_forwarder) hci_controller_source.set_packet_sink(controller_to_host_forwarder)
+1 -1
View File
@@ -2704,5 +2704,5 @@ COMPANY_IDENTIFIERS = {
0x0A7C: "WAFERLOCK", 0x0A7C: "WAFERLOCK",
0x0A7D: "Freedman Electronics Pty Ltd", 0x0A7D: "Freedman Electronics Pty Ltd",
0x0A7E: "Keba AG", 0x0A7E: "Keba AG",
0x0A7F: "Intuity Medical" 0x0A7F: "Intuity Medical",
} }
+170 -88
View File
@@ -48,11 +48,15 @@ class Connection:
def on_hci_acl_data_packet(self, packet): def on_hci_acl_data_packet(self, packet):
self.assembler.feed_packet(packet) self.assembler.feed_packet(packet)
self.controller.send_hci_packet(HCI_Number_Of_Completed_Packets_Event([(self.handle, 1)])) self.controller.send_hci_packet(
HCI_Number_Of_Completed_Packets_Event([(self.handle, 1)])
)
def on_acl_pdu(self, data): def on_acl_pdu(self, data):
if self.link: if self.link:
self.link.send_acl_data(self.controller.random_address, self.peer_address, data) self.link.send_acl_data(
self.controller.random_address, self.peer_address, data
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -62,21 +66,29 @@ class Controller:
self.hci_sink = None self.hci_sink = None
self.link = link self.link = link
self.central_connections = {} # Connections where this controller is the central self.central_connections = (
self.peripheral_connections = {} # Connections where this controller is the peripheral {}
) # Connections where this controller is the central
self.peripheral_connections = (
{}
) # Connections where this controller is the peripheral
self.hci_version = HCI_VERSION_BLUETOOTH_CORE_5_0 self.hci_version = HCI_VERSION_BLUETOOTH_CORE_5_0
self.hci_revision = 0 self.hci_revision = 0
self.lmp_version = HCI_VERSION_BLUETOOTH_CORE_5_0 self.lmp_version = HCI_VERSION_BLUETOOTH_CORE_5_0
self.lmp_subversion = 0 self.lmp_subversion = 0
self.lmp_features = bytes.fromhex('0000000060000000') # BR/EDR Not Supported, LE Supported (Controller) self.lmp_features = bytes.fromhex(
'0000000060000000'
) # BR/EDR Not Supported, LE Supported (Controller)
self.manufacturer_name = 0xFFFF self.manufacturer_name = 0xFFFF
self.hc_le_data_packet_length = 27 self.hc_le_data_packet_length = 27
self.hc_total_num_le_data_packets = 64 self.hc_total_num_le_data_packets = 64
self.supported_commands = bytes.fromhex('2000800000c000000000e40000002822000000000000040000f7ffff7f00000030f0f9ff01008004000000000000000000000000000000000000000000000000') self.supported_commands = bytes.fromhex(
'2000800000c000000000e40000002822000000000000040000f7ffff7f00000030f0f9ff01008004000000000000000000000000000000000000000000000000'
)
self.le_features = bytes.fromhex('ff49010000000000') self.le_features = bytes.fromhex('ff49010000000000')
self.le_states = bytes.fromhex('ffff3fffff030000') self.le_states = bytes.fromhex('ffff3fffff030000')
self.avertising_channel_tx_power = 0 self.advertising_channel_tx_power = 0
self.filter_accept_list_size = 8 self.filter_accept_list_size = 8
self.resolving_list_size = 8 self.resolving_list_size = 8
self.supported_max_tx_octets = 27 self.supported_max_tx_octets = 27
@@ -162,7 +174,9 @@ class Controller:
self.on_hci_packet(HCI_Packet.from_bytes(packet)) self.on_hci_packet(HCI_Packet.from_bytes(packet))
def on_hci_packet(self, packet): def on_hci_packet(self, packet):
logger.debug(f'{color("<<<", "blue")} [{self.name}] {color("HOST -> CONTROLLER", "blue")}: {packet}') logger.debug(
f'{color("<<<", "blue")} [{self.name}] {color("HOST -> CONTROLLER", "blue")}: {packet}'
)
# If the packet is a command, invoke the handler for this packet # If the packet is a command, invoke the handler for this packet
if packet.hci_packet_type == HCI_COMMAND_PACKET: if packet.hci_packet_type == HCI_COMMAND_PACKET:
@@ -179,11 +193,13 @@ class Controller:
handler = getattr(self, handler_name, self.on_hci_command) handler = getattr(self, handler_name, self.on_hci_command)
result = handler(command) result = handler(command)
if type(result) is bytes: if type(result) is bytes:
self.send_hci_packet(HCI_Command_Complete_Event( self.send_hci_packet(
HCI_Command_Complete_Event(
num_hci_command_packets=1, num_hci_command_packets=1,
command_opcode=command.op_code, command_opcode=command.op_code,
return_parameters = result return_parameters=result,
)) )
)
def on_hci_event_packet(self, event): def on_hci_event_packet(self, event):
logger.warning('!!! unexpected event packet') logger.warning('!!! unexpected event packet')
@@ -192,14 +208,18 @@ class Controller:
# Look for the connection to which this data belongs # Look for the connection to which this data belongs
connection = self.find_connection_by_handle(packet.connection_handle) connection = self.find_connection_by_handle(packet.connection_handle)
if connection is None: if connection is None:
logger.warning(f'!!! no connection for handle 0x{packet.connection_handle:04X}') logger.warning(
f'!!! no connection for handle 0x{packet.connection_handle:04X}'
)
return return
# Pass the packet to the connection # Pass the packet to the connection
connection.on_hci_acl_data_packet(packet) connection.on_hci_acl_data_packet(packet)
def send_hci_packet(self, packet): def send_hci_packet(self, packet):
logger.debug(f'{color(">>>", "green")} [{self.name}] {color("CONTROLLER -> HOST", "green")}: {packet}') logger.debug(
f'{color(">>>", "green")} [{self.name}] {color("CONTROLLER -> HOST", "green")}: {packet}'
)
if self.host: if self.host:
self.host.on_packet(packet.to_bytes()) self.host.on_packet(packet.to_bytes())
@@ -215,8 +235,7 @@ class Controller:
handle = 0 handle = 0
max_handle = 0 max_handle = 0
for connection in itertools.chain( for connection in itertools.chain(
self.central_connections.values(), self.central_connections.values(), self.peripheral_connections.values()
self.peripheral_connections.values()
): ):
max_handle = max(max_handle, connection.handle) max_handle = max(max_handle, connection.handle)
if connection.handle == handle: if connection.handle == handle:
@@ -225,12 +244,13 @@ class Controller:
return handle return handle
def find_connection_by_address(self, address): def find_connection_by_address(self, address):
return self.central_connections.get(address) or self.peripheral_connections.get(address) return self.central_connections.get(address) or self.peripheral_connections.get(
address
)
def find_connection_by_handle(self, handle): def find_connection_by_handle(self, handle):
for connection in itertools.chain( for connection in itertools.chain(
self.central_connections.values(), self.central_connections.values(), self.peripheral_connections.values()
self.peripheral_connections.values()
): ):
if connection.handle == handle: if connection.handle == handle:
return connection return connection
@@ -253,22 +273,26 @@ class Controller:
connection = self.peripheral_connections.get(peer_address) connection = self.peripheral_connections.get(peer_address)
if connection is None: if connection is None:
connection_handle = self.allocate_connection_handle() connection_handle = self.allocate_connection_handle()
connection = Connection(self, connection_handle, BT_PERIPHERAL_ROLE, peer_address, self.link) connection = Connection(
self, connection_handle, BT_PERIPHERAL_ROLE, peer_address, self.link
)
self.peripheral_connections[peer_address] = connection self.peripheral_connections[peer_address] = connection
logger.debug(f'New PERIPHERAL connection handle: 0x{connection_handle:04X}') logger.debug(f'New PERIPHERAL connection handle: 0x{connection_handle:04X}')
# Then say that the connection has completed # Then say that the connection has completed
self.send_hci_packet(HCI_LE_Connection_Complete_Event( self.send_hci_packet(
HCI_LE_Connection_Complete_Event(
status=HCI_SUCCESS, status=HCI_SUCCESS,
connection_handle=connection.handle, connection_handle=connection.handle,
role=connection.role, role=connection.role,
peer_address_type=peer_address_type, peer_address_type=peer_address_type,
peer_address=peer_address, peer_address=peer_address,
conn_interval = 10, # FIXME connection_interval=10, # FIXME
conn_latency = 0, # FIXME peripheral_latency=0, # FIXME
supervision_timeout=10, # FIXME supervision_timeout=10, # FIXME
master_clock_accuracy = 7 # FIXME central_clock_accuracy=7, # FIXME
)) )
)
def on_link_central_disconnected(self, peer_address, reason): def on_link_central_disconnected(self, peer_address, reason):
''' '''
@@ -277,18 +301,22 @@ class Controller:
# Send a disconnection complete event # Send a disconnection complete event
if connection := self.peripheral_connections.get(peer_address): if connection := self.peripheral_connections.get(peer_address):
self.send_hci_packet(HCI_Disconnection_Complete_Event( self.send_hci_packet(
HCI_Disconnection_Complete_Event(
status=HCI_SUCCESS, status=HCI_SUCCESS,
connection_handle=connection.handle, connection_handle=connection.handle,
reason = reason reason=reason,
)) )
)
# Remove the connection # Remove the connection
del self.peripheral_connections[peer_address] del self.peripheral_connections[peer_address]
else: else:
logger.warn(f'!!! No peripheral connection found for {peer_address}') logger.warn(f'!!! No peripheral connection found for {peer_address}')
def on_link_peripheral_connection_complete(self, le_create_connection_command, status): def on_link_peripheral_connection_complete(
self, le_create_connection_command, status
):
''' '''
Called by the link when a connection has been made or has failed to be made Called by the link when a connection has been made or has failed to be made
''' '''
@@ -300,29 +328,29 @@ class Controller:
if connection is None: if connection is None:
connection_handle = self.allocate_connection_handle() connection_handle = self.allocate_connection_handle()
connection = Connection( connection = Connection(
self, self, connection_handle, BT_CENTRAL_ROLE, peer_address, self.link
connection_handle,
BT_CENTRAL_ROLE,
peer_address,
self.link
) )
self.central_connections[peer_address] = connection self.central_connections[peer_address] = connection
logger.debug(f'New CENTRAL connection handle: 0x{connection_handle:04X}') logger.debug(
f'New CENTRAL connection handle: 0x{connection_handle:04X}'
)
else: else:
connection = None connection = None
# Say that the connection has completed # Say that the connection has completed
self.send_hci_packet(HCI_LE_Connection_Complete_Event( self.send_hci_packet(
HCI_LE_Connection_Complete_Event(
status=status, status=status,
connection_handle=connection.handle if connection else 0, connection_handle=connection.handle if connection else 0,
role=BT_CENTRAL_ROLE, role=BT_CENTRAL_ROLE,
peer_address_type=le_create_connection_command.peer_address_type, peer_address_type=le_create_connection_command.peer_address_type,
peer_address=le_create_connection_command.peer_address, peer_address=le_create_connection_command.peer_address,
conn_interval = le_create_connection_command.conn_interval_min, connection_interval=le_create_connection_command.connection_interval_min,
conn_latency = le_create_connection_command.conn_latency, peripheral_latency=le_create_connection_command.max_latency,
supervision_timeout=le_create_connection_command.supervision_timeout, supervision_timeout=le_create_connection_command.supervision_timeout,
master_clock_accuracy = 0 central_clock_accuracy=0,
)) )
)
def on_link_peripheral_disconnection_complete(self, disconnection_command, status): def on_link_peripheral_disconnection_complete(self, disconnection_command, status):
''' '''
@@ -330,14 +358,18 @@ class Controller:
''' '''
# Send a disconnection complete event # Send a disconnection complete event
self.send_hci_packet(HCI_Disconnection_Complete_Event( self.send_hci_packet(
HCI_Disconnection_Complete_Event(
status=status, status=status,
connection_handle=disconnection_command.connection_handle, connection_handle=disconnection_command.connection_handle,
reason = disconnection_command.reason reason=disconnection_command.reason,
)) )
)
# Remove the connection # Remove the connection
if connection := self.find_central_connection_by_handle(disconnection_command.connection_handle): if connection := self.find_central_connection_by_handle(
disconnection_command.connection_handle
):
logger.debug(f'CENTRAL Connection removed: {connection}') logger.debug(f'CENTRAL Connection removed: {connection}')
del self.central_connections[connection.peer_address] del self.central_connections[connection.peer_address]
@@ -348,11 +380,13 @@ class Controller:
# Send a disconnection complete event # Send a disconnection complete event
if connection := self.central_connections.get(peer_address): if connection := self.central_connections.get(peer_address):
self.send_hci_packet(HCI_Disconnection_Complete_Event( self.send_hci_packet(
HCI_Disconnection_Complete_Event(
status=HCI_SUCCESS, status=HCI_SUCCESS,
connection_handle=connection.handle, connection_handle=connection.handle,
reason = HCI_CONNECTION_TIMEOUT_ERROR reason=HCI_CONNECTION_TIMEOUT_ERROR,
)) )
)
# Remove the connection # Remove the connection
del self.central_connections[peer_address] del self.central_connections[peer_address]
@@ -364,9 +398,7 @@ class Controller:
if connection := self.find_connection_by_address(peer_address): if connection := self.find_connection_by_address(peer_address):
self.send_hci_packet( self.send_hci_packet(
HCI_Encryption_Change_Event( HCI_Encryption_Change_Event(
status = 0, status=0, connection_handle=connection.handle, encryption_enabled=1
connection_handle = connection.handle,
encryption_enabled = 1
) )
) )
@@ -394,7 +426,7 @@ class Controller:
address_type=sender_address.address_type, address_type=sender_address.address_type,
address=sender_address, address=sender_address,
data=data, data=data,
rssi = -50 rssi=-50,
) )
self.send_hci_packet(HCI_LE_Advertising_Report_Event([report])) self.send_hci_packet(HCI_LE_Advertising_Report_Event([report]))
@@ -405,7 +437,7 @@ class Controller:
address_type=sender_address.address_type, address_type=sender_address.address_type,
address=sender_address, address=sender_address,
data=data, data=data,
rssi = -50 rssi=-50,
) )
self.send_hci_packet(HCI_LE_Advertising_Report_Event([report])) self.send_hci_packet(HCI_LE_Advertising_Report_Event([report]))
@@ -414,14 +446,18 @@ class Controller:
############################################################ ############################################################
def on_advertising_timer_fired(self): def on_advertising_timer_fired(self):
self.send_advertising_data() self.send_advertising_data()
self.advertising_timer_handle = asyncio.get_running_loop().call_later(self.advertising_interval / 1000.0, self.on_advertising_timer_fired) self.advertising_timer_handle = asyncio.get_running_loop().call_later(
self.advertising_interval / 1000.0, self.on_advertising_timer_fired
)
def start_advertising(self): def start_advertising(self):
# Stop any ongoing advertising before we start again # Stop any ongoing advertising before we start again
self.stop_advertising() self.stop_advertising()
# Advertise now # Advertise now
self.advertising_timer_handle = asyncio.get_running_loop().call_soon(self.on_advertising_timer_fired) self.advertising_timer_handle = asyncio.get_running_loop().call_soon(
self.on_advertising_timer_fired
)
def stop_advertising(self): def stop_advertising(self):
if self.advertising_timer_handle is not None: if self.advertising_timer_handle is not None:
@@ -455,14 +491,20 @@ class Controller:
See Bluetooth spec Vol 2, Part E - 7.1.6 Disconnect Command See Bluetooth spec Vol 2, Part E - 7.1.6 Disconnect Command
''' '''
# First, say that the disconnection is pending # First, say that the disconnection is pending
self.send_hci_packet(HCI_Command_Status_Event( self.send_hci_packet(
HCI_Command_Status_Event(
status=HCI_COMMAND_STATUS_PENDING, status=HCI_COMMAND_STATUS_PENDING,
num_hci_command_packets=1, num_hci_command_packets=1,
command_opcode = command.op_code command_opcode=command.op_code,
)) )
)
# Notify the link of the disconnection # Notify the link of the disconnection
if not (connection := self.find_central_connection_by_handle(command.connection_handle)): if not (
connection := self.find_central_connection_by_handle(
command.connection_handle
)
):
logger.warn('connection not found') logger.warn('connection not found')
return return
@@ -583,13 +625,15 @@ class Controller:
''' '''
See Bluetooth spec Vol 2, Part E - 7.4.1 Read Local Version Information Command See Bluetooth spec Vol 2, Part E - 7.4.1 Read Local Version Information Command
''' '''
return struct.pack('<BBHBHH', return struct.pack(
'<BBHBHH',
HCI_SUCCESS, HCI_SUCCESS,
self.hci_version, self.hci_version,
self.hci_revision, self.hci_revision,
self.lmp_version, self.lmp_version,
self.manufacturer_name, self.manufacturer_name,
self.lmp_subversion) self.lmp_subversion,
)
def on_hci_read_local_supported_commands_command(self, command): def on_hci_read_local_supported_commands_command(self, command):
''' '''
@@ -607,7 +651,11 @@ class Controller:
''' '''
See Bluetooth spec Vol 2, Part E - 7.4.6 Read BD_ADDR Command See Bluetooth spec Vol 2, Part E - 7.4.6 Read BD_ADDR Command
''' '''
bd_addr = self._public_address.to_bytes() if self._public_address is not None else bytes(6) bd_addr = (
self._public_address.to_bytes()
if self._public_address is not None
else bytes(6)
)
return bytes([HCI_SUCCESS]) + bd_addr return bytes([HCI_SUCCESS]) + bd_addr
def on_hci_le_set_event_mask_command(self, command): def on_hci_le_set_event_mask_command(self, command):
@@ -621,10 +669,12 @@ class Controller:
''' '''
See Bluetooth spec Vol 2, Part E - 7.8.2 LE Read Buffer Size Command See Bluetooth spec Vol 2, Part E - 7.8.2 LE Read Buffer Size Command
''' '''
return struct.pack('<BHB', return struct.pack(
'<BHB',
HCI_SUCCESS, HCI_SUCCESS,
self.hc_le_data_packet_length, self.hc_le_data_packet_length,
self.hc_total_num_le_data_packets) self.hc_total_num_le_data_packets,
)
def on_hci_le_read_local_supported_features_command(self, command): def on_hci_le_read_local_supported_features_command(self, command):
''' '''
@@ -650,7 +700,7 @@ class Controller:
''' '''
See Bluetooth spec Vol 2, Part E - 7.8.6 LE Read Advertising Channel Tx Power Command See Bluetooth spec Vol 2, Part E - 7.8.6 LE Read Advertising Channel Tx Power Command
''' '''
return bytes([HCI_SUCCESS, self.avertising_channel_tx_power]) return bytes([HCI_SUCCESS, self.advertising_channel_tx_power])
def on_hci_le_set_advertising_data_command(self, command): def on_hci_le_set_advertising_data_command(self, command):
''' '''
@@ -708,22 +758,26 @@ class Controller:
# Check that we don't already have a pending connection # Check that we don't already have a pending connection
if self.link.get_pending_connection(): if self.link.get_pending_connection():
self.send_hci_packet(HCI_Command_Status_Event( self.send_hci_packet(
HCI_Command_Status_Event(
status=HCI_COMMAND_DISALLOWED_ERROR, status=HCI_COMMAND_DISALLOWED_ERROR,
num_hci_command_packets=1, num_hci_command_packets=1,
command_opcode = command.op_code command_opcode=command.op_code,
)) )
)
return return
# Initiate the connection # Initiate the connection
self.link.connect(self.random_address, command) self.link.connect(self.random_address, command)
# Say that the connection is pending # Say that the connection is pending
self.send_hci_packet(HCI_Command_Status_Event( self.send_hci_packet(
HCI_Command_Status_Event(
status=HCI_COMMAND_STATUS_PENDING, status=HCI_COMMAND_STATUS_PENDING,
num_hci_command_packets=1, num_hci_command_packets=1,
command_opcode = command.op_code command_opcode=command.op_code,
)) )
)
def on_hci_le_create_connection_cancel_command(self, command): def on_hci_le_create_connection_cancel_command(self, command):
''' '''
@@ -761,18 +815,22 @@ class Controller:
''' '''
# First, say that the command is pending # First, say that the command is pending
self.send_hci_packet(HCI_Command_Status_Event( self.send_hci_packet(
HCI_Command_Status_Event(
status=HCI_COMMAND_STATUS_PENDING, status=HCI_COMMAND_STATUS_PENDING,
num_hci_command_packets=1, num_hci_command_packets=1,
command_opcode = command.op_code command_opcode=command.op_code,
)) )
)
# Then send the remote features # Then send the remote features
self.send_hci_packet(HCI_LE_Read_Remote_Features_Complete_Event( self.send_hci_packet(
HCI_LE_Read_Remote_Features_Complete_Event(
status=HCI_SUCCESS, status=HCI_SUCCESS,
connection_handle=0, connection_handle=0,
le_features = bytes.fromhex('dd40000000000000') le_features=bytes.fromhex('dd40000000000000'),
)) )
)
def on_hci_le_rand_command(self, command): def on_hci_le_rand_command(self, command):
''' '''
@@ -786,7 +844,11 @@ class Controller:
''' '''
# Check the parameters # Check the parameters
if not (connection := self.find_central_connection_by_handle(command.connection_handle)): if not (
connection := self.find_central_connection_by_handle(
command.connection_handle
)
):
logger.warn('connection not found') logger.warn('connection not found')
return bytes([HCI_INVALID_HCI_COMMAND_PARAMETERS_ERROR]) return bytes([HCI_INVALID_HCI_COMMAND_PARAMETERS_ERROR])
@@ -796,14 +858,16 @@ class Controller:
connection.peer_address, connection.peer_address,
command.random_number, command.random_number,
command.encrypted_diversifier, command.encrypted_diversifier,
command.long_term_key command.long_term_key,
) )
self.send_hci_packet(HCI_Command_Status_Event( self.send_hci_packet(
HCI_Command_Status_Event(
status=HCI_COMMAND_STATUS_PENDING, status=HCI_COMMAND_STATUS_PENDING,
num_hci_command_packets=1, num_hci_command_packets=1,
command_opcode = command.op_code command_opcode=command.op_code,
)) )
)
def on_hci_le_read_supported_states_command(self, command): def on_hci_le_read_supported_states_command(self, command):
''' '''
@@ -815,16 +879,20 @@ class Controller:
''' '''
See Bluetooth spec Vol 2, Part E - 7.8.34 LE Read Suggested Default Data Length Command See Bluetooth spec Vol 2, Part E - 7.8.34 LE Read Suggested Default Data Length Command
''' '''
return struct.pack('<BHH', return struct.pack(
'<BHH',
HCI_SUCCESS, HCI_SUCCESS,
self.suggested_max_tx_octets, self.suggested_max_tx_octets,
self.suggested_max_tx_time) self.suggested_max_tx_time,
)
def on_hci_le_write_suggested_default_data_length_command(self, command): def on_hci_le_write_suggested_default_data_length_command(self, command):
''' '''
See Bluetooth spec Vol 2, Part E - 7.8.35 LE Write Suggested Default Data Length Command See Bluetooth spec Vol 2, Part E - 7.8.35 LE Write Suggested Default Data Length Command
''' '''
self.suggested_max_tx_octets, self.suggested_max_tx_time = struct.unpack('<HH', command.parameters[:4]) self.suggested_max_tx_octets, self.suggested_max_tx_time = struct.unpack(
'<HH', command.parameters[:4]
)
return bytes([HCI_SUCCESS]) return bytes([HCI_SUCCESS])
def on_hci_le_read_local_p_256_public_key_command(self, command): def on_hci_le_read_local_p_256_public_key_command(self, command):
@@ -857,9 +925,9 @@ class Controller:
See Bluetooth spec Vol 2, Part E - 7.8.44 LE Set Address Resolution Enable Command See Bluetooth spec Vol 2, Part E - 7.8.44 LE Set Address Resolution Enable Command
''' '''
ret = HCI_SUCCESS ret = HCI_SUCCESS
if command.address_resolution == 1: if command.address_resolution_enable == 1:
self.le_address_resolution = True self.le_address_resolution = True
elif command.address_resolution == 0: elif command.address_resolution_enable == 0:
self.le_address_resolution = False self.le_address_resolution = False
else: else:
ret = HCI_INVALID_HCI_COMMAND_PARAMETERS_ERROR ret = HCI_INVALID_HCI_COMMAND_PARAMETERS_ERROR
@@ -876,12 +944,26 @@ class Controller:
''' '''
See Bluetooth spec Vol 2, Part E - 7.8.46 LE Read Maximum Data Length Command See Bluetooth spec Vol 2, Part E - 7.8.46 LE Read Maximum Data Length Command
''' '''
return struct.pack('<BHHHH', return struct.pack(
'<BHHHH',
HCI_SUCCESS, HCI_SUCCESS,
self.supported_max_tx_octets, self.supported_max_tx_octets,
self.supported_max_tx_time, self.supported_max_tx_time,
self.supported_max_rx_octets, self.supported_max_rx_octets,
self.supported_max_rx_time) self.supported_max_rx_time,
)
def on_hci_le_read_phy_command(self, command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.47 LE Read PHY command
'''
return struct.pack(
'<BHBB',
HCI_SUCCESS,
command.connection_handle,
HCI_LE_1M_PHY,
HCI_LE_1M_PHY,
)
def on_hci_le_set_default_phy_command(self, command): def on_hci_le_set_default_phy_command(self, command):
''' '''
@@ -890,6 +972,6 @@ class Controller:
self.default_phy = { self.default_phy = {
'all_phys': command.all_phys, 'all_phys': command.all_phys,
'tx_phys': command.tx_phys, 'tx_phys': command.tx_phys,
'rx_phys': command.rx_phys 'rx_phys': command.rx_phys,
} }
return bytes([HCI_SUCCESS]) return bytes([HCI_SUCCESS])
+122 -33
View File
@@ -23,6 +23,8 @@ from .company_ids import COMPANY_IDENTIFIERS
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Constants # Constants
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# fmt: off
BT_CENTRAL_ROLE = 0 BT_CENTRAL_ROLE = 0
BT_PERIPHERAL_ROLE = 1 BT_PERIPHERAL_ROLE = 1
@@ -30,6 +32,9 @@ BT_BR_EDR_TRANSPORT = 0
BT_LE_TRANSPORT = 1 BT_LE_TRANSPORT = 1
# fmt: on
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Utils # Utils
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -58,11 +63,19 @@ def padded_bytes(buffer, size):
return buffer + bytes(padding_size) return buffer + bytes(padding_size)
def get_dict_key_by_value(dictionary, value):
for key, val in dictionary.items():
if val == value:
return key
return None
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Exceptions # Exceptions
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class BaseError(Exception): class BaseError(Exception):
"""Base class for errors with an error code, error name and namespace""" """Base class for errors with an error code, error name and namespace"""
def __init__(self, error_code, error_namespace='', error_name='', details=''): def __init__(self, error_code, error_namespace='', error_name='', details=''):
super().__init__() super().__init__()
self.error_code = error_code self.error_code = error_code
@@ -91,15 +104,33 @@ class TimeoutError(Exception):
"""Timeout Error""" """Timeout Error"""
class CommandTimeoutError(Exception):
"""Command Timeout Error"""
class InvalidStateError(Exception): class InvalidStateError(Exception):
"""Invalid State Error""" """Invalid State Error"""
class ConnectionError(BaseError): class ConnectionError(BaseError):
"""Connection Error""" """Connection Error"""
FAILURE = 0x01 FAILURE = 0x01
CONNECTION_REFUSED = 0x02 CONNECTION_REFUSED = 0x02
def __init__(
self,
error_code,
transport,
peer_address,
error_namespace='',
error_name='',
details='',
):
super().__init__(error_code, error_namespace, error_name, details)
self.transport = transport
self.peer_address = peer_address
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# UUID # UUID
@@ -112,6 +143,7 @@ class UUID:
''' '''
See Bluetooth spec Vol 3, Part B - 2.5.1 UUID See Bluetooth spec Vol 3, Part B - 2.5.1 UUID
''' '''
BASE_UUID = bytes.fromhex('00001000800000805F9B34FB') BASE_UUID = bytes.fromhex('00001000800000805F9B34FB')
UUIDS = [] # Registry of all instances created UUIDS = [] # Registry of all instances created
@@ -120,13 +152,18 @@ class UUID:
self.uuid_bytes = struct.pack('<H', uuid_str_or_int) self.uuid_bytes = struct.pack('<H', uuid_str_or_int)
else: else:
if len(uuid_str_or_int) == 36: if len(uuid_str_or_int) == 36:
if uuid_str_or_int[8] != '-' or uuid_str_or_int[13] != '-' or uuid_str_or_int[18] != '-' or uuid_str_or_int[23] != '-': if (
uuid_str_or_int[8] != '-'
or uuid_str_or_int[13] != '-'
or uuid_str_or_int[18] != '-'
or uuid_str_or_int[23] != '-'
):
raise ValueError('invalid UUID format') raise ValueError('invalid UUID format')
uuid_str = uuid_str_or_int.replace('-', '') uuid_str = uuid_str_or_int.replace('-', '')
else: else:
uuid_str = uuid_str_or_int uuid_str = uuid_str_or_int
if len(uuid_str) != 32 and len(uuid_str) != 8 and len(uuid_str) != 4: if len(uuid_str) != 32 and len(uuid_str) != 8 and len(uuid_str) != 4:
raise ValueError('invalid UUID format') raise ValueError(f"invalid UUID format: {uuid_str}")
self.uuid_bytes = bytes(reversed(bytes.fromhex(uuid_str))) self.uuid_bytes = bytes(reversed(bytes.fromhex(uuid_str)))
self.name = name self.name = name
@@ -189,13 +226,15 @@ class UUID:
if len(self.uuid_bytes) == 2 or len(self.uuid_bytes) == 4: if len(self.uuid_bytes) == 2 or len(self.uuid_bytes) == 4:
return bytes(reversed(self.uuid_bytes)).hex().upper() return bytes(reversed(self.uuid_bytes)).hex().upper()
else: else:
return ''.join([ return ''.join(
[
bytes(reversed(self.uuid_bytes[12:16])).hex(), bytes(reversed(self.uuid_bytes[12:16])).hex(),
bytes(reversed(self.uuid_bytes[10:12])).hex(), bytes(reversed(self.uuid_bytes[10:12])).hex(),
bytes(reversed(self.uuid_bytes[8:10])).hex(), bytes(reversed(self.uuid_bytes[8:10])).hex(),
bytes(reversed(self.uuid_bytes[6:8])).hex(), bytes(reversed(self.uuid_bytes[6:8])).hex(),
bytes(reversed(self.uuid_bytes[0:6])).hex() bytes(reversed(self.uuid_bytes[0:6])).hex(),
]).upper() ]
).upper()
def __bytes__(self): def __bytes__(self):
return self.to_bytes() return self.to_bytes()
@@ -219,13 +258,15 @@ class UUID:
v = struct.unpack('<I', self.uuid_bytes)[0] v = struct.unpack('<I', self.uuid_bytes)[0]
result = f'UUID-32:{v:08X}' result = f'UUID-32:{v:08X}'
else: else:
result = '-'.join([ result = '-'.join(
[
bytes(reversed(self.uuid_bytes[12:16])).hex(), bytes(reversed(self.uuid_bytes[12:16])).hex(),
bytes(reversed(self.uuid_bytes[10:12])).hex(), bytes(reversed(self.uuid_bytes[10:12])).hex(),
bytes(reversed(self.uuid_bytes[8:10])).hex(), bytes(reversed(self.uuid_bytes[8:10])).hex(),
bytes(reversed(self.uuid_bytes[6:8])).hex(), bytes(reversed(self.uuid_bytes[6:8])).hex(),
bytes(reversed(self.uuid_bytes[0:6])).hex() bytes(reversed(self.uuid_bytes[0:6])).hex(),
]).upper() ]
).upper()
if self.name is not None: if self.name is not None:
return result + f' ({self.name})' return result + f' ({self.name})'
else: else:
@@ -238,6 +279,7 @@ class UUID:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Common UUID constants # Common UUID constants
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# fmt: off
# Protocol Identifiers # Protocol Identifiers
BT_SDP_PROTOCOL_ID = UUID.from_16_bits(0x0001, 'SDP') BT_SDP_PROTOCOL_ID = UUID.from_16_bits(0x0001, 'SDP')
@@ -343,11 +385,15 @@ BT_HDP_SERVICE = UUID.from_16_bits(0x1400,
BT_HDP_SOURCE_SERVICE = UUID.from_16_bits(0x1401, 'HDP Source') BT_HDP_SOURCE_SERVICE = UUID.from_16_bits(0x1401, 'HDP Source')
BT_HDP_SINK_SERVICE = UUID.from_16_bits(0x1402, 'HDP Sink') BT_HDP_SINK_SERVICE = UUID.from_16_bits(0x1402, 'HDP Sink')
# fmt: on
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# DeviceClass # DeviceClass
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class DeviceClass: class DeviceClass:
# fmt: off
# Major Service Classes (flags combined with OR) # Major Service Classes (flags combined with OR)
LIMITED_DISCOVERABLE_MODE_SERVICE_CLASS = (1 << 0) LIMITED_DISCOVERABLE_MODE_SERVICE_CLASS = (1 << 0)
LE_AUDIO_SERVICE_CLASS = (1 << 1) LE_AUDIO_SERVICE_CLASS = (1 << 1)
@@ -515,11 +561,17 @@ class DeviceClass:
PERIPHERAL_MAJOR_DEVICE_CLASS: PERIPHERAL_MINOR_DEVICE_CLASS_NAMES PERIPHERAL_MAJOR_DEVICE_CLASS: PERIPHERAL_MINOR_DEVICE_CLASS_NAMES
} }
# fmt: on
@staticmethod @staticmethod
def split_class_of_device(class_of_device): def split_class_of_device(class_of_device):
# Split the bit fields of the composite class of device value into: # Split the bit fields of the composite class of device value into:
# (service_classes, major_device_class, minor_device_class) # (service_classes, major_device_class, minor_device_class)
return ((class_of_device >> 13 & 0x7FF), (class_of_device >> 8 & 0x1F), (class_of_device >> 2 & 0x3F)) return (
(class_of_device >> 13 & 0x7FF),
(class_of_device >> 8 & 0x1F),
(class_of_device >> 2 & 0x3F),
)
@staticmethod @staticmethod
def pack_class_of_device(service_classes, major_device_class, minor_device_class): def pack_class_of_device(service_classes, major_device_class, minor_device_class):
@@ -527,7 +579,9 @@ class DeviceClass:
@staticmethod @staticmethod
def service_class_labels(service_class_flags): def service_class_labels(service_class_flags):
return bit_flags_to_strings(service_class_flags, DeviceClass.SERVICE_CLASS_LABELS) return bit_flags_to_strings(
service_class_flags, DeviceClass.SERVICE_CLASS_LABELS
)
@staticmethod @staticmethod
def major_device_class_name(device_class): def major_device_class_name(device_class):
@@ -545,6 +599,8 @@ class DeviceClass:
# Advertising Data # Advertising Data
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class AdvertisingData: class AdvertisingData:
# fmt: off
# This list is only partial, it still needs to be filled in from the spec # This list is only partial, it still needs to be filled in from the spec
FLAGS = 0x01 FLAGS = 0x01
INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS = 0x02 INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS = 0x02
@@ -656,6 +712,8 @@ class AdvertisingData:
BR_EDR_CONTROLLER_FLAG = 0x08 BR_EDR_CONTROLLER_FLAG = 0x08
BR_EDR_HOST_FLAG = 0x10 BR_EDR_HOST_FLAG = 0x10
# fmt: on
def __init__(self, ad_structures=[]): def __init__(self, ad_structures=[]):
self.ad_structures = ad_structures[:] self.ad_structures = ad_structures[:]
@@ -667,19 +725,17 @@ class AdvertisingData:
@staticmethod @staticmethod
def flags_to_string(flags, short=False): def flags_to_string(flags, short=False):
flag_names = [ flag_names = (
'LE Limited', ['LE Limited', 'LE General', 'No BR/EDR', 'BR/EDR C', 'BR/EDR H']
'LE General', if short
'No BR/EDR', else [
'BR/EDR C',
'BR/EDR H'
] if short else [
'LE Limited Discoverable Mode', 'LE Limited Discoverable Mode',
'LE General Discoverable Mode', 'LE General Discoverable Mode',
'BR/EDR Not Supported', 'BR/EDR Not Supported',
'Simultaneous LE and BR/EDR (Controller)', 'Simultaneous LE and BR/EDR (Controller)',
'Simultaneous LE and BR/EDR (Host)' 'Simultaneous LE and BR/EDR (Host)',
] ]
)
return ','.join(bit_flags_to_strings(flags, flag_names)) return ','.join(bit_flags_to_strings(flags, flag_names))
@staticmethod @staticmethod
@@ -693,10 +749,12 @@ class AdvertisingData:
@staticmethod @staticmethod
def uuid_list_to_string(ad_data, uuid_size): def uuid_list_to_string(ad_data, uuid_size):
return ', '.join([ return ', '.join(
[
str(uuid) str(uuid)
for uuid in AdvertisingData.uuid_list_to_objects(ad_data, uuid_size) for uuid in AdvertisingData.uuid_list_to_objects(ad_data, uuid_size)
]) ]
)
@staticmethod @staticmethod
def ad_data_to_string(ad_type, ad_data): def ad_data_to_string(ad_type, ad_data):
@@ -760,17 +818,20 @@ class AdvertisingData:
def ad_data_to_object(ad_type, ad_data): def ad_data_to_object(ad_type, ad_data):
if ad_type in { if ad_type in {
AdvertisingData.COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS, AdvertisingData.COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.LIST_OF_16_BIT_SERVICE_SOLICITATION_UUIDS,
}: }:
return AdvertisingData.uuid_list_to_objects(ad_data, 2) return AdvertisingData.uuid_list_to_objects(ad_data, 2)
elif ad_type in { elif ad_type in {
AdvertisingData.COMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS, AdvertisingData.COMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS AdvertisingData.INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.LIST_OF_32_BIT_SERVICE_SOLICITATION_UUIDS,
}: }:
return AdvertisingData.uuid_list_to_objects(ad_data, 4) return AdvertisingData.uuid_list_to_objects(ad_data, 4)
elif ad_type in { elif ad_type in {
AdvertisingData.COMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS, AdvertisingData.COMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.LIST_OF_128_BIT_SERVICE_SOLICITATION_UUIDS,
}: }:
return AdvertisingData.uuid_list_to_objects(ad_data, 16) return AdvertisingData.uuid_list_to_objects(ad_data, 16)
elif ad_type == AdvertisingData.SERVICE_DATA_16_BIT_UUID: elif ad_type == AdvertisingData.SERVICE_DATA_16_BIT_UUID:
@@ -781,11 +842,21 @@ class AdvertisingData:
return (UUID.from_bytes(ad_data[:16]), ad_data[16:]) return (UUID.from_bytes(ad_data[:16]), ad_data[16:])
elif ad_type in { elif ad_type in {
AdvertisingData.SHORTENED_LOCAL_NAME, AdvertisingData.SHORTENED_LOCAL_NAME,
AdvertisingData.COMPLETE_LOCAL_NAME AdvertisingData.COMPLETE_LOCAL_NAME,
AdvertisingData.URI,
}: }:
return ad_data.decode("utf-8") return ad_data.decode("utf-8")
elif ad_type == AdvertisingData.TX_POWER_LEVEL: elif ad_type in {AdvertisingData.TX_POWER_LEVEL, AdvertisingData.FLAGS}:
return ad_data[0] return ad_data[0]
elif ad_type in {
AdvertisingData.APPEARANCE,
AdvertisingData.ADVERTISING_INTERVAL,
}:
return struct.unpack('<H', ad_data)[0]
elif ad_type == AdvertisingData.CLASS_OF_DEVICE:
return struct.unpack('<I', bytes([*ad_data, 0]))[0]
elif ad_type == AdvertisingData.PERIPHERAL_CONNECTION_INTERVAL_RANGE:
return struct.unpack('<HH', ad_data)
elif ad_type == AdvertisingData.MANUFACTURER_SPECIFIC_DATA: elif ad_type == AdvertisingData.MANUFACTURER_SPECIFIC_DATA:
return (struct.unpack_from('<H', ad_data, 0)[0], ad_data[2:]) return (struct.unpack_from('<H', ad_data, 0)[0], ad_data[2:])
else: else:
@@ -802,26 +873,40 @@ class AdvertisingData:
self.ad_structures.append((ad_type, ad_data)) self.ad_structures.append((ad_type, ad_data))
offset += length offset += length
def get(self, type_id, return_all=False, raw=True): def get(self, type_id, return_all=False, raw=False):
''' '''
Get Advertising Data Structure(s) with a given type Get Advertising Data Structure(s) with a given type
If return_all is True, returns a (possibly empty) list of matches, If return_all is True, returns a (possibly empty) list of matches,
else returns the first entry, or None if no structure matches. else returns the first entry, or None if no structure matches.
''' '''
def process_ad_data(ad_data): def process_ad_data(ad_data):
return ad_data if raw else self.ad_data_to_object(type_id, ad_data) return ad_data if raw else self.ad_data_to_object(type_id, ad_data)
if return_all: if return_all:
return [process_ad_data(ad[1]) for ad in self.ad_structures if ad[0] == type_id] return [
process_ad_data(ad[1]) for ad in self.ad_structures if ad[0] == type_id
]
else: else:
return next((process_ad_data(ad[1]) for ad in self.ad_structures if ad[0] == type_id), None) return next(
(
process_ad_data(ad[1])
for ad in self.ad_structures
if ad[0] == type_id
),
None,
)
def __bytes__(self): def __bytes__(self):
return b''.join([bytes([len(x[1]) + 1, x[0]]) + x[1] for x in self.ad_structures]) return b''.join(
[bytes([len(x[1]) + 1, x[0]]) + x[1] for x in self.ad_structures]
)
def to_string(self, separator=', '): def to_string(self, separator=', '):
return separator.join([AdvertisingData.ad_data_to_string(x[0], x[1]) for x in self.ad_structures]) return separator.join(
[AdvertisingData.ad_data_to_string(x[0], x[1]) for x in self.ad_structures]
)
def __str__(self): def __str__(self):
return self.to_string() return self.to_string()
@@ -831,13 +916,17 @@ class AdvertisingData:
# Connection Parameters # Connection Parameters
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class ConnectionParameters: class ConnectionParameters:
def __init__(self, connection_interval, connection_latency, supervision_timeout): def __init__(self, connection_interval, peripheral_latency, supervision_timeout):
self.connection_interval = connection_interval self.connection_interval = connection_interval
self.connection_latency = connection_latency self.peripheral_latency = peripheral_latency
self.supervision_timeout = supervision_timeout self.supervision_timeout = supervision_timeout
def __str__(self): def __str__(self):
return f'ConnectionParameters(connection_interval={self.connection_interval}, connection_latency={self.connection_latency}, supervision_timeout={self.supervision_timeout}' return (
f'ConnectionParameters(connection_interval={self.connection_interval}, '
f'peripheral_latency={self.peripheral_latency}, '
f'supervision_timeout={self.supervision_timeout}'
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
+71 -43
View File
@@ -24,19 +24,16 @@
import logging import logging
import operator import operator
import platform import platform
if platform.system() != 'Emscripten': if platform.system() != 'Emscripten':
import secrets import secrets
from cryptography.hazmat.primitives.ciphers import ( from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
Cipher,
algorithms,
modes
)
from cryptography.hazmat.primitives.asymmetric.ec import ( from cryptography.hazmat.primitives.asymmetric.ec import (
generate_private_key, generate_private_key,
ECDH, ECDH,
EllipticCurvePublicNumbers, EllipticCurvePublicNumbers,
EllipticCurvePrivateNumbers, EllipticCurvePrivateNumbers,
SECP256R1 SECP256R1,
) )
from cryptography.hazmat.primitives import cmac from cryptography.hazmat.primitives import cmac
else: else:
@@ -66,16 +63,26 @@ class EccKey:
d = int.from_bytes(d_bytes, byteorder='big', signed=False) d = int.from_bytes(d_bytes, byteorder='big', signed=False)
x = int.from_bytes(x_bytes, byteorder='big', signed=False) x = int.from_bytes(x_bytes, byteorder='big', signed=False)
y = int.from_bytes(y_bytes, byteorder='big', signed=False) y = int.from_bytes(y_bytes, byteorder='big', signed=False)
private_key = EllipticCurvePrivateNumbers(d, EllipticCurvePublicNumbers(x, y, SECP256R1())).private_key() private_key = EllipticCurvePrivateNumbers(
d, EllipticCurvePublicNumbers(x, y, SECP256R1())
).private_key()
return cls(private_key) return cls(private_key)
@property @property
def x(self): def x(self):
return self.private_key.public_key().public_numbers().x.to_bytes(32, byteorder='big') return (
self.private_key.public_key()
.public_numbers()
.x.to_bytes(32, byteorder='big')
)
@property @property
def y(self): def y(self):
return self.private_key.public_key().public_numbers().y.to_bytes(32, byteorder='big') return (
self.private_key.public_key()
.public_numbers()
.y.to_bytes(32, byteorder='big')
)
def dh(self, public_key_x, public_key_y): def dh(self, public_key_x, public_key_y):
x = int.from_bytes(public_key_x, byteorder='big', signed=False) x = int.from_bytes(public_key_x, byteorder='big', signed=False)
@@ -92,7 +99,7 @@ class EccKey:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def xor(x, y): def xor(x, y):
assert(len(x) == len(y)) assert len(x) == len(y)
return bytes(map(operator.xor, x, y)) return bytes(map(operator.xor, x, y))
@@ -165,7 +172,11 @@ def f4(u, v, x, z):
''' '''
See Bluetooth spec, Vol 3, Part H - 2.2.6 LE Secure Connections Confirm Value Generation Function f4 See Bluetooth spec, Vol 3, Part H - 2.2.6 LE Secure Connections Confirm Value Generation Function f4
''' '''
return bytes(reversed(aes_cmac(bytes(reversed(u)) + bytes(reversed(v)) + z, bytes(reversed(x))))) return bytes(
reversed(
aes_cmac(bytes(reversed(u)) + bytes(reversed(v)) + z, bytes(reversed(x)))
)
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -177,28 +188,36 @@ def f5(w, n1, n2, a1, a2):
''' '''
salt = bytes.fromhex('6C888391AAF5A53860370BDB5A6083BE') salt = bytes.fromhex('6C888391AAF5A53860370BDB5A6083BE')
t = aes_cmac(bytes(reversed(w)), salt) t = aes_cmac(bytes(reversed(w)), salt)
key_id = bytes([0x62, 0x74, 0x6c, 0x65]) key_id = bytes([0x62, 0x74, 0x6C, 0x65])
return ( return (
bytes(reversed(aes_cmac( bytes(
bytes([0]) + reversed(
key_id + aes_cmac(
bytes(reversed(n1)) + bytes([0])
bytes(reversed(n2)) + + key_id
bytes(reversed(a1)) + + bytes(reversed(n1))
bytes(reversed(a2)) + + bytes(reversed(n2))
bytes([1, 0]), + bytes(reversed(a1))
t + bytes(reversed(a2))
))), + bytes([1, 0]),
bytes(reversed(aes_cmac( t,
bytes([1]) + )
key_id + )
bytes(reversed(n1)) + ),
bytes(reversed(n2)) + bytes(
bytes(reversed(a1)) + reversed(
bytes(reversed(a2)) + aes_cmac(
bytes([1, 0]), bytes([1])
t + key_id
))) + bytes(reversed(n1))
+ bytes(reversed(n2))
+ bytes(reversed(a1))
+ bytes(reversed(a2))
+ bytes([1, 0]),
t,
)
)
),
) )
@@ -207,15 +226,19 @@ def f6(w, n1, n2, r, io_cap, a1, a2):
''' '''
See Bluetooth spec, Vol 3, Part H - 2.2.8 LE Secure Connections Check Value Generation Function f6 See Bluetooth spec, Vol 3, Part H - 2.2.8 LE Secure Connections Check Value Generation Function f6
''' '''
return bytes(reversed(aes_cmac( return bytes(
bytes(reversed(n1)) + reversed(
bytes(reversed(n2)) + aes_cmac(
bytes(reversed(r)) + bytes(reversed(n1))
bytes(reversed(io_cap)) + + bytes(reversed(n2))
bytes(reversed(a1)) + + bytes(reversed(r))
bytes(reversed(a2)), + bytes(reversed(io_cap))
bytes(reversed(w)) + bytes(reversed(a1))
))) + bytes(reversed(a2)),
bytes(reversed(w)),
)
)
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -224,10 +247,14 @@ def g2(u, v, x, y):
See Bluetooth spec, Vol 3, Part H - 2.2.9 LE Secure Connections Numeric Comparison Value Generation Function g2 See Bluetooth spec, Vol 3, Part H - 2.2.9 LE Secure Connections Numeric Comparison Value Generation Function g2
''' '''
return int.from_bytes( return int.from_bytes(
aes_cmac(bytes(reversed(u)) + bytes(reversed(v)) + bytes(reversed(y)), bytes(reversed(x)))[-4:], aes_cmac(
byteorder='big' bytes(reversed(u)) + bytes(reversed(v)) + bytes(reversed(y)),
bytes(reversed(x)),
)[-4:],
byteorder='big',
) )
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def h6(w, key_id): def h6(w, key_id):
''' '''
@@ -235,6 +262,7 @@ def h6(w, key_id):
''' '''
return aes_cmac(key_id, w) return aes_cmac(key_id, w)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def h7(salt, w): def h7(salt, w):
''' '''
+1458 -297
View File
File diff suppressed because it is too large Load Diff
+7 -7
View File
@@ -23,7 +23,7 @@ from .gatt import (
Characteristic, Characteristic,
GATT_GENERIC_ACCESS_SERVICE, GATT_GENERIC_ACCESS_SERVICE,
GATT_DEVICE_NAME_CHARACTERISTIC, GATT_DEVICE_NAME_CHARACTERISTIC,
GATT_APPEARANCE_CHARACTERISTIC GATT_APPEARANCE_CHARACTERISTIC,
) )
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -43,17 +43,17 @@ class GenericAccessService(Service):
GATT_DEVICE_NAME_CHARACTERISTIC, GATT_DEVICE_NAME_CHARACTERISTIC,
Characteristic.READ, Characteristic.READ,
Characteristic.READABLE, Characteristic.READABLE,
device_name.encode('utf-8')[:248] device_name.encode('utf-8')[:248],
) )
appearance_characteristic = Characteristic( appearance_characteristic = Characteristic(
GATT_APPEARANCE_CHARACTERISTIC, GATT_APPEARANCE_CHARACTERISTIC,
Characteristic.READ, Characteristic.READ,
Characteristic.READABLE, Characteristic.READABLE,
struct.pack('<H', (appearance[0] << 6) | appearance[1]) struct.pack('<H', (appearance[0] << 6) | appearance[1]),
) )
super().__init__(GATT_GENERIC_ACCESS_SERVICE, [ super().__init__(
device_name_characteristic, GATT_GENERIC_ACCESS_SERVICE,
appearance_characteristic [device_name_characteristic, appearance_characteristic],
]) )
+109 -20
View File
@@ -22,9 +22,12 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio import asyncio
import enum
import types import types
import logging import logging
from pyee import EventEmitter
from colors import color from colors import color
from .core import * from .core import *
@@ -39,6 +42,8 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Constants # Constants
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# fmt: off
GATT_REQUEST_TIMEOUT = 30 # seconds GATT_REQUEST_TIMEOUT = 30 # seconds
GATT_MAX_ATTRIBUTE_VALUE_SIZE = 512 GATT_MAX_ATTRIBUTE_VALUE_SIZE = 512
@@ -149,6 +154,14 @@ GATT_HEART_RATE_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2A39, 'Heart
# Battery Service # Battery Service
GATT_BATTERY_LEVEL_CHARACTERISTIC = UUID.from_16_bits(0x2A19, 'Battery Level') GATT_BATTERY_LEVEL_CHARACTERISTIC = UUID.from_16_bits(0x2A19, 'Battery Level')
# ASHA Service
GATT_ASHA_SERVICE = UUID.from_16_bits(0xFDF0, 'Audio Streaming for Hearing Aid')
GATT_ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC = UUID('6333651e-c481-4a3e-9169-7c902aad37bb', 'ReadOnlyProperties')
GATT_ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC = UUID('f0d4de7e-4a88-476c-9d9f-1937b0996cc0', 'AudioControlPoint')
GATT_ASHA_AUDIO_STATUS_CHARACTERISTIC = UUID('38663f1a-e711-4cac-b641-326b56404837', 'AudioStatus')
GATT_ASHA_VOLUME_CHARACTERISTIC = UUID('00e4ca9e-ab14-41e4-8823-f9e70c7e91df', 'Volume')
GATT_ASHA_LE_PSM_OUT_CHARACTERISTIC = UUID('2d410339-82b6-42aa-b34e-e2e01df8cc1a', 'LE_PSM_OUT')
# Misc # Misc
GATT_DEVICE_NAME_CHARACTERISTIC = UUID.from_16_bits(0x2A00, 'Device Name') GATT_DEVICE_NAME_CHARACTERISTIC = UUID.from_16_bits(0x2A00, 'Device Name')
GATT_APPEARANCE_CHARACTERISTIC = UUID.from_16_bits(0x2A01, 'Appearance') GATT_APPEARANCE_CHARACTERISTIC = UUID.from_16_bits(0x2A01, 'Appearance')
@@ -163,11 +176,14 @@ GATT_CURRENT_TIME_CHARACTERISTIC = UUID.from_16_bi
GATT_BOOT_KEYBOARD_OUTPUT_REPORT_CHARACTERISTIC = UUID.from_16_bits(0x2A32, 'Boot Keyboard Output Report') GATT_BOOT_KEYBOARD_OUTPUT_REPORT_CHARACTERISTIC = UUID.from_16_bits(0x2A32, 'Boot Keyboard Output Report')
GATT_CENTRAL_ADDRESS_RESOLUTION__CHARACTERISTIC = UUID.from_16_bits(0x2AA6, 'Central Address Resolution') GATT_CENTRAL_ADDRESS_RESOLUTION__CHARACTERISTIC = UUID.from_16_bits(0x2AA6, 'Central Address Resolution')
# fmt: on
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Utils # Utils
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def show_services(services): def show_services(services):
for service in services: for service in services:
print(color(str(service), 'cyan')) print(color(str(service), 'cyan'))
@@ -185,21 +201,31 @@ class Service(Attribute):
See Vol 3, Part G - 3.1 SERVICE DEFINITION See Vol 3, Part G - 3.1 SERVICE DEFINITION
''' '''
def __init__(self, uuid, characteristics, primary=True): def __init__(self, uuid, characteristics: list[Characteristic], primary=True):
# Convert the uuid to a UUID object if it isn't already # Convert the uuid to a UUID object if it isn't already
if type(uuid) is str: if type(uuid) is str:
uuid = UUID(uuid) uuid = UUID(uuid)
super().__init__( super().__init__(
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE if primary else GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE, GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE
if primary
else GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE,
Attribute.READABLE, Attribute.READABLE,
uuid.to_pdu_bytes() uuid.to_pdu_bytes(),
) )
self.uuid = uuid self.uuid = uuid
self.included_services = [] self.included_services = []
self.characteristics = characteristics[:] self.characteristics = characteristics[:]
self.primary = primary self.primary = primary
def get_advertising_data(self):
"""
Get Service specific advertising data
Defined by each Service, default value is empty
:return Service data for advertising
"""
return None
def __str__(self): def __str__(self):
return f'Service(handle=0x{self.handle:04X}, end=0x{self.end_group_handle:04X}, uuid={self.uuid}){"" if self.primary else "*"}' return f'Service(handle=0x{self.handle:04X}, end=0x{self.end_group_handle:04X}, uuid={self.uuid}){"" if self.primary else "*"}'
@@ -210,6 +236,7 @@ class TemplateService(Service):
Convenience abstract class that can be used by profile-specific subclasses that want Convenience abstract class that can be used by profile-specific subclasses that want
to expose their UUID as a class property to expose their UUID as a class property
''' '''
UUID = None UUID = None
def __init__(self, characteristics, primary=True): def __init__(self, characteristics, primary=True):
@@ -228,9 +255,9 @@ class Characteristic(Attribute):
WRITE_WITHOUT_RESPONSE = 0x04 WRITE_WITHOUT_RESPONSE = 0x04
WRITE = 0x08 WRITE = 0x08
NOTIFY = 0x10 NOTIFY = 0x10
INDICATE = 0X20 INDICATE = 0x20
AUTHENTICATED_SIGNED_WRITES = 0X40 AUTHENTICATED_SIGNED_WRITES = 0x40
EXTENDED_PROPERTIES = 0X80 EXTENDED_PROPERTIES = 0x80
PROPERTY_NAMES = { PROPERTY_NAMES = {
BROADCAST: 'BROADCAST', BROADCAST: 'BROADCAST',
@@ -240,7 +267,7 @@ class Characteristic(Attribute):
NOTIFY: 'NOTIFY', NOTIFY: 'NOTIFY',
INDICATE: 'INDICATE', INDICATE: 'INDICATE',
AUTHENTICATED_SIGNED_WRITES: 'AUTHENTICATED_SIGNED_WRITES', AUTHENTICATED_SIGNED_WRITES: 'AUTHENTICATED_SIGNED_WRITES',
EXTENDED_PROPERTIES: 'EXTENDED_PROPERTIES' EXTENDED_PROPERTIES: 'EXTENDED_PROPERTIES',
} }
@staticmethod @staticmethod
@@ -249,32 +276,75 @@ class Characteristic(Attribute):
@staticmethod @staticmethod
def properties_as_string(properties): def properties_as_string(properties):
return ','.join([ return ','.join(
Characteristic.property_name(p) for p in Characteristic.PROPERTY_NAMES.keys() [
Characteristic.property_name(p)
for p in Characteristic.PROPERTY_NAMES.keys()
if properties & p if properties & p
]) ]
)
def __init__(self, uuid, properties, permissions, value = b'', descriptors = []): @staticmethod
def string_to_properties(properties_str: str):
return functools.reduce(
lambda x, y: x | get_dict_key_by_value(Characteristic.PROPERTY_NAMES, y),
properties_str.split(","),
0,
)
def __init__(
self,
uuid,
properties,
permissions,
value=b'',
descriptors: list[Descriptor] = [],
):
super().__init__(uuid, permissions, value) super().__init__(uuid, permissions, value)
self.uuid = self.type self.uuid = self.type
if type(properties) is str:
self.properties = Characteristic.string_to_properties(properties)
else:
self.properties = properties self.properties = properties
self.descriptors = descriptors self.descriptors = descriptors
def get_descriptor(self, descriptor_type): def get_descriptor(self, descriptor_type):
for descriptor in self.descriptors: for descriptor in self.descriptors:
if descriptor.uuid == descriptor_type: if descriptor.type == descriptor_type:
return descriptor return descriptor
def __str__(self): def __str__(self):
return f'Characteristic(handle=0x{self.handle:04X}, end=0x{self.end_group_handle:04X}, uuid={self.uuid}, properties={Characteristic.properties_as_string(self.properties)})' return f'Characteristic(handle=0x{self.handle:04X}, end=0x{self.end_group_handle:04X}, uuid={self.uuid}, properties={Characteristic.properties_as_string(self.properties)})'
# -----------------------------------------------------------------------------
class CharacteristicDeclaration(Attribute):
'''
See Vol 3, Part G - 3.3.1 CHARACTERISTIC DECLARATION
'''
def __init__(self, characteristic, value_handle):
declaration_bytes = (
struct.pack('<BH', characteristic.properties, value_handle)
+ characteristic.uuid.to_pdu_bytes()
)
super().__init__(
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE, Attribute.READABLE, declaration_bytes
)
self.value_handle = value_handle
self.characteristic = characteristic
def __str__(self):
return f'CharacteristicDeclaration(handle=0x{self.handle:04X}, value_handle=0x{self.value_handle:04X}, uuid={self.characteristic.uuid}, properties={Characteristic.properties_as_string(self.characteristic.properties)})'
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class CharacteristicValue: class CharacteristicValue:
''' '''
Characteristic value where reading and/or writing is delegated to functions Characteristic value where reading and/or writing is delegated to functions
passed as arguments to the constructor. passed as arguments to the constructor.
''' '''
def __init__(self, read=None, write=None): def __init__(self, read=None, write=None):
self._read = read self._read = read
self._write = write self._write = write
@@ -301,14 +371,14 @@ class CharacteristicAdapter:
If the characteristic has a `subscribe` method, it is wrapped with one where If the characteristic has a `subscribe` method, it is wrapped with one where
the values are decoded before being passed to the subscriber. the values are decoded before being passed to the subscriber.
''' '''
def __init__(self, characteristic): def __init__(self, characteristic):
self.wrapped_characteristic = characteristic self.wrapped_characteristic = characteristic
self.subscribers = {} # Map from subscriber to proxy subscriber self.subscribers = {} # Map from subscriber to proxy subscriber
if ( if asyncio.iscoroutinefunction(
asyncio.iscoroutinefunction(characteristic.read_value) and characteristic.read_value
asyncio.iscoroutinefunction(characteristic.write_value) ) and asyncio.iscoroutinefunction(characteristic.write_value):
):
self.read_value = self.read_decoded_value self.read_value = self.read_decoded_value
self.write_value = self.write_decoded_value self.write_value = self.write_decoded_value
else: else:
@@ -331,7 +401,7 @@ class CharacteristicAdapter:
'read_value', 'read_value',
'write_value', 'write_value',
'subscribe', 'subscribe',
'unsubscribe' 'unsubscribe',
}: }:
super().__setattr__(name, value) super().__setattr__(name, value)
else: else:
@@ -341,13 +411,17 @@ class CharacteristicAdapter:
return self.encode_value(self.wrapped_characteristic.read_value(connection)) return self.encode_value(self.wrapped_characteristic.read_value(connection))
def write_encoded_value(self, connection, value): def write_encoded_value(self, connection, value):
return self.wrapped_characteristic.write_value(connection, self.decode_value(value)) return self.wrapped_characteristic.write_value(
connection, self.decode_value(value)
)
async def read_decoded_value(self): async def read_decoded_value(self):
return self.decode_value(await self.wrapped_characteristic.read_value()) return self.decode_value(await self.wrapped_characteristic.read_value())
async def write_decoded_value(self, value): async def write_decoded_value(self, value, with_response=False):
return await self.wrapped_characteristic.write_value(self.encode_value(value)) return await self.wrapped_characteristic.write_value(
self.encode_value(value), with_response
)
def encode_value(self, value): def encode_value(self, value):
return value return value
@@ -366,6 +440,7 @@ class CharacteristicAdapter:
def on_change(value): def on_change(value):
original_subscriber(self.decode_value(value)) original_subscriber(self.decode_value(value))
self.subscribers[subscriber] = on_change self.subscribers[subscriber] = on_change
subscriber = on_change subscriber = on_change
@@ -387,6 +462,7 @@ class DelegatedCharacteristicAdapter(CharacteristicAdapter):
''' '''
Adapter that converts bytes values using an encode and a decode function. Adapter that converts bytes values using an encode and a decode function.
''' '''
def __init__(self, characteristic, encode=None, decode=None): def __init__(self, characteristic, encode=None, decode=None):
super().__init__(characteristic) super().__init__(characteristic)
self.encode = encode self.encode = encode
@@ -409,6 +485,7 @@ class PackedCharacteristicAdapter(CharacteristicAdapter):
they return/accept a tuple with the same number of elements as is required for they return/accept a tuple with the same number of elements as is required for
the format. the format.
''' '''
def __init__(self, characteristic, format): def __init__(self, characteristic, format):
super().__init__(characteristic) super().__init__(characteristic)
self.struct = struct.Struct(format) self.struct = struct.Struct(format)
@@ -436,6 +513,7 @@ class MappedCharacteristicAdapter(PackedCharacteristicAdapter):
is packed/unpacked according to format, with the arguments extracted from the dictionary is packed/unpacked according to format, with the arguments extracted from the dictionary
by key, in the same order as they occur in the `keys` parameter. by key, in the same order as they occur in the `keys` parameter.
''' '''
def __init__(self, characteristic, format, keys): def __init__(self, characteristic, format, keys):
super().__init__(characteristic, format) super().__init__(characteristic, format)
self.keys = keys self.keys = keys
@@ -452,6 +530,7 @@ class UTF8CharacteristicAdapter(CharacteristicAdapter):
''' '''
Adapter that converts strings to/from bytes using UTF-8 encoding Adapter that converts strings to/from bytes using UTF-8 encoding
''' '''
def encode_value(self, value): def encode_value(self, value):
return value.encode('utf-8') return value.encode('utf-8')
@@ -470,3 +549,13 @@ class Descriptor(Attribute):
def __str__(self): def __str__(self):
return f'Descriptor(handle=0x{self.handle:04X}, type={self.type}, value={self.read_value(None).hex()})' return f'Descriptor(handle=0x{self.handle:04X}, type={self.type}, value={self.read_value(None).hex()})'
class ClientCharacteristicConfigurationBits(enum.IntFlag):
'''
See Vol 3, Part G - 3.3.3.3 - Table 3.11 Client Characteristic Configuration bit field definition
'''
DEFAULT = 0x0000
NOTIFICATION = 0x0001
INDICATION = 0x0002
+170 -76
View File
@@ -26,19 +26,21 @@
import asyncio import asyncio
import logging import logging
import struct import struct
from colors import color from colors import color
from .core import ProtocolError, TimeoutError
from .hci import *
from .att import * from .att import *
from .core import InvalidStateError, ProtocolError, TimeoutError
from .gatt import ( from .gatt import (
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
GATT_REQUEST_TIMEOUT,
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE,
GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE,
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE, GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
Characteristic GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE,
GATT_REQUEST_TIMEOUT,
GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE,
Characteristic,
ClientCharacteristicConfigurationBits,
) )
from .hci import *
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -58,10 +60,14 @@ class AttributeProxy(EventEmitter):
self.type = attribute_type self.type = attribute_type
async def read_value(self, no_long_read=False): async def read_value(self, no_long_read=False):
return self.decode_value(await self.client.read_value(self.handle, no_long_read)) return self.decode_value(
await self.client.read_value(self.handle, no_long_read)
)
async def write_value(self, value, with_response=False): async def write_value(self, value, with_response=False):
return await self.client.write_value(self.handle, self.encode_value(value), with_response) return await self.client.write_value(
self.handle, self.encode_value(value), with_response
)
def encode_value(self, value): def encode_value(self, value):
return value return value
@@ -82,7 +88,11 @@ class ServiceProxy(AttributeProxy):
return cls(service) if service else None return cls(service) if service else None
def __init__(self, client, handle, end_group_handle, uuid, primary=True): def __init__(self, client, handle, end_group_handle, uuid, primary=True):
attribute_type = GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE if primary else GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE attribute_type = (
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE
if primary
else GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE
)
super().__init__(client, handle, end_group_handle, attribute_type) super().__init__(client, handle, end_group_handle, attribute_type)
self.uuid = uuid self.uuid = uuid
self.characteristics = [] self.characteristics = []
@@ -114,7 +124,7 @@ class CharacteristicProxy(AttributeProxy):
async def discover_descriptors(self): async def discover_descriptors(self):
return await self.client.discover_descriptors(self) return await self.client.discover_descriptors(self)
async def subscribe(self, subscriber=None): async def subscribe(self, subscriber=None, prefer_notify=True):
if subscriber is not None: if subscriber is not None:
if subscriber in self.subscribers: if subscriber in self.subscribers:
# We already have a proxy subscriber # We already have a proxy subscriber
@@ -125,10 +135,11 @@ class CharacteristicProxy(AttributeProxy):
def on_change(value): def on_change(value):
original_subscriber(self.decode_value(value)) original_subscriber(self.decode_value(value))
self.subscribers[subscriber] = on_change self.subscribers[subscriber] = on_change
subscriber = on_change subscriber = on_change
return await self.client.subscribe(self, subscriber) return await self.client.subscribe(self, subscriber, prefer_notify)
async def unsubscribe(self, subscriber=None): async def unsubscribe(self, subscriber=None):
if subscriber in self.subscribers: if subscriber in self.subscribers:
@@ -152,6 +163,7 @@ class ProfileServiceProxy:
''' '''
Base class for profile-specific service proxies Base class for profile-specific service proxies
''' '''
@classmethod @classmethod
def from_client(cls, client): def from_client(cls, client):
return ServiceProxy.from_client(cls, client, cls.SERVICE_CLASS.UUID) return ServiceProxy.from_client(cls, client, cls.SERVICE_CLASS.UUID)
@@ -167,7 +179,9 @@ class Client:
self.request_semaphore = asyncio.Semaphore(1) self.request_semaphore = asyncio.Semaphore(1)
self.pending_request = None self.pending_request = None
self.pending_response = None self.pending_response = None
self.notification_subscribers = {} # Notification subscribers, by attribute handle self.notification_subscribers = (
{}
) # Notification subscribers, by attribute handle
self.indication_subscribers = {} # Indication subscribers, by attribute handle self.indication_subscribers = {} # Indication subscribers, by attribute handle
self.services = [] self.services = []
@@ -175,17 +189,21 @@ class Client:
self.connection.send_l2cap_pdu(ATT_CID, pdu) self.connection.send_l2cap_pdu(ATT_CID, pdu)
async def send_command(self, command): async def send_command(self, command):
logger.debug(f'GATT Command from client: [0x{self.connection.handle:04X}] {command}') logger.debug(
f'GATT Command from client: [0x{self.connection.handle:04X}] {command}'
)
self.send_gatt_pdu(command.to_bytes()) self.send_gatt_pdu(command.to_bytes())
async def send_request(self, request): async def send_request(self, request):
logger.debug(f'GATT Request from client: [0x{self.connection.handle:04X}] {request}') logger.debug(
f'GATT Request from client: [0x{self.connection.handle:04X}] {request}'
)
# Wait until we can send (only one pending command at a time for the connection) # Wait until we can send (only one pending command at a time for the connection)
response = None response = None
async with self.request_semaphore: async with self.request_semaphore:
assert(self.pending_request is None) assert self.pending_request is None
assert(self.pending_response is None) assert self.pending_response is None
# Create a future value to hold the eventual response # Create a future value to hold the eventual response
self.pending_response = asyncio.get_running_loop().create_future() self.pending_response = asyncio.get_running_loop().create_future()
@@ -193,7 +211,9 @@ class Client:
try: try:
self.send_gatt_pdu(request.to_bytes()) self.send_gatt_pdu(request.to_bytes())
response = await asyncio.wait_for(self.pending_response, GATT_REQUEST_TIMEOUT) response = await asyncio.wait_for(
self.pending_response, GATT_REQUEST_TIMEOUT
)
except asyncio.TimeoutError: except asyncio.TimeoutError:
logger.warning(color('!!! GATT Request timeout', 'red')) logger.warning(color('!!! GATT Request timeout', 'red'))
raise TimeoutError(f'GATT timeout for {request.name}') raise TimeoutError(f'GATT timeout for {request.name}')
@@ -204,7 +224,9 @@ class Client:
return response return response
def send_confirmation(self, confirmation): def send_confirmation(self, confirmation):
logger.debug(f'GATT Confirmation from client: [0x{self.connection.handle:04X}] {confirmation}') logger.debug(
f'GATT Confirmation from client: [0x{self.connection.handle:04X}] {confirmation}'
)
self.send_gatt_pdu(confirmation.to_bytes()) self.send_gatt_pdu(confirmation.to_bytes())
async def request_mtu(self, mtu): async def request_mtu(self, mtu):
@@ -226,7 +248,7 @@ class Client:
response.error_code, response.error_code,
'att', 'att',
ATT_PDU.error_name(response.error_code), ATT_PDU.error_name(response.error_code),
response response,
) )
# Compute the final MTU # Compute the final MTU
@@ -239,7 +261,11 @@ class Client:
def get_characteristics_by_uuid(self, uuid, service=None): def get_characteristics_by_uuid(self, uuid, service=None):
services = [service] if service else self.services services = [service] if service else self.services
return [c for c in [c for s in services for c in s.characteristics] if c.uuid == uuid] return [
c
for c in [c for s in services for c in s.characteristics]
if c.uuid == uuid
]
def on_service_discovered(self, service): def on_service_discovered(self, service):
'''Add a service to the service list if it wasn't already there''' '''Add a service to the service list if it wasn't already there'''
@@ -262,7 +288,7 @@ class Client:
ATT_Read_By_Group_Type_Request( ATT_Read_By_Group_Type_Request(
starting_handle=starting_handle, starting_handle=starting_handle,
ending_handle=0xFFFF, ending_handle=0xFFFF,
attribute_group_type = GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE attribute_group_type=GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE,
) )
) )
if response is None: if response is None:
@@ -273,15 +299,26 @@ class Client:
if response.op_code == ATT_ERROR_RESPONSE: if response.op_code == ATT_ERROR_RESPONSE:
if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR: if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR:
# Unexpected end # Unexpected end
logger.waning(f'!!! unexpected error while discovering services: {HCI_Constant.error_name(response.error_code)}') logger.warning(
f'!!! unexpected error while discovering services: {HCI_Constant.error_name(response.error_code)}'
)
# TODO raise appropriate exception # TODO raise appropriate exception
return return
break break
for attribute_handle, end_group_handle, attribute_value in response.attributes: for (
if attribute_handle < starting_handle or end_group_handle < attribute_handle: attribute_handle,
end_group_handle,
attribute_value,
) in response.attributes:
if (
attribute_handle < starting_handle
or end_group_handle < attribute_handle
):
# Something's not right # Something's not right
logger.warning(f'bogus handle values: {attribute_handle} {end_group_handle}') logger.warning(
f'bogus handle values: {attribute_handle} {end_group_handle}'
)
return return
# Create a service proxy for this service # Create a service proxy for this service
@@ -290,7 +327,7 @@ class Client:
attribute_handle, attribute_handle,
end_group_handle, end_group_handle,
UUID.from_bytes(attribute_value), UUID.from_bytes(attribute_value),
True True,
) )
# Filter out returned services based on the given uuids list # Filter out returned services based on the given uuids list
@@ -326,7 +363,7 @@ class Client:
starting_handle=starting_handle, starting_handle=starting_handle,
ending_handle=0xFFFF, ending_handle=0xFFFF,
attribute_type=GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE, attribute_type=GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE,
attribute_value = uuid.to_pdu_bytes() attribute_value=uuid.to_pdu_bytes(),
) )
) )
if response is None: if response is None:
@@ -337,19 +374,28 @@ class Client:
if response.op_code == ATT_ERROR_RESPONSE: if response.op_code == ATT_ERROR_RESPONSE:
if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR: if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR:
# Unexpected end # Unexpected end
logger.waning(f'!!! unexpected error while discovering services: {HCI_Constant.error_name(response.error_code)}') logger.warning(
f'!!! unexpected error while discovering services: {HCI_Constant.error_name(response.error_code)}'
)
# TODO raise appropriate exception # TODO raise appropriate exception
return return
break break
for attribute_handle, end_group_handle in response.handles_information: for attribute_handle, end_group_handle in response.handles_information:
if attribute_handle < starting_handle or end_group_handle < attribute_handle: if (
attribute_handle < starting_handle
or end_group_handle < attribute_handle
):
# Something's not right # Something's not right
logger.warning(f'bogus handle values: {attribute_handle} {end_group_handle}') logger.warning(
f'bogus handle values: {attribute_handle} {end_group_handle}'
)
return return
# Create a service proxy for this service # Create a service proxy for this service
service = ServiceProxy(self, attribute_handle, end_group_handle, uuid, True) service = ServiceProxy(
self, attribute_handle, end_group_handle, uuid, True
)
# Add the service to the peer's service list # Add the service to the peer's service list
services.append(service) services.append(service)
@@ -398,7 +444,7 @@ class Client:
ATT_Read_By_Type_Request( ATT_Read_By_Type_Request(
starting_handle=starting_handle, starting_handle=starting_handle,
ending_handle=ending_handle, ending_handle=ending_handle,
attribute_type = GATT_CHARACTERISTIC_ATTRIBUTE_TYPE attribute_type=GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
) )
) )
if response is None: if response is None:
@@ -409,7 +455,9 @@ class Client:
if response.op_code == ATT_ERROR_RESPONSE: if response.op_code == ATT_ERROR_RESPONSE:
if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR: if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR:
# Unexpected end # Unexpected end
logger.warning(f'!!! unexpected error while discovering characteristics: {HCI_Constant.error_name(response.error_code)}') logger.warning(
f'!!! unexpected error while discovering characteristics: {HCI_Constant.error_name(response.error_code)}'
)
# TODO raise appropriate exception # TODO raise appropriate exception
return return
break break
@@ -427,7 +475,9 @@ class Client:
properties, handle = struct.unpack_from('<BH', attribute_value) properties, handle = struct.unpack_from('<BH', attribute_value)
characteristic_uuid = UUID.from_bytes(attribute_value[3:]) characteristic_uuid = UUID.from_bytes(attribute_value[3:])
characteristic = CharacteristicProxy(self, handle, 0, characteristic_uuid, properties) characteristic = CharacteristicProxy(
self, handle, 0, characteristic_uuid, properties
)
# Set the previous characteristic's end handle # Set the previous characteristic's end handle
if characteristics: if characteristics:
@@ -443,13 +493,17 @@ class Client:
characteristics[-1].end_group_handle = service.end_group_handle characteristics[-1].end_group_handle = service.end_group_handle
# Set the service's characteristics # Set the service's characteristics
characteristics = [c for c in characteristics if not uuids or c.uuid in uuids] characteristics = [
c for c in characteristics if not uuids or c.uuid in uuids
]
service.characteristics = characteristics service.characteristics = characteristics
discovered_characteristics.extend(characteristics) discovered_characteristics.extend(characteristics)
return discovered_characteristics return discovered_characteristics
async def discover_descriptors(self, characteristic = None, start_handle = None, end_handle = None): async def discover_descriptors(
self, characteristic=None, start_handle=None, end_handle=None
):
''' '''
See Vol 3, Part G - 4.7.1 Discover All Characteristic Descriptors See Vol 3, Part G - 4.7.1 Discover All Characteristic Descriptors
''' '''
@@ -466,8 +520,7 @@ class Client:
while starting_handle <= ending_handle: while starting_handle <= ending_handle:
response = await self.send_request( response = await self.send_request(
ATT_Find_Information_Request( ATT_Find_Information_Request(
starting_handle = starting_handle, starting_handle=starting_handle, ending_handle=ending_handle
ending_handle = ending_handle
) )
) )
if response is None: if response is None:
@@ -478,7 +531,9 @@ class Client:
if response.op_code == ATT_ERROR_RESPONSE: if response.op_code == ATT_ERROR_RESPONSE:
if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR: if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR:
# Unexpected end # Unexpected end
logger.warning(f'!!! unexpected error while discovering descriptors: {HCI_Constant.error_name(response.error_code)}') logger.warning(
f'!!! unexpected error while discovering descriptors: {HCI_Constant.error_name(response.error_code)}'
)
# TODO raise appropriate exception # TODO raise appropriate exception
return [] return []
break break
@@ -494,7 +549,9 @@ class Client:
logger.warning(f'bogus handle value: {attribute_handle}') logger.warning(f'bogus handle value: {attribute_handle}')
return [] return []
descriptor = DescriptorProxy(self, attribute_handle, UUID.from_bytes(attribute_uuid)) descriptor = DescriptorProxy(
self, attribute_handle, UUID.from_bytes(attribute_uuid)
)
descriptors.append(descriptor) descriptors.append(descriptor)
# TODO: read descriptor value # TODO: read descriptor value
@@ -517,8 +574,7 @@ class Client:
while True: while True:
response = await self.send_request( response = await self.send_request(
ATT_Find_Information_Request( ATT_Find_Information_Request(
starting_handle = starting_handle, starting_handle=starting_handle, ending_handle=ending_handle
ending_handle = ending_handle
) )
) )
if response is None: if response is None:
@@ -528,7 +584,9 @@ class Client:
if response.op_code == ATT_ERROR_RESPONSE: if response.op_code == ATT_ERROR_RESPONSE:
if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR: if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR:
# Unexpected end # Unexpected end
logger.warning(f'!!! unexpected error while discovering attributes: {HCI_Constant.error_name(response.error_code)}') logger.warning(
f'!!! unexpected error while discovering attributes: {HCI_Constant.error_name(response.error_code)}'
)
return [] return []
break break
@@ -538,7 +596,9 @@ class Client:
logger.warning(f'bogus handle value: {attribute_handle}') logger.warning(f'bogus handle value: {attribute_handle}')
return [] return []
attribute = AttributeProxy(self, attribute_handle, 0, UUID.from_bytes(attribute_uuid)) attribute = AttributeProxy(
self, attribute_handle, 0, UUID.from_bytes(attribute_uuid)
)
attributes.append(attribute) attributes.append(attribute)
# Move on to the next attributes # Move on to the next attributes
@@ -546,29 +606,40 @@ class Client:
return attributes return attributes
async def subscribe(self, characteristic, subscriber=None): async def subscribe(self, characteristic, subscriber=None, prefer_notify=True):
# If we haven't already discovered the descriptors for this characteristic, do it now # If we haven't already discovered the descriptors for this characteristic, do it now
if not characteristic.descriptors_discovered: if not characteristic.descriptors_discovered:
await self.discover_descriptors(characteristic) await self.discover_descriptors(characteristic)
# Look for the CCCD descriptor # Look for the CCCD descriptor
cccd = characteristic.get_descriptor(GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR) cccd = characteristic.get_descriptor(
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR
)
if not cccd: if not cccd:
logger.warning('subscribing to characteristic with no CCCD descriptor') logger.warning('subscribing to characteristic with no CCCD descriptor')
return return
# Set the subscription bits and select the subscriber set if (
bits = 0 characteristic.properties & Characteristic.NOTIFY
subscriber_sets = [] and characteristic.properties & Characteristic.INDICATE
if characteristic.properties & Characteristic.NOTIFY: ):
bits |= 0x0001 if prefer_notify:
subscriber_sets.append(self.notification_subscribers.setdefault(characteristic.handle, set())) bits = ClientCharacteristicConfigurationBits.NOTIFICATION
if characteristic.properties & Characteristic.INDICATE: subscribers = self.notification_subscribers
bits |= 0x0002 else:
subscriber_sets.append(self.indication_subscribers.setdefault(characteristic.handle, set())) bits = ClientCharacteristicConfigurationBits.INDICATION
subscribers = self.indication_subscribers
elif characteristic.properties & Characteristic.NOTIFY:
bits = ClientCharacteristicConfigurationBits.NOTIFICATION
subscribers = self.notification_subscribers
elif characteristic.properties & Characteristic.INDICATE:
bits = ClientCharacteristicConfigurationBits.INDICATION
subscribers = self.indication_subscribers
else:
raise InvalidStateError("characteristic is not notify or indicate")
# Add subscribers to the sets # Add subscribers to the sets
for subscriber_set in subscriber_sets: subscriber_set = subscribers.setdefault(characteristic.handle, set())
if subscriber is not None: if subscriber is not None:
subscriber_set.add(subscriber) subscriber_set.add(subscriber)
# Add the characteristic as a subscriber, which will result in the characteristic # Add the characteristic as a subscriber, which will result in the characteristic
@@ -583,14 +654,19 @@ class Client:
await self.discover_descriptors(characteristic) await self.discover_descriptors(characteristic)
# Look for the CCCD descriptor # Look for the CCCD descriptor
cccd = characteristic.get_descriptor(GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR) cccd = characteristic.get_descriptor(
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR
)
if not cccd: if not cccd:
logger.warning('unsubscribing from characteristic with no CCCD descriptor') logger.warning('unsubscribing from characteristic with no CCCD descriptor')
return return
if subscriber is not None: if subscriber is not None:
# Remove matching subscriber from subscriber sets # Remove matching subscriber from subscriber sets
for subscriber_set in (self.notification_subscribers, self.indication_subscribers): for subscriber_set in (
self.notification_subscribers,
self.indication_subscribers,
):
subscribers = subscriber_set.get(characteristic.handle, []) subscribers = subscriber_set.get(characteristic.handle, [])
if subscriber in subscribers: if subscriber in subscribers:
subscribers.remove(subscriber) subscribers.remove(subscriber)
@@ -616,7 +692,9 @@ class Client:
# Send a request to read # Send a request to read
attribute_handle = attribute if type(attribute) is int else attribute.handle attribute_handle = attribute if type(attribute) is int else attribute.handle
response = await self.send_request(ATT_Read_Request(attribute_handle = attribute_handle)) response = await self.send_request(
ATT_Read_Request(attribute_handle=attribute_handle)
)
if response is None: if response is None:
raise TimeoutError('read timeout') raise TimeoutError('read timeout')
if response.op_code == ATT_ERROR_RESPONSE: if response.op_code == ATT_ERROR_RESPONSE:
@@ -624,7 +702,7 @@ class Client:
response.error_code, response.error_code,
'att', 'att',
ATT_PDU.error_name(response.error_code), ATT_PDU.error_name(response.error_code),
response response,
) )
# If the value is the max size for the MTU, try to read more unless the caller # If the value is the max size for the MTU, try to read more unless the caller
@@ -635,18 +713,23 @@ class Client:
offset = len(attribute_value) offset = len(attribute_value)
while True: while True:
response = await self.send_request( response = await self.send_request(
ATT_Read_Blob_Request(attribute_handle = attribute_handle, value_offset = offset) ATT_Read_Blob_Request(
attribute_handle=attribute_handle, value_offset=offset
)
) )
if response is None: if response is None:
raise TimeoutError('read timeout') raise TimeoutError('read timeout')
if response.op_code == ATT_ERROR_RESPONSE: if response.op_code == ATT_ERROR_RESPONSE:
if response.error_code == ATT_ATTRIBUTE_NOT_LONG_ERROR or response.error_code == ATT_INVALID_OFFSET_ERROR: if (
response.error_code == ATT_ATTRIBUTE_NOT_LONG_ERROR
or response.error_code == ATT_INVALID_OFFSET_ERROR
):
break break
raise ProtocolError( raise ProtocolError(
response.error_code, response.error_code,
'att', 'att',
ATT_PDU.error_name(response.error_code), ATT_PDU.error_name(response.error_code),
response response,
) )
part = response.part_attribute_value part = response.part_attribute_value
@@ -678,7 +761,7 @@ class Client:
ATT_Read_By_Type_Request( ATT_Read_By_Type_Request(
starting_handle=starting_handle, starting_handle=starting_handle,
ending_handle=ending_handle, ending_handle=ending_handle,
attribute_type = uuid attribute_type=uuid,
) )
) )
if response is None: if response is None:
@@ -689,7 +772,9 @@ class Client:
if response.op_code == ATT_ERROR_RESPONSE: if response.op_code == ATT_ERROR_RESPONSE:
if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR: if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR:
# Unexpected end # Unexpected end
logger.warning(f'!!! unexpected error while reading characteristics: {HCI_Constant.error_name(response.error_code)}') logger.warning(
f'!!! unexpected error while reading characteristics: {HCI_Constant.error_name(response.error_code)}'
)
# TODO raise appropriate exception # TODO raise appropriate exception
return [] return []
break break
@@ -724,26 +809,27 @@ class Client:
if with_response: if with_response:
response = await self.send_request( response = await self.send_request(
ATT_Write_Request( ATT_Write_Request(
attribute_handle = attribute_handle, attribute_handle=attribute_handle, attribute_value=value
attribute_value = value
) )
) )
if response.op_code == ATT_ERROR_RESPONSE: if response.op_code == ATT_ERROR_RESPONSE:
raise ProtocolError( raise ProtocolError(
response.error_code, response.error_code,
'att', 'att',
ATT_PDU.error_name(response.error_code), response ATT_PDU.error_name(response.error_code),
response,
) )
else: else:
await self.send_command( await self.send_command(
ATT_Write_Command( ATT_Write_Command(
attribute_handle = attribute_handle, attribute_handle=attribute_handle, attribute_value=value
attribute_value = value
) )
) )
def on_gatt_pdu(self, att_pdu): def on_gatt_pdu(self, att_pdu):
logger.debug(f'GATT Response to client: [0x{self.connection.handle:04X}] {att_pdu}') logger.debug(
f'GATT Response to client: [0x{self.connection.handle:04X}] {att_pdu}'
)
if att_pdu.op_code in ATT_RESPONSES: if att_pdu.op_code in ATT_RESPONSES:
if self.pending_request is None: if self.pending_request is None:
# Not expected! # Not expected!
@@ -752,9 +838,13 @@ class Client:
# Sanity check: the response should match the pending request unless it is an error response # Sanity check: the response should match the pending request unless it is an error response
if att_pdu.op_code != ATT_ERROR_RESPONSE: if att_pdu.op_code != ATT_ERROR_RESPONSE:
expected_response_name = self.pending_request.name.replace('_REQUEST', '_RESPONSE') expected_response_name = self.pending_request.name.replace(
'_REQUEST', '_RESPONSE'
)
if att_pdu.name != expected_response_name: if att_pdu.name != expected_response_name:
logger.warning(f'!!! mismatched response: expected {expected_response_name}') logger.warning(
f'!!! mismatched response: expected {expected_response_name}'
)
return return
# Return the response to the coroutine that is waiting for it # Return the response to the coroutine that is waiting for it
@@ -765,11 +855,15 @@ class Client:
if handler is not None: if handler is not None:
handler(att_pdu) handler(att_pdu)
else: else:
logger.warning(f'{color(f"--- Ignoring GATT Response from [0x{self.connection.handle:04X}]:", "red")} {att_pdu}') logger.warning(
f'{color(f"--- Ignoring GATT Response from [0x{self.connection.handle:04X}]:", "red")} {att_pdu}'
)
def on_att_handle_value_notification(self, notification): def on_att_handle_value_notification(self, notification):
# Call all subscribers # Call all subscribers
subscribers = self.notification_subscribers.get(notification.attribute_handle, []) subscribers = self.notification_subscribers.get(
notification.attribute_handle, []
)
if not subscribers: if not subscribers:
logger.warning('!!! received notification with no subscriber') logger.warning('!!! received notification with no subscriber')
for subscriber in subscribers: for subscriber in subscribers:
+243 -101
View File
@@ -26,6 +26,7 @@
import asyncio import asyncio
import logging import logging
from collections import defaultdict from collections import defaultdict
from typing import Tuple, Optional
from pyee import EventEmitter from pyee import EventEmitter
from colors import color from colors import color
@@ -55,17 +56,32 @@ class Server(EventEmitter):
self.device = device self.device = device
self.attributes = [] # Attributes, ordered by increasing handle values self.attributes = [] # Attributes, ordered by increasing handle values
self.attributes_by_handle = {} # Map for fast attribute access by handle self.attributes_by_handle = {} # Map for fast attribute access by handle
self.max_mtu = GATT_SERVER_DEFAULT_MAX_MTU # The max MTU we're willing to negotiate self.max_mtu = (
self.subscribers = {} # Map of subscriber states by connection handle and attribute handle GATT_SERVER_DEFAULT_MAX_MTU # The max MTU we're willing to negotiate
)
self.subscribers = (
{}
) # Map of subscriber states by connection handle and attribute handle
self.indication_semaphores = defaultdict(lambda: asyncio.Semaphore(1)) self.indication_semaphores = defaultdict(lambda: asyncio.Semaphore(1))
self.pending_confirmations = defaultdict(lambda: None) self.pending_confirmations = defaultdict(lambda: None)
def __str__(self):
return "\n".join(map(str, self.attributes))
def send_gatt_pdu(self, connection_handle, pdu): def send_gatt_pdu(self, connection_handle, pdu):
self.device.send_l2cap_pdu(connection_handle, ATT_CID, pdu) self.device.send_l2cap_pdu(connection_handle, ATT_CID, pdu)
def next_handle(self): def next_handle(self):
return 1 + len(self.attributes) return 1 + len(self.attributes)
def get_advertising_service_data(self):
return {
attribute: data
for attribute in self.attributes
if isinstance(attribute, Service)
and (data := attribute.get_advertising_data())
}
def get_attribute(self, handle): def get_attribute(self, handle):
attribute = self.attributes_by_handle.get(handle) attribute = self.attributes_by_handle.get(handle)
if attribute: if attribute:
@@ -79,15 +95,74 @@ class Server(EventEmitter):
return attribute return attribute
return None return None
def get_service_attribute(self, service_uuid: UUID) -> Optional[Service]:
return next(
(
attribute
for attribute in self.attributes
if attribute.type == GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE
and attribute.uuid == service_uuid
),
None,
)
def get_characteristic_attributes(
self, service_uuid: UUID, characteristic_uuid: UUID
) -> Optional[Tuple[CharacteristicDeclaration, Characteristic]]:
service_handle = self.get_service_attribute(service_uuid)
if not service_handle:
return None
return next(
(
(attribute, self.get_attribute(attribute.characteristic.handle))
for attribute in map(
self.get_attribute,
range(service_handle.handle, service_handle.end_group_handle + 1),
)
if attribute.type == GATT_CHARACTERISTIC_ATTRIBUTE_TYPE
and attribute.characteristic.uuid == characteristic_uuid
),
None,
)
def get_descriptor_attribute(
self, service_uuid: UUID, characteristic_uuid: UUID, descriptor_uuid: UUID
) -> Optional[Descriptor]:
characteristics = self.get_characteristic_attributes(
service_uuid, characteristic_uuid
)
if not characteristics:
return None
(_, characteristic_value) = characteristics
return next(
(
attribute
for attribute in map(
self.get_attribute,
range(
characteristic_value.handle + 1,
characteristic_value.end_group_handle + 1,
),
)
if attribute.type == descriptor_uuid
),
None,
)
def add_attribute(self, attribute): def add_attribute(self, attribute):
# Assign a handle to this attribute # Assign a handle to this attribute
attribute.handle = self.next_handle() attribute.handle = self.next_handle()
attribute.end_group_handle = attribute.handle # TODO: keep track of descriptors in the group attribute.end_group_handle = (
attribute.handle
) # TODO: keep track of descriptors in the group
# Add this attribute to the list # Add this attribute to the list
self.attributes.append(attribute) self.attributes.append(attribute)
def add_service(self, service): def add_service(self, service: Service):
# Add the service attribute to the DB # Add the service attribute to the DB
self.add_attribute(service) self.add_attribute(service)
@@ -95,16 +170,9 @@ class Server(EventEmitter):
# Add all characteristics # Add all characteristics
for characteristic in service.characteristics: for characteristic in service.characteristics:
# Add a Characteristic Declaration (Vol 3, Part G - 3.3.1 Characteristic Declaration) # Add a Characteristic Declaration
declaration_bytes = struct.pack( characteristic_declaration = CharacteristicDeclaration(
'<BH', characteristic, self.next_handle() + 1
characteristic.properties,
self.next_handle() + 1, # The value will be the next attribute after this declaration
) + characteristic.uuid.to_pdu_bytes()
characteristic_declaration = Attribute(
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
Attribute.READABLE,
declaration_bytes
) )
self.add_attribute(characteristic_declaration) self.add_attribute(characteristic_declaration)
@@ -118,17 +186,25 @@ class Server(EventEmitter):
# If the characteristic supports subscriptions, add a CCCD descriptor # If the characteristic supports subscriptions, add a CCCD descriptor
# unless there is one already # unless there is one already
if ( if (
characteristic.properties & (Characteristic.NOTIFY | Characteristic.INDICATE) and characteristic.properties
characteristic.get_descriptor(GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR) is None & (Characteristic.NOTIFY | Characteristic.INDICATE)
and characteristic.get_descriptor(
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR
)
is None
): ):
self.add_attribute( self.add_attribute(
Descriptor( Descriptor(
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR, GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
Attribute.READABLE | Attribute.WRITEABLE, Attribute.READABLE | Attribute.WRITEABLE,
CharacteristicValue( CharacteristicValue(
read=lambda connection, characteristic=characteristic: self.read_cccd(connection, characteristic), read=lambda connection, characteristic=characteristic: self.read_cccd(
write=lambda connection, value, characteristic=characteristic: self.write_cccd(connection, characteristic, value) connection, characteristic
) ),
write=lambda connection, value, characteristic=characteristic: self.write_cccd(
connection, characteristic, value
),
),
) )
) )
@@ -155,7 +231,9 @@ class Server(EventEmitter):
return cccd or bytes([0, 0]) return cccd or bytes([0, 0])
def write_cccd(self, connection, characteristic, value): def write_cccd(self, connection, characteristic, value):
logger.debug(f'Subscription update for connection={connection.handle:04X}, handle={characteristic.handle:04X}: {value.hex()}') logger.debug(
f'Subscription update for connection=0x{connection.handle:04X}, handle=0x{characteristic.handle:04X}: {value.hex()}'
)
# Sanity check # Sanity check
if len(value) != 2: if len(value) != 2:
@@ -165,13 +243,23 @@ class Server(EventEmitter):
cccds = self.subscribers.setdefault(connection.handle, {}) cccds = self.subscribers.setdefault(connection.handle, {})
cccds[characteristic.handle] = value cccds[characteristic.handle] = value
logger.debug(f'CCCDs: {cccds}') logger.debug(f'CCCDs: {cccds}')
notify_enabled = (value[0] & 0x01 != 0) notify_enabled = value[0] & 0x01 != 0
indicate_enabled = (value[0] & 0x02 != 0) indicate_enabled = value[0] & 0x02 != 0
characteristic.emit('subscription', connection, notify_enabled, indicate_enabled) characteristic.emit(
self.emit('characteristic_subscription', connection, characteristic, notify_enabled, indicate_enabled) 'subscription', connection, notify_enabled, indicate_enabled
)
self.emit(
'characteristic_subscription',
connection,
characteristic,
notify_enabled,
indicate_enabled,
)
def send_response(self, connection, response): def send_response(self, connection, response):
logger.debug(f'GATT Response from server: [0x{connection.handle:04X}] {response}') logger.debug(
f'GATT Response from server: [0x{connection.handle:04X}] {response}'
)
self.send_gatt_pdu(connection.handle, response.to_bytes()) self.send_gatt_pdu(connection.handle, response.to_bytes())
async def notify_subscriber(self, connection, attribute, value=None, force=False): async def notify_subscriber(self, connection, attribute, value=None, force=False):
@@ -183,14 +271,20 @@ class Server(EventEmitter):
return return
cccd = subscribers.get(attribute.handle) cccd = subscribers.get(attribute.handle)
if not cccd: if not cccd:
logger.debug(f'not notifying, no subscribers for handle {attribute.handle:04X}') logger.debug(
f'not notifying, no subscribers for handle {attribute.handle:04X}'
)
return return
if len(cccd) != 2 or (cccd[0] & 0x01 == 0): if len(cccd) != 2 or (cccd[0] & 0x01 == 0):
logger.debug(f'not notifying, cccd={cccd.hex()}') logger.debug(f'not notifying, cccd={cccd.hex()}')
return return
# Get or encode the value # Get or encode the value
value = attribute.read_value(connection) if value is None else attribute.encode_value(value) value = (
attribute.read_value(connection)
if value is None
else attribute.encode_value(value)
)
# Truncate if needed # Truncate if needed
if len(value) > connection.att_mtu - 3: if len(value) > connection.att_mtu - 3:
@@ -198,10 +292,11 @@ class Server(EventEmitter):
# Notify # Notify
notification = ATT_Handle_Value_Notification( notification = ATT_Handle_Value_Notification(
attribute_handle = attribute.handle, attribute_handle=attribute.handle, attribute_value=value
attribute_value = value )
logger.debug(
f'GATT Notify from server: [0x{connection.handle:04X}] {notification}'
) )
logger.debug(f'GATT Notify from server: [0x{connection.handle:04X}] {notification}')
self.send_gatt_pdu(connection.handle, bytes(notification)) self.send_gatt_pdu(connection.handle, bytes(notification))
async def indicate_subscriber(self, connection, attribute, value=None, force=False): async def indicate_subscriber(self, connection, attribute, value=None, force=False):
@@ -213,14 +308,20 @@ class Server(EventEmitter):
return return
cccd = subscribers.get(attribute.handle) cccd = subscribers.get(attribute.handle)
if not cccd: if not cccd:
logger.debug(f'not indicating, no subscribers for handle {attribute.handle:04X}') logger.debug(
f'not indicating, no subscribers for handle {attribute.handle:04X}'
)
return return
if len(cccd) != 2 or (cccd[0] & 0x02 == 0): if len(cccd) != 2 or (cccd[0] & 0x02 == 0):
logger.debug(f'not indicating, cccd={cccd.hex()}') logger.debug(f'not indicating, cccd={cccd.hex()}')
return return
# Get or encode the value # Get or encode the value
value = attribute.read_value(connection) if value is None else attribute.encode_value(value) value = (
attribute.read_value(connection)
if value is None
else attribute.encode_value(value)
)
# Truncate if needed # Truncate if needed
if len(value) > connection.att_mtu - 3: if len(value) > connection.att_mtu - 3:
@@ -228,31 +329,39 @@ class Server(EventEmitter):
# Indicate # Indicate
indication = ATT_Handle_Value_Indication( indication = ATT_Handle_Value_Indication(
attribute_handle = attribute.handle, attribute_handle=attribute.handle, attribute_value=value
attribute_value = value )
logger.debug(
f'GATT Indicate from server: [0x{connection.handle:04X}] {indication}'
) )
logger.debug(f'GATT Indicate from server: [0x{connection.handle:04X}] {indication}')
# Wait until we can send (only one pending indication at a time per connection) # Wait until we can send (only one pending indication at a time per connection)
async with self.indication_semaphores[connection.handle]: async with self.indication_semaphores[connection.handle]:
assert(self.pending_confirmations[connection.handle] is None) assert self.pending_confirmations[connection.handle] is None
# Create a future value to hold the eventual response # Create a future value to hold the eventual response
self.pending_confirmations[connection.handle] = asyncio.get_running_loop().create_future() self.pending_confirmations[
connection.handle
] = asyncio.get_running_loop().create_future()
try: try:
self.send_gatt_pdu(connection.handle, indication.to_bytes()) self.send_gatt_pdu(connection.handle, indication.to_bytes())
await asyncio.wait_for(self.pending_confirmations[connection.handle], GATT_REQUEST_TIMEOUT) await asyncio.wait_for(
self.pending_confirmations[connection.handle], GATT_REQUEST_TIMEOUT
)
except asyncio.TimeoutError: except asyncio.TimeoutError:
logger.warning(color('!!! GATT Indicate timeout', 'red')) logger.warning(color('!!! GATT Indicate timeout', 'red'))
raise TimeoutError(f'GATT timeout for {indication.name}') raise TimeoutError(f'GATT timeout for {indication.name}')
finally: finally:
self.pending_confirmations[connection.handle] = None self.pending_confirmations[connection.handle] = None
async def notify_or_indicate_subscribers(self, indicate, attribute, value=None, force=False): async def notify_or_indicate_subscribers(
self, indicate, attribute, value=None, force=False
):
# Get all the connections for which there's at least one subscription # Get all the connections for which there's at least one subscription
connections = [ connections = [
connection for connection in [ connection
for connection in [
self.device.lookup_connection(connection_handle) self.device.lookup_connection(connection_handle)
for (connection_handle, subscribers) in self.subscribers.items() for (connection_handle, subscribers) in self.subscribers.items()
if force or subscribers.get(attribute.handle) if force or subscribers.get(attribute.handle)
@@ -263,10 +372,12 @@ class Server(EventEmitter):
# Indicate or notify for each connection # Indicate or notify for each connection
if connections: if connections:
coroutine = self.indicate_subscriber if indicate else self.notify_subscriber coroutine = self.indicate_subscriber if indicate else self.notify_subscriber
await asyncio.wait([ await asyncio.wait(
[
asyncio.create_task(coroutine(connection, attribute, value, force)) asyncio.create_task(coroutine(connection, attribute, value, force))
for connection in connections for connection in connections
]) ]
)
async def notify_subscribers(self, attribute, value=None, force=False): async def notify_subscribers(self, attribute, value=None, force=False):
return await self.notify_or_indicate_subscribers(False, attribute, value, force) return await self.notify_or_indicate_subscribers(False, attribute, value, force)
@@ -294,7 +405,7 @@ class Server(EventEmitter):
response = ATT_Error_Response( response = ATT_Error_Response(
request_opcode_in_error=att_pdu.op_code, request_opcode_in_error=att_pdu.op_code,
attribute_handle_in_error=error.att_handle, attribute_handle_in_error=error.att_handle,
error_code = error.error_code error_code=error.error_code,
) )
self.send_response(connection, response) self.send_response(connection, response)
except Exception as error: except Exception as error:
@@ -302,7 +413,7 @@ class Server(EventEmitter):
response = ATT_Error_Response( response = ATT_Error_Response(
request_opcode_in_error=att_pdu.op_code, request_opcode_in_error=att_pdu.op_code,
attribute_handle_in_error=0x0000, attribute_handle_in_error=0x0000,
error_code = ATT_UNLIKELY_ERROR_ERROR error_code=ATT_UNLIKELY_ERROR_ERROR,
) )
self.send_response(connection, response) self.send_response(connection, response)
raise error raise error
@@ -313,7 +424,9 @@ class Server(EventEmitter):
self.on_att_request(connection, att_pdu) self.on_att_request(connection, att_pdu)
else: else:
# Just ignore # Just ignore
logger.warning(f'{color("--- Ignoring GATT Request from [0x{connection.handle:04X}]:", "red")} {att_pdu}') logger.warning(
f'{color("--- Ignoring GATT Request from [0x{connection.handle:04X}]:", "red")} {att_pdu}'
)
####################################################### #######################################################
# ATT handlers # ATT handlers
@@ -322,11 +435,13 @@ class Server(EventEmitter):
''' '''
Handler for requests without a more specific handler Handler for requests without a more specific handler
''' '''
logger.warning(f'{color(f"--- Unsupported ATT Request from [0x{connection.handle:04X}]:", "red")} {pdu}') logger.warning(
f'{color(f"--- Unsupported ATT Request from [0x{connection.handle:04X}]:", "red")} {pdu}'
)
response = ATT_Error_Response( response = ATT_Error_Response(
request_opcode_in_error=pdu.op_code, request_opcode_in_error=pdu.op_code,
attribute_handle_in_error=0x0000, attribute_handle_in_error=0x0000,
error_code = ATT_REQUEST_NOT_SUPPORTED_ERROR error_code=ATT_REQUEST_NOT_SUPPORTED_ERROR,
) )
self.send_response(connection, response) self.send_response(connection, response)
@@ -334,7 +449,9 @@ class Server(EventEmitter):
''' '''
See Bluetooth spec Vol 3, Part F - 3.4.2.1 Exchange MTU Request See Bluetooth spec Vol 3, Part F - 3.4.2.1 Exchange MTU Request
''' '''
self.send_response(connection, ATT_Exchange_MTU_Response(server_rx_mtu = self.max_mtu)) self.send_response(
connection, ATT_Exchange_MTU_Response(server_rx_mtu=self.max_mtu)
)
# Compute the final MTU # Compute the final MTU
if request.client_rx_mtu >= ATT_DEFAULT_MTU: if request.client_rx_mtu >= ATT_DEFAULT_MTU:
@@ -351,12 +468,18 @@ class Server(EventEmitter):
''' '''
# Check the request parameters # Check the request parameters
if request.starting_handle == 0 or request.starting_handle > request.ending_handle: if (
self.send_response(connection, ATT_Error_Response( request.starting_handle == 0
or request.starting_handle > request.ending_handle
):
self.send_response(
connection,
ATT_Error_Response(
request_opcode_in_error=request.op_code, request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.starting_handle, attribute_handle_in_error=request.starting_handle,
error_code = ATT_INVALID_HANDLE_ERROR error_code=ATT_INVALID_HANDLE_ERROR,
)) ),
)
return return
# Build list of returned attributes # Build list of returned attributes
@@ -364,9 +487,10 @@ class Server(EventEmitter):
attributes = [] attributes = []
uuid_size = 0 uuid_size = 0
for attribute in ( for attribute in (
attribute for attribute in self.attributes if attribute
attribute.handle >= request.starting_handle and for attribute in self.attributes
attribute.handle <= request.ending_handle if attribute.handle >= request.starting_handle
and attribute.handle <= request.ending_handle
): ):
# TODO: check permissions # TODO: check permissions
@@ -394,13 +518,13 @@ class Server(EventEmitter):
] ]
response = ATT_Find_Information_Response( response = ATT_Find_Information_Response(
format=1 if len(attributes[0].type.to_pdu_bytes()) == 2 else 2, format=1 if len(attributes[0].type.to_pdu_bytes()) == 2 else 2,
information_data = b''.join(information_data_list) information_data=b''.join(information_data_list),
) )
else: else:
response = ATT_Error_Response( response = ATT_Error_Response(
request_opcode_in_error=request.op_code, request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.starting_handle, attribute_handle_in_error=request.starting_handle,
error_code = ATT_ATTRIBUTE_NOT_FOUND_ERROR error_code=ATT_ATTRIBUTE_NOT_FOUND_ERROR,
) )
self.send_response(connection, response) self.send_response(connection, response)
@@ -414,12 +538,13 @@ class Server(EventEmitter):
pdu_space_available = connection.att_mtu - 2 pdu_space_available = connection.att_mtu - 2
attributes = [] attributes = []
for attribute in ( for attribute in (
attribute for attribute in self.attributes if attribute
attribute.handle >= request.starting_handle and for attribute in self.attributes
attribute.handle <= request.ending_handle and if attribute.handle >= request.starting_handle
attribute.type == request.attribute_type and and attribute.handle <= request.ending_handle
attribute.read_value(connection) == request.attribute_value and and attribute.type == request.attribute_type
pdu_space_available >= 4 and attribute.read_value(connection) == request.attribute_value
and pdu_space_available >= 4
): ):
# TODO: check permissions # TODO: check permissions
@@ -434,14 +559,16 @@ class Server(EventEmitter):
if attribute.type in { if attribute.type in {
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE, GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE,
GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE, GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE,
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
}: }:
# Part of a group # Part of a group
group_end_handle = attribute.end_group_handle group_end_handle = attribute.end_group_handle
else: else:
# Not part of a group # Not part of a group
group_end_handle = attribute.handle group_end_handle = attribute.handle
handles_information_list.append(struct.pack('<HH', attribute.handle, group_end_handle)) handles_information_list.append(
struct.pack('<HH', attribute.handle, group_end_handle)
)
response = ATT_Find_By_Type_Value_Response( response = ATT_Find_By_Type_Value_Response(
handles_information_list=b''.join(handles_information_list) handles_information_list=b''.join(handles_information_list)
) )
@@ -449,7 +576,7 @@ class Server(EventEmitter):
response = ATT_Error_Response( response = ATT_Error_Response(
request_opcode_in_error=request.op_code, request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.starting_handle, attribute_handle_in_error=request.starting_handle,
error_code = ATT_ATTRIBUTE_NOT_FOUND_ERROR error_code=ATT_ATTRIBUTE_NOT_FOUND_ERROR,
) )
self.send_response(connection, response) self.send_response(connection, response)
@@ -462,11 +589,12 @@ class Server(EventEmitter):
pdu_space_available = connection.att_mtu - 2 pdu_space_available = connection.att_mtu - 2
attributes = [] attributes = []
for attribute in ( for attribute in (
attribute for attribute in self.attributes if attribute
attribute.type == request.attribute_type and for attribute in self.attributes
attribute.handle >= request.starting_handle and if attribute.type == request.attribute_type
attribute.handle <= request.ending_handle and and attribute.handle >= request.starting_handle
pdu_space_available and attribute.handle <= request.ending_handle
and pdu_space_available
): ):
# TODO: check permissions # TODO: check permissions
@@ -490,16 +618,17 @@ class Server(EventEmitter):
pdu_space_available -= entry_size pdu_space_available -= entry_size
if attributes: if attributes:
attribute_data_list = [struct.pack('<H', handle) + value for handle, value in attributes] attribute_data_list = [
struct.pack('<H', handle) + value for handle, value in attributes
]
response = ATT_Read_By_Type_Response( response = ATT_Read_By_Type_Response(
length = entry_size, length=entry_size, attribute_data_list=b''.join(attribute_data_list)
attribute_data_list = b''.join(attribute_data_list)
) )
else: else:
response = ATT_Error_Response( response = ATT_Error_Response(
request_opcode_in_error=request.op_code, request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.starting_handle, attribute_handle_in_error=request.starting_handle,
error_code = ATT_ATTRIBUTE_NOT_FOUND_ERROR error_code=ATT_ATTRIBUTE_NOT_FOUND_ERROR,
) )
self.send_response(connection, response) self.send_response(connection, response)
@@ -513,14 +642,12 @@ class Server(EventEmitter):
# TODO: check permissions # TODO: check permissions
value = attribute.read_value(connection) value = attribute.read_value(connection)
value_size = min(connection.att_mtu - 1, len(value)) value_size = min(connection.att_mtu - 1, len(value))
response = ATT_Read_Response( response = ATT_Read_Response(attribute_value=value[:value_size])
attribute_value = value[:value_size]
)
else: else:
response = ATT_Error_Response( response = ATT_Error_Response(
request_opcode_in_error=request.op_code, request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.attribute_handle, attribute_handle_in_error=request.attribute_handle,
error_code = ATT_INVALID_HANDLE_ERROR error_code=ATT_INVALID_HANDLE_ERROR,
) )
self.send_response(connection, response) self.send_response(connection, response)
@@ -536,24 +663,28 @@ class Server(EventEmitter):
response = ATT_Error_Response( response = ATT_Error_Response(
request_opcode_in_error=request.op_code, request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.attribute_handle, attribute_handle_in_error=request.attribute_handle,
error_code = ATT_INVALID_OFFSET_ERROR error_code=ATT_INVALID_OFFSET_ERROR,
) )
elif len(value) <= connection.att_mtu - 1: elif len(value) <= connection.att_mtu - 1:
response = ATT_Error_Response( response = ATT_Error_Response(
request_opcode_in_error=request.op_code, request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.attribute_handle, attribute_handle_in_error=request.attribute_handle,
error_code = ATT_ATTRIBUTE_NOT_LONG_ERROR error_code=ATT_ATTRIBUTE_NOT_LONG_ERROR,
) )
else: else:
part_size = min(connection.att_mtu - 1, len(value) - request.value_offset) part_size = min(
connection.att_mtu - 1, len(value) - request.value_offset
)
response = ATT_Read_Blob_Response( response = ATT_Read_Blob_Response(
part_attribute_value = value[request.value_offset:request.value_offset + part_size] part_attribute_value=value[
request.value_offset : request.value_offset + part_size
]
) )
else: else:
response = ATT_Error_Response( response = ATT_Error_Response(
request_opcode_in_error=request.op_code, request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.attribute_handle, attribute_handle_in_error=request.attribute_handle,
error_code = ATT_INVALID_HANDLE_ERROR error_code=ATT_INVALID_HANDLE_ERROR,
) )
self.send_response(connection, response) self.send_response(connection, response)
@@ -564,12 +695,12 @@ class Server(EventEmitter):
if request.attribute_group_type not in { if request.attribute_group_type not in {
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE, GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE,
GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE, GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE,
GATT_INCLUDE_ATTRIBUTE_TYPE GATT_INCLUDE_ATTRIBUTE_TYPE,
}: }:
response = ATT_Error_Response( response = ATT_Error_Response(
request_opcode_in_error=request.op_code, request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.starting_handle, attribute_handle_in_error=request.starting_handle,
error_code = ATT_UNSUPPORTED_GROUP_TYPE_ERROR error_code=ATT_UNSUPPORTED_GROUP_TYPE_ERROR,
) )
self.send_response(connection, response) self.send_response(connection, response)
return return
@@ -577,11 +708,12 @@ class Server(EventEmitter):
pdu_space_available = connection.att_mtu - 2 pdu_space_available = connection.att_mtu - 2
attributes = [] attributes = []
for attribute in ( for attribute in (
attribute for attribute in self.attributes if attribute
attribute.type == request.attribute_group_type and for attribute in self.attributes
attribute.handle >= request.starting_handle and if attribute.type == request.attribute_group_type
attribute.handle <= request.ending_handle and and attribute.handle >= request.starting_handle
pdu_space_available and attribute.handle <= request.ending_handle
and pdu_space_available
): ):
# Check the attribute value size # Check the attribute value size
attribute_value = attribute.read_value(connection) attribute_value = attribute.read_value(connection)
@@ -599,7 +731,9 @@ class Server(EventEmitter):
break break
# Add the attribute to the list # Add the attribute to the list
attributes.append((attribute.handle, attribute.end_group_handle, attribute_value)) attributes.append(
(attribute.handle, attribute.end_group_handle, attribute_value)
)
pdu_space_available -= entry_size pdu_space_available -= entry_size
if attributes: if attributes:
@@ -609,13 +743,13 @@ class Server(EventEmitter):
] ]
response = ATT_Read_By_Group_Type_Response( response = ATT_Read_By_Group_Type_Response(
length=len(attribute_data_list[0]), length=len(attribute_data_list[0]),
attribute_data_list = b''.join(attribute_data_list) attribute_data_list=b''.join(attribute_data_list),
) )
else: else:
response = ATT_Error_Response( response = ATT_Error_Response(
request_opcode_in_error=request.op_code, request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.starting_handle, attribute_handle_in_error=request.starting_handle,
error_code = ATT_ATTRIBUTE_NOT_FOUND_ERROR error_code=ATT_ATTRIBUTE_NOT_FOUND_ERROR,
) )
self.send_response(connection, response) self.send_response(connection, response)
@@ -628,22 +762,28 @@ class Server(EventEmitter):
# Check that the attribute exists # Check that the attribute exists
attribute = self.get_attribute(request.attribute_handle) attribute = self.get_attribute(request.attribute_handle)
if attribute is None: if attribute is None:
self.send_response(connection, ATT_Error_Response( self.send_response(
connection,
ATT_Error_Response(
request_opcode_in_error=request.op_code, request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.attribute_handle, attribute_handle_in_error=request.attribute_handle,
error_code = ATT_INVALID_HANDLE_ERROR error_code=ATT_INVALID_HANDLE_ERROR,
)) ),
)
return return
# TODO: check permissions # TODO: check permissions
# Check the request parameters # Check the request parameters
if len(request.attribute_value) > GATT_MAX_ATTRIBUTE_VALUE_SIZE: if len(request.attribute_value) > GATT_MAX_ATTRIBUTE_VALUE_SIZE:
self.send_response(connection, ATT_Error_Response( self.send_response(
connection,
ATT_Error_Response(
request_opcode_in_error=request.op_code, request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.attribute_handle, attribute_handle_in_error=request.attribute_handle,
error_code = ATT_INVALID_ATTRIBUTE_LENGTH_ERROR error_code=ATT_INVALID_ATTRIBUTE_LENGTH_ERROR,
)) ),
)
return return
# Accept the value # Accept the value
@@ -680,7 +820,9 @@ class Server(EventEmitter):
''' '''
if self.pending_confirmations[connection.handle] is None: if self.pending_confirmations[connection.handle] is None:
# Not expected! # Not expected!
logger.warning('!!! unexpected confirmation, there is no pending indication') logger.warning(
'!!! unexpected confirmation, there is no pending indication'
)
return return
self.pending_confirmations[connection.handle].set_result(None) self.pending_confirmations[connection.handle].set_result(None)
+1660 -596
View File
File diff suppressed because it is too large Load Diff
+41 -18
View File
@@ -29,20 +29,17 @@ from .l2cap import (
L2CAP_SIGNALING_CID, L2CAP_SIGNALING_CID,
L2CAP_LE_SIGNALING_CID, L2CAP_LE_SIGNALING_CID,
L2CAP_Control_Frame, L2CAP_Control_Frame,
L2CAP_Connection_Response L2CAP_Connection_Response,
) )
from .hci import ( from .hci import (
HCI_EVENT_PACKET, HCI_EVENT_PACKET,
HCI_ACL_DATA_PACKET, HCI_ACL_DATA_PACKET,
HCI_DISCONNECTION_COMPLETE_EVENT, HCI_DISCONNECTION_COMPLETE_EVENT,
HCI_AclDataPacketAssembler HCI_AclDataPacketAssembler,
) )
from .rfcomm import RFCOMM_Frame, RFCOMM_PSM from .rfcomm import RFCOMM_Frame, RFCOMM_PSM
from .sdp import SDP_PDU, SDP_PSM from .sdp import SDP_PDU, SDP_PSM
from .avdtp import ( from .avdtp import MessageAssembler as AVDTP_MessageAssembler, AVDTP_PSM
MessageAssembler as AVDTP_MessageAssembler,
AVDTP_PSM
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -78,7 +75,10 @@ class PacketTracer:
elif l2cap_pdu.cid == SMP_CID: elif l2cap_pdu.cid == SMP_CID:
smp_command = SMP_Command.from_bytes(l2cap_pdu.payload) smp_command = SMP_Command.from_bytes(l2cap_pdu.payload)
self.analyzer.emit(smp_command) self.analyzer.emit(smp_command)
elif l2cap_pdu.cid == L2CAP_SIGNALING_CID or l2cap_pdu.cid == L2CAP_LE_SIGNALING_CID: elif (
l2cap_pdu.cid == L2CAP_SIGNALING_CID
or l2cap_pdu.cid == L2CAP_LE_SIGNALING_CID
):
control_frame = L2CAP_Control_Frame.from_bytes(l2cap_pdu.payload) control_frame = L2CAP_Control_Frame.from_bytes(l2cap_pdu.payload)
self.analyzer.emit(control_frame) self.analyzer.emit(control_frame)
@@ -86,7 +86,10 @@ class PacketTracer:
if control_frame.code == L2CAP_CONNECTION_REQUEST: if control_frame.code == L2CAP_CONNECTION_REQUEST:
self.psms[control_frame.source_cid] = control_frame.psm self.psms[control_frame.source_cid] = control_frame.psm
elif control_frame.code == L2CAP_CONNECTION_RESPONSE: elif control_frame.code == L2CAP_CONNECTION_RESPONSE:
if control_frame.result == L2CAP_Connection_Response.CONNECTION_SUCCESSFUL: if (
control_frame.result
== L2CAP_Connection_Response.CONNECTION_SUCCESSFUL
):
if self.peer: if self.peer:
if psm := self.peer.psms.get(control_frame.source_cid): if psm := self.peer.psms.get(control_frame.source_cid):
# Found a pending connection # Found a pending connection
@@ -94,8 +97,14 @@ class PacketTracer:
# For AVDTP connections, create a packet assembler for each direction # For AVDTP connections, create a packet assembler for each direction
if psm == AVDTP_PSM: if psm == AVDTP_PSM:
self.avdtp_assemblers[control_frame.source_cid] = AVDTP_MessageAssembler(self.on_avdtp_message) self.avdtp_assemblers[
self.peer.avdtp_assemblers[control_frame.destination_cid] = AVDTP_MessageAssembler(self.peer.on_avdtp_message) control_frame.source_cid
] = AVDTP_MessageAssembler(self.on_avdtp_message)
self.peer.avdtp_assemblers[
control_frame.destination_cid
] = AVDTP_MessageAssembler(
self.peer.on_avdtp_message
)
else: else:
# Try to find the PSM associated with this PDU # Try to find the PSM associated with this PDU
@@ -107,18 +116,24 @@ class PacketTracer:
rfcomm_frame = RFCOMM_Frame.from_bytes(l2cap_pdu.payload) rfcomm_frame = RFCOMM_Frame.from_bytes(l2cap_pdu.payload)
self.analyzer.emit(rfcomm_frame) self.analyzer.emit(rfcomm_frame)
elif psm == AVDTP_PSM: elif psm == AVDTP_PSM:
self.analyzer.emit(f'{color("L2CAP", "green")} [CID={l2cap_pdu.cid}, PSM=AVDTP]: {l2cap_pdu.payload.hex()}') self.analyzer.emit(
f'{color("L2CAP", "green")} [CID={l2cap_pdu.cid}, PSM=AVDTP]: {l2cap_pdu.payload.hex()}'
)
assembler = self.avdtp_assemblers.get(l2cap_pdu.cid) assembler = self.avdtp_assemblers.get(l2cap_pdu.cid)
if assembler: if assembler:
assembler.on_pdu(l2cap_pdu.payload) assembler.on_pdu(l2cap_pdu.payload)
else: else:
psm_string = name_or_number(PSM_NAMES, psm) psm_string = name_or_number(PSM_NAMES, psm)
self.analyzer.emit(f'{color("L2CAP", "green")} [CID={l2cap_pdu.cid}, PSM={psm_string}]: {l2cap_pdu.payload.hex()}') self.analyzer.emit(
f'{color("L2CAP", "green")} [CID={l2cap_pdu.cid}, PSM={psm_string}]: {l2cap_pdu.payload.hex()}'
)
else: else:
self.analyzer.emit(l2cap_pdu) self.analyzer.emit(l2cap_pdu)
def on_avdtp_message(self, transaction_label, message): def on_avdtp_message(self, transaction_label, message):
self.analyzer.emit(f'{color("AVDTP", "green")} [{transaction_label}] {message}') self.analyzer.emit(
f'{color("AVDTP", "green")} [{transaction_label}] {message}'
)
def feed_packet(self, packet): def feed_packet(self, packet):
self.packet_assembler.feed_packet(packet) self.packet_assembler.feed_packet(packet)
@@ -131,7 +146,9 @@ class PacketTracer:
self.peer = None # Analyzer in the other direction self.peer = None # Analyzer in the other direction
def start_acl_stream(self, connection_handle): def start_acl_stream(self, connection_handle):
logger.info(f'[{self.label}] +++ Creating ACL stream for connection 0x{connection_handle:04X}') logger.info(
f'[{self.label}] +++ Creating ACL stream for connection 0x{connection_handle:04X}'
)
stream = PacketTracer.AclStream(self) stream = PacketTracer.AclStream(self)
self.acl_streams[connection_handle] = stream self.acl_streams[connection_handle] = stream
@@ -144,7 +161,9 @@ class PacketTracer:
def end_acl_stream(self, connection_handle): def end_acl_stream(self, connection_handle):
if connection_handle in self.acl_streams: if connection_handle in self.acl_streams:
logger.info(f'[{self.label}] --- Removing ACL stream for connection 0x{connection_handle:04X}') logger.info(
f'[{self.label}] --- Removing ACL stream for connection 0x{connection_handle:04X}'
)
del self.acl_streams[connection_handle] del self.acl_streams[connection_handle]
# Let the other forwarder know so it can cleanup its stream as well # Let the other forwarder know so it can cleanup its stream as well
@@ -176,9 +195,13 @@ class PacketTracer:
self, self,
host_to_controller_label=color('HOST->CONTROLLER', 'blue'), host_to_controller_label=color('HOST->CONTROLLER', 'blue'),
controller_to_host_label=color('CONTROLLER->HOST', 'cyan'), controller_to_host_label=color('CONTROLLER->HOST', 'cyan'),
emit_message=logger.info emit_message=logger.info,
): ):
self.host_to_controller_analyzer = PacketTracer.Analyzer(host_to_controller_label, emit_message) self.host_to_controller_analyzer = PacketTracer.Analyzer(
self.controller_to_host_analyzer = PacketTracer.Analyzer(controller_to_host_label, emit_message) host_to_controller_label, emit_message
)
self.controller_to_host_analyzer = PacketTracer.Analyzer(
controller_to_host_label, emit_message
)
self.host_to_controller_analyzer.peer = self.controller_to_host_analyzer self.host_to_controller_analyzer.peer = self.controller_to_host_analyzer
self.controller_to_host_analyzer.peer = self.host_to_controller_analyzer self.controller_to_host_analyzer.peer = self.host_to_controller_analyzer
+243 -107
View File
@@ -36,11 +36,15 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Constants # Constants
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# fmt: off
HOST_DEFAULT_HC_LE_ACL_DATA_PACKET_LENGTH = 27 HOST_DEFAULT_HC_LE_ACL_DATA_PACKET_LENGTH = 27
HOST_HC_TOTAL_NUM_LE_ACL_DATA_PACKETS = 1 HOST_HC_TOTAL_NUM_LE_ACL_DATA_PACKETS = 1
HOST_DEFAULT_HC_ACL_DATA_PACKET_LENGTH = 27 HOST_DEFAULT_HC_ACL_DATA_PACKET_LENGTH = 27
HOST_HC_TOTAL_NUM_ACL_DATA_PACKETS = 1 HOST_HC_TOTAL_NUM_ACL_DATA_PACKETS = 1
# fmt: on
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Connection: class Connection:
@@ -76,9 +80,11 @@ class Host(EventEmitter):
self.hc_total_num_acl_data_packets = HOST_HC_TOTAL_NUM_ACL_DATA_PACKETS self.hc_total_num_acl_data_packets = HOST_HC_TOTAL_NUM_ACL_DATA_PACKETS
self.acl_packet_queue = collections.deque() self.acl_packet_queue = collections.deque()
self.acl_packets_in_flight = 0 self.acl_packets_in_flight = 0
self.local_version = HCI_VERSION_BLUETOOTH_CORE_4_0 self.local_version = None
self.local_supported_commands = bytes(64) self.local_supported_commands = bytes(64)
self.local_le_features = 0 self.local_le_features = 0
self.suggested_max_tx_octets = 251 # Max allowed
self.suggested_max_tx_time = 2120 # Max allowed
self.command_semaphore = asyncio.Semaphore(1) self.command_semaphore = asyncio.Semaphore(1)
self.long_term_key_provider = None self.long_term_key_provider = None
self.link_key_provider = None self.link_key_provider = None
@@ -91,68 +97,105 @@ class Host(EventEmitter):
self.set_packet_sink(controller_sink) self.set_packet_sink(controller_sink)
async def reset(self): async def reset(self):
await self.send_command(HCI_Reset_Command()) await self.send_command(HCI_Reset_Command(), check_result=True)
self.ready = True self.ready = True
response = await self.send_command(HCI_Read_Local_Supported_Commands_Command()) response = await self.send_command(
if response.return_parameters.status == HCI_SUCCESS: HCI_Read_Local_Supported_Commands_Command(), check_result=True
)
self.local_supported_commands = response.return_parameters.supported_commands self.local_supported_commands = response.return_parameters.supported_commands
else:
logger.warn(f'HCI_Read_Local_Supported_Commands_Command failed: {response.return_parameters.status}')
if self.supports_command(HCI_LE_READ_LOCAL_SUPPORTED_FEATURES_COMMAND): if self.supports_command(HCI_LE_READ_LOCAL_SUPPORTED_FEATURES_COMMAND):
response = await self.send_command(HCI_LE_Read_Local_Supported_Features_Command()) response = await self.send_command(
if response.return_parameters.status == HCI_SUCCESS: HCI_LE_Read_Local_Supported_Features_Command(), check_result=True
self.local_le_features = struct.unpack('<Q', response.return_parameters.le_features)[0] )
else: self.local_le_features = struct.unpack(
logger.warn(f'HCI_LE_Read_Supported_Features_Command failed: {response.return_parameters.status}') '<Q', response.return_parameters.le_features
)[0]
if self.supports_command(HCI_READ_LOCAL_VERSION_INFORMATION_COMMAND): if self.supports_command(HCI_READ_LOCAL_VERSION_INFORMATION_COMMAND):
response = await self.send_command(HCI_Read_Local_Version_Information_Command()) response = await self.send_command(
if response.return_parameters.status == HCI_SUCCESS: HCI_Read_Local_Version_Information_Command(), check_result=True
)
self.local_version = response.return_parameters self.local_version = response.return_parameters
else:
logger.warn(f'HCI_Read_Local_Version_Information_Command failed: {response.return_parameters.status}')
await self.send_command(HCI_Set_Event_Mask_Command(event_mask = bytes.fromhex('FFFFFFFFFFFFFF3F'))) await self.send_command(
HCI_Set_Event_Mask_Command(event_mask=bytes.fromhex('FFFFFFFFFFFFFF3F'))
)
if self.local_version.hci_version <= HCI_VERSION_BLUETOOTH_CORE_4_0: if (
self.local_version is not None
and self.local_version.hci_version <= HCI_VERSION_BLUETOOTH_CORE_4_0
):
# Some older controllers don't like event masks with bits they don't understand # Some older controllers don't like event masks with bits they don't understand
le_event_mask = bytes.fromhex('1F00000000000000') le_event_mask = bytes.fromhex('1F00000000000000')
else: else:
le_event_mask = bytes.fromhex('FFFFF00000000000') le_event_mask = bytes.fromhex('FFFFF00000000000')
await self.send_command(HCI_LE_Set_Event_Mask_Command(le_event_mask = le_event_mask)) await self.send_command(
HCI_LE_Set_Event_Mask_Command(le_event_mask=le_event_mask)
)
if self.supports_command(HCI_READ_BUFFER_SIZE_COMMAND): if self.supports_command(HCI_READ_BUFFER_SIZE_COMMAND):
response = await self.send_command(HCI_Read_Buffer_Size_Command()) response = await self.send_command(
if response.return_parameters.status == HCI_SUCCESS: HCI_Read_Buffer_Size_Command(), check_result=True
self.hc_acl_data_packet_length = response.return_parameters.hc_acl_data_packet_length )
self.hc_total_num_acl_data_packets = response.return_parameters.hc_total_num_acl_data_packets self.hc_acl_data_packet_length = (
else: response.return_parameters.hc_acl_data_packet_length
logger.warn(f'HCI_Read_Buffer_Size_Command failed: {response.return_parameters.status}') )
self.hc_total_num_acl_data_packets = (
if self.supports_command(HCI_LE_READ_BUFFER_SIZE_COMMAND): response.return_parameters.hc_total_num_acl_data_packets
response = await self.send_command(HCI_LE_Read_Buffer_Size_Command()) )
if response.return_parameters.status == HCI_SUCCESS:
self.hc_le_acl_data_packet_length = response.return_parameters.hc_le_acl_data_packet_length
self.hc_total_num_le_acl_data_packets = response.return_parameters.hc_total_num_le_acl_data_packets
else:
logger.warn(f'HCI_LE_Read_Buffer_Size_Command failed: {response.return_parameters.status}')
if response.return_parameters.hc_le_acl_data_packet_length == 0 or response.return_parameters.hc_total_num_le_acl_data_packets == 0:
# LE and Classic share the same values
self.hc_le_acl_data_packet_length = self.hc_acl_data_packet_length
self.hc_total_num_le_acl_data_packets = self.hc_total_num_acl_data_packets
logger.debug( logger.debug(
f'HCI ACL flow control: hc_acl_data_packet_length={self.hc_acl_data_packet_length},' f'HCI ACL flow control: hc_acl_data_packet_length={self.hc_acl_data_packet_length},'
f'hc_total_num_acl_data_packets={self.hc_total_num_acl_data_packets}' f'hc_total_num_acl_data_packets={self.hc_total_num_acl_data_packets}'
) )
if self.supports_command(HCI_LE_READ_BUFFER_SIZE_COMMAND):
response = await self.send_command(
HCI_LE_Read_Buffer_Size_Command(), check_result=True
)
self.hc_le_acl_data_packet_length = (
response.return_parameters.hc_le_acl_data_packet_length
)
self.hc_total_num_le_acl_data_packets = (
response.return_parameters.hc_total_num_le_acl_data_packets
)
logger.debug( logger.debug(
f'HCI LE ACL flow control: hc_le_acl_data_packet_length={self.hc_le_acl_data_packet_length},' f'HCI LE ACL flow control: hc_le_acl_data_packet_length={self.hc_le_acl_data_packet_length},'
f'hc_total_num_le_acl_data_packets={self.hc_total_num_le_acl_data_packets}' f'hc_total_num_le_acl_data_packets={self.hc_total_num_le_acl_data_packets}'
) )
if (
response.return_parameters.hc_le_acl_data_packet_length == 0
or response.return_parameters.hc_total_num_le_acl_data_packets == 0
):
# LE and Classic share the same values
self.hc_le_acl_data_packet_length = self.hc_acl_data_packet_length
self.hc_total_num_le_acl_data_packets = (
self.hc_total_num_acl_data_packets
)
if self.supports_command(
HCI_LE_READ_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND
) and self.supports_command(HCI_LE_WRITE_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND):
response = await self.send_command(
HCI_LE_Read_Suggested_Default_Data_Length_Command()
)
suggested_max_tx_octets = response.return_parameters.suggested_max_tx_octets
suggested_max_tx_time = response.return_parameters.suggested_max_tx_time
if (
suggested_max_tx_octets != self.suggested_max_tx_octets
or suggested_max_tx_time != self.suggested_max_tx_time
):
await self.send_command(
HCI_LE_Write_Suggested_Default_Data_Length_Command(
suggested_max_tx_octets=self.suggested_max_tx_octets,
suggested_max_tx_time=self.suggested_max_tx_time,
)
)
self.reset_done = True self.reset_done = True
@property @property
@@ -171,7 +214,7 @@ class Host(EventEmitter):
def send_hci_packet(self, packet): def send_hci_packet(self, packet):
self.hci_sink.on_packet(packet.to_bytes()) self.hci_sink.on_packet(packet.to_bytes())
async def send_command(self, command): async def send_command(self, command, check_result=False):
logger.debug(f'{color("### HOST -> CONTROLLER", "blue")}: {command}') logger.debug(f'{color("### HOST -> CONTROLLER", "blue")}: {command}')
# Wait until we can send (only one pending command at a time) # Wait until we can send (only one pending command at a time)
@@ -186,11 +229,29 @@ class Host(EventEmitter):
try: try:
self.send_hci_packet(command) self.send_hci_packet(command)
response = await self.pending_response response = await self.pending_response
# TODO: check error values
# Check the return parameters if required
if check_result:
if type(response.return_parameters) is int:
status = response.return_parameters
elif type(response.return_parameters) is bytes:
# return parameters first field is a one byte status code
status = response.return_parameters[0]
else:
status = response.return_parameters.status
if status != HCI_SUCCESS:
logger.warning(
f'{command.name} failed ({HCI_Constant.error_name(status)})'
)
raise HCI_Error(status)
return response return response
except Exception as error: except Exception as error:
logger.warning(f'{color("!!! Exception while sending HCI packet:", "red")} {error}') logger.warning(
# raise error f'{color("!!! Exception while sending HCI packet:", "red")} {error}'
)
raise error
finally: finally:
self.pending_command = None self.pending_command = None
self.pending_response = None self.pending_response = None
@@ -217,9 +278,11 @@ class Host(EventEmitter):
pb_flag=pb_flag, pb_flag=pb_flag,
bc_flag=0, bc_flag=0,
data_total_length=data_total_length, data_total_length=data_total_length,
data = l2cap_pdu[offset:offset + data_total_length] data=l2cap_pdu[offset : offset + data_total_length],
)
logger.debug(
f'{color("### HOST -> CONTROLLER", "blue")}: (CID={cid}) {acl_packet}'
) )
logger.debug(f'{color("### HOST -> CONTROLLER", "blue")}: (CID={cid}) {acl_packet}')
self.queue_acl_packet(acl_packet) self.queue_acl_packet(acl_packet)
pb_flag = 1 pb_flag = 1
offset += data_total_length offset += data_total_length
@@ -230,11 +293,16 @@ class Host(EventEmitter):
self.check_acl_packet_queue() self.check_acl_packet_queue()
if len(self.acl_packet_queue): if len(self.acl_packet_queue):
logger.debug(f'{self.acl_packets_in_flight} ACL packets in flight, {len(self.acl_packet_queue)} in queue') logger.debug(
f'{self.acl_packets_in_flight} ACL packets in flight, {len(self.acl_packet_queue)} in queue'
)
def check_acl_packet_queue(self): def check_acl_packet_queue(self):
# Send all we can (TODO: support different LE/Classic limits) # Send all we can (TODO: support different LE/Classic limits)
while len(self.acl_packet_queue) > 0 and self.acl_packets_in_flight < self.hc_total_num_le_acl_data_packets: while (
len(self.acl_packet_queue) > 0
and self.acl_packets_in_flight < self.hc_total_num_le_acl_data_packets
):
packet = self.acl_packet_queue.pop() packet = self.acl_packet_queue.pop()
self.send_hci_packet(packet) self.send_hci_packet(packet)
self.acl_packets_in_flight += 1 self.acl_packets_in_flight += 1
@@ -246,7 +314,9 @@ class Host(EventEmitter):
if value == command: if value == command:
# Check if the flag is set # Check if the flag is set
if octet < len(self.local_supported_commands) and flag_position < 8: if octet < len(self.local_supported_commands) and flag_position < 8:
return (self.local_supported_commands[octet] & (1 << flag_position)) != 0 return (
self.local_supported_commands[octet] & (1 << flag_position)
) != 0
return False return False
@@ -268,15 +338,17 @@ class Host(EventEmitter):
@property @property
def supported_le_features(self): def supported_le_features(self):
return [feature for feature in range(64) if self.local_le_features & (1 << feature)] return [
feature for feature in range(64) if self.local_le_features & (1 << feature)
]
# Packet Sink protocol (packets coming from the controller via HCI) # Packet Sink protocol (packets coming from the controller via HCI)
def on_packet(self, packet): def on_packet(self, packet):
hci_packet = HCI_Packet.from_bytes(packet) hci_packet = HCI_Packet.from_bytes(packet)
if self.ready or ( if self.ready or (
hci_packet.hci_packet_type == HCI_EVENT_PACKET and hci_packet.hci_packet_type == HCI_EVENT_PACKET
hci_packet.event_code == HCI_COMMAND_COMPLETE_EVENT and and hci_packet.event_code == HCI_COMMAND_COMPLETE_EVENT
hci_packet.command_opcode == HCI_RESET_COMMAND and hci_packet.command_opcode == HCI_RESET_COMMAND
): ):
self.on_hci_packet(hci_packet) self.on_hci_packet(hci_packet)
else: else:
@@ -315,7 +387,9 @@ class Host(EventEmitter):
if self.pending_response: if self.pending_response:
# Check that it is what we were expecting # Check that it is what we were expecting
if self.pending_command.op_code != event.command_opcode: if self.pending_command.op_code != event.command_opcode:
logger.warning(f'!!! command result mismatch, expected 0x{self.pending_command.op_code:X} but got 0x{event.command_opcode:X}') logger.warning(
f'!!! command result mismatch, expected 0x{self.pending_command.op_code:X} but got 0x{event.command_opcode:X}'
)
self.pending_response.set_result(event) self.pending_response.set_result(event)
else: else:
@@ -343,36 +417,47 @@ class Host(EventEmitter):
self.acl_packets_in_flight -= total_packets self.acl_packets_in_flight -= total_packets
self.check_acl_packet_queue() self.check_acl_packet_queue()
else: else:
logger.warning(color(f'!!! {total_packets} completed but only {self.acl_packets_in_flight} in flight')) logger.warning(
color(
f'!!! {total_packets} completed but only {self.acl_packets_in_flight} in flight'
)
)
self.acl_packets_in_flight = 0 self.acl_packets_in_flight = 0
# Classic only # Classic only
def on_hci_connection_request_event(self, event): def on_hci_connection_request_event(self, event):
# For now, just accept everything # Notify the listeners
# TODO: delegate the decision self.emit(
self.send_command_sync( 'connection_request',
HCI_Accept_Connection_Request_Command( event.bd_addr,
bd_addr = event.bd_addr, event.class_of_device,
role = 0x01 # Remain the peripheral event.link_type,
)
) )
def on_hci_le_connection_complete_event(self, event): def on_hci_le_connection_complete_event(self, event):
# Check if this is a cancellation # Check if this is a cancellation
if event.status == HCI_SUCCESS: if event.status == HCI_SUCCESS:
# Create/update the connection # Create/update the connection
logger.debug(f'### CONNECTION: [0x{event.connection_handle:04X}] {event.peer_address} as {HCI_Constant.role_name(event.role)}') logger.debug(
f'### CONNECTION: [0x{event.connection_handle:04X}] {event.peer_address} as {HCI_Constant.role_name(event.role)}'
)
connection = self.connections.get(event.connection_handle) connection = self.connections.get(event.connection_handle)
if connection is None: if connection is None:
connection = Connection(self, event.connection_handle, event.role, event.peer_address, BT_LE_TRANSPORT) connection = Connection(
self,
event.connection_handle,
event.role,
event.peer_address,
BT_LE_TRANSPORT,
)
self.connections[event.connection_handle] = connection self.connections[event.connection_handle] = connection
# Notify the client # Notify the client
connection_parameters = ConnectionParameters( connection_parameters = ConnectionParameters(
event.conn_interval, event.connection_interval,
event.conn_latency, event.peripheral_latency,
event.supervision_timeout event.supervision_timeout,
) )
self.emit( self.emit(
'connection', 'connection',
@@ -381,13 +466,15 @@ class Host(EventEmitter):
event.peer_address, event.peer_address,
None, None,
event.role, event.role,
connection_parameters connection_parameters,
) )
else: else:
logger.debug(f'### CONNECTION FAILED: {event.status}') logger.debug(f'### CONNECTION FAILED: {event.status}')
# Notify the listeners # Notify the listeners
self.emit('connection_failure', event.status) self.emit(
'connection_failure', BT_LE_TRANSPORT, event.peer_address, event.status
)
def on_hci_le_enhanced_connection_complete_event(self, event): def on_hci_le_enhanced_connection_complete_event(self, event):
# Just use the same implementation as for the non-enhanced event for now # Just use the same implementation as for the non-enhanced event for now
@@ -396,11 +483,19 @@ class Host(EventEmitter):
def on_hci_connection_complete_event(self, event): def on_hci_connection_complete_event(self, event):
if event.status == HCI_SUCCESS: if event.status == HCI_SUCCESS:
# Create/update the connection # Create/update the connection
logger.debug(f'### BR/EDR CONNECTION: [0x{event.connection_handle:04X}] {event.bd_addr}') logger.debug(
f'### BR/EDR CONNECTION: [0x{event.connection_handle:04X}] {event.bd_addr}'
)
connection = self.connections.get(event.connection_handle) connection = self.connections.get(event.connection_handle)
if connection is None: if connection is None:
connection = Connection(self, event.connection_handle, BT_CENTRAL_ROLE, event.bd_addr, BT_BR_EDR_TRANSPORT) connection = Connection(
self,
event.connection_handle,
BT_CENTRAL_ROLE,
event.bd_addr,
BT_BR_EDR_TRANSPORT,
)
self.connections[event.connection_handle] = connection self.connections[event.connection_handle] = connection
# Notify the client # Notify the client
@@ -411,13 +506,15 @@ class Host(EventEmitter):
event.bd_addr, event.bd_addr,
None, None,
BT_CENTRAL_ROLE, BT_CENTRAL_ROLE,
None None,
) )
else: else:
logger.debug(f'### BR/EDR CONNECTION FAILED: {event.status}') logger.debug(f'### BR/EDR CONNECTION FAILED: {event.status}')
# Notify the client # Notify the client
self.emit('connection_failure', event.connection_handle, event.status) self.emit(
'connection_failure', BT_BR_EDR_TRANSPORT, event.bd_addr, event.status
)
def on_hci_disconnection_complete_event(self, event): def on_hci_disconnection_complete_event(self, event):
# Find the connection # Find the connection
@@ -426,7 +523,9 @@ class Host(EventEmitter):
return return
if event.status == HCI_SUCCESS: if event.status == HCI_SUCCESS:
logger.debug(f'### DISCONNECTION: [0x{event.connection_handle:04X}] {connection.peer_address} as {HCI_Constant.role_name(connection.role)}, reason={event.reason}') logger.debug(
f'### DISCONNECTION: [0x{event.connection_handle:04X}] {connection.peer_address} as {HCI_Constant.role_name(connection.role)}, reason={event.reason}'
)
del self.connections[event.connection_handle] del self.connections[event.connection_handle]
# Notify the listeners # Notify the listeners
@@ -435,7 +534,7 @@ class Host(EventEmitter):
logger.debug(f'### DISCONNECTION FAILED: {event.status}') logger.debug(f'### DISCONNECTION FAILED: {event.status}')
# Notify the listeners # Notify the listeners
self.emit('disconnection_failure', event.status) self.emit('disconnection_failure', event.connection_handle, event.status)
def on_hci_le_connection_update_complete_event(self, event): def on_hci_le_connection_update_complete_event(self, event):
if (connection := self.connections.get(event.connection_handle)) is None: if (connection := self.connections.get(event.connection_handle)) is None:
@@ -445,13 +544,17 @@ class Host(EventEmitter):
# Notify the client # Notify the client
if event.status == HCI_SUCCESS: if event.status == HCI_SUCCESS:
connection_parameters = ConnectionParameters( connection_parameters = ConnectionParameters(
event.conn_interval, event.connection_interval,
event.conn_latency, event.peripheral_latency,
event.supervision_timeout event.supervision_timeout,
)
self.emit(
'connection_parameters_update', connection.handle, connection_parameters
) )
self.emit('connection_parameters_update', connection.handle, connection_parameters)
else: else:
self.emit('connection_parameters_update_failure', connection.handle, event.status) self.emit(
'connection_parameters_update_failure', connection.handle, event.status
)
def on_hci_le_phy_update_complete_event(self, event): def on_hci_le_phy_update_complete_event(self, event):
if (connection := self.connections.get(event.connection_handle)) is None: if (connection := self.connections.get(event.connection_handle)) is None:
@@ -467,13 +570,10 @@ class Host(EventEmitter):
def on_hci_le_advertising_report_event(self, event): def on_hci_le_advertising_report_event(self, event):
for report in event.reports: for report in event.reports:
self.emit( self.emit('advertising_report', report)
'advertising_report',
report.address, def on_hci_le_extended_advertising_report_event(self, event):
report.data, self.on_hci_le_advertising_report_event(event)
report.rssi,
report.event_type
)
def on_hci_le_remote_connection_parameter_request_event(self, event): def on_hci_le_remote_connection_parameter_request_event(self, event):
if event.connection_handle not in self.connections: if event.connection_handle not in self.connections:
@@ -489,8 +589,8 @@ class Host(EventEmitter):
interval_max=event.interval_max, interval_max=event.interval_max,
latency=event.latency, latency=event.latency,
timeout=event.timeout, timeout=event.timeout,
minimum_ce_length = 0, min_ce_length=0,
maximum_ce_length = 0 max_ce_length=0,
) )
) )
@@ -505,14 +605,12 @@ class Host(EventEmitter):
long_term_key = None long_term_key = None
else: else:
long_term_key = await self.long_term_key_provider( long_term_key = await self.long_term_key_provider(
connection.handle, connection.handle, event.random_number, event.encryption_diversifier
event.random_number,
event.encryption_diversifier
) )
if long_term_key: if long_term_key:
response = HCI_LE_Long_Term_Key_Request_Reply_Command( response = HCI_LE_Long_Term_Key_Request_Reply_Command(
connection_handle=event.connection_handle, connection_handle=event.connection_handle,
long_term_key = long_term_key long_term_key=long_term_key,
) )
else: else:
response = HCI_LE_Long_Term_Key_Request_Negative_Reply_Command( response = HCI_LE_Long_Term_Key_Request_Negative_Reply_Command(
@@ -531,10 +629,14 @@ class Host(EventEmitter):
def on_hci_role_change_event(self, event): def on_hci_role_change_event(self, event):
if event.status == HCI_SUCCESS: if event.status == HCI_SUCCESS:
logger.debug(f'role change for {event.bd_addr}: {HCI_Constant.role_name(event.new_role)}') logger.debug(
f'role change for {event.bd_addr}: {HCI_Constant.role_name(event.new_role)}'
)
# TODO: lookup the connection and update the role # TODO: lookup the connection and update the role
else: else:
logger.debug(f'role change for {event.bd_addr} failed: {HCI_Constant.error_name(event.status)}') logger.debug(
f'role change for {event.bd_addr} failed: {HCI_Constant.error_name(event.status)}'
)
def on_hci_le_data_length_change_event(self, event): def on_hci_le_data_length_change_event(self, event):
self.emit( self.emit(
@@ -543,7 +645,7 @@ class Host(EventEmitter):
event.max_tx_octets, event.max_tx_octets,
event.max_tx_time, event.max_tx_time,
event.max_rx_octets, event.max_rx_octets,
event.max_rx_time event.max_rx_time,
) )
def on_hci_authentication_complete_event(self, event): def on_hci_authentication_complete_event(self, event):
@@ -551,21 +653,35 @@ class Host(EventEmitter):
if event.status == HCI_SUCCESS: if event.status == HCI_SUCCESS:
self.emit('connection_authentication', event.connection_handle) self.emit('connection_authentication', event.connection_handle)
else: else:
self.emit('connection_authentication_failure', event.connection_handle, event.status) self.emit(
'connection_authentication_failure',
event.connection_handle,
event.status,
)
def on_hci_encryption_change_event(self, event): def on_hci_encryption_change_event(self, event):
# Notify the client # Notify the client
if event.status == HCI_SUCCESS: if event.status == HCI_SUCCESS:
self.emit('connection_encryption_change', event.connection_handle, event.encryption_enabled) self.emit(
'connection_encryption_change',
event.connection_handle,
event.encryption_enabled,
)
else: else:
self.emit('connection_encryption_failure', event.connection_handle, event.status) self.emit(
'connection_encryption_failure', event.connection_handle, event.status
)
def on_hci_encryption_key_refresh_complete_event(self, event): def on_hci_encryption_key_refresh_complete_event(self, event):
# Notify the client # Notify the client
if event.status == HCI_SUCCESS: if event.status == HCI_SUCCESS:
self.emit('connection_encryption_key_refresh', event.connection_handle) self.emit('connection_encryption_key_refresh', event.connection_handle)
else: else:
self.emit('connection_encryption_key_refresh_failure', event.connection_handle, event.status) self.emit(
'connection_encryption_key_refresh_failure',
event.connection_handle,
event.status,
)
def on_hci_link_supervision_timeout_changed_event(self, event): def on_hci_link_supervision_timeout_changed_event(self, event):
pass pass
@@ -577,19 +693,24 @@ class Host(EventEmitter):
pass pass
def on_hci_link_key_notification_event(self, event): def on_hci_link_key_notification_event(self, event):
logger.debug(f'link key for {event.bd_addr}: {event.link_key.hex()}, type={HCI_Constant.link_key_type_name(event.key_type)}') logger.debug(
f'link key for {event.bd_addr}: {event.link_key.hex()}, type={HCI_Constant.link_key_type_name(event.key_type)}'
)
self.emit('link_key', event.bd_addr, event.link_key, event.key_type) self.emit('link_key', event.bd_addr, event.link_key, event.key_type)
def on_hci_simple_pairing_complete_event(self, event): def on_hci_simple_pairing_complete_event(self, event):
logger.debug(f'simple pairing complete for {event.bd_addr}: status={HCI_Constant.status_name(event.status)}') logger.debug(
f'simple pairing complete for {event.bd_addr}: status={HCI_Constant.status_name(event.status)}'
)
# Notify the client
if event.status == HCI_SUCCESS:
self.emit('ssp_complete', event.bd_addr)
def on_hci_pin_code_request_event(self, event): def on_hci_pin_code_request_event(self, event):
# For now, just refuse all requests # For now, just refuse all requests
# TODO: delegate the decision # TODO: delegate the decision
self.send_command_sync( self.send_command_sync(
HCI_PIN_Code_Request_Negative_Reply_Command( HCI_PIN_Code_Request_Negative_Reply_Command(bd_addr=event.bd_addr)
bd_addr = event.bd_addr
)
) )
def on_hci_link_key_request_event(self, event): def on_hci_link_key_request_event(self, event):
@@ -601,8 +722,7 @@ class Host(EventEmitter):
link_key = await self.link_key_provider(event.bd_addr) link_key = await self.link_key_provider(event.bd_addr)
if link_key: if link_key:
response = HCI_Link_Key_Request_Reply_Command( response = HCI_Link_Key_Request_Reply_Command(
bd_addr = event.bd_addr, bd_addr=event.bd_addr, link_key=link_key
link_key = link_key
) )
else: else:
response = HCI_Link_Key_Request_Negative_Reply_Command( response = HCI_Link_Key_Request_Negative_Reply_Command(
@@ -620,11 +740,20 @@ class Host(EventEmitter):
pass pass
def on_hci_user_confirmation_request_event(self, event): def on_hci_user_confirmation_request_event(self, event):
self.emit('authentication_user_confirmation_request', event.bd_addr, event.numeric_value) self.emit(
'authentication_user_confirmation_request',
event.bd_addr,
event.numeric_value,
)
def on_hci_user_passkey_request_event(self, event): def on_hci_user_passkey_request_event(self, event):
self.emit('authentication_user_passkey_request', event.bd_addr) self.emit('authentication_user_passkey_request', event.bd_addr)
def on_hci_user_passkey_notification_event(self, event):
self.emit(
'authentication_user_passkey_notification', event.bd_addr, event.passkey
)
def on_hci_inquiry_complete_event(self, event): def on_hci_inquiry_complete_event(self, event):
self.emit('inquiry_complete') self.emit('inquiry_complete')
@@ -635,7 +764,7 @@ class Host(EventEmitter):
response.bd_addr, response.bd_addr,
response.class_of_device, response.class_of_device,
b'', b'',
response.rssi response.rssi,
) )
def on_hci_extended_inquiry_result_event(self, event): def on_hci_extended_inquiry_result_event(self, event):
@@ -644,7 +773,7 @@ class Host(EventEmitter):
event.bd_addr, event.bd_addr,
event.class_of_device, event.class_of_device,
event.extended_inquiry_response, event.extended_inquiry_response,
event.rssi event.rssi,
) )
def on_hci_remote_name_request_complete_event(self, event): def on_hci_remote_name_request_complete_event(self, event):
@@ -652,3 +781,10 @@ class Host(EventEmitter):
self.emit('remote_name_failure', event.bd_addr, event.status) self.emit('remote_name_failure', event.bd_addr, event.status)
else: else:
self.emit('remote_name', event.bd_addr, event.remote_name) self.emit('remote_name', event.bd_addr, event.remote_name)
def on_hci_remote_host_supported_features_notification_event(self, event):
self.emit(
'remote_host_supported_features',
event.bd_addr,
event.host_supported_features,
)
+17 -3
View File
@@ -20,6 +20,7 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import asyncio
import logging import logging
import os import os
import json import json
@@ -143,6 +144,10 @@ class KeyStore:
async def get_all(self): async def get_all(self):
return [] return []
async def delete_all(self):
all_keys = await self.get_all()
await asyncio.gather(*(self.delete(name) for (name, _) in all_keys))
async def get_resolving_keys(self): async def get_resolving_keys(self):
all_keys = await self.get_all() all_keys = await self.get_all()
resolving_keys = [] resolving_keys = []
@@ -189,9 +194,9 @@ class JsonKeyStore(KeyStore):
if filename is None: if filename is None:
# Use a default for the current user # Use a default for the current user
import appdirs import appdirs
self.directory_name = os.path.join( self.directory_name = os.path.join(
appdirs.user_data_dir(self.APP_NAME, self.APP_AUTHOR), appdirs.user_data_dir(self.APP_NAME, self.APP_AUTHOR), self.KEYS_DIR
self.KEYS_DIR
) )
json_filename = f'{self.namespace}.json'.lower().replace(':', '-') json_filename = f'{self.namespace}.json'.lower().replace(':', '-')
self.filename = os.path.join(self.directory_name, json_filename) self.filename = os.path.join(self.directory_name, json_filename)
@@ -257,7 +262,16 @@ class JsonKeyStore(KeyStore):
if namespace is None: if namespace is None:
return [] return []
return [(name, PairingKeys.from_dict(keys)) for (name, keys) in namespace.items()] return [
(name, PairingKeys.from_dict(keys)) for (name, keys) in namespace.items()
]
async def delete_all(self):
db = await self.load()
db.pop(self.namespace, None)
await self.save(db)
async def get(self, name): async def get(self, name):
db = await self.load() db = await self.load()
+972 -186
View File
File diff suppressed because it is too large Load Diff
+78 -25
View File
@@ -25,7 +25,7 @@ from bumble.hci import (
Address, Address,
HCI_SUCCESS, HCI_SUCCESS,
HCI_CONNECTION_ACCEPT_TIMEOUT_ERROR, HCI_CONNECTION_ACCEPT_TIMEOUT_ERROR,
HCI_CONNECTION_TIMEOUT_ERROR HCI_CONNECTION_TIMEOUT_ERROR,
) )
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -103,23 +103,30 @@ class LocalLink:
return return
# Connect to the first controller with a matching address # Connect to the first controller with a matching address
if peripheral_controller := self.find_controller(le_create_connection_command.peer_address): if peripheral_controller := self.find_controller(
central_controller.on_link_peripheral_connection_complete(le_create_connection_command, HCI_SUCCESS) le_create_connection_command.peer_address
):
central_controller.on_link_peripheral_connection_complete(
le_create_connection_command, HCI_SUCCESS
)
peripheral_controller.on_link_central_connected(central_address) peripheral_controller.on_link_central_connected(central_address)
return return
# No peripheral found # No peripheral found
central_controller.on_link_peripheral_connection_complete( central_controller.on_link_peripheral_connection_complete(
le_create_connection_command, le_create_connection_command, HCI_CONNECTION_ACCEPT_TIMEOUT_ERROR
HCI_CONNECTION_ACCEPT_TIMEOUT_ERROR
) )
def connect(self, central_address, le_create_connection_command): def connect(self, central_address, le_create_connection_command):
logger.debug(f'$$$ CONNECTION {central_address} -> {le_create_connection_command.peer_address}') logger.debug(
f'$$$ CONNECTION {central_address} -> {le_create_connection_command.peer_address}'
)
self.pending_connection = (central_address, le_create_connection_command) self.pending_connection = (central_address, le_create_connection_command)
asyncio.get_running_loop().call_soon(self.on_connection_complete) asyncio.get_running_loop().call_soon(self.on_connection_complete)
def on_disconnection_complete(self, central_address, peripheral_address, disconnect_command): def on_disconnection_complete(
self, central_address, peripheral_address, disconnect_command
):
# Find the controller that initiated the disconnection # Find the controller that initiated the disconnection
if not (central_controller := self.find_controller(central_address)): if not (central_controller := self.find_controller(central_address)):
logger.warning('!!! Initiating controller not found') logger.warning('!!! Initiating controller not found')
@@ -127,16 +134,24 @@ class LocalLink:
# Disconnect from the first controller with a matching address # Disconnect from the first controller with a matching address
if peripheral_controller := self.find_controller(peripheral_address): if peripheral_controller := self.find_controller(peripheral_address):
peripheral_controller.on_link_central_disconnected(central_address, disconnect_command.reason) peripheral_controller.on_link_central_disconnected(
central_address, disconnect_command.reason
)
central_controller.on_link_peripheral_disconnection_complete(disconnect_command, HCI_SUCCESS) central_controller.on_link_peripheral_disconnection_complete(
disconnect_command, HCI_SUCCESS
)
def disconnect(self, central_address, peripheral_address, disconnect_command): def disconnect(self, central_address, peripheral_address, disconnect_command):
logger.debug(f'$$$ DISCONNECTION {central_address} -> {peripheral_address}: reason = {disconnect_command.reason}') logger.debug(
f'$$$ DISCONNECTION {central_address} -> {peripheral_address}: reason = {disconnect_command.reason}'
)
args = [central_address, peripheral_address, disconnect_command] args = [central_address, peripheral_address, disconnect_command]
asyncio.get_running_loop().call_soon(self.on_disconnection_complete, *args) asyncio.get_running_loop().call_soon(self.on_disconnection_complete, *args)
def on_connection_encrypted(self, central_address, peripheral_address, rand, ediv, ltk): def on_connection_encrypted(
self, central_address, peripheral_address, rand, ediv, ltk
):
logger.debug(f'*** ENCRYPTION {central_address} -> {peripheral_address}') logger.debug(f'*** ENCRYPTION {central_address} -> {peripheral_address}')
if central_controller := self.find_controller(central_address): if central_controller := self.find_controller(central_address):
@@ -152,6 +167,7 @@ class RemoteLink:
A Link implementation that communicates with other virtual controllers via a A Link implementation that communicates with other virtual controllers via a
WebSocket relay WebSocket relay
''' '''
def __init__(self, uri): def __init__(self, uri):
self.controller = None self.controller = None
self.uri = uri self.uri = uri
@@ -160,7 +176,9 @@ class RemoteLink:
self.rpc_result = None self.rpc_result = None
self.pending_connection = None self.pending_connection = None
self.central_connections = set() # List of addresses that we have connected to self.central_connections = set() # List of addresses that we have connected to
self.peripheral_connections = set() # List of addresses that have connected to us self.peripheral_connections = (
set()
) # List of addresses that have connected to us
# Connect and run asynchronously # Connect and run asynchronously
asyncio.create_task(self.run_connection()) asyncio.create_task(self.run_connection())
@@ -192,7 +210,9 @@ class RemoteLink:
try: try:
await item await item
except Exception as error: except Exception as error:
logger.warning(f'{color("!!! Exception in async handler:", "red")} {error}') logger.warning(
f'{color("!!! Exception in async handler:", "red")} {error}'
)
async def run_connection(self): async def run_connection(self):
# Connect to the relay # Connect to the relay
@@ -227,7 +247,9 @@ class RemoteLink:
self.central_connections.remove(address) self.central_connections.remove(address)
if address in self.peripheral_connections: if address in self.peripheral_connections:
self.controller.on_link_central_disconnected(address, HCI_CONNECTION_TIMEOUT_ERROR) self.controller.on_link_central_disconnected(
address, HCI_CONNECTION_TIMEOUT_ERROR
)
self.peripheral_connections.remove(address) self.peripheral_connections.remove(address)
async def on_unreachable_received(self, target): async def on_unreachable_received(self, target):
@@ -244,7 +266,9 @@ class RemoteLink:
async def on_advertisement_message_received(self, sender, advertisement): async def on_advertisement_message_received(self, sender, advertisement):
try: try:
self.controller.on_link_advertising_data(Address(sender), bytes.fromhex(advertisement)) self.controller.on_link_advertising_data(
Address(sender), bytes.fromhex(advertisement)
)
except Exception: except Exception:
logger.exception('exception') logger.exception('exception')
@@ -275,7 +299,9 @@ class RemoteLink:
# Notify the controller # Notify the controller
logger.debug(f'connected to peripheral {self.pending_connection.peer_address}') logger.debug(f'connected to peripheral {self.pending_connection.peer_address}')
self.controller.on_link_peripheral_connection_complete(self.pending_connection, HCI_SUCCESS) self.controller.on_link_peripheral_connection_complete(
self.pending_connection, HCI_SUCCESS
)
async def on_disconnect_message_received(self, sender, message): async def on_disconnect_message_received(self, sender, message):
# Notify the controller # Notify the controller
@@ -296,7 +322,7 @@ class RemoteLink:
websocket = await self.websocket websocket = await self.websocket
# Create a future value to hold the eventual result # Create a future value to hold the eventual result
assert(self.rpc_result is None) assert self.rpc_result is None
self.rpc_result = asyncio.get_running_loop().create_future() self.rpc_result = asyncio.get_running_loop().create_future()
# Send the command # Send the command
@@ -345,16 +371,43 @@ class RemoteLink:
logger.warn('connection already pending') logger.warn('connection already pending')
return return
self.pending_connection = le_create_connection_command self.pending_connection = le_create_connection_command
self.execute(partial(self.send_connection_request_to_relay, str(le_create_connection_command.peer_address))) self.execute(
partial(
self.send_connection_request_to_relay,
str(le_create_connection_command.peer_address),
)
)
def on_disconnection_complete(self, disconnect_command): def on_disconnection_complete(self, disconnect_command):
self.controller.on_link_peripheral_disconnection_complete(disconnect_command, HCI_SUCCESS) self.controller.on_link_peripheral_disconnection_complete(
disconnect_command, HCI_SUCCESS
)
def disconnect(self, central_address, peripheral_address, disconnect_command): def disconnect(self, central_address, peripheral_address, disconnect_command):
logger.debug(f'disconnect {central_address} -> {peripheral_address}: reason = {disconnect_command.reason}') logger.debug(
self.execute(partial(self.send_targetted_message, peripheral_address, f'disconnect:reason={disconnect_command.reason}')) f'disconnect {central_address} -> {peripheral_address}: reason = {disconnect_command.reason}'
asyncio.get_running_loop().call_soon(self.on_disconnection_complete, disconnect_command) )
self.execute(
partial(
self.send_targetted_message,
peripheral_address,
f'disconnect:reason={disconnect_command.reason}',
)
)
asyncio.get_running_loop().call_soon(
self.on_disconnection_complete, disconnect_command
)
def on_connection_encrypted(self, central_address, peripheral_address, rand, ediv, ltk): def on_connection_encrypted(
asyncio.get_running_loop().call_soon(self.controller.on_link_encrypted, peripheral_address, rand, ediv, ltk) self, central_address, peripheral_address, rand, ediv, ltk
self.execute(partial(self.send_targetted_message, peripheral_address, f'encrypted:ltk={ltk.hex()}')) ):
asyncio.get_running_loop().call_soon(
self.controller.on_link_encrypted, peripheral_address, rand, ediv, ltk
)
self.execute(
partial(
self.send_targetted_message,
peripheral_address,
f'encrypted:ltk={ltk.hex()}',
)
)
+157
View File
@@ -0,0 +1,157 @@
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import struct
import logging
from ..core import AdvertisingData
from ..gatt import (
GATT_ASHA_SERVICE,
GATT_ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC,
GATT_ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC,
GATT_ASHA_AUDIO_STATUS_CHARACTERISTIC,
GATT_ASHA_VOLUME_CHARACTERISTIC,
GATT_ASHA_LE_PSM_OUT_CHARACTERISTIC,
TemplateService,
Characteristic,
CharacteristicValue,
PackedCharacteristicAdapter,
)
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
class AshaService(TemplateService):
UUID = GATT_ASHA_SERVICE
OPCODE_START = 1
OPCODE_STOP = 2
OPCODE_STATUS = 3
PROTOCOL_VERSION = 0x01
RESERVED_FOR_FUTURE_USE = [00, 00]
FEATURE_MAP = [0x01] # [LE CoC audio output streaming supported]
SUPPORTED_CODEC_ID = [0x02, 0x01] # Codec IDs [G.722 at 16 kHz]
RENDER_DELAY = [00, 00]
def __init__(self, capability: int, hisyncid: [int]):
self.hisyncid = hisyncid
self.capability = capability # Device Capabilities [Left, Monaural]
# Handler for volume control
def on_volume_write(connection, value):
logger.info(f'--- VOLUME Write:{value[0]}')
# Handler for audio control commands
def on_audio_control_point_write(connection, value):
logger.info(f'--- AUDIO CONTROL POINT Write:{value.hex()}')
opcode = value[0]
if opcode == AshaService.OPCODE_START:
# Start
audio_type = ('Unknown', 'Ringtone', 'Phone Call', 'Media')[value[2]]
logger.info(
f'### START: codec={value[1]}, audio_type={audio_type}, volume={value[3]}, otherstate={value[4]}'
)
elif opcode == AshaService.OPCODE_STOP:
logger.info('### STOP')
elif opcode == AshaService.OPCODE_STATUS:
logger.info(f'### STATUS: connected={value[1]}')
# TODO Respond with a status
# asyncio.create_task(device.notify_subscribers(audio_status_characteristic, force=True))
self.read_only_properties_characteristic = Characteristic(
GATT_ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC,
Characteristic.READ,
Characteristic.READABLE,
bytes(
[
AshaService.PROTOCOL_VERSION, # Version
self.capability,
]
)
+ bytes(self.hisyncid)
+ bytes(AshaService.FEATURE_MAP)
+ bytes(AshaService.RENDER_DELAY)
+ bytes(AshaService.RESERVED_FOR_FUTURE_USE)
+ bytes(AshaService.SUPPORTED_CODEC_ID),
)
self.audio_control_point_characteristic = Characteristic(
GATT_ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC,
Characteristic.WRITE | Characteristic.WRITE_WITHOUT_RESPONSE,
Characteristic.WRITEABLE,
CharacteristicValue(write=on_audio_control_point_write),
)
self.audio_status_characteristic = Characteristic(
GATT_ASHA_AUDIO_STATUS_CHARACTERISTIC,
Characteristic.READ | Characteristic.NOTIFY,
Characteristic.READABLE,
bytes([0]),
)
self.volume_characteristic = Characteristic(
GATT_ASHA_VOLUME_CHARACTERISTIC,
Characteristic.WRITE_WITHOUT_RESPONSE,
Characteristic.WRITEABLE,
CharacteristicValue(write=on_volume_write),
)
# TODO add real psm value
self.psm = 0x0080
# self.psm = device.register_l2cap_channel_server(0, on_coc, 8)
self.le_psm_out_characteristic = Characteristic(
GATT_ASHA_LE_PSM_OUT_CHARACTERISTIC,
Characteristic.READ,
Characteristic.READABLE,
struct.pack('<H', self.psm),
)
characteristics = [
self.read_only_properties_characteristic,
self.audio_control_point_characteristic,
self.audio_status_characteristic,
self.volume_characteristic,
self.le_psm_out_characteristic,
]
super().__init__(characteristics)
def get_advertising_data(self):
# Advertisement only uses 4 least significant bytes of the HiSyncId.
return bytes(
AdvertisingData(
[
(
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
bytes(GATT_ASHA_SERVICE),
),
(
AdvertisingData.SERVICE_DATA_16_BIT_UUID,
bytes(GATT_ASHA_SERVICE)
+ bytes(
[
AshaService.PROTOCOL_VERSION,
self.capability,
]
)
+ bytes(self.hisyncid[:4]),
),
]
)
)
+7 -6
View File
@@ -23,7 +23,7 @@ from ..gatt import (
TemplateService, TemplateService,
Characteristic, Characteristic,
CharacteristicValue, CharacteristicValue,
PackedCharacteristicAdapter PackedCharacteristicAdapter,
) )
@@ -38,9 +38,9 @@ class BatteryService(TemplateService):
GATT_BATTERY_LEVEL_CHARACTERISTIC, GATT_BATTERY_LEVEL_CHARACTERISTIC,
Characteristic.READ | Characteristic.NOTIFY, Characteristic.READ | Characteristic.NOTIFY,
Characteristic.READABLE, Characteristic.READABLE,
CharacteristicValue(read=read_battery_level) CharacteristicValue(read=read_battery_level),
), ),
format=BatteryService.BATTERY_LEVEL_FORMAT format=BatteryService.BATTERY_LEVEL_FORMAT,
) )
super().__init__([self.battery_level_characteristic]) super().__init__([self.battery_level_characteristic])
@@ -52,10 +52,11 @@ class BatteryServiceProxy(ProfileServiceProxy):
def __init__(self, service_proxy): def __init__(self, service_proxy):
self.service_proxy = service_proxy self.service_proxy = service_proxy
if characteristics := service_proxy.get_characteristics_by_uuid(GATT_BATTERY_LEVEL_CHARACTERISTIC): if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_BATTERY_LEVEL_CHARACTERISTIC
):
self.battery_level = PackedCharacteristicAdapter( self.battery_level = PackedCharacteristicAdapter(
characteristics[0], characteristics[0], format=BatteryService.BATTERY_LEVEL_FORMAT
format=BatteryService.BATTERY_LEVEL_FORMAT
) )
else: else:
self.battery_level = None self.battery_level = None
+21 -18
View File
@@ -33,7 +33,7 @@ from ..gatt import (
TemplateService, TemplateService,
Characteristic, Characteristic,
DelegatedCharacteristicAdapter, DelegatedCharacteristicAdapter,
UTF8CharacteristicAdapter UTF8CharacteristicAdapter,
) )
@@ -63,38 +63,37 @@ class DeviceInformationService(TemplateService):
# TODO: pnp_id # TODO: pnp_id
): ):
characteristics = [ characteristics = [
Characteristic( Characteristic(uuid, Characteristic.READ, Characteristic.READABLE, field)
uuid,
Characteristic.READ,
Characteristic.READABLE,
field
)
for (field, uuid) in ( for (field, uuid) in (
(manufacturer_name, GATT_MANUFACTURER_NAME_STRING_CHARACTERISTIC), (manufacturer_name, GATT_MANUFACTURER_NAME_STRING_CHARACTERISTIC),
(model_number, GATT_MODEL_NUMBER_STRING_CHARACTERISTIC), (model_number, GATT_MODEL_NUMBER_STRING_CHARACTERISTIC),
(serial_number, GATT_SERIAL_NUMBER_STRING_CHARACTERISTIC), (serial_number, GATT_SERIAL_NUMBER_STRING_CHARACTERISTIC),
(hardware_revision, GATT_HARDWARE_REVISION_STRING_CHARACTERISTIC), (hardware_revision, GATT_HARDWARE_REVISION_STRING_CHARACTERISTIC),
(firmware_revision, GATT_FIRMWARE_REVISION_STRING_CHARACTERISTIC), (firmware_revision, GATT_FIRMWARE_REVISION_STRING_CHARACTERISTIC),
(software_revision, GATT_SOFTWARE_REVISION_STRING_CHARACTERISTIC) (software_revision, GATT_SOFTWARE_REVISION_STRING_CHARACTERISTIC),
) )
if field is not None if field is not None
] ]
if system_id is not None: if system_id is not None:
characteristics.append(Characteristic( characteristics.append(
Characteristic(
GATT_SYSTEM_ID_CHARACTERISTIC, GATT_SYSTEM_ID_CHARACTERISTIC,
Characteristic.READ, Characteristic.READ,
Characteristic.READABLE, Characteristic.READABLE,
self.pack_system_id(*system_id) self.pack_system_id(*system_id),
)) )
)
if ieee_regulatory_certification_data_list is not None: if ieee_regulatory_certification_data_list is not None:
characteristics.append(Characteristic( characteristics.append(
Characteristic(
GATT_REGULATORY_CERTIFICATION_DATA_LIST_CHARACTERISTIC, GATT_REGULATORY_CERTIFICATION_DATA_LIST_CHARACTERISTIC,
Characteristic.READ, Characteristic.READ,
Characteristic.READABLE, Characteristic.READABLE,
ieee_regulatory_certification_data_list ieee_regulatory_certification_data_list,
)) )
)
super().__init__(characteristics) super().__init__(characteristics)
@@ -112,7 +111,7 @@ class DeviceInformationServiceProxy(ProfileServiceProxy):
('serial_number', GATT_SERIAL_NUMBER_STRING_CHARACTERISTIC), ('serial_number', GATT_SERIAL_NUMBER_STRING_CHARACTERISTIC),
('hardware_revision', GATT_HARDWARE_REVISION_STRING_CHARACTERISTIC), ('hardware_revision', GATT_HARDWARE_REVISION_STRING_CHARACTERISTIC),
('firmware_revision', GATT_FIRMWARE_REVISION_STRING_CHARACTERISTIC), ('firmware_revision', GATT_FIRMWARE_REVISION_STRING_CHARACTERISTIC),
('software_revision', GATT_SOFTWARE_REVISION_STRING_CHARACTERISTIC) ('software_revision', GATT_SOFTWARE_REVISION_STRING_CHARACTERISTIC),
): ):
if characteristics := service_proxy.get_characteristics_by_uuid(uuid): if characteristics := service_proxy.get_characteristics_by_uuid(uuid):
characteristic = UTF8CharacteristicAdapter(characteristics[0]) characteristic = UTF8CharacteristicAdapter(characteristics[0])
@@ -120,16 +119,20 @@ class DeviceInformationServiceProxy(ProfileServiceProxy):
characteristic = None characteristic = None
self.__setattr__(field, characteristic) self.__setattr__(field, characteristic)
if characteristics := service_proxy.get_characteristics_by_uuid(GATT_SYSTEM_ID_CHARACTERISTIC): if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_SYSTEM_ID_CHARACTERISTIC
):
self.system_id = DelegatedCharacteristicAdapter( self.system_id = DelegatedCharacteristicAdapter(
characteristics[0], characteristics[0],
encode=lambda v: DeviceInformationService.pack_system_id(*v), encode=lambda v: DeviceInformationService.pack_system_id(*v),
decode=DeviceInformationService.unpack_system_id decode=DeviceInformationService.unpack_system_id,
) )
else: else:
self.system_id = None self.system_id = None
if characteristics := service_proxy.get_characteristics_by_uuid(GATT_REGULATORY_CERTIFICATION_DATA_LIST_CHARACTERISTIC): if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_REGULATORY_CERTIFICATION_DATA_LIST_CHARACTERISTIC
):
self.ieee_regulatory_certification_data_list = characteristics[0] self.ieee_regulatory_certification_data_list = characteristics[0]
else: else:
self.ieee_regulatory_certification_data_list = None self.ieee_regulatory_certification_data_list = None
+45 -30
View File
@@ -30,7 +30,7 @@ from ..gatt import (
Characteristic, Characteristic,
CharacteristicValue, CharacteristicValue,
DelegatedCharacteristicAdapter, DelegatedCharacteristicAdapter,
PackedCharacteristicAdapter PackedCharacteristicAdapter,
) )
@@ -42,12 +42,12 @@ class HeartRateService(TemplateService):
RESET_ENERGY_EXPENDED = 0x01 RESET_ENERGY_EXPENDED = 0x01
class BodySensorLocation(IntEnum): class BodySensorLocation(IntEnum):
OTHER = 0, OTHER = (0,)
CHEST = 1, CHEST = (1,)
WRIST = 2, WRIST = (2,)
FINGER = 3, FINGER = (3,)
HAND = 4, HAND = (4,)
EAR_LOBE = 5, EAR_LOBE = (5,)
FOOT = 6 FOOT = 6
class HeartRateMeasurement: class HeartRateMeasurement:
@@ -56,12 +56,14 @@ class HeartRateService(TemplateService):
heart_rate, heart_rate,
sensor_contact_detected=None, sensor_contact_detected=None,
energy_expended=None, energy_expended=None,
rr_intervals=None rr_intervals=None,
): ):
if heart_rate < 0 or heart_rate > 0xFFFF: if heart_rate < 0 or heart_rate > 0xFFFF:
raise ValueError('heart_rate out of range') raise ValueError('heart_rate out of range')
if energy_expended is not None and (energy_expended < 0 or energy_expended > 0xFFFF): if energy_expended is not None and (
energy_expended < 0 or energy_expended > 0xFFFF
):
raise ValueError('energy_expended out of range') raise ValueError('energy_expended out of range')
if rr_intervals: if rr_intervals:
@@ -87,7 +89,7 @@ class HeartRateService(TemplateService):
offset += 1 offset += 1
if flags & (1 << 2): if flags & (1 << 2):
sensor_contact_detected = (flags & (1 << 1) != 0) sensor_contact_detected = flags & (1 << 1) != 0
else: else:
sensor_contact_detected = None sensor_contact_detected = None
@@ -119,38 +121,42 @@ class HeartRateService(TemplateService):
flags |= ((1 if self.sensor_contact_detected else 0) << 1) | (1 << 2) flags |= ((1 if self.sensor_contact_detected else 0) << 1) | (1 << 2)
if self.energy_expended is not None: if self.energy_expended is not None:
flags |= (1 << 3) flags |= 1 << 3
data += struct.pack('<H', self.energy_expended) data += struct.pack('<H', self.energy_expended)
if self.rr_intervals: if self.rr_intervals:
flags |= (1 << 4) flags |= 1 << 4
data += b''.join([ data += b''.join(
[
struct.pack('<H', int(rr_interval * 1024)) struct.pack('<H', int(rr_interval * 1024))
for rr_interval in self.rr_intervals for rr_interval in self.rr_intervals
]) ]
)
return bytes([flags]) + data return bytes([flags]) + data
def __str__(self): def __str__(self):
return f'HeartRateMeasurement(heart_rate={self.heart_rate},'\ return (
f' sensor_contact_detected={self.sensor_contact_detected},'\ f'HeartRateMeasurement(heart_rate={self.heart_rate},'
f' energy_expended={self.energy_expended},'\ f' sensor_contact_detected={self.sensor_contact_detected},'
f' energy_expended={self.energy_expended},'
f' rr_intervals={self.rr_intervals})' f' rr_intervals={self.rr_intervals})'
)
def __init__( def __init__(
self, self,
read_heart_rate_measurement, read_heart_rate_measurement,
body_sensor_location=None, body_sensor_location=None,
reset_energy_expended=None reset_energy_expended=None,
): ):
self.heart_rate_measurement_characteristic = DelegatedCharacteristicAdapter( self.heart_rate_measurement_characteristic = DelegatedCharacteristicAdapter(
Characteristic( Characteristic(
GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC, GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC,
Characteristic.NOTIFY, Characteristic.NOTIFY,
0, 0,
CharacteristicValue(read=read_heart_rate_measurement) CharacteristicValue(read=read_heart_rate_measurement),
), ),
encode=lambda value: bytes(value) encode=lambda value: bytes(value),
) )
characteristics = [self.heart_rate_measurement_characteristic] characteristics = [self.heart_rate_measurement_characteristic]
@@ -159,11 +165,12 @@ class HeartRateService(TemplateService):
GATT_BODY_SENSOR_LOCATION_CHARACTERISTIC, GATT_BODY_SENSOR_LOCATION_CHARACTERISTIC,
Characteristic.READ, Characteristic.READ,
Characteristic.READABLE, Characteristic.READABLE,
bytes([int(body_sensor_location)]) bytes([int(body_sensor_location)]),
) )
characteristics.append(self.body_sensor_location_characteristic) characteristics.append(self.body_sensor_location_characteristic)
if reset_energy_expended: if reset_energy_expended:
def write_heart_rate_control_point_value(connection, value): def write_heart_rate_control_point_value(connection, value):
if value == self.RESET_ENERGY_EXPENDED: if value == self.RESET_ENERGY_EXPENDED:
if reset_energy_expended is not None: if reset_energy_expended is not None:
@@ -176,9 +183,9 @@ class HeartRateService(TemplateService):
GATT_HEART_RATE_CONTROL_POINT_CHARACTERISTIC, GATT_HEART_RATE_CONTROL_POINT_CHARACTERISTIC,
Characteristic.WRITE, Characteristic.WRITE,
Characteristic.WRITEABLE, Characteristic.WRITEABLE,
CharacteristicValue(write=write_heart_rate_control_point_value) CharacteristicValue(write=write_heart_rate_control_point_value),
), ),
format=HeartRateService.HEART_RATE_CONTROL_POINT_FORMAT format=HeartRateService.HEART_RATE_CONTROL_POINT_FORMAT,
) )
characteristics.append(self.heart_rate_control_point_characteristic) characteristics.append(self.heart_rate_control_point_characteristic)
@@ -192,30 +199,38 @@ class HeartRateServiceProxy(ProfileServiceProxy):
def __init__(self, service_proxy): def __init__(self, service_proxy):
self.service_proxy = service_proxy self.service_proxy = service_proxy
if characteristics := service_proxy.get_characteristics_by_uuid(GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC): if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC
):
self.heart_rate_measurement = DelegatedCharacteristicAdapter( self.heart_rate_measurement = DelegatedCharacteristicAdapter(
characteristics[0], characteristics[0],
decode=HeartRateService.HeartRateMeasurement.from_bytes decode=HeartRateService.HeartRateMeasurement.from_bytes,
) )
else: else:
self.heart_rate_measurement = None self.heart_rate_measurement = None
if characteristics := service_proxy.get_characteristics_by_uuid(GATT_BODY_SENSOR_LOCATION_CHARACTERISTIC): if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_BODY_SENSOR_LOCATION_CHARACTERISTIC
):
self.body_sensor_location = DelegatedCharacteristicAdapter( self.body_sensor_location = DelegatedCharacteristicAdapter(
characteristics[0], characteristics[0],
decode=lambda value: HeartRateService.BodySensorLocation(value[0]) decode=lambda value: HeartRateService.BodySensorLocation(value[0]),
) )
else: else:
self.body_sensor_location = None self.body_sensor_location = None
if characteristics := service_proxy.get_characteristics_by_uuid(GATT_HEART_RATE_CONTROL_POINT_CHARACTERISTIC): if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_HEART_RATE_CONTROL_POINT_CHARACTERISTIC
):
self.heart_rate_control_point = PackedCharacteristicAdapter( self.heart_rate_control_point = PackedCharacteristicAdapter(
characteristics[0], characteristics[0],
format=HeartRateService.HEART_RATE_CONTROL_POINT_FORMAT format=HeartRateService.HEART_RATE_CONTROL_POINT_FORMAT,
) )
else: else:
self.heart_rate_control_point = None self.heart_rate_control_point = None
async def reset_energy_expended(self): async def reset_energy_expended(self):
if self.heart_rate_control_point is not None: if self.heart_rate_control_point is not None:
return await self.heart_rate_control_point.write_value(HeartRateService.RESET_ENERGY_EXPENDED) return await self.heart_rate_control_point.write_value(
HeartRateService.RESET_ENERGY_EXPENDED
)
+102 -89
View File
@@ -21,7 +21,7 @@ import asyncio
from colors import color from colors import color
from pyee import EventEmitter from pyee import EventEmitter
from .core import InvalidStateError, ProtocolError, ConnectionError from .core import BT_BR_EDR_TRANSPORT, InvalidStateError, ProtocolError, ConnectionError
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -32,6 +32,8 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Constants # Constants
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# fmt: off
RFCOMM_PSM = 0x0003 RFCOMM_PSM = 0x0003
@@ -98,6 +100,8 @@ RFCOMM_DEFAULT_PREFERRED_MTU = 1280
RFCOMM_DYNAMIC_CHANNEL_NUMBER_START = 1 RFCOMM_DYNAMIC_CHANNEL_NUMBER_START = 1
RFCOMM_DYNAMIC_CHANNEL_NUMBER_END = 30 RFCOMM_DYNAMIC_CHANNEL_NUMBER_END = 30
# fmt: on
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def fcs(buffer): def fcs(buffer):
@@ -150,7 +154,10 @@ class RFCOMM_Frame:
@staticmethod @staticmethod
def make_mcc(type, c_r, data): def make_mcc(type, c_r, data):
return bytes([(type << 2 | c_r << 1 | 1) & 0xFF, (len(data) & 0x7F) << 1 | 1]) + data return (
bytes([(type << 2 | c_r << 1 | 1) & 0xFF, (len(data) & 0x7F) << 1 | 1])
+ data
)
@staticmethod @staticmethod
def sabm(c_r, dlci): def sabm(c_r, dlci):
@@ -170,7 +177,9 @@ class RFCOMM_Frame:
@staticmethod @staticmethod
def uih(c_r, dlci, information, p_f=0): def uih(c_r, dlci, information, p_f=0):
return RFCOMM_Frame(RFCOMM_UIH_FRAME, c_r, dlci, p_f, information, with_credits = (p_f == 1)) return RFCOMM_Frame(
RFCOMM_UIH_FRAME, c_r, dlci, p_f, information, with_credits=(p_f == 1)
)
@staticmethod @staticmethod
def from_bytes(data): def from_bytes(data):
@@ -197,7 +206,12 @@ class RFCOMM_Frame:
return frame return frame
def __bytes__(self): def __bytes__(self):
return bytes([self.address, self.control]) + self.length + self.information + bytes([self.fcs]) return (
bytes([self.address, self.control])
+ self.length
+ self.information
+ bytes([self.fcs])
)
def __str__(self): def __str__(self):
return f'{color(self.type_name(), "yellow")}(c/r={self.c_r},dlci={self.dlci},p/f={self.p_f},length={len(self.information)},fcs=0x{self.fcs:02X})' return f'{color(self.type_name(), "yellow")}(c/r={self.c_r},dlci={self.dlci},p/f={self.p_f},length={len(self.information)},fcs=0x{self.fcs:02X})'
@@ -205,7 +219,16 @@ class RFCOMM_Frame:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class RFCOMM_MCC_PN: class RFCOMM_MCC_PN:
def __init__(self, dlci, cl, priority, ack_timer, max_frame_size, max_retransmissions, window_size): def __init__(
self,
dlci,
cl,
priority,
ack_timer,
max_frame_size,
max_retransmissions,
window_size,
):
self.dlci = dlci self.dlci = dlci
self.cl = cl self.cl = cl
self.priority = priority self.priority = priority
@@ -223,11 +246,12 @@ class RFCOMM_MCC_PN:
ack_timer=data[3], ack_timer=data[3],
max_frame_size=data[4] | data[5] << 8, max_frame_size=data[4] | data[5] << 8,
max_retransmissions=data[6], max_retransmissions=data[6],
window_size = data[7] window_size=data[7],
) )
def __bytes__(self): def __bytes__(self):
return bytes([ return bytes(
[
self.dlci & 0xFF, self.dlci & 0xFF,
self.cl & 0xFF, self.cl & 0xFF,
self.priority & 0xFF, self.priority & 0xFF,
@@ -235,8 +259,9 @@ class RFCOMM_MCC_PN:
self.max_frame_size & 0xFF, self.max_frame_size & 0xFF,
(self.max_frame_size >> 8) & 0xFF, (self.max_frame_size >> 8) & 0xFF,
self.max_retransmissions & 0xFF, self.max_retransmissions & 0xFF,
self.window_size & 0xFF self.window_size & 0xFF,
]) ]
)
def __str__(self): def __str__(self):
return f'PN(dlci={self.dlci},cl={self.cl},priority={self.priority},ack_timer={self.ack_timer},max_frame_size={self.max_frame_size},max_retransmissions={self.max_retransmissions},window_size={self.window_size})' return f'PN(dlci={self.dlci},cl={self.cl},priority={self.priority},ack_timer={self.ack_timer},max_frame_size={self.max_frame_size},max_retransmissions={self.max_retransmissions},window_size={self.window_size})'
@@ -260,14 +285,21 @@ class RFCOMM_MCC_MSC:
rtc=data[1] >> 2 & 1, rtc=data[1] >> 2 & 1,
rtr=data[1] >> 3 & 1, rtr=data[1] >> 3 & 1,
ic=data[1] >> 6 & 1, ic=data[1] >> 6 & 1,
dv = data[1] >> 7 & 1 dv=data[1] >> 7 & 1,
) )
def __bytes__(self): def __bytes__(self):
return bytes([ return bytes(
[
(self.dlci << 2) | 3, (self.dlci << 2) | 3,
1 | self.fc << 1 | self.rtc << 2 | self.rtr << 3 | self.ic << 6 | self.dv << 7 1
]) | self.fc << 1
| self.rtc << 2
| self.rtr << 3
| self.ic << 6
| self.dv << 7,
]
)
def __str__(self): def __str__(self):
return f'MSC(dlci={self.dlci},fc={self.fc},rtc={self.rtc},rtr={self.rtr},ic={self.ic},dv={self.dv})' return f'MSC(dlci={self.dlci},fc={self.fc},rtc={self.rtc},rtr={self.rtr},ic={self.ic},dv={self.dv})'
@@ -289,7 +321,7 @@ class DLC(EventEmitter):
CONNECTED: 'CONNECTED', CONNECTED: 'CONNECTED',
DISCONNECTING: 'DISCONNECTING', DISCONNECTING: 'DISCONNECTING',
DISCONNECTED: 'DISCONNECTED', DISCONNECTED: 'DISCONNECTED',
RESET: 'RESET' RESET: 'RESET',
} }
def __init__(self, multiplexer, dlci, max_frame_size, initial_tx_credits): def __init__(self, multiplexer, dlci, max_frame_size, initial_tx_credits):
@@ -307,14 +339,18 @@ class DLC(EventEmitter):
# Compute the MTU # Compute the MTU
max_overhead = 4 + 1 # header with 2-byte length + fcs max_overhead = 4 + 1 # header with 2-byte length + fcs
self.mtu = min(max_frame_size, self.multiplexer.l2cap_channel.mtu - max_overhead) self.mtu = min(
max_frame_size, self.multiplexer.l2cap_channel.mtu - max_overhead
)
@staticmethod @staticmethod
def state_name(state): def state_name(state):
return DLC.STATE_NAMES[state] return DLC.STATE_NAMES[state]
def change_state(self, new_state): def change_state(self, new_state):
logger.debug(f'{self} state change -> {color(self.state_name(new_state), "magenta")}') logger.debug(
f'{self} state change -> {color(self.state_name(new_state), "magenta")}'
)
self.state = new_state self.state = new_state
def send_frame(self, frame): def send_frame(self, frame):
@@ -332,23 +368,10 @@ class DLC(EventEmitter):
self.send_frame(RFCOMM_Frame.ua(c_r=1 - self.c_r, dlci=self.dlci)) self.send_frame(RFCOMM_Frame.ua(c_r=1 - self.c_r, dlci=self.dlci))
# Exchange the modem status with the peer # Exchange the modem status with the peer
msc = RFCOMM_MCC_MSC( msc = RFCOMM_MCC_MSC(dlci=self.dlci, fc=0, rtc=1, rtr=1, ic=0, dv=1)
dlci = self.dlci,
fc = 0,
rtc = 1,
rtr = 1,
ic = 0,
dv = 1
)
mcc = RFCOMM_Frame.make_mcc(type=RFCOMM_MCC_MSC_TYPE, c_r=1, data=bytes(msc)) mcc = RFCOMM_Frame.make_mcc(type=RFCOMM_MCC_MSC_TYPE, c_r=1, data=bytes(msc))
logger.debug(f'>>> MCC MSC Command: {msc}') logger.debug(f'>>> MCC MSC Command: {msc}')
self.send_frame( self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
RFCOMM_Frame.uih(
c_r = self.c_r,
dlci = 0,
information = mcc
)
)
self.change_state(DLC.CONNECTED) self.change_state(DLC.CONNECTED)
self.emit('open') self.emit('open')
@@ -359,23 +382,10 @@ class DLC(EventEmitter):
return return
# Exchange the modem status with the peer # Exchange the modem status with the peer
msc = RFCOMM_MCC_MSC( msc = RFCOMM_MCC_MSC(dlci=self.dlci, fc=0, rtc=1, rtr=1, ic=0, dv=1)
dlci = self.dlci,
fc = 0,
rtc = 1,
rtr = 1,
ic = 0,
dv = 1
)
mcc = RFCOMM_Frame.make_mcc(type=RFCOMM_MCC_MSC_TYPE, c_r=1, data=bytes(msc)) mcc = RFCOMM_Frame.make_mcc(type=RFCOMM_MCC_MSC_TYPE, c_r=1, data=bytes(msc))
logger.debug(f'>>> MCC MSC Command: {msc}') logger.debug(f'>>> MCC MSC Command: {msc}')
self.send_frame( self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
RFCOMM_Frame.uih(
c_r = self.c_r,
dlci = 0,
information = mcc
)
)
self.change_state(DLC.CONNECTED) self.change_state(DLC.CONNECTED)
self.multiplexer.on_dlc_open_complete(self) self.multiplexer.on_dlc_open_complete(self)
@@ -395,10 +405,14 @@ class DLC(EventEmitter):
credits = frame.information[0] credits = frame.information[0]
self.tx_credits += credits self.tx_credits += credits
logger.debug(f'<<< Credits [{self.dlci}]: received {credits}, total={self.tx_credits}') logger.debug(
f'<<< Credits [{self.dlci}]: received {credits}, total={self.tx_credits}'
)
data = data[1:] data = data[1:]
logger.debug(f'{color("<<< Data", "yellow")} [{self.dlci}] {len(data)} bytes, rx_credits={self.rx_credits}: {data.hex()}') logger.debug(
f'{color("<<< Data", "yellow")} [{self.dlci}] {len(data)} bytes, rx_credits={self.rx_credits}: {data.hex()}'
)
if len(data) and self.sink: if len(data) and self.sink:
self.sink(data) self.sink(data)
@@ -418,23 +432,12 @@ class DLC(EventEmitter):
if c_r: if c_r:
# Command # Command
logger.debug(f'<<< MCC MSC Command: {msc}') logger.debug(f'<<< MCC MSC Command: {msc}')
msc = RFCOMM_MCC_MSC( msc = RFCOMM_MCC_MSC(dlci=self.dlci, fc=0, rtc=1, rtr=1, ic=0, dv=1)
dlci = self.dlci, mcc = RFCOMM_Frame.make_mcc(
fc = 0, type=RFCOMM_MCC_MSC_TYPE, c_r=0, data=bytes(msc)
rtc = 1,
rtr = 1,
ic = 0,
dv = 1
) )
mcc = RFCOMM_Frame.make_mcc(type = RFCOMM_MCC_MSC_TYPE, c_r = 0, data = bytes(msc))
logger.debug(f'>>> MCC MSC Response: {msc}') logger.debug(f'>>> MCC MSC Response: {msc}')
self.send_frame( self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
RFCOMM_Frame.uih(
c_r = self.c_r,
dlci = 0,
information = mcc
)
)
else: else:
# Response # Response
logger.debug(f'<<< MCC MSC Response: {msc}') logger.debug(f'<<< MCC MSC Response: {msc}')
@@ -445,12 +448,7 @@ class DLC(EventEmitter):
self.change_state(DLC.CONNECTING) self.change_state(DLC.CONNECTING)
self.connection_result = asyncio.get_running_loop().create_future() self.connection_result = asyncio.get_running_loop().create_future()
self.send_frame( self.send_frame(RFCOMM_Frame.sabm(c_r=self.c_r, dlci=self.dlci))
RFCOMM_Frame.sabm(
c_r = self.c_r,
dlci = self.dlci
)
)
def accept(self): def accept(self):
if not self.state == DLC.INIT: if not self.state == DLC.INIT:
@@ -463,17 +461,11 @@ class DLC(EventEmitter):
ack_timer=0, ack_timer=0,
max_frame_size=RFCOMM_DEFAULT_PREFERRED_MTU, max_frame_size=RFCOMM_DEFAULT_PREFERRED_MTU,
max_retransmissions=0, max_retransmissions=0,
window_size = RFCOMM_DEFAULT_INITIAL_RX_CREDITS window_size=RFCOMM_DEFAULT_INITIAL_RX_CREDITS,
) )
mcc = RFCOMM_Frame.make_mcc(type=RFCOMM_MCC_PN_TYPE, c_r=0, data=bytes(pn)) mcc = RFCOMM_Frame.make_mcc(type=RFCOMM_MCC_PN_TYPE, c_r=0, data=bytes(pn))
logger.debug(f'>>> PN Response: {pn}') logger.debug(f'>>> PN Response: {pn}')
self.send_frame( self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
RFCOMM_Frame.uih(
c_r = self.c_r,
dlci = 0,
information = mcc
)
)
self.change_state(DLC.CONNECTING) self.change_state(DLC.CONNECTING)
def rx_credits_needed(self): def rx_credits_needed(self):
@@ -491,7 +483,7 @@ class DLC(EventEmitter):
chunk = bytes([rx_credits_needed]) + self.tx_buffer[: self.mtu - 1] chunk = bytes([rx_credits_needed]) + self.tx_buffer[: self.mtu - 1]
self.tx_buffer = self.tx_buffer[len(chunk) - 1 :] self.tx_buffer = self.tx_buffer[len(chunk) - 1 :]
self.rx_credits += rx_credits_needed self.rx_credits += rx_credits_needed
tx_credit_spent = (len(chunk) > 1) tx_credit_spent = len(chunk) > 1
else: else:
chunk = self.tx_buffer[: self.mtu] chunk = self.tx_buffer[: self.mtu]
self.tx_buffer = self.tx_buffer[len(chunk) :] self.tx_buffer = self.tx_buffer[len(chunk) :]
@@ -503,13 +495,15 @@ class DLC(EventEmitter):
self.tx_credits -= 1 self.tx_credits -= 1
# Send the frame # Send the frame
logger.debug(f'>>> sending {len(chunk)} bytes with {rx_credits_needed} credits, rx_credits={self.rx_credits}, tx_credits={self.tx_credits}') logger.debug(
f'>>> sending {len(chunk)} bytes with {rx_credits_needed} credits, rx_credits={self.rx_credits}, tx_credits={self.tx_credits}'
)
self.send_frame( self.send_frame(
RFCOMM_Frame.uih( RFCOMM_Frame.uih(
c_r=self.c_r, c_r=self.c_r,
dlci=self.dlci, dlci=self.dlci,
information=chunk, information=chunk,
p_f = 1 if rx_credits_needed > 0 else 0 p_f=1 if rx_credits_needed > 0 else 0,
) )
) )
@@ -558,7 +552,7 @@ class Multiplexer(EventEmitter):
OPENING: 'OPENING', OPENING: 'OPENING',
DISCONNECTING: 'DISCONNECTING', DISCONNECTING: 'DISCONNECTING',
DISCONNECTED: 'DISCONNECTED', DISCONNECTED: 'DISCONNECTED',
RESET: 'RESET' RESET: 'RESET',
} }
def __init__(self, l2cap_channel, role): def __init__(self, l2cap_channel, role):
@@ -580,7 +574,9 @@ class Multiplexer(EventEmitter):
return Multiplexer.STATE_NAMES[state] return Multiplexer.STATE_NAMES[state]
def change_state(self, new_state): def change_state(self, new_state):
logger.debug(f'{self} state change -> {color(self.state_name(new_state), "cyan")}') logger.debug(
f'{self} state change -> {color(self.state_name(new_state), "cyan")}'
)
self.state = new_state self.state = new_state
def send_frame(self, frame): def send_frame(self, frame):
@@ -634,13 +630,22 @@ class Multiplexer(EventEmitter):
if self.state == Multiplexer.OPENING: if self.state == Multiplexer.OPENING:
self.change_state(Multiplexer.CONNECTED) self.change_state(Multiplexer.CONNECTED)
if self.open_result: if self.open_result:
self.open_result.set_exception(ConnectionError(ConnectionError.CONNECTION_REFUSED)) self.open_result.set_exception(
ConnectionError(
ConnectionError.CONNECTION_REFUSED,
BT_BR_EDR_TRANSPORT,
self.l2cap_channel.connection.peer_address,
'rfcomm',
)
)
else: else:
logger.warn(f'unexpected state for DM: {self}') logger.warn(f'unexpected state for DM: {self}')
def on_disc_frame(self, frame): def on_disc_frame(self, frame):
self.change_state(Multiplexer.DISCONNECTED) self.change_state(Multiplexer.DISCONNECTED)
self.send_frame(RFCOMM_Frame.ua(c_r = 0 if self.role == Multiplexer.INITIATOR else 1, dlci = 0)) self.send_frame(
RFCOMM_Frame.ua(c_r=0 if self.role == Multiplexer.INITIATOR else 1, dlci=0)
)
def on_uih_frame(self, frame): def on_uih_frame(self, frame):
(type, c_r, value) = RFCOMM_Frame.parse_mcc(frame.information) (type, c_r, value) = RFCOMM_Frame.parse_mcc(frame.information)
@@ -716,7 +721,11 @@ class Multiplexer(EventEmitter):
self.disconnection_result = asyncio.get_running_loop().create_future() self.disconnection_result = asyncio.get_running_loop().create_future()
self.change_state(Multiplexer.DISCONNECTING) self.change_state(Multiplexer.DISCONNECTING)
self.send_frame(RFCOMM_Frame.disc(c_r = 1 if self.role == Multiplexer.INITIATOR else 0, dlci = 0)) self.send_frame(
RFCOMM_Frame.disc(
c_r=1 if self.role == Multiplexer.INITIATOR else 0, dlci=0
)
)
await self.disconnection_result await self.disconnection_result
async def open_dlc(self, channel): async def open_dlc(self, channel):
@@ -733,7 +742,7 @@ class Multiplexer(EventEmitter):
ack_timer=0, ack_timer=0,
max_frame_size=RFCOMM_DEFAULT_PREFERRED_MTU, max_frame_size=RFCOMM_DEFAULT_PREFERRED_MTU,
max_retransmissions=0, max_retransmissions=0,
window_size = RFCOMM_DEFAULT_INITIAL_RX_CREDITS window_size=RFCOMM_DEFAULT_INITIAL_RX_CREDITS,
) )
mcc = RFCOMM_Frame.make_mcc(type=RFCOMM_MCC_PN_TYPE, c_r=1, data=bytes(pn)) mcc = RFCOMM_Frame.make_mcc(type=RFCOMM_MCC_PN_TYPE, c_r=1, data=bytes(pn))
logger.debug(f'>>> Sending MCC: {pn}') logger.debug(f'>>> Sending MCC: {pn}')
@@ -743,7 +752,7 @@ class Multiplexer(EventEmitter):
RFCOMM_Frame.uih( RFCOMM_Frame.uih(
c_r=1 if self.role == Multiplexer.INITIATOR else 0, c_r=1 if self.role == Multiplexer.INITIATOR else 0,
dlci=0, dlci=0,
information = mcc information=mcc,
) )
) )
result = await self.open_result result = await self.open_result
@@ -771,7 +780,9 @@ class Client:
async def start(self): async def start(self):
# Create a new L2CAP connection # Create a new L2CAP connection
try: try:
self.l2cap_channel = await self.device.l2cap_channel_manager.connect(self.connection, RFCOMM_PSM) self.l2cap_channel = await self.device.l2cap_channel_manager.connect(
self.connection, RFCOMM_PSM
)
except ProtocolError as error: except ProtocolError as error:
logger.warn(f'L2CAP connection failed: {error}') logger.warn(f'L2CAP connection failed: {error}')
raise raise
@@ -806,7 +817,9 @@ class Server(EventEmitter):
def listen(self, acceptor): def listen(self, acceptor):
# Find a free channel number # Find a free channel number
for channel in range(RFCOMM_DYNAMIC_CHANNEL_NUMBER_START, RFCOMM_DYNAMIC_CHANNEL_NUMBER_END + 1): for channel in range(
RFCOMM_DYNAMIC_CHANNEL_NUMBER_START, RFCOMM_DYNAMIC_CHANNEL_NUMBER_END + 1
):
if channel not in self.acceptors: if channel not in self.acceptors:
self.acceptors[channel] = acceptor self.acceptors[channel] = acceptor
return channel return channel
+154 -72
View File
@@ -33,6 +33,8 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Constants # Constants
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# fmt: off
SDP_CONTINUATION_WATCHDOG = 64 # Maximum number of continuations we're willing to do SDP_CONTINUATION_WATCHDOG = 64 # Maximum number of continuations we're willing to do
SDP_PSM = 0x0001 SDP_PSM = 0x0001
@@ -112,6 +114,8 @@ SDP_PUBLIC_BROWSE_ROOT = core.UUID.from_16_bits(0x1002, 'PublicBrowseRoot')
# To be used in searches where an attribute ID list allows a range to be specified # To be used in searches where an attribute ID list allows a range to be specified
SDP_ALL_ATTRIBUTES_RANGE = (0x0000FFFF, 4) # Express this as tuple so we can convey the desired encoding size SDP_ALL_ATTRIBUTES_RANGE = (0x0000FFFF, 4) # Express this as tuple so we can convey the desired encoding size
# fmt: on
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class DataElement: class DataElement:
@@ -134,19 +138,33 @@ class DataElement:
BOOLEAN: 'BOOLEAN', BOOLEAN: 'BOOLEAN',
SEQUENCE: 'SEQUENCE', SEQUENCE: 'SEQUENCE',
ALTERNATIVE: 'ALTERNATIVE', ALTERNATIVE: 'ALTERNATIVE',
URL: 'URL' URL: 'URL',
} }
type_constructors = { type_constructors = {
NIL: lambda x: DataElement(DataElement.NIL, None), NIL: lambda x: DataElement(DataElement.NIL, None),
UNSIGNED_INTEGER: lambda x, y: DataElement(DataElement.UNSIGNED_INTEGER, DataElement.unsigned_integer_from_bytes(x), value_size=y), UNSIGNED_INTEGER: lambda x, y: DataElement(
SIGNED_INTEGER: lambda x, y: DataElement(DataElement.SIGNED_INTEGER, DataElement.signed_integer_from_bytes(x), value_size=y), DataElement.UNSIGNED_INTEGER,
UUID: lambda x: DataElement(DataElement.UUID, core.UUID.from_bytes(bytes(reversed(x)))), DataElement.unsigned_integer_from_bytes(x),
value_size=y,
),
SIGNED_INTEGER: lambda x, y: DataElement(
DataElement.SIGNED_INTEGER,
DataElement.signed_integer_from_bytes(x),
value_size=y,
),
UUID: lambda x: DataElement(
DataElement.UUID, core.UUID.from_bytes(bytes(reversed(x)))
),
TEXT_STRING: lambda x: DataElement(DataElement.TEXT_STRING, x.decode('utf8')), TEXT_STRING: lambda x: DataElement(DataElement.TEXT_STRING, x.decode('utf8')),
BOOLEAN: lambda x: DataElement(DataElement.BOOLEAN, x[0] == 1), BOOLEAN: lambda x: DataElement(DataElement.BOOLEAN, x[0] == 1),
SEQUENCE: lambda x: DataElement(DataElement.SEQUENCE, DataElement.list_from_bytes(x)), SEQUENCE: lambda x: DataElement(
ALTERNATIVE: lambda x: DataElement(DataElement.ALTERNATIVE, DataElement.list_from_bytes(x)), DataElement.SEQUENCE, DataElement.list_from_bytes(x)
URL: lambda x: DataElement(DataElement.URL, x.decode('utf8')) ),
ALTERNATIVE: lambda x: DataElement(
DataElement.ALTERNATIVE, DataElement.list_from_bytes(x)
),
URL: lambda x: DataElement(DataElement.URL, x.decode('utf8')),
} }
def __init__(self, type, value, value_size=None): def __init__(self, type, value, value_size=None):
@@ -289,13 +307,18 @@ class DataElement:
value_data = data[1 + value_offset : 1 + value_offset + value_size] value_data = data[1 + value_offset : 1 + value_offset + value_size]
constructor = DataElement.type_constructors.get(type) constructor = DataElement.type_constructors.get(type)
if constructor: if constructor:
if type == DataElement.UNSIGNED_INTEGER or type == DataElement.SIGNED_INTEGER: if (
type == DataElement.UNSIGNED_INTEGER
or type == DataElement.SIGNED_INTEGER
):
result = constructor(value_data, value_size) result = constructor(value_data, value_size)
else: else:
result = constructor(value_data) result = constructor(value_data)
else: else:
result = DataElement(type, value_data) result = DataElement(type, value_data)
result.bytes = data[:1 + value_offset + value_size] # Keep a copy so we can re-serialize to an exact replica result.bytes = data[
: 1 + value_offset + value_size
] # Keep a copy so we can re-serialize to an exact replica
return result return result
def to_bytes(self): def to_bytes(self):
@@ -349,9 +372,11 @@ class DataElement:
if size != 0: if size != 0:
raise ValueError('NIL must be empty') raise ValueError('NIL must be empty')
size_index = 0 size_index = 0
elif (self.type == DataElement.UNSIGNED_INTEGER or elif (
self.type == DataElement.SIGNED_INTEGER or self.type == DataElement.UNSIGNED_INTEGER
self.type == DataElement.UUID): or self.type == DataElement.SIGNED_INTEGER
or self.type == DataElement.UUID
):
if size <= 1: if size <= 1:
size_index = 0 size_index = 0
elif size == 2: elif size == 2:
@@ -364,10 +389,12 @@ class DataElement:
size_index = 4 size_index = 4
else: else:
raise ValueError('invalid data size') raise ValueError('invalid data size')
elif (self.type == DataElement.TEXT_STRING or elif (
self.type == DataElement.SEQUENCE or self.type == DataElement.TEXT_STRING
self.type == DataElement.ALTERNATIVE or or self.type == DataElement.SEQUENCE
self.type == DataElement.URL): or self.type == DataElement.ALTERNATIVE
or self.type == DataElement.URL
):
if size <= 0xFF: if size <= 0xFF:
size_index = 5 size_index = 5
size_bytes = bytes([size]) size_bytes = bytes([size])
@@ -396,7 +423,10 @@ class DataElement:
container_separator = '\n' if pretty else '' container_separator = '\n' if pretty else ''
element_separator = '\n' if pretty else ',' element_separator = '\n' if pretty else ','
value_string = f'[{container_separator}{element_separator.join([element.to_string(pretty, indentation + 1 if pretty else 0) for element in self.value])}{container_separator}{prefix}]' value_string = f'[{container_separator}{element_separator.join([element.to_string(pretty, indentation + 1 if pretty else 0) for element in self.value])}{container_separator}{prefix}]'
elif self.type == DataElement.UNSIGNED_INTEGER or self.type == DataElement.SIGNED_INTEGER: elif (
self.type == DataElement.UNSIGNED_INTEGER
or self.type == DataElement.SIGNED_INTEGER
):
value_string = f'{self.value}#{self.value_size}' value_string = f'{self.value}#{self.value_size}'
elif isinstance(self.value, DataElement): elif isinstance(self.value, DataElement):
value_string = self.value.to_string(pretty, indentation) value_string = self.value.to_string(pretty, indentation)
@@ -428,7 +458,14 @@ class ServiceAttribute:
@staticmethod @staticmethod
def find_attribute_in_list(attribute_list, attribute_id): def find_attribute_in_list(attribute_list, attribute_id):
return next((attribute.value for attribute in attribute_list if attribute.id == attribute_id), None) return next(
(
attribute.value
for attribute in attribute_list
if attribute.id == attribute_id
),
None,
)
@staticmethod @staticmethod
def id_name(id): def id_name(id):
@@ -462,6 +499,7 @@ class SDP_PDU:
''' '''
See Bluetooth spec @ Vol 3, Part B - 4.2 PROTOCOL DATA UNIT FORMAT See Bluetooth spec @ Vol 3, Part B - 4.2 PROTOCOL DATA UNIT FORMAT
''' '''
sdp_pdu_classes = {} sdp_pdu_classes = {}
@staticmethod @staticmethod
@@ -484,7 +522,9 @@ class SDP_PDU:
@staticmethod @staticmethod
def parse_service_record_handle_list_preceded_by_count(data, offset): def parse_service_record_handle_list_preceded_by_count(data, offset):
count = struct.unpack_from('>H', data, offset - 2)[0] count = struct.unpack_from('>H', data, offset - 2)[0]
handle_list = [struct.unpack_from('>I', data, offset + x * 4)[0] for x in range(count)] handle_list = [
struct.unpack_from('>I', data, offset + x * 4)[0] for x in range(count)
]
return offset + count * 4, handle_list return offset + count * 4, handle_list
@staticmethod @staticmethod
@@ -532,7 +572,10 @@ class SDP_PDU:
HCI_Object.init_from_fields(self, self.fields, kwargs) HCI_Object.init_from_fields(self, self.fields, kwargs)
if pdu is None: if pdu is None:
parameters = HCI_Object.dict_to_bytes(kwargs, self.fields) parameters = HCI_Object.dict_to_bytes(kwargs, self.fields)
pdu = struct.pack('>BHH', self.pdu_id, transaction_id, len(parameters)) + parameters pdu = (
struct.pack('>BHH', self.pdu_id, transaction_id, len(parameters))
+ parameters
)
self.pdu = pdu self.pdu = pdu
self.transaction_id = transaction_id self.transaction_id = transaction_id
@@ -555,9 +598,7 @@ class SDP_PDU:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@SDP_PDU.subclass([ @SDP_PDU.subclass([('error_code', {'size': 2, 'mapper': SDP_PDU.error_name})])
('error_code', {'size': 2, 'mapper': SDP_PDU.error_name})
])
class SDP_ErrorResponse(SDP_PDU): class SDP_ErrorResponse(SDP_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part B - 4.4.1 SDP_ErrorResponse PDU See Bluetooth spec @ Vol 3, Part B - 4.4.1 SDP_ErrorResponse PDU
@@ -565,11 +606,13 @@ class SDP_ErrorResponse(SDP_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@SDP_PDU.subclass([ @SDP_PDU.subclass(
[
('service_search_pattern', DataElement.parse_from_bytes), ('service_search_pattern', DataElement.parse_from_bytes),
('maximum_service_record_count', '>2'), ('maximum_service_record_count', '>2'),
('continuation_state', '*') ('continuation_state', '*'),
]) ]
)
class SDP_ServiceSearchRequest(SDP_PDU): class SDP_ServiceSearchRequest(SDP_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part B - 4.5.1 SDP_ServiceSearchRequest PDU See Bluetooth spec @ Vol 3, Part B - 4.5.1 SDP_ServiceSearchRequest PDU
@@ -577,12 +620,17 @@ class SDP_ServiceSearchRequest(SDP_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@SDP_PDU.subclass([ @SDP_PDU.subclass(
[
('total_service_record_count', '>2'), ('total_service_record_count', '>2'),
('current_service_record_count', '>2'), ('current_service_record_count', '>2'),
('service_record_handle_list', SDP_PDU.parse_service_record_handle_list_preceded_by_count), (
('continuation_state', '*') 'service_record_handle_list',
]) SDP_PDU.parse_service_record_handle_list_preceded_by_count,
),
('continuation_state', '*'),
]
)
class SDP_ServiceSearchResponse(SDP_PDU): class SDP_ServiceSearchResponse(SDP_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part B - 4.5.2 SDP_ServiceSearchResponse PDU See Bluetooth spec @ Vol 3, Part B - 4.5.2 SDP_ServiceSearchResponse PDU
@@ -590,12 +638,14 @@ class SDP_ServiceSearchResponse(SDP_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@SDP_PDU.subclass([ @SDP_PDU.subclass(
[
('service_record_handle', '>4'), ('service_record_handle', '>4'),
('maximum_attribute_byte_count', '>2'), ('maximum_attribute_byte_count', '>2'),
('attribute_id_list', DataElement.parse_from_bytes), ('attribute_id_list', DataElement.parse_from_bytes),
('continuation_state', '*') ('continuation_state', '*'),
]) ]
)
class SDP_ServiceAttributeRequest(SDP_PDU): class SDP_ServiceAttributeRequest(SDP_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part B - 4.6.1 SDP_ServiceAttributeRequest PDU See Bluetooth spec @ Vol 3, Part B - 4.6.1 SDP_ServiceAttributeRequest PDU
@@ -603,11 +653,13 @@ class SDP_ServiceAttributeRequest(SDP_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@SDP_PDU.subclass([ @SDP_PDU.subclass(
[
('attribute_list_byte_count', '>2'), ('attribute_list_byte_count', '>2'),
('attribute_list', SDP_PDU.parse_bytes_preceded_by_length), ('attribute_list', SDP_PDU.parse_bytes_preceded_by_length),
('continuation_state', '*') ('continuation_state', '*'),
]) ]
)
class SDP_ServiceAttributeResponse(SDP_PDU): class SDP_ServiceAttributeResponse(SDP_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part B - 4.6.2 SDP_ServiceAttributeResponse PDU See Bluetooth spec @ Vol 3, Part B - 4.6.2 SDP_ServiceAttributeResponse PDU
@@ -615,12 +667,14 @@ class SDP_ServiceAttributeResponse(SDP_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@SDP_PDU.subclass([ @SDP_PDU.subclass(
[
('service_search_pattern', DataElement.parse_from_bytes), ('service_search_pattern', DataElement.parse_from_bytes),
('maximum_attribute_byte_count', '>2'), ('maximum_attribute_byte_count', '>2'),
('attribute_id_list', DataElement.parse_from_bytes), ('attribute_id_list', DataElement.parse_from_bytes),
('continuation_state', '*') ('continuation_state', '*'),
]) ]
)
class SDP_ServiceSearchAttributeRequest(SDP_PDU): class SDP_ServiceSearchAttributeRequest(SDP_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part B - 4.7.1 SDP_ServiceSearchAttributeRequest PDU See Bluetooth spec @ Vol 3, Part B - 4.7.1 SDP_ServiceSearchAttributeRequest PDU
@@ -628,11 +682,13 @@ class SDP_ServiceSearchAttributeRequest(SDP_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@SDP_PDU.subclass([ @SDP_PDU.subclass(
[
('attribute_lists_byte_count', '>2'), ('attribute_lists_byte_count', '>2'),
('attribute_lists', SDP_PDU.parse_bytes_preceded_by_length), ('attribute_lists', SDP_PDU.parse_bytes_preceded_by_length),
('continuation_state', '*') ('continuation_state', '*'),
]) ]
)
class SDP_ServiceSearchAttributeResponse(SDP_PDU): class SDP_ServiceSearchAttributeResponse(SDP_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part B - 4.7.2 SDP_ServiceSearchAttributeResponse PDU See Bluetooth spec @ Vol 3, Part B - 4.7.2 SDP_ServiceSearchAttributeResponse PDU
@@ -659,7 +715,9 @@ class Client:
if self.pending_request is not None: if self.pending_request is not None:
raise InvalidStateError('request already pending') raise InvalidStateError('request already pending')
service_search_pattern = DataElement.sequence([DataElement.uuid(uuid) for uuid in uuids]) service_search_pattern = DataElement.sequence(
[DataElement.uuid(uuid) for uuid in uuids]
)
# Request and accumulate until there's no more continuation # Request and accumulate until there's no more continuation
service_record_handle_list = [] service_record_handle_list = []
@@ -671,7 +729,7 @@ class Client:
transaction_id=0, # Transaction ID TODO: pick a real value transaction_id=0, # Transaction ID TODO: pick a real value
service_search_pattern=service_search_pattern, service_search_pattern=service_search_pattern,
maximum_service_record_count=0xFFFF, maximum_service_record_count=0xFFFF,
continuation_state = continuation_state continuation_state=continuation_state,
) )
) )
response = SDP_PDU.from_bytes(response_pdu) response = SDP_PDU.from_bytes(response_pdu)
@@ -689,10 +747,14 @@ class Client:
if self.pending_request is not None: if self.pending_request is not None:
raise InvalidStateError('request already pending') raise InvalidStateError('request already pending')
service_search_pattern = DataElement.sequence([DataElement.uuid(uuid) for uuid in uuids]) service_search_pattern = DataElement.sequence(
[DataElement.uuid(uuid) for uuid in uuids]
)
attribute_id_list = DataElement.sequence( attribute_id_list = DataElement.sequence(
[ [
DataElement.unsigned_integer(attribute_id[0], value_size=attribute_id[1]) DataElement.unsigned_integer(
attribute_id[0], value_size=attribute_id[1]
)
if type(attribute_id) is tuple if type(attribute_id) is tuple
else DataElement.unsigned_integer_16(attribute_id) else DataElement.unsigned_integer_16(attribute_id)
for attribute_id in attribute_ids for attribute_id in attribute_ids
@@ -710,7 +772,7 @@ class Client:
service_search_pattern=service_search_pattern, service_search_pattern=service_search_pattern,
maximum_attribute_byte_count=0xFFFF, maximum_attribute_byte_count=0xFFFF,
attribute_id_list=attribute_id_list, attribute_id_list=attribute_id_list,
continuation_state = continuation_state continuation_state=continuation_state,
) )
) )
response = SDP_PDU.from_bytes(response_pdu) response = SDP_PDU.from_bytes(response_pdu)
@@ -740,7 +802,9 @@ class Client:
attribute_id_list = DataElement.sequence( attribute_id_list = DataElement.sequence(
[ [
DataElement.unsigned_integer(attribute_id[0], value_size=attribute_id[1]) DataElement.unsigned_integer(
attribute_id[0], value_size=attribute_id[1]
)
if type(attribute_id) is tuple if type(attribute_id) is tuple
else DataElement.unsigned_integer_16(attribute_id) else DataElement.unsigned_integer_16(attribute_id)
for attribute_id in attribute_ids for attribute_id in attribute_ids
@@ -758,7 +822,7 @@ class Client:
service_record_handle=service_record_handle, service_record_handle=service_record_handle,
maximum_attribute_byte_count=0xFFFF, maximum_attribute_byte_count=0xFFFF,
attribute_id_list=attribute_id_list, attribute_id_list=attribute_id_list,
continuation_state = continuation_state continuation_state=continuation_state,
) )
) )
response = SDP_PDU.from_bytes(response_pdu) response = SDP_PDU.from_bytes(response_pdu)
@@ -823,8 +887,7 @@ class Server:
logger.warn(color(f'failed to parse SDP Request PDU: {error}', 'red')) logger.warn(color(f'failed to parse SDP Request PDU: {error}', 'red'))
self.send_response( self.send_response(
SDP_ErrorResponse( SDP_ErrorResponse(
transaction_id = 0, transaction_id=0, error_code=SDP_INVALID_REQUEST_SYNTAX_ERROR
error_code = SDP_INVALID_REQUEST_SYNTAX_ERROR
) )
) )
@@ -841,7 +904,7 @@ class Server:
self.send_response( self.send_response(
SDP_ErrorResponse( SDP_ErrorResponse(
transaction_id=sdp_pdu.transaction_id, transaction_id=sdp_pdu.transaction_id,
error_code = SDP_INSUFFICIENT_RESOURCES_TO_SATISFY_REQUEST_ERROR error_code=SDP_INSUFFICIENT_RESOURCES_TO_SATISFY_REQUEST_ERROR,
) )
) )
else: else:
@@ -849,7 +912,7 @@ class Server:
self.send_response( self.send_response(
SDP_ErrorResponse( SDP_ErrorResponse(
transaction_id=sdp_pdu.transaction_id, transaction_id=sdp_pdu.transaction_id,
error_code = SDP_INVALID_REQUEST_SYNTAX_ERROR error_code=SDP_INVALID_REQUEST_SYNTAX_ERROR,
) )
) )
@@ -877,7 +940,8 @@ class Server:
id_range_start = attribute_id.value id_range_start = attribute_id.value
id_range_end = attribute_id.value id_range_end = attribute_id.value
attributes += [ attributes += [
attribute for attribute in service attribute
for attribute in service
if attribute.id >= id_range_start and attribute.id <= id_range_end if attribute.id >= id_range_start and attribute.id <= id_range_end
] ]
@@ -897,7 +961,7 @@ class Server:
self.send_response( self.send_response(
SDP_ErrorResponse( SDP_ErrorResponse(
transaction_id=request.transaction_id, transaction_id=request.transaction_id,
error_code = SDP_INVALID_CONTINUATION_STATE_ERROR error_code=SDP_INVALID_CONTINUATION_STATE_ERROR,
) )
) )
return return
@@ -910,30 +974,38 @@ class Server:
service_record_handles = list(matching_services.keys()) service_record_handles = list(matching_services.keys())
# Only return up to the maximum requested # Only return up to the maximum requested
service_record_handles_subset = service_record_handles[:request.maximum_service_record_count] service_record_handles_subset = service_record_handles[
: request.maximum_service_record_count
]
# Serialize to a byte array, and remember the total count # Serialize to a byte array, and remember the total count
logger.debug(f'Service Record Handles: {service_record_handles}') logger.debug(f'Service Record Handles: {service_record_handles}')
self.current_response = ( self.current_response = (
len(service_record_handles), len(service_record_handles),
service_record_handles_subset service_record_handles_subset,
) )
# Respond, keeping any unsent handles for later # Respond, keeping any unsent handles for later
service_record_handles = self.current_response[1][:request.maximum_service_record_count] service_record_handles = self.current_response[1][
: request.maximum_service_record_count
]
self.current_response = ( self.current_response = (
self.current_response[0], self.current_response[0],
self.current_response[1][request.maximum_service_record_count:] self.current_response[1][request.maximum_service_record_count :],
)
continuation_state = (
Server.CONTINUATION_STATE if self.current_response[1] else bytes([0])
)
service_record_handle_list = b''.join(
[struct.pack('>I', handle) for handle in service_record_handles]
) )
continuation_state = Server.CONTINUATION_STATE if self.current_response[1] else bytes([0])
service_record_handle_list = b''.join([struct.pack('>I', handle) for handle in service_record_handles])
self.send_response( self.send_response(
SDP_ServiceSearchResponse( SDP_ServiceSearchResponse(
transaction_id=request.transaction_id, transaction_id=request.transaction_id,
total_service_record_count=self.current_response[0], total_service_record_count=self.current_response[0],
current_service_record_count=len(service_record_handles), current_service_record_count=len(service_record_handles),
service_record_handle_list=service_record_handle_list, service_record_handle_list=service_record_handle_list,
continuation_state = continuation_state continuation_state=continuation_state,
) )
) )
@@ -944,7 +1016,7 @@ class Server:
self.send_response( self.send_response(
SDP_ErrorResponse( SDP_ErrorResponse(
transaction_id=request.transaction_id, transaction_id=request.transaction_id,
error_code = SDP_INVALID_CONTINUATION_STATE_ERROR error_code=SDP_INVALID_CONTINUATION_STATE_ERROR,
) )
) )
return return
@@ -958,26 +1030,30 @@ class Server:
self.send_response( self.send_response(
SDP_ErrorResponse( SDP_ErrorResponse(
transaction_id=request.transaction_id, transaction_id=request.transaction_id,
error_code = SDP_INVALID_SERVICE_RECORD_HANDLE_ERROR error_code=SDP_INVALID_SERVICE_RECORD_HANDLE_ERROR,
) )
) )
return return
# Get the attributes for the service # Get the attributes for the service
attribute_list = Server.get_service_attributes(service, request.attribute_id_list.value) attribute_list = Server.get_service_attributes(
service, request.attribute_id_list.value
)
# Serialize to a byte array # Serialize to a byte array
logger.debug(f'Attributes: {attribute_list}') logger.debug(f'Attributes: {attribute_list}')
self.current_response = bytes(attribute_list) self.current_response = bytes(attribute_list)
# Respond, keeping any pending chunks for later # Respond, keeping any pending chunks for later
attribute_list, continuation_state = self.get_next_response_payload(request.maximum_attribute_byte_count) attribute_list, continuation_state = self.get_next_response_payload(
request.maximum_attribute_byte_count
)
self.send_response( self.send_response(
SDP_ServiceAttributeResponse( SDP_ServiceAttributeResponse(
transaction_id=request.transaction_id, transaction_id=request.transaction_id,
attribute_list_byte_count=len(attribute_list), attribute_list_byte_count=len(attribute_list),
attribute_list=attribute_list, attribute_list=attribute_list,
continuation_state = continuation_state continuation_state=continuation_state,
) )
) )
@@ -988,7 +1064,7 @@ class Server:
self.send_response( self.send_response(
SDP_ErrorResponse( SDP_ErrorResponse(
transaction_id=request.transaction_id, transaction_id=request.transaction_id,
error_code = SDP_INVALID_CONTINUATION_STATE_ERROR error_code=SDP_INVALID_CONTINUATION_STATE_ERROR,
) )
) )
else: else:
@@ -996,12 +1072,16 @@ class Server:
self.current_response = None self.current_response = None
# Find the matching services # Find the matching services
matching_services = self.match_services(request.service_search_pattern).values() matching_services = self.match_services(
request.service_search_pattern
).values()
# Filter the required attributes # Filter the required attributes
attribute_lists = DataElement.sequence([]) attribute_lists = DataElement.sequence([])
for service in matching_services: for service in matching_services:
attribute_list = Server.get_service_attributes(service, request.attribute_id_list.value) attribute_list = Server.get_service_attributes(
service, request.attribute_id_list.value
)
if attribute_list.value: if attribute_list.value:
attribute_lists.value.append(attribute_list) attribute_lists.value.append(attribute_list)
@@ -1010,12 +1090,14 @@ class Server:
self.current_response = bytes(attribute_lists) self.current_response = bytes(attribute_lists)
# Respond, keeping any pending chunks for later # Respond, keeping any pending chunks for later
attribute_lists, continuation_state = self.get_next_response_payload(request.maximum_attribute_byte_count) attribute_lists, continuation_state = self.get_next_response_payload(
request.maximum_attribute_byte_count
)
self.send_response( self.send_response(
SDP_ServiceSearchAttributeResponse( SDP_ServiceSearchAttributeResponse(
transaction_id=request.transaction_id, transaction_id=request.transaction_id,
attribute_lists_byte_count=len(attribute_lists), attribute_lists_byte_count=len(attribute_lists),
attribute_lists=attribute_lists, attribute_lists=attribute_lists,
continuation_state = continuation_state continuation_state=continuation_state,
) )
) )
+330 -204
View File
File diff suppressed because it is too large Load Diff
+13
View File
@@ -38,42 +38,55 @@ async def open_transport(name):
scheme, *spec = name.split(':', 1) scheme, *spec = name.split(':', 1)
if scheme == 'serial' and spec: if scheme == 'serial' and spec:
from .serial import open_serial_transport from .serial import open_serial_transport
return await open_serial_transport(spec[0]) return await open_serial_transport(spec[0])
elif scheme == 'udp' and spec: elif scheme == 'udp' and spec:
from .udp import open_udp_transport from .udp import open_udp_transport
return await open_udp_transport(spec[0]) return await open_udp_transport(spec[0])
elif scheme == 'tcp-client' and spec: elif scheme == 'tcp-client' and spec:
from .tcp_client import open_tcp_client_transport from .tcp_client import open_tcp_client_transport
return await open_tcp_client_transport(spec[0]) return await open_tcp_client_transport(spec[0])
elif scheme == 'tcp-server' and spec: elif scheme == 'tcp-server' and spec:
from .tcp_server import open_tcp_server_transport from .tcp_server import open_tcp_server_transport
return await open_tcp_server_transport(spec[0]) return await open_tcp_server_transport(spec[0])
elif scheme == 'ws-client' and spec: elif scheme == 'ws-client' and spec:
from .ws_client import open_ws_client_transport from .ws_client import open_ws_client_transport
return await open_ws_client_transport(spec[0]) return await open_ws_client_transport(spec[0])
elif scheme == 'ws-server' and spec: elif scheme == 'ws-server' and spec:
from .ws_server import open_ws_server_transport from .ws_server import open_ws_server_transport
return await open_ws_server_transport(spec[0]) return await open_ws_server_transport(spec[0])
elif scheme == 'pty': elif scheme == 'pty':
from .pty import open_pty_transport from .pty import open_pty_transport
return await open_pty_transport(spec[0] if spec else None) return await open_pty_transport(spec[0] if spec else None)
elif scheme == 'file': elif scheme == 'file':
from .file import open_file_transport from .file import open_file_transport
return await open_file_transport(spec[0] if spec else None) return await open_file_transport(spec[0] if spec else None)
elif scheme == 'vhci': elif scheme == 'vhci':
from .vhci import open_vhci_transport from .vhci import open_vhci_transport
return await open_vhci_transport(spec[0] if spec else None) return await open_vhci_transport(spec[0] if spec else None)
elif scheme == 'hci-socket': elif scheme == 'hci-socket':
from .hci_socket import open_hci_socket_transport from .hci_socket import open_hci_socket_transport
return await open_hci_socket_transport(spec[0] if spec else None) return await open_hci_socket_transport(spec[0] if spec else None)
elif scheme == 'usb': elif scheme == 'usb':
from .usb import open_usb_transport from .usb import open_usb_transport
return await open_usb_transport(spec[0] if spec else None) return await open_usb_transport(spec[0] if spec else None)
elif scheme == 'pyusb': elif scheme == 'pyusb':
from .pyusb import open_pyusb_transport from .pyusb import open_pyusb_transport
return await open_pyusb_transport(spec[0] if spec else None) return await open_pyusb_transport(spec[0] if spec else None)
elif scheme == 'android-emulator': elif scheme == 'android-emulator':
from .android_emulator import open_android_emulator_transport from .android_emulator import open_android_emulator_transport
return await open_android_emulator_transport(spec[0] if spec else None) return await open_android_emulator_transport(spec[0] if spec else None)
else: else:
raise ValueError('unknown transport scheme') raise ValueError('unknown transport scheme')
+2 -7
View File
@@ -59,12 +59,7 @@ async def open_android_emulator_transport(spec):
return bytes([packet.type]) + packet.packet return bytes([packet.type]) + packet.packet
async def write(self, packet): async def write(self, packet):
await self.hci_device.write( await self.hci_device.write(HCIPacket(type=packet[0], packet=packet[1:]))
HCIPacket(
type = packet[0],
packet = packet[1:]
)
)
# Parse the parameters # Parse the parameters
mode = 'host' mode = 'host'
@@ -100,7 +95,7 @@ async def open_android_emulator_transport(spec):
transport = PumpedTransport( transport = PumpedTransport(
PumpedPacketSource(hci_device.read), PumpedPacketSource(hci_device.read),
PumpedPacketSink(hci_device.write), PumpedPacketSink(hci_device.write),
channel.close channel.close,
) )
transport.start() transport.start()
+14 -6
View File
@@ -36,7 +36,7 @@ HCI_PACKET_INFO = {
hci.HCI_COMMAND_PACKET: (1, 2, 'B'), hci.HCI_COMMAND_PACKET: (1, 2, 'B'),
hci.HCI_ACL_DATA_PACKET: (2, 2, 'H'), hci.HCI_ACL_DATA_PACKET: (2, 2, 'H'),
hci.HCI_SYNCHRONOUS_DATA_PACKET: (1, 2, 'B'), hci.HCI_SYNCHRONOUS_DATA_PACKET: (1, 2, 'B'),
hci.HCI_EVENT_PACKET: (1, 1, 'B') hci.HCI_EVENT_PACKET: (1, 1, 'B'),
} }
@@ -67,6 +67,7 @@ class PacketParser:
''' '''
In-line parser that accepts data and emits 'on_packet' when a full packet has been parsed In-line parser that accepts data and emits 'on_packet' when a full packet has been parsed
''' '''
NEED_TYPE = 0 NEED_TYPE = 0
NEED_LENGTH = 1 NEED_LENGTH = 1
NEED_BODY = 2 NEED_BODY = 2
@@ -95,13 +96,17 @@ class PacketParser:
if self.bytes_needed == 0: if self.bytes_needed == 0:
if self.state == PacketParser.NEED_TYPE: if self.state == PacketParser.NEED_TYPE:
packet_type = self.packet[0] packet_type = self.packet[0]
self.packet_info = HCI_PACKET_INFO.get(packet_type) or self.extended_packet_info.get(packet_type) self.packet_info = HCI_PACKET_INFO.get(
packet_type
) or self.extended_packet_info.get(packet_type)
if self.packet_info is None: if self.packet_info is None:
raise ValueError(f'invalid packet type {packet_type}') raise ValueError(f'invalid packet type {packet_type}')
self.state = PacketParser.NEED_LENGTH self.state = PacketParser.NEED_LENGTH
self.bytes_needed = self.packet_info[0] + self.packet_info[1] self.bytes_needed = self.packet_info[0] + self.packet_info[1]
elif self.state == PacketParser.NEED_LENGTH: elif self.state == PacketParser.NEED_LENGTH:
body_length = struct.unpack_from(self.packet_info[2], self.packet, 1 + self.packet_info[1])[0] body_length = struct.unpack_from(
self.packet_info[2], self.packet, 1 + self.packet_info[1]
)[0]
self.bytes_needed = body_length self.bytes_needed = body_length
self.state = PacketParser.NEED_BODY self.state = PacketParser.NEED_BODY
@@ -111,7 +116,9 @@ class PacketParser:
try: try:
self.sink.on_packet(bytes(self.packet)) self.sink.on_packet(bytes(self.packet))
except Exception as error: except Exception as error:
logger.warning(color(f'!!! Exception in on_packet: {error}', 'red')) logger.warning(
color(f'!!! Exception in on_packet: {error}', 'red')
)
self.reset() self.reset()
def set_packet_sink(self, sink): def set_packet_sink(self, sink):
@@ -187,6 +194,7 @@ class AsyncPipeSink:
''' '''
Sink that forwards packets asynchronously to another sink Sink that forwards packets asynchronously to another sink
''' '''
def __init__(self, sink): def __init__(self, sink):
self.sink = sink self.sink = sink
self.loop = asyncio.get_running_loop() self.loop = asyncio.get_running_loop()
@@ -274,7 +282,7 @@ class PumpedPacketSource(ParserSource):
self.terminated.set_result(error) self.terminated.set_result(error)
break break
self.pump_task = asyncio.get_running_loop().create_task(pump_packets()) self.pump_task = asyncio.create_task(pump_packets())
def close(self): def close(self):
if self.pump_task: if self.pump_task:
@@ -304,7 +312,7 @@ class PumpedPacketSink:
logger.warn(f'exception while sending packet: {error}') logger.warn(f'exception while sending packet: {error}')
break break
self.pump_task = asyncio.get_running_loop().create_task(pump_packets()) self.pump_task = asyncio.create_task(pump_packets())
def close(self): def close(self):
if self.pump_task: if self.pump_task:
@@ -21,24 +21,28 @@ from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import message as _message from google.protobuf import message as _message
from google.protobuf import reflection as _reflection from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports) # @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default() _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n emulated_bluetooth_packets.proto\x12\x1b\x61ndroid.emulation.bluetooth\"\xfb\x01\n\tHCIPacket\x12?\n\x04type\x18\x01 \x01(\x0e\x32\x31.android.emulation.bluetooth.HCIPacket.PacketType\x12\x0e\n\x06packet\x18\x02 \x01(\x0c\"\x9c\x01\n\nPacketType\x12\x1b\n\x17PACKET_TYPE_UNSPECIFIED\x10\x00\x12\x1b\n\x17PACKET_TYPE_HCI_COMMAND\x10\x01\x12\x13\n\x0fPACKET_TYPE_ACL\x10\x02\x12\x13\n\x0fPACKET_TYPE_SCO\x10\x03\x12\x15\n\x11PACKET_TYPE_EVENT\x10\x04\x12\x13\n\x0fPACKET_TYPE_ISO\x10\x05\x42J\n\x1f\x63om.android.emulation.bluetoothP\x01\xf8\x01\x01\xa2\x02\x03\x41\x45\x42\xaa\x02\x1b\x41ndroid.Emulation.Bluetoothb\x06proto3'
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n emulated_bluetooth_packets.proto\x12\x1b\x61ndroid.emulation.bluetooth\"\xfb\x01\n\tHCIPacket\x12?\n\x04type\x18\x01 \x01(\x0e\x32\x31.android.emulation.bluetooth.HCIPacket.PacketType\x12\x0e\n\x06packet\x18\x02 \x01(\x0c\"\x9c\x01\n\nPacketType\x12\x1b\n\x17PACKET_TYPE_UNSPECIFIED\x10\x00\x12\x1b\n\x17PACKET_TYPE_HCI_COMMAND\x10\x01\x12\x13\n\x0fPACKET_TYPE_ACL\x10\x02\x12\x13\n\x0fPACKET_TYPE_SCO\x10\x03\x12\x15\n\x11PACKET_TYPE_EVENT\x10\x04\x12\x13\n\x0fPACKET_TYPE_ISO\x10\x05\x42J\n\x1f\x63om.android.emulation.bluetoothP\x01\xf8\x01\x01\xa2\x02\x03\x41\x45\x42\xaa\x02\x1b\x41ndroid.Emulation.Bluetoothb\x06proto3') )
_HCIPACKET = DESCRIPTOR.message_types_by_name['HCIPacket'] _HCIPACKET = DESCRIPTOR.message_types_by_name['HCIPacket']
_HCIPACKET_PACKETTYPE = _HCIPACKET.enum_types_by_name['PacketType'] _HCIPACKET_PACKETTYPE = _HCIPACKET.enum_types_by_name['PacketType']
HCIPacket = _reflection.GeneratedProtocolMessageType('HCIPacket', (_message.Message,), { HCIPacket = _reflection.GeneratedProtocolMessageType(
'HCIPacket',
(_message.Message,),
{
'DESCRIPTOR': _HCIPACKET, 'DESCRIPTOR': _HCIPACKET,
'__module__': 'emulated_bluetooth_packets_pb2' '__module__': 'emulated_bluetooth_packets_pb2'
# @@protoc_insertion_point(class_scope:android.emulation.bluetooth.HCIPacket) # @@protoc_insertion_point(class_scope:android.emulation.bluetooth.HCIPacket)
}) },
)
_sym_db.RegisterMessage(HCIPacket) _sym_db.RegisterMessage(HCIPacket)
if _descriptor._USE_C_DESCRIPTORS == False: if _descriptor._USE_C_DESCRIPTORS == False:
+10 -4
View File
@@ -21,6 +21,7 @@ from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import message as _message from google.protobuf import message as _message
from google.protobuf import reflection as _reflection from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports) # @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default() _sym_db = _symbol_database.Default()
@@ -29,16 +30,21 @@ _sym_db = _symbol_database.Default()
from . import emulated_bluetooth_packets_pb2 as emulated__bluetooth__packets__pb2 from . import emulated_bluetooth_packets_pb2 as emulated__bluetooth__packets__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x18\x65mulated_bluetooth.proto\x12\x1b\x61ndroid.emulation.bluetooth\x1a emulated_bluetooth_packets.proto\"\x19\n\x07RawData\x12\x0e\n\x06packet\x18\x01 \x01(\x0c\x32\xcb\x02\n\x18\x45mulatedBluetoothService\x12\x64\n\x12registerClassicPhy\x12$.android.emulation.bluetooth.RawData\x1a$.android.emulation.bluetooth.RawData(\x01\x30\x01\x12`\n\x0eregisterBlePhy\x12$.android.emulation.bluetooth.RawData\x1a$.android.emulation.bluetooth.RawData(\x01\x30\x01\x12g\n\x11registerHCIDevice\x12&.android.emulation.bluetooth.HCIPacket\x1a&.android.emulation.bluetooth.HCIPacket(\x01\x30\x01\x42\"\n\x1e\x63om.android.emulator.bluetoothP\x01\x62\x06proto3') DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n\x18\x65mulated_bluetooth.proto\x12\x1b\x61ndroid.emulation.bluetooth\x1a emulated_bluetooth_packets.proto\"\x19\n\x07RawData\x12\x0e\n\x06packet\x18\x01 \x01(\x0c\x32\xcb\x02\n\x18\x45mulatedBluetoothService\x12\x64\n\x12registerClassicPhy\x12$.android.emulation.bluetooth.RawData\x1a$.android.emulation.bluetooth.RawData(\x01\x30\x01\x12`\n\x0eregisterBlePhy\x12$.android.emulation.bluetooth.RawData\x1a$.android.emulation.bluetooth.RawData(\x01\x30\x01\x12g\n\x11registerHCIDevice\x12&.android.emulation.bluetooth.HCIPacket\x1a&.android.emulation.bluetooth.HCIPacket(\x01\x30\x01\x42\"\n\x1e\x63om.android.emulator.bluetoothP\x01\x62\x06proto3'
)
_RAWDATA = DESCRIPTOR.message_types_by_name['RawData'] _RAWDATA = DESCRIPTOR.message_types_by_name['RawData']
RawData = _reflection.GeneratedProtocolMessageType('RawData', (_message.Message,), { RawData = _reflection.GeneratedProtocolMessageType(
'RawData',
(_message.Message,),
{
'DESCRIPTOR': _RAWDATA, 'DESCRIPTOR': _RAWDATA,
'__module__': 'emulated_bluetooth_pb2' '__module__': 'emulated_bluetooth_pb2'
# @@protoc_insertion_point(class_scope:android.emulation.bluetooth.RawData) # @@protoc_insertion_point(class_scope:android.emulation.bluetooth.RawData)
}) },
)
_sym_db.RegisterMessage(RawData) _sym_db.RegisterMessage(RawData)
_EMULATEDBLUETOOTHSERVICE = DESCRIPTOR.services_by_name['EmulatedBluetoothService'] _EMULATEDBLUETOOTHSERVICE = DESCRIPTOR.services_by_name['EmulatedBluetoothService']
+53 -16
View File
@@ -138,7 +138,8 @@ def add_EmulatedBluetoothServiceServicer_to_server(servicer, server):
), ),
} }
generic_handler = grpc.method_handlers_generic_handler( generic_handler = grpc.method_handlers_generic_handler(
'android.emulation.bluetooth.EmulatedBluetoothService', rpc_method_handlers) 'android.emulation.bluetooth.EmulatedBluetoothService', rpc_method_handlers
)
server.add_generic_rpc_handlers((generic_handler,)) server.add_generic_rpc_handlers((generic_handler,))
@@ -156,7 +157,8 @@ class EmulatedBluetoothService(object):
""" """
@staticmethod @staticmethod
def registerClassicPhy(request_iterator, def registerClassicPhy(
request_iterator,
target, target,
options=(), options=(),
channel_credentials=None, channel_credentials=None,
@@ -165,15 +167,27 @@ class EmulatedBluetoothService(object):
compression=None, compression=None,
wait_for_ready=None, wait_for_ready=None,
timeout=None, timeout=None,
metadata=None): metadata=None,
return grpc.experimental.stream_stream(request_iterator, target, '/android.emulation.bluetooth.EmulatedBluetoothService/registerClassicPhy', ):
return grpc.experimental.stream_stream(
request_iterator,
target,
'/android.emulation.bluetooth.EmulatedBluetoothService/registerClassicPhy',
emulated__bluetooth__pb2.RawData.SerializeToString, emulated__bluetooth__pb2.RawData.SerializeToString,
emulated__bluetooth__pb2.RawData.FromString, emulated__bluetooth__pb2.RawData.FromString,
options, channel_credentials, options,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata) channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
)
@staticmethod @staticmethod
def registerBlePhy(request_iterator, def registerBlePhy(
request_iterator,
target, target,
options=(), options=(),
channel_credentials=None, channel_credentials=None,
@@ -182,15 +196,27 @@ class EmulatedBluetoothService(object):
compression=None, compression=None,
wait_for_ready=None, wait_for_ready=None,
timeout=None, timeout=None,
metadata=None): metadata=None,
return grpc.experimental.stream_stream(request_iterator, target, '/android.emulation.bluetooth.EmulatedBluetoothService/registerBlePhy', ):
return grpc.experimental.stream_stream(
request_iterator,
target,
'/android.emulation.bluetooth.EmulatedBluetoothService/registerBlePhy',
emulated__bluetooth__pb2.RawData.SerializeToString, emulated__bluetooth__pb2.RawData.SerializeToString,
emulated__bluetooth__pb2.RawData.FromString, emulated__bluetooth__pb2.RawData.FromString,
options, channel_credentials, options,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata) channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
)
@staticmethod @staticmethod
def registerHCIDevice(request_iterator, def registerHCIDevice(
request_iterator,
target, target,
options=(), options=(),
channel_credentials=None, channel_credentials=None,
@@ -199,9 +225,20 @@ class EmulatedBluetoothService(object):
compression=None, compression=None,
wait_for_ready=None, wait_for_ready=None,
timeout=None, timeout=None,
metadata=None): metadata=None,
return grpc.experimental.stream_stream(request_iterator, target, '/android.emulation.bluetooth.EmulatedBluetoothService/registerHCIDevice', ):
return grpc.experimental.stream_stream(
request_iterator,
target,
'/android.emulation.bluetooth.EmulatedBluetoothService/registerHCIDevice',
emulated__bluetooth__packets__pb2.HCIPacket.SerializeToString, emulated__bluetooth__packets__pb2.HCIPacket.SerializeToString,
emulated__bluetooth__packets__pb2.HCIPacket.FromString, emulated__bluetooth__packets__pb2.HCIPacket.FromString,
options, channel_credentials, options,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata) channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
)
@@ -21,6 +21,7 @@ from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import message as _message from google.protobuf import message as _message
from google.protobuf import reflection as _reflection from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports) # @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default() _sym_db = _symbol_database.Default()
@@ -29,8 +30,9 @@ _sym_db = _symbol_database.Default()
import emulated_bluetooth_packets_pb2 as emulated__bluetooth__packets__pb2 import emulated_bluetooth_packets_pb2 as emulated__bluetooth__packets__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1d\x65mulated_bluetooth_vhci.proto\x12\x1b\x61ndroid.emulation.bluetooth\x1a emulated_bluetooth_packets.proto2y\n\x15VhciForwardingService\x12`\n\nattachVhci\x12&.android.emulation.bluetooth.HCIPacket\x1a&.android.emulation.bluetooth.HCIPacket(\x01\x30\x01\x42J\n\x1f\x63om.android.emulation.bluetoothP\x01\xf8\x01\x01\xa2\x02\x03\x41\x45\x42\xaa\x02\x1b\x41ndroid.Emulation.Bluetoothb\x06proto3') DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n\x1d\x65mulated_bluetooth_vhci.proto\x12\x1b\x61ndroid.emulation.bluetooth\x1a emulated_bluetooth_packets.proto2y\n\x15VhciForwardingService\x12`\n\nattachVhci\x12&.android.emulation.bluetooth.HCIPacket\x1a&.android.emulation.bluetooth.HCIPacket(\x01\x30\x01\x42J\n\x1f\x63om.android.emulation.bluetoothP\x01\xf8\x01\x01\xa2\x02\x03\x41\x45\x42\xaa\x02\x1b\x41ndroid.Emulation.Bluetoothb\x06proto3'
)
_VHCIFORWARDINGSERVICE = DESCRIPTOR.services_by_name['VhciForwardingService'] _VHCIFORWARDINGSERVICE = DESCRIPTOR.services_by_name['VhciForwardingService']
@@ -82,7 +82,8 @@ def add_VhciForwardingServiceServicer_to_server(servicer, server):
), ),
} }
generic_handler = grpc.method_handlers_generic_handler( generic_handler = grpc.method_handlers_generic_handler(
'android.emulation.bluetooth.VhciForwardingService', rpc_method_handlers) 'android.emulation.bluetooth.VhciForwardingService', rpc_method_handlers
)
server.add_generic_rpc_handlers((generic_handler,)) server.add_generic_rpc_handlers((generic_handler,))
@@ -97,7 +98,8 @@ class VhciForwardingService(object):
""" """
@staticmethod @staticmethod
def attachVhci(request_iterator, def attachVhci(
request_iterator,
target, target,
options=(), options=(),
channel_credentials=None, channel_credentials=None,
@@ -106,9 +108,20 @@ class VhciForwardingService(object):
compression=None, compression=None,
wait_for_ready=None, wait_for_ready=None,
timeout=None, timeout=None,
metadata=None): metadata=None,
return grpc.experimental.stream_stream(request_iterator, target, '/android.emulation.bluetooth.VhciForwardingService/attachVhci', ):
return grpc.experimental.stream_stream(
request_iterator,
target,
'/android.emulation.bluetooth.VhciForwardingService/attachVhci',
emulated__bluetooth__packets__pb2.HCIPacket.SerializeToString, emulated__bluetooth__packets__pb2.HCIPacket.SerializeToString,
emulated__bluetooth__packets__pb2.HCIPacket.FromString, emulated__bluetooth__packets__pb2.HCIPacket.FromString,
options, channel_credentials, options,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata) channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
)
+2 -5
View File
@@ -39,14 +39,12 @@ async def open_file_transport(spec):
# Setup reading # Setup reading
read_transport, packet_source = await asyncio.get_running_loop().connect_read_pipe( read_transport, packet_source = await asyncio.get_running_loop().connect_read_pipe(
lambda: StreamPacketSource(), lambda: StreamPacketSource(), file
file
) )
# Setup writing # Setup writing
write_transport, _ = await asyncio.get_running_loop().connect_write_pipe( write_transport, _ = await asyncio.get_running_loop().connect_write_pipe(
lambda: asyncio.BaseProtocol(), lambda: asyncio.BaseProtocol(), file
file
) )
packet_sink = StreamPacketSink(write_transport) packet_sink = StreamPacketSink(write_transport)
@@ -57,4 +55,3 @@ async def open_file_transport(spec):
file.close() file.close()
return FileTransport(packet_source, packet_sink) return FileTransport(packet_source, packet_sink)
+22 -5
View File
@@ -44,7 +44,11 @@ async def open_hci_socket_transport(spec):
# Create a raw HCI socket # Create a raw HCI socket
try: try:
hci_socket = socket.socket(socket.AF_BLUETOOTH, socket.SOCK_RAW | socket.SOCK_NONBLOCK, socket.BTPROTO_HCI) hci_socket = socket.socket(
socket.AF_BLUETOOTH,
socket.SOCK_RAW | socket.SOCK_NONBLOCK,
socket.BTPROTO_HCI,
)
except AttributeError: except AttributeError:
# Not supported on this platform # Not supported on this platform
logger.info("HCI sockets not supported on this platform") logger.info("HCI sockets not supported on this platform")
@@ -67,15 +71,26 @@ async def open_hci_socket_transport(spec):
raise Exception('Bluetooth HCI sockets not supported on this platform') raise Exception('Bluetooth HCI sockets not supported on this platform')
libc.bind.argtypes = (ctypes.c_int, ctypes.POINTER(ctypes.c_char), ctypes.c_int) libc.bind.argtypes = (ctypes.c_int, ctypes.POINTER(ctypes.c_char), ctypes.c_int)
libc.bind.restype = ctypes.c_int libc.bind.restype = ctypes.c_int
bind_address = struct.pack('<HHH', socket.AF_BLUETOOTH, adapter_index, HCI_CHANNEL_USER) bind_address = struct.pack(
if libc.bind(hci_socket.fileno(), ctypes.create_string_buffer(bind_address), len(bind_address)) != 0: '<HHH', socket.AF_BLUETOOTH, adapter_index, HCI_CHANNEL_USER
)
if (
libc.bind(
hci_socket.fileno(),
ctypes.create_string_buffer(bind_address),
len(bind_address),
)
!= 0
):
raise IOError(ctypes.get_errno(), os.strerror(ctypes.get_errno())) raise IOError(ctypes.get_errno(), os.strerror(ctypes.get_errno()))
class HciSocketSource(ParserSource): class HciSocketSource(ParserSource):
def __init__(self, socket): def __init__(self, socket):
super().__init__() super().__init__()
self.socket = socket self.socket = socket
asyncio.get_running_loop().add_reader(socket.fileno(), self.recv_until_would_block) asyncio.get_running_loop().add_reader(
socket.fileno(), self.recv_until_would_block
)
def recv_until_would_block(self): def recv_until_would_block(self):
logger.debug('recv until would block +++') logger.debug('recv until would block +++')
@@ -114,7 +129,9 @@ async def open_hci_socket_transport(spec):
if self.packets: if self.packets:
# There's still something to send, ensure that we are monitoring the socket # There's still something to send, ensure that we are monitoring the socket
if not self.writer_added: if not self.writer_added:
asyncio.get_running_loop().add_writer(socket.fileno(), self.send_until_would_block) asyncio.get_running_loop().add_writer(
socket.fileno(), self.send_until_would_block
)
self.writer_added = True self.writer_added = True
else: else:
# Nothing left to send, stop monitoring the socket # Nothing left to send, stop monitoring the socket
+2 -4
View File
@@ -47,13 +47,11 @@ async def open_pty_transport(spec):
tty.setraw(replica) tty.setraw(replica)
read_transport, packet_source = await asyncio.get_running_loop().connect_read_pipe( read_transport, packet_source = await asyncio.get_running_loop().connect_read_pipe(
lambda: StreamPacketSource(), lambda: StreamPacketSource(), io.open(primary, 'rb', closefd=False)
io.open(primary, 'rb', closefd=False)
) )
write_transport, _ = await asyncio.get_running_loop().connect_write_pipe( write_transport, _ = await asyncio.get_running_loop().connect_write_pipe(
lambda: asyncio.BaseProtocol(), lambda: asyncio.BaseProtocol(), io.open(primary, 'wb', closefd=False)
io.open(primary, 'wb', closefd=False)
) )
packet_sink = StreamPacketSink(write_transport) packet_sink = StreamPacketSink(write_transport)
+21 -11
View File
@@ -80,9 +80,17 @@ async def open_pyusb_transport(spec):
if packet_type == hci.HCI_ACL_DATA_PACKET: if packet_type == hci.HCI_ACL_DATA_PACKET:
self.device.write(USB_ENDPOINT_ACL_OUT, packet[1:]) self.device.write(USB_ENDPOINT_ACL_OUT, packet[1:])
elif packet_type == hci.HCI_COMMAND_PACKET: elif packet_type == hci.HCI_COMMAND_PACKET:
self.device.ctrl_transfer(USB_RECIPIENT_DEVICE | USB_REQUEST_TYPE_CLASS, 0, 0, 0, packet[1:]) self.device.ctrl_transfer(
USB_RECIPIENT_DEVICE | USB_REQUEST_TYPE_CLASS,
0,
0,
0,
packet[1:],
)
else: else:
logger.warning(color(f'unsupported packet type {packet_type}', 'red')) logger.warning(
color(f'unsupported packet type {packet_type}', 'red')
)
except usb.core.USBTimeoutError: except usb.core.USBTimeoutError:
logger.warning('USB Write Timeout') logger.warning('USB Write Timeout')
except usb.core.USBError as error: except usb.core.USBError as error:
@@ -109,13 +117,11 @@ async def open_pyusb_transport(spec):
self.loop = asyncio.get_running_loop() self.loop = asyncio.get_running_loop()
self.queue = asyncio.Queue() self.queue = asyncio.Queue()
self.event_thread = threading.Thread( self.event_thread = threading.Thread(
target=self.run, target=self.run, args=(USB_ENDPOINT_EVENTS_IN, hci.HCI_EVENT_PACKET)
args=(USB_ENDPOINT_EVENTS_IN, hci.HCI_EVENT_PACKET)
) )
self.event_thread.stop_event = None self.event_thread.stop_event = None
self.acl_thread = threading.Thread( self.acl_thread = threading.Thread(
target=self.run, target=self.run, args=(USB_ENDPOINT_ACL_IN, hci.HCI_ACL_DATA_PACKET)
args=(USB_ENDPOINT_ACL_IN, hci.HCI_ACL_DATA_PACKET)
) )
self.acl_thread.stop_event = None self.acl_thread.stop_event = None
@@ -124,7 +130,7 @@ async def open_pyusb_transport(spec):
if sco_enabled: if sco_enabled:
self.sco_thread = threading.Thread( self.sco_thread = threading.Thread(
target=self.run, target=self.run,
args=(USB_ENDPOINT_SCO_IN, hci.HCI_SYNCHRONOUS_DATA_PACKET) args=(USB_ENDPOINT_SCO_IN, hci.HCI_SYNCHRONOUS_DATA_PACKET),
) )
self.sco_thread.stop_event = None self.sco_thread.stop_event = None
@@ -197,15 +203,19 @@ async def open_pyusb_transport(spec):
# Find the device according to the spec moniker # Find the device according to the spec moniker
if ':' in spec: if ':' in spec:
vendor_id, product_id = spec.split(':') vendor_id, product_id = spec.split(':')
device = usb.core.find(idVendor=int(vendor_id, 16), idProduct=int(product_id, 16)) device = usb.core.find(
idVendor=int(vendor_id, 16), idProduct=int(product_id, 16)
)
else: else:
device_index = int(spec) device_index = int(spec)
devices = list(usb.core.find( devices = list(
usb.core.find(
find_all=1, find_all=1,
bDeviceClass=USB_DEVICE_CLASS_WIRELESS_CONTROLLER, bDeviceClass=USB_DEVICE_CLASS_WIRELESS_CONTROLLER,
bDeviceSubClass=USB_DEVICE_SUBCLASS_RF_CONTROLLER, bDeviceSubClass=USB_DEVICE_SUBCLASS_RF_CONTROLLER,
bDeviceProtocol = USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER bDeviceProtocol=USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER,
)) )
)
if len(devices) > device_index: if len(devices) > device_index:
device = devices[device_index] device = devices[device_index]
else: else:
+1 -2
View File
@@ -64,9 +64,8 @@ async def open_serial_transport(spec):
device, device,
baudrate=speed, baudrate=speed,
rtscts=rtscts, rtscts=rtscts,
dsrdtr=dsrdtr dsrdtr=dsrdtr,
) )
packet_sink = StreamPacketSink(serial_transport) packet_sink = StreamPacketSink(serial_transport)
return Transport(packet_source, packet_sink) return Transport(packet_source, packet_sink)
+5 -2
View File
@@ -53,10 +53,13 @@ async def open_udp_transport(spec):
local, remote = spec.split(',') local, remote = spec.split(',')
local_host, local_port = local.split(':') local_host, local_port = local.split(':')
remote_host, remote_port = remote.split(':') remote_host, remote_port = remote.split(':')
udp_transport, packet_source = await asyncio.get_running_loop().create_datagram_endpoint( (
udp_transport,
packet_source,
) = await asyncio.get_running_loop().create_datagram_endpoint(
lambda: UdpPacketSource(), lambda: UdpPacketSource(),
local_addr=(local_host, int(local_port)), local_addr=(local_host, int(local_port)),
remote_addr=(remote_host, int(remote_port)) remote_addr=(remote_host, int(remote_port)),
) )
packet_sink = UdpPacketSink(udp_transport) packet_sink = UdpPacketSink(udp_transport)
+71 -35
View File
@@ -72,7 +72,7 @@ async def open_usb_transport(spec):
USB_BT_HCI_CLASS_TUPLE = ( USB_BT_HCI_CLASS_TUPLE = (
USB_DEVICE_CLASS_WIRELESS_CONTROLLER, USB_DEVICE_CLASS_WIRELESS_CONTROLLER,
USB_DEVICE_SUBCLASS_RF_CONTROLLER, USB_DEVICE_SUBCLASS_RF_CONTROLLER,
USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER,
) )
READ_SIZE = 1024 READ_SIZE = 1024
@@ -114,7 +114,9 @@ async def open_usb_transport(spec):
elif status == usb1.TRANSFER_CANCELLED: elif status == usb1.TRANSFER_CANCELLED:
self.loop.call_soon_threadsafe(self.cancel_done.set_result, None) self.loop.call_soon_threadsafe(self.cancel_done.set_result, None)
else: else:
logger.warning(color(f'!!! out transfer not completed: status={status}', 'red')) logger.warning(
color(f'!!! out transfer not completed: status={status}', 'red')
)
def on_packet_sent_(self): def on_packet_sent_(self):
if self.packets: if self.packets:
@@ -129,17 +131,18 @@ async def open_usb_transport(spec):
packet_type = packet[0] packet_type = packet[0]
if packet_type == hci.HCI_ACL_DATA_PACKET: if packet_type == hci.HCI_ACL_DATA_PACKET:
self.transfer.setBulk( self.transfer.setBulk(
self.acl_out, self.acl_out, packet[1:], callback=self.on_packet_sent
packet[1:],
callback=self.on_packet_sent
) )
logger.debug('submit ACL') logger.debug('submit ACL')
self.transfer.submit() self.transfer.submit()
elif packet_type == hci.HCI_COMMAND_PACKET: elif packet_type == hci.HCI_COMMAND_PACKET:
self.transfer.setControl( self.transfer.setControl(
USB_RECIPIENT_DEVICE | USB_REQUEST_TYPE_CLASS, 0, 0, 0, USB_RECIPIENT_DEVICE | USB_REQUEST_TYPE_CLASS,
0,
0,
0,
packet[1:], packet[1:],
callback=self.on_packet_sent callback=self.on_packet_sent,
) )
logger.debug('submit COMMAND') logger.debug('submit COMMAND')
self.transfer.submit() self.transfer.submit()
@@ -177,7 +180,7 @@ async def open_usb_transport(spec):
self.event_loop_done = self.loop.create_future() self.event_loop_done = self.loop.create_future()
self.cancel_done = { self.cancel_done = {
hci.HCI_EVENT_PACKET: self.loop.create_future(), hci.HCI_EVENT_PACKET: self.loop.create_future(),
hci.HCI_ACL_DATA_PACKET: self.loop.create_future() hci.HCI_ACL_DATA_PACKET: self.loop.create_future(),
} }
# Create a thread to process events # Create a thread to process events
@@ -190,7 +193,7 @@ async def open_usb_transport(spec):
self.events_in, self.events_in,
READ_SIZE, READ_SIZE,
callback=self.on_packet_received, callback=self.on_packet_received,
user_data=hci.HCI_EVENT_PACKET user_data=hci.HCI_EVENT_PACKET,
) )
self.events_in_transfer.submit() self.events_in_transfer.submit()
@@ -199,7 +202,7 @@ async def open_usb_transport(spec):
self.acl_in, self.acl_in,
READ_SIZE, READ_SIZE,
callback=self.on_packet_received, callback=self.on_packet_received,
user_data=hci.HCI_ACL_DATA_PACKET user_data=hci.HCI_ACL_DATA_PACKET,
) )
self.acl_in_transfer.submit() self.acl_in_transfer.submit()
@@ -212,13 +215,20 @@ async def open_usb_transport(spec):
# logger.debug(f'<<< USB IN transfer callback: status={status} packet_type={packet_type} length={transfer.getActualLength()}') # logger.debug(f'<<< USB IN transfer callback: status={status} packet_type={packet_type} length={transfer.getActualLength()}')
if status == usb1.TRANSFER_COMPLETED: if status == usb1.TRANSFER_COMPLETED:
packet = bytes([packet_type]) + transfer.getBuffer()[:transfer.getActualLength()] packet = (
bytes([packet_type])
+ transfer.getBuffer()[: transfer.getActualLength()]
)
self.loop.call_soon_threadsafe(self.queue.put_nowait, packet) self.loop.call_soon_threadsafe(self.queue.put_nowait, packet)
elif status == usb1.TRANSFER_CANCELLED: elif status == usb1.TRANSFER_CANCELLED:
self.loop.call_soon_threadsafe(self.cancel_done[packet_type].set_result, None) self.loop.call_soon_threadsafe(
self.cancel_done[packet_type].set_result, None
)
return return
else: else:
logger.warning(color(f'!!! transfer not completed: status={status}', 'red')) logger.warning(
color(f'!!! transfer not completed: status={status}', 'red')
)
# Re-submit the transfer so we can receive more data # Re-submit the transfer so we can receive more data
transfer.submit() transfer.submit()
@@ -233,7 +243,10 @@ async def open_usb_transport(spec):
def run(self): def run(self):
logger.debug('starting USB event loop') logger.debug('starting USB event loop')
while self.events_in_transfer.isSubmitted() or self.acl_in_transfer.isSubmitted(): while (
self.events_in_transfer.isSubmitted()
or self.acl_in_transfer.isSubmitted()
):
try: try:
self.context.handleEvents() self.context.handleEvents()
except usb1.USBErrorInterrupted: except usb1.USBErrorInterrupted:
@@ -253,11 +266,15 @@ async def open_usb_transport(spec):
packet_type = transfer.getUserData() packet_type = transfer.getUserData()
try: try:
transfer.cancel() transfer.cancel()
logger.debug(f'waiting for IN[{packet_type}] transfer cancellation to be done...') logger.debug(
f'waiting for IN[{packet_type}] transfer cancellation to be done...'
)
await self.cancel_done[packet_type] await self.cancel_done[packet_type]
logger.debug(f'IN[{packet_type}] transfer cancellation done') logger.debug(f'IN[{packet_type}] transfer cancellation done')
except usb1.USBError: except usb1.USBError:
logger.debug(f'IN[{packet_type}] transfer likely already completed') logger.debug(
f'IN[{packet_type}] transfer likely already completed'
)
# Wait for the thread to terminate # Wait for the thread to terminate
await self.event_loop_done await self.event_loop_done
@@ -315,9 +332,9 @@ async def open_usb_transport(spec):
except usb1.USBError: except usb1.USBError:
device_serial_number = None device_serial_number = None
if ( if (
device.getVendorID() == int(vendor_id, 16) and device.getVendorID() == int(vendor_id, 16)
device.getProductID() == int(product_id, 16) and and device.getProductID() == int(product_id, 16)
(serial_number is None or serial_number == device_serial_number) and (serial_number is None or serial_number == device_serial_number)
): ):
if device_index == 0: if device_index == 0:
found = device found = device
@@ -328,8 +345,11 @@ async def open_usb_transport(spec):
# Look for a compatible device by index # Look for a compatible device by index
def device_is_bluetooth_hci(device): def device_is_bluetooth_hci(device):
# Check if the device class indicates a match # Check if the device class indicates a match
if (device.getDeviceClass(), device.getDeviceSubClass(), device.getDeviceProtocol()) == \ if (
USB_BT_HCI_CLASS_TUPLE: device.getDeviceClass(),
device.getDeviceSubClass(),
device.getDeviceProtocol(),
) == USB_BT_HCI_CLASS_TUPLE:
return True return True
# If the device class is 'Device', look for a matching interface # If the device class is 'Device', look for a matching interface
@@ -337,8 +357,11 @@ async def open_usb_transport(spec):
for configuration in device: for configuration in device:
for interface in configuration: for interface in configuration:
for setting in interface: for setting in interface:
if (setting.getClass(), setting.getSubClass(), setting.getProtocol()) == \ if (
USB_BT_HCI_CLASS_TUPLE: setting.getClass(),
setting.getSubClass(),
setting.getProtocol(),
) == USB_BT_HCI_CLASS_TUPLE:
return True return True
return False return False
@@ -366,8 +389,13 @@ async def open_usb_transport(spec):
setting = None setting = None
for setting in interface: for setting in interface:
if ( if (
not forced_mode and not forced_mode
(setting.getClass(), setting.getSubClass(), setting.getProtocol()) != USB_BT_HCI_CLASS_TUPLE and (
setting.getClass(),
setting.getSubClass(),
setting.getProtocol(),
)
!= USB_BT_HCI_CLASS_TUPLE
): ):
continue continue
@@ -382,22 +410,31 @@ async def open_usb_transport(spec):
acl_in = address acl_in = address
elif acl_out is None: elif acl_out is None:
acl_out = address acl_out = address
elif attributes & 0x03 == USB_ENDPOINT_TRANSFER_TYPE_INTERRUPT: elif (
attributes & 0x03
== USB_ENDPOINT_TRANSFER_TYPE_INTERRUPT
):
if address & USB_ENDPOINT_IN and events_in is None: if address & USB_ENDPOINT_IN and events_in is None:
events_in = address events_in = address
# Return if we found all 3 endpoints # Return if we found all 3 endpoints
if acl_in is not None and acl_out is not None and events_in is not None: if (
acl_in is not None
and acl_out is not None
and events_in is not None
):
return ( return (
configuration_index + 1, configuration_index + 1,
setting.getNumber(), setting.getNumber(),
setting.getAlternateSetting(), setting.getAlternateSetting(),
acl_in, acl_in,
acl_out, acl_out,
events_in events_in,
) )
else: else:
logger.debug(f'skipping configuration {configuration_index + 1} / interface {setting.getNumber()}') logger.debug(
f'skipping configuration {configuration_index + 1} / interface {setting.getNumber()}'
)
endpoints = find_endpoints(found) endpoints = find_endpoints(found)
if endpoints is None: if endpoints is None:
@@ -414,14 +451,13 @@ async def open_usb_transport(spec):
device = found.open() device = found.open()
# Detach the kernel driver if supported and needed # Auto-detach the kernel driver if supported
if usb1.hasCapability(usb1.CAP_SUPPORTS_DETACH_KERNEL_DRIVER): if usb1.hasCapability(usb1.CAP_SUPPORTS_DETACH_KERNEL_DRIVER):
try: try:
if device.kernelDriverActive(interface): logger.debug('auto-detaching kernel driver')
logger.debug("detaching kernel driver") device.setAutoDetachKernelDriver(True)
device.detachKernelDriver(interface) except usb1.USBError as error:
except usb1.USBError: logger.warning(f'unable to auto-detach kernel driver: {error}')
pass
# Set the configuration if needed # Set the configuration if needed
try: try:
+1 -2
View File
@@ -33,7 +33,7 @@ async def open_vhci_transport(spec):
path at /dev/vhci), or the path of a VHCI device path at /dev/vhci), or the path of a VHCI device
''' '''
HCI_VENDOR_PKT = 0xff HCI_VENDOR_PKT = 0xFF
HCI_BREDR = 0x00 # Controller type HCI_BREDR = 0x00 # Controller type
# Open the VHCI device # Open the VHCI device
@@ -56,4 +56,3 @@ async def open_vhci_transport(spec):
transport.sink.on_packet(bytes([HCI_VENDOR_PKT, HCI_BREDR])) transport.sink.on_packet(bytes([HCI_VENDOR_PKT, HCI_BREDR]))
return transport return transport
+1 -1
View File
@@ -43,7 +43,7 @@ async def open_ws_client_transport(spec):
transport = PumpedTransport( transport = PumpedTransport(
PumpedPacketSource(websocket.recv), PumpedPacketSource(websocket.recv),
PumpedPacketSink(websocket.send), PumpedPacketSink(websocket.send),
websocket.close websocket.close,
) )
transport.start() transport.start()
return transport return transport
+4 -2
View File
@@ -52,12 +52,14 @@ async def open_ws_server_transport(spec):
self.server = await websockets.serve( self.server = await websockets.serve(
ws_handler=self.on_connection, ws_handler=self.on_connection,
host=local_host if local_host != '_' else None, host=local_host if local_host != '_' else None,
port = int(local_port) port=int(local_port),
) )
logger.debug(f'websocket server ready on port {local_port}') logger.debug(f'websocket server ready on port {local_port}')
async def on_connection(self, connection): async def on_connection(self, connection):
logger.debug(f'new connection on {connection.local_address} from {connection.remote_address}') logger.debug(
f'new connection on {connection.local_address} from {connection.remote_address}'
)
self.connection.set_result(connection) self.connection.set_result(connection)
try: try:
async for packet in connection: async for packet in connection:
+109 -2
View File
@@ -18,6 +18,7 @@
import asyncio import asyncio
import logging import logging
import traceback import traceback
import collections
from functools import wraps from functools import wraps
from colors import color from colors import color
from pyee import EventEmitter from pyee import EventEmitter
@@ -33,6 +34,7 @@ logger = logging.getLogger(__name__)
def setup_event_forwarding(emitter, forwarder, event_name): def setup_event_forwarding(emitter, forwarder, event_name):
def emit(*args, **kwargs): def emit(*args, **kwargs):
forwarder.emit(event_name, *args, **kwargs) forwarder.emit(event_name, *args, **kwargs)
emitter.on(event_name, emit) emitter.on(event_name, emit)
@@ -43,6 +45,7 @@ def composite_listener(cls):
registers/deregisters all methods named `on_<event_name>` as a listener for registers/deregisters all methods named `on_<event_name>` as a listener for
the <event_name> event with an emitter. the <event_name> event with an emitter.
""" """
def register(self, emitter): def register(self, emitter):
for method_name in dir(cls): for method_name in dir(cls):
if method_name.startswith('on_'): if method_name.startswith('on_'):
@@ -109,7 +112,9 @@ class AsyncRunner:
try: try:
await item await item
except Exception as error: except Exception as error:
logger.warning(f'{color("!!! Exception in work queue:", "red")} {error}') logger.warning(
f'{color("!!! Exception in work queue:", "red")} {error}'
)
# Shared default queue # Shared default queue
default_queue = WorkQueue() default_queue = WorkQueue()
@@ -130,7 +135,9 @@ class AsyncRunner:
try: try:
await coroutine await coroutine
except Exception: except Exception:
logger.warning(f'{color("!!! Exception in wrapper:", "red")} {traceback.format_exc()}') logger.warning(
f'{color("!!! Exception in wrapper:", "red")} {traceback.format_exc()}'
)
asyncio.create_task(run()) asyncio.create_task(run())
else: else:
@@ -140,3 +147,103 @@ class AsyncRunner:
return wrapper return wrapper
return decorator return decorator
# -----------------------------------------------------------------------------
class FlowControlAsyncPipe:
"""
Asyncio pipe with flow control. When writing to the pipe, the source is
paused (by calling a function passed in when the pipe is created) if the
amount of queued data exceeds a specified threshold.
"""
def __init__(
self,
pause_source,
resume_source,
write_to_sink=None,
drain_sink=None,
threshold=0,
):
self.pause_source = pause_source
self.resume_source = resume_source
self.write_to_sink = write_to_sink
self.drain_sink = drain_sink
self.threshold = threshold
self.queue = collections.deque() # Queue of packets
self.queued_bytes = 0 # Number of bytes in the queue
self.ready_to_pump = asyncio.Event()
self.paused = False
self.source_paused = False
self.pump_task = None
def start(self):
if self.pump_task is None:
self.pump_task = asyncio.create_task(self.pump())
self.check_pump()
def stop(self):
if self.pump_task is not None:
self.pump_task.cancel()
self.pump_task = None
def write(self, packet):
self.queued_bytes += len(packet)
self.queue.append(packet)
# Pause the source if we're over the threshold
if self.queued_bytes > self.threshold and not self.source_paused:
logger.debug(f'pausing source (queued={self.queued_bytes})')
self.pause_source()
self.source_paused = True
self.check_pump()
def pause(self):
if not self.paused:
self.paused = True
if not self.source_paused:
self.pause_source()
self.source_paused = True
self.check_pump()
def resume(self):
if self.paused:
self.paused = False
if self.source_paused:
self.resume_source()
self.source_paused = False
self.check_pump()
def can_pump(self):
return self.queue and not self.paused and self.write_to_sink is not None
def check_pump(self):
if self.can_pump():
self.ready_to_pump.set()
else:
self.ready_to_pump.clear()
async def pump(self):
while True:
# Wait until we can try to pump packets
await self.ready_to_pump.wait()
# Try to pump a packet
if self.can_pump():
packet = self.queue.pop()
self.write_to_sink(packet)
self.queued_bytes -= len(packet)
# Drain the sink if we can
if self.drain_sink:
await self.drain_sink()
# Check if we can accept more
if self.queued_bytes <= self.threshold and self.source_paused:
logger.debug(f'resuming source (queued={self.queued_bytes})')
self.source_paused = False
self.resume_source()
self.check_pump()
+5 -5
View File
@@ -1,6 +1,6 @@
# This requirements file is for python3 # This requirements file is for python3
mkdocs == 1.2.3 mkdocs == 1.4.0
mkdocs-material == 7.1.7 mkdocs-material == 8.5.6
mkdocs-material-extensions == 1.0.1 mkdocs-material-extensions == 1.0.3
pymdown-extensions == 8.2 pymdown-extensions == 9.6
mkdocstrings == 0.15.1 mkdocstrings-python == 0.7.1
+114 -44
View File
@@ -1,47 +1,85 @@
:material-linux: LINUX PLATFORM :material-linux: LINUX PLATFORM
=============================== ===============================
In addition to all the standard functionality available from the project by running the python tools and/or writing your own apps by leveraging the API, it is also possible on Linux hosts to interface the Bumble stack with the native BlueZ stack, and with Bluetooth controllers. Using Bumble With Physical Bluetooth Controllers
------------------------------------------------
Using Bumble With BlueZ A Bumble application can interface with a local Bluetooth controller on a Linux host.
----------------------- The 3 main types of physical Bluetooth controllers are:
A Bumble virtual controller can be attached to the BlueZ stack. * Bluetooth USB Dongle
Attaching a controller to BlueZ can be done by either simulating a UART HCI interface, or by using the VHCI driver interface if available. * HCI over UART (via a serial port)
In both cases, the controller can run locally on the Linux host, or remotely on a different host, with a bridge between the remote controller and the local BlueZ host, which may be useful when the BlueZ stack is running on an embedded system, or a host on which running the Bumble controller is not convenient. * Kernel-managed Bluetooth HCI (HCI Sockets)
### Using VHCI !!! tip "Conflicts with the kernel and BlueZ"
If your use a USB dongle that is recognized by your kernel as a supported Bluetooth device, it is
likely that the kernel driver will claim that USB device and attach it to the BlueZ stack.
If you want to claim ownership of it to use with Bumble, you will need to set the state of the corresponding HCI interface as `DOWN`.
HCI interfaces are numbered, starting from 0 (i.e `hci0`, `hci1`, ...).
With the [VHCI transport](../transports/vhci.md) you can attach a Bumble virtual controller to the BlueZ stack. Once attached, the controller will appear just like any other controller, and thus can be used with the standard BlueZ tools. For example, to bring `hci0` down:
!!! example "Attaching a virtual controller"
With the example app `run_controller.py`:
``` ```
PYTHONPATH=. python3 examples/run_controller.py F6:F7:F8:F9:FA:FB examples/device1.json vhci $ sudo hciconfig hci0 down
``` ```
You should see a 'Virtual Bus' controller. For example: You can use the `hciconfig` command with no arguments to get a list of HCI interfaces seen by
``` the kernel.
$ hciconfig
hci0: Type: Primary Bus: Virtual
BD Address: F6:F7:F8:F9:FA:FB ACL MTU: 27:64 SCO MTU: 0:0
UP RUNNING
RX bytes:0 acl:0 sco:0 events:43 errors:0
TX bytes:274 acl:0 sco:0 commands:43 errors:0
```
And scanning for devices should show the virtual 'Bumble' device that's running as part of the `run_controller.py` example app: Also, if `bluetoothd` is running on your system, it will likely re-claim the interface after you
close it, so you may need to bring the interface back `UP` before using it again, or to disable
`bluetoothd` altogether (see the section further below about BlueZ and `bluetoothd`).
### Using a USB Dongle
See the [USB Transport page](../transports/usb.md) for general information on how to use HCI USB controllers.
!!! tip "USB Permissions"
By default, when running as a regular user, you won't have the permission to use
arbitrary USB devices.
You can change the permissions for a specific USB device based on its bus number and
device number (you can use `lsusb` to find the Bus and Device numbers for your Bluetooth
dongle).
Example:
``` ```
pi@raspberrypi:~ $ sudo hcitool -i hci2 lescan $ sudo chmod o+w /dev/bus/usb/001/004
LE Scan ...
F0:F1:F2:F3:F4:F5 Bumble
``` ```
This will change the permissions for Device 4 on Bus 1.
Note that the USB Bus number and Device number may change depending on where you plug the USB
dongle and what other USB devices and hubs are also plugged in.
If you need to make the permission changes permanent across reboots, you can create a `udev`
rule for your specific Bluetooth dongle. Visit [this Arch Linux Wiki page](https://wiki.archlinux.org/title/udev) for a
good overview of how you may do that.
### Using HCI over UART
See the [Serial Transport page](../transports/serial.md) for general information on how to use HCI over a UART (serial port).
### Using HCI Sockets ### Using HCI Sockets
HCI sockets provide a way to send/receive HCI packets to/from a Bluetooth controller managed by the kernel. HCI sockets provide a way to send/receive HCI packets to/from a Bluetooth controller managed by the kernel.
The HCI device referenced by an `hci-socket` transport (`hciX`, where `X` is an integer, with `hci0` being the first controller device, and so on) must be in the `DOWN` state before it can be opened as a transport. See the [HCI Socket Transport page](../transports/hci_socket.md) for details on the `hci-socket` tansport syntax.
You can bring a HCI controller `UP` or `DOWN` with `hciconfig`.
The HCI device referenced by an `hci-socket` transport (`hci<X>`, where `<X>` is an integer, with `hci0` being the first controller device, and so on) must be in the `DOWN` state before it can be opened as a transport.
You can bring a HCI controller `UP` or `DOWN` with `hciconfig hci<X> up` and `hciconfig hci<X> up`.
!!! tip "HCI Socket Permissions"
By default, when running as a regular user, you won't have the permission to use
an HCI socket to a Bluetooth controller (you may see an exception like `PermissionError: [Errno 1] Operation not permitted`).
If you want to run without using `sudo`, you need to manage the capabilities by adding the appropriate entries in `/etc/security/capability.conf` to grant a user or group the `cap_net_admin` capability.
See [this manpage](https://manpages.ubuntu.com/manpages/bionic/man5/capability.conf.5.html) for details.
Alternatively, if you are just experimenting temporarily, the `capsh` command may be useful in order
to execute a single command with enhanced permissions, as in this example:
```
$ sudo capsh --caps="cap_net_admin+eip cap_setpcap,cap_setuid,cap_setgid+ep" --keep=1 --user=$USER --addamb=cap_net_admin -- -c "<path/to/executable> <executable-args>"
```
Where `<path/to/executable>` is the path to your `python3` executable or to one of the Bumble bundled command-line applications.
!!! tip "List all available controllers" !!! tip "List all available controllers"
The command The command
@@ -72,29 +110,16 @@ You can bring a HCI controller `UP` or `DOWN` with `hciconfig`.
``` ```
$ hciconfig hci0 down $ hciconfig hci0 down
``` ```
(or `hciX` with `X` being the index of the controller device you want to use), but a simpler solution is to just stop the `bluetoothd` daemon, with a command like: (or `hci<X>` with `<X>` being the index of the controller device you want to use), but a simpler solution is to just stop the `bluetoothd` daemon, with a command like:
``` ```
$ sudo systemctl stop bluetooth.service $ sudo systemctl stop bluetooth.service
``` ```
You can always re-start the daemon with You can always re-start the daemon with
``` ```
$ sudo systemctl start bluetooth.service $ sudo systemctl start bluetooth.service
```
### Using a Simulated UART HCI Bumble on the Raspberry Pi
--------------------------
### Bridge to a Remote Controller
Using Bumble With Bluetooth Controllers
---------------------------------------
A Bumble application can interface with a local Bluetooth controller.
If your Bluetooth controller is a standard HCI USB controller, see the [USB Transport page](../transports/usb.md) for details on how to use HCI USB controllers.
If your Bluetooth controller is a standard HCI UART controller, see the [Serial Transport page](../transports/serial.md).
Alternatively, a Bumble Host object can communicate with one of the platform's controllers via an HCI Socket.
`<details to be filled in>`
### Raspberry Pi 4 :fontawesome-brands-raspberry-pi: ### Raspberry Pi 4 :fontawesome-brands-raspberry-pi:
@@ -102,9 +127,10 @@ You can use the Bluetooth controller either via the kernel, or directly to the d
#### Via The Kernel #### Via The Kernel
Use an HCI Socket transport Use an HCI Socket transport (see section above)
#### Directly #### Directly
In order to use the Bluetooth controller directly on a Raspberry Pi 4 board, you need to ensure that it isn't being used by the BlueZ stack (which it probably is by default). In order to use the Bluetooth controller directly on a Raspberry Pi 4 board, you need to ensure that it isn't being used by the BlueZ stack (which it probably is by default).
``` ```
@@ -136,3 +162,47 @@ should detach the controller from the stack, after which you can use the HCI UAR
python3 run_scanner.py serial:/dev/serial1,3000000 python3 run_scanner.py serial:/dev/serial1,3000000
``` ```
Using Bumble With BlueZ
-----------------------
In addition to all the standard functionality available from the project by running the python tools and/or writing your own apps by leveraging the API, it is also possible on Linux hosts to interface the Bumble stack with the native BlueZ stack, and with Bluetooth controllers.
A Bumble virtual controller can be attached to the BlueZ stack.
Attaching a controller to BlueZ can be done by either simulating a UART HCI interface, or by using the VHCI driver interface if available.
In both cases, the controller can run locally on the Linux host, or remotely on a different host, with a bridge between the remote controller and the local BlueZ host, which may be useful when the BlueZ stack is running on an embedded system, or a host on which running the Bumble controller is not convenient.
### Using VHCI
With the [VHCI transport](../transports/vhci.md) you can attach a Bumble virtual controller to the BlueZ stack. Once attached, the controller will appear just like any other controller, and thus can be used with the standard BlueZ tools.
!!! example "Attaching a virtual controller"
With the example app `run_controller.py`:
```
python3 examples/run_controller.py F6:F7:F8:F9:FA:FB examples/device1.json vhci
```
You should see a 'Virtual Bus' controller. For example:
```
$ hciconfig
hci0: Type: Primary Bus: Virtual
BD Address: F6:F7:F8:F9:FA:FB ACL MTU: 27:64 SCO MTU: 0:0
UP RUNNING
RX bytes:0 acl:0 sco:0 events:43 errors:0
TX bytes:274 acl:0 sco:0 commands:43 errors:0
```
And scanning for devices should show the virtual 'Bumble' device that's running as part of the `run_controller.py` example app:
```
pi@raspberrypi:~ $ sudo hcitool -i hci2 lescan
LE Scan ...
F0:F1:F2:F3:F4:F5 Bumble
```
```
### Using a Simulated UART HCI
### Bridge to a Remote Controller
@@ -5,8 +5,9 @@ The Android emulator transport either connects, as a host, to a "Root Canal" vir
("host" mode), or attaches a virtual controller to the Android Bluetooth host stack ("controller" mode). ("host" mode), or attaches a virtual controller to the Android Bluetooth host stack ("controller" mode).
## Moniker ## Moniker
The moniker syntax for an Android Emulator transport is: `android-emulator:[mode=<host|controller>][mode=<host|controller>]`. The moniker syntax for an Android Emulator transport is: `android-emulator:[mode=<host|controller>][<hostname>:<port>]`, where
Both the `mode=<host|controller>` and `mode=<host|controller>` parameters are optional (so the moniker `android-emulator` by itself is a valid moniker, which will create a transport in `host` mode, connected to `localhost` on the default gRPC port for the emulator) the `mode` parameter can specify running as a host or a controller, and `<hostname>:<port>` can specify a host name (or IP address) and TCP port number on which to reach the gRPC server for the emulator.
Both the `mode=<host|controller>` and `<hostname>:<port>` parameters are optional (so the moniker `android-emulator` by itself is a valid moniker, which will create a transport in `host` mode, connected to `localhost` on the default gRPC port for the emulator).
!!! example Example !!! example Example
`android-emulator` `android-emulator`
+5
View File
@@ -0,0 +1,5 @@
{
"name": "Bumble Aid Left",
"address": "F1:F2:F3:F4:F5:F6",
"keystore": "JsonKeyStore"
}
+5
View File
@@ -0,0 +1,5 @@
{
"name": "Bumble Aid Right",
"address": "F7:F8:F9:FA:FB:FC",
"keystore": "JsonKeyStore"
}
+1
View File
@@ -80,6 +80,7 @@ async def main():
await my_work_queue2.run() await my_work_queue2.run()
print("MAIN: end (should never get here)") print("MAIN: end (should never get here)")
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main()) asyncio.run(main())
+3 -1
View File
@@ -55,7 +55,9 @@ async def main():
# Subscribe to and read the battery level # Subscribe to and read the battery level
if battery_service.battery_level: if battery_service.battery_level:
await battery_service.battery_level.subscribe( await battery_service.battery_level.subscribe(
lambda value: print(f'{color("Battery Level Update:", "green")} {value}') lambda value: print(
f'{color("Battery Level Update:", "green")} {value}'
)
) )
value = await battery_service.battery_level.read_value() value = await battery_service.battery_level.read_value()
print(f'{color("Initial Battery Level:", "green")} {value}') print(f'{color("Initial Battery Level:", "green")} {value}')
+16 -6
View File
@@ -44,11 +44,19 @@ async def main():
# Set the advertising data # Set the advertising data
device.advertising_data = bytes( device.advertising_data = bytes(
AdvertisingData([ AdvertisingData(
(AdvertisingData.COMPLETE_LOCAL_NAME, bytes('Bumble Battery', 'utf-8')), [
(AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS, bytes(battery_service.uuid)), (
(AdvertisingData.APPEARANCE, struct.pack('<H', 0x0340)) AdvertisingData.COMPLETE_LOCAL_NAME,
]) bytes('Bumble Battery', 'utf-8'),
),
(
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
bytes(battery_service.uuid),
),
(AdvertisingData.APPEARANCE, struct.pack('<H', 0x0340)),
]
)
) )
# Go! # Go!
@@ -58,7 +66,9 @@ async def main():
# Notify every 3 seconds # Notify every 3 seconds
while True: while True:
await asyncio.sleep(3.0) await asyncio.sleep(3.0)
await device.notify_subscribers(battery_service.battery_level_characteristic) await device.notify_subscribers(
battery_service.battery_level_characteristic
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
+44 -11
View File
@@ -28,7 +28,9 @@ from bumble.transport import open_transport
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def main(): async def main():
if len(sys.argv) != 3: if len(sys.argv) != 3:
print('Usage: device_information_client.py <transport-spec> <bluetooth-address>') print(
'Usage: device_information_client.py <transport-spec> <bluetooth-address>'
)
print('example: device_information_client.py usb:0 E1:CA:72:48:C4:E8') print('example: device_information_client.py usb:0 E1:CA:72:48:C4:E8')
return return
@@ -49,7 +51,9 @@ async def main():
# Discover the Device Information service # Discover the Device Information service
peer = Peer(connection) peer = Peer(connection)
print('=== Discovering Device Information Service') print('=== Discovering Device Information Service')
device_information_service = await peer.discover_service_and_create_proxy(DeviceInformationServiceProxy) device_information_service = await peer.discover_service_and_create_proxy(
DeviceInformationServiceProxy
)
# Check that the service was found # Check that the service was found
if device_information_service is None: if device_information_service is None:
@@ -58,21 +62,50 @@ async def main():
# Read and print the fields # Read and print the fields
if device_information_service.manufacturer_name is not None: if device_information_service.manufacturer_name is not None:
print(color('Manufacturer Name: ', 'green'), await device_information_service.manufacturer_name.read_value()) print(
color('Manufacturer Name: ', 'green'),
await device_information_service.manufacturer_name.read_value(),
)
if device_information_service.model_number is not None: if device_information_service.model_number is not None:
print(color('Model Number: ', 'green'), await device_information_service.model_number.read_value()) print(
color('Model Number: ', 'green'),
await device_information_service.model_number.read_value(),
)
if device_information_service.serial_number is not None: if device_information_service.serial_number is not None:
print(color('Serial Number: ', 'green'), await device_information_service.serial_number.read_value()) print(
color('Serial Number: ', 'green'),
await device_information_service.serial_number.read_value(),
)
if device_information_service.hardware_revision is not None: if device_information_service.hardware_revision is not None:
print(color('Hardware Revision: ', 'green'), await device_information_service.hardware_revision.read_value()) print(
color('Hardware Revision: ', 'green'),
await device_information_service.hardware_revision.read_value(),
)
if device_information_service.firmware_revision is not None: if device_information_service.firmware_revision is not None:
print(color('Firmware Revision: ', 'green'), await device_information_service.firmware_revision.read_value()) print(
color('Firmware Revision: ', 'green'),
await device_information_service.firmware_revision.read_value(),
)
if device_information_service.software_revision is not None: if device_information_service.software_revision is not None:
print(color('Software Revision: ', 'green'), await device_information_service.software_revision.read_value()) print(
color('Software Revision: ', 'green'),
await device_information_service.software_revision.read_value(),
)
if device_information_service.system_id is not None: if device_information_service.system_id is not None:
print(color('System ID: ', 'green'), await device_information_service.system_id.read_value()) print(
if device_information_service.ieee_regulatory_certification_data_list is not None: color('System ID: ', 'green'),
print(color('Regulatory Certification:', 'green'), (await device_information_service.ieee_regulatory_certification_data_list.read_value()).hex()) await device_information_service.system_id.read_value(),
)
if (
device_information_service.ieee_regulatory_certification_data_list
is not None
):
print(
color('Regulatory Certification:', 'green'),
(
await device_information_service.ieee_regulatory_certification_data_list.read_value()
).hex(),
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
+11 -5
View File
@@ -44,16 +44,21 @@ async def main():
serial_number='7654321', serial_number='7654321',
hardware_revision='1.1.3', hardware_revision='1.1.3',
software_revision='2.5.6', software_revision='2.5.6',
system_id = (0x123456, 0x8877665544) system_id=(0x123456, 0x8877665544),
) )
device.add_service(device_information_service) device.add_service(device_information_service)
# Set the advertising data # Set the advertising data
device.advertising_data = bytes( device.advertising_data = bytes(
AdvertisingData([ AdvertisingData(
(AdvertisingData.COMPLETE_LOCAL_NAME, bytes('Bumble Device', 'utf-8')), [
(AdvertisingData.APPEARANCE, struct.pack('<H', 0x0340)) (
]) AdvertisingData.COMPLETE_LOCAL_NAME,
bytes('Bumble Device', 'utf-8'),
),
(AdvertisingData.APPEARANCE, struct.pack('<H', 0x0340)),
]
)
) )
# Go! # Go!
@@ -61,6 +66,7 @@ async def main():
await device.start_advertising(auto_restart=True) await device.start_advertising(auto_restart=True)
await hci_source.wait_for_termination() await hci_source.wait_for_termination()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main()) asyncio.run(main())
+3 -1
View File
@@ -61,7 +61,9 @@ async def main():
# Subscribe to the heart rate measurement # Subscribe to the heart rate measurement
if heart_rate_service.heart_rate_measurement: if heart_rate_service.heart_rate_measurement:
await heart_rate_service.heart_rate_measurement.subscribe( await heart_rate_service.heart_rate_measurement.subscribe(
lambda value: print(f'{color("Heart Rate Measurement:", "green")} {value}') lambda value: print(
f'{color("Heart Rate Measurement:", "green")} {value}'
)
) )
await peer.sustain() await peer.sustain()
+30 -10
View File
@@ -55,29 +55,47 @@ async def main():
serial_number='7654321', serial_number='7654321',
hardware_revision='1.1.3', hardware_revision='1.1.3',
software_revision='2.5.6', software_revision='2.5.6',
system_id = (0x123456, 0x8877665544) system_id=(0x123456, 0x8877665544),
) )
heart_rate_service = HeartRateService( heart_rate_service = HeartRateService(
read_heart_rate_measurement=lambda _: HeartRateService.HeartRateMeasurement( read_heart_rate_measurement=lambda _: HeartRateService.HeartRateMeasurement(
heart_rate=100 + int(50 * math.sin(time.time() * math.pi / 60)), heart_rate=100 + int(50 * math.sin(time.time() * math.pi / 60)),
sensor_contact_detected=random.choice((True, False, None)), sensor_contact_detected=random.choice((True, False, None)),
energy_expended = random.choice((int((time.time() - energy_start_time) * 100), None)), energy_expended=random.choice(
rr_intervals = random.choice(((random.randint(900, 1100) / 1000, random.randint(900, 1100) / 1000), None)) (int((time.time() - energy_start_time) * 100), None)
),
rr_intervals=random.choice(
(
(
random.randint(900, 1100) / 1000,
random.randint(900, 1100) / 1000,
),
None,
)
),
), ),
body_sensor_location=HeartRateService.BodySensorLocation.WRIST, body_sensor_location=HeartRateService.BodySensorLocation.WRIST,
reset_energy_expended=lambda _: reset_energy_expended() reset_energy_expended=lambda _: reset_energy_expended(),
) )
device.add_services([device_information_service, heart_rate_service]) device.add_services([device_information_service, heart_rate_service])
# Set the advertising data # Set the advertising data
device.advertising_data = bytes( device.advertising_data = bytes(
AdvertisingData([ AdvertisingData(
(AdvertisingData.COMPLETE_LOCAL_NAME, bytes('Bumble Heart', 'utf-8')), [
(AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS, bytes(heart_rate_service.uuid)), (
(AdvertisingData.APPEARANCE, struct.pack('<H', 0x0340)) AdvertisingData.COMPLETE_LOCAL_NAME,
]) bytes('Bumble Heart', 'utf-8'),
),
(
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
bytes(heart_rate_service.uuid),
),
(AdvertisingData.APPEARANCE, struct.pack('<H', 0x0340)),
]
)
) )
# Go! # Go!
@@ -87,7 +105,9 @@ async def main():
# Notify every 3 seconds # Notify every 3 seconds
while True: while True:
await asyncio.sleep(3.0) await asyncio.sleep(3.0)
await device.notify_subscribers(heart_rate_service.heart_rate_measurement_characteristic) await device.notify_subscribers(
heart_rate_service.heart_rate_measurement_characteristic
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
+165 -78
View File
@@ -34,8 +34,8 @@ from bumble.gatt import (
Characteristic, Characteristic,
CharacteristicValue, CharacteristicValue,
GATT_DEVICE_INFORMATION_SERVICE, GATT_DEVICE_INFORMATION_SERVICE,
GATT_DEVICE_HUMAN_INTERFACE_DEVICE_SERVICE, GATT_HUMAN_INTERFACE_DEVICE_SERVICE,
GATT_DEVICE_BATTERY_SERVICE, GATT_BATTERY_SERVICE,
GATT_BATTERY_LEVEL_CHARACTERISTIC, GATT_BATTERY_LEVEL_CHARACTERISTIC,
GATT_MANUFACTURER_NAME_STRING_CHARACTERISTIC, GATT_MANUFACTURER_NAME_STRING_CHARACTERISTIC,
GATT_REPORT_CHARACTERISTIC, GATT_REPORT_CHARACTERISTIC,
@@ -43,7 +43,7 @@ from bumble.gatt import (
GATT_PROTOCOL_MODE_CHARACTERISTIC, GATT_PROTOCOL_MODE_CHARACTERISTIC,
GATT_HID_INFORMATION_CHARACTERISTIC, GATT_HID_INFORMATION_CHARACTERISTIC,
GATT_HID_CONTROL_POINT_CHARACTERISTIC, GATT_HID_CONTROL_POINT_CHARACTERISTIC,
GATT_REPORT_REFERENCE_DESCRIPTOR GATT_REPORT_REFERENCE_DESCRIPTOR,
) )
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -58,41 +58,75 @@ HID_OUTPUT_REPORT = 0x02
HID_FEATURE_REPORT = 0x03 HID_FEATURE_REPORT = 0x03
# Report Map # Report Map
HID_KEYBOARD_REPORT_MAP = bytes([ HID_KEYBOARD_REPORT_MAP = bytes(
0x05, 0x01, # Usage Page (Generic Desktop Ctrls) [
0x09, 0x06, # Usage (Keyboard) 0x05,
0xA1, 0x01, # Collection (Application) 0x01, # Usage Page (Generic Desktop Ctrls)
0x85, 0x01, # . Report ID (1) 0x09,
0x05, 0x07, # . Usage Page (Kbrd/Keypad) 0x06, # Usage (Keyboard)
0x19, 0xE0, # . Usage Minimum (0xE0) 0xA1,
0x29, 0xE7, # . Usage Maximum (0xE7) 0x01, # Collection (Application)
0x15, 0x00, # . Logical Minimum (0) 0x85,
0x25, 0x01, # . Logical Maximum (1) 0x01, # . Report ID (1)
0x75, 0x01, # . Report Size (1) 0x05,
0x95, 0x08, # . Report Count (8) 0x07, # . Usage Page (Kbrd/Keypad)
0x81, 0x02, # . Input (Data,Var,Abs,No Wrap,Linear,Preferred State,No Null Position) 0x19,
0x95, 0x01, # . Report Count (1) 0xE0, # . Usage Minimum (0xE0)
0x75, 0x08, # . Report Size (8) 0x29,
0x81, 0x01, # . Input (Const,Array,Abs,No Wrap,Linear,Preferred State,No Null Position) 0xE7, # . Usage Maximum (0xE7)
0x95, 0x06, # . Report Count (6) 0x15,
0x75, 0x08, # . Report Size (8) 0x00, # . Logical Minimum (0)
0x15, 0x00, # . Logical Minimum (0x00) 0x25,
0x25, 0x94, # . Logical Maximum (0x94) 0x01, # . Logical Maximum (1)
0x05, 0x07, # . Usage Page (Kbrd/Keypad) 0x75,
0x19, 0x00, # . Usage Minimum (0x00) 0x01, # . Report Size (1)
0x29, 0x94, # . Usage Maximum (0x94) 0x95,
0x81, 0x00, # . Input (Data,Array,Abs,No Wrap,Linear,Preferred State,No Null Position) 0x08, # . Report Count (8)
0x95, 0x05, # . Report Count (5) 0x81,
0x75, 0x01, # . Report Size (1) 0x02, # . Input (Data,Var,Abs,No Wrap,Linear,Preferred State,No Null Position)
0x05, 0x08, # . Usage Page (LEDs) 0x95,
0x19, 0x01, # . Usage Minimum (Num Lock) 0x01, # . Report Count (1)
0x29, 0x05, # . Usage Maximum (Kana) 0x75,
0x91, 0x02, # . Output (Data,Var,Abs,No Wrap,Linear,Preferred State,No Null Position,Non-volatile) 0x08, # . Report Size (8)
0x95, 0x01, # . Report Count (1) 0x81,
0x75, 0x03, # . Report Size (3) 0x01, # . Input (Const,Array,Abs,No Wrap,Linear,Preferred State,No Null Position)
0x91, 0x01, # . Output (Const,Array,Abs,No Wrap,Linear,Preferred State,No Null Position,Non-volatile) 0x95,
0xC0 # End Collection 0x06, # . Report Count (6)
]) 0x75,
0x08, # . Report Size (8)
0x15,
0x00, # . Logical Minimum (0x00)
0x25,
0x94, # . Logical Maximum (0x94)
0x05,
0x07, # . Usage Page (Kbrd/Keypad)
0x19,
0x00, # . Usage Minimum (0x00)
0x29,
0x94, # . Usage Maximum (0x94)
0x81,
0x00, # . Input (Data,Array,Abs,No Wrap,Linear,Preferred State,No Null Position)
0x95,
0x05, # . Report Count (5)
0x75,
0x01, # . Report Size (1)
0x05,
0x08, # . Usage Page (LEDs)
0x19,
0x01, # . Usage Minimum (Num Lock)
0x29,
0x05, # . Usage Maximum (Kana)
0x91,
0x02, # . Output (Data,Var,Abs,No Wrap,Linear,Preferred State,No Null Position,Non-volatile)
0x95,
0x01, # . Report Count (1)
0x75,
0x03, # . Report Size (3)
0x91,
0x01, # . Output (Const,Array,Abs,No Wrap,Linear,Preferred State,No Null Position,Non-volatile)
0xC0, # End Collection
]
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -126,38 +160,48 @@ async def keyboard_host(device, peer_address):
connection = await device.connect(peer_address) connection = await device.connect(peer_address)
await connection.pair() await connection.pair()
peer = Peer(connection) peer = Peer(connection)
await peer.discover_service(GATT_DEVICE_HUMAN_INTERFACE_DEVICE_SERVICE) await peer.discover_service(GATT_HUMAN_INTERFACE_DEVICE_SERVICE)
hid_services = peer.get_services_by_uuid(GATT_DEVICE_HUMAN_INTERFACE_DEVICE_SERVICE) hid_services = peer.get_services_by_uuid(GATT_HUMAN_INTERFACE_DEVICE_SERVICE)
if not hid_services: if not hid_services:
print(color('!!! No HID service', 'red')) print(color('!!! No HID service', 'red'))
return return
await peer.discover_characteristics() await peer.discover_characteristics()
protocol_mode_characteristics = peer.get_characteristics_by_uuid(GATT_PROTOCOL_MODE_CHARACTERISTIC) protocol_mode_characteristics = peer.get_characteristics_by_uuid(
GATT_PROTOCOL_MODE_CHARACTERISTIC
)
if not protocol_mode_characteristics: if not protocol_mode_characteristics:
print(color('!!! No Protocol Mode characteristic', 'red')) print(color('!!! No Protocol Mode characteristic', 'red'))
return return
protocol_mode_characteristic = protocol_mode_characteristics[0] protocol_mode_characteristic = protocol_mode_characteristics[0]
hid_information_characteristics = peer.get_characteristics_by_uuid(GATT_HID_INFORMATION_CHARACTERISTIC) hid_information_characteristics = peer.get_characteristics_by_uuid(
GATT_HID_INFORMATION_CHARACTERISTIC
)
if not hid_information_characteristics: if not hid_information_characteristics:
print(color('!!! No HID Information characteristic', 'red')) print(color('!!! No HID Information characteristic', 'red'))
return return
hid_information_characteristic = hid_information_characteristics[0] hid_information_characteristic = hid_information_characteristics[0]
report_map_characteristics = peer.get_characteristics_by_uuid(GATT_REPORT_MAP_CHARACTERISTIC) report_map_characteristics = peer.get_characteristics_by_uuid(
GATT_REPORT_MAP_CHARACTERISTIC
)
if not report_map_characteristics: if not report_map_characteristics:
print(color('!!! No Report Map characteristic', 'red')) print(color('!!! No Report Map characteristic', 'red'))
return return
report_map_characteristic = report_map_characteristics[0] report_map_characteristic = report_map_characteristics[0]
control_point_characteristics = peer.get_characteristics_by_uuid(GATT_HID_CONTROL_POINT_CHARACTERISTIC) control_point_characteristics = peer.get_characteristics_by_uuid(
GATT_HID_CONTROL_POINT_CHARACTERISTIC
)
if not control_point_characteristics: if not control_point_characteristics:
print(color('!!! No Control Point characteristic', 'red')) print(color('!!! No Control Point characteristic', 'red'))
return return
# control_point_characteristic = control_point_characteristics[0] # control_point_characteristic = control_point_characteristics[0]
report_characteristics = peer.get_characteristics_by_uuid(GATT_REPORT_CHARACTERISTIC) report_characteristics = peer.get_characteristics_by_uuid(
GATT_REPORT_CHARACTERISTIC
)
if not report_characteristics: if not report_characteristics:
print(color('!!! No Report characteristic', 'red')) print(color('!!! No Report characteristic', 'red'))
return return
@@ -165,13 +209,20 @@ async def keyboard_host(device, peer_address):
print(color('REPORT:', 'yellow'), characteristic) print(color('REPORT:', 'yellow'), characteristic)
if characteristic.properties & Characteristic.NOTIFY: if characteristic.properties & Characteristic.NOTIFY:
await peer.discover_descriptors(characteristic) await peer.discover_descriptors(characteristic)
report_reference_descriptor = characteristic.get_descriptor(GATT_REPORT_REFERENCE_DESCRIPTOR) report_reference_descriptor = characteristic.get_descriptor(
GATT_REPORT_REFERENCE_DESCRIPTOR
)
if report_reference_descriptor: if report_reference_descriptor:
report_reference = await peer.read_value(report_reference_descriptor) report_reference = await peer.read_value(report_reference_descriptor)
print(color(' Report Reference:', 'blue'), report_reference.hex()) print(color(' Report Reference:', 'blue'), report_reference.hex())
else: else:
report_reference = bytes([0, 0]) report_reference = bytes([0, 0])
await peer.subscribe(characteristic, lambda value, param=f'[{i}] {report_reference.hex()}': on_report(param, value)) await peer.subscribe(
characteristic,
lambda value, param=f'[{i}] {report_reference.hex()}': on_report(
param, value
),
)
protocol_mode = await peer.read_value(protocol_mode_characteristic) protocol_mode = await peer.read_value(protocol_mode_characteristic)
print(f'Protocol Mode: {protocol_mode.hex()}') print(f'Protocol Mode: {protocol_mode.hex()}')
@@ -192,23 +243,34 @@ async def keyboard_device(device, command):
Characteristic.READABLE | Characteristic.WRITEABLE, Characteristic.READABLE | Characteristic.WRITEABLE,
bytes([0, 0, 0, 0, 0, 0, 0, 0]), bytes([0, 0, 0, 0, 0, 0, 0, 0]),
[ [
Descriptor(GATT_REPORT_REFERENCE_DESCRIPTOR, Descriptor.READABLE, bytes([0x01, HID_INPUT_REPORT])) Descriptor(
] GATT_REPORT_REFERENCE_DESCRIPTOR,
Descriptor.READABLE,
bytes([0x01, HID_INPUT_REPORT]),
)
],
) )
# Create an 'output report' characteristic to receive keyboard reports from the host # Create an 'output report' characteristic to receive keyboard reports from the host
output_report_characteristic = Characteristic( output_report_characteristic = Characteristic(
GATT_REPORT_CHARACTERISTIC, GATT_REPORT_CHARACTERISTIC,
Characteristic.READ | Characteristic.WRITE | Characteristic.WRITE_WITHOUT_RESPONSE, Characteristic.READ
| Characteristic.WRITE
| Characteristic.WRITE_WITHOUT_RESPONSE,
Characteristic.READABLE | Characteristic.WRITEABLE, Characteristic.READABLE | Characteristic.WRITEABLE,
bytes([0]), bytes([0]),
[ [
Descriptor(GATT_REPORT_REFERENCE_DESCRIPTOR, Descriptor.READABLE, bytes([0x01, HID_OUTPUT_REPORT])) Descriptor(
] GATT_REPORT_REFERENCE_DESCRIPTOR,
Descriptor.READABLE,
bytes([0x01, HID_OUTPUT_REPORT]),
)
],
) )
# Add the services to the GATT sever # Add the services to the GATT sever
device.add_services([ device.add_services(
[
Service( Service(
GATT_DEVICE_INFORMATION_SERVICE, GATT_DEVICE_INFORMATION_SERVICE,
[ [
@@ -216,53 +278,56 @@ async def keyboard_device(device, command):
GATT_MANUFACTURER_NAME_STRING_CHARACTERISTIC, GATT_MANUFACTURER_NAME_STRING_CHARACTERISTIC,
Characteristic.READ, Characteristic.READ,
Characteristic.READABLE, Characteristic.READABLE,
'Bumble' 'Bumble',
) )
] ],
), ),
Service( Service(
GATT_DEVICE_HUMAN_INTERFACE_DEVICE_SERVICE, GATT_HUMAN_INTERFACE_DEVICE_SERVICE,
[ [
Characteristic( Characteristic(
GATT_PROTOCOL_MODE_CHARACTERISTIC, GATT_PROTOCOL_MODE_CHARACTERISTIC,
Characteristic.READ, Characteristic.READ,
Characteristic.READABLE, Characteristic.READABLE,
bytes([HID_REPORT_PROTOCOL]) bytes([HID_REPORT_PROTOCOL]),
), ),
Characteristic( Characteristic(
GATT_HID_INFORMATION_CHARACTERISTIC, GATT_HID_INFORMATION_CHARACTERISTIC,
Characteristic.READ, Characteristic.READ,
Characteristic.READABLE, Characteristic.READABLE,
bytes([0x11, 0x01, 0x00, 0x03]) # bcdHID=1.1, bCountryCode=0x00, Flags=RemoteWake|NormallyConnectable bytes(
[0x11, 0x01, 0x00, 0x03]
), # bcdHID=1.1, bCountryCode=0x00, Flags=RemoteWake|NormallyConnectable
), ),
Characteristic( Characteristic(
GATT_HID_CONTROL_POINT_CHARACTERISTIC, GATT_HID_CONTROL_POINT_CHARACTERISTIC,
Characteristic.WRITE_WITHOUT_RESPONSE, Characteristic.WRITE_WITHOUT_RESPONSE,
Characteristic.WRITEABLE, Characteristic.WRITEABLE,
CharacteristicValue(write=on_hid_control_point_write) CharacteristicValue(write=on_hid_control_point_write),
), ),
Characteristic( Characteristic(
GATT_REPORT_MAP_CHARACTERISTIC, GATT_REPORT_MAP_CHARACTERISTIC,
Characteristic.READ, Characteristic.READ,
Characteristic.READABLE, Characteristic.READABLE,
HID_KEYBOARD_REPORT_MAP HID_KEYBOARD_REPORT_MAP,
), ),
input_report_characteristic, input_report_characteristic,
output_report_characteristic output_report_characteristic,
] ],
), ),
Service( Service(
GATT_DEVICE_BATTERY_SERVICE, GATT_BATTERY_SERVICE,
[ [
Characteristic( Characteristic(
GATT_BATTERY_LEVEL_CHARACTERISTIC, GATT_BATTERY_LEVEL_CHARACTERISTIC,
Characteristic.READ, Characteristic.READ,
Characteristic.READABLE, Characteristic.READABLE,
bytes([100]) bytes([100]),
) )
],
),
] ]
) )
])
# Debug print # Debug print
for attribute in device.gatt_server.attributes: for attribute in device.gatt_server.attributes:
@@ -270,13 +335,20 @@ async def keyboard_device(device, command):
# Set the advertising data # Set the advertising data
device.advertising_data = bytes( device.advertising_data = bytes(
AdvertisingData([ AdvertisingData(
(AdvertisingData.COMPLETE_LOCAL_NAME, bytes('Bumble Keyboard', 'utf-8')), [
(AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS, (
bytes(GATT_DEVICE_HUMAN_INTERFACE_DEVICE_SERVICE)), AdvertisingData.COMPLETE_LOCAL_NAME,
bytes('Bumble Keyboard', 'utf-8'),
),
(
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
bytes(GATT_HUMAN_INTERFACE_DEVICE_SERVICE),
),
(AdvertisingData.APPEARANCE, struct.pack('<H', 0x03C1)), (AdvertisingData.APPEARANCE, struct.pack('<H', 0x03C1)),
(AdvertisingData.FLAGS, bytes([0x05])) (AdvertisingData.FLAGS, bytes([0x05])),
]) ]
)
) )
# Attach a listener # Attach a listener
@@ -303,14 +375,21 @@ async def keyboard_device(device, command):
code = ord(key) code = ord(key)
if code >= ord('a') and code <= ord('z'): if code >= ord('a') and code <= ord('z'):
hid_code = 0x04 + code - ord('a') hid_code = 0x04 + code - ord('a')
input_report_characteristic.value = bytes([0, 0, hid_code, 0, 0, 0, 0, 0]) input_report_characteristic.value = bytes(
await device.notify_subscribers(input_report_characteristic) [0, 0, hid_code, 0, 0, 0, 0, 0]
)
await device.notify_subscribers(
input_report_characteristic
)
elif message_type == 'keyup': elif message_type == 'keyup':
input_report_characteristic.value = bytes.fromhex('0000000000000000') input_report_characteristic.value = bytes.fromhex(
'0000000000000000'
)
await device.notify_subscribers(input_report_characteristic) await device.notify_subscribers(input_report_characteristic)
except websockets.exceptions.ConnectionClosedOK: except websockets.exceptions.ConnectionClosedOK:
pass pass
await websockets.serve(serve, 'localhost', 8989) await websockets.serve(serve, 'localhost', 8989)
await asyncio.get_event_loop().create_future() await asyncio.get_event_loop().create_future()
else: else:
@@ -321,7 +400,9 @@ async def keyboard_device(device, command):
# Keypress for the letter # Keypress for the letter
keycode = 0x04 + letter - 0x61 keycode = 0x04 + letter - 0x61
input_report_characteristic.value = bytes([0, 0, keycode, 0, 0, 0, 0, 0]) input_report_characteristic.value = bytes(
[0, 0, keycode, 0, 0, 0, 0, 0]
)
await device.notify_subscribers(input_report_characteristic) await device.notify_subscribers(input_report_characteristic)
# Key release # Key release
@@ -335,10 +416,16 @@ async def main():
print('Usage: python keyboard.py <device-config> <transport-spec> <command>') print('Usage: python keyboard.py <device-config> <transport-spec> <command>')
print(' where <command> is one of:') print(' where <command> is one of:')
print(' connect <address> (run a keyboard host, connecting to a keyboard)') print(' connect <address> (run a keyboard host, connecting to a keyboard)')
print(' web (run a keyboard with keypress input from a web page, see keyboard.html') print(
print(' sim (run a keyboard simulation, emitting a canned sequence of keystrokes') ' web (run a keyboard with keypress input from a web page, see keyboard.html'
)
print(
' sim (run a keyboard simulation, emitting a canned sequence of keystrokes'
)
print('example: python keyboard.py keyboard.json usb:0 sim') print('example: python keyboard.py keyboard.json usb:0 sim')
print('example: python keyboard.py keyboard.json usb:0 connect A0:A1:A2:A3:A4:A5') print(
'example: python keyboard.py keyboard.json usb:0 connect A0:A1:A2:A3:A4:A5'
)
return return
async with await open_transport_or_link(sys.argv[2]) as (hci_source, hci_sink): async with await open_transport_or_link(sys.argv[2]) as (hci_source, hci_sink):
+24 -19
View File
@@ -27,12 +27,9 @@ from bumble.core import (
BT_BR_EDR_TRANSPORT, BT_BR_EDR_TRANSPORT,
BT_AVDTP_PROTOCOL_ID, BT_AVDTP_PROTOCOL_ID,
BT_AUDIO_SINK_SERVICE, BT_AUDIO_SINK_SERVICE,
BT_L2CAP_PROTOCOL_ID BT_L2CAP_PROTOCOL_ID,
)
from bumble.avdtp import (
Protocol as AVDTP_Protocol,
find_avdtp_service_with_connection
) )
from bumble.avdtp import Protocol as AVDTP_Protocol, find_avdtp_service_with_connection
from bumble.a2dp import make_audio_source_service_sdp_records from bumble.a2dp import make_audio_source_service_sdp_records
from bumble.sdp import ( from bumble.sdp import (
Client as SDP_Client, Client as SDP_Client,
@@ -40,7 +37,7 @@ from bumble.sdp import (
DataElement, DataElement,
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID, SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
) )
@@ -48,7 +45,9 @@ from bumble.sdp import (
def sdp_records(): def sdp_records():
service_record_handle = 0x00010001 service_record_handle = 0x00010001
return { return {
service_record_handle: make_audio_source_service_sdp_records(service_record_handle) service_record_handle: make_audio_source_service_sdp_records(
service_record_handle
)
} }
@@ -64,8 +63,8 @@ async def find_a2dp_service(device, connection):
[ [
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID, SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
] ],
) )
print(color('==================================', 'blue')) print(color('==================================', 'blue'))
@@ -78,8 +77,7 @@ async def find_a2dp_service(device, connection):
# Service classes # Service classes
service_class_id_list = ServiceAttribute.find_attribute_in_list( service_class_id_list = ServiceAttribute.find_attribute_in_list(
attribute_list, attribute_list, SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID
) )
if service_class_id_list: if service_class_id_list:
if service_class_id_list.value: if service_class_id_list.value:
@@ -89,8 +87,7 @@ async def find_a2dp_service(device, connection):
# Protocol info # Protocol info
protocol_descriptor_list = ServiceAttribute.find_attribute_in_list( protocol_descriptor_list = ServiceAttribute.find_attribute_in_list(
attribute_list, attribute_list, SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID
) )
if protocol_descriptor_list: if protocol_descriptor_list:
print(color(' Protocol:', 'green')) print(color(' Protocol:', 'green'))
@@ -103,18 +100,24 @@ async def find_a2dp_service(device, connection):
if len(protocol_descriptor.value) >= 2: if len(protocol_descriptor.value) >= 2:
avdtp_version_major = protocol_descriptor.value[1].value >> 8 avdtp_version_major = protocol_descriptor.value[1].value >> 8
avdtp_version_minor = protocol_descriptor.value[1].value & 0xFF avdtp_version_minor = protocol_descriptor.value[1].value & 0xFF
print(f'{color(" AVDTP Version:", "cyan")} {avdtp_version_major}.{avdtp_version_minor}') print(
f'{color(" AVDTP Version:", "cyan")} {avdtp_version_major}.{avdtp_version_minor}'
)
service_version = (avdtp_version_major, avdtp_version_minor) service_version = (avdtp_version_major, avdtp_version_minor)
# Profile info # Profile info
bluetooth_profile_descriptor_list = ServiceAttribute.find_attribute_in_list( bluetooth_profile_descriptor_list = ServiceAttribute.find_attribute_in_list(
attribute_list, attribute_list, SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID
) )
if bluetooth_profile_descriptor_list: if bluetooth_profile_descriptor_list:
if bluetooth_profile_descriptor_list.value: if bluetooth_profile_descriptor_list.value:
if bluetooth_profile_descriptor_list.value[0].type == DataElement.SEQUENCE: if (
bluetooth_profile_descriptors = bluetooth_profile_descriptor_list.value bluetooth_profile_descriptor_list.value[0].type
== DataElement.SEQUENCE
):
bluetooth_profile_descriptors = (
bluetooth_profile_descriptor_list.value
)
else: else:
# Sometimes, instead of a list of lists, we just find a list. Fix that # Sometimes, instead of a list of lists, we just find a list. Fix that
bluetooth_profile_descriptors = [bluetooth_profile_descriptor_list] bluetooth_profile_descriptors = [bluetooth_profile_descriptor_list]
@@ -123,7 +126,9 @@ async def find_a2dp_service(device, connection):
for bluetooth_profile_descriptor in bluetooth_profile_descriptors: for bluetooth_profile_descriptor in bluetooth_profile_descriptors:
version_major = bluetooth_profile_descriptor.value[1].value >> 8 version_major = bluetooth_profile_descriptor.value[1].value >> 8
version_minor = bluetooth_profile_descriptor.value[1].value & 0xFF version_minor = bluetooth_profile_descriptor.value[1].value & 0xFF
print(f' {bluetooth_profile_descriptor.value[0].value} - version {version_major}.{version_minor}') print(
f' {bluetooth_profile_descriptor.value[0].value} - version {version_major}.{version_minor}'
)
await sdp_client.disconnect() await sdp_client.disconnect()
return service_version return service_version
+19 -12
View File
@@ -28,7 +28,7 @@ from bumble.avdtp import (
AVDTP_AUDIO_MEDIA_TYPE, AVDTP_AUDIO_MEDIA_TYPE,
Protocol, Protocol,
Listener, Listener,
MediaCodecCapabilities MediaCodecCapabilities,
) )
from bumble.a2dp import ( from bumble.a2dp import (
make_audio_sink_service_sdp_records, make_audio_sink_service_sdp_records,
@@ -39,19 +39,19 @@ from bumble.a2dp import (
SBC_LOUDNESS_ALLOCATION_METHOD, SBC_LOUDNESS_ALLOCATION_METHOD,
SBC_STEREO_CHANNEL_MODE, SBC_STEREO_CHANNEL_MODE,
SBC_JOINT_STEREO_CHANNEL_MODE, SBC_JOINT_STEREO_CHANNEL_MODE,
SbcMediaCodecInformation SbcMediaCodecInformation,
) )
Context = { Context = {'output': None}
'output': None
}
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def sdp_records(): def sdp_records():
service_record_handle = 0x00010001 service_record_handle = 0x00010001
return { return {
service_record_handle: make_audio_sink_service_sdp_records(service_record_handle) service_record_handle: make_audio_sink_service_sdp_records(
service_record_handle
)
} }
@@ -67,14 +67,17 @@ def codec_capabilities():
SBC_MONO_CHANNEL_MODE, SBC_MONO_CHANNEL_MODE,
SBC_DUAL_CHANNEL_MODE, SBC_DUAL_CHANNEL_MODE,
SBC_STEREO_CHANNEL_MODE, SBC_STEREO_CHANNEL_MODE,
SBC_JOINT_STEREO_CHANNEL_MODE SBC_JOINT_STEREO_CHANNEL_MODE,
], ],
block_lengths=[4, 8, 12, 16], block_lengths=[4, 8, 12, 16],
subbands=[4, 8], subbands=[4, 8],
allocation_methods = [SBC_LOUDNESS_ALLOCATION_METHOD, SBC_SNR_ALLOCATION_METHOD], allocation_methods=[
SBC_LOUDNESS_ALLOCATION_METHOD,
SBC_SNR_ALLOCATION_METHOD,
],
minimum_bitpool_value=2, minimum_bitpool_value=2,
maximum_bitpool_value = 53 maximum_bitpool_value=53,
) ),
) )
@@ -104,7 +107,9 @@ def on_rtp_packet(packet):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def main(): async def main():
if len(sys.argv) < 4: if len(sys.argv) < 4:
print('Usage: run_a2dp_sink.py <device-config> <transport-spec> <sbc-file> [<bt-addr>]') print(
'Usage: run_a2dp_sink.py <device-config> <transport-spec> <sbc-file> [<bt-addr>]'
)
print('example: run_a2dp_sink.py classic1.json usb:0 output.sbc') print('example: run_a2dp_sink.py classic1.json usb:0 output.sbc')
return return
@@ -133,7 +138,9 @@ async def main():
# Connect to the source # Connect to the source
target_address = sys.argv[4] target_address = sys.argv[4]
print(f'=== Connecting to {target_address}...') print(f'=== Connecting to {target_address}...')
connection = await device.connect(target_address, transport=BT_BR_EDR_TRANSPORT) connection = await device.connect(
target_address, transport=BT_BR_EDR_TRANSPORT
)
print(f'=== Connected to {connection.peer_address}!') print(f'=== Connected to {connection.peer_address}!')
# Request authentication # Request authentication
+31 -13
View File
@@ -30,7 +30,7 @@ from bumble.avdtp import (
MediaCodecCapabilities, MediaCodecCapabilities,
MediaPacketPump, MediaPacketPump,
Protocol, Protocol,
Listener Listener,
) )
from bumble.a2dp import ( from bumble.a2dp import (
SBC_JOINT_STEREO_CHANNEL_MODE, SBC_JOINT_STEREO_CHANNEL_MODE,
@@ -38,7 +38,7 @@ from bumble.a2dp import (
make_audio_source_service_sdp_records, make_audio_source_service_sdp_records,
A2DP_SBC_CODEC_TYPE, A2DP_SBC_CODEC_TYPE,
SbcMediaCodecInformation, SbcMediaCodecInformation,
SbcPacketSource SbcPacketSource,
) )
@@ -46,7 +46,9 @@ from bumble.a2dp import (
def sdp_records(): def sdp_records():
service_record_handle = 0x00010001 service_record_handle = 0x00010001
return { return {
service_record_handle: make_audio_source_service_sdp_records(service_record_handle) service_record_handle: make_audio_source_service_sdp_records(
service_record_handle
)
} }
@@ -63,14 +65,16 @@ def codec_capabilities():
subbands=8, subbands=8,
allocation_method=SBC_LOUDNESS_ALLOCATION_METHOD, allocation_method=SBC_LOUDNESS_ALLOCATION_METHOD,
minimum_bitpool_value=2, minimum_bitpool_value=2,
maximum_bitpool_value = 53 maximum_bitpool_value=53,
) ),
) )
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def on_avdtp_connection(read_function, protocol): def on_avdtp_connection(read_function, protocol):
packet_source = SbcPacketSource(read_function, protocol.l2cap_channel.mtu, codec_capabilities()) packet_source = SbcPacketSource(
read_function, protocol.l2cap_channel.mtu, codec_capabilities()
)
packet_pump = MediaPacketPump(packet_source.packets) packet_pump = MediaPacketPump(packet_source.packets)
protocol.add_source(packet_source.codec_capabilities, packet_pump) protocol.add_source(packet_source.codec_capabilities, packet_pump)
@@ -83,14 +87,18 @@ async def stream_packets(read_function, protocol):
print('@@@', endpoint) print('@@@', endpoint)
# Select a sink # Select a sink
sink = protocol.find_remote_sink_by_codec(AVDTP_AUDIO_MEDIA_TYPE, A2DP_SBC_CODEC_TYPE) sink = protocol.find_remote_sink_by_codec(
AVDTP_AUDIO_MEDIA_TYPE, A2DP_SBC_CODEC_TYPE
)
if sink is None: if sink is None:
print(color('!!! no SBC sink found', 'red')) print(color('!!! no SBC sink found', 'red'))
return return
print(f'### Selected sink: {sink.seid}') print(f'### Selected sink: {sink.seid}')
# Stream the packets # Stream the packets
packet_source = SbcPacketSource(read_function, protocol.l2cap_channel.mtu, codec_capabilities()) packet_source = SbcPacketSource(
read_function, protocol.l2cap_channel.mtu, codec_capabilities()
)
packet_pump = MediaPacketPump(packet_source.packets) packet_pump = MediaPacketPump(packet_source.packets)
source = protocol.add_source(packet_source.codec_capabilities, packet_pump) source = protocol.add_source(packet_source.codec_capabilities, packet_pump)
stream = await protocol.create_stream(source, sink) stream = await protocol.create_stream(source, sink)
@@ -107,8 +115,12 @@ async def stream_packets(read_function, protocol):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def main(): async def main():
if len(sys.argv) < 4: if len(sys.argv) < 4:
print('Usage: run_a2dp_source.py <device-config> <transport-spec> <sbc-file> [<bluetooth-address>]') print(
print('example: run_a2dp_source.py classic1.json usb:0 test.sbc E1:CA:72:48:C4:E8') 'Usage: run_a2dp_source.py <device-config> <transport-spec> <sbc-file> [<bluetooth-address>]'
)
print(
'example: run_a2dp_source.py classic1.json usb:0 test.sbc E1:CA:72:48:C4:E8'
)
return return
print('<<< connecting to HCI...') print('<<< connecting to HCI...')
@@ -134,7 +146,9 @@ async def main():
# Connect to a peer # Connect to a peer
target_address = sys.argv[4] target_address = sys.argv[4]
print(f'=== Connecting to {target_address}...') print(f'=== Connecting to {target_address}...')
connection = await device.connect(target_address, transport=BT_BR_EDR_TRANSPORT) connection = await device.connect(
target_address, transport=BT_BR_EDR_TRANSPORT
)
print(f'=== Connected to {connection.peer_address}!') print(f'=== Connected to {connection.peer_address}!')
# Request authentication # Request authentication
@@ -148,7 +162,9 @@ async def main():
print('*** Encryption on') print('*** Encryption on')
# Look for an A2DP service # Look for an A2DP service
avdtp_version = await find_avdtp_service_with_connection(device, connection) avdtp_version = await find_avdtp_service_with_connection(
device, connection
)
if not avdtp_version: if not avdtp_version:
print(color('!!! no A2DP service found')) print(color('!!! no A2DP service found'))
return return
@@ -161,7 +177,9 @@ async def main():
else: else:
# Create a listener to wait for AVDTP connections # Create a listener to wait for AVDTP connections
listener = Listener(Listener.create_registrar(device), version=(1, 2)) listener = Listener(Listener.create_registrar(device), version=(1, 2))
listener.on('connection', lambda protocol: on_avdtp_connection(read, protocol)) listener.on(
'connection', lambda protocol: on_avdtp_connection(read, protocol)
)
# Become connectable and wait for a connection # Become connectable and wait for a connection
await device.set_discoverable(True) await device.set_discoverable(True)
+20 -4
View File
@@ -29,20 +29,36 @@ from bumble.transport import open_transport_or_link
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def main(): async def main():
if len(sys.argv) != 3: if len(sys.argv) < 3:
print('Usage: run_advertiser.py <config-file> <transport-spec>') print(
print('example: run_advertiser.py device1.json link-relay:ws://localhost:8888/test') 'Usage: run_advertiser.py <config-file> <transport-spec> [type] [address]'
)
print('example: run_advertiser.py device1.json usb:0')
return return
if len(sys.argv) >= 4:
advertising_type = AdvertisingType(int(sys.argv[3]))
else:
advertising_type = AdvertisingType.UNDIRECTED_CONNECTABLE_SCANNABLE
if advertising_type.is_directed:
if len(sys.argv) < 5:
print('<address> required for directed advertising')
return
target = Address(sys.argv[4])
else:
target = None
print('<<< connecting to HCI...') print('<<< connecting to HCI...')
async with await open_transport_or_link(sys.argv[2]) as (hci_source, hci_sink): async with await open_transport_or_link(sys.argv[2]) as (hci_source, hci_sink):
print('<<< connected') print('<<< connected')
device = Device.from_config_file_with_hci(sys.argv[1], hci_source, hci_sink) device = Device.from_config_file_with_hci(sys.argv[1], hci_source, hci_sink)
await device.power_on() await device.power_on()
await device.start_advertising() await device.start_advertising(advertising_type=advertising_type, target=target)
await hci_source.wait_for_termination() await hci_source.wait_for_termination()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main()) asyncio.run(main())
+200
View File
@@ -0,0 +1,200 @@
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import struct
import sys
import os
import logging
from bumble.core import AdvertisingData
from bumble.device import Device
from bumble.transport import open_transport_or_link
from bumble.hci import UUID
from bumble.gatt import Service, Characteristic, CharacteristicValue
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
ASHA_SERVICE = UUID.from_16_bits(0xFDF0, 'Audio Streaming for Hearing Aid')
ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC = UUID(
'6333651e-c481-4a3e-9169-7c902aad37bb', 'ReadOnlyProperties'
)
ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC = UUID(
'f0d4de7e-4a88-476c-9d9f-1937b0996cc0', 'AudioControlPoint'
)
ASHA_AUDIO_STATUS_CHARACTERISTIC = UUID(
'38663f1a-e711-4cac-b641-326b56404837', 'AudioStatus'
)
ASHA_VOLUME_CHARACTERISTIC = UUID('00e4ca9e-ab14-41e4-8823-f9e70c7e91df', 'Volume')
ASHA_LE_PSM_OUT_CHARACTERISTIC = UUID(
'2d410339-82b6-42aa-b34e-e2e01df8cc1a', 'LE_PSM_OUT'
)
# -----------------------------------------------------------------------------
async def main():
if len(sys.argv) != 4:
print(
'Usage: python run_asha_sink.py <device-config> <transport-spec> <audio-file>'
)
print('example: python run_asha_sink.py device1.json usb:0 audio_out.g722')
return
audio_out = open(sys.argv[3], 'wb')
async with await open_transport_or_link(sys.argv[2]) as (hci_source, hci_sink):
device = Device.from_config_file_with_hci(sys.argv[1], hci_source, hci_sink)
# Handler for audio control commands
def on_audio_control_point_write(connection, value):
print('--- AUDIO CONTROL POINT Write:', value.hex())
opcode = value[0]
if opcode == 1:
# Start
audio_type = ('Unknown', 'Ringtone', 'Phone Call', 'Media')[value[2]]
print(
f'### START: codec={value[1]}, audio_type={audio_type}, volume={value[3]}, otherstate={value[4]}'
)
elif opcode == 2:
print('### STOP')
elif opcode == 3:
print(f'### STATUS: connected={value[1]}')
# Respond with a status
asyncio.create_task(
device.notify_subscribers(audio_status_characteristic, force=True)
)
# Handler for volume control
def on_volume_write(connection, value):
print('--- VOLUME Write:', value[0])
# Register an L2CAP CoC server
def on_coc(channel):
def on_data(data):
print('<<< Voice data received:', data.hex())
audio_out.write(data)
channel.sink = on_data
psm = device.register_l2cap_channel_server(0, on_coc, 8)
print(f'### LE_PSM_OUT = {psm}')
# Add the ASHA service to the GATT server
read_only_properties_characteristic = Characteristic(
ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC,
Characteristic.READ,
Characteristic.READABLE,
bytes(
[
0x01, # Version
0x00, # Device Capabilities [Left, Monaural]
0x01,
0x02,
0x03,
0x04,
0x05,
0x06,
0x07,
0x08, # HiSyncId
0x01, # Feature Map [LE CoC audio output streaming supported]
0x00,
0x00, # Render Delay
0x00,
0x00, # RFU
0x02,
0x00, # Codec IDs [G.722 at 16 kHz]
]
),
)
audio_control_point_characteristic = Characteristic(
ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC,
Characteristic.WRITE | Characteristic.WRITE_WITHOUT_RESPONSE,
Characteristic.WRITEABLE,
CharacteristicValue(write=on_audio_control_point_write),
)
audio_status_characteristic = Characteristic(
ASHA_AUDIO_STATUS_CHARACTERISTIC,
Characteristic.READ | Characteristic.NOTIFY,
Characteristic.READABLE,
bytes([0]),
)
volume_characteristic = Characteristic(
ASHA_VOLUME_CHARACTERISTIC,
Characteristic.WRITE_WITHOUT_RESPONSE,
Characteristic.WRITEABLE,
CharacteristicValue(write=on_volume_write),
)
le_psm_out_characteristic = Characteristic(
ASHA_LE_PSM_OUT_CHARACTERISTIC,
Characteristic.READ,
Characteristic.READABLE,
struct.pack('<H', psm),
)
device.add_service(
Service(
ASHA_SERVICE,
[
read_only_properties_characteristic,
audio_control_point_characteristic,
audio_status_characteristic,
volume_characteristic,
le_psm_out_characteristic,
],
)
)
# Set the advertising data
device.advertising_data = bytes(
AdvertisingData(
[
(AdvertisingData.COMPLETE_LOCAL_NAME, bytes(device.name, 'utf-8')),
(AdvertisingData.FLAGS, bytes([0x06])),
(
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
bytes(ASHA_SERVICE),
),
(
AdvertisingData.SERVICE_DATA_16_BIT_UUID,
bytes(ASHA_SERVICE)
+ bytes(
[
0x01, # Protocol Version
0x00, # Capability
0x01,
0x02,
0x03,
0x04, # Truncated HiSyncID
]
),
),
]
)
)
# Go!
await device.power_on()
await device.start_advertising(auto_restart=True)
await hci_source.wait_for_termination()
# -----------------------------------------------------------------------------
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main())
+37 -9
View File
@@ -24,14 +24,22 @@ from colors import color
from bumble.device import Device from bumble.device import Device
from bumble.transport import open_transport_or_link from bumble.transport import open_transport_or_link
from bumble.core import BT_BR_EDR_TRANSPORT, BT_L2CAP_PROTOCOL_ID from bumble.core import BT_BR_EDR_TRANSPORT, BT_L2CAP_PROTOCOL_ID
from bumble.sdp import Client as SDP_Client, SDP_PUBLIC_BROWSE_ROOT, SDP_ALL_ATTRIBUTES_RANGE from bumble.sdp import (
Client as SDP_Client,
SDP_PUBLIC_BROWSE_ROOT,
SDP_ALL_ATTRIBUTES_RANGE,
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def main(): async def main():
if len(sys.argv) < 3: if len(sys.argv) < 3:
print('Usage: run_classic_connect.py <device-config> <transport-spec> <bluetooth-address>') print(
print('example: run_classic_connect.py classic1.json usb:04b4:f901 E1:CA:72:48:C4:E8') 'Usage: run_classic_connect.py <device-config> <transport-spec> <bluetooth-addresses..>'
)
print(
'example: run_classic_connect.py classic1.json usb:04b4:f901 E1:CA:72:48:C4:E8'
)
return return
print('<<< connecting to HCI...') print('<<< connecting to HCI...')
@@ -43,8 +51,7 @@ async def main():
device.classic_enabled = True device.classic_enabled = True
await device.power_on() await device.power_on()
# Connect to a peer async def connect(target_address):
target_address = sys.argv[3]
print(f'=== Connecting to {target_address}...') print(f'=== Connecting to {target_address}...')
connection = await device.connect(target_address, transport=BT_BR_EDR_TRANSPORT) connection = await device.connect(target_address, transport=BT_BR_EDR_TRANSPORT)
print(f'=== Connected to {connection.peer_address}!') print(f'=== Connected to {connection.peer_address}!')
@@ -54,28 +61,49 @@ async def main():
await sdp_client.connect(connection) await sdp_client.connect(connection)
# List all services in the root browse group # List all services in the root browse group
service_record_handles = await sdp_client.search_services([SDP_PUBLIC_BROWSE_ROOT]) service_record_handles = await sdp_client.search_services(
[SDP_PUBLIC_BROWSE_ROOT]
)
print(color('\n==================================', 'blue')) print(color('\n==================================', 'blue'))
print(color('SERVICES:', 'yellow'), service_record_handles) print(color('SERVICES:', 'yellow'), service_record_handles)
# For each service in the root browse group, get all its attributes # For each service in the root browse group, get all its attributes
for service_record_handle in service_record_handles: for service_record_handle in service_record_handles:
attributes = await sdp_client.get_attributes(service_record_handle, [SDP_ALL_ATTRIBUTES_RANGE]) attributes = await sdp_client.get_attributes(
service_record_handle, [SDP_ALL_ATTRIBUTES_RANGE]
)
print(color(f'SERVICE {service_record_handle:04X} attributes:', 'yellow')) print(color(f'SERVICE {service_record_handle:04X} attributes:', 'yellow'))
for attribute in attributes: for attribute in attributes:
print(' ', attribute.to_string(color=True)) print(' ', attribute.to_string(color=True))
# Search for services with an L2CAP service attribute # Search for services with an L2CAP service attribute
search_result = await sdp_client.search_attributes([BT_L2CAP_PROTOCOL_ID], [SDP_ALL_ATTRIBUTES_RANGE]) search_result = await sdp_client.search_attributes(
[BT_L2CAP_PROTOCOL_ID], [SDP_ALL_ATTRIBUTES_RANGE]
)
print(color('\n==================================', 'blue')) print(color('\n==================================', 'blue'))
print(color('SEARCH RESULTS:', 'yellow')) print(color('SEARCH RESULTS:', 'yellow'))
for attribute_list in search_result: for attribute_list in search_result:
print(color('SERVICE:', 'green')) print(color('SERVICE:', 'green'))
print(' ' + '\n '.join([attribute.to_string(color=True) for attribute in attribute_list])) print(
' '
+ '\n '.join(
[attribute.to_string(color=True) for attribute in attribute_list]
)
)
await sdp_client.disconnect() await sdp_client.disconnect()
await hci_source.wait_for_termination() await hci_source.wait_for_termination()
# Connect to a peer
target_addresses = sys.argv[3:]
await asyncio.wait(
[
asyncio.create_task(connect(target_address))
for target_address in target_addresses
]
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main()) asyncio.run(main())
+35 -20
View File
@@ -30,49 +30,63 @@ from bumble.sdp import (
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
) )
from bumble.core import ( from bumble.core import (
BT_AUDIO_SINK_SERVICE, BT_AUDIO_SINK_SERVICE,
BT_L2CAP_PROTOCOL_ID, BT_L2CAP_PROTOCOL_ID,
BT_AVDTP_PROTOCOL_ID, BT_AVDTP_PROTOCOL_ID,
BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE,
) )
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
SDP_SERVICE_RECORDS = { SDP_SERVICE_RECORDS = {
0x00010001: [ 0x00010001: [
ServiceAttribute(SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, DataElement.unsigned_integer_32(0x00010001)), ServiceAttribute(
ServiceAttribute(SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID, DataElement.sequence([ SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT) DataElement.unsigned_integer_32(0x00010001),
])), ),
ServiceAttribute(
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
DataElement.sequence([DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)]),
),
ServiceAttribute( ServiceAttribute(
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
DataElement.sequence([DataElement.uuid(BT_AUDIO_SINK_SERVICE)]) DataElement.sequence([DataElement.uuid(BT_AUDIO_SINK_SERVICE)]),
), ),
ServiceAttribute( ServiceAttribute(
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence([ DataElement.sequence(
DataElement.sequence([ [
DataElement.sequence(
[
DataElement.uuid(BT_L2CAP_PROTOCOL_ID), DataElement.uuid(BT_L2CAP_PROTOCOL_ID),
DataElement.unsigned_integer_16(25) DataElement.unsigned_integer_16(25),
]), ]
DataElement.sequence([ ),
DataElement.sequence(
[
DataElement.uuid(BT_AVDTP_PROTOCOL_ID), DataElement.uuid(BT_AVDTP_PROTOCOL_ID),
DataElement.unsigned_integer_16(256) DataElement.unsigned_integer_16(256),
]) ]
]) ),
]
),
), ),
ServiceAttribute( ServiceAttribute(
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID, SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence([ DataElement.sequence(
DataElement.sequence([ [
DataElement.sequence(
[
DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE), DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE),
DataElement.unsigned_integer_16(256) DataElement.unsigned_integer_16(256),
]) ]
])
) )
] ]
),
),
]
} }
@@ -99,6 +113,7 @@ async def main():
await hci_source.wait_for_termination() await hci_source.wait_for_termination()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main()) asyncio.run(main())
+15 -4
View File
@@ -29,13 +29,23 @@ from bumble.core import DeviceClass
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class DiscoveryListener(Device.Listener): class DiscoveryListener(Device.Listener):
def on_inquiry_result(self, address, class_of_device, eir_data, rssi): def on_inquiry_result(self, address, class_of_device, eir_data, rssi):
service_classes, major_device_class, minor_device_class = DeviceClass.split_class_of_device(class_of_device) (
service_classes,
major_device_class,
minor_device_class,
) = DeviceClass.split_class_of_device(class_of_device)
separator = '\n ' separator = '\n '
print(f'>>> {color(address, "yellow")}:') print(f'>>> {color(address, "yellow")}:')
print(f' Device Class (raw): {class_of_device:06X}') print(f' Device Class (raw): {class_of_device:06X}')
print(f' Device Major Class: {DeviceClass.major_device_class_name(major_device_class)}') print(
print(f' Device Minor Class: {DeviceClass.minor_device_class_name(major_device_class, minor_device_class)}') f' Device Major Class: {DeviceClass.major_device_class_name(major_device_class)}'
print(f' Device Services: {", ".join(DeviceClass.service_class_labels(service_classes))}') )
print(
f' Device Minor Class: {DeviceClass.minor_device_class_name(major_device_class, minor_device_class)}'
)
print(
f' Device Services: {", ".join(DeviceClass.service_class_labels(service_classes))}'
)
print(f' RSSI: {rssi}') print(f' RSSI: {rssi}')
if eir_data.ad_structures: if eir_data.ad_structures:
print(f' {eir_data.to_string(separator)}') print(f' {eir_data.to_string(separator)}')
@@ -59,6 +69,7 @@ async def main():
await hci_source.wait_for_termination() await hci_source.wait_for_termination()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main()) asyncio.run(main())
+7 -2
View File
@@ -27,8 +27,12 @@ from bumble.transport import open_transport_or_link
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def main(): async def main():
if len(sys.argv) < 3: if len(sys.argv) < 3:
print('Usage: run_connect_and_encrypt.py <device-config> <transport-spec> <bluetooth-address>') print(
print('example: run_connect_and_encrypt.py device1.json usb:0 E1:CA:72:48:C4:E8') 'Usage: run_connect_and_encrypt.py <device-config> <transport-spec> <bluetooth-address>'
)
print(
'example: run_connect_and_encrypt.py device1.json usb:0 E1:CA:72:48:C4:E8'
)
return return
print('<<< connecting to HCI...') print('<<< connecting to HCI...')
@@ -53,6 +57,7 @@ async def main():
await hci_source.wait_for_termination() await hci_source.wait_for_termination()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main()) asyncio.run(main())
+19 -8
View File
@@ -32,8 +32,12 @@ from bumble.transport import open_transport_or_link
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def main(): async def main():
if len(sys.argv) != 4: if len(sys.argv) != 4:
print('Usage: run_controller.py <controller-address> <device-config> <transport-spec>') print(
print('example: run_controller.py F2:F3:F4:F5:F6:F7 device1.json udp:0.0.0.0:22333,172.16.104.161:22333') 'Usage: run_controller.py <controller-address> <device-config> <transport-spec>'
)
print(
'example: run_controller.py F2:F3:F4:F5:F6:F7 device1.json udp:0.0.0.0:22333,172.16.104.161:22333'
)
return return
print('>>> connecting to HCI...') print('>>> connecting to HCI...')
@@ -44,7 +48,9 @@ async def main():
link = LocalLink() link = LocalLink()
# Create a first controller using the packet source/sink as its host interface # Create a first controller using the packet source/sink as its host interface
controller1 = Controller('C1', host_source = hci_source, host_sink = hci_sink, link = link) controller1 = Controller(
'C1', host_source=hci_source, host_sink=hci_sink, link=link
)
controller1.random_address = sys.argv[1] controller1.random_address = sys.argv[1]
# Create a second controller using the same link # Create a second controller using the same link
@@ -59,17 +65,21 @@ async def main():
device.host = host device.host = host
# Add some basic services to the device's GATT server # Add some basic services to the device's GATT server
descriptor = Descriptor(GATT_CHARACTERISTIC_USER_DESCRIPTION_DESCRIPTOR, Descriptor.READABLE, 'My Description') descriptor = Descriptor(
GATT_CHARACTERISTIC_USER_DESCRIPTION_DESCRIPTOR,
Descriptor.READABLE,
'My Description',
)
manufacturer_name_characteristic = Characteristic( manufacturer_name_characteristic = Characteristic(
GATT_MANUFACTURER_NAME_STRING_CHARACTERISTIC, GATT_MANUFACTURER_NAME_STRING_CHARACTERISTIC,
Characteristic.READ, Characteristic.READ,
Characteristic.READABLE, Characteristic.READABLE,
"Fitbit", "Fitbit",
[descriptor] [descriptor],
)
device_info_service = Service(
GATT_DEVICE_INFORMATION_SERVICE, [manufacturer_name_characteristic]
) )
device_info_service = Service(GATT_DEVICE_INFORMATION_SERVICE, [
manufacturer_name_characteristic
])
device.add_service(device_info_service) device.add_service(device_info_service)
# Debug print # Debug print
@@ -82,6 +92,7 @@ async def main():
await hci_source.wait_for_termination() await hci_source.wait_for_termination()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main()) asyncio.run(main())
+13 -6
View File
@@ -29,15 +29,17 @@ from bumble.transport import open_transport_or_link
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class ScannerListener(Device.Listener): class ScannerListener(Device.Listener):
def on_advertisement(self, address, ad_data, rssi, connectable): def on_advertisement(self, advertisement):
address_type_string = ('P', 'R', 'PI', 'RI')[address.address_type] address_type_string = ('P', 'R', 'PI', 'RI')[advertisement.address.address_type]
address_color = 'yellow' if connectable else 'red' address_color = 'yellow' if advertisement.is_connectable else 'red'
if address_type_string.startswith('P'): if address_type_string.startswith('P'):
type_color = 'green' type_color = 'green'
else: else:
type_color = 'cyan' type_color = 'cyan'
print(f'>>> {color(address, address_color)} [{color(address_type_string, type_color)}]: RSSI={rssi}, {ad_data}') print(
f'>>> {color(advertisement.address, address_color)} [{color(address_type_string, type_color)}]: RSSI={advertisement.rssi}, {advertisement.data}'
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -55,20 +57,25 @@ async def main():
link = LocalLink() link = LocalLink()
# Create a first controller using the packet source/sink as its host interface # Create a first controller using the packet source/sink as its host interface
controller1 = Controller('C1', host_source = hci_source, host_sink = hci_sink, link = link) controller1 = Controller(
'C1', host_source=hci_source, host_sink=hci_sink, link=link
)
controller1.address = 'E0:E1:E2:E3:E4:E5' controller1.address = 'E0:E1:E2:E3:E4:E5'
# Create a second controller using the same link # Create a second controller using the same link
controller2 = Controller('C2', link=link) controller2 = Controller('C2', link=link)
# Create a device with a scanner listener # Create a device with a scanner listener
device = Device.with_hci('Bumble', 'F0:F1:F2:F3:F4:F5', controller2, controller2) device = Device.with_hci(
'Bumble', 'F0:F1:F2:F3:F4:F5', controller2, controller2
)
device.listener = ScannerListener() device.listener = ScannerListener()
await device.power_on() await device.power_on()
await device.start_scanning() await device.start_scanning()
await hci_source.wait_for_termination() await hci_source.wait_for_termination()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main()) asyncio.run(main())
+4 -1
View File
@@ -70,7 +70,9 @@ class Listener(Device.Listener):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def main(): async def main():
if len(sys.argv) < 3: if len(sys.argv) < 3:
print('Usage: run_gatt_client.py <device-config> <transport-spec> [<bluetooth-address>]') print(
'Usage: run_gatt_client.py <device-config> <transport-spec> [<bluetooth-address>]'
)
print('example: run_gatt_client.py device1.json usb:0 E1:CA:72:48:C4:E8') print('example: run_gatt_client.py device1.json usb:0 E1:CA:72:48:C4:E8')
return return
@@ -93,6 +95,7 @@ async def main():
await asyncio.get_running_loop().create_future() await asyncio.get_running_loop().create_future()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main()) asyncio.run(main())
+11 -6
View File
@@ -32,7 +32,7 @@ from bumble.gatt import (
show_services, show_services,
GATT_CHARACTERISTIC_USER_DESCRIPTION_DESCRIPTOR, GATT_CHARACTERISTIC_USER_DESCRIPTION_DESCRIPTOR,
GATT_MANUFACTURER_NAME_STRING_CHARACTERISTIC, GATT_MANUFACTURER_NAME_STRING_CHARACTERISTIC,
GATT_DEVICE_INFORMATION_SERVICE GATT_DEVICE_INFORMATION_SERVICE,
) )
@@ -63,17 +63,21 @@ async def main():
await server_device.power_on() await server_device.power_on()
# Add a few entries to the device's GATT server # Add a few entries to the device's GATT server
descriptor = Descriptor(GATT_CHARACTERISTIC_USER_DESCRIPTION_DESCRIPTOR, Descriptor.READABLE, 'My Description') descriptor = Descriptor(
GATT_CHARACTERISTIC_USER_DESCRIPTION_DESCRIPTOR,
Descriptor.READABLE,
'My Description',
)
manufacturer_name_characteristic = Characteristic( manufacturer_name_characteristic = Characteristic(
GATT_MANUFACTURER_NAME_STRING_CHARACTERISTIC, GATT_MANUFACTURER_NAME_STRING_CHARACTERISTIC,
Characteristic.READ, Characteristic.READ,
Characteristic.READABLE, Characteristic.READABLE,
"Fitbit", "Fitbit",
[descriptor] [descriptor],
)
device_info_service = Service(
GATT_DEVICE_INFORMATION_SERVICE, [manufacturer_name_characteristic]
) )
device_info_service = Service(GATT_DEVICE_INFORMATION_SERVICE, [
manufacturer_name_characteristic
])
server_device.add_service(device_info_service) server_device.add_service(device_info_service)
# Connect the client to the server # Connect the client to the server
@@ -109,6 +113,7 @@ async def main():
await asyncio.get_running_loop().create_future() await asyncio.get_running_loop().create_future()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main()) asyncio.run(main())
+22 -16
View File
@@ -22,10 +22,7 @@ import logging
from bumble.device import Device, Connection from bumble.device import Device, Connection
from bumble.transport import open_transport_or_link from bumble.transport import open_transport_or_link
from bumble.att import ( from bumble.att import ATT_Error, ATT_INSUFFICIENT_ENCRYPTION_ERROR
ATT_Error,
ATT_INSUFFICIENT_ENCRYPTION_ERROR
)
from bumble.gatt import ( from bumble.gatt import (
Service, Service,
Characteristic, Characteristic,
@@ -33,7 +30,7 @@ from bumble.gatt import (
Descriptor, Descriptor,
GATT_CHARACTERISTIC_USER_DESCRIPTION_DESCRIPTOR, GATT_CHARACTERISTIC_USER_DESCRIPTION_DESCRIPTOR,
GATT_MANUFACTURER_NAME_STRING_CHARACTERISTIC, GATT_MANUFACTURER_NAME_STRING_CHARACTERISTIC,
GATT_DEVICE_INFORMATION_SERVICE GATT_DEVICE_INFORMATION_SERVICE,
) )
@@ -76,7 +73,9 @@ def my_custom_write_with_error(connection, value):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def main(): async def main():
if len(sys.argv) < 3: if len(sys.argv) < 3:
print('Usage: run_gatt_server.py <device-config> <transport-spec> [<bluetooth-address>]') print(
'Usage: run_gatt_server.py <device-config> <transport-spec> [<bluetooth-address>]'
)
print('example: run_gatt_server.py device1.json usb:0 E1:CA:72:48:C4:E8') print('example: run_gatt_server.py device1.json usb:0 E1:CA:72:48:C4:E8')
return return
@@ -89,17 +88,21 @@ async def main():
device.listener = Listener(device) device.listener = Listener(device)
# Add a few entries to the device's GATT server # Add a few entries to the device's GATT server
descriptor = Descriptor(GATT_CHARACTERISTIC_USER_DESCRIPTION_DESCRIPTOR, Descriptor.READABLE, 'My Description') descriptor = Descriptor(
GATT_CHARACTERISTIC_USER_DESCRIPTION_DESCRIPTOR,
Descriptor.READABLE,
'My Description',
)
manufacturer_name_characteristic = Characteristic( manufacturer_name_characteristic = Characteristic(
GATT_MANUFACTURER_NAME_STRING_CHARACTERISTIC, GATT_MANUFACTURER_NAME_STRING_CHARACTERISTIC,
Characteristic.READ, Characteristic.READ,
Characteristic.READABLE, Characteristic.READABLE,
'Fitbit', 'Fitbit',
[descriptor] [descriptor],
)
device_info_service = Service(
GATT_DEVICE_INFORMATION_SERVICE, [manufacturer_name_characteristic]
) )
device_info_service = Service(GATT_DEVICE_INFORMATION_SERVICE, [
manufacturer_name_characteristic
])
custom_service1 = Service( custom_service1 = Service(
'50DB505C-8AC4-4738-8448-3B1D9CC09CC5', '50DB505C-8AC4-4738-8448-3B1D9CC09CC5',
[ [
@@ -107,21 +110,23 @@ async def main():
'D901B45B-4916-412E-ACCA-376ECB603B2C', 'D901B45B-4916-412E-ACCA-376ECB603B2C',
Characteristic.READ | Characteristic.WRITE, Characteristic.READ | Characteristic.WRITE,
Characteristic.READABLE | Characteristic.WRITEABLE, Characteristic.READABLE | Characteristic.WRITEABLE,
CharacteristicValue(read=my_custom_read, write=my_custom_write) CharacteristicValue(read=my_custom_read, write=my_custom_write),
), ),
Characteristic( Characteristic(
'552957FB-CF1F-4A31-9535-E78847E1A714', '552957FB-CF1F-4A31-9535-E78847E1A714',
Characteristic.READ | Characteristic.WRITE, Characteristic.READ | Characteristic.WRITE,
Characteristic.READABLE | Characteristic.WRITEABLE, Characteristic.READABLE | Characteristic.WRITEABLE,
CharacteristicValue(read=my_custom_read_with_error, write=my_custom_write_with_error) CharacteristicValue(
read=my_custom_read_with_error, write=my_custom_write_with_error
),
), ),
Characteristic( Characteristic(
'486F64C6-4B5F-4B3B-8AFF-EDE134A8446A', '486F64C6-4B5F-4B3B-8AFF-EDE134A8446A',
Characteristic.READ | Characteristic.NOTIFY, Characteristic.READ | Characteristic.NOTIFY,
Characteristic.READABLE, Characteristic.READABLE,
'hello' 'hello',
) ),
] ],
) )
device.add_services([device_info_service, custom_service1]) device.add_services([device_info_service, custom_service1])
@@ -142,6 +147,7 @@ async def main():
await hci_source.wait_for_termination() await hci_source.wait_for_termination()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main()) asyncio.run(main())
+53 -26
View File
@@ -31,12 +31,9 @@ from bumble.sdp import (
ServiceAttribute, ServiceAttribute,
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
)
from bumble.hci import (
BT_HANDSFREE_SERVICE,
BT_RFCOMM_PROTOCOL_ID
) )
from bumble.hci import BT_HANDSFREE_SERVICE, BT_RFCOMM_PROTOCOL_ID
from bumble.hfp import HfpProtocol from bumble.hfp import HfpProtocol
@@ -52,8 +49,8 @@ async def list_rfcomm_channels(device, connection):
[ [
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID, SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
] ],
) )
print(color('==================================', 'blue')) print(color('==================================', 'blue'))
print(color('Handsfree Services:', 'yellow')) print(color('Handsfree Services:', 'yellow'))
@@ -61,40 +58,59 @@ async def list_rfcomm_channels(device, connection):
for attribute_list in search_result: for attribute_list in search_result:
# Look for the RFCOMM Channel number # Look for the RFCOMM Channel number
protocol_descriptor_list = ServiceAttribute.find_attribute_in_list( protocol_descriptor_list = ServiceAttribute.find_attribute_in_list(
attribute_list, attribute_list, SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID
) )
if protocol_descriptor_list: if protocol_descriptor_list:
for protocol_descriptor in protocol_descriptor_list.value: for protocol_descriptor in protocol_descriptor_list.value:
if len(protocol_descriptor.value) >= 2: if len(protocol_descriptor.value) >= 2:
if protocol_descriptor.value[0].value == BT_RFCOMM_PROTOCOL_ID: if protocol_descriptor.value[0].value == BT_RFCOMM_PROTOCOL_ID:
print(color('SERVICE:', 'green')) print(color('SERVICE:', 'green'))
print(color(' RFCOMM Channel:', 'cyan'), protocol_descriptor.value[1].value) print(
color(' RFCOMM Channel:', 'cyan'),
protocol_descriptor.value[1].value,
)
rfcomm_channels.append(protocol_descriptor.value[1].value) rfcomm_channels.append(protocol_descriptor.value[1].value)
# List profiles # List profiles
bluetooth_profile_descriptor_list = ServiceAttribute.find_attribute_in_list( bluetooth_profile_descriptor_list = (
ServiceAttribute.find_attribute_in_list(
attribute_list, attribute_list,
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
)
) )
if bluetooth_profile_descriptor_list: if bluetooth_profile_descriptor_list:
if bluetooth_profile_descriptor_list.value: if bluetooth_profile_descriptor_list.value:
if bluetooth_profile_descriptor_list.value[0].type == DataElement.SEQUENCE: if (
bluetooth_profile_descriptors = bluetooth_profile_descriptor_list.value bluetooth_profile_descriptor_list.value[0].type
== DataElement.SEQUENCE
):
bluetooth_profile_descriptors = (
bluetooth_profile_descriptor_list.value
)
else: else:
# Sometimes, instead of a list of lists, we just find a list. Fix that # Sometimes, instead of a list of lists, we just find a list. Fix that
bluetooth_profile_descriptors = [bluetooth_profile_descriptor_list] bluetooth_profile_descriptors = [
bluetooth_profile_descriptor_list
]
print(color(' Profiles:', 'green')) print(color(' Profiles:', 'green'))
for bluetooth_profile_descriptor in bluetooth_profile_descriptors: for (
version_major = bluetooth_profile_descriptor.value[1].value >> 8 bluetooth_profile_descriptor
version_minor = bluetooth_profile_descriptor.value[1].value & 0xFF ) in bluetooth_profile_descriptors:
print(f' {bluetooth_profile_descriptor.value[0].value} - version {version_major}.{version_minor}') version_major = (
bluetooth_profile_descriptor.value[1].value >> 8
)
version_minor = (
bluetooth_profile_descriptor.value[1].value
& 0xFF
)
print(
f' {bluetooth_profile_descriptor.value[0].value} - version {version_major}.{version_minor}'
)
# List service classes # List service classes
service_class_id_list = ServiceAttribute.find_attribute_in_list( service_class_id_list = ServiceAttribute.find_attribute_in_list(
attribute_list, attribute_list, SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID
) )
if service_class_id_list: if service_class_id_list:
if service_class_id_list.value: if service_class_id_list.value:
@@ -109,9 +125,15 @@ async def list_rfcomm_channels(device, connection):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def main(): async def main():
if len(sys.argv) < 4: if len(sys.argv) < 4:
print('Usage: run_hfp_gateway.py <device-config> <transport-spec> <bluetooth-address>') print(
print(' specifying a channel number, or "discover" to list all RFCOMM channels') 'Usage: run_hfp_gateway.py <device-config> <transport-spec> <bluetooth-address>'
print('example: run_hfp_gateway.py hfp_gateway.json usb:04b4:f901 E1:CA:72:48:C4:E8') )
print(
' specifying a channel number, or "discover" to list all RFCOMM channels'
)
print(
'example: run_hfp_gateway.py hfp_gateway.json usb:04b4:f901 E1:CA:72:48:C4:E8'
)
return return
print('<<< connecting to HCI...') print('<<< connecting to HCI...')
@@ -173,7 +195,9 @@ async def main():
protocol.send_response_line('+BRSF: 30') protocol.send_response_line('+BRSF: 30')
protocol.send_response_line('OK') protocol.send_response_line('OK')
elif line.startswith('AT+CIND=?'): elif line.startswith('AT+CIND=?'):
protocol.send_response_line('+CIND: ("call",(0,1)),("callsetup",(0-3)),("service",(0-1)),("signal",(0-5)),("roam",(0,1)),("battchg",(0-5)),("callheld",(0-2))') protocol.send_response_line(
'+CIND: ("call",(0,1)),("callsetup",(0-3)),("service",(0-1)),("signal",(0-5)),("roam",(0,1)),("battchg",(0-5)),("callheld",(0-2))'
)
protocol.send_response_line('OK') protocol.send_response_line('OK')
elif line.startswith('AT+CIND?'): elif line.startswith('AT+CIND?'):
protocol.send_response_line('+CIND: 0,0,1,4,1,5,0') protocol.send_response_line('+CIND: 0,0,1,4,1,5,0')
@@ -193,7 +217,9 @@ async def main():
elif line.startswith('AT+BIA='): elif line.startswith('AT+BIA='):
protocol.send_response_line('OK') protocol.send_response_line('OK')
elif line.startswith('AT+BVRA='): elif line.startswith('AT+BVRA='):
protocol.send_response_line('+BVRA: 1,1,12AA,1,1,"Message 1 from Janina"') protocol.send_response_line(
'+BVRA: 1,1,12AA,1,1,"Message 1 from Janina"'
)
elif line.startswith('AT+XEVENT='): elif line.startswith('AT+XEVENT='):
protocol.send_response_line('OK') protocol.send_response_line('OK')
elif line.startswith('AT+XAPL='): elif line.startswith('AT+XAPL='):
@@ -204,6 +230,7 @@ async def main():
await hci_source.wait_for_termination() await hci_source.wait_for_termination()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main()) asyncio.run(main())
+29 -19
View File
@@ -32,13 +32,13 @@ from bumble.sdp import (
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
) )
from bumble.core import ( from bumble.core import (
BT_GENERIC_AUDIO_SERVICE, BT_GENERIC_AUDIO_SERVICE,
BT_HANDSFREE_SERVICE, BT_HANDSFREE_SERVICE,
BT_L2CAP_PROTOCOL_ID, BT_L2CAP_PROTOCOL_ID,
BT_RFCOMM_PROTOCOL_ID BT_RFCOMM_PROTOCOL_ID,
) )
from bumble.hfp import HfpProtocol from bumble.hfp import HfpProtocol
@@ -49,37 +49,45 @@ def make_sdp_records(rfcomm_channel):
0x00010001: [ 0x00010001: [
ServiceAttribute( ServiceAttribute(
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
DataElement.unsigned_integer_32(0x00010001) DataElement.unsigned_integer_32(0x00010001),
), ),
ServiceAttribute( ServiceAttribute(
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
DataElement.sequence([ DataElement.sequence(
[
DataElement.uuid(BT_HANDSFREE_SERVICE), DataElement.uuid(BT_HANDSFREE_SERVICE),
DataElement.uuid(BT_GENERIC_AUDIO_SERVICE) DataElement.uuid(BT_GENERIC_AUDIO_SERVICE),
]) ]
),
), ),
ServiceAttribute( ServiceAttribute(
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence([ DataElement.sequence(
DataElement.sequence([ [
DataElement.uuid(BT_L2CAP_PROTOCOL_ID) DataElement.sequence([DataElement.uuid(BT_L2CAP_PROTOCOL_ID)]),
]), DataElement.sequence(
DataElement.sequence([ [
DataElement.uuid(BT_RFCOMM_PROTOCOL_ID), DataElement.uuid(BT_RFCOMM_PROTOCOL_ID),
DataElement.unsigned_integer_8(rfcomm_channel) DataElement.unsigned_integer_8(rfcomm_channel),
]) ]
]) ),
]
),
), ),
ServiceAttribute( ServiceAttribute(
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID, SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence([ DataElement.sequence(
DataElement.sequence([ [
DataElement.sequence(
[
DataElement.uuid(BT_HANDSFREE_SERVICE), DataElement.uuid(BT_HANDSFREE_SERVICE),
DataElement.unsigned_integer_16(0x0105) DataElement.unsigned_integer_16(0x0105),
]) ]
])
) )
] ]
),
),
]
} }
@@ -103,6 +111,7 @@ class UiServer:
except websockets.exceptions.ConnectionClosedOK: except websockets.exceptions.ConnectionClosedOK:
pass pass
await websockets.serve(serve, 'localhost', 8989) await websockets.serve(serve, 'localhost', 8989)
@@ -160,6 +169,7 @@ async def main():
await hci_source.wait_for_termination() await hci_source.wait_for_termination()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main()) asyncio.run(main())
+18 -11
View File
@@ -23,10 +23,7 @@ import logging
from bumble.device import Device, Connection from bumble.device import Device, Connection
from bumble.transport import open_transport_or_link from bumble.transport import open_transport_or_link
from bumble.gatt import ( from bumble.gatt import Service, Characteristic
Service,
Characteristic
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -41,7 +38,9 @@ class Listener(Device.Listener, Connection.Listener):
def on_disconnection(self, reason): def on_disconnection(self, reason):
print(f'### Disconnected, reason={reason}') print(f'### Disconnected, reason={reason}')
def on_characteristic_subscription(self, connection, characteristic, notify_enabled, indicate_enabled): def on_characteristic_subscription(
self, connection, characteristic, notify_enabled, indicate_enabled
):
print( print(
f'$$$ Characteristic subscription for handle {characteristic.handle} from {connection}: ' f'$$$ Characteristic subscription for handle {characteristic.handle} from {connection}: '
f'notify {"enabled" if notify_enabled else "disabled"}, ' f'notify {"enabled" if notify_enabled else "disabled"}, '
@@ -49,11 +48,18 @@ class Listener(Device.Listener, Connection.Listener):
) )
# -----------------------------------------------------------------------------
# Alternative way to listen for subscriptions
# -----------------------------------------------------------------------------
def on_my_characteristic_subscription(peer, enabled):
print(f'### My characteristic from {peer}: {"enabled" if enabled else "disabled"}')
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def main(): async def main():
if len(sys.argv) < 3: if len(sys.argv) < 3:
print('Usage: run_gatt_server.py <device-config> <transport-spec>') print('Usage: run_notifier.py <device-config> <transport-spec>')
print('example: run_gatt_server.py device1.json usb:0') print('example: run_notifier.py device1.json usb:0')
return return
print('<<< connecting to HCI...') print('<<< connecting to HCI...')
@@ -69,23 +75,24 @@ async def main():
'486F64C6-4B5F-4B3B-8AFF-EDE134A8446A', '486F64C6-4B5F-4B3B-8AFF-EDE134A8446A',
Characteristic.READ | Characteristic.NOTIFY, Characteristic.READ | Characteristic.NOTIFY,
Characteristic.READABLE, Characteristic.READABLE,
bytes([0x40]) bytes([0x40]),
) )
characteristic2 = Characteristic( characteristic2 = Characteristic(
'8EBDEBAE-0017-418E-8D3B-3A3809492165', '8EBDEBAE-0017-418E-8D3B-3A3809492165',
Characteristic.READ | Characteristic.INDICATE, Characteristic.READ | Characteristic.INDICATE,
Characteristic.READABLE, Characteristic.READABLE,
bytes([0x41]) bytes([0x41]),
) )
characteristic3 = Characteristic( characteristic3 = Characteristic(
'8EBDEBAE-0017-418E-8D3B-3A3809492165', '8EBDEBAE-0017-418E-8D3B-3A3809492165',
Characteristic.READ | Characteristic.NOTIFY | Characteristic.INDICATE, Characteristic.READ | Characteristic.NOTIFY | Characteristic.INDICATE,
Characteristic.READABLE, Characteristic.READABLE,
bytes([0x42]) bytes([0x42]),
) )
characteristic3.on('subscription', on_my_characteristic_subscription)
custom_service = Service( custom_service = Service(
'50DB505C-8AC4-4738-8448-3B1D9CC09CC5', '50DB505C-8AC4-4738-8448-3B1D9CC09CC5',
[characteristic1, characteristic2, characteristic3] [characteristic1, characteristic2, characteristic3],
) )
device.add_services([custom_service]) device.add_services([custom_service])
+49 -22
View File
@@ -31,7 +31,7 @@ from bumble.sdp import (
ServiceAttribute, ServiceAttribute,
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
) )
from bumble.hci import BT_L2CAP_PROTOCOL_ID, BT_RFCOMM_PROTOCOL_ID from bumble.hci import BT_L2CAP_PROTOCOL_ID, BT_RFCOMM_PROTOCOL_ID
@@ -48,47 +48,66 @@ async def list_rfcomm_channels(device, connection):
[ [
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID, SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
] ],
) )
print(color('==================================', 'blue')) print(color('==================================', 'blue'))
print(color('RFCOMM Services:', 'yellow')) print(color('RFCOMM Services:', 'yellow'))
for attribute_list in search_result: for attribute_list in search_result:
# Look for the RFCOMM Channel number # Look for the RFCOMM Channel number
protocol_descriptor_list = ServiceAttribute.find_attribute_in_list( protocol_descriptor_list = ServiceAttribute.find_attribute_in_list(
attribute_list, attribute_list, SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID
) )
if protocol_descriptor_list: if protocol_descriptor_list:
for protocol_descriptor in protocol_descriptor_list.value: for protocol_descriptor in protocol_descriptor_list.value:
if len(protocol_descriptor.value) >= 2: if len(protocol_descriptor.value) >= 2:
if protocol_descriptor.value[0].value == BT_RFCOMM_PROTOCOL_ID: if protocol_descriptor.value[0].value == BT_RFCOMM_PROTOCOL_ID:
print(color('SERVICE:', 'green')) print(color('SERVICE:', 'green'))
print(color(' RFCOMM Channel:', 'cyan'), protocol_descriptor.value[1].value) print(
color(' RFCOMM Channel:', 'cyan'),
protocol_descriptor.value[1].value,
)
# List profiles # List profiles
bluetooth_profile_descriptor_list = ServiceAttribute.find_attribute_in_list( bluetooth_profile_descriptor_list = (
ServiceAttribute.find_attribute_in_list(
attribute_list, attribute_list,
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
)
) )
if bluetooth_profile_descriptor_list: if bluetooth_profile_descriptor_list:
if bluetooth_profile_descriptor_list.value: if bluetooth_profile_descriptor_list.value:
if bluetooth_profile_descriptor_list.value[0].type == DataElement.SEQUENCE: if (
bluetooth_profile_descriptors = bluetooth_profile_descriptor_list.value bluetooth_profile_descriptor_list.value[0].type
== DataElement.SEQUENCE
):
bluetooth_profile_descriptors = (
bluetooth_profile_descriptor_list.value
)
else: else:
# Sometimes, instead of a list of lists, we just find a list. Fix that # Sometimes, instead of a list of lists, we just find a list. Fix that
bluetooth_profile_descriptors = [bluetooth_profile_descriptor_list] bluetooth_profile_descriptors = [
bluetooth_profile_descriptor_list
]
print(color(' Profiles:', 'green')) print(color(' Profiles:', 'green'))
for bluetooth_profile_descriptor in bluetooth_profile_descriptors: for (
version_major = bluetooth_profile_descriptor.value[1].value >> 8 bluetooth_profile_descriptor
version_minor = bluetooth_profile_descriptor.value[1].value & 0xFF ) in bluetooth_profile_descriptors:
print(f' {bluetooth_profile_descriptor.value[0].value} - version {version_major}.{version_minor}') version_major = (
bluetooth_profile_descriptor.value[1].value >> 8
)
version_minor = (
bluetooth_profile_descriptor.value[1].value
& 0xFF
)
print(
f' {bluetooth_profile_descriptor.value[0].value} - version {version_major}.{version_minor}'
)
# List service classes # List service classes
service_class_id_list = ServiceAttribute.find_attribute_in_list( service_class_id_list = ServiceAttribute.find_attribute_in_list(
attribute_list, attribute_list, SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID
) )
if service_class_id_list: if service_class_id_list:
if service_class_id_list.value: if service_class_id_list.value:
@@ -98,6 +117,7 @@ async def list_rfcomm_channels(device, connection):
await sdp_client.disconnect() await sdp_client.disconnect()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class TcpServerProtocol(asyncio.Protocol): class TcpServerProtocol(asyncio.Protocol):
def __init__(self, rfcomm_session): def __init__(self, rfcomm_session):
@@ -137,9 +157,15 @@ async def tcp_server(tcp_port, rfcomm_session):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def main(): async def main():
if len(sys.argv) < 5: if len(sys.argv) < 5:
print('Usage: run_rfcomm_client.py <device-config> <transport-spec> <bluetooth-address> <channel>|discover [tcp-port]') print(
print(' specifying a channel number, or "discover" to list all RFCOMM channels') 'Usage: run_rfcomm_client.py <device-config> <transport-spec> <bluetooth-address> <channel>|discover [tcp-port]'
print('example: run_rfcomm_client.py classic1.json usb:04b4:f901 E1:CA:72:48:C4:E8 8') )
print(
' specifying a channel number, or "discover" to list all RFCOMM channels'
)
print(
'example: run_rfcomm_client.py classic1.json usb:04b4:f901 E1:CA:72:48:C4:E8 8'
)
return return
print('<<< connecting to HCI...') print('<<< connecting to HCI...')
@@ -173,7 +199,7 @@ async def main():
print('*** Encryption on') print('*** Encryption on')
# Create a client and start it # Create a client and start it
print('@@@ Starting to RFCOMM client...') print('@@@ Starting RFCOMM client...')
rfcomm_client = Client(device, connection) rfcomm_client = Client(device, connection)
rfcomm_mux = await rfcomm_client.start() rfcomm_mux = await rfcomm_client.start()
print('@@@ Started') print('@@@ Started')
@@ -192,10 +218,11 @@ async def main():
if len(sys.argv) == 6: if len(sys.argv) == 6:
# A TCP port was specified, start listening # A TCP port was specified, start listening
tcp_port = int(sys.argv[5]) tcp_port = int(sys.argv[5])
asyncio.get_running_loop().create_task(tcp_server(tcp_port, session)) asyncio.create_task(tcp_server(tcp_port, session))
await hci_source.wait_for_termination() await hci_source.wait_for_termination()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main()) asyncio.run(main())
+29 -16
View File
@@ -31,7 +31,7 @@ from bumble.sdp import (
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID, SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
) )
from bumble.hci import BT_L2CAP_PROTOCOL_ID, BT_RFCOMM_PROTOCOL_ID from bumble.hci import BT_L2CAP_PROTOCOL_ID, BT_RFCOMM_PROTOCOL_ID
@@ -40,22 +40,34 @@ from bumble.hci import BT_L2CAP_PROTOCOL_ID, BT_RFCOMM_PROTOCOL_ID
def sdp_records(channel): def sdp_records(channel):
return { return {
0x00010001: [ 0x00010001: [
ServiceAttribute(SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, DataElement.unsigned_integer_32(0x00010001)), ServiceAttribute(
ServiceAttribute(SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID, DataElement.sequence([ SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT) DataElement.unsigned_integer_32(0x00010001),
])), ),
ServiceAttribute(SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, DataElement.sequence([ ServiceAttribute(
DataElement.uuid(UUID('E6D55659-C8B4-4B85-96BB-B1143AF6D3AE')) SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
])), DataElement.sequence([DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)]),
ServiceAttribute(SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, DataElement.sequence([ ),
DataElement.sequence([ ServiceAttribute(
DataElement.uuid(BT_L2CAP_PROTOCOL_ID) SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
]), DataElement.sequence(
DataElement.sequence([ [DataElement.uuid(UUID('E6D55659-C8B4-4B85-96BB-B1143AF6D3AE'))]
),
),
ServiceAttribute(
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.sequence([DataElement.uuid(BT_L2CAP_PROTOCOL_ID)]),
DataElement.sequence(
[
DataElement.uuid(BT_RFCOMM_PROTOCOL_ID), DataElement.uuid(BT_RFCOMM_PROTOCOL_ID),
DataElement.unsigned_integer_8(channel) DataElement.unsigned_integer_8(channel),
]) ]
])) ),
]
),
),
] ]
} }
@@ -113,6 +125,7 @@ async def main():
await hci_source.wait_for_termination() await hci_source.wait_for_termination()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main()) asyncio.run(main())
+12 -7
View File
@@ -35,35 +35,40 @@ async def main():
print('<<< connecting to HCI...') print('<<< connecting to HCI...')
async with await open_transport_or_link(sys.argv[1]) as (hci_source, hci_sink): async with await open_transport_or_link(sys.argv[1]) as (hci_source, hci_sink):
print('<<< connected') print('<<< connected')
filter_duplicates = (len(sys.argv) == 3 and sys.argv[2] == 'filter') filter_duplicates = len(sys.argv) == 3 and sys.argv[2] == 'filter'
device = Device.with_hci('Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink) device = Device.with_hci('Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink)
@device.on('advertisement') @device.on('advertisement')
def _(address, ad_data, rssi, connectable): def _(advertisement):
address_type_string = ('PUBLIC', 'RANDOM', 'PUBLIC_ID', 'RANDOM_ID')[address.address_type] address_type_string = ('PUBLIC', 'RANDOM', 'PUBLIC_ID', 'RANDOM_ID')[
address_color = 'yellow' if connectable else 'red' advertisement.address.address_type
]
address_color = 'yellow' if advertisement.is_connectable else 'red'
address_qualifier = '' address_qualifier = ''
if address_type_string.startswith('P'): if address_type_string.startswith('P'):
type_color = 'cyan' type_color = 'cyan'
else: else:
if address.is_static: if advertisement.address.is_static:
type_color = 'green' type_color = 'green'
address_qualifier = '(static)' address_qualifier = '(static)'
elif address.is_resolvable: elif advertisement.address.is_resolvable:
type_color = 'magenta' type_color = 'magenta'
address_qualifier = '(resolvable)' address_qualifier = '(resolvable)'
else: else:
type_color = 'white' type_color = 'white'
separator = '\n ' separator = '\n '
print(f'>>> {color(address, address_color)} [{color(address_type_string, type_color)}]{address_qualifier}:{separator}RSSI:{rssi}{separator}{ad_data.to_string(separator)}') print(
f'>>> {color(advertisement.address, address_color)} [{color(address_type_string, type_color)}]{address_qualifier}:{separator}RSSI:{advertisement.rssi}{separator}{advertisement.data.to_string(separator)}'
)
await device.power_on() await device.power_on()
await device.start_scanning(filter_duplicates=filter_duplicates) await device.start_scanning(filter_duplicates=filter_duplicates)
await hci_source.wait_for_termination() await hci_source.wait_for_termination()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main()) asyncio.run(main())
+6 -2
View File
@@ -48,8 +48,10 @@ install_requires =
[options.entry_points] [options.entry_points]
console_scripts = console_scripts =
bumble-console = bumble.apps.console:main bumble-console = bumble.apps.console:main
bumble-controller-info = bumble.apps.controller_info:main
bumble-gatt-dump = bumble.apps.gatt_dump:main bumble-gatt-dump = bumble.apps.gatt_dump:main
bumble-hci-bridge = bumble.apps.hci_bridge:main bumble-hci-bridge = bumble.apps.hci_bridge:main
bumble-l2cap-bridge = bumble.apps.l2cap_bridge:main
bumble-pair = bumble.apps.pair:main bumble-pair = bumble.apps.pair:main
bumble-scan = bumble.apps.scan:main bumble-scan = bumble.apps.scan:main
bumble-show = bumble.apps.show:main bumble-show = bumble.apps.show:main
@@ -63,10 +65,12 @@ build =
test = test =
pytest >= 6.2 pytest >= 6.2
pytest-asyncio >= 0.17 pytest-asyncio >= 0.17
pytest-html >= 3.2.0
coverage >= 6.4
development = development =
invoke >= 1.4 invoke >= 1.4
nox >= 2022 nox >= 2022
documentation = documentation =
mkdocs >= 1.2.3 mkdocs >= 1.4.0
mkdocs-material >= 8.1.9 mkdocs-material >= 8.5.6
mkdocstrings[python] >= 0.19.0 mkdocstrings[python] >= 0.19.0
+1
View File
@@ -13,4 +13,5 @@
# limitations under the License. # limitations under the License.
from setuptools import setup from setuptools import setup
setup() setup()
+18 -4
View File
@@ -27,6 +27,7 @@ ns = Collection()
build_tasks = Collection() build_tasks = Collection()
ns.add_collection(build_tasks, name="build") ns.add_collection(build_tasks, name="build")
@task @task
def build(ctx, install=False): def build(ctx, install=False):
if install: if install:
@@ -34,26 +35,32 @@ def build(ctx, install=False):
ctx.run("python -m build") ctx.run("python -m build")
build_tasks.add_task(build, default=True) build_tasks.add_task(build, default=True)
@task @task
def release_build(ctx): def release_build(ctx):
build(ctx, install=True) build(ctx, install=True)
build_tasks.add_task(release_build, name="release") build_tasks.add_task(release_build, name="release")
@task @task
def mkdocs(ctx): def mkdocs(ctx):
ctx.run("mkdocs build -f docs/mkdocs/mkdocs.yml") ctx.run("mkdocs build -f docs/mkdocs/mkdocs.yml")
build_tasks.add_task(mkdocs, name="mkdocs") build_tasks.add_task(mkdocs, name="mkdocs")
# Testing # Testing
test_tasks = Collection() test_tasks = Collection()
ns.add_collection(test_tasks, name="test") ns.add_collection(test_tasks, name="test")
@task
def test(ctx, filter=None, junit=False, install=False): @task(incrementable=["verbose"])
def test(ctx, filter=None, junit=False, install=False, html=False, verbose=0):
# Install the package before running the tests # Install the package before running the tests
if install: if install:
ctx.run("python -m pip install .[test]") ctx.run("python -m pip install .[test]")
@@ -62,13 +69,20 @@ def test(ctx, filter=None, junit=False, install=False):
if junit: if junit:
args += "--junit-xml test-results.xml" args += "--junit-xml test-results.xml"
if filter is not None: if filter is not None:
args += " -k '{}'".format(filter) args += f" -k '{filter}'"
ctx.run("python -m pytest {} {}".format(os.path.join(ROOT_DIR, "tests"), args)) if html:
args += " --html results.html"
if verbose > 0:
args += f" -{'v' * verbose}"
ctx.run(f"python -m pytest {os.path.join(ROOT_DIR, 'tests')} {args}")
test_tasks.add_task(test, default=True) test_tasks.add_task(test, default=True)
@task @task
def release_test(ctx): def release_test(ctx):
test(ctx, install=True) test(ctx, install=True)
test_tasks.add_task(release_test, name="release") test_tasks.add_task(release_test, name="release")
+57 -47
View File
@@ -35,7 +35,7 @@ from bumble.avdtp import (
MediaPacket, MediaPacket,
AVDTP_AUDIO_MEDIA_TYPE, AVDTP_AUDIO_MEDIA_TYPE,
AVDTP_TSEP_SNK, AVDTP_TSEP_SNK,
A2DP_SBC_CODEC_TYPE A2DP_SBC_CODEC_TYPE,
) )
from bumble.a2dp import ( from bumble.a2dp import (
SbcMediaCodecInformation, SbcMediaCodecInformation,
@@ -44,7 +44,7 @@ from bumble.a2dp import (
SBC_STEREO_CHANNEL_MODE, SBC_STEREO_CHANNEL_MODE,
SBC_JOINT_STEREO_CHANNEL_MODE, SBC_JOINT_STEREO_CHANNEL_MODE,
SBC_LOUDNESS_ALLOCATION_METHOD, SBC_LOUDNESS_ALLOCATION_METHOD,
SBC_SNR_ALLOCATION_METHOD SBC_SNR_ALLOCATION_METHOD,
) )
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -61,17 +61,17 @@ class TwoDevices:
self.link = LocalLink() self.link = LocalLink()
self.controllers = [ self.controllers = [
Controller('C1', link=self.link), Controller('C1', link=self.link),
Controller('C2', link = self.link) Controller('C2', link=self.link),
] ]
self.devices = [ self.devices = [
Device( Device(
address='F0:F1:F2:F3:F4:F5', address='F0:F1:F2:F3:F4:F5',
host = Host(self.controllers[0], AsyncPipeSink(self.controllers[0])) host=Host(self.controllers[0], AsyncPipeSink(self.controllers[0])),
), ),
Device( Device(
address='F5:F4:F3:F2:F1:F0', address='F5:F4:F3:F2:F1:F0',
host = Host(self.controllers[1], AsyncPipeSink(self.controllers[1])) host=Host(self.controllers[1], AsyncPipeSink(self.controllers[1])),
) ),
] ]
self.paired = [None, None] self.paired = [None, None]
@@ -87,8 +87,12 @@ async def test_self_connection():
two_devices = TwoDevices() two_devices = TwoDevices()
# Attach listeners # Attach listeners
two_devices.devices[0].on('connection', lambda connection: two_devices.on_connection(0, connection)) two_devices.devices[0].on(
two_devices.devices[1].on('connection', lambda connection: two_devices.on_connection(1, connection)) 'connection', lambda connection: two_devices.on_connection(0, connection)
)
two_devices.devices[1].on(
'connection', lambda connection: two_devices.on_connection(1, connection)
)
# Start # Start
await two_devices.devices[0].power_on() await two_devices.devices[0].power_on()
@@ -98,8 +102,8 @@ async def test_self_connection():
await two_devices.devices[0].connect(two_devices.devices[1].random_address) await two_devices.devices[0].connect(two_devices.devices[1].random_address)
# Check the post conditions # Check the post conditions
assert(two_devices.connections[0] is not None) assert two_devices.connections[0] is not None
assert(two_devices.connections[1] is not None) assert two_devices.connections[1] is not None
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -114,8 +118,8 @@ def source_codec_capabilities():
subbands=8, subbands=8,
allocation_method=SBC_LOUDNESS_ALLOCATION_METHOD, allocation_method=SBC_LOUDNESS_ALLOCATION_METHOD,
minimum_bitpool_value=2, minimum_bitpool_value=2,
maximum_bitpool_value = 53 maximum_bitpool_value=53,
) ),
) )
@@ -130,14 +134,17 @@ def sink_codec_capabilities():
SBC_MONO_CHANNEL_MODE, SBC_MONO_CHANNEL_MODE,
SBC_DUAL_CHANNEL_MODE, SBC_DUAL_CHANNEL_MODE,
SBC_STEREO_CHANNEL_MODE, SBC_STEREO_CHANNEL_MODE,
SBC_JOINT_STEREO_CHANNEL_MODE SBC_JOINT_STEREO_CHANNEL_MODE,
], ],
block_lengths=[4, 8, 12, 16], block_lengths=[4, 8, 12, 16],
subbands=[4, 8], subbands=[4, 8],
allocation_methods = [SBC_LOUDNESS_ALLOCATION_METHOD, SBC_SNR_ALLOCATION_METHOD], allocation_methods=[
SBC_LOUDNESS_ALLOCATION_METHOD,
SBC_SNR_ALLOCATION_METHOD,
],
minimum_bitpool_value=2, minimum_bitpool_value=2,
maximum_bitpool_value = 53 maximum_bitpool_value=53,
) ),
) )
@@ -164,21 +171,25 @@ async def test_source_sink_1():
listener = Listener(Listener.create_registrar(two_devices.devices[1])) listener = Listener(Listener.create_registrar(two_devices.devices[1]))
listener.on('connection', on_avdtp_connection) listener.on('connection', on_avdtp_connection)
connection = await two_devices.devices[0].connect(two_devices.devices[1].random_address) connection = await two_devices.devices[0].connect(
two_devices.devices[1].random_address
)
client = await Protocol.connect(connection) client = await Protocol.connect(connection)
endpoints = await client.discover_remote_endpoints() endpoints = await client.discover_remote_endpoints()
assert(len(endpoints) == 1) assert len(endpoints) == 1
remote_sink = list(endpoints)[0] remote_sink = list(endpoints)[0]
assert(remote_sink.in_use == 0) assert remote_sink.in_use == 0
assert(remote_sink.media_type == AVDTP_AUDIO_MEDIA_TYPE) assert remote_sink.media_type == AVDTP_AUDIO_MEDIA_TYPE
assert(remote_sink.tsep == AVDTP_TSEP_SNK) assert remote_sink.tsep == AVDTP_TSEP_SNK
async def generate_packets(packet_count): async def generate_packets(packet_count):
sequence_number = 0 sequence_number = 0
timestamp = 0 timestamp = 0
for i in range(packet_count): for i in range(packet_count):
payload = bytes([sequence_number % 256]) payload = bytes([sequence_number % 256])
packet = MediaPacket(2, 0, 0, 0, sequence_number, timestamp, 0, [], 96, payload) packet = MediaPacket(
2, 0, 0, 0, sequence_number, timestamp, 0, [], 96, payload
)
packet.timestamp_seconds = timestamp / 44100 packet.timestamp_seconds = timestamp / 44100
timestamp += 10 timestamp += 10
sequence_number += 1 sequence_number += 1
@@ -192,50 +203,49 @@ async def test_source_sink_1():
source = client.add_source(source_codec_capabilities(), pump) source = client.add_source(source_codec_capabilities(), pump)
stream = await client.create_stream(source, remote_sink) stream = await client.create_stream(source, remote_sink)
await stream.start() await stream.start()
assert(stream.state == AVDTP_STREAMING_STATE) assert stream.state == AVDTP_STREAMING_STATE
assert(stream.local_endpoint.in_use == 1) assert stream.local_endpoint.in_use == 1
assert(stream.rtp_channel is not None) assert stream.rtp_channel is not None
assert(sink.in_use == 1) assert sink.in_use == 1
assert(sink.stream is not None) assert sink.stream is not None
assert(sink.stream.state == AVDTP_STREAMING_STATE) assert sink.stream.state == AVDTP_STREAMING_STATE
await rtp_packets_fully_received await rtp_packets_fully_received
await stream.close() await stream.close()
assert(stream.rtp_channel is None) assert stream.rtp_channel is None
assert(source.in_use == 0) assert source.in_use == 0
assert(source.stream.state == AVDTP_IDLE_STATE) assert source.stream.state == AVDTP_IDLE_STATE
assert(sink.in_use == 0) assert sink.in_use == 0
assert(sink.stream.state == AVDTP_IDLE_STATE) assert sink.stream.state == AVDTP_IDLE_STATE
# Send packets manually # Send packets manually
rtp_packets_fully_received = asyncio.get_running_loop().create_future() rtp_packets_fully_received = asyncio.get_running_loop().create_future()
rtp_packets_expected = 3 rtp_packets_expected = 3
rtp_packets = [] rtp_packets = []
source_packets = [ source_packets = [
MediaPacket(2, 0, 0, 0, i, i * 10, 0, [], 96, bytes([i])) MediaPacket(2, 0, 0, 0, i, i * 10, 0, [], 96, bytes([i])) for i in range(3)
for i in range(3)
] ]
source = client.add_source(source_codec_capabilities(), None) source = client.add_source(source_codec_capabilities(), None)
stream = await client.create_stream(source, remote_sink) stream = await client.create_stream(source, remote_sink)
await stream.start() await stream.start()
assert(stream.state == AVDTP_STREAMING_STATE) assert stream.state == AVDTP_STREAMING_STATE
assert(stream.local_endpoint.in_use == 1) assert stream.local_endpoint.in_use == 1
assert(stream.rtp_channel is not None) assert stream.rtp_channel is not None
assert(sink.in_use == 1) assert sink.in_use == 1
assert(sink.stream is not None) assert sink.stream is not None
assert(sink.stream.state == AVDTP_STREAMING_STATE) assert sink.stream.state == AVDTP_STREAMING_STATE
stream.send_media_packet(source_packets[0]) stream.send_media_packet(source_packets[0])
stream.send_media_packet(source_packets[1]) stream.send_media_packet(source_packets[1])
stream.send_media_packet(source_packets[2]) stream.send_media_packet(source_packets[2])
await stream.close() await stream.close()
assert(stream.rtp_channel is None) assert stream.rtp_channel is None
assert(len(rtp_packets) == 3) assert len(rtp_packets) == 3
assert(source.in_use == 0) assert source.in_use == 0
assert(source.stream.state == AVDTP_IDLE_STATE) assert source.stream.state == AVDTP_IDLE_STATE
assert(sink.in_use == 0) assert sink.in_use == 0
assert(sink.stream.state == AVDTP_IDLE_STATE) assert sink.stream.state == AVDTP_IDLE_STATE
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
+11 -7
View File
@@ -28,7 +28,7 @@ from bumble.avdtp import (
Set_Configuration_Command, Set_Configuration_Command,
Set_Configuration_Response, Set_Configuration_Response,
ServiceCapabilities, ServiceCapabilities,
MediaCodecCapabilities MediaCodecCapabilities,
) )
@@ -39,22 +39,26 @@ def test_messages():
MediaCodecCapabilities( MediaCodecCapabilities(
media_type=AVDTP_AUDIO_MEDIA_TYPE, media_type=AVDTP_AUDIO_MEDIA_TYPE,
media_codec_type=A2DP_SBC_CODEC_TYPE, media_codec_type=A2DP_SBC_CODEC_TYPE,
media_codec_information = bytes.fromhex('211502fa') media_codec_information=bytes.fromhex('211502fa'),
), ),
ServiceCapabilities(AVDTP_DELAY_REPORTING_SERVICE_CATEGORY) ServiceCapabilities(AVDTP_DELAY_REPORTING_SERVICE_CATEGORY),
] ]
message = Get_Capabilities_Response(capabilities) message = Get_Capabilities_Response(capabilities)
parsed = Message.create(AVDTP_GET_CAPABILITIES, Message.RESPONSE_ACCEPT, message.payload) parsed = Message.create(
assert(message.payload == parsed.payload) AVDTP_GET_CAPABILITIES, Message.RESPONSE_ACCEPT, message.payload
)
assert message.payload == parsed.payload
message = Set_Configuration_Command(3, 4, capabilities) message = Set_Configuration_Command(3, 4, capabilities)
parsed = Message.create(AVDTP_SET_CONFIGURATION, Message.COMMAND, message.payload) parsed = Message.create(AVDTP_SET_CONFIGURATION, Message.COMMAND, message.payload)
assert(message.payload == parsed.payload) assert message.payload == parsed.payload
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def test_rtp(): def test_rtp():
packet = bytes.fromhex('8060000103141c6a000000000a9cbd2adbfe75443333542210037eeeed5f76dfbbbb57ddb890eed5f76e2ad3958613d3d04a5f596fc2b54d613a6a95570b4b49c2d0955ac710ca6abb293bb4580d5896b106cd6a7c4b557d8bb73aac56b8e633aa161447caa86585ae4cbc9576cc9cbd2a54fe7443322064221000b44a5cd51929bc96328916b1694e1f3611d6b6928dbf554b01e96d23a6ad879834d99326a649b94ca6adbeab1311e372a3aa3468e9582d2d9c857da28e5b76a2d363089367432930a0160af22d48911bc46cea549cbd2a03fe754332206532210054cf1d3d9260d3bc9895566f124b22c4b3cb6bc66648cf9b21e1613a48b3592466e90cee3424cc6cc56d2f569b12145234c6bd73560c95ad9c584c9d6c26552cea9905da55b3eab182c40e2dae64b46c328ba64d9cbd2a3cde74433220643211001e8d1ad6210d5c26b296d40d298a29b073b46bb4542ceb1aea011612c6df64c731068d49b56bb48afb2456ea9b5903222bb63b8b1a60c52896325a22aad781486cdb36269d9dc6dd38d9acf5b0e9328e0b23542c9cbd2adffe744323206432200095731b2a62604accea58da8ee6aba6d6fc9169ab66a824527412a66ac6c5c41d12c85295673c3263848c88ae934f62619c46ed2adccaaeb3eac70c396bb28cb8cecaf22423c548cd4adca92d30d1370ba34a772d9cbd2a3efe6442221064322100cc932cd12222dcd854d6da8d09330d2708b392a3997ec8a2f30b9312b8c562d9353513eda7733c4b835176eeca695909cc10d08614574d36cac669c583e68d9778daca9b92d6e4bb5cd008ef3562aa52332bc54a9cbd2a1efe6443332064322100a6e91a6ddc58a3a4b966a3452cb6d0b9c5334d2b695929128dcd6123b8b366d491122fd545f9b96cf769d530d2e2646b15c6a43695b12d33aa214e622e45b1ac132309a39eddc82caad35115b3d2350c5c6dcd749cbd2a9c7e654332207433110086ed5b68531a54c6e7bb052d15add1b204bd62568d8922d3379418b9c4e202482909ab712a744d81f392fa94193d62293ac6dfa7278f79b451c70c3b4b2b64d70f0b3463323c46f598ecd70d35e5a743282307099cbd2ae9fe654332106432110082acdb4aca734b843b6699f491ad3a511aab6db2344eeed386d0aa34c49c4b0a4b2aa59ec98bba6419b06310d2f9626c42a7466728f0ca0f1db579b46c0a701264e59153535228dc6497492dac722596138bd74a9cbd2a0b7e655432107432110056a8d22a62d643b428e513b52ea4a66c7a41991719370c8d9664ce2bca685dd2690b1c368c5dce36d26b38d10e0c672343ca8c25c58d0d5c568de433b7561c61268aaf83260b4b868dca8ee6dc6ba573abcb5093') packet = bytes.fromhex(
'8060000103141c6a000000000a9cbd2adbfe75443333542210037eeeed5f76dfbbbb57ddb890eed5f76e2ad3958613d3d04a5f596fc2b54d613a6a95570b4b49c2d0955ac710ca6abb293bb4580d5896b106cd6a7c4b557d8bb73aac56b8e633aa161447caa86585ae4cbc9576cc9cbd2a54fe7443322064221000b44a5cd51929bc96328916b1694e1f3611d6b6928dbf554b01e96d23a6ad879834d99326a649b94ca6adbeab1311e372a3aa3468e9582d2d9c857da28e5b76a2d363089367432930a0160af22d48911bc46cea549cbd2a03fe754332206532210054cf1d3d9260d3bc9895566f124b22c4b3cb6bc66648cf9b21e1613a48b3592466e90cee3424cc6cc56d2f569b12145234c6bd73560c95ad9c584c9d6c26552cea9905da55b3eab182c40e2dae64b46c328ba64d9cbd2a3cde74433220643211001e8d1ad6210d5c26b296d40d298a29b073b46bb4542ceb1aea011612c6df64c731068d49b56bb48afb2456ea9b5903222bb63b8b1a60c52896325a22aad781486cdb36269d9dc6dd38d9acf5b0e9328e0b23542c9cbd2adffe744323206432200095731b2a62604accea58da8ee6aba6d6fc9169ab66a824527412a66ac6c5c41d12c85295673c3263848c88ae934f62619c46ed2adccaaeb3eac70c396bb28cb8cecaf22423c548cd4adca92d30d1370ba34a772d9cbd2a3efe6442221064322100cc932cd12222dcd854d6da8d09330d2708b392a3997ec8a2f30b9312b8c562d9353513eda7733c4b835176eeca695909cc10d08614574d36cac669c583e68d9778daca9b92d6e4bb5cd008ef3562aa52332bc54a9cbd2a1efe6443332064322100a6e91a6ddc58a3a4b966a3452cb6d0b9c5334d2b695929128dcd6123b8b366d491122fd545f9b96cf769d530d2e2646b15c6a43695b12d33aa214e622e45b1ac132309a39eddc82caad35115b3d2350c5c6dcd749cbd2a9c7e654332207433110086ed5b68531a54c6e7bb052d15add1b204bd62568d8922d3379418b9c4e202482909ab712a744d81f392fa94193d62293ac6dfa7278f79b451c70c3b4b2b64d70f0b3463323c46f598ecd70d35e5a743282307099cbd2ae9fe654332106432110082acdb4aca734b843b6699f491ad3a511aab6db2344eeed386d0aa34c49c4b0a4b2aa59ec98bba6419b06310d2f9626c42a7466728f0ca0f1db579b46c0a701264e59153535228dc6497492dac722596138bd74a9cbd2a0b7e655432107432110056a8d22a62d643b428e513b52ea4a66c7a41991719370c8d9664ce2bca685dd2690b1c368c5dce36d26b38d10e0c672343ca8c25c58d0d5c568de433b7561c61268aaf83260b4b868dca8ee6dc6ba573abcb5093'
)
media_packet = MediaPacket.from_bytes(packet) media_packet = MediaPacket.from_bytes(packet)
print(media_packet) print(media_packet)
+24 -12
View File
@@ -15,28 +15,40 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from bumble.core import AdvertisingData from bumble.core import AdvertisingData, get_dict_key_by_value
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def test_ad_data(): def test_ad_data():
data = bytes([2, AdvertisingData.TX_POWER_LEVEL, 123]) data = bytes([2, AdvertisingData.TX_POWER_LEVEL, 123])
ad = AdvertisingData.from_bytes(data) ad = AdvertisingData.from_bytes(data)
ad_bytes = bytes(ad) ad_bytes = bytes(ad)
assert(data == ad_bytes) assert data == ad_bytes
assert(ad.get(AdvertisingData.COMPLETE_LOCAL_NAME) is None) assert ad.get(AdvertisingData.COMPLETE_LOCAL_NAME, raw=True) is None
assert(ad.get(AdvertisingData.TX_POWER_LEVEL) == bytes([123])) assert ad.get(AdvertisingData.TX_POWER_LEVEL, raw=True) == bytes([123])
assert(ad.get(AdvertisingData.COMPLETE_LOCAL_NAME, return_all=True) == []) assert ad.get(AdvertisingData.COMPLETE_LOCAL_NAME, return_all=True, raw=True) == []
assert(ad.get(AdvertisingData.TX_POWER_LEVEL, return_all=True) == [bytes([123])]) assert ad.get(AdvertisingData.TX_POWER_LEVEL, return_all=True, raw=True) == [
bytes([123])
]
data2 = bytes([2, AdvertisingData.TX_POWER_LEVEL, 234]) data2 = bytes([2, AdvertisingData.TX_POWER_LEVEL, 234])
ad.append(data2) ad.append(data2)
ad_bytes = bytes(ad) ad_bytes = bytes(ad)
assert(ad_bytes == data + data2) assert ad_bytes == data + data2
assert(ad.get(AdvertisingData.COMPLETE_LOCAL_NAME) is None) assert ad.get(AdvertisingData.COMPLETE_LOCAL_NAME, raw=True) is None
assert(ad.get(AdvertisingData.TX_POWER_LEVEL) == bytes([123])) assert ad.get(AdvertisingData.TX_POWER_LEVEL, raw=True) == bytes([123])
assert(ad.get(AdvertisingData.COMPLETE_LOCAL_NAME, return_all=True) == []) assert ad.get(AdvertisingData.COMPLETE_LOCAL_NAME, return_all=True, raw=True) == []
assert(ad.get(AdvertisingData.TX_POWER_LEVEL, return_all=True) == [bytes([123]), bytes([234])]) assert ad.get(AdvertisingData.TX_POWER_LEVEL, return_all=True, raw=True) == [
bytes([123]),
bytes([234]),
]
# -----------------------------------------------------------------------------
def test_get_dict_key_by_value():
dictionary = {"A": 1, "B": 2}
assert get_dict_key_by_value(dictionary, 1) == "A"
assert get_dict_key_by_value(dictionary, 2) == "B"
assert get_dict_key_by_value(dictionary, 3) is None
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
+254
View File
@@ -0,0 +1,254 @@
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import logging
import os
from types import LambdaType
import pytest
from bumble.core import BT_BR_EDR_TRANSPORT
from bumble.device import Connection, Device
from bumble.host import Host
from bumble.hci import (
HCI_ACCEPT_CONNECTION_REQUEST_COMMAND,
HCI_COMMAND_STATUS_PENDING,
HCI_CREATE_CONNECTION_COMMAND,
HCI_SUCCESS,
Address,
HCI_Command_Complete_Event,
HCI_Command_Status_Event,
HCI_Connection_Complete_Event,
HCI_Connection_Request_Event,
HCI_Packet,
)
from bumble.gatt import (
GATT_GENERIC_ACCESS_SERVICE,
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
GATT_DEVICE_NAME_CHARACTERISTIC,
GATT_APPEARANCE_CHARACTERISTIC,
)
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
class Sink:
def __init__(self, flow):
self.flow = flow
next(self.flow)
def on_packet(self, packet):
self.flow.send(packet)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_device_connect_parallel():
d0 = Device(host=Host(None, None))
d1 = Device(host=Host(None, None))
d2 = Device(host=Host(None, None))
# enable classic
d0.classic_enabled = True
d1.classic_enabled = True
d2.classic_enabled = True
# set public addresses
d0.public_address = Address(
'F0:F1:F2:F3:F4:F5', address_type=Address.PUBLIC_DEVICE_ADDRESS
)
d1.public_address = Address(
'F5:F4:F3:F2:F1:F0', address_type=Address.PUBLIC_DEVICE_ADDRESS
)
d2.public_address = Address(
'F5:F4:F3:F3:F4:F5', address_type=Address.PUBLIC_DEVICE_ADDRESS
)
def d0_flow():
packet = HCI_Packet.from_bytes((yield))
assert packet.name == 'HCI_CREATE_CONNECTION_COMMAND'
assert packet.bd_addr == d1.public_address
d0.host.on_hci_packet(
HCI_Command_Status_Event(
status=HCI_COMMAND_STATUS_PENDING,
num_hci_command_packets=1,
command_opcode=HCI_CREATE_CONNECTION_COMMAND,
)
)
d1.host.on_hci_packet(
HCI_Connection_Request_Event(
bd_addr=d0.public_address,
class_of_device=0,
link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE,
)
)
packet = HCI_Packet.from_bytes((yield))
assert packet.name == 'HCI_CREATE_CONNECTION_COMMAND'
assert packet.bd_addr == d2.public_address
d0.host.on_hci_packet(
HCI_Command_Status_Event(
status=HCI_COMMAND_STATUS_PENDING,
num_hci_command_packets=1,
command_opcode=HCI_CREATE_CONNECTION_COMMAND,
)
)
d2.host.on_hci_packet(
HCI_Connection_Request_Event(
bd_addr=d0.public_address,
class_of_device=0,
link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE,
)
)
assert (yield) == None
def d1_flow():
packet = HCI_Packet.from_bytes((yield))
assert packet.name == 'HCI_ACCEPT_CONNECTION_REQUEST_COMMAND'
d1.host.on_hci_packet(
HCI_Command_Complete_Event(
num_hci_command_packets=1,
command_opcode=HCI_ACCEPT_CONNECTION_REQUEST_COMMAND,
return_parameters=b"\x00",
)
)
d1.host.on_hci_packet(
HCI_Connection_Complete_Event(
status=HCI_SUCCESS,
connection_handle=0x100,
bd_addr=d0.public_address,
link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE,
encryption_enabled=True,
)
)
d0.host.on_hci_packet(
HCI_Connection_Complete_Event(
status=HCI_SUCCESS,
connection_handle=0x100,
bd_addr=d1.public_address,
link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE,
encryption_enabled=True,
)
)
assert (yield) == None
def d2_flow():
packet = HCI_Packet.from_bytes((yield))
assert packet.name == 'HCI_ACCEPT_CONNECTION_REQUEST_COMMAND'
d2.host.on_hci_packet(
HCI_Command_Complete_Event(
num_hci_command_packets=1,
command_opcode=HCI_ACCEPT_CONNECTION_REQUEST_COMMAND,
return_parameters=b"\x00",
)
)
d2.host.on_hci_packet(
HCI_Connection_Complete_Event(
status=HCI_SUCCESS,
connection_handle=0x101,
bd_addr=d0.public_address,
link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE,
encryption_enabled=True,
)
)
d0.host.on_hci_packet(
HCI_Connection_Complete_Event(
status=HCI_SUCCESS,
connection_handle=0x101,
bd_addr=d2.public_address,
link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE,
encryption_enabled=True,
)
)
assert (yield) == None
d0.host.set_packet_sink(Sink(d0_flow()))
d1.host.set_packet_sink(Sink(d1_flow()))
d2.host.set_packet_sink(Sink(d2_flow()))
[c01, c02, a10, a20, a01] = await asyncio.gather(
*[
asyncio.create_task(
d0.connect(d1.public_address, transport=BT_BR_EDR_TRANSPORT)
),
asyncio.create_task(
d0.connect(d2.public_address, transport=BT_BR_EDR_TRANSPORT)
),
asyncio.create_task(d1.accept(peer_address=d0.public_address)),
asyncio.create_task(d2.accept()),
asyncio.create_task(d0.accept(peer_address=d1.public_address)),
]
)
assert type(c01) == Connection
assert type(c02) == Connection
assert type(a10) == Connection
assert type(a20) == Connection
assert type(a01) == Connection
assert c01.handle == a10.handle and c01.handle == 0x100
assert c02.handle == a20.handle and c02.handle == 0x101
assert a01 == c01
# -----------------------------------------------------------------------------
async def run_test_device():
await test_device_connect_parallel()
# -----------------------------------------------------------------------------
def test_gatt_services_with_gas():
device = Device(host=Host(None, None))
# there should be one service and two chars, therefore 5 attributes
assert len(device.gatt_server.attributes) == 5
assert device.gatt_server.attributes[0].uuid == GATT_GENERIC_ACCESS_SERVICE
assert device.gatt_server.attributes[1].type == GATT_CHARACTERISTIC_ATTRIBUTE_TYPE
assert device.gatt_server.attributes[2].uuid == GATT_DEVICE_NAME_CHARACTERISTIC
assert device.gatt_server.attributes[3].type == GATT_CHARACTERISTIC_ATTRIBUTE_TYPE
assert device.gatt_server.attributes[4].uuid == GATT_APPEARANCE_CHARACTERISTIC
# -----------------------------------------------------------------------------
def test_gatt_services_without_gas():
device = Device(host=Host(None, None), generic_access_service=False)
# there should be no services
assert len(device.gatt_server.attributes) == 0
# -----------------------------------------------------------------------------
if __name__ == '__main__':
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(run_test_device())
+243 -63
View File
@@ -28,6 +28,7 @@ from bumble.device import Device, Peer
from bumble.host import Host from bumble.host import Host
from bumble.gatt import ( from bumble.gatt import (
GATT_BATTERY_LEVEL_CHARACTERISTIC, GATT_BATTERY_LEVEL_CHARACTERISTIC,
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
CharacteristicAdapter, CharacteristicAdapter,
DelegatedCharacteristicAdapter, DelegatedCharacteristicAdapter,
PackedCharacteristicAdapter, PackedCharacteristicAdapter,
@@ -35,7 +36,7 @@ from bumble.gatt import (
UTF8CharacteristicAdapter, UTF8CharacteristicAdapter,
Service, Service,
Characteristic, Characteristic,
CharacteristicValue CharacteristicValue,
) )
from bumble.transport import AsyncPipeSink from bumble.transport import AsyncPipeSink
from bumble.core import UUID from bumble.core import UUID
@@ -44,7 +45,7 @@ from bumble.att import (
ATT_ATTRIBUTE_NOT_FOUND_ERROR, ATT_ATTRIBUTE_NOT_FOUND_ERROR,
ATT_PDU, ATT_PDU,
ATT_Error_Response, ATT_Error_Response,
ATT_Read_By_Group_Type_Request ATT_Read_By_Group_Type_Request,
) )
@@ -75,7 +76,7 @@ def test_UUID():
u2 = UUID.from_bytes(b1) u2 = UUID.from_bytes(b1)
assert u1 == u2 assert u1 == u2
u3 = UUID.from_16_bits(0x180a) u3 = UUID.from_16_bits(0x180A)
assert str(u3) == 'UUID-16:180A (Device Information)' assert str(u3) == 'UUID-16:180A (Device Information)'
@@ -84,7 +85,7 @@ def test_ATT_Error_Response():
pdu = ATT_Error_Response( pdu = ATT_Error_Response(
request_opcode_in_error=ATT_EXCHANGE_MTU_REQUEST, request_opcode_in_error=ATT_EXCHANGE_MTU_REQUEST,
attribute_handle_in_error=0x0000, attribute_handle_in_error=0x0000,
error_code = ATT_ATTRIBUTE_NOT_FOUND_ERROR error_code=ATT_ATTRIBUTE_NOT_FOUND_ERROR,
) )
basic_check(pdu) basic_check(pdu)
@@ -94,7 +95,7 @@ def test_ATT_Read_By_Group_Type_Request():
pdu = ATT_Read_By_Group_Type_Request( pdu = ATT_Read_By_Group_Type_Request(
starting_handle=0x0001, starting_handle=0x0001,
ending_handle=0xFFFF, ending_handle=0xFFFF,
attribute_group_type = UUID.from_16_bits(0x2800) attribute_group_type=UUID.from_16_bits(0x2800),
) )
basic_check(pdu) basic_check(pdu)
@@ -109,7 +110,12 @@ async def test_characteristic_encoding():
def decode_value(self, value_bytes): def decode_value(self, value_bytes):
return value_bytes[0] return value_bytes[0]
c = Foo(GATT_BATTERY_LEVEL_CHARACTERISTIC, Characteristic.READ, Characteristic.READABLE, 123) c = Foo(
GATT_BATTERY_LEVEL_CHARACTERISTIC,
Characteristic.READ,
Characteristic.READABLE,
123,
)
x = c.read_value(None) x = c.read_value(None)
assert x == bytes([123]) assert x == bytes([123])
c.write_value(None, bytes([122])) c.write_value(None, bytes([122]))
@@ -122,7 +128,7 @@ async def test_characteristic_encoding():
characteristic.handle, characteristic.handle,
characteristic.end_group_handle, characteristic.end_group_handle,
characteristic.uuid, characteristic.uuid,
characteristic.properties characteristic.properties,
) )
def encode_value(self, value): def encode_value(self, value):
@@ -137,13 +143,10 @@ async def test_characteristic_encoding():
'FDB159DB-036C-49E3-B3DB-6325AC750806', 'FDB159DB-036C-49E3-B3DB-6325AC750806',
Characteristic.READ | Characteristic.WRITE | Characteristic.NOTIFY, Characteristic.READ | Characteristic.WRITE | Characteristic.NOTIFY,
Characteristic.READABLE | Characteristic.WRITEABLE, Characteristic.READABLE | Characteristic.WRITEABLE,
bytes([123]) bytes([123]),
) )
service = Service( service = Service('3A657F47-D34F-46B3-B1EC-698E29B6B829', [characteristic])
'3A657F47-D34F-46B3-B1EC-698E29B6B829',
[characteristic]
)
server.add_service(service) server.add_service(service)
await client.power_on() await client.power_on()
@@ -164,6 +167,17 @@ async def test_characteristic_encoding():
await async_barrier() await async_barrier()
assert characteristic.value == bytes([124]) assert characteristic.value == bytes([124])
v = await cp.read_value()
assert v == 124
await cp.write_value(125, with_response=True)
await async_barrier()
assert characteristic.value == bytes([125])
cd = DelegatedCharacteristicAdapter(c, encode=lambda x: bytes([x // 2]))
await cd.write_value(100, with_response=True)
await async_barrier()
assert characteristic.value == bytes([50])
last_change = None last_change = None
def on_change(value): def on_change(value):
@@ -215,11 +229,63 @@ async def test_characteristic_encoding():
assert last_change is None assert last_change is None
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_attribute_getters():
[client, server] = LinkedDevices().devices[:2]
characteristic_uuid = UUID('FDB159DB-036C-49E3-B3DB-6325AC750806')
characteristic = Characteristic(
characteristic_uuid,
Characteristic.READ | Characteristic.WRITE | Characteristic.NOTIFY,
Characteristic.READABLE | Characteristic.WRITEABLE,
bytes([123]),
)
service_uuid = UUID('3A657F47-D34F-46B3-B1EC-698E29B6B829')
service = Service(service_uuid, [characteristic])
server.add_service(service)
service_attr = server.gatt_server.get_service_attribute(service_uuid)
assert service_attr
(
char_decl_attr,
char_value_attr,
) = server.gatt_server.get_characteristic_attributes(
service_uuid, characteristic_uuid
)
assert char_decl_attr and char_value_attr
desc_attr = server.gatt_server.get_descriptor_attribute(
service_uuid,
characteristic_uuid,
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
)
assert desc_attr
# assert all handles are in expected order
assert (
service_attr.handle
< char_decl_attr.handle
< char_value_attr.handle
< desc_attr.handle
== service_attr.end_group_handle
)
# assert characteristic declarations attribute is followed by characteristic value attribute
assert char_decl_attr.handle + 1 == char_value_attr.handle
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def test_CharacteristicAdapter(): def test_CharacteristicAdapter():
# Check that the CharacteristicAdapter base class is transparent # Check that the CharacteristicAdapter base class is transparent
v = bytes([1, 2, 3]) v = bytes([1, 2, 3])
c = Characteristic(GATT_BATTERY_LEVEL_CHARACTERISTIC, Characteristic.READ, Characteristic.READABLE, v) c = Characteristic(
GATT_BATTERY_LEVEL_CHARACTERISTIC,
Characteristic.READ,
Characteristic.READABLE,
v,
)
a = CharacteristicAdapter(c) a = CharacteristicAdapter(c)
value = a.read_value(None) value = a.read_value(None)
@@ -230,7 +296,9 @@ def test_CharacteristicAdapter():
assert c.value == v assert c.value == v
# Simple delegated adapter # Simple delegated adapter
a = DelegatedCharacteristicAdapter(c, lambda x: bytes(reversed(x)), lambda x: bytes(reversed(x))) a = DelegatedCharacteristicAdapter(
c, lambda x: bytes(reversed(x)), lambda x: bytes(reversed(x))
)
value = a.read_value(None) value = a.read_value(None)
assert value == bytes(reversed(v)) assert value == bytes(reversed(v))
@@ -299,7 +367,9 @@ def test_CharacteristicValue():
assert x == b assert x == b
result = [] result = []
c = CharacteristicValue(write=lambda connection, value: result.append((connection, value))) c = CharacteristicValue(
write=lambda connection, value: result.append((connection, value))
)
z = object() z = object()
c.write(z, b) c.write(z, b)
assert result == [(z, b)] assert result == [(z, b)]
@@ -314,21 +384,21 @@ class LinkedDevices:
self.controllers = [ self.controllers = [
Controller('C1', link=self.link), Controller('C1', link=self.link),
Controller('C2', link=self.link), Controller('C2', link=self.link),
Controller('C3', link = self.link) Controller('C3', link=self.link),
] ]
self.devices = [ self.devices = [
Device( Device(
address='F0:F1:F2:F3:F4:F5', address='F0:F1:F2:F3:F4:F5',
host = Host(self.controllers[0], AsyncPipeSink(self.controllers[0])) host=Host(self.controllers[0], AsyncPipeSink(self.controllers[0])),
), ),
Device( Device(
address='F1:F2:F3:F4:F5:F6', address='F1:F2:F3:F4:F5:F6',
host = Host(self.controllers[1], AsyncPipeSink(self.controllers[1])) host=Host(self.controllers[1], AsyncPipeSink(self.controllers[1])),
), ),
Device( Device(
address='F2:F3:F4:F5:F6:F7', address='F2:F3:F4:F5:F6:F7',
host = Host(self.controllers[2], AsyncPipeSink(self.controllers[2])) host=Host(self.controllers[2], AsyncPipeSink(self.controllers[2])),
) ),
] ]
self.paired = [None, None, None] self.paired = [None, None, None]
@@ -349,7 +419,7 @@ async def test_read_write():
characteristic1 = Characteristic( characteristic1 = Characteristic(
'FDB159DB-036C-49E3-B3DB-6325AC750806', 'FDB159DB-036C-49E3-B3DB-6325AC750806',
Characteristic.READ | Characteristic.WRITE, Characteristic.READ | Characteristic.WRITE,
Characteristic.READABLE | Characteristic.WRITEABLE Characteristic.READABLE | Characteristic.WRITEABLE,
) )
def on_characteristic1_write(connection, value): def on_characteristic1_write(connection, value):
@@ -367,15 +437,13 @@ async def test_read_write():
'66DE9057-C848-4ACA-B993-D675644EBB85', '66DE9057-C848-4ACA-B993-D675644EBB85',
Characteristic.READ | Characteristic.WRITE, Characteristic.READ | Characteristic.WRITE,
Characteristic.READABLE | Characteristic.WRITEABLE, Characteristic.READABLE | Characteristic.WRITEABLE,
CharacteristicValue(read=on_characteristic2_read, write=on_characteristic2_write) CharacteristicValue(
read=on_characteristic2_read, write=on_characteristic2_write
),
) )
service1 = Service( service1 = Service(
'3A657F47-D34F-46B3-B1EC-698E29B6B829', '3A657F47-D34F-46B3-B1EC-698E29B6B829', [characteristic1, characteristic2]
[
characteristic1,
characteristic2
]
) )
server.add_services([service1]) server.add_services([service1])
@@ -403,7 +471,9 @@ async def test_read_write():
assert v1 == b assert v1 == b
assert type(characteristic1._last_value is tuple) assert type(characteristic1._last_value is tuple)
assert len(characteristic1._last_value) == 2 assert len(characteristic1._last_value) == 2
assert str(characteristic1._last_value[0].peer_address) == str(client.random_address) assert str(characteristic1._last_value[0].peer_address) == str(
client.random_address
)
assert characteristic1._last_value[1] == b assert characteristic1._last_value[1] == b
bb = bytes([3, 4, 5, 6]) bb = bytes([3, 4, 5, 6])
characteristic1.value = bb characteristic1.value = bb
@@ -414,7 +484,9 @@ async def test_read_write():
await async_barrier() await async_barrier()
assert type(characteristic2._last_value is tuple) assert type(characteristic2._last_value is tuple)
assert len(characteristic2._last_value) == 2 assert len(characteristic2._last_value) == 2
assert str(characteristic2._last_value[0].peer_address) == str(client.random_address) assert str(characteristic2._last_value[0].peer_address) == str(
client.random_address
)
assert characteristic2._last_value[1] == b assert characteristic2._last_value[1] == b
@@ -428,15 +500,10 @@ async def test_read_write2():
'FDB159DB-036C-49E3-B3DB-6325AC750806', 'FDB159DB-036C-49E3-B3DB-6325AC750806',
Characteristic.READ | Characteristic.WRITE, Characteristic.READ | Characteristic.WRITE,
Characteristic.READABLE | Characteristic.WRITEABLE, Characteristic.READABLE | Characteristic.WRITEABLE,
value=v value=v,
) )
service1 = Service( service1 = Service('3A657F47-D34F-46B3-B1EC-698E29B6B829', [characteristic1])
'3A657F47-D34F-46B3-B1EC-698E29B6B829',
[
characteristic1
]
)
server.add_services([service1]) server.add_services([service1])
await client.power_on() await client.power_on()
@@ -477,11 +544,15 @@ async def test_subscribe_notify():
'FDB159DB-036C-49E3-B3DB-6325AC750806', 'FDB159DB-036C-49E3-B3DB-6325AC750806',
Characteristic.READ | Characteristic.NOTIFY, Characteristic.READ | Characteristic.NOTIFY,
Characteristic.READABLE, Characteristic.READABLE,
bytes([1, 2, 3]) bytes([1, 2, 3]),
) )
def on_characteristic1_subscription(connection, notify_enabled, indicate_enabled): def on_characteristic1_subscription(connection, notify_enabled, indicate_enabled):
characteristic1._last_subscription = (connection, notify_enabled, indicate_enabled) characteristic1._last_subscription = (
connection,
notify_enabled,
indicate_enabled,
)
characteristic1.on('subscription', on_characteristic1_subscription) characteristic1.on('subscription', on_characteristic1_subscription)
@@ -489,11 +560,15 @@ async def test_subscribe_notify():
'66DE9057-C848-4ACA-B993-D675644EBB85', '66DE9057-C848-4ACA-B993-D675644EBB85',
Characteristic.READ | Characteristic.INDICATE, Characteristic.READ | Characteristic.INDICATE,
Characteristic.READABLE, Characteristic.READABLE,
bytes([4, 5, 6]) bytes([4, 5, 6]),
) )
def on_characteristic2_subscription(connection, notify_enabled, indicate_enabled): def on_characteristic2_subscription(connection, notify_enabled, indicate_enabled):
characteristic2._last_subscription = (connection, notify_enabled, indicate_enabled) characteristic2._last_subscription = (
connection,
notify_enabled,
indicate_enabled,
)
characteristic2.on('subscription', on_characteristic2_subscription) characteristic2.on('subscription', on_characteristic2_subscription)
@@ -501,26 +576,33 @@ async def test_subscribe_notify():
'AB5E639C-40C1-4238-B9CB-AF41F8B806E4', 'AB5E639C-40C1-4238-B9CB-AF41F8B806E4',
Characteristic.READ | Characteristic.NOTIFY | Characteristic.INDICATE, Characteristic.READ | Characteristic.NOTIFY | Characteristic.INDICATE,
Characteristic.READABLE, Characteristic.READABLE,
bytes([7, 8, 9]) bytes([7, 8, 9]),
) )
def on_characteristic3_subscription(connection, notify_enabled, indicate_enabled): def on_characteristic3_subscription(connection, notify_enabled, indicate_enabled):
characteristic3._last_subscription = (connection, notify_enabled, indicate_enabled) characteristic3._last_subscription = (
connection,
notify_enabled,
indicate_enabled,
)
characteristic3.on('subscription', on_characteristic3_subscription) characteristic3.on('subscription', on_characteristic3_subscription)
service1 = Service( service1 = Service(
'3A657F47-D34F-46B3-B1EC-698E29B6B829', '3A657F47-D34F-46B3-B1EC-698E29B6B829',
[ [characteristic1, characteristic2, characteristic3],
characteristic1,
characteristic2,
characteristic3
]
) )
server.add_services([service1]) server.add_services([service1])
def on_characteristic_subscription(connection, characteristic, notify_enabled, indicate_enabled): def on_characteristic_subscription(
server._last_subscription = (connection, characteristic, notify_enabled, indicate_enabled) connection, characteristic, notify_enabled, indicate_enabled
):
server._last_subscription = (
connection,
characteristic,
notify_enabled,
indicate_enabled,
)
server.on('characteristic_subscription', on_characteristic_subscription) server.on('characteristic_subscription', on_characteristic_subscription)
@@ -587,53 +669,89 @@ async def test_subscribe_notify():
await peer.subscribe(c2, on_c2_update) await peer.subscribe(c2, on_c2_update)
await async_barrier() await async_barrier()
await server.notify_subscriber(characteristic2._last_subscription[0], characteristic2) await server.notify_subscriber(
characteristic2._last_subscription[0], characteristic2
)
await async_barrier() await async_barrier()
assert not c2._called assert not c2._called
await server.indicate_subscriber(characteristic2._last_subscription[0], characteristic2) await server.indicate_subscriber(
characteristic2._last_subscription[0], characteristic2
)
await async_barrier() await async_barrier()
assert c2._called assert c2._called
assert c2._last_update == characteristic2.value assert c2._last_update == characteristic2.value
c2._called = False c2._called = False
await peer.unsubscribe(c2, on_c2_update) await peer.unsubscribe(c2, on_c2_update)
await server.indicate_subscriber(characteristic2._last_subscription[0], characteristic2) await server.indicate_subscriber(
characteristic2._last_subscription[0], characteristic2
)
await async_barrier() await async_barrier()
assert not c2._called assert not c2._called
c3._called = False
c3._called_2 = False
c3._called_3 = False
c3._last_update = None
c3._last_update_2 = None
c3._last_update_3 = None
def on_c3_update(value): def on_c3_update(value):
c3._called = True c3._called = True
c3._last_update = value c3._last_update = value
def on_c3_update_2(value): def on_c3_update_2(value): # for notify
c3._called_2 = True c3._called_2 = True
c3._last_update_2 = value c3._last_update_2 = value
def on_c3_update_3(value): # for indicate
c3._called_3 = True
c3._last_update_3 = value
c3.on('update', on_c3_update) c3.on('update', on_c3_update)
await peer.subscribe(c3, on_c3_update_2) await peer.subscribe(c3, on_c3_update_2)
await async_barrier() await async_barrier()
await server.notify_subscriber(characteristic3._last_subscription[0], characteristic3) await server.notify_subscriber(
await async_barrier() characteristic3._last_subscription[0], characteristic3
assert c3._called )
assert c3._last_update == characteristic3.value
assert c3._called_2
assert c3._last_update_2 == characteristic3.value
characteristic3.value = bytes([1, 2, 3])
await server.indicate_subscriber(characteristic3._last_subscription[0], characteristic3)
await async_barrier() await async_barrier()
assert c3._called assert c3._called
assert c3._last_update == characteristic3.value assert c3._last_update == characteristic3.value
assert c3._called_2 assert c3._called_2
assert c3._last_update_2 == characteristic3.value assert c3._last_update_2 == characteristic3.value
assert not c3._called_3
c3._called = False c3._called = False
c3._called_2 = False c3._called_2 = False
c3._called_3 = False
await peer.unsubscribe(c3) await peer.unsubscribe(c3)
await server.notify_subscriber(characteristic3._last_subscription[0], characteristic3) await peer.subscribe(c3, on_c3_update_3, prefer_notify=False)
await server.indicate_subscriber(characteristic3._last_subscription[0], characteristic3) await async_barrier()
characteristic3.value = bytes([1, 2, 3])
await server.indicate_subscriber(
characteristic3._last_subscription[0], characteristic3
)
await async_barrier()
assert c3._called
assert c3._last_update == characteristic3.value
assert not c3._called_2
assert c3._called_3
assert c3._last_update_3 == characteristic3.value
c3._called = False
c3._called_2 = False
c3._called_3 = False
await peer.unsubscribe(c3)
await server.notify_subscriber(
characteristic3._last_subscription[0], characteristic3
)
await server.indicate_subscriber(
characteristic3._last_subscription[0], characteristic3
)
await async_barrier() await async_barrier()
assert not c3._called assert not c3._called
assert not c3._called_2 assert not c3._called_2
assert not c3._called_3
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -644,6 +762,7 @@ async def test_mtu_exchange():
d3.gatt_server.max_mtu = 100 d3.gatt_server.max_mtu = 100
d3_connections = [] d3_connections = []
@d3.on('connection') @d3.on('connection')
def on_d3_connection(connection): def on_d3_connection(connection):
d3_connections.append(connection) d3_connections.append(connection)
@@ -672,6 +791,67 @@ async def test_mtu_exchange():
assert d2_connection.att_mtu == 50 assert d2_connection.att_mtu == 50
# -----------------------------------------------------------------------------
def test_char_property_to_string():
# single
assert Characteristic.property_name(0x01) == "BROADCAST"
assert Characteristic.property_name(Characteristic.BROADCAST) == "BROADCAST"
# double
assert Characteristic.properties_as_string(0x03) == "BROADCAST,READ"
assert (
Characteristic.properties_as_string(
Characteristic.BROADCAST | Characteristic.READ
)
== "BROADCAST,READ"
)
# -----------------------------------------------------------------------------
def test_char_property_string_to_type():
# single
assert Characteristic.string_to_properties("BROADCAST") == Characteristic.BROADCAST
# double
assert (
Characteristic.string_to_properties("BROADCAST,READ")
== Characteristic.BROADCAST | Characteristic.READ
)
assert (
Characteristic.string_to_properties("READ,BROADCAST")
== Characteristic.BROADCAST | Characteristic.READ
)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_server_string():
[_, server] = LinkedDevices().devices[:2]
characteristic = Characteristic(
'FDB159DB-036C-49E3-B3DB-6325AC750806',
Characteristic.READ | Characteristic.WRITE | Characteristic.NOTIFY,
Characteristic.READABLE | Characteristic.WRITEABLE,
bytes([123]),
)
service = Service('3A657F47-D34F-46B3-B1EC-698E29B6B829', [characteristic])
server.add_service(service)
assert (
str(server.gatt_server)
== """Service(handle=0x0001, end=0x0005, uuid=UUID-16:1800 (Generic Access))
CharacteristicDeclaration(handle=0x0002, value_handle=0x0003, uuid=UUID-16:2A00 (Device Name), properties=READ)
Characteristic(handle=0x0003, end=0x0003, uuid=UUID-16:2A00 (Device Name), properties=READ)
CharacteristicDeclaration(handle=0x0004, value_handle=0x0005, uuid=UUID-16:2A01 (Appearance), properties=READ)
Characteristic(handle=0x0005, end=0x0005, uuid=UUID-16:2A01 (Appearance), properties=READ)
Service(handle=0x0006, end=0x0009, uuid=3A657F47-D34F-46B3-B1EC-698E29B6B829)
CharacteristicDeclaration(handle=0x0007, value_handle=0x0008, uuid=FDB159DB-036C-49E3-B3DB-6325AC750806, properties=READ,WRITE,NOTIFY)
Characteristic(handle=0x0008, end=0x0009, uuid=FDB159DB-036C-49E3-B3DB-6325AC750806, properties=READ,WRITE,NOTIFY)
Descriptor(handle=0x0009, type=UUID-16:2902 (Client Characteristic Configuration), value=0000)"""
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def async_main(): async def async_main():
await test_read_write() await test_read_write()

Some files were not shown because too many files have changed in this diff Show More