Compare commits

...

43 Commits

Author SHA1 Message Date
Gilles Boccon-Gibod
43e632f83c Merge pull request #244 from google/gbg/hci-source-termination-mode
add sink method for lost transports
2023-08-18 10:17:11 -07:00
Gilles Boccon-Gibod
623298b0e9 emit flush event when transport lost 2023-08-18 09:59:15 -07:00
Gilles Boccon-Gibod
6e8c44b5e6 Merge pull request #249 from zxzxwu/player
Support SBC in speaker.app
2023-08-18 09:55:23 -07:00
Josh Wu
ec4dcc174e Support SBC in speaker.app 2023-08-18 17:13:11 +08:00
Charlie Boutier
b247aca3b4 pandora_server: add support to accept bumble config file 2023-08-17 14:24:56 -07:00
Gilles Boccon-Gibod
f4add16aea Merge pull request #241 from hchataing/hfp-hf
hfp: Implement initiate SLC procedure for HFP-HF
2023-08-14 10:32:55 -07:00
Gilles Boccon-Gibod
2bfec3c4ed add sink method for lost transports 2023-08-12 10:54:20 -07:00
Henri Chataing
9963b51c04 hfp: Implement initiate SLC procedure for HFP-HF 2023-08-10 08:37:54 -07:00
Gilles Boccon-Gibod
fe28473ba8 Merge pull request #234 from zxzxwu/addr
Support address resolution offload
2023-08-08 21:30:13 -07:00
Gilles Boccon-Gibod
53d66bc74a Merge pull request #237 from marshallpierce/mp/company-ids
Faster company id table
2023-08-08 21:29:45 -07:00
Marshall Pierce
e2c1ad5342 Faster company id table
Following up on the [loose end from the initial
PR](https://github.com/google/bumble/pull/207#discussion_r1278015116),
we can avoid accessing the Python company id map at runtime by doing
code gen ahead of time.

Using an example to do the code gen avoids even the small build slowdown
from invoking the code gen logic in build.rs, but more importantly,
means that it's still a totally boring normal build that won't require
any IDE setup, etc, to work for everyone. Since the company ID list
changes rarely, and there's a test to ensure it always matches, this
seems like a good trade.
2023-08-04 10:12:52 -06:00
Josh Wu
6399c5fb04 Auto add device to resolving list after pairing 2023-08-03 20:51:00 +08:00
Josh Wu
784cf4f26a Add a flag to enable LE address resolution 2023-08-03 20:50:57 +08:00
Josh Wu
0301b1a999 Pandora: Configure identity address type 2023-08-02 11:31:07 -07:00
Lucas Abel
3ab2cd5e71 pandora: decrease all info logs to debug 2023-08-02 10:56:41 -07:00
uael
6ea669531a pandora: add tcp option to transport configuration
* Add a fallback to `tcp` when `transport` is not set.
* Default the `tcp` transport to the default rootcanal HCI address.
2023-08-01 08:51:12 -07:00
Josh Wu
cbbada4748 SMP: Delegate distributed address type 2023-08-01 08:38:03 -07:00
Gilles Boccon-Gibod
152b8d1233 Merge pull request #230 from google/gbg/hci-object-array
add support for field arrays in hci packet definitions
2023-08-01 07:44:31 -07:00
Gilles Boccon-Gibod
bdad225033 add support for field arrays in hci packet definitions 2023-07-30 22:19:10 -07:00
Gilles Boccon-Gibod
8eeb58e467 Merge pull request #207 from marshallpierce/mp/rust-poc
Proof-of-concept Rust wrapper
2023-07-28 20:14:23 -07:00
Marshall Pierce
91971433d2 PR feedback 2023-07-28 14:34:02 -06:00
Gilles Boccon-Gibod
a0a4bd457f Merge pull request #227 from google/gbg/py11
compatibility with python 11
2023-07-28 12:54:30 -07:00
Gilles Boccon-Gibod
4ffc050eed restore python < 11 compat 2023-07-27 16:37:27 -07:00
Gilles Boccon-Gibod
60678419a0 compatibility with python 11 2023-07-27 14:55:28 -07:00
Gilles Boccon-Gibod
648dcc9305 use type object instead of type strings 2023-07-27 13:19:37 -07:00
Josh Wu
190529184e L2CAP: Import device.Connection for typing 2023-07-27 09:07:55 -07:00
Josh Wu
46eb81466d Add more argement hints in L2CAP 2023-07-27 09:07:55 -07:00
Josh Wu
9c70c487b9 Add type hint to L2CAP module 2023-07-27 09:07:55 -07:00
Josh Wu
43234d7c3e Use with-patch to mock SMP session 2023-07-27 08:00:36 -07:00
Josh Wu
dbf878dc3f SMP: Remove PairingMethod.__str__ 2023-07-27 08:00:36 -07:00
Josh Wu
f6c0bd88d7 SMP: Do not send phase 2 commands in CTKD 2023-07-27 08:00:36 -07:00
Josh Wu
8440b7fbf1 SMP: Refactor pairing method as enum 2023-07-27 08:00:36 -07:00
Gilles Boccon-Gibod
808ab54135 Merge pull request #221 from google/gbg/core-classes
add new device class major/minor identifiers
2023-07-25 09:49:05 -07:00
Gilles Boccon-Gibod
52b29ad680 add new device class major/minor identifiers 2023-07-24 17:41:57 -07:00
Gilles Boccon-Gibod
d41bf9c587 Merge pull request #216 from google/gbg/host-buffer-size-command
accept Host Buffer Size Command in the controller
2023-07-24 09:05:10 -07:00
Gilles Boccon-Gibod
b758825164 add flow control command 2023-07-22 13:04:39 -07:00
Gilles Boccon-Gibod
779dfe5473 accept Host Buffer Size Command in the controller 2023-07-21 19:36:26 -07:00
Marshall Pierce
afb21220e2 Proof-of-concept Rust wrapper
This contains Rust wrappers around enough of the Python API to implement Rust versions of the `battery_client` and `run_scanner` examples. The goal is to gather feedback on the approach, and of course to show that it is possible.

The module structure mirrors that of the Python. The Rust API is not optimally Rust-y, but given the constraints of everything having to delegate to Python, it's at least usable.

Notably, this does not yet solve the packaging problem: users must have an appropriate virtualenv, libpython, etc. [PyOxidizer](https://github.com/indygreg/PyOxidizer) may be a viable path there.
2023-07-20 10:50:15 -06:00
Gilles Boccon-Gibod
f9a4c7518e Merge pull request #214 from marshallpierce/mp/scanner-rssi
Add a space after RSSI
2023-07-14 10:52:54 -07:00
Marshall Pierce
bad2fdf69f Add a space after RSSI
The other data elements have a space, so I'm guessing that RSSI
is intended to as well. Perhaps there's some subtle reason why
it should have a space, though, in which case feel free to
close this.

Output now looks like this:

```
>>> 58:D3:49:E7:40:DA/P [PUBLIC]:
  RSSI: -67
  [Flags]: LE General,BR/EDR C,BR/EDR H
  [TX Power Level]: 4
  [Manufacturer Specific Data]: company=Apple, Inc., data=0f08c00af4392b00040c10020f04
```
2023-07-13 12:47:45 -06:00
Lucas Abel
a84df469cd pairing: handle user errors from all delegate calls 2023-07-12 11:03:21 -07:00
Gilles Boccon-Gibod
03e33e39bd Merge pull request #211 from google/gbg/fix-ws-transport-doc
fix doc for ws-client ws-server transports
2023-07-12 07:06:32 -07:00
Gilles Boccon-Gibod
753fb69272 fix doc for ws-client ws-server transports 2023-07-12 06:06:20 -07:00
58 changed files with 8459 additions and 820 deletions

View File

@@ -16,7 +16,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10"]
python-version: ["3.8", "3.9", "3.10", "3.11"]
fail-fast: false
steps:
@@ -41,3 +41,30 @@ jobs:
run: |
inv build
inv build.mkdocs
build-rust:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [ "3.8", "3.9", "3.10" ]
fail-fast: false
steps:
- name: Check out from Git
uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install ".[build,test,development,documentation]"
- name: Install Rust toolchain
uses: actions-rust-lang/setup-rust-toolchain@v1
with:
components: clippy,rustfmt
- name: Rust Lints
run: cd rust && cargo fmt --check && cargo clippy --all-targets -- --deny warnings
- name: Rust Build
run: cd rust && cargo build --all-targets
- name: Rust Tests
run: cd rust && cargo test

1
.gitignore vendored
View File

@@ -9,3 +9,4 @@ __pycache__
# generated by setuptools_scm
bumble/_version.py
.vscode/launch.json
/.idea

View File

@@ -1,8 +1,10 @@
import asyncio
import click
import logging
import json
from bumble.pandora import PandoraDevice, serve
from typing import Dict, Any
BUMBLE_SERVER_GRPC_PORT = 7999
ROOTCANAL_PORT_CUTTLEFISH = 7300
@@ -18,13 +20,30 @@ ROOTCANAL_PORT_CUTTLEFISH = 7300
help='HCI transport',
default=f'tcp-client:127.0.0.1:<rootcanal-port>',
)
def main(grpc_port: int, rootcanal_port: int, transport: str) -> None:
@click.option(
'--config',
help='Bumble json configuration file',
)
def main(grpc_port: int, rootcanal_port: int, transport: str, config: str) -> None:
if '<rootcanal-port>' in transport:
transport = transport.replace('<rootcanal-port>', str(rootcanal_port))
device = PandoraDevice({'transport': transport})
bumble_config = retrieve_config(config)
if 'transport' not in bumble_config.keys():
bumble_config.update({'transport': transport})
device = PandoraDevice(bumble_config)
logging.basicConfig(level=logging.DEBUG)
asyncio.run(serve(device, port=grpc_port))
def retrieve_config(config: str) -> Dict[str, Any]:
if not config:
return {}
with open(config, 'r') as f:
return json.load(f)
if __name__ == '__main__':
main() # pylint: disable=no-value-for-parameter

View File

@@ -228,10 +228,11 @@ class FfplayOutput(QueuedOutput):
subprocess: Optional[asyncio.subprocess.Process]
ffplay_task: Optional[asyncio.Task]
def __init__(self) -> None:
super().__init__(AacAudioExtractor())
def __init__(self, codec: str) -> None:
super().__init__(AudioExtractor.create(codec))
self.subprocess = None
self.ffplay_task = None
self.codec = codec
async def start(self):
if self.started:
@@ -240,7 +241,7 @@ class FfplayOutput(QueuedOutput):
await super().start()
self.subprocess = await asyncio.create_subprocess_shell(
'ffplay -acodec aac pipe:0',
f'ffplay -f {self.codec} pipe:0',
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
@@ -419,7 +420,7 @@ class Speaker:
self.outputs = []
for output in outputs:
if output == '@ffplay':
self.outputs.append(FfplayOutput())
self.outputs.append(FfplayOutput(codec))
continue
# Default to FileOutput
@@ -708,17 +709,6 @@ def speaker(
):
"""Run the speaker."""
# ffplay only works with AAC for now
if codec != 'aac' and '@ffplay' in output:
print(
color(
f'{codec} not supported with @ffplay output, '
'@ffplay output will be skipped',
'yellow',
)
)
output = list(filter(lambda x: x != '@ffplay', output))
if '@ffplay' in output:
# Check if ffplay is installed
try:

85
bumble/at.py Normal file
View File

@@ -0,0 +1,85 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Union
def tokenize_parameters(buffer: bytes) -> List[bytes]:
"""Split input parameters into tokens.
Removes space characters outside of double quote blocks:
T-rec-V-25 - 5.2.1 Command line general format: "Space characters (IA5 2/0)
are ignored [..], unless they are embedded in numeric or string constants"
Raises ValueError in case of invalid input string."""
tokens = []
in_quotes = False
token = bytearray()
for b in buffer:
char = bytearray([b])
if in_quotes:
token.extend(char)
if char == b'\"':
in_quotes = False
tokens.append(token[1:-1])
token = bytearray()
else:
if char == b' ':
pass
elif char == b',' or char == b')':
tokens.append(token)
tokens.append(char)
token = bytearray()
elif char == b'(':
if len(token) > 0:
raise ValueError("open_paren following regular character")
tokens.append(char)
elif char == b'"':
if len(token) > 0:
raise ValueError("quote following regular character")
in_quotes = True
token.extend(char)
else:
token.extend(char)
tokens.append(token)
return [bytes(token) for token in tokens if len(token) > 0]
def parse_parameters(buffer: bytes) -> List[Union[bytes, list]]:
"""Parse the parameters using the comma and parenthesis separators.
Raises ValueError in case of invalid input string."""
tokens = tokenize_parameters(buffer)
accumulator: List[list] = [[]]
current: Union[bytes, list] = bytes()
for token in tokens:
if token == b',':
accumulator[-1].append(current)
current = bytes()
elif token == b'(':
accumulator.append([])
elif token == b')':
if len(accumulator) < 2:
raise ValueError("close_paren without matching open_paren")
accumulator[-1].append(current)
current = accumulator.pop()
else:
current = token
accumulator[-1].append(current)
if len(accumulator) > 1:
raise ValueError("missing close_paren")
return accumulator[0]

View File

@@ -188,6 +188,8 @@ class Controller:
if link:
link.add_controller(self)
self.terminated = asyncio.get_running_loop().create_future()
@property
def host(self):
return self.hci_sink
@@ -288,10 +290,9 @@ class Controller:
if self.host:
self.host.on_packet(packet.to_bytes())
# This method allow the controller to emulate the same API as a transport source
# This method allows the controller to emulate the same API as a transport source
async def wait_for_termination(self):
# For now, just wait forever
await asyncio.get_running_loop().create_future()
await self.terminated
############################################################
# Link connections
@@ -654,7 +655,7 @@ class Controller:
def on_hci_create_connection_command(self, command):
'''
See Bluetooth spec Vol 2, Part E - 7.1.5 Create Connection command
See Bluetooth spec Vol 4, Part E - 7.1.5 Create Connection command
'''
if self.link is None:
@@ -685,7 +686,7 @@ class Controller:
def on_hci_disconnect_command(self, command):
'''
See Bluetooth spec Vol 2, Part E - 7.1.6 Disconnect Command
See Bluetooth spec Vol 4, Part E - 7.1.6 Disconnect Command
'''
# First, say that the disconnection is pending
self.send_hci_packet(
@@ -719,7 +720,7 @@ class Controller:
def on_hci_accept_connection_request_command(self, command):
'''
See Bluetooth spec Vol 2, Part E - 7.1.8 Accept Connection Request command
See Bluetooth spec Vol 4, Part E - 7.1.8 Accept Connection Request command
'''
if self.link is None:
@@ -735,7 +736,7 @@ class Controller:
def on_hci_switch_role_command(self, command):
'''
See Bluetooth spec Vol 2, Part E - 7.2.8 Switch Role command
See Bluetooth spec Vol 4, Part E - 7.2.8 Switch Role command
'''
if self.link is None:
@@ -751,21 +752,21 @@ class Controller:
def on_hci_set_event_mask_command(self, command):
'''
See Bluetooth spec Vol 2, Part E - 7.3.1 Set Event Mask Command
See Bluetooth spec Vol 4, Part E - 7.3.1 Set Event Mask Command
'''
self.event_mask = command.event_mask
return bytes([HCI_SUCCESS])
def on_hci_reset_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.3.2 Reset Command
See Bluetooth spec Vol 4, Part E - 7.3.2 Reset Command
'''
# TODO: cleanup what needs to be reset
return bytes([HCI_SUCCESS])
def on_hci_write_local_name_command(self, command):
'''
See Bluetooth spec Vol 2, Part E - 7.3.11 Write Local Name Command
See Bluetooth spec Vol 4, Part E - 7.3.11 Write Local Name Command
'''
local_name = command.local_name
if len(local_name):
@@ -780,7 +781,7 @@ class Controller:
def on_hci_read_local_name_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.3.12 Read Local Name Command
See Bluetooth spec Vol 4, Part E - 7.3.12 Read Local Name Command
'''
local_name = bytes(self.local_name, 'utf-8')[:248]
if len(local_name) < 248:
@@ -790,19 +791,19 @@ class Controller:
def on_hci_read_class_of_device_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.3.25 Read Class of Device Command
See Bluetooth spec Vol 4, Part E - 7.3.25 Read Class of Device Command
'''
return bytes([HCI_SUCCESS, 0, 0, 0])
def on_hci_write_class_of_device_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.3.26 Write Class of Device Command
See Bluetooth spec Vol 4, Part E - 7.3.26 Write Class of Device Command
'''
return bytes([HCI_SUCCESS])
def on_hci_read_synchronous_flow_control_enable_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.3.36 Read Synchronous Flow Control Enable
See Bluetooth spec Vol 4, Part E - 7.3.36 Read Synchronous Flow Control Enable
Command
'''
if self.sync_flow_control:
@@ -813,7 +814,7 @@ class Controller:
def on_hci_write_synchronous_flow_control_enable_command(self, command):
'''
See Bluetooth spec Vol 2, Part E - 7.3.37 Write Synchronous Flow Control Enable
See Bluetooth spec Vol 4, Part E - 7.3.37 Write Synchronous Flow Control Enable
Command
'''
ret = HCI_SUCCESS
@@ -825,41 +826,59 @@ class Controller:
ret = HCI_INVALID_HCI_COMMAND_PARAMETERS_ERROR
return bytes([ret])
def on_hci_set_controller_to_host_flow_control_command(self, _command):
'''
See Bluetooth spec Vol 4, Part E - 7.3.38 Set Controller To Host Flow Control
Command
'''
# For now we just accept the command but ignore the values.
# TODO: respect the passed in values.
return bytes([HCI_SUCCESS])
def on_hci_host_buffer_size_command(self, _command):
'''
See Bluetooth spec Vol 4, Part E - 7.3.39 Host Buffer Size Command
'''
# For now we just accept the command but ignore the values.
# TODO: respect the passed in values.
return bytes([HCI_SUCCESS])
def on_hci_write_extended_inquiry_response_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.3.59 Write Simple Pairing Mode Command
See Bluetooth spec Vol 4, Part E - 7.3.56 Write Extended Inquiry Response
Command
'''
return bytes([HCI_SUCCESS])
def on_hci_write_simple_pairing_mode_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.3.59 Write Simple Pairing Mode Command
See Bluetooth spec Vol 4, Part E - 7.3.59 Write Simple Pairing Mode Command
'''
return bytes([HCI_SUCCESS])
def on_hci_set_event_mask_page_2_command(self, command):
'''
See Bluetooth spec Vol 2, Part E - 7.3.69 Set Event Mask Page 2 Command
See Bluetooth spec Vol 4, Part E - 7.3.69 Set Event Mask Page 2 Command
'''
self.event_mask_page_2 = command.event_mask_page_2
return bytes([HCI_SUCCESS])
def on_hci_read_le_host_support_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.3.78 Write LE Host Support Command
See Bluetooth spec Vol 4, Part E - 7.3.78 Write LE Host Support Command
'''
return bytes([HCI_SUCCESS, 1, 0])
def on_hci_write_le_host_support_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.3.79 Write LE Host Support Command
See Bluetooth spec Vol 4, Part E - 7.3.79 Write LE Host Support Command
'''
# TODO / Just ignore for now
return bytes([HCI_SUCCESS])
def on_hci_write_authenticated_payload_timeout_command(self, command):
'''
See Bluetooth spec Vol 2, Part E - 7.3.94 Write Authenticated Payload Timeout
See Bluetooth spec Vol 4, Part E - 7.3.94 Write Authenticated Payload Timeout
Command
'''
# TODO
@@ -867,7 +886,7 @@ class Controller:
def on_hci_read_local_version_information_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.4.1 Read Local Version Information Command
See Bluetooth spec Vol 4, Part E - 7.4.1 Read Local Version Information Command
'''
return struct.pack(
'<BBHBHH',
@@ -881,19 +900,19 @@ class Controller:
def on_hci_read_local_supported_commands_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.4.2 Read Local Supported Commands Command
See Bluetooth spec Vol 4, Part E - 7.4.2 Read Local Supported Commands Command
'''
return bytes([HCI_SUCCESS]) + self.supported_commands
def on_hci_read_local_supported_features_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.4.3 Read Local Supported Features Command
See Bluetooth spec Vol 4, Part E - 7.4.3 Read Local Supported Features Command
'''
return bytes([HCI_SUCCESS]) + self.lmp_features
def on_hci_read_bd_addr_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.4.6 Read BD_ADDR Command
See Bluetooth spec Vol 4, Part E - 7.4.6 Read BD_ADDR Command
'''
bd_addr = (
self._public_address.to_bytes()
@@ -904,14 +923,14 @@ class Controller:
def on_hci_le_set_event_mask_command(self, command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.1 LE Set Event Mask Command
See Bluetooth spec Vol 4, Part E - 7.8.1 LE Set Event Mask Command
'''
self.le_event_mask = command.le_event_mask
return bytes([HCI_SUCCESS])
def on_hci_le_read_buffer_size_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.2 LE Read Buffer Size Command
See Bluetooth spec Vol 4, Part E - 7.8.2 LE Read Buffer Size Command
'''
return struct.pack(
'<BHB',
@@ -922,49 +941,49 @@ class Controller:
def on_hci_le_read_local_supported_features_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.3 LE Read Local Supported Features
See Bluetooth spec Vol 4, Part E - 7.8.3 LE Read Local Supported Features
Command
'''
return bytes([HCI_SUCCESS]) + self.le_features
def on_hci_le_set_random_address_command(self, command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.4 LE Set Random Address Command
See Bluetooth spec Vol 4, Part E - 7.8.4 LE Set Random Address Command
'''
self.random_address = command.random_address
return bytes([HCI_SUCCESS])
def on_hci_le_set_advertising_parameters_command(self, command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.5 LE Set Advertising Parameters Command
See Bluetooth spec Vol 4, Part E - 7.8.5 LE Set Advertising Parameters Command
'''
self.advertising_parameters = command
return bytes([HCI_SUCCESS])
def on_hci_le_read_advertising_physical_channel_tx_power_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.6 LE Read Advertising Physical Channel
See Bluetooth spec Vol 4, Part E - 7.8.6 LE Read Advertising Physical Channel
Tx Power Command
'''
return bytes([HCI_SUCCESS, self.advertising_channel_tx_power])
def on_hci_le_set_advertising_data_command(self, command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.7 LE Set Advertising Data Command
See Bluetooth spec Vol 4, Part E - 7.8.7 LE Set Advertising Data Command
'''
self.advertising_data = command.advertising_data
return bytes([HCI_SUCCESS])
def on_hci_le_set_scan_response_data_command(self, command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.8 LE Set Scan Response Data Command
See Bluetooth spec Vol 4, Part E - 7.8.8 LE Set Scan Response Data Command
'''
self.le_scan_response_data = command.scan_response_data
return bytes([HCI_SUCCESS])
def on_hci_le_set_advertising_enable_command(self, command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.9 LE Set Advertising Enable Command
See Bluetooth spec Vol 4, Part E - 7.8.9 LE Set Advertising Enable Command
'''
if command.advertising_enable:
self.start_advertising()
@@ -975,7 +994,7 @@ class Controller:
def on_hci_le_set_scan_parameters_command(self, command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.10 LE Set Scan Parameters Command
See Bluetooth spec Vol 4, Part E - 7.8.10 LE Set Scan Parameters Command
'''
self.le_scan_type = command.le_scan_type
self.le_scan_interval = command.le_scan_interval
@@ -986,7 +1005,7 @@ class Controller:
def on_hci_le_set_scan_enable_command(self, command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.11 LE Set Scan Enable Command
See Bluetooth spec Vol 4, Part E - 7.8.11 LE Set Scan Enable Command
'''
self.le_scan_enable = command.le_scan_enable
self.filter_duplicates = command.filter_duplicates
@@ -994,7 +1013,7 @@ class Controller:
def on_hci_le_create_connection_command(self, command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.12 LE Create Connection Command
See Bluetooth spec Vol 4, Part E - 7.8.12 LE Create Connection Command
'''
if not self.link:
@@ -1027,40 +1046,40 @@ class Controller:
def on_hci_le_create_connection_cancel_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.13 LE Create Connection Cancel Command
See Bluetooth spec Vol 4, Part E - 7.8.13 LE Create Connection Cancel Command
'''
return bytes([HCI_SUCCESS])
def on_hci_le_read_filter_accept_list_size_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.14 LE Read Filter Accept List Size
See Bluetooth spec Vol 4, Part E - 7.8.14 LE Read Filter Accept List Size
Command
'''
return bytes([HCI_SUCCESS, self.filter_accept_list_size])
def on_hci_le_clear_filter_accept_list_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.15 LE Clear Filter Accept List Command
See Bluetooth spec Vol 4, Part E - 7.8.15 LE Clear Filter Accept List Command
'''
return bytes([HCI_SUCCESS])
def on_hci_le_add_device_to_filter_accept_list_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.16 LE Add Device To Filter Accept List
See Bluetooth spec Vol 4, Part E - 7.8.16 LE Add Device To Filter Accept List
Command
'''
return bytes([HCI_SUCCESS])
def on_hci_le_remove_device_from_filter_accept_list_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.17 LE Remove Device From Filter Accept
See Bluetooth spec Vol 4, Part E - 7.8.17 LE Remove Device From Filter Accept
List Command
'''
return bytes([HCI_SUCCESS])
def on_hci_le_read_remote_features_command(self, command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.21 LE Read Remote Features Command
See Bluetooth spec Vol 4, Part E - 7.8.21 LE Read Remote Features Command
'''
# First, say that the command is pending
@@ -1083,13 +1102,13 @@ class Controller:
def on_hci_le_rand_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.23 LE Rand Command
See Bluetooth spec Vol 4, Part E - 7.8.23 LE Rand Command
'''
return bytes([HCI_SUCCESS]) + struct.pack('Q', random.randint(0, 1 << 64))
def on_hci_le_enable_encryption_command(self, command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.24 LE Enable Encryption Command
See Bluetooth spec Vol 4, Part E - 7.8.24 LE Enable Encryption Command
'''
# Check the parameters
@@ -1122,13 +1141,13 @@ class Controller:
def on_hci_le_read_supported_states_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.27 LE Read Supported States Command
See Bluetooth spec Vol 4, Part E - 7.8.27 LE Read Supported States Command
'''
return bytes([HCI_SUCCESS]) + self.le_states
def on_hci_le_read_suggested_default_data_length_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.34 LE Read Suggested Default Data Length
See Bluetooth spec Vol 4, Part E - 7.8.34 LE Read Suggested Default Data Length
Command
'''
return struct.pack(
@@ -1140,7 +1159,7 @@ class Controller:
def on_hci_le_write_suggested_default_data_length_command(self, command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.35 LE Write Suggested Default Data Length
See Bluetooth spec Vol 4, Part E - 7.8.35 LE Write Suggested Default Data Length
Command
'''
self.suggested_max_tx_octets, self.suggested_max_tx_time = struct.unpack(
@@ -1150,33 +1169,33 @@ class Controller:
def on_hci_le_read_local_p_256_public_key_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.36 LE Read P-256 Public Key Command
See Bluetooth spec Vol 4, Part E - 7.8.36 LE Read P-256 Public Key Command
'''
# TODO create key and send HCI_LE_Read_Local_P-256_Public_Key_Complete event
return bytes([HCI_SUCCESS])
def on_hci_le_add_device_to_resolving_list_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.38 LE Add Device To Resolving List
See Bluetooth spec Vol 4, Part E - 7.8.38 LE Add Device To Resolving List
Command
'''
return bytes([HCI_SUCCESS])
def on_hci_le_clear_resolving_list_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.40 LE Clear Resolving List Command
See Bluetooth spec Vol 4, Part E - 7.8.40 LE Clear Resolving List Command
'''
return bytes([HCI_SUCCESS])
def on_hci_le_read_resolving_list_size_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.41 LE Read Resolving List Size Command
See Bluetooth spec Vol 4, Part E - 7.8.41 LE Read Resolving List Size Command
'''
return bytes([HCI_SUCCESS, self.resolving_list_size])
def on_hci_le_set_address_resolution_enable_command(self, command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.44 LE Set Address Resolution Enable
See Bluetooth spec Vol 4, Part E - 7.8.44 LE Set Address Resolution Enable
Command
'''
ret = HCI_SUCCESS
@@ -1190,7 +1209,7 @@ class Controller:
def on_hci_le_set_resolvable_private_address_timeout_command(self, command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.45 LE Set Resolvable Private Address
See Bluetooth spec Vol 4, Part E - 7.8.45 LE Set Resolvable Private Address
Timeout Command
'''
self.le_rpa_timeout = command.rpa_timeout
@@ -1198,7 +1217,7 @@ class Controller:
def on_hci_le_read_maximum_data_length_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.46 LE Read Maximum Data Length Command
See Bluetooth spec Vol 4, Part E - 7.8.46 LE Read Maximum Data Length Command
'''
return struct.pack(
'<BHHHH',
@@ -1211,7 +1230,7 @@ class Controller:
def on_hci_le_read_phy_command(self, command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.47 LE Read PHY Command
See Bluetooth spec Vol 4, Part E - 7.8.47 LE Read PHY Command
'''
return struct.pack(
'<BHBB',
@@ -1223,7 +1242,7 @@ class Controller:
def on_hci_le_set_default_phy_command(self, command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.48 LE Set Default PHY Command
See Bluetooth spec Vol 4, Part E - 7.8.48 LE Set Default PHY Command
'''
self.default_phy = {
'all_phys': command.all_phys,
@@ -1234,6 +1253,6 @@ class Controller:
def on_hci_le_read_transmit_power_command(self, _command):
'''
See Bluetooth spec Vol 2, Part E - 7.8.74 LE Read Transmit Power Command
See Bluetooth spec Vol 4, Part E - 7.8.74 LE Read Transmit Power Command
'''
return struct.pack('<BBB', HCI_SUCCESS, 0, 0)

View File

@@ -17,7 +17,7 @@
# -----------------------------------------------------------------------------
from __future__ import annotations
import struct
from typing import List, Optional, Tuple, Union, cast
from typing import List, Optional, Tuple, Union, cast, Dict
from .company_ids import COMPANY_IDENTIFIERS
@@ -53,7 +53,7 @@ def bit_flags_to_strings(bits, bit_flag_names):
return names
def name_or_number(dictionary, number, width=2):
def name_or_number(dictionary: Dict[int, str], number: int, width: int = 2) -> str:
name = dictionary.get(number)
if name is not None:
return name
@@ -562,11 +562,82 @@ class DeviceClass:
PERIPHERAL_HANDHELD_GESTURAL_INPUT_DEVICE_MINOR_DEVICE_CLASS: 'Handheld gestural input device'
}
WEARABLE_UNCATEGORIZED_MINOR_DEVICE_CLASS = 0x00
WEARABLE_WRISTWATCH_MINOR_DEVICE_CLASS = 0x01
WEARABLE_PAGER_MINOR_DEVICE_CLASS = 0x02
WEARABLE_JACKET_MINOR_DEVICE_CLASS = 0x03
WEARABLE_HELMET_MINOR_DEVICE_CLASS = 0x04
WEARABLE_GLASSES_MINOR_DEVICE_CLASS = 0x05
WEARABLE_MINOR_DEVICE_CLASS_NAMES = {
WEARABLE_UNCATEGORIZED_MINOR_DEVICE_CLASS: 'Uncategorized',
WEARABLE_WRISTWATCH_MINOR_DEVICE_CLASS: 'Wristwatch',
WEARABLE_PAGER_MINOR_DEVICE_CLASS: 'Pager',
WEARABLE_JACKET_MINOR_DEVICE_CLASS: 'Jacket',
WEARABLE_HELMET_MINOR_DEVICE_CLASS: 'Helmet',
WEARABLE_GLASSES_MINOR_DEVICE_CLASS: 'Glasses',
}
TOY_UNCATEGORIZED_MINOR_DEVICE_CLASS = 0x00
TOY_ROBOT_MINOR_DEVICE_CLASS = 0x01
TOY_VEHICLE_MINOR_DEVICE_CLASS = 0x02
TOY_DOLL_ACTION_FIGURE_MINOR_DEVICE_CLASS = 0x03
TOY_CONTROLLER_MINOR_DEVICE_CLASS = 0x04
TOY_GAME_MINOR_DEVICE_CLASS = 0x05
TOY_MINOR_DEVICE_CLASS_NAMES = {
TOY_UNCATEGORIZED_MINOR_DEVICE_CLASS: 'Uncategorized',
TOY_ROBOT_MINOR_DEVICE_CLASS: 'Robot',
TOY_VEHICLE_MINOR_DEVICE_CLASS: 'Vehicle',
TOY_DOLL_ACTION_FIGURE_MINOR_DEVICE_CLASS: 'Doll/Action figure',
TOY_CONTROLLER_MINOR_DEVICE_CLASS: 'Controller',
TOY_GAME_MINOR_DEVICE_CLASS: 'Game',
}
HEALTH_UNDEFINED_MINOR_DEVICE_CLASS = 0x00
HEALTH_BLOOD_PRESSURE_MONITOR_MINOR_DEVICE_CLASS = 0x01
HEALTH_THERMOMETER_MINOR_DEVICE_CLASS = 0x02
HEALTH_WEIGHING_SCALE_MINOR_DEVICE_CLASS = 0x03
HEALTH_GLUCOSE_METER_MINOR_DEVICE_CLASS = 0x04
HEALTH_PULSE_OXIMETER_MINOR_DEVICE_CLASS = 0x05
HEALTH_HEART_PULSE_RATE_MONITOR_MINOR_DEVICE_CLASS = 0x06
HEALTH_HEALTH_DATA_DISPLAY_MINOR_DEVICE_CLASS = 0x07
HEALTH_STEP_COUNTER_MINOR_DEVICE_CLASS = 0x08
HEALTH_BODY_COMPOSITION_ANALYZER_MINOR_DEVICE_CLASS = 0x09
HEALTH_PEAK_FLOW_MONITOR_MINOR_DEVICE_CLASS = 0x0A
HEALTH_MEDICATION_MONITOR_MINOR_DEVICE_CLASS = 0x0B
HEALTH_KNEE_PROSTHESIS_MINOR_DEVICE_CLASS = 0x0C
HEALTH_ANKLE_PROSTHESIS_MINOR_DEVICE_CLASS = 0x0D
HEALTH_GENERIC_HEALTH_MANAGER_MINOR_DEVICE_CLASS = 0x0E
HEALTH_PERSONAL_MOBILITY_DEVICE_MINOR_DEVICE_CLASS = 0x0F
HEALTH_MINOR_DEVICE_CLASS_NAMES = {
HEALTH_UNDEFINED_MINOR_DEVICE_CLASS: 'Undefined',
HEALTH_BLOOD_PRESSURE_MONITOR_MINOR_DEVICE_CLASS: 'Blood Pressure Monitor',
HEALTH_THERMOMETER_MINOR_DEVICE_CLASS: 'Thermometer',
HEALTH_WEIGHING_SCALE_MINOR_DEVICE_CLASS: 'Weighing Scale',
HEALTH_GLUCOSE_METER_MINOR_DEVICE_CLASS: 'Glucose Meter',
HEALTH_PULSE_OXIMETER_MINOR_DEVICE_CLASS: 'Pulse Oximeter',
HEALTH_HEART_PULSE_RATE_MONITOR_MINOR_DEVICE_CLASS: 'Heart/Pulse Rate Monitor',
HEALTH_HEALTH_DATA_DISPLAY_MINOR_DEVICE_CLASS: 'Health Data Display',
HEALTH_STEP_COUNTER_MINOR_DEVICE_CLASS: 'Step Counter',
HEALTH_BODY_COMPOSITION_ANALYZER_MINOR_DEVICE_CLASS: 'Body Composition Analyzer',
HEALTH_PEAK_FLOW_MONITOR_MINOR_DEVICE_CLASS: 'Peak Flow Monitor',
HEALTH_MEDICATION_MONITOR_MINOR_DEVICE_CLASS: 'Medication Monitor',
HEALTH_KNEE_PROSTHESIS_MINOR_DEVICE_CLASS: 'Knee Prosthesis',
HEALTH_ANKLE_PROSTHESIS_MINOR_DEVICE_CLASS: 'Ankle Prosthesis',
HEALTH_GENERIC_HEALTH_MANAGER_MINOR_DEVICE_CLASS: 'Generic Health Manager',
HEALTH_PERSONAL_MOBILITY_DEVICE_MINOR_DEVICE_CLASS: 'Personal Mobility Device',
}
MINOR_DEVICE_CLASS_NAMES = {
COMPUTER_MAJOR_DEVICE_CLASS: COMPUTER_MINOR_DEVICE_CLASS_NAMES,
PHONE_MAJOR_DEVICE_CLASS: PHONE_MINOR_DEVICE_CLASS_NAMES,
AUDIO_VIDEO_MAJOR_DEVICE_CLASS: AUDIO_VIDEO_MINOR_DEVICE_CLASS_NAMES,
PERIPHERAL_MAJOR_DEVICE_CLASS: PERIPHERAL_MINOR_DEVICE_CLASS_NAMES
PERIPHERAL_MAJOR_DEVICE_CLASS: PERIPHERAL_MINOR_DEVICE_CLASS_NAMES,
WEARABLE_MAJOR_DEVICE_CLASS: WEARABLE_MINOR_DEVICE_CLASS_NAMES,
TOY_MAJOR_DEVICE_CLASS: TOY_MINOR_DEVICE_CLASS_NAMES,
HEALTH_MAJOR_DEVICE_CLASS: HEALTH_MINOR_DEVICE_CLASS_NAMES,
}
# fmt: on

View File

@@ -86,6 +86,7 @@ from .hci import (
HCI_LE_Extended_Create_Connection_Command,
HCI_LE_Rand_Command,
HCI_LE_Read_PHY_Command,
HCI_LE_Set_Address_Resolution_Enable_Command,
HCI_LE_Set_Advertising_Data_Command,
HCI_LE_Set_Advertising_Enable_Command,
HCI_LE_Set_Advertising_Parameters_Command,
@@ -778,6 +779,7 @@ class DeviceConfiguration:
self.irk = bytes(16) # This really must be changed for any level of security
self.keystore = None
self.gatt_services: List[Dict[str, Any]] = []
self.address_resolution_offload = False
def load_from_dict(self, config: Dict[str, Any]) -> None:
# Load simple properties
@@ -1029,6 +1031,7 @@ class Device(CompositeEventEmitter):
self.discoverable = config.discoverable
self.connectable = config.connectable
self.classic_accept_any = config.classic_accept_any
self.address_resolution_offload = config.address_resolution_offload
for service in config.gatt_services:
characteristics = []
@@ -1256,31 +1259,16 @@ class Device(CompositeEventEmitter):
)
# Load the address resolving list
if self.keystore and self.host.supports_command(
HCI_LE_CLEAR_RESOLVING_LIST_COMMAND
):
await self.send_command(HCI_LE_Clear_Resolving_List_Command()) # type: ignore[call-arg]
if self.keystore:
await self.refresh_resolving_list()
resolving_keys = await self.keystore.get_resolving_keys()
for irk, address in resolving_keys:
await self.send_command(
HCI_LE_Add_Device_To_Resolving_List_Command(
peer_identity_address_type=address.address_type,
peer_identity_address=address,
peer_irk=irk,
local_irk=self.irk,
) # type: ignore[call-arg]
)
# Enable address resolution
# await self.send_command(
# HCI_LE_Set_Address_Resolution_Enable_Command(
# address_resolution_enable=1)
# )
# )
# Create a host-side address resolver
self.address_resolver = smp.AddressResolver(resolving_keys)
# Enable address resolution
if self.address_resolution_offload:
await self.send_command(
HCI_LE_Set_Address_Resolution_Enable_Command(
address_resolution_enable=1
) # type: ignore[call-arg]
)
if self.classic_enabled:
await self.send_command(
@@ -1310,6 +1298,26 @@ class Device(CompositeEventEmitter):
await self.host.flush()
self.powered_on = False
async def refresh_resolving_list(self) -> None:
assert self.keystore is not None
resolving_keys = await self.keystore.get_resolving_keys()
# Create a host-side address resolver
self.address_resolver = smp.AddressResolver(resolving_keys)
if self.address_resolution_offload:
await self.send_command(HCI_LE_Clear_Resolving_List_Command()) # type: ignore[call-arg]
for irk, address in resolving_keys:
await self.send_command(
HCI_LE_Add_Device_To_Resolving_List_Command(
peer_identity_address_type=address.address_type,
peer_identity_address=address,
peer_irk=irk,
local_irk=self.irk,
) # type: ignore[call-arg]
)
def supports_le_feature(self, feature):
return self.host.supports_le_feature(feature)
@@ -2851,18 +2859,22 @@ class Device(CompositeEventEmitter):
method = methods[peer_io_capability][io_capability]
async def reply() -> None:
if await connection.abort_on('disconnection', method()):
await self.host.send_command(
HCI_User_Confirmation_Request_Reply_Command( # type: ignore[call-arg]
bd_addr=connection.peer_address
)
)
else:
await self.host.send_command(
HCI_User_Confirmation_Request_Negative_Reply_Command( # type: ignore[call-arg]
bd_addr=connection.peer_address
try:
if await connection.abort_on('disconnection', method()):
await self.host.send_command(
HCI_User_Confirmation_Request_Reply_Command( # type: ignore[call-arg]
bd_addr=connection.peer_address
)
)
return
except Exception as error:
logger.warning(f'exception while confirming: {error}')
await self.host.send_command(
HCI_User_Confirmation_Request_Negative_Reply_Command( # type: ignore[call-arg]
bd_addr=connection.peer_address
)
)
AsyncRunner.spawn(reply())
@@ -2874,21 +2886,25 @@ class Device(CompositeEventEmitter):
pairing_config = self.pairing_config_factory(connection)
async def reply() -> None:
number = await connection.abort_on(
'disconnection', pairing_config.delegate.get_number()
try:
number = await connection.abort_on(
'disconnection', pairing_config.delegate.get_number()
)
if number is not None:
await self.host.send_command(
HCI_User_Passkey_Request_Reply_Command( # type: ignore[call-arg]
bd_addr=connection.peer_address, numeric_value=number
)
)
return
except Exception as error:
logger.warning(f'exception while asking for pass-key: {error}')
await self.host.send_command(
HCI_User_Passkey_Request_Negative_Reply_Command( # type: ignore[call-arg]
bd_addr=connection.peer_address
)
)
if number is not None:
await self.host.send_command(
HCI_User_Passkey_Request_Reply_Command( # type: ignore[call-arg]
bd_addr=connection.peer_address, numeric_value=number
)
)
else:
await self.host.send_command(
HCI_User_Passkey_Request_Negative_Reply_Command( # type: ignore[call-arg]
bd_addr=connection.peer_address
)
)
AsyncRunner.spawn(reply())

View File

@@ -283,8 +283,7 @@ class IncludedServiceDeclaration(Attribute):
f'IncludedServiceDefinition(handle=0x{self.handle:04X}, '
f'group_starting_handle=0x{self.service.handle:04X}, '
f'group_ending_handle=0x{self.service.end_group_handle:04X}, '
f'uuid={self.service.uuid}, '
f'{self.service.properties!s})'
f'uuid={self.service.uuid})'
)
@@ -309,31 +308,33 @@ class Characteristic(Attribute):
AUTHENTICATED_SIGNED_WRITES = 0x40
EXTENDED_PROPERTIES = 0x80
@staticmethod
def from_string(properties_str: str) -> Characteristic.Properties:
property_names: List[str] = []
for property in Characteristic.Properties:
if property.name is None:
raise TypeError()
property_names.append(property.name)
def string_to_property(property_string) -> Characteristic.Properties:
for property in zip(Characteristic.Properties, property_names):
if property_string == property[1]:
return property[0]
raise TypeError(f"Unable to convert {property_string} to Property")
@classmethod
def from_string(cls, properties_str: str) -> Characteristic.Properties:
try:
return functools.reduce(
lambda x, y: x | string_to_property(y),
properties_str.split(","),
lambda x, y: x | cls[y],
properties_str.replace("|", ",").split(","),
Characteristic.Properties(0),
)
except TypeError:
except (TypeError, KeyError):
# The check for `p.name is not None` here is needed because for InFlag
# enums, the .name property can be None, when the enum value is 0,
# so the type hint for .name is Optional[str].
enum_list: List[str] = [p.name for p in cls if p.name is not None]
enum_list_str = ",".join(enum_list)
raise TypeError(
f"Characteristic.Properties::from_string() error:\nExpected a string containing any of the keys, separated by commas: {','.join(property_names)}\nGot: {properties_str}"
f"Characteristic.Properties::from_string() error:\nExpected a string containing any of the keys, separated by , or |: {enum_list_str}\nGot: {properties_str}"
)
def __str__(self):
# NOTE: we override this method to offer a consistent result between python
# versions: the value returned by IntFlag.__str__() changed in version 11.
return '|'.join(
flag.name
for flag in Characteristic.Properties
if self.value & flag.value and flag.name is not None
)
# For backwards compatibility these are defined here
# For new code, please use Characteristic.Properties.X
BROADCAST = Properties.BROADCAST
@@ -373,7 +374,7 @@ class Characteristic(Attribute):
f'Characteristic(handle=0x{self.handle:04X}, '
f'end=0x{self.end_group_handle:04X}, '
f'uuid={self.uuid}, '
f'{self.properties!s})'
f'{self.properties})'
)
@@ -401,7 +402,7 @@ class CharacteristicDeclaration(Attribute):
f'CharacteristicDeclaration(handle=0x{self.handle:04X}, '
f'value_handle=0x{self.value_handle:04X}, '
f'uuid={self.characteristic.uuid}, '
f'{self.characteristic.properties!s})'
f'{self.characteristic.properties})'
)

View File

@@ -1445,8 +1445,14 @@ class HCI_Object:
@staticmethod
def init_from_fields(hci_object, fields, values):
if isinstance(values, dict):
for field_name, _ in fields:
setattr(hci_object, field_name, values[field_name])
for field in fields:
if isinstance(field, list):
# The field is an array, up-level the array field names
for sub_field_name, _ in field:
setattr(hci_object, sub_field_name, values[sub_field_name])
else:
field_name = field[0]
setattr(hci_object, field_name, values[field_name])
else:
for field_name, field_value in zip(fields, values):
setattr(hci_object, field_name, field_value)
@@ -1456,133 +1462,161 @@ class HCI_Object:
parsed = HCI_Object.dict_from_bytes(data, offset, fields)
HCI_Object.init_from_fields(hci_object, parsed.keys(), parsed.values())
@staticmethod
def parse_field(data, offset, field_type):
# The field_type may be a dictionary with a mapper, parser, and/or size
if isinstance(field_type, dict):
if 'size' in field_type:
field_type = field_type['size']
elif 'parser' in field_type:
field_type = field_type['parser']
# Parse the field
if field_type == '*':
# The rest of the bytes
field_value = data[offset:]
return (field_value, len(field_value))
if field_type == 1:
# 8-bit unsigned
return (data[offset], 1)
if field_type == -1:
# 8-bit signed
return (struct.unpack_from('b', data, offset)[0], 1)
if field_type == 2:
# 16-bit unsigned
return (struct.unpack_from('<H', data, offset)[0], 2)
if field_type == '>2':
# 16-bit unsigned big-endian
return (struct.unpack_from('>H', data, offset)[0], 2)
if field_type == -2:
# 16-bit signed
return (struct.unpack_from('<h', data, offset)[0], 2)
if field_type == 3:
# 24-bit unsigned
padded = data[offset : offset + 3] + bytes([0])
return (struct.unpack('<I', padded)[0], 3)
if field_type == 4:
# 32-bit unsigned
return (struct.unpack_from('<I', data, offset)[0], 4)
if field_type == '>4':
# 32-bit unsigned big-endian
return (struct.unpack_from('>I', data, offset)[0], 4)
if isinstance(field_type, int) and 4 < field_type <= 256:
# Byte array (from 5 up to 256 bytes)
return (data[offset : offset + field_type], field_type)
if callable(field_type):
new_offset, field_value = field_type(data, offset)
return (field_value, new_offset - offset)
raise ValueError(f'unknown field type {field_type}')
@staticmethod
def dict_from_bytes(data, offset, fields):
result = collections.OrderedDict()
for (field_name, field_type) in fields:
# The field_type may be a dictionary with a mapper, parser, and/or size
if isinstance(field_type, dict):
if 'size' in field_type:
field_type = field_type['size']
elif 'parser' in field_type:
field_type = field_type['parser']
# Parse the field
if field_type == '*':
# The rest of the bytes
field_value = data[offset:]
offset += len(field_value)
elif field_type == 1:
# 8-bit unsigned
field_value = data[offset]
for field in fields:
if isinstance(field, list):
# This is an array field, starting with a 1-byte item count.
item_count = data[offset]
offset += 1
elif field_type == -1:
# 8-bit signed
field_value = struct.unpack_from('b', data, offset)[0]
offset += 1
elif field_type == 2:
# 16-bit unsigned
field_value = struct.unpack_from('<H', data, offset)[0]
offset += 2
elif field_type == '>2':
# 16-bit unsigned big-endian
field_value = struct.unpack_from('>H', data, offset)[0]
offset += 2
elif field_type == -2:
# 16-bit signed
field_value = struct.unpack_from('<h', data, offset)[0]
offset += 2
elif field_type == 3:
# 24-bit unsigned
padded = data[offset : offset + 3] + bytes([0])
field_value = struct.unpack('<I', padded)[0]
offset += 3
elif field_type == 4:
# 32-bit unsigned
field_value = struct.unpack_from('<I', data, offset)[0]
offset += 4
elif field_type == '>4':
# 32-bit unsigned big-endian
field_value = struct.unpack_from('>I', data, offset)[0]
offset += 4
elif isinstance(field_type, int) and 4 < field_type <= 256:
# Byte array (from 5 up to 256 bytes)
field_value = data[offset : offset + field_type]
offset += field_type
elif callable(field_type):
offset, field_value = field_type(data, offset)
else:
raise ValueError(f'unknown field type {field_type}')
for _ in range(item_count):
for sub_field_name, sub_field_type in field:
value, size = HCI_Object.parse_field(
data, offset, sub_field_type
)
result.setdefault(sub_field_name, []).append(value)
offset += size
continue
field_name, field_type = field
field_value, field_size = HCI_Object.parse_field(data, offset, field_type)
result[field_name] = field_value
offset += field_size
return result
@staticmethod
def serialize_field(field_value, field_type):
# The field_type may be a dictionary with a mapper, parser, serializer,
# and/or size
serializer = None
if isinstance(field_type, dict):
if 'serializer' in field_type:
serializer = field_type['serializer']
if 'size' in field_type:
field_type = field_type['size']
# Serialize the field
if serializer:
field_bytes = serializer(field_value)
elif field_type == 1:
# 8-bit unsigned
field_bytes = bytes([field_value])
elif field_type == -1:
# 8-bit signed
field_bytes = struct.pack('b', field_value)
elif field_type == 2:
# 16-bit unsigned
field_bytes = struct.pack('<H', field_value)
elif field_type == '>2':
# 16-bit unsigned big-endian
field_bytes = struct.pack('>H', field_value)
elif field_type == -2:
# 16-bit signed
field_bytes = struct.pack('<h', field_value)
elif field_type == 3:
# 24-bit unsigned
field_bytes = struct.pack('<I', field_value)[0:3]
elif field_type == 4:
# 32-bit unsigned
field_bytes = struct.pack('<I', field_value)
elif field_type == '>4':
# 32-bit unsigned big-endian
field_bytes = struct.pack('>I', field_value)
elif field_type == '*':
if isinstance(field_value, int):
if 0 <= field_value <= 255:
field_bytes = bytes([field_value])
else:
raise ValueError('value too large for *-typed field')
else:
field_bytes = bytes(field_value)
elif isinstance(field_value, (bytes, bytearray)) or hasattr(
field_value, 'to_bytes'
):
field_bytes = bytes(field_value)
if isinstance(field_type, int) and 4 < field_type <= 256:
# Truncate or pad with zeros if the field is too long or too short
if len(field_bytes) < field_type:
field_bytes += bytes(field_type - len(field_bytes))
elif len(field_bytes) > field_type:
field_bytes = field_bytes[:field_type]
else:
raise ValueError(f"don't know how to serialize type {type(field_value)}")
return field_bytes
@staticmethod
def dict_to_bytes(hci_object, fields):
result = bytearray()
for (field_name, field_type) in fields:
# The field_type may be a dictionary with a mapper, parser, serializer,
# and/or size
serializer = None
if isinstance(field_type, dict):
if 'serializer' in field_type:
serializer = field_type['serializer']
if 'size' in field_type:
field_type = field_type['size']
# Serialize the field
field_value = hci_object[field_name]
if serializer:
field_bytes = serializer(field_value)
elif field_type == 1:
# 8-bit unsigned
field_bytes = bytes([field_value])
elif field_type == -1:
# 8-bit signed
field_bytes = struct.pack('b', field_value)
elif field_type == 2:
# 16-bit unsigned
field_bytes = struct.pack('<H', field_value)
elif field_type == '>2':
# 16-bit unsigned big-endian
field_bytes = struct.pack('>H', field_value)
elif field_type == -2:
# 16-bit signed
field_bytes = struct.pack('<h', field_value)
elif field_type == 3:
# 24-bit unsigned
field_bytes = struct.pack('<I', field_value)[0:3]
elif field_type == 4:
# 32-bit unsigned
field_bytes = struct.pack('<I', field_value)
elif field_type == '>4':
# 32-bit unsigned big-endian
field_bytes = struct.pack('>I', field_value)
elif field_type == '*':
if isinstance(field_value, int):
if 0 <= field_value <= 255:
field_bytes = bytes([field_value])
else:
raise ValueError('value too large for *-typed field')
else:
field_bytes = bytes(field_value)
elif isinstance(field_value, (bytes, bytearray)) or hasattr(
field_value, 'to_bytes'
):
field_bytes = bytes(field_value)
if isinstance(field_type, int) and 4 < field_type <= 256:
# Truncate or Pad with zeros if the field is too long or too short
if len(field_bytes) < field_type:
field_bytes += bytes(field_type - len(field_bytes))
elif len(field_bytes) > field_type:
field_bytes = field_bytes[:field_type]
else:
raise ValueError(
f"don't know how to serialize type {type(field_value)}"
for field in fields:
if isinstance(field, list):
# The field is an array. The serialized form starts with a 1-byte
# item count. We use the length of the first array field as the
# array count, since all array fields have the same number of items.
item_count = len(hci_object[field[0][0]])
result += bytes([item_count]) + b''.join(
b''.join(
HCI_Object.serialize_field(
hci_object[sub_field_name][i], sub_field_type
)
for sub_field_name, sub_field_type in field
)
for i in range(item_count)
)
continue
result += field_bytes
(field_name, field_type) = field
result += HCI_Object.serialize_field(hci_object[field_name], field_type)
return bytes(result)
@@ -1617,48 +1651,73 @@ class HCI_Object:
return str(value)
@staticmethod
def format_fields(hci_object, keys, indentation='', value_mappers=None):
if not keys:
return ''
def stringify_field(
field_name, field_type, field_value, indentation, value_mappers
):
value_mapper = None
if isinstance(field_type, dict):
# Get the value mapper from the specifier
value_mapper = field_type.get('mapper')
# Measure the widest field name
max_field_name_length = max(
(len(key[0] if isinstance(key, tuple) else key) for key in keys)
# Check if there's a matching mapper passed
if value_mappers:
value_mapper = value_mappers.get(field_name, value_mapper)
# Map the value if we have a mapper
if value_mapper is not None:
field_value = value_mapper(field_value)
# Get the string representation of the value
return HCI_Object.format_field_value(
field_value, indentation=indentation + ' '
)
@staticmethod
def format_fields(hci_object, fields, indentation='', value_mappers=None):
if not fields:
return ''
# Build array of formatted key:value pairs
fields = []
for key in keys:
value_mapper = None
if isinstance(key, tuple):
# The key has an associated specifier
key, specifier = key
field_strings = []
for field in fields:
if isinstance(field, list):
for sub_field in field:
sub_field_name, sub_field_type = sub_field
item_count = len(hci_object[sub_field_name])
for i in range(item_count):
field_strings.append(
(
f'{sub_field_name}[{i}]',
HCI_Object.stringify_field(
sub_field_name,
sub_field_type,
hci_object[sub_field_name][i],
indentation,
value_mappers,
),
),
)
continue
# Get the value mapper from the specifier
if isinstance(specifier, dict):
value_mapper = specifier.get('mapper')
# Get the value for the field
value = hci_object[key]
# Check if there's a matching mapper passed
if value_mappers:
value_mapper = value_mappers.get(key, value_mapper)
# Map the value if we have a mapper
if value_mapper is not None:
value = value_mapper(value)
# Get the string representation of the value
value_str = HCI_Object.format_field_value(
value, indentation=indentation + ' '
field_name, field_type = field
field_value = hci_object[field_name]
field_strings.append(
(
field_name,
HCI_Object.stringify_field(
field_name, field_type, field_value, indentation, value_mappers
),
),
)
# Add the field to the formatted result
key_str = color(f'{key + ":":{1 + max_field_name_length}}', 'cyan')
fields.append(f'{indentation}{key_str} {value_str}')
return '\n'.join(fields)
# Measure the widest field name
max_field_name_length = max(len(s[0]) for s in field_strings)
sep = ':'
return '\n'.join(
f'{indentation}'
f'{color(f"{field_name + sep:{1 + max_field_name_length}}", "cyan")} {field_value}'
for field_name, field_value in field_strings
)
def __bytes__(self):
return self.to_bytes()
@@ -3769,9 +3828,7 @@ class HCI_LE_Set_Extended_Advertising_Parameters_Command(HCI_Command):
'advertising_data',
{
'parser': HCI_Object.parse_length_prefixed_bytes,
'serializer': functools.partial(
HCI_Object.serialize_length_prefixed_bytes
),
'serializer': HCI_Object.serialize_length_prefixed_bytes,
},
),
]
@@ -3819,9 +3876,7 @@ class HCI_LE_Set_Extended_Advertising_Data_Command(HCI_Command):
'scan_response_data',
{
'parser': HCI_Object.parse_length_prefixed_bytes,
'serializer': functools.partial(
HCI_Object.serialize_length_prefixed_bytes
),
'serializer': HCI_Object.serialize_length_prefixed_bytes,
},
),
]
@@ -3849,73 +3904,21 @@ class HCI_LE_Set_Extended_Scan_Response_Data_Command(HCI_Command):
# -----------------------------------------------------------------------------
@HCI_Command.command(fields=None)
@HCI_Command.command(
[
('enable', 1),
[
('advertising_handles', 1),
('durations', 2),
('max_extended_advertising_events', 1),
],
]
)
class HCI_LE_Set_Extended_Advertising_Enable_Command(HCI_Command):
'''
See Bluetooth spec @ 7.8.56 LE Set Extended Advertising Enable Command
'''
@classmethod
def from_parameters(cls, parameters):
enable = parameters[0]
num_sets = parameters[1]
advertising_handles = []
durations = []
max_extended_advertising_events = []
offset = 2
for _ in range(num_sets):
advertising_handles.append(parameters[offset])
durations.append(struct.unpack_from('<H', parameters, offset + 1)[0])
max_extended_advertising_events.append(parameters[offset + 3])
offset += 4
return cls(
enable, advertising_handles, durations, max_extended_advertising_events
)
def __init__(
self, enable, advertising_handles, durations, max_extended_advertising_events
):
super().__init__(HCI_LE_SET_EXTENDED_ADVERTISING_ENABLE_COMMAND)
self.enable = enable
self.advertising_handles = advertising_handles
self.durations = durations
self.max_extended_advertising_events = max_extended_advertising_events
self.parameters = bytes([enable, len(advertising_handles)]) + b''.join(
[
struct.pack(
'<BHB',
advertising_handles[i],
durations[i],
max_extended_advertising_events[i],
)
for i in range(len(advertising_handles))
]
)
def __str__(self):
fields = [('enable:', self.enable)]
for i, advertising_handle in enumerate(self.advertising_handles):
fields.append(
(f'advertising_handle[{i}]: ', advertising_handle)
)
fields.append((f'duration[{i}]: ', self.durations[i]))
fields.append(
(
f'max_extended_advertising_events[{i}]:',
self.max_extended_advertising_events[i],
)
)
return (
color(self.name, 'green')
+ ':\n'
+ '\n'.join(
[color(field[0], 'cyan') + ' ' + str(field[1]) for field in fields]
)
)
# -----------------------------------------------------------------------------
@HCI_Command.command(
@@ -4066,7 +4069,10 @@ class HCI_LE_Set_Extended_Scan_Parameters_Command(HCI_Command):
color(self.name, 'green')
+ ':\n'
+ '\n'.join(
[color(field[0], 'cyan') + ' ' + str(field[1]) for field in fields]
[
color(' ' + field[0], 'cyan') + ' ' + str(field[1])
for field in fields
]
)
)
@@ -4242,7 +4248,10 @@ class HCI_LE_Extended_Create_Connection_Command(HCI_Command):
color(self.name, 'green')
+ ':\n'
+ '\n'.join(
[color(field[0], 'cyan') + ' ' + str(field[1]) for field in fields]
[
color(' ' + field[0], 'cyan') + ' ' + str(field[1])
for field in fields
]
)
)
@@ -5205,7 +5214,7 @@ class HCI_Number_Of_Completed_Packets_Event(HCI_Event):
def __str__(self):
lines = [
color(self.name, 'magenta') + ':',
color(' number_of_handles: ', 'cyan')
color(' number_of_handles: ', 'cyan')
+ f'{len(self.connection_handles)}',
]
for i, connection_handle in enumerate(self.connection_handles):

View File

@@ -1,4 +1,4 @@
# Copyright 2021-2022 Google LLC
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,11 +17,31 @@
# -----------------------------------------------------------------------------
import logging
import asyncio
import collections
from typing import Union
import dataclasses
import enum
import traceback
from typing import Dict, List, Union, Set
from . import at
from . import rfcomm
from .colors import color
from bumble.core import (
ProtocolError,
BT_GENERIC_AUDIO_SERVICE,
BT_HANDSFREE_SERVICE,
BT_L2CAP_PROTOCOL_ID,
BT_RFCOMM_PROTOCOL_ID,
)
from bumble.sdp import (
DataElement,
ServiceAttribute,
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID,
)
# -----------------------------------------------------------------------------
# Logging
@@ -30,72 +50,700 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Protocol Support
# Normative protocol definitions
# -----------------------------------------------------------------------------
# HF supported features (AT+BRSF=) (normative).
# Hands-Free Profile v1.8, 4.34.2, AT Capabilities Re-Used from GSM 07.07
# and 3GPP 27.007
class HfFeature(enum.IntFlag):
EC_NR = 0x001 # Echo Cancel & Noise reduction
THREE_WAY_CALLING = 0x002
CLI_PRESENTATION_CAPABILITY = 0x004
VOICE_RECOGNITION_ACTIVATION = 0x008
REMOTE_VOLUME_CONTROL = 0x010
ENHANCED_CALL_STATUS = 0x020
ENHANCED_CALL_CONTROL = 0x040
CODEC_NEGOTIATION = 0x080
HF_INDICATORS = 0x100
ESCO_S4_SETTINGS_SUPPORTED = 0x200
ENHANCED_VOICE_RECOGNITION_STATUS = 0x400
VOICE_RECOGNITION_TEST = 0x800
# AG supported features (+BRSF:) (normative).
# Hands-Free Profile v1.8, 4.34.2, AT Capabilities Re-Used from GSM 07.07
# and 3GPP 27.007
class AgFeature(enum.IntFlag):
THREE_WAY_CALLING = 0x001
EC_NR = 0x002 # Echo Cancel & Noise reduction
VOICE_RECOGNITION_FUNCTION = 0x004
IN_BAND_RING_TONE_CAPABILITY = 0x008
VOICE_TAG = 0x010 # Attach a number to voice tag
REJECT_CALL = 0x020 # Ability to reject a call
ENHANCED_CALL_STATUS = 0x040
ENHANCED_CALL_CONTROL = 0x080
EXTENDED_ERROR_RESULT_CODES = 0x100
CODEC_NEGOTIATION = 0x200
HF_INDICATORS = 0x400
ESCO_S4_SETTINGS_SUPPORTED = 0x800
ENHANCED_VOICE_RECOGNITION_STATUS = 0x1000
VOICE_RECOGNITION_TEST = 0x2000
# Audio Codec IDs (normative).
# Hands-Free Profile v1.8, 10 Appendix B
class AudioCodec(enum.IntEnum):
CVSD = 0x01 # Support for CVSD audio codec
MSBC = 0x02 # Support for mSBC audio codec
# HF Indicators (normative).
# Bluetooth Assigned Numbers, 6.10.1 HF Indicators
class HfIndicator(enum.IntEnum):
ENHANCED_SAFETY = 0x01 # Enhanced safety feature
BATTERY_LEVEL = 0x02 # Battery level feature
# Call Hold supported operations (normative).
# AT Commands Reference Guide, 3.5.2.3.12 +CHLD - Call Holding Services
class CallHoldOperation(enum.IntEnum):
RELEASE_ALL_HELD_CALLS = 0 # Release all held calls
RELEASE_ALL_ACTIVE_CALLS = 1 # Release all active calls, accept other
HOLD_ALL_ACTIVE_CALLS = 2 # Place all active calls on hold, accept other
ADD_HELD_CALL = 3 # Adds a held call to conversation
# Response Hold status (normative).
# Hands-Free Profile v1.8, 4.34.2, AT Capabilities Re-Used from GSM 07.07
# and 3GPP 27.007
class ResponseHoldStatus(enum.IntEnum):
INC_CALL_HELD = 0 # Put incoming call on hold
HELD_CALL_ACC = 1 # Accept a held incoming call
HELD_CALL_REJ = 2 # Reject a held incoming call
# Values for the Call Setup AG indicator (normative).
# Hands-Free Profile v1.8, 4.34.2, AT Capabilities Re-Used from GSM 07.07
# and 3GPP 27.007
class CallSetupAgIndicator(enum.IntEnum):
NOT_IN_CALL_SETUP = 0
INCOMING_CALL_PROCESS = 1
OUTGOING_CALL_SETUP = 2
REMOTE_ALERTED = 3 # Remote party alerted in an outgoing call
# Values for the Call Held AG indicator (normative).
# Hands-Free Profile v1.8, 4.34.2, AT Capabilities Re-Used from GSM 07.07
# and 3GPP 27.007
class CallHeldAgIndicator(enum.IntEnum):
NO_CALLS_HELD = 0
# Call is placed on hold or active/held calls swapped
# (The AG has both an active AND a held call)
CALL_ON_HOLD_AND_ACTIVE_CALL = 1
CALL_ON_HOLD_NO_ACTIVE_CALL = 2 # Call on hold, no active call
# Call Info direction (normative).
# AT Commands Reference Guide, 3.5.2.3.15 +CLCC - List Current Calls
class CallInfoDirection(enum.IntEnum):
MOBILE_ORIGINATED_CALL = 0
MOBILE_TERMINATED_CALL = 1
# Call Info status (normative).
# AT Commands Reference Guide, 3.5.2.3.15 +CLCC - List Current Calls
class CallInfoStatus(enum.IntEnum):
ACTIVE = 0
HELD = 1
DIALING = 2
ALERTING = 3
INCOMING = 4
WAITING = 5
# Call Info mode (normative).
# AT Commands Reference Guide, 3.5.2.3.15 +CLCC - List Current Calls
class CallInfoMode(enum.IntEnum):
VOICE = 0
DATA = 1
FAX = 2
UNKNOWN = 9
# -----------------------------------------------------------------------------
class HfpProtocol:
# Hands-Free Control Interoperability Requirements
# -----------------------------------------------------------------------------
# Response codes.
RESPONSE_CODES = [
"+APLSIRI",
"+BAC",
"+BCC",
"+BCS",
"+BIA",
"+BIEV",
"+BIND",
"+BINP",
"+BLDN",
"+BRSF",
"+BTRH",
"+BVRA",
"+CCWA",
"+CHLD",
"+CHUP",
"+CIND",
"+CLCC",
"+CLIP",
"+CMEE",
"+CMER",
"+CNUM",
"+COPS",
"+IPHONEACCEV",
"+NREC",
"+VGM",
"+VGS",
"+VTS",
"+XAPL",
"A",
"D",
]
# Unsolicited responses and statuses.
UNSOLICITED_CODES = [
"+APLSIRI",
"+BCS",
"+BIND",
"+BSIR",
"+BTRH",
"+BVRA",
"+CCWA",
"+CIEV",
"+CLIP",
"+VGM",
"+VGS",
"BLACKLISTED",
"BUSY",
"DELAYED",
"NO ANSWER",
"NO CARRIER",
"RING",
]
# Status codes
STATUS_CODES = [
"+CME ERROR",
"BLACKLISTED",
"BUSY",
"DELAYED",
"ERROR",
"NO ANSWER",
"NO CARRIER",
"OK",
]
@dataclasses.dataclass
class Configuration:
supported_hf_features: List[HfFeature]
supported_hf_indicators: List[HfIndicator]
supported_audio_codecs: List[AudioCodec]
class AtResponseType(enum.Enum):
"""Indicate if a response is expected from an AT command, and if multiple
responses are accepted."""
NONE = 0
SINGLE = 1
MULTIPLE = 2
class AtResponse:
code: str
parameters: list
def __init__(self, response: bytearray):
code_and_parameters = response.split(b':')
parameters = (
code_and_parameters[1] if len(code_and_parameters) > 1 else bytearray()
)
self.code = code_and_parameters[0].decode()
self.parameters = at.parse_parameters(parameters)
@dataclasses.dataclass
class AgIndicatorState:
description: str
index: int
supported_values: Set[int]
current_status: int
@dataclasses.dataclass
class HfIndicatorState:
supported: bool = False
enabled: bool = False
class HfProtocol:
"""Implementation for the Hands-Free side of the Hands-Free profile.
Reference specification Hands-Free Profile v1.8"""
supported_hf_features: int
supported_audio_codecs: List[AudioCodec]
supported_ag_features: int
supported_ag_call_hold_operations: List[CallHoldOperation]
ag_indicators: List[AgIndicatorState]
hf_indicators: Dict[HfIndicator, HfIndicatorState]
dlc: rfcomm.DLC
buffer: str
lines: collections.deque
lines_available: asyncio.Event
command_lock: asyncio.Lock
response_queue: asyncio.Queue
unsolicited_queue: asyncio.Queue
read_buffer: bytearray
def __init__(self, dlc: rfcomm.DLC) -> None:
def __init__(self, dlc: rfcomm.DLC, configuration: Configuration):
# Configure internal state.
self.dlc = dlc
self.buffer = ''
self.lines = collections.deque()
self.lines_available = asyncio.Event()
self.command_lock = asyncio.Lock()
self.response_queue = asyncio.Queue()
self.unsolicited_queue = asyncio.Queue()
self.read_buffer = bytearray()
dlc.sink = self.feed
# Build local features.
self.supported_hf_features = sum(configuration.supported_hf_features)
self.supported_audio_codecs = configuration.supported_audio_codecs
def feed(self, data: Union[bytes, str]) -> None:
# Convert the data to a string if needed
if isinstance(data, bytes):
data = data.decode('utf-8')
self.hf_indicators = {
indicator: HfIndicatorState()
for indicator in configuration.supported_hf_indicators
}
logger.debug(f'<<< Data received: {data}')
# Clear remote features.
self.supported_ag_features = 0
self.supported_ag_call_hold_operations = []
self.ag_indicators = []
# Add to the buffer and look for lines
self.buffer += data
while (separator := self.buffer.find('\r')) >= 0:
line = self.buffer[:separator].strip()
self.buffer = self.buffer[separator + 1 :]
if len(line) > 0:
self.on_line(line)
# Bind the AT reader to the RFCOMM channel.
self.dlc.sink = self._read_at
def on_line(self, line: str) -> None:
self.lines.append(line)
self.lines_available.set()
def supports_hf_feature(self, feature: HfFeature) -> bool:
return (self.supported_hf_features & feature) != 0
def send_command_line(self, line: str) -> None:
logger.debug(color(f'>>> {line}', 'yellow'))
self.dlc.write(line + '\r')
def supports_ag_feature(self, feature: AgFeature) -> bool:
return (self.supported_ag_features & feature) != 0
def send_response_line(self, line: str) -> None:
logger.debug(color(f'>>> {line}', 'yellow'))
self.dlc.write('\r\n' + line + '\r\n')
# Read AT messages from the RFCOMM channel.
# Enqueue AT commands, responses, unsolicited responses to their
# respective queues, and set the corresponding event.
def _read_at(self, data: bytes):
# Append to the read buffer.
self.read_buffer.extend(data)
async def next_line(self) -> str:
await self.lines_available.wait()
line = self.lines.popleft()
if not self.lines:
self.lines_available.clear()
logger.debug(color(f'<<< {line}', 'green'))
return line
# Locate header and trailer.
header = self.read_buffer.find(b'\r\n')
trailer = self.read_buffer.find(b'\r\n', header + 2)
if header == -1 or trailer == -1:
return
async def initialize_service(self) -> None:
# Perform Service Level Connection Initialization
self.send_command_line('AT+BRSF=2072') # Retrieve Supported Features
await (self.next_line())
await (self.next_line())
# Isolate the AT response code and parameters.
raw_response = self.read_buffer[header + 2 : trailer]
response = AtResponse(raw_response)
logger.debug(f"<<< {raw_response.decode()}")
self.send_command_line('AT+CIND=?')
await (self.next_line())
await (self.next_line())
# Consume the response bytes.
self.read_buffer = self.read_buffer[trailer + 2 :]
self.send_command_line('AT+CIND?')
await (self.next_line())
await (self.next_line())
# Forward the received code to the correct queue.
if self.command_lock.locked() and (
response.code in STATUS_CODES or response.code in RESPONSE_CODES
):
self.response_queue.put_nowait(response)
elif response.code in UNSOLICITED_CODES:
self.unsolicited_queue.put_nowait(response)
else:
logger.warning(f"dropping unexpected response with code '{response.code}'")
self.send_command_line('AT+CMER=3,0,0,1')
await (self.next_line())
# Send an AT command and wait for the peer resposne.
# Wait for the AT responses sent by the peer, to the status code.
# Raises asyncio.TimeoutError if the status is not received
# after a timeout (default 1 second).
# Raises ProtocolError if the status is not OK.
async def execute_command(
self,
cmd: str,
timeout: float = 1.0,
response_type: AtResponseType = AtResponseType.NONE,
) -> Union[None, AtResponse, List[AtResponse]]:
async with self.command_lock:
logger.debug(f">>> {cmd}")
self.dlc.write(cmd + '\r')
responses: List[AtResponse] = []
while True:
result = await asyncio.wait_for(
self.response_queue.get(), timeout=timeout
)
if result.code == 'OK':
if response_type == AtResponseType.SINGLE and len(responses) != 1:
raise ProtocolError("NO ANSWER")
if response_type == AtResponseType.MULTIPLE:
return responses
if response_type == AtResponseType.SINGLE:
return responses[0]
return None
if result.code in STATUS_CODES:
raise ProtocolError(result.code)
responses.append(result)
# 4.2.1 Service Level Connection Initialization.
async def initiate_slc(self):
# 4.2.1.1 Supported features exchange
# First, in the initialization procedure, the HF shall send the
# AT+BRSF=<HF supported features> command to the AG to both notify
# the AG of the supported features in the HF, as well as to retrieve the
# supported features in the AG using the +BRSF result code.
response = await self.execute_command(
f"AT+BRSF={self.supported_hf_features}", response_type=AtResponseType.SINGLE
)
self.supported_ag_features = int(response.parameters[0])
logger.info(f"supported AG features: {self.supported_ag_features}")
for feature in AgFeature:
if self.supports_ag_feature(feature):
logger.info(f" - {feature.name}")
# 4.2.1.2 Codec Negotiation
# Secondly, in the initialization procedure, if the HF supports the
# Codec Negotiation feature, it shall check if the AT+BRSF command
# response from the AG has indicated that it supports the Codec
# Negotiation feature.
if self.supports_hf_feature(
HfFeature.CODEC_NEGOTIATION
) and self.supports_ag_feature(AgFeature.CODEC_NEGOTIATION):
# If both the HF and AG do support the Codec Negotiation feature
# then the HF shall send the AT+BAC=<HF available codecs> command to
# the AG to notify the AG of the available codecs in the HF.
codecs = [str(c) for c in self.supported_audio_codecs]
await self.execute_command(f"AT+BAC={','.join(codecs)}")
# 4.2.1.3 AG Indicators
# After having retrieved the supported features in the AG, the HF shall
# determine which indicators are supported by the AG, as well as the
# ordering of the supported indicators. This is because, according to
# the 3GPP 27.007 specification [2], the AG may support additional
# indicators not provided for by the Hands-Free Profile, and because the
# ordering of the indicators is implementation specific. The HF uses
# the AT+CIND=? Test command to retrieve information about the supported
# indicators and their ordering.
response = await self.execute_command(
"AT+CIND=?", response_type=AtResponseType.SINGLE
)
self.ag_indicators = []
for index, indicator in enumerate(response.parameters):
description = indicator[0].decode()
supported_values = []
for value in indicator[1]:
value = value.split(b'-')
value = [int(v) for v in value]
value_min = value[0]
value_max = value[1] if len(value) > 1 else value[0]
supported_values.extend([v for v in range(value_min, value_max + 1)])
self.ag_indicators.append(
AgIndicatorState(description, index, set(supported_values), 0)
)
# Once the HF has the necessary supported indicator and ordering
# information, it shall retrieve the current status of the indicators
# in the AG using the AT+CIND? Read command.
response = await self.execute_command(
"AT+CIND?", response_type=AtResponseType.SINGLE
)
for index, indicator in enumerate(response.parameters):
self.ag_indicators[index].current_status = int(indicator)
# After having retrieved the status of the indicators in the AG, the HF
# shall then enable the "Indicators status update" function in the AG by
# issuing the AT+CMER command, to which the AG shall respond with OK.
await self.execute_command("AT+CMER=3,,,1")
if self.supports_hf_feature(
HfFeature.THREE_WAY_CALLING
) and self.supports_ag_feature(HfFeature.THREE_WAY_CALLING):
# After the HF has enabled the “Indicators status update” function in
# the AG, and if the “Call waiting and 3-way calling” bit was set in the
# supported features bitmap by both the HF and the AG, the HF shall
# issue the AT+CHLD=? test command to retrieve the information about how
# the call hold and multiparty services are supported in the AG. The HF
# shall not issue the AT+CHLD=? test command in case either the HF or
# the AG does not support the "Three-way calling" feature.
response = await self.execute_command(
"AT+CHLD=?", response_type=AtResponseType.SINGLE
)
self.supported_ag_call_hold_operations = [
CallHoldOperation(int(operation))
for operation in response.parameters[0]
if not b'x' in operation
]
# 4.2.1.4 HF Indicators
# If the HF supports the HF indicator feature, it shall check the +BRSF
# response to see if the AG also supports the HF Indicator feature.
if self.supports_hf_feature(
HfFeature.HF_INDICATORS
) and self.supports_ag_feature(AgFeature.HF_INDICATORS):
# If both the HF and AG support the HF Indicator feature, then the HF
# shall send the AT+BIND=<HF supported HF indicators> command to the AG
# to notify the AG of the supported indicators assigned numbers in the
# HF. The AG shall respond with OK
indicators = [str(i) for i in self.hf_indicators.keys()]
await self.execute_command(f"AT+BIND={','.join(indicators)}")
# After having provided the AG with the HF indicators it supports,
# the HF shall send the AT+BIND=? to request HF indicators supported
# by the AG. The AG shall reply with the +BIND response listing all
# HF indicators that it supports followed by an OK.
response = await self.execute_command(
"AT+BIND=?", response_type=AtResponseType.SINGLE
)
logger.info("supported HF indicators:")
for indicator in response.parameters[0]:
indicator = HfIndicator(int(indicator))
logger.info(f" - {indicator.name}")
if indicator in self.hf_indicators:
self.hf_indicators[indicator].supported = True
# Once the HF receives the supported HF indicators list from the AG,
# the HF shall send the AT+BIND? command to determine which HF
# indicators are enabled. The AG shall respond with one or more
# +BIND responses. The AG shall terminate the list with OK.
# (See Section 4.36.1.3).
responses = await self.execute_command(
"AT+BIND?", response_type=AtResponseType.MULTIPLE
)
logger.info("enabled HF indicators:")
for response in responses:
indicator = HfIndicator(int(response.parameters[0]))
enabled = int(response.parameters[1]) != 0
logger.info(f" - {indicator.name}: {enabled}")
if indicator in self.hf_indicators:
self.hf_indicators[indicator].enabled = True
logger.info("SLC setup completed")
# 4.11.2 Audio Connection Setup by HF
async def setup_audio_connection(self):
# When the HF triggers the establishment of the Codec Connection it
# shall send the AT command AT+BCC to the AG. The AG shall respond with
# OK if it will start the Codec Connection procedure, and with ERROR
# if it cannot start the Codec Connection procedure.
await self.execute_command("AT+BCC")
# 4.11.3 Codec Connection Setup
async def setup_codec_connection(self, codec_id: int):
# The AG shall send a +BCS=<Codec ID> unsolicited response to the HF.
# The HF shall then respond to the incoming unsolicited response with
# the AT command AT+BCS=<Codec ID>. The ID shall be the same as in the
# unsolicited response code as long as the ID is supported.
# If the received ID is not available, the HF shall respond with
# AT+BAC with its available codecs.
if codec_id not in self.supported_audio_codecs:
codecs = [str(c) for c in self.supported_audio_codecs]
await self.execute_command(f"AT+BAC={','.join(codecs)}")
return
await self.execute_command(f"AT+BCS={codec_id}")
# After sending the OK response, the AG shall open the
# Synchronous Connection with the settings that are determined by the
# ID. The HF shall be ready to accept the synchronous connection
# establishment as soon as it has sent the AT commands AT+BCS=<Codec ID>.
logger.info("codec connection setup completed")
# 4.13.1 Answer Incoming Call from the HF In-Band Ringing
async def answer_incoming_call(self):
# The user accepts the incoming voice call by using the proper means
# provided by the HF. The HF shall then send the ATA command
# (see Section 4.34) to the AG. The AG shall then begin the procedure for
# accepting the incoming call.
await self.execute_command("ATA")
# 4.14.1 Reject an Incoming Call from the HF
async def reject_incoming_call(self):
# The user rejects the incoming call by using the User Interface on the
# Hands-Free unit. The HF shall then send the AT+CHUP command
# (see Section 4.34) to the AG. This may happen at any time during the
# procedures described in Sections 4.13.1 and 4.13.2.
await self.execute_command("AT+CHUP")
# 4.15.1 Terminate a Call Process from the HF
async def terminate_call(self):
# The user may abort the ongoing call process using whatever means
# provided by the Hands-Free unit. The HF shall send AT+CHUP command
# (see Section 4.34) to the AG, and the AG shall then start the
# procedure to terminate or interrupt the current call procedure.
# The AG shall then send the OK indication followed by the +CIEV result
# code, with the value indicating (call=0).
await self.execute_command("AT+CHUP")
async def update_ag_indicator(self, index: int, value: int):
self.ag_indicators[index].current_status = value
logger.info(
f"AG indicator updated: {self.ag_indicators[index].description}, {value}"
)
async def handle_unsolicited(self):
"""Handle unsolicited result codes sent by the audio gateway."""
result = await self.unsolicited_queue.get()
if result.code == "+BCS":
await self.setup_codec_connection(int(result.parameters[0]))
elif result.code == "+CIEV":
await self.update_ag_indicator(
int(result.parameters[0]), int(result.parameters[1])
)
else:
logging.info(f"unhandled unsolicited response {result.code}")
async def run(self):
"""Main rountine for the Hands-Free side of the HFP protocol.
Initiates the service level connection then loops handling
unsolicited AG responses."""
try:
await self.initiate_slc()
while True:
await self.handle_unsolicited()
except Exception:
logger.error("HFP-HF protocol failed with the following error:")
logger.error(traceback.format_exc())
# -----------------------------------------------------------------------------
# Normative SDP definitions
# -----------------------------------------------------------------------------
# Profile version (normative).
# Hands-Free Profile v1.8, 5.3 SDP Interoperability Requirements
class ProfileVersion(enum.IntEnum):
V1_5 = 0x0105
V1_6 = 0x0106
V1_7 = 0x0107
V1_8 = 0x0108
V1_9 = 0x0109
# HF supported features (normative).
# Hands-Free Profile v1.8, 5.3 SDP Interoperability Requirements
class HfSdpFeature(enum.IntFlag):
EC_NR = 0x01 # Echo Cancel & Noise reduction
THREE_WAY_CALLING = 0x02
CLI_PRESENTATION_CAPABILITY = 0x04
VOICE_RECOGNITION_ACTIVATION = 0x08
REMOTE_VOLUME_CONTROL = 0x10
WIDE_BAND = 0x20 # Wide band speech
ENHANCED_VOICE_RECOGNITION_STATUS = 0x40
VOICE_RECOGNITION_TEST = 0x80
# AG supported features (normative).
# Hands-Free Profile v1.8, 5.3 SDP Interoperability Requirements
class AgSdpFeature(enum.IntFlag):
THREE_WAY_CALLING = 0x01
EC_NR = 0x02 # Echo Cancel & Noise reduction
VOICE_RECOGNITION_FUNCTION = 0x04
IN_BAND_RING_TONE_CAPABILITY = 0x08
VOICE_TAG = 0x10 # Attach a number to voice tag
WIDE_BAND = 0x20 # Wide band speech
ENHANCED_VOICE_RECOGNITION_STATUS = 0x40
VOICE_RECOGNITION_TEST = 0x80
def sdp_records(
service_record_handle: int, rfcomm_channel: int, configuration: Configuration
) -> List[ServiceAttribute]:
"""Generate the SDP record for HFP Hands-Free support.
The record exposes the features supported in the input configuration,
and the allocated RFCOMM channel."""
hf_supported_features = 0
if HfFeature.EC_NR in configuration.supported_hf_features:
hf_supported_features |= HfSdpFeature.EC_NR
if HfFeature.THREE_WAY_CALLING in configuration.supported_hf_features:
hf_supported_features |= HfSdpFeature.THREE_WAY_CALLING
if HfFeature.CLI_PRESENTATION_CAPABILITY in configuration.supported_hf_features:
hf_supported_features |= HfSdpFeature.CLI_PRESENTATION_CAPABILITY
if HfFeature.VOICE_RECOGNITION_ACTIVATION in configuration.supported_hf_features:
hf_supported_features |= HfSdpFeature.VOICE_RECOGNITION_ACTIVATION
if HfFeature.REMOTE_VOLUME_CONTROL in configuration.supported_hf_features:
hf_supported_features |= HfSdpFeature.REMOTE_VOLUME_CONTROL
if (
HfFeature.ENHANCED_VOICE_RECOGNITION_STATUS
in configuration.supported_hf_features
):
hf_supported_features |= HfSdpFeature.ENHANCED_VOICE_RECOGNITION_STATUS
if HfFeature.VOICE_RECOGNITION_TEST in configuration.supported_hf_features:
hf_supported_features |= HfSdpFeature.VOICE_RECOGNITION_TEST
if AudioCodec.MSBC in configuration.supported_audio_codecs:
hf_supported_features |= HfSdpFeature.WIDE_BAND
return [
ServiceAttribute(
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
DataElement.unsigned_integer_32(service_record_handle),
),
ServiceAttribute(
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.uuid(BT_HANDSFREE_SERVICE),
DataElement.uuid(BT_GENERIC_AUDIO_SERVICE),
]
),
),
ServiceAttribute(
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.sequence([DataElement.uuid(BT_L2CAP_PROTOCOL_ID)]),
DataElement.sequence(
[
DataElement.uuid(BT_RFCOMM_PROTOCOL_ID),
DataElement.unsigned_integer_8(rfcomm_channel),
]
),
]
),
),
ServiceAttribute(
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.sequence(
[
DataElement.uuid(BT_HANDSFREE_SERVICE),
DataElement.unsigned_integer_16(ProfileVersion.V1_8),
]
)
]
),
),
ServiceAttribute(
SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID,
DataElement.unsigned_integer_16(hf_supported_features),
),
]

View File

@@ -20,13 +20,13 @@ import collections
import logging
import struct
from typing import Optional
from bumble.colors import color
from bumble.l2cap import L2CAP_PDU
from bumble.snoop import Snooper
from bumble import drivers
from typing import Optional
from .hci import (
Address,
HCI_ACL_DATA_PACKET,
@@ -63,16 +63,15 @@ from .hci import (
HCI_Read_Local_Version_Information_Command,
HCI_Reset_Command,
HCI_Set_Event_Mask_Command,
map_null_terminated_utf8_string,
)
from .core import (
BT_BR_EDR_TRANSPORT,
BT_CENTRAL_ROLE,
BT_LE_TRANSPORT,
ConnectionPHY,
ConnectionParameters,
)
from .utils import AbortableEventEmitter
from .transport.common import TransportLostError
# -----------------------------------------------------------------------------
@@ -349,7 +348,7 @@ class Host(AbortableEventEmitter):
return response
except Exception as error:
logger.warning(
f'{color("!!! Exception while sending HCI packet:", "red")} {error}'
f'{color("!!! Exception while sending command:", "red")} {error}'
)
raise error
finally:
@@ -455,6 +454,13 @@ class Host(AbortableEventEmitter):
else:
logger.debug('reset not done, ignoring packet from controller')
def on_transport_lost(self):
# Called by the source when the transport has been lost.
if self.pending_response:
self.pending_response.set_exception(TransportLostError('transport lost'))
self.emit('flush')
def on_hci_packet(self, packet):
logger.debug(f'{color("### CONTROLLER -> HOST", "green")}: {packet}')

View File

@@ -22,7 +22,19 @@ import struct
from collections import deque
from pyee import EventEmitter
from typing import Dict, Type
from typing import (
Dict,
Type,
List,
Optional,
Tuple,
Callable,
Any,
Union,
Deque,
Iterable,
TYPE_CHECKING,
)
from .colors import color
from .core import BT_CENTRAL_ROLE, InvalidStateError, ProtocolError
@@ -33,6 +45,9 @@ from .hci import (
name_or_number,
)
if TYPE_CHECKING:
from bumble.device import Connection
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
@@ -155,7 +170,7 @@ class L2CAP_PDU:
'''
@staticmethod
def from_bytes(data):
def from_bytes(data: bytes) -> L2CAP_PDU:
# Sanity check
if len(data) < 4:
raise ValueError('not enough data for L2CAP header')
@@ -165,18 +180,18 @@ class L2CAP_PDU:
return L2CAP_PDU(l2cap_pdu_cid, l2cap_pdu_payload)
def to_bytes(self):
def to_bytes(self) -> bytes:
header = struct.pack('<HH', len(self.payload), self.cid)
return header + self.payload
def __init__(self, cid, payload):
def __init__(self, cid: int, payload: bytes) -> None:
self.cid = cid
self.payload = payload
def __bytes__(self):
def __bytes__(self) -> bytes:
return self.to_bytes()
def __str__(self):
def __str__(self) -> str:
return f'{color("L2CAP", "green")} [CID={self.cid}]: {self.payload.hex()}'
@@ -188,10 +203,10 @@ class L2CAP_Control_Frame:
classes: Dict[int, Type[L2CAP_Control_Frame]] = {}
code = 0
name = None
name: str
@staticmethod
def from_bytes(pdu):
def from_bytes(pdu: bytes) -> L2CAP_Control_Frame:
code = pdu[0]
cls = L2CAP_Control_Frame.classes.get(code)
@@ -216,11 +231,11 @@ class L2CAP_Control_Frame:
return self
@staticmethod
def code_name(code):
def code_name(code: int) -> str:
return name_or_number(L2CAP_CONTROL_FRAME_NAMES, code)
@staticmethod
def decode_configuration_options(data):
def decode_configuration_options(data: bytes) -> List[Tuple[int, bytes]]:
options = []
while len(data) >= 2:
value_type = data[0]
@@ -232,7 +247,7 @@ class L2CAP_Control_Frame:
return options
@staticmethod
def encode_configuration_options(options):
def encode_configuration_options(options: List[Tuple[int, bytes]]) -> bytes:
return b''.join(
[bytes([option[0], len(option[1])]) + option[1] for option in options]
)
@@ -256,29 +271,30 @@ class L2CAP_Control_Frame:
return inner
def __init__(self, pdu=None, **kwargs):
def __init__(self, pdu=None, **kwargs) -> None:
self.identifier = kwargs.get('identifier', 0)
if hasattr(self, 'fields') and kwargs:
HCI_Object.init_from_fields(self, self.fields, kwargs)
if pdu is None:
data = HCI_Object.dict_to_bytes(kwargs, self.fields)
pdu = (
bytes([self.code, self.identifier])
+ struct.pack('<H', len(data))
+ data
)
if hasattr(self, 'fields'):
if kwargs:
HCI_Object.init_from_fields(self, self.fields, kwargs)
if pdu is None:
data = HCI_Object.dict_to_bytes(kwargs, self.fields)
pdu = (
bytes([self.code, self.identifier])
+ struct.pack('<H', len(data))
+ data
)
self.pdu = pdu
def init_from_bytes(self, pdu, offset):
return HCI_Object.init_from_bytes(self, pdu, offset, self.fields)
def to_bytes(self):
def to_bytes(self) -> bytes:
return self.pdu
def __bytes__(self):
def __bytes__(self) -> bytes:
return self.to_bytes()
def __str__(self):
def __str__(self) -> str:
result = f'{color(self.name, "yellow")} [ID={self.identifier}]'
if fields := getattr(self, 'fields', None):
result += ':\n' + HCI_Object.format_fields(self.__dict__, fields, ' ')
@@ -315,7 +331,7 @@ class L2CAP_Command_Reject(L2CAP_Control_Frame):
}
@staticmethod
def reason_name(reason):
def reason_name(reason: int) -> str:
return name_or_number(L2CAP_Command_Reject.REASON_NAMES, reason)
@@ -343,7 +359,7 @@ class L2CAP_Connection_Request(L2CAP_Control_Frame):
'''
@staticmethod
def parse_psm(data, offset=0):
def parse_psm(data: bytes, offset: int = 0) -> Tuple[int, int]:
psm_length = 2
psm = data[offset] | data[offset + 1] << 8
@@ -355,7 +371,7 @@ class L2CAP_Connection_Request(L2CAP_Control_Frame):
return offset + psm_length, psm
@staticmethod
def serialize_psm(psm):
def serialize_psm(psm: int) -> bytes:
serialized = struct.pack('<H', psm & 0xFFFF)
psm >>= 16
while psm:
@@ -405,7 +421,7 @@ class L2CAP_Connection_Response(L2CAP_Control_Frame):
}
@staticmethod
def result_name(result):
def result_name(result: int) -> str:
return name_or_number(L2CAP_Connection_Response.RESULT_NAMES, result)
@@ -452,7 +468,7 @@ class L2CAP_Configure_Response(L2CAP_Control_Frame):
}
@staticmethod
def result_name(result):
def result_name(result: int) -> str:
return name_or_number(L2CAP_Configure_Response.RESULT_NAMES, result)
@@ -529,7 +545,7 @@ class L2CAP_Information_Request(L2CAP_Control_Frame):
}
@staticmethod
def info_type_name(info_type):
def info_type_name(info_type: int) -> str:
return name_or_number(L2CAP_Information_Request.INFO_TYPE_NAMES, info_type)
@@ -556,7 +572,7 @@ class L2CAP_Information_Response(L2CAP_Control_Frame):
RESULT_NAMES = {SUCCESS: 'SUCCESS', NOT_SUPPORTED: 'NOT_SUPPORTED'}
@staticmethod
def result_name(result):
def result_name(result: int) -> str:
return name_or_number(L2CAP_Information_Response.RESULT_NAMES, result)
@@ -588,6 +604,8 @@ class L2CAP_LE_Credit_Based_Connection_Request(L2CAP_Control_Frame):
(CODE 0x14)
'''
source_cid: int
# -----------------------------------------------------------------------------
@L2CAP_Control_Frame.subclass(
@@ -640,7 +658,7 @@ class L2CAP_LE_Credit_Based_Connection_Response(L2CAP_Control_Frame):
}
@staticmethod
def result_name(result):
def result_name(result: int) -> str:
return name_or_number(
L2CAP_LE_Credit_Based_Connection_Response.RESULT_NAMES, result
)
@@ -701,7 +719,22 @@ class Channel(EventEmitter):
WAIT_CONTROL_IND: 'WAIT_CONTROL_IND',
}
def __init__(self, manager, connection, signaling_cid, psm, source_cid, mtu):
connection_result: Optional[asyncio.Future[None]]
disconnection_result: Optional[asyncio.Future[None]]
response: Optional[asyncio.Future[bytes]]
sink: Optional[Callable[[bytes], Any]]
state: int
connection: Connection
def __init__(
self,
manager: 'ChannelManager',
connection: Connection,
signaling_cid: int,
psm: int,
source_cid: int,
mtu: int,
) -> None:
super().__init__()
self.manager = manager
self.connection = connection
@@ -716,19 +749,19 @@ class Channel(EventEmitter):
self.disconnection_result = None
self.sink = None
def change_state(self, new_state):
def change_state(self, new_state: int) -> None:
logger.debug(
f'{self} state change -> {color(Channel.STATE_NAMES[new_state], "cyan")}'
)
self.state = new_state
def send_pdu(self, pdu):
def send_pdu(self, pdu) -> None:
self.manager.send_pdu(self.connection, self.destination_cid, pdu)
def send_control_frame(self, frame):
def send_control_frame(self, frame) -> None:
self.manager.send_control_frame(self.connection, self.signaling_cid, frame)
async def send_request(self, request):
async def send_request(self, request) -> bytes:
# Check that there isn't already a request pending
if self.response:
raise InvalidStateError('request already pending')
@@ -739,7 +772,7 @@ class Channel(EventEmitter):
self.send_pdu(request)
return await self.response
def on_pdu(self, pdu):
def on_pdu(self, pdu) -> None:
if self.response:
self.response.set_result(pdu)
self.response = None
@@ -751,7 +784,7 @@ class Channel(EventEmitter):
color('received pdu without a pending request or sink', 'red')
)
async def connect(self):
async def connect(self) -> None:
if self.state != Channel.CLOSED:
raise InvalidStateError('invalid state')
@@ -778,7 +811,7 @@ class Channel(EventEmitter):
finally:
self.connection_result = None
async def disconnect(self):
async def disconnect(self) -> None:
if self.state != Channel.OPEN:
raise InvalidStateError('invalid state')
@@ -796,12 +829,12 @@ class Channel(EventEmitter):
self.disconnection_result = asyncio.get_running_loop().create_future()
return await self.disconnection_result
def abort(self):
def abort(self) -> None:
if self.state == self.OPEN:
self.change_state(self.CLOSED)
self.emit('close')
def send_configure_request(self):
def send_configure_request(self) -> None:
options = L2CAP_Control_Frame.encode_configuration_options(
[
(
@@ -819,7 +852,7 @@ class Channel(EventEmitter):
)
)
def on_connection_request(self, request):
def on_connection_request(self, request) -> None:
self.destination_cid = request.source_cid
self.change_state(Channel.WAIT_CONNECT)
self.send_control_frame(
@@ -858,7 +891,7 @@ class Channel(EventEmitter):
)
self.connection_result = None
def on_configure_request(self, request):
def on_configure_request(self, request) -> None:
if self.state not in (
Channel.WAIT_CONFIG,
Channel.WAIT_CONFIG_REQ,
@@ -896,7 +929,7 @@ class Channel(EventEmitter):
elif self.state == Channel.WAIT_CONFIG_REQ_RSP:
self.change_state(Channel.WAIT_CONFIG_RSP)
def on_configure_response(self, response):
def on_configure_response(self, response) -> None:
if response.result == L2CAP_Configure_Response.SUCCESS:
if self.state == Channel.WAIT_CONFIG_REQ_RSP:
self.change_state(Channel.WAIT_CONFIG_REQ)
@@ -930,7 +963,7 @@ class Channel(EventEmitter):
)
# TODO: decide how to fail gracefully
def on_disconnection_request(self, request):
def on_disconnection_request(self, request) -> None:
if self.state in (Channel.OPEN, Channel.WAIT_DISCONNECT):
self.send_control_frame(
L2CAP_Disconnection_Response(
@@ -945,7 +978,7 @@ class Channel(EventEmitter):
else:
logger.warning(color('invalid state', 'red'))
def on_disconnection_response(self, response):
def on_disconnection_response(self, response) -> None:
if self.state != Channel.WAIT_DISCONNECT:
logger.warning(color('invalid state', 'red'))
return
@@ -964,7 +997,7 @@ class Channel(EventEmitter):
self.emit('close')
self.manager.on_channel_closed(self)
def __str__(self):
def __str__(self) -> str:
return (
f'Channel({self.source_cid}->{self.destination_cid}, '
f'PSM={self.psm}, '
@@ -995,25 +1028,32 @@ class LeConnectionOrientedChannel(EventEmitter):
CONNECTION_ERROR: 'CONNECTION_ERROR',
}
out_queue: Deque[bytes]
connection_result: Optional[asyncio.Future[LeConnectionOrientedChannel]]
disconnection_result: Optional[asyncio.Future[None]]
out_sdu: Optional[bytes]
state: int
connection: Connection
@staticmethod
def state_name(state):
def state_name(state: int) -> str:
return name_or_number(LeConnectionOrientedChannel.STATE_NAMES, state)
def __init__(
self,
manager,
connection,
le_psm,
source_cid,
destination_cid,
mtu,
mps,
credits, # pylint: disable=redefined-builtin
peer_mtu,
peer_mps,
peer_credits,
connected,
):
manager: 'ChannelManager',
connection: Connection,
le_psm: int,
source_cid: int,
destination_cid: int,
mtu: int,
mps: int,
credits: int, # pylint: disable=redefined-builtin
peer_mtu: int,
peer_mps: int,
peer_credits: int,
connected: bool,
) -> None:
super().__init__()
self.manager = manager
self.connection = connection
@@ -1045,7 +1085,7 @@ class LeConnectionOrientedChannel(EventEmitter):
else:
self.state = LeConnectionOrientedChannel.INIT
def change_state(self, new_state):
def change_state(self, new_state: int) -> None:
logger.debug(
f'{self} state change -> {color(self.state_name(new_state), "cyan")}'
)
@@ -1056,13 +1096,13 @@ class LeConnectionOrientedChannel(EventEmitter):
elif new_state == self.DISCONNECTED:
self.emit('close')
def send_pdu(self, pdu):
def send_pdu(self, pdu) -> None:
self.manager.send_pdu(self.connection, self.destination_cid, pdu)
def send_control_frame(self, frame):
def send_control_frame(self, frame) -> None:
self.manager.send_control_frame(self.connection, L2CAP_LE_SIGNALING_CID, frame)
async def connect(self):
async def connect(self) -> LeConnectionOrientedChannel:
# Check that we're in the right state
if self.state != self.INIT:
raise InvalidStateError('not in a connectable state')
@@ -1090,7 +1130,7 @@ class LeConnectionOrientedChannel(EventEmitter):
# Wait for the connection to succeed or fail
return await self.connection_result
async def disconnect(self):
async def disconnect(self) -> None:
# Check that we're connected
if self.state != self.CONNECTED:
raise InvalidStateError('not connected')
@@ -1110,11 +1150,11 @@ class LeConnectionOrientedChannel(EventEmitter):
self.disconnection_result = asyncio.get_running_loop().create_future()
return await self.disconnection_result
def abort(self):
def abort(self) -> None:
if self.state == self.CONNECTED:
self.change_state(self.DISCONNECTED)
def on_pdu(self, pdu):
def on_pdu(self, pdu) -> None:
if self.sink is None:
logger.warning('received pdu without a sink')
return
@@ -1180,7 +1220,7 @@ class LeConnectionOrientedChannel(EventEmitter):
self.in_sdu = None
self.in_sdu_length = 0
def on_connection_response(self, response):
def on_connection_response(self, response) -> None:
# Look for a matching pending response result
if self.connection_result is None:
logger.warning(
@@ -1214,14 +1254,14 @@ class LeConnectionOrientedChannel(EventEmitter):
# Cleanup
self.connection_result = None
def on_credits(self, credits): # pylint: disable=redefined-builtin
def on_credits(self, credits: int) -> None: # pylint: disable=redefined-builtin
self.credits += credits
logger.debug(f'received {credits} credits, total = {self.credits}')
# Try to send more data if we have any queued up
self.process_output()
def on_disconnection_request(self, request):
def on_disconnection_request(self, request) -> None:
self.send_control_frame(
L2CAP_Disconnection_Response(
identifier=request.identifier,
@@ -1232,7 +1272,7 @@ class LeConnectionOrientedChannel(EventEmitter):
self.change_state(self.DISCONNECTED)
self.flush_output()
def on_disconnection_response(self, response):
def on_disconnection_response(self, response) -> None:
if self.state != self.DISCONNECTING:
logger.warning(color('invalid state', 'red'))
return
@@ -1249,11 +1289,11 @@ class LeConnectionOrientedChannel(EventEmitter):
self.disconnection_result.set_result(None)
self.disconnection_result = None
def flush_output(self):
def flush_output(self) -> None:
self.out_queue.clear()
self.out_sdu = None
def process_output(self):
def process_output(self) -> None:
while self.credits > 0:
if self.out_sdu is not None:
# Finish the current SDU
@@ -1296,7 +1336,7 @@ class LeConnectionOrientedChannel(EventEmitter):
self.drained.set()
return
def write(self, data):
def write(self, data: bytes) -> None:
if self.state != self.CONNECTED:
logger.warning('not connected, dropping data')
return
@@ -1311,18 +1351,18 @@ class LeConnectionOrientedChannel(EventEmitter):
# Send what we can
self.process_output()
async def drain(self):
async def drain(self) -> None:
await self.drained.wait()
def pause_reading(self):
def pause_reading(self) -> None:
# TODO: not implemented yet
pass
def resume_reading(self):
def resume_reading(self) -> None:
# TODO: not implemented yet
pass
def __str__(self):
def __str__(self) -> str:
return (
f'CoC({self.source_cid}->{self.destination_cid}, '
f'State={self.state_name(self.state)}, '
@@ -1335,9 +1375,21 @@ class LeConnectionOrientedChannel(EventEmitter):
# -----------------------------------------------------------------------------
class ChannelManager:
identifiers: Dict[int, int]
channels: Dict[int, Dict[int, Union[Channel, LeConnectionOrientedChannel]]]
servers: Dict[int, Callable[[Channel], Any]]
le_coc_channels: Dict[int, Dict[int, LeConnectionOrientedChannel]]
le_coc_servers: Dict[
int, Tuple[Callable[[LeConnectionOrientedChannel], Any], int, int, int]
]
le_coc_requests: Dict[int, L2CAP_LE_Credit_Based_Connection_Request]
fixed_channels: Dict[int, Optional[Callable[[int, bytes], Any]]]
def __init__(
self, extended_features=(), connectionless_mtu=L2CAP_DEFAULT_CONNECTIONLESS_MTU
):
self,
extended_features: Iterable[int] = (),
connectionless_mtu: int = L2CAP_DEFAULT_CONNECTIONLESS_MTU,
) -> None:
self._host = None
self.identifiers = {} # Incrementing identifier values by connection
self.channels = {} # All channels, mapped by connection and source cid
@@ -1366,20 +1418,20 @@ class ChannelManager:
if host is not None:
host.on('disconnection', self.on_disconnection)
def find_channel(self, connection_handle, cid):
def find_channel(self, connection_handle: int, cid: int):
if connection_channels := self.channels.get(connection_handle):
return connection_channels.get(cid)
return None
def find_le_coc_channel(self, connection_handle, cid):
def find_le_coc_channel(self, connection_handle: int, cid: int):
if connection_channels := self.le_coc_channels.get(connection_handle):
return connection_channels.get(cid)
return None
@staticmethod
def find_free_br_edr_cid(channels):
def find_free_br_edr_cid(channels: Iterable[int]) -> int:
# Pick the smallest valid CID that's not already in the list
# (not necessarily the most efficient algorithm, but the list of CID is
# very small in practice)
@@ -1392,7 +1444,7 @@ class ChannelManager:
raise RuntimeError('no free CID available')
@staticmethod
def find_free_le_cid(channels):
def find_free_le_cid(channels: Iterable[int]) -> int:
# Pick the smallest valid CID that's not already in the list
# (not necessarily the most efficient algorithm, but the list of CID is
# very small in practice)
@@ -1405,7 +1457,7 @@ class ChannelManager:
raise RuntimeError('no free CID')
@staticmethod
def check_le_coc_parameters(max_credits, mtu, mps):
def check_le_coc_parameters(max_credits: int, mtu: int, mps: int) -> None:
if (
max_credits < 1
or max_credits > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_CREDITS
@@ -1419,19 +1471,21 @@ class ChannelManager:
):
raise ValueError('MPS out of range')
def next_identifier(self, connection):
def next_identifier(self, connection: Connection) -> int:
identifier = (self.identifiers.setdefault(connection.handle, 0) + 1) % 256
self.identifiers[connection.handle] = identifier
return identifier
def register_fixed_channel(self, cid, handler):
def register_fixed_channel(
self, cid: int, handler: Callable[[int, bytes], Any]
) -> None:
self.fixed_channels[cid] = handler
def deregister_fixed_channel(self, cid):
def deregister_fixed_channel(self, cid: int) -> None:
if cid in self.fixed_channels:
del self.fixed_channels[cid]
def register_server(self, psm, server):
def register_server(self, psm: int, server: Callable[[Channel], Any]) -> int:
if psm == 0:
# Find a free PSM
for candidate in range(
@@ -1465,12 +1519,12 @@ class ChannelManager:
def register_le_coc_server(
self,
psm,
server,
max_credits=L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS,
mtu=L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU,
mps=L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS,
):
psm: int,
server: Callable[[LeConnectionOrientedChannel], Any],
max_credits: int = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS,
mtu: int = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU,
mps: int = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS,
) -> int:
self.check_le_coc_parameters(max_credits, mtu, mps)
if psm == 0:
@@ -1498,7 +1552,7 @@ class ChannelManager:
return psm
def on_disconnection(self, connection_handle, _reason):
def on_disconnection(self, connection_handle: int, _reason: int) -> None:
logger.debug(f'disconnection from {connection_handle}, cleaning up channels')
if connection_handle in self.channels:
for _, channel in self.channels[connection_handle].items():
@@ -1511,7 +1565,7 @@ class ChannelManager:
if connection_handle in self.identifiers:
del self.identifiers[connection_handle]
def send_pdu(self, connection, cid, pdu):
def send_pdu(self, connection, cid: int, pdu) -> None:
pdu_str = pdu.hex() if isinstance(pdu, bytes) else str(pdu)
logger.debug(
f'{color(">>> Sending L2CAP PDU", "blue")} '
@@ -1520,14 +1574,16 @@ class ChannelManager:
)
self.host.send_l2cap_pdu(connection.handle, cid, bytes(pdu))
def on_pdu(self, connection, cid, pdu):
def on_pdu(self, connection: Connection, cid: int, pdu) -> None:
if cid in (L2CAP_SIGNALING_CID, L2CAP_LE_SIGNALING_CID):
# Parse the L2CAP payload into a Control Frame object
control_frame = L2CAP_Control_Frame.from_bytes(pdu)
self.on_control_frame(connection, cid, control_frame)
elif cid in self.fixed_channels:
self.fixed_channels[cid](connection.handle, pdu)
handler = self.fixed_channels[cid]
assert handler is not None
handler(connection.handle, pdu)
else:
if (channel := self.find_channel(connection.handle, cid)) is None:
logger.warning(
@@ -1539,7 +1595,9 @@ class ChannelManager:
channel.on_pdu(pdu)
def send_control_frame(self, connection, cid, control_frame):
def send_control_frame(
self, connection: Connection, cid: int, control_frame
) -> None:
logger.debug(
f'{color(">>> Sending L2CAP Signaling Control Frame", "blue")} '
f'on connection [0x{connection.handle:04X}] (CID={cid}) '
@@ -1547,7 +1605,7 @@ class ChannelManager:
)
self.host.send_l2cap_pdu(connection.handle, cid, bytes(control_frame))
def on_control_frame(self, connection, cid, control_frame):
def on_control_frame(self, connection: Connection, cid: int, control_frame) -> None:
logger.debug(
f'{color("<<< Received L2CAP Signaling Control Frame", "green")} '
f'on connection [0x{connection.handle:04X}] (CID={cid}) '
@@ -1584,10 +1642,14 @@ class ChannelManager:
),
)
def on_l2cap_command_reject(self, _connection, _cid, packet):
def on_l2cap_command_reject(
self, _connection: Connection, _cid: int, packet
) -> None:
logger.warning(f'{color("!!! Command rejected:", "red")} {packet.reason}')
def on_l2cap_connection_request(self, connection, cid, request):
def on_l2cap_connection_request(
self, connection: Connection, cid: int, request
) -> None:
# Check if there's a server for this PSM
server = self.servers.get(request.psm)
if server:
@@ -1639,7 +1701,9 @@ class ChannelManager:
),
)
def on_l2cap_connection_response(self, connection, cid, response):
def on_l2cap_connection_response(
self, connection: Connection, cid: int, response
) -> None:
if (
channel := self.find_channel(connection.handle, response.source_cid)
) is None:
@@ -1654,7 +1718,9 @@ class ChannelManager:
channel.on_connection_response(response)
def on_l2cap_configure_request(self, connection, cid, request):
def on_l2cap_configure_request(
self, connection: Connection, cid: int, request
) -> None:
if (
channel := self.find_channel(connection.handle, request.destination_cid)
) is None:
@@ -1669,7 +1735,9 @@ class ChannelManager:
channel.on_configure_request(request)
def on_l2cap_configure_response(self, connection, cid, response):
def on_l2cap_configure_response(
self, connection: Connection, cid: int, response
) -> None:
if (
channel := self.find_channel(connection.handle, response.source_cid)
) is None:
@@ -1684,7 +1752,9 @@ class ChannelManager:
channel.on_configure_response(response)
def on_l2cap_disconnection_request(self, connection, cid, request):
def on_l2cap_disconnection_request(
self, connection: Connection, cid: int, request
) -> None:
if (
channel := self.find_channel(connection.handle, request.destination_cid)
) is None:
@@ -1699,7 +1769,9 @@ class ChannelManager:
channel.on_disconnection_request(request)
def on_l2cap_disconnection_response(self, connection, cid, response):
def on_l2cap_disconnection_response(
self, connection: Connection, cid: int, response
) -> None:
if (
channel := self.find_channel(connection.handle, response.source_cid)
) is None:
@@ -1714,7 +1786,7 @@ class ChannelManager:
channel.on_disconnection_response(response)
def on_l2cap_echo_request(self, connection, cid, request):
def on_l2cap_echo_request(self, connection: Connection, cid: int, request) -> None:
logger.debug(f'<<< Echo request: data={request.data.hex()}')
self.send_control_frame(
connection,
@@ -1722,11 +1794,15 @@ class ChannelManager:
L2CAP_Echo_Response(identifier=request.identifier, data=request.data),
)
def on_l2cap_echo_response(self, _connection, _cid, response):
def on_l2cap_echo_response(
self, _connection: Connection, _cid: int, response
) -> None:
logger.debug(f'<<< Echo response: data={response.data.hex()}')
# TODO notify listeners
def on_l2cap_information_request(self, connection, cid, request):
def on_l2cap_information_request(
self, connection: Connection, cid: int, request
) -> None:
if request.info_type == L2CAP_Information_Request.CONNECTIONLESS_MTU:
result = L2CAP_Information_Response.SUCCESS
data = self.connectionless_mtu.to_bytes(2, 'little')
@@ -1750,7 +1826,9 @@ class ChannelManager:
),
)
def on_l2cap_connection_parameter_update_request(self, connection, cid, request):
def on_l2cap_connection_parameter_update_request(
self, connection: Connection, cid: int, request
):
if connection.role == BT_CENTRAL_ROLE:
self.send_control_frame(
connection,
@@ -1769,7 +1847,7 @@ class ChannelManager:
supervision_timeout=request.timeout,
min_ce_length=0,
max_ce_length=0,
)
) # type: ignore[call-arg]
)
else:
self.send_control_frame(
@@ -1781,11 +1859,15 @@ class ChannelManager:
),
)
def on_l2cap_connection_parameter_update_response(self, connection, cid, response):
def on_l2cap_connection_parameter_update_response(
self, connection: Connection, cid: int, response
) -> None:
# TODO: check response
pass
def on_l2cap_le_credit_based_connection_request(self, connection, cid, request):
def on_l2cap_le_credit_based_connection_request(
self, connection: Connection, cid: int, request
) -> None:
if request.le_psm in self.le_coc_servers:
(server, max_credits, mtu, mps) = self.le_coc_servers[request.le_psm]
@@ -1887,7 +1969,9 @@ class ChannelManager:
),
)
def on_l2cap_le_credit_based_connection_response(self, connection, _cid, response):
def on_l2cap_le_credit_based_connection_response(
self, connection: Connection, _cid: int, response
) -> None:
# Find the pending request by identifier
request = self.le_coc_requests.get(response.identifier)
if request is None:
@@ -1910,7 +1994,9 @@ class ChannelManager:
# Process the response
channel.on_connection_response(response)
def on_l2cap_le_flow_control_credit(self, connection, _cid, credit):
def on_l2cap_le_flow_control_credit(
self, connection: Connection, _cid: int, credit
) -> None:
channel = self.find_le_coc_channel(connection.handle, credit.cid)
if channel is None:
logger.warning(f'received credits for an unknown channel (cid={credit.cid}')
@@ -1918,13 +2004,15 @@ class ChannelManager:
channel.on_credits(credit.credits)
def on_channel_closed(self, channel):
def on_channel_closed(self, channel: Channel) -> None:
connection_channels = self.channels.get(channel.connection.handle)
if connection_channels:
if channel.source_cid in connection_channels:
del connection_channels[channel.source_cid]
async def open_le_coc(self, connection, psm, max_credits, mtu, mps):
async def open_le_coc(
self, connection: Connection, psm: int, max_credits: int, mtu: int, mps: int
) -> LeConnectionOrientedChannel:
self.check_le_coc_parameters(max_credits, mtu, mps)
# Find a free CID for the new channel
@@ -1965,7 +2053,7 @@ class ChannelManager:
return channel
async def connect(self, connection, psm):
async def connect(self, connection: Connection, psm: int) -> Channel:
# NOTE: this implementation hard-codes BR/EDR
# Find a free CID for a new channel

View File

@@ -19,6 +19,7 @@ import enum
from typing import Optional, Tuple
from .hci import (
Address,
HCI_NO_INPUT_NO_OUTPUT_IO_CAPABILITY,
HCI_DISPLAY_ONLY_IO_CAPABILITY,
HCI_DISPLAY_YES_NO_IO_CAPABILITY,
@@ -168,21 +169,28 @@ class PairingDelegate:
class PairingConfig:
"""Configuration for the Pairing protocol."""
class AddressType(enum.IntEnum):
PUBLIC = Address.PUBLIC_DEVICE_ADDRESS
RANDOM = Address.RANDOM_DEVICE_ADDRESS
def __init__(
self,
sc: bool = True,
mitm: bool = True,
bonding: bool = True,
delegate: Optional[PairingDelegate] = None,
identity_address_type: Optional[AddressType] = None,
) -> None:
self.sc = sc
self.mitm = mitm
self.bonding = bonding
self.delegate = delegate or PairingDelegate()
self.identity_address_type = identity_address_type
def __str__(self) -> str:
return (
f'PairingConfig(sc={self.sc}, '
f'mitm={self.mitm}, bonding={self.bonding}, '
f'identity_address_type={self.identity_address_type}, '
f'delegate[{self.delegate.io_capability}])'
)

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from bumble.pairing import PairingDelegate
from bumble.pairing import PairingConfig, PairingDelegate
from dataclasses import dataclass
from typing import Any, Dict
@@ -20,6 +20,7 @@ from typing import Any, Dict
@dataclass
class Config:
io_capability: PairingDelegate.IoCapability = PairingDelegate.NO_OUTPUT_NO_INPUT
identity_address_type: PairingConfig.AddressType = PairingConfig.AddressType.RANDOM
pairing_sc_enable: bool = True
pairing_mitm_enable: bool = True
pairing_bonding_enable: bool = True
@@ -35,6 +36,12 @@ class Config:
'io_capability', 'no_output_no_input'
).upper()
self.io_capability = getattr(PairingDelegate, io_capability_name)
identity_address_type_name: str = config.get(
'identity_address_type', 'random'
).upper()
self.identity_address_type = getattr(
PairingConfig.AddressType, identity_address_type_name
)
self.pairing_sc_enable = config.get('pairing_sc_enable', True)
self.pairing_mitm_enable = config.get('pairing_mitm_enable', True)
self.pairing_bonding_enable = config.get('pairing_bonding_enable', True)

View File

@@ -34,6 +34,10 @@ from bumble.sdp import (
from typing import Any, Dict, List, Optional
# Default rootcanal HCI TCP address
ROOTCANAL_HCI_ADDRESS = "localhost:6402"
class PandoraDevice:
"""
Small wrapper around a Bumble device and it's HCI transport.
@@ -53,7 +57,9 @@ class PandoraDevice:
def __init__(self, config: Dict[str, Any]) -> None:
self.config = config
self.device = _make_device(config)
self._hci_name = config.get('transport', '')
self._hci_name = config.get(
'transport', f"tcp-client:{config.get('tcp', ROOTCANAL_HCI_ADDRESS)}"
)
self._hci = None
@property

View File

@@ -112,7 +112,7 @@ class HostService(HostServicer):
async def FactoryReset(
self, request: empty_pb2.Empty, context: grpc.ServicerContext
) -> empty_pb2.Empty:
self.log.info('FactoryReset')
self.log.debug('FactoryReset')
# delete all bonds
if self.device.keystore is not None:
@@ -126,7 +126,7 @@ class HostService(HostServicer):
async def Reset(
self, request: empty_pb2.Empty, context: grpc.ServicerContext
) -> empty_pb2.Empty:
self.log.info('Reset')
self.log.debug('Reset')
# clear service.
self.waited_connections.clear()
@@ -139,7 +139,7 @@ class HostService(HostServicer):
async def ReadLocalAddress(
self, request: empty_pb2.Empty, context: grpc.ServicerContext
) -> ReadLocalAddressResponse:
self.log.info('ReadLocalAddress')
self.log.debug('ReadLocalAddress')
return ReadLocalAddressResponse(
address=bytes(reversed(bytes(self.device.public_address)))
)
@@ -152,7 +152,7 @@ class HostService(HostServicer):
address = Address(
bytes(reversed(request.address)), address_type=Address.PUBLIC_DEVICE_ADDRESS
)
self.log.info(f"Connect to {address}")
self.log.debug(f"Connect to {address}")
try:
connection = await self.device.connect(
@@ -167,7 +167,7 @@ class HostService(HostServicer):
return ConnectResponse(connection_already_exists=empty_pb2.Empty())
raise e
self.log.info(f"Connect to {address} done (handle={connection.handle})")
self.log.debug(f"Connect to {address} done (handle={connection.handle})")
cookie = any_pb2.Any(value=connection.handle.to_bytes(4, 'big'))
return ConnectResponse(connection=Connection(cookie=cookie))
@@ -186,7 +186,7 @@ class HostService(HostServicer):
if address in (Address.NIL, Address.ANY):
raise ValueError('Invalid address')
self.log.info(f"WaitConnection from {address}...")
self.log.debug(f"WaitConnection from {address}...")
connection = self.device.find_connection_by_bd_addr(
address, transport=BT_BR_EDR_TRANSPORT
@@ -201,7 +201,7 @@ class HostService(HostServicer):
# save connection has waited and respond.
self.waited_connections.add(id(connection))
self.log.info(
self.log.debug(
f"WaitConnection from {address} done (handle={connection.handle})"
)
@@ -216,7 +216,7 @@ class HostService(HostServicer):
if address in (Address.NIL, Address.ANY):
raise ValueError('Invalid address')
self.log.info(f"ConnectLE to {address}...")
self.log.debug(f"ConnectLE to {address}...")
try:
connection = await self.device.connect(
@@ -233,7 +233,7 @@ class HostService(HostServicer):
return ConnectLEResponse(connection_already_exists=empty_pb2.Empty())
raise e
self.log.info(f"ConnectLE to {address} done (handle={connection.handle})")
self.log.debug(f"ConnectLE to {address} done (handle={connection.handle})")
cookie = any_pb2.Any(value=connection.handle.to_bytes(4, 'big'))
return ConnectLEResponse(connection=Connection(cookie=cookie))
@@ -243,12 +243,12 @@ class HostService(HostServicer):
self, request: DisconnectRequest, context: grpc.ServicerContext
) -> empty_pb2.Empty:
connection_handle = int.from_bytes(request.connection.cookie.value, 'big')
self.log.info(f"Disconnect: {connection_handle}")
self.log.debug(f"Disconnect: {connection_handle}")
self.log.info("Disconnecting...")
self.log.debug("Disconnecting...")
if connection := self.device.lookup_connection(connection_handle):
await connection.disconnect(HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR)
self.log.info("Disconnected")
self.log.debug("Disconnected")
return empty_pb2.Empty()
@@ -257,7 +257,7 @@ class HostService(HostServicer):
self, request: WaitDisconnectionRequest, context: grpc.ServicerContext
) -> empty_pb2.Empty:
connection_handle = int.from_bytes(request.connection.cookie.value, 'big')
self.log.info(f"WaitDisconnection: {connection_handle}")
self.log.debug(f"WaitDisconnection: {connection_handle}")
if connection := self.device.lookup_connection(connection_handle):
disconnection_future: asyncio.Future[
@@ -270,7 +270,7 @@ class HostService(HostServicer):
connection.on('disconnection', on_disconnection)
try:
await disconnection_future
self.log.info("Disconnected")
self.log.debug("Disconnected")
finally:
connection.remove_listener('disconnection', on_disconnection) # type: ignore
@@ -378,7 +378,7 @@ class HostService(HostServicer):
try:
while True:
if not self.device.is_advertising:
self.log.info('Advertise')
self.log.debug('Advertise')
await self.device.start_advertising(
target=target,
advertising_type=advertising_type,
@@ -393,10 +393,10 @@ class HostService(HostServicer):
bumble.device.Connection
] = asyncio.get_running_loop().create_future()
self.log.info('Wait for LE connection...')
self.log.debug('Wait for LE connection...')
connection = await pending_connection
self.log.info(
self.log.debug(
f"Advertise: Connected to {connection.peer_address} (handle={connection.handle})"
)
@@ -410,7 +410,7 @@ class HostService(HostServicer):
self.device.remove_listener('connection', on_connection) # type: ignore
try:
self.log.info('Stop advertising')
self.log.debug('Stop advertising')
await self.device.abort_on('flush', self.device.stop_advertising())
except:
pass
@@ -423,7 +423,7 @@ class HostService(HostServicer):
if request.phys:
raise NotImplementedError("TODO: add support for `request.phys`")
self.log.info('Scan')
self.log.debug('Scan')
scan_queue: asyncio.Queue[Advertisement] = asyncio.Queue()
handler = self.device.on('advertisement', scan_queue.put_nowait)
@@ -470,7 +470,7 @@ class HostService(HostServicer):
finally:
self.device.remove_listener('advertisement', handler) # type: ignore
try:
self.log.info('Stop scanning')
self.log.debug('Stop scanning')
await self.device.abort_on('flush', self.device.stop_scanning())
except:
pass
@@ -479,7 +479,7 @@ class HostService(HostServicer):
async def Inquiry(
self, request: empty_pb2.Empty, context: grpc.ServicerContext
) -> AsyncGenerator[InquiryResponse, None]:
self.log.info('Inquiry')
self.log.debug('Inquiry')
inquiry_queue: asyncio.Queue[
Optional[Tuple[Address, int, AdvertisingData, int]]
@@ -510,7 +510,7 @@ class HostService(HostServicer):
self.device.remove_listener('inquiry_complete', complete_handler) # type: ignore
self.device.remove_listener('inquiry_result', result_handler) # type: ignore
try:
self.log.info('Stop inquiry')
self.log.debug('Stop inquiry')
await self.device.abort_on('flush', self.device.stop_discovery())
except:
pass
@@ -519,7 +519,7 @@ class HostService(HostServicer):
async def SetDiscoverabilityMode(
self, request: SetDiscoverabilityModeRequest, context: grpc.ServicerContext
) -> empty_pb2.Empty:
self.log.info("SetDiscoverabilityMode")
self.log.debug("SetDiscoverabilityMode")
await self.device.set_discoverable(request.mode != NOT_DISCOVERABLE)
return empty_pb2.Empty()
@@ -527,7 +527,7 @@ class HostService(HostServicer):
async def SetConnectabilityMode(
self, request: SetConnectabilityModeRequest, context: grpc.ServicerContext
) -> empty_pb2.Empty:
self.log.info("SetConnectabilityMode")
self.log.debug("SetConnectabilityMode")
await self.device.set_connectable(request.mode != NOT_CONNECTABLE)
return empty_pb2.Empty()

View File

@@ -99,7 +99,7 @@ class PairingDelegate(BasePairingDelegate):
return ev
async def confirm(self, auto: bool = False) -> bool:
self.log.info(
self.log.debug(
f"Pairing event: `just_works` (io_capability: {self.io_capability})"
)
@@ -114,7 +114,7 @@ class PairingDelegate(BasePairingDelegate):
return answer.confirm
async def compare_numbers(self, number: int, digits: int = 6) -> bool:
self.log.info(
self.log.debug(
f"Pairing event: `numeric_comparison` (io_capability: {self.io_capability})"
)
@@ -129,7 +129,7 @@ class PairingDelegate(BasePairingDelegate):
return answer.confirm
async def get_number(self) -> Optional[int]:
self.log.info(
self.log.debug(
f"Pairing event: `passkey_entry_request` (io_capability: {self.io_capability})"
)
@@ -146,7 +146,7 @@ class PairingDelegate(BasePairingDelegate):
return answer.passkey
async def get_string(self, max_length: int) -> Optional[str]:
self.log.info(
self.log.debug(
f"Pairing event: `pin_code_request` (io_capability: {self.io_capability})"
)
@@ -177,7 +177,7 @@ class PairingDelegate(BasePairingDelegate):
):
return
self.log.info(
self.log.debug(
f"Pairing event: `passkey_entry_notification` (io_capability: {self.io_capability})"
)
@@ -232,6 +232,7 @@ class SecurityService(SecurityServicer):
sc=config.pairing_sc_enable,
mitm=config.pairing_mitm_enable,
bonding=config.pairing_bonding_enable,
identity_address_type=config.identity_address_type,
delegate=PairingDelegate(
connection,
self,
@@ -247,7 +248,7 @@ class SecurityService(SecurityServicer):
async def OnPairing(
self, request: AsyncIterator[PairingEventAnswer], context: grpc.ServicerContext
) -> AsyncGenerator[PairingEvent, None]:
self.log.info('OnPairing')
self.log.debug('OnPairing')
if self.event_queue is not None:
raise RuntimeError('already streaming pairing events')
@@ -273,7 +274,7 @@ class SecurityService(SecurityServicer):
self, request: SecureRequest, context: grpc.ServicerContext
) -> SecureResponse:
connection_handle = int.from_bytes(request.connection.cookie.value, 'big')
self.log.info(f"Secure: {connection_handle}")
self.log.debug(f"Secure: {connection_handle}")
connection = self.device.lookup_connection(connection_handle)
assert connection
@@ -291,7 +292,7 @@ class SecurityService(SecurityServicer):
# trigger pairing if needed
if self.need_pairing(connection, level):
try:
self.log.info('Pair...')
self.log.debug('Pair...')
if (
connection.transport == BT_LE_TRANSPORT
@@ -309,7 +310,7 @@ class SecurityService(SecurityServicer):
else:
await connection.pair()
self.log.info('Paired')
self.log.debug('Paired')
except asyncio.CancelledError:
self.log.warning("Connection died during encryption")
return SecureResponse(connection_died=empty_pb2.Empty())
@@ -320,9 +321,9 @@ class SecurityService(SecurityServicer):
# trigger authentication if needed
if self.need_authentication(connection, level):
try:
self.log.info('Authenticate...')
self.log.debug('Authenticate...')
await connection.authenticate()
self.log.info('Authenticated')
self.log.debug('Authenticated')
except asyncio.CancelledError:
self.log.warning("Connection died during authentication")
return SecureResponse(connection_died=empty_pb2.Empty())
@@ -333,9 +334,9 @@ class SecurityService(SecurityServicer):
# trigger encryption if needed
if self.need_encryption(connection, level):
try:
self.log.info('Encrypt...')
self.log.debug('Encrypt...')
await connection.encrypt()
self.log.info('Encrypted')
self.log.debug('Encrypted')
except asyncio.CancelledError:
self.log.warning("Connection died during encryption")
return SecureResponse(connection_died=empty_pb2.Empty())
@@ -353,7 +354,7 @@ class SecurityService(SecurityServicer):
self, request: WaitSecurityRequest, context: grpc.ServicerContext
) -> WaitSecurityResponse:
connection_handle = int.from_bytes(request.connection.cookie.value, 'big')
self.log.info(f"WaitSecurity: {connection_handle}")
self.log.debug(f"WaitSecurity: {connection_handle}")
connection = self.device.lookup_connection(connection_handle)
assert connection
@@ -390,7 +391,7 @@ class SecurityService(SecurityServicer):
def set_failure(name: str) -> Callable[..., None]:
def wrapper(*args: Any) -> None:
self.log.info(f'Wait for security: error `{name}`: {args}')
self.log.debug(f'Wait for security: error `{name}`: {args}')
wait_for_security.set_result(name)
return wrapper
@@ -398,13 +399,13 @@ class SecurityService(SecurityServicer):
def try_set_success(*_: Any) -> None:
assert connection
if self.reached_security_level(connection, level):
self.log.info('Wait for security: done')
self.log.debug('Wait for security: done')
wait_for_security.set_result('success')
def on_encryption_change(*_: Any) -> None:
assert connection
if self.reached_security_level(connection, level):
self.log.info('Wait for security: done')
self.log.debug('Wait for security: done')
wait_for_security.set_result('success')
elif (
connection.transport == BT_BR_EDR_TRANSPORT
@@ -432,7 +433,7 @@ class SecurityService(SecurityServicer):
if self.reached_security_level(connection, level):
return WaitSecurityResponse(success=empty_pb2.Empty())
self.log.info('Wait for security...')
self.log.debug('Wait for security...')
kwargs = {}
kwargs[await wait_for_security] = empty_pb2.Empty()
@@ -442,12 +443,12 @@ class SecurityService(SecurityServicer):
# wait for `authenticate` to finish if any
if authenticate_task is not None:
self.log.info('Wait for authentication...')
self.log.debug('Wait for authentication...')
try:
await authenticate_task # type: ignore
except:
pass
self.log.info('Authenticated')
self.log.debug('Authenticated')
return WaitSecurityResponse(**kwargs)
@@ -503,7 +504,7 @@ class SecurityStorageService(SecurityStorageServicer):
self, request: IsBondedRequest, context: grpc.ServicerContext
) -> wrappers_pb2.BoolValue:
address = utils.address_from_request(request, request.WhichOneof("address"))
self.log.info(f"IsBonded: {address}")
self.log.debug(f"IsBonded: {address}")
if self.device.keystore is not None:
is_bonded = await self.device.keystore.get(str(address)) is not None
@@ -517,7 +518,7 @@ class SecurityStorageService(SecurityStorageServicer):
self, request: DeleteBondRequest, context: grpc.ServicerContext
) -> empty_pb2.Empty:
address = utils.address_from_request(request, request.WhichOneof("address"))
self.log.info(f"DeleteBond: {address}")
self.log.debug(f"DeleteBond: {address}")
if self.device.keystore is not None:
with suppress(KeyError):

View File

@@ -94,6 +94,10 @@ SDP_CLIENT_EXECUTABLE_URL_ATTRIBUTE_ID = 0X000B
SDP_ICON_URL_ATTRIBUTE_ID = 0X000C
SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID = 0X000D
# Attribute Identifier (cf. Assigned Numbers for Service Discovery)
# used by AVRCP, HFP and A2DP
SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID = 0x0311
SDP_ATTRIBUTE_ID_NAMES = {
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID: 'SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID',
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID: 'SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID',

View File

@@ -25,6 +25,7 @@
from __future__ import annotations
import logging
import asyncio
import enum
import secrets
from typing import (
TYPE_CHECKING,
@@ -553,20 +554,16 @@ class AddressResolver:
# -----------------------------------------------------------------------------
class Session:
# Pairing methods
class PairingMethod(enum.IntEnum):
JUST_WORKS = 0
NUMERIC_COMPARISON = 1
PASSKEY = 2
OOB = 3
CTKD_OVER_CLASSIC = 4
PAIRING_METHOD_NAMES = {
JUST_WORKS: 'JUST_WORKS',
NUMERIC_COMPARISON: 'NUMERIC_COMPARISON',
PASSKEY: 'PASSKEY',
OOB: 'OOB',
}
# -----------------------------------------------------------------------------
class Session:
# I/O Capability to pairing method decision matrix
#
# See Bluetooth spec @ Vol 3, part H - Table 2.8: Mapping of IO Capabilities to Key
@@ -581,47 +578,50 @@ class Session:
# (False).
PAIRING_METHODS = {
SMP_DISPLAY_ONLY_IO_CAPABILITY: {
SMP_DISPLAY_ONLY_IO_CAPABILITY: JUST_WORKS,
SMP_DISPLAY_YES_NO_IO_CAPABILITY: JUST_WORKS,
SMP_KEYBOARD_ONLY_IO_CAPABILITY: (PASSKEY, True, False),
SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: JUST_WORKS,
SMP_KEYBOARD_DISPLAY_IO_CAPABILITY: (PASSKEY, True, False),
SMP_DISPLAY_ONLY_IO_CAPABILITY: PairingMethod.JUST_WORKS,
SMP_DISPLAY_YES_NO_IO_CAPABILITY: PairingMethod.JUST_WORKS,
SMP_KEYBOARD_ONLY_IO_CAPABILITY: (PairingMethod.PASSKEY, True, False),
SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: PairingMethod.JUST_WORKS,
SMP_KEYBOARD_DISPLAY_IO_CAPABILITY: (PairingMethod.PASSKEY, True, False),
},
SMP_DISPLAY_YES_NO_IO_CAPABILITY: {
SMP_DISPLAY_ONLY_IO_CAPABILITY: JUST_WORKS,
SMP_DISPLAY_YES_NO_IO_CAPABILITY: (JUST_WORKS, NUMERIC_COMPARISON),
SMP_KEYBOARD_ONLY_IO_CAPABILITY: (PASSKEY, True, False),
SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: JUST_WORKS,
SMP_DISPLAY_ONLY_IO_CAPABILITY: PairingMethod.JUST_WORKS,
SMP_DISPLAY_YES_NO_IO_CAPABILITY: (
PairingMethod.JUST_WORKS,
PairingMethod.NUMERIC_COMPARISON,
),
SMP_KEYBOARD_ONLY_IO_CAPABILITY: (PairingMethod.PASSKEY, True, False),
SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: PairingMethod.JUST_WORKS,
SMP_KEYBOARD_DISPLAY_IO_CAPABILITY: (
(PASSKEY, True, False),
NUMERIC_COMPARISON,
(PairingMethod.PASSKEY, True, False),
PairingMethod.NUMERIC_COMPARISON,
),
},
SMP_KEYBOARD_ONLY_IO_CAPABILITY: {
SMP_DISPLAY_ONLY_IO_CAPABILITY: (PASSKEY, False, True),
SMP_DISPLAY_YES_NO_IO_CAPABILITY: (PASSKEY, False, True),
SMP_KEYBOARD_ONLY_IO_CAPABILITY: (PASSKEY, False, False),
SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: JUST_WORKS,
SMP_KEYBOARD_DISPLAY_IO_CAPABILITY: (PASSKEY, False, True),
SMP_DISPLAY_ONLY_IO_CAPABILITY: (PairingMethod.PASSKEY, False, True),
SMP_DISPLAY_YES_NO_IO_CAPABILITY: (PairingMethod.PASSKEY, False, True),
SMP_KEYBOARD_ONLY_IO_CAPABILITY: (PairingMethod.PASSKEY, False, False),
SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: PairingMethod.JUST_WORKS,
SMP_KEYBOARD_DISPLAY_IO_CAPABILITY: (PairingMethod.PASSKEY, False, True),
},
SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: {
SMP_DISPLAY_ONLY_IO_CAPABILITY: JUST_WORKS,
SMP_DISPLAY_YES_NO_IO_CAPABILITY: JUST_WORKS,
SMP_KEYBOARD_ONLY_IO_CAPABILITY: JUST_WORKS,
SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: JUST_WORKS,
SMP_KEYBOARD_DISPLAY_IO_CAPABILITY: JUST_WORKS,
SMP_DISPLAY_ONLY_IO_CAPABILITY: PairingMethod.JUST_WORKS,
SMP_DISPLAY_YES_NO_IO_CAPABILITY: PairingMethod.JUST_WORKS,
SMP_KEYBOARD_ONLY_IO_CAPABILITY: PairingMethod.JUST_WORKS,
SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: PairingMethod.JUST_WORKS,
SMP_KEYBOARD_DISPLAY_IO_CAPABILITY: PairingMethod.JUST_WORKS,
},
SMP_KEYBOARD_DISPLAY_IO_CAPABILITY: {
SMP_DISPLAY_ONLY_IO_CAPABILITY: (PASSKEY, False, True),
SMP_DISPLAY_ONLY_IO_CAPABILITY: (PairingMethod.PASSKEY, False, True),
SMP_DISPLAY_YES_NO_IO_CAPABILITY: (
(PASSKEY, False, True),
NUMERIC_COMPARISON,
(PairingMethod.PASSKEY, False, True),
PairingMethod.NUMERIC_COMPARISON,
),
SMP_KEYBOARD_ONLY_IO_CAPABILITY: (PASSKEY, True, False),
SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: JUST_WORKS,
SMP_KEYBOARD_ONLY_IO_CAPABILITY: (PairingMethod.PASSKEY, True, False),
SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY: PairingMethod.JUST_WORKS,
SMP_KEYBOARD_DISPLAY_IO_CAPABILITY: (
(PASSKEY, True, False),
NUMERIC_COMPARISON,
(PairingMethod.PASSKEY, True, False),
PairingMethod.NUMERIC_COMPARISON,
),
},
}
@@ -664,7 +664,7 @@ class Session:
self.passkey_ready = asyncio.Event()
self.passkey_step = 0
self.passkey_display = False
self.pairing_method = 0
self.pairing_method: PairingMethod = PairingMethod.JUST_WORKS
self.pairing_config = pairing_config
self.wait_before_continuing: Optional[asyncio.Future[None]] = None
self.completed = False
@@ -769,19 +769,23 @@ class Session:
def decide_pairing_method(
self, auth_req: int, initiator_io_capability: int, responder_io_capability: int
) -> None:
if self.connection.transport == BT_BR_EDR_TRANSPORT:
self.pairing_method = PairingMethod.CTKD_OVER_CLASSIC
return
if (not self.mitm) and (auth_req & SMP_MITM_AUTHREQ == 0):
self.pairing_method = self.JUST_WORKS
self.pairing_method = PairingMethod.JUST_WORKS
return
details = self.PAIRING_METHODS[initiator_io_capability][responder_io_capability] # type: ignore[index]
if isinstance(details, tuple) and len(details) == 2:
# One entry for legacy pairing and one for secure connections
details = details[1 if self.sc else 0]
if isinstance(details, int):
if isinstance(details, PairingMethod):
# Just a method ID
self.pairing_method = details
else:
# PASSKEY method, with a method ID and display/input flags
assert isinstance(details[0], PairingMethod)
self.pairing_method = details[0]
self.passkey_display = details[1 if self.is_initiator else 2]
@@ -858,10 +862,13 @@ class Session:
self.tk = self.passkey.to_bytes(16, byteorder='little')
logger.debug(f'TK from passkey = {self.tk.hex()}')
self.connection.abort_on(
'disconnection',
self.pairing_config.delegate.display_number(self.passkey, digits=6),
)
try:
self.connection.abort_on(
'disconnection',
self.pairing_config.delegate.display_number(self.passkey, digits=6),
)
except Exception as error:
logger.warning(f'exception while displaying number: {error}')
def input_passkey(self, next_steps: Optional[Callable[[], None]] = None) -> None:
# Prompt the user for the passkey displayed on the peer
@@ -929,9 +936,12 @@ class Session:
if self.sc:
async def next_steps() -> None:
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
if self.pairing_method in (
PairingMethod.JUST_WORKS,
PairingMethod.NUMERIC_COMPARISON,
):
z = 0
elif self.pairing_method == self.PASSKEY:
elif self.pairing_method == PairingMethod.PASSKEY:
# We need a passkey
await self.passkey_ready.wait()
assert self.passkey
@@ -983,6 +993,19 @@ class Session:
)
)
def send_identity_address_command(self) -> None:
identity_address = {
None: self.connection.self_address,
Address.PUBLIC_DEVICE_ADDRESS: self.manager.device.public_address,
Address.RANDOM_DEVICE_ADDRESS: self.manager.device.random_address,
}[self.pairing_config.identity_address_type]
self.send_command(
SMP_Identity_Address_Information_Command(
addr_type=identity_address.address_type,
bd_addr=identity_address,
)
)
def start_encryption(self, key: bytes) -> None:
# We can now encrypt the connection with the short term key, so that we can
# distribute the long term and/or other keys over an encrypted connection
@@ -1006,6 +1029,7 @@ class Session:
self.ltk = crypto.h6(ilk, b'brle')
def distribute_keys(self) -> None:
# Distribute the keys as required
if self.is_initiator:
# CTKD: Derive LTK from LinkKey
@@ -1035,12 +1059,7 @@ class Session:
identity_resolving_key=self.manager.device.irk
)
)
self.send_command(
SMP_Identity_Address_Information_Command(
addr_type=self.connection.self_address.address_type,
bd_addr=self.connection.self_address,
)
)
self.send_identity_address_command()
# Distribute CSRK
csrk = bytes(16) # FIXME: testing
@@ -1084,12 +1103,7 @@ class Session:
identity_resolving_key=self.manager.device.irk
)
)
self.send_command(
SMP_Identity_Address_Information_Command(
addr_type=self.connection.self_address.address_type,
bd_addr=self.connection.self_address,
)
)
self.send_identity_address_command()
# Distribute CSRK
csrk = bytes(16) # FIXME: testing
@@ -1224,7 +1238,7 @@ class Session:
# Create an object to hold the keys
keys = PairingKeys()
keys.address_type = peer_address.address_type
authenticated = self.pairing_method != self.JUST_WORKS
authenticated = self.pairing_method != PairingMethod.JUST_WORKS
if self.sc or self.connection.transport == BT_BR_EDR_TRANSPORT:
keys.ltk = PairingKeys.Key(value=self.ltk, authenticated=authenticated)
else:
@@ -1258,7 +1272,7 @@ class Session:
keys.link_key = PairingKeys.Key(
value=self.link_key, authenticated=authenticated
)
self.manager.on_pairing(self, peer_address, keys)
await self.manager.on_pairing(self, peer_address, keys)
def on_pairing_failure(self, reason: int) -> None:
logger.warning(f'pairing failure ({error_name(reason)})')
@@ -1300,7 +1314,11 @@ class Session:
self, command: SMP_Pairing_Request_Command
) -> None:
# Check if the request should proceed
accepted = await self.pairing_config.delegate.accept()
try:
accepted = await self.pairing_config.delegate.accept()
except Exception as error:
logger.warning(f'exception while accepting: {error}')
accepted = False
if not accepted:
logger.debug('pairing rejected by delegate')
self.send_pairing_failed(SMP_PAIRING_NOT_SUPPORTED_ERROR)
@@ -1323,9 +1341,7 @@ class Session:
self.decide_pairing_method(
command.auth_req, command.io_capability, self.io_capability
)
logger.debug(
f'pairing method: {self.PAIRING_METHOD_NAMES[self.pairing_method]}'
)
logger.debug(f'pairing method: {self.pairing_method.name}')
# Key distribution
(
@@ -1341,7 +1357,7 @@ class Session:
# Display a passkey if we need to
if not self.sc:
if self.pairing_method == self.PASSKEY and self.passkey_display:
if self.pairing_method == PairingMethod.PASSKEY and self.passkey_display:
self.display_passkey()
# Respond
@@ -1382,9 +1398,7 @@ class Session:
self.decide_pairing_method(
command.auth_req, self.io_capability, command.io_capability
)
logger.debug(
f'pairing method: {self.PAIRING_METHOD_NAMES[self.pairing_method]}'
)
logger.debug(f'pairing method: {self.pairing_method.name}')
# Key distribution
if (
@@ -1400,13 +1414,16 @@ class Session:
self.compute_peer_expected_distributions(self.responder_key_distribution)
# Start phase 2
if self.sc:
if self.pairing_method == self.PASSKEY:
if self.pairing_method == PairingMethod.CTKD_OVER_CLASSIC:
# Authentication is already done in SMP, so remote shall start keys distribution immediately
return
elif self.sc:
if self.pairing_method == PairingMethod.PASSKEY:
self.display_or_input_passkey()
self.send_public_key_command()
else:
if self.pairing_method == self.PASSKEY:
if self.pairing_method == PairingMethod.PASSKEY:
self.display_or_input_passkey(self.send_pairing_confirm_command)
else:
self.send_pairing_confirm_command()
@@ -1418,7 +1435,10 @@ class Session:
self.send_pairing_random_command()
else:
# If the method is PASSKEY, now is the time to input the code
if self.pairing_method == self.PASSKEY and not self.passkey_display:
if (
self.pairing_method == PairingMethod.PASSKEY
and not self.passkey_display
):
self.input_passkey(self.send_pairing_confirm_command)
else:
self.send_pairing_confirm_command()
@@ -1426,11 +1446,14 @@ class Session:
def on_smp_pairing_confirm_command_secure_connections(
self, _: SMP_Pairing_Confirm_Command
) -> None:
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
if self.pairing_method in (
PairingMethod.JUST_WORKS,
PairingMethod.NUMERIC_COMPARISON,
):
if self.is_initiator:
self.r = crypto.r()
self.send_pairing_random_command()
elif self.pairing_method == self.PASSKEY:
elif self.pairing_method == PairingMethod.PASSKEY:
if self.is_initiator:
self.send_pairing_random_command()
else:
@@ -1486,13 +1509,16 @@ class Session:
def on_smp_pairing_random_command_secure_connections(
self, command: SMP_Pairing_Random_Command
) -> None:
if self.pairing_method == self.PASSKEY and self.passkey is None:
if self.pairing_method == PairingMethod.PASSKEY and self.passkey is None:
logger.warning('no passkey entered, ignoring command')
return
# pylint: disable=too-many-return-statements
if self.is_initiator:
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
if self.pairing_method in (
PairingMethod.JUST_WORKS,
PairingMethod.NUMERIC_COMPARISON,
):
assert self.confirm_value
# Check that the random value matches what was committed to earlier
confirm_verifier = crypto.f4(
@@ -1502,7 +1528,7 @@ class Session:
self.confirm_value, confirm_verifier, SMP_CONFIRM_VALUE_FAILED_ERROR
):
return
elif self.pairing_method == self.PASSKEY:
elif self.pairing_method == PairingMethod.PASSKEY:
assert self.passkey and self.confirm_value
# Check that the random value matches what was committed to earlier
confirm_verifier = crypto.f4(
@@ -1525,9 +1551,12 @@ class Session:
else:
return
else:
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
if self.pairing_method in (
PairingMethod.JUST_WORKS,
PairingMethod.NUMERIC_COMPARISON,
):
self.send_pairing_random_command()
elif self.pairing_method == self.PASSKEY:
elif self.pairing_method == PairingMethod.PASSKEY:
assert self.passkey and self.confirm_value
# Check that the random value matches what was committed to earlier
confirm_verifier = crypto.f4(
@@ -1558,10 +1587,13 @@ class Session:
(mac_key, self.ltk) = crypto.f5(self.dh_key, self.na, self.nb, a, b)
# Compute the DH Key checks
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
if self.pairing_method in (
PairingMethod.JUST_WORKS,
PairingMethod.NUMERIC_COMPARISON,
):
ra = bytes(16)
rb = ra
elif self.pairing_method == self.PASSKEY:
elif self.pairing_method == PairingMethod.PASSKEY:
assert self.passkey
ra = self.passkey.to_bytes(16, byteorder='little')
rb = ra
@@ -1585,13 +1617,16 @@ class Session:
self.wait_before_continuing.set_result(None)
# Prompt the user for confirmation if needed
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
if self.pairing_method in (
PairingMethod.JUST_WORKS,
PairingMethod.NUMERIC_COMPARISON,
):
# Compute the 6-digit code
code = crypto.g2(self.pka, self.pkb, self.na, self.nb) % 1000000
# Ask for user confirmation
self.wait_before_continuing = asyncio.get_running_loop().create_future()
if self.pairing_method == self.JUST_WORKS:
if self.pairing_method == PairingMethod.JUST_WORKS:
self.prompt_user_for_confirmation(next_steps)
else:
self.prompt_user_for_numeric_comparison(code, next_steps)
@@ -1628,13 +1663,16 @@ class Session:
if self.is_initiator:
self.send_pairing_confirm_command()
else:
if self.pairing_method == self.PASSKEY:
if self.pairing_method == PairingMethod.PASSKEY:
self.display_or_input_passkey()
# Send our public key back to the initiator
self.send_public_key_command()
if self.pairing_method in (self.JUST_WORKS, self.NUMERIC_COMPARISON):
if self.pairing_method in (
PairingMethod.JUST_WORKS,
PairingMethod.NUMERIC_COMPARISON,
):
# We can now send the confirmation value
self.send_pairing_confirm_command()
@@ -1789,20 +1827,13 @@ class Manager(EventEmitter):
def on_session_start(self, session: Session) -> None:
self.device.on_pairing_start(session.connection)
def on_pairing(
async def on_pairing(
self, session: Session, identity_address: Optional[Address], keys: PairingKeys
) -> None:
# Store the keys in the key store
if self.device.keystore and identity_address is not None:
async def store_keys():
try:
assert self.device.keystore
await self.device.keystore.update(str(identity_address), keys)
except Exception as error:
logger.warning(f'!!! error while storing keys: {error}')
self.device.abort_on('flush', store_keys())
await self.device.keystore.update(str(identity_address), keys)
await self.device.refresh_resolving_list()
# Notify the device
self.device.on_pairing(session.connection, identity_address, keys, session.sc)

View File

@@ -44,11 +44,18 @@ HCI_PACKET_INFO = {
}
# -----------------------------------------------------------------------------
class TransportLostError(Exception):
"""
The Transport has been lost/disconnected.
"""
# -----------------------------------------------------------------------------
class PacketPump:
'''
Pump HCI packets from a reader to a sink
'''
"""
Pump HCI packets from a reader to a sink.
"""
def __init__(self, reader, sink):
self.reader = reader
@@ -68,10 +75,10 @@ class PacketPump:
# -----------------------------------------------------------------------------
class PacketParser:
'''
"""
In-line parser that accepts data and emits 'on_packet' when a full packet has been
parsed
'''
parsed.
"""
# pylint: disable=attribute-defined-outside-init
@@ -134,9 +141,9 @@ class PacketParser:
# -----------------------------------------------------------------------------
class PacketReader:
'''
Reader that reads HCI packets from a sync source
'''
"""
Reader that reads HCI packets from a sync source.
"""
def __init__(self, source):
self.source = source
@@ -169,9 +176,9 @@ class PacketReader:
# -----------------------------------------------------------------------------
class AsyncPacketReader:
'''
Reader that reads HCI packets from an async source
'''
"""
Reader that reads HCI packets from an async source.
"""
def __init__(self, source):
self.source = source
@@ -198,9 +205,9 @@ class AsyncPacketReader:
# -----------------------------------------------------------------------------
class AsyncPipeSink:
'''
Sink that forwards packets asynchronously to another sink
'''
"""
Sink that forwards packets asynchronously to another sink.
"""
def __init__(self, sink):
self.sink = sink
@@ -216,6 +223,9 @@ class ParserSource:
Base class designed to be subclassed by transport-specific source classes
"""
terminated: asyncio.Future
parser: PacketParser
def __init__(self):
self.parser = PacketParser()
self.terminated = asyncio.get_running_loop().create_future()
@@ -223,7 +233,19 @@ class ParserSource:
def set_packet_sink(self, sink):
self.parser.set_packet_sink(sink)
def on_transport_lost(self):
self.terminated.set_result(None)
if self.parser.sink:
try:
self.parser.sink.on_transport_lost()
except AttributeError:
pass
async def wait_for_termination(self):
"""
Convenience method for backward compatibility. Prefer using the `terminated`
attribute instead.
"""
return await self.terminated
def close(self):

View File

@@ -39,7 +39,7 @@ async def open_tcp_client_transport(spec):
class TcpPacketSource(StreamPacketSource):
def connection_lost(self, exc):
logger.debug(f'connection lost: {exc}')
self.terminated.set_result(exc)
self.on_transport_lost()
remote_host, remote_port = spec.split(':')
tcp_transport, packet_source = await asyncio.get_running_loop().create_connection(

View File

@@ -1,11 +1,11 @@
UDP TRANSPORT
=============
WEBSOCKET CLIENT TRANSPORT
==========================
The UDP transport is a UDP socket, receiving packets on a specified port number, and sending packets to a specified host and port number.
The WebSocket Client transport is WebSocket connection to a WebSocket server over which HCI packets
are sent and received.
## Moniker
The moniker syntax for a UDP transport is: `udp:<local-host>:<local-port>,<remote-host>:<remote-port>`.
The moniker syntax for a WebSocket Client transport is: `ws-client:<ws-url>`
!!! example
`udp:0.0.0.0:9000,127.0.0.1:9001`
UDP transport where packets are received on port `9000` and sent to `127.0.0.1` on port `9001`
`ws-client:ws://localhost:1234/some/path`

View File

@@ -1,11 +1,13 @@
UDP TRANSPORT
=============
WEBSOCKET SERVER TRANSPORT
==========================
The UDP transport is a UDP socket, receiving packets on a specified port number, and sending packets to a specified host and port number.
The WebSocket Server transport is WebSocket server that accepts connections from a WebSocket
client. HCI packets are sent and received over the connection.
## Moniker
The moniker syntax for a UDP transport is: `udp:<local-host>:<local-port>,<remote-host>:<remote-port>`.
The moniker syntax for a WebSocket Server transport is: `ws-server:<host>:<port>`,
where `<host>` may be the address of a local network interface, or `_`to accept connections on all local network interfaces. `<port>` is the TCP port number on which to accept connections.
!!! example
`udp:0.0.0.0:9000,127.0.0.1:9001`
UDP transport where packets are received on port `9000` and sent to `127.0.0.1` on port `9001`
`ws-server:_:9001`

View File

@@ -16,9 +16,11 @@
# Imports
# -----------------------------------------------------------------------------
import asyncio
import collections
import sys
import os
import logging
from typing import Union
from bumble.colors import color
@@ -30,6 +32,7 @@ from bumble.core import (
BT_RFCOMM_PROTOCOL_ID,
BT_BR_EDR_TRANSPORT,
)
from bumble import rfcomm
from bumble.rfcomm import Client
from bumble.sdp import (
Client as SDP_Client,
@@ -39,7 +42,64 @@ from bumble.sdp import (
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
)
from bumble.hfp import HfpProtocol
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Protocol Support
# -----------------------------------------------------------------------------
# -----------------------------------------------------------------------------
class HfpProtocol:
dlc: rfcomm.DLC
buffer: str
lines: collections.deque
lines_available: asyncio.Event
def __init__(self, dlc: rfcomm.DLC) -> None:
self.dlc = dlc
self.buffer = ''
self.lines = collections.deque()
self.lines_available = asyncio.Event()
dlc.sink = self.feed
def feed(self, data: Union[bytes, str]) -> None:
# Convert the data to a string if needed
if isinstance(data, bytes):
data = data.decode('utf-8')
logger.debug(f'<<< Data received: {data}')
# Add to the buffer and look for lines
self.buffer += data
while (separator := self.buffer.find('\r')) >= 0:
line = self.buffer[:separator].strip()
self.buffer = self.buffer[separator + 1 :]
if len(line) > 0:
self.on_line(line)
def on_line(self, line: str) -> None:
self.lines.append(line)
self.lines_available.set()
def send_command_line(self, line: str) -> None:
logger.debug(color(f'>>> {line}', 'yellow'))
self.dlc.write(line + '\r')
def send_response_line(self, line: str) -> None:
logger.debug(color(f'>>> {line}', 'yellow'))
self.dlc.write('\r\n' + line + '\r\n')
async def next_line(self) -> str:
await self.lines_available.wait()
line = self.lines.popleft()
if not self.lines:
self.lines_available.clear()
logger.debug(color(f'<<< {line}', 'green'))
return line
# -----------------------------------------------------------------------------

View File

@@ -21,82 +21,22 @@ import os
import logging
import json
import websockets
from typing import Optional
from bumble.device import Device
from bumble.transport import open_transport_or_link
from bumble.rfcomm import Server as RfcommServer
from bumble.sdp import (
DataElement,
ServiceAttribute,
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
)
from bumble.core import (
BT_GENERIC_AUDIO_SERVICE,
BT_HANDSFREE_SERVICE,
BT_L2CAP_PROTOCOL_ID,
BT_RFCOMM_PROTOCOL_ID,
)
from bumble.hfp import HfpProtocol
# -----------------------------------------------------------------------------
def make_sdp_records(rfcomm_channel):
return {
0x00010001: [
ServiceAttribute(
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
DataElement.unsigned_integer_32(0x00010001),
),
ServiceAttribute(
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.uuid(BT_HANDSFREE_SERVICE),
DataElement.uuid(BT_GENERIC_AUDIO_SERVICE),
]
),
),
ServiceAttribute(
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.sequence([DataElement.uuid(BT_L2CAP_PROTOCOL_ID)]),
DataElement.sequence(
[
DataElement.uuid(BT_RFCOMM_PROTOCOL_ID),
DataElement.unsigned_integer_8(rfcomm_channel),
]
),
]
),
),
ServiceAttribute(
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.sequence(
[
DataElement.uuid(BT_HANDSFREE_SERVICE),
DataElement.unsigned_integer_16(0x0105),
]
)
]
),
),
]
}
from bumble import hfp
from bumble.hfp import HfProtocol
# -----------------------------------------------------------------------------
class UiServer:
protocol = None
protocol: Optional[HfProtocol] = None
async def start(self):
# Start a Websocket server to receive events from a web page
"""Start a Websocket server to receive events from a web page."""
async def serve(websocket, _path):
while True:
try:
@@ -107,7 +47,7 @@ class UiServer:
message_type = parsed['type']
if message_type == 'at_command':
if self.protocol is not None:
self.protocol.send_command_line(parsed['command'])
await self.protocol.execute_command(parsed['command'])
except websockets.exceptions.ConnectionClosedOK:
pass
@@ -117,19 +57,11 @@ class UiServer:
# -----------------------------------------------------------------------------
async def protocol_loop(protocol):
await protocol.initialize_service()
while True:
await (protocol.next_line())
# -----------------------------------------------------------------------------
def on_dlc(dlc):
def on_dlc(dlc, configuration: hfp.Configuration):
print('*** DLC connected', dlc)
protocol = HfpProtocol(dlc)
protocol = HfProtocol(dlc, configuration)
UiServer.protocol = protocol
asyncio.create_task(protocol_loop(protocol))
asyncio.create_task(protocol.run())
# -----------------------------------------------------------------------------
@@ -143,6 +75,27 @@ async def main():
async with await open_transport_or_link(sys.argv[2]) as (hci_source, hci_sink):
print('<<< connected')
# Hands-Free profile configuration.
# TODO: load configuration from file.
configuration = hfp.Configuration(
supported_hf_features=[
hfp.HfFeature.THREE_WAY_CALLING,
hfp.HfFeature.REMOTE_VOLUME_CONTROL,
hfp.HfFeature.ENHANCED_CALL_STATUS,
hfp.HfFeature.ENHANCED_CALL_CONTROL,
hfp.HfFeature.CODEC_NEGOTIATION,
hfp.HfFeature.HF_INDICATORS,
hfp.HfFeature.ESCO_S4_SETTINGS_SUPPORTED,
],
supported_hf_indicators=[
hfp.HfIndicator.BATTERY_LEVEL,
],
supported_audio_codecs=[
hfp.AudioCodec.CVSD,
hfp.AudioCodec.MSBC,
],
)
# Create a device
device = Device.from_config_file_with_hci(sys.argv[1], hci_source, hci_sink)
device.classic_enabled = True
@@ -151,11 +104,13 @@ async def main():
rfcomm_server = RfcommServer(device)
# Listen for incoming DLC connections
channel_number = rfcomm_server.listen(on_dlc)
channel_number = rfcomm_server.listen(lambda dlc: on_dlc(dlc, configuration))
print(f'### Listening for connection on channel {channel_number}')
# Advertise the HFP RFComm channel in the SDP
device.sdp_service_records = make_sdp_records(channel_number)
device.sdp_service_records = {
0x00010001: hfp.sdp_records(0x00010001, channel_number, configuration)
}
# Let's go!
await device.power_on()

View File

@@ -62,7 +62,7 @@ async def main():
print(
f'>>> {color(advertisement.address, address_color)} '
f'[{color(address_type_string, type_color)}]'
f'{address_qualifier}:{separator}RSSI:{advertisement.rssi}'
f'{address_qualifier}:{separator}RSSI: {advertisement.rssi}'
f'{separator}'
f'{advertisement.data.to_string(separator)}'
)

BIN
mmm.sbc Normal file

Binary file not shown.

2
rust/.gitignore vendored Normal file
View File

@@ -0,0 +1,2 @@
/target
/.idea

7
rust/CHANGELOG.md Normal file
View File

@@ -0,0 +1,7 @@
# Next
- Code-gen company ID table
# 0.1.0
- Initial release

1194
rust/Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

56
rust/Cargo.toml Normal file
View File

@@ -0,0 +1,56 @@
[package]
name = "bumble"
description = "Rust API for the Bumble Bluetooth stack"
version = "0.1.0"
edition = "2021"
license = "Apache-2.0"
homepage = "https://google.github.io/bumble/index.html"
repository = "https://github.com/google/bumble"
documentation = "https://docs.rs/crate/bumble"
authors = ["Marshall Pierce <marshallpierce@google.com>"]
keywords = ["bluetooth", "ble"]
categories = ["api-bindings", "network-programming"]
rust-version = "1.69.0"
[dependencies]
pyo3 = { version = "0.18.3", features = ["macros"] }
pyo3-asyncio = { version = "0.18.0", features = ["tokio-runtime"] }
tokio = { version = "1.28.2" }
nom = "7.1.3"
strum = "0.25.0"
strum_macros = "0.25.0"
hex = "0.4.3"
itertools = "0.11.0"
lazy_static = "1.4.0"
thiserror = "1.0.41"
anyhow = { version = "1.0.71", optional = true }
[dev-dependencies]
tokio = { version = "1.28.2", features = ["full"] }
tempfile = "3.6.0"
nix = "0.26.2"
anyhow = "1.0.71"
pyo3 = { version = "0.18.3", features = ["macros", "anyhow"] }
pyo3-asyncio = { version = "0.18.0", features = ["tokio-runtime", "attributes", "testing"] }
clap = { version = "4.3.3", features = ["derive"] }
owo-colors = "3.5.0"
log = "0.4.19"
env_logger = "0.10.0"
rusb = "0.9.2"
rand = "0.8.5"
[[bin]]
name = "gen-assigned-numbers"
path = "tools/gen_assigned_numbers.rs"
required-features = ["bumble-dev-tools"]
# test entry point that uses pyo3_asyncio's test harness
[[test]]
name = "pytests"
path = "pytests/pytests.rs"
harness = false
[features]
anyhow = ["pyo3/anyhow"]
pyo3-asyncio-attributes = ["pyo3-asyncio/attributes"]
bumble-dev-tools = ["dep:anyhow"]

56
rust/README.md Normal file
View File

@@ -0,0 +1,56 @@
# What is this?
Rust wrappers around the [Bumble](https://github.com/google/bumble) Python API.
Method calls are mapped to the equivalent Python, and return types adapted where
relevant.
See the `examples` directory for usage.
# Usage
Set up a virtualenv for Bumble, or otherwise have an isolated Python environment
for Bumble and its dependencies.
Due to Python being
[picky about how its sys path is set up](https://github.com/PyO3/pyo3/issues/1741,
it's necessary to explicitly point to the virtualenv's `site-packages`. Use
suitable virtualenv paths as appropriate for your OS, as seen here running
the `battery_client` example:
```
PYTHONPATH=..:~/.virtualenvs/bumble/lib/python3.10/site-packages/ \
cargo run --example battery_client -- \
--transport android-netsim --target-addr F0:F1:F2:F3:F4:F5
```
Run the corresponding `battery_server` Python example, and launch an emulator in
Android Studio (currently, Canary is required) to run netsim.
# Development
Run the tests:
```
PYTHONPATH=.. cargo test
```
Check lints:
```
cargo clippy --all-targets
```
## Code gen
To have the fastest startup while keeping the build simple, code gen for
assigned numbers is done with the `gen_assigned_numbers` tool. It should
be re-run whenever the Python assigned numbers are changed. To ensure that the
generated code is kept up to date, the Rust data is compared to the Python
in tests at `pytests/assigned_numbers.rs`.
To regenerate the assigned number tables based on the Python codebase:
```
PYTHONPATH=.. cargo run --bin gen-assigned-numbers --features bumble-dev-tools
```

View File

@@ -0,0 +1,112 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Counterpart to the Python example `battery_server.py`.
//!
//! Start an Android emulator from Android Studio, or otherwise have netsim running.
//!
//! Run the server from the project root:
//! ```
//! PYTHONPATH=. python examples/battery_server.py \
//! examples/device1.json android-netsim
//! ```
//!
//! Then run this example from the `rust` directory:
//!
//! ```
//! PYTHONPATH=..:/path/to/virtualenv/site-packages/ \
//! cargo run --example battery_client -- \
//! --transport android-netsim \
//! --target-addr F0:F1:F2:F3:F4:F5
//! ```
use bumble::wrapper::{
device::{Device, Peer},
profile::BatteryServiceProxy,
transport::Transport,
PyObjectExt,
};
use clap::Parser as _;
use log::info;
use owo_colors::OwoColorize;
use pyo3::prelude::*;
#[pyo3_asyncio::tokio::main]
async fn main() -> PyResult<()> {
env_logger::builder()
.filter_level(log::LevelFilter::Info)
.init();
let cli = Cli::parse();
let transport = Transport::open(cli.transport).await?;
let device = Device::with_hci(
"Bumble",
"F0:F1:F2:F3:F4:F5",
transport.source()?,
transport.sink()?,
)?;
device.power_on().await?;
let conn = device.connect(&cli.target_addr).await?;
let mut peer = Peer::new(conn)?;
for mut s in peer.discover_services().await? {
s.discover_characteristics().await?;
}
let battery_service = peer
.create_service_proxy::<BatteryServiceProxy>()?
.ok_or(anyhow::anyhow!("No battery service found"))?;
let mut battery_level_char = battery_service
.battery_level()?
.ok_or(anyhow::anyhow!("No battery level characteristic"))?;
info!(
"{} {}",
"Initial Battery Level:".green(),
battery_level_char
.read_value()
.await?
.extract_with_gil::<u32>()?
);
battery_level_char
.subscribe(|_py, args| {
info!(
"{} {:?}",
"Battery level update:".green(),
args.get_item(0)?.extract::<u32>()?,
);
Ok(())
})
.await?;
// wait until user kills the process
tokio::signal::ctrl_c().await?;
Ok(())
}
#[derive(clap::Parser)]
#[command(author, version, about, long_about = None)]
struct Cli {
/// Bumble transport spec.
///
/// <https://google.github.io/bumble/transports/index.html>
#[arg(long)]
transport: String,
/// Address to connect to
#[arg(long)]
target_addr: String,
}

View File

@@ -0,0 +1,98 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use anyhow::anyhow;
use bumble::{
adv::{AdvertisementDataBuilder, CommonDataType},
wrapper::{
device::Device,
logging::{bumble_env_logging_level, py_logging_basic_config},
transport::Transport,
},
};
use clap::Parser as _;
use pyo3::PyResult;
use rand::Rng;
use std::path;
#[pyo3_asyncio::tokio::main]
async fn main() -> PyResult<()> {
env_logger::builder()
.filter_level(log::LevelFilter::Info)
.init();
let cli = Cli::parse();
if cli.log_hci {
py_logging_basic_config(bumble_env_logging_level("DEBUG"))?;
}
let transport = Transport::open(cli.transport).await?;
let mut device = Device::from_config_file_with_hci(
&cli.device_config,
transport.source()?,
transport.sink()?,
)?;
let mut adv_data = AdvertisementDataBuilder::new();
adv_data
.append(
CommonDataType::CompleteLocalName,
"Bumble from Rust".as_bytes(),
)
.map_err(|e| anyhow!(e))?;
// Randomized TX power
adv_data
.append(
CommonDataType::TxPowerLevel,
&[rand::thread_rng().gen_range(-100_i8..=20) as u8],
)
.map_err(|e| anyhow!(e))?;
device.set_advertising_data(adv_data)?;
device.power_on().await?;
println!("Advertising...");
device.start_advertising(true).await?;
// wait until user kills the process
tokio::signal::ctrl_c().await?;
println!("Stopping...");
device.stop_advertising().await?;
Ok(())
}
#[derive(clap::Parser)]
#[command(author, version, about, long_about = None)]
struct Cli {
/// Bumble device config.
///
/// See, for instance, `examples/device1.json` in the Python project.
#[arg(long)]
device_config: path::PathBuf,
/// Bumble transport spec.
///
/// <https://google.github.io/bumble/transports/index.html>
#[arg(long)]
transport: String,
/// Log HCI commands
#[arg(long)]
log_hci: bool,
}

185
rust/examples/scanner.rs Normal file
View File

@@ -0,0 +1,185 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Counterpart to the Python example `run_scanner.py`.
//!
//! Device deduplication is done here rather than relying on the controller's filtering to provide
//! for additional features, like the ability to make deduplication time-bounded.
use bumble::{
adv::CommonDataType,
wrapper::{
core::AdvertisementDataUnit, device::Device, hci::AddressType, transport::Transport,
},
};
use clap::Parser as _;
use itertools::Itertools;
use owo_colors::{OwoColorize, Style};
use pyo3::PyResult;
use std::{
collections,
sync::{Arc, Mutex},
time,
};
#[pyo3_asyncio::tokio::main]
async fn main() -> PyResult<()> {
env_logger::builder()
.filter_level(log::LevelFilter::Info)
.init();
let cli = Cli::parse();
let transport = Transport::open(cli.transport).await?;
let mut device = Device::with_hci(
"Bumble",
"F0:F1:F2:F3:F4:F5",
transport.source()?,
transport.sink()?,
)?;
// in practice, devices can send multiple advertisements from the same address, so we keep
// track of a timestamp for each set of data
let seen_advertisements = Arc::new(Mutex::new(collections::HashMap::<
Vec<u8>,
collections::HashMap<Vec<AdvertisementDataUnit>, time::Instant>,
>::new()));
let seen_adv_clone = seen_advertisements.clone();
device.on_advertisement(move |_py, adv| {
let rssi = adv.rssi()?;
let data_units = adv.data()?.data_units()?;
let addr = adv.address()?;
let show_adv = if cli.filter_duplicates {
let addr_bytes = addr.as_le_bytes()?;
let mut seen_adv_cache = seen_adv_clone.lock().unwrap();
let expiry_duration = time::Duration::from_secs(cli.dedup_expiry_secs);
let advs_from_addr = seen_adv_cache
.entry(addr_bytes)
.or_insert_with(collections::HashMap::new);
// we expect cache hits to be the norm, so we do a separate lookup to avoid cloning
// on every lookup with entry()
let show = if let Some(prev) = advs_from_addr.get_mut(&data_units) {
let expired = prev.elapsed() > expiry_duration;
*prev = time::Instant::now();
expired
} else {
advs_from_addr.insert(data_units.clone(), time::Instant::now());
true
};
// clean out anything we haven't seen in a while
advs_from_addr.retain(|_, instant| instant.elapsed() <= expiry_duration);
show
} else {
true
};
if !show_adv {
return Ok(());
}
let addr_style = if adv.is_connectable()? {
Style::new().yellow()
} else {
Style::new().red()
};
let (type_style, qualifier) = match adv.address()?.address_type()? {
AddressType::PublicIdentity | AddressType::PublicDevice => (Style::new().cyan(), ""),
_ => {
if addr.is_static()? {
(Style::new().green(), "(static)")
} else if addr.is_resolvable()? {
(Style::new().magenta(), "(resolvable)")
} else {
(Style::new().default_color(), "")
}
}
};
println!(
">>> {} [{:?}] {qualifier}:\n RSSI: {}",
addr.as_hex()?.style(addr_style),
addr.address_type()?.style(type_style),
rssi,
);
data_units.into_iter().for_each(|(code, data)| {
let matching = CommonDataType::for_type_code(code).collect::<Vec<_>>();
let code_str = if matching.is_empty() {
format!("0x{}", hex::encode_upper([code.into()]))
} else {
matching
.iter()
.map(|t| format!("{}", t))
.join(" / ")
.blue()
.to_string()
};
// use the first matching type's formatted data, if any
let data_str = matching
.iter()
.filter_map(|t| {
t.format_data(&data).map(|formatted| {
format!(
"{} {}",
formatted,
format!("(raw: 0x{})", hex::encode_upper(&data)).dimmed()
)
})
})
.next()
.unwrap_or_else(|| format!("0x{}", hex::encode_upper(&data)));
println!(" [{}]: {}", code_str, data_str)
});
Ok(())
})?;
device.power_on().await?;
// do our own dedup
device.start_scanning(false).await?;
// wait until user kills the process
tokio::signal::ctrl_c().await?;
Ok(())
}
#[derive(clap::Parser)]
#[command(author, version, about, long_about = None)]
struct Cli {
/// Bumble transport spec.
///
/// <https://google.github.io/bumble/transports/index.html>
#[arg(long)]
transport: String,
/// Filter duplicate advertisements
#[arg(long, default_value_t = false)]
filter_duplicates: bool,
/// How long before a deduplicated advertisement that hasn't been seen in a while is considered
/// fresh again, in seconds
#[arg(long, default_value_t = 10, requires = "filter_duplicates")]
dedup_expiry_secs: u64,
}

342
rust/examples/usb_probe.rs Normal file
View File

@@ -0,0 +1,342 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Rust version of the Python `usb_probe.py`.
//!
//! This tool lists all the USB devices, with details about each device.
//! For each device, the different possible Bumble transport strings that can
//! refer to it are listed. If the device is known to be a Bluetooth HCI device,
//! its identifier is printed in reverse colors, and the transport names in cyan color.
//! For other devices, regardless of their type, the transport names are printed
//! in red. Whether that device is actually a Bluetooth device or not depends on
//! whether it is a Bluetooth device that uses a non-standard Class, or some other
//! type of device (there's no way to tell).
use clap::Parser as _;
use itertools::Itertools as _;
use owo_colors::{OwoColorize, Style};
use rusb::{Device, DeviceDescriptor, Direction, TransferType, UsbContext};
use std::{
collections::{HashMap, HashSet},
time::Duration,
};
const USB_DEVICE_CLASS_DEVICE: u8 = 0x00;
const USB_DEVICE_CLASS_WIRELESS_CONTROLLER: u8 = 0xE0;
const USB_DEVICE_SUBCLASS_RF_CONTROLLER: u8 = 0x01;
const USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER: u8 = 0x01;
fn main() -> anyhow::Result<()> {
let cli = Cli::parse();
let mut bt_dev_count = 0;
let mut device_serials_by_id: HashMap<(u16, u16), HashSet<String>> = HashMap::new();
for device in rusb::devices()?.iter() {
let device_desc = device.device_descriptor().unwrap();
let class_info = ClassInfo::from(&device_desc);
let handle = device.open()?;
let timeout = Duration::from_secs(1);
// some devices don't have languages
let lang = handle
.read_languages(timeout)
.ok()
.and_then(|langs| langs.into_iter().next());
let serial = lang.and_then(|l| {
handle
.read_serial_number_string(l, &device_desc, timeout)
.ok()
});
let mfg = lang.and_then(|l| {
handle
.read_manufacturer_string(l, &device_desc, timeout)
.ok()
});
let product = lang.and_then(|l| handle.read_product_string(l, &device_desc, timeout).ok());
let is_hci = is_bluetooth_hci(&device, &device_desc)?;
let addr_style = if is_hci {
bt_dev_count += 1;
Style::new().black().on_yellow()
} else {
Style::new().yellow().on_black()
};
let mut transport_names = Vec::new();
let basic_transport_name = format!(
"usb:{:04X}:{:04X}",
device_desc.vendor_id(),
device_desc.product_id()
);
if is_hci {
transport_names.push(format!("usb:{}", bt_dev_count - 1));
}
let device_id = (device_desc.vendor_id(), device_desc.product_id());
if !device_serials_by_id.contains_key(&device_id) {
transport_names.push(basic_transport_name.clone());
} else {
transport_names.push(format!(
"{}#{}",
basic_transport_name,
device_serials_by_id
.get(&device_id)
.map(|serials| serials.len())
.unwrap_or(0)
))
}
if let Some(s) = &serial {
if !device_serials_by_id
.get(&device_id)
.map(|serials| serials.contains(s))
.unwrap_or(false)
{
transport_names.push(format!("{}/{}", basic_transport_name, s))
}
}
println!(
"{}",
format!(
"ID {:04X}:{:04X}",
device_desc.vendor_id(),
device_desc.product_id()
)
.style(addr_style)
);
if !transport_names.is_empty() {
let style = if is_hci {
Style::new().cyan()
} else {
Style::new().red()
};
println!(
"{:26}{}",
" Bumble Transport Names:".blue(),
transport_names.iter().map(|n| n.style(style)).join(" or ")
)
}
println!(
"{:26}{:03}/{:03}",
" Bus/Device:".green(),
device.bus_number(),
device.address()
);
println!(
"{:26}{}",
" Class:".green(),
class_info.formatted_class_name()
);
println!(
"{:26}{}",
" Subclass/Protocol:".green(),
class_info.formatted_subclass_protocol()
);
if let Some(s) = serial {
println!("{:26}{}", " Serial:".green(), s);
device_serials_by_id
.entry(device_id)
.or_insert(HashSet::new())
.insert(s);
}
if let Some(m) = mfg {
println!("{:26}{}", " Manufacturer:".green(), m);
}
if let Some(p) = product {
println!("{:26}{}", " Product:".green(), p);
}
if cli.verbose {
print_device_details(&device, &device_desc)?;
}
println!();
}
Ok(())
}
fn is_bluetooth_hci<T: UsbContext>(
device: &Device<T>,
device_desc: &DeviceDescriptor,
) -> rusb::Result<bool> {
if device_desc.class_code() == USB_DEVICE_CLASS_WIRELESS_CONTROLLER
&& device_desc.sub_class_code() == USB_DEVICE_SUBCLASS_RF_CONTROLLER
&& device_desc.protocol_code() == USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER
{
Ok(true)
} else if device_desc.class_code() == USB_DEVICE_CLASS_DEVICE {
for i in 0..device_desc.num_configurations() {
for interface in device.config_descriptor(i)?.interfaces() {
for d in interface.descriptors() {
if d.class_code() == USB_DEVICE_CLASS_WIRELESS_CONTROLLER
&& d.sub_class_code() == USB_DEVICE_SUBCLASS_RF_CONTROLLER
&& d.protocol_code() == USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER
{
return Ok(true);
}
}
}
}
Ok(false)
} else {
Ok(false)
}
}
fn print_device_details<T: UsbContext>(
device: &Device<T>,
device_desc: &DeviceDescriptor,
) -> anyhow::Result<()> {
for i in 0..device_desc.num_configurations() {
println!(" Configuration {}", i + 1);
for interface in device.config_descriptor(i)?.interfaces() {
let interface_descriptors: Vec<_> = interface.descriptors().collect();
for d in &interface_descriptors {
let class_info =
ClassInfo::new(d.class_code(), d.sub_class_code(), d.protocol_code());
println!(
" Interface: {}{} ({}, {})",
interface.number(),
if interface_descriptors.len() > 1 {
format!("/{}", d.setting_number())
} else {
String::new()
},
class_info.formatted_class_name(),
class_info.formatted_subclass_protocol()
);
for e in d.endpoint_descriptors() {
println!(
" Endpoint {:#04X}: {} {}",
e.address(),
match e.transfer_type() {
TransferType::Control => "CONTROL",
TransferType::Isochronous => "ISOCHRONOUS",
TransferType::Bulk => "BULK",
TransferType::Interrupt => "INTERRUPT",
},
match e.direction() {
Direction::In => "IN",
Direction::Out => "OUT",
}
)
}
}
}
}
Ok(())
}
struct ClassInfo {
class: u8,
sub_class: u8,
protocol: u8,
}
impl ClassInfo {
fn new(class: u8, sub_class: u8, protocol: u8) -> Self {
Self {
class,
sub_class,
protocol,
}
}
fn class_name(&self) -> Option<&str> {
match self.class {
0x00 => Some("Device"),
0x01 => Some("Audio"),
0x02 => Some("Communications and CDC Control"),
0x03 => Some("Human Interface Device"),
0x05 => Some("Physical"),
0x06 => Some("Still Imaging"),
0x07 => Some("Printer"),
0x08 => Some("Mass Storage"),
0x09 => Some("Hub"),
0x0A => Some("CDC Data"),
0x0B => Some("Smart Card"),
0x0D => Some("Content Security"),
0x0E => Some("Video"),
0x0F => Some("Personal Healthcare"),
0x10 => Some("Audio/Video"),
0x11 => Some("Billboard"),
0x12 => Some("USB Type-C Bridge"),
0x3C => Some("I3C"),
0xDC => Some("Diagnostic"),
USB_DEVICE_CLASS_WIRELESS_CONTROLLER => Some("Wireless Controller"),
0xEF => Some("Miscellaneous"),
0xFE => Some("Application Specific"),
0xFF => Some("Vendor Specific"),
_ => None,
}
}
fn protocol_name(&self) -> Option<&str> {
match self.class {
USB_DEVICE_CLASS_WIRELESS_CONTROLLER => match self.sub_class {
0x01 => match self.protocol {
0x01 => Some("Bluetooth"),
0x02 => Some("UWB"),
0x03 => Some("Remote NDIS"),
0x04 => Some("Bluetooth AMP"),
_ => None,
},
_ => None,
},
_ => None,
}
}
fn formatted_class_name(&self) -> String {
self.class_name()
.map(|s| s.to_string())
.unwrap_or_else(|| format!("{:#04X}", self.class))
}
fn formatted_subclass_protocol(&self) -> String {
format!(
"{}/{}{}",
self.sub_class,
self.protocol,
self.protocol_name()
.map(|s| format!(" [{}]", s))
.unwrap_or_else(String::new)
)
}
}
impl From<&DeviceDescriptor> for ClassInfo {
fn from(value: &DeviceDescriptor) -> Self {
Self::new(
value.class_code(),
value.sub_class_code(),
value.protocol_code(),
)
}
}
#[derive(clap::Parser)]
#[command(author, version, about, long_about = None)]
struct Cli {
/// Show additional info for each USB device
#[arg(long, default_value_t = false)]
verbose: bool,
}

View File

@@ -0,0 +1,30 @@
use bumble::wrapper::{self, core::Uuid16};
use pyo3::{intern, prelude::*, types::PyDict};
use std::collections;
#[pyo3_asyncio::tokio::test]
async fn company_ids_matches_python() -> PyResult<()> {
let ids_from_python = Python::with_gil(|py| {
PyModule::import(py, intern!(py, "bumble.company_ids"))?
.getattr(intern!(py, "COMPANY_IDENTIFIERS"))?
.downcast::<PyDict>()?
.into_iter()
.map(|(k, v)| {
Ok((
Uuid16::from_be_bytes(k.extract::<u16>()?.to_be_bytes()),
v.str()?.to_str()?.to_string(),
))
})
.collect::<PyResult<collections::HashMap<_, _>>>()
})?;
assert_eq!(
wrapper::assigned_numbers::COMPANY_IDS
.iter()
.map(|(id, name)| (*id, name.to_string()))
.collect::<collections::HashMap<_, _>>(),
ids_from_python,
"Company ids do not match -- re-run gen_assigned_numbers?"
);
Ok(())
}

21
rust/pytests/pytests.rs Normal file
View File

@@ -0,0 +1,21 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#[pyo3_asyncio::tokio::main]
async fn main() -> pyo3::PyResult<()> {
pyo3_asyncio::testing::main().await
}
mod assigned_numbers;
mod wrapper;

31
rust/pytests/wrapper.rs Normal file
View File

@@ -0,0 +1,31 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use bumble::wrapper::transport::Transport;
use nix::sys::stat::Mode;
use pyo3::PyResult;
#[pyo3_asyncio::tokio::test]
async fn fifo_transport_can_open() -> PyResult<()> {
let dir = tempfile::tempdir().unwrap();
let mut fifo = dir.path().to_path_buf();
fifo.push("bumble-transport-fifo");
nix::unistd::mkfifo(&fifo, Mode::S_IRWXU).unwrap();
let mut t = Transport::open(format!("file:{}", fifo.to_str().unwrap())).await?;
t.close().await?;
Ok(())
}

446
rust/src/adv.rs Normal file
View File

@@ -0,0 +1,446 @@
//! BLE advertisements.
use crate::wrapper::assigned_numbers::{COMPANY_IDS, SERVICE_IDS};
use crate::wrapper::core::{Uuid128, Uuid16, Uuid32};
use itertools::Itertools;
use nom::{combinator, multi, number};
use std::fmt;
use strum::IntoEnumIterator;
/// The numeric code for a common data type.
///
/// For known types, see [CommonDataType], or use this type directly for non-assigned codes.
#[derive(PartialEq, Eq, Debug, Clone, Copy, Hash)]
pub struct CommonDataTypeCode(u8);
impl From<CommonDataType> for CommonDataTypeCode {
fn from(value: CommonDataType) -> Self {
let byte = match value {
CommonDataType::Flags => 0x01,
CommonDataType::IncompleteListOf16BitServiceClassUuids => 0x02,
CommonDataType::CompleteListOf16BitServiceClassUuids => 0x03,
CommonDataType::IncompleteListOf32BitServiceClassUuids => 0x04,
CommonDataType::CompleteListOf32BitServiceClassUuids => 0x05,
CommonDataType::IncompleteListOf128BitServiceClassUuids => 0x06,
CommonDataType::CompleteListOf128BitServiceClassUuids => 0x07,
CommonDataType::ShortenedLocalName => 0x08,
CommonDataType::CompleteLocalName => 0x09,
CommonDataType::TxPowerLevel => 0x0A,
CommonDataType::ClassOfDevice => 0x0D,
CommonDataType::SimplePairingHashC192 => 0x0E,
CommonDataType::SimplePairingRandomizerR192 => 0x0F,
// These two both really have type code 0x10! D:
CommonDataType::DeviceId => 0x10,
CommonDataType::SecurityManagerTkValue => 0x10,
CommonDataType::SecurityManagerOutOfBandFlags => 0x11,
CommonDataType::PeripheralConnectionIntervalRange => 0x12,
CommonDataType::ListOf16BitServiceSolicitationUuids => 0x14,
CommonDataType::ListOf128BitServiceSolicitationUuids => 0x15,
CommonDataType::ServiceData16BitUuid => 0x16,
CommonDataType::PublicTargetAddress => 0x17,
CommonDataType::RandomTargetAddress => 0x18,
CommonDataType::Appearance => 0x19,
CommonDataType::AdvertisingInterval => 0x1A,
CommonDataType::LeBluetoothDeviceAddress => 0x1B,
CommonDataType::LeRole => 0x1C,
CommonDataType::SimplePairingHashC256 => 0x1D,
CommonDataType::SimplePairingRandomizerR256 => 0x1E,
CommonDataType::ListOf32BitServiceSolicitationUuids => 0x1F,
CommonDataType::ServiceData32BitUuid => 0x20,
CommonDataType::ServiceData128BitUuid => 0x21,
CommonDataType::LeSecureConnectionsConfirmationValue => 0x22,
CommonDataType::LeSecureConnectionsRandomValue => 0x23,
CommonDataType::Uri => 0x24,
CommonDataType::IndoorPositioning => 0x25,
CommonDataType::TransportDiscoveryData => 0x26,
CommonDataType::LeSupportedFeatures => 0x27,
CommonDataType::ChannelMapUpdateIndication => 0x28,
CommonDataType::PbAdv => 0x29,
CommonDataType::MeshMessage => 0x2A,
CommonDataType::MeshBeacon => 0x2B,
CommonDataType::BigInfo => 0x2C,
CommonDataType::BroadcastCode => 0x2D,
CommonDataType::ResolvableSetIdentifier => 0x2E,
CommonDataType::AdvertisingIntervalLong => 0x2F,
CommonDataType::ThreeDInformationData => 0x3D,
CommonDataType::ManufacturerSpecificData => 0xFF,
};
Self(byte)
}
}
impl From<u8> for CommonDataTypeCode {
fn from(value: u8) -> Self {
Self(value)
}
}
impl From<CommonDataTypeCode> for u8 {
fn from(value: CommonDataTypeCode) -> Self {
value.0
}
}
/// Data types for assigned type codes.
///
/// See Bluetooth Assigned Numbers § 2.3
#[derive(Debug, Clone, Copy, PartialEq, Eq, strum_macros::EnumIter)]
#[allow(missing_docs)]
pub enum CommonDataType {
Flags,
IncompleteListOf16BitServiceClassUuids,
CompleteListOf16BitServiceClassUuids,
IncompleteListOf32BitServiceClassUuids,
CompleteListOf32BitServiceClassUuids,
IncompleteListOf128BitServiceClassUuids,
CompleteListOf128BitServiceClassUuids,
ShortenedLocalName,
CompleteLocalName,
TxPowerLevel,
ClassOfDevice,
SimplePairingHashC192,
SimplePairingRandomizerR192,
DeviceId,
SecurityManagerTkValue,
SecurityManagerOutOfBandFlags,
PeripheralConnectionIntervalRange,
ListOf16BitServiceSolicitationUuids,
ListOf128BitServiceSolicitationUuids,
ServiceData16BitUuid,
PublicTargetAddress,
RandomTargetAddress,
Appearance,
AdvertisingInterval,
LeBluetoothDeviceAddress,
LeRole,
SimplePairingHashC256,
SimplePairingRandomizerR256,
ListOf32BitServiceSolicitationUuids,
ServiceData32BitUuid,
ServiceData128BitUuid,
LeSecureConnectionsConfirmationValue,
LeSecureConnectionsRandomValue,
Uri,
IndoorPositioning,
TransportDiscoveryData,
LeSupportedFeatures,
ChannelMapUpdateIndication,
PbAdv,
MeshMessage,
MeshBeacon,
BigInfo,
BroadcastCode,
ResolvableSetIdentifier,
AdvertisingIntervalLong,
ThreeDInformationData,
ManufacturerSpecificData,
}
impl CommonDataType {
/// Iterate over the zero, one, or more matching types for the provided code.
///
/// `0x10` maps to both Device Id and Security Manager TK Value, so multiple matching types
/// may exist for a single code.
pub fn for_type_code(code: CommonDataTypeCode) -> impl Iterator<Item = CommonDataType> {
Self::iter().filter(move |t| CommonDataTypeCode::from(*t) == code)
}
/// Apply type-specific human-oriented formatting to data, if any is applicable
pub fn format_data(&self, data: &[u8]) -> Option<String> {
match self {
Self::Flags => Some(Flags::matching(data).map(|f| format!("{:?}", f)).join(",")),
Self::CompleteListOf16BitServiceClassUuids
| Self::IncompleteListOf16BitServiceClassUuids
| Self::ListOf16BitServiceSolicitationUuids => {
combinator::complete(multi::many0(Uuid16::parse_le))(data)
.map(|(_res, uuids)| {
uuids
.into_iter()
.map(|uuid| {
SERVICE_IDS
.get(&uuid)
.map(|name| format!("{:?} ({name})", uuid))
.unwrap_or_else(|| format!("{:?}", uuid))
})
.join(", ")
})
.ok()
}
Self::CompleteListOf32BitServiceClassUuids
| Self::IncompleteListOf32BitServiceClassUuids
| Self::ListOf32BitServiceSolicitationUuids => {
combinator::complete(multi::many0(Uuid32::parse))(data)
.map(|(_res, uuids)| uuids.into_iter().map(|u| format!("{:?}", u)).join(", "))
.ok()
}
Self::CompleteListOf128BitServiceClassUuids
| Self::IncompleteListOf128BitServiceClassUuids
| Self::ListOf128BitServiceSolicitationUuids => {
combinator::complete(multi::many0(Uuid128::parse_le))(data)
.map(|(_res, uuids)| uuids.into_iter().map(|u| format!("{:?}", u)).join(", "))
.ok()
}
Self::ServiceData16BitUuid => Uuid16::parse_le(data)
.map(|(rem, uuid)| {
format!(
"service={:?}, data={}",
SERVICE_IDS
.get(&uuid)
.map(|name| format!("{:?} ({name})", uuid))
.unwrap_or_else(|| format!("{:?}", uuid)),
hex::encode_upper(rem)
)
})
.ok(),
Self::ServiceData32BitUuid => Uuid32::parse(data)
.map(|(rem, uuid)| format!("service={:?}, data={}", uuid, hex::encode_upper(rem)))
.ok(),
Self::ServiceData128BitUuid => Uuid128::parse_le(data)
.map(|(rem, uuid)| format!("service={:?}, data={}", uuid, hex::encode_upper(rem)))
.ok(),
Self::ShortenedLocalName | Self::CompleteLocalName => {
std::str::from_utf8(data).ok().map(|s| format!("\"{}\"", s))
}
Self::TxPowerLevel => {
let (_, tx) =
combinator::complete(number::complete::i8::<_, nom::error::Error<_>>)(data)
.ok()?;
Some(tx.to_string())
}
Self::ManufacturerSpecificData => {
let (rem, id) = Uuid16::parse_le(data).ok()?;
Some(format!(
"company={}, data=0x{}",
COMPANY_IDS
.get(&id)
.map(|s| s.to_string())
.unwrap_or_else(|| format!("{:?}", id)),
hex::encode_upper(rem)
))
}
_ => None,
}
}
}
impl fmt::Display for CommonDataType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
CommonDataType::Flags => write!(f, "Flags"),
CommonDataType::IncompleteListOf16BitServiceClassUuids => {
write!(f, "Incomplete List of 16-bit Service Class UUIDs")
}
CommonDataType::CompleteListOf16BitServiceClassUuids => {
write!(f, "Complete List of 16-bit Service Class UUIDs")
}
CommonDataType::IncompleteListOf32BitServiceClassUuids => {
write!(f, "Incomplete List of 32-bit Service Class UUIDs")
}
CommonDataType::CompleteListOf32BitServiceClassUuids => {
write!(f, "Complete List of 32-bit Service Class UUIDs")
}
CommonDataType::ListOf16BitServiceSolicitationUuids => {
write!(f, "List of 16-bit Service Solicitation UUIDs")
}
CommonDataType::ListOf32BitServiceSolicitationUuids => {
write!(f, "List of 32-bit Service Solicitation UUIDs")
}
CommonDataType::ListOf128BitServiceSolicitationUuids => {
write!(f, "List of 128-bit Service Solicitation UUIDs")
}
CommonDataType::IncompleteListOf128BitServiceClassUuids => {
write!(f, "Incomplete List of 128-bit Service Class UUIDs")
}
CommonDataType::CompleteListOf128BitServiceClassUuids => {
write!(f, "Complete List of 128-bit Service Class UUIDs")
}
CommonDataType::ShortenedLocalName => write!(f, "Shortened Local Name"),
CommonDataType::CompleteLocalName => write!(f, "Complete Local Name"),
CommonDataType::TxPowerLevel => write!(f, "TX Power Level"),
CommonDataType::ClassOfDevice => write!(f, "Class of Device"),
CommonDataType::SimplePairingHashC192 => {
write!(f, "Simple Pairing Hash C-192")
}
CommonDataType::SimplePairingHashC256 => {
write!(f, "Simple Pairing Hash C 256")
}
CommonDataType::SimplePairingRandomizerR192 => {
write!(f, "Simple Pairing Randomizer R-192")
}
CommonDataType::SimplePairingRandomizerR256 => {
write!(f, "Simple Pairing Randomizer R 256")
}
CommonDataType::DeviceId => write!(f, "Device Id"),
CommonDataType::SecurityManagerTkValue => {
write!(f, "Security Manager TK Value")
}
CommonDataType::SecurityManagerOutOfBandFlags => {
write!(f, "Security Manager Out of Band Flags")
}
CommonDataType::PeripheralConnectionIntervalRange => {
write!(f, "Peripheral Connection Interval Range")
}
CommonDataType::ServiceData16BitUuid => {
write!(f, "Service Data 16-bit UUID")
}
CommonDataType::ServiceData32BitUuid => {
write!(f, "Service Data 32-bit UUID")
}
CommonDataType::ServiceData128BitUuid => {
write!(f, "Service Data 128-bit UUID")
}
CommonDataType::PublicTargetAddress => write!(f, "Public Target Address"),
CommonDataType::RandomTargetAddress => write!(f, "Random Target Address"),
CommonDataType::Appearance => write!(f, "Appearance"),
CommonDataType::AdvertisingInterval => write!(f, "Advertising Interval"),
CommonDataType::LeBluetoothDeviceAddress => {
write!(f, "LE Bluetooth Device Address")
}
CommonDataType::LeRole => write!(f, "LE Role"),
CommonDataType::LeSecureConnectionsConfirmationValue => {
write!(f, "LE Secure Connections Confirmation Value")
}
CommonDataType::LeSecureConnectionsRandomValue => {
write!(f, "LE Secure Connections Random Value")
}
CommonDataType::LeSupportedFeatures => write!(f, "LE Supported Features"),
CommonDataType::Uri => write!(f, "URI"),
CommonDataType::IndoorPositioning => write!(f, "Indoor Positioning"),
CommonDataType::TransportDiscoveryData => {
write!(f, "Transport Discovery Data")
}
CommonDataType::ChannelMapUpdateIndication => {
write!(f, "Channel Map Update Indication")
}
CommonDataType::PbAdv => write!(f, "PB-ADV"),
CommonDataType::MeshMessage => write!(f, "Mesh Message"),
CommonDataType::MeshBeacon => write!(f, "Mesh Beacon"),
CommonDataType::BigInfo => write!(f, "BIGIInfo"),
CommonDataType::BroadcastCode => write!(f, "Broadcast Code"),
CommonDataType::ResolvableSetIdentifier => {
write!(f, "Resolvable Set Identifier")
}
CommonDataType::AdvertisingIntervalLong => {
write!(f, "Advertising Interval Long")
}
CommonDataType::ThreeDInformationData => write!(f, "3D Information Data"),
CommonDataType::ManufacturerSpecificData => {
write!(f, "Manufacturer Specific Data")
}
}
}
}
/// Accumulates advertisement data to broadcast on a [crate::wrapper::device::Device].
#[derive(Debug, Clone, Default)]
pub struct AdvertisementDataBuilder {
encoded_data: Vec<u8>,
}
impl AdvertisementDataBuilder {
/// Returns a new, empty instance.
pub fn new() -> Self {
Self {
encoded_data: Vec::new(),
}
}
/// Append advertising data to the builder.
///
/// Returns an error if the data cannot be appended.
pub fn append(
&mut self,
type_code: impl Into<CommonDataTypeCode>,
data: &[u8],
) -> Result<(), AdvertisementDataBuilderError> {
self.encoded_data.push(
data.len()
.try_into()
.ok()
.and_then(|len: u8| len.checked_add(1))
.ok_or(AdvertisementDataBuilderError::DataTooLong)?,
);
self.encoded_data.push(type_code.into().0);
self.encoded_data.extend_from_slice(data);
Ok(())
}
pub(crate) fn into_bytes(self) -> Vec<u8> {
self.encoded_data
}
}
/// Errors that can occur when building advertisement data with [AdvertisementDataBuilder].
#[derive(Debug, PartialEq, Eq, thiserror::Error)]
pub enum AdvertisementDataBuilderError {
/// The provided adv data is too long to be encoded
#[error("Data too long")]
DataTooLong,
}
#[derive(PartialEq, Eq, strum_macros::EnumIter)]
#[allow(missing_docs)]
/// Features in the Flags AD
pub enum Flags {
LeLimited,
LeDiscoverable,
NoBrEdr,
BrEdrController,
BrEdrHost,
}
impl fmt::Debug for Flags {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.short_name())
}
}
impl Flags {
/// Iterates over the flags that are present in the provided `flags` bytes.
pub fn matching(flags: &[u8]) -> impl Iterator<Item = Self> + '_ {
// The encoding is not clear from the spec: do we look at the first byte? or the last?
// In practice it's only one byte.
let first_byte = flags.first().unwrap_or(&0_u8);
Self::iter().filter(move |f| {
let mask = match f {
Flags::LeLimited => 0x01_u8,
Flags::LeDiscoverable => 0x02,
Flags::NoBrEdr => 0x04,
Flags::BrEdrController => 0x08,
Flags::BrEdrHost => 0x10,
};
mask & first_byte > 0
})
}
/// An abbreviated form of the flag name.
///
/// See [Flags::name] for the full name.
pub fn short_name(&self) -> &'static str {
match self {
Flags::LeLimited => "LE Limited",
Flags::LeDiscoverable => "LE General",
Flags::NoBrEdr => "No BR/EDR",
Flags::BrEdrController => "BR/EDR C",
Flags::BrEdrHost => "BR/EDR H",
}
}
/// The human-readable name of the flag.
///
/// See [Flags::short_name] for a shorter string for use if compactness is important.
pub fn name(&self) -> &'static str {
match self {
Flags::LeLimited => "LE Limited Discoverable Mode",
Flags::LeDiscoverable => "LE General Discoverable Mode",
Flags::NoBrEdr => "BR/EDR Not Supported",
Flags::BrEdrController => "Simultaneous LE and BR/EDR (Controller)",
Flags::BrEdrHost => "Simultaneous LE and BR/EDR (Host)",
}
}
}

31
rust/src/lib.rs Normal file
View File

@@ -0,0 +1,31 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Rust API for [Bumble](https://github.com/google/bumble).
//!
//! Bumble is a userspace Bluetooth stack that works with more or less anything that uses HCI. This
//! could be physical Bluetooth USB dongles, netsim, HCI proxied over a network from some device
//! elsewhere, etc.
//!
//! It also does not restrict what you can do with Bluetooth the way that OS Bluetooth APIs
//! typically do, making it good for prototyping, experimentation, test tools, etc.
//!
//! Bumble is primarily written in Python. Rust types that wrap the Python API, which is currently
//! the bulk of the code, are in the [wrapper] module.
#![deny(missing_docs, unsafe_code)]
pub mod wrapper;
pub mod adv;

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,21 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Assigned numbers from the Bluetooth spec.
mod company_ids;
mod services;
pub use company_ids::COMPANY_IDS;
pub use services::SERVICE_IDS;

View File

@@ -0,0 +1,82 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Assigned service IDs
use crate::wrapper::core::Uuid16;
use lazy_static::lazy_static;
use std::collections;
lazy_static! {
/// Assigned service IDs
pub static ref SERVICE_IDS: collections::HashMap<Uuid16, &'static str> = [
(0x1800_u16, "Generic Access"),
(0x1801, "Generic Attribute"),
(0x1802, "Immediate Alert"),
(0x1803, "Link Loss"),
(0x1804, "TX Power"),
(0x1805, "Current Time"),
(0x1806, "Reference Time Update"),
(0x1807, "Next DST Change"),
(0x1808, "Glucose"),
(0x1809, "Health Thermometer"),
(0x180A, "Device Information"),
(0x180D, "Heart Rate"),
(0x180E, "Phone Alert Status"),
(0x180F, "Battery"),
(0x1810, "Blood Pressure"),
(0x1811, "Alert Notification"),
(0x1812, "Human Interface Device"),
(0x1813, "Scan Parameters"),
(0x1814, "Running Speed and Cadence"),
(0x1815, "Automation IO"),
(0x1816, "Cycling Speed and Cadence"),
(0x1818, "Cycling Power"),
(0x1819, "Location and Navigation"),
(0x181A, "Environmental Sensing"),
(0x181B, "Body Composition"),
(0x181C, "User Data"),
(0x181D, "Weight Scale"),
(0x181E, "Bond Management"),
(0x181F, "Continuous Glucose Monitoring"),
(0x1820, "Internet Protocol Support"),
(0x1821, "Indoor Positioning"),
(0x1822, "Pulse Oximeter"),
(0x1823, "HTTP Proxy"),
(0x1824, "Transport Discovery"),
(0x1825, "Object Transfer"),
(0x1826, "Fitness Machine"),
(0x1827, "Mesh Provisioning"),
(0x1828, "Mesh Proxy"),
(0x1829, "Reconnection Configuration"),
(0x183A, "Insulin Delivery"),
(0x183B, "Binary Sensor"),
(0x183C, "Emergency Configuration"),
(0x183E, "Physical Activity Monitor"),
(0x1843, "Audio Input Control"),
(0x1844, "Volume Control"),
(0x1845, "Volume Offset Control"),
(0x1846, "Coordinated Set Identification Service"),
(0x1847, "Device Time"),
(0x1848, "Media Control Service"),
(0x1849, "Generic Media Control Service"),
(0x184A, "Constant Tone Extension"),
(0x184B, "Telephone Bearer Service"),
(0x184C, "Generic Telephone Bearer Service"),
(0x184D, "Microphone Control"),
]
.into_iter()
.map(|(num, name)| (Uuid16::from_le_bytes(num.to_le_bytes()), name))
.collect();
}

196
rust/src/wrapper/core.rs Normal file
View File

@@ -0,0 +1,196 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Core types
use crate::adv::CommonDataTypeCode;
use lazy_static::lazy_static;
use nom::{bytes, combinator};
use pyo3::{intern, PyObject, PyResult, Python};
use std::fmt;
lazy_static! {
static ref BASE_UUID: [u8; 16] = hex::decode("0000000000001000800000805F9B34FB")
.unwrap()
.try_into()
.unwrap();
}
/// A type code and data pair from an advertisement
pub type AdvertisementDataUnit = (CommonDataTypeCode, Vec<u8>);
/// Contents of an advertisement
pub struct AdvertisingData(pub(crate) PyObject);
impl AdvertisingData {
/// Data units in the advertisement contents
pub fn data_units(&self) -> PyResult<Vec<AdvertisementDataUnit>> {
Python::with_gil(|py| {
let list = self.0.getattr(py, intern!(py, "ad_structures"))?;
list.as_ref(py)
.iter()?
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.map(|tuple| {
let type_code = tuple
.call_method1(intern!(py, "__getitem__"), (0,))?
.extract::<u8>()?
.into();
let data = tuple
.call_method1(intern!(py, "__getitem__"), (1,))?
.extract::<Vec<u8>>()?;
Ok((type_code, data))
})
.collect::<Result<Vec<_>, _>>()
})
}
}
/// 16-bit UUID
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
pub struct Uuid16 {
/// Big-endian bytes
uuid: [u8; 2],
}
impl Uuid16 {
/// Construct a UUID from little-endian bytes
pub fn from_le_bytes(mut bytes: [u8; 2]) -> Self {
bytes.reverse();
Self::from_be_bytes(bytes)
}
/// Construct a UUID from big-endian bytes
pub fn from_be_bytes(bytes: [u8; 2]) -> Self {
Self { uuid: bytes }
}
/// The UUID in big-endian bytes form
pub fn as_be_bytes(&self) -> [u8; 2] {
self.uuid
}
/// The UUID in little-endian bytes form
pub fn as_le_bytes(&self) -> [u8; 2] {
let mut uuid = self.uuid;
uuid.reverse();
uuid
}
pub(crate) fn parse_le(input: &[u8]) -> nom::IResult<&[u8], Self> {
combinator::map_res(bytes::complete::take(2_usize), |b: &[u8]| {
b.try_into().map(|mut uuid: [u8; 2]| {
uuid.reverse();
Self { uuid }
})
})(input)
}
}
impl fmt::Debug for Uuid16 {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "UUID-16:{}", hex::encode_upper(self.uuid))
}
}
/// 32-bit UUID
#[derive(PartialEq, Eq, Hash)]
pub struct Uuid32 {
/// Big-endian bytes
uuid: [u8; 4],
}
impl Uuid32 {
/// The UUID in big-endian bytes form
pub fn as_bytes(&self) -> [u8; 4] {
self.uuid
}
pub(crate) fn parse(input: &[u8]) -> nom::IResult<&[u8], Self> {
combinator::map_res(bytes::complete::take(4_usize), |b: &[u8]| {
b.try_into().map(|mut uuid: [u8; 4]| {
uuid.reverse();
Self { uuid }
})
})(input)
}
}
impl fmt::Debug for Uuid32 {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "UUID-32:{}", hex::encode_upper(self.uuid))
}
}
impl From<Uuid16> for Uuid32 {
fn from(value: Uuid16) -> Self {
let mut uuid = [0; 4];
uuid[2..].copy_from_slice(&value.uuid);
Self { uuid }
}
}
/// 128-bit UUID
#[derive(PartialEq, Eq, Hash)]
pub struct Uuid128 {
/// Big-endian bytes
uuid: [u8; 16],
}
impl Uuid128 {
/// The UUID in big-endian bytes form
pub fn as_bytes(&self) -> [u8; 16] {
self.uuid
}
pub(crate) fn parse_le(input: &[u8]) -> nom::IResult<&[u8], Self> {
combinator::map_res(bytes::complete::take(16_usize), |b: &[u8]| {
b.try_into().map(|mut uuid: [u8; 16]| {
uuid.reverse();
Self { uuid }
})
})(input)
}
}
impl fmt::Debug for Uuid128 {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"{}-{}-{}-{}-{}",
hex::encode_upper(&self.uuid[..4]),
hex::encode_upper(&self.uuid[4..6]),
hex::encode_upper(&self.uuid[6..8]),
hex::encode_upper(&self.uuid[8..10]),
hex::encode_upper(&self.uuid[10..])
)
}
}
impl From<Uuid16> for Uuid128 {
fn from(value: Uuid16) -> Self {
let mut uuid = *BASE_UUID;
uuid[2..4].copy_from_slice(&value.uuid);
Self { uuid }
}
}
impl From<Uuid32> for Uuid128 {
fn from(value: Uuid32) -> Self {
let mut uuid = *BASE_UUID;
uuid[..4].copy_from_slice(&value.uuid);
Self { uuid }
}
}

248
rust/src/wrapper/device.rs Normal file
View File

@@ -0,0 +1,248 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Devices and connections to them
use crate::{
adv::AdvertisementDataBuilder,
wrapper::{
core::AdvertisingData,
gatt_client::{ProfileServiceProxy, ServiceProxy},
hci::Address,
transport::{Sink, Source},
ClosureCallback,
},
};
use pyo3::types::PyDict;
use pyo3::{intern, types::PyModule, PyObject, PyResult, Python, ToPyObject};
use std::path;
/// A device that can send/receive HCI frames.
#[derive(Clone)]
pub struct Device(PyObject);
impl Device {
/// Create a Device per the provided file configured to communicate with a controller through an HCI source/sink
pub fn from_config_file_with_hci(
device_config: &path::Path,
source: Source,
sink: Sink,
) -> PyResult<Self> {
Python::with_gil(|py| {
PyModule::import(py, intern!(py, "bumble.device"))?
.getattr(intern!(py, "Device"))?
.call_method1(
intern!(py, "from_config_file_with_hci"),
(device_config, source.0, sink.0),
)
.map(|any| Self(any.into()))
})
}
/// Create a Device configured to communicate with a controller through an HCI source/sink
pub fn with_hci(name: &str, address: &str, source: Source, sink: Sink) -> PyResult<Self> {
Python::with_gil(|py| {
PyModule::import(py, intern!(py, "bumble.device"))?
.getattr(intern!(py, "Device"))?
.call_method1(intern!(py, "with_hci"), (name, address, source.0, sink.0))
.map(|any| Self(any.into()))
})
}
/// Turn the device on
pub async fn power_on(&self) -> PyResult<()> {
Python::with_gil(|py| {
self.0
.call_method0(py, intern!(py, "power_on"))
.and_then(|coroutine| pyo3_asyncio::tokio::into_future(coroutine.as_ref(py)))
})?
.await
.map(|_| ())
}
/// Connect to a peer
pub async fn connect(&self, peer_addr: &str) -> PyResult<Connection> {
Python::with_gil(|py| {
self.0
.call_method1(py, intern!(py, "connect"), (peer_addr,))
.and_then(|coroutine| pyo3_asyncio::tokio::into_future(coroutine.as_ref(py)))
})?
.await
.map(Connection)
}
/// Start scanning
pub async fn start_scanning(&self, filter_duplicates: bool) -> PyResult<()> {
Python::with_gil(|py| {
let kwargs = PyDict::new(py);
kwargs.set_item("filter_duplicates", filter_duplicates)?;
self.0
.call_method(py, intern!(py, "start_scanning"), (), Some(kwargs))
.and_then(|coroutine| pyo3_asyncio::tokio::into_future(coroutine.as_ref(py)))
})?
.await
.map(|_| ())
}
/// Register a callback to be called for each advertisement
pub fn on_advertisement(
&mut self,
callback: impl Fn(Python, Advertisement) -> PyResult<()> + Send + 'static,
) -> PyResult<()> {
let boxed = ClosureCallback::new(move |py, args, _kwargs| {
callback(py, Advertisement(args.get_item(0)?.into()))
});
Python::with_gil(|py| {
self.0
.call_method1(py, intern!(py, "add_listener"), ("advertisement", boxed))
})
.map(|_| ())
}
/// Set the advertisement data to be used when [Device::start_advertising] is called.
pub fn set_advertising_data(&mut self, adv_data: AdvertisementDataBuilder) -> PyResult<()> {
Python::with_gil(|py| {
self.0.setattr(
py,
intern!(py, "advertising_data"),
adv_data.into_bytes().as_slice(),
)
})
.map(|_| ())
}
/// Start advertising the data set with [Device.set_advertisement].
pub async fn start_advertising(&mut self, auto_restart: bool) -> PyResult<()> {
Python::with_gil(|py| {
let kwargs = PyDict::new(py);
kwargs.set_item("auto_restart", auto_restart)?;
self.0
.call_method(py, intern!(py, "start_advertising"), (), Some(kwargs))
.and_then(|coroutine| pyo3_asyncio::tokio::into_future(coroutine.as_ref(py)))
})?
.await
.map(|_| ())
}
/// Stop advertising.
pub async fn stop_advertising(&mut self) -> PyResult<()> {
Python::with_gil(|py| {
self.0
.call_method0(py, intern!(py, "stop_advertising"))
.and_then(|coroutine| pyo3_asyncio::tokio::into_future(coroutine.as_ref(py)))
})?
.await
.map(|_| ())
}
}
/// A connection to a remote device.
pub struct Connection(PyObject);
/// The other end of a connection
pub struct Peer(PyObject);
impl Peer {
/// Wrap a [Connection] in a Peer
pub fn new(conn: Connection) -> PyResult<Self> {
Python::with_gil(|py| {
PyModule::import(py, intern!(py, "bumble.device"))?
.getattr(intern!(py, "Peer"))?
.call1((conn.0,))
.map(|obj| Self(obj.into()))
})
}
/// Populates the peer's cache of services.
///
/// Returns the discovered services.
pub async fn discover_services(&mut self) -> PyResult<Vec<ServiceProxy>> {
Python::with_gil(|py| {
self.0
.call_method0(py, intern!(py, "discover_services"))
.and_then(|coroutine| pyo3_asyncio::tokio::into_future(coroutine.as_ref(py)))
})?
.await
.and_then(|list| {
Python::with_gil(|py| {
list.as_ref(py)
.iter()?
.map(|r| r.map(|h| ServiceProxy(h.to_object(py))))
.collect()
})
})
}
/// Returns a snapshot of the Services currently in the peer's cache
pub fn services(&self) -> PyResult<Vec<ServiceProxy>> {
Python::with_gil(|py| {
self.0
.getattr(py, intern!(py, "services"))?
.as_ref(py)
.iter()?
.map(|r| r.map(|h| ServiceProxy(h.to_object(py))))
.collect()
})
}
/// Build a [ProfileServiceProxy] for the specified type.
/// [Peer::discover_services] or some other means of populating the Peer's service cache must be
/// called first, or the required service won't be found.
pub fn create_service_proxy<P: ProfileServiceProxy>(&self) -> PyResult<Option<P>> {
Python::with_gil(|py| {
let module = py.import(P::PROXY_CLASS_MODULE)?;
let class = module.getattr(P::PROXY_CLASS_NAME)?;
self.0
.call_method1(py, intern!(py, "create_service_proxy"), (class,))
.map(|obj| {
if obj.is_none(py) {
None
} else {
Some(P::wrap(obj))
}
})
})
}
}
/// A BLE advertisement
pub struct Advertisement(PyObject);
impl Advertisement {
/// Address that sent the advertisement
pub fn address(&self) -> PyResult<Address> {
Python::with_gil(|py| self.0.getattr(py, intern!(py, "address")).map(Address))
}
/// Returns true if the advertisement is connectable
pub fn is_connectable(&self) -> PyResult<bool> {
Python::with_gil(|py| {
self.0
.getattr(py, intern!(py, "is_connectable"))?
.extract::<bool>(py)
})
}
/// RSSI of the advertisement
pub fn rssi(&self) -> PyResult<i8> {
Python::with_gil(|py| self.0.getattr(py, intern!(py, "rssi"))?.extract::<i8>(py))
}
/// Data in the advertisement
pub fn data(&self) -> PyResult<AdvertisingData> {
Python::with_gil(|py| self.0.getattr(py, intern!(py, "data")).map(AdvertisingData))
}
}

View File

@@ -0,0 +1,79 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! GATT client support
use crate::wrapper::ClosureCallback;
use pyo3::types::PyTuple;
use pyo3::{intern, PyObject, PyResult, Python};
/// A GATT service on a remote device
pub struct ServiceProxy(pub(crate) PyObject);
impl ServiceProxy {
/// Discover the characteristics in this service.
///
/// Populates an internal cache of characteristics in this service.
pub async fn discover_characteristics(&mut self) -> PyResult<()> {
Python::with_gil(|py| {
self.0
.call_method0(py, intern!(py, "discover_characteristics"))
.and_then(|coroutine| pyo3_asyncio::tokio::into_future(coroutine.as_ref(py)))
})?
.await
.map(|_| ())
}
}
/// A GATT characteristic on a remote device
pub struct CharacteristicProxy(pub(crate) PyObject);
impl CharacteristicProxy {
/// Subscribe to changes to the characteristic, executing `callback` for each new value
pub async fn subscribe(
&mut self,
callback: impl Fn(Python, &PyTuple) -> PyResult<()> + Send + 'static,
) -> PyResult<()> {
let boxed = ClosureCallback::new(move |py, args, _kwargs| callback(py, args));
Python::with_gil(|py| {
self.0
.call_method1(py, intern!(py, "subscribe"), (boxed,))
.and_then(|obj| pyo3_asyncio::tokio::into_future(obj.as_ref(py)))
})?
.await
.map(|_| ())
}
/// Read the current value of the characteristic
pub async fn read_value(&self) -> PyResult<PyObject> {
Python::with_gil(|py| {
self.0
.call_method0(py, intern!(py, "read_value"))
.and_then(|obj| pyo3_asyncio::tokio::into_future(obj.as_ref(py)))
})?
.await
}
}
/// Equivalent to the Python `ProfileServiceProxy`.
pub trait ProfileServiceProxy {
/// The module containing the proxy class
const PROXY_CLASS_MODULE: &'static str;
/// The module class name
const PROXY_CLASS_NAME: &'static str;
/// Wrap a PyObject in the Rust wrapper type
fn wrap(obj: PyObject) -> Self;
}

112
rust/src/wrapper/hci.rs Normal file
View File

@@ -0,0 +1,112 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! HCI
use itertools::Itertools as _;
use pyo3::{exceptions::PyException, intern, types::PyModule, PyErr, PyObject, PyResult, Python};
/// A Bluetooth address
pub struct Address(pub(crate) PyObject);
impl Address {
/// The type of address
pub fn address_type(&self) -> PyResult<AddressType> {
Python::with_gil(|py| {
let addr_type = self
.0
.getattr(py, intern!(py, "address_type"))?
.extract::<u32>(py)?;
let module = PyModule::import(py, intern!(py, "bumble.hci"))?;
let klass = module.getattr(intern!(py, "Address"))?;
if addr_type
== klass
.getattr(intern!(py, "PUBLIC_DEVICE_ADDRESS"))?
.extract::<u32>()?
{
Ok(AddressType::PublicDevice)
} else if addr_type
== klass
.getattr(intern!(py, "RANDOM_DEVICE_ADDRESS"))?
.extract::<u32>()?
{
Ok(AddressType::RandomDevice)
} else if addr_type
== klass
.getattr(intern!(py, "PUBLIC_IDENTITY_ADDRESS"))?
.extract::<u32>()?
{
Ok(AddressType::PublicIdentity)
} else if addr_type
== klass
.getattr(intern!(py, "RANDOM_IDENTITY_ADDRESS"))?
.extract::<u32>()?
{
Ok(AddressType::RandomIdentity)
} else {
Err(PyErr::new::<PyException, _>("Invalid address type"))
}
})
}
/// True if the address is static
pub fn is_static(&self) -> PyResult<bool> {
Python::with_gil(|py| {
self.0
.getattr(py, intern!(py, "is_static"))?
.extract::<bool>(py)
})
}
/// True if the address is resolvable
pub fn is_resolvable(&self) -> PyResult<bool> {
Python::with_gil(|py| {
self.0
.getattr(py, intern!(py, "is_resolvable"))?
.extract::<bool>(py)
})
}
/// Address bytes in _little-endian_ format
pub fn as_le_bytes(&self) -> PyResult<Vec<u8>> {
Python::with_gil(|py| {
self.0
.call_method0(py, intern!(py, "to_bytes"))?
.extract::<Vec<u8>>(py)
})
}
/// Address bytes as big-endian colon-separated hex
pub fn as_hex(&self) -> PyResult<String> {
self.as_le_bytes().map(|bytes| {
bytes
.into_iter()
.rev()
.map(|byte| hex::encode_upper([byte]))
.join(":")
})
}
}
/// BT address types
#[allow(missing_docs)]
#[derive(PartialEq, Eq, Debug)]
pub enum AddressType {
PublicDevice,
RandomDevice,
PublicIdentity,
RandomIdentity,
}

View File

@@ -0,0 +1,27 @@
//! Bumble & Python logging
use pyo3::types::PyDict;
use pyo3::{intern, types::PyModule, PyResult, Python};
use std::env;
/// Returns the uppercased contents of the `BUMBLE_LOGLEVEL` env var, or `default` if it is not present or not UTF-8.
///
/// The result could be passed to [py_logging_basic_config] to configure Python's logging
/// accordingly.
pub fn bumble_env_logging_level(default: impl Into<String>) -> String {
env::var("BUMBLE_LOGLEVEL")
.unwrap_or_else(|_| default.into())
.to_ascii_uppercase()
}
/// Call `logging.basicConfig` with the provided logging level
pub fn py_logging_basic_config(log_level: impl Into<String>) -> PyResult<()> {
Python::with_gil(|py| {
let kwargs = PyDict::new(py);
kwargs.set_item("level", log_level.into())?;
PyModule::import(py, intern!(py, "logging"))?
.call_method(intern!(py, "basicConfig"), (), Some(kwargs))
.map(|_| ())
})
}

92
rust/src/wrapper/mod.rs Normal file
View File

@@ -0,0 +1,92 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Types that wrap the Python API.
//!
//! Because mutability, aliasing, etc is all hidden behind Python, the normal Rust rules about
//! only one mutable reference to one piece of memory, etc, may not hold since using `&mut self`
//! instead of `&self` is only guided by inspection of the Python source, not the compiler.
//!
//! The modules are generally structured to mirror the Python equivalents.
// Re-exported to make it easy for users to depend on the same `PyObject`, etc
pub use pyo3;
use pyo3::{
prelude::*,
types::{PyDict, PyTuple},
};
pub use pyo3_asyncio;
pub mod assigned_numbers;
pub mod core;
pub mod device;
pub mod gatt_client;
pub mod hci;
pub mod logging;
pub mod profile;
pub mod transport;
/// Convenience extensions to [PyObject]
pub trait PyObjectExt {
/// Get a GIL-bound reference
fn gil_ref<'py>(&'py self, py: Python<'py>) -> &'py PyAny;
/// Extract any [FromPyObject] implementation from this value
fn extract_with_gil<T>(&self) -> PyResult<T>
where
T: for<'a> FromPyObject<'a>,
{
Python::with_gil(|py| self.gil_ref(py).extract::<T>())
}
}
impl PyObjectExt for PyObject {
fn gil_ref<'py>(&'py self, py: Python<'py>) -> &'py PyAny {
self.as_ref(py)
}
}
/// Wrapper to make Rust closures ([Fn] implementations) callable from Python.
///
/// The Python callable form returns a Python `None`.
#[pyclass(name = "SubscribeCallback")]
pub(crate) struct ClosureCallback {
// can't use generics in a pyclass, so have to box
#[allow(clippy::type_complexity)]
callback: Box<dyn Fn(Python, &PyTuple, Option<&PyDict>) -> PyResult<()> + Send + 'static>,
}
impl ClosureCallback {
/// Create a new callback around the provided closure
pub fn new(
callback: impl Fn(Python, &PyTuple, Option<&PyDict>) -> PyResult<()> + Send + 'static,
) -> Self {
Self {
callback: Box::new(callback),
}
}
}
#[pymethods]
impl ClosureCallback {
#[pyo3(signature = (*args, **kwargs))]
fn __call__(
&self,
py: Python<'_>,
args: &PyTuple,
kwargs: Option<&PyDict>,
) -> PyResult<Py<PyAny>> {
(self.callback)(py, args, kwargs).map(|_| py.None())
}
}

View File

@@ -0,0 +1,47 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! GATT profiles
use crate::wrapper::gatt_client::{CharacteristicProxy, ProfileServiceProxy};
use pyo3::{intern, PyObject, PyResult, Python};
/// Exposes the battery GATT service
pub struct BatteryServiceProxy(PyObject);
impl BatteryServiceProxy {
/// Get the battery level, if available
pub fn battery_level(&self) -> PyResult<Option<CharacteristicProxy>> {
Python::with_gil(|py| {
self.0
.getattr(py, intern!(py, "battery_level"))
.map(|level| {
if level.is_none(py) {
None
} else {
Some(CharacteristicProxy(level))
}
})
})
}
}
impl ProfileServiceProxy for BatteryServiceProxy {
const PROXY_CLASS_MODULE: &'static str = "bumble.profiles.battery_service";
const PROXY_CLASS_NAME: &'static str = "BatteryServiceProxy";
fn wrap(obj: PyObject) -> Self {
Self(obj)
}
}

View File

@@ -0,0 +1,72 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! HCI packet transport
use pyo3::{intern, types::PyModule, PyObject, PyResult, Python};
/// A source/sink pair for HCI packet I/O.
///
/// See <https://google.github.io/bumble/transports/index.html>.
pub struct Transport(PyObject);
impl Transport {
/// Open a new Transport for the provided spec, e.g. `"usb:0"` or `"android-netsim"`.
pub async fn open(transport_spec: impl Into<String>) -> PyResult<Self> {
Python::with_gil(|py| {
PyModule::import(py, intern!(py, "bumble.transport"))?
.call_method1(intern!(py, "open_transport"), (transport_spec.into(),))
.and_then(pyo3_asyncio::tokio::into_future)
})?
.await
.map(Self)
}
/// Close the transport.
pub async fn close(&mut self) -> PyResult<()> {
Python::with_gil(|py| {
self.0
.call_method0(py, intern!(py, "close"))
.and_then(|coroutine| pyo3_asyncio::tokio::into_future(coroutine.as_ref(py)))
})?
.await
.map(|_| ())
}
/// Returns the source half of the transport.
pub fn source(&self) -> PyResult<Source> {
Python::with_gil(|py| self.0.getattr(py, intern!(py, "source"))).map(Source)
}
/// Returns the sink half of the transport.
pub fn sink(&self) -> PyResult<Sink> {
Python::with_gil(|py| self.0.getattr(py, intern!(py, "sink"))).map(Sink)
}
}
impl Drop for Transport {
fn drop(&mut self) {
// can't await in a Drop impl, but we can at least spawn a task to do it
let obj = self.0.clone();
tokio::spawn(async move { Self(obj).close().await });
}
}
/// The source side of a [Transport].
#[derive(Clone)]
pub struct Source(pub(crate) PyObject);
/// The sink side of a [Transport].
#[derive(Clone)]
pub struct Sink(pub(crate) PyObject);

View File

@@ -0,0 +1,97 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! This tool generates Rust code with assigned number tables from the equivalent Python.
use pyo3::{
intern,
types::{PyDict, PyModule},
PyResult, Python,
};
use std::{collections, env, fs, path};
fn main() -> anyhow::Result<()> {
pyo3::prepare_freethreaded_python();
let mut dir = path::Path::new(&env::var("CARGO_MANIFEST_DIR")?).to_path_buf();
dir.push("src/wrapper/assigned_numbers");
company_ids(&dir)?;
Ok(())
}
fn company_ids(base_dir: &path::Path) -> anyhow::Result<()> {
let mut sorted_ids = load_company_ids()?.into_iter().collect::<Vec<_>>();
sorted_ids.sort_by_key(|(id, _name)| *id);
let mut contents = String::new();
contents.push_str(LICENSE_HEADER);
contents.push_str("\n\n");
contents.push_str(
"// auto-generated by gen_assigned_numbers, do not edit
use crate::wrapper::core::Uuid16;
use lazy_static::lazy_static;
use std::collections;
lazy_static! {
/// Assigned company IDs
pub static ref COMPANY_IDS: collections::HashMap<Uuid16, &'static str> = [
",
);
for (id, name) in sorted_ids {
contents.push_str(&format!(" ({id}_u16, r#\"{name}\"#),\n"))
}
contents.push_str(
" ]
.into_iter()
.map(|(id, name)| (Uuid16::from_be_bytes(id.to_be_bytes()), name))
.collect();
}
",
);
let mut company_ids = base_dir.to_path_buf();
company_ids.push("company_ids.rs");
fs::write(&company_ids, contents)?;
Ok(())
}
fn load_company_ids() -> PyResult<collections::HashMap<u16, String>> {
Python::with_gil(|py| {
PyModule::import(py, intern!(py, "bumble.company_ids"))?
.getattr(intern!(py, "COMPANY_IDENTIFIERS"))?
.downcast::<PyDict>()?
.into_iter()
.map(|(k, v)| Ok((k.extract::<u16>()?, v.str()?.to_str()?.to_string())))
.collect::<PyResult<collections::HashMap<_, _>>>()
})
}
const LICENSE_HEADER: &str = r#"// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License."#;

35
tests/at_test.py Normal file
View File

@@ -0,0 +1,35 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from bumble import at
def test_tokenize_parameters():
assert at.tokenize_parameters(b'1, 2, 3') == [b'1', b',', b'2', b',', b'3']
assert at.tokenize_parameters(b'"1, 2, 3"') == [b'1, 2, 3']
assert at.tokenize_parameters(b'(1, "2, 3")') == [b'(', b'1', b',', b'2, 3', b')']
def test_parse_parameters():
assert at.parse_parameters(b'1, 2, 3') == [b'1', b'2', b'3']
assert at.parse_parameters(b'1,, 3') == [b'1', b'', b'3']
assert at.parse_parameters(b'"1, 2, 3"') == [b'1, 2, 3']
assert at.parse_parameters(b'1, (2, (3))') == [b'1', [b'2', [b'3']]]
assert at.parse_parameters(b'1, (2, "3, 4"), 5') == [b'1', [b'2', b'3, 4'], b'5']
# -----------------------------------------------------------------------------
if __name__ == '__main__':
test_tokenize_parameters()
test_parse_parameters()

View File

@@ -803,14 +803,14 @@ async def test_mtu_exchange():
# -----------------------------------------------------------------------------
def test_char_property_to_string():
# single
assert str(Characteristic.Properties(0x01)) == "Properties.BROADCAST"
assert str(Characteristic.Properties.BROADCAST) == "Properties.BROADCAST"
assert str(Characteristic.Properties(0x01)) == "BROADCAST"
assert str(Characteristic.Properties.BROADCAST) == "BROADCAST"
# double
assert str(Characteristic.Properties(0x03)) == "Properties.READ|BROADCAST"
assert str(Characteristic.Properties(0x03)) == "BROADCAST|READ"
assert (
str(Characteristic.Properties.BROADCAST | Characteristic.Properties.READ)
== "Properties.READ|BROADCAST"
== "BROADCAST|READ"
)
@@ -831,6 +831,10 @@ def test_characteristic_property_from_string():
Characteristic.Properties.from_string("READ,BROADCAST")
== Characteristic.Properties.BROADCAST | Characteristic.Properties.READ
)
assert (
Characteristic.Properties.from_string("BROADCAST|READ")
== Characteristic.Properties.BROADCAST | Characteristic.Properties.READ
)
# -----------------------------------------------------------------------------
@@ -841,7 +845,7 @@ def test_characteristic_property_from_string_assert():
assert (
str(e_info.value)
== """Characteristic.Properties::from_string() error:
Expected a string containing any of the keys, separated by commas: BROADCAST,READ,WRITE_WITHOUT_RESPONSE,WRITE,NOTIFY,INDICATE,AUTHENTICATED_SIGNED_WRITES,EXTENDED_PROPERTIES
Expected a string containing any of the keys, separated by , or |: BROADCAST,READ,WRITE_WITHOUT_RESPONSE,WRITE,NOTIFY,INDICATE,AUTHENTICATED_SIGNED_WRITES,EXTENDED_PROPERTIES
Got: BROADCAST,HELLO"""
)
@@ -866,13 +870,13 @@ async def test_server_string():
assert (
str(server.gatt_server)
== """Service(handle=0x0001, end=0x0005, uuid=UUID-16:1800 (Generic Access))
CharacteristicDeclaration(handle=0x0002, value_handle=0x0003, uuid=UUID-16:2A00 (Device Name), Properties.READ)
Characteristic(handle=0x0003, end=0x0003, uuid=UUID-16:2A00 (Device Name), Properties.READ)
CharacteristicDeclaration(handle=0x0004, value_handle=0x0005, uuid=UUID-16:2A01 (Appearance), Properties.READ)
Characteristic(handle=0x0005, end=0x0005, uuid=UUID-16:2A01 (Appearance), Properties.READ)
CharacteristicDeclaration(handle=0x0002, value_handle=0x0003, uuid=UUID-16:2A00 (Device Name), READ)
Characteristic(handle=0x0003, end=0x0003, uuid=UUID-16:2A00 (Device Name), READ)
CharacteristicDeclaration(handle=0x0004, value_handle=0x0005, uuid=UUID-16:2A01 (Appearance), READ)
Characteristic(handle=0x0005, end=0x0005, uuid=UUID-16:2A01 (Appearance), READ)
Service(handle=0x0006, end=0x0009, uuid=3A657F47-D34F-46B3-B1EC-698E29B6B829)
CharacteristicDeclaration(handle=0x0007, value_handle=0x0008, uuid=FDB159DB-036C-49E3-B3DB-6325AC750806, Properties.NOTIFY|WRITE|READ)
Characteristic(handle=0x0008, end=0x0009, uuid=FDB159DB-036C-49E3-B3DB-6325AC750806, Properties.NOTIFY|WRITE|READ)
CharacteristicDeclaration(handle=0x0007, value_handle=0x0008, uuid=FDB159DB-036C-49E3-B3DB-6325AC750806, READ|WRITE|NOTIFY)
Characteristic(handle=0x0008, end=0x0009, uuid=FDB159DB-036C-49E3-B3DB-6325AC750806, READ|WRITE|NOTIFY)
Descriptor(handle=0x0009, type=UUID-16:2902 (Client Characteristic Configuration), value=0000)"""
)

View File

@@ -46,6 +46,7 @@ from bumble.hci import (
HCI_LE_Set_Advertising_Parameters_Command,
HCI_LE_Set_Default_PHY_Command,
HCI_LE_Set_Event_Mask_Command,
HCI_LE_Set_Extended_Advertising_Enable_Command,
HCI_LE_Set_Extended_Scan_Parameters_Command,
HCI_LE_Set_Random_Address_Command,
HCI_LE_Set_Scan_Enable_Command,
@@ -422,6 +423,25 @@ def test_HCI_LE_Set_Extended_Scan_Parameters_Command():
basic_check(command)
# -----------------------------------------------------------------------------
def test_HCI_LE_Set_Extended_Advertising_Enable_Command():
command = HCI_Packet.from_bytes(
bytes.fromhex('0139200e010301050008020600090307000a')
)
assert command.enable == 1
assert command.advertising_handles == [1, 2, 3]
assert command.durations == [5, 6, 7]
assert command.max_extended_advertising_events == [8, 9, 10]
command = HCI_LE_Set_Extended_Advertising_Enable_Command(
enable=1,
advertising_handles=[1, 2, 3],
durations=[5, 6, 7],
max_extended_advertising_events=[8, 9, 10],
)
basic_check(command)
# -----------------------------------------------------------------------------
def test_address():
a = Address('C4:F2:17:1A:1D:BB')
@@ -478,6 +498,7 @@ def run_test_commands():
test_HCI_LE_Read_Remote_Features_Command()
test_HCI_LE_Set_Default_PHY_Command()
test_HCI_LE_Set_Extended_Scan_Parameters_Command()
test_HCI_LE_Set_Extended_Advertising_Enable_Command()
# -----------------------------------------------------------------------------

View File

@@ -21,6 +21,8 @@ import logging
import os
import pytest
from unittest.mock import MagicMock, patch
from bumble.controller import Controller
from bumble.core import BT_BR_EDR_TRANSPORT, BT_PERIPHERAL_ROLE, BT_CENTRAL_ROLE
from bumble.link import LocalLink
@@ -34,6 +36,8 @@ from bumble.smp import (
SMP_CONFIRM_VALUE_FAILED_ERROR,
)
from bumble.core import ProtocolError
from bumble.hci import HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_256_TYPE
from bumble.keys import PairingKeys
# -----------------------------------------------------------------------------
@@ -64,13 +68,16 @@ class TwoDevices:
),
]
self.paired = [None, None]
self.paired = [
asyncio.get_event_loop().create_future(),
asyncio.get_event_loop().create_future(),
]
def on_connection(self, which, connection):
self.connections[which] = connection
def on_paired(self, which, keys):
self.paired[which] = keys
def on_paired(self, which: int, keys: PairingKeys):
self.paired[which].set_result(keys)
# -----------------------------------------------------------------------------
@@ -319,8 +326,8 @@ async def _test_self_smp_with_configs(pairing_config1, pairing_config2):
# Pair
await two_devices.devices[0].pair(connection)
assert connection.is_encrypted
assert two_devices.paired[0] is not None
assert two_devices.paired[1] is not None
assert await two_devices.paired[0] is not None
assert await two_devices.paired[1] is not None
# -----------------------------------------------------------------------------
@@ -473,6 +480,101 @@ async def test_self_smp_wrong_pin():
assert not paired
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_self_smp_over_classic():
# Create two devices, each with a controller, attached to the same link
two_devices = TwoDevices()
# Attach listeners
two_devices.devices[0].on(
'connection', lambda connection: two_devices.on_connection(0, connection)
)
two_devices.devices[1].on(
'connection', lambda connection: two_devices.on_connection(1, connection)
)
# Enable Classic connections
two_devices.devices[0].classic_enabled = True
two_devices.devices[1].classic_enabled = True
# Start
await two_devices.devices[0].power_on()
await two_devices.devices[1].power_on()
# Connect the two devices
await asyncio.gather(
two_devices.devices[0].connect(
two_devices.devices[1].public_address, transport=BT_BR_EDR_TRANSPORT
),
two_devices.devices[1].accept(two_devices.devices[0].public_address),
)
# Check the post conditions
assert two_devices.connections[0] is not None
assert two_devices.connections[1] is not None
# Mock connection
# TODO: Implement Classic SSP and encryption in link relayer
LINK_KEY = bytes.fromhex('287ad379dca402530a39f1f43047b835')
two_devices.devices[0].on_link_key(
two_devices.devices[1].public_address,
LINK_KEY,
HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_256_TYPE,
)
two_devices.devices[1].on_link_key(
two_devices.devices[0].public_address,
LINK_KEY,
HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_256_TYPE,
)
two_devices.connections[0].encryption = 1
two_devices.connections[1].encryption = 1
two_devices.connections[0].on(
'pairing', lambda keys: two_devices.on_paired(0, keys)
)
two_devices.connections[1].on(
'pairing', lambda keys: two_devices.on_paired(1, keys)
)
# Mock SMP
with patch('bumble.smp.Session', spec=True) as MockSmpSession:
MockSmpSession.send_pairing_confirm_command = MagicMock()
MockSmpSession.send_pairing_dhkey_check_command = MagicMock()
MockSmpSession.send_public_key_command = MagicMock()
MockSmpSession.send_pairing_random_command = MagicMock()
# Start CTKD
await two_devices.connections[0].pair()
await asyncio.gather(*two_devices.paired)
# Phase 2 commands should not be invoked
MockSmpSession.send_pairing_confirm_command.assert_not_called()
MockSmpSession.send_pairing_dhkey_check_command.assert_not_called()
MockSmpSession.send_public_key_command.assert_not_called()
MockSmpSession.send_pairing_random_command.assert_not_called()
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_self_smp_public_address():
pairing_config = PairingConfig(
mitm=True,
sc=True,
bonding=True,
identity_address_type=PairingConfig.AddressType.PUBLIC,
delegate=PairingDelegate(
PairingDelegate.IoCapability.DISPLAY_OUTPUT_AND_YES_NO_INPUT,
PairingDelegate.KeyDistribution.DISTRIBUTE_ENCRYPTION_KEY
| PairingDelegate.KeyDistribution.DISTRIBUTE_IDENTITY_KEY
| PairingDelegate.KeyDistribution.DISTRIBUTE_SIGNING_KEY
| PairingDelegate.KeyDistribution.DISTRIBUTE_LINK_KEY,
),
)
await _test_self_smp_with_configs(pairing_config, pairing_config)
# -----------------------------------------------------------------------------
async def run_test_self():
await test_self_connection()
@@ -481,6 +583,8 @@ async def run_test_self():
await test_self_smp()
await test_self_smp_reject()
await test_self_smp_wrong_pin()
await test_self_smp_over_classic()
await test_self_smp_public_address()
# -----------------------------------------------------------------------------