Compare commits

..

2 Commits

Author SHA1 Message Date
Gilles Boccon-Gibod 368e7eff05 update tests to look at side effects instead of internals 2022-08-12 12:32:36 -07:00
Gilles Boccon-Gibod 55b813bbf5 don't use a lambda as a subscriber 2022-08-12 12:06:08 -07:00
521 changed files with 9956 additions and 81906 deletions
-2
View File
@@ -1,2 +0,0 @@
# Migrate code style to Black
135df0dcc01ab765f432e19b1a5202d29bd55545
-39
View File
@@ -1,39 +0,0 @@
# Check the code against the formatter and linter
name: Code format and lint check
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
permissions:
contents: read
jobs:
check:
name: Check Code
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
fail-fast: false
steps:
- name: Check out from Git
uses: actions/checkout@v3
- name: Get history and tags for SCM versioning to work
run: |
git fetch --prune --unshallow
git fetch --depth=1 origin +refs/tags/*:refs/tags/*
- name: Set up Python
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install ".[build,test,development,pandora]"
- name: Check
run: |
invoke project.pre-commit
-43
View File
@@ -1,43 +0,0 @@
name: Python Avatar
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
permissions:
contents: read
jobs:
test:
name: Avatar [${{ matrix.shard }}]
runs-on: ubuntu-latest
strategy:
matrix:
shard: [
1/24, 2/24, 3/24, 4/24,
5/24, 6/24, 7/24, 8/24,
9/24, 10/24, 11/24, 12/24,
13/24, 14/24, 15/24, 16/24,
17/24, 18/24, 19/24, 20/24,
21/24, 22/24, 23/24, 24/24,
]
steps:
- uses: actions/checkout@v3
- name: Set Up Python 3.11
uses: actions/setup-python@v4
with:
python-version: 3.11
- name: Install
run: |
python -m pip install --upgrade pip
python -m pip install .[avatar,pandora]
- name: Rootcanal
run: nohup python -m rootcanal > rootcanal.log &
- name: Test
run: |
avatar --list | grep -Ev '^=' > test-names.txt
timeout 5m avatar --test-beds bumble.bumbles --tests $(split test-names.txt -n l/${{ matrix.shard }})
- name: Rootcanal Logs
run: cat rootcanal.log
+10 -53
View File
@@ -12,12 +12,8 @@ permissions:
jobs: jobs:
build: build:
runs-on: ${{ matrix.os }}
strategy: runs-on: ubuntu-latest
matrix:
os: ['ubuntu-latest', 'macos-latest', 'windows-latest']
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
fail-fast: false
steps: steps:
- name: Check out from Git - name: Check out from Git
@@ -25,58 +21,19 @@ jobs:
- name: Get history and tags for SCM versioning to work - name: Get history and tags for SCM versioning to work
run: | run: |
git fetch --prune --unshallow git fetch --prune --unshallow
git fetch --depth=1 origin +refs/tags/*:refs/tags/* git fetch --depth=1 origin +refs/tags/*:refs/tags/*
- name: Set up Python ${{ matrix.python-version }} - name: Set up Python 3.10
uses: actions/setup-python@v4 uses: actions/setup-python@v3
with: with:
python-version: ${{ matrix.python-version }} python-version: "3.10"
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
python -m pip install ".[build,test,development,documentation]" python -m pip install ".[test,development,documentation]"
- name: Test - name: Test with pytest
run: | run: |
invoke test pytest
- name: Build - name: Build
run: | run: |
inv build inv build
inv build.mkdocs inv mkdocs
build-rust:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [ "3.8", "3.9", "3.10", "3.11", "3.12" ]
rust-version: [ "1.76.0", "stable" ]
fail-fast: false
steps:
- name: Check out from Git
uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install Python dependencies
run: |
python -m pip install --upgrade pip
python -m pip install ".[build,test,development,documentation]"
- name: Install Rust toolchain
uses: actions-rust-lang/setup-rust-toolchain@v1
with:
components: clippy,rustfmt
toolchain: ${{ matrix.rust-version }}
- name: Install Rust dependencies
run: cargo install cargo-all-features # allows building/testing combinations of features
- name: Check License Headers
run: cd rust && cargo run --features dev-tools --bin file-header check-all
- name: Rust Build
run: cd rust && cargo build --all-targets && cargo build-all-features --all-targets
# Lints after build so what clippy needs is already built
- name: Rust Lints
run: cd rust && cargo fmt --check && cargo clippy --all-targets -- --deny warnings && cargo clippy --all-features --all-targets -- --deny warnings
- name: Rust Tests
run: cd rust && cargo test-all-features
# At some point, hook up publishing the binary. For now, just make sure it builds.
# Once we're ready to publish binaries, this should be built with `--release`.
- name: Build Bumble CLI
run: cd rust && cargo build --features bumble-tools --bin bumble
+4 -12
View File
@@ -3,17 +3,9 @@ build/
dist/ dist/
*.egg-info/ *.egg-info/
*~ *~
bumble/__pycache__
docs/mkdocs/site docs/mkdocs/site
tests/__pycache__
test-results.xml test-results.xml
__pycache__ bumble/transport/__pycache__
# Vim bumble/profiles/__pycache__
.*.sw*
# generated by setuptools_scm
bumble/_version.py
.vscode/launch.json
.vscode/settings.json
/.idea
venv/
.venv/
# snoop logs
out/
-99
View File
@@ -1,99 +0,0 @@
{
"cSpell.words": [
"Abortable",
"aiohttp",
"altsetting",
"ansiblue",
"ansicyan",
"ansigreen",
"ansimagenta",
"ansired",
"ansiyellow",
"appendleft",
"ascs",
"ASHA",
"asyncio",
"ATRAC",
"avctp",
"avdtp",
"avrcp",
"bitpool",
"bitstruct",
"BSCP",
"BTPROTO",
"CCCD",
"cccds",
"cmac",
"CONNECTIONLESS",
"csip",
"csis",
"csrcs",
"CVSD",
"datagram",
"DATALINK",
"delayreport",
"deregisters",
"deregistration",
"dhkey",
"diversifier",
"endianness",
"ESCO",
"Fitbit",
"GATTLINK",
"HANDSFREE",
"keydown",
"keyup",
"levelname",
"libc",
"liblc",
"libusb",
"MITM",
"MSBC",
"NDIS",
"netsim",
"NONBLOCK",
"NONCONN",
"OXIMETER",
"popleft",
"PRAND",
"protobuf",
"psms",
"pyee",
"Pyodide",
"pyusb",
"rfcomm",
"ROHC",
"rssi",
"SEID",
"seids",
"SERV",
"SIRK",
"ssrc",
"strerror",
"subband",
"subbands",
"subevent",
"Subrating",
"substates",
"tobytes",
"tsep",
"UNMUTE",
"unmuted",
"usbmodem",
"vhci",
"wasmtime",
"websockets",
"xcursor",
"ycursor"
],
"[python]": {
"editor.rulers": [88]
},
"python.formatting.provider": "black",
"pylint.importStrategy": "useBundled",
"python.testing.pytestArgs": [
"."
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
}
+1 -20
View File
@@ -199,23 +199,4 @@
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
---
Files: bumble/colors.py
Copyright (c) 2012 Giorgos Verigakis <verigak@gmail.com>
Permission to use, copy, modify, and distribute this software for any
purpose with or without fee is hereby granted, provided that the above
copyright notice and this permission notice appear in all copies.
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+5 -13
View File
@@ -9,16 +9,16 @@
Bluetooth Stack for Apps, Emulation, Test and Experimentation Bluetooth Stack for Apps, Emulation, Test and Experimentation
============================================================= =============================================================
<img src="docs/mkdocs/src/images/logo_framed.png" alt="Logo" width="200" height="200"/> <img src="docs/mkdocs/src/images/logo_framed.png" alt="drawing" width="200" height="200"/>
Bumble is a full-featured Bluetooth stack written entirely in Python. It supports most of the common Bluetooth Low Energy (BLE) and Bluetooth Classic (BR/EDR) protocols and profiles, including GAP, L2CAP, ATT, GATT, SMP, SDP, RFCOMM, HFP, HID and A2DP. The stack can be used with physical radios via HCI over USB, UART, or the Linux VHCI, as well as virtual radios, including the virtual Bluetooth support of the Android emulator. Bumble is a full-featured Bluetooth stack written entirely in Python. It supports most of the common Bluetooth Low Energy (BLE) and Bluetooth Classic (BR/EDR) protocols and profiles, including GAP, L2CAP, ATT, GATT, SMP, SDP, RFCOMM, HFP, HID and A2DP. The stack can be used with physical radios via HCI over USB, UART, or the Linux VHCI, as well as virtual radios, including the virtual Bluetooth support of the Android emulator.
## Documentation ## Documentation
Browse the pre-built [Online Documentation](https://google.github.io/bumble/), Browse the pre-built [Online Documentation](https://google.github.io/bumble/),
or see the documentation source under `docs/mkdocs/src`, or build the static HTML site from the markdown text with: or see the documentation source under `docs/mkdocs/src`, or build the static HTML site from the markdown text with:
``` ```
mkdocs build -f docs/mkdocs/mkdocs.yml mkdocs build -f docs/mkdocs/mkdocs.yml
``` ```
## Usage ## Usage
@@ -29,7 +29,7 @@ For a quick start to using Bumble, see the [Getting Started](docs/mkdocs/src/get
### Dependencies ### Dependencies
To install package dependencies needed to run the bumble examples, execute the following commands: To install package dependencies needed to run the bumble examples execute the following commands:
``` ```
python -m pip install --upgrade pip python -m pip install --upgrade pip
@@ -38,20 +38,12 @@ python -m pip install ".[test,development,documentation]"
### Examples ### Examples
Refer to the [Examples Documentation](examples/README.md) for details on the included example scripts and how to run them. Refer to the [Example Documentation](examples/README.md) for details on the included example scripts and how to run them.
The complete [list of Examples](/docs/mkdocs/src/examples/index.md), and what they are designed to do is here. The complete [list of Examples](/docs/mkdocs/src/examples/index.md), and what they are designed to do is here.
There are also a set of [Apps and Tools](docs/mkdocs/src/apps_and_tools/index.md) that show the utility of Bumble. There are also a set of [Apps and Tools](docs/mkdocs/src/apps_and_tools/index.md) that show the utility of Bumble.
### Using Bumble With a USB Dongle
Bumble is easiest to use with a dedicated USB dongle.
This is because internal Bluetooth interfaces tend to be locked down by the operating system.
You can use the [usb_probe](/docs/mkdocs/src/apps_and_tools/usb_probe.md) tool (all platforms) or `lsusb` (Linux or macOS) to list the available USB devices on your system.
See the [USB Transport](/docs/mkdocs/src/transports/usb.md) page for details on how to refer to USB devices. Also, if your are on a mac, see [these instructions](docs/mkdocs/src/platforms/macos.md).
## License ## License
Licensed under the [Apache 2.0](LICENSE) License. Licensed under the [Apache 2.0](LICENSE) License.
+2
View File
@@ -47,3 +47,5 @@ NOTE: this assumes you're running a Link Relay on port `10723`.
## `console.py` ## `console.py`
A simple text-based-ui interactive Bluetooth device with GATT client capabilities. A simple text-based-ui interactive Bluetooth device with GATT client capabilities.
-407
View File
@@ -1,407 +0,0 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import dataclasses
import logging
import os
from typing import cast, Dict, Optional, Tuple
import click
import pyee
from bumble.colors import color
import bumble.company_ids
import bumble.core
import bumble.device
import bumble.gatt
import bumble.hci
import bumble.profiles.bap
import bumble.profiles.pbp
import bumble.transport
import bumble.utils
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
AURACAST_DEFAULT_DEVICE_NAME = "Bumble Auracast"
AURACAST_DEFAULT_DEVICE_ADDRESS = bumble.hci.Address("F0:F1:F2:F3:F4:F5")
# -----------------------------------------------------------------------------
# Discover Broadcasts
# -----------------------------------------------------------------------------
class BroadcastDiscoverer:
@dataclasses.dataclass
class Broadcast(pyee.EventEmitter):
name: str
sync: bumble.device.PeriodicAdvertisingSync
rssi: int = 0
public_broadcast_announcement: Optional[
bumble.profiles.pbp.PublicBroadcastAnnouncement
] = None
broadcast_audio_announcement: Optional[
bumble.profiles.bap.BroadcastAudioAnnouncement
] = None
basic_audio_announcement: Optional[
bumble.profiles.bap.BasicAudioAnnouncement
] = None
appearance: Optional[bumble.core.Appearance] = None
biginfo: Optional[bumble.device.BIGInfoAdvertisement] = None
manufacturer_data: Optional[Tuple[str, bytes]] = None
def __post_init__(self) -> None:
super().__init__()
self.sync.on('establishment', self.on_sync_establishment)
self.sync.on('loss', self.on_sync_loss)
self.sync.on('periodic_advertisement', self.on_periodic_advertisement)
self.sync.on('biginfo_advertisement', self.on_biginfo_advertisement)
self.establishment_timeout_task = asyncio.create_task(
self.wait_for_establishment()
)
async def wait_for_establishment(self) -> None:
await asyncio.sleep(5.0)
if self.sync.state == bumble.device.PeriodicAdvertisingSync.State.PENDING:
print(
color(
'!!! Periodic advertisement sync not established in time, '
'canceling',
'red',
)
)
await self.sync.terminate()
def update(self, advertisement: bumble.device.Advertisement) -> None:
self.rssi = advertisement.rssi
for service_data in advertisement.data.get_all(
bumble.core.AdvertisingData.SERVICE_DATA
):
assert isinstance(service_data, tuple)
service_uuid, data = service_data
assert isinstance(data, bytes)
if (
service_uuid
== bumble.gatt.GATT_PUBLIC_BROADCAST_ANNOUNCEMENT_SERVICE
):
self.public_broadcast_announcement = (
bumble.profiles.pbp.PublicBroadcastAnnouncement.from_bytes(data)
)
continue
if (
service_uuid
== bumble.gatt.GATT_BROADCAST_AUDIO_ANNOUNCEMENT_SERVICE
):
self.broadcast_audio_announcement = (
bumble.profiles.bap.BroadcastAudioAnnouncement.from_bytes(data)
)
continue
self.appearance = advertisement.data.get( # type: ignore[assignment]
bumble.core.AdvertisingData.APPEARANCE
)
if manufacturer_data := advertisement.data.get(
bumble.core.AdvertisingData.MANUFACTURER_SPECIFIC_DATA
):
assert isinstance(manufacturer_data, tuple)
company_id = cast(int, manufacturer_data[0])
data = cast(bytes, manufacturer_data[1])
self.manufacturer_data = (
bumble.company_ids.COMPANY_IDENTIFIERS.get(
company_id, f'0x{company_id:04X}'
),
data,
)
def print(self) -> None:
print(
color('Broadcast:', 'yellow'),
self.sync.advertiser_address,
color(self.sync.state.name, 'green'),
)
print(f' {color("Name", "cyan")}: {self.name}')
if self.appearance:
print(f' {color("Appearance", "cyan")}: {str(self.appearance)}')
print(f' {color("RSSI", "cyan")}: {self.rssi}')
print(f' {color("SID", "cyan")}: {self.sync.sid}')
if self.manufacturer_data:
print(
f' {color("Manufacturer Data", "cyan")}: '
f'{self.manufacturer_data[0]} -> {self.manufacturer_data[1].hex()}'
)
if self.broadcast_audio_announcement:
print(
f' {color("Broadcast ID", "cyan")}: '
f'{self.broadcast_audio_announcement.broadcast_id}'
)
if self.public_broadcast_announcement:
print(
f' {color("Features", "cyan")}: '
f'{self.public_broadcast_announcement.features}'
)
print(
f' {color("Metadata", "cyan")}: '
f'{self.public_broadcast_announcement.metadata}'
)
if self.basic_audio_announcement:
print(color(' Audio:', 'cyan'))
print(
color(' Presentation Delay:', 'magenta'),
self.basic_audio_announcement.presentation_delay,
)
for subgroup in self.basic_audio_announcement.subgroups:
print(color(' Subgroup:', 'magenta'))
print(color(' Codec ID:', 'yellow'))
print(
color(' Coding Format: ', 'green'),
subgroup.codec_id.coding_format.name,
)
print(
color(' Company ID: ', 'green'),
subgroup.codec_id.company_id,
)
print(
color(' Vendor Specific Codec ID:', 'green'),
subgroup.codec_id.vendor_specific_codec_id,
)
print(
color(' Codec Config:', 'yellow'),
subgroup.codec_specific_configuration,
)
print(color(' Metadata: ', 'yellow'), subgroup.metadata)
for bis in subgroup.bis:
print(color(f' BIS [{bis.index}]:', 'yellow'))
print(
color(' Codec Config:', 'green'),
bis.codec_specific_configuration,
)
if self.biginfo:
print(color(' BIG:', 'cyan'))
print(
color(' Number of BIS:', 'magenta'),
self.biginfo.num_bis,
)
print(
color(' PHY: ', 'magenta'),
self.biginfo.phy.name,
)
print(
color(' Framed: ', 'magenta'),
self.biginfo.framed,
)
print(
color(' Encrypted: ', 'magenta'),
self.biginfo.encrypted,
)
def on_sync_establishment(self) -> None:
self.establishment_timeout_task.cancel()
self.emit('change')
def on_sync_loss(self) -> None:
self.basic_audio_announcement = None
self.biginfo = None
self.emit('change')
def on_periodic_advertisement(
self, advertisement: bumble.device.PeriodicAdvertisement
) -> None:
if advertisement.data is None:
return
for service_data in advertisement.data.get_all(
bumble.core.AdvertisingData.SERVICE_DATA
):
assert isinstance(service_data, tuple)
service_uuid, data = service_data
assert isinstance(data, bytes)
if service_uuid == bumble.gatt.GATT_BASIC_AUDIO_ANNOUNCEMENT_SERVICE:
self.basic_audio_announcement = (
bumble.profiles.bap.BasicAudioAnnouncement.from_bytes(data)
)
break
self.emit('change')
def on_biginfo_advertisement(
self, advertisement: bumble.device.BIGInfoAdvertisement
) -> None:
self.biginfo = advertisement
self.emit('change')
def __init__(
self,
device: bumble.device.Device,
filter_duplicates: bool,
sync_timeout: float,
):
self.device = device
self.filter_duplicates = filter_duplicates
self.sync_timeout = sync_timeout
self.broadcasts: Dict[bumble.hci.Address, BroadcastDiscoverer.Broadcast] = {}
self.status_message = ''
device.on('advertisement', self.on_advertisement)
async def run(self) -> None:
self.status_message = color('Scanning...', 'green')
await self.device.start_scanning(
active=False,
filter_duplicates=False,
)
def refresh(self) -> None:
# Clear the screen from the top
print('\033[H')
print('\033[0J')
print('\033[H')
# Print the status message
print(self.status_message)
print("==========================================")
# Print all broadcasts
for broadcast in self.broadcasts.values():
broadcast.print()
print('------------------------------------------')
# Clear the screen to the bottom
print('\033[0J')
def on_advertisement(self, advertisement: bumble.device.Advertisement) -> None:
if (
broadcast_name := advertisement.data.get(
bumble.core.AdvertisingData.BROADCAST_NAME
)
) is None:
return
assert isinstance(broadcast_name, str)
if broadcast := self.broadcasts.get(advertisement.address):
broadcast.update(advertisement)
self.refresh()
return
bumble.utils.AsyncRunner.spawn(
self.on_new_broadcast(broadcast_name, advertisement)
)
async def on_new_broadcast(
self, name: str, advertisement: bumble.device.Advertisement
) -> None:
periodic_advertising_sync = await self.device.create_periodic_advertising_sync(
advertiser_address=advertisement.address,
sid=advertisement.sid,
sync_timeout=self.sync_timeout,
filter_duplicates=self.filter_duplicates,
)
broadcast = self.Broadcast(
name,
periodic_advertising_sync,
)
broadcast.on('change', self.refresh)
broadcast.update(advertisement)
self.broadcasts[advertisement.address] = broadcast
periodic_advertising_sync.on('loss', lambda: self.on_broadcast_loss(broadcast))
self.status_message = color(
f'+Found {len(self.broadcasts)} broadcasts', 'green'
)
self.refresh()
def on_broadcast_loss(self, broadcast: Broadcast) -> None:
del self.broadcasts[broadcast.sync.advertiser_address]
bumble.utils.AsyncRunner.spawn(broadcast.sync.terminate())
self.status_message = color(
f'-Found {len(self.broadcasts)} broadcasts', 'green'
)
self.refresh()
async def run_discover_broadcasts(
filter_duplicates: bool, sync_timeout: float, transport: str
) -> None:
async with await bumble.transport.open_transport(transport) as (
hci_source,
hci_sink,
):
device = bumble.device.Device.with_hci(
AURACAST_DEFAULT_DEVICE_NAME,
AURACAST_DEFAULT_DEVICE_ADDRESS,
hci_source,
hci_sink,
)
await device.power_on()
discoverer = BroadcastDiscoverer(device, filter_duplicates, sync_timeout)
await discoverer.run()
await hci_source.terminated
# -----------------------------------------------------------------------------
# Main
# -----------------------------------------------------------------------------
@click.group()
@click.pass_context
def auracast(
ctx,
):
ctx.ensure_object(dict)
@auracast.command('discover-broadcasts')
@click.option(
'--filter-duplicates', is_flag=True, default=False, help='Filter duplicates'
)
@click.option(
'--sync-timeout',
metavar='SYNC_TIMEOUT',
type=float,
default=5.0,
help='Sync timeout (in seconds)',
)
@click.argument('transport')
@click.pass_context
def discover_broadcasts(ctx, filter_duplicates, sync_timeout, transport):
"""Discover public broadcasts"""
asyncio.run(run_discover_broadcasts(filter_duplicates, sync_timeout, transport))
def main():
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
auracast()
# -----------------------------------------------------------------------------
if __name__ == "__main__":
main() # pylint: disable=no-value-for-parameter
-1752
View File
File diff suppressed because it is too large Load Diff
-63
View File
@@ -1,63 +0,0 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import click
from bumble.colors import color
from bumble.hci import Address
from bumble.helpers import generate_irk, verify_rpa_with_irk
@click.group()
def cli():
'''
This is a tool for generating IRK, RPA,
and verifying IRK/RPA pairs
'''
@click.command()
def gen_irk() -> None:
print(generate_irk().hex())
@click.command()
@click.argument("irk", type=str)
def gen_rpa(irk: str) -> None:
irk_bytes = bytes.fromhex(irk)
rpa = Address.generate_private_address(irk_bytes)
print(rpa.to_string(with_type_qualifier=False))
@click.command()
@click.argument("irk", type=str)
@click.argument("rpa", type=str)
def verify_rpa(irk: str, rpa: str) -> None:
address = Address(rpa)
irk_bytes = bytes.fromhex(irk)
if verify_rpa_with_irk(address, irk_bytes):
print(color("Verified", "green"))
else:
print(color("Not Verified", "red"))
def main():
cli.add_command(gen_irk)
cli.add_command(gen_rpa)
cli.add_command(verify_rpa)
cli()
# -----------------------------------------------------------------------------
if __name__ == '__main__':
main()
+178 -737
View File
File diff suppressed because it is too large Load Diff
+20 -161
View File
@@ -18,158 +18,52 @@
import asyncio import asyncio
import os import os
import logging import logging
import time
import click import click
from colors import color
from bumble.company_ids import COMPANY_IDENTIFIERS from bumble.company_ids import COMPANY_IDENTIFIERS
from bumble.colors import color
from bumble.core import name_or_number from bumble.core import name_or_number
from bumble.hci import ( from bumble.hci import (
map_null_terminated_utf8_string, map_null_terminated_utf8_string,
LeFeature, HCI_LE_SUPPORTED_FEATURES_NAMES,
HCI_SUCCESS, HCI_SUCCESS,
HCI_VERSION_NAMES, HCI_VERSION_NAMES,
LMP_VERSION_NAMES, LMP_VERSION_NAMES,
HCI_Command, HCI_Command,
HCI_Command_Complete_Event,
HCI_Command_Status_Event,
HCI_READ_BUFFER_SIZE_COMMAND,
HCI_Read_Buffer_Size_Command,
HCI_READ_BD_ADDR_COMMAND,
HCI_Read_BD_ADDR_Command, HCI_Read_BD_ADDR_Command,
HCI_READ_LOCAL_NAME_COMMAND, HCI_READ_BD_ADDR_COMMAND,
HCI_Read_Local_Name_Command, HCI_Read_Local_Name_Command,
HCI_LE_READ_BUFFER_SIZE_COMMAND, HCI_READ_LOCAL_NAME_COMMAND
HCI_LE_Read_Buffer_Size_Command,
HCI_LE_READ_MAXIMUM_DATA_LENGTH_COMMAND,
HCI_LE_Read_Maximum_Data_Length_Command,
HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND,
HCI_LE_Read_Number_Of_Supported_Advertising_Sets_Command,
HCI_LE_READ_MAXIMUM_ADVERTISING_DATA_LENGTH_COMMAND,
HCI_LE_Read_Maximum_Advertising_Data_Length_Command,
HCI_LE_READ_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND,
HCI_LE_Read_Suggested_Default_Data_Length_Command,
HCI_Read_Local_Version_Information_Command,
) )
from bumble.host import Host from bumble.host import Host
from bumble.transport import open_transport_or_link from bumble.transport import open_transport_or_link
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def command_succeeded(response): async def get_classic_info(host):
if isinstance(response, HCI_Command_Status_Event):
return response.status == HCI_SUCCESS
if isinstance(response, HCI_Command_Complete_Event):
return response.return_parameters.status == HCI_SUCCESS
return False
# -----------------------------------------------------------------------------
async def get_classic_info(host: Host) -> None:
if host.supports_command(HCI_READ_BD_ADDR_COMMAND): if host.supports_command(HCI_READ_BD_ADDR_COMMAND):
response = await host.send_command(HCI_Read_BD_ADDR_Command()) response = await host.send_command(HCI_Read_BD_ADDR_Command())
if command_succeeded(response): if response.return_parameters.status == HCI_SUCCESS:
print() print()
print( print(color('Classic Address:', 'yellow'), response.return_parameters.bd_addr)
color('Classic Address:', 'yellow'),
response.return_parameters.bd_addr.to_string(False),
)
if host.supports_command(HCI_READ_LOCAL_NAME_COMMAND): if host.supports_command(HCI_READ_LOCAL_NAME_COMMAND):
response = await host.send_command(HCI_Read_Local_Name_Command()) response = await host.send_command(HCI_Read_Local_Name_Command())
if command_succeeded(response): if response.return_parameters.status == HCI_SUCCESS:
print() print()
print( print(color('Local Name:', 'yellow'), map_null_terminated_utf8_string(response.return_parameters.local_name))
color('Local Name:', 'yellow'),
map_null_terminated_utf8_string(response.return_parameters.local_name),
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def get_le_info(host: Host) -> None: async def get_le_info(host):
print() print()
if host.supports_command(HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND):
response = await host.send_command(
HCI_LE_Read_Number_Of_Supported_Advertising_Sets_Command()
)
if command_succeeded(response):
print(
color('LE Number Of Supported Advertising Sets:', 'yellow'),
response.return_parameters.num_supported_advertising_sets,
'\n',
)
if host.supports_command(HCI_LE_READ_MAXIMUM_ADVERTISING_DATA_LENGTH_COMMAND):
response = await host.send_command(
HCI_LE_Read_Maximum_Advertising_Data_Length_Command()
)
if command_succeeded(response):
print(
color('LE Maximum Advertising Data Length:', 'yellow'),
response.return_parameters.max_advertising_data_length,
'\n',
)
if host.supports_command(HCI_LE_READ_MAXIMUM_DATA_LENGTH_COMMAND):
response = await host.send_command(HCI_LE_Read_Maximum_Data_Length_Command())
if command_succeeded(response):
print(
color('Maximum Data Length:', 'yellow'),
(
f'tx:{response.return_parameters.supported_max_tx_octets}/'
f'{response.return_parameters.supported_max_tx_time}, '
f'rx:{response.return_parameters.supported_max_rx_octets}/'
f'{response.return_parameters.supported_max_rx_time}'
),
'\n',
)
if host.supports_command(HCI_LE_READ_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND):
response = await host.send_command(
HCI_LE_Read_Suggested_Default_Data_Length_Command()
)
if command_succeeded(response):
print(
color('Suggested Default Data Length:', 'yellow'),
f'{response.return_parameters.suggested_max_tx_octets}/'
f'{response.return_parameters.suggested_max_tx_time}',
'\n',
)
print(color('LE Features:', 'yellow')) print(color('LE Features:', 'yellow'))
for feature in host.supported_le_features: for feature in host.supported_le_features:
print(f' {LeFeature(feature).name}') print(' ', name_or_number(HCI_LE_SUPPORTED_FEATURES_NAMES, feature))
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def get_acl_flow_control_info(host: Host) -> None: async def async_main(transport):
print()
if host.supports_command(HCI_READ_BUFFER_SIZE_COMMAND):
response = await host.send_command(
HCI_Read_Buffer_Size_Command(), check_result=True
)
print(
color('ACL Flow Control:', 'yellow'),
f'{response.return_parameters.hc_total_num_acl_data_packets} '
f'packets of size {response.return_parameters.hc_acl_data_packet_length}',
)
if host.supports_command(HCI_LE_READ_BUFFER_SIZE_COMMAND):
response = await host.send_command(
HCI_LE_Read_Buffer_Size_Command(), check_result=True
)
print(
color('LE ACL Flow Control:', 'yellow'),
f'{response.return_parameters.hc_total_num_le_acl_data_packets} '
f'packets of size {response.return_parameters.hc_le_acl_data_packet_length}',
)
# -----------------------------------------------------------------------------
async def async_main(latency_probes, transport):
print('<<< connecting to HCI...') print('<<< connecting to HCI...')
async with await open_transport_or_link(transport) as (hci_source, hci_sink): async with await open_transport_or_link(transport) as (hci_source, hci_sink):
print('<<< connected') print('<<< connected')
@@ -177,38 +71,12 @@ async def async_main(latency_probes, transport):
host = Host(hci_source, hci_sink) host = Host(hci_source, hci_sink)
await host.reset() await host.reset()
# Measure the latency if requested
latencies = []
if latency_probes:
for _ in range(latency_probes):
start = time.time()
await host.send_command(HCI_Read_Local_Version_Information_Command())
latencies.append(1000 * (time.time() - start))
print(
color('HCI Command Latency:', 'yellow'),
(
f'min={min(latencies):.2f}, '
f'max={max(latencies):.2f}, '
f'average={sum(latencies)/len(latencies):.2f}'
),
'\n',
)
# Print version # Print version
print(color('Version:', 'yellow')) print(color('Version:', 'yellow'))
print( print(color(' Manufacturer: ', 'green'), name_or_number(COMPANY_IDENTIFIERS, host.local_version.company_identifier))
color(' Manufacturer: ', 'green'), print(color(' HCI Version: ', 'green'), name_or_number(HCI_VERSION_NAMES, host.local_version.hci_version))
name_or_number(COMPANY_IDENTIFIERS, host.local_version.company_identifier),
)
print(
color(' HCI Version: ', 'green'),
name_or_number(HCI_VERSION_NAMES, host.local_version.hci_version),
)
print(color(' HCI Subversion:', 'green'), host.local_version.hci_subversion) print(color(' HCI Subversion:', 'green'), host.local_version.hci_subversion)
print( print(color(' LMP Version: ', 'green'), name_or_number(LMP_VERSION_NAMES, host.local_version.lmp_version))
color(' LMP Version: ', 'green'),
name_or_number(LMP_VERSION_NAMES, host.local_version.lmp_version),
)
print(color(' LMP Subversion:', 'green'), host.local_version.lmp_subversion) print(color(' LMP Subversion:', 'green'), host.local_version.lmp_subversion)
# Get the Classic info # Get the Classic info
@@ -217,28 +85,19 @@ async def async_main(latency_probes, transport):
# Get the LE info # Get the LE info
await get_le_info(host) await get_le_info(host)
# Print the ACL flow control info
await get_acl_flow_control_info(host)
# Print the list of commands supported by the controller # Print the list of commands supported by the controller
print() print()
print(color('Supported Commands:', 'yellow')) print(color('Supported Commands:', 'yellow'))
for command in host.supported_commands: for command in host.supported_commands:
print(f' {HCI_Command.command_name(command)}') print(' ', HCI_Command.command_name(command))
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@click.command() @click.command()
@click.option(
'--latency-probes',
metavar='N',
type=int,
help='Send N commands to measure HCI transport latency statistics',
)
@click.argument('transport') @click.argument('transport')
def main(latency_probes, transport): def main(transport):
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper()) logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
asyncio.run(async_main(latency_probes, transport)) asyncio.run(async_main(transport))
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
-205
View File
@@ -1,205 +0,0 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import logging
import os
import time
from typing import Optional
from bumble.colors import color
from bumble.hci import (
HCI_READ_LOOPBACK_MODE_COMMAND,
HCI_Read_Loopback_Mode_Command,
HCI_WRITE_LOOPBACK_MODE_COMMAND,
HCI_Write_Loopback_Mode_Command,
LoopbackMode,
)
from bumble.host import Host
from bumble.transport import open_transport_or_link
import click
class Loopback:
"""Send and receive ACL data packets in local loopback mode"""
def __init__(self, packet_size: int, packet_count: int, transport: str):
self.transport = transport
self.packet_size = packet_size
self.packet_count = packet_count
self.connection_handle: Optional[int] = None
self.connection_event = asyncio.Event()
self.done = asyncio.Event()
self.expected_cid = 0
self.bytes_received = 0
self.start_timestamp = 0.0
self.last_timestamp = 0.0
def on_connection(self, connection_handle: int, *args):
"""Retrieve connection handle from new connection event"""
if not self.connection_event.is_set():
# save first connection handle for ACL
# subsequent connections are SCO
self.connection_handle = connection_handle
self.connection_event.set()
def on_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes):
"""Calculate packet receive speed"""
now = time.time()
print(f'<<< Received packet {cid}: {len(pdu)} bytes')
assert connection_handle == self.connection_handle
assert cid == self.expected_cid
self.expected_cid += 1
if cid == 0:
self.start_timestamp = now
else:
elapsed_since_start = now - self.start_timestamp
elapsed_since_last = now - self.last_timestamp
self.bytes_received += len(pdu)
instant_rx_speed = len(pdu) / elapsed_since_last
average_rx_speed = self.bytes_received / elapsed_since_start
print(
color(
f'@@@ RX speed: instant={instant_rx_speed:.4f},'
f' average={average_rx_speed:.4f}',
'cyan',
)
)
self.last_timestamp = now
if self.expected_cid == self.packet_count:
print(color('@@@ Received last packet', 'green'))
self.done.set()
async def run(self):
"""Run a loopback throughput test"""
print(color('>>> Connecting to HCI...', 'green'))
async with await open_transport_or_link(self.transport) as (
hci_source,
hci_sink,
):
print(color('>>> Connected', 'green'))
host = Host(hci_source, hci_sink)
await host.reset()
# make sure data can fit in one l2cap pdu
l2cap_header_size = 4
max_packet_size = (
host.acl_packet_queue
if host.acl_packet_queue
else host.le_acl_packet_queue
).max_packet_size - l2cap_header_size
if self.packet_size > max_packet_size:
print(
color(
f'!!! Packet size ({self.packet_size}) larger than max supported'
f' size ({max_packet_size})',
'red',
)
)
return
if not host.supports_command(
HCI_WRITE_LOOPBACK_MODE_COMMAND
) or not host.supports_command(HCI_READ_LOOPBACK_MODE_COMMAND):
print(color('!!! Loopback mode not supported', 'red'))
return
# set event callbacks
host.on('connection', self.on_connection)
host.on('l2cap_pdu', self.on_l2cap_pdu)
loopback_mode = LoopbackMode.LOCAL
print(color('### Setting loopback mode', 'blue'))
await host.send_command(
HCI_Write_Loopback_Mode_Command(loopback_mode=LoopbackMode.LOCAL),
check_result=True,
)
print(color('### Checking loopback mode', 'blue'))
response = await host.send_command(
HCI_Read_Loopback_Mode_Command(), check_result=True
)
if response.return_parameters.loopback_mode != loopback_mode:
print(color('!!! Loopback mode mismatch', 'red'))
return
await self.connection_event.wait()
print(color('### Connected', 'cyan'))
print(color('=== Start sending', 'magenta'))
start_time = time.time()
bytes_sent = 0
for cid in range(0, self.packet_count):
# using the cid as an incremental index
host.send_l2cap_pdu(
self.connection_handle, cid, bytes(self.packet_size)
)
print(
color(
f'>>> Sending packet {cid}: {self.packet_size} bytes', 'yellow'
)
)
bytes_sent += self.packet_size # don't count L2CAP or HCI header sizes
await asyncio.sleep(0) # yield to allow packet receive
await self.done.wait()
print(color('=== Done!', 'magenta'))
elapsed = time.time() - start_time
average_tx_speed = bytes_sent / elapsed
print(
color(
f'@@@ TX speed: average={average_tx_speed:.4f} ({bytes_sent} bytes'
f' in {elapsed:.2f} seconds)',
'green',
)
)
# -----------------------------------------------------------------------------
@click.command()
@click.option(
'--packet-size',
'-s',
metavar='SIZE',
type=click.IntRange(8, 4096),
default=500,
help='Packet size',
)
@click.option(
'--packet-count',
'-c',
metavar='COUNT',
type=click.IntRange(1, 65535),
default=10,
help='Packet count',
)
@click.argument('transport')
def main(packet_size, packet_count, transport):
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
loopback = Loopback(packet_size, packet_count, transport)
asyncio.run(loopback.run())
# -----------------------------------------------------------------------------
if __name__ == '__main__':
main()
+4 -12
View File
@@ -28,14 +28,11 @@ from bumble.transport import open_transport_or_link
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def async_main(): async def async_main():
if len(sys.argv) != 3: if len(sys.argv) != 3:
print( print('Usage: controllers.py <hci-transport-1> <hci-transport-2> [<hci-transport-3> ...]')
'Usage: controllers.py <hci-transport-1> <hci-transport-2> '
'[<hci-transport-3> ...]'
)
print('example: python controllers.py pty:ble1 pty:ble2') print('example: python controllers.py pty:ble1 pty:ble2')
return return
# Create a local link to attach the controllers to # Create a loccal link to attach the controllers to
link = LocalLink() link = LocalLink()
# Create a transport and controller for all requested names # Create a transport and controller for all requested names
@@ -44,12 +41,7 @@ async def async_main():
for index, transport_name in enumerate(sys.argv[1:]): for index, transport_name in enumerate(sys.argv[1:]):
transport = await open_transport_or_link(transport_name) transport = await open_transport_or_link(transport_name)
transports.append(transport) transports.append(transport)
controller = Controller( controller = Controller(f'C{index}', host_source = transport.source, host_sink = transport.sink, link = link)
f'C{index}',
host_source=transport.source,
host_sink=transport.sink,
link=link,
)
controllers.append(controller) controllers.append(controller)
# Wait until the user interrupts # Wait until the user interrupts
@@ -62,7 +54,7 @@ async def async_main():
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def main(): def main():
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(async_main()) asyncio.run(async_main())
+8 -17
View File
@@ -19,9 +19,9 @@ import asyncio
import os import os
import logging import logging
import click import click
from colors import color
import bumble.core from bumble.core import ProtocolError, TimeoutError
from bumble.colors import color
from bumble.device import Device, Peer from bumble.device import Device, Peer
from bumble.gatt import show_services from bumble.gatt import show_services
from bumble.transport import open_transport_or_link from bumble.transport import open_transport_or_link
@@ -49,9 +49,9 @@ async def dump_gatt_db(peer, done):
try: try:
value = await attribute.read_value() value = await attribute.read_value()
print(color(f'{value.hex()}', 'green')) print(color(f'{value.hex()}', 'green'))
except bumble.core.ProtocolError as error: except ProtocolError as error:
print(color(error, 'red')) print(color(error, 'red'))
except bumble.core.TimeoutError: except TimeoutError:
print(color('read timeout', 'red')) print(color('read timeout', 'red'))
if done is not None: if done is not None:
@@ -64,13 +64,9 @@ async def async_main(device_config, encrypt, transport, address_or_name):
# Create a device # Create a device
if device_config: if device_config:
device = Device.from_config_file_with_hci( device = Device.from_config_file_with_hci(device_config, hci_source, hci_sink)
device_config, hci_source, hci_sink
)
else: else:
device = Device.with_hci( device = Device.with_hci('Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink)
'Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink
)
await device.power_on() await device.power_on()
if address_or_name: if address_or_name:
@@ -85,12 +81,7 @@ async def async_main(device_config, encrypt, transport, address_or_name):
else: else:
# Wait for a connection # Wait for a connection
done = asyncio.get_running_loop().create_future() done = asyncio.get_running_loop().create_future()
device.on( device.on('connection', lambda connection: asyncio.create_task(dump_gatt_db(Peer(connection), done)))
'connection',
lambda connection: asyncio.create_task(
dump_gatt_db(Peer(connection), done)
),
)
await device.start_advertising(auto_restart=True) await device.start_advertising(auto_restart=True)
print(color('### Waiting for connection...', 'blue')) print(color('### Waiting for connection...', 'blue'))
@@ -108,7 +99,7 @@ def main(device_config, encrypt, transport, address_or_name):
Dump the GATT database on a remote device. If ADDRESS_OR_NAME is not specified, Dump the GATT database on a remote device. If ADDRESS_OR_NAME is not specified,
wait for an incoming connection. wait for an incoming connection.
""" """
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(async_main(device_config, encrypt, transport, address_or_name)) asyncio.run(async_main(device_config, encrypt, transport, address_or_name))
+69 -260
View File
@@ -17,15 +17,13 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import asyncio import asyncio
import os import os
import struct
import logging import logging
import click import click
from colors import color
from bumble import l2cap
from bumble.colors import color
from bumble.device import Device, Peer from bumble.device import Device, Peer
from bumble.core import AdvertisingData from bumble.core import AdvertisingData
from bumble.gatt import Service, Characteristic, CharacteristicValue from bumble.gatt import Service, Characteristic
from bumble.utils import AsyncRunner from bumble.utils import AsyncRunner
from bumble.transport import open_transport_or_link from bumble.transport import open_transport_or_link
from bumble.hci import HCI_Constant from bumble.hci import HCI_Constant
@@ -34,73 +32,24 @@ from bumble.hci import HCI_Constant
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Constants # Constants
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
GG_GATTLINK_SERVICE_UUID = 'ABBAFF00-E56A-484C-B832-8B17CF6CBFE8' GG_GATTLINK_SERVICE_UUID = 'ABBAFF00-E56A-484C-B832-8B17CF6CBFE8'
GG_GATTLINK_RX_CHARACTERISTIC_UUID = 'ABBAFF01-E56A-484C-B832-8B17CF6CBFE8' GG_GATTLINK_RX_CHARACTERISTIC_UUID = 'ABBAFF01-E56A-484C-B832-8B17CF6CBFE8'
GG_GATTLINK_TX_CHARACTERISTIC_UUID = 'ABBAFF02-E56A-484C-B832-8B17CF6CBFE8' GG_GATTLINK_TX_CHARACTERISTIC_UUID = 'ABBAFF02-E56A-484C-B832-8B17CF6CBFE8'
GG_GATTLINK_L2CAP_CHANNEL_PSM_CHARACTERISTIC_UUID = ( GG_GATTLINK_L2CAP_CHANNEL_PSM_CHARACTERISTIC_UUID = 'ABBAFF03-E56A-484C-B832-8B17CF6CBFE8'
'ABBAFF03-E56A-484C-B832-8B17CF6CBFE8'
)
GG_PREFERRED_MTU = 256 GG_PREFERRED_MTU = 256
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class GattlinkL2capEndpoint: class GattlinkHubBridge(Device.Listener):
def __init__(self): def __init__(self):
self.l2cap_channel = None self.peer = None
self.l2cap_packet = b'' self.rx_socket = None
self.l2cap_packet_size = 0 self.tx_socket = None
# Called when an L2CAP SDU has been received
def on_coc_sdu(self, sdu):
print(color(f'<<< [L2CAP SDU]: {len(sdu)} bytes', 'cyan'))
while len(sdu):
if self.l2cap_packet_size == 0:
# Expect a new packet
self.l2cap_packet_size = sdu[0] + 1
sdu = sdu[1:]
else:
bytes_needed = self.l2cap_packet_size - len(self.l2cap_packet)
chunk = min(bytes_needed, len(sdu))
self.l2cap_packet += sdu[:chunk]
sdu = sdu[chunk:]
if len(self.l2cap_packet) == self.l2cap_packet_size:
self.on_l2cap_packet(self.l2cap_packet)
self.l2cap_packet = b''
self.l2cap_packet_size = 0
# -----------------------------------------------------------------------------
class GattlinkHubBridge(GattlinkL2capEndpoint, Device.Listener):
def __init__(self, device, peer_address):
super().__init__()
self.device = device
self.peer_address = peer_address
self.peer = None
self.tx_socket = None
self.rx_characteristic = None self.rx_characteristic = None
self.tx_characteristic = None self.tx_characteristic = None
self.l2cap_psm_characteristic = None
device.listener = self
async def start(self):
# Connect to the peer
print(f'=== Connecting to {self.peer_address}...')
await self.device.connect(self.peer_address)
async def connect_l2cap(self, psm):
print(color(f'### Connecting with L2CAP on PSM = {psm}', 'yellow'))
try:
self.l2cap_channel = await self.peer.connection.open_l2cap_channel(psm)
print(color('*** Connected', 'yellow'), self.l2cap_channel)
self.l2cap_channel.sink = self.on_coc_sdu
except Exception as error:
print(color(f'!!! Connection failed: {error}', 'red'))
@AsyncRunner.run_in_task() @AsyncRunner.run_in_task()
# pylint: disable=invalid-overridden-method
async def on_connection(self, connection): async def on_connection(self, connection):
print(f'=== Connected to {connection}') print(f'=== Connected to {connection}')
self.peer = Peer(connection) self.peer = Peer(connection)
@@ -131,226 +80,115 @@ class GattlinkHubBridge(GattlinkL2capEndpoint, Device.Listener):
self.rx_characteristic = characteristic self.rx_characteristic = characteristic
elif characteristic.uuid == GG_GATTLINK_TX_CHARACTERISTIC_UUID: elif characteristic.uuid == GG_GATTLINK_TX_CHARACTERISTIC_UUID:
self.tx_characteristic = characteristic self.tx_characteristic = characteristic
elif (
characteristic.uuid == GG_GATTLINK_L2CAP_CHANNEL_PSM_CHARACTERISTIC_UUID
):
self.l2cap_psm_characteristic = characteristic
print('RX:', self.rx_characteristic) print('RX:', self.rx_characteristic)
print('TX:', self.tx_characteristic) print('TX:', self.tx_characteristic)
print('PSM:', self.l2cap_psm_characteristic)
if self.l2cap_psm_characteristic: # Subscribe to TX
# Subscribe to and then read the PSM value if self.tx_characteristic:
await self.peer.subscribe(
self.l2cap_psm_characteristic, self.on_l2cap_psm_received
)
psm_bytes = await self.peer.read_value(self.l2cap_psm_characteristic)
psm = struct.unpack('<H', psm_bytes)[0]
await self.connect_l2cap(psm)
elif self.tx_characteristic:
# Subscribe to TX
await self.peer.subscribe(self.tx_characteristic, self.on_tx_received) await self.peer.subscribe(self.tx_characteristic, self.on_tx_received)
print(color('=== Subscribed to Gattlink TX', 'yellow')) print(color('=== Subscribed to Gattlink TX', 'yellow'))
else: else:
print(color('!!! No Gattlink TX or PSM found', 'red')) print(color('!!! Gattlink TX not found', 'red'))
def on_connection_failure(self, error): def on_connection_failure(self, error):
print(color(f'!!! Connection failed: {error}')) print(color(f'!!! Connection failed: {error}'))
def on_disconnection(self, reason): def on_disconnection(self, reason):
print( print(color(f'!!! Disconnected from {self.peer}, reason={HCI_Constant.error_name(reason)}', 'red'))
color(
f'!!! Disconnected from {self.peer}, '
f'reason={HCI_Constant.error_name(reason)}',
'red',
)
)
self.tx_characteristic = None self.tx_characteristic = None
self.rx_characteristic = None self.rx_characteristic = None
self.peer = None self.peer = None
# Called when an L2CAP packet has been received
def on_l2cap_packet(self, packet):
print(color(f'<<< [L2CAP PACKET]: {len(packet)} bytes', 'cyan'))
print(color('>>> [UDP]', 'magenta'))
self.tx_socket.sendto(packet)
# Called by the GATT client when a notification is received # Called by the GATT client when a notification is received
def on_tx_received(self, value): def on_tx_received(self, value):
print(color(f'<<< [GATT TX]: {len(value)} bytes', 'cyan')) print(color('>>> TX:', 'magenta'), value.hex())
if self.tx_socket: if self.tx_socket:
print(color('>>> [UDP]', 'magenta'))
self.tx_socket.sendto(value) self.tx_socket.sendto(value)
# Called by asyncio when the UDP socket is created
def on_l2cap_psm_received(self, value):
psm = struct.unpack('<H', value)[0]
asyncio.create_task(self.connect_l2cap(psm))
# Called by asyncio when the UDP socket is created # Called by asyncio when the UDP socket is created
def connection_made(self, transport): def connection_made(self, transport):
pass pass
# Called by asyncio when a UDP datagram is received # Called by asyncio when a UDP datagram is received
def datagram_received(self, data, _address): def datagram_received(self, data, address):
print(color(f'<<< [UDP]: {len(data)} bytes', 'green')) print(color('<<< RX:', 'magenta'), data.hex())
if self.l2cap_channel: # TODO: use a queue instead of creating a task everytime
print(color('>>> [L2CAP]', 'yellow')) if self.peer and self.rx_characteristic:
self.l2cap_channel.write(bytes([len(data) - 1]) + data)
elif self.peer and self.rx_characteristic:
print(color('>>> [GATT RX]', 'yellow'))
asyncio.create_task(self.peer.write_value(self.rx_characteristic, data)) asyncio.create_task(self.peer.write_value(self.rx_characteristic, data))
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class GattlinkNodeBridge(GattlinkL2capEndpoint, Device.Listener): class GattlinkNodeBridge(Device.Listener):
def __init__(self, device: Device): def __init__(self):
super().__init__() self.peer = None
self.device = device self.rx_socket = None
self.peer = None
self.tx_socket = None self.tx_socket = None
self.tx_subscriber = None
self.rx_characteristic = None
self.transport = None
# Register as a listener
device.listener = self
# Listen for incoming L2CAP CoC connections
psm = 0xFB
device.create_l2cap_server(
spec=l2cap.LeCreditBasedChannelSpec(
psm=0xFB,
),
handler=self.on_coc,
)
print(f'### Listening for CoC connection on PSM {psm}')
# Setup the Gattlink service
self.rx_characteristic = Characteristic(
GG_GATTLINK_RX_CHARACTERISTIC_UUID,
Characteristic.WRITE_WITHOUT_RESPONSE,
Characteristic.WRITEABLE,
CharacteristicValue(write=self.on_rx_write),
)
self.tx_characteristic = Characteristic(
GG_GATTLINK_TX_CHARACTERISTIC_UUID,
Characteristic.Properties.NOTIFY,
Characteristic.READABLE,
)
self.tx_characteristic.on('subscription', self.on_tx_subscription)
self.psm_characteristic = Characteristic(
GG_GATTLINK_L2CAP_CHANNEL_PSM_CHARACTERISTIC_UUID,
Characteristic.Properties.READ | Characteristic.Properties.NOTIFY,
Characteristic.READABLE,
bytes([psm, 0]),
)
gattlink_service = Service(
GG_GATTLINK_SERVICE_UUID,
[self.rx_characteristic, self.tx_characteristic, self.psm_characteristic],
)
device.add_services([gattlink_service])
device.advertising_data = bytes(
AdvertisingData(
[
(AdvertisingData.COMPLETE_LOCAL_NAME, bytes('Bumble GG', 'utf-8')),
(
AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS,
bytes(
reversed(bytes.fromhex('ABBAFF00E56A484CB8328B17CF6CBFE8'))
),
),
]
)
)
async def start(self):
await self.device.start_advertising()
# Called by asyncio when the UDP socket is created # Called by asyncio when the UDP socket is created
def connection_made(self, transport): def connection_made(self, transport):
self.transport = transport pass
# Called by asyncio when a UDP datagram is received # Called by asyncio when a UDP datagram is received
def datagram_received(self, data, _address): def datagram_received(self, data, address):
print(color(f'<<< [UDP]: {len(data)} bytes', 'green')) print(color('<<< RX:', 'magenta'), data.hex())
if self.l2cap_channel: # TODO: use a queue instead of creating a task everytime
print(color('>>> [L2CAP]', 'yellow')) if self.peer and self.rx_characteristic:
self.l2cap_channel.write(bytes([len(data) - 1]) + data) asyncio.create_task(self.peer.write_value(self.rx_characteristic, data))
elif self.tx_subscriber:
print(color('>>> [GATT TX]', 'yellow'))
self.tx_characteristic.value = data
asyncio.create_task(self.device.notify_subscribers(self.tx_characteristic))
# Called when a write to the RX characteristic has been received
def on_rx_write(self, _connection, data):
print(color(f'<<< [GATT RX]: {len(data)} bytes', 'cyan'))
print(color('>>> [UDP]', 'magenta'))
self.tx_socket.sendto(data)
# Called when the subscription to the TX characteristic has changed
def on_tx_subscription(self, peer, enabled):
print(
f'### [GATT TX] subscription from {peer}: '
f'{"enabled" if enabled else "disabled"}'
)
if enabled:
self.tx_subscriber = peer
else:
self.tx_subscriber = None
# Called when an L2CAP packet is received
def on_l2cap_packet(self, packet):
print(color(f'<<< [L2CAP PACKET]: {len(packet)} bytes', 'cyan'))
print(color('>>> [UDP]', 'magenta'))
self.tx_socket.sendto(packet)
# Called when a new connection is established
def on_coc(self, channel):
print('*** CoC Connection', channel)
self.l2cap_channel = channel
channel.sink = self.on_coc_sdu
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def run( async def run(hci_transport, device_address, send_host, send_port, receive_host, receive_port):
hci_transport,
device_address,
role_or_peer_address,
send_host,
send_port,
receive_host,
receive_port,
):
print('<<< connecting to HCI...') print('<<< connecting to HCI...')
async with await open_transport_or_link(hci_transport) as (hci_source, hci_sink): async with await open_transport_or_link(hci_transport) as (hci_source, hci_sink):
print('<<< connected') print('<<< connected')
# Instantiate a bridge object # Instantiate a bridge object
device = Device.with_hci('Bumble GG', device_address, hci_source, hci_sink) bridge = GattlinkNodeBridge()
# Instantiate a bridge object
if role_or_peer_address == 'node':
bridge = GattlinkNodeBridge(device)
else:
bridge = GattlinkHubBridge(device, role_or_peer_address)
# Create a UDP to RX bridge (receive from UDP, send to RX) # Create a UDP to RX bridge (receive from UDP, send to RX)
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
await loop.create_datagram_endpoint( await loop.create_datagram_endpoint(
lambda: bridge, local_addr=(receive_host, receive_port) lambda: bridge,
local_addr=(receive_host, receive_port)
) )
# Create a UDP to TX bridge (receive from TX, send to UDP) # Create a UDP to TX bridge (receive from TX, send to UDP)
bridge.tx_socket, _ = await loop.create_datagram_endpoint( bridge.tx_socket, _ = await loop.create_datagram_endpoint(
asyncio.DatagramProtocol, lambda: asyncio.DatagramProtocol(),
remote_addr=(send_host, send_port), remote_addr=(send_host, send_port)
) )
# Create a device to manage the host, with a custom listener
device = Device.with_hci('Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink)
device.listener = bridge
await device.power_on() await device.power_on()
await bridge.start()
# Connect to the peer
# print(f'=== Connecting to {device_address}...')
# await device.connect(device_address)
# TODO move to class
gattlink_service = Service(
GG_GATTLINK_SERVICE_UUID,
[
Characteristic(
GG_GATTLINK_L2CAP_CHANNEL_PSM_CHARACTERISTIC_UUID,
Characteristic.READ,
Characteristic.READABLE,
bytes([193, 0])
)
]
)
device.add_services([gattlink_service])
device.advertising_data = bytes(
AdvertisingData([
(AdvertisingData.COMPLETE_LOCAL_NAME, bytes('Bumble GG', 'utf-8')),
(AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS, bytes(reversed(bytes.fromhex('ABBAFF00E56A484CB8328B17CF6CBFE8'))))
])
)
await device.start_advertising()
# Wait until the source terminates # Wait until the source terminates
await hci_source.wait_for_termination() await hci_source.wait_for_termination()
@@ -359,44 +197,15 @@ async def run(
@click.command() @click.command()
@click.argument('hci_transport') @click.argument('hci_transport')
@click.argument('device_address') @click.argument('device_address')
@click.argument('role_or_peer_address') @click.option('-sh', '--send-host', type=str, default='127.0.0.1', help='UDP host to send to')
@click.option(
'-sh', '--send-host', type=str, default='127.0.0.1', help='UDP host to send to'
)
@click.option('-sp', '--send-port', type=int, default=9001, help='UDP port to send to') @click.option('-sp', '--send-port', type=int, default=9001, help='UDP port to send to')
@click.option( @click.option('-rh', '--receive-host', type=str, default='127.0.0.1', help='UDP host to receive on')
'-rh', @click.option('-rp', '--receive-port', type=int, default=9000, help='UDP port to receive on')
'--receive-host', def main(hci_transport, device_address, send_host, send_port, receive_host, receive_port):
type=str, logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
default='127.0.0.1', asyncio.run(run(hci_transport, device_address, send_host, send_port, receive_host, receive_port))
help='UDP host to receive on',
)
@click.option(
'-rp', '--receive-port', type=int, default=9000, help='UDP port to receive on'
)
def main(
hci_transport,
device_address,
role_or_peer_address,
send_host,
send_port,
receive_host,
receive_port,
):
asyncio.run(
run(
hci_transport,
device_address,
role_or_peer_address,
send_host,
send_port,
receive_host,
receive_port,
)
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
if __name__ == '__main__': if __name__ == '__main__':
main() main()
+11 -31
View File
@@ -34,29 +34,16 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def async_main(): async def async_main():
if len(sys.argv) < 3: if len(sys.argv) < 3:
print( print('Usage: hci_bridge.py <host-transport-spec> <controller-transport-spec> [command-short-circuit-list]')
'Usage: hci_bridge.py <host-transport-spec> <controller-transport-spec> ' print('example: python hci_bridge.py udp:0.0.0.0:9000,127.0.0.1:9001 serial:/dev/tty.usbmodem0006839912171,1000000 0x3f:0x0070,0x3f:0x0074,0x3f:0x0077,0x3f:0x0078')
'[command-short-circuit-list]'
)
print(
'example: python hci_bridge.py udp:0.0.0.0:9000,127.0.0.1:9001 '
'serial:/dev/tty.usbmodem0006839912171,1000000 '
'0x3f:0x0070,0x3f:0x0074,0x3f:0x0077,0x3f:0x0078'
)
return return
print('>>> connecting to HCI...') print('>>> connecting to HCI...')
async with await transport.open_transport_or_link(sys.argv[1]) as ( async with await transport.open_transport_or_link(sys.argv[1]) as (hci_host_source, hci_host_sink):
hci_host_source,
hci_host_sink,
):
print('>>> connected') print('>>> connected')
print('>>> connecting to HCI...') print('>>> connecting to HCI...')
async with await transport.open_transport_or_link(sys.argv[2]) as ( async with await transport.open_transport_or_link(sys.argv[2]) as (hci_controller_source, hci_controller_sink):
hci_controller_source,
hci_controller_sink,
):
print('>>> connected') print('>>> connected')
command_short_circuits = [] command_short_circuits = []
@@ -64,43 +51,36 @@ async def async_main():
for op_code_str in sys.argv[3].split(','): for op_code_str in sys.argv[3].split(','):
if ':' in op_code_str: if ':' in op_code_str:
ogf, ocf = op_code_str.split(':') ogf, ocf = op_code_str.split(':')
command_short_circuits.append( command_short_circuits.append(hci.hci_command_op_code(int(ogf, 16), int(ocf, 16)))
hci.hci_command_op_code(int(ogf, 16), int(ocf, 16))
)
else: else:
command_short_circuits.append(int(op_code_str, 16)) command_short_circuits.append(int(op_code_str, 16))
def host_to_controller_filter(hci_packet): def host_to_controller_filter(hci_packet):
if ( if hci_packet.hci_packet_type == hci.HCI_COMMAND_PACKET and hci_packet.op_code in command_short_circuits:
hci_packet.hci_packet_type == hci.HCI_COMMAND_PACKET
and hci_packet.op_code in command_short_circuits
):
# Respond with a success response # Respond with a success response
logger.debug('short-circuiting packet') logger.debug('short-circuiting packet')
response = hci.HCI_Command_Complete_Event( response = hci.HCI_Command_Complete_Event(
num_hci_command_packets=1, num_hci_command_packets = 1,
command_opcode=hci_packet.op_code, command_opcode = hci_packet.op_code,
return_parameters=bytes([hci.HCI_SUCCESS]), return_parameters = bytes([hci.HCI_SUCCESS])
) )
# Return a packet with 'respond to sender' set to True # Return a packet with 'respond to sender' set to True
return (response.to_bytes(), True) return (response.to_bytes(), True)
return None
_ = HCI_Bridge( _ = HCI_Bridge(
hci_host_source, hci_host_source,
hci_host_sink, hci_host_sink,
hci_controller_source, hci_controller_source,
hci_controller_sink, hci_controller_sink,
host_to_controller_filter, host_to_controller_filter,
None, None
) )
await asyncio.get_running_loop().create_future() await asyncio.get_running_loop().create_future()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def main(): def main():
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(async_main()) asyncio.run(async_main())
-361
View File
@@ -1,361 +0,0 @@
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import logging
import os
import click
from bumble import l2cap
from bumble.colors import color
from bumble.transport import open_transport_or_link
from bumble.device import Device
from bumble.utils import FlowControlAsyncPipe
from bumble.hci import HCI_Constant
# -----------------------------------------------------------------------------
class ServerBridge:
"""
L2CAP CoC server bridge: waits for a peer to connect an L2CAP CoC channel
on a specified PSM. When the connection is made, the bridge connects a TCP
socket to a remote host and bridges the data in both directions, with flow
control.
When the L2CAP CoC channel is closed, the bridge disconnects the TCP socket
and waits for a new L2CAP CoC channel to be connected.
When the TCP connection is closed by the TCP server, XXXX
"""
def __init__(self, psm, max_credits, mtu, mps, tcp_host, tcp_port):
self.psm = psm
self.max_credits = max_credits
self.mtu = mtu
self.mps = mps
self.tcp_host = tcp_host
self.tcp_port = tcp_port
async def start(self, device: Device) -> None:
# Listen for incoming L2CAP channel connections
device.create_l2cap_server(
spec=l2cap.LeCreditBasedChannelSpec(
psm=self.psm, mtu=self.mtu, mps=self.mps, max_credits=self.max_credits
),
handler=self.on_channel,
)
print(
color(f'### Listening for channel connection on PSM {self.psm}', 'yellow')
)
def on_ble_connection(connection):
def on_ble_disconnection(reason):
print(
color('@@@ Bluetooth disconnection:', 'red'),
HCI_Constant.error_name(reason),
)
print(color('@@@ Bluetooth connection:', 'green'), connection)
connection.on('disconnection', on_ble_disconnection)
device.on('connection', on_ble_connection)
await device.start_advertising(auto_restart=True)
# Called when a new L2CAP connection is established
def on_channel(self, l2cap_channel):
print(color('*** L2CAP channel:', 'cyan'), l2cap_channel)
class Pipe:
def __init__(self, bridge, l2cap_channel):
self.bridge = bridge
self.tcp_transport = None
self.l2cap_channel = l2cap_channel
l2cap_channel.on('close', self.on_l2cap_close)
l2cap_channel.sink = self.on_channel_sdu
async def connect_to_tcp(self):
# Connect to the TCP server
print(
color(
f'### Connecting to TCP {self.bridge.tcp_host}:'
f'{self.bridge.tcp_port}...',
'yellow',
)
)
class TcpClientProtocol(asyncio.Protocol):
def __init__(self, pipe):
self.pipe = pipe
def connection_lost(self, exc):
print(color(f'!!! TCP connection lost: {exc}', 'red'))
if self.pipe.l2cap_channel is not None:
asyncio.create_task(self.pipe.l2cap_channel.disconnect())
def data_received(self, data):
print(color(f'<<< [TCP DATA]: {len(data)} bytes', 'blue'))
self.pipe.l2cap_channel.write(data)
try:
(
self.tcp_transport,
_,
) = await asyncio.get_running_loop().create_connection(
lambda: TcpClientProtocol(self),
host=self.bridge.tcp_host,
port=self.bridge.tcp_port,
)
print(color('### Connected', 'green'))
except Exception as error:
print(color(f'!!! Connection failed: {error}', 'red'))
await self.l2cap_channel.disconnect()
def on_l2cap_close(self):
print(color('*** L2CAP channel closed', 'red'))
self.l2cap_channel = None
if self.tcp_transport is not None:
self.tcp_transport.close()
def on_channel_sdu(self, sdu):
print(color(f'<<< [L2CAP SDU]: {len(sdu)} bytes', 'cyan'))
if self.tcp_transport is None:
print(color('!!! TCP socket not open, dropping', 'red'))
return
self.tcp_transport.write(sdu)
pipe = Pipe(self, l2cap_channel)
asyncio.create_task(pipe.connect_to_tcp())
# -----------------------------------------------------------------------------
class ClientBridge:
"""
L2CAP CoC client bridge: connects to a BLE device, then waits for an inbound
TCP connection on a specified port number. When a TCP client connects, an
L2CAP CoC channel connection to the BLE device is established, and the data
is bridged in both directions, with flow control.
When the TCP connection is closed by the client, the L2CAP CoC channel is
disconnected, but the connection to the BLE device remains, ready for a new
TCP client to connect.
When the L2CAP CoC channel is closed, XXXX
"""
READ_CHUNK_SIZE = 4096
def __init__(self, psm, max_credits, mtu, mps, address, tcp_host, tcp_port):
self.psm = psm
self.max_credits = max_credits
self.mtu = mtu
self.mps = mps
self.address = address
self.tcp_host = tcp_host
self.tcp_port = tcp_port
async def start(self, device):
print(color(f'### Connecting to {self.address}...', 'yellow'))
connection = await device.connect(self.address)
print(color('### Connected', 'green'))
# Called when the BLE connection is disconnected
def on_ble_disconnection(reason):
print(
color('@@@ Bluetooth disconnection:', 'red'),
HCI_Constant.error_name(reason),
)
connection.on('disconnection', on_ble_disconnection)
# Called when a TCP connection is established
async def on_tcp_connection(reader, writer):
peer_name = writer.get_extra_info('peer_name')
print(color(f'<<< TCP connection from {peer_name}', 'magenta'))
def on_channel_sdu(sdu):
print(color(f'<<< [L2CAP SDU]: {len(sdu)} bytes', 'cyan'))
l2cap_to_tcp_pipe.write(sdu)
def on_l2cap_close():
print(color('*** L2CAP channel closed', 'red'))
l2cap_to_tcp_pipe.stop()
writer.close()
# Connect a new L2CAP channel
print(color(f'>>> Opening L2CAP channel on PSM = {self.psm}', 'yellow'))
try:
l2cap_channel = await connection.create_l2cap_channel(
spec=l2cap.LeCreditBasedChannelSpec(
psm=self.psm,
max_credits=self.max_credits,
mtu=self.mtu,
mps=self.mps,
)
)
print(color('*** L2CAP channel:', 'cyan'), l2cap_channel)
except Exception as error:
print(color(f'!!! Connection failed: {error}', 'red'))
writer.close()
return
l2cap_channel.sink = on_channel_sdu
l2cap_channel.on('close', on_l2cap_close)
# Start a flow control pipe from L2CAP to TCP
l2cap_to_tcp_pipe = FlowControlAsyncPipe(
l2cap_channel.pause_reading,
l2cap_channel.resume_reading,
writer.write,
writer.drain,
)
l2cap_to_tcp_pipe.start()
# Pipe data from TCP to L2CAP
while True:
try:
data = await reader.read(self.READ_CHUNK_SIZE)
if len(data) == 0:
print(color('!!! End of stream', 'red'))
await l2cap_channel.disconnect()
return
print(color(f'<<< [TCP DATA]: {len(data)} bytes', 'blue'))
l2cap_channel.write(data)
await l2cap_channel.drain()
except Exception as error:
print(f'!!! Exception: {error}')
break
writer.close()
print(color('~~~ Bye bye', 'magenta'))
await asyncio.start_server(
on_tcp_connection,
host=self.tcp_host if self.tcp_host != '_' else None,
port=self.tcp_port,
)
print(
color(
f'### Listening for TCP connections on port {self.tcp_port}', 'magenta'
)
)
# -----------------------------------------------------------------------------
async def run(device_config, hci_transport, bridge):
print('<<< connecting to HCI...')
async with await open_transport_or_link(hci_transport) as (hci_source, hci_sink):
print('<<< connected')
device = Device.from_config_file_with_hci(device_config, hci_source, hci_sink)
# Let's go
await device.power_on()
await bridge.start(device)
# Wait until the transport terminates
await hci_source.wait_for_termination()
# -----------------------------------------------------------------------------
@click.group()
@click.pass_context
@click.option('--device-config', help='Device configuration file', required=True)
@click.option('--hci-transport', help='HCI transport', required=True)
@click.option('--psm', help='PSM for L2CAP', type=int, default=1234)
@click.option(
'--l2cap-max-credits',
help='Maximum L2CAP Credits',
type=click.IntRange(1, 65535),
default=128,
)
@click.option(
'--l2cap-mtu',
help='L2CAP MTU',
type=click.IntRange(
l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU,
l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MTU,
),
default=1024,
)
@click.option(
'--l2cap-mps',
help='L2CAP MPS',
type=click.IntRange(
l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MPS,
l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MPS,
),
default=1024,
)
def cli(
context,
device_config,
hci_transport,
psm,
l2cap_max_credits,
l2cap_mtu,
l2cap_mps,
):
context.ensure_object(dict)
context.obj['device_config'] = device_config
context.obj['hci_transport'] = hci_transport
context.obj['psm'] = psm
context.obj['max_credits'] = l2cap_max_credits
context.obj['mtu'] = l2cap_mtu
context.obj['mps'] = l2cap_mps
# -----------------------------------------------------------------------------
@cli.command()
@click.pass_context
@click.option('--tcp-host', help='TCP host', default='localhost')
@click.option('--tcp-port', help='TCP port', default=9544)
def server(context, tcp_host, tcp_port):
bridge = ServerBridge(
context.obj['psm'],
context.obj['max_credits'],
context.obj['mtu'],
context.obj['mps'],
tcp_host,
tcp_port,
)
asyncio.run(run(context.obj['device_config'], context.obj['hci_transport'], bridge))
# -----------------------------------------------------------------------------
@cli.command()
@click.pass_context
@click.argument('bluetooth-address')
@click.option('--tcp-host', help='TCP host', default='_')
@click.option('--tcp-port', help='TCP port', default=9543)
def client(context, bluetooth_address, tcp_host, tcp_port):
bridge = ClientBridge(
context.obj['psm'],
context.obj['max_credits'],
context.obj['mtu'],
context.obj['mps'],
bluetooth_address,
tcp_host,
tcp_port,
)
asyncio.run(run(context.obj['device_config'], context.obj['hci_transport'], bridge))
# -----------------------------------------------------------------------------
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
if __name__ == '__main__':
cli(obj={}) # pylint: disable=no-value-for-parameter
-577
View File
@@ -1,577 +0,0 @@
# Copyright 2021-2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import datetime
import enum
import functools
from importlib import resources
import json
import os
import logging
import pathlib
from typing import Optional, List, cast
import weakref
import struct
import ctypes
import wasmtime
import wasmtime.loader
import liblc3 # type: ignore
import logging
import click
import aiohttp.web
import bumble
from bumble.core import AdvertisingData
from bumble.colors import color
from bumble.device import Device, DeviceConfiguration, AdvertisingParameters
from bumble.transport import open_transport
from bumble.profiles import bap
from bumble.hci import Address, CodecID, CodingFormat, HCI_IsoDataPacket
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
DEFAULT_UI_PORT = 7654
def _sink_pac_record() -> bap.PacRecord:
return bap.PacRecord(
coding_format=CodingFormat(CodecID.LC3),
codec_specific_capabilities=bap.CodecSpecificCapabilities(
supported_sampling_frequencies=(
bap.SupportedSamplingFrequency.FREQ_8000
| bap.SupportedSamplingFrequency.FREQ_16000
| bap.SupportedSamplingFrequency.FREQ_24000
| bap.SupportedSamplingFrequency.FREQ_32000
| bap.SupportedSamplingFrequency.FREQ_48000
),
supported_frame_durations=(
bap.SupportedFrameDuration.DURATION_10000_US_SUPPORTED
),
supported_audio_channel_count=[1, 2],
min_octets_per_codec_frame=26,
max_octets_per_codec_frame=240,
supported_max_codec_frames_per_sdu=2,
),
)
def _source_pac_record() -> bap.PacRecord:
return bap.PacRecord(
coding_format=CodingFormat(CodecID.LC3),
codec_specific_capabilities=bap.CodecSpecificCapabilities(
supported_sampling_frequencies=(
bap.SupportedSamplingFrequency.FREQ_8000
| bap.SupportedSamplingFrequency.FREQ_16000
| bap.SupportedSamplingFrequency.FREQ_24000
| bap.SupportedSamplingFrequency.FREQ_32000
| bap.SupportedSamplingFrequency.FREQ_48000
),
supported_frame_durations=(
bap.SupportedFrameDuration.DURATION_10000_US_SUPPORTED
),
supported_audio_channel_count=[1],
min_octets_per_codec_frame=30,
max_octets_per_codec_frame=100,
supported_max_codec_frames_per_sdu=1,
),
)
# -----------------------------------------------------------------------------
# WASM - liblc3
# -----------------------------------------------------------------------------
store = wasmtime.loader.store
_memory = cast(wasmtime.Memory, liblc3.memory)
STACK_POINTER = _memory.data_len(store)
_memory.grow(store, 1)
# Mapping wasmtime memory to linear address
memory = (ctypes.c_ubyte * _memory.data_len(store)).from_address(
ctypes.addressof(_memory.data_ptr(store).contents) # type: ignore
)
class Liblc3PcmFormat(enum.IntEnum):
S16 = 0
S24 = 1
S24_3LE = 2
FLOAT = 3
MAX_DECODER_SIZE = liblc3.lc3_decoder_size(10000, 48000)
MAX_ENCODER_SIZE = liblc3.lc3_encoder_size(10000, 48000)
DECODER_STACK_POINTER = STACK_POINTER
ENCODER_STACK_POINTER = DECODER_STACK_POINTER + MAX_DECODER_SIZE * 2
DECODE_BUFFER_STACK_POINTER = ENCODER_STACK_POINTER + MAX_ENCODER_SIZE * 2
ENCODE_BUFFER_STACK_POINTER = DECODE_BUFFER_STACK_POINTER + 8192
DEFAULT_PCM_SAMPLE_RATE = 48000
DEFAULT_PCM_FORMAT = Liblc3PcmFormat.S16
DEFAULT_PCM_BYTES_PER_SAMPLE = 2
encoders: List[int] = []
decoders: List[int] = []
def setup_encoders(
sample_rate_hz: int, frame_duration_us: int, num_channels: int
) -> None:
logger.info(
f"setup_encoders {sample_rate_hz}Hz {frame_duration_us}us {num_channels}channels"
)
encoders[:num_channels] = [
liblc3.lc3_setup_encoder(
frame_duration_us,
sample_rate_hz,
DEFAULT_PCM_SAMPLE_RATE, # Input sample rate
ENCODER_STACK_POINTER + MAX_ENCODER_SIZE * i,
)
for i in range(num_channels)
]
def setup_decoders(
sample_rate_hz: int, frame_duration_us: int, num_channels: int
) -> None:
logger.info(
f"setup_decoders {sample_rate_hz}Hz {frame_duration_us}us {num_channels}channels"
)
decoders[:num_channels] = [
liblc3.lc3_setup_decoder(
frame_duration_us,
sample_rate_hz,
DEFAULT_PCM_SAMPLE_RATE, # Output sample rate
DECODER_STACK_POINTER + MAX_DECODER_SIZE * i,
)
for i in range(num_channels)
]
def decode(
frame_duration_us: int,
num_channels: int,
input_bytes: bytes,
) -> bytes:
if not input_bytes:
return b''
input_buffer_offset = DECODE_BUFFER_STACK_POINTER
input_buffer_size = len(input_bytes)
input_bytes_per_frame = input_buffer_size // num_channels
# Copy into wasm
memory[input_buffer_offset : input_buffer_offset + input_buffer_size] = input_bytes # type: ignore
output_buffer_offset = input_buffer_offset + input_buffer_size
output_buffer_size = (
liblc3.lc3_frame_samples(frame_duration_us, DEFAULT_PCM_SAMPLE_RATE)
* DEFAULT_PCM_BYTES_PER_SAMPLE
* num_channels
)
for i in range(num_channels):
res = liblc3.lc3_decode(
decoders[i],
input_buffer_offset + input_bytes_per_frame * i,
input_bytes_per_frame,
DEFAULT_PCM_FORMAT,
output_buffer_offset + i * DEFAULT_PCM_BYTES_PER_SAMPLE,
num_channels, # Stride
)
if res != 0:
logging.error(f"Parsing failed, res={res}")
# Extract decoded data from the output buffer
return bytes(
memory[output_buffer_offset : output_buffer_offset + output_buffer_size]
)
def encode(
sdu_length: int,
num_channels: int,
stride: int,
input_bytes: bytes,
) -> bytes:
if not input_bytes:
return b''
input_buffer_offset = ENCODE_BUFFER_STACK_POINTER
input_buffer_size = len(input_bytes)
# Copy into wasm
memory[input_buffer_offset : input_buffer_offset + input_buffer_size] = input_bytes # type: ignore
output_buffer_offset = input_buffer_offset + input_buffer_size
output_buffer_size = sdu_length
output_frame_size = output_buffer_size // num_channels
for i in range(num_channels):
res = liblc3.lc3_encode(
encoders[i],
DEFAULT_PCM_FORMAT,
input_buffer_offset + DEFAULT_PCM_BYTES_PER_SAMPLE * i,
stride,
output_frame_size,
output_buffer_offset + output_frame_size * i,
)
if res != 0:
logging.error(f"Parsing failed, res={res}")
# Extract decoded data from the output buffer
return bytes(
memory[output_buffer_offset : output_buffer_offset + output_buffer_size]
)
async def lc3_source_task(
filename: str,
sdu_length: int,
frame_duration_us: int,
device: Device,
cis_handle: int,
) -> None:
with open(filename, 'rb') as f:
header = f.read(44)
assert header[8:12] == b'WAVE'
pcm_num_channel, pcm_sample_rate, _byte_rate, _block_align, bits_per_sample = (
struct.unpack("<HIIHH", header[22:36])
)
assert pcm_sample_rate == DEFAULT_PCM_SAMPLE_RATE
assert bits_per_sample == DEFAULT_PCM_BYTES_PER_SAMPLE * 8
frame_bytes = (
liblc3.lc3_frame_samples(frame_duration_us, DEFAULT_PCM_SAMPLE_RATE)
* DEFAULT_PCM_BYTES_PER_SAMPLE
)
packet_sequence_number = 0
while True:
next_round = datetime.datetime.now() + datetime.timedelta(
microseconds=frame_duration_us
)
pcm_data = f.read(frame_bytes)
sdu = encode(sdu_length, pcm_num_channel, pcm_num_channel, pcm_data)
iso_packet = HCI_IsoDataPacket(
connection_handle=cis_handle,
data_total_length=sdu_length + 4,
packet_sequence_number=packet_sequence_number,
pb_flag=0b10,
packet_status_flag=0,
iso_sdu_length=sdu_length,
iso_sdu_fragment=sdu,
)
device.host.send_hci_packet(iso_packet)
packet_sequence_number += 1
sleep_time = next_round - datetime.datetime.now()
await asyncio.sleep(sleep_time.total_seconds())
# -----------------------------------------------------------------------------
class UiServer:
speaker: weakref.ReferenceType[Speaker]
port: int
def __init__(self, speaker: Speaker, port: int) -> None:
self.speaker = weakref.ref(speaker)
self.port = port
self.channel_socket = None
async def start_http(self) -> None:
"""Start the UI HTTP server."""
app = aiohttp.web.Application()
app.add_routes(
[
aiohttp.web.get('/', self.get_static),
aiohttp.web.get('/index.html', self.get_static),
aiohttp.web.get('/channel', self.get_channel),
]
)
runner = aiohttp.web.AppRunner(app)
await runner.setup()
site = aiohttp.web.TCPSite(runner, 'localhost', self.port)
print('UI HTTP server at ' + color(f'http://127.0.0.1:{self.port}', 'green'))
await site.start()
async def get_static(self, request):
path = request.path
if path == '/':
path = '/index.html'
if path.endswith('.html'):
content_type = 'text/html'
elif path.endswith('.js'):
content_type = 'text/javascript'
elif path.endswith('.css'):
content_type = 'text/css'
elif path.endswith('.svg'):
content_type = 'image/svg+xml'
else:
content_type = 'text/plain'
text = (
resources.files("bumble.apps.lea_unicast")
.joinpath(pathlib.Path(path).relative_to('/'))
.read_text(encoding="utf-8")
)
return aiohttp.web.Response(text=text, content_type=content_type)
async def get_channel(self, request):
ws = aiohttp.web.WebSocketResponse()
await ws.prepare(request)
# Process messages until the socket is closed.
self.channel_socket = ws
async for message in ws:
if message.type == aiohttp.WSMsgType.TEXT:
logger.debug(f'<<< received message: {message.data}')
await self.on_message(message.data)
elif message.type == aiohttp.WSMsgType.ERROR:
logger.debug(
f'channel connection closed with exception {ws.exception()}'
)
self.channel_socket = None
logger.debug('--- channel connection closed')
return ws
async def on_message(self, message_str: str):
# Parse the message as JSON
message = json.loads(message_str)
# Dispatch the message
message_type = message['type']
message_params = message.get('params', {})
handler = getattr(self, f'on_{message_type}_message')
if handler:
await handler(**message_params)
async def on_hello_message(self):
await self.send_message(
'hello',
bumble_version=bumble.__version__,
codec=self.speaker().codec,
streamState=self.speaker().stream_state.name,
)
if connection := self.speaker().connection:
await self.send_message(
'connection',
peer_address=connection.peer_address.to_string(False),
peer_name=connection.peer_name,
)
async def send_message(self, message_type: str, **kwargs) -> None:
if self.channel_socket is None:
return
message = {'type': message_type, 'params': kwargs}
await self.channel_socket.send_json(message)
async def send_audio(self, data: bytes) -> None:
if self.channel_socket is None:
return
try:
await self.channel_socket.send_bytes(data)
except Exception as error:
logger.warning(f'exception while sending audio packet: {error}')
# -----------------------------------------------------------------------------
class Speaker:
def __init__(
self,
device_config_path: Optional[str],
ui_port: int,
transport: str,
lc3_input_file_path: str,
):
self.device_config_path = device_config_path
self.transport = transport
self.lc3_input_file_path = lc3_input_file_path
# Create an HTTP server for the UI
self.ui_server = UiServer(speaker=self, port=ui_port)
async def run(self) -> None:
await self.ui_server.start_http()
async with await open_transport(self.transport) as hci_transport:
# Create a device
if self.device_config_path:
device_config = DeviceConfiguration.from_file(self.device_config_path)
else:
device_config = DeviceConfiguration(
name="Bumble LE Headphone",
class_of_device=0x244418,
keystore="JsonKeyStore",
advertising_interval_min=25,
advertising_interval_max=25,
address=Address('F1:F2:F3:F4:F5:F6'),
)
device_config.le_enabled = True
device_config.cis_enabled = True
self.device = Device.from_config_with_hci(
device_config, hci_transport.source, hci_transport.sink
)
self.device.add_service(
bap.PublishedAudioCapabilitiesService(
supported_source_context=bap.ContextType(0xFFFF),
available_source_context=bap.ContextType(0xFFFF),
supported_sink_context=bap.ContextType(0xFFFF), # All context types
available_sink_context=bap.ContextType(0xFFFF), # All context types
sink_audio_locations=(
bap.AudioLocation.FRONT_LEFT | bap.AudioLocation.FRONT_RIGHT
),
sink_pac=[_sink_pac_record()],
source_audio_locations=bap.AudioLocation.FRONT_LEFT,
source_pac=[_source_pac_record()],
)
)
ascs = bap.AudioStreamControlService(
self.device, sink_ase_id=[1], source_ase_id=[2]
)
self.device.add_service(ascs)
advertising_data = bytes(
AdvertisingData(
[
(
AdvertisingData.COMPLETE_LOCAL_NAME,
bytes(device_config.name, 'utf-8'),
),
(
AdvertisingData.FLAGS,
bytes([AdvertisingData.LE_GENERAL_DISCOVERABLE_MODE_FLAG]),
),
(
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
bytes(bap.PublishedAudioCapabilitiesService.UUID),
),
]
)
) + bytes(bap.UnicastServerAdvertisingData())
def on_pdu(pdu: HCI_IsoDataPacket, ase: bap.AseStateMachine):
codec_config = ase.codec_specific_configuration
assert isinstance(codec_config, bap.CodecSpecificConfiguration)
pcm = decode(
codec_config.frame_duration.us,
codec_config.audio_channel_allocation.channel_count,
pdu.iso_sdu_fragment,
)
self.device.abort_on('disconnection', self.ui_server.send_audio(pcm))
def on_ase_state_change(ase: bap.AseStateMachine) -> None:
if ase.state == bap.AseStateMachine.State.STREAMING:
codec_config = ase.codec_specific_configuration
assert isinstance(codec_config, bap.CodecSpecificConfiguration)
assert ase.cis_link
if ase.role == bap.AudioRole.SOURCE:
ase.cis_link.abort_on(
'disconnection',
lc3_source_task(
filename=self.lc3_input_file_path,
sdu_length=(
codec_config.codec_frames_per_sdu
* codec_config.octets_per_codec_frame
),
frame_duration_us=codec_config.frame_duration.us,
device=self.device,
cis_handle=ase.cis_link.handle,
),
)
else:
ase.cis_link.sink = functools.partial(on_pdu, ase=ase)
elif ase.state == bap.AseStateMachine.State.CODEC_CONFIGURED:
codec_config = ase.codec_specific_configuration
assert isinstance(codec_config, bap.CodecSpecificConfiguration)
if ase.role == bap.AudioRole.SOURCE:
setup_encoders(
codec_config.sampling_frequency.hz,
codec_config.frame_duration.us,
codec_config.audio_channel_allocation.channel_count,
)
else:
setup_decoders(
codec_config.sampling_frequency.hz,
codec_config.frame_duration.us,
codec_config.audio_channel_allocation.channel_count,
)
for ase in ascs.ase_state_machines.values():
ase.on('state_change', functools.partial(on_ase_state_change, ase=ase))
await self.device.power_on()
await self.device.create_advertising_set(
advertising_data=advertising_data,
auto_restart=True,
advertising_parameters=AdvertisingParameters(
primary_advertising_interval_min=100,
primary_advertising_interval_max=100,
),
)
await hci_transport.source.terminated
@click.command()
@click.option(
'--ui-port',
'ui_port',
metavar='HTTP_PORT',
default=DEFAULT_UI_PORT,
show_default=True,
help='HTTP port for the UI server',
)
@click.option('--device-config', metavar='FILENAME', help='Device configuration file')
@click.argument('transport')
@click.argument('lc3_file')
def speaker(ui_port: int, device_config: str, transport: str, lc3_file: str) -> None:
"""Run the speaker."""
asyncio.run(Speaker(device_config, ui_port, transport, lc3_file).run())
# -----------------------------------------------------------------------------
def main():
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
speaker()
# -----------------------------------------------------------------------------
if __name__ == "__main__":
main() # pylint: disable=no-value-for-parameter
-68
View File
@@ -1,68 +0,0 @@
<html data-bs-theme="dark">
<head>
<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.2/dist/css/bootstrap.min.css" rel="stylesheet"
integrity="sha384-T3c6CoIi6uLrA9TneNEoa7RxnatzjcDSCmG1MXxSR1GAsXEV/Dwwykc2MPK8M2HN" crossorigin="anonymous">
<script src="https://unpkg.com/pcm-player"></script>
</head>
<body>
<nav class="navbar navbar-dark bg-primary">
<div class="container">
<span class="navbar-brand mb-0 h1">Bumble Unicast Server</span>
</div>
</nav>
<br>
<div class="container">
<button type="button" class="btn btn-danger" id="connect-audio" onclick="connectAudio()">Connect Audio</button>
<button class="btn btn-primary" type="button" disabled>
<span class="spinner-border spinner-border-sm" id="ws-status-spinner" aria-hidden="true"></span>
<span role="status" id="ws-status">WebSocket Connecting...</span>
</button>
</div>
<script>
let player = null;
const wsStatus = document.getElementById("ws-status");
const wsStatusSpinner = document.getElementById("ws-status-spinner");
const socket = new WebSocket('ws://127.0.0.1:7654/channel');
socket.binaryType = "arraybuffer";
socket.onmessage = function (message) {
if (typeof message.data === 'string' || message.data instanceof String) {
console.log(`channel MESSAGE: ${message.data}`);
} else {
console.log(typeof (message.data))
// BINARY audio data.
if (player == null) return;
player.feed(message.data);
}
};
socket.onopen = (message) => {
wsStatusSpinner.remove();
wsStatus.textContent = "WebSocket Connected";
}
socket.onclose = (message) => {
wsStatus.textContent = "WebSocket Disconnected";
}
function connectAudio() {
player = new PCMPlayer({
inputCodec: 'Int16',
channels: 2,
sampleRate: 48000,
flushTime: 10,
});
const button = document.getElementById("connect-audio")
button.disabled = true;
button.textContent = "Audio Connected";
}
</script>
</div>
</body>
</html>
Binary file not shown.
+23 -36
View File
@@ -16,6 +16,7 @@
# Imports # Imports
# ---------------------------------------------------------------------------- # ----------------------------------------------------------------------------
import sys import sys
import websockets
import logging import logging
import json import json
import asyncio import asyncio
@@ -23,9 +24,7 @@ import argparse
import uuid import uuid
import os import os
from urllib.parse import urlparse from urllib.parse import urlparse
import websockets from colors import color
from bumble.colors import color
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -66,9 +65,9 @@ class Connection:
""" """
def __init__(self, room, websocket): def __init__(self, room, websocket):
self.room = room self.room = room
self.websocket = websocket self.websocket = websocket
self.address = str(uuid.uuid4()) self.address = str(uuid.uuid4())
async def send_message(self, message): async def send_message(self, message):
try: try:
@@ -99,11 +98,7 @@ class Connection:
self.address = address self.address = address
def __str__(self): def __str__(self):
return ( return f'Connection(address="{self.address}", client={self.websocket.remote_address[0]}:{self.websocket.remote_address[1]})'
f'Connection(address="{self.address}", '
f'client={self.websocket.remote_address[0]}:'
f'{self.websocket.remote_address[1]})'
)
# ---------------------------------------------------------------------------- # ----------------------------------------------------------------------------
@@ -115,9 +110,9 @@ class Room:
""" """
def __init__(self, relay, name): def __init__(self, relay, name):
self.relay = relay self.relay = relay
self.name = name self.name = name
self.observers = [] self.observers = []
self.connections = [] self.connections = []
async def add_connection(self, connection): async def add_connection(self, connection):
@@ -144,15 +139,13 @@ class Room:
# Parse the message to decide how to handle it # Parse the message to decide how to handle it
if message.startswith('@'): if message.startswith('@'):
# This is a targeted message # This is a targetted message
await self.on_targeted_message(connection, message) await self.on_targetted_message(connection, message)
elif message.startswith('/'): elif message.startswith('/'):
# This is an RPC request # This is an RPC request
await self.on_rpc_request(connection, message) await self.on_rpc_request(connection, message)
else: else:
await connection.send_message( await connection.send_message(f'result:{error_to_json("error: invalid message")}')
f'result:{error_to_json("error: invalid message")}'
)
async def broadcast_message(self, sender, message): async def broadcast_message(self, sender, message):
''' '''
@@ -162,9 +155,7 @@ class Room:
async def on_rpc_request(self, connection, message): async def on_rpc_request(self, connection, message):
command, *params = message.split(' ', 1) command, *params = message.split(' ', 1)
if handler := getattr( if handler := getattr(self, f'on_{command[1:].lower().replace("-","_")}_command', None):
self, f'on_{command[1:].lower().replace("-","_")}_command', None
):
try: try:
result = await handler(connection, params) result = await handler(connection, params)
except Exception as error: except Exception as error:
@@ -174,7 +165,7 @@ class Room:
await connection.send_message(result or 'result:{}') await connection.send_message(result or 'result:{}')
async def on_targeted_message(self, connection, message): async def on_targetted_message(self, connection, message):
target, *payload = message.split(' ', 1) target, *payload = message.split(' ', 1)
if not payload: if not payload:
return error_to_json('missing arguments') return error_to_json('missing arguments')
@@ -183,8 +174,7 @@ class Room:
# Determine what targets to send to # Determine what targets to send to
if target == '*': if target == '*':
# Send to all connections in the room except the connection from which the # Send to all connections in the room except the connection from which the message was received
# message was received
connections = [c for c in self.connections if c != connection] connections = [c for c in self.connections if c != connection]
else: else:
connections = self.find_connections_by_address(target) connections = self.find_connections_by_address(target)
@@ -202,9 +192,7 @@ class Room:
current_address = connection.address current_address = connection.address
new_address = params[0] new_address = params[0]
connection.set_address(new_address) connection.set_address(new_address)
await self.broadcast_message( await self.broadcast_message(connection, f'address-changed:from={current_address},to={new_address}')
connection, f'address-changed:from={current_address},to={new_address}'
)
# ---------------------------------------------------------------------------- # ----------------------------------------------------------------------------
@@ -222,10 +210,9 @@ class Relay:
def start(self): def start(self):
logger.info(f'Starting Relay on port {self.port}') logger.info(f'Starting Relay on port {self.port}')
# pylint: disable-next=no-member
return websockets.serve(self.serve, '0.0.0.0', self.port, ping_interval=None) return websockets.serve(self.serve, '0.0.0.0', self.port, ping_interval=None)
async def serve_as_controller(self, connection): async def serve_as_controller(connection):
pass pass
async def serve(self, websocket, path): async def serve(self, websocket, path):
@@ -259,24 +246,24 @@ def main():
print('ERROR: Python 3.6.1 or higher is required') print('ERROR: Python 3.6.1 or higher is required')
sys.exit(1) sys.exit(1)
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
# Parse arguments # Parse arguments
arg_parser = argparse.ArgumentParser(description='Bumble Link Relay') arg_parser = argparse.ArgumentParser(description='Bumble Link Relay')
arg_parser.add_argument('--log-level', default='INFO', help='logger level') arg_parser.add_argument('--log-level', default='INFO', help='logger level')
arg_parser.add_argument('--log-config', help='logger config file (YAML)') arg_parser.add_argument('--log-config', help='logger config file (YAML)')
arg_parser.add_argument( arg_parser.add_argument('--port',
'--port', type=int, default=DEFAULT_RELAY_PORT, help='Port to listen on' type = int,
) default = DEFAULT_RELAY_PORT,
help = 'Port to listen on')
args = arg_parser.parse_args() args = arg_parser.parse_args()
# Setup logger # Setup logger
if args.log_config: if args.log_config:
from logging import config # pylint: disable=import-outside-toplevel from logging import config
config.fileConfig(args.log_config) config.fileConfig(args.log_config)
else: else:
logging.basicConfig(level=getattr(logging, args.log_level.upper())) logging.basicConfig(level = getattr(logging, args.log_level.upper()))
# Start a relay # Start a relay
relay = Relay(args.port) relay = Relay(args.port)
+118 -314
View File
@@ -19,79 +19,44 @@ import asyncio
import os import os
import logging import logging
import click import click
from prompt_toolkit.shortcuts import PromptSession import aioconsole
from colors import color
from bumble.colors import color
from bumble.device import Device, Peer from bumble.device import Device, Peer
from bumble.transport import open_transport_or_link from bumble.transport import open_transport_or_link
from bumble.pairing import OobData, PairingDelegate, PairingConfig from bumble.smp import PairingDelegate, PairingConfig
from bumble.smp import OobContext, OobLegacyContext
from bumble.smp import error_name as smp_error_name from bumble.smp import error_name as smp_error_name
from bumble.keys import JsonKeyStore from bumble.keys import JsonKeyStore
from bumble.core import ( from bumble.core import ProtocolError
AdvertisingData,
ProtocolError,
BT_LE_TRANSPORT,
BT_BR_EDR_TRANSPORT,
)
from bumble.gatt import ( from bumble.gatt import (
GATT_DEVICE_NAME_CHARACTERISTIC, GATT_DEVICE_NAME_CHARACTERISTIC,
GATT_GENERIC_ACCESS_SERVICE, GATT_GENERIC_ACCESS_SERVICE,
Service, Service,
Characteristic, Characteristic,
CharacteristicValue, CharacteristicValue
) )
from bumble.att import ( from bumble.att import (
ATT_Error, ATT_Error,
ATT_INSUFFICIENT_AUTHENTICATION_ERROR, ATT_INSUFFICIENT_AUTHENTICATION_ERROR,
ATT_INSUFFICIENT_ENCRYPTION_ERROR, ATT_INSUFFICIENT_ENCRYPTION_ERROR
) )
# -----------------------------------------------------------------------------
class Waiter:
instance = None
def __init__(self, linger=False):
self.done = asyncio.get_running_loop().create_future()
self.linger = linger
def terminate(self):
if not self.linger:
self.done.set_result(None)
async def wait_until_terminated(self):
return await self.done
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Delegate(PairingDelegate): class Delegate(PairingDelegate):
def __init__(self, mode, connection, capability_string, do_prompt): def __init__(self, mode, connection, capability_string, prompt):
super().__init__( super().__init__({
io_capability={ 'keyboard': PairingDelegate.KEYBOARD_INPUT_ONLY,
'keyboard': PairingDelegate.KEYBOARD_INPUT_ONLY, 'display': PairingDelegate.DISPLAY_OUTPUT_ONLY,
'display': PairingDelegate.DISPLAY_OUTPUT_ONLY, 'display+keyboard': PairingDelegate.DISPLAY_OUTPUT_AND_KEYBOARD_INPUT,
'display+keyboard': PairingDelegate.DISPLAY_OUTPUT_AND_KEYBOARD_INPUT, 'display+yes/no': PairingDelegate.DISPLAY_OUTPUT_AND_YES_NO_INPUT,
'display+yes/no': PairingDelegate.DISPLAY_OUTPUT_AND_YES_NO_INPUT, 'none': PairingDelegate.NO_OUTPUT_NO_INPUT
'none': PairingDelegate.NO_OUTPUT_NO_INPUT, }[capability_string.lower()])
}[capability_string.lower()]
)
self.mode = mode self.mode = mode
self.peer = Peer(connection) self.peer = Peer(connection)
self.peer_name = None self.peer_name = None
self.do_prompt = do_prompt self.prompt = prompt
def print(self, message):
print(color(message, 'yellow'))
async def prompt(self, message):
# Wait a bit to allow some of the log lines to print before we prompt
await asyncio.sleep(1)
session = PromptSession(message)
response = await session.prompt_async()
return response.lower().strip()
async def update_peer_name(self): async def update_peer_name(self):
if self.peer_name is not None: if self.peer_name is not None:
@@ -106,103 +71,87 @@ class Delegate(PairingDelegate):
self.peer_name = '[?]' self.peer_name = '[?]'
async def accept(self): async def accept(self):
if self.do_prompt: if self.prompt:
await self.update_peer_name() await self.update_peer_name()
# Prompt for acceptance # Wait a bit to allow some of the log lines to print before we prompt
self.print('###-----------------------------------') await asyncio.sleep(1)
self.print(f'### Pairing request from {self.peer_name}')
self.print('###-----------------------------------')
while True:
response = await self.prompt('>>> Accept? ')
# Prompt for acceptance
print(color('###-----------------------------------', 'yellow'))
print(color(f'### Pairing request from {self.peer_name}', 'yellow'))
print(color('###-----------------------------------', 'yellow'))
while True:
response = await aioconsole.ainput(color('>>> Accept? ', 'yellow'))
response = response.lower().strip()
if response == 'yes': if response == 'yes':
return True return True
elif response == 'no':
if response == 'no':
return False return False
else:
# Accept silently # Accept silently
return True return True
async def compare_numbers(self, number, digits): async def compare_numbers(self, number, digits):
await self.update_peer_name() await self.update_peer_name()
# Prompt for a numeric comparison # Wait a bit to allow some of the log lines to print before we prompt
self.print('###-----------------------------------') await asyncio.sleep(1)
self.print(f'### Pairing with {self.peer_name}')
self.print('###-----------------------------------')
while True:
response = await self.prompt(
f'>>> Does the other device display {number:0{digits}}? '
)
# Prompt for a numeric comparison
print(color('###-----------------------------------', 'yellow'))
print(color(f'### Pairing with {self.peer_name}', 'yellow'))
print(color('###-----------------------------------', 'yellow'))
while True:
response = await aioconsole.ainput(color(f'>>> Does the other device display {number:0{digits}}? ', 'yellow'))
response = response.lower().strip()
if response == 'yes': if response == 'yes':
return True return True
elif response == 'no':
if response == 'no':
return False return False
async def get_number(self): async def get_number(self):
await self.update_peer_name() await self.update_peer_name()
# Wait a bit to allow some of the log lines to print before we prompt
await asyncio.sleep(1)
# Prompt for a PIN # Prompt for a PIN
while True: while True:
try: try:
self.print('###-----------------------------------') print(color('###-----------------------------------', 'yellow'))
self.print(f'### Pairing with {self.peer_name}') print(color(f'### Pairing with {self.peer_name}', 'yellow'))
self.print('###-----------------------------------') print(color('###-----------------------------------', 'yellow'))
return int(await self.prompt('>>> Enter PIN: ')) return int(await aioconsole.ainput(color('>>> Enter PIN: ', 'yellow')))
except ValueError: except ValueError:
pass pass
async def display_number(self, number, digits): async def display_number(self, number, digits):
await self.update_peer_name() await self.update_peer_name()
# Wait a bit to allow some of the log lines to print before we prompt
await asyncio.sleep(1)
# Display a PIN code # Display a PIN code
self.print('###-----------------------------------') print(color('###-----------------------------------', 'yellow'))
self.print(f'### Pairing with {self.peer_name}') print(color(f'### Pairing with {self.peer_name}', 'yellow'))
self.print(f'### PIN: {number:0{digits}}') print(color(f'### PIN: {number:0{digits}}', 'yellow'))
self.print('###-----------------------------------') print(color('###-----------------------------------', 'yellow'))
async def get_string(self, max_length: int):
await self.update_peer_name()
# Prompt a PIN (for legacy pairing in classic)
self.print('###-----------------------------------')
self.print(f'### Pairing with {self.peer_name}')
self.print('###-----------------------------------')
count = 0
while True:
response = await self.prompt('>>> Enter PIN (1-6 chars):')
if len(response) == 0:
count += 1
if count > 3:
self.print('too many tries, stopping the pairing')
return None
self.print('no PIN was entered, try again')
continue
return response
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def get_peer_name(peer, mode): async def get_peer_name(peer, mode):
if mode == 'classic': if mode == 'classic':
return await peer.request_name() return await peer.request_name()
else:
# Try to get the peer name from GATT
services = await peer.discover_service(GATT_GENERIC_ACCESS_SERVICE)
if not services:
return None
# Try to get the peer name from GATT values = await peer.read_characteristics_by_uuid(GATT_DEVICE_NAME_CHARACTERISTIC, services[0])
services = await peer.discover_service(GATT_GENERIC_ACCESS_SERVICE) if values:
if not services: return values[0].decode('utf-8')
return None
values = await peer.read_characteristics_by_uuid(
GATT_DEVICE_NAME_CHARACTERISTIC, services[0]
)
if values:
return values[0].decode('utf-8')
return None
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -215,12 +164,12 @@ def read_with_error(connection):
if AUTHENTICATION_ERROR_RETURNED[0]: if AUTHENTICATION_ERROR_RETURNED[0]:
return bytes([1]) return bytes([1])
else:
AUTHENTICATION_ERROR_RETURNED[0] = True AUTHENTICATION_ERROR_RETURNED[0] = True
raise ATT_Error(ATT_INSUFFICIENT_AUTHENTICATION_ERROR) raise ATT_Error(ATT_INSUFFICIENT_AUTHENTICATION_ERROR)
def write_with_error(connection, _value): def write_with_error(connection, value):
if not connection.is_encrypted: if not connection.is_encrypted:
raise ATT_Error(ATT_INSUFFICIENT_ENCRYPTION_ERROR) raise ATT_Error(ATT_INSUFFICIENT_ENCRYPTION_ERROR)
@@ -234,14 +183,14 @@ def on_connection(connection, request):
print(color(f'<<< Connection: {connection}', 'green')) print(color(f'<<< Connection: {connection}', 'green'))
# Listen for pairing events # Listen for pairing events
connection.on('pairing_start', on_pairing_start) connection.on('pairing_start', on_pairing_start)
connection.on('pairing', lambda keys: on_pairing(connection.peer_address, keys)) connection.on('pairing', on_pairing)
connection.on('pairing_failure', on_pairing_failure) connection.on('pairing_failure', on_pairing_failure)
# Listen for encryption changes # Listen for encryption changes
connection.on( connection.on(
'connection_encryption_change', 'connection_encryption_change',
lambda: on_connection_encryption_change(connection), lambda: on_connection_encryption_change(connection)
) )
# Request pairing if needed # Request pairing if needed
@@ -253,12 +202,7 @@ def on_connection(connection, request):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def on_connection_encryption_change(connection): def on_connection_encryption_change(connection):
print(color('@@@-----------------------------------', 'blue')) print(color('@@@-----------------------------------', 'blue'))
print( print(color(f'@@@ Connection is {"" if connection.is_encrypted else "not"}encrypted', 'blue'))
color(
f'@@@ Connection is {"" if connection.is_encrypted else "not"}encrypted',
'blue',
)
)
print(color('@@@-----------------------------------', 'blue')) print(color('@@@-----------------------------------', 'blue'))
@@ -270,12 +214,11 @@ def on_pairing_start():
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def on_pairing(address, keys): def on_pairing(keys):
print(color('***-----------------------------------', 'cyan')) print(color('***-----------------------------------', 'cyan'))
print(color(f'*** Paired! (peer identity={address})', 'cyan')) print(color('*** Paired!', 'cyan'))
keys.print(prefix=color('*** ', 'cyan')) keys.print(prefix=color('*** ', 'cyan'))
print(color('***-----------------------------------', 'cyan')) print(color('***-----------------------------------', 'cyan'))
Waiter.instance.terminate()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -283,7 +226,6 @@ def on_pairing_failure(reason):
print(color('***-----------------------------------', 'red')) print(color('***-----------------------------------', 'red'))
print(color(f'*** Pairing failed: {smp_error_name(reason)}', 'red')) print(color(f'*** Pairing failed: {smp_error_name(reason)}', 'red'))
print(color('***-----------------------------------', 'red')) print(color('***-----------------------------------', 'red'))
Waiter.instance.terminate()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -292,20 +234,15 @@ async def pair(
sc, sc,
mitm, mitm,
bond, bond,
ctkd,
linger,
io, io,
oob,
prompt, prompt,
request, request,
print_keys, print_keys,
keystore_file, keystore_file,
device_config, device_config,
hci_transport, hci_transport,
address_or_name, address_or_name
): ):
Waiter.instance = Waiter(linger=linger)
print('<<< connecting to HCI...') print('<<< connecting to HCI...')
async with await open_transport_or_link(hci_transport) as (hci_source, hci_sink): async with await open_transport_or_link(hci_transport) as (hci_source, hci_sink):
print('<<< connected') print('<<< connected')
@@ -313,38 +250,9 @@ async def pair(
# Create a device to manage the host # Create a device to manage the host
device = Device.from_config_file_with_hci(device_config, hci_source, hci_sink) device = Device.from_config_file_with_hci(device_config, hci_source, hci_sink)
# Expose a GATT characteristic that can be used to trigger pairing by
# responding with an authentication error when read
if mode == 'le':
device.le_enabled = True
device.add_service(
Service(
'50DB505C-8AC4-4738-8448-3B1D9CC09CC5',
[
Characteristic(
'552957FB-CF1F-4A31-9535-E78847E1A714',
Characteristic.Properties.READ
| Characteristic.Properties.WRITE,
Characteristic.READABLE | Characteristic.WRITEABLE,
CharacteristicValue(
read=read_with_error, write=write_with_error
),
)
],
)
)
# Select LE or Classic
if mode == 'classic':
device.classic_enabled = True
device.classic_smp_enabled = ctkd
# Get things going
await device.power_on()
# Set a custom keystore if specified on the command line # Set a custom keystore if specified on the command line
if keystore_file: if keystore_file:
device.keystore = JsonKeyStore.from_device(device, filename=keystore_file) device.keystore = JsonKeyStore(namespace=None, filename=keystore_file)
# Print the existing keys before pairing # Print the existing keys before pairing
if print_keys and device.keystore: if print_keys and device.keystore:
@@ -353,51 +261,44 @@ async def pair(
await device.keystore.print(prefix=color('@@@ ', 'blue')) await device.keystore.print(prefix=color('@@@ ', 'blue'))
print(color('@@@-----------------------------------', 'blue')) print(color('@@@-----------------------------------', 'blue'))
# Create an OOB context if needed # Expose a GATT characteristic that can be used to trigger pairing by
if oob: # responding with an authentication error when read
our_oob_context = OobContext() if mode == 'le':
shared_data = ( device.add_service(
None Service(
if oob == '-' '50DB505C-8AC4-4738-8448-3B1D9CC09CC5',
else OobData.from_ad(AdvertisingData.from_bytes(bytes.fromhex(oob))) [
Characteristic(
'552957FB-CF1F-4A31-9535-E78847E1A714',
Characteristic.READ | Characteristic.WRITE,
Characteristic.READABLE | Characteristic.WRITEABLE,
CharacteristicValue(read=read_with_error, write=write_with_error)
)
]
)
) )
legacy_context = OobLegacyContext()
oob_contexts = PairingConfig.OobConfig( # Select LE or Classic
our_context=our_oob_context, if mode == 'classic':
peer_data=shared_data, device.classic_enabled = True
legacy_context=legacy_context, device.le_enabled = False
)
oob_data = OobData( # Get things going
address=device.random_address, await device.power_on()
shared_data=shared_data,
legacy_context=legacy_context,
)
print(color('@@@-----------------------------------', 'yellow'))
print(color('@@@ OOB Data:', 'yellow'))
print(color(f'@@@ {our_oob_context.share()}', 'yellow'))
print(color(f'@@@ TK={legacy_context.tk.hex()}', 'yellow'))
print(color(f'@@@ HEX: ({bytes(oob_data.to_ad()).hex()})', 'yellow'))
print(color('@@@-----------------------------------', 'yellow'))
else:
oob_contexts = None
# Set up a pairing config factory # Set up a pairing config factory
device.pairing_config_factory = lambda connection: PairingConfig( device.pairing_config_factory = lambda connection: PairingConfig(
sc=sc, sc,
mitm=mitm, mitm,
bonding=bond, bond,
oob=oob_contexts, Delegate(mode, connection, io, prompt)
delegate=Delegate(mode, connection, io, prompt),
) )
# Connect to a peer or wait for a connection # Connect to a peer or wait for a connection
device.on('connection', lambda connection: on_connection(connection, request)) device.on('connection', lambda connection: on_connection(connection, request))
if address_or_name is not None: if address_or_name is not None:
print(color(f'=== Connecting to {address_or_name}...', 'green')) print(color(f'=== Connecting to {address_or_name}...', 'green'))
connection = await device.connect( connection = await device.connect(address_or_name)
address_or_name,
transport=BT_LE_TRANSPORT if mode == 'le' else BT_BR_EDR_TRANSPORT,
)
if not request: if not request:
try: try:
@@ -405,131 +306,34 @@ async def pair(
await connection.pair() await connection.pair()
else: else:
await connection.authenticate() await connection.authenticate()
return
except ProtocolError as error: except ProtocolError as error:
print(color(f'Pairing failed: {error}', 'red')) print(color(f'Pairing failed: {error}', 'red'))
return
else: else:
if mode == 'le': # Advertise so that peers can find us and connect
# Advertise so that peers can find us and connect await device.start_advertising(auto_restart=True)
await device.start_advertising(auto_restart=True)
else:
# Become discoverable and connectable
await device.set_discoverable(True)
await device.set_connectable(True)
# Run until the user asks to exit await hci_source.wait_for_termination()
await Waiter.instance.wait_until_terminated()
# -----------------------------------------------------------------------------
class LogHandler(logging.Handler):
def __init__(self):
super().__init__()
self.setFormatter(logging.Formatter('%(levelname)s:%(name)s:%(message)s'))
def emit(self, record):
message = self.format(record)
print(message)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@click.command() @click.command()
@click.option( @click.option('--mode', type=click.Choice(['le', 'classic']), default='le', show_default=True)
'--mode', type=click.Choice(['le', 'classic']), default='le', show_default=True @click.option('--sc', type=bool, default=True, help='Use the Secure Connections protocol', show_default=True)
) @click.option('--mitm', type=bool, default=True, help='Request MITM protection', show_default=True)
@click.option( @click.option('--bond', type=bool, default=True, help='Enable bonding', show_default=True)
'--sc', @click.option('--io', type=click.Choice(['keyboard', 'display', 'display+keyboard', 'display+yes/no', 'none']), default='display+keyboard', show_default=True)
type=bool,
default=True,
help='Use the Secure Connections protocol',
show_default=True,
)
@click.option(
'--mitm', type=bool, default=True, help='Request MITM protection', show_default=True
)
@click.option(
'--bond', type=bool, default=True, help='Enable bonding', show_default=True
)
@click.option(
'--ctkd',
type=bool,
default=True,
help='Enable CTKD',
show_default=True,
)
@click.option('--linger', default=False, is_flag=True, help='Linger after pairing')
@click.option(
'--io',
type=click.Choice(
['keyboard', 'display', 'display+keyboard', 'display+yes/no', 'none']
),
default='display+keyboard',
show_default=True,
)
@click.option(
'--oob',
metavar='<oob-data-hex>',
help=(
'Use OOB pairing with this data from the peer '
'(use "-" to enable OOB without peer data)'
),
)
@click.option('--prompt', is_flag=True, help='Prompt to accept/reject pairing request') @click.option('--prompt', is_flag=True, help='Prompt to accept/reject pairing request')
@click.option( @click.option('--request', is_flag=True, help='Request that the connecting peer initiate pairing')
'--request', is_flag=True, help='Request that the connecting peer initiate pairing'
)
@click.option('--print-keys', is_flag=True, help='Print the bond keys before pairing') @click.option('--print-keys', is_flag=True, help='Print the bond keys before pairing')
@click.option( @click.option('--keystore-file', help='File in which to store the pairing keys')
'--keystore-file',
metavar='<filename>',
help='File in which to store the pairing keys',
)
@click.argument('device-config') @click.argument('device-config')
@click.argument('hci_transport') @click.argument('hci_transport')
@click.argument('address-or-name', required=False) @click.argument('address-or-name', required=False)
def main( def main(mode, sc, mitm, bond, io, prompt, request, print_keys, keystore_file, device_config, hci_transport, address_or_name):
mode, logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
sc, asyncio.run(pair(mode, sc, mitm, bond, io, prompt, request, print_keys, keystore_file, device_config, hci_transport, address_or_name))
mitm,
bond,
ctkd,
linger,
io,
oob,
prompt,
request,
print_keys,
keystore_file,
device_config,
hci_transport,
address_or_name,
):
# Setup logging
log_handler = LogHandler()
root_logger = logging.getLogger()
root_logger.addHandler(log_handler)
root_logger.setLevel(os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
# Pair
asyncio.run(
pair(
mode,
sc,
mitm,
bond,
ctkd,
linger,
io,
oob,
prompt,
request,
print_keys,
keystore_file,
device_config,
hci_transport,
address_or_name,
)
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
-51
View File
@@ -1,51 +0,0 @@
import asyncio
import click
import logging
import json
from bumble.pandora import PandoraDevice, Config, serve
from typing import Dict, Any
BUMBLE_SERVER_GRPC_PORT = 7999
ROOTCANAL_PORT_CUTTLEFISH = 7300
@click.command()
@click.option('--grpc-port', help='gRPC port to serve', default=BUMBLE_SERVER_GRPC_PORT)
@click.option(
'--rootcanal-port', help='Rootcanal TCP port', default=ROOTCANAL_PORT_CUTTLEFISH
)
@click.option(
'--transport',
help='HCI transport',
default=f'tcp-client:127.0.0.1:<rootcanal-port>',
)
@click.option(
'--config',
help='Bumble json configuration file',
)
def main(grpc_port: int, rootcanal_port: int, transport: str, config: str) -> None:
if '<rootcanal-port>' in transport:
transport = transport.replace('<rootcanal-port>', str(rootcanal_port))
bumble_config = retrieve_config(config)
bumble_config.setdefault('transport', transport)
device = PandoraDevice(bumble_config)
server_config = Config()
server_config.load_from_dict(bumble_config.get('server', {}))
logging.basicConfig(level=logging.DEBUG)
asyncio.run(serve(device, config=server_config, port=grpc_port))
def retrieve_config(config: str) -> Dict[str, Any]:
if not config:
return {}
with open(config, 'r') as f:
return json.load(f)
if __name__ == '__main__':
main() # pylint: disable=no-value-for-parameter
-511
View File
@@ -1,511 +0,0 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import logging
import os
import time
from typing import Optional
import click
from bumble.colors import color
from bumble.device import Device, DeviceConfiguration, Connection
from bumble import core
from bumble import hci
from bumble import rfcomm
from bumble import transport
from bumble import utils
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
DEFAULT_RFCOMM_UUID = "E6D55659-C8B4-4B85-96BB-B1143AF6D3AE"
DEFAULT_MTU = 4096
DEFAULT_CLIENT_TCP_PORT = 9544
DEFAULT_SERVER_TCP_PORT = 9545
TRACE_MAX_SIZE = 48
# -----------------------------------------------------------------------------
class Tracer:
"""
Trace data buffers transmitted from one endpoint to another, with stats.
"""
def __init__(self, channel_name: str) -> None:
self.channel_name = channel_name
self.last_ts: float = 0.0
def trace_data(self, data: bytes) -> None:
now = time.time()
elapsed_s = now - self.last_ts if self.last_ts else 0
elapsed_ms = int(elapsed_s * 1000)
instant_throughput_kbps = ((len(data) / elapsed_s) / 1000) if elapsed_s else 0.0
hex_str = data[:TRACE_MAX_SIZE].hex() + (
"..." if len(data) > TRACE_MAX_SIZE else ""
)
print(
f"[{self.channel_name}] {len(data):4} bytes "
f"(+{elapsed_ms:4}ms, {instant_throughput_kbps: 7.2f}kB/s) "
f" {hex_str}"
)
self.last_ts = now
# -----------------------------------------------------------------------------
class ServerBridge:
"""
RFCOMM server bridge: waits for a peer to connect an RFCOMM channel.
The RFCOMM channel may be associated with a UUID published in an SDP service
description, or simply be on a system-assigned channel number.
When the connection is made, the bridge connects a TCP socket to a remote host and
bridges the data in both directions, with flow control.
When the RFCOMM channel is closed, the bridge disconnects the TCP socket
and waits for a new channel to be connected.
"""
READ_CHUNK_SIZE = 4096
def __init__(
self, channel: int, uuid: str, trace: bool, tcp_host: str, tcp_port: int
) -> None:
self.device: Optional[Device] = None
self.channel = channel
self.uuid = uuid
self.tcp_host = tcp_host
self.tcp_port = tcp_port
self.rfcomm_channel: Optional[rfcomm.DLC] = None
self.tcp_tracer: Optional[Tracer]
self.rfcomm_tracer: Optional[Tracer]
if trace:
self.tcp_tracer = Tracer(color("RFCOMM->TCP", "cyan"))
self.rfcomm_tracer = Tracer(color("TCP->RFCOMM", "magenta"))
else:
self.rfcomm_tracer = None
self.tcp_tracer = None
async def start(self, device: Device) -> None:
self.device = device
# Create and register a server
rfcomm_server = rfcomm.Server(self.device)
# Listen for incoming DLC connections
self.channel = rfcomm_server.listen(self.on_rfcomm_channel, self.channel)
# Setup the SDP to advertise this channel
service_record_handle = 0x00010001
self.device.sdp_service_records = {
service_record_handle: rfcomm.make_service_sdp_records(
service_record_handle, self.channel, core.UUID(self.uuid)
)
}
# We're ready for a connection
self.device.on("connection", self.on_connection)
await self.set_available(True)
print(
color(
(
f"### Listening for RFCOMM connection on {device.public_address}, "
f"channel {self.channel}"
),
"yellow",
)
)
async def set_available(self, available: bool):
# Become discoverable and connectable
assert self.device
await self.device.set_connectable(available)
await self.device.set_discoverable(available)
def on_connection(self, connection):
print(color(f"@@@ Bluetooth connection: {connection}", "blue"))
connection.on("disconnection", self.on_disconnection)
# Don't accept new connections until we're disconnected
utils.AsyncRunner.spawn(self.set_available(False))
def on_disconnection(self, reason: int):
print(
color("@@@ Bluetooth disconnection:", "red"),
hci.HCI_Constant.error_name(reason),
)
# We're ready for a new connection
utils.AsyncRunner.spawn(self.set_available(True))
# Called when an RFCOMM channel is established
@utils.AsyncRunner.run_in_task()
async def on_rfcomm_channel(self, rfcomm_channel):
print(color("*** RFCOMM channel:", "cyan"), rfcomm_channel)
# Connect to the TCP server
print(
color(
f"### Connecting to TCP {self.tcp_host}:{self.tcp_port}",
"yellow",
)
)
try:
reader, writer = await asyncio.open_connection(self.tcp_host, self.tcp_port)
except OSError:
print(color("!!! Connection failed", "red"))
await rfcomm_channel.disconnect()
return
# Pipe data from RFCOMM to TCP
def on_rfcomm_channel_closed():
print(color("*** RFCOMM channel closed", "cyan"))
writer.close()
def write_rfcomm_data(data):
if self.rfcomm_tracer:
self.rfcomm_tracer.trace_data(data)
writer.write(data)
rfcomm_channel.sink = write_rfcomm_data
rfcomm_channel.on("close", on_rfcomm_channel_closed)
# Pipe data from TCP to RFCOMM
while True:
try:
data = await reader.read(self.READ_CHUNK_SIZE)
if len(data) == 0:
print(color("### TCP end of stream", "yellow"))
if rfcomm_channel.state == rfcomm.DLC.State.CONNECTED:
await rfcomm_channel.disconnect()
return
if self.tcp_tracer:
self.tcp_tracer.trace_data(data)
rfcomm_channel.write(data)
await rfcomm_channel.drain()
except Exception as error:
print(f"!!! Exception: {error}")
break
writer.close()
await writer.wait_closed()
print(color("~~~ Bye bye", "magenta"))
# -----------------------------------------------------------------------------
class ClientBridge:
"""
RFCOMM client bridge: connects to a BR/EDR device, then waits for an inbound
TCP connection on a specified port number. When a TCP client connects, an
RFCOMM connection to the device is established, and the data is bridged in both
directions, with flow control.
When the TCP connection is closed by the client, the RFCOMM channel is
disconnected, but the connection to the device remains, ready for a new TCP client
to connect.
"""
READ_CHUNK_SIZE = 4096
def __init__(
self,
channel: int,
uuid: str,
trace: bool,
address: str,
tcp_host: str,
tcp_port: int,
encrypt: bool,
):
self.channel = channel
self.uuid = uuid
self.trace = trace
self.address = address
self.tcp_host = tcp_host
self.tcp_port = tcp_port
self.encrypt = encrypt
self.device: Optional[Device] = None
self.connection: Optional[Connection] = None
self.rfcomm_client: Optional[rfcomm.Client]
self.rfcomm_mux: Optional[rfcomm.Multiplexer]
self.tcp_connected: bool = False
self.tcp_tracer: Optional[Tracer]
self.rfcomm_tracer: Optional[Tracer]
if trace:
self.tcp_tracer = Tracer(color("RFCOMM->TCP", "cyan"))
self.rfcomm_tracer = Tracer(color("TCP->RFCOMM", "magenta"))
else:
self.rfcomm_tracer = None
self.tcp_tracer = None
async def connect(self) -> None:
if self.connection:
return
print(color(f"@@@ Connecting to Bluetooth {self.address}", "blue"))
assert self.device
self.connection = await self.device.connect(
self.address, transport=core.BT_BR_EDR_TRANSPORT
)
print(color(f"@@@ Bluetooth connection: {self.connection}", "blue"))
self.connection.on("disconnection", self.on_disconnection)
if self.encrypt:
print(color("@@@ Encrypting Bluetooth connection", "blue"))
await self.connection.encrypt()
print(color("@@@ Bluetooth connection encrypted", "blue"))
self.rfcomm_client = rfcomm.Client(self.connection)
try:
self.rfcomm_mux = await self.rfcomm_client.start()
except BaseException as e:
print(color("!!! Failed to setup RFCOMM connection", "red"), e)
raise
async def start(self, device: Device) -> None:
self.device = device
await device.set_connectable(False)
await device.set_discoverable(False)
# Called when a TCP connection is established
async def on_tcp_connection(reader, writer):
print(color("<<< TCP connection", "magenta"))
if self.tcp_connected:
print(
color("!!! TCP connection already active, rejecting new one", "red")
)
writer.close()
return
self.tcp_connected = True
try:
await self.pipe(reader, writer)
except BaseException as error:
print(color("!!! Exception while piping data:", "red"), error)
return
finally:
writer.close()
await writer.wait_closed()
self.tcp_connected = False
await asyncio.start_server(
on_tcp_connection,
host=self.tcp_host if self.tcp_host != "_" else None,
port=self.tcp_port,
)
print(
color(
f"### Listening for TCP connections on port {self.tcp_port}", "magenta"
)
)
async def pipe(
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
) -> None:
# Resolve the channel number from the UUID if needed
if self.channel == 0:
await self.connect()
assert self.connection
channel = await rfcomm.find_rfcomm_channel_with_uuid(
self.connection, self.uuid
)
if channel:
print(color(f"### Found RFCOMM channel {channel}", "yellow"))
else:
print(color(f"!!! RFCOMM channel with UUID {self.uuid} not found"))
return
else:
channel = self.channel
# Connect a new RFCOMM channel
await self.connect()
assert self.rfcomm_mux
print(color(f"*** Opening RFCOMM channel {channel}", "green"))
try:
rfcomm_channel = await self.rfcomm_mux.open_dlc(channel)
print(color(f"*** RFCOMM channel open: {rfcomm_channel}", "green"))
except Exception as error:
print(color(f"!!! RFCOMM open failed: {error}", "red"))
return
# Pipe data from RFCOMM to TCP
def on_rfcomm_channel_closed():
print(color("*** RFCOMM channel closed", "green"))
def write_rfcomm_data(data):
if self.trace:
self.rfcomm_tracer.trace_data(data)
writer.write(data)
rfcomm_channel.on("close", on_rfcomm_channel_closed)
rfcomm_channel.sink = write_rfcomm_data
# Pipe data from TCP to RFCOMM
while True:
try:
data = await reader.read(self.READ_CHUNK_SIZE)
if len(data) == 0:
print(color("### TCP end of stream", "yellow"))
if rfcomm_channel.state == rfcomm.DLC.State.CONNECTED:
await rfcomm_channel.disconnect()
self.tcp_connected = False
return
if self.tcp_tracer:
self.tcp_tracer.trace_data(data)
rfcomm_channel.write(data)
await rfcomm_channel.drain()
except Exception as error:
print(f"!!! Exception: {error}")
break
print(color("~~~ Bye bye", "magenta"))
def on_disconnection(self, reason: int) -> None:
print(
color("@@@ Bluetooth disconnection:", "red"),
hci.HCI_Constant.error_name(reason),
)
self.connection = None
# -----------------------------------------------------------------------------
async def run(device_config, hci_transport, bridge):
print("<<< connecting to HCI...")
async with await transport.open_transport_or_link(hci_transport) as (
hci_source,
hci_sink,
):
print("<<< connected")
if device_config:
device = Device.from_config_file_with_hci(
device_config, hci_source, hci_sink
)
else:
device = Device.from_config_with_hci(
DeviceConfiguration(), hci_source, hci_sink
)
device.classic_enabled = True
# Let's go
await device.power_on()
try:
await bridge.start(device)
# Wait until the transport terminates
await hci_source.wait_for_termination()
except core.ConnectionError as error:
print(color(f"!!! Bluetooth connection failed: {error}", "red"))
except Exception as error:
print(f"Exception while running bridge: {error}")
# -----------------------------------------------------------------------------
@click.group()
@click.pass_context
@click.option(
"--device-config",
metavar="CONFIG_FILE",
help="Device configuration file",
)
@click.option(
"--hci-transport", metavar="TRANSPORT_NAME", help="HCI transport", required=True
)
@click.option("--trace", is_flag=True, help="Trace bridged data to stdout")
@click.option(
"--channel",
metavar="CHANNEL_NUMER",
help="RFCOMM channel number",
type=int,
default=0,
)
@click.option(
"--uuid",
metavar="UUID",
help="UUID for the RFCOMM channel",
default=DEFAULT_RFCOMM_UUID,
)
def cli(
context,
device_config,
hci_transport,
trace,
channel,
uuid,
):
context.ensure_object(dict)
context.obj["device_config"] = device_config
context.obj["hci_transport"] = hci_transport
context.obj["trace"] = trace
context.obj["channel"] = channel
context.obj["uuid"] = uuid
# -----------------------------------------------------------------------------
@cli.command()
@click.pass_context
@click.option("--tcp-host", help="TCP host", default="localhost")
@click.option("--tcp-port", help="TCP port", default=DEFAULT_SERVER_TCP_PORT)
def server(context, tcp_host, tcp_port):
bridge = ServerBridge(
context.obj["channel"],
context.obj["uuid"],
context.obj["trace"],
tcp_host,
tcp_port,
)
asyncio.run(run(context.obj["device_config"], context.obj["hci_transport"], bridge))
# -----------------------------------------------------------------------------
@cli.command()
@click.pass_context
@click.argument("bluetooth-address")
@click.option("--tcp-host", help="TCP host", default="_")
@click.option("--tcp-port", help="TCP port", default=DEFAULT_CLIENT_TCP_PORT)
@click.option("--encrypt", is_flag=True, help="Encrypt the connection")
def client(context, bluetooth_address, tcp_host, tcp_port, encrypt):
bridge = ClientBridge(
context.obj["channel"],
context.obj["uuid"],
context.obj["trace"],
bluetooth_address,
tcp_host,
tcp_port,
encrypt,
)
asyncio.run(run(context.obj["device_config"], context.obj["hci_transport"], bridge))
# -----------------------------------------------------------------------------
logging.basicConfig(level=os.environ.get("BUMBLE_LOGLEVEL", "WARNING").upper())
if __name__ == "__main__":
cli(obj={}) # pylint: disable=no-value-for-parameter
+40 -143
View File
@@ -19,20 +19,20 @@ import asyncio
import os import os
import logging import logging
import click import click
from colors import color
from bumble.colors import color
from bumble.device import Device from bumble.device import Device
from bumble.transport import open_transport_or_link from bumble.transport import open_transport_or_link
from bumble.keys import JsonKeyStore from bumble.keys import JsonKeyStore
from bumble.smp import AddressResolver from bumble.smp import AddressResolver
from bumble.device import Advertisement from bumble.hci import HCI_LE_Advertising_Report_Event
from bumble.hci import Address, HCI_Constant, HCI_LE_1M_PHY, HCI_LE_CODED_PHY from bumble.core import AdvertisingData
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def make_rssi_bar(rssi): def make_rssi_bar(rssi):
DISPLAY_MIN_RSSI = -105 DISPLAY_MIN_RSSI = -105
DISPLAY_MAX_RSSI = -30 DISPLAY_MAX_RSSI = -30
DEFAULT_RSSI_BAR_WIDTH = 30 DEFAULT_RSSI_BAR_WIDTH = 30
blocks = ['', '', '', '', '', '', '', ''] blocks = ['', '', '', '', '', '', '', '']
@@ -48,33 +48,23 @@ class AdvertisementPrinter:
self.min_rssi = min_rssi self.min_rssi = min_rssi
self.resolver = resolver self.resolver = resolver
def print_advertisement(self, advertisement): def print_advertisement(self, address, address_color, ad_data, rssi):
address = advertisement.address if self.min_rssi is not None and rssi < self.min_rssi:
address_color = 'yellow' if advertisement.is_connectable else 'red'
if self.min_rssi is not None and advertisement.rssi < self.min_rssi:
return return
address_qualifier = '' address_qualifier = ''
resolution_qualifier = '' resolution_qualifier = ''
if self.resolver and advertisement.address.is_resolvable: if self.resolver and address.is_resolvable:
resolved = self.resolver.resolve(advertisement.address) resolved = self.resolver.resolve(address)
if resolved is not None: if resolved is not None:
resolution_qualifier = f'(resolved from {advertisement.address})' resolution_qualifier = f'(resolved from {address})'
address = resolved address = resolved
address_type_string = ('PUBLIC', 'RANDOM', 'PUBLIC_ID', 'RANDOM_ID')[ address_type_string = ('PUBLIC', 'RANDOM', 'PUBLIC_ID', 'RANDOM_ID')[address.address_type]
address.address_type if address.is_public:
] type_color = 'cyan'
if address.address_type in (
Address.RANDOM_IDENTITY_ADDRESS,
Address.PUBLIC_IDENTITY_ADDRESS,
):
type_color = 'yellow'
else: else:
if address.is_public: if address.is_static:
type_color = 'cyan'
elif address.is_static:
type_color = 'green' type_color = 'green'
address_qualifier = '(static)' address_qualifier = '(static)'
elif address.is_resolvable: elif address.is_resolvable:
@@ -84,32 +74,18 @@ class AdvertisementPrinter:
type_color = 'blue' type_color = 'blue'
address_qualifier = '(non-resolvable)' address_qualifier = '(non-resolvable)'
rssi_bar = make_rssi_bar(rssi)
separator = '\n ' separator = '\n '
rssi_bar = make_rssi_bar(advertisement.rssi) print(f'>>> {color(address, address_color)} [{color(address_type_string, type_color)}]{address_qualifier}{resolution_qualifier}:{separator}RSSI:{rssi:4} {rssi_bar}{separator}{ad_data.to_string(separator)}\n')
if not advertisement.is_legacy:
phy_info = (
f'PHY: {HCI_Constant.le_phy_name(advertisement.primary_phy)}/'
f'{HCI_Constant.le_phy_name(advertisement.secondary_phy)} '
f'{separator}'
)
else:
phy_info = ''
print( def on_advertisement(self, address, ad_data, rssi, connectable):
f'>>> {color(address, address_color)} ' address_color = 'yellow' if connectable else 'red'
f'[{color(address_type_string, type_color)}]{address_qualifier}' self.print_advertisement(address, address_color, ad_data, rssi)
f'{resolution_qualifier}:{separator}'
f'{phy_info}'
f'RSSI:{advertisement.rssi:4} {rssi_bar}{separator}'
f'{advertisement.data.to_string(separator)}\n'
)
def on_advertisement(self, advertisement): def on_advertising_report(self, address, ad_data, rssi, event_type):
self.print_advertisement(advertisement) print(f'{color("EVENT", "green")}: {HCI_LE_Advertising_Report_Event.event_type_name(event_type)}')
ad_data = AdvertisingData.from_bytes(ad_data)
def on_advertising_report(self, report): self.print_advertisement(address, 'yellow', ad_data, rssi)
print(f'{color("EVENT", "green")}: {report.event_type_string()}')
self.print_advertisement(Advertisement.from_advertising_report(report))
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -118,49 +94,30 @@ async def scan(
passive, passive,
scan_interval, scan_interval,
scan_window, scan_window,
phy,
filter_duplicates, filter_duplicates,
raw, raw,
irks,
keystore_file, keystore_file,
device_config, device_config,
transport, transport
): ):
print('<<< connecting to HCI...') print('<<< connecting to HCI...')
async with await open_transport_or_link(transport) as (hci_source, hci_sink): async with await open_transport_or_link(transport) as (hci_source, hci_sink):
print('<<< connected') print('<<< connected')
if device_config: if device_config:
device = Device.from_config_file_with_hci( device = Device.from_config_file_with_hci(device_config, hci_source, hci_sink)
device_config, hci_source, hci_sink
)
else: else:
device = Device.with_hci( device = Device.with_hci('Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink)
'Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink
)
await device.power_on()
if keystore_file: if keystore_file:
device.keystore = JsonKeyStore.from_device(device, filename=keystore_file) keystore = JsonKeyStore(namespace=None, filename=keystore_file)
device.keystore = keystore
else:
resolver = None
if device.keystore: if device.keystore:
resolving_keys = await device.keystore.get_resolving_keys() resolving_keys = await device.keystore.get_resolving_keys()
else: resolver = AddressResolver(resolving_keys)
resolving_keys = []
for irk_and_address in irks:
if ':' not in irk_and_address:
raise ValueError('invalid IRK:ADDRESS value')
irk_hex, address_str = irk_and_address.split(':', 1)
resolving_keys.append(
(
bytes.fromhex(irk_hex),
Address(address_str, Address.RANDOM_DEVICE_ADDRESS),
)
)
resolver = AddressResolver(resolving_keys) if resolving_keys else None
printer = AdvertisementPrinter(min_rssi, resolver) printer = AdvertisementPrinter(min_rssi, resolver)
if raw: if raw:
@@ -168,17 +125,12 @@ async def scan(
else: else:
device.on('advertisement', printer.on_advertisement) device.on('advertisement', printer.on_advertisement)
if phy is None: await device.power_on()
scanning_phys = [HCI_LE_1M_PHY, HCI_LE_CODED_PHY]
else:
scanning_phys = [{'1m': HCI_LE_1M_PHY, 'coded': HCI_LE_CODED_PHY}[phy]]
await device.start_scanning( await device.start_scanning(
active=(not passive), active=(not passive),
scan_interval=scan_interval, scan_interval=scan_interval,
scan_window=scan_window, scan_window=scan_window,
filter_duplicates=filter_duplicates, filter_duplicates=filter_duplicates
scanning_phys=scanning_phys,
) )
await hci_source.wait_for_termination() await hci_source.wait_for_termination()
@@ -190,69 +142,14 @@ async def scan(
@click.option('--passive', is_flag=True, default=False, help='Perform passive scanning') @click.option('--passive', is_flag=True, default=False, help='Perform passive scanning')
@click.option('--scan-interval', type=int, default=60, help='Scan interval') @click.option('--scan-interval', type=int, default=60, help='Scan interval')
@click.option('--scan-window', type=int, default=60, help='Scan window') @click.option('--scan-window', type=int, default=60, help='Scan window')
@click.option( @click.option('--filter-duplicates', type=bool, default=True, help='Filter duplicates at the controller level')
'--phy', type=click.Choice(['1m', 'coded']), help='Only scan on the specified PHY' @click.option('--raw', is_flag=True, default=False, help='Listen for raw advertising reports instead of processed ones')
) @click.option('--keystore-file', help='Keystore file to use when resolving addresses')
@click.option( @click.option('--device-config', help='Device config file for the scanning device')
'--filter-duplicates',
type=bool,
default=True,
help='Filter duplicates at the controller level',
)
@click.option(
'--raw',
is_flag=True,
default=False,
help='Listen for raw advertising reports instead of processed ones',
)
@click.option(
'--irk',
metavar='<IRK_HEX>:<ADDRESS>',
help=(
'Use this IRK for resolving private addresses ' '(may be used more than once)'
),
multiple=True,
)
@click.option(
'--keystore-file',
metavar='FILE_PATH',
help='Keystore file to use when resolving addresses',
)
@click.option(
'--device-config',
metavar='FILE_PATH',
help='Device config file for the scanning device',
)
@click.argument('transport') @click.argument('transport')
def main( def main(min_rssi, passive, scan_interval, scan_window, filter_duplicates, raw, keystore_file, device_config, transport):
min_rssi, logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
passive, asyncio.run(scan(min_rssi, passive, scan_interval, scan_window, filter_duplicates, raw, keystore_file, device_config, transport))
scan_interval,
scan_window,
phy,
filter_duplicates,
raw,
irk,
keystore_file,
device_config,
transport,
):
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
asyncio.run(
scan(
min_rssi,
passive,
scan_interval,
scan_window,
phy,
filter_duplicates,
raw,
irk,
keystore_file,
device_config,
transport,
)
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
+31 -101
View File
@@ -15,90 +15,54 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import datetime
import logging
import os
import struct import struct
import click import click
from colors import color
from bumble.colors import color
from bumble import hci from bumble import hci
from bumble.transport.common import PacketReader from bumble.transport.common import PacketReader
from bumble.helpers import PacketTracer from bumble.helpers import PacketTracer
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Classes
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class SnoopPacketReader: class SnoopPacketReader:
''' '''
Reader that reads HCI packets from a "snoop" file (based on RFC 1761, but not Reader that reads HCI packets from a "snoop" file (based on RFC 1761, but not exactly the same...)
exactly the same...)
''' '''
DATALINK_H1 = 1001 DATALINK_H1 = 1001
DATALINK_H4 = 1002 DATALINK_H4 = 1002
DATALINK_BSCP = 1003 DATALINK_BSCP = 1003
DATALINK_H5 = 1004 DATALINK_H5 = 1004
IDENTIFICATION_PATTERN = b'btsnoop\0'
TIMESTAMP_ANCHOR = datetime.datetime(2000, 1, 1)
TIMESTAMP_DELTA = 0x00E03AB44A676000
ONE_MICROSECOND = datetime.timedelta(microseconds=1)
def __init__(self, source): def __init__(self, source):
self.source = source self.source = source
self.at_end = False
# Read the header # Read the header
identification_pattern = source.read(8) identification_pattern = source.read(8)
if identification_pattern != self.IDENTIFICATION_PATTERN: if identification_pattern.hex().lower() != '6274736e6f6f7000':
raise ValueError( raise ValueError('not a valid snoop file, unexpected identification pattern')
'not a valid snoop file, unexpected identification pattern' (self.version_number, self.data_link_type) = struct.unpack('>II', source.read(8))
) if self.data_link_type != self.DATALINK_H4 and self.data_link_type != self.DATALINK_H1:
(self.version_number, self.data_link_type) = struct.unpack(
'>II', source.read(8)
)
if self.data_link_type not in (self.DATALINK_H4, self.DATALINK_H1):
raise ValueError(f'datalink type {self.data_link_type} not supported') raise ValueError(f'datalink type {self.data_link_type} not supported')
def next_packet(self): def next_packet(self):
# Read the record header # Read the record header
header = self.source.read(24) header = self.source.read(24)
if len(header) < 24: if len(header) < 24:
self.at_end = True return (0, None)
return (None, 0, None)
# Parse the header
( (
original_length, original_length,
included_length, included_length,
packet_flags, packet_flags,
_cumulative_drops, cumulative_drops,
timestamp, timestamp_seconds,
) = struct.unpack('>IIIIQ', header) timestamp_microsecond
) = struct.unpack('>IIIIII', header)
# Skip truncated packets # Abort on truncated packets
if original_length != included_length: if original_length != included_length:
print( return (0, None)
color(
f"!!! truncated packet ({included_length}/{original_length})", "red"
)
)
self.source.read(included_length)
return (None, 0, None)
# Convert the timestamp to a datetime object.
ts_dt = self.TIMESTAMP_ANCHOR + datetime.timedelta(
microseconds=timestamp - self.TIMESTAMP_DELTA
)
if self.data_link_type == self.DATALINK_H1: if self.data_link_type == self.DATALINK_H1:
# The packet is un-encapsulated, look at the flags to figure out its type # The packet is un-encapsulated, look at the flags to figure out its type
@@ -115,76 +79,42 @@ class SnoopPacketReader:
else: else:
packet_type = hci.HCI_ACL_DATA_PACKET packet_type = hci.HCI_ACL_DATA_PACKET
return ( return (packet_flags & 1, bytes([packet_type]) + self.source.read(included_length))
packet_flags & 1, else:
bytes([packet_type]) + self.source.read(included_length), return (packet_flags & 1, self.source.read(included_length))
)
return (ts_dt, packet_flags & 1, self.source.read(included_length))
# -----------------------------------------------------------------------------
class Printer:
def __init__(self):
self.index = 0
def print(self, message: str) -> None:
self.index += 1
print(f"[{self.index:8}]{message}")
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Main # Main
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@click.command() @click.command()
@click.option( @click.option('--format', type=click.Choice(['h4', 'snoop']), default='h4', help='Format of the input file')
'--format',
type=click.Choice(['h4', 'snoop']),
default='h4',
help='Format of the input file',
)
@click.option(
'--vendors',
type=click.Choice(['android', 'zephyr']),
multiple=True,
help='Support vendor-specific commands (list one or more)',
)
@click.argument('filename') @click.argument('filename')
# pylint: disable=redefined-builtin def show(format, filename):
def main(format, vendors, filename):
for vendor in vendors:
if vendor == 'android':
import bumble.vendor.android.hci
elif vendor == 'zephyr':
import bumble.vendor.zephyr.hci
input = open(filename, 'rb') input = open(filename, 'rb')
if format == 'h4': if format == 'h4':
packet_reader = PacketReader(input) packet_reader = PacketReader(input)
def read_next_packet(): def read_next_packet():
return (None, 0, packet_reader.next_packet()) (0, packet_reader.next_packet())
else: else:
packet_reader = SnoopPacketReader(input) packet_reader = SnoopPacketReader(input)
read_next_packet = packet_reader.next_packet read_next_packet = packet_reader.next_packet
printer = Printer() tracer = PacketTracer(emit_message=print)
tracer = PacketTracer(emit_message=printer.print)
while not packet_reader.at_end: while True:
try: try:
(timestamp, direction, packet) = read_next_packet() (direction, packet) = read_next_packet()
if packet: if packet is None:
tracer.trace(hci.HCI_Packet.from_bytes(packet), direction, timestamp) break
else: tracer.trace(hci.HCI_Packet.from_bytes(packet), direction)
printer.print(color("[TRUNCATED]", "red"))
except Exception as error: except Exception as error:
logger.exception()
print(color(f'!!! {error}', 'red')) print(color(f'!!! {error}', 'red'))
pass
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
if __name__ == '__main__': if __name__ == '__main__':
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper()) show()
main() # pylint: disable=no-value-for-parameter
View File
-42
View File
@@ -1,42 +0,0 @@
<?xml version="1.0" encoding="UTF-8" standalone="no"?> <!-- Created with Vectornator for iOS (http://vectornator.io/) --><!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
<svg height="100%" style="fill-rule:nonzero;clip-rule:evenodd;stroke-linecap:round;stroke-linejoin:round;" xmlns:xlink="http://www.w3.org/1999/xlink" xmlns="http://www.w3.org/2000/svg" xml:space="preserve" width="100%" xmlns:vectornator="http://vectornator.io" version="1.1" viewBox="0 0 745 744.634">
<metadata>
<vectornator:setting key="DimensionsVisible" value="1"/>
<vectornator:setting key="PencilOnly" value="0"/>
<vectornator:setting key="SnapToPoints" value="0"/>
<vectornator:setting key="OutlineMode" value="0"/>
<vectornator:setting key="CMYKEnabledKey" value="0"/>
<vectornator:setting key="RulersVisible" value="1"/>
<vectornator:setting key="SnapToEdges" value="0"/>
<vectornator:setting key="GuidesVisible" value="1"/>
<vectornator:setting key="DisplayWhiteBackground" value="0"/>
<vectornator:setting key="doHistoryDisabled" value="0"/>
<vectornator:setting key="SnapToGuides" value="1"/>
<vectornator:setting key="TimeLapseWatermarkDisabled" value="0"/>
<vectornator:setting key="Units" value="Pixels"/>
<vectornator:setting key="DynamicGuides" value="0"/>
<vectornator:setting key="IsolateActiveLayer" value="0"/>
<vectornator:setting key="SnapToGrid" value="0"/>
</metadata>
<defs/>
<g id="Layer 1" vectornator:layerName="Layer 1">
<path stroke="#000000" stroke-width="18.6464" d="M368.753+729.441L58.8847+550.539L58.8848+192.734L368.753+13.8313L678.621+192.734L678.621+550.539L368.753+729.441Z" fill="#0082fc" stroke-linecap="butt" fill-opacity="0.307489" opacity="1" stroke-linejoin="round"/>
<g opacity="1">
<g opacity="1">
<path stroke="#000000" stroke-width="20" d="M292.873+289.256L442.872+289.256L442.872+539.254L292.873+539.254L292.873+289.256Z" fill="#fcd100" stroke-linecap="butt" opacity="1" stroke-linejoin="round"/>
<path stroke="#000000" stroke-width="20" d="M292.873+289.256C292.873+247.835+326.452+214.257+367.873+214.257C409.294+214.257+442.872+247.835+442.872+289.256C442.872+330.677+409.294+364.256+367.873+364.256C326.452+364.256+292.873+330.677+292.873+289.256Z" fill="#fcd100" stroke-linecap="butt" opacity="1" stroke-linejoin="round"/>
<path stroke="#000000" stroke-width="20" d="M292.873+539.254C292.873+497.833+326.452+464.255+367.873+464.255C409.294+464.255+442.872+497.833+442.872+539.254C442.872+580.675+409.294+614.254+367.873+614.254C326.452+614.254+292.873+580.675+292.873+539.254Z" fill="#fcd100" stroke-linecap="butt" opacity="1" stroke-linejoin="round"/>
<path stroke="#0082fc" stroke-width="0.1" d="M302.873+289.073L432.872+289.073L432.872+539.072L302.873+539.072L302.873+289.073Z" fill="#fcd100" stroke-linecap="butt" opacity="1" stroke-linejoin="round"/>
</g>
<path stroke="#000000" stroke-width="0.1" d="M103.161+309.167L226.956+443.903L366.671+309.604L103.161+309.167Z" fill="#0082fc" stroke-linecap="round" opacity="1" stroke-linejoin="round"/>
<path stroke="#000000" stroke-width="0.1" d="M383.411+307.076L508.887+440.112L650.5+307.507L383.411+307.076Z" fill="#0082fc" stroke-linecap="round" opacity="1" stroke-linejoin="round"/>
<path stroke="#000000" stroke-width="20" d="M522.045+154.808L229.559+448.882L83.8397+300.104L653.666+302.936L511.759+444.785L223.101+156.114" fill="none" stroke-linecap="round" opacity="1" stroke-linejoin="round"/>
<path stroke="#000000" stroke-width="61.8698" d="M295.857+418.738L438.9+418.738" fill="none" stroke-linecap="butt" opacity="1" stroke-linejoin="round"/>
<path stroke="#000000" stroke-width="61.8698" d="M295.857+521.737L438.9+521.737" fill="none" stroke-linecap="butt" opacity="1" stroke-linejoin="round"/>
<g opacity="1">
<path stroke="#0082fc" stroke-width="0.1" d="M367.769+667.024L367.821+616.383L403.677+616.336C383.137+626.447+368.263+638.69+367.769+667.024Z" fill="#000000" stroke-linecap="butt" opacity="1" stroke-linejoin="round"/>
<path stroke="#0082fc" stroke-width="0.1" d="M367.836+667.024L367.784+616.383L331.928+616.336C352.468+626.447+367.341+638.69+367.836+667.024Z" fill="#000000" stroke-linecap="butt" opacity="1" stroke-linejoin="round"/>
</g>
</g>
</g>
</svg>

Before

Width:  |  Height:  |  Size: 4.1 KiB

-76
View File
@@ -1,76 +0,0 @@
body, h1, h2, h3, h4, h5, h6 {
font-family: sans-serif;
}
#controlsDiv {
margin: 6px;
}
#connectionText {
background-color: rgb(239, 89, 75);
border: none;
border-radius: 4px;
padding: 8px;
display: inline-block;
margin: 4px;
}
#startButton {
padding: 4px;
margin: 6px;
}
#fftCanvas {
border-radius: 16px;
margin: 6px;
}
#bandwidthCanvas {
border: grey;
border-style: solid;
border-radius: 8px;
margin: 6px;
}
#streamStateText {
background-color: rgb(93, 165, 93);
border: none;
border-radius: 8px;
padding: 10px 20px;
display: inline-block;
margin: 6px;
}
#connectionStateText {
background-color: rgb(112, 146, 206);
border: none;
border-radius: 8px;
padding: 10px 20px;
display: inline-block;
margin: 6px;
}
#propertiesTable {
border: grey;
border-style: solid;
border-radius: 4px;
padding: 4px;
margin: 6px;
margin-left: 0;
}
th, td {
padding-left: 6px;
padding-right: 6px;
}
.properties td:nth-child(even) {
background-color: #d6eeee;
font-family: monospace;
}
.properties td:nth-child(odd) {
font-weight: bold;
}
.properties tr td:nth-child(2) { width: 150px; }
-34
View File
@@ -1,34 +0,0 @@
<!DOCTYPE html>
<html>
<head>
<title>Bumble Speaker</title>
<script src="speaker.js"></script>
<link rel="stylesheet" href="speaker.css">
</head>
<body>
<h1><img src="logo.svg" width=100 height=100 style="vertical-align:middle" alt=""/>Bumble Virtual Speaker</h1>
<div id="connectionText"></div>
<div id="speaker">
<table><tr>
<td>
<table id="propertiesTable" class="properties">
<tr><td>Codec</td><td><span id="codecText"></span></td></tr>
<tr><td>Packets</td><td><span id="packetsReceivedText"></span></td></tr>
<tr><td>Bytes</td><td><span id="bytesReceivedText"></span></td></tr>
</table>
</td>
<td>
<canvas id="bandwidthCanvas" width="500", height="100">Bandwidth Graph</canvas>
</td>
</tr></table>
<span id="streamStateText">IDLE</span>
<span id="connectionStateText">NOT CONNECTED</span>
<div id="controlsDiv">
<button id="audioOnButton">Audio On</button>
<span id="audioSupportMessageText"></span>
</div>
<canvas id="fftCanvas" width="1024", height="300">Audio Frequencies Animation</canvas>
<audio id="audio"></audio>
</div>
</body>
</html>
-315
View File
@@ -1,315 +0,0 @@
(function () {
'use strict';
const channelUrl = ((window.location.protocol === "https:") ? "wss://" : "ws://") + window.location.host + "/channel";
let channelSocket;
let connectionText;
let codecText;
let packetsReceivedText;
let bytesReceivedText;
let streamStateText;
let connectionStateText;
let controlsDiv;
let audioOnButton;
let mediaSource;
let sourceBuffer;
let audioElement;
let audioContext;
let audioAnalyzer;
let audioFrequencyBinCount;
let audioFrequencyData;
let packetsReceived = 0;
let bytesReceived = 0;
let audioState = "stopped";
let streamState = "IDLE";
let audioSupportMessageText;
let fftCanvas;
let fftCanvasContext;
let bandwidthCanvas;
let bandwidthCanvasContext;
let bandwidthBinCount;
let bandwidthBins = [];
const FFT_WIDTH = 800;
const FFT_HEIGHT = 256;
const BANDWIDTH_WIDTH = 500;
const BANDWIDTH_HEIGHT = 100;
function hexToBytes(hex) {
return Uint8Array.from(hex.match(/.{1,2}/g).map((byte) => parseInt(byte, 16)));
}
function init() {
initUI();
initMediaSource();
initAudioElement();
initAnalyzer();
connect();
}
function initUI() {
controlsDiv = document.getElementById("controlsDiv");
controlsDiv.style.visibility = "hidden";
connectionText = document.getElementById("connectionText");
audioOnButton = document.getElementById("audioOnButton");
codecText = document.getElementById("codecText");
packetsReceivedText = document.getElementById("packetsReceivedText");
bytesReceivedText = document.getElementById("bytesReceivedText");
streamStateText = document.getElementById("streamStateText");
connectionStateText = document.getElementById("connectionStateText");
audioSupportMessageText = document.getElementById("audioSupportMessageText");
audioOnButton.onclick = () => startAudio();
setConnectionText("");
requestAnimationFrame(onAnimationFrame);
}
function initMediaSource() {
mediaSource = new MediaSource();
mediaSource.onsourceopen = onMediaSourceOpen;
mediaSource.onsourceclose = onMediaSourceClose;
mediaSource.onsourceended = onMediaSourceEnd;
}
function initAudioElement() {
audioElement = document.getElementById("audio");
audioElement.src = URL.createObjectURL(mediaSource);
// audioElement.controls = true;
}
function initAnalyzer() {
fftCanvas = document.getElementById("fftCanvas");
fftCanvas.width = FFT_WIDTH
fftCanvas.height = FFT_HEIGHT
fftCanvasContext = fftCanvas.getContext('2d');
fftCanvasContext.fillStyle = "rgb(0, 0, 0)";
fftCanvasContext.fillRect(0, 0, FFT_WIDTH, FFT_HEIGHT);
bandwidthCanvas = document.getElementById("bandwidthCanvas");
bandwidthCanvas.width = BANDWIDTH_WIDTH
bandwidthCanvas.height = BANDWIDTH_HEIGHT
bandwidthCanvasContext = bandwidthCanvas.getContext('2d');
bandwidthCanvasContext.fillStyle = "rgb(255, 255, 255)";
bandwidthCanvasContext.fillRect(0, 0, BANDWIDTH_WIDTH, BANDWIDTH_HEIGHT);
}
function startAnalyzer() {
// FFT
if (audioElement.captureStream !== undefined) {
audioContext = new AudioContext();
audioAnalyzer = audioContext.createAnalyser();
audioAnalyzer.fftSize = 128;
audioFrequencyBinCount = audioAnalyzer.frequencyBinCount;
audioFrequencyData = new Uint8Array(audioFrequencyBinCount);
const stream = audioElement.captureStream();
const source = audioContext.createMediaStreamSource(stream);
source.connect(audioAnalyzer);
}
// Bandwidth
bandwidthBinCount = BANDWIDTH_WIDTH / 2;
bandwidthBins = [];
}
function setConnectionText(message) {
connectionText.innerText = message;
if (message.length == 0) {
connectionText.style.display = "none";
} else {
connectionText.style.display = "inline-block";
}
}
function setStreamState(state) {
streamState = state;
streamStateText.innerText = streamState;
}
function onAnimationFrame() {
// FFT
if (audioAnalyzer !== undefined) {
audioAnalyzer.getByteFrequencyData(audioFrequencyData);
fftCanvasContext.fillStyle = "rgb(0, 0, 0)";
fftCanvasContext.fillRect(0, 0, FFT_WIDTH, FFT_HEIGHT);
const barCount = audioFrequencyBinCount;
const barWidth = (FFT_WIDTH / audioFrequencyBinCount) - 1;
for (let bar = 0; bar < barCount; bar++) {
const barHeight = audioFrequencyData[bar];
fftCanvasContext.fillStyle = `rgb(${barHeight / 256 * 200 + 50}, 50, ${50 + 2 * bar})`;
fftCanvasContext.fillRect(bar * (barWidth + 1), FFT_HEIGHT - barHeight, barWidth, barHeight);
}
}
// Bandwidth
bandwidthCanvasContext.fillStyle = "rgb(255, 255, 255)";
bandwidthCanvasContext.fillRect(0, 0, BANDWIDTH_WIDTH, BANDWIDTH_HEIGHT);
bandwidthCanvasContext.fillStyle = `rgb(100, 100, 100)`;
for (let t = 0; t < bandwidthBins.length; t++) {
const lineHeight = (bandwidthBins[t] / 1000) * BANDWIDTH_HEIGHT;
bandwidthCanvasContext.fillRect(t * 2, BANDWIDTH_HEIGHT - lineHeight, 2, lineHeight);
}
// Display again at the next frame
requestAnimationFrame(onAnimationFrame);
}
function onMediaSourceOpen() {
console.log(this.readyState);
sourceBuffer = mediaSource.addSourceBuffer("audio/aac");
}
function onMediaSourceClose() {
console.log(this.readyState);
}
function onMediaSourceEnd() {
console.log(this.readyState);
}
async function startAudio() {
try {
console.log("starting audio...");
audioOnButton.disabled = true;
audioState = "starting";
await audioElement.play();
console.log("audio started");
audioState = "playing";
startAnalyzer();
} catch(error) {
console.error(`play failed: ${error}`);
audioState = "stopped";
audioOnButton.disabled = false;
}
}
function onAudioPacket(packet) {
if (audioState != "stopped") {
// Queue the audio packet.
sourceBuffer.appendBuffer(packet);
}
packetsReceived += 1;
packetsReceivedText.innerText = packetsReceived;
bytesReceived += packet.byteLength;
bytesReceivedText.innerText = bytesReceived;
bandwidthBins[bandwidthBins.length] = packet.byteLength;
if (bandwidthBins.length > bandwidthBinCount) {
bandwidthBins.shift();
}
}
function onChannelOpen() {
console.log('channel OPEN');
setConnectionText("");
controlsDiv.style.visibility = "visible";
// Handshake with the backend.
sendMessage({
type: "hello"
});
}
function onChannelClose() {
console.log('channel CLOSED');
setConnectionText("Connection to CLI app closed, restart it and reload this page.");
controlsDiv.style.visibility = "hidden";
}
function onChannelError(error) {
console.log(`channel ERROR: ${error}`);
setConnectionText(`Connection to CLI app error ({${error}}), restart it and reload this page.`);
controlsDiv.style.visibility = "hidden";
}
function onChannelMessage(message) {
if (typeof message.data === 'string' || message.data instanceof String) {
// JSON message.
const jsonMessage = JSON.parse(message.data);
console.log(`channel MESSAGE: ${message.data}`);
// Dispatch the message.
const handlerName = `on${jsonMessage.type.charAt(0).toUpperCase()}${jsonMessage.type.slice(1)}Message`
const handler = messageHandlers[handlerName];
if (handler !== undefined) {
const params = jsonMessage.params;
if (params === undefined) {
params = {};
}
handler(params);
} else {
console.warn(`unhandled message: ${jsonMessage.type}`)
}
} else {
// BINARY audio data.
onAudioPacket(message.data);
}
}
function onHelloMessage(params) {
codecText.innerText = params.codec;
if (params.codec != "aac") {
audioOnButton.disabled = true;
audioSupportMessageText.innerText = "Only AAC can be played, audio will be disabled";
audioSupportMessageText.style.display = "inline-block";
} else {
audioSupportMessageText.innerText = "";
audioSupportMessageText.style.display = "none";
}
if (params.streamState) {
setStreamState(params.streamState);
}
}
function onStartMessage(params) {
setStreamState("STARTED");
}
function onStopMessage(params) {
setStreamState("STOPPED");
}
function onSuspendMessage(params) {
setStreamState("SUSPENDED");
}
function onConnectionMessage(params) {
connectionStateText.innerText = `CONNECTED: ${params.peer_name} (${params.peer_address})`;
}
function onDisconnectionMessage(params) {
connectionStateText.innerText = "DISCONNECTED";
}
function sendMessage(message) {
channelSocket.send(JSON.stringify(message));
}
function connect() {
console.log("connecting to CLI app");
channelSocket = new WebSocket(channelUrl);
channelSocket.binaryType = "arraybuffer";
channelSocket.onopen = onChannelOpen;
channelSocket.onclose = onChannelClose;
channelSocket.onerror = onChannelError;
channelSocket.onmessage = onChannelMessage;
}
const messageHandlers = {
onHelloMessage,
onStartMessage,
onStopMessage,
onSuspendMessage,
onConnectionMessage,
onDisconnectionMessage
}
window.onload = (event) => {
init();
}
}());
-738
View File
@@ -1,738 +0,0 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import asyncio.subprocess
from importlib import resources
import enum
import json
import os
import logging
import pathlib
import subprocess
from typing import Dict, List, Optional
import weakref
import click
import aiohttp
from aiohttp import web
import bumble
from bumble.colors import color
from bumble.core import BT_BR_EDR_TRANSPORT, CommandTimeoutError
from bumble.device import Connection, Device, DeviceConfiguration
from bumble.hci import HCI_StatusError
from bumble.pairing import PairingConfig
from bumble.sdp import ServiceAttribute
from bumble.transport import open_transport
from bumble.avdtp import (
AVDTP_AUDIO_MEDIA_TYPE,
Listener,
MediaCodecCapabilities,
MediaPacket,
Protocol,
)
from bumble.a2dp import (
MPEG_2_AAC_LC_OBJECT_TYPE,
make_audio_sink_service_sdp_records,
A2DP_SBC_CODEC_TYPE,
A2DP_MPEG_2_4_AAC_CODEC_TYPE,
SBC_MONO_CHANNEL_MODE,
SBC_DUAL_CHANNEL_MODE,
SBC_SNR_ALLOCATION_METHOD,
SBC_LOUDNESS_ALLOCATION_METHOD,
SBC_STEREO_CHANNEL_MODE,
SBC_JOINT_STEREO_CHANNEL_MODE,
SbcMediaCodecInformation,
AacMediaCodecInformation,
)
from bumble.utils import AsyncRunner
from bumble.codecs import AacAudioRtpPacket
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
DEFAULT_UI_PORT = 7654
# -----------------------------------------------------------------------------
class AudioExtractor:
@staticmethod
def create(codec: str):
if codec == 'aac':
return AacAudioExtractor()
if codec == 'sbc':
return SbcAudioExtractor()
def extract_audio(self, packet: MediaPacket) -> bytes:
raise NotImplementedError()
# -----------------------------------------------------------------------------
class AacAudioExtractor:
def extract_audio(self, packet: MediaPacket) -> bytes:
return AacAudioRtpPacket(packet.payload).to_adts()
# -----------------------------------------------------------------------------
class SbcAudioExtractor:
def extract_audio(self, packet: MediaPacket) -> bytes:
# header = packet.payload[0]
# fragmented = header >> 7
# start = (header >> 6) & 0x01
# last = (header >> 5) & 0x01
# number_of_frames = header & 0x0F
# TODO: support fragmented payloads
return packet.payload[1:]
# -----------------------------------------------------------------------------
class Output:
async def start(self) -> None:
pass
async def stop(self) -> None:
pass
async def suspend(self) -> None:
pass
async def on_connection(self, connection: Connection) -> None:
pass
async def on_disconnection(self, reason: int) -> None:
pass
def on_rtp_packet(self, packet: MediaPacket) -> None:
pass
# -----------------------------------------------------------------------------
class FileOutput(Output):
filename: str
codec: str
extractor: AudioExtractor
def __init__(self, filename, codec):
self.filename = filename
self.codec = codec
self.file = open(filename, 'wb')
self.extractor = AudioExtractor.create(codec)
def on_rtp_packet(self, packet: MediaPacket) -> None:
self.file.write(self.extractor.extract_audio(packet))
# -----------------------------------------------------------------------------
class QueuedOutput(Output):
MAX_QUEUE_SIZE = 32768
packets: asyncio.Queue
extractor: AudioExtractor
packet_pump_task: Optional[asyncio.Task]
started: bool
def __init__(self, extractor):
self.extractor = extractor
self.packets = asyncio.Queue()
self.packet_pump_task = None
self.started = False
async def start(self):
if self.started:
return
self.packet_pump_task = asyncio.create_task(self.pump_packets())
async def pump_packets(self):
while True:
packet = await self.packets.get()
await self.on_audio_packet(packet)
async def on_audio_packet(self, packet: bytes) -> None:
pass
def on_rtp_packet(self, packet: MediaPacket) -> None:
if self.packets.qsize() > self.MAX_QUEUE_SIZE:
logger.debug("queue full, dropping")
return
self.packets.put_nowait(self.extractor.extract_audio(packet))
# -----------------------------------------------------------------------------
class WebSocketOutput(QueuedOutput):
def __init__(self, codec, send_audio, send_message):
super().__init__(AudioExtractor.create(codec))
self.send_audio = send_audio
self.send_message = send_message
async def on_connection(self, connection: Connection) -> None:
try:
await connection.request_remote_name()
except HCI_StatusError:
pass
peer_name = '' if connection.peer_name is None else connection.peer_name
peer_address = connection.peer_address.to_string(False)
await self.send_message(
'connection',
peer_address=peer_address,
peer_name=peer_name,
)
async def on_disconnection(self, reason) -> None:
await self.send_message('disconnection')
async def on_audio_packet(self, packet: bytes) -> None:
await self.send_audio(packet)
async def start(self):
await super().start()
await self.send_message('start')
async def stop(self):
await super().stop()
await self.send_message('stop')
async def suspend(self):
await super().suspend()
await self.send_message('suspend')
# -----------------------------------------------------------------------------
class FfplayOutput(QueuedOutput):
MAX_QUEUE_SIZE = 32768
subprocess: Optional[asyncio.subprocess.Process]
ffplay_task: Optional[asyncio.Task]
def __init__(self, codec: str) -> None:
super().__init__(AudioExtractor.create(codec))
self.subprocess = None
self.ffplay_task = None
self.codec = codec
async def start(self):
if self.started:
return
await super().start()
self.subprocess = await asyncio.create_subprocess_shell(
f'ffplay -f {self.codec} pipe:0',
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
self.ffplay_task = asyncio.create_task(self.monitor_ffplay())
async def stop(self):
# TODO
pass
async def suspend(self):
# TODO
pass
async def monitor_ffplay(self):
async def read_stream(name, stream):
while True:
data = await stream.read()
logger.debug(f'{name}:', data)
await asyncio.wait(
[
asyncio.create_task(
read_stream('[ffplay stdout]', self.subprocess.stdout)
),
asyncio.create_task(
read_stream('[ffplay stderr]', self.subprocess.stderr)
),
asyncio.create_task(self.subprocess.wait()),
]
)
logger.debug("FFPLAY done")
async def on_audio_packet(self, packet):
try:
self.subprocess.stdin.write(packet)
except Exception:
logger.warning('!!!! exception while sending audio to ffplay pipe')
# -----------------------------------------------------------------------------
class UiServer:
speaker: weakref.ReferenceType[Speaker]
port: int
def __init__(self, speaker: Speaker, port: int) -> None:
self.speaker = weakref.ref(speaker)
self.port = port
self.channel_socket = None
async def start_http(self) -> None:
"""Start the UI HTTP server."""
app = web.Application()
app.add_routes(
[
web.get('/', self.get_static),
web.get('/speaker.html', self.get_static),
web.get('/speaker.js', self.get_static),
web.get('/speaker.css', self.get_static),
web.get('/logo.svg', self.get_static),
web.get('/channel', self.get_channel),
]
)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, 'localhost', self.port)
print('UI HTTP server at ' + color(f'http://127.0.0.1:{self.port}', 'green'))
await site.start()
async def get_static(self, request):
path = request.path
if path == '/':
path = '/speaker.html'
if path.endswith('.html'):
content_type = 'text/html'
elif path.endswith('.js'):
content_type = 'text/javascript'
elif path.endswith('.css'):
content_type = 'text/css'
elif path.endswith('.svg'):
content_type = 'image/svg+xml'
else:
content_type = 'text/plain'
text = (
resources.files("bumble.apps.speaker")
.joinpath(pathlib.Path(path).relative_to('/'))
.read_text(encoding="utf-8")
)
return aiohttp.web.Response(text=text, content_type=content_type)
async def get_channel(self, request):
ws = web.WebSocketResponse()
await ws.prepare(request)
# Process messages until the socket is closed.
self.channel_socket = ws
async for message in ws:
if message.type == aiohttp.WSMsgType.TEXT:
logger.debug(f'<<< received message: {message.data}')
await self.on_message(message.data)
elif message.type == aiohttp.WSMsgType.ERROR:
logger.debug(
f'channel connection closed with exception {ws.exception()}'
)
self.channel_socket = None
logger.debug('--- channel connection closed')
return ws
async def on_message(self, message_str: str):
# Parse the message as JSON
message = json.loads(message_str)
# Dispatch the message
message_type = message['type']
message_params = message.get('params', {})
handler = getattr(self, f'on_{message_type}_message')
if handler:
await handler(**message_params)
async def on_hello_message(self):
await self.send_message(
'hello',
bumble_version=bumble.__version__,
codec=self.speaker().codec,
streamState=self.speaker().stream_state.name,
)
if connection := self.speaker().connection:
await self.send_message(
'connection',
peer_address=connection.peer_address.to_string(False),
peer_name=connection.peer_name,
)
async def send_message(self, message_type: str, **kwargs) -> None:
if self.channel_socket is None:
return
message = {'type': message_type, 'params': kwargs}
await self.channel_socket.send_json(message)
async def send_audio(self, data: bytes) -> None:
if self.channel_socket is None:
return
try:
await self.channel_socket.send_bytes(data)
except Exception as error:
logger.warning(f'exception while sending audio packet: {error}')
# -----------------------------------------------------------------------------
class Speaker:
class StreamState(enum.Enum):
IDLE = 0
STOPPED = 1
STARTED = 2
SUSPENDED = 3
def __init__(self, device_config, transport, codec, discover, outputs, ui_port):
self.device_config = device_config
self.transport = transport
self.codec = codec
self.discover = discover
self.ui_port = ui_port
self.device = None
self.connection = None
self.listener = None
self.packets_received = 0
self.bytes_received = 0
self.stream_state = Speaker.StreamState.IDLE
self.outputs = []
for output in outputs:
if output == '@ffplay':
self.outputs.append(FfplayOutput(codec))
continue
# Default to FileOutput
self.outputs.append(FileOutput(output, codec))
# Create an HTTP server for the UI
self.ui_server = UiServer(speaker=self, port=ui_port)
def sdp_records(self) -> Dict[int, List[ServiceAttribute]]:
service_record_handle = 0x00010001
return {
service_record_handle: make_audio_sink_service_sdp_records(
service_record_handle
)
}
def codec_capabilities(self) -> MediaCodecCapabilities:
if self.codec == 'aac':
return self.aac_codec_capabilities()
if self.codec == 'sbc':
return self.sbc_codec_capabilities()
raise RuntimeError('unsupported codec')
def aac_codec_capabilities(self) -> MediaCodecCapabilities:
return MediaCodecCapabilities(
media_type=AVDTP_AUDIO_MEDIA_TYPE,
media_codec_type=A2DP_MPEG_2_4_AAC_CODEC_TYPE,
media_codec_information=AacMediaCodecInformation.from_lists(
object_types=[MPEG_2_AAC_LC_OBJECT_TYPE],
sampling_frequencies=[48000, 44100],
channels=[1, 2],
vbr=1,
bitrate=256000,
),
)
def sbc_codec_capabilities(self) -> MediaCodecCapabilities:
return MediaCodecCapabilities(
media_type=AVDTP_AUDIO_MEDIA_TYPE,
media_codec_type=A2DP_SBC_CODEC_TYPE,
media_codec_information=SbcMediaCodecInformation.from_lists(
sampling_frequencies=[48000, 44100, 32000, 16000],
channel_modes=[
SBC_MONO_CHANNEL_MODE,
SBC_DUAL_CHANNEL_MODE,
SBC_STEREO_CHANNEL_MODE,
SBC_JOINT_STEREO_CHANNEL_MODE,
],
block_lengths=[4, 8, 12, 16],
subbands=[4, 8],
allocation_methods=[
SBC_LOUDNESS_ALLOCATION_METHOD,
SBC_SNR_ALLOCATION_METHOD,
],
minimum_bitpool_value=2,
maximum_bitpool_value=53,
),
)
async def dispatch_to_outputs(self, function):
for output in self.outputs:
await function(output)
def on_bluetooth_connection(self, connection):
print(f'Connection: {connection}')
self.connection = connection
connection.on('disconnection', self.on_bluetooth_disconnection)
AsyncRunner.spawn(
self.dispatch_to_outputs(lambda output: output.on_connection(connection))
)
def on_bluetooth_disconnection(self, reason):
print(f'Disconnection ({reason})')
self.connection = None
AsyncRunner.spawn(self.advertise())
AsyncRunner.spawn(
self.dispatch_to_outputs(lambda output: output.on_disconnection(reason))
)
def on_avdtp_connection(self, protocol):
print('Audio Stream Open')
# Add a sink endpoint to the server
sink = protocol.add_sink(self.codec_capabilities())
sink.on('start', self.on_sink_start)
sink.on('stop', self.on_sink_stop)
sink.on('suspend', self.on_sink_suspend)
sink.on('configuration', lambda: self.on_sink_configuration(sink.configuration))
sink.on('rtp_packet', self.on_rtp_packet)
sink.on('rtp_channel_open', self.on_rtp_channel_open)
sink.on('rtp_channel_close', self.on_rtp_channel_close)
# Listen for close events
protocol.on('close', self.on_avdtp_close)
# Discover all endpoints on the remote device is requested
if self.discover:
AsyncRunner.spawn(self.discover_remote_endpoints(protocol))
def on_avdtp_close(self):
print("Audio Stream Closed")
def on_sink_start(self):
print("Sink Started\u001b[0K")
self.stream_state = self.StreamState.STARTED
AsyncRunner.spawn(self.dispatch_to_outputs(lambda output: output.start()))
def on_sink_stop(self):
print("Sink Stopped\u001b[0K")
self.stream_state = self.StreamState.STOPPED
AsyncRunner.spawn(self.dispatch_to_outputs(lambda output: output.stop()))
def on_sink_suspend(self):
print("Sink Suspended\u001b[0K")
self.stream_state = self.StreamState.SUSPENDED
AsyncRunner.spawn(self.dispatch_to_outputs(lambda output: output.suspend()))
def on_sink_configuration(self, config):
print("Sink Configuration:")
print('\n'.join([" " + str(capability) for capability in config]))
def on_rtp_channel_open(self):
print("RTP Channel Open")
def on_rtp_channel_close(self):
print("RTP Channel Closed")
self.stream_state = self.StreamState.IDLE
def on_rtp_packet(self, packet):
self.packets_received += 1
self.bytes_received += len(packet.payload)
print(
f'[{self.bytes_received} bytes in {self.packets_received} packets] {packet}',
end='\r',
)
for output in self.outputs:
output.on_rtp_packet(packet)
async def advertise(self):
await self.device.set_discoverable(True)
await self.device.set_connectable(True)
async def connect(self, address):
# Connect to the source
print(f'=== Connecting to {address}...')
connection = await self.device.connect(address, transport=BT_BR_EDR_TRANSPORT)
print(f'=== Connected to {connection.peer_address}')
# Request authentication
print('*** Authenticating...')
await connection.authenticate()
print('*** Authenticated')
# Enable encryption
print('*** Enabling encryption...')
await connection.encrypt()
print('*** Encryption on')
protocol = await Protocol.connect(connection)
self.listener.set_server(connection, protocol)
self.on_avdtp_connection(protocol)
async def discover_remote_endpoints(self, protocol):
endpoints = await protocol.discover_remote_endpoints()
print(f'@@@ Found {len(endpoints)} endpoints')
for endpoint in endpoints:
print('@@@', endpoint)
async def run(self, connect_address):
await self.ui_server.start_http()
self.outputs.append(
WebSocketOutput(
self.codec, self.ui_server.send_audio, self.ui_server.send_message
)
)
async with await open_transport(self.transport) as (hci_source, hci_sink):
# Create a device
device_config = DeviceConfiguration()
if self.device_config:
device_config.load_from_file(self.device_config)
else:
device_config.name = "Bumble Speaker"
device_config.class_of_device = 0x240414
device_config.keystore = "JsonKeyStore"
device_config.classic_enabled = True
device_config.le_enabled = False
self.device = Device.from_config_with_hci(
device_config, hci_source, hci_sink
)
# Setup the SDP to expose the sink service
self.device.sdp_service_records = self.sdp_records()
# Don't require MITM when pairing.
self.device.pairing_config_factory = lambda connection: PairingConfig(
mitm=False
)
# Start the controller
await self.device.power_on()
# Print some of the config/properties
print("Speaker Name:", color(device_config.name, 'yellow'))
print(
"Speaker Bluetooth Address:",
color(
self.device.public_address.to_string(with_type_qualifier=False),
'yellow',
),
)
# Listen for Bluetooth connections
self.device.on('connection', self.on_bluetooth_connection)
# Create a listener to wait for AVDTP connections
self.listener = Listener.for_device(self.device)
self.listener.on('connection', self.on_avdtp_connection)
print(f'Speaker ready to play, codec={color(self.codec, "cyan")}')
if connect_address:
# Connect to the source
try:
await self.connect(connect_address)
except CommandTimeoutError:
print(color("Connection timed out", "red"))
return
else:
# Start being discoverable and connectable
print("Waiting for connection...")
await self.advertise()
await hci_source.wait_for_termination()
for output in self.outputs:
await output.stop()
# -----------------------------------------------------------------------------
@click.group()
@click.pass_context
def speaker_cli(ctx, device_config):
ctx.ensure_object(dict)
ctx.obj['device_config'] = device_config
@click.command()
@click.option(
'--codec', type=click.Choice(['sbc', 'aac']), default='aac', show_default=True
)
@click.option(
'--discover', is_flag=True, help='Discover remote endpoints once connected'
)
@click.option(
'--output',
multiple=True,
metavar='NAME',
help=(
'Send audio to this named output '
'(may be used more than once for multiple outputs)'
),
)
@click.option(
'--ui-port',
'ui_port',
metavar='HTTP_PORT',
default=DEFAULT_UI_PORT,
show_default=True,
help='HTTP port for the UI server',
)
@click.option(
'--connect',
'connect_address',
metavar='ADDRESS_OR_NAME',
help='Address or name to connect to',
)
@click.option('--device-config', metavar='FILENAME', help='Device configuration file')
@click.argument('transport')
def speaker(
transport, codec, connect_address, discover, output, ui_port, device_config
):
"""Run the speaker."""
if '@ffplay' in output:
# Check if ffplay is installed
try:
subprocess.run(['ffplay', '-version'], capture_output=True, check=True)
except FileNotFoundError:
print(
color('ffplay not installed, @ffplay output will be disabled', 'yellow')
)
output = list(filter(lambda x: x != '@ffplay', output))
asyncio.run(
Speaker(device_config, transport, codec, discover, output, ui_port).run(
connect_address
)
)
# -----------------------------------------------------------------------------
def main():
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
speaker()
# -----------------------------------------------------------------------------
if __name__ == "__main__":
main() # pylint: disable=no-value-for-parameter
+22 -41
View File
@@ -22,59 +22,40 @@ import click
from bumble.device import Device from bumble.device import Device
from bumble.keys import JsonKeyStore from bumble.keys import JsonKeyStore
from bumble.transport import open_transport
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def unbond_with_keystore(keystore, address): async def unbond(keystore_file, device_config, address):
if address is None: # Create a device to manage the host
return await keystore.print() device = Device.from_config_file(device_config)
try: # Get all entries in the keystore
await keystore.delete(address)
except KeyError:
print('!!! pairing not found')
# -----------------------------------------------------------------------------
async def unbond(keystore_file, device_config, hci_transport, address):
# With a keystore file, we can instantiate the keystore directly
if keystore_file: if keystore_file:
return await unbond_with_keystore(JsonKeyStore(None, keystore_file), address) keystore = JsonKeyStore(None, keystore_file)
else:
keystore = device.keystore
# Without a keystore file, we need to obtain the keystore from the device if keystore is None:
async with await open_transport(hci_transport) as (hci_source, hci_sink): print('no keystore')
# Create a device to manage the host return
device = Device.from_config_file_with_hci(device_config, hci_source, hci_sink)
# Power-on the device to ensure we have a key store if address is None:
await device.power_on() await keystore.print()
else:
return await unbond_with_keystore(device.keystore, address) try:
await keystore.delete(address)
except KeyError:
print('!!! pairing not found')
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@click.command() @click.command()
@click.option('--keystore-file', help='File in which the pairing keys are stored') @click.option('--keystore-file', help='File in which to store the pairing keys')
@click.option('--hci-transport', help='HCI transport for the controller') @click.argument('device-config')
@click.argument('device-config', required=False)
@click.argument('address', required=False) @click.argument('address', required=False)
def main(keystore_file, hci_transport, device_config, address): def main(keystore_file, device_config, address):
""" logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
Remove pairing keys for a device, given its address. asyncio.run(unbond(keystore_file, device_config, address))
If no keystore file is specified, the --hci-transport option must be used to
connect to a controller, so that the keystore for that controller can be
instantiated.
If no address is passed, the existing pairing keys for all addresses are printed.
"""
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
if not keystore_file and not hci_transport:
print('either --keystore-file or --hci-transport must be specified.')
return
asyncio.run(unbond(keystore_file, device_config, hci_transport, address))
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
-278
View File
@@ -1,278 +0,0 @@
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# This tool lists all the USB devices, with details about each device.
# For each device, the different possible Bumble transport strings that can
# refer to it are listed. If the device is known to be a Bluetooth HCI device,
# its identifier is printed in reverse colors, and the transport names in cyan color.
# For other devices, regardless of their type, the transport names are printed
# in red. Whether that device is actually a Bluetooth device or not depends on
# whether it is a Bluetooth device that uses a non-standard Class, or some other
# type of device (there's no way to tell).
# -----------------------------------------------------------------------------
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import os
import logging
import click
import usb1
from bumble.colors import color
from bumble.transport.usb import load_libusb
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
USB_DEVICE_CLASS_DEVICE = 0x00
USB_DEVICE_CLASS_WIRELESS_CONTROLLER = 0xE0
USB_DEVICE_SUBCLASS_RF_CONTROLLER = 0x01
USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER = 0x01
USB_DEVICE_CLASSES = {
0x00: 'Device',
0x01: 'Audio',
0x02: 'Communications and CDC Control',
0x03: 'Human Interface Device',
0x05: 'Physical',
0x06: 'Still Imaging',
0x07: 'Printer',
0x08: 'Mass Storage',
0x09: 'Hub',
0x0A: 'CDC Data',
0x0B: 'Smart Card',
0x0D: 'Content Security',
0x0E: 'Video',
0x0F: 'Personal Healthcare',
0x10: 'Audio/Video',
0x11: 'Billboard',
0x12: 'USB Type-C Bridge',
0x3C: 'I3C',
0xDC: 'Diagnostic',
USB_DEVICE_CLASS_WIRELESS_CONTROLLER: (
'Wireless Controller',
{
0x01: {
0x01: 'Bluetooth',
0x02: 'UWB',
0x03: 'Remote NDIS',
0x04: 'Bluetooth AMP',
}
},
),
0xEF: 'Miscellaneous',
0xFE: 'Application Specific',
0xFF: 'Vendor Specific',
}
USB_ENDPOINT_IN = 0x80
USB_ENDPOINT_TYPES = ['CONTROL', 'ISOCHRONOUS', 'BULK', 'INTERRUPT']
USB_BT_HCI_CLASS_TUPLE = (
USB_DEVICE_CLASS_WIRELESS_CONTROLLER,
USB_DEVICE_SUBCLASS_RF_CONTROLLER,
USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER,
)
# -----------------------------------------------------------------------------
def show_device_details(device):
for configuration in device:
print(f' Configuration {configuration.getConfigurationValue()}')
for interface in configuration:
for setting in interface:
alternate_setting = setting.getAlternateSetting()
suffix = (
f'/{alternate_setting}' if interface.getNumSettings() > 1 else ''
)
(class_string, subclass_string) = get_class_info(
setting.getClass(), setting.getSubClass(), setting.getProtocol()
)
details = f'({class_string}, {subclass_string})'
print(f' Interface: {setting.getNumber()}{suffix} {details}')
for endpoint in setting:
endpoint_type = USB_ENDPOINT_TYPES[endpoint.getAttributes() & 3]
endpoint_direction = (
'OUT'
if (endpoint.getAddress() & USB_ENDPOINT_IN == 0)
else 'IN'
)
print(
f' Endpoint 0x{endpoint.getAddress():02X}: '
f'{endpoint_type} {endpoint_direction}'
)
# -----------------------------------------------------------------------------
def get_class_info(cls, subclass, protocol):
class_info = USB_DEVICE_CLASSES.get(cls)
protocol_string = ''
if class_info is None:
class_string = f'0x{cls:02X}'
else:
if isinstance(class_info, tuple):
class_string = class_info[0]
subclass_info = class_info[1].get(subclass)
if subclass_info:
protocol_string = subclass_info.get(protocol)
if protocol_string is not None:
protocol_string = f' [{protocol_string}]'
else:
class_string = class_info
subclass_string = f'{subclass}/{protocol}{protocol_string}'
return (class_string, subclass_string)
# -----------------------------------------------------------------------------
def is_bluetooth_hci(device):
# Check if the device class indicates a match
if (
device.getDeviceClass(),
device.getDeviceSubClass(),
device.getDeviceProtocol(),
) == USB_BT_HCI_CLASS_TUPLE:
return True
# If the device class is 'Device', look for a matching interface
if device.getDeviceClass() == USB_DEVICE_CLASS_DEVICE:
for configuration in device:
for interface in configuration:
for setting in interface:
if (
setting.getClass(),
setting.getSubClass(),
setting.getProtocol(),
) == USB_BT_HCI_CLASS_TUPLE:
return True
return False
# -----------------------------------------------------------------------------
@click.command()
@click.option('--verbose', is_flag=True, default=False, help='Print more details')
def main(verbose):
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
load_libusb()
with usb1.USBContext() as context:
bluetooth_device_count = 0
devices = {}
for device in context.getDeviceIterator(skip_on_error=True):
device_class = device.getDeviceClass()
device_subclass = device.getDeviceSubClass()
device_protocol = device.getDeviceProtocol()
device_id = (device.getVendorID(), device.getProductID())
(device_class_string, device_subclass_string) = get_class_info(
device_class, device_subclass, device_protocol
)
try:
device_serial_number = device.getSerialNumber()
except usb1.USBError:
device_serial_number = None
try:
device_manufacturer = device.getManufacturer()
except usb1.USBError:
device_manufacturer = None
try:
device_product = device.getProduct()
except usb1.USBError:
device_product = None
device_is_bluetooth_hci = is_bluetooth_hci(device)
if device_is_bluetooth_hci:
bluetooth_device_count += 1
fg_color = 'black'
bg_color = 'yellow'
else:
fg_color = 'yellow'
bg_color = 'black'
# Compute the different ways this can be referenced as a Bumble transport
bumble_transport_names = []
basic_transport_name = (
f'usb:{device.getVendorID():04X}:{device.getProductID():04X}'
)
if device_is_bluetooth_hci:
bumble_transport_names.append(f'usb:{bluetooth_device_count - 1}')
if device_id not in devices:
bumble_transport_names.append(basic_transport_name)
else:
bumble_transport_names.append(
f'{basic_transport_name}#{len(devices[device_id])}'
)
if device_serial_number is not None:
if (
device_id not in devices
or device_serial_number not in devices[device_id]
):
bumble_transport_names.append(
f'{basic_transport_name}/{device_serial_number}'
)
# Print the results
print(
color(
f'ID {device.getVendorID():04X}:{device.getProductID():04X}',
fg=fg_color,
bg=bg_color,
)
)
if bumble_transport_names:
print(
color(' Bumble Transport Names:', 'blue'),
' or '.join(
color(x, 'cyan' if device_is_bluetooth_hci else 'red')
for x in bumble_transport_names
),
)
print(
color(' Bus/Device: ', 'green'),
f'{device.getBusNumber():03}/{device.getDeviceAddress():03}',
)
print(color(' Class: ', 'green'), device_class_string)
print(color(' Subclass/Protocol: ', 'green'), device_subclass_string)
if device_serial_number is not None:
print(color(' Serial: ', 'green'), device_serial_number)
if device_manufacturer is not None:
print(color(' Manufacturer: ', 'green'), device_manufacturer)
if device_product is not None:
print(color(' Product: ', 'green'), device_product)
if verbose:
show_device_details(device)
print()
devices.setdefault(device_id, []).append(device_serial_number)
# -----------------------------------------------------------------------------
if __name__ == '__main__':
main() # pylint: disable=no-value-for-parameter
-4
View File
@@ -1,4 +0,0 @@
try:
from ._version import version as __version__
except ImportError:
__version__ = "unknown version"
+246 -357
View File
@@ -15,13 +15,11 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations
import dataclasses
import struct import struct
import bitstruct
import logging import logging
from collections.abc import AsyncGenerator from collections import namedtuple
from typing import List, Callable, Awaitable from colors import color
from .company_ids import COMPANY_IDENTIFIERS from .company_ids import COMPANY_IDENTIFIERS
from .sdp import ( from .sdp import (
@@ -32,7 +30,7 @@ from .sdp import (
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID, SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID
) )
from .core import ( from .core import (
BT_L2CAP_PROTOCOL_ID, BT_L2CAP_PROTOCOL_ID,
@@ -40,7 +38,7 @@ from .core import (
BT_AUDIO_SINK_SERVICE, BT_AUDIO_SINK_SERVICE,
BT_AVDTP_PROTOCOL_ID, BT_AVDTP_PROTOCOL_ID,
BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE, BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE,
name_or_number, name_or_number
) )
@@ -53,7 +51,6 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Constants # Constants
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# fmt: off
A2DP_SBC_CODEC_TYPE = 0x00 A2DP_SBC_CODEC_TYPE = 0x00
A2DP_MPEG_1_2_AUDIO_CODEC_TYPE = 0x01 A2DP_MPEG_1_2_AUDIO_CODEC_TYPE = 0x01
@@ -130,272 +127,212 @@ MPEG_2_4_OBJECT_TYPE_NAMES = {
MPEG_4_AAC_SCALABLE_OBJECT_TYPE: 'MPEG_4_AAC_SCALABLE_OBJECT_TYPE' MPEG_4_AAC_SCALABLE_OBJECT_TYPE: 'MPEG_4_AAC_SCALABLE_OBJECT_TYPE'
} }
# fmt: on
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def flags_to_list(flags, values): def flags_to_list(flags, values):
result = [] result = []
for i, value in enumerate(values): for i in range(len(values)):
if flags & (1 << (len(values) - i - 1)): if flags & (1 << (len(values) - i - 1)):
result.append(value) result.append(values[i])
return result return result
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def make_audio_source_service_sdp_records(service_record_handle, version=(1, 3)): def make_audio_source_service_sdp_records(service_record_handle, version=(1, 3)):
# pylint: disable=import-outside-toplevel
from .avdtp import AVDTP_PSM from .avdtp import AVDTP_PSM
version_int = version[0] << 8 | version[1] version_int = version[0] << 8 | version[1]
return [ return [
ServiceAttribute( ServiceAttribute(SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, DataElement.unsigned_integer_32(service_record_handle)),
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, ServiceAttribute(SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID, DataElement.sequence([
DataElement.unsigned_integer_32(service_record_handle), DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)
), ])),
ServiceAttribute( ServiceAttribute(SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, DataElement.sequence([
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID, DataElement.uuid(BT_AUDIO_SOURCE_SERVICE)
DataElement.sequence([DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)]), ])),
), ServiceAttribute(SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, DataElement.sequence([
ServiceAttribute( DataElement.sequence([
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, DataElement.uuid(BT_L2CAP_PROTOCOL_ID),
DataElement.sequence([DataElement.uuid(BT_AUDIO_SOURCE_SERVICE)]), DataElement.unsigned_integer_16(AVDTP_PSM)
), ]),
ServiceAttribute( DataElement.sequence([
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, DataElement.uuid(BT_AVDTP_PROTOCOL_ID),
DataElement.sequence( DataElement.unsigned_integer_16(version_int)
[ ])
DataElement.sequence( ])),
[ ServiceAttribute(SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID, DataElement.sequence([
DataElement.uuid(BT_L2CAP_PROTOCOL_ID), DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE),
DataElement.unsigned_integer_16(AVDTP_PSM), DataElement.unsigned_integer_16(version_int)
] ])),
),
DataElement.sequence(
[
DataElement.uuid(BT_AVDTP_PROTOCOL_ID),
DataElement.unsigned_integer_16(version_int),
]
),
]
),
),
ServiceAttribute(
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.sequence(
[
DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE),
DataElement.unsigned_integer_16(version_int),
]
)
]
),
),
] ]
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def make_audio_sink_service_sdp_records(service_record_handle, version=(1, 3)): def make_audio_sink_service_sdp_records(service_record_handle, version=(1, 3)):
# pylint: disable=import-outside-toplevel
from .avdtp import AVDTP_PSM from .avdtp import AVDTP_PSM
version_int = version[0] << 8 | version[1] version_int = version[0] << 8 | version[1]
return [ return [
ServiceAttribute( ServiceAttribute(SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, DataElement.unsigned_integer_32(service_record_handle)),
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, ServiceAttribute(SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID, DataElement.sequence([
DataElement.unsigned_integer_32(service_record_handle), DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)
), ])),
ServiceAttribute( ServiceAttribute(SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, DataElement.sequence([
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID, DataElement.uuid(BT_AUDIO_SINK_SERVICE)
DataElement.sequence([DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)]), ])),
), ServiceAttribute(SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, DataElement.sequence([
ServiceAttribute( DataElement.sequence([
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, DataElement.uuid(BT_L2CAP_PROTOCOL_ID),
DataElement.sequence([DataElement.uuid(BT_AUDIO_SINK_SERVICE)]), DataElement.unsigned_integer_16(AVDTP_PSM)
), ]),
ServiceAttribute( DataElement.sequence([
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, DataElement.uuid(BT_AVDTP_PROTOCOL_ID),
DataElement.sequence( DataElement.unsigned_integer_16(version_int)
[ ])
DataElement.sequence( ])),
[ ServiceAttribute(SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID, DataElement.sequence([
DataElement.uuid(BT_L2CAP_PROTOCOL_ID), DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE),
DataElement.unsigned_integer_16(AVDTP_PSM), DataElement.unsigned_integer_16(version_int)
] ])),
),
DataElement.sequence(
[
DataElement.uuid(BT_AVDTP_PROTOCOL_ID),
DataElement.unsigned_integer_16(version_int),
]
),
]
),
),
ServiceAttribute(
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.sequence(
[
DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE),
DataElement.unsigned_integer_16(version_int),
]
)
]
),
),
] ]
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@dataclasses.dataclass class SbcMediaCodecInformation(
class SbcMediaCodecInformation: namedtuple(
'SbcMediaCodecInformation',
[
'sampling_frequency',
'channel_mode',
'block_length',
'subbands',
'allocation_method',
'minimum_bitpool_value',
'maximum_bitpool_value'
]
)
):
''' '''
A2DP spec - 4.3.2 Codec Specific Information Elements A2DP spec - 4.3.2 Codec Specific Information Elements
''' '''
sampling_frequency: int BIT_FIELDS = 'u4u4u4u2u2u8u8'
channel_mode: int SAMPLING_FREQUENCY_BITS = {
block_length: int 16000: 1 << 3,
subbands: int 32000: 1 << 2,
allocation_method: int 44100: 1 << 1,
minimum_bitpool_value: int 48000: 1
maximum_bitpool_value: int }
CHANNEL_MODE_BITS = {
SAMPLING_FREQUENCY_BITS = {16000: 1 << 3, 32000: 1 << 2, 44100: 1 << 1, 48000: 1} SBC_MONO_CHANNEL_MODE: 1 << 3,
CHANNEL_MODE_BITS = { SBC_DUAL_CHANNEL_MODE: 1 << 2,
SBC_MONO_CHANNEL_MODE: 1 << 3, SBC_STEREO_CHANNEL_MODE: 1 << 1,
SBC_DUAL_CHANNEL_MODE: 1 << 2, SBC_JOINT_STEREO_CHANNEL_MODE: 1
SBC_STEREO_CHANNEL_MODE: 1 << 1, }
SBC_JOINT_STEREO_CHANNEL_MODE: 1, BLOCK_LENGTH_BITS = {
4: 1 << 3,
8: 1 << 2,
12: 1 << 1,
16: 1
}
SUBBANDS_BITS = {
4: 1 << 1,
8: 1
} }
BLOCK_LENGTH_BITS = {4: 1 << 3, 8: 1 << 2, 12: 1 << 1, 16: 1}
SUBBANDS_BITS = {4: 1 << 1, 8: 1}
ALLOCATION_METHOD_BITS = { ALLOCATION_METHOD_BITS = {
SBC_SNR_ALLOCATION_METHOD: 1 << 1, SBC_SNR_ALLOCATION_METHOD: 1 << 1,
SBC_LOUDNESS_ALLOCATION_METHOD: 1, SBC_LOUDNESS_ALLOCATION_METHOD: 1
} }
@staticmethod @staticmethod
def from_bytes(data: bytes) -> SbcMediaCodecInformation: def from_bytes(data):
sampling_frequency = (data[0] >> 4) & 0x0F return SbcMediaCodecInformation(*bitstruct.unpack(SbcMediaCodecInformation.BIT_FIELDS, data))
channel_mode = (data[0] >> 0) & 0x0F
block_length = (data[1] >> 4) & 0x0F
subbands = (data[1] >> 2) & 0x03
allocation_method = (data[1] >> 0) & 0x03
minimum_bitpool_value = (data[2] >> 0) & 0xFF
maximum_bitpool_value = (data[3] >> 0) & 0xFF
return SbcMediaCodecInformation(
sampling_frequency,
channel_mode,
block_length,
subbands,
allocation_method,
minimum_bitpool_value,
maximum_bitpool_value,
)
@classmethod @classmethod
def from_discrete_values( def from_discrete_values(
cls, cls,
sampling_frequency: int, sampling_frequency,
channel_mode: int, channel_mode,
block_length: int, block_length,
subbands: int, subbands,
allocation_method: int, allocation_method,
minimum_bitpool_value: int, minimum_bitpool_value,
maximum_bitpool_value: int, maximum_bitpool_value
) -> SbcMediaCodecInformation: ):
return SbcMediaCodecInformation( return SbcMediaCodecInformation(
sampling_frequency=cls.SAMPLING_FREQUENCY_BITS[sampling_frequency], sampling_frequency = cls.SAMPLING_FREQUENCY_BITS[sampling_frequency],
channel_mode=cls.CHANNEL_MODE_BITS[channel_mode], channel_mode = cls.CHANNEL_MODE_BITS[channel_mode],
block_length=cls.BLOCK_LENGTH_BITS[block_length], block_length = cls.BLOCK_LENGTH_BITS[block_length],
subbands=cls.SUBBANDS_BITS[subbands], subbands = cls.SUBBANDS_BITS[subbands],
allocation_method=cls.ALLOCATION_METHOD_BITS[allocation_method], allocation_method = cls.ALLOCATION_METHOD_BITS[allocation_method],
minimum_bitpool_value=minimum_bitpool_value, minimum_bitpool_value = minimum_bitpool_value,
maximum_bitpool_value=maximum_bitpool_value, maximum_bitpool_value = maximum_bitpool_value
) )
@classmethod @classmethod
def from_lists( def from_lists(
cls, cls,
sampling_frequencies: List[int], sampling_frequencies,
channel_modes: List[int], channel_modes,
block_lengths: List[int], block_lengths,
subbands: List[int], subbands,
allocation_methods: List[int], allocation_methods,
minimum_bitpool_value: int, minimum_bitpool_value,
maximum_bitpool_value: int, maximum_bitpool_value
) -> SbcMediaCodecInformation: ):
return SbcMediaCodecInformation( return SbcMediaCodecInformation(
sampling_frequency=sum( sampling_frequency = sum(cls.SAMPLING_FREQUENCY_BITS[x] for x in sampling_frequencies),
cls.SAMPLING_FREQUENCY_BITS[x] for x in sampling_frequencies channel_mode = sum(cls.CHANNEL_MODE_BITS[x] for x in channel_modes),
), block_length = sum(cls.BLOCK_LENGTH_BITS[x] for x in block_lengths),
channel_mode=sum(cls.CHANNEL_MODE_BITS[x] for x in channel_modes), subbands = sum(cls.SUBBANDS_BITS[x] for x in subbands),
block_length=sum(cls.BLOCK_LENGTH_BITS[x] for x in block_lengths), allocation_method = sum(cls.ALLOCATION_METHOD_BITS[x] for x in allocation_methods),
subbands=sum(cls.SUBBANDS_BITS[x] for x in subbands), minimum_bitpool_value = minimum_bitpool_value,
allocation_method=sum( maximum_bitpool_value = maximum_bitpool_value
cls.ALLOCATION_METHOD_BITS[x] for x in allocation_methods
),
minimum_bitpool_value=minimum_bitpool_value,
maximum_bitpool_value=maximum_bitpool_value,
) )
def __bytes__(self) -> bytes: def __bytes__(self):
return bytes( return bitstruct.pack(self.BIT_FIELDS, *self)
[
(self.sampling_frequency << 4) | self.channel_mode,
(self.block_length << 4)
| (self.subbands << 2)
| self.allocation_method,
self.minimum_bitpool_value,
self.maximum_bitpool_value,
]
)
def __str__(self) -> str: def __str__(self):
channel_modes = ['MONO', 'DUAL_CHANNEL', 'STEREO', 'JOINT_STEREO'] channel_modes = ['MONO', 'DUAL_CHANNEL', 'STEREO', 'JOINT_STEREO']
allocation_methods = ['SNR', 'Loudness'] allocation_methods = ['SNR', 'Loudness']
return '\n'.join( return '\n'.join([
# pylint: disable=line-too-long 'SbcMediaCodecInformation(',
[ f' sampling_frequency: {",".join([str(x) for x in flags_to_list(self.sampling_frequency, SBC_SAMPLING_FREQUENCIES)])}',
'SbcMediaCodecInformation(', f' channel_mode: {",".join([str(x) for x in flags_to_list(self.channel_mode, channel_modes)])}',
f' sampling_frequency: {",".join([str(x) for x in flags_to_list(self.sampling_frequency, SBC_SAMPLING_FREQUENCIES)])}', f' block_length: {",".join([str(x) for x in flags_to_list(self.block_length, SBC_BLOCK_LENGTHS)])}',
f' channel_mode: {",".join([str(x) for x in flags_to_list(self.channel_mode, channel_modes)])}', f' subbands: {",".join([str(x) for x in flags_to_list(self.subbands, SBC_SUBBANDS)])}',
f' block_length: {",".join([str(x) for x in flags_to_list(self.block_length, SBC_BLOCK_LENGTHS)])}', f' allocation_method: {",".join([str(x) for x in flags_to_list(self.allocation_method, allocation_methods)])}',
f' subbands: {",".join([str(x) for x in flags_to_list(self.subbands, SBC_SUBBANDS)])}', f' minimum_bitpool_value: {self.minimum_bitpool_value}',
f' allocation_method: {",".join([str(x) for x in flags_to_list(self.allocation_method, allocation_methods)])}', f' maximum_bitpool_value: {self.maximum_bitpool_value}'
f' minimum_bitpool_value: {self.minimum_bitpool_value}', ')'
f' maximum_bitpool_value: {self.maximum_bitpool_value}' ')', ])
]
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@dataclasses.dataclass class AacMediaCodecInformation(
class AacMediaCodecInformation: namedtuple(
'AacMediaCodecInformation',
[
'object_type',
'sampling_frequency',
'channels',
'vbr',
'bitrate'
]
)
):
''' '''
A2DP spec - 4.5.2 Codec Specific Information Elements A2DP spec - 4.5.2 Codec Specific Information Elements
''' '''
object_type: int BIT_FIELDS = 'u8u12u2p2u1u23'
sampling_frequency: int
channels: int
rfa: int
vbr: int
bitrate: int
OBJECT_TYPE_BITS = { OBJECT_TYPE_BITS = {
MPEG_2_AAC_LC_OBJECT_TYPE: 1 << 7, MPEG_2_AAC_LC_OBJECT_TYPE: 1 << 7,
MPEG_4_AAC_LC_OBJECT_TYPE: 1 << 6, MPEG_4_AAC_LC_OBJECT_TYPE: 1 << 6,
MPEG_4_AAC_LTP_OBJECT_TYPE: 1 << 5, MPEG_4_AAC_LTP_OBJECT_TYPE: 1 << 5,
MPEG_4_AAC_SCALABLE_OBJECT_TYPE: 1 << 4, MPEG_4_AAC_SCALABLE_OBJECT_TYPE: 1 << 4
} }
SAMPLING_FREQUENCY_BITS = { SAMPLING_FREQUENCY_BITS = {
8000: 1 << 11, 8000: 1 << 11,
11025: 1 << 10, 11025: 1 << 10,
12000: 1 << 9, 12000: 1 << 9,
16000: 1 << 8, 16000: 1 << 8,
@@ -406,167 +343,137 @@ class AacMediaCodecInformation:
48000: 1 << 3, 48000: 1 << 3,
64000: 1 << 2, 64000: 1 << 2,
88200: 1 << 1, 88200: 1 << 1,
96000: 1, 96000: 1
}
CHANNELS_BITS = {
1: 1 << 1,
2: 1
} }
CHANNELS_BITS = {1: 1 << 1, 2: 1}
@staticmethod @staticmethod
def from_bytes(data: bytes) -> AacMediaCodecInformation: def from_bytes(data):
object_type = data[0] return AacMediaCodecInformation(*bitstruct.unpack(AacMediaCodecInformation.BIT_FIELDS, data))
sampling_frequency = (data[1] << 4) | ((data[2] >> 4) & 0x0F)
channels = (data[2] >> 2) & 0x03
rfa = 0
vbr = (data[3] >> 7) & 0x01
bitrate = ((data[3] & 0x7F) << 16) | (data[4] << 8) | data[5]
return AacMediaCodecInformation(
object_type, sampling_frequency, channels, rfa, vbr, bitrate
)
@classmethod @classmethod
def from_discrete_values( def from_discrete_values(
cls, cls,
object_type: int, object_type,
sampling_frequency: int, sampling_frequency,
channels: int, channels,
vbr: int, vbr,
bitrate: int, bitrate
) -> AacMediaCodecInformation: ):
return AacMediaCodecInformation( return AacMediaCodecInformation(
object_type=cls.OBJECT_TYPE_BITS[object_type], object_type = cls.OBJECT_TYPE_BITS[object_type],
sampling_frequency=cls.SAMPLING_FREQUENCY_BITS[sampling_frequency], sampling_frequency = cls.SAMPLING_FREQUENCY_BITS[sampling_frequency],
channels=cls.CHANNELS_BITS[channels], channels = cls.CHANNELS_BITS[channels],
rfa=0, vbr = vbr,
vbr=vbr, bitrate = bitrate
bitrate=bitrate,
) )
@classmethod @classmethod
def from_lists( def from_lists(
cls, cls,
object_types: List[int], object_types,
sampling_frequencies: List[int], sampling_frequencies,
channels: List[int], channels,
vbr: int, vbr,
bitrate: int, bitrate
) -> AacMediaCodecInformation: ):
return AacMediaCodecInformation( return AacMediaCodecInformation(
object_type=sum(cls.OBJECT_TYPE_BITS[x] for x in object_types), object_type = sum(cls.OBJECT_TYPE_BITS[x] for x in object_types),
sampling_frequency=sum( sampling_frequency = sum(cls.SAMPLING_FREQUENCY_BITS[x] for x in sampling_frequencies),
cls.SAMPLING_FREQUENCY_BITS[x] for x in sampling_frequencies channels = sum(cls.CHANNELS_BITS[x] for x in channels),
), vbr = vbr,
channels=sum(cls.CHANNELS_BITS[x] for x in channels), bitrate = bitrate
rfa=0,
vbr=vbr,
bitrate=bitrate,
) )
def __bytes__(self) -> bytes: def __bytes__(self):
return bytes( return bitstruct.pack(self.BIT_FIELDS, *self)
[
self.object_type & 0xFF,
(self.sampling_frequency >> 4) & 0xFF,
(((self.sampling_frequency & 0x0F) << 4) | (self.channels << 2)) & 0xFF,
((self.vbr << 7) | ((self.bitrate >> 16) & 0x7F)) & 0xFF,
((self.bitrate >> 8) & 0xFF) & 0xFF,
self.bitrate & 0xFF,
]
)
def __str__(self) -> str: def __str__(self):
object_types = [ object_types = ['MPEG_2_AAC_LC', 'MPEG_4_AAC_LC', 'MPEG_4_AAC_LTP', 'MPEG_4_AAC_SCALABLE', '[4]', '[5]', '[6]', '[7]']
'MPEG_2_AAC_LC',
'MPEG_4_AAC_LC',
'MPEG_4_AAC_LTP',
'MPEG_4_AAC_SCALABLE',
'[4]',
'[5]',
'[6]',
'[7]',
]
channels = [1, 2] channels = [1, 2]
# pylint: disable=line-too-long return '\n'.join([
return '\n'.join( 'AacMediaCodecInformation(',
[ f' object_type: {",".join([str(x) for x in flags_to_list(self.object_type, object_types)])}',
'AacMediaCodecInformation(', f' sampling_frequency: {",".join([str(x) for x in flags_to_list(self.sampling_frequency, MPEG_2_4_AAC_SAMPLING_FREQUENCIES)])}',
f' object_type: {",".join([str(x) for x in flags_to_list(self.object_type, object_types)])}', f' channels: {",".join([str(x) for x in flags_to_list(self.channels, channels)])}',
f' sampling_frequency: {",".join([str(x) for x in flags_to_list(self.sampling_frequency, MPEG_2_4_AAC_SAMPLING_FREQUENCIES)])}', f' vbr: {self.vbr}',
f' channels: {",".join([str(x) for x in flags_to_list(self.channels, channels)])}', f' bitrate: {self.bitrate}'
f' vbr: {self.vbr}', ')'
f' bitrate: {self.bitrate}' ')', ])
]
)
@dataclasses.dataclass
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class VendorSpecificMediaCodecInformation: class VendorSpecificMediaCodecInformation:
''' '''
A2DP spec - 4.7.2 Codec Specific Information Elements A2DP spec - 4.7.2 Codec Specific Information Elements
''' '''
vendor_id: int
codec_id: int
value: bytes
@staticmethod @staticmethod
def from_bytes(data: bytes) -> VendorSpecificMediaCodecInformation: def from_bytes(data):
(vendor_id, codec_id) = struct.unpack_from('<IH', data, 0) (vendor_id, codec_id) = struct.unpack_from('<IH', data, 0)
return VendorSpecificMediaCodecInformation(vendor_id, codec_id, data[6:]) return VendorSpecificMediaCodecInformation(vendor_id, codec_id, data[6:])
def __bytes__(self) -> bytes: def __init__(self, vendor_id, codec_id, value):
self.vendor_id = vendor_id
self.codec_id = codec_id
self.value = value
def __bytes__(self):
return struct.pack('<IH', self.vendor_id, self.codec_id, self.value) return struct.pack('<IH', self.vendor_id, self.codec_id, self.value)
def __str__(self) -> str: def __str__(self):
# pylint: disable=line-too-long return '\n'.join([
return '\n'.join( 'VendorSpecificMediaCodecInformation(',
[ f' vendor_id: {self.vendor_id:08X} ({name_or_number(COMPANY_IDENTIFIERS, self.vendor_id & 0xFFFF)})',
'VendorSpecificMediaCodecInformation(', f' codec_id: {self.codec_id:04X}',
f' vendor_id: {self.vendor_id:08X} ({name_or_number(COMPANY_IDENTIFIERS, self.vendor_id & 0xFFFF)})', f' value: {self.value.hex()}'
f' codec_id: {self.codec_id:04X}', ')'
f' value: {self.value.hex()}' ')', ])
]
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@dataclasses.dataclass
class SbcFrame: class SbcFrame:
sampling_frequency: int def __init__(
block_count: int self,
channel_mode: int sampling_frequency,
subband_count: int block_count,
payload: bytes channel_mode,
subband_count,
payload
):
self.sampling_frequency = sampling_frequency
self.block_count = block_count
self.channel_mode = channel_mode
self.subband_count = subband_count
self.payload = payload
@property @property
def sample_count(self) -> int: def sample_count(self):
return self.subband_count * self.block_count return self.subband_count * self.block_count
@property @property
def bitrate(self) -> int: def bitrate(self):
return 8 * ((len(self.payload) * self.sampling_frequency) // self.sample_count) return 8 * ((len(self.payload) * self.sampling_frequency) // self.sample_count)
@property @property
def duration(self) -> float: def duration(self):
return self.sample_count / self.sampling_frequency return self.sample_count / self.sampling_frequency
def __str__(self) -> str: def __str__(self):
return ( return f'SBC(sf={self.sampling_frequency},cm={self.channel_mode},br={self.bitrate},sc={self.sample_count},size={len(self.payload)})'
f'SBC(sf={self.sampling_frequency},'
f'cm={self.channel_mode},'
f'br={self.bitrate},'
f'sc={self.sample_count},'
f'size={len(self.payload)})'
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class SbcParser: class SbcParser:
def __init__(self, read: Callable[[int], Awaitable[bytes]]) -> None: def __init__(self, read):
self.read = read self.read = read
@property @property
def frames(self) -> AsyncGenerator[SbcFrame, None]: def frames(self):
async def generate_frames() -> AsyncGenerator[SbcFrame, None]: async def generate_frames():
while True: while True:
# Read 4 bytes of header # Read 4 bytes of header
header = await self.read(4) header = await self.read(4)
@@ -580,53 +487,44 @@ class SbcParser:
# Extract some of the header fields # Extract some of the header fields
sampling_frequency = SBC_SAMPLING_FREQUENCIES[(header[1] >> 6) & 3] sampling_frequency = SBC_SAMPLING_FREQUENCIES[(header[1] >> 6) & 3]
blocks = 4 * (1 + ((header[1] >> 4) & 3)) blocks = 4 * (1 + ((header[1] >> 4) & 3))
channel_mode = (header[1] >> 2) & 3 channel_mode = (header[1] >> 2) & 3
channels = 1 if channel_mode == SBC_MONO_CHANNEL_MODE else 2 channels = 1 if channel_mode == SBC_MONO_CHANNEL_MODE else 2
subbands = 8 if ((header[1]) & 1) else 4 subbands = 8 if ((header[1]) & 1) else 4
bitpool = header[2] bitpool = header[2]
# Compute the frame length # Compute the frame length
frame_length = 4 + (4 * subbands * channels) // 8 frame_length = 4 + (4 * subbands * channels) // 8
if channel_mode in (SBC_MONO_CHANNEL_MODE, SBC_DUAL_CHANNEL_MODE): if channel_mode in (SBC_MONO_CHANNEL_MODE, SBC_DUAL_CHANNEL_MODE):
frame_length += (blocks * channels * bitpool) // 8 frame_length += (blocks * channels * bitpool) // 8
else: else:
frame_length += ( frame_length += ((1 if channel_mode == SBC_JOINT_STEREO_CHANNEL_MODE else 0) * subbands + blocks * bitpool) // 8
(1 if channel_mode == SBC_JOINT_STEREO_CHANNEL_MODE else 0)
* subbands
+ blocks * bitpool
) // 8
# Read the rest of the frame # Read the rest of the frame
payload = header + await self.read(frame_length - 4) payload = header + await self.read(frame_length - 4)
# Emit the next frame # Emit the next frame
yield SbcFrame( yield SbcFrame(sampling_frequency, blocks, channel_mode, subbands, payload)
sampling_frequency, blocks, channel_mode, subbands, payload
)
return generate_frames() return generate_frames()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class SbcPacketSource: class SbcPacketSource:
def __init__( def __init__(self, read, mtu, codec_capabilities):
self, read: Callable[[int], Awaitable[bytes]], mtu: int, codec_capabilities self.read = read
) -> None: self.mtu = mtu
self.read = read
self.mtu = mtu
self.codec_capabilities = codec_capabilities self.codec_capabilities = codec_capabilities
@property @property
def packets(self): def packets(self):
async def generate_packets(): async def generate_packets():
# pylint: disable=import-outside-toplevel
from .avdtp import MediaPacket # Import here to avoid a circular reference from .avdtp import MediaPacket # Import here to avoid a circular reference
sequence_number = 0 sequence_number = 0
timestamp = 0 timestamp = 0
frames = [] frames = []
frames_size = 0 frames_size = 0
max_rtp_payload = self.mtu - 12 - 1 max_rtp_payload = self.mtu - 12 - 1
# NOTE: this doesn't support frame fragments # NOTE: this doesn't support frame fragments
@@ -634,27 +532,18 @@ class SbcPacketSource:
async for frame in sbc_parser.frames: async for frame in sbc_parser.frames:
print(frame) print(frame)
if ( if frames_size + len(frame.payload) > max_rtp_payload or len(frames) == 16:
frames_size + len(frame.payload) > max_rtp_payload
or len(frames) == 16
):
# Need to flush what has been accumulated so far # Need to flush what has been accumulated so far
# Emit a packet # Emit a packet
sbc_payload = bytes([len(frames)]) + b''.join( sbc_payload = bytes([len(frames)]) + b''.join([frame.payload for frame in frames])
[frame.payload for frame in frames] packet = MediaPacket(2, 0, 0, 0, sequence_number, timestamp, 0, [], 96, sbc_payload)
)
packet = MediaPacket(
2, 0, 0, 0, sequence_number, timestamp, 0, [], 96, sbc_payload
)
packet.timestamp_seconds = timestamp / frame.sampling_frequency packet.timestamp_seconds = timestamp / frame.sampling_frequency
yield packet yield packet
# Prepare for next packets # Prepare for next packets
sequence_number += 1 sequence_number += 1
sequence_number &= 0xFFFF timestamp += sum([frame.sample_count for frame in frames])
timestamp += sum((frame.sample_count for frame in frames))
timestamp &= 0xFFFFFFFF
frames = [frame] frames = [frame]
frames_size = len(frame.payload) frames_size = len(frame.payload)
else: else:
-85
View File
@@ -1,85 +0,0 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Union
def tokenize_parameters(buffer: bytes) -> List[bytes]:
"""Split input parameters into tokens.
Removes space characters outside of double quote blocks:
T-rec-V-25 - 5.2.1 Command line general format: "Space characters (IA5 2/0)
are ignored [..], unless they are embedded in numeric or string constants"
Raises ValueError in case of invalid input string."""
tokens = []
in_quotes = False
token = bytearray()
for b in buffer:
char = bytearray([b])
if in_quotes:
token.extend(char)
if char == b'\"':
in_quotes = False
tokens.append(token[1:-1])
token = bytearray()
else:
if char == b' ':
pass
elif char == b',' or char == b')':
tokens.append(token)
tokens.append(char)
token = bytearray()
elif char == b'(':
if len(token) > 0:
raise ValueError("open_paren following regular character")
tokens.append(char)
elif char == b'"':
if len(token) > 0:
raise ValueError("quote following regular character")
in_quotes = True
token.extend(char)
else:
token.extend(char)
tokens.append(token)
return [bytes(token) for token in tokens if len(token) > 0]
def parse_parameters(buffer: bytes) -> List[Union[bytes, list]]:
"""Parse the parameters using the comma and parenthesis separators.
Raises ValueError in case of invalid input string."""
tokens = tokenize_parameters(buffer)
accumulator: List[list] = [[]]
current: Union[bytes, list] = bytes()
for token in tokens:
if token == b',':
accumulator[-1].append(current)
current = bytes()
elif token == b'(':
accumulator.append([])
elif token == b')':
if len(accumulator) < 2:
raise ValueError("close_paren without matching open_paren")
accumulator[-1].append(current)
current = accumulator.pop()
else:
current = token
accumulator[-1].append(current)
if len(accumulator) > 1:
raise ValueError("missing close_paren")
return accumulator[0]
+164 -368
View File
@@ -22,38 +22,15 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations from colors import color
import enum
import functools
import inspect
import struct
from typing import (
Any,
Awaitable,
Callable,
Dict,
List,
Optional,
Type,
Union,
TYPE_CHECKING,
)
from pyee import EventEmitter from pyee import EventEmitter
from bumble.core import UUID, name_or_number, ProtocolError from .core import *
from bumble.hci import HCI_Object, key_with_value from .hci import *
from bumble.colors import color
if TYPE_CHECKING:
from bumble.device import Connection
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Constants # Constants
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# fmt: off
# pylint: disable=line-too-long
ATT_CID = 0x04 ATT_CID = 0x04
ATT_ERROR_RESPONSE = 0x01 ATT_ERROR_RESPONSE = 0x01
@@ -186,31 +163,30 @@ ATT_ERROR_NAMES = {
ATT_DEFAULT_MTU = 23 ATT_DEFAULT_MTU = 23
HANDLE_FIELD_SPEC = {'size': 2, 'mapper': lambda x: f'0x{x:04X}'} HANDLE_FIELD_SPEC = {'size': 2, 'mapper': lambda x: f'0x{x:04X}'}
# pylint: disable-next=unnecessary-lambda-assignment,unnecessary-lambda UUID_2_16_FIELD_SPEC = lambda x, y: UUID.parse_uuid(x, y) # noqa: E731
UUID_2_16_FIELD_SPEC = lambda x, y: UUID.parse_uuid(x, y)
# pylint: disable-next=unnecessary-lambda-assignment,unnecessary-lambda
UUID_2_FIELD_SPEC = lambda x, y: UUID.parse_uuid_2(x, y) # noqa: E731 UUID_2_FIELD_SPEC = lambda x, y: UUID.parse_uuid_2(x, y) # noqa: E731
# fmt: on
# pylint: enable=line-too-long # -----------------------------------------------------------------------------
# pylint: disable=invalid-name # Utils
# -----------------------------------------------------------------------------
def key_with_value(dictionary, target_value):
for key, value in dictionary.items():
if value == target_value:
return key
return None
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Exceptions # Exceptions
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class ATT_Error(ProtocolError): class ATT_Error(Exception):
def __init__(self, error_code, att_handle=0x0000, message=''): def __init__(self, error_code, att_handle=0x0000):
super().__init__( self.error_code = error_code
error_code,
error_namespace='att',
error_name=ATT_PDU.error_name(error_code),
)
self.att_handle = att_handle self.att_handle = att_handle
self.message = message
def __str__(self): def __str__(self):
return f'ATT_Error(error={self.error_name}, handle={self.att_handle:04X}): {self.message}' return f'ATT_Error({ATT_PDU.error_name(self.error_code)})'
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -220,10 +196,8 @@ class ATT_PDU:
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.3 ATTRIBUTE PDU See Bluetooth spec @ Vol 3, Part F - 3.3 ATTRIBUTE PDU
''' '''
pdu_classes = {}
pdu_classes: Dict[int, Type[ATT_PDU]] = {}
op_code = 0 op_code = 0
name: str
@staticmethod @staticmethod
def from_bytes(pdu): def from_bytes(pdu):
@@ -300,13 +274,11 @@ class ATT_PDU:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass( @ATT_PDU.subclass([
[ ('request_opcode_in_error', {'size': 1, 'mapper': ATT_PDU.pdu_name}),
('request_opcode_in_error', {'size': 1, 'mapper': ATT_PDU.pdu_name}), ('attribute_handle_in_error', HANDLE_FIELD_SPEC),
('attribute_handle_in_error', HANDLE_FIELD_SPEC), ('error_code', {'size': 1, 'mapper': ATT_PDU.error_name})
('error_code', {'size': 1, 'mapper': ATT_PDU.error_name}), ])
]
)
class ATT_Error_Response(ATT_PDU): class ATT_Error_Response(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.1.1 Error Response See Bluetooth spec @ Vol 3, Part F - 3.4.1.1 Error Response
@@ -314,7 +286,9 @@ class ATT_Error_Response(ATT_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass([('client_rx_mtu', 2)]) @ATT_PDU.subclass([
('client_rx_mtu', 2)
])
class ATT_Exchange_MTU_Request(ATT_PDU): class ATT_Exchange_MTU_Request(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.2.1 Exchange MTU Request See Bluetooth spec @ Vol 3, Part F - 3.4.2.1 Exchange MTU Request
@@ -322,7 +296,9 @@ class ATT_Exchange_MTU_Request(ATT_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass([('server_rx_mtu', 2)]) @ATT_PDU.subclass([
('server_rx_mtu', 2)
])
class ATT_Exchange_MTU_Response(ATT_PDU): class ATT_Exchange_MTU_Response(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.2.2 Exchange MTU Response See Bluetooth spec @ Vol 3, Part F - 3.4.2.2 Exchange MTU Response
@@ -330,9 +306,10 @@ class ATT_Exchange_MTU_Response(ATT_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass( @ATT_PDU.subclass([
[('starting_handle', HANDLE_FIELD_SPEC), ('ending_handle', HANDLE_FIELD_SPEC)] ('starting_handle', HANDLE_FIELD_SPEC),
) ('ending_handle', HANDLE_FIELD_SPEC)
])
class ATT_Find_Information_Request(ATT_PDU): class ATT_Find_Information_Request(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.3.1 Find Information Request See Bluetooth spec @ Vol 3, Part F - 3.4.3.1 Find Information Request
@@ -340,7 +317,10 @@ class ATT_Find_Information_Request(ATT_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass([('format', 1), ('information_data', '*')]) @ATT_PDU.subclass([
('format', 1),
('information_data', '*')
])
class ATT_Find_Information_Response(ATT_PDU): class ATT_Find_Information_Response(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.3.2 Find Information Response See Bluetooth spec @ Vol 3, Part F - 3.4.3.2 Find Information Response
@@ -352,7 +332,7 @@ class ATT_Find_Information_Response(ATT_PDU):
uuid_size = 2 if self.format == 1 else 16 uuid_size = 2 if self.format == 1 else 16
while offset + uuid_size <= len(self.information_data): while offset + uuid_size <= len(self.information_data):
handle = struct.unpack_from('<H', self.information_data, offset)[0] handle = struct.unpack_from('<H', self.information_data, offset)[0]
uuid = self.information_data[2 + offset : 2 + offset + uuid_size] uuid = self.information_data[2 + offset:2 + offset + uuid_size]
self.information.append((handle, uuid)) self.information.append((handle, uuid))
offset += 2 + uuid_size offset += 2 + uuid_size
@@ -366,33 +346,20 @@ class ATT_Find_Information_Response(ATT_PDU):
def __str__(self): def __str__(self):
result = color(self.name, 'yellow') result = color(self.name, 'yellow')
result += ':\n' + HCI_Object.format_fields( result += ':\n' + HCI_Object.format_fields(self.__dict__, [
self.__dict__, ('format', 1),
[ ('information', {'mapper': lambda x: ', '.join([f'0x{handle:04X}:{uuid.hex()}' for handle, uuid in x])})
('format', 1), ], ' ')
(
'information',
{
'mapper': lambda x: ', '.join(
[f'0x{handle:04X}:{uuid.hex()}' for handle, uuid in x]
)
},
),
],
' ',
)
return result return result
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass( @ATT_PDU.subclass([
[ ('starting_handle', HANDLE_FIELD_SPEC),
('starting_handle', HANDLE_FIELD_SPEC), ('ending_handle', HANDLE_FIELD_SPEC),
('ending_handle', HANDLE_FIELD_SPEC), ('attribute_type', UUID_2_FIELD_SPEC),
('attribute_type', UUID_2_FIELD_SPEC), ('attribute_value', '*')
('attribute_value', '*'), ])
]
)
class ATT_Find_By_Type_Value_Request(ATT_PDU): class ATT_Find_By_Type_Value_Request(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.3.3 Find By Type Value Request See Bluetooth spec @ Vol 3, Part F - 3.4.3.3 Find By Type Value Request
@@ -400,7 +367,9 @@ class ATT_Find_By_Type_Value_Request(ATT_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass([('handles_information_list', '*')]) @ATT_PDU.subclass([
('handles_information_list', '*')
])
class ATT_Find_By_Type_Value_Response(ATT_PDU): class ATT_Find_By_Type_Value_Response(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.3.4 Find By Type Value Response See Bluetooth spec @ Vol 3, Part F - 3.4.3.4 Find By Type Value Response
@@ -410,9 +379,7 @@ class ATT_Find_By_Type_Value_Response(ATT_PDU):
self.handles_information = [] self.handles_information = []
offset = 0 offset = 0
while offset + 4 <= len(self.handles_information_list): while offset + 4 <= len(self.handles_information_list):
found_attribute_handle, group_end_handle = struct.unpack_from( found_attribute_handle, group_end_handle = struct.unpack_from('<HH', self.handles_information_list, offset)
'<HH', self.handles_information_list, offset
)
self.handles_information.append((found_attribute_handle, group_end_handle)) self.handles_information.append((found_attribute_handle, group_end_handle))
offset += 4 offset += 4
@@ -426,34 +393,18 @@ class ATT_Find_By_Type_Value_Response(ATT_PDU):
def __str__(self): def __str__(self):
result = color(self.name, 'yellow') result = color(self.name, 'yellow')
result += ':\n' + HCI_Object.format_fields( result += ':\n' + HCI_Object.format_fields(self.__dict__, [
self.__dict__, ('handles_information', {'mapper': lambda x: ', '.join([f'0x{handle1:04X}-0x{handle2:04X}' for handle1, handle2 in x])})
[ ], ' ')
(
'handles_information',
{
'mapper': lambda x: ', '.join(
[
f'0x{handle1:04X}-0x{handle2:04X}'
for handle1, handle2 in x
]
)
},
)
],
' ',
)
return result return result
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass( @ATT_PDU.subclass([
[ ('starting_handle', HANDLE_FIELD_SPEC),
('starting_handle', HANDLE_FIELD_SPEC), ('ending_handle', HANDLE_FIELD_SPEC),
('ending_handle', HANDLE_FIELD_SPEC), ('attribute_type', UUID_2_16_FIELD_SPEC)
('attribute_type', UUID_2_16_FIELD_SPEC), ])
]
)
class ATT_Read_By_Type_Request(ATT_PDU): class ATT_Read_By_Type_Request(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.4.1 Read By Type Request See Bluetooth spec @ Vol 3, Part F - 3.4.4.1 Read By Type Request
@@ -461,7 +412,10 @@ class ATT_Read_By_Type_Request(ATT_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass([('length', 1), ('attribute_data_list', '*')]) @ATT_PDU.subclass([
('length', 1),
('attribute_data_list', '*')
])
class ATT_Read_By_Type_Response(ATT_PDU): class ATT_Read_By_Type_Response(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.4.2 Read By Type Response See Bluetooth spec @ Vol 3, Part F - 3.4.4.2 Read By Type Response
@@ -470,15 +424,9 @@ class ATT_Read_By_Type_Response(ATT_PDU):
def parse_attribute_data_list(self): def parse_attribute_data_list(self):
self.attributes = [] self.attributes = []
offset = 0 offset = 0
while self.length != 0 and offset + self.length <= len( while self.length != 0 and offset + self.length <= len(self.attribute_data_list):
self.attribute_data_list attribute_handle, = struct.unpack_from('<H', self.attribute_data_list, offset)
): attribute_value = self.attribute_data_list[offset + 2:offset + self.length]
(attribute_handle,) = struct.unpack_from(
'<H', self.attribute_data_list, offset
)
attribute_value = self.attribute_data_list[
offset + 2 : offset + self.length
]
self.attributes.append((attribute_handle, attribute_value)) self.attributes.append((attribute_handle, attribute_value))
offset += self.length offset += self.length
@@ -492,26 +440,17 @@ class ATT_Read_By_Type_Response(ATT_PDU):
def __str__(self): def __str__(self):
result = color(self.name, 'yellow') result = color(self.name, 'yellow')
result += ':\n' + HCI_Object.format_fields( result += ':\n' + HCI_Object.format_fields(self.__dict__, [
self.__dict__, ('length', 1),
[ ('attributes', {'mapper': lambda x: ', '.join([f'0x{handle:04X}:{value.hex()}' for handle, value in x])})
('length', 1), ], ' ')
(
'attributes',
{
'mapper': lambda x: ', '.join(
[f'0x{handle:04X}:{value.hex()}' for handle, value in x]
)
},
),
],
' ',
)
return result return result
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass([('attribute_handle', HANDLE_FIELD_SPEC)]) @ATT_PDU.subclass([
('attribute_handle', HANDLE_FIELD_SPEC)
])
class ATT_Read_Request(ATT_PDU): class ATT_Read_Request(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.4.3 Read Request See Bluetooth spec @ Vol 3, Part F - 3.4.4.3 Read Request
@@ -519,7 +458,9 @@ class ATT_Read_Request(ATT_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass([('attribute_value', '*')]) @ATT_PDU.subclass([
('attribute_value', '*')
])
class ATT_Read_Response(ATT_PDU): class ATT_Read_Response(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.4.4 Read Response See Bluetooth spec @ Vol 3, Part F - 3.4.4.4 Read Response
@@ -527,7 +468,10 @@ class ATT_Read_Response(ATT_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass([('attribute_handle', HANDLE_FIELD_SPEC), ('value_offset', 2)]) @ATT_PDU.subclass([
('attribute_handle', HANDLE_FIELD_SPEC),
('value_offset', 2)
])
class ATT_Read_Blob_Request(ATT_PDU): class ATT_Read_Blob_Request(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.4.5 Read Blob Request See Bluetooth spec @ Vol 3, Part F - 3.4.4.5 Read Blob Request
@@ -535,7 +479,9 @@ class ATT_Read_Blob_Request(ATT_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass([('part_attribute_value', '*')]) @ATT_PDU.subclass([
('part_attribute_value', '*')
])
class ATT_Read_Blob_Response(ATT_PDU): class ATT_Read_Blob_Response(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.4.6 Read Blob Response See Bluetooth spec @ Vol 3, Part F - 3.4.4.6 Read Blob Response
@@ -543,7 +489,9 @@ class ATT_Read_Blob_Response(ATT_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass([('set_of_handles', '*')]) @ATT_PDU.subclass([
('set_of_handles', '*')
])
class ATT_Read_Multiple_Request(ATT_PDU): class ATT_Read_Multiple_Request(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.4.7 Read Multiple Request See Bluetooth spec @ Vol 3, Part F - 3.4.4.7 Read Multiple Request
@@ -551,7 +499,9 @@ class ATT_Read_Multiple_Request(ATT_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass([('set_of_values', '*')]) @ATT_PDU.subclass([
('set_of_values', '*')
])
class ATT_Read_Multiple_Response(ATT_PDU): class ATT_Read_Multiple_Response(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.4.8 Read Multiple Response See Bluetooth spec @ Vol 3, Part F - 3.4.4.8 Read Multiple Response
@@ -559,13 +509,11 @@ class ATT_Read_Multiple_Response(ATT_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass( @ATT_PDU.subclass([
[ ('starting_handle', HANDLE_FIELD_SPEC),
('starting_handle', HANDLE_FIELD_SPEC), ('ending_handle', HANDLE_FIELD_SPEC),
('ending_handle', HANDLE_FIELD_SPEC), ('attribute_group_type', UUID_2_16_FIELD_SPEC)
('attribute_group_type', UUID_2_16_FIELD_SPEC), ])
]
)
class ATT_Read_By_Group_Type_Request(ATT_PDU): class ATT_Read_By_Group_Type_Request(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.4.9 Read by Group Type Request See Bluetooth spec @ Vol 3, Part F - 3.4.4.9 Read by Group Type Request
@@ -573,7 +521,10 @@ class ATT_Read_By_Group_Type_Request(ATT_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass([('length', 1), ('attribute_data_list', '*')]) @ATT_PDU.subclass([
('length', 1),
('attribute_data_list', '*')
])
class ATT_Read_By_Group_Type_Response(ATT_PDU): class ATT_Read_By_Group_Type_Response(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.4.10 Read by Group Type Response See Bluetooth spec @ Vol 3, Part F - 3.4.4.10 Read by Group Type Response
@@ -582,18 +533,10 @@ class ATT_Read_By_Group_Type_Response(ATT_PDU):
def parse_attribute_data_list(self): def parse_attribute_data_list(self):
self.attributes = [] self.attributes = []
offset = 0 offset = 0
while self.length != 0 and offset + self.length <= len( while self.length != 0 and offset + self.length <= len(self.attribute_data_list):
self.attribute_data_list attribute_handle, end_group_handle = struct.unpack_from('<HH', self.attribute_data_list, offset)
): attribute_value = self.attribute_data_list[offset + 4:offset + self.length]
attribute_handle, end_group_handle = struct.unpack_from( self.attributes.append((attribute_handle, end_group_handle, attribute_value))
'<HH', self.attribute_data_list, offset
)
attribute_value = self.attribute_data_list[
offset + 4 : offset + self.length
]
self.attributes.append(
(attribute_handle, end_group_handle, attribute_value)
)
offset += self.length offset += self.length
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
@@ -606,29 +549,18 @@ class ATT_Read_By_Group_Type_Response(ATT_PDU):
def __str__(self): def __str__(self):
result = color(self.name, 'yellow') result = color(self.name, 'yellow')
result += ':\n' + HCI_Object.format_fields( result += ':\n' + HCI_Object.format_fields(self.__dict__, [
self.__dict__, ('length', 1),
[ ('attributes', {'mapper': lambda x: ', '.join([f'0x{handle:04X}-0x{end:04X}:{value.hex()}' for handle, end, value in x])})
('length', 1), ], ' ')
(
'attributes',
{
'mapper': lambda x: ', '.join(
[
f'0x{handle:04X}-0x{end:04X}:{value.hex()}'
for handle, end, value in x
]
)
},
),
],
' ',
)
return result return result
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass([('attribute_handle', HANDLE_FIELD_SPEC), ('attribute_value', '*')]) @ATT_PDU.subclass([
('attribute_handle', HANDLE_FIELD_SPEC),
('attribute_value', '*')
])
class ATT_Write_Request(ATT_PDU): class ATT_Write_Request(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.5.1 Write Request See Bluetooth spec @ Vol 3, Part F - 3.4.5.1 Write Request
@@ -644,7 +576,10 @@ class ATT_Write_Response(ATT_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass([('attribute_handle', HANDLE_FIELD_SPEC), ('attribute_value', '*')]) @ATT_PDU.subclass([
('attribute_handle', HANDLE_FIELD_SPEC),
('attribute_value', '*')
])
class ATT_Write_Command(ATT_PDU): class ATT_Write_Command(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.5.3 Write Command See Bluetooth spec @ Vol 3, Part F - 3.4.5.3 Write Command
@@ -652,13 +587,11 @@ class ATT_Write_Command(ATT_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass( @ATT_PDU.subclass([
[ ('attribute_handle', HANDLE_FIELD_SPEC),
('attribute_handle', HANDLE_FIELD_SPEC), ('attribute_value', '*')
('attribute_value', '*'), # ('authentication_signature', 'TODO')
# ('authentication_signature', 'TODO') ])
]
)
class ATT_Signed_Write_Command(ATT_PDU): class ATT_Signed_Write_Command(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.5.4 Signed Write Command See Bluetooth spec @ Vol 3, Part F - 3.4.5.4 Signed Write Command
@@ -666,13 +599,11 @@ class ATT_Signed_Write_Command(ATT_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass( @ATT_PDU.subclass([
[ ('attribute_handle', HANDLE_FIELD_SPEC),
('attribute_handle', HANDLE_FIELD_SPEC), ('value_offset', 2),
('value_offset', 2), ('part_attribute_value', '*')
('part_attribute_value', '*'), ])
]
)
class ATT_Prepare_Write_Request(ATT_PDU): class ATT_Prepare_Write_Request(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.6.1 Prepare Write Request See Bluetooth spec @ Vol 3, Part F - 3.4.6.1 Prepare Write Request
@@ -680,13 +611,11 @@ class ATT_Prepare_Write_Request(ATT_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass( @ATT_PDU.subclass([
[ ('attribute_handle', HANDLE_FIELD_SPEC),
('attribute_handle', HANDLE_FIELD_SPEC), ('value_offset', 2),
('value_offset', 2), ('part_attribute_value', '*')
('part_attribute_value', '*'), ])
]
)
class ATT_Prepare_Write_Response(ATT_PDU): class ATT_Prepare_Write_Response(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.6.2 Prepare Write Response See Bluetooth spec @ Vol 3, Part F - 3.4.6.2 Prepare Write Response
@@ -710,7 +639,10 @@ class ATT_Execute_Write_Response(ATT_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass([('attribute_handle', HANDLE_FIELD_SPEC), ('attribute_value', '*')]) @ATT_PDU.subclass([
('attribute_handle', HANDLE_FIELD_SPEC),
('attribute_value', '*')
])
class ATT_Handle_Value_Notification(ATT_PDU): class ATT_Handle_Value_Notification(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.7.1 Handle Value Notification See Bluetooth spec @ Vol 3, Part F - 3.4.7.1 Handle Value Notification
@@ -718,7 +650,10 @@ class ATT_Handle_Value_Notification(ATT_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ATT_PDU.subclass([('attribute_handle', HANDLE_FIELD_SPEC), ('attribute_value', '*')]) @ATT_PDU.subclass([
('attribute_handle', HANDLE_FIELD_SPEC),
('attribute_value', '*')
])
class ATT_Handle_Value_Indication(ATT_PDU): class ATT_Handle_Value_Indication(ATT_PDU):
''' '''
See Bluetooth spec @ Vol 3, Part F - 3.4.7.2 Handle Value Indication See Bluetooth spec @ Vol 3, Part F - 3.4.7.2 Handle Value Indication
@@ -733,200 +668,61 @@ class ATT_Handle_Value_Confirmation(ATT_PDU):
''' '''
# -----------------------------------------------------------------------------
class AttributeValue:
'''
Attribute value where reading and/or writing is delegated to functions
passed as arguments to the constructor.
'''
def __init__(
self,
read: Union[
Callable[[Optional[Connection]], bytes],
Callable[[Optional[Connection]], Awaitable[bytes]],
None,
] = None,
write: Union[
Callable[[Optional[Connection], bytes], None],
Callable[[Optional[Connection], bytes], Awaitable[None]],
None,
] = None,
):
self._read = read
self._write = write
def read(self, connection: Optional[Connection]) -> Union[bytes, Awaitable[bytes]]:
return self._read(connection) if self._read else b''
def write(
self, connection: Optional[Connection], value: bytes
) -> Union[Awaitable[None], None]:
if self._write:
return self._write(connection, value)
return None
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Attribute(EventEmitter): class Attribute(EventEmitter):
class Permissions(enum.IntFlag): # Permission flags
READABLE = 0x01 READABLE = 0x01
WRITEABLE = 0x02 WRITEABLE = 0x02
READ_REQUIRES_ENCRYPTION = 0x04 READ_REQUIRES_ENCRYPTION = 0x04
WRITE_REQUIRES_ENCRYPTION = 0x08 WRITE_REQUIRES_ENCRYPTION = 0x08
READ_REQUIRES_AUTHENTICATION = 0x10 READ_REQUIRES_AUTHENTICATION = 0x10
WRITE_REQUIRES_AUTHENTICATION = 0x20 WRITE_REQUIRES_AUTHENTICATION = 0x20
READ_REQUIRES_AUTHORIZATION = 0x40 READ_REQUIRES_AUTHORIZATION = 0x40
WRITE_REQUIRES_AUTHORIZATION = 0x80 WRITE_REQUIRES_AUTHORIZATION = 0x80
@classmethod def __init__(self, attribute_type, permissions, value = b''):
def from_string(cls, permissions_str: str) -> Attribute.Permissions:
try:
return functools.reduce(
lambda x, y: x | Attribute.Permissions[y],
permissions_str.replace('|', ',').split(","),
Attribute.Permissions(0),
)
except TypeError as exc:
# The check for `p.name is not None` here is needed because for InFlag
# enums, the .name property can be None, when the enum value is 0,
# so the type hint for .name is Optional[str].
enum_list: List[str] = [p.name for p in cls if p.name is not None]
enum_list_str = ",".join(enum_list)
raise TypeError(
f"Attribute::permissions error:\nExpected a string containing any of the keys, separated by commas: {enum_list_str }\nGot: {permissions_str}"
) from exc
# Permission flags(legacy-use only)
READABLE = Permissions.READABLE
WRITEABLE = Permissions.WRITEABLE
READ_REQUIRES_ENCRYPTION = Permissions.READ_REQUIRES_ENCRYPTION
WRITE_REQUIRES_ENCRYPTION = Permissions.WRITE_REQUIRES_ENCRYPTION
READ_REQUIRES_AUTHENTICATION = Permissions.READ_REQUIRES_AUTHENTICATION
WRITE_REQUIRES_AUTHENTICATION = Permissions.WRITE_REQUIRES_AUTHENTICATION
READ_REQUIRES_AUTHORIZATION = Permissions.READ_REQUIRES_AUTHORIZATION
WRITE_REQUIRES_AUTHORIZATION = Permissions.WRITE_REQUIRES_AUTHORIZATION
value: Union[bytes, AttributeValue]
def __init__(
self,
attribute_type: Union[str, bytes, UUID],
permissions: Union[str, Attribute.Permissions],
value: Union[str, bytes, AttributeValue] = b'',
) -> None:
EventEmitter.__init__(self) EventEmitter.__init__(self)
self.handle = 0 self.handle = 0
self.end_group_handle = 0 self.end_group_handle = 0
if isinstance(permissions, str): self.permissions = permissions
self.permissions = Attribute.Permissions.from_string(permissions)
else:
self.permissions = permissions
# Convert the type to a UUID object if it isn't already # Convert the type to a UUID object if it isn't already
if isinstance(attribute_type, str): if type(attribute_type) is str:
self.type = UUID(attribute_type) self.type = UUID(attribute_type)
elif isinstance(attribute_type, bytes): elif type(attribute_type) is bytes:
self.type = UUID.from_bytes(attribute_type) self.type = UUID.from_bytes(attribute_type)
else: else:
self.type = attribute_type self.type = attribute_type
# Convert the value to a byte array # Convert the value to a byte array
if isinstance(value, str): if type(value) is str:
self.value = bytes(value, 'utf-8') self.value = bytes(value, 'utf-8')
else: else:
self.value = value self.value = value
def encode_value(self, value: Any) -> bytes: def read_value(self, connection):
return value if read := getattr(self.value, 'read', None):
def decode_value(self, value_bytes: bytes) -> Any:
return value_bytes
async def read_value(self, connection: Optional[Connection]) -> bytes:
if (
(self.permissions & self.READ_REQUIRES_ENCRYPTION)
and connection is not None
and not connection.encryption
):
raise ATT_Error(
error_code=ATT_INSUFFICIENT_ENCRYPTION_ERROR, att_handle=self.handle
)
if (
(self.permissions & self.READ_REQUIRES_AUTHENTICATION)
and connection is not None
and not connection.authenticated
):
raise ATT_Error(
error_code=ATT_INSUFFICIENT_AUTHENTICATION_ERROR, att_handle=self.handle
)
if self.permissions & self.READ_REQUIRES_AUTHORIZATION:
# TODO: handle authorization better
raise ATT_Error(
error_code=ATT_INSUFFICIENT_AUTHORIZATION_ERROR, att_handle=self.handle
)
if hasattr(self.value, 'read'):
try: try:
value = self.value.read(connection) return read(connection)
if inspect.isawaitable(value):
value = await value
except ATT_Error as error: except ATT_Error as error:
raise ATT_Error( raise ATT_Error(error_code=error.error_code, att_handle=self.handle)
error_code=error.error_code, att_handle=self.handle
) from error
else: else:
value = self.value return self.value
return self.encode_value(value) def write_value(self, connection, value):
if write := getattr(self.value, 'write', None):
async def write_value(self, connection: Connection, value_bytes: bytes) -> None:
if (
self.permissions & self.WRITE_REQUIRES_ENCRYPTION
) and not connection.encryption:
raise ATT_Error(
error_code=ATT_INSUFFICIENT_ENCRYPTION_ERROR, att_handle=self.handle
)
if (
self.permissions & self.WRITE_REQUIRES_AUTHENTICATION
) and not connection.authenticated:
raise ATT_Error(
error_code=ATT_INSUFFICIENT_AUTHENTICATION_ERROR, att_handle=self.handle
)
if self.permissions & self.WRITE_REQUIRES_AUTHORIZATION:
# TODO: handle authorization better
raise ATT_Error(
error_code=ATT_INSUFFICIENT_AUTHORIZATION_ERROR, att_handle=self.handle
)
value = self.decode_value(value_bytes)
if hasattr(self.value, 'write'):
try: try:
result = self.value.write(connection, value) write(connection, value)
if inspect.isawaitable(result):
await result
except ATT_Error as error: except ATT_Error as error:
raise ATT_Error( raise ATT_Error(error_code=error.error_code, att_handle=self.handle)
error_code=error.error_code, att_handle=self.handle
) from error
else: else:
self.value = value self.value = value
self.emit('write', connection, value) self.emit('write', connection, value)
def __repr__(self): def __repr__(self):
if isinstance(self.value, bytes): if len(self.value) > 0:
value_str = self.value.hex()
else:
value_str = str(self.value)
if value_str:
value_string = f', value={self.value.hex()}' value_string = f', value={self.value.hex()}'
else: else:
value_string = '' value_string = ''
return ( return f'Attribute(handle=0x{self.handle:04X}, type={self.type}, permissions={self.permissions}{value_string})'
f'Attribute(handle=0x{self.handle:04X}, '
f'type={self.type}, '
f'permissions={self.permissions}{value_string})'
)
-520
View File
@@ -1,520 +0,0 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import enum
import struct
from typing import Dict, Type, Union, Tuple
from bumble.utils import OpenIntEnum
# -----------------------------------------------------------------------------
class Frame:
class SubunitType(enum.IntEnum):
# AV/C Digital Interface Command Set General Specification Version 4.1
# Table 7.4
MONITOR = 0x00
AUDIO = 0x01
PRINTER = 0x02
DISC = 0x03
TAPE_RECORDER_OR_PLAYER = 0x04
TUNER = 0x05
CA = 0x06
CAMERA = 0x07
PANEL = 0x09
BULLETIN_BOARD = 0x0A
VENDOR_UNIQUE = 0x1C
EXTENDED = 0x1E
UNIT = 0x1F
class OperationCode(OpenIntEnum):
# 0x00 - 0x0F: Unit and subunit commands
VENDOR_DEPENDENT = 0x00
RESERVE = 0x01
PLUG_INFO = 0x02
# 0x10 - 0x3F: Unit commands
DIGITAL_OUTPUT = 0x10
DIGITAL_INPUT = 0x11
CHANNEL_USAGE = 0x12
OUTPUT_PLUG_SIGNAL_FORMAT = 0x18
INPUT_PLUG_SIGNAL_FORMAT = 0x19
GENERAL_BUS_SETUP = 0x1F
CONNECT_AV = 0x20
DISCONNECT_AV = 0x21
CONNECTIONS = 0x22
CONNECT = 0x24
DISCONNECT = 0x25
UNIT_INFO = 0x30
SUBUNIT_INFO = 0x31
# 0x40 - 0x7F: Subunit commands
PASS_THROUGH = 0x7C
GUI_UPDATE = 0x7D
PUSH_GUI_DATA = 0x7E
USER_ACTION = 0x7F
# 0xA0 - 0xBF: Unit and subunit commands
VERSION = 0xB0
POWER = 0xB2
subunit_type: SubunitType
subunit_id: int
opcode: OperationCode
operands: bytes
@staticmethod
def subclass(subclass):
# Infer the opcode from the class name
if subclass.__name__.endswith("CommandFrame"):
short_name = subclass.__name__.replace("CommandFrame", "")
category_class = CommandFrame
elif subclass.__name__.endswith("ResponseFrame"):
short_name = subclass.__name__.replace("ResponseFrame", "")
category_class = ResponseFrame
else:
raise ValueError(f"invalid subclass name {subclass.__name__}")
uppercase_indexes = [
i for i in range(len(short_name)) if short_name[i].isupper()
]
uppercase_indexes.append(len(short_name))
words = [
short_name[uppercase_indexes[i] : uppercase_indexes[i + 1]].upper()
for i in range(len(uppercase_indexes) - 1)
]
opcode_name = "_".join(words)
opcode = Frame.OperationCode[opcode_name]
category_class.subclasses[opcode] = subclass
return subclass
@staticmethod
def from_bytes(data: bytes) -> Frame:
if data[0] >> 4 != 0:
raise ValueError("first 4 bits must be 0s")
ctype_or_response = data[0] & 0xF
subunit_type = Frame.SubunitType(data[1] >> 3)
subunit_id = data[1] & 7
if subunit_type == Frame.SubunitType.EXTENDED:
# Not supported
raise NotImplementedError("extended subunit types not supported")
if subunit_id < 5:
opcode_offset = 2
elif subunit_id == 5:
# Extended to the next byte
extension = data[2]
if extension == 0:
raise ValueError("extended subunit ID value reserved")
if extension == 0xFF:
subunit_id = 5 + 254 + data[3]
opcode_offset = 4
else:
subunit_id = 5 + extension
opcode_offset = 3
elif subunit_id == 6:
raise ValueError("reserved subunit ID")
opcode = Frame.OperationCode(data[opcode_offset])
operands = data[opcode_offset + 1 :]
# Look for a registered subclass
if ctype_or_response < 8:
# Command
ctype = CommandFrame.CommandType(ctype_or_response)
if c_subclass := CommandFrame.subclasses.get(opcode):
return c_subclass(
ctype,
subunit_type,
subunit_id,
*c_subclass.parse_operands(operands),
)
return CommandFrame(ctype, subunit_type, subunit_id, opcode, operands)
else:
# Response
response = ResponseFrame.ResponseCode(ctype_or_response)
if r_subclass := ResponseFrame.subclasses.get(opcode):
return r_subclass(
response,
subunit_type,
subunit_id,
*r_subclass.parse_operands(operands),
)
return ResponseFrame(response, subunit_type, subunit_id, opcode, operands)
def to_bytes(
self,
ctype_or_response: Union[CommandFrame.CommandType, ResponseFrame.ResponseCode],
) -> bytes:
# TODO: support extended subunit types and ids.
return (
bytes(
[
ctype_or_response,
self.subunit_type << 3 | self.subunit_id,
self.opcode,
]
)
+ self.operands
)
def to_string(self, extra: str) -> str:
return (
f"{self.__class__.__name__}({extra}"
f"subunit_type={self.subunit_type.name}, "
f"subunit_id=0x{self.subunit_id:02X}, "
f"opcode={self.opcode.name}, "
f"operands={self.operands.hex()})"
)
def __init__(
self,
subunit_type: SubunitType,
subunit_id: int,
opcode: OperationCode,
operands: bytes,
) -> None:
self.subunit_type = subunit_type
self.subunit_id = subunit_id
self.opcode = opcode
self.operands = operands
# -----------------------------------------------------------------------------
class CommandFrame(Frame):
class CommandType(OpenIntEnum):
# AV/C Digital Interface Command Set General Specification Version 4.1
# Table 7.1
CONTROL = 0x00
STATUS = 0x01
SPECIFIC_INQUIRY = 0x02
NOTIFY = 0x03
GENERAL_INQUIRY = 0x04
subclasses: Dict[Frame.OperationCode, Type[CommandFrame]] = {}
ctype: CommandType
@staticmethod
def parse_operands(operands: bytes) -> Tuple:
raise NotImplementedError
def __init__(
self,
ctype: CommandType,
subunit_type: Frame.SubunitType,
subunit_id: int,
opcode: Frame.OperationCode,
operands: bytes,
) -> None:
super().__init__(subunit_type, subunit_id, opcode, operands)
self.ctype = ctype
def __bytes__(self):
return self.to_bytes(self.ctype)
def __str__(self):
return self.to_string(f"ctype={self.ctype.name}, ")
# -----------------------------------------------------------------------------
class ResponseFrame(Frame):
class ResponseCode(OpenIntEnum):
# AV/C Digital Interface Command Set General Specification Version 4.1
# Table 7.2
NOT_IMPLEMENTED = 0x08
ACCEPTED = 0x09
REJECTED = 0x0A
IN_TRANSITION = 0x0B
IMPLEMENTED_OR_STABLE = 0x0C
CHANGED = 0x0D
INTERIM = 0x0F
subclasses: Dict[Frame.OperationCode, Type[ResponseFrame]] = {}
response: ResponseCode
@staticmethod
def parse_operands(operands: bytes) -> Tuple:
raise NotImplementedError
def __init__(
self,
response: ResponseCode,
subunit_type: Frame.SubunitType,
subunit_id: int,
opcode: Frame.OperationCode,
operands: bytes,
) -> None:
super().__init__(subunit_type, subunit_id, opcode, operands)
self.response = response
def __bytes__(self):
return self.to_bytes(self.response)
def __str__(self):
return self.to_string(f"response={self.response.name}, ")
# -----------------------------------------------------------------------------
class VendorDependentFrame:
company_id: int
vendor_dependent_data: bytes
@staticmethod
def parse_operands(operands: bytes) -> Tuple:
return (
struct.unpack(">I", b"\x00" + operands[:3])[0],
operands[3:],
)
def make_operands(self) -> bytes:
return struct.pack(">I", self.company_id)[1:] + self.vendor_dependent_data
def __init__(self, company_id: int, vendor_dependent_data: bytes):
self.company_id = company_id
self.vendor_dependent_data = vendor_dependent_data
# -----------------------------------------------------------------------------
@Frame.subclass
class VendorDependentCommandFrame(VendorDependentFrame, CommandFrame):
def __init__(
self,
ctype: CommandFrame.CommandType,
subunit_type: Frame.SubunitType,
subunit_id: int,
company_id: int,
vendor_dependent_data: bytes,
) -> None:
VendorDependentFrame.__init__(self, company_id, vendor_dependent_data)
CommandFrame.__init__(
self,
ctype,
subunit_type,
subunit_id,
Frame.OperationCode.VENDOR_DEPENDENT,
self.make_operands(),
)
def __str__(self):
return (
f"VendorDependentCommandFrame(ctype={self.ctype.name}, "
f"subunit_type={self.subunit_type.name}, "
f"subunit_id=0x{self.subunit_id:02X}, "
f"company_id=0x{self.company_id:06X}, "
f"vendor_dependent_data={self.vendor_dependent_data.hex()})"
)
# -----------------------------------------------------------------------------
@Frame.subclass
class VendorDependentResponseFrame(VendorDependentFrame, ResponseFrame):
def __init__(
self,
response: ResponseFrame.ResponseCode,
subunit_type: Frame.SubunitType,
subunit_id: int,
company_id: int,
vendor_dependent_data: bytes,
) -> None:
VendorDependentFrame.__init__(self, company_id, vendor_dependent_data)
ResponseFrame.__init__(
self,
response,
subunit_type,
subunit_id,
Frame.OperationCode.VENDOR_DEPENDENT,
self.make_operands(),
)
def __str__(self):
return (
f"VendorDependentResponseFrame(response={self.response.name}, "
f"subunit_type={self.subunit_type.name}, "
f"subunit_id=0x{self.subunit_id:02X}, "
f"company_id=0x{self.company_id:06X}, "
f"vendor_dependent_data={self.vendor_dependent_data.hex()})"
)
# -----------------------------------------------------------------------------
class PassThroughFrame:
"""
See AV/C Panel Subunit Specification 1.1 - 9.4 PASS THROUGH control command
"""
class StateFlag(enum.IntEnum):
PRESSED = 0
RELEASED = 1
class OperationId(OpenIntEnum):
SELECT = 0x00
UP = 0x01
DOWN = 0x01
LEFT = 0x03
RIGHT = 0x04
RIGHT_UP = 0x05
RIGHT_DOWN = 0x06
LEFT_UP = 0x07
LEFT_DOWN = 0x08
ROOT_MENU = 0x09
SETUP_MENU = 0x0A
CONTENTS_MENU = 0x0B
FAVORITE_MENU = 0x0C
EXIT = 0x0D
NUMBER_0 = 0x20
NUMBER_1 = 0x21
NUMBER_2 = 0x22
NUMBER_3 = 0x23
NUMBER_4 = 0x24
NUMBER_5 = 0x25
NUMBER_6 = 0x26
NUMBER_7 = 0x27
NUMBER_8 = 0x28
NUMBER_9 = 0x29
DOT = 0x2A
ENTER = 0x2B
CLEAR = 0x2C
CHANNEL_UP = 0x30
CHANNEL_DOWN = 0x31
PREVIOUS_CHANNEL = 0x32
SOUND_SELECT = 0x33
INPUT_SELECT = 0x34
DISPLAY_INFORMATION = 0x35
HELP = 0x36
PAGE_UP = 0x37
PAGE_DOWN = 0x38
POWER = 0x40
VOLUME_UP = 0x41
VOLUME_DOWN = 0x42
MUTE = 0x43
PLAY = 0x44
STOP = 0x45
PAUSE = 0x46
RECORD = 0x47
REWIND = 0x48
FAST_FORWARD = 0x49
EJECT = 0x4A
FORWARD = 0x4B
BACKWARD = 0x4C
ANGLE = 0x50
SUBPICTURE = 0x51
F1 = 0x71
F2 = 0x72
F3 = 0x73
F4 = 0x74
F5 = 0x75
VENDOR_UNIQUE = 0x7E
state_flag: StateFlag
operation_id: OperationId
operation_data: bytes
@staticmethod
def parse_operands(operands: bytes) -> Tuple:
return (
PassThroughFrame.StateFlag(operands[0] >> 7),
PassThroughFrame.OperationId(operands[0] & 0x7F),
operands[1 : 1 + operands[1]],
)
def make_operands(self):
return (
bytes([self.state_flag << 7 | self.operation_id, len(self.operation_data)])
+ self.operation_data
)
def __init__(
self,
state_flag: StateFlag,
operation_id: OperationId,
operation_data: bytes,
) -> None:
if len(operation_data) > 255:
raise ValueError("operation data must be <= 255 bytes")
self.state_flag = state_flag
self.operation_id = operation_id
self.operation_data = operation_data
# -----------------------------------------------------------------------------
@Frame.subclass
class PassThroughCommandFrame(PassThroughFrame, CommandFrame):
def __init__(
self,
ctype: CommandFrame.CommandType,
subunit_type: Frame.SubunitType,
subunit_id: int,
state_flag: PassThroughFrame.StateFlag,
operation_id: PassThroughFrame.OperationId,
operation_data: bytes,
) -> None:
PassThroughFrame.__init__(self, state_flag, operation_id, operation_data)
CommandFrame.__init__(
self,
ctype,
subunit_type,
subunit_id,
Frame.OperationCode.PASS_THROUGH,
self.make_operands(),
)
def __str__(self):
return (
f"PassThroughCommandFrame(ctype={self.ctype.name}, "
f"subunit_type={self.subunit_type.name}, "
f"subunit_id=0x{self.subunit_id:02X}, "
f"state_flag={self.state_flag.name}, "
f"operation_id={self.operation_id.name}, "
f"operation_data={self.operation_data.hex()})"
)
# -----------------------------------------------------------------------------
@Frame.subclass
class PassThroughResponseFrame(PassThroughFrame, ResponseFrame):
def __init__(
self,
response: ResponseFrame.ResponseCode,
subunit_type: Frame.SubunitType,
subunit_id: int,
state_flag: PassThroughFrame.StateFlag,
operation_id: PassThroughFrame.OperationId,
operation_data: bytes,
) -> None:
PassThroughFrame.__init__(self, state_flag, operation_id, operation_data)
ResponseFrame.__init__(
self,
response,
subunit_type,
subunit_id,
Frame.OperationCode.PASS_THROUGH,
self.make_operands(),
)
def __str__(self):
return (
f"PassThroughResponseFrame(response={self.response.name}, "
f"subunit_type={self.subunit_type.name}, "
f"subunit_id=0x{self.subunit_id:02X}, "
f"state_flag={self.state_flag.name}, "
f"operation_id={self.operation_id.name}, "
f"operation_data={self.operation_data.hex()})"
)
-291
View File
@@ -1,291 +0,0 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
from enum import IntEnum
import logging
import struct
from typing import Callable, cast, Dict, Optional
from bumble.colors import color
from bumble import avc
from bumble import l2cap
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
AVCTP_PSM = 0x0017
AVCTP_BROWSING_PSM = 0x001B
# -----------------------------------------------------------------------------
class MessageAssembler:
Callback = Callable[[int, bool, bool, int, bytes], None]
transaction_label: int
pid: int
c_r: int
ipid: int
payload: bytes
number_of_packets: int
packets_received: int
def __init__(self, callback: Callback) -> None:
self.callback = callback
self.reset()
def reset(self) -> None:
self.packets_received = 0
self.transaction_label = -1
self.pid = -1
self.c_r = -1
self.ipid = -1
self.payload = b''
self.number_of_packets = 0
self.packet_count = 0
def on_pdu(self, pdu: bytes) -> None:
self.packets_received += 1
transaction_label = pdu[0] >> 4
packet_type = Protocol.PacketType((pdu[0] >> 2) & 3)
c_r = (pdu[0] >> 1) & 1
ipid = pdu[0] & 1
if c_r == 0 and ipid != 0:
logger.warning("invalid IPID in command frame")
self.reset()
return
pid_offset = 1
if packet_type in (Protocol.PacketType.SINGLE, Protocol.PacketType.START):
if self.transaction_label >= 0:
# We are already in a transaction
logger.warning("received START or SINGLE fragment while in transaction")
self.reset()
self.packets_received = 1
if packet_type == Protocol.PacketType.START:
self.number_of_packets = pdu[1]
pid_offset = 2
pid = struct.unpack_from(">H", pdu, pid_offset)[0]
self.payload += pdu[pid_offset + 2 :]
if packet_type in (Protocol.PacketType.CONTINUE, Protocol.PacketType.END):
if transaction_label != self.transaction_label:
logger.warning("transaction label does not match")
self.reset()
return
if pid != self.pid:
logger.warning("PID does not match")
self.reset()
return
if c_r != self.c_r:
logger.warning("C/R does not match")
self.reset()
return
if self.packets_received > self.number_of_packets:
logger.warning("too many fragments in transaction")
self.reset()
return
if packet_type == Protocol.PacketType.END:
if self.packets_received != self.number_of_packets:
logger.warning("premature END")
self.reset()
return
else:
self.transaction_label = transaction_label
self.c_r = c_r
self.ipid = ipid
self.pid = pid
if packet_type in (Protocol.PacketType.SINGLE, Protocol.PacketType.END):
self.on_message_complete()
def on_message_complete(self):
try:
self.callback(
self.transaction_label,
self.c_r == 0,
self.ipid != 0,
self.pid,
self.payload,
)
except Exception as error:
logger.exception(color(f"!!! exception in callback: {error}", "red"))
self.reset()
# -----------------------------------------------------------------------------
class Protocol:
CommandHandler = Callable[[int, avc.CommandFrame], None]
command_handlers: Dict[int, CommandHandler] # Command handlers, by PID
ResponseHandler = Callable[[int, Optional[avc.ResponseFrame]], None]
response_handlers: Dict[int, ResponseHandler] # Response handlers, by PID
next_transaction_label: int
message_assembler: MessageAssembler
class PacketType(IntEnum):
SINGLE = 0b00
START = 0b01
CONTINUE = 0b10
END = 0b11
def __init__(self, l2cap_channel: l2cap.ClassicChannel) -> None:
self.command_handlers = {}
self.response_handlers = {}
self.l2cap_channel = l2cap_channel
self.message_assembler = MessageAssembler(self.on_message)
# Register to receive PDUs from the channel
l2cap_channel.sink = self.on_pdu
l2cap_channel.on("open", self.on_l2cap_channel_open)
l2cap_channel.on("close", self.on_l2cap_channel_close)
def on_l2cap_channel_open(self):
logger.debug(color("<<< AVCTP channel open", "magenta"))
def on_l2cap_channel_close(self):
logger.debug(color("<<< AVCTP channel closed", "magenta"))
def on_pdu(self, pdu: bytes) -> None:
self.message_assembler.on_pdu(pdu)
def on_message(
self,
transaction_label: int,
is_command: bool,
ipid: bool,
pid: int,
payload: bytes,
) -> None:
logger.debug(
f"<<< AVCTP Message: pid={pid}, "
f"transaction_label={transaction_label}, "
f"is_command={is_command}, "
f"ipid={ipid}, "
f"payload={payload.hex()}"
)
# Check for invalid PID responses.
if ipid:
logger.debug(f"received IPID for PID={pid}")
# Find the appropriate handler.
if is_command:
if pid not in self.command_handlers:
logger.warning(f"no command handler for PID {pid}")
self.send_ipid(transaction_label, pid)
return
command_frame = cast(avc.CommandFrame, avc.Frame.from_bytes(payload))
self.command_handlers[pid](transaction_label, command_frame)
else:
if pid not in self.response_handlers:
logger.warning(f"no response handler for PID {pid}")
return
# By convention, for an ipid, send a None payload to the response handler.
if ipid:
response_frame = None
else:
response_frame = cast(avc.ResponseFrame, avc.Frame.from_bytes(payload))
self.response_handlers[pid](transaction_label, response_frame)
def send_message(
self,
transaction_label: int,
is_command: bool,
ipid: bool,
pid: int,
payload: bytes,
):
# TODO: fragment large messages
packet_type = Protocol.PacketType.SINGLE
pdu = (
struct.pack(
">BH",
transaction_label << 4
| packet_type << 2
| (0 if is_command else 1) << 1
| (1 if ipid else 0),
pid,
)
+ payload
)
self.l2cap_channel.send_pdu(pdu)
def send_command(self, transaction_label: int, pid: int, payload: bytes) -> None:
logger.debug(
">>> AVCTP command: "
f"transaction_label={transaction_label}, "
f"pid={pid}, "
f"payload={payload.hex()}"
)
self.send_message(transaction_label, True, False, pid, payload)
def send_response(self, transaction_label: int, pid: int, payload: bytes):
logger.debug(
">>> AVCTP response: "
f"transaction_label={transaction_label}, "
f"pid={pid}, "
f"payload={payload.hex()}"
)
self.send_message(transaction_label, False, False, pid, payload)
def send_ipid(self, transaction_label: int, pid: int) -> None:
logger.debug(
">>> AVCTP ipid: " f"transaction_label={transaction_label}, " f"pid={pid}"
)
self.send_message(transaction_label, False, True, pid, b'')
def register_command_handler(
self, pid: int, handler: Protocol.CommandHandler
) -> None:
self.command_handlers[pid] = handler
def unregister_command_handler(
self, pid: int, handler: Protocol.CommandHandler
) -> None:
if pid not in self.command_handlers or self.command_handlers[pid] != handler:
raise ValueError("command handler not registered")
del self.command_handlers[pid]
def register_response_handler(
self, pid: int, handler: Protocol.ResponseHandler
) -> None:
self.response_handlers[pid] = handler
def unregister_response_handler(
self, pid: int, handler: Protocol.ResponseHandler
) -> None:
if pid not in self.response_handlers or self.response_handlers[pid] != handler:
raise ValueError("response handler not registered")
del self.response_handlers[pid]
+410 -735
View File
File diff suppressed because it is too large Load Diff
-1918
View File
File diff suppressed because it is too large Load Diff
+7 -7
View File
@@ -30,10 +30,10 @@ logger = logging.getLogger(__name__)
class HCI_Bridge: class HCI_Bridge:
class Forwarder: class Forwarder:
def __init__(self, hci_sink, sender_hci_sink, packet_filter, trace): def __init__(self, hci_sink, sender_hci_sink, packet_filter, trace):
self.hci_sink = hci_sink self.hci_sink = hci_sink
self.sender_hci_sink = sender_hci_sink self.sender_hci_sink = sender_hci_sink
self.packet_filter = packet_filter self.packet_filter = packet_filter
self.trace = trace self.trace = trace
def on_packet(self, packet): def on_packet(self, packet):
# Convert the packet bytes to an object # Convert the packet bytes to an object
@@ -61,15 +61,15 @@ class HCI_Bridge:
hci_host_sink, hci_host_sink,
hci_controller_source, hci_controller_source,
hci_controller_sink, hci_controller_sink,
host_to_controller_filter=None, host_to_controller_filter = None,
controller_to_host_filter=None, controller_to_host_filter = None
): ):
tracer = PacketTracer(emit_message=logger.info) tracer = PacketTracer(emit_message=logger.info)
host_to_controller_forwarder = HCI_Bridge.Forwarder( host_to_controller_forwarder = HCI_Bridge.Forwarder(
hci_controller_sink, hci_controller_sink,
hci_host_sink, hci_host_sink,
host_to_controller_filter, host_to_controller_filter,
lambda packet: tracer.trace(packet, 0), lambda packet: tracer.trace(packet, 0)
) )
hci_host_source.set_packet_sink(host_to_controller_forwarder) hci_host_source.set_packet_sink(host_to_controller_forwarder)
@@ -77,6 +77,6 @@ class HCI_Bridge:
hci_host_sink, hci_host_sink,
hci_controller_sink, hci_controller_sink,
controller_to_host_filter, controller_to_host_filter,
lambda packet: tracer.trace(packet, 1), lambda packet: tracer.trace(packet, 1)
) )
hci_controller_source.set_packet_sink(controller_to_host_forwarder) hci_controller_source.set_packet_sink(controller_to_host_forwarder)
-381
View File
@@ -1,381 +0,0 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
from dataclasses import dataclass
# -----------------------------------------------------------------------------
class BitReader:
"""Simple but not optimized bit stream reader."""
data: bytes
bytes_position: int
bit_position: int
cache: int
bits_cached: int
def __init__(self, data: bytes):
self.data = data
self.byte_position = 0
self.bit_position = 0
self.cache = 0
self.bits_cached = 0
def read(self, bits: int) -> int:
""" "Read up to 32 bits."""
if bits > 32:
raise ValueError('maximum read size is 32')
if self.bits_cached >= bits:
# We have enough bits.
self.bits_cached -= bits
self.bit_position += bits
return (self.cache >> self.bits_cached) & ((1 << bits) - 1)
# Read more cache, up to 32 bits
feed_bytes = self.data[self.byte_position : self.byte_position + 4]
feed_size = len(feed_bytes)
feed_int = int.from_bytes(feed_bytes, byteorder='big')
if 8 * feed_size + self.bits_cached < bits:
raise ValueError('trying to read past the data')
self.byte_position += feed_size
# Combine the new cache and the old cache
cache = self.cache & ((1 << self.bits_cached) - 1)
new_bits = bits - self.bits_cached
self.bits_cached = 8 * feed_size - new_bits
result = (feed_int >> self.bits_cached) | (cache << new_bits)
self.cache = feed_int
self.bit_position += bits
return result
def read_bytes(self, count: int):
if self.bit_position + 8 * count > 8 * len(self.data):
raise ValueError('not enough data')
if self.bit_position % 8:
# Not byte aligned
result = bytearray(count)
for i in range(count):
result[i] = self.read(8)
return bytes(result)
# Byte aligned
self.byte_position = self.bit_position // 8
self.bits_cached = 0
self.cache = 0
offset = self.bit_position // 8
self.bit_position += 8 * count
return self.data[offset : offset + count]
def bits_left(self) -> int:
return (8 * len(self.data)) - self.bit_position
def skip(self, bits: int) -> None:
# Slow, but simple...
while bits:
if bits > 32:
self.read(32)
bits -= 32
else:
self.read(bits)
break
# -----------------------------------------------------------------------------
class AacAudioRtpPacket:
"""AAC payload encapsulated in an RTP packet payload"""
@staticmethod
def latm_value(reader: BitReader) -> int:
bytes_for_value = reader.read(2)
value = 0
for _ in range(bytes_for_value + 1):
value = value * 256 + reader.read(8)
return value
@staticmethod
def program_config_element(reader: BitReader):
raise ValueError('program_config_element not supported')
@dataclass
class GASpecificConfig:
def __init__(
self, reader: BitReader, channel_configuration: int, audio_object_type: int
) -> None:
# GASpecificConfig - ISO/EIC 14496-3 Table 4.1
frame_length_flag = reader.read(1)
depends_on_core_coder = reader.read(1)
if depends_on_core_coder:
self.core_coder_delay = reader.read(14)
extension_flag = reader.read(1)
if not channel_configuration:
AacAudioRtpPacket.program_config_element(reader)
if audio_object_type in (6, 20):
self.layer_nr = reader.read(3)
if extension_flag:
if audio_object_type == 22:
num_of_sub_frame = reader.read(5)
layer_length = reader.read(11)
if audio_object_type in (17, 19, 20, 23):
aac_section_data_resilience_flags = reader.read(1)
aac_scale_factor_data_resilience_flags = reader.read(1)
aac_spectral_data_resilience_flags = reader.read(1)
extension_flag_3 = reader.read(1)
if extension_flag_3 == 1:
raise ValueError('extensionFlag3 == 1 not supported')
@staticmethod
def audio_object_type(reader: BitReader):
# GetAudioObjectType - ISO/EIC 14496-3 Table 1.16
audio_object_type = reader.read(5)
if audio_object_type == 31:
audio_object_type = 32 + reader.read(6)
return audio_object_type
@dataclass
class AudioSpecificConfig:
audio_object_type: int
sampling_frequency_index: int
sampling_frequency: int
channel_configuration: int
sbr_present_flag: int
ps_present_flag: int
extension_audio_object_type: int
extension_sampling_frequency_index: int
extension_sampling_frequency: int
extension_channel_configuration: int
SAMPLING_FREQUENCIES = [
96000,
88200,
64000,
48000,
44100,
32000,
24000,
22050,
16000,
12000,
11025,
8000,
7350,
]
def __init__(self, reader: BitReader) -> None:
# AudioSpecificConfig - ISO/EIC 14496-3 Table 1.15
self.audio_object_type = AacAudioRtpPacket.audio_object_type(reader)
self.sampling_frequency_index = reader.read(4)
if self.sampling_frequency_index == 0xF:
self.sampling_frequency = reader.read(24)
else:
self.sampling_frequency = self.SAMPLING_FREQUENCIES[
self.sampling_frequency_index
]
self.channel_configuration = reader.read(4)
self.sbr_present_flag = -1
self.ps_present_flag = -1
if self.audio_object_type in (5, 29):
self.extension_audio_object_type = 5
self.sbc_present_flag = 1
if self.audio_object_type == 29:
self.ps_present_flag = 1
self.extension_sampling_frequency_index = reader.read(4)
if self.extension_sampling_frequency_index == 0xF:
self.extension_sampling_frequency = reader.read(24)
else:
self.extension_sampling_frequency = self.SAMPLING_FREQUENCIES[
self.extension_sampling_frequency_index
]
self.audio_object_type = AacAudioRtpPacket.audio_object_type(reader)
if self.audio_object_type == 22:
self.extension_channel_configuration = reader.read(4)
else:
self.extension_audio_object_type = 0
if self.audio_object_type in (1, 2, 3, 4, 6, 7, 17, 19, 20, 21, 22, 23):
ga_specific_config = AacAudioRtpPacket.GASpecificConfig(
reader, self.channel_configuration, self.audio_object_type
)
else:
raise ValueError(
f'audioObjectType {self.audio_object_type} not supported'
)
# if self.extension_audio_object_type != 5 and bits_to_decode >= 16:
# sync_extension_type = reader.read(11)
# if sync_extension_type == 0x2B7:
# self.extension_audio_object_type = AacAudioRtpPacket.audio_object_type(reader)
# if self.extension_audio_object_type == 5:
# self.sbr_present_flag = reader.read(1)
# if self.sbr_present_flag:
# self.extension_sampling_frequency_index = reader.read(4)
# if self.extension_sampling_frequency_index == 0xF:
# self.extension_sampling_frequency = reader.read(24)
# else:
# self.extension_sampling_frequency = self.SAMPLING_FREQUENCIES[self.extension_sampling_frequency_index]
# if bits_to_decode >= 12:
# sync_extension_type = reader.read(11)
# if sync_extension_type == 0x548:
# self.ps_present_flag = reader.read(1)
# elif self.extension_audio_object_type == 22:
# self.sbr_present_flag = reader.read(1)
# if self.sbr_present_flag:
# self.extension_sampling_frequency_index = reader.read(4)
# if self.extension_sampling_frequency_index == 0xF:
# self.extension_sampling_frequency = reader.read(24)
# else:
# self.extension_sampling_frequency = self.SAMPLING_FREQUENCIES[self.extension_sampling_frequency_index]
# self.extension_channel_configuration = reader.read(4)
@dataclass
class StreamMuxConfig:
other_data_present: int
other_data_len_bits: int
audio_specific_config: AacAudioRtpPacket.AudioSpecificConfig
def __init__(self, reader: BitReader) -> None:
# StreamMuxConfig - ISO/EIC 14496-3 Table 1.42
audio_mux_version = reader.read(1)
if audio_mux_version == 1:
audio_mux_version_a = reader.read(1)
else:
audio_mux_version_a = 0
if audio_mux_version_a != 0:
raise ValueError('audioMuxVersionA != 0 not supported')
if audio_mux_version == 1:
tara_buffer_fullness = AacAudioRtpPacket.latm_value(reader)
stream_cnt = 0
all_streams_same_time_framing = reader.read(1)
num_sub_frames = reader.read(6)
num_program = reader.read(4)
if num_program != 0:
raise ValueError('num_program != 0 not supported')
num_layer = reader.read(3)
if num_layer != 0:
raise ValueError('num_layer != 0 not supported')
if audio_mux_version == 0:
self.audio_specific_config = AacAudioRtpPacket.AudioSpecificConfig(
reader
)
else:
asc_len = AacAudioRtpPacket.latm_value(reader)
marker = reader.bit_position
self.audio_specific_config = AacAudioRtpPacket.AudioSpecificConfig(
reader
)
audio_specific_config_len = reader.bit_position - marker
if asc_len < audio_specific_config_len:
raise ValueError('audio_specific_config_len > asc_len')
asc_len -= audio_specific_config_len
reader.skip(asc_len)
frame_length_type = reader.read(3)
if frame_length_type == 0:
latm_buffer_fullness = reader.read(8)
elif frame_length_type == 1:
frame_length = reader.read(9)
else:
raise ValueError(f'frame_length_type {frame_length_type} not supported')
self.other_data_present = reader.read(1)
if self.other_data_present:
if audio_mux_version == 1:
self.other_data_len_bits = AacAudioRtpPacket.latm_value(reader)
else:
self.other_data_len_bits = 0
while True:
self.other_data_len_bits *= 256
other_data_len_esc = reader.read(1)
self.other_data_len_bits += reader.read(8)
if other_data_len_esc == 0:
break
crc_check_present = reader.read(1)
if crc_check_present:
crc_checksum = reader.read(8)
@dataclass
class AudioMuxElement:
payload: bytes
stream_mux_config: AacAudioRtpPacket.StreamMuxConfig
def __init__(self, reader: BitReader, mux_config_present: int):
if mux_config_present == 0:
raise ValueError('muxConfigPresent == 0 not supported')
# AudioMuxElement - ISO/EIC 14496-3 Table 1.41
use_same_stream_mux = reader.read(1)
if use_same_stream_mux:
raise ValueError('useSameStreamMux == 1 not supported')
self.stream_mux_config = AacAudioRtpPacket.StreamMuxConfig(reader)
# We only support:
# allStreamsSameTimeFraming == 1
# audioMuxVersionA == 0,
# numProgram == 0
# numSubFrames == 0
# numLayer == 0
mux_slot_length_bytes = 0
while True:
tmp = reader.read(8)
mux_slot_length_bytes += tmp
if tmp != 255:
break
self.payload = reader.read_bytes(mux_slot_length_bytes)
if self.stream_mux_config.other_data_present:
reader.skip(self.stream_mux_config.other_data_len_bits)
# ByteAlign
while reader.bit_position % 8:
reader.read(1)
def __init__(self, data: bytes) -> None:
# Parse the bit stream
reader = BitReader(data)
self.audio_mux_element = self.AudioMuxElement(reader, mux_config_present=1)
def to_adts(self):
# pylint: disable=line-too-long
sampling_frequency_index = (
self.audio_mux_element.stream_mux_config.audio_specific_config.sampling_frequency_index
)
channel_configuration = (
self.audio_mux_element.stream_mux_config.audio_specific_config.channel_configuration
)
frame_size = len(self.audio_mux_element.payload)
return (
bytes(
[
0xFF,
0xF1, # 0xF9 (MPEG2)
0x40
| (sampling_frequency_index << 2)
| (channel_configuration >> 2),
((channel_configuration & 0x3) << 6) | ((frame_size + 7) >> 11),
((frame_size + 7) >> 3) & 0xFF,
(((frame_size + 7) << 5) & 0xFF) | 0x1F,
0xFC,
]
)
+ self.audio_mux_element.payload
)
-103
View File
@@ -1,103 +0,0 @@
# Copyright (c) 2012 Giorgos Verigakis <verigak@gmail.com>
#
# Permission to use, copy, modify, and distribute this software for any
# purpose with or without fee is hereby granted, provided that the above
# copyright notice and this permission notice appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
from functools import partial
from typing import List, Optional, Union
# ANSI color names. There is also a "default"
COLORS = ('black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white')
# ANSI style names
STYLES = (
'none',
'bold',
'faint',
'italic',
'underline',
'blink',
'blink2',
'negative',
'concealed',
'crossed',
)
ColorSpec = Union[str, int]
def _join(*values: ColorSpec) -> str:
return ';'.join(str(v) for v in values)
def _color_code(spec: ColorSpec, base: int) -> str:
if isinstance(spec, str):
spec = spec.strip().lower()
if spec == 'default':
return _join(base + 9)
elif spec in COLORS:
return _join(base + COLORS.index(spec))
elif isinstance(spec, int) and 0 <= spec <= 255:
return _join(base + 8, 5, spec)
else:
raise ValueError('Invalid color spec "%s"' % spec)
def color(
s: str,
fg: Optional[ColorSpec] = None,
bg: Optional[ColorSpec] = None,
style: Optional[str] = None,
) -> str:
codes: List[ColorSpec] = []
if fg:
codes.append(_color_code(fg, 30))
if bg:
codes.append(_color_code(bg, 40))
if style:
for style_part in style.split('+'):
if style_part in STYLES:
codes.append(STYLES.index(style_part))
else:
raise ValueError('Invalid style "%s"' % style_part)
if codes:
return '\x1b[{0}m{1}\x1b[0m'.format(_join(*codes), s)
else:
return s
# Foreground color shortcuts
black = partial(color, fg='black')
red = partial(color, fg='red')
green = partial(color, fg='green')
yellow = partial(color, fg='yellow')
blue = partial(color, fg='blue')
magenta = partial(color, fg='magenta')
cyan = partial(color, fg='cyan')
white = partial(color, fg='white')
# Style shortcuts
bold = partial(color, style='bold')
none = partial(color, style='none')
faint = partial(color, style='faint')
italic = partial(color, style='italic')
underline = partial(color, style='underline')
blink = partial(color, style='blink')
blink2 = partial(color, style='blink2')
negative = partial(color, style='negative')
concealed = partial(color, style='concealed')
crossed = partial(color, style='crossed')
+198 -839
View File
File diff suppressed because it is too large Load Diff
+295 -1077
View File
File diff suppressed because it is too large Load Diff
+237 -1027
View File
File diff suppressed because it is too large Load Diff
+86 -142
View File
@@ -21,24 +21,27 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations
import logging import logging
import operator import operator
import platform
import secrets if platform.system() != 'Emscripten':
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes import secrets
from cryptography.hazmat.primitives.asymmetric.ec import ( from cryptography.hazmat.primitives.ciphers import (
generate_private_key, Cipher,
ECDH, algorithms,
EllipticCurvePrivateKey, modes
EllipticCurvePublicNumbers, )
EllipticCurvePrivateNumbers, from cryptography.hazmat.primitives.asymmetric.ec import (
SECP256R1, generate_private_key,
) ECDH,
from cryptography.hazmat.primitives import cmac EllipticCurvePublicNumbers,
from typing import Tuple EllipticCurvePrivateNumbers,
SECP256R1
)
from cryptography.hazmat.primitives import cmac
else:
# TODO: implement stubs
pass
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -50,43 +53,31 @@ logger = logging.getLogger(__name__)
# Classes # Classes
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class EccKey: class EccKey:
def __init__(self, private_key: EllipticCurvePrivateKey) -> None: def __init__(self, private_key):
self.private_key = private_key self.private_key = private_key
@classmethod @classmethod
def generate(cls) -> EccKey: def generate(cls):
private_key = generate_private_key(SECP256R1()) private_key = generate_private_key(SECP256R1())
return cls(private_key) return cls(private_key)
@classmethod @classmethod
def from_private_key_bytes( def from_private_key_bytes(cls, d_bytes, x_bytes, y_bytes):
cls, d_bytes: bytes, x_bytes: bytes, y_bytes: bytes
) -> EccKey:
d = int.from_bytes(d_bytes, byteorder='big', signed=False) d = int.from_bytes(d_bytes, byteorder='big', signed=False)
x = int.from_bytes(x_bytes, byteorder='big', signed=False) x = int.from_bytes(x_bytes, byteorder='big', signed=False)
y = int.from_bytes(y_bytes, byteorder='big', signed=False) y = int.from_bytes(y_bytes, byteorder='big', signed=False)
private_key = EllipticCurvePrivateNumbers( private_key = EllipticCurvePrivateNumbers(d, EllipticCurvePublicNumbers(x, y, SECP256R1())).private_key()
d, EllipticCurvePublicNumbers(x, y, SECP256R1())
).private_key()
return cls(private_key) return cls(private_key)
@property @property
def x(self) -> bytes: def x(self):
return ( return self.private_key.public_key().public_numbers().x.to_bytes(32, byteorder='big')
self.private_key.public_key()
.public_numbers()
.x.to_bytes(32, byteorder='big')
)
@property @property
def y(self) -> bytes: def y(self):
return ( return self.private_key.public_key().public_numbers().y.to_bytes(32, byteorder='big')
self.private_key.public_key()
.public_numbers()
.y.to_bytes(32, byteorder='big')
)
def dh(self, public_key_x: bytes, public_key_y: bytes) -> bytes: def dh(self, public_key_x, public_key_y):
x = int.from_bytes(public_key_x, byteorder='big', signed=False) x = int.from_bytes(public_key_x, byteorder='big', signed=False)
y = int.from_bytes(public_key_y, byteorder='big', signed=False) y = int.from_bytes(public_key_y, byteorder='big', signed=False)
public_key = EllipticCurvePublicNumbers(x, y, SECP256R1()).public_key() public_key = EllipticCurvePublicNumbers(x, y, SECP256R1()).public_key()
@@ -99,33 +90,14 @@ class EccKey:
# Functions # Functions
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def generate_prand() -> bytes: def xor(x, y):
'''Generates random 3 bytes, with the 2 most significant bits of 0b01. assert(len(x) == len(y))
See Bluetooth spec, Vol 6, Part E - Table 1.2.
'''
prand_bytes = secrets.token_bytes(6)
return prand_bytes[:2] + bytes([(prand_bytes[2] & 0b01111111) | 0b01000000])
# -----------------------------------------------------------------------------
def xor(x: bytes, y: bytes) -> bytes:
assert len(x) == len(y)
return bytes(map(operator.xor, x, y)) return bytes(map(operator.xor, x, y))
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def reverse(input: bytes) -> bytes: def r():
'''
Returns bytes of input in reversed endianness.
'''
return input[::-1]
# -----------------------------------------------------------------------------
def r() -> bytes:
''' '''
Generate 16 bytes of random data Generate 16 bytes of random data
''' '''
@@ -133,20 +105,20 @@ def r() -> bytes:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def e(key: bytes, data: bytes) -> bytes: def e(key, data):
''' '''
AES-128 ECB, expecting byte-swapped inputs and producing a byte-swapped output. AES-128 ECB, expecting byte-swapped inputs and producing a byte-swapped output.
See Bluetooth spec Vol 3, Part H - 2.2.1 Security function e See Bluetooth spec Vol 3, Part H - 2.2.1 Security function e
''' '''
cipher = Cipher(algorithms.AES(reverse(key)), modes.ECB()) cipher = Cipher(algorithms.AES(bytes(reversed(key))), modes.ECB())
encryptor = cipher.encryptor() encryptor = cipher.encryptor()
return reverse(encryptor.update(reverse(data))) return bytes(reversed(encryptor.update(bytes(reversed(data)))))
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def ah(k: bytes, r: bytes) -> bytes: # pylint: disable=redefined-outer-name def ah(k, r):
''' '''
See Bluetooth spec Vol 3, Part H - 2.2.2 Random Address Hash function ah See Bluetooth spec Vol 3, Part H - 2.2.2 Random Address Hash function ah
''' '''
@@ -157,19 +129,9 @@ def ah(k: bytes, r: bytes) -> bytes: # pylint: disable=redefined-outer-name
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def c1( def c1(k, r, preq, pres, iat, rat, ia, ra):
k: bytes,
r: bytes,
preq: bytes,
pres: bytes,
iat: int,
rat: int,
ia: bytes,
ra: bytes,
) -> bytes: # pylint: disable=redefined-outer-name
''' '''
See Bluetooth spec, Vol 3, Part H - 2.2.3 Confirm value generation function c1 for See Bluetooth spec, Vol 3, Part H - 2.2.3 Confirm value generation function c1 for LE Legacy Pairing
LE Legacy Pairing
''' '''
p1 = bytes([iat, rat]) + preq + pres p1 = bytes([iat, rat]) + preq + pres
@@ -178,17 +140,16 @@ def c1(
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def s1(k: bytes, r1: bytes, r2: bytes) -> bytes: def s1(k, r1, r2):
''' '''
See Bluetooth spec, Vol 3, Part H - 2.2.4 Key generation function s1 for LE Legacy See Bluetooth spec, Vol 3, Part H - 2.2.4 Key generation function s1 for LE Legacy Pairing
Pairing
''' '''
return e(k, r2[0:8] + r1[0:8]) return e(k, r2[0:8] + r1[0:8])
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def aes_cmac(m: bytes, k: bytes) -> bytes: def aes_cmac(m, k):
''' '''
See Bluetooth spec, Vol 3, Part H - 2.2.5 FunctionAES-CMAC See Bluetooth spec, Vol 3, Part H - 2.2.5 FunctionAES-CMAC
@@ -200,100 +161,83 @@ def aes_cmac(m: bytes, k: bytes) -> bytes:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def f4(u: bytes, v: bytes, x: bytes, z: bytes) -> bytes: def f4(u, v, x, z):
''' '''
See Bluetooth spec, Vol 3, Part H - 2.2.6 LE Secure Connections Confirm Value See Bluetooth spec, Vol 3, Part H - 2.2.6 LE Secure Connections Confirm Value Generation Function f4
Generation Function f4
''' '''
return reverse(aes_cmac(reverse(u) + reverse(v) + z, reverse(x))) return bytes(reversed(aes_cmac(bytes(reversed(u)) + bytes(reversed(v)) + z, bytes(reversed(x)))))
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def f5(w: bytes, n1: bytes, n2: bytes, a1: bytes, a2: bytes) -> Tuple[bytes, bytes]: def f5(w, n1, n2, a1, a2):
''' '''
See Bluetooth spec, Vol 3, Part H - 2.2.7 LE Secure Connections Key Generation See Bluetooth spec, Vol 3, Part H - 2.2.7 LE Secure Connections Key Generation Function f5
Function f5
NOTE: this returns a tuple: (MacKey, LTK) in little-endian byte order NOTE: this returns a tuple: (MacKey, LTK) in little-endian byte order
''' '''
salt = bytes.fromhex('6C888391AAF5A53860370BDB5A6083BE') salt = bytes.fromhex('6C888391AAF5A53860370BDB5A6083BE')
t = aes_cmac(reverse(w), salt) t = aes_cmac(bytes(reversed(w)), salt)
key_id = bytes([0x62, 0x74, 0x6C, 0x65]) key_id = bytes([0x62, 0x74, 0x6c, 0x65])
return ( return (
reverse( bytes(reversed(aes_cmac(
aes_cmac( bytes([0]) +
bytes([0]) key_id +
+ key_id bytes(reversed(n1)) +
+ reverse(n1) bytes(reversed(n2)) +
+ reverse(n2) bytes(reversed(a1)) +
+ reverse(a1) bytes(reversed(a2)) +
+ reverse(a2) bytes([1, 0]),
+ bytes([1, 0]), t
t, ))),
) bytes(reversed(aes_cmac(
), bytes([1]) +
reverse( key_id +
aes_cmac( bytes(reversed(n1)) +
bytes([1]) bytes(reversed(n2)) +
+ key_id bytes(reversed(a1)) +
+ reverse(n1) bytes(reversed(a2)) +
+ reverse(n2) bytes([1, 0]),
+ reverse(a1) t
+ reverse(a2) )))
+ bytes([1, 0]),
t,
)
),
) )
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def f6( def f6(w, n1, n2, r, io_cap, a1, a2):
w: bytes, n1: bytes, n2: bytes, r: bytes, io_cap: bytes, a1: bytes, a2: bytes
) -> bytes: # pylint: disable=redefined-outer-name
''' '''
See Bluetooth spec, Vol 3, Part H - 2.2.8 LE Secure Connections Check Value See Bluetooth spec, Vol 3, Part H - 2.2.8 LE Secure Connections Check Value Generation Function f6
Generation Function f6
''' '''
return reverse( return bytes(reversed(aes_cmac(
aes_cmac( bytes(reversed(n1)) +
reverse(n1) bytes(reversed(n2)) +
+ reverse(n2) bytes(reversed(r)) +
+ reverse(r) bytes(reversed(io_cap)) +
+ reverse(io_cap) bytes(reversed(a1)) +
+ reverse(a1) bytes(reversed(a2)),
+ reverse(a2), bytes(reversed(w))
reverse(w), )))
)
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def g2(u: bytes, v: bytes, x: bytes, y: bytes) -> int: def g2(u, v, x, y):
''' '''
See Bluetooth spec, Vol 3, Part H - 2.2.9 LE Secure Connections Numeric Comparison See Bluetooth spec, Vol 3, Part H - 2.2.9 LE Secure Connections Numeric Comparison Value Generation Function g2
Value Generation Function g2
''' '''
return int.from_bytes( return int.from_bytes(
aes_cmac( aes_cmac(bytes(reversed(u)) + bytes(reversed(v)) + bytes(reversed(y)), bytes(reversed(x)))[-4:],
reverse(u) + reverse(v) + reverse(y), byteorder='big'
reverse(x),
)[-4:],
byteorder='big',
) )
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def h6(w: bytes, key_id: bytes) -> bytes: def h6(w, key_id):
''' '''
See Bluetooth spec, Vol 3, Part H - 2.2.10 Link key conversion function h6 See Bluetooth spec, Vol 3, Part H - 2.2.10 Link key conversion function h6
''' '''
return reverse(aes_cmac(key_id, reverse(w))) return aes_cmac(key_id, w)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def h7(salt: bytes, w: bytes) -> bytes: def h7(salt, w):
''' '''
See Bluetooth spec, Vol 3, Part H - 2.2.11 Link key conversion function h7 See Bluetooth spec, Vol 3, Part H - 2.2.11 Link key conversion function h7
''' '''
return reverse(aes_cmac(reverse(w), salt)) return aes_cmac(w, salt)
-416
View File
@@ -1,416 +0,0 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
# fmt: off
WL = [-60, -30, 58, 172, 334, 538, 1198, 3042]
RL42 = [0, 7, 6, 5, 4, 3, 2, 1, 7, 6, 5, 4, 3, 2, 1, 0]
ILB = [
2048,
2093,
2139,
2186,
2233,
2282,
2332,
2383,
2435,
2489,
2543,
2599,
2656,
2714,
2774,
2834,
2896,
2960,
3025,
3091,
3158,
3228,
3298,
3371,
3444,
3520,
3597,
3676,
3756,
3838,
3922,
4008,
]
WH = [0, -214, 798]
RH2 = [2, 1, 2, 1]
# Values in QM2/QM4/QM6 left shift three bits than original g722 specification.
QM2 = [-7408, -1616, 7408, 1616]
QM4 = [
0,
-20456,
-12896,
-8968,
-6288,
-4240,
-2584,
-1200,
20456,
12896,
8968,
6288,
4240,
2584,
1200,
0,
]
QM6 = [
-136,
-136,
-136,
-136,
-24808,
-21904,
-19008,
-16704,
-14984,
-13512,
-12280,
-11192,
-10232,
-9360,
-8576,
-7856,
-7192,
-6576,
-6000,
-5456,
-4944,
-4464,
-4008,
-3576,
-3168,
-2776,
-2400,
-2032,
-1688,
-1360,
-1040,
-728,
24808,
21904,
19008,
16704,
14984,
13512,
12280,
11192,
10232,
9360,
8576,
7856,
7192,
6576,
6000,
5456,
4944,
4464,
4008,
3576,
3168,
2776,
2400,
2032,
1688,
1360,
1040,
728,
432,
136,
-432,
-136,
]
QMF_COEFFS = [3, -11, 12, 32, -210, 951, 3876, -805, 362, -156, 53, -11]
# fmt: on
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
class G722Decoder(object):
"""G.722 decoder with bitrate 64kbit/s.
For the Blocks in the sub-band decoders, please refer to the G.722
specification for the required information. G722 specification:
https://www.itu.int/rec/T-REC-G.722-201209-I
"""
def __init__(self):
self._x = [0] * 24
self._band = [Band(), Band()]
# The initial value in BLOCK 3L
self._band[0].det = 32
# The initial value in BLOCK 3H
self._band[1].det = 8
def decode_frame(self, encoded_data) -> bytearray:
result_array = bytearray(len(encoded_data) * 4)
self.g722_decode(result_array, encoded_data)
return result_array
def g722_decode(self, result_array, encoded_data) -> int:
"""Decode the data frame using g722 decoder."""
result_length = 0
for code in encoded_data:
higher_bits = (code >> 6) & 0x03
lower_bits = code & 0x3F
rlow = self.lower_sub_band_decoder(lower_bits)
rhigh = self.higher_sub_band_decoder(higher_bits)
# Apply the receive QMF
self._x[:22] = self._x[2:]
self._x[22] = rlow + rhigh
self._x[23] = rlow - rhigh
xout2 = sum(self._x[2 * i] * QMF_COEFFS[i] for i in range(12))
xout1 = sum(self._x[2 * i + 1] * QMF_COEFFS[11 - i] for i in range(12))
result_length = self.update_decoded_result(
xout1, result_length, result_array
)
result_length = self.update_decoded_result(
xout2, result_length, result_array
)
return result_length
def update_decoded_result(self, xout, byte_length, byte_array) -> int:
result = (int)(xout >> 11)
bytes_result = result.to_bytes(2, 'little', signed=True)
byte_array[byte_length] = bytes_result[0]
byte_array[byte_length + 1] = bytes_result[1]
return byte_length + 2
def lower_sub_band_decoder(self, lower_bits) -> int:
"""Lower sub-band decoder for last six bits."""
# Block 5L
# INVQBL
wd1 = lower_bits
wd2 = QM6[wd1]
wd1 >>= 2
wd2 = (self._band[0].det * wd2) >> 15
# RECONS
rlow = self._band[0].s + wd2
# Block 6L
# LIMIT
if rlow > 16383:
rlow = 16383
elif rlow < -16384:
rlow = -16384
# Block 2L
# INVQAL
wd2 = QM4[wd1]
dlowt = (self._band[0].det * wd2) >> 15
# Block 3L
# LOGSCL
wd2 = RL42[wd1]
wd1 = (self._band[0].nb * 127) >> 7
wd1 += WL[wd2]
if wd1 < 0:
wd1 = 0
elif wd1 > 18432:
wd1 = 18432
self._band[0].nb = wd1
# SCALEL
wd1 = (self._band[0].nb >> 6) & 31
wd2 = 8 - (self._band[0].nb >> 11)
if wd2 < 0:
wd3 = ILB[wd1] << -wd2
else:
wd3 = ILB[wd1] >> wd2
self._band[0].det = wd3 << 2
# Block 4L
self._band[0].block4(dlowt)
return rlow
def higher_sub_band_decoder(self, higher_bits) -> int:
"""Higher sub-band decoder for first two bits."""
# Block 2H
# INVQAH
wd2 = QM2[higher_bits]
dhigh = (self._band[1].det * wd2) >> 15
# Block 5H
# RECONS
rhigh = dhigh + self._band[1].s
# Block 6H
# LIMIT
if rhigh > 16383:
rhigh = 16383
elif rhigh < -16384:
rhigh = -16384
# Block 3H
# LOGSCH
wd2 = RH2[higher_bits]
wd1 = (self._band[1].nb * 127) >> 7
wd1 += WH[wd2]
if wd1 < 0:
wd1 = 0
elif wd1 > 22528:
wd1 = 22528
self._band[1].nb = wd1
# SCALEH
wd1 = (self._band[1].nb >> 6) & 31
wd2 = 10 - (self._band[1].nb >> 11)
if wd2 < 0:
wd3 = ILB[wd1] << -wd2
else:
wd3 = ILB[wd1] >> wd2
self._band[1].det = wd3 << 2
# Block 4H
self._band[1].block4(dhigh)
return rhigh
# -----------------------------------------------------------------------------
class Band(object):
"""Structure for G722 decode proccessing."""
s: int = 0
nb: int = 0
det: int = 0
def __init__(self):
self._sp = 0
self._sz = 0
self._r = [0] * 3
self._a = [0] * 3
self._ap = [0] * 3
self._p = [0] * 3
self._d = [0] * 7
self._b = [0] * 7
self._bp = [0] * 7
self._sg = [0] * 7
def saturate(self, amp: int) -> int:
if amp > 32767:
return 32767
elif amp < -32768:
return -32768
else:
return amp
def block4(self, d: int) -> None:
"""Block4 for both lower and higher sub-band decoder."""
wd1 = 0
wd2 = 0
wd3 = 0
# RECONS
self._d[0] = d
self._r[0] = self.saturate(self.s + d)
# PARREC
self._p[0] = self.saturate(self._sz + d)
# UPPOL2
for i in range(3):
self._sg[i] = (self._p[i]) >> 15
wd1 = self.saturate((self._a[1]) << 2)
wd2 = -wd1 if self._sg[0] == self._sg[1] else wd1
if wd2 > 32767:
wd2 = 32767
wd3 = 128 if self._sg[0] == self._sg[2] else -128
wd3 += wd2 >> 7
wd3 += (self._a[2] * 32512) >> 15
if wd3 > 12288:
wd3 = 12288
elif wd3 < -12288:
wd3 = -12288
self._ap[2] = wd3
# UPPOL1
self._sg[0] = (self._p[0]) >> 15
self._sg[1] = (self._p[1]) >> 15
wd1 = 192 if self._sg[0] == self._sg[1] else -192
wd2 = (self._a[1] * 32640) >> 15
self._ap[1] = self.saturate(wd1 + wd2)
wd3 = self.saturate(15360 - self._ap[2])
if self._ap[1] > wd3:
self._ap[1] = wd3
elif self._ap[1] < -wd3:
self._ap[1] = -wd3
# UPZERO
wd1 = 0 if d == 0 else 128
self._sg[0] = d >> 15
for i in range(1, 7):
self._sg[i] = (self._d[i]) >> 15
wd2 = wd1 if self._sg[i] == self._sg[0] else -wd1
wd3 = (self._b[i] * 32640) >> 15
self._bp[i] = self.saturate(wd2 + wd3)
# DELAYA
for i in range(6, 0, -1):
self._d[i] = self._d[i - 1]
self._b[i] = self._bp[i]
for i in range(2, 0, -1):
self._r[i] = self._r[i - 1]
self._p[i] = self._p[i - 1]
self._a[i] = self._ap[i]
# FILTEP
self._sp = 0
for i in range(1, 3):
wd1 = self.saturate(self._r[i] + self._r[i])
self._sp += (self._a[i] * wd1) >> 15
self._sp = self.saturate(self._sp)
# FILTEZ
self._sz = 0
for i in range(6, 0, -1):
wd1 = self.saturate(self._d[i] + self._d[i])
self._sz += (self._b[i] * wd1) >> 15
self._sz = self.saturate(self._sz)
# PREDIC
self.s = self.saturate(self._sp + self._sz)
+615 -3927
View File
File diff suppressed because it is too large Load Diff
-87
View File
@@ -1,87 +0,0 @@
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Drivers that can be used to customize the interaction between a host and a controller,
like loading firmware after a cold start.
"""
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import logging
import pathlib
import platform
from typing import Dict, Iterable, Optional, Type, TYPE_CHECKING
from . import rtk, intel
from .common import Driver
if TYPE_CHECKING:
from bumble.host import Host
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Functions
# -----------------------------------------------------------------------------
async def get_driver_for_host(host: Host) -> Optional[Driver]:
"""Probe diver classes until one returns a valid instance for a host, or none is
found.
If a "driver" HCI metadata entry is present, only that driver class will be probed.
"""
driver_classes: Dict[str, Type[Driver]] = {"rtk": rtk.Driver, "intel": intel.Driver}
probe_list: Iterable[str]
if driver_name := host.hci_metadata.get("driver"):
# Only probe a single driver
probe_list = [driver_name]
else:
# Probe all drivers
probe_list = driver_classes.keys()
for driver_name in probe_list:
if driver_class := driver_classes.get(driver_name):
logger.debug(f"Probing driver class: {driver_name}")
if driver := await driver_class.for_host(host):
logger.debug(f"Instantiated {driver_name} driver")
return driver
else:
logger.debug(f"Skipping unknown driver class: {driver_name}")
return None
def project_data_dir() -> pathlib.Path:
"""
Returns:
A path to an OS-specific directory for bumble data. The directory is created if
it doesn't exist.
"""
import platformdirs
if platform.system() == 'Darwin':
# platformdirs doesn't handle macOS right: it doesn't assemble a bundle id
# out of author & project
return platformdirs.user_data_path(
appname='com.google.bumble', ensure_exists=True
)
else:
# windows and linux don't use the com qualifier
return platformdirs.user_data_path(
appname='bumble', appauthor='google', ensure_exists=True
)
-45
View File
@@ -1,45 +0,0 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Common types for drivers.
"""
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import abc
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
class Driver(abc.ABC):
"""Base class for drivers."""
@staticmethod
async def for_host(_host):
"""Return a driver instance for a host.
Args:
host: Host object for which a driver should be created.
Returns:
A Driver instance if a driver should be instantiated for this host, or
None if no driver instance of this class is needed.
"""
return None
@abc.abstractmethod
async def init_controller(self):
"""Initialize the controller."""
-102
View File
@@ -1,102 +0,0 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import logging
from bumble.drivers import common
from bumble.hci import (
hci_vendor_command_op_code, # type: ignore
HCI_Command,
HCI_Reset_Command,
)
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constant
# -----------------------------------------------------------------------------
INTEL_USB_PRODUCTS = {
# Intel AX210
(0x8087, 0x0032),
# Intel BE200
(0x8087, 0x0036),
}
# -----------------------------------------------------------------------------
# HCI Commands
# -----------------------------------------------------------------------------
HCI_INTEL_DDC_CONFIG_WRITE_COMMAND = hci_vendor_command_op_code(0xFC8B) # type: ignore
HCI_INTEL_DDC_CONFIG_WRITE_PAYLOAD = [0x03, 0xE4, 0x02, 0x00]
HCI_Command.register_commands(globals())
@HCI_Command.command( # type: ignore
fields=[("params", "*")],
return_parameters_fields=[
("params", "*"),
],
)
class Hci_Intel_DDC_Config_Write_Command(HCI_Command):
pass
class Driver(common.Driver):
def __init__(self, host):
self.host = host
@staticmethod
def check(host):
driver = host.hci_metadata.get("driver")
if driver == "intel":
return True
vendor_id = host.hci_metadata.get("vendor_id")
product_id = host.hci_metadata.get("product_id")
if vendor_id is None or product_id is None:
logger.debug("USB metadata not sufficient")
return False
if (vendor_id, product_id) not in INTEL_USB_PRODUCTS:
logger.debug(
f"USB device ({vendor_id:04X}, {product_id:04X}) " "not in known list"
)
return False
return True
@classmethod
async def for_host(cls, host, force=False): # type: ignore
# Only instantiate this driver if explicitly selected
if not force and not cls.check(host):
return None
return cls(host)
async def init_controller(self):
self.host.ready = True
await self.host.send_command(HCI_Reset_Command(), check_result=True)
await self.host.send_command(
Hci_Intel_DDC_Config_Write_Command(
params=HCI_INTEL_DDC_CONFIG_WRITE_PAYLOAD
)
)
-666
View File
@@ -1,666 +0,0 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Support for Realtek USB dongles.
Based on various online bits of information, including the Linux kernel.
(see `drivers/bluetooth/btrtl.c`)
"""
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from dataclasses import dataclass
import asyncio
import enum
import logging
import math
import os
import pathlib
import platform
import struct
from typing import Tuple
import weakref
from bumble.hci import (
hci_vendor_command_op_code,
STATUS_SPEC,
HCI_SUCCESS,
HCI_Command,
HCI_Reset_Command,
HCI_Read_Local_Version_Information_Command,
)
from bumble.drivers import common
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
RTK_ROM_LMP_8723A = 0x1200
RTK_ROM_LMP_8723B = 0x8723
RTK_ROM_LMP_8821A = 0x8821
RTK_ROM_LMP_8761A = 0x8761
RTK_ROM_LMP_8822B = 0x8822
RTK_ROM_LMP_8852A = 0x8852
RTK_CONFIG_MAGIC = 0x8723AB55
RTK_EPATCH_SIGNATURE = b"Realtech"
RTK_FRAGMENT_LENGTH = 252
RTK_FIRMWARE_DIR_ENV = "BUMBLE_RTK_FIRMWARE_DIR"
RTK_LINUX_FIRMWARE_DIR = "/lib/firmware/rtl_bt"
class RtlProjectId(enum.IntEnum):
PROJECT_ID_8723A = 0
PROJECT_ID_8723B = 1
PROJECT_ID_8821A = 2
PROJECT_ID_8761A = 3
PROJECT_ID_8822B = 8
PROJECT_ID_8723D = 9
PROJECT_ID_8821C = 10
PROJECT_ID_8822C = 13
PROJECT_ID_8761B = 14
PROJECT_ID_8852A = 18
PROJECT_ID_8852B = 20
PROJECT_ID_8852C = 25
RTK_PROJECT_ID_TO_ROM = {
0: RTK_ROM_LMP_8723A,
1: RTK_ROM_LMP_8723B,
2: RTK_ROM_LMP_8821A,
3: RTK_ROM_LMP_8761A,
8: RTK_ROM_LMP_8822B,
9: RTK_ROM_LMP_8723B,
10: RTK_ROM_LMP_8821A,
13: RTK_ROM_LMP_8822B,
14: RTK_ROM_LMP_8761A,
18: RTK_ROM_LMP_8852A,
20: RTK_ROM_LMP_8852A,
25: RTK_ROM_LMP_8852A,
}
# List of USB (VendorID, ProductID) for Realtek-based devices.
RTK_USB_PRODUCTS = {
# Realtek 8723AE
(0x0930, 0x021D),
(0x13D3, 0x3394),
# Realtek 8723BE
(0x0489, 0xE085),
(0x0489, 0xE08B),
(0x04F2, 0xB49F),
(0x13D3, 0x3410),
(0x13D3, 0x3416),
(0x13D3, 0x3459),
(0x13D3, 0x3494),
# Realtek 8723BU
(0x7392, 0xA611),
# Realtek 8723DE
(0x0BDA, 0xB009),
(0x2FF8, 0xB011),
# Realtek 8761BUV
(0x0B05, 0x190E),
(0x0BDA, 0x8771),
(0x2230, 0x0016),
(0x2357, 0x0604),
(0x2550, 0x8761),
(0x2B89, 0x8761),
(0x7392, 0xC611),
(0x0BDA, 0x877B),
# Realtek 8821AE
(0x0B05, 0x17DC),
(0x13D3, 0x3414),
(0x13D3, 0x3458),
(0x13D3, 0x3461),
(0x13D3, 0x3462),
# Realtek 8821CE
(0x0BDA, 0xB00C),
(0x0BDA, 0xC822),
(0x13D3, 0x3529),
# Realtek 8822BE
(0x0B05, 0x185C),
(0x13D3, 0x3526),
# Realtek 8822CE
(0x04C5, 0x161F),
(0x04CA, 0x4005),
(0x0B05, 0x18EF),
(0x0BDA, 0xB00C),
(0x0BDA, 0xC123),
(0x0BDA, 0xC822),
(0x0CB5, 0xC547),
(0x1358, 0xC123),
(0x13D3, 0x3548),
(0x13D3, 0x3549),
(0x13D3, 0x3553),
(0x13D3, 0x3555),
(0x2FF8, 0x3051),
# Realtek 8822CU
(0x13D3, 0x3549),
# Realtek 8852AE
(0x04C5, 0x165C),
(0x04CA, 0x4006),
(0x0BDA, 0x2852),
(0x0BDA, 0x385A),
(0x0BDA, 0x4852),
(0x0BDA, 0xC852),
(0x0CB8, 0xC549),
# Realtek 8852BE
(0x0BDA, 0x887B),
(0x0CB8, 0xC559),
(0x13D3, 0x3571),
# Realtek 8852CE
(0x04C5, 0x1675),
(0x04CA, 0x4007),
(0x0CB8, 0xC558),
(0x13D3, 0x3586),
(0x13D3, 0x3587),
(0x13D3, 0x3592),
}
# -----------------------------------------------------------------------------
# HCI Commands
# -----------------------------------------------------------------------------
HCI_RTK_READ_ROM_VERSION_COMMAND = hci_vendor_command_op_code(0x6D)
HCI_RTK_DOWNLOAD_COMMAND = hci_vendor_command_op_code(0x20)
HCI_RTK_DROP_FIRMWARE_COMMAND = hci_vendor_command_op_code(0x66)
HCI_Command.register_commands(globals())
@HCI_Command.command(return_parameters_fields=[("status", STATUS_SPEC), ("version", 1)])
class HCI_RTK_Read_ROM_Version_Command(HCI_Command):
pass
@HCI_Command.command(
fields=[("index", 1), ("payload", RTK_FRAGMENT_LENGTH)],
return_parameters_fields=[("status", STATUS_SPEC), ("index", 1)],
)
class HCI_RTK_Download_Command(HCI_Command):
pass
@HCI_Command.command()
class HCI_RTK_Drop_Firmware_Command(HCI_Command):
pass
# -----------------------------------------------------------------------------
class Firmware:
def __init__(self, firmware):
extension_sig = bytes([0x51, 0x04, 0xFD, 0x77])
if not firmware.startswith(RTK_EPATCH_SIGNATURE):
raise ValueError("Firmware does not start with epatch signature")
if not firmware.endswith(extension_sig):
raise ValueError("Firmware does not end with extension sig")
# The firmware should start with a 14 byte header.
epatch_header_size = 14
if len(firmware) < epatch_header_size:
raise ValueError("Firmware too short")
# Look for the "project ID", starting from the end.
offset = len(firmware) - len(extension_sig)
project_id = -1
while offset >= epatch_header_size:
length, opcode = firmware[offset - 2 : offset]
offset -= 2
if opcode == 0xFF:
# End
break
if length == 0:
raise ValueError("Invalid 0-length instruction")
if opcode == 0 and length == 1:
project_id = firmware[offset - 1]
break
offset -= length
if project_id < 0:
raise ValueError("Project ID not found")
self.project_id = project_id
# Read the patch tables info.
self.version, num_patches = struct.unpack("<IH", firmware[8:14])
self.patches = []
# The patches tables are laid out as:
# <ChipID_1><ChipID_2>...<ChipID_N> (16 bits each)
# <PatchLength_1><PatchLength_2>...<PatchLength_N> (16 bits each)
# <PatchOffset_1><PatchOffset_2>...<PatchOffset_N> (32 bits each)
if epatch_header_size + 8 * num_patches > len(firmware):
raise ValueError("Firmware too short")
chip_id_table_offset = epatch_header_size
patch_length_table_offset = chip_id_table_offset + 2 * num_patches
patch_offset_table_offset = chip_id_table_offset + 4 * num_patches
for patch_index in range(num_patches):
chip_id_offset = chip_id_table_offset + 2 * patch_index
(chip_id,) = struct.unpack_from("<H", firmware, chip_id_offset)
(patch_length,) = struct.unpack_from(
"<H", firmware, patch_length_table_offset + 2 * patch_index
)
(patch_offset,) = struct.unpack_from(
"<I", firmware, patch_offset_table_offset + 4 * patch_index
)
if patch_offset + patch_length > len(firmware):
raise ValueError("Firmware too short")
# Get the SVN version for the patch
(svn_version,) = struct.unpack_from(
"<I", firmware, patch_offset + patch_length - 8
)
# Create a payload with the patch, replacing the last 4 bytes with
# the firmware version.
self.patches.append(
(
chip_id,
firmware[patch_offset : patch_offset + patch_length - 4]
+ struct.pack("<I", self.version),
svn_version,
)
)
class Driver(common.Driver):
@dataclass
class DriverInfo:
rom: int
hci: Tuple[int, int]
config_needed: bool
has_rom_version: bool
has_msft_ext: bool = False
fw_name: str = ""
config_name: str = ""
DRIVER_INFOS = [
# 8723A
DriverInfo(
rom=RTK_ROM_LMP_8723A,
hci=(0x0B, 0x06),
config_needed=False,
has_rom_version=False,
fw_name="rtl8723a_fw.bin",
config_name="",
),
# 8723B
DriverInfo(
rom=RTK_ROM_LMP_8723B,
hci=(0x0B, 0x06),
config_needed=False,
has_rom_version=True,
fw_name="rtl8723b_fw.bin",
config_name="rtl8723b_config.bin",
),
# 8723D
DriverInfo(
rom=RTK_ROM_LMP_8723B,
hci=(0x0D, 0x08),
config_needed=True,
has_rom_version=True,
fw_name="rtl8723d_fw.bin",
config_name="rtl8723d_config.bin",
),
# 8821A
DriverInfo(
rom=RTK_ROM_LMP_8821A,
hci=(0x0A, 0x06),
config_needed=False,
has_rom_version=True,
fw_name="rtl8821a_fw.bin",
config_name="rtl8821a_config.bin",
),
# 8821C
DriverInfo(
rom=RTK_ROM_LMP_8821A,
hci=(0x0C, 0x08),
config_needed=False,
has_rom_version=True,
has_msft_ext=True,
fw_name="rtl8821c_fw.bin",
config_name="rtl8821c_config.bin",
),
# 8761A
DriverInfo(
rom=RTK_ROM_LMP_8761A,
hci=(0x0A, 0x06),
config_needed=False,
has_rom_version=True,
fw_name="rtl8761a_fw.bin",
config_name="rtl8761a_config.bin",
),
# 8761BU
DriverInfo(
rom=RTK_ROM_LMP_8761A,
hci=(0x0B, 0x0A),
config_needed=False,
has_rom_version=True,
fw_name="rtl8761bu_fw.bin",
config_name="rtl8761bu_config.bin",
),
# 8822C
DriverInfo(
rom=RTK_ROM_LMP_8822B,
hci=(0x0C, 0x0A),
config_needed=False,
has_rom_version=True,
has_msft_ext=True,
fw_name="rtl8822cu_fw.bin",
config_name="rtl8822cu_config.bin",
),
# 8822B
DriverInfo(
rom=RTK_ROM_LMP_8822B,
hci=(0x0B, 0x07),
config_needed=True,
has_rom_version=True,
has_msft_ext=True,
fw_name="rtl8822b_fw.bin",
config_name="rtl8822b_config.bin",
),
# 8852A
DriverInfo(
rom=RTK_ROM_LMP_8852A,
hci=(0x0A, 0x0B),
config_needed=False,
has_rom_version=True,
has_msft_ext=True,
fw_name="rtl8852au_fw.bin",
config_name="rtl8852au_config.bin",
),
# 8852B
DriverInfo(
rom=RTK_ROM_LMP_8852A,
hci=(0xB, 0xB),
config_needed=False,
has_rom_version=True,
has_msft_ext=True,
fw_name="rtl8852bu_fw.bin",
config_name="rtl8852bu_config.bin",
),
# 8852C
DriverInfo(
rom=RTK_ROM_LMP_8852A,
hci=(0x0C, 0x0C),
config_needed=False,
has_rom_version=True,
has_msft_ext=True,
fw_name="rtl8852cu_fw.bin",
config_name="rtl8852cu_config.bin",
),
]
POST_DROP_DELAY = 0.2
@staticmethod
def find_driver_info(hci_version, hci_subversion, lmp_subversion):
for driver_info in Driver.DRIVER_INFOS:
if driver_info.rom == lmp_subversion and driver_info.hci == (
hci_subversion,
hci_version,
):
return driver_info
return None
@staticmethod
def find_binary_path(file_name):
# First check if an environment variable is set
if RTK_FIRMWARE_DIR_ENV in os.environ:
if (
path := pathlib.Path(os.environ[RTK_FIRMWARE_DIR_ENV]) / file_name
).is_file():
logger.debug(f"{file_name} found in env dir")
return path
# When the environment variable is set, don't look elsewhere
return None
# Then, look where the firmware download tool writes by default
if (path := rtk_firmware_dir() / file_name).is_file():
logger.debug(f"{file_name} found in project data dir")
return path
# Then, look in the package's driver directory
if (path := pathlib.Path(__file__).parent / "rtk_fw" / file_name).is_file():
logger.debug(f"{file_name} found in package dir")
return path
# On Linux, check the system's FW directory
if (
platform.system() == "Linux"
and (path := pathlib.Path(RTK_LINUX_FIRMWARE_DIR) / file_name).is_file()
):
logger.debug(f"{file_name} found in Linux system FW dir")
return path
# Finally look in the current directory
if (path := pathlib.Path.cwd() / file_name).is_file():
logger.debug(f"{file_name} found in CWD")
return path
return None
@staticmethod
def check(host):
if not host.hci_metadata:
logger.debug("USB metadata not found")
return False
if host.hci_metadata.get('driver') == 'rtk':
# Forced driver
return True
vendor_id = host.hci_metadata.get("vendor_id")
product_id = host.hci_metadata.get("product_id")
if vendor_id is None or product_id is None:
logger.debug("USB metadata not sufficient")
return False
if (vendor_id, product_id) not in RTK_USB_PRODUCTS:
logger.debug(
f"USB device ({vendor_id:04X}, {product_id:04X}) " "not in known list"
)
return False
return True
@classmethod
async def driver_info_for_host(cls, host):
await host.send_command(HCI_Reset_Command(), check_result=True)
host.ready = True # Needed to let the host know the controller is ready.
response = await host.send_command(
HCI_Read_Local_Version_Information_Command(), check_result=True
)
local_version = response.return_parameters
logger.debug(
f"looking for a driver: 0x{local_version.lmp_subversion:04X} "
f"(0x{local_version.hci_version:02X}, "
f"0x{local_version.hci_subversion:04X})"
)
driver_info = cls.find_driver_info(
local_version.hci_version,
local_version.hci_subversion,
local_version.lmp_subversion,
)
if driver_info is None:
# TODO: it seems that the Linux driver will send command (0x3f, 0x66)
# in this case and then re-read the local version, then re-match.
logger.debug("firmware already loaded or no known driver for this device")
return driver_info
@classmethod
async def for_host(cls, host, force=False):
# Check that a driver is needed for this host
if not force and not cls.check(host):
return None
# Get the driver info
driver_info = await cls.driver_info_for_host(host)
if driver_info is None:
return None
# Load the firmware
firmware_path = cls.find_binary_path(driver_info.fw_name)
if not firmware_path:
logger.warning(f"Firmware file {driver_info.fw_name} not found")
logger.warning("See https://google.github.io/bumble/drivers/realtek.html")
return None
with open(firmware_path, "rb") as firmware_file:
firmware = firmware_file.read()
# Load the config
config = None
if driver_info.config_name:
config_path = cls.find_binary_path(driver_info.config_name)
if config_path:
with open(config_path, "rb") as config_file:
config = config_file.read()
if driver_info.config_needed and not config:
logger.warning("Config needed, but no config file available")
return None
return cls(host, driver_info, firmware, config)
def __init__(self, host, driver_info, firmware, config):
self.host = weakref.proxy(host)
self.driver_info = driver_info
self.firmware = firmware
self.config = config
@staticmethod
async def drop_firmware(host):
host.send_hci_packet(HCI_RTK_Drop_Firmware_Command())
# Wait for the command to be effective (no response is sent)
await asyncio.sleep(Driver.POST_DROP_DELAY)
async def download_for_rtl8723a(self):
# Check that the firmware image does not include an epatch signature.
if RTK_EPATCH_SIGNATURE in self.firmware:
logger.warning(
"epatch signature found in firmware, it is probably the wrong firmware"
)
return
# TODO: load the firmware
async def download_for_rtl8723b(self):
if self.driver_info.has_rom_version:
response = await self.host.send_command(
HCI_RTK_Read_ROM_Version_Command(), check_result=True
)
if response.return_parameters.status != HCI_SUCCESS:
logger.warning("can't get ROM version")
return
rom_version = response.return_parameters.version
logger.debug(f"ROM version before download: {rom_version:04X}")
else:
rom_version = 0
firmware = Firmware(self.firmware)
logger.debug(f"firmware: project_id=0x{firmware.project_id:04X}")
for patch in firmware.patches:
if patch[0] == rom_version + 1:
logger.debug(f"using patch {patch[0]}")
break
else:
logger.warning("no valid patch found for rom version {rom_version}")
return
# Append the config if there is one.
if self.config:
payload = patch[1] + self.config
else:
payload = patch[1]
# Download the payload, one fragment at a time.
fragment_count = math.ceil(len(payload) / RTK_FRAGMENT_LENGTH)
for fragment_index in range(fragment_count):
# NOTE: the Linux driver somehow adds 1 to the index after it wraps around.
# That's odd, but we"ll do the same here.
download_index = fragment_index & 0x7F
if download_index >= 0x80:
download_index += 1
if fragment_index == fragment_count - 1:
download_index |= 0x80 # End marker.
fragment_offset = fragment_index * RTK_FRAGMENT_LENGTH
fragment = payload[fragment_offset : fragment_offset + RTK_FRAGMENT_LENGTH]
logger.debug(f"downloading fragment {fragment_index}")
await self.host.send_command(
HCI_RTK_Download_Command(
index=download_index, payload=fragment, check_result=True
)
)
logger.debug("download complete!")
# Read the version again
response = await self.host.send_command(
HCI_RTK_Read_ROM_Version_Command(), check_result=True
)
if response.return_parameters.status != HCI_SUCCESS:
logger.warning("can't get ROM version")
else:
rom_version = response.return_parameters.version
logger.debug(f"ROM version after download: {rom_version:04X}")
async def download_firmware(self):
if self.driver_info.rom == RTK_ROM_LMP_8723A:
return await self.download_for_rtl8723a()
if self.driver_info.rom in (
RTK_ROM_LMP_8723B,
RTK_ROM_LMP_8821A,
RTK_ROM_LMP_8761A,
RTK_ROM_LMP_8822B,
RTK_ROM_LMP_8852A,
):
return await self.download_for_rtl8723b()
raise ValueError("ROM not supported")
async def init_controller(self):
await self.download_firmware()
await self.host.send_command(HCI_Reset_Command(), check_result=True)
logger.info(f"loaded FW image {self.driver_info.fw_name}")
def rtk_firmware_dir() -> pathlib.Path:
"""
Returns:
A path to a subdir of the project data dir for Realtek firmware.
The directory is created if it doesn't exist.
"""
from bumble.drivers import project_data_dir
p = project_data_dir() / "firmware" / "realtek"
p.mkdir(parents=True, exist_ok=True)
return p
+10 -11
View File
@@ -23,7 +23,7 @@ from .gatt import (
Characteristic, Characteristic,
GATT_GENERIC_ACCESS_SERVICE, GATT_GENERIC_ACCESS_SERVICE,
GATT_DEVICE_NAME_CHARACTERISTIC, GATT_DEVICE_NAME_CHARACTERISTIC,
GATT_APPEARANCE_CHARACTERISTIC, GATT_APPEARANCE_CHARACTERISTIC
) )
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -36,25 +36,24 @@ logger = logging.getLogger(__name__)
# Classes # Classes
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class GenericAccessService(Service): class GenericAccessService(Service):
def __init__(self, device_name, appearance=(0, 0)): def __init__(self, device_name, appearance = (0, 0)):
device_name_characteristic = Characteristic( device_name_characteristic = Characteristic(
GATT_DEVICE_NAME_CHARACTERISTIC, GATT_DEVICE_NAME_CHARACTERISTIC,
Characteristic.Properties.READ, Characteristic.READ,
Characteristic.READABLE, Characteristic.READABLE,
device_name.encode('utf-8')[:248], device_name.encode('utf-8')[:248]
) )
appearance_characteristic = Characteristic( appearance_characteristic = Characteristic(
GATT_APPEARANCE_CHARACTERISTIC, GATT_APPEARANCE_CHARACTERISTIC,
Characteristic.Properties.READ, Characteristic.READ,
Characteristic.READABLE, Characteristic.READABLE,
struct.pack('<H', (appearance[0] << 6) | appearance[1]), struct.pack('<H', (appearance[0] << 6) | appearance[1])
) )
super().__init__( super().__init__(GATT_GENERIC_ACCESS_SERVICE, [
GATT_GENERIC_ACCESS_SERVICE, device_name_characteristic,
[device_name_characteristic, appearance_characteristic], appearance_characteristic
) ])
+116 -433
View File
@@ -22,30 +22,14 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations import asyncio
import enum import types
import functools
import logging import logging
import struct from colors import color
from typing import (
Callable,
Dict,
Iterable,
List,
Optional,
Sequence,
Union,
TYPE_CHECKING,
)
from bumble.colors import color
from bumble.core import UUID
from bumble.att import Attribute, AttributeValue
if TYPE_CHECKING:
from bumble.gatt_client import AttributeProxy
from bumble.device import Connection
from .core import *
from .hci import *
from .att import *
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -55,9 +39,6 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Constants # Constants
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# fmt: off
# pylint: disable=line-too-long
GATT_REQUEST_TIMEOUT = 30 # seconds GATT_REQUEST_TIMEOUT = 30 # seconds
GATT_MAX_ATTRIBUTE_VALUE_SIZE = 512 GATT_MAX_ATTRIBUTE_VALUE_SIZE = 512
@@ -105,35 +86,20 @@ GATT_RECONNECTION_CONFIGURATION_SERVICE = UUID.from_16_bits(0x1829, 'Reconne
GATT_INSULIN_DELIVERY_SERVICE = UUID.from_16_bits(0x183A, 'Insulin Delivery') GATT_INSULIN_DELIVERY_SERVICE = UUID.from_16_bits(0x183A, 'Insulin Delivery')
GATT_BINARY_SENSOR_SERVICE = UUID.from_16_bits(0x183B, 'Binary Sensor') GATT_BINARY_SENSOR_SERVICE = UUID.from_16_bits(0x183B, 'Binary Sensor')
GATT_EMERGENCY_CONFIGURATION_SERVICE = UUID.from_16_bits(0x183C, 'Emergency Configuration') GATT_EMERGENCY_CONFIGURATION_SERVICE = UUID.from_16_bits(0x183C, 'Emergency Configuration')
GATT_AUTHORIZATION_CONTROL_SERVICE = UUID.from_16_bits(0x183D, 'Authorization Control')
GATT_PHYSICAL_ACTIVITY_MONITOR_SERVICE = UUID.from_16_bits(0x183E, 'Physical Activity Monitor') GATT_PHYSICAL_ACTIVITY_MONITOR_SERVICE = UUID.from_16_bits(0x183E, 'Physical Activity Monitor')
GATT_ELAPSED_TIME_SERVICE = UUID.from_16_bits(0x183F, 'Elapsed Time')
GATT_GENERIC_HEALTH_SENSOR_SERVICE = UUID.from_16_bits(0x1840, 'Generic Health Sensor')
GATT_AUDIO_INPUT_CONTROL_SERVICE = UUID.from_16_bits(0x1843, 'Audio Input Control') GATT_AUDIO_INPUT_CONTROL_SERVICE = UUID.from_16_bits(0x1843, 'Audio Input Control')
GATT_VOLUME_CONTROL_SERVICE = UUID.from_16_bits(0x1844, 'Volume Control') GATT_VOLUME_CONTROL_SERVICE = UUID.from_16_bits(0x1844, 'Volume Control')
GATT_VOLUME_OFFSET_CONTROL_SERVICE = UUID.from_16_bits(0x1845, 'Volume Offset Control') GATT_VOLUME_OFFSET_CONTROL_SERVICE = UUID.from_16_bits(0x1845, 'Volume Offset Control')
GATT_COORDINATED_SET_IDENTIFICATION_SERVICE = UUID.from_16_bits(0x1846, 'Coordinated Set Identification') GATT_COORDINATED_SET_IDENTIFICATION_SERVICE = UUID.from_16_bits(0x1846, 'Coordinated Set Identification Service')
GATT_DEVICE_TIME_SERVICE = UUID.from_16_bits(0x1847, 'Device Time') GATT_DEVICE_TIME_SERVICE = UUID.from_16_bits(0x1847, 'Device Time')
GATT_MEDIA_CONTROL_SERVICE = UUID.from_16_bits(0x1848, 'Media Control') GATT_MEDIA_CONTROL_SERVICE = UUID.from_16_bits(0x1848, 'Media Control Service')
GATT_GENERIC_MEDIA_CONTROL_SERVICE = UUID.from_16_bits(0x1849, 'Generic Media Control') GATT_GENERIC_MEDIA_CONTROL_SERVICE = UUID.from_16_bits(0x1849, 'Generic Media Control Service')
GATT_CONSTANT_TONE_EXTENSION_SERVICE = UUID.from_16_bits(0x184A, 'Constant Tone Extension') GATT_CONSTANT_TONE_EXTENSION_SERVICE = UUID.from_16_bits(0x184A, 'Constant Tone Extension')
GATT_TELEPHONE_BEARER_SERVICE = UUID.from_16_bits(0x184B, 'Telephone Bearer') GATT_TELEPHONE_BEARER_SERVICE = UUID.from_16_bits(0x184B, 'Telephone Bearer Service')
GATT_GENERIC_TELEPHONE_BEARER_SERVICE = UUID.from_16_bits(0x184C, 'Generic Telephone Bearer') GATT_GENERIC_TELEPHONE_BEARER_SERVICE = UUID.from_16_bits(0x184C, 'Generic Telephone Bearer Service')
GATT_MICROPHONE_CONTROL_SERVICE = UUID.from_16_bits(0x184D, 'Microphone Control') GATT_MICROPHONE_CONTROL_SERVICE = UUID.from_16_bits(0x184D, 'Microphone Control')
GATT_AUDIO_STREAM_CONTROL_SERVICE = UUID.from_16_bits(0x184E, 'Audio Stream Control')
GATT_BROADCAST_AUDIO_SCAN_SERVICE = UUID.from_16_bits(0x184F, 'Broadcast Audio Scan')
GATT_PUBLISHED_AUDIO_CAPABILITIES_SERVICE = UUID.from_16_bits(0x1850, 'Published Audio Capabilities')
GATT_BASIC_AUDIO_ANNOUNCEMENT_SERVICE = UUID.from_16_bits(0x1851, 'Basic Audio Announcement')
GATT_BROADCAST_AUDIO_ANNOUNCEMENT_SERVICE = UUID.from_16_bits(0x1852, 'Broadcast Audio Announcement')
GATT_COMMON_AUDIO_SERVICE = UUID.from_16_bits(0x1853, 'Common Audio')
GATT_HEARING_ACCESS_SERVICE = UUID.from_16_bits(0x1854, 'Hearing Access')
GATT_TELEPHONY_AND_MEDIA_AUDIO_SERVICE = UUID.from_16_bits(0x1855, 'Telephony and Media Audio')
GATT_PUBLIC_BROADCAST_ANNOUNCEMENT_SERVICE = UUID.from_16_bits(0x1856, 'Public Broadcast Announcement')
GATT_ELECTRONIC_SHELF_LABEL_SERVICE = UUID.from_16_bits(0X1857, 'Electronic Shelf Label')
GATT_GAMING_AUDIO_SERVICE = UUID.from_16_bits(0x1858, 'Gaming Audio')
GATT_MESH_PROXY_SOLICITATION_SERVICE = UUID.from_16_bits(0x1859, 'Mesh Audio Solicitation')
# Attribute Types # Types
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE = UUID.from_16_bits(0x2800, 'Primary Service') GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE = UUID.from_16_bits(0x2800, 'Primary Service')
GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE = UUID.from_16_bits(0x2801, 'Secondary Service') GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE = UUID.from_16_bits(0x2801, 'Secondary Service')
GATT_INCLUDE_ATTRIBUTE_TYPE = UUID.from_16_bits(0x2802, 'Include') GATT_INCLUDE_ATTRIBUTE_TYPE = UUID.from_16_bits(0x2802, 'Include')
@@ -156,8 +122,6 @@ GATT_ENVIRONMENTAL_SENSING_MEASUREMENT_DESCRIPTOR = UUID.from_16_bits(0x290C,
GATT_ENVIRONMENTAL_SENSING_TRIGGER_DESCRIPTOR = UUID.from_16_bits(0x290D, 'Environmental Sensing Trigger Setting') GATT_ENVIRONMENTAL_SENSING_TRIGGER_DESCRIPTOR = UUID.from_16_bits(0x290D, 'Environmental Sensing Trigger Setting')
GATT_TIME_TRIGGER_DESCRIPTOR = UUID.from_16_bits(0x290E, 'Time Trigger Setting') GATT_TIME_TRIGGER_DESCRIPTOR = UUID.from_16_bits(0x290E, 'Time Trigger Setting')
GATT_COMPLETE_BR_EDR_TRANSPORT_BLOCK_DATA_DESCRIPTOR = UUID.from_16_bits(0x290F, 'Complete BR-EDR Transport Block Data') GATT_COMPLETE_BR_EDR_TRANSPORT_BLOCK_DATA_DESCRIPTOR = UUID.from_16_bits(0x290F, 'Complete BR-EDR Transport Block Data')
GATT_OBSERVATION_SCHEDULE_DESCRIPTOR = UUID.from_16_bits(0x290F, 'Observation Schedule')
GATT_VALID_RANGE_AND_ACCURACY_DESCRIPTOR = UUID.from_16_bits(0x290F, 'Valid Range And Accuracy')
# Device Information Service # Device Information Service
GATT_SYSTEM_ID_CHARACTERISTIC = UUID.from_16_bits(0x2A23, 'System ID') GATT_SYSTEM_ID_CHARACTERISTIC = UUID.from_16_bits(0x2A23, 'System ID')
@@ -185,104 +149,6 @@ GATT_HEART_RATE_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2A39, 'Heart
# Battery Service # Battery Service
GATT_BATTERY_LEVEL_CHARACTERISTIC = UUID.from_16_bits(0x2A19, 'Battery Level') GATT_BATTERY_LEVEL_CHARACTERISTIC = UUID.from_16_bits(0x2A19, 'Battery Level')
# Telephony And Media Audio Service (TMAS)
GATT_TMAP_ROLE_CHARACTERISTIC = UUID.from_16_bits(0x2B51, 'TMAP Role')
# Audio Input Control Service (AICS)
GATT_AUDIO_INPUT_STATE_CHARACTERISTIC = UUID.from_16_bits(0x2B77, 'Audio Input State')
GATT_GAIN_SETTINGS_ATTRIBUTE_CHARACTERISTIC = UUID.from_16_bits(0x2B78, 'Gain Settings Attribute')
GATT_AUDIO_INPUT_TYPE_CHARACTERISTIC = UUID.from_16_bits(0x2B79, 'Audio Input Type')
GATT_AUDIO_INPUT_STATUS_CHARACTERISTIC = UUID.from_16_bits(0x2B7A, 'Audio Input Status')
GATT_AUDIO_INPUT_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2B7B, 'Audio Input Control Point')
GATT_AUDIO_INPUT_DESCRIPTION_CHARACTERISTIC = UUID.from_16_bits(0x2B7C, 'Audio Input Description')
# Volume Control Service (VCS)
GATT_VOLUME_STATE_CHARACTERISTIC = UUID.from_16_bits(0x2B7D, 'Volume State')
GATT_VOLUME_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2B7E, 'Volume Control Point')
GATT_VOLUME_FLAGS_CHARACTERISTIC = UUID.from_16_bits(0x2B7F, 'Volume Flags')
# Volume Offset Control Service (VOCS)
GATT_VOLUME_OFFSET_STATE_CHARACTERISTIC = UUID.from_16_bits(0x2B80, 'Volume Offset State')
GATT_AUDIO_LOCATION_CHARACTERISTIC = UUID.from_16_bits(0x2B81, 'Audio Location')
GATT_VOLUME_OFFSET_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2B82, 'Volume Offset Control Point')
GATT_AUDIO_OUTPUT_DESCRIPTION_CHARACTERISTIC = UUID.from_16_bits(0x2B83, 'Audio Output Description')
# Coordinated Set Identification Service (CSIS)
GATT_SET_IDENTITY_RESOLVING_KEY_CHARACTERISTIC = UUID.from_16_bits(0x2B84, 'Set Identity Resolving Key')
GATT_COORDINATED_SET_SIZE_CHARACTERISTIC = UUID.from_16_bits(0x2B85, 'Coordinated Set Size')
GATT_SET_MEMBER_LOCK_CHARACTERISTIC = UUID.from_16_bits(0x2B86, 'Set Member Lock')
GATT_SET_MEMBER_RANK_CHARACTERISTIC = UUID.from_16_bits(0x2B87, 'Set Member Rank')
# Media Control Service (MCS)
GATT_MEDIA_PLAYER_NAME_CHARACTERISTIC = UUID.from_16_bits(0x2B93, 'Media Player Name')
GATT_MEDIA_PLAYER_ICON_OBJECT_ID_CHARACTERISTIC = UUID.from_16_bits(0x2B94, 'Media Player Icon Object ID')
GATT_MEDIA_PLAYER_ICON_URL_CHARACTERISTIC = UUID.from_16_bits(0x2B95, 'Media Player Icon URL')
GATT_TRACK_CHANGED_CHARACTERISTIC = UUID.from_16_bits(0x2B96, 'Track Changed')
GATT_TRACK_TITLE_CHARACTERISTIC = UUID.from_16_bits(0x2B97, 'Track Title')
GATT_TRACK_DURATION_CHARACTERISTIC = UUID.from_16_bits(0x2B98, 'Track Duration')
GATT_TRACK_POSITION_CHARACTERISTIC = UUID.from_16_bits(0x2B99, 'Track Position')
GATT_PLAYBACK_SPEED_CHARACTERISTIC = UUID.from_16_bits(0x2B9A, 'Playback Speed')
GATT_SEEKING_SPEED_CHARACTERISTIC = UUID.from_16_bits(0x2B9B, 'Seeking Speed')
GATT_CURRENT_TRACK_SEGMENTS_OBJECT_ID_CHARACTERISTIC = UUID.from_16_bits(0x2B9C, 'Current Track Segments Object ID')
GATT_CURRENT_TRACK_OBJECT_ID_CHARACTERISTIC = UUID.from_16_bits(0x2B9D, 'Current Track Object ID')
GATT_NEXT_TRACK_OBJECT_ID_CHARACTERISTIC = UUID.from_16_bits(0x2B9E, 'Next Track Object ID')
GATT_PARENT_GROUP_OBJECT_ID_CHARACTERISTIC = UUID.from_16_bits(0x2B9F, 'Parent Group Object ID')
GATT_CURRENT_GROUP_OBJECT_ID_CHARACTERISTIC = UUID.from_16_bits(0x2BA0, 'Current Group Object ID')
GATT_PLAYING_ORDER_CHARACTERISTIC = UUID.from_16_bits(0x2BA1, 'Playing Order')
GATT_PLAYING_ORDERS_SUPPORTED_CHARACTERISTIC = UUID.from_16_bits(0x2BA2, 'Playing Orders Supported')
GATT_MEDIA_STATE_CHARACTERISTIC = UUID.from_16_bits(0x2BA3, 'Media State')
GATT_MEDIA_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2BA4, 'Media Control Point')
GATT_MEDIA_CONTROL_POINT_OPCODES_SUPPORTED_CHARACTERISTIC = UUID.from_16_bits(0x2BA5, 'Media Control Point Opcodes Supported')
GATT_SEARCH_RESULTS_OBJECT_ID_CHARACTERISTIC = UUID.from_16_bits(0x2BA6, 'Search Results Object ID')
GATT_SEARCH_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2BA7, 'Search Control Point')
GATT_CONTENT_CONTROL_ID_CHARACTERISTIC = UUID.from_16_bits(0x2BBA, 'Content Control Id')
# Telephone Bearer Service (TBS)
GATT_BEARER_PROVIDER_NAME_CHARACTERISTIC = UUID.from_16_bits(0x2BB4, 'Bearer Provider Name')
GATT_BEARER_UCI_CHARACTERISTIC = UUID.from_16_bits(0x2BB5, 'Bearer UCI')
GATT_BEARER_TECHNOLOGY_CHARACTERISTIC = UUID.from_16_bits(0x2BB6, 'Bearer Technology')
GATT_BEARER_URI_SCHEMES_SUPPORTED_LIST_CHARACTERISTIC = UUID.from_16_bits(0x2BB7, 'Bearer URI Schemes Supported List')
GATT_BEARER_SIGNAL_STRENGTH_CHARACTERISTIC = UUID.from_16_bits(0x2BB8, 'Bearer Signal Strength')
GATT_BEARER_SIGNAL_STRENGTH_REPORTING_INTERVAL_CHARACTERISTIC = UUID.from_16_bits(0x2BB9, 'Bearer Signal Strength Reporting Interval')
GATT_BEARER_LIST_CURRENT_CALLS_CHARACTERISTIC = UUID.from_16_bits(0x2BBA, 'Bearer List Current Calls')
GATT_CONTENT_CONTROL_ID_CHARACTERISTIC = UUID.from_16_bits(0x2BBB, 'Content Control ID')
GATT_STATUS_FLAGS_CHARACTERISTIC = UUID.from_16_bits(0x2BBC, 'Status Flags')
GATT_INCOMING_CALL_TARGET_BEARER_URI_CHARACTERISTIC = UUID.from_16_bits(0x2BBD, 'Incoming Call Target Bearer URI')
GATT_CALL_STATE_CHARACTERISTIC = UUID.from_16_bits(0x2BBE, 'Call State')
GATT_CALL_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2BBF, 'Call Control Point')
GATT_CALL_CONTROL_POINT_OPTIONAL_OPCODES_CHARACTERISTIC = UUID.from_16_bits(0x2BC0, 'Call Control Point Optional Opcodes')
GATT_TERMINATION_REASON_CHARACTERISTIC = UUID.from_16_bits(0x2BC1, 'Termination Reason')
GATT_INCOMING_CALL_CHARACTERISTIC = UUID.from_16_bits(0x2BC2, 'Incoming Call')
GATT_CALL_FRIENDLY_NAME_CHARACTERISTIC = UUID.from_16_bits(0x2BC3, 'Call Friendly Name')
# Microphone Control Service (MICS)
GATT_MUTE_CHARACTERISTIC = UUID.from_16_bits(0x2BC3, 'Mute')
# Audio Stream Control Service (ASCS)
GATT_SINK_ASE_CHARACTERISTIC = UUID.from_16_bits(0x2BC4, 'Sink ASE')
GATT_SOURCE_ASE_CHARACTERISTIC = UUID.from_16_bits(0x2BC5, 'Source ASE')
GATT_ASE_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2BC6, 'ASE Control Point')
# Broadcast Audio Scan Service (BASS)
GATT_BROADCAST_AUDIO_SCAN_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2BC7, 'Broadcast Audio Scan Control Point')
GATT_BROADCAST_RECEIVE_STATE_CHARACTERISTIC = UUID.from_16_bits(0x2BC8, 'Broadcast Receive State')
# Published Audio Capabilities Service (PACS)
GATT_SINK_PAC_CHARACTERISTIC = UUID.from_16_bits(0x2BC9, 'Sink PAC')
GATT_SINK_AUDIO_LOCATION_CHARACTERISTIC = UUID.from_16_bits(0x2BCA, 'Sink Audio Location')
GATT_SOURCE_PAC_CHARACTERISTIC = UUID.from_16_bits(0x2BCB, 'Source PAC')
GATT_SOURCE_AUDIO_LOCATION_CHARACTERISTIC = UUID.from_16_bits(0x2BCC, 'Source Audio Location')
GATT_AVAILABLE_AUDIO_CONTEXTS_CHARACTERISTIC = UUID.from_16_bits(0x2BCD, 'Available Audio Contexts')
GATT_SUPPORTED_AUDIO_CONTEXTS_CHARACTERISTIC = UUID.from_16_bits(0x2BCE, 'Supported Audio Contexts')
# ASHA Service
GATT_ASHA_SERVICE = UUID.from_16_bits(0xFDF0, 'Audio Streaming for Hearing Aid')
GATT_ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC = UUID('6333651e-c481-4a3e-9169-7c902aad37bb', 'ReadOnlyProperties')
GATT_ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC = UUID('f0d4de7e-4a88-476c-9d9f-1937b0996cc0', 'AudioControlPoint')
GATT_ASHA_AUDIO_STATUS_CHARACTERISTIC = UUID('38663f1a-e711-4cac-b641-326b56404837', 'AudioStatus')
GATT_ASHA_VOLUME_CHARACTERISTIC = UUID('00e4ca9e-ab14-41e4-8823-f9e70c7e91df', 'Volume')
GATT_ASHA_LE_PSM_OUT_CHARACTERISTIC = UUID('2d410339-82b6-42aa-b34e-e2e01df8cc1a', 'LE_PSM_OUT')
# Misc # Misc
GATT_DEVICE_NAME_CHARACTERISTIC = UUID.from_16_bits(0x2A00, 'Device Name') GATT_DEVICE_NAME_CHARACTERISTIC = UUID.from_16_bits(0x2A00, 'Device Name')
GATT_APPEARANCE_CHARACTERISTIC = UUID.from_16_bits(0x2A01, 'Appearance') GATT_APPEARANCE_CHARACTERISTIC = UUID.from_16_bits(0x2A01, 'Appearance')
@@ -296,20 +162,13 @@ GATT_BOOT_KEYBOARD_INPUT_REPORT_CHARACTERISTIC = UUID.from_16_bi
GATT_CURRENT_TIME_CHARACTERISTIC = UUID.from_16_bits(0x2A2B, 'Current Time') GATT_CURRENT_TIME_CHARACTERISTIC = UUID.from_16_bits(0x2A2B, 'Current Time')
GATT_BOOT_KEYBOARD_OUTPUT_REPORT_CHARACTERISTIC = UUID.from_16_bits(0x2A32, 'Boot Keyboard Output Report') GATT_BOOT_KEYBOARD_OUTPUT_REPORT_CHARACTERISTIC = UUID.from_16_bits(0x2A32, 'Boot Keyboard Output Report')
GATT_CENTRAL_ADDRESS_RESOLUTION__CHARACTERISTIC = UUID.from_16_bits(0x2AA6, 'Central Address Resolution') GATT_CENTRAL_ADDRESS_RESOLUTION__CHARACTERISTIC = UUID.from_16_bits(0x2AA6, 'Central Address Resolution')
GATT_CLIENT_SUPPORTED_FEATURES_CHARACTERISTIC = UUID.from_16_bits(0x2B29, 'Client Supported Features')
GATT_DATABASE_HASH_CHARACTERISTIC = UUID.from_16_bits(0x2B2A, 'Database Hash')
GATT_SERVER_SUPPORTED_FEATURES_CHARACTERISTIC = UUID.from_16_bits(0x2B3A, 'Server Supported Features')
# fmt: on
# pylint: enable=line-too-long
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Utils # Utils
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def show_services(services):
def show_services(services: Iterable[Service]) -> None:
for service in services: for service in services:
print(color(str(service), 'cyan')) print(color(str(service), 'cyan'))
@@ -326,50 +185,23 @@ class Service(Attribute):
See Vol 3, Part G - 3.1 SERVICE DEFINITION See Vol 3, Part G - 3.1 SERVICE DEFINITION
''' '''
uuid: UUID def __init__(self, uuid, characteristics, primary=True):
characteristics: List[Characteristic]
included_services: List[Service]
def __init__(
self,
uuid: Union[str, UUID],
characteristics: List[Characteristic],
primary=True,
included_services: List[Service] = [],
) -> None:
# Convert the uuid to a UUID object if it isn't already # Convert the uuid to a UUID object if it isn't already
if isinstance(uuid, str): if type(uuid) is str:
uuid = UUID(uuid) uuid = UUID(uuid)
super().__init__( super().__init__(
( GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE if primary else GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE,
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE
if primary
else GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE
),
Attribute.READABLE, Attribute.READABLE,
uuid.to_pdu_bytes(), uuid.to_pdu_bytes()
) )
self.uuid = uuid self.uuid = uuid
self.included_services = included_services[:] self.included_services = []
self.characteristics = characteristics[:] self.characteristics = characteristics[:]
self.primary = primary self.primary = primary
def get_advertising_data(self) -> Optional[bytes]: def __str__(self):
""" return f'Service(handle=0x{self.handle:04X}, end=0x{self.end_group_handle:04X}, uuid={self.uuid}){"" if self.primary else "*"}'
Get Service specific advertising data
Defined by each Service, default value is empty
:return Service data for advertising
"""
return None
def __str__(self) -> str:
return (
f'Service(handle=0x{self.handle:04X}, '
f'end=0x{self.end_group_handle:04X}, '
f'uuid={self.uuid})'
f'{"" if self.primary else "*"}'
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -378,42 +210,10 @@ class TemplateService(Service):
Convenience abstract class that can be used by profile-specific subclasses that want Convenience abstract class that can be used by profile-specific subclasses that want
to expose their UUID as a class property to expose their UUID as a class property
''' '''
UUID = None
UUID: UUID def __init__(self, characteristics, primary=True):
super().__init__(self.UUID, characteristics, primary)
def __init__(
self,
characteristics: List[Characteristic],
primary: bool = True,
included_services: List[Service] = [],
) -> None:
super().__init__(self.UUID, characteristics, primary, included_services)
# -----------------------------------------------------------------------------
class IncludedServiceDeclaration(Attribute):
'''
See Vol 3, Part G - 3.2 INCLUDE DEFINITION
'''
service: Service
def __init__(self, service: Service) -> None:
declaration_bytes = struct.pack(
'<HH2s', service.handle, service.end_group_handle, service.uuid.to_bytes()
)
super().__init__(
GATT_INCLUDE_ATTRIBUTE_TYPE, Attribute.READABLE, declaration_bytes
)
self.service = service
def __str__(self) -> str:
return (
f'IncludedServiceDefinition(handle=0x{self.handle:04X}, '
f'group_starting_handle=0x{self.service.handle:04X}, '
f'group_ending_handle=0x{self.service.end_group_handle:04X}, '
f'uuid={self.service.uuid})'
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -422,192 +222,121 @@ class Characteristic(Attribute):
See Vol 3, Part G - 3.3 CHARACTERISTIC DEFINITION See Vol 3, Part G - 3.3 CHARACTERISTIC DEFINITION
''' '''
uuid: UUID # Property flags
properties: Characteristic.Properties BROADCAST = 0x01
READ = 0x02
WRITE_WITHOUT_RESPONSE = 0x04
WRITE = 0x08
NOTIFY = 0x10
INDICATE = 0X20
AUTHENTICATED_SIGNED_WRITES = 0X40
EXTENDED_PROPERTIES = 0X80
class Properties(enum.IntFlag): PROPERTY_NAMES = {
"""Property flags""" BROADCAST: 'BROADCAST',
READ: 'READ',
WRITE_WITHOUT_RESPONSE: 'WRITE_WITHOUT_RESPONSE',
WRITE: 'WRITE',
NOTIFY: 'NOTIFY',
INDICATE: 'INDICATE',
AUTHENTICATED_SIGNED_WRITES: 'AUTHENTICATED_SIGNED_WRITES',
EXTENDED_PROPERTIES: 'EXTENDED_PROPERTIES'
}
BROADCAST = 0x01 @staticmethod
READ = 0x02 def property_name(property):
WRITE_WITHOUT_RESPONSE = 0x04 return Characteristic.PROPERTY_NAMES.get(property, '')
WRITE = 0x08
NOTIFY = 0x10
INDICATE = 0x20
AUTHENTICATED_SIGNED_WRITES = 0x40
EXTENDED_PROPERTIES = 0x80
@classmethod @staticmethod
def from_string(cls, properties_str: str) -> Characteristic.Properties: def properties_as_string(properties):
try: return ','.join([
return functools.reduce( Characteristic.property_name(p) for p in Characteristic.PROPERTY_NAMES.keys()
lambda x, y: x | cls[y], if properties & p
properties_str.replace("|", ",").split(","), ])
Characteristic.Properties(0),
)
except (TypeError, KeyError):
# The check for `p.name is not None` here is needed because for InFlag
# enums, the .name property can be None, when the enum value is 0,
# so the type hint for .name is Optional[str].
enum_list: List[str] = [p.name for p in cls if p.name is not None]
enum_list_str = ",".join(enum_list)
raise TypeError(
f"Characteristic.Properties::from_string() error:\nExpected a string containing any of the keys, separated by , or |: {enum_list_str}\nGot: {properties_str}"
)
def __str__(self) -> str: def __init__(self, uuid, properties, permissions, value = b'', descriptors = []):
# NOTE: we override this method to offer a consistent result between python
# versions: the value returned by IntFlag.__str__() changed in version 11.
return '|'.join(
flag.name
for flag in Characteristic.Properties
if self.value & flag.value and flag.name is not None
)
# For backwards compatibility these are defined here
# For new code, please use Characteristic.Properties.X
BROADCAST = Properties.BROADCAST
READ = Properties.READ
WRITE_WITHOUT_RESPONSE = Properties.WRITE_WITHOUT_RESPONSE
WRITE = Properties.WRITE
NOTIFY = Properties.NOTIFY
INDICATE = Properties.INDICATE
AUTHENTICATED_SIGNED_WRITES = Properties.AUTHENTICATED_SIGNED_WRITES
EXTENDED_PROPERTIES = Properties.EXTENDED_PROPERTIES
def __init__(
self,
uuid: Union[str, bytes, UUID],
properties: Characteristic.Properties,
permissions: Union[str, Attribute.Permissions],
value: Union[str, bytes, CharacteristicValue] = b'',
descriptors: Sequence[Descriptor] = (),
):
super().__init__(uuid, permissions, value) super().__init__(uuid, permissions, value)
self.uuid = self.type self.uuid = self.type
self.properties = properties self.properties = properties
self.descriptors = descriptors self.descriptors = descriptors
def get_descriptor(self, descriptor_type): def get_descriptor(self, descriptor_type):
for descriptor in self.descriptors: for descriptor in self.descriptors:
if descriptor.type == descriptor_type: if descriptor.uuid == descriptor_type:
return descriptor return descriptor
return None def __str__(self):
return f'Characteristic(handle=0x{self.handle:04X}, end=0x{self.end_group_handle:04X}, uuid={self.uuid}, properties={Characteristic.properties_as_string(self.properties)})'
def has_properties(self, properties: Characteristic.Properties) -> bool:
return self.properties & properties == properties
def __str__(self) -> str:
return (
f'Characteristic(handle=0x{self.handle:04X}, '
f'end=0x{self.end_group_handle:04X}, '
f'uuid={self.uuid}, '
f'{self.properties})'
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class CharacteristicDeclaration(Attribute): class CharacteristicValue:
''' '''
See Vol 3, Part G - 3.3.1 CHARACTERISTIC DECLARATION Characteristic value where reading and/or writing is delegated to functions
passed as arguments to the constructor.
''' '''
def __init__(self, read=None, write=None):
self._read = read
self._write = write
characteristic: Characteristic def read(self, connection):
return self._read(connection) if self._read else b''
def __init__(self, characteristic: Characteristic, value_handle: int) -> None: def write(self, connection, value):
declaration_bytes = ( if self._write:
struct.pack('<BH', characteristic.properties, value_handle) self._write(connection, value)
+ characteristic.uuid.to_pdu_bytes()
)
super().__init__(
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE, Attribute.READABLE, declaration_bytes
)
self.value_handle = value_handle
self.characteristic = characteristic
def __str__(self) -> str:
return (
f'CharacteristicDeclaration(handle=0x{self.handle:04X}, '
f'value_handle=0x{self.value_handle:04X}, '
f'uuid={self.characteristic.uuid}, '
f'{self.characteristic.properties})'
)
# -----------------------------------------------------------------------------
class CharacteristicValue(AttributeValue):
"""Same as AttributeValue, for backward compatibility"""
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class CharacteristicAdapter: class CharacteristicAdapter:
''' '''
An adapter that can adapt Characteristic and AttributeProxy objects An adapter that can adapt any object with `read_value` and `write_value`
by wrapping their `read_value()` and `write_value()` methods with ones that methods (like Characteristic and CharacteristicProxy objects) by wrapping
return/accept encoded/decoded values. those methods with ones that return/accept encoded/decoded values.
Objects with async methods are considered proxies, so the adaptation is one
For proxies (i.e used by a GATT client), the adaptation is one where the return where the return value of `read_value` is decoded and the value passed to
value of `read_value()` is decoded and the value passed to `write_value()` is `write_value` is encoded. Other objects are considered local characteristics
encoded. The `subscribe()` method, is wrapped with one where the values are decoded so the adaptation is one where the return value of `read_value` is encoded
before being passed to the subscriber. and the value passed to `write_value` is decoded.
If the characteristic has a `subscribe` method, it is wrapped with one where
For local values (i.e hosted by a GATT server) the adaptation is one where the the values are decoded before being passed to the subscriber.
return value of `read_value()` is encoded and the value passed to `write_value()`
is decoded.
''' '''
def __init__(self, characteristic):
read_value: Callable
write_value: Callable
def __init__(self, characteristic: Union[Characteristic, AttributeProxy]):
self.wrapped_characteristic = characteristic self.wrapped_characteristic = characteristic
self.subscribers: Dict[Callable, Callable] = (
{}
) # Map from subscriber to proxy subscriber
if isinstance(characteristic, Characteristic): if (
self.read_value = self.read_encoded_value asyncio.iscoroutinefunction(characteristic.read_value) and
self.write_value = self.write_encoded_value asyncio.iscoroutinefunction(characteristic.write_value)
else: ):
self.read_value = self.read_decoded_value self.read_value = self.read_decoded_value
self.write_value = self.write_decoded_value self.write_value = self.write_decoded_value
else:
self.read_value = self.read_encoded_value
self.write_value = self.write_encoded_value
if hasattr(self.wrapped_characteristic, 'subscribe'):
self.subscribe = self.wrapped_subscribe self.subscribe = self.wrapped_subscribe
self.unsubscribe = self.wrapped_unsubscribe
def __getattr__(self, name): def __getattr__(self, name):
return getattr(self.wrapped_characteristic, name) return getattr(self.wrapped_characteristic, name)
def __setattr__(self, name, value): def __setattr__(self, name, value):
if name in ( if name in {'wrapped_characteristic', 'read_value', 'write_value', 'subscribe'}:
'wrapped_characteristic',
'subscribers',
'read_value',
'write_value',
'subscribe',
'unsubscribe',
):
super().__setattr__(name, value) super().__setattr__(name, value)
else: else:
setattr(self.wrapped_characteristic, name, value) setattr(self.wrapped_characteristic, name, value)
async def read_encoded_value(self, connection): def read_encoded_value(self, connection):
return self.encode_value( return self.encode_value(self.wrapped_characteristic.read_value(connection))
await self.wrapped_characteristic.read_value(connection)
)
async def write_encoded_value(self, connection, value): def write_encoded_value(self, connection, value):
return await self.wrapped_characteristic.write_value( return self.wrapped_characteristic.write_value(connection, self.decode_value(value))
connection, self.decode_value(value)
)
async def read_decoded_value(self): async def read_decoded_value(self):
return self.decode_value(await self.wrapped_characteristic.read_value()) return self.decode_value(await self.wrapped_characteristic.read_value())
async def write_decoded_value(self, value, with_response=False): async def write_decoded_value(self, value):
return await self.wrapped_characteristic.write_value( return await self.wrapped_characteristic.write_value(self.encode_value(value))
self.encode_value(value), with_response
)
def encode_value(self, value): def encode_value(self, value):
return value return value
@@ -616,29 +345,11 @@ class CharacteristicAdapter:
return value return value
def wrapped_subscribe(self, subscriber=None): def wrapped_subscribe(self, subscriber=None):
if subscriber is not None: return self.wrapped_characteristic.subscribe(
if subscriber in self.subscribers: None if subscriber is None else lambda value: subscriber(self.decode_value(value))
# We already have a proxy subscriber )
subscriber = self.subscribers[subscriber]
else:
# Create and register a proxy that will decode the value
original_subscriber = subscriber
def on_change(value): def __str__(self):
original_subscriber(self.decode_value(value))
self.subscribers[subscriber] = on_change
subscriber = on_change
return self.wrapped_characteristic.subscribe(subscriber)
def wrapped_unsubscribe(self, subscriber=None):
if subscriber in self.subscribers:
subscriber = self.subscribers.pop(subscriber)
return self.wrapped_characteristic.unsubscribe(subscriber)
def __str__(self) -> str:
wrapped = str(self.wrapped_characteristic) wrapped = str(self.wrapped_characteristic)
return f'{self.__class__.__name__}({wrapped})' return f'{self.__class__.__name__}({wrapped})'
@@ -648,7 +359,6 @@ class DelegatedCharacteristicAdapter(CharacteristicAdapter):
''' '''
Adapter that converts bytes values using an encode and a decode function. Adapter that converts bytes values using an encode and a decode function.
''' '''
def __init__(self, characteristic, encode=None, decode=None): def __init__(self, characteristic, encode=None, decode=None):
super().__init__(characteristic) super().__init__(characteristic)
self.encode = encode self.encode = encode
@@ -671,10 +381,9 @@ class PackedCharacteristicAdapter(CharacteristicAdapter):
they return/accept a tuple with the same number of elements as is required for they return/accept a tuple with the same number of elements as is required for
the format. the format.
''' '''
def __init__(self, characteristic, format):
def __init__(self, characteristic, pack_format):
super().__init__(characteristic) super().__init__(characteristic)
self.struct = struct.Struct(pack_format) self.struct = struct.Struct(format)
def pack(self, *values): def pack(self, *values):
return self.struct.pack(*values) return self.struct.pack(*values)
@@ -683,7 +392,7 @@ class PackedCharacteristicAdapter(CharacteristicAdapter):
return self.struct.unpack(buffer) return self.struct.unpack(buffer)
def encode_value(self, value): def encode_value(self, value):
return self.pack(*value if isinstance(value, tuple) else (value,)) return self.pack(*value if type(value) is tuple else (value,))
def decode_value(self, value): def decode_value(self, value):
unpacked = self.unpack(value) unpacked = self.unpack(value)
@@ -696,15 +405,13 @@ class MappedCharacteristicAdapter(PackedCharacteristicAdapter):
Adapter that packs/unpacks characteristic values according to a standard Adapter that packs/unpacks characteristic values according to a standard
Python `struct` format. Python `struct` format.
The adapted `read_value` and `write_value` methods return/accept aa dictionary which The adapted `read_value` and `write_value` methods return/accept aa dictionary which
is packed/unpacked according to format, with the arguments extracted from the is packed/unpacked according to format, with the arguments extracted from the dictionary
dictionary by key, in the same order as they occur in the `keys` parameter. by key, in the same order as they occur in the `keys` parameter.
''' '''
def __init__(self, characteristic, format, keys):
def __init__(self, characteristic, pack_format, keys): super().__init__(characteristic, format)
super().__init__(characteristic, pack_format)
self.keys = keys self.keys = keys
# pylint: disable=arguments-differ
def pack(self, values): def pack(self, values):
return super().pack(*(values[key] for key in self.keys)) return super().pack(*(values[key] for key in self.keys))
@@ -717,11 +424,10 @@ class UTF8CharacteristicAdapter(CharacteristicAdapter):
''' '''
Adapter that converts strings to/from bytes using UTF-8 encoding Adapter that converts strings to/from bytes using UTF-8 encoding
''' '''
def encode_value(self, value):
def encode_value(self, value: str) -> bytes:
return value.encode('utf-8') return value.encode('utf-8')
def decode_value(self, value: bytes) -> str: def decode_value(self, value):
return value.decode('utf-8') return value.decode('utf-8')
@@ -731,31 +437,8 @@ class Descriptor(Attribute):
See Vol 3, Part G - 3.3.3 Characteristic Descriptor Declarations See Vol 3, Part G - 3.3.3 Characteristic Descriptor Declarations
''' '''
def __str__(self) -> str: def __init__(self, descriptor_type, permissions, value = b''):
if isinstance(self.value, bytes): super().__init__(descriptor_type, permissions, value)
value_str = self.value.hex()
elif isinstance(self.value, CharacteristicValue):
value = self.value.read(None)
if isinstance(value, bytes):
value_str = value.hex()
else:
value_str = '<async>'
else:
value_str = '<...>'
return (
f'Descriptor(handle=0x{self.handle:04X}, '
f'type={self.type}, '
f'value={value_str})'
)
def __str__(self):
# ----------------------------------------------------------------------------- return f'Descriptor(handle=0x{self.handle:04X}, type={self.type}, value={self.read_value(None).hex()})'
class ClientCharacteristicConfigurationBits(enum.IntFlag):
'''
See Vol 3, Part G - 3.3.3.3 - Table 3.11 Client Characteristic Configuration bit
field definition
'''
DEFAULT = 0x0000
NOTIFICATION = 0x0001
INDICATION = 0x0002
+182 -571
View File
File diff suppressed because it is too large Load Diff
+245 -532
View File
File diff suppressed because it is too large Load Diff
+1772 -4093
View File
File diff suppressed because it is too large Load Diff
+72 -198
View File
@@ -15,46 +15,32 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations
from collections.abc import Callable, MutableMapping
import datetime
from typing import cast, Any, Optional
import logging import logging
from colors import color
from bumble import avc from .core import name_or_number
from bumble import avctp from .gatt import ATT_PDU, ATT_CID
from bumble import avdtp from .l2cap import (
from bumble import avrcp
from bumble import crypto
from bumble import rfcomm
from bumble import sdp
from bumble.colors import color
from bumble.att import ATT_CID, ATT_PDU
from bumble.smp import SMP_CID, SMP_Command
from bumble.core import name_or_number
from bumble.l2cap import (
L2CAP_PDU, L2CAP_PDU,
L2CAP_CONNECTION_REQUEST, L2CAP_CONNECTION_REQUEST,
L2CAP_CONNECTION_RESPONSE, L2CAP_CONNECTION_RESPONSE,
L2CAP_SIGNALING_CID, L2CAP_SIGNALING_CID,
L2CAP_LE_SIGNALING_CID, L2CAP_LE_SIGNALING_CID,
L2CAP_Control_Frame, L2CAP_Control_Frame,
L2CAP_Connection_Request, L2CAP_Connection_Response
L2CAP_Connection_Response,
) )
from bumble.hci import ( from .hci import (
Address,
HCI_EVENT_PACKET, HCI_EVENT_PACKET,
HCI_ACL_DATA_PACKET, HCI_ACL_DATA_PACKET,
HCI_DISCONNECTION_COMPLETE_EVENT, HCI_DISCONNECTION_COMPLETE_EVENT,
HCI_AclDataPacketAssembler, HCI_AclDataPacketAssembler
HCI_Packet, )
HCI_Event, from .rfcomm import RFCOMM_Frame, RFCOMM_PSM
HCI_AclDataPacket, from .sdp import SDP_PDU, SDP_PSM
HCI_Disconnection_Complete_Event, from .avdtp import (
MessageAssembler as AVDTP_MessageAssembler,
AVDTP_PSM
) )
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -64,158 +50,83 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
PSM_NAMES = { PSM_NAMES = {
rfcomm.RFCOMM_PSM: 'RFCOMM', RFCOMM_PSM: 'RFCOMM',
sdp.SDP_PSM: 'SDP', SDP_PSM: 'SDP',
avdtp.AVDTP_PSM: 'AVDTP', AVDTP_PSM: 'AVDTP'
avctp.AVCTP_PSM: 'AVCTP',
# TODO: add more PSM values # TODO: add more PSM values
} }
AVCTP_PID_NAMES = {avrcp.AVRCP_PID: 'AVRCP'}
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class PacketTracer: class PacketTracer:
class AclStream: class AclStream:
psms: MutableMapping[int, int] def __init__(self, analyzer):
peer: Optional[PacketTracer.AclStream] self.analyzer = analyzer
avdtp_assemblers: MutableMapping[int, avdtp.MessageAssembler]
avctp_assemblers: MutableMapping[int, avctp.MessageAssembler]
def __init__(self, analyzer: PacketTracer.Analyzer) -> None:
self.analyzer = analyzer
self.packet_assembler = HCI_AclDataPacketAssembler(self.on_acl_pdu) self.packet_assembler = HCI_AclDataPacketAssembler(self.on_acl_pdu)
self.avdtp_assemblers = {} # AVDTP assemblers, by source_cid self.avdtp_assemblers = {} # AVDTP assemblers, by source_cid
self.avctp_assemblers = {} # AVCTP assemblers, by source_cid self.psms = {} # PSM, by source_cid
self.psms = {} # PSM, by source_cid self.peer = None # ACL stream in the other direction
self.peer = None
# pylint: disable=too-many-nested-blocks def on_acl_pdu(self, pdu):
def on_acl_pdu(self, pdu: bytes) -> None:
l2cap_pdu = L2CAP_PDU.from_bytes(pdu) l2cap_pdu = L2CAP_PDU.from_bytes(pdu)
self.analyzer.emit(l2cap_pdu)
if l2cap_pdu.cid == ATT_CID: if l2cap_pdu.cid == ATT_CID:
att_pdu = ATT_PDU.from_bytes(l2cap_pdu.payload) att_pdu = ATT_PDU.from_bytes(l2cap_pdu.payload)
self.analyzer.emit(att_pdu) self.analyzer.emit(att_pdu)
elif l2cap_pdu.cid == SMP_CID: elif l2cap_pdu.cid == L2CAP_SIGNALING_CID or l2cap_pdu.cid == L2CAP_LE_SIGNALING_CID:
smp_command = SMP_Command.from_bytes(l2cap_pdu.payload)
self.analyzer.emit(smp_command)
elif l2cap_pdu.cid in (L2CAP_SIGNALING_CID, L2CAP_LE_SIGNALING_CID):
control_frame = L2CAP_Control_Frame.from_bytes(l2cap_pdu.payload) control_frame = L2CAP_Control_Frame.from_bytes(l2cap_pdu.payload)
self.analyzer.emit(control_frame) self.analyzer.emit(control_frame)
# Check if this signals a new channel # Check if this signals a new channel
if control_frame.code == L2CAP_CONNECTION_REQUEST: if control_frame.code == L2CAP_CONNECTION_REQUEST:
connection_request = cast(L2CAP_Connection_Request, control_frame) self.psms[control_frame.source_cid] = control_frame.psm
self.psms[connection_request.source_cid] = connection_request.psm
elif control_frame.code == L2CAP_CONNECTION_RESPONSE: elif control_frame.code == L2CAP_CONNECTION_RESPONSE:
connection_response = cast(L2CAP_Connection_Response, control_frame) if control_frame.result == L2CAP_Connection_Response.CONNECTION_SUCCESSFUL:
if ( if self.peer:
connection_response.result if psm := self.peer.psms.get(control_frame.source_cid):
== L2CAP_Connection_Response.CONNECTION_SUCCESSFUL # Found a pending connection
): self.psms[control_frame.destination_cid] = psm
if self.peer and (
psm := self.peer.psms.get(connection_response.source_cid) # For AVDTP connections, create a packet assembler for each direction
): if psm == AVDTP_PSM:
# Found a pending connection self.avdtp_assemblers[control_frame.source_cid] = AVDTP_MessageAssembler(self.on_avdtp_message)
self.psms[connection_response.destination_cid] = psm self.peer.avdtp_assemblers[control_frame.destination_cid] = AVDTP_MessageAssembler(self.peer.on_avdtp_message)
# For AVDTP connections, create a packet assembler for
# each direction
if psm == avdtp.AVDTP_PSM:
self.avdtp_assemblers[
connection_response.source_cid
] = avdtp.MessageAssembler(self.on_avdtp_message)
self.peer.avdtp_assemblers[
connection_response.destination_cid
] = avdtp.MessageAssembler(self.peer.on_avdtp_message)
elif psm == avctp.AVCTP_PSM:
self.avctp_assemblers[
connection_response.source_cid
] = avctp.MessageAssembler(self.on_avctp_message)
self.peer.avctp_assemblers[
connection_response.destination_cid
] = avctp.MessageAssembler(self.peer.on_avctp_message)
else: else:
# Try to find the PSM associated with this PDU # Try to find the PSM associated with this PDU
if self.peer and (psm := self.peer.psms.get(l2cap_pdu.cid)): if self.peer and (psm := self.peer.psms.get(l2cap_pdu.cid)):
if psm == sdp.SDP_PSM: if psm == SDP_PSM:
sdp_pdu = sdp.SDP_PDU.from_bytes(l2cap_pdu.payload) sdp_pdu = SDP_PDU.from_bytes(l2cap_pdu.payload)
self.analyzer.emit(sdp_pdu) self.analyzer.emit(sdp_pdu)
elif psm == rfcomm.RFCOMM_PSM: elif psm == RFCOMM_PSM:
rfcomm_frame = rfcomm.RFCOMM_Frame.from_bytes(l2cap_pdu.payload) rfcomm_frame = RFCOMM_Frame.from_bytes(l2cap_pdu.payload)
self.analyzer.emit(rfcomm_frame) self.analyzer.emit(rfcomm_frame)
elif psm == avdtp.AVDTP_PSM: elif psm == AVDTP_PSM:
self.analyzer.emit( self.analyzer.emit(f'{color("L2CAP", "green")} [CID={l2cap_pdu.cid}, PSM=AVDTP]: {l2cap_pdu.payload.hex()}')
f'{color("L2CAP", "green")} [CID={l2cap_pdu.cid}, ' assembler = self.avdtp_assemblers.get(l2cap_pdu.cid)
f'PSM=AVDTP]: {l2cap_pdu.payload.hex()}' if assembler:
) assembler.on_pdu(l2cap_pdu.payload)
if avdtp_assembler := self.avdtp_assemblers.get(l2cap_pdu.cid):
avdtp_assembler.on_pdu(l2cap_pdu.payload)
elif psm == avctp.AVCTP_PSM:
self.analyzer.emit(
f'{color("L2CAP", "green")} [CID={l2cap_pdu.cid}, '
f'PSM=AVCTP]: {l2cap_pdu.payload.hex()}'
)
if avctp_assembler := self.avctp_assemblers.get(l2cap_pdu.cid):
avctp_assembler.on_pdu(l2cap_pdu.payload)
else: else:
psm_string = name_or_number(PSM_NAMES, psm) psm_string = name_or_number(PSM_NAMES, psm)
self.analyzer.emit( self.analyzer.emit(f'{color("L2CAP", "green")} [CID={l2cap_pdu.cid}, PSM={psm_string}]: {l2cap_pdu.payload.hex()}')
f'{color("L2CAP", "green")} [CID={l2cap_pdu.cid}, '
f'PSM={psm_string}]: {l2cap_pdu.payload.hex()}'
)
else: else:
self.analyzer.emit(l2cap_pdu) self.analyzer.emit(l2cap_pdu)
def on_avdtp_message( def on_avdtp_message(self, transaction_label, message):
self, transaction_label: int, message: avdtp.Message self.analyzer.emit(f'{color("AVDTP", "green")} [{transaction_label}] {message}')
) -> None:
self.analyzer.emit(
f'{color("AVDTP", "green")} [{transaction_label}] {message}'
)
def on_avctp_message( def feed_packet(self, packet):
self,
transaction_label: int,
is_command: bool,
ipid: bool,
pid: int,
payload: bytes,
):
if pid == avrcp.AVRCP_PID:
avc_frame = avc.Frame.from_bytes(payload)
details = str(avc_frame)
else:
details = payload.hex()
c_r = 'Command' if is_command else 'Response'
self.analyzer.emit(
f'{color("AVCTP", "green")} '
f'{c_r}[{transaction_label}][{name_or_number(AVCTP_PID_NAMES, pid)}] '
f'{"#" if ipid else ""}'
f'{details}'
)
def feed_packet(self, packet: HCI_AclDataPacket) -> None:
self.packet_assembler.feed_packet(packet) self.packet_assembler.feed_packet(packet)
class Analyzer: class Analyzer:
acl_streams: MutableMapping[int, PacketTracer.AclStream] def __init__(self, label, emit_message):
peer: PacketTracer.Analyzer self.label = label
def __init__(self, label: str, emit_message: Callable[..., None]) -> None:
self.label = label
self.emit_message = emit_message self.emit_message = emit_message
self.acl_streams = {} # ACL streams, by connection handle self.acl_streams = {} # ACL streams, by connection handle
self.packet_timestamp: Optional[datetime.datetime] = None self.peer = None # Analyzer in the other direction
def start_acl_stream(self, connection_handle: int) -> PacketTracer.AclStream: def start_acl_stream(self, connection_handle):
logger.info( logger.info(f'[{self.label}] +++ Creating ACL stream for connection 0x{connection_handle:04X}')
f'[{self.label}] +++ Creating ACL stream for connection '
f'0x{connection_handle:04X}'
)
stream = PacketTracer.AclStream(self) stream = PacketTracer.AclStream(self)
self.acl_streams[connection_handle] = stream self.acl_streams[connection_handle] = stream
@@ -226,80 +137,43 @@ class PacketTracer:
return stream return stream
def end_acl_stream(self, connection_handle: int) -> None: def end_acl_stream(self, connection_handle):
if connection_handle in self.acl_streams: if connection_handle in self.acl_streams:
logger.info( logger.info(f'[{self.label}] --- Removing ACL stream for connection 0x{connection_handle:04X}')
f'[{self.label}] --- Removing ACL stream for connection '
f'0x{connection_handle:04X}'
)
del self.acl_streams[connection_handle] del self.acl_streams[connection_handle]
# Let the other forwarder know so it can cleanup its stream as well # Let the other forwarder know so it can cleanup its stream as well
self.peer.end_acl_stream(connection_handle) self.peer.end_acl_stream(connection_handle)
def on_packet( def on_packet(self, packet):
self, timestamp: Optional[datetime.datetime], packet: HCI_Packet
) -> None:
self.packet_timestamp = timestamp
self.emit(packet) self.emit(packet)
if packet.hci_packet_type == HCI_ACL_DATA_PACKET: if packet.hci_packet_type == HCI_ACL_DATA_PACKET:
acl_packet = cast(HCI_AclDataPacket, packet)
# Look for an existing stream for this handle, create one if it is the # Look for an existing stream for this handle, create one if it is the
# first ACL packet for that connection handle # first ACL packet for that connection handle
if ( if (stream := self.acl_streams.get(packet.connection_handle)) is None:
stream := self.acl_streams.get(acl_packet.connection_handle) stream = self.start_acl_stream(packet.connection_handle)
) is None: stream.feed_packet(packet)
stream = self.start_acl_stream(acl_packet.connection_handle)
stream.feed_packet(acl_packet)
elif packet.hci_packet_type == HCI_EVENT_PACKET: elif packet.hci_packet_type == HCI_EVENT_PACKET:
event_packet = cast(HCI_Event, packet) if packet.event_code == HCI_DISCONNECTION_COMPLETE_EVENT:
if event_packet.event_code == HCI_DISCONNECTION_COMPLETE_EVENT: self.end_acl_stream(packet.connection_handle)
self.end_acl_stream(
cast(HCI_Disconnection_Complete_Event, packet).connection_handle
)
def emit(self, message: Any) -> None: def emit(self, message):
if self.packet_timestamp: self.emit_message(f'[{self.label}] {message}')
prefix = f"[{self.packet_timestamp.strftime('%Y-%m-%d %H:%M:%S.%f')}]"
else:
prefix = ""
self.emit_message(f'{prefix}[{self.label}] {message}')
def trace( def trace(self, packet, direction=0):
self,
packet: HCI_Packet,
direction: int = 0,
timestamp: Optional[datetime.datetime] = None,
) -> None:
if direction == 0: if direction == 0:
self.host_to_controller_analyzer.on_packet(timestamp, packet) self.host_to_controller_analyzer.on_packet(packet)
else: else:
self.controller_to_host_analyzer.on_packet(timestamp, packet) self.controller_to_host_analyzer.on_packet(packet)
def __init__( def __init__(
self, self,
host_to_controller_label: str = color('HOST->CONTROLLER', 'blue'), host_to_controller_label=color('HOST->CONTROLLER', 'blue'),
controller_to_host_label: str = color('CONTROLLER->HOST', 'cyan'), controller_to_host_label=color('CONTROLLER->HOST', 'cyan'),
emit_message: Callable[..., None] = logger.info, emit_message=logger.info
) -> None: ):
self.host_to_controller_analyzer = PacketTracer.Analyzer( self.host_to_controller_analyzer = PacketTracer.Analyzer(host_to_controller_label, emit_message)
host_to_controller_label, emit_message self.controller_to_host_analyzer = PacketTracer.Analyzer(controller_to_host_label, emit_message)
)
self.controller_to_host_analyzer = PacketTracer.Analyzer(
controller_to_host_label, emit_message
)
self.host_to_controller_analyzer.peer = self.controller_to_host_analyzer self.host_to_controller_analyzer.peer = self.controller_to_host_analyzer
self.controller_to_host_analyzer.peer = self.host_to_controller_analyzer self.controller_to_host_analyzer.peer = self.host_to_controller_analyzer
def generate_irk() -> bytes:
return crypto.r()
def verify_rpa_with_irk(rpa: Address, irk: bytes) -> bool:
rpa_bytes = bytes(rpa)
prand_given = rpa_bytes[3:]
hash_given = rpa_bytes[:3]
hash_local = crypto.ah(irk, prand_given)
return hash_local[:3] == hash_given
+30 -2033
View File
File diff suppressed because it is too large Load Diff
-555
View File
@@ -1,555 +0,0 @@
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
from dataclasses import dataclass
import logging
import enum
import struct
from abc import ABC, abstractmethod
from pyee import EventEmitter
from typing import Optional, Callable, TYPE_CHECKING
from typing_extensions import override
from bumble import l2cap, device
from bumble.colors import color
from bumble.core import InvalidStateError, ProtocolError
from .hci import Address
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
# fmt: on
HID_CONTROL_PSM = 0x0011
HID_INTERRUPT_PSM = 0x0013
class Message:
message_type: MessageType
# Report types
class ReportType(enum.IntEnum):
OTHER_REPORT = 0x00
INPUT_REPORT = 0x01
OUTPUT_REPORT = 0x02
FEATURE_REPORT = 0x03
# Handshake parameters
class Handshake(enum.IntEnum):
SUCCESSFUL = 0x00
NOT_READY = 0x01
ERR_INVALID_REPORT_ID = 0x02
ERR_UNSUPPORTED_REQUEST = 0x03
ERR_INVALID_PARAMETER = 0x04
ERR_UNKNOWN = 0x0E
ERR_FATAL = 0x0F
# Message Type
class MessageType(enum.IntEnum):
HANDSHAKE = 0x00
CONTROL = 0x01
GET_REPORT = 0x04
SET_REPORT = 0x05
GET_PROTOCOL = 0x06
SET_PROTOCOL = 0x07
DATA = 0x0A
# Protocol modes
class ProtocolMode(enum.IntEnum):
BOOT_PROTOCOL = 0x00
REPORT_PROTOCOL = 0x01
# Control Operations
class ControlCommand(enum.IntEnum):
SUSPEND = 0x03
EXIT_SUSPEND = 0x04
VIRTUAL_CABLE_UNPLUG = 0x05
# Class Method to derive header
@classmethod
def header(cls, lower_bits: int = 0x00) -> bytes:
return bytes([(cls.message_type << 4) | lower_bits])
# HIDP messages
@dataclass
class GetReportMessage(Message):
report_type: int
report_id: int
buffer_size: int
message_type = Message.MessageType.GET_REPORT
def __bytes__(self) -> bytes:
packet_bytes = bytearray()
packet_bytes.append(self.report_id)
if self.buffer_size == 0:
return self.header(self.report_type) + packet_bytes
else:
return (
self.header(0x08 | self.report_type)
+ packet_bytes
+ struct.pack("<H", self.buffer_size)
)
@dataclass
class SetReportMessage(Message):
report_type: int
data: bytes
message_type = Message.MessageType.SET_REPORT
def __bytes__(self) -> bytes:
return self.header(self.report_type) + self.data
@dataclass
class SendControlData(Message):
report_type: int
data: bytes
message_type = Message.MessageType.DATA
def __bytes__(self) -> bytes:
return self.header(self.report_type) + self.data
@dataclass
class GetProtocolMessage(Message):
message_type = Message.MessageType.GET_PROTOCOL
def __bytes__(self) -> bytes:
return self.header()
@dataclass
class SetProtocolMessage(Message):
protocol_mode: int
message_type = Message.MessageType.SET_PROTOCOL
def __bytes__(self) -> bytes:
return self.header(self.protocol_mode)
@dataclass
class Suspend(Message):
message_type = Message.MessageType.CONTROL
def __bytes__(self) -> bytes:
return self.header(Message.ControlCommand.SUSPEND)
@dataclass
class ExitSuspend(Message):
message_type = Message.MessageType.CONTROL
def __bytes__(self) -> bytes:
return self.header(Message.ControlCommand.EXIT_SUSPEND)
@dataclass
class VirtualCableUnplug(Message):
message_type = Message.MessageType.CONTROL
def __bytes__(self) -> bytes:
return self.header(Message.ControlCommand.VIRTUAL_CABLE_UNPLUG)
# Device sends input report, host sends output report.
@dataclass
class SendData(Message):
data: bytes
report_type: int
message_type = Message.MessageType.DATA
def __bytes__(self) -> bytes:
return self.header(self.report_type) + self.data
@dataclass
class SendHandshakeMessage(Message):
result_code: int
message_type = Message.MessageType.HANDSHAKE
def __bytes__(self) -> bytes:
return self.header(self.result_code)
# -----------------------------------------------------------------------------
class HID(ABC, EventEmitter):
l2cap_ctrl_channel: Optional[l2cap.ClassicChannel] = None
l2cap_intr_channel: Optional[l2cap.ClassicChannel] = None
connection: Optional[device.Connection] = None
class Role(enum.IntEnum):
HOST = 0x00
DEVICE = 0x01
def __init__(self, device: device.Device, role: Role) -> None:
super().__init__()
self.remote_device_bd_address: Optional[Address] = None
self.device = device
self.role = role
# Register ourselves with the L2CAP channel manager
device.register_l2cap_server(HID_CONTROL_PSM, self.on_l2cap_connection)
device.register_l2cap_server(HID_INTERRUPT_PSM, self.on_l2cap_connection)
device.on('connection', self.on_device_connection)
async def connect_control_channel(self) -> None:
# Create a new L2CAP connection - control channel
try:
self.l2cap_ctrl_channel = await self.device.l2cap_channel_manager.connect(
self.connection, HID_CONTROL_PSM
)
except ProtocolError:
logging.exception(f'L2CAP connection failed.')
raise
assert self.l2cap_ctrl_channel is not None
# Become a sink for the L2CAP channel
self.l2cap_ctrl_channel.sink = self.on_ctrl_pdu
async def connect_interrupt_channel(self) -> None:
# Create a new L2CAP connection - interrupt channel
try:
self.l2cap_intr_channel = await self.device.l2cap_channel_manager.connect(
self.connection, HID_INTERRUPT_PSM
)
except ProtocolError:
logging.exception(f'L2CAP connection failed.')
raise
assert self.l2cap_intr_channel is not None
# Become a sink for the L2CAP channel
self.l2cap_intr_channel.sink = self.on_intr_pdu
async def disconnect_interrupt_channel(self) -> None:
if self.l2cap_intr_channel is None:
raise InvalidStateError('invalid state')
channel = self.l2cap_intr_channel
self.l2cap_intr_channel = None
await channel.disconnect()
async def disconnect_control_channel(self) -> None:
if self.l2cap_ctrl_channel is None:
raise InvalidStateError('invalid state')
channel = self.l2cap_ctrl_channel
self.l2cap_ctrl_channel = None
await channel.disconnect()
def on_device_connection(self, connection: device.Connection) -> None:
self.connection = connection
self.remote_device_bd_address = connection.peer_address
connection.on('disconnection', self.on_device_disconnection)
def on_device_disconnection(self, reason: int) -> None:
self.connection = None
def on_l2cap_connection(self, l2cap_channel: l2cap.ClassicChannel) -> None:
logger.debug(f'+++ New L2CAP connection: {l2cap_channel}')
l2cap_channel.on('open', lambda: self.on_l2cap_channel_open(l2cap_channel))
l2cap_channel.on('close', lambda: self.on_l2cap_channel_close(l2cap_channel))
def on_l2cap_channel_open(self, l2cap_channel: l2cap.ClassicChannel) -> None:
if l2cap_channel.psm == HID_CONTROL_PSM:
self.l2cap_ctrl_channel = l2cap_channel
self.l2cap_ctrl_channel.sink = self.on_ctrl_pdu
else:
self.l2cap_intr_channel = l2cap_channel
self.l2cap_intr_channel.sink = self.on_intr_pdu
logger.debug(f'$$$ L2CAP channel open: {l2cap_channel}')
def on_l2cap_channel_close(self, l2cap_channel: l2cap.ClassicChannel) -> None:
if l2cap_channel.psm == HID_CONTROL_PSM:
self.l2cap_ctrl_channel = None
else:
self.l2cap_intr_channel = None
logger.debug(f'$$$ L2CAP channel close: {l2cap_channel}')
@abstractmethod
def on_ctrl_pdu(self, pdu: bytes) -> None:
pass
def on_intr_pdu(self, pdu: bytes) -> None:
logger.debug(f'<<< HID INTERRUPT PDU: {pdu.hex()}')
self.emit("interrupt_data", pdu)
def send_pdu_on_ctrl(self, msg: bytes) -> None:
assert self.l2cap_ctrl_channel
self.l2cap_ctrl_channel.send_pdu(msg)
def send_pdu_on_intr(self, msg: bytes) -> None:
assert self.l2cap_intr_channel
self.l2cap_intr_channel.send_pdu(msg)
def send_data(self, data: bytes) -> None:
if self.role == HID.Role.HOST:
report_type = Message.ReportType.OUTPUT_REPORT
else:
report_type = Message.ReportType.INPUT_REPORT
msg = SendData(data, report_type)
hid_message = bytes(msg)
if self.l2cap_intr_channel is not None:
logger.debug(f'>>> HID INTERRUPT SEND DATA, PDU: {hid_message.hex()}')
self.send_pdu_on_intr(hid_message)
def virtual_cable_unplug(self) -> None:
msg = VirtualCableUnplug()
hid_message = bytes(msg)
logger.debug(f'>>> HID CONTROL VIRTUAL CABLE UNPLUG, PDU: {hid_message.hex()}')
self.send_pdu_on_ctrl(hid_message)
# -----------------------------------------------------------------------------
class Device(HID):
class GetSetReturn(enum.IntEnum):
FAILURE = 0x00
REPORT_ID_NOT_FOUND = 0x01
ERR_UNSUPPORTED_REQUEST = 0x02
ERR_UNKNOWN = 0x03
ERR_INVALID_PARAMETER = 0x04
SUCCESS = 0xFF
class GetSetStatus:
def __init__(self) -> None:
self.data = bytearray()
self.status = 0
def __init__(self, device: device.Device) -> None:
super().__init__(device, HID.Role.DEVICE)
get_report_cb: Optional[Callable[[int, int, int], None]] = None
set_report_cb: Optional[Callable[[int, int, int, bytes], None]] = None
get_protocol_cb: Optional[Callable[[], None]] = None
set_protocol_cb: Optional[Callable[[int], None]] = None
@override
def on_ctrl_pdu(self, pdu: bytes) -> None:
logger.debug(f'<<< HID CONTROL PDU: {pdu.hex()}')
param = pdu[0] & 0x0F
message_type = pdu[0] >> 4
if message_type == Message.MessageType.GET_REPORT:
logger.debug('<<< HID GET REPORT')
self.handle_get_report(pdu)
elif message_type == Message.MessageType.SET_REPORT:
logger.debug('<<< HID SET REPORT')
self.handle_set_report(pdu)
elif message_type == Message.MessageType.GET_PROTOCOL:
logger.debug('<<< HID GET PROTOCOL')
self.handle_get_protocol(pdu)
elif message_type == Message.MessageType.SET_PROTOCOL:
logger.debug('<<< HID SET PROTOCOL')
self.handle_set_protocol(pdu)
elif message_type == Message.MessageType.DATA:
logger.debug('<<< HID CONTROL DATA')
self.emit('control_data', pdu)
elif message_type == Message.MessageType.CONTROL:
if param == Message.ControlCommand.SUSPEND:
logger.debug('<<< HID SUSPEND')
self.emit('suspend')
elif param == Message.ControlCommand.EXIT_SUSPEND:
logger.debug('<<< HID EXIT SUSPEND')
self.emit('exit_suspend')
elif param == Message.ControlCommand.VIRTUAL_CABLE_UNPLUG:
logger.debug('<<< HID VIRTUAL CABLE UNPLUG')
self.emit('virtual_cable_unplug')
else:
logger.debug('<<< HID CONTROL OPERATION UNSUPPORTED')
else:
logger.debug('<<< HID MESSAGE TYPE UNSUPPORTED')
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
def send_handshake_message(self, result_code: int) -> None:
msg = SendHandshakeMessage(result_code)
hid_message = bytes(msg)
logger.debug(f'>>> HID HANDSHAKE MESSAGE, PDU: {hid_message.hex()}')
self.send_pdu_on_ctrl(hid_message)
def send_control_data(self, report_type: int, data: bytes):
msg = SendControlData(report_type=report_type, data=data)
hid_message = bytes(msg)
logger.debug(f'>>> HID CONTROL DATA: {hid_message.hex()}')
self.send_pdu_on_ctrl(hid_message)
def handle_get_report(self, pdu: bytes):
if self.get_report_cb is None:
logger.debug("GetReport callback not registered !!")
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
return
report_type = pdu[0] & 0x03
buffer_flag = (pdu[0] & 0x08) >> 3
report_id = pdu[1]
logger.debug(f"buffer_flag: {buffer_flag}")
if buffer_flag == 1:
buffer_size = (pdu[3] << 8) | pdu[2]
else:
buffer_size = 0
ret = self.get_report_cb(report_id, report_type, buffer_size)
assert ret is not None
if ret.status == self.GetSetReturn.FAILURE:
self.send_handshake_message(Message.Handshake.ERR_UNKNOWN)
elif ret.status == self.GetSetReturn.SUCCESS:
data = bytearray()
data.append(report_id)
data.extend(ret.data)
if len(data) < self.l2cap_ctrl_channel.peer_mtu: # type: ignore[union-attr]
self.send_control_data(report_type=report_type, data=data)
else:
self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER)
elif ret.status == self.GetSetReturn.REPORT_ID_NOT_FOUND:
self.send_handshake_message(Message.Handshake.ERR_INVALID_REPORT_ID)
elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER:
self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER)
elif ret.status == self.GetSetReturn.ERR_UNSUPPORTED_REQUEST:
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
def register_get_report_cb(self, cb: Callable[[int, int, int], None]) -> None:
self.get_report_cb = cb
logger.debug("GetReport callback registered successfully")
def handle_set_report(self, pdu: bytes):
if self.set_report_cb is None:
logger.debug("SetReport callback not registered !!")
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
return
report_type = pdu[0] & 0x03
report_id = pdu[1]
report_data = pdu[2:]
report_size = len(report_data) + 1
ret = self.set_report_cb(report_id, report_type, report_size, report_data)
assert ret is not None
if ret.status == self.GetSetReturn.SUCCESS:
self.send_handshake_message(Message.Handshake.SUCCESSFUL)
elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER:
self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER)
elif ret.status == self.GetSetReturn.REPORT_ID_NOT_FOUND:
self.send_handshake_message(Message.Handshake.ERR_INVALID_REPORT_ID)
else:
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
def register_set_report_cb(
self, cb: Callable[[int, int, int, bytes], None]
) -> None:
self.set_report_cb = cb
logger.debug("SetReport callback registered successfully")
def handle_get_protocol(self, pdu: bytes):
if self.get_protocol_cb is None:
logger.debug("GetProtocol callback not registered !!")
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
return
ret = self.get_protocol_cb()
assert ret is not None
if ret.status == self.GetSetReturn.SUCCESS:
self.send_control_data(Message.ReportType.OTHER_REPORT, ret.data)
else:
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
def register_get_protocol_cb(self, cb: Callable[[], None]) -> None:
self.get_protocol_cb = cb
logger.debug("GetProtocol callback registered successfully")
def handle_set_protocol(self, pdu: bytes):
if self.set_protocol_cb is None:
logger.debug("SetProtocol callback not registered !!")
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
return
ret = self.set_protocol_cb(pdu[0] & 0x01)
assert ret is not None
if ret.status == self.GetSetReturn.SUCCESS:
self.send_handshake_message(Message.Handshake.SUCCESSFUL)
else:
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
def register_set_protocol_cb(self, cb: Callable[[int], None]) -> None:
self.set_protocol_cb = cb
logger.debug("SetProtocol callback registered successfully")
# -----------------------------------------------------------------------------
class Host(HID):
def __init__(self, device: device.Device) -> None:
super().__init__(device, HID.Role.HOST)
def get_report(self, report_type: int, report_id: int, buffer_size: int) -> None:
msg = GetReportMessage(
report_type=report_type, report_id=report_id, buffer_size=buffer_size
)
hid_message = bytes(msg)
logger.debug(f'>>> HID CONTROL GET REPORT, PDU: {hid_message.hex()}')
self.send_pdu_on_ctrl(hid_message)
def set_report(self, report_type: int, data: bytes) -> None:
msg = SetReportMessage(report_type=report_type, data=data)
hid_message = bytes(msg)
logger.debug(f'>>> HID CONTROL SET REPORT, PDU:{hid_message.hex()}')
self.send_pdu_on_ctrl(hid_message)
def get_protocol(self) -> None:
msg = GetProtocolMessage()
hid_message = bytes(msg)
logger.debug(f'>>> HID CONTROL GET PROTOCOL, PDU: {hid_message.hex()}')
self.send_pdu_on_ctrl(hid_message)
def set_protocol(self, protocol_mode: int) -> None:
msg = SetProtocolMessage(protocol_mode=protocol_mode)
hid_message = bytes(msg)
logger.debug(f'>>> HID CONTROL SET PROTOCOL, PDU: {hid_message.hex()}')
self.send_pdu_on_ctrl(hid_message)
def suspend(self) -> None:
msg = Suspend()
hid_message = bytes(msg)
logger.debug(f'>>> HID CONTROL SUSPEND, PDU:{hid_message.hex()}')
self.send_pdu_on_ctrl(hid_message)
def exit_suspend(self) -> None:
msg = ExitSuspend()
hid_message = bytes(msg)
logger.debug(f'>>> HID CONTROL EXIT SUSPEND, PDU:{hid_message.hex()}')
self.send_pdu_on_ctrl(hid_message)
@override
def on_ctrl_pdu(self, pdu: bytes) -> None:
logger.debug(f'<<< HID CONTROL PDU: {pdu.hex()}')
param = pdu[0] & 0x0F
message_type = pdu[0] >> 4
if message_type == Message.MessageType.HANDSHAKE:
logger.debug(f'<<< HID HANDSHAKE: {Message.Handshake(param).name}')
self.emit('handshake', Message.Handshake(param))
elif message_type == Message.MessageType.DATA:
logger.debug('<<< HID CONTROL DATA')
self.emit('control_data', pdu)
elif message_type == Message.MessageType.CONTROL:
if param == Message.ControlCommand.VIRTUAL_CABLE_UNPLUG:
logger.debug('<<< HID VIRTUAL CABLE UNPLUG')
self.emit('virtual_cable_unplug')
else:
logger.debug('<<< HID CONTROL OPERATION UNSUPPORTED')
else:
logger.debug('<<< HID MESSAGE TYPE UNSUPPORTED')
+279 -852
View File
File diff suppressed because it is too large Load Diff
+84 -170
View File
@@ -20,20 +20,13 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import logging import logging
import os import os
import json import json
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type from colors import color
from typing_extensions import Self
from .colors import color
from .hci import Address from .hci import Address
if TYPE_CHECKING:
from .device import Device
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -45,10 +38,10 @@ logger = logging.getLogger(__name__)
class PairingKeys: class PairingKeys:
class Key: class Key:
def __init__(self, value, authenticated=False, ediv=None, rand=None): def __init__(self, value, authenticated=False, ediv=None, rand=None):
self.value = value self.value = value
self.authenticated = authenticated self.authenticated = authenticated
self.ediv = ediv self.ediv = ediv
self.rand = rand self.rand = rand
@classmethod @classmethod
def from_dict(cls, key_dict): def from_dict(cls, key_dict):
@@ -71,33 +64,31 @@ class PairingKeys:
return key_dict return key_dict
def __init__(self): def __init__(self):
self.address_type = None self.address_type = None
self.ltk = None self.ltk = None
self.ltk_central = None self.ltk_central = None
self.ltk_peripheral = None self.ltk_peripheral = None
self.irk = None self.irk = None
self.csrk = None self.csrk = None
self.link_key = None # Classic self.link_key = None # Classic
@staticmethod @staticmethod
def key_from_dict(keys_dict, key_name): def key_from_dict(keys_dict, key_name):
key_dict = keys_dict.get(key_name) key_dict = keys_dict.get(key_name)
if key_dict is None: if key_dict is not None:
return None return PairingKeys.Key.from_dict(key_dict)
return PairingKeys.Key.from_dict(key_dict)
@staticmethod @staticmethod
def from_dict(keys_dict): def from_dict(keys_dict):
keys = PairingKeys() keys = PairingKeys()
keys.address_type = keys_dict.get('address_type') keys.address_type = keys_dict.get('address_type')
keys.ltk = PairingKeys.key_from_dict(keys_dict, 'ltk') keys.ltk = PairingKeys.key_from_dict(keys_dict, 'ltk')
keys.ltk_central = PairingKeys.key_from_dict(keys_dict, 'ltk_central') keys.ltk_central = PairingKeys.key_from_dict(keys_dict, 'ltk_central')
keys.ltk_peripheral = PairingKeys.key_from_dict(keys_dict, 'ltk_peripheral') keys.ltk_peripheral = PairingKeys.key_from_dict(keys_dict, 'ltk_peripheral')
keys.irk = PairingKeys.key_from_dict(keys_dict, 'irk') keys.irk = PairingKeys.key_from_dict(keys_dict, 'irk')
keys.csrk = PairingKeys.key_from_dict(keys_dict, 'csrk') keys.csrk = PairingKeys.key_from_dict(keys_dict, 'csrk')
keys.link_key = PairingKeys.key_from_dict(keys_dict, 'link_key') keys.link_key = PairingKeys.key_from_dict(keys_dict, 'link_key')
return keys return keys
@@ -129,37 +120,33 @@ class PairingKeys:
def print(self, prefix=''): def print(self, prefix=''):
keys_dict = self.to_dict() keys_dict = self.to_dict()
for container_property, value in keys_dict.items(): for (property, value) in keys_dict.items():
if isinstance(value, dict): if type(value) is dict:
print(f'{prefix}{color(container_property, "cyan")}:') print(f'{prefix}{color(property, "cyan")}:')
for key_property, key_value in value.items(): for (key_property, key_value) in value.items():
print(f'{prefix} {color(key_property, "green")}: {key_value}') print(f'{prefix} {color(key_property, "green")}: {key_value}')
else: else:
print(f'{prefix}{color(container_property, "cyan")}: {value}') print(f'{prefix}{color(property, "cyan")}: {value}')
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class KeyStore: class KeyStore:
async def delete(self, name: str): async def delete(self, name):
pass pass
async def update(self, name: str, keys: PairingKeys) -> None: async def update(self, name, keys):
pass pass
async def get(self, _name: str) -> Optional[PairingKeys]: async def get(self, name):
return None return PairingKeys()
async def get_all(self) -> List[Tuple[str, PairingKeys]]: async def get_all(self):
return [] return []
async def delete_all(self) -> None:
all_keys = await self.get_all()
await asyncio.gather(*(self.delete(name) for (name, _) in all_keys))
async def get_resolving_keys(self): async def get_resolving_keys(self):
all_keys = await self.get_all() all_keys = await self.get_all()
resolving_keys = [] resolving_keys = []
for name, keys in all_keys: for (name, keys) in all_keys:
if keys.irk is not None: if keys.irk is not None:
if keys.address_type is None: if keys.address_type is None:
address_type = Address.RANDOM_DEVICE_ADDRESS address_type = Address.RANDOM_DEVICE_ADDRESS
@@ -172,81 +159,41 @@ class KeyStore:
async def print(self, prefix=''): async def print(self, prefix=''):
entries = await self.get_all() entries = await self.get_all()
separator = '' separator = ''
for name, keys in entries: for (name, keys) in entries:
print(separator + prefix + color(name, 'yellow')) print(separator + prefix + color(name, 'yellow'))
keys.print(prefix=prefix + ' ') keys.print(prefix = prefix + ' ')
separator = '\n' separator = '\n'
@staticmethod @staticmethod
def create_for_device(device: Device) -> KeyStore: def create_for_device(device_config):
if device.config.keystore is None: if device_config.keystore is None:
return MemoryKeyStore() return None
keystore_type = device.config.keystore.split(':', 1)[0] keystore_type = device_config.keystore.split(':', 1)[0]
if keystore_type == 'JsonKeyStore': if keystore_type == 'JsonKeyStore':
return JsonKeyStore.from_device(device) return JsonKeyStore.from_device_config(device_config)
return MemoryKeyStore() return None
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class JsonKeyStore(KeyStore): class JsonKeyStore(KeyStore):
""" APP_NAME = 'Bumble'
KeyStore implementation that is backed by a JSON file. APP_AUTHOR = 'Google'
KEYS_DIR = 'Pairing'
This implementation supports storing a hierarchy of key sets in a single file.
A key set is a representation of a PairingKeys object. Each key set is stored
in a map, with the address of paired peer as the key. Maps are themselves grouped
into namespaces, grouping pairing keys by controller addresses.
The JSON object model looks like:
{
"<namespace>": {
"peer-address": {
"address_type": <n>,
"irk" : {
"authenticated": <true/false>,
"value": "hex-encoded-key"
},
... other keys ...
},
... other peers ...
}
... other namespaces ...
}
A namespace is typically the BD_ADDR of a controller, since that is a convenient
unique identifier, but it may be something else.
A special namespace, called the "default" namespace, is used when instantiating this
class without a namespace. With the default namespace, reading from a file will
load an existing namespace if there is only one, which may be convenient for reading
from a file with a single key set and for which the namespace isn't known. If the
file does not include any existing key set, or if there are more than one and none
has the default name, a new one will be created with the name "__DEFAULT__".
"""
APP_NAME = 'Bumble'
APP_AUTHOR = 'Google'
KEYS_DIR = 'Pairing'
DEFAULT_NAMESPACE = '__DEFAULT__' DEFAULT_NAMESPACE = '__DEFAULT__'
DEFAULT_BASE_NAME = "keys"
def __init__(self, namespace, filename=None): def __init__(self, namespace, filename=None):
self.namespace = namespace if namespace is not None else self.DEFAULT_NAMESPACE self.namespace = namespace if namespace is not None else self.DEFAULT_NAMESPACE
if filename is None: if filename is None:
# Use a default for the current user # Use a default for the current user
# Import here because this may not exist on all platforms
# pylint: disable=import-outside-toplevel
import appdirs import appdirs
self.directory_name = os.path.join( self.directory_name = os.path.join(
appdirs.user_data_dir(self.APP_NAME, self.APP_AUTHOR), self.KEYS_DIR appdirs.user_data_dir(self.APP_NAME, self.APP_AUTHOR),
) self.KEYS_DIR
base_name = self.DEFAULT_BASE_NAME if namespace is None else self.namespace
json_filename = (
f'{base_name}.json'.lower().replace(':', '-').replace('/p', '-p')
) )
json_filename = f'{self.namespace}.json'.lower().replace(':', '-')
self.filename = os.path.join(self.directory_name, json_filename) self.filename = os.path.join(self.directory_name, json_filename)
else: else:
self.filename = filename self.filename = filename
@@ -254,49 +201,23 @@ class JsonKeyStore(KeyStore):
logger.debug(f'JSON keystore: {self.filename}') logger.debug(f'JSON keystore: {self.filename}')
@classmethod @staticmethod
def from_device( def from_device_config(device_config):
cls: Type[Self], device: Device, filename: Optional[str] = None params = device_config.keystore.split(':', 1)[1:]
) -> Self: namespace = str(device_config.address)
if not filename: if params:
# Extract the filename from the config if there is one filename = params[1]
if device.config.keystore is not None:
params = device.config.keystore.split(':', 1)[1:]
if params:
filename = params[0]
# Use a namespace based on the device address
if device.public_address not in (Address.ANY, Address.ANY_RANDOM):
namespace = str(device.public_address)
elif device.random_address != Address.ANY_RANDOM:
namespace = str(device.random_address)
else: else:
namespace = JsonKeyStore.DEFAULT_NAMESPACE filename = None
return cls(namespace, filename) return JsonKeyStore(namespace, filename)
async def load(self): async def load(self):
# Try to open the file, without failing. If the file does not exist, it
# will be created upon saving.
try: try:
with open(self.filename, 'r', encoding='utf-8') as json_file: with open(self.filename, 'r') as json_file:
db = json.load(json_file) return json.load(json_file)
except FileNotFoundError: except FileNotFoundError:
db = {} return {}
# First, look for a namespace match
if self.namespace in db:
return (db, db[self.namespace])
# Then, if the namespace is the default namespace, and there's
# only one entry in the db, use that
if self.namespace == self.DEFAULT_NAMESPACE and len(db) == 1:
return next(iter(db.items()))
# Finally, just create an empty key map for the namespace
key_map = {}
db[self.namespace] = key_map
return (db, key_map)
async def save(self, db): async def save(self, db):
# Create the directory if it doesn't exist # Create the directory if it doesn't exist
@@ -305,55 +226,48 @@ class JsonKeyStore(KeyStore):
# Save to a temporary file # Save to a temporary file
temp_filename = self.filename + '.tmp' temp_filename = self.filename + '.tmp'
with open(temp_filename, 'w', encoding='utf-8') as output: with open(temp_filename, 'w') as output:
json.dump(db, output, sort_keys=True, indent=4) json.dump(db, output, sort_keys=True, indent=4)
# Atomically replace the previous file # Atomically replace the previous file
os.replace(temp_filename, self.filename) os.rename(temp_filename, self.filename)
async def delete(self, name: str) -> None: async def delete(self, name):
db, key_map = await self.load() db = await self.load()
del key_map[name]
namespace = db.get(self.namespace)
if namespace is None:
raise KeyError(name)
del namespace[name]
await self.save(db) await self.save(db)
async def update(self, name, keys): async def update(self, name, keys):
db, key_map = await self.load() db = await self.load()
key_map.setdefault(name, {}).update(keys.to_dict())
namespace = db.setdefault(self.namespace, {})
namespace[name] = keys.to_dict()
await self.save(db) await self.save(db)
async def get_all(self): async def get_all(self):
_, key_map = await self.load() db = await self.load()
return [(name, PairingKeys.from_dict(keys)) for (name, keys) in key_map.items()]
async def delete_all(self): namespace = db.get(self.namespace)
db, key_map = await self.load() if namespace is None:
key_map.clear() return []
await self.save(db)
async def get(self, name: str) -> Optional[PairingKeys]: return [(name, PairingKeys.from_dict(keys)) for (name, keys) in namespace.items()]
_, key_map = await self.load()
if name not in key_map: async def get(self, name):
db = await self.load()
namespace = db.get(self.namespace)
if namespace is None:
return None return None
return PairingKeys.from_dict(key_map[name]) keys = namespace.get(name)
if keys is None:
return None
return PairingKeys.from_dict(keys)
# -----------------------------------------------------------------------------
class MemoryKeyStore(KeyStore):
all_keys: Dict[str, PairingKeys]
def __init__(self) -> None:
self.all_keys = {}
async def delete(self, name: str) -> None:
if name in self.all_keys:
del self.all_keys[name]
async def update(self, name: str, keys: PairingKeys) -> None:
self.all_keys[name] = keys
async def get(self, name: str) -> Optional[PairingKeys]:
return self.all_keys.get(name)
async def get_all(self) -> List[Tuple[str, PairingKeys]]:
return list(self.all_keys.items())
+445 -1600
View File
File diff suppressed because it is too large Load Diff
+52 -330
View File
@@ -17,22 +17,16 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import logging import logging
import asyncio import asyncio
import websockets
from functools import partial from functools import partial
from colors import color
from bumble.core import BT_PERIPHERAL_ROLE, BT_BR_EDR_TRANSPORT, BT_LE_TRANSPORT
from bumble.colors import color
from bumble.hci import ( from bumble.hci import (
Address, Address,
HCI_SUCCESS, HCI_SUCCESS,
HCI_CONNECTION_ACCEPT_TIMEOUT_ERROR, HCI_CONNECTION_ACCEPT_TIMEOUT_ERROR,
HCI_CONNECTION_TIMEOUT_ERROR, HCI_CONNECTION_TIMEOUT_ERROR
HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR,
HCI_PAGE_TIMEOUT_ERROR,
HCI_Connection_Complete_Event,
) )
from bumble import controller
from typing import Optional, Set
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -53,24 +47,16 @@ def parse_parameters(params_str):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# TODO: add more support for various LL exchanges # TODO: add more support for various LL exchanges (see Vol 6, Part B - 2.4 DATA CHANNEL PDU)
# (see Vol 6, Part B - 2.4 DATA CHANNEL PDU)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class LocalLink: class LocalLink:
''' '''
Link bus for controllers to communicate with each other Link bus for controllers to communicate with each other
''' '''
controllers: Set[controller.Controller]
def __init__(self): def __init__(self):
self.controllers = set() self.controllers = set()
self.pending_connection = None self.pending_connection = None
self.pending_classic_connection = None
############################################################
# Common utils
############################################################
def add_controller(self, controller): def add_controller(self, controller):
logger.debug(f'new controller: {controller}') logger.debug(f'new controller: {controller}')
@@ -85,41 +71,22 @@ class LocalLink:
return controller return controller
return None return None
def find_classic_controller( def on_address_changed(self, controller):
self, address: Address pass
) -> Optional[controller.Controller]:
for controller in self.controllers:
if controller.public_address == address:
return controller
return None
def get_pending_connection(self): def get_pending_connection(self):
return self.pending_connection return self.pending_connection
############################################################
# LE handlers
############################################################
def on_address_changed(self, controller):
pass
def send_advertising_data(self, sender_address, data): def send_advertising_data(self, sender_address, data):
# Send the advertising data to all controllers, except the sender # Send the advertising data to all controllers, except the sender
for controller in self.controllers: for controller in self.controllers:
if controller.random_address != sender_address: if controller.random_address != sender_address:
controller.on_link_advertising_data(sender_address, data) controller.on_link_advertising_data(sender_address, data)
def send_acl_data(self, sender_controller, destination_address, transport, data): def send_acl_data(self, sender_address, destination_address, data):
# Send the data to the first controller with a matching address # Send the data to the first controller with a matching address
if transport == BT_LE_TRANSPORT: if controller := self.find_controller(destination_address):
destination_controller = self.find_controller(destination_address) controller.on_link_acl_data(sender_address, data)
source_address = sender_controller.random_address
elif transport == BT_BR_EDR_TRANSPORT:
destination_controller = self.find_classic_controller(destination_address)
source_address = sender_controller.public_address
if destination_controller is not None:
destination_controller.on_link_acl_data(source_address, transport, data)
def on_connection_complete(self): def on_connection_complete(self):
# Check that we expect this call # Check that we expect this call
@@ -136,31 +103,23 @@ class LocalLink:
return return
# Connect to the first controller with a matching address # Connect to the first controller with a matching address
if peripheral_controller := self.find_controller( if peripheral_controller := self.find_controller(le_create_connection_command.peer_address):
le_create_connection_command.peer_address central_controller.on_link_peripheral_connection_complete(le_create_connection_command, HCI_SUCCESS)
):
central_controller.on_link_peripheral_connection_complete(
le_create_connection_command, HCI_SUCCESS
)
peripheral_controller.on_link_central_connected(central_address) peripheral_controller.on_link_central_connected(central_address)
return return
# No peripheral found # No peripheral found
central_controller.on_link_peripheral_connection_complete( central_controller.on_link_peripheral_connection_complete(
le_create_connection_command, HCI_CONNECTION_ACCEPT_TIMEOUT_ERROR le_create_connection_command,
HCI_CONNECTION_ACCEPT_TIMEOUT_ERROR
) )
def connect(self, central_address, le_create_connection_command): def connect(self, central_address, le_create_connection_command):
logger.debug( logger.debug(f'$$$ CONNECTION {central_address} -> {le_create_connection_command.peer_address}')
f'$$$ CONNECTION {central_address} -> '
f'{le_create_connection_command.peer_address}'
)
self.pending_connection = (central_address, le_create_connection_command) self.pending_connection = (central_address, le_create_connection_command)
asyncio.get_running_loop().call_soon(self.on_connection_complete) asyncio.get_running_loop().call_soon(self.on_connection_complete)
def on_disconnection_complete( def on_disconnection_complete(self, central_address, peripheral_address, disconnect_command):
self, central_address, peripheral_address, disconnect_command
):
# Find the controller that initiated the disconnection # Find the controller that initiated the disconnection
if not (central_controller := self.find_controller(central_address)): if not (central_controller := self.find_controller(central_address)):
logger.warning('!!! Initiating controller not found') logger.warning('!!! Initiating controller not found')
@@ -168,26 +127,16 @@ class LocalLink:
# Disconnect from the first controller with a matching address # Disconnect from the first controller with a matching address
if peripheral_controller := self.find_controller(peripheral_address): if peripheral_controller := self.find_controller(peripheral_address):
peripheral_controller.on_link_central_disconnected( peripheral_controller.on_link_central_disconnected(central_address, disconnect_command.reason)
central_address, disconnect_command.reason
)
central_controller.on_link_peripheral_disconnection_complete( central_controller.on_link_peripheral_disconnection_complete(disconnect_command, HCI_SUCCESS)
disconnect_command, HCI_SUCCESS
)
def disconnect(self, central_address, peripheral_address, disconnect_command): def disconnect(self, central_address, peripheral_address, disconnect_command):
logger.debug( logger.debug(f'$$$ DISCONNECTION {central_address} -> {peripheral_address}: reason = {disconnect_command.reason}')
f'$$$ DISCONNECTION {central_address} -> '
f'{peripheral_address}: reason = {disconnect_command.reason}'
)
args = [central_address, peripheral_address, disconnect_command] args = [central_address, peripheral_address, disconnect_command]
asyncio.get_running_loop().call_soon(self.on_disconnection_complete, *args) asyncio.get_running_loop().call_soon(self.on_disconnection_complete, *args)
# pylint: disable=too-many-arguments def on_connection_encrypted(self, central_address, peripheral_address, rand, ediv, ltk):
def on_connection_encrypted(
self, central_address, peripheral_address, rand, ediv, ltk
):
logger.debug(f'*** ENCRYPTION {central_address} -> {peripheral_address}') logger.debug(f'*** ENCRYPTION {central_address} -> {peripheral_address}')
if central_controller := self.find_controller(central_address): if central_controller := self.find_controller(central_address):
@@ -196,189 +145,6 @@ class LocalLink:
if peripheral_controller := self.find_controller(peripheral_address): if peripheral_controller := self.find_controller(peripheral_address):
peripheral_controller.on_link_encrypted(central_address, rand, ediv, ltk) peripheral_controller.on_link_encrypted(central_address, rand, ediv, ltk)
def create_cis(
self,
central_controller: controller.Controller,
peripheral_address: Address,
cig_id: int,
cis_id: int,
) -> None:
logger.debug(
f'$$$ CIS Request {central_controller.random_address} -> {peripheral_address}'
)
if peripheral_controller := self.find_controller(peripheral_address):
asyncio.get_running_loop().call_soon(
peripheral_controller.on_link_cis_request,
central_controller.random_address,
cig_id,
cis_id,
)
def accept_cis(
self,
peripheral_controller: controller.Controller,
central_address: Address,
cig_id: int,
cis_id: int,
) -> None:
logger.debug(
f'$$$ CIS Accept {peripheral_controller.random_address} -> {central_address}'
)
if central_controller := self.find_controller(central_address):
asyncio.get_running_loop().call_soon(
central_controller.on_link_cis_established, cig_id, cis_id
)
asyncio.get_running_loop().call_soon(
peripheral_controller.on_link_cis_established, cig_id, cis_id
)
def disconnect_cis(
self,
initiator_controller: controller.Controller,
peer_address: Address,
cig_id: int,
cis_id: int,
) -> None:
logger.debug(
f'$$$ CIS Disconnect {initiator_controller.random_address} -> {peer_address}'
)
if peer_controller := self.find_controller(peer_address):
asyncio.get_running_loop().call_soon(
initiator_controller.on_link_cis_disconnected, cig_id, cis_id
)
asyncio.get_running_loop().call_soon(
peer_controller.on_link_cis_disconnected, cig_id, cis_id
)
############################################################
# Classic handlers
############################################################
def classic_connect(self, initiator_controller, responder_address):
logger.debug(
f'[Classic] {initiator_controller.public_address} connects to {responder_address}'
)
responder_controller = self.find_classic_controller(responder_address)
if responder_controller is None:
initiator_controller.on_classic_connection_complete(
responder_address, HCI_PAGE_TIMEOUT_ERROR
)
return
self.pending_classic_connection = (initiator_controller, responder_controller)
responder_controller.on_classic_connection_request(
initiator_controller.public_address,
HCI_Connection_Complete_Event.ACL_LINK_TYPE,
)
def classic_accept_connection(
self, responder_controller, initiator_address, responder_role
):
logger.debug(
f'[Classic] {responder_controller.public_address} accepts to connect {initiator_address}'
)
initiator_controller = self.find_classic_controller(initiator_address)
if initiator_controller is None:
responder_controller.on_classic_connection_complete(
responder_controller.public_address, HCI_PAGE_TIMEOUT_ERROR
)
return
async def task():
if responder_role != BT_PERIPHERAL_ROLE:
initiator_controller.on_classic_role_change(
responder_controller.public_address, int(not (responder_role))
)
initiator_controller.on_classic_connection_complete(
responder_controller.public_address, HCI_SUCCESS
)
asyncio.create_task(task())
responder_controller.on_classic_role_change(
initiator_controller.public_address, responder_role
)
responder_controller.on_classic_connection_complete(
initiator_controller.public_address, HCI_SUCCESS
)
self.pending_classic_connection = None
def classic_disconnect(self, initiator_controller, responder_address, reason):
logger.debug(
f'[Classic] {initiator_controller.public_address} disconnects {responder_address}'
)
responder_controller = self.find_classic_controller(responder_address)
async def task():
initiator_controller.on_classic_disconnected(responder_address, reason)
asyncio.create_task(task())
responder_controller.on_classic_disconnected(
initiator_controller.public_address, reason
)
def classic_switch_role(
self, initiator_controller, responder_address, initiator_new_role
):
responder_controller = self.find_classic_controller(responder_address)
if responder_controller is None:
return
async def task():
initiator_controller.on_classic_role_change(
responder_address, initiator_new_role
)
asyncio.create_task(task())
responder_controller.on_classic_role_change(
initiator_controller.public_address, int(not (initiator_new_role))
)
def classic_sco_connect(
self,
initiator_controller: controller.Controller,
responder_address: Address,
link_type: int,
):
logger.debug(
f'[Classic] {initiator_controller.public_address} connects SCO to {responder_address}'
)
responder_controller = self.find_classic_controller(responder_address)
# Initiator controller should handle it.
assert responder_controller
responder_controller.on_classic_connection_request(
initiator_controller.public_address,
link_type,
)
def classic_accept_sco_connection(
self,
responder_controller: controller.Controller,
initiator_address: Address,
link_type: int,
):
logger.debug(
f'[Classic] {responder_controller.public_address} accepts to connect SCO {initiator_address}'
)
initiator_controller = self.find_classic_controller(initiator_address)
if initiator_controller is None:
responder_controller.on_classic_sco_connection_complete(
responder_controller.public_address,
HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR,
link_type,
)
return
async def task():
initiator_controller.on_classic_sco_connection_complete(
responder_controller.public_address, HCI_SUCCESS, link_type
)
asyncio.create_task(task())
responder_controller.on_classic_sco_connection_complete(
initiator_controller.public_address, HCI_SUCCESS, link_type
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class RemoteLink: class RemoteLink:
@@ -386,18 +152,15 @@ class RemoteLink:
A Link implementation that communicates with other virtual controllers via a A Link implementation that communicates with other virtual controllers via a
WebSocket relay WebSocket relay
''' '''
def __init__(self, uri): def __init__(self, uri):
self.controller = None self.controller = None
self.uri = uri self.uri = uri
self.execution_queue = asyncio.Queue() self.execution_queue = asyncio.Queue()
self.websocket = asyncio.get_running_loop().create_future() self.websocket = asyncio.get_running_loop().create_future()
self.rpc_result = None self.rpc_result = None
self.pending_connection = None self.pending_connection = None
self.central_connections = set() # List of addresses that we have connected to self.central_connections = set() # List of addresses that we have connected to
self.peripheral_connections = ( self.peripheral_connections = set() # List of addresses that have connected to us
set()
) # List of addresses that have connected to us
# Connect and run asynchronously # Connect and run asynchronously
asyncio.create_task(self.run_connection()) asyncio.create_task(self.run_connection())
@@ -416,9 +179,6 @@ class RemoteLink:
def get_pending_connection(self): def get_pending_connection(self):
return self.pending_connection return self.pending_connection
def get_pending_classic_connection(self):
return self.pending_classic_connection
async def wait_until_connected(self): async def wait_until_connected(self):
await self.websocket await self.websocket
@@ -432,16 +192,11 @@ class RemoteLink:
try: try:
await item await item
except Exception as error: except Exception as error:
logger.warning( logger.warning(f'{color("!!! Exception in async handler:", "red")} {error}')
f'{color("!!! Exception in async handler:", "red")} {error}'
)
async def run_connection(self): async def run_connection(self):
import websockets # lazy import
# Connect to the relay # Connect to the relay
logger.debug(f'connecting to {self.uri}') logger.debug(f'connecting to {self.uri}')
# pylint: disable-next=no-member
websocket = await websockets.connect(self.uri) websocket = await websockets.connect(self.uri)
self.websocket.set_result(websocket) self.websocket.set_result(websocket)
logger.debug(f'connected to {self.uri}') logger.debug(f'connected to {self.uri}')
@@ -472,9 +227,7 @@ class RemoteLink:
self.central_connections.remove(address) self.central_connections.remove(address)
if address in self.peripheral_connections: if address in self.peripheral_connections:
self.controller.on_link_central_disconnected( self.controller.on_link_central_disconnected(address, HCI_CONNECTION_TIMEOUT_ERROR)
address, HCI_CONNECTION_TIMEOUT_ERROR
)
self.peripheral_connections.remove(address) self.peripheral_connections.remove(address)
async def on_unreachable_received(self, target): async def on_unreachable_received(self, target):
@@ -491,9 +244,7 @@ class RemoteLink:
async def on_advertisement_message_received(self, sender, advertisement): async def on_advertisement_message_received(self, sender, advertisement):
try: try:
self.controller.on_link_advertising_data( self.controller.on_link_advertising_data(Address(sender), bytes.fromhex(advertisement))
Address(sender), bytes.fromhex(advertisement)
)
except Exception: except Exception:
logger.exception('exception') logger.exception('exception')
@@ -512,11 +263,11 @@ class RemoteLink:
self.controller.on_link_central_connected(Address(sender)) self.controller.on_link_central_connected(Address(sender))
# Accept the connection by responding to it # Accept the connection by responding to it
await self.send_targeted_message(sender, 'connected') await self.send_targetted_message(sender, 'connected')
async def on_connected_message_received(self, sender, _): async def on_connected_message_received(self, sender, _):
if not self.pending_connection: if not self.pending_connection:
logger.warning('received a connection ack, but no connection is pending') logger.warn('received a connection ack, but no connection is pending')
return return
# Remember the connection # Remember the connection
@@ -524,9 +275,7 @@ class RemoteLink:
# Notify the controller # Notify the controller
logger.debug(f'connected to peripheral {self.pending_connection.peer_address}') logger.debug(f'connected to peripheral {self.pending_connection.peer_address}')
self.controller.on_link_peripheral_connection_complete( self.controller.on_link_peripheral_connection_complete(self.pending_connection, HCI_SUCCESS)
self.pending_connection, HCI_SUCCESS
)
async def on_disconnect_message_received(self, sender, message): async def on_disconnect_message_received(self, sender, message):
# Notify the controller # Notify the controller
@@ -538,7 +287,7 @@ class RemoteLink:
if sender in self.peripheral_connections: if sender in self.peripheral_connections:
self.peripheral_connections.remove(sender) self.peripheral_connections.remove(sender)
async def on_encrypted_message_received(self, sender, _): async def on_encrypted_message_received(self, sender, message):
# TODO parse params to get real args # TODO parse params to get real args
self.controller.on_link_encrypted(Address(sender), bytes(8), 0, bytes(16)) self.controller.on_link_encrypted(Address(sender), bytes(8), 0, bytes(16))
@@ -547,7 +296,7 @@ class RemoteLink:
websocket = await self.websocket websocket = await self.websocket
# Create a future value to hold the eventual result # Create a future value to hold the eventual result
assert self.rpc_result is None assert(self.rpc_result is None)
self.rpc_result = asyncio.get_running_loop().create_future() self.rpc_result = asyncio.get_running_loop().create_future()
# Send the command # Send the command
@@ -560,7 +309,7 @@ class RemoteLink:
# TODO: parse the result # TODO: parse the result
async def send_targeted_message(self, target, message): async def send_targetted_message(self, target, message):
# Ensure we have a connection # Ensure we have a connection
websocket = await self.websocket websocket = await self.websocket
@@ -577,62 +326,35 @@ class RemoteLink:
self.execute(self.notify_address_changed) self.execute(self.notify_address_changed)
async def send_advertising_data_to_relay(self, data): async def send_advertising_data_to_relay(self, data):
await self.send_targeted_message('*', f'advertisement:{data.hex()}') await self.send_targetted_message('*', f'advertisement:{data.hex()}')
def send_advertising_data(self, _, data): def send_advertising_data(self, sender_address, data):
self.execute(partial(self.send_advertising_data_to_relay, data)) self.execute(partial(self.send_advertising_data_to_relay, data))
async def send_acl_data_to_relay(self, peer_address, data): async def send_acl_data_to_relay(self, peer_address, data):
await self.send_targeted_message(peer_address, f'acl:{data.hex()}') await self.send_targetted_message(peer_address, f'acl:{data.hex()}')
def send_acl_data(self, _, peer_address, _transport, data): def send_acl_data(self, sender_address, peer_address, data):
# TODO: handle different transport
self.execute(partial(self.send_acl_data_to_relay, peer_address, data)) self.execute(partial(self.send_acl_data_to_relay, peer_address, data))
async def send_connection_request_to_relay(self, peer_address): async def send_connection_request_to_relay(self, peer_address):
await self.send_targeted_message(peer_address, 'connect') await self.send_targetted_message(peer_address, 'connect')
def connect(self, _, le_create_connection_command): def connect(self, central_address, le_create_connection_command):
if self.pending_connection: if self.pending_connection:
logger.warning('connection already pending') logger.warn('connection already pending')
return return
self.pending_connection = le_create_connection_command self.pending_connection = le_create_connection_command
self.execute( self.execute(partial(self.send_connection_request_to_relay, str(le_create_connection_command.peer_address)))
partial(
self.send_connection_request_to_relay,
str(le_create_connection_command.peer_address),
)
)
def on_disconnection_complete(self, disconnect_command): def on_disconnection_complete(self, disconnect_command):
self.controller.on_link_peripheral_disconnection_complete( self.controller.on_link_peripheral_disconnection_complete(disconnect_command, HCI_SUCCESS)
disconnect_command, HCI_SUCCESS
)
def disconnect(self, central_address, peripheral_address, disconnect_command): def disconnect(self, central_address, peripheral_address, disconnect_command):
logger.debug( logger.debug(f'disconnect {central_address} -> {peripheral_address}: reason = {disconnect_command.reason}')
f'disconnect {central_address} -> ' self.execute(partial(self.send_targetted_message, peripheral_address, f'disconnect:reason={disconnect_command.reason}'))
f'{peripheral_address}: reason = {disconnect_command.reason}' asyncio.get_running_loop().call_soon(self.on_disconnection_complete, disconnect_command)
)
self.execute(
partial(
self.send_targeted_message,
peripheral_address,
f'disconnect:reason={disconnect_command.reason}',
)
)
asyncio.get_running_loop().call_soon(
self.on_disconnection_complete, disconnect_command
)
def on_connection_encrypted(self, _, peripheral_address, rand, ediv, ltk): def on_connection_encrypted(self, central_address, peripheral_address, rand, ediv, ltk):
asyncio.get_running_loop().call_soon( asyncio.get_running_loop().call_soon(self.controller.on_link_encrypted, peripheral_address, rand, ediv, ltk)
self.controller.on_link_encrypted, peripheral_address, rand, ediv, ltk self.execute(partial(self.send_targetted_message, peripheral_address, f'encrypted:ltk={ltk.hex()}'))
)
self.execute(
partial(
self.send_targeted_message,
peripheral_address,
f'encrypted:ltk={ltk.hex()}',
)
)
-262
View File
@@ -1,262 +0,0 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import enum
from dataclasses import dataclass
from typing import Optional, Tuple
from .hci import (
Address,
HCI_NO_INPUT_NO_OUTPUT_IO_CAPABILITY,
HCI_DISPLAY_ONLY_IO_CAPABILITY,
HCI_DISPLAY_YES_NO_IO_CAPABILITY,
HCI_KEYBOARD_ONLY_IO_CAPABILITY,
)
from .smp import (
SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY,
SMP_KEYBOARD_ONLY_IO_CAPABILITY,
SMP_DISPLAY_ONLY_IO_CAPABILITY,
SMP_DISPLAY_YES_NO_IO_CAPABILITY,
SMP_KEYBOARD_DISPLAY_IO_CAPABILITY,
SMP_ENC_KEY_DISTRIBUTION_FLAG,
SMP_ID_KEY_DISTRIBUTION_FLAG,
SMP_SIGN_KEY_DISTRIBUTION_FLAG,
SMP_LINK_KEY_DISTRIBUTION_FLAG,
OobContext,
OobLegacyContext,
OobSharedData,
)
from .core import AdvertisingData, LeRole
# -----------------------------------------------------------------------------
@dataclass
class OobData:
"""OOB data that can be sent from one device to another."""
address: Optional[Address] = None
role: Optional[LeRole] = None
shared_data: Optional[OobSharedData] = None
legacy_context: Optional[OobLegacyContext] = None
@classmethod
def from_ad(cls, ad: AdvertisingData) -> OobData:
instance = cls()
shared_data_c: Optional[bytes] = None
shared_data_r: Optional[bytes] = None
for ad_type, ad_data in ad.ad_structures:
if ad_type == AdvertisingData.LE_BLUETOOTH_DEVICE_ADDRESS:
instance.address = Address(ad_data)
elif ad_type == AdvertisingData.LE_ROLE:
instance.role = LeRole(ad_data[0])
elif ad_type == AdvertisingData.LE_SECURE_CONNECTIONS_CONFIRMATION_VALUE:
shared_data_c = ad_data
elif ad_type == AdvertisingData.LE_SECURE_CONNECTIONS_RANDOM_VALUE:
shared_data_r = ad_data
elif ad_type == AdvertisingData.SECURITY_MANAGER_TK_VALUE:
instance.legacy_context = OobLegacyContext(tk=ad_data)
if shared_data_c and shared_data_r:
instance.shared_data = OobSharedData(c=shared_data_c, r=shared_data_r)
return instance
def to_ad(self) -> AdvertisingData:
ad_structures = []
if self.address is not None:
ad_structures.append(
(AdvertisingData.LE_BLUETOOTH_DEVICE_ADDRESS, bytes(self.address))
)
if self.role is not None:
ad_structures.append((AdvertisingData.LE_ROLE, bytes([self.role])))
if self.shared_data is not None:
ad_structures.extend(self.shared_data.to_ad().ad_structures)
if self.legacy_context is not None:
ad_structures.append(
(AdvertisingData.SECURITY_MANAGER_TK_VALUE, self.legacy_context.tk)
)
return AdvertisingData(ad_structures)
# -----------------------------------------------------------------------------
class PairingDelegate:
"""Abstract base class for Pairing Delegates."""
# I/O Capabilities.
# These are defined abstractly, and can be mapped to specific Classic pairing
# and/or SMP constants.
class IoCapability(enum.IntEnum):
NO_OUTPUT_NO_INPUT = SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY
KEYBOARD_INPUT_ONLY = SMP_KEYBOARD_ONLY_IO_CAPABILITY
DISPLAY_OUTPUT_ONLY = SMP_DISPLAY_ONLY_IO_CAPABILITY
DISPLAY_OUTPUT_AND_YES_NO_INPUT = SMP_DISPLAY_YES_NO_IO_CAPABILITY
DISPLAY_OUTPUT_AND_KEYBOARD_INPUT = SMP_KEYBOARD_DISPLAY_IO_CAPABILITY
# Direct names for backward compatibility.
NO_OUTPUT_NO_INPUT = IoCapability.NO_OUTPUT_NO_INPUT
KEYBOARD_INPUT_ONLY = IoCapability.KEYBOARD_INPUT_ONLY
DISPLAY_OUTPUT_ONLY = IoCapability.DISPLAY_OUTPUT_ONLY
DISPLAY_OUTPUT_AND_YES_NO_INPUT = IoCapability.DISPLAY_OUTPUT_AND_YES_NO_INPUT
DISPLAY_OUTPUT_AND_KEYBOARD_INPUT = IoCapability.DISPLAY_OUTPUT_AND_KEYBOARD_INPUT
# Key Distribution [LE only]
class KeyDistribution(enum.IntFlag):
DISTRIBUTE_ENCRYPTION_KEY = SMP_ENC_KEY_DISTRIBUTION_FLAG
DISTRIBUTE_IDENTITY_KEY = SMP_ID_KEY_DISTRIBUTION_FLAG
DISTRIBUTE_SIGNING_KEY = SMP_SIGN_KEY_DISTRIBUTION_FLAG
DISTRIBUTE_LINK_KEY = SMP_LINK_KEY_DISTRIBUTION_FLAG
DEFAULT_KEY_DISTRIBUTION: KeyDistribution = (
KeyDistribution.DISTRIBUTE_ENCRYPTION_KEY
| KeyDistribution.DISTRIBUTE_IDENTITY_KEY
)
# Default mapping from abstract to Classic I/O capabilities.
# Subclasses may override this if they prefer a different mapping.
CLASSIC_IO_CAPABILITIES_MAP = {
NO_OUTPUT_NO_INPUT: HCI_NO_INPUT_NO_OUTPUT_IO_CAPABILITY,
KEYBOARD_INPUT_ONLY: HCI_KEYBOARD_ONLY_IO_CAPABILITY,
DISPLAY_OUTPUT_ONLY: HCI_DISPLAY_ONLY_IO_CAPABILITY,
DISPLAY_OUTPUT_AND_YES_NO_INPUT: HCI_DISPLAY_YES_NO_IO_CAPABILITY,
DISPLAY_OUTPUT_AND_KEYBOARD_INPUT: HCI_DISPLAY_YES_NO_IO_CAPABILITY,
}
io_capability: IoCapability
local_initiator_key_distribution: KeyDistribution
local_responder_key_distribution: KeyDistribution
def __init__(
self,
io_capability: IoCapability = NO_OUTPUT_NO_INPUT,
local_initiator_key_distribution: KeyDistribution = DEFAULT_KEY_DISTRIBUTION,
local_responder_key_distribution: KeyDistribution = DEFAULT_KEY_DISTRIBUTION,
) -> None:
self.io_capability = io_capability
self.local_initiator_key_distribution = local_initiator_key_distribution
self.local_responder_key_distribution = local_responder_key_distribution
@property
def classic_io_capability(self) -> int:
"""Map the abstract I/O capability to a Classic constant."""
# pylint: disable=line-too-long
return self.CLASSIC_IO_CAPABILITIES_MAP.get(
self.io_capability, HCI_NO_INPUT_NO_OUTPUT_IO_CAPABILITY
)
@property
def smp_io_capability(self) -> int:
"""Map the abstract I/O capability to an SMP constant."""
# This is just a 1-1 direct mapping
return self.io_capability
async def accept(self) -> bool:
"""Accept or reject a Pairing request."""
return True
async def confirm(self, auto: bool = False) -> bool:
"""
Respond yes or no to a Pairing confirmation question.
The `auto` parameter stands for automatic confirmation.
"""
return True
# pylint: disable-next=unused-argument
async def compare_numbers(self, number: int, digits: int) -> bool:
"""Compare two numbers."""
return True
async def get_number(self) -> Optional[int]:
"""
Return an optional number as an answer to a passkey request.
Returning `None` will result in a negative reply.
"""
return 0
async def get_string(self, max_length: int) -> Optional[str]:
"""
Return a string whose utf-8 encoding is up to max_length bytes.
"""
return None
# pylint: disable-next=unused-argument
async def display_number(self, number: int, digits: int) -> None:
"""Display a number."""
# [LE only]
async def key_distribution_response(
self, peer_initiator_key_distribution: int, peer_responder_key_distribution: int
) -> Tuple[int, int]:
"""
Return the key distribution response in an SMP protocol context.
NOTE: since it is only used by the SMP protocol, this method's input and output
are directly as integers, using the SMP constants, rather than the abstract
KeyDistribution enums.
"""
return (
int(
peer_initiator_key_distribution & self.local_initiator_key_distribution
),
int(
peer_responder_key_distribution & self.local_responder_key_distribution
),
)
# -----------------------------------------------------------------------------
class PairingConfig:
"""Configuration for the Pairing protocol."""
class AddressType(enum.IntEnum):
PUBLIC = Address.PUBLIC_DEVICE_ADDRESS
RANDOM = Address.RANDOM_DEVICE_ADDRESS
@dataclass
class OobConfig:
"""Config for OOB pairing."""
our_context: Optional[OobContext]
peer_data: Optional[OobSharedData]
legacy_context: Optional[OobLegacyContext]
def __init__(
self,
sc: bool = True,
mitm: bool = True,
bonding: bool = True,
delegate: Optional[PairingDelegate] = None,
identity_address_type: Optional[AddressType] = None,
oob: Optional[OobConfig] = None,
) -> None:
self.sc = sc
self.mitm = mitm
self.bonding = bonding
self.delegate = delegate or PairingDelegate()
self.identity_address_type = identity_address_type
self.oob = oob
def __str__(self) -> str:
return (
f'PairingConfig(sc={self.sc}, '
f'mitm={self.mitm}, bonding={self.bonding}, '
f'identity_address_type={self.identity_address_type}, '
f'delegate[{self.delegate.io_capability}]), '
f'oob[{self.oob}])'
)
-105
View File
@@ -1,105 +0,0 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Bumble Pandora server.
This module implement the Pandora Bluetooth test APIs for the Bumble stack.
"""
__version__ = "0.0.1"
import grpc
import grpc.aio
from .config import Config
from .device import PandoraDevice
from .host import HostService
from .security import SecurityService, SecurityStorageService
from pandora.host_grpc_aio import add_HostServicer_to_server
from pandora.security_grpc_aio import (
add_SecurityServicer_to_server,
add_SecurityStorageServicer_to_server,
)
from typing import Callable, List, Optional
# public symbols
__all__ = [
'register_servicer_hook',
'serve',
'Config',
'PandoraDevice',
]
# Add servicers hooks.
_SERVICERS_HOOKS: List[Callable[[PandoraDevice, Config, grpc.aio.Server], None]] = []
def register_servicer_hook(
hook: Callable[[PandoraDevice, Config, grpc.aio.Server], None]
) -> None:
_SERVICERS_HOOKS.append(hook)
async def serve(
bumble: PandoraDevice,
config: Config = Config(),
grpc_server: Optional[grpc.aio.Server] = None,
port: int = 0,
) -> None:
# initialize a gRPC server if not provided.
server = grpc_server if grpc_server is not None else grpc.aio.server()
port = server.add_insecure_port(f'localhost:{port}')
try:
while True:
# load server config from dict.
config.load_from_dict(bumble.config.get('server', {}))
# add Pandora services to the gRPC server.
add_HostServicer_to_server(
HostService(server, bumble.device, config), server
)
add_SecurityServicer_to_server(
SecurityService(bumble.device, config), server
)
add_SecurityStorageServicer_to_server(
SecurityStorageService(bumble.device, config), server
)
# call hooks if any.
for hook in _SERVICERS_HOOKS:
hook(bumble, config, server)
# open device.
await bumble.open()
try:
# Pandora require classic devices to be discoverable & connectable.
if bumble.device.classic_enabled:
await bumble.device.set_discoverable(True)
await bumble.device.set_connectable(True)
# start & serve gRPC server.
await server.start()
await server.wait_for_termination()
finally:
# close device.
await bumble.close()
# re-initialize the gRPC server.
server = grpc.aio.server()
server.add_insecure_port(f'localhost:{port}')
finally:
# stop server.
await server.stop(None)
-56
View File
@@ -1,56 +0,0 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from bumble.pairing import PairingConfig, PairingDelegate
from dataclasses import dataclass
from typing import Any, Dict
@dataclass
class Config:
io_capability: PairingDelegate.IoCapability = PairingDelegate.NO_OUTPUT_NO_INPUT
identity_address_type: PairingConfig.AddressType = PairingConfig.AddressType.RANDOM
pairing_sc_enable: bool = True
pairing_mitm_enable: bool = True
pairing_bonding_enable: bool = True
smp_local_initiator_key_distribution: PairingDelegate.KeyDistribution = (
PairingDelegate.DEFAULT_KEY_DISTRIBUTION
)
smp_local_responder_key_distribution: PairingDelegate.KeyDistribution = (
PairingDelegate.DEFAULT_KEY_DISTRIBUTION
)
def load_from_dict(self, config: Dict[str, Any]) -> None:
io_capability_name: str = config.get(
'io_capability', 'no_output_no_input'
).upper()
self.io_capability = getattr(PairingDelegate, io_capability_name)
identity_address_type_name: str = config.get(
'identity_address_type', 'random'
).upper()
self.identity_address_type = getattr(
PairingConfig.AddressType, identity_address_type_name
)
self.pairing_sc_enable = config.get('pairing_sc_enable', True)
self.pairing_mitm_enable = config.get('pairing_mitm_enable', True)
self.pairing_bonding_enable = config.get('pairing_bonding_enable', True)
self.smp_local_initiator_key_distribution = config.get(
'smp_local_initiator_key_distribution',
PairingDelegate.DEFAULT_KEY_DISTRIBUTION,
)
self.smp_local_responder_key_distribution = config.get(
'smp_local_responder_key_distribution',
PairingDelegate.DEFAULT_KEY_DISTRIBUTION,
)
-164
View File
@@ -1,164 +0,0 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Generic & dependency free Bumble (reference) device."""
from __future__ import annotations
from bumble import transport
from bumble.core import (
BT_GENERIC_AUDIO_SERVICE,
BT_HANDSFREE_SERVICE,
BT_L2CAP_PROTOCOL_ID,
BT_RFCOMM_PROTOCOL_ID,
)
from bumble.device import Device, DeviceConfiguration
from bumble.host import Host
from bumble.sdp import (
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
DataElement,
ServiceAttribute,
)
from typing import Any, Dict, List, Optional
# Default rootcanal HCI TCP address
ROOTCANAL_HCI_ADDRESS = "localhost:6402"
class PandoraDevice:
"""
Small wrapper around a Bumble device and it's HCI transport.
Notes:
- The Bumble device is idle by default.
- Repetitive calls to `open`/`close` will result on new Bumble device instances.
"""
# Bumble device instance & configuration.
device: Device
config: Dict[str, Any]
# HCI transport name & instance.
_hci_name: str
_hci: Optional[transport.Transport] # type: ignore[name-defined]
def __init__(self, config: Dict[str, Any]) -> None:
self.config = config
self.device = _make_device(config)
self._hci_name = config.get(
'transport', f"tcp-client:{config.get('tcp', ROOTCANAL_HCI_ADDRESS)}"
)
self._hci = None
@property
def idle(self) -> bool:
return self._hci is None
async def open(self) -> None:
if self._hci is not None:
return
# open HCI transport & set device host.
self._hci = await transport.open_transport(self._hci_name)
self.device.host = Host(controller_source=self._hci.source, controller_sink=self._hci.sink) # type: ignore[no-untyped-call]
# power-on.
await self.device.power_on()
async def close(self) -> None:
if self._hci is None:
return
# flush & re-initialize device.
await self.device.host.flush()
self.device.host = None # type: ignore[assignment]
self.device = _make_device(self.config)
# close HCI transport.
await self._hci.close()
self._hci = None
async def reset(self) -> None:
await self.close()
await self.open()
def info(self) -> Optional[Dict[str, str]]:
return {
'public_bd_address': str(self.device.public_address),
'random_address': str(self.device.random_address),
}
def _make_device(config: Dict[str, Any]) -> Device:
"""Initialize an idle Bumble device instance."""
# initialize bumble device.
device_config = DeviceConfiguration()
device_config.load_from_dict(config)
device = Device(config=device_config, host=None)
# Add fake a2dp service to avoid Android disconnect
device.sdp_service_records = _make_sdp_records(1)
return device
# TODO(b/267540823): remove when Pandora A2dp is supported
def _make_sdp_records(rfcomm_channel: int) -> Dict[int, List[ServiceAttribute]]:
return {
0x00010001: [
ServiceAttribute(
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
DataElement.unsigned_integer_32(0x00010001),
),
ServiceAttribute(
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.uuid(BT_HANDSFREE_SERVICE),
DataElement.uuid(BT_GENERIC_AUDIO_SERVICE),
]
),
),
ServiceAttribute(
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.sequence([DataElement.uuid(BT_L2CAP_PROTOCOL_ID)]),
DataElement.sequence(
[
DataElement.uuid(BT_RFCOMM_PROTOCOL_ID),
DataElement.unsigned_integer_8(rfcomm_channel),
]
),
]
),
),
ServiceAttribute(
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.sequence(
[
DataElement.uuid(BT_HANDSFREE_SERVICE),
DataElement.unsigned_integer_16(0x0105),
]
)
]
),
),
]
}
File diff suppressed because it is too large Load Diff
View File
-559
View File
@@ -1,559 +0,0 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
import contextlib
import grpc
import logging
from . import utils
from .config import Config
from bumble import hci
from bumble.core import (
BT_BR_EDR_TRANSPORT,
BT_LE_TRANSPORT,
BT_PERIPHERAL_ROLE,
ProtocolError,
)
from bumble.device import Connection as BumbleConnection, Device
from bumble.hci import HCI_Error
from bumble.utils import EventWatcher
from bumble.pairing import PairingConfig, PairingDelegate as BasePairingDelegate
from google.protobuf import any_pb2 # pytype: disable=pyi-error
from google.protobuf import empty_pb2 # pytype: disable=pyi-error
from google.protobuf import wrappers_pb2 # pytype: disable=pyi-error
from pandora.host_pb2 import Connection
from pandora.security_grpc_aio import SecurityServicer, SecurityStorageServicer
from pandora.security_pb2 import (
LE_LEVEL1,
LE_LEVEL2,
LE_LEVEL3,
LE_LEVEL4,
LEVEL0,
LEVEL1,
LEVEL2,
LEVEL3,
LEVEL4,
DeleteBondRequest,
IsBondedRequest,
LESecurityLevel,
PairingEvent,
PairingEventAnswer,
SecureRequest,
SecureResponse,
SecurityLevel,
WaitSecurityRequest,
WaitSecurityResponse,
)
from typing import Any, AsyncGenerator, AsyncIterator, Callable, Dict, Optional, Union
class PairingDelegate(BasePairingDelegate):
def __init__(
self,
connection: BumbleConnection,
service: "SecurityService",
io_capability: BasePairingDelegate.IoCapability = BasePairingDelegate.NO_OUTPUT_NO_INPUT,
local_initiator_key_distribution: BasePairingDelegate.KeyDistribution = BasePairingDelegate.DEFAULT_KEY_DISTRIBUTION,
local_responder_key_distribution: BasePairingDelegate.KeyDistribution = BasePairingDelegate.DEFAULT_KEY_DISTRIBUTION,
) -> None:
self.log = utils.BumbleServerLoggerAdapter(
logging.getLogger(),
{'service_name': 'Security', 'device': connection.device},
)
self.connection = connection
self.service = service
super().__init__(
io_capability,
local_initiator_key_distribution,
local_responder_key_distribution,
)
async def accept(self) -> bool:
return True
def add_origin(self, ev: PairingEvent) -> PairingEvent:
if not self.connection.is_incomplete:
assert ev.connection
ev.connection.CopyFrom(
Connection(
cookie=any_pb2.Any(value=self.connection.handle.to_bytes(4, 'big'))
)
)
else:
# In BR/EDR, connection may not be complete,
# use address instead
assert self.connection.transport == BT_BR_EDR_TRANSPORT
ev.address = bytes(reversed(bytes(self.connection.peer_address)))
return ev
async def confirm(self, auto: bool = False) -> bool:
self.log.debug(
f"Pairing event: `just_works` (io_capability: {self.io_capability})"
)
if self.service.event_queue is None or self.service.event_answer is None:
return True
event = self.add_origin(PairingEvent(just_works=empty_pb2.Empty()))
self.service.event_queue.put_nowait(event)
answer = await anext(self.service.event_answer) # type: ignore
assert answer.event == event
assert answer.answer_variant() == 'confirm' and answer.confirm is not None
return answer.confirm
async def compare_numbers(self, number: int, digits: int = 6) -> bool:
self.log.debug(
f"Pairing event: `numeric_comparison` (io_capability: {self.io_capability})"
)
if self.service.event_queue is None or self.service.event_answer is None:
raise RuntimeError('security: unhandled number comparison request')
event = self.add_origin(PairingEvent(numeric_comparison=number))
self.service.event_queue.put_nowait(event)
answer = await anext(self.service.event_answer) # type: ignore
assert answer.event == event
assert answer.answer_variant() == 'confirm' and answer.confirm is not None
return answer.confirm
async def get_number(self) -> Optional[int]:
self.log.debug(
f"Pairing event: `passkey_entry_request` (io_capability: {self.io_capability})"
)
if self.service.event_queue is None or self.service.event_answer is None:
raise RuntimeError('security: unhandled number request')
event = self.add_origin(PairingEvent(passkey_entry_request=empty_pb2.Empty()))
self.service.event_queue.put_nowait(event)
answer = await anext(self.service.event_answer) # type: ignore
assert answer.event == event
if answer.answer_variant() is None:
return None
assert answer.answer_variant() == 'passkey'
return answer.passkey
async def get_string(self, max_length: int) -> Optional[str]:
self.log.debug(
f"Pairing event: `pin_code_request` (io_capability: {self.io_capability})"
)
if self.service.event_queue is None or self.service.event_answer is None:
raise RuntimeError('security: unhandled pin_code request')
event = self.add_origin(PairingEvent(pin_code_request=empty_pb2.Empty()))
self.service.event_queue.put_nowait(event)
answer = await anext(self.service.event_answer) # type: ignore
assert answer.event == event
if answer.answer_variant() is None:
return None
assert answer.answer_variant() == 'pin'
if answer.pin is None:
return None
pin = answer.pin.decode('utf-8')
if not pin or len(pin) > max_length:
raise ValueError(f'Pin must be utf-8 encoded up to {max_length} bytes')
return pin
async def display_number(self, number: int, digits: int = 6) -> None:
if (
self.connection.transport == BT_BR_EDR_TRANSPORT
and self.io_capability == BasePairingDelegate.DISPLAY_OUTPUT_ONLY
):
return
self.log.debug(
f"Pairing event: `passkey_entry_notification` (io_capability: {self.io_capability})"
)
if self.service.event_queue is None:
raise RuntimeError('security: unhandled number display request')
event = self.add_origin(PairingEvent(passkey_entry_notification=number))
self.service.event_queue.put_nowait(event)
BR_LEVEL_REACHED: Dict[SecurityLevel, Callable[[BumbleConnection], bool]] = {
LEVEL0: lambda connection: True,
LEVEL1: lambda connection: connection.encryption == 0 or connection.authenticated,
LEVEL2: lambda connection: connection.encryption != 0 and connection.authenticated,
LEVEL3: lambda connection: connection.encryption != 0
and connection.authenticated
and connection.link_key_type
in (
hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_192_TYPE,
hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_256_TYPE,
),
LEVEL4: lambda connection: connection.encryption
== hci.HCI_Encryption_Change_Event.AES_CCM
and connection.authenticated
and connection.link_key_type
== hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_256_TYPE,
}
LE_LEVEL_REACHED: Dict[LESecurityLevel, Callable[[BumbleConnection], bool]] = {
LE_LEVEL1: lambda connection: True,
LE_LEVEL2: lambda connection: connection.encryption != 0,
LE_LEVEL3: lambda connection: connection.encryption != 0
and connection.authenticated,
LE_LEVEL4: lambda connection: connection.encryption != 0
and connection.authenticated
and connection.sc,
}
class SecurityService(SecurityServicer):
def __init__(self, device: Device, config: Config) -> None:
self.log = utils.BumbleServerLoggerAdapter(
logging.getLogger(), {'service_name': 'Security', 'device': device}
)
self.event_queue: Optional[asyncio.Queue[PairingEvent]] = None
self.event_answer: Optional[AsyncIterator[PairingEventAnswer]] = None
self.device = device
self.config = config
def pairing_config_factory(connection: BumbleConnection) -> PairingConfig:
return PairingConfig(
sc=config.pairing_sc_enable,
mitm=config.pairing_mitm_enable,
bonding=config.pairing_bonding_enable,
identity_address_type=(
PairingConfig.AddressType.PUBLIC
if connection.self_address.is_public
else config.identity_address_type
),
delegate=PairingDelegate(
connection,
self,
io_capability=config.io_capability,
local_initiator_key_distribution=config.smp_local_initiator_key_distribution,
local_responder_key_distribution=config.smp_local_responder_key_distribution,
),
)
self.device.pairing_config_factory = pairing_config_factory
@utils.rpc
async def OnPairing(
self, request: AsyncIterator[PairingEventAnswer], context: grpc.ServicerContext
) -> AsyncGenerator[PairingEvent, None]:
self.log.debug('OnPairing')
if self.event_queue is not None:
raise RuntimeError('already streaming pairing events')
if len(self.device.connections):
raise RuntimeError(
'the `OnPairing` method shall be initiated before establishing any connections.'
)
self.event_queue = asyncio.Queue()
self.event_answer = request
try:
while event := await self.event_queue.get():
yield event
finally:
self.event_queue = None
self.event_answer = None
@utils.rpc
async def Secure(
self, request: SecureRequest, context: grpc.ServicerContext
) -> SecureResponse:
connection_handle = int.from_bytes(request.connection.cookie.value, 'big')
self.log.debug(f"Secure: {connection_handle}")
connection = self.device.lookup_connection(connection_handle)
assert connection
oneof = request.WhichOneof('level')
level = getattr(request, oneof)
assert {BT_BR_EDR_TRANSPORT: 'classic', BT_LE_TRANSPORT: 'le'}[
connection.transport
] == oneof
# security level already reached
if self.reached_security_level(connection, level):
return SecureResponse(success=empty_pb2.Empty())
# trigger pairing if needed
if self.need_pairing(connection, level):
try:
self.log.debug('Pair...')
security_result = asyncio.get_running_loop().create_future()
with contextlib.closing(EventWatcher()) as watcher:
@watcher.on(connection, 'pairing')
def on_pairing(*_: Any) -> None:
security_result.set_result('success')
@watcher.on(connection, 'pairing_failure')
def on_pairing_failure(*_: Any) -> None:
security_result.set_result('pairing_failure')
@watcher.on(connection, 'disconnection')
def on_disconnection(*_: Any) -> None:
security_result.set_result('connection_died')
if (
connection.transport == BT_LE_TRANSPORT
and connection.role == BT_PERIPHERAL_ROLE
):
connection.request_pairing()
else:
await connection.pair()
result = await security_result
self.log.debug(f'Pairing session complete, status={result}')
if result != 'success':
return SecureResponse(**{result: empty_pb2.Empty()})
except asyncio.CancelledError:
self.log.warning("Connection died during encryption")
return SecureResponse(connection_died=empty_pb2.Empty())
except (HCI_Error, ProtocolError) as e:
self.log.warning(f"Pairing failure: {e}")
return SecureResponse(pairing_failure=empty_pb2.Empty())
# trigger authentication if needed
if self.need_authentication(connection, level):
try:
self.log.debug('Authenticate...')
await connection.authenticate()
self.log.debug('Authenticated')
except asyncio.CancelledError:
self.log.warning("Connection died during authentication")
return SecureResponse(connection_died=empty_pb2.Empty())
except (HCI_Error, ProtocolError) as e:
self.log.warning(f"Authentication failure: {e}")
return SecureResponse(authentication_failure=empty_pb2.Empty())
# trigger encryption if needed
if self.need_encryption(connection, level):
try:
self.log.debug('Encrypt...')
await connection.encrypt()
self.log.debug('Encrypted')
except asyncio.CancelledError:
self.log.warning("Connection died during encryption")
return SecureResponse(connection_died=empty_pb2.Empty())
except (HCI_Error, ProtocolError) as e:
self.log.warning(f"Encryption failure: {e}")
return SecureResponse(encryption_failure=empty_pb2.Empty())
# security level has been reached ?
if self.reached_security_level(connection, level):
return SecureResponse(success=empty_pb2.Empty())
return SecureResponse(not_reached=empty_pb2.Empty())
@utils.rpc
async def WaitSecurity(
self, request: WaitSecurityRequest, context: grpc.ServicerContext
) -> WaitSecurityResponse:
connection_handle = int.from_bytes(request.connection.cookie.value, 'big')
self.log.debug(f"WaitSecurity: {connection_handle}")
connection = self.device.lookup_connection(connection_handle)
assert connection
assert request.level
level = request.level
assert {BT_BR_EDR_TRANSPORT: 'classic', BT_LE_TRANSPORT: 'le'}[
connection.transport
] == request.level_variant()
wait_for_security: asyncio.Future[str] = (
asyncio.get_running_loop().create_future()
)
authenticate_task: Optional[asyncio.Future[None]] = None
pair_task: Optional[asyncio.Future[None]] = None
async def authenticate() -> None:
assert connection
if (encryption := connection.encryption) != 0:
self.log.debug('Disable encryption...')
try:
await connection.encrypt(enable=False)
except:
pass
self.log.debug('Disable encryption: done')
self.log.debug('Authenticate...')
await connection.authenticate()
self.log.debug('Authenticate: done')
if encryption != 0 and connection.encryption != encryption:
self.log.debug('Re-enable encryption...')
await connection.encrypt()
self.log.debug('Re-enable encryption: done')
def set_failure(name: str) -> Callable[..., None]:
def wrapper(*args: Any) -> None:
self.log.debug(f'Wait for security: error `{name}`: {args}')
wait_for_security.set_result(name)
return wrapper
def try_set_success(*_: Any) -> None:
assert connection
if self.reached_security_level(connection, level):
self.log.debug('Wait for security: done')
wait_for_security.set_result('success')
def on_encryption_change(*_: Any) -> None:
assert connection
if self.reached_security_level(connection, level):
self.log.debug('Wait for security: done')
wait_for_security.set_result('success')
elif (
connection.transport == BT_BR_EDR_TRANSPORT
and self.need_authentication(connection, level)
):
nonlocal authenticate_task
if authenticate_task is None:
authenticate_task = asyncio.create_task(authenticate())
def pair(*_: Any) -> None:
if self.need_pairing(connection, level):
pair_task = asyncio.create_task(connection.pair())
listeners: Dict[str, Callable[..., None]] = {
'disconnection': set_failure('connection_died'),
'pairing_failure': set_failure('pairing_failure'),
'connection_authentication_failure': set_failure('authentication_failure'),
'connection_encryption_failure': set_failure('encryption_failure'),
'pairing': try_set_success,
'connection_authentication': try_set_success,
'connection_encryption_change': on_encryption_change,
'classic_pairing': try_set_success,
'classic_pairing_failure': set_failure('pairing_failure'),
'security_request': pair,
}
with contextlib.closing(EventWatcher()) as watcher:
# register event handlers
for event, listener in listeners.items():
watcher.on(connection, event, listener)
# security level already reached
if self.reached_security_level(connection, level):
return WaitSecurityResponse(success=empty_pb2.Empty())
self.log.debug('Wait for security...')
kwargs = {}
kwargs[await wait_for_security] = empty_pb2.Empty()
# wait for `authenticate` to finish if any
if authenticate_task is not None:
self.log.debug('Wait for authentication...')
try:
await authenticate_task # type: ignore
except:
pass
self.log.debug('Authenticated')
# wait for `pair` to finish if any
if pair_task is not None:
self.log.debug('Wait for authentication...')
try:
await pair_task # type: ignore
except:
pass
self.log.debug('paired')
return WaitSecurityResponse(**kwargs)
def reached_security_level(
self, connection: BumbleConnection, level: Union[SecurityLevel, LESecurityLevel]
) -> bool:
self.log.debug(
str(
{
'level': level,
'encryption': connection.encryption,
'authenticated': connection.authenticated,
'sc': connection.sc,
'link_key_type': connection.link_key_type,
}
)
)
if isinstance(level, LESecurityLevel):
return LE_LEVEL_REACHED[level](connection)
return BR_LEVEL_REACHED[level](connection)
def need_pairing(self, connection: BumbleConnection, level: int) -> bool:
if connection.transport == BT_LE_TRANSPORT:
return level >= LE_LEVEL3 and not connection.authenticated
return False
def need_authentication(self, connection: BumbleConnection, level: int) -> bool:
if connection.transport == BT_LE_TRANSPORT:
return False
if level == LEVEL2 and connection.encryption != 0:
return not connection.authenticated
return level >= LEVEL2 and not connection.authenticated
def need_encryption(self, connection: BumbleConnection, level: int) -> bool:
# TODO(abel): need to support MITM
if connection.transport == BT_LE_TRANSPORT:
return level == LE_LEVEL2 and not connection.encryption
return level >= LEVEL2 and not connection.encryption
class SecurityStorageService(SecurityStorageServicer):
def __init__(self, device: Device, config: Config) -> None:
self.log = utils.BumbleServerLoggerAdapter(
logging.getLogger(), {'service_name': 'SecurityStorage', 'device': device}
)
self.device = device
self.config = config
@utils.rpc
async def IsBonded(
self, request: IsBondedRequest, context: grpc.ServicerContext
) -> wrappers_pb2.BoolValue:
address = utils.address_from_request(request, request.WhichOneof("address"))
self.log.debug(f"IsBonded: {address}")
if self.device.keystore is not None:
is_bonded = await self.device.keystore.get(str(address)) is not None
else:
is_bonded = False
return wrappers_pb2.BoolValue(value=is_bonded)
@utils.rpc
async def DeleteBond(
self, request: DeleteBondRequest, context: grpc.ServicerContext
) -> empty_pb2.Empty:
address = utils.address_from_request(request, request.WhichOneof("address"))
self.log.debug(f"DeleteBond: {address}")
if self.device.keystore is not None:
with contextlib.suppress(KeyError):
await self.device.keystore.delete(str(address))
return empty_pb2.Empty()
-113
View File
@@ -1,113 +0,0 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import contextlib
import functools
import grpc
import inspect
import logging
from bumble.device import Device
from bumble.hci import Address
from google.protobuf.message import Message # pytype: disable=pyi-error
from typing import Any, Dict, Generator, MutableMapping, Optional, Tuple
ADDRESS_TYPES: Dict[str, int] = {
"public": Address.PUBLIC_DEVICE_ADDRESS,
"random": Address.RANDOM_DEVICE_ADDRESS,
"public_identity": Address.PUBLIC_IDENTITY_ADDRESS,
"random_static_identity": Address.RANDOM_IDENTITY_ADDRESS,
}
def address_from_request(request: Message, field: Optional[str]) -> Address:
if field is None:
return Address.ANY
return Address(bytes(reversed(getattr(request, field))), ADDRESS_TYPES[field])
class BumbleServerLoggerAdapter(logging.LoggerAdapter): # type: ignore
"""Formats logs from the PandoraClient."""
def process(
self, msg: str, kwargs: MutableMapping[str, Any]
) -> Tuple[str, MutableMapping[str, Any]]:
assert self.extra
service_name = self.extra['service_name']
assert isinstance(service_name, str)
device = self.extra['device']
assert isinstance(device, Device)
addr_bytes = bytes(
reversed(bytes(device.public_address))
) # pytype: disable=attribute-error
addr = ':'.join([f'{x:02X}' for x in addr_bytes[4:]])
return (f'[bumble.{service_name}:{addr}] {msg}', kwargs)
@contextlib.contextmanager
def exception_to_rpc_error(
context: grpc.ServicerContext,
) -> Generator[None, None, None]:
try:
yield None
except NotImplementedError as e:
context.set_code(grpc.StatusCode.UNIMPLEMENTED) # type: ignore
context.set_details(str(e)) # type: ignore
except ValueError as e:
context.set_code(grpc.StatusCode.INVALID_ARGUMENT) # type: ignore
context.set_details(str(e)) # type: ignore
except RuntimeError as e:
context.set_code(grpc.StatusCode.ABORTED) # type: ignore
context.set_details(str(e)) # type: ignore
# Decorate an RPC servicer method with a wrapper that transform exceptions to gRPC errors.
def rpc(func: Any) -> Any:
@functools.wraps(func)
async def asyncgen_wrapper(
self: Any, request: Any, context: grpc.ServicerContext
) -> Any:
with exception_to_rpc_error(context):
async for v in func(self, request, context):
yield v
@functools.wraps(func)
async def async_wrapper(
self: Any, request: Any, context: grpc.ServicerContext
) -> Any:
with exception_to_rpc_error(context):
return await func(self, request, context)
@functools.wraps(func)
def gen_wrapper(self: Any, request: Any, context: grpc.ServicerContext) -> Any:
with exception_to_rpc_error(context):
for v in func(self, request, context):
yield v
@functools.wraps(func)
def wrapper(self: Any, request: Any, context: grpc.ServicerContext) -> Any:
with exception_to_rpc_error(context):
return func(self, request, context)
if inspect.isasyncgenfunction(func):
return asyncgen_wrapper
if inspect.iscoroutinefunction(func):
return async_wrapper
if inspect.isgenerator(func):
return gen_wrapper
return wrapper
-193
View File
@@ -1,193 +0,0 @@
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import struct
import logging
from typing import List, Optional
from bumble import l2cap
from ..core import AdvertisingData
from ..device import Device, Connection
from ..gatt import (
GATT_ASHA_SERVICE,
GATT_ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC,
GATT_ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC,
GATT_ASHA_AUDIO_STATUS_CHARACTERISTIC,
GATT_ASHA_VOLUME_CHARACTERISTIC,
GATT_ASHA_LE_PSM_OUT_CHARACTERISTIC,
TemplateService,
Characteristic,
CharacteristicValue,
)
from ..utils import AsyncRunner
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
class AshaService(TemplateService):
UUID = GATT_ASHA_SERVICE
OPCODE_START = 1
OPCODE_STOP = 2
OPCODE_STATUS = 3
PROTOCOL_VERSION = 0x01
RESERVED_FOR_FUTURE_USE = [00, 00]
FEATURE_MAP = [0x01] # [LE CoC audio output streaming supported]
SUPPORTED_CODEC_ID = [0x02, 0x01] # Codec IDs [G.722 at 16 kHz]
RENDER_DELAY = [00, 00]
def __init__(self, capability: int, hisyncid: List[int], device: Device, psm=0):
self.hisyncid = hisyncid
self.capability = capability # Device Capabilities [Left, Monaural]
self.device = device
self.audio_out_data = b''
self.psm = psm # a non-zero psm is mainly for testing purpose
# Handler for volume control
def on_volume_write(connection, value):
logger.info(f'--- VOLUME Write:{value[0]}')
self.emit('volume', connection, value[0])
# Handler for audio control commands
def on_audio_control_point_write(connection: Optional[Connection], value):
logger.info(f'--- AUDIO CONTROL POINT Write:{value.hex()}')
opcode = value[0]
if opcode == AshaService.OPCODE_START:
# Start
audio_type = ('Unknown', 'Ringtone', 'Phone Call', 'Media')[value[2]]
logger.info(
f'### START: codec={value[1]}, '
f'audio_type={audio_type}, '
f'volume={value[3]}, '
f'otherstate={value[4]}'
)
self.emit(
'start',
connection,
{
'codec': value[1],
'audiotype': value[2],
'volume': value[3],
'otherstate': value[4],
},
)
elif opcode == AshaService.OPCODE_STOP:
logger.info('### STOP')
self.emit('stop', connection)
elif opcode == AshaService.OPCODE_STATUS:
logger.info(f'### STATUS: connected={value[1]}')
# OPCODE_STATUS does not need audio status point update
if opcode != AshaService.OPCODE_STATUS:
AsyncRunner.spawn(
device.notify_subscribers(
self.audio_status_characteristic, force=True
)
)
self.read_only_properties_characteristic = Characteristic(
GATT_ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC,
Characteristic.Properties.READ,
Characteristic.READABLE,
bytes(
[
AshaService.PROTOCOL_VERSION, # Version
self.capability,
]
)
+ bytes(self.hisyncid)
+ bytes(AshaService.FEATURE_MAP)
+ bytes(AshaService.RENDER_DELAY)
+ bytes(AshaService.RESERVED_FOR_FUTURE_USE)
+ bytes(AshaService.SUPPORTED_CODEC_ID),
)
self.audio_control_point_characteristic = Characteristic(
GATT_ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC,
Characteristic.Properties.WRITE
| Characteristic.Properties.WRITE_WITHOUT_RESPONSE,
Characteristic.WRITEABLE,
CharacteristicValue(write=on_audio_control_point_write),
)
self.audio_status_characteristic = Characteristic(
GATT_ASHA_AUDIO_STATUS_CHARACTERISTIC,
Characteristic.Properties.READ | Characteristic.Properties.NOTIFY,
Characteristic.READABLE,
bytes([0]),
)
self.volume_characteristic = Characteristic(
GATT_ASHA_VOLUME_CHARACTERISTIC,
Characteristic.Properties.WRITE_WITHOUT_RESPONSE,
Characteristic.WRITEABLE,
CharacteristicValue(write=on_volume_write),
)
# Register an L2CAP CoC server
def on_coc(channel):
def on_data(data):
logging.debug(f'<<< data received:{data}')
self.emit('data', channel.connection, data)
self.audio_out_data += data
channel.sink = on_data
# let the server find a free PSM
self.psm = device.create_l2cap_server(
spec=l2cap.LeCreditBasedChannelSpec(psm=self.psm, max_credits=8),
handler=on_coc,
).psm
self.le_psm_out_characteristic = Characteristic(
GATT_ASHA_LE_PSM_OUT_CHARACTERISTIC,
Characteristic.Properties.READ,
Characteristic.READABLE,
struct.pack('<H', self.psm),
)
characteristics = [
self.read_only_properties_characteristic,
self.audio_control_point_characteristic,
self.audio_status_characteristic,
self.volume_characteristic,
self.le_psm_out_characteristic,
]
super().__init__(characteristics)
def get_advertising_data(self):
# Advertisement only uses 4 least significant bytes of the HiSyncId.
return bytes(
AdvertisingData(
[
(
AdvertisingData.SERVICE_DATA_16_BIT_UUID,
bytes(GATT_ASHA_SERVICE)
+ bytes(
[
AshaService.PROTOCOL_VERSION,
self.capability,
]
)
+ bytes(self.hisyncid[:4]),
),
]
)
)
File diff suppressed because it is too large Load Diff
+7 -8
View File
@@ -23,7 +23,7 @@ from ..gatt import (
TemplateService, TemplateService,
Characteristic, Characteristic,
CharacteristicValue, CharacteristicValue,
PackedCharacteristicAdapter, PackedCharacteristicAdapter
) )
@@ -36,11 +36,11 @@ class BatteryService(TemplateService):
self.battery_level_characteristic = PackedCharacteristicAdapter( self.battery_level_characteristic = PackedCharacteristicAdapter(
Characteristic( Characteristic(
GATT_BATTERY_LEVEL_CHARACTERISTIC, GATT_BATTERY_LEVEL_CHARACTERISTIC,
Characteristic.Properties.READ | Characteristic.Properties.NOTIFY, Characteristic.READ | Characteristic.NOTIFY,
Characteristic.READABLE, Characteristic.READABLE,
CharacteristicValue(read=read_battery_level), CharacteristicValue(read=read_battery_level)
), ),
pack_format=BatteryService.BATTERY_LEVEL_FORMAT, format=BatteryService.BATTERY_LEVEL_FORMAT
) )
super().__init__([self.battery_level_characteristic]) super().__init__([self.battery_level_characteristic])
@@ -52,11 +52,10 @@ class BatteryServiceProxy(ProfileServiceProxy):
def __init__(self, service_proxy): def __init__(self, service_proxy):
self.service_proxy = service_proxy self.service_proxy = service_proxy
if characteristics := service_proxy.get_characteristics_by_uuid( if characteristics := service_proxy.get_characteristics_by_uuid(GATT_BATTERY_LEVEL_CHARACTERISTIC):
GATT_BATTERY_LEVEL_CHARACTERISTIC
):
self.battery_level = PackedCharacteristicAdapter( self.battery_level = PackedCharacteristicAdapter(
characteristics[0], pack_format=BatteryService.BATTERY_LEVEL_FORMAT characteristics[0],
format=BatteryService.BATTERY_LEVEL_FORMAT
) )
else: else:
self.battery_level = None self.battery_level = None
-52
View File
@@ -1,52 +0,0 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
from bumble import gatt
from bumble import gatt_client
from bumble.profiles import csip
# -----------------------------------------------------------------------------
# Server
# -----------------------------------------------------------------------------
class CommonAudioServiceService(gatt.TemplateService):
UUID = gatt.GATT_COMMON_AUDIO_SERVICE
def __init__(
self,
coordinated_set_identification_service: csip.CoordinatedSetIdentificationService,
) -> None:
self.coordinated_set_identification_service = (
coordinated_set_identification_service
)
super().__init__(
characteristics=[],
included_services=[coordinated_set_identification_service],
)
# -----------------------------------------------------------------------------
# Client
# -----------------------------------------------------------------------------
class CommonAudioServiceServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = CommonAudioServiceService
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
self.service_proxy = service_proxy
-257
View File
@@ -1,257 +0,0 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import enum
import struct
from typing import Optional, Tuple
from bumble import core
from bumble import crypto
from bumble import device
from bumble import gatt
from bumble import gatt_client
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
SET_IDENTITY_RESOLVING_KEY_LENGTH = 16
class SirkType(enum.IntEnum):
'''Coordinated Set Identification Service - 5.1 Set Identity Resolving Key.'''
ENCRYPTED = 0x00
PLAINTEXT = 0x01
class MemberLock(enum.IntEnum):
'''Coordinated Set Identification Service - 5.3 Set Member Lock.'''
UNLOCKED = 0x01
LOCKED = 0x02
# -----------------------------------------------------------------------------
# Crypto Toolbox
# -----------------------------------------------------------------------------
def s1(m: bytes) -> bytes:
'''
Coordinated Set Identification Service - 4.3 s1 SALT generation function.
'''
return crypto.aes_cmac(m[::-1], bytes(16))[::-1]
def k1(n: bytes, salt: bytes, p: bytes) -> bytes:
'''
Coordinated Set Identification Service - 4.4 k1 derivation function.
'''
t = crypto.aes_cmac(n[::-1], salt[::-1])
return crypto.aes_cmac(p[::-1], t)[::-1]
def sef(k: bytes, r: bytes) -> bytes:
'''
Coordinated Set Identification Service - 4.5 SIRK encryption function sef.
SIRK decryption function sdf shares the same algorithm. The only difference is that argument r is:
* Plaintext in encryption
* Cipher in decryption
'''
return crypto.xor(k1(k, s1(b'SIRKenc'[::-1]), b'csis'[::-1]), r)
def sih(k: bytes, r: bytes) -> bytes:
'''
Coordinated Set Identification Service - 4.7 Resolvable Set Identifier hash function sih.
'''
return crypto.e(k, r + bytes(13))[:3]
def generate_rsi(sirk: bytes) -> bytes:
'''
Coordinated Set Identification Service - 4.8 Resolvable Set Identifier generation operation.
'''
prand = crypto.generate_prand()
return sih(sirk, prand) + prand
# -----------------------------------------------------------------------------
# Server
# -----------------------------------------------------------------------------
class CoordinatedSetIdentificationService(gatt.TemplateService):
UUID = gatt.GATT_COORDINATED_SET_IDENTIFICATION_SERVICE
set_identity_resolving_key: bytes
set_identity_resolving_key_characteristic: gatt.Characteristic
coordinated_set_size_characteristic: Optional[gatt.Characteristic] = None
set_member_lock_characteristic: Optional[gatt.Characteristic] = None
set_member_rank_characteristic: Optional[gatt.Characteristic] = None
def __init__(
self,
set_identity_resolving_key: bytes,
set_identity_resolving_key_type: SirkType,
coordinated_set_size: Optional[int] = None,
set_member_lock: Optional[MemberLock] = None,
set_member_rank: Optional[int] = None,
) -> None:
if len(set_identity_resolving_key) != SET_IDENTITY_RESOLVING_KEY_LENGTH:
raise ValueError(
f'Invalid SIRK length {len(set_identity_resolving_key)}, expected {SET_IDENTITY_RESOLVING_KEY_LENGTH}'
)
characteristics = []
self.set_identity_resolving_key = set_identity_resolving_key
self.set_identity_resolving_key_type = set_identity_resolving_key_type
self.set_identity_resolving_key_characteristic = gatt.Characteristic(
uuid=gatt.GATT_SET_IDENTITY_RESOLVING_KEY_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=gatt.CharacteristicValue(read=self.on_sirk_read),
)
characteristics.append(self.set_identity_resolving_key_characteristic)
if coordinated_set_size is not None:
self.coordinated_set_size_characteristic = gatt.Characteristic(
uuid=gatt.GATT_COORDINATED_SET_SIZE_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=struct.pack('B', coordinated_set_size),
)
characteristics.append(self.coordinated_set_size_characteristic)
if set_member_lock is not None:
self.set_member_lock_characteristic = gatt.Characteristic(
uuid=gatt.GATT_SET_MEMBER_LOCK_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY
| gatt.Characteristic.Properties.WRITE,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION
| gatt.Characteristic.Permissions.WRITEABLE,
value=struct.pack('B', set_member_lock),
)
characteristics.append(self.set_member_lock_characteristic)
if set_member_rank is not None:
self.set_member_rank_characteristic = gatt.Characteristic(
uuid=gatt.GATT_SET_MEMBER_RANK_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=struct.pack('B', set_member_rank),
)
characteristics.append(self.set_member_rank_characteristic)
super().__init__(characteristics)
async def on_sirk_read(self, connection: Optional[device.Connection]) -> bytes:
if self.set_identity_resolving_key_type == SirkType.PLAINTEXT:
sirk_bytes = self.set_identity_resolving_key
else:
assert connection
if connection.transport == core.BT_LE_TRANSPORT:
key = await connection.device.get_long_term_key(
connection_handle=connection.handle, rand=b'', ediv=0
)
else:
key = await connection.device.get_link_key(connection.peer_address)
if not key:
raise RuntimeError('LTK or LinkKey is not present')
sirk_bytes = sef(key, self.set_identity_resolving_key)
return bytes([self.set_identity_resolving_key_type]) + sirk_bytes
def get_advertising_data(self) -> bytes:
return bytes(
core.AdvertisingData(
[
(
core.AdvertisingData.RESOLVABLE_SET_IDENTIFIER,
generate_rsi(self.set_identity_resolving_key),
),
]
)
)
# -----------------------------------------------------------------------------
# Client
# -----------------------------------------------------------------------------
class CoordinatedSetIdentificationProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = CoordinatedSetIdentificationService
set_identity_resolving_key: gatt_client.CharacteristicProxy
coordinated_set_size: Optional[gatt_client.CharacteristicProxy] = None
set_member_lock: Optional[gatt_client.CharacteristicProxy] = None
set_member_rank: Optional[gatt_client.CharacteristicProxy] = None
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
self.service_proxy = service_proxy
self.set_identity_resolving_key = service_proxy.get_characteristics_by_uuid(
gatt.GATT_SET_IDENTITY_RESOLVING_KEY_CHARACTERISTIC
)[0]
if characteristics := service_proxy.get_characteristics_by_uuid(
gatt.GATT_COORDINATED_SET_SIZE_CHARACTERISTIC
):
self.coordinated_set_size = characteristics[0]
if characteristics := service_proxy.get_characteristics_by_uuid(
gatt.GATT_SET_MEMBER_LOCK_CHARACTERISTIC
):
self.set_member_lock = characteristics[0]
if characteristics := service_proxy.get_characteristics_by_uuid(
gatt.GATT_SET_MEMBER_RANK_CHARACTERISTIC
):
self.set_member_rank = characteristics[0]
async def read_set_identity_resolving_key(self) -> Tuple[SirkType, bytes]:
'''Reads SIRK and decrypts if encrypted.'''
response = await self.set_identity_resolving_key.read_value()
if len(response) != SET_IDENTITY_RESOLVING_KEY_LENGTH + 1:
raise RuntimeError('Invalid SIRK value')
sirk_type = SirkType(response[0])
if sirk_type == SirkType.PLAINTEXT:
sirk = response[1:]
else:
connection = self.service_proxy.client.connection
device = connection.device
if connection.transport == core.BT_LE_TRANSPORT:
key = await device.get_long_term_key(
connection_handle=connection.handle, rand=b'', ediv=0
)
else:
key = await device.get_link_key(connection.peer_address)
if not key:
raise RuntimeError('LTK or LinkKey is not present')
sirk = sef(key, response[1:])
return (sirk_type, sirk)
+39 -53
View File
@@ -17,10 +17,10 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import struct import struct
from typing import Optional, Tuple from typing import Tuple
from bumble.gatt_client import ServiceProxy, ProfileServiceProxy, CharacteristicProxy from ..gatt_client import ProfileServiceProxy
from bumble.gatt import ( from ..gatt import (
GATT_DEVICE_INFORMATION_SERVICE, GATT_DEVICE_INFORMATION_SERVICE,
GATT_FIRMWARE_REVISION_STRING_CHARACTERISTIC, GATT_FIRMWARE_REVISION_STRING_CHARACTERISTIC,
GATT_HARDWARE_REVISION_STRING_CHARACTERISTIC, GATT_HARDWARE_REVISION_STRING_CHARACTERISTIC,
@@ -33,7 +33,7 @@ from bumble.gatt import (
TemplateService, TemplateService,
Characteristic, Characteristic,
DelegatedCharacteristicAdapter, DelegatedCharacteristicAdapter,
UTF8CharacteristicAdapter, UTF8CharacteristicAdapter
) )
@@ -52,50 +52,49 @@ class DeviceInformationService(TemplateService):
def __init__( def __init__(
self, self,
manufacturer_name: Optional[str] = None, manufacturer_name: str = None,
model_number: Optional[str] = None, model_number: str = None,
serial_number: Optional[str] = None, serial_number: str = None,
hardware_revision: Optional[str] = None, hardware_revision: str = None,
firmware_revision: Optional[str] = None, firmware_revision: str = None,
software_revision: Optional[str] = None, software_revision: str = None,
system_id: Optional[Tuple[int, int]] = None, # (OUI, Manufacturer ID) system_id: Tuple[int, int] = None, # (OUI, Manufacturer ID)
ieee_regulatory_certification_data_list: Optional[bytes] = None, ieee_regulatory_certification_data_list: bytes = None
# TODO: pnp_id # TODO: pnp_id
): ):
characteristics = [ characteristics = [
Characteristic( Characteristic(
uuid, Characteristic.Properties.READ, Characteristic.READABLE, field uuid,
Characteristic.READ,
Characteristic.READABLE,
field
) )
for (field, uuid) in ( for (field, uuid) in (
(manufacturer_name, GATT_MANUFACTURER_NAME_STRING_CHARACTERISTIC), (manufacturer_name, GATT_MANUFACTURER_NAME_STRING_CHARACTERISTIC),
(model_number, GATT_MODEL_NUMBER_STRING_CHARACTERISTIC), (model_number, GATT_MODEL_NUMBER_STRING_CHARACTERISTIC),
(serial_number, GATT_SERIAL_NUMBER_STRING_CHARACTERISTIC), (serial_number, GATT_SERIAL_NUMBER_STRING_CHARACTERISTIC),
(hardware_revision, GATT_HARDWARE_REVISION_STRING_CHARACTERISTIC), (hardware_revision, GATT_HARDWARE_REVISION_STRING_CHARACTERISTIC),
(firmware_revision, GATT_FIRMWARE_REVISION_STRING_CHARACTERISTIC), (firmware_revision, GATT_FIRMWARE_REVISION_STRING_CHARACTERISTIC),
(software_revision, GATT_SOFTWARE_REVISION_STRING_CHARACTERISTIC), (software_revision, GATT_SOFTWARE_REVISION_STRING_CHARACTERISTIC)
) )
if field is not None if field is not None
] ]
if system_id is not None: if system_id is not None:
characteristics.append( characteristics.append(Characteristic(
Characteristic( GATT_SYSTEM_ID_CHARACTERISTIC,
GATT_SYSTEM_ID_CHARACTERISTIC, Characteristic.READ,
Characteristic.Properties.READ, Characteristic.READABLE,
Characteristic.READABLE, self.pack_system_id(*system_id)
self.pack_system_id(*system_id), ))
)
)
if ieee_regulatory_certification_data_list is not None: if ieee_regulatory_certification_data_list is not None:
characteristics.append( characteristics.append(Characteristic(
Characteristic( GATT_REGULATORY_CERTIFICATION_DATA_LIST_CHARACTERISTIC,
GATT_REGULATORY_CERTIFICATION_DATA_LIST_CHARACTERISTIC, Characteristic.READ,
Characteristic.Properties.READ, Characteristic.READABLE,
Characteristic.READABLE, ieee_regulatory_certification_data_list
ieee_regulatory_certification_data_list, ))
)
)
super().__init__(characteristics) super().__init__(characteristics)
@@ -104,25 +103,16 @@ class DeviceInformationService(TemplateService):
class DeviceInformationServiceProxy(ProfileServiceProxy): class DeviceInformationServiceProxy(ProfileServiceProxy):
SERVICE_CLASS = DeviceInformationService SERVICE_CLASS = DeviceInformationService
manufacturer_name: Optional[UTF8CharacteristicAdapter] def __init__(self, service_proxy):
model_number: Optional[UTF8CharacteristicAdapter]
serial_number: Optional[UTF8CharacteristicAdapter]
hardware_revision: Optional[UTF8CharacteristicAdapter]
firmware_revision: Optional[UTF8CharacteristicAdapter]
software_revision: Optional[UTF8CharacteristicAdapter]
system_id: Optional[DelegatedCharacteristicAdapter]
ieee_regulatory_certification_data_list: Optional[CharacteristicProxy]
def __init__(self, service_proxy: ServiceProxy):
self.service_proxy = service_proxy self.service_proxy = service_proxy
for field, uuid in ( for (field, uuid) in (
('manufacturer_name', GATT_MANUFACTURER_NAME_STRING_CHARACTERISTIC), ('manufacturer_name', GATT_MANUFACTURER_NAME_STRING_CHARACTERISTIC),
('model_number', GATT_MODEL_NUMBER_STRING_CHARACTERISTIC), ('model_number', GATT_MODEL_NUMBER_STRING_CHARACTERISTIC),
('serial_number', GATT_SERIAL_NUMBER_STRING_CHARACTERISTIC), ('serial_number', GATT_SERIAL_NUMBER_STRING_CHARACTERISTIC),
('hardware_revision', GATT_HARDWARE_REVISION_STRING_CHARACTERISTIC), ('hardware_revision', GATT_HARDWARE_REVISION_STRING_CHARACTERISTIC),
('firmware_revision', GATT_FIRMWARE_REVISION_STRING_CHARACTERISTIC), ('firmware_revision', GATT_FIRMWARE_REVISION_STRING_CHARACTERISTIC),
('software_revision', GATT_SOFTWARE_REVISION_STRING_CHARACTERISTIC), ('software_revision', GATT_SOFTWARE_REVISION_STRING_CHARACTERISTIC)
): ):
if characteristics := service_proxy.get_characteristics_by_uuid(uuid): if characteristics := service_proxy.get_characteristics_by_uuid(uuid):
characteristic = UTF8CharacteristicAdapter(characteristics[0]) characteristic = UTF8CharacteristicAdapter(characteristics[0])
@@ -130,20 +120,16 @@ class DeviceInformationServiceProxy(ProfileServiceProxy):
characteristic = None characteristic = None
self.__setattr__(field, characteristic) self.__setattr__(field, characteristic)
if characteristics := service_proxy.get_characteristics_by_uuid( if characteristics := service_proxy.get_characteristics_by_uuid(GATT_SYSTEM_ID_CHARACTERISTIC):
GATT_SYSTEM_ID_CHARACTERISTIC
):
self.system_id = DelegatedCharacteristicAdapter( self.system_id = DelegatedCharacteristicAdapter(
characteristics[0], characteristics[0],
encode=lambda v: DeviceInformationService.pack_system_id(*v), encode=lambda v: DeviceInformationService.pack_system_id(*v),
decode=DeviceInformationService.unpack_system_id, decode=DeviceInformationService.unpack_system_id
) )
else: else:
self.system_id = None self.system_id = None
if characteristics := service_proxy.get_characteristics_by_uuid( if characteristics := service_proxy.get_characteristics_by_uuid(GATT_REGULATORY_CERTIFICATION_DATA_LIST_CHARACTERISTIC):
GATT_REGULATORY_CERTIFICATION_DATA_LIST_CHARACTERISTIC
):
self.ieee_regulatory_certification_data_list = characteristics[0] self.ieee_regulatory_certification_data_list = characteristics[0]
else: else:
self.ieee_regulatory_certification_data_list = None self.ieee_regulatory_certification_data_list = None
+42 -58
View File
@@ -30,25 +30,25 @@ from ..gatt import (
Characteristic, Characteristic,
CharacteristicValue, CharacteristicValue,
DelegatedCharacteristicAdapter, DelegatedCharacteristicAdapter,
PackedCharacteristicAdapter, PackedCharacteristicAdapter
) )
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class HeartRateService(TemplateService): class HeartRateService(TemplateService):
UUID = GATT_HEART_RATE_SERVICE UUID = GATT_HEART_RATE_SERVICE
HEART_RATE_CONTROL_POINT_FORMAT = 'B' HEART_RATE_CONTROL_POINT_FORMAT = 'B'
CONTROL_POINT_NOT_SUPPORTED = 0x80 CONTROL_POINT_NOT_SUPPORTED = 0x80
RESET_ENERGY_EXPENDED = 0x01 RESET_ENERGY_EXPENDED = 0x01
class BodySensorLocation(IntEnum): class BodySensorLocation(IntEnum):
OTHER = 0 OTHER = 0,
CHEST = 1 CHEST = 1,
WRIST = 2 WRIST = 2,
FINGER = 3 FINGER = 3,
HAND = 4 HAND = 4,
EAR_LOBE = 5 EAR_LOBE = 5,
FOOT = 6 FOOT = 6
class HeartRateMeasurement: class HeartRateMeasurement:
def __init__( def __init__(
@@ -56,14 +56,12 @@ class HeartRateService(TemplateService):
heart_rate, heart_rate,
sensor_contact_detected=None, sensor_contact_detected=None,
energy_expended=None, energy_expended=None,
rr_intervals=None, rr_intervals=None
): ):
if heart_rate < 0 or heart_rate > 0xFFFF: if heart_rate < 0 or heart_rate > 0xFFFF:
raise ValueError('heart_rate out of range') raise ValueError('heart_rate out of range')
if energy_expended is not None and ( if energy_expended is not None and (energy_expended < 0 or energy_expended > 0xFFFF):
energy_expended < 0 or energy_expended > 0xFFFF
):
raise ValueError('energy_expended out of range') raise ValueError('energy_expended out of range')
if rr_intervals: if rr_intervals:
@@ -71,10 +69,10 @@ class HeartRateService(TemplateService):
if rr_interval < 0 or rr_interval * 1024 > 0xFFFF: if rr_interval < 0 or rr_interval * 1024 > 0xFFFF:
raise ValueError('rr_intervals out of range') raise ValueError('rr_intervals out of range')
self.heart_rate = heart_rate self.heart_rate = heart_rate
self.sensor_contact_detected = sensor_contact_detected self.sensor_contact_detected = sensor_contact_detected
self.energy_expended = energy_expended self.energy_expended = energy_expended
self.rr_intervals = rr_intervals self.rr_intervals = rr_intervals
@classmethod @classmethod
def from_bytes(cls, data): def from_bytes(cls, data):
@@ -89,7 +87,7 @@ class HeartRateService(TemplateService):
offset += 1 offset += 1
if flags & (1 << 2): if flags & (1 << 2):
sensor_contact_detected = flags & (1 << 1) != 0 sensor_contact_detected = (flags & (1 << 1) != 0)
else: else:
sensor_contact_detected = None sensor_contact_detected = None
@@ -121,57 +119,51 @@ class HeartRateService(TemplateService):
flags |= ((1 if self.sensor_contact_detected else 0) << 1) | (1 << 2) flags |= ((1 if self.sensor_contact_detected else 0) << 1) | (1 << 2)
if self.energy_expended is not None: if self.energy_expended is not None:
flags |= 1 << 3 flags |= (1 << 3)
data += struct.pack('<H', self.energy_expended) data += struct.pack('<H', self.energy_expended)
if self.rr_intervals: if self.rr_intervals:
flags |= 1 << 4 flags |= (1 << 4)
data += b''.join( data += b''.join([
[ struct.pack('<H', int(rr_interval * 1024))
struct.pack('<H', int(rr_interval * 1024)) for rr_interval in self.rr_intervals
for rr_interval in self.rr_intervals ])
]
)
return bytes([flags]) + data return bytes([flags]) + data
def __str__(self): def __str__(self):
return ( return f'HeartRateMeasurement(heart_rate={self.heart_rate},'\
f'HeartRateMeasurement(heart_rate={self.heart_rate},' f' sensor_contact_detected={self.sensor_contact_detected},'\
f' sensor_contact_detected={self.sensor_contact_detected},' f' energy_expended={self.energy_expended},'\
f' energy_expended={self.energy_expended},'
f' rr_intervals={self.rr_intervals})' f' rr_intervals={self.rr_intervals})'
)
def __init__( def __init__(
self, self,
read_heart_rate_measurement, read_heart_rate_measurement,
body_sensor_location=None, body_sensor_location=None,
reset_energy_expended=None, reset_energy_expended=None
): ):
self.heart_rate_measurement_characteristic = DelegatedCharacteristicAdapter( self.heart_rate_measurement_characteristic = DelegatedCharacteristicAdapter(
Characteristic( Characteristic(
GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC, GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC,
Characteristic.Properties.NOTIFY, Characteristic.NOTIFY,
0, 0,
CharacteristicValue(read=read_heart_rate_measurement), CharacteristicValue(read=read_heart_rate_measurement)
), ),
# pylint: disable=unnecessary-lambda encode=lambda value: bytes(value)
encode=lambda value: bytes(value),
) )
characteristics = [self.heart_rate_measurement_characteristic] characteristics = [self.heart_rate_measurement_characteristic]
if body_sensor_location is not None: if body_sensor_location is not None:
self.body_sensor_location_characteristic = Characteristic( self.body_sensor_location_characteristic = Characteristic(
GATT_BODY_SENSOR_LOCATION_CHARACTERISTIC, GATT_BODY_SENSOR_LOCATION_CHARACTERISTIC,
Characteristic.Properties.READ, Characteristic.READ,
Characteristic.READABLE, Characteristic.READABLE,
bytes([int(body_sensor_location)]), bytes([int(body_sensor_location)])
) )
characteristics.append(self.body_sensor_location_characteristic) characteristics.append(self.body_sensor_location_characteristic)
if reset_energy_expended: if reset_energy_expended:
def write_heart_rate_control_point_value(connection, value): def write_heart_rate_control_point_value(connection, value):
if value == self.RESET_ENERGY_EXPENDED: if value == self.RESET_ENERGY_EXPENDED:
if reset_energy_expended is not None: if reset_energy_expended is not None:
@@ -182,11 +174,11 @@ class HeartRateService(TemplateService):
self.heart_rate_control_point_characteristic = PackedCharacteristicAdapter( self.heart_rate_control_point_characteristic = PackedCharacteristicAdapter(
Characteristic( Characteristic(
GATT_HEART_RATE_CONTROL_POINT_CHARACTERISTIC, GATT_HEART_RATE_CONTROL_POINT_CHARACTERISTIC,
Characteristic.Properties.WRITE, Characteristic.WRITE,
Characteristic.WRITEABLE, Characteristic.WRITEABLE,
CharacteristicValue(write=write_heart_rate_control_point_value), CharacteristicValue(write=write_heart_rate_control_point_value)
), ),
pack_format=HeartRateService.HEART_RATE_CONTROL_POINT_FORMAT, format=HeartRateService.HEART_RATE_CONTROL_POINT_FORMAT
) )
characteristics.append(self.heart_rate_control_point_characteristic) characteristics.append(self.heart_rate_control_point_characteristic)
@@ -200,38 +192,30 @@ class HeartRateServiceProxy(ProfileServiceProxy):
def __init__(self, service_proxy): def __init__(self, service_proxy):
self.service_proxy = service_proxy self.service_proxy = service_proxy
if characteristics := service_proxy.get_characteristics_by_uuid( if characteristics := service_proxy.get_characteristics_by_uuid(GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC):
GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC
):
self.heart_rate_measurement = DelegatedCharacteristicAdapter( self.heart_rate_measurement = DelegatedCharacteristicAdapter(
characteristics[0], characteristics[0],
decode=HeartRateService.HeartRateMeasurement.from_bytes, decode=HeartRateService.HeartRateMeasurement.from_bytes
) )
else: else:
self.heart_rate_measurement = None self.heart_rate_measurement = None
if characteristics := service_proxy.get_characteristics_by_uuid( if characteristics := service_proxy.get_characteristics_by_uuid(GATT_BODY_SENSOR_LOCATION_CHARACTERISTIC):
GATT_BODY_SENSOR_LOCATION_CHARACTERISTIC
):
self.body_sensor_location = DelegatedCharacteristicAdapter( self.body_sensor_location = DelegatedCharacteristicAdapter(
characteristics[0], characteristics[0],
decode=lambda value: HeartRateService.BodySensorLocation(value[0]), decode=lambda value: HeartRateService.BodySensorLocation(value[0])
) )
else: else:
self.body_sensor_location = None self.body_sensor_location = None
if characteristics := service_proxy.get_characteristics_by_uuid( if characteristics := service_proxy.get_characteristics_by_uuid(GATT_HEART_RATE_CONTROL_POINT_CHARACTERISTIC):
GATT_HEART_RATE_CONTROL_POINT_CHARACTERISTIC
):
self.heart_rate_control_point = PackedCharacteristicAdapter( self.heart_rate_control_point = PackedCharacteristicAdapter(
characteristics[0], characteristics[0],
pack_format=HeartRateService.HEART_RATE_CONTROL_POINT_FORMAT, format=HeartRateService.HEART_RATE_CONTROL_POINT_FORMAT
) )
else: else:
self.heart_rate_control_point = None self.heart_rate_control_point = None
async def reset_energy_expended(self): async def reset_energy_expended(self):
if self.heart_rate_control_point is not None: if self.heart_rate_control_point is not None:
return await self.heart_rate_control_point.write_value( return await self.heart_rate_control_point.write_value(HeartRateService.RESET_ENERGY_EXPENDED)
HeartRateService.RESET_ENERGY_EXPENDED
)
-49
View File
@@ -1,49 +0,0 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import dataclasses
from typing import List
from typing_extensions import Self
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class Metadata:
@dataclasses.dataclass
class Entry:
tag: int
data: bytes
entries: List[Entry]
@classmethod
def from_bytes(cls, data: bytes) -> Self:
entries = []
offset = 0
length = len(data)
while length >= 2:
entry_length = data[offset]
entry_tag = data[offset + 1]
entry_data = data[offset + 2 : offset + 2 + entry_length - 1]
entries.append(cls.Entry(entry_tag, entry_data))
length -= entry_length
offset += entry_length
return cls(entries)
-46
View File
@@ -1,46 +0,0 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import dataclasses
import enum
from typing_extensions import Self
from bumble.profiles import le_audio
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class PublicBroadcastAnnouncement:
class Features(enum.IntFlag):
ENCRYPTED = 1 << 0
STANDARD_QUALITY_CONFIGURATION = 1 << 1
HIGH_QUALITY_CONFIGURATION = 1 << 2
features: Features
metadata: le_audio.Metadata
@classmethod
def from_bytes(cls, data: bytes) -> Self:
features = cls.Features(data[0])
metadata_length = data[1]
metadata_ltv = data[1 : 1 + metadata_length]
return cls(
features=features, metadata=le_audio.Metadata.from_bytes(metadata_ltv)
)
View File
-228
View File
@@ -1,228 +0,0 @@
# Copyright 2021-2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import enum
from bumble import att
from bumble import device
from bumble import gatt
from bumble import gatt_client
from typing import Optional
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
MIN_VOLUME = 0
MAX_VOLUME = 255
class ErrorCode(enum.IntEnum):
'''
See Volume Control Service 1.6. Application error codes.
'''
INVALID_CHANGE_COUNTER = 0x80
OPCODE_NOT_SUPPORTED = 0x81
class VolumeFlags(enum.IntFlag):
'''
See Volume Control Service 3.3. Volume Flags.
'''
VOLUME_SETTING_PERSISTED = 0x01
# RFU
class VolumeControlPointOpcode(enum.IntEnum):
'''
See Volume Control Service Table 3.3: Volume Control Point procedure requirements.
'''
# fmt: off
RELATIVE_VOLUME_DOWN = 0x00
RELATIVE_VOLUME_UP = 0x01
UNMUTE_RELATIVE_VOLUME_DOWN = 0x02
UNMUTE_RELATIVE_VOLUME_UP = 0x03
SET_ABSOLUTE_VOLUME = 0x04
UNMUTE = 0x05
MUTE = 0x06
# -----------------------------------------------------------------------------
# Server
# -----------------------------------------------------------------------------
class VolumeControlService(gatt.TemplateService):
UUID = gatt.GATT_VOLUME_CONTROL_SERVICE
volume_state: gatt.Characteristic
volume_control_point: gatt.Characteristic
volume_flags: gatt.Characteristic
volume_setting: int
muted: int
change_counter: int
def __init__(
self,
step_size: int = 16,
volume_setting: int = 0,
muted: int = 0,
change_counter: int = 0,
volume_flags: int = 0,
) -> None:
self.step_size = step_size
self.volume_setting = volume_setting
self.muted = muted
self.change_counter = change_counter
self.volume_state = gatt.Characteristic(
uuid=gatt.GATT_VOLUME_STATE_CHARACTERISTIC,
properties=(
gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY
),
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=gatt.CharacteristicValue(read=self._on_read_volume_state),
)
self.volume_control_point = gatt.Characteristic(
uuid=gatt.GATT_VOLUME_CONTROL_POINT_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.WRITE,
permissions=gatt.Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION,
value=gatt.CharacteristicValue(write=self._on_write_volume_control_point),
)
self.volume_flags = gatt.Characteristic(
uuid=gatt.GATT_VOLUME_FLAGS_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=bytes([volume_flags]),
)
super().__init__(
[
self.volume_state,
self.volume_control_point,
self.volume_flags,
]
)
@property
def volume_state_bytes(self) -> bytes:
return bytes([self.volume_setting, self.muted, self.change_counter])
@volume_state_bytes.setter
def volume_state_bytes(self, new_value: bytes) -> None:
self.volume_setting, self.muted, self.change_counter = new_value
def _on_read_volume_state(self, _connection: Optional[device.Connection]) -> bytes:
return self.volume_state_bytes
def _on_write_volume_control_point(
self, connection: Optional[device.Connection], value: bytes
) -> None:
assert connection
opcode = VolumeControlPointOpcode(value[0])
change_counter = value[1]
if change_counter != self.change_counter:
raise att.ATT_Error(ErrorCode.INVALID_CHANGE_COUNTER)
handler = getattr(self, '_on_' + opcode.name.lower())
if handler(*value[2:]):
self.change_counter = (self.change_counter + 1) % 256
connection.abort_on(
'disconnection',
connection.device.notify_subscribers(
attribute=self.volume_state,
value=self.volume_state_bytes,
),
)
self.emit(
'volume_state', self.volume_setting, self.muted, self.change_counter
)
def _on_relative_volume_down(self) -> bool:
old_volume = self.volume_setting
self.volume_setting = max(self.volume_setting - self.step_size, MIN_VOLUME)
return self.volume_setting != old_volume
def _on_relative_volume_up(self) -> bool:
old_volume = self.volume_setting
self.volume_setting = min(self.volume_setting + self.step_size, MAX_VOLUME)
return self.volume_setting != old_volume
def _on_unmute_relative_volume_down(self) -> bool:
old_volume, old_muted_state = self.volume_setting, self.muted
self.volume_setting = max(self.volume_setting - self.step_size, MIN_VOLUME)
self.muted = 0
return (self.volume_setting, self.muted) != (old_volume, old_muted_state)
def _on_unmute_relative_volume_up(self) -> bool:
old_volume, old_muted_state = self.volume_setting, self.muted
self.volume_setting = min(self.volume_setting + self.step_size, MAX_VOLUME)
self.muted = 0
return (self.volume_setting, self.muted) != (old_volume, old_muted_state)
def _on_set_absolute_volume(self, volume_setting: int) -> bool:
old_volume_setting = self.volume_setting
self.volume_setting = volume_setting
return old_volume_setting != self.volume_setting
def _on_unmute(self) -> bool:
old_muted_state = self.muted
self.muted = 0
return self.muted != old_muted_state
def _on_mute(self) -> bool:
old_muted_state = self.muted
self.muted = 1
return self.muted != old_muted_state
# -----------------------------------------------------------------------------
# Client
# -----------------------------------------------------------------------------
class VolumeControlServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = VolumeControlService
volume_control_point: gatt_client.CharacteristicProxy
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
self.service_proxy = service_proxy
self.volume_state = gatt.PackedCharacteristicAdapter(
service_proxy.get_characteristics_by_uuid(
gatt.GATT_VOLUME_STATE_CHARACTERISTIC
)[0],
'BBB',
)
self.volume_control_point = service_proxy.get_characteristics_by_uuid(
gatt.GATT_VOLUME_CONTROL_POINT_CHARACTERISTIC
)[0]
self.volume_flags = gatt.PackedCharacteristicAdapter(
service_proxy.get_characteristics_by_uuid(
gatt.GATT_VOLUME_FLAGS_CHARACTERISTIC
)[0],
'B',
)
View File
+422 -732
View File
File diff suppressed because it is too large Load Diff
+233 -419
View File
File diff suppressed because it is too large Load Diff
+574 -1044
View File
File diff suppressed because it is too large Load Diff
-170
View File
@@ -1,170 +0,0 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from contextlib import contextmanager
from enum import IntEnum
import logging
import struct
import datetime
from typing import BinaryIO, Generator
import os
from bumble.hci import HCI_COMMAND_PACKET, HCI_EVENT_PACKET
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
class Snooper:
"""
Base class for snooper implementations.
A snooper is an object that will be provided with HCI packets as they are
exchanged between a host and a controller.
"""
class Direction(IntEnum):
HOST_TO_CONTROLLER = 0
CONTROLLER_TO_HOST = 1
class DataLinkType(IntEnum):
H1 = 1001
H4 = 1002
HCI_BSCP = 1003
H5 = 1004
def snoop(self, hci_packet: bytes, direction: Direction) -> None:
"""Snoop on an HCI packet."""
# -----------------------------------------------------------------------------
class BtSnooper(Snooper):
"""
Snooper that saves HCI packets using the BTSnoop format, based on RFC 1761.
"""
IDENTIFICATION_PATTERN = b'btsnoop\0'
TIMESTAMP_ANCHOR = datetime.datetime(2000, 1, 1)
TIMESTAMP_DELTA = 0x00E03AB44A676000
ONE_MS = datetime.timedelta(microseconds=1)
def __init__(self, output: BinaryIO):
self.output = output
# Write the header
self.output.write(
self.IDENTIFICATION_PATTERN + struct.pack('>LL', 1, self.DataLinkType.H4)
)
def snoop(self, hci_packet: bytes, direction: Snooper.Direction) -> None:
flags = int(direction)
packet_type = hci_packet[0]
if packet_type in (HCI_EVENT_PACKET, HCI_COMMAND_PACKET):
flags |= 0x10
# Compute the current timestamp
timestamp = (
int((datetime.datetime.utcnow() - self.TIMESTAMP_ANCHOR) / self.ONE_MS)
+ self.TIMESTAMP_DELTA
)
# Emit the record
self.output.write(
struct.pack(
'>IIIIQ',
len(hci_packet), # Original Length
len(hci_packet), # Included Length
flags, # Packet Flags
0, # Cumulative Drops
timestamp, # Timestamp
)
+ hci_packet
)
# -----------------------------------------------------------------------------
_SNOOPER_INSTANCE_COUNT = 0
@contextmanager
def create_snooper(spec: str) -> Generator[Snooper, None, None]:
"""
Create a snooper given a specification string.
The general syntax for the specification string is:
<snooper-type>:<type-specific-arguments>
Supported snooper types are:
btsnoop
The syntax for the type-specific arguments for this type is:
<io-type>:<io-type-specific-arguments>
Supported I/O types are:
file
The type-specific arguments for this I/O type is a string that is converted
to a file path using the python `str.format()` string formatting. The log
records will be written to that file if it can be opened/created.
The keyword args that may be referenced by the string pattern are:
now: the value of `datetime.now()`
utcnow: the value of `datetime.utcnow()`
pid: the current process ID.
instance: the instance ID in the current process.
Examples:
btsnoop:file:my_btsnoop.log
btsnoop:file:/tmp/bumble_{now:%Y-%m-%d-%H:%M:%S}_{pid}.log
"""
if ':' not in spec:
raise ValueError('snooper type prefix missing')
snooper_type, snooper_args = spec.split(':', maxsplit=1)
if snooper_type == 'btsnoop':
if ':' not in snooper_args:
raise ValueError('I/O type for btsnoop snooper type missing')
io_type, io_name = snooper_args.split(':', maxsplit=1)
if io_type == 'file':
# Process the file name string pattern.
global _SNOOPER_INSTANCE_COUNT
file_path = io_name.format(
now=datetime.datetime.now(),
utcnow=datetime.datetime.utcnow(),
pid=os.getpid(),
instance=_SNOOPER_INSTANCE_COUNT,
)
# Open the file
logger.debug(f'Snoop file: {file_path}')
with open(file_path, 'wb') as snoop_file:
_SNOOPER_INSTANCE_COUNT += 1
yield BtSnooper(snoop_file)
_SNOOPER_INSTANCE_COUNT -= 1
return
raise ValueError(f'I/O type {io_type} not supported')
raise ValueError(f'snooper type {snooper_type} not found')
+43 -162
View File
@@ -1,4 +1,4 @@
# Copyright 2021-2023 Google LLC # Copyright 2021-2022 Google LLC
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -15,13 +15,11 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from contextlib import asynccontextmanager
import logging import logging
import os
from typing import Optional
from .common import Transport, AsyncPipeSink, SnoopingTransport from .common import Transport, AsyncPipeSink
from ..snoop import create_snooper from ..link import RemoteLink
from ..controller import Controller
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -30,185 +28,68 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def _wrap_transport(transport: Transport) -> Transport: async def open_transport(name):
""" '''
Automatically wrap a Transport instance when a wrapping class can be inferred
from the environment.
If no wrapping class is applicable, the transport argument is returned as-is.
"""
# If BUMBLE_SNOOPER is set, try to automatically create a snooper.
if snooper_spec := os.getenv('BUMBLE_SNOOPER'):
try:
return SnoopingTransport.create_with(
transport, create_snooper(snooper_spec)
)
except Exception as exc:
logger.warning(f'Exception while creating snooper: {exc}')
return transport
# -----------------------------------------------------------------------------
async def open_transport(name: str) -> Transport:
"""
Open a transport by name. Open a transport by name.
The name must be <type>:<metadata><parameters> The name must be <type>:<parameters>
Where <parameters> depend on the type (and may be empty for some types), and Where <parameters> depend on the type (and may be empty for some types).
<metadata> is either omitted, or a ,-separated list of <key>=<value> pairs, The supported types are: serial,udp,tcp,pty,usb
enclosed in []. '''
If there are not metadata or parameter, the : after the <type> may be omitted. scheme, *spec = name.split(':', 1)
Examples:
* usb:0
* usb:[driver=rtk]0
* android-netsim
The supported types are:
* serial
* udp
* tcp-client
* tcp-server
* ws-client
* ws-server
* pty
* file
* vhci
* hci-socket
* usb
* pyusb
* android-emulator
* android-netsim
"""
scheme, *tail = name.split(':', 1)
spec = tail[0] if tail else None
metadata = None
if spec:
# Metadata may precede the spec
if spec.startswith('['):
metadata_str, *tail = spec[1:].split(']')
spec = tail[0] if tail else None
metadata = dict([entry.split('=') for entry in metadata_str.split(',')])
transport = await _open_transport(scheme, spec)
if metadata:
transport.source.metadata = { # type: ignore[attr-defined]
**metadata,
**getattr(transport.source, 'metadata', {}),
}
# pylint: disable=line-too-long
logger.debug(f'HCI metadata: {transport.source.metadata}') # type: ignore[attr-defined]
return _wrap_transport(transport)
# -----------------------------------------------------------------------------
async def _open_transport(scheme: str, spec: Optional[str]) -> Transport:
# pylint: disable=import-outside-toplevel
# pylint: disable=too-many-return-statements
if scheme == 'serial' and spec: if scheme == 'serial' and spec:
from .serial import open_serial_transport from .serial import open_serial_transport
return await open_serial_transport(spec[0])
return await open_serial_transport(spec) elif scheme == 'udp' and spec:
if scheme == 'udp' and spec:
from .udp import open_udp_transport from .udp import open_udp_transport
return await open_udp_transport(spec[0])
return await open_udp_transport(spec) elif scheme == 'tcp-client' and spec:
if scheme == 'tcp-client' and spec:
from .tcp_client import open_tcp_client_transport from .tcp_client import open_tcp_client_transport
return await open_tcp_client_transport(spec[0])
return await open_tcp_client_transport(spec) elif scheme == 'tcp-server' and spec:
if scheme == 'tcp-server' and spec:
from .tcp_server import open_tcp_server_transport from .tcp_server import open_tcp_server_transport
return await open_tcp_server_transport(spec[0])
return await open_tcp_server_transport(spec) elif scheme == 'ws-client' and spec:
if scheme == 'ws-client' and spec:
from .ws_client import open_ws_client_transport from .ws_client import open_ws_client_transport
return await open_ws_client_transport(spec[0])
return await open_ws_client_transport(spec) elif scheme == 'ws-server' and spec:
if scheme == 'ws-server' and spec:
from .ws_server import open_ws_server_transport from .ws_server import open_ws_server_transport
return await open_ws_server_transport(spec[0])
return await open_ws_server_transport(spec) elif scheme == 'pty':
if scheme == 'pty':
from .pty import open_pty_transport from .pty import open_pty_transport
return await open_pty_transport(spec[0] if spec else None)
return await open_pty_transport(spec) elif scheme == 'file':
if scheme == 'file':
from .file import open_file_transport from .file import open_file_transport
return await open_file_transport(spec[0] if spec else None)
assert spec is not None elif scheme == 'vhci':
return await open_file_transport(spec)
if scheme == 'vhci':
from .vhci import open_vhci_transport from .vhci import open_vhci_transport
return await open_vhci_transport(spec[0] if spec else None)
return await open_vhci_transport(spec) elif scheme == 'hci-socket':
if scheme == 'hci-socket':
from .hci_socket import open_hci_socket_transport from .hci_socket import open_hci_socket_transport
return await open_hci_socket_transport(spec[0] if spec else None)
return await open_hci_socket_transport(spec) elif scheme == 'usb':
if scheme == 'usb':
from .usb import open_usb_transport from .usb import open_usb_transport
return await open_usb_transport(spec[0] if spec else None)
assert spec elif scheme == 'pyusb':
return await open_usb_transport(spec)
if scheme == 'pyusb':
from .pyusb import open_pyusb_transport from .pyusb import open_pyusb_transport
return await open_pyusb_transport(spec[0] if spec else None)
assert spec elif scheme == 'android-emulator':
return await open_pyusb_transport(spec)
if scheme == 'android-emulator':
from .android_emulator import open_android_emulator_transport from .android_emulator import open_android_emulator_transport
return await open_android_emulator_transport(spec[0] if spec else None)
return await open_android_emulator_transport(spec) else:
raise ValueError('unknown transport scheme')
if scheme == 'android-netsim':
from .android_netsim import open_android_netsim_transport
return await open_android_netsim_transport(spec)
raise ValueError('unknown transport scheme')
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def open_transport_or_link(name: str) -> Transport: async def open_transport_or_link(name):
"""
Open a transport or a link relay.
Args:
name:
Name of the transport or link relay to open.
When the name starts with "link-relay:", open a link relay (see RemoteLink
for details on what the arguments are).
For other namespaces, see `open_transport`.
"""
if name.startswith('link-relay:'): if name.startswith('link-relay:'):
logger.warning('Link Relay has been deprecated.')
from ..controller import Controller
from ..link import RemoteLink # lazy import
link = RemoteLink(name[11:]) link = RemoteLink(name[11:])
await link.wait_until_connected() await link.wait_until_connected()
controller = Controller('remote', link=link) # type:ignore[arg-type] controller = Controller('remote', link = link)
class LinkTransport(Transport): class LinkTransport(Transport):
async def close(self): async def close(self):
link.close() link.close()
return _wrap_transport(LinkTransport(controller, AsyncPipeSink(controller))) return LinkTransport(controller, AsyncPipeSink(controller))
else:
return await open_transport(name) return await open_transport(name)
+20 -23
View File
@@ -1,4 +1,4 @@
# Copyright 2021-2023 Google LLC # Copyright 2021-2022 Google LLC
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@@ -16,16 +16,12 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import logging import logging
import grpc.aio import grpc
from typing import Optional, Union from .common import PumpedTransport, PumpedPacketSource, PumpedPacketSink
from .emulated_bluetooth_pb2_grpc import EmulatedBluetoothServiceStub
from .common import PumpedTransport, PumpedPacketSource, PumpedPacketSink, Transport from .emulated_bluetooth_packets_pb2 import HCIPacket
from .emulated_bluetooth_vhci_pb2_grpc import VhciForwardingServiceStub
# pylint: disable=no-name-in-module
from .grpc_protobuf.emulated_bluetooth_pb2_grpc import EmulatedBluetoothServiceStub
from .grpc_protobuf.emulated_bluetooth_packets_pb2 import HCIPacket
from .grpc_protobuf.emulated_bluetooth_vhci_pb2_grpc import VhciForwardingServiceStub
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -35,7 +31,7 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def open_android_emulator_transport(spec: Optional[str]) -> Transport: async def open_android_emulator_transport(spec):
''' '''
Open a transport connection to an Android emulator via its gRPC interface. Open a transport connection to an Android emulator via its gRPC interface.
The parameter string has this syntax: The parameter string has this syntax:
@@ -63,13 +59,18 @@ async def open_android_emulator_transport(spec: Optional[str]) -> Transport:
return bytes([packet.type]) + packet.packet return bytes([packet.type]) + packet.packet
async def write(self, packet): async def write(self, packet):
await self.hci_device.write(HCIPacket(type=packet[0], packet=packet[1:])) await self.hci_device.write(
HCIPacket(
type = packet[0],
packet = packet[1:]
)
)
# Parse the parameters # Parse the parameters
mode = 'host' mode = 'host'
server_host = 'localhost' server_host = 'localhost'
server_port = '8554' server_port = 8554
if spec: if spec is not None:
params = spec.split(',') params = spec.split(',')
for param in params: for param in params:
if param.startswith('mode='): if param.startswith('mode='):
@@ -84,7 +85,6 @@ async def open_android_emulator_transport(spec: Optional[str]) -> Transport:
logger.debug(f'connecting to gRPC server at {server_address}') logger.debug(f'connecting to gRPC server at {server_address}')
channel = grpc.aio.insecure_channel(server_address) channel = grpc.aio.insecure_channel(server_address)
service: Union[EmulatedBluetoothServiceStub, VhciForwardingServiceStub]
if mode == 'host': if mode == 'host':
# Connect as a host # Connect as a host
service = EmulatedBluetoothServiceStub(channel) service = EmulatedBluetoothServiceStub(channel)
@@ -97,13 +97,10 @@ async def open_android_emulator_transport(spec: Optional[str]) -> Transport:
raise ValueError('invalid mode') raise ValueError('invalid mode')
# Create the transport object # Create the transport object
class EmulatorTransport(PumpedTransport): transport = PumpedTransport(
async def close(self): PumpedPacketSource(hci_device.read),
await super().close() PumpedPacketSink(hci_device.write),
await channel.close() channel.close
transport = EmulatorTransport(
PumpedPacketSource(hci_device.read), PumpedPacketSink(hci_device.write)
) )
transport.start() transport.start()
-446
View File
@@ -1,446 +0,0 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import atexit
import logging
import os
import pathlib
import sys
from typing import Dict, Optional
import grpc.aio
from .common import (
ParserSource,
PumpedTransport,
PumpedPacketSource,
PumpedPacketSink,
Transport,
)
# pylint: disable=no-name-in-module
from .grpc_protobuf.packet_streamer_pb2_grpc import (
PacketStreamerStub,
PacketStreamerServicer,
add_PacketStreamerServicer_to_server,
)
from .grpc_protobuf.packet_streamer_pb2 import PacketRequest, PacketResponse
from .grpc_protobuf.hci_packet_pb2 import HCIPacket
from .grpc_protobuf.startup_pb2 import Chip, ChipInfo
from .grpc_protobuf.common_pb2 import ChipKind
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
DEFAULT_NAME = 'bumble0'
DEFAULT_MANUFACTURER = 'Bumble'
# -----------------------------------------------------------------------------
def get_ini_dir() -> Optional[pathlib.Path]:
if sys.platform == 'darwin':
if tmpdir := os.getenv('TMPDIR', None):
return pathlib.Path(tmpdir)
if home := os.getenv('HOME', None):
return pathlib.Path(home) / 'Library/Caches/TemporaryItems'
elif sys.platform == 'linux':
if xdg_runtime_dir := os.environ.get('XDG_RUNTIME_DIR', None):
return pathlib.Path(xdg_runtime_dir)
elif sys.platform == 'win32':
if local_app_data_dir := os.environ.get('LOCALAPPDATA', None):
return pathlib.Path(local_app_data_dir) / 'Temp'
return None
# -----------------------------------------------------------------------------
def ini_file_name(instance_number: int) -> str:
suffix = f'_{instance_number}' if instance_number > 0 else ''
return f'netsim{suffix}.ini'
# -----------------------------------------------------------------------------
def find_grpc_port(instance_number: int) -> int:
if not (ini_dir := get_ini_dir()):
logger.debug('no known directory for .ini file')
return 0
ini_file = ini_dir / ini_file_name(instance_number)
logger.debug(f'Looking for .ini file at {ini_file}')
if ini_file.is_file():
with open(ini_file, 'r') as ini_file_data:
for line in ini_file_data.readlines():
if '=' in line:
key, value = line.split('=')
if key == 'grpc.port':
logger.debug(f'gRPC port = {value}')
return int(value)
logger.debug('no grpc.port property found in .ini file')
# Not found
return 0
# -----------------------------------------------------------------------------
def publish_grpc_port(grpc_port: int, instance_number: int) -> bool:
if not (ini_dir := get_ini_dir()):
logger.debug('no known directory for .ini file')
return False
if not ini_dir.is_dir():
logger.debug('ini directory does not exist')
return False
ini_file = ini_dir / ini_file_name(instance_number)
try:
ini_file.write_text(f'grpc.port={grpc_port}\n')
logger.debug(f"published gRPC port at {ini_file}")
def cleanup():
logger.debug("removing .ini file")
ini_file.unlink()
atexit.register(cleanup)
return True
except OSError:
logger.debug('failed to write to .ini file')
return False
# -----------------------------------------------------------------------------
async def open_android_netsim_controller_transport(
server_host: Optional[str], server_port: int, options: Dict[str, str]
) -> Transport:
if not server_port:
raise ValueError('invalid port')
if server_host == '_' or not server_host:
server_host = 'localhost'
instance_number = int(options.get('instance', "0"))
if not publish_grpc_port(server_port, instance_number):
logger.warning("unable to publish gRPC port")
class HciDevice:
def __init__(self, context, on_data_received):
self.context = context
self.on_data_received = on_data_received
self.name = None
self.loop = asyncio.get_running_loop()
self.done = self.loop.create_future()
self.task = self.loop.create_task(self.pump())
async def pump(self):
try:
await self.pump_loop()
except asyncio.CancelledError:
logger.debug('Pump task canceled')
self.done.set_result(None)
async def pump_loop(self):
while True:
request = await self.context.read()
if request == grpc.aio.EOF:
logger.debug('End of request stream')
self.done.set_result(None)
return
# If we're not initialized yet, wait for a init packet.
if self.name is None:
if request.WhichOneof('request_type') == 'initial_info':
logger.debug(f'Received initial info: {request}')
# We only accept BLUETOOTH
if request.initial_info.chip.kind != ChipKind.BLUETOOTH:
logger.warning('Unsupported chip type')
error = PacketResponse(error='Unsupported chip type')
await self.context.write(error)
return
self.name = request.initial_info.name
continue
# Expect a data packet
request_type = request.WhichOneof('request_type')
if request_type != 'hci_packet':
logger.warning(f'Unexpected request type: {request_type}')
error = PacketResponse(error='Unexpected request type')
await self.context.write(error)
continue
# Process the packet
data = (
bytes([request.hci_packet.packet_type]) + request.hci_packet.packet
)
logger.debug(f'<<< PACKET: {data.hex()}')
self.on_data_received(data)
async def send_packet(self, data):
return await self.context.write(
PacketResponse(
hci_packet=HCIPacket(packet_type=data[0], packet=data[1:])
)
)
def terminate(self):
self.task.cancel()
async def wait_for_termination(self):
await self.done
class Server(PacketStreamerServicer, ParserSource):
def __init__(self):
PacketStreamerServicer.__init__(self)
ParserSource.__init__(self)
self.device = None
# Create a gRPC server with `so_reuseport=0` so that if there's already
# a server listening on that port, we get an exception.
self.grpc_server = grpc.aio.server(options=(('grpc.so_reuseport', 0),))
add_PacketStreamerServicer_to_server(self, self.grpc_server)
self.grpc_server.add_insecure_port(f'{server_host}:{server_port}')
logger.debug(f'gRPC server listening on {server_host}:{server_port}')
async def start(self):
logger.debug('Starting gRPC server')
await self.grpc_server.start()
async def serve(self):
# Keep serving until terminated.
try:
await self.grpc_server.wait_for_termination()
logger.debug('gRPC server terminated')
except asyncio.CancelledError:
logger.debug('gRPC server cancelled')
await self.grpc_server.stop(None)
async def send_packet(self, packet):
if not self.device:
logger.debug('no device, dropping packet')
return
return await self.device.send_packet(packet)
async def StreamPackets(self, _request_iterator, context):
logger.debug('StreamPackets request')
# Check that we don't already have a device
if self.device:
logger.debug('busy, already serving a device')
return PacketResponse(error='Busy')
# Instantiate a new device
self.device = HciDevice(context, self.parser.feed_data)
# Wait for the device to terminate
logger.debug('Waiting for device to terminate')
try:
await self.device.wait_for_termination()
except asyncio.CancelledError:
logger.debug('Request canceled')
self.device.terminate()
logger.debug('Device terminated')
self.device = None
server = Server()
await server.start()
asyncio.get_running_loop().create_task(server.serve())
sink = PumpedPacketSink(server.send_packet)
sink.start()
return Transport(server, sink)
# -----------------------------------------------------------------------------
async def open_android_netsim_host_transport_with_address(
server_host: Optional[str],
server_port: int,
options: Optional[Dict[str, str]] = None,
):
if server_host == '_' or not server_host:
server_host = 'localhost'
if not server_port:
# Look for the gRPC config in a .ini file
instance_number = 0 if options is None else int(options.get('instance', '0'))
server_port = find_grpc_port(instance_number)
if not server_port:
raise RuntimeError('gRPC server port not found')
# Connect to the gRPC server
server_address = f'{server_host}:{server_port}'
logger.debug(f'Connecting to gRPC server at {server_address}')
channel = grpc.aio.insecure_channel(server_address)
return await open_android_netsim_host_transport_with_channel(
channel,
options,
)
# -----------------------------------------------------------------------------
async def open_android_netsim_host_transport_with_channel(
channel, options: Optional[Dict[str, str]] = None
):
# Wrapper for I/O operations
class HciDevice:
def __init__(self, name, manufacturer, hci_device):
self.name = name
self.manufacturer = manufacturer
self.hci_device = hci_device
async def start(self): # Send the startup info
chip_info = ChipInfo(
name=self.name,
chip=Chip(kind=ChipKind.BLUETOOTH, manufacturer=self.manufacturer),
)
logger.debug(f'Sending chip info to netsim: {chip_info}')
await self.hci_device.write(PacketRequest(initial_info=chip_info))
async def read(self):
response = await self.hci_device.read()
response_type = response.WhichOneof('response_type')
if response_type == 'error':
logger.warning(f'received error: {response.error}')
raise RuntimeError(response.error)
if response_type == 'hci_packet':
return (
bytes([response.hci_packet.packet_type])
+ response.hci_packet.packet
)
raise ValueError('unsupported response type')
async def write(self, packet):
await self.hci_device.write(
PacketRequest(
hci_packet=HCIPacket(packet_type=packet[0], packet=packet[1:])
)
)
name = DEFAULT_NAME if options is None else options.get('name', DEFAULT_NAME)
manufacturer = DEFAULT_MANUFACTURER
# Connect as a host
service = PacketStreamerStub(channel)
hci_device = HciDevice(
name=name,
manufacturer=manufacturer,
hci_device=service.StreamPackets(),
)
await hci_device.start()
# Create the transport object
class GrpcTransport(PumpedTransport):
async def close(self):
await super().close()
await channel.close()
transport = GrpcTransport(
PumpedPacketSource(hci_device.read),
PumpedPacketSink(hci_device.write),
)
transport.start()
return transport
# -----------------------------------------------------------------------------
async def open_android_netsim_transport(spec: Optional[str]) -> Transport:
'''
Open a transport connection as a client or server, implementing Android's `netsim`
simulator protocol over gRPC.
The parameter string has this syntax:
[<host>:<port>][<options>]
Where <options> is a ','-separated list of <name>=<value> pairs.
General options:
mode=host|controller (default: host)
Specifies whether the transport is used
to connect *to* a netsim server (netsim is the controller), or accept
connections *as* a netsim-compatible server.
instance=<n>
Specifies an instance number, with <n> > 0. This is used to determine which
.init file to use. In `host` mode, it is ignored when the <host>:<port>
specifier is present, since in that case no .ini file is used.
In `host` mode:
The <host>:<port> part is optional. When not specified, the transport
looks for a netsim .ini file, from which it will read the `grpc.backend.port`
property.
Options for this mode are:
name=<name>
The "chip" name, used to identify the "chip" instance. This
may be useful when several clients are connected, since each needs to use a
different name.
In `controller` mode:
The <host>:<port> part is required. <host> may be the address of a local network
interface, or '_' to accept connections on all local network interfaces.
Examples:
(empty string) --> connect to netsim on the port specified in the .ini file
localhost:8555 --> connect to netsim on localhost:8555
name=bumble1 --> connect to netsim, using `bumble1` as the "chip" name.
localhost:8555,name=bumble1 --> connect to netsim on localhost:8555, using
`bumble1` as the "chip" name.
_:8877,mode=controller --> accept connections as a controller on any interface
on port 8877.
'''
# Parse the parameters
params = spec.split(',') if spec else []
if params and ':' in params[0]:
# Explicit <host>:<port>
host, port_str = params[0].split(':')
port = int(port_str)
params_offset = 1
else:
host = None
port = 0
params_offset = 0
options: Dict[str, str] = {}
for param in params[params_offset:]:
if '=' not in param:
raise ValueError('invalid parameter, expected <name>=<value>')
option_name, option_value = param.split('=')
options[option_name] = option_value
mode = options.get('mode', 'host')
if mode == 'host':
return await open_android_netsim_host_transport_with_address(
host, port, options
)
if mode == 'controller':
if host is None:
raise ValueError('<host>:<port> missing')
return await open_android_netsim_controller_transport(host, port, options)
raise ValueError('invalid mode option')
+85 -230
View File
@@ -15,17 +15,12 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations
import contextlib
import struct import struct
import asyncio import asyncio
import logging import logging
import io from colors import color
from typing import Any, ContextManager, Tuple, Optional, Protocol, Dict
from bumble import hci from .. import hci
from bumble.colors import color
from bumble.snoop import Snooper
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -37,109 +32,76 @@ logger = logging.getLogger(__name__)
# Information needed to parse HCI packets with a generic parser: # Information needed to parse HCI packets with a generic parser:
# For each packet type, the info represents: # For each packet type, the info represents:
# (length-size, length-offset, unpack-type) # (length-size, length-offset, unpack-type)
HCI_PACKET_INFO: Dict[int, Tuple[int, int, str]] = { HCI_PACKET_INFO = {
hci.HCI_COMMAND_PACKET: (1, 2, 'B'), hci.HCI_COMMAND_PACKET: (1, 2, 'B'),
hci.HCI_ACL_DATA_PACKET: (2, 2, 'H'), hci.HCI_ACL_DATA_PACKET: (2, 2, 'H'),
hci.HCI_SYNCHRONOUS_DATA_PACKET: (1, 2, 'B'), hci.HCI_SYNCHRONOUS_DATA_PACKET: (1, 2, 'B'),
hci.HCI_EVENT_PACKET: (1, 1, 'B'), hci.HCI_EVENT_PACKET: (1, 1, 'B')
hci.HCI_ISO_DATA_PACKET: (2, 2, 'H'),
} }
# -----------------------------------------------------------------------------
# Errors
# -----------------------------------------------------------------------------
class TransportLostError(Exception):
"""
The Transport has been lost/disconnected.
"""
# -----------------------------------------------------------------------------
# Typing Protocols
# -----------------------------------------------------------------------------
class TransportSink(Protocol):
def on_packet(self, packet: bytes) -> None: ...
class TransportSource(Protocol):
terminated: asyncio.Future[None]
def set_packet_sink(self, sink: TransportSink) -> None: ...
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class PacketPump: class PacketPump:
""" '''
Pump HCI packets from a reader to a sink. Pump HCI packets from a reader to a sink
""" '''
def __init__(self, reader: AsyncPacketReader, sink: TransportSink) -> None: def __init__(self, reader, sink):
self.reader = reader self.reader = reader
self.sink = sink self.sink = sink
async def run(self) -> None: async def run(self):
while True: while True:
try: try:
# Get a packet from the source
packet = hci.HCI_Packet.from_bytes(await self.reader.next_packet())
# Deliver the packet to the sink # Deliver the packet to the sink
self.sink.on_packet(await self.reader.next_packet()) self.sink.on_packet(packet)
except Exception as error: except Exception as error:
logger.warning(f'!!! {error}') logger.warning(f'!!! {error}')
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class PacketParser: class PacketParser:
""" '''
In-line parser that accepts data and emits 'on_packet' when a full packet has been In-line parser that accepts data and emits 'on_packet' when a full packet has been parsed
parsed. '''
""" NEED_TYPE = 0
# pylint: disable=attribute-defined-outside-init
NEED_TYPE = 0
NEED_LENGTH = 1 NEED_LENGTH = 1
NEED_BODY = 2 NEED_BODY = 2
sink: Optional[TransportSink] def __init__(self, sink = None):
extended_packet_info: Dict[int, Tuple[int, int, str]]
packet_info: Optional[Tuple[int, int, str]] = None
def __init__(self, sink: Optional[TransportSink] = None) -> None:
self.sink = sink self.sink = sink
self.extended_packet_info = {} self.extended_packet_info = {}
self.reset() self.reset()
def reset(self) -> None: def reset(self):
self.state = PacketParser.NEED_TYPE self.state = PacketParser.NEED_TYPE
self.bytes_needed = 1 self.bytes_needed = 1
self.packet = bytearray() self.packet = bytearray()
self.packet_info = None self.packet_info = None
def feed_data(self, data: bytes) -> None: def feed_data(self, data):
data_offset = 0 data_offset = 0
data_left = len(data) data_left = len(data)
while data_left and self.bytes_needed: while data_left and self.bytes_needed:
consumed = min(self.bytes_needed, data_left) consumed = min(self.bytes_needed, data_left)
self.packet.extend(data[data_offset : data_offset + consumed]) self.packet.extend(data[data_offset:data_offset + consumed])
data_offset += consumed data_offset += consumed
data_left -= consumed data_left -= consumed
self.bytes_needed -= consumed self.bytes_needed -= consumed
if self.bytes_needed == 0: if self.bytes_needed == 0:
if self.state == PacketParser.NEED_TYPE: if self.state == PacketParser.NEED_TYPE:
packet_type = self.packet[0] packet_type = self.packet[0]
self.packet_info = HCI_PACKET_INFO.get( self.packet_info = HCI_PACKET_INFO.get(packet_type) or self.extended_packet_info.get(packet_type)
packet_type
) or self.extended_packet_info.get(packet_type)
if self.packet_info is None: if self.packet_info is None:
raise ValueError(f'invalid packet type {packet_type}') raise ValueError(f'invalid packet type {packet_type}')
self.state = PacketParser.NEED_LENGTH self.state = PacketParser.NEED_LENGTH
self.bytes_needed = self.packet_info[0] + self.packet_info[1] self.bytes_needed = self.packet_info[0] + self.packet_info[1]
elif self.state == PacketParser.NEED_LENGTH: elif self.state == PacketParser.NEED_LENGTH:
assert self.packet_info is not None body_length = struct.unpack_from(self.packet_info[2], self.packet, 1 + self.packet_info[1])[0]
body_length = struct.unpack_from(
self.packet_info[2], self.packet, 1 + self.packet_info[1]
)[0]
self.bytes_needed = body_length self.bytes_needed = body_length
self.state = PacketParser.NEED_BODY self.state = PacketParser.NEED_BODY
@@ -149,36 +111,32 @@ class PacketParser:
try: try:
self.sink.on_packet(bytes(self.packet)) self.sink.on_packet(bytes(self.packet))
except Exception as error: except Exception as error:
logger.exception( logger.warning(color(f'!!! Exception in on_packet: {error}', 'red'))
color(f'!!! Exception in on_packet: {error}', 'red')
)
self.reset() self.reset()
def set_packet_sink(self, sink: TransportSink) -> None: def set_packet_sink(self, sink):
self.sink = sink self.sink = sink
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class PacketReader: class PacketReader:
""" '''
Reader that reads HCI packets from a sync source. Reader that reads HCI packets from a sync source
""" '''
def __init__(self, source: io.BufferedReader) -> None: def __init__(self, source):
self.source = source self.source = source
self.at_end = False
def next_packet(self) -> Optional[bytes]: def next_packet(self):
# Get the packet type # Get the packet type
packet_type = self.source.read(1) packet_type = self.source.read(1)
if len(packet_type) != 1: if len(packet_type) != 1:
self.at_end = True
return None return None
# Get the packet info based on its type # Get the packet info based on its type
packet_info = HCI_PACKET_INFO.get(packet_type[0]) packet_info = HCI_PACKET_INFO.get(packet_type[0])
if packet_info is None: if packet_info is None:
raise ValueError(f'invalid packet type {packet_type[0]} found') raise ValueError(f'invalid packet type {packet_type} found')
# Read the header (that includes the length) # Read the header (that includes the length)
header_size = packet_info[0] + packet_info[1] header_size = packet_info[0] + packet_info[1]
@@ -197,21 +155,21 @@ class PacketReader:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class AsyncPacketReader: class AsyncPacketReader:
""" '''
Reader that reads HCI packets from an async source. Reader that reads HCI packets from an async source
""" '''
def __init__(self, source: asyncio.StreamReader) -> None: def __init__(self, source):
self.source = source self.source = source
async def next_packet(self) -> bytes: async def next_packet(self):
# Get the packet type # Get the packet type
packet_type = await self.source.readexactly(1) packet_type = await self.source.readexactly(1)
# Get the packet info based on its type # Get the packet info based on its type
packet_info = HCI_PACKET_INFO.get(packet_type[0]) packet_info = HCI_PACKET_INFO.get(packet_type[0])
if packet_info is None: if packet_info is None:
raise ValueError(f'invalid packet type {packet_type[0]} found') raise ValueError(f'invalid packet type {packet_type} found')
# Read the header (that includes the length) # Read the header (that includes the length)
header_size = packet_info[0] + packet_info[1] header_size = packet_info[0] + packet_info[1]
@@ -226,15 +184,14 @@ class AsyncPacketReader:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class AsyncPipeSink: class AsyncPipeSink:
""" '''
Sink that forwards packets asynchronously to another sink. Sink that forwards packets asynchronously to another sink
""" '''
def __init__(self, sink):
def __init__(self, sink: TransportSink) -> None:
self.sink = sink self.sink = sink
self.loop = asyncio.get_running_loop() self.loop = asyncio.get_running_loop()
def on_packet(self, packet: bytes) -> None: def on_packet(self, packet):
self.loop.call_soon(self.sink.on_packet, packet) self.loop.call_soon(self.sink.on_packet, packet)
@@ -244,70 +201,43 @@ class ParserSource:
Base class designed to be subclassed by transport-specific source classes Base class designed to be subclassed by transport-specific source classes
""" """
terminated: asyncio.Future[None] def __init__(self):
parser: PacketParser self.parser = PacketParser()
def __init__(self) -> None:
self.parser = PacketParser()
self.terminated = asyncio.get_running_loop().create_future() self.terminated = asyncio.get_running_loop().create_future()
def set_packet_sink(self, sink: TransportSink) -> None: def set_packet_sink(self, sink):
self.parser.set_packet_sink(sink) self.parser.set_packet_sink(sink)
def on_transport_lost(self) -> None: async def wait_for_termination(self):
self.terminated.set_result(None)
if self.parser.sink:
if hasattr(self.parser.sink, 'on_transport_lost'):
self.parser.sink.on_transport_lost()
async def wait_for_termination(self) -> None:
"""
Convenience method for backward compatibility. Prefer using the `terminated`
attribute instead.
"""
return await self.terminated return await self.terminated
def close(self) -> None: def close(self):
pass pass
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class StreamPacketSource(asyncio.Protocol, ParserSource): class StreamPacketSource(asyncio.Protocol, ParserSource):
def data_received(self, data: bytes) -> None: def data_received(self, data):
self.parser.feed_data(data) self.parser.feed_data(data)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class StreamPacketSink: class StreamPacketSink:
def __init__(self, transport: asyncio.WriteTransport) -> None: def __init__(self, transport):
self.transport = transport self.transport = transport
def on_packet(self, packet: bytes) -> None: def on_packet(self, packet):
self.transport.write(packet) self.transport.write(packet)
def close(self) -> None: def close(self):
self.transport.close() self.transport.close()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Transport: class Transport:
""" def __init__(self, source, sink):
Base class for all transports.
A Transport represents a source and a sink together.
An instance must be closed by calling close() when no longer used. Instances
implement the ContextManager protocol so that they may be used in a `async with`
statement.
An instance is iterable. The iterator yields, in order, its source and sink, so
that it may be used with a convenient call syntax like:
async with create_transport() as (source, sink):
...
"""
def __init__(self, source: TransportSource, sink: TransportSink) -> None:
self.source = source self.source = source
self.sink = sink self.sink = sink
async def __aenter__(self): async def __aenter__(self):
return self return self
@@ -318,40 +248,35 @@ class Transport:
def __iter__(self): def __iter__(self):
return iter((self.source, self.sink)) return iter((self.source, self.sink))
async def close(self) -> None: async def close(self):
if hasattr(self.source, 'close'): self.source.close()
self.source.close() self.sink.close()
if hasattr(self.sink, 'close'):
self.sink.close()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class PumpedPacketSource(ParserSource): class PumpedPacketSource(ParserSource):
pump_task: Optional[asyncio.Task[None]] def __init__(self, receive):
def __init__(self, receive) -> None:
super().__init__() super().__init__()
self.receive_function = receive self.receive_function = receive
self.pump_task = None self.pump_task = None
def start(self) -> None: def start(self):
async def pump_packets() -> None: async def pump_packets():
while True: while True:
try: try:
packet = await self.receive_function() packet = await self.receive_function()
self.parser.feed_data(packet) self.parser.feed_data(packet)
except asyncio.CancelledError: except asyncio.exceptions.CancelledError:
logger.debug('source pump task done') logger.debug('source pump task done')
self.terminated.set_result(None)
break break
except Exception as error: except Exception as error:
logger.warning(f'exception while waiting for packet: {error}') logger.warn(f'exception while waiting for packet: {error}')
self.terminated.set_exception(error) self.terminated.set_result(error)
break break
self.pump_task = asyncio.create_task(pump_packets()) self.pump_task = asyncio.get_running_loop().create_task(pump_packets())
def close(self) -> None: def close(self):
if self.pump_task: if self.pump_task:
self.pump_task.cancel() self.pump_task.cancel()
@@ -360,10 +285,10 @@ class PumpedPacketSource(ParserSource):
class PumpedPacketSink: class PumpedPacketSink:
def __init__(self, send): def __init__(self, send):
self.send_function = send self.send_function = send
self.packet_queue = asyncio.Queue() self.packet_queue = asyncio.Queue()
self.pump_task = None self.pump_task = None
def on_packet(self, packet: bytes) -> None: def on_packet(self, packet):
self.packet_queue.put_nowait(packet) self.packet_queue.put_nowait(packet)
def start(self): def start(self):
@@ -372,14 +297,14 @@ class PumpedPacketSink:
try: try:
packet = await self.packet_queue.get() packet = await self.packet_queue.get()
await self.send_function(packet) await self.send_function(packet)
except asyncio.CancelledError: except asyncio.exceptions.CancelledError:
logger.debug('sink pump task done') logger.debug('sink pump task done')
break break
except Exception as error: except Exception as error:
logger.warning(f'exception while sending packet: {error}') logger.warn(f'exception while sending packet: {error}')
break break
self.pump_task = asyncio.create_task(pump_packets()) self.pump_task = asyncio.get_running_loop().create_task(pump_packets())
def close(self): def close(self):
if self.pump_task: if self.pump_task:
@@ -388,84 +313,14 @@ class PumpedPacketSink:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class PumpedTransport(Transport): class PumpedTransport(Transport):
source: PumpedPacketSource def __init__(self, source, sink, close_function):
sink: PumpedPacketSink
def __init__(
self,
source: PumpedPacketSource,
sink: PumpedPacketSink,
) -> None:
super().__init__(source, sink) super().__init__(source, sink)
self.close_function = close_function
def start(self) -> None: def start(self):
self.source.start() self.source.start()
self.sink.start() self.sink.start()
# -----------------------------------------------------------------------------
class SnoopingTransport(Transport):
"""Transport wrapper that snoops on packets to/from a wrapped transport."""
@staticmethod
def create_with(
transport: Transport, snooper: ContextManager[Snooper]
) -> SnoopingTransport:
"""
Create an instance given a snooper that works as as context manager.
The returned instance will exit the snooper context when it is closed.
"""
with contextlib.ExitStack() as exit_stack:
return SnoopingTransport(
transport, exit_stack.enter_context(snooper), exit_stack.pop_all().close
)
raise RuntimeError('unexpected code path') # Satisfy the type checker
class Source:
sink: TransportSink
@property
def metadata(self) -> dict[str, Any]:
return getattr(self.source, 'metadata', {})
def __init__(self, source: TransportSource, snooper: Snooper):
self.source = source
self.snooper = snooper
self.terminated = source.terminated
def set_packet_sink(self, sink: TransportSink) -> None:
self.sink = sink
self.source.set_packet_sink(self)
def on_packet(self, packet: bytes) -> None:
self.snooper.snoop(packet, Snooper.Direction.CONTROLLER_TO_HOST)
if self.sink:
self.sink.on_packet(packet)
class Sink:
def __init__(self, sink: TransportSink, snooper: Snooper) -> None:
self.sink = sink
self.snooper = snooper
def on_packet(self, packet: bytes) -> None:
self.snooper.snoop(packet, Snooper.Direction.HOST_TO_CONTROLLER)
if self.sink:
self.sink.on_packet(packet)
def __init__(
self,
transport: Transport,
snooper: Snooper,
close_snooper=None,
) -> None:
super().__init__(
self.Source(transport.source, snooper), self.Sink(transport.sink, snooper)
)
self.transport = transport
self.close_snooper = close_snooper
async def close(self): async def close(self):
await self.transport.close() await super().close()
if self.close_snooper: await self.close_function()
self.close_snooper()
@@ -1,10 +1,25 @@
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT! # Generated by the protocol buffer compiler. DO NOT EDIT!
# source: emulated_bluetooth_packets.proto # source: emulated_bluetooth_packets.proto
"""Generated protocol buffer code.""" """Generated protocol buffer code."""
from google.protobuf.internal import builder as _builder
from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports) # @@protoc_insertion_point(imports)
@@ -15,8 +30,17 @@ _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n emulated_bluetooth_packets.proto\x12\x1b\x61ndroid.emulation.bluetooth\"\xfb\x01\n\tHCIPacket\x12?\n\x04type\x18\x01 \x01(\x0e\x32\x31.android.emulation.bluetooth.HCIPacket.PacketType\x12\x0e\n\x06packet\x18\x02 \x01(\x0c\"\x9c\x01\n\nPacketType\x12\x1b\n\x17PACKET_TYPE_UNSPECIFIED\x10\x00\x12\x1b\n\x17PACKET_TYPE_HCI_COMMAND\x10\x01\x12\x13\n\x0fPACKET_TYPE_ACL\x10\x02\x12\x13\n\x0fPACKET_TYPE_SCO\x10\x03\x12\x15\n\x11PACKET_TYPE_EVENT\x10\x04\x12\x13\n\x0fPACKET_TYPE_ISO\x10\x05\x42J\n\x1f\x63om.android.emulation.bluetoothP\x01\xf8\x01\x01\xa2\x02\x03\x41\x45\x42\xaa\x02\x1b\x41ndroid.Emulation.Bluetoothb\x06proto3') DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n emulated_bluetooth_packets.proto\x12\x1b\x61ndroid.emulation.bluetooth\"\xfb\x01\n\tHCIPacket\x12?\n\x04type\x18\x01 \x01(\x0e\x32\x31.android.emulation.bluetooth.HCIPacket.PacketType\x12\x0e\n\x06packet\x18\x02 \x01(\x0c\"\x9c\x01\n\nPacketType\x12\x1b\n\x17PACKET_TYPE_UNSPECIFIED\x10\x00\x12\x1b\n\x17PACKET_TYPE_HCI_COMMAND\x10\x01\x12\x13\n\x0fPACKET_TYPE_ACL\x10\x02\x12\x13\n\x0fPACKET_TYPE_SCO\x10\x03\x12\x15\n\x11PACKET_TYPE_EVENT\x10\x04\x12\x13\n\x0fPACKET_TYPE_ISO\x10\x05\x42J\n\x1f\x63om.android.emulation.bluetoothP\x01\xf8\x01\x01\xa2\x02\x03\x41\x45\x42\xaa\x02\x1b\x41ndroid.Emulation.Bluetoothb\x06proto3')
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'emulated_bluetooth_packets_pb2', globals())
_HCIPACKET = DESCRIPTOR.message_types_by_name['HCIPacket']
_HCIPACKET_PACKETTYPE = _HCIPACKET.enum_types_by_name['PacketType']
HCIPacket = _reflection.GeneratedProtocolMessageType('HCIPacket', (_message.Message,), {
'DESCRIPTOR' : _HCIPACKET,
'__module__' : 'emulated_bluetooth_packets_pb2'
# @@protoc_insertion_point(class_scope:android.emulation.bluetooth.HCIPacket)
})
_sym_db.RegisterMessage(HCIPacket)
if _descriptor._USE_C_DESCRIPTORS == False: if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None DESCRIPTOR._options = None
@@ -0,0 +1,53 @@
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: emulated_bluetooth.proto
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
from . import emulated_bluetooth_packets_pb2 as emulated__bluetooth__packets__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x18\x65mulated_bluetooth.proto\x12\x1b\x61ndroid.emulation.bluetooth\x1a emulated_bluetooth_packets.proto\"\x19\n\x07RawData\x12\x0e\n\x06packet\x18\x01 \x01(\x0c\x32\xcb\x02\n\x18\x45mulatedBluetoothService\x12\x64\n\x12registerClassicPhy\x12$.android.emulation.bluetooth.RawData\x1a$.android.emulation.bluetooth.RawData(\x01\x30\x01\x12`\n\x0eregisterBlePhy\x12$.android.emulation.bluetooth.RawData\x1a$.android.emulation.bluetooth.RawData(\x01\x30\x01\x12g\n\x11registerHCIDevice\x12&.android.emulation.bluetooth.HCIPacket\x1a&.android.emulation.bluetooth.HCIPacket(\x01\x30\x01\x42\"\n\x1e\x63om.android.emulator.bluetoothP\x01\x62\x06proto3')
_RAWDATA = DESCRIPTOR.message_types_by_name['RawData']
RawData = _reflection.GeneratedProtocolMessageType('RawData', (_message.Message,), {
'DESCRIPTOR' : _RAWDATA,
'__module__' : 'emulated_bluetooth_pb2'
# @@protoc_insertion_point(class_scope:android.emulation.bluetooth.RawData)
})
_sym_db.RegisterMessage(RawData)
_EMULATEDBLUETOOTHSERVICE = DESCRIPTOR.services_by_name['EmulatedBluetoothService']
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
DESCRIPTOR._serialized_options = b'\n\036com.android.emulator.bluetoothP\001'
_RAWDATA._serialized_start=91
_RAWDATA._serialized_end=116
_EMULATEDBLUETOOTHSERVICE._serialized_start=119
_EMULATEDBLUETOOTHSERVICE._serialized_end=450
# @@protoc_insertion_point(module_scope)

Some files were not shown because too many files have changed in this diff Show More