Compare commits

...

75 Commits

Author SHA1 Message Date
Gilles Boccon-Gibod
2478d45673 more windows compat fixes 2023-09-12 14:52:42 -07:00
Gilles Boccon-Gibod
1bc7d94111 windows NamedTemporaryFile compatibility 2023-09-12 14:33:12 -07:00
Gilles Boccon-Gibod
6432414cd5 run tests on windows and mac in addition to linux 2023-09-12 13:50:15 -07:00
Gilles Boccon-Gibod
179064ba15 run pre-commit tests with all supported Python versions 2023-09-12 13:42:33 -07:00
William Escande
783b2d70a5 Add connection parameter update from peripheral 2023-09-12 11:08:04 -07:00
zxzxwu
80824f3fc1 Merge pull request #280 from zxzxwu/device_typing
Add terminated to TransportSource protocol
2023-09-12 20:46:35 +08:00
Gilles Boccon-Gibod
56139c622f Merge pull request #258 from mogenson/vsc_tx_power
Add support for Zephyr HCI VSC set TX power command
2023-09-11 21:34:11 -07:00
Michael Mogenson
da02f6a39b Add HCI Zephyr vendor commands to read and write TX power
Create platforms/zephyr/hci.py with definitions of vendor HCI commands
to read and write TX power.

Add documentation for how to prepare an nRF52840 dongle with a Zephyr
HCI USB firmware application that includes dynamic TX power support and
how to send a write TX power vendor HCI command from Bumble.
2023-09-11 10:06:10 -04:00
Josh Wu
548d5597c0 Transport: Add termination protocol signature 2023-09-11 14:36:40 +08:00
zxzxwu
7fd65d2412 Merge pull request #279 from zxzxwu/typo
Fix typo
2023-09-11 03:02:11 +08:00
Josh Wu
05a54a4af9 Fix typo 2023-09-10 20:32:58 +08:00
Gilles Boccon-Gibod
1e00c8f456 Merge pull request #276 from google/gbg/add-zephyr-zip-to-docs
add zephyr binary to docs
2023-09-08 18:07:15 -07:00
Gilles Boccon-Gibod
90d165aa01 add zephyr binary 2023-09-08 14:17:15 -07:00
zxzxwu
01603ca9e4 Merge pull request #271 from zxzxwu/device_typing
Typing transport and relateds
2023-09-09 00:55:59 +08:00
Gilles Boccon-Gibod
a1b6eb61f2 Merge pull request #269 from google/gbg/android_vendor_hci
add support for vendor HCI commands and events
2023-09-08 08:50:49 -07:00
zxzxwu
25f300d3ec Merge pull request #270 from zxzxwu/typo
Fix typos
2023-09-08 17:32:33 +08:00
Josh Wu
41fe63df06 Fix typos 2023-09-08 16:30:06 +08:00
Josh Wu
b312170d5f Typing transport 2023-09-08 15:27:01 +08:00
David Duarte
cf7f2e8f44 Make platformdirs import lazy
platformdirs is not available in Android
2023-09-07 21:13:29 -07:00
Gilles Boccon-Gibod
d292083ed1 Merge pull request #272 from zxzxwu/gfp
Bring HfpProtocol back
2023-09-07 13:03:36 -07:00
Gilles Boccon-Gibod
9b11142b45 Merge pull request #267 from google/gbg/rfcomm-with-uuid
rfcomm with UUID
2023-09-07 13:01:56 -07:00
Hui Peng
acdbc4d7b9 Raise an exception when an L2cap connection fails 2023-09-07 19:24:38 +02:00
Josh Wu
838d10a09d Add HFP tests 2023-09-07 23:20:16 +08:00
Josh Wu
3852aa056b Bring HfpProtocol back 2023-09-07 23:20:09 +08:00
Gilles Boccon-Gibod
ae77e4528f add support for vendor HCI commands and events 2023-09-06 20:00:15 -07:00
Gilles Boccon-Gibod
9303f4fc5b Merge pull request #262 from whitevegagabriel/l2cap
Port l2cap_bridge sample to Rust
2023-09-06 17:13:12 -07:00
Gilles Boccon-Gibod
8be9f4cb0e add doc and fix types 2023-09-06 17:05:30 -07:00
Gilles Boccon-Gibod
1ea12b1bf7 rebase 2023-09-06 17:05:24 -07:00
Gilles Boccon-Gibod
65e6d68355 add tcp server 2023-09-06 16:49:21 -07:00
Gabriel White-Vega
9732eb8836 Address PR feedback 2023-09-06 09:47:08 -04:00
Gabriel White-Vega
5ae668bc70 Port l2cap_bridge sample to Rust
- Added Rust wrappers where relevant
- Edited a couple logs in python l2cap_bridge to be more symmetrical
- Created cli subcommand for running the rustified l2cap bridge
2023-09-05 16:03:02 -04:00
Gilles Boccon-Gibod
fd4d1bcca3 Merge pull request #261 from marshallpierce/mp/rust-realtek-tools
Rust tools for working with Realtek firmware
2023-09-05 10:55:29 -07:00
Gilles Boccon-Gibod
0a251c9f8e Merge pull request #265 from mogenson/grpcio-update
Update grpcio and pip package versions
2023-08-31 14:53:54 -07:00
Michael Mogenson
351d77be59 Update grpcio and pip package versions
The current grpcio version 1.51.1 fails to build on aarch64 based MacOS
computers. Update the version of the grpcio and grpcio-tools packages to
the latest 1.57.0 version. There are binary wheels available for this
version from PyPi for aarch64 MacOS.

Also update the pip version for the Conda environment. It seems a newer
version of pip is required to detect and install these wheels.

Testing:

invoke test passes and I can start the bumble-pandora-server
successfully.
2023-08-31 14:01:14 -04:00
Marshall Pierce
0e2fc80509 Rust tools for working with Realtek firmware
Further adventures in porting tools to Rust to flesh out the supported
API.

These tools didn't feel like `example`s, so I made a top level `bumble`
CLI tool that hosts them all as subcommands. I also moved the usb probe
not-really-an-`example` into it as well. I'm open to suggestions on how
best to organize the subcommands to make them intuitive to explore with
`--help`, and how to leave room for other future tools.

I also adopted the per-OS project data dir for a default firmware
location so that users can download once and then use those .bin files
from anywhere without having to sprinkle .bin files in project
directories or reaching inside the python package dir hierarchy.
2023-08-30 15:37:35 -06:00
Gilles Boccon-Gibod
8f3fdecb93 Merge pull request #263 from zxzxwu/pdu
Typing packet transmission flow
2023-08-30 11:15:12 -07:00
Josh Wu
249a205d8e Typing packet transmission flow 2023-08-30 01:47:46 +08:00
Gilles Boccon-Gibod
7485801222 Merge pull request #256 from zxzxwu/sdp-type-fix
Typing SDP and add tests
2023-08-28 08:41:02 -07:00
Gilles Boccon-Gibod
4678e59737 Merge pull request #250 from google/gbg/new-rtk-dongles
add entry to the list of supported USB devices
2023-08-28 08:40:40 -07:00
Gilles Boccon-Gibod
952d351c00 Merge pull request #247 from google/gbg/wasm-with-ws
wasm with ws
2023-08-28 08:40:18 -07:00
Josh Wu
901eb55b0e Add SDP self tests 2023-08-24 01:27:07 +08:00
Josh Wu
727586e40e Typing SDP 2023-08-23 14:52:44 +08:00
Gilles Boccon-Gibod
3aa678a58e Merge pull request #253 from zxzxwu/rfcomm_type_fix
Adding more typing in rfcomm.py
2023-08-22 09:47:38 -07:00
Gilles Boccon-Gibod
fc7c1a8113 Merge pull request #255 from zxzxwu/player
Remove accidentally added files
2023-08-22 07:34:31 -07:00
Josh Wu
f62a0bbe75 Remove accidentally added files 2023-08-22 22:12:41 +08:00
Josh Wu
7341172739 Use __future__.annotations for typing 2023-08-22 14:44:15 +08:00
Gilles Boccon-Gibod
91b9fbe450 Merge pull request #240 from zxzxwu/ssp
Handle SSP Complete events
2023-08-21 18:01:28 -07:00
Josh Wu
e6b566b848 RFCOMM: Refactor role to enum 2023-08-21 15:16:34 +08:00
Josh Wu
2527a711dc Refactor RFCOMM states to enum 2023-08-21 15:12:52 +08:00
Josh Wu
5fba6b1cae Complete typing in RFCOMM 2023-08-21 15:12:52 +08:00
Gilles Boccon-Gibod
43e632f83c Merge pull request #244 from google/gbg/hci-source-termination-mode
add sink method for lost transports
2023-08-18 10:17:11 -07:00
Gilles Boccon-Gibod
623298b0e9 emit flush event when transport lost 2023-08-18 09:59:15 -07:00
Gilles Boccon-Gibod
85a61dc39d add entry to the list of supported USB devices 2023-08-18 09:56:06 -07:00
Gilles Boccon-Gibod
6e8c44b5e6 Merge pull request #249 from zxzxwu/player
Support SBC in speaker.app
2023-08-18 09:55:23 -07:00
Josh Wu
ec4dcc174e Support SBC in speaker.app 2023-08-18 17:13:11 +08:00
Charlie Boutier
b247aca3b4 pandora_server: add support to accept bumble config file 2023-08-17 14:24:56 -07:00
Gilles Boccon-Gibod
6226bfd196 fix typo after refactor 2023-08-17 09:51:56 -07:00
Gilles Boccon-Gibod
71e11b7cf8 format 2023-08-15 15:20:48 -07:00
Gilles Boccon-Gibod
800c62fdb6 add readme for web examples 2023-08-15 15:17:38 -07:00
Gilles Boccon-Gibod
640b9cd53a refactor pyiodide support and add examples 2023-08-15 13:36:58 -07:00
Gilles Boccon-Gibod
f4add16aea Merge pull request #241 from hchataing/hfp-hf
hfp: Implement initiate SLC procedure for HFP-HF
2023-08-14 10:32:55 -07:00
Gilles Boccon-Gibod
2bfec3c4ed add sink method for lost transports 2023-08-12 10:54:20 -07:00
Henri Chataing
9963b51c04 hfp: Implement initiate SLC procedure for HFP-HF 2023-08-10 08:37:54 -07:00
Josh Wu
2af3494d8c Handle SSP Complete events 2023-08-10 10:58:41 +08:00
Gilles Boccon-Gibod
fe28473ba8 Merge pull request #234 from zxzxwu/addr
Support address resolution offload
2023-08-08 21:30:13 -07:00
Gilles Boccon-Gibod
53d66bc74a Merge pull request #237 from marshallpierce/mp/company-ids
Faster company id table
2023-08-08 21:29:45 -07:00
Marshall Pierce
e2c1ad5342 Faster company id table
Following up on the [loose end from the initial
PR](https://github.com/google/bumble/pull/207#discussion_r1278015116),
we can avoid accessing the Python company id map at runtime by doing
code gen ahead of time.

Using an example to do the code gen avoids even the small build slowdown
from invoking the code gen logic in build.rs, but more importantly,
means that it's still a totally boring normal build that won't require
any IDE setup, etc, to work for everyone. Since the company ID list
changes rarely, and there's a test to ensure it always matches, this
seems like a good trade.
2023-08-04 10:12:52 -06:00
Josh Wu
6399c5fb04 Auto add device to resolving list after pairing 2023-08-03 20:51:00 +08:00
Josh Wu
784cf4f26a Add a flag to enable LE address resolution 2023-08-03 20:50:57 +08:00
Josh Wu
0301b1a999 Pandora: Configure identity address type 2023-08-02 11:31:07 -07:00
Lucas Abel
3ab2cd5e71 pandora: decrease all info logs to debug 2023-08-02 10:56:41 -07:00
uael
6ea669531a pandora: add tcp option to transport configuration
* Add a fallback to `tcp` when `transport` is not set.
* Default the `tcp` transport to the default rootcanal HCI address.
2023-08-01 08:51:12 -07:00
Josh Wu
cbbada4748 SMP: Delegate distributed address type 2023-08-01 08:38:03 -07:00
Gilles Boccon-Gibod
152b8d1233 Merge pull request #230 from google/gbg/hci-object-array
add support for field arrays in hci packet definitions
2023-08-01 07:44:31 -07:00
Gilles Boccon-Gibod
bdad225033 add support for field arrays in hci packet definitions 2023-07-30 22:19:10 -07:00
114 changed files with 9936 additions and 1549 deletions

View File

@@ -14,6 +14,10 @@ jobs:
check: check:
name: Check Code name: Check Code
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11"]
fail-fast: false
steps: steps:
- name: Check out from Git - name: Check out from Git

View File

@@ -12,10 +12,10 @@ permissions:
jobs: jobs:
build: build:
runs-on: ${{ matrix.os }}
runs-on: ubuntu-latest
strategy: strategy:
matrix: matrix:
os: ['ubuntu-latest', 'macos-latest', 'windows-latest']
python-version: ["3.8", "3.9", "3.10", "3.11"] python-version: ["3.8", "3.9", "3.10", "3.11"]
fail-fast: false fail-fast: false
@@ -41,11 +41,13 @@ jobs:
run: | run: |
inv build inv build
inv build.mkdocs inv build.mkdocs
build-rust: build-rust:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
matrix: matrix:
python-version: [ "3.8", "3.9", "3.10" ] python-version: [ "3.8", "3.9", "3.10", "3.11" ]
rust-version: [ "1.70.0", "stable" ]
fail-fast: false fail-fast: false
steps: steps:
- name: Check out from Git - name: Check out from Git
@@ -62,9 +64,15 @@ jobs:
uses: actions-rust-lang/setup-rust-toolchain@v1 uses: actions-rust-lang/setup-rust-toolchain@v1
with: with:
components: clippy,rustfmt components: clippy,rustfmt
- name: Rust Lints toolchain: ${{ matrix.rust-version }}
run: cd rust && cargo fmt --check && cargo clippy --all-targets -- --deny warnings
- name: Rust Build - name: Rust Build
run: cd rust && cargo build --all-targets run: cd rust && cargo build --all-targets && cargo build --all-features --all-targets
# Lints after build so what clippy needs is already built
- name: Rust Lints
run: cd rust && cargo fmt --check && cargo clippy --all-targets -- --deny warnings && cargo clippy --all-features --all-targets -- --deny warnings
- name: Rust Tests - name: Rust Tests
run: cd rust && cargo test run: cd rust && cargo test
# At some point, hook up publishing the binary. For now, just make sure it builds.
# Once we're ready to publish binaries, this should be built with `--release`.
- name: Build Bumble CLI
run: cd rust && cargo build --features bumble-tools --bin bumble

View File

@@ -105,7 +105,7 @@ class ServerBridge:
asyncio.create_task(self.pipe.l2cap_channel.disconnect()) asyncio.create_task(self.pipe.l2cap_channel.disconnect())
def data_received(self, data): def data_received(self, data):
print(f'<<< Received on TCP: {len(data)}') print(color(f'<<< [TCP DATA]: {len(data)} bytes', 'blue'))
self.pipe.l2cap_channel.write(data) self.pipe.l2cap_channel.write(data)
try: try:
@@ -123,6 +123,7 @@ class ServerBridge:
await self.l2cap_channel.disconnect() await self.l2cap_channel.disconnect()
def on_l2cap_close(self): def on_l2cap_close(self):
print(color('*** L2CAP channel closed', 'red'))
self.l2cap_channel = None self.l2cap_channel = None
if self.tcp_transport is not None: if self.tcp_transport is not None:
self.tcp_transport.close() self.tcp_transport.close()

View File

@@ -1,8 +1,10 @@
import asyncio import asyncio
import click import click
import logging import logging
import json
from bumble.pandora import PandoraDevice, serve from bumble.pandora import PandoraDevice, serve
from typing import Dict, Any
BUMBLE_SERVER_GRPC_PORT = 7999 BUMBLE_SERVER_GRPC_PORT = 7999
ROOTCANAL_PORT_CUTTLEFISH = 7300 ROOTCANAL_PORT_CUTTLEFISH = 7300
@@ -18,13 +20,30 @@ ROOTCANAL_PORT_CUTTLEFISH = 7300
help='HCI transport', help='HCI transport',
default=f'tcp-client:127.0.0.1:<rootcanal-port>', default=f'tcp-client:127.0.0.1:<rootcanal-port>',
) )
def main(grpc_port: int, rootcanal_port: int, transport: str) -> None: @click.option(
'--config',
help='Bumble json configuration file',
)
def main(grpc_port: int, rootcanal_port: int, transport: str, config: str) -> None:
if '<rootcanal-port>' in transport: if '<rootcanal-port>' in transport:
transport = transport.replace('<rootcanal-port>', str(rootcanal_port)) transport = transport.replace('<rootcanal-port>', str(rootcanal_port))
device = PandoraDevice({'transport': transport})
bumble_config = retrieve_config(config)
if 'transport' not in bumble_config.keys():
bumble_config.update({'transport': transport})
device = PandoraDevice(bumble_config)
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
asyncio.run(serve(device, port=grpc_port)) asyncio.run(serve(device, port=grpc_port))
def retrieve_config(config: str) -> Dict[str, Any]:
if not config:
return {}
with open(config, 'r') as f:
return json.load(f)
if __name__ == '__main__': if __name__ == '__main__':
main() # pylint: disable=no-value-for-parameter main() # pylint: disable=no-value-for-parameter

View File

@@ -102,9 +102,21 @@ class SnoopPacketReader:
default='h4', default='h4',
help='Format of the input file', help='Format of the input file',
) )
@click.option(
'--vendors',
type=click.Choice(['android', 'zephyr']),
multiple=True,
help='Support vendor-specific commands (list one or more)',
)
@click.argument('filename') @click.argument('filename')
# pylint: disable=redefined-builtin # pylint: disable=redefined-builtin
def main(format, filename): def main(format, vendors, filename):
for vendor in vendors:
if vendor == 'android':
import bumble.vendor.android.hci
elif vendor == 'zephyr':
import bumble.vendor.zephyr.hci
input = open(filename, 'rb') input = open(filename, 'rb')
if format == 'h4': if format == 'h4':
packet_reader = PacketReader(input) packet_reader = PacketReader(input)
@@ -124,7 +136,6 @@ def main(format, filename):
if packet is None: if packet is None:
break break
tracer.trace(hci.HCI_Packet.from_bytes(packet), direction) tracer.trace(hci.HCI_Packet.from_bytes(packet), direction)
except Exception as error: except Exception as error:
print(color(f'!!! {error}', 'red')) print(color(f'!!! {error}', 'red'))

View File

@@ -56,7 +56,7 @@ body, h1, h2, h3, h4, h5, h6 {
border-radius: 4px; border-radius: 4px;
padding: 4px; padding: 4px;
margin: 6px; margin: 6px;
margin-left: 0px; margin-left: 0;
} }
th, td { th, td {
@@ -65,7 +65,7 @@ th, td {
} }
.properties td:nth-child(even) { .properties td:nth-child(even) {
background-color: #D6EEEE; background-color: #d6eeee;
font-family: monospace; font-family: monospace;
} }

View File

@@ -2,7 +2,7 @@
<html> <html>
<head> <head>
<title>Bumble Speaker</title> <title>Bumble Speaker</title>
<script type="text/javascript" src="speaker.js"></script> <script src="speaker.js"></script>
<link rel="stylesheet" href="speaker.css"> <link rel="stylesheet" href="speaker.css">
</head> </head>
<body> <body>

View File

@@ -228,10 +228,11 @@ class FfplayOutput(QueuedOutput):
subprocess: Optional[asyncio.subprocess.Process] subprocess: Optional[asyncio.subprocess.Process]
ffplay_task: Optional[asyncio.Task] ffplay_task: Optional[asyncio.Task]
def __init__(self) -> None: def __init__(self, codec: str) -> None:
super().__init__(AacAudioExtractor()) super().__init__(AudioExtractor.create(codec))
self.subprocess = None self.subprocess = None
self.ffplay_task = None self.ffplay_task = None
self.codec = codec
async def start(self): async def start(self):
if self.started: if self.started:
@@ -240,7 +241,7 @@ class FfplayOutput(QueuedOutput):
await super().start() await super().start()
self.subprocess = await asyncio.create_subprocess_shell( self.subprocess = await asyncio.create_subprocess_shell(
'ffplay -acodec aac pipe:0', f'ffplay -f {self.codec} pipe:0',
stdin=asyncio.subprocess.PIPE, stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE,
@@ -419,7 +420,7 @@ class Speaker:
self.outputs = [] self.outputs = []
for output in outputs: for output in outputs:
if output == '@ffplay': if output == '@ffplay':
self.outputs.append(FfplayOutput()) self.outputs.append(FfplayOutput(codec))
continue continue
# Default to FileOutput # Default to FileOutput
@@ -708,17 +709,6 @@ def speaker(
): ):
"""Run the speaker.""" """Run the speaker."""
# ffplay only works with AAC for now
if codec != 'aac' and '@ffplay' in output:
print(
color(
f'{codec} not supported with @ffplay output, '
'@ffplay output will be skipped',
'yellow',
)
)
output = list(filter(lambda x: x != '@ffplay', output))
if '@ffplay' in output: if '@ffplay' in output:
# Check if ffplay is installed # Check if ffplay is installed
try: try:

85
bumble/at.py Normal file
View File

@@ -0,0 +1,85 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Union
def tokenize_parameters(buffer: bytes) -> List[bytes]:
"""Split input parameters into tokens.
Removes space characters outside of double quote blocks:
T-rec-V-25 - 5.2.1 Command line general format: "Space characters (IA5 2/0)
are ignored [..], unless they are embedded in numeric or string constants"
Raises ValueError in case of invalid input string."""
tokens = []
in_quotes = False
token = bytearray()
for b in buffer:
char = bytearray([b])
if in_quotes:
token.extend(char)
if char == b'\"':
in_quotes = False
tokens.append(token[1:-1])
token = bytearray()
else:
if char == b' ':
pass
elif char == b',' or char == b')':
tokens.append(token)
tokens.append(char)
token = bytearray()
elif char == b'(':
if len(token) > 0:
raise ValueError("open_paren following regular character")
tokens.append(char)
elif char == b'"':
if len(token) > 0:
raise ValueError("quote following regular character")
in_quotes = True
token.extend(char)
else:
token.extend(char)
tokens.append(token)
return [bytes(token) for token in tokens if len(token) > 0]
def parse_parameters(buffer: bytes) -> List[Union[bytes, list]]:
"""Parse the parameters using the comma and parenthesis separators.
Raises ValueError in case of invalid input string."""
tokens = tokenize_parameters(buffer)
accumulator: List[list] = [[]]
current: Union[bytes, list] = bytes()
for token in tokens:
if token == b',':
accumulator[-1].append(current)
current = bytes()
elif token == b'(':
accumulator.append([])
elif token == b')':
if len(accumulator) < 2:
raise ValueError("close_paren without matching open_paren")
accumulator[-1].append(current)
current = accumulator.pop()
else:
current = token
accumulator[-1].append(current)
if len(accumulator) > 1:
raise ValueError("missing close_paren")
return accumulator[0]

View File

@@ -15,6 +15,8 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations
import logging import logging
import asyncio import asyncio
import itertools import itertools
@@ -58,8 +60,10 @@ from bumble.hci import (
HCI_Packet, HCI_Packet,
HCI_Role_Change_Event, HCI_Role_Change_Event,
) )
from typing import Optional, Union, Dict from typing import Optional, Union, Dict, TYPE_CHECKING
if TYPE_CHECKING:
from bumble.transport.common import TransportSink, TransportSource
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -104,7 +108,7 @@ class Controller:
self, self,
name, name,
host_source=None, host_source=None,
host_sink=None, host_sink: Optional[TransportSink] = None,
link=None, link=None,
public_address: Optional[Union[bytes, str, Address]] = None, public_address: Optional[Union[bytes, str, Address]] = None,
): ):
@@ -188,6 +192,8 @@ class Controller:
if link: if link:
link.add_controller(self) link.add_controller(self)
self.terminated = asyncio.get_running_loop().create_future()
@property @property
def host(self): def host(self):
return self.hci_sink return self.hci_sink
@@ -288,10 +294,9 @@ class Controller:
if self.host: if self.host:
self.host.on_packet(packet.to_bytes()) self.host.on_packet(packet.to_bytes())
# This method allow the controller to emulate the same API as a transport source # This method allows the controller to emulate the same API as a transport source
async def wait_for_termination(self): async def wait_for_termination(self):
# For now, just wait forever await self.terminated
await asyncio.get_running_loop().create_future()
############################################################ ############################################################
# Link connections # Link connections

View File

@@ -78,7 +78,13 @@ def get_dict_key_by_value(dictionary, value):
class BaseError(Exception): class BaseError(Exception):
"""Base class for errors with an error code, error name and namespace""" """Base class for errors with an error code, error name and namespace"""
def __init__(self, error_code, error_namespace='', error_name='', details=''): def __init__(
self,
error_code: int | None,
error_namespace: str = '',
error_name: str = '',
details: str = '',
):
super().__init__() super().__init__()
self.error_code = error_code self.error_code = error_code
self.error_namespace = error_namespace self.error_namespace = error_namespace
@@ -90,12 +96,14 @@ class BaseError(Exception):
namespace = f'{self.error_namespace}/' namespace = f'{self.error_namespace}/'
else: else:
namespace = '' namespace = ''
if self.error_name: error_text = {
name = f'{self.error_name} [0x{self.error_code:X}]' (True, True): f'{self.error_name} [0x{self.error_code:X}]',
else: (True, False): self.error_name,
name = f'0x{self.error_code:X}' (False, True): f'0x{self.error_code:X}',
(False, False): '',
}[(self.error_name != '', self.error_code is not None)]
return f'{type(self).__name__}({namespace}{name})' return f'{type(self).__name__}({namespace}{error_text})'
class ProtocolError(BaseError): class ProtocolError(BaseError):
@@ -134,6 +142,10 @@ class ConnectionError(BaseError): # pylint: disable=redefined-builtin
self.peer_address = peer_address self.peer_address = peer_address
class ConnectionParameterUpdateError(BaseError):
"""Connection Parameter Update Error"""
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# UUID # UUID
# #

View File

@@ -23,22 +23,18 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import logging import logging
import operator import operator
import platform
if platform.system() != 'Emscripten': import secrets
import secrets from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.primitives.asymmetric.ec import (
from cryptography.hazmat.primitives.asymmetric.ec import ( generate_private_key,
generate_private_key, ECDH,
ECDH, EllipticCurvePublicNumbers,
EllipticCurvePublicNumbers, EllipticCurvePrivateNumbers,
EllipticCurvePrivateNumbers, SECP256R1,
SECP256R1, )
) from cryptography.hazmat.primitives import cmac
from cryptography.hazmat.primitives import cmac
else:
# TODO: implement stubs
pass
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging

View File

@@ -23,7 +23,18 @@ import asyncio
import logging import logging
from contextlib import asynccontextmanager, AsyncExitStack from contextlib import asynccontextmanager, AsyncExitStack
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Type, Union from typing import (
Any,
Callable,
ClassVar,
Dict,
List,
Optional,
Tuple,
Type,
Union,
TYPE_CHECKING,
)
from .colors import color from .colors import color
from .att import ATT_CID, ATT_DEFAULT_MTU, ATT_PDU from .att import ATT_CID, ATT_DEFAULT_MTU, ATT_PDU
@@ -86,6 +97,7 @@ from .hci import (
HCI_LE_Extended_Create_Connection_Command, HCI_LE_Extended_Create_Connection_Command,
HCI_LE_Rand_Command, HCI_LE_Rand_Command,
HCI_LE_Read_PHY_Command, HCI_LE_Read_PHY_Command,
HCI_LE_Set_Address_Resolution_Enable_Command,
HCI_LE_Set_Advertising_Data_Command, HCI_LE_Set_Advertising_Data_Command,
HCI_LE_Set_Advertising_Enable_Command, HCI_LE_Set_Advertising_Enable_Command,
HCI_LE_Set_Advertising_Parameters_Command, HCI_LE_Set_Advertising_Parameters_Command,
@@ -129,6 +141,7 @@ from .core import (
BT_LE_TRANSPORT, BT_LE_TRANSPORT,
BT_PERIPHERAL_ROLE, BT_PERIPHERAL_ROLE,
AdvertisingData, AdvertisingData,
ConnectionParameterUpdateError,
CommandTimeoutError, CommandTimeoutError,
ConnectionPHY, ConnectionPHY,
InvalidStateError, InvalidStateError,
@@ -151,6 +164,9 @@ from . import sdp
from . import l2cap from . import l2cap
from . import core from . import core
if TYPE_CHECKING:
from .transport.common import TransportSource, TransportSink
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -651,7 +667,7 @@ class Connection(CompositeEventEmitter):
def is_incomplete(self) -> bool: def is_incomplete(self) -> bool:
return self.handle is None return self.handle is None
def send_l2cap_pdu(self, cid, pdu): def send_l2cap_pdu(self, cid: int, pdu: bytes) -> None:
self.device.send_l2cap_pdu(self.handle, cid, pdu) self.device.send_l2cap_pdu(self.handle, cid, pdu)
def create_l2cap_connector(self, psm): def create_l2cap_connector(self, psm):
@@ -708,6 +724,7 @@ class Connection(CompositeEventEmitter):
connection_interval_max, connection_interval_max,
max_latency, max_latency,
supervision_timeout, supervision_timeout,
use_l2cap=False,
): ):
return await self.device.update_connection_parameters( return await self.device.update_connection_parameters(
self, self,
@@ -715,6 +732,7 @@ class Connection(CompositeEventEmitter):
connection_interval_max, connection_interval_max,
max_latency, max_latency,
supervision_timeout, supervision_timeout,
use_l2cap=use_l2cap,
) )
async def set_phy(self, tx_phys=None, rx_phys=None, phy_options=None): async def set_phy(self, tx_phys=None, rx_phys=None, phy_options=None):
@@ -778,6 +796,7 @@ class DeviceConfiguration:
self.irk = bytes(16) # This really must be changed for any level of security self.irk = bytes(16) # This really must be changed for any level of security
self.keystore = None self.keystore = None
self.gatt_services: List[Dict[str, Any]] = [] self.gatt_services: List[Dict[str, Any]] = []
self.address_resolution_offload = False
def load_from_dict(self, config: Dict[str, Any]) -> None: def load_from_dict(self, config: Dict[str, Any]) -> None:
# Load simple properties # Load simple properties
@@ -940,7 +959,13 @@ class Device(CompositeEventEmitter):
pass pass
@classmethod @classmethod
def with_hci(cls, name, address, hci_source, hci_sink): def with_hci(
cls,
name: str,
address: Address,
hci_source: TransportSource,
hci_sink: TransportSink,
) -> Device:
''' '''
Create a Device instance with a Host configured to communicate with a controller Create a Device instance with a Host configured to communicate with a controller
through an HCI source/sink through an HCI source/sink
@@ -949,18 +974,25 @@ class Device(CompositeEventEmitter):
return cls(name=name, address=address, host=host) return cls(name=name, address=address, host=host)
@classmethod @classmethod
def from_config_file(cls, filename): def from_config_file(cls, filename: str) -> Device:
config = DeviceConfiguration() config = DeviceConfiguration()
config.load_from_file(filename) config.load_from_file(filename)
return cls(config=config) return cls(config=config)
@classmethod @classmethod
def from_config_with_hci(cls, config, hci_source, hci_sink): def from_config_with_hci(
cls,
config: DeviceConfiguration,
hci_source: TransportSource,
hci_sink: TransportSink,
) -> Device:
host = Host(controller_source=hci_source, controller_sink=hci_sink) host = Host(controller_source=hci_source, controller_sink=hci_sink)
return cls(config=config, host=host) return cls(config=config, host=host)
@classmethod @classmethod
def from_config_file_with_hci(cls, filename, hci_source, hci_sink): def from_config_file_with_hci(
cls, filename: str, hci_source: TransportSource, hci_sink: TransportSink
) -> Device:
config = DeviceConfiguration() config = DeviceConfiguration()
config.load_from_file(filename) config.load_from_file(filename)
return cls.from_config_with_hci(config, hci_source, hci_sink) return cls.from_config_with_hci(config, hci_source, hci_sink)
@@ -1029,6 +1061,7 @@ class Device(CompositeEventEmitter):
self.discoverable = config.discoverable self.discoverable = config.discoverable
self.connectable = config.connectable self.connectable = config.connectable
self.classic_accept_any = config.classic_accept_any self.classic_accept_any = config.classic_accept_any
self.address_resolution_offload = config.address_resolution_offload
for service in config.gatt_services: for service in config.gatt_services:
characteristics = [] characteristics = []
@@ -1093,7 +1126,7 @@ class Device(CompositeEventEmitter):
return self._host return self._host
@host.setter @host.setter
def host(self, host): def host(self, host: Host) -> None:
# Unsubscribe from events from the current host # Unsubscribe from events from the current host
if self._host: if self._host:
for event_name in device_host_event_handlers: for event_name in device_host_event_handlers:
@@ -1180,7 +1213,7 @@ class Device(CompositeEventEmitter):
connection, psm, max_credits, mtu, mps connection, psm, max_credits, mtu, mps
) )
def send_l2cap_pdu(self, connection_handle, cid, pdu): def send_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes) -> None:
self.host.send_l2cap_pdu(connection_handle, cid, pdu) self.host.send_l2cap_pdu(connection_handle, cid, pdu)
async def send_command(self, command, check_result=False): async def send_command(self, command, check_result=False):
@@ -1256,31 +1289,16 @@ class Device(CompositeEventEmitter):
) )
# Load the address resolving list # Load the address resolving list
if self.keystore and self.host.supports_command( if self.keystore:
HCI_LE_CLEAR_RESOLVING_LIST_COMMAND await self.refresh_resolving_list()
):
await self.send_command(HCI_LE_Clear_Resolving_List_Command()) # type: ignore[call-arg]
resolving_keys = await self.keystore.get_resolving_keys() # Enable address resolution
for irk, address in resolving_keys: if self.address_resolution_offload:
await self.send_command( await self.send_command(
HCI_LE_Add_Device_To_Resolving_List_Command( HCI_LE_Set_Address_Resolution_Enable_Command(
peer_identity_address_type=address.address_type, address_resolution_enable=1
peer_identity_address=address, ) # type: ignore[call-arg]
peer_irk=irk, )
local_irk=self.irk,
) # type: ignore[call-arg]
)
# Enable address resolution
# await self.send_command(
# HCI_LE_Set_Address_Resolution_Enable_Command(
# address_resolution_enable=1)
# )
# )
# Create a host-side address resolver
self.address_resolver = smp.AddressResolver(resolving_keys)
if self.classic_enabled: if self.classic_enabled:
await self.send_command( await self.send_command(
@@ -1310,6 +1328,26 @@ class Device(CompositeEventEmitter):
await self.host.flush() await self.host.flush()
self.powered_on = False self.powered_on = False
async def refresh_resolving_list(self) -> None:
assert self.keystore is not None
resolving_keys = await self.keystore.get_resolving_keys()
# Create a host-side address resolver
self.address_resolver = smp.AddressResolver(resolving_keys)
if self.address_resolution_offload:
await self.send_command(HCI_LE_Clear_Resolving_List_Command()) # type: ignore[call-arg]
for irk, address in resolving_keys:
await self.send_command(
HCI_LE_Add_Device_To_Resolving_List_Command(
peer_identity_address_type=address.address_type,
peer_identity_address=address,
peer_irk=irk,
local_irk=self.irk,
) # type: ignore[call-arg]
)
def supports_le_feature(self, feature): def supports_le_feature(self, feature):
return self.host.supports_le_feature(feature) return self.host.supports_le_feature(feature)
@@ -2075,11 +2113,30 @@ class Device(CompositeEventEmitter):
supervision_timeout, supervision_timeout,
min_ce_length=0, min_ce_length=0,
max_ce_length=0, max_ce_length=0,
): use_l2cap=False,
) -> None:
''' '''
NOTE: the name of the parameters may look odd, but it just follows the names NOTE: the name of the parameters may look odd, but it just follows the names
used in the Bluetooth spec. used in the Bluetooth spec.
''' '''
if use_l2cap:
if connection.role != BT_PERIPHERAL_ROLE:
raise InvalidStateError(
'only peripheral can update connection parameters with l2cap'
)
l2cap_result = (
await self.l2cap_channel_manager.update_connection_parameters(
connection,
connection_interval_min,
connection_interval_max,
max_latency,
supervision_timeout,
)
)
if l2cap_result != l2cap.L2CAP_CONNECTION_PARAMETERS_ACCEPTED_RESULT:
raise ConnectionParameterUpdateError(l2cap_result)
result = await self.send_command( result = await self.send_command(
HCI_LE_Connection_Update_Command( HCI_LE_Connection_Update_Command(
connection_handle=connection.handle, connection_handle=connection.handle,
@@ -2089,7 +2146,7 @@ class Device(CompositeEventEmitter):
supervision_timeout=supervision_timeout, supervision_timeout=supervision_timeout,
min_ce_length=min_ce_length, min_ce_length=min_ce_length,
max_ce_length=max_ce_length, max_ce_length=max_ce_length,
) ) # type: ignore[call-arg]
) )
if result.status != HCI_Command_Status_Event.PENDING: if result.status != HCI_Command_Status_Event.PENDING:
raise HCI_StatusError(result) raise HCI_StatusError(result)
@@ -2230,9 +2287,11 @@ class Device(CompositeEventEmitter):
def request_pairing(self, connection): def request_pairing(self, connection):
return self.smp_manager.request_pairing(connection) return self.smp_manager.request_pairing(connection)
async def get_long_term_key(self, connection_handle, rand, ediv): async def get_long_term_key(
self, connection_handle: int, rand: bytes, ediv: int
) -> Optional[bytes]:
if (connection := self.lookup_connection(connection_handle)) is None: if (connection := self.lookup_connection(connection_handle)) is None:
return return None
# Start by looking for the key in an SMP session # Start by looking for the key in an SMP session
ltk = self.smp_manager.get_long_term_key(connection, rand, ediv) ltk = self.smp_manager.get_long_term_key(connection, rand, ediv)
@@ -2252,19 +2311,24 @@ class Device(CompositeEventEmitter):
if connection.role == BT_PERIPHERAL_ROLE and keys.ltk_peripheral: if connection.role == BT_PERIPHERAL_ROLE and keys.ltk_peripheral:
return keys.ltk_peripheral.value return keys.ltk_peripheral.value
return None
async def get_link_key(self, address: Address) -> Optional[bytes]: async def get_link_key(self, address: Address) -> Optional[bytes]:
# Look for the key in the keystore if self.keystore is None:
if self.keystore is not None: return None
keys = await self.keystore.get(str(address))
if keys is not None:
logger.debug('found keys in the key store')
if keys.link_key is None:
logger.warning('no link key')
return None
return keys.link_key.value # Look for the key in the keystore
return None keys = await self.keystore.get(str(address))
if keys is None:
logger.debug(f'no keys found for {address}')
return None
logger.debug('found keys in the key store')
if keys.link_key is None:
logger.warning('no link key')
return None
return keys.link_key.value
# [Classic only] # [Classic only]
async def authenticate(self, connection): async def authenticate(self, connection):
@@ -2383,6 +2447,18 @@ class Device(CompositeEventEmitter):
'connection_encryption_failure', on_encryption_failure 'connection_encryption_failure', on_encryption_failure
) )
async def update_keys(self, address: str, keys: PairingKeys) -> None:
if self.keystore is None:
return
try:
await self.keystore.update(address, keys)
await self.refresh_resolving_list()
except Exception as error:
logger.warning(f'!!! error while storing keys: {error}')
else:
self.emit('key_store_update')
# [Classic only] # [Classic only]
async def switch_role(self, connection: Connection, role: int): async def switch_role(self, connection: Connection, role: int):
pending_role_change = asyncio.get_running_loop().create_future() pending_role_change = asyncio.get_running_loop().create_future()
@@ -2477,13 +2553,7 @@ class Device(CompositeEventEmitter):
value=link_key, authenticated=authenticated value=link_key, authenticated=authenticated
) )
async def store_keys(): self.abort_on('flush', self.update_keys(str(bd_addr), pairing_keys))
try:
await self.keystore.update(str(bd_addr), pairing_keys)
except Exception as error:
logger.warning(f'!!! error while storing keys: {error}')
self.abort_on('flush', store_keys())
if connection := self.find_connection_by_bd_addr( if connection := self.find_connection_by_bd_addr(
bd_addr, transport=BT_BR_EDR_TRANSPORT bd_addr, transport=BT_BR_EDR_TRANSPORT
@@ -2735,20 +2805,6 @@ class Device(CompositeEventEmitter):
) )
connection.emit('connection_authentication_failure', error) connection.emit('connection_authentication_failure', error)
@host_event_handler
@with_connection_from_address
def on_ssp_complete(self, connection):
# On Secure Simple Pairing complete, in case:
# - Connection isn't already authenticated
# - AND we are not the initiator of the authentication
# We must trigger authentication to know if we are truly authenticated
if not connection.authenticating and not connection.authenticated:
logger.debug(
f'*** Trigger Connection Authentication: [0x{connection.handle:04X}] '
f'{connection.peer_address}'
)
asyncio.create_task(connection.authenticate())
# [Classic only] # [Classic only]
@host_event_handler @host_event_handler
@with_connection_from_address @with_connection_from_address
@@ -3103,6 +3159,18 @@ class Device(CompositeEventEmitter):
connection.emit('role_change_failure', error) connection.emit('role_change_failure', error)
self.emit('role_change_failure', address, error) self.emit('role_change_failure', address, error)
# [Classic only]
@host_event_handler
@with_connection_from_address
def on_classic_pairing(self, connection: Connection) -> None:
connection.emit('classic_pairing')
# [Classic only]
@host_event_handler
@with_connection_from_address
def on_classic_pairing_failure(self, connection: Connection, status) -> None:
connection.emit('classic_pairing_failure', status)
def on_pairing_start(self, connection: Connection) -> None: def on_pairing_start(self, connection: Connection) -> None:
connection.emit('pairing_start') connection.emit('pairing_start')
@@ -3151,7 +3219,7 @@ class Device(CompositeEventEmitter):
@host_event_handler @host_event_handler
@with_connection_from_handle @with_connection_from_handle
def on_l2cap_pdu(self, connection, cid, pdu): def on_l2cap_pdu(self, connection: Connection, cid: int, pdu: bytes):
self.l2cap_channel_manager.on_pdu(connection, cid, pdu) self.l2cap_channel_manager.on_pdu(connection, cid, pdu)
def __str__(self): def __str__(self):

View File

@@ -21,6 +21,8 @@ like loading firmware after a cold start.
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import abc import abc
import logging import logging
import pathlib
import platform
from . import rtk from . import rtk
@@ -66,3 +68,24 @@ async def get_driver_for_host(host):
return driver return driver
return None return None
def project_data_dir() -> pathlib.Path:
"""
Returns:
A path to an OS-specific directory for bumble data. The directory is created if
it doesn't exist.
"""
import platformdirs
if platform.system() == 'Darwin':
# platformdirs doesn't handle macOS right: it doesn't assemble a bundle id
# out of author & project
return platformdirs.user_data_path(
appname='com.google.bumble', ensure_exists=True
)
else:
# windows and linux don't use the com qualifier
return platformdirs.user_data_path(
appname='bumble', appauthor='google', ensure_exists=True
)

View File

@@ -34,10 +34,9 @@ import weakref
from bumble.hci import ( from bumble.hci import (
hci_command_op_code, hci_vendor_command_op_code,
STATUS_SPEC, STATUS_SPEC,
HCI_SUCCESS, HCI_SUCCESS,
HCI_COMMAND_NAMES,
HCI_Command, HCI_Command,
HCI_Reset_Command, HCI_Reset_Command,
HCI_Read_Local_Version_Information_Command, HCI_Read_Local_Version_Information_Command,
@@ -125,6 +124,7 @@ RTK_USB_PRODUCTS = {
(0x2550, 0x8761), (0x2550, 0x8761),
(0x2B89, 0x8761), (0x2B89, 0x8761),
(0x7392, 0xC611), (0x7392, 0xC611),
(0x0BDA, 0x877B),
# Realtek 8821AE # Realtek 8821AE
(0x0B05, 0x17DC), (0x0B05, 0x17DC),
(0x13D3, 0x3414), (0x13D3, 0x3414),
@@ -178,8 +178,10 @@ RTK_USB_PRODUCTS = {
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# HCI Commands # HCI Commands
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
HCI_RTK_READ_ROM_VERSION_COMMAND = hci_command_op_code(0x3F, 0x6D) HCI_RTK_READ_ROM_VERSION_COMMAND = hci_vendor_command_op_code(0x6D)
HCI_COMMAND_NAMES[HCI_RTK_READ_ROM_VERSION_COMMAND] = "HCI_RTK_READ_ROM_VERSION_COMMAND" HCI_RTK_DOWNLOAD_COMMAND = hci_vendor_command_op_code(0x20)
HCI_RTK_DROP_FIRMWARE_COMMAND = hci_vendor_command_op_code(0x66)
HCI_Command.register_commands(globals())
@HCI_Command.command(return_parameters_fields=[("status", STATUS_SPEC), ("version", 1)]) @HCI_Command.command(return_parameters_fields=[("status", STATUS_SPEC), ("version", 1)])
@@ -187,10 +189,6 @@ class HCI_RTK_Read_ROM_Version_Command(HCI_Command):
pass pass
HCI_RTK_DOWNLOAD_COMMAND = hci_command_op_code(0x3F, 0x20)
HCI_COMMAND_NAMES[HCI_RTK_DOWNLOAD_COMMAND] = "HCI_RTK_DOWNLOAD_COMMAND"
@HCI_Command.command( @HCI_Command.command(
fields=[("index", 1), ("payload", RTK_FRAGMENT_LENGTH)], fields=[("index", 1), ("payload", RTK_FRAGMENT_LENGTH)],
return_parameters_fields=[("status", STATUS_SPEC), ("index", 1)], return_parameters_fields=[("status", STATUS_SPEC), ("index", 1)],
@@ -199,10 +197,6 @@ class HCI_RTK_Download_Command(HCI_Command):
pass pass
HCI_RTK_DROP_FIRMWARE_COMMAND = hci_command_op_code(0x3F, 0x66)
HCI_COMMAND_NAMES[HCI_RTK_DROP_FIRMWARE_COMMAND] = "HCI_RTK_DROP_FIRMWARE_COMMAND"
@HCI_Command.command() @HCI_Command.command()
class HCI_RTK_Drop_Firmware_Command(HCI_Command): class HCI_RTK_Drop_Firmware_Command(HCI_Command):
pass pass
@@ -445,6 +439,11 @@ class Driver:
# When the environment variable is set, don't look elsewhere # When the environment variable is set, don't look elsewhere
return None return None
# Then, look where the firmware download tool writes by default
if (path := rtk_firmware_dir() / file_name).is_file():
logger.debug(f"{file_name} found in project data dir")
return path
# Then, look in the package's driver directory # Then, look in the package's driver directory
if (path := pathlib.Path(__file__).parent / "rtk_fw" / file_name).is_file(): if (path := pathlib.Path(__file__).parent / "rtk_fw" / file_name).is_file():
logger.debug(f"{file_name} found in package dir") logger.debug(f"{file_name} found in package dir")
@@ -645,3 +644,16 @@ class Driver:
await self.download_firmware() await self.download_firmware()
await self.host.send_command(HCI_Reset_Command(), check_result=True) await self.host.send_command(HCI_Reset_Command(), check_result=True)
logger.info(f"loaded FW image {self.driver_info.fw_name}") logger.info(f"loaded FW image {self.driver_info.fw_name}")
def rtk_firmware_dir() -> pathlib.Path:
"""
Returns:
A path to a subdir of the project data dir for Realtek firmware.
The directory is created if it doesn't exist.
"""
from bumble.drivers import project_data_dir
p = project_data_dir() / "firmware" / "realtek"
p.mkdir(parents=True, exist_ok=True)
return p

View File

@@ -16,11 +16,11 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations from __future__ import annotations
import struct
import collections import collections
import logging
import functools import functools
from typing import Dict, Type, Union import logging
import struct
from typing import Any, Dict, Callable, Optional, Type, Union
from .colors import color from .colors import color
from .core import ( from .core import (
@@ -47,6 +47,10 @@ def hci_command_op_code(ogf, ocf):
return ogf << 10 | ocf return ogf << 10 | ocf
def hci_vendor_command_op_code(ocf):
return hci_command_op_code(HCI_VENDOR_OGF, ocf)
def key_with_value(dictionary, target_value): def key_with_value(dictionary, target_value):
for key, value in dictionary.items(): for key, value in dictionary.items():
if value == target_value: if value == target_value:
@@ -101,6 +105,8 @@ def phy_list_to_bits(phys):
# fmt: off # fmt: off
# pylint: disable=line-too-long # pylint: disable=line-too-long
HCI_VENDOR_OGF = 0x3F
# HCI Version # HCI Version
HCI_VERSION_BLUETOOTH_CORE_1_0B = 0 HCI_VERSION_BLUETOOTH_CORE_1_0B = 0
HCI_VERSION_BLUETOOTH_CORE_1_1 = 1 HCI_VERSION_BLUETOOTH_CORE_1_1 = 1
@@ -206,10 +212,8 @@ HCI_INQUIRY_RESPONSE_NOTIFICATION_EVENT = 0X56
HCI_AUTHENTICATED_PAYLOAD_TIMEOUT_EXPIRED_EVENT = 0X57 HCI_AUTHENTICATED_PAYLOAD_TIMEOUT_EXPIRED_EVENT = 0X57
HCI_SAM_STATUS_CHANGE_EVENT = 0X58 HCI_SAM_STATUS_CHANGE_EVENT = 0X58
HCI_EVENT_NAMES = { HCI_VENDOR_EVENT = 0xFF
event_code: event_name for (event_name, event_code) in globals().items()
if event_name.startswith('HCI_') and event_name.endswith('_EVENT')
}
# HCI Subevent Codes # HCI Subevent Codes
HCI_LE_CONNECTION_COMPLETE_EVENT = 0x01 HCI_LE_CONNECTION_COMPLETE_EVENT = 0x01
@@ -248,10 +252,6 @@ HCI_LE_TRANSMIT_POWER_REPORTING_EVENT = 0X21
HCI_LE_BIGINFO_ADVERTISING_REPORT_EVENT = 0X22 HCI_LE_BIGINFO_ADVERTISING_REPORT_EVENT = 0X22
HCI_LE_SUBRATE_CHANGE_EVENT = 0X23 HCI_LE_SUBRATE_CHANGE_EVENT = 0X23
HCI_SUBEVENT_NAMES = {
event_code: event_name for (event_name, event_code) in globals().items()
if event_name.startswith('HCI_LE_') and event_name.endswith('_EVENT') and event_code != HCI_LE_META_EVENT
}
# HCI Command # HCI Command
HCI_INQUIRY_COMMAND = hci_command_op_code(0x01, 0x0001) HCI_INQUIRY_COMMAND = hci_command_op_code(0x01, 0x0001)
@@ -557,10 +557,6 @@ HCI_LE_SET_DATA_RELATED_ADDRESS_CHANGES_COMMAND = hci_c
HCI_LE_SET_DEFAULT_SUBRATE_COMMAND = hci_command_op_code(0x08, 0x007D) HCI_LE_SET_DEFAULT_SUBRATE_COMMAND = hci_command_op_code(0x08, 0x007D)
HCI_LE_SUBRATE_REQUEST_COMMAND = hci_command_op_code(0x08, 0x007E) HCI_LE_SUBRATE_REQUEST_COMMAND = hci_command_op_code(0x08, 0x007E)
HCI_COMMAND_NAMES = {
command_code: command_name for (command_name, command_code) in globals().items()
if command_name.startswith('HCI_') and command_name.endswith('_COMMAND')
}
# HCI Error Codes # HCI Error Codes
# See Bluetooth spec Vol 2, Part D - 1.3 LIST OF ERROR CODES # See Bluetooth spec Vol 2, Part D - 1.3 LIST OF ERROR CODES
@@ -1445,8 +1441,14 @@ class HCI_Object:
@staticmethod @staticmethod
def init_from_fields(hci_object, fields, values): def init_from_fields(hci_object, fields, values):
if isinstance(values, dict): if isinstance(values, dict):
for field_name, _ in fields: for field in fields:
setattr(hci_object, field_name, values[field_name]) if isinstance(field, list):
# The field is an array, up-level the array field names
for sub_field_name, _ in field:
setattr(hci_object, sub_field_name, values[sub_field_name])
else:
field_name = field[0]
setattr(hci_object, field_name, values[field_name])
else: else:
for field_name, field_value in zip(fields, values): for field_name, field_value in zip(fields, values):
setattr(hci_object, field_name, field_value) setattr(hci_object, field_name, field_value)
@@ -1456,133 +1458,161 @@ class HCI_Object:
parsed = HCI_Object.dict_from_bytes(data, offset, fields) parsed = HCI_Object.dict_from_bytes(data, offset, fields)
HCI_Object.init_from_fields(hci_object, parsed.keys(), parsed.values()) HCI_Object.init_from_fields(hci_object, parsed.keys(), parsed.values())
@staticmethod
def parse_field(data, offset, field_type):
# The field_type may be a dictionary with a mapper, parser, and/or size
if isinstance(field_type, dict):
if 'size' in field_type:
field_type = field_type['size']
elif 'parser' in field_type:
field_type = field_type['parser']
# Parse the field
if field_type == '*':
# The rest of the bytes
field_value = data[offset:]
return (field_value, len(field_value))
if field_type == 1:
# 8-bit unsigned
return (data[offset], 1)
if field_type == -1:
# 8-bit signed
return (struct.unpack_from('b', data, offset)[0], 1)
if field_type == 2:
# 16-bit unsigned
return (struct.unpack_from('<H', data, offset)[0], 2)
if field_type == '>2':
# 16-bit unsigned big-endian
return (struct.unpack_from('>H', data, offset)[0], 2)
if field_type == -2:
# 16-bit signed
return (struct.unpack_from('<h', data, offset)[0], 2)
if field_type == 3:
# 24-bit unsigned
padded = data[offset : offset + 3] + bytes([0])
return (struct.unpack('<I', padded)[0], 3)
if field_type == 4:
# 32-bit unsigned
return (struct.unpack_from('<I', data, offset)[0], 4)
if field_type == '>4':
# 32-bit unsigned big-endian
return (struct.unpack_from('>I', data, offset)[0], 4)
if isinstance(field_type, int) and 4 < field_type <= 256:
# Byte array (from 5 up to 256 bytes)
return (data[offset : offset + field_type], field_type)
if callable(field_type):
new_offset, field_value = field_type(data, offset)
return (field_value, new_offset - offset)
raise ValueError(f'unknown field type {field_type}')
@staticmethod @staticmethod
def dict_from_bytes(data, offset, fields): def dict_from_bytes(data, offset, fields):
result = collections.OrderedDict() result = collections.OrderedDict()
for (field_name, field_type) in fields: for field in fields:
# The field_type may be a dictionary with a mapper, parser, and/or size if isinstance(field, list):
if isinstance(field_type, dict): # This is an array field, starting with a 1-byte item count.
if 'size' in field_type: item_count = data[offset]
field_type = field_type['size']
elif 'parser' in field_type:
field_type = field_type['parser']
# Parse the field
if field_type == '*':
# The rest of the bytes
field_value = data[offset:]
offset += len(field_value)
elif field_type == 1:
# 8-bit unsigned
field_value = data[offset]
offset += 1 offset += 1
elif field_type == -1: for _ in range(item_count):
# 8-bit signed for sub_field_name, sub_field_type in field:
field_value = struct.unpack_from('b', data, offset)[0] value, size = HCI_Object.parse_field(
offset += 1 data, offset, sub_field_type
elif field_type == 2: )
# 16-bit unsigned result.setdefault(sub_field_name, []).append(value)
field_value = struct.unpack_from('<H', data, offset)[0] offset += size
offset += 2 continue
elif field_type == '>2':
# 16-bit unsigned big-endian
field_value = struct.unpack_from('>H', data, offset)[0]
offset += 2
elif field_type == -2:
# 16-bit signed
field_value = struct.unpack_from('<h', data, offset)[0]
offset += 2
elif field_type == 3:
# 24-bit unsigned
padded = data[offset : offset + 3] + bytes([0])
field_value = struct.unpack('<I', padded)[0]
offset += 3
elif field_type == 4:
# 32-bit unsigned
field_value = struct.unpack_from('<I', data, offset)[0]
offset += 4
elif field_type == '>4':
# 32-bit unsigned big-endian
field_value = struct.unpack_from('>I', data, offset)[0]
offset += 4
elif isinstance(field_type, int) and 4 < field_type <= 256:
# Byte array (from 5 up to 256 bytes)
field_value = data[offset : offset + field_type]
offset += field_type
elif callable(field_type):
offset, field_value = field_type(data, offset)
else:
raise ValueError(f'unknown field type {field_type}')
field_name, field_type = field
field_value, field_size = HCI_Object.parse_field(data, offset, field_type)
result[field_name] = field_value result[field_name] = field_value
offset += field_size
return result return result
@staticmethod
def serialize_field(field_value, field_type):
# The field_type may be a dictionary with a mapper, parser, serializer,
# and/or size
serializer = None
if isinstance(field_type, dict):
if 'serializer' in field_type:
serializer = field_type['serializer']
if 'size' in field_type:
field_type = field_type['size']
# Serialize the field
if serializer:
field_bytes = serializer(field_value)
elif field_type == 1:
# 8-bit unsigned
field_bytes = bytes([field_value])
elif field_type == -1:
# 8-bit signed
field_bytes = struct.pack('b', field_value)
elif field_type == 2:
# 16-bit unsigned
field_bytes = struct.pack('<H', field_value)
elif field_type == '>2':
# 16-bit unsigned big-endian
field_bytes = struct.pack('>H', field_value)
elif field_type == -2:
# 16-bit signed
field_bytes = struct.pack('<h', field_value)
elif field_type == 3:
# 24-bit unsigned
field_bytes = struct.pack('<I', field_value)[0:3]
elif field_type == 4:
# 32-bit unsigned
field_bytes = struct.pack('<I', field_value)
elif field_type == '>4':
# 32-bit unsigned big-endian
field_bytes = struct.pack('>I', field_value)
elif field_type == '*':
if isinstance(field_value, int):
if 0 <= field_value <= 255:
field_bytes = bytes([field_value])
else:
raise ValueError('value too large for *-typed field')
else:
field_bytes = bytes(field_value)
elif isinstance(field_value, (bytes, bytearray)) or hasattr(
field_value, 'to_bytes'
):
field_bytes = bytes(field_value)
if isinstance(field_type, int) and 4 < field_type <= 256:
# Truncate or pad with zeros if the field is too long or too short
if len(field_bytes) < field_type:
field_bytes += bytes(field_type - len(field_bytes))
elif len(field_bytes) > field_type:
field_bytes = field_bytes[:field_type]
else:
raise ValueError(f"don't know how to serialize type {type(field_value)}")
return field_bytes
@staticmethod @staticmethod
def dict_to_bytes(hci_object, fields): def dict_to_bytes(hci_object, fields):
result = bytearray() result = bytearray()
for (field_name, field_type) in fields: for field in fields:
# The field_type may be a dictionary with a mapper, parser, serializer, if isinstance(field, list):
# and/or size # The field is an array. The serialized form starts with a 1-byte
serializer = None # item count. We use the length of the first array field as the
if isinstance(field_type, dict): # array count, since all array fields have the same number of items.
if 'serializer' in field_type: item_count = len(hci_object[field[0][0]])
serializer = field_type['serializer'] result += bytes([item_count]) + b''.join(
if 'size' in field_type: b''.join(
field_type = field_type['size'] HCI_Object.serialize_field(
hci_object[sub_field_name][i], sub_field_type
# Serialize the field )
field_value = hci_object[field_name] for sub_field_name, sub_field_type in field
if serializer: )
field_bytes = serializer(field_value) for i in range(item_count)
elif field_type == 1:
# 8-bit unsigned
field_bytes = bytes([field_value])
elif field_type == -1:
# 8-bit signed
field_bytes = struct.pack('b', field_value)
elif field_type == 2:
# 16-bit unsigned
field_bytes = struct.pack('<H', field_value)
elif field_type == '>2':
# 16-bit unsigned big-endian
field_bytes = struct.pack('>H', field_value)
elif field_type == -2:
# 16-bit signed
field_bytes = struct.pack('<h', field_value)
elif field_type == 3:
# 24-bit unsigned
field_bytes = struct.pack('<I', field_value)[0:3]
elif field_type == 4:
# 32-bit unsigned
field_bytes = struct.pack('<I', field_value)
elif field_type == '>4':
# 32-bit unsigned big-endian
field_bytes = struct.pack('>I', field_value)
elif field_type == '*':
if isinstance(field_value, int):
if 0 <= field_value <= 255:
field_bytes = bytes([field_value])
else:
raise ValueError('value too large for *-typed field')
else:
field_bytes = bytes(field_value)
elif isinstance(field_value, (bytes, bytearray)) or hasattr(
field_value, 'to_bytes'
):
field_bytes = bytes(field_value)
if isinstance(field_type, int) and 4 < field_type <= 256:
# Truncate or Pad with zeros if the field is too long or too short
if len(field_bytes) < field_type:
field_bytes += bytes(field_type - len(field_bytes))
elif len(field_bytes) > field_type:
field_bytes = field_bytes[:field_type]
else:
raise ValueError(
f"don't know how to serialize type {type(field_value)}"
) )
continue
result += field_bytes (field_name, field_type) = field
result += HCI_Object.serialize_field(hci_object[field_name], field_type)
return bytes(result) return bytes(result)
@@ -1617,48 +1647,73 @@ class HCI_Object:
return str(value) return str(value)
@staticmethod @staticmethod
def format_fields(hci_object, keys, indentation='', value_mappers=None): def stringify_field(
if not keys: field_name, field_type, field_value, indentation, value_mappers
return '' ):
value_mapper = None
if isinstance(field_type, dict):
# Get the value mapper from the specifier
value_mapper = field_type.get('mapper')
# Measure the widest field name # Check if there's a matching mapper passed
max_field_name_length = max( if value_mappers:
(len(key[0] if isinstance(key, tuple) else key) for key in keys) value_mapper = value_mappers.get(field_name, value_mapper)
# Map the value if we have a mapper
if value_mapper is not None:
field_value = value_mapper(field_value)
# Get the string representation of the value
return HCI_Object.format_field_value(
field_value, indentation=indentation + ' '
) )
@staticmethod
def format_fields(hci_object, fields, indentation='', value_mappers=None):
if not fields:
return ''
# Build array of formatted key:value pairs # Build array of formatted key:value pairs
fields = [] field_strings = []
for key in keys: for field in fields:
value_mapper = None if isinstance(field, list):
if isinstance(key, tuple): for sub_field in field:
# The key has an associated specifier sub_field_name, sub_field_type = sub_field
key, specifier = key item_count = len(hci_object[sub_field_name])
for i in range(item_count):
field_strings.append(
(
f'{sub_field_name}[{i}]',
HCI_Object.stringify_field(
sub_field_name,
sub_field_type,
hci_object[sub_field_name][i],
indentation,
value_mappers,
),
),
)
continue
# Get the value mapper from the specifier field_name, field_type = field
if isinstance(specifier, dict): field_value = hci_object[field_name]
value_mapper = specifier.get('mapper') field_strings.append(
(
# Get the value for the field field_name,
value = hci_object[key] HCI_Object.stringify_field(
field_name, field_type, field_value, indentation, value_mappers
# Check if there's a matching mapper passed ),
if value_mappers: ),
value_mapper = value_mappers.get(key, value_mapper)
# Map the value if we have a mapper
if value_mapper is not None:
value = value_mapper(value)
# Get the string representation of the value
value_str = HCI_Object.format_field_value(
value, indentation=indentation + ' '
) )
# Add the field to the formatted result # Measure the widest field name
key_str = color(f'{key + ":":{1 + max_field_name_length}}', 'cyan') max_field_name_length = max(len(s[0]) for s in field_strings)
fields.append(f'{indentation}{key_str} {value_str}') sep = ':'
return '\n'.join(
return '\n'.join(fields) f'{indentation}'
f'{color(f"{field_name + sep:{1 + max_field_name_length}}", "cyan")} {field_value}'
for field_name, field_value in field_strings
)
def __bytes__(self): def __bytes__(self):
return self.to_bytes() return self.to_bytes()
@@ -1859,7 +1914,7 @@ class HCI_Packet:
hci_packet_type: int hci_packet_type: int
@staticmethod @staticmethod
def from_bytes(packet): def from_bytes(packet: bytes) -> HCI_Packet:
packet_type = packet[0] packet_type = packet[0]
if packet_type == HCI_COMMAND_PACKET: if packet_type == HCI_COMMAND_PACKET:
@@ -1901,6 +1956,7 @@ class HCI_Command(HCI_Packet):
''' '''
hci_packet_type = HCI_COMMAND_PACKET hci_packet_type = HCI_COMMAND_PACKET
command_names: Dict[int, str] = {}
command_classes: Dict[int, Type[HCI_Command]] = {} command_classes: Dict[int, Type[HCI_Command]] = {}
@staticmethod @staticmethod
@@ -1911,9 +1967,9 @@ class HCI_Command(HCI_Packet):
def inner(cls): def inner(cls):
cls.name = cls.__name__.upper() cls.name = cls.__name__.upper()
cls.op_code = key_with_value(HCI_COMMAND_NAMES, cls.name) cls.op_code = key_with_value(cls.command_names, cls.name)
if cls.op_code is None: if cls.op_code is None:
raise KeyError(f'command {cls.name} not found in HCI_COMMAND_NAMES') raise KeyError(f'command {cls.name} not found in command_names')
cls.fields = fields cls.fields = fields
cls.return_parameters_fields = return_parameters_fields cls.return_parameters_fields = return_parameters_fields
@@ -1933,7 +1989,19 @@ class HCI_Command(HCI_Packet):
return inner return inner
@staticmethod @staticmethod
def from_bytes(packet): def command_map(symbols: Dict[str, Any]) -> Dict[int, str]:
return {
command_code: command_name
for (command_name, command_code) in symbols.items()
if command_name.startswith('HCI_') and command_name.endswith('_COMMAND')
}
@classmethod
def register_commands(cls, symbols: Dict[str, Any]) -> None:
cls.command_names.update(cls.command_map(symbols))
@staticmethod
def from_bytes(packet: bytes) -> HCI_Command:
op_code, length = struct.unpack_from('<HB', packet, 1) op_code, length = struct.unpack_from('<HB', packet, 1)
parameters = packet[4:] parameters = packet[4:]
if len(parameters) != length: if len(parameters) != length:
@@ -1952,11 +2020,11 @@ class HCI_Command(HCI_Packet):
HCI_Object.init_from_bytes(self, parameters, 0, fields) HCI_Object.init_from_bytes(self, parameters, 0, fields)
return self return self
return cls.from_parameters(parameters) return cls.from_parameters(parameters) # type: ignore
@staticmethod @staticmethod
def command_name(op_code): def command_name(op_code):
name = HCI_COMMAND_NAMES.get(op_code) name = HCI_Command.command_names.get(op_code)
if name is not None: if name is not None:
return name return name
return f'[OGF=0x{op_code >> 10:02x}, OCF=0x{op_code & 0x3FF:04x}]' return f'[OGF=0x{op_code >> 10:02x}, OCF=0x{op_code & 0x3FF:04x}]'
@@ -1965,6 +2033,16 @@ class HCI_Command(HCI_Packet):
def create_return_parameters(cls, **kwargs): def create_return_parameters(cls, **kwargs):
return HCI_Object(cls.return_parameters_fields, **kwargs) return HCI_Object(cls.return_parameters_fields, **kwargs)
@classmethod
def parse_return_parameters(cls, parameters):
if not cls.return_parameters_fields:
return None
return_parameters = HCI_Object.from_bytes(
parameters, 0, cls.return_parameters_fields
)
return_parameters.fields = cls.return_parameters_fields
return return_parameters
def __init__(self, op_code, parameters=None, **kwargs): def __init__(self, op_code, parameters=None, **kwargs):
super().__init__(HCI_Command.command_name(op_code)) super().__init__(HCI_Command.command_name(op_code))
if (fields := getattr(self, 'fields', None)) and kwargs: if (fields := getattr(self, 'fields', None)) and kwargs:
@@ -1994,6 +2072,9 @@ class HCI_Command(HCI_Packet):
return result return result
HCI_Command.register_commands(globals())
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@HCI_Command.command( @HCI_Command.command(
[ [
@@ -3769,9 +3850,7 @@ class HCI_LE_Set_Extended_Advertising_Parameters_Command(HCI_Command):
'advertising_data', 'advertising_data',
{ {
'parser': HCI_Object.parse_length_prefixed_bytes, 'parser': HCI_Object.parse_length_prefixed_bytes,
'serializer': functools.partial( 'serializer': HCI_Object.serialize_length_prefixed_bytes,
HCI_Object.serialize_length_prefixed_bytes
),
}, },
), ),
] ]
@@ -3819,9 +3898,7 @@ class HCI_LE_Set_Extended_Advertising_Data_Command(HCI_Command):
'scan_response_data', 'scan_response_data',
{ {
'parser': HCI_Object.parse_length_prefixed_bytes, 'parser': HCI_Object.parse_length_prefixed_bytes,
'serializer': functools.partial( 'serializer': HCI_Object.serialize_length_prefixed_bytes,
HCI_Object.serialize_length_prefixed_bytes
),
}, },
), ),
] ]
@@ -3849,73 +3926,21 @@ class HCI_LE_Set_Extended_Scan_Response_Data_Command(HCI_Command):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@HCI_Command.command(fields=None) @HCI_Command.command(
[
('enable', 1),
[
('advertising_handles', 1),
('durations', 2),
('max_extended_advertising_events', 1),
],
]
)
class HCI_LE_Set_Extended_Advertising_Enable_Command(HCI_Command): class HCI_LE_Set_Extended_Advertising_Enable_Command(HCI_Command):
''' '''
See Bluetooth spec @ 7.8.56 LE Set Extended Advertising Enable Command See Bluetooth spec @ 7.8.56 LE Set Extended Advertising Enable Command
''' '''
@classmethod
def from_parameters(cls, parameters):
enable = parameters[0]
num_sets = parameters[1]
advertising_handles = []
durations = []
max_extended_advertising_events = []
offset = 2
for _ in range(num_sets):
advertising_handles.append(parameters[offset])
durations.append(struct.unpack_from('<H', parameters, offset + 1)[0])
max_extended_advertising_events.append(parameters[offset + 3])
offset += 4
return cls(
enable, advertising_handles, durations, max_extended_advertising_events
)
def __init__(
self, enable, advertising_handles, durations, max_extended_advertising_events
):
super().__init__(HCI_LE_SET_EXTENDED_ADVERTISING_ENABLE_COMMAND)
self.enable = enable
self.advertising_handles = advertising_handles
self.durations = durations
self.max_extended_advertising_events = max_extended_advertising_events
self.parameters = bytes([enable, len(advertising_handles)]) + b''.join(
[
struct.pack(
'<BHB',
advertising_handles[i],
durations[i],
max_extended_advertising_events[i],
)
for i in range(len(advertising_handles))
]
)
def __str__(self):
fields = [('enable:', self.enable)]
for i, advertising_handle in enumerate(self.advertising_handles):
fields.append(
(f'advertising_handle[{i}]: ', advertising_handle)
)
fields.append((f'duration[{i}]: ', self.durations[i]))
fields.append(
(
f'max_extended_advertising_events[{i}]:',
self.max_extended_advertising_events[i],
)
)
return (
color(self.name, 'green')
+ ':\n'
+ '\n'.join(
[color(field[0], 'cyan') + ' ' + str(field[1]) for field in fields]
)
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@HCI_Command.command( @HCI_Command.command(
@@ -4066,7 +4091,10 @@ class HCI_LE_Set_Extended_Scan_Parameters_Command(HCI_Command):
color(self.name, 'green') color(self.name, 'green')
+ ':\n' + ':\n'
+ '\n'.join( + '\n'.join(
[color(field[0], 'cyan') + ' ' + str(field[1]) for field in fields] [
color(' ' + field[0], 'cyan') + ' ' + str(field[1])
for field in fields
]
) )
) )
@@ -4242,7 +4270,10 @@ class HCI_LE_Extended_Create_Connection_Command(HCI_Command):
color(self.name, 'green') color(self.name, 'green')
+ ':\n' + ':\n'
+ '\n'.join( + '\n'.join(
[color(field[0], 'cyan') + ' ' + str(field[1]) for field in fields] [
color(' ' + field[0], 'cyan') + ' ' + str(field[1])
for field in fields
]
) )
) )
@@ -4299,8 +4330,8 @@ class HCI_Event(HCI_Packet):
''' '''
hci_packet_type = HCI_EVENT_PACKET hci_packet_type = HCI_EVENT_PACKET
event_names: Dict[int, str] = {}
event_classes: Dict[int, Type[HCI_Event]] = {} event_classes: Dict[int, Type[HCI_Event]] = {}
meta_event_classes: Dict[int, Type[HCI_LE_Meta_Event]] = {}
@staticmethod @staticmethod
def event(fields=()): def event(fields=()):
@@ -4310,9 +4341,9 @@ class HCI_Event(HCI_Packet):
def inner(cls): def inner(cls):
cls.name = cls.__name__.upper() cls.name = cls.__name__.upper()
cls.event_code = key_with_value(HCI_EVENT_NAMES, cls.name) cls.event_code = key_with_value(cls.event_names, cls.name)
if cls.event_code is None: if cls.event_code is None:
raise KeyError('event not found in HCI_EVENT_NAMES') raise KeyError(f'event {cls.name} not found in event_names')
cls.fields = fields cls.fields = fields
# Patch the __init__ method to fix the event_code # Patch the __init__ method to fix the event_code
@@ -4328,12 +4359,30 @@ class HCI_Event(HCI_Packet):
return inner return inner
@staticmethod
def event_map(symbols: Dict[str, Any]) -> Dict[int, str]:
return {
event_code: event_name
for (event_name, event_code) in symbols.items()
if event_name.startswith('HCI_')
and not event_name.startswith('HCI_LE_')
and event_name.endswith('_EVENT')
}
@staticmethod
def event_name(event_code):
return name_or_number(HCI_Event.event_names, event_code)
@staticmethod
def register_events(symbols: Dict[str, Any]) -> None:
HCI_Event.event_names.update(HCI_Event.event_map(symbols))
@staticmethod @staticmethod
def registered(event_class): def registered(event_class):
event_class.name = event_class.__name__.upper() event_class.name = event_class.__name__.upper()
event_class.event_code = key_with_value(HCI_EVENT_NAMES, event_class.name) event_class.event_code = key_with_value(HCI_Event.event_names, event_class.name)
if event_class.event_code is None: if event_class.event_code is None:
raise KeyError('event not found in HCI_EVENT_NAMES') raise KeyError(f'event {event_class.name} not found in event_names')
# Register a factory for this class # Register a factory for this class
HCI_Event.event_classes[event_class.event_code] = event_class HCI_Event.event_classes[event_class.event_code] = event_class
@@ -4341,22 +4390,28 @@ class HCI_Event(HCI_Packet):
return event_class return event_class
@staticmethod @staticmethod
def from_bytes(packet): def from_bytes(packet: bytes) -> HCI_Event:
event_code = packet[1] event_code = packet[1]
length = packet[2] length = packet[2]
parameters = packet[3:] parameters = packet[3:]
if len(parameters) != length: if len(parameters) != length:
raise ValueError('invalid packet length') raise ValueError('invalid packet length')
cls: Type[HCI_Event | HCI_LE_Meta_Event] | None
if event_code == HCI_LE_META_EVENT: if event_code == HCI_LE_META_EVENT:
# We do this dispatch here and not in the subclass in order to avoid call # We do this dispatch here and not in the subclass in order to avoid call
# loops # loops
subevent_code = parameters[0] subevent_code = parameters[0]
cls = HCI_Event.meta_event_classes.get(subevent_code) cls = HCI_LE_Meta_Event.subevent_classes.get(subevent_code)
if cls is None: if cls is None:
# No class registered, just use a generic class instance # No class registered, just use a generic class instance
return HCI_LE_Meta_Event(subevent_code, parameters) return HCI_LE_Meta_Event(subevent_code, parameters)
elif event_code == HCI_VENDOR_EVENT:
subevent_code = parameters[0]
cls = HCI_Vendor_Event.subevent_classes.get(subevent_code)
if cls is None:
# No class registered, just use a generic class instance
return HCI_Vendor_Event(subevent_code, parameters)
else: else:
cls = HCI_Event.event_classes.get(event_code) cls = HCI_Event.event_classes.get(event_code)
if cls is None: if cls is None:
@@ -4364,7 +4419,7 @@ class HCI_Event(HCI_Packet):
return HCI_Event(event_code, parameters) return HCI_Event(event_code, parameters)
# Invoke the factory to create a new instance # Invoke the factory to create a new instance
return cls.from_parameters(parameters) return cls.from_parameters(parameters) # type: ignore
@classmethod @classmethod
def from_parameters(cls, parameters): def from_parameters(cls, parameters):
@@ -4374,10 +4429,6 @@ class HCI_Event(HCI_Packet):
HCI_Object.init_from_bytes(self, parameters, 0, fields) HCI_Object.init_from_bytes(self, parameters, 0, fields)
return self return self
@staticmethod
def event_name(event_code):
return name_or_number(HCI_EVENT_NAMES, event_code)
def __init__(self, event_code, parameters=None, **kwargs): def __init__(self, event_code, parameters=None, **kwargs):
super().__init__(HCI_Event.event_name(event_code)) super().__init__(HCI_Event.event_name(event_code))
if (fields := getattr(self, 'fields', None)) and kwargs: if (fields := getattr(self, 'fields', None)) and kwargs:
@@ -4404,71 +4455,111 @@ class HCI_Event(HCI_Packet):
return result return result
HCI_Event.register_events(globals())
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class HCI_LE_Meta_Event(HCI_Event): class HCI_Extended_Event(HCI_Event):
''' '''
See Bluetooth spec @ 7.7.65 LE Meta Event HCI_Event subclass for events that has a subevent code.
''' '''
@staticmethod subevent_names: Dict[int, str] = {}
def event(fields=()): subevent_classes: Dict[int, Type[HCI_Extended_Event]]
@classmethod
def event(cls, fields=()):
''' '''
Decorator used to declare and register subclasses Decorator used to declare and register subclasses
''' '''
def inner(cls): def inner(cls):
cls.name = cls.__name__.upper() cls.name = cls.__name__.upper()
cls.subevent_code = key_with_value(HCI_SUBEVENT_NAMES, cls.name) cls.subevent_code = key_with_value(cls.subevent_names, cls.name)
if cls.subevent_code is None: if cls.subevent_code is None:
raise KeyError('subevent not found in HCI_SUBEVENT_NAMES') raise KeyError(f'subevent {cls.name} not found in subevent_names')
cls.fields = fields cls.fields = fields
# Patch the __init__ method to fix the subevent_code # Patch the __init__ method to fix the subevent_code
original_init = cls.__init__
def init(self, parameters=None, **kwargs): def init(self, parameters=None, **kwargs):
return HCI_LE_Meta_Event.__init__( return original_init(self, cls.subevent_code, parameters, **kwargs)
self, cls.subevent_code, parameters, **kwargs
)
cls.__init__ = init cls.__init__ = init
# Register a factory for this class # Register a factory for this class
HCI_Event.meta_event_classes[cls.subevent_code] = cls cls.subevent_classes[cls.subevent_code] = cls
return cls return cls
return inner return inner
@classmethod
def subevent_name(cls, subevent_code):
subevent_name = cls.subevent_names.get(subevent_code)
if subevent_name is not None:
return subevent_name
return f'{cls.__name__.upper()}[0x{subevent_code:02X}]'
@staticmethod
def subevent_map(symbols: Dict[str, Any]) -> Dict[int, str]:
return {
subevent_code: subevent_name
for (subevent_name, subevent_code) in symbols.items()
if subevent_name.startswith('HCI_') and subevent_name.endswith('_EVENT')
}
@classmethod
def register_subevents(cls, symbols: Dict[str, Any]) -> None:
cls.subevent_names.update(cls.subevent_map(symbols))
@classmethod @classmethod
def from_parameters(cls, parameters): def from_parameters(cls, parameters):
self = cls.__new__(cls) self = cls.__new__(cls)
HCI_LE_Meta_Event.__init__(self, self.subevent_code, parameters) HCI_Extended_Event.__init__(self, self.subevent_code, parameters)
if fields := getattr(self, 'fields', None): if fields := getattr(self, 'fields', None):
HCI_Object.init_from_bytes(self, parameters, 1, fields) HCI_Object.init_from_bytes(self, parameters, 1, fields)
return self return self
@staticmethod
def subevent_name(subevent_code):
return name_or_number(HCI_SUBEVENT_NAMES, subevent_code)
def __init__(self, subevent_code, parameters, **kwargs): def __init__(self, subevent_code, parameters, **kwargs):
self.subevent_code = subevent_code self.subevent_code = subevent_code
if parameters is None and (fields := getattr(self, 'fields', None)) and kwargs: if parameters is None and (fields := getattr(self, 'fields', None)) and kwargs:
parameters = bytes([subevent_code]) + HCI_Object.dict_to_bytes( parameters = bytes([subevent_code]) + HCI_Object.dict_to_bytes(
kwargs, fields kwargs, fields
) )
super().__init__(HCI_LE_META_EVENT, parameters, **kwargs) super().__init__(self.event_code, parameters, **kwargs)
# Override the name in order to adopt the subevent name instead # Override the name in order to adopt the subevent name instead
self.name = self.subevent_name(subevent_code) self.name = self.subevent_name(subevent_code)
def __str__(self):
result = color(self.subevent_name(self.subevent_code), 'magenta') # -----------------------------------------------------------------------------
if fields := getattr(self, 'fields', None): class HCI_LE_Meta_Event(HCI_Extended_Event):
result += ':\n' + HCI_Object.format_fields(self.__dict__, fields, ' ') '''
else: See Bluetooth spec @ 7.7.65 LE Meta Event
if self.parameters: '''
result += f': {self.parameters.hex()}'
return result event_code: int = HCI_LE_META_EVENT
subevent_classes = {}
@staticmethod
def subevent_map(symbols: Dict[str, Any]) -> Dict[int, str]:
return {
subevent_code: subevent_name
for (subevent_name, subevent_code) in symbols.items()
if subevent_name.startswith('HCI_LE_') and subevent_name.endswith('_EVENT')
}
HCI_LE_Meta_Event.register_subevents(globals())
# -----------------------------------------------------------------------------
class HCI_Vendor_Event(HCI_Extended_Event):
event_code: int = HCI_VENDOR_EVENT
subevent_classes = {}
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -4582,7 +4673,7 @@ class HCI_LE_Advertising_Report_Event(HCI_LE_Meta_Event):
return f'{color(self.subevent_name(self.subevent_code), "magenta")}:\n{reports}' return f'{color(self.subevent_name(self.subevent_code), "magenta")}:\n{reports}'
HCI_Event.meta_event_classes[ HCI_LE_Meta_Event.subevent_classes[
HCI_LE_ADVERTISING_REPORT_EVENT HCI_LE_ADVERTISING_REPORT_EVENT
] = HCI_LE_Advertising_Report_Event ] = HCI_LE_Advertising_Report_Event
@@ -4836,7 +4927,7 @@ class HCI_LE_Extended_Advertising_Report_Event(HCI_LE_Meta_Event):
return f'{color(self.subevent_name(self.subevent_code), "magenta")}:\n{reports}' return f'{color(self.subevent_name(self.subevent_code), "magenta")}:\n{reports}'
HCI_Event.meta_event_classes[ HCI_LE_Meta_Event.subevent_classes[
HCI_LE_EXTENDED_ADVERTISING_REPORT_EVENT HCI_LE_EXTENDED_ADVERTISING_REPORT_EVENT
] = HCI_LE_Extended_Advertising_Report_Event ] = HCI_LE_Extended_Advertising_Report_Event
@@ -5077,6 +5168,7 @@ class HCI_Command_Complete_Event(HCI_Event):
''' '''
return_parameters = b'' return_parameters = b''
command_opcode: int
def map_return_parameters(self, return_parameters): def map_return_parameters(self, return_parameters):
'''Map simple 'status' return parameters to their named constant form''' '''Map simple 'status' return parameters to their named constant form'''
@@ -5109,11 +5201,11 @@ class HCI_Command_Complete_Event(HCI_Event):
self.return_parameters = self.return_parameters[0] self.return_parameters = self.return_parameters[0]
else: else:
cls = HCI_Command.command_classes.get(self.command_opcode) cls = HCI_Command.command_classes.get(self.command_opcode)
if cls and cls.return_parameters_fields: if cls:
self.return_parameters = HCI_Object.from_bytes( # Try to parse the return parameters bytes into an object.
self.return_parameters, 0, cls.return_parameters_fields return_parameters = cls.parse_return_parameters(self.return_parameters)
) if return_parameters is not None:
self.return_parameters.fields = cls.return_parameters_fields self.return_parameters = return_parameters
return self return self
@@ -5205,7 +5297,7 @@ class HCI_Number_Of_Completed_Packets_Event(HCI_Event):
def __str__(self): def __str__(self):
lines = [ lines = [
color(self.name, 'magenta') + ':', color(self.name, 'magenta') + ':',
color(' number_of_handles: ', 'cyan') color(' number_of_handles: ', 'cyan')
+ f'{len(self.connection_handles)}', + f'{len(self.connection_handles)}',
] ]
for i, connection_handle in enumerate(self.connection_handles): for i, connection_handle in enumerate(self.connection_handles):
@@ -5596,7 +5688,7 @@ class HCI_Remote_Host_Supported_Features_Notification_Event(HCI_Event):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class HCI_AclDataPacket: class HCI_AclDataPacket(HCI_Packet):
''' '''
See Bluetooth spec @ 5.4.2 HCI ACL Data Packets See Bluetooth spec @ 5.4.2 HCI ACL Data Packets
''' '''
@@ -5604,7 +5696,7 @@ class HCI_AclDataPacket:
hci_packet_type = HCI_ACL_DATA_PACKET hci_packet_type = HCI_ACL_DATA_PACKET
@staticmethod @staticmethod
def from_bytes(packet): def from_bytes(packet: bytes) -> HCI_AclDataPacket:
# Read the header # Read the header
h, data_total_length = struct.unpack_from('<HH', packet, 1) h, data_total_length = struct.unpack_from('<HH', packet, 1)
connection_handle = h & 0xFFF connection_handle = h & 0xFFF
@@ -5646,12 +5738,14 @@ class HCI_AclDataPacket:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class HCI_AclDataPacketAssembler: class HCI_AclDataPacketAssembler:
def __init__(self, callback): current_data: Optional[bytes]
def __init__(self, callback: Callable[[bytes], Any]) -> None:
self.callback = callback self.callback = callback
self.current_data = None self.current_data = None
self.l2cap_pdu_length = 0 self.l2cap_pdu_length = 0
def feed_packet(self, packet): def feed_packet(self, packet: HCI_AclDataPacket) -> None:
if packet.pb_flag in ( if packet.pb_flag in (
HCI_ACL_PB_FIRST_NON_FLUSHABLE, HCI_ACL_PB_FIRST_NON_FLUSHABLE,
HCI_ACL_PB_FIRST_FLUSHABLE, HCI_ACL_PB_FIRST_FLUSHABLE,
@@ -5665,6 +5759,7 @@ class HCI_AclDataPacketAssembler:
return return
self.current_data += packet.data self.current_data += packet.data
assert self.current_data is not None
if len(self.current_data) == self.l2cap_pdu_length + 4: if len(self.current_data) == self.l2cap_pdu_length + 4:
# The packet is complete, invoke the callback # The packet is complete, invoke the callback
logger.debug(f'<<< ACL PDU: {self.current_data.hex()}') logger.debug(f'<<< ACL PDU: {self.current_data.hex()}')

View File

@@ -1,4 +1,4 @@
# Copyright 2021-2022 Google LLC # Copyright 2023 Google LLC
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -15,19 +15,51 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import collections.abc
import logging import logging
import asyncio import asyncio
import collections import dataclasses
from typing import Union import enum
import traceback
import warnings
from typing import Dict, List, Union, Set, TYPE_CHECKING
from . import at
from . import rfcomm from . import rfcomm
from .colors import color
from bumble.colors import color
from bumble.core import (
ProtocolError,
BT_GENERIC_AUDIO_SERVICE,
BT_HANDSFREE_SERVICE,
BT_L2CAP_PROTOCOL_ID,
BT_RFCOMM_PROTOCOL_ID,
)
from bumble.sdp import (
DataElement,
ServiceAttribute,
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID,
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Error
# -----------------------------------------------------------------------------
class HfpProtocolError(ProtocolError):
def __init__(self, error_name: str = '', details: str = ''):
super().__init__(None, 'hfp', error_name, details)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Protocol Support # Protocol Support
@@ -41,6 +73,7 @@ class HfpProtocol:
lines_available: asyncio.Event lines_available: asyncio.Event
def __init__(self, dlc: rfcomm.DLC) -> None: def __init__(self, dlc: rfcomm.DLC) -> None:
warnings.warn("See HfProtocol", DeprecationWarning)
self.dlc = dlc self.dlc = dlc
self.buffer = '' self.buffer = ''
self.lines = collections.deque() self.lines = collections.deque()
@@ -83,19 +116,706 @@ class HfpProtocol:
logger.debug(color(f'<<< {line}', 'green')) logger.debug(color(f'<<< {line}', 'green'))
return line return line
async def initialize_service(self) -> None:
# Perform Service Level Connection Initialization
self.send_command_line('AT+BRSF=2072') # Retrieve Supported Features
await (self.next_line())
await (self.next_line())
self.send_command_line('AT+CIND=?') # -----------------------------------------------------------------------------
await (self.next_line()) # Normative protocol definitions
await (self.next_line()) # -----------------------------------------------------------------------------
self.send_command_line('AT+CIND?')
await (self.next_line())
await (self.next_line())
self.send_command_line('AT+CMER=3,0,0,1') # HF supported features (AT+BRSF=) (normative).
await (self.next_line()) # Hands-Free Profile v1.8, 4.34.2, AT Capabilities Re-Used from GSM 07.07
# and 3GPP 27.007
class HfFeature(enum.IntFlag):
EC_NR = 0x001 # Echo Cancel & Noise reduction
THREE_WAY_CALLING = 0x002
CLI_PRESENTATION_CAPABILITY = 0x004
VOICE_RECOGNITION_ACTIVATION = 0x008
REMOTE_VOLUME_CONTROL = 0x010
ENHANCED_CALL_STATUS = 0x020
ENHANCED_CALL_CONTROL = 0x040
CODEC_NEGOTIATION = 0x080
HF_INDICATORS = 0x100
ESCO_S4_SETTINGS_SUPPORTED = 0x200
ENHANCED_VOICE_RECOGNITION_STATUS = 0x400
VOICE_RECOGNITION_TEST = 0x800
# AG supported features (+BRSF:) (normative).
# Hands-Free Profile v1.8, 4.34.2, AT Capabilities Re-Used from GSM 07.07
# and 3GPP 27.007
class AgFeature(enum.IntFlag):
THREE_WAY_CALLING = 0x001
EC_NR = 0x002 # Echo Cancel & Noise reduction
VOICE_RECOGNITION_FUNCTION = 0x004
IN_BAND_RING_TONE_CAPABILITY = 0x008
VOICE_TAG = 0x010 # Attach a number to voice tag
REJECT_CALL = 0x020 # Ability to reject a call
ENHANCED_CALL_STATUS = 0x040
ENHANCED_CALL_CONTROL = 0x080
EXTENDED_ERROR_RESULT_CODES = 0x100
CODEC_NEGOTIATION = 0x200
HF_INDICATORS = 0x400
ESCO_S4_SETTINGS_SUPPORTED = 0x800
ENHANCED_VOICE_RECOGNITION_STATUS = 0x1000
VOICE_RECOGNITION_TEST = 0x2000
# Audio Codec IDs (normative).
# Hands-Free Profile v1.8, 10 Appendix B
class AudioCodec(enum.IntEnum):
CVSD = 0x01 # Support for CVSD audio codec
MSBC = 0x02 # Support for mSBC audio codec
# HF Indicators (normative).
# Bluetooth Assigned Numbers, 6.10.1 HF Indicators
class HfIndicator(enum.IntEnum):
ENHANCED_SAFETY = 0x01 # Enhanced safety feature
BATTERY_LEVEL = 0x02 # Battery level feature
# Call Hold supported operations (normative).
# AT Commands Reference Guide, 3.5.2.3.12 +CHLD - Call Holding Services
class CallHoldOperation(enum.IntEnum):
RELEASE_ALL_HELD_CALLS = 0 # Release all held calls
RELEASE_ALL_ACTIVE_CALLS = 1 # Release all active calls, accept other
HOLD_ALL_ACTIVE_CALLS = 2 # Place all active calls on hold, accept other
ADD_HELD_CALL = 3 # Adds a held call to conversation
# Response Hold status (normative).
# Hands-Free Profile v1.8, 4.34.2, AT Capabilities Re-Used from GSM 07.07
# and 3GPP 27.007
class ResponseHoldStatus(enum.IntEnum):
INC_CALL_HELD = 0 # Put incoming call on hold
HELD_CALL_ACC = 1 # Accept a held incoming call
HELD_CALL_REJ = 2 # Reject a held incoming call
# Values for the Call Setup AG indicator (normative).
# Hands-Free Profile v1.8, 4.34.2, AT Capabilities Re-Used from GSM 07.07
# and 3GPP 27.007
class CallSetupAgIndicator(enum.IntEnum):
NOT_IN_CALL_SETUP = 0
INCOMING_CALL_PROCESS = 1
OUTGOING_CALL_SETUP = 2
REMOTE_ALERTED = 3 # Remote party alerted in an outgoing call
# Values for the Call Held AG indicator (normative).
# Hands-Free Profile v1.8, 4.34.2, AT Capabilities Re-Used from GSM 07.07
# and 3GPP 27.007
class CallHeldAgIndicator(enum.IntEnum):
NO_CALLS_HELD = 0
# Call is placed on hold or active/held calls swapped
# (The AG has both an active AND a held call)
CALL_ON_HOLD_AND_ACTIVE_CALL = 1
CALL_ON_HOLD_NO_ACTIVE_CALL = 2 # Call on hold, no active call
# Call Info direction (normative).
# AT Commands Reference Guide, 3.5.2.3.15 +CLCC - List Current Calls
class CallInfoDirection(enum.IntEnum):
MOBILE_ORIGINATED_CALL = 0
MOBILE_TERMINATED_CALL = 1
# Call Info status (normative).
# AT Commands Reference Guide, 3.5.2.3.15 +CLCC - List Current Calls
class CallInfoStatus(enum.IntEnum):
ACTIVE = 0
HELD = 1
DIALING = 2
ALERTING = 3
INCOMING = 4
WAITING = 5
# Call Info mode (normative).
# AT Commands Reference Guide, 3.5.2.3.15 +CLCC - List Current Calls
class CallInfoMode(enum.IntEnum):
VOICE = 0
DATA = 1
FAX = 2
UNKNOWN = 9
# -----------------------------------------------------------------------------
# Hands-Free Control Interoperability Requirements
# -----------------------------------------------------------------------------
# Response codes.
RESPONSE_CODES = [
"+APLSIRI",
"+BAC",
"+BCC",
"+BCS",
"+BIA",
"+BIEV",
"+BIND",
"+BINP",
"+BLDN",
"+BRSF",
"+BTRH",
"+BVRA",
"+CCWA",
"+CHLD",
"+CHUP",
"+CIND",
"+CLCC",
"+CLIP",
"+CMEE",
"+CMER",
"+CNUM",
"+COPS",
"+IPHONEACCEV",
"+NREC",
"+VGM",
"+VGS",
"+VTS",
"+XAPL",
"A",
"D",
]
# Unsolicited responses and statuses.
UNSOLICITED_CODES = [
"+APLSIRI",
"+BCS",
"+BIND",
"+BSIR",
"+BTRH",
"+BVRA",
"+CCWA",
"+CIEV",
"+CLIP",
"+VGM",
"+VGS",
"BLACKLISTED",
"BUSY",
"DELAYED",
"NO ANSWER",
"NO CARRIER",
"RING",
]
# Status codes
STATUS_CODES = [
"+CME ERROR",
"BLACKLISTED",
"BUSY",
"DELAYED",
"ERROR",
"NO ANSWER",
"NO CARRIER",
"OK",
]
@dataclasses.dataclass
class Configuration:
supported_hf_features: List[HfFeature]
supported_hf_indicators: List[HfIndicator]
supported_audio_codecs: List[AudioCodec]
class AtResponseType(enum.Enum):
"""Indicate if a response is expected from an AT command, and if multiple
responses are accepted."""
NONE = 0
SINGLE = 1
MULTIPLE = 2
class AtResponse:
code: str
parameters: list
def __init__(self, response: bytearray):
code_and_parameters = response.split(b':')
parameters = (
code_and_parameters[1] if len(code_and_parameters) > 1 else bytearray()
)
self.code = code_and_parameters[0].decode()
self.parameters = at.parse_parameters(parameters)
@dataclasses.dataclass
class AgIndicatorState:
description: str
index: int
supported_values: Set[int]
current_status: int
@dataclasses.dataclass
class HfIndicatorState:
supported: bool = False
enabled: bool = False
class HfProtocol:
"""Implementation for the Hands-Free side of the Hands-Free profile.
Reference specification Hands-Free Profile v1.8"""
supported_hf_features: int
supported_audio_codecs: List[AudioCodec]
supported_ag_features: int
supported_ag_call_hold_operations: List[CallHoldOperation]
ag_indicators: List[AgIndicatorState]
hf_indicators: Dict[HfIndicator, HfIndicatorState]
dlc: rfcomm.DLC
command_lock: asyncio.Lock
if TYPE_CHECKING:
response_queue: asyncio.Queue[AtResponse]
unsolicited_queue: asyncio.Queue[AtResponse]
else:
response_queue: asyncio.Queue
unsolicited_queue: asyncio.Queue
read_buffer: bytearray
def __init__(self, dlc: rfcomm.DLC, configuration: Configuration):
# Configure internal state.
self.dlc = dlc
self.command_lock = asyncio.Lock()
self.response_queue = asyncio.Queue()
self.unsolicited_queue = asyncio.Queue()
self.read_buffer = bytearray()
# Build local features.
self.supported_hf_features = sum(configuration.supported_hf_features)
self.supported_audio_codecs = configuration.supported_audio_codecs
self.hf_indicators = {
indicator: HfIndicatorState()
for indicator in configuration.supported_hf_indicators
}
# Clear remote features.
self.supported_ag_features = 0
self.supported_ag_call_hold_operations = []
self.ag_indicators = []
# Bind the AT reader to the RFCOMM channel.
self.dlc.sink = self._read_at
def supports_hf_feature(self, feature: HfFeature) -> bool:
return (self.supported_hf_features & feature) != 0
def supports_ag_feature(self, feature: AgFeature) -> bool:
return (self.supported_ag_features & feature) != 0
# Read AT messages from the RFCOMM channel.
# Enqueue AT commands, responses, unsolicited responses to their
# respective queues, and set the corresponding event.
def _read_at(self, data: bytes):
# Append to the read buffer.
self.read_buffer.extend(data)
# Locate header and trailer.
header = self.read_buffer.find(b'\r\n')
trailer = self.read_buffer.find(b'\r\n', header + 2)
if header == -1 or trailer == -1:
return
# Isolate the AT response code and parameters.
raw_response = self.read_buffer[header + 2 : trailer]
response = AtResponse(raw_response)
logger.debug(f"<<< {raw_response.decode()}")
# Consume the response bytes.
self.read_buffer = self.read_buffer[trailer + 2 :]
# Forward the received code to the correct queue.
if self.command_lock.locked() and (
response.code in STATUS_CODES or response.code in RESPONSE_CODES
):
self.response_queue.put_nowait(response)
elif response.code in UNSOLICITED_CODES:
self.unsolicited_queue.put_nowait(response)
else:
logger.warning(f"dropping unexpected response with code '{response.code}'")
# Send an AT command and wait for the peer response.
# Wait for the AT responses sent by the peer, to the status code.
# Raises asyncio.TimeoutError if the status is not received
# after a timeout (default 1 second).
# Raises ProtocolError if the status is not OK.
async def execute_command(
self,
cmd: str,
timeout: float = 1.0,
response_type: AtResponseType = AtResponseType.NONE,
) -> Union[None, AtResponse, List[AtResponse]]:
async with self.command_lock:
logger.debug(f">>> {cmd}")
self.dlc.write(cmd + '\r')
responses: List[AtResponse] = []
while True:
result = await asyncio.wait_for(
self.response_queue.get(), timeout=timeout
)
if result.code == 'OK':
if response_type == AtResponseType.SINGLE and len(responses) != 1:
raise HfpProtocolError("NO ANSWER")
if response_type == AtResponseType.MULTIPLE:
return responses
if response_type == AtResponseType.SINGLE:
return responses[0]
return None
if result.code in STATUS_CODES:
raise HfpProtocolError(result.code)
responses.append(result)
# 4.2.1 Service Level Connection Initialization.
async def initiate_slc(self):
# 4.2.1.1 Supported features exchange
# First, in the initialization procedure, the HF shall send the
# AT+BRSF=<HF supported features> command to the AG to both notify
# the AG of the supported features in the HF, as well as to retrieve the
# supported features in the AG using the +BRSF result code.
response = await self.execute_command(
f"AT+BRSF={self.supported_hf_features}", response_type=AtResponseType.SINGLE
)
self.supported_ag_features = int(response.parameters[0])
logger.info(f"supported AG features: {self.supported_ag_features}")
for feature in AgFeature:
if self.supports_ag_feature(feature):
logger.info(f" - {feature.name}")
# 4.2.1.2 Codec Negotiation
# Secondly, in the initialization procedure, if the HF supports the
# Codec Negotiation feature, it shall check if the AT+BRSF command
# response from the AG has indicated that it supports the Codec
# Negotiation feature.
if self.supports_hf_feature(
HfFeature.CODEC_NEGOTIATION
) and self.supports_ag_feature(AgFeature.CODEC_NEGOTIATION):
# If both the HF and AG do support the Codec Negotiation feature
# then the HF shall send the AT+BAC=<HF available codecs> command to
# the AG to notify the AG of the available codecs in the HF.
codecs = [str(c) for c in self.supported_audio_codecs]
await self.execute_command(f"AT+BAC={','.join(codecs)}")
# 4.2.1.3 AG Indicators
# After having retrieved the supported features in the AG, the HF shall
# determine which indicators are supported by the AG, as well as the
# ordering of the supported indicators. This is because, according to
# the 3GPP 27.007 specification [2], the AG may support additional
# indicators not provided for by the Hands-Free Profile, and because the
# ordering of the indicators is implementation specific. The HF uses
# the AT+CIND=? Test command to retrieve information about the supported
# indicators and their ordering.
response = await self.execute_command(
"AT+CIND=?", response_type=AtResponseType.SINGLE
)
self.ag_indicators = []
for index, indicator in enumerate(response.parameters):
description = indicator[0].decode()
supported_values = []
for value in indicator[1]:
value = value.split(b'-')
value = [int(v) for v in value]
value_min = value[0]
value_max = value[1] if len(value) > 1 else value[0]
supported_values.extend([v for v in range(value_min, value_max + 1)])
self.ag_indicators.append(
AgIndicatorState(description, index, set(supported_values), 0)
)
# Once the HF has the necessary supported indicator and ordering
# information, it shall retrieve the current status of the indicators
# in the AG using the AT+CIND? Read command.
response = await self.execute_command(
"AT+CIND?", response_type=AtResponseType.SINGLE
)
for index, indicator in enumerate(response.parameters):
self.ag_indicators[index].current_status = int(indicator)
# After having retrieved the status of the indicators in the AG, the HF
# shall then enable the "Indicators status update" function in the AG by
# issuing the AT+CMER command, to which the AG shall respond with OK.
await self.execute_command("AT+CMER=3,,,1")
if self.supports_hf_feature(
HfFeature.THREE_WAY_CALLING
) and self.supports_ag_feature(HfFeature.THREE_WAY_CALLING):
# After the HF has enabled the “Indicators status update” function in
# the AG, and if the “Call waiting and 3-way calling” bit was set in the
# supported features bitmap by both the HF and the AG, the HF shall
# issue the AT+CHLD=? test command to retrieve the information about how
# the call hold and multiparty services are supported in the AG. The HF
# shall not issue the AT+CHLD=? test command in case either the HF or
# the AG does not support the "Three-way calling" feature.
response = await self.execute_command(
"AT+CHLD=?", response_type=AtResponseType.SINGLE
)
self.supported_ag_call_hold_operations = [
CallHoldOperation(int(operation))
for operation in response.parameters[0]
if not b'x' in operation
]
# 4.2.1.4 HF Indicators
# If the HF supports the HF indicator feature, it shall check the +BRSF
# response to see if the AG also supports the HF Indicator feature.
if self.supports_hf_feature(
HfFeature.HF_INDICATORS
) and self.supports_ag_feature(AgFeature.HF_INDICATORS):
# If both the HF and AG support the HF Indicator feature, then the HF
# shall send the AT+BIND=<HF supported HF indicators> command to the AG
# to notify the AG of the supported indicators assigned numbers in the
# HF. The AG shall respond with OK
indicators = [str(i) for i in self.hf_indicators.keys()]
await self.execute_command(f"AT+BIND={','.join(indicators)}")
# After having provided the AG with the HF indicators it supports,
# the HF shall send the AT+BIND=? to request HF indicators supported
# by the AG. The AG shall reply with the +BIND response listing all
# HF indicators that it supports followed by an OK.
response = await self.execute_command(
"AT+BIND=?", response_type=AtResponseType.SINGLE
)
logger.info("supported HF indicators:")
for indicator in response.parameters[0]:
indicator = HfIndicator(int(indicator))
logger.info(f" - {indicator.name}")
if indicator in self.hf_indicators:
self.hf_indicators[indicator].supported = True
# Once the HF receives the supported HF indicators list from the AG,
# the HF shall send the AT+BIND? command to determine which HF
# indicators are enabled. The AG shall respond with one or more
# +BIND responses. The AG shall terminate the list with OK.
# (See Section 4.36.1.3).
responses = await self.execute_command(
"AT+BIND?", response_type=AtResponseType.MULTIPLE
)
logger.info("enabled HF indicators:")
for response in responses:
indicator = HfIndicator(int(response.parameters[0]))
enabled = int(response.parameters[1]) != 0
logger.info(f" - {indicator.name}: {enabled}")
if indicator in self.hf_indicators:
self.hf_indicators[indicator].enabled = True
logger.info("SLC setup completed")
# 4.11.2 Audio Connection Setup by HF
async def setup_audio_connection(self):
# When the HF triggers the establishment of the Codec Connection it
# shall send the AT command AT+BCC to the AG. The AG shall respond with
# OK if it will start the Codec Connection procedure, and with ERROR
# if it cannot start the Codec Connection procedure.
await self.execute_command("AT+BCC")
# 4.11.3 Codec Connection Setup
async def setup_codec_connection(self, codec_id: int):
# The AG shall send a +BCS=<Codec ID> unsolicited response to the HF.
# The HF shall then respond to the incoming unsolicited response with
# the AT command AT+BCS=<Codec ID>. The ID shall be the same as in the
# unsolicited response code as long as the ID is supported.
# If the received ID is not available, the HF shall respond with
# AT+BAC with its available codecs.
if codec_id not in self.supported_audio_codecs:
codecs = [str(c) for c in self.supported_audio_codecs]
await self.execute_command(f"AT+BAC={','.join(codecs)}")
return
await self.execute_command(f"AT+BCS={codec_id}")
# After sending the OK response, the AG shall open the
# Synchronous Connection with the settings that are determined by the
# ID. The HF shall be ready to accept the synchronous connection
# establishment as soon as it has sent the AT commands AT+BCS=<Codec ID>.
logger.info("codec connection setup completed")
# 4.13.1 Answer Incoming Call from the HF In-Band Ringing
async def answer_incoming_call(self):
# The user accepts the incoming voice call by using the proper means
# provided by the HF. The HF shall then send the ATA command
# (see Section 4.34) to the AG. The AG shall then begin the procedure for
# accepting the incoming call.
await self.execute_command("ATA")
# 4.14.1 Reject an Incoming Call from the HF
async def reject_incoming_call(self):
# The user rejects the incoming call by using the User Interface on the
# Hands-Free unit. The HF shall then send the AT+CHUP command
# (see Section 4.34) to the AG. This may happen at any time during the
# procedures described in Sections 4.13.1 and 4.13.2.
await self.execute_command("AT+CHUP")
# 4.15.1 Terminate a Call Process from the HF
async def terminate_call(self):
# The user may abort the ongoing call process using whatever means
# provided by the Hands-Free unit. The HF shall send AT+CHUP command
# (see Section 4.34) to the AG, and the AG shall then start the
# procedure to terminate or interrupt the current call procedure.
# The AG shall then send the OK indication followed by the +CIEV result
# code, with the value indicating (call=0).
await self.execute_command("AT+CHUP")
async def update_ag_indicator(self, index: int, value: int):
self.ag_indicators[index].current_status = value
logger.info(
f"AG indicator updated: {self.ag_indicators[index].description}, {value}"
)
async def handle_unsolicited(self):
"""Handle unsolicited result codes sent by the audio gateway."""
result = await self.unsolicited_queue.get()
if result.code == "+BCS":
await self.setup_codec_connection(int(result.parameters[0]))
elif result.code == "+CIEV":
await self.update_ag_indicator(
int(result.parameters[0]), int(result.parameters[1])
)
else:
logging.info(f"unhandled unsolicited response {result.code}")
async def run(self):
"""Main rountine for the Hands-Free side of the HFP protocol.
Initiates the service level connection then loops handling
unsolicited AG responses."""
try:
await self.initiate_slc()
while True:
await self.handle_unsolicited()
except Exception:
logger.error("HFP-HF protocol failed with the following error:")
logger.error(traceback.format_exc())
# -----------------------------------------------------------------------------
# Normative SDP definitions
# -----------------------------------------------------------------------------
# Profile version (normative).
# Hands-Free Profile v1.8, 5.3 SDP Interoperability Requirements
class ProfileVersion(enum.IntEnum):
V1_5 = 0x0105
V1_6 = 0x0106
V1_7 = 0x0107
V1_8 = 0x0108
V1_9 = 0x0109
# HF supported features (normative).
# Hands-Free Profile v1.8, 5.3 SDP Interoperability Requirements
class HfSdpFeature(enum.IntFlag):
EC_NR = 0x01 # Echo Cancel & Noise reduction
THREE_WAY_CALLING = 0x02
CLI_PRESENTATION_CAPABILITY = 0x04
VOICE_RECOGNITION_ACTIVATION = 0x08
REMOTE_VOLUME_CONTROL = 0x10
WIDE_BAND = 0x20 # Wide band speech
ENHANCED_VOICE_RECOGNITION_STATUS = 0x40
VOICE_RECOGNITION_TEST = 0x80
# AG supported features (normative).
# Hands-Free Profile v1.8, 5.3 SDP Interoperability Requirements
class AgSdpFeature(enum.IntFlag):
THREE_WAY_CALLING = 0x01
EC_NR = 0x02 # Echo Cancel & Noise reduction
VOICE_RECOGNITION_FUNCTION = 0x04
IN_BAND_RING_TONE_CAPABILITY = 0x08
VOICE_TAG = 0x10 # Attach a number to voice tag
WIDE_BAND = 0x20 # Wide band speech
ENHANCED_VOICE_RECOGNITION_STATUS = 0x40
VOICE_RECOGNITION_TEST = 0x80
def sdp_records(
service_record_handle: int, rfcomm_channel: int, configuration: Configuration
) -> List[ServiceAttribute]:
"""Generate the SDP record for HFP Hands-Free support.
The record exposes the features supported in the input configuration,
and the allocated RFCOMM channel."""
hf_supported_features = 0
if HfFeature.EC_NR in configuration.supported_hf_features:
hf_supported_features |= HfSdpFeature.EC_NR
if HfFeature.THREE_WAY_CALLING in configuration.supported_hf_features:
hf_supported_features |= HfSdpFeature.THREE_WAY_CALLING
if HfFeature.CLI_PRESENTATION_CAPABILITY in configuration.supported_hf_features:
hf_supported_features |= HfSdpFeature.CLI_PRESENTATION_CAPABILITY
if HfFeature.VOICE_RECOGNITION_ACTIVATION in configuration.supported_hf_features:
hf_supported_features |= HfSdpFeature.VOICE_RECOGNITION_ACTIVATION
if HfFeature.REMOTE_VOLUME_CONTROL in configuration.supported_hf_features:
hf_supported_features |= HfSdpFeature.REMOTE_VOLUME_CONTROL
if (
HfFeature.ENHANCED_VOICE_RECOGNITION_STATUS
in configuration.supported_hf_features
):
hf_supported_features |= HfSdpFeature.ENHANCED_VOICE_RECOGNITION_STATUS
if HfFeature.VOICE_RECOGNITION_TEST in configuration.supported_hf_features:
hf_supported_features |= HfSdpFeature.VOICE_RECOGNITION_TEST
if AudioCodec.MSBC in configuration.supported_audio_codecs:
hf_supported_features |= HfSdpFeature.WIDE_BAND
return [
ServiceAttribute(
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
DataElement.unsigned_integer_32(service_record_handle),
),
ServiceAttribute(
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.uuid(BT_HANDSFREE_SERVICE),
DataElement.uuid(BT_GENERIC_AUDIO_SERVICE),
]
),
),
ServiceAttribute(
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.sequence([DataElement.uuid(BT_L2CAP_PROTOCOL_ID)]),
DataElement.sequence(
[
DataElement.uuid(BT_RFCOMM_PROTOCOL_ID),
DataElement.unsigned_integer_8(rfcomm_channel),
]
),
]
),
),
ServiceAttribute(
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.sequence(
[
DataElement.uuid(BT_HANDSFREE_SERVICE),
DataElement.unsigned_integer_16(ProfileVersion.V1_8),
]
)
]
),
),
ServiceAttribute(
SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID,
DataElement.unsigned_integer_16(hf_supported_features),
),
]

View File

@@ -15,23 +15,24 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio import asyncio
import collections import collections
import logging import logging
import struct import struct
from typing import Optional, TYPE_CHECKING, Dict, Callable, Awaitable
from bumble.colors import color from bumble.colors import color
from bumble.l2cap import L2CAP_PDU from bumble.l2cap import L2CAP_PDU
from bumble.snoop import Snooper from bumble.snoop import Snooper
from bumble import drivers from bumble import drivers
from typing import Optional
from .hci import ( from .hci import (
Address, Address,
HCI_ACL_DATA_PACKET, HCI_ACL_DATA_PACKET,
HCI_COMMAND_COMPLETE_EVENT,
HCI_COMMAND_PACKET, HCI_COMMAND_PACKET,
HCI_COMMAND_COMPLETE_EVENT,
HCI_EVENT_PACKET, HCI_EVENT_PACKET,
HCI_LE_READ_BUFFER_SIZE_COMMAND, HCI_LE_READ_BUFFER_SIZE_COMMAND,
HCI_LE_READ_LOCAL_SUPPORTED_FEATURES_COMMAND, HCI_LE_READ_LOCAL_SUPPORTED_FEATURES_COMMAND,
@@ -45,8 +46,11 @@ from .hci import (
HCI_VERSION_BLUETOOTH_CORE_4_0, HCI_VERSION_BLUETOOTH_CORE_4_0,
HCI_AclDataPacket, HCI_AclDataPacket,
HCI_AclDataPacketAssembler, HCI_AclDataPacketAssembler,
HCI_Command,
HCI_Command_Complete_Event,
HCI_Constant, HCI_Constant,
HCI_Error, HCI_Error,
HCI_Event,
HCI_LE_Long_Term_Key_Request_Negative_Reply_Command, HCI_LE_Long_Term_Key_Request_Negative_Reply_Command,
HCI_LE_Long_Term_Key_Request_Reply_Command, HCI_LE_Long_Term_Key_Request_Reply_Command,
HCI_LE_Read_Buffer_Size_Command, HCI_LE_Read_Buffer_Size_Command,
@@ -63,16 +67,19 @@ from .hci import (
HCI_Read_Local_Version_Information_Command, HCI_Read_Local_Version_Information_Command,
HCI_Reset_Command, HCI_Reset_Command,
HCI_Set_Event_Mask_Command, HCI_Set_Event_Mask_Command,
map_null_terminated_utf8_string,
) )
from .core import ( from .core import (
BT_BR_EDR_TRANSPORT, BT_BR_EDR_TRANSPORT,
BT_CENTRAL_ROLE,
BT_LE_TRANSPORT, BT_LE_TRANSPORT,
ConnectionPHY, ConnectionPHY,
ConnectionParameters, ConnectionParameters,
InvalidStateError,
) )
from .utils import AbortableEventEmitter from .utils import AbortableEventEmitter
from .transport.common import TransportLostError
if TYPE_CHECKING:
from .transport.common import TransportSink, TransportSource
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -96,27 +103,38 @@ HOST_HC_TOTAL_NUM_ACL_DATA_PACKETS = 1
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Connection: class Connection:
def __init__(self, host, handle, peer_address, transport): def __init__(self, host: Host, handle: int, peer_address: Address, transport: int):
self.host = host self.host = host
self.handle = handle self.handle = handle
self.peer_address = peer_address self.peer_address = peer_address
self.assembler = HCI_AclDataPacketAssembler(self.on_acl_pdu) self.assembler = HCI_AclDataPacketAssembler(self.on_acl_pdu)
self.transport = transport self.transport = transport
def on_hci_acl_data_packet(self, packet): def on_hci_acl_data_packet(self, packet: HCI_AclDataPacket) -> None:
self.assembler.feed_packet(packet) self.assembler.feed_packet(packet)
def on_acl_pdu(self, pdu): def on_acl_pdu(self, pdu: bytes) -> None:
l2cap_pdu = L2CAP_PDU.from_bytes(pdu) l2cap_pdu = L2CAP_PDU.from_bytes(pdu)
self.host.on_l2cap_pdu(self, l2cap_pdu.cid, l2cap_pdu.payload) self.host.on_l2cap_pdu(self, l2cap_pdu.cid, l2cap_pdu.payload)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Host(AbortableEventEmitter): class Host(AbortableEventEmitter):
def __init__(self, controller_source=None, controller_sink=None): connections: Dict[int, Connection]
acl_packet_queue: collections.deque[HCI_AclDataPacket]
hci_sink: TransportSink
long_term_key_provider: Optional[
Callable[[int, bytes, int], Awaitable[Optional[bytes]]]
]
link_key_provider: Optional[Callable[[Address], Awaitable[Optional[bytes]]]]
def __init__(
self,
controller_source: Optional[TransportSource] = None,
controller_sink: Optional[TransportSink] = None,
) -> None:
super().__init__() super().__init__()
self.hci_sink = None
self.hci_metadata = None self.hci_metadata = None
self.ready = False # True when we can accept incoming packets self.ready = False # True when we can accept incoming packets
self.reset_done = False self.reset_done = False
@@ -296,7 +314,7 @@ class Host(AbortableEventEmitter):
self.reset_done = True self.reset_done = True
@property @property
def controller(self): def controller(self) -> TransportSink:
return self.hci_sink return self.hci_sink
@controller.setter @controller.setter
@@ -305,13 +323,12 @@ class Host(AbortableEventEmitter):
if controller: if controller:
controller.set_packet_sink(self) controller.set_packet_sink(self)
def set_packet_sink(self, sink): def set_packet_sink(self, sink: TransportSink) -> None:
self.hci_sink = sink self.hci_sink = sink
def send_hci_packet(self, packet): def send_hci_packet(self, packet: HCI_Packet) -> None:
if self.snooper: if self.snooper:
self.snooper.snoop(bytes(packet), Snooper.Direction.HOST_TO_CONTROLLER) self.snooper.snoop(bytes(packet), Snooper.Direction.HOST_TO_CONTROLLER)
self.hci_sink.on_packet(bytes(packet)) self.hci_sink.on_packet(bytes(packet))
async def send_command(self, command, check_result=False): async def send_command(self, command, check_result=False):
@@ -349,7 +366,7 @@ class Host(AbortableEventEmitter):
return response return response
except Exception as error: except Exception as error:
logger.warning( logger.warning(
f'{color("!!! Exception while sending HCI packet:", "red")} {error}' f'{color("!!! Exception while sending command:", "red")} {error}'
) )
raise error raise error
finally: finally:
@@ -357,13 +374,13 @@ class Host(AbortableEventEmitter):
self.pending_response = None self.pending_response = None
# Use this method to send a command from a task # Use this method to send a command from a task
def send_command_sync(self, command): def send_command_sync(self, command: HCI_Command) -> None:
async def send_command(command): async def send_command(command: HCI_Command) -> None:
await self.send_command(command) await self.send_command(command)
asyncio.create_task(send_command(command)) asyncio.create_task(send_command(command))
def send_l2cap_pdu(self, connection_handle, cid, pdu): def send_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes) -> None:
l2cap_pdu = bytes(L2CAP_PDU(cid, pdu)) l2cap_pdu = bytes(L2CAP_PDU(cid, pdu))
# Send the data to the controller via ACL packets # Send the data to the controller via ACL packets
@@ -388,7 +405,7 @@ class Host(AbortableEventEmitter):
offset += data_total_length offset += data_total_length
bytes_remaining -= data_total_length bytes_remaining -= data_total_length
def queue_acl_packet(self, acl_packet): def queue_acl_packet(self, acl_packet: HCI_AclDataPacket) -> None:
self.acl_packet_queue.appendleft(acl_packet) self.acl_packet_queue.appendleft(acl_packet)
self.check_acl_packet_queue() self.check_acl_packet_queue()
@@ -398,7 +415,7 @@ class Host(AbortableEventEmitter):
f'{len(self.acl_packet_queue)} in queue' f'{len(self.acl_packet_queue)} in queue'
) )
def check_acl_packet_queue(self): def check_acl_packet_queue(self) -> None:
# Send all we can (TODO: support different LE/Classic limits) # Send all we can (TODO: support different LE/Classic limits)
while ( while (
len(self.acl_packet_queue) > 0 len(self.acl_packet_queue) > 0
@@ -444,47 +461,53 @@ class Host(AbortableEventEmitter):
] ]
# Packet Sink protocol (packets coming from the controller via HCI) # Packet Sink protocol (packets coming from the controller via HCI)
def on_packet(self, packet): def on_packet(self, packet: bytes) -> None:
hci_packet = HCI_Packet.from_bytes(packet) hci_packet = HCI_Packet.from_bytes(packet)
if self.ready or ( if self.ready or (
hci_packet.hci_packet_type == HCI_EVENT_PACKET isinstance(hci_packet, HCI_Command_Complete_Event)
and hci_packet.event_code == HCI_COMMAND_COMPLETE_EVENT
and hci_packet.command_opcode == HCI_RESET_COMMAND and hci_packet.command_opcode == HCI_RESET_COMMAND
): ):
self.on_hci_packet(hci_packet) self.on_hci_packet(hci_packet)
else: else:
logger.debug('reset not done, ignoring packet from controller') logger.debug('reset not done, ignoring packet from controller')
def on_hci_packet(self, packet): def on_transport_lost(self):
# Called by the source when the transport has been lost.
if self.pending_response:
self.pending_response.set_exception(TransportLostError('transport lost'))
self.emit('flush')
def on_hci_packet(self, packet: HCI_Packet) -> None:
logger.debug(f'{color("### CONTROLLER -> HOST", "green")}: {packet}') logger.debug(f'{color("### CONTROLLER -> HOST", "green")}: {packet}')
if self.snooper: if self.snooper:
self.snooper.snoop(bytes(packet), Snooper.Direction.CONTROLLER_TO_HOST) self.snooper.snoop(bytes(packet), Snooper.Direction.CONTROLLER_TO_HOST)
# If the packet is a command, invoke the handler for this packet # If the packet is a command, invoke the handler for this packet
if packet.hci_packet_type == HCI_COMMAND_PACKET: if isinstance(packet, HCI_Command):
self.on_hci_command_packet(packet) self.on_hci_command_packet(packet)
elif packet.hci_packet_type == HCI_EVENT_PACKET: elif isinstance(packet, HCI_Event):
self.on_hci_event_packet(packet) self.on_hci_event_packet(packet)
elif packet.hci_packet_type == HCI_ACL_DATA_PACKET: elif isinstance(packet, HCI_AclDataPacket):
self.on_hci_acl_data_packet(packet) self.on_hci_acl_data_packet(packet)
else: else:
logger.warning(f'!!! unknown packet type {packet.hci_packet_type}') logger.warning(f'!!! unknown packet type {packet.hci_packet_type}')
def on_hci_command_packet(self, command): def on_hci_command_packet(self, command: HCI_Command) -> None:
logger.warning(f'!!! unexpected command packet: {command}') logger.warning(f'!!! unexpected command packet: {command}')
def on_hci_event_packet(self, event): def on_hci_event_packet(self, event: HCI_Event) -> None:
handler_name = f'on_{event.name.lower()}' handler_name = f'on_{event.name.lower()}'
handler = getattr(self, handler_name, self.on_hci_event) handler = getattr(self, handler_name, self.on_hci_event)
handler(event) handler(event)
def on_hci_acl_data_packet(self, packet): def on_hci_acl_data_packet(self, packet: HCI_AclDataPacket) -> None:
# Look for the connection to which this data belongs # Look for the connection to which this data belongs
if connection := self.connections.get(packet.connection_handle): if connection := self.connections.get(packet.connection_handle):
connection.on_hci_acl_data_packet(packet) connection.on_hci_acl_data_packet(packet)
def on_l2cap_pdu(self, connection, cid, pdu): def on_l2cap_pdu(self, connection: Connection, cid: int, pdu: bytes) -> None:
self.emit('l2cap_pdu', connection.handle, cid, pdu) self.emit('l2cap_pdu', connection.handle, cid, pdu)
def on_command_processed(self, event): def on_command_processed(self, event):
@@ -822,6 +845,10 @@ class Host(AbortableEventEmitter):
f'simple pairing complete for {event.bd_addr}: ' f'simple pairing complete for {event.bd_addr}: '
f'status={HCI_Constant.status_name(event.status)}' f'status={HCI_Constant.status_name(event.status)}'
) )
if event.status == HCI_SUCCESS:
self.emit('classic_pairing', event.bd_addr)
else:
self.emit('classic_pairing_failure', event.bd_addr, event.status)
def on_hci_pin_code_request_event(self, event): def on_hci_pin_code_request_event(self, event):
self.emit('pin_code_request', event.bd_addr) self.emit('pin_code_request', event.bd_addr)

View File

@@ -33,6 +33,7 @@ from typing import (
Union, Union,
Deque, Deque,
Iterable, Iterable,
SupportsBytes,
TYPE_CHECKING, TYPE_CHECKING,
) )
@@ -47,6 +48,7 @@ from .hci import (
if TYPE_CHECKING: if TYPE_CHECKING:
from bumble.device import Connection from bumble.device import Connection
from bumble.host import Host
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -728,7 +730,7 @@ class Channel(EventEmitter):
def __init__( def __init__(
self, self,
manager: 'ChannelManager', manager: ChannelManager,
connection: Connection, connection: Connection,
signaling_cid: int, signaling_cid: int,
psm: int, psm: int,
@@ -755,13 +757,13 @@ class Channel(EventEmitter):
) )
self.state = new_state self.state = new_state
def send_pdu(self, pdu) -> None: def send_pdu(self, pdu: SupportsBytes | bytes) -> None:
self.manager.send_pdu(self.connection, self.destination_cid, pdu) self.manager.send_pdu(self.connection, self.destination_cid, pdu)
def send_control_frame(self, frame) -> None: def send_control_frame(self, frame: L2CAP_Control_Frame) -> None:
self.manager.send_control_frame(self.connection, self.signaling_cid, frame) self.manager.send_control_frame(self.connection, self.signaling_cid, frame)
async def send_request(self, request) -> bytes: async def send_request(self, request: SupportsBytes) -> bytes:
# Check that there isn't already a request pending # Check that there isn't already a request pending
if self.response: if self.response:
raise InvalidStateError('request already pending') raise InvalidStateError('request already pending')
@@ -772,7 +774,7 @@ class Channel(EventEmitter):
self.send_pdu(request) self.send_pdu(request)
return await self.response return await self.response
def on_pdu(self, pdu) -> None: def on_pdu(self, pdu: bytes) -> None:
if self.response: if self.response:
self.response.set_result(pdu) self.response.set_result(pdu)
self.response = None self.response = None
@@ -1041,7 +1043,7 @@ class LeConnectionOrientedChannel(EventEmitter):
def __init__( def __init__(
self, self,
manager: 'ChannelManager', manager: ChannelManager,
connection: Connection, connection: Connection,
le_psm: int, le_psm: int,
source_cid: int, source_cid: int,
@@ -1096,10 +1098,10 @@ class LeConnectionOrientedChannel(EventEmitter):
elif new_state == self.DISCONNECTED: elif new_state == self.DISCONNECTED:
self.emit('close') self.emit('close')
def send_pdu(self, pdu) -> None: def send_pdu(self, pdu: SupportsBytes | bytes) -> None:
self.manager.send_pdu(self.connection, self.destination_cid, pdu) self.manager.send_pdu(self.connection, self.destination_cid, pdu)
def send_control_frame(self, frame) -> None: def send_control_frame(self, frame: L2CAP_Control_Frame) -> None:
self.manager.send_control_frame(self.connection, L2CAP_LE_SIGNALING_CID, frame) self.manager.send_control_frame(self.connection, L2CAP_LE_SIGNALING_CID, frame)
async def connect(self) -> LeConnectionOrientedChannel: async def connect(self) -> LeConnectionOrientedChannel:
@@ -1154,7 +1156,7 @@ class LeConnectionOrientedChannel(EventEmitter):
if self.state == self.CONNECTED: if self.state == self.CONNECTED:
self.change_state(self.DISCONNECTED) self.change_state(self.DISCONNECTED)
def on_pdu(self, pdu) -> None: def on_pdu(self, pdu: bytes) -> None:
if self.sink is None: if self.sink is None:
logger.warning('received pdu without a sink') logger.warning('received pdu without a sink')
return return
@@ -1384,6 +1386,8 @@ class ChannelManager:
] ]
le_coc_requests: Dict[int, L2CAP_LE_Credit_Based_Connection_Request] le_coc_requests: Dict[int, L2CAP_LE_Credit_Based_Connection_Request]
fixed_channels: Dict[int, Optional[Callable[[int, bytes], Any]]] fixed_channels: Dict[int, Optional[Callable[[int, bytes], Any]]]
_host: Optional[Host]
connection_parameters_update_response: Optional[asyncio.Future[int]]
def __init__( def __init__(
self, self,
@@ -1405,13 +1409,15 @@ class ChannelManager:
self.le_coc_requests = {} # LE CoC connection requests, by identifier self.le_coc_requests = {} # LE CoC connection requests, by identifier
self.extended_features = extended_features self.extended_features = extended_features
self.connectionless_mtu = connectionless_mtu self.connectionless_mtu = connectionless_mtu
self.connection_parameters_update_response = None
@property @property
def host(self): def host(self) -> Host:
assert self._host
return self._host return self._host
@host.setter @host.setter
def host(self, host): def host(self, host: Host) -> None:
if self._host is not None: if self._host is not None:
self._host.remove_listener('disconnection', self.on_disconnection) self._host.remove_listener('disconnection', self.on_disconnection)
self._host = host self._host = host
@@ -1565,7 +1571,7 @@ class ChannelManager:
if connection_handle in self.identifiers: if connection_handle in self.identifiers:
del self.identifiers[connection_handle] del self.identifiers[connection_handle]
def send_pdu(self, connection, cid: int, pdu) -> None: def send_pdu(self, connection, cid: int, pdu: SupportsBytes | bytes) -> None:
pdu_str = pdu.hex() if isinstance(pdu, bytes) else str(pdu) pdu_str = pdu.hex() if isinstance(pdu, bytes) else str(pdu)
logger.debug( logger.debug(
f'{color(">>> Sending L2CAP PDU", "blue")} ' f'{color(">>> Sending L2CAP PDU", "blue")} '
@@ -1574,7 +1580,7 @@ class ChannelManager:
) )
self.host.send_l2cap_pdu(connection.handle, cid, bytes(pdu)) self.host.send_l2cap_pdu(connection.handle, cid, bytes(pdu))
def on_pdu(self, connection: Connection, cid: int, pdu) -> None: def on_pdu(self, connection: Connection, cid: int, pdu: bytes) -> None:
if cid in (L2CAP_SIGNALING_CID, L2CAP_LE_SIGNALING_CID): if cid in (L2CAP_SIGNALING_CID, L2CAP_LE_SIGNALING_CID):
# Parse the L2CAP payload into a Control Frame object # Parse the L2CAP payload into a Control Frame object
control_frame = L2CAP_Control_Frame.from_bytes(pdu) control_frame = L2CAP_Control_Frame.from_bytes(pdu)
@@ -1596,7 +1602,7 @@ class ChannelManager:
channel.on_pdu(pdu) channel.on_pdu(pdu)
def send_control_frame( def send_control_frame(
self, connection: Connection, cid: int, control_frame self, connection: Connection, cid: int, control_frame: L2CAP_Control_Frame
) -> None: ) -> None:
logger.debug( logger.debug(
f'{color(">>> Sending L2CAP Signaling Control Frame", "blue")} ' f'{color(">>> Sending L2CAP Signaling Control Frame", "blue")} '
@@ -1605,7 +1611,9 @@ class ChannelManager:
) )
self.host.send_l2cap_pdu(connection.handle, cid, bytes(control_frame)) self.host.send_l2cap_pdu(connection.handle, cid, bytes(control_frame))
def on_control_frame(self, connection: Connection, cid: int, control_frame) -> None: def on_control_frame(
self, connection: Connection, cid: int, control_frame: L2CAP_Control_Frame
) -> None:
logger.debug( logger.debug(
f'{color("<<< Received L2CAP Signaling Control Frame", "green")} ' f'{color("<<< Received L2CAP Signaling Control Frame", "green")} '
f'on connection [0x{connection.handle:04X}] (CID={cid}) ' f'on connection [0x{connection.handle:04X}] (CID={cid}) '
@@ -1859,11 +1867,45 @@ class ChannelManager:
), ),
) )
async def update_connection_parameters(
self,
connection: Connection,
interval_min: int,
interval_max: int,
latency: int,
timeout: int,
) -> int:
# Check that there isn't already a request pending
if self.connection_parameters_update_response:
raise InvalidStateError('request already pending')
self.connection_parameters_update_response = (
asyncio.get_running_loop().create_future()
)
self.send_control_frame(
connection,
L2CAP_LE_SIGNALING_CID,
L2CAP_Connection_Parameter_Update_Request(
interval_min=interval_min,
interval_max=interval_max,
latency=latency,
timeout=timeout,
),
)
return await self.connection_parameters_update_response
def on_l2cap_connection_parameter_update_response( def on_l2cap_connection_parameter_update_response(
self, connection: Connection, cid: int, response self, connection: Connection, cid: int, response
) -> None: ) -> None:
# TODO: check response if self.connection_parameters_update_response:
pass self.connection_parameters_update_response.set_result(response.result)
self.connection_parameters_update_response = None
else:
logger.warning(
color(
'received l2cap_connection_parameter_update_response without a pending request',
'red',
)
)
def on_l2cap_le_credit_based_connection_request( def on_l2cap_le_credit_based_connection_request(
self, connection: Connection, cid: int, request self, connection: Connection, cid: int, request
@@ -2072,7 +2114,8 @@ class ChannelManager:
# Connect # Connect
try: try:
await channel.connect() await channel.connect()
except Exception: except Exception as e:
del connection_channels[source_cid] del connection_channels[source_cid]
raise e
return channel return channel

View File

@@ -19,6 +19,7 @@ import enum
from typing import Optional, Tuple from typing import Optional, Tuple
from .hci import ( from .hci import (
Address,
HCI_NO_INPUT_NO_OUTPUT_IO_CAPABILITY, HCI_NO_INPUT_NO_OUTPUT_IO_CAPABILITY,
HCI_DISPLAY_ONLY_IO_CAPABILITY, HCI_DISPLAY_ONLY_IO_CAPABILITY,
HCI_DISPLAY_YES_NO_IO_CAPABILITY, HCI_DISPLAY_YES_NO_IO_CAPABILITY,
@@ -168,21 +169,28 @@ class PairingDelegate:
class PairingConfig: class PairingConfig:
"""Configuration for the Pairing protocol.""" """Configuration for the Pairing protocol."""
class AddressType(enum.IntEnum):
PUBLIC = Address.PUBLIC_DEVICE_ADDRESS
RANDOM = Address.RANDOM_DEVICE_ADDRESS
def __init__( def __init__(
self, self,
sc: bool = True, sc: bool = True,
mitm: bool = True, mitm: bool = True,
bonding: bool = True, bonding: bool = True,
delegate: Optional[PairingDelegate] = None, delegate: Optional[PairingDelegate] = None,
identity_address_type: Optional[AddressType] = None,
) -> None: ) -> None:
self.sc = sc self.sc = sc
self.mitm = mitm self.mitm = mitm
self.bonding = bonding self.bonding = bonding
self.delegate = delegate or PairingDelegate() self.delegate = delegate or PairingDelegate()
self.identity_address_type = identity_address_type
def __str__(self) -> str: def __str__(self) -> str:
return ( return (
f'PairingConfig(sc={self.sc}, ' f'PairingConfig(sc={self.sc}, '
f'mitm={self.mitm}, bonding={self.bonding}, ' f'mitm={self.mitm}, bonding={self.bonding}, '
f'identity_address_type={self.identity_address_type}, '
f'delegate[{self.delegate.io_capability}])' f'delegate[{self.delegate.io_capability}])'
) )

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from bumble.pairing import PairingDelegate from bumble.pairing import PairingConfig, PairingDelegate
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict from typing import Any, Dict
@@ -20,6 +20,7 @@ from typing import Any, Dict
@dataclass @dataclass
class Config: class Config:
io_capability: PairingDelegate.IoCapability = PairingDelegate.NO_OUTPUT_NO_INPUT io_capability: PairingDelegate.IoCapability = PairingDelegate.NO_OUTPUT_NO_INPUT
identity_address_type: PairingConfig.AddressType = PairingConfig.AddressType.RANDOM
pairing_sc_enable: bool = True pairing_sc_enable: bool = True
pairing_mitm_enable: bool = True pairing_mitm_enable: bool = True
pairing_bonding_enable: bool = True pairing_bonding_enable: bool = True
@@ -35,6 +36,12 @@ class Config:
'io_capability', 'no_output_no_input' 'io_capability', 'no_output_no_input'
).upper() ).upper()
self.io_capability = getattr(PairingDelegate, io_capability_name) self.io_capability = getattr(PairingDelegate, io_capability_name)
identity_address_type_name: str = config.get(
'identity_address_type', 'random'
).upper()
self.identity_address_type = getattr(
PairingConfig.AddressType, identity_address_type_name
)
self.pairing_sc_enable = config.get('pairing_sc_enable', True) self.pairing_sc_enable = config.get('pairing_sc_enable', True)
self.pairing_mitm_enable = config.get('pairing_mitm_enable', True) self.pairing_mitm_enable = config.get('pairing_mitm_enable', True)
self.pairing_bonding_enable = config.get('pairing_bonding_enable', True) self.pairing_bonding_enable = config.get('pairing_bonding_enable', True)

View File

@@ -34,6 +34,10 @@ from bumble.sdp import (
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
# Default rootcanal HCI TCP address
ROOTCANAL_HCI_ADDRESS = "localhost:6402"
class PandoraDevice: class PandoraDevice:
""" """
Small wrapper around a Bumble device and it's HCI transport. Small wrapper around a Bumble device and it's HCI transport.
@@ -53,7 +57,9 @@ class PandoraDevice:
def __init__(self, config: Dict[str, Any]) -> None: def __init__(self, config: Dict[str, Any]) -> None:
self.config = config self.config = config
self.device = _make_device(config) self.device = _make_device(config)
self._hci_name = config.get('transport', '') self._hci_name = config.get(
'transport', f"tcp-client:{config.get('tcp', ROOTCANAL_HCI_ADDRESS)}"
)
self._hci = None self._hci = None
@property @property

View File

@@ -112,7 +112,7 @@ class HostService(HostServicer):
async def FactoryReset( async def FactoryReset(
self, request: empty_pb2.Empty, context: grpc.ServicerContext self, request: empty_pb2.Empty, context: grpc.ServicerContext
) -> empty_pb2.Empty: ) -> empty_pb2.Empty:
self.log.info('FactoryReset') self.log.debug('FactoryReset')
# delete all bonds # delete all bonds
if self.device.keystore is not None: if self.device.keystore is not None:
@@ -126,7 +126,7 @@ class HostService(HostServicer):
async def Reset( async def Reset(
self, request: empty_pb2.Empty, context: grpc.ServicerContext self, request: empty_pb2.Empty, context: grpc.ServicerContext
) -> empty_pb2.Empty: ) -> empty_pb2.Empty:
self.log.info('Reset') self.log.debug('Reset')
# clear service. # clear service.
self.waited_connections.clear() self.waited_connections.clear()
@@ -139,7 +139,7 @@ class HostService(HostServicer):
async def ReadLocalAddress( async def ReadLocalAddress(
self, request: empty_pb2.Empty, context: grpc.ServicerContext self, request: empty_pb2.Empty, context: grpc.ServicerContext
) -> ReadLocalAddressResponse: ) -> ReadLocalAddressResponse:
self.log.info('ReadLocalAddress') self.log.debug('ReadLocalAddress')
return ReadLocalAddressResponse( return ReadLocalAddressResponse(
address=bytes(reversed(bytes(self.device.public_address))) address=bytes(reversed(bytes(self.device.public_address)))
) )
@@ -152,7 +152,7 @@ class HostService(HostServicer):
address = Address( address = Address(
bytes(reversed(request.address)), address_type=Address.PUBLIC_DEVICE_ADDRESS bytes(reversed(request.address)), address_type=Address.PUBLIC_DEVICE_ADDRESS
) )
self.log.info(f"Connect to {address}") self.log.debug(f"Connect to {address}")
try: try:
connection = await self.device.connect( connection = await self.device.connect(
@@ -167,7 +167,7 @@ class HostService(HostServicer):
return ConnectResponse(connection_already_exists=empty_pb2.Empty()) return ConnectResponse(connection_already_exists=empty_pb2.Empty())
raise e raise e
self.log.info(f"Connect to {address} done (handle={connection.handle})") self.log.debug(f"Connect to {address} done (handle={connection.handle})")
cookie = any_pb2.Any(value=connection.handle.to_bytes(4, 'big')) cookie = any_pb2.Any(value=connection.handle.to_bytes(4, 'big'))
return ConnectResponse(connection=Connection(cookie=cookie)) return ConnectResponse(connection=Connection(cookie=cookie))
@@ -186,7 +186,7 @@ class HostService(HostServicer):
if address in (Address.NIL, Address.ANY): if address in (Address.NIL, Address.ANY):
raise ValueError('Invalid address') raise ValueError('Invalid address')
self.log.info(f"WaitConnection from {address}...") self.log.debug(f"WaitConnection from {address}...")
connection = self.device.find_connection_by_bd_addr( connection = self.device.find_connection_by_bd_addr(
address, transport=BT_BR_EDR_TRANSPORT address, transport=BT_BR_EDR_TRANSPORT
@@ -201,7 +201,7 @@ class HostService(HostServicer):
# save connection has waited and respond. # save connection has waited and respond.
self.waited_connections.add(id(connection)) self.waited_connections.add(id(connection))
self.log.info( self.log.debug(
f"WaitConnection from {address} done (handle={connection.handle})" f"WaitConnection from {address} done (handle={connection.handle})"
) )
@@ -216,7 +216,7 @@ class HostService(HostServicer):
if address in (Address.NIL, Address.ANY): if address in (Address.NIL, Address.ANY):
raise ValueError('Invalid address') raise ValueError('Invalid address')
self.log.info(f"ConnectLE to {address}...") self.log.debug(f"ConnectLE to {address}...")
try: try:
connection = await self.device.connect( connection = await self.device.connect(
@@ -233,7 +233,7 @@ class HostService(HostServicer):
return ConnectLEResponse(connection_already_exists=empty_pb2.Empty()) return ConnectLEResponse(connection_already_exists=empty_pb2.Empty())
raise e raise e
self.log.info(f"ConnectLE to {address} done (handle={connection.handle})") self.log.debug(f"ConnectLE to {address} done (handle={connection.handle})")
cookie = any_pb2.Any(value=connection.handle.to_bytes(4, 'big')) cookie = any_pb2.Any(value=connection.handle.to_bytes(4, 'big'))
return ConnectLEResponse(connection=Connection(cookie=cookie)) return ConnectLEResponse(connection=Connection(cookie=cookie))
@@ -243,12 +243,12 @@ class HostService(HostServicer):
self, request: DisconnectRequest, context: grpc.ServicerContext self, request: DisconnectRequest, context: grpc.ServicerContext
) -> empty_pb2.Empty: ) -> empty_pb2.Empty:
connection_handle = int.from_bytes(request.connection.cookie.value, 'big') connection_handle = int.from_bytes(request.connection.cookie.value, 'big')
self.log.info(f"Disconnect: {connection_handle}") self.log.debug(f"Disconnect: {connection_handle}")
self.log.info("Disconnecting...") self.log.debug("Disconnecting...")
if connection := self.device.lookup_connection(connection_handle): if connection := self.device.lookup_connection(connection_handle):
await connection.disconnect(HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR) await connection.disconnect(HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR)
self.log.info("Disconnected") self.log.debug("Disconnected")
return empty_pb2.Empty() return empty_pb2.Empty()
@@ -257,7 +257,7 @@ class HostService(HostServicer):
self, request: WaitDisconnectionRequest, context: grpc.ServicerContext self, request: WaitDisconnectionRequest, context: grpc.ServicerContext
) -> empty_pb2.Empty: ) -> empty_pb2.Empty:
connection_handle = int.from_bytes(request.connection.cookie.value, 'big') connection_handle = int.from_bytes(request.connection.cookie.value, 'big')
self.log.info(f"WaitDisconnection: {connection_handle}") self.log.debug(f"WaitDisconnection: {connection_handle}")
if connection := self.device.lookup_connection(connection_handle): if connection := self.device.lookup_connection(connection_handle):
disconnection_future: asyncio.Future[ disconnection_future: asyncio.Future[
@@ -270,7 +270,7 @@ class HostService(HostServicer):
connection.on('disconnection', on_disconnection) connection.on('disconnection', on_disconnection)
try: try:
await disconnection_future await disconnection_future
self.log.info("Disconnected") self.log.debug("Disconnected")
finally: finally:
connection.remove_listener('disconnection', on_disconnection) # type: ignore connection.remove_listener('disconnection', on_disconnection) # type: ignore
@@ -378,7 +378,7 @@ class HostService(HostServicer):
try: try:
while True: while True:
if not self.device.is_advertising: if not self.device.is_advertising:
self.log.info('Advertise') self.log.debug('Advertise')
await self.device.start_advertising( await self.device.start_advertising(
target=target, target=target,
advertising_type=advertising_type, advertising_type=advertising_type,
@@ -393,10 +393,10 @@ class HostService(HostServicer):
bumble.device.Connection bumble.device.Connection
] = asyncio.get_running_loop().create_future() ] = asyncio.get_running_loop().create_future()
self.log.info('Wait for LE connection...') self.log.debug('Wait for LE connection...')
connection = await pending_connection connection = await pending_connection
self.log.info( self.log.debug(
f"Advertise: Connected to {connection.peer_address} (handle={connection.handle})" f"Advertise: Connected to {connection.peer_address} (handle={connection.handle})"
) )
@@ -410,7 +410,7 @@ class HostService(HostServicer):
self.device.remove_listener('connection', on_connection) # type: ignore self.device.remove_listener('connection', on_connection) # type: ignore
try: try:
self.log.info('Stop advertising') self.log.debug('Stop advertising')
await self.device.abort_on('flush', self.device.stop_advertising()) await self.device.abort_on('flush', self.device.stop_advertising())
except: except:
pass pass
@@ -423,7 +423,7 @@ class HostService(HostServicer):
if request.phys: if request.phys:
raise NotImplementedError("TODO: add support for `request.phys`") raise NotImplementedError("TODO: add support for `request.phys`")
self.log.info('Scan') self.log.debug('Scan')
scan_queue: asyncio.Queue[Advertisement] = asyncio.Queue() scan_queue: asyncio.Queue[Advertisement] = asyncio.Queue()
handler = self.device.on('advertisement', scan_queue.put_nowait) handler = self.device.on('advertisement', scan_queue.put_nowait)
@@ -470,7 +470,7 @@ class HostService(HostServicer):
finally: finally:
self.device.remove_listener('advertisement', handler) # type: ignore self.device.remove_listener('advertisement', handler) # type: ignore
try: try:
self.log.info('Stop scanning') self.log.debug('Stop scanning')
await self.device.abort_on('flush', self.device.stop_scanning()) await self.device.abort_on('flush', self.device.stop_scanning())
except: except:
pass pass
@@ -479,7 +479,7 @@ class HostService(HostServicer):
async def Inquiry( async def Inquiry(
self, request: empty_pb2.Empty, context: grpc.ServicerContext self, request: empty_pb2.Empty, context: grpc.ServicerContext
) -> AsyncGenerator[InquiryResponse, None]: ) -> AsyncGenerator[InquiryResponse, None]:
self.log.info('Inquiry') self.log.debug('Inquiry')
inquiry_queue: asyncio.Queue[ inquiry_queue: asyncio.Queue[
Optional[Tuple[Address, int, AdvertisingData, int]] Optional[Tuple[Address, int, AdvertisingData, int]]
@@ -510,7 +510,7 @@ class HostService(HostServicer):
self.device.remove_listener('inquiry_complete', complete_handler) # type: ignore self.device.remove_listener('inquiry_complete', complete_handler) # type: ignore
self.device.remove_listener('inquiry_result', result_handler) # type: ignore self.device.remove_listener('inquiry_result', result_handler) # type: ignore
try: try:
self.log.info('Stop inquiry') self.log.debug('Stop inquiry')
await self.device.abort_on('flush', self.device.stop_discovery()) await self.device.abort_on('flush', self.device.stop_discovery())
except: except:
pass pass
@@ -519,7 +519,7 @@ class HostService(HostServicer):
async def SetDiscoverabilityMode( async def SetDiscoverabilityMode(
self, request: SetDiscoverabilityModeRequest, context: grpc.ServicerContext self, request: SetDiscoverabilityModeRequest, context: grpc.ServicerContext
) -> empty_pb2.Empty: ) -> empty_pb2.Empty:
self.log.info("SetDiscoverabilityMode") self.log.debug("SetDiscoverabilityMode")
await self.device.set_discoverable(request.mode != NOT_DISCOVERABLE) await self.device.set_discoverable(request.mode != NOT_DISCOVERABLE)
return empty_pb2.Empty() return empty_pb2.Empty()
@@ -527,7 +527,7 @@ class HostService(HostServicer):
async def SetConnectabilityMode( async def SetConnectabilityMode(
self, request: SetConnectabilityModeRequest, context: grpc.ServicerContext self, request: SetConnectabilityModeRequest, context: grpc.ServicerContext
) -> empty_pb2.Empty: ) -> empty_pb2.Empty:
self.log.info("SetConnectabilityMode") self.log.debug("SetConnectabilityMode")
await self.device.set_connectable(request.mode != NOT_CONNECTABLE) await self.device.set_connectable(request.mode != NOT_CONNECTABLE)
return empty_pb2.Empty() return empty_pb2.Empty()

View File

@@ -99,7 +99,7 @@ class PairingDelegate(BasePairingDelegate):
return ev return ev
async def confirm(self, auto: bool = False) -> bool: async def confirm(self, auto: bool = False) -> bool:
self.log.info( self.log.debug(
f"Pairing event: `just_works` (io_capability: {self.io_capability})" f"Pairing event: `just_works` (io_capability: {self.io_capability})"
) )
@@ -114,7 +114,7 @@ class PairingDelegate(BasePairingDelegate):
return answer.confirm return answer.confirm
async def compare_numbers(self, number: int, digits: int = 6) -> bool: async def compare_numbers(self, number: int, digits: int = 6) -> bool:
self.log.info( self.log.debug(
f"Pairing event: `numeric_comparison` (io_capability: {self.io_capability})" f"Pairing event: `numeric_comparison` (io_capability: {self.io_capability})"
) )
@@ -129,7 +129,7 @@ class PairingDelegate(BasePairingDelegate):
return answer.confirm return answer.confirm
async def get_number(self) -> Optional[int]: async def get_number(self) -> Optional[int]:
self.log.info( self.log.debug(
f"Pairing event: `passkey_entry_request` (io_capability: {self.io_capability})" f"Pairing event: `passkey_entry_request` (io_capability: {self.io_capability})"
) )
@@ -146,7 +146,7 @@ class PairingDelegate(BasePairingDelegate):
return answer.passkey return answer.passkey
async def get_string(self, max_length: int) -> Optional[str]: async def get_string(self, max_length: int) -> Optional[str]:
self.log.info( self.log.debug(
f"Pairing event: `pin_code_request` (io_capability: {self.io_capability})" f"Pairing event: `pin_code_request` (io_capability: {self.io_capability})"
) )
@@ -177,7 +177,7 @@ class PairingDelegate(BasePairingDelegate):
): ):
return return
self.log.info( self.log.debug(
f"Pairing event: `passkey_entry_notification` (io_capability: {self.io_capability})" f"Pairing event: `passkey_entry_notification` (io_capability: {self.io_capability})"
) )
@@ -232,6 +232,7 @@ class SecurityService(SecurityServicer):
sc=config.pairing_sc_enable, sc=config.pairing_sc_enable,
mitm=config.pairing_mitm_enable, mitm=config.pairing_mitm_enable,
bonding=config.pairing_bonding_enable, bonding=config.pairing_bonding_enable,
identity_address_type=config.identity_address_type,
delegate=PairingDelegate( delegate=PairingDelegate(
connection, connection,
self, self,
@@ -247,7 +248,7 @@ class SecurityService(SecurityServicer):
async def OnPairing( async def OnPairing(
self, request: AsyncIterator[PairingEventAnswer], context: grpc.ServicerContext self, request: AsyncIterator[PairingEventAnswer], context: grpc.ServicerContext
) -> AsyncGenerator[PairingEvent, None]: ) -> AsyncGenerator[PairingEvent, None]:
self.log.info('OnPairing') self.log.debug('OnPairing')
if self.event_queue is not None: if self.event_queue is not None:
raise RuntimeError('already streaming pairing events') raise RuntimeError('already streaming pairing events')
@@ -273,7 +274,7 @@ class SecurityService(SecurityServicer):
self, request: SecureRequest, context: grpc.ServicerContext self, request: SecureRequest, context: grpc.ServicerContext
) -> SecureResponse: ) -> SecureResponse:
connection_handle = int.from_bytes(request.connection.cookie.value, 'big') connection_handle = int.from_bytes(request.connection.cookie.value, 'big')
self.log.info(f"Secure: {connection_handle}") self.log.debug(f"Secure: {connection_handle}")
connection = self.device.lookup_connection(connection_handle) connection = self.device.lookup_connection(connection_handle)
assert connection assert connection
@@ -291,7 +292,7 @@ class SecurityService(SecurityServicer):
# trigger pairing if needed # trigger pairing if needed
if self.need_pairing(connection, level): if self.need_pairing(connection, level):
try: try:
self.log.info('Pair...') self.log.debug('Pair...')
if ( if (
connection.transport == BT_LE_TRANSPORT connection.transport == BT_LE_TRANSPORT
@@ -309,7 +310,7 @@ class SecurityService(SecurityServicer):
else: else:
await connection.pair() await connection.pair()
self.log.info('Paired') self.log.debug('Paired')
except asyncio.CancelledError: except asyncio.CancelledError:
self.log.warning("Connection died during encryption") self.log.warning("Connection died during encryption")
return SecureResponse(connection_died=empty_pb2.Empty()) return SecureResponse(connection_died=empty_pb2.Empty())
@@ -320,9 +321,9 @@ class SecurityService(SecurityServicer):
# trigger authentication if needed # trigger authentication if needed
if self.need_authentication(connection, level): if self.need_authentication(connection, level):
try: try:
self.log.info('Authenticate...') self.log.debug('Authenticate...')
await connection.authenticate() await connection.authenticate()
self.log.info('Authenticated') self.log.debug('Authenticated')
except asyncio.CancelledError: except asyncio.CancelledError:
self.log.warning("Connection died during authentication") self.log.warning("Connection died during authentication")
return SecureResponse(connection_died=empty_pb2.Empty()) return SecureResponse(connection_died=empty_pb2.Empty())
@@ -333,9 +334,9 @@ class SecurityService(SecurityServicer):
# trigger encryption if needed # trigger encryption if needed
if self.need_encryption(connection, level): if self.need_encryption(connection, level):
try: try:
self.log.info('Encrypt...') self.log.debug('Encrypt...')
await connection.encrypt() await connection.encrypt()
self.log.info('Encrypted') self.log.debug('Encrypted')
except asyncio.CancelledError: except asyncio.CancelledError:
self.log.warning("Connection died during encryption") self.log.warning("Connection died during encryption")
return SecureResponse(connection_died=empty_pb2.Empty()) return SecureResponse(connection_died=empty_pb2.Empty())
@@ -353,7 +354,7 @@ class SecurityService(SecurityServicer):
self, request: WaitSecurityRequest, context: grpc.ServicerContext self, request: WaitSecurityRequest, context: grpc.ServicerContext
) -> WaitSecurityResponse: ) -> WaitSecurityResponse:
connection_handle = int.from_bytes(request.connection.cookie.value, 'big') connection_handle = int.from_bytes(request.connection.cookie.value, 'big')
self.log.info(f"WaitSecurity: {connection_handle}") self.log.debug(f"WaitSecurity: {connection_handle}")
connection = self.device.lookup_connection(connection_handle) connection = self.device.lookup_connection(connection_handle)
assert connection assert connection
@@ -390,7 +391,7 @@ class SecurityService(SecurityServicer):
def set_failure(name: str) -> Callable[..., None]: def set_failure(name: str) -> Callable[..., None]:
def wrapper(*args: Any) -> None: def wrapper(*args: Any) -> None:
self.log.info(f'Wait for security: error `{name}`: {args}') self.log.debug(f'Wait for security: error `{name}`: {args}')
wait_for_security.set_result(name) wait_for_security.set_result(name)
return wrapper return wrapper
@@ -398,13 +399,13 @@ class SecurityService(SecurityServicer):
def try_set_success(*_: Any) -> None: def try_set_success(*_: Any) -> None:
assert connection assert connection
if self.reached_security_level(connection, level): if self.reached_security_level(connection, level):
self.log.info('Wait for security: done') self.log.debug('Wait for security: done')
wait_for_security.set_result('success') wait_for_security.set_result('success')
def on_encryption_change(*_: Any) -> None: def on_encryption_change(*_: Any) -> None:
assert connection assert connection
if self.reached_security_level(connection, level): if self.reached_security_level(connection, level):
self.log.info('Wait for security: done') self.log.debug('Wait for security: done')
wait_for_security.set_result('success') wait_for_security.set_result('success')
elif ( elif (
connection.transport == BT_BR_EDR_TRANSPORT connection.transport == BT_BR_EDR_TRANSPORT
@@ -422,6 +423,8 @@ class SecurityService(SecurityServicer):
'pairing': try_set_success, 'pairing': try_set_success,
'connection_authentication': try_set_success, 'connection_authentication': try_set_success,
'connection_encryption_change': on_encryption_change, 'connection_encryption_change': on_encryption_change,
'classic_pairing': try_set_success,
'classic_pairing_failure': set_failure('pairing_failure'),
} }
# register event handlers # register event handlers
@@ -432,7 +435,7 @@ class SecurityService(SecurityServicer):
if self.reached_security_level(connection, level): if self.reached_security_level(connection, level):
return WaitSecurityResponse(success=empty_pb2.Empty()) return WaitSecurityResponse(success=empty_pb2.Empty())
self.log.info('Wait for security...') self.log.debug('Wait for security...')
kwargs = {} kwargs = {}
kwargs[await wait_for_security] = empty_pb2.Empty() kwargs[await wait_for_security] = empty_pb2.Empty()
@@ -442,12 +445,12 @@ class SecurityService(SecurityServicer):
# wait for `authenticate` to finish if any # wait for `authenticate` to finish if any
if authenticate_task is not None: if authenticate_task is not None:
self.log.info('Wait for authentication...') self.log.debug('Wait for authentication...')
try: try:
await authenticate_task # type: ignore await authenticate_task # type: ignore
except: except:
pass pass
self.log.info('Authenticated') self.log.debug('Authenticated')
return WaitSecurityResponse(**kwargs) return WaitSecurityResponse(**kwargs)
@@ -503,7 +506,7 @@ class SecurityStorageService(SecurityStorageServicer):
self, request: IsBondedRequest, context: grpc.ServicerContext self, request: IsBondedRequest, context: grpc.ServicerContext
) -> wrappers_pb2.BoolValue: ) -> wrappers_pb2.BoolValue:
address = utils.address_from_request(request, request.WhichOneof("address")) address = utils.address_from_request(request, request.WhichOneof("address"))
self.log.info(f"IsBonded: {address}") self.log.debug(f"IsBonded: {address}")
if self.device.keystore is not None: if self.device.keystore is not None:
is_bonded = await self.device.keystore.get(str(address)) is not None is_bonded = await self.device.keystore.get(str(address)) is not None
@@ -517,7 +520,7 @@ class SecurityStorageService(SecurityStorageServicer):
self, request: DeleteBondRequest, context: grpc.ServicerContext self, request: DeleteBondRequest, context: grpc.ServicerContext
) -> empty_pb2.Empty: ) -> empty_pb2.Empty:
address = utils.address_from_request(request, request.WhichOneof("address")) address = utils.address_from_request(request, request.WhichOneof("address"))
self.log.info(f"DeleteBond: {address}") self.log.debug(f"DeleteBond: {address}")
if self.device.keystore is not None: if self.device.keystore is not None:
with suppress(KeyError): with suppress(KeyError):

View File

@@ -15,15 +15,37 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations
import logging import logging
import asyncio import asyncio
import enum
from typing import Callable, Dict, List, Optional, Tuple, Union, TYPE_CHECKING
from pyee import EventEmitter from pyee import EventEmitter
from typing import Optional, Tuple, Callable, Dict, Union
from . import core, l2cap from . import core, l2cap
from .colors import color from .colors import color
from .core import BT_BR_EDR_TRANSPORT, InvalidStateError, ProtocolError from .core import (
UUID,
BT_RFCOMM_PROTOCOL_ID,
BT_BR_EDR_TRANSPORT,
BT_L2CAP_PROTOCOL_ID,
InvalidStateError,
ProtocolError,
)
from .sdp import (
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_PUBLIC_BROWSE_ROOT,
DataElement,
ServiceAttribute,
)
if TYPE_CHECKING:
from bumble.device import Device, Connection
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -105,6 +127,50 @@ RFCOMM_DYNAMIC_CHANNEL_NUMBER_END = 30
# fmt: on # fmt: on
# -----------------------------------------------------------------------------
def make_service_sdp_records(
service_record_handle: int, channel: int, uuid: Optional[UUID] = None
) -> List[ServiceAttribute]:
"""
Create SDP records for an RFComm service given a channel number and an
optional UUID. A Service Class Attribute is included only if the UUID is not None.
"""
records = [
ServiceAttribute(
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
DataElement.unsigned_integer_32(service_record_handle),
),
ServiceAttribute(
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
DataElement.sequence([DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)]),
),
ServiceAttribute(
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.sequence([DataElement.uuid(BT_L2CAP_PROTOCOL_ID)]),
DataElement.sequence(
[
DataElement.uuid(BT_RFCOMM_PROTOCOL_ID),
DataElement.unsigned_integer_8(channel),
]
),
]
),
),
]
if uuid:
records.append(
ServiceAttribute(
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
DataElement.sequence([DataElement.uuid(uuid)]),
)
)
return records
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def compute_fcs(buffer: bytes) -> int: def compute_fcs(buffer: bytes) -> int:
result = 0xFF result = 0xFF
@@ -149,9 +215,9 @@ class RFCOMM_Frame:
return RFCOMM_FRAME_TYPE_NAMES[self.type] return RFCOMM_FRAME_TYPE_NAMES[self.type]
@staticmethod @staticmethod
def parse_mcc(data) -> Tuple[int, int, bytes]: def parse_mcc(data) -> Tuple[int, bool, bytes]:
mcc_type = data[0] >> 2 mcc_type = data[0] >> 2
c_r = (data[0] >> 1) & 1 c_r = bool((data[0] >> 1) & 1)
length = data[1] length = data[1]
if data[1] & 1: if data[1] & 1:
length >>= 1 length >>= 1
@@ -192,7 +258,7 @@ class RFCOMM_Frame:
) )
@staticmethod @staticmethod
def from_bytes(data: bytes): def from_bytes(data: bytes) -> RFCOMM_Frame:
# Extract fields # Extract fields
dlci = (data[0] >> 2) & 0x3F dlci = (data[0] >> 2) & 0x3F
c_r = (data[0] >> 1) & 0x01 c_r = (data[0] >> 1) & 0x01
@@ -215,7 +281,7 @@ class RFCOMM_Frame:
return frame return frame
def __bytes__(self): def __bytes__(self) -> bytes:
return ( return (
bytes([self.address, self.control]) bytes([self.address, self.control])
+ self.length + self.length
@@ -223,7 +289,7 @@ class RFCOMM_Frame:
+ bytes([self.fcs]) + bytes([self.fcs])
) )
def __str__(self): def __str__(self) -> str:
return ( return (
f'{color(self.type_name(), "yellow")}' f'{color(self.type_name(), "yellow")}'
f'(c/r={self.c_r},' f'(c/r={self.c_r},'
@@ -253,7 +319,7 @@ class RFCOMM_MCC_PN:
max_frame_size: int, max_frame_size: int,
max_retransmissions: int, max_retransmissions: int,
window_size: int, window_size: int,
): ) -> None:
self.dlci = dlci self.dlci = dlci
self.cl = cl self.cl = cl
self.priority = priority self.priority = priority
@@ -263,7 +329,7 @@ class RFCOMM_MCC_PN:
self.window_size = window_size self.window_size = window_size
@staticmethod @staticmethod
def from_bytes(data: bytes): def from_bytes(data: bytes) -> RFCOMM_MCC_PN:
return RFCOMM_MCC_PN( return RFCOMM_MCC_PN(
dlci=data[0], dlci=data[0],
cl=data[1], cl=data[1],
@@ -274,7 +340,7 @@ class RFCOMM_MCC_PN:
window_size=data[7], window_size=data[7],
) )
def __bytes__(self): def __bytes__(self) -> bytes:
return bytes( return bytes(
[ [
self.dlci & 0xFF, self.dlci & 0xFF,
@@ -288,7 +354,7 @@ class RFCOMM_MCC_PN:
] ]
) )
def __str__(self): def __str__(self) -> str:
return ( return (
f'PN(dlci={self.dlci},' f'PN(dlci={self.dlci},'
f'cl={self.cl},' f'cl={self.cl},'
@@ -309,7 +375,9 @@ class RFCOMM_MCC_MSC:
ic: int ic: int
dv: int dv: int
def __init__(self, dlci: int, fc: int, rtc: int, rtr: int, ic: int, dv: int): def __init__(
self, dlci: int, fc: int, rtc: int, rtr: int, ic: int, dv: int
) -> None:
self.dlci = dlci self.dlci = dlci
self.fc = fc self.fc = fc
self.rtc = rtc self.rtc = rtc
@@ -318,7 +386,7 @@ class RFCOMM_MCC_MSC:
self.dv = dv self.dv = dv
@staticmethod @staticmethod
def from_bytes(data: bytes): def from_bytes(data: bytes) -> RFCOMM_MCC_MSC:
return RFCOMM_MCC_MSC( return RFCOMM_MCC_MSC(
dlci=data[0] >> 2, dlci=data[0] >> 2,
fc=data[1] >> 1 & 1, fc=data[1] >> 1 & 1,
@@ -328,7 +396,7 @@ class RFCOMM_MCC_MSC:
dv=data[1] >> 7 & 1, dv=data[1] >> 7 & 1,
) )
def __bytes__(self): def __bytes__(self) -> bytes:
return bytes( return bytes(
[ [
(self.dlci << 2) | 3, (self.dlci << 2) | 3,
@@ -341,7 +409,7 @@ class RFCOMM_MCC_MSC:
] ]
) )
def __str__(self): def __str__(self) -> str:
return ( return (
f'MSC(dlci={self.dlci},' f'MSC(dlci={self.dlci},'
f'fc={self.fc},' f'fc={self.fc},'
@@ -354,29 +422,24 @@ class RFCOMM_MCC_MSC:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class DLC(EventEmitter): class DLC(EventEmitter):
# States class State(enum.IntEnum):
INIT = 0x00 INIT = 0x00
CONNECTING = 0x01 CONNECTING = 0x01
CONNECTED = 0x02 CONNECTED = 0x02
DISCONNECTING = 0x03 DISCONNECTING = 0x03
DISCONNECTED = 0x04 DISCONNECTED = 0x04
RESET = 0x05 RESET = 0x05
STATE_NAMES = {
INIT: 'INIT',
CONNECTING: 'CONNECTING',
CONNECTED: 'CONNECTED',
DISCONNECTING: 'DISCONNECTING',
DISCONNECTED: 'DISCONNECTED',
RESET: 'RESET',
}
connection_result: Optional[asyncio.Future] connection_result: Optional[asyncio.Future]
sink: Optional[Callable[[bytes], None]] sink: Optional[Callable[[bytes], None]]
def __init__( def __init__(
self, multiplexer, dlci: int, max_frame_size: int, initial_tx_credits: int self,
): multiplexer: Multiplexer,
dlci: int,
max_frame_size: int,
initial_tx_credits: int,
) -> None:
super().__init__() super().__init__()
self.multiplexer = multiplexer self.multiplexer = multiplexer
self.dlci = dlci self.dlci = dlci
@@ -384,9 +447,9 @@ class DLC(EventEmitter):
self.rx_threshold = self.rx_credits // 2 self.rx_threshold = self.rx_credits // 2
self.tx_credits = initial_tx_credits self.tx_credits = initial_tx_credits
self.tx_buffer = b'' self.tx_buffer = b''
self.state = DLC.INIT self.state = DLC.State.INIT
self.role = multiplexer.role self.role = multiplexer.role
self.c_r = 1 if self.role == Multiplexer.INITIATOR else 0 self.c_r = 1 if self.role == Multiplexer.Role.INITIATOR else 0
self.sink = None self.sink = None
self.connection_result = None self.connection_result = None
@@ -396,14 +459,8 @@ class DLC(EventEmitter):
max_frame_size, self.multiplexer.l2cap_channel.mtu - max_overhead max_frame_size, self.multiplexer.l2cap_channel.mtu - max_overhead
) )
@staticmethod def change_state(self, new_state: State) -> None:
def state_name(state: int) -> str: logger.debug(f'{self} state change -> {color(new_state.name, "magenta")}')
return DLC.STATE_NAMES[state]
def change_state(self, new_state: int) -> None:
logger.debug(
f'{self} state change -> {color(self.state_name(new_state), "magenta")}'
)
self.state = new_state self.state = new_state
def send_frame(self, frame: RFCOMM_Frame) -> None: def send_frame(self, frame: RFCOMM_Frame) -> None:
@@ -413,8 +470,8 @@ class DLC(EventEmitter):
handler = getattr(self, f'on_{frame.type_name()}_frame'.lower()) handler = getattr(self, f'on_{frame.type_name()}_frame'.lower())
handler(frame) handler(frame)
def on_sabm_frame(self, _frame) -> None: def on_sabm_frame(self, _frame: RFCOMM_Frame) -> None:
if self.state != DLC.CONNECTING: if self.state != DLC.State.CONNECTING:
logger.warning( logger.warning(
color('!!! received SABM when not in CONNECTING state', 'red') color('!!! received SABM when not in CONNECTING state', 'red')
) )
@@ -430,11 +487,11 @@ class DLC(EventEmitter):
logger.debug(f'>>> MCC MSC Command: {msc}') logger.debug(f'>>> MCC MSC Command: {msc}')
self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc)) self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
self.change_state(DLC.CONNECTED) self.change_state(DLC.State.CONNECTED)
self.emit('open') self.emit('open')
def on_ua_frame(self, _frame) -> None: def on_ua_frame(self, _frame: RFCOMM_Frame) -> None:
if self.state != DLC.CONNECTING: if self.state != DLC.State.CONNECTING:
logger.warning( logger.warning(
color('!!! received SABM when not in CONNECTING state', 'red') color('!!! received SABM when not in CONNECTING state', 'red')
) )
@@ -448,14 +505,14 @@ class DLC(EventEmitter):
logger.debug(f'>>> MCC MSC Command: {msc}') logger.debug(f'>>> MCC MSC Command: {msc}')
self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc)) self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
self.change_state(DLC.CONNECTED) self.change_state(DLC.State.CONNECTED)
self.multiplexer.on_dlc_open_complete(self) self.multiplexer.on_dlc_open_complete(self)
def on_dm_frame(self, frame) -> None: def on_dm_frame(self, frame: RFCOMM_Frame) -> None:
# TODO: handle all states # TODO: handle all states
pass pass
def on_disc_frame(self, _frame) -> None: def on_disc_frame(self, _frame: RFCOMM_Frame) -> None:
# TODO: handle all states # TODO: handle all states
self.send_frame(RFCOMM_Frame.ua(c_r=1 - self.c_r, dlci=self.dlci)) self.send_frame(RFCOMM_Frame.ua(c_r=1 - self.c_r, dlci=self.dlci))
@@ -489,10 +546,10 @@ class DLC(EventEmitter):
# Check if there's anything to send (including credits) # Check if there's anything to send (including credits)
self.process_tx() self.process_tx()
def on_ui_frame(self, frame) -> None: def on_ui_frame(self, frame: RFCOMM_Frame) -> None:
pass pass
def on_mcc_msc(self, c_r, msc) -> None: def on_mcc_msc(self, c_r: bool, msc: RFCOMM_MCC_MSC) -> None:
if c_r: if c_r:
# Command # Command
logger.debug(f'<<< MCC MSC Command: {msc}') logger.debug(f'<<< MCC MSC Command: {msc}')
@@ -507,15 +564,15 @@ class DLC(EventEmitter):
logger.debug(f'<<< MCC MSC Response: {msc}') logger.debug(f'<<< MCC MSC Response: {msc}')
def connect(self) -> None: def connect(self) -> None:
if self.state != DLC.INIT: if self.state != DLC.State.INIT:
raise InvalidStateError('invalid state') raise InvalidStateError('invalid state')
self.change_state(DLC.CONNECTING) self.change_state(DLC.State.CONNECTING)
self.connection_result = asyncio.get_running_loop().create_future() self.connection_result = asyncio.get_running_loop().create_future()
self.send_frame(RFCOMM_Frame.sabm(c_r=self.c_r, dlci=self.dlci)) self.send_frame(RFCOMM_Frame.sabm(c_r=self.c_r, dlci=self.dlci))
def accept(self) -> None: def accept(self) -> None:
if self.state != DLC.INIT: if self.state != DLC.State.INIT:
raise InvalidStateError('invalid state') raise InvalidStateError('invalid state')
pn = RFCOMM_MCC_PN( pn = RFCOMM_MCC_PN(
@@ -530,7 +587,7 @@ class DLC(EventEmitter):
mcc = RFCOMM_Frame.make_mcc(mcc_type=RFCOMM_MCC_PN_TYPE, c_r=0, data=bytes(pn)) mcc = RFCOMM_Frame.make_mcc(mcc_type=RFCOMM_MCC_PN_TYPE, c_r=0, data=bytes(pn))
logger.debug(f'>>> PN Response: {pn}') logger.debug(f'>>> PN Response: {pn}')
self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc)) self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
self.change_state(DLC.CONNECTING) self.change_state(DLC.State.CONNECTING)
def rx_credits_needed(self) -> int: def rx_credits_needed(self) -> int:
if self.rx_credits <= self.rx_threshold: if self.rx_credits <= self.rx_threshold:
@@ -592,34 +649,24 @@ class DLC(EventEmitter):
# TODO # TODO
pass pass
def __str__(self): def __str__(self) -> str:
return f'DLC(dlci={self.dlci},state={self.state_name(self.state)})' return f'DLC(dlci={self.dlci},state={self.state.name})'
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Multiplexer(EventEmitter): class Multiplexer(EventEmitter):
# Roles class Role(enum.IntEnum):
INITIATOR = 0x00 INITIATOR = 0x00
RESPONDER = 0x01 RESPONDER = 0x01
# States class State(enum.IntEnum):
INIT = 0x00 INIT = 0x00
CONNECTING = 0x01 CONNECTING = 0x01
CONNECTED = 0x02 CONNECTED = 0x02
OPENING = 0x03 OPENING = 0x03
DISCONNECTING = 0x04 DISCONNECTING = 0x04
DISCONNECTED = 0x05 DISCONNECTED = 0x05
RESET = 0x06 RESET = 0x06
STATE_NAMES = {
INIT: 'INIT',
CONNECTING: 'CONNECTING',
CONNECTED: 'CONNECTED',
OPENING: 'OPENING',
DISCONNECTING: 'DISCONNECTING',
DISCONNECTED: 'DISCONNECTED',
RESET: 'RESET',
}
connection_result: Optional[asyncio.Future] connection_result: Optional[asyncio.Future]
disconnection_result: Optional[asyncio.Future] disconnection_result: Optional[asyncio.Future]
@@ -627,11 +674,11 @@ class Multiplexer(EventEmitter):
acceptor: Optional[Callable[[int], bool]] acceptor: Optional[Callable[[int], bool]]
dlcs: Dict[int, DLC] dlcs: Dict[int, DLC]
def __init__(self, l2cap_channel: l2cap.Channel, role: int) -> None: def __init__(self, l2cap_channel: l2cap.Channel, role: Role) -> None:
super().__init__() super().__init__()
self.role = role self.role = role
self.l2cap_channel = l2cap_channel self.l2cap_channel = l2cap_channel
self.state = Multiplexer.INIT self.state = Multiplexer.State.INIT
self.dlcs = {} # DLCs, by DLCI self.dlcs = {} # DLCs, by DLCI
self.connection_result = None self.connection_result = None
self.disconnection_result = None self.disconnection_result = None
@@ -641,14 +688,8 @@ class Multiplexer(EventEmitter):
# Become a sink for the L2CAP channel # Become a sink for the L2CAP channel
l2cap_channel.sink = self.on_pdu l2cap_channel.sink = self.on_pdu
@staticmethod def change_state(self, new_state: State) -> None:
def state_name(state: int): logger.debug(f'{self} state change -> {color(new_state.name, "cyan")}')
return Multiplexer.STATE_NAMES[state]
def change_state(self, new_state: int) -> None:
logger.debug(
f'{self} state change -> {color(self.state_name(new_state), "cyan")}'
)
self.state = new_state self.state = new_state
def send_frame(self, frame: RFCOMM_Frame) -> None: def send_frame(self, frame: RFCOMM_Frame) -> None:
@@ -679,28 +720,28 @@ class Multiplexer(EventEmitter):
handler = getattr(self, f'on_{frame.type_name()}_frame'.lower()) handler = getattr(self, f'on_{frame.type_name()}_frame'.lower())
handler(frame) handler(frame)
def on_sabm_frame(self, _frame) -> None: def on_sabm_frame(self, _frame: RFCOMM_Frame) -> None:
if self.state != Multiplexer.INIT: if self.state != Multiplexer.State.INIT:
logger.debug('not in INIT state, ignoring SABM') logger.debug('not in INIT state, ignoring SABM')
return return
self.change_state(Multiplexer.CONNECTED) self.change_state(Multiplexer.State.CONNECTED)
self.send_frame(RFCOMM_Frame.ua(c_r=1, dlci=0)) self.send_frame(RFCOMM_Frame.ua(c_r=1, dlci=0))
def on_ua_frame(self, _frame) -> None: def on_ua_frame(self, _frame: RFCOMM_Frame) -> None:
if self.state == Multiplexer.CONNECTING: if self.state == Multiplexer.State.CONNECTING:
self.change_state(Multiplexer.CONNECTED) self.change_state(Multiplexer.State.CONNECTED)
if self.connection_result: if self.connection_result:
self.connection_result.set_result(0) self.connection_result.set_result(0)
self.connection_result = None self.connection_result = None
elif self.state == Multiplexer.DISCONNECTING: elif self.state == Multiplexer.State.DISCONNECTING:
self.change_state(Multiplexer.DISCONNECTED) self.change_state(Multiplexer.State.DISCONNECTED)
if self.disconnection_result: if self.disconnection_result:
self.disconnection_result.set_result(None) self.disconnection_result.set_result(None)
self.disconnection_result = None self.disconnection_result = None
def on_dm_frame(self, _frame) -> None: def on_dm_frame(self, _frame: RFCOMM_Frame) -> None:
if self.state == Multiplexer.OPENING: if self.state == Multiplexer.State.OPENING:
self.change_state(Multiplexer.CONNECTED) self.change_state(Multiplexer.State.CONNECTED)
if self.open_result: if self.open_result:
self.open_result.set_exception( self.open_result.set_exception(
core.ConnectionError( core.ConnectionError(
@@ -713,10 +754,12 @@ class Multiplexer(EventEmitter):
else: else:
logger.warning(f'unexpected state for DM: {self}') logger.warning(f'unexpected state for DM: {self}')
def on_disc_frame(self, _frame) -> None: def on_disc_frame(self, _frame: RFCOMM_Frame) -> None:
self.change_state(Multiplexer.DISCONNECTED) self.change_state(Multiplexer.State.DISCONNECTED)
self.send_frame( self.send_frame(
RFCOMM_Frame.ua(c_r=0 if self.role == Multiplexer.INITIATOR else 1, dlci=0) RFCOMM_Frame.ua(
c_r=0 if self.role == Multiplexer.Role.INITIATOR else 1, dlci=0
)
) )
def on_uih_frame(self, frame: RFCOMM_Frame) -> None: def on_uih_frame(self, frame: RFCOMM_Frame) -> None:
@@ -729,11 +772,11 @@ class Multiplexer(EventEmitter):
mcs = RFCOMM_MCC_MSC.from_bytes(value) mcs = RFCOMM_MCC_MSC.from_bytes(value)
self.on_mcc_msc(c_r, mcs) self.on_mcc_msc(c_r, mcs)
def on_ui_frame(self, frame) -> None: def on_ui_frame(self, frame: RFCOMM_Frame) -> None:
pass pass
def on_mcc_pn(self, c_r, pn) -> None: def on_mcc_pn(self, c_r: bool, pn: RFCOMM_MCC_PN) -> None:
if c_r == 1: if c_r:
# Command # Command
logger.debug(f'<<< PN Command: {pn}') logger.debug(f'<<< PN Command: {pn}')
@@ -764,14 +807,14 @@ class Multiplexer(EventEmitter):
else: else:
# Response # Response
logger.debug(f'>>> PN Response: {pn}') logger.debug(f'>>> PN Response: {pn}')
if self.state == Multiplexer.OPENING: if self.state == Multiplexer.State.OPENING:
dlc = DLC(self, pn.dlci, pn.max_frame_size, pn.window_size) dlc = DLC(self, pn.dlci, pn.max_frame_size, pn.window_size)
self.dlcs[pn.dlci] = dlc self.dlcs[pn.dlci] = dlc
dlc.connect() dlc.connect()
else: else:
logger.warning('ignoring PN response') logger.warning('ignoring PN response')
def on_mcc_msc(self, c_r, msc) -> None: def on_mcc_msc(self, c_r: bool, msc: RFCOMM_MCC_MSC) -> None:
dlc = self.dlcs.get(msc.dlci) dlc = self.dlcs.get(msc.dlci)
if dlc is None: if dlc is None:
logger.warning(f'no dlc for DLCI {msc.dlci}') logger.warning(f'no dlc for DLCI {msc.dlci}')
@@ -779,30 +822,30 @@ class Multiplexer(EventEmitter):
dlc.on_mcc_msc(c_r, msc) dlc.on_mcc_msc(c_r, msc)
async def connect(self) -> None: async def connect(self) -> None:
if self.state != Multiplexer.INIT: if self.state != Multiplexer.State.INIT:
raise InvalidStateError('invalid state') raise InvalidStateError('invalid state')
self.change_state(Multiplexer.CONNECTING) self.change_state(Multiplexer.State.CONNECTING)
self.connection_result = asyncio.get_running_loop().create_future() self.connection_result = asyncio.get_running_loop().create_future()
self.send_frame(RFCOMM_Frame.sabm(c_r=1, dlci=0)) self.send_frame(RFCOMM_Frame.sabm(c_r=1, dlci=0))
return await self.connection_result return await self.connection_result
async def disconnect(self) -> None: async def disconnect(self) -> None:
if self.state != Multiplexer.CONNECTED: if self.state != Multiplexer.State.CONNECTED:
return return
self.disconnection_result = asyncio.get_running_loop().create_future() self.disconnection_result = asyncio.get_running_loop().create_future()
self.change_state(Multiplexer.DISCONNECTING) self.change_state(Multiplexer.State.DISCONNECTING)
self.send_frame( self.send_frame(
RFCOMM_Frame.disc( RFCOMM_Frame.disc(
c_r=1 if self.role == Multiplexer.INITIATOR else 0, dlci=0 c_r=1 if self.role == Multiplexer.Role.INITIATOR else 0, dlci=0
) )
) )
await self.disconnection_result await self.disconnection_result
async def open_dlc(self, channel: int) -> DLC: async def open_dlc(self, channel: int) -> DLC:
if self.state != Multiplexer.CONNECTED: if self.state != Multiplexer.State.CONNECTED:
if self.state == Multiplexer.OPENING: if self.state == Multiplexer.State.OPENING:
raise InvalidStateError('open already in progress') raise InvalidStateError('open already in progress')
raise InvalidStateError('not connected') raise InvalidStateError('not connected')
@@ -819,10 +862,10 @@ class Multiplexer(EventEmitter):
mcc = RFCOMM_Frame.make_mcc(mcc_type=RFCOMM_MCC_PN_TYPE, c_r=1, data=bytes(pn)) mcc = RFCOMM_Frame.make_mcc(mcc_type=RFCOMM_MCC_PN_TYPE, c_r=1, data=bytes(pn))
logger.debug(f'>>> Sending MCC: {pn}') logger.debug(f'>>> Sending MCC: {pn}')
self.open_result = asyncio.get_running_loop().create_future() self.open_result = asyncio.get_running_loop().create_future()
self.change_state(Multiplexer.OPENING) self.change_state(Multiplexer.State.OPENING)
self.send_frame( self.send_frame(
RFCOMM_Frame.uih( RFCOMM_Frame.uih(
c_r=1 if self.role == Multiplexer.INITIATOR else 0, c_r=1 if self.role == Multiplexer.Role.INITIATOR else 0,
dlci=0, dlci=0,
information=mcc, information=mcc,
) )
@@ -831,14 +874,14 @@ class Multiplexer(EventEmitter):
self.open_result = None self.open_result = None
return result return result
def on_dlc_open_complete(self, dlc: DLC): def on_dlc_open_complete(self, dlc: DLC) -> None:
logger.debug(f'DLC [{dlc.dlci}] open complete') logger.debug(f'DLC [{dlc.dlci}] open complete')
self.change_state(Multiplexer.CONNECTED) self.change_state(Multiplexer.State.CONNECTED)
if self.open_result: if self.open_result:
self.open_result.set_result(dlc) self.open_result.set_result(dlc)
def __str__(self): def __str__(self) -> str:
return f'Multiplexer(state={self.state_name(self.state)})' return f'Multiplexer(state={self.state.name})'
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -846,7 +889,7 @@ class Client:
multiplexer: Optional[Multiplexer] multiplexer: Optional[Multiplexer]
l2cap_channel: Optional[l2cap.Channel] l2cap_channel: Optional[l2cap.Channel]
def __init__(self, device, connection) -> None: def __init__(self, device: Device, connection: Connection) -> None:
self.device = device self.device = device
self.connection = connection self.connection = connection
self.l2cap_channel = None self.l2cap_channel = None
@@ -864,7 +907,7 @@ class Client:
assert self.l2cap_channel is not None assert self.l2cap_channel is not None
# Create a mutliplexer to manage DLCs with the server # Create a mutliplexer to manage DLCs with the server
self.multiplexer = Multiplexer(self.l2cap_channel, Multiplexer.INITIATOR) self.multiplexer = Multiplexer(self.l2cap_channel, Multiplexer.Role.INITIATOR)
# Connect the multiplexer # Connect the multiplexer
await self.multiplexer.connect() await self.multiplexer.connect()
@@ -886,7 +929,7 @@ class Client:
class Server(EventEmitter): class Server(EventEmitter):
acceptors: Dict[int, Callable[[DLC], None]] acceptors: Dict[int, Callable[[DLC], None]]
def __init__(self, device) -> None: def __init__(self, device: Device) -> None:
super().__init__() super().__init__()
self.device = device self.device = device
self.multiplexer = None self.multiplexer = None
@@ -925,7 +968,7 @@ class Server(EventEmitter):
logger.debug(f'$$$ L2CAP channel open: {l2cap_channel}') logger.debug(f'$$$ L2CAP channel open: {l2cap_channel}')
# Create a new multiplexer for the channel # Create a new multiplexer for the channel
multiplexer = Multiplexer(l2cap_channel, Multiplexer.RESPONDER) multiplexer = Multiplexer(l2cap_channel, Multiplexer.Role.RESPONDER)
multiplexer.acceptor = self.accept_dlc multiplexer.acceptor = self.accept_dlc
multiplexer.on('dlc', self.on_dlc) multiplexer.on('dlc', self.on_dlc)

View File

@@ -18,13 +18,16 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
import struct import struct
from typing import Dict, List, Type from typing import Dict, List, Type, Optional, Tuple, Union, NewType, TYPE_CHECKING
from . import core from . import core, l2cap
from .colors import color from .colors import color
from .core import InvalidStateError from .core import InvalidStateError
from .hci import HCI_Object, name_or_number, key_with_value from .hci import HCI_Object, name_or_number, key_with_value
if TYPE_CHECKING:
from .device import Device, Connection
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -94,6 +97,10 @@ SDP_CLIENT_EXECUTABLE_URL_ATTRIBUTE_ID = 0X000B
SDP_ICON_URL_ATTRIBUTE_ID = 0X000C SDP_ICON_URL_ATTRIBUTE_ID = 0X000C
SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID = 0X000D SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID = 0X000D
# Attribute Identifier (cf. Assigned Numbers for Service Discovery)
# used by AVRCP, HFP and A2DP
SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID = 0x0311
SDP_ATTRIBUTE_ID_NAMES = { SDP_ATTRIBUTE_ID_NAMES = {
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID: 'SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID', SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID: 'SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID',
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID: 'SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID', SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID: 'SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID',
@@ -462,7 +469,7 @@ class ServiceAttribute:
self.value = value self.value = value
@staticmethod @staticmethod
def list_from_data_elements(elements): def list_from_data_elements(elements: List[DataElement]) -> List[ServiceAttribute]:
attribute_list = [] attribute_list = []
for i in range(0, len(elements) // 2): for i in range(0, len(elements) // 2):
attribute_id, attribute_value = elements[2 * i : 2 * (i + 1)] attribute_id, attribute_value = elements[2 * i : 2 * (i + 1)]
@@ -474,7 +481,9 @@ class ServiceAttribute:
return attribute_list return attribute_list
@staticmethod @staticmethod
def find_attribute_in_list(attribute_list, attribute_id): def find_attribute_in_list(
attribute_list: List[ServiceAttribute], attribute_id: int
) -> Optional[DataElement]:
return next( return next(
( (
attribute.value attribute.value
@@ -489,7 +498,7 @@ class ServiceAttribute:
return name_or_number(SDP_ATTRIBUTE_ID_NAMES, id_code) return name_or_number(SDP_ATTRIBUTE_ID_NAMES, id_code)
@staticmethod @staticmethod
def is_uuid_in_value(uuid, value): def is_uuid_in_value(uuid: core.UUID, value: DataElement) -> bool:
# Find if a uuid matches a value, either directly or recursing into sequences # Find if a uuid matches a value, either directly or recursing into sequences
if value.type == DataElement.UUID: if value.type == DataElement.UUID:
return value.value == uuid return value.value == uuid
@@ -543,7 +552,9 @@ class SDP_PDU:
return self return self
@staticmethod @staticmethod
def parse_service_record_handle_list_preceded_by_count(data, offset): def parse_service_record_handle_list_preceded_by_count(
data: bytes, offset: int
) -> Tuple[int, List[int]]:
count = struct.unpack_from('>H', data, offset - 2)[0] count = struct.unpack_from('>H', data, offset - 2)[0]
handle_list = [ handle_list = [
struct.unpack_from('>I', data, offset + x * 4)[0] for x in range(count) struct.unpack_from('>I', data, offset + x * 4)[0] for x in range(count)
@@ -641,6 +652,10 @@ class SDP_ServiceSearchRequest(SDP_PDU):
See Bluetooth spec @ Vol 3, Part B - 4.5.1 SDP_ServiceSearchRequest PDU See Bluetooth spec @ Vol 3, Part B - 4.5.1 SDP_ServiceSearchRequest PDU
''' '''
service_search_pattern: DataElement
maximum_service_record_count: int
continuation_state: bytes
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@SDP_PDU.subclass( @SDP_PDU.subclass(
@@ -659,6 +674,11 @@ class SDP_ServiceSearchResponse(SDP_PDU):
See Bluetooth spec @ Vol 3, Part B - 4.5.2 SDP_ServiceSearchResponse PDU See Bluetooth spec @ Vol 3, Part B - 4.5.2 SDP_ServiceSearchResponse PDU
''' '''
service_record_handle_list: List[int]
total_service_record_count: int
current_service_record_count: int
continuation_state: bytes
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@SDP_PDU.subclass( @SDP_PDU.subclass(
@@ -674,6 +694,11 @@ class SDP_ServiceAttributeRequest(SDP_PDU):
See Bluetooth spec @ Vol 3, Part B - 4.6.1 SDP_ServiceAttributeRequest PDU See Bluetooth spec @ Vol 3, Part B - 4.6.1 SDP_ServiceAttributeRequest PDU
''' '''
service_record_handle: int
maximum_attribute_byte_count: int
attribute_id_list: DataElement
continuation_state: bytes
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@SDP_PDU.subclass( @SDP_PDU.subclass(
@@ -688,6 +713,10 @@ class SDP_ServiceAttributeResponse(SDP_PDU):
See Bluetooth spec @ Vol 3, Part B - 4.6.2 SDP_ServiceAttributeResponse PDU See Bluetooth spec @ Vol 3, Part B - 4.6.2 SDP_ServiceAttributeResponse PDU
''' '''
attribute_list_byte_count: int
attribute_list: bytes
continuation_state: bytes
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@SDP_PDU.subclass( @SDP_PDU.subclass(
@@ -703,6 +732,11 @@ class SDP_ServiceSearchAttributeRequest(SDP_PDU):
See Bluetooth spec @ Vol 3, Part B - 4.7.1 SDP_ServiceSearchAttributeRequest PDU See Bluetooth spec @ Vol 3, Part B - 4.7.1 SDP_ServiceSearchAttributeRequest PDU
''' '''
service_search_pattern: DataElement
maximum_attribute_byte_count: int
attribute_id_list: DataElement
continuation_state: bytes
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@SDP_PDU.subclass( @SDP_PDU.subclass(
@@ -717,26 +751,34 @@ class SDP_ServiceSearchAttributeResponse(SDP_PDU):
See Bluetooth spec @ Vol 3, Part B - 4.7.2 SDP_ServiceSearchAttributeResponse PDU See Bluetooth spec @ Vol 3, Part B - 4.7.2 SDP_ServiceSearchAttributeResponse PDU
''' '''
attribute_list_byte_count: int
attribute_list: bytes
continuation_state: bytes
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Client: class Client:
def __init__(self, device): channel: Optional[l2cap.Channel]
def __init__(self, device: Device) -> None:
self.device = device self.device = device
self.pending_request = None self.pending_request = None
self.channel = None self.channel = None
async def connect(self, connection): async def connect(self, connection: Connection) -> None:
result = await self.device.l2cap_channel_manager.connect(connection, SDP_PSM) result = await self.device.l2cap_channel_manager.connect(connection, SDP_PSM)
self.channel = result self.channel = result
async def disconnect(self): async def disconnect(self) -> None:
if self.channel: if self.channel:
await self.channel.disconnect() await self.channel.disconnect()
self.channel = None self.channel = None
async def search_services(self, uuids): async def search_services(self, uuids: List[core.UUID]) -> List[int]:
if self.pending_request is not None: if self.pending_request is not None:
raise InvalidStateError('request already pending') raise InvalidStateError('request already pending')
if self.channel is None:
raise InvalidStateError('L2CAP not connected')
service_search_pattern = DataElement.sequence( service_search_pattern = DataElement.sequence(
[DataElement.uuid(uuid) for uuid in uuids] [DataElement.uuid(uuid) for uuid in uuids]
@@ -766,9 +808,13 @@ class Client:
return service_record_handle_list return service_record_handle_list
async def search_attributes(self, uuids, attribute_ids): async def search_attributes(
self, uuids: List[core.UUID], attribute_ids: List[Union[int, Tuple[int, int]]]
) -> List[List[ServiceAttribute]]:
if self.pending_request is not None: if self.pending_request is not None:
raise InvalidStateError('request already pending') raise InvalidStateError('request already pending')
if self.channel is None:
raise InvalidStateError('L2CAP not connected')
service_search_pattern = DataElement.sequence( service_search_pattern = DataElement.sequence(
[DataElement.uuid(uuid) for uuid in uuids] [DataElement.uuid(uuid) for uuid in uuids]
@@ -819,9 +865,15 @@ class Client:
if sequence.type == DataElement.SEQUENCE if sequence.type == DataElement.SEQUENCE
] ]
async def get_attributes(self, service_record_handle, attribute_ids): async def get_attributes(
self,
service_record_handle: int,
attribute_ids: List[Union[int, Tuple[int, int]]],
) -> List[ServiceAttribute]:
if self.pending_request is not None: if self.pending_request is not None:
raise InvalidStateError('request already pending') raise InvalidStateError('request already pending')
if self.channel is None:
raise InvalidStateError('L2CAP not connected')
attribute_id_list = DataElement.sequence( attribute_id_list = DataElement.sequence(
[ [
@@ -869,21 +921,25 @@ class Client:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Server: class Server:
CONTINUATION_STATE = bytes([0x01, 0x43]) CONTINUATION_STATE = bytes([0x01, 0x43])
channel: Optional[l2cap.Channel]
Service = NewType('Service', List[ServiceAttribute])
service_records: Dict[int, Service]
current_response: Union[None, bytes, Tuple[int, List[int]]]
def __init__(self, device): def __init__(self, device: Device) -> None:
self.device = device self.device = device
self.service_records = {} # Service records maps, by record handle self.service_records = {} # Service records maps, by record handle
self.channel = None self.channel = None
self.current_response = None self.current_response = None
def register(self, l2cap_channel_manager): def register(self, l2cap_channel_manager: l2cap.ChannelManager) -> None:
l2cap_channel_manager.register_server(SDP_PSM, self.on_connection) l2cap_channel_manager.register_server(SDP_PSM, self.on_connection)
def send_response(self, response): def send_response(self, response):
logger.debug(f'{color(">>> Sending SDP Response", "blue")}: {response}') logger.debug(f'{color(">>> Sending SDP Response", "blue")}: {response}')
self.channel.send_pdu(response) self.channel.send_pdu(response)
def match_services(self, search_pattern): def match_services(self, search_pattern: DataElement) -> Dict[int, Service]:
# Find the services for which the attributes in the pattern is a subset of the # Find the services for which the attributes in the pattern is a subset of the
# service's attribute values (NOTE: the value search recurses into sequences) # service's attribute values (NOTE: the value search recurses into sequences)
matching_services = {} matching_services = {}
@@ -953,7 +1009,9 @@ class Server:
return (payload, continuation_state) return (payload, continuation_state)
@staticmethod @staticmethod
def get_service_attributes(service, attribute_ids): def get_service_attributes(
service: Service, attribute_ids: List[DataElement]
) -> DataElement:
attributes = [] attributes = []
for attribute_id in attribute_ids: for attribute_id in attribute_ids:
if attribute_id.value_size == 4: if attribute_id.value_size == 4:
@@ -978,10 +1036,10 @@ class Server:
return attribute_list return attribute_list
def on_sdp_service_search_request(self, request): def on_sdp_service_search_request(self, request: SDP_ServiceSearchRequest) -> None:
# Check if this is a continuation # Check if this is a continuation
if len(request.continuation_state) > 1: if len(request.continuation_state) > 1:
if not self.current_response: if self.current_response is None:
self.send_response( self.send_response(
SDP_ErrorResponse( SDP_ErrorResponse(
transaction_id=request.transaction_id, transaction_id=request.transaction_id,
@@ -1010,6 +1068,7 @@ class Server:
) )
# Respond, keeping any unsent handles for later # Respond, keeping any unsent handles for later
assert isinstance(self.current_response, tuple)
service_record_handles = self.current_response[1][ service_record_handles = self.current_response[1][
: request.maximum_service_record_count : request.maximum_service_record_count
] ]
@@ -1033,10 +1092,12 @@ class Server:
) )
) )
def on_sdp_service_attribute_request(self, request): def on_sdp_service_attribute_request(
self, request: SDP_ServiceAttributeRequest
) -> None:
# Check if this is a continuation # Check if this is a continuation
if len(request.continuation_state) > 1: if len(request.continuation_state) > 1:
if not self.current_response: if self.current_response is None:
self.send_response( self.send_response(
SDP_ErrorResponse( SDP_ErrorResponse(
transaction_id=request.transaction_id, transaction_id=request.transaction_id,
@@ -1069,22 +1130,24 @@ class Server:
self.current_response = bytes(attribute_list) self.current_response = bytes(attribute_list)
# Respond, keeping any pending chunks for later # Respond, keeping any pending chunks for later
attribute_list, continuation_state = self.get_next_response_payload( attribute_list_response, continuation_state = self.get_next_response_payload(
request.maximum_attribute_byte_count request.maximum_attribute_byte_count
) )
self.send_response( self.send_response(
SDP_ServiceAttributeResponse( SDP_ServiceAttributeResponse(
transaction_id=request.transaction_id, transaction_id=request.transaction_id,
attribute_list_byte_count=len(attribute_list), attribute_list_byte_count=len(attribute_list_response),
attribute_list=attribute_list, attribute_list=attribute_list,
continuation_state=continuation_state, continuation_state=continuation_state,
) )
) )
def on_sdp_service_search_attribute_request(self, request): def on_sdp_service_search_attribute_request(
self, request: SDP_ServiceSearchAttributeRequest
) -> None:
# Check if this is a continuation # Check if this is a continuation
if len(request.continuation_state) > 1: if len(request.continuation_state) > 1:
if not self.current_response: if self.current_response is None:
self.send_response( self.send_response(
SDP_ErrorResponse( SDP_ErrorResponse(
transaction_id=request.transaction_id, transaction_id=request.transaction_id,
@@ -1114,13 +1177,13 @@ class Server:
self.current_response = bytes(attribute_lists) self.current_response = bytes(attribute_lists)
# Respond, keeping any pending chunks for later # Respond, keeping any pending chunks for later
attribute_lists, continuation_state = self.get_next_response_payload( attribute_lists_response, continuation_state = self.get_next_response_payload(
request.maximum_attribute_byte_count request.maximum_attribute_byte_count
) )
self.send_response( self.send_response(
SDP_ServiceSearchAttributeResponse( SDP_ServiceSearchAttributeResponse(
transaction_id=request.transaction_id, transaction_id=request.transaction_id,
attribute_lists_byte_count=len(attribute_lists), attribute_lists_byte_count=len(attribute_lists_response),
attribute_lists=attribute_lists, attribute_lists=attribute_lists,
continuation_state=continuation_state, continuation_state=continuation_state,
) )

View File

@@ -993,6 +993,19 @@ class Session:
) )
) )
def send_identity_address_command(self) -> None:
identity_address = {
None: self.connection.self_address,
Address.PUBLIC_DEVICE_ADDRESS: self.manager.device.public_address,
Address.RANDOM_DEVICE_ADDRESS: self.manager.device.random_address,
}[self.pairing_config.identity_address_type]
self.send_command(
SMP_Identity_Address_Information_Command(
addr_type=identity_address.address_type,
bd_addr=identity_address,
)
)
def start_encryption(self, key: bytes) -> None: def start_encryption(self, key: bytes) -> None:
# We can now encrypt the connection with the short term key, so that we can # We can now encrypt the connection with the short term key, so that we can
# distribute the long term and/or other keys over an encrypted connection # distribute the long term and/or other keys over an encrypted connection
@@ -1016,6 +1029,7 @@ class Session:
self.ltk = crypto.h6(ilk, b'brle') self.ltk = crypto.h6(ilk, b'brle')
def distribute_keys(self) -> None: def distribute_keys(self) -> None:
# Distribute the keys as required # Distribute the keys as required
if self.is_initiator: if self.is_initiator:
# CTKD: Derive LTK from LinkKey # CTKD: Derive LTK from LinkKey
@@ -1045,12 +1059,7 @@ class Session:
identity_resolving_key=self.manager.device.irk identity_resolving_key=self.manager.device.irk
) )
) )
self.send_command( self.send_identity_address_command()
SMP_Identity_Address_Information_Command(
addr_type=self.connection.self_address.address_type,
bd_addr=self.connection.self_address,
)
)
# Distribute CSRK # Distribute CSRK
csrk = bytes(16) # FIXME: testing csrk = bytes(16) # FIXME: testing
@@ -1094,12 +1103,7 @@ class Session:
identity_resolving_key=self.manager.device.irk identity_resolving_key=self.manager.device.irk
) )
) )
self.send_command( self.send_identity_address_command()
SMP_Identity_Address_Information_Command(
addr_type=self.connection.self_address.address_type,
bd_addr=self.connection.self_address,
)
)
# Distribute CSRK # Distribute CSRK
csrk = bytes(16) # FIXME: testing csrk = bytes(16) # FIXME: testing
@@ -1268,7 +1272,7 @@ class Session:
keys.link_key = PairingKeys.Key( keys.link_key = PairingKeys.Key(
value=self.link_key, authenticated=authenticated value=self.link_key, authenticated=authenticated
) )
self.manager.on_pairing(self, peer_address, keys) await self.manager.on_pairing(self, peer_address, keys)
def on_pairing_failure(self, reason: int) -> None: def on_pairing_failure(self, reason: int) -> None:
logger.warning(f'pairing failure ({error_name(reason)})') logger.warning(f'pairing failure ({error_name(reason)})')
@@ -1823,20 +1827,14 @@ class Manager(EventEmitter):
def on_session_start(self, session: Session) -> None: def on_session_start(self, session: Session) -> None:
self.device.on_pairing_start(session.connection) self.device.on_pairing_start(session.connection)
def on_pairing( async def on_pairing(
self, session: Session, identity_address: Optional[Address], keys: PairingKeys self, session: Session, identity_address: Optional[Address], keys: PairingKeys
) -> None: ) -> None:
# Store the keys in the key store # Store the keys in the key store
if self.device.keystore and identity_address is not None: if self.device.keystore and identity_address is not None:
self.device.abort_on(
async def store_keys(): 'flush', self.device.update_keys(str(identity_address), keys)
try: )
assert self.device.keystore
await self.device.keystore.update(str(identity_address), keys)
except Exception as error:
logger.warning(f'!!! error while storing keys: {error}')
self.device.abort_on('flush', store_keys())
# Notify the device # Notify the device
self.device.on_pairing(session.connection, identity_address, keys, session.sc) self.device.on_pairing(session.connection, identity_address, keys, session.sc)

View File

@@ -20,7 +20,6 @@ import logging
import os import os
from .common import Transport, AsyncPipeSink, SnoopingTransport from .common import Transport, AsyncPipeSink, SnoopingTransport
from ..controller import Controller
from ..snoop import create_snooper from ..snoop import create_snooper
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -69,6 +68,7 @@ async def open_transport(name: str) -> Transport:
* usb * usb
* pyusb * pyusb
* android-emulator * android-emulator
* android-netsim
""" """
return _wrap_transport(await _open_transport(name)) return _wrap_transport(await _open_transport(name))
@@ -118,7 +118,8 @@ async def _open_transport(name: str) -> Transport:
if scheme == 'file': if scheme == 'file':
from .file import open_file_transport from .file import open_file_transport
return await open_file_transport(spec[0] if spec else None) assert spec is not None
return await open_file_transport(spec[0])
if scheme == 'vhci': if scheme == 'vhci':
from .vhci import open_vhci_transport from .vhci import open_vhci_transport
@@ -133,12 +134,14 @@ async def _open_transport(name: str) -> Transport:
if scheme == 'usb': if scheme == 'usb':
from .usb import open_usb_transport from .usb import open_usb_transport
return await open_usb_transport(spec[0] if spec else None) assert spec is not None
return await open_usb_transport(spec[0])
if scheme == 'pyusb': if scheme == 'pyusb':
from .pyusb import open_pyusb_transport from .pyusb import open_pyusb_transport
return await open_pyusb_transport(spec[0] if spec else None) assert spec is not None
return await open_pyusb_transport(spec[0])
if scheme == 'android-emulator': if scheme == 'android-emulator':
from .android_emulator import open_android_emulator_transport from .android_emulator import open_android_emulator_transport
@@ -167,6 +170,7 @@ async def open_transport_or_link(name: str) -> Transport:
""" """
if name.startswith('link-relay:'): if name.startswith('link-relay:'):
from ..controller import Controller
from ..link import RemoteLink # lazy import from ..link import RemoteLink # lazy import
link = RemoteLink(name[11:]) link = RemoteLink(name[11:])

View File

@@ -18,7 +18,7 @@
import logging import logging
import grpc.aio import grpc.aio
from .common import PumpedTransport, PumpedPacketSource, PumpedPacketSink from .common import PumpedTransport, PumpedPacketSource, PumpedPacketSink, Transport
# pylint: disable=no-name-in-module # pylint: disable=no-name-in-module
from .grpc_protobuf.emulated_bluetooth_pb2_grpc import EmulatedBluetoothServiceStub from .grpc_protobuf.emulated_bluetooth_pb2_grpc import EmulatedBluetoothServiceStub
@@ -33,7 +33,7 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def open_android_emulator_transport(spec): async def open_android_emulator_transport(spec: str | None) -> Transport:
''' '''
Open a transport connection to an Android emulator via its gRPC interface. Open a transport connection to an Android emulator via its gRPC interface.
The parameter string has this syntax: The parameter string has this syntax:
@@ -66,7 +66,7 @@ async def open_android_emulator_transport(spec):
# Parse the parameters # Parse the parameters
mode = 'host' mode = 'host'
server_host = 'localhost' server_host = 'localhost'
server_port = 8554 server_port = '8554'
if spec is not None: if spec is not None:
params = spec.split(',') params = spec.split(',')
for param in params: for param in params:
@@ -82,6 +82,7 @@ async def open_android_emulator_transport(spec):
logger.debug(f'connecting to gRPC server at {server_address}') logger.debug(f'connecting to gRPC server at {server_address}')
channel = grpc.aio.insecure_channel(server_address) channel = grpc.aio.insecure_channel(server_address)
service: EmulatedBluetoothServiceStub | VhciForwardingServiceStub
if mode == 'host': if mode == 'host':
# Connect as a host # Connect as a host
service = EmulatedBluetoothServiceStub(channel) service = EmulatedBluetoothServiceStub(channel)

View File

@@ -121,7 +121,9 @@ def publish_grpc_port(grpc_port) -> bool:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def open_android_netsim_controller_transport(server_host, server_port): async def open_android_netsim_controller_transport(
server_host: str | None, server_port: int
) -> Transport:
if not server_port: if not server_port:
raise ValueError('invalid port') raise ValueError('invalid port')
if server_host == '_' or not server_host: if server_host == '_' or not server_host:

View File

@@ -20,11 +20,12 @@ import contextlib
import struct import struct
import asyncio import asyncio
import logging import logging
from typing import ContextManager import io
from typing import ContextManager, Tuple, Optional, Protocol, Dict
from .. import hci from bumble import hci
from ..colors import color from bumble.colors import color
from ..snoop import Snooper from bumble.snoop import Snooper
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -36,7 +37,7 @@ logger = logging.getLogger(__name__)
# Information needed to parse HCI packets with a generic parser: # Information needed to parse HCI packets with a generic parser:
# For each packet type, the info represents: # For each packet type, the info represents:
# (length-size, length-offset, unpack-type) # (length-size, length-offset, unpack-type)
HCI_PACKET_INFO = { HCI_PACKET_INFO: Dict[int, Tuple[int, int, str]] = {
hci.HCI_COMMAND_PACKET: (1, 2, 'B'), hci.HCI_COMMAND_PACKET: (1, 2, 'B'),
hci.HCI_ACL_DATA_PACKET: (2, 2, 'H'), hci.HCI_ACL_DATA_PACKET: (2, 2, 'H'),
hci.HCI_SYNCHRONOUS_DATA_PACKET: (1, 2, 'B'), hci.HCI_SYNCHRONOUS_DATA_PACKET: (1, 2, 'B'),
@@ -45,33 +46,54 @@ HCI_PACKET_INFO = {
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class PacketPump: # Errors
''' # -----------------------------------------------------------------------------
Pump HCI packets from a reader to a sink class TransportLostError(Exception):
''' """
The Transport has been lost/disconnected.
"""
def __init__(self, reader, sink):
# -----------------------------------------------------------------------------
# Typing Protocols
# -----------------------------------------------------------------------------
class TransportSink(Protocol):
def on_packet(self, packet: bytes) -> None:
...
class TransportSource(Protocol):
terminated: asyncio.Future[None]
def set_packet_sink(self, sink: TransportSink) -> None:
...
# -----------------------------------------------------------------------------
class PacketPump:
"""
Pump HCI packets from a reader to a sink.
"""
def __init__(self, reader: AsyncPacketReader, sink: TransportSink) -> None:
self.reader = reader self.reader = reader
self.sink = sink self.sink = sink
async def run(self): async def run(self) -> None:
while True: while True:
try: try:
# Get a packet from the source
packet = hci.HCI_Packet.from_bytes(await self.reader.next_packet())
# Deliver the packet to the sink # Deliver the packet to the sink
self.sink.on_packet(packet) self.sink.on_packet(await self.reader.next_packet())
except Exception as error: except Exception as error:
logger.warning(f'!!! {error}') logger.warning(f'!!! {error}')
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class PacketParser: class PacketParser:
''' """
In-line parser that accepts data and emits 'on_packet' when a full packet has been In-line parser that accepts data and emits 'on_packet' when a full packet has been
parsed parsed.
''' """
# pylint: disable=attribute-defined-outside-init # pylint: disable=attribute-defined-outside-init
@@ -79,18 +101,22 @@ class PacketParser:
NEED_LENGTH = 1 NEED_LENGTH = 1
NEED_BODY = 2 NEED_BODY = 2
def __init__(self, sink=None): sink: Optional[TransportSink]
extended_packet_info: Dict[int, Tuple[int, int, str]]
packet_info: Optional[Tuple[int, int, str]] = None
def __init__(self, sink: Optional[TransportSink] = None) -> None:
self.sink = sink self.sink = sink
self.extended_packet_info = {} self.extended_packet_info = {}
self.reset() self.reset()
def reset(self): def reset(self) -> None:
self.state = PacketParser.NEED_TYPE self.state = PacketParser.NEED_TYPE
self.bytes_needed = 1 self.bytes_needed = 1
self.packet = bytearray() self.packet = bytearray()
self.packet_info = None self.packet_info = None
def feed_data(self, data): def feed_data(self, data: bytes) -> None:
data_offset = 0 data_offset = 0
data_left = len(data) data_left = len(data)
while data_left and self.bytes_needed: while data_left and self.bytes_needed:
@@ -111,6 +137,7 @@ class PacketParser:
self.state = PacketParser.NEED_LENGTH self.state = PacketParser.NEED_LENGTH
self.bytes_needed = self.packet_info[0] + self.packet_info[1] self.bytes_needed = self.packet_info[0] + self.packet_info[1]
elif self.state == PacketParser.NEED_LENGTH: elif self.state == PacketParser.NEED_LENGTH:
assert self.packet_info is not None
body_length = struct.unpack_from( body_length = struct.unpack_from(
self.packet_info[2], self.packet, 1 + self.packet_info[1] self.packet_info[2], self.packet, 1 + self.packet_info[1]
)[0] )[0]
@@ -128,20 +155,20 @@ class PacketParser:
) )
self.reset() self.reset()
def set_packet_sink(self, sink): def set_packet_sink(self, sink: TransportSink) -> None:
self.sink = sink self.sink = sink
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class PacketReader: class PacketReader:
''' """
Reader that reads HCI packets from a sync source Reader that reads HCI packets from a sync source.
''' """
def __init__(self, source): def __init__(self, source: io.BufferedReader) -> None:
self.source = source self.source = source
def next_packet(self): def next_packet(self) -> Optional[bytes]:
# Get the packet type # Get the packet type
packet_type = self.source.read(1) packet_type = self.source.read(1)
if len(packet_type) != 1: if len(packet_type) != 1:
@@ -150,7 +177,7 @@ class PacketReader:
# Get the packet info based on its type # Get the packet info based on its type
packet_info = HCI_PACKET_INFO.get(packet_type[0]) packet_info = HCI_PACKET_INFO.get(packet_type[0])
if packet_info is None: if packet_info is None:
raise ValueError(f'invalid packet type {packet_type} found') raise ValueError(f'invalid packet type {packet_type[0]} found')
# Read the header (that includes the length) # Read the header (that includes the length)
header_size = packet_info[0] + packet_info[1] header_size = packet_info[0] + packet_info[1]
@@ -169,21 +196,21 @@ class PacketReader:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class AsyncPacketReader: class AsyncPacketReader:
''' """
Reader that reads HCI packets from an async source Reader that reads HCI packets from an async source.
''' """
def __init__(self, source): def __init__(self, source: asyncio.StreamReader) -> None:
self.source = source self.source = source
async def next_packet(self): async def next_packet(self) -> bytes:
# Get the packet type # Get the packet type
packet_type = await self.source.readexactly(1) packet_type = await self.source.readexactly(1)
# Get the packet info based on its type # Get the packet info based on its type
packet_info = HCI_PACKET_INFO.get(packet_type[0]) packet_info = HCI_PACKET_INFO.get(packet_type[0])
if packet_info is None: if packet_info is None:
raise ValueError(f'invalid packet type {packet_type} found') raise ValueError(f'invalid packet type {packet_type[0]} found')
# Read the header (that includes the length) # Read the header (that includes the length)
header_size = packet_info[0] + packet_info[1] header_size = packet_info[0] + packet_info[1]
@@ -198,15 +225,15 @@ class AsyncPacketReader:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class AsyncPipeSink: class AsyncPipeSink:
''' """
Sink that forwards packets asynchronously to another sink Sink that forwards packets asynchronously to another sink.
''' """
def __init__(self, sink): def __init__(self, sink: TransportSink) -> None:
self.sink = sink self.sink = sink
self.loop = asyncio.get_running_loop() self.loop = asyncio.get_running_loop()
def on_packet(self, packet): def on_packet(self, packet: bytes) -> None:
self.loop.call_soon(self.sink.on_packet, packet) self.loop.call_soon(self.sink.on_packet, packet)
@@ -216,35 +243,48 @@ class ParserSource:
Base class designed to be subclassed by transport-specific source classes Base class designed to be subclassed by transport-specific source classes
""" """
def __init__(self): terminated: asyncio.Future[None]
parser: PacketParser
def __init__(self) -> None:
self.parser = PacketParser() self.parser = PacketParser()
self.terminated = asyncio.get_running_loop().create_future() self.terminated = asyncio.get_running_loop().create_future()
def set_packet_sink(self, sink): def set_packet_sink(self, sink: TransportSink) -> None:
self.parser.set_packet_sink(sink) self.parser.set_packet_sink(sink)
async def wait_for_termination(self): def on_transport_lost(self) -> None:
self.terminated.set_result(None)
if self.parser.sink:
if hasattr(self.parser.sink, 'on_transport_lost'):
self.parser.sink.on_transport_lost()
async def wait_for_termination(self) -> None:
"""
Convenience method for backward compatibility. Prefer using the `terminated`
attribute instead.
"""
return await self.terminated return await self.terminated
def close(self): def close(self) -> None:
pass pass
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class StreamPacketSource(asyncio.Protocol, ParserSource): class StreamPacketSource(asyncio.Protocol, ParserSource):
def data_received(self, data): def data_received(self, data: bytes) -> None:
self.parser.feed_data(data) self.parser.feed_data(data)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class StreamPacketSink: class StreamPacketSink:
def __init__(self, transport): def __init__(self, transport: asyncio.WriteTransport) -> None:
self.transport = transport self.transport = transport
def on_packet(self, packet): def on_packet(self, packet: bytes) -> None:
self.transport.write(packet) self.transport.write(packet)
def close(self): def close(self) -> None:
self.transport.close() self.transport.close()
@@ -264,7 +304,7 @@ class Transport:
... ...
""" """
def __init__(self, source, sink): def __init__(self, source: TransportSource, sink: TransportSink) -> None:
self.source = source self.source = source
self.sink = sink self.sink = sink
@@ -278,19 +318,23 @@ class Transport:
return iter((self.source, self.sink)) return iter((self.source, self.sink))
async def close(self) -> None: async def close(self) -> None:
self.source.close() if hasattr(self.source, 'close'):
self.sink.close() self.source.close()
if hasattr(self.sink, 'close'):
self.sink.close()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class PumpedPacketSource(ParserSource): class PumpedPacketSource(ParserSource):
def __init__(self, receive): pump_task: Optional[asyncio.Task[None]]
def __init__(self, receive) -> None:
super().__init__() super().__init__()
self.receive_function = receive self.receive_function = receive
self.pump_task = None self.pump_task = None
def start(self): def start(self) -> None:
async def pump_packets(): async def pump_packets() -> None:
while True: while True:
try: try:
packet = await self.receive_function() packet = await self.receive_function()
@@ -300,12 +344,12 @@ class PumpedPacketSource(ParserSource):
break break
except Exception as error: except Exception as error:
logger.warning(f'exception while waiting for packet: {error}') logger.warning(f'exception while waiting for packet: {error}')
self.terminated.set_result(error) self.terminated.set_exception(error)
break break
self.pump_task = asyncio.create_task(pump_packets()) self.pump_task = asyncio.create_task(pump_packets())
def close(self): def close(self) -> None:
if self.pump_task: if self.pump_task:
self.pump_task.cancel() self.pump_task.cancel()
@@ -317,7 +361,7 @@ class PumpedPacketSink:
self.packet_queue = asyncio.Queue() self.packet_queue = asyncio.Queue()
self.pump_task = None self.pump_task = None
def on_packet(self, packet): def on_packet(self, packet: bytes) -> None:
self.packet_queue.put_nowait(packet) self.packet_queue.put_nowait(packet)
def start(self): def start(self):
@@ -342,15 +386,23 @@ class PumpedPacketSink:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class PumpedTransport(Transport): class PumpedTransport(Transport):
def __init__(self, source, sink, close_function): source: PumpedPacketSource
sink: PumpedPacketSink
def __init__(
self,
source: PumpedPacketSource,
sink: PumpedPacketSink,
close_function,
) -> None:
super().__init__(source, sink) super().__init__(source, sink)
self.close_function = close_function self.close_function = close_function
def start(self): def start(self) -> None:
self.source.start() self.source.start()
self.sink.start() self.sink.start()
async def close(self): async def close(self) -> None:
await super().close() await super().close()
await self.close_function() await self.close_function()
@@ -375,31 +427,38 @@ class SnoopingTransport(Transport):
raise RuntimeError('unexpected code path') # Satisfy the type checker raise RuntimeError('unexpected code path') # Satisfy the type checker
class Source: class Source:
def __init__(self, source, snooper): sink: TransportSink
def __init__(self, source: TransportSource, snooper: Snooper):
self.source = source self.source = source
self.snooper = snooper self.snooper = snooper
self.sink = None self.terminated = source.terminated
def set_packet_sink(self, sink): def set_packet_sink(self, sink: TransportSink) -> None:
self.sink = sink self.sink = sink
self.source.set_packet_sink(self) self.source.set_packet_sink(self)
def on_packet(self, packet): def on_packet(self, packet: bytes) -> None:
self.snooper.snoop(packet, Snooper.Direction.CONTROLLER_TO_HOST) self.snooper.snoop(packet, Snooper.Direction.CONTROLLER_TO_HOST)
if self.sink: if self.sink:
self.sink.on_packet(packet) self.sink.on_packet(packet)
class Sink: class Sink:
def __init__(self, sink, snooper): def __init__(self, sink: TransportSink, snooper: Snooper) -> None:
self.sink = sink self.sink = sink
self.snooper = snooper self.snooper = snooper
def on_packet(self, packet): def on_packet(self, packet: bytes) -> None:
self.snooper.snoop(packet, Snooper.Direction.HOST_TO_CONTROLLER) self.snooper.snoop(packet, Snooper.Direction.HOST_TO_CONTROLLER)
if self.sink: if self.sink:
self.sink.on_packet(packet) self.sink.on_packet(packet)
def __init__(self, transport, snooper, close_snooper=None): def __init__(
self,
transport: Transport,
snooper: Snooper,
close_snooper=None,
) -> None:
super().__init__( super().__init__(
self.Source(transport.source, snooper), self.Sink(transport.sink, snooper) self.Source(transport.source, snooper), self.Sink(transport.sink, snooper)
) )

View File

@@ -28,7 +28,7 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def open_file_transport(spec): async def open_file_transport(spec: str) -> Transport:
''' '''
Open a File transport (typically not for a real file, but for a PTY or other unix Open a File transport (typically not for a real file, but for a PTY or other unix
virtual files). virtual files).

View File

@@ -33,7 +33,7 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def open_hci_socket_transport(spec): async def open_hci_socket_transport(spec: str | None) -> Transport:
''' '''
Open an HCI Socket (only available on some platforms). Open an HCI Socket (only available on some platforms).
The parameter string is either empty (to use the first/default Bluetooth adapter) The parameter string is either empty (to use the first/default Bluetooth adapter)
@@ -47,7 +47,7 @@ async def open_hci_socket_transport(spec):
hci_socket = socket.socket( hci_socket = socket.socket(
socket.AF_BLUETOOTH, socket.AF_BLUETOOTH,
socket.SOCK_RAW | socket.SOCK_NONBLOCK, socket.SOCK_RAW | socket.SOCK_NONBLOCK,
socket.BTPROTO_HCI, socket.BTPROTO_HCI, # type: ignore
) )
except AttributeError as error: except AttributeError as error:
# Not supported on this platform # Not supported on this platform

View File

@@ -32,7 +32,7 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def open_pty_transport(spec): async def open_pty_transport(spec: str | None) -> Transport:
''' '''
Open a PTY transport. Open a PTY transport.
The parameter string may be empty, or a path name where a symbolic link The parameter string may be empty, or a path name where a symbolic link

View File

@@ -35,7 +35,7 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def open_pyusb_transport(spec): async def open_pyusb_transport(spec: str) -> Transport:
''' '''
Open a USB transport. [Implementation based on PyUSB] Open a USB transport. [Implementation based on PyUSB]
The parameter string has this syntax: The parameter string has this syntax:

View File

@@ -28,7 +28,7 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def open_serial_transport(spec): async def open_serial_transport(spec: str) -> Transport:
''' '''
Open a serial port transport. Open a serial port transport.
The parameter string has this syntax: The parameter string has this syntax:

View File

@@ -27,7 +27,7 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def open_tcp_client_transport(spec): async def open_tcp_client_transport(spec: str) -> Transport:
''' '''
Open a TCP client transport. Open a TCP client transport.
The parameter string has this syntax: The parameter string has this syntax:
@@ -39,7 +39,7 @@ async def open_tcp_client_transport(spec):
class TcpPacketSource(StreamPacketSource): class TcpPacketSource(StreamPacketSource):
def connection_lost(self, exc): def connection_lost(self, exc):
logger.debug(f'connection lost: {exc}') logger.debug(f'connection lost: {exc}')
self.terminated.set_result(exc) self.on_transport_lost()
remote_host, remote_port = spec.split(':') remote_host, remote_port = spec.split(':')
tcp_transport, packet_source = await asyncio.get_running_loop().create_connection( tcp_transport, packet_source = await asyncio.get_running_loop().create_connection(

View File

@@ -15,6 +15,7 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio import asyncio
import logging import logging
@@ -27,7 +28,7 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def open_tcp_server_transport(spec): async def open_tcp_server_transport(spec: str) -> Transport:
''' '''
Open a TCP server transport. Open a TCP server transport.
The parameter string has this syntax: The parameter string has this syntax:
@@ -42,7 +43,7 @@ async def open_tcp_server_transport(spec):
async def close(self): async def close(self):
await super().close() await super().close()
class TcpServerProtocol: class TcpServerProtocol(asyncio.BaseProtocol):
def __init__(self, packet_source, packet_sink): def __init__(self, packet_source, packet_sink):
self.packet_source = packet_source self.packet_source = packet_source
self.packet_sink = packet_sink self.packet_sink = packet_sink

View File

@@ -27,7 +27,7 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def open_udp_transport(spec): async def open_udp_transport(spec: str) -> Transport:
''' '''
Open a UDP transport. Open a UDP transport.
The parameter string has this syntax: The parameter string has this syntax:

View File

@@ -60,7 +60,7 @@ def load_libusb():
usb1.loadLibrary(libusb_dll) usb1.loadLibrary(libusb_dll)
async def open_usb_transport(spec): async def open_usb_transport(spec: str) -> Transport:
''' '''
Open a USB transport. Open a USB transport.
The moniker string has this syntax: The moniker string has this syntax:

View File

@@ -17,6 +17,7 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import logging import logging
from .common import Transport
from .file import open_file_transport from .file import open_file_transport
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -26,7 +27,7 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def open_vhci_transport(spec): async def open_vhci_transport(spec: str | None) -> Transport:
''' '''
Open a VHCI transport (only available on some platforms). Open a VHCI transport (only available on some platforms).
The parameter string is either empty (to use the default VHCI device The parameter string is either empty (to use the default VHCI device
@@ -42,15 +43,15 @@ async def open_vhci_transport(spec):
# Override the source's `data_received` method so that we can # Override the source's `data_received` method so that we can
# filter out the vendor packet that is received just after the # filter out the vendor packet that is received just after the
# initial open # initial open
def vhci_data_received(data): def vhci_data_received(data: bytes) -> None:
if len(data) > 0 and data[0] == HCI_VENDOR_PKT: if len(data) > 0 and data[0] == HCI_VENDOR_PKT:
if len(data) == 4: if len(data) == 4:
hci_index = data[2] << 8 | data[3] hci_index = data[2] << 8 | data[3]
logger.info(f'HCI index {hci_index}') logger.info(f'HCI index {hci_index}')
else: else:
transport.source.parser.feed_data(data) transport.source.parser.feed_data(data) # type: ignore
transport.source.data_received = vhci_data_received transport.source.data_received = vhci_data_received # type: ignore
# Write the initial config # Write the initial config
transport.sink.on_packet(bytes([HCI_VENDOR_PKT, HCI_BREDR])) transport.sink.on_packet(bytes([HCI_VENDOR_PKT, HCI_BREDR]))

View File

@@ -16,9 +16,9 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import logging import logging
import websockets import websockets.client
from .common import PumpedPacketSource, PumpedPacketSink, PumpedTransport from .common import PumpedPacketSource, PumpedPacketSink, PumpedTransport, Transport
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -27,7 +27,7 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def open_ws_client_transport(spec): async def open_ws_client_transport(spec: str) -> Transport:
''' '''
Open a WebSocket client transport. Open a WebSocket client transport.
The parameter string has this syntax: The parameter string has this syntax:
@@ -38,7 +38,7 @@ async def open_ws_client_transport(spec):
remote_host, remote_port = spec.split(':') remote_host, remote_port = spec.split(':')
uri = f'ws://{remote_host}:{remote_port}' uri = f'ws://{remote_host}:{remote_port}'
websocket = await websockets.connect(uri) websocket = await websockets.client.connect(uri)
transport = PumpedTransport( transport = PumpedTransport(
PumpedPacketSource(websocket.recv), PumpedPacketSource(websocket.recv),

View File

@@ -15,7 +15,6 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import asyncio
import logging import logging
import websockets import websockets
@@ -28,7 +27,7 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def open_ws_server_transport(spec): async def open_ws_server_transport(spec: str) -> Transport:
''' '''
Open a WebSocket server transport. Open a WebSocket server transport.
The parameter string has this syntax: The parameter string has this syntax:
@@ -43,7 +42,7 @@ async def open_ws_server_transport(spec):
def __init__(self): def __init__(self):
source = ParserSource() source = ParserSource()
sink = PumpedPacketSink(self.send_packet) sink = PumpedPacketSink(self.send_packet)
self.connection = asyncio.get_running_loop().create_future() self.connection = None
self.server = None self.server = None
super().__init__(source, sink) super().__init__(source, sink)
@@ -63,7 +62,7 @@ async def open_ws_server_transport(spec):
f'new connection on {connection.local_address} ' f'new connection on {connection.local_address} '
f'from {connection.remote_address}' f'from {connection.remote_address}'
) )
self.connection.set_result(connection) self.connection = connection
# pylint: disable=no-member # pylint: disable=no-member
try: try:
async for packet in connection: async for packet in connection:
@@ -74,12 +73,14 @@ async def open_ws_server_transport(spec):
except websockets.WebSocketException as error: except websockets.WebSocketException as error:
logger.debug(f'exception while receiving packet: {error}') logger.debug(f'exception while receiving packet: {error}')
# Wait for a new connection # We're now disconnected
self.connection = asyncio.get_running_loop().create_future() self.connection = None
async def send_packet(self, packet): async def send_packet(self, packet):
connection = await self.connection if self.connection is None:
return await connection.send(packet) logger.debug('no connection, dropping packet')
return
return await self.connection.send(packet)
local_host, local_port = spec.split(':') local_host, local_port = spec.split(':')
transport = WsServerTransport() transport = WsServerTransport()

0
bumble/vendor/__init__.py vendored Normal file
View File

0
bumble/vendor/android/__init__.py vendored Normal file
View File

318
bumble/vendor/android/hci.py vendored Normal file
View File

@@ -0,0 +1,318 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import struct
from bumble.hci import (
name_or_number,
hci_vendor_command_op_code,
Address,
HCI_Constant,
HCI_Object,
HCI_Command,
HCI_Vendor_Event,
STATUS_SPEC,
)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
# Android Vendor Specific Commands and Events.
# Only a subset of the commands are implemented here currently.
#
# pylint: disable-next=line-too-long
# See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#chip-capabilities-and-configuration
HCI_LE_GET_VENDOR_CAPABILITIES_COMMAND = hci_vendor_command_op_code(0x153)
HCI_LE_APCF_COMMAND = hci_vendor_command_op_code(0x157)
HCI_GET_CONTROLLER_ACTIVITY_ENERGY_INFO_COMMAND = hci_vendor_command_op_code(0x159)
HCI_A2DP_HARDWARE_OFFLOAD_COMMAND = hci_vendor_command_op_code(0x15D)
HCI_BLUETOOTH_QUALITY_REPORT_COMMAND = hci_vendor_command_op_code(0x15E)
HCI_DYNAMIC_AUDIO_BUFFER_COMMAND = hci_vendor_command_op_code(0x15F)
HCI_BLUETOOTH_QUALITY_REPORT_EVENT = 0x58
HCI_Command.register_commands(globals())
HCI_Vendor_Event.register_subevents(globals())
# -----------------------------------------------------------------------------
@HCI_Command.command(
return_parameters_fields=[
('status', STATUS_SPEC),
('max_advt_instances', 1),
('offloaded_resolution_of_private_address', 1),
('total_scan_results_storage', 2),
('max_irk_list_sz', 1),
('filtering_support', 1),
('max_filter', 1),
('activity_energy_info_support', 1),
('version_supported', 2),
('total_num_of_advt_tracked', 2),
('extended_scan_support', 1),
('debug_logging_supported', 1),
('le_address_generation_offloading_support', 1),
('a2dp_source_offload_capability_mask', 4),
('bluetooth_quality_report_support', 1),
('dynamic_audio_buffer_support', 4),
]
)
class HCI_LE_Get_Vendor_Capabilities_Command(HCI_Command):
# pylint: disable=line-too-long
'''
See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#vendor-specific-capabilities
'''
@classmethod
def parse_return_parameters(cls, parameters):
# There are many versions of this data structure, so we need to parse until
# there are no more bytes to parse, and leave un-signal parameters set to
# None (older versions)
nones = {field: None for field, _ in cls.return_parameters_fields}
return_parameters = HCI_Object(cls.return_parameters_fields, **nones)
try:
offset = 0
for field in cls.return_parameters_fields:
field_name, field_type = field
field_value, field_size = HCI_Object.parse_field(
parameters, offset, field_type
)
setattr(return_parameters, field_name, field_value)
offset += field_size
except struct.error:
pass
return return_parameters
# -----------------------------------------------------------------------------
@HCI_Command.command(
fields=[
(
'opcode',
{
'size': 1,
'mapper': lambda x: HCI_LE_APCF_Command.opcode_name(x),
},
),
('payload', '*'),
],
return_parameters_fields=[
('status', STATUS_SPEC),
(
'opcode',
{
'size': 1,
'mapper': lambda x: HCI_LE_APCF_Command.opcode_name(x),
},
),
('payload', '*'),
],
)
class HCI_LE_APCF_Command(HCI_Command):
# pylint: disable=line-too-long
'''
See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#le_apcf_command
NOTE: the subcommand-specific payloads are left as opaque byte arrays in this
implementation. A future enhancement may define subcommand-specific data structures.
'''
# APCF Subcommands
# TODO: use the OpenIntEnum class (when upcoming PR is merged)
APCF_ENABLE = 0x00
APCF_SET_FILTERING_PARAMETERS = 0x01
APCF_BROADCASTER_ADDRESS = 0x02
APCF_SERVICE_UUID = 0x03
APCF_SERVICE_SOLICITATION_UUID = 0x04
APCF_LOCAL_NAME = 0x05
APCF_MANUFACTURER_DATA = 0x06
APCF_SERVICE_DATA = 0x07
APCF_TRANSPORT_DISCOVERY_SERVICE = 0x08
APCF_AD_TYPE_FILTER = 0x09
APCF_READ_EXTENDED_FEATURES = 0xFF
OPCODE_NAMES = {
APCF_ENABLE: 'APCF_ENABLE',
APCF_SET_FILTERING_PARAMETERS: 'APCF_SET_FILTERING_PARAMETERS',
APCF_BROADCASTER_ADDRESS: 'APCF_BROADCASTER_ADDRESS',
APCF_SERVICE_UUID: 'APCF_SERVICE_UUID',
APCF_SERVICE_SOLICITATION_UUID: 'APCF_SERVICE_SOLICITATION_UUID',
APCF_LOCAL_NAME: 'APCF_LOCAL_NAME',
APCF_MANUFACTURER_DATA: 'APCF_MANUFACTURER_DATA',
APCF_SERVICE_DATA: 'APCF_SERVICE_DATA',
APCF_TRANSPORT_DISCOVERY_SERVICE: 'APCF_TRANSPORT_DISCOVERY_SERVICE',
APCF_AD_TYPE_FILTER: 'APCF_AD_TYPE_FILTER',
APCF_READ_EXTENDED_FEATURES: 'APCF_READ_EXTENDED_FEATURES',
}
@classmethod
def opcode_name(cls, opcode):
return name_or_number(cls.OPCODE_NAMES, opcode)
# -----------------------------------------------------------------------------
@HCI_Command.command(
return_parameters_fields=[
('status', STATUS_SPEC),
('total_tx_time_ms', 4),
('total_rx_time_ms', 4),
('total_idle_time_ms', 4),
('total_energy_used', 4),
],
)
class HCI_Get_Controller_Activity_Energy_Info_Command(HCI_Command):
# pylint: disable=line-too-long
'''
See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#le_get_controller_activity_energy_info
'''
# -----------------------------------------------------------------------------
@HCI_Command.command(
fields=[
(
'opcode',
{
'size': 1,
'mapper': lambda x: HCI_A2DP_Hardware_Offload_Command.opcode_name(x),
},
),
('payload', '*'),
],
return_parameters_fields=[
('status', STATUS_SPEC),
(
'opcode',
{
'size': 1,
'mapper': lambda x: HCI_A2DP_Hardware_Offload_Command.opcode_name(x),
},
),
('payload', '*'),
],
)
class HCI_A2DP_Hardware_Offload_Command(HCI_Command):
# pylint: disable=line-too-long
'''
See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#a2dp-hardware-offload-support
NOTE: the subcommand-specific payloads are left as opaque byte arrays in this
implementation. A future enhancement may define subcommand-specific data structures.
'''
# A2DP Hardware Offload Subcommands
# TODO: use the OpenIntEnum class (when upcoming PR is merged)
START_A2DP_OFFLOAD = 0x01
STOP_A2DP_OFFLOAD = 0x02
OPCODE_NAMES = {
START_A2DP_OFFLOAD: 'START_A2DP_OFFLOAD',
STOP_A2DP_OFFLOAD: 'STOP_A2DP_OFFLOAD',
}
@classmethod
def opcode_name(cls, opcode):
return name_or_number(cls.OPCODE_NAMES, opcode)
# -----------------------------------------------------------------------------
@HCI_Command.command(
fields=[
(
'opcode',
{
'size': 1,
'mapper': lambda x: HCI_Dynamic_Audio_Buffer_Command.opcode_name(x),
},
),
('payload', '*'),
],
return_parameters_fields=[
('status', STATUS_SPEC),
(
'opcode',
{
'size': 1,
'mapper': lambda x: HCI_Dynamic_Audio_Buffer_Command.opcode_name(x),
},
),
('payload', '*'),
],
)
class HCI_Dynamic_Audio_Buffer_Command(HCI_Command):
# pylint: disable=line-too-long
'''
See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#dynamic-audio-buffer-command
NOTE: the subcommand-specific payloads are left as opaque byte arrays in this
implementation. A future enhancement may define subcommand-specific data structures.
'''
# Dynamic Audio Buffer Subcommands
# TODO: use the OpenIntEnum class (when upcoming PR is merged)
GET_AUDIO_BUFFER_TIME_CAPABILITY = 0x01
OPCODE_NAMES = {
GET_AUDIO_BUFFER_TIME_CAPABILITY: 'GET_AUDIO_BUFFER_TIME_CAPABILITY',
}
@classmethod
def opcode_name(cls, opcode):
return name_or_number(cls.OPCODE_NAMES, opcode)
# -----------------------------------------------------------------------------
@HCI_Vendor_Event.event(
fields=[
('quality_report_id', 1),
('packet_types', 1),
('connection_handle', 2),
('connection_role', {'size': 1, 'mapper': HCI_Constant.role_name}),
('tx_power_level', -1),
('rssi', -1),
('snr', 1),
('unused_afh_channel_count', 1),
('afh_select_unideal_channel_count', 1),
('lsto', 2),
('connection_piconet_clock', 4),
('retransmission_count', 4),
('no_rx_count', 4),
('nak_count', 4),
('last_tx_ack_timestamp', 4),
('flow_off_count', 4),
('last_flow_on_timestamp', 4),
('buffer_overflow_bytes', 4),
('buffer_underflow_bytes', 4),
('bdaddr', Address.parse_address),
('cal_failed_item_count', 1),
('tx_total_packets', 4),
('tx_unacked_packets', 4),
('tx_flushed_packets', 4),
('tx_last_subevent_packets', 4),
('crc_error_packets', 4),
('rx_duplicate_packets', 4),
('vendor_specific_parameters', '*'),
]
)
class HCI_Bluetooth_Quality_Report_Event(HCI_Vendor_Event):
# pylint: disable=line-too-long
'''
See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#bluetooth-quality-report-sub-event
'''

0
bumble/vendor/zephyr/__init__.py vendored Normal file
View File

88
bumble/vendor/zephyr/hci.py vendored Normal file
View File

@@ -0,0 +1,88 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from bumble.hci import (
hci_vendor_command_op_code,
HCI_Command,
STATUS_SPEC,
)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
# Zephyr RTOS Vendor Specific Commands and Events.
# Only a subset of the commands are implemented here currently.
#
# pylint: disable-next=line-too-long
# See https://github.com/zephyrproject-rtos/zephyr/blob/main/include/zephyr/bluetooth/hci_vs.h
HCI_WRITE_TX_POWER_LEVEL_COMMAND = hci_vendor_command_op_code(0x000E)
HCI_READ_TX_POWER_LEVEL_COMMAND = hci_vendor_command_op_code(0x000F)
HCI_Command.register_commands(globals())
# -----------------------------------------------------------------------------
class TX_Power_Level_Command:
'''
Base class for read and write TX power level HCI commands
'''
TX_POWER_HANDLE_TYPE_ADV = 0x00
TX_POWER_HANDLE_TYPE_SCAN = 0x01
TX_POWER_HANDLE_TYPE_CONN = 0x02
# -----------------------------------------------------------------------------
@HCI_Command.command(
fields=[('handle_type', 1), ('connection_handle', 2), ('tx_power_level', -1)],
return_parameters_fields=[
('status', STATUS_SPEC),
('handle_type', 1),
('connection_handle', 2),
('selected_tx_power_level', -1),
],
)
class HCI_Write_Tx_Power_Level_Command(HCI_Command, TX_Power_Level_Command):
'''
Write TX power level. See BT_HCI_OP_VS_WRITE_TX_POWER_LEVEL in
https://github.com/zephyrproject-rtos/zephyr/blob/main/include/zephyr/bluetooth/hci_vs.h
Power level is in dB. Connection handle for TX_POWER_HANDLE_TYPE_ADV and
TX_POWER_HANDLE_TYPE_SCAN should be zero.
'''
# -----------------------------------------------------------------------------
@HCI_Command.command(
fields=[('handle_type', 1), ('connection_handle', 2)],
return_parameters_fields=[
('status', STATUS_SPEC),
('handle_type', 1),
('connection_handle', 2),
('tx_power_level', -1),
],
)
class HCI_Read_Tx_Power_Level_Command(HCI_Command, TX_Power_Level_Command):
'''
Read TX power level. See BT_HCI_OP_VS_READ_TX_POWER_LEVEL in
https://github.com/zephyrproject-rtos/zephyr/blob/main/include/zephyr/bluetooth/hci_vs.h
Power level is in dB. Connection handle for TX_POWER_HANDLE_TYPE_ADV and
TX_POWER_HANDLE_TYPE_SCAN should be zero.
'''

View File

@@ -64,6 +64,7 @@ nav:
- Linux: platforms/linux.md - Linux: platforms/linux.md
- Windows: platforms/windows.md - Windows: platforms/windows.md
- Android: platforms/android.md - Android: platforms/android.md
- Zephyr: platforms/zephyr.md
- Examples: - Examples:
- Overview: examples/index.md - Overview: examples/index.md

Binary file not shown.

View File

@@ -9,3 +9,4 @@ For platform-specific information, see the following pages:
* :material-linux: Linux - see the [Linux platform page](linux.md) * :material-linux: Linux - see the [Linux platform page](linux.md)
* :material-microsoft-windows: Windows - see the [Windows platform page](windows.md) * :material-microsoft-windows: Windows - see the [Windows platform page](windows.md)
* :material-android: Android - see the [Android platform page](android.md) * :material-android: Android - see the [Android platform page](android.md)
* :material-memory: Zephyr - see the [Zephyr platform page](zephyr.md)

View File

@@ -0,0 +1,51 @@
:material-memory: ZEPHYR PLATFORM
=================================
Set TX Power on nRF52840
------------------------
The Nordic nRF52840 supports Zephyr's vendor specific HCI command for setting TX
power during advertising, connection, or scanning. With the example [HCI
USB](https://docs.zephyrproject.org/latest/samples/bluetooth/hci_usb/README.html)
application, an [nRF52840
dongle](https://www.nordicsemi.com/Products/Development-
hardware/nRF52840-Dongle) can be used as a Bumble controller.
To add dynamic TX power support to the HCI USB application, add the following to
`zephyr/samples/bluetooth/hci_usb/prj.conf` and build.
```
CONFIG_BT_CTLR_ADVANCED_FEATURES=y
CONFIG_BT_CTLR_CONN_RSSI=y
CONFIG_BT_CTLR_TX_PWR_DYNAMIC_CONTROL=y
```
Alternatively, a prebuilt firmware application can be downloaded here:
[hci_usb.zip](../downloads/zephyr/hci_usb.zip).
Put the nRF52840 dongle into bootloader mode by pressing the RESET button. The
LED should pulse red. Load the firmware application with the `nrfutil` tool:
```
nrfutil dfu usb-serial -pkg hci_usb.zip -p /dev/ttyACM0
```
The vendor specific HCI commands to read and write TX power are defined in
`bumble/vendor/zephyr/hci.py` and may be used as such:
```python
from bumble.vendor.zephyr.hci import HCI_Write_Tx_Power_Level_Command
# set advertising power to -4 dB
response = await host.send_command(
HCI_Write_Tx_Power_Level_Command(
handle_type=HCI_Write_Tx_Power_Level_Command.TX_POWER_HANDLE_TYPE_ADV,
connection_handle=0,
tx_power_level=-4,
)
)
if response.return_parameters.status == HCI_SUCCESS:
print(f"TX power set to {response.return_parameters.selected_tx_power_level}")
```

View File

@@ -3,7 +3,7 @@ channels:
- defaults - defaults
- conda-forge - conda-forge
dependencies: dependencies:
- pip=20 - pip=23
- python=3.8 - python=3.8
- pip: - pip:
- --editable .[development,documentation,test] - --editable .[development,documentation,test]

View File

@@ -30,7 +30,7 @@ from bumble.core import (
BT_RFCOMM_PROTOCOL_ID, BT_RFCOMM_PROTOCOL_ID,
BT_BR_EDR_TRANSPORT, BT_BR_EDR_TRANSPORT,
) )
from bumble.rfcomm import Client from bumble import rfcomm, hfp
from bumble.sdp import ( from bumble.sdp import (
Client as SDP_Client, Client as SDP_Client,
DataElement, DataElement,
@@ -39,7 +39,9 @@ from bumble.sdp import (
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID, SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
) )
from bumble.hfp import HfpProtocol
logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -181,7 +183,7 @@ async def main():
# Create a client and start it # Create a client and start it
print('@@@ Starting to RFCOMM client...') print('@@@ Starting to RFCOMM client...')
rfcomm_client = Client(device, connection) rfcomm_client = rfcomm.Client(device, connection)
rfcomm_mux = await rfcomm_client.start() rfcomm_mux = await rfcomm_client.start()
print('@@@ Started') print('@@@ Started')
@@ -196,7 +198,7 @@ async def main():
return return
# Protocol loop (just for testing at this point) # Protocol loop (just for testing at this point)
protocol = HfpProtocol(session) protocol = hfp.HfpProtocol(session)
while True: while True:
line = await protocol.next_line() line = await protocol.next_line()

View File

@@ -21,82 +21,22 @@ import os
import logging import logging
import json import json
import websockets import websockets
from typing import Optional
from bumble.device import Device from bumble.device import Device
from bumble.transport import open_transport_or_link from bumble.transport import open_transport_or_link
from bumble.rfcomm import Server as RfcommServer from bumble.rfcomm import Server as RfcommServer
from bumble.sdp import ( from bumble import hfp
DataElement, from bumble.hfp import HfProtocol
ServiceAttribute,
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
)
from bumble.core import (
BT_GENERIC_AUDIO_SERVICE,
BT_HANDSFREE_SERVICE,
BT_L2CAP_PROTOCOL_ID,
BT_RFCOMM_PROTOCOL_ID,
)
from bumble.hfp import HfpProtocol
# -----------------------------------------------------------------------------
def make_sdp_records(rfcomm_channel):
return {
0x00010001: [
ServiceAttribute(
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
DataElement.unsigned_integer_32(0x00010001),
),
ServiceAttribute(
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.uuid(BT_HANDSFREE_SERVICE),
DataElement.uuid(BT_GENERIC_AUDIO_SERVICE),
]
),
),
ServiceAttribute(
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.sequence([DataElement.uuid(BT_L2CAP_PROTOCOL_ID)]),
DataElement.sequence(
[
DataElement.uuid(BT_RFCOMM_PROTOCOL_ID),
DataElement.unsigned_integer_8(rfcomm_channel),
]
),
]
),
),
ServiceAttribute(
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.sequence(
[
DataElement.uuid(BT_HANDSFREE_SERVICE),
DataElement.unsigned_integer_16(0x0105),
]
)
]
),
),
]
}
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class UiServer: class UiServer:
protocol = None protocol: Optional[HfProtocol] = None
async def start(self): async def start(self):
# Start a Websocket server to receive events from a web page """Start a Websocket server to receive events from a web page."""
async def serve(websocket, _path): async def serve(websocket, _path):
while True: while True:
try: try:
@@ -107,7 +47,7 @@ class UiServer:
message_type = parsed['type'] message_type = parsed['type']
if message_type == 'at_command': if message_type == 'at_command':
if self.protocol is not None: if self.protocol is not None:
self.protocol.send_command_line(parsed['command']) await self.protocol.execute_command(parsed['command'])
except websockets.exceptions.ConnectionClosedOK: except websockets.exceptions.ConnectionClosedOK:
pass pass
@@ -117,19 +57,11 @@ class UiServer:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def protocol_loop(protocol): def on_dlc(dlc, configuration: hfp.Configuration):
await protocol.initialize_service()
while True:
await (protocol.next_line())
# -----------------------------------------------------------------------------
def on_dlc(dlc):
print('*** DLC connected', dlc) print('*** DLC connected', dlc)
protocol = HfpProtocol(dlc) protocol = HfProtocol(dlc, configuration)
UiServer.protocol = protocol UiServer.protocol = protocol
asyncio.create_task(protocol_loop(protocol)) asyncio.create_task(protocol.run())
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -143,6 +75,27 @@ async def main():
async with await open_transport_or_link(sys.argv[2]) as (hci_source, hci_sink): async with await open_transport_or_link(sys.argv[2]) as (hci_source, hci_sink):
print('<<< connected') print('<<< connected')
# Hands-Free profile configuration.
# TODO: load configuration from file.
configuration = hfp.Configuration(
supported_hf_features=[
hfp.HfFeature.THREE_WAY_CALLING,
hfp.HfFeature.REMOTE_VOLUME_CONTROL,
hfp.HfFeature.ENHANCED_CALL_STATUS,
hfp.HfFeature.ENHANCED_CALL_CONTROL,
hfp.HfFeature.CODEC_NEGOTIATION,
hfp.HfFeature.HF_INDICATORS,
hfp.HfFeature.ESCO_S4_SETTINGS_SUPPORTED,
],
supported_hf_indicators=[
hfp.HfIndicator.BATTERY_LEVEL,
],
supported_audio_codecs=[
hfp.AudioCodec.CVSD,
hfp.AudioCodec.MSBC,
],
)
# Create a device # Create a device
device = Device.from_config_file_with_hci(sys.argv[1], hci_source, hci_sink) device = Device.from_config_file_with_hci(sys.argv[1], hci_source, hci_sink)
device.classic_enabled = True device.classic_enabled = True
@@ -151,11 +104,13 @@ async def main():
rfcomm_server = RfcommServer(device) rfcomm_server = RfcommServer(device)
# Listen for incoming DLC connections # Listen for incoming DLC connections
channel_number = rfcomm_server.listen(on_dlc) channel_number = rfcomm_server.listen(lambda dlc: on_dlc(dlc, configuration))
print(f'### Listening for connection on channel {channel_number}') print(f'### Listening for connection on channel {channel_number}')
# Advertise the HFP RFComm channel in the SDP # Advertise the HFP RFComm channel in the SDP
device.sdp_service_records = make_sdp_records(channel_number) device.sdp_service_records = {
0x00010001: hfp.sdp_records(0x00010001, channel_number, configuration)
}
# Let's go! # Let's go!
await device.power_on() await device.power_on()

View File

@@ -20,83 +20,109 @@ import sys
import os import os
import logging import logging
from bumble.core import UUID
from bumble.device import Device from bumble.device import Device
from bumble.transport import open_transport_or_link from bumble.transport import open_transport_or_link
from bumble.core import BT_L2CAP_PROTOCOL_ID, BT_RFCOMM_PROTOCOL_ID, UUID
from bumble.rfcomm import Server from bumble.rfcomm import Server
from bumble.sdp import ( from bumble.utils import AsyncRunner
DataElement, from bumble.rfcomm import make_service_sdp_records
ServiceAttribute,
SDP_PUBLIC_BROWSE_ROOT,
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def sdp_records(channel): def sdp_records(channel, uuid):
service_record_handle = 0x00010001
return { return {
0x00010001: [ service_record_handle: make_service_sdp_records(
ServiceAttribute( service_record_handle, channel, UUID(uuid)
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, )
DataElement.unsigned_integer_32(0x00010001),
),
ServiceAttribute(
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
DataElement.sequence([DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)]),
),
ServiceAttribute(
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[DataElement.uuid(UUID('E6D55659-C8B4-4B85-96BB-B1143AF6D3AE'))]
),
),
ServiceAttribute(
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.sequence([DataElement.uuid(BT_L2CAP_PROTOCOL_ID)]),
DataElement.sequence(
[
DataElement.uuid(BT_RFCOMM_PROTOCOL_ID),
DataElement.unsigned_integer_8(channel),
]
),
]
),
),
]
} }
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def on_dlc(dlc): def on_rfcomm_session(rfcomm_session, tcp_server):
print('*** DLC connected', dlc) print('*** RFComm session connected', rfcomm_session)
dlc.sink = lambda data: on_rfcomm_data_received(dlc, data) tcp_server.attach_session(rfcomm_session)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def on_rfcomm_data_received(dlc, data): class TcpServerProtocol(asyncio.Protocol):
print(f'<<< Data received: {data.hex()}') def __init__(self, server):
try: self.server = server
message = data.decode('utf-8')
print(f'<<< Message = {message}')
except Exception:
pass
# Echo everything back def connection_made(self, transport):
dlc.write(data) peer_name = transport.get_extra_info('peer_name')
print(f'<<< TCP Server: connection from {peer_name}')
if self.server:
self.server.tcp_transport = transport
else:
transport.close()
def connection_lost(self, exc):
print('<<< TCP Server: connection lost')
if self.server:
self.server.tcp_transport = None
def data_received(self, data):
print(f'<<< TCP Server: data received: {len(data)} bytes - {data.hex()}')
if self.server:
self.server.tcp_data_received(data)
# -----------------------------------------------------------------------------
class TcpServer:
def __init__(self, port):
self.rfcomm_session = None
self.tcp_transport = None
AsyncRunner.spawn(self.run(port))
def attach_session(self, rfcomm_session):
if self.rfcomm_session:
self.rfcomm_session.sink = None
self.rfcomm_session = rfcomm_session
rfcomm_session.sink = self.rfcomm_data_received
def rfcomm_data_received(self, data):
print(f'<<< RFCOMM Data: {data.hex()}')
if self.tcp_transport:
self.tcp_transport.write(data)
else:
print('!!! no TCP connection, dropping data')
def tcp_data_received(self, data):
if self.rfcomm_session:
self.rfcomm_session.write(data)
else:
print('!!! no RFComm session, dropping data')
async def run(self, port):
print(f'$$$ Starting TCP server on port {port}')
server = await asyncio.get_running_loop().create_server(
lambda: TcpServerProtocol(self), '127.0.0.1', port
)
async with server:
await server.serve_forever()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def main(): async def main():
if len(sys.argv) < 3: if len(sys.argv) < 4:
print('Usage: run_rfcomm_server.py <device-config> <transport-spec>') print(
print('example: run_rfcomm_server.py classic2.json usb:04b4:f901') 'Usage: run_rfcomm_server.py <device-config> <transport-spec> '
'<tcp-port> [<uuid>]'
)
print('example: run_rfcomm_server.py classic2.json usb:0 8888')
return return
tcp_port = int(sys.argv[3])
if len(sys.argv) >= 5:
uuid = sys.argv[4]
else:
uuid = 'E6D55659-C8B4-4B85-96BB-B1143AF6D3AE'
print('<<< connecting to HCI...') print('<<< connecting to HCI...')
async with await open_transport_or_link(sys.argv[2]) as (hci_source, hci_sink): async with await open_transport_or_link(sys.argv[2]) as (hci_source, hci_sink):
print('<<< connected') print('<<< connected')
@@ -105,15 +131,20 @@ async def main():
device = Device.from_config_file_with_hci(sys.argv[1], hci_source, hci_sink) device = Device.from_config_file_with_hci(sys.argv[1], hci_source, hci_sink)
device.classic_enabled = True device.classic_enabled = True
# Create and register a server # Create a TCP server
tcp_server = TcpServer(tcp_port)
# Create and register an RFComm server
rfcomm_server = Server(device) rfcomm_server = Server(device)
# Listen for incoming DLC connections # Listen for incoming DLC connections
channel_number = rfcomm_server.listen(on_dlc) channel_number = rfcomm_server.listen(
print(f'### Listening for connection on channel {channel_number}') lambda session: on_rfcomm_session(session, tcp_server)
)
print(f'### Listening for RFComm connections on channel {channel_number}')
# Setup the SDP to advertise this channel # Setup the SDP to advertise this channel
device.sdp_service_records = sdp_records(channel_number) device.sdp_service_records = sdp_records(channel_number, uuid)
# Start the controller # Start the controller
await device.power_on() await device.power_on()

7
rust/CHANGELOG.md Normal file
View File

@@ -0,0 +1,7 @@
# Next
- Code-gen company ID table
# 0.1.0
- Initial release

894
rust/Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -10,12 +10,12 @@ documentation = "https://docs.rs/crate/bumble"
authors = ["Marshall Pierce <marshallpierce@google.com>"] authors = ["Marshall Pierce <marshallpierce@google.com>"]
keywords = ["bluetooth", "ble"] keywords = ["bluetooth", "ble"]
categories = ["api-bindings", "network-programming"] categories = ["api-bindings", "network-programming"]
rust-version = "1.69.0" rust-version = "1.70.0"
[dependencies] [dependencies]
pyo3 = { version = "0.18.3", features = ["macros"] } pyo3 = { version = "0.18.3", features = ["macros"] }
pyo3-asyncio = { version = "0.18.0", features = ["tokio-runtime"] } pyo3-asyncio = { version = "0.18.0", features = ["tokio-runtime"] }
tokio = { version = "1.28.2" } tokio = { version = "1.28.2", features = ["macros", "signal"] }
nom = "7.1.3" nom = "7.1.3"
strum = "0.25.0" strum = "0.25.0"
strum_macros = "0.25.0" strum_macros = "0.25.0"
@@ -24,6 +24,17 @@ itertools = "0.11.0"
lazy_static = "1.4.0" lazy_static = "1.4.0"
thiserror = "1.0.41" thiserror = "1.0.41"
# CLI
anyhow = { version = "1.0.71", optional = true }
clap = { version = "4.3.3", features = ["derive"], optional = true }
directories = { version = "5.0.1", optional = true }
env_logger = { version = "0.10.0", optional = true }
futures = { version = "0.3.28", optional = true }
log = { version = "0.4.19", optional = true }
owo-colors = { version = "3.5.0", optional = true }
reqwest = { version = "0.11.20", features = ["blocking"], optional = true }
rusb = { version = "0.9.2", optional = true }
[dev-dependencies] [dev-dependencies]
tokio = { version = "1.28.2", features = ["full"] } tokio = { version = "1.28.2", features = ["full"] }
tempfile = "3.6.0" tempfile = "3.6.0"
@@ -31,12 +42,25 @@ nix = "0.26.2"
anyhow = "1.0.71" anyhow = "1.0.71"
pyo3 = { version = "0.18.3", features = ["macros", "anyhow"] } pyo3 = { version = "0.18.3", features = ["macros", "anyhow"] }
pyo3-asyncio = { version = "0.18.0", features = ["tokio-runtime", "attributes", "testing"] } pyo3-asyncio = { version = "0.18.0", features = ["tokio-runtime", "attributes", "testing"] }
rusb = "0.9.2"
rand = "0.8.5"
clap = { version = "4.3.3", features = ["derive"] } clap = { version = "4.3.3", features = ["derive"] }
owo-colors = "3.5.0" owo-colors = "3.5.0"
log = "0.4.19" log = "0.4.19"
env_logger = "0.10.0" env_logger = "0.10.0"
rusb = "0.9.2"
rand = "0.8.5" [package.metadata.docs.rs]
rustdoc-args = ["--generate-link-to-definition"]
[[bin]]
name = "gen-assigned-numbers"
path = "tools/gen_assigned_numbers.rs"
required-features = ["bumble-codegen"]
[[bin]]
name = "bumble"
path = "src/main.rs"
required-features = ["bumble-tools"]
# test entry point that uses pyo3_asyncio's test harness # test entry point that uses pyo3_asyncio's test harness
[[test]] [[test]]
@@ -46,4 +70,8 @@ harness = false
[features] [features]
anyhow = ["pyo3/anyhow"] anyhow = ["pyo3/anyhow"]
pyo3-asyncio-attributes = ["pyo3-asyncio/attributes"] pyo3-asyncio-attributes = ["pyo3-asyncio/attributes"]
bumble-codegen = ["dep:anyhow"]
# separate feature for CLI so that dependencies don't spend time building these
bumble-tools = ["dep:clap", "anyhow", "dep:anyhow", "dep:directories", "pyo3-asyncio-attributes", "dep:owo-colors", "dep:reqwest", "dep:rusb", "dep:log", "dep:env_logger", "dep:futures"]
default = []

View File

@@ -5,7 +5,8 @@ Rust wrappers around the [Bumble](https://github.com/google/bumble) Python API.
Method calls are mapped to the equivalent Python, and return types adapted where Method calls are mapped to the equivalent Python, and return types adapted where
relevant. relevant.
See the `examples` directory for usage. See the CLI in `src/main.rs` or the `examples` directory for how to use the
Bumble API.
# Usage # Usage
@@ -27,6 +28,15 @@ PYTHONPATH=..:~/.virtualenvs/bumble/lib/python3.10/site-packages/ \
Run the corresponding `battery_server` Python example, and launch an emulator in Run the corresponding `battery_server` Python example, and launch an emulator in
Android Studio (currently, Canary is required) to run netsim. Android Studio (currently, Canary is required) to run netsim.
# CLI
Explore the available subcommands:
```
PYTHONPATH=..:[virtualenv site-packages] \
cargo run --features bumble-tools --bin bumble -- --help
```
# Development # Development
Run the tests: Run the tests:
@@ -39,4 +49,18 @@ Check lints:
``` ```
cargo clippy --all-targets cargo clippy --all-targets
```
## Code gen
To have the fastest startup while keeping the build simple, code gen for
assigned numbers is done with the `gen_assigned_numbers` tool. It should
be re-run whenever the Python assigned numbers are changed. To ensure that the
generated code is kept up to date, the Rust data is compared to the Python
in tests at `pytests/assigned_numbers.rs`.
To regenerate the assigned number tables based on the Python codebase:
```
PYTHONPATH=.. cargo run --bin gen-assigned-numbers --features bumble-codegen
``` ```

View File

@@ -0,0 +1,30 @@
use bumble::wrapper::{self, core::Uuid16};
use pyo3::{intern, prelude::*, types::PyDict};
use std::collections;
#[pyo3_asyncio::tokio::test]
async fn company_ids_matches_python() -> PyResult<()> {
let ids_from_python = Python::with_gil(|py| {
PyModule::import(py, intern!(py, "bumble.company_ids"))?
.getattr(intern!(py, "COMPANY_IDENTIFIERS"))?
.downcast::<PyDict>()?
.into_iter()
.map(|(k, v)| {
Ok((
Uuid16::from_be_bytes(k.extract::<u16>()?.to_be_bytes()),
v.str()?.to_str()?.to_string(),
))
})
.collect::<PyResult<collections::HashMap<_, _>>>()
})?;
assert_eq!(
wrapper::assigned_numbers::COMPANY_IDS
.iter()
.map(|(id, name)| (*id, name.to_string()))
.collect::<collections::HashMap<_, _>>(),
ids_from_python,
"Company ids do not match -- re-run gen_assigned_numbers?"
);
Ok(())
}

View File

@@ -17,4 +17,5 @@ async fn main() -> pyo3::PyResult<()> {
pyo3_asyncio::testing::main().await pyo3_asyncio::testing::main().await
} }
mod assigned_numbers;
mod wrapper; mod wrapper;

View File

@@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use bumble::{wrapper, wrapper::transport::Transport}; use bumble::wrapper::{drivers::rtk::DriverInfo, transport::Transport};
use nix::sys::stat::Mode; use nix::sys::stat::Mode;
use pyo3::prelude::*; use pyo3::PyResult;
#[pyo3_asyncio::tokio::test] #[pyo3_asyncio::tokio::test]
async fn fifo_transport_can_open() -> PyResult<()> { async fn fifo_transport_can_open() -> PyResult<()> {
@@ -31,7 +31,7 @@ async fn fifo_transport_can_open() -> PyResult<()> {
} }
#[pyo3_asyncio::tokio::test] #[pyo3_asyncio::tokio::test]
async fn company_ids() -> PyResult<()> { async fn realtek_driver_info_all_drivers() -> PyResult<()> {
assert!(wrapper::assigned_numbers::COMPANY_IDS.len() > 2000); assert_eq!(12, DriverInfo::all_drivers()?.len());
Ok(()) Ok(())
} }

View File

@@ -0,0 +1,4 @@
This dir contains samples firmware images in the format used for Realtek chips,
but with repetitions of the length of the section as a little-endian 32-bit int
for the patch data instead of actual firmware, since we only need the structure
to test parsing.

View File

@@ -0,0 +1,15 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
pub(crate) mod rtk;

View File

@@ -0,0 +1,265 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Realtek firmware tools
use crate::{Download, Source};
use anyhow::anyhow;
use bumble::wrapper::{
drivers::rtk::{Driver, DriverInfo, Firmware},
host::{DriverFactory, Host},
transport::Transport,
};
use owo_colors::{colors::css, OwoColorize};
use pyo3::PyResult;
use std::{fs, path};
pub(crate) async fn download(dl: Download) -> PyResult<()> {
let data_dir = dl
.output_dir
.or_else(|| {
directories::ProjectDirs::from("com", "google", "bumble")
.map(|pd| pd.data_local_dir().join("firmware").join("realtek"))
})
.unwrap_or_else(|| {
eprintln!("Could not determine standard data directory");
path::PathBuf::from(".")
});
fs::create_dir_all(&data_dir)?;
let (base_url, uses_bin_suffix) = match dl.source {
Source::LinuxKernel => ("https://git.kernel.org/pub/scm/linux/kernel/git/firmware/linux-firmware.git/plain/rtl_bt", true),
Source::RealtekOpensource => ("https://github.com/Realtek-OpenSource/android_hardware_realtek/raw/rtk1395/bt/rtkbt/Firmware/BT", false),
Source::LinuxFromScratch => ("https://anduin.linuxfromscratch.org/sources/linux-firmware/rtl_bt", true),
};
println!("Downloading");
println!("{} {}", "FROM:".green(), base_url);
println!("{} {}", "TO:".green(), data_dir.to_string_lossy());
let url_for_file = |file_name: &str| {
let url_suffix = if uses_bin_suffix {
file_name
} else {
file_name.trim_end_matches(".bin")
};
let mut url = base_url.to_string();
url.push('/');
url.push_str(url_suffix);
url
};
let to_download = if let Some(single) = dl.single {
vec![(
format!("{single}_fw.bin"),
Some(format!("{single}_config.bin")),
false,
)]
} else {
DriverInfo::all_drivers()?
.iter()
.map(|di| Ok((di.firmware_name()?, di.config_name()?, di.config_needed()?)))
.collect::<PyResult<Vec<_>>>()?
};
let client = SimpleClient::new();
for (fw_filename, config_filename, config_needed) in to_download {
println!("{}", "---".yellow());
let fw_path = data_dir.join(&fw_filename);
let config_path = config_filename.as_ref().map(|f| data_dir.join(f));
if fw_path.exists() && !dl.overwrite {
println!(
"{}",
format!("{} already exists, skipping", fw_path.to_string_lossy())
.fg::<css::Orange>()
);
continue;
}
if let Some(cp) = config_path.as_ref() {
if cp.exists() && !dl.overwrite {
println!(
"{}",
format!("{} already exists, skipping", cp.to_string_lossy())
.fg::<css::Orange>()
);
continue;
}
}
let fw_contents = match client.get(&url_for_file(&fw_filename)).await {
Ok(data) => {
println!("Downloaded {}: {} bytes", fw_filename, data.len());
data
}
Err(e) => {
eprintln!(
"{} {} {:?}",
"Failed to download".red(),
fw_filename.red(),
e
);
continue;
}
};
let config_contents = if let Some(cn) = &config_filename {
match client.get(&url_for_file(cn)).await {
Ok(data) => {
println!("Downloaded {}: {} bytes", cn, data.len());
Some(data)
}
Err(e) => {
if config_needed {
eprintln!("{} {} {:?}", "Failed to download".red(), cn.red(), e);
continue;
} else {
eprintln!(
"{}",
format!("No config available as {cn}").fg::<css::Orange>()
);
None
}
}
}
} else {
None
};
fs::write(&fw_path, &fw_contents)?;
if !dl.no_parse && config_filename.is_some() {
println!("{} {}", "Parsing:".cyan(), &fw_filename);
match Firmware::parse(&fw_contents).map_err(|e| anyhow!("Parse error: {:?}", e)) {
Ok(fw) => dump_firmware_desc(&fw),
Err(e) => {
eprintln!(
"{} {:?}",
"Could not parse firmware:".fg::<css::Orange>(),
e
);
}
}
}
if let Some((cp, cd)) = config_path
.as_ref()
.and_then(|p| config_contents.map(|c| (p, c)))
{
fs::write(cp, &cd)?;
}
}
Ok(())
}
pub(crate) fn parse(firmware_path: &path::Path) -> PyResult<()> {
let contents = fs::read(firmware_path)?;
let fw = Firmware::parse(&contents)
// squish the error into a string to avoid the error type requiring that the input be
// 'static
.map_err(|e| anyhow!("Parse error: {:?}", e))?;
dump_firmware_desc(&fw);
Ok(())
}
pub(crate) async fn info(transport: &str, force: bool) -> PyResult<()> {
let transport = Transport::open(transport).await?;
let mut host = Host::new(transport.source()?, transport.sink()?)?;
host.reset(DriverFactory::None).await?;
if !force && !Driver::check(&host).await? {
println!("USB device not supported by this RTK driver");
} else if let Some(driver_info) = Driver::driver_info_for_host(&host).await? {
println!("Driver:");
println!(" {:10} {:04X}", "ROM:", driver_info.rom()?);
println!(" {:10} {}", "Firmware:", driver_info.firmware_name()?);
println!(
" {:10} {}",
"Config:",
driver_info.config_name()?.unwrap_or_default()
);
} else {
println!("Firmware already loaded or no supported driver for this device.")
}
Ok(())
}
pub(crate) async fn load(transport: &str, force: bool) -> PyResult<()> {
let transport = Transport::open(transport).await?;
let mut host = Host::new(transport.source()?, transport.sink()?)?;
host.reset(DriverFactory::None).await?;
match Driver::for_host(&host, force).await? {
None => {
eprintln!("Firmware already loaded or no supported driver for this device.");
}
Some(mut d) => d.download_firmware().await?,
};
Ok(())
}
pub(crate) async fn drop(transport: &str) -> PyResult<()> {
let transport = Transport::open(transport).await?;
let mut host = Host::new(transport.source()?, transport.sink()?)?;
host.reset(DriverFactory::None).await?;
Driver::drop_firmware(&mut host).await?;
Ok(())
}
fn dump_firmware_desc(fw: &Firmware) {
println!(
"Firmware: version=0x{:08X} project_id=0x{:04X}",
fw.version(),
fw.project_id()
);
for p in fw.patches() {
println!(
" Patch: chip_id=0x{:04X}, {} bytes, SVN Version={:08X}",
p.chip_id(),
p.contents().len(),
p.svn_version()
)
}
}
struct SimpleClient {
client: reqwest::Client,
}
impl SimpleClient {
fn new() -> Self {
Self {
client: reqwest::Client::new(),
}
}
async fn get(&self, url: &str) -> anyhow::Result<Vec<u8>> {
let resp = self.client.get(url).send().await?;
if !resp.status().is_success() {
return Err(anyhow!("Bad status: {}", resp.status()));
}
let bytes = resp.bytes().await?;
Ok(bytes.as_ref().to_vec())
}
}

View File

@@ -0,0 +1,191 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/// L2CAP CoC client bridge: connects to a BLE device, then waits for an inbound
/// TCP connection on a specified port number. When a TCP client connects, an
/// L2CAP CoC channel connection to the BLE device is established, and the data
/// is bridged in both directions, with flow control.
/// When the TCP connection is closed by the client, the L2CAP CoC channel is
/// disconnected, but the connection to the BLE device remains, ready for a new
/// TCP client to connect.
/// When the L2CAP CoC channel is closed, the TCP connection is closed as well.
use crate::cli::l2cap::{
proxy_l2cap_rx_to_tcp_tx, proxy_tcp_rx_to_l2cap_tx, run_future_with_current_task_locals,
BridgeData,
};
use bumble::wrapper::{
device::{Connection, Device},
hci::HciConstant,
};
use futures::executor::block_on;
use owo_colors::OwoColorize;
use pyo3::{PyResult, Python};
use std::{net::SocketAddr, sync::Arc};
use tokio::{
join,
net::{TcpListener, TcpStream},
sync::{mpsc, Mutex},
};
pub struct Args {
pub psm: u16,
pub max_credits: Option<u16>,
pub mtu: Option<u16>,
pub mps: Option<u16>,
pub bluetooth_address: String,
pub tcp_host: String,
pub tcp_port: u16,
}
pub async fn start(args: &Args, device: &mut Device) -> PyResult<()> {
println!(
"{}",
format!("### Connecting to {}...", args.bluetooth_address).yellow()
);
let mut ble_connection = device.connect(&args.bluetooth_address).await?;
ble_connection.on_disconnection(|_py, reason| {
let disconnection_info = match HciConstant::error_name(reason) {
Ok(info_string) => info_string,
Err(py_err) => format!("failed to get disconnection error name ({})", py_err),
};
println!(
"{} {}",
"@@@ Bluetooth disconnection: ".red(),
disconnection_info,
);
Ok(())
})?;
// Start the TCP server.
let listener = TcpListener::bind(format!("{}:{}", args.tcp_host, args.tcp_port))
.await
.expect("failed to bind tcp to address");
println!(
"{}",
format!(
"### Listening for TCP connections on port {}",
args.tcp_port
)
.magenta()
);
let psm = args.psm;
let max_credits = args.max_credits;
let mtu = args.mtu;
let mps = args.mps;
let ble_connection = Arc::new(Mutex::new(ble_connection));
// Ensure Python event loop is available to l2cap `disconnect`
let _ = run_future_with_current_task_locals(async move {
while let Ok((tcp_stream, addr)) = listener.accept().await {
let ble_connection = ble_connection.clone();
let _ = run_future_with_current_task_locals(proxy_data_between_tcp_and_l2cap(
ble_connection,
tcp_stream,
addr,
psm,
max_credits,
mtu,
mps,
));
}
Ok(())
});
Ok(())
}
async fn proxy_data_between_tcp_and_l2cap(
ble_connection: Arc<Mutex<Connection>>,
tcp_stream: TcpStream,
addr: SocketAddr,
psm: u16,
max_credits: Option<u16>,
mtu: Option<u16>,
mps: Option<u16>,
) -> PyResult<()> {
println!("{}", format!("<<< TCP connection from {}", addr).magenta());
println!(
"{}",
format!(">>> Opening L2CAP channel on PSM = {}", psm).yellow()
);
let mut l2cap_channel = match ble_connection
.lock()
.await
.open_l2cap_channel(psm, max_credits, mtu, mps)
.await
{
Ok(channel) => channel,
Err(e) => {
println!("{}", format!("!!! Connection failed: {e}").red());
// TCP stream will get dropped after returning, automatically shutting it down.
return Err(e);
}
};
let channel_info = l2cap_channel
.debug_string()
.unwrap_or_else(|e| format!("failed to get l2cap channel info ({e})"));
println!("{}{}", "*** L2CAP channel: ".cyan(), channel_info);
let (l2cap_to_tcp_tx, l2cap_to_tcp_rx) = mpsc::channel::<BridgeData>(10);
// Set l2cap callback (`set_sink`) for when data is received.
let l2cap_to_tcp_tx_clone = l2cap_to_tcp_tx.clone();
l2cap_channel
.set_sink(move |_py, sdu| {
block_on(l2cap_to_tcp_tx_clone.send(BridgeData::Data(sdu.into())))
.expect("failed to channel data to tcp");
Ok(())
})
.expect("failed to set sink for l2cap connection");
// Set l2cap callback for when the channel is closed.
l2cap_channel
.on_close(move |_py| {
println!("{}", "*** L2CAP channel closed".red());
block_on(l2cap_to_tcp_tx.send(BridgeData::CloseSignal))
.expect("failed to channel close signal to tcp");
Ok(())
})
.expect("failed to set on_close callback for l2cap channel");
let l2cap_channel = Arc::new(Mutex::new(Some(l2cap_channel)));
let (tcp_reader, tcp_writer) = tcp_stream.into_split();
// Do tcp stuff when something happens on the l2cap channel.
let handle_l2cap_data_future =
proxy_l2cap_rx_to_tcp_tx(l2cap_to_tcp_rx, tcp_writer, l2cap_channel.clone());
// Do l2cap stuff when something happens on tcp.
let handle_tcp_data_future = proxy_tcp_rx_to_l2cap_tx(tcp_reader, l2cap_channel.clone(), true);
let (handle_l2cap_result, handle_tcp_result) =
join!(handle_l2cap_data_future, handle_tcp_data_future);
if let Err(e) = handle_l2cap_result {
println!("!!! Error: {e}");
}
if let Err(e) = handle_tcp_result {
println!("!!! Error: {e}");
}
Python::with_gil(|_| {
// Must hold GIL at least once while/after dropping for Python heap object to ensure
// de-allocation.
drop(l2cap_channel);
});
Ok(())
}

190
rust/src/cli/l2cap/mod.rs Normal file
View File

@@ -0,0 +1,190 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Rust version of the Python `l2cap_bridge.py` found under the `apps` folder.
use crate::L2cap;
use anyhow::anyhow;
use bumble::wrapper::{device::Device, l2cap::LeConnectionOrientedChannel, transport::Transport};
use owo_colors::{colors::css::Orange, OwoColorize};
use pyo3::{PyObject, PyResult, Python};
use std::{future::Future, path::PathBuf, sync::Arc};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::tcp::{OwnedReadHalf, OwnedWriteHalf},
sync::{mpsc::Receiver, Mutex},
};
mod client_bridge;
mod server_bridge;
pub(crate) async fn run(
command: L2cap,
device_config: PathBuf,
transport: String,
psm: u16,
max_credits: Option<u16>,
mtu: Option<u16>,
mps: Option<u16>,
) -> PyResult<()> {
println!("<<< connecting to HCI...");
let transport = Transport::open(transport).await?;
println!("<<< connected");
let mut device =
Device::from_config_file_with_hci(&device_config, transport.source()?, transport.sink()?)?;
device.power_on().await?;
match command {
L2cap::Server { tcp_host, tcp_port } => {
let args = server_bridge::Args {
psm,
max_credits,
mtu,
mps,
tcp_host,
tcp_port,
};
server_bridge::start(&args, &mut device).await?
}
L2cap::Client {
bluetooth_address,
tcp_host,
tcp_port,
} => {
let args = client_bridge::Args {
psm,
max_credits,
mtu,
mps,
bluetooth_address,
tcp_host,
tcp_port,
};
client_bridge::start(&args, &mut device).await?
}
};
// wait until user kills the process
tokio::signal::ctrl_c().await?;
Ok(())
}
/// Used for channeling data from Python callbacks to a Rust consumer.
enum BridgeData {
Data(Vec<u8>),
CloseSignal,
}
async fn proxy_l2cap_rx_to_tcp_tx(
mut l2cap_data_receiver: Receiver<BridgeData>,
mut tcp_writer: OwnedWriteHalf,
l2cap_channel: Arc<Mutex<Option<LeConnectionOrientedChannel>>>,
) -> anyhow::Result<()> {
while let Some(bridge_data) = l2cap_data_receiver.recv().await {
match bridge_data {
BridgeData::Data(sdu) => {
println!("{}", format!("<<< [L2CAP SDU]: {} bytes", sdu.len()).cyan());
tcp_writer
.write_all(sdu.as_ref())
.await
.map_err(|_| anyhow!("Failed to write to tcp stream"))?;
tcp_writer
.flush()
.await
.map_err(|_| anyhow!("Failed to flush tcp stream"))?;
}
BridgeData::CloseSignal => {
l2cap_channel.lock().await.take();
tcp_writer
.shutdown()
.await
.map_err(|_| anyhow!("Failed to shut down write half of tcp stream"))?;
return Ok(());
}
}
}
Ok(())
}
async fn proxy_tcp_rx_to_l2cap_tx(
mut tcp_reader: OwnedReadHalf,
l2cap_channel: Arc<Mutex<Option<LeConnectionOrientedChannel>>>,
drain_l2cap_after_write: bool,
) -> PyResult<()> {
let mut buf = [0; 4096];
loop {
match tcp_reader.read(&mut buf).await {
Ok(len) => {
if len == 0 {
println!("{}", "!!! End of stream".fg::<Orange>());
if let Some(mut channel) = l2cap_channel.lock().await.take() {
channel.disconnect().await.map_err(|e| {
eprintln!("Failed to call disconnect on l2cap channel: {e}");
e
})?;
}
return Ok(());
}
println!("{}", format!("<<< [TCP DATA]: {len} bytes").blue());
match l2cap_channel.lock().await.as_mut() {
None => {
println!("{}", "!!! L2CAP channel not connected, dropping".red());
return Ok(());
}
Some(channel) => {
channel.write(&buf[..len])?;
if drain_l2cap_after_write {
channel.drain().await?;
}
}
}
}
Err(e) => {
println!("{}", format!("!!! TCP connection lost: {}", e).red());
if let Some(mut channel) = l2cap_channel.lock().await.take() {
let _ = channel.disconnect().await.map_err(|e| {
eprintln!("Failed to call disconnect on l2cap channel: {e}");
});
}
return Err(e.into());
}
}
}
}
/// Copies the current thread's TaskLocals into a Python "awaitable" and encapsulates it in a Rust
/// future, running it as a Python Task.
/// `TaskLocals` stores the current event loop, and allows the user to copy the current Python
/// context if necessary. In this case, the python event loop is used when calling `disconnect` on
/// an l2cap connection, or else the call will fail.
pub fn run_future_with_current_task_locals<F>(
fut: F,
) -> PyResult<impl Future<Output = PyResult<PyObject>> + Send>
where
F: Future<Output = PyResult<()>> + Send + 'static,
{
Python::with_gil(|py| {
let locals = pyo3_asyncio::tokio::get_current_locals(py)?;
let future = pyo3_asyncio::tokio::scope(locals.clone(), fut);
pyo3_asyncio::tokio::future_into_py_with_locals(py, locals, future)
.and_then(pyo3_asyncio::tokio::into_future)
})
}

View File

@@ -0,0 +1,205 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/// L2CAP CoC server bridge: waits for a peer to connect an L2CAP CoC channel
/// on a specified PSM. When the connection is made, the bridge connects a TCP
/// socket to a remote host and bridges the data in both directions, with flow
/// control.
/// When the L2CAP CoC channel is closed, the bridge disconnects the TCP socket
/// and waits for a new L2CAP CoC channel to be connected.
/// When the TCP connection is closed by the TCP server, the L2CAP connection is closed as well.
use crate::cli::l2cap::{
proxy_l2cap_rx_to_tcp_tx, proxy_tcp_rx_to_l2cap_tx, run_future_with_current_task_locals,
BridgeData,
};
use bumble::wrapper::{device::Device, hci::HciConstant, l2cap::LeConnectionOrientedChannel};
use futures::executor::block_on;
use owo_colors::OwoColorize;
use pyo3::{PyResult, Python};
use std::{sync::Arc, time::Duration};
use tokio::{
join,
net::TcpStream,
select,
sync::{mpsc, Mutex},
};
pub struct Args {
pub psm: u16,
pub max_credits: Option<u16>,
pub mtu: Option<u16>,
pub mps: Option<u16>,
pub tcp_host: String,
pub tcp_port: u16,
}
pub async fn start(args: &Args, device: &mut Device) -> PyResult<()> {
let host = args.tcp_host.clone();
let port = args.tcp_port;
device.register_l2cap_channel_server(
args.psm,
move |_py, l2cap_channel| {
let channel_info = l2cap_channel
.debug_string()
.unwrap_or_else(|e| format!("failed to get l2cap channel info ({e})"));
println!("{} {channel_info}", "*** L2CAP channel:".cyan());
let host = host.clone();
// Ensure Python event loop is available to l2cap `disconnect`
let _ = run_future_with_current_task_locals(proxy_data_between_l2cap_and_tcp(
l2cap_channel,
host,
port,
));
Ok(())
},
args.max_credits,
args.mtu,
args.mps,
)?;
println!(
"{}",
format!("### Listening for CoC connection on PSM {}", args.psm).yellow()
);
device.on_connection(|_py, mut connection| {
let connection_info = connection
.debug_string()
.unwrap_or_else(|e| format!("failed to get connection info ({e})"));
println!(
"{} {}",
"@@@ Bluetooth connection: ".green(),
connection_info,
);
connection.on_disconnection(|_py, reason| {
let disconnection_info = match HciConstant::error_name(reason) {
Ok(info_string) => info_string,
Err(py_err) => format!("failed to get disconnection error name ({})", py_err),
};
println!(
"{} {}",
"@@@ Bluetooth disconnection: ".red(),
disconnection_info,
);
Ok(())
})?;
Ok(())
})?;
device.start_advertising(false).await?;
Ok(())
}
async fn proxy_data_between_l2cap_and_tcp(
mut l2cap_channel: LeConnectionOrientedChannel,
tcp_host: String,
tcp_port: u16,
) -> PyResult<()> {
let (l2cap_to_tcp_tx, mut l2cap_to_tcp_rx) = mpsc::channel::<BridgeData>(10);
// Set callback (`set_sink`) for when l2cap data is received.
let l2cap_to_tcp_tx_clone = l2cap_to_tcp_tx.clone();
l2cap_channel
.set_sink(move |_py, sdu| {
block_on(l2cap_to_tcp_tx_clone.send(BridgeData::Data(sdu.into())))
.expect("failed to channel data to tcp");
Ok(())
})
.expect("failed to set sink for l2cap connection");
// Set l2cap callback for when the channel is closed.
l2cap_channel
.on_close(move |_py| {
println!("{}", "*** L2CAP channel closed".red());
block_on(l2cap_to_tcp_tx.send(BridgeData::CloseSignal))
.expect("failed to channel close signal to tcp");
Ok(())
})
.expect("failed to set on_close callback for l2cap channel");
println!(
"{}",
format!("### Connecting to TCP {tcp_host}:{tcp_port}...").yellow()
);
let l2cap_channel = Arc::new(Mutex::new(Some(l2cap_channel)));
let tcp_stream = match TcpStream::connect(format!("{tcp_host}:{tcp_port}")).await {
Ok(stream) => {
println!("{}", "### Connected".green());
Some(stream)
}
Err(err) => {
println!("{}", format!("!!! Connection failed: {err}").red());
if let Some(mut channel) = l2cap_channel.lock().await.take() {
// Bumble might enter an invalid state if disconnection request is received from
// l2cap client before receiving a disconnection response from the same client,
// blocking this async call from returning.
// See: https://github.com/google/bumble/issues/257
select! {
res = channel.disconnect() => {
let _ = res.map_err(|e| eprintln!("Failed to call disconnect on l2cap channel: {e}"));
},
_ = tokio::time::sleep(Duration::from_secs(1)) => eprintln!("Timed out while calling disconnect on l2cap channel."),
}
}
None
}
};
match tcp_stream {
None => {
while let Some(bridge_data) = l2cap_to_tcp_rx.recv().await {
match bridge_data {
BridgeData::Data(sdu) => {
println!("{}", format!("<<< [L2CAP SDU]: {} bytes", sdu.len()).cyan());
println!("{}", "!!! TCP socket not open, dropping".red())
}
BridgeData::CloseSignal => break,
}
}
}
Some(tcp_stream) => {
let (tcp_reader, tcp_writer) = tcp_stream.into_split();
// Do tcp stuff when something happens on the l2cap channel.
let handle_l2cap_data_future =
proxy_l2cap_rx_to_tcp_tx(l2cap_to_tcp_rx, tcp_writer, l2cap_channel.clone());
// Do l2cap stuff when something happens on tcp.
let handle_tcp_data_future =
proxy_tcp_rx_to_l2cap_tx(tcp_reader, l2cap_channel.clone(), false);
let (handle_l2cap_result, handle_tcp_result) =
join!(handle_l2cap_data_future, handle_tcp_data_future);
if let Err(e) = handle_l2cap_result {
println!("!!! Error: {e}");
}
if let Err(e) = handle_tcp_result {
println!("!!! Error: {e}");
}
}
};
Python::with_gil(|_| {
// Must hold GIL at least once while/after dropping for Python heap object to ensure
// de-allocation.
drop(l2cap_channel);
});
Ok(())
}

19
rust/src/cli/mod.rs Normal file
View File

@@ -0,0 +1,19 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
pub(crate) mod firmware;
pub(crate) mod usb;
pub(crate) mod l2cap;

View File

@@ -23,7 +23,6 @@
//! whether it is a Bluetooth device that uses a non-standard Class, or some other //! whether it is a Bluetooth device that uses a non-standard Class, or some other
//! type of device (there's no way to tell). //! type of device (there's no way to tell).
use clap::Parser as _;
use itertools::Itertools as _; use itertools::Itertools as _;
use owo_colors::{OwoColorize, Style}; use owo_colors::{OwoColorize, Style};
use rusb::{Device, DeviceDescriptor, Direction, TransferType, UsbContext}; use rusb::{Device, DeviceDescriptor, Direction, TransferType, UsbContext};
@@ -31,15 +30,12 @@ use std::{
collections::{HashMap, HashSet}, collections::{HashMap, HashSet},
time::Duration, time::Duration,
}; };
const USB_DEVICE_CLASS_DEVICE: u8 = 0x00; const USB_DEVICE_CLASS_DEVICE: u8 = 0x00;
const USB_DEVICE_CLASS_WIRELESS_CONTROLLER: u8 = 0xE0; const USB_DEVICE_CLASS_WIRELESS_CONTROLLER: u8 = 0xE0;
const USB_DEVICE_SUBCLASS_RF_CONTROLLER: u8 = 0x01; const USB_DEVICE_SUBCLASS_RF_CONTROLLER: u8 = 0x01;
const USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER: u8 = 0x01; const USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER: u8 = 0x01;
fn main() -> anyhow::Result<()> { pub(crate) fn probe(verbose: bool) -> anyhow::Result<()> {
let cli = Cli::parse();
let mut bt_dev_count = 0; let mut bt_dev_count = 0;
let mut device_serials_by_id: HashMap<(u16, u16), HashSet<String>> = HashMap::new(); let mut device_serials_by_id: HashMap<(u16, u16), HashSet<String>> = HashMap::new();
for device in rusb::devices()?.iter() { for device in rusb::devices()?.iter() {
@@ -159,7 +155,7 @@ fn main() -> anyhow::Result<()> {
println!("{:26}{}", " Product:".green(), p); println!("{:26}{}", " Product:".green(), p);
} }
if cli.verbose { if verbose {
print_device_details(&device, &device_desc)?; print_device_details(&device, &device_desc)?;
} }
@@ -332,11 +328,3 @@ impl From<&DeviceDescriptor> for ClassInfo {
) )
} }
} }
#[derive(clap::Parser)]
#[command(author, version, about, long_about = None)]
struct Cli {
/// Show additional info for each USB device
#[arg(long, default_value_t = false)]
verbose: bool,
}

View File

@@ -0,0 +1,17 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Device drivers
pub(crate) mod rtk;

View File

@@ -0,0 +1,253 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Drivers for Realtek controllers
use nom::{bytes, combinator, error, multi, number, sequence};
/// Realtek firmware file contents
pub struct Firmware {
version: u32,
project_id: u8,
patches: Vec<Patch>,
}
impl Firmware {
/// Parse a `*_fw.bin` file
pub fn parse(input: &[u8]) -> Result<Self, nom::Err<error::Error<&[u8]>>> {
let extension_sig = [0x51, 0x04, 0xFD, 0x77];
let (_rem, (_tag, fw_version, patch_count, payload)) =
combinator::all_consuming(combinator::map_parser(
// ignore the sig suffix
sequence::terminated(
bytes::complete::take(
// underflow will show up as parse failure
input.len().saturating_sub(extension_sig.len()),
),
bytes::complete::tag(extension_sig.as_slice()),
),
sequence::tuple((
bytes::complete::tag(b"Realtech"),
// version
number::complete::le_u32,
// patch count
combinator::map(number::complete::le_u16, |c| c as usize),
// everything else except suffix
combinator::rest,
)),
))(input)?;
// ignore remaining input, since patch offsets are relative to the complete input
let (_rem, (chip_ids, patch_lengths, patch_offsets)) = sequence::tuple((
// chip id
multi::many_m_n(patch_count, patch_count, number::complete::le_u16),
// patch length
multi::many_m_n(patch_count, patch_count, number::complete::le_u16),
// patch offset
multi::many_m_n(patch_count, patch_count, number::complete::le_u32),
))(payload)?;
let patches = chip_ids
.into_iter()
.zip(patch_lengths.into_iter())
.zip(patch_offsets.into_iter())
.map(|((chip_id, patch_length), patch_offset)| {
combinator::map(
sequence::preceded(
bytes::complete::take(patch_offset),
// ignore trailing 4-byte suffix
sequence::terminated(
// patch including svn version, but not suffix
combinator::consumed(sequence::preceded(
// patch before svn version or version suffix
// prefix length underflow will show up as parse failure
bytes::complete::take(patch_length.saturating_sub(8)),
// svn version
number::complete::le_u32,
)),
// dummy suffix, overwritten with firmware version
bytes::complete::take(4_usize),
),
),
|(patch_contents_before_version, svn_version): (&[u8], u32)| {
let mut contents = patch_contents_before_version.to_vec();
// replace what would have been the trailing dummy suffix with fw version
contents.extend_from_slice(&fw_version.to_le_bytes());
Patch {
contents,
svn_version,
chip_id,
}
},
)(input)
.map(|(_rem, output)| output)
})
.collect::<Result<Vec<_>, _>>()?;
// look for project id from the end
let mut offset = payload.len();
let mut project_id: Option<u8> = None;
while offset >= 2 {
// Won't panic, since offset >= 2
let chunk = &payload[offset - 2..offset];
let length: usize = chunk[0].into();
let opcode = chunk[1];
offset -= 2;
if opcode == 0xFF {
break;
}
if length == 0 {
// report what nom likely would have done, if nom was good at parsing backwards
return Err(nom::Err::Error(error::Error::new(
chunk,
error::ErrorKind::Verify,
)));
}
if opcode == 0 && length == 1 {
project_id = offset
.checked_sub(1)
.and_then(|index| payload.get(index))
.copied();
break;
}
offset -= length;
}
match project_id {
Some(project_id) => Ok(Firmware {
project_id,
version: fw_version,
patches,
}),
None => {
// we ran out of file without finding a project id
Err(nom::Err::Error(error::Error::new(
payload,
error::ErrorKind::Eof,
)))
}
}
}
/// Patch version
pub fn version(&self) -> u32 {
self.version
}
/// Project id
pub fn project_id(&self) -> u8 {
self.project_id
}
/// Patches
pub fn patches(&self) -> &[Patch] {
&self.patches
}
}
/// Patch in a [Firmware}
pub struct Patch {
chip_id: u16,
contents: Vec<u8>,
svn_version: u32,
}
impl Patch {
/// Chip id
pub fn chip_id(&self) -> u16 {
self.chip_id
}
/// Contents of the patch, including the 4-byte firmware version suffix
pub fn contents(&self) -> &[u8] {
&self.contents
}
/// SVN version
pub fn svn_version(&self) -> u32 {
self.svn_version
}
}
#[cfg(test)]
mod tests {
use super::*;
use anyhow::anyhow;
use std::{fs, io, path};
#[test]
fn parse_firmware_rtl8723b() -> anyhow::Result<()> {
let fw = Firmware::parse(&firmware_contents("rtl8723b_fw_structure.bin")?)
.map_err(|e| anyhow!("{:?}", e))?;
let fw_version = 0x0E2F9F73;
assert_eq!(fw_version, fw.version());
assert_eq!(0x0001, fw.project_id());
assert_eq!(
vec![(0x0001, 0x00002BBF, 22368,), (0x0002, 0x00002BBF, 22496,),],
patch_summaries(fw, fw_version)
);
Ok(())
}
#[test]
fn parse_firmware_rtl8761bu() -> anyhow::Result<()> {
let fw = Firmware::parse(&firmware_contents("rtl8761bu_fw_structure.bin")?)
.map_err(|e| anyhow!("{:?}", e))?;
let fw_version = 0xDFC6D922;
assert_eq!(fw_version, fw.version());
assert_eq!(0x000E, fw.project_id());
assert_eq!(
vec![(0x0001, 0x00005060, 14048,), (0x0002, 0xD6D525A4, 30204,),],
patch_summaries(fw, fw_version)
);
Ok(())
}
fn firmware_contents(filename: &str) -> io::Result<Vec<u8>> {
fs::read(
path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("resources/test/firmware/realtek")
.join(filename),
)
}
/// Return a tuple of (chip id, svn version, contents len, contents sha256)
fn patch_summaries(fw: Firmware, fw_version: u32) -> Vec<(u16, u32, usize)> {
fw.patches()
.iter()
.map(|p| {
let contents = p.contents();
let mut dummy_contents = dummy_contents(contents.len());
dummy_contents.extend_from_slice(&p.svn_version().to_le_bytes());
dummy_contents.extend_from_slice(&fw_version.to_le_bytes());
assert_eq!(&dummy_contents, contents);
(p.chip_id(), p.svn_version(), contents.len())
})
.collect::<Vec<_>>()
}
fn dummy_contents(len: usize) -> Vec<u8> {
let mut vec = (len as u32).to_le_bytes().as_slice().repeat(len / 4 + 1);
assert!(vec.len() >= len);
// leave room for svn version and firmware version
vec.truncate(len - 8);
vec
}
}

20
rust/src/internal/mod.rs Normal file
View File

@@ -0,0 +1,20 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! It's not clear where to put Rust code that isn't simply a wrapper around Python. Until we have
//! a good answer for what to do there, the idea is to put it in this (non-public) module, and
//! `pub use` it into the relevant areas of the `wrapper` module so that it's still easy for users
//! to discover.
pub(crate) mod drivers;

View File

@@ -29,3 +29,5 @@
pub mod wrapper; pub mod wrapper;
pub mod adv; pub mod adv;
pub(crate) mod internal;

271
rust/src/main.rs Normal file
View File

@@ -0,0 +1,271 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! CLI tools for Bumble
#![deny(missing_docs, unsafe_code)]
use bumble::wrapper::logging::{bumble_env_logging_level, py_logging_basic_config};
use clap::Parser as _;
use pyo3::PyResult;
use std::{fmt, path};
mod cli;
#[pyo3_asyncio::tokio::main]
async fn main() -> PyResult<()> {
env_logger::builder()
.filter_level(log::LevelFilter::Info)
.init();
py_logging_basic_config(bumble_env_logging_level("INFO"))?;
let cli: Cli = Cli::parse();
match cli.subcommand {
Subcommand::Firmware { subcommand: fw } => match fw {
Firmware::Realtek { subcommand: rtk } => match rtk {
Realtek::Download(dl) => {
cli::firmware::rtk::download(dl).await?;
}
Realtek::Drop { transport } => cli::firmware::rtk::drop(&transport).await?,
Realtek::Info { transport, force } => {
cli::firmware::rtk::info(&transport, force).await?;
}
Realtek::Load { transport, force } => {
cli::firmware::rtk::load(&transport, force).await?
}
Realtek::Parse { firmware_path } => cli::firmware::rtk::parse(&firmware_path)?,
},
},
Subcommand::L2cap {
subcommand,
device_config,
transport,
psm,
l2cap_coc_max_credits,
l2cap_coc_mtu,
l2cap_coc_mps,
} => {
cli::l2cap::run(
subcommand,
device_config,
transport,
psm,
l2cap_coc_max_credits,
l2cap_coc_mtu,
l2cap_coc_mps,
)
.await?
}
Subcommand::Usb { subcommand } => match subcommand {
Usb::Probe(probe) => cli::usb::probe(probe.verbose)?,
},
}
Ok(())
}
#[derive(clap::Parser)]
struct Cli {
#[clap(subcommand)]
subcommand: Subcommand,
}
#[derive(clap::Subcommand, Debug, Clone)]
enum Subcommand {
/// Manage device firmware
Firmware {
#[clap(subcommand)]
subcommand: Firmware,
},
/// L2cap client/server operations
L2cap {
#[command(subcommand)]
subcommand: L2cap,
/// Device configuration file.
///
/// See, for instance, `examples/device1.json` in the Python project.
#[arg(long)]
device_config: path::PathBuf,
/// Bumble transport spec.
///
/// <https://google.github.io/bumble/transports/index.html>
#[arg(long)]
transport: String,
/// PSM for L2CAP Connection-oriented Channel.
///
/// Must be in the range [0, 65535].
#[arg(long)]
psm: u16,
/// Maximum L2CAP CoC Credits. When not specified, lets Bumble set the default.
///
/// Must be in the range [1, 65535].
#[arg(long, value_parser = clap::value_parser!(u16).range(1..))]
l2cap_coc_max_credits: Option<u16>,
/// L2CAP CoC MTU. When not specified, lets Bumble set the default.
///
/// Must be in the range [23, 65535].
#[arg(long, value_parser = clap::value_parser!(u16).range(23..))]
l2cap_coc_mtu: Option<u16>,
/// L2CAP CoC MPS. When not specified, lets Bumble set the default.
///
/// Must be in the range [23, 65535].
#[arg(long, value_parser = clap::value_parser!(u16).range(23..))]
l2cap_coc_mps: Option<u16>,
},
/// USB operations
Usb {
#[clap(subcommand)]
subcommand: Usb,
},
}
#[derive(clap::Subcommand, Debug, Clone)]
enum Firmware {
/// Manage Realtek chipset firmware
Realtek {
#[clap(subcommand)]
subcommand: Realtek,
},
}
#[derive(clap::Subcommand, Debug, Clone)]
enum Realtek {
/// Download Realtek firmware
Download(Download),
/// Drop firmware from a USB device
Drop {
/// Bumble transport spec. Must be for a USB device.
///
/// <https://google.github.io/bumble/transports/index.html>
#[arg(long)]
transport: String,
},
/// Show driver info for a USB device
Info {
/// Bumble transport spec. Must be for a USB device.
///
/// <https://google.github.io/bumble/transports/index.html>
#[arg(long)]
transport: String,
/// Try to resolve driver info even if USB info is not available, or if the USB
/// (vendor,product) tuple is not in the list of known compatible RTK USB dongles.
#[arg(long, default_value_t = false)]
force: bool,
},
/// Load firmware onto a USB device
Load {
/// Bumble transport spec. Must be for a USB device.
///
/// <https://google.github.io/bumble/transports/index.html>
#[arg(long)]
transport: String,
/// Load firmware even if the USB info doesn't match.
#[arg(long, default_value_t = false)]
force: bool,
},
/// Parse a firmware file
Parse {
/// Firmware file to parse
firmware_path: path::PathBuf,
},
}
#[derive(clap::Args, Debug, Clone)]
struct Download {
/// Directory to download to. Defaults to an OS-specific path specific to the Bumble tool.
#[arg(long)]
output_dir: Option<path::PathBuf>,
/// Source to download from
#[arg(long, default_value_t = Source::LinuxKernel)]
source: Source,
/// Only download a single image
#[arg(long, value_name = "base name")]
single: Option<String>,
/// Overwrite existing files
#[arg(long, default_value_t = false)]
overwrite: bool,
/// Don't print the parse results for the downloaded file names
#[arg(long)]
no_parse: bool,
}
#[derive(Debug, Clone, clap::ValueEnum)]
enum Source {
LinuxKernel,
RealtekOpensource,
LinuxFromScratch,
}
impl fmt::Display for Source {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Source::LinuxKernel => write!(f, "linux-kernel"),
Source::RealtekOpensource => write!(f, "realtek-opensource"),
Source::LinuxFromScratch => write!(f, "linux-from-scratch"),
}
}
}
#[derive(clap::Subcommand, Debug, Clone)]
enum L2cap {
/// Starts an L2CAP server
Server {
/// TCP host that the l2cap server will connect to.
/// Data is bridged like so:
/// TCP server <-> (TCP client / **L2CAP server**) <-> (L2CAP client / TCP server) <-> TCP client
#[arg(long, default_value = "localhost")]
tcp_host: String,
/// TCP port that the server will connect to.
///
/// Must be in the range [1, 65535].
#[arg(long, default_value_t = 9544)]
tcp_port: u16,
},
/// Starts an L2CAP client
Client {
/// L2cap server address that this l2cap client will connect to.
bluetooth_address: String,
/// TCP host that the l2cap client will bind to and listen for incoming TCP connections.
/// Data is bridged like so:
/// TCP client <-> (TCP server / **L2CAP client**) <-> (L2CAP server / TCP client) <-> TCP server
#[arg(long, default_value = "localhost")]
tcp_host: String,
/// TCP port that the client will connect to.
///
/// Must be in the range [1, 65535].
#[arg(long, default_value_t = 9543)]
tcp_port: u16,
},
}
#[derive(clap::Subcommand, Debug, Clone)]
enum Usb {
/// Probe the USB bus for Bluetooth devices
Probe(Probe),
}
#[derive(clap::Args, Debug, Clone)]
struct Probe {
/// Show additional info for each USB device
#[arg(long, default_value_t = false)]
verbose: bool,
}

File diff suppressed because it is too large Load Diff

View File

@@ -14,40 +14,8 @@
//! Assigned numbers from the Bluetooth spec. //! Assigned numbers from the Bluetooth spec.
use crate::wrapper::core::Uuid16; mod company_ids;
use lazy_static::lazy_static;
use pyo3::{
intern,
types::{PyDict, PyModule},
PyResult, Python,
};
use std::collections;
mod services; mod services;
pub use company_ids::COMPANY_IDS;
pub use services::SERVICE_IDS; pub use services::SERVICE_IDS;
lazy_static! {
/// Assigned company IDs
pub static ref COMPANY_IDS: collections::HashMap<Uuid16, String> = load_company_ids()
.expect("Could not load company ids -- are Bumble's Python sources available?");
}
fn load_company_ids() -> PyResult<collections::HashMap<Uuid16, String>> {
// this takes about 4ms on a fast machine -- slower than constructing in rust, but not slow
// enough to worry about
Python::with_gil(|py| {
PyModule::import(py, intern!(py, "bumble.company_ids"))?
.getattr(intern!(py, "COMPANY_IDENTIFIERS"))?
.downcast::<PyDict>()?
.into_iter()
.map(|(k, v)| {
Ok((
Uuid16::from_be_bytes(k.extract::<u16>()?.to_be_bytes()),
v.str()?.to_str()?.to_string(),
))
})
.collect::<PyResult<collections::HashMap<_, _>>>()
})
}

View File

@@ -59,7 +59,7 @@ impl AdvertisingData {
} }
/// 16-bit UUID /// 16-bit UUID
#[derive(PartialEq, Eq, Hash)] #[derive(PartialEq, Eq, Hash, Clone, Copy)]
pub struct Uuid16 { pub struct Uuid16 {
/// Big-endian bytes /// Big-endian bytes
uuid: [u8; 2], uuid: [u8; 2],

View File

@@ -19,13 +19,19 @@ use crate::{
wrapper::{ wrapper::{
core::AdvertisingData, core::AdvertisingData,
gatt_client::{ProfileServiceProxy, ServiceProxy}, gatt_client::{ProfileServiceProxy, ServiceProxy},
hci::Address, hci::{Address, HciErrorCode},
host::Host,
l2cap::LeConnectionOrientedChannel,
transport::{Sink, Source}, transport::{Sink, Source},
ClosureCallback, ClosureCallback, PyDictExt, PyObjectExt,
}, },
}; };
use pyo3::types::PyDict; use pyo3::{
use pyo3::{intern, types::PyModule, PyObject, PyResult, Python, ToPyObject}; intern,
types::{PyDict, PyModule},
IntoPy, PyObject, PyResult, Python, ToPyObject,
};
use pyo3_asyncio::tokio::into_future;
use std::path; use std::path;
/// A device that can send/receive HCI frames. /// A device that can send/receive HCI frames.
@@ -65,7 +71,7 @@ impl Device {
Python::with_gil(|py| { Python::with_gil(|py| {
self.0 self.0
.call_method0(py, intern!(py, "power_on")) .call_method0(py, intern!(py, "power_on"))
.and_then(|coroutine| pyo3_asyncio::tokio::into_future(coroutine.as_ref(py))) .and_then(|coroutine| into_future(coroutine.as_ref(py)))
})? })?
.await .await
.map(|_| ()) .map(|_| ())
@@ -76,12 +82,28 @@ impl Device {
Python::with_gil(|py| { Python::with_gil(|py| {
self.0 self.0
.call_method1(py, intern!(py, "connect"), (peer_addr,)) .call_method1(py, intern!(py, "connect"), (peer_addr,))
.and_then(|coroutine| pyo3_asyncio::tokio::into_future(coroutine.as_ref(py))) .and_then(|coroutine| into_future(coroutine.as_ref(py)))
})? })?
.await .await
.map(Connection) .map(Connection)
} }
/// Register a callback to be called for each incoming connection.
pub fn on_connection(
&mut self,
callback: impl Fn(Python, Connection) -> PyResult<()> + Send + 'static,
) -> PyResult<()> {
let boxed = ClosureCallback::new(move |py, args, _kwargs| {
callback(py, Connection(args.get_item(0)?.into()))
});
Python::with_gil(|py| {
self.0
.call_method1(py, intern!(py, "add_listener"), ("connection", boxed))
})
.map(|_| ())
}
/// Start scanning /// Start scanning
pub async fn start_scanning(&self, filter_duplicates: bool) -> PyResult<()> { pub async fn start_scanning(&self, filter_duplicates: bool) -> PyResult<()> {
Python::with_gil(|py| { Python::with_gil(|py| {
@@ -89,7 +111,7 @@ impl Device {
kwargs.set_item("filter_duplicates", filter_duplicates)?; kwargs.set_item("filter_duplicates", filter_duplicates)?;
self.0 self.0
.call_method(py, intern!(py, "start_scanning"), (), Some(kwargs)) .call_method(py, intern!(py, "start_scanning"), (), Some(kwargs))
.and_then(|coroutine| pyo3_asyncio::tokio::into_future(coroutine.as_ref(py))) .and_then(|coroutine| into_future(coroutine.as_ref(py)))
})? })?
.await .await
.map(|_| ()) .map(|_| ())
@@ -123,6 +145,15 @@ impl Device {
.map(|_| ()) .map(|_| ())
} }
/// Returns the host used by the device, if any
pub fn host(&mut self) -> PyResult<Option<Host>> {
Python::with_gil(|py| {
self.0
.getattr(py, intern!(py, "host"))
.map(|obj| obj.into_option(Host::from))
})
}
/// Start advertising the data set with [Device.set_advertisement]. /// Start advertising the data set with [Device.set_advertisement].
pub async fn start_advertising(&mut self, auto_restart: bool) -> PyResult<()> { pub async fn start_advertising(&mut self, auto_restart: bool) -> PyResult<()> {
Python::with_gil(|py| { Python::with_gil(|py| {
@@ -131,7 +162,7 @@ impl Device {
self.0 self.0
.call_method(py, intern!(py, "start_advertising"), (), Some(kwargs)) .call_method(py, intern!(py, "start_advertising"), (), Some(kwargs))
.and_then(|coroutine| pyo3_asyncio::tokio::into_future(coroutine.as_ref(py))) .and_then(|coroutine| into_future(coroutine.as_ref(py)))
})? })?
.await .await
.map(|_| ()) .map(|_| ())
@@ -142,16 +173,114 @@ impl Device {
Python::with_gil(|py| { Python::with_gil(|py| {
self.0 self.0
.call_method0(py, intern!(py, "stop_advertising")) .call_method0(py, intern!(py, "stop_advertising"))
.and_then(|coroutine| pyo3_asyncio::tokio::into_future(coroutine.as_ref(py))) .and_then(|coroutine| into_future(coroutine.as_ref(py)))
})? })?
.await .await
.map(|_| ()) .map(|_| ())
} }
/// Registers an L2CAP connection oriented channel server. When a client connects to the server,
/// the `server` callback is passed a handle to the established channel. When optional arguments
/// are not specified, the Python module specifies the defaults.
pub fn register_l2cap_channel_server(
&mut self,
psm: u16,
server: impl Fn(Python, LeConnectionOrientedChannel) -> PyResult<()> + Send + 'static,
max_credits: Option<u16>,
mtu: Option<u16>,
mps: Option<u16>,
) -> PyResult<()> {
Python::with_gil(|py| {
let boxed = ClosureCallback::new(move |py, args, _kwargs| {
server(
py,
LeConnectionOrientedChannel::from(args.get_item(0)?.into()),
)
});
let kwargs = PyDict::new(py);
kwargs.set_item("psm", psm)?;
kwargs.set_item("server", boxed.into_py(py))?;
kwargs.set_opt_item("max_credits", max_credits)?;
kwargs.set_opt_item("mtu", mtu)?;
kwargs.set_opt_item("mps", mps)?;
self.0.call_method(
py,
intern!(py, "register_l2cap_channel_server"),
(),
Some(kwargs),
)
})?;
Ok(())
}
} }
/// A connection to a remote device. /// A connection to a remote device.
pub struct Connection(PyObject); pub struct Connection(PyObject);
impl Connection {
/// Open an L2CAP channel using this connection. When optional arguments are not specified, the
/// Python module specifies the defaults.
pub async fn open_l2cap_channel(
&mut self,
psm: u16,
max_credits: Option<u16>,
mtu: Option<u16>,
mps: Option<u16>,
) -> PyResult<LeConnectionOrientedChannel> {
Python::with_gil(|py| {
let kwargs = PyDict::new(py);
kwargs.set_item("psm", psm)?;
kwargs.set_opt_item("max_credits", max_credits)?;
kwargs.set_opt_item("mtu", mtu)?;
kwargs.set_opt_item("mps", mps)?;
self.0
.call_method(py, intern!(py, "open_l2cap_channel"), (), Some(kwargs))
.and_then(|coroutine| pyo3_asyncio::tokio::into_future(coroutine.as_ref(py)))
})?
.await
.map(LeConnectionOrientedChannel::from)
}
/// Disconnect from device with provided reason. When optional arguments are not specified, the
/// Python module specifies the defaults.
pub async fn disconnect(&mut self, reason: Option<HciErrorCode>) -> PyResult<()> {
Python::with_gil(|py| {
let kwargs = PyDict::new(py);
kwargs.set_opt_item("reason", reason)?;
self.0
.call_method(py, intern!(py, "disconnect"), (), Some(kwargs))
.and_then(|coroutine| pyo3_asyncio::tokio::into_future(coroutine.as_ref(py)))
})?
.await
.map(|_| ())
}
/// Register a callback to be called on disconnection.
pub fn on_disconnection(
&mut self,
callback: impl Fn(Python, HciErrorCode) -> PyResult<()> + Send + 'static,
) -> PyResult<()> {
let boxed = ClosureCallback::new(move |py, args, _kwargs| {
callback(py, args.get_item(0)?.extract()?)
});
Python::with_gil(|py| {
self.0
.call_method1(py, intern!(py, "add_listener"), ("disconnection", boxed))
})
.map(|_| ())
}
/// Returns some information about the connection as a [String].
pub fn debug_string(&self) -> PyResult<String> {
Python::with_gil(|py| {
let str_obj = self.0.call_method0(py, intern!(py, "__str__"))?;
str_obj.gil_ref(py).extract()
})
}
}
/// The other end of a connection /// The other end of a connection
pub struct Peer(PyObject); pub struct Peer(PyObject);
@@ -173,7 +302,7 @@ impl Peer {
Python::with_gil(|py| { Python::with_gil(|py| {
self.0 self.0
.call_method0(py, intern!(py, "discover_services")) .call_method0(py, intern!(py, "discover_services"))
.and_then(|coroutine| pyo3_asyncio::tokio::into_future(coroutine.as_ref(py))) .and_then(|coroutine| into_future(coroutine.as_ref(py)))
})? })?
.await .await
.and_then(|list| { .and_then(|list| {
@@ -207,13 +336,7 @@ impl Peer {
let class = module.getattr(P::PROXY_CLASS_NAME)?; let class = module.getattr(P::PROXY_CLASS_NAME)?;
self.0 self.0
.call_method1(py, intern!(py, "create_service_proxy"), (class,)) .call_method1(py, intern!(py, "create_service_proxy"), (class,))
.map(|obj| { .map(|obj| obj.into_option(P::wrap))
if obj.is_none(py) {
None
} else {
Some(P::wrap(obj))
}
})
}) })
} }
} }

View File

@@ -0,0 +1,17 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Device drivers
pub mod rtk;

View File

@@ -0,0 +1,141 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Drivers for Realtek controllers
use crate::wrapper::{host::Host, PyObjectExt};
use pyo3::{intern, types::PyModule, PyObject, PyResult, Python, ToPyObject};
use pyo3_asyncio::tokio::into_future;
pub use crate::internal::drivers::rtk::{Firmware, Patch};
/// Driver for a Realtek controller
pub struct Driver(PyObject);
impl Driver {
/// Locate the driver for the provided host.
pub async fn for_host(host: &Host, force: bool) -> PyResult<Option<Self>> {
Python::with_gil(|py| {
PyModule::import(py, intern!(py, "bumble.drivers.rtk"))?
.getattr(intern!(py, "Driver"))?
.call_method1(intern!(py, "for_host"), (&host.obj, force))
.and_then(into_future)
})?
.await
.map(|obj| obj.into_option(Self))
}
/// Check if the host has a known driver.
pub async fn check(host: &Host) -> PyResult<bool> {
Python::with_gil(|py| {
PyModule::import(py, intern!(py, "bumble.drivers.rtk"))?
.getattr(intern!(py, "Driver"))?
.call_method1(intern!(py, "check"), (&host.obj,))
.and_then(|obj| obj.extract::<bool>())
})
}
/// Find the [DriverInfo] for the host, if one matches
pub async fn driver_info_for_host(host: &Host) -> PyResult<Option<DriverInfo>> {
Python::with_gil(|py| {
PyModule::import(py, intern!(py, "bumble.drivers.rtk"))?
.getattr(intern!(py, "Driver"))?
.call_method1(intern!(py, "driver_info_for_host"), (&host.obj,))
.and_then(into_future)
})?
.await
.map(|obj| obj.into_option(DriverInfo))
}
/// Send a command to the device to drop firmware
pub async fn drop_firmware(host: &mut Host) -> PyResult<()> {
Python::with_gil(|py| {
PyModule::import(py, intern!(py, "bumble.drivers.rtk"))?
.getattr(intern!(py, "Driver"))?
.call_method1(intern!(py, "drop_firmware"), (&host.obj,))
.and_then(into_future)
})?
.await
.map(|_| ())
}
/// Load firmware onto the device.
pub async fn download_firmware(&mut self) -> PyResult<()> {
Python::with_gil(|py| {
self.0
.call_method0(py, intern!(py, "download_firmware"))
.and_then(|coroutine| into_future(coroutine.as_ref(py)))
})?
.await
.map(|_| ())
}
}
/// Metadata about a known driver & applicable device
pub struct DriverInfo(PyObject);
impl DriverInfo {
/// Returns a list of all drivers that Bumble knows how to handle.
pub fn all_drivers() -> PyResult<Vec<DriverInfo>> {
Python::with_gil(|py| {
PyModule::import(py, intern!(py, "bumble.drivers.rtk"))?
.getattr(intern!(py, "Driver"))?
.getattr(intern!(py, "DRIVER_INFOS"))?
.iter()?
.map(|r| r.map(|h| DriverInfo(h.to_object(py))))
.collect::<PyResult<Vec<_>>>()
})
}
/// The firmware file name to load from the filesystem, e.g. `foo_fw.bin`.
pub fn firmware_name(&self) -> PyResult<String> {
Python::with_gil(|py| {
self.0
.getattr(py, intern!(py, "fw_name"))?
.as_ref(py)
.extract::<String>()
})
}
/// The config file name, if any, to load from the filesystem, e.g. `foo_config.bin`.
pub fn config_name(&self) -> PyResult<Option<String>> {
Python::with_gil(|py| {
let obj = self.0.getattr(py, intern!(py, "config_name"))?;
let handle = obj.as_ref(py);
if handle.is_none() {
Ok(None)
} else {
handle
.extract::<String>()
.map(|s| if s.is_empty() { None } else { Some(s) })
}
})
}
/// Whether or not config is required.
pub fn config_needed(&self) -> PyResult<bool> {
Python::with_gil(|py| {
self.0
.getattr(py, intern!(py, "config_needed"))?
.as_ref(py)
.extract::<bool>()
})
}
/// ROM id
pub fn rom(&self) -> PyResult<u32> {
Python::with_gil(|py| self.0.getattr(py, intern!(py, "rom"))?.as_ref(py).extract())
}
}

View File

@@ -15,7 +15,40 @@
//! HCI //! HCI
use itertools::Itertools as _; use itertools::Itertools as _;
use pyo3::{exceptions::PyException, intern, types::PyModule, PyErr, PyObject, PyResult, Python}; use pyo3::{
exceptions::PyException, intern, types::PyModule, FromPyObject, PyAny, PyErr, PyObject,
PyResult, Python, ToPyObject,
};
/// HCI error code.
pub struct HciErrorCode(u8);
impl<'source> FromPyObject<'source> for HciErrorCode {
fn extract(ob: &'source PyAny) -> PyResult<Self> {
Ok(HciErrorCode(ob.extract()?))
}
}
impl ToPyObject for HciErrorCode {
fn to_object(&self, py: Python<'_>) -> PyObject {
self.0.to_object(py)
}
}
/// Provides helpers for interacting with HCI
pub struct HciConstant;
impl HciConstant {
/// Human-readable error name
pub fn error_name(status: HciErrorCode) -> PyResult<String> {
Python::with_gil(|py| {
PyModule::import(py, intern!(py, "bumble.hci"))?
.getattr(intern!(py, "HCI_Constant"))?
.call_method1(intern!(py, "error_name"), (status.0,))?
.extract()
})
}
}
/// A Bluetooth address /// A Bluetooth address
pub struct Address(pub(crate) PyObject); pub struct Address(pub(crate) PyObject);

71
rust/src/wrapper/host.rs Normal file
View File

@@ -0,0 +1,71 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Host-side types
use crate::wrapper::transport::{Sink, Source};
use pyo3::{intern, prelude::PyModule, types::PyDict, PyObject, PyResult, Python};
/// Host HCI commands
pub struct Host {
pub(crate) obj: PyObject,
}
impl Host {
/// Create a Host that wraps the provided obj
pub(crate) fn from(obj: PyObject) -> Self {
Self { obj }
}
/// Create a new Host
pub fn new(source: Source, sink: Sink) -> PyResult<Self> {
Python::with_gil(|py| {
PyModule::import(py, intern!(py, "bumble.host"))?
.getattr(intern!(py, "Host"))?
.call((source.0, sink.0), None)
.map(|any| Self { obj: any.into() })
})
}
/// Send a reset command and perform other reset tasks.
pub async fn reset(&mut self, driver_factory: DriverFactory) -> PyResult<()> {
Python::with_gil(|py| {
let kwargs = match driver_factory {
DriverFactory::None => {
let kw = PyDict::new(py);
kw.set_item("driver_factory", py.None())?;
Some(kw)
}
DriverFactory::Auto => {
// leave the default in place
None
}
};
self.obj
.call_method(py, intern!(py, "reset"), (), kwargs)
.and_then(|coroutine| pyo3_asyncio::tokio::into_future(coroutine.as_ref(py)))
})?
.await
.map(|_| ())
}
}
/// Driver factory to use when initializing a host
#[derive(Debug, Clone)]
pub enum DriverFactory {
/// Do not load drivers
None,
/// Load appropriate driver, if any is found
Auto,
}

92
rust/src/wrapper/l2cap.rs Normal file
View File

@@ -0,0 +1,92 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! L2CAP
use crate::wrapper::{ClosureCallback, PyObjectExt};
use pyo3::{intern, PyObject, PyResult, Python};
/// L2CAP connection-oriented channel
pub struct LeConnectionOrientedChannel(PyObject);
impl LeConnectionOrientedChannel {
/// Create a LeConnectionOrientedChannel that wraps the provided obj.
pub(crate) fn from(obj: PyObject) -> Self {
Self(obj)
}
/// Queues data to be automatically sent across this channel.
pub fn write(&mut self, data: &[u8]) -> PyResult<()> {
Python::with_gil(|py| self.0.call_method1(py, intern!(py, "write"), (data,))).map(|_| ())
}
/// Wait for queued data to be sent on this channel.
pub async fn drain(&mut self) -> PyResult<()> {
Python::with_gil(|py| {
self.0
.call_method0(py, intern!(py, "drain"))
.and_then(|coroutine| pyo3_asyncio::tokio::into_future(coroutine.as_ref(py)))
})?
.await
.map(|_| ())
}
/// Register a callback to be called when the channel is closed.
pub fn on_close(
&mut self,
callback: impl Fn(Python) -> PyResult<()> + Send + 'static,
) -> PyResult<()> {
let boxed = ClosureCallback::new(move |py, _args, _kwargs| callback(py));
Python::with_gil(|py| {
self.0
.call_method1(py, intern!(py, "add_listener"), ("close", boxed))
})
.map(|_| ())
}
/// Register a callback to be called when the channel receives data.
pub fn set_sink(
&mut self,
callback: impl Fn(Python, &[u8]) -> PyResult<()> + Send + 'static,
) -> PyResult<()> {
let boxed = ClosureCallback::new(move |py, args, _kwargs| {
callback(py, args.get_item(0)?.extract()?)
});
Python::with_gil(|py| self.0.setattr(py, intern!(py, "sink"), boxed)).map(|_| ())
}
/// Disconnect the l2cap channel.
/// Must be called from a thread with a Python event loop, which should be true on
/// `tokio::main` and `async_std::main`.
///
/// For more info, see https://awestlake87.github.io/pyo3-asyncio/master/doc/pyo3_asyncio/#event-loop-references-and-contextvars.
pub async fn disconnect(&mut self) -> PyResult<()> {
Python::with_gil(|py| {
self.0
.call_method0(py, intern!(py, "disconnect"))
.and_then(|coroutine| pyo3_asyncio::tokio::into_future(coroutine.as_ref(py)))
})?
.await
.map(|_| ())
}
/// Returns some information about the channel as a [String].
pub fn debug_string(&self) -> PyResult<String> {
Python::with_gil(|py| {
let str_obj = self.0.call_method0(py, intern!(py, "__str__"))?;
str_obj.gil_ref(py).extract()
})
}
}

View File

@@ -31,14 +31,17 @@ pub use pyo3_asyncio;
pub mod assigned_numbers; pub mod assigned_numbers;
pub mod core; pub mod core;
pub mod device; pub mod device;
pub mod drivers;
pub mod gatt_client; pub mod gatt_client;
pub mod hci; pub mod hci;
pub mod host;
pub mod l2cap;
pub mod logging; pub mod logging;
pub mod profile; pub mod profile;
pub mod transport; pub mod transport;
/// Convenience extensions to [PyObject] /// Convenience extensions to [PyObject]
pub trait PyObjectExt { pub trait PyObjectExt: Sized {
/// Get a GIL-bound reference /// Get a GIL-bound reference
fn gil_ref<'py>(&'py self, py: Python<'py>) -> &'py PyAny; fn gil_ref<'py>(&'py self, py: Python<'py>) -> &'py PyAny;
@@ -49,6 +52,17 @@ pub trait PyObjectExt {
{ {
Python::with_gil(|py| self.gil_ref(py).extract::<T>()) Python::with_gil(|py| self.gil_ref(py).extract::<T>())
} }
/// If the Python object is a Python `None`, return a Rust `None`, otherwise `Some` with the mapped type
fn into_option<T>(self, map_obj: impl Fn(Self) -> T) -> Option<T> {
Python::with_gil(|py| {
if self.gil_ref(py).is_none() {
None
} else {
Some(map_obj(self))
}
})
}
} }
impl PyObjectExt for PyObject { impl PyObjectExt for PyObject {
@@ -57,6 +71,21 @@ impl PyObjectExt for PyObject {
} }
} }
/// Convenience extensions to [PyDict]
pub trait PyDictExt {
/// Set item in dict only if value is Some, otherwise do nothing.
fn set_opt_item<K: ToPyObject, V: ToPyObject>(&self, key: K, value: Option<V>) -> PyResult<()>;
}
impl PyDictExt for PyDict {
fn set_opt_item<K: ToPyObject, V: ToPyObject>(&self, key: K, value: Option<V>) -> PyResult<()> {
if let Some(value) = value {
self.set_item(key, value)?
}
Ok(())
}
}
/// Wrapper to make Rust closures ([Fn] implementations) callable from Python. /// Wrapper to make Rust closures ([Fn] implementations) callable from Python.
/// ///
/// The Python callable form returns a Python `None`. /// The Python callable form returns a Python `None`.

View File

@@ -14,7 +14,10 @@
//! GATT profiles //! GATT profiles
use crate::wrapper::gatt_client::{CharacteristicProxy, ProfileServiceProxy}; use crate::wrapper::{
gatt_client::{CharacteristicProxy, ProfileServiceProxy},
PyObjectExt,
};
use pyo3::{intern, PyObject, PyResult, Python}; use pyo3::{intern, PyObject, PyResult, Python};
/// Exposes the battery GATT service /// Exposes the battery GATT service
@@ -26,13 +29,7 @@ impl BatteryServiceProxy {
Python::with_gil(|py| { Python::with_gil(|py| {
self.0 self.0
.getattr(py, intern!(py, "battery_level")) .getattr(py, intern!(py, "battery_level"))
.map(|level| { .map(|level| level.into_option(CharacteristicProxy))
if level.is_none(py) {
None
} else {
Some(CharacteristicProxy(level))
}
})
}) })
} }
} }

View File

@@ -0,0 +1,97 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! This tool generates Rust code with assigned number tables from the equivalent Python.
use pyo3::{
intern,
types::{PyDict, PyModule},
PyResult, Python,
};
use std::{collections, env, fs, path};
fn main() -> anyhow::Result<()> {
pyo3::prepare_freethreaded_python();
let mut dir = path::Path::new(&env::var("CARGO_MANIFEST_DIR")?).to_path_buf();
dir.push("src/wrapper/assigned_numbers");
company_ids(&dir)?;
Ok(())
}
fn company_ids(base_dir: &path::Path) -> anyhow::Result<()> {
let mut sorted_ids = load_company_ids()?.into_iter().collect::<Vec<_>>();
sorted_ids.sort_by_key(|(id, _name)| *id);
let mut contents = String::new();
contents.push_str(LICENSE_HEADER);
contents.push_str("\n\n");
contents.push_str(
"// auto-generated by gen_assigned_numbers, do not edit
use crate::wrapper::core::Uuid16;
use lazy_static::lazy_static;
use std::collections;
lazy_static! {
/// Assigned company IDs
pub static ref COMPANY_IDS: collections::HashMap<Uuid16, &'static str> = [
",
);
for (id, name) in sorted_ids {
contents.push_str(&format!(" ({id}_u16, r#\"{name}\"#),\n"))
}
contents.push_str(
" ]
.into_iter()
.map(|(id, name)| (Uuid16::from_be_bytes(id.to_be_bytes()), name))
.collect();
}
",
);
let mut company_ids = base_dir.to_path_buf();
company_ids.push("company_ids.rs");
fs::write(&company_ids, contents)?;
Ok(())
}
fn load_company_ids() -> PyResult<collections::HashMap<u16, String>> {
Python::with_gil(|py| {
PyModule::import(py, intern!(py, "bumble.company_ids"))?
.getattr(intern!(py, "COMPANY_IDENTIFIERS"))?
.downcast::<PyDict>()?
.into_iter()
.map(|(k, v)| Ok((k.extract::<u16>()?, v.str()?.to_str()?.to_string())))
.collect::<PyResult<collections::HashMap<_, _>>>()
})
}
const LICENSE_HEADER: &str = r#"// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License."#;

View File

@@ -32,17 +32,18 @@ package_dir =
include_package_data = True include_package_data = True
install_requires = install_requires =
aiohttp ~= 3.8; platform_system!='Emscripten' aiohttp ~= 3.8; platform_system!='Emscripten'
appdirs >= 1.4 appdirs >= 1.4; platform_system!='Emscripten'
bt-test-interfaces >= 0.0.2 bt-test-interfaces >= 0.0.2; platform_system!='Emscripten'
click == 8.1.3; platform_system!='Emscripten' click == 8.1.3; platform_system!='Emscripten'
cryptography == 35; platform_system!='Emscripten' cryptography == 39; platform_system!='Emscripten'
grpcio == 1.51.1; platform_system!='Emscripten' grpcio == 1.57.0; platform_system!='Emscripten'
humanize >= 4.6.0 humanize >= 4.6.0; platform_system!='Emscripten'
libusb1 >= 2.0.1; platform_system!='Emscripten' libusb1 >= 2.0.1; platform_system!='Emscripten'
libusb-package == 1.0.26.1; platform_system!='Emscripten' libusb-package == 1.0.26.1; platform_system!='Emscripten'
platformdirs == 3.10.0; platform_system!='Emscripten'
prompt_toolkit >= 3.0.16; platform_system!='Emscripten' prompt_toolkit >= 3.0.16; platform_system!='Emscripten'
prettytable >= 3.6.0 prettytable >= 3.6.0; platform_system!='Emscripten'
protobuf >= 3.12.4 protobuf >= 3.12.4; platform_system!='Emscripten'
pyee >= 8.2.2 pyee >= 8.2.2
pyserial-asyncio >= 0.5; platform_system!='Emscripten' pyserial-asyncio >= 0.5; platform_system!='Emscripten'
pyserial >= 3.5; platform_system!='Emscripten' pyserial >= 3.5; platform_system!='Emscripten'
@@ -81,7 +82,7 @@ test =
coverage >= 6.4 coverage >= 6.4
development = development =
black == 22.10 black == 22.10
grpcio-tools >= 1.51.1 grpcio-tools >= 1.57.0
invoke >= 1.7.3 invoke >= 1.7.3
mypy == 1.2.0 mypy == 1.2.0
nox >= 2022 nox >= 2022

View File

@@ -1,28 +0,0 @@
<!DOCTYPE html>
<html>
<head>
<title>Audio WAV Player</title>
</head>
<body>
<h1>Audio WAV Player</h1>
<audio id="audioPlayer" controls>
<source src="" type="audio/wav">
</audio>
<script>
const audioPlayer = document.getElementById('audioPlayer');
const ws = new WebSocket('ws://localhost:8080');
let mediaSource = new MediaSource();
audioPlayer.src = URL.createObjectURL(mediaSource);
mediaSource.addEventListener('sourceopen', function(event) {
const sourceBuffer = mediaSource.addSourceBuffer('audio/wav');
ws.onmessage = function(event) {
sourceBuffer.appendBuffer(event.data);
};
});
</script>
</body>
</html>

View File

@@ -177,3 +177,33 @@ project_tasks.add_task(lint)
project_tasks.add_task(format_code, name="format") project_tasks.add_task(format_code, name="format")
project_tasks.add_task(check_types, name="check-types") project_tasks.add_task(check_types, name="check-types")
project_tasks.add_task(pre_commit) project_tasks.add_task(pre_commit)
# -----------------------------------------------------------------------------
# Web
# -----------------------------------------------------------------------------
web_tasks = Collection()
ns.add_collection(web_tasks, name="web")
# -----------------------------------------------------------------------------
@task
def serve(ctx, port=8000):
"""
Run a simple HTTP server for the examples under the `web` directory.
"""
import http.server
address = ("", port)
class Handler(http.server.SimpleHTTPRequestHandler):
def __init__(self, *args, **kwargs):
super().__init__(*args, directory="web", **kwargs)
server = http.server.HTTPServer(address, Handler)
print(f"Now serving on port {port} 🕸️")
server.serve_forever()
# -----------------------------------------------------------------------------
web_tasks.add_task(serve)

13
tests/__init__.py Normal file
View File

@@ -0,0 +1,13 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

35
tests/at_test.py Normal file
View File

@@ -0,0 +1,35 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from bumble import at
def test_tokenize_parameters():
assert at.tokenize_parameters(b'1, 2, 3') == [b'1', b',', b'2', b',', b'3']
assert at.tokenize_parameters(b'"1, 2, 3"') == [b'1, 2, 3']
assert at.tokenize_parameters(b'(1, "2, 3")') == [b'(', b'1', b',', b'2, 3', b')']
def test_parse_parameters():
assert at.parse_parameters(b'1, 2, 3') == [b'1', b'2', b'3']
assert at.parse_parameters(b'1,, 3') == [b'1', b'', b'3']
assert at.parse_parameters(b'"1, 2, 3"') == [b'1, 2, 3']
assert at.parse_parameters(b'1, (2, (3))') == [b'1', [b'2', [b'3']]]
assert at.parse_parameters(b'1, (2, "3, 4"), 5') == [b'1', [b'2', b'3, 4'], b'5']
# -----------------------------------------------------------------------------
if __name__ == '__main__':
test_tokenize_parameters()
test_parse_parameters()

View File

@@ -46,6 +46,7 @@ from bumble.hci import (
HCI_LE_Set_Advertising_Parameters_Command, HCI_LE_Set_Advertising_Parameters_Command,
HCI_LE_Set_Default_PHY_Command, HCI_LE_Set_Default_PHY_Command,
HCI_LE_Set_Event_Mask_Command, HCI_LE_Set_Event_Mask_Command,
HCI_LE_Set_Extended_Advertising_Enable_Command,
HCI_LE_Set_Extended_Scan_Parameters_Command, HCI_LE_Set_Extended_Scan_Parameters_Command,
HCI_LE_Set_Random_Address_Command, HCI_LE_Set_Random_Address_Command,
HCI_LE_Set_Scan_Enable_Command, HCI_LE_Set_Scan_Enable_Command,
@@ -422,6 +423,25 @@ def test_HCI_LE_Set_Extended_Scan_Parameters_Command():
basic_check(command) basic_check(command)
# -----------------------------------------------------------------------------
def test_HCI_LE_Set_Extended_Advertising_Enable_Command():
command = HCI_Packet.from_bytes(
bytes.fromhex('0139200e010301050008020600090307000a')
)
assert command.enable == 1
assert command.advertising_handles == [1, 2, 3]
assert command.durations == [5, 6, 7]
assert command.max_extended_advertising_events == [8, 9, 10]
command = HCI_LE_Set_Extended_Advertising_Enable_Command(
enable=1,
advertising_handles=[1, 2, 3],
durations=[5, 6, 7],
max_extended_advertising_events=[8, 9, 10],
)
basic_check(command)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def test_address(): def test_address():
a = Address('C4:F2:17:1A:1D:BB') a = Address('C4:F2:17:1A:1D:BB')
@@ -478,6 +498,7 @@ def run_test_commands():
test_HCI_LE_Read_Remote_Features_Command() test_HCI_LE_Read_Remote_Features_Command()
test_HCI_LE_Set_Default_PHY_Command() test_HCI_LE_Set_Default_PHY_Command()
test_HCI_LE_Set_Extended_Scan_Parameters_Command() test_HCI_LE_Set_Extended_Scan_Parameters_Command()
test_HCI_LE_Set_Extended_Advertising_Enable_Command()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------

100
tests/hfp_test.py Normal file
View File

@@ -0,0 +1,100 @@
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import logging
import os
import pytest
from typing import Tuple
from .test_utils import TwoDevices
from bumble import hfp
from bumble import rfcomm
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
async def make_hfp_connections(
hf_config: hfp.Configuration,
) -> Tuple[hfp.HfProtocol, hfp.HfpProtocol]:
# Setup devices
devices = TwoDevices()
await devices.setup_connection()
# Setup RFCOMM channel
wait_dlc = asyncio.get_running_loop().create_future()
rfcomm_channel = rfcomm.Server(devices.devices[0]).listen(
lambda dlc: wait_dlc.set_result(dlc)
)
assert devices.connections[0]
assert devices.connections[1]
client_mux = await rfcomm.Client(devices.devices[1], devices.connections[1]).start()
client_dlc = await client_mux.open_dlc(rfcomm_channel)
server_dlc = await wait_dlc
# Setup HFP connection
hf = hfp.HfProtocol(client_dlc, hf_config)
ag = hfp.HfpProtocol(server_dlc)
return hf, ag
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_slc():
hf_config = hfp.Configuration(
supported_hf_features=[], supported_hf_indicators=[], supported_audio_codecs=[]
)
hf, ag = await make_hfp_connections(hf_config)
async def ag_loop():
while line := await ag.next_line():
if line.startswith('AT+BRSF'):
ag.send_response_line('+BRSF: 0')
elif line.startswith('AT+CIND=?'):
ag.send_response_line(
'+CIND: ("call",(0,1)),("callsetup",(0-3)),("service",(0-1)),'
'("signal",(0-5)),("roam",(0,1)),("battchg",(0-5)),'
'("callheld",(0-2))'
)
elif line.startswith('AT+CIND?'):
ag.send_response_line('+CIND: 0,0,1,4,1,5,0')
ag.send_response_line('OK')
ag_task = asyncio.create_task(ag_loop())
await hf.initiate_slc()
ag_task.cancel()
# -----------------------------------------------------------------------------
async def run():
await test_slc()
# -----------------------------------------------------------------------------
if __name__ == '__main__':
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(run())

View File

@@ -18,6 +18,8 @@
import asyncio import asyncio
import json import json
import logging import logging
import pathlib
import pytest
import tempfile import tempfile
import os import os
@@ -83,87 +85,95 @@ JSON3 = """
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def test_basic(): @pytest.fixture
with tempfile.NamedTemporaryFile(mode="r+", encoding='utf-8') as file: def temporary_file():
keystore = JsonKeyStore('my_namespace', file.name) file = tempfile.NamedTemporaryFile(delete=False)
file.close()
yield file.name
pathlib.Path(file.name).unlink()
# -----------------------------------------------------------------------------
async def test_basic(temporary_file):
with open(temporary_file, mode='w', encoding='utf-8') as file:
file.write("{}") file.write("{}")
file.flush() file.flush()
keys = await keystore.get_all() keystore = JsonKeyStore('my_namespace', temporary_file)
assert len(keys) == 0
keys = PairingKeys() keys = await keystore.get_all()
await keystore.update('foo', keys) assert len(keys) == 0
foo = await keystore.get('foo')
assert foo is not None
assert foo.ltk is None
ltk = bytes([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
keys.ltk = PairingKeys.Key(ltk)
await keystore.update('foo', keys)
foo = await keystore.get('foo')
assert foo is not None
assert foo.ltk is not None
assert foo.ltk.value == ltk
file.flush() keys = PairingKeys()
with open(file.name, "r", encoding="utf-8") as json_file: await keystore.update('foo', keys)
json_data = json.load(json_file) foo = await keystore.get('foo')
assert 'my_namespace' in json_data assert foo is not None
assert 'foo' in json_data['my_namespace'] assert foo.ltk is None
assert 'ltk' in json_data['my_namespace']['foo'] ltk = bytes([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
keys.ltk = PairingKeys.Key(ltk)
await keystore.update('foo', keys)
foo = await keystore.get('foo')
assert foo is not None
assert foo.ltk is not None
assert foo.ltk.value == ltk
with open(file.name, "r", encoding="utf-8") as json_file:
json_data = json.load(json_file)
assert 'my_namespace' in json_data
assert 'foo' in json_data['my_namespace']
assert 'ltk' in json_data['my_namespace']['foo']
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def test_parsing(): async def test_parsing(temporary_file):
with tempfile.NamedTemporaryFile(mode="w", encoding='utf-8') as file: with open(temporary_file, mode='w', encoding='utf-8') as file:
keystore = JsonKeyStore('my_namespace', file.name)
file.write(JSON1) file.write(JSON1)
file.flush() file.flush()
foo = await keystore.get('14:7D:DA:4E:53:A8/P') keystore = JsonKeyStore('my_namespace', file.name)
assert foo is not None foo = await keystore.get('14:7D:DA:4E:53:A8/P')
assert foo.ltk.value == bytes.fromhex('d1897ee10016eb1a08e4e037fd54c683') assert foo is not None
assert foo.ltk.value == bytes.fromhex('d1897ee10016eb1a08e4e037fd54c683')
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def test_default_namespace(): async def test_default_namespace(temporary_file):
with tempfile.NamedTemporaryFile(mode="w", encoding='utf-8') as file: with open(temporary_file, mode='w', encoding='utf-8') as file:
keystore = JsonKeyStore(None, file.name)
file.write(JSON1) file.write(JSON1)
file.flush() file.flush()
all_keys = await keystore.get_all() keystore = JsonKeyStore(None, file.name)
assert len(all_keys) == 1 all_keys = await keystore.get_all()
name, keys = all_keys[0] assert len(all_keys) == 1
assert name == '14:7D:DA:4E:53:A8/P' name, keys = all_keys[0]
assert keys.irk.value == bytes.fromhex('e7b2543b206e4e46b44f9e51dad22bd1') assert name == '14:7D:DA:4E:53:A8/P'
assert keys.irk.value == bytes.fromhex('e7b2543b206e4e46b44f9e51dad22bd1')
with tempfile.NamedTemporaryFile(mode="w", encoding='utf-8') as file: with open(temporary_file, mode='w', encoding='utf-8') as file:
keystore = JsonKeyStore(None, file.name)
file.write(JSON2) file.write(JSON2)
file.flush() file.flush()
keys = PairingKeys() keystore = JsonKeyStore(None, file.name)
ltk = bytes([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]) keys = PairingKeys()
keys.ltk = PairingKeys.Key(ltk) ltk = bytes([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
await keystore.update('foo', keys) keys.ltk = PairingKeys.Key(ltk)
file.flush() await keystore.update('foo', keys)
with open(file.name, "r", encoding="utf-8") as json_file: with open(file.name, "r", encoding="utf-8") as json_file:
json_data = json.load(json_file) json_data = json.load(json_file)
assert '__DEFAULT__' in json_data assert '__DEFAULT__' in json_data
assert 'foo' in json_data['__DEFAULT__'] assert 'foo' in json_data['__DEFAULT__']
assert 'ltk' in json_data['__DEFAULT__']['foo'] assert 'ltk' in json_data['__DEFAULT__']['foo']
with tempfile.NamedTemporaryFile(mode="w", encoding='utf-8') as file: with open(temporary_file, mode='w', encoding='utf-8') as file:
keystore = JsonKeyStore(None, file.name)
file.write(JSON3) file.write(JSON3)
file.flush() file.flush()
all_keys = await keystore.get_all() keystore = JsonKeyStore(None, file.name)
assert len(all_keys) == 1 all_keys = await keystore.get_all()
name, keys = all_keys[0] assert len(all_keys) == 1
assert name == '14:7D:DA:4E:53:A8/P' name, keys = all_keys[0]
assert keys.irk.value == bytes.fromhex('e7b2543b206e4e46b44f9e51dad22bd1') assert name == '14:7D:DA:4E:53:A8/P'
assert keys.irk.value == bytes.fromhex('e7b2543b206e4e46b44f9e51dad22bd1')
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------

View File

@@ -21,13 +21,9 @@ import os
import random import random
import pytest import pytest
from bumble.controller import Controller
from bumble.link import LocalLink
from bumble.device import Device
from bumble.host import Host
from bumble.transport import AsyncPipeSink
from bumble.core import ProtocolError from bumble.core import ProtocolError
from bumble.l2cap import L2CAP_Connection_Request from bumble.l2cap import L2CAP_Connection_Request
from .test_utils import TwoDevices
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -37,60 +33,6 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class TwoDevices:
def __init__(self):
self.connections = [None, None]
self.link = LocalLink()
self.controllers = [
Controller('C1', link=self.link),
Controller('C2', link=self.link),
]
self.devices = [
Device(
address='F0:F1:F2:F3:F4:F5',
host=Host(self.controllers[0], AsyncPipeSink(self.controllers[0])),
),
Device(
address='F5:F4:F3:F2:F1:F0',
host=Host(self.controllers[1], AsyncPipeSink(self.controllers[1])),
),
]
self.paired = [None, None]
def on_connection(self, which, connection):
self.connections[which] = connection
def on_paired(self, which, keys):
self.paired[which] = keys
# -----------------------------------------------------------------------------
async def setup_connection():
# Create two devices, each with a controller, attached to the same link
two_devices = TwoDevices()
# Attach listeners
two_devices.devices[0].on(
'connection', lambda connection: two_devices.on_connection(0, connection)
)
two_devices.devices[1].on(
'connection', lambda connection: two_devices.on_connection(1, connection)
)
# Start
await two_devices.devices[0].power_on()
await two_devices.devices[1].power_on()
# Connect the two devices
await two_devices.devices[0].connect(two_devices.devices[1].random_address)
# Check the post conditions
assert two_devices.connections[0] is not None
assert two_devices.connections[1] is not None
return two_devices
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -132,7 +74,8 @@ def test_helpers():
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_basic_connection(): async def test_basic_connection():
devices = await setup_connection() devices = TwoDevices()
await devices.setup_connection()
psm = 1234 psm = 1234
# Check that if there's no one listening, we can't connect # Check that if there's no one listening, we can't connect
@@ -184,7 +127,8 @@ async def test_basic_connection():
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def transfer_payload(max_credits, mtu, mps): async def transfer_payload(max_credits, mtu, mps):
devices = await setup_connection() devices = TwoDevices()
await devices.setup_connection()
received = [] received = []
@@ -226,7 +170,8 @@ async def test_transfer():
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_bidirectional_transfer(): async def test_bidirectional_transfer():
devices = await setup_connection() devices = TwoDevices()
await devices.setup_connection()
client_received = [] client_received = []
server_received = [] server_received = []

View File

@@ -15,15 +15,30 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from bumble.core import UUID import asyncio
from bumble.sdp import DataElement import logging
import os
from bumble.core import UUID, BT_L2CAP_PROTOCOL_ID, BT_RFCOMM_PROTOCOL_ID
from bumble.sdp import (
DataElement,
ServiceAttribute,
Client,
Server,
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
SDP_PUBLIC_BROWSE_ROOT,
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
)
from .test_utils import TwoDevices
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# pylint: disable=invalid-name # pylint: disable=invalid-name
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def basic_check(x): def basic_check(x: DataElement) -> None:
serialized = bytes(x) serialized = bytes(x)
if len(serialized) < 500: if len(serialized) < 500:
print('Original:', x) print('Original:', x)
@@ -41,7 +56,7 @@ def basic_check(x):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def test_data_elements(): def test_data_elements() -> None:
e = DataElement(DataElement.NIL, None) e = DataElement(DataElement.NIL, None)
basic_check(e) basic_check(e)
@@ -157,5 +172,108 @@ def test_data_elements():
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
if __name__ == '__main__': def sdp_records():
return {
0x00010001: [
ServiceAttribute(
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
DataElement.unsigned_integer_32(0x00010001),
),
ServiceAttribute(
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
DataElement.sequence([DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)]),
),
ServiceAttribute(
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[DataElement.uuid(UUID('E6D55659-C8B4-4B85-96BB-B1143AF6D3AE'))]
),
),
ServiceAttribute(
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.sequence([DataElement.uuid(BT_L2CAP_PROTOCOL_ID)]),
]
),
),
]
}
# -----------------------------------------------------------------------------
async def test_service_search():
# Setup connections
devices = TwoDevices()
await devices.setup_connection()
assert devices.connections[0]
assert devices.connections[1]
# Register SDP service
devices.devices[0].sdp_server.service_records.update(sdp_records())
# Search for service
client = Client(devices.devices[1])
await client.connect(devices.connections[1])
services = await client.search_services(
[UUID('E6D55659-C8B4-4B85-96BB-B1143AF6D3AE')]
)
# Then
assert services[0] == 0x00010001
# -----------------------------------------------------------------------------
async def test_service_attribute():
# Setup connections
devices = TwoDevices()
await devices.setup_connection()
# Register SDP service
devices.devices[0].sdp_server.service_records.update(sdp_records())
# Search for service
client = Client(devices.devices[1])
await client.connect(devices.connections[1])
attributes = await client.get_attributes(
0x00010001, [SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID]
)
# Then
assert attributes[0].value.value == sdp_records()[0x00010001][0].value.value
# -----------------------------------------------------------------------------
async def test_service_search_attribute():
# Setup connections
devices = TwoDevices()
await devices.setup_connection()
# Register SDP service
devices.devices[0].sdp_server.service_records.update(sdp_records())
# Search for service
client = Client(devices.devices[1])
await client.connect(devices.connections[1])
attributes = await client.search_attributes(
[UUID('E6D55659-C8B4-4B85-96BB-B1143AF6D3AE')], [(0x0000FFFF, 8)]
)
# Then
for expect, actual in zip(attributes, sdp_records().values()):
assert expect.id == actual.id
assert expect.value == actual.value
# -----------------------------------------------------------------------------
async def run():
test_data_elements() test_data_elements()
await test_service_attribute()
await test_service_search()
await test_service_search_attribute()
# -----------------------------------------------------------------------------
if __name__ == '__main__':
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(run())

Some files were not shown because too many files have changed in this diff Show More