Compare commits

...

76 Commits

Author SHA1 Message Date
Lucas Abel
9a9b4e5bf1 Merge pull request #117 from google/uael/host-fixes
host: fixed `.latency` attribute error
2023-01-27 17:38:11 -08:00
Abel Lucas
895f1618d8 host: fixed .latency attribute error 2023-01-27 23:05:43 +00:00
Gilles Boccon-Gibod
52746e0c68 Merge pull request #116 from google/barbibulle-patch-1
fix libusb-package dependency
2023-01-25 15:59:42 -08:00
Gilles Boccon-Gibod
f9b7072423 Update setup.cfg 2023-01-25 15:37:33 -08:00
Gilles Boccon-Gibod
fa4be1958f Merge pull request #114 from google/gbg/fix-constant-typo
fix typo in constant name
2023-01-23 08:50:07 -08:00
Gilles Boccon-Gibod
f1686d8a9a fix typo in constant name 2023-01-22 19:10:13 -08:00
Gilles Boccon-Gibod
5c6a7f2036 Merge pull request #113 from google/gbg/mypy
add basic support for mypy type checking
2023-01-20 08:08:19 -08:00
Gilles Boccon-Gibod
99758e4b7d add basic support for mypy type checking 2023-01-20 00:20:50 -08:00
Alan Rosenthal
7385de6a69 Merge pull request #95 from AlanRosenthal/alan/fix_show_attributes
Fix `show attributes`
2023-01-19 14:57:22 -05:00
Alan Rosenthal
bb297e7516 Fix show attributes
`show attributes` wasn't being populated since `show_attributes()` was never called.

Also updated `show attributes` to match the color and indentation of `show services`
2023-01-19 12:21:37 -05:00
Lucas Abel
8a91c614c7 Merge pull request #109 from qiaoccolato/main
transport: make libusb_package optional
2023-01-18 14:48:05 -08:00
Qiao Yang
70a50a74b7 transport: make libusb_package optional 2023-01-17 15:17:11 -08:00
Gilles Boccon-Gibod
6a16c61c5f Merge pull request #111 from google/gbg/fix-null-address-setting
don't set a random address when it is 00:00:00:00:00:00
2023-01-13 21:35:32 -08:00
Gilles Boccon-Gibod
0a22f2f7c7 use HCI_LE_Rand 2023-01-13 16:59:34 -08:00
Gilles Boccon-Gibod
422b05ad51 don't set a random address when it is 00:00:00:00:00:00 2023-01-13 13:22:27 -08:00
Gilles Boccon-Gibod
16e926a216 Merge pull request #107 from yuyangh/yuyangh/add_ASHA_L2CAP
add ASHA L2CAP and Event Emitter
2023-01-13 11:05:16 -08:00
Gilles Boccon-Gibod
e94dc66d0c Merge pull request #110 from aleksandrovrts/hci-socket_fix
Fix bug when use hci-socket transport
2023-01-11 09:35:23 -08:00
Aleksandr Aleksandrov
e37c77532b hci_socket.py: fix socket.fileno() call 2023-01-11 16:16:45 +03:00
Gilles Boccon-Gibod
8b9ce03e86 Merge pull request #108 from google/gbg/fix-bluez-vhci
support more commands in controller.py
2023-01-08 14:40:26 -08:00
Gilles Boccon-Gibod
7e854efbbb support more commands in controller.py 2023-01-06 21:51:47 -08:00
Yuyang Huang
64b75be29b add psm parameter for testing support 2023-01-03 16:39:45 -08:00
Yuyang Huang
06018211fe emit event for ASHA l2cap packet 2023-01-03 15:01:32 -08:00
Yuyang Huang
e640991608 Merge branch 'google:main' into yuyangh/add_ASHA_L2CAP 2023-01-03 14:58:37 -08:00
Yuyang Huang
1068a6858d improve logging 2022-12-20 13:33:18 -08:00
Lucas Abel
17db5dd4ff Merge pull request #103 from google/uael/device-fixes
Misc device fixes
2022-12-20 12:15:49 -08:00
Abel Lucas
ea0a7e2347 device: commit LE connection **before** reading it's PHY 2022-12-20 19:25:43 +00:00
Yuyang Huang
6febd1ba35 add L2CAP CoC to ASHA 2022-12-20 11:15:58 -08:00
Gilles Boccon-Gibod
ea6a8d4339 Merge pull request #104 from google/gbg/fix-windll-load
fix libusb loading on Windows
2022-12-20 08:05:57 -08:00
Abel Lucas
ce049865a4 device: always prefer R2 for remote name request 2022-12-20 01:48:08 +00:00
Gilles Boccon-Gibod
6e0129b71d fix libusb loading on Windows 2022-12-18 22:00:26 -08:00
Gilles Boccon-Gibod
7ae3a1d973 Merge pull request #101 from google/gbg/formatting-linting-automation
formatting linting automation
2022-12-16 19:39:28 -08:00
Gilles Boccon-Gibod
c2959dadb4 formatting and linting automation
Squashed commits:
[cd479ba] formatting and linting automation
[7fbfabb] formatting and linting automation
[c4f9505] fix after rebase
[f506ad4] rename job
[441d517] update doc (+7 squashed commits)
[2e1b416] fix invoke and github action
[6ae5bb4] doc for git blame
[44b5461] add GitHub action
[b07474f] add docs
[4cd9a6f] more linter fixes
[db71901] wip
[540dc88] wip
2022-12-15 23:07:17 -08:00
Michael Mogenson
80fe2ea422 Merge pull request #102 from mogenson/libusb_package
Load libusb-1.0 shared library from libusb_package wheel
2022-12-15 21:53:56 -05:00
Lucas Abel
08e6590a76 Merge pull request #88 from google/uael/abort_on_event
Host: spawn each asynchronous task with the right aliveness
2022-12-15 12:46:37 -08:00
Abel Lucas
f580ffcbc3 device: set as Secure Connection when encrypted with AES 2022-12-15 17:02:21 +00:00
Abel Lucas
5178c866ac classic: add to .encrypt the possibilty to disable encryption 2022-12-15 17:02:21 +00:00
Abel Lucas
441933bd64 reverted: 662704e "classic: complete authentication when being the .authenticate acceptor" 2022-12-15 17:02:21 +00:00
Abel Lucas
287df94090 host: spawn each asynchronous task with the right aliveness 2022-12-15 17:02:21 +00:00
Michael Mogenson
86f9496575 Load libusb-1.0 shared library from libusb_package wheel
It would be nice to pip install bumble without having to first install
the libusb system dependency. Expecially on platforms like Windows and
Mac, without a default package manager.

The libusb_package Python package distributes prebuilt libusb-1.0 shared
libraries for each OS and architecture as binary wheels for the pyusb
project. Add this package as a dependency for bumble.

For the pyusb transport, the libusb_package.find() function is a drop-in
replacement for pyusb.core.find(). It searches the libusb_package
site-path before system paths and creates a pyusb backend.

For the usb transport, use libusb_package.get_library_path() to return a
path to the libusb-1.0 library in site-packages. If this path exists,
create a ctypes DLL and init the usb1 backend. This only needs to be
done once. All future calls to usb1 will use this opened library.
If the library path does not exist, do nothing, and usb1 will search
default system paths when the usb1.USBContext object is created.

This commit pins the libusb_package dependency at 1.0.26.0 to ensure
every bumble install uses the exact same version of the libusb library.
2022-12-15 10:22:02 -05:00
Gilles Boccon-Gibod
f5fe3d87f2 Merge pull request #98 from yuyangh/yuyangh/update_asha_advertising
update ASHA AdvertisingData
2022-12-12 15:23:47 -08:00
Yuyang Huang
f65bed2ec4 Merge branch 'main' into yuyangh/update_asha_advertising 2022-12-12 13:35:17 -08:00
Gilles Boccon-Gibod
3efe35065d Merge pull request #96 from google/gbg/black
format with Black
2022-12-12 13:27:58 -08:00
Yuyang Huang
83b42488ea update ASHA AdvertisingData
previously the ASHA AdvertisingData uses INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS, now it lets user to define whether it is complete list or incomplete list
2022-12-12 11:42:48 -08:00
Gilles Boccon-Gibod
135df0dcc0 format with Black 2022-12-10 09:40:12 -08:00
Gilles Boccon-Gibod
8bef344879 Merge pull request #94 from AlanRosenthal/alan/bumble_version_in_show_device
Add bumble's version to `show device`
2022-12-10 09:00:30 -08:00
Alan Rosenthal
55e2f23e29 Add bumble's version to show device 2022-12-09 12:23:45 -05:00
Gilles Boccon-Gibod
297246fa4c Merge pull request #92 from yuyangh/yuyangh/add_ASHA_GATT
add ASHA profile
2022-12-07 13:16:01 -08:00
Yuyang Huang
52db1cfcc1 improve code style 2022-12-06 07:38:05 -08:00
Yuyang Huang
29f9a79502 improve get service advertising data 2022-12-05 11:22:07 -08:00
Gilles Boccon-Gibod
c86125de4f Merge pull request #93 from AlanRosenthal/alan/add_default_services
Add Device::add_default_services()
2022-12-01 12:36:03 -08:00
Yuyang Huang
697d5df3f8 code style update 2022-12-01 10:50:15 -08:00
Yuyang Huang
87aa4f617e add ASHA advertising factory method 2022-12-01 10:40:30 -08:00
Alan Rosenthal
a8eff737e6 Add Device::add_default_services()
This will allow a test to:
a: add services to a device
b: reset services via `Server()`
c: add the default services back
2022-12-01 17:02:54 +00:00
Gilles Boccon-Gibod
4417eb636c Merge pull request #83 from AlanRosenthal/alan/pytest_fixes_2
Test all python versions in CI
2022-11-29 12:48:35 -08:00
Alan Rosenthal
f4e5e61bbb Test all python versions in CI
Followed instructions here: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#using-the-python-starter-workflow
2022-11-29 20:22:56 +00:00
Lucas Abel
ba7a60025f Merge pull request #89 from google/uael/misc
Typos & CI fixes
2022-11-29 12:14:03 -08:00
Yuyang Huang
d92b7e9b74 add ASHA Profile 2022-11-29 10:34:59 -08:00
Yuyang Huang
b0336adf1c add ASHA GATT UUID 2022-11-29 10:02:40 -08:00
Abel Lucas
691450c7de gatt: fix CharacteristicDeclaration.__str__ and associated test 2022-11-29 16:43:47 +00:00
Abel Lucas
99a0eb21c1 address: fix deprecated use of combined @classmethod and @property 2022-11-29 16:33:12 +00:00
Abel Lucas
ab4859bd94 device: fix typos 2022-11-29 16:33:12 +00:00
Lucas Abel
0d70cbde64 Merge pull request #75 from google/uael/fixes
Pairing: device/host fixes & improvements
2022-11-28 21:42:43 -08:00
Gilles Boccon-Gibod
f41d0682b2 Merge pull request #80 from AlanRosenthal/alan/gatt_server_getter
Added class CharacteristicDeclaration, gatt_server getters
2022-11-28 19:21:08 -08:00
Gilles Boccon-Gibod
062dc1e53d Merge pull request #85 from AlanRosenthal/alan/gatt_server_console2
Add `bumble-console --device-config` support for gatt services
2022-11-28 19:19:25 -08:00
Abel Lucas
662704e551 classic: complete authentication when being the .authenticate acceptor 2022-11-29 00:28:39 +00:00
Abel Lucas
02a474c44e smp: emit enough information on pairing complete to deduce security level 2022-11-29 00:28:38 +00:00
Abel Lucas
a1c7aec492 device: fix .find_connection_by_bd_addr 2022-11-29 00:28:38 +00:00
Abel Lucas
6112f00049 device: introduce BR/EDR pending connections
This commit enable the BR/EDR pairing to run asynchronously to
the connection being established.

When in security mode 3, a controller shall start authentication as
part of the connection, which result in HCI events being sent on a BD
address without a completed connection (ie. no connection handle).
2022-11-29 00:28:38 +00:00
Alan Rosenthal
f56ac14f2c Add bumble-console --device-config support for gatt services
This PR adds support for bumble-console to be preloaded with gatt services via `--device-config`.
This PR also adds some type annotations
2022-11-28 14:11:27 -05:00
Gilles Boccon-Gibod
a739fc71ce Merge pull request #84 from google/gbg/78
use libusb auto-detach feature
2022-11-27 12:42:02 -08:00
Alan Rosenthal
b89f9030a0 Added class CharacteristicDeclaration, gatt_server getters
* Converted CharacteristicDeclaration implementation to class
* Added ability to get a gatt_server attribute by service UUID, characteristics UUID, descriptor UUID
2022-11-27 19:22:25 +00:00
Gilles Boccon-Gibod
9e5a85bd10 use libusb auto-detach feature 2022-11-25 17:52:13 -08:00
Gilles Boccon-Gibod
b437bd8619 Merge pull request #82 from AlanRosenthal/alan/pytest_fixes
Fix test failures
2022-11-25 13:19:53 -08:00
Alan Rosenthal
a3e4674819 Fix test failures
a. `DeprecationWarning: The 'warn' method is deprecated, use 'warning' instead`
Updated call in `bumble/smp.py`

b. `ModuleNotFoundError: No module named 'bumble.apps'`
Updated imports in `tests/import_test.py`

c. Added `pytest-html` for easier viewing of test results
Added package in `setup.cfg`, and hook in `tasks.py`

d. Updated workflows to use `invoke test`

This is a partial fix of #81
2022-11-23 11:31:27 -05:00
Abel Lucas
5f1d57fcb0 device: simplify and fixes remote name request 2022-11-22 21:20:56 +00:00
Abel Lucas
9c133706e6 keys: add a way to remove all bonds from key store 2022-11-18 18:22:15 +00:00
181 changed files with 11831 additions and 6853 deletions

2
.git-blame-ignore-revs Normal file
View File

@@ -0,0 +1,2 @@
# Migrate code style to Black
135df0dcc01ab765f432e19b1a5202d29bd55545

35
.github/workflows/code-check.yml vendored Normal file
View File

@@ -0,0 +1,35 @@
# Check the code against the formatter and linter
name: Code format and lint check
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
permissions:
contents: read
jobs:
check:
name: Check Code
runs-on: ubuntu-latest
steps:
- name: Check out from Git
uses: actions/checkout@v3
- name: Get history and tags for SCM versioning to work
run: |
git fetch --prune --unshallow
git fetch --depth=1 origin +refs/tags/*:refs/tags/*
- name: Set up Python
uses: actions/setup-python@v3
with:
python-version: '3.10'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install ".[build,test,development]"
- name: Check
run: |
invoke project.pre-commit

View File

@@ -14,6 +14,10 @@ jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10"]
fail-fast: false
steps:
- name: Check out from Git
@@ -22,17 +26,17 @@ jobs:
run: |
git fetch --prune --unshallow
git fetch --depth=1 origin +refs/tags/*:refs/tags/*
- name: Set up Python 3.10
uses: actions/setup-python@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: "3.10"
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install ".[build,test,development,documentation]"
- name: Test with pytest
- name: Test
run: |
pytest
invoke test
- name: Build
run: |
inv build

8
.gitignore vendored
View File

@@ -3,9 +3,9 @@ build/
dist/
*.egg-info/
*~
bumble/__pycache__
docs/mkdocs/site
tests/__pycache__
test-results.xml
bumble/transport/__pycache__
bumble/profiles/__pycache__
__pycache__
# generated by setuptools_scm
bumble/_version.py
.vscode/launch.json

75
.vscode/settings.json vendored Normal file
View File

@@ -0,0 +1,75 @@
{
"cSpell.words": [
"Abortable",
"altsetting",
"ansiblue",
"ansicyan",
"ansigreen",
"ansimagenta",
"ansired",
"ansiyellow",
"appendleft",
"ASHA",
"asyncio",
"ATRAC",
"avdtp",
"bitpool",
"bitstruct",
"BSCP",
"BTPROTO",
"CCCD",
"cccds",
"cmac",
"CONNECTIONLESS",
"csrcs",
"datagram",
"DATALINK",
"delayreport",
"deregisters",
"deregistration",
"dhkey",
"diversifier",
"Fitbit",
"GATTLINK",
"HANDSFREE",
"keydown",
"keyup",
"levelname",
"libc",
"libusb",
"MITM",
"NDIS",
"NONBLOCK",
"NONCONN",
"OXIMETER",
"popleft",
"psms",
"pyee",
"pyusb",
"rfcomm",
"ROHC",
"rssi",
"SEID",
"seids",
"SERV",
"ssrc",
"strerror",
"subband",
"subbands",
"subevent",
"Subrating",
"substates",
"tobytes",
"tsep",
"usbmodem",
"vhci",
"websockets",
"xcursor",
"ycursor"
],
"[python]": {
"editor.rulers": [88]
},
"python.formatting.provider": "black",
"pylint.importStrategy": "useBundled"
}

View File

@@ -29,7 +29,7 @@ For a quick start to using Bumble, see the [Getting Started](docs/mkdocs/src/get
### Dependencies
To install package dependencies needed to run the bumble examples execute the following commands:
To install package dependencies needed to run the bumble examples, execute the following commands:
```
python -m pip install --upgrade pip
@@ -50,7 +50,7 @@ Bumble is easiest to use with a dedicated USB dongle.
This is because internal Bluetooth interfaces tend to be locked down by the operating system.
You can use the [usb_probe](/docs/mkdocs/src/apps_and_tools/usb_probe.md) tool (all platforms) or `lsusb` (Linux or macOS) to list the available USB devices on your system.
See the [USB Transport](/docs/mkdocs/src/transports/usb.md) page for details on how to refer to USB devices.
See the [USB Transport](/docs/mkdocs/src/transports/usb.md) page for details on how to refer to USB devices. Also, if your are on a mac, see [these instructions](docs/mkdocs/src/platforms/macos.md).
## License

View File

@@ -47,5 +47,3 @@ NOTE: this assumes you're running a Link Relay on port `10723`.
## `console.py`
A simple text-based-ui interactive Bluetooth device with GATT client capabilities.

View File

@@ -29,18 +29,6 @@ from collections import OrderedDict
import click
import colors
from bumble.core import UUID, AdvertisingData, TimeoutError, BT_LE_TRANSPORT
from bumble.device import ConnectionParametersPreferences, Device, Connection, Peer
from bumble.utils import AsyncRunner
from bumble.transport import open_transport_or_link
from bumble.gatt import Characteristic
from bumble.hci import (
HCI_Constant,
HCI_LE_1M_PHY,
HCI_LE_2M_PHY,
HCI_LE_CODED_PHY,
)
from prompt_toolkit import Application
from prompt_toolkit.history import FileHistory
from prompt_toolkit.completion import Completer, Completion, NestedCompleter
@@ -60,9 +48,24 @@ from prompt_toolkit.layout import (
FormattedTextControl,
FloatContainer,
ConditionalContainer,
Dimension
Dimension,
)
from bumble import __version__
import bumble.core
from bumble.core import UUID, AdvertisingData, BT_LE_TRANSPORT
from bumble.device import ConnectionParametersPreferences, Device, Connection, Peer
from bumble.utils import AsyncRunner
from bumble.transport import open_transport_or_link
from bumble.gatt import Characteristic, Service, CharacteristicDeclaration, Descriptor
from bumble.hci import (
HCI_Constant,
HCI_LE_1M_PHY,
HCI_LE_2M_PHY,
HCI_LE_CODED_PHY,
)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
@@ -74,22 +77,15 @@ DISPLAY_MAX_RSSI = -30
RSSI_MONITOR_INTERVAL = 5.0 # Seconds
# -----------------------------------------------------------------------------
# Globals
# -----------------------------------------------------------------------------
App = None
# -----------------------------------------------------------------------------
# Utils
# -----------------------------------------------------------------------------
def le_phy_name(phy_id):
return {
HCI_LE_1M_PHY: '1M',
HCI_LE_2M_PHY: '2M',
HCI_LE_CODED_PHY: 'CODED'
}.get(phy_id, HCI_Constant.le_phy_name(phy_id))
return {HCI_LE_1M_PHY: '1M', HCI_LE_2M_PHY: '2M', HCI_LE_CODED_PHY: 'CODED'}.get(
phy_id, HCI_Constant.le_phy_name(phy_id)
)
def rssi_bar(rssi):
@@ -104,7 +100,7 @@ def rssi_bar(rssi):
def parse_phys(phys):
if phys.lower() == '*':
return None
else:
phy_list = []
elements = phys.lower().split(',')
for element in elements:
@@ -132,12 +128,14 @@ class ConsoleApp:
self.monitor_rssi = False
self.connection_rssi = None
style = Style.from_dict({
style = Style.from_dict(
{
'output-field': 'bg:#000044 #ffffff',
'input-field': 'bg:#000000 #ffffff',
'line': '#004400',
'error': 'fg:ansired'
})
'error': 'fg:ansired',
}
)
class LiveCompleter(Completer):
def __init__(self, words):
@@ -149,26 +147,17 @@ class ConsoleApp:
yield Completion(word, start_position=-len(prefix))
def make_completer():
return NestedCompleter.from_nested_dict({
'scan': {
'on': None,
'off': None,
'clear': None
},
'advertise': {
'on': None,
'off': None
},
'rssi': {
'on': None,
'off': None
},
return NestedCompleter.from_nested_dict(
{
'scan': {'on': None, 'off': None, 'clear': None},
'advertise': {'on': None, 'off': None},
'rssi': {'on': None, 'off': None},
'show': {
'scan': None,
'services': None,
'attributes': None,
'log': None,
'device': None
'device': None,
'local-services': None,
'remote-services': None,
},
'filter': {
'address': None,
@@ -177,24 +166,18 @@ class ConsoleApp:
'update-parameters': None,
'encrypt': None,
'disconnect': None,
'discover': {
'services': None,
'attributes': None
},
'discover': {'services': None, 'attributes': None},
'request-mtu': None,
'read': LiveCompleter(self.known_attributes),
'write': LiveCompleter(self.known_attributes),
'subscribe': LiveCompleter(self.known_attributes),
'unsubscribe': LiveCompleter(self.known_attributes),
'set-phy': {
'1m': None,
'2m': None,
'coded': None
},
'set-phy': {'1m': None, '2m': None, 'coded': None},
'set-default-phy': None,
'quit': None,
'exit': None
})
'exit': None,
}
)
self.input_field = TextArea(
height=1,
@@ -202,49 +185,55 @@ class ConsoleApp:
multiline=False,
wrap_lines=False,
completer=make_completer(),
history=FileHistory(os.path.join(BUMBLE_USER_DIR, 'history'))
history=FileHistory(os.path.join(BUMBLE_USER_DIR, 'history')),
)
self.input_field.accept_handler = self.accept_input
self.output_height = Dimension(min=7, max=7, weight=1)
self.output_lines = []
self.output = FormattedTextControl(get_cursor_position=lambda: Point(0, max(0, len(self.output_lines) - 1)))
self.output = FormattedTextControl(
get_cursor_position=lambda: Point(0, max(0, len(self.output_lines) - 1))
)
self.output_max_lines = 20
self.scan_results_text = FormattedTextControl()
self.services_text = FormattedTextControl()
self.attributes_text = FormattedTextControl()
self.local_services_text = FormattedTextControl()
self.remote_services_text = FormattedTextControl()
self.device_text = FormattedTextControl()
self.log_text = FormattedTextControl(get_cursor_position=lambda: Point(0, max(0, len(self.log_lines) - 1)))
self.log_text = FormattedTextControl(
get_cursor_position=lambda: Point(0, max(0, len(self.log_lines) - 1))
)
self.log_height = Dimension(min=7, weight=4)
self.log_max_lines = 100
self.log_lines = []
container = HSplit([
container = HSplit(
[
ConditionalContainer(
Frame(Window(self.scan_results_text), title='Scan Results'),
filter=Condition(lambda: self.top_tab == 'scan')
filter=Condition(lambda: self.top_tab == 'scan'),
),
ConditionalContainer(
Frame(Window(self.services_text), title='Services'),
filter=Condition(lambda: self.top_tab == 'services')
Frame(Window(self.local_services_text), title='Local Services'),
filter=Condition(lambda: self.top_tab == 'local-services'),
),
ConditionalContainer(
Frame(Window(self.attributes_text), title='Attributes'),
filter=Condition(lambda: self.top_tab == 'attributes')
Frame(Window(self.remote_services_text), title='Remove Services'),
filter=Condition(lambda: self.top_tab == 'remote-services'),
),
ConditionalContainer(
Frame(Window(self.log_text, height=self.log_height), title='Log'),
filter=Condition(lambda: self.top_tab == 'log')
filter=Condition(lambda: self.top_tab == 'log'),
),
ConditionalContainer(
Frame(Window(self.device_text), title='Device'),
filter=Condition(lambda: self.top_tab == 'device')
filter=Condition(lambda: self.top_tab == 'device'),
),
Frame(Window(self.output, height=self.output_height)),
FormattedTextToolbar(text=self.get_status_bar_text, style='reverse'),
self.input_field
])
self.input_field,
]
)
container = FloatContainer(
container,
@@ -259,17 +248,16 @@ class ConsoleApp:
layout = Layout(container, focused_element=self.input_field)
kb = KeyBindings()
@kb.add("c-c")
@kb.add("c-q")
key_bindings = KeyBindings()
@key_bindings.add("c-c")
@key_bindings.add("c-q")
def _(event):
event.app.exit()
# pylint: disable=invalid-name
self.ui = Application(
layout=layout,
style=style,
key_bindings=kb,
full_screen=True
layout=layout, style=style, key_bindings=key_bindings, full_screen=True
)
async def run_async(self, device_config, transport):
@@ -277,16 +265,23 @@ class ConsoleApp:
async with await open_transport_or_link(transport) as (hci_source, hci_sink):
if device_config:
self.device = Device.from_config_file_with_hci(device_config, hci_source, hci_sink)
self.device = Device.from_config_file_with_hci(
device_config, hci_source, hci_sink
)
else:
random_address = f"{random.randint(192,255):02X}" # address is static random
for c in random.sample(range(255), 5):
random_address += f":{c:02X}"
random_address = (
f"{random.randint(192,255):02X}" # address is static random
)
for random_byte in random.sample(range(255), 5):
random_address += f":{random_byte:02X}"
self.append_to_log(f"Setting random address: {random_address}")
self.device = Device.with_hci('Bumble', random_address, hci_source, hci_sink)
self.device = Device.with_hci(
'Bumble', random_address, hci_source, hci_sink
)
self.device.listener = DeviceListener(self)
await self.device.power_on()
self.show_device(self.device)
self.show_local_services(self.device.gatt_server.attributes)
# Run the UI
await self.ui.run_async()
@@ -296,7 +291,7 @@ class ConsoleApp:
def add_known_address(self, address):
self.known_addresses.add(address)
def accept_input(self, buff):
def accept_input(self, _):
if len(self.input_field.text) == 0:
return
self.append_to_output([('', '* '), ('ansicyan', self.input_field.text)], False)
@@ -315,13 +310,27 @@ class ConsoleApp:
connection_state = 'CONNECTING'
elif self.connected_peer:
connection = self.connected_peer.connection
connection_parameters = f'{connection.parameters.connection_interval}/{connection.parameters.peripheral_latency}/{connection.parameters.supervision_timeout}'
connection_parameters = (
f'{connection.parameters.connection_interval}/'
f'{connection.parameters.peripheral_latency}/'
f'{connection.parameters.supervision_timeout}'
)
if connection.transport == BT_LE_TRANSPORT:
phy_state = f' RX={le_phy_name(connection.phy.rx_phy)}/TX={le_phy_name(connection.phy.tx_phy)}'
phy_state = (
f' RX={le_phy_name(connection.phy.rx_phy)}/'
f'TX={le_phy_name(connection.phy.tx_phy)}'
)
else:
phy_state = ''
connection_state = f'{connection.peer_address} {connection_parameters} {connection.data_length}{phy_state}'
encryption_state = 'ENCRYPTED' if connection.is_encrypted else 'NOT ENCRYPTED'
connection_state = (
f'{connection.peer_address} '
f'{connection_parameters} '
f'{connection.data_length}'
f'{phy_state}'
)
encryption_state = (
'ENCRYPTED' if connection.is_encrypted else 'NOT ENCRYPTED'
)
att_mtu = f'ATT_MTU: {connection.att_mtu}'
return [
@@ -333,7 +342,7 @@ class ConsoleApp:
('', ' '),
('ansicyan', f' {att_mtu} '),
('', ' '),
('ansiyellow', f' {rssi} ')
('ansiyellow', f' {rssi} '),
]
def show_error(self, title, details=None):
@@ -351,35 +360,45 @@ class ConsoleApp:
self.scan_results_text.text = ANSI('\n'.join(lines))
self.ui.invalidate()
def show_services(self, services):
def show_remote_services(self, services):
lines = []
del self.known_attributes[:]
for service in services:
lines.append(('ansicyan', str(service) + '\n'))
lines.append(("ansicyan", f"{service}\n"))
for characteristic in service.characteristics:
lines.append(('ansimagenta', ' ' + str(characteristic) + '\n'))
self.known_attributes.append(f'{service.uuid.to_hex_str()}.{characteristic.uuid.to_hex_str()}')
lines.append(('ansimagenta', f' {characteristic} + \n'))
self.known_attributes.append(
f'{service.uuid.to_hex_str()}.{characteristic.uuid.to_hex_str()}'
)
self.known_attributes.append(f'*.{characteristic.uuid.to_hex_str()}')
self.known_attributes.append(f'#{characteristic.handle:X}')
for descriptor in characteristic.descriptors:
lines.append(('ansigreen', ' ' + str(descriptor) + '\n'))
lines.append(("ansigreen", f" {descriptor}\n"))
self.services_text.text = lines
self.remote_services_text.text = lines
self.ui.invalidate()
def show_attributes(self, attributes):
def show_local_services(self, attributes):
lines = []
for attribute in attributes:
lines.append(('ansicyan', f'{attribute}\n'))
if isinstance(attribute, Service):
lines.append(("ansicyan", f"{attribute}\n"))
elif isinstance(attribute, (Characteristic, CharacteristicDeclaration)):
lines.append(("ansimagenta", f" {attribute}\n"))
elif isinstance(attribute, Descriptor):
lines.append(("ansigreen", f" {attribute}\n"))
else:
lines.append(("ansiyellow", f"{attribute}\n"))
self.attributes_text.text = lines
self.local_services_text.text = lines
self.ui.invalidate()
def show_device(self, device):
lines = []
lines.append(('ansicyan', 'Bumble Version: '))
lines.append(('', f'{__version__}\n'))
lines.append(('ansicyan', 'Name: '))
lines.append(('', f'{device.name}\n'))
lines.append(('ansicyan', 'Public Address: '))
@@ -407,7 +426,10 @@ class ConsoleApp:
advertising_interval = (
device.advertising_interval_min
if device.advertising_interval_min == device.advertising_interval_max
else f"{device.advertising_interval_min} to {device.advertising_interval_max}"
else (
f'{device.advertising_interval_min} to '
f'{device.advertising_interval_max}'
)
)
lines.append(('ansicyan', 'Advertising Interval: '))
lines.append(('', f'{advertising_interval}\n'))
@@ -416,7 +438,7 @@ class ConsoleApp:
self.ui.invalidate()
def append_to_output(self, line, invalidate=True):
if type(line) is str:
if isinstance(line, str):
line = [('', line)]
self.output_lines = self.output_lines[-self.output_max_lines :]
self.output_lines.append(line)
@@ -454,7 +476,7 @@ class ConsoleApp:
await self.connected_peer.discover_descriptors(characteristic)
self.append_to_output('discovery completed')
self.show_services(self.connected_peer.services)
self.show_remote_services(self.connected_peer.services)
async def discover_attributes(self):
if not self.connected_peer:
@@ -486,6 +508,8 @@ class ConsoleApp:
if characteristic.handle == attribute_handle:
return characteristic
return None
async def rssi_monitor_loop(self):
while True:
if self.monitor_rssi and self.connected_peer:
@@ -515,7 +539,11 @@ class ConsoleApp:
elif params[0] == 'on':
if len(params) == 2:
if not params[1].startswith("filter="):
self.show_error('invalid syntax', 'expected address filter=key1:value1,key2:value,... available filters: address')
self.show_error(
'invalid syntax',
'expected address filter=key1:value1,key2:value,... '
'available filters: address',
)
# regex: (word):(any char except ,)
matches = re.findall(r"(\w+):([^,]+)", params[1])
for match in matches:
@@ -557,8 +585,7 @@ class ConsoleApp:
connection_parameters_preferences = None
else:
connection_parameters_preferences = {
phy: ConnectionParametersPreferences()
for phy in phys
phy: ConnectionParametersPreferences() for phy in phys
}
if self.device.is_scanning:
@@ -570,13 +597,13 @@ class ConsoleApp:
await self.device.connect(
params[0],
connection_parameters_preferences=connection_parameters_preferences,
timeout=DEFAULT_CONNECTION_TIMEOUT
timeout=DEFAULT_CONNECTION_TIMEOUT,
)
self.top_tab = 'services'
except TimeoutError:
except bumble.core.TimeoutError:
self.show_error('connection timed out')
async def do_disconnect(self, params):
async def do_disconnect(self, _):
if self.device.is_le_connecting:
await self.device.cancel_connection()
else:
@@ -588,7 +615,11 @@ class ConsoleApp:
async def do_update_parameters(self, params):
if len(params) != 1 or len(params[0].split('/')) != 3:
self.show_error('invalid syntax', 'expected update-parameters <interval-min>-<interval-max>/<max-latency>/<supervision>')
self.show_error(
'invalid syntax',
'expected update-parameters <interval-min>-<interval-max>'
'/<max-latency>/<supervision>',
)
return
if not self.connected_peer:
@@ -596,17 +627,19 @@ class ConsoleApp:
return
connection_intervals, max_latency, supervision_timeout = params[0].split('/')
connection_interval_min, connection_interval_max = [int(x) for x in connection_intervals.split('-')]
connection_interval_min, connection_interval_max = [
int(x) for x in connection_intervals.split('-')
]
max_latency = int(max_latency)
supervision_timeout = int(supervision_timeout)
await self.connected_peer.connection.update_parameters(
connection_interval_min,
connection_interval_max,
max_latency,
supervision_timeout
supervision_timeout,
)
async def do_encrypt(self, params):
async def do_encrypt(self, _):
if not self.connected_peer:
self.show_error('not connected')
return
@@ -629,17 +662,26 @@ class ConsoleApp:
async def do_show(self, params):
if params:
if params[0] in {'scan', 'services', 'attributes', 'log', 'device'}:
if params[0] in {
'scan',
'log',
'device',
'local-services',
'remote-services',
}:
self.top_tab = params[0]
self.ui.invalidate()
async def do_get_phy(self, params):
async def do_get_phy(self, _):
if not self.connected_peer:
self.show_error('not connected')
return
phy = await self.connected_peer.connection.get_phy()
self.append_to_output(f'PHY: RX={HCI_Constant.le_phy_name(phy[0])}, TX={HCI_Constant.le_phy_name(phy[1])}')
self.append_to_output(
f'PHY: RX={HCI_Constant.le_phy_name(phy[0])}, '
f'TX={HCI_Constant.le_phy_name(phy[1])}'
)
async def do_request_mtu(self, params):
if len(params) != 1:
@@ -721,7 +763,9 @@ class ConsoleApp:
return
await characteristic.subscribe(
lambda value: self.append_to_output(f"{characteristic} VALUE: 0x{value.hex()}"),
lambda value: self.append_to_output(
f"{characteristic} VALUE: 0x{value.hex()}"
),
)
async def do_unsubscribe(self, params):
@@ -742,7 +786,9 @@ class ConsoleApp:
async def do_set_phy(self, params):
if len(params) != 1:
self.show_error('invalid syntax', 'expected set-phy <tx_rx_phys>|<tx_phys>/<rx_phys>')
self.show_error(
'invalid syntax', 'expected set-phy <tx_rx_phys>|<tx_phys>/<rx_phys>'
)
return
if not self.connected_peer:
@@ -756,13 +802,15 @@ class ConsoleApp:
rx_phys = tx_phys
await self.connected_peer.connection.set_phy(
tx_phys=parse_phys(tx_phys),
rx_phys=parse_phys(rx_phys)
tx_phys=parse_phys(tx_phys), rx_phys=parse_phys(rx_phys)
)
async def do_set_default_phy(self, params):
if len(params) != 1:
self.show_error('invalid syntax', 'expected set-default-phy <tx_rx_phys>|<tx_phys>/<rx_phys>')
self.show_error(
'invalid syntax',
'expected set-default-phy <tx_rx_phys>|<tx_phys>/<rx_phys>',
)
return
if '/' in params[0]:
@@ -772,14 +820,13 @@ class ConsoleApp:
rx_phys = tx_phys
await self.device.set_default_phy(
tx_phys=parse_phys(tx_phys),
rx_phys=parse_phys(rx_phys)
tx_phys=parse_phys(tx_phys), rx_phys=parse_phys(rx_phys)
)
async def do_exit(self, params):
async def do_exit(self, _):
self.ui.exit()
async def do_quit(self, params):
async def do_quit(self, _):
self.ui.exit()
async def do_filter(self, params):
@@ -789,6 +836,7 @@ class ConsoleApp:
return
self.device.listener.address_filter = params[1]
# -----------------------------------------------------------------------------
# Device and Connection Listener
# -----------------------------------------------------------------------------
@@ -808,7 +856,9 @@ class DeviceListener(Device.Listener, Connection.Listener):
self._address_filter = re.compile(r".*")
else:
self._address_filter = re.compile(filter_addr)
self.scan_results = OrderedDict(filter(lambda x: self.filter_address_match(x), self.scan_results))
self.scan_results = OrderedDict(
filter(self.filter_address_match, self.scan_results)
)
self.app.show_scan_results(self.scan_results)
def filter_address_match(self, address):
@@ -818,6 +868,7 @@ class DeviceListener(Device.Listener, Connection.Listener):
return bool(self.address_filter.match(address))
@AsyncRunner.run_in_task()
# pylint: disable=invalid-overridden-method
async def on_connection(self, connection):
self.app.connected_peer = Peer(connection)
self.app.connection_rssi = None
@@ -825,24 +876,44 @@ class DeviceListener(Device.Listener, Connection.Listener):
connection.listener = self
def on_disconnection(self, reason):
self.app.append_to_output(f'disconnected from {self.app.connected_peer}, reason: {HCI_Constant.error_name(reason)}')
self.app.append_to_output(
f'disconnected from {self.app.connected_peer}, '
f'reason: {HCI_Constant.error_name(reason)}'
)
self.app.connected_peer = None
self.app.connection_rssi = None
def on_connection_parameters_update(self):
self.app.append_to_output(f'connection parameters update: {self.app.connected_peer.connection.parameters}')
self.app.append_to_output(
f'connection parameters update: '
f'{self.app.connected_peer.connection.parameters}'
)
def on_connection_phy_update(self):
self.app.append_to_output(f'connection phy update: {self.app.connected_peer.connection.phy}')
self.app.append_to_output(
f'connection phy update: {self.app.connected_peer.connection.phy}'
)
def on_connection_att_mtu_update(self):
self.app.append_to_output(f'connection att mtu update: {self.app.connected_peer.connection.att_mtu}')
self.app.append_to_output(
f'connection att mtu update: {self.app.connected_peer.connection.att_mtu}'
)
def on_connection_encryption_change(self):
self.app.append_to_output(f'connection encryption change: {"encrypted" if self.app.connected_peer.connection.is_encrypted else "not encrypted"}')
encryption_state = (
'encrypted'
if self.app.connected_peer.connection.is_encrypted
else 'not encrypted'
)
self.app.append_to_output(
'connection encryption change: ' f'{encryption_state}'
)
def on_connection_data_length_change(self):
self.app.append_to_output(f'connection data length change: {self.app.connected_peer.connection.data_length}')
self.app.append_to_output(
'connection data length change: '
f'{self.app.connected_peer.connection.data_length}'
)
def on_advertisement(self, advertisement):
if not self.filter_address_match(str(advertisement.address)):
@@ -856,7 +927,13 @@ class DeviceListener(Device.Listener, Connection.Listener):
entry.connectable = advertisement.is_connectable
else:
self.app.add_known_address(str(advertisement.address))
self.scan_results[entry_key] = ScanResult(advertisement.address, advertisement.address.address_type, advertisement.data, advertisement.rssi, advertisement.is_connectable)
self.scan_results[entry_key] = ScanResult(
advertisement.address,
advertisement.address.address_type,
advertisement.data,
advertisement.rssi,
advertisement.is_connectable,
)
self.app.show_scan_results(self.scan_results)
@@ -892,10 +969,16 @@ class ScanResult:
else:
name = ''
# Remove any '/P' qualifier suffix from the address string
address_str = str(self.address).replace('/P', '')
# RSSI bar
bar_string = rssi_bar(self.rssi)
bar_padding = ' ' * (DEFAULT_RSSI_BAR_WIDTH + 5 - len(bar_string))
return f'{address_color(str(self.address))} [{type_color(address_type_string)}] {bar_string} {bar_padding} {name}'
return (
f'{address_color(address_str)} [{type_color(address_type_string)}] '
f'{bar_string} {bar_padding} {name}'
)
# -----------------------------------------------------------------------------
@@ -923,7 +1006,7 @@ def main(device_config, transport):
if not os.path.isdir(BUMBLE_USER_DIR):
os.mkdir(BUMBLE_USER_DIR)
# Create an instane of the app
# Create an instance of the app
app = ConsoleApp()
# Setup logging
@@ -940,4 +1023,4 @@ def main(device_config, transport):
# -----------------------------------------------------------------------------
if __name__ == "__main__":
main()
main() # pylint: disable=no-value-for-parameter

View File

@@ -39,7 +39,7 @@ from bumble.hci import (
HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND,
HCI_LE_Read_Number_Of_Supported_Advertising_Sets_Command,
HCI_LE_READ_MAXIMUM_ADVERTISING_DATA_LENGTH_COMMAND,
HCI_LE_Read_Maximum_Advertising_Data_Length_Command
HCI_LE_Read_Maximum_Advertising_Data_Length_Command,
)
from bumble.host import Host
from bumble.transport import open_transport_or_link
@@ -51,13 +51,18 @@ async def get_classic_info(host):
response = await host.send_command(HCI_Read_BD_ADDR_Command())
if response.return_parameters.status == HCI_SUCCESS:
print()
print(color('Classic Address:', 'yellow'), response.return_parameters.bd_addr)
print(
color('Classic Address:', 'yellow'), response.return_parameters.bd_addr
)
if host.supports_command(HCI_READ_LOCAL_NAME_COMMAND):
response = await host.send_command(HCI_Read_Local_Name_Command())
if response.return_parameters.status == HCI_SUCCESS:
print()
print(color('Local Name:', 'yellow'), map_null_terminated_utf8_string(response.return_parameters.local_name))
print(
color('Local Name:', 'yellow'),
map_null_terminated_utf8_string(response.return_parameters.local_name),
)
# -----------------------------------------------------------------------------
@@ -65,21 +70,25 @@ async def get_le_info(host):
print()
if host.supports_command(HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND):
response = await host.send_command(HCI_LE_Read_Number_Of_Supported_Advertising_Sets_Command())
response = await host.send_command(
HCI_LE_Read_Number_Of_Supported_Advertising_Sets_Command()
)
if response.return_parameters.status == HCI_SUCCESS:
print(
color('LE Number Of Supported Advertising Sets:', 'yellow'),
response.return_parameters.num_supported_advertising_sets,
'\n'
'\n',
)
if host.supports_command(HCI_LE_READ_MAXIMUM_ADVERTISING_DATA_LENGTH_COMMAND):
response = await host.send_command(HCI_LE_Read_Maximum_Advertising_Data_Length_Command())
response = await host.send_command(
HCI_LE_Read_Maximum_Advertising_Data_Length_Command()
)
if response.return_parameters.status == HCI_SUCCESS:
print(
color('LE Maximum Advertising Data Length:', 'yellow'),
response.return_parameters.max_advertising_data_length,
'\n'
'\n',
)
if host.supports_command(HCI_LE_READ_MAXIMUM_DATA_LENGTH_COMMAND):
@@ -93,7 +102,7 @@ async def get_le_info(host):
f'rx:{response.return_parameters.supported_max_rx_octets}/'
f'{response.return_parameters.supported_max_rx_time}'
),
'\n'
'\n',
)
print(color('LE Features:', 'yellow'))
@@ -112,10 +121,19 @@ async def async_main(transport):
# Print version
print(color('Version:', 'yellow'))
print(color(' Manufacturer: ', 'green'), name_or_number(COMPANY_IDENTIFIERS, host.local_version.company_identifier))
print(color(' HCI Version: ', 'green'), name_or_number(HCI_VERSION_NAMES, host.local_version.hci_version))
print(
color(' Manufacturer: ', 'green'),
name_or_number(COMPANY_IDENTIFIERS, host.local_version.company_identifier),
)
print(
color(' HCI Version: ', 'green'),
name_or_number(HCI_VERSION_NAMES, host.local_version.hci_version),
)
print(color(' HCI Subversion:', 'green'), host.local_version.hci_subversion)
print(color(' LMP Version: ', 'green'), name_or_number(LMP_VERSION_NAMES, host.local_version.lmp_version))
print(
color(' LMP Version: ', 'green'),
name_or_number(LMP_VERSION_NAMES, host.local_version.lmp_version),
)
print(color(' LMP Subversion:', 'green'), host.local_version.lmp_subversion)
# Get the Classic info

View File

@@ -28,11 +28,14 @@ from bumble.transport import open_transport_or_link
# -----------------------------------------------------------------------------
async def async_main():
if len(sys.argv) != 3:
print('Usage: controllers.py <hci-transport-1> <hci-transport-2> [<hci-transport-3> ...]')
print(
'Usage: controllers.py <hci-transport-1> <hci-transport-2> '
'[<hci-transport-3> ...]'
)
print('example: python controllers.py pty:ble1 pty:ble2')
return
# Create a loccal link to attach the controllers to
# Create a local link to attach the controllers to
link = LocalLink()
# Create a transport and controller for all requested names
@@ -41,7 +44,12 @@ async def async_main():
for index, transport_name in enumerate(sys.argv[1:]):
transport = await open_transport_or_link(transport_name)
transports.append(transport)
controller = Controller(f'C{index}', host_source = transport.source, host_sink = transport.sink, link = link)
controller = Controller(
f'C{index}',
host_source=transport.source,
host_sink=transport.sink,
link=link,
)
controllers.append(controller)
# Wait until the user interrupts

View File

@@ -21,7 +21,7 @@ import logging
import click
from colors import color
from bumble.core import ProtocolError, TimeoutError
import bumble.core
from bumble.device import Device, Peer
from bumble.gatt import show_services
from bumble.transport import open_transport_or_link
@@ -49,9 +49,9 @@ async def dump_gatt_db(peer, done):
try:
value = await attribute.read_value()
print(color(f'{value.hex()}', 'green'))
except ProtocolError as error:
except bumble.core.ProtocolError as error:
print(color(error, 'red'))
except TimeoutError:
except bumble.core.TimeoutError:
print(color('read timeout', 'red'))
if done is not None:
@@ -64,9 +64,13 @@ async def async_main(device_config, encrypt, transport, address_or_name):
# Create a device
if device_config:
device = Device.from_config_file_with_hci(device_config, hci_source, hci_sink)
device = Device.from_config_file_with_hci(
device_config, hci_source, hci_sink
)
else:
device = Device.with_hci('Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink)
device = Device.with_hci(
'Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink
)
await device.power_on()
if address_or_name:
@@ -81,7 +85,12 @@ async def async_main(device_config, encrypt, transport, address_or_name):
else:
# Wait for a connection
done = asyncio.get_running_loop().create_future()
device.on('connection', lambda connection: asyncio.create_task(dump_gatt_db(Peer(connection), done)))
device.on(
'connection',
lambda connection: asyncio.create_task(
dump_gatt_db(Peer(connection), done)
),
)
await device.start_advertising(auto_restart=True)
print(color('### Waiting for connection...', 'blue'))

View File

@@ -36,7 +36,9 @@ from bumble.hci import HCI_Constant
GG_GATTLINK_SERVICE_UUID = 'ABBAFF00-E56A-484C-B832-8B17CF6CBFE8'
GG_GATTLINK_RX_CHARACTERISTIC_UUID = 'ABBAFF01-E56A-484C-B832-8B17CF6CBFE8'
GG_GATTLINK_TX_CHARACTERISTIC_UUID = 'ABBAFF02-E56A-484C-B832-8B17CF6CBFE8'
GG_GATTLINK_L2CAP_CHANNEL_PSM_CHARACTERISTIC_UUID = 'ABBAFF03-E56A-484C-B832-8B17CF6CBFE8'
GG_GATTLINK_L2CAP_CHANNEL_PSM_CHARACTERISTIC_UUID = (
'ABBAFF03-E56A-484C-B832-8B17CF6CBFE8'
)
GG_PREFERRED_MTU = 256
@@ -97,6 +99,7 @@ class GattlinkHubBridge(GattlinkL2capEndpoint, Device.Listener):
print(color(f'!!! Connection failed: {error}', 'red'))
@AsyncRunner.run_in_task()
# pylint: disable=invalid-overridden-method
async def on_connection(self, connection):
print(f'=== Connected to {connection}')
self.peer = Peer(connection)
@@ -127,7 +130,9 @@ class GattlinkHubBridge(GattlinkL2capEndpoint, Device.Listener):
self.rx_characteristic = characteristic
elif characteristic.uuid == GG_GATTLINK_TX_CHARACTERISTIC_UUID:
self.tx_characteristic = characteristic
elif characteristic.uuid == GG_GATTLINK_L2CAP_CHANNEL_PSM_CHARACTERISTIC_UUID:
elif (
characteristic.uuid == GG_GATTLINK_L2CAP_CHANNEL_PSM_CHARACTERISTIC_UUID
):
self.l2cap_psm_characteristic = characteristic
print('RX:', self.rx_characteristic)
print('TX:', self.tx_characteristic)
@@ -135,7 +140,9 @@ class GattlinkHubBridge(GattlinkL2capEndpoint, Device.Listener):
if self.l2cap_psm_characteristic:
# Subscribe to and then read the PSM value
await self.peer.subscribe(self.l2cap_psm_characteristic, self.on_l2cap_psm_received)
await self.peer.subscribe(
self.l2cap_psm_characteristic, self.on_l2cap_psm_received
)
psm_bytes = await self.peer.read_value(self.l2cap_psm_characteristic)
psm = struct.unpack('<H', psm_bytes)[0]
await self.connect_l2cap(psm)
@@ -150,7 +157,13 @@ class GattlinkHubBridge(GattlinkL2capEndpoint, Device.Listener):
print(color(f'!!! Connection failed: {error}'))
def on_disconnection(self, reason):
print(color(f'!!! Disconnected from {self.peer}, reason={HCI_Constant.error_name(reason)}', 'red'))
print(
color(
f'!!! Disconnected from {self.peer}, '
f'reason={HCI_Constant.error_name(reason)}',
'red',
)
)
self.tx_characteristic = None
self.rx_characteristic = None
self.peer = None
@@ -178,7 +191,7 @@ class GattlinkHubBridge(GattlinkL2capEndpoint, Device.Listener):
pass
# Called by asyncio when a UDP datagram is received
def datagram_received(self, data, address):
def datagram_received(self, data, _address):
print(color(f'<<< [UDP]: {len(data)} bytes', 'green'))
if self.l2cap_channel:
@@ -198,6 +211,7 @@ class GattlinkNodeBridge(GattlinkL2capEndpoint, Device.Listener):
self.tx_socket = None
self.tx_subscriber = None
self.rx_characteristic = None
self.transport = None
# Register as a listener
device.listener = self
@@ -212,35 +226,37 @@ class GattlinkNodeBridge(GattlinkL2capEndpoint, Device.Listener):
GG_GATTLINK_RX_CHARACTERISTIC_UUID,
Characteristic.WRITE_WITHOUT_RESPONSE,
Characteristic.WRITEABLE,
CharacteristicValue(write=self.on_rx_write)
CharacteristicValue(write=self.on_rx_write),
)
self.tx_characteristic = Characteristic(
GG_GATTLINK_TX_CHARACTERISTIC_UUID,
Characteristic.NOTIFY,
Characteristic.READABLE
Characteristic.READABLE,
)
self.tx_characteristic.on('subscription', self.on_tx_subscription)
self.psm_characteristic = Characteristic(
GG_GATTLINK_L2CAP_CHANNEL_PSM_CHARACTERISTIC_UUID,
Characteristic.READ | Characteristic.NOTIFY,
Characteristic.READABLE,
bytes([psm, 0])
bytes([psm, 0]),
)
gattlink_service = Service(
GG_GATTLINK_SERVICE_UUID,
[
self.rx_characteristic,
self.tx_characteristic,
self.psm_characteristic
]
[self.rx_characteristic, self.tx_characteristic, self.psm_characteristic],
)
device.add_services([gattlink_service])
device.advertising_data = bytes(
AdvertisingData([
AdvertisingData(
[
(AdvertisingData.COMPLETE_LOCAL_NAME, bytes('Bumble GG', 'utf-8')),
(AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS,
bytes(reversed(bytes.fromhex('ABBAFF00E56A484CB8328B17CF6CBFE8'))))
])
(
AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS,
bytes(
reversed(bytes.fromhex('ABBAFF00E56A484CB8328B17CF6CBFE8'))
),
),
]
)
)
async def start(self):
@@ -251,7 +267,7 @@ class GattlinkNodeBridge(GattlinkL2capEndpoint, Device.Listener):
self.transport = transport
# Called by asyncio when a UDP datagram is received
def datagram_received(self, data, address):
def datagram_received(self, data, _address):
print(color(f'<<< [UDP]: {len(data)} bytes', 'green'))
if self.l2cap_channel:
@@ -263,14 +279,17 @@ class GattlinkNodeBridge(GattlinkL2capEndpoint, Device.Listener):
asyncio.create_task(self.device.notify_subscribers(self.tx_characteristic))
# Called when a write to the RX characteristic has been received
def on_rx_write(self, connection, data):
def on_rx_write(self, _connection, data):
print(color(f'<<< [GATT RX]: {len(data)} bytes', 'cyan'))
print(color('>>> [UDP]', 'magenta'))
self.tx_socket.sendto(data)
# Called when the subscription to the TX characteristic has changed
def on_tx_subscription(self, peer, enabled):
print(f'### [GATT TX] subscription from {peer}: {"enabled" if enabled else "disabled"}')
print(
f'### [GATT TX] subscription from {peer}: '
f'{"enabled" if enabled else "disabled"}'
)
if enabled:
self.tx_subscriber = peer
else:
@@ -290,7 +309,15 @@ class GattlinkNodeBridge(GattlinkL2capEndpoint, Device.Listener):
# -----------------------------------------------------------------------------
async def run(hci_transport, device_address, role_or_peer_address, send_host, send_port, receive_host, receive_port):
async def run(
hci_transport,
device_address,
role_or_peer_address,
send_host,
send_port,
receive_host,
receive_port,
):
print('<<< connecting to HCI...')
async with await open_transport_or_link(hci_transport) as (hci_source, hci_sink):
print('<<< connected')
@@ -307,14 +334,14 @@ async def run(hci_transport, device_address, role_or_peer_address, send_host, se
# Create a UDP to RX bridge (receive from UDP, send to RX)
loop = asyncio.get_running_loop()
await loop.create_datagram_endpoint(
lambda: bridge,
local_addr=(receive_host, receive_port)
lambda: bridge, local_addr=(receive_host, receive_port)
)
# Create a UDP to TX bridge (receive from TX, send to UDP)
bridge.tx_socket, _ = await loop.create_datagram_endpoint(
# pylint: disable-next=unnecessary-lambda
lambda: asyncio.DatagramProtocol(),
remote_addr=(send_host, send_port)
remote_addr=(send_host, send_port),
)
await device.power_on()
@@ -328,12 +355,40 @@ async def run(hci_transport, device_address, role_or_peer_address, send_host, se
@click.argument('hci_transport')
@click.argument('device_address')
@click.argument('role_or_peer_address')
@click.option('-sh', '--send-host', type=str, default='127.0.0.1', help='UDP host to send to')
@click.option(
'-sh', '--send-host', type=str, default='127.0.0.1', help='UDP host to send to'
)
@click.option('-sp', '--send-port', type=int, default=9001, help='UDP port to send to')
@click.option('-rh', '--receive-host', type=str, default='127.0.0.1', help='UDP host to receive on')
@click.option('-rp', '--receive-port', type=int, default=9000, help='UDP port to receive on')
def main(hci_transport, device_address, role_or_peer_address, send_host, send_port, receive_host, receive_port):
asyncio.run(run(hci_transport, device_address, role_or_peer_address, send_host, send_port, receive_host, receive_port))
@click.option(
'-rh',
'--receive-host',
type=str,
default='127.0.0.1',
help='UDP host to receive on',
)
@click.option(
'-rp', '--receive-port', type=int, default=9000, help='UDP port to receive on'
)
def main(
hci_transport,
device_address,
role_or_peer_address,
send_host,
send_port,
receive_host,
receive_port,
):
asyncio.run(
run(
hci_transport,
device_address,
role_or_peer_address,
send_host,
send_port,
receive_host,
receive_port,
)
)
# -----------------------------------------------------------------------------

View File

@@ -34,16 +34,29 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
async def async_main():
if len(sys.argv) < 3:
print('Usage: hci_bridge.py <host-transport-spec> <controller-transport-spec> [command-short-circuit-list]')
print('example: python hci_bridge.py udp:0.0.0.0:9000,127.0.0.1:9001 serial:/dev/tty.usbmodem0006839912171,1000000 0x3f:0x0070,0x3f:0x0074,0x3f:0x0077,0x3f:0x0078')
print(
'Usage: hci_bridge.py <host-transport-spec> <controller-transport-spec> '
'[command-short-circuit-list]'
)
print(
'example: python hci_bridge.py udp:0.0.0.0:9000,127.0.0.1:9001 '
'serial:/dev/tty.usbmodem0006839912171,1000000 '
'0x3f:0x0070,0x3f:0x0074,0x3f:0x0077,0x3f:0x0078'
)
return
print('>>> connecting to HCI...')
async with await transport.open_transport_or_link(sys.argv[1]) as (hci_host_source, hci_host_sink):
async with await transport.open_transport_or_link(sys.argv[1]) as (
hci_host_source,
hci_host_sink,
):
print('>>> connected')
print('>>> connecting to HCI...')
async with await transport.open_transport_or_link(sys.argv[2]) as (hci_controller_source, hci_controller_sink):
async with await transport.open_transport_or_link(sys.argv[2]) as (
hci_controller_source,
hci_controller_sink,
):
print('>>> connected')
command_short_circuits = []
@@ -51,29 +64,36 @@ async def async_main():
for op_code_str in sys.argv[3].split(','):
if ':' in op_code_str:
ogf, ocf = op_code_str.split(':')
command_short_circuits.append(hci.hci_command_op_code(int(ogf, 16), int(ocf, 16)))
command_short_circuits.append(
hci.hci_command_op_code(int(ogf, 16), int(ocf, 16))
)
else:
command_short_circuits.append(int(op_code_str, 16))
def host_to_controller_filter(hci_packet):
if hci_packet.hci_packet_type == hci.HCI_COMMAND_PACKET and hci_packet.op_code in command_short_circuits:
if (
hci_packet.hci_packet_type == hci.HCI_COMMAND_PACKET
and hci_packet.op_code in command_short_circuits
):
# Respond with a success response
logger.debug('short-circuiting packet')
response = hci.HCI_Command_Complete_Event(
num_hci_command_packets=1,
command_opcode=hci_packet.op_code,
return_parameters = bytes([hci.HCI_SUCCESS])
return_parameters=bytes([hci.HCI_SUCCESS]),
)
# Return a packet with 'respond to sender' set to True
return (response.to_bytes(), True)
return None
_ = HCI_Bridge(
hci_host_source,
hci_host_sink,
hci_controller_source,
hci_controller_sink,
host_to_controller_filter,
None
None,
)
await asyncio.get_running_loop().create_future()

View File

@@ -16,9 +16,9 @@
# Imports
# -----------------------------------------------------------------------------
import asyncio
import click
import logging
import os
import click
from colors import color
from bumble.transport import open_transport_or_link
@@ -38,15 +38,8 @@ class ServerBridge:
and waits for a new L2CAP CoC channel to be connected.
When the TCP connection is closed by the TCP server, XXXX
"""
def __init__(
self,
psm,
max_credits,
mtu,
mps,
tcp_host,
tcp_port
):
def __init__(self, psm, max_credits, mtu, mps, tcp_host, tcp_port):
self.psm = psm
self.max_credits = max_credits
self.mtu = mtu
@@ -61,13 +54,16 @@ class ServerBridge:
server=self.on_coc,
max_credits=self.max_credits,
mtu=self.mtu,
mps = self.mps
mps=self.mps,
)
print(color(f'### Listening for CoC connection on PSM {self.psm}', 'yellow'))
def on_ble_connection(connection):
def on_ble_disconnection(reason):
print(color('@@@ Bluetooth disconnection:', 'red'), HCI_Constant.error_name(reason))
print(
color('@@@ Bluetooth disconnection:', 'red'),
HCI_Constant.error_name(reason),
)
print(color('@@@ Bluetooth connection:', 'green'), connection)
connection.on('disconnection', on_ble_disconnection)
@@ -91,14 +87,20 @@ class ServerBridge:
async def connect_to_tcp(self):
# Connect to the TCP server
print(color(f'### Connecting to TCP {self.bridge.tcp_host}:{self.bridge.tcp_port}...', 'yellow'))
print(
color(
f'### Connecting to TCP {self.bridge.tcp_host}:'
f'{self.bridge.tcp_port}...',
'yellow',
)
)
class TcpClientProtocol(asyncio.Protocol):
def __init__(self, pipe):
self.pipe = pipe
def connection_lost(self, error):
print(color(f'!!! TCP connection lost: {error}', 'red'))
def connection_lost(self, exc):
print(color(f'!!! TCP connection lost: {exc}', 'red'))
if self.pipe.l2cap_channel is not None:
asyncio.create_task(self.pipe.l2cap_channel.disconnect())
@@ -107,7 +109,10 @@ class ServerBridge:
self.pipe.l2cap_channel.write(data)
try:
self.tcp_transport, _ = await asyncio.get_running_loop().create_connection(
(
self.tcp_transport,
_,
) = await asyncio.get_running_loop().create_connection(
lambda: TcpClientProtocol(self),
host=self.bridge.tcp_host,
port=self.bridge.tcp_port,
@@ -149,16 +154,7 @@ class ClientBridge:
READ_CHUNK_SIZE = 4096
def __init__(
self,
psm,
max_credits,
mtu,
mps,
address,
tcp_host,
tcp_port
):
def __init__(self, psm, max_credits, mtu, mps, address, tcp_host, tcp_port):
self.psm = psm
self.max_credits = max_credits
self.mtu = mtu
@@ -174,14 +170,17 @@ class ClientBridge:
# Called when the BLE connection is disconnected
def on_ble_disconnection(reason):
print(color('@@@ Bluetooth disconnection:', 'red'), HCI_Constant.error_name(reason))
print(
color('@@@ Bluetooth disconnection:', 'red'),
HCI_Constant.error_name(reason),
)
connection.on('disconnection', on_ble_disconnection)
# Called when a TCP connection is established
async def on_tcp_connection(reader, writer):
peername = writer.get_extra_info('peername')
print(color(f'<<< TCP connection from {peername}', 'magenta'))
peer_name = writer.get_extra_info('peer_name')
print(color(f'<<< TCP connection from {peer_name}', 'magenta'))
def on_coc_sdu(sdu):
print(color(f'<<< [L2CAP SDU]: {len(sdu)} bytes', 'cyan'))
@@ -199,7 +198,7 @@ class ClientBridge:
psm=self.psm,
max_credits=self.max_credits,
mtu=self.mtu,
mps = self.mps
mps=self.mps,
)
print(color('*** L2CAP channel:', 'cyan'), l2cap_channel)
except Exception as error:
@@ -215,7 +214,7 @@ class ClientBridge:
l2cap_channel.pause_reading,
l2cap_channel.resume_reading,
writer.write,
writer.drain
writer.drain,
)
l2cap_to_tcp_pipe.start()
@@ -242,9 +241,13 @@ class ClientBridge:
await asyncio.start_server(
on_tcp_connection,
host=self.tcp_host if self.tcp_host != '_' else None,
port=self.tcp_port
port=self.tcp_port,
)
print(
color(
f'### Listening for TCP connections on port {self.tcp_port}', 'magenta'
)
)
print(color(f'### Listening for TCP connections on port {self.tcp_port}', 'magenta'))
# -----------------------------------------------------------------------------
@@ -269,10 +272,33 @@ async def run(device_config, hci_transport, bridge):
@click.option('--device-config', help='Device configuration file', required=True)
@click.option('--hci-transport', help='HCI transport', required=True)
@click.option('--psm', help='PSM for L2CAP CoC', type=int, default=1234)
@click.option('--l2cap-coc-max-credits', help='Maximum L2CAP CoC Credits', type=click.IntRange(1, 65535), default=128)
@click.option('--l2cap-coc-mtu', help='L2CAP CoC MTU', type=click.IntRange(23, 65535), default=1022)
@click.option('--l2cap-coc-mps', help='L2CAP CoC MPS', type=click.IntRange(23, 65533), default=1024)
def cli(context, device_config, hci_transport, psm, l2cap_coc_max_credits, l2cap_coc_mtu, l2cap_coc_mps):
@click.option(
'--l2cap-coc-max-credits',
help='Maximum L2CAP CoC Credits',
type=click.IntRange(1, 65535),
default=128,
)
@click.option(
'--l2cap-coc-mtu',
help='L2CAP CoC MTU',
type=click.IntRange(23, 65535),
default=1022,
)
@click.option(
'--l2cap-coc-mps',
help='L2CAP CoC MPS',
type=click.IntRange(23, 65533),
default=1024,
)
def cli(
context,
device_config,
hci_transport,
psm,
l2cap_coc_max_credits,
l2cap_coc_mtu,
l2cap_coc_mps,
):
context.ensure_object(dict)
context.obj['device_config'] = device_config
context.obj['hci_transport'] = hci_transport
@@ -294,12 +320,9 @@ def server(context, tcp_host, tcp_port):
context.obj['mtu'],
context.obj['mps'],
tcp_host,
tcp_port)
asyncio.run(run(
context.obj['device_config'],
context.obj['hci_transport'],
bridge
))
tcp_port,
)
asyncio.run(run(context.obj['device_config'], context.obj['hci_transport'], bridge))
# -----------------------------------------------------------------------------
@@ -316,16 +339,12 @@ def client(context, bluetooth_address, tcp_host, tcp_port):
context.obj['mps'],
bluetooth_address,
tcp_host,
tcp_port
tcp_port,
)
asyncio.run(run(
context.obj['device_config'],
context.obj['hci_transport'],
bridge
))
asyncio.run(run(context.obj['device_config'], context.obj['hci_transport'], bridge))
# -----------------------------------------------------------------------------
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
if __name__ == '__main__':
cli(obj={})
cli(obj={}) # pylint: disable=no-value-for-parameter

View File

@@ -16,7 +16,6 @@
# Imports
# ----------------------------------------------------------------------------
import sys
import websockets
import logging
import json
import asyncio
@@ -25,6 +24,7 @@ import uuid
import os
from urllib.parse import urlparse
from colors import color
import websockets
# -----------------------------------------------------------------------------
# Logging
@@ -98,7 +98,11 @@ class Connection:
self.address = address
def __str__(self):
return f'Connection(address="{self.address}", client={self.websocket.remote_address[0]}:{self.websocket.remote_address[1]})'
return (
f'Connection(address="{self.address}", '
f'client={self.websocket.remote_address[0]}:'
f'{self.websocket.remote_address[1]})'
)
# ----------------------------------------------------------------------------
@@ -139,13 +143,15 @@ class Room:
# Parse the message to decide how to handle it
if message.startswith('@'):
# This is a targetted message
await self.on_targetted_message(connection, message)
# This is a targeted message
await self.on_targeted_message(connection, message)
elif message.startswith('/'):
# This is an RPC request
await self.on_rpc_request(connection, message)
else:
await connection.send_message(f'result:{error_to_json("error: invalid message")}')
await connection.send_message(
f'result:{error_to_json("error: invalid message")}'
)
async def broadcast_message(self, sender, message):
'''
@@ -155,7 +161,9 @@ class Room:
async def on_rpc_request(self, connection, message):
command, *params = message.split(' ', 1)
if handler := getattr(self, f'on_{command[1:].lower().replace("-","_")}_command', None):
if handler := getattr(
self, f'on_{command[1:].lower().replace("-","_")}_command', None
):
try:
result = await handler(connection, params)
except Exception as error:
@@ -165,7 +173,7 @@ class Room:
await connection.send_message(result or 'result:{}')
async def on_targetted_message(self, connection, message):
async def on_targeted_message(self, connection, message):
target, *payload = message.split(' ', 1)
if not payload:
return error_to_json('missing arguments')
@@ -174,7 +182,8 @@ class Room:
# Determine what targets to send to
if target == '*':
# Send to all connections in the room except the connection from which the message was received
# Send to all connections in the room except the connection from which the
# message was received
connections = [c for c in self.connections if c != connection]
else:
connections = self.find_connections_by_address(target)
@@ -192,7 +201,9 @@ class Room:
current_address = connection.address
new_address = params[0]
connection.set_address(new_address)
await self.broadcast_message(connection, f'address-changed:from={current_address},to={new_address}')
await self.broadcast_message(
connection, f'address-changed:from={current_address},to={new_address}'
)
# ----------------------------------------------------------------------------
@@ -210,9 +221,10 @@ class Relay:
def start(self):
logger.info(f'Starting Relay on port {self.port}')
# pylint: disable-next=no-member
return websockets.serve(self.serve, '0.0.0.0', self.port, ping_interval=None)
async def serve_as_controller(connection):
async def serve_as_controller(self, connection):
pass
async def serve(self, websocket, path):
@@ -252,15 +264,15 @@ def main():
arg_parser = argparse.ArgumentParser(description='Bumble Link Relay')
arg_parser.add_argument('--log-level', default='INFO', help='logger level')
arg_parser.add_argument('--log-config', help='logger config file (YAML)')
arg_parser.add_argument('--port',
type = int,
default = DEFAULT_RELAY_PORT,
help = 'Port to listen on')
arg_parser.add_argument(
'--port', type=int, default=DEFAULT_RELAY_PORT, help='Port to listen on'
)
args = arg_parser.parse_args()
# Setup logger
if args.log_config:
from logging import config
from logging import config # pylint: disable=import-outside-toplevel
config.fileConfig(args.log_config)
else:
logging.basicConfig(level=getattr(logging, args.log_level.upper()))

View File

@@ -33,25 +33,27 @@ from bumble.gatt import (
GATT_GENERIC_ACCESS_SERVICE,
Service,
Characteristic,
CharacteristicValue
CharacteristicValue,
)
from bumble.att import (
ATT_Error,
ATT_INSUFFICIENT_AUTHENTICATION_ERROR,
ATT_INSUFFICIENT_ENCRYPTION_ERROR
ATT_INSUFFICIENT_ENCRYPTION_ERROR,
)
# -----------------------------------------------------------------------------
class Delegate(PairingDelegate):
def __init__(self, mode, connection, capability_string, prompt):
super().__init__({
super().__init__(
{
'keyboard': PairingDelegate.KEYBOARD_INPUT_ONLY,
'display': PairingDelegate.DISPLAY_OUTPUT_ONLY,
'display+keyboard': PairingDelegate.DISPLAY_OUTPUT_AND_KEYBOARD_INPUT,
'display+yes/no': PairingDelegate.DISPLAY_OUTPUT_AND_YES_NO_INPUT,
'none': PairingDelegate.NO_OUTPUT_NO_INPUT
}[capability_string.lower()])
'none': PairingDelegate.NO_OUTPUT_NO_INPUT,
}[capability_string.lower()]
)
self.mode = mode
self.peer = Peer(connection)
@@ -84,15 +86,17 @@ class Delegate(PairingDelegate):
while True:
response = await aioconsole.ainput(color('>>> Accept? ', 'yellow'))
response = response.lower().strip()
if response == 'yes':
return True
elif response == 'no':
if response == 'no':
return False
else:
# Accept silently
return True
async def compare_numbers(self, number, digits):
async def compare_numbers(self, number, digits=6):
await self.update_peer_name()
# Wait a bit to allow some of the log lines to print before we prompt
@@ -103,11 +107,17 @@ class Delegate(PairingDelegate):
print(color(f'### Pairing with {self.peer_name}', 'yellow'))
print(color('###-----------------------------------', 'yellow'))
while True:
response = await aioconsole.ainput(color(f'>>> Does the other device display {number:0{digits}}? ', 'yellow'))
response = await aioconsole.ainput(
color(
f'>>> Does the other device display {number:0{digits}}? ', 'yellow'
)
)
response = response.lower().strip()
if response == 'yes':
return True
elif response == 'no':
if response == 'no':
return False
async def get_number(self):
@@ -126,7 +136,7 @@ class Delegate(PairingDelegate):
except ValueError:
pass
async def display_number(self, number, digits):
async def display_number(self, number, digits=6):
await self.update_peer_name()
# Wait a bit to allow some of the log lines to print before we prompt
@@ -143,16 +153,20 @@ class Delegate(PairingDelegate):
async def get_peer_name(peer, mode):
if mode == 'classic':
return await peer.request_name()
else:
# Try to get the peer name from GATT
services = await peer.discover_service(GATT_GENERIC_ACCESS_SERVICE)
if not services:
return None
values = await peer.read_characteristics_by_uuid(GATT_DEVICE_NAME_CHARACTERISTIC, services[0])
values = await peer.read_characteristics_by_uuid(
GATT_DEVICE_NAME_CHARACTERISTIC, services[0]
)
if values:
return values[0].decode('utf-8')
return None
# -----------------------------------------------------------------------------
AUTHENTICATION_ERROR_RETURNED = [False, False]
@@ -164,12 +178,12 @@ def read_with_error(connection):
if AUTHENTICATION_ERROR_RETURNED[0]:
return bytes([1])
else:
AUTHENTICATION_ERROR_RETURNED[0] = True
raise ATT_Error(ATT_INSUFFICIENT_AUTHENTICATION_ERROR)
def write_with_error(connection, value):
def write_with_error(connection, _value):
if not connection.is_encrypted:
raise ATT_Error(ATT_INSUFFICIENT_ENCRYPTION_ERROR)
@@ -190,7 +204,7 @@ def on_connection(connection, request):
# Listen for encryption changes
connection.on(
'connection_encryption_change',
lambda: on_connection_encryption_change(connection)
lambda: on_connection_encryption_change(connection),
)
# Request pairing if needed
@@ -202,7 +216,12 @@ def on_connection(connection, request):
# -----------------------------------------------------------------------------
def on_connection_encryption_change(connection):
print(color('@@@-----------------------------------', 'blue'))
print(color(f'@@@ Connection is {"" if connection.is_encrypted else "not"}encrypted', 'blue'))
print(
color(
f'@@@ Connection is {"" if connection.is_encrypted else "not"}encrypted',
'blue',
)
)
print(color('@@@-----------------------------------', 'blue'))
@@ -241,7 +260,7 @@ async def pair(
keystore_file,
device_config,
hci_transport,
address_or_name
address_or_name,
):
print('<<< connecting to HCI...')
async with await open_transport_or_link(hci_transport) as (hci_source, hci_sink):
@@ -272,9 +291,11 @@ async def pair(
'552957FB-CF1F-4A31-9535-E78847E1A714',
Characteristic.READ | Characteristic.WRITE,
Characteristic.READABLE | Characteristic.WRITEABLE,
CharacteristicValue(read=read_with_error, write=write_with_error)
CharacteristicValue(
read=read_with_error, write=write_with_error
),
)
]
],
)
)
@@ -288,10 +309,7 @@ async def pair(
# Set up a pairing config factory
device.pairing_config_factory = lambda connection: PairingConfig(
sc,
mitm,
bond,
Delegate(mode, connection, io, prompt)
sc, mitm, bond, Delegate(mode, connection, io, prompt)
)
# Connect to a peer or wait for a connection
@@ -319,21 +337,70 @@ async def pair(
# -----------------------------------------------------------------------------
@click.command()
@click.option('--mode', type=click.Choice(['le', 'classic']), default='le', show_default=True)
@click.option('--sc', type=bool, default=True, help='Use the Secure Connections protocol', show_default=True)
@click.option('--mitm', type=bool, default=True, help='Request MITM protection', show_default=True)
@click.option('--bond', type=bool, default=True, help='Enable bonding', show_default=True)
@click.option('--io', type=click.Choice(['keyboard', 'display', 'display+keyboard', 'display+yes/no', 'none']), default='display+keyboard', show_default=True)
@click.option(
'--mode', type=click.Choice(['le', 'classic']), default='le', show_default=True
)
@click.option(
'--sc',
type=bool,
default=True,
help='Use the Secure Connections protocol',
show_default=True,
)
@click.option(
'--mitm', type=bool, default=True, help='Request MITM protection', show_default=True
)
@click.option(
'--bond', type=bool, default=True, help='Enable bonding', show_default=True
)
@click.option(
'--io',
type=click.Choice(
['keyboard', 'display', 'display+keyboard', 'display+yes/no', 'none']
),
default='display+keyboard',
show_default=True,
)
@click.option('--prompt', is_flag=True, help='Prompt to accept/reject pairing request')
@click.option('--request', is_flag=True, help='Request that the connecting peer initiate pairing')
@click.option(
'--request', is_flag=True, help='Request that the connecting peer initiate pairing'
)
@click.option('--print-keys', is_flag=True, help='Print the bond keys before pairing')
@click.option('--keystore-file', help='File in which to store the pairing keys')
@click.argument('device-config')
@click.argument('hci_transport')
@click.argument('address-or-name', required=False)
def main(mode, sc, mitm, bond, io, prompt, request, print_keys, keystore_file, device_config, hci_transport, address_or_name):
def main(
mode,
sc,
mitm,
bond,
io,
prompt,
request,
print_keys,
keystore_file,
device_config,
hci_transport,
address_or_name,
):
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(pair(mode, sc, mitm, bond, io, prompt, request, print_keys, keystore_file, device_config, hci_transport, address_or_name))
asyncio.run(
pair(
mode,
sc,
mitm,
bond,
io,
prompt,
request,
print_keys,
keystore_file,
device_config,
hci_transport,
address_or_name,
)
)
# -----------------------------------------------------------------------------

View File

@@ -63,7 +63,9 @@ class AdvertisementPrinter:
resolution_qualifier = f'(resolved from {advertisement.address})'
address = resolved
address_type_string = ('PUBLIC', 'RANDOM', 'PUBLIC_ID', 'RANDOM_ID')[address.address_type]
address_type_string = ('PUBLIC', 'RANDOM', 'PUBLIC_ID', 'RANDOM_ID')[
address.address_type
]
if address.is_public:
type_color = 'cyan'
else:
@@ -90,10 +92,12 @@ class AdvertisementPrinter:
print(
f'>>> {color(address, address_color)} '
f'[{color(address_type_string, type_color)}]{address_qualifier}{resolution_qualifier}:{separator}'
f'[{color(address_type_string, type_color)}]{address_qualifier}'
f'{resolution_qualifier}:{separator}'
f'{phy_info}'
f'RSSI:{advertisement.rssi:4} {rssi_bar}{separator}'
f'{advertisement.data.to_string(separator)}\n')
f'{advertisement.data.to_string(separator)}\n'
)
def on_advertisement(self, advertisement):
self.print_advertisement(advertisement)
@@ -114,16 +118,20 @@ async def scan(
raw,
keystore_file,
device_config,
transport
transport,
):
print('<<< connecting to HCI...')
async with await open_transport_or_link(transport) as (hci_source, hci_sink):
print('<<< connected')
if device_config:
device = Device.from_config_file_with_hci(device_config, hci_source, hci_sink)
device = Device.from_config_file_with_hci(
device_config, hci_source, hci_sink
)
else:
device = Device.with_hci('Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink)
device = Device.with_hci(
'Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink
)
if keystore_file:
keystore = JsonKeyStore(namespace=None, filename=keystore_file)
@@ -153,7 +161,7 @@ async def scan(
scan_interval=scan_interval,
scan_window=scan_window,
filter_duplicates=filter_duplicates,
scanning_phys=scanning_phys
scanning_phys=scanning_phys,
)
await hci_source.wait_for_termination()
@@ -165,15 +173,51 @@ async def scan(
@click.option('--passive', is_flag=True, default=False, help='Perform passive scanning')
@click.option('--scan-interval', type=int, default=60, help='Scan interval')
@click.option('--scan-window', type=int, default=60, help='Scan window')
@click.option('--phy', type=click.Choice(['1m', 'coded']), help='Only scan on the specified PHY')
@click.option('--filter-duplicates', type=bool, default=True, help='Filter duplicates at the controller level')
@click.option('--raw', is_flag=True, default=False, help='Listen for raw advertising reports instead of processed ones')
@click.option(
'--phy', type=click.Choice(['1m', 'coded']), help='Only scan on the specified PHY'
)
@click.option(
'--filter-duplicates',
type=bool,
default=True,
help='Filter duplicates at the controller level',
)
@click.option(
'--raw',
is_flag=True,
default=False,
help='Listen for raw advertising reports instead of processed ones',
)
@click.option('--keystore-file', help='Keystore file to use when resolving addresses')
@click.option('--device-config', help='Device config file for the scanning device')
@click.argument('transport')
def main(min_rssi, passive, scan_interval, scan_window, phy, filter_duplicates, raw, keystore_file, device_config, transport):
def main(
min_rssi,
passive,
scan_interval,
scan_window,
phy,
filter_duplicates,
raw,
keystore_file,
device_config,
transport,
):
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
asyncio.run(scan(min_rssi, passive, scan_interval, scan_window, phy, filter_duplicates, raw, keystore_file, device_config, transport))
asyncio.run(
scan(
min_rssi,
passive,
scan_interval,
scan_window,
phy,
filter_duplicates,
raw,
keystore_file,
device_config,
transport,
)
)
# -----------------------------------------------------------------------------

View File

@@ -27,7 +27,8 @@ from bumble.helpers import PacketTracer
# -----------------------------------------------------------------------------
class SnoopPacketReader:
'''
Reader that reads HCI packets from a "snoop" file (based on RFC 1761, but not exactly the same...)
Reader that reads HCI packets from a "snoop" file (based on RFC 1761, but not
exactly the same...)
'''
DATALINK_H1 = 1001
@@ -41,9 +42,13 @@ class SnoopPacketReader:
# Read the header
identification_pattern = source.read(8)
if identification_pattern.hex().lower() != '6274736e6f6f7000':
raise ValueError('not a valid snoop file, unexpected identification pattern')
(self.version_number, self.data_link_type) = struct.unpack('>II', source.read(8))
if self.data_link_type != self.DATALINK_H4 and self.data_link_type != self.DATALINK_H1:
raise ValueError(
'not a valid snoop file, unexpected identification pattern'
)
(self.version_number, self.data_link_type) = struct.unpack(
'>II', source.read(8)
)
if self.data_link_type not in (self.DATALINK_H4, self.DATALINK_H1):
raise ValueError(f'datalink type {self.data_link_type} not supported')
def next_packet(self):
@@ -55,9 +60,9 @@ class SnoopPacketReader:
original_length,
included_length,
packet_flags,
cumulative_drops,
timestamp_seconds,
timestamp_microsecond
_cumulative_drops,
_timestamp_seconds,
_timestamp_microsecond,
) = struct.unpack('>IIIIII', header)
# Abort on truncated packets
@@ -79,8 +84,11 @@ class SnoopPacketReader:
else:
packet_type = hci.HCI_ACL_DATA_PACKET
return (packet_flags & 1, bytes([packet_type]) + self.source.read(included_length))
else:
return (
packet_flags & 1,
bytes([packet_type]) + self.source.read(included_length),
)
return (packet_flags & 1, self.source.read(included_length))
@@ -88,15 +96,22 @@ class SnoopPacketReader:
# Main
# -----------------------------------------------------------------------------
@click.command()
@click.option('--format', type=click.Choice(['h4', 'snoop']), default='h4', help='Format of the input file')
@click.option(
'--format',
type=click.Choice(['h4', 'snoop']),
default='h4',
help='Format of the input file',
)
@click.argument('filename')
# pylint: disable=redefined-builtin
def main(format, filename):
input = open(filename, 'rb')
if format == 'h4':
packet_reader = PacketReader(input)
def read_next_packet():
(0, packet_reader.next_packet())
return (0, packet_reader.next_packet())
else:
packet_reader = SnoopPacketReader(input)
read_next_packet = packet_reader.next_packet
@@ -112,9 +127,8 @@ def main(format, filename):
except Exception as error:
print(color(f'!!! {error}', 'red'))
pass
# -----------------------------------------------------------------------------
if __name__ == '__main__':
main()
main() # pylint: disable=no-value-for-parameter

View File

@@ -28,11 +28,12 @@
# -----------------------------------------------------------------------------
import os
import logging
import sys
import click
import usb1
from colors import color
from bumble.transport.usb import load_libusb
# -----------------------------------------------------------------------------
# Constants
@@ -69,13 +70,13 @@ USB_DEVICE_CLASSES = {
0x01: 'Bluetooth',
0x02: 'UWB',
0x03: 'Remote NDIS',
0x04: 'Bluetooth AMP'
}
0x04: 'Bluetooth AMP',
}
},
),
0xEF: 'Miscellaneous',
0xFE: 'Application Specific',
0xFF: 'Vendor Specific'
0xFF: 'Vendor Specific',
}
USB_ENDPOINT_IN = 0x80
@@ -84,7 +85,7 @@ USB_ENDPOINT_TYPES = ['CONTROL', 'ISOCHRONOUS', 'BULK', 'INTERRUPT']
USB_BT_HCI_CLASS_TUPLE = (
USB_DEVICE_CLASS_WIRELESS_CONTROLLER,
USB_DEVICE_SUBCLASS_RF_CONTROLLER,
USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER
USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER,
)
@@ -94,19 +95,26 @@ def show_device_details(device):
print(f' Configuration {configuration.getConfigurationValue()}')
for interface in configuration:
for setting in interface:
alternateSetting = setting.getAlternateSetting()
suffix = f'/{alternateSetting}' if interface.getNumSettings() > 1 else ''
alternate_setting = setting.getAlternateSetting()
suffix = (
f'/{alternate_setting}' if interface.getNumSettings() > 1 else ''
)
(class_string, subclass_string) = get_class_info(
setting.getClass(),
setting.getSubClass(),
setting.getProtocol()
setting.getClass(), setting.getSubClass(), setting.getProtocol()
)
details = f'({class_string}, {subclass_string})'
print(f' Interface: {setting.getNumber()}{suffix} {details}')
for endpoint in setting:
endpoint_type = USB_ENDPOINT_TYPES[endpoint.getAttributes() & 3]
endpoint_direction = 'OUT' if (endpoint.getAddress() & USB_ENDPOINT_IN == 0) else 'IN'
print(f' Endpoint 0x{endpoint.getAddress():02X}: {endpoint_type} {endpoint_direction}')
endpoint_direction = (
'OUT'
if (endpoint.getAddress() & USB_ENDPOINT_IN == 0)
else 'IN'
)
print(
f' Endpoint 0x{endpoint.getAddress():02X}: '
f'{endpoint_type} {endpoint_direction}'
)
# -----------------------------------------------------------------------------
@@ -116,7 +124,7 @@ def get_class_info(cls, subclass, protocol):
if class_info is None:
class_string = f'0x{cls:02X}'
else:
if type(class_info) is tuple:
if isinstance(class_info, tuple):
class_string = class_info[0]
subclass_info = class_info[1].get(subclass)
if subclass_info:
@@ -135,7 +143,11 @@ def get_class_info(cls, subclass, protocol):
# -----------------------------------------------------------------------------
def is_bluetooth_hci(device):
# Check if the device class indicates a match
if (device.getDeviceClass(), device.getDeviceSubClass(), device.getDeviceProtocol()) == USB_BT_HCI_CLASS_TUPLE:
if (
device.getDeviceClass(),
device.getDeviceSubClass(),
device.getDeviceProtocol(),
) == USB_BT_HCI_CLASS_TUPLE:
return True
# If the device class is 'Device', look for a matching interface
@@ -143,7 +155,11 @@ def is_bluetooth_hci(device):
for configuration in device:
for interface in configuration:
for setting in interface:
if (setting.getClass(), setting.getSubClass(), setting.getProtocol()) == USB_BT_HCI_CLASS_TUPLE:
if (
setting.getClass(),
setting.getSubClass(),
setting.getProtocol(),
) == USB_BT_HCI_CLASS_TUPLE:
return True
return False
@@ -155,6 +171,7 @@ def is_bluetooth_hci(device):
def main(verbose):
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
load_libusb()
with usb1.USBContext() as context:
bluetooth_device_count = 0
devices = {}
@@ -167,9 +184,7 @@ def main(verbose):
device_id = (device.getVendorID(), device.getProductID())
(device_class_string, device_subclass_string) = get_class_info(
device_class,
device_subclass,
device_protocol
device_class, device_subclass, device_protocol
)
try:
@@ -198,7 +213,9 @@ def main(verbose):
# Compute the different ways this can be referenced as a Bumble transport
bumble_transport_names = []
basic_transport_name = f'usb:{device.getVendorID():04X}:{device.getProductID():04X}'
basic_transport_name = (
f'usb:{device.getVendorID():04X}:{device.getProductID():04X}'
)
if device_is_bluetooth_hci:
bumble_transport_names.append(f'usb:{bluetooth_device_count - 1}')
@@ -206,17 +223,39 @@ def main(verbose):
if device_id not in devices:
bumble_transport_names.append(basic_transport_name)
else:
bumble_transport_names.append(f'{basic_transport_name}#{len(devices[device_id])}')
bumble_transport_names.append(
f'{basic_transport_name}#{len(devices[device_id])}'
)
if device_serial_number is not None:
if device_id not in devices or device_serial_number not in devices[device_id]:
bumble_transport_names.append(f'{basic_transport_name}/{device_serial_number}')
if (
device_id not in devices
or device_serial_number not in devices[device_id]
):
bumble_transport_names.append(
f'{basic_transport_name}/{device_serial_number}'
)
# Print the results
print(color(f'ID {device.getVendorID():04X}:{device.getProductID():04X}', fg=fg_color, bg=bg_color))
print(
color(
f'ID {device.getVendorID():04X}:{device.getProductID():04X}',
fg=fg_color,
bg=bg_color,
)
)
if bumble_transport_names:
print(color(' Bumble Transport Names:', 'blue'), ' or '.join(color(x, 'cyan' if device_is_bluetooth_hci else 'red') for x in bumble_transport_names))
print(color(' Bus/Device: ', 'green'), f'{device.getBusNumber():03}/{device.getDeviceAddress():03}')
print(
color(' Bumble Transport Names:', 'blue'),
' or '.join(
color(x, 'cyan' if device_is_bluetooth_hci else 'red')
for x in bumble_transport_names
),
)
print(
color(' Bus/Device: ', 'green'),
f'{device.getBusNumber():03}/{device.getDeviceAddress():03}',
)
print(color(' Class: ', 'green'), device_class_string)
print(color(' Subclass/Protocol: ', 'green'), device_subclass_string)
if device_serial_number is not None:
@@ -236,4 +275,4 @@ def main(verbose):
# -----------------------------------------------------------------------------
if __name__ == '__main__':
main()
main() # pylint: disable=no-value-for-parameter

View File

@@ -0,0 +1,4 @@
try:
from ._version import version as __version__
except ImportError:
__version__ = "unknown version"

View File

@@ -16,10 +16,9 @@
# Imports
# -----------------------------------------------------------------------------
import struct
import bitstruct
import logging
from collections import namedtuple
from colors import color
import bitstruct
from .company_ids import COMPANY_IDENTIFIERS
from .sdp import (
@@ -30,7 +29,7 @@ from .sdp import (
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
)
from .core import (
BT_L2CAP_PROTOCOL_ID,
@@ -38,7 +37,7 @@ from .core import (
BT_AUDIO_SINK_SERVICE,
BT_AVDTP_PROTOCOL_ID,
BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE,
name_or_number
name_or_number,
)
@@ -51,6 +50,7 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
# fmt: off
A2DP_SBC_CODEC_TYPE = 0x00
A2DP_MPEG_1_2_AUDIO_CODEC_TYPE = 0x01
@@ -127,71 +127,115 @@ MPEG_2_4_OBJECT_TYPE_NAMES = {
MPEG_4_AAC_SCALABLE_OBJECT_TYPE: 'MPEG_4_AAC_SCALABLE_OBJECT_TYPE'
}
# fmt: on
# -----------------------------------------------------------------------------
def flags_to_list(flags, values):
result = []
for i in range(len(values)):
for i, value in enumerate(values):
if flags & (1 << (len(values) - i - 1)):
result.append(values[i])
result.append(value)
return result
# -----------------------------------------------------------------------------
def make_audio_source_service_sdp_records(service_record_handle, version=(1, 3)):
# pylint: disable=import-outside-toplevel
from .avdtp import AVDTP_PSM
version_int = version[0] << 8 | version[1]
return [
ServiceAttribute(SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, DataElement.unsigned_integer_32(service_record_handle)),
ServiceAttribute(SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID, DataElement.sequence([
DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)
])),
ServiceAttribute(SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, DataElement.sequence([
DataElement.uuid(BT_AUDIO_SOURCE_SERVICE)
])),
ServiceAttribute(SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, DataElement.sequence([
DataElement.sequence([
ServiceAttribute(
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
DataElement.unsigned_integer_32(service_record_handle),
),
ServiceAttribute(
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
DataElement.sequence([DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)]),
),
ServiceAttribute(
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
DataElement.sequence([DataElement.uuid(BT_AUDIO_SOURCE_SERVICE)]),
),
ServiceAttribute(
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.sequence(
[
DataElement.uuid(BT_L2CAP_PROTOCOL_ID),
DataElement.unsigned_integer_16(AVDTP_PSM)
]),
DataElement.sequence([
DataElement.unsigned_integer_16(AVDTP_PSM),
]
),
DataElement.sequence(
[
DataElement.uuid(BT_AVDTP_PROTOCOL_ID),
DataElement.unsigned_integer_16(version_int)
])
])),
ServiceAttribute(SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID, DataElement.sequence([
DataElement.unsigned_integer_16(version_int),
]
),
]
),
),
ServiceAttribute(
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE),
DataElement.unsigned_integer_16(version_int)
])),
DataElement.unsigned_integer_16(version_int),
]
),
),
]
# -----------------------------------------------------------------------------
def make_audio_sink_service_sdp_records(service_record_handle, version=(1, 3)):
# pylint: disable=import-outside-toplevel
from .avdtp import AVDTP_PSM
version_int = version[0] << 8 | version[1]
return [
ServiceAttribute(SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, DataElement.unsigned_integer_32(service_record_handle)),
ServiceAttribute(SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID, DataElement.sequence([
DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)
])),
ServiceAttribute(SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, DataElement.sequence([
DataElement.uuid(BT_AUDIO_SINK_SERVICE)
])),
ServiceAttribute(SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, DataElement.sequence([
DataElement.sequence([
ServiceAttribute(
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
DataElement.unsigned_integer_32(service_record_handle),
),
ServiceAttribute(
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
DataElement.sequence([DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)]),
),
ServiceAttribute(
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
DataElement.sequence([DataElement.uuid(BT_AUDIO_SINK_SERVICE)]),
),
ServiceAttribute(
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.sequence(
[
DataElement.uuid(BT_L2CAP_PROTOCOL_ID),
DataElement.unsigned_integer_16(AVDTP_PSM)
]),
DataElement.sequence([
DataElement.unsigned_integer_16(AVDTP_PSM),
]
),
DataElement.sequence(
[
DataElement.uuid(BT_AVDTP_PROTOCOL_ID),
DataElement.unsigned_integer_16(version_int)
])
])),
ServiceAttribute(SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID, DataElement.sequence([
DataElement.unsigned_integer_16(version_int),
]
),
]
),
),
ServiceAttribute(
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE),
DataElement.unsigned_integer_16(version_int)
])),
DataElement.unsigned_integer_16(version_int),
]
),
),
]
@@ -206,8 +250,8 @@ class SbcMediaCodecInformation(
'subbands',
'allocation_method',
'minimum_bitpool_value',
'maximum_bitpool_value'
]
'maximum_bitpool_value',
],
)
):
'''
@@ -215,36 +259,25 @@ class SbcMediaCodecInformation(
'''
BIT_FIELDS = 'u4u4u4u2u2u8u8'
SAMPLING_FREQUENCY_BITS = {
16000: 1 << 3,
32000: 1 << 2,
44100: 1 << 1,
48000: 1
}
SAMPLING_FREQUENCY_BITS = {16000: 1 << 3, 32000: 1 << 2, 44100: 1 << 1, 48000: 1}
CHANNEL_MODE_BITS = {
SBC_MONO_CHANNEL_MODE: 1 << 3,
SBC_DUAL_CHANNEL_MODE: 1 << 2,
SBC_STEREO_CHANNEL_MODE: 1 << 1,
SBC_JOINT_STEREO_CHANNEL_MODE: 1
}
BLOCK_LENGTH_BITS = {
4: 1 << 3,
8: 1 << 2,
12: 1 << 1,
16: 1
}
SUBBANDS_BITS = {
4: 1 << 1,
8: 1
SBC_JOINT_STEREO_CHANNEL_MODE: 1,
}
BLOCK_LENGTH_BITS = {4: 1 << 3, 8: 1 << 2, 12: 1 << 1, 16: 1}
SUBBANDS_BITS = {4: 1 << 1, 8: 1}
ALLOCATION_METHOD_BITS = {
SBC_SNR_ALLOCATION_METHOD: 1 << 1,
SBC_LOUDNESS_ALLOCATION_METHOD: 1
SBC_LOUDNESS_ALLOCATION_METHOD: 1,
}
@staticmethod
def from_bytes(data):
return SbcMediaCodecInformation(*bitstruct.unpack(SbcMediaCodecInformation.BIT_FIELDS, data))
return SbcMediaCodecInformation(
*bitstruct.unpack(SbcMediaCodecInformation.BIT_FIELDS, data)
)
@classmethod
def from_discrete_values(
@@ -255,7 +288,7 @@ class SbcMediaCodecInformation(
subbands,
allocation_method,
minimum_bitpool_value,
maximum_bitpool_value
maximum_bitpool_value,
):
return SbcMediaCodecInformation(
sampling_frequency=cls.SAMPLING_FREQUENCY_BITS[sampling_frequency],
@@ -264,7 +297,7 @@ class SbcMediaCodecInformation(
subbands=cls.SUBBANDS_BITS[subbands],
allocation_method=cls.ALLOCATION_METHOD_BITS[allocation_method],
minimum_bitpool_value=minimum_bitpool_value,
maximum_bitpool_value = maximum_bitpool_value
maximum_bitpool_value=maximum_bitpool_value,
)
@classmethod
@@ -276,16 +309,20 @@ class SbcMediaCodecInformation(
subbands,
allocation_methods,
minimum_bitpool_value,
maximum_bitpool_value
maximum_bitpool_value,
):
return SbcMediaCodecInformation(
sampling_frequency = sum(cls.SAMPLING_FREQUENCY_BITS[x] for x in sampling_frequencies),
sampling_frequency=sum(
cls.SAMPLING_FREQUENCY_BITS[x] for x in sampling_frequencies
),
channel_mode=sum(cls.CHANNEL_MODE_BITS[x] for x in channel_modes),
block_length=sum(cls.BLOCK_LENGTH_BITS[x] for x in block_lengths),
subbands=sum(cls.SUBBANDS_BITS[x] for x in subbands),
allocation_method = sum(cls.ALLOCATION_METHOD_BITS[x] for x in allocation_methods),
allocation_method=sum(
cls.ALLOCATION_METHOD_BITS[x] for x in allocation_methods
),
minimum_bitpool_value=minimum_bitpool_value,
maximum_bitpool_value = maximum_bitpool_value
maximum_bitpool_value=maximum_bitpool_value,
)
def __bytes__(self):
@@ -294,7 +331,9 @@ class SbcMediaCodecInformation(
def __str__(self):
channel_modes = ['MONO', 'DUAL_CHANNEL', 'STEREO', 'JOINT_STEREO']
allocation_methods = ['SNR', 'Loudness']
return '\n'.join([
return '\n'.join(
# pylint: disable=line-too-long
[
'SbcMediaCodecInformation(',
f' sampling_frequency: {",".join([str(x) for x in flags_to_list(self.sampling_frequency, SBC_SAMPLING_FREQUENCIES)])}',
f' channel_mode: {",".join([str(x) for x in flags_to_list(self.channel_mode, channel_modes)])}',
@@ -302,22 +341,16 @@ class SbcMediaCodecInformation(
f' subbands: {",".join([str(x) for x in flags_to_list(self.subbands, SBC_SUBBANDS)])}',
f' allocation_method: {",".join([str(x) for x in flags_to_list(self.allocation_method, allocation_methods)])}',
f' minimum_bitpool_value: {self.minimum_bitpool_value}',
f' maximum_bitpool_value: {self.maximum_bitpool_value}'
')'
])
f' maximum_bitpool_value: {self.maximum_bitpool_value}' ')',
]
)
# -----------------------------------------------------------------------------
class AacMediaCodecInformation(
namedtuple(
'AacMediaCodecInformation',
[
'object_type',
'sampling_frequency',
'channels',
'vbr',
'bitrate'
]
['object_type', 'sampling_frequency', 'channels', 'vbr', 'bitrate'],
)
):
'''
@@ -329,7 +362,7 @@ class AacMediaCodecInformation(
MPEG_2_AAC_LC_OBJECT_TYPE: 1 << 7,
MPEG_4_AAC_LC_OBJECT_TYPE: 1 << 6,
MPEG_4_AAC_LTP_OBJECT_TYPE: 1 << 5,
MPEG_4_AAC_SCALABLE_OBJECT_TYPE: 1 << 4
MPEG_4_AAC_SCALABLE_OBJECT_TYPE: 1 << 4,
}
SAMPLING_FREQUENCY_BITS = {
8000: 1 << 11,
@@ -343,66 +376,66 @@ class AacMediaCodecInformation(
48000: 1 << 3,
64000: 1 << 2,
88200: 1 << 1,
96000: 1
}
CHANNELS_BITS = {
1: 1 << 1,
2: 1
96000: 1,
}
CHANNELS_BITS = {1: 1 << 1, 2: 1}
@staticmethod
def from_bytes(data):
return AacMediaCodecInformation(*bitstruct.unpack(AacMediaCodecInformation.BIT_FIELDS, data))
return AacMediaCodecInformation(
*bitstruct.unpack(AacMediaCodecInformation.BIT_FIELDS, data)
)
@classmethod
def from_discrete_values(
cls,
object_type,
sampling_frequency,
channels,
vbr,
bitrate
cls, object_type, sampling_frequency, channels, vbr, bitrate
):
return AacMediaCodecInformation(
object_type=cls.OBJECT_TYPE_BITS[object_type],
sampling_frequency=cls.SAMPLING_FREQUENCY_BITS[sampling_frequency],
channels=cls.CHANNELS_BITS[channels],
vbr=vbr,
bitrate = bitrate
bitrate=bitrate,
)
@classmethod
def from_lists(
cls,
object_types,
sampling_frequencies,
channels,
vbr,
bitrate
):
def from_lists(cls, object_types, sampling_frequencies, channels, vbr, bitrate):
return AacMediaCodecInformation(
object_type=sum(cls.OBJECT_TYPE_BITS[x] for x in object_types),
sampling_frequency = sum(cls.SAMPLING_FREQUENCY_BITS[x] for x in sampling_frequencies),
sampling_frequency=sum(
cls.SAMPLING_FREQUENCY_BITS[x] for x in sampling_frequencies
),
channels=sum(cls.CHANNELS_BITS[x] for x in channels),
vbr=vbr,
bitrate = bitrate
bitrate=bitrate,
)
def __bytes__(self):
return bitstruct.pack(self.BIT_FIELDS, *self)
def __str__(self):
object_types = ['MPEG_2_AAC_LC', 'MPEG_4_AAC_LC', 'MPEG_4_AAC_LTP', 'MPEG_4_AAC_SCALABLE', '[4]', '[5]', '[6]', '[7]']
object_types = [
'MPEG_2_AAC_LC',
'MPEG_4_AAC_LC',
'MPEG_4_AAC_LTP',
'MPEG_4_AAC_SCALABLE',
'[4]',
'[5]',
'[6]',
'[7]',
]
channels = [1, 2]
return '\n'.join([
# pylint: disable=line-too-long
return '\n'.join(
[
'AacMediaCodecInformation(',
f' object_type: {",".join([str(x) for x in flags_to_list(self.object_type, object_types)])}',
f' sampling_frequency: {",".join([str(x) for x in flags_to_list(self.sampling_frequency, MPEG_2_4_AAC_SAMPLING_FREQUENCIES)])}',
f' channels: {",".join([str(x) for x in flags_to_list(self.channels, channels)])}',
f' vbr: {self.vbr}',
f' bitrate: {self.bitrate}'
')'
])
f' bitrate: {self.bitrate}' ')',
]
)
# -----------------------------------------------------------------------------
@@ -425,24 +458,21 @@ class VendorSpecificMediaCodecInformation:
return struct.pack('<IH', self.vendor_id, self.codec_id, self.value)
def __str__(self):
return '\n'.join([
# pylint: disable=line-too-long
return '\n'.join(
[
'VendorSpecificMediaCodecInformation(',
f' vendor_id: {self.vendor_id:08X} ({name_or_number(COMPANY_IDENTIFIERS, self.vendor_id & 0xFFFF)})',
f' codec_id: {self.codec_id:04X}',
f' value: {self.value.hex()}'
')'
])
f' value: {self.value.hex()}' ')',
]
)
# -----------------------------------------------------------------------------
class SbcFrame:
def __init__(
self,
sampling_frequency,
block_count,
channel_mode,
subband_count,
payload
self, sampling_frequency, block_count, channel_mode, subband_count, payload
):
self.sampling_frequency = sampling_frequency
self.block_count = block_count
@@ -463,7 +493,13 @@ class SbcFrame:
return self.sample_count / self.sampling_frequency
def __str__(self):
return f'SBC(sf={self.sampling_frequency},cm={self.channel_mode},br={self.bitrate},sc={self.sample_count},size={len(self.payload)})'
return (
f'SBC(sf={self.sampling_frequency},'
f'cm={self.channel_mode},'
f'br={self.bitrate},'
f'sc={self.sample_count},'
f'size={len(self.payload)})'
)
# -----------------------------------------------------------------------------
@@ -498,13 +534,19 @@ class SbcParser:
if channel_mode in (SBC_MONO_CHANNEL_MODE, SBC_DUAL_CHANNEL_MODE):
frame_length += (blocks * channels * bitpool) // 8
else:
frame_length += ((1 if channel_mode == SBC_JOINT_STEREO_CHANNEL_MODE else 0) * subbands + blocks * bitpool) // 8
frame_length += (
(1 if channel_mode == SBC_JOINT_STEREO_CHANNEL_MODE else 0)
* subbands
+ blocks * bitpool
) // 8
# Read the rest of the frame
payload = header + await self.read(frame_length - 4)
# Emit the next frame
yield SbcFrame(sampling_frequency, blocks, channel_mode, subbands, payload)
yield SbcFrame(
sampling_frequency, blocks, channel_mode, subbands, payload
)
return generate_frames()
@@ -519,6 +561,7 @@ class SbcPacketSource:
@property
def packets(self):
async def generate_packets():
# pylint: disable=import-outside-toplevel
from .avdtp import MediaPacket # Import here to avoid a circular reference
sequence_number = 0
@@ -532,18 +575,25 @@ class SbcPacketSource:
async for frame in sbc_parser.frames:
print(frame)
if frames_size + len(frame.payload) > max_rtp_payload or len(frames) == 16:
if (
frames_size + len(frame.payload) > max_rtp_payload
or len(frames) == 16
):
# Need to flush what has been accumulated so far
# Emit a packet
sbc_payload = bytes([len(frames)]) + b''.join([frame.payload for frame in frames])
packet = MediaPacket(2, 0, 0, 0, sequence_number, timestamp, 0, [], 96, sbc_payload)
sbc_payload = bytes([len(frames)]) + b''.join(
[frame.payload for frame in frames]
)
packet = MediaPacket(
2, 0, 0, 0, sequence_number, timestamp, 0, [], 96, sbc_payload
)
packet.timestamp_seconds = timestamp / frame.sampling_frequency
yield packet
# Prepare for next packets
sequence_number += 1
timestamp += sum([frame.sample_count for frame in frames])
timestamp += sum((frame.sample_count for frame in frames))
frames = [frame]
frames_size = len(frame.payload)
else:

View File

@@ -22,15 +22,22 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import struct
from colors import color
from pyee import EventEmitter
from typing import Dict, Type
from bumble.core import UUID, name_or_number
from bumble.hci import HCI_Object, key_with_value
from .core import *
from .hci import *
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
# fmt: off
# pylint: disable=line-too-long
ATT_CID = 0x04
ATT_ERROR_RESPONSE = 0x01
@@ -163,19 +170,14 @@ ATT_ERROR_NAMES = {
ATT_DEFAULT_MTU = 23
HANDLE_FIELD_SPEC = {'size': 2, 'mapper': lambda x: f'0x{x:04X}'}
UUID_2_16_FIELD_SPEC = lambda x, y: UUID.parse_uuid(x, y) # noqa: E731
# pylint: disable-next=unnecessary-lambda-assignment,unnecessary-lambda
UUID_2_16_FIELD_SPEC = lambda x, y: UUID.parse_uuid(x, y)
# pylint: disable-next=unnecessary-lambda-assignment,unnecessary-lambda
UUID_2_FIELD_SPEC = lambda x, y: UUID.parse_uuid_2(x, y) # noqa: E731
# -----------------------------------------------------------------------------
# Utils
# -----------------------------------------------------------------------------
def key_with_value(dictionary, target_value):
for key, value in dictionary.items():
if value == target_value:
return key
return None
# fmt: on
# pylint: enable=line-too-long
# pylint: disable=invalid-name
# -----------------------------------------------------------------------------
# Exceptions
@@ -196,8 +198,10 @@ class ATT_PDU:
'''
See Bluetooth spec @ Vol 3, Part F - 3.3 ATTRIBUTE PDU
'''
pdu_classes = {}
pdu_classes: Dict[int, Type[ATT_PDU]] = {}
op_code = 0
name = None
@staticmethod
def from_bytes(pdu):
@@ -274,11 +278,13 @@ class ATT_PDU:
# -----------------------------------------------------------------------------
@ATT_PDU.subclass([
@ATT_PDU.subclass(
[
('request_opcode_in_error', {'size': 1, 'mapper': ATT_PDU.pdu_name}),
('attribute_handle_in_error', HANDLE_FIELD_SPEC),
('error_code', {'size': 1, 'mapper': ATT_PDU.error_name})
])
('error_code', {'size': 1, 'mapper': ATT_PDU.error_name}),
]
)
class ATT_Error_Response(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.1.1 Error Response
@@ -286,9 +292,7 @@ class ATT_Error_Response(ATT_PDU):
# -----------------------------------------------------------------------------
@ATT_PDU.subclass([
('client_rx_mtu', 2)
])
@ATT_PDU.subclass([('client_rx_mtu', 2)])
class ATT_Exchange_MTU_Request(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.2.1 Exchange MTU Request
@@ -296,9 +300,7 @@ class ATT_Exchange_MTU_Request(ATT_PDU):
# -----------------------------------------------------------------------------
@ATT_PDU.subclass([
('server_rx_mtu', 2)
])
@ATT_PDU.subclass([('server_rx_mtu', 2)])
class ATT_Exchange_MTU_Response(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.2.2 Exchange MTU Response
@@ -306,10 +308,9 @@ class ATT_Exchange_MTU_Response(ATT_PDU):
# -----------------------------------------------------------------------------
@ATT_PDU.subclass([
('starting_handle', HANDLE_FIELD_SPEC),
('ending_handle', HANDLE_FIELD_SPEC)
])
@ATT_PDU.subclass(
[('starting_handle', HANDLE_FIELD_SPEC), ('ending_handle', HANDLE_FIELD_SPEC)]
)
class ATT_Find_Information_Request(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.3.1 Find Information Request
@@ -317,10 +318,7 @@ class ATT_Find_Information_Request(ATT_PDU):
# -----------------------------------------------------------------------------
@ATT_PDU.subclass([
('format', 1),
('information_data', '*')
])
@ATT_PDU.subclass([('format', 1), ('information_data', '*')])
class ATT_Find_Information_Response(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.3.2 Find Information Response
@@ -346,20 +344,33 @@ class ATT_Find_Information_Response(ATT_PDU):
def __str__(self):
result = color(self.name, 'yellow')
result += ':\n' + HCI_Object.format_fields(self.__dict__, [
result += ':\n' + HCI_Object.format_fields(
self.__dict__,
[
('format', 1),
('information', {'mapper': lambda x: ', '.join([f'0x{handle:04X}:{uuid.hex()}' for handle, uuid in x])})
], ' ')
(
'information',
{
'mapper': lambda x: ', '.join(
[f'0x{handle:04X}:{uuid.hex()}' for handle, uuid in x]
)
},
),
],
' ',
)
return result
# -----------------------------------------------------------------------------
@ATT_PDU.subclass([
@ATT_PDU.subclass(
[
('starting_handle', HANDLE_FIELD_SPEC),
('ending_handle', HANDLE_FIELD_SPEC),
('attribute_type', UUID_2_FIELD_SPEC),
('attribute_value', '*')
])
('attribute_value', '*'),
]
)
class ATT_Find_By_Type_Value_Request(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.3.3 Find By Type Value Request
@@ -367,9 +378,7 @@ class ATT_Find_By_Type_Value_Request(ATT_PDU):
# -----------------------------------------------------------------------------
@ATT_PDU.subclass([
('handles_information_list', '*')
])
@ATT_PDU.subclass([('handles_information_list', '*')])
class ATT_Find_By_Type_Value_Response(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.3.4 Find By Type Value Response
@@ -379,7 +388,9 @@ class ATT_Find_By_Type_Value_Response(ATT_PDU):
self.handles_information = []
offset = 0
while offset + 4 <= len(self.handles_information_list):
found_attribute_handle, group_end_handle = struct.unpack_from('<HH', self.handles_information_list, offset)
found_attribute_handle, group_end_handle = struct.unpack_from(
'<HH', self.handles_information_list, offset
)
self.handles_information.append((found_attribute_handle, group_end_handle))
offset += 4
@@ -393,18 +404,34 @@ class ATT_Find_By_Type_Value_Response(ATT_PDU):
def __str__(self):
result = color(self.name, 'yellow')
result += ':\n' + HCI_Object.format_fields(self.__dict__, [
('handles_information', {'mapper': lambda x: ', '.join([f'0x{handle1:04X}-0x{handle2:04X}' for handle1, handle2 in x])})
], ' ')
result += ':\n' + HCI_Object.format_fields(
self.__dict__,
[
(
'handles_information',
{
'mapper': lambda x: ', '.join(
[
f'0x{handle1:04X}-0x{handle2:04X}'
for handle1, handle2 in x
]
)
},
)
],
' ',
)
return result
# -----------------------------------------------------------------------------
@ATT_PDU.subclass([
@ATT_PDU.subclass(
[
('starting_handle', HANDLE_FIELD_SPEC),
('ending_handle', HANDLE_FIELD_SPEC),
('attribute_type', UUID_2_16_FIELD_SPEC)
])
('attribute_type', UUID_2_16_FIELD_SPEC),
]
)
class ATT_Read_By_Type_Request(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.4.1 Read By Type Request
@@ -412,10 +439,7 @@ class ATT_Read_By_Type_Request(ATT_PDU):
# -----------------------------------------------------------------------------
@ATT_PDU.subclass([
('length', 1),
('attribute_data_list', '*')
])
@ATT_PDU.subclass([('length', 1), ('attribute_data_list', '*')])
class ATT_Read_By_Type_Response(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.4.2 Read By Type Response
@@ -424,9 +448,15 @@ class ATT_Read_By_Type_Response(ATT_PDU):
def parse_attribute_data_list(self):
self.attributes = []
offset = 0
while self.length != 0 and offset + self.length <= len(self.attribute_data_list):
attribute_handle, = struct.unpack_from('<H', self.attribute_data_list, offset)
attribute_value = self.attribute_data_list[offset + 2:offset + self.length]
while self.length != 0 and offset + self.length <= len(
self.attribute_data_list
):
(attribute_handle,) = struct.unpack_from(
'<H', self.attribute_data_list, offset
)
attribute_value = self.attribute_data_list[
offset + 2 : offset + self.length
]
self.attributes.append((attribute_handle, attribute_value))
offset += self.length
@@ -440,17 +470,26 @@ class ATT_Read_By_Type_Response(ATT_PDU):
def __str__(self):
result = color(self.name, 'yellow')
result += ':\n' + HCI_Object.format_fields(self.__dict__, [
result += ':\n' + HCI_Object.format_fields(
self.__dict__,
[
('length', 1),
('attributes', {'mapper': lambda x: ', '.join([f'0x{handle:04X}:{value.hex()}' for handle, value in x])})
], ' ')
(
'attributes',
{
'mapper': lambda x: ', '.join(
[f'0x{handle:04X}:{value.hex()}' for handle, value in x]
)
},
),
],
' ',
)
return result
# -----------------------------------------------------------------------------
@ATT_PDU.subclass([
('attribute_handle', HANDLE_FIELD_SPEC)
])
@ATT_PDU.subclass([('attribute_handle', HANDLE_FIELD_SPEC)])
class ATT_Read_Request(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.4.3 Read Request
@@ -458,9 +497,7 @@ class ATT_Read_Request(ATT_PDU):
# -----------------------------------------------------------------------------
@ATT_PDU.subclass([
('attribute_value', '*')
])
@ATT_PDU.subclass([('attribute_value', '*')])
class ATT_Read_Response(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.4.4 Read Response
@@ -468,10 +505,7 @@ class ATT_Read_Response(ATT_PDU):
# -----------------------------------------------------------------------------
@ATT_PDU.subclass([
('attribute_handle', HANDLE_FIELD_SPEC),
('value_offset', 2)
])
@ATT_PDU.subclass([('attribute_handle', HANDLE_FIELD_SPEC), ('value_offset', 2)])
class ATT_Read_Blob_Request(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.4.5 Read Blob Request
@@ -479,9 +513,7 @@ class ATT_Read_Blob_Request(ATT_PDU):
# -----------------------------------------------------------------------------
@ATT_PDU.subclass([
('part_attribute_value', '*')
])
@ATT_PDU.subclass([('part_attribute_value', '*')])
class ATT_Read_Blob_Response(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.4.6 Read Blob Response
@@ -489,9 +521,7 @@ class ATT_Read_Blob_Response(ATT_PDU):
# -----------------------------------------------------------------------------
@ATT_PDU.subclass([
('set_of_handles', '*')
])
@ATT_PDU.subclass([('set_of_handles', '*')])
class ATT_Read_Multiple_Request(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.4.7 Read Multiple Request
@@ -499,9 +529,7 @@ class ATT_Read_Multiple_Request(ATT_PDU):
# -----------------------------------------------------------------------------
@ATT_PDU.subclass([
('set_of_values', '*')
])
@ATT_PDU.subclass([('set_of_values', '*')])
class ATT_Read_Multiple_Response(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.4.8 Read Multiple Response
@@ -509,11 +537,13 @@ class ATT_Read_Multiple_Response(ATT_PDU):
# -----------------------------------------------------------------------------
@ATT_PDU.subclass([
@ATT_PDU.subclass(
[
('starting_handle', HANDLE_FIELD_SPEC),
('ending_handle', HANDLE_FIELD_SPEC),
('attribute_group_type', UUID_2_16_FIELD_SPEC)
])
('attribute_group_type', UUID_2_16_FIELD_SPEC),
]
)
class ATT_Read_By_Group_Type_Request(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.4.9 Read by Group Type Request
@@ -521,10 +551,7 @@ class ATT_Read_By_Group_Type_Request(ATT_PDU):
# -----------------------------------------------------------------------------
@ATT_PDU.subclass([
('length', 1),
('attribute_data_list', '*')
])
@ATT_PDU.subclass([('length', 1), ('attribute_data_list', '*')])
class ATT_Read_By_Group_Type_Response(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.4.10 Read by Group Type Response
@@ -533,10 +560,18 @@ class ATT_Read_By_Group_Type_Response(ATT_PDU):
def parse_attribute_data_list(self):
self.attributes = []
offset = 0
while self.length != 0 and offset + self.length <= len(self.attribute_data_list):
attribute_handle, end_group_handle = struct.unpack_from('<HH', self.attribute_data_list, offset)
attribute_value = self.attribute_data_list[offset + 4:offset + self.length]
self.attributes.append((attribute_handle, end_group_handle, attribute_value))
while self.length != 0 and offset + self.length <= len(
self.attribute_data_list
):
attribute_handle, end_group_handle = struct.unpack_from(
'<HH', self.attribute_data_list, offset
)
attribute_value = self.attribute_data_list[
offset + 4 : offset + self.length
]
self.attributes.append(
(attribute_handle, end_group_handle, attribute_value)
)
offset += self.length
def __init__(self, *args, **kwargs):
@@ -549,18 +584,29 @@ class ATT_Read_By_Group_Type_Response(ATT_PDU):
def __str__(self):
result = color(self.name, 'yellow')
result += ':\n' + HCI_Object.format_fields(self.__dict__, [
result += ':\n' + HCI_Object.format_fields(
self.__dict__,
[
('length', 1),
('attributes', {'mapper': lambda x: ', '.join([f'0x{handle:04X}-0x{end:04X}:{value.hex()}' for handle, end, value in x])})
], ' ')
(
'attributes',
{
'mapper': lambda x: ', '.join(
[
f'0x{handle:04X}-0x{end:04X}:{value.hex()}'
for handle, end, value in x
]
)
},
),
],
' ',
)
return result
# -----------------------------------------------------------------------------
@ATT_PDU.subclass([
('attribute_handle', HANDLE_FIELD_SPEC),
('attribute_value', '*')
])
@ATT_PDU.subclass([('attribute_handle', HANDLE_FIELD_SPEC), ('attribute_value', '*')])
class ATT_Write_Request(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.5.1 Write Request
@@ -576,10 +622,7 @@ class ATT_Write_Response(ATT_PDU):
# -----------------------------------------------------------------------------
@ATT_PDU.subclass([
('attribute_handle', HANDLE_FIELD_SPEC),
('attribute_value', '*')
])
@ATT_PDU.subclass([('attribute_handle', HANDLE_FIELD_SPEC), ('attribute_value', '*')])
class ATT_Write_Command(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.5.3 Write Command
@@ -587,11 +630,13 @@ class ATT_Write_Command(ATT_PDU):
# -----------------------------------------------------------------------------
@ATT_PDU.subclass([
@ATT_PDU.subclass(
[
('attribute_handle', HANDLE_FIELD_SPEC),
('attribute_value', '*')
# ('authentication_signature', 'TODO')
])
]
)
class ATT_Signed_Write_Command(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.5.4 Signed Write Command
@@ -599,11 +644,13 @@ class ATT_Signed_Write_Command(ATT_PDU):
# -----------------------------------------------------------------------------
@ATT_PDU.subclass([
@ATT_PDU.subclass(
[
('attribute_handle', HANDLE_FIELD_SPEC),
('value_offset', 2),
('part_attribute_value', '*')
])
('part_attribute_value', '*'),
]
)
class ATT_Prepare_Write_Request(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.6.1 Prepare Write Request
@@ -611,11 +658,13 @@ class ATT_Prepare_Write_Request(ATT_PDU):
# -----------------------------------------------------------------------------
@ATT_PDU.subclass([
@ATT_PDU.subclass(
[
('attribute_handle', HANDLE_FIELD_SPEC),
('value_offset', 2),
('part_attribute_value', '*')
])
('part_attribute_value', '*'),
]
)
class ATT_Prepare_Write_Response(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.6.2 Prepare Write Response
@@ -639,10 +688,7 @@ class ATT_Execute_Write_Response(ATT_PDU):
# -----------------------------------------------------------------------------
@ATT_PDU.subclass([
('attribute_handle', HANDLE_FIELD_SPEC),
('attribute_value', '*')
])
@ATT_PDU.subclass([('attribute_handle', HANDLE_FIELD_SPEC), ('attribute_value', '*')])
class ATT_Handle_Value_Notification(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.7.1 Handle Value Notification
@@ -650,10 +696,7 @@ class ATT_Handle_Value_Notification(ATT_PDU):
# -----------------------------------------------------------------------------
@ATT_PDU.subclass([
('attribute_handle', HANDLE_FIELD_SPEC),
('attribute_value', '*')
])
@ATT_PDU.subclass([('attribute_handle', HANDLE_FIELD_SPEC), ('attribute_value', '*')])
class ATT_Handle_Value_Indication(ATT_PDU):
'''
See Bluetooth spec @ Vol 3, Part F - 3.4.7.2 Handle Value Indication
@@ -687,15 +730,15 @@ class Attribute(EventEmitter):
self.permissions = permissions
# Convert the type to a UUID object if it isn't already
if type(attribute_type) is str:
if isinstance(attribute_type, str):
self.type = UUID(attribute_type)
elif type(attribute_type) is bytes:
elif isinstance(attribute_type, bytes):
self.type = UUID.from_bytes(attribute_type)
else:
self.type = attribute_type
# Convert the value to a byte array
if type(value) is str:
if isinstance(value, str):
self.value = bytes(value, 'utf-8')
else:
self.value = value
@@ -709,9 +752,11 @@ class Attribute(EventEmitter):
def read_value(self, connection):
if read := getattr(self.value, 'read', None):
try:
value = read(connection)
value = read(connection) # pylint: disable=not-callable
except ATT_Error as error:
raise ATT_Error(error_code=error.error_code, att_handle=self.handle)
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
) from error
else:
value = self.value
@@ -722,16 +767,18 @@ class Attribute(EventEmitter):
if write := getattr(self.value, 'write', None):
try:
write(connection, value)
write(connection, value) # pylint: disable=not-callable
except ATT_Error as error:
raise ATT_Error(error_code=error.error_code, att_handle=self.handle)
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
) from error
else:
self.value = value
self.emit('write', connection, value)
def __repr__(self):
if type(self.value) is bytes:
if isinstance(self.value, bytes):
value_str = self.value.hex()
else:
value_str = str(self.value)
@@ -739,4 +786,8 @@ class Attribute(EventEmitter):
value_string = f', value={self.value.hex()}'
else:
value_string = ''
return f'Attribute(handle=0x{self.handle:04X}, type={self.type}, permissions={self.permissions}{value_string})'
return (
f'Attribute(handle=0x{self.handle:04X}, '
f'type={self.type}, '
f'permissions={self.permissions}{value_string})'
)

File diff suppressed because it is too large Load Diff

View File

@@ -62,14 +62,14 @@ class HCI_Bridge:
hci_controller_source,
hci_controller_sink,
host_to_controller_filter=None,
controller_to_host_filter = None
controller_to_host_filter=None,
):
tracer = PacketTracer(emit_message=logger.info)
host_to_controller_forwarder = HCI_Bridge.Forwarder(
hci_controller_sink,
hci_host_sink,
host_to_controller_filter,
lambda packet: tracer.trace(packet, 0)
lambda packet: tracer.trace(packet, 0),
)
hci_host_source.set_packet_sink(host_to_controller_forwarder)
@@ -77,6 +77,6 @@ class HCI_Bridge:
hci_host_sink,
hci_controller_sink,
controller_to_host_filter,
lambda packet: tracer.trace(packet, 1)
lambda packet: tracer.trace(packet, 1),
)
hci_controller_source.set_packet_sink(controller_to_host_forwarder)

View File

@@ -17,6 +17,7 @@
# the `generate_company_id_list.py` script
# -----------------------------------------------------------------------------
# pylint: disable=line-too-long
COMPANY_IDENTIFIERS = {
0x0000: "Ericsson Technology Licensing",
0x0001: "Nokia Mobile Phones",
@@ -196,28 +197,28 @@ COMPANY_IDENTIFIERS = {
0x00AF: "Cinetix",
0x00B0: "Passif Semiconductor Corp",
0x00B1: "Saris Cycling Group, Inc",
0x00B2: "Bekey A/S",
0x00B3: "Clarinox Technologies Pty. Ltd.",
0x00B4: "BDE Technology Co., Ltd.",
0x00B2: "Bekey A/S",
0x00B3: "Clarinox Technologies Pty. Ltd.",
0x00B4: "BDE Technology Co., Ltd.",
0x00B5: "Swirl Networks",
0x00B6: "Meso international",
0x00B7: "TreLab Ltd",
0x00B8: "Qualcomm Innovation Center, Inc. (QuIC)",
0x00B9: "Johnson Controls, Inc.",
0x00BA: "Starkey Laboratories Inc.",
0x00BB: "S-Power Electronics Limited",
0x00BC: "Ace Sensor Inc",
0x00BD: "Aplix Corporation",
0x00BE: "AAMP of America",
0x00BF: "Stalmart Technology Limited",
0x00C0: "AMICCOM Electronics Corporation",
0x00C1: "Shenzhen Excelsecu Data Technology Co.,Ltd",
0x00C2: "Geneq Inc.",
0x00C3: "adidas AG",
0x00C4: "LG Electronics",
0x00C5: "Onset Computer Corporation",
0x00C6: "Selfly BV",
0x00C7: "Quuppa Oy.",
0x00B6: "Meso international",
0x00B7: "TreLab Ltd",
0x00B8: "Qualcomm Innovation Center, Inc. (QuIC)",
0x00B9: "Johnson Controls, Inc.",
0x00BA: "Starkey Laboratories Inc.",
0x00BB: "S-Power Electronics Limited",
0x00BC: "Ace Sensor Inc",
0x00BD: "Aplix Corporation",
0x00BE: "AAMP of America",
0x00BF: "Stalmart Technology Limited",
0x00C0: "AMICCOM Electronics Corporation",
0x00C1: "Shenzhen Excelsecu Data Technology Co.,Ltd",
0x00C2: "Geneq Inc.",
0x00C3: "adidas AG",
0x00C4: "LG Electronics",
0x00C5: "Onset Computer Corporation",
0x00C6: "Selfly BV",
0x00C7: "Quuppa Oy.",
0x00C8: "GeLo Inc",
0x00C9: "Evluma",
0x00CA: "MC10",
@@ -249,10 +250,10 @@ COMPANY_IDENTIFIERS = {
0x00E4: "Laird Connectivity, Inc. formerly L.S. Research Inc.",
0x00E5: "Eden Software Consultants Ltd.",
0x00E6: "Freshtemp",
0x00E7: "KS Technologies",
0x00E8: "ACTS Technologies",
0x00E9: "Vtrack Systems",
0x00EA: "Nielsen-Kellerman Company",
0x00E7: "KS Technologies",
0x00E8: "ACTS Technologies",
0x00E9: "Vtrack Systems",
0x00EA: "Nielsen-Kellerman Company",
0x00EB: "Server Technology Inc.",
0x00EC: "BioResearch Associates",
0x00ED: "Jolly Logic, LLC",
@@ -2704,5 +2705,5 @@ COMPANY_IDENTIFIERS = {
0x0A7C: "WAFERLOCK",
0x0A7D: "Freedman Electronics Pty Ltd",
0x0A7E: "Keba AG",
0x0A7F: "Intuity Medical"
0x0A7F: "Intuity Medical",
}

View File

@@ -19,9 +19,36 @@ import logging
import asyncio
import itertools
import random
import struct
from colors import color
from bumble.core import BT_CENTRAL_ROLE, BT_PERIPHERAL_ROLE
from bumble.hci import (
HCI_ACL_DATA_PACKET,
HCI_COMMAND_DISALLOWED_ERROR,
HCI_COMMAND_PACKET,
HCI_COMMAND_STATUS_PENDING,
HCI_CONNECTION_TIMEOUT_ERROR,
HCI_EVENT_PACKET,
HCI_INVALID_HCI_COMMAND_PARAMETERS_ERROR,
HCI_LE_1M_PHY,
HCI_SUCCESS,
HCI_UNKNOWN_HCI_COMMAND_ERROR,
HCI_VERSION_BLUETOOTH_CORE_5_0,
Address,
HCI_AclDataPacket,
HCI_AclDataPacketAssembler,
HCI_Command_Complete_Event,
HCI_Command_Status_Event,
HCI_Disconnection_Complete_Event,
HCI_Encryption_Change_Event,
HCI_LE_Advertising_Report_Event,
HCI_LE_Connection_Complete_Event,
HCI_LE_Read_Remote_Features_Complete_Event,
HCI_Number_Of_Completed_Packets_Event,
HCI_Packet,
)
from .hci import *
from .l2cap import *
# -----------------------------------------------------------------------------
# Logging
@@ -48,11 +75,15 @@ class Connection:
def on_hci_acl_data_packet(self, packet):
self.assembler.feed_packet(packet)
self.controller.send_hci_packet(HCI_Number_Of_Completed_Packets_Event([(self.handle, 1)]))
self.controller.send_hci_packet(
HCI_Number_Of_Completed_Packets_Event([(self.handle, 1)])
)
def on_acl_pdu(self, data):
if self.link:
self.link.send_acl_data(self.controller.random_address, self.peer_address, data)
self.link.send_acl_data(
self.controller.random_address, self.peer_address, data
)
# -----------------------------------------------------------------------------
@@ -62,22 +93,36 @@ class Controller:
self.hci_sink = None
self.link = link
self.central_connections = {} # Connections where this controller is the central
self.peripheral_connections = {} # Connections where this controller is the peripheral
self.central_connections = (
{}
) # Connections where this controller is the central
self.peripheral_connections = (
{}
) # Connections where this controller is the peripheral
self.hci_version = HCI_VERSION_BLUETOOTH_CORE_5_0
self.hci_revision = 0
self.lmp_version = HCI_VERSION_BLUETOOTH_CORE_5_0
self.lmp_subversion = 0
self.lmp_features = bytes.fromhex('0000000060000000') # BR/EDR Not Supported, LE Supported (Controller)
self.lmp_features = bytes.fromhex(
'0000000060000000'
) # BR/EDR Not Supported, LE Supported (Controller)
self.manufacturer_name = 0xFFFF
self.hc_le_data_packet_length = 27
self.hc_total_num_le_data_packets = 64
self.supported_commands = bytes.fromhex('2000800000c000000000e40000002822000000000000040000f7ffff7f00000030f0f9ff01008004000000000000000000000000000000000000000000000000')
self.event_mask = 0
self.event_mask_page_2 = 0
self.supported_commands = bytes.fromhex(
'2000800000c000000000e40000002822000000000000040000f7ffff7f000000'
'30f0f9ff01008004000000000000000000000000000000000000000000000000'
)
self.le_event_mask = 0
self.advertising_parameters = None
self.le_features = bytes.fromhex('ff49010000000000')
self.le_states = bytes.fromhex('ffff3fffff030000')
self.advertising_channel_tx_power = 0
self.filter_accept_list_size = 8
self.filter_duplicates = False
self.resolving_list_size = 8
self.supported_max_tx_octets = 27
self.supported_max_tx_time = 10000 # microseconds
@@ -121,7 +166,8 @@ class Controller:
@host.setter
def host(self, host):
'''
Sets the host (sink) for this controller, and set this controller as the controller (sink) for the host
Sets the host (sink) for this controller, and set this controller as the
controller (sink) for the host
'''
self.set_packet_sink(host)
if host:
@@ -139,7 +185,7 @@ class Controller:
@public_address.setter
def public_address(self, address):
if type(address) is str:
if isinstance(address, str):
address = Address(address)
self._public_address = address
@@ -149,7 +195,7 @@ class Controller:
@random_address.setter
def random_address(self, address):
if type(address) is str:
if isinstance(address, str):
address = Address(address)
self._random_address = address
logger.debug(f'new random address: {address}')
@@ -162,7 +208,10 @@ class Controller:
self.on_hci_packet(HCI_Packet.from_bytes(packet))
def on_hci_packet(self, packet):
logger.debug(f'{color("<<<", "blue")} [{self.name}] {color("HOST -> CONTROLLER", "blue")}: {packet}')
logger.debug(
f'{color("<<<", "blue")} [{self.name}] '
f'{color("HOST -> CONTROLLER", "blue")}: {packet}'
)
# If the packet is a command, invoke the handler for this packet
if packet.hci_packet_type == HCI_COMMAND_PACKET:
@@ -178,28 +227,35 @@ class Controller:
handler_name = f'on_{command.name.lower()}'
handler = getattr(self, handler_name, self.on_hci_command)
result = handler(command)
if type(result) is bytes:
self.send_hci_packet(HCI_Command_Complete_Event(
if isinstance(result, bytes):
self.send_hci_packet(
HCI_Command_Complete_Event(
num_hci_command_packets=1,
command_opcode=command.op_code,
return_parameters = result
))
return_parameters=result,
)
)
def on_hci_event_packet(self, event):
def on_hci_event_packet(self, _event):
logger.warning('!!! unexpected event packet')
def on_hci_acl_data_packet(self, packet):
# Look for the connection to which this data belongs
connection = self.find_connection_by_handle(packet.connection_handle)
if connection is None:
logger.warning(f'!!! no connection for handle 0x{packet.connection_handle:04X}')
logger.warning(
f'!!! no connection for handle 0x{packet.connection_handle:04X}'
)
return
# Pass the packet to the connection
connection.on_hci_acl_data_packet(packet)
def send_hci_packet(self, packet):
logger.debug(f'{color(">>>", "green")} [{self.name}] {color("CONTROLLER -> HOST", "green")}: {packet}')
logger.debug(
f'{color(">>>", "green")} [{self.name}] '
f'{color("CONTROLLER -> HOST", "green")}: {packet}'
)
if self.host:
self.host.on_packet(packet.to_bytes())
@@ -215,8 +271,7 @@ class Controller:
handle = 0
max_handle = 0
for connection in itertools.chain(
self.central_connections.values(),
self.peripheral_connections.values()
self.central_connections.values(), self.peripheral_connections.values()
):
max_handle = max(max_handle, connection.handle)
if connection.handle == handle:
@@ -225,12 +280,13 @@ class Controller:
return handle
def find_connection_by_address(self, address):
return self.central_connections.get(address) or self.peripheral_connections.get(address)
return self.central_connections.get(address) or self.peripheral_connections.get(
address
)
def find_connection_by_handle(self, handle):
for connection in itertools.chain(
self.central_connections.values(),
self.peripheral_connections.values()
self.central_connections.values(), self.peripheral_connections.values()
):
if connection.handle == handle:
return connection
@@ -253,12 +309,15 @@ class Controller:
connection = self.peripheral_connections.get(peer_address)
if connection is None:
connection_handle = self.allocate_connection_handle()
connection = Connection(self, connection_handle, BT_PERIPHERAL_ROLE, peer_address, self.link)
connection = Connection(
self, connection_handle, BT_PERIPHERAL_ROLE, peer_address, self.link
)
self.peripheral_connections[peer_address] = connection
logger.debug(f'New PERIPHERAL connection handle: 0x{connection_handle:04X}')
# Then say that the connection has completed
self.send_hci_packet(HCI_LE_Connection_Complete_Event(
self.send_hci_packet(
HCI_LE_Connection_Complete_Event(
status=HCI_SUCCESS,
connection_handle=connection.handle,
role=connection.role,
@@ -267,8 +326,9 @@ class Controller:
connection_interval=10, # FIXME
peripheral_latency=0, # FIXME
supervision_timeout=10, # FIXME
central_clock_accuracy = 7 # FIXME
))
central_clock_accuracy=7, # FIXME
)
)
def on_link_central_disconnected(self, peer_address, reason):
'''
@@ -277,18 +337,22 @@ class Controller:
# Send a disconnection complete event
if connection := self.peripheral_connections.get(peer_address):
self.send_hci_packet(HCI_Disconnection_Complete_Event(
self.send_hci_packet(
HCI_Disconnection_Complete_Event(
status=HCI_SUCCESS,
connection_handle=connection.handle,
reason = reason
))
reason=reason,
)
)
# Remove the connection
del self.peripheral_connections[peer_address]
else:
logger.warn(f'!!! No peripheral connection found for {peer_address}')
logger.warning(f'!!! No peripheral connection found for {peer_address}')
def on_link_peripheral_connection_complete(self, le_create_connection_command, status):
def on_link_peripheral_connection_complete(
self, le_create_connection_command, status
):
'''
Called by the link when a connection has been made or has failed to be made
'''
@@ -300,19 +364,19 @@ class Controller:
if connection is None:
connection_handle = self.allocate_connection_handle()
connection = Connection(
self,
connection_handle,
BT_CENTRAL_ROLE,
peer_address,
self.link
self, connection_handle, BT_CENTRAL_ROLE, peer_address, self.link
)
self.central_connections[peer_address] = connection
logger.debug(f'New CENTRAL connection handle: 0x{connection_handle:04X}')
logger.debug(
f'New CENTRAL connection handle: 0x{connection_handle:04X}'
)
else:
connection = None
# Say that the connection has completed
self.send_hci_packet(HCI_LE_Connection_Complete_Event(
self.send_hci_packet(
# pylint: disable=line-too-long
HCI_LE_Connection_Complete_Event(
status=status,
connection_handle=connection.handle if connection else 0,
role=BT_CENTRAL_ROLE,
@@ -321,8 +385,9 @@ class Controller:
connection_interval=le_create_connection_command.connection_interval_min,
peripheral_latency=le_create_connection_command.max_latency,
supervision_timeout=le_create_connection_command.supervision_timeout,
central_clock_accuracy = 0
))
central_clock_accuracy=0,
)
)
def on_link_peripheral_disconnection_complete(self, disconnection_command, status):
'''
@@ -330,14 +395,18 @@ class Controller:
'''
# Send a disconnection complete event
self.send_hci_packet(HCI_Disconnection_Complete_Event(
self.send_hci_packet(
HCI_Disconnection_Complete_Event(
status=status,
connection_handle=disconnection_command.connection_handle,
reason = disconnection_command.reason
))
reason=disconnection_command.reason,
)
)
# Remove the connection
if connection := self.find_central_connection_by_handle(disconnection_command.connection_handle):
if connection := self.find_central_connection_by_handle(
disconnection_command.connection_handle
):
logger.debug(f'CENTRAL Connection removed: {connection}')
del self.central_connections[connection.peer_address]
@@ -348,25 +417,25 @@ class Controller:
# Send a disconnection complete event
if connection := self.central_connections.get(peer_address):
self.send_hci_packet(HCI_Disconnection_Complete_Event(
self.send_hci_packet(
HCI_Disconnection_Complete_Event(
status=HCI_SUCCESS,
connection_handle=connection.handle,
reason = HCI_CONNECTION_TIMEOUT_ERROR
))
reason=HCI_CONNECTION_TIMEOUT_ERROR,
)
)
# Remove the connection
del self.central_connections[peer_address]
else:
logger.warn(f'!!! No central connection found for {peer_address}')
logger.warning(f'!!! No central connection found for {peer_address}')
def on_link_encrypted(self, peer_address, rand, ediv, ltk):
def on_link_encrypted(self, peer_address, _rand, _ediv, _ltk):
# For now, just setup the encryption without asking the host
if connection := self.find_connection_by_address(peer_address):
self.send_hci_packet(
HCI_Encryption_Change_Event(
status = 0,
connection_handle = connection.handle,
encryption_enabled = 1
status=0, connection_handle=connection.handle, encryption_enabled=1
)
)
@@ -388,24 +457,24 @@ class Controller:
return
# Send a scan report
report = HCI_Object(
HCI_LE_Advertising_Report_Event.REPORT_FIELDS,
report = HCI_LE_Advertising_Report_Event.Report(
HCI_LE_Advertising_Report_Event.Report.FIELDS,
event_type=HCI_LE_Advertising_Report_Event.ADV_IND,
address_type=sender_address.address_type,
address=sender_address,
data=data,
rssi = -50
rssi=-50,
)
self.send_hci_packet(HCI_LE_Advertising_Report_Event([report]))
# Simulate a scan response
report = HCI_Object(
HCI_LE_Advertising_Report_Event.REPORT_FIELDS,
report = HCI_LE_Advertising_Report_Event.Report(
HCI_LE_Advertising_Report_Event.Report.FIELDS,
event_type=HCI_LE_Advertising_Report_Event.SCAN_RSP,
address_type=sender_address.address_type,
address=sender_address,
data=data,
rssi = -50
rssi=-50,
)
self.send_hci_packet(HCI_LE_Advertising_Report_Event([report]))
@@ -414,14 +483,18 @@ class Controller:
############################################################
def on_advertising_timer_fired(self):
self.send_advertising_data()
self.advertising_timer_handle = asyncio.get_running_loop().call_later(self.advertising_interval / 1000.0, self.on_advertising_timer_fired)
self.advertising_timer_handle = asyncio.get_running_loop().call_later(
self.advertising_interval / 1000.0, self.on_advertising_timer_fired
)
def start_advertising(self):
# Stop any ongoing advertising before we start again
self.stop_advertising()
# Advertise now
self.advertising_timer_handle = asyncio.get_running_loop().call_soon(self.on_advertising_timer_fired)
self.advertising_timer_handle = asyncio.get_running_loop().call_soon(
self.on_advertising_timer_fired
)
def stop_advertising(self):
if self.advertising_timer_handle is not None:
@@ -455,15 +528,21 @@ class Controller:
See Bluetooth spec Vol 2, Part E - 7.1.6 Disconnect Command
'''
# First, say that the disconnection is pending
self.send_hci_packet(HCI_Command_Status_Event(
self.send_hci_packet(
HCI_Command_Status_Event(
status=HCI_COMMAND_STATUS_PENDING,
num_hci_command_packets=1,
command_opcode = command.op_code
))
command_opcode=command.op_code,
)
)
# Notify the link of the disconnection
if not (connection := self.find_central_connection_by_handle(command.connection_handle)):
logger.warn('connection not found')
if not (
connection := self.find_central_connection_by_handle(
command.connection_handle
)
):
logger.warning('connection not found')
return
if self.link:
@@ -479,7 +558,7 @@ class Controller:
self.event_mask = command.event_mask
return bytes([HCI_SUCCESS])
def on_hci_reset_command(self, command):
def on_hci_reset_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.3.2 Reset Command
'''
@@ -501,7 +580,7 @@ class Controller:
pass
return bytes([HCI_SUCCESS])
def on_hci_read_local_name_command(self, command):
def on_hci_read_local_name_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.3.12 Read Local Name Command
'''
@@ -511,21 +590,22 @@ class Controller:
return bytes([HCI_SUCCESS]) + local_name
def on_hci_read_class_of_device_command(self, command):
def on_hci_read_class_of_device_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.3.25 Read Class of Device Command
'''
return bytes([HCI_SUCCESS, 0, 0, 0])
def on_hci_write_class_of_device_command(self, command):
def on_hci_write_class_of_device_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.3.26 Write Class of Device Command
'''
return bytes([HCI_SUCCESS])
def on_hci_read_synchronous_flow_control_enable_command(self, command):
def on_hci_read_synchronous_flow_control_enable_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.3.36 Read Synchronous Flow Control Enable Command
See Bluetooth spec Vol 2, Part E - 7.3.36 Read Synchronous Flow Control Enable
Command
'''
if self.sync_flow_control:
ret = 1
@@ -535,7 +615,8 @@ class Controller:
def on_hci_write_synchronous_flow_control_enable_command(self, command):
'''
See Bluetooth spec Vol 2, Part E - 7.3.37 Write Synchronous Flow Control Enable Command
See Bluetooth spec Vol 2, Part E - 7.3.37 Write Synchronous Flow Control Enable
Command
'''
ret = HCI_SUCCESS
if command.synchronous_flow_control_enable == 1:
@@ -546,7 +627,7 @@ class Controller:
ret = HCI_INVALID_HCI_COMMAND_PARAMETERS_ERROR
return bytes([ret])
def on_hci_write_simple_pairing_mode_command(self, command):
def on_hci_write_simple_pairing_mode_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.3.59 Write Simple Pairing Mode Command
'''
@@ -559,13 +640,13 @@ class Controller:
self.event_mask_page_2 = command.event_mask_page_2
return bytes([HCI_SUCCESS])
def on_hci_read_le_host_support_command(self, command):
def on_hci_read_le_host_support_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.3.78 Write LE Host Support Command
'''
return bytes([HCI_SUCCESS, 1, 0])
def on_hci_write_le_host_support_command(self, command):
def on_hci_write_le_host_support_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.3.79 Write LE Host Support Command
'''
@@ -574,12 +655,13 @@ class Controller:
def on_hci_write_authenticated_payload_timeout_command(self, command):
'''
See Bluetooth spec Vol 2, Part E - 7.3.94 Write Authenticated Payload Timeout Command
See Bluetooth spec Vol 2, Part E - 7.3.94 Write Authenticated Payload Timeout
Command
'''
# TODO
return struct.pack('<BH', HCI_SUCCESS, command.connection_handle)
def on_hci_read_local_version_information_command(self, command):
def on_hci_read_local_version_information_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.4.1 Read Local Version Information Command
'''
@@ -590,26 +672,30 @@ class Controller:
self.hci_revision,
self.lmp_version,
self.manufacturer_name,
self.lmp_subversion
self.lmp_subversion,
)
def on_hci_read_local_supported_commands_command(self, command):
def on_hci_read_local_supported_commands_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.4.2 Read Local Supported Commands Command
'''
return bytes([HCI_SUCCESS]) + self.supported_commands
def on_hci_read_local_supported_features_command(self, command):
def on_hci_read_local_supported_features_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.4.3 Read Local Supported Features Command
'''
return bytes([HCI_SUCCESS]) + self.lmp_features
def on_hci_read_bd_addr_command(self, command):
def on_hci_read_bd_addr_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.4.6 Read BD_ADDR Command
'''
bd_addr = self._public_address.to_bytes() if self._public_address is not None else bytes(6)
bd_addr = (
self._public_address.to_bytes()
if self._public_address is not None
else bytes(6)
)
return bytes([HCI_SUCCESS]) + bd_addr
def on_hci_le_set_event_mask_command(self, command):
@@ -619,18 +705,21 @@ class Controller:
self.le_event_mask = command.le_event_mask
return bytes([HCI_SUCCESS])
def on_hci_le_read_buffer_size_command(self, command):
def on_hci_le_read_buffer_size_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.2 LE Read Buffer Size Command
'''
return struct.pack('<BHB',
return struct.pack(
'<BHB',
HCI_SUCCESS,
self.hc_le_data_packet_length,
self.hc_total_num_le_data_packets)
self.hc_total_num_le_data_packets,
)
def on_hci_le_read_local_supported_features_command(self, command):
def on_hci_le_read_local_supported_features_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.3 LE Read Local Supported Features Command
See Bluetooth spec Vol 2, Part E - 7.8.3 LE Read Local Supported Features
Command
'''
return bytes([HCI_SUCCESS]) + self.le_features
@@ -648,9 +737,10 @@ class Controller:
self.advertising_parameters = command
return bytes([HCI_SUCCESS])
def on_hci_le_read_advertising_channel_tx_power_command(self, command):
def on_hci_le_read_advertising_physical_channel_tx_power_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.6 LE Read Advertising Channel Tx Power Command
See Bluetooth spec Vol 2, Part E - 7.8.6 LE Read Advertising Physical Channel
Tx Power Command
'''
return bytes([HCI_SUCCESS, self.advertising_channel_tx_power])
@@ -710,50 +800,57 @@ class Controller:
# Check that we don't already have a pending connection
if self.link.get_pending_connection():
self.send_hci_packet(HCI_Command_Status_Event(
self.send_hci_packet(
HCI_Command_Status_Event(
status=HCI_COMMAND_DISALLOWED_ERROR,
num_hci_command_packets=1,
command_opcode = command.op_code
))
command_opcode=command.op_code,
)
)
return
# Initiate the connection
self.link.connect(self.random_address, command)
# Say that the connection is pending
self.send_hci_packet(HCI_Command_Status_Event(
self.send_hci_packet(
HCI_Command_Status_Event(
status=HCI_COMMAND_STATUS_PENDING,
num_hci_command_packets=1,
command_opcode = command.op_code
))
command_opcode=command.op_code,
)
)
def on_hci_le_create_connection_cancel_command(self, command):
def on_hci_le_create_connection_cancel_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.13 LE Create Connection Cancel Command
'''
return bytes([HCI_SUCCESS])
def on_hci_le_read_filter_accept_list_size_command(self, command):
def on_hci_le_read_filter_accept_list_size_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.14 LE Read Filter Accept List Size Command
See Bluetooth spec Vol 2, Part E - 7.8.14 LE Read Filter Accept List Size
Command
'''
return bytes([HCI_SUCCESS, self.filter_accept_list_size])
def on_hci_le_clear_filter_accept_list_command(self, command):
def on_hci_le_clear_filter_accept_list_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.15 LE Clear Filter Accept List Command
'''
return bytes([HCI_SUCCESS])
def on_hci_le_add_device_to_filter_accept_list_command(self, command):
def on_hci_le_add_device_to_filter_accept_list_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.16 LE Add Device To Filter Accept List Command
See Bluetooth spec Vol 2, Part E - 7.8.16 LE Add Device To Filter Accept List
Command
'''
return bytes([HCI_SUCCESS])
def on_hci_le_remove_device_from_filter_accept_list_command(self, command):
def on_hci_le_remove_device_from_filter_accept_list_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.17 LE Remove Device From Filter Accept List Command
See Bluetooth spec Vol 2, Part E - 7.8.17 LE Remove Device From Filter Accept
List Command
'''
return bytes([HCI_SUCCESS])
@@ -763,20 +860,24 @@ class Controller:
'''
# First, say that the command is pending
self.send_hci_packet(HCI_Command_Status_Event(
self.send_hci_packet(
HCI_Command_Status_Event(
status=HCI_COMMAND_STATUS_PENDING,
num_hci_command_packets=1,
command_opcode = command.op_code
))
command_opcode=command.op_code,
)
)
# Then send the remote features
self.send_hci_packet(HCI_LE_Read_Remote_Features_Complete_Event(
self.send_hci_packet(
HCI_LE_Read_Remote_Features_Complete_Event(
status=HCI_SUCCESS,
connection_handle=0,
le_features = bytes.fromhex('dd40000000000000')
))
le_features=bytes.fromhex('dd40000000000000'),
)
)
def on_hci_le_rand_command(self, command):
def on_hci_le_rand_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.23 LE Rand Command
'''
@@ -788,8 +889,12 @@ class Controller:
'''
# Check the parameters
if not (connection := self.find_central_connection_by_handle(command.connection_handle)):
logger.warn('connection not found')
if not (
connection := self.find_central_connection_by_handle(
command.connection_handle
)
):
logger.warning('connection not found')
return bytes([HCI_INVALID_HCI_COMMAND_PARAMETERS_ERROR])
# Notify that the connection is now encrypted
@@ -798,57 +903,68 @@ class Controller:
connection.peer_address,
command.random_number,
command.encrypted_diversifier,
command.long_term_key
command.long_term_key,
)
self.send_hci_packet(HCI_Command_Status_Event(
self.send_hci_packet(
HCI_Command_Status_Event(
status=HCI_COMMAND_STATUS_PENDING,
num_hci_command_packets=1,
command_opcode = command.op_code
))
command_opcode=command.op_code,
)
)
def on_hci_le_read_supported_states_command(self, command):
return None
def on_hci_le_read_supported_states_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.27 LE Read Supported States Command
'''
return bytes([HCI_SUCCESS]) + self.le_states
def on_hci_le_read_suggested_default_data_length_command(self, command):
def on_hci_le_read_suggested_default_data_length_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.34 LE Read Suggested Default Data Length Command
See Bluetooth spec Vol 2, Part E - 7.8.34 LE Read Suggested Default Data Length
Command
'''
return struct.pack('<BHH',
return struct.pack(
'<BHH',
HCI_SUCCESS,
self.suggested_max_tx_octets,
self.suggested_max_tx_time)
self.suggested_max_tx_time,
)
def on_hci_le_write_suggested_default_data_length_command(self, command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.35 LE Write Suggested Default Data Length Command
See Bluetooth spec Vol 2, Part E - 7.8.35 LE Write Suggested Default Data Length
Command
'''
self.suggested_max_tx_octets, self.suggested_max_tx_time = struct.unpack('<HH', command.parameters[:4])
self.suggested_max_tx_octets, self.suggested_max_tx_time = struct.unpack(
'<HH', command.parameters[:4]
)
return bytes([HCI_SUCCESS])
def on_hci_le_read_local_p_256_public_key_command(self, command):
def on_hci_le_read_local_p_256_public_key_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.36 LE Read P-256 Public Key Command
'''
# TODO create key and send HCI_LE_Read_Local_P-256_Public_Key_Complete event
return bytes([HCI_SUCCESS])
def on_hci_le_add_device_to_resolving_list_command(self, command):
def on_hci_le_add_device_to_resolving_list_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.38 LE Add Device To Resolving List Command
See Bluetooth spec Vol 2, Part E - 7.8.38 LE Add Device To Resolving List
Command
'''
return bytes([HCI_SUCCESS])
def on_hci_le_clear_resolving_list_command(self, command):
def on_hci_le_clear_resolving_list_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.40 LE Clear Resolving List Command
'''
return bytes([HCI_SUCCESS])
def on_hci_le_read_resolving_list_size_command(self, command):
def on_hci_le_read_resolving_list_size_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.41 LE Read Resolving List Size Command
'''
@@ -856,7 +972,8 @@ class Controller:
def on_hci_le_set_address_resolution_enable_command(self, command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.44 LE Set Address Resolution Enable Command
See Bluetooth spec Vol 2, Part E - 7.8.44 LE Set Address Resolution Enable
Command
'''
ret = HCI_SUCCESS
if command.address_resolution_enable == 1:
@@ -869,12 +986,13 @@ class Controller:
def on_hci_le_set_resolvable_private_address_timeout_command(self, command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.45 LE Set Resolvable Private Address Timeout Command
See Bluetooth spec Vol 2, Part E - 7.8.45 LE Set Resolvable Private Address
Timeout Command
'''
self.le_rpa_timeout = command.rpa_timeout
return bytes([HCI_SUCCESS])
def on_hci_le_read_maximum_data_length_command(self, command):
def on_hci_le_read_maximum_data_length_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.46 LE Read Maximum Data Length Command
'''
@@ -884,19 +1002,19 @@ class Controller:
self.supported_max_tx_octets,
self.supported_max_tx_time,
self.supported_max_rx_octets,
self.supported_max_rx_time
self.supported_max_rx_time,
)
def on_hci_le_read_phy_command(self, command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.47 LE Read PHY command
See Bluetooth spec Vol 2, Part E - 7.8.47 LE Read PHY Command
'''
return struct.pack(
'<BHBB',
HCI_SUCCESS,
command.connection_handle,
HCI_LE_1M_PHY,
HCI_LE_1M_PHY
HCI_LE_1M_PHY,
)
def on_hci_le_set_default_phy_command(self, command):
@@ -906,7 +1024,12 @@ class Controller:
self.default_phy = {
'all_phys': command.all_phys,
'tx_phys': command.tx_phys,
'rx_phys': command.rx_phys
'rx_phys': command.rx_phys,
}
return bytes([HCI_SUCCESS])
def on_hci_le_read_transmit_power_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.74 LE Read Transmit Power Command
'''
return struct.pack('<BBB', HCI_SUCCESS, 0, 0)

View File

@@ -15,6 +15,7 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import struct
from .company_ids import COMPANY_IDENTIFIERS
@@ -23,6 +24,8 @@ from .company_ids import COMPANY_IDENTIFIERS
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
# fmt: off
BT_CENTRAL_ROLE = 0
BT_PERIPHERAL_ROLE = 1
@@ -30,6 +33,9 @@ BT_BR_EDR_TRANSPORT = 0
BT_LE_TRANSPORT = 1
# fmt: on
# -----------------------------------------------------------------------------
# Utils
# -----------------------------------------------------------------------------
@@ -58,11 +64,19 @@ def padded_bytes(buffer, size):
return buffer + bytes(padding_size)
def get_dict_key_by_value(dictionary, value):
for key, val in dictionary.items():
if val == value:
return key
return None
# -----------------------------------------------------------------------------
# Exceptions
# -----------------------------------------------------------------------------
class BaseError(Exception):
"""Base class for errors with an error code, error name and namespace"""
def __init__(self, error_code, error_namespace='', error_name='', details=''):
super().__init__()
self.error_code = error_code
@@ -87,7 +101,7 @@ class ProtocolError(BaseError):
"""Protocol Error"""
class TimeoutError(Exception):
class TimeoutError(Exception): # pylint: disable=redefined-builtin
"""Timeout Error"""
@@ -99,12 +113,21 @@ class InvalidStateError(Exception):
"""Invalid State Error"""
class ConnectionError(BaseError):
class ConnectionError(BaseError): # pylint: disable=redefined-builtin
"""Connection Error"""
FAILURE = 0x01
CONNECTION_REFUSED = 0x02
def __init__(self, error_code, transport, peer_address, error_namespace='', error_name='', details=''):
def __init__(
self,
error_code,
transport,
peer_address,
error_namespace='',
error_name='',
details='',
):
super().__init__(error_code, error_namespace, error_name, details)
self.transport = transport
self.peer_address = peer_address
@@ -121,26 +144,33 @@ class UUID:
'''
See Bluetooth spec Vol 3, Part B - 2.5.1 UUID
'''
BASE_UUID = bytes.fromhex('00001000800000805F9B34FB')
UUIDS = [] # Registry of all instances created
UUIDS: list[UUID] = [] # Registry of all instances created
def __init__(self, uuid_str_or_int, name=None):
if type(uuid_str_or_int) is int:
if isinstance(uuid_str_or_int, int):
self.uuid_bytes = struct.pack('<H', uuid_str_or_int)
else:
if len(uuid_str_or_int) == 36:
if uuid_str_or_int[8] != '-' or uuid_str_or_int[13] != '-' or uuid_str_or_int[18] != '-' or uuid_str_or_int[23] != '-':
if (
uuid_str_or_int[8] != '-'
or uuid_str_or_int[13] != '-'
or uuid_str_or_int[18] != '-'
or uuid_str_or_int[23] != '-'
):
raise ValueError('invalid UUID format')
uuid_str = uuid_str_or_int.replace('-', '')
else:
uuid_str = uuid_str_or_int
if len(uuid_str) != 32 and len(uuid_str) != 8 and len(uuid_str) != 4:
raise ValueError('invalid UUID format')
raise ValueError(f"invalid UUID format: {uuid_str}")
self.uuid_bytes = bytes(reversed(bytes.fromhex(uuid_str)))
self.name = name
def register(self):
# Register this object in the class registry, and update the entry's name if it wasn't set already
# Register this object in the class registry, and update the entry's name if
# it wasn't set already
for uuid in self.UUIDS:
if self == uuid:
if uuid.name is None:
@@ -152,13 +182,13 @@ class UUID:
@classmethod
def from_bytes(cls, uuid_bytes, name=None):
if len(uuid_bytes) in {2, 4, 16}:
if len(uuid_bytes) in (2, 4, 16):
self = cls.__new__(cls)
self.uuid_bytes = uuid_bytes
self.name = name
return self.register()
else:
raise ValueError('only 2, 4 and 16 bytes are allowed')
@classmethod
@@ -170,19 +200,20 @@ class UUID:
return cls.from_bytes(struct.pack('<I', uuid_32), name)
@classmethod
def parse_uuid(cls, bytes, offset):
return len(bytes), cls.from_bytes(bytes[offset:])
def parse_uuid(cls, uuid_as_bytes, offset):
return len(uuid_as_bytes), cls.from_bytes(uuid_as_bytes[offset:])
@classmethod
def parse_uuid_2(cls, bytes, offset):
return offset + 2, cls.from_bytes(bytes[offset:offset + 2])
def parse_uuid_2(cls, uuid_as_bytes, offset):
return offset + 2, cls.from_bytes(uuid_as_bytes[offset : offset + 2])
def to_bytes(self, force_128=False):
if len(self.uuid_bytes) == 16 or not force_128:
return self.uuid_bytes
elif len(self.uuid_bytes) == 4:
if len(self.uuid_bytes) == 4:
return self.uuid_bytes + UUID.BASE_UUID
else:
return self.uuid_bytes + bytes([0, 0]) + UUID.BASE_UUID
def to_pdu_bytes(self):
@@ -197,14 +228,16 @@ class UUID:
def to_hex_str(self):
if len(self.uuid_bytes) == 2 or len(self.uuid_bytes) == 4:
return bytes(reversed(self.uuid_bytes)).hex().upper()
else:
return ''.join([
return ''.join(
[
bytes(reversed(self.uuid_bytes[12:16])).hex(),
bytes(reversed(self.uuid_bytes[10:12])).hex(),
bytes(reversed(self.uuid_bytes[8:10])).hex(),
bytes(reversed(self.uuid_bytes[6:8])).hex(),
bytes(reversed(self.uuid_bytes[0:6])).hex()
]).upper()
bytes(reversed(self.uuid_bytes[0:6])).hex(),
]
).upper()
def __bytes__(self):
return self.to_bytes()
@@ -212,7 +245,8 @@ class UUID:
def __eq__(self, other):
if isinstance(other, UUID):
return self.to_bytes(force_128=True) == other.to_bytes(force_128=True)
elif type(other) is str:
if isinstance(other, str):
return UUID(other) == self
return False
@@ -222,22 +256,25 @@ class UUID:
def __str__(self):
if len(self.uuid_bytes) == 2:
v = struct.unpack('<H', self.uuid_bytes)[0]
result = f'UUID-16:{v:04X}'
uuid = struct.unpack('<H', self.uuid_bytes)[0]
result = f'UUID-16:{uuid:04X}'
elif len(self.uuid_bytes) == 4:
v = struct.unpack('<I', self.uuid_bytes)[0]
result = f'UUID-32:{v:08X}'
uuid = struct.unpack('<I', self.uuid_bytes)[0]
result = f'UUID-32:{uuid:08X}'
else:
result = '-'.join([
result = '-'.join(
[
bytes(reversed(self.uuid_bytes[12:16])).hex(),
bytes(reversed(self.uuid_bytes[10:12])).hex(),
bytes(reversed(self.uuid_bytes[8:10])).hex(),
bytes(reversed(self.uuid_bytes[6:8])).hex(),
bytes(reversed(self.uuid_bytes[0:6])).hex()
]).upper()
bytes(reversed(self.uuid_bytes[0:6])).hex(),
]
).upper()
if self.name is not None:
return result + f' ({self.name})'
else:
return result
def __repr__(self):
@@ -247,6 +284,8 @@ class UUID:
# -----------------------------------------------------------------------------
# Common UUID constants
# -----------------------------------------------------------------------------
# fmt: off
# pylint: disable=line-too-long
# Protocol Identifiers
BT_SDP_PROTOCOL_ID = UUID.from_16_bits(0x0001, 'SDP')
@@ -352,11 +391,17 @@ BT_HDP_SERVICE = UUID.from_16_bits(0x1400,
BT_HDP_SOURCE_SERVICE = UUID.from_16_bits(0x1401, 'HDP Source')
BT_HDP_SINK_SERVICE = UUID.from_16_bits(0x1402, 'HDP Sink')
# fmt: on
# pylint: enable=line-too-long
# -----------------------------------------------------------------------------
# DeviceClass
# -----------------------------------------------------------------------------
class DeviceClass:
# fmt: off
# pylint: disable=line-too-long
# Major Service Classes (flags combined with OR)
LIMITED_DISCOVERABLE_MODE_SERVICE_CLASS = (1 << 0)
LE_AUDIO_SERVICE_CLASS = (1 << 1)
@@ -524,11 +569,18 @@ class DeviceClass:
PERIPHERAL_MAJOR_DEVICE_CLASS: PERIPHERAL_MINOR_DEVICE_CLASS_NAMES
}
# fmt: on
# pylint: enable=line-too-long
@staticmethod
def split_class_of_device(class_of_device):
# Split the bit fields of the composite class of device value into:
# (service_classes, major_device_class, minor_device_class)
return ((class_of_device >> 13 & 0x7FF), (class_of_device >> 8 & 0x1F), (class_of_device >> 2 & 0x3F))
return (
(class_of_device >> 13 & 0x7FF),
(class_of_device >> 8 & 0x1F),
(class_of_device >> 2 & 0x3F),
)
@staticmethod
def pack_class_of_device(service_classes, major_device_class, minor_device_class):
@@ -536,7 +588,9 @@ class DeviceClass:
@staticmethod
def service_class_labels(service_class_flags):
return bit_flags_to_strings(service_class_flags, DeviceClass.SERVICE_CLASS_LABELS)
return bit_flags_to_strings(
service_class_flags, DeviceClass.SERVICE_CLASS_LABELS
)
@staticmethod
def major_device_class_name(device_class):
@@ -554,6 +608,9 @@ class DeviceClass:
# Advertising Data
# -----------------------------------------------------------------------------
class AdvertisingData:
# fmt: off
# pylint: disable=line-too-long
# This list is only partial, it still needs to be filled in from the spec
FLAGS = 0x01
INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS = 0x02
@@ -665,7 +722,12 @@ class AdvertisingData:
BR_EDR_CONTROLLER_FLAG = 0x08
BR_EDR_HOST_FLAG = 0x10
def __init__(self, ad_structures = []):
# fmt: on
# pylint: enable=line-too-long
def __init__(self, ad_structures=None):
if ad_structures is None:
ad_structures = []
self.ad_structures = ad_structures[:]
@staticmethod
@@ -676,19 +738,17 @@ class AdvertisingData:
@staticmethod
def flags_to_string(flags, short=False):
flag_names = [
'LE Limited',
'LE General',
'No BR/EDR',
'BR/EDR C',
'BR/EDR H'
] if short else [
flag_names = (
['LE Limited', 'LE General', 'No BR/EDR', 'BR/EDR C', 'BR/EDR H']
if short
else [
'LE Limited Discoverable Mode',
'LE General Discoverable Mode',
'BR/EDR Not Supported',
'Simultaneous LE and BR/EDR (Controller)',
'Simultaneous LE and BR/EDR (Host)'
'Simultaneous LE and BR/EDR (Host)',
]
)
return ','.join(bit_flags_to_strings(flags, flag_names))
@staticmethod
@@ -702,10 +762,12 @@ class AdvertisingData:
@staticmethod
def uuid_list_to_string(ad_data, uuid_size):
return ', '.join([
return ', '.join(
[
str(uuid)
for uuid in AdvertisingData.uuid_list_to_objects(ad_data, uuid_size)
])
]
)
@staticmethod
def ad_data_to_string(ad_type, ad_data):
@@ -765,55 +827,64 @@ class AdvertisingData:
return f'[{ad_type_str}]: {ad_data_str}'
# pylint: disable=too-many-return-statements
@staticmethod
def ad_data_to_object(ad_type, ad_data):
if ad_type in {
if ad_type in (
AdvertisingData.COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.LIST_OF_16_BIT_SERVICE_SOLICITATION_UUIDS
}:
AdvertisingData.LIST_OF_16_BIT_SERVICE_SOLICITATION_UUIDS,
):
return AdvertisingData.uuid_list_to_objects(ad_data, 2)
elif ad_type in {
if ad_type in (
AdvertisingData.COMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.LIST_OF_32_BIT_SERVICE_SOLICITATION_UUIDS
}:
AdvertisingData.LIST_OF_32_BIT_SERVICE_SOLICITATION_UUIDS,
):
return AdvertisingData.uuid_list_to_objects(ad_data, 4)
elif ad_type in {
if ad_type in (
AdvertisingData.COMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.LIST_OF_128_BIT_SERVICE_SOLICITATION_UUIDS
}:
AdvertisingData.LIST_OF_128_BIT_SERVICE_SOLICITATION_UUIDS,
):
return AdvertisingData.uuid_list_to_objects(ad_data, 16)
elif ad_type == AdvertisingData.SERVICE_DATA_16_BIT_UUID:
if ad_type == AdvertisingData.SERVICE_DATA_16_BIT_UUID:
return (UUID.from_bytes(ad_data[:2]), ad_data[2:])
elif ad_type == AdvertisingData.SERVICE_DATA_32_BIT_UUID:
if ad_type == AdvertisingData.SERVICE_DATA_32_BIT_UUID:
return (UUID.from_bytes(ad_data[:4]), ad_data[4:])
elif ad_type == AdvertisingData.SERVICE_DATA_128_BIT_UUID:
if ad_type == AdvertisingData.SERVICE_DATA_128_BIT_UUID:
return (UUID.from_bytes(ad_data[:16]), ad_data[16:])
elif ad_type in {
if ad_type in (
AdvertisingData.SHORTENED_LOCAL_NAME,
AdvertisingData.COMPLETE_LOCAL_NAME,
AdvertisingData.URI
}:
AdvertisingData.URI,
):
return ad_data.decode("utf-8")
elif ad_type in {
AdvertisingData.TX_POWER_LEVEL,
AdvertisingData.FLAGS
}:
if ad_type in (AdvertisingData.TX_POWER_LEVEL, AdvertisingData.FLAGS):
return ad_data[0]
elif ad_type in {
if ad_type in (
AdvertisingData.APPEARANCE,
AdvertisingData.ADVERTISING_INTERVAL
}:
AdvertisingData.ADVERTISING_INTERVAL,
):
return struct.unpack('<H', ad_data)[0]
elif ad_type == AdvertisingData.CLASS_OF_DEVICE:
if ad_type == AdvertisingData.CLASS_OF_DEVICE:
return struct.unpack('<I', bytes([*ad_data, 0]))[0]
elif ad_type == AdvertisingData.PERIPHERAL_CONNECTION_INTERVAL_RANGE:
if ad_type == AdvertisingData.PERIPHERAL_CONNECTION_INTERVAL_RANGE:
return struct.unpack('<HH', ad_data)
elif ad_type == AdvertisingData.MANUFACTURER_SPECIFIC_DATA:
if ad_type == AdvertisingData.MANUFACTURER_SPECIFIC_DATA:
return (struct.unpack_from('<H', ad_data, 0)[0], ad_data[2:])
else:
return ad_data
def append(self, data):
@@ -834,19 +905,29 @@ class AdvertisingData:
If return_all is True, returns a (possibly empty) list of matches,
else returns the first entry, or None if no structure matches.
'''
def process_ad_data(ad_data):
return ad_data if raw else self.ad_data_to_object(type_id, ad_data)
if return_all:
return [process_ad_data(ad[1]) for ad in self.ad_structures if ad[0] == type_id]
else:
return next((process_ad_data(ad[1]) for ad in self.ad_structures if ad[0] == type_id), None)
return [
process_ad_data(ad[1]) for ad in self.ad_structures if ad[0] == type_id
]
return next(
(process_ad_data(ad[1]) for ad in self.ad_structures if ad[0] == type_id),
None,
)
def __bytes__(self):
return b''.join([bytes([len(x[1]) + 1, x[0]]) + x[1] for x in self.ad_structures])
return b''.join(
[bytes([len(x[1]) + 1, x[0]]) + x[1] for x in self.ad_structures]
)
def to_string(self, separator=', '):
return separator.join([AdvertisingData.ad_data_to_string(x[0], x[1]) for x in self.ad_structures])
return separator.join(
[AdvertisingData.ad_data_to_string(x[0], x[1]) for x in self.ad_structures]
)
def __str__(self):
return self.to_string()

View File

@@ -24,19 +24,16 @@
import logging
import operator
import platform
if platform.system() != 'Emscripten':
import secrets
from cryptography.hazmat.primitives.ciphers import (
Cipher,
algorithms,
modes
)
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.primitives.asymmetric.ec import (
generate_private_key,
ECDH,
EllipticCurvePublicNumbers,
EllipticCurvePrivateNumbers,
SECP256R1
SECP256R1,
)
from cryptography.hazmat.primitives import cmac
else:
@@ -66,16 +63,26 @@ class EccKey:
d = int.from_bytes(d_bytes, byteorder='big', signed=False)
x = int.from_bytes(x_bytes, byteorder='big', signed=False)
y = int.from_bytes(y_bytes, byteorder='big', signed=False)
private_key = EllipticCurvePrivateNumbers(d, EllipticCurvePublicNumbers(x, y, SECP256R1())).private_key()
private_key = EllipticCurvePrivateNumbers(
d, EllipticCurvePublicNumbers(x, y, SECP256R1())
).private_key()
return cls(private_key)
@property
def x(self):
return self.private_key.public_key().public_numbers().x.to_bytes(32, byteorder='big')
return (
self.private_key.public_key()
.public_numbers()
.x.to_bytes(32, byteorder='big')
)
@property
def y(self):
return self.private_key.public_key().public_numbers().y.to_bytes(32, byteorder='big')
return (
self.private_key.public_key()
.public_numbers()
.y.to_bytes(32, byteorder='big')
)
def dh(self, public_key_x, public_key_y):
x = int.from_bytes(public_key_x, byteorder='big', signed=False)
@@ -92,7 +99,7 @@ class EccKey:
# -----------------------------------------------------------------------------
def xor(x, y):
assert(len(x) == len(y))
assert len(x) == len(y)
return bytes(map(operator.xor, x, y))
@@ -118,7 +125,7 @@ def e(key, data):
# -----------------------------------------------------------------------------
def ah(k, r):
def ah(k, r): # pylint: disable=redefined-outer-name
'''
See Bluetooth spec Vol 3, Part H - 2.2.2 Random Address Hash function ah
'''
@@ -129,9 +136,10 @@ def ah(k, r):
# -----------------------------------------------------------------------------
def c1(k, r, preq, pres, iat, rat, ia, ra):
def c1(k, r, preq, pres, iat, rat, ia, ra): # pylint: disable=redefined-outer-name
'''
See Bluetooth spec, Vol 3, Part H - 2.2.3 Confirm value generation function c1 for LE Legacy Pairing
See Bluetooth spec, Vol 3, Part H - 2.2.3 Confirm value generation function c1 for
LE Legacy Pairing
'''
p1 = bytes([iat, rat]) + preq + pres
@@ -142,7 +150,8 @@ def c1(k, r, preq, pres, iat, rat, ia, ra):
# -----------------------------------------------------------------------------
def s1(k, r1, r2):
'''
See Bluetooth spec, Vol 3, Part H - 2.2.4 Key generation function s1 for LE Legacy Pairing
See Bluetooth spec, Vol 3, Part H - 2.2.4 Key generation function s1 for LE Legacy
Pairing
'''
return e(k, r2[0:8] + r1[0:8])
@@ -163,71 +172,95 @@ def aes_cmac(m, k):
# -----------------------------------------------------------------------------
def f4(u, v, x, z):
'''
See Bluetooth spec, Vol 3, Part H - 2.2.6 LE Secure Connections Confirm Value Generation Function f4
See Bluetooth spec, Vol 3, Part H - 2.2.6 LE Secure Connections Confirm Value
Generation Function f4
'''
return bytes(reversed(aes_cmac(bytes(reversed(u)) + bytes(reversed(v)) + z, bytes(reversed(x)))))
return bytes(
reversed(
aes_cmac(bytes(reversed(u)) + bytes(reversed(v)) + z, bytes(reversed(x)))
)
)
# -----------------------------------------------------------------------------
def f5(w, n1, n2, a1, a2):
'''
See Bluetooth spec, Vol 3, Part H - 2.2.7 LE Secure Connections Key Generation Function f5
See Bluetooth spec, Vol 3, Part H - 2.2.7 LE Secure Connections Key Generation
Function f5
NOTE: this returns a tuple: (MacKey, LTK) in little-endian byte order
'''
salt = bytes.fromhex('6C888391AAF5A53860370BDB5A6083BE')
t = aes_cmac(bytes(reversed(w)), salt)
key_id = bytes([0x62, 0x74, 0x6c, 0x65])
key_id = bytes([0x62, 0x74, 0x6C, 0x65])
return (
bytes(reversed(aes_cmac(
bytes([0]) +
key_id +
bytes(reversed(n1)) +
bytes(reversed(n2)) +
bytes(reversed(a1)) +
bytes(reversed(a2)) +
bytes([1, 0]),
t
))),
bytes(reversed(aes_cmac(
bytes([1]) +
key_id +
bytes(reversed(n1)) +
bytes(reversed(n2)) +
bytes(reversed(a1)) +
bytes(reversed(a2)) +
bytes([1, 0]),
t
)))
bytes(
reversed(
aes_cmac(
bytes([0])
+ key_id
+ bytes(reversed(n1))
+ bytes(reversed(n2))
+ bytes(reversed(a1))
+ bytes(reversed(a2))
+ bytes([1, 0]),
t,
)
)
),
bytes(
reversed(
aes_cmac(
bytes([1])
+ key_id
+ bytes(reversed(n1))
+ bytes(reversed(n2))
+ bytes(reversed(a1))
+ bytes(reversed(a2))
+ bytes([1, 0]),
t,
)
)
),
)
# -----------------------------------------------------------------------------
def f6(w, n1, n2, r, io_cap, a1, a2):
def f6(w, n1, n2, r, io_cap, a1, a2): # pylint: disable=redefined-outer-name
'''
See Bluetooth spec, Vol 3, Part H - 2.2.8 LE Secure Connections Check Value Generation Function f6
See Bluetooth spec, Vol 3, Part H - 2.2.8 LE Secure Connections Check Value
Generation Function f6
'''
return bytes(reversed(aes_cmac(
bytes(reversed(n1)) +
bytes(reversed(n2)) +
bytes(reversed(r)) +
bytes(reversed(io_cap)) +
bytes(reversed(a1)) +
bytes(reversed(a2)),
bytes(reversed(w))
)))
return bytes(
reversed(
aes_cmac(
bytes(reversed(n1))
+ bytes(reversed(n2))
+ bytes(reversed(r))
+ bytes(reversed(io_cap))
+ bytes(reversed(a1))
+ bytes(reversed(a2)),
bytes(reversed(w)),
)
)
)
# -----------------------------------------------------------------------------
def g2(u, v, x, y):
'''
See Bluetooth spec, Vol 3, Part H - 2.2.9 LE Secure Connections Numeric Comparison Value Generation Function g2
See Bluetooth spec, Vol 3, Part H - 2.2.9 LE Secure Connections Numeric Comparison
Value Generation Function g2
'''
return int.from_bytes(
aes_cmac(bytes(reversed(u)) + bytes(reversed(v)) + bytes(reversed(y)), bytes(reversed(x)))[-4:],
byteorder='big'
aes_cmac(
bytes(reversed(u)) + bytes(reversed(v)) + bytes(reversed(y)),
bytes(reversed(x)),
)[-4:],
byteorder='big',
)
# -----------------------------------------------------------------------------
def h6(w, key_id):
'''
@@ -235,6 +268,7 @@ def h6(w, key_id):
'''
return aes_cmac(key_id, w)
# -----------------------------------------------------------------------------
def h7(salt, w):
'''

File diff suppressed because it is too large Load Diff

View File

@@ -23,7 +23,7 @@ from .gatt import (
Characteristic,
GATT_GENERIC_ACCESS_SERVICE,
GATT_DEVICE_NAME_CHARACTERISTIC,
GATT_APPEARANCE_CHARACTERISTIC
GATT_APPEARANCE_CHARACTERISTIC,
)
# -----------------------------------------------------------------------------
@@ -43,17 +43,17 @@ class GenericAccessService(Service):
GATT_DEVICE_NAME_CHARACTERISTIC,
Characteristic.READ,
Characteristic.READABLE,
device_name.encode('utf-8')[:248]
device_name.encode('utf-8')[:248],
)
appearance_characteristic = Characteristic(
GATT_APPEARANCE_CHARACTERISTIC,
Characteristic.READ,
Characteristic.READABLE,
struct.pack('<H', (appearance[0] << 6) | appearance[1])
struct.pack('<H', (appearance[0] << 6) | appearance[1]),
)
super().__init__(GATT_GENERIC_ACCESS_SERVICE, [
device_name_characteristic,
appearance_characteristic
])
super().__init__(
GATT_GENERIC_ACCESS_SERVICE,
[device_name_characteristic, appearance_characteristic],
)

View File

@@ -22,16 +22,18 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import enum
import types
import functools
import logging
from pyee import EventEmitter
import struct
from typing import Sequence
from colors import color
from .core import *
from .hci import *
from .att import *
from .core import UUID, get_dict_key_by_value
from .att import Attribute
# -----------------------------------------------------------------------------
# Logging
@@ -41,6 +43,9 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
# fmt: off
# pylint: disable=line-too-long
GATT_REQUEST_TIMEOUT = 30 # seconds
GATT_MAX_ATTRIBUTE_VALUE_SIZE = 512
@@ -151,6 +156,14 @@ GATT_HEART_RATE_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2A39, 'Heart
# Battery Service
GATT_BATTERY_LEVEL_CHARACTERISTIC = UUID.from_16_bits(0x2A19, 'Battery Level')
# ASHA Service
GATT_ASHA_SERVICE = UUID.from_16_bits(0xFDF0, 'Audio Streaming for Hearing Aid')
GATT_ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC = UUID('6333651e-c481-4a3e-9169-7c902aad37bb', 'ReadOnlyProperties')
GATT_ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC = UUID('f0d4de7e-4a88-476c-9d9f-1937b0996cc0', 'AudioControlPoint')
GATT_ASHA_AUDIO_STATUS_CHARACTERISTIC = UUID('38663f1a-e711-4cac-b641-326b56404837', 'AudioStatus')
GATT_ASHA_VOLUME_CHARACTERISTIC = UUID('00e4ca9e-ab14-41e4-8823-f9e70c7e91df', 'Volume')
GATT_ASHA_LE_PSM_OUT_CHARACTERISTIC = UUID('2d410339-82b6-42aa-b34e-e2e01df8cc1a', 'LE_PSM_OUT')
# Misc
GATT_DEVICE_NAME_CHARACTERISTIC = UUID.from_16_bits(0x2A00, 'Device Name')
GATT_APPEARANCE_CHARACTERISTIC = UUID.from_16_bits(0x2A01, 'Appearance')
@@ -165,11 +178,15 @@ GATT_CURRENT_TIME_CHARACTERISTIC = UUID.from_16_bi
GATT_BOOT_KEYBOARD_OUTPUT_REPORT_CHARACTERISTIC = UUID.from_16_bits(0x2A32, 'Boot Keyboard Output Report')
GATT_CENTRAL_ADDRESS_RESOLUTION__CHARACTERISTIC = UUID.from_16_bits(0x2AA6, 'Central Address Resolution')
# fmt: on
# pylint: enable=line-too-long
# -----------------------------------------------------------------------------
# Utils
# -----------------------------------------------------------------------------
def show_services(services):
for service in services:
print(color(str(service), 'cyan'))
@@ -187,23 +204,38 @@ class Service(Attribute):
See Vol 3, Part G - 3.1 SERVICE DEFINITION
'''
def __init__(self, uuid, characteristics, primary=True):
def __init__(self, uuid, characteristics: list[Characteristic], primary=True):
# Convert the uuid to a UUID object if it isn't already
if type(uuid) is str:
if isinstance(uuid, str):
uuid = UUID(uuid)
super().__init__(
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE if primary else GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE,
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE
if primary
else GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE,
Attribute.READABLE,
uuid.to_pdu_bytes()
uuid.to_pdu_bytes(),
)
self.uuid = uuid
self.included_services = []
# self.included_services = []
self.characteristics = characteristics[:]
self.primary = primary
def get_advertising_data(self):
"""
Get Service specific advertising data
Defined by each Service, default value is empty
:return Service data for advertising
"""
return None
def __str__(self):
return f'Service(handle=0x{self.handle:04X}, end=0x{self.end_group_handle:04X}, uuid={self.uuid}){"" if self.primary else "*"}'
return (
f'Service(handle=0x{self.handle:04X}, '
f'end=0x{self.end_group_handle:04X}, '
f'uuid={self.uuid})'
f'{"" if self.primary else "*"}'
)
# -----------------------------------------------------------------------------
@@ -212,6 +244,7 @@ class TemplateService(Service):
Convenience abstract class that can be used by profile-specific subclasses that want
to expose their UUID as a class property
'''
UUID = None
def __init__(self, characteristics, primary=True):
@@ -230,9 +263,9 @@ class Characteristic(Attribute):
WRITE_WITHOUT_RESPONSE = 0x04
WRITE = 0x08
NOTIFY = 0x10
INDICATE = 0X20
AUTHENTICATED_SIGNED_WRITES = 0X40
EXTENDED_PROPERTIES = 0X80
INDICATE = 0x20
AUTHENTICATED_SIGNED_WRITES = 0x40
EXTENDED_PROPERTIES = 0x80
PROPERTY_NAMES = {
BROADCAST: 'BROADCAST',
@@ -242,23 +275,44 @@ class Characteristic(Attribute):
NOTIFY: 'NOTIFY',
INDICATE: 'INDICATE',
AUTHENTICATED_SIGNED_WRITES: 'AUTHENTICATED_SIGNED_WRITES',
EXTENDED_PROPERTIES: 'EXTENDED_PROPERTIES'
EXTENDED_PROPERTIES: 'EXTENDED_PROPERTIES',
}
@staticmethod
def property_name(property):
return Characteristic.PROPERTY_NAMES.get(property, '')
def property_name(property_int):
return Characteristic.PROPERTY_NAMES.get(property_int, '')
@staticmethod
def properties_as_string(properties):
return ','.join([
Characteristic.property_name(p) for p in Characteristic.PROPERTY_NAMES.keys()
return ','.join(
[
Characteristic.property_name(p)
for p in Characteristic.PROPERTY_NAMES
if properties & p
])
]
)
def __init__(self, uuid, properties, permissions, value = b'', descriptors = []):
@staticmethod
def string_to_properties(properties_str: str):
return functools.reduce(
lambda x, y: x | get_dict_key_by_value(Characteristic.PROPERTY_NAMES, y),
properties_str.split(","),
0,
)
def __init__(
self,
uuid,
properties,
permissions,
value=b'',
descriptors: Sequence[Descriptor] = (),
):
super().__init__(uuid, permissions, value)
self.uuid = self.type
if isinstance(properties, str):
self.properties = Characteristic.string_to_properties(properties)
else:
self.properties = properties
self.descriptors = descriptors
@@ -267,8 +321,41 @@ class Characteristic(Attribute):
if descriptor.type == descriptor_type:
return descriptor
return None
def __str__(self):
return f'Characteristic(handle=0x{self.handle:04X}, end=0x{self.end_group_handle:04X}, uuid={self.uuid}, properties={Characteristic.properties_as_string(self.properties)})'
return (
f'Characteristic(handle=0x{self.handle:04X}, '
f'end=0x{self.end_group_handle:04X}, '
f'uuid={self.uuid}, '
f'properties={Characteristic.properties_as_string(self.properties)})'
)
# -----------------------------------------------------------------------------
class CharacteristicDeclaration(Attribute):
'''
See Vol 3, Part G - 3.3.1 CHARACTERISTIC DECLARATION
'''
def __init__(self, characteristic, value_handle):
declaration_bytes = (
struct.pack('<BH', characteristic.properties, value_handle)
+ characteristic.uuid.to_pdu_bytes()
)
super().__init__(
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE, Attribute.READABLE, declaration_bytes
)
self.value_handle = value_handle
self.characteristic = characteristic
def __str__(self):
return (
f'CharacteristicDeclaration(handle=0x{self.handle:04X}, '
f'value_handle=0x{self.value_handle:04X}, '
f'uuid={self.characteristic.uuid}, properties='
f'{Characteristic.properties_as_string(self.characteristic.properties)})'
)
# -----------------------------------------------------------------------------
@@ -277,6 +364,7 @@ class CharacteristicValue:
Characteristic value where reading and/or writing is delegated to functions
passed as arguments to the constructor.
'''
def __init__(self, read=None, write=None):
self._read = read
self._write = write
@@ -303,14 +391,14 @@ class CharacteristicAdapter:
If the characteristic has a `subscribe` method, it is wrapped with one where
the values are decoded before being passed to the subscriber.
'''
def __init__(self, characteristic):
self.wrapped_characteristic = characteristic
self.subscribers = {} # Map from subscriber to proxy subscriber
if (
asyncio.iscoroutinefunction(characteristic.read_value) and
asyncio.iscoroutinefunction(characteristic.write_value)
):
if asyncio.iscoroutinefunction(
characteristic.read_value
) and asyncio.iscoroutinefunction(characteristic.write_value):
self.read_value = self.read_decoded_value
self.write_value = self.write_decoded_value
else:
@@ -327,14 +415,14 @@ class CharacteristicAdapter:
return getattr(self.wrapped_characteristic, name)
def __setattr__(self, name, value):
if name in {
if name in (
'wrapped_characteristic',
'subscribers',
'read_value',
'write_value',
'subscribe',
'unsubscribe'
}:
'unsubscribe',
):
super().__setattr__(name, value)
else:
setattr(self.wrapped_characteristic, name, value)
@@ -343,15 +431,16 @@ class CharacteristicAdapter:
return self.encode_value(self.wrapped_characteristic.read_value(connection))
def write_encoded_value(self, connection, value):
return self.wrapped_characteristic.write_value(connection, self.decode_value(value))
return self.wrapped_characteristic.write_value(
connection, self.decode_value(value)
)
async def read_decoded_value(self):
return self.decode_value(await self.wrapped_characteristic.read_value())
async def write_decoded_value(self, value, with_response=False):
return await self.wrapped_characteristic.write_value(
self.encode_value(value),
with_response
self.encode_value(value), with_response
)
def encode_value(self, value):
@@ -371,6 +460,7 @@ class CharacteristicAdapter:
def on_change(value):
original_subscriber(self.decode_value(value))
self.subscribers[subscriber] = on_change
subscriber = on_change
@@ -392,6 +482,7 @@ class DelegatedCharacteristicAdapter(CharacteristicAdapter):
'''
Adapter that converts bytes values using an encode and a decode function.
'''
def __init__(self, characteristic, encode=None, decode=None):
super().__init__(characteristic)
self.encode = encode
@@ -414,9 +505,10 @@ class PackedCharacteristicAdapter(CharacteristicAdapter):
they return/accept a tuple with the same number of elements as is required for
the format.
'''
def __init__(self, characteristic, format):
def __init__(self, characteristic, pack_format):
super().__init__(characteristic)
self.struct = struct.Struct(format)
self.struct = struct.Struct(pack_format)
def pack(self, *values):
return self.struct.pack(*values)
@@ -425,7 +517,7 @@ class PackedCharacteristicAdapter(CharacteristicAdapter):
return self.struct.unpack(buffer)
def encode_value(self, value):
return self.pack(*value if type(value) is tuple else (value,))
return self.pack(*value if isinstance(value, tuple) else (value,))
def decode_value(self, value):
unpacked = self.unpack(value)
@@ -438,13 +530,15 @@ class MappedCharacteristicAdapter(PackedCharacteristicAdapter):
Adapter that packs/unpacks characteristic values according to a standard
Python `struct` format.
The adapted `read_value` and `write_value` methods return/accept aa dictionary which
is packed/unpacked according to format, with the arguments extracted from the dictionary
by key, in the same order as they occur in the `keys` parameter.
is packed/unpacked according to format, with the arguments extracted from the
dictionary by key, in the same order as they occur in the `keys` parameter.
'''
def __init__(self, characteristic, format, keys):
super().__init__(characteristic, format)
def __init__(self, characteristic, pack_format, keys):
super().__init__(characteristic, pack_format)
self.keys = keys
# pylint: disable=arguments-differ
def pack(self, values):
return super().pack(*(values[key] for key in self.keys))
@@ -457,6 +551,7 @@ class UTF8CharacteristicAdapter(CharacteristicAdapter):
'''
Adapter that converts strings to/from bytes using UTF-8 encoding
'''
def encode_value(self, value):
return value.encode('utf-8')
@@ -470,17 +565,20 @@ class Descriptor(Attribute):
See Vol 3, Part G - 3.3.3 Characteristic Descriptor Declarations
'''
def __init__(self, descriptor_type, permissions, value = b''):
super().__init__(descriptor_type, permissions, value)
def __str__(self):
return f'Descriptor(handle=0x{self.handle:04X}, type={self.type}, value={self.read_value(None).hex()})'
return (
f'Descriptor(handle=0x{self.handle:04X}, '
f'type={self.type}, '
f'value={self.read_value(None).hex()})'
)
class ClientCharacteristicConfigurationBits(enum.IntFlag):
'''
See Vol 3, Part G - 3.3.3.3 - Table 3.11 Client Characteristic Configuration bit field definition
See Vol 3, Part G - 3.3.3.3 - Table 3.11 Client Characteristic Configuration bit
field definition
'''
DEFAULT = 0x0000
NOTIFICATION = 0x0001
INDICATION = 0x0002

View File

@@ -28,15 +28,40 @@ import logging
import struct
from colors import color
from pyee import EventEmitter
from .att import *
from .core import InvalidStateError, ProtocolError, TimeoutError
from .gatt import (GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
from .hci import HCI_Constant
from .att import (
ATT_ATTRIBUTE_NOT_FOUND_ERROR,
ATT_ATTRIBUTE_NOT_LONG_ERROR,
ATT_CID,
ATT_DEFAULT_MTU,
ATT_ERROR_RESPONSE,
ATT_INVALID_OFFSET_ERROR,
ATT_PDU,
ATT_RESPONSES,
ATT_Exchange_MTU_Request,
ATT_Find_By_Type_Value_Request,
ATT_Find_Information_Request,
ATT_Handle_Value_Confirmation,
ATT_Read_Blob_Request,
ATT_Read_By_Group_Type_Request,
ATT_Read_By_Type_Request,
ATT_Read_Request,
ATT_Write_Command,
ATT_Write_Request,
)
from . import core
from .core import UUID, InvalidStateError, ProtocolError
from .gatt import (
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE, GATT_REQUEST_TIMEOUT,
GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE, Characteristic,
ClientCharacteristicConfigurationBits)
from .hci import *
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE,
GATT_REQUEST_TIMEOUT,
GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE,
Characteristic,
ClientCharacteristicConfigurationBits,
)
# -----------------------------------------------------------------------------
# Logging
@@ -56,10 +81,14 @@ class AttributeProxy(EventEmitter):
self.type = attribute_type
async def read_value(self, no_long_read=False):
return self.decode_value(await self.client.read_value(self.handle, no_long_read))
return self.decode_value(
await self.client.read_value(self.handle, no_long_read)
)
async def write_value(self, value, with_response=False):
return await self.client.write_value(self.handle, self.encode_value(value), with_response)
return await self.client.write_value(
self.handle, self.encode_value(value), with_response
)
def encode_value(self, value):
return value
@@ -68,24 +97,29 @@ class AttributeProxy(EventEmitter):
return value_bytes
def __str__(self):
return f'Attribute(handle=0x{self.handle:04X}, type={self.uuid})'
return f'Attribute(handle=0x{self.handle:04X}, type={self.type})'
class ServiceProxy(AttributeProxy):
@staticmethod
def from_client(cls, client, service_uuid):
# The service and its characteristics are considered to have already been discovered
def from_client(service_class, client, service_uuid):
# The service and its characteristics are considered to have already been
# discovered
services = client.get_services_by_uuid(service_uuid)
service = services[0] if services else None
return cls(service) if service else None
return service_class(service) if service else None
def __init__(self, client, handle, end_group_handle, uuid, primary=True):
attribute_type = GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE if primary else GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE
attribute_type = (
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE
if primary
else GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE
)
super().__init__(client, handle, end_group_handle, attribute_type)
self.uuid = uuid
self.characteristics = []
async def discover_characteristics(self, uuids=[]):
async def discover_characteristics(self, uuids=()):
return await self.client.discover_characteristics(uuids, self)
def get_characteristics_by_uuid(self, uuid):
@@ -109,6 +143,8 @@ class CharacteristicProxy(AttributeProxy):
if descriptor.type == descriptor_type:
return descriptor
return None
async def discover_descriptors(self):
return await self.client.discover_descriptors(self)
@@ -123,6 +159,7 @@ class CharacteristicProxy(AttributeProxy):
def on_change(value):
original_subscriber(self.decode_value(value))
self.subscribers[subscriber] = on_change
subscriber = on_change
@@ -135,7 +172,11 @@ class CharacteristicProxy(AttributeProxy):
return await self.client.unsubscribe(self, subscriber)
def __str__(self):
return f'Characteristic(handle=0x{self.handle:04X}, uuid={self.uuid}, properties={Characteristic.properties_as_string(self.properties)})'
return (
f'Characteristic(handle=0x{self.handle:04X}, '
f'uuid={self.uuid}, '
f'properties={Characteristic.properties_as_string(self.properties)})'
)
class DescriptorProxy(AttributeProxy):
@@ -150,6 +191,7 @@ class ProfileServiceProxy:
'''
Base class for profile-specific service proxies
'''
@classmethod
def from_client(cls, client):
return ServiceProxy.from_client(cls, client, cls.SERVICE_CLASS.UUID)
@@ -165,7 +207,9 @@ class Client:
self.request_semaphore = asyncio.Semaphore(1)
self.pending_request = None
self.pending_response = None
self.notification_subscribers = {} # Notification subscribers, by attribute handle
self.notification_subscribers = (
{}
) # Notification subscribers, by attribute handle
self.indication_subscribers = {} # Indication subscribers, by attribute handle
self.services = []
@@ -173,11 +217,15 @@ class Client:
self.connection.send_l2cap_pdu(ATT_CID, pdu)
async def send_command(self, command):
logger.debug(f'GATT Command from client: [0x{self.connection.handle:04X}] {command}')
logger.debug(
f'GATT Command from client: [0x{self.connection.handle:04X}] {command}'
)
self.send_gatt_pdu(command.to_bytes())
async def send_request(self, request):
logger.debug(f'GATT Request from client: [0x{self.connection.handle:04X}] {request}')
logger.debug(
f'GATT Request from client: [0x{self.connection.handle:04X}] {request}'
)
# Wait until we can send (only one pending command at a time for the connection)
response = None
@@ -191,10 +239,12 @@ class Client:
try:
self.send_gatt_pdu(request.to_bytes())
response = await asyncio.wait_for(self.pending_response, GATT_REQUEST_TIMEOUT)
except asyncio.TimeoutError:
response = await asyncio.wait_for(
self.pending_response, GATT_REQUEST_TIMEOUT
)
except asyncio.TimeoutError as error:
logger.warning(color('!!! GATT Request timeout', 'red'))
raise TimeoutError(f'GATT timeout for {request.name}')
raise core.TimeoutError(f'GATT timeout for {request.name}') from error
finally:
self.pending_request = None
self.pending_response = None
@@ -202,7 +252,10 @@ class Client:
return response
def send_confirmation(self, confirmation):
logger.debug(f'GATT Confirmation from client: [0x{self.connection.handle:04X}] {confirmation}')
logger.debug(
f'GATT Confirmation from client: [0x{self.connection.handle:04X}] '
f'{confirmation}'
)
self.send_gatt_pdu(confirmation.to_bytes())
async def request_mtu(self, mtu):
@@ -224,7 +277,7 @@ class Client:
response.error_code,
'att',
ATT_PDU.error_name(response.error_code),
response
response,
)
# Compute the final MTU
@@ -237,7 +290,11 @@ class Client:
def get_characteristics_by_uuid(self, uuid, service=None):
services = [service] if service else self.services
return [c for c in [c for s in services for c in s.characteristics] if c.uuid == uuid]
return [
c
for c in [c for s in services for c in s.characteristics]
if c.uuid == uuid
]
def on_service_discovered(self, service):
'''Add a service to the service list if it wasn't already there'''
@@ -260,7 +317,7 @@ class Client:
ATT_Read_By_Group_Type_Request(
starting_handle=starting_handle,
ending_handle=0xFFFF,
attribute_group_type = GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE
attribute_group_type=GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE,
)
)
if response is None:
@@ -271,15 +328,27 @@ class Client:
if response.op_code == ATT_ERROR_RESPONSE:
if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR:
# Unexpected end
logger.warning(f'!!! unexpected error while discovering services: {HCI_Constant.error_name(response.error_code)}')
logger.warning(
'!!! unexpected error while discovering services: '
f'{HCI_Constant.error_name(response.error_code)}'
)
# TODO raise appropriate exception
return
break
for attribute_handle, end_group_handle, attribute_value in response.attributes:
if attribute_handle < starting_handle or end_group_handle < attribute_handle:
for (
attribute_handle,
end_group_handle,
attribute_value,
) in response.attributes:
if (
attribute_handle < starting_handle
or end_group_handle < attribute_handle
):
# Something's not right
logger.warning(f'bogus handle values: {attribute_handle} {end_group_handle}')
logger.warning(
f'bogus handle values: {attribute_handle} {end_group_handle}'
)
return
# Create a service proxy for this service
@@ -288,7 +357,7 @@ class Client:
attribute_handle,
end_group_handle,
UUID.from_bytes(attribute_value),
True
True,
)
# Filter out returned services based on the given uuids list
@@ -313,7 +382,7 @@ class Client:
'''
# Force uuid to be a UUID object
if type(uuid) is str:
if isinstance(uuid, str):
uuid = UUID(uuid)
starting_handle = 0x0001
@@ -324,7 +393,7 @@ class Client:
starting_handle=starting_handle,
ending_handle=0xFFFF,
attribute_type=GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE,
attribute_value = uuid.to_pdu_bytes()
attribute_value=uuid.to_pdu_bytes(),
)
)
if response is None:
@@ -335,19 +404,29 @@ class Client:
if response.op_code == ATT_ERROR_RESPONSE:
if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR:
# Unexpected end
logger.warning(f'!!! unexpected error while discovering services: {HCI_Constant.error_name(response.error_code)}')
logger.warning(
'!!! unexpected error while discovering services: '
f'{HCI_Constant.error_name(response.error_code)}'
)
# TODO raise appropriate exception
return
break
for attribute_handle, end_group_handle in response.handles_information:
if attribute_handle < starting_handle or end_group_handle < attribute_handle:
if (
attribute_handle < starting_handle
or end_group_handle < attribute_handle
):
# Something's not right
logger.warning(f'bogus handle values: {attribute_handle} {end_group_handle}')
logger.warning(
f'bogus handle values: {attribute_handle} {end_group_handle}'
)
return
# Create a service proxy for this service
service = ServiceProxy(self, attribute_handle, end_group_handle, uuid, True)
service = ServiceProxy(
self, attribute_handle, end_group_handle, uuid, True
)
# Add the service to the peer's service list
services.append(service)
@@ -366,7 +445,7 @@ class Client:
return services
async def discover_included_services(self, service):
async def discover_included_services(self, _service):
'''
See Vol 3, Part G - 4.5.1 Find Included Services
'''
@@ -375,11 +454,12 @@ class Client:
async def discover_characteristics(self, uuids, service):
'''
See Vol 3, Part G - 4.6.1 Discover All Characteristics of a Service and 4.6.2 Discover Characteristics by UUID
See Vol 3, Part G - 4.6.1 Discover All Characteristics of a Service and 4.6.2
Discover Characteristics by UUID
'''
# Cast the UUIDs type from string to object if needed
uuids = [UUID(uuid) if type(uuid) is str else uuid for uuid in uuids]
uuids = [UUID(uuid) if isinstance(uuid, str) else uuid for uuid in uuids]
# Decide which services to discover for
services = [service] if service else self.services
@@ -396,7 +476,7 @@ class Client:
ATT_Read_By_Type_Request(
starting_handle=starting_handle,
ending_handle=ending_handle,
attribute_type = GATT_CHARACTERISTIC_ATTRIBUTE_TYPE
attribute_type=GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
)
)
if response is None:
@@ -407,7 +487,10 @@ class Client:
if response.op_code == ATT_ERROR_RESPONSE:
if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR:
# Unexpected end
logger.warning(f'!!! unexpected error while discovering characteristics: {HCI_Constant.error_name(response.error_code)}')
logger.warning(
'!!! unexpected error while discovering characteristics: '
f'{HCI_Constant.error_name(response.error_code)}'
)
# TODO raise appropriate exception
return
break
@@ -425,7 +508,9 @@ class Client:
properties, handle = struct.unpack_from('<BH', attribute_value)
characteristic_uuid = UUID.from_bytes(attribute_value[3:])
characteristic = CharacteristicProxy(self, handle, 0, characteristic_uuid, properties)
characteristic = CharacteristicProxy(
self, handle, 0, characteristic_uuid, properties
)
# Set the previous characteristic's end handle
if characteristics:
@@ -441,13 +526,17 @@ class Client:
characteristics[-1].end_group_handle = service.end_group_handle
# Set the service's characteristics
characteristics = [c for c in characteristics if not uuids or c.uuid in uuids]
characteristics = [
c for c in characteristics if not uuids or c.uuid in uuids
]
service.characteristics = characteristics
discovered_characteristics.extend(characteristics)
return discovered_characteristics
async def discover_descriptors(self, characteristic = None, start_handle = None, end_handle = None):
async def discover_descriptors(
self, characteristic=None, start_handle=None, end_handle=None
):
'''
See Vol 3, Part G - 4.7.1 Discover All Characteristic Descriptors
'''
@@ -464,8 +553,7 @@ class Client:
while starting_handle <= ending_handle:
response = await self.send_request(
ATT_Find_Information_Request(
starting_handle = starting_handle,
ending_handle = ending_handle
starting_handle=starting_handle, ending_handle=ending_handle
)
)
if response is None:
@@ -476,7 +564,10 @@ class Client:
if response.op_code == ATT_ERROR_RESPONSE:
if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR:
# Unexpected end
logger.warning(f'!!! unexpected error while discovering descriptors: {HCI_Constant.error_name(response.error_code)}')
logger.warning(
'!!! unexpected error while discovering descriptors: '
f'{HCI_Constant.error_name(response.error_code)}'
)
# TODO raise appropriate exception
return []
break
@@ -492,7 +583,9 @@ class Client:
logger.warning(f'bogus handle value: {attribute_handle}')
return []
descriptor = DescriptorProxy(self, attribute_handle, UUID.from_bytes(attribute_uuid))
descriptor = DescriptorProxy(
self, attribute_handle, UUID.from_bytes(attribute_uuid)
)
descriptors.append(descriptor)
# TODO: read descriptor value
@@ -515,8 +608,7 @@ class Client:
while True:
response = await self.send_request(
ATT_Find_Information_Request(
starting_handle = starting_handle,
ending_handle = ending_handle
starting_handle=starting_handle, ending_handle=ending_handle
)
)
if response is None:
@@ -526,7 +618,10 @@ class Client:
if response.op_code == ATT_ERROR_RESPONSE:
if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR:
# Unexpected end
logger.warning(f'!!! unexpected error while discovering attributes: {HCI_Constant.error_name(response.error_code)}')
logger.warning(
'!!! unexpected error while discovering attributes: '
f'{HCI_Constant.error_name(response.error_code)}'
)
return []
break
@@ -536,7 +631,9 @@ class Client:
logger.warning(f'bogus handle value: {attribute_handle}')
return []
attribute = AttributeProxy(self, attribute_handle, 0, UUID.from_bytes(attribute_uuid))
attribute = AttributeProxy(
self, attribute_handle, 0, UUID.from_bytes(attribute_uuid)
)
attributes.append(attribute)
# Move on to the next attributes
@@ -545,12 +642,15 @@ class Client:
return attributes
async def subscribe(self, characteristic, subscriber=None, prefer_notify=True):
# If we haven't already discovered the descriptors for this characteristic, do it now
# If we haven't already discovered the descriptors for this characteristic,
# do it now
if not characteristic.descriptors_discovered:
await self.discover_descriptors(characteristic)
# Look for the CCCD descriptor
cccd = characteristic.get_descriptor(GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR)
cccd = characteristic.get_descriptor(
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR
)
if not cccd:
logger.warning('subscribing to characteristic with no CCCD descriptor')
return
@@ -578,33 +678,40 @@ class Client:
subscriber_set = subscribers.setdefault(characteristic.handle, set())
if subscriber is not None:
subscriber_set.add(subscriber)
# Add the characteristic as a subscriber, which will result in the characteristic
# emitting an 'update' event when a notification or indication is received
# Add the characteristic as a subscriber, which will result in the
# characteristic emitting an 'update' event when a notification or indication
# is received
subscriber_set.add(characteristic)
await self.write_value(cccd, struct.pack('<H', bits), with_response=True)
async def unsubscribe(self, characteristic, subscriber=None):
# If we haven't already discovered the descriptors for this characteristic, do it now
# If we haven't already discovered the descriptors for this characteristic,
# do it now
if not characteristic.descriptors_discovered:
await self.discover_descriptors(characteristic)
# Look for the CCCD descriptor
cccd = characteristic.get_descriptor(GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR)
cccd = characteristic.get_descriptor(
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR
)
if not cccd:
logger.warning('unsubscribing from characteristic with no CCCD descriptor')
return
if subscriber is not None:
# Remove matching subscriber from subscriber sets
for subscriber_set in (self.notification_subscribers, self.indication_subscribers):
for subscriber_set in (
self.notification_subscribers,
self.indication_subscribers,
):
subscribers = subscriber_set.get(characteristic.handle, [])
if subscriber in subscribers:
subscribers.remove(subscriber)
# Cleanup if we removed the last one
if not subscribers:
subscriber_set.remove(characteristic.handle)
del subscriber_set[characteristic.handle]
else:
# Remove all subscribers for this attribute from the sets!
self.notification_subscribers.pop(characteristic.handle, None)
@@ -622,8 +729,10 @@ class Client:
'''
# Send a request to read
attribute_handle = attribute if type(attribute) is int else attribute.handle
response = await self.send_request(ATT_Read_Request(attribute_handle = attribute_handle))
attribute_handle = attribute if isinstance(attribute, int) else attribute.handle
response = await self.send_request(
ATT_Read_Request(attribute_handle=attribute_handle)
)
if response is None:
raise TimeoutError('read timeout')
if response.op_code == ATT_ERROR_RESPONSE:
@@ -631,7 +740,7 @@ class Client:
response.error_code,
'att',
ATT_PDU.error_name(response.error_code),
response
response,
)
# If the value is the max size for the MTU, try to read more unless the caller
@@ -642,18 +751,23 @@ class Client:
offset = len(attribute_value)
while True:
response = await self.send_request(
ATT_Read_Blob_Request(attribute_handle = attribute_handle, value_offset = offset)
ATT_Read_Blob_Request(
attribute_handle=attribute_handle, value_offset=offset
)
)
if response is None:
raise TimeoutError('read timeout')
if response.op_code == ATT_ERROR_RESPONSE:
if response.error_code == ATT_ATTRIBUTE_NOT_LONG_ERROR or response.error_code == ATT_INVALID_OFFSET_ERROR:
if response.error_code in (
ATT_ATTRIBUTE_NOT_LONG_ERROR,
ATT_INVALID_OFFSET_ERROR,
):
break
raise ProtocolError(
response.error_code,
'att',
ATT_PDU.error_name(response.error_code),
response
response,
)
part = response.part_attribute_value
@@ -685,7 +799,7 @@ class Client:
ATT_Read_By_Type_Request(
starting_handle=starting_handle,
ending_handle=ending_handle,
attribute_type = uuid
attribute_type=uuid,
)
)
if response is None:
@@ -696,7 +810,10 @@ class Client:
if response.op_code == ATT_ERROR_RESPONSE:
if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR:
# Unexpected end
logger.warning(f'!!! unexpected error while reading characteristics: {HCI_Constant.error_name(response.error_code)}')
logger.warning(
'!!! unexpected error while reading characteristics: '
f'{HCI_Constant.error_name(response.error_code)}'
)
# TODO raise appropriate exception
return []
break
@@ -721,47 +838,54 @@ class Client:
async def write_value(self, attribute, value, with_response=False):
'''
See Vol 3, Part G - 4.9.1 Write Without Response & 4.9.3 Write Characteristic Value
See Vol 3, Part G - 4.9.1 Write Without Response & 4.9.3 Write Characteristic
Value
`attribute` can be an Attribute object, or a handle value
'''
# Send a request or command to write
attribute_handle = attribute if type(attribute) is int else attribute.handle
attribute_handle = attribute if isinstance(attribute, int) else attribute.handle
if with_response:
response = await self.send_request(
ATT_Write_Request(
attribute_handle = attribute_handle,
attribute_value = value
attribute_handle=attribute_handle, attribute_value=value
)
)
if response.op_code == ATT_ERROR_RESPONSE:
raise ProtocolError(
response.error_code,
'att',
ATT_PDU.error_name(response.error_code), response
ATT_PDU.error_name(response.error_code),
response,
)
else:
await self.send_command(
ATT_Write_Command(
attribute_handle = attribute_handle,
attribute_value = value
attribute_handle=attribute_handle, attribute_value=value
)
)
def on_gatt_pdu(self, att_pdu):
logger.debug(f'GATT Response to client: [0x{self.connection.handle:04X}] {att_pdu}')
logger.debug(
f'GATT Response to client: [0x{self.connection.handle:04X}] {att_pdu}'
)
if att_pdu.op_code in ATT_RESPONSES:
if self.pending_request is None:
# Not expected!
logger.warning('!!! unexpected response, there is no pending request')
return
# Sanity check: the response should match the pending request unless it is an error response
# Sanity check: the response should match the pending request unless it is
# an error response
if att_pdu.op_code != ATT_ERROR_RESPONSE:
expected_response_name = self.pending_request.name.replace('_REQUEST', '_RESPONSE')
expected_response_name = self.pending_request.name.replace(
'_REQUEST', '_RESPONSE'
)
if att_pdu.name != expected_response_name:
logger.warning(f'!!! mismatched response: expected {expected_response_name}')
logger.warning(
f'!!! mismatched response: expected {expected_response_name}'
)
return
# Return the response to the coroutine that is waiting for it
@@ -772,11 +896,20 @@ class Client:
if handler is not None:
handler(att_pdu)
else:
logger.warning(f'{color(f"--- Ignoring GATT Response from [0x{self.connection.handle:04X}]:", "red")} {att_pdu}')
logger.warning(
color(
'--- Ignoring GATT Response from '
f'[0x{self.connection.handle:04X}]: ',
'red',
)
+ str(att_pdu)
)
def on_att_handle_value_notification(self, notification):
# Call all subscribers
subscribers = self.notification_subscribers.get(notification.attribute_handle, [])
subscribers = self.notification_subscribers.get(
notification.attribute_handle, []
)
if not subscribers:
logger.warning('!!! received notification with no subscriber')
for subscriber in subscribers:

View File

@@ -26,13 +26,53 @@
import asyncio
import logging
from collections import defaultdict
import struct
from typing import Tuple, Optional
from pyee import EventEmitter
from colors import color
from .core import *
from .hci import *
from .att import *
from .gatt import *
from .core import UUID
from .att import (
ATT_ATTRIBUTE_NOT_FOUND_ERROR,
ATT_ATTRIBUTE_NOT_LONG_ERROR,
ATT_CID,
ATT_DEFAULT_MTU,
ATT_INVALID_ATTRIBUTE_LENGTH_ERROR,
ATT_INVALID_HANDLE_ERROR,
ATT_INVALID_OFFSET_ERROR,
ATT_REQUEST_NOT_SUPPORTED_ERROR,
ATT_REQUESTS,
ATT_UNLIKELY_ERROR_ERROR,
ATT_UNSUPPORTED_GROUP_TYPE_ERROR,
ATT_Error,
ATT_Error_Response,
ATT_Exchange_MTU_Response,
ATT_Find_By_Type_Value_Response,
ATT_Find_Information_Response,
ATT_Handle_Value_Indication,
ATT_Handle_Value_Notification,
ATT_Read_Blob_Response,
ATT_Read_By_Group_Type_Response,
ATT_Read_By_Type_Response,
ATT_Read_Response,
ATT_Write_Response,
Attribute,
)
from .gatt import (
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
GATT_INCLUDE_ATTRIBUTE_TYPE,
GATT_MAX_ATTRIBUTE_VALUE_SIZE,
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE,
GATT_REQUEST_TIMEOUT,
GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE,
Characteristic,
CharacteristicDeclaration,
CharacteristicValue,
Descriptor,
Service,
)
# -----------------------------------------------------------------------------
# Logging
@@ -55,17 +95,32 @@ class Server(EventEmitter):
self.device = device
self.attributes = [] # Attributes, ordered by increasing handle values
self.attributes_by_handle = {} # Map for fast attribute access by handle
self.max_mtu = GATT_SERVER_DEFAULT_MAX_MTU # The max MTU we're willing to negotiate
self.subscribers = {} # Map of subscriber states by connection handle and attribute handle
self.max_mtu = (
GATT_SERVER_DEFAULT_MAX_MTU # The max MTU we're willing to negotiate
)
self.subscribers = (
{}
) # Map of subscriber states by connection handle and attribute handle
self.indication_semaphores = defaultdict(lambda: asyncio.Semaphore(1))
self.pending_confirmations = defaultdict(lambda: None)
def __str__(self):
return "\n".join(map(str, self.attributes))
def send_gatt_pdu(self, connection_handle, pdu):
self.device.send_l2cap_pdu(connection_handle, ATT_CID, pdu)
def next_handle(self):
return 1 + len(self.attributes)
def get_advertising_service_data(self):
return {
attribute: data
for attribute in self.attributes
if isinstance(attribute, Service)
and (data := attribute.get_advertising_data())
}
def get_attribute(self, handle):
attribute = self.attributes_by_handle.get(handle)
if attribute:
@@ -79,15 +134,74 @@ class Server(EventEmitter):
return attribute
return None
def get_service_attribute(self, service_uuid: UUID) -> Optional[Service]:
return next(
(
attribute
for attribute in self.attributes
if attribute.type == GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE
and attribute.uuid == service_uuid
),
None,
)
def get_characteristic_attributes(
self, service_uuid: UUID, characteristic_uuid: UUID
) -> Optional[Tuple[CharacteristicDeclaration, Characteristic]]:
service_handle = self.get_service_attribute(service_uuid)
if not service_handle:
return None
return next(
(
(attribute, self.get_attribute(attribute.characteristic.handle))
for attribute in map(
self.get_attribute,
range(service_handle.handle, service_handle.end_group_handle + 1),
)
if attribute.type == GATT_CHARACTERISTIC_ATTRIBUTE_TYPE
and attribute.characteristic.uuid == characteristic_uuid
),
None,
)
def get_descriptor_attribute(
self, service_uuid: UUID, characteristic_uuid: UUID, descriptor_uuid: UUID
) -> Optional[Descriptor]:
characteristics = self.get_characteristic_attributes(
service_uuid, characteristic_uuid
)
if not characteristics:
return None
(_, characteristic_value) = characteristics
return next(
(
attribute
for attribute in map(
self.get_attribute,
range(
characteristic_value.handle + 1,
characteristic_value.end_group_handle + 1,
),
)
if attribute.type == descriptor_uuid
),
None,
)
def add_attribute(self, attribute):
# Assign a handle to this attribute
attribute.handle = self.next_handle()
attribute.end_group_handle = attribute.handle # TODO: keep track of descriptors in the group
attribute.end_group_handle = (
attribute.handle
) # TODO: keep track of descriptors in the group
# Add this attribute to the list
self.attributes.append(attribute)
def add_service(self, service):
def add_service(self, service: Service):
# Add the service attribute to the DB
self.add_attribute(service)
@@ -95,16 +209,9 @@ class Server(EventEmitter):
# Add all characteristics
for characteristic in service.characteristics:
# Add a Characteristic Declaration (Vol 3, Part G - 3.3.1 Characteristic Declaration)
declaration_bytes = struct.pack(
'<BH',
characteristic.properties,
self.next_handle() + 1, # The value will be the next attribute after this declaration
) + characteristic.uuid.to_pdu_bytes()
characteristic_declaration = Attribute(
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
Attribute.READABLE,
declaration_bytes
# Add a Characteristic Declaration
characteristic_declaration = CharacteristicDeclaration(
characteristic, self.next_handle() + 1
)
self.add_attribute(characteristic_declaration)
@@ -118,17 +225,26 @@ class Server(EventEmitter):
# If the characteristic supports subscriptions, add a CCCD descriptor
# unless there is one already
if (
characteristic.properties & (Characteristic.NOTIFY | Characteristic.INDICATE) and
characteristic.get_descriptor(GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR) is None
characteristic.properties
& (Characteristic.NOTIFY | Characteristic.INDICATE)
and characteristic.get_descriptor(
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR
)
is None
):
self.add_attribute(
# pylint: disable=line-too-long
Descriptor(
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
Attribute.READABLE | Attribute.WRITEABLE,
CharacteristicValue(
read=lambda connection, characteristic=characteristic: self.read_cccd(connection, characteristic),
write=lambda connection, value, characteristic=characteristic: self.write_cccd(connection, characteristic, value)
)
read=lambda connection, characteristic=characteristic: self.read_cccd(
connection, characteristic
),
write=lambda connection, value, characteristic=characteristic: self.write_cccd(
connection, characteristic, value
),
),
)
)
@@ -155,23 +271,36 @@ class Server(EventEmitter):
return cccd or bytes([0, 0])
def write_cccd(self, connection, characteristic, value):
logger.debug(f'Subscription update for connection=0x{connection.handle:04X}, handle=0x{characteristic.handle:04X}: {value.hex()}')
logger.debug(
f'Subscription update for connection=0x{connection.handle:04X}, '
f'handle=0x{characteristic.handle:04X}: {value.hex()}'
)
# Sanity check
if len(value) != 2:
logger.warn('CCCD value not 2 bytes long')
logger.warning('CCCD value not 2 bytes long')
return
cccds = self.subscribers.setdefault(connection.handle, {})
cccds[characteristic.handle] = value
logger.debug(f'CCCDs: {cccds}')
notify_enabled = (value[0] & 0x01 != 0)
indicate_enabled = (value[0] & 0x02 != 0)
characteristic.emit('subscription', connection, notify_enabled, indicate_enabled)
self.emit('characteristic_subscription', connection, characteristic, notify_enabled, indicate_enabled)
notify_enabled = value[0] & 0x01 != 0
indicate_enabled = value[0] & 0x02 != 0
characteristic.emit(
'subscription', connection, notify_enabled, indicate_enabled
)
self.emit(
'characteristic_subscription',
connection,
characteristic,
notify_enabled,
indicate_enabled,
)
def send_response(self, connection, response):
logger.debug(f'GATT Response from server: [0x{connection.handle:04X}] {response}')
logger.debug(
f'GATT Response from server: [0x{connection.handle:04X}] {response}'
)
self.send_gatt_pdu(connection.handle, response.to_bytes())
async def notify_subscriber(self, connection, attribute, value=None, force=False):
@@ -183,14 +312,20 @@ class Server(EventEmitter):
return
cccd = subscribers.get(attribute.handle)
if not cccd:
logger.debug(f'not notifying, no subscribers for handle {attribute.handle:04X}')
logger.debug(
f'not notifying, no subscribers for handle {attribute.handle:04X}'
)
return
if len(cccd) != 2 or (cccd[0] & 0x01 == 0):
logger.debug(f'not notifying, cccd={cccd.hex()}')
return
# Get or encode the value
value = attribute.read_value(connection) if value is None else attribute.encode_value(value)
value = (
attribute.read_value(connection)
if value is None
else attribute.encode_value(value)
)
# Truncate if needed
if len(value) > connection.att_mtu - 3:
@@ -198,10 +333,11 @@ class Server(EventEmitter):
# Notify
notification = ATT_Handle_Value_Notification(
attribute_handle = attribute.handle,
attribute_value = value
attribute_handle=attribute.handle, attribute_value=value
)
logger.debug(
f'GATT Notify from server: [0x{connection.handle:04X}] {notification}'
)
logger.debug(f'GATT Notify from server: [0x{connection.handle:04X}] {notification}')
self.send_gatt_pdu(connection.handle, bytes(notification))
async def indicate_subscriber(self, connection, attribute, value=None, force=False):
@@ -213,14 +349,20 @@ class Server(EventEmitter):
return
cccd = subscribers.get(attribute.handle)
if not cccd:
logger.debug(f'not indicating, no subscribers for handle {attribute.handle:04X}')
logger.debug(
f'not indicating, no subscribers for handle {attribute.handle:04X}'
)
return
if len(cccd) != 2 or (cccd[0] & 0x02 == 0):
logger.debug(f'not indicating, cccd={cccd.hex()}')
return
# Get or encode the value
value = attribute.read_value(connection) if value is None else attribute.encode_value(value)
value = (
attribute.read_value(connection)
if value is None
else attribute.encode_value(value)
)
# Truncate if needed
if len(value) > connection.att_mtu - 3:
@@ -228,31 +370,39 @@ class Server(EventEmitter):
# Indicate
indication = ATT_Handle_Value_Indication(
attribute_handle = attribute.handle,
attribute_value = value
attribute_handle=attribute.handle, attribute_value=value
)
logger.debug(
f'GATT Indicate from server: [0x{connection.handle:04X}] {indication}'
)
logger.debug(f'GATT Indicate from server: [0x{connection.handle:04X}] {indication}')
# Wait until we can send (only one pending indication at a time per connection)
async with self.indication_semaphores[connection.handle]:
assert(self.pending_confirmations[connection.handle] is None)
assert self.pending_confirmations[connection.handle] is None
# Create a future value to hold the eventual response
self.pending_confirmations[connection.handle] = asyncio.get_running_loop().create_future()
self.pending_confirmations[
connection.handle
] = asyncio.get_running_loop().create_future()
try:
self.send_gatt_pdu(connection.handle, indication.to_bytes())
await asyncio.wait_for(self.pending_confirmations[connection.handle], GATT_REQUEST_TIMEOUT)
except asyncio.TimeoutError:
await asyncio.wait_for(
self.pending_confirmations[connection.handle], GATT_REQUEST_TIMEOUT
)
except asyncio.TimeoutError as error:
logger.warning(color('!!! GATT Indicate timeout', 'red'))
raise TimeoutError(f'GATT timeout for {indication.name}')
raise TimeoutError(f'GATT timeout for {indication.name}') from error
finally:
self.pending_confirmations[connection.handle] = None
async def notify_or_indicate_subscribers(self, indicate, attribute, value=None, force=False):
async def notify_or_indicate_subscribers(
self, indicate, attribute, value=None, force=False
):
# Get all the connections for which there's at least one subscription
connections = [
connection for connection in [
connection
for connection in [
self.device.lookup_connection(connection_handle)
for (connection_handle, subscribers) in self.subscribers.items()
if force or subscribers.get(attribute.handle)
@@ -263,10 +413,12 @@ class Server(EventEmitter):
# Indicate or notify for each connection
if connections:
coroutine = self.indicate_subscriber if indicate else self.notify_subscriber
await asyncio.wait([
await asyncio.wait(
[
asyncio.create_task(coroutine(connection, attribute, value, force))
for connection in connections
])
]
)
async def notify_subscribers(self, attribute, value=None, force=False):
return await self.notify_or_indicate_subscribers(False, attribute, value, force)
@@ -294,7 +446,7 @@ class Server(EventEmitter):
response = ATT_Error_Response(
request_opcode_in_error=att_pdu.op_code,
attribute_handle_in_error=error.att_handle,
error_code = error.error_code
error_code=error.error_code,
)
self.send_response(connection, response)
except Exception as error:
@@ -302,7 +454,7 @@ class Server(EventEmitter):
response = ATT_Error_Response(
request_opcode_in_error=att_pdu.op_code,
attribute_handle_in_error=0x0000,
error_code = ATT_UNLIKELY_ERROR_ERROR
error_code=ATT_UNLIKELY_ERROR_ERROR,
)
self.send_response(connection, response)
raise error
@@ -313,7 +465,13 @@ class Server(EventEmitter):
self.on_att_request(connection, att_pdu)
else:
# Just ignore
logger.warning(f'{color("--- Ignoring GATT Request from [0x{connection.handle:04X}]:", "red")} {att_pdu}')
logger.warning(
color(
f'--- Ignoring GATT Request from [0x{connection.handle:04X}]: ',
'red',
)
+ str(att_pdu)
)
#######################################################
# ATT handlers
@@ -322,11 +480,16 @@ class Server(EventEmitter):
'''
Handler for requests without a more specific handler
'''
logger.warning(f'{color(f"--- Unsupported ATT Request from [0x{connection.handle:04X}]:", "red")} {pdu}')
logger.warning(
color(
f'--- Unsupported ATT Request from [0x{connection.handle:04X}]: ', 'red'
)
+ str(pdu)
)
response = ATT_Error_Response(
request_opcode_in_error=pdu.op_code,
attribute_handle_in_error=0x0000,
error_code = ATT_REQUEST_NOT_SUPPORTED_ERROR
error_code=ATT_REQUEST_NOT_SUPPORTED_ERROR,
)
self.send_response(connection, response)
@@ -334,7 +497,9 @@ class Server(EventEmitter):
'''
See Bluetooth spec Vol 3, Part F - 3.4.2.1 Exchange MTU Request
'''
self.send_response(connection, ATT_Exchange_MTU_Response(server_rx_mtu = self.max_mtu))
self.send_response(
connection, ATT_Exchange_MTU_Response(server_rx_mtu=self.max_mtu)
)
# Compute the final MTU
if request.client_rx_mtu >= ATT_DEFAULT_MTU:
@@ -351,12 +516,18 @@ class Server(EventEmitter):
'''
# Check the request parameters
if request.starting_handle == 0 or request.starting_handle > request.ending_handle:
self.send_response(connection, ATT_Error_Response(
if (
request.starting_handle == 0
or request.starting_handle > request.ending_handle
):
self.send_response(
connection,
ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.starting_handle,
error_code = ATT_INVALID_HANDLE_ERROR
))
error_code=ATT_INVALID_HANDLE_ERROR,
),
)
return
# Build list of returned attributes
@@ -364,9 +535,10 @@ class Server(EventEmitter):
attributes = []
uuid_size = 0
for attribute in (
attribute for attribute in self.attributes if
attribute.handle >= request.starting_handle and
attribute.handle <= request.ending_handle
attribute
for attribute in self.attributes
if attribute.handle >= request.starting_handle
and attribute.handle <= request.ending_handle
):
# TODO: check permissions
@@ -394,13 +566,13 @@ class Server(EventEmitter):
]
response = ATT_Find_Information_Response(
format=1 if len(attributes[0].type.to_pdu_bytes()) == 2 else 2,
information_data = b''.join(information_data_list)
information_data=b''.join(information_data_list),
)
else:
response = ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.starting_handle,
error_code = ATT_ATTRIBUTE_NOT_FOUND_ERROR
error_code=ATT_ATTRIBUTE_NOT_FOUND_ERROR,
)
self.send_response(connection, response)
@@ -414,12 +586,13 @@ class Server(EventEmitter):
pdu_space_available = connection.att_mtu - 2
attributes = []
for attribute in (
attribute for attribute in self.attributes if
attribute.handle >= request.starting_handle and
attribute.handle <= request.ending_handle and
attribute.type == request.attribute_type and
attribute.read_value(connection) == request.attribute_value and
pdu_space_available >= 4
attribute
for attribute in self.attributes
if attribute.handle >= request.starting_handle
and attribute.handle <= request.ending_handle
and attribute.type == request.attribute_type
and attribute.read_value(connection) == request.attribute_value
and pdu_space_available >= 4
):
# TODO: check permissions
@@ -431,17 +604,19 @@ class Server(EventEmitter):
if attributes:
handles_information_list = []
for attribute in attributes:
if attribute.type in {
if attribute.type in (
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE,
GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE,
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE
}:
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
):
# Part of a group
group_end_handle = attribute.end_group_handle
else:
# Not part of a group
group_end_handle = attribute.handle
handles_information_list.append(struct.pack('<HH', attribute.handle, group_end_handle))
handles_information_list.append(
struct.pack('<HH', attribute.handle, group_end_handle)
)
response = ATT_Find_By_Type_Value_Response(
handles_information_list=b''.join(handles_information_list)
)
@@ -449,7 +624,7 @@ class Server(EventEmitter):
response = ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.starting_handle,
error_code = ATT_ATTRIBUTE_NOT_FOUND_ERROR
error_code=ATT_ATTRIBUTE_NOT_FOUND_ERROR,
)
self.send_response(connection, response)
@@ -462,11 +637,12 @@ class Server(EventEmitter):
pdu_space_available = connection.att_mtu - 2
attributes = []
for attribute in (
attribute for attribute in self.attributes if
attribute.type == request.attribute_type and
attribute.handle >= request.starting_handle and
attribute.handle <= request.ending_handle and
pdu_space_available
attribute
for attribute in self.attributes
if attribute.type == request.attribute_type
and attribute.handle >= request.starting_handle
and attribute.handle <= request.ending_handle
and pdu_space_available
):
# TODO: check permissions
@@ -490,16 +666,17 @@ class Server(EventEmitter):
pdu_space_available -= entry_size
if attributes:
attribute_data_list = [struct.pack('<H', handle) + value for handle, value in attributes]
attribute_data_list = [
struct.pack('<H', handle) + value for handle, value in attributes
]
response = ATT_Read_By_Type_Response(
length = entry_size,
attribute_data_list = b''.join(attribute_data_list)
length=entry_size, attribute_data_list=b''.join(attribute_data_list)
)
else:
response = ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.starting_handle,
error_code = ATT_ATTRIBUTE_NOT_FOUND_ERROR
error_code=ATT_ATTRIBUTE_NOT_FOUND_ERROR,
)
self.send_response(connection, response)
@@ -513,14 +690,12 @@ class Server(EventEmitter):
# TODO: check permissions
value = attribute.read_value(connection)
value_size = min(connection.att_mtu - 1, len(value))
response = ATT_Read_Response(
attribute_value = value[:value_size]
)
response = ATT_Read_Response(attribute_value=value[:value_size])
else:
response = ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.attribute_handle,
error_code = ATT_INVALID_HANDLE_ERROR
error_code=ATT_INVALID_HANDLE_ERROR,
)
self.send_response(connection, response)
@@ -536,24 +711,28 @@ class Server(EventEmitter):
response = ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.attribute_handle,
error_code = ATT_INVALID_OFFSET_ERROR
error_code=ATT_INVALID_OFFSET_ERROR,
)
elif len(value) <= connection.att_mtu - 1:
response = ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.attribute_handle,
error_code = ATT_ATTRIBUTE_NOT_LONG_ERROR
error_code=ATT_ATTRIBUTE_NOT_LONG_ERROR,
)
else:
part_size = min(connection.att_mtu - 1, len(value) - request.value_offset)
part_size = min(
connection.att_mtu - 1, len(value) - request.value_offset
)
response = ATT_Read_Blob_Response(
part_attribute_value = value[request.value_offset:request.value_offset + part_size]
part_attribute_value=value[
request.value_offset : request.value_offset + part_size
]
)
else:
response = ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.attribute_handle,
error_code = ATT_INVALID_HANDLE_ERROR
error_code=ATT_INVALID_HANDLE_ERROR,
)
self.send_response(connection, response)
@@ -561,15 +740,15 @@ class Server(EventEmitter):
'''
See Bluetooth spec Vol 3, Part F - 3.4.4.9 Read by Group Type Request
'''
if request.attribute_group_type not in {
if request.attribute_group_type not in (
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE,
GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE,
GATT_INCLUDE_ATTRIBUTE_TYPE
}:
GATT_INCLUDE_ATTRIBUTE_TYPE,
):
response = ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.starting_handle,
error_code = ATT_UNSUPPORTED_GROUP_TYPE_ERROR
error_code=ATT_UNSUPPORTED_GROUP_TYPE_ERROR,
)
self.send_response(connection, response)
return
@@ -577,11 +756,12 @@ class Server(EventEmitter):
pdu_space_available = connection.att_mtu - 2
attributes = []
for attribute in (
attribute for attribute in self.attributes if
attribute.type == request.attribute_group_type and
attribute.handle >= request.starting_handle and
attribute.handle <= request.ending_handle and
pdu_space_available
attribute
for attribute in self.attributes
if attribute.type == request.attribute_group_type
and attribute.handle >= request.starting_handle
and attribute.handle <= request.ending_handle
and pdu_space_available
):
# Check the attribute value size
attribute_value = attribute.read_value(connection)
@@ -599,7 +779,9 @@ class Server(EventEmitter):
break
# Add the attribute to the list
attributes.append((attribute.handle, attribute.end_group_handle, attribute_value))
attributes.append(
(attribute.handle, attribute.end_group_handle, attribute_value)
)
pdu_space_available -= entry_size
if attributes:
@@ -609,13 +791,13 @@ class Server(EventEmitter):
]
response = ATT_Read_By_Group_Type_Response(
length=len(attribute_data_list[0]),
attribute_data_list = b''.join(attribute_data_list)
attribute_data_list=b''.join(attribute_data_list),
)
else:
response = ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.starting_handle,
error_code = ATT_ATTRIBUTE_NOT_FOUND_ERROR
error_code=ATT_ATTRIBUTE_NOT_FOUND_ERROR,
)
self.send_response(connection, response)
@@ -628,22 +810,28 @@ class Server(EventEmitter):
# Check that the attribute exists
attribute = self.get_attribute(request.attribute_handle)
if attribute is None:
self.send_response(connection, ATT_Error_Response(
self.send_response(
connection,
ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.attribute_handle,
error_code = ATT_INVALID_HANDLE_ERROR
))
error_code=ATT_INVALID_HANDLE_ERROR,
),
)
return
# TODO: check permissions
# Check the request parameters
if len(request.attribute_value) > GATT_MAX_ATTRIBUTE_VALUE_SIZE:
self.send_response(connection, ATT_Error_Response(
self.send_response(
connection,
ATT_Error_Response(
request_opcode_in_error=request.op_code,
attribute_handle_in_error=request.attribute_handle,
error_code = ATT_INVALID_ATTRIBUTE_LENGTH_ERROR
))
error_code=ATT_INVALID_ATTRIBUTE_LENGTH_ERROR,
),
)
return
# Accept the value
@@ -674,13 +862,15 @@ class Server(EventEmitter):
except Exception as error:
logger.warning(f'!!! ignoring exception: {error}')
def on_att_handle_value_confirmation(self, connection, confirmation):
def on_att_handle_value_confirmation(self, connection, _confirmation):
'''
See Bluetooth spec Vol 3, Part F - 3.4.7.3 Handle Value Confirmation
'''
if self.pending_confirmations[connection.handle] is None:
# Not expected!
logger.warning('!!! unexpected confirmation, there is no pending indication')
logger.warning(
'!!! unexpected confirmation, there is no pending indication'
)
return
self.pending_confirmations[connection.handle].set_result(None)

File diff suppressed because it is too large Load Diff

View File

@@ -18,10 +18,9 @@
import logging
from colors import color
from bumble.smp import SMP_CID, SMP_Command
from .att import ATT_CID, ATT_PDU
from .smp import SMP_CID, SMP_Command
from .core import name_or_number
from .gatt import ATT_PDU, ATT_CID
from .l2cap import (
L2CAP_PDU,
L2CAP_CONNECTION_REQUEST,
@@ -29,20 +28,17 @@ from .l2cap import (
L2CAP_SIGNALING_CID,
L2CAP_LE_SIGNALING_CID,
L2CAP_Control_Frame,
L2CAP_Connection_Response
L2CAP_Connection_Response,
)
from .hci import (
HCI_EVENT_PACKET,
HCI_ACL_DATA_PACKET,
HCI_DISCONNECTION_COMPLETE_EVENT,
HCI_AclDataPacketAssembler
HCI_AclDataPacketAssembler,
)
from .rfcomm import RFCOMM_Frame, RFCOMM_PSM
from .sdp import SDP_PDU, SDP_PSM
from .avdtp import (
MessageAssembler as AVDTP_MessageAssembler,
AVDTP_PSM
)
from .avdtp import MessageAssembler as AVDTP_MessageAssembler, AVDTP_PSM
# -----------------------------------------------------------------------------
# Logging
@@ -69,6 +65,7 @@ class PacketTracer:
self.psms = {} # PSM, by source_cid
self.peer = None # ACL stream in the other direction
# pylint: disable=too-many-nested-blocks
def on_acl_pdu(self, pdu):
l2cap_pdu = L2CAP_PDU.from_bytes(pdu)
@@ -78,7 +75,7 @@ class PacketTracer:
elif l2cap_pdu.cid == SMP_CID:
smp_command = SMP_Command.from_bytes(l2cap_pdu.payload)
self.analyzer.emit(smp_command)
elif l2cap_pdu.cid == L2CAP_SIGNALING_CID or l2cap_pdu.cid == L2CAP_LE_SIGNALING_CID:
elif l2cap_pdu.cid in (L2CAP_SIGNALING_CID, L2CAP_LE_SIGNALING_CID):
control_frame = L2CAP_Control_Frame.from_bytes(l2cap_pdu.payload)
self.analyzer.emit(control_frame)
@@ -86,16 +83,26 @@ class PacketTracer:
if control_frame.code == L2CAP_CONNECTION_REQUEST:
self.psms[control_frame.source_cid] = control_frame.psm
elif control_frame.code == L2CAP_CONNECTION_RESPONSE:
if control_frame.result == L2CAP_Connection_Response.CONNECTION_SUCCESSFUL:
if (
control_frame.result
== L2CAP_Connection_Response.CONNECTION_SUCCESSFUL
):
if self.peer:
if psm := self.peer.psms.get(control_frame.source_cid):
# Found a pending connection
self.psms[control_frame.destination_cid] = psm
# For AVDTP connections, create a packet assembler for each direction
# For AVDTP connections, create a packet assembler for
# each direction
if psm == AVDTP_PSM:
self.avdtp_assemblers[control_frame.source_cid] = AVDTP_MessageAssembler(self.on_avdtp_message)
self.peer.avdtp_assemblers[control_frame.destination_cid] = AVDTP_MessageAssembler(self.peer.on_avdtp_message)
self.avdtp_assemblers[
control_frame.source_cid
] = AVDTP_MessageAssembler(self.on_avdtp_message)
self.peer.avdtp_assemblers[
control_frame.destination_cid
] = AVDTP_MessageAssembler(
self.peer.on_avdtp_message
)
else:
# Try to find the PSM associated with this PDU
@@ -107,18 +114,26 @@ class PacketTracer:
rfcomm_frame = RFCOMM_Frame.from_bytes(l2cap_pdu.payload)
self.analyzer.emit(rfcomm_frame)
elif psm == AVDTP_PSM:
self.analyzer.emit(f'{color("L2CAP", "green")} [CID={l2cap_pdu.cid}, PSM=AVDTP]: {l2cap_pdu.payload.hex()}')
self.analyzer.emit(
f'{color("L2CAP", "green")} [CID={l2cap_pdu.cid}, '
f'PSM=AVDTP]: {l2cap_pdu.payload.hex()}'
)
assembler = self.avdtp_assemblers.get(l2cap_pdu.cid)
if assembler:
assembler.on_pdu(l2cap_pdu.payload)
else:
psm_string = name_or_number(PSM_NAMES, psm)
self.analyzer.emit(f'{color("L2CAP", "green")} [CID={l2cap_pdu.cid}, PSM={psm_string}]: {l2cap_pdu.payload.hex()}')
self.analyzer.emit(
f'{color("L2CAP", "green")} [CID={l2cap_pdu.cid}, '
f'PSM={psm_string}]: {l2cap_pdu.payload.hex()}'
)
else:
self.analyzer.emit(l2cap_pdu)
def on_avdtp_message(self, transaction_label, message):
self.analyzer.emit(f'{color("AVDTP", "green")} [{transaction_label}] {message}')
self.analyzer.emit(
f'{color("AVDTP", "green")} [{transaction_label}] {message}'
)
def feed_packet(self, packet):
self.packet_assembler.feed_packet(packet)
@@ -131,7 +146,10 @@ class PacketTracer:
self.peer = None # Analyzer in the other direction
def start_acl_stream(self, connection_handle):
logger.info(f'[{self.label}] +++ Creating ACL stream for connection 0x{connection_handle:04X}')
logger.info(
f'[{self.label}] +++ Creating ACL stream for connection '
f'0x{connection_handle:04X}'
)
stream = PacketTracer.AclStream(self)
self.acl_streams[connection_handle] = stream
@@ -144,7 +162,10 @@ class PacketTracer:
def end_acl_stream(self, connection_handle):
if connection_handle in self.acl_streams:
logger.info(f'[{self.label}] --- Removing ACL stream for connection 0x{connection_handle:04X}')
logger.info(
f'[{self.label}] --- Removing ACL stream for connection '
f'0x{connection_handle:04X}'
)
del self.acl_streams[connection_handle]
# Let the other forwarder know so it can cleanup its stream as well
@@ -176,9 +197,13 @@ class PacketTracer:
self,
host_to_controller_label=color('HOST->CONTROLLER', 'blue'),
controller_to_host_label=color('CONTROLLER->HOST', 'cyan'),
emit_message=logger.info
emit_message=logger.info,
):
self.host_to_controller_analyzer = PacketTracer.Analyzer(host_to_controller_label, emit_message)
self.controller_to_host_analyzer = PacketTracer.Analyzer(controller_to_host_label, emit_message)
self.host_to_controller_analyzer = PacketTracer.Analyzer(
host_to_controller_label, emit_message
)
self.controller_to_host_analyzer = PacketTracer.Analyzer(
controller_to_host_label, emit_message
)
self.host_to_controller_analyzer.peer = self.controller_to_host_analyzer
self.controller_to_host_analyzer.peer = self.host_to_controller_analyzer

View File

@@ -43,7 +43,7 @@ class HfpProtocol:
def feed(self, data):
# Convert the data to a string if needed
if type(data) == bytes:
if isinstance(data, bytes):
data = data.decode('utf-8')
logger.debug(f'<<< Data received: {data}')
@@ -79,16 +79,16 @@ class HfpProtocol:
async def initialize_service(self):
# Perform Service Level Connection Initialization
self.send_command_line('AT+BRSF=2072') # Retrieve Supported Features
line = await(self.next_line())
line = await(self.next_line())
await (self.next_line())
await (self.next_line())
self.send_command_line('AT+CIND=?')
line = await(self.next_line())
line = await(self.next_line())
await (self.next_line())
await (self.next_line())
self.send_command_line('AT+CIND?')
line = await(self.next_line())
line = await(self.next_line())
await (self.next_line())
await (self.next_line())
self.send_command_line('AT+CMER=3,0,0,1')
line = await(self.next_line())
await (self.next_line())

View File

@@ -16,16 +16,60 @@
# Imports
# -----------------------------------------------------------------------------
import asyncio
import collections
import logging
from pyee import EventEmitter
import struct
from colors import color
from .hci import *
from .l2cap import *
from .att import *
from .gatt import *
from .smp import *
from .core import ConnectionParameters
from bumble.l2cap import L2CAP_PDU
from .hci import (
HCI_ACL_DATA_PACKET,
HCI_COMMAND_COMPLETE_EVENT,
HCI_COMMAND_PACKET,
HCI_EVENT_PACKET,
HCI_LE_READ_BUFFER_SIZE_COMMAND,
HCI_LE_READ_LOCAL_SUPPORTED_FEATURES_COMMAND,
HCI_LE_READ_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND,
HCI_LE_WRITE_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND,
HCI_READ_BUFFER_SIZE_COMMAND,
HCI_READ_LOCAL_VERSION_INFORMATION_COMMAND,
HCI_RESET_COMMAND,
HCI_SUCCESS,
HCI_SUPPORTED_COMMANDS_FLAGS,
HCI_VERSION_BLUETOOTH_CORE_4_0,
HCI_AclDataPacket,
HCI_AclDataPacketAssembler,
HCI_Constant,
HCI_Error,
HCI_LE_Long_Term_Key_Request_Negative_Reply_Command,
HCI_LE_Long_Term_Key_Request_Reply_Command,
HCI_LE_Read_Buffer_Size_Command,
HCI_LE_Read_Local_Supported_Features_Command,
HCI_LE_Read_Suggested_Default_Data_Length_Command,
HCI_LE_Remote_Connection_Parameter_Request_Reply_Command,
HCI_LE_Set_Event_Mask_Command,
HCI_LE_Write_Suggested_Default_Data_Length_Command,
HCI_Link_Key_Request_Negative_Reply_Command,
HCI_Link_Key_Request_Reply_Command,
HCI_PIN_Code_Request_Negative_Reply_Command,
HCI_Packet,
HCI_Read_Buffer_Size_Command,
HCI_Read_Local_Supported_Commands_Command,
HCI_Read_Local_Version_Information_Command,
HCI_Reset_Command,
HCI_Set_Event_Mask_Command,
)
from .core import (
BT_BR_EDR_TRANSPORT,
BT_CENTRAL_ROLE,
BT_LE_TRANSPORT,
ConnectionPHY,
ConnectionParameters,
)
from .utils import AbortableEventEmitter
# -----------------------------------------------------------------------------
# Logging
@@ -36,11 +80,15 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
# fmt: off
HOST_DEFAULT_HC_LE_ACL_DATA_PACKET_LENGTH = 27
HOST_HC_TOTAL_NUM_LE_ACL_DATA_PACKETS = 1
HOST_DEFAULT_HC_ACL_DATA_PACKET_LENGTH = 27
HOST_HC_TOTAL_NUM_ACL_DATA_PACKETS = 1
# fmt: on
# -----------------------------------------------------------------------------
class Connection:
@@ -61,12 +109,13 @@ class Connection:
# -----------------------------------------------------------------------------
class Host(EventEmitter):
class Host(AbortableEventEmitter):
def __init__(self, controller_source=None, controller_sink=None):
super().__init__()
self.hci_sink = None
self.ready = False # True when we can accept incoming packets
self.reset_done = False
self.connections = {} # Connections, by connection handle
self.pending_command = None
self.pending_response = None
@@ -92,73 +141,122 @@ class Host(EventEmitter):
if controller_sink:
self.set_packet_sink(controller_sink)
async def flush(self):
# Make sure no command is pending
await self.command_semaphore.acquire()
# Flush current host state, then release command semaphore
self.emit('flush')
self.command_semaphore.release()
async def reset(self):
if self.ready:
self.ready = False
await self.flush()
await self.send_command(HCI_Reset_Command(), check_result=True)
self.ready = True
response = await self.send_command(HCI_Read_Local_Supported_Commands_Command(), check_result=True)
response = await self.send_command(
HCI_Read_Local_Supported_Commands_Command(), check_result=True
)
self.local_supported_commands = response.return_parameters.supported_commands
if self.supports_command(HCI_LE_READ_LOCAL_SUPPORTED_FEATURES_COMMAND):
response = await self.send_command(HCI_LE_Read_Local_Supported_Features_Command(), check_result=True)
self.local_le_features = struct.unpack('<Q', response.return_parameters.le_features)[0]
response = await self.send_command(
HCI_LE_Read_Local_Supported_Features_Command(), check_result=True
)
self.local_le_features = struct.unpack(
'<Q', response.return_parameters.le_features
)[0]
if self.supports_command(HCI_READ_LOCAL_VERSION_INFORMATION_COMMAND):
response = await self.send_command(HCI_Read_Local_Version_Information_Command(), check_result=True)
response = await self.send_command(
HCI_Read_Local_Version_Information_Command(), check_result=True
)
self.local_version = response.return_parameters
await self.send_command(HCI_Set_Event_Mask_Command(event_mask = bytes.fromhex('FFFFFFFFFFFFFF3F')))
await self.send_command(
HCI_Set_Event_Mask_Command(event_mask=bytes.fromhex('FFFFFFFFFFFFFF3F'))
)
if self.local_version is not None and self.local_version.hci_version <= HCI_VERSION_BLUETOOTH_CORE_4_0:
# Some older controllers don't like event masks with bits they don't understand
if (
self.local_version is not None
and self.local_version.hci_version <= HCI_VERSION_BLUETOOTH_CORE_4_0
):
# Some older controllers don't like event masks with bits they don't
# understand
le_event_mask = bytes.fromhex('1F00000000000000')
else:
le_event_mask = bytes.fromhex('FFFFF00000000000')
await self.send_command(HCI_LE_Set_Event_Mask_Command(le_event_mask = le_event_mask))
await self.send_command(
HCI_LE_Set_Event_Mask_Command(le_event_mask=le_event_mask)
)
if self.supports_command(HCI_READ_BUFFER_SIZE_COMMAND):
response = await self.send_command(HCI_Read_Buffer_Size_Command(), check_result=True)
self.hc_acl_data_packet_length = response.return_parameters.hc_acl_data_packet_length
self.hc_total_num_acl_data_packets = response.return_parameters.hc_total_num_acl_data_packets
response = await self.send_command(
HCI_Read_Buffer_Size_Command(), check_result=True
)
self.hc_acl_data_packet_length = (
response.return_parameters.hc_acl_data_packet_length
)
self.hc_total_num_acl_data_packets = (
response.return_parameters.hc_total_num_acl_data_packets
)
logger.debug(
f'HCI ACL flow control: hc_acl_data_packet_length={self.hc_acl_data_packet_length},'
'HCI ACL flow control: '
f'hc_acl_data_packet_length={self.hc_acl_data_packet_length},'
f'hc_total_num_acl_data_packets={self.hc_total_num_acl_data_packets}'
)
if self.supports_command(HCI_LE_READ_BUFFER_SIZE_COMMAND):
response = await self.send_command(HCI_LE_Read_Buffer_Size_Command(), check_result=True)
self.hc_le_acl_data_packet_length = response.return_parameters.hc_le_acl_data_packet_length
self.hc_total_num_le_acl_data_packets = response.return_parameters.hc_total_num_le_acl_data_packets
response = await self.send_command(
HCI_LE_Read_Buffer_Size_Command(), check_result=True
)
self.hc_le_acl_data_packet_length = (
response.return_parameters.hc_le_acl_data_packet_length
)
self.hc_total_num_le_acl_data_packets = (
response.return_parameters.hc_total_num_le_acl_data_packets
)
logger.debug(
f'HCI LE ACL flow control: hc_le_acl_data_packet_length={self.hc_le_acl_data_packet_length},'
f'hc_total_num_le_acl_data_packets={self.hc_total_num_le_acl_data_packets}'
'HCI LE ACL flow control: '
f'hc_le_acl_data_packet_length={self.hc_le_acl_data_packet_length},'
'hc_total_num_le_acl_data_packets='
f'{self.hc_total_num_le_acl_data_packets}'
)
if (
response.return_parameters.hc_le_acl_data_packet_length == 0 or
response.return_parameters.hc_total_num_le_acl_data_packets == 0
response.return_parameters.hc_le_acl_data_packet_length == 0
or response.return_parameters.hc_total_num_le_acl_data_packets == 0
):
# LE and Classic share the same values
self.hc_le_acl_data_packet_length = self.hc_acl_data_packet_length
self.hc_total_num_le_acl_data_packets = self.hc_total_num_acl_data_packets
self.hc_total_num_le_acl_data_packets = (
self.hc_total_num_acl_data_packets
)
if (
self.supports_command(HCI_LE_READ_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND) and
self.supports_command(HCI_LE_WRITE_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND)
):
response = await self.send_command(HCI_LE_Read_Suggested_Default_Data_Length_Command())
if self.supports_command(
HCI_LE_READ_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND
) and self.supports_command(HCI_LE_WRITE_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND):
response = await self.send_command(
HCI_LE_Read_Suggested_Default_Data_Length_Command()
)
suggested_max_tx_octets = response.return_parameters.suggested_max_tx_octets
suggested_max_tx_time = response.return_parameters.suggested_max_tx_time
if (
suggested_max_tx_octets != self.suggested_max_tx_octets or
suggested_max_tx_time != self.suggested_max_tx_time
suggested_max_tx_octets != self.suggested_max_tx_octets
or suggested_max_tx_time != self.suggested_max_tx_time
):
await self.send_command(HCI_LE_Write_Suggested_Default_Data_Length_Command(
await self.send_command(
HCI_LE_Write_Suggested_Default_Data_Length_Command(
suggested_max_tx_octets=self.suggested_max_tx_octets,
suggested_max_tx_time = self.suggested_max_tx_time
))
suggested_max_tx_time=self.suggested_max_tx_time,
)
)
self.reset_done = True
@@ -196,21 +294,25 @@ class Host(EventEmitter):
# Check the return parameters if required
if check_result:
if type(response.return_parameters) is int:
if isinstance(response.return_parameters, int):
status = response.return_parameters
elif type(response.return_parameters) is bytes:
elif isinstance(response.return_parameters, bytes):
# return parameters first field is a one byte status code
status = response.return_parameters[0]
else:
status = response.return_parameters.status
if status != HCI_SUCCESS:
logger.warning(f'{command.name} failed ({HCI_Constant.error_name(status)})')
logger.warning(
f'{command.name} failed ({HCI_Constant.error_name(status)})'
)
raise HCI_Error(status)
return response
except Exception as error:
logger.warning(f'{color("!!! Exception while sending HCI packet:", "red")} {error}')
logger.warning(
f'{color("!!! Exception while sending HCI packet:", "red")} {error}'
)
raise error
finally:
self.pending_command = None
@@ -238,9 +340,11 @@ class Host(EventEmitter):
pb_flag=pb_flag,
bc_flag=0,
data_total_length=data_total_length,
data = l2cap_pdu[offset:offset + data_total_length]
data=l2cap_pdu[offset : offset + data_total_length],
)
logger.debug(
f'{color("### HOST -> CONTROLLER", "blue")}: (CID={cid}) {acl_packet}'
)
logger.debug(f'{color("### HOST -> CONTROLLER", "blue")}: (CID={cid}) {acl_packet}')
self.queue_acl_packet(acl_packet)
pb_flag = 1
offset += data_total_length
@@ -251,11 +355,17 @@ class Host(EventEmitter):
self.check_acl_packet_queue()
if len(self.acl_packet_queue):
logger.debug(f'{self.acl_packets_in_flight} ACL packets in flight, {len(self.acl_packet_queue)} in queue')
logger.debug(
f'{self.acl_packets_in_flight} ACL packets in flight, '
f'{len(self.acl_packet_queue)} in queue'
)
def check_acl_packet_queue(self):
# Send all we can (TODO: support different LE/Classic limits)
while len(self.acl_packet_queue) > 0 and self.acl_packets_in_flight < self.hc_total_num_le_acl_data_packets:
while (
len(self.acl_packet_queue) > 0
and self.acl_packets_in_flight < self.hc_total_num_le_acl_data_packets
):
packet = self.acl_packet_queue.pop()
self.send_hci_packet(packet)
self.acl_packets_in_flight += 1
@@ -267,7 +377,9 @@ class Host(EventEmitter):
if value == command:
# Check if the flag is set
if octet < len(self.local_supported_commands) and flag_position < 8:
return (self.local_supported_commands[octet] & (1 << flag_position)) != 0
return (
self.local_supported_commands[octet] & (1 << flag_position)
) != 0
return False
@@ -289,15 +401,17 @@ class Host(EventEmitter):
@property
def supported_le_features(self):
return [feature for feature in range(64) if self.local_le_features & (1 << feature)]
return [
feature for feature in range(64) if self.local_le_features & (1 << feature)
]
# Packet Sink protocol (packets coming from the controller via HCI)
def on_packet(self, packet):
hci_packet = HCI_Packet.from_bytes(packet)
if self.ready or (
hci_packet.hci_packet_type == HCI_EVENT_PACKET and
hci_packet.event_code == HCI_COMMAND_COMPLETE_EVENT and
hci_packet.command_opcode == HCI_RESET_COMMAND
hci_packet.hci_packet_type == HCI_EVENT_PACKET
and hci_packet.event_code == HCI_COMMAND_COMPLETE_EVENT
and hci_packet.command_opcode == HCI_RESET_COMMAND
):
self.on_hci_packet(hci_packet)
else:
@@ -336,7 +450,11 @@ class Host(EventEmitter):
if self.pending_response:
# Check that it is what we were expecting
if self.pending_command.op_code != event.command_opcode:
logger.warning(f'!!! command result mismatch, expected 0x{self.pending_command.op_code:X} but got 0x{event.command_opcode:X}')
logger.warning(
'!!! command result mismatch, expected '
f'0x{self.pending_command.op_code:X} but got '
f'0x{event.command_opcode:X}'
)
self.pending_response.set_result(event)
else:
@@ -350,9 +468,11 @@ class Host(EventEmitter):
def on_hci_command_complete_event(self, event):
if event.command_opcode == 0:
# This is used just for the Num_HCI_Command_Packets field, not related to an actual command
# This is used just for the Num_HCI_Command_Packets field, not related to
# an actual command
logger.debug('no-command event')
else:
return None
return self.on_command_processed(event)
def on_hci_command_status_event(self, event):
@@ -364,7 +484,12 @@ class Host(EventEmitter):
self.acl_packets_in_flight -= total_packets
self.check_acl_packet_queue()
else:
logger.warning(color(f'!!! {total_packets} completed but only {self.acl_packets_in_flight} in flight'))
logger.warning(
color(
'!!! {total_packets} completed but only '
f'{self.acl_packets_in_flight} in flight'
)
)
self.acl_packets_in_flight = 0
# Classic only
@@ -381,18 +506,27 @@ class Host(EventEmitter):
# Check if this is a cancellation
if event.status == HCI_SUCCESS:
# Create/update the connection
logger.debug(f'### CONNECTION: [0x{event.connection_handle:04X}] {event.peer_address} as {HCI_Constant.role_name(event.role)}')
logger.debug(
f'### CONNECTION: [0x{event.connection_handle:04X}] '
f'{event.peer_address} as {HCI_Constant.role_name(event.role)}'
)
connection = self.connections.get(event.connection_handle)
if connection is None:
connection = Connection(self, event.connection_handle, event.role, event.peer_address, BT_LE_TRANSPORT)
connection = Connection(
self,
event.connection_handle,
event.role,
event.peer_address,
BT_LE_TRANSPORT,
)
self.connections[event.connection_handle] = connection
# Notify the client
connection_parameters = ConnectionParameters(
event.connection_interval,
event.peripheral_latency,
event.supervision_timeout
event.supervision_timeout,
)
self.emit(
'connection',
@@ -401,13 +535,15 @@ class Host(EventEmitter):
event.peer_address,
None,
event.role,
connection_parameters
connection_parameters,
)
else:
logger.debug(f'### CONNECTION FAILED: {event.status}')
# Notify the listeners
self.emit('connection_failure', BT_LE_TRANSPORT, event.peer_address, event.status)
self.emit(
'connection_failure', BT_LE_TRANSPORT, event.peer_address, event.status
)
def on_hci_le_enhanced_connection_complete_event(self, event):
# Just use the same implementation as for the non-enhanced event for now
@@ -416,11 +552,20 @@ class Host(EventEmitter):
def on_hci_connection_complete_event(self, event):
if event.status == HCI_SUCCESS:
# Create/update the connection
logger.debug(f'### BR/EDR CONNECTION: [0x{event.connection_handle:04X}] {event.bd_addr}')
logger.debug(
f'### BR/EDR CONNECTION: [0x{event.connection_handle:04X}] '
f'{event.bd_addr}'
)
connection = self.connections.get(event.connection_handle)
if connection is None:
connection = Connection(self, event.connection_handle, BT_CENTRAL_ROLE, event.bd_addr, BT_BR_EDR_TRANSPORT)
connection = Connection(
self,
event.connection_handle,
BT_CENTRAL_ROLE,
event.bd_addr,
BT_BR_EDR_TRANSPORT,
)
self.connections[event.connection_handle] = connection
# Notify the client
@@ -431,13 +576,15 @@ class Host(EventEmitter):
event.bd_addr,
None,
BT_CENTRAL_ROLE,
None
None,
)
else:
logger.debug(f'### BR/EDR CONNECTION FAILED: {event.status}')
# Notify the client
self.emit('connection_failure', BT_BR_EDR_TRANSPORT, event.bd_addr, event.status)
self.emit(
'connection_failure', BT_BR_EDR_TRANSPORT, event.bd_addr, event.status
)
def on_hci_disconnection_complete_event(self, event):
# Find the connection
@@ -446,7 +593,12 @@ class Host(EventEmitter):
return
if event.status == HCI_SUCCESS:
logger.debug(f'### DISCONNECTION: [0x{event.connection_handle:04X}] {connection.peer_address} as {HCI_Constant.role_name(connection.role)}, reason={event.reason}')
logger.debug(
f'### DISCONNECTION: [0x{event.connection_handle:04X}] '
f'{connection.peer_address} as '
f'{HCI_Constant.role_name(connection.role)}, '
f'reason={event.reason}'
)
del self.connections[event.connection_handle]
# Notify the listeners
@@ -467,11 +619,15 @@ class Host(EventEmitter):
connection_parameters = ConnectionParameters(
event.connection_interval,
event.peripheral_latency,
event.supervision_timeout
event.supervision_timeout,
)
self.emit(
'connection_parameters_update', connection.handle, connection_parameters
)
self.emit('connection_parameters_update', connection.handle, connection_parameters)
else:
self.emit('connection_parameters_update_failure', connection.handle, event.status)
self.emit(
'connection_parameters_update_failure', connection.handle, event.status
)
def on_hci_le_phy_update_complete_event(self, event):
if (connection := self.connections.get(event.connection_handle)) is None:
@@ -504,10 +660,10 @@ class Host(EventEmitter):
connection_handle=event.connection_handle,
interval_min=event.interval_min,
interval_max=event.interval_max,
latency = event.latency,
max_latency=event.max_latency,
timeout=event.timeout,
min_ce_length=0,
max_ce_length = 0
max_ce_length=0,
)
)
@@ -521,15 +677,19 @@ class Host(EventEmitter):
logger.debug('no long term key provider')
long_term_key = None
else:
long_term_key = await self.long_term_key_provider(
long_term_key = await self.abort_on(
'flush',
# pylint: disable-next=not-callable
self.long_term_key_provider(
connection.handle,
event.random_number,
event.encryption_diversifier
event.encryption_diversifier,
),
)
if long_term_key:
response = HCI_LE_Long_Term_Key_Request_Reply_Command(
connection_handle=event.connection_handle,
long_term_key = long_term_key
long_term_key=long_term_key,
)
else:
response = HCI_LE_Long_Term_Key_Request_Negative_Reply_Command(
@@ -548,10 +708,16 @@ class Host(EventEmitter):
def on_hci_role_change_event(self, event):
if event.status == HCI_SUCCESS:
logger.debug(f'role change for {event.bd_addr}: {HCI_Constant.role_name(event.new_role)}')
logger.debug(
f'role change for {event.bd_addr}: '
f'{HCI_Constant.role_name(event.new_role)}'
)
# TODO: lookup the connection and update the role
else:
logger.debug(f'role change for {event.bd_addr} failed: {HCI_Constant.error_name(event.status)}')
logger.debug(
f'role change for {event.bd_addr} failed: '
f'{HCI_Constant.error_name(event.status)}'
)
def on_hci_le_data_length_change_event(self, event):
self.emit(
@@ -560,7 +726,7 @@ class Host(EventEmitter):
event.max_tx_octets,
event.max_tx_time,
event.max_rx_octets,
event.max_rx_time
event.max_rx_time,
)
def on_hci_authentication_complete_event(self, event):
@@ -568,21 +734,35 @@ class Host(EventEmitter):
if event.status == HCI_SUCCESS:
self.emit('connection_authentication', event.connection_handle)
else:
self.emit('connection_authentication_failure', event.connection_handle, event.status)
self.emit(
'connection_authentication_failure',
event.connection_handle,
event.status,
)
def on_hci_encryption_change_event(self, event):
# Notify the client
if event.status == HCI_SUCCESS:
self.emit('connection_encryption_change', event.connection_handle, event.encryption_enabled)
self.emit(
'connection_encryption_change',
event.connection_handle,
event.encryption_enabled,
)
else:
self.emit('connection_encryption_failure', event.connection_handle, event.status)
self.emit(
'connection_encryption_failure', event.connection_handle, event.status
)
def on_hci_encryption_key_refresh_complete_event(self, event):
# Notify the client
if event.status == HCI_SUCCESS:
self.emit('connection_encryption_key_refresh', event.connection_handle)
else:
self.emit('connection_encryption_key_refresh_failure', event.connection_handle, event.status)
self.emit(
'connection_encryption_key_refresh_failure',
event.connection_handle,
event.status,
)
def on_hci_link_supervision_timeout_changed_event(self, event):
pass
@@ -594,19 +774,23 @@ class Host(EventEmitter):
pass
def on_hci_link_key_notification_event(self, event):
logger.debug(f'link key for {event.bd_addr}: {event.link_key.hex()}, type={HCI_Constant.link_key_type_name(event.key_type)}')
logger.debug(
f'link key for {event.bd_addr}: {event.link_key.hex()}, '
f'type={HCI_Constant.link_key_type_name(event.key_type)}'
)
self.emit('link_key', event.bd_addr, event.link_key, event.key_type)
def on_hci_simple_pairing_complete_event(self, event):
logger.debug(f'simple pairing complete for {event.bd_addr}: status={HCI_Constant.status_name(event.status)}')
logger.debug(
f'simple pairing complete for {event.bd_addr}: '
f'status={HCI_Constant.status_name(event.status)}'
)
def on_hci_pin_code_request_event(self, event):
# For now, just refuse all requests
# TODO: delegate the decision
self.send_command_sync(
HCI_PIN_Code_Request_Negative_Reply_Command(
bd_addr = event.bd_addr
)
HCI_PIN_Code_Request_Negative_Reply_Command(bd_addr=event.bd_addr)
)
def on_hci_link_key_request_event(self, event):
@@ -615,11 +799,14 @@ class Host(EventEmitter):
logger.debug('no link key provider')
link_key = None
else:
link_key = await self.link_key_provider(event.bd_addr)
link_key = await self.abort_on(
'flush',
# pylint: disable-next=not-callable
self.link_key_provider(event.bd_addr),
)
if link_key:
response = HCI_Link_Key_Request_Reply_Command(
bd_addr = event.bd_addr,
link_key = link_key
bd_addr=event.bd_addr, link_key=link_key
)
else:
response = HCI_Link_Key_Request_Negative_Reply_Command(
@@ -637,15 +824,21 @@ class Host(EventEmitter):
pass
def on_hci_user_confirmation_request_event(self, event):
self.emit('authentication_user_confirmation_request', event.bd_addr, event.numeric_value)
self.emit(
'authentication_user_confirmation_request',
event.bd_addr,
event.numeric_value,
)
def on_hci_user_passkey_request_event(self, event):
self.emit('authentication_user_passkey_request', event.bd_addr)
def on_hci_user_passkey_notification_event(self, event):
self.emit('authentication_user_passkey_notification', event.bd_addr, event.passkey)
self.emit(
'authentication_user_passkey_notification', event.bd_addr, event.passkey
)
def on_hci_inquiry_complete_event(self, event):
def on_hci_inquiry_complete_event(self, _event):
self.emit('inquiry_complete')
def on_hci_inquiry_result_with_rssi_event(self, event):
@@ -655,7 +848,7 @@ class Host(EventEmitter):
response.bd_addr,
response.class_of_device,
b'',
response.rssi
response.rssi,
)
def on_hci_extended_inquiry_result_event(self, event):
@@ -664,7 +857,7 @@ class Host(EventEmitter):
event.bd_addr,
event.class_of_device,
event.extended_inquiry_response,
event.rssi
event.rssi,
)
def on_hci_remote_name_request_complete_event(self, event):
@@ -674,4 +867,8 @@ class Host(EventEmitter):
self.emit('remote_name', event.bd_addr, event.remote_name)
def on_hci_remote_host_supported_features_notification_event(self, event):
self.emit('remote_host_supported_features', event.bd_addr, event.host_supported_features)
self.emit(
'remote_host_supported_features',
event.bd_addr,
event.host_supported_features,
)

View File

@@ -20,6 +20,7 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import logging
import os
import json
@@ -75,7 +76,9 @@ class PairingKeys:
@staticmethod
def key_from_dict(keys_dict, key_name):
key_dict = keys_dict.get(key_name)
if key_dict is not None:
if key_dict is None:
return None
return PairingKeys.Key.from_dict(key_dict)
@staticmethod
@@ -120,9 +123,9 @@ class PairingKeys:
def print(self, prefix=''):
keys_dict = self.to_dict()
for (property, value) in keys_dict.items():
if type(value) is dict:
print(f'{prefix}{color(property, "cyan")}:')
for (container_property, value) in keys_dict.items():
if isinstance(value, dict):
print(f'{prefix}{color(container_property, "cyan")}:')
for (key_property, key_value) in value.items():
print(f'{prefix} {color(key_property, "green")}: {key_value}')
else:
@@ -137,12 +140,16 @@ class KeyStore:
async def update(self, name, keys):
pass
async def get(self, name):
async def get(self, _name):
return PairingKeys()
async def get_all(self):
return []
async def delete_all(self):
all_keys = await self.get_all()
await asyncio.gather(*(self.delete(name) for (name, _) in all_keys))
async def get_resolving_keys(self):
all_keys = await self.get_all()
resolving_keys = []
@@ -188,10 +195,13 @@ class JsonKeyStore(KeyStore):
if filename is None:
# Use a default for the current user
# Import here because this may not exist on all platforms
# pylint: disable=import-outside-toplevel
import appdirs
self.directory_name = os.path.join(
appdirs.user_data_dir(self.APP_NAME, self.APP_AUTHOR),
self.KEYS_DIR
appdirs.user_data_dir(self.APP_NAME, self.APP_AUTHOR), self.KEYS_DIR
)
json_filename = f'{self.namespace}.json'.lower().replace(':', '-')
self.filename = os.path.join(self.directory_name, json_filename)
@@ -214,7 +224,7 @@ class JsonKeyStore(KeyStore):
async def load(self):
try:
with open(self.filename, 'r') as json_file:
with open(self.filename, 'r', encoding='utf-8') as json_file:
return json.load(json_file)
except FileNotFoundError:
return {}
@@ -226,7 +236,7 @@ class JsonKeyStore(KeyStore):
# Save to a temporary file
temp_filename = self.filename + '.tmp'
with open(temp_filename, 'w') as output:
with open(temp_filename, 'w', encoding='utf-8') as output:
json.dump(db, output, sort_keys=True, indent=4)
# Atomically replace the previous file
@@ -257,7 +267,16 @@ class JsonKeyStore(KeyStore):
if namespace is None:
return []
return [(name, PairingKeys.from_dict(keys)) for (name, keys) in namespace.items()]
return [
(name, PairingKeys.from_dict(keys)) for (name, keys) in namespace.items()
]
async def delete_all(self):
db = await self.load()
db.pop(self.namespace, None)
await self.save(db)
async def get(self, name):
db = await self.load()

File diff suppressed because it is too large Load Diff

View File

@@ -17,15 +17,16 @@
# -----------------------------------------------------------------------------
import logging
import asyncio
import websockets
from functools import partial
from colors import color
import websockets
from bumble.hci import (
Address,
HCI_SUCCESS,
HCI_CONNECTION_ACCEPT_TIMEOUT_ERROR,
HCI_CONNECTION_TIMEOUT_ERROR
HCI_CONNECTION_TIMEOUT_ERROR,
)
# -----------------------------------------------------------------------------
@@ -47,7 +48,8 @@ def parse_parameters(params_str):
# -----------------------------------------------------------------------------
# TODO: add more support for various LL exchanges (see Vol 6, Part B - 2.4 DATA CHANNEL PDU)
# TODO: add more support for various LL exchanges
# (see Vol 6, Part B - 2.4 DATA CHANNEL PDU)
# -----------------------------------------------------------------------------
class LocalLink:
'''
@@ -103,23 +105,31 @@ class LocalLink:
return
# Connect to the first controller with a matching address
if peripheral_controller := self.find_controller(le_create_connection_command.peer_address):
central_controller.on_link_peripheral_connection_complete(le_create_connection_command, HCI_SUCCESS)
if peripheral_controller := self.find_controller(
le_create_connection_command.peer_address
):
central_controller.on_link_peripheral_connection_complete(
le_create_connection_command, HCI_SUCCESS
)
peripheral_controller.on_link_central_connected(central_address)
return
# No peripheral found
central_controller.on_link_peripheral_connection_complete(
le_create_connection_command,
HCI_CONNECTION_ACCEPT_TIMEOUT_ERROR
le_create_connection_command, HCI_CONNECTION_ACCEPT_TIMEOUT_ERROR
)
def connect(self, central_address, le_create_connection_command):
logger.debug(f'$$$ CONNECTION {central_address} -> {le_create_connection_command.peer_address}')
logger.debug(
f'$$$ CONNECTION {central_address} -> '
f'{le_create_connection_command.peer_address}'
)
self.pending_connection = (central_address, le_create_connection_command)
asyncio.get_running_loop().call_soon(self.on_connection_complete)
def on_disconnection_complete(self, central_address, peripheral_address, disconnect_command):
def on_disconnection_complete(
self, central_address, peripheral_address, disconnect_command
):
# Find the controller that initiated the disconnection
if not (central_controller := self.find_controller(central_address)):
logger.warning('!!! Initiating controller not found')
@@ -127,16 +137,26 @@ class LocalLink:
# Disconnect from the first controller with a matching address
if peripheral_controller := self.find_controller(peripheral_address):
peripheral_controller.on_link_central_disconnected(central_address, disconnect_command.reason)
peripheral_controller.on_link_central_disconnected(
central_address, disconnect_command.reason
)
central_controller.on_link_peripheral_disconnection_complete(disconnect_command, HCI_SUCCESS)
central_controller.on_link_peripheral_disconnection_complete(
disconnect_command, HCI_SUCCESS
)
def disconnect(self, central_address, peripheral_address, disconnect_command):
logger.debug(f'$$$ DISCONNECTION {central_address} -> {peripheral_address}: reason = {disconnect_command.reason}')
logger.debug(
f'$$$ DISCONNECTION {central_address} -> '
f'{peripheral_address}: reason = {disconnect_command.reason}'
)
args = [central_address, peripheral_address, disconnect_command]
asyncio.get_running_loop().call_soon(self.on_disconnection_complete, *args)
def on_connection_encrypted(self, central_address, peripheral_address, rand, ediv, ltk):
# pylint: disable=too-many-arguments
def on_connection_encrypted(
self, central_address, peripheral_address, rand, ediv, ltk
):
logger.debug(f'*** ENCRYPTION {central_address} -> {peripheral_address}')
if central_controller := self.find_controller(central_address):
@@ -152,6 +172,7 @@ class RemoteLink:
A Link implementation that communicates with other virtual controllers via a
WebSocket relay
'''
def __init__(self, uri):
self.controller = None
self.uri = uri
@@ -160,7 +181,9 @@ class RemoteLink:
self.rpc_result = None
self.pending_connection = None
self.central_connections = set() # List of addresses that we have connected to
self.peripheral_connections = set() # List of addresses that have connected to us
self.peripheral_connections = (
set()
) # List of addresses that have connected to us
# Connect and run asynchronously
asyncio.create_task(self.run_connection())
@@ -192,11 +215,14 @@ class RemoteLink:
try:
await item
except Exception as error:
logger.warning(f'{color("!!! Exception in async handler:", "red")} {error}')
logger.warning(
f'{color("!!! Exception in async handler:", "red")} {error}'
)
async def run_connection(self):
# Connect to the relay
logger.debug(f'connecting to {self.uri}')
# pylint: disable-next=no-member
websocket = await websockets.connect(self.uri)
self.websocket.set_result(websocket)
logger.debug(f'connected to {self.uri}')
@@ -227,7 +253,9 @@ class RemoteLink:
self.central_connections.remove(address)
if address in self.peripheral_connections:
self.controller.on_link_central_disconnected(address, HCI_CONNECTION_TIMEOUT_ERROR)
self.controller.on_link_central_disconnected(
address, HCI_CONNECTION_TIMEOUT_ERROR
)
self.peripheral_connections.remove(address)
async def on_unreachable_received(self, target):
@@ -244,7 +272,9 @@ class RemoteLink:
async def on_advertisement_message_received(self, sender, advertisement):
try:
self.controller.on_link_advertising_data(Address(sender), bytes.fromhex(advertisement))
self.controller.on_link_advertising_data(
Address(sender), bytes.fromhex(advertisement)
)
except Exception:
logger.exception('exception')
@@ -263,11 +293,11 @@ class RemoteLink:
self.controller.on_link_central_connected(Address(sender))
# Accept the connection by responding to it
await self.send_targetted_message(sender, 'connected')
await self.send_targeted_message(sender, 'connected')
async def on_connected_message_received(self, sender, _):
if not self.pending_connection:
logger.warn('received a connection ack, but no connection is pending')
logger.warning('received a connection ack, but no connection is pending')
return
# Remember the connection
@@ -275,7 +305,9 @@ class RemoteLink:
# Notify the controller
logger.debug(f'connected to peripheral {self.pending_connection.peer_address}')
self.controller.on_link_peripheral_connection_complete(self.pending_connection, HCI_SUCCESS)
self.controller.on_link_peripheral_connection_complete(
self.pending_connection, HCI_SUCCESS
)
async def on_disconnect_message_received(self, sender, message):
# Notify the controller
@@ -287,7 +319,7 @@ class RemoteLink:
if sender in self.peripheral_connections:
self.peripheral_connections.remove(sender)
async def on_encrypted_message_received(self, sender, message):
async def on_encrypted_message_received(self, sender, _):
# TODO parse params to get real args
self.controller.on_link_encrypted(Address(sender), bytes(8), 0, bytes(16))
@@ -296,7 +328,7 @@ class RemoteLink:
websocket = await self.websocket
# Create a future value to hold the eventual result
assert(self.rpc_result is None)
assert self.rpc_result is None
self.rpc_result = asyncio.get_running_loop().create_future()
# Send the command
@@ -309,7 +341,7 @@ class RemoteLink:
# TODO: parse the result
async def send_targetted_message(self, target, message):
async def send_targeted_message(self, target, message):
# Ensure we have a connection
websocket = await self.websocket
@@ -326,35 +358,61 @@ class RemoteLink:
self.execute(self.notify_address_changed)
async def send_advertising_data_to_relay(self, data):
await self.send_targetted_message('*', f'advertisement:{data.hex()}')
await self.send_targeted_message('*', f'advertisement:{data.hex()}')
def send_advertising_data(self, sender_address, data):
def send_advertising_data(self, _, data):
self.execute(partial(self.send_advertising_data_to_relay, data))
async def send_acl_data_to_relay(self, peer_address, data):
await self.send_targetted_message(peer_address, f'acl:{data.hex()}')
await self.send_targeted_message(peer_address, f'acl:{data.hex()}')
def send_acl_data(self, sender_address, peer_address, data):
def send_acl_data(self, _, peer_address, data):
self.execute(partial(self.send_acl_data_to_relay, peer_address, data))
async def send_connection_request_to_relay(self, peer_address):
await self.send_targetted_message(peer_address, 'connect')
await self.send_targeted_message(peer_address, 'connect')
def connect(self, central_address, le_create_connection_command):
def connect(self, _, le_create_connection_command):
if self.pending_connection:
logger.warn('connection already pending')
logger.warning('connection already pending')
return
self.pending_connection = le_create_connection_command
self.execute(partial(self.send_connection_request_to_relay, str(le_create_connection_command.peer_address)))
self.execute(
partial(
self.send_connection_request_to_relay,
str(le_create_connection_command.peer_address),
)
)
def on_disconnection_complete(self, disconnect_command):
self.controller.on_link_peripheral_disconnection_complete(disconnect_command, HCI_SUCCESS)
self.controller.on_link_peripheral_disconnection_complete(
disconnect_command, HCI_SUCCESS
)
def disconnect(self, central_address, peripheral_address, disconnect_command):
logger.debug(f'disconnect {central_address} -> {peripheral_address}: reason = {disconnect_command.reason}')
self.execute(partial(self.send_targetted_message, peripheral_address, f'disconnect:reason={disconnect_command.reason}'))
asyncio.get_running_loop().call_soon(self.on_disconnection_complete, disconnect_command)
logger.debug(
f'disconnect {central_address} -> '
f'{peripheral_address}: reason = {disconnect_command.reason}'
)
self.execute(
partial(
self.send_targeted_message,
peripheral_address,
f'disconnect:reason={disconnect_command.reason}',
)
)
asyncio.get_running_loop().call_soon(
self.on_disconnection_complete, disconnect_command
)
def on_connection_encrypted(self, central_address, peripheral_address, rand, ediv, ltk):
asyncio.get_running_loop().call_soon(self.controller.on_link_encrypted, peripheral_address, rand, ediv, ltk)
self.execute(partial(self.send_targetted_message, peripheral_address, f'encrypted:ltk={ltk.hex()}'))
def on_connection_encrypted(self, _, peripheral_address, rand, ediv, ltk):
asyncio.get_running_loop().call_soon(
self.controller.on_link_encrypted, peripheral_address, rand, ediv, ltk
)
self.execute(
partial(
self.send_targeted_message,
peripheral_address,
f'encrypted:ltk={ltk.hex()}',
)
)

View File

@@ -0,0 +1,171 @@
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import struct
import logging
from typing import List
from ..core import AdvertisingData
from ..gatt import (
GATT_ASHA_SERVICE,
GATT_ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC,
GATT_ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC,
GATT_ASHA_AUDIO_STATUS_CHARACTERISTIC,
GATT_ASHA_VOLUME_CHARACTERISTIC,
GATT_ASHA_LE_PSM_OUT_CHARACTERISTIC,
TemplateService,
Characteristic,
CharacteristicValue,
)
from ..device import Device
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
class AshaService(TemplateService):
UUID = GATT_ASHA_SERVICE
OPCODE_START = 1
OPCODE_STOP = 2
OPCODE_STATUS = 3
PROTOCOL_VERSION = 0x01
RESERVED_FOR_FUTURE_USE = [00, 00]
FEATURE_MAP = [0x01] # [LE CoC audio output streaming supported]
SUPPORTED_CODEC_ID = [0x02, 0x01] # Codec IDs [G.722 at 16 kHz]
RENDER_DELAY = [00, 00]
def __init__(self, capability: int, hisyncid: List[int], device: Device, psm=0):
self.hisyncid = hisyncid
self.capability = capability # Device Capabilities [Left, Monaural]
self.device = device
self.emitted_data_name = 'ASHA_data_' + str(self.capability)
self.audio_out_data = b''
self.psm = psm # a non-zero psm is mainly for testing purpose
# Handler for volume control
def on_volume_write(_connection, value):
logger.info(f'--- VOLUME Write:{value[0]}')
# Handler for audio control commands
def on_audio_control_point_write(_connection, value):
logger.info(f'--- AUDIO CONTROL POINT Write:{value.hex()}')
opcode = value[0]
if opcode == AshaService.OPCODE_START:
# Start
audio_type = ('Unknown', 'Ringtone', 'Phone Call', 'Media')[value[2]]
logger.info(
f'### START: codec={value[1]}, '
f'audio_type={audio_type}, '
f'volume={value[3]}, '
f'otherstate={value[4]}'
)
elif opcode == AshaService.OPCODE_STOP:
logger.info('### STOP')
elif opcode == AshaService.OPCODE_STATUS:
logger.info(f'### STATUS: connected={value[1]}')
# TODO Respond with a status
# asyncio.create_task(device.notify_subscribers(audio_status_characteristic,
# force=True))
self.read_only_properties_characteristic = Characteristic(
GATT_ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC,
Characteristic.READ,
Characteristic.READABLE,
bytes(
[
AshaService.PROTOCOL_VERSION, # Version
self.capability,
]
)
+ bytes(self.hisyncid)
+ bytes(AshaService.FEATURE_MAP)
+ bytes(AshaService.RENDER_DELAY)
+ bytes(AshaService.RESERVED_FOR_FUTURE_USE)
+ bytes(AshaService.SUPPORTED_CODEC_ID),
)
self.audio_control_point_characteristic = Characteristic(
GATT_ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC,
Characteristic.WRITE | Characteristic.WRITE_WITHOUT_RESPONSE,
Characteristic.WRITEABLE,
CharacteristicValue(write=on_audio_control_point_write),
)
self.audio_status_characteristic = Characteristic(
GATT_ASHA_AUDIO_STATUS_CHARACTERISTIC,
Characteristic.READ | Characteristic.NOTIFY,
Characteristic.READABLE,
bytes([0]),
)
self.volume_characteristic = Characteristic(
GATT_ASHA_VOLUME_CHARACTERISTIC,
Characteristic.WRITE_WITHOUT_RESPONSE,
Characteristic.WRITEABLE,
CharacteristicValue(write=on_volume_write),
)
# Register an L2CAP CoC server
def on_coc(channel):
def on_data(data):
logging.debug(f'<<< data received:{data}')
self.emit(self.emitted_data_name, data)
self.audio_out_data += data
channel.sink = on_data
# let the server find a free PSM
self.psm = self.device.register_l2cap_channel_server(self.psm, on_coc, 8)
self.le_psm_out_characteristic = Characteristic(
GATT_ASHA_LE_PSM_OUT_CHARACTERISTIC,
Characteristic.READ,
Characteristic.READABLE,
struct.pack('<H', self.psm),
)
characteristics = [
self.read_only_properties_characteristic,
self.audio_control_point_characteristic,
self.audio_status_characteristic,
self.volume_characteristic,
self.le_psm_out_characteristic,
]
super().__init__(characteristics)
def get_advertising_data(self):
# Advertisement only uses 4 least significant bytes of the HiSyncId.
return bytes(
AdvertisingData(
[
(
AdvertisingData.SERVICE_DATA_16_BIT_UUID,
bytes(GATT_ASHA_SERVICE)
+ bytes(
[
AshaService.PROTOCOL_VERSION,
self.capability,
]
)
+ bytes(self.hisyncid[:4]),
),
]
)
)

View File

@@ -23,7 +23,7 @@ from ..gatt import (
TemplateService,
Characteristic,
CharacteristicValue,
PackedCharacteristicAdapter
PackedCharacteristicAdapter,
)
@@ -38,9 +38,9 @@ class BatteryService(TemplateService):
GATT_BATTERY_LEVEL_CHARACTERISTIC,
Characteristic.READ | Characteristic.NOTIFY,
Characteristic.READABLE,
CharacteristicValue(read=read_battery_level)
CharacteristicValue(read=read_battery_level),
),
format=BatteryService.BATTERY_LEVEL_FORMAT
pack_format=BatteryService.BATTERY_LEVEL_FORMAT,
)
super().__init__([self.battery_level_characteristic])
@@ -52,10 +52,11 @@ class BatteryServiceProxy(ProfileServiceProxy):
def __init__(self, service_proxy):
self.service_proxy = service_proxy
if characteristics := service_proxy.get_characteristics_by_uuid(GATT_BATTERY_LEVEL_CHARACTERISTIC):
if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_BATTERY_LEVEL_CHARACTERISTIC
):
self.battery_level = PackedCharacteristicAdapter(
characteristics[0],
format=BatteryService.BATTERY_LEVEL_FORMAT
characteristics[0], pack_format=BatteryService.BATTERY_LEVEL_FORMAT
)
else:
self.battery_level = None

View File

@@ -17,7 +17,7 @@
# Imports
# -----------------------------------------------------------------------------
import struct
from typing import Tuple
from typing import Optional, Tuple
from ..gatt_client import ProfileServiceProxy
from ..gatt import (
@@ -33,7 +33,7 @@ from ..gatt import (
TemplateService,
Characteristic,
DelegatedCharacteristicAdapter,
UTF8CharacteristicAdapter
UTF8CharacteristicAdapter,
)
@@ -52,49 +52,48 @@ class DeviceInformationService(TemplateService):
def __init__(
self,
manufacturer_name: str = None,
model_number: str = None,
serial_number: str = None,
hardware_revision: str = None,
firmware_revision: str = None,
software_revision: str = None,
system_id: Tuple[int, int] = None, # (OUI, Manufacturer ID)
ieee_regulatory_certification_data_list: bytes = None
manufacturer_name: Optional[str] = None,
model_number: Optional[str] = None,
serial_number: Optional[str] = None,
hardware_revision: Optional[str] = None,
firmware_revision: Optional[str] = None,
software_revision: Optional[str] = None,
system_id: Optional[Tuple[int, int]] = None, # (OUI, Manufacturer ID)
ieee_regulatory_certification_data_list: Optional[bytes] = None
# TODO: pnp_id
):
characteristics = [
Characteristic(
uuid,
Characteristic.READ,
Characteristic.READABLE,
field
)
Characteristic(uuid, Characteristic.READ, Characteristic.READABLE, field)
for (field, uuid) in (
(manufacturer_name, GATT_MANUFACTURER_NAME_STRING_CHARACTERISTIC),
(model_number, GATT_MODEL_NUMBER_STRING_CHARACTERISTIC),
(serial_number, GATT_SERIAL_NUMBER_STRING_CHARACTERISTIC),
(hardware_revision, GATT_HARDWARE_REVISION_STRING_CHARACTERISTIC),
(firmware_revision, GATT_FIRMWARE_REVISION_STRING_CHARACTERISTIC),
(software_revision, GATT_SOFTWARE_REVISION_STRING_CHARACTERISTIC)
(software_revision, GATT_SOFTWARE_REVISION_STRING_CHARACTERISTIC),
)
if field is not None
]
if system_id is not None:
characteristics.append(Characteristic(
characteristics.append(
Characteristic(
GATT_SYSTEM_ID_CHARACTERISTIC,
Characteristic.READ,
Characteristic.READABLE,
self.pack_system_id(*system_id)
))
self.pack_system_id(*system_id),
)
)
if ieee_regulatory_certification_data_list is not None:
characteristics.append(Characteristic(
characteristics.append(
Characteristic(
GATT_REGULATORY_CERTIFICATION_DATA_LIST_CHARACTERISTIC,
Characteristic.READ,
Characteristic.READABLE,
ieee_regulatory_certification_data_list
))
ieee_regulatory_certification_data_list,
)
)
super().__init__(characteristics)
@@ -112,7 +111,7 @@ class DeviceInformationServiceProxy(ProfileServiceProxy):
('serial_number', GATT_SERIAL_NUMBER_STRING_CHARACTERISTIC),
('hardware_revision', GATT_HARDWARE_REVISION_STRING_CHARACTERISTIC),
('firmware_revision', GATT_FIRMWARE_REVISION_STRING_CHARACTERISTIC),
('software_revision', GATT_SOFTWARE_REVISION_STRING_CHARACTERISTIC)
('software_revision', GATT_SOFTWARE_REVISION_STRING_CHARACTERISTIC),
):
if characteristics := service_proxy.get_characteristics_by_uuid(uuid):
characteristic = UTF8CharacteristicAdapter(characteristics[0])
@@ -120,16 +119,20 @@ class DeviceInformationServiceProxy(ProfileServiceProxy):
characteristic = None
self.__setattr__(field, characteristic)
if characteristics := service_proxy.get_characteristics_by_uuid(GATT_SYSTEM_ID_CHARACTERISTIC):
if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_SYSTEM_ID_CHARACTERISTIC
):
self.system_id = DelegatedCharacteristicAdapter(
characteristics[0],
encode=lambda v: DeviceInformationService.pack_system_id(*v),
decode=DeviceInformationService.unpack_system_id
decode=DeviceInformationService.unpack_system_id,
)
else:
self.system_id = None
if characteristics := service_proxy.get_characteristics_by_uuid(GATT_REGULATORY_CERTIFICATION_DATA_LIST_CHARACTERISTIC):
if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_REGULATORY_CERTIFICATION_DATA_LIST_CHARACTERISTIC
):
self.ieee_regulatory_certification_data_list = characteristics[0]
else:
self.ieee_regulatory_certification_data_list = None

View File

@@ -30,7 +30,7 @@ from ..gatt import (
Characteristic,
CharacteristicValue,
DelegatedCharacteristicAdapter,
PackedCharacteristicAdapter
PackedCharacteristicAdapter,
)
@@ -42,12 +42,12 @@ class HeartRateService(TemplateService):
RESET_ENERGY_EXPENDED = 0x01
class BodySensorLocation(IntEnum):
OTHER = 0,
CHEST = 1,
WRIST = 2,
FINGER = 3,
HAND = 4,
EAR_LOBE = 5,
OTHER = (0,)
CHEST = (1,)
WRIST = (2,)
FINGER = (3,)
HAND = (4,)
EAR_LOBE = (5,)
FOOT = 6
class HeartRateMeasurement:
@@ -56,12 +56,14 @@ class HeartRateService(TemplateService):
heart_rate,
sensor_contact_detected=None,
energy_expended=None,
rr_intervals=None
rr_intervals=None,
):
if heart_rate < 0 or heart_rate > 0xFFFF:
raise ValueError('heart_rate out of range')
if energy_expended is not None and (energy_expended < 0 or energy_expended > 0xFFFF):
if energy_expended is not None and (
energy_expended < 0 or energy_expended > 0xFFFF
):
raise ValueError('energy_expended out of range')
if rr_intervals:
@@ -87,7 +89,7 @@ class HeartRateService(TemplateService):
offset += 1
if flags & (1 << 2):
sensor_contact_detected = (flags & (1 << 1) != 0)
sensor_contact_detected = flags & (1 << 1) != 0
else:
sensor_contact_detected = None
@@ -119,38 +121,43 @@ class HeartRateService(TemplateService):
flags |= ((1 if self.sensor_contact_detected else 0) << 1) | (1 << 2)
if self.energy_expended is not None:
flags |= (1 << 3)
flags |= 1 << 3
data += struct.pack('<H', self.energy_expended)
if self.rr_intervals:
flags |= (1 << 4)
data += b''.join([
flags |= 1 << 4
data += b''.join(
[
struct.pack('<H', int(rr_interval * 1024))
for rr_interval in self.rr_intervals
])
]
)
return bytes([flags]) + data
def __str__(self):
return f'HeartRateMeasurement(heart_rate={self.heart_rate},'\
f' sensor_contact_detected={self.sensor_contact_detected},'\
f' energy_expended={self.energy_expended},'\
return (
f'HeartRateMeasurement(heart_rate={self.heart_rate},'
f' sensor_contact_detected={self.sensor_contact_detected},'
f' energy_expended={self.energy_expended},'
f' rr_intervals={self.rr_intervals})'
)
def __init__(
self,
read_heart_rate_measurement,
body_sensor_location=None,
reset_energy_expended=None
reset_energy_expended=None,
):
self.heart_rate_measurement_characteristic = DelegatedCharacteristicAdapter(
Characteristic(
GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC,
Characteristic.NOTIFY,
0,
CharacteristicValue(read=read_heart_rate_measurement)
CharacteristicValue(read=read_heart_rate_measurement),
),
encode=lambda value: bytes(value)
# pylint: disable=unnecessary-lambda
encode=lambda value: bytes(value),
)
characteristics = [self.heart_rate_measurement_characteristic]
@@ -159,11 +166,12 @@ class HeartRateService(TemplateService):
GATT_BODY_SENSOR_LOCATION_CHARACTERISTIC,
Characteristic.READ,
Characteristic.READABLE,
bytes([int(body_sensor_location)])
bytes([int(body_sensor_location)]),
)
characteristics.append(self.body_sensor_location_characteristic)
if reset_energy_expended:
def write_heart_rate_control_point_value(connection, value):
if value == self.RESET_ENERGY_EXPENDED:
if reset_energy_expended is not None:
@@ -176,9 +184,9 @@ class HeartRateService(TemplateService):
GATT_HEART_RATE_CONTROL_POINT_CHARACTERISTIC,
Characteristic.WRITE,
Characteristic.WRITEABLE,
CharacteristicValue(write=write_heart_rate_control_point_value)
CharacteristicValue(write=write_heart_rate_control_point_value),
),
format=HeartRateService.HEART_RATE_CONTROL_POINT_FORMAT
pack_format=HeartRateService.HEART_RATE_CONTROL_POINT_FORMAT,
)
characteristics.append(self.heart_rate_control_point_characteristic)
@@ -192,30 +200,38 @@ class HeartRateServiceProxy(ProfileServiceProxy):
def __init__(self, service_proxy):
self.service_proxy = service_proxy
if characteristics := service_proxy.get_characteristics_by_uuid(GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC):
if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC
):
self.heart_rate_measurement = DelegatedCharacteristicAdapter(
characteristics[0],
decode=HeartRateService.HeartRateMeasurement.from_bytes
decode=HeartRateService.HeartRateMeasurement.from_bytes,
)
else:
self.heart_rate_measurement = None
if characteristics := service_proxy.get_characteristics_by_uuid(GATT_BODY_SENSOR_LOCATION_CHARACTERISTIC):
if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_BODY_SENSOR_LOCATION_CHARACTERISTIC
):
self.body_sensor_location = DelegatedCharacteristicAdapter(
characteristics[0],
decode=lambda value: HeartRateService.BodySensorLocation(value[0])
decode=lambda value: HeartRateService.BodySensorLocation(value[0]),
)
else:
self.body_sensor_location = None
if characteristics := service_proxy.get_characteristics_by_uuid(GATT_HEART_RATE_CONTROL_POINT_CHARACTERISTIC):
if characteristics := service_proxy.get_characteristics_by_uuid(
GATT_HEART_RATE_CONTROL_POINT_CHARACTERISTIC
):
self.heart_rate_control_point = PackedCharacteristicAdapter(
characteristics[0],
format=HeartRateService.HEART_RATE_CONTROL_POINT_FORMAT
pack_format=HeartRateService.HEART_RATE_CONTROL_POINT_FORMAT,
)
else:
self.heart_rate_control_point = None
async def reset_energy_expended(self):
if self.heart_rate_control_point is not None:
return await self.heart_rate_control_point.write_value(HeartRateService.RESET_ENERGY_EXPENDED)
return await self.heart_rate_control_point.write_value(
HeartRateService.RESET_ENERGY_EXPENDED
)

0
bumble/profiles/py.typed Normal file
View File

0
bumble/py.typed Normal file
View File

View File

@@ -21,7 +21,8 @@ import asyncio
from colors import color
from pyee import EventEmitter
from .core import BT_BR_EDR_TRANSPORT, InvalidStateError, ProtocolError, ConnectionError
from . import core
from .core import BT_BR_EDR_TRANSPORT, InvalidStateError, ProtocolError
# -----------------------------------------------------------------------------
# Logging
@@ -32,6 +33,8 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
# fmt: off
RFCOMM_PSM = 0x0003
@@ -98,19 +101,21 @@ RFCOMM_DEFAULT_PREFERRED_MTU = 1280
RFCOMM_DYNAMIC_CHANNEL_NUMBER_START = 1
RFCOMM_DYNAMIC_CHANNEL_NUMBER_END = 30
# fmt: on
# -----------------------------------------------------------------------------
def fcs(buffer):
fcs = 0xFF
def compute_fcs(buffer):
result = 0xFF
for byte in buffer:
fcs = CRC_TABLE[fcs ^ byte]
return 0xFF - fcs
result = CRC_TABLE[result ^ byte]
return 0xFF - result
# -----------------------------------------------------------------------------
class RFCOMM_Frame:
def __init__(self, type, c_r, dlci, p_f, information = b'', with_credits = False):
self.type = type
def __init__(self, frame_type, c_r, dlci, p_f, information=b'', with_credits=False):
self.type = frame_type
self.c_r = c_r
self.dlci = dlci
self.p_f = p_f
@@ -125,18 +130,18 @@ class RFCOMM_Frame:
# 1-byte length indicator
self.length = bytes([(length << 1) | 1])
self.address = (dlci << 2) | (c_r << 1) | 1
self.control = type | (p_f << 4)
if type == RFCOMM_UIH_FRAME:
self.fcs = fcs(bytes([self.address, self.control]))
self.control = frame_type | (p_f << 4)
if frame_type == RFCOMM_UIH_FRAME:
self.fcs = compute_fcs(bytes([self.address, self.control]))
else:
self.fcs = fcs(bytes([self.address, self.control]) + self.length)
self.fcs = compute_fcs(bytes([self.address, self.control]) + self.length)
def type_name(self):
return RFCOMM_FRAME_TYPE_NAMES[self.type]
@staticmethod
def parse_mcc(data):
type = data[0] >> 2
mcc_type = data[0] >> 2
c_r = (data[0] >> 1) & 1
length = data[1]
if data[1] & 1:
@@ -146,11 +151,14 @@ class RFCOMM_Frame:
length = (data[3] << 7) & (length >> 1)
value = data[3 : 3 + length]
return (type, c_r, value)
return (mcc_type, c_r, value)
@staticmethod
def make_mcc(type, c_r, data):
return bytes([(type << 2 | c_r << 1 | 1) & 0xFF, (len(data) & 0x7F) << 1 | 1]) + data
def make_mcc(mcc_type, c_r, data):
return (
bytes([(mcc_type << 2 | c_r << 1 | 1) & 0xFF, (len(data) & 0x7F) << 1 | 1])
+ data
)
@staticmethod
def sabm(c_r, dlci):
@@ -170,14 +178,16 @@ class RFCOMM_Frame:
@staticmethod
def uih(c_r, dlci, information, p_f=0):
return RFCOMM_Frame(RFCOMM_UIH_FRAME, c_r, dlci, p_f, information, with_credits = (p_f == 1))
return RFCOMM_Frame(
RFCOMM_UIH_FRAME, c_r, dlci, p_f, information, with_credits=(p_f == 1)
)
@staticmethod
def from_bytes(data):
# Extract fields
dlci = (data[0] >> 2) & 0x3F
c_r = (data[0] >> 1) & 0x01
type = data[1] & 0xEF
frame_type = data[1] & 0xEF
p_f = (data[1] >> 4) & 0x01
length = data[2]
if length & 0x01:
@@ -189,23 +199,44 @@ class RFCOMM_Frame:
fcs = data[-1]
# Construct the frame and check the CRC
frame = RFCOMM_Frame(type, c_r, dlci, p_f, information)
frame = RFCOMM_Frame(frame_type, c_r, dlci, p_f, information)
if frame.fcs != fcs:
logger.warn(f'FCS mismatch: got {fcs:02X}, expected {frame.fcs:02X}')
logger.warning(f'FCS mismatch: got {fcs:02X}, expected {frame.fcs:02X}')
raise ValueError('fcs mismatch')
return frame
def __bytes__(self):
return bytes([self.address, self.control]) + self.length + self.information + bytes([self.fcs])
return (
bytes([self.address, self.control])
+ self.length
+ self.information
+ bytes([self.fcs])
)
def __str__(self):
return f'{color(self.type_name(), "yellow")}(c/r={self.c_r},dlci={self.dlci},p/f={self.p_f},length={len(self.information)},fcs=0x{self.fcs:02X})'
return (
f'{color(self.type_name(), "yellow")}'
f'(c/r={self.c_r},'
f'dlci={self.dlci},'
f'p/f={self.p_f},'
f'length={len(self.information)},'
f'fcs=0x{self.fcs:02X})'
)
# -----------------------------------------------------------------------------
class RFCOMM_MCC_PN:
def __init__(self, dlci, cl, priority, ack_timer, max_frame_size, max_retransmissions, window_size):
def __init__(
self,
dlci,
cl,
priority,
ack_timer,
max_frame_size,
max_retransmissions,
window_size,
):
self.dlci = dlci
self.cl = cl
self.priority = priority
@@ -223,11 +254,12 @@ class RFCOMM_MCC_PN:
ack_timer=data[3],
max_frame_size=data[4] | data[5] << 8,
max_retransmissions=data[6],
window_size = data[7]
window_size=data[7],
)
def __bytes__(self):
return bytes([
return bytes(
[
self.dlci & 0xFF,
self.cl & 0xFF,
self.priority & 0xFF,
@@ -235,11 +267,20 @@ class RFCOMM_MCC_PN:
self.max_frame_size & 0xFF,
(self.max_frame_size >> 8) & 0xFF,
self.max_retransmissions & 0xFF,
self.window_size & 0xFF
])
self.window_size & 0xFF,
]
)
def __str__(self):
return f'PN(dlci={self.dlci},cl={self.cl},priority={self.priority},ack_timer={self.ack_timer},max_frame_size={self.max_frame_size},max_retransmissions={self.max_retransmissions},window_size={self.window_size})'
return (
f'PN(dlci={self.dlci},'
f'cl={self.cl},'
f'priority={self.priority},'
f'ack_timer={self.ack_timer},'
f'max_frame_size={self.max_frame_size},'
f'max_retransmissions={self.max_retransmissions},'
f'window_size={self.window_size})'
)
# -----------------------------------------------------------------------------
@@ -260,17 +301,31 @@ class RFCOMM_MCC_MSC:
rtc=data[1] >> 2 & 1,
rtr=data[1] >> 3 & 1,
ic=data[1] >> 6 & 1,
dv = data[1] >> 7 & 1
dv=data[1] >> 7 & 1,
)
def __bytes__(self):
return bytes([
return bytes(
[
(self.dlci << 2) | 3,
1 | self.fc << 1 | self.rtc << 2 | self.rtr << 3 | self.ic << 6 | self.dv << 7
])
1
| self.fc << 1
| self.rtc << 2
| self.rtr << 3
| self.ic << 6
| self.dv << 7,
]
)
def __str__(self):
return f'MSC(dlci={self.dlci},fc={self.fc},rtc={self.rtc},rtr={self.rtr},ic={self.ic},dv={self.dv})'
return (
f'MSC(dlci={self.dlci},'
f'fc={self.fc},'
f'rtc={self.rtc},'
f'rtr={self.rtr},'
f'ic={self.ic},'
f'dv={self.dv})'
)
# -----------------------------------------------------------------------------
@@ -289,7 +344,7 @@ class DLC(EventEmitter):
CONNECTED: 'CONNECTED',
DISCONNECTING: 'DISCONNECTING',
DISCONNECTED: 'DISCONNECTED',
RESET: 'RESET'
RESET: 'RESET',
}
def __init__(self, multiplexer, dlci, max_frame_size, initial_tx_credits):
@@ -304,17 +359,22 @@ class DLC(EventEmitter):
self.role = multiplexer.role
self.c_r = 1 if self.role == Multiplexer.INITIATOR else 0
self.sink = None
self.connection_result = None
# Compute the MTU
max_overhead = 4 + 1 # header with 2-byte length + fcs
self.mtu = min(max_frame_size, self.multiplexer.l2cap_channel.mtu - max_overhead)
self.mtu = min(
max_frame_size, self.multiplexer.l2cap_channel.mtu - max_overhead
)
@staticmethod
def state_name(state):
return DLC.STATE_NAMES[state]
def change_state(self, new_state):
logger.debug(f'{self} state change -> {color(self.state_name(new_state), "magenta")}')
logger.debug(
f'{self} state change -> {color(self.state_name(new_state), "magenta")}'
)
self.state = new_state
def send_frame(self, frame):
@@ -324,58 +384,40 @@ class DLC(EventEmitter):
handler = getattr(self, f'on_{frame.type_name()}_frame'.lower())
handler(frame)
def on_sabm_frame(self, frame):
def on_sabm_frame(self, _frame):
if self.state != DLC.CONNECTING:
logger.warn(color('!!! received SABM when not in CONNECTING state', 'red'))
logger.warning(
color('!!! received SABM when not in CONNECTING state', 'red')
)
return
self.send_frame(RFCOMM_Frame.ua(c_r=1 - self.c_r, dlci=self.dlci))
# Exchange the modem status with the peer
msc = RFCOMM_MCC_MSC(
dlci = self.dlci,
fc = 0,
rtc = 1,
rtr = 1,
ic = 0,
dv = 1
msc = RFCOMM_MCC_MSC(dlci=self.dlci, fc=0, rtc=1, rtr=1, ic=0, dv=1)
mcc = RFCOMM_Frame.make_mcc(
mcc_type=RFCOMM_MCC_MSC_TYPE, c_r=1, data=bytes(msc)
)
mcc = RFCOMM_Frame.make_mcc(type = RFCOMM_MCC_MSC_TYPE, c_r = 1, data = bytes(msc))
logger.debug(f'>>> MCC MSC Command: {msc}')
self.send_frame(
RFCOMM_Frame.uih(
c_r = self.c_r,
dlci = 0,
information = mcc
)
)
self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
self.change_state(DLC.CONNECTED)
self.emit('open')
def on_ua_frame(self, frame):
def on_ua_frame(self, _frame):
if self.state != DLC.CONNECTING:
logger.warn(color('!!! received SABM when not in CONNECTING state', 'red'))
logger.warning(
color('!!! received SABM when not in CONNECTING state', 'red')
)
return
# Exchange the modem status with the peer
msc = RFCOMM_MCC_MSC(
dlci = self.dlci,
fc = 0,
rtc = 1,
rtr = 1,
ic = 0,
dv = 1
msc = RFCOMM_MCC_MSC(dlci=self.dlci, fc=0, rtc=1, rtr=1, ic=0, dv=1)
mcc = RFCOMM_Frame.make_mcc(
mcc_type=RFCOMM_MCC_MSC_TYPE, c_r=1, data=bytes(msc)
)
mcc = RFCOMM_Frame.make_mcc(type = RFCOMM_MCC_MSC_TYPE, c_r = 1, data = bytes(msc))
logger.debug(f'>>> MCC MSC Command: {msc}')
self.send_frame(
RFCOMM_Frame.uih(
c_r = self.c_r,
dlci = 0,
information = mcc
)
)
self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
self.change_state(DLC.CONNECTED)
self.multiplexer.on_dlc_open_complete(self)
@@ -384,7 +426,7 @@ class DLC(EventEmitter):
# TODO: handle all states
pass
def on_disc_frame(self, frame):
def on_disc_frame(self, _frame):
# TODO: handle all states
self.send_frame(RFCOMM_Frame.ua(c_r=1 - self.c_r, dlci=self.dlci))
@@ -392,21 +434,28 @@ class DLC(EventEmitter):
data = frame.information
if frame.p_f == 1:
# With credits
credits = frame.information[0]
self.tx_credits += credits
received_credits = frame.information[0]
self.tx_credits += received_credits
logger.debug(f'<<< Credits [{self.dlci}]: received {credits}, total={self.tx_credits}')
logger.debug(
f'<<< Credits [{self.dlci}]: '
f'received {credits}, total={self.tx_credits}'
)
data = data[1:]
logger.debug(f'{color("<<< Data", "yellow")} [{self.dlci}] {len(data)} bytes, rx_credits={self.rx_credits}: {data.hex()}')
logger.debug(
f'{color("<<< Data", "yellow")} '
f'[{self.dlci}] {len(data)} bytes, '
f'rx_credits={self.rx_credits}: {data.hex()}'
)
if len(data) and self.sink:
self.sink(data)
self.sink(data) # pylint: disable=not-callable
# Update the credits
if self.rx_credits > 0:
self.rx_credits -= 1
else:
logger.warn(color('!!! received frame with no rx credits', 'red'))
logger.warning(color('!!! received frame with no rx credits', 'red'))
# Check if there's anything to send (including credits)
self.process_tx()
@@ -418,42 +467,26 @@ class DLC(EventEmitter):
if c_r:
# Command
logger.debug(f'<<< MCC MSC Command: {msc}')
msc = RFCOMM_MCC_MSC(
dlci = self.dlci,
fc = 0,
rtc = 1,
rtr = 1,
ic = 0,
dv = 1
msc = RFCOMM_MCC_MSC(dlci=self.dlci, fc=0, rtc=1, rtr=1, ic=0, dv=1)
mcc = RFCOMM_Frame.make_mcc(
mcc_type=RFCOMM_MCC_MSC_TYPE, c_r=0, data=bytes(msc)
)
mcc = RFCOMM_Frame.make_mcc(type = RFCOMM_MCC_MSC_TYPE, c_r = 0, data = bytes(msc))
logger.debug(f'>>> MCC MSC Response: {msc}')
self.send_frame(
RFCOMM_Frame.uih(
c_r = self.c_r,
dlci = 0,
information = mcc
)
)
self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
else:
# Response
logger.debug(f'<<< MCC MSC Response: {msc}')
def connect(self):
if not self.state == DLC.INIT:
if self.state != DLC.INIT:
raise InvalidStateError('invalid state')
self.change_state(DLC.CONNECTING)
self.connection_result = asyncio.get_running_loop().create_future()
self.send_frame(
RFCOMM_Frame.sabm(
c_r = self.c_r,
dlci = self.dlci
)
)
self.send_frame(RFCOMM_Frame.sabm(c_r=self.c_r, dlci=self.dlci))
def accept(self):
if not self.state == DLC.INIT:
if self.state != DLC.INIT:
raise InvalidStateError('invalid state')
pn = RFCOMM_MCC_PN(
@@ -463,23 +496,17 @@ class DLC(EventEmitter):
ack_timer=0,
max_frame_size=RFCOMM_DEFAULT_PREFERRED_MTU,
max_retransmissions=0,
window_size = RFCOMM_DEFAULT_INITIAL_RX_CREDITS
window_size=RFCOMM_DEFAULT_INITIAL_RX_CREDITS,
)
mcc = RFCOMM_Frame.make_mcc(type = RFCOMM_MCC_PN_TYPE, c_r = 0, data = bytes(pn))
mcc = RFCOMM_Frame.make_mcc(mcc_type=RFCOMM_MCC_PN_TYPE, c_r=0, data=bytes(pn))
logger.debug(f'>>> PN Response: {pn}')
self.send_frame(
RFCOMM_Frame.uih(
c_r = self.c_r,
dlci = 0,
information = mcc
)
)
self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
self.change_state(DLC.CONNECTING)
def rx_credits_needed(self):
if self.rx_credits <= self.rx_threshold:
return RFCOMM_DEFAULT_INITIAL_RX_CREDITS - self.rx_credits
else:
return 0
def process_tx(self):
@@ -491,7 +518,7 @@ class DLC(EventEmitter):
chunk = bytes([rx_credits_needed]) + self.tx_buffer[: self.mtu - 1]
self.tx_buffer = self.tx_buffer[len(chunk) - 1 :]
self.rx_credits += rx_credits_needed
tx_credit_spent = (len(chunk) > 1)
tx_credit_spent = len(chunk) > 1
else:
chunk = self.tx_buffer[: self.mtu]
self.tx_buffer = self.tx_buffer[len(chunk) :]
@@ -503,13 +530,17 @@ class DLC(EventEmitter):
self.tx_credits -= 1
# Send the frame
logger.debug(f'>>> sending {len(chunk)} bytes with {rx_credits_needed} credits, rx_credits={self.rx_credits}, tx_credits={self.tx_credits}')
logger.debug(
f'>>> sending {len(chunk)} bytes with {rx_credits_needed} credits, '
f'rx_credits={self.rx_credits}, '
f'tx_credits={self.tx_credits}'
)
self.send_frame(
RFCOMM_Frame.uih(
c_r=self.c_r,
dlci=self.dlci,
information=chunk,
p_f = 1 if rx_credits_needed > 0 else 0
p_f=1 if rx_credits_needed > 0 else 0,
)
)
@@ -518,8 +549,8 @@ class DLC(EventEmitter):
# Stream protocol
def write(self, data):
# We can only send bytes
if type(data) != bytes:
if type(data) == str:
if not isinstance(data, bytes):
if isinstance(data, str):
# Automatically convert strings to bytes using UTF-8
data = data.encode('utf-8')
else:
@@ -558,7 +589,7 @@ class Multiplexer(EventEmitter):
OPENING: 'OPENING',
DISCONNECTING: 'DISCONNECTING',
DISCONNECTED: 'DISCONNECTED',
RESET: 'RESET'
RESET: 'RESET',
}
def __init__(self, l2cap_channel, role):
@@ -580,7 +611,9 @@ class Multiplexer(EventEmitter):
return Multiplexer.STATE_NAMES[state]
def change_state(self, new_state):
logger.debug(f'{self} state change -> {color(self.state_name(new_state), "cyan")}')
logger.debug(
f'{self} state change -> {color(self.state_name(new_state), "cyan")}'
)
self.state = new_state
def send_frame(self, frame):
@@ -596,14 +629,14 @@ class Multiplexer(EventEmitter):
self.on_frame(frame)
else:
if frame.type == RFCOMM_DM_FRAME:
# DM responses are for a DLCI, but since we only create the dlc when we receive
# a PN response (because we need the parameters), we handle DM frames at the Multiplexer
# level
# DM responses are for a DLCI, but since we only create the dlc when we
# receive a PN response (because we need the parameters), we handle DM
# frames at the Multiplexer level
self.on_dm_frame(frame)
else:
dlc = self.dlcs.get(frame.dlci)
if dlc is None:
logger.warn(f'no dlc for DLCI {frame.dlci}')
logger.warning(f'no dlc for DLCI {frame.dlci}')
return
dlc.on_frame(frame)
@@ -611,14 +644,14 @@ class Multiplexer(EventEmitter):
handler = getattr(self, f'on_{frame.type_name()}_frame'.lower())
handler(frame)
def on_sabm_frame(self, frame):
def on_sabm_frame(self, _frame):
if self.state != Multiplexer.INIT:
logger.debug('not in INIT state, ignoring SABM')
return
self.change_state(Multiplexer.CONNECTED)
self.send_frame(RFCOMM_Frame.ua(c_r=1, dlci=0))
def on_ua_frame(self, frame):
def on_ua_frame(self, _frame):
if self.state == Multiplexer.CONNECTING:
self.change_state(Multiplexer.CONNECTED)
if self.connection_result:
@@ -630,30 +663,34 @@ class Multiplexer(EventEmitter):
self.disconnection_result.set_result(None)
self.disconnection_result = None
def on_dm_frame(self, frame):
def on_dm_frame(self, _frame):
if self.state == Multiplexer.OPENING:
self.change_state(Multiplexer.CONNECTED)
if self.open_result:
self.open_result.set_exception(ConnectionError(
ConnectionError.CONNECTION_REFUSED,
self.open_result.set_exception(
core.ConnectionError(
core.ConnectionError.CONNECTION_REFUSED,
BT_BR_EDR_TRANSPORT,
self.l2cap_channel.connection.peer_address,
'rfcomm'
))
'rfcomm',
)
)
else:
logger.warn(f'unexpected state for DM: {self}')
logger.warning(f'unexpected state for DM: {self}')
def on_disc_frame(self, frame):
def on_disc_frame(self, _frame):
self.change_state(Multiplexer.DISCONNECTED)
self.send_frame(RFCOMM_Frame.ua(c_r = 0 if self.role == Multiplexer.INITIATOR else 1, dlci = 0))
self.send_frame(
RFCOMM_Frame.ua(c_r=0 if self.role == Multiplexer.INITIATOR else 1, dlci=0)
)
def on_uih_frame(self, frame):
(type, c_r, value) = RFCOMM_Frame.parse_mcc(frame.information)
(mcc_type, c_r, value) = RFCOMM_Frame.parse_mcc(frame.information)
if type == RFCOMM_MCC_PN_TYPE:
if mcc_type == RFCOMM_MCC_PN_TYPE:
pn = RFCOMM_MCC_PN.from_bytes(value)
self.on_mcc_pn(c_r, pn)
elif type == RFCOMM_MCC_MSC_TYPE:
elif mcc_type == RFCOMM_MCC_MSC_TYPE:
mcs = RFCOMM_MCC_MSC.from_bytes(value)
self.on_mcc_msc(c_r, mcs)
@@ -669,7 +706,7 @@ class Multiplexer(EventEmitter):
if pn.dlci & 1:
# Not expected, this is an initiator-side number
# TODO: error out
logger.warn(f'invalid DLCI: {pn.dlci}')
logger.warning(f'invalid DLCI: {pn.dlci}')
else:
if self.acceptor:
channel_number = pn.dlci >> 1
@@ -688,7 +725,7 @@ class Multiplexer(EventEmitter):
self.send_frame(RFCOMM_Frame.dm(c_r=1, dlci=pn.dlci))
else:
# No acceptor?? shouldn't happen
logger.warn(color('!!! no acceptor registered', 'red'))
logger.warning(color('!!! no acceptor registered', 'red'))
else:
# Response
logger.debug(f'>>> PN Response: {pn}')
@@ -697,12 +734,12 @@ class Multiplexer(EventEmitter):
self.dlcs[pn.dlci] = dlc
dlc.connect()
else:
logger.warn('ignoring PN response')
logger.warning('ignoring PN response')
def on_mcc_msc(self, c_r, msc):
dlc = self.dlcs.get(msc.dlci)
if dlc is None:
logger.warn(f'no dlc for DLCI {msc.dlci}')
logger.warning(f'no dlc for DLCI {msc.dlci}')
return
dlc.on_mcc_msc(c_r, msc)
@@ -721,14 +758,18 @@ class Multiplexer(EventEmitter):
self.disconnection_result = asyncio.get_running_loop().create_future()
self.change_state(Multiplexer.DISCONNECTING)
self.send_frame(RFCOMM_Frame.disc(c_r = 1 if self.role == Multiplexer.INITIATOR else 0, dlci = 0))
self.send_frame(
RFCOMM_Frame.disc(
c_r=1 if self.role == Multiplexer.INITIATOR else 0, dlci=0
)
)
await self.disconnection_result
async def open_dlc(self, channel):
if self.state != Multiplexer.CONNECTED:
if self.state == Multiplexer.OPENING:
raise InvalidStateError('open already in progress')
else:
raise InvalidStateError('not connected')
pn = RFCOMM_MCC_PN(
@@ -738,9 +779,9 @@ class Multiplexer(EventEmitter):
ack_timer=0,
max_frame_size=RFCOMM_DEFAULT_PREFERRED_MTU,
max_retransmissions=0,
window_size = RFCOMM_DEFAULT_INITIAL_RX_CREDITS
window_size=RFCOMM_DEFAULT_INITIAL_RX_CREDITS,
)
mcc = RFCOMM_Frame.make_mcc(type = RFCOMM_MCC_PN_TYPE, c_r = 1, data = bytes(pn))
mcc = RFCOMM_Frame.make_mcc(mcc_type=RFCOMM_MCC_PN_TYPE, c_r=1, data=bytes(pn))
logger.debug(f'>>> Sending MCC: {pn}')
self.open_result = asyncio.get_running_loop().create_future()
self.change_state(Multiplexer.OPENING)
@@ -748,7 +789,7 @@ class Multiplexer(EventEmitter):
RFCOMM_Frame.uih(
c_r=1 if self.role == Multiplexer.INITIATOR else 0,
dlci=0,
information = mcc
information=mcc,
)
)
result = await self.open_result
@@ -776,9 +817,11 @@ class Client:
async def start(self):
# Create a new L2CAP connection
try:
self.l2cap_channel = await self.device.l2cap_channel_manager.connect(self.connection, RFCOMM_PSM)
self.l2cap_channel = await self.device.l2cap_channel_manager.connect(
self.connection, RFCOMM_PSM
)
except ProtocolError as error:
logger.warn(f'L2CAP connection failed: {error}')
logger.warning(f'L2CAP connection failed: {error}')
raise
# Create a mutliplexer to manage DLCs with the server
@@ -811,7 +854,9 @@ class Server(EventEmitter):
def listen(self, acceptor):
# Find a free channel number
for channel in range(RFCOMM_DYNAMIC_CHANNEL_NUMBER_START, RFCOMM_DYNAMIC_CHANNEL_NUMBER_END + 1):
for channel in range(
RFCOMM_DYNAMIC_CHANNEL_NUMBER_START, RFCOMM_DYNAMIC_CHANNEL_NUMBER_END + 1
):
if channel not in self.acceptors:
self.acceptors[channel] = acceptor
return channel

View File

@@ -15,10 +15,12 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import logging
import struct
from colors import color
import colors
from typing import Dict, Type
from . import core
from .core import InvalidStateError
@@ -33,6 +35,9 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
# fmt: off
# pylint: disable=line-too-long
SDP_CONTINUATION_WATCHDOG = 64 # Maximum number of continuations we're willing to do
SDP_PSM = 0x0001
@@ -112,6 +117,10 @@ SDP_PUBLIC_BROWSE_ROOT = core.UUID.from_16_bits(0x1002, 'PublicBrowseRoot')
# To be used in searches where an attribute ID list allows a range to be specified
SDP_ALL_ATTRIBUTES_RANGE = (0x0000FFFF, 4) # Express this as tuple so we can convey the desired encoding size
# fmt: on
# pylint: enable=line-too-long
# pylint: disable=invalid-name
# -----------------------------------------------------------------------------
class DataElement:
@@ -134,27 +143,42 @@ class DataElement:
BOOLEAN: 'BOOLEAN',
SEQUENCE: 'SEQUENCE',
ALTERNATIVE: 'ALTERNATIVE',
URL: 'URL'
URL: 'URL',
}
type_constructors = {
NIL: lambda x: DataElement(DataElement.NIL, None),
UNSIGNED_INTEGER: lambda x, y: DataElement(DataElement.UNSIGNED_INTEGER, DataElement.unsigned_integer_from_bytes(x), value_size=y),
SIGNED_INTEGER: lambda x, y: DataElement(DataElement.SIGNED_INTEGER, DataElement.signed_integer_from_bytes(x), value_size=y),
UUID: lambda x: DataElement(DataElement.UUID, core.UUID.from_bytes(bytes(reversed(x)))),
UNSIGNED_INTEGER: lambda x, y: DataElement(
DataElement.UNSIGNED_INTEGER,
DataElement.unsigned_integer_from_bytes(x),
value_size=y,
),
SIGNED_INTEGER: lambda x, y: DataElement(
DataElement.SIGNED_INTEGER,
DataElement.signed_integer_from_bytes(x),
value_size=y,
),
UUID: lambda x: DataElement(
DataElement.UUID, core.UUID.from_bytes(bytes(reversed(x)))
),
TEXT_STRING: lambda x: DataElement(DataElement.TEXT_STRING, x.decode('utf8')),
BOOLEAN: lambda x: DataElement(DataElement.BOOLEAN, x[0] == 1),
SEQUENCE: lambda x: DataElement(DataElement.SEQUENCE, DataElement.list_from_bytes(x)),
ALTERNATIVE: lambda x: DataElement(DataElement.ALTERNATIVE, DataElement.list_from_bytes(x)),
URL: lambda x: DataElement(DataElement.URL, x.decode('utf8'))
SEQUENCE: lambda x: DataElement(
DataElement.SEQUENCE, DataElement.list_from_bytes(x)
),
ALTERNATIVE: lambda x: DataElement(
DataElement.ALTERNATIVE, DataElement.list_from_bytes(x)
),
URL: lambda x: DataElement(DataElement.URL, x.decode('utf8')),
}
def __init__(self, type, value, value_size=None):
self.type = type
def __init__(self, element_type, value, value_size=None):
self.type = element_type
self.value = value
self.value_size = value_size
self.bytes = None # Used a cache when parsing from bytes so we can emit a byte-for-byte replica
if type == DataElement.UNSIGNED_INTEGER or type == DataElement.SIGNED_INTEGER:
# Used as a cache when parsing from bytes so we can emit a byte-for-byte replica
self.bytes = None
if element_type in (DataElement.UNSIGNED_INTEGER, DataElement.SIGNED_INTEGER):
if value_size is None:
raise ValueError('integer types must have a value size specified')
@@ -222,26 +246,32 @@ class DataElement:
def unsigned_integer_from_bytes(data):
if len(data) == 1:
return data[0]
elif len(data) == 2:
if len(data) == 2:
return struct.unpack('>H', data)[0]
elif len(data) == 4:
if len(data) == 4:
return struct.unpack('>I', data)[0]
elif len(data) == 8:
if len(data) == 8:
return struct.unpack('>Q', data)[0]
else:
raise ValueError(f'invalid integer length {len(data)}')
@staticmethod
def signed_integer_from_bytes(data):
if len(data) == 1:
return struct.unpack('b', data)[0]
elif len(data) == 2:
if len(data) == 2:
return struct.unpack('>h', data)[0]
elif len(data) == 4:
if len(data) == 4:
return struct.unpack('>i', data)[0]
elif len(data) == 8:
if len(data) == 8:
return struct.unpack('>q', data)[0]
else:
raise ValueError(f'invalid integer length {len(data)}')
@staticmethod
@@ -260,11 +290,11 @@ class DataElement:
@staticmethod
def from_bytes(data):
type = data[0] >> 3
element_type = data[0] >> 3
size_index = data[0] & 7
value_offset = 0
if size_index == 0:
if type == DataElement.NIL:
if element_type == DataElement.NIL:
value_size = 0
else:
value_size = 1
@@ -287,15 +317,20 @@ class DataElement:
value_offset = 4
value_data = data[1 + value_offset : 1 + value_offset + value_size]
constructor = DataElement.type_constructors.get(type)
constructor = DataElement.type_constructors.get(element_type)
if constructor:
if type == DataElement.UNSIGNED_INTEGER or type == DataElement.SIGNED_INTEGER:
if element_type in (
DataElement.UNSIGNED_INTEGER,
DataElement.SIGNED_INTEGER,
):
result = constructor(value_data, value_size)
else:
result = constructor(value_data)
else:
result = DataElement(type, value_data)
result.bytes = data[:1 + value_offset + value_size] # Keep a copy so we can re-serialize to an exact replica
result = DataElement(element_type, value_data)
result.bytes = data[
: 1 + value_offset + value_size
] # Keep a copy so we can re-serialize to an exact replica
return result
def to_bytes(self):
@@ -311,7 +346,8 @@ class DataElement:
elif self.type == DataElement.UNSIGNED_INTEGER:
if self.value < 0:
raise ValueError('UNSIGNED_INTEGER cannot be negative')
elif self.value_size == 1:
if self.value_size == 1:
data = struct.pack('B', self.value)
elif self.value_size == 2:
data = struct.pack('>H', self.value)
@@ -334,11 +370,11 @@ class DataElement:
raise ValueError('invalid value_size')
elif self.type == DataElement.UUID:
data = bytes(reversed(bytes(self.value)))
elif self.type == DataElement.TEXT_STRING or self.type == DataElement.URL:
elif self.type in (DataElement.TEXT_STRING, DataElement.URL):
data = self.value.encode('utf8')
elif self.type == DataElement.BOOLEAN:
data = bytes([1 if self.value else 0])
elif self.type == DataElement.SEQUENCE or self.type == DataElement.ALTERNATIVE:
elif self.type in (DataElement.SEQUENCE, DataElement.ALTERNATIVE):
data = b''.join([bytes(element) for element in self.value])
else:
data = self.value
@@ -349,9 +385,11 @@ class DataElement:
if size != 0:
raise ValueError('NIL must be empty')
size_index = 0
elif (self.type == DataElement.UNSIGNED_INTEGER or
self.type == DataElement.SIGNED_INTEGER or
self.type == DataElement.UUID):
elif self.type in (
DataElement.UNSIGNED_INTEGER,
DataElement.SIGNED_INTEGER,
DataElement.UUID,
):
if size <= 1:
size_index = 0
elif size == 2:
@@ -364,10 +402,12 @@ class DataElement:
size_index = 4
else:
raise ValueError('invalid data size')
elif (self.type == DataElement.TEXT_STRING or
self.type == DataElement.SEQUENCE or
self.type == DataElement.ALTERNATIVE or
self.type == DataElement.URL):
elif self.type in (
DataElement.TEXT_STRING,
DataElement.SEQUENCE,
DataElement.ALTERNATIVE,
DataElement.URL,
):
if size <= 0xFF:
size_index = 5
size_bytes = bytes([size])
@@ -392,11 +432,19 @@ class DataElement:
type_name = name_or_number(self.TYPE_NAMES, self.type)
if self.type == DataElement.NIL:
value_string = ''
elif self.type == DataElement.SEQUENCE or self.type == DataElement.ALTERNATIVE:
elif self.type in (DataElement.SEQUENCE, DataElement.ALTERNATIVE):
container_separator = '\n' if pretty else ''
element_separator = '\n' if pretty else ','
value_string = f'[{container_separator}{element_separator.join([element.to_string(pretty, indentation + 1 if pretty else 0) for element in self.value])}{container_separator}{prefix}]'
elif self.type == DataElement.UNSIGNED_INTEGER or self.type == DataElement.SIGNED_INTEGER:
elements = [
element.to_string(pretty, indentation + 1 if pretty else 0)
for element in self.value
]
value_string = (
f'[{container_separator}'
f'{element_separator.join(elements)}'
f'{container_separator}{prefix}]'
)
elif self.type in (DataElement.UNSIGNED_INTEGER, DataElement.SIGNED_INTEGER):
value_string = f'{self.value}#{self.value_size}'
elif isinstance(self.value, DataElement):
value_string = self.value.to_string(pretty, indentation)
@@ -410,8 +458,8 @@ class DataElement:
# -----------------------------------------------------------------------------
class ServiceAttribute:
def __init__(self, id, value):
self.id = id
def __init__(self, attribute_id, value):
self.id = attribute_id
self.value = value
@staticmethod
@@ -420,7 +468,7 @@ class ServiceAttribute:
for i in range(0, len(elements) // 2):
attribute_id, attribute_value = elements[2 * i : 2 * (i + 1)]
if attribute_id.type != DataElement.UNSIGNED_INTEGER:
logger.warn('attribute ID element is not an integer')
logger.warning('attribute ID element is not an integer')
continue
attribute_list.append(ServiceAttribute(attribute_id.value, attribute_value))
@@ -428,29 +476,40 @@ class ServiceAttribute:
@staticmethod
def find_attribute_in_list(attribute_list, attribute_id):
return next((attribute.value for attribute in attribute_list if attribute.id == attribute_id), None)
return next(
(
attribute.value
for attribute in attribute_list
if attribute.id == attribute_id
),
None,
)
@staticmethod
def id_name(id):
return name_or_number(SDP_ATTRIBUTE_ID_NAMES, id)
def id_name(id_code):
return name_or_number(SDP_ATTRIBUTE_ID_NAMES, id_code)
@staticmethod
def is_uuid_in_value(uuid, value):
# Find if a uuid matches a value, either directly or recursing into sequences
if value.type == DataElement.UUID:
return value.value == uuid
elif value.type == DataElement.SEQUENCE:
if value.type == DataElement.SEQUENCE:
for element in value.value:
if ServiceAttribute.is_uuid_in_value(uuid, element):
return True
return False
else:
return False
def to_string(self, color=False):
if color:
return f'Attribute(id={colors.color(self.id_name(self.id),"magenta")},value={self.value})'
else:
def to_string(self, with_colors=False):
if with_colors:
return (
f'Attribute(id={colors.color(self.id_name(self.id),"magenta")},'
f'value={self.value})'
)
return f'Attribute(id={self.id_name(self.id)},value={self.value})'
def __str__(self):
@@ -462,11 +521,14 @@ class SDP_PDU:
'''
See Bluetooth spec @ Vol 3, Part B - 4.2 PROTOCOL DATA UNIT FORMAT
'''
sdp_pdu_classes = {}
sdp_pdu_classes: Dict[int, Type[SDP_PDU]] = {}
name = None
pdu_id = 0
@staticmethod
def from_bytes(pdu):
pdu_id, transaction_id, parameters_length = struct.unpack_from('>BHH', pdu, 0)
pdu_id, transaction_id, _parameters_length = struct.unpack_from('>BHH', pdu, 0)
cls = SDP_PDU.sdp_pdu_classes.get(pdu_id)
if cls is None:
@@ -484,7 +546,9 @@ class SDP_PDU:
@staticmethod
def parse_service_record_handle_list_preceded_by_count(data, offset):
count = struct.unpack_from('>H', data, offset - 2)[0]
handle_list = [struct.unpack_from('>I', data, offset + x * 4)[0] for x in range(count)]
handle_list = [
struct.unpack_from('>I', data, offset + x * 4)[0] for x in range(count)
]
return offset + count * 4, handle_list
@staticmethod
@@ -532,7 +596,10 @@ class SDP_PDU:
HCI_Object.init_from_fields(self, self.fields, kwargs)
if pdu is None:
parameters = HCI_Object.dict_to_bytes(kwargs, self.fields)
pdu = struct.pack('>BHH', self.pdu_id, transaction_id, len(parameters)) + parameters
pdu = (
struct.pack('>BHH', self.pdu_id, transaction_id, len(parameters))
+ parameters
)
self.pdu = pdu
self.transaction_id = transaction_id
@@ -555,9 +622,7 @@ class SDP_PDU:
# -----------------------------------------------------------------------------
@SDP_PDU.subclass([
('error_code', {'size': 2, 'mapper': SDP_PDU.error_name})
])
@SDP_PDU.subclass([('error_code', {'size': 2, 'mapper': SDP_PDU.error_name})])
class SDP_ErrorResponse(SDP_PDU):
'''
See Bluetooth spec @ Vol 3, Part B - 4.4.1 SDP_ErrorResponse PDU
@@ -565,11 +630,13 @@ class SDP_ErrorResponse(SDP_PDU):
# -----------------------------------------------------------------------------
@SDP_PDU.subclass([
@SDP_PDU.subclass(
[
('service_search_pattern', DataElement.parse_from_bytes),
('maximum_service_record_count', '>2'),
('continuation_state', '*')
])
('continuation_state', '*'),
]
)
class SDP_ServiceSearchRequest(SDP_PDU):
'''
See Bluetooth spec @ Vol 3, Part B - 4.5.1 SDP_ServiceSearchRequest PDU
@@ -577,12 +644,17 @@ class SDP_ServiceSearchRequest(SDP_PDU):
# -----------------------------------------------------------------------------
@SDP_PDU.subclass([
@SDP_PDU.subclass(
[
('total_service_record_count', '>2'),
('current_service_record_count', '>2'),
('service_record_handle_list', SDP_PDU.parse_service_record_handle_list_preceded_by_count),
('continuation_state', '*')
])
(
'service_record_handle_list',
SDP_PDU.parse_service_record_handle_list_preceded_by_count,
),
('continuation_state', '*'),
]
)
class SDP_ServiceSearchResponse(SDP_PDU):
'''
See Bluetooth spec @ Vol 3, Part B - 4.5.2 SDP_ServiceSearchResponse PDU
@@ -590,12 +662,14 @@ class SDP_ServiceSearchResponse(SDP_PDU):
# -----------------------------------------------------------------------------
@SDP_PDU.subclass([
@SDP_PDU.subclass(
[
('service_record_handle', '>4'),
('maximum_attribute_byte_count', '>2'),
('attribute_id_list', DataElement.parse_from_bytes),
('continuation_state', '*')
])
('continuation_state', '*'),
]
)
class SDP_ServiceAttributeRequest(SDP_PDU):
'''
See Bluetooth spec @ Vol 3, Part B - 4.6.1 SDP_ServiceAttributeRequest PDU
@@ -603,11 +677,13 @@ class SDP_ServiceAttributeRequest(SDP_PDU):
# -----------------------------------------------------------------------------
@SDP_PDU.subclass([
@SDP_PDU.subclass(
[
('attribute_list_byte_count', '>2'),
('attribute_list', SDP_PDU.parse_bytes_preceded_by_length),
('continuation_state', '*')
])
('continuation_state', '*'),
]
)
class SDP_ServiceAttributeResponse(SDP_PDU):
'''
See Bluetooth spec @ Vol 3, Part B - 4.6.2 SDP_ServiceAttributeResponse PDU
@@ -615,12 +691,14 @@ class SDP_ServiceAttributeResponse(SDP_PDU):
# -----------------------------------------------------------------------------
@SDP_PDU.subclass([
@SDP_PDU.subclass(
[
('service_search_pattern', DataElement.parse_from_bytes),
('maximum_attribute_byte_count', '>2'),
('attribute_id_list', DataElement.parse_from_bytes),
('continuation_state', '*')
])
('continuation_state', '*'),
]
)
class SDP_ServiceSearchAttributeRequest(SDP_PDU):
'''
See Bluetooth spec @ Vol 3, Part B - 4.7.1 SDP_ServiceSearchAttributeRequest PDU
@@ -628,11 +706,13 @@ class SDP_ServiceSearchAttributeRequest(SDP_PDU):
# -----------------------------------------------------------------------------
@SDP_PDU.subclass([
@SDP_PDU.subclass(
[
('attribute_lists_byte_count', '>2'),
('attribute_lists', SDP_PDU.parse_bytes_preceded_by_length),
('continuation_state', '*')
])
('continuation_state', '*'),
]
)
class SDP_ServiceSearchAttributeResponse(SDP_PDU):
'''
See Bluetooth spec @ Vol 3, Part B - 4.7.2 SDP_ServiceSearchAttributeResponse PDU
@@ -659,7 +739,9 @@ class Client:
if self.pending_request is not None:
raise InvalidStateError('request already pending')
service_search_pattern = DataElement.sequence([DataElement.uuid(uuid) for uuid in uuids])
service_search_pattern = DataElement.sequence(
[DataElement.uuid(uuid) for uuid in uuids]
)
# Request and accumulate until there's no more continuation
service_record_handle_list = []
@@ -671,7 +753,7 @@ class Client:
transaction_id=0, # Transaction ID TODO: pick a real value
service_search_pattern=service_search_pattern,
maximum_service_record_count=0xFFFF,
continuation_state = continuation_state
continuation_state=continuation_state,
)
)
response = SDP_PDU.from_bytes(response_pdu)
@@ -689,11 +771,15 @@ class Client:
if self.pending_request is not None:
raise InvalidStateError('request already pending')
service_search_pattern = DataElement.sequence([DataElement.uuid(uuid) for uuid in uuids])
service_search_pattern = DataElement.sequence(
[DataElement.uuid(uuid) for uuid in uuids]
)
attribute_id_list = DataElement.sequence(
[
DataElement.unsigned_integer(attribute_id[0], value_size=attribute_id[1])
if type(attribute_id) is tuple
DataElement.unsigned_integer(
attribute_id[0], value_size=attribute_id[1]
)
if isinstance(attribute_id, tuple)
else DataElement.unsigned_integer_16(attribute_id)
for attribute_id in attribute_ids
]
@@ -710,7 +796,7 @@ class Client:
service_search_pattern=service_search_pattern,
maximum_attribute_byte_count=0xFFFF,
attribute_id_list=attribute_id_list,
continuation_state = continuation_state
continuation_state=continuation_state,
)
)
response = SDP_PDU.from_bytes(response_pdu)
@@ -725,7 +811,7 @@ class Client:
# Parse the result into attribute lists
attribute_lists_sequences = DataElement.from_bytes(accumulator)
if attribute_lists_sequences.type != DataElement.SEQUENCE:
logger.warn('unexpected data type')
logger.warning('unexpected data type')
return []
return [
@@ -740,8 +826,10 @@ class Client:
attribute_id_list = DataElement.sequence(
[
DataElement.unsigned_integer(attribute_id[0], value_size=attribute_id[1])
if type(attribute_id) is tuple
DataElement.unsigned_integer(
attribute_id[0], value_size=attribute_id[1]
)
if isinstance(attribute_id, tuple)
else DataElement.unsigned_integer_16(attribute_id)
for attribute_id in attribute_ids
]
@@ -758,7 +846,7 @@ class Client:
service_record_handle=service_record_handle,
maximum_attribute_byte_count=0xFFFF,
attribute_id_list=attribute_id_list,
continuation_state = continuation_state
continuation_state=continuation_state,
)
)
response = SDP_PDU.from_bytes(response_pdu)
@@ -773,7 +861,7 @@ class Client:
# Parse the result into a list of attributes
attribute_list_sequence = DataElement.from_bytes(accumulator)
if attribute_list_sequence.type != DataElement.SEQUENCE:
logger.warn('unexpected data type')
logger.warning('unexpected data type')
return []
return ServiceAttribute.list_from_data_elements(attribute_list_sequence.value)
@@ -786,6 +874,7 @@ class Server:
def __init__(self, device):
self.device = device
self.service_records = {} # Service records maps, by record handle
self.channel = None
self.current_response = None
def register(self, l2cap_channel_manager):
@@ -820,11 +909,10 @@ class Server:
try:
sdp_pdu = SDP_PDU.from_bytes(pdu)
except Exception as error:
logger.warn(color(f'failed to parse SDP Request PDU: {error}', 'red'))
logger.warning(color(f'failed to parse SDP Request PDU: {error}', 'red'))
self.send_response(
SDP_ErrorResponse(
transaction_id = 0,
error_code = SDP_INVALID_REQUEST_SYNTAX_ERROR
transaction_id=0, error_code=SDP_INVALID_REQUEST_SYNTAX_ERROR
)
)
@@ -841,7 +929,7 @@ class Server:
self.send_response(
SDP_ErrorResponse(
transaction_id=sdp_pdu.transaction_id,
error_code = SDP_INSUFFICIENT_RESOURCES_TO_SATISFY_REQUEST_ERROR
error_code=SDP_INSUFFICIENT_RESOURCES_TO_SATISFY_REQUEST_ERROR,
)
)
else:
@@ -849,7 +937,7 @@ class Server:
self.send_response(
SDP_ErrorResponse(
transaction_id=sdp_pdu.transaction_id,
error_code = SDP_INVALID_REQUEST_SYNTAX_ERROR
error_code=SDP_INVALID_REQUEST_SYNTAX_ERROR,
)
)
@@ -877,11 +965,12 @@ class Server:
id_range_start = attribute_id.value
id_range_end = attribute_id.value
attributes += [
attribute for attribute in service
attribute
for attribute in service
if attribute.id >= id_range_start and attribute.id <= id_range_end
]
# Return the maching attributes, sorted by attribute id
# Return the matching attributes, sorted by attribute id
attributes.sort(key=lambda x: x.id)
attribute_list = DataElement.sequence([])
for attribute in attributes:
@@ -897,7 +986,7 @@ class Server:
self.send_response(
SDP_ErrorResponse(
transaction_id=request.transaction_id,
error_code = SDP_INVALID_CONTINUATION_STATE_ERROR
error_code=SDP_INVALID_CONTINUATION_STATE_ERROR,
)
)
return
@@ -910,30 +999,38 @@ class Server:
service_record_handles = list(matching_services.keys())
# Only return up to the maximum requested
service_record_handles_subset = service_record_handles[:request.maximum_service_record_count]
service_record_handles_subset = service_record_handles[
: request.maximum_service_record_count
]
# Serialize to a byte array, and remember the total count
logger.debug(f'Service Record Handles: {service_record_handles}')
self.current_response = (
len(service_record_handles),
service_record_handles_subset
service_record_handles_subset,
)
# Respond, keeping any unsent handles for later
service_record_handles = self.current_response[1][:request.maximum_service_record_count]
service_record_handles = self.current_response[1][
: request.maximum_service_record_count
]
self.current_response = (
self.current_response[0],
self.current_response[1][request.maximum_service_record_count:]
self.current_response[1][request.maximum_service_record_count :],
)
continuation_state = (
Server.CONTINUATION_STATE if self.current_response[1] else bytes([0])
)
service_record_handle_list = b''.join(
[struct.pack('>I', handle) for handle in service_record_handles]
)
continuation_state = Server.CONTINUATION_STATE if self.current_response[1] else bytes([0])
service_record_handle_list = b''.join([struct.pack('>I', handle) for handle in service_record_handles])
self.send_response(
SDP_ServiceSearchResponse(
transaction_id=request.transaction_id,
total_service_record_count=self.current_response[0],
current_service_record_count=len(service_record_handles),
service_record_handle_list=service_record_handle_list,
continuation_state = continuation_state
continuation_state=continuation_state,
)
)
@@ -944,7 +1041,7 @@ class Server:
self.send_response(
SDP_ErrorResponse(
transaction_id=request.transaction_id,
error_code = SDP_INVALID_CONTINUATION_STATE_ERROR
error_code=SDP_INVALID_CONTINUATION_STATE_ERROR,
)
)
return
@@ -958,26 +1055,30 @@ class Server:
self.send_response(
SDP_ErrorResponse(
transaction_id=request.transaction_id,
error_code = SDP_INVALID_SERVICE_RECORD_HANDLE_ERROR
error_code=SDP_INVALID_SERVICE_RECORD_HANDLE_ERROR,
)
)
return
# Get the attributes for the service
attribute_list = Server.get_service_attributes(service, request.attribute_id_list.value)
attribute_list = Server.get_service_attributes(
service, request.attribute_id_list.value
)
# Serialize to a byte array
logger.debug(f'Attributes: {attribute_list}')
self.current_response = bytes(attribute_list)
# Respond, keeping any pending chunks for later
attribute_list, continuation_state = self.get_next_response_payload(request.maximum_attribute_byte_count)
attribute_list, continuation_state = self.get_next_response_payload(
request.maximum_attribute_byte_count
)
self.send_response(
SDP_ServiceAttributeResponse(
transaction_id=request.transaction_id,
attribute_list_byte_count=len(attribute_list),
attribute_list=attribute_list,
continuation_state = continuation_state
continuation_state=continuation_state,
)
)
@@ -988,7 +1089,7 @@ class Server:
self.send_response(
SDP_ErrorResponse(
transaction_id=request.transaction_id,
error_code = SDP_INVALID_CONTINUATION_STATE_ERROR
error_code=SDP_INVALID_CONTINUATION_STATE_ERROR,
)
)
else:
@@ -996,12 +1097,16 @@ class Server:
self.current_response = None
# Find the matching services
matching_services = self.match_services(request.service_search_pattern).values()
matching_services = self.match_services(
request.service_search_pattern
).values()
# Filter the required attributes
attribute_lists = DataElement.sequence([])
for service in matching_services:
attribute_list = Server.get_service_attributes(service, request.attribute_id_list.value)
attribute_list = Server.get_service_attributes(
service, request.attribute_id_list.value
)
if attribute_list.value:
attribute_lists.value.append(attribute_list)
@@ -1010,12 +1115,14 @@ class Server:
self.current_response = bytes(attribute_lists)
# Respond, keeping any pending chunks for later
attribute_lists, continuation_state = self.get_next_response_payload(request.maximum_attribute_byte_count)
attribute_lists, continuation_state = self.get_next_response_payload(
request.maximum_attribute_byte_count
)
self.send_response(
SDP_ServiceSearchAttributeResponse(
transaction_id=request.transaction_id,
attribute_lists_byte_count=len(attribute_lists),
attribute_lists=attribute_lists,
continuation_state = continuation_state
continuation_state=continuation_state,
)
)

File diff suppressed because it is too large Load Diff

View File

@@ -35,47 +35,75 @@ async def open_transport(name):
Where <parameters> depend on the type (and may be empty for some types).
The supported types are: serial,udp,tcp,pty,usb
'''
# pylint: disable=import-outside-toplevel
# pylint: disable=too-many-return-statements
scheme, *spec = name.split(':', 1)
if scheme == 'serial' and spec:
from .serial import open_serial_transport
return await open_serial_transport(spec[0])
elif scheme == 'udp' and spec:
if scheme == 'udp' and spec:
from .udp import open_udp_transport
return await open_udp_transport(spec[0])
elif scheme == 'tcp-client' and spec:
if scheme == 'tcp-client' and spec:
from .tcp_client import open_tcp_client_transport
return await open_tcp_client_transport(spec[0])
elif scheme == 'tcp-server' and spec:
if scheme == 'tcp-server' and spec:
from .tcp_server import open_tcp_server_transport
return await open_tcp_server_transport(spec[0])
elif scheme == 'ws-client' and spec:
if scheme == 'ws-client' and spec:
from .ws_client import open_ws_client_transport
return await open_ws_client_transport(spec[0])
elif scheme == 'ws-server' and spec:
if scheme == 'ws-server' and spec:
from .ws_server import open_ws_server_transport
return await open_ws_server_transport(spec[0])
elif scheme == 'pty':
if scheme == 'pty':
from .pty import open_pty_transport
return await open_pty_transport(spec[0] if spec else None)
elif scheme == 'file':
if scheme == 'file':
from .file import open_file_transport
return await open_file_transport(spec[0] if spec else None)
elif scheme == 'vhci':
if scheme == 'vhci':
from .vhci import open_vhci_transport
return await open_vhci_transport(spec[0] if spec else None)
elif scheme == 'hci-socket':
if scheme == 'hci-socket':
from .hci_socket import open_hci_socket_transport
return await open_hci_socket_transport(spec[0] if spec else None)
elif scheme == 'usb':
if scheme == 'usb':
from .usb import open_usb_transport
return await open_usb_transport(spec[0] if spec else None)
elif scheme == 'pyusb':
if scheme == 'pyusb':
from .pyusb import open_pyusb_transport
return await open_pyusb_transport(spec[0] if spec else None)
elif scheme == 'android-emulator':
if scheme == 'android-emulator':
from .android_emulator import open_android_emulator_transport
return await open_android_emulator_transport(spec[0] if spec else None)
else:
raise ValueError('unknown transport scheme')
@@ -91,5 +119,5 @@ async def open_transport_or_link(name):
link.close()
return LinkTransport(controller, AsyncPipeSink(controller))
else:
return await open_transport(name)

View File

@@ -1,4 +1,4 @@
# Copyright 2021-2022 Google LLC
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -20,9 +20,11 @@ import grpc
from .common import PumpedTransport, PumpedPacketSource, PumpedPacketSink
from .emulated_bluetooth_pb2_grpc import EmulatedBluetoothServiceStub
from .emulated_bluetooth_packets_pb2 import HCIPacket
from .emulated_bluetooth_vhci_pb2_grpc import VhciForwardingServiceStub
# pylint: disable-next=no-name-in-module
from .emulated_bluetooth_packets_pb2 import HCIPacket
# -----------------------------------------------------------------------------
# Logging
@@ -59,12 +61,7 @@ async def open_android_emulator_transport(spec):
return bytes([packet.type]) + packet.packet
async def write(self, packet):
await self.hci_device.write(
HCIPacket(
type = packet[0],
packet = packet[1:]
)
)
await self.hci_device.write(HCIPacket(type=packet[0], packet=packet[1:]))
# Parse the parameters
mode = 'host'
@@ -100,7 +97,7 @@ async def open_android_emulator_transport(spec):
transport = PumpedTransport(
PumpedPacketSource(hci_device.read),
PumpedPacketSink(hci_device.write),
channel.close
channel.close,
)
transport.start()

View File

@@ -36,7 +36,7 @@ HCI_PACKET_INFO = {
hci.HCI_COMMAND_PACKET: (1, 2, 'B'),
hci.HCI_ACL_DATA_PACKET: (2, 2, 'H'),
hci.HCI_SYNCHRONOUS_DATA_PACKET: (1, 2, 'B'),
hci.HCI_EVENT_PACKET: (1, 1, 'B')
hci.HCI_EVENT_PACKET: (1, 1, 'B'),
}
@@ -65,8 +65,12 @@ class PacketPump:
# -----------------------------------------------------------------------------
class PacketParser:
'''
In-line parser that accepts data and emits 'on_packet' when a full packet has been parsed
In-line parser that accepts data and emits 'on_packet' when a full packet has been
parsed
'''
# pylint: disable=attribute-defined-outside-init
NEED_TYPE = 0
NEED_LENGTH = 1
NEED_BODY = 2
@@ -95,13 +99,17 @@ class PacketParser:
if self.bytes_needed == 0:
if self.state == PacketParser.NEED_TYPE:
packet_type = self.packet[0]
self.packet_info = HCI_PACKET_INFO.get(packet_type) or self.extended_packet_info.get(packet_type)
self.packet_info = HCI_PACKET_INFO.get(
packet_type
) or self.extended_packet_info.get(packet_type)
if self.packet_info is None:
raise ValueError(f'invalid packet type {packet_type}')
self.state = PacketParser.NEED_LENGTH
self.bytes_needed = self.packet_info[0] + self.packet_info[1]
elif self.state == PacketParser.NEED_LENGTH:
body_length = struct.unpack_from(self.packet_info[2], self.packet, 1 + self.packet_info[1])[0]
body_length = struct.unpack_from(
self.packet_info[2], self.packet, 1 + self.packet_info[1]
)[0]
self.bytes_needed = body_length
self.state = PacketParser.NEED_BODY
@@ -111,7 +119,9 @@ class PacketParser:
try:
self.sink.on_packet(bytes(self.packet))
except Exception as error:
logger.warning(color(f'!!! Exception in on_packet: {error}', 'red'))
logger.warning(
color(f'!!! Exception in on_packet: {error}', 'red')
)
self.reset()
def set_packet_sink(self, sink):
@@ -187,6 +197,7 @@ class AsyncPipeSink:
'''
Sink that forwards packets asynchronously to another sink
'''
def __init__(self, sink):
self.sink = sink
self.loop = asyncio.get_running_loop()
@@ -270,7 +281,7 @@ class PumpedPacketSource(ParserSource):
logger.debug('source pump task done')
break
except Exception as error:
logger.warn(f'exception while waiting for packet: {error}')
logger.warning(f'exception while waiting for packet: {error}')
self.terminated.set_result(error)
break
@@ -301,7 +312,7 @@ class PumpedPacketSink:
logger.debug('sink pump task done')
break
except Exception as error:
logger.warn(f'exception while sending packet: {error}')
logger.warning(f'exception while sending packet: {error}')
break
self.pump_task = asyncio.create_task(pump_packets())

View File

@@ -1,4 +1,4 @@
# Copyright 2021-2022 Google LLC
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,31 +16,24 @@
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: emulated_bluetooth_packets.proto
"""Generated protocol buffer code."""
from google.protobuf.internal import builder as _builder
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n emulated_bluetooth_packets.proto\x12\x1b\x61ndroid.emulation.bluetooth\"\xfb\x01\n\tHCIPacket\x12?\n\x04type\x18\x01 \x01(\x0e\x32\x31.android.emulation.bluetooth.HCIPacket.PacketType\x12\x0e\n\x06packet\x18\x02 \x01(\x0c\"\x9c\x01\n\nPacketType\x12\x1b\n\x17PACKET_TYPE_UNSPECIFIED\x10\x00\x12\x1b\n\x17PACKET_TYPE_HCI_COMMAND\x10\x01\x12\x13\n\x0fPACKET_TYPE_ACL\x10\x02\x12\x13\n\x0fPACKET_TYPE_SCO\x10\x03\x12\x15\n\x11PACKET_TYPE_EVENT\x10\x04\x12\x13\n\x0fPACKET_TYPE_ISO\x10\x05\x42J\n\x1f\x63om.android.emulation.bluetoothP\x01\xf8\x01\x01\xa2\x02\x03\x41\x45\x42\xaa\x02\x1b\x41ndroid.Emulation.Bluetoothb\x06proto3'
)
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n emulated_bluetooth_packets.proto\x12\x1b\x61ndroid.emulation.bluetooth\"\xfb\x01\n\tHCIPacket\x12?\n\x04type\x18\x01 \x01(\x0e\x32\x31.android.emulation.bluetooth.HCIPacket.PacketType\x12\x0e\n\x06packet\x18\x02 \x01(\x0c\"\x9c\x01\n\nPacketType\x12\x1b\n\x17PACKET_TYPE_UNSPECIFIED\x10\x00\x12\x1b\n\x17PACKET_TYPE_HCI_COMMAND\x10\x01\x12\x13\n\x0fPACKET_TYPE_ACL\x10\x02\x12\x13\n\x0fPACKET_TYPE_SCO\x10\x03\x12\x15\n\x11PACKET_TYPE_EVENT\x10\x04\x12\x13\n\x0fPACKET_TYPE_ISO\x10\x05\x42J\n\x1f\x63om.android.emulation.bluetoothP\x01\xf8\x01\x01\xa2\x02\x03\x41\x45\x42\xaa\x02\x1b\x41ndroid.Emulation.Bluetoothb\x06proto3')
_HCIPACKET = DESCRIPTOR.message_types_by_name['HCIPacket']
_HCIPACKET_PACKETTYPE = _HCIPACKET.enum_types_by_name['PacketType']
HCIPacket = _reflection.GeneratedProtocolMessageType('HCIPacket', (_message.Message,), {
'DESCRIPTOR' : _HCIPACKET,
'__module__' : 'emulated_bluetooth_packets_pb2'
# @@protoc_insertion_point(class_scope:android.emulation.bluetooth.HCIPacket)
})
_sym_db.RegisterMessage(HCIPacket)
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(
DESCRIPTOR, 'emulated_bluetooth_packets_pb2', globals()
)
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None

View File

@@ -0,0 +1,41 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union
DESCRIPTOR: _descriptor.FileDescriptor
class HCIPacket(_message.Message):
__slots__ = ["packet", "type"]
class PacketType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
__slots__ = []
PACKET_FIELD_NUMBER: _ClassVar[int]
PACKET_TYPE_ACL: HCIPacket.PacketType
PACKET_TYPE_EVENT: HCIPacket.PacketType
PACKET_TYPE_HCI_COMMAND: HCIPacket.PacketType
PACKET_TYPE_ISO: HCIPacket.PacketType
PACKET_TYPE_SCO: HCIPacket.PacketType
PACKET_TYPE_UNSPECIFIED: HCIPacket.PacketType
TYPE_FIELD_NUMBER: _ClassVar[int]
packet: bytes
type: HCIPacket.PacketType
def __init__(
self,
type: _Optional[_Union[HCIPacket.PacketType, str]] = ...,
packet: _Optional[bytes] = ...,
) -> None: ...

View File

@@ -0,0 +1,17 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc

View File

@@ -1,4 +1,4 @@
# Copyright 2021-2022 Google LLC
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,11 +16,11 @@
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: emulated_bluetooth.proto
"""Generated protocol buffer code."""
from google.protobuf.internal import builder as _builder
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
@@ -29,19 +29,12 @@ _sym_db = _symbol_database.Default()
from . import emulated_bluetooth_packets_pb2 as emulated__bluetooth__packets__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x18\x65mulated_bluetooth.proto\x12\x1b\x61ndroid.emulation.bluetooth\x1a emulated_bluetooth_packets.proto\"\x19\n\x07RawData\x12\x0e\n\x06packet\x18\x01 \x01(\x0c\x32\xcb\x02\n\x18\x45mulatedBluetoothService\x12\x64\n\x12registerClassicPhy\x12$.android.emulation.bluetooth.RawData\x1a$.android.emulation.bluetooth.RawData(\x01\x30\x01\x12`\n\x0eregisterBlePhy\x12$.android.emulation.bluetooth.RawData\x1a$.android.emulation.bluetooth.RawData(\x01\x30\x01\x12g\n\x11registerHCIDevice\x12&.android.emulation.bluetooth.HCIPacket\x1a&.android.emulation.bluetooth.HCIPacket(\x01\x30\x01\x42\"\n\x1e\x63om.android.emulator.bluetoothP\x01\x62\x06proto3')
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n\x18\x65mulated_bluetooth.proto\x12\x1b\x61ndroid.emulation.bluetooth\x1a emulated_bluetooth_packets.proto\"\x19\n\x07RawData\x12\x0e\n\x06packet\x18\x01 \x01(\x0c\x32\xcb\x02\n\x18\x45mulatedBluetoothService\x12\x64\n\x12registerClassicPhy\x12$.android.emulation.bluetooth.RawData\x1a$.android.emulation.bluetooth.RawData(\x01\x30\x01\x12`\n\x0eregisterBlePhy\x12$.android.emulation.bluetooth.RawData\x1a$.android.emulation.bluetooth.RawData(\x01\x30\x01\x12g\n\x11registerHCIDevice\x12&.android.emulation.bluetooth.HCIPacket\x1a&.android.emulation.bluetooth.HCIPacket(\x01\x30\x01\x42\"\n\x1e\x63om.android.emulator.bluetoothP\x01\x62\x06proto3'
)
_RAWDATA = DESCRIPTOR.message_types_by_name['RawData']
RawData = _reflection.GeneratedProtocolMessageType('RawData', (_message.Message,), {
'DESCRIPTOR' : _RAWDATA,
'__module__' : 'emulated_bluetooth_pb2'
# @@protoc_insertion_point(class_scope:android.emulation.bluetooth.RawData)
})
_sym_db.RegisterMessage(RawData)
_EMULATEDBLUETOOTHSERVICE = DESCRIPTOR.services_by_name['EmulatedBluetoothService']
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'emulated_bluetooth_pb2', globals())
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None

View File

@@ -0,0 +1,26 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import emulated_bluetooth_packets_pb2 as _emulated_bluetooth_packets_pb2
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from typing import ClassVar as _ClassVar, Optional as _Optional
DESCRIPTOR: _descriptor.FileDescriptor
class RawData(_message.Message):
__slots__ = ["packet"]
PACKET_FIELD_NUMBER: _ClassVar[int]
packet: bytes
def __init__(self, packet: _Optional[bytes] = ...) -> None: ...

View File

@@ -1,4 +1,4 @@
# Copyright 2021-2022 Google LLC
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -138,7 +138,8 @@ def add_EmulatedBluetoothServiceServicer_to_server(servicer, server):
),
}
generic_handler = grpc.method_handlers_generic_handler(
'android.emulation.bluetooth.EmulatedBluetoothService', rpc_method_handlers)
'android.emulation.bluetooth.EmulatedBluetoothService', rpc_method_handlers
)
server.add_generic_rpc_handlers((generic_handler,))
@@ -156,7 +157,8 @@ class EmulatedBluetoothService(object):
"""
@staticmethod
def registerClassicPhy(request_iterator,
def registerClassicPhy(
request_iterator,
target,
options=(),
channel_credentials=None,
@@ -165,15 +167,27 @@ class EmulatedBluetoothService(object):
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.stream_stream(request_iterator, target, '/android.emulation.bluetooth.EmulatedBluetoothService/registerClassicPhy',
metadata=None,
):
return grpc.experimental.stream_stream(
request_iterator,
target,
'/android.emulation.bluetooth.EmulatedBluetoothService/registerClassicPhy',
emulated__bluetooth__pb2.RawData.SerializeToString,
emulated__bluetooth__pb2.RawData.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
)
@staticmethod
def registerBlePhy(request_iterator,
def registerBlePhy(
request_iterator,
target,
options=(),
channel_credentials=None,
@@ -182,15 +196,27 @@ class EmulatedBluetoothService(object):
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.stream_stream(request_iterator, target, '/android.emulation.bluetooth.EmulatedBluetoothService/registerBlePhy',
metadata=None,
):
return grpc.experimental.stream_stream(
request_iterator,
target,
'/android.emulation.bluetooth.EmulatedBluetoothService/registerBlePhy',
emulated__bluetooth__pb2.RawData.SerializeToString,
emulated__bluetooth__pb2.RawData.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
)
@staticmethod
def registerHCIDevice(request_iterator,
def registerHCIDevice(
request_iterator,
target,
options=(),
channel_credentials=None,
@@ -199,9 +225,20 @@ class EmulatedBluetoothService(object):
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.stream_stream(request_iterator, target, '/android.emulation.bluetooth.EmulatedBluetoothService/registerHCIDevice',
metadata=None,
):
return grpc.experimental.stream_stream(
request_iterator,
target,
'/android.emulation.bluetooth.EmulatedBluetoothService/registerHCIDevice',
emulated__bluetooth__packets__pb2.HCIPacket.SerializeToString,
emulated__bluetooth__packets__pb2.HCIPacket.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
)

View File

@@ -1,4 +1,4 @@
# Copyright 2021-2022 Google LLC
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,24 +16,27 @@
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: emulated_bluetooth_vhci.proto
"""Generated protocol buffer code."""
from google.protobuf.internal import builder as _builder
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
import emulated_bluetooth_packets_pb2 as emulated__bluetooth__packets__pb2
from . import emulated_bluetooth_packets_pb2 as emulated__bluetooth__packets__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1d\x65mulated_bluetooth_vhci.proto\x12\x1b\x61ndroid.emulation.bluetooth\x1a emulated_bluetooth_packets.proto2y\n\x15VhciForwardingService\x12`\n\nattachVhci\x12&.android.emulation.bluetooth.HCIPacket\x1a&.android.emulation.bluetooth.HCIPacket(\x01\x30\x01\x42J\n\x1f\x63om.android.emulation.bluetoothP\x01\xf8\x01\x01\xa2\x02\x03\x41\x45\x42\xaa\x02\x1b\x41ndroid.Emulation.Bluetoothb\x06proto3')
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n\x1d\x65mulated_bluetooth_vhci.proto\x12\x1b\x61ndroid.emulation.bluetooth\x1a emulated_bluetooth_packets.proto2y\n\x15VhciForwardingService\x12`\n\nattachVhci\x12&.android.emulation.bluetooth.HCIPacket\x1a&.android.emulation.bluetooth.HCIPacket(\x01\x30\x01\x42J\n\x1f\x63om.android.emulation.bluetoothP\x01\xf8\x01\x01\xa2\x02\x03\x41\x45\x42\xaa\x02\x1b\x41ndroid.Emulation.Bluetoothb\x06proto3'
)
_VHCIFORWARDINGSERVICE = DESCRIPTOR.services_by_name['VhciForwardingService']
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(
DESCRIPTOR, 'emulated_bluetooth_vhci_pb2', globals()
)
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None

View File

@@ -0,0 +1,19 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import emulated_bluetooth_packets_pb2 as _emulated_bluetooth_packets_pb2
from google.protobuf import descriptor as _descriptor
from typing import ClassVar as _ClassVar
DESCRIPTOR: _descriptor.FileDescriptor

View File

@@ -1,4 +1,4 @@
# Copyright 2021-2022 Google LLC
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -82,7 +82,8 @@ def add_VhciForwardingServiceServicer_to_server(servicer, server):
),
}
generic_handler = grpc.method_handlers_generic_handler(
'android.emulation.bluetooth.VhciForwardingService', rpc_method_handlers)
'android.emulation.bluetooth.VhciForwardingService', rpc_method_handlers
)
server.add_generic_rpc_handlers((generic_handler,))
@@ -97,7 +98,8 @@ class VhciForwardingService(object):
"""
@staticmethod
def attachVhci(request_iterator,
def attachVhci(
request_iterator,
target,
options=(),
channel_credentials=None,
@@ -106,9 +108,20 @@ class VhciForwardingService(object):
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.stream_stream(request_iterator, target, '/android.emulation.bluetooth.VhciForwardingService/attachVhci',
metadata=None,
):
return grpc.experimental.stream_stream(
request_iterator,
target,
'/android.emulation.bluetooth.VhciForwardingService/attachVhci',
emulated__bluetooth__packets__pb2.HCIPacket.SerializeToString,
emulated__bluetooth__packets__pb2.HCIPacket.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
)

View File

@@ -30,8 +30,9 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
async def open_file_transport(spec):
'''
Open a File transport (typically not for a real file, but for a PTY or other unix virtual files).
The parameter string is the path of the file to open
Open a File transport (typically not for a real file, but for a PTY or other unix
virtual files).
The parameter string is the path of the file to open.
'''
# Open the file
@@ -39,14 +40,12 @@ async def open_file_transport(spec):
# Setup reading
read_transport, packet_source = await asyncio.get_running_loop().connect_read_pipe(
lambda: StreamPacketSource(),
file
StreamPacketSource, file
)
# Setup writing
write_transport, _ = await asyncio.get_running_loop().connect_write_pipe(
lambda: asyncio.BaseProtocol(),
file
asyncio.BaseProtocol, file
)
packet_sink = StreamPacketSink(write_transport)
@@ -57,4 +56,3 @@ async def open_file_transport(spec):
file.close()
return FileTransport(packet_source, packet_sink)

View File

@@ -40,15 +40,21 @@ async def open_hci_socket_transport(spec):
or a 0-based integer to indicate the adapter number.
'''
HCI_CHANNEL_USER = 1
HCI_CHANNEL_USER = 1 # pylint: disable=invalid-name
# Create a raw HCI socket
try:
hci_socket = socket.socket(socket.AF_BLUETOOTH, socket.SOCK_RAW | socket.SOCK_NONBLOCK, socket.BTPROTO_HCI)
except AttributeError:
hci_socket = socket.socket(
socket.AF_BLUETOOTH,
socket.SOCK_RAW | socket.SOCK_NONBLOCK,
socket.BTPROTO_HCI,
)
except AttributeError as error:
# Not supported on this platform
logger.info("HCI sockets not supported on this platform")
raise Exception('Bluetooth HCI sockets not supported on this platform')
raise Exception(
'Bluetooth HCI sockets not supported on this platform'
) from error
# Compute the adapter index
if spec is None:
@@ -62,20 +68,37 @@ async def open_hci_socket_transport(spec):
try:
ctypes.cdll.LoadLibrary('libc.so.6')
libc = ctypes.CDLL('libc.so.6', use_errno=True)
except OSError:
except OSError as error:
logger.info("HCI sockets not supported on this platform")
raise Exception('Bluetooth HCI sockets not supported on this platform')
raise Exception(
'Bluetooth HCI sockets not supported on this platform'
) from error
libc.bind.argtypes = (ctypes.c_int, ctypes.POINTER(ctypes.c_char), ctypes.c_int)
libc.bind.restype = ctypes.c_int
bind_address = struct.pack('<HHH', socket.AF_BLUETOOTH, adapter_index, HCI_CHANNEL_USER)
if libc.bind(hci_socket.fileno(), ctypes.create_string_buffer(bind_address), len(bind_address)) != 0:
bind_address = struct.pack(
# pylint: disable=no-member
'<HHH',
socket.AF_BLUETOOTH,
adapter_index,
HCI_CHANNEL_USER,
)
if (
libc.bind(
hci_socket.fileno(),
ctypes.create_string_buffer(bind_address),
len(bind_address),
)
!= 0
):
raise IOError(ctypes.get_errno(), os.strerror(ctypes.get_errno()))
class HciSocketSource(ParserSource):
def __init__(self, socket):
def __init__(self, hci_socket):
super().__init__()
self.socket = socket
asyncio.get_running_loop().add_reader(socket.fileno(), self.recv_until_would_block)
self.socket = hci_socket
asyncio.get_running_loop().add_reader(
self.socket.fileno(), self.recv_until_would_block
)
def recv_until_would_block(self):
logger.debug('recv until would block +++')
@@ -92,8 +115,8 @@ async def open_hci_socket_transport(spec):
asyncio.get_running_loop().remove_reader(self.socket.fileno())
class HciSocketSink:
def __init__(self, socket):
self.socket = socket
def __init__(self, hci_socket):
self.socket = hci_socket
self.packets = collections.deque()
self.writer_added = False
@@ -112,9 +135,14 @@ async def open_hci_socket_transport(spec):
break
if self.packets:
# There's still something to send, ensure that we are monitoring the socket
# There's still something to send, ensure that we are monitoring the
# socket
if not self.writer_added:
asyncio.get_running_loop().add_writer(socket.fileno(), self.send_until_would_block)
asyncio.get_running_loop().add_writer(
# pylint: disable=no-member
self.socket.fileno(),
self.send_until_would_block,
)
self.writer_added = True
else:
# Nothing left to send, stop monitoring the socket
@@ -131,9 +159,9 @@ async def open_hci_socket_transport(spec):
asyncio.get_running_loop().remove_writer(self.socket.fileno())
class HciSocketTransport(Transport):
def __init__(self, socket, source, sink):
def __init__(self, hci_socket, source, sink):
super().__init__(source, sink)
self.socket = socket
self.socket = hci_socket
async def close(self):
logger.debug('closing HCI socket transport')

View File

@@ -47,13 +47,11 @@ async def open_pty_transport(spec):
tty.setraw(replica)
read_transport, packet_source = await asyncio.get_running_loop().connect_read_pipe(
lambda: StreamPacketSource(),
io.open(primary, 'rb', closefd=False)
StreamPacketSource, io.open(primary, 'rb', closefd=False)
)
write_transport, _ = await asyncio.get_running_loop().connect_write_pipe(
lambda: asyncio.BaseProtocol(),
io.open(primary, 'wb', closefd=False)
asyncio.BaseProtocol, io.open(primary, 'wb', closefd=False)
)
packet_sink = StreamPacketSink(write_transport)

View File

View File

@@ -17,10 +17,11 @@
# -----------------------------------------------------------------------------
import asyncio
import logging
import usb.core
import usb.util
import threading
import time
import usb.core
import usb.util
from colors import color
from .common import Transport, ParserSource
@@ -48,6 +49,7 @@ async def open_pyusb_transport(spec):
04b4:f901 --> the BT USB dongle with vendor=04b4 and product=f901
'''
# pylint: disable=invalid-name
USB_RECIPIENT_DEVICE = 0x00
USB_REQUEST_TYPE_CLASS = 0x01 << 5
USB_ENDPOINT_EVENTS_IN = 0x81
@@ -80,9 +82,17 @@ async def open_pyusb_transport(spec):
if packet_type == hci.HCI_ACL_DATA_PACKET:
self.device.write(USB_ENDPOINT_ACL_OUT, packet[1:])
elif packet_type == hci.HCI_COMMAND_PACKET:
self.device.ctrl_transfer(USB_RECIPIENT_DEVICE | USB_REQUEST_TYPE_CLASS, 0, 0, 0, packet[1:])
self.device.ctrl_transfer(
USB_RECIPIENT_DEVICE | USB_REQUEST_TYPE_CLASS,
0,
0,
0,
packet[1:],
)
else:
logger.warning(color(f'unsupported packet type {packet_type}', 'red'))
logger.warning(
color(f'unsupported packet type {packet_type}', 'red')
)
except usb.core.USBTimeoutError:
logger.warning('USB Write Timeout')
except usb.core.USBError as error:
@@ -100,7 +110,7 @@ async def open_pyusb_transport(spec):
def run(self):
while self.stop_event is None:
time.sleep(1)
self.loop.call_soon_threadsafe(lambda: self.stop_event.set())
self.loop.call_soon_threadsafe(self.stop_event.set)
class UsbPacketSource(asyncio.Protocol, ParserSource):
def __init__(self, device, sco_enabled):
@@ -108,14 +118,13 @@ async def open_pyusb_transport(spec):
self.device = device
self.loop = asyncio.get_running_loop()
self.queue = asyncio.Queue()
self.dequeue_task = None
self.event_thread = threading.Thread(
target=self.run,
args=(USB_ENDPOINT_EVENTS_IN, hci.HCI_EVENT_PACKET)
target=self.run, args=(USB_ENDPOINT_EVENTS_IN, hci.HCI_EVENT_PACKET)
)
self.event_thread.stop_event = None
self.acl_thread = threading.Thread(
target=self.run,
args=(USB_ENDPOINT_ACL_IN, hci.HCI_ACL_DATA_PACKET)
target=self.run, args=(USB_ENDPOINT_ACL_IN, hci.HCI_ACL_DATA_PACKET)
)
self.acl_thread.stop_event = None
@@ -124,12 +133,12 @@ async def open_pyusb_transport(spec):
if sco_enabled:
self.sco_thread = threading.Thread(
target=self.run,
args=(USB_ENDPOINT_SCO_IN, hci.HCI_SYNCHRONOUS_DATA_PACKET)
args=(USB_ENDPOINT_SCO_IN, hci.HCI_SYNCHRONOUS_DATA_PACKET),
)
self.sco_thread.stop_event = None
def data_received(self, packet):
self.parser.feed_data(packet)
def data_received(self, data):
self.parser.feed_data(data)
def enqueue(self, packet):
self.queue.put_nowait(packet)
@@ -173,16 +182,17 @@ async def open_pyusb_transport(spec):
except usb.core.USBTimeoutError:
continue
except usb.core.USBError:
# Don't log this: because pyusb doesn't really support multiple threads
# reading at the same time, we can get occasional USBError(errno=5)
# Input/Output errors reported, but they seem to be harmless.
# Don't log this: because pyusb doesn't really support multiple
# threads reading at the same time, we can get occasional
# USBError(errno=5) Input/Output errors reported, but they seem to
# be harmless.
# Until support for async or multi-thread support is added to pyusb,
# we'll just live with this as is...
# logger.warning(f'USB read error: {error}')
time.sleep(1) # Sleep one second to avoid busy looping
stop_event = current_thread.stop_event
self.loop.call_soon_threadsafe(lambda: stop_event.set())
self.loop.call_soon_threadsafe(stop_event.set)
class UsbTransport(Transport):
def __init__(self, device, source, sink):
@@ -194,18 +204,28 @@ async def open_pyusb_transport(spec):
await self.sink.stop()
usb.util.release_interface(self.device, 0)
usb_find = usb.core.find
try:
import libusb_package
except ImportError:
logger.debug('libusb_package is not available')
else:
usb_find = libusb_package.find
# Find the device according to the spec moniker
if ':' in spec:
vendor_id, product_id = spec.split(':')
device = usb.core.find(idVendor=int(vendor_id, 16), idProduct=int(product_id, 16))
device = usb_find(idVendor=int(vendor_id, 16), idProduct=int(product_id, 16))
else:
device_index = int(spec)
devices = list(usb.core.find(
devices = list(
usb_find(
find_all=1,
bDeviceClass=USB_DEVICE_CLASS_WIRELESS_CONTROLLER,
bDeviceSubClass=USB_DEVICE_SUBCLASS_RF_CONTROLLER,
bDeviceProtocol = USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER
))
bDeviceProtocol=USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER,
)
)
if len(devices) > device_index:
device = devices[device_index]
else:
@@ -232,6 +252,7 @@ async def open_pyusb_transport(spec):
# Select an alternate setting for SCO, if available
sco_enabled = False
# pylint: disable=line-too-long
# NOTE: this is disabled for now, because SCO with alternate settings is broken,
# see: https://github.com/libusb/libusb/issues/36
#

View File

@@ -60,13 +60,12 @@ async def open_serial_transport(spec):
device = spec
serial_transport, packet_source = await serial_asyncio.create_serial_connection(
asyncio.get_running_loop(),
lambda: StreamPacketSource(),
StreamPacketSource,
device,
baudrate=speed,
rtscts=rtscts,
dsrdtr=dsrdtr
dsrdtr=dsrdtr,
)
packet_sink = StreamPacketSink(serial_transport)
return Transport(packet_source, packet_sink)

View File

@@ -37,13 +37,13 @@ async def open_tcp_client_transport(spec):
'''
class TcpPacketSource(StreamPacketSource):
def connection_lost(self, error):
logger.debug(f'connection lost: {error}')
self.terminated.set_result(error)
def connection_lost(self, exc):
logger.debug(f'connection lost: {exc}')
self.terminated.set_result(exc)
remote_host, remote_port = spec.split(':')
tcp_transport, packet_source = await asyncio.get_running_loop().create_connection(
lambda: TcpPacketSource(),
TcpPacketSource,
host=remote_host,
port=int(remote_port),
)

View File

@@ -49,8 +49,8 @@ async def open_tcp_server_transport(spec):
# Called when a new connection is established
def connection_made(self, transport):
peername = transport.get_extra_info('peername')
logger.debug('connection from {}'.format(peername))
peer_name = transport.get_extra_info('peer_name')
logger.debug(f'connection from {peer_name}')
self.packet_sink.transport = transport
# Called when the client is disconnected

View File

@@ -53,10 +53,13 @@ async def open_udp_transport(spec):
local, remote = spec.split(',')
local_host, local_port = local.split(':')
remote_host, remote_port = remote.split(':')
udp_transport, packet_source = await asyncio.get_running_loop().create_datagram_endpoint(
lambda: UdpPacketSource(),
(
udp_transport,
packet_source,
) = await asyncio.get_running_loop().create_datagram_endpoint(
UdpPacketSource,
local_addr=(local_host, int(local_port)),
remote_addr=(remote_host, int(remote_port))
remote_addr=(remote_host, int(remote_port)),
)
packet_sink = UdpPacketSink(udp_transport)

View File

@@ -17,9 +17,12 @@
# -----------------------------------------------------------------------------
import asyncio
import logging
import usb1
import threading
import collections
import ctypes
import platform
import usb1
from colors import color
from .common import Transport, ParserSource
@@ -33,6 +36,30 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
def load_libusb():
'''
Attempt to load the libusb-1.0 C library from libusb_package in site-packages.
If the library exists, we create a DLL object and initialize the usb1 backend.
This only needs to be done once, but before a usb1.USBContext is created.
If the library does not exists, do nothing and usb1 will search default system paths
when usb1.USBContext is created.
'''
try:
import libusb_package
except ImportError:
logger.debug('libusb_package is not available')
else:
if libusb_path := libusb_package.get_library_path():
logger.debug(f'loading libusb library at {libusb_path}')
dll_loader = (
ctypes.WinDLL if platform.system() == 'Windows' else ctypes.CDLL
)
libusb_dll = dll_loader(
str(libusb_path), use_errno=True, use_last_error=True
)
usb1.loadLibrary(libusb_dll)
async def open_usb_transport(spec):
'''
Open a USB transport.
@@ -44,21 +71,26 @@ async def open_usb_transport(spec):
With <index> as the 0-based index to select amongst all the devices that appear
to be supporting Bluetooth HCI (0 being the first one), or
Where <vendor> and <product> are the vendor ID and product ID in hexadecimal. The
/<serial-number> suffix or #<index> suffix max be specified when more than one device with
the same vendor and product identifiers are present.
/<serial-number> suffix or #<index> suffix max be specified when more than one
device with the same vendor and product identifiers are present.
In addition, if the moniker ends with the symbol "!", the device will be used in "forced" mode:
the first USB interface of the device will be used, regardless of the interface class/subclass.
This may be useful for some devices that use a custom class/subclass but may nonetheless work as-is.
In addition, if the moniker ends with the symbol "!", the device will be used in
"forced" mode:
the first USB interface of the device will be used, regardless of the interface
class/subclass.
This may be useful for some devices that use a custom class/subclass but may
nonetheless work as-is.
Examples:
0 --> the first BT USB dongle
04b4:f901 --> the BT USB dongle with vendor=04b4 and product=f901
04b4:f901#2 --> the third USB device with vendor=04b4 and product=f901
04b4:f901/00E04C239987 --> the BT USB dongle with vendor=04b4 and product=f901 and serial number 00E04C239987
04b4:f901/00E04C239987 --> the BT USB dongle with vendor=04b4 and product=f901 and
serial number 00E04C239987
usb:0B05:17CB! --> the BT USB dongle vendor=0B05 and product=17CB, in "forced" mode.
'''
# pylint: disable=invalid-name
USB_RECIPIENT_DEVICE = 0x00
USB_REQUEST_TYPE_CLASS = 0x01 << 5
USB_DEVICE_CLASS_DEVICE = 0x00
@@ -72,7 +104,7 @@ async def open_usb_transport(spec):
USB_BT_HCI_CLASS_TUPLE = (
USB_DEVICE_CLASS_WIRELESS_CONTROLLER,
USB_DEVICE_SUBCLASS_RF_CONTROLLER,
USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER
USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER,
)
READ_SIZE = 1024
@@ -109,12 +141,15 @@ async def open_usb_transport(spec):
status = transfer.getStatus()
# logger.debug(f'<<< USB out transfer callback: status={status}')
# pylint: disable=no-member
if status == usb1.TRANSFER_COMPLETED:
self.loop.call_soon_threadsafe(self.on_packet_sent_)
elif status == usb1.TRANSFER_CANCELLED:
self.loop.call_soon_threadsafe(self.cancel_done.set_result, None)
else:
logger.warning(color(f'!!! out transfer not completed: status={status}', 'red'))
logger.warning(
color(f'!!! out transfer not completed: status={status}', 'red')
)
def on_packet_sent_(self):
if self.packets:
@@ -129,32 +164,38 @@ async def open_usb_transport(spec):
packet_type = packet[0]
if packet_type == hci.HCI_ACL_DATA_PACKET:
self.transfer.setBulk(
self.acl_out,
packet[1:],
callback=self.on_packet_sent
self.acl_out, packet[1:], callback=self.on_packet_sent
)
logger.debug('submit ACL')
self.transfer.submit()
elif packet_type == hci.HCI_COMMAND_PACKET:
self.transfer.setControl(
USB_RECIPIENT_DEVICE | USB_REQUEST_TYPE_CLASS, 0, 0, 0,
USB_RECIPIENT_DEVICE | USB_REQUEST_TYPE_CLASS,
0,
0,
0,
packet[1:],
callback=self.on_packet_sent
callback=self.on_packet_sent,
)
logger.debug('submit COMMAND')
self.transfer.submit()
else:
logger.warning(color(f'unsupported packet type {packet_type}', 'red'))
async def close(self):
def close(self):
self.closed = True
async def terminate(self):
if not self.closed:
self.close()
# Empty the packet queue so that we don't send any more data
self.packets.clear()
# If we have a transfer in flight, cancel it
if self.transfer.isSubmitted():
# Try to cancel the transfer, but that may fail because it may have already completed
# Try to cancel the transfer, but that may fail because it may have
# already completed
try:
self.transfer.cancel()
@@ -173,12 +214,15 @@ async def open_usb_transport(spec):
self.events_in = events_in
self.loop = asyncio.get_running_loop()
self.queue = asyncio.Queue()
self.dequeue_task = None
self.closed = False
self.event_loop_done = self.loop.create_future()
self.cancel_done = {
hci.HCI_EVENT_PACKET: self.loop.create_future(),
hci.HCI_ACL_DATA_PACKET: self.loop.create_future()
hci.HCI_ACL_DATA_PACKET: self.loop.create_future(),
}
self.events_in_transfer = None
self.acl_in_transfer = None
# Create a thread to process events
self.event_thread = threading.Thread(target=self.run)
@@ -190,7 +234,7 @@ async def open_usb_transport(spec):
self.events_in,
READ_SIZE,
callback=self.on_packet_received,
user_data=hci.HCI_EVENT_PACKET
user_data=hci.HCI_EVENT_PACKET,
)
self.events_in_transfer.submit()
@@ -199,7 +243,7 @@ async def open_usb_transport(spec):
self.acl_in,
READ_SIZE,
callback=self.on_packet_received,
user_data=hci.HCI_ACL_DATA_PACKET
user_data=hci.HCI_ACL_DATA_PACKET,
)
self.acl_in_transfer.submit()
@@ -209,16 +253,28 @@ async def open_usb_transport(spec):
def on_packet_received(self, transfer):
packet_type = transfer.getUserData()
status = transfer.getStatus()
# logger.debug(f'<<< USB IN transfer callback: status={status} packet_type={packet_type} length={transfer.getActualLength()}')
# logger.debug(
# f'<<< USB IN transfer callback: status={status} '
# f'packet_type={packet_type} '
# f'length={transfer.getActualLength()}'
# )
# pylint: disable=no-member
if status == usb1.TRANSFER_COMPLETED:
packet = bytes([packet_type]) + transfer.getBuffer()[:transfer.getActualLength()]
packet = (
bytes([packet_type])
+ transfer.getBuffer()[: transfer.getActualLength()]
)
self.loop.call_soon_threadsafe(self.queue.put_nowait, packet)
elif status == usb1.TRANSFER_CANCELLED:
self.loop.call_soon_threadsafe(self.cancel_done[packet_type].set_result, None)
self.loop.call_soon_threadsafe(
self.cancel_done[packet_type].set_result, None
)
return
else:
logger.warning(color(f'!!! transfer not completed: status={status}', 'red'))
logger.warning(
color(f'!!! transfer not completed: status={status}', 'red')
)
# Re-submit the transfer so we can receive more data
transfer.submit()
@@ -233,7 +289,11 @@ async def open_usb_transport(spec):
def run(self):
logger.debug('starting USB event loop')
while self.events_in_transfer.isSubmitted() or self.acl_in_transfer.isSubmitted():
while (
self.events_in_transfer.isSubmitted()
or self.acl_in_transfer.isSubmitted()
):
# pylint: disable=no-member
try:
self.context.handleEvents()
except usb1.USBErrorInterrupted:
@@ -242,22 +302,33 @@ async def open_usb_transport(spec):
logger.debug('USB event loop done')
self.loop.call_soon_threadsafe(self.event_loop_done.set_result, None)
async def close(self):
def close(self):
self.closed = True
async def terminate(self):
if not self.closed:
self.close()
self.dequeue_task.cancel()
# Cancel the transfers
for transfer in (self.events_in_transfer, self.acl_in_transfer):
if transfer.isSubmitted():
# Try to cancel the transfer, but that may fail because it may have already completed
# Try to cancel the transfer, but that may fail because it may have
# already completed
packet_type = transfer.getUserData()
try:
transfer.cancel()
logger.debug(f'waiting for IN[{packet_type}] transfer cancellation to be done...')
logger.debug(
f'waiting for IN[{packet_type}] transfer cancellation '
'to be done...'
)
await self.cancel_done[packet_type]
logger.debug(f'IN[{packet_type}] transfer cancellation done')
except usb1.USBError:
logger.debug(f'IN[{packet_type}] transfer likely already completed')
logger.debug(
f'IN[{packet_type}] transfer likely already completed'
)
# Wait for the thread to terminate
await self.event_loop_done
@@ -281,13 +352,16 @@ async def open_usb_transport(spec):
sink.start()
async def close(self):
await self.source.close()
await self.sink.close()
self.source.close()
self.sink.close()
await self.source.terminate()
await self.sink.terminate()
self.device.releaseInterface(self.interface)
self.device.close()
self.context.close()
# Find the device according to the spec moniker
load_libusb()
context = usb1.USBContext()
context.open()
try:
@@ -315,9 +389,9 @@ async def open_usb_transport(spec):
except usb1.USBError:
device_serial_number = None
if (
device.getVendorID() == int(vendor_id, 16) and
device.getProductID() == int(product_id, 16) and
(serial_number is None or serial_number == device_serial_number)
device.getVendorID() == int(vendor_id, 16)
and device.getProductID() == int(product_id, 16)
and (serial_number is None or serial_number == device_serial_number)
):
if device_index == 0:
found = device
@@ -328,8 +402,11 @@ async def open_usb_transport(spec):
# Look for a compatible device by index
def device_is_bluetooth_hci(device):
# Check if the device class indicates a match
if (device.getDeviceClass(), device.getDeviceSubClass(), device.getDeviceProtocol()) == \
USB_BT_HCI_CLASS_TUPLE:
if (
device.getDeviceClass(),
device.getDeviceSubClass(),
device.getDeviceProtocol(),
) == USB_BT_HCI_CLASS_TUPLE:
return True
# If the device class is 'Device', look for a matching interface
@@ -337,8 +414,11 @@ async def open_usb_transport(spec):
for configuration in device:
for interface in configuration:
for setting in interface:
if (setting.getClass(), setting.getSubClass(), setting.getProtocol()) == \
USB_BT_HCI_CLASS_TUPLE:
if (
setting.getClass(),
setting.getSubClass(),
setting.getProtocol(),
) == USB_BT_HCI_CLASS_TUPLE:
return True
return False
@@ -360,14 +440,20 @@ async def open_usb_transport(spec):
# Look for the first interface with the right class and endpoints
def find_endpoints(device):
# pylint: disable-next=too-many-nested-blocks
for (configuration_index, configuration) in enumerate(device):
interface = None
for interface in configuration:
setting = None
for setting in interface:
if (
not forced_mode and
(setting.getClass(), setting.getSubClass(), setting.getProtocol()) != USB_BT_HCI_CLASS_TUPLE
not forced_mode
and (
setting.getClass(),
setting.getSubClass(),
setting.getProtocol(),
)
!= USB_BT_HCI_CLASS_TUPLE
):
continue
@@ -382,22 +468,34 @@ async def open_usb_transport(spec):
acl_in = address
elif acl_out is None:
acl_out = address
elif attributes & 0x03 == USB_ENDPOINT_TRANSFER_TYPE_INTERRUPT:
elif (
attributes & 0x03
== USB_ENDPOINT_TRANSFER_TYPE_INTERRUPT
):
if address & USB_ENDPOINT_IN and events_in is None:
events_in = address
# Return if we found all 3 endpoints
if acl_in is not None and acl_out is not None and events_in is not None:
if (
acl_in is not None
and acl_out is not None
and events_in is not None
):
return (
configuration_index + 1,
setting.getNumber(),
setting.getAlternateSetting(),
acl_in,
acl_out,
events_in
events_in,
)
else:
logger.debug(f'skipping configuration {configuration_index + 1} / interface {setting.getNumber()}')
logger.debug(
f'skipping configuration {configuration_index + 1} / '
f'interface {setting.getNumber()}'
)
return None
endpoints = find_endpoints(found)
if endpoints is None:
@@ -414,14 +512,14 @@ async def open_usb_transport(spec):
device = found.open()
# Detach the kernel driver if supported and needed
# Auto-detach the kernel driver if supported
# pylint: disable=no-member
if usb1.hasCapability(usb1.CAP_SUPPORTS_DETACH_KERNEL_DRIVER):
try:
if device.kernelDriverActive(interface):
logger.debug("detaching kernel driver")
device.detachKernelDriver(interface)
except usb1.USBError:
pass
logger.debug('auto-detaching kernel driver')
device.setAutoDetachKernelDriver(True)
except usb1.USBError as error:
logger.warning(f'unable to auto-detach kernel driver: {error}')
# Set the configuration if needed
try:

View File

@@ -33,7 +33,7 @@ async def open_vhci_transport(spec):
path at /dev/vhci), or the path of a VHCI device
'''
HCI_VENDOR_PKT = 0xff
HCI_VENDOR_PKT = 0xFF
HCI_BREDR = 0x00 # Controller type
# Open the VHCI device
@@ -56,4 +56,3 @@ async def open_vhci_transport(spec):
transport.sink.on_packet(bytes([HCI_VENDOR_PKT, HCI_BREDR]))
return transport

View File

@@ -43,7 +43,7 @@ async def open_ws_client_transport(spec):
transport = PumpedTransport(
PumpedPacketSource(websocket.recv),
PumpedPacketSink(websocket.send),
websocket.close
websocket.close,
)
transport.start()
return transport

View File

@@ -44,27 +44,33 @@ async def open_ws_server_transport(spec):
source = ParserSource()
sink = PumpedPacketSink(self.send_packet)
self.connection = asyncio.get_running_loop().create_future()
self.server = None
super().__init__(source, sink)
async def serve(self, local_host, local_port):
self.sink.start()
# pylint: disable-next=no-member
self.server = await websockets.serve(
ws_handler=self.on_connection,
host=local_host if local_host != '_' else None,
port = int(local_port)
port=int(local_port),
)
logger.debug(f'websocket server ready on port {local_port}')
async def on_connection(self, connection):
logger.debug(f'new connection on {connection.local_address} from {connection.remote_address}')
logger.debug(
f'new connection on {connection.local_address} '
f'from {connection.remote_address}'
)
self.connection.set_result(connection)
# pylint: disable=no-member
try:
async for packet in connection:
if type(packet) is bytes:
if isinstance(packet, bytes):
self.source.parser.feed_data(packet)
else:
logger.warn('discarding packet: not a BINARY frame')
logger.warning('discarding packet: not a BINARY frame')
except websockets.WebSocketException as error:
logger.debug(f'exception while receiving packet: {error}')

View File

@@ -19,6 +19,8 @@ import asyncio
import logging
import traceback
import collections
import sys
from typing import Awaitable
from functools import wraps
from colors import color
from pyee import EventEmitter
@@ -34,6 +36,7 @@ logger = logging.getLogger(__name__)
def setup_event_forwarding(emitter, forwarder, event_name):
def emit(*args, **kwargs):
forwarder.emit(event_name, *args, **kwargs)
emitter.on(event_name, emit)
@@ -44,6 +47,8 @@ def composite_listener(cls):
registers/deregisters all methods named `on_<event_name>` as a listener for
the <event_name> event with an emitter.
"""
# pylint: disable=protected-access
def register(self, emitter):
for method_name in dir(cls):
if method_name.startswith('on_'):
@@ -60,7 +65,36 @@ def composite_listener(cls):
# -----------------------------------------------------------------------------
class CompositeEventEmitter(EventEmitter):
class AbortableEventEmitter(EventEmitter):
def abort_on(self, event: str, awaitable: Awaitable):
"""
Set a coroutine or future to abort when an event occur.
"""
future = asyncio.ensure_future(awaitable)
if future.done():
return future
def on_event(*_):
msg = f'abort: {event} event occurred.'
if isinstance(future, asyncio.Task):
# python < 3.9 does not support passing a message on `Task.cancel`
if sys.version_info < (3, 9, 0):
future.cancel()
else:
future.cancel(msg)
else:
future.set_exception(asyncio.CancelledError(msg))
def on_done(_):
self.remove_listener(event, on_event)
self.on(event, on_event)
future.add_done_callback(on_done)
return future
# -----------------------------------------------------------------------------
class CompositeEventEmitter(AbortableEventEmitter):
def __init__(self):
super().__init__()
self._listener = None
@@ -71,6 +105,7 @@ class CompositeEventEmitter(EventEmitter):
@listener.setter
def listener(self, listener):
# pylint: disable=protected-access
if self._listener:
# Call the deregistration methods for each base class that has them
for cls in self._listener.__class__.mro():
@@ -110,7 +145,9 @@ class AsyncRunner:
try:
await item
except Exception as error:
logger.warning(f'{color("!!! Exception in work queue:", "red")} {error}')
logger.warning(
f'{color("!!! Exception in work queue:", "red")} {error}'
)
# Shared default queue
default_queue = WorkQueue()
@@ -131,7 +168,10 @@ class AsyncRunner:
try:
await coroutine
except Exception:
logger.warning(f'{color("!!! Exception in wrapper:", "red")} {traceback.format_exc()}')
logger.warning(
f'{color("!!! Exception in wrapper:", "red")} '
f'{traceback.format_exc()}'
)
asyncio.create_task(run())
else:
@@ -150,7 +190,15 @@ class FlowControlAsyncPipe:
paused (by calling a function passed in when the pipe is created) if the
amount of queued data exceeds a specified threshold.
"""
def __init__(self, pause_source, resume_source, write_to_sink=None, drain_sink=None, threshold=0):
def __init__(
self,
pause_source,
resume_source,
write_to_sink=None,
drain_sink=None,
threshold=0,
):
self.pause_source = pause_source
self.resume_source = resume_source
self.write_to_sink = write_to_sink

View File

@@ -7,6 +7,8 @@ nav:
- Getting Started: getting_started.md
- Development:
- Python Environments: development/python_environments.md
- Contributing: development/contributing.md
- Code Style: development/code_style.md
- Use Cases:
- Overview: use_cases/index.md
- Use Case 1: use_cases/use_case_1.md

View File

@@ -16,4 +16,3 @@ Bumble Python API
### HCI_Disconnect_Command
::: bumble.hci.HCI_Disconnect_Command

View File

@@ -28,5 +28,3 @@ a host that send custom HCI commands that the controller may not understand.
(through which the communication with other virtual controllers will be mediated).
NOTE: this assumes you're running a Link Relay on port `10723`.

View File

@@ -11,4 +11,3 @@ These include:
* [Golden Gate Bridge](gg_bridge.md) - a bridge between GATT and UDP to use with the Golden Gate "stack tool"
* [Show](show.md) - Parse a file with HCI packets and print the details of each packet in a human readable form
* [Link Relay](link_relay.md) - WebSocket relay for virtual RemoteLink instances to communicate with each other.

View File

@@ -31,4 +31,3 @@ The WebSocket path used by a connecting client indicates which virtual "chat roo
It is possible to connect to a "chat room" in a relay as an observer, rather than a virtual controller. In this case, a text-based console can be used to observe what is going on in the "chat room". Tools like [`wscat`](https://github.com/websockets/wscat#readme) or [`websocat`](https://github.com/vi/websocat) can be used for that.
Example: `wscat --connect ws://localhost:10723/test`

View File

@@ -0,0 +1,43 @@
CODE STYLE
==========
The Python code style used in this project follows the [Black code style](https://black.readthedocs.io/en/stable/the_black_code_style/current_style.html).
# Formatting
For now, we are configuring the `black` formatter with the option to leave quotes unchanged.
The preferred quote style is single quotes, which isn't a configurable option for `Black`, so we are not enforcing it. This may change in the future.
## Ignoring Commit for Git Blame
The adoption of `Black` as a formatter came in late in the project, with already a large code base. As a result, a large number of files were changed in a single commit, which gets in the way of tracing authorship with `git blame`. The file `git-blame-ignore-revs` contains the commit hash of when that mass-formatting event occurred, which you can use to skip it in a `git blame` analysis:
!!! example "Ignoring a commit with `git blame`"
```
$ git blame --ignore-revs-file .git-blame-ignore-revs
```
# Linting
The project includes a `pylint` configuration (see the `pyproject.toml` file for details).
The `pre-commit` checks only enforce that there are no errors. But we strongly recommend that you run the linter with warnings enabled at least, and possibly the "Refactor" ('R') and "Convention" ('C') categories as well.
To run the linter, use the `project.lint` invoke command.
!!! example "Running the linter with default options"
With the default settings, Errors and Warnings are enabled, but Refactor and Convention categories are not.
```
$ invoke project.lint
```
!!! example "Running the linter with all categories"
```
$ invoke project.lint --disable=""
```
# Editor/IDE Integration
## Visual Studio Code
The project includes a `.vscode/settings.json` file that specifies the `black` formatter and enables an editor ruler at 88 columns.
You may want to configure your own environment to "format on save" with `black` if you find that useful. We are not making that choice at the workspace level.

View File

@@ -0,0 +1,11 @@
CONTRIBUTING TO THE PROJECT
===========================
To contribute some code to the project, you will need to submit a GitHub Pull Request (a.k.a PR). Please familiarize yourself with how that works (see [GitHub Pull Requests](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/about-pull-requests))
You should follow the project's [code style](code_style.md), and pre-check your code before submitting a PR. The GitHub project is set up with some [Actions](https://github.com/features/actions) that will check that a PR passes at least the basic tests and complies with the coding style, but it is still recommended to check that for yourself before submitting a PR.
To run the basic checks (essentially: running the tests, the linter, and the formatter), use the `project.pre-commit` `invoke` command, and address any issues found:
```
$ invoke project.pre-commit
```

View File

@@ -67,6 +67,8 @@ $ python -m pip install git+https://github.com/google/bumble.git@27c0551
When you work on the Bumble code itself, and run some of the tests or example apps, or import the
module in your own code, you typically either install the package from source in "development mode" as described above, or you may choose to skip the install phase.
If you plan on contributing to the project, please read the [contributing](development/contributing.md) section.
## Without Installing
If you prefer not to install the package (even in development mode), you can load the module directly from its location in the project.
A simple way to do that is to set your `PYTHONPATH` to

View File

@@ -163,4 +163,3 @@ Future features to be considered include:
* Bindings for languages other than Python
* RPC interface to expose most of the API for remote use
* (...suggest anything you want...)

View File

@@ -86,4 +86,3 @@ Use the `--format snoop` option to specify that the file is in that specific for
```shell
$ bumble-show --format snoop btsnoop_hci.log
```

View File

@@ -204,5 +204,3 @@ With the [VHCI transport](../transports/vhci.md) you can attach a Bumble virtual
### Using a Simulated UART HCI
### Bridge to a Remote Controller

View File

@@ -11,4 +11,3 @@ To do that, use the following command:
sudo nvram bluetoothHostControllerSwitchBehavior="never"
```
A reboot shouldn't be necessary after that. See [Tech Note 2295](https://developer.apple.com/library/archive/technotes/tn2295/_index.html)

View File

@@ -42,6 +42,11 @@ This may be useful for some devices that use a custom class/subclass but may non
The library includes two different implementations of the USB transport, implemented using different python bindings for `libusb`.
Using the transport prefix `pyusb:` instead of `usb:` selects the implementation based on [PyUSB](https://pypi.org/project/pyusb/), using the synchronous API of `libusb`, whereas the default implementation is based on [libusb1](https://pypi.org/project/libusb1/), using the asynchronous API of `libusb`. In order to use the alternative PyUSB-based implementation, you need to ensure that you have installed that python module, as it isn't installed by default as a dependency of Bumble.
## Libusb
The `libusb-1.0` shared library is required to use both `usb` and `pyusb` transports. This library should be installed automatically with Bumble, as part of the `libusb_package` Python package.
If your OS or architecture is not supported by `libusb_package`, you can install a system-wide library with `brew install libusb` for Mac or `apt install libusb-1.0-0` for Linux.
## Listing Available USB Devices
### With `usb_probe`

View File

@@ -80,6 +80,7 @@ async def main():
await my_work_queue2.run()
print("MAIN: end (should never get here)")
# -----------------------------------------------------------------------------
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main())

View File

@@ -20,7 +20,7 @@ import sys
import os
import logging
from colors import color
from bumble.device import Device, Peer
from bumble.device import Device
from bumble.transport import open_transport
from bumble.profiles.battery_service import BatteryServiceProxy
@@ -55,7 +55,9 @@ async def main():
# Subscribe to and read the battery level
if battery_service.battery_level:
await battery_service.battery_level.subscribe(
lambda value: print(f'{color("Battery Level Update:", "green")} {value}')
lambda value: print(
f'{color("Battery Level Update:", "green")} {value}'
)
)
value = await battery_service.battery_level.read_value()
print(f'{color("Initial Battery Level:", "green")} {value}')

View File

@@ -44,11 +44,19 @@ async def main():
# Set the advertising data
device.advertising_data = bytes(
AdvertisingData([
(AdvertisingData.COMPLETE_LOCAL_NAME, bytes('Bumble Battery', 'utf-8')),
(AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS, bytes(battery_service.uuid)),
(AdvertisingData.APPEARANCE, struct.pack('<H', 0x0340))
])
AdvertisingData(
[
(
AdvertisingData.COMPLETE_LOCAL_NAME,
bytes('Bumble Battery', 'utf-8'),
),
(
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
bytes(battery_service.uuid),
),
(AdvertisingData.APPEARANCE, struct.pack('<H', 0x0340)),
]
)
)
# Go!
@@ -58,7 +66,9 @@ async def main():
# Notify every 3 seconds
while True:
await asyncio.sleep(3.0)
await device.notify_subscribers(battery_service.battery_level_characteristic)
await device.notify_subscribers(
battery_service.battery_level_characteristic
)
# -----------------------------------------------------------------------------

View File

@@ -28,7 +28,9 @@ from bumble.transport import open_transport
# -----------------------------------------------------------------------------
async def main():
if len(sys.argv) != 3:
print('Usage: device_information_client.py <transport-spec> <bluetooth-address>')
print(
'Usage: device_information_client.py <transport-spec> <bluetooth-address>'
)
print('example: device_information_client.py usb:0 E1:CA:72:48:C4:E8')
return
@@ -49,7 +51,9 @@ async def main():
# Discover the Device Information service
peer = Peer(connection)
print('=== Discovering Device Information Service')
device_information_service = await peer.discover_service_and_create_proxy(DeviceInformationServiceProxy)
device_information_service = await peer.discover_service_and_create_proxy(
DeviceInformationServiceProxy
)
# Check that the service was found
if device_information_service is None:
@@ -58,21 +62,51 @@ async def main():
# Read and print the fields
if device_information_service.manufacturer_name is not None:
print(color('Manufacturer Name: ', 'green'), await device_information_service.manufacturer_name.read_value())
print(
color('Manufacturer Name: ', 'green'),
await device_information_service.manufacturer_name.read_value(),
)
if device_information_service.model_number is not None:
print(color('Model Number: ', 'green'), await device_information_service.model_number.read_value())
print(
color('Model Number: ', 'green'),
await device_information_service.model_number.read_value(),
)
if device_information_service.serial_number is not None:
print(color('Serial Number: ', 'green'), await device_information_service.serial_number.read_value())
print(
color('Serial Number: ', 'green'),
await device_information_service.serial_number.read_value(),
)
if device_information_service.hardware_revision is not None:
print(color('Hardware Revision: ', 'green'), await device_information_service.hardware_revision.read_value())
print(
color('Hardware Revision: ', 'green'),
await device_information_service.hardware_revision.read_value(),
)
if device_information_service.firmware_revision is not None:
print(color('Firmware Revision: ', 'green'), await device_information_service.firmware_revision.read_value())
print(
color('Firmware Revision: ', 'green'),
await device_information_service.firmware_revision.read_value(),
)
if device_information_service.software_revision is not None:
print(color('Software Revision: ', 'green'), await device_information_service.software_revision.read_value())
print(
color('Software Revision: ', 'green'),
await device_information_service.software_revision.read_value(),
)
if device_information_service.system_id is not None:
print(color('System ID: ', 'green'), await device_information_service.system_id.read_value())
if device_information_service.ieee_regulatory_certification_data_list is not None:
print(color('Regulatory Certification:', 'green'), (await device_information_service.ieee_regulatory_certification_data_list.read_value()).hex())
print(
color('System ID: ', 'green'),
await device_information_service.system_id.read_value(),
)
if (
device_information_service.ieee_regulatory_certification_data_list
is not None
):
print(
color('Regulatory Certification:', 'green'),
(
# pylint: disable-next=line-too-long
await device_information_service.ieee_regulatory_certification_data_list.read_value()
).hex(),
)
# -----------------------------------------------------------------------------

View File

@@ -44,16 +44,21 @@ async def main():
serial_number='7654321',
hardware_revision='1.1.3',
software_revision='2.5.6',
system_id = (0x123456, 0x8877665544)
system_id=(0x123456, 0x8877665544),
)
device.add_service(device_information_service)
# Set the advertising data
device.advertising_data = bytes(
AdvertisingData([
(AdvertisingData.COMPLETE_LOCAL_NAME, bytes('Bumble Device', 'utf-8')),
(AdvertisingData.APPEARANCE, struct.pack('<H', 0x0340))
])
AdvertisingData(
[
(
AdvertisingData.COMPLETE_LOCAL_NAME,
bytes('Bumble Device', 'utf-8'),
),
(AdvertisingData.APPEARANCE, struct.pack('<H', 0x0340)),
]
)
)
# Go!
@@ -61,6 +66,7 @@ async def main():
await device.start_advertising(auto_restart=True)
await hci_source.wait_for_termination()
# -----------------------------------------------------------------------------
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main())

View File

@@ -20,7 +20,7 @@ import sys
import os
import logging
from colors import color
from bumble.device import Device, Peer
from bumble.device import Device
from bumble.transport import open_transport
from bumble.profiles.heart_rate_service import HeartRateServiceProxy
@@ -61,7 +61,9 @@ async def main():
# Subscribe to the heart rate measurement
if heart_rate_service.heart_rate_measurement:
await heart_rate_service.heart_rate_measurement.subscribe(
lambda value: print(f'{color("Heart Rate Measurement:", "green")} {value}')
lambda value: print(
f'{color("Heart Rate Measurement:", "green")} {value}'
)
)
await peer.sustain()

View File

@@ -55,29 +55,47 @@ async def main():
serial_number='7654321',
hardware_revision='1.1.3',
software_revision='2.5.6',
system_id = (0x123456, 0x8877665544)
system_id=(0x123456, 0x8877665544),
)
heart_rate_service = HeartRateService(
read_heart_rate_measurement=lambda _: HeartRateService.HeartRateMeasurement(
heart_rate=100 + int(50 * math.sin(time.time() * math.pi / 60)),
sensor_contact_detected=random.choice((True, False, None)),
energy_expended = random.choice((int((time.time() - energy_start_time) * 100), None)),
rr_intervals = random.choice(((random.randint(900, 1100) / 1000, random.randint(900, 1100) / 1000), None))
energy_expended=random.choice(
(int((time.time() - energy_start_time) * 100), None)
),
rr_intervals=random.choice(
(
(
random.randint(900, 1100) / 1000,
random.randint(900, 1100) / 1000,
),
None,
)
),
),
body_sensor_location=HeartRateService.BodySensorLocation.WRIST,
reset_energy_expended=lambda _: reset_energy_expended()
reset_energy_expended=lambda _: reset_energy_expended(),
)
device.add_services([device_information_service, heart_rate_service])
# Set the advertising data
device.advertising_data = bytes(
AdvertisingData([
(AdvertisingData.COMPLETE_LOCAL_NAME, bytes('Bumble Heart', 'utf-8')),
(AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS, bytes(heart_rate_service.uuid)),
(AdvertisingData.APPEARANCE, struct.pack('<H', 0x0340))
])
AdvertisingData(
[
(
AdvertisingData.COMPLETE_LOCAL_NAME,
bytes('Bumble Heart', 'utf-8'),
),
(
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
bytes(heart_rate_service.uuid),
),
(AdvertisingData.APPEARANCE, struct.pack('<H', 0x0340)),
]
)
)
# Go!
@@ -87,7 +105,9 @@ async def main():
# Notify every 3 seconds
while True:
await asyncio.sleep(3.0)
await device.notify_subscribers(heart_rate_service.heart_rate_measurement_characteristic)
await device.notify_subscribers(
heart_rate_service.heart_rate_measurement_characteristic
)
# -----------------------------------------------------------------------------

View File

@@ -20,8 +20,8 @@ import sys
import os
import logging
import struct
import websockets
import json
import websockets
from colors import color
from bumble.core import AdvertisingData
@@ -43,7 +43,7 @@ from bumble.gatt import (
GATT_PROTOCOL_MODE_CHARACTERISTIC,
GATT_HID_INFORMATION_CHARACTERISTIC,
GATT_HID_CONTROL_POINT_CHARACTERISTIC,
GATT_REPORT_REFERENCE_DESCRIPTOR
GATT_REPORT_REFERENCE_DESCRIPTOR,
)
# -----------------------------------------------------------------------------
@@ -58,44 +58,80 @@ HID_OUTPUT_REPORT = 0x02
HID_FEATURE_REPORT = 0x03
# Report Map
HID_KEYBOARD_REPORT_MAP = bytes([
0x05, 0x01, # Usage Page (Generic Desktop Ctrls)
0x09, 0x06, # Usage (Keyboard)
0xA1, 0x01, # Collection (Application)
0x85, 0x01, # . Report ID (1)
0x05, 0x07, # . Usage Page (Kbrd/Keypad)
0x19, 0xE0, # . Usage Minimum (0xE0)
0x29, 0xE7, # . Usage Maximum (0xE7)
0x15, 0x00, # . Logical Minimum (0)
0x25, 0x01, # . Logical Maximum (1)
0x75, 0x01, # . Report Size (1)
0x95, 0x08, # . Report Count (8)
0x81, 0x02, # . Input (Data,Var,Abs,No Wrap,Linear,Preferred State,No Null Position)
0x95, 0x01, # . Report Count (1)
0x75, 0x08, # . Report Size (8)
0x81, 0x01, # . Input (Const,Array,Abs,No Wrap,Linear,Preferred State,No Null Position)
0x95, 0x06, # . Report Count (6)
0x75, 0x08, # . Report Size (8)
0x15, 0x00, # . Logical Minimum (0x00)
0x25, 0x94, # . Logical Maximum (0x94)
0x05, 0x07, # . Usage Page (Kbrd/Keypad)
0x19, 0x00, # . Usage Minimum (0x00)
0x29, 0x94, # . Usage Maximum (0x94)
0x81, 0x00, # . Input (Data,Array,Abs,No Wrap,Linear,Preferred State,No Null Position)
0x95, 0x05, # . Report Count (5)
0x75, 0x01, # . Report Size (1)
0x05, 0x08, # . Usage Page (LEDs)
0x19, 0x01, # . Usage Minimum (Num Lock)
0x29, 0x05, # . Usage Maximum (Kana)
0x91, 0x02, # . Output (Data,Var,Abs,No Wrap,Linear,Preferred State,No Null Position,Non-volatile)
0x95, 0x01, # . Report Count (1)
0x75, 0x03, # . Report Size (3)
0x91, 0x01, # . Output (Const,Array,Abs,No Wrap,Linear,Preferred State,No Null Position,Non-volatile)
0xC0 # End Collection
])
HID_KEYBOARD_REPORT_MAP = bytes(
# pylint: disable=line-too-long
[
0x05,
0x01, # Usage Page (Generic Desktop Controls)
0x09,
0x06, # Usage (Keyboard)
0xA1,
0x01, # Collection (Application)
0x85,
0x01, # . Report ID (1)
0x05,
0x07, # . Usage Page (Keyboard/Keypad)
0x19,
0xE0, # . Usage Minimum (0xE0)
0x29,
0xE7, # . Usage Maximum (0xE7)
0x15,
0x00, # . Logical Minimum (0)
0x25,
0x01, # . Logical Maximum (1)
0x75,
0x01, # . Report Size (1)
0x95,
0x08, # . Report Count (8)
0x81,
0x02, # . Input (Data,Var,Abs,No Wrap,Linear,Preferred State,No Null Position)
0x95,
0x01, # . Report Count (1)
0x75,
0x08, # . Report Size (8)
0x81,
0x01, # . Input (Const,Array,Abs,No Wrap,Linear,Preferred State,No Null Position)
0x95,
0x06, # . Report Count (6)
0x75,
0x08, # . Report Size (8)
0x15,
0x00, # . Logical Minimum (0x00)
0x25,
0x94, # . Logical Maximum (0x94)
0x05,
0x07, # . Usage Page (Keyboard/Keypad)
0x19,
0x00, # . Usage Minimum (0x00)
0x29,
0x94, # . Usage Maximum (0x94)
0x81,
0x00, # . Input (Data,Array,Abs,No Wrap,Linear,Preferred State,No Null Position)
0x95,
0x05, # . Report Count (5)
0x75,
0x01, # . Report Size (1)
0x05,
0x08, # . Usage Page (LEDs)
0x19,
0x01, # . Usage Minimum (Num Lock)
0x29,
0x05, # . Usage Maximum (Kana)
0x91,
0x02, # . Output (Data,Var,Abs,No Wrap,Linear,Preferred State,No Null Position,Non-volatile)
0x95,
0x01, # . Report Count (1)
0x75,
0x03, # . Report Size (3)
0x91,
0x01, # . Output (Const,Array,Abs,No Wrap,Linear,Preferred State,No Null Position,Non-volatile)
0xC0, # End Collection
]
)
# -----------------------------------------------------------------------------
# pylint: disable=invalid-overridden-method
class ServerListener(Device.Listener, Connection.Listener):
def __init__(self, device):
self.device = device
@@ -111,7 +147,7 @@ class ServerListener(Device.Listener, Connection.Listener):
# -----------------------------------------------------------------------------
def on_hid_control_point_write(connection, value):
def on_hid_control_point_write(_connection, value):
print(f'Control Point Write: {value}')
@@ -133,31 +169,41 @@ async def keyboard_host(device, peer_address):
return
await peer.discover_characteristics()
protocol_mode_characteristics = peer.get_characteristics_by_uuid(GATT_PROTOCOL_MODE_CHARACTERISTIC)
protocol_mode_characteristics = peer.get_characteristics_by_uuid(
GATT_PROTOCOL_MODE_CHARACTERISTIC
)
if not protocol_mode_characteristics:
print(color('!!! No Protocol Mode characteristic', 'red'))
return
protocol_mode_characteristic = protocol_mode_characteristics[0]
hid_information_characteristics = peer.get_characteristics_by_uuid(GATT_HID_INFORMATION_CHARACTERISTIC)
hid_information_characteristics = peer.get_characteristics_by_uuid(
GATT_HID_INFORMATION_CHARACTERISTIC
)
if not hid_information_characteristics:
print(color('!!! No HID Information characteristic', 'red'))
return
hid_information_characteristic = hid_information_characteristics[0]
report_map_characteristics = peer.get_characteristics_by_uuid(GATT_REPORT_MAP_CHARACTERISTIC)
report_map_characteristics = peer.get_characteristics_by_uuid(
GATT_REPORT_MAP_CHARACTERISTIC
)
if not report_map_characteristics:
print(color('!!! No Report Map characteristic', 'red'))
return
report_map_characteristic = report_map_characteristics[0]
control_point_characteristics = peer.get_characteristics_by_uuid(GATT_HID_CONTROL_POINT_CHARACTERISTIC)
control_point_characteristics = peer.get_characteristics_by_uuid(
GATT_HID_CONTROL_POINT_CHARACTERISTIC
)
if not control_point_characteristics:
print(color('!!! No Control Point characteristic', 'red'))
return
# control_point_characteristic = control_point_characteristics[0]
report_characteristics = peer.get_characteristics_by_uuid(GATT_REPORT_CHARACTERISTIC)
report_characteristics = peer.get_characteristics_by_uuid(
GATT_REPORT_CHARACTERISTIC
)
if not report_characteristics:
print(color('!!! No Report characteristic', 'red'))
return
@@ -165,13 +211,20 @@ async def keyboard_host(device, peer_address):
print(color('REPORT:', 'yellow'), characteristic)
if characteristic.properties & Characteristic.NOTIFY:
await peer.discover_descriptors(characteristic)
report_reference_descriptor = characteristic.get_descriptor(GATT_REPORT_REFERENCE_DESCRIPTOR)
report_reference_descriptor = characteristic.get_descriptor(
GATT_REPORT_REFERENCE_DESCRIPTOR
)
if report_reference_descriptor:
report_reference = await peer.read_value(report_reference_descriptor)
print(color(' Report Reference:', 'blue'), report_reference.hex())
else:
report_reference = bytes([0, 0])
await peer.subscribe(characteristic, lambda value, param=f'[{i}] {report_reference.hex()}': on_report(param, value))
await peer.subscribe(
characteristic,
lambda value, param=f'[{i}] {report_reference.hex()}': on_report(
param, value
),
)
protocol_mode = await peer.read_value(protocol_mode_characteristic)
print(f'Protocol Mode: {protocol_mode.hex()}')
@@ -192,23 +245,34 @@ async def keyboard_device(device, command):
Characteristic.READABLE | Characteristic.WRITEABLE,
bytes([0, 0, 0, 0, 0, 0, 0, 0]),
[
Descriptor(GATT_REPORT_REFERENCE_DESCRIPTOR, Descriptor.READABLE, bytes([0x01, HID_INPUT_REPORT]))
]
Descriptor(
GATT_REPORT_REFERENCE_DESCRIPTOR,
Descriptor.READABLE,
bytes([0x01, HID_INPUT_REPORT]),
)
],
)
# Create an 'output report' characteristic to receive keyboard reports from the host
output_report_characteristic = Characteristic(
GATT_REPORT_CHARACTERISTIC,
Characteristic.READ | Characteristic.WRITE | Characteristic.WRITE_WITHOUT_RESPONSE,
Characteristic.READ
| Characteristic.WRITE
| Characteristic.WRITE_WITHOUT_RESPONSE,
Characteristic.READABLE | Characteristic.WRITEABLE,
bytes([0]),
[
Descriptor(GATT_REPORT_REFERENCE_DESCRIPTOR, Descriptor.READABLE, bytes([0x01, HID_OUTPUT_REPORT]))
]
Descriptor(
GATT_REPORT_REFERENCE_DESCRIPTOR,
Descriptor.READABLE,
bytes([0x01, HID_OUTPUT_REPORT]),
)
],
)
# Add the services to the GATT sever
device.add_services([
device.add_services(
[
Service(
GATT_DEVICE_INFORMATION_SERVICE,
[
@@ -216,9 +280,9 @@ async def keyboard_device(device, command):
GATT_MANUFACTURER_NAME_STRING_CHARACTERISTIC,
Characteristic.READ,
Characteristic.READABLE,
'Bumble'
'Bumble',
)
]
],
),
Service(
GATT_HUMAN_INTERFACE_DEVICE_SERVICE,
@@ -227,29 +291,31 @@ async def keyboard_device(device, command):
GATT_PROTOCOL_MODE_CHARACTERISTIC,
Characteristic.READ,
Characteristic.READABLE,
bytes([HID_REPORT_PROTOCOL])
bytes([HID_REPORT_PROTOCOL]),
),
Characteristic(
GATT_HID_INFORMATION_CHARACTERISTIC,
Characteristic.READ,
Characteristic.READABLE,
bytes([0x11, 0x01, 0x00, 0x03]) # bcdHID=1.1, bCountryCode=0x00, Flags=RemoteWake|NormallyConnectable
# bcdHID=1.1, bCountryCode=0x00,
# Flags=RemoteWake|NormallyConnectable
bytes([0x11, 0x01, 0x00, 0x03]),
),
Characteristic(
GATT_HID_CONTROL_POINT_CHARACTERISTIC,
Characteristic.WRITE_WITHOUT_RESPONSE,
Characteristic.WRITEABLE,
CharacteristicValue(write=on_hid_control_point_write)
CharacteristicValue(write=on_hid_control_point_write),
),
Characteristic(
GATT_REPORT_MAP_CHARACTERISTIC,
Characteristic.READ,
Characteristic.READABLE,
HID_KEYBOARD_REPORT_MAP
HID_KEYBOARD_REPORT_MAP,
),
input_report_characteristic,
output_report_characteristic
]
output_report_characteristic,
],
),
Service(
GATT_BATTERY_SERVICE,
@@ -258,11 +324,12 @@ async def keyboard_device(device, command):
GATT_BATTERY_LEVEL_CHARACTERISTIC,
Characteristic.READ,
Characteristic.READABLE,
bytes([100])
bytes([100]),
)
],
),
]
)
])
# Debug print
for attribute in device.gatt_server.attributes:
@@ -270,13 +337,20 @@ async def keyboard_device(device, command):
# Set the advertising data
device.advertising_data = bytes(
AdvertisingData([
(AdvertisingData.COMPLETE_LOCAL_NAME, bytes('Bumble Keyboard', 'utf-8')),
(AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
bytes(GATT_HUMAN_INTERFACE_DEVICE_SERVICE)),
AdvertisingData(
[
(
AdvertisingData.COMPLETE_LOCAL_NAME,
bytes('Bumble Keyboard', 'utf-8'),
),
(
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
bytes(GATT_HUMAN_INTERFACE_DEVICE_SERVICE),
),
(AdvertisingData.APPEARANCE, struct.pack('<H', 0x03C1)),
(AdvertisingData.FLAGS, bytes([0x05]))
])
(AdvertisingData.FLAGS, bytes([0x05])),
]
)
)
# Attach a listener
@@ -288,7 +362,7 @@ async def keyboard_device(device, command):
if command == 'web':
# Start a Websocket server to receive events from a web page
async def serve(websocket, path):
async def serve(websocket, _path):
while True:
try:
message = await websocket.recv()
@@ -301,16 +375,24 @@ async def keyboard_device(device, command):
key = parsed['key']
if len(key) == 1:
code = ord(key)
if code >= ord('a') and code <= ord('z'):
if ord('a') <= code <= ord('z'):
hid_code = 0x04 + code - ord('a')
input_report_characteristic.value = bytes([0, 0, hid_code, 0, 0, 0, 0, 0])
await device.notify_subscribers(input_report_characteristic)
input_report_characteristic.value = bytes(
[0, 0, hid_code, 0, 0, 0, 0, 0]
)
await device.notify_subscribers(
input_report_characteristic
)
elif message_type == 'keyup':
input_report_characteristic.value = bytes.fromhex('0000000000000000')
input_report_characteristic.value = bytes.fromhex(
'0000000000000000'
)
await device.notify_subscribers(input_report_characteristic)
except websockets.exceptions.ConnectionClosedOK:
pass
# pylint: disable-next=no-member
await websockets.serve(serve, 'localhost', 8989)
await asyncio.get_event_loop().create_future()
else:
@@ -321,7 +403,9 @@ async def keyboard_device(device, command):
# Keypress for the letter
keycode = 0x04 + letter - 0x61
input_report_characteristic.value = bytes([0, 0, keycode, 0, 0, 0, 0, 0])
input_report_characteristic.value = bytes(
[0, 0, keycode, 0, 0, 0, 0, 0]
)
await device.notify_subscribers(input_report_characteristic)
# Key release
@@ -332,13 +416,20 @@ async def keyboard_device(device, command):
# -----------------------------------------------------------------------------
async def main():
if len(sys.argv) < 4:
print('Usage: python keyboard.py <device-config> <transport-spec> <command>')
print(' where <command> is one of:')
print(' connect <address> (run a keyboard host, connecting to a keyboard)')
print(' web (run a keyboard with keypress input from a web page, see keyboard.html')
print(' sim (run a keyboard simulation, emitting a canned sequence of keystrokes')
print(
'Usage: python keyboard.py <device-config> <transport-spec> <command>'
' where <command> is one of:\n'
' connect <address> (run a keyboard host, connecting to a keyboard)\n'
' web (run a keyboard with keypress input from a web page, '
'see keyboard.html\n'
)
print(
' sim (run a keyboard simulation, emitting a canned sequence of keystrokes'
)
print('example: python keyboard.py keyboard.json usb:0 sim')
print('example: python keyboard.py keyboard.json usb:0 connect A0:A1:A2:A3:A4:A5')
print(
'example: python keyboard.py keyboard.json usb:0 connect A0:A1:A2:A3:A4:A5'
)
return
async with await open_transport_or_link(sys.argv[2]) as (hci_source, hci_sink):
@@ -349,7 +440,7 @@ async def main():
if command == 'connect':
# Run as a Keyboard host
await keyboard_host(device, sys.argv[4])
elif command in {'sim', 'web'}:
elif command in ('sim', 'web'):
# Run as a keyboard device
await keyboard_device(device, command)

View File

@@ -27,12 +27,9 @@ from bumble.core import (
BT_BR_EDR_TRANSPORT,
BT_AVDTP_PROTOCOL_ID,
BT_AUDIO_SINK_SERVICE,
BT_L2CAP_PROTOCOL_ID
)
from bumble.avdtp import (
Protocol as AVDTP_Protocol,
find_avdtp_service_with_connection
BT_L2CAP_PROTOCOL_ID,
)
from bumble.avdtp import Protocol as AVDTP_Protocol
from bumble.a2dp import make_audio_source_service_sdp_records
from bumble.sdp import (
Client as SDP_Client,
@@ -40,7 +37,7 @@ from bumble.sdp import (
DataElement,
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
)
@@ -48,11 +45,14 @@ from bumble.sdp import (
def sdp_records():
service_record_handle = 0x00010001
return {
service_record_handle: make_audio_source_service_sdp_records(service_record_handle)
service_record_handle: make_audio_source_service_sdp_records(
service_record_handle
)
}
# -----------------------------------------------------------------------------
# pylint: disable-next=too-many-nested-blocks
async def find_a2dp_service(device, connection):
# Connect to the SDP Server
sdp_client = SDP_Client(device)
@@ -64,8 +64,8 @@ async def find_a2dp_service(device, connection):
[
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID
]
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
],
)
print(color('==================================', 'blue'))
@@ -78,8 +78,7 @@ async def find_a2dp_service(device, connection):
# Service classes
service_class_id_list = ServiceAttribute.find_attribute_in_list(
attribute_list,
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID
attribute_list, SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID
)
if service_class_id_list:
if service_class_id_list.value:
@@ -89,8 +88,7 @@ async def find_a2dp_service(device, connection):
# Protocol info
protocol_descriptor_list = ServiceAttribute.find_attribute_in_list(
attribute_list,
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID
attribute_list, SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID
)
if protocol_descriptor_list:
print(color(' Protocol:', 'green'))
@@ -103,27 +101,38 @@ async def find_a2dp_service(device, connection):
if len(protocol_descriptor.value) >= 2:
avdtp_version_major = protocol_descriptor.value[1].value >> 8
avdtp_version_minor = protocol_descriptor.value[1].value & 0xFF
print(f'{color(" AVDTP Version:", "cyan")} {avdtp_version_major}.{avdtp_version_minor}')
print(
f'{color(" AVDTP Version:", "cyan")} '
f'{avdtp_version_major}.{avdtp_version_minor}'
)
service_version = (avdtp_version_major, avdtp_version_minor)
# Profile info
bluetooth_profile_descriptor_list = ServiceAttribute.find_attribute_in_list(
attribute_list,
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID
attribute_list, SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID
)
if bluetooth_profile_descriptor_list:
if bluetooth_profile_descriptor_list.value:
if bluetooth_profile_descriptor_list.value[0].type == DataElement.SEQUENCE:
bluetooth_profile_descriptors = bluetooth_profile_descriptor_list.value
if (
bluetooth_profile_descriptor_list.value[0].type
== DataElement.SEQUENCE
):
bluetooth_profile_descriptors = (
bluetooth_profile_descriptor_list.value
)
else:
# Sometimes, instead of a list of lists, we just find a list. Fix that
# Sometimes, instead of a list of lists, we just find a list.
# Fix that.
bluetooth_profile_descriptors = [bluetooth_profile_descriptor_list]
print(color(' Profiles:', 'green'))
for bluetooth_profile_descriptor in bluetooth_profile_descriptors:
version_major = bluetooth_profile_descriptor.value[1].value >> 8
version_minor = bluetooth_profile_descriptor.value[1].value & 0xFF
print(f' {bluetooth_profile_descriptor.value[0].value} - version {version_major}.{version_minor}')
print(
f' {bluetooth_profile_descriptor.value[0].value}'
f' - version {version_major}.{version_minor}'
)
await sdp_client.disconnect()
return service_version
@@ -147,7 +156,8 @@ async def main():
# Start the controller
await device.power_on()
# Setup the SDP to expose a SRC service, in case the remote device queries us back
# Setup the SDP to expose a SRC service, in case the remote device queries us
# back
device.sdp_service_records = sdp_records()
# Connect to a peer

View File

@@ -20,7 +20,6 @@ import sys
import os
import logging
from colors import color
from bumble.device import Device
from bumble.transport import open_transport_or_link
from bumble.core import BT_BR_EDR_TRANSPORT
@@ -28,7 +27,7 @@ from bumble.avdtp import (
AVDTP_AUDIO_MEDIA_TYPE,
Protocol,
Listener,
MediaCodecCapabilities
MediaCodecCapabilities,
)
from bumble.a2dp import (
make_audio_sink_service_sdp_records,
@@ -39,19 +38,19 @@ from bumble.a2dp import (
SBC_LOUDNESS_ALLOCATION_METHOD,
SBC_STEREO_CHANNEL_MODE,
SBC_JOINT_STEREO_CHANNEL_MODE,
SbcMediaCodecInformation
SbcMediaCodecInformation,
)
Context = {
'output': None
}
Context = {'output': None}
# -----------------------------------------------------------------------------
def sdp_records():
service_record_handle = 0x00010001
return {
service_record_handle: make_audio_sink_service_sdp_records(service_record_handle)
service_record_handle: make_audio_sink_service_sdp_records(
service_record_handle
)
}
@@ -67,14 +66,17 @@ def codec_capabilities():
SBC_MONO_CHANNEL_MODE,
SBC_DUAL_CHANNEL_MODE,
SBC_STEREO_CHANNEL_MODE,
SBC_JOINT_STEREO_CHANNEL_MODE
SBC_JOINT_STEREO_CHANNEL_MODE,
],
block_lengths=[4, 8, 12, 16],
subbands=[4, 8],
allocation_methods = [SBC_LOUDNESS_ALLOCATION_METHOD, SBC_SNR_ALLOCATION_METHOD],
allocation_methods=[
SBC_LOUDNESS_ALLOCATION_METHOD,
SBC_SNR_ALLOCATION_METHOD,
],
minimum_bitpool_value=2,
maximum_bitpool_value = 53
)
maximum_bitpool_value=53,
),
)
@@ -89,8 +91,8 @@ def on_avdtp_connection(server):
def on_rtp_packet(packet):
header = packet.payload[0]
fragmented = header >> 7
start = (header >> 6) & 0x01
last = (header >> 5) & 0x01
# start = (header >> 6) & 0x01
# last = (header >> 5) & 0x01
number_of_frames = header & 0x0F
if fragmented:
@@ -104,7 +106,10 @@ def on_rtp_packet(packet):
# -----------------------------------------------------------------------------
async def main():
if len(sys.argv) < 4:
print('Usage: run_a2dp_sink.py <device-config> <transport-spec> <sbc-file> [<bt-addr>]')
print(
'Usage: run_a2dp_sink.py <device-config> <transport-spec> <sbc-file> '
'[<bt-addr>]'
)
print('example: run_a2dp_sink.py classic1.json usb:0 output.sbc')
return
@@ -133,7 +138,9 @@ async def main():
# Connect to the source
target_address = sys.argv[4]
print(f'=== Connecting to {target_address}...')
connection = await device.connect(target_address, transport=BT_BR_EDR_TRANSPORT)
connection = await device.connect(
target_address, transport=BT_BR_EDR_TRANSPORT
)
print(f'=== Connected to {connection.peer_address}!')
# Request authentication

View File

@@ -30,7 +30,7 @@ from bumble.avdtp import (
MediaCodecCapabilities,
MediaPacketPump,
Protocol,
Listener
Listener,
)
from bumble.a2dp import (
SBC_JOINT_STEREO_CHANNEL_MODE,
@@ -38,7 +38,7 @@ from bumble.a2dp import (
make_audio_source_service_sdp_records,
A2DP_SBC_CODEC_TYPE,
SbcMediaCodecInformation,
SbcPacketSource
SbcPacketSource,
)
@@ -46,13 +46,16 @@ from bumble.a2dp import (
def sdp_records():
service_record_handle = 0x00010001
return {
service_record_handle: make_audio_source_service_sdp_records(service_record_handle)
service_record_handle: make_audio_source_service_sdp_records(
service_record_handle
)
}
# -----------------------------------------------------------------------------
def codec_capabilities():
# NOTE: this shouldn't be hardcoded, but should be inferred from the input file instead
# NOTE: this shouldn't be hardcoded, but should be inferred from the input file
# instead
return MediaCodecCapabilities(
media_type=AVDTP_AUDIO_MEDIA_TYPE,
media_codec_type=A2DP_SBC_CODEC_TYPE,
@@ -63,14 +66,16 @@ def codec_capabilities():
subbands=8,
allocation_method=SBC_LOUDNESS_ALLOCATION_METHOD,
minimum_bitpool_value=2,
maximum_bitpool_value = 53
)
maximum_bitpool_value=53,
),
)
# -----------------------------------------------------------------------------
def on_avdtp_connection(read_function, protocol):
packet_source = SbcPacketSource(read_function, protocol.l2cap_channel.mtu, codec_capabilities())
packet_source = SbcPacketSource(
read_function, protocol.l2cap_channel.mtu, codec_capabilities()
)
packet_pump = MediaPacketPump(packet_source.packets)
protocol.add_source(packet_source.codec_capabilities, packet_pump)
@@ -83,14 +88,18 @@ async def stream_packets(read_function, protocol):
print('@@@', endpoint)
# Select a sink
sink = protocol.find_remote_sink_by_codec(AVDTP_AUDIO_MEDIA_TYPE, A2DP_SBC_CODEC_TYPE)
sink = protocol.find_remote_sink_by_codec(
AVDTP_AUDIO_MEDIA_TYPE, A2DP_SBC_CODEC_TYPE
)
if sink is None:
print(color('!!! no SBC sink found', 'red'))
return
print(f'### Selected sink: {sink.seid}')
# Stream the packets
packet_source = SbcPacketSource(read_function, protocol.l2cap_channel.mtu, codec_capabilities())
packet_source = SbcPacketSource(
read_function, protocol.l2cap_channel.mtu, codec_capabilities()
)
packet_pump = MediaPacketPump(packet_source.packets)
source = protocol.add_source(packet_source.codec_capabilities, packet_pump)
stream = await protocol.create_stream(source, sink)
@@ -107,8 +116,13 @@ async def stream_packets(read_function, protocol):
# -----------------------------------------------------------------------------
async def main():
if len(sys.argv) < 4:
print('Usage: run_a2dp_source.py <device-config> <transport-spec> <sbc-file> [<bluetooth-address>]')
print('example: run_a2dp_source.py classic1.json usb:0 test.sbc E1:CA:72:48:C4:E8')
print(
'Usage: run_a2dp_source.py <device-config> <transport-spec> <sbc-file> '
'[<bluetooth-address>]'
)
print(
'example: run_a2dp_source.py classic1.json usb:0 test.sbc E1:CA:72:48:C4:E8'
)
return
print('<<< connecting to HCI...')
@@ -126,7 +140,8 @@ async def main():
await device.power_on()
with open(sys.argv[3], 'rb') as sbc_file:
# NOTE: this should be using asyncio file reading, but blocking reads are good enough for testing
# NOTE: this should be using asyncio file reading, but blocking reads are
# good enough for testing
async def read(byte_count):
return sbc_file.read(byte_count)
@@ -134,7 +149,9 @@ async def main():
# Connect to a peer
target_address = sys.argv[4]
print(f'=== Connecting to {target_address}...')
connection = await device.connect(target_address, transport=BT_BR_EDR_TRANSPORT)
connection = await device.connect(
target_address, transport=BT_BR_EDR_TRANSPORT
)
print(f'=== Connected to {connection.peer_address}!')
# Request authentication
@@ -148,7 +165,9 @@ async def main():
print('*** Encryption on')
# Look for an A2DP service
avdtp_version = await find_avdtp_service_with_connection(device, connection)
avdtp_version = await find_avdtp_service_with_connection(
device, connection
)
if not avdtp_version:
print(color('!!! no A2DP service found'))
return
@@ -161,7 +180,9 @@ async def main():
else:
# Create a listener to wait for AVDTP connections
listener = Listener(Listener.create_registrar(device), version=(1, 2))
listener.on('connection', lambda protocol: on_avdtp_connection(read, protocol))
listener.on(
'connection', lambda protocol: on_avdtp_connection(read, protocol)
)
# Become connectable and wait for a connection
await device.set_discoverable(True)

View File

@@ -16,21 +16,21 @@
# Imports
# -----------------------------------------------------------------------------
import asyncio
import logging
import sys
import os
from bumble.device import AdvertisingType, Device
from bumble.hci import Address
from bumble.hci import *
from bumble.controller import *
from bumble.device import *
from bumble.transport import *
from bumble.host import *
from bumble.transport import open_transport_or_link
# -----------------------------------------------------------------------------
async def main():
if len(sys.argv) < 3:
print('Usage: run_advertiser.py <config-file> <transport-spec> [type] [address]')
print(
'Usage: run_advertiser.py <config-file> <transport-spec> [type] [address]'
)
print('example: run_advertiser.py device1.json usb:0')
return
@@ -56,6 +56,7 @@ async def main():
await device.start_advertising(advertising_type=advertising_type, target=target)
await hci_source.wait_for_termination()
# -----------------------------------------------------------------------------
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main())

Some files were not shown because too many files have changed in this diff Show More