Compare commits

..

1 Commits

Author SHA1 Message Date
Charlie Boutier 7237619d3b A2DP example: Codec selection based on file type
Currently support SBC and AAC
2025-05-08 14:24:42 -07:00
258 changed files with 17264 additions and 28796 deletions
+4 -6
View File
@@ -6,8 +6,6 @@ on:
branches: [ main ] branches: [ main ]
pull_request: pull_request:
branches: [ main ] branches: [ main ]
workflow_dispatch:
branches: [main]
permissions: permissions:
contents: read contents: read
@@ -18,24 +16,24 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
matrix: matrix:
python-version: ["3.10", "3.11", "3.12", "3.13.0", "3.14"] python-version: ["3.9", "3.10", "3.11", "3.12", "3.13.0"]
fail-fast: false fail-fast: false
steps: steps:
- name: Check out from Git - name: Check out from Git
uses: actions/checkout@v6 uses: actions/checkout@v3
- name: Get history and tags for SCM versioning to work - name: Get history and tags for SCM versioning to work
run: | run: |
git fetch --prune --unshallow git fetch --prune --unshallow
git fetch --depth=1 origin +refs/tags/*:refs/tags/* git fetch --depth=1 origin +refs/tags/*:refs/tags/*
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v6 uses: actions/setup-python@v3
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
python -m pip install ".[build,test,development]" python -m pip install ".[build,examples,test,development]"
- name: Check - name: Check
run: | run: |
invoke project.pre-commit invoke project.pre-commit
+1 -3
View File
@@ -17,8 +17,6 @@ on:
pull_request: pull_request:
# The branches below must be a subset of the branches above # The branches below must be a subset of the branches above
branches: [ main ] branches: [ main ]
workflow_dispatch:
branches: [main]
schedule: schedule:
- cron: '39 21 * * 4' - cron: '39 21 * * 4'
@@ -40,7 +38,7 @@ jobs:
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v6 uses: actions/checkout@v3
# Initializes the CodeQL tools for scanning. # Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL - name: Initialize CodeQL
+2 -6
View File
@@ -7,10 +7,6 @@ on:
branches: [ main ] branches: [ main ]
paths: paths:
- 'extras/android/BtBench/**' - 'extras/android/BtBench/**'
workflow_dispatch:
branches: [main]
paths:
- 'extras/android/BtBench/**'
permissions: permissions:
contents: read contents: read
@@ -22,10 +18,10 @@ jobs:
steps: steps:
- name: Check out from Git - name: Check out from Git
uses: actions/checkout@v6 uses: actions/checkout@v3
- name: Set up JDK - name: Set up JDK
uses: actions/setup-java@v5 uses: actions/setup-java@v4
with: with:
distribution: 'zulu' distribution: 'zulu'
java-version: 17 java-version: 17
+3 -5
View File
@@ -5,8 +5,6 @@ on:
branches: [ main ] branches: [ main ]
pull_request: pull_request:
branches: [ main ] branches: [ main ]
workflow_dispatch:
branches: [main]
permissions: permissions:
contents: read contents: read
@@ -26,9 +24,9 @@ jobs:
21/24, 22/24, 23/24, 24/24, 21/24, 22/24, 23/24, 24/24,
] ]
steps: steps:
- uses: actions/checkout@v6 - uses: actions/checkout@v3
- name: Set Up Python 3.11 - name: Set Up Python 3.11
uses: actions/setup-python@v6 uses: actions/setup-python@v4
with: with:
python-version: 3.11 python-version: 3.11
- name: Install - name: Install
@@ -46,7 +44,7 @@ jobs:
run: cat rootcanal.log run: cat rootcanal.log
- name: Upload Mobly logs - name: Upload Mobly logs
if: always() if: always()
uses: actions/upload-artifact@v6 uses: actions/upload-artifact@v4
with: with:
name: mobly-logs-${{ strategy.job-index }} name: mobly-logs-${{ strategy.job-index }}
path: /tmp/logs/mobly/bumble.bumbles/ path: /tmp/logs/mobly/bumble.bumbles/
+9 -12
View File
@@ -6,8 +6,6 @@ on:
branches: [ main ] branches: [ main ]
pull_request: pull_request:
branches: [ main ] branches: [ main ]
workflow_dispatch:
branches: [main]
permissions: permissions:
contents: read contents: read
@@ -18,18 +16,18 @@ jobs:
strategy: strategy:
matrix: matrix:
os: ['ubuntu-latest', 'macos-latest', 'windows-latest'] os: ['ubuntu-latest', 'macos-latest', 'windows-latest']
python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"] python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
fail-fast: false fail-fast: false
steps: steps:
- name: Check out from Git - name: Check out from Git
uses: actions/checkout@v6 uses: actions/checkout@v3
- name: Get history and tags for SCM versioning to work - name: Get history and tags for SCM versioning to work
run: | run: |
git fetch --prune --unshallow git fetch --prune --unshallow
git fetch --depth=1 origin +refs/tags/*:refs/tags/* git fetch --depth=1 origin +refs/tags/*:refs/tags/*
- name: Set up Python ${{ matrix.python-version }} - name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v6 uses: actions/setup-python@v4
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
- name: Install dependencies - name: Install dependencies
@@ -48,15 +46,14 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
matrix: matrix:
# Rust runtime doesn't support 3.14 yet. python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
python-version: ["3.10", "3.11", "3.12", "3.13"] rust-version: [ "1.76.0", "stable" ]
rust-version: [ "1.80.0", "1.91.0" ]
fail-fast: false fail-fast: false
steps: steps:
- name: Check out from Git - name: Check out from Git
uses: actions/checkout@v6 uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }} - name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v6 uses: actions/setup-python@v4
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
- name: Install Python dependencies - name: Install Python dependencies
@@ -69,11 +66,11 @@ jobs:
components: clippy,rustfmt components: clippy,rustfmt
toolchain: ${{ matrix.rust-version }} toolchain: ${{ matrix.rust-version }}
- name: Install Rust dependencies - name: Install Rust dependencies
run: cargo install cargo-all-features --version 1.11.0 --locked # allows building/testing combinations of features run: cargo install cargo-all-features # allows building/testing combinations of features
- name: Check License Headers - name: Check License Headers
run: cd rust && cargo run --features dev-tools --bin file-header check-all run: cd rust && cargo run --features dev-tools --bin file-header check-all
- name: Rust Build - name: Rust Build
run: cd rust && cargo build --all-targets && cargo build-all-features run: cd rust && cargo build --all-targets && cargo build-all-features --all-targets
# Lints after build so what clippy needs is already built # Lints after build so what clippy needs is already built
- name: Rust Lints - name: Rust Lints
run: cd rust && cargo fmt --check && cargo clippy --all-targets -- --deny warnings && cargo clippy --all-features --all-targets -- --deny warnings run: cd rust && cargo fmt --check && cargo clippy --all-targets -- --deny warnings && cargo clippy --all-features --all-targets -- --deny warnings
+3 -3
View File
@@ -14,13 +14,13 @@ jobs:
steps: steps:
- name: Check out from Git - name: Check out from Git
uses: actions/checkout@v6 uses: actions/checkout@v3
- name: Get history and tags for SCM versioning to work - name: Get history and tags for SCM versioning to work
run: | run: |
git fetch --prune --unshallow git fetch --prune --unshallow
git fetch --depth=1 origin +refs/tags/*:refs/tags/* git fetch --depth=1 origin +refs/tags/*:refs/tags/*
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v6 uses: actions/setup-python@v3
with: with:
python-version: '3.10' python-version: '3.10'
- name: Install dependencies - name: Install dependencies
@@ -31,7 +31,7 @@ jobs:
run: python -m build run: python -m build
- name: Publish package to PyPI - name: Publish package to PyPI
if: github.event_name == 'release' && startsWith(github.ref, 'refs/tags') if: github.event_name == 'release' && startsWith(github.ref, 'refs/tags')
uses: pypa/gh-action-pypi-publish@ed0c53931b1dc9bd32cbe73a98c7f6766f8a527e # release/v1.13 uses: pypa/gh-action-pypi-publish@release/v1
with: with:
user: __token__ user: __token__
password: ${{ secrets.PYPI_API_TOKEN }} password: ${{ secrets.PYPI_API_TOKEN }}
-3
View File
@@ -17,6 +17,3 @@ venv/
.venv/ .venv/
# snoop logs # snoop logs
out/ out/
# macOS
.DS_Store
._*
+1 -6
View File
@@ -102,10 +102,5 @@
"." "."
], ],
"python.testing.unittestEnabled": false, "python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true, "python.testing.pytestEnabled": true
"python-envs.defaultEnvManager": "ms-python.python:system",
"python-envs.pythonProjects": [],
"nrf-connect.applications": [
"${workspaceFolder}/extras/zephyr/hci_usb"
]
} }
+1 -1
View File
@@ -50,7 +50,7 @@ Bumble is easiest to use with a dedicated USB dongle.
This is because internal Bluetooth interfaces tend to be locked down by the operating system. This is because internal Bluetooth interfaces tend to be locked down by the operating system.
You can use the [usb_probe](/docs/mkdocs/src/apps_and_tools/usb_probe.md) tool (all platforms) or `lsusb` (Linux or macOS) to list the available USB devices on your system. You can use the [usb_probe](/docs/mkdocs/src/apps_and_tools/usb_probe.md) tool (all platforms) or `lsusb` (Linux or macOS) to list the available USB devices on your system.
See the [USB Transport](/docs/mkdocs/src/transports/usb.md) page for details on how to refer to USB devices. Also, if you are on a mac, see [these instructions](docs/mkdocs/src/platforms/macos.md). See the [USB Transport](/docs/mkdocs/src/transports/usb.md) page for details on how to refer to USB devices. Also, if your are on a mac, see [these instructions](docs/mkdocs/src/platforms/macos.md).
## License ## License
+3
View File
@@ -12,6 +12,9 @@ Apps
## `show.py` ## `show.py`
Parse a file with HCI packets and print the details of each packet in a human readable form Parse a file with HCI packets and print the details of each packet in a human readable form
## `link_relay.py`
Simple WebSocket relay for virtual RemoteLink instances to communicate with each other through.
## `hci_bridge.py` ## `hci_bridge.py`
This app acts as a simple bridge between two HCI transports, with a host on one side and This app acts as a simple bridge between two HCI transports, with a host on one side and
a controller on the other. All the HCI packets bridged between the two are printed on the console a controller on the other. All the HCI packets bridged between the two are printed on the console
+515 -765
View File
File diff suppressed because it is too large Load Diff
+45 -488
View File
@@ -19,45 +19,33 @@ import asyncio
import dataclasses import dataclasses
import enum import enum
import logging import logging
import os
import statistics import statistics
import struct import struct
import time import time
import click import click
import bumble.core
import bumble.logging
import bumble.rfcomm
from bumble import l2cap from bumble import l2cap
from bumble.colors import color
from bumble.core import ( from bumble.core import (
PhysicalTransport,
BT_L2CAP_PROTOCOL_ID, BT_L2CAP_PROTOCOL_ID,
BT_RFCOMM_PROTOCOL_ID, BT_RFCOMM_PROTOCOL_ID,
UUID, UUID,
CommandTimeoutError, CommandTimeoutError,
ConnectionPHY,
PhysicalTransport,
)
from bumble.device import (
CigParameters,
CisLink,
Connection,
ConnectionParametersPreferences,
Device,
Peer,
) )
from bumble.colors import color
from bumble.device import Connection, ConnectionParametersPreferences, Device, Peer
from bumble.gatt import Characteristic, CharacteristicValue, Service from bumble.gatt import Characteristic, CharacteristicValue, Service
from bumble.hci import ( from bumble.hci import (
HCI_LE_1M_PHY, HCI_LE_1M_PHY,
HCI_LE_2M_PHY, HCI_LE_2M_PHY,
HCI_LE_CODED_PHY, HCI_LE_CODED_PHY,
Role,
HCI_Constant, HCI_Constant,
HCI_Error, HCI_Error,
HCI_IsoDataPacket,
HCI_StatusError, HCI_StatusError,
Role,
) )
from bumble.pairing import PairingConfig
from bumble.sdp import ( from bumble.sdp import (
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID, SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
@@ -67,8 +55,12 @@ from bumble.sdp import (
DataElement, DataElement,
ServiceAttribute, ServiceAttribute,
) )
from bumble.transport import open_transport from bumble.transport import open_transport_or_link
import bumble.rfcomm
import bumble.core
from bumble.utils import AsyncRunner from bumble.utils import AsyncRunner
from bumble.pairing import PairingConfig
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -83,28 +75,17 @@ DEFAULT_CENTRAL_ADDRESS = 'F0:F0:F0:F0:F0:F0'
DEFAULT_CENTRAL_NAME = 'Speed Central' DEFAULT_CENTRAL_NAME = 'Speed Central'
DEFAULT_PERIPHERAL_ADDRESS = 'F1:F1:F1:F1:F1:F1' DEFAULT_PERIPHERAL_ADDRESS = 'F1:F1:F1:F1:F1:F1'
DEFAULT_PERIPHERAL_NAME = 'Speed Peripheral' DEFAULT_PERIPHERAL_NAME = 'Speed Peripheral'
DEFAULT_ADVERTISING_INTERVAL = 100
SPEED_SERVICE_UUID = '50DB505C-8AC4-4738-8448-3B1D9CC09CC5' SPEED_SERVICE_UUID = '50DB505C-8AC4-4738-8448-3B1D9CC09CC5'
SPEED_TX_UUID = 'E789C754-41A1-45F4-A948-A0A1A90DBA53' SPEED_TX_UUID = 'E789C754-41A1-45F4-A948-A0A1A90DBA53'
SPEED_RX_UUID = '016A2CC7-E14B-4819-935F-1F56EAE4098D' SPEED_RX_UUID = '016A2CC7-E14B-4819-935F-1F56EAE4098D'
DEFAULT_RFCOMM_UUID = 'E6D55659-C8B4-4B85-96BB-B1143AF6D3AE' DEFAULT_RFCOMM_UUID = 'E6D55659-C8B4-4B85-96BB-B1143AF6D3AE'
DEFAULT_L2CAP_PSM = 128 DEFAULT_L2CAP_PSM = 128
DEFAULT_L2CAP_MAX_CREDITS = 128 DEFAULT_L2CAP_MAX_CREDITS = 128
DEFAULT_L2CAP_MTU = 1024 DEFAULT_L2CAP_MTU = 1024
DEFAULT_L2CAP_MPS = 1024 DEFAULT_L2CAP_MPS = 1024
DEFAULT_ISO_MAX_SDU_C_TO_P = 251
DEFAULT_ISO_MAX_SDU_P_TO_C = 251
DEFAULT_ISO_SDU_INTERVAL_C_TO_P = 10000
DEFAULT_ISO_SDU_INTERVAL_P_TO_C = 10000
DEFAULT_ISO_MAX_TRANSPORT_LATENCY_C_TO_P = 35
DEFAULT_ISO_MAX_TRANSPORT_LATENCY_P_TO_C = 35
DEFAULT_ISO_RTN_C_TO_P = 3
DEFAULT_ISO_RTN_P_TO_C = 3
DEFAULT_LINGER_TIME = 1.0 DEFAULT_LINGER_TIME = 1.0
DEFAULT_POST_CONNECTION_WAIT_TIME = 1.0 DEFAULT_POST_CONNECTION_WAIT_TIME = 1.0
@@ -121,14 +102,14 @@ def le_phy_name(phy_id):
) )
def print_connection_phy(phy: ConnectionPHY) -> None: def print_connection_phy(phy):
logging.info( logging.info(
color('@@@ PHY: ', 'yellow') + f'TX:{le_phy_name(phy.tx_phy)}/' color('@@@ PHY: ', 'yellow') + f'TX:{le_phy_name(phy.tx_phy)}/'
f'RX:{le_phy_name(phy.rx_phy)}' f'RX:{le_phy_name(phy.rx_phy)}'
) )
def print_connection(connection: Connection) -> None: def print_connection(connection):
params = [] params = []
if connection.transport == PhysicalTransport.LE: if connection.transport == PhysicalTransport.LE:
params.append( params.append(
@@ -153,34 +134,6 @@ def print_connection(connection: Connection) -> None:
logging.info(color('@@@ Connection: ', 'yellow') + ' '.join(params)) logging.info(color('@@@ Connection: ', 'yellow') + ' '.join(params))
def print_cis_link(cis_link: CisLink) -> None:
logging.info(color("@@@ CIS established", "green"))
logging.info(color('@@@ ISO interval: ', 'green') + f"{cis_link.iso_interval}ms")
logging.info(color('@@@ NSE: ', 'green') + f"{cis_link.nse}")
logging.info(color('@@@ Central->Peripheral:', 'green'))
if cis_link.phy_c_to_p is not None:
logging.info(
color('@@@ PHY: ', 'green') + f"{cis_link.phy_c_to_p.name}"
)
logging.info(
color('@@@ Latency: ', 'green') + f"{cis_link.transport_latency_c_to_p}µs"
)
logging.info(color('@@@ BN: ', 'green') + f"{cis_link.bn_c_to_p}")
logging.info(color('@@@ FT: ', 'green') + f"{cis_link.ft_c_to_p}")
logging.info(color('@@@ Max PDU: ', 'green') + f"{cis_link.max_pdu_c_to_p}")
logging.info(color('@@@ Peripheral->Central:', 'green'))
if cis_link.phy_p_to_c is not None:
logging.info(
color('@@@ PHY: ', 'green') + f"{cis_link.phy_p_to_c.name}"
)
logging.info(
color('@@@ Latency: ', 'green') + f"{cis_link.transport_latency_p_to_c}µs"
)
logging.info(color('@@@ BN: ', 'green') + f"{cis_link.bn_p_to_c}")
logging.info(color('@@@ FT: ', 'green') + f"{cis_link.ft_p_to_c}")
logging.info(color('@@@ Max PDU: ', 'green') + f"{cis_link.max_pdu_p_to_c}")
def make_sdp_records(channel): def make_sdp_records(channel):
return { return {
0x00010001: [ 0x00010001: [
@@ -244,51 +197,6 @@ async def switch_roles(connection, role):
logging.info(f'{color("### Role switch failed:", "red")} {error}') logging.info(f'{color("### Role switch failed:", "red")} {error}')
async def pre_power_on(device: Device, classic: bool) -> None:
device.classic_enabled = classic
# Set up a pairing config factory with minimal requirements.
device.config.keystore = "JsonKeyStore"
device.pairing_config_factory = lambda _: PairingConfig(
sc=False, mitm=False, bonding=False
)
async def post_power_on(
device: Device,
le_scan: tuple[int, int] | None,
le_advertise: int | None,
classic_page_scan: bool,
classic_inquiry_scan: bool,
) -> None:
if classic_page_scan:
logging.info(color("*** Enabling page scan", "blue"))
await device.set_connectable(True)
if classic_inquiry_scan:
logging.info(color("*** Enabling inquiry scan", "blue"))
await device.set_discoverable(True)
if le_scan:
scan_window, scan_interval = le_scan
logging.info(
color(
f"*** Starting LE scanning [{scan_window}ms/{scan_interval}ms]",
"blue",
)
)
await device.start_scanning(
scan_interval=scan_interval, scan_window=scan_window
)
if le_advertise:
logging.info(color(f"*** Starting LE advertising [{le_advertise}ms]", "blue"))
await device.start_advertising(
advertising_interval_min=le_advertise,
advertising_interval_max=le_advertise,
auto_restart=True,
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Packet # Packet
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -489,7 +397,7 @@ class Sender:
flags=( flags=(
Packet.PacketFlags.LAST Packet.PacketFlags.LAST
if tx_i == self.tx_packet_count - 1 if tx_i == self.tx_packet_count - 1
else Packet.PacketFlags(0) else 0
), ),
sequence=tx_i, sequence=tx_i,
timestamp=int((time.time() - self.start_time) * 1000000), timestamp=int((time.time() - self.start_time) * 1000000),
@@ -506,8 +414,7 @@ class Sender:
self.bytes_sent += len(packet) self.bytes_sent += len(packet)
await self.packet_io.send_packet(packet) await self.packet_io.send_packet(packet)
if self.packet_io.can_receive(): await self.done.wait()
await self.done.wait()
run_counter = f'[{run + 1} of {self.repeat + 1}]' if self.repeat else '' run_counter = f'[{run + 1} of {self.repeat + 1}]' if self.repeat else ''
logging.info(color(f'=== {run_counter} Done!', 'magenta')) logging.info(color(f'=== {run_counter} Done!', 'magenta'))
@@ -537,9 +444,6 @@ class Sender:
) )
self.done.set() self.done.set()
def is_sender(self):
return True
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Receiver # Receiver
@@ -587,8 +491,7 @@ class Receiver:
logging.info( logging.info(
color( color(
f'!!! Unexpected packet, expected {self.expected_packet_index} ' f'!!! Unexpected packet, expected {self.expected_packet_index} '
f'but received {packet.sequence}', f'but received {packet.sequence}'
'red',
) )
) )
@@ -631,9 +534,6 @@ class Receiver:
await self.done.wait() await self.done.wait()
logging.info(color('=== Done!', 'magenta')) logging.info(color('=== Done!', 'magenta'))
def is_sender(self):
return False
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Ping # Ping
@@ -769,8 +669,7 @@ class Ping:
color( color(
f'!!! Unexpected packet, ' f'!!! Unexpected packet, '
f'expected {self.next_expected_packet_index} ' f'expected {self.next_expected_packet_index} '
f'but received {packet.sequence}', f'but received {packet.sequence}'
'red',
) )
) )
@@ -778,9 +677,6 @@ class Ping:
self.done.set() self.done.set()
return return
def is_sender(self):
return True
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Pong # Pong
@@ -825,8 +721,7 @@ class Pong:
logging.info( logging.info(
color( color(
f'!!! Unexpected packet, expected {self.expected_packet_index} ' f'!!! Unexpected packet, expected {self.expected_packet_index} '
f'but received {packet.sequence}', f'but received {packet.sequence}'
'red',
) )
) )
@@ -848,9 +743,6 @@ class Pong:
await self.done.wait() await self.done.wait()
logging.info(color('=== Done!', 'magenta')) logging.info(color('=== Done!', 'magenta'))
def is_sender(self):
return False
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# GattClient # GattClient
@@ -1014,9 +906,6 @@ class StreamedPacketIO:
# pylint: disable-next=not-callable # pylint: disable-next=not-callable
self.io_sink(struct.pack('>H', len(packet)) + packet) self.io_sink(struct.pack('>H', len(packet)) + packet)
def can_receive(self):
return True
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# L2capClient # L2capClient
@@ -1288,96 +1177,6 @@ class RfcommServer(StreamedPacketIO):
await self.dlc.drain() await self.dlc.drain()
# -----------------------------------------------------------------------------
# IsoClient
# -----------------------------------------------------------------------------
class IsoClient(StreamedPacketIO):
def __init__(
self,
device: Device,
) -> None:
super().__init__()
self.device = device
self.ready = asyncio.Event()
self.cis_link: CisLink | None = None
async def on_connection(
self, connection: Connection, cis_link: CisLink, sender: bool
) -> None:
connection.on(connection.EVENT_DISCONNECTION, self.on_disconnection)
self.cis_link = cis_link
self.io_sink = cis_link.write
await cis_link.setup_data_path(
cis_link.Direction.HOST_TO_CONTROLLER
if sender
else cis_link.Direction.CONTROLLER_TO_HOST
)
cis_link.sink = self.on_iso_packet
self.ready.set()
def on_iso_packet(self, iso_packet: HCI_IsoDataPacket) -> None:
self.on_packet(iso_packet.iso_sdu_fragment)
def on_disconnection(self, _):
pass
async def drain(self):
if self.cis_link is None:
return
await self.cis_link.drain()
def can_receive(self):
return False
# -----------------------------------------------------------------------------
# IsoServer
# -----------------------------------------------------------------------------
class IsoServer(StreamedPacketIO):
def __init__(
self,
device: Device,
):
super().__init__()
self.device = device
self.cis_link: CisLink | None = None
self.ready = asyncio.Event()
logging.info(
color(
'### Listening for ISO connection',
'yellow',
)
)
async def on_connection(
self, connection: Connection, cis_link: CisLink, sender: bool
) -> None:
connection.on(connection.EVENT_DISCONNECTION, self.on_disconnection)
self.io_sink = cis_link.write
await cis_link.setup_data_path(
cis_link.Direction.HOST_TO_CONTROLLER
if sender
else cis_link.Direction.CONTROLLER_TO_HOST
)
cis_link.sink = self.on_iso_packet
self.ready.set()
def on_iso_packet(self, iso_packet: HCI_IsoDataPacket) -> None:
self.on_packet(iso_packet.iso_sdu_fragment)
def on_disconnection(self, _):
pass
async def drain(self):
if self.cis_link is None:
return
await self.cis_link.drain()
def can_receive(self):
return False
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Central # Central
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -1386,52 +1185,26 @@ class Central(Connection.Listener):
self, self,
transport, transport,
peripheral_address, peripheral_address,
classic,
scenario_factory, scenario_factory,
mode_factory, mode_factory,
connection_interval, connection_interval,
phy, phy,
authenticate, authenticate,
encrypt, encrypt,
iso,
iso_sdu_interval_c_to_p,
iso_sdu_interval_p_to_c,
iso_max_sdu_c_to_p,
iso_max_sdu_p_to_c,
iso_max_transport_latency_c_to_p,
iso_max_transport_latency_p_to_c,
iso_rtn_c_to_p,
iso_rtn_p_to_c,
classic,
extended_data_length, extended_data_length,
role_switch, role_switch,
le_scan,
le_advertise,
classic_page_scan,
classic_inquiry_scan,
): ):
super().__init__() super().__init__()
self.transport = transport self.transport = transport
self.peripheral_address = peripheral_address self.peripheral_address = peripheral_address
self.classic = classic self.classic = classic
self.iso = iso
self.iso_sdu_interval_c_to_p = iso_sdu_interval_c_to_p
self.iso_sdu_interval_p_to_c = iso_sdu_interval_p_to_c
self.iso_max_sdu_c_to_p = iso_max_sdu_c_to_p
self.iso_max_sdu_p_to_c = iso_max_sdu_p_to_c
self.iso_max_transport_latency_c_to_p = iso_max_transport_latency_c_to_p
self.iso_max_transport_latency_p_to_c = iso_max_transport_latency_p_to_c
self.iso_rtn_c_to_p = iso_rtn_c_to_p
self.iso_rtn_p_to_c = iso_rtn_p_to_c
self.scenario_factory = scenario_factory self.scenario_factory = scenario_factory
self.mode_factory = mode_factory self.mode_factory = mode_factory
self.authenticate = authenticate self.authenticate = authenticate
self.encrypt = encrypt or authenticate self.encrypt = encrypt or authenticate
self.extended_data_length = extended_data_length self.extended_data_length = extended_data_length
self.role_switch = role_switch self.role_switch = role_switch
self.le_scan = le_scan
self.le_advertise = le_advertise
self.classic_page_scan = classic_page_scan
self.classic_inquiry_scan = classic_inquiry_scan
self.device = None self.device = None
self.connection = None self.connection = None
@@ -1468,7 +1241,7 @@ class Central(Connection.Listener):
async def run(self): async def run(self):
logging.info(color('>>> Connecting to HCI...', 'green')) logging.info(color('>>> Connecting to HCI...', 'green'))
async with await open_transport(self.transport) as ( async with await open_transport_or_link(self.transport) as (
hci_source, hci_source,
hci_sink, hci_sink,
): ):
@@ -1481,22 +1254,18 @@ class Central(Connection.Listener):
mode = self.mode_factory(self.device) mode = self.mode_factory(self.device)
scenario = self.scenario_factory(mode) scenario = self.scenario_factory(mode)
self.device.classic_enabled = self.classic self.device.classic_enabled = self.classic
self.device.cis_enabled = self.iso
# Set up a pairing config factory with minimal requirements. # Set up a pairing config factory with minimal requirements.
self.device.config.keystore = "JsonKeyStore"
self.device.pairing_config_factory = lambda _: PairingConfig( self.device.pairing_config_factory = lambda _: PairingConfig(
sc=False, mitm=False, bonding=False sc=False, mitm=False, bonding=False
) )
await pre_power_on(self.device, self.classic)
await self.device.power_on() await self.device.power_on()
await post_power_on(
self.device, if self.classic:
self.le_scan, await self.device.set_discoverable(False)
self.le_advertise, await self.device.set_connectable(False)
self.classic_page_scan,
self.classic_inquiry_scan,
)
logging.info( logging.info(
color(f'### Connecting to {self.peripheral_address}...', 'cyan') color(f'### Connecting to {self.peripheral_address}...', 'cyan')
@@ -1571,72 +1340,7 @@ class Central(Connection.Listener):
) )
) )
# Setup ISO streams. await mode.on_connection(self.connection)
if self.iso:
if scenario.is_sender():
sdu_interval_c_to_p = (
self.iso_sdu_interval_c_to_p or DEFAULT_ISO_SDU_INTERVAL_C_TO_P
)
sdu_interval_p_to_c = self.iso_sdu_interval_p_to_c or 0
max_transport_latency_c_to_p = (
self.iso_max_transport_latency_c_to_p
or DEFAULT_ISO_MAX_TRANSPORT_LATENCY_C_TO_P
)
max_transport_latency_p_to_c = (
self.iso_max_transport_latency_p_to_c or 0
)
max_sdu_c_to_p = (
self.iso_max_sdu_c_to_p or DEFAULT_ISO_MAX_SDU_C_TO_P
)
max_sdu_p_to_c = self.iso_max_sdu_p_to_c or 0
rtn_c_to_p = self.iso_rtn_c_to_p or DEFAULT_ISO_RTN_C_TO_P
rtn_p_to_c = self.iso_rtn_p_to_c or 0
else:
sdu_interval_p_to_c = (
self.iso_sdu_interval_p_to_c or DEFAULT_ISO_SDU_INTERVAL_P_TO_C
)
sdu_interval_c_to_p = self.iso_sdu_interval_c_to_p or 0
max_transport_latency_p_to_c = (
self.iso_max_transport_latency_p_to_c
or DEFAULT_ISO_MAX_TRANSPORT_LATENCY_P_TO_C
)
max_transport_latency_c_to_p = (
self.iso_max_transport_latency_c_to_p or 0
)
max_sdu_p_to_c = (
self.iso_max_sdu_p_to_c or DEFAULT_ISO_MAX_SDU_P_TO_C
)
max_sdu_c_to_p = self.iso_max_sdu_c_to_p or 0
rtn_p_to_c = self.iso_rtn_p_to_c or DEFAULT_ISO_RTN_P_TO_C
rtn_c_to_p = self.iso_rtn_c_to_p or 0
cis_handles = await self.device.setup_cig(
CigParameters(
cig_id=1,
sdu_interval_c_to_p=sdu_interval_c_to_p,
sdu_interval_p_to_c=sdu_interval_p_to_c,
max_transport_latency_c_to_p=max_transport_latency_c_to_p,
max_transport_latency_p_to_c=max_transport_latency_p_to_c,
cis_parameters=[
CigParameters.CisParameters(
cis_id=2,
max_sdu_c_to_p=max_sdu_c_to_p,
max_sdu_p_to_c=max_sdu_p_to_c,
rtn_c_to_p=rtn_c_to_p,
rtn_p_to_c=rtn_p_to_c,
)
],
)
)
cis_link = (
await self.device.create_cis([(cis_handles[0], self.connection)])
)[0]
print_cis_link(cis_link)
await mode.on_connection(
self.connection, cis_link, scenario.is_sender()
)
else:
await mode.on_connection(self.connection)
await scenario.run() await scenario.run()
await asyncio.sleep(DEFAULT_LINGER_TIME) await asyncio.sleep(DEFAULT_LINGER_TIME)
@@ -1672,38 +1376,24 @@ class Peripheral(Device.Listener, Connection.Listener):
scenario_factory, scenario_factory,
mode_factory, mode_factory,
classic, classic,
iso,
extended_data_length, extended_data_length,
role_switch, role_switch,
le_scan,
le_advertise,
classic_page_scan,
classic_inquiry_scan,
): ):
self.transport = transport self.transport = transport
self.classic = classic self.classic = classic
self.iso = iso
self.scenario_factory = scenario_factory self.scenario_factory = scenario_factory
self.mode_factory = mode_factory self.mode_factory = mode_factory
self.extended_data_length = extended_data_length self.extended_data_length = extended_data_length
self.role_switch = role_switch self.role_switch = role_switch
self.le_scan = le_scan
self.classic_page_scan = classic_page_scan
self.classic_inquiry_scan = classic_inquiry_scan
self.scenario = None self.scenario = None
self.mode = None self.mode = None
self.device = None self.device = None
self.connection = None self.connection = None
self.connected = asyncio.Event() self.connected = asyncio.Event()
if le_advertise:
self.le_advertise = le_advertise
else:
self.le_advertise = 0 if classic else DEFAULT_ADVERTISING_INTERVAL
async def run(self): async def run(self):
logging.info(color('>>> Connecting to HCI...', 'green')) logging.info(color('>>> Connecting to HCI...', 'green'))
async with await open_transport(self.transport) as ( async with await open_transport_or_link(self.transport) as (
hci_source, hci_source,
hci_sink, hci_sink,
): ):
@@ -1717,22 +1407,20 @@ class Peripheral(Device.Listener, Connection.Listener):
self.mode = self.mode_factory(self.device) self.mode = self.mode_factory(self.device)
self.scenario = self.scenario_factory(self.mode) self.scenario = self.scenario_factory(self.mode)
self.device.classic_enabled = self.classic self.device.classic_enabled = self.classic
self.device.cis_enabled = self.iso
# Set up a pairing config factory with minimal requirements. # Set up a pairing config factory with minimal requirements.
self.device.config.keystore = "JsonKeyStore"
self.device.pairing_config_factory = lambda _: PairingConfig( self.device.pairing_config_factory = lambda _: PairingConfig(
sc=False, mitm=False, bonding=False sc=False, mitm=False, bonding=False
) )
await pre_power_on(self.device, self.classic)
await self.device.power_on() await self.device.power_on()
await post_power_on(
self.device, if self.classic:
self.le_scan, await self.device.set_discoverable(True)
self.le_advertise, await self.device.set_connectable(True)
self.classic or self.classic_page_scan, else:
self.classic or self.classic_inquiry_scan, await self.device.start_advertising(auto_restart=True)
)
if self.classic: if self.classic:
logging.info( logging.info(
@@ -1754,21 +1442,7 @@ class Peripheral(Device.Listener, Connection.Listener):
logging.info(color('### Connected', 'cyan')) logging.info(color('### Connected', 'cyan'))
print_connection(self.connection) print_connection(self.connection)
if self.iso: await self.mode.on_connection(self.connection)
async def on_cis_request(cis_link: CisLink) -> None:
logging.info(color("@@@ Accepting CIS", "green"))
await self.device.accept_cis_request(cis_link)
print_cis_link(cis_link)
await self.mode.on_connection(
self.connection, cis_link, self.scenario.is_sender()
)
self.connection.on(self.connection.EVENT_CIS_REQUEST, on_cis_request)
else:
await self.mode.on_connection(self.connection)
await self.scenario.run() await self.scenario.run()
await asyncio.sleep(DEFAULT_LINGER_TIME) await asyncio.sleep(DEFAULT_LINGER_TIME)
@@ -1777,14 +1451,10 @@ class Peripheral(Device.Listener, Connection.Listener):
self.connection = connection self.connection = connection
self.connected.set() self.connected.set()
# Stop being discoverable and connectable if possible # Stop being discoverable and connectable
if self.classic: if self.classic:
if not self.classic_inquiry_scan: AsyncRunner.spawn(self.device.set_discoverable(False))
logging.info(color("*** Stopping inquiry scan", "blue")) AsyncRunner.spawn(self.device.set_connectable(False))
AsyncRunner.spawn(self.device.set_discoverable(False))
if not self.classic_page_scan:
logging.info(color("*** Stopping page scan", "blue"))
AsyncRunner.spawn(self.device.set_connectable(False))
# Request a new data length if needed # Request a new data length if needed
if not self.classic and self.extended_data_length: if not self.classic and self.extended_data_length:
@@ -1805,9 +1475,7 @@ class Peripheral(Device.Listener, Connection.Listener):
self.scenario.reset() self.scenario.reset()
if self.classic: if self.classic:
logging.info(color("*** Enabling inquiry scan", "blue"))
AsyncRunner.spawn(self.device.set_discoverable(True)) AsyncRunner.spawn(self.device.set_discoverable(True))
logging.info(color("*** Enabling page scan", "blue"))
AsyncRunner.spawn(self.device.set_connectable(True)) AsyncRunner.spawn(self.device.set_connectable(True))
def on_connection_parameters_update(self): def on_connection_parameters_update(self):
@@ -1880,12 +1548,6 @@ def create_mode_factory(ctx, default_mode):
credits_threshold=ctx.obj['rfcomm_credits_threshold'], credits_threshold=ctx.obj['rfcomm_credits_threshold'],
) )
if mode == 'iso-server':
return IsoServer(device)
if mode == 'iso-client':
return IsoClient(device)
raise ValueError('invalid mode') raise ValueError('invalid mode')
return create_mode return create_mode
@@ -1913,9 +1575,6 @@ def create_scenario_factory(ctx, default_scenario):
return Receiver(packet_io, ctx.obj['linger']) return Receiver(packet_io, ctx.obj['linger'])
if scenario == 'ping': if scenario == 'ping':
if isinstance(packet_io, (IsoClient, IsoServer)):
raise ValueError('ping not supported with ISO')
return Ping( return Ping(
packet_io, packet_io,
start_delay=ctx.obj['start_delay'], start_delay=ctx.obj['start_delay'],
@@ -1927,9 +1586,6 @@ def create_scenario_factory(ctx, default_scenario):
) )
if scenario == 'pong': if scenario == 'pong':
if isinstance(packet_io, (IsoClient, IsoServer)):
raise ValueError('pong not supported with ISO')
return Pong(packet_io, ctx.obj['linger']) return Pong(packet_io, ctx.obj['linger'])
raise ValueError('invalid scenario') raise ValueError('invalid scenario')
@@ -1953,8 +1609,6 @@ def create_scenario_factory(ctx, default_scenario):
'l2cap-server', 'l2cap-server',
'rfcomm-client', 'rfcomm-client',
'rfcomm-server', 'rfcomm-server',
'iso-client',
'iso-server',
] ]
), ),
) )
@@ -1967,7 +1621,6 @@ def create_scenario_factory(ctx, default_scenario):
) )
@click.option( @click.option(
'--extended-data-length', '--extended-data-length',
metavar='<TX-OCTETS>/<TX-TIME>',
help='Request a data length upon connection, specified as tx_octets/tx_time', help='Request a data length upon connection, specified as tx_octets/tx_time',
) )
@click.option( @click.option(
@@ -1975,26 +1628,6 @@ def create_scenario_factory(ctx, default_scenario):
type=click.Choice(['central', 'peripheral']), type=click.Choice(['central', 'peripheral']),
help='Request role switch upon connection (central or peripheral)', help='Request role switch upon connection (central or peripheral)',
) )
@click.option(
'--le-scan',
metavar='<WINDOW>/<INTERVAL>',
help='Perform an LE scan with a given window and interval (milliseconds)',
)
@click.option(
'--le-advertise',
metavar='<INTERVAL>',
help='Advertise with a given interval (milliseconds)',
)
@click.option(
'--classic-page-scan',
is_flag=True,
help='Enable Classic page scanning',
)
@click.option(
'--classic-inquiry-scan',
is_flag=True,
help='Enable Classic enquiry scanning',
)
@click.option( @click.option(
'--rfcomm-channel', '--rfcomm-channel',
type=int, type=int,
@@ -2120,10 +1753,6 @@ def bench(
att_mtu, att_mtu,
extended_data_length, extended_data_length,
role_switch, role_switch,
le_scan,
le_advertise,
classic_page_scan,
classic_inquiry_scan,
packet_size, packet_size,
packet_count, packet_count,
start_delay, start_delay,
@@ -2172,12 +1801,7 @@ def bench(
else None else None
) )
ctx.obj['role_switch'] = role_switch ctx.obj['role_switch'] = role_switch
ctx.obj['le_scan'] = [float(x) for x in le_scan.split('/')] if le_scan else None
ctx.obj['le_advertise'] = float(le_advertise) if le_advertise else None
ctx.obj['classic_page_scan'] = classic_page_scan
ctx.obj['classic_inquiry_scan'] = classic_inquiry_scan
ctx.obj['classic'] = mode in ('rfcomm-client', 'rfcomm-server') ctx.obj['classic'] = mode in ('rfcomm-client', 'rfcomm-server')
ctx.obj['iso'] = mode in ('iso-client', 'iso-server')
@bench.command() @bench.command()
@@ -2199,94 +1823,28 @@ def bench(
@click.option('--phy', type=click.Choice(['1m', '2m', 'coded']), help='PHY to use') @click.option('--phy', type=click.Choice(['1m', '2m', 'coded']), help='PHY to use')
@click.option('--authenticate', is_flag=True, help='Authenticate (RFComm only)') @click.option('--authenticate', is_flag=True, help='Authenticate (RFComm only)')
@click.option('--encrypt', is_flag=True, help='Encrypt the connection (RFComm only)') @click.option('--encrypt', is_flag=True, help='Encrypt the connection (RFComm only)')
@click.option(
'--iso-sdu-interval-c-to-p',
type=int,
help='ISO SDU central -> peripheral (microseconds)',
)
@click.option(
'--iso-sdu-interval-p-to-c',
type=int,
help='ISO SDU interval peripheral -> central (microseconds)',
)
@click.option(
'--iso-max-sdu-c-to-p',
type=int,
help='ISO max SDU central -> peripheral',
)
@click.option(
'--iso-max-sdu-p-to-c',
type=int,
help='ISO max SDU peripheral -> central',
)
@click.option(
'--iso-max-transport-latency-c-to-p',
type=int,
help='ISO max transport latency central -> peripheral (milliseconds)',
)
@click.option(
'--iso-max-transport-latency-p-to-c',
type=int,
help='ISO max transport latency peripheral -> central (milliseconds)',
)
@click.option(
'--iso-rtn-c-to-p',
type=int,
help='ISO RTN central -> peripheral (integer count)',
)
@click.option(
'--iso-rtn-p-to-c',
type=int,
help='ISO RTN peripheral -> central (integer count)',
)
@click.pass_context @click.pass_context
def central( def central(
ctx, ctx, transport, peripheral_address, connection_interval, phy, authenticate, encrypt
transport,
peripheral_address,
connection_interval,
phy,
authenticate,
encrypt,
iso_sdu_interval_c_to_p,
iso_sdu_interval_p_to_c,
iso_max_sdu_c_to_p,
iso_max_sdu_p_to_c,
iso_max_transport_latency_c_to_p,
iso_max_transport_latency_p_to_c,
iso_rtn_c_to_p,
iso_rtn_p_to_c,
): ):
"""Run as a central (initiates the connection)""" """Run as a central (initiates the connection)"""
scenario_factory = create_scenario_factory(ctx, 'send') scenario_factory = create_scenario_factory(ctx, 'send')
mode_factory = create_mode_factory(ctx, 'gatt-client') mode_factory = create_mode_factory(ctx, 'gatt-client')
classic = ctx.obj['classic']
async def run_central(): async def run_central():
await Central( await Central(
transport, transport,
peripheral_address, peripheral_address,
classic,
scenario_factory, scenario_factory,
mode_factory, mode_factory,
connection_interval, connection_interval,
phy, phy,
authenticate, authenticate,
encrypt or authenticate, encrypt or authenticate,
ctx.obj['iso'],
iso_sdu_interval_c_to_p,
iso_sdu_interval_p_to_c,
iso_max_sdu_c_to_p,
iso_max_sdu_p_to_c,
iso_max_transport_latency_c_to_p,
iso_max_transport_latency_p_to_c,
iso_rtn_c_to_p,
iso_rtn_p_to_c,
ctx.obj['classic'],
ctx.obj['extended_data_length'], ctx.obj['extended_data_length'],
ctx.obj['role_switch'], ctx.obj['role_switch'],
ctx.obj['le_scan'],
ctx.obj['le_advertise'],
ctx.obj['classic_page_scan'],
ctx.obj['classic_inquiry_scan'],
).run() ).run()
asyncio.run(run_central()) asyncio.run(run_central())
@@ -2306,20 +1864,19 @@ def peripheral(ctx, transport):
scenario_factory, scenario_factory,
mode_factory, mode_factory,
ctx.obj['classic'], ctx.obj['classic'],
ctx.obj['iso'],
ctx.obj['extended_data_length'], ctx.obj['extended_data_length'],
ctx.obj['role_switch'], ctx.obj['role_switch'],
ctx.obj['le_scan'],
ctx.obj['le_advertise'],
ctx.obj['classic_page_scan'],
ctx.obj['classic_inquiry_scan'],
).run() ).run()
asyncio.run(run_peripheral()) asyncio.run(run_peripheral())
def main(): def main():
bumble.logging.setup_basic_logging('INFO') logging.basicConfig(
level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper(),
format="[%(asctime)s.%(msecs)03d] %(levelname)s:%(name)s:%(message)s",
datefmt="%H:%M:%S",
)
bench() bench()
-1
View File
@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
import click import click
from bumble.colors import color from bumble.colors import color
from bumble.hci import Address from bumble.hci import Address
from bumble.helpers import generate_irk, verify_rpa_with_irk from bumble.helpers import generate_irk, verify_rpa_with_irk
+36 -28
View File
@@ -23,54 +23,58 @@ import asyncio
import logging import logging
import os import os
import re import re
import humanize
from typing import Optional, Union
from collections import OrderedDict from collections import OrderedDict
import click import click
import humanize
from prettytable import PrettyTable from prettytable import PrettyTable
from prompt_toolkit import Application from prompt_toolkit import Application
from prompt_toolkit.completion import Completer, Completion, NestedCompleter
from prompt_toolkit.data_structures import Point
from prompt_toolkit.filters import Condition
from prompt_toolkit.formatted_text import ANSI
from prompt_toolkit.history import FileHistory from prompt_toolkit.history import FileHistory
from prompt_toolkit.completion import Completer, Completion, NestedCompleter
from prompt_toolkit.key_binding import KeyBindings from prompt_toolkit.key_binding import KeyBindings
from prompt_toolkit.formatted_text import ANSI
from prompt_toolkit.styles import Style
from prompt_toolkit.filters import Condition
from prompt_toolkit.widgets import TextArea, Frame
from prompt_toolkit.widgets.toolbars import FormattedTextToolbar
from prompt_toolkit.data_structures import Point
from prompt_toolkit.layout import ( from prompt_toolkit.layout import (
Layout,
HSplit,
Window,
CompletionsMenu, CompletionsMenu,
Float,
FormattedTextControl,
FloatContainer,
ConditionalContainer, ConditionalContainer,
Dimension, Dimension,
Float,
FloatContainer,
FormattedTextControl,
HSplit,
Layout,
Window,
) )
from prompt_toolkit.styles import Style
from prompt_toolkit.widgets import Frame, TextArea
from prompt_toolkit.widgets.toolbars import FormattedTextToolbar
from bumble import __version__
import bumble.core import bumble.core
from bumble import __version__, colors from bumble import colors
from bumble.core import UUID, AdvertisingData from bumble.core import UUID, AdvertisingData, PhysicalTransport
from bumble.device import ( from bumble.device import (
Connection,
ConnectionParametersPreferences, ConnectionParametersPreferences,
ConnectionPHY, ConnectionPHY,
Device, Device,
Connection,
Peer, Peer,
) )
from bumble.gatt import Characteristic, CharacteristicDeclaration, Descriptor, Service from bumble.utils import AsyncRunner
from bumble.transport import open_transport_or_link
from bumble.gatt import Characteristic, Service, CharacteristicDeclaration, Descriptor
from bumble.gatt_client import CharacteristicProxy from bumble.gatt_client import CharacteristicProxy
from bumble.hci import ( from bumble.hci import (
Address,
HCI_Constant,
HCI_LE_1M_PHY, HCI_LE_1M_PHY,
HCI_LE_2M_PHY, HCI_LE_2M_PHY,
HCI_LE_CODED_PHY, HCI_LE_CODED_PHY,
Address,
HCI_Constant,
) )
from bumble.transport import open_transport
from bumble.utils import AsyncRunner
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Constants # Constants
@@ -125,8 +129,8 @@ def parse_phys(phys):
# Console App # Console App
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class ConsoleApp: class ConsoleApp:
connected_peer: Peer | None connected_peer: Optional[Peer]
connection_phy: ConnectionPHY | None connection_phy: Optional[ConnectionPHY]
def __init__(self): def __init__(self):
self.known_addresses = set() self.known_addresses = set()
@@ -287,7 +291,7 @@ class ConsoleApp:
async def run_async(self, device_config, transport): async def run_async(self, device_config, transport):
rssi_monitoring_task = asyncio.create_task(self.rssi_monitor_loop()) rssi_monitoring_task = asyncio.create_task(self.rssi_monitor_loop())
async with await open_transport(transport) as (hci_source, hci_sink): async with await open_transport_or_link(transport) as (hci_source, hci_sink):
if device_config: if device_config:
self.device = Device.from_config_file_with_hci( self.device = Device.from_config_file_with_hci(
device_config, hci_source, hci_sink device_config, hci_source, hci_sink
@@ -519,7 +523,7 @@ class ConsoleApp:
self.show_attributes(attributes) self.show_attributes(attributes)
def find_remote_characteristic(self, param) -> CharacteristicProxy | None: def find_remote_characteristic(self, param) -> Optional[CharacteristicProxy]:
if not self.connected_peer: if not self.connected_peer:
return None return None
parts = param.split('.') parts = param.split('.')
@@ -541,7 +545,9 @@ class ConsoleApp:
return None return None
def find_local_attribute(self, param) -> Characteristic | Descriptor | None: def find_local_attribute(
self, param
) -> Optional[Union[Characteristic, Descriptor]]:
parts = param.split('.') parts = param.split('.')
if len(parts) == 3: if len(parts) == 3:
service_uuid = UUID(parts[0]) service_uuid = UUID(parts[0])
@@ -1093,7 +1099,9 @@ class DeviceListener(Device.Listener, Connection.Listener):
if self.app.connected_peer.connection.is_encrypted if self.app.connected_peer.connection.is_encrypted
else 'not encrypted' else 'not encrypted'
) )
self.app.append_to_output(f'connection encryption change: {encryption_state}') self.app.append_to_output(
'connection encryption change: ' f'{encryption_state}'
)
def on_connection_data_length_change(self): def on_connection_data_length_change(self):
self.app.append_to_output( self.app.append_to_output(
+147 -172
View File
@@ -16,120 +16,130 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import asyncio import asyncio
import os
import logging
import time import time
import click import click
import bumble.logging
from bumble.colors import color
from bumble.company_ids import COMPANY_IDENTIFIERS from bumble.company_ids import COMPANY_IDENTIFIERS
from bumble.colors import color
from bumble.core import name_or_number from bumble.core import name_or_number
from bumble.hci import ( from bumble.hci import (
HCI_LE_READ_BUFFER_SIZE_COMMAND, map_null_terminated_utf8_string,
HCI_LE_READ_BUFFER_SIZE_V2_COMMAND, CodecID,
HCI_LE_READ_MAXIMUM_DATA_LENGTH_COMMAND, LeFeature,
HCI_LE_READ_MINIMUM_SUPPORTED_CONNECTION_INTERVAL_COMMAND, HCI_SUCCESS,
HCI_LE_READ_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND, HCI_VERSION_NAMES,
HCI_READ_BD_ADDR_COMMAND, LMP_VERSION_NAMES,
HCI_READ_BUFFER_SIZE_COMMAND,
HCI_READ_LOCAL_NAME_COMMAND,
HCI_Command, HCI_Command,
HCI_LE_Read_Buffer_Size_Command, HCI_Command_Complete_Event,
HCI_LE_Read_Buffer_Size_V2_Command, HCI_Command_Status_Event,
HCI_LE_Read_Maximum_Data_Length_Command, HCI_READ_BUFFER_SIZE_COMMAND,
HCI_LE_Read_Minimum_Supported_Connection_Interval_Command,
HCI_LE_Read_Suggested_Default_Data_Length_Command,
HCI_Read_BD_ADDR_Command,
HCI_Read_Buffer_Size_Command, HCI_Read_Buffer_Size_Command,
HCI_LE_READ_BUFFER_SIZE_V2_COMMAND,
HCI_LE_Read_Buffer_Size_V2_Command,
HCI_READ_BD_ADDR_COMMAND,
HCI_Read_BD_ADDR_Command,
HCI_READ_LOCAL_NAME_COMMAND,
HCI_Read_Local_Name_Command, HCI_Read_Local_Name_Command,
HCI_LE_READ_BUFFER_SIZE_COMMAND,
HCI_LE_Read_Buffer_Size_Command,
HCI_LE_READ_MAXIMUM_DATA_LENGTH_COMMAND,
HCI_LE_Read_Maximum_Data_Length_Command,
HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND,
HCI_LE_Read_Number_Of_Supported_Advertising_Sets_Command,
HCI_LE_READ_MAXIMUM_ADVERTISING_DATA_LENGTH_COMMAND,
HCI_LE_Read_Maximum_Advertising_Data_Length_Command,
HCI_LE_READ_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND,
HCI_LE_Read_Suggested_Default_Data_Length_Command,
HCI_Read_Local_Supported_Codecs_Command, HCI_Read_Local_Supported_Codecs_Command,
HCI_Read_Local_Supported_Codecs_V2_Command, HCI_Read_Local_Supported_Codecs_V2_Command,
HCI_Read_Local_Version_Information_Command, HCI_Read_Local_Version_Information_Command,
HCI_Read_Voice_Setting_Command,
LeFeature,
SpecificationVersion,
VoiceSetting,
map_null_terminated_utf8_string,
) )
from bumble.host import Host from bumble.host import Host
from bumble.transport import open_transport from bumble.transport import open_transport_or_link
# -----------------------------------------------------------------------------
def command_succeeded(response):
if isinstance(response, HCI_Command_Status_Event):
return response.status == HCI_SUCCESS
if isinstance(response, HCI_Command_Complete_Event):
return response.return_parameters.status == HCI_SUCCESS
return False
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def get_classic_info(host: Host) -> None: async def get_classic_info(host: Host) -> None:
if host.supports_command(HCI_READ_BD_ADDR_COMMAND): if host.supports_command(HCI_READ_BD_ADDR_COMMAND):
response1 = await host.send_sync_command(HCI_Read_BD_ADDR_Command()) response = await host.send_command(HCI_Read_BD_ADDR_Command())
print() if command_succeeded(response):
print( print()
color('Public Address:', 'yellow'), print(
response1.bd_addr.to_string(False), color('Public Address:', 'yellow'),
) response.return_parameters.bd_addr.to_string(False),
)
if host.supports_command(HCI_READ_LOCAL_NAME_COMMAND): if host.supports_command(HCI_READ_LOCAL_NAME_COMMAND):
response2 = await host.send_sync_command(HCI_Read_Local_Name_Command()) response = await host.send_command(HCI_Read_Local_Name_Command())
print() if command_succeeded(response):
print( print()
color('Local Name:', 'yellow'), print(
map_null_terminated_utf8_string(response2.local_name), color('Local Name:', 'yellow'),
) map_null_terminated_utf8_string(response.return_parameters.local_name),
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def get_le_info(host: Host) -> None: async def get_le_info(host: Host) -> None:
print() print()
print( if host.supports_command(HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND):
color('LE Number Of Supported Advertising Sets:', 'yellow'), response = await host.send_command(
host.number_of_supported_advertising_sets, HCI_LE_Read_Number_Of_Supported_Advertising_Sets_Command()
'\n', )
) if command_succeeded(response):
print(
color('LE Number Of Supported Advertising Sets:', 'yellow'),
response.return_parameters.num_supported_advertising_sets,
'\n',
)
print( if host.supports_command(HCI_LE_READ_MAXIMUM_ADVERTISING_DATA_LENGTH_COMMAND):
color('LE Maximum Advertising Data Length:', 'yellow'), response = await host.send_command(
host.maximum_advertising_data_length, HCI_LE_Read_Maximum_Advertising_Data_Length_Command()
'\n', )
) if command_succeeded(response):
print(
color('LE Maximum Advertising Data Length:', 'yellow'),
response.return_parameters.max_advertising_data_length,
'\n',
)
if host.supports_command(HCI_LE_READ_MAXIMUM_DATA_LENGTH_COMMAND): if host.supports_command(HCI_LE_READ_MAXIMUM_DATA_LENGTH_COMMAND):
response1 = await host.send_sync_command( response = await host.send_command(HCI_LE_Read_Maximum_Data_Length_Command())
HCI_LE_Read_Maximum_Data_Length_Command() if command_succeeded(response):
) print(
print( color('Maximum Data Length:', 'yellow'),
color('LE Maximum Data Length:', 'yellow'), (
( f'tx:{response.return_parameters.supported_max_tx_octets}/'
f'tx:{response1.supported_max_tx_octets}/' f'{response.return_parameters.supported_max_tx_time}, '
f'{response1.supported_max_tx_time}, ' f'rx:{response.return_parameters.supported_max_rx_octets}/'
f'rx:{response1.supported_max_rx_octets}/' f'{response.return_parameters.supported_max_rx_time}'
f'{response1.supported_max_rx_time}' ),
), '\n',
) )
if host.supports_command(HCI_LE_READ_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND): if host.supports_command(HCI_LE_READ_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND):
response2 = await host.send_sync_command( response = await host.send_command(
HCI_LE_Read_Suggested_Default_Data_Length_Command() HCI_LE_Read_Suggested_Default_Data_Length_Command()
) )
print( if command_succeeded(response):
color('LE Suggested Default Data Length:', 'yellow'),
f'{response2.suggested_max_tx_octets}/'
f'{response2.suggested_max_tx_time}',
'\n',
)
if host.supports_command(HCI_LE_READ_MINIMUM_SUPPORTED_CONNECTION_INTERVAL_COMMAND):
response3 = await host.send_sync_command(
HCI_LE_Read_Minimum_Supported_Connection_Interval_Command()
)
print(
color('LE Minimum Supported Connection Interval:', 'yellow'),
f'{response3.minimum_supported_connection_interval * 125} µs',
)
for group in range(len(response3.group_min)):
print( print(
f' Group {group}: ' color('Suggested Default Data Length:', 'yellow'),
f'{response3.group_min[group] * 125} µs to ' f'{response.return_parameters.suggested_max_tx_octets}/'
f'{response3.group_max[group] * 125} µs ' f'{response.return_parameters.suggested_max_tx_time}',
'by increments of '
f'{response3.group_stride[group] * 125} µs',
'\n', '\n',
) )
@@ -143,31 +153,37 @@ async def get_flow_control_info(host: Host) -> None:
print() print()
if host.supports_command(HCI_READ_BUFFER_SIZE_COMMAND): if host.supports_command(HCI_READ_BUFFER_SIZE_COMMAND):
response1 = await host.send_sync_command(HCI_Read_Buffer_Size_Command()) response = await host.send_command(
HCI_Read_Buffer_Size_Command(), check_result=True
)
print( print(
color('ACL Flow Control:', 'yellow'), color('ACL Flow Control:', 'yellow'),
f'{response1.hc_total_num_acl_data_packets} ' f'{response.return_parameters.hc_total_num_acl_data_packets} '
f'packets of size {response1.hc_acl_data_packet_length}', f'packets of size {response.return_parameters.hc_acl_data_packet_length}',
) )
if host.supports_command(HCI_LE_READ_BUFFER_SIZE_V2_COMMAND): if host.supports_command(HCI_LE_READ_BUFFER_SIZE_V2_COMMAND):
response2 = await host.send_sync_command(HCI_LE_Read_Buffer_Size_V2_Command()) response = await host.send_command(
HCI_LE_Read_Buffer_Size_V2_Command(), check_result=True
)
print( print(
color('LE ACL Flow Control:', 'yellow'), color('LE ACL Flow Control:', 'yellow'),
f'{response2.total_num_le_acl_data_packets} ' f'{response.return_parameters.total_num_le_acl_data_packets} '
f'packets of size {response2.le_acl_data_packet_length}', f'packets of size {response.return_parameters.le_acl_data_packet_length}',
) )
print( print(
color('LE ISO Flow Control:', 'yellow'), color('LE ISO Flow Control:', 'yellow'),
f'{response2.total_num_iso_data_packets} ' f'{response.return_parameters.total_num_iso_data_packets} '
f'packets of size {response2.iso_data_packet_length}', f'packets of size {response.return_parameters.iso_data_packet_length}',
) )
elif host.supports_command(HCI_LE_READ_BUFFER_SIZE_COMMAND): elif host.supports_command(HCI_LE_READ_BUFFER_SIZE_COMMAND):
response3 = await host.send_sync_command(HCI_LE_Read_Buffer_Size_Command()) response = await host.send_command(
HCI_LE_Read_Buffer_Size_Command(), check_result=True
)
print( print(
color('LE ACL Flow Control:', 'yellow'), color('LE ACL Flow Control:', 'yellow'),
f'{response3.total_num_le_acl_data_packets} ' f'{response.return_parameters.total_num_le_acl_data_packets} '
f'packets of size {response3.le_acl_data_packet_length}', f'packets of size {response.return_parameters.le_acl_data_packet_length}',
) )
@@ -176,95 +192,78 @@ async def get_codecs_info(host: Host) -> None:
print() print()
if host.supports_command(HCI_Read_Local_Supported_Codecs_V2_Command.op_code): if host.supports_command(HCI_Read_Local_Supported_Codecs_V2_Command.op_code):
response1 = await host.send_sync_command( response = await host.send_command(
HCI_Read_Local_Supported_Codecs_V2_Command() HCI_Read_Local_Supported_Codecs_V2_Command(), check_result=True
) )
print(color('Codecs:', 'yellow')) print(color('Codecs:', 'yellow'))
for codec_id, transport in zip( for codec_id, transport in zip(
response1.standard_codec_ids, response.return_parameters.standard_codec_ids,
response1.standard_codec_transports, response.return_parameters.standard_codec_transports,
): ):
print(f' {codec_id.name} - {transport.name}') transport_name = HCI_Read_Local_Supported_Codecs_V2_Command.Transport(
transport
).name
codec_name = CodecID(codec_id).name
print(f' {codec_name} - {transport_name}')
for vendor_codec_id, vendor_transport in zip( for codec_id, transport in zip(
response1.vendor_specific_codec_ids, response.return_parameters.vendor_specific_codec_ids,
response1.vendor_specific_codec_transports, response.return_parameters.vendor_specific_codec_transports,
): ):
company = name_or_number(COMPANY_IDENTIFIERS, vendor_codec_id >> 16) transport_name = HCI_Read_Local_Supported_Codecs_V2_Command.Transport(
print(f' {company} / {vendor_codec_id & 0xFFFF} - {vendor_transport.name}') transport
).name
company = name_or_number(COMPANY_IDENTIFIERS, codec_id >> 16)
print(f' {company} / {codec_id & 0xFFFF} - {transport_name}')
if not response1.standard_codec_ids: if not response.return_parameters.standard_codec_ids:
print(' No standard codecs') print(' No standard codecs')
if not response1.vendor_specific_codec_ids: if not response.return_parameters.vendor_specific_codec_ids:
print(' No Vendor-specific codecs') print(' No Vendor-specific codecs')
if host.supports_command(HCI_Read_Local_Supported_Codecs_Command.op_code): if host.supports_command(HCI_Read_Local_Supported_Codecs_Command.op_code):
response2 = await host.send_sync_command( response = await host.send_command(
HCI_Read_Local_Supported_Codecs_Command() HCI_Read_Local_Supported_Codecs_Command(), check_result=True
) )
print(color('Codecs (BR/EDR):', 'yellow')) print(color('Codecs (BR/EDR):', 'yellow'))
for codec_id in response2.standard_codec_ids: for codec_id in response.return_parameters.standard_codec_ids:
print(f' {codec_id.name}') codec_name = CodecID(codec_id).name
print(f' {codec_name}')
for vendor_codec_id in response2.vendor_specific_codec_ids: for codec_id in response.return_parameters.vendor_specific_codec_ids:
company = name_or_number(COMPANY_IDENTIFIERS, vendor_codec_id >> 16) company = name_or_number(COMPANY_IDENTIFIERS, codec_id >> 16)
print(f' {company} / {vendor_codec_id & 0xFFFF}') print(f' {company} / {codec_id & 0xFFFF}')
if not response2.standard_codec_ids: if not response.return_parameters.standard_codec_ids:
print(' No standard codecs') print(' No standard codecs')
if not response2.vendor_specific_codec_ids: if not response.return_parameters.vendor_specific_codec_ids:
print(' No Vendor-specific codecs') print(' No Vendor-specific codecs')
if host.supports_command(HCI_Read_Voice_Setting_Command.op_code):
response3 = await host.send_sync_command(HCI_Read_Voice_Setting_Command())
voice_setting = VoiceSetting.from_int(response3.voice_setting)
print(color('Voice Setting:', 'yellow'))
print(f' Air Coding Format: {voice_setting.air_coding_format.name}')
print(f' Linear PCM Bit Position: {voice_setting.linear_pcm_bit_position}')
print(f' Input Sample Size: {voice_setting.input_sample_size.name}')
print(f' Input Data Format: {voice_setting.input_data_format.name}')
print(f' Input Coding Format: {voice_setting.input_coding_format.name}')
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def async_main( async def async_main(latency_probes, transport):
latency_probes, latency_probe_interval, latency_probe_command, transport
):
print('<<< connecting to HCI...') print('<<< connecting to HCI...')
async with await open_transport(transport) as (hci_source, hci_sink): async with await open_transport_or_link(transport) as (hci_source, hci_sink):
print('<<< connected') print('<<< connected')
host = Host(hci_source, hci_sink) host = Host(hci_source, hci_sink)
await host.reset() await host.reset()
# Measure the latency if requested # Measure the latency if requested
# (we add an extra probe at the start, that we ignore, just to ensure that
# the transport is primed)
latencies = [] latencies = []
if latency_probes: if latency_probes:
if latency_probe_command: for _ in range(latency_probes):
probe_hci_command = HCI_Command.from_bytes(
bytes.fromhex(latency_probe_command)
)
else:
probe_hci_command = HCI_Read_Local_Version_Information_Command()
for iteration in range(1 + latency_probes):
if latency_probe_interval:
await asyncio.sleep(latency_probe_interval / 1000)
start = time.time() start = time.time()
await host.send_command(probe_hci_command) await host.send_command(HCI_Read_Local_Version_Information_Command())
if iteration: latencies.append(1000 * (time.time() - start))
latencies.append(1000 * (time.time() - start))
print( print(
color('HCI Command Latency:', 'yellow'), color('HCI Command Latency:', 'yellow'),
( (
f'min={min(latencies):.2f}, ' f'min={min(latencies):.2f}, '
f'max={max(latencies):.2f}, ' f'max={max(latencies):.2f}, '
f'average={sum(latencies) / len(latencies):.2f},' f'average={sum(latencies)/len(latencies):.2f}'
), ),
[f'{latency:.4}' for latency in latencies],
'\n', '\n',
) )
@@ -276,20 +275,14 @@ async def async_main(
) )
print( print(
color(' HCI Version: ', 'green'), color(' HCI Version: ', 'green'),
SpecificationVersion(host.local_version.hci_version).name, name_or_number(HCI_VERSION_NAMES, host.local_version.hci_version),
)
print(
color(' HCI Subversion:', 'green'),
f'0x{host.local_version.hci_subversion:04x}',
) )
print(color(' HCI Subversion:', 'green'), host.local_version.hci_subversion)
print( print(
color(' LMP Version: ', 'green'), color(' LMP Version: ', 'green'),
SpecificationVersion(host.local_version.lmp_version).name, name_or_number(LMP_VERSION_NAMES, host.local_version.lmp_version),
)
print(
color(' LMP Subversion:', 'green'),
f'0x{host.local_version.lmp_subversion:04x}',
) )
print(color(' LMP Subversion:', 'green'), host.local_version.lmp_subversion)
# Get the Classic info # Get the Classic info
await get_classic_info(host) await get_classic_info(host)
@@ -318,28 +311,10 @@ async def async_main(
type=int, type=int,
help='Send N commands to measure HCI transport latency statistics', help='Send N commands to measure HCI transport latency statistics',
) )
@click.option(
'--latency-probe-interval',
metavar='INTERVAL',
type=int,
help='Interval between latency probes (milliseconds)',
)
@click.option(
'--latency-probe-command',
metavar='COMMAND_HEX',
help=(
'Probe command (HCI Command packet bytes, in hex. Use 0177FC00 for'
' a loopback test with the HCI remote proxy app)'
),
)
@click.argument('transport') @click.argument('transport')
def main(latency_probes, latency_probe_interval, latency_probe_command, transport): def main(latency_probes, transport):
bumble.logging.setup_basic_logging() logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
asyncio.run( asyncio.run(async_main(latency_probes, transport))
async_main(
latency_probes, latency_probe_interval, latency_probe_command, transport
)
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
+59 -170
View File
@@ -16,149 +16,79 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import asyncio import asyncio
import statistics import logging
import struct import os
import time import time
from typing import Optional
import click
import bumble.logging
from bumble.colors import color from bumble.colors import color
from bumble.hci import ( from bumble.hci import (
HCI_READ_LOOPBACK_MODE_COMMAND, HCI_READ_LOOPBACK_MODE_COMMAND,
HCI_WRITE_LOOPBACK_MODE_COMMAND,
Address,
HCI_Read_Loopback_Mode_Command, HCI_Read_Loopback_Mode_Command,
HCI_SynchronousDataPacket, HCI_WRITE_LOOPBACK_MODE_COMMAND,
HCI_Write_Loopback_Mode_Command, HCI_Write_Loopback_Mode_Command,
LoopbackMode, LoopbackMode,
) )
from bumble.host import Host from bumble.host import Host
from bumble.transport import open_transport from bumble.transport import open_transport_or_link
import click
class Loopback: class Loopback:
"""Send and receive ACL data packets in local loopback mode""" """Send and receive ACL data packets in local loopback mode"""
def __init__( def __init__(self, packet_size: int, packet_count: int, transport: str):
self,
packet_size: int,
packet_count: int,
connection_type: str,
mode: str,
interval: int,
transport: str,
):
self.transport = transport self.transport = transport
self.packet_size = packet_size self.packet_size = packet_size
self.packet_count = packet_count self.packet_count = packet_count
self.connection_handle: int | None = None self.connection_handle: Optional[int] = None
self.connection_type = connection_type
self.connection_event = asyncio.Event() self.connection_event = asyncio.Event()
self.mode = mode
self.interval = interval
self.done = asyncio.Event() self.done = asyncio.Event()
self.expected_counter = 0 self.expected_cid = 0
self.bytes_received = 0 self.bytes_received = 0
self.start_timestamp = 0.0 self.start_timestamp = 0.0
self.last_timestamp = 0.0 self.last_timestamp = 0.0
self.send_timestamps: list[float] = []
self.rtts: list[float] = []
def on_connection(self, connection_handle: int, *args): def on_connection(self, connection_handle: int, *args):
"""Retrieve connection handle from new connection event""" """Retrieve connection handle from new connection event"""
if not self.connection_event.is_set(): if not self.connection_event.is_set():
# The first connection handle is of type ACL, # save first connection handle for ACL
# subsequent connections are of type SCO # subsequent connections are SCO
if self.connection_type == "sco" and self.connection_handle is None:
self.connection_handle = connection_handle
return
self.connection_handle = connection_handle self.connection_handle = connection_handle
self.connection_event.set() self.connection_event.set()
def on_sco_connection( def on_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes):
self,
address: Address,
connection_handle: int,
link_type,
rx_packet_length: int,
tx_packet_length: int,
air_mode: int,
) -> None:
self.on_connection(connection_handle)
def on_packet(self, connection_handle: int, packet: bytes):
"""Calculate packet receive speed""" """Calculate packet receive speed"""
now = time.time() now = time.time()
(counter,) = struct.unpack_from("H", packet, 0) print(f'<<< Received packet {cid}: {len(pdu)} bytes')
rtt = now - self.send_timestamps[counter]
self.rtts.append(rtt)
print(f'<<< Received packet {counter}: {len(packet)} bytes, RTT={rtt:.4f}')
assert connection_handle == self.connection_handle assert connection_handle == self.connection_handle
assert counter == self.expected_counter assert cid == self.expected_cid
self.expected_counter += 1 self.expected_cid += 1
if counter == 0: if cid == 0:
self.start_timestamp = now self.start_timestamp = now
else: else:
elapsed_since_start = now - self.start_timestamp elapsed_since_start = now - self.start_timestamp
elapsed_since_last = now - self.last_timestamp elapsed_since_last = now - self.last_timestamp
self.bytes_received += len(packet) self.bytes_received += len(pdu)
instant_rx_speed = len(packet) / elapsed_since_last instant_rx_speed = len(pdu) / elapsed_since_last
average_rx_speed = self.bytes_received / elapsed_since_start average_rx_speed = self.bytes_received / elapsed_since_start
if self.mode == 'throughput': print(
print( color(
color( f'@@@ RX speed: instant={instant_rx_speed:.4f},'
f'@@@ RX speed: instant={instant_rx_speed:.4f},' f' average={average_rx_speed:.4f}',
f' average={average_rx_speed:.4f},', 'cyan',
'cyan',
)
) )
)
self.last_timestamp = now self.last_timestamp = now
if self.expected_counter == self.packet_count: if self.expected_cid == self.packet_count:
print(color('@@@ Received last packet', 'green')) print(color('@@@ Received last packet', 'green'))
self.done.set() self.done.set()
def on_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes): async def run(self):
self.on_packet(connection_handle, pdu)
def on_sco_packet(self, connection_handle: int, packet) -> None:
self.on_packet(connection_handle, packet)
async def send_acl_packet(self, host: Host, packet: bytes) -> None:
assert self.connection_handle
host.send_l2cap_pdu(self.connection_handle, 0, packet)
async def send_sco_packet(self, host: Host, packet: bytes) -> None:
assert self.connection_handle
host.send_hci_packet(
HCI_SynchronousDataPacket(
connection_handle=self.connection_handle,
packet_status=HCI_SynchronousDataPacket.Status.CORRECTLY_RECEIVED_DATA,
data_total_length=len(packet),
data=packet,
)
)
async def send_loop(self, host: Host, sender) -> None:
for counter in range(0, self.packet_count):
print(
color(
f'>>> Sending {self.connection_type.upper()} '
f'packet {counter}: {self.packet_size} bytes',
'yellow',
)
)
self.send_timestamps.append(time.time())
await sender(host, struct.pack("H", counter) + bytes(self.packet_size - 2))
await asyncio.sleep(self.interval / 1000 if self.mode == "rtt" else 0)
async def run(self) -> None:
"""Run a loopback throughput test""" """Run a loopback throughput test"""
print(color('>>> Connecting to HCI...', 'green')) print(color('>>> Connecting to HCI...', 'green'))
async with await open_transport(self.transport) as ( async with await open_transport_or_link(self.transport) as (
hci_source, hci_source,
hci_sink, hci_sink,
): ):
@@ -170,15 +100,11 @@ class Loopback:
# make sure data can fit in one l2cap pdu # make sure data can fit in one l2cap pdu
l2cap_header_size = 4 l2cap_header_size = 4
packet_queue = ( max_packet_size = (
host.acl_packet_queue host.acl_packet_queue
if host.acl_packet_queue if host.acl_packet_queue
else host.le_acl_packet_queue else host.le_acl_packet_queue
) ).max_packet_size - l2cap_header_size
if packet_queue is None:
print(color('!!! No packet queue', 'red'))
return
max_packet_size = packet_queue.max_packet_size - l2cap_header_size
if self.packet_size > max_packet_size: if self.packet_size > max_packet_size:
print( print(
color( color(
@@ -196,62 +122,56 @@ class Loopback:
return return
# set event callbacks # set event callbacks
host.on('classic_connection', self.on_connection) host.on('connection', self.on_connection)
host.on('le_connection', self.on_connection)
host.on('sco_connection', self.on_sco_connection)
host.on('l2cap_pdu', self.on_l2cap_pdu) host.on('l2cap_pdu', self.on_l2cap_pdu)
host.on('sco_packet', self.on_sco_packet)
loopback_mode = LoopbackMode.LOCAL loopback_mode = LoopbackMode.LOCAL
print(color('### Setting loopback mode', 'blue')) print(color('### Setting loopback mode', 'blue'))
await host.send_sync_command( await host.send_command(
HCI_Write_Loopback_Mode_Command(loopback_mode=LoopbackMode.LOCAL), HCI_Write_Loopback_Mode_Command(loopback_mode=LoopbackMode.LOCAL),
check_result=True,
) )
print(color('### Checking loopback mode', 'blue')) print(color('### Checking loopback mode', 'blue'))
response = await host.send_sync_command(HCI_Read_Loopback_Mode_Command()) response = await host.send_command(
if response.loopback_mode != loopback_mode: HCI_Read_Loopback_Mode_Command(), check_result=True
)
if response.return_parameters.loopback_mode != loopback_mode:
print(color('!!! Loopback mode mismatch', 'red')) print(color('!!! Loopback mode mismatch', 'red'))
return return
await self.connection_event.wait() await self.connection_event.wait()
assert self.connection_handle is not None
print(color('### Connected', 'cyan')) print(color('### Connected', 'cyan'))
print(color('=== Start sending', 'magenta')) print(color('=== Start sending', 'magenta'))
start_time = time.time() start_time = time.time()
if self.connection_type == "acl": bytes_sent = 0
sender = self.send_acl_packet for cid in range(0, self.packet_count):
elif self.connection_type == "sco": # using the cid as an incremental index
sender = self.send_sco_packet host.send_l2cap_pdu(
else: self.connection_handle, cid, bytes(self.packet_size)
raise ValueError(f'Unknown connection type: {self.connection_type}') )
await self.send_loop(host, sender) print(
color(
f'>>> Sending packet {cid}: {self.packet_size} bytes', 'yellow'
)
)
bytes_sent += self.packet_size # don't count L2CAP or HCI header sizes
await asyncio.sleep(0) # yield to allow packet receive
await self.done.wait() await self.done.wait()
print(color('=== Done!', 'magenta')) print(color('=== Done!', 'magenta'))
bytes_sent = self.packet_size * self.packet_count
elapsed = time.time() - start_time elapsed = time.time() - start_time
average_tx_speed = bytes_sent / elapsed average_tx_speed = bytes_sent / elapsed
if self.mode == 'throughput': print(
print( color(
color( f'@@@ TX speed: average={average_tx_speed:.4f} ({bytes_sent} bytes'
f'@@@ TX speed: average={average_tx_speed:.4f} ' f' in {elapsed:.2f} seconds)',
f'({bytes_sent} bytes in {elapsed:.2f} seconds)', 'green',
'green',
)
)
if self.mode == 'rtt':
print(
color(
f'RTTs: min={min(self.rtts):.4f}, '
f'max={max(self.rtts):.4f}, '
f'avg={statistics.mean(self.rtts):.4f}',
'blue',
)
) )
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -272,43 +192,12 @@ class Loopback:
default=10, default=10,
help='Packet count', help='Packet count',
) )
@click.option(
'--connection-type',
'-t',
metavar='TYPE',
type=click.Choice(['acl', 'sco']),
default='acl',
help='Connection type',
)
@click.option(
'--mode',
'-m',
metavar='MODE',
type=click.Choice(['throughput', 'rtt']),
default='throughput',
help='Test mode',
)
@click.option(
'--interval',
type=int,
default=100,
help='Inter-packet interval (ms) [RTT mode only]',
)
@click.argument('transport') @click.argument('transport')
def main(packet_size, packet_count, connection_type, mode, interval, transport): def main(packet_size, packet_count, transport):
bumble.logging.setup_basic_logging() logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
if connection_type == "sco" and packet_size > 255: loopback = Loopback(packet_size, packet_count, transport)
print("ERROR: the maximum packet size for SCO is 255") asyncio.run(loopback.run())
return
async def run():
loopback = Loopback(
packet_size, packet_count, connection_type, mode, interval, transport
)
await loopback.run()
asyncio.run(run())
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
+5 -4
View File
@@ -15,13 +15,14 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import logging
import asyncio import asyncio
import sys import sys
import os
import bumble.logging
from bumble.controller import Controller from bumble.controller import Controller
from bumble.link import LocalLink from bumble.link import LocalLink
from bumble.transport import open_transport from bumble.transport import open_transport_or_link
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -41,7 +42,7 @@ async def async_main():
transports = [] transports = []
controllers = [] controllers = []
for index, transport_name in enumerate(sys.argv[1:]): for index, transport_name in enumerate(sys.argv[1:]):
transport = await open_transport(transport_name) transport = await open_transport_or_link(transport_name)
transports.append(transport) transports.append(transport)
controller = Controller( controller = Controller(
f'C{index}', f'C{index}',
@@ -61,7 +62,7 @@ async def async_main():
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def main(): def main():
bumble.logging.setup_basic_logging() logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(async_main()) asyncio.run(async_main())
+10 -8
View File
@@ -16,22 +16,23 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import asyncio import asyncio
from collections.abc import Callable, Iterable import os
import logging
from typing import Callable, Iterable, Optional
import click import click
import bumble.logging
from bumble.colors import color
from bumble.core import ProtocolError from bumble.core import ProtocolError
from bumble.colors import color
from bumble.device import Device, Peer from bumble.device import Device, Peer
from bumble.gatt import Service from bumble.gatt import Service
from bumble.profiles.battery_service import BatteryServiceProxy
from bumble.profiles.device_information_service import DeviceInformationServiceProxy from bumble.profiles.device_information_service import DeviceInformationServiceProxy
from bumble.profiles.battery_service import BatteryServiceProxy
from bumble.profiles.gap import GenericAccessServiceProxy from bumble.profiles.gap import GenericAccessServiceProxy
from bumble.profiles.pacs import PublishedAudioCapabilitiesServiceProxy from bumble.profiles.pacs import PublishedAudioCapabilitiesServiceProxy
from bumble.profiles.tmap import TelephonyAndMediaAudioServiceProxy from bumble.profiles.tmap import TelephonyAndMediaAudioServiceProxy
from bumble.profiles.vcs import VolumeControlServiceProxy from bumble.profiles.vcs import VolumeControlServiceProxy
from bumble.transport import open_transport from bumble.transport import open_transport_or_link
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -174,7 +175,7 @@ async def show_vcs(vcs: VolumeControlServiceProxy) -> None:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def show_device_info(peer, done: asyncio.Future | None) -> None: async def show_device_info(peer, done: Optional[asyncio.Future]) -> None:
try: try:
# Discover all services # Discover all services
print(color('### Discovering Services and Characteristics', 'magenta')) print(color('### Discovering Services and Characteristics', 'magenta'))
@@ -214,7 +215,8 @@ async def show_device_info(peer, done: asyncio.Future | None) -> None:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def async_main(device_config, encrypt, transport, address_or_name): async def async_main(device_config, encrypt, transport, address_or_name):
async with await open_transport(transport) as (hci_source, hci_sink): async with await open_transport_or_link(transport) as (hci_source, hci_sink):
# Create a device # Create a device
if device_config: if device_config:
device = Device.from_config_file_with_hci( device = Device.from_config_file_with_hci(
@@ -265,7 +267,7 @@ def main(device_config, encrypt, transport, address_or_name):
Dump the GATT database on a remote device. If ADDRESS_OR_NAME is not specified, Dump the GATT database on a remote device. If ADDRESS_OR_NAME is not specified,
wait for an incoming connection. wait for an incoming connection.
""" """
bumble.logging.setup_basic_logging() logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(async_main(device_config, encrypt, transport, address_or_name)) asyncio.run(async_main(device_config, encrypt, transport, address_or_name))
+6 -5
View File
@@ -16,15 +16,15 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import asyncio import asyncio
import os
import logging
import click import click
import bumble.core import bumble.core
import bumble.logging
from bumble.colors import color from bumble.colors import color
from bumble.device import Device, Peer from bumble.device import Device, Peer
from bumble.gatt import show_services from bumble.gatt import show_services
from bumble.transport import open_transport from bumble.transport import open_transport_or_link
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -60,7 +60,8 @@ async def dump_gatt_db(peer, done):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def async_main(device_config, encrypt, transport, address_or_name): async def async_main(device_config, encrypt, transport, address_or_name):
async with await open_transport(transport) as (hci_source, hci_sink): async with await open_transport_or_link(transport) as (hci_source, hci_sink):
# Create a device # Create a device
if device_config: if device_config:
device = Device.from_config_file_with_hci( device = Device.from_config_file_with_hci(
@@ -111,7 +112,7 @@ def main(device_config, encrypt, transport, address_or_name):
Dump the GATT database on a remote device. If ADDRESS_OR_NAME is not specified, Dump the GATT database on a remote device. If ADDRESS_OR_NAME is not specified,
wait for an incoming connection. wait for an incoming connection.
""" """
bumble.logging.setup_basic_logging() logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(async_main(device_config, encrypt, transport, address_or_name)) asyncio.run(async_main(device_config, encrypt, transport, address_or_name))
+10 -9
View File
@@ -16,19 +16,20 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import asyncio import asyncio
import os
import struct import struct
import logging
import click import click
import bumble.logging
from bumble import l2cap from bumble import l2cap
from bumble.colors import color from bumble.colors import color
from bumble.core import AdvertisingData
from bumble.device import Device, Peer from bumble.device import Device, Peer
from bumble.gatt import Characteristic, CharacteristicValue, Service from bumble.core import AdvertisingData
from bumble.hci import HCI_Constant from bumble.gatt import Service, Characteristic, CharacteristicValue
from bumble.transport import open_transport
from bumble.utils import AsyncRunner from bumble.utils import AsyncRunner
from bumble.transport import open_transport_or_link
from bumble.hci import HCI_Constant
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Constants # Constants
@@ -324,7 +325,7 @@ async def run(
receive_port, receive_port,
): ):
print('<<< connecting to HCI...') print('<<< connecting to HCI...')
async with await open_transport(hci_transport) as (hci_source, hci_sink): async with await open_transport_or_link(hci_transport) as (hci_source, hci_sink):
print('<<< connected') print('<<< connected')
# Instantiate a bridge object # Instantiate a bridge object
@@ -352,7 +353,7 @@ async def run(
await bridge.start() await bridge.start()
# Wait until the source terminates # Wait until the source terminates
await hci_source.terminated await hci_source.wait_for_termination()
@click.command() @click.command()
@@ -382,7 +383,6 @@ def main(
receive_host, receive_host,
receive_port, receive_port,
): ):
bumble.logging.setup_basic_logging('WARNING')
asyncio.run( asyncio.run(
run( run(
hci_transport, hci_transport,
@@ -397,5 +397,6 @@ def main(
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
if __name__ == '__main__': if __name__ == '__main__':
main() main()
+6 -9
View File
@@ -12,15 +12,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import asyncio
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import logging import logging
import asyncio
import os
import sys import sys
import bumble.logging
from bumble import hci, transport from bumble import hci, transport
from bumble.bridge import HCI_Bridge from bumble.bridge import HCI_Bridge
@@ -47,14 +46,14 @@ async def async_main():
return return
print('>>> connecting to HCI...') print('>>> connecting to HCI...')
async with await transport.open_transport(sys.argv[1]) as ( async with await transport.open_transport_or_link(sys.argv[1]) as (
hci_host_source, hci_host_source,
hci_host_sink, hci_host_sink,
): ):
print('>>> connected') print('>>> connected')
print('>>> connecting to HCI...') print('>>> connecting to HCI...')
async with await transport.open_transport(sys.argv[2]) as ( async with await transport.open_transport_or_link(sys.argv[2]) as (
hci_controller_source, hci_controller_source,
hci_controller_sink, hci_controller_sink,
): ):
@@ -81,9 +80,7 @@ async def async_main():
response = hci.HCI_Command_Complete_Event( response = hci.HCI_Command_Complete_Event(
num_hci_command_packets=1, num_hci_command_packets=1,
command_opcode=hci_packet.op_code, command_opcode=hci_packet.op_code,
return_parameters=hci.HCI_StatusReturnParameters( return_parameters=bytes([hci.HCI_SUCCESS]),
status=hci.HCI_ErrorCode.SUCCESS
),
) )
# Return a packet with 'respond to sender' set to True # Return a packet with 'respond to sender' set to True
return (bytes(response), True) return (bytes(response), True)
@@ -103,7 +100,7 @@ async def async_main():
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def main(): def main():
bumble.logging.setup_basic_logging() logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(async_main()) asyncio.run(async_main())
+7 -7
View File
@@ -16,16 +16,16 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import asyncio import asyncio
import logging
import os
import click import click
import bumble.logging
from bumble import l2cap from bumble import l2cap
from bumble.colors import color from bumble.colors import color
from bumble.transport import open_transport_or_link
from bumble.device import Device from bumble.device import Device
from bumble.hci import HCI_Constant
from bumble.transport import open_transport
from bumble.utils import FlowControlAsyncPipe from bumble.utils import FlowControlAsyncPipe
from bumble.hci import HCI_Constant
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -258,7 +258,7 @@ class ClientBridge:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def run(device_config, hci_transport, bridge): async def run(device_config, hci_transport, bridge):
print('<<< connecting to HCI...') print('<<< connecting to HCI...')
async with await open_transport(hci_transport) as (hci_source, hci_sink): async with await open_transport_or_link(hci_transport) as (hci_source, hci_sink):
print('<<< connected') print('<<< connected')
device = Device.from_config_file_with_hci(device_config, hci_source, hci_sink) device = Device.from_config_file_with_hci(device_config, hci_source, hci_sink)
@@ -268,7 +268,7 @@ async def run(device_config, hci_transport, bridge):
await bridge.start(device) await bridge.start(device)
# Wait until the transport terminates # Wait until the transport terminates
await hci_source.terminated await hci_source.wait_for_termination()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -356,6 +356,6 @@ def client(context, bluetooth_address, tcp_host, tcp_port):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
if __name__ == '__main__': if __name__ == '__main__':
bumble.logging.setup_basic_logging('WARNING')
cli(obj={}) # pylint: disable=no-value-for-parameter cli(obj={}) # pylint: disable=no-value-for-parameter
+22 -17
View File
@@ -20,30 +20,31 @@ from __future__ import annotations
import asyncio import asyncio
import datetime import datetime
import functools import functools
from importlib import resources
import json import json
import os
import logging import logging
import pathlib import pathlib
import wave
import weakref import weakref
from importlib import resources import wave
try: try:
import lc3 # type: ignore # pylint: disable=E0401 import lc3 # type: ignore # pylint: disable=E0401
except ImportError as e: except ImportError as e:
raise ImportError("Try `python -m pip install \".[lc3]\"`.") from e raise ImportError("Try `python -m pip install \".[lc3]\"`.") from e
import aiohttp.web
import click import click
import aiohttp.web
import bumble import bumble
import bumble.logging from bumble import utils
from bumble import data_types, utils
from bumble.colors import color
from bumble.core import AdvertisingData from bumble.core import AdvertisingData
from bumble.device import AdvertisingParameters, CisLink, Device, DeviceConfiguration from bumble.colors import color
from bumble.hci import Address, CodecID, CodingFormat, HCI_IsoDataPacket from bumble.device import Device, DeviceConfiguration, AdvertisingParameters, CisLink
from bumble.profiles import ascs, bap, pacs
from bumble.transport import open_transport from bumble.transport import open_transport
from bumble.profiles import ascs, bap, pacs
from bumble.hci import Address, CodecID, CodingFormat, HCI_IsoDataPacket
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -268,6 +269,7 @@ class UiServer:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Speaker: class Speaker:
def __init__( def __init__(
self, self,
device_config_path: str | None, device_config_path: str | None,
@@ -298,7 +300,6 @@ class Speaker:
advertising_interval_max=25, advertising_interval_max=25,
address=Address('F1:F2:F3:F4:F5:F6'), address=Address('F1:F2:F3:F4:F5:F6'),
identity_address_type=Address.RANDOM_DEVICE_ADDRESS, identity_address_type=Address.RANDOM_DEVICE_ADDRESS,
eatt_enabled=True,
) )
device_config.le_enabled = True device_config.le_enabled = True
@@ -330,13 +331,17 @@ class Speaker:
advertising_data = bytes( advertising_data = bytes(
AdvertisingData( AdvertisingData(
[ [
data_types.CompleteLocalName(device_config.name), (
data_types.Flags( AdvertisingData.COMPLETE_LOCAL_NAME,
AdvertisingData.Flags.LE_GENERAL_DISCOVERABLE_MODE bytes(device_config.name, 'utf-8'),
| AdvertisingData.Flags.BR_EDR_NOT_SUPPORTED
), ),
data_types.IncompleteListOf16BitServiceUUIDs( (
[pacs.PublishedAudioCapabilitiesService.UUID] AdvertisingData.FLAGS,
bytes([AdvertisingData.LE_GENERAL_DISCOVERABLE_MODE_FLAG]),
),
(
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
bytes(pacs.PublishedAudioCapabilitiesService.UUID),
), ),
] ]
) )
@@ -444,7 +449,7 @@ def speaker(ui_port: int, device_config: str, transport: str, lc3_file: str) ->
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def main(): def main():
bumble.logging.setup_basic_logging() logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
speaker() speaker()
View File
+289
View File
@@ -0,0 +1,289 @@
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ----------------------------------------------------------------------------
# Imports
# ----------------------------------------------------------------------------
import sys
import logging
import json
import asyncio
import argparse
import uuid
import os
from urllib.parse import urlparse
import websockets
from bumble.colors import color
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------
# Constants
# ----------------------------------------------------------------------------
DEFAULT_RELAY_PORT = 10723
# ----------------------------------------------------------------------------
# Utils
# ----------------------------------------------------------------------------
def error_to_json(error):
return json.dumps({'error': error})
def error_to_result(error):
return f'result:{error_to_json(error)}'
async def broadcast_message(message, connections):
# Send to all the connections
tasks = [connection.send_message(message) for connection in connections]
if tasks:
await asyncio.gather(*tasks)
# ----------------------------------------------------------------------------
# Connection class
# ----------------------------------------------------------------------------
class Connection:
"""
A Connection represents a client connected to the relay over a websocket
"""
def __init__(self, room, websocket):
self.room = room
self.websocket = websocket
self.address = str(uuid.uuid4())
async def send_message(self, message):
try:
logger.debug(color(f'->{self.address}: {message}', 'yellow'))
return await self.websocket.send(message)
except websockets.exceptions.WebSocketException as error:
logger.info(f'! client "{self}" disconnected: {error}')
await self.cleanup()
async def send_error(self, error):
return await self.send_message(f'result:{error_to_json(error)}')
async def receive_message(self):
try:
message = await self.websocket.recv()
logger.debug(color(f'<-{self.address}: {message}', 'blue'))
return message
except websockets.exceptions.WebSocketException as error:
logger.info(color(f'! client "{self}" disconnected: {error}', 'red'))
await self.cleanup()
async def cleanup(self):
if self.room:
await self.room.remove_connection(self)
def set_address(self, address):
logger.info(f'Connection address changed: {self.address} -> {address}')
self.address = address
def __str__(self):
return (
f'Connection(address="{self.address}", '
f'client={self.websocket.remote_address[0]}:'
f'{self.websocket.remote_address[1]})'
)
# ----------------------------------------------------------------------------
# Room class
# ----------------------------------------------------------------------------
class Room:
"""
A Room is a collection of bridged connections
"""
def __init__(self, relay, name):
self.relay = relay
self.name = name
self.observers = []
self.connections = []
async def add_connection(self, connection):
logger.info(f'New participant in {self.name}: {connection}')
self.connections.append(connection)
await self.broadcast_message(connection, f'joined:{connection.address}')
async def remove_connection(self, connection):
if connection in self.connections:
self.connections.remove(connection)
await self.broadcast_message(connection, f'left:{connection.address}')
def find_connections_by_address(self, address):
return [c for c in self.connections if c.address == address]
async def bridge_connection(self, connection):
while True:
# Wait for a message
message = await connection.receive_message()
# Skip empty messages
if message is None:
return
# Parse the message to decide how to handle it
if message.startswith('@'):
# This is a targeted message
await self.on_targeted_message(connection, message)
elif message.startswith('/'):
# This is an RPC request
await self.on_rpc_request(connection, message)
else:
await connection.send_message(
f'result:{error_to_json("error: invalid message")}'
)
async def broadcast_message(self, sender, message):
'''
Send to all connections in the room except back to the sender
'''
await broadcast_message(message, [c for c in self.connections if c != sender])
async def on_rpc_request(self, connection, message):
command, *params = message.split(' ', 1)
if handler := getattr(
self, f'on_{command[1:].lower().replace("-","_")}_command', None
):
try:
result = await handler(connection, params)
except Exception as error:
result = error_to_result(error)
else:
result = error_to_result('unknown command')
await connection.send_message(result or 'result:{}')
async def on_targeted_message(self, connection, message):
target, *payload = message.split(' ', 1)
if not payload:
return error_to_json('missing arguments')
payload = payload[0]
target = target[1:]
# Determine what targets to send to
if target == '*':
# Send to all connections in the room except the connection from which the
# message was received
connections = [c for c in self.connections if c != connection]
else:
connections = self.find_connections_by_address(target)
if not connections:
# Unicast with no recipient, let the sender know
await connection.send_message(f'unreachable:{target}')
# Send to targets
await broadcast_message(f'message:{connection.address}/{payload}', connections)
async def on_set_address_command(self, connection, params):
if not params:
return error_to_result('missing address')
current_address = connection.address
new_address = params[0]
connection.set_address(new_address)
await self.broadcast_message(
connection, f'address-changed:from={current_address},to={new_address}'
)
# ----------------------------------------------------------------------------
class Relay:
"""
A relay accepts connections with the following url: ws://<hostname>/<room>.
Participants in a room can communicate with each other
"""
def __init__(self, port):
self.port = port
self.rooms = {}
self.observers = []
def start(self):
logger.info(f'Starting Relay on port {self.port}')
# pylint: disable-next=no-member
return websockets.serve(self.serve, '0.0.0.0', self.port, ping_interval=None)
async def serve_as_controller(self, connection):
pass
async def serve(self, websocket, path):
logger.debug(f'New connection with path {path}')
# Parse the path
parsed = urlparse(path)
# Check if this is a controller client
if parsed.path == '/':
return await self.serve_as_controller(Connection('', websocket))
# Find or create a room for this connection
room_name = parsed.path[1:].split('/')[0]
if room_name not in self.rooms:
self.rooms[room_name] = Room(self, room_name)
room = self.rooms[room_name]
# Add the connection to the room
connection = Connection(room, websocket)
await room.add_connection(connection)
# Bridge until the connection is closed
await room.bridge_connection(connection)
# ----------------------------------------------------------------------------
def main():
# Check the Python version
if sys.version_info < (3, 6, 1):
print('ERROR: Python 3.6.1 or higher is required')
sys.exit(1)
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
# Parse arguments
arg_parser = argparse.ArgumentParser(description='Bumble Link Relay')
arg_parser.add_argument('--log-level', default='INFO', help='logger level')
arg_parser.add_argument('--log-config', help='logger config file (YAML)')
arg_parser.add_argument(
'--port', type=int, default=DEFAULT_RELAY_PORT, help='Port to listen on'
)
args = arg_parser.parse_args()
# Setup logger
if args.log_config:
from logging import config # pylint: disable=import-outside-toplevel
config.fileConfig(args.log_config)
else:
logging.basicConfig(level=getattr(logging, args.log_level.upper()))
# Start a relay
relay = Relay(args.port)
asyncio.get_event_loop().run_until_complete(relay.start())
asyncio.get_event_loop().run_forever()
# ----------------------------------------------------------------------------
if __name__ == '__main__':
main()
+21
View File
@@ -0,0 +1,21 @@
[loggers]
keys=root
[handlers]
keys=stream_handler
[formatters]
keys=formatter
[logger_root]
level=DEBUG
handlers=stream_handler
[handler_stream_handler]
class=StreamHandler
level=DEBUG
formatter=formatter
args=(sys.stderr,)
[formatter_formatter]
format=%(asctime)s %(name)-12s %(levelname)-8s %(message)s
+100 -99
View File
@@ -15,46 +15,43 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio import asyncio
import logging
import os import os
from typing import ClassVar import logging
import struct
import click import click
from prompt_toolkit.shortcuts import PromptSession from prompt_toolkit.shortcuts import PromptSession
from bumble import data_types, smp
from bumble.a2dp import make_audio_sink_service_sdp_records from bumble.a2dp import make_audio_sink_service_sdp_records
from bumble.att import (
ATT_INSUFFICIENT_AUTHENTICATION_ERROR,
ATT_INSUFFICIENT_ENCRYPTION_ERROR,
ATT_Error,
)
from bumble.colors import color from bumble.colors import color
from bumble.device import Device, Peer
from bumble.transport import open_transport_or_link
from bumble.pairing import OobData, PairingDelegate, PairingConfig
from bumble.smp import OobContext, OobLegacyContext
from bumble.smp import error_name as smp_error_name
from bumble.keys import JsonKeyStore
from bumble.core import ( from bumble.core import (
UUID,
AdvertisingData, AdvertisingData,
Appearance, Appearance,
DataType,
PhysicalTransport,
ProtocolError, ProtocolError,
PhysicalTransport,
UUID,
) )
from bumble.device import Connection, Device, Peer
from bumble.gatt import ( from bumble.gatt import (
GATT_DEVICE_NAME_CHARACTERISTIC, GATT_DEVICE_NAME_CHARACTERISTIC,
GATT_GENERIC_ACCESS_SERVICE, GATT_GENERIC_ACCESS_SERVICE,
GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC,
GATT_HEART_RATE_SERVICE, GATT_HEART_RATE_SERVICE,
Characteristic, GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC,
Service, Service,
Characteristic,
) )
from bumble.hci import OwnAddressType from bumble.hci import OwnAddressType
from bumble.keys import JsonKeyStore from bumble.att import (
from bumble.pairing import OobData, PairingConfig, PairingDelegate ATT_Error,
from bumble.smp import OobContext, OobLegacyContext ATT_INSUFFICIENT_AUTHENTICATION_ERROR,
from bumble.transport import open_transport ATT_INSUFFICIENT_ENCRYPTION_ERROR,
)
from bumble.utils import AsyncRunner from bumble.utils import AsyncRunner
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -65,7 +62,7 @@ POST_PAIRING_DELAY = 1
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Waiter: class Waiter:
instance: ClassVar[Waiter | None] = None instance = None
def __init__(self, linger=False): def __init__(self, linger=False):
self.done = asyncio.get_running_loop().create_future() self.done = asyncio.get_running_loop().create_future()
@@ -319,41 +316,40 @@ async def on_classic_pairing(connection):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@AsyncRunner.run_in_task() @AsyncRunner.run_in_task()
async def on_pairing_failure(connection: Connection, reason: smp.ErrorCode): async def on_pairing_failure(connection, reason):
print(color('***-----------------------------------', 'red')) print(color('***-----------------------------------', 'red'))
print(color(f'*** Pairing failed: {reason.name}', 'red')) print(color(f'*** Pairing failed: {smp_error_name(reason)}', 'red'))
print(color('***-----------------------------------', 'red')) print(color('***-----------------------------------', 'red'))
await connection.disconnect() await connection.disconnect()
if Waiter.instance: Waiter.instance.terminate()
Waiter.instance.terminate()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def pair( async def pair(
mode: str, mode,
sc: bool, sc,
mitm: bool, mitm,
bond: bool, bond,
ctkd: bool, ctkd,
advertising_address: str, advertising_address,
identity_address: str, identity_address,
linger: bool, linger,
io: str, io,
oob: str, oob,
prompt: bool, prompt,
request: bool, request,
print_keys: bool, print_keys,
keystore_file: str, keystore_file,
advertise_service_uuids: str, advertise_service_uuids,
advertise_appearance: str, advertise_appearance,
device_config: str, device_config,
hci_transport: str, hci_transport,
address_or_name: str, address_or_name,
): ):
Waiter.instance = Waiter(linger=linger) Waiter.instance = Waiter(linger=linger)
print('<<< connecting to HCI...') print('<<< connecting to HCI...')
async with await open_transport(hci_transport) as (hci_source, hci_sink): async with await open_transport_or_link(hci_transport) as (hci_source, hci_sink):
print('<<< connected') print('<<< connected')
# Create a device to manage the host # Create a device to manage the host
@@ -406,20 +402,14 @@ async def pair(
# Create an OOB context if needed # Create an OOB context if needed
if oob: if oob:
our_oob_context = OobContext() our_oob_context = OobContext()
legacy_context: OobLegacyContext | None shared_data = (
if oob == '-': None
shared_data = None if oob == '-'
legacy_context = OobLegacyContext() else OobData.from_ad(
else:
oob_data = OobData.from_ad(
AdvertisingData.from_bytes(bytes.fromhex(oob)) AdvertisingData.from_bytes(bytes.fromhex(oob))
) ).shared_data
shared_data = oob_data.shared_data )
legacy_context = oob_data.legacy_context legacy_context = OobLegacyContext()
if legacy_context is None and not sc:
print(color('OOB pairing in legacy mode requires TK', 'red'))
return
oob_contexts = PairingConfig.OobConfig( oob_contexts = PairingConfig.OobConfig(
our_context=our_oob_context, our_context=our_oob_context,
peer_data=shared_data, peer_data=shared_data,
@@ -429,9 +419,7 @@ async def pair(
print(color('@@@ OOB Data:', 'yellow')) print(color('@@@ OOB Data:', 'yellow'))
if shared_data is None: if shared_data is None:
oob_data = OobData( oob_data = OobData(
address=device.random_address, address=device.random_address, shared_data=our_oob_context.share()
shared_data=our_oob_context.share(),
legacy_context=(None if sc else legacy_context),
) )
print( print(
color( color(
@@ -439,8 +427,7 @@ async def pair(
'yellow', 'yellow',
) )
) )
if legacy_context: print(color(f'@@@ TK={legacy_context.tk.hex()}', 'yellow'))
print(color(f'@@@ TK={legacy_context.tk.hex()}', 'yellow'))
print(color('@@@-----------------------------------', 'yellow')) print(color('@@@-----------------------------------', 'yellow'))
else: else:
oob_contexts = None oob_contexts = None
@@ -511,29 +498,39 @@ async def pair(
if mode == 'dual': if mode == 'dual':
flags |= AdvertisingData.Flags.SIMULTANEOUS_LE_BR_EDR_CAPABLE flags |= AdvertisingData.Flags.SIMULTANEOUS_LE_BR_EDR_CAPABLE
advertising_data_types: list[DataType] = [ ad_structs = [
data_types.Flags(flags), (
data_types.CompleteLocalName('Bumble'), AdvertisingData.FLAGS,
bytes([flags]),
),
(AdvertisingData.COMPLETE_LOCAL_NAME, 'Bumble'.encode()),
] ]
if service_uuids_16: if service_uuids_16:
advertising_data_types.append( ad_structs.append(
data_types.IncompleteListOf16BitServiceUUIDs(service_uuids_16) (
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
b"".join(bytes(uuid) for uuid in service_uuids_16),
)
) )
if service_uuids_32: if service_uuids_32:
advertising_data_types.append( ad_structs.append(
data_types.IncompleteListOf32BitServiceUUIDs(service_uuids_32) (
AdvertisingData.INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS,
b"".join(bytes(uuid) for uuid in service_uuids_32),
)
) )
if service_uuids_128: if service_uuids_128:
advertising_data_types.append( ad_structs.append(
data_types.IncompleteListOf128BitServiceUUIDs(service_uuids_128) (
AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS,
b"".join(bytes(uuid) for uuid in service_uuids_128),
)
) )
if advertise_appearance: if advertise_appearance:
advertise_appearance = advertise_appearance.upper() advertise_appearance = advertise_appearance.upper()
try: try:
appearance = data_types.Appearance.from_int( advertise_appearance_int = int(advertise_appearance)
int(advertise_appearance)
)
except ValueError: except ValueError:
category, subcategory = advertise_appearance.split('/') category, subcategory = advertise_appearance.split('/')
try: try:
@@ -551,12 +548,16 @@ async def pair(
except ValueError: except ValueError:
print(color(f'Invalid subcategory {subcategory}', 'red')) print(color(f'Invalid subcategory {subcategory}', 'red'))
return return
appearance = data_types.Appearance( advertise_appearance_int = int(
category_enum, subcategory_enum Appearance(category_enum, subcategory_enum)
) )
ad_structs.append(
advertising_data_types.append(appearance) (
device.advertising_data = bytes(AdvertisingData(advertising_data_types)) AdvertisingData.APPEARANCE,
struct.pack('<H', advertise_appearance_int),
)
)
device.advertising_data = bytes(AdvertisingData(ad_structs))
await device.start_advertising( await device.start_advertising(
auto_restart=True, auto_restart=True,
own_address_type=( own_address_type=(
@@ -665,25 +666,25 @@ class LogHandler(logging.Handler):
@click.argument('hci_transport') @click.argument('hci_transport')
@click.argument('address-or-name', required=False) @click.argument('address-or-name', required=False)
def main( def main(
mode: str, mode,
sc: bool, sc,
mitm: bool, mitm,
bond: bool, bond,
ctkd: bool, ctkd,
advertising_address: str, advertising_address,
identity_address: str, identity_address,
linger: bool, linger,
io: str, io,
oob: str, oob,
prompt: bool, prompt,
request: bool, request,
print_keys: bool, print_keys,
keystore_file: str, keystore_file,
advertise_service_uuid: str, advertise_service_uuid,
advertise_appearance: str, advertise_appearance,
device_config: str, device_config,
hci_transport: str, hci_transport,
address_or_name: str, address_or_name,
): ):
# Setup logging # Setup logging
log_handler = LogHandler() log_handler = LogHandler()
+7 -8
View File
@@ -1,11 +1,10 @@
import asyncio import asyncio
import json
import logging
from typing import Any
import click import click
import logging
import json
from bumble.pandora import Config, PandoraDevice, serve from bumble.pandora import PandoraDevice, Config, serve
from typing import Dict, Any
BUMBLE_SERVER_GRPC_PORT = 7999 BUMBLE_SERVER_GRPC_PORT = 7999
ROOTCANAL_PORT_CUTTLEFISH = 7300 ROOTCANAL_PORT_CUTTLEFISH = 7300
@@ -19,7 +18,7 @@ ROOTCANAL_PORT_CUTTLEFISH = 7300
@click.option( @click.option(
'--transport', '--transport',
help='HCI transport', help='HCI transport',
default='tcp-client:127.0.0.1:<rootcanal-port>', default=f'tcp-client:127.0.0.1:<rootcanal-port>',
) )
@click.option( @click.option(
'--config', '--config',
@@ -40,11 +39,11 @@ def main(grpc_port: int, rootcanal_port: int, transport: str, config: str) -> No
asyncio.run(serve(device, config=server_config, port=grpc_port)) asyncio.run(serve(device, config=server_config, port=grpc_port))
def retrieve_config(config: str) -> dict[str, Any]: def retrieve_config(config: str) -> Dict[str, Any]:
if not config: if not config:
return {} return {}
with open(config) as f: with open(config, 'r') as f:
return json.load(f) return json.load(f)
+33 -27
View File
@@ -16,49 +16,55 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import asyncio.subprocess
import os
import logging import logging
from typing import Optional, Union
import click import click
import bumble.logging
from bumble.a2dp import ( from bumble.a2dp import (
make_audio_source_service_sdp_records,
A2DP_SBC_CODEC_TYPE,
A2DP_MPEG_2_4_AAC_CODEC_TYPE, A2DP_MPEG_2_4_AAC_CODEC_TYPE,
A2DP_NON_A2DP_CODEC_TYPE, A2DP_NON_A2DP_CODEC_TYPE,
A2DP_SBC_CODEC_TYPE,
AacFrame, AacFrame,
AacMediaCodecInformation,
AacPacketSource,
AacParser, AacParser,
OpusMediaCodecInformation, AacPacketSource,
OpusPacket, AacMediaCodecInformation,
OpusPacketSource,
OpusParser,
SbcFrame, SbcFrame,
SbcMediaCodecInformation,
SbcPacketSource,
SbcParser, SbcParser,
make_audio_source_service_sdp_records, SbcPacketSource,
SbcMediaCodecInformation,
OpusPacket,
OpusParser,
OpusPacketSource,
OpusMediaCodecInformation,
) )
from bumble.avrcp import Protocol as AvrcpProtocol
from bumble.avdtp import ( from bumble.avdtp import (
find_avdtp_service_with_connection,
AVDTP_AUDIO_MEDIA_TYPE, AVDTP_AUDIO_MEDIA_TYPE,
AVDTP_DELAY_REPORTING_SERVICE_CATEGORY, AVDTP_DELAY_REPORTING_SERVICE_CATEGORY,
MediaCodecCapabilities, MediaCodecCapabilities,
MediaPacketPump, MediaPacketPump,
find_avdtp_service_with_connection, Protocol as AvdtpProtocol,
) )
from bumble.avdtp import Protocol as AvdtpProtocol
from bumble.avrcp import Protocol as AvrcpProtocol
from bumble.colors import color from bumble.colors import color
from bumble.core import AdvertisingData, DeviceClass, PhysicalTransport from bumble.core import (
from bumble.core import ConnectionError as BumbleConnectionError AdvertisingData,
ConnectionError as BumbleConnectionError,
DeviceClass,
PhysicalTransport,
)
from bumble.device import Connection, Device, DeviceConfiguration from bumble.device import Connection, Device, DeviceConfiguration
from bumble.hci import HCI_CONNECTION_ALREADY_EXISTS_ERROR, Address, HCI_Constant from bumble.hci import Address, HCI_CONNECTION_ALREADY_EXISTS_ERROR, HCI_Constant
from bumble.pairing import PairingConfig from bumble.pairing import PairingConfig
from bumble.transport import open_transport from bumble.transport import open_transport
from bumble.utils import AsyncRunner from bumble.utils import AsyncRunner
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -189,7 +195,7 @@ class Player:
def __init__( def __init__(
self, self,
transport: str, transport: str,
device_config: str | None, device_config: Optional[str],
authenticate: bool, authenticate: bool,
encrypt: bool, encrypt: bool,
) -> None: ) -> None:
@@ -197,8 +203,8 @@ class Player:
self.device_config = device_config self.device_config = device_config
self.authenticate = authenticate self.authenticate = authenticate
self.encrypt = encrypt self.encrypt = encrypt
self.avrcp_protocol: AvrcpProtocol | None = None self.avrcp_protocol: Optional[AvrcpProtocol] = None
self.done: asyncio.Event | None self.done: Optional[asyncio.Event]
async def run(self, workload) -> None: async def run(self, workload) -> None:
self.done = asyncio.Event() self.done = asyncio.Event()
@@ -313,7 +319,7 @@ class Player:
codec_type: int, codec_type: int,
vendor_id: int, vendor_id: int,
codec_id: int, codec_id: int,
packet_source: SbcPacketSource | AacPacketSource | OpusPacketSource, packet_source: Union[SbcPacketSource, AacPacketSource, OpusPacketSource],
codec_capabilities: MediaCodecCapabilities, codec_capabilities: MediaCodecCapabilities,
): ):
# Discover all endpoints on the remote device # Discover all endpoints on the remote device
@@ -379,11 +385,11 @@ class Player:
print(f">>> {color(address.to_string(False), 'yellow')}:") print(f">>> {color(address.to_string(False), 'yellow')}:")
print(f" Device Class (raw): {class_of_device:06X}") print(f" Device Class (raw): {class_of_device:06X}")
major_class_name = DeviceClass.major_device_class_name(major_device_class) major_class_name = DeviceClass.major_device_class_name(major_device_class)
print(f" Device Major Class: {major_class_name}") print(" Device Major Class: " f"{major_class_name}")
minor_class_name = DeviceClass.minor_device_class_name( minor_class_name = DeviceClass.minor_device_class_name(
major_device_class, minor_device_class major_device_class, minor_device_class
) )
print(f" Device Minor Class: {minor_class_name}") print(" Device Minor Class: " f"{minor_class_name}")
print( print(
" Device Services: " " Device Services: "
f"{', '.join(DeviceClass.service_class_labels(service_classes))}" f"{', '.join(DeviceClass.service_class_labels(service_classes))}"
@@ -418,7 +424,7 @@ class Player:
async def play( async def play(
self, self,
device: Device, device: Device,
address: str | None, address: Optional[str],
audio_format: str, audio_format: str,
audio_file: str, audio_file: str,
) -> None: ) -> None:
@@ -447,7 +453,7 @@ class Player:
return input_file.read(byte_count) return input_file.read(byte_count)
# Obtain the codec capabilities from the stream # Obtain the codec capabilities from the stream
packet_source: SbcPacketSource | AacPacketSource | OpusPacketSource packet_source: Union[SbcPacketSource, AacPacketSource, OpusPacketSource]
vendor_id = 0 vendor_id = 0
codec_id = 0 codec_id = 0
if audio_format == "sbc": if audio_format == "sbc":
@@ -593,7 +599,7 @@ def play(context, address, audio_format, audio_file):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def main(): def main():
bumble.logging.setup_basic_logging("WARNING") logging.basicConfig(level=os.environ.get("BUMBLE_LOGLEVEL", "WARNING").upper())
player_cli() player_cli()
+23 -16
View File
@@ -16,14 +16,21 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import asyncio import asyncio
import logging
import os
import time import time
from typing import Optional
import click import click
import bumble.logging
from bumble import core, hci, rfcomm, transport, utils
from bumble.colors import color from bumble.colors import color
from bumble.device import Connection, Device, DeviceConfiguration from bumble.device import Device, DeviceConfiguration, Connection
from bumble import core
from bumble import hci
from bumble import rfcomm
from bumble import transport
from bumble import utils
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Constants # Constants
@@ -81,14 +88,14 @@ class ServerBridge:
def __init__( def __init__(
self, channel: int, uuid: str, trace: bool, tcp_host: str, tcp_port: int self, channel: int, uuid: str, trace: bool, tcp_host: str, tcp_port: int
) -> None: ) -> None:
self.device: Device | None = None self.device: Optional[Device] = None
self.channel = channel self.channel = channel
self.uuid = uuid self.uuid = uuid
self.tcp_host = tcp_host self.tcp_host = tcp_host
self.tcp_port = tcp_port self.tcp_port = tcp_port
self.rfcomm_channel: rfcomm.DLC | None = None self.rfcomm_channel: Optional[rfcomm.DLC] = None
self.tcp_tracer: Tracer | None self.tcp_tracer: Optional[Tracer]
self.rfcomm_tracer: Tracer | None self.rfcomm_tracer: Optional[Tracer]
if trace: if trace:
self.tcp_tracer = Tracer(color("RFCOMM->TCP", "cyan")) self.tcp_tracer = Tracer(color("RFCOMM->TCP", "cyan"))
@@ -241,14 +248,14 @@ class ClientBridge:
self.tcp_port = tcp_port self.tcp_port = tcp_port
self.authenticate = authenticate self.authenticate = authenticate
self.encrypt = encrypt self.encrypt = encrypt
self.device: Device | None = None self.device: Optional[Device] = None
self.connection: Connection | None = None self.connection: Optional[Connection] = None
self.rfcomm_client: rfcomm.Client | None self.rfcomm_client: Optional[rfcomm.Client]
self.rfcomm_mux: rfcomm.Multiplexer | None self.rfcomm_mux: Optional[rfcomm.Multiplexer]
self.tcp_connected: bool = False self.tcp_connected: bool = False
self.tcp_tracer: Tracer | None self.tcp_tracer: Optional[Tracer]
self.rfcomm_tracer: Tracer | None self.rfcomm_tracer: Optional[Tracer]
if trace: if trace:
self.tcp_tracer = Tracer(color("RFCOMM->TCP", "cyan")) self.tcp_tracer = Tracer(color("RFCOMM->TCP", "cyan"))
@@ -399,7 +406,7 @@ class ClientBridge:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def run(device_config, hci_transport, bridge): async def run(device_config, hci_transport, bridge):
print("<<< connecting to HCI...") print("<<< connecting to HCI...")
async with await transport.open_transport(hci_transport) as ( async with await transport.open_transport_or_link(hci_transport) as (
hci_source, hci_source,
hci_sink, hci_sink,
): ):
@@ -421,7 +428,7 @@ async def run(device_config, hci_transport, bridge):
await bridge.start(device) await bridge.start(device)
# Wait until the transport terminates # Wait until the transport terminates
await hci_source.terminated await hci_source.wait_for_termination()
except core.ConnectionError as error: except core.ConnectionError as error:
print(color(f"!!! Bluetooth connection failed: {error}", "red")) print(color(f"!!! Bluetooth connection failed: {error}", "red"))
except Exception as error: except Exception as error:
@@ -508,6 +515,6 @@ def client(context, bluetooth_address, tcp_host, tcp_port, authenticate, encrypt
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
logging.basicConfig(level=os.environ.get("BUMBLE_LOGLEVEL", "WARNING").upper())
if __name__ == "__main__": if __name__ == "__main__":
bumble.logging.setup_basic_logging("WARNING")
cli(obj={}) # pylint: disable=no-value-for-parameter cli(obj={}) # pylint: disable=no-value-for-parameter
+15 -28
View File
@@ -16,17 +16,17 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import asyncio import asyncio
import os
import logging
import click import click
import bumble.logging
from bumble import data_types
from bumble.colors import color from bumble.colors import color
from bumble.device import Advertisement, Device, DeviceConfiguration from bumble.device import Device
from bumble.hci import HCI_LE_1M_PHY, HCI_LE_CODED_PHY, Address, HCI_Constant from bumble.transport import open_transport_or_link
from bumble.keys import JsonKeyStore from bumble.keys import JsonKeyStore
from bumble.smp import AddressResolver from bumble.smp import AddressResolver
from bumble.transport import open_transport from bumble.device import Advertisement
from bumble.hci import Address, HCI_Constant, HCI_LE_1M_PHY, HCI_LE_CODED_PHY
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -95,22 +95,13 @@ class AdvertisementPrinter:
else: else:
phy_info = '' phy_info = ''
details = separator.join(
[
data_type.to_string(use_label=True)
for data_type in data_types.data_types_from_advertising_data(
advertisement.data
)
]
)
print( print(
f'>>> {color(address, address_color)} ' f'>>> {color(address, address_color)} '
f'[{color(address_type_string, type_color)}]{address_qualifier}' f'[{color(address_type_string, type_color)}]{address_qualifier}'
f'{resolution_qualifier}:{separator}' f'{resolution_qualifier}:{separator}'
f'{phy_info}' f'{phy_info}'
f'RSSI:{advertisement.rssi:4} {rssi_bar}{separator}' f'RSSI:{advertisement.rssi:4} {rssi_bar}{separator}'
f'{details}\n' f'{advertisement.data.to_string(separator)}\n'
) )
def on_advertisement(self, advertisement): def on_advertisement(self, advertisement):
@@ -136,7 +127,7 @@ async def scan(
transport, transport,
): ):
print('<<< connecting to HCI...') print('<<< connecting to HCI...')
async with await open_transport(transport) as (hci_source, hci_sink): async with await open_transport_or_link(transport) as (hci_source, hci_sink):
print('<<< connected') print('<<< connected')
if device_config: if device_config:
@@ -144,14 +135,8 @@ async def scan(
device_config, hci_source, hci_sink device_config, hci_source, hci_sink
) )
else: else:
device = Device.from_config_with_hci( device = Device.with_hci(
DeviceConfiguration( 'Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink
name='Bumble',
address=Address('F0:F1:F2:F3:F4:F5'),
keystore='JsonKeyStore',
),
hci_source,
hci_sink,
) )
await device.power_on() await device.power_on()
@@ -196,7 +181,7 @@ async def scan(
scanning_phys=scanning_phys, scanning_phys=scanning_phys,
) )
await hci_source.terminated await hci_source.wait_for_termination()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -223,7 +208,9 @@ async def scan(
@click.option( @click.option(
'--irk', '--irk',
metavar='<IRK_HEX>:<ADDRESS>', metavar='<IRK_HEX>:<ADDRESS>',
help=('Use this IRK for resolving private addresses (may be used more than once)'), help=(
'Use this IRK for resolving private addresses ' '(may be used more than once)'
),
multiple=True, multiple=True,
) )
@click.option( @click.option(
@@ -250,7 +237,7 @@ def main(
device_config, device_config,
transport, transport,
): ):
bumble.logging.setup_basic_logging('WARNING') logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
asyncio.run( asyncio.run(
scan( scan(
min_rssi, min_rssi,
+7 -8
View File
@@ -16,17 +16,17 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import datetime import datetime
import importlib
import logging import logging
import os
import struct import struct
import click import click
import bumble.logging
from bumble import hci
from bumble.colors import color from bumble.colors import color
from bumble.helpers import PacketTracer from bumble import hci
from bumble.transport.common import PacketReader from bumble.transport.common import PacketReader
from bumble.helpers import PacketTracer
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -154,10 +154,9 @@ class Printer:
def main(format, vendor, filename): def main(format, vendor, filename):
for vendor_name in vendor: for vendor_name in vendor:
if vendor_name == 'android': if vendor_name == 'android':
# Prevent being deleted by linter. import bumble.vendor.android.hci
importlib.import_module('bumble.vendor.android.hci')
elif vendor_name == 'zephyr': elif vendor_name == 'zephyr':
importlib.import_module('bumble.vendor.zephyr.hci') import bumble.vendor.zephyr.hci
input = open(filename, 'rb') input = open(filename, 'rb')
if format == 'h4': if format == 'h4':
@@ -187,5 +186,5 @@ def main(format, vendor, filename):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
if __name__ == '__main__': if __name__ == '__main__':
bumble.logging.setup_basic_logging('WARNING') logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
main() # pylint: disable=no-value-for-parameter main() # pylint: disable=no-value-for-parameter
-1
View File
@@ -15,7 +15,6 @@
<tr><td>Codec</td><td><span id="codecText"></span></td></tr> <tr><td>Codec</td><td><span id="codecText"></span></td></tr>
<tr><td>Packets</td><td><span id="packetsReceivedText"></span></td></tr> <tr><td>Packets</td><td><span id="packetsReceivedText"></span></td></tr>
<tr><td>Bytes</td><td><span id="bytesReceivedText"></span></td></tr> <tr><td>Bytes</td><td><span id="bytesReceivedText"></span></td></tr>
<tr><td>Bitrate</td><td><span id="bitrate"></span></td></tr>
</table> </table>
</td> </td>
<td> <td>
+62 -113
View File
@@ -7,19 +7,17 @@ let connectionText;
let codecText; let codecText;
let packetsReceivedText; let packetsReceivedText;
let bytesReceivedText; let bytesReceivedText;
let bitrateText;
let streamStateText; let streamStateText;
let connectionStateText; let connectionStateText;
let controlsDiv; let controlsDiv;
let audioOnButton; let audioOnButton;
let audioDecoder; let mediaSource;
let audioCodec; let sourceBuffer;
let audioElement;
let audioContext; let audioContext;
let audioAnalyzer; let audioAnalyzer;
let audioFrequencyBinCount; let audioFrequencyBinCount;
let audioFrequencyData; let audioFrequencyData;
let nextAudioStartPosition = 0;
let audioStartTime = 0;
let packetsReceived = 0; let packetsReceived = 0;
let bytesReceived = 0; let bytesReceived = 0;
let audioState = "stopped"; let audioState = "stopped";
@@ -31,17 +29,20 @@ let bandwidthCanvas;
let bandwidthCanvasContext; let bandwidthCanvasContext;
let bandwidthBinCount; let bandwidthBinCount;
let bandwidthBins = []; let bandwidthBins = [];
let bitrateSamples = [];
const FFT_WIDTH = 800; const FFT_WIDTH = 800;
const FFT_HEIGHT = 256; const FFT_HEIGHT = 256;
const BANDWIDTH_WIDTH = 500; const BANDWIDTH_WIDTH = 500;
const BANDWIDTH_HEIGHT = 100; const BANDWIDTH_HEIGHT = 100;
const BITRATE_WINDOW = 30;
function hexToBytes(hex) {
return Uint8Array.from(hex.match(/.{1,2}/g).map((byte) => parseInt(byte, 16)));
}
function init() { function init() {
initUI(); initUI();
initAudioContext(); initMediaSource();
initAudioElement();
initAnalyzer(); initAnalyzer();
connect(); connect();
@@ -55,7 +56,6 @@ function initUI() {
codecText = document.getElementById("codecText"); codecText = document.getElementById("codecText");
packetsReceivedText = document.getElementById("packetsReceivedText"); packetsReceivedText = document.getElementById("packetsReceivedText");
bytesReceivedText = document.getElementById("bytesReceivedText"); bytesReceivedText = document.getElementById("bytesReceivedText");
bitrateText = document.getElementById("bitrate");
streamStateText = document.getElementById("streamStateText"); streamStateText = document.getElementById("streamStateText");
connectionStateText = document.getElementById("connectionStateText"); connectionStateText = document.getElementById("connectionStateText");
audioSupportMessageText = document.getElementById("audioSupportMessageText"); audioSupportMessageText = document.getElementById("audioSupportMessageText");
@@ -67,9 +67,17 @@ function initUI() {
requestAnimationFrame(onAnimationFrame); requestAnimationFrame(onAnimationFrame);
} }
function initAudioContext() { function initMediaSource() {
audioContext = new AudioContext(); mediaSource = new MediaSource();
audioContext.onstatechange = () => console.log("AudioContext state:", audioContext.state); mediaSource.onsourceopen = onMediaSourceOpen;
mediaSource.onsourceclose = onMediaSourceClose;
mediaSource.onsourceended = onMediaSourceEnd;
}
function initAudioElement() {
audioElement = document.getElementById("audio");
audioElement.src = URL.createObjectURL(mediaSource);
// audioElement.controls = true;
} }
function initAnalyzer() { function initAnalyzer() {
@@ -86,16 +94,24 @@ function initAnalyzer() {
bandwidthCanvasContext = bandwidthCanvas.getContext('2d'); bandwidthCanvasContext = bandwidthCanvas.getContext('2d');
bandwidthCanvasContext.fillStyle = "rgb(255, 255, 255)"; bandwidthCanvasContext.fillStyle = "rgb(255, 255, 255)";
bandwidthCanvasContext.fillRect(0, 0, BANDWIDTH_WIDTH, BANDWIDTH_HEIGHT); bandwidthCanvasContext.fillRect(0, 0, BANDWIDTH_WIDTH, BANDWIDTH_HEIGHT);
}
function startAnalyzer() {
// FFT
if (audioElement.captureStream !== undefined) {
audioContext = new AudioContext();
audioAnalyzer = audioContext.createAnalyser();
audioAnalyzer.fftSize = 128;
audioFrequencyBinCount = audioAnalyzer.frequencyBinCount;
audioFrequencyData = new Uint8Array(audioFrequencyBinCount);
const stream = audioElement.captureStream();
const source = audioContext.createMediaStreamSource(stream);
source.connect(audioAnalyzer);
}
// Bandwidth
bandwidthBinCount = BANDWIDTH_WIDTH / 2; bandwidthBinCount = BANDWIDTH_WIDTH / 2;
bandwidthBins = []; bandwidthBins = [];
bitrateSamples = [];
audioAnalyzer = audioContext.createAnalyser();
audioAnalyzer.fftSize = 128;
audioFrequencyBinCount = audioAnalyzer.frequencyBinCount;
audioFrequencyData = new Uint8Array(audioFrequencyBinCount);
audioAnalyzer.connect(audioContext.destination)
} }
function setConnectionText(message) { function setConnectionText(message) {
@@ -132,8 +148,7 @@ function onAnimationFrame() {
bandwidthCanvasContext.fillRect(0, 0, BANDWIDTH_WIDTH, BANDWIDTH_HEIGHT); bandwidthCanvasContext.fillRect(0, 0, BANDWIDTH_WIDTH, BANDWIDTH_HEIGHT);
bandwidthCanvasContext.fillStyle = `rgb(100, 100, 100)`; bandwidthCanvasContext.fillStyle = `rgb(100, 100, 100)`;
for (let t = 0; t < bandwidthBins.length; t++) { for (let t = 0; t < bandwidthBins.length; t++) {
const bytesReceived = bandwidthBins[t] const lineHeight = (bandwidthBins[t] / 1000) * BANDWIDTH_HEIGHT;
const lineHeight = (bytesReceived / 1000) * BANDWIDTH_HEIGHT;
bandwidthCanvasContext.fillRect(t * 2, BANDWIDTH_HEIGHT - lineHeight, 2, lineHeight); bandwidthCanvasContext.fillRect(t * 2, BANDWIDTH_HEIGHT - lineHeight, 2, lineHeight);
} }
@@ -141,14 +156,28 @@ function onAnimationFrame() {
requestAnimationFrame(onAnimationFrame); requestAnimationFrame(onAnimationFrame);
} }
function onMediaSourceOpen() {
console.log(this.readyState);
sourceBuffer = mediaSource.addSourceBuffer("audio/aac");
}
function onMediaSourceClose() {
console.log(this.readyState);
}
function onMediaSourceEnd() {
console.log(this.readyState);
}
async function startAudio() { async function startAudio() {
try { try {
console.log("starting audio..."); console.log("starting audio...");
audioOnButton.disabled = true; audioOnButton.disabled = true;
audioState = "starting"; audioState = "starting";
audioContext.resume(); await audioElement.play();
console.log("audio started"); console.log("audio started");
audioState = "playing"; audioState = "playing";
startAnalyzer();
} catch(error) { } catch(error) {
console.error(`play failed: ${error}`); console.error(`play failed: ${error}`);
audioState = "stopped"; audioState = "stopped";
@@ -156,47 +185,12 @@ async function startAudio() {
} }
} }
function onDecodedAudio(audioData) { function onAudioPacket(packet) {
const bufferSource = audioContext.createBufferSource() if (audioState != "stopped") {
// Queue the audio packet.
const now = audioContext.currentTime; sourceBuffer.appendBuffer(packet);
let nextAudioStartTime = audioStartTime + (nextAudioStartPosition / audioData.sampleRate);
if (nextAudioStartTime < now) {
console.log("starting new audio time base")
audioStartTime = now;
nextAudioStartTime = now;
nextAudioStartPosition = 0;
} else {
console.log(`audio buffer scheduled in ${nextAudioStartTime - now}`)
} }
const audioBuffer = audioContext.createBuffer(
audioData.numberOfChannels,
audioData.numberOfFrames,
audioData.sampleRate
);
for (let channel = 0; channel < audioData.numberOfChannels; channel++) {
audioData.copyTo(
audioBuffer.getChannelData(channel),
{
planeIndex: channel,
format: "f32-planar"
}
)
}
bufferSource.buffer = audioBuffer;
bufferSource.connect(audioAnalyzer)
bufferSource.start(nextAudioStartTime);
nextAudioStartPosition += audioData.numberOfFrames;
}
function onCodecError(error) {
console.log("Codec error:", error)
}
async function onAudioPacket(packet) {
packetsReceived += 1; packetsReceived += 1;
packetsReceivedText.innerText = packetsReceived; packetsReceivedText.innerText = packetsReceived;
bytesReceived += packet.byteLength; bytesReceived += packet.byteLength;
@@ -206,48 +200,6 @@ async function onAudioPacket(packet) {
if (bandwidthBins.length > bandwidthBinCount) { if (bandwidthBins.length > bandwidthBinCount) {
bandwidthBins.shift(); bandwidthBins.shift();
} }
bitrateSamples[bitrateSamples.length] = {ts: Date.now(), bytes: packet.byteLength}
if (bitrateSamples.length > BITRATE_WINDOW) {
bitrateSamples.shift();
}
if (bitrateSamples.length >= 2) {
const windowBytes = bitrateSamples.reduce((accumulator, x) => accumulator + x.bytes, 0) - bitrateSamples[0].bytes;
const elapsed = bitrateSamples[bitrateSamples.length-1].ts - bitrateSamples[0].ts;
const bitrate = Math.floor(8 * windowBytes / elapsed)
bitrateText.innerText = `${bitrate} kb/s`
}
if (audioState == "stopped") {
return;
}
if (audioDecoder === undefined) {
let audioConfig;
if (audioCodec == 'aac') {
audioConfig = {
codec: 'mp4a.40.2',
sampleRate: 44100, // ignored
numberOfChannels: 2, // ignored
}
} else if (audioCodec == 'opus') {
audioConfig = {
codec: 'opus',
sampleRate: 48000, // ignored
numberOfChannels: 2, // ignored
}
}
audioDecoder = new AudioDecoder({ output: onDecodedAudio, error: onCodecError });
audioDecoder.configure(audioConfig)
}
const encodedAudio = new EncodedAudioChunk({
type: "key",
data: packet,
timestamp: 0,
transfer: [packet],
});
audioDecoder.decode(encodedAudio);
} }
function onChannelOpen() { function onChannelOpen() {
@@ -297,19 +249,16 @@ function onChannelMessage(message) {
} }
} }
async function onHelloMessage(params) { function onHelloMessage(params) {
codecText.innerText = params.codec; codecText.innerText = params.codec;
if (params.codec != "aac") {
if (params.codec == "aac" || params.codec == "opus") { audioOnButton.disabled = true;
audioCodec = params.codec audioSupportMessageText.innerText = "Only AAC can be played, audio will be disabled";
audioSupportMessageText.style.display = "inline-block";
} else {
audioSupportMessageText.innerText = ""; audioSupportMessageText.innerText = "";
audioSupportMessageText.style.display = "none"; audioSupportMessageText.style.display = "none";
} else {
audioOnButton.disabled = true;
audioSupportMessageText.innerText = "Only AAC and Opus can be played, audio will be disabled";
audioSupportMessageText.style.display = "inline-block";
} }
if (params.streamState) { if (params.streamState) {
setStreamState(params.streamState); setStreamState(params.streamState);
} }
+42 -145
View File
@@ -16,48 +16,47 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import asyncio.subprocess import asyncio.subprocess
from importlib import resources
import enum import enum
import json import json
import os
import logging import logging
import pathlib import pathlib
import subprocess import subprocess
from typing import Dict, List, Optional
import weakref import weakref
from importlib import resources
import aiohttp
import click import click
import aiohttp
from aiohttp import web from aiohttp import web
import bumble import bumble
import bumble.logging from bumble.colors import color
from bumble.a2dp import ( from bumble.core import PhysicalTransport, CommandTimeoutError
A2DP_MPEG_2_4_AAC_CODEC_TYPE, from bumble.device import Connection, Device, DeviceConfiguration
A2DP_NON_A2DP_CODEC_TYPE, from bumble.hci import HCI_StatusError
A2DP_SBC_CODEC_TYPE, from bumble.pairing import PairingConfig
AacMediaCodecInformation, from bumble.sdp import ServiceAttribute
OpusMediaCodecInformation, from bumble.transport import open_transport
SbcMediaCodecInformation,
make_audio_sink_service_sdp_records,
)
from bumble.avdtp import ( from bumble.avdtp import (
AVDTP_AUDIO_MEDIA_TYPE, AVDTP_AUDIO_MEDIA_TYPE,
Listener, Listener,
MediaCodecCapabilities, MediaCodecCapabilities,
Protocol, Protocol,
) )
from bumble.codecs import AacAudioRtpPacket from bumble.a2dp import (
from bumble.colors import color make_audio_sink_service_sdp_records,
from bumble.core import CommandTimeoutError, PhysicalTransport A2DP_SBC_CODEC_TYPE,
from bumble.device import Connection, Device, DeviceConfiguration A2DP_MPEG_2_4_AAC_CODEC_TYPE,
from bumble.hci import HCI_StatusError SbcMediaCodecInformation,
from bumble.pairing import PairingConfig AacMediaCodecInformation,
from bumble.rtp import MediaPacket )
from bumble.sdp import ServiceAttribute
from bumble.transport import open_transport
from bumble.utils import AsyncRunner from bumble.utils import AsyncRunner
from bumble.codecs import AacAudioRtpPacket
from bumble.rtp import MediaPacket
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -79,8 +78,6 @@ class AudioExtractor:
return AacAudioExtractor() return AacAudioExtractor()
if codec == 'sbc': if codec == 'sbc':
return SbcAudioExtractor() return SbcAudioExtractor()
if codec == 'opus':
return OpusAudioExtractor()
def extract_audio(self, packet: MediaPacket) -> bytes: def extract_audio(self, packet: MediaPacket) -> bytes:
raise NotImplementedError() raise NotImplementedError()
@@ -105,13 +102,6 @@ class SbcAudioExtractor:
return packet.payload[1:] return packet.payload[1:]
# -----------------------------------------------------------------------------
class OpusAudioExtractor:
def extract_audio(self, packet: MediaPacket) -> bytes:
# TODO: parse fields
return packet.payload[1:]
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Output: class Output:
async def start(self) -> None: async def start(self) -> None:
@@ -155,7 +145,7 @@ class QueuedOutput(Output):
packets: asyncio.Queue packets: asyncio.Queue
extractor: AudioExtractor extractor: AudioExtractor
packet_pump_task: asyncio.Task | None packet_pump_task: Optional[asyncio.Task]
started: bool started: bool
def __init__(self, extractor): def __init__(self, extractor):
@@ -229,8 +219,8 @@ class WebSocketOutput(QueuedOutput):
class FfplayOutput(QueuedOutput): class FfplayOutput(QueuedOutput):
MAX_QUEUE_SIZE = 32768 MAX_QUEUE_SIZE = 32768
subprocess: asyncio.subprocess.Process | None subprocess: Optional[asyncio.subprocess.Process]
ffplay_task: asyncio.Task | None ffplay_task: Optional[asyncio.Task]
def __init__(self, codec: str) -> None: def __init__(self, codec: str) -> None:
super().__init__(AudioExtractor.create(codec)) super().__init__(AudioExtractor.create(codec))
@@ -245,7 +235,7 @@ class FfplayOutput(QueuedOutput):
await super().start() await super().start()
self.subprocess = await asyncio.create_subprocess_shell( self.subprocess = await asyncio.create_subprocess_shell(
f'ffplay -probesize 32 -f {self.codec} pipe:0', f'ffplay -f {self.codec} pipe:0',
stdin=asyncio.subprocess.PIPE, stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE,
@@ -409,24 +399,10 @@ class Speaker:
STARTED = 2 STARTED = 2
SUSPENDED = 3 SUSPENDED = 3
def __init__( def __init__(self, device_config, transport, codec, discover, outputs, ui_port):
self,
device_config,
transport,
codec,
sampling_frequencies,
bitrate,
vbr,
discover,
outputs,
ui_port,
):
self.device_config = device_config self.device_config = device_config
self.transport = transport self.transport = transport
self.codec = codec self.codec = codec
self.sampling_frequencies = sampling_frequencies
self.bitrate = bitrate
self.vbr = vbr
self.discover = discover self.discover = discover
self.ui_port = ui_port self.ui_port = ui_port
self.device = None self.device = None
@@ -447,7 +423,7 @@ class Speaker:
# Create an HTTP server for the UI # Create an HTTP server for the UI
self.ui_server = UiServer(speaker=self, port=ui_port) self.ui_server = UiServer(speaker=self, port=ui_port)
def sdp_records(self) -> dict[int, list[ServiceAttribute]]: def sdp_records(self) -> Dict[int, List[ServiceAttribute]]:
service_record_handle = 0x00010001 service_record_handle = 0x00010001
return { return {
service_record_handle: make_audio_sink_service_sdp_records( service_record_handle: make_audio_sink_service_sdp_records(
@@ -462,56 +438,32 @@ class Speaker:
if self.codec == 'sbc': if self.codec == 'sbc':
return self.sbc_codec_capabilities() return self.sbc_codec_capabilities()
if self.codec == 'opus':
return self.opus_codec_capabilities()
raise RuntimeError('unsupported codec') raise RuntimeError('unsupported codec')
def aac_codec_capabilities(self) -> MediaCodecCapabilities: def aac_codec_capabilities(self) -> MediaCodecCapabilities:
supported_sampling_frequencies = AacMediaCodecInformation.SamplingFrequency(0)
for sampling_frequency in self.sampling_frequencies or [
8000,
11025,
12000,
16000,
22050,
24000,
32000,
44100,
48000,
]:
supported_sampling_frequencies |= (
AacMediaCodecInformation.SamplingFrequency.from_int(sampling_frequency)
)
return MediaCodecCapabilities( return MediaCodecCapabilities(
media_type=AVDTP_AUDIO_MEDIA_TYPE, media_type=AVDTP_AUDIO_MEDIA_TYPE,
media_codec_type=A2DP_MPEG_2_4_AAC_CODEC_TYPE, media_codec_type=A2DP_MPEG_2_4_AAC_CODEC_TYPE,
media_codec_information=AacMediaCodecInformation( media_codec_information=AacMediaCodecInformation(
object_type=AacMediaCodecInformation.ObjectType.MPEG_2_AAC_LC, object_type=AacMediaCodecInformation.ObjectType.MPEG_2_AAC_LC,
sampling_frequency=supported_sampling_frequencies, sampling_frequency=AacMediaCodecInformation.SamplingFrequency.SF_48000
| AacMediaCodecInformation.SamplingFrequency.SF_44100,
channels=AacMediaCodecInformation.Channels.MONO channels=AacMediaCodecInformation.Channels.MONO
| AacMediaCodecInformation.Channels.STEREO, | AacMediaCodecInformation.Channels.STEREO,
vbr=1 if self.vbr else 0, vbr=1,
bitrate=self.bitrate or 256000, bitrate=256000,
), ),
) )
def sbc_codec_capabilities(self) -> MediaCodecCapabilities: def sbc_codec_capabilities(self) -> MediaCodecCapabilities:
supported_sampling_frequencies = SbcMediaCodecInformation.SamplingFrequency(0)
for sampling_frequency in self.sampling_frequencies or [
16000,
32000,
44100,
48000,
]:
supported_sampling_frequencies |= (
SbcMediaCodecInformation.SamplingFrequency.from_int(sampling_frequency)
)
return MediaCodecCapabilities( return MediaCodecCapabilities(
media_type=AVDTP_AUDIO_MEDIA_TYPE, media_type=AVDTP_AUDIO_MEDIA_TYPE,
media_codec_type=A2DP_SBC_CODEC_TYPE, media_codec_type=A2DP_SBC_CODEC_TYPE,
media_codec_information=SbcMediaCodecInformation( media_codec_information=SbcMediaCodecInformation(
sampling_frequency=supported_sampling_frequencies, sampling_frequency=SbcMediaCodecInformation.SamplingFrequency.SF_48000
| SbcMediaCodecInformation.SamplingFrequency.SF_44100
| SbcMediaCodecInformation.SamplingFrequency.SF_32000
| SbcMediaCodecInformation.SamplingFrequency.SF_16000,
channel_mode=SbcMediaCodecInformation.ChannelMode.MONO channel_mode=SbcMediaCodecInformation.ChannelMode.MONO
| SbcMediaCodecInformation.ChannelMode.DUAL_CHANNEL | SbcMediaCodecInformation.ChannelMode.DUAL_CHANNEL
| SbcMediaCodecInformation.ChannelMode.STEREO | SbcMediaCodecInformation.ChannelMode.STEREO
@@ -529,25 +481,6 @@ class Speaker:
), ),
) )
def opus_codec_capabilities(self) -> MediaCodecCapabilities:
supported_sampling_frequencies = OpusMediaCodecInformation.SamplingFrequency(0)
for sampling_frequency in self.sampling_frequencies or [48000]:
supported_sampling_frequencies |= (
OpusMediaCodecInformation.SamplingFrequency.from_int(sampling_frequency)
)
return MediaCodecCapabilities(
media_type=AVDTP_AUDIO_MEDIA_TYPE,
media_codec_type=A2DP_NON_A2DP_CODEC_TYPE,
media_codec_information=OpusMediaCodecInformation(
frame_size=OpusMediaCodecInformation.FrameSize.FS_10MS
| OpusMediaCodecInformation.FrameSize.FS_20MS,
channel_mode=OpusMediaCodecInformation.ChannelMode.MONO
| OpusMediaCodecInformation.ChannelMode.STEREO
| OpusMediaCodecInformation.ChannelMode.DUAL_MONO,
sampling_frequency=supported_sampling_frequencies,
),
)
async def dispatch_to_outputs(self, function): async def dispatch_to_outputs(self, function):
for output in self.outputs: for output in self.outputs:
await function(output) await function(output)
@@ -726,7 +659,7 @@ class Speaker:
print("Waiting for connection...") print("Waiting for connection...")
await self.advertise() await self.advertise()
await hci_source.terminated await hci_source.wait_for_termination()
for output in self.outputs: for output in self.outputs:
await output.stop() await output.stop()
@@ -742,26 +675,7 @@ def speaker_cli(ctx, device_config):
@click.command() @click.command()
@click.option( @click.option(
'--codec', '--codec', type=click.Choice(['sbc', 'aac']), default='aac', show_default=True
type=click.Choice(['sbc', 'aac', 'opus']),
default='aac',
show_default=True,
)
@click.option(
'--sampling-frequency',
metavar='SAMPLING-FREQUENCY',
type=int,
multiple=True,
help='Enable a sampling frequency (may be specified more than once)',
)
@click.option(
'--bitrate',
metavar='BITRATE',
type=int,
help='Supported bitrate (AAC only)',
)
@click.option(
'--vbr/--no-vbr', is_flag=True, default=True, help='Enable VBR (AAC only)'
) )
@click.option( @click.option(
'--discover', is_flag=True, help='Discover remote endpoints once connected' '--discover', is_flag=True, help='Discover remote endpoints once connected'
@@ -792,16 +706,7 @@ def speaker_cli(ctx, device_config):
@click.option('--device-config', metavar='FILENAME', help='Device configuration file') @click.option('--device-config', metavar='FILENAME', help='Device configuration file')
@click.argument('transport') @click.argument('transport')
def speaker( def speaker(
transport, transport, codec, connect_address, discover, output, ui_port, device_config
codec,
sampling_frequency,
bitrate,
vbr,
connect_address,
discover,
output,
ui_port,
device_config,
): ):
"""Run the speaker.""" """Run the speaker."""
@@ -816,23 +721,15 @@ def speaker(
output = list(filter(lambda x: x != '@ffplay', output)) output = list(filter(lambda x: x != '@ffplay', output))
asyncio.run( asyncio.run(
Speaker( Speaker(device_config, transport, codec, discover, output, ui_port).run(
device_config, connect_address
transport, )
codec,
sampling_frequency,
bitrate,
vbr,
discover,
output,
ui_port,
).run(connect_address)
) )
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def main(): def main():
bumble.logging.setup_basic_logging('WARNING') logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
speaker() speaker()
+3 -3
View File
@@ -16,10 +16,10 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import asyncio import asyncio
import os
import logging
import click import click
import bumble.logging
from bumble.device import Device from bumble.device import Device
from bumble.keys import JsonKeyStore from bumble.keys import JsonKeyStore
from bumble.transport import open_transport from bumble.transport import open_transport
@@ -68,7 +68,7 @@ def main(keystore_file, hci_transport, device_config, address):
instantiated. instantiated.
If no address is passed, the existing pairing keys for all addresses are printed. If no address is passed, the existing pairing keys for all addresses are printed.
""" """
bumble.logging.setup_basic_logging() logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
if not keystore_file and not hci_transport: if not keystore_file and not hci_transport:
print('either --keystore-file or --hci-transport must be specified.') print('either --keystore-file or --hci-transport must be specified.')
+7 -23
View File
@@ -26,15 +26,15 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from typing import Any import os
import logging
import click import click
import usb1 import usb1
import bumble.logging
from bumble.colors import color from bumble.colors import color
from bumble.transport.usb import load_libusb from bumble.transport.usb import load_libusb
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Constants # Constants
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -111,14 +111,9 @@ def show_device_details(device):
if (endpoint.getAddress() & USB_ENDPOINT_IN == 0) if (endpoint.getAddress() & USB_ENDPOINT_IN == 0)
else 'IN' else 'IN'
) )
endpoint_details = (
f', Max Packet Size = {endpoint.getMaxPacketSize()}'
if endpoint_type == 'ISOCHRONOUS'
else ''
)
print( print(
f' Endpoint 0x{endpoint.getAddress():02X}: ' f' Endpoint 0x{endpoint.getAddress():02X}: '
f'{endpoint_type} {endpoint_direction}{endpoint_details}' f'{endpoint_type} {endpoint_direction}'
) )
@@ -173,16 +168,13 @@ def is_bluetooth_hci(device):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@click.command() @click.command()
@click.option('--verbose', is_flag=True, default=False, help='Print more details') @click.option('--verbose', is_flag=True, default=False, help='Print more details')
@click.option('--hci-only', is_flag=True, default=False, help='only show HCI device') def main(verbose):
@click.option('--manufacturer', help='filter by manufacturer') logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
@click.option('--product', help='filter by product')
def main(verbose: bool, manufacturer: str, product: str, hci_only: bool):
bumble.logging.setup_basic_logging('WARNING')
load_libusb() load_libusb()
with usb1.USBContext() as context: with usb1.USBContext() as context:
bluetooth_device_count = 0 bluetooth_device_count = 0
devices: dict[tuple[Any, Any], list[str | None]] = {} devices = {}
for device in context.getDeviceIterator(skip_on_error=True): for device in context.getDeviceIterator(skip_on_error=True):
device_class = device.getDeviceClass() device_class = device.getDeviceClass()
@@ -244,14 +236,6 @@ def main(verbose: bool, manufacturer: str, product: str, hci_only: bool):
f'{basic_transport_name}/{device_serial_number}' f'{basic_transport_name}/{device_serial_number}'
) )
# Filter
if product and device_product != product:
continue
if manufacturer and device_manufacturer != manufacturer:
continue
if not is_bluetooth_hci(device) and hci_only:
continue
# Print the results # Print the results
print( print(
color( color(
+54 -80
View File
@@ -17,37 +17,37 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations from __future__ import annotations
from collections.abc import AsyncGenerator
import dataclasses import dataclasses
import enum import enum
import logging import logging
import struct import struct
from collections.abc import AsyncGenerator, Awaitable, Callable from typing import Awaitable, Callable
from typing import ClassVar from typing_extensions import ClassVar, Self
from typing_extensions import Self
from bumble import utils
from bumble.codecs import AacAudioRtpPacket from bumble.codecs import AacAudioRtpPacket
from bumble.company_ids import COMPANY_IDENTIFIERS from bumble.company_ids import COMPANY_IDENTIFIERS
from bumble.sdp import (
DataElement,
ServiceAttribute,
SDP_PUBLIC_BROWSE_ROOT,
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
)
from bumble.core import ( from bumble.core import (
BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE,
BT_AUDIO_SINK_SERVICE,
BT_AUDIO_SOURCE_SERVICE,
BT_AVDTP_PROTOCOL_ID,
BT_L2CAP_PROTOCOL_ID, BT_L2CAP_PROTOCOL_ID,
BT_AUDIO_SOURCE_SERVICE,
BT_AUDIO_SINK_SERVICE,
BT_AVDTP_PROTOCOL_ID,
BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE,
name_or_number, name_or_number,
) )
from bumble.rtp import MediaPacket from bumble.rtp import MediaPacket
from bumble.sdp import (
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_PUBLIC_BROWSE_ROOT,
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
DataElement,
ServiceAttribute,
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -60,18 +60,19 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# fmt: off # fmt: off
class CodecType(utils.OpenIntEnum): A2DP_SBC_CODEC_TYPE = 0x00
SBC = 0x00 A2DP_MPEG_1_2_AUDIO_CODEC_TYPE = 0x01
MPEG_1_2_AUDIO = 0x01 A2DP_MPEG_2_4_AAC_CODEC_TYPE = 0x02
MPEG_2_4_AAC = 0x02 A2DP_ATRAC_FAMILY_CODEC_TYPE = 0x03
ATRAC_FAMILY = 0x03 A2DP_NON_A2DP_CODEC_TYPE = 0xFF
NON_A2DP = 0xFF
A2DP_SBC_CODEC_TYPE = CodecType.SBC A2DP_CODEC_TYPE_NAMES = {
A2DP_MPEG_1_2_AUDIO_CODEC_TYPE = CodecType.MPEG_1_2_AUDIO A2DP_SBC_CODEC_TYPE: 'A2DP_SBC_CODEC_TYPE',
A2DP_MPEG_2_4_AAC_CODEC_TYPE = CodecType.MPEG_2_4_AAC A2DP_MPEG_1_2_AUDIO_CODEC_TYPE: 'A2DP_MPEG_1_2_AUDIO_CODEC_TYPE',
A2DP_ATRAC_FAMILY_CODEC_TYPE = CodecType.ATRAC_FAMILY A2DP_MPEG_2_4_AAC_CODEC_TYPE: 'A2DP_MPEG_2_4_AAC_CODEC_TYPE',
A2DP_NON_A2DP_CODEC_TYPE = CodecType.NON_A2DP A2DP_ATRAC_FAMILY_CODEC_TYPE: 'A2DP_ATRAC_FAMILY_CODEC_TYPE',
A2DP_NON_A2DP_CODEC_TYPE: 'A2DP_NON_A2DP_CODEC_TYPE'
}
SBC_SYNC_WORD = 0x9C SBC_SYNC_WORD = 0x9C
@@ -88,6 +89,13 @@ SBC_DUAL_CHANNEL_MODE = 0x01
SBC_STEREO_CHANNEL_MODE = 0x02 SBC_STEREO_CHANNEL_MODE = 0x02
SBC_JOINT_STEREO_CHANNEL_MODE = 0x03 SBC_JOINT_STEREO_CHANNEL_MODE = 0x03
SBC_CHANNEL_MODE_NAMES = {
SBC_MONO_CHANNEL_MODE: 'SBC_MONO_CHANNEL_MODE',
SBC_DUAL_CHANNEL_MODE: 'SBC_DUAL_CHANNEL_MODE',
SBC_STEREO_CHANNEL_MODE: 'SBC_STEREO_CHANNEL_MODE',
SBC_JOINT_STEREO_CHANNEL_MODE: 'SBC_JOINT_STEREO_CHANNEL_MODE'
}
SBC_BLOCK_LENGTHS = [4, 8, 12, 16] SBC_BLOCK_LENGTHS = [4, 8, 12, 16]
SBC_SUBBANDS = [4, 8] SBC_SUBBANDS = [4, 8]
@@ -95,6 +103,11 @@ SBC_SUBBANDS = [4, 8]
SBC_SNR_ALLOCATION_METHOD = 0x00 SBC_SNR_ALLOCATION_METHOD = 0x00
SBC_LOUDNESS_ALLOCATION_METHOD = 0x01 SBC_LOUDNESS_ALLOCATION_METHOD = 0x01
SBC_ALLOCATION_METHOD_NAMES = {
SBC_SNR_ALLOCATION_METHOD: 'SBC_SNR_ALLOCATION_METHOD',
SBC_LOUDNESS_ALLOCATION_METHOD: 'SBC_LOUDNESS_ALLOCATION_METHOD'
}
SBC_MAX_FRAMES_IN_RTP_PAYLOAD = 15 SBC_MAX_FRAMES_IN_RTP_PAYLOAD = 15
MPEG_2_4_AAC_SAMPLING_FREQUENCIES = [ MPEG_2_4_AAC_SAMPLING_FREQUENCIES = [
@@ -117,6 +130,13 @@ MPEG_4_AAC_LC_OBJECT_TYPE = 0x01
MPEG_4_AAC_LTP_OBJECT_TYPE = 0x02 MPEG_4_AAC_LTP_OBJECT_TYPE = 0x02
MPEG_4_AAC_SCALABLE_OBJECT_TYPE = 0x03 MPEG_4_AAC_SCALABLE_OBJECT_TYPE = 0x03
MPEG_2_4_OBJECT_TYPE_NAMES = {
MPEG_2_AAC_LC_OBJECT_TYPE: 'MPEG_2_AAC_LC_OBJECT_TYPE',
MPEG_4_AAC_LC_OBJECT_TYPE: 'MPEG_4_AAC_LC_OBJECT_TYPE',
MPEG_4_AAC_LTP_OBJECT_TYPE: 'MPEG_4_AAC_LTP_OBJECT_TYPE',
MPEG_4_AAC_SCALABLE_OBJECT_TYPE: 'MPEG_4_AAC_SCALABLE_OBJECT_TYPE'
}
OPUS_MAX_FRAMES_IN_RTP_PAYLOAD = 15 OPUS_MAX_FRAMES_IN_RTP_PAYLOAD = 15
@@ -240,49 +260,9 @@ def make_audio_sink_service_sdp_records(service_record_handle, version=(1, 3)):
] ]
# -----------------------------------------------------------------------------
class MediaCodecInformation:
'''Base Media Codec Information.'''
@classmethod
def create(
cls, media_codec_type: int, data: bytes
) -> MediaCodecInformation | bytes:
match media_codec_type:
case CodecType.SBC:
return SbcMediaCodecInformation.from_bytes(data)
case CodecType.MPEG_2_4_AAC:
return AacMediaCodecInformation.from_bytes(data)
case CodecType.NON_A2DP:
vendor_media_codec_information = (
VendorSpecificMediaCodecInformation.from_bytes(data)
)
if (
vendor_class_map := A2DP_VENDOR_MEDIA_CODEC_INFORMATION_CLASSES.get(
vendor_media_codec_information.vendor_id
)
) and (
media_codec_information_class := vendor_class_map.get(
vendor_media_codec_information.codec_id
)
):
return media_codec_information_class.from_bytes(
vendor_media_codec_information.value
)
return vendor_media_codec_information
@classmethod
def from_bytes(cls, data: bytes) -> Self:
del data # Unused.
raise NotImplementedError
def __bytes__(self) -> bytes:
raise NotImplementedError
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@dataclasses.dataclass @dataclasses.dataclass
class SbcMediaCodecInformation(MediaCodecInformation): class SbcMediaCodecInformation:
''' '''
A2DP spec - 4.3.2 Codec Specific Information Elements A2DP spec - 4.3.2 Codec Specific Information Elements
''' '''
@@ -366,7 +346,7 @@ class SbcMediaCodecInformation(MediaCodecInformation):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@dataclasses.dataclass @dataclasses.dataclass
class AacMediaCodecInformation(MediaCodecInformation): class AacMediaCodecInformation:
''' '''
A2DP spec - 4.5.2 Codec Specific Information Elements A2DP spec - 4.5.2 Codec Specific Information Elements
''' '''
@@ -448,7 +428,7 @@ class AacMediaCodecInformation(MediaCodecInformation):
@dataclasses.dataclass @dataclasses.dataclass
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class VendorSpecificMediaCodecInformation(MediaCodecInformation): class VendorSpecificMediaCodecInformation:
''' '''
A2DP spec - 4.7.2 Codec Specific Information Elements A2DP spec - 4.7.2 Codec Specific Information Elements
''' '''
@@ -472,7 +452,7 @@ class VendorSpecificMediaCodecInformation(MediaCodecInformation):
'VendorSpecificMediaCodecInformation(', 'VendorSpecificMediaCodecInformation(',
f' vendor_id: {self.vendor_id:08X} ({name_or_number(COMPANY_IDENTIFIERS, self.vendor_id & 0xFFFF)})', f' vendor_id: {self.vendor_id:08X} ({name_or_number(COMPANY_IDENTIFIERS, self.vendor_id & 0xFFFF)})',
f' codec_id: {self.codec_id:04X}', f' codec_id: {self.codec_id:04X}',
f' value: {self.value.hex()})', f' value: {self.value.hex()}' ')',
] ]
) )
@@ -499,12 +479,6 @@ class OpusMediaCodecInformation(VendorSpecificMediaCodecInformation):
class SamplingFrequency(enum.IntFlag): class SamplingFrequency(enum.IntFlag):
SF_48000 = 1 << 0 SF_48000 = 1 << 0
@classmethod
def from_int(cls, sampling_frequency: int) -> Self:
if sampling_frequency != 48000:
raise ValueError("no such sampling frequency")
return cls(1)
VENDOR_ID: ClassVar[int] = 0x000000E0 VENDOR_ID: ClassVar[int] = 0x000000E0
CODEC_ID: ClassVar[int] = 0x0001 CODEC_ID: ClassVar[int] = 0x0001
@@ -668,7 +642,7 @@ class SbcPacketSource:
# Prepare for next packets # Prepare for next packets
sequence_number += 1 sequence_number += 1
sequence_number &= 0xFFFF sequence_number &= 0xFFFF
sample_count += sum(frame.sample_count for frame in frames) sample_count += sum((frame.sample_count for frame in frames))
frames = [frame] frames = [frame]
frames_size = len(frame.payload) frames_size = len(frame.payload)
else: else:
+36 -37
View File
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import List, Union
from bumble import core from bumble import core
@@ -20,14 +21,14 @@ class AtParsingError(core.InvalidPacketError):
"""Error raised when parsing AT commands fails.""" """Error raised when parsing AT commands fails."""
def tokenize_parameters(buffer: bytes) -> list[bytes]: def tokenize_parameters(buffer: bytes) -> List[bytes]:
"""Split input parameters into tokens. """Split input parameters into tokens.
Removes space characters outside of double quote blocks: Removes space characters outside of double quote blocks:
T-rec-V-25 - 5.2.1 Command line general format: "Space characters (IA5 2/0) T-rec-V-25 - 5.2.1 Command line general format: "Space characters (IA5 2/0)
are ignored [..], unless they are embedded in numeric or string constants" are ignored [..], unless they are embedded in numeric or string constants"
Raises AtParsingError in case of invalid input string.""" Raises AtParsingError in case of invalid input string."""
tokens: list[bytearray] = [] tokens = []
in_quotes = False in_quotes = False
token = bytearray() token = bytearray()
for b in buffer: for b in buffer:
@@ -35,56 +36,54 @@ def tokenize_parameters(buffer: bytes) -> list[bytes]:
if in_quotes: if in_quotes:
token.extend(char) token.extend(char)
if char == b'"': if char == b'\"':
in_quotes = False in_quotes = False
tokens.append(token[1:-1]) tokens.append(token[1:-1])
token = bytearray() token = bytearray()
else: else:
match char: if char == b' ':
case b' ': pass
pass elif char == b',' or char == b')':
case b',' | b')': tokens.append(token)
tokens.append(token) tokens.append(char)
tokens.append(char) token = bytearray()
token = bytearray() elif char == b'(':
case b'(': if len(token) > 0:
if len(token) > 0: raise AtParsingError("open_paren following regular character")
raise AtParsingError("open_paren following regular character") tokens.append(char)
tokens.append(char) elif char == b'"':
case b'"': if len(token) > 0:
if len(token) > 0: raise AtParsingError("quote following regular character")
raise AtParsingError("quote following regular character") in_quotes = True
in_quotes = True token.extend(char)
token.extend(char) else:
case _: token.extend(char)
token.extend(char)
tokens.append(token) tokens.append(token)
return [bytes(token) for token in tokens if len(token) > 0] return [bytes(token) for token in tokens if len(token) > 0]
def parse_parameters(buffer: bytes) -> list[bytes | list]: def parse_parameters(buffer: bytes) -> List[Union[bytes, list]]:
"""Parse the parameters using the comma and parenthesis separators. """Parse the parameters using the comma and parenthesis separators.
Raises AtParsingError in case of invalid input string.""" Raises AtParsingError in case of invalid input string."""
tokens = tokenize_parameters(buffer) tokens = tokenize_parameters(buffer)
accumulator: list[list] = [[]] accumulator: List[list] = [[]]
current: bytes | list = b'' current: Union[bytes, list] = bytes()
for token in tokens: for token in tokens:
match token: if token == b',':
case b',': accumulator[-1].append(current)
accumulator[-1].append(current) current = bytes()
current = b'' elif token == b'(':
case b'(': accumulator.append([])
accumulator.append([]) elif token == b')':
case b')': if len(accumulator) < 2:
if len(accumulator) < 2: raise AtParsingError("close_paren without matching open_paren")
raise AtParsingError("close_paren without matching open_paren") accumulator[-1].append(current)
accumulator[-1].append(current) current = accumulator.pop()
current = accumulator.pop() else:
case _: current = token
current = token
accumulator[-1].append(current) accumulator[-1].append(current)
if len(accumulator) > 1: if len(accumulator) > 1:
+305 -430
View File
File diff suppressed because it is too large Load Diff
+13 -11
View File
@@ -17,17 +17,20 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations from __future__ import annotations
import abc
import asyncio import asyncio
import concurrent.futures import abc
from concurrent.futures import ThreadPoolExecutor
import dataclasses import dataclasses
import enum import enum
import logging import logging
import pathlib import pathlib
from typing import (
AsyncGenerator,
BinaryIO,
TYPE_CHECKING,
)
import sys import sys
import wave import wave
from collections.abc import AsyncGenerator
from typing import TYPE_CHECKING, BinaryIO
from bumble.colors import color from bumble.colors import color
@@ -177,7 +180,7 @@ class ThreadedAudioOutput(AudioOutput):
""" """
def __init__(self) -> None: def __init__(self) -> None:
self._thread_pool = concurrent.futures.ThreadPoolExecutor(1) self._thread_pool = ThreadPoolExecutor(1)
self._pcm_samples: asyncio.Queue[bytes] = asyncio.Queue() self._pcm_samples: asyncio.Queue[bytes] = asyncio.Queue()
self._write_task = asyncio.create_task(self._write_loop()) self._write_task = asyncio.create_task(self._write_loop())
@@ -227,8 +230,8 @@ class SoundDeviceAudioOutput(ThreadedAudioOutput):
try: try:
self._stream.write(pcm_samples) self._stream.write(pcm_samples)
except Exception: except Exception as error:
logger.exception('Sound device error') print(f'Sound device error: {error}')
raise raise
def _close(self): def _close(self):
@@ -406,7 +409,7 @@ class ThreadedAudioInput(AudioInput):
"""Base class for AudioInput implementation where reading samples may block.""" """Base class for AudioInput implementation where reading samples may block."""
def __init__(self) -> None: def __init__(self) -> None:
self._thread_pool = concurrent.futures.ThreadPoolExecutor(1) self._thread_pool = ThreadPoolExecutor(1)
self._pcm_samples: asyncio.Queue[bytes] = asyncio.Queue() self._pcm_samples: asyncio.Queue[bytes] = asyncio.Queue()
@abc.abstractmethod @abc.abstractmethod
@@ -546,6 +549,5 @@ class SoundDeviceAudioInput(ThreadedAudioInput):
return bytes(pcm_buffer) return bytes(pcm_buffer)
def _close(self): def _close(self):
if self._stream: self._stream.stop()
self._stream.stop() self._stream = None
self._stream = None
+10 -9
View File
@@ -16,11 +16,12 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations from __future__ import annotations
import enum import enum
import struct import struct
from typing import Dict, Type, Union, Tuple
from bumble import core, utils from bumble import core
from bumble import utils
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -165,7 +166,7 @@ class Frame:
def to_bytes( def to_bytes(
self, self,
ctype_or_response: CommandFrame.CommandType | ResponseFrame.ResponseCode, ctype_or_response: Union[CommandFrame.CommandType, ResponseFrame.ResponseCode],
) -> bytes: ) -> bytes:
# TODO: support extended subunit types and ids. # TODO: support extended subunit types and ids.
return ( return (
@@ -212,11 +213,11 @@ class CommandFrame(Frame):
NOTIFY = 0x03 NOTIFY = 0x03
GENERAL_INQUIRY = 0x04 GENERAL_INQUIRY = 0x04
subclasses: dict[Frame.OperationCode, type[CommandFrame]] = {} subclasses: Dict[Frame.OperationCode, Type[CommandFrame]] = {}
ctype: CommandType ctype: CommandType
@staticmethod @staticmethod
def parse_operands(operands: bytes) -> tuple: def parse_operands(operands: bytes) -> Tuple:
raise NotImplementedError raise NotImplementedError
def __init__( def __init__(
@@ -250,11 +251,11 @@ class ResponseFrame(Frame):
CHANGED = 0x0D CHANGED = 0x0D
INTERIM = 0x0F INTERIM = 0x0F
subclasses: dict[Frame.OperationCode, type[ResponseFrame]] = {} subclasses: Dict[Frame.OperationCode, Type[ResponseFrame]] = {}
response: ResponseCode response: ResponseCode
@staticmethod @staticmethod
def parse_operands(operands: bytes) -> tuple: def parse_operands(operands: bytes) -> Tuple:
raise NotImplementedError raise NotImplementedError
def __init__( def __init__(
@@ -281,7 +282,7 @@ class VendorDependentFrame:
vendor_dependent_data: bytes vendor_dependent_data: bytes
@staticmethod @staticmethod
def parse_operands(operands: bytes) -> tuple: def parse_operands(operands: bytes) -> Tuple:
return ( return (
struct.unpack(">I", b"\x00" + operands[:3])[0], struct.unpack(">I", b"\x00" + operands[:3])[0],
operands[3:], operands[3:],
@@ -431,7 +432,7 @@ class PassThroughFrame:
operation_data: bytes operation_data: bytes
@staticmethod @staticmethod
def parse_operands(operands: bytes) -> tuple: def parse_operands(operands: bytes) -> Tuple:
return ( return (
PassThroughFrame.StateFlag(operands[0] >> 7), PassThroughFrame.StateFlag(operands[0] >> 7),
PassThroughFrame.OperationId(operands[0] & 0x7F), PassThroughFrame.OperationId(operands[0] & 0x7F),
+21 -15
View File
@@ -16,14 +16,15 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations from __future__ import annotations
from enum import IntEnum
import logging import logging
import struct import struct
from collections.abc import Callable from typing import Callable, cast, Dict, Optional
from enum import IntEnum
from bumble import core, l2cap
from bumble.colors import color from bumble.colors import color
from bumble import avc
from bumble import core
from bumble import l2cap
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -136,18 +137,18 @@ class MessageAssembler:
self.pid, self.pid,
self.payload, self.payload,
) )
except Exception: except Exception as error:
logger.exception(color("!!! exception in callback", "red")) logger.exception(color(f"!!! exception in callback: {error}", "red"))
self.reset() self.reset()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Protocol: class Protocol:
CommandHandler = Callable[[int, bytes], None] CommandHandler = Callable[[int, avc.CommandFrame], None]
command_handlers: dict[int, CommandHandler] # Command handlers, by PID command_handlers: Dict[int, CommandHandler] # Command handlers, by PID
ResponseHandler = Callable[[int, bytes | None], None] ResponseHandler = Callable[[int, Optional[avc.ResponseFrame]], None]
response_handlers: dict[int, ResponseHandler] # Response handlers, by PID response_handlers: Dict[int, ResponseHandler] # Response handlers, by PID
next_transaction_label: int next_transaction_label: int
message_assembler: MessageAssembler message_assembler: MessageAssembler
@@ -204,15 +205,20 @@ class Protocol:
self.send_ipid(transaction_label, pid) self.send_ipid(transaction_label, pid)
return return
self.command_handlers[pid](transaction_label, payload) command_frame = cast(avc.CommandFrame, avc.Frame.from_bytes(payload))
self.command_handlers[pid](transaction_label, command_frame)
else: else:
if pid not in self.response_handlers: if pid not in self.response_handlers:
logger.warning(f"no response handler for PID {pid}") logger.warning(f"no response handler for PID {pid}")
return return
# By convention, for an ipid, send a None payload to the response handler. # By convention, for an ipid, send a None payload to the response handler.
response_payload = None if ipid else payload if ipid:
self.response_handlers[pid](transaction_label, response_payload) response_frame = None
else:
response_frame = cast(avc.ResponseFrame, avc.Frame.from_bytes(payload))
self.response_handlers[pid](transaction_label, response_frame)
def send_message( def send_message(
self, self,
@@ -235,7 +241,7 @@ class Protocol:
) )
+ payload + payload
) )
self.l2cap_channel.write(pdu) self.l2cap_channel.send_pdu(pdu)
def send_command(self, transaction_label: int, pid: int, payload: bytes) -> None: def send_command(self, transaction_label: int, pid: int, payload: bytes) -> None:
logger.debug( logger.debug(
@@ -257,7 +263,7 @@ class Protocol:
def send_ipid(self, transaction_label: int, pid: int) -> None: def send_ipid(self, transaction_label: int, pid: int) -> None:
logger.debug( logger.debug(
f">>> AVCTP ipid: transaction_label={transaction_label}, pid={pid}" ">>> AVCTP ipid: " f"transaction_label={transaction_label}, " f"pid={pid}"
) )
self.send_message(transaction_label, False, True, pid, b'') self.send_message(transaction_label, False, True, pid, b'')
+634 -737
View File
File diff suppressed because it is too large Load Diff
+860 -1787
View File
File diff suppressed because it is too large Load Diff
+2 -10
View File
@@ -37,12 +37,7 @@ class HCI_Bridge:
def on_packet(self, packet): def on_packet(self, packet):
# Convert the packet bytes to an object # Convert the packet bytes to an object
try: hci_packet = HCI_Packet.from_bytes(packet)
hci_packet = HCI_Packet.from_bytes(packet)
except Exception:
logger.warning('forwarding unparsed packet as-is')
self.hci_sink.on_packet(packet)
return
# Filter the packet # Filter the packet
if self.packet_filter is not None: if self.packet_filter is not None:
@@ -55,10 +50,7 @@ class HCI_Bridge:
return return
# Analyze the packet # Analyze the packet
try: self.trace(hci_packet)
self.trace(hci_packet)
except Exception:
logger.exception('Exception while tracing packet')
# Bridge the packet # Bridge the packet
self.hci_sink.on_packet(packet) self.hci_sink.on_packet(packet)
+15 -17
View File
@@ -16,9 +16,7 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing_extensions import Self from typing_extensions import Self
from bumble import core from bumble import core
@@ -163,23 +161,23 @@ class AacAudioRtpPacket:
cls, reader: BitReader, channel_configuration: int, audio_object_type: int cls, reader: BitReader, channel_configuration: int, audio_object_type: int
) -> Self: ) -> Self:
# GASpecificConfig - ISO/EIC 14496-3 Table 4.1 # GASpecificConfig - ISO/EIC 14496-3 Table 4.1
reader.read(1) # frame_length_flag frame_length_flag = reader.read(1)
depends_on_core_coder = reader.read(1) depends_on_core_coder = reader.read(1)
if depends_on_core_coder: if depends_on_core_coder:
reader.read(14) # core_coder_delay core_coder_delay = reader.read(14)
extension_flag = reader.read(1) extension_flag = reader.read(1)
if not channel_configuration: if not channel_configuration:
raise core.InvalidPacketError('program_config_element not supported') raise core.InvalidPacketError('program_config_element not supported')
if audio_object_type in (6, 20): if audio_object_type in (6, 20):
reader.read(3) # layer_nr layer_nr = reader.read(3)
if extension_flag: if extension_flag:
if audio_object_type == 22: if audio_object_type == 22:
reader.read(5) # num_of_sub_frame num_of_sub_frame = reader.read(5)
reader.read(11) # layer_length layer_length = reader.read(11)
if audio_object_type in (17, 19, 20, 23): if audio_object_type in (17, 19, 20, 23):
reader.read(1) # aac_section_data_resilience_flags aac_section_data_resilience_flags = reader.read(1)
reader.read(1) # aac_scale_factor_data_resilience_flags aac_scale_factor_data_resilience_flags = reader.read(1)
reader.read(1) # aac_spectral_data_resilience_flags aac_spectral_data_resilience_flags = reader.read(1)
extension_flag_3 = reader.read(1) extension_flag_3 = reader.read(1)
if extension_flag_3 == 1: if extension_flag_3 == 1:
raise core.InvalidPacketError('extensionFlag3 == 1 not supported') raise core.InvalidPacketError('extensionFlag3 == 1 not supported')
@@ -364,10 +362,10 @@ class AacAudioRtpPacket:
if audio_mux_version_a != 0: if audio_mux_version_a != 0:
raise core.InvalidPacketError('audioMuxVersionA != 0 not supported') raise core.InvalidPacketError('audioMuxVersionA != 0 not supported')
if audio_mux_version == 1: if audio_mux_version == 1:
AacAudioRtpPacket.read_latm_value(reader) # tara_buffer_fullness tara_buffer_fullness = AacAudioRtpPacket.read_latm_value(reader)
# stream_cnt = 0 stream_cnt = 0
reader.read(1) # all_streams_same_time_framing all_streams_same_time_framing = reader.read(1)
reader.read(6) # num_sub_frames num_sub_frames = reader.read(6)
num_program = reader.read(4) num_program = reader.read(4)
if num_program != 0: if num_program != 0:
raise core.InvalidPacketError('num_program != 0 not supported') raise core.InvalidPacketError('num_program != 0 not supported')
@@ -391,9 +389,9 @@ class AacAudioRtpPacket:
reader.skip(asc_len) reader.skip(asc_len)
frame_length_type = reader.read(3) frame_length_type = reader.read(3)
if frame_length_type == 0: if frame_length_type == 0:
reader.read(8) # latm_buffer_fullness latm_buffer_fullness = reader.read(8)
elif frame_length_type == 1: elif frame_length_type == 1:
reader.read(9) # frame_length frame_length = reader.read(9)
else: else:
raise core.InvalidPacketError( raise core.InvalidPacketError(
f'frame_length_type {frame_length_type} not supported' f'frame_length_type {frame_length_type} not supported'
@@ -413,7 +411,7 @@ class AacAudioRtpPacket:
break break
crc_check_present = reader.read(1) crc_check_present = reader.read(1)
if crc_check_present: if crc_check_present:
reader.read(8) # crc_checksum crc_checksum = reader.read(8)
return cls(other_data_present, other_data_len_bits, audio_specific_config) return cls(other_data_present, other_data_len_bits, audio_specific_config)
+9 -8
View File
@@ -13,6 +13,7 @@
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
from functools import partial from functools import partial
from typing import List, Optional, Union
class ColorError(ValueError): class ColorError(ValueError):
@@ -37,7 +38,7 @@ STYLES = (
) )
ColorSpec = str | int ColorSpec = Union[str, int]
def _join(*values: ColorSpec) -> str: def _join(*values: ColorSpec) -> str:
@@ -55,16 +56,16 @@ def _color_code(spec: ColorSpec, base: int) -> str:
elif isinstance(spec, int) and 0 <= spec <= 255: elif isinstance(spec, int) and 0 <= spec <= 255:
return _join(base + 8, 5, spec) return _join(base + 8, 5, spec)
else: else:
raise ColorError(f'Invalid color spec "{spec}"') raise ColorError('Invalid color spec "%s"' % spec)
def color( def color(
s: str, s: str,
fg: ColorSpec | None = None, fg: Optional[ColorSpec] = None,
bg: ColorSpec | None = None, bg: Optional[ColorSpec] = None,
style: str | None = None, style: Optional[str] = None,
) -> str: ) -> str:
codes: list[ColorSpec] = [] codes: List[ColorSpec] = []
if fg: if fg:
codes.append(_color_code(fg, 30)) codes.append(_color_code(fg, 30))
@@ -75,10 +76,10 @@ def color(
if style_part in STYLES: if style_part in STYLES:
codes.append(STYLES.index(style_part)) codes.append(STYLES.index(style_part))
else: else:
raise ColorError(f'Invalid style "{style_part}"') raise ColorError('Invalid style "%s"' % style_part)
if codes: if codes:
return f'\x1b[{_join(*codes)}m{s}\x1b[0m' return '\x1b[{0}m{1}\x1b[0m'.format(_join(*codes), s)
else: else:
return s return s
+843 -1929
View File
File diff suppressed because it is too large Load Diff
+361 -680
View File
File diff suppressed because it is too large Load Diff
+95 -12
View File
@@ -1,6 +1,6 @@
# Copyright 2021-2025 Google LLC # Copyright 2021-2022 Google LLC
# #
# Licensed under the Apache License, Version 2.0 (the "License") # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
@@ -12,6 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# -----------------------------------------------------------------------------
# Crypto support
#
# See Bluetooth spec Vol 3, Part H - 2.2 CRYPTOGRAPHIC TOOLBOX
# -----------------------------------------------------------------------------
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -19,17 +25,20 @@ from __future__ import annotations
import logging import logging
import operator import operator
import secrets import secrets
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.primitives.asymmetric.ec import (
generate_private_key,
ECDH,
EllipticCurvePrivateKey,
EllipticCurvePublicNumbers,
EllipticCurvePrivateNumbers,
SECP256R1,
)
from cryptography.hazmat.primitives import cmac
from typing import Tuple
try:
from bumble.crypto.cryptography import EccKey, aes_cmac, e
except ImportError:
logging.getLogger(__name__).debug(
"Unable to import cryptography, using built-in primitives."
)
from bumble.crypto.builtin import EccKey, aes_cmac, e # type: ignore[assignment]
_EccKey = EccKey # For the linter only
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -37,6 +46,55 @@ _EccKey = EccKey # For the linter only
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
class EccKey:
def __init__(self, private_key: EllipticCurvePrivateKey) -> None:
self.private_key = private_key
@classmethod
def generate(cls) -> EccKey:
private_key = generate_private_key(SECP256R1())
return cls(private_key)
@classmethod
def from_private_key_bytes(
cls, d_bytes: bytes, x_bytes: bytes, y_bytes: bytes
) -> EccKey:
d = int.from_bytes(d_bytes, byteorder='big', signed=False)
x = int.from_bytes(x_bytes, byteorder='big', signed=False)
y = int.from_bytes(y_bytes, byteorder='big', signed=False)
private_key = EllipticCurvePrivateNumbers(
d, EllipticCurvePublicNumbers(x, y, SECP256R1())
).private_key()
return cls(private_key)
@property
def x(self) -> bytes:
return (
self.private_key.public_key()
.public_numbers()
.x.to_bytes(32, byteorder='big')
)
@property
def y(self) -> bytes:
return (
self.private_key.public_key()
.public_numbers()
.y.to_bytes(32, byteorder='big')
)
def dh(self, public_key_x: bytes, public_key_y: bytes) -> bytes:
x = int.from_bytes(public_key_x, byteorder='big', signed=False)
y = int.from_bytes(public_key_y, byteorder='big', signed=False)
public_key = EllipticCurvePublicNumbers(x, y, SECP256R1()).public_key()
shared_key = self.private_key.exchange(ECDH(), public_key)
return shared_key
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Functions # Functions
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -74,6 +132,19 @@ def r() -> bytes:
return secrets.token_bytes(16) return secrets.token_bytes(16)
# -----------------------------------------------------------------------------
def e(key: bytes, data: bytes) -> bytes:
'''
AES-128 ECB, expecting byte-swapped inputs and producing a byte-swapped output.
See Bluetooth spec Vol 3, Part H - 2.2.1 Security function e
'''
cipher = Cipher(algorithms.AES(reverse(key)), modes.ECB())
encryptor = cipher.encryptor()
return reverse(encryptor.update(reverse(data)))
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def ah(k: bytes, r: bytes) -> bytes: # pylint: disable=redefined-outer-name def ah(k: bytes, r: bytes) -> bytes: # pylint: disable=redefined-outer-name
''' '''
@@ -116,6 +187,18 @@ def s1(k: bytes, r1: bytes, r2: bytes) -> bytes:
return e(k, r2[0:8] + r1[0:8]) return e(k, r2[0:8] + r1[0:8])
# -----------------------------------------------------------------------------
def aes_cmac(m: bytes, k: bytes) -> bytes:
'''
See Bluetooth spec, Vol 3, Part H - 2.2.5 FunctionAES-CMAC
NOTE: the input and output of this internal function are in big-endian byte order
'''
mac = cmac.CMAC(algorithms.AES(k))
mac.update(m)
return mac.finalize()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def f4(u: bytes, v: bytes, x: bytes, z: bytes) -> bytes: def f4(u: bytes, v: bytes, x: bytes, z: bytes) -> bytes:
''' '''
@@ -126,7 +209,7 @@ def f4(u: bytes, v: bytes, x: bytes, z: bytes) -> bytes:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def f5(w: bytes, n1: bytes, n2: bytes, a1: bytes, a2: bytes) -> tuple[bytes, bytes]: def f5(w: bytes, n1: bytes, n2: bytes, a1: bytes, a2: bytes) -> Tuple[bytes, bytes]:
''' '''
See Bluetooth spec, Vol 3, Part H - 2.2.7 LE Secure Connections Key Generation See Bluetooth spec, Vol 3, Part H - 2.2.7 LE Secure Connections Key Generation
Function f5 Function f5
-646
View File
@@ -1,646 +0,0 @@
# Copyright 2021-2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License")
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# The implementation is modified from:
# * AES - https://github.com/ricmoo/pyaes by Richard Moore under MIT License
# * CMAC - https://github.com/pycrypto/pycrypto by contributors under pycrypto License.
# -----------------------------------------------------------------------------
# Built-in implementation of cryptography primitives.
#
# Note: It's very dangerous to use this library in the real world.
# -----------------------------------------------------------------------------
from __future__ import annotations
import copy
import dataclasses
import functools
import secrets
import struct
from bumble import core
def _compact_word(word: bytes) -> int:
return int.from_bytes(word, "big")
def _shift_bytes(bs: bytes, xor_lsb: int = 0) -> bytes:
return ((int.from_bytes(bs, "big") << 1) ^ xor_lsb).to_bytes(len(bs) + 1, "big")[1:]
def _xor(a: bytes, b: bytes) -> bytes:
return bytes(x ^ y for x, y in zip(a, b))
# Based *largely* on the Rijndael implementation
# See: http://csrc.nist.gov/publications/FIPS/FIPS197/FIPS-197.pdf
class _AES:
'''Encapsulates the AES block cipher.
You generally should not need this. Use the AESModeOfOperation classes
below instead.'''
# fmt: off
# Number of rounds by key size
_NUMBER_OF_ROUNDS = {16: 10, 24: 12, 32: 14}
# Round constant words
_RCON = [ 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36, 0x6c, 0xd8, 0xab, 0x4d, 0x9a, 0x2f, 0x5e, 0xbc, 0x63, 0xc6, 0x97, 0x35, 0x6a, 0xd4, 0xb3, 0x7d, 0xfa, 0xef, 0xc5, 0x91 ]
# S-box and Inverse S-box (S is for Substitution)
_S = [ 0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b, 0xfe, 0xd7, 0xab, 0x76, 0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0, 0xb7, 0xfd, 0x93, 0x26, 0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15, 0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2, 0xeb, 0x27, 0xb2, 0x75, 0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0, 0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84, 0x53, 0xd1, 0x00, 0xed, 0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf, 0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f, 0x50, 0x3c, 0x9f, 0xa8, 0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5, 0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2, 0xcd, 0x0c, 0x13, 0xec, 0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73, 0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14, 0xde, 0x5e, 0x0b, 0xdb, 0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c, 0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79, 0xe7, 0xc8, 0x37, 0x6d, 0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08, 0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f, 0x4b, 0xbd, 0x8b, 0x8a, 0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, 0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e, 0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf, 0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, 0x16 ]
_S_INV =[ 0x52, 0x09, 0x6a, 0xd5, 0x30, 0x36, 0xa5, 0x38, 0xbf, 0x40, 0xa3, 0x9e, 0x81, 0xf3, 0xd7, 0xfb, 0x7c, 0xe3, 0x39, 0x82, 0x9b, 0x2f, 0xff, 0x87, 0x34, 0x8e, 0x43, 0x44, 0xc4, 0xde, 0xe9, 0xcb, 0x54, 0x7b, 0x94, 0x32, 0xa6, 0xc2, 0x23, 0x3d, 0xee, 0x4c, 0x95, 0x0b, 0x42, 0xfa, 0xc3, 0x4e, 0x08, 0x2e, 0xa1, 0x66, 0x28, 0xd9, 0x24, 0xb2, 0x76, 0x5b, 0xa2, 0x49, 0x6d, 0x8b, 0xd1, 0x25, 0x72, 0xf8, 0xf6, 0x64, 0x86, 0x68, 0x98, 0x16, 0xd4, 0xa4, 0x5c, 0xcc, 0x5d, 0x65, 0xb6, 0x92, 0x6c, 0x70, 0x48, 0x50, 0xfd, 0xed, 0xb9, 0xda, 0x5e, 0x15, 0x46, 0x57, 0xa7, 0x8d, 0x9d, 0x84, 0x90, 0xd8, 0xab, 0x00, 0x8c, 0xbc, 0xd3, 0x0a, 0xf7, 0xe4, 0x58, 0x05, 0xb8, 0xb3, 0x45, 0x06, 0xd0, 0x2c, 0x1e, 0x8f, 0xca, 0x3f, 0x0f, 0x02, 0xc1, 0xaf, 0xbd, 0x03, 0x01, 0x13, 0x8a, 0x6b, 0x3a, 0x91, 0x11, 0x41, 0x4f, 0x67, 0xdc, 0xea, 0x97, 0xf2, 0xcf, 0xce, 0xf0, 0xb4, 0xe6, 0x73, 0x96, 0xac, 0x74, 0x22, 0xe7, 0xad, 0x35, 0x85, 0xe2, 0xf9, 0x37, 0xe8, 0x1c, 0x75, 0xdf, 0x6e, 0x47, 0xf1, 0x1a, 0x71, 0x1d, 0x29, 0xc5, 0x89, 0x6f, 0xb7, 0x62, 0x0e, 0xaa, 0x18, 0xbe, 0x1b, 0xfc, 0x56, 0x3e, 0x4b, 0xc6, 0xd2, 0x79, 0x20, 0x9a, 0xdb, 0xc0, 0xfe, 0x78, 0xcd, 0x5a, 0xf4, 0x1f, 0xdd, 0xa8, 0x33, 0x88, 0x07, 0xc7, 0x31, 0xb1, 0x12, 0x10, 0x59, 0x27, 0x80, 0xec, 0x5f, 0x60, 0x51, 0x7f, 0xa9, 0x19, 0xb5, 0x4a, 0x0d, 0x2d, 0xe5, 0x7a, 0x9f, 0x93, 0xc9, 0x9c, 0xef, 0xa0, 0xe0, 0x3b, 0x4d, 0xae, 0x2a, 0xf5, 0xb0, 0xc8, 0xeb, 0xbb, 0x3c, 0x83, 0x53, 0x99, 0x61, 0x17, 0x2b, 0x04, 0x7e, 0xba, 0x77, 0xd6, 0x26, 0xe1, 0x69, 0x14, 0x63, 0x55, 0x21, 0x0c, 0x7d ]
# Transformations for encryption
_T1 = [ 0xc66363a5, 0xf87c7c84, 0xee777799, 0xf67b7b8d, 0xfff2f20d, 0xd66b6bbd, 0xde6f6fb1, 0x91c5c554, 0x60303050, 0x02010103, 0xce6767a9, 0x562b2b7d, 0xe7fefe19, 0xb5d7d762, 0x4dababe6, 0xec76769a, 0x8fcaca45, 0x1f82829d, 0x89c9c940, 0xfa7d7d87, 0xeffafa15, 0xb25959eb, 0x8e4747c9, 0xfbf0f00b, 0x41adadec, 0xb3d4d467, 0x5fa2a2fd, 0x45afafea, 0x239c9cbf, 0x53a4a4f7, 0xe4727296, 0x9bc0c05b, 0x75b7b7c2, 0xe1fdfd1c, 0x3d9393ae, 0x4c26266a, 0x6c36365a, 0x7e3f3f41, 0xf5f7f702, 0x83cccc4f, 0x6834345c, 0x51a5a5f4, 0xd1e5e534, 0xf9f1f108, 0xe2717193, 0xabd8d873, 0x62313153, 0x2a15153f, 0x0804040c, 0x95c7c752, 0x46232365, 0x9dc3c35e, 0x30181828, 0x379696a1, 0x0a05050f, 0x2f9a9ab5, 0x0e070709, 0x24121236, 0x1b80809b, 0xdfe2e23d, 0xcdebeb26, 0x4e272769, 0x7fb2b2cd, 0xea75759f, 0x1209091b, 0x1d83839e, 0x582c2c74, 0x341a1a2e, 0x361b1b2d, 0xdc6e6eb2, 0xb45a5aee, 0x5ba0a0fb, 0xa45252f6, 0x763b3b4d, 0xb7d6d661, 0x7db3b3ce, 0x5229297b, 0xdde3e33e, 0x5e2f2f71, 0x13848497, 0xa65353f5, 0xb9d1d168, 0x00000000, 0xc1eded2c, 0x40202060, 0xe3fcfc1f, 0x79b1b1c8, 0xb65b5bed, 0xd46a6abe, 0x8dcbcb46, 0x67bebed9, 0x7239394b, 0x944a4ade, 0x984c4cd4, 0xb05858e8, 0x85cfcf4a, 0xbbd0d06b, 0xc5efef2a, 0x4faaaae5, 0xedfbfb16, 0x864343c5, 0x9a4d4dd7, 0x66333355, 0x11858594, 0x8a4545cf, 0xe9f9f910, 0x04020206, 0xfe7f7f81, 0xa05050f0, 0x783c3c44, 0x259f9fba, 0x4ba8a8e3, 0xa25151f3, 0x5da3a3fe, 0x804040c0, 0x058f8f8a, 0x3f9292ad, 0x219d9dbc, 0x70383848, 0xf1f5f504, 0x63bcbcdf, 0x77b6b6c1, 0xafdada75, 0x42212163, 0x20101030, 0xe5ffff1a, 0xfdf3f30e, 0xbfd2d26d, 0x81cdcd4c, 0x180c0c14, 0x26131335, 0xc3ecec2f, 0xbe5f5fe1, 0x359797a2, 0x884444cc, 0x2e171739, 0x93c4c457, 0x55a7a7f2, 0xfc7e7e82, 0x7a3d3d47, 0xc86464ac, 0xba5d5de7, 0x3219192b, 0xe6737395, 0xc06060a0, 0x19818198, 0x9e4f4fd1, 0xa3dcdc7f, 0x44222266, 0x542a2a7e, 0x3b9090ab, 0x0b888883, 0x8c4646ca, 0xc7eeee29, 0x6bb8b8d3, 0x2814143c, 0xa7dede79, 0xbc5e5ee2, 0x160b0b1d, 0xaddbdb76, 0xdbe0e03b, 0x64323256, 0x743a3a4e, 0x140a0a1e, 0x924949db, 0x0c06060a, 0x4824246c, 0xb85c5ce4, 0x9fc2c25d, 0xbdd3d36e, 0x43acacef, 0xc46262a6, 0x399191a8, 0x319595a4, 0xd3e4e437, 0xf279798b, 0xd5e7e732, 0x8bc8c843, 0x6e373759, 0xda6d6db7, 0x018d8d8c, 0xb1d5d564, 0x9c4e4ed2, 0x49a9a9e0, 0xd86c6cb4, 0xac5656fa, 0xf3f4f407, 0xcfeaea25, 0xca6565af, 0xf47a7a8e, 0x47aeaee9, 0x10080818, 0x6fbabad5, 0xf0787888, 0x4a25256f, 0x5c2e2e72, 0x381c1c24, 0x57a6a6f1, 0x73b4b4c7, 0x97c6c651, 0xcbe8e823, 0xa1dddd7c, 0xe874749c, 0x3e1f1f21, 0x964b4bdd, 0x61bdbddc, 0x0d8b8b86, 0x0f8a8a85, 0xe0707090, 0x7c3e3e42, 0x71b5b5c4, 0xcc6666aa, 0x904848d8, 0x06030305, 0xf7f6f601, 0x1c0e0e12, 0xc26161a3, 0x6a35355f, 0xae5757f9, 0x69b9b9d0, 0x17868691, 0x99c1c158, 0x3a1d1d27, 0x279e9eb9, 0xd9e1e138, 0xebf8f813, 0x2b9898b3, 0x22111133, 0xd26969bb, 0xa9d9d970, 0x078e8e89, 0x339494a7, 0x2d9b9bb6, 0x3c1e1e22, 0x15878792, 0xc9e9e920, 0x87cece49, 0xaa5555ff, 0x50282878, 0xa5dfdf7a, 0x038c8c8f, 0x59a1a1f8, 0x09898980, 0x1a0d0d17, 0x65bfbfda, 0xd7e6e631, 0x844242c6, 0xd06868b8, 0x824141c3, 0x299999b0, 0x5a2d2d77, 0x1e0f0f11, 0x7bb0b0cb, 0xa85454fc, 0x6dbbbbd6, 0x2c16163a ]
_T2 = [ 0xa5c66363, 0x84f87c7c, 0x99ee7777, 0x8df67b7b, 0x0dfff2f2, 0xbdd66b6b, 0xb1de6f6f, 0x5491c5c5, 0x50603030, 0x03020101, 0xa9ce6767, 0x7d562b2b, 0x19e7fefe, 0x62b5d7d7, 0xe64dabab, 0x9aec7676, 0x458fcaca, 0x9d1f8282, 0x4089c9c9, 0x87fa7d7d, 0x15effafa, 0xebb25959, 0xc98e4747, 0x0bfbf0f0, 0xec41adad, 0x67b3d4d4, 0xfd5fa2a2, 0xea45afaf, 0xbf239c9c, 0xf753a4a4, 0x96e47272, 0x5b9bc0c0, 0xc275b7b7, 0x1ce1fdfd, 0xae3d9393, 0x6a4c2626, 0x5a6c3636, 0x417e3f3f, 0x02f5f7f7, 0x4f83cccc, 0x5c683434, 0xf451a5a5, 0x34d1e5e5, 0x08f9f1f1, 0x93e27171, 0x73abd8d8, 0x53623131, 0x3f2a1515, 0x0c080404, 0x5295c7c7, 0x65462323, 0x5e9dc3c3, 0x28301818, 0xa1379696, 0x0f0a0505, 0xb52f9a9a, 0x090e0707, 0x36241212, 0x9b1b8080, 0x3ddfe2e2, 0x26cdebeb, 0x694e2727, 0xcd7fb2b2, 0x9fea7575, 0x1b120909, 0x9e1d8383, 0x74582c2c, 0x2e341a1a, 0x2d361b1b, 0xb2dc6e6e, 0xeeb45a5a, 0xfb5ba0a0, 0xf6a45252, 0x4d763b3b, 0x61b7d6d6, 0xce7db3b3, 0x7b522929, 0x3edde3e3, 0x715e2f2f, 0x97138484, 0xf5a65353, 0x68b9d1d1, 0x00000000, 0x2cc1eded, 0x60402020, 0x1fe3fcfc, 0xc879b1b1, 0xedb65b5b, 0xbed46a6a, 0x468dcbcb, 0xd967bebe, 0x4b723939, 0xde944a4a, 0xd4984c4c, 0xe8b05858, 0x4a85cfcf, 0x6bbbd0d0, 0x2ac5efef, 0xe54faaaa, 0x16edfbfb, 0xc5864343, 0xd79a4d4d, 0x55663333, 0x94118585, 0xcf8a4545, 0x10e9f9f9, 0x06040202, 0x81fe7f7f, 0xf0a05050, 0x44783c3c, 0xba259f9f, 0xe34ba8a8, 0xf3a25151, 0xfe5da3a3, 0xc0804040, 0x8a058f8f, 0xad3f9292, 0xbc219d9d, 0x48703838, 0x04f1f5f5, 0xdf63bcbc, 0xc177b6b6, 0x75afdada, 0x63422121, 0x30201010, 0x1ae5ffff, 0x0efdf3f3, 0x6dbfd2d2, 0x4c81cdcd, 0x14180c0c, 0x35261313, 0x2fc3ecec, 0xe1be5f5f, 0xa2359797, 0xcc884444, 0x392e1717, 0x5793c4c4, 0xf255a7a7, 0x82fc7e7e, 0x477a3d3d, 0xacc86464, 0xe7ba5d5d, 0x2b321919, 0x95e67373, 0xa0c06060, 0x98198181, 0xd19e4f4f, 0x7fa3dcdc, 0x66442222, 0x7e542a2a, 0xab3b9090, 0x830b8888, 0xca8c4646, 0x29c7eeee, 0xd36bb8b8, 0x3c281414, 0x79a7dede, 0xe2bc5e5e, 0x1d160b0b, 0x76addbdb, 0x3bdbe0e0, 0x56643232, 0x4e743a3a, 0x1e140a0a, 0xdb924949, 0x0a0c0606, 0x6c482424, 0xe4b85c5c, 0x5d9fc2c2, 0x6ebdd3d3, 0xef43acac, 0xa6c46262, 0xa8399191, 0xa4319595, 0x37d3e4e4, 0x8bf27979, 0x32d5e7e7, 0x438bc8c8, 0x596e3737, 0xb7da6d6d, 0x8c018d8d, 0x64b1d5d5, 0xd29c4e4e, 0xe049a9a9, 0xb4d86c6c, 0xfaac5656, 0x07f3f4f4, 0x25cfeaea, 0xafca6565, 0x8ef47a7a, 0xe947aeae, 0x18100808, 0xd56fbaba, 0x88f07878, 0x6f4a2525, 0x725c2e2e, 0x24381c1c, 0xf157a6a6, 0xc773b4b4, 0x5197c6c6, 0x23cbe8e8, 0x7ca1dddd, 0x9ce87474, 0x213e1f1f, 0xdd964b4b, 0xdc61bdbd, 0x860d8b8b, 0x850f8a8a, 0x90e07070, 0x427c3e3e, 0xc471b5b5, 0xaacc6666, 0xd8904848, 0x05060303, 0x01f7f6f6, 0x121c0e0e, 0xa3c26161, 0x5f6a3535, 0xf9ae5757, 0xd069b9b9, 0x91178686, 0x5899c1c1, 0x273a1d1d, 0xb9279e9e, 0x38d9e1e1, 0x13ebf8f8, 0xb32b9898, 0x33221111, 0xbbd26969, 0x70a9d9d9, 0x89078e8e, 0xa7339494, 0xb62d9b9b, 0x223c1e1e, 0x92158787, 0x20c9e9e9, 0x4987cece, 0xffaa5555, 0x78502828, 0x7aa5dfdf, 0x8f038c8c, 0xf859a1a1, 0x80098989, 0x171a0d0d, 0xda65bfbf, 0x31d7e6e6, 0xc6844242, 0xb8d06868, 0xc3824141, 0xb0299999, 0x775a2d2d, 0x111e0f0f, 0xcb7bb0b0, 0xfca85454, 0xd66dbbbb, 0x3a2c1616 ]
_T3 = [ 0x63a5c663, 0x7c84f87c, 0x7799ee77, 0x7b8df67b, 0xf20dfff2, 0x6bbdd66b, 0x6fb1de6f, 0xc55491c5, 0x30506030, 0x01030201, 0x67a9ce67, 0x2b7d562b, 0xfe19e7fe, 0xd762b5d7, 0xabe64dab, 0x769aec76, 0xca458fca, 0x829d1f82, 0xc94089c9, 0x7d87fa7d, 0xfa15effa, 0x59ebb259, 0x47c98e47, 0xf00bfbf0, 0xadec41ad, 0xd467b3d4, 0xa2fd5fa2, 0xafea45af, 0x9cbf239c, 0xa4f753a4, 0x7296e472, 0xc05b9bc0, 0xb7c275b7, 0xfd1ce1fd, 0x93ae3d93, 0x266a4c26, 0x365a6c36, 0x3f417e3f, 0xf702f5f7, 0xcc4f83cc, 0x345c6834, 0xa5f451a5, 0xe534d1e5, 0xf108f9f1, 0x7193e271, 0xd873abd8, 0x31536231, 0x153f2a15, 0x040c0804, 0xc75295c7, 0x23654623, 0xc35e9dc3, 0x18283018, 0x96a13796, 0x050f0a05, 0x9ab52f9a, 0x07090e07, 0x12362412, 0x809b1b80, 0xe23ddfe2, 0xeb26cdeb, 0x27694e27, 0xb2cd7fb2, 0x759fea75, 0x091b1209, 0x839e1d83, 0x2c74582c, 0x1a2e341a, 0x1b2d361b, 0x6eb2dc6e, 0x5aeeb45a, 0xa0fb5ba0, 0x52f6a452, 0x3b4d763b, 0xd661b7d6, 0xb3ce7db3, 0x297b5229, 0xe33edde3, 0x2f715e2f, 0x84971384, 0x53f5a653, 0xd168b9d1, 0x00000000, 0xed2cc1ed, 0x20604020, 0xfc1fe3fc, 0xb1c879b1, 0x5bedb65b, 0x6abed46a, 0xcb468dcb, 0xbed967be, 0x394b7239, 0x4ade944a, 0x4cd4984c, 0x58e8b058, 0xcf4a85cf, 0xd06bbbd0, 0xef2ac5ef, 0xaae54faa, 0xfb16edfb, 0x43c58643, 0x4dd79a4d, 0x33556633, 0x85941185, 0x45cf8a45, 0xf910e9f9, 0x02060402, 0x7f81fe7f, 0x50f0a050, 0x3c44783c, 0x9fba259f, 0xa8e34ba8, 0x51f3a251, 0xa3fe5da3, 0x40c08040, 0x8f8a058f, 0x92ad3f92, 0x9dbc219d, 0x38487038, 0xf504f1f5, 0xbcdf63bc, 0xb6c177b6, 0xda75afda, 0x21634221, 0x10302010, 0xff1ae5ff, 0xf30efdf3, 0xd26dbfd2, 0xcd4c81cd, 0x0c14180c, 0x13352613, 0xec2fc3ec, 0x5fe1be5f, 0x97a23597, 0x44cc8844, 0x17392e17, 0xc45793c4, 0xa7f255a7, 0x7e82fc7e, 0x3d477a3d, 0x64acc864, 0x5de7ba5d, 0x192b3219, 0x7395e673, 0x60a0c060, 0x81981981, 0x4fd19e4f, 0xdc7fa3dc, 0x22664422, 0x2a7e542a, 0x90ab3b90, 0x88830b88, 0x46ca8c46, 0xee29c7ee, 0xb8d36bb8, 0x143c2814, 0xde79a7de, 0x5ee2bc5e, 0x0b1d160b, 0xdb76addb, 0xe03bdbe0, 0x32566432, 0x3a4e743a, 0x0a1e140a, 0x49db9249, 0x060a0c06, 0x246c4824, 0x5ce4b85c, 0xc25d9fc2, 0xd36ebdd3, 0xacef43ac, 0x62a6c462, 0x91a83991, 0x95a43195, 0xe437d3e4, 0x798bf279, 0xe732d5e7, 0xc8438bc8, 0x37596e37, 0x6db7da6d, 0x8d8c018d, 0xd564b1d5, 0x4ed29c4e, 0xa9e049a9, 0x6cb4d86c, 0x56faac56, 0xf407f3f4, 0xea25cfea, 0x65afca65, 0x7a8ef47a, 0xaee947ae, 0x08181008, 0xbad56fba, 0x7888f078, 0x256f4a25, 0x2e725c2e, 0x1c24381c, 0xa6f157a6, 0xb4c773b4, 0xc65197c6, 0xe823cbe8, 0xdd7ca1dd, 0x749ce874, 0x1f213e1f, 0x4bdd964b, 0xbddc61bd, 0x8b860d8b, 0x8a850f8a, 0x7090e070, 0x3e427c3e, 0xb5c471b5, 0x66aacc66, 0x48d89048, 0x03050603, 0xf601f7f6, 0x0e121c0e, 0x61a3c261, 0x355f6a35, 0x57f9ae57, 0xb9d069b9, 0x86911786, 0xc15899c1, 0x1d273a1d, 0x9eb9279e, 0xe138d9e1, 0xf813ebf8, 0x98b32b98, 0x11332211, 0x69bbd269, 0xd970a9d9, 0x8e89078e, 0x94a73394, 0x9bb62d9b, 0x1e223c1e, 0x87921587, 0xe920c9e9, 0xce4987ce, 0x55ffaa55, 0x28785028, 0xdf7aa5df, 0x8c8f038c, 0xa1f859a1, 0x89800989, 0x0d171a0d, 0xbfda65bf, 0xe631d7e6, 0x42c68442, 0x68b8d068, 0x41c38241, 0x99b02999, 0x2d775a2d, 0x0f111e0f, 0xb0cb7bb0, 0x54fca854, 0xbbd66dbb, 0x163a2c16 ]
_T4 = [ 0x6363a5c6, 0x7c7c84f8, 0x777799ee, 0x7b7b8df6, 0xf2f20dff, 0x6b6bbdd6, 0x6f6fb1de, 0xc5c55491, 0x30305060, 0x01010302, 0x6767a9ce, 0x2b2b7d56, 0xfefe19e7, 0xd7d762b5, 0xababe64d, 0x76769aec, 0xcaca458f, 0x82829d1f, 0xc9c94089, 0x7d7d87fa, 0xfafa15ef, 0x5959ebb2, 0x4747c98e, 0xf0f00bfb, 0xadadec41, 0xd4d467b3, 0xa2a2fd5f, 0xafafea45, 0x9c9cbf23, 0xa4a4f753, 0x727296e4, 0xc0c05b9b, 0xb7b7c275, 0xfdfd1ce1, 0x9393ae3d, 0x26266a4c, 0x36365a6c, 0x3f3f417e, 0xf7f702f5, 0xcccc4f83, 0x34345c68, 0xa5a5f451, 0xe5e534d1, 0xf1f108f9, 0x717193e2, 0xd8d873ab, 0x31315362, 0x15153f2a, 0x04040c08, 0xc7c75295, 0x23236546, 0xc3c35e9d, 0x18182830, 0x9696a137, 0x05050f0a, 0x9a9ab52f, 0x0707090e, 0x12123624, 0x80809b1b, 0xe2e23ddf, 0xebeb26cd, 0x2727694e, 0xb2b2cd7f, 0x75759fea, 0x09091b12, 0x83839e1d, 0x2c2c7458, 0x1a1a2e34, 0x1b1b2d36, 0x6e6eb2dc, 0x5a5aeeb4, 0xa0a0fb5b, 0x5252f6a4, 0x3b3b4d76, 0xd6d661b7, 0xb3b3ce7d, 0x29297b52, 0xe3e33edd, 0x2f2f715e, 0x84849713, 0x5353f5a6, 0xd1d168b9, 0x00000000, 0xeded2cc1, 0x20206040, 0xfcfc1fe3, 0xb1b1c879, 0x5b5bedb6, 0x6a6abed4, 0xcbcb468d, 0xbebed967, 0x39394b72, 0x4a4ade94, 0x4c4cd498, 0x5858e8b0, 0xcfcf4a85, 0xd0d06bbb, 0xefef2ac5, 0xaaaae54f, 0xfbfb16ed, 0x4343c586, 0x4d4dd79a, 0x33335566, 0x85859411, 0x4545cf8a, 0xf9f910e9, 0x02020604, 0x7f7f81fe, 0x5050f0a0, 0x3c3c4478, 0x9f9fba25, 0xa8a8e34b, 0x5151f3a2, 0xa3a3fe5d, 0x4040c080, 0x8f8f8a05, 0x9292ad3f, 0x9d9dbc21, 0x38384870, 0xf5f504f1, 0xbcbcdf63, 0xb6b6c177, 0xdada75af, 0x21216342, 0x10103020, 0xffff1ae5, 0xf3f30efd, 0xd2d26dbf, 0xcdcd4c81, 0x0c0c1418, 0x13133526, 0xecec2fc3, 0x5f5fe1be, 0x9797a235, 0x4444cc88, 0x1717392e, 0xc4c45793, 0xa7a7f255, 0x7e7e82fc, 0x3d3d477a, 0x6464acc8, 0x5d5de7ba, 0x19192b32, 0x737395e6, 0x6060a0c0, 0x81819819, 0x4f4fd19e, 0xdcdc7fa3, 0x22226644, 0x2a2a7e54, 0x9090ab3b, 0x8888830b, 0x4646ca8c, 0xeeee29c7, 0xb8b8d36b, 0x14143c28, 0xdede79a7, 0x5e5ee2bc, 0x0b0b1d16, 0xdbdb76ad, 0xe0e03bdb, 0x32325664, 0x3a3a4e74, 0x0a0a1e14, 0x4949db92, 0x06060a0c, 0x24246c48, 0x5c5ce4b8, 0xc2c25d9f, 0xd3d36ebd, 0xacacef43, 0x6262a6c4, 0x9191a839, 0x9595a431, 0xe4e437d3, 0x79798bf2, 0xe7e732d5, 0xc8c8438b, 0x3737596e, 0x6d6db7da, 0x8d8d8c01, 0xd5d564b1, 0x4e4ed29c, 0xa9a9e049, 0x6c6cb4d8, 0x5656faac, 0xf4f407f3, 0xeaea25cf, 0x6565afca, 0x7a7a8ef4, 0xaeaee947, 0x08081810, 0xbabad56f, 0x787888f0, 0x25256f4a, 0x2e2e725c, 0x1c1c2438, 0xa6a6f157, 0xb4b4c773, 0xc6c65197, 0xe8e823cb, 0xdddd7ca1, 0x74749ce8, 0x1f1f213e, 0x4b4bdd96, 0xbdbddc61, 0x8b8b860d, 0x8a8a850f, 0x707090e0, 0x3e3e427c, 0xb5b5c471, 0x6666aacc, 0x4848d890, 0x03030506, 0xf6f601f7, 0x0e0e121c, 0x6161a3c2, 0x35355f6a, 0x5757f9ae, 0xb9b9d069, 0x86869117, 0xc1c15899, 0x1d1d273a, 0x9e9eb927, 0xe1e138d9, 0xf8f813eb, 0x9898b32b, 0x11113322, 0x6969bbd2, 0xd9d970a9, 0x8e8e8907, 0x9494a733, 0x9b9bb62d, 0x1e1e223c, 0x87879215, 0xe9e920c9, 0xcece4987, 0x5555ffaa, 0x28287850, 0xdfdf7aa5, 0x8c8c8f03, 0xa1a1f859, 0x89898009, 0x0d0d171a, 0xbfbfda65, 0xe6e631d7, 0x4242c684, 0x6868b8d0, 0x4141c382, 0x9999b029, 0x2d2d775a, 0x0f0f111e, 0xb0b0cb7b, 0x5454fca8, 0xbbbbd66d, 0x16163a2c ]
# Transformations for decryption
_T5 = [ 0x51f4a750, 0x7e416553, 0x1a17a4c3, 0x3a275e96, 0x3bab6bcb, 0x1f9d45f1, 0xacfa58ab, 0x4be30393, 0x2030fa55, 0xad766df6, 0x88cc7691, 0xf5024c25, 0x4fe5d7fc, 0xc52acbd7, 0x26354480, 0xb562a38f, 0xdeb15a49, 0x25ba1b67, 0x45ea0e98, 0x5dfec0e1, 0xc32f7502, 0x814cf012, 0x8d4697a3, 0x6bd3f9c6, 0x038f5fe7, 0x15929c95, 0xbf6d7aeb, 0x955259da, 0xd4be832d, 0x587421d3, 0x49e06929, 0x8ec9c844, 0x75c2896a, 0xf48e7978, 0x99583e6b, 0x27b971dd, 0xbee14fb6, 0xf088ad17, 0xc920ac66, 0x7dce3ab4, 0x63df4a18, 0xe51a3182, 0x97513360, 0x62537f45, 0xb16477e0, 0xbb6bae84, 0xfe81a01c, 0xf9082b94, 0x70486858, 0x8f45fd19, 0x94de6c87, 0x527bf8b7, 0xab73d323, 0x724b02e2, 0xe31f8f57, 0x6655ab2a, 0xb2eb2807, 0x2fb5c203, 0x86c57b9a, 0xd33708a5, 0x302887f2, 0x23bfa5b2, 0x02036aba, 0xed16825c, 0x8acf1c2b, 0xa779b492, 0xf307f2f0, 0x4e69e2a1, 0x65daf4cd, 0x0605bed5, 0xd134621f, 0xc4a6fe8a, 0x342e539d, 0xa2f355a0, 0x058ae132, 0xa4f6eb75, 0x0b83ec39, 0x4060efaa, 0x5e719f06, 0xbd6e1051, 0x3e218af9, 0x96dd063d, 0xdd3e05ae, 0x4de6bd46, 0x91548db5, 0x71c45d05, 0x0406d46f, 0x605015ff, 0x1998fb24, 0xd6bde997, 0x894043cc, 0x67d99e77, 0xb0e842bd, 0x07898b88, 0xe7195b38, 0x79c8eedb, 0xa17c0a47, 0x7c420fe9, 0xf8841ec9, 0x00000000, 0x09808683, 0x322bed48, 0x1e1170ac, 0x6c5a724e, 0xfd0efffb, 0x0f853856, 0x3daed51e, 0x362d3927, 0x0a0fd964, 0x685ca621, 0x9b5b54d1, 0x24362e3a, 0x0c0a67b1, 0x9357e70f, 0xb4ee96d2, 0x1b9b919e, 0x80c0c54f, 0x61dc20a2, 0x5a774b69, 0x1c121a16, 0xe293ba0a, 0xc0a02ae5, 0x3c22e043, 0x121b171d, 0x0e090d0b, 0xf28bc7ad, 0x2db6a8b9, 0x141ea9c8, 0x57f11985, 0xaf75074c, 0xee99ddbb, 0xa37f60fd, 0xf701269f, 0x5c72f5bc, 0x44663bc5, 0x5bfb7e34, 0x8b432976, 0xcb23c6dc, 0xb6edfc68, 0xb8e4f163, 0xd731dcca, 0x42638510, 0x13972240, 0x84c61120, 0x854a247d, 0xd2bb3df8, 0xaef93211, 0xc729a16d, 0x1d9e2f4b, 0xdcb230f3, 0x0d8652ec, 0x77c1e3d0, 0x2bb3166c, 0xa970b999, 0x119448fa, 0x47e96422, 0xa8fc8cc4, 0xa0f03f1a, 0x567d2cd8, 0x223390ef, 0x87494ec7, 0xd938d1c1, 0x8ccaa2fe, 0x98d40b36, 0xa6f581cf, 0xa57ade28, 0xdab78e26, 0x3fadbfa4, 0x2c3a9de4, 0x5078920d, 0x6a5fcc9b, 0x547e4662, 0xf68d13c2, 0x90d8b8e8, 0x2e39f75e, 0x82c3aff5, 0x9f5d80be, 0x69d0937c, 0x6fd52da9, 0xcf2512b3, 0xc8ac993b, 0x10187da7, 0xe89c636e, 0xdb3bbb7b, 0xcd267809, 0x6e5918f4, 0xec9ab701, 0x834f9aa8, 0xe6956e65, 0xaaffe67e, 0x21bccf08, 0xef15e8e6, 0xbae79bd9, 0x4a6f36ce, 0xea9f09d4, 0x29b07cd6, 0x31a4b2af, 0x2a3f2331, 0xc6a59430, 0x35a266c0, 0x744ebc37, 0xfc82caa6, 0xe090d0b0, 0x33a7d815, 0xf104984a, 0x41ecdaf7, 0x7fcd500e, 0x1791f62f, 0x764dd68d, 0x43efb04d, 0xccaa4d54, 0xe49604df, 0x9ed1b5e3, 0x4c6a881b, 0xc12c1fb8, 0x4665517f, 0x9d5eea04, 0x018c355d, 0xfa877473, 0xfb0b412e, 0xb3671d5a, 0x92dbd252, 0xe9105633, 0x6dd64713, 0x9ad7618c, 0x37a10c7a, 0x59f8148e, 0xeb133c89, 0xcea927ee, 0xb761c935, 0xe11ce5ed, 0x7a47b13c, 0x9cd2df59, 0x55f2733f, 0x1814ce79, 0x73c737bf, 0x53f7cdea, 0x5ffdaa5b, 0xdf3d6f14, 0x7844db86, 0xcaaff381, 0xb968c43e, 0x3824342c, 0xc2a3405f, 0x161dc372, 0xbce2250c, 0x283c498b, 0xff0d9541, 0x39a80171, 0x080cb3de, 0xd8b4e49c, 0x6456c190, 0x7bcb8461, 0xd532b670, 0x486c5c74, 0xd0b85742 ]
_T6 = [ 0x5051f4a7, 0x537e4165, 0xc31a17a4, 0x963a275e, 0xcb3bab6b, 0xf11f9d45, 0xabacfa58, 0x934be303, 0x552030fa, 0xf6ad766d, 0x9188cc76, 0x25f5024c, 0xfc4fe5d7, 0xd7c52acb, 0x80263544, 0x8fb562a3, 0x49deb15a, 0x6725ba1b, 0x9845ea0e, 0xe15dfec0, 0x02c32f75, 0x12814cf0, 0xa38d4697, 0xc66bd3f9, 0xe7038f5f, 0x9515929c, 0xebbf6d7a, 0xda955259, 0x2dd4be83, 0xd3587421, 0x2949e069, 0x448ec9c8, 0x6a75c289, 0x78f48e79, 0x6b99583e, 0xdd27b971, 0xb6bee14f, 0x17f088ad, 0x66c920ac, 0xb47dce3a, 0x1863df4a, 0x82e51a31, 0x60975133, 0x4562537f, 0xe0b16477, 0x84bb6bae, 0x1cfe81a0, 0x94f9082b, 0x58704868, 0x198f45fd, 0x8794de6c, 0xb7527bf8, 0x23ab73d3, 0xe2724b02, 0x57e31f8f, 0x2a6655ab, 0x07b2eb28, 0x032fb5c2, 0x9a86c57b, 0xa5d33708, 0xf2302887, 0xb223bfa5, 0xba02036a, 0x5ced1682, 0x2b8acf1c, 0x92a779b4, 0xf0f307f2, 0xa14e69e2, 0xcd65daf4, 0xd50605be, 0x1fd13462, 0x8ac4a6fe, 0x9d342e53, 0xa0a2f355, 0x32058ae1, 0x75a4f6eb, 0x390b83ec, 0xaa4060ef, 0x065e719f, 0x51bd6e10, 0xf93e218a, 0x3d96dd06, 0xaedd3e05, 0x464de6bd, 0xb591548d, 0x0571c45d, 0x6f0406d4, 0xff605015, 0x241998fb, 0x97d6bde9, 0xcc894043, 0x7767d99e, 0xbdb0e842, 0x8807898b, 0x38e7195b, 0xdb79c8ee, 0x47a17c0a, 0xe97c420f, 0xc9f8841e, 0x00000000, 0x83098086, 0x48322bed, 0xac1e1170, 0x4e6c5a72, 0xfbfd0eff, 0x560f8538, 0x1e3daed5, 0x27362d39, 0x640a0fd9, 0x21685ca6, 0xd19b5b54, 0x3a24362e, 0xb10c0a67, 0x0f9357e7, 0xd2b4ee96, 0x9e1b9b91, 0x4f80c0c5, 0xa261dc20, 0x695a774b, 0x161c121a, 0x0ae293ba, 0xe5c0a02a, 0x433c22e0, 0x1d121b17, 0x0b0e090d, 0xadf28bc7, 0xb92db6a8, 0xc8141ea9, 0x8557f119, 0x4caf7507, 0xbbee99dd, 0xfda37f60, 0x9ff70126, 0xbc5c72f5, 0xc544663b, 0x345bfb7e, 0x768b4329, 0xdccb23c6, 0x68b6edfc, 0x63b8e4f1, 0xcad731dc, 0x10426385, 0x40139722, 0x2084c611, 0x7d854a24, 0xf8d2bb3d, 0x11aef932, 0x6dc729a1, 0x4b1d9e2f, 0xf3dcb230, 0xec0d8652, 0xd077c1e3, 0x6c2bb316, 0x99a970b9, 0xfa119448, 0x2247e964, 0xc4a8fc8c, 0x1aa0f03f, 0xd8567d2c, 0xef223390, 0xc787494e, 0xc1d938d1, 0xfe8ccaa2, 0x3698d40b, 0xcfa6f581, 0x28a57ade, 0x26dab78e, 0xa43fadbf, 0xe42c3a9d, 0x0d507892, 0x9b6a5fcc, 0x62547e46, 0xc2f68d13, 0xe890d8b8, 0x5e2e39f7, 0xf582c3af, 0xbe9f5d80, 0x7c69d093, 0xa96fd52d, 0xb3cf2512, 0x3bc8ac99, 0xa710187d, 0x6ee89c63, 0x7bdb3bbb, 0x09cd2678, 0xf46e5918, 0x01ec9ab7, 0xa8834f9a, 0x65e6956e, 0x7eaaffe6, 0x0821bccf, 0xe6ef15e8, 0xd9bae79b, 0xce4a6f36, 0xd4ea9f09, 0xd629b07c, 0xaf31a4b2, 0x312a3f23, 0x30c6a594, 0xc035a266, 0x37744ebc, 0xa6fc82ca, 0xb0e090d0, 0x1533a7d8, 0x4af10498, 0xf741ecda, 0x0e7fcd50, 0x2f1791f6, 0x8d764dd6, 0x4d43efb0, 0x54ccaa4d, 0xdfe49604, 0xe39ed1b5, 0x1b4c6a88, 0xb8c12c1f, 0x7f466551, 0x049d5eea, 0x5d018c35, 0x73fa8774, 0x2efb0b41, 0x5ab3671d, 0x5292dbd2, 0x33e91056, 0x136dd647, 0x8c9ad761, 0x7a37a10c, 0x8e59f814, 0x89eb133c, 0xeecea927, 0x35b761c9, 0xede11ce5, 0x3c7a47b1, 0x599cd2df, 0x3f55f273, 0x791814ce, 0xbf73c737, 0xea53f7cd, 0x5b5ffdaa, 0x14df3d6f, 0x867844db, 0x81caaff3, 0x3eb968c4, 0x2c382434, 0x5fc2a340, 0x72161dc3, 0x0cbce225, 0x8b283c49, 0x41ff0d95, 0x7139a801, 0xde080cb3, 0x9cd8b4e4, 0x906456c1, 0x617bcb84, 0x70d532b6, 0x74486c5c, 0x42d0b857 ]
_T7 = [ 0xa75051f4, 0x65537e41, 0xa4c31a17, 0x5e963a27, 0x6bcb3bab, 0x45f11f9d, 0x58abacfa, 0x03934be3, 0xfa552030, 0x6df6ad76, 0x769188cc, 0x4c25f502, 0xd7fc4fe5, 0xcbd7c52a, 0x44802635, 0xa38fb562, 0x5a49deb1, 0x1b6725ba, 0x0e9845ea, 0xc0e15dfe, 0x7502c32f, 0xf012814c, 0x97a38d46, 0xf9c66bd3, 0x5fe7038f, 0x9c951592, 0x7aebbf6d, 0x59da9552, 0x832dd4be, 0x21d35874, 0x692949e0, 0xc8448ec9, 0x896a75c2, 0x7978f48e, 0x3e6b9958, 0x71dd27b9, 0x4fb6bee1, 0xad17f088, 0xac66c920, 0x3ab47dce, 0x4a1863df, 0x3182e51a, 0x33609751, 0x7f456253, 0x77e0b164, 0xae84bb6b, 0xa01cfe81, 0x2b94f908, 0x68587048, 0xfd198f45, 0x6c8794de, 0xf8b7527b, 0xd323ab73, 0x02e2724b, 0x8f57e31f, 0xab2a6655, 0x2807b2eb, 0xc2032fb5, 0x7b9a86c5, 0x08a5d337, 0x87f23028, 0xa5b223bf, 0x6aba0203, 0x825ced16, 0x1c2b8acf, 0xb492a779, 0xf2f0f307, 0xe2a14e69, 0xf4cd65da, 0xbed50605, 0x621fd134, 0xfe8ac4a6, 0x539d342e, 0x55a0a2f3, 0xe132058a, 0xeb75a4f6, 0xec390b83, 0xefaa4060, 0x9f065e71, 0x1051bd6e, 0x8af93e21, 0x063d96dd, 0x05aedd3e, 0xbd464de6, 0x8db59154, 0x5d0571c4, 0xd46f0406, 0x15ff6050, 0xfb241998, 0xe997d6bd, 0x43cc8940, 0x9e7767d9, 0x42bdb0e8, 0x8b880789, 0x5b38e719, 0xeedb79c8, 0x0a47a17c, 0x0fe97c42, 0x1ec9f884, 0x00000000, 0x86830980, 0xed48322b, 0x70ac1e11, 0x724e6c5a, 0xfffbfd0e, 0x38560f85, 0xd51e3dae, 0x3927362d, 0xd9640a0f, 0xa621685c, 0x54d19b5b, 0x2e3a2436, 0x67b10c0a, 0xe70f9357, 0x96d2b4ee, 0x919e1b9b, 0xc54f80c0, 0x20a261dc, 0x4b695a77, 0x1a161c12, 0xba0ae293, 0x2ae5c0a0, 0xe0433c22, 0x171d121b, 0x0d0b0e09, 0xc7adf28b, 0xa8b92db6, 0xa9c8141e, 0x198557f1, 0x074caf75, 0xddbbee99, 0x60fda37f, 0x269ff701, 0xf5bc5c72, 0x3bc54466, 0x7e345bfb, 0x29768b43, 0xc6dccb23, 0xfc68b6ed, 0xf163b8e4, 0xdccad731, 0x85104263, 0x22401397, 0x112084c6, 0x247d854a, 0x3df8d2bb, 0x3211aef9, 0xa16dc729, 0x2f4b1d9e, 0x30f3dcb2, 0x52ec0d86, 0xe3d077c1, 0x166c2bb3, 0xb999a970, 0x48fa1194, 0x642247e9, 0x8cc4a8fc, 0x3f1aa0f0, 0x2cd8567d, 0x90ef2233, 0x4ec78749, 0xd1c1d938, 0xa2fe8cca, 0x0b3698d4, 0x81cfa6f5, 0xde28a57a, 0x8e26dab7, 0xbfa43fad, 0x9de42c3a, 0x920d5078, 0xcc9b6a5f, 0x4662547e, 0x13c2f68d, 0xb8e890d8, 0xf75e2e39, 0xaff582c3, 0x80be9f5d, 0x937c69d0, 0x2da96fd5, 0x12b3cf25, 0x993bc8ac, 0x7da71018, 0x636ee89c, 0xbb7bdb3b, 0x7809cd26, 0x18f46e59, 0xb701ec9a, 0x9aa8834f, 0x6e65e695, 0xe67eaaff, 0xcf0821bc, 0xe8e6ef15, 0x9bd9bae7, 0x36ce4a6f, 0x09d4ea9f, 0x7cd629b0, 0xb2af31a4, 0x23312a3f, 0x9430c6a5, 0x66c035a2, 0xbc37744e, 0xcaa6fc82, 0xd0b0e090, 0xd81533a7, 0x984af104, 0xdaf741ec, 0x500e7fcd, 0xf62f1791, 0xd68d764d, 0xb04d43ef, 0x4d54ccaa, 0x04dfe496, 0xb5e39ed1, 0x881b4c6a, 0x1fb8c12c, 0x517f4665, 0xea049d5e, 0x355d018c, 0x7473fa87, 0x412efb0b, 0x1d5ab367, 0xd25292db, 0x5633e910, 0x47136dd6, 0x618c9ad7, 0x0c7a37a1, 0x148e59f8, 0x3c89eb13, 0x27eecea9, 0xc935b761, 0xe5ede11c, 0xb13c7a47, 0xdf599cd2, 0x733f55f2, 0xce791814, 0x37bf73c7, 0xcdea53f7, 0xaa5b5ffd, 0x6f14df3d, 0xdb867844, 0xf381caaf, 0xc43eb968, 0x342c3824, 0x405fc2a3, 0xc372161d, 0x250cbce2, 0x498b283c, 0x9541ff0d, 0x017139a8, 0xb3de080c, 0xe49cd8b4, 0xc1906456, 0x84617bcb, 0xb670d532, 0x5c74486c, 0x5742d0b8 ]
_T8 = [ 0xf4a75051, 0x4165537e, 0x17a4c31a, 0x275e963a, 0xab6bcb3b, 0x9d45f11f, 0xfa58abac, 0xe303934b, 0x30fa5520, 0x766df6ad, 0xcc769188, 0x024c25f5, 0xe5d7fc4f, 0x2acbd7c5, 0x35448026, 0x62a38fb5, 0xb15a49de, 0xba1b6725, 0xea0e9845, 0xfec0e15d, 0x2f7502c3, 0x4cf01281, 0x4697a38d, 0xd3f9c66b, 0x8f5fe703, 0x929c9515, 0x6d7aebbf, 0x5259da95, 0xbe832dd4, 0x7421d358, 0xe0692949, 0xc9c8448e, 0xc2896a75, 0x8e7978f4, 0x583e6b99, 0xb971dd27, 0xe14fb6be, 0x88ad17f0, 0x20ac66c9, 0xce3ab47d, 0xdf4a1863, 0x1a3182e5, 0x51336097, 0x537f4562, 0x6477e0b1, 0x6bae84bb, 0x81a01cfe, 0x082b94f9, 0x48685870, 0x45fd198f, 0xde6c8794, 0x7bf8b752, 0x73d323ab, 0x4b02e272, 0x1f8f57e3, 0x55ab2a66, 0xeb2807b2, 0xb5c2032f, 0xc57b9a86, 0x3708a5d3, 0x2887f230, 0xbfa5b223, 0x036aba02, 0x16825ced, 0xcf1c2b8a, 0x79b492a7, 0x07f2f0f3, 0x69e2a14e, 0xdaf4cd65, 0x05bed506, 0x34621fd1, 0xa6fe8ac4, 0x2e539d34, 0xf355a0a2, 0x8ae13205, 0xf6eb75a4, 0x83ec390b, 0x60efaa40, 0x719f065e, 0x6e1051bd, 0x218af93e, 0xdd063d96, 0x3e05aedd, 0xe6bd464d, 0x548db591, 0xc45d0571, 0x06d46f04, 0x5015ff60, 0x98fb2419, 0xbde997d6, 0x4043cc89, 0xd99e7767, 0xe842bdb0, 0x898b8807, 0x195b38e7, 0xc8eedb79, 0x7c0a47a1, 0x420fe97c, 0x841ec9f8, 0x00000000, 0x80868309, 0x2bed4832, 0x1170ac1e, 0x5a724e6c, 0x0efffbfd, 0x8538560f, 0xaed51e3d, 0x2d392736, 0x0fd9640a, 0x5ca62168, 0x5b54d19b, 0x362e3a24, 0x0a67b10c, 0x57e70f93, 0xee96d2b4, 0x9b919e1b, 0xc0c54f80, 0xdc20a261, 0x774b695a, 0x121a161c, 0x93ba0ae2, 0xa02ae5c0, 0x22e0433c, 0x1b171d12, 0x090d0b0e, 0x8bc7adf2, 0xb6a8b92d, 0x1ea9c814, 0xf1198557, 0x75074caf, 0x99ddbbee, 0x7f60fda3, 0x01269ff7, 0x72f5bc5c, 0x663bc544, 0xfb7e345b, 0x4329768b, 0x23c6dccb, 0xedfc68b6, 0xe4f163b8, 0x31dccad7, 0x63851042, 0x97224013, 0xc6112084, 0x4a247d85, 0xbb3df8d2, 0xf93211ae, 0x29a16dc7, 0x9e2f4b1d, 0xb230f3dc, 0x8652ec0d, 0xc1e3d077, 0xb3166c2b, 0x70b999a9, 0x9448fa11, 0xe9642247, 0xfc8cc4a8, 0xf03f1aa0, 0x7d2cd856, 0x3390ef22, 0x494ec787, 0x38d1c1d9, 0xcaa2fe8c, 0xd40b3698, 0xf581cfa6, 0x7ade28a5, 0xb78e26da, 0xadbfa43f, 0x3a9de42c, 0x78920d50, 0x5fcc9b6a, 0x7e466254, 0x8d13c2f6, 0xd8b8e890, 0x39f75e2e, 0xc3aff582, 0x5d80be9f, 0xd0937c69, 0xd52da96f, 0x2512b3cf, 0xac993bc8, 0x187da710, 0x9c636ee8, 0x3bbb7bdb, 0x267809cd, 0x5918f46e, 0x9ab701ec, 0x4f9aa883, 0x956e65e6, 0xffe67eaa, 0xbccf0821, 0x15e8e6ef, 0xe79bd9ba, 0x6f36ce4a, 0x9f09d4ea, 0xb07cd629, 0xa4b2af31, 0x3f23312a, 0xa59430c6, 0xa266c035, 0x4ebc3774, 0x82caa6fc, 0x90d0b0e0, 0xa7d81533, 0x04984af1, 0xecdaf741, 0xcd500e7f, 0x91f62f17, 0x4dd68d76, 0xefb04d43, 0xaa4d54cc, 0x9604dfe4, 0xd1b5e39e, 0x6a881b4c, 0x2c1fb8c1, 0x65517f46, 0x5eea049d, 0x8c355d01, 0x877473fa, 0x0b412efb, 0x671d5ab3, 0xdbd25292, 0x105633e9, 0xd647136d, 0xd7618c9a, 0xa10c7a37, 0xf8148e59, 0x133c89eb, 0xa927eece, 0x61c935b7, 0x1ce5ede1, 0x47b13c7a, 0xd2df599c, 0xf2733f55, 0x14ce7918, 0xc737bf73, 0xf7cdea53, 0xfdaa5b5f, 0x3d6f14df, 0x44db8678, 0xaff381ca, 0x68c43eb9, 0x24342c38, 0xa3405fc2, 0x1dc37216, 0xe2250cbc, 0x3c498b28, 0x0d9541ff, 0xa8017139, 0x0cb3de08, 0xb4e49cd8, 0x56c19064, 0xcb84617b, 0x32b670d5, 0x6c5c7448, 0xb85742d0 ]
# Transformations for decryption key expansion
_U1 = [ 0x00000000, 0x0e090d0b, 0x1c121a16, 0x121b171d, 0x3824342c, 0x362d3927, 0x24362e3a, 0x2a3f2331, 0x70486858, 0x7e416553, 0x6c5a724e, 0x62537f45, 0x486c5c74, 0x4665517f, 0x547e4662, 0x5a774b69, 0xe090d0b0, 0xee99ddbb, 0xfc82caa6, 0xf28bc7ad, 0xd8b4e49c, 0xd6bde997, 0xc4a6fe8a, 0xcaaff381, 0x90d8b8e8, 0x9ed1b5e3, 0x8ccaa2fe, 0x82c3aff5, 0xa8fc8cc4, 0xa6f581cf, 0xb4ee96d2, 0xbae79bd9, 0xdb3bbb7b, 0xd532b670, 0xc729a16d, 0xc920ac66, 0xe31f8f57, 0xed16825c, 0xff0d9541, 0xf104984a, 0xab73d323, 0xa57ade28, 0xb761c935, 0xb968c43e, 0x9357e70f, 0x9d5eea04, 0x8f45fd19, 0x814cf012, 0x3bab6bcb, 0x35a266c0, 0x27b971dd, 0x29b07cd6, 0x038f5fe7, 0x0d8652ec, 0x1f9d45f1, 0x119448fa, 0x4be30393, 0x45ea0e98, 0x57f11985, 0x59f8148e, 0x73c737bf, 0x7dce3ab4, 0x6fd52da9, 0x61dc20a2, 0xad766df6, 0xa37f60fd, 0xb16477e0, 0xbf6d7aeb, 0x955259da, 0x9b5b54d1, 0x894043cc, 0x87494ec7, 0xdd3e05ae, 0xd33708a5, 0xc12c1fb8, 0xcf2512b3, 0xe51a3182, 0xeb133c89, 0xf9082b94, 0xf701269f, 0x4de6bd46, 0x43efb04d, 0x51f4a750, 0x5ffdaa5b, 0x75c2896a, 0x7bcb8461, 0x69d0937c, 0x67d99e77, 0x3daed51e, 0x33a7d815, 0x21bccf08, 0x2fb5c203, 0x058ae132, 0x0b83ec39, 0x1998fb24, 0x1791f62f, 0x764dd68d, 0x7844db86, 0x6a5fcc9b, 0x6456c190, 0x4e69e2a1, 0x4060efaa, 0x527bf8b7, 0x5c72f5bc, 0x0605bed5, 0x080cb3de, 0x1a17a4c3, 0x141ea9c8, 0x3e218af9, 0x302887f2, 0x223390ef, 0x2c3a9de4, 0x96dd063d, 0x98d40b36, 0x8acf1c2b, 0x84c61120, 0xaef93211, 0xa0f03f1a, 0xb2eb2807, 0xbce2250c, 0xe6956e65, 0xe89c636e, 0xfa877473, 0xf48e7978, 0xdeb15a49, 0xd0b85742, 0xc2a3405f, 0xccaa4d54, 0x41ecdaf7, 0x4fe5d7fc, 0x5dfec0e1, 0x53f7cdea, 0x79c8eedb, 0x77c1e3d0, 0x65daf4cd, 0x6bd3f9c6, 0x31a4b2af, 0x3fadbfa4, 0x2db6a8b9, 0x23bfa5b2, 0x09808683, 0x07898b88, 0x15929c95, 0x1b9b919e, 0xa17c0a47, 0xaf75074c, 0xbd6e1051, 0xb3671d5a, 0x99583e6b, 0x97513360, 0x854a247d, 0x8b432976, 0xd134621f, 0xdf3d6f14, 0xcd267809, 0xc32f7502, 0xe9105633, 0xe7195b38, 0xf5024c25, 0xfb0b412e, 0x9ad7618c, 0x94de6c87, 0x86c57b9a, 0x88cc7691, 0xa2f355a0, 0xacfa58ab, 0xbee14fb6, 0xb0e842bd, 0xea9f09d4, 0xe49604df, 0xf68d13c2, 0xf8841ec9, 0xd2bb3df8, 0xdcb230f3, 0xcea927ee, 0xc0a02ae5, 0x7a47b13c, 0x744ebc37, 0x6655ab2a, 0x685ca621, 0x42638510, 0x4c6a881b, 0x5e719f06, 0x5078920d, 0x0a0fd964, 0x0406d46f, 0x161dc372, 0x1814ce79, 0x322bed48, 0x3c22e043, 0x2e39f75e, 0x2030fa55, 0xec9ab701, 0xe293ba0a, 0xf088ad17, 0xfe81a01c, 0xd4be832d, 0xdab78e26, 0xc8ac993b, 0xc6a59430, 0x9cd2df59, 0x92dbd252, 0x80c0c54f, 0x8ec9c844, 0xa4f6eb75, 0xaaffe67e, 0xb8e4f163, 0xb6edfc68, 0x0c0a67b1, 0x02036aba, 0x10187da7, 0x1e1170ac, 0x342e539d, 0x3a275e96, 0x283c498b, 0x26354480, 0x7c420fe9, 0x724b02e2, 0x605015ff, 0x6e5918f4, 0x44663bc5, 0x4a6f36ce, 0x587421d3, 0x567d2cd8, 0x37a10c7a, 0x39a80171, 0x2bb3166c, 0x25ba1b67, 0x0f853856, 0x018c355d, 0x13972240, 0x1d9e2f4b, 0x47e96422, 0x49e06929, 0x5bfb7e34, 0x55f2733f, 0x7fcd500e, 0x71c45d05, 0x63df4a18, 0x6dd64713, 0xd731dcca, 0xd938d1c1, 0xcb23c6dc, 0xc52acbd7, 0xef15e8e6, 0xe11ce5ed, 0xf307f2f0, 0xfd0efffb, 0xa779b492, 0xa970b999, 0xbb6bae84, 0xb562a38f, 0x9f5d80be, 0x91548db5, 0x834f9aa8, 0x8d4697a3 ]
_U2 = [ 0x00000000, 0x0b0e090d, 0x161c121a, 0x1d121b17, 0x2c382434, 0x27362d39, 0x3a24362e, 0x312a3f23, 0x58704868, 0x537e4165, 0x4e6c5a72, 0x4562537f, 0x74486c5c, 0x7f466551, 0x62547e46, 0x695a774b, 0xb0e090d0, 0xbbee99dd, 0xa6fc82ca, 0xadf28bc7, 0x9cd8b4e4, 0x97d6bde9, 0x8ac4a6fe, 0x81caaff3, 0xe890d8b8, 0xe39ed1b5, 0xfe8ccaa2, 0xf582c3af, 0xc4a8fc8c, 0xcfa6f581, 0xd2b4ee96, 0xd9bae79b, 0x7bdb3bbb, 0x70d532b6, 0x6dc729a1, 0x66c920ac, 0x57e31f8f, 0x5ced1682, 0x41ff0d95, 0x4af10498, 0x23ab73d3, 0x28a57ade, 0x35b761c9, 0x3eb968c4, 0x0f9357e7, 0x049d5eea, 0x198f45fd, 0x12814cf0, 0xcb3bab6b, 0xc035a266, 0xdd27b971, 0xd629b07c, 0xe7038f5f, 0xec0d8652, 0xf11f9d45, 0xfa119448, 0x934be303, 0x9845ea0e, 0x8557f119, 0x8e59f814, 0xbf73c737, 0xb47dce3a, 0xa96fd52d, 0xa261dc20, 0xf6ad766d, 0xfda37f60, 0xe0b16477, 0xebbf6d7a, 0xda955259, 0xd19b5b54, 0xcc894043, 0xc787494e, 0xaedd3e05, 0xa5d33708, 0xb8c12c1f, 0xb3cf2512, 0x82e51a31, 0x89eb133c, 0x94f9082b, 0x9ff70126, 0x464de6bd, 0x4d43efb0, 0x5051f4a7, 0x5b5ffdaa, 0x6a75c289, 0x617bcb84, 0x7c69d093, 0x7767d99e, 0x1e3daed5, 0x1533a7d8, 0x0821bccf, 0x032fb5c2, 0x32058ae1, 0x390b83ec, 0x241998fb, 0x2f1791f6, 0x8d764dd6, 0x867844db, 0x9b6a5fcc, 0x906456c1, 0xa14e69e2, 0xaa4060ef, 0xb7527bf8, 0xbc5c72f5, 0xd50605be, 0xde080cb3, 0xc31a17a4, 0xc8141ea9, 0xf93e218a, 0xf2302887, 0xef223390, 0xe42c3a9d, 0x3d96dd06, 0x3698d40b, 0x2b8acf1c, 0x2084c611, 0x11aef932, 0x1aa0f03f, 0x07b2eb28, 0x0cbce225, 0x65e6956e, 0x6ee89c63, 0x73fa8774, 0x78f48e79, 0x49deb15a, 0x42d0b857, 0x5fc2a340, 0x54ccaa4d, 0xf741ecda, 0xfc4fe5d7, 0xe15dfec0, 0xea53f7cd, 0xdb79c8ee, 0xd077c1e3, 0xcd65daf4, 0xc66bd3f9, 0xaf31a4b2, 0xa43fadbf, 0xb92db6a8, 0xb223bfa5, 0x83098086, 0x8807898b, 0x9515929c, 0x9e1b9b91, 0x47a17c0a, 0x4caf7507, 0x51bd6e10, 0x5ab3671d, 0x6b99583e, 0x60975133, 0x7d854a24, 0x768b4329, 0x1fd13462, 0x14df3d6f, 0x09cd2678, 0x02c32f75, 0x33e91056, 0x38e7195b, 0x25f5024c, 0x2efb0b41, 0x8c9ad761, 0x8794de6c, 0x9a86c57b, 0x9188cc76, 0xa0a2f355, 0xabacfa58, 0xb6bee14f, 0xbdb0e842, 0xd4ea9f09, 0xdfe49604, 0xc2f68d13, 0xc9f8841e, 0xf8d2bb3d, 0xf3dcb230, 0xeecea927, 0xe5c0a02a, 0x3c7a47b1, 0x37744ebc, 0x2a6655ab, 0x21685ca6, 0x10426385, 0x1b4c6a88, 0x065e719f, 0x0d507892, 0x640a0fd9, 0x6f0406d4, 0x72161dc3, 0x791814ce, 0x48322bed, 0x433c22e0, 0x5e2e39f7, 0x552030fa, 0x01ec9ab7, 0x0ae293ba, 0x17f088ad, 0x1cfe81a0, 0x2dd4be83, 0x26dab78e, 0x3bc8ac99, 0x30c6a594, 0x599cd2df, 0x5292dbd2, 0x4f80c0c5, 0x448ec9c8, 0x75a4f6eb, 0x7eaaffe6, 0x63b8e4f1, 0x68b6edfc, 0xb10c0a67, 0xba02036a, 0xa710187d, 0xac1e1170, 0x9d342e53, 0x963a275e, 0x8b283c49, 0x80263544, 0xe97c420f, 0xe2724b02, 0xff605015, 0xf46e5918, 0xc544663b, 0xce4a6f36, 0xd3587421, 0xd8567d2c, 0x7a37a10c, 0x7139a801, 0x6c2bb316, 0x6725ba1b, 0x560f8538, 0x5d018c35, 0x40139722, 0x4b1d9e2f, 0x2247e964, 0x2949e069, 0x345bfb7e, 0x3f55f273, 0x0e7fcd50, 0x0571c45d, 0x1863df4a, 0x136dd647, 0xcad731dc, 0xc1d938d1, 0xdccb23c6, 0xd7c52acb, 0xe6ef15e8, 0xede11ce5, 0xf0f307f2, 0xfbfd0eff, 0x92a779b4, 0x99a970b9, 0x84bb6bae, 0x8fb562a3, 0xbe9f5d80, 0xb591548d, 0xa8834f9a, 0xa38d4697 ]
_U3 = [ 0x00000000, 0x0d0b0e09, 0x1a161c12, 0x171d121b, 0x342c3824, 0x3927362d, 0x2e3a2436, 0x23312a3f, 0x68587048, 0x65537e41, 0x724e6c5a, 0x7f456253, 0x5c74486c, 0x517f4665, 0x4662547e, 0x4b695a77, 0xd0b0e090, 0xddbbee99, 0xcaa6fc82, 0xc7adf28b, 0xe49cd8b4, 0xe997d6bd, 0xfe8ac4a6, 0xf381caaf, 0xb8e890d8, 0xb5e39ed1, 0xa2fe8cca, 0xaff582c3, 0x8cc4a8fc, 0x81cfa6f5, 0x96d2b4ee, 0x9bd9bae7, 0xbb7bdb3b, 0xb670d532, 0xa16dc729, 0xac66c920, 0x8f57e31f, 0x825ced16, 0x9541ff0d, 0x984af104, 0xd323ab73, 0xde28a57a, 0xc935b761, 0xc43eb968, 0xe70f9357, 0xea049d5e, 0xfd198f45, 0xf012814c, 0x6bcb3bab, 0x66c035a2, 0x71dd27b9, 0x7cd629b0, 0x5fe7038f, 0x52ec0d86, 0x45f11f9d, 0x48fa1194, 0x03934be3, 0x0e9845ea, 0x198557f1, 0x148e59f8, 0x37bf73c7, 0x3ab47dce, 0x2da96fd5, 0x20a261dc, 0x6df6ad76, 0x60fda37f, 0x77e0b164, 0x7aebbf6d, 0x59da9552, 0x54d19b5b, 0x43cc8940, 0x4ec78749, 0x05aedd3e, 0x08a5d337, 0x1fb8c12c, 0x12b3cf25, 0x3182e51a, 0x3c89eb13, 0x2b94f908, 0x269ff701, 0xbd464de6, 0xb04d43ef, 0xa75051f4, 0xaa5b5ffd, 0x896a75c2, 0x84617bcb, 0x937c69d0, 0x9e7767d9, 0xd51e3dae, 0xd81533a7, 0xcf0821bc, 0xc2032fb5, 0xe132058a, 0xec390b83, 0xfb241998, 0xf62f1791, 0xd68d764d, 0xdb867844, 0xcc9b6a5f, 0xc1906456, 0xe2a14e69, 0xefaa4060, 0xf8b7527b, 0xf5bc5c72, 0xbed50605, 0xb3de080c, 0xa4c31a17, 0xa9c8141e, 0x8af93e21, 0x87f23028, 0x90ef2233, 0x9de42c3a, 0x063d96dd, 0x0b3698d4, 0x1c2b8acf, 0x112084c6, 0x3211aef9, 0x3f1aa0f0, 0x2807b2eb, 0x250cbce2, 0x6e65e695, 0x636ee89c, 0x7473fa87, 0x7978f48e, 0x5a49deb1, 0x5742d0b8, 0x405fc2a3, 0x4d54ccaa, 0xdaf741ec, 0xd7fc4fe5, 0xc0e15dfe, 0xcdea53f7, 0xeedb79c8, 0xe3d077c1, 0xf4cd65da, 0xf9c66bd3, 0xb2af31a4, 0xbfa43fad, 0xa8b92db6, 0xa5b223bf, 0x86830980, 0x8b880789, 0x9c951592, 0x919e1b9b, 0x0a47a17c, 0x074caf75, 0x1051bd6e, 0x1d5ab367, 0x3e6b9958, 0x33609751, 0x247d854a, 0x29768b43, 0x621fd134, 0x6f14df3d, 0x7809cd26, 0x7502c32f, 0x5633e910, 0x5b38e719, 0x4c25f502, 0x412efb0b, 0x618c9ad7, 0x6c8794de, 0x7b9a86c5, 0x769188cc, 0x55a0a2f3, 0x58abacfa, 0x4fb6bee1, 0x42bdb0e8, 0x09d4ea9f, 0x04dfe496, 0x13c2f68d, 0x1ec9f884, 0x3df8d2bb, 0x30f3dcb2, 0x27eecea9, 0x2ae5c0a0, 0xb13c7a47, 0xbc37744e, 0xab2a6655, 0xa621685c, 0x85104263, 0x881b4c6a, 0x9f065e71, 0x920d5078, 0xd9640a0f, 0xd46f0406, 0xc372161d, 0xce791814, 0xed48322b, 0xe0433c22, 0xf75e2e39, 0xfa552030, 0xb701ec9a, 0xba0ae293, 0xad17f088, 0xa01cfe81, 0x832dd4be, 0x8e26dab7, 0x993bc8ac, 0x9430c6a5, 0xdf599cd2, 0xd25292db, 0xc54f80c0, 0xc8448ec9, 0xeb75a4f6, 0xe67eaaff, 0xf163b8e4, 0xfc68b6ed, 0x67b10c0a, 0x6aba0203, 0x7da71018, 0x70ac1e11, 0x539d342e, 0x5e963a27, 0x498b283c, 0x44802635, 0x0fe97c42, 0x02e2724b, 0x15ff6050, 0x18f46e59, 0x3bc54466, 0x36ce4a6f, 0x21d35874, 0x2cd8567d, 0x0c7a37a1, 0x017139a8, 0x166c2bb3, 0x1b6725ba, 0x38560f85, 0x355d018c, 0x22401397, 0x2f4b1d9e, 0x642247e9, 0x692949e0, 0x7e345bfb, 0x733f55f2, 0x500e7fcd, 0x5d0571c4, 0x4a1863df, 0x47136dd6, 0xdccad731, 0xd1c1d938, 0xc6dccb23, 0xcbd7c52a, 0xe8e6ef15, 0xe5ede11c, 0xf2f0f307, 0xfffbfd0e, 0xb492a779, 0xb999a970, 0xae84bb6b, 0xa38fb562, 0x80be9f5d, 0x8db59154, 0x9aa8834f, 0x97a38d46 ]
_U4 = [ 0x00000000, 0x090d0b0e, 0x121a161c, 0x1b171d12, 0x24342c38, 0x2d392736, 0x362e3a24, 0x3f23312a, 0x48685870, 0x4165537e, 0x5a724e6c, 0x537f4562, 0x6c5c7448, 0x65517f46, 0x7e466254, 0x774b695a, 0x90d0b0e0, 0x99ddbbee, 0x82caa6fc, 0x8bc7adf2, 0xb4e49cd8, 0xbde997d6, 0xa6fe8ac4, 0xaff381ca, 0xd8b8e890, 0xd1b5e39e, 0xcaa2fe8c, 0xc3aff582, 0xfc8cc4a8, 0xf581cfa6, 0xee96d2b4, 0xe79bd9ba, 0x3bbb7bdb, 0x32b670d5, 0x29a16dc7, 0x20ac66c9, 0x1f8f57e3, 0x16825ced, 0x0d9541ff, 0x04984af1, 0x73d323ab, 0x7ade28a5, 0x61c935b7, 0x68c43eb9, 0x57e70f93, 0x5eea049d, 0x45fd198f, 0x4cf01281, 0xab6bcb3b, 0xa266c035, 0xb971dd27, 0xb07cd629, 0x8f5fe703, 0x8652ec0d, 0x9d45f11f, 0x9448fa11, 0xe303934b, 0xea0e9845, 0xf1198557, 0xf8148e59, 0xc737bf73, 0xce3ab47d, 0xd52da96f, 0xdc20a261, 0x766df6ad, 0x7f60fda3, 0x6477e0b1, 0x6d7aebbf, 0x5259da95, 0x5b54d19b, 0x4043cc89, 0x494ec787, 0x3e05aedd, 0x3708a5d3, 0x2c1fb8c1, 0x2512b3cf, 0x1a3182e5, 0x133c89eb, 0x082b94f9, 0x01269ff7, 0xe6bd464d, 0xefb04d43, 0xf4a75051, 0xfdaa5b5f, 0xc2896a75, 0xcb84617b, 0xd0937c69, 0xd99e7767, 0xaed51e3d, 0xa7d81533, 0xbccf0821, 0xb5c2032f, 0x8ae13205, 0x83ec390b, 0x98fb2419, 0x91f62f17, 0x4dd68d76, 0x44db8678, 0x5fcc9b6a, 0x56c19064, 0x69e2a14e, 0x60efaa40, 0x7bf8b752, 0x72f5bc5c, 0x05bed506, 0x0cb3de08, 0x17a4c31a, 0x1ea9c814, 0x218af93e, 0x2887f230, 0x3390ef22, 0x3a9de42c, 0xdd063d96, 0xd40b3698, 0xcf1c2b8a, 0xc6112084, 0xf93211ae, 0xf03f1aa0, 0xeb2807b2, 0xe2250cbc, 0x956e65e6, 0x9c636ee8, 0x877473fa, 0x8e7978f4, 0xb15a49de, 0xb85742d0, 0xa3405fc2, 0xaa4d54cc, 0xecdaf741, 0xe5d7fc4f, 0xfec0e15d, 0xf7cdea53, 0xc8eedb79, 0xc1e3d077, 0xdaf4cd65, 0xd3f9c66b, 0xa4b2af31, 0xadbfa43f, 0xb6a8b92d, 0xbfa5b223, 0x80868309, 0x898b8807, 0x929c9515, 0x9b919e1b, 0x7c0a47a1, 0x75074caf, 0x6e1051bd, 0x671d5ab3, 0x583e6b99, 0x51336097, 0x4a247d85, 0x4329768b, 0x34621fd1, 0x3d6f14df, 0x267809cd, 0x2f7502c3, 0x105633e9, 0x195b38e7, 0x024c25f5, 0x0b412efb, 0xd7618c9a, 0xde6c8794, 0xc57b9a86, 0xcc769188, 0xf355a0a2, 0xfa58abac, 0xe14fb6be, 0xe842bdb0, 0x9f09d4ea, 0x9604dfe4, 0x8d13c2f6, 0x841ec9f8, 0xbb3df8d2, 0xb230f3dc, 0xa927eece, 0xa02ae5c0, 0x47b13c7a, 0x4ebc3774, 0x55ab2a66, 0x5ca62168, 0x63851042, 0x6a881b4c, 0x719f065e, 0x78920d50, 0x0fd9640a, 0x06d46f04, 0x1dc37216, 0x14ce7918, 0x2bed4832, 0x22e0433c, 0x39f75e2e, 0x30fa5520, 0x9ab701ec, 0x93ba0ae2, 0x88ad17f0, 0x81a01cfe, 0xbe832dd4, 0xb78e26da, 0xac993bc8, 0xa59430c6, 0xd2df599c, 0xdbd25292, 0xc0c54f80, 0xc9c8448e, 0xf6eb75a4, 0xffe67eaa, 0xe4f163b8, 0xedfc68b6, 0x0a67b10c, 0x036aba02, 0x187da710, 0x1170ac1e, 0x2e539d34, 0x275e963a, 0x3c498b28, 0x35448026, 0x420fe97c, 0x4b02e272, 0x5015ff60, 0x5918f46e, 0x663bc544, 0x6f36ce4a, 0x7421d358, 0x7d2cd856, 0xa10c7a37, 0xa8017139, 0xb3166c2b, 0xba1b6725, 0x8538560f, 0x8c355d01, 0x97224013, 0x9e2f4b1d, 0xe9642247, 0xe0692949, 0xfb7e345b, 0xf2733f55, 0xcd500e7f, 0xc45d0571, 0xdf4a1863, 0xd647136d, 0x31dccad7, 0x38d1c1d9, 0x23c6dccb, 0x2acbd7c5, 0x15e8e6ef, 0x1ce5ede1, 0x07f2f0f3, 0x0efffbfd, 0x79b492a7, 0x70b999a9, 0x6bae84bb, 0x62a38fb5, 0x5d80be9f, 0x548db591, 0x4f9aa883, 0x4697a38d ]
# fmt: on
def __init__(self, key: bytes) -> None:
if len(key) not in (16, 24, 32):
raise core.InvalidArgumentError(f'Invalid key size {len(key)}')
rounds = self._NUMBER_OF_ROUNDS[len(key)]
# Encryption round keys
self._ke = [[0] * 4 for i in range(rounds + 1)]
# Decryption round keys
self._kd = [[0] * 4 for i in range(rounds + 1)]
round_key_count = (rounds + 1) * 4
kc = len(key) // 4
# Convert the key into ints
tk = [struct.unpack('>i', key[i : i + 4])[0] for i in range(0, len(key), 4)]
# Copy values into round key arrays
for i in range(0, kc):
self._ke[i // 4][i % 4] = tk[i]
self._kd[rounds - (i // 4)][i % 4] = tk[i]
# Key expansion (FIPS-197 section 5.2)
r_con_pointer = 0
t = kc
while t < round_key_count:
tt = tk[kc - 1]
tk[0] ^= (
(self._S[(tt >> 16) & 0xFF] << 24)
^ (self._S[(tt >> 8) & 0xFF] << 16)
^ (self._S[tt & 0xFF] << 8)
^ self._S[(tt >> 24) & 0xFF]
^ (self._RCON[r_con_pointer] << 24)
)
r_con_pointer += 1
if kc != 8:
for i in range(1, kc):
tk[i] ^= tk[i - 1]
# Key expansion for 256-bit keys is "slightly different" (FIPS-197)
else:
for i in range(1, kc // 2):
tk[i] ^= tk[i - 1]
tt = tk[kc // 2 - 1]
tk[kc // 2] ^= (
self._S[tt & 0xFF]
^ (self._S[(tt >> 8) & 0xFF] << 8)
^ (self._S[(tt >> 16) & 0xFF] << 16)
^ (self._S[(tt >> 24) & 0xFF] << 24)
)
for i in range(kc // 2 + 1, kc):
tk[i] ^= tk[i - 1]
# Copy values into round key arrays
j = 0
while j < kc and t < round_key_count:
self._ke[t // 4][t % 4] = tk[j]
self._kd[rounds - (t // 4)][t % 4] = tk[j]
j += 1
t += 1
# Inverse-Cipher-ify the decryption round key (FIPS-197 section 5.3)
for r in range(1, rounds):
for j in range(0, 4):
tt = self._kd[r][j]
self._kd[r][j] = (
self._U1[(tt >> 24) & 0xFF]
^ self._U2[(tt >> 16) & 0xFF]
^ self._U3[(tt >> 8) & 0xFF]
^ self._U4[tt & 0xFF]
)
def encrypt(self, plaintext: bytes) -> bytes:
"""Encrypt a block of plain text using the AES block cipher."""
if len(plaintext) != 16:
raise core.InvalidArgumentError(f'wrong block length {len(plaintext)}')
rounds = len(self._ke) - 1
(s1, s2, s3) = [1, 2, 3]
a = [0, 0, 0, 0]
# Convert plaintext to (ints ^ key)
t = [
(_compact_word(plaintext[4 * i : 4 * i + 4]) ^ self._ke[0][i])
for i in range(0, 4)
]
# Apply round transforms
for r in range(1, rounds):
for i in range(0, 4):
a[i] = (
self._T1[(t[i] >> 24) & 0xFF]
^ self._T2[(t[(i + s1) % 4] >> 16) & 0xFF]
^ self._T3[(t[(i + s2) % 4] >> 8) & 0xFF]
^ self._T4[t[(i + s3) % 4] & 0xFF]
^ self._ke[r][i]
)
t = copy.copy(a)
# The last round is special
result = []
for i in range(0, 4):
tt = self._ke[rounds][i]
result.append((self._S[(t[i] >> 24) & 0xFF] ^ (tt >> 24)) & 0xFF)
result.append((self._S[(t[(i + s1) % 4] >> 16) & 0xFF] ^ (tt >> 16)) & 0xFF)
result.append((self._S[(t[(i + s2) % 4] >> 8) & 0xFF] ^ (tt >> 8)) & 0xFF)
result.append((self._S[t[(i + s3) % 4] & 0xFF] ^ tt) & 0xFF)
return bytes(result)
def decrypt(self, cipher_text: bytes) -> bytes:
"""Decrypt a block of cipher text using the AES block cipher."""
if len(cipher_text) != 16:
raise core.InvalidArgumentError(f'wrong block length {len(cipher_text)}')
rounds = len(self._kd) - 1
(s1, s2, s3) = [3, 2, 1]
a = [0, 0, 0, 0]
# Convert ciphertext to (ints ^ key)
t = [
(_compact_word(cipher_text[4 * i : 4 * i + 4]) ^ self._kd[0][i])
for i in range(0, 4)
]
# Apply round transforms
for r in range(1, rounds):
for i in range(0, 4):
a[i] = (
self._T5[(t[i] >> 24) & 0xFF]
^ self._T6[(t[(i + s1) % 4] >> 16) & 0xFF]
^ self._T7[(t[(i + s2) % 4] >> 8) & 0xFF]
^ self._T8[t[(i + s3) % 4] & 0xFF]
^ self._kd[r][i]
)
t = copy.copy(a)
# The last round is special
result = []
for i in range(0, 4):
tt = self._kd[rounds][i]
result.append((self._S_INV[(t[i] >> 24) & 0xFF] ^ (tt >> 24)) & 0xFF)
result.append(
(self._S_INV[(t[(i + s1) % 4] >> 16) & 0xFF] ^ (tt >> 16)) & 0xFF
)
result.append(
(self._S_INV[(t[(i + s2) % 4] >> 8) & 0xFF] ^ (tt >> 8)) & 0xFF
)
result.append((self._S_INV[t[(i + s3) % 4] & 0xFF] ^ tt) & 0xFF)
return bytes(result)
class _ECB:
def __init__(self, key: bytes):
self._aes = _AES(key)
def encrypt(self, plaintext: bytes) -> bytes:
return b"".join(
[
self._aes.encrypt(
plaintext[offset : offset + 16].ljust(16, b"\x00") # Pad 0.
)
for offset in range(0, len(plaintext), 16)
]
)
def decrypt(self, cipher_text: bytes) -> bytes:
return b"".join(
[
self._aes.encrypt(cipher_text[offset : offset + 16])
for offset in range(0, len(cipher_text), 16)
]
)
class _CBC:
def __init__(self, key: bytes, iv: bytes = bytes(16)) -> None:
if len(iv) != 16:
raise core.InvalidArgumentError(
f'initialization vector must be 16 bytes, get {len(iv)}'
)
else:
self._last_cipher_block = iv
self._aes = _AES(key)
def encrypt(self, plaintext: bytes) -> bytes:
cipher_text = b""
for offset in range(0, len(plaintext), 16):
pre_cipher_block = _xor(
plaintext[offset : offset + 16], self._last_cipher_block
)
self._last_cipher_block = self._aes.encrypt(pre_cipher_block)
cipher_text += self._last_cipher_block
return cipher_text
def decrypt(self, cipher_text: bytes) -> bytes:
plaintext = b""
for offset in range(0, len(cipher_text), 16):
plaintext += _xor(
self._aes.decrypt(cipher_text[offset : offset + 16]),
self._last_cipher_block,
)
self._last_cipher_block = cipher_text[offset : offset + 16]
return plaintext
class _CMAC:
def __init__(
self,
key: bytes,
msg: bytes = bytes(16),
mac_len: int = 16,
update_after_digest: bool = False,
) -> None:
self.digest_size = mac_len
self._key = key
self._block_size = bs = 16
self._mac_tag: bytes | None = None
self._update_after_digest = update_after_digest
# Section 5.3 of NIST SP 800 38B and Appendix B
if bs == 8:
const_Rb = 0x1B
self._max_size = 8 * (2**21)
elif bs == 16:
const_Rb = 0x87
self._max_size = 16 * (2**48)
else:
raise core.InvalidArgumentError(
f"CMAC requires a cipher with a block size of 8 or 16 bytes, not {bs}"
)
# Compute sub-keys
zero_block = bytes(bs)
self._ecb = _ECB(key)
L = self._ecb.encrypt(zero_block)
if L[0] & 0x80:
self._k1 = _shift_bytes(L, const_Rb)
else:
self._k1 = _shift_bytes(L)
if self._k1[0] & 0x80:
self._k2 = _shift_bytes(self._k1, const_Rb)
else:
self._k2 = _shift_bytes(self._k1)
# Initialize CBC cipher with zero IV
self._cbc = _CBC(key, zero_block)
# Cache for outstanding data to authenticate
self._cache = bytearray(bs)
self._cache_n = 0
# Last piece of cipher text produced
self._last_ct = zero_block
# Last block that was encrypted with AES
self._last_pt: bytes | None = None
# Counter for total message size
self._data_size = 0
if msg:
self.update(msg)
def update(self, msg: bytes) -> _CMAC:
"""Authenticate the next chunk of message.
Args:
data (byte string/byte array/memoryview): The next chunk of data
"""
if self._mac_tag is not None and not self._update_after_digest:
raise core.InvalidStateError(
"update() cannot be called after digest() or verify()"
)
self._data_size += len(msg)
bs = self._block_size
if self._cache_n > 0:
filler = min(bs - self._cache_n, len(msg))
self._cache[self._cache_n : self._cache_n + filler] = msg[:filler]
self._cache_n += filler
if self._cache_n < bs:
return self
msg = msg[filler:]
self._update(self._cache)
self._cache_n = 0
remain = len(msg) % bs
if remain > 0:
self._update(msg[:-remain])
self._cache[:remain] = msg[-remain:]
else:
self._update(msg)
self._cache_n = remain
return self
def _update(self, data_block: bytes) -> None:
"""Update a block aligned to the block boundary"""
bs = self._block_size
assert len(data_block) % bs == 0
if len(data_block) == 0:
return
ct = self._cbc.encrypt(data_block)
if len(data_block) == bs:
second_last = self._last_ct
else:
second_last = ct[-bs * 2 : -bs]
self._last_ct = ct[-bs:]
self._last_pt = _xor(second_last, data_block[-bs:])
def digest(self) -> bytes:
bs = self._block_size
if self._mac_tag is not None and not self._update_after_digest:
return self._mac_tag
if self._data_size > self._max_size:
raise core.InvalidArgumentError("MAC is unsafe for this message")
if self._cache_n == 0 and self._data_size > 0 and self._last_pt:
# Last block was full
pt = _xor(self._last_pt, self._k1)
else:
# Last block is partial (or message length is zero)
partial = self._cache[:]
partial[self._cache_n :] = b'\x80' + b'\x00' * (bs - self._cache_n - 1)
pt = _xor(_xor(self._last_ct, partial), self._k2)
self._mac_tag = self._ecb.encrypt(pt)[: self.digest_size]
return self._mac_tag
# Define the original Point class for clarity and conversion purposes
@dataclasses.dataclass
class _Point:
"""Represents a point on the elliptic curve in affine coordinates."""
curve: _EllipticCurve
x: int = 0
y: int = 0
infinite: bool = False
@dataclasses.dataclass(frozen=True)
class _JacobianPoint:
"""Represents a point on the elliptic curve in Jacobian coordinates."""
curve: _EllipticCurve
x: int = 1 # For point at infinity (1:1:0)
y: int = 1
z: int = 0 # z = 0 indicates point at infinity
@classmethod
def point_at_infinity(cls, curve: _EllipticCurve) -> _JacobianPoint:
return _JacobianPoint(curve=curve, x=1, y=1, z=0)
@classmethod
def from_affine(cls, affine_point: _Point) -> _JacobianPoint:
if affine_point.infinite:
return _JacobianPoint.point_at_infinity(affine_point.curve)
# A simple conversion is (x, y, 1)
return _JacobianPoint(
curve=affine_point.curve, x=affine_point.x, y=affine_point.y, z=1
)
def to_affine(self) -> _Point:
if self.z == 0:
return _Point(infinite=True, curve=self.curve)
p = self.curve.p
inv_z = pow(self.z, -1, p)
affine_x = (self.x * inv_z**2) % p
affine_y = (self.y * inv_z**3) % p
return _Point(curve=self.curve, x=affine_x, y=affine_y, infinite=False)
def double(self) -> _JacobianPoint:
if self.z == 0 or self.y == 0:
return _JacobianPoint.point_at_infinity(self.curve)
s = 4 * self.x * self.y**2
m = 3 * self.x**2 + self.curve.a * self.z**4
x2 = m**2 - 2 * s
y2 = m * (s - x2) - 8 * self.y**4
z2 = 2 * self.y * self.z
p = self.curve.p
return _JacobianPoint(curve=self.curve, x=x2 % p, y=y2 % p, z=z2 % p)
def __add__(self, other: _JacobianPoint) -> _JacobianPoint:
if self.z == 0 and other.z == 0:
return _JacobianPoint.point_at_infinity(self.curve)
elif self.z == 0:
return other
elif other.z == 0:
return self
x1 = self.x
y1 = self.y
z1 = self.z
x2 = other.x
y2 = other.y
z2 = other.z
p = self.curve.p
u1 = (x1 * z2**2) % p
u2 = (x2 * z1**2) % p
s1 = (y1 * z2**3) % p
s2 = (y2 * z1**3) % p
if u1 == u2:
if s1 != s2:
return _JacobianPoint.point_at_infinity(self.curve)
else:
return self.double()
else:
h = u2 - u1
r = s2 - s1
h3 = h**3 % p
u1h2 = (u1 * h**2) % p
x3 = r**2 - h3 - 2 * u1h2
y3 = r * (u1h2 - x3) - s1 * h3
z3 = h * z1 * z2
return _JacobianPoint(self.curve, x3 % p, y3 % p, z3 % p)
def __mul__(self, k: int) -> _JacobianPoint:
addend = self
result = _JacobianPoint.point_at_infinity(self.curve)
while k > 0:
if k % 2 != 0:
result = result + addend
addend = addend.double()
k = k >> 1
return result
def __rmul__(self, k: int) -> _JacobianPoint:
return self * k
@dataclasses.dataclass
class _EllipticCurve:
p: int
a: int
b: int
n: int
g_x: int
g_y: int
_generator_jacobian: _JacobianPoint = dataclasses.field(init=False)
def __post_init__(self):
self._generator_jacobian = _JacobianPoint(
curve=self, x=self.g_x, y=self.g_y, z=1
)
@dataclasses.dataclass
class PrivateKey:
key: int
curve: _EllipticCurve
def generate_private_key(self) -> PrivateKey:
"""Generates a random private key."""
return self.PrivateKey(key=secrets.randbelow(self.n), curve=self)
def generate_public_key(self, private_key: int) -> _Point:
"""Generates a public key from a private key using Jacobian coordinates for scalar multiplication."""
public_key_jacobian = self._generator_jacobian * private_key
return public_key_jacobian.to_affine()
def ecdh_shared_secret(self, private_key: int, other_public_key: _Point) -> bytes:
"""Computes the shared secret using ECDH."""
other_public_key_jacobian = _JacobianPoint.from_affine(other_public_key)
shared_point_jacobian = other_public_key_jacobian * private_key
shared_point_affine = shared_point_jacobian.to_affine()
if shared_point_affine.infinite:
raise core.InvalidPacketError(
"Shared secret calculation resulted in the point at infinite"
)
return shared_point_affine.x.to_bytes(32, 'big')
@classmethod
def SECP256R1(cls) -> _EllipticCurve:
p = 0xFFFFFFFF00000001000000000000000000000000FFFFFFFFFFFFFFFFFFFFFFFF
a = 0xFFFFFFFF00000001000000000000000000000000FFFFFFFFFFFFFFFFFFFFFFFC
b = 0x5AC635D8AA3A93E7B3EBBD55769886BC651D06B0CC53B0F63BCE3C3E27D2604B
n = 0xFFFFFFFF00000000FFFFFFFFFFFFFFFFBCE6FAADA7179E84F3B9CAC2FC632551 # Curve order
g_x = 0x6B17D1F2E12C4247F8BCE6E563A440F277037D812DEB33A0F4A13945D898C296
g_y = 0x4FE342E2FE1A7F9B8EE7EB4A7C0F9E162BCE33576B315ECECBB6406837BF51F5
return _EllipticCurve(p=p, a=a, b=b, n=n, g_x=g_x, g_y=g_y)
class EccKey:
def __init__(self, private_key: _EllipticCurve.PrivateKey) -> None:
self.private_key = private_key
@functools.cached_property
def x(self) -> bytes:
return self.private_key.curve.generate_public_key(
self.private_key.key
).x.to_bytes(32, byteorder='big')
@functools.cached_property
def y(self) -> bytes:
return self.private_key.curve.generate_public_key(
self.private_key.key
).y.to_bytes(32, byteorder='big')
def dh(self, public_key_x: bytes, public_key_y: bytes) -> bytes:
x = int.from_bytes(public_key_x, byteorder='big', signed=False)
y = int.from_bytes(public_key_y, byteorder='big', signed=False)
return self.private_key.curve.ecdh_shared_secret(
self.private_key.key,
_Point(x=x, y=y, curve=self.private_key.curve),
)
@classmethod
def generate(cls) -> EccKey:
return EccKey(_EllipticCurve.SECP256R1().generate_private_key())
@classmethod
def from_private_key_bytes(cls, d_bytes: bytes) -> EccKey:
d = int.from_bytes(d_bytes, byteorder='big', signed=False)
return EccKey(_EllipticCurve.PrivateKey(d, _EllipticCurve.SECP256R1()))
def e(key: bytes, data: bytes) -> bytes:
'''
AES-128 ECB, expecting byte-swapped inputs and producing a byte-swapped output.
See Bluetooth spec Vol 3, Part H - 2.2.1 Security function e
'''
return _ECB(key[::-1]).encrypt(data[::-1])[::-1]
def aes_cmac(m: bytes, k: bytes) -> bytes:
'''
See Bluetooth spec, Vol 3, Part H - 2.2.5 FunctionAES-CMAC
NOTE: the input and output of this internal function are in big-endian byte order
'''
return _CMAC(key=k, msg=m).digest()
-82
View File
@@ -1,82 +0,0 @@
# Copyright 2021-2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License")
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import functools
from cryptography.hazmat.primitives import ciphers, cmac
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives.ciphers import algorithms, modes
def e(key: bytes, data: bytes) -> bytes:
'''
AES-128 ECB, expecting byte-swapped inputs and producing a byte-swapped output.
See Bluetooth spec Vol 3, Part H - 2.2.1 Security function e
'''
cipher = ciphers.Cipher(algorithms.AES(key[::-1]), modes.ECB())
encryptor = cipher.encryptor()
return encryptor.update(data[::-1])[::-1]
class EccKey:
def __init__(self, private_key: ec.EllipticCurvePrivateKey) -> None:
self.private_key = private_key
@classmethod
def generate(cls) -> EccKey:
return EccKey(ec.generate_private_key(ec.SECP256R1()))
@classmethod
def from_private_key_bytes(cls, d_bytes: bytes) -> EccKey:
d = int.from_bytes(d_bytes, byteorder='big', signed=False)
return EccKey(ec.derive_private_key(d, ec.SECP256R1()))
@functools.cached_property
def x(self) -> bytes:
return (
self.private_key.public_key()
.public_numbers()
.x.to_bytes(32, byteorder='big')
)
@functools.cached_property
def y(self) -> bytes:
return (
self.private_key.public_key()
.public_numbers()
.y.to_bytes(32, byteorder='big')
)
def dh(self, public_key_x: bytes, public_key_y: bytes) -> bytes:
x = int.from_bytes(public_key_x, byteorder='big', signed=False)
y = int.from_bytes(public_key_y, byteorder='big', signed=False)
return self.private_key.exchange(
ec.ECDH(),
ec.EllipticCurvePublicNumbers(x, y, ec.SECP256R1()).public_key(),
)
def aes_cmac(m: bytes, k: bytes) -> bytes:
'''
See Bluetooth spec, Vol 3, Part H - 2.2.5 FunctionAES-CMAC
NOTE: the input and output of this internal function are in big-endian byte order
'''
mac = cmac.CMAC(algorithms.AES(k))
mac.update(m)
return mac.finalize()
-1026
View File
File diff suppressed because it is too large Load Diff
+3 -2
View File
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Union
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Constants # Constants
@@ -166,12 +167,12 @@ class G722Decoder:
# The initial value in BLOCK 3H # The initial value in BLOCK 3H
self._band[1].det = 8 self._band[1].det = 8
def decode_frame(self, encoded_data: bytes | bytearray) -> bytearray: def decode_frame(self, encoded_data: Union[bytes, bytearray]) -> bytearray:
result_array = bytearray(len(encoded_data) * 4) result_array = bytearray(len(encoded_data) * 4)
self.g722_decode(result_array, encoded_data) self.g722_decode(result_array, encoded_data)
return result_array return result_array
def g722_decode(self, result_array, encoded_data: bytes | bytearray) -> int: def g722_decode(self, result_array, encoded_data: Union[bytes, bytearray]) -> int:
"""Decode the data frame using g722 decoder.""" """Decode the data frame using g722 decoder."""
result_length = 0 result_length = 0
+1247 -1971
View File
File diff suppressed because it is too large Load Diff
+4 -10
View File
@@ -20,14 +20,12 @@ like loading firmware after a cold start.
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations from __future__ import annotations
import logging import logging
import pathlib import pathlib
import platform import platform
from collections.abc import Iterable from typing import Dict, Iterable, Optional, Type, TYPE_CHECKING
from typing import TYPE_CHECKING
from bumble.drivers import intel, rtk from bumble.drivers import rtk, intel
from bumble.drivers.common import Driver from bumble.drivers.common import Driver
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -42,18 +40,14 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Functions # Functions
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def get_driver_for_host(host: Host) -> Driver | None: async def get_driver_for_host(host: Host) -> Optional[Driver]:
"""Probe diver classes until one returns a valid instance for a host, or none is """Probe diver classes until one returns a valid instance for a host, or none is
found. found.
If a "driver" HCI metadata entry is present, only that driver class will be probed. If a "driver" HCI metadata entry is present, only that driver class will be probed.
""" """
driver_classes: dict[str, type[Driver]] = {"rtk": rtk.Driver, "intel": intel.Driver} driver_classes: Dict[str, Type[Driver]] = {"rtk": rtk.Driver, "intel": intel.Driver}
probe_list: Iterable[str] probe_list: Iterable[str]
if driver_name := host.hci_metadata.get("driver"): if driver_name := host.hci_metadata.get("driver"):
# The "driver" metadata may include runtime options after a '/' (for example
# "intel/ddc=..."). Keep only the base driver name (the portion before the
# first slash) so it matches a key in driver_classes (e.g. "intel").
driver_name = driver_name.split("/")[0]
# Only probe a single driver # Only probe a single driver
probe_list = [driver_name] probe_list = [driver_name]
else: else:
+2
View File
@@ -20,6 +20,8 @@ Common types for drivers.
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import abc import abc
from bumble import core
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Classes # Classes
+118 -148
View File
@@ -20,7 +20,6 @@ Loosely based on the Fuchsia OS implementation.
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import collections import collections
import dataclasses import dataclasses
@@ -29,10 +28,12 @@ import os
import pathlib import pathlib
import platform import platform
import struct import struct
from typing import TYPE_CHECKING, Any from typing import Any, Deque, Optional, TYPE_CHECKING
from bumble import core, hci, utils from bumble import core
from bumble.drivers import common from bumble.drivers import common
from bumble import hci
from bumble import utils
if TYPE_CHECKING: if TYPE_CHECKING:
from bumble.host import Host from bumble.host import Host
@@ -49,7 +50,6 @@ logger = logging.getLogger(__name__)
INTEL_USB_PRODUCTS = { INTEL_USB_PRODUCTS = {
(0x8087, 0x0032), # AX210 (0x8087, 0x0032), # AX210
(0x8087, 0x0033), # AX211
(0x8087, 0x0036), # BE200 (0x8087, 0x0036), # BE200
} }
@@ -89,54 +89,54 @@ HCI_INTEL_WRITE_BOOT_PARAMS_COMMAND = hci.hci_vendor_command_op_code(0x000E)
hci.HCI_Command.register_commands(globals()) hci.HCI_Command.register_commands(globals())
@dataclasses.dataclass @hci.HCI_Command.command(
class HCI_Intel_Read_Version_ReturnParameters(hci.HCI_StatusReturnParameters): fields=[
tlv: bytes = hci.field(metadata=hci.metadata('*')) ("param0", 1),
],
return_parameters_fields=[
("status", hci.STATUS_SPEC),
("tlv", "*"),
],
)
class HCI_Intel_Read_Version_Command(hci.HCI_Command):
pass
@hci.HCI_SyncCommand.sync_command(HCI_Intel_Read_Version_ReturnParameters) @hci.HCI_Command.command(
@dataclasses.dataclass fields=[("data_type", 1), ("data", "*")],
class HCI_Intel_Read_Version_Command( return_parameters_fields=[
hci.HCI_SyncCommand[HCI_Intel_Read_Version_ReturnParameters] ("status", 1),
): ],
param0: int = dataclasses.field(metadata=hci.metadata(1)) )
class Hci_Intel_Secure_Send_Command(hci.HCI_Command):
pass
@hci.HCI_SyncCommand.sync_command(hci.HCI_StatusReturnParameters) @hci.HCI_Command.command(
@dataclasses.dataclass fields=[
class Hci_Intel_Secure_Send_Command( ("reset_type", 1),
hci.HCI_SyncCommand[hci.HCI_StatusReturnParameters] ("patch_enable", 1),
): ("ddc_reload", 1),
data_type: int = dataclasses.field(metadata=hci.metadata(1)) ("boot_option", 1),
data: bytes = dataclasses.field(metadata=hci.metadata("*")) ("boot_address", 4),
],
return_parameters_fields=[
("data", "*"),
],
)
class HCI_Intel_Reset_Command(hci.HCI_Command):
pass
@dataclasses.dataclass @hci.HCI_Command.command(
class HCI_Intel_Reset_ReturnParameters(hci.HCI_ReturnParameters): fields=[("data", "*")],
data: bytes = hci.field(metadata=hci.metadata('*')) return_parameters_fields=[
("status", hci.STATUS_SPEC),
("params", "*"),
@hci.HCI_SyncCommand.sync_command(HCI_Intel_Reset_ReturnParameters) ],
@dataclasses.dataclass )
class HCI_Intel_Reset_Command(hci.HCI_SyncCommand[HCI_Intel_Reset_ReturnParameters]): class Hci_Intel_Write_Device_Config_Command(hci.HCI_Command):
reset_type: int = dataclasses.field(metadata=hci.metadata(1)) pass
patch_enable: int = dataclasses.field(metadata=hci.metadata(1))
ddc_reload: int = dataclasses.field(metadata=hci.metadata(1))
boot_option: int = dataclasses.field(metadata=hci.metadata(1))
boot_address: int = dataclasses.field(metadata=hci.metadata(4))
@dataclasses.dataclass
class HCI_Intel_Write_Device_Config_ReturnParameters(hci.HCI_StatusReturnParameters):
params: bytes = hci.field(metadata=hci.metadata('*'))
@hci.HCI_SyncCommand.sync_command(HCI_Intel_Write_Device_Config_ReturnParameters)
@dataclasses.dataclass
class HCI_Intel_Write_Device_Config_Command(
hci.HCI_SyncCommand[HCI_Intel_Write_Device_Config_ReturnParameters]
):
data: bytes = dataclasses.field(metadata=hci.metadata("*"))
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -201,51 +201,50 @@ def _parse_tlv(data: bytes) -> list[tuple[ValueType, Any]]:
value = data[2 : 2 + value_length] value = data[2 : 2 + value_length]
typed_value: Any typed_value: Any
match value_type: if value_type == ValueType.END:
case ValueType.END: break
break
case ValueType.CNVI | ValueType.CNVR: if value_type in (ValueType.CNVI, ValueType.CNVR):
(v,) = struct.unpack("<I", value) (v,) = struct.unpack("<I", value)
typed_value = ( typed_value = (
(((v >> 0) & 0xF) << 12) (((v >> 0) & 0xF) << 12)
| (((v >> 4) & 0xF) << 0) | (((v >> 4) & 0xF) << 0)
| (((v >> 8) & 0xF) << 4) | (((v >> 8) & 0xF) << 4)
| (((v >> 24) & 0xF) << 8) | (((v >> 24) & 0xF) << 8)
) )
case ValueType.HARDWARE_INFO: elif value_type == ValueType.HARDWARE_INFO:
(v,) = struct.unpack("<I", value) (v,) = struct.unpack("<I", value)
typed_value = HardwareInfo( typed_value = HardwareInfo(
HardwarePlatform((v >> 8) & 0xFF), HardwareVariant((v >> 16) & 0x3F) HardwarePlatform((v >> 8) & 0xFF), HardwareVariant((v >> 16) & 0x3F)
) )
case ( elif value_type in (
ValueType.USB_VENDOR_ID ValueType.USB_VENDOR_ID,
| ValueType.USB_PRODUCT_ID ValueType.USB_PRODUCT_ID,
| ValueType.DEVICE_REVISION ValueType.DEVICE_REVISION,
): ):
(typed_value,) = struct.unpack("<H", value) (typed_value,) = struct.unpack("<H", value)
case ValueType.CURRENT_MODE_OF_OPERATION: elif value_type == ValueType.CURRENT_MODE_OF_OPERATION:
typed_value = ModeOfOperation(value[0]) typed_value = ModeOfOperation(value[0])
case ( elif value_type in (
ValueType.BUILD_TYPE ValueType.BUILD_TYPE,
| ValueType.BUILD_NUMBER ValueType.BUILD_NUMBER,
| ValueType.SECURE_BOOT ValueType.SECURE_BOOT,
| ValueType.OTP_LOCK ValueType.OTP_LOCK,
| ValueType.API_LOCK ValueType.API_LOCK,
| ValueType.DEBUG_LOCK ValueType.DEBUG_LOCK,
| ValueType.SECURE_BOOT_ENGINE_TYPE ValueType.SECURE_BOOT_ENGINE_TYPE,
): ):
typed_value = value[0] typed_value = value[0]
case ValueType.TIMESTAMP: elif value_type == ValueType.TIMESTAMP:
typed_value = Timestamp(value[0], value[1]) typed_value = Timestamp(value[0], value[1])
case ValueType.FIRMWARE_BUILD: elif value_type == ValueType.FIRMWARE_BUILD:
typed_value = FirmwareBuild(value[0], Timestamp(value[1], value[2])) typed_value = FirmwareBuild(value[0], Timestamp(value[1], value[2]))
case ValueType.BLUETOOTH_ADDRESS: elif value_type == ValueType.BLUETOOTH_ADDRESS:
typed_value = hci.Address( typed_value = hci.Address(
value, address_type=hci.Address.PUBLIC_DEVICE_ADDRESS value, address_type=hci.Address.PUBLIC_DEVICE_ADDRESS
) )
case _: else:
typed_value = value typed_value = value
result.append((value_type, typed_value)) result.append((value_type, typed_value))
data = data[2 + value_length :] data = data[2 + value_length :]
@@ -294,7 +293,6 @@ class HardwareVariant(utils.OpenIntEnum):
# This is a just a partial list. # This is a just a partial list.
# Add other constants here as new hardware is encountered and tested. # Add other constants here as new hardware is encountered and tested.
TYPHOON_PEAK = 0x17 TYPHOON_PEAK = 0x17
GARFIELD_PEAK = 0x19
GALE_PEAK = 0x1C GALE_PEAK = 0x1C
@@ -348,7 +346,7 @@ class Driver(common.Driver):
def __init__(self, host: Host) -> None: def __init__(self, host: Host) -> None:
self.host = host self.host = host
self.max_in_flight_firmware_load_commands = 1 self.max_in_flight_firmware_load_commands = 1
self.pending_firmware_load_commands: collections.deque[hci.HCI_Command] = ( self.pending_firmware_load_commands: Deque[hci.HCI_Command] = (
collections.deque() collections.deque()
) )
self.can_send_firmware_load_command = asyncio.Event() self.can_send_firmware_load_command = asyncio.Event()
@@ -357,8 +355,8 @@ class Driver(common.Driver):
self.reset_complete = asyncio.Event() self.reset_complete = asyncio.Event()
# Parse configuration options from the driver name. # Parse configuration options from the driver name.
self.ddc_addon: bytes | None = None self.ddc_addon: Optional[bytes] = None
self.ddc_override: bytes | None = None self.ddc_override: Optional[bytes] = None
driver = host.hci_metadata.get("driver") driver = host.hci_metadata.get("driver")
if driver is not None and driver.startswith("intel/"): if driver is not None and driver.startswith("intel/"):
for key, value in [ for key, value in [
@@ -384,7 +382,7 @@ class Driver(common.Driver):
if (vendor_id, product_id) not in INTEL_USB_PRODUCTS: if (vendor_id, product_id) not in INTEL_USB_PRODUCTS:
logger.debug( logger.debug(
f"USB device ({vendor_id:04X}, {product_id:04X}) not in known list" f"USB device ({vendor_id:04X}, {product_id:04X}) " "not in known list"
) )
return False return False
@@ -406,7 +404,7 @@ class Driver(common.Driver):
self.host.on_hci_event_packet(event) self.host.on_hci_event_packet(event)
return return
if not event.return_parameters.status == hci.HCI_SUCCESS: if not event.return_parameters == hci.HCI_SUCCESS:
raise DriverError("HCI_Command_Complete_Event error") raise DriverError("HCI_Command_Complete_Event error")
if self.max_in_flight_firmware_load_commands != event.num_hci_command_packets: if self.max_in_flight_firmware_load_commands != event.num_hci_command_packets:
@@ -463,10 +461,6 @@ class Driver(common.Driver):
== ModeOfOperation.OPERATIONAL == ModeOfOperation.OPERATIONAL
): ):
logger.debug("firmware already loaded") logger.debug("firmware already loaded")
# If the firmeare is already loaded, still attempt to load any
# device configuration (DDC). DDC can be applied independently of a
# firmware reload and may contain runtime overrides or patches.
await self.load_ddc_if_any()
return return
# We only support some platforms and variants. # We only support some platforms and variants.
@@ -477,7 +471,6 @@ class Driver(common.Driver):
raise DriverError("hardware platform not supported") raise DriverError("hardware platform not supported")
if hardware_info.variant not in ( if hardware_info.variant not in (
HardwareVariant.TYPHOON_PEAK, HardwareVariant.TYPHOON_PEAK,
HardwareVariant.GARFIELD_PEAK,
HardwareVariant.GALE_PEAK, HardwareVariant.GALE_PEAK,
): ):
raise DriverError("hardware variant not supported") raise DriverError("hardware variant not supported")
@@ -487,7 +480,9 @@ class Driver(common.Driver):
raise DriverError("insufficient device info, missing CNVI or CNVR") raise DriverError("insufficient device info, missing CNVI or CNVR")
firmware_base_name = ( firmware_base_name = (
f"ibt-{device_info[ValueType.CNVI]:04X}-{device_info[ValueType.CNVR]:04X}" "ibt-"
f"{device_info[ValueType.CNVI]:04X}-"
f"{device_info[ValueType.CNVR]:04X}"
) )
logger.debug(f"FW base name: {firmware_base_name}") logger.debug(f"FW base name: {firmware_base_name}")
@@ -604,39 +599,17 @@ class Driver(common.Driver):
await self.reset_complete.wait() await self.reset_complete.wait()
logger.debug("reset complete") logger.debug("reset complete")
await self.load_ddc_if_any(firmware_base_name) # Load the device config if there is one.
async def load_ddc_if_any(self, firmware_base_name: str | None = None) -> None:
"""
Check for and load any Device Data Configuration (DDC) blobs.
Args:
firmware_base_name: Base name of the selected firmware (e.g. "ibt-XXXX-YYYY").
If None, don't attempt to look up a .ddc file that
corresponds to the firmware image.
Priority:
1. If a ddc_override was provided via driver metadata, use it (highest priority).
2. Otherwise, if firmware_base_name is provided, attempt to find a .ddc file
that corresponds to the selected firmware image.
3. Finally, if a ddc_addon was provided, append/load it after the primary DDC.
"""
# If an explicit DDC override was supplied, use it and skip file lookup.
if self.ddc_override: if self.ddc_override:
logger.debug("loading overridden DDC") logger.debug("loading overridden DDC")
await self.load_device_config(self.ddc_override) await self.load_device_config(self.ddc_override)
else: else:
# Only attempt .ddc file lookup if a firmware_base_name was provided. ddc_name = f"{firmware_base_name}.ddc"
if firmware_base_name is None: ddc_path = _find_binary_path(ddc_name)
logger.debug( if ddc_path:
"no firmware_base_name provided; skipping .ddc file lookup" logger.debug(f"loading DDC from {ddc_path}")
) ddc_data = ddc_path.read_bytes()
else: await self.load_device_config(ddc_data)
ddc_name = f"{firmware_base_name}.ddc"
ddc_path = _find_binary_path(ddc_name)
if ddc_path:
logger.debug(f"loading DDC from {ddc_path}")
ddc_data = ddc_path.read_bytes()
await self.load_device_config(ddc_data)
if self.ddc_addon: if self.ddc_addon:
logger.debug("loading DDC addon") logger.debug("loading DDC addon")
await self.load_device_config(self.ddc_addon) await self.load_device_config(self.ddc_addon)
@@ -645,8 +618,8 @@ class Driver(common.Driver):
while ddc_data: while ddc_data:
ddc_len = 1 + ddc_data[0] ddc_len = 1 + ddc_data[0]
ddc_payload = ddc_data[:ddc_len] ddc_payload = ddc_data[:ddc_len]
await self.host.send_sync_command( await self.host.send_command(
HCI_Intel_Write_Device_Config_Command(data=ddc_payload) Hci_Intel_Write_Device_Config_Command(data=ddc_payload)
) )
ddc_data = ddc_data[ddc_len:] ddc_data = ddc_data[ddc_len:]
@@ -664,34 +637,31 @@ class Driver(common.Driver):
async def read_device_info(self) -> dict[ValueType, Any]: async def read_device_info(self) -> dict[ValueType, Any]:
self.host.ready = True self.host.ready = True
response1 = await self.host.send_sync_command_raw(hci.HCI_Reset_Command()) response = await self.host.send_command(hci.HCI_Reset_Command())
if not isinstance( if not (
response1.return_parameters, hci.HCI_StatusReturnParameters isinstance(response, hci.HCI_Command_Complete_Event)
) or response1.return_parameters.status not in ( and response.return_parameters
hci.HCI_UNKNOWN_HCI_COMMAND_ERROR, in (hci.HCI_UNKNOWN_HCI_COMMAND_ERROR, hci.HCI_SUCCESS)
hci.HCI_SUCCESS,
): ):
# When the controller is in operational mode, the response is a # When the controller is in operational mode, the response is a
# successful response. # successful response.
# When the controller is in bootloader mode, # When the controller is in bootloader mode,
# HCI_UNKNOWN_HCI_COMMAND_ERROR is the expected response. Anything # HCI_UNKNOWN_HCI_COMMAND_ERROR is the expected response. Anything
# else is a failure. # else is a failure.
logger.warning(f"unexpected response: {response1}") logger.warning(f"unexpected response: {response}")
raise DriverError("unexpected HCI response") raise DriverError("unexpected HCI response")
# Read the firmware version. # Read the firmware version.
response2 = await self.host.send_sync_command_raw( response = await self.host.send_command(
HCI_Intel_Read_Version_Command(param0=0xFF) HCI_Intel_Read_Version_Command(param0=0xFF)
) )
if ( if not isinstance(response, hci.HCI_Command_Complete_Event):
not isinstance( raise DriverError("unexpected HCI response")
response2.return_parameters, HCI_Intel_Read_Version_ReturnParameters
) if response.return_parameters.status != 0: # type: ignore
or response2.return_parameters.status != 0
):
raise DriverError("HCI_Intel_Read_Version_Command error") raise DriverError("HCI_Intel_Read_Version_Command error")
tlvs = _parse_tlv(response2.return_parameters.tlv) # type: ignore tlvs = _parse_tlv(response.return_parameters.tlv) # type: ignore
# Convert the list to a dict. That's Ok here because we only expect each type # Convert the list to a dict. That's Ok here because we only expect each type
# to appear just once. # to appear just once.
+61 -132
View File
@@ -16,8 +16,11 @@ Support for Realtek USB dongles.
Based on various online bits of information, including the Linux kernel. Based on various online bits of information, including the Linux kernel.
(see `drivers/bluetooth/btrtl.c`) (see `drivers/bluetooth/btrtl.c`)
""" """
from __future__ import annotations
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from dataclasses import dataclass
import asyncio import asyncio
import enum import enum
import logging import logging
@@ -26,20 +29,21 @@ import os
import pathlib import pathlib
import platform import platform
import struct import struct
from typing import Tuple
import weakref import weakref
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
from bumble import core, hci from bumble import core
from bumble.hci import (
hci_vendor_command_op_code,
STATUS_SPEC,
HCI_SUCCESS,
HCI_Command,
HCI_Reset_Command,
HCI_Read_Local_Version_Information_Command,
)
from bumble.drivers import common from bumble.drivers import common
if TYPE_CHECKING:
from bumble.host import Host
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -82,7 +86,6 @@ class RtlProjectId(enum.IntEnum):
PROJECT_ID_8852A = 18 PROJECT_ID_8852A = 18
PROJECT_ID_8852B = 20 PROJECT_ID_8852B = 20
PROJECT_ID_8852C = 25 PROJECT_ID_8852C = 25
PROJECT_ID_8761C = 51
RTK_PROJECT_ID_TO_ROM = { RTK_PROJECT_ID_TO_ROM = {
@@ -98,7 +101,6 @@ RTK_PROJECT_ID_TO_ROM = {
18: RTK_ROM_LMP_8852A, 18: RTK_ROM_LMP_8852A,
20: RTK_ROM_LMP_8852A, 20: RTK_ROM_LMP_8852A,
25: RTK_ROM_LMP_8852A, 25: RTK_ROM_LMP_8852A,
51: RTK_ROM_LMP_8761A,
} }
# List of USB (VendorID, ProductID) for Realtek-based devices. # List of USB (VendorID, ProductID) for Realtek-based devices.
@@ -122,19 +124,12 @@ RTK_USB_PRODUCTS = {
# Realtek 8761BUV # Realtek 8761BUV
(0x0B05, 0x190E), (0x0B05, 0x190E),
(0x0BDA, 0x8771), (0x0BDA, 0x8771),
(0x0BDA, 0x877B),
(0x0BDA, 0xA728),
(0x0BDA, 0xA729),
(0x2230, 0x0016), (0x2230, 0x0016),
(0x2357, 0x0604), (0x2357, 0x0604),
(0x2550, 0x8761), (0x2550, 0x8761),
(0x2B89, 0x8761), (0x2B89, 0x8761),
(0x2C0A, 0x8761),
(0x7392, 0xC611), (0x7392, 0xC611),
# Realtek 8761CUV (0x0BDA, 0x877B),
(0x0B05, 0x1BF6),
(0x0BDA, 0xC761),
(0x7392, 0xF611),
# Realtek 8821AE # Realtek 8821AE
(0x0B05, 0x17DC), (0x0B05, 0x17DC),
(0x13D3, 0x3414), (0x13D3, 0x3414),
@@ -188,42 +183,27 @@ RTK_USB_PRODUCTS = {
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# HCI Commands # HCI Commands
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
HCI_RTK_READ_ROM_VERSION_COMMAND = hci.hci_vendor_command_op_code(0x6D) HCI_RTK_READ_ROM_VERSION_COMMAND = hci_vendor_command_op_code(0x6D)
HCI_RTK_DOWNLOAD_COMMAND = hci.hci_vendor_command_op_code(0x20) HCI_RTK_DOWNLOAD_COMMAND = hci_vendor_command_op_code(0x20)
HCI_RTK_DROP_FIRMWARE_COMMAND = hci.hci_vendor_command_op_code(0x66) HCI_RTK_DROP_FIRMWARE_COMMAND = hci_vendor_command_op_code(0x66)
hci.HCI_Command.register_commands(globals()) HCI_Command.register_commands(globals())
@dataclass @HCI_Command.command(return_parameters_fields=[("status", STATUS_SPEC), ("version", 1)])
class HCI_RTK_Read_ROM_Version_ReturnParameters(hci.HCI_StatusReturnParameters): class HCI_RTK_Read_ROM_Version_Command(HCI_Command):
version: int = field(metadata=hci.metadata(1))
@hci.HCI_SyncCommand.sync_command(HCI_RTK_Read_ROM_Version_ReturnParameters)
@dataclass
class HCI_RTK_Read_ROM_Version_Command(
hci.HCI_SyncCommand[HCI_RTK_Read_ROM_Version_ReturnParameters]
):
pass pass
@dataclass @HCI_Command.command(
class HCI_RTK_Download_ReturnParameters(hci.HCI_StatusReturnParameters): fields=[("index", 1), ("payload", RTK_FRAGMENT_LENGTH)],
index: int = field(metadata=hci.metadata(1)) return_parameters_fields=[("status", STATUS_SPEC), ("index", 1)],
)
class HCI_RTK_Download_Command(HCI_Command):
pass
@hci.HCI_SyncCommand.sync_command(HCI_RTK_Download_ReturnParameters) @HCI_Command.command()
@dataclass class HCI_RTK_Drop_Firmware_Command(HCI_Command):
class HCI_RTK_Download_Command(hci.HCI_SyncCommand[HCI_RTK_Download_ReturnParameters]):
index: int = field(metadata=hci.metadata(1))
payload: bytes = field(metadata=hci.metadata(RTK_FRAGMENT_LENGTH))
@hci.HCI_SyncCommand.sync_command(hci.HCI_GenericReturnParameters)
@dataclass
class HCI_RTK_Drop_Firmware_Command(
hci.HCI_SyncCommand[hci.HCI_GenericReturnParameters]
):
pass pass
@@ -314,7 +294,7 @@ class Driver(common.Driver):
@dataclass @dataclass
class DriverInfo: class DriverInfo:
rom: int rom: int
hci: tuple[int, int] hci: Tuple[int, int]
config_needed: bool config_needed: bool
has_rom_version: bool has_rom_version: bool
has_msft_ext: bool = False has_msft_ext: bool = False
@@ -388,15 +368,6 @@ class Driver(common.Driver):
fw_name="rtl8761bu_fw.bin", fw_name="rtl8761bu_fw.bin",
config_name="rtl8761bu_config.bin", config_name="rtl8761bu_config.bin",
), ),
# 8761CU
DriverInfo(
rom=RTK_ROM_LMP_8761A,
hci=(0x0E, 0x00),
config_needed=False,
has_rom_version=True,
fw_name="rtl8761cu_fw.bin",
config_name="rtl8761cu_config.bin",
),
# 8822C # 8822C
DriverInfo( DriverInfo(
rom=RTK_ROM_LMP_8822B, rom=RTK_ROM_LMP_8822B,
@@ -454,17 +425,9 @@ class Driver(common.Driver):
@staticmethod @staticmethod
def find_driver_info(hci_version, hci_subversion, lmp_subversion): def find_driver_info(hci_version, hci_subversion, lmp_subversion):
for driver_info in Driver.DRIVER_INFOS: for driver_info in Driver.DRIVER_INFOS:
if driver_info.rom == lmp_subversion and ( if driver_info.rom == lmp_subversion and driver_info.hci == (
driver_info.hci hci_subversion,
== ( hci_version,
hci_subversion,
hci_version,
)
or driver_info.hci
== (
hci_subversion,
0x0,
)
): ):
return driver_info return driver_info
@@ -509,7 +472,7 @@ class Driver(common.Driver):
return None return None
@staticmethod @staticmethod
def check(host: Host) -> bool: def check(host):
if not host.hci_metadata: if not host.hci_metadata:
logger.debug("USB metadata not found") logger.debug("USB metadata not found")
return False return False
@@ -526,51 +489,29 @@ class Driver(common.Driver):
if (vendor_id, product_id) not in RTK_USB_PRODUCTS: if (vendor_id, product_id) not in RTK_USB_PRODUCTS:
logger.debug( logger.debug(
f"USB device ({vendor_id:04X}, {product_id:04X}) not in known list" f"USB device ({vendor_id:04X}, {product_id:04X}) " "not in known list"
) )
return False return False
return True return True
@staticmethod
async def get_loaded_firmware_version(host: Host) -> int | None:
response1 = await host.send_sync_command_raw(HCI_RTK_Read_ROM_Version_Command())
if (
not isinstance(
response1.return_parameters, HCI_RTK_Read_ROM_Version_ReturnParameters
)
or response1.return_parameters.status != hci.HCI_SUCCESS
):
return None
response2 = await host.send_sync_command(
hci.HCI_Read_Local_Version_Information_Command()
)
return response2.hci_subversion << 16 | response2.lmp_subversion
@classmethod @classmethod
async def driver_info_for_host(cls, host: Host) -> DriverInfo | None: async def driver_info_for_host(cls, host):
try: try:
await host.send_sync_command( await host.send_command(
hci.HCI_Reset_Command(), HCI_Reset_Command(),
check_result=True,
response_timeout=cls.POST_RESET_DELAY, response_timeout=cls.POST_RESET_DELAY,
) )
host.ready = True # Needed to let the host know the controller is ready. host.ready = True # Needed to let the host know the controller is ready.
except asyncio.exceptions.TimeoutError: except asyncio.exceptions.TimeoutError:
logger.warning("timeout waiting for hci reset, retrying") logger.warning("timeout waiting for hci reset, retrying")
await host.send_sync_command(hci.HCI_Reset_Command()) await host.send_command(HCI_Reset_Command(), check_result=True)
host.ready = True host.ready = True
response = await host.send_sync_command_raw( command = HCI_Read_Local_Version_Information_Command()
hci.HCI_Read_Local_Version_Information_Command() response = await host.send_command(command, check_result=True)
) if response.command_opcode != command.op_code:
if (
not isinstance(
response.return_parameters,
hci.HCI_Read_Local_Version_Information_ReturnParameters,
)
or response.return_parameters.status != hci.HCI_SUCCESS
):
logger.error("failed to probe local version information") logger.error("failed to probe local version information")
return None return None
@@ -595,7 +536,7 @@ class Driver(common.Driver):
return driver_info return driver_info
@classmethod @classmethod
async def for_host(cls, host: Host, force: bool = False): async def for_host(cls, host, force=False):
# Check that a driver is needed for this host # Check that a driver is needed for this host
if not force and not cls.check(host): if not force and not cls.check(host):
return None return None
@@ -650,35 +591,28 @@ class Driver(common.Driver):
# TODO: load the firmware # TODO: load the firmware
async def download_for_rtl8723b(self) -> int | None: async def download_for_rtl8723b(self):
if self.driver_info.has_rom_version: if self.driver_info.has_rom_version:
response1 = await self.host.send_sync_command_raw( response = await self.host.send_command(
HCI_RTK_Read_ROM_Version_Command() HCI_RTK_Read_ROM_Version_Command(), check_result=True
) )
if ( if response.return_parameters.status != HCI_SUCCESS:
not isinstance(
response1.return_parameters,
HCI_RTK_Read_ROM_Version_ReturnParameters,
)
or response1.return_parameters.status != hci.HCI_SUCCESS
):
logger.warning("can't get ROM version") logger.warning("can't get ROM version")
return None return
rom_version = response1.return_parameters.version rom_version = response.return_parameters.version
logger.debug(f"ROM version before download: {rom_version:04X}") logger.debug(f"ROM version before download: {rom_version:04X}")
else: else:
rom_version = 0 rom_version = 0
firmware = Firmware(self.firmware) firmware = Firmware(self.firmware)
logger.debug(f"firmware: project_id=0x{firmware.project_id:04X}") logger.debug(f"firmware: project_id=0x{firmware.project_id:04X}")
logger.debug(f"firmware: version=0x{firmware.version:04X}")
for patch in firmware.patches: for patch in firmware.patches:
if patch[0] == rom_version + 1: if patch[0] == rom_version + 1:
logger.debug(f"using patch {patch[0]}") logger.debug(f"using patch {patch[0]}")
break break
else: else:
logger.warning("no valid patch found for rom version {rom_version}") logger.warning("no valid patch found for rom version {rom_version}")
return None return
# Append the config if there is one. # Append the config if there is one.
if self.config: if self.config:
@@ -699,28 +633,23 @@ class Driver(common.Driver):
fragment_offset = fragment_index * RTK_FRAGMENT_LENGTH fragment_offset = fragment_index * RTK_FRAGMENT_LENGTH
fragment = payload[fragment_offset : fragment_offset + RTK_FRAGMENT_LENGTH] fragment = payload[fragment_offset : fragment_offset + RTK_FRAGMENT_LENGTH]
logger.debug(f"downloading fragment {fragment_index}") logger.debug(f"downloading fragment {fragment_index}")
await self.host.send_sync_command( await self.host.send_command(
HCI_RTK_Download_Command(index=download_index, payload=fragment) HCI_RTK_Download_Command(
index=download_index, payload=fragment, check_result=True
)
) )
logger.debug("download complete!") logger.debug("download complete!")
# Read the version again # Read the version again
response2 = await self.host.send_sync_command_raw( response = await self.host.send_command(
HCI_RTK_Read_ROM_Version_Command() HCI_RTK_Read_ROM_Version_Command(), check_result=True
) )
if ( if response.return_parameters.status != HCI_SUCCESS:
not isinstance(
response2.return_parameters, HCI_RTK_Read_ROM_Version_ReturnParameters
)
or response2.return_parameters.status != hci.HCI_SUCCESS
):
logger.warning("can't get ROM version") logger.warning("can't get ROM version")
else: else:
rom_version = response2.return_parameters.version rom_version = response.return_parameters.version
logger.debug(f"ROM version after download: {rom_version:02X}") logger.debug(f"ROM version after download: {rom_version:04X}")
return firmware.version
async def download_firmware(self): async def download_firmware(self):
if self.driver_info.rom == RTK_ROM_LMP_8723A: if self.driver_info.rom == RTK_ROM_LMP_8723A:
@@ -739,7 +668,7 @@ class Driver(common.Driver):
async def init_controller(self): async def init_controller(self):
await self.download_firmware() await self.download_firmware()
await self.host.send_sync_command(hci.HCI_Reset_Command()) await self.host.send_command(HCI_Reset_Command(), check_result=True)
logger.info(f"loaded FW image {self.driver_info.fw_name}") logger.info(f"loaded FW image {self.driver_info.fw_name}")
+60
View File
@@ -0,0 +1,60 @@
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import logging
import struct
from bumble.gatt import (
Service,
Characteristic,
GATT_GENERIC_ACCESS_SERVICE,
GATT_DEVICE_NAME_CHARACTERISTIC,
GATT_APPEARANCE_CHARACTERISTIC,
)
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
# -----------------------------------------------------------------------------
class GenericAccessService(Service):
def __init__(self, device_name, appearance=(0, 0)):
device_name_characteristic = Characteristic(
GATT_DEVICE_NAME_CHARACTERISTIC,
Characteristic.Properties.READ,
Characteristic.READABLE,
device_name.encode('utf-8')[:248],
)
appearance_characteristic = Characteristic(
GATT_APPEARANCE_CHARACTERISTIC,
Characteristic.Properties.READ,
Characteristic.READABLE,
struct.pack('<H', (appearance[0] << 6) | appearance[1]),
)
super().__init__(
GATT_GENERIC_ACCESS_SERVICE,
[device_name_characteristic, appearance_characteristic],
)
+19 -16
View File
@@ -23,17 +23,15 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations from __future__ import annotations
import enum import enum
import functools import functools
import logging import logging
import struct import struct
from collections.abc import Iterable, Sequence from typing import Iterable, List, Optional, Sequence, TypeVar, Union
from typing import ClassVar, TypeVar
from bumble.att import Attribute, AttributeValue, AttributeValueV2
from bumble.colors import color from bumble.colors import color
from bumble.core import UUID, BaseBumbleError from bumble.core import BaseBumbleError, UUID
from bumble.att import Attribute, AttributeValue
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Typing # Typing
@@ -228,6 +226,7 @@ GATT_MEDIA_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x
GATT_MEDIA_CONTROL_POINT_OPCODES_SUPPORTED_CHARACTERISTIC = UUID.from_16_bits(0x2BA5, 'Media Control Point Opcodes Supported') GATT_MEDIA_CONTROL_POINT_OPCODES_SUPPORTED_CHARACTERISTIC = UUID.from_16_bits(0x2BA5, 'Media Control Point Opcodes Supported')
GATT_SEARCH_RESULTS_OBJECT_ID_CHARACTERISTIC = UUID.from_16_bits(0x2BA6, 'Search Results Object ID') GATT_SEARCH_RESULTS_OBJECT_ID_CHARACTERISTIC = UUID.from_16_bits(0x2BA6, 'Search Results Object ID')
GATT_SEARCH_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2BA7, 'Search Control Point') GATT_SEARCH_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2BA7, 'Search Control Point')
GATT_CONTENT_CONTROL_ID_CHARACTERISTIC = UUID.from_16_bits(0x2BBA, 'Content Control Id')
# Telephone Bearer Service (TBS) # Telephone Bearer Service (TBS)
GATT_BEARER_PROVIDER_NAME_CHARACTERISTIC = UUID.from_16_bits(0x2BB3, 'Bearer Provider Name') GATT_BEARER_PROVIDER_NAME_CHARACTERISTIC = UUID.from_16_bits(0x2BB3, 'Bearer Provider Name')
@@ -351,12 +350,12 @@ class Service(Attribute):
''' '''
uuid: UUID uuid: UUID
characteristics: list[Characteristic] characteristics: List[Characteristic]
included_services: list[Service] included_services: List[Service]
def __init__( def __init__(
self, self,
uuid: str | UUID, uuid: Union[str, UUID],
characteristics: Iterable[Characteristic], characteristics: Iterable[Characteristic],
primary=True, primary=True,
included_services: Iterable[Service] = (), included_services: Iterable[Service] = (),
@@ -379,7 +378,7 @@ class Service(Attribute):
self.characteristics = list(characteristics) self.characteristics = list(characteristics)
self.primary = primary self.primary = primary
def get_advertising_data(self) -> bytes | None: def get_advertising_data(self) -> Optional[bytes]:
""" """
Get Service specific advertising data Get Service specific advertising data
Defined by each Service, default value is empty Defined by each Service, default value is empty
@@ -403,7 +402,7 @@ class TemplateService(Service):
to expose their UUID as a class property to expose their UUID as a class property
''' '''
UUID: ClassVar[UUID] UUID: UUID
def __init__( def __init__(
self, self,
@@ -475,7 +474,7 @@ class Characteristic(Attribute[_T]):
# The check for `p.name is not None` here is needed because for InFlag # The check for `p.name is not None` here is needed because for InFlag
# enums, the .name property can be None, when the enum value is 0, # enums, the .name property can be None, when the enum value is 0,
# so the type hint for .name is Optional[str]. # so the type hint for .name is Optional[str].
enum_list: list[str] = [p.name for p in cls if p.name is not None] enum_list: List[str] = [p.name for p in cls if p.name is not None]
enum_list_str = ",".join(enum_list) enum_list_str = ",".join(enum_list)
raise TypeError( raise TypeError(
f"Characteristic.Properties::from_string() error:\nExpected a string containing any of the keys, separated by , or |: {enum_list_str}\nGot: {properties_str}" f"Characteristic.Properties::from_string() error:\nExpected a string containing any of the keys, separated by , or |: {enum_list_str}\nGot: {properties_str}"
@@ -503,10 +502,10 @@ class Characteristic(Attribute[_T]):
def __init__( def __init__(
self, self,
uuid: str | bytes | UUID, uuid: Union[str, bytes, UUID],
properties: Characteristic.Properties, properties: Characteristic.Properties,
permissions: str | Attribute.Permissions, permissions: Union[str, Attribute.Permissions],
value: AttributeValue[_T] | _T | None = None, value: Union[AttributeValue[_T], _T, None] = None,
descriptors: Sequence[Descriptor] = (), descriptors: Sequence[Descriptor] = (),
): ):
super().__init__(uuid, permissions, value) super().__init__(uuid, permissions, value)
@@ -579,8 +578,12 @@ class Descriptor(Attribute):
def __str__(self) -> str: def __str__(self) -> str:
if isinstance(self.value, bytes): if isinstance(self.value, bytes):
value_str = self.value.hex() value_str = self.value.hex()
elif isinstance(self.value, (AttributeValue, AttributeValueV2)): elif isinstance(self.value, CharacteristicValue):
value_str = '<dynamic>' value = self.value.read(None)
if isinstance(value, bytes):
value_str = value.hex()
else:
value_str = '<async>'
else: else:
value_str = '<...>' value_str = '<...>'
return ( return (
+21 -12
View File
@@ -20,15 +20,23 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations from __future__ import annotations
import struct import struct
from collections.abc import Callable, Iterable from typing import (
from typing import Any, Generic, Literal, TypeVar Any,
Callable,
Generic,
Iterable,
Literal,
Optional,
Type,
TypeVar,
)
from bumble import utils
from bumble.core import InvalidOperationError from bumble.core import InvalidOperationError
from bumble.gatt import Characteristic from bumble.gatt import Characteristic
from bumble.gatt_client import CharacteristicProxy from bumble.gatt_client import CharacteristicProxy
from bumble import utils
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Typing # Typing
@@ -75,8 +83,8 @@ class DelegatedCharacteristicAdapter(CharacteristicAdapter[_T]):
def __init__( def __init__(
self, self,
characteristic: Characteristic, characteristic: Characteristic,
encode: Callable[[_T], bytes] | None = None, encode: Optional[Callable[[_T], bytes]] = None,
decode: Callable[[bytes], _T] | None = None, decode: Optional[Callable[[bytes], _T]] = None,
): ):
super().__init__(characteristic) super().__init__(characteristic)
self.encode = encode self.encode = encode
@@ -102,8 +110,8 @@ class DelegatedCharacteristicProxyAdapter(CharacteristicProxyAdapter[_T]):
def __init__( def __init__(
self, self,
characteristic_proxy: CharacteristicProxy, characteristic_proxy: CharacteristicProxy,
encode: Callable[[_T], bytes] | None = None, encode: Optional[Callable[[_T], bytes]] = None,
decode: Callable[[bytes], _T] | None = None, decode: Optional[Callable[[bytes], _T]] = None,
): ):
super().__init__(characteristic_proxy) super().__init__(characteristic_proxy)
self.encode = encode self.encode = encode
@@ -262,7 +270,7 @@ class SerializableCharacteristicAdapter(CharacteristicAdapter[_T2]):
`to_bytes` and `__bytes__` methods, respectively. `to_bytes` and `__bytes__` methods, respectively.
''' '''
def __init__(self, characteristic: Characteristic, cls: type[_T2]) -> None: def __init__(self, characteristic: Characteristic, cls: Type[_T2]) -> None:
super().__init__(characteristic) super().__init__(characteristic)
self.cls = cls self.cls = cls
@@ -281,7 +289,7 @@ class SerializableCharacteristicProxyAdapter(CharacteristicProxyAdapter[_T2]):
''' '''
def __init__( def __init__(
self, characteristic_proxy: CharacteristicProxy, cls: type[_T2] self, characteristic_proxy: CharacteristicProxy, cls: Type[_T2]
) -> None: ) -> None:
super().__init__(characteristic_proxy) super().__init__(characteristic_proxy)
self.cls = cls self.cls = cls
@@ -303,7 +311,7 @@ class EnumCharacteristicAdapter(CharacteristicAdapter[_T3]):
def __init__( def __init__(
self, self,
characteristic: Characteristic, characteristic: Characteristic,
cls: type[_T3], cls: Type[_T3],
length: int, length: int,
byteorder: Literal['little', 'big'] = 'little', byteorder: Literal['little', 'big'] = 'little',
): ):
@@ -339,7 +347,7 @@ class EnumCharacteristicProxyAdapter(CharacteristicProxyAdapter[_T3]):
def __init__( def __init__(
self, self,
characteristic_proxy: CharacteristicProxy, characteristic_proxy: CharacteristicProxy,
cls: type[_T3], cls: Type[_T3],
length: int, length: int,
byteorder: Literal['little', 'big'] = 'little', byteorder: Literal['little', 'big'] = 'little',
): ):
@@ -362,4 +370,5 @@ class EnumCharacteristicProxyAdapter(CharacteristicProxyAdapter[_T3]):
def decode_value(self, value: bytes) -> _T3: def decode_value(self, value: bytes) -> _T3:
int_value = int.from_bytes(value, self.byteorder) int_value = int.from_bytes(value, self.byteorder)
a = self.cls(int_value)
return self.cls(int_value) return self.cls(int_value)
+160 -194
View File
@@ -24,47 +24,72 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import functools
import logging import logging
import struct import struct
from collections.abc import Callable, Iterable
from datetime import datetime from datetime import datetime
from typing import ( from typing import (
TYPE_CHECKING,
Any, Any,
ClassVar, Callable,
Dict,
Generic, Generic,
Iterable,
List,
Optional,
Set,
Tuple,
Union,
Type,
TypeVar, TypeVar,
overload, TYPE_CHECKING,
) )
from typing_extensions import Self
from bumble import att, core, l2cap, utils
from bumble.colors import color from bumble.colors import color
from bumble.hci import HCI_Constant
from bumble.att import (
ATT_ATTRIBUTE_NOT_FOUND_ERROR,
ATT_ATTRIBUTE_NOT_LONG_ERROR,
ATT_CID,
ATT_DEFAULT_MTU,
ATT_ERROR_RESPONSE,
ATT_INVALID_OFFSET_ERROR,
ATT_PDU,
ATT_RESPONSES,
ATT_Exchange_MTU_Request,
ATT_Find_By_Type_Value_Request,
ATT_Find_Information_Request,
ATT_Handle_Value_Confirmation,
ATT_Read_Blob_Request,
ATT_Read_By_Group_Type_Request,
ATT_Read_By_Type_Request,
ATT_Read_Request,
ATT_Write_Command,
ATT_Write_Request,
ATT_Error,
)
from bumble import utils
from bumble import core
from bumble.core import UUID, InvalidStateError from bumble.core import UUID, InvalidStateError
from bumble.gatt import ( from bumble.gatt import (
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE, GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR, GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
GATT_INCLUDE_ATTRIBUTE_TYPE,
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE, GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE,
GATT_REQUEST_TIMEOUT, GATT_REQUEST_TIMEOUT,
GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE, GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE,
GATT_INCLUDE_ATTRIBUTE_TYPE,
Characteristic, Characteristic,
ClientCharacteristicConfigurationBits, ClientCharacteristicConfigurationBits,
InvalidServiceError, InvalidServiceError,
TemplateService, TemplateService,
) )
from bumble.hci import HCI_Constant
if TYPE_CHECKING:
from bumble import device as device_module
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Typing # Typing
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
if TYPE_CHECKING:
from bumble.device import Connection
_T = TypeVar('_T') _T = TypeVar('_T')
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -124,8 +149,8 @@ class AttributeProxy(utils.EventEmitter, Generic[_T]):
class ServiceProxy(AttributeProxy): class ServiceProxy(AttributeProxy):
uuid: UUID uuid: UUID
characteristics: list[CharacteristicProxy[bytes]] characteristics: List[CharacteristicProxy[bytes]]
included_services: list[ServiceProxy] included_services: List[ServiceProxy]
@staticmethod @staticmethod
def from_client(service_class, client: Client, service_uuid: UUID): def from_client(service_class, client: Client, service_uuid: UUID):
@@ -174,8 +199,8 @@ class ServiceProxy(AttributeProxy):
class CharacteristicProxy(AttributeProxy[_T]): class CharacteristicProxy(AttributeProxy[_T]):
properties: Characteristic.Properties properties: Characteristic.Properties
descriptors: list[DescriptorProxy] descriptors: List[DescriptorProxy]
subscribers: dict[Any, Callable[[_T], Any]] subscribers: Dict[Any, Callable[[_T], Any]]
EVENT_UPDATE = "update" EVENT_UPDATE = "update"
@@ -194,7 +219,7 @@ class CharacteristicProxy(AttributeProxy[_T]):
self.descriptors_discovered = False self.descriptors_discovered = False
self.subscribers = {} # Map from subscriber to proxy subscriber self.subscribers = {} # Map from subscriber to proxy subscriber
def get_descriptor(self, descriptor_type: UUID) -> DescriptorProxy | None: def get_descriptor(self, descriptor_type: UUID) -> Optional[DescriptorProxy]:
for descriptor in self.descriptors: for descriptor in self.descriptors:
if descriptor.type == descriptor_type: if descriptor.type == descriptor_type:
return descriptor return descriptor
@@ -206,7 +231,7 @@ class CharacteristicProxy(AttributeProxy[_T]):
async def subscribe( async def subscribe(
self, self,
subscriber: Callable[[_T], Any] | None = None, subscriber: Optional[Callable[[_T], Any]] = None,
prefer_notify: bool = True, prefer_notify: bool = True,
) -> None: ) -> None:
if subscriber is not None: if subscriber is not None:
@@ -252,10 +277,10 @@ class ProfileServiceProxy:
Base class for profile-specific service proxies Base class for profile-specific service proxies
''' '''
SERVICE_CLASS: ClassVar[type[TemplateService]] SERVICE_CLASS: Type[TemplateService]
@classmethod @classmethod
def from_client(cls, client: Client) -> Self | None: def from_client(cls, client: Client) -> Optional[ProfileServiceProxy]:
return ServiceProxy.from_client(cls, client, cls.SERVICE_CLASS.UUID) return ServiceProxy.from_client(cls, client, cls.SERVICE_CLASS.UUID)
@@ -263,17 +288,19 @@ class ProfileServiceProxy:
# GATT Client # GATT Client
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Client: class Client:
services: list[ServiceProxy] services: List[ServiceProxy]
cached_values: dict[int, tuple[datetime, bytes]] cached_values: Dict[int, Tuple[datetime, bytes]]
notification_subscribers: dict[ notification_subscribers: Dict[
int, set[CharacteristicProxy | Callable[[bytes], Any]] int, Set[Union[CharacteristicProxy, Callable[[bytes], Any]]]
] ]
indication_subscribers: dict[int, set[CharacteristicProxy | Callable[[bytes], Any]]] indication_subscribers: Dict[
pending_response: asyncio.futures.Future[att.ATT_PDU] | None int, Set[Union[CharacteristicProxy, Callable[[bytes], Any]]]
pending_request: att.ATT_PDU | None ]
pending_response: Optional[asyncio.futures.Future[ATT_PDU]]
pending_request: Optional[ATT_PDU]
def __init__(self, bearer: att.Bearer) -> None: def __init__(self, connection: Connection) -> None:
self.bearer = bearer self.connection = connection
self.mtu_exchange_done = False self.mtu_exchange_done = False
self.request_semaphore = asyncio.Semaphore(1) self.request_semaphore = asyncio.Semaphore(1)
self.pending_request = None self.pending_request = None
@@ -283,76 +310,21 @@ class Client:
self.services = [] self.services = []
self.cached_values = {} self.cached_values = {}
if att.is_enhanced_bearer(bearer): connection.on(connection.EVENT_DISCONNECTION, self.on_disconnection)
bearer.on(bearer.EVENT_CLOSE, self.on_disconnection)
self._bearer_id = (
f'[0x{bearer.connection.handle:04X}|CID=0x{bearer.source_cid:04X}]'
)
self.connection = bearer.connection
else:
bearer.on(bearer.EVENT_DISCONNECTION, self.on_disconnection)
self._bearer_id = f'[0x{bearer.handle:04X}]'
self.connection = bearer
@overload
@classmethod
async def connect_eatt(
cls,
connection: device_module.Connection,
spec: l2cap.LeCreditBasedChannelSpec | None = None,
) -> Client: ...
@overload
@classmethod
async def connect_eatt(
cls,
connection: device_module.Connection,
spec: l2cap.LeCreditBasedChannelSpec | None = None,
count: int = 1,
) -> list[Client]: ...
@classmethod
async def connect_eatt(
cls,
connection: device_module.Connection,
spec: l2cap.LeCreditBasedChannelSpec | None = None,
count: int = 1,
) -> list[Client] | Client:
channels = await connection.device.l2cap_channel_manager.create_enhanced_credit_based_channels(
connection,
spec or l2cap.LeCreditBasedChannelSpec(psm=att.EATT_PSM),
count,
)
def on_pdu(client: Client, pdu: bytes):
client.on_gatt_pdu(att.ATT_PDU.from_bytes(pdu))
clients = [cls(channel) for channel in channels]
for channel, client in zip(channels, clients):
channel.sink = functools.partial(on_pdu, client)
channel.att_mtu = att.ATT_DEFAULT_MTU
return clients[0] if count == 1 else clients
@property
def mtu(self) -> int:
return self.bearer.att_mtu
@mtu.setter
def mtu(self, value: int) -> None:
self.bearer.on_att_mtu_update(value)
def send_gatt_pdu(self, pdu: bytes) -> None: def send_gatt_pdu(self, pdu: bytes) -> None:
if att.is_enhanced_bearer(self.bearer): self.connection.send_l2cap_pdu(ATT_CID, pdu)
self.bearer.write(pdu)
else:
self.bearer.send_l2cap_pdu(att.ATT_CID, pdu)
async def send_command(self, command: att.ATT_PDU) -> None: async def send_command(self, command: ATT_PDU) -> None:
logger.debug(f'GATT Command from client: {self._bearer_id} {command}') logger.debug(
f'GATT Command from client: [0x{self.connection.handle:04X}] {command}'
)
self.send_gatt_pdu(bytes(command)) self.send_gatt_pdu(bytes(command))
async def send_request(self, request: att.ATT_PDU): async def send_request(self, request: ATT_PDU):
logger.debug(f'GATT Request from client: {self._bearer_id} {request}') logger.debug(
f'GATT Request from client: [0x{self.connection.handle:04X}] {request}'
)
# Wait until we can send (only one pending command at a time for the connection) # Wait until we can send (only one pending command at a time for the connection)
response = None response = None
@@ -378,42 +350,41 @@ class Client:
return response return response
def send_confirmation( def send_confirmation(self, confirmation: ATT_Handle_Value_Confirmation) -> None:
self, confirmation: att.ATT_Handle_Value_Confirmation logger.debug(
) -> None: f'GATT Confirmation from client: [0x{self.connection.handle:04X}] '
logger.debug(f'GATT Confirmation from client: {self._bearer_id} {confirmation}') f'{confirmation}'
)
self.send_gatt_pdu(bytes(confirmation)) self.send_gatt_pdu(bytes(confirmation))
async def request_mtu(self, mtu: int) -> int: async def request_mtu(self, mtu: int) -> int:
# Check the range # Check the range
if mtu < att.ATT_DEFAULT_MTU: if mtu < ATT_DEFAULT_MTU:
raise core.InvalidArgumentError(f'MTU must be >= {att.ATT_DEFAULT_MTU}') raise core.InvalidArgumentError(f'MTU must be >= {ATT_DEFAULT_MTU}')
if mtu > 0xFFFF: if mtu > 0xFFFF:
raise core.InvalidArgumentError('MTU must be <= 0xFFFF') raise core.InvalidArgumentError('MTU must be <= 0xFFFF')
# We can only send one request per connection # We can only send one request per connection
if self.mtu_exchange_done: if self.mtu_exchange_done:
return self.mtu return self.connection.att_mtu
# Send the request # Send the request
self.mtu_exchange_done = True self.mtu_exchange_done = True
response = await self.send_request( response = await self.send_request(ATT_Exchange_MTU_Request(client_rx_mtu=mtu))
att.ATT_Exchange_MTU_Request(client_rx_mtu=mtu) if response.op_code == ATT_ERROR_RESPONSE:
) raise ATT_Error(error_code=response.error_code, message=response)
if response.op_code == att.Opcode.ATT_ERROR_RESPONSE:
raise att.ATT_Error(error_code=response.error_code, message=response)
# Compute the final MTU # Compute the final MTU
self.mtu = min(mtu, response.server_rx_mtu) self.connection.att_mtu = min(mtu, response.server_rx_mtu)
return self.mtu return self.connection.att_mtu
def get_services_by_uuid(self, uuid: UUID) -> list[ServiceProxy]: def get_services_by_uuid(self, uuid: UUID) -> List[ServiceProxy]:
return [service for service in self.services if service.uuid == uuid] return [service for service in self.services if service.uuid == uuid]
def get_characteristics_by_uuid( def get_characteristics_by_uuid(
self, uuid: UUID, service: ServiceProxy | None = None self, uuid: UUID, service: Optional[ServiceProxy] = None
) -> list[CharacteristicProxy[bytes]]: ) -> List[CharacteristicProxy[bytes]]:
services = [service] if service else self.services services = [service] if service else self.services
return [ return [
c c
@@ -421,14 +392,13 @@ class Client:
if c.uuid == uuid if c.uuid == uuid
] ]
def get_attribute_grouping( def get_attribute_grouping(self, attribute_handle: int) -> Optional[
self, attribute_handle: int Union[
) -> ( ServiceProxy,
ServiceProxy Tuple[ServiceProxy, CharacteristicProxy],
| tuple[ServiceProxy, CharacteristicProxy] Tuple[ServiceProxy, CharacteristicProxy, DescriptorProxy],
| tuple[ServiceProxy, CharacteristicProxy, DescriptorProxy] ]
| None ]:
):
""" """
Get the attribute(s) associated with an attribute handle Get the attribute(s) associated with an attribute handle
""" """
@@ -459,7 +429,7 @@ class Client:
if not already_known: if not already_known:
self.services.append(service) self.services.append(service)
async def discover_services(self, uuids: Iterable[UUID] = ()) -> list[ServiceProxy]: async def discover_services(self, uuids: Iterable[UUID] = ()) -> List[ServiceProxy]:
''' '''
See Vol 3, Part G - 4.4.1 Discover All Primary Services See Vol 3, Part G - 4.4.1 Discover All Primary Services
''' '''
@@ -467,7 +437,7 @@ class Client:
services = [] services = []
while starting_handle < 0xFFFF: while starting_handle < 0xFFFF:
response = await self.send_request( response = await self.send_request(
att.ATT_Read_By_Group_Type_Request( ATT_Read_By_Group_Type_Request(
starting_handle=starting_handle, starting_handle=starting_handle,
ending_handle=0xFFFF, ending_handle=0xFFFF,
attribute_group_type=GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE, attribute_group_type=GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE,
@@ -478,14 +448,14 @@ class Client:
return [] return []
# Check if we reached the end of the iteration # Check if we reached the end of the iteration
if response.op_code == att.Opcode.ATT_ERROR_RESPONSE: if response.op_code == ATT_ERROR_RESPONSE:
if response.error_code != att.ATT_ATTRIBUTE_NOT_FOUND_ERROR: if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR:
# Unexpected end # Unexpected end
logger.warning( logger.warning(
'!!! unexpected error while discovering services: ' '!!! unexpected error while discovering services: '
f'{HCI_Constant.error_name(response.error_code)}' f'{HCI_Constant.error_name(response.error_code)}'
) )
raise att.ATT_Error( raise ATT_Error(
error_code=response.error_code, error_code=response.error_code,
message='Unexpected error while discovering services', message='Unexpected error while discovering services',
) )
@@ -531,7 +501,7 @@ class Client:
return services return services
async def discover_service(self, uuid: str | UUID) -> list[ServiceProxy]: async def discover_service(self, uuid: Union[str, UUID]) -> List[ServiceProxy]:
''' '''
See Vol 3, Part G - 4.4.2 Discover Primary Service by Service UUID See Vol 3, Part G - 4.4.2 Discover Primary Service by Service UUID
''' '''
@@ -544,7 +514,7 @@ class Client:
services = [] services = []
while starting_handle < 0xFFFF: while starting_handle < 0xFFFF:
response = await self.send_request( response = await self.send_request(
att.ATT_Find_By_Type_Value_Request( ATT_Find_By_Type_Value_Request(
starting_handle=starting_handle, starting_handle=starting_handle,
ending_handle=0xFFFF, ending_handle=0xFFFF,
attribute_type=GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE, attribute_type=GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE,
@@ -556,8 +526,8 @@ class Client:
return [] return []
# Check if we reached the end of the iteration # Check if we reached the end of the iteration
if response.op_code == att.Opcode.ATT_ERROR_RESPONSE: if response.op_code == ATT_ERROR_RESPONSE:
if response.error_code != att.ATT_ATTRIBUTE_NOT_FOUND_ERROR: if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR:
# Unexpected end # Unexpected end
logger.warning( logger.warning(
'!!! unexpected error while discovering services: ' '!!! unexpected error while discovering services: '
@@ -602,7 +572,7 @@ class Client:
async def discover_included_services( async def discover_included_services(
self, service: ServiceProxy self, service: ServiceProxy
) -> list[ServiceProxy]: ) -> List[ServiceProxy]:
''' '''
See Vol 3, Part G - 4.5.1 Find Included Services See Vol 3, Part G - 4.5.1 Find Included Services
''' '''
@@ -610,10 +580,10 @@ class Client:
starting_handle = service.handle starting_handle = service.handle
ending_handle = service.end_group_handle ending_handle = service.end_group_handle
included_services: list[ServiceProxy] = [] included_services: List[ServiceProxy] = []
while starting_handle <= ending_handle: while starting_handle <= ending_handle:
response = await self.send_request( response = await self.send_request(
att.ATT_Read_By_Type_Request( ATT_Read_By_Type_Request(
starting_handle=starting_handle, starting_handle=starting_handle,
ending_handle=ending_handle, ending_handle=ending_handle,
attribute_type=GATT_INCLUDE_ATTRIBUTE_TYPE, attribute_type=GATT_INCLUDE_ATTRIBUTE_TYPE,
@@ -624,14 +594,14 @@ class Client:
return [] return []
# Check if we reached the end of the iteration # Check if we reached the end of the iteration
if response.op_code == att.Opcode.ATT_ERROR_RESPONSE: if response.op_code == ATT_ERROR_RESPONSE:
if response.error_code != att.ATT_ATTRIBUTE_NOT_FOUND_ERROR: if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR:
# Unexpected end # Unexpected end
logger.warning( logger.warning(
'!!! unexpected error while discovering included services: ' '!!! unexpected error while discovering included services: '
f'{HCI_Constant.error_name(response.error_code)}' f'{HCI_Constant.error_name(response.error_code)}'
) )
raise att.ATT_Error( raise ATT_Error(
error_code=response.error_code, error_code=response.error_code,
message='Unexpected error while discovering included services', message='Unexpected error while discovering included services',
) )
@@ -665,8 +635,8 @@ class Client:
return included_services return included_services
async def discover_characteristics( async def discover_characteristics(
self, uuids, service: ServiceProxy | None self, uuids, service: Optional[ServiceProxy]
) -> list[CharacteristicProxy[bytes]]: ) -> List[CharacteristicProxy[bytes]]:
''' '''
See Vol 3, Part G - 4.6.1 Discover All Characteristics of a Service and 4.6.2 See Vol 3, Part G - 4.6.1 Discover All Characteristics of a Service and 4.6.2
Discover Characteristics by UUID Discover Characteristics by UUID
@@ -679,15 +649,15 @@ class Client:
services = [service] if service else self.services services = [service] if service else self.services
# Perform characteristic discovery for each service # Perform characteristic discovery for each service
discovered_characteristics: list[CharacteristicProxy[bytes]] = [] discovered_characteristics: List[CharacteristicProxy[bytes]] = []
for service in services: for service in services:
starting_handle = service.handle starting_handle = service.handle
ending_handle = service.end_group_handle ending_handle = service.end_group_handle
characteristics: list[CharacteristicProxy[bytes]] = [] characteristics: List[CharacteristicProxy[bytes]] = []
while starting_handle <= ending_handle: while starting_handle <= ending_handle:
response = await self.send_request( response = await self.send_request(
att.ATT_Read_By_Type_Request( ATT_Read_By_Type_Request(
starting_handle=starting_handle, starting_handle=starting_handle,
ending_handle=ending_handle, ending_handle=ending_handle,
attribute_type=GATT_CHARACTERISTIC_ATTRIBUTE_TYPE, attribute_type=GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
@@ -698,14 +668,14 @@ class Client:
return [] return []
# Check if we reached the end of the iteration # Check if we reached the end of the iteration
if response.op_code == att.Opcode.ATT_ERROR_RESPONSE: if response.op_code == ATT_ERROR_RESPONSE:
if response.error_code != att.ATT_ATTRIBUTE_NOT_FOUND_ERROR: if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR:
# Unexpected end # Unexpected end
logger.warning( logger.warning(
'!!! unexpected error while discovering characteristics: ' '!!! unexpected error while discovering characteristics: '
f'{HCI_Constant.error_name(response.error_code)}' f'{HCI_Constant.error_name(response.error_code)}'
) )
raise att.ATT_Error( raise ATT_Error(
error_code=response.error_code, error_code=response.error_code,
message='Unexpected error while discovering characteristics', message='Unexpected error while discovering characteristics',
) )
@@ -752,10 +722,10 @@ class Client:
async def discover_descriptors( async def discover_descriptors(
self, self,
characteristic: CharacteristicProxy | None = None, characteristic: Optional[CharacteristicProxy] = None,
start_handle: int | None = None, start_handle: Optional[int] = None,
end_handle: int | None = None, end_handle: Optional[int] = None,
) -> list[DescriptorProxy]: ) -> List[DescriptorProxy]:
''' '''
See Vol 3, Part G - 4.7.1 Discover All Characteristic Descriptors See Vol 3, Part G - 4.7.1 Discover All Characteristic Descriptors
''' '''
@@ -768,10 +738,10 @@ class Client:
else: else:
return [] return []
descriptors: list[DescriptorProxy] = [] descriptors: List[DescriptorProxy] = []
while starting_handle <= ending_handle: while starting_handle <= ending_handle:
response = await self.send_request( response = await self.send_request(
att.ATT_Find_Information_Request( ATT_Find_Information_Request(
starting_handle=starting_handle, ending_handle=ending_handle starting_handle=starting_handle, ending_handle=ending_handle
) )
) )
@@ -780,8 +750,8 @@ class Client:
return [] return []
# Check if we reached the end of the iteration # Check if we reached the end of the iteration
if response.op_code == att.Opcode.ATT_ERROR_RESPONSE: if response.op_code == ATT_ERROR_RESPONSE:
if response.error_code != att.ATT_ATTRIBUTE_NOT_FOUND_ERROR: if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR:
# Unexpected end # Unexpected end
logger.warning( logger.warning(
'!!! unexpected error while discovering descriptors: ' '!!! unexpected error while discovering descriptors: '
@@ -817,7 +787,7 @@ class Client:
return descriptors return descriptors
async def discover_attributes(self) -> list[AttributeProxy[bytes]]: async def discover_attributes(self) -> List[AttributeProxy[bytes]]:
''' '''
Discover all attributes, regardless of type Discover all attributes, regardless of type
''' '''
@@ -826,7 +796,7 @@ class Client:
attributes = [] attributes = []
while True: while True:
response = await self.send_request( response = await self.send_request(
att.ATT_Find_Information_Request( ATT_Find_Information_Request(
starting_handle=starting_handle, ending_handle=ending_handle starting_handle=starting_handle, ending_handle=ending_handle
) )
) )
@@ -834,8 +804,8 @@ class Client:
return [] return []
# Check if we reached the end of the iteration # Check if we reached the end of the iteration
if response.op_code == att.Opcode.ATT_ERROR_RESPONSE: if response.op_code == ATT_ERROR_RESPONSE:
if response.error_code != att.ATT_ATTRIBUTE_NOT_FOUND_ERROR: if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR:
# Unexpected end # Unexpected end
logger.warning( logger.warning(
'!!! unexpected error while discovering attributes: ' '!!! unexpected error while discovering attributes: '
@@ -863,7 +833,7 @@ class Client:
async def subscribe( async def subscribe(
self, self,
characteristic: CharacteristicProxy, characteristic: CharacteristicProxy,
subscriber: Callable[[Any], Any] | None = None, subscriber: Optional[Callable[[Any], Any]] = None,
prefer_notify: bool = True, prefer_notify: bool = True,
) -> None: ) -> None:
# If we haven't already discovered the descriptors for this characteristic, # If we haven't already discovered the descriptors for this characteristic,
@@ -913,7 +883,7 @@ class Client:
async def unsubscribe( async def unsubscribe(
self, self,
characteristic: CharacteristicProxy, characteristic: CharacteristicProxy,
subscriber: Callable[[Any], Any] | None = None, subscriber: Optional[Callable[[Any], Any]] = None,
force: bool = False, force: bool = False,
) -> None: ) -> None:
''' '''
@@ -978,7 +948,7 @@ class Client:
await self.write_value(cccd, b'\x00\x00', with_response=True) await self.write_value(cccd, b'\x00\x00', with_response=True)
async def read_value( async def read_value(
self, attribute: int | AttributeProxy, no_long_read: bool = False self, attribute: Union[int, AttributeProxy], no_long_read: bool = False
) -> bytes: ) -> bytes:
''' '''
See Vol 3, Part G - 4.8.1 Read Characteristic Value See Vol 3, Part G - 4.8.1 Read Characteristic Value
@@ -989,41 +959,39 @@ class Client:
# Send a request to read # Send a request to read
attribute_handle = attribute if isinstance(attribute, int) else attribute.handle attribute_handle = attribute if isinstance(attribute, int) else attribute.handle
response = await self.send_request( response = await self.send_request(
att.ATT_Read_Request(attribute_handle=attribute_handle) ATT_Read_Request(attribute_handle=attribute_handle)
) )
if response is None: if response is None:
raise TimeoutError('read timeout') raise TimeoutError('read timeout')
if response.op_code == att.Opcode.ATT_ERROR_RESPONSE: if response.op_code == ATT_ERROR_RESPONSE:
raise att.ATT_Error(error_code=response.error_code, message=response) raise ATT_Error(error_code=response.error_code, message=response)
# If the value is the max size for the MTU, try to read more unless the caller # If the value is the max size for the MTU, try to read more unless the caller
# specifically asked not to do that # specifically asked not to do that
attribute_value = response.attribute_value attribute_value = response.attribute_value
if not no_long_read and len(attribute_value) == self.mtu - 1: if not no_long_read and len(attribute_value) == self.connection.att_mtu - 1:
logger.debug('using READ BLOB to get the rest of the value') logger.debug('using READ BLOB to get the rest of the value')
offset = len(attribute_value) offset = len(attribute_value)
while True: while True:
response = await self.send_request( response = await self.send_request(
att.ATT_Read_Blob_Request( ATT_Read_Blob_Request(
attribute_handle=attribute_handle, value_offset=offset attribute_handle=attribute_handle, value_offset=offset
) )
) )
if response is None: if response is None:
raise TimeoutError('read timeout') raise TimeoutError('read timeout')
if response.op_code == att.Opcode.ATT_ERROR_RESPONSE: if response.op_code == ATT_ERROR_RESPONSE:
if response.error_code in ( if response.error_code in (
att.ATT_ATTRIBUTE_NOT_LONG_ERROR, ATT_ATTRIBUTE_NOT_LONG_ERROR,
att.ATT_INVALID_OFFSET_ERROR, ATT_INVALID_OFFSET_ERROR,
): ):
break break
raise att.ATT_Error( raise ATT_Error(error_code=response.error_code, message=response)
error_code=response.error_code, message=response
)
part = response.part_attribute_value part = response.part_attribute_value
attribute_value += part attribute_value += part
if len(part) < self.mtu - 1: if len(part) < self.connection.att_mtu - 1:
break break
offset += len(part) offset += len(part)
@@ -1033,8 +1001,8 @@ class Client:
return attribute_value return attribute_value
async def read_characteristics_by_uuid( async def read_characteristics_by_uuid(
self, uuid: UUID, service: ServiceProxy | None self, uuid: UUID, service: Optional[ServiceProxy]
) -> list[bytes]: ) -> List[bytes]:
''' '''
See Vol 3, Part G - 4.8.2 Read Using Characteristic UUID See Vol 3, Part G - 4.8.2 Read Using Characteristic UUID
''' '''
@@ -1049,7 +1017,7 @@ class Client:
characteristics_values = [] characteristics_values = []
while starting_handle <= ending_handle: while starting_handle <= ending_handle:
response = await self.send_request( response = await self.send_request(
att.ATT_Read_By_Type_Request( ATT_Read_By_Type_Request(
starting_handle=starting_handle, starting_handle=starting_handle,
ending_handle=ending_handle, ending_handle=ending_handle,
attribute_type=uuid, attribute_type=uuid,
@@ -1060,8 +1028,8 @@ class Client:
return [] return []
# Check if we reached the end of the iteration # Check if we reached the end of the iteration
if response.op_code == att.Opcode.ATT_ERROR_RESPONSE: if response.op_code == ATT_ERROR_RESPONSE:
if response.error_code != att.ATT_ATTRIBUTE_NOT_FOUND_ERROR: if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR:
# Unexpected end # Unexpected end
logger.warning( logger.warning(
'!!! unexpected error while reading characteristics: ' '!!! unexpected error while reading characteristics: '
@@ -1091,7 +1059,7 @@ class Client:
async def write_value( async def write_value(
self, self,
attribute: int | AttributeProxy, attribute: Union[int, AttributeProxy],
value: bytes, value: bytes,
with_response: bool = False, with_response: bool = False,
) -> None: ) -> None:
@@ -1106,27 +1074,28 @@ class Client:
attribute_handle = attribute if isinstance(attribute, int) else attribute.handle attribute_handle = attribute if isinstance(attribute, int) else attribute.handle
if with_response: if with_response:
response = await self.send_request( response = await self.send_request(
att.ATT_Write_Request( ATT_Write_Request(
attribute_handle=attribute_handle, attribute_value=value attribute_handle=attribute_handle, attribute_value=value
) )
) )
if response.op_code == att.Opcode.ATT_ERROR_RESPONSE: if response.op_code == ATT_ERROR_RESPONSE:
raise att.ATT_Error(error_code=response.error_code, message=response) raise ATT_Error(error_code=response.error_code, message=response)
else: else:
await self.send_command( await self.send_command(
att.ATT_Write_Command( ATT_Write_Command(
attribute_handle=attribute_handle, attribute_value=value attribute_handle=attribute_handle, attribute_value=value
) )
) )
def on_disconnection(self, *args) -> None: def on_disconnection(self, _) -> None:
del args # unused.
if self.pending_response and not self.pending_response.done(): if self.pending_response and not self.pending_response.done():
self.pending_response.cancel() self.pending_response.cancel()
def on_gatt_pdu(self, att_pdu: att.ATT_PDU) -> None: def on_gatt_pdu(self, att_pdu: ATT_PDU) -> None:
logger.debug(f'GATT Response to client: {self._bearer_id} {att_pdu}') logger.debug(
if att_pdu.op_code in att.ATT_RESPONSES: f'GATT Response to client: [0x{self.connection.handle:04X}] {att_pdu}'
)
if att_pdu.op_code in ATT_RESPONSES:
if self.pending_request is None: if self.pending_request is None:
# Not expected! # Not expected!
logger.warning('!!! unexpected response, there is no pending request') logger.warning('!!! unexpected response, there is no pending request')
@@ -1134,7 +1103,7 @@ class Client:
# The response should match the pending request unless it is # The response should match the pending request unless it is
# an error response # an error response
if att_pdu.op_code != att.Opcode.ATT_ERROR_RESPONSE: if att_pdu.op_code != ATT_ERROR_RESPONSE:
expected_response_name = self.pending_request.name.replace( expected_response_name = self.pending_request.name.replace(
'_REQUEST', '_RESPONSE' '_REQUEST', '_RESPONSE'
) )
@@ -1155,15 +1124,14 @@ class Client:
else: else:
logger.warning( logger.warning(
color( color(
'--- Ignoring GATT Response from ' f'{self._bearer_id}: ', '--- Ignoring GATT Response from '
f'[0x{self.connection.handle:04X}]: ',
'red', 'red',
) )
+ str(att_pdu) + str(att_pdu)
) )
def on_att_handle_value_notification( def on_att_handle_value_notification(self, notification):
self, notification: att.ATT_Handle_Value_Notification
):
# Call all subscribers # Call all subscribers
subscribers = self.notification_subscribers.get( subscribers = self.notification_subscribers.get(
notification.attribute_handle, set() notification.attribute_handle, set()
@@ -1178,9 +1146,7 @@ class Client:
else: else:
subscriber.emit(subscriber.EVENT_UPDATE, notification.attribute_value) subscriber.emit(subscriber.EVENT_UPDATE, notification.attribute_value)
def on_att_handle_value_indication( def on_att_handle_value_indication(self, indication):
self, indication: att.ATT_Handle_Value_Indication
):
# Call all subscribers # Call all subscribers
subscribers = self.indication_subscribers.get( subscribers = self.indication_subscribers.get(
indication.attribute_handle, set() indication.attribute_handle, set()
@@ -1196,7 +1162,7 @@ class Client:
subscriber.emit(subscriber.EVENT_UPDATE, indication.attribute_value) subscriber.emit(subscriber.EVENT_UPDATE, indication.attribute_value)
# Confirm that we received the indication # Confirm that we received the indication
self.send_confirmation(att.ATT_Handle_Value_Confirmation()) self.send_confirmation(ATT_Handle_Value_Confirmation())
def cache_value(self, attribute_handle: int, value: bytes) -> None: def cache_value(self, attribute_handle: int, value: bytes) -> None:
self.cached_values[attribute_handle] = ( self.cached_values[attribute_handle] = (
+253 -417
View File
File diff suppressed because it is too large Load Diff
+3394 -3828
View File
File diff suppressed because it is too large Load Diff
+35 -27
View File
@@ -17,36 +17,44 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations from __future__ import annotations
import datetime
import logging
from collections.abc import Callable, MutableMapping from collections.abc import Callable, MutableMapping
from typing import Any, cast import datetime
from typing import cast, Any, Optional
import logging
from bumble import avc, avctp, avdtp, avrcp, crypto, rfcomm, sdp from bumble import avc
from bumble.att import ATT_CID, ATT_PDU from bumble import avctp
from bumble import avdtp
from bumble import avrcp
from bumble import crypto
from bumble import rfcomm
from bumble import sdp
from bumble.colors import color from bumble.colors import color
from bumble.att import ATT_CID, ATT_PDU
from bumble.smp import SMP_CID, SMP_Command
from bumble.core import name_or_number from bumble.core import name_or_number
from bumble.hci import (
HCI_ACL_DATA_PACKET,
HCI_DISCONNECTION_COMPLETE_EVENT,
HCI_EVENT_PACKET,
Address,
HCI_AclDataPacket,
HCI_AclDataPacketAssembler,
HCI_Disconnection_Complete_Event,
HCI_Event,
HCI_Packet,
)
from bumble.l2cap import ( from bumble.l2cap import (
L2CAP_LE_SIGNALING_CID,
L2CAP_PDU, L2CAP_PDU,
L2CAP_CONNECTION_REQUEST,
L2CAP_CONNECTION_RESPONSE,
L2CAP_SIGNALING_CID, L2CAP_SIGNALING_CID,
CommandCode, L2CAP_LE_SIGNALING_CID,
L2CAP_Control_Frame,
L2CAP_Connection_Request, L2CAP_Connection_Request,
L2CAP_Connection_Response, L2CAP_Connection_Response,
L2CAP_Control_Frame,
) )
from bumble.smp import SMP_CID, SMP_Command from bumble.hci import (
Address,
HCI_EVENT_PACKET,
HCI_ACL_DATA_PACKET,
HCI_DISCONNECTION_COMPLETE_EVENT,
HCI_AclDataPacketAssembler,
HCI_Packet,
HCI_Event,
HCI_AclDataPacket,
HCI_Disconnection_Complete_Event,
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -70,7 +78,7 @@ AVCTP_PID_NAMES = {avrcp.AVRCP_PID: 'AVRCP'}
class PacketTracer: class PacketTracer:
class AclStream: class AclStream:
psms: MutableMapping[int, int] psms: MutableMapping[int, int]
peer: PacketTracer.AclStream | None peer: Optional[PacketTracer.AclStream]
avdtp_assemblers: MutableMapping[int, avdtp.MessageAssembler] avdtp_assemblers: MutableMapping[int, avdtp.MessageAssembler]
avctp_assemblers: MutableMapping[int, avctp.MessageAssembler] avctp_assemblers: MutableMapping[int, avctp.MessageAssembler]
@@ -98,14 +106,14 @@ class PacketTracer:
self.analyzer.emit(control_frame) self.analyzer.emit(control_frame)
# Check if this signals a new channel # Check if this signals a new channel
if control_frame.code == CommandCode.L2CAP_CONNECTION_REQUEST: if control_frame.code == L2CAP_CONNECTION_REQUEST:
connection_request = cast(L2CAP_Connection_Request, control_frame) connection_request = cast(L2CAP_Connection_Request, control_frame)
self.psms[connection_request.source_cid] = connection_request.psm self.psms[connection_request.source_cid] = connection_request.psm
elif control_frame.code == CommandCode.L2CAP_CONNECTION_RESPONSE: elif control_frame.code == L2CAP_CONNECTION_RESPONSE:
connection_response = cast(L2CAP_Connection_Response, control_frame) connection_response = cast(L2CAP_Connection_Response, control_frame)
if ( if (
connection_response.result connection_response.result
== L2CAP_Connection_Response.Result.CONNECTION_SUCCESSFUL == L2CAP_Connection_Response.CONNECTION_SUCCESSFUL
): ):
if self.peer and ( if self.peer and (
psm := self.peer.psms.get(connection_response.source_cid) psm := self.peer.psms.get(connection_response.source_cid)
@@ -201,7 +209,7 @@ class PacketTracer:
self.label = label self.label = label
self.emit_message = emit_message self.emit_message = emit_message
self.acl_streams = {} # ACL streams, by connection handle self.acl_streams = {} # ACL streams, by connection handle
self.packet_timestamp: datetime.datetime | None = None self.packet_timestamp: Optional[datetime.datetime] = None
def start_acl_stream(self, connection_handle: int) -> PacketTracer.AclStream: def start_acl_stream(self, connection_handle: int) -> PacketTracer.AclStream:
logger.info( logger.info(
@@ -230,7 +238,7 @@ class PacketTracer:
self.peer.end_acl_stream(connection_handle) self.peer.end_acl_stream(connection_handle)
def on_packet( def on_packet(
self, timestamp: datetime.datetime | None, packet: HCI_Packet self, timestamp: Optional[datetime.datetime], packet: HCI_Packet
) -> None: ) -> None:
self.packet_timestamp = timestamp self.packet_timestamp = timestamp
self.emit(packet) self.emit(packet)
@@ -262,7 +270,7 @@ class PacketTracer:
self, self,
packet: HCI_Packet, packet: HCI_Packet,
direction: int = 0, direction: int = 0,
timestamp: datetime.datetime | None = None, timestamp: Optional[datetime.datetime] = None,
) -> None: ) -> None:
if direction == 0: if direction == 0:
self.host_to_controller_analyzer.on_packet(timestamp, packet) self.host_to_controller_analyzer.on_packet(timestamp, packet)
+187 -151
View File
@@ -17,36 +17,50 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations from __future__ import annotations
import asyncio
import collections import collections
import collections.abc import collections.abc
import logging
import asyncio
import dataclasses import dataclasses
import enum import enum
import logging
import re
import traceback import traceback
from collections.abc import Iterable import re
from typing import Any, ClassVar, Literal, overload from typing import (
Dict,
List,
Union,
Set,
Any,
Optional,
Type,
Tuple,
ClassVar,
Iterable,
TYPE_CHECKING,
)
from typing_extensions import Self from typing_extensions import Self
from bumble import at, device, rfcomm, sdp, utils from bumble import at
from bumble import device
from bumble import rfcomm
from bumble import sdp
from bumble import utils
from bumble.colors import color from bumble.colors import color
from bumble.core import ( from bumble.core import (
ProtocolError,
BT_GENERIC_AUDIO_SERVICE, BT_GENERIC_AUDIO_SERVICE,
BT_HANDSFREE_AUDIO_GATEWAY_SERVICE,
BT_HANDSFREE_SERVICE, BT_HANDSFREE_SERVICE,
BT_HANDSFREE_AUDIO_GATEWAY_SERVICE,
BT_L2CAP_PROTOCOL_ID, BT_L2CAP_PROTOCOL_ID,
BT_RFCOMM_PROTOCOL_ID, BT_RFCOMM_PROTOCOL_ID,
ProtocolError,
) )
from bumble.hci import ( from bumble.hci import (
CodecID,
CodingFormat,
HCI_Enhanced_Setup_Synchronous_Connection_Command, HCI_Enhanced_Setup_Synchronous_Connection_Command,
PcmDataFormat, CodingFormat,
CodecID,
) )
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -69,8 +83,6 @@ class HfpProtocolError(ProtocolError):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class HfpProtocol: class HfpProtocol:
MAX_BUFFER_SIZE: ClassVar[int] = 65536
dlc: rfcomm.DLC dlc: rfcomm.DLC
buffer: str buffer: str
lines: collections.deque lines: collections.deque
@@ -84,22 +96,13 @@ class HfpProtocol:
dlc.sink = self.feed dlc.sink = self.feed
def feed(self, data: bytes | str) -> None: def feed(self, data: Union[bytes, str]) -> None:
# Convert the data to a string if needed # Convert the data to a string if needed
if isinstance(data, bytes): if isinstance(data, bytes):
data = data.decode('utf-8', errors='replace') data = data.decode('utf-8')
logger.debug(f'<<< Data received: {data}') logger.debug(f'<<< Data received: {data}')
# Drop incoming data if it would overflow the buffer; keep existing
# partial packet state intact so a future clean packet can still parse.
if len(self.buffer) + len(data) > self.MAX_BUFFER_SIZE:
logger.warning(
'HFP buffer overflow (>%d bytes), dropping incoming data',
self.MAX_BUFFER_SIZE,
)
return
# Add to the buffer and look for lines # Add to the buffer and look for lines
self.buffer += data self.buffer += data
while (separator := self.buffer.find('\r')) >= 0: while (separator := self.buffer.find('\r')) >= 0:
@@ -178,7 +181,7 @@ class AgFeature(enum.IntFlag):
VOICE_RECOGNITION_TEXT = 0x2000 VOICE_RECOGNITION_TEXT = 0x2000
class AudioCodec(utils.OpenIntEnum): class AudioCodec(enum.IntEnum):
""" """
Audio Codec IDs (normative). Audio Codec IDs (normative).
@@ -190,7 +193,7 @@ class AudioCodec(utils.OpenIntEnum):
LC3_SWB = 0x03 # Support for LC3-SWB audio codec LC3_SWB = 0x03 # Support for LC3-SWB audio codec
class HfIndicator(utils.OpenIntEnum): class HfIndicator(enum.IntEnum):
""" """
HF Indicators (normative). HF Indicators (normative).
@@ -219,7 +222,7 @@ class CallHoldOperation(enum.Enum):
) )
class ResponseHoldStatus(utils.OpenIntEnum): class ResponseHoldStatus(enum.IntEnum):
""" """
Response Hold status (normative). Response Hold status (normative).
@@ -247,7 +250,7 @@ class AgIndicator(enum.Enum):
BATTERY_CHARGE = 'battchg' BATTERY_CHARGE = 'battchg'
class CallSetupAgIndicator(utils.OpenIntEnum): class CallSetupAgIndicator(enum.IntEnum):
""" """
Values for the Call Setup AG indicator (normative). Values for the Call Setup AG indicator (normative).
@@ -260,7 +263,7 @@ class CallSetupAgIndicator(utils.OpenIntEnum):
REMOTE_ALERTED = 3 # Remote party alerted in an outgoing call REMOTE_ALERTED = 3 # Remote party alerted in an outgoing call
class CallHeldAgIndicator(utils.OpenIntEnum): class CallHeldAgIndicator(enum.IntEnum):
""" """
Values for the Call Held AG indicator (normative). Values for the Call Held AG indicator (normative).
@@ -274,7 +277,7 @@ class CallHeldAgIndicator(utils.OpenIntEnum):
CALL_ON_HOLD_NO_ACTIVE_CALL = 2 # Call on hold, no active call CALL_ON_HOLD_NO_ACTIVE_CALL = 2 # Call on hold, no active call
class CallInfoDirection(utils.OpenIntEnum): class CallInfoDirection(enum.IntEnum):
""" """
Call Info direction (normative). Call Info direction (normative).
@@ -285,7 +288,7 @@ class CallInfoDirection(utils.OpenIntEnum):
MOBILE_TERMINATED_CALL = 1 MOBILE_TERMINATED_CALL = 1
class CallInfoStatus(utils.OpenIntEnum): class CallInfoStatus(enum.IntEnum):
""" """
Call Info status (normative). Call Info status (normative).
@@ -300,7 +303,7 @@ class CallInfoStatus(utils.OpenIntEnum):
WAITING = 5 WAITING = 5
class CallInfoMode(utils.OpenIntEnum): class CallInfoMode(enum.IntEnum):
""" """
Call Info mode (normative). Call Info mode (normative).
@@ -313,7 +316,7 @@ class CallInfoMode(utils.OpenIntEnum):
UNKNOWN = 9 UNKNOWN = 9
class CallInfoMultiParty(utils.OpenIntEnum): class CallInfoMultiParty(enum.IntEnum):
""" """
Call Info Multi-Party state (normative). Call Info Multi-Party state (normative).
@@ -337,8 +340,8 @@ class CallInfo:
status: CallInfoStatus status: CallInfoStatus
mode: CallInfoMode mode: CallInfoMode
multi_party: CallInfoMultiParty multi_party: CallInfoMultiParty
number: str | None = None number: Optional[str] = None
type: int | None = None type: Optional[int] = None
@dataclasses.dataclass @dataclasses.dataclass
@@ -366,13 +369,13 @@ class CallLineIdentification:
number: str number: str
type: int type: int
subaddr: str | None = None subaddr: Optional[str] = None
satype: int | None = None satype: Optional[int] = None
alpha: str | None = None alpha: Optional[str] = None
cli_validity: int | None = None cli_validity: Optional[int] = None
@classmethod @classmethod
def parse_from(cls, parameters: list[bytes]) -> Self: def parse_from(cls: Type[Self], parameters: List[bytes]) -> Self:
return cls( return cls(
number=parameters[0].decode(), number=parameters[0].decode(),
type=int(parameters[1]), type=int(parameters[1]),
@@ -400,7 +403,7 @@ class CallLineIdentification:
) )
class VoiceRecognitionState(utils.OpenIntEnum): class VoiceRecognitionState(enum.IntEnum):
""" """
vrec values provided in AT+BVRA command. vrec values provided in AT+BVRA command.
@@ -413,7 +416,7 @@ class VoiceRecognitionState(utils.OpenIntEnum):
ENHANCED_READY = 2 ENHANCED_READY = 2
class CmeError(utils.OpenIntEnum): class CmeError(enum.IntEnum):
""" """
CME ERROR codes (partial listed). CME ERROR codes (partial listed).
@@ -432,6 +435,61 @@ class CmeError(utils.OpenIntEnum):
# Hands-Free Control Interoperability Requirements # Hands-Free Control Interoperability Requirements
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Response codes.
RESPONSE_CODES = {
"+APLSIRI",
"+BAC",
"+BCC",
"+BCS",
"+BIA",
"+BIEV",
"+BIND",
"+BINP",
"+BLDN",
"+BRSF",
"+BTRH",
"+BVRA",
"+CCWA",
"+CHLD",
"+CHUP",
"+CIND",
"+CLCC",
"+CLIP",
"+CMEE",
"+CMER",
"+CNUM",
"+COPS",
"+IPHONEACCEV",
"+NREC",
"+VGM",
"+VGS",
"+VTS",
"+XAPL",
"A",
"D",
}
# Unsolicited responses and statuses.
UNSOLICITED_CODES = {
"+APLSIRI",
"+BCS",
"+BIND",
"+BSIR",
"+BTRH",
"+BVRA",
"+CCWA",
"+CIEV",
"+CLIP",
"+VGM",
"+VGS",
"BLACKLISTED",
"BUSY",
"DELAYED",
"NO ANSWER",
"NO CARRIER",
"RING",
}
# Status codes # Status codes
STATUS_CODES = { STATUS_CODES = {
"+CME ERROR", "+CME ERROR",
@@ -447,9 +505,9 @@ STATUS_CODES = {
@dataclasses.dataclass @dataclasses.dataclass
class HfConfiguration: class HfConfiguration:
supported_hf_features: collections.abc.Sequence[HfFeature] supported_hf_features: List[HfFeature]
supported_hf_indicators: collections.abc.Sequence[HfIndicator] supported_hf_indicators: List[HfIndicator]
supported_audio_codecs: collections.abc.Sequence[AudioCodec] supported_audio_codecs: List[AudioCodec]
@dataclasses.dataclass @dataclasses.dataclass
@@ -477,7 +535,7 @@ class AtResponse:
parameters: list parameters: list
@classmethod @classmethod
def parse_from(cls: type[Self], buffer: bytearray) -> Self: def parse_from(cls: Type[Self], buffer: bytearray) -> Self:
code_and_parameters = buffer.split(b':') code_and_parameters = buffer.split(b':')
parameters = ( parameters = (
code_and_parameters[1] if len(code_and_parameters) > 1 else bytearray() code_and_parameters[1] if len(code_and_parameters) > 1 else bytearray()
@@ -505,7 +563,7 @@ class AtCommand:
) )
@classmethod @classmethod
def parse_from(cls: type[Self], buffer: bytearray) -> Self: def parse_from(cls: Type[Self], buffer: bytearray) -> Self:
if not (match := cls._PARSE_PATTERN.fullmatch(buffer.decode())): if not (match := cls._PARSE_PATTERN.fullmatch(buffer.decode())):
if buffer.startswith(b'ATA'): if buffer.startswith(b'ATA'):
return cls(code='A', sub_code=AtCommand.SubCode.NONE, parameters=[]) return cls(code='A', sub_code=AtCommand.SubCode.NONE, parameters=[])
@@ -540,9 +598,9 @@ class AgIndicatorState:
""" """
indicator: AgIndicator indicator: AgIndicator
supported_values: set[int] supported_values: Set[int]
current_status: int current_status: int
index: int | None = None index: Optional[int] = None
enabled: bool = True enabled: bool = True
@property @property
@@ -555,17 +613,17 @@ class AgIndicatorState:
supported_values_text = ( supported_values_text = (
f'({",".join(str(v) for v in self.supported_values)})' f'({",".join(str(v) for v in self.supported_values)})'
) )
return f'("{self.indicator.value}",{supported_values_text})' return f'(\"{self.indicator.value}\",{supported_values_text})'
@classmethod @classmethod
def call(cls: type[Self]) -> Self: def call(cls: Type[Self]) -> Self:
"""Default call indicator state.""" """Default call indicator state."""
return cls( return cls(
indicator=AgIndicator.CALL, supported_values={0, 1}, current_status=0 indicator=AgIndicator.CALL, supported_values={0, 1}, current_status=0
) )
@classmethod @classmethod
def callsetup(cls: type[Self]) -> Self: def callsetup(cls: Type[Self]) -> Self:
"""Default callsetup indicator state.""" """Default callsetup indicator state."""
return cls( return cls(
indicator=AgIndicator.CALL_SETUP, indicator=AgIndicator.CALL_SETUP,
@@ -574,7 +632,7 @@ class AgIndicatorState:
) )
@classmethod @classmethod
def callheld(cls: type[Self]) -> Self: def callheld(cls: Type[Self]) -> Self:
"""Default call indicator state.""" """Default call indicator state."""
return cls( return cls(
indicator=AgIndicator.CALL_HELD, indicator=AgIndicator.CALL_HELD,
@@ -583,14 +641,14 @@ class AgIndicatorState:
) )
@classmethod @classmethod
def service(cls: type[Self]) -> Self: def service(cls: Type[Self]) -> Self:
"""Default service indicator state.""" """Default service indicator state."""
return cls( return cls(
indicator=AgIndicator.SERVICE, supported_values={0, 1}, current_status=0 indicator=AgIndicator.SERVICE, supported_values={0, 1}, current_status=0
) )
@classmethod @classmethod
def signal(cls: type[Self]) -> Self: def signal(cls: Type[Self]) -> Self:
"""Default signal indicator state.""" """Default signal indicator state."""
return cls( return cls(
indicator=AgIndicator.SIGNAL, indicator=AgIndicator.SIGNAL,
@@ -599,14 +657,14 @@ class AgIndicatorState:
) )
@classmethod @classmethod
def roam(cls: type[Self]) -> Self: def roam(cls: Type[Self]) -> Self:
"""Default roam indicator state.""" """Default roam indicator state."""
return cls( return cls(
indicator=AgIndicator.CALL, supported_values={0, 1}, current_status=0 indicator=AgIndicator.CALL, supported_values={0, 1}, current_status=0
) )
@classmethod @classmethod
def battchg(cls: type[Self]) -> Self: def battchg(cls: Type[Self]) -> Self:
"""Default battery charge indicator state.""" """Default battery charge indicator state."""
return cls( return cls(
indicator=AgIndicator.BATTERY_CHARGE, indicator=AgIndicator.BATTERY_CHARGE,
@@ -674,19 +732,22 @@ class HfProtocol(utils.EventEmitter):
"""Termination signal for run() loop.""" """Termination signal for run() loop."""
supported_hf_features: int supported_hf_features: int
supported_audio_codecs: list[AudioCodec] supported_audio_codecs: List[AudioCodec]
supported_ag_features: int supported_ag_features: int
supported_ag_call_hold_operations: list[CallHoldOperation] supported_ag_call_hold_operations: List[CallHoldOperation]
ag_indicators: list[AgIndicatorState] ag_indicators: List[AgIndicatorState]
hf_indicators: dict[HfIndicator, HfIndicatorState] hf_indicators: Dict[HfIndicator, HfIndicatorState]
dlc: rfcomm.DLC dlc: rfcomm.DLC
command_lock: asyncio.Lock command_lock: asyncio.Lock
pending_command: str | None = None if TYPE_CHECKING:
response_queue: asyncio.Queue[AtResponse] response_queue: asyncio.Queue[AtResponse]
unsolicited_queue: asyncio.Queue[AtResponse | None] unsolicited_queue: asyncio.Queue[Optional[AtResponse]]
else:
response_queue: asyncio.Queue
unsolicited_queue: asyncio.Queue
read_buffer: bytearray read_buffer: bytearray
active_codec: AudioCodec active_codec: AudioCodec
@@ -708,7 +769,7 @@ class HfProtocol(utils.EventEmitter):
# Build local features. # Build local features.
self.supported_hf_features = sum(configuration.supported_hf_features) self.supported_hf_features = sum(configuration.supported_hf_features)
self.supported_audio_codecs = list(configuration.supported_audio_codecs) self.supported_audio_codecs = configuration.supported_audio_codecs
self.hf_indicators = { self.hf_indicators = {
indicator: HfIndicatorState(indicator=indicator) indicator: HfIndicatorState(indicator=indicator)
@@ -759,46 +820,23 @@ class HfProtocol(utils.EventEmitter):
self.read_buffer = self.read_buffer[trailer + 2 :] self.read_buffer = self.read_buffer[trailer + 2 :]
# Forward the received code to the correct queue. # Forward the received code to the correct queue.
if self.pending_command and ( if self.command_lock.locked() and (
response.code in STATUS_CODES or response.code in self.pending_command response.code in STATUS_CODES or response.code in RESPONSE_CODES
): ):
self.response_queue.put_nowait(response) self.response_queue.put_nowait(response)
else: elif response.code in UNSOLICITED_CODES:
self.unsolicited_queue.put_nowait(response) self.unsolicited_queue.put_nowait(response)
else:
@overload logger.warning(
async def execute_command( f"dropping unexpected response with code '{response.code}'"
self, )
cmd: str,
timeout: float = 1.0,
*,
response_type: Literal[AtResponseType.NONE] = AtResponseType.NONE,
) -> None: ...
@overload
async def execute_command(
self,
cmd: str,
timeout: float = 1.0,
*,
response_type: Literal[AtResponseType.SINGLE],
) -> AtResponse: ...
@overload
async def execute_command(
self,
cmd: str,
timeout: float = 1.0,
*,
response_type: Literal[AtResponseType.MULTIPLE],
) -> list[AtResponse]: ...
async def execute_command( async def execute_command(
self, self,
cmd: str, cmd: str,
timeout: float = 1.0, timeout: float = 1.0,
response_type: AtResponseType = AtResponseType.NONE, response_type: AtResponseType = AtResponseType.NONE,
) -> None | AtResponse | list[AtResponse]: ) -> Union[None, AtResponse, List[AtResponse]]:
""" """
Sends an AT command and wait for the peer response. Sends an AT command and wait for the peer response.
Wait for the AT responses sent by the peer, to the status code. Wait for the AT responses sent by the peer, to the status code.
@@ -812,34 +850,27 @@ class HfProtocol(utils.EventEmitter):
asyncio.TimeoutError: the status is not received after a timeout (default 1 second). asyncio.TimeoutError: the status is not received after a timeout (default 1 second).
ProtocolError: the status is not OK. ProtocolError: the status is not OK.
""" """
try: async with self.command_lock:
async with self.command_lock: logger.debug(f">>> {cmd}")
self.pending_command = cmd self.dlc.write(cmd + '\r')
logger.debug(f">>> {cmd}") responses: List[AtResponse] = []
self.dlc.write(cmd + '\r')
responses: list[AtResponse] = []
while True: while True:
result = await asyncio.wait_for( result = await asyncio.wait_for(
self.response_queue.get(), timeout=timeout self.response_queue.get(), timeout=timeout
) )
if result.code == 'OK': if result.code == 'OK':
if ( if response_type == AtResponseType.SINGLE and len(responses) != 1:
response_type == AtResponseType.SINGLE raise HfpProtocolError("NO ANSWER")
and len(responses) != 1
):
raise HfpProtocolError("NO ANSWER")
if response_type == AtResponseType.MULTIPLE: if response_type == AtResponseType.MULTIPLE:
return responses return responses
if response_type == AtResponseType.SINGLE: if response_type == AtResponseType.SINGLE:
return responses[0] return responses[0]
return None return None
if result.code in STATUS_CODES: if result.code in STATUS_CODES:
raise HfpProtocolError(result.code) raise HfpProtocolError(result.code)
responses.append(result) responses.append(result)
finally:
self.pending_command = None
async def initiate_slc(self): async def initiate_slc(self):
"""4.2.1 Service Level Connection Initialization.""" """4.2.1 Service Level Connection Initialization."""
@@ -1042,7 +1073,7 @@ class HfProtocol(utils.EventEmitter):
# code, with the value indicating (call=0). # code, with the value indicating (call=0).
await self.execute_command("AT+CHUP") await self.execute_command("AT+CHUP")
async def query_current_calls(self) -> list[CallInfo]: async def query_current_calls(self) -> List[CallInfo]:
"""4.32.1 Query List of Current Calls in AG. """4.32.1 Query List of Current Calls in AG.
Return: Return:
@@ -1051,6 +1082,7 @@ class HfProtocol(utils.EventEmitter):
responses = await self.execute_command( responses = await self.execute_command(
"AT+CLCC", response_type=AtResponseType.MULTIPLE "AT+CLCC", response_type=AtResponseType.MULTIPLE
) )
assert isinstance(responses, list)
calls = [] calls = []
for response in responses: for response in responses:
@@ -1172,27 +1204,27 @@ class AgProtocol(utils.EventEmitter):
EVENT_MICROPHONE_VOLUME = "microphone_volume" EVENT_MICROPHONE_VOLUME = "microphone_volume"
supported_hf_features: int supported_hf_features: int
supported_hf_indicators: set[HfIndicator] supported_hf_indicators: Set[HfIndicator]
supported_audio_codecs: list[AudioCodec] supported_audio_codecs: List[AudioCodec]
supported_ag_features: int supported_ag_features: int
supported_ag_call_hold_operations: list[CallHoldOperation] supported_ag_call_hold_operations: List[CallHoldOperation]
ag_indicators: list[AgIndicatorState] ag_indicators: List[AgIndicatorState]
hf_indicators: collections.OrderedDict[HfIndicator, HfIndicatorState] hf_indicators: collections.OrderedDict[HfIndicator, HfIndicatorState]
dlc: rfcomm.DLC dlc: rfcomm.DLC
read_buffer: bytearray read_buffer: bytearray
active_codec: AudioCodec active_codec: AudioCodec
calls: list[CallInfo] calls: List[CallInfo]
indicator_report_enabled: bool indicator_report_enabled: bool
inband_ringtone_enabled: bool inband_ringtone_enabled: bool
cme_error_enabled: bool cme_error_enabled: bool
cli_notification_enabled: bool cli_notification_enabled: bool
call_waiting_enabled: bool call_waiting_enabled: bool
_remained_slc_setup_features: set[HfFeature] _remained_slc_setup_features: Set[HfFeature]
def __init__(self, dlc: rfcomm.DLC, configuration: AgConfiguration) -> None: def __init__(self, dlc: rfcomm.DLC, configuration: AgConfiguration) -> None:
super().__init__() super().__init__()
@@ -1335,7 +1367,7 @@ class AgProtocol(utils.EventEmitter):
logger.warning(f'AG indicator {indicator} is disabled') logger.warning(f'AG indicator {indicator} is disabled')
indicator_state.current_status = value indicator_state.current_status = value
self.send_response(f'+CIEV: {index + 1},{value}') self.send_response(f'+CIEV: {index+1},{value}')
async def negotiate_codec(self, codec: AudioCodec) -> None: async def negotiate_codec(self, codec: AudioCodec) -> None:
"""Starts codec negotiation.""" """Starts codec negotiation."""
@@ -1395,13 +1427,13 @@ class AgProtocol(utils.EventEmitter):
self.emit(self.EVENT_VOICE_RECOGNITION, VoiceRecognitionState(int(vrec))) self.emit(self.EVENT_VOICE_RECOGNITION, VoiceRecognitionState(int(vrec)))
def _on_chld(self, operation_code: bytes) -> None: def _on_chld(self, operation_code: bytes) -> None:
call_index: int | None = None call_index: Optional[int] = None
if len(operation_code) > 1: if len(operation_code) > 1:
call_index = int(operation_code[1:]) call_index = int(operation_code[1:])
operation_code = operation_code[:1] + b'x' operation_code = operation_code[:1] + b'x'
try: try:
operation = CallHoldOperation(operation_code.decode()) operation = CallHoldOperation(operation_code.decode())
except Exception: except:
logger.error(f'Invalid operation: {operation_code.decode()}') logger.error(f'Invalid operation: {operation_code.decode()}')
self.send_cme_error(CmeError.OPERATION_NOT_SUPPORTED) self.send_cme_error(CmeError.OPERATION_NOT_SUPPORTED)
return return
@@ -1465,8 +1497,8 @@ class AgProtocol(utils.EventEmitter):
def _on_cmer( def _on_cmer(
self, self,
mode: bytes, mode: bytes,
keypad: bytes | None = None, keypad: Optional[bytes] = None,
display: bytes | None = None, display: Optional[bytes] = None,
indicator: bytes = b'', indicator: bytes = b'',
) -> None: ) -> None:
if ( if (
@@ -1573,7 +1605,7 @@ class AgProtocol(utils.EventEmitter):
def _on_clcc(self) -> None: def _on_clcc(self) -> None:
for call in self.calls: for call in self.calls:
number_text = f',"{call.number}"' if call.number is not None else '' number_text = f',\"{call.number}\"' if call.number is not None else ''
type_text = f',{call.type}' if call.type is not None else '' type_text = f',{call.type}' if call.type is not None else ''
response = ( response = (
f'+CLCC: {call.index}' f'+CLCC: {call.index}'
@@ -1607,7 +1639,7 @@ class AgProtocol(utils.EventEmitter):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class ProfileVersion(utils.OpenIntEnum): class ProfileVersion(enum.IntEnum):
""" """
Profile version (normative). Profile version (normative).
@@ -1662,7 +1694,7 @@ def make_hf_sdp_records(
rfcomm_channel: int, rfcomm_channel: int,
configuration: HfConfiguration, configuration: HfConfiguration,
version: ProfileVersion = ProfileVersion.V1_8, version: ProfileVersion = ProfileVersion.V1_8,
) -> list[sdp.ServiceAttribute]: ) -> List[sdp.ServiceAttribute]:
""" """
Generates the SDP record for HFP Hands-Free support. Generates the SDP record for HFP Hands-Free support.
@@ -1748,7 +1780,7 @@ def make_ag_sdp_records(
rfcomm_channel: int, rfcomm_channel: int,
configuration: AgConfiguration, configuration: AgConfiguration,
version: ProfileVersion = ProfileVersion.V1_8, version: ProfileVersion = ProfileVersion.V1_8,
) -> list[sdp.ServiceAttribute]: ) -> List[sdp.ServiceAttribute]:
""" """
Generates the SDP record for HFP Audio-Gateway support. Generates the SDP record for HFP Audio-Gateway support.
@@ -1828,7 +1860,7 @@ def make_ag_sdp_records(
async def find_hf_sdp_record( async def find_hf_sdp_record(
connection: device.Connection, connection: device.Connection,
) -> tuple[int, ProfileVersion, HfSdpFeature] | None: ) -> Optional[Tuple[int, ProfileVersion, HfSdpFeature]]:
"""Searches a Hands-Free SDP record from remote device. """Searches a Hands-Free SDP record from remote device.
Args: Args:
@@ -1848,9 +1880,9 @@ async def find_hf_sdp_record(
], ],
) )
for attribute_lists in search_result: for attribute_lists in search_result:
channel: int | None = None channel: Optional[int] = None
version: ProfileVersion | None = None version: Optional[ProfileVersion] = None
features: HfSdpFeature | None = None features: Optional[HfSdpFeature] = None
for attribute in attribute_lists: for attribute in attribute_lists:
# The layout is [[L2CAP_PROTOCOL], [RFCOMM_PROTOCOL, RFCOMM_CHANNEL]]. # The layout is [[L2CAP_PROTOCOL], [RFCOMM_PROTOCOL, RFCOMM_CHANNEL]].
if attribute.id == sdp.SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID: if attribute.id == sdp.SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID:
@@ -1880,7 +1912,7 @@ async def find_hf_sdp_record(
async def find_ag_sdp_record( async def find_ag_sdp_record(
connection: device.Connection, connection: device.Connection,
) -> tuple[int, ProfileVersion, AgSdpFeature] | None: ) -> Optional[Tuple[int, ProfileVersion, AgSdpFeature]]:
"""Searches an Audio-Gateway SDP record from remote device. """Searches an Audio-Gateway SDP record from remote device.
Args: Args:
@@ -1899,9 +1931,9 @@ async def find_ag_sdp_record(
], ],
) )
for attribute_lists in search_result: for attribute_lists in search_result:
channel: int | None = None channel: Optional[int] = None
version: ProfileVersion | None = None version: Optional[ProfileVersion] = None
features: AgSdpFeature | None = None features: Optional[AgSdpFeature] = None
for attribute in attribute_lists: for attribute in attribute_lists:
# The layout is [[L2CAP_PROTOCOL], [RFCOMM_PROTOCOL, RFCOMM_CHANNEL]]. # The layout is [[L2CAP_PROTOCOL], [RFCOMM_PROTOCOL, RFCOMM_CHANNEL]].
if attribute.id == sdp.SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID: if attribute.id == sdp.SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID:
@@ -1955,8 +1987,12 @@ class EscoParameters:
output_coding_format: CodingFormat = CodingFormat(CodecID.LINEAR_PCM) output_coding_format: CodingFormat = CodingFormat(CodecID.LINEAR_PCM)
input_coded_data_size: int = 16 input_coded_data_size: int = 16
output_coded_data_size: int = 16 output_coded_data_size: int = 16
input_pcm_data_format: PcmDataFormat = PcmDataFormat.TWOS_COMPLEMENT input_pcm_data_format: (
output_pcm_data_format: PcmDataFormat = PcmDataFormat.TWOS_COMPLEMENT HCI_Enhanced_Setup_Synchronous_Connection_Command.PcmDataFormat
) = HCI_Enhanced_Setup_Synchronous_Connection_Command.PcmDataFormat.TWOS_COMPLEMENT
output_pcm_data_format: (
HCI_Enhanced_Setup_Synchronous_Connection_Command.PcmDataFormat
) = HCI_Enhanced_Setup_Synchronous_Connection_Command.PcmDataFormat.TWOS_COMPLEMENT
input_pcm_sample_payload_msb_position: int = 0 input_pcm_sample_payload_msb_position: int = 0
output_pcm_sample_payload_msb_position: int = 0 output_pcm_sample_payload_msb_position: int = 0
input_data_path: HCI_Enhanced_Setup_Synchronous_Connection_Command.DataPath = ( input_data_path: HCI_Enhanced_Setup_Synchronous_Connection_Command.DataPath = (
@@ -1974,7 +2010,7 @@ class EscoParameters:
transmit_codec_frame_size: int = 60 transmit_codec_frame_size: int = 60
receive_codec_frame_size: int = 60 receive_codec_frame_size: int = 60
def asdict(self) -> dict[str, Any]: def asdict(self) -> Dict[str, Any]:
# dataclasses.asdict() will recursively deep-copy the entire object, # dataclasses.asdict() will recursively deep-copy the entire object,
# which is expensive and breaks CodingFormat object, so let it simply copy here. # which is expensive and breaks CodingFormat object, so let it simply copy here.
return self.__dict__ return self.__dict__
@@ -2055,7 +2091,6 @@ _ESCO_PARAMETERS_MSBC_T1 = EscoParameters(
max_latency=0x0008, max_latency=0x0008,
packet_type=( packet_type=(
HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.EV3 HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.EV3
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_2_EV3
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_3_EV3 | HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_3_EV3
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_2_EV5 | HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_2_EV5
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_3_EV5 | HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_3_EV5
@@ -2071,6 +2106,7 @@ _ESCO_PARAMETERS_MSBC_T2 = EscoParameters(
max_latency=0x000D, max_latency=0x000D,
packet_type=( packet_type=(
HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.EV3 HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.EV3
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_2_EV3
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_3_EV3 | HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_3_EV3
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_2_EV5 | HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_2_EV5
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_3_EV5 | HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_3_EV5
+27 -33
View File
@@ -16,20 +16,22 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations from __future__ import annotations
import enum
import logging
import struct
from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass
import logging
import enum
import struct
from abc import ABC, abstractmethod
from typing import Optional, Callable
from typing_extensions import override from typing_extensions import override
from bumble import device, l2cap, utils from bumble import l2cap
from bumble import device
from bumble import utils
from bumble.core import InvalidStateError, ProtocolError from bumble.core import InvalidStateError, ProtocolError
from bumble.hci import Address from bumble.hci import Address
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -195,9 +197,9 @@ class SendHandshakeMessage(Message):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class HID(ABC, utils.EventEmitter): class HID(ABC, utils.EventEmitter):
l2cap_ctrl_channel: l2cap.ClassicChannel | None = None l2cap_ctrl_channel: Optional[l2cap.ClassicChannel] = None
l2cap_intr_channel: l2cap.ClassicChannel | None = None l2cap_intr_channel: Optional[l2cap.ClassicChannel] = None
connection: device.Connection | None = None connection: Optional[device.Connection] = None
EVENT_INTERRUPT_DATA = "interrupt_data" EVENT_INTERRUPT_DATA = "interrupt_data"
EVENT_CONTROL_DATA = "control_data" EVENT_CONTROL_DATA = "control_data"
@@ -212,46 +214,38 @@ class HID(ABC, utils.EventEmitter):
def __init__(self, device: device.Device, role: Role) -> None: def __init__(self, device: device.Device, role: Role) -> None:
super().__init__() super().__init__()
self.remote_device_bd_address: Address | None = None self.remote_device_bd_address: Optional[Address] = None
self.device = device self.device = device
self.role = role self.role = role
# Register ourselves with the L2CAP channel manager # Register ourselves with the L2CAP channel manager
device.create_l2cap_server( device.register_l2cap_server(HID_CONTROL_PSM, self.on_l2cap_connection)
l2cap.ClassicChannelSpec(HID_CONTROL_PSM), self.on_l2cap_connection device.register_l2cap_server(HID_INTERRUPT_PSM, self.on_l2cap_connection)
)
device.create_l2cap_server(
l2cap.ClassicChannelSpec(HID_INTERRUPT_PSM), self.on_l2cap_connection
)
device.on(device.EVENT_CONNECTION, self.on_device_connection) device.on(device.EVENT_CONNECTION, self.on_device_connection)
async def connect_control_channel(self) -> None: async def connect_control_channel(self) -> None:
if not self.connection:
raise InvalidStateError("Connection is not established!")
# Create a new L2CAP connection - control channel # Create a new L2CAP connection - control channel
try: try:
channel = await self.connection.create_l2cap_channel( channel = await self.device.l2cap_channel_manager.connect(
l2cap.ClassicChannelSpec(HID_CONTROL_PSM) self.connection, HID_CONTROL_PSM
) )
channel.sink = self.on_ctrl_pdu channel.sink = self.on_ctrl_pdu
self.l2cap_ctrl_channel = channel self.l2cap_ctrl_channel = channel
except ProtocolError: except ProtocolError:
logging.exception('L2CAP connection failed.') logging.exception(f'L2CAP connection failed.')
raise raise
async def connect_interrupt_channel(self) -> None: async def connect_interrupt_channel(self) -> None:
if not self.connection:
raise InvalidStateError("Connection is not established!")
# Create a new L2CAP connection - interrupt channel # Create a new L2CAP connection - interrupt channel
try: try:
channel = await self.connection.create_l2cap_channel( channel = await self.device.l2cap_channel_manager.connect(
l2cap.ClassicChannelSpec(HID_INTERRUPT_PSM) self.connection, HID_INTERRUPT_PSM
) )
channel.sink = self.on_intr_pdu channel.sink = self.on_intr_pdu
self.l2cap_intr_channel = channel self.l2cap_intr_channel = channel
except ProtocolError: except ProtocolError:
logging.exception('L2CAP connection failed.') logging.exception(f'L2CAP connection failed.')
raise raise
async def disconnect_interrupt_channel(self) -> None: async def disconnect_interrupt_channel(self) -> None:
@@ -312,11 +306,11 @@ class HID(ABC, utils.EventEmitter):
def send_pdu_on_ctrl(self, msg: bytes) -> None: def send_pdu_on_ctrl(self, msg: bytes) -> None:
assert self.l2cap_ctrl_channel assert self.l2cap_ctrl_channel
self.l2cap_ctrl_channel.write(msg) self.l2cap_ctrl_channel.send_pdu(msg)
def send_pdu_on_intr(self, msg: bytes) -> None: def send_pdu_on_intr(self, msg: bytes) -> None:
assert self.l2cap_intr_channel assert self.l2cap_intr_channel
self.l2cap_intr_channel.write(msg) self.l2cap_intr_channel.send_pdu(msg)
def send_data(self, data: bytes) -> None: def send_data(self, data: bytes) -> None:
if self.role == HID.Role.HOST: if self.role == HID.Role.HOST:
@@ -353,10 +347,10 @@ class Device(HID):
data: bytes = b'' data: bytes = b''
status: int = 0 status: int = 0
get_report_cb: Callable[[int, int, int], GetSetStatus] | None = None get_report_cb: Optional[Callable[[int, int, int], GetSetStatus]] = None
set_report_cb: Callable[[int, int, int, bytes], GetSetStatus] | None = None set_report_cb: Optional[Callable[[int, int, int, bytes], GetSetStatus]] = None
get_protocol_cb: Callable[[], GetSetStatus] | None = None get_protocol_cb: Optional[Callable[[], GetSetStatus]] = None
set_protocol_cb: Callable[[int], GetSetStatus] | None = None set_protocol_cb: Optional[Callable[[int], GetSetStatus]] = None
def __init__(self, device: device.Device) -> None: def __init__(self, device: device.Device) -> None:
super().__init__(device, HID.Role.DEVICE) super().__init__(device, HID.Role.DEVICE)
+334 -797
View File
File diff suppressed because it is too large Load Diff
+51 -53
View File
@@ -21,19 +21,16 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import dataclasses import dataclasses
import json
import logging import logging
import os import os
import pathlib import json
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Any
from typing_extensions import Self from typing_extensions import Self
from bumble import hci
from bumble.colors import color from bumble.colors import color
from bumble import hci
if TYPE_CHECKING: if TYPE_CHECKING:
from bumble.device import Device from bumble.device import Device
@@ -52,8 +49,8 @@ class PairingKeys:
class Key: class Key:
value: bytes value: bytes
authenticated: bool = False authenticated: bool = False
ediv: int | None = None ediv: Optional[int] = None
rand: bytes | None = None rand: Optional[bytes] = None
@classmethod @classmethod
def from_dict(cls, key_dict: dict[str, Any]) -> PairingKeys.Key: def from_dict(cls, key_dict: dict[str, Any]) -> PairingKeys.Key:
@@ -75,17 +72,17 @@ class PairingKeys:
return key_dict return key_dict
address_type: hci.AddressType | None = None address_type: Optional[hci.AddressType] = None
ltk: Key | None = None ltk: Optional[Key] = None
ltk_central: Key | None = None ltk_central: Optional[Key] = None
ltk_peripheral: Key | None = None ltk_peripheral: Optional[Key] = None
irk: Key | None = None irk: Optional[Key] = None
csrk: Key | None = None csrk: Optional[Key] = None
link_key: Key | None = None # Classic link_key: Optional[Key] = None # Classic
link_key_type: int | None = None # Classic link_key_type: Optional[int] = None # Classic
@classmethod @classmethod
def key_from_dict(cls, keys_dict: dict[str, Any], key_name: str) -> Key | None: def key_from_dict(cls, keys_dict: dict[str, Any], key_name: str) -> Optional[Key]:
key_dict = keys_dict.get(key_name) key_dict = keys_dict.get(key_name)
if key_dict is None: if key_dict is None:
return None return None
@@ -157,10 +154,10 @@ class KeyStore:
async def update(self, name: str, keys: PairingKeys) -> None: async def update(self, name: str, keys: PairingKeys) -> None:
pass pass
async def get(self, _name: str) -> PairingKeys | None: async def get(self, _name: str) -> Optional[PairingKeys]:
return None return None
async def get_all(self) -> list[tuple[str, PairingKeys]]: async def get_all(self) -> List[Tuple[str, PairingKeys]]:
return [] return []
async def delete_all(self) -> None: async def delete_all(self) -> None:
@@ -249,30 +246,33 @@ class JsonKeyStore(KeyStore):
DEFAULT_NAMESPACE = '__DEFAULT__' DEFAULT_NAMESPACE = '__DEFAULT__'
DEFAULT_BASE_NAME = "keys" DEFAULT_BASE_NAME = "keys"
def __init__( def __init__(self, namespace, filename=None):
self, namespace: str | None = None, filename: str | None = None self.namespace = namespace if namespace is not None else self.DEFAULT_NAMESPACE
) -> None:
self.namespace = namespace or self.DEFAULT_NAMESPACE
if filename: if filename is None:
self.filename = pathlib.Path(filename).resolve() # Use a default for the current user
self.directory_name = self.filename.parent
# Import here because this may not exist on all platforms
# pylint: disable=import-outside-toplevel
import appdirs
self.directory_name = os.path.join(
appdirs.user_data_dir(self.APP_NAME, self.APP_AUTHOR), self.KEYS_DIR
)
base_name = self.DEFAULT_BASE_NAME if namespace is None else self.namespace
json_filename = (
f'{base_name}.json'.lower().replace(':', '-').replace('/p', '-p')
)
self.filename = os.path.join(self.directory_name, json_filename)
else: else:
import platformdirs # Deferred import self.filename = filename
self.directory_name = os.path.dirname(os.path.abspath(self.filename))
base_dir = platformdirs.user_data_path(self.APP_NAME, self.APP_AUTHOR) logger.debug(f'JSON keystore: {self.filename}')
self.directory_name = base_dir / self.KEYS_DIR
base_name = self.namespace if namespace else self.DEFAULT_BASE_NAME
safe_name = base_name.lower().replace(':', '-').replace('/', '-')
self.filename = self.directory_name / f"{safe_name}.json"
logger.debug('JSON keystore: %s', self.filename)
@classmethod @classmethod
def from_device( def from_device(
cls: type[Self], device: Device, filename: str | None = None cls: Type[Self], device: Device, filename: Optional[str] = None
) -> Self: ) -> Self:
if not filename: if not filename:
# Extract the filename from the config if there is one # Extract the filename from the config if there is one
@@ -291,13 +291,11 @@ class JsonKeyStore(KeyStore):
return cls(namespace, filename) return cls(namespace, filename)
async def load( async def load(self):
self,
) -> tuple[dict[str, dict[str, dict[str, Any]]], dict[str, dict[str, Any]]]:
# Try to open the file, without failing. If the file does not exist, it # Try to open the file, without failing. If the file does not exist, it
# will be created upon saving. # will be created upon saving.
try: try:
with open(self.filename, encoding='utf-8') as json_file: with open(self.filename, 'r', encoding='utf-8') as json_file:
db = json.load(json_file) db = json.load(json_file)
except FileNotFoundError: except FileNotFoundError:
db = {} db = {}
@@ -312,17 +310,17 @@ class JsonKeyStore(KeyStore):
return next(iter(db.items())) return next(iter(db.items()))
# Finally, just create an empty key map for the namespace # Finally, just create an empty key map for the namespace
key_map: dict[str, dict[str, Any]] = {} key_map = {}
db[self.namespace] = key_map db[self.namespace] = key_map
return (db, key_map) return (db, key_map)
async def save(self, db: dict[str, dict[str, dict[str, Any]]]) -> None: async def save(self, db):
# Create the directory if it doesn't exist # Create the directory if it doesn't exist
if not self.directory_name.exists(): if not os.path.exists(self.directory_name):
self.directory_name.mkdir(parents=True, exist_ok=True) os.makedirs(self.directory_name, exist_ok=True)
# Save to a temporary file # Save to a temporary file
temp_filename = self.filename.with_name(self.filename.name + ".tmp") temp_filename = self.filename + '.tmp'
with open(temp_filename, 'w', encoding='utf-8') as output: with open(temp_filename, 'w', encoding='utf-8') as output:
json.dump(db, output, sort_keys=True, indent=4) json.dump(db, output, sort_keys=True, indent=4)
@@ -334,21 +332,21 @@ class JsonKeyStore(KeyStore):
del key_map[name] del key_map[name]
await self.save(db) await self.save(db)
async def update(self, name: str, keys: PairingKeys) -> None: async def update(self, name, keys):
db, key_map = await self.load() db, key_map = await self.load()
key_map.setdefault(name, {}).update(keys.to_dict()) key_map.setdefault(name, {}).update(keys.to_dict())
await self.save(db) await self.save(db)
async def get_all(self) -> list[tuple[str, PairingKeys]]: async def get_all(self):
_, key_map = await self.load() _, key_map = await self.load()
return [(name, PairingKeys.from_dict(keys)) for (name, keys) in key_map.items()] return [(name, PairingKeys.from_dict(keys)) for (name, keys) in key_map.items()]
async def delete_all(self) -> None: async def delete_all(self):
db, key_map = await self.load() db, key_map = await self.load()
key_map.clear() key_map.clear()
await self.save(db) await self.save(db)
async def get(self, name: str) -> PairingKeys | None: async def get(self, name: str) -> Optional[PairingKeys]:
_, key_map = await self.load() _, key_map = await self.load()
if name not in key_map: if name not in key_map:
return None return None
@@ -358,7 +356,7 @@ class JsonKeyStore(KeyStore):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class MemoryKeyStore(KeyStore): class MemoryKeyStore(KeyStore):
all_keys: dict[str, PairingKeys] all_keys: Dict[str, PairingKeys]
def __init__(self) -> None: def __init__(self) -> None:
self.all_keys = {} self.all_keys = {}
@@ -370,8 +368,8 @@ class MemoryKeyStore(KeyStore):
async def update(self, name: str, keys: PairingKeys) -> None: async def update(self, name: str, keys: PairingKeys) -> None:
self.all_keys[name] = keys self.all_keys[name] = keys
async def get(self, name: str) -> PairingKeys | None: async def get(self, name: str) -> Optional[PairingKeys]:
return self.all_keys.get(name) return self.all_keys.get(name)
async def get_all(self) -> list[tuple[str, PairingKeys]]: async def get_all(self) -> List[Tuple[str, PairingKeys]]:
return list(self.all_keys.items()) return list(self.all_keys.items())
+680 -1489
View File
File diff suppressed because it is too large Load Diff
+558 -63
View File
@@ -11,20 +11,32 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import annotations
import asyncio
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import logging import logging
from typing import TYPE_CHECKING import asyncio
from functools import partial
from bumble import core, hci, ll, lmp from bumble.core import (
PhysicalTransport,
InvalidStateError,
)
from bumble.colors import color
from bumble.hci import (
Address,
Role,
HCI_SUCCESS,
HCI_CONNECTION_ACCEPT_TIMEOUT_ERROR,
HCI_CONNECTION_TIMEOUT_ERROR,
HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR,
HCI_PAGE_TIMEOUT_ERROR,
HCI_Connection_Complete_Event,
)
from bumble import controller
if TYPE_CHECKING: from typing import Optional, Set
from bumble import controller
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -32,6 +44,18 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Utils
# -----------------------------------------------------------------------------
def parse_parameters(params_str):
result = {}
for param_str in params_str.split(','):
if '=' in param_str:
key, value = param_str.split('=')
result[key] = value
return result
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# TODO: add more support for various LL exchanges # TODO: add more support for various LL exchanges
# (see Vol 6, Part B - 2.4 DATA CHANNEL PDU) # (see Vol 6, Part B - 2.4 DATA CHANNEL PDU)
@@ -41,38 +65,41 @@ class LocalLink:
Link bus for controllers to communicate with each other Link bus for controllers to communicate with each other
''' '''
controllers: set[controller.Controller] controllers: Set[controller.Controller]
def __init__(self): def __init__(self):
self.controllers = set() self.controllers = set()
self.pending_connection = None
self.pending_classic_connection = None self.pending_classic_connection = None
############################################################ ############################################################
# Common utils # Common utils
############################################################ ############################################################
def add_controller(self, controller: controller.Controller): def add_controller(self, controller):
logger.debug(f'new controller: {controller}') logger.debug(f'new controller: {controller}')
self.controllers.add(controller) self.controllers.add(controller)
def remove_controller(self, controller: controller.Controller): def remove_controller(self, controller):
self.controllers.remove(controller) self.controllers.remove(controller)
def find_le_controller(self, address: hci.Address) -> controller.Controller | None: def find_controller(self, address):
for controller in self.controllers: for controller in self.controllers:
for connection in controller.le_connections.values(): if controller.random_address == address:
if connection.self_address == address: return controller
return controller
return None return None
def find_classic_controller( def find_classic_controller(
self, address: hci.Address self, address: Address
) -> controller.Controller | None: ) -> Optional[controller.Controller]:
for controller in self.controllers: for controller in self.controllers:
if controller.public_address == address: if controller.public_address == address:
return controller return controller
return None return None
def get_pending_connection(self):
return self.pending_connection
############################################################ ############################################################
# LE handlers # LE handlers
############################################################ ############################################################
@@ -80,70 +107,538 @@ class LocalLink:
def on_address_changed(self, controller): def on_address_changed(self, controller):
pass pass
def send_acl_data( def send_advertising_data(self, sender_address, data):
self, # Send the advertising data to all controllers, except the sender
sender_controller: controller.Controller, for controller in self.controllers:
destination_address: hci.Address, if controller.random_address != sender_address:
transport: core.PhysicalTransport, controller.on_link_advertising_data(sender_address, data)
data: bytes,
): def send_acl_data(self, sender_controller, destination_address, transport, data):
# Send the data to the first controller with a matching address # Send the data to the first controller with a matching address
if transport == core.PhysicalTransport.LE: if transport == PhysicalTransport.LE:
destination_controller = self.find_le_controller(destination_address) destination_controller = self.find_controller(destination_address)
source_address = sender_controller.random_address source_address = sender_controller.random_address
elif transport == core.PhysicalTransport.BR_EDR: elif transport == PhysicalTransport.BR_EDR:
destination_controller = self.find_classic_controller(destination_address) destination_controller = self.find_classic_controller(destination_address)
source_address = sender_controller.public_address source_address = sender_controller.public_address
else: else:
raise ValueError("unsupported transport type") raise ValueError("unsupported transport type")
if destination_controller is not None: if destination_controller is not None:
asyncio.get_running_loop().call_soon( destination_controller.on_link_acl_data(source_address, transport, data)
lambda: destination_controller.on_link_acl_data(
source_address, transport, data
)
)
def send_advertising_pdu( def on_connection_complete(self):
self, # Check that we expect this call
sender_controller: controller.Controller, if not self.pending_connection:
packet: ll.AdvertisingPdu, logger.warning('on_connection_complete with no pending connection')
): return
loop = asyncio.get_running_loop()
for c in self.controllers:
if c != sender_controller:
loop.call_soon(c.on_ll_advertising_pdu, packet)
def send_ll_control_pdu( central_address, le_create_connection_command = self.pending_connection
self, self.pending_connection = None
sender_address: hci.Address,
receiver_address: hci.Address, # Find the controller that initiated the connection
packet: ll.ControlPdu, if not (central_controller := self.find_controller(central_address)):
): logger.warning('!!! Initiating controller not found')
if not (receiver_controller := self.find_le_controller(receiver_address)): return
raise core.InvalidArgumentError(
f"Unable to find controller for address {receiver_address}" # Connect to the first controller with a matching address
if peripheral_controller := self.find_controller(
le_create_connection_command.peer_address
):
central_controller.on_link_peripheral_connection_complete(
le_create_connection_command, HCI_SUCCESS
) )
asyncio.get_running_loop().call_soon( peripheral_controller.on_link_central_connected(central_address)
lambda: receiver_controller.on_ll_control_pdu(sender_address, packet) return
# No peripheral found
central_controller.on_link_peripheral_connection_complete(
le_create_connection_command, HCI_CONNECTION_ACCEPT_TIMEOUT_ERROR
) )
def connect(self, central_address, le_create_connection_command):
logger.debug(
f'$$$ CONNECTION {central_address} -> '
f'{le_create_connection_command.peer_address}'
)
self.pending_connection = (central_address, le_create_connection_command)
asyncio.get_running_loop().call_soon(self.on_connection_complete)
def on_disconnection_complete(
self, central_address, peripheral_address, disconnect_command
):
# Find the controller that initiated the disconnection
if not (central_controller := self.find_controller(central_address)):
logger.warning('!!! Initiating controller not found')
return
# Disconnect from the first controller with a matching address
if peripheral_controller := self.find_controller(peripheral_address):
peripheral_controller.on_link_central_disconnected(
central_address, disconnect_command.reason
)
central_controller.on_link_peripheral_disconnection_complete(
disconnect_command, HCI_SUCCESS
)
def disconnect(self, central_address, peripheral_address, disconnect_command):
logger.debug(
f'$$$ DISCONNECTION {central_address} -> '
f'{peripheral_address}: reason = {disconnect_command.reason}'
)
args = [central_address, peripheral_address, disconnect_command]
asyncio.get_running_loop().call_soon(self.on_disconnection_complete, *args)
# pylint: disable=too-many-arguments
def on_connection_encrypted(
self, central_address, peripheral_address, rand, ediv, ltk
):
logger.debug(f'*** ENCRYPTION {central_address} -> {peripheral_address}')
if central_controller := self.find_controller(central_address):
central_controller.on_link_encrypted(peripheral_address, rand, ediv, ltk)
if peripheral_controller := self.find_controller(peripheral_address):
peripheral_controller.on_link_encrypted(central_address, rand, ediv, ltk)
def create_cis(
self,
central_controller: controller.Controller,
peripheral_address: Address,
cig_id: int,
cis_id: int,
) -> None:
logger.debug(
f'$$$ CIS Request {central_controller.random_address} -> {peripheral_address}'
)
if peripheral_controller := self.find_controller(peripheral_address):
asyncio.get_running_loop().call_soon(
peripheral_controller.on_link_cis_request,
central_controller.random_address,
cig_id,
cis_id,
)
def accept_cis(
self,
peripheral_controller: controller.Controller,
central_address: Address,
cig_id: int,
cis_id: int,
) -> None:
logger.debug(
f'$$$ CIS Accept {peripheral_controller.random_address} -> {central_address}'
)
if central_controller := self.find_controller(central_address):
asyncio.get_running_loop().call_soon(
central_controller.on_link_cis_established, cig_id, cis_id
)
asyncio.get_running_loop().call_soon(
peripheral_controller.on_link_cis_established, cig_id, cis_id
)
def disconnect_cis(
self,
initiator_controller: controller.Controller,
peer_address: Address,
cig_id: int,
cis_id: int,
) -> None:
logger.debug(
f'$$$ CIS Disconnect {initiator_controller.random_address} -> {peer_address}'
)
if peer_controller := self.find_controller(peer_address):
asyncio.get_running_loop().call_soon(
initiator_controller.on_link_cis_disconnected, cig_id, cis_id
)
asyncio.get_running_loop().call_soon(
peer_controller.on_link_cis_disconnected, cig_id, cis_id
)
############################################################ ############################################################
# Classic handlers # Classic handlers
############################################################ ############################################################
def send_lmp_packet( def classic_connect(self, initiator_controller, responder_address):
self, logger.debug(
sender_controller: controller.Controller, f'[Classic] {initiator_controller.public_address} connects to {responder_address}'
receiver_address: hci.Address, )
packet: lmp.Packet, responder_controller = self.find_classic_controller(responder_address)
): if responder_controller is None:
if not (receiver_controller := self.find_classic_controller(receiver_address)): initiator_controller.on_classic_connection_complete(
raise core.InvalidArgumentError( responder_address, HCI_PAGE_TIMEOUT_ERROR
f"Unable to find controller for address {receiver_address}"
) )
asyncio.get_running_loop().call_soon( return
lambda: receiver_controller.on_lmp_packet( self.pending_classic_connection = (initiator_controller, responder_controller)
sender_controller.public_address, packet
responder_controller.on_classic_connection_request(
initiator_controller.public_address,
HCI_Connection_Complete_Event.ACL_LINK_TYPE,
)
def classic_accept_connection(
self, responder_controller, initiator_address, responder_role
):
logger.debug(
f'[Classic] {responder_controller.public_address} accepts to connect {initiator_address}'
)
initiator_controller = self.find_classic_controller(initiator_address)
if initiator_controller is None:
responder_controller.on_classic_connection_complete(
responder_controller.public_address, HCI_PAGE_TIMEOUT_ERROR
)
return
async def task():
if responder_role != Role.PERIPHERAL:
initiator_controller.on_classic_role_change(
responder_controller.public_address, int(not (responder_role))
)
initiator_controller.on_classic_connection_complete(
responder_controller.public_address, HCI_SUCCESS
)
asyncio.create_task(task())
responder_controller.on_classic_role_change(
initiator_controller.public_address, responder_role
)
responder_controller.on_classic_connection_complete(
initiator_controller.public_address, HCI_SUCCESS
)
self.pending_classic_connection = None
def classic_disconnect(self, initiator_controller, responder_address, reason):
logger.debug(
f'[Classic] {initiator_controller.public_address} disconnects {responder_address}'
)
responder_controller = self.find_classic_controller(responder_address)
async def task():
initiator_controller.on_classic_disconnected(responder_address, reason)
asyncio.create_task(task())
responder_controller.on_classic_disconnected(
initiator_controller.public_address, reason
)
def classic_switch_role(
self, initiator_controller, responder_address, initiator_new_role
):
responder_controller = self.find_classic_controller(responder_address)
if responder_controller is None:
return
async def task():
initiator_controller.on_classic_role_change(
responder_address, initiator_new_role
)
asyncio.create_task(task())
responder_controller.on_classic_role_change(
initiator_controller.public_address, int(not (initiator_new_role))
)
def classic_sco_connect(
self,
initiator_controller: controller.Controller,
responder_address: Address,
link_type: int,
):
logger.debug(
f'[Classic] {initiator_controller.public_address} connects SCO to {responder_address}'
)
responder_controller = self.find_classic_controller(responder_address)
# Initiator controller should handle it.
assert responder_controller
responder_controller.on_classic_connection_request(
initiator_controller.public_address,
link_type,
)
def classic_accept_sco_connection(
self,
responder_controller: controller.Controller,
initiator_address: Address,
link_type: int,
):
logger.debug(
f'[Classic] {responder_controller.public_address} accepts to connect SCO {initiator_address}'
)
initiator_controller = self.find_classic_controller(initiator_address)
if initiator_controller is None:
responder_controller.on_classic_sco_connection_complete(
responder_controller.public_address,
HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR,
link_type,
)
return
async def task():
initiator_controller.on_classic_sco_connection_complete(
responder_controller.public_address, HCI_SUCCESS, link_type
)
asyncio.create_task(task())
responder_controller.on_classic_sco_connection_complete(
initiator_controller.public_address, HCI_SUCCESS, link_type
)
# -----------------------------------------------------------------------------
class RemoteLink:
'''
A Link implementation that communicates with other virtual controllers via a
WebSocket relay
'''
def __init__(self, uri):
self.controller = None
self.uri = uri
self.execution_queue = asyncio.Queue()
self.websocket = asyncio.get_running_loop().create_future()
self.rpc_result = None
self.pending_connection = None
self.central_connections = set() # List of addresses that we have connected to
self.peripheral_connections = (
set()
) # List of addresses that have connected to us
# Connect and run asynchronously
asyncio.create_task(self.run_connection())
asyncio.create_task(self.run_executor_loop())
def add_controller(self, controller):
if self.controller:
raise InvalidStateError('controller already set')
self.controller = controller
def remove_controller(self, controller):
if self.controller != controller:
raise InvalidStateError('controller mismatch')
self.controller = None
def get_pending_connection(self):
return self.pending_connection
def get_pending_classic_connection(self):
return self.pending_classic_connection
async def wait_until_connected(self):
await self.websocket
def execute(self, async_function):
self.execution_queue.put_nowait(async_function())
async def run_executor_loop(self):
logger.debug('executor loop starting')
while True:
item = await self.execution_queue.get()
try:
await item
except Exception as error:
logger.warning(
f'{color("!!! Exception in async handler:", "red")} {error}'
)
async def run_connection(self):
import websockets # lazy import
# Connect to the relay
logger.debug(f'connecting to {self.uri}')
# pylint: disable-next=no-member
websocket = await websockets.connect(self.uri)
self.websocket.set_result(websocket)
logger.debug(f'connected to {self.uri}')
while True:
message = await websocket.recv()
logger.debug(f'received message: {message}')
keyword, *payload = message.split(':', 1)
handler_name = f'on_{keyword}_received'
handler = getattr(self, handler_name, None)
if handler:
await handler(payload[0] if payload else None)
def close(self):
if self.websocket.done():
logger.debug('closing websocket')
websocket = self.websocket.result()
asyncio.create_task(websocket.close())
async def on_result_received(self, result):
if self.rpc_result:
self.rpc_result.set_result(result)
async def on_left_received(self, address):
if address in self.central_connections:
self.controller.on_link_peripheral_disconnected(Address(address))
self.central_connections.remove(address)
if address in self.peripheral_connections:
self.controller.on_link_central_disconnected(
address, HCI_CONNECTION_TIMEOUT_ERROR
)
self.peripheral_connections.remove(address)
async def on_unreachable_received(self, target):
await self.on_left_received(target)
async def on_message_received(self, message):
sender, *payload = message.split('/', 1)
if payload:
keyword, *payload = payload[0].split(':', 1)
handler_name = f'on_{keyword}_message_received'
handler = getattr(self, handler_name, None)
if handler:
await handler(sender, payload[0] if payload else None)
async def on_advertisement_message_received(self, sender, advertisement):
try:
self.controller.on_link_advertising_data(
Address(sender), bytes.fromhex(advertisement)
)
except Exception:
logger.exception('exception')
async def on_acl_message_received(self, sender, acl_data):
try:
self.controller.on_link_acl_data(Address(sender), bytes.fromhex(acl_data))
except Exception:
logger.exception('exception')
async def on_connect_message_received(self, sender, _):
# Remember the connection
self.peripheral_connections.add(sender)
# Notify the controller
logger.debug(f'connection from central {sender}')
self.controller.on_link_central_connected(Address(sender))
# Accept the connection by responding to it
await self.send_targeted_message(sender, 'connected')
async def on_connected_message_received(self, sender, _):
if not self.pending_connection:
logger.warning('received a connection ack, but no connection is pending')
return
# Remember the connection
self.central_connections.add(sender)
# Notify the controller
logger.debug(f'connected to peripheral {self.pending_connection.peer_address}')
self.controller.on_link_peripheral_connection_complete(
self.pending_connection, HCI_SUCCESS
)
async def on_disconnect_message_received(self, sender, message):
# Notify the controller
params = parse_parameters(message)
reason = int(params.get('reason', str(HCI_CONNECTION_TIMEOUT_ERROR)))
self.controller.on_link_central_disconnected(Address(sender), reason)
# Forget the connection
if sender in self.peripheral_connections:
self.peripheral_connections.remove(sender)
async def on_encrypted_message_received(self, sender, _):
# TODO parse params to get real args
self.controller.on_link_encrypted(Address(sender), bytes(8), 0, bytes(16))
async def send_rpc_command(self, command):
# Ensure we have a connection
websocket = await self.websocket
# Create a future value to hold the eventual result
assert self.rpc_result is None
self.rpc_result = asyncio.get_running_loop().create_future()
# Send the command
await websocket.send(command)
# Wait for the result
rpc_result = await self.rpc_result
self.rpc_result = None
logger.debug(f'rpc_result: {rpc_result}')
# TODO: parse the result
async def send_targeted_message(self, target, message):
# Ensure we have a connection
websocket = await self.websocket
# Send the message
await websocket.send(f'@{target} {message}')
async def notify_address_changed(self):
await self.send_rpc_command(f'/set-address {self.controller.random_address}')
def on_address_changed(self, controller):
logger.info(f'address changed for {controller}: {controller.random_address}')
# Notify the relay of the change
self.execute(self.notify_address_changed)
async def send_advertising_data_to_relay(self, data):
await self.send_targeted_message('*', f'advertisement:{data.hex()}')
def send_advertising_data(self, _, data):
self.execute(partial(self.send_advertising_data_to_relay, data))
async def send_acl_data_to_relay(self, peer_address, data):
await self.send_targeted_message(peer_address, f'acl:{data.hex()}')
def send_acl_data(self, _, peer_address, _transport, data):
# TODO: handle different transport
self.execute(partial(self.send_acl_data_to_relay, peer_address, data))
async def send_connection_request_to_relay(self, peer_address):
await self.send_targeted_message(peer_address, 'connect')
def connect(self, _, le_create_connection_command):
if self.pending_connection:
logger.warning('connection already pending')
return
self.pending_connection = le_create_connection_command
self.execute(
partial(
self.send_connection_request_to_relay,
str(le_create_connection_command.peer_address),
)
)
def on_disconnection_complete(self, disconnect_command):
self.controller.on_link_peripheral_disconnection_complete(
disconnect_command, HCI_SUCCESS
)
def disconnect(self, central_address, peripheral_address, disconnect_command):
logger.debug(
f'disconnect {central_address} -> '
f'{peripheral_address}: reason = {disconnect_command.reason}'
)
self.execute(
partial(
self.send_targeted_message,
peripheral_address,
f'disconnect:reason={disconnect_command.reason}',
)
)
asyncio.get_running_loop().call_soon(
self.on_disconnection_complete, disconnect_command
)
def on_connection_encrypted(self, _, peripheral_address, rand, ediv, ltk):
asyncio.get_running_loop().call_soon(
self.controller.on_link_encrypted, peripheral_address, rand, ediv, ltk
)
self.execute(
partial(
self.send_targeted_message,
peripheral_address,
f'encrypted:ltk={ltk.hex()}',
) )
) )
-221
View File
@@ -1,221 +0,0 @@
# Copyright 2021-2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import dataclasses
from typing import ClassVar
from bumble import hci
# -----------------------------------------------------------------------------
# Advertising PDU
# -----------------------------------------------------------------------------
class AdvertisingPdu:
"""Base Advertising Physical Channel PDU class.
See Core Spec 6.0, Volume 6, Part B, 2.3. Advertising physical channel PDU.
Currently these messages don't really follow the LL spec, because LL protocol is
context-aware and we don't have real physical transport.
"""
@dataclasses.dataclass
class ConnectInd(AdvertisingPdu):
initiator_address: hci.Address
advertiser_address: hci.Address
interval: int
latency: int
timeout: int
@dataclasses.dataclass
class AdvInd(AdvertisingPdu):
advertiser_address: hci.Address
data: bytes
@dataclasses.dataclass
class AdvDirectInd(AdvertisingPdu):
advertiser_address: hci.Address
target_address: hci.Address
@dataclasses.dataclass
class AdvNonConnInd(AdvertisingPdu):
advertiser_address: hci.Address
data: bytes
@dataclasses.dataclass
class AdvExtInd(AdvertisingPdu):
advertiser_address: hci.Address
data: bytes
target_address: hci.Address | None = None
adi: int | None = None
tx_power: int | None = None
# -----------------------------------------------------------------------------
# LL Control PDU
# -----------------------------------------------------------------------------
class ControlPdu:
"""Base LL Control PDU Class.
See Core Spec 6.0, Volume 6, Part B, 2.4.2. LL Control PDU.
Currently these messages don't really follow the LL spec, because LL protocol is
context-aware and we don't have real physical transport.
"""
class Opcode(hci.SpecableEnum):
LL_CONNECTION_UPDATE_IND = 0x00
LL_CHANNEL_MAP_IND = 0x01
LL_TERMINATE_IND = 0x02
LL_ENC_REQ = 0x03
LL_ENC_RSP = 0x04
LL_START_ENC_REQ = 0x05
LL_START_ENC_RSP = 0x06
LL_UNKNOWN_RSP = 0x07
LL_FEATURE_REQ = 0x08
LL_FEATURE_RSP = 0x09
LL_PAUSE_ENC_REQ = 0x0A
LL_PAUSE_ENC_RSP = 0x0B
LL_VERSION_IND = 0x0C
LL_REJECT_IND = 0x0D
LL_PERIPHERAL_FEATURE_REQ = 0x0E
LL_CONNECTION_PARAM_REQ = 0x0F
LL_CONNECTION_PARAM_RSP = 0x10
LL_REJECT_EXT_IND = 0x11
LL_PING_REQ = 0x12
LL_PING_RSP = 0x13
LL_LENGTH_REQ = 0x14
LL_LENGTH_RSP = 0x15
LL_PHY_REQ = 0x16
LL_PHY_RSP = 0x17
LL_PHY_UPDATE_IND = 0x18
LL_MIN_USED_CHANNELS_IND = 0x19
LL_CTE_REQ = 0x1A
LL_CTE_RSP = 0x1B
LL_PERIODIC_SYNC_IND = 0x1C
LL_CLOCK_ACCURACY_REQ = 0x1D
LL_CLOCK_ACCURACY_RSP = 0x1E
LL_CIS_REQ = 0x1F
LL_CIS_RSP = 0x20
LL_CIS_IND = 0x21
LL_CIS_TERMINATE_IND = 0x22
LL_POWER_CONTROL_REQ = 0x23
LL_POWER_CONTROL_RSP = 0x24
LL_POWER_CHANGE_IND = 0x25
LL_SUBRATE_REQ = 0x26
LL_SUBRATE_IND = 0x27
LL_CHANNEL_REPORTING_IND = 0x28
LL_CHANNEL_STATUS_IND = 0x29
LL_PERIODIC_SYNC_WR_IND = 0x2A
LL_FEATURE_EXT_REQ = 0x2B
LL_FEATURE_EXT_RSP = 0x2C
LL_CS_SEC_RSP = 0x2D
LL_CS_CAPABILITIES_REQ = 0x2E
LL_CS_CAPABILITIES_RSP = 0x2F
LL_CS_CONFIG_REQ = 0x30
LL_CS_CONFIG_RSP = 0x31
LL_CS_REQ = 0x32
LL_CS_RSP = 0x33
LL_CS_IND = 0x34
LL_CS_TERMINATE_REQ = 0x35
LL_CS_FAE_REQ = 0x36
LL_CS_FAE_RSP = 0x37
LL_CS_CHANNEL_MAP_IND = 0x38
LL_CS_SEC_REQ = 0x39
LL_CS_TERMINATE_RSP = 0x3A
LL_FRAME_SPACE_REQ = 0x3B
LL_FRAME_SPACE_RSP = 0x3C
opcode: ClassVar[Opcode]
@dataclasses.dataclass
class TerminateInd(ControlPdu):
opcode = ControlPdu.Opcode.LL_TERMINATE_IND
error_code: int
@dataclasses.dataclass
class EncReq(ControlPdu):
opcode = ControlPdu.Opcode.LL_ENC_REQ
rand: bytes
ediv: int
ltk: bytes
@dataclasses.dataclass
class CisReq(ControlPdu):
opcode = ControlPdu.Opcode.LL_CIS_REQ
cig_id: int
cis_id: int
@dataclasses.dataclass
class CisRsp(ControlPdu):
opcode = ControlPdu.Opcode.LL_CIS_REQ
cig_id: int
cis_id: int
@dataclasses.dataclass
class CisInd(ControlPdu):
opcode = ControlPdu.Opcode.LL_CIS_REQ
cig_id: int
cis_id: int
@dataclasses.dataclass
class CisTerminateInd(ControlPdu):
opcode = ControlPdu.Opcode.LL_CIS_TERMINATE_IND
cig_id: int
cis_id: int
error_code: int
@dataclasses.dataclass
class FeatureReq(ControlPdu):
opcode = ControlPdu.Opcode.LL_FEATURE_REQ
feature_set: bytes
@dataclasses.dataclass
class FeatureRsp(ControlPdu):
opcode = ControlPdu.Opcode.LL_FEATURE_RSP
feature_set: bytes
@dataclasses.dataclass
class PeripheralFeatureReq(ControlPdu):
opcode = ControlPdu.Opcode.LL_PERIPHERAL_FEATURE_REQ
feature_set: bytes
-359
View File
@@ -1,359 +0,0 @@
# Copyright 2021-2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import struct
from dataclasses import dataclass, field
from typing import TypeVar
from bumble import hci, utils
class Opcode(utils.OpenIntEnum):
'''
See Bluetooth spec @ Vol 2, Part C - 5.1 PDU summary.
Follow the alphabetical order defined there.
'''
# fmt: off
LMP_ACCEPTED = 3
LMP_ACCEPTED_EXT = 127 << 8 + 1
LMP_AU_RAND = 11
LMP_AUTO_RATE = 35
LMP_CHANNEL_CLASSIFICATION = 127 << 8 + 17
LMP_CHANNEL_CLASSIFICATION_REQ = 127 << 8 + 16
LMP_CLK_ADJ = 127 << 8 + 5
LMP_CLK_ADJ_ACK = 127 << 8 + 6
LMP_CLK_ADJ_REQ = 127 << 8 + 7
LMP_CLKOFFSET_REQ = 5
LMP_CLKOFFSET_RES = 6
LMP_COMB_KEY = 9
LMP_DECR_POWER_REQ = 32
LMP_DETACH = 7
LMP_DHKEY_CHECK = 65
LMP_ENCAPSULATED_HEADER = 61
LMP_ENCAPSULATED_PAYLOAD = 62
LMP_ENCRYPTION_KEY_SIZE_MASK_REQ= 58
LMP_ENCRYPTION_KEY_SIZE_MASK_RES= 59
LMP_ENCRYPTION_KEY_SIZE_REQ = 16
LMP_ENCRYPTION_MODE_REQ = 15
LMP_ESCO_LINK_REQ = 127 << 8 + 12
LMP_FEATURES_REQ = 39
LMP_FEATURES_REQ_EXT = 127 << 8 + 3
LMP_FEATURES_RES = 40
LMP_FEATURES_RES_EXT = 127 << 8 + 4
LMP_HOLD = 20
LMP_HOLD_REQ = 21
LMP_HOST_CONNECTION_REQ = 51
LMP_IN_RAND = 8
LMP_INCR_POWER_REQ = 31
LMP_IO_CAPABILITY_REQ = 127 << 8 + 25
LMP_IO_CAPABILITY_RES = 127 << 8 + 26
LMP_KEYPRESS_NOTIFICATION = 127 << 8 + 30
LMP_MAX_POWER = 33
LMP_MAX_SLOT = 45
LMP_MAX_SLOT_REQ = 46
LMP_MIN_POWER = 34
LMP_NAME_REQ = 1
LMP_NAME_RES = 2
LMP_NOT_ACCEPTED = 4
LMP_NOT_ACCEPTED_EXT = 127 << 8 + 2
LMP_NUMERIC_COMPARISON_FAILED = 127 << 8 + 27
LMP_OOB_FAILED = 127 << 8 + 29
LMP_PACKET_TYPE_TABLE_REQ = 127 << 8 + 11
LMP_PAGE_MODE_REQ = 53
LMP_PAGE_SCAN_MODE_REQ = 54
LMP_PASSKEY_FAILED = 127 << 8 + 28
LMP_PAUSE_ENCRYPTION_AES_REQ = 66
LMP_PAUSE_ENCRYPTION_REQ = 127 << 8 + 23
LMP_PING_REQ = 127 << 8 + 33
LMP_PING_RES = 127 << 8 + 34
LMP_POWER_CONTROL_REQ = 127 << 8 + 31
LMP_POWER_CONTROL_RES = 127 << 8 + 32
LMP_PREFERRED_RATE = 36
LMP_QUALITY_OF_SERVICE = 41
LMP_QUALITY_OF_SERVICE_REQ = 42
LMP_REMOVE_ESCO_LINK_REQ = 127 << 8 + 13
LMP_REMOVE_SCO_LINK_REQ = 44
LMP_RESUME_ENCRYPTION_REQ = 127 << 8 + 24
LMP_SAM_DEFINE_MAP = 127 << 8 + 36
LMP_SAM_SET_TYPE0 = 127 << 8 + 35
LMP_SAM_SWITCH = 127 << 8 + 37
LMP_SCO_LINK_REQ = 43
LMP_SET_AFH = 60
LMP_SETUP_COMPLETE = 49
LMP_SIMPLE_PAIRING_CONFIRM = 63
LMP_SIMPLE_PAIRING_NUMBER = 64
LMP_SLOT_OFFSET = 52
LMP_SNIFF_REQ = 23
LMP_SNIFF_SUBRATING_REQ = 127 << 8 + 21
LMP_SNIFF_SUBRATING_RES = 127 << 8 + 22
LMP_SRES = 12
LMP_START_ENCRYPTION_REQ = 17
LMP_STOP_ENCRYPTION_REQ = 18
LMP_SUPERVISION_TIMEOUT = 55
LMP_SWITCH_REQ = 19
LMP_TEMP_KEY = 14
LMP_TEMP_RAND = 13
LMP_TEST_ACTIVATE = 56
LMP_TEST_CONTROL = 57
LMP_TIMING_ACCURACY_REQ = 47
LMP_TIMING_ACCURACY_RES = 48
LMP_UNIT_KEY = 10
LMP_UNSNIFF_REQ = 24
LMP_USE_SEMI_PERMANENT_KEY = 50
LMP_VERSION_REQ = 37
LMP_VERSION_RES = 38
# fmt: on
@classmethod
def parse_from(cls, data: bytes, offset: int = 0) -> tuple[int, Opcode]:
opcode = data[offset]
if opcode in (124, 127):
opcode = struct.unpack('>H', data)[0]
return offset + 2, Opcode(opcode)
return offset + 1, Opcode(opcode)
def __bytes__(self) -> bytes:
if self.value >> 8:
return struct.pack('>H', self.value)
return bytes([self.value])
@classmethod
def type_metadata(cls):
return hci.metadata(
{
'serializer': bytes,
'parser': lambda data, offset: (Opcode.parse_from(data, offset)),
}
)
class Packet:
'''
See Bluetooth spec @ Vol 2, Part C - 5.1 PDU summary
'''
subclasses: dict[int, type[Packet]] = {}
opcode: Opcode
fields: hci.Fields = ()
_payload: bytes = b''
_Packet = TypeVar("_Packet", bound="Packet")
@classmethod
def subclass(cls, subclass: type[_Packet]) -> type[_Packet]:
# Register a factory for this class
cls.subclasses[subclass.opcode] = subclass
subclass.fields = hci.HCI_Object.fields_from_dataclass(subclass)
return subclass
@classmethod
def from_bytes(cls, data: bytes) -> Packet:
offset, opcode = Opcode.parse_from(data)
if not (subclass := cls.subclasses.get(opcode)):
instance = Packet()
instance.opcode = opcode
else:
instance = subclass(
**hci.HCI_Object.dict_from_bytes(data, offset, subclass.fields)
)
instance.payload = data[offset:]
return instance
@property
def payload(self) -> bytes:
if self._payload is None:
self._payload = hci.HCI_Object.dict_to_bytes(self.__dict__, self.fields)
return self._payload
@payload.setter
def payload(self, value: bytes) -> None:
self._payload = value
def __bytes__(self) -> bytes:
return bytes(self.opcode) + self.payload
@Packet.subclass
@dataclass
class LmpAccepted(Packet):
opcode = Opcode.LMP_ACCEPTED
response_opcode: Opcode = field(metadata=Opcode.type_metadata())
@Packet.subclass
@dataclass
class LmpNotAccepted(Packet):
opcode = Opcode.LMP_NOT_ACCEPTED
response_opcode: Opcode = field(metadata=Opcode.type_metadata())
error_code: int = field(metadata=hci.metadata(1))
@Packet.subclass
@dataclass
class LmpAcceptedExt(Packet):
opcode = Opcode.LMP_ACCEPTED_EXT
response_opcode: Opcode = field(metadata=Opcode.type_metadata())
@Packet.subclass
@dataclass
class LmpNotAcceptedExt(Packet):
opcode = Opcode.LMP_NOT_ACCEPTED_EXT
response_opcode: Opcode = field(metadata=Opcode.type_metadata())
error_code: int = field(metadata=hci.metadata(1))
@Packet.subclass
@dataclass
class LmpAuRand(Packet):
opcode = Opcode.LMP_AU_RAND
random_number: bytes = field(metadata=hci.metadata(16))
@Packet.subclass
@dataclass
class LmpDetach(Packet):
opcode = Opcode.LMP_DETACH
error_code: int = field(metadata=hci.metadata(1))
@Packet.subclass
@dataclass
class LmpEscoLinkReq(Packet):
opcode = Opcode.LMP_ESCO_LINK_REQ
esco_handle: int = field(metadata=hci.metadata(1))
esco_lt_addr: int = field(metadata=hci.metadata(1))
timing_control_flags: int = field(metadata=hci.metadata(1))
d_esco: int = field(metadata=hci.metadata(1))
t_esco: int = field(metadata=hci.metadata(1))
w_esco: int = field(metadata=hci.metadata(1))
esco_packet_type_c_to_p: int = field(metadata=hci.metadata(1))
esco_packet_type_p_to_c: int = field(metadata=hci.metadata(1))
packet_length_c_to_p: int = field(metadata=hci.metadata(2))
packet_length_p_to_c: int = field(metadata=hci.metadata(2))
air_mode: int = field(metadata=hci.metadata(1))
negotiation_state: int = field(metadata=hci.metadata(1))
@Packet.subclass
@dataclass
class LmpHostConnectionReq(Packet):
opcode = Opcode.LMP_HOST_CONNECTION_REQ
@Packet.subclass
@dataclass
class LmpRemoveEscoLinkReq(Packet):
opcode = Opcode.LMP_REMOVE_ESCO_LINK_REQ
esco_handle: int = field(metadata=hci.metadata(1))
error_code: int = field(metadata=hci.metadata(1))
@Packet.subclass
@dataclass
class LmpRemoveScoLinkReq(Packet):
opcode = Opcode.LMP_REMOVE_SCO_LINK_REQ
sco_handle: int = field(metadata=hci.metadata(1))
error_code: int = field(metadata=hci.metadata(1))
@Packet.subclass
@dataclass
class LmpScoLinkReq(Packet):
opcode = Opcode.LMP_SCO_LINK_REQ
sco_handle: int = field(metadata=hci.metadata(1))
timing_control_flags: int = field(metadata=hci.metadata(1))
d_sco: int = field(metadata=hci.metadata(1))
t_sco: int = field(metadata=hci.metadata(1))
sco_packet: int = field(metadata=hci.metadata(1))
air_mode: int = field(metadata=hci.metadata(1))
@Packet.subclass
@dataclass
class LmpSwitchReq(Packet):
opcode = Opcode.LMP_SWITCH_REQ
switch_instant: int = field(metadata=hci.metadata(4), default=0)
@Packet.subclass
@dataclass
class LmpNameReq(Packet):
opcode = Opcode.LMP_NAME_REQ
name_offset: int = field(metadata=hci.metadata(2))
@Packet.subclass
@dataclass
class LmpNameRes(Packet):
opcode = Opcode.LMP_NAME_RES
name_offset: int = field(metadata=hci.metadata(2))
name_length: int = field(metadata=hci.metadata(3))
name_fregment: bytes = field(metadata=hci.metadata('*'))
@Packet.subclass
@dataclass
class LmpFeaturesReq(Packet):
opcode = Opcode.LMP_FEATURES_REQ
features: bytes = field(metadata=hci.metadata(8))
@Packet.subclass
@dataclass
class LmpFeaturesRes(Packet):
opcode = Opcode.LMP_FEATURES_RES
features: bytes = field(metadata=hci.metadata(8))
@Packet.subclass
@dataclass
class LmpFeaturesReqExt(Packet):
opcode = Opcode.LMP_FEATURES_REQ_EXT
features_page: int = field(metadata=hci.metadata(1))
features: bytes = field(metadata=hci.metadata(8))
@Packet.subclass
@dataclass
class LmpFeaturesResExt(Packet):
opcode = Opcode.LMP_FEATURES_RES_EXT
features_page: int = field(metadata=hci.metadata(1))
max_features_page: int = field(metadata=hci.metadata(1))
features: bytes = field(metadata=hci.metadata(8))
-65
View File
@@ -1,65 +0,0 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import functools
import logging
import os
from bumble import colors
# -----------------------------------------------------------------------------
class ColorFormatter(logging.Formatter):
_colorizers = {
logging.DEBUG: functools.partial(colors.color, fg="white"),
logging.INFO: functools.partial(colors.color, fg="green"),
logging.WARNING: functools.partial(colors.color, fg="yellow"),
logging.ERROR: functools.partial(colors.color, fg="red"),
logging.CRITICAL: functools.partial(colors.color, fg="black", bg="red"),
}
_formatters = {
level: logging.Formatter(
fmt=colorizer("{asctime}.{msecs:03.0f} {levelname:.1} {name}: ")
+ "{message}",
datefmt="%H:%M:%S",
style="{",
)
for level, colorizer in _colorizers.items()
}
def format(self, record: logging.LogRecord) -> str:
return self._formatters[record.levelno].format(record)
def setup_basic_logging(default_level: str = "INFO") -> None:
"""
Set up basic logging with logging.basicConfig, configured with a simple formatter
that prints out the date and log level in color.
If the BUMBLE_LOGLEVEL environment variable is set to the name of a log level, it
is used. Otherwise the default_level argument is used.
Args:
default_level: default logging level
"""
handler = logging.StreamHandler()
handler.setFormatter(ColorFormatter())
logging.basicConfig(
level=os.environ.get("BUMBLE_LOGLEVEL", default_level).upper(),
handlers=[handler],
)
+51 -45
View File
@@ -16,18 +16,32 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations from __future__ import annotations
import enum import enum
import secrets
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple
from bumble import hci, smp from bumble.hci import (
from bumble.core import AdvertisingData, LeRole Address,
HCI_NO_INPUT_NO_OUTPUT_IO_CAPABILITY,
HCI_DISPLAY_ONLY_IO_CAPABILITY,
HCI_DISPLAY_YES_NO_IO_CAPABILITY,
HCI_KEYBOARD_ONLY_IO_CAPABILITY,
)
from bumble.smp import ( from bumble.smp import (
SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY,
SMP_KEYBOARD_ONLY_IO_CAPABILITY,
SMP_DISPLAY_ONLY_IO_CAPABILITY,
SMP_DISPLAY_YES_NO_IO_CAPABILITY,
SMP_KEYBOARD_DISPLAY_IO_CAPABILITY,
SMP_ENC_KEY_DISTRIBUTION_FLAG,
SMP_ID_KEY_DISTRIBUTION_FLAG,
SMP_SIGN_KEY_DISTRIBUTION_FLAG,
SMP_LINK_KEY_DISTRIBUTION_FLAG,
OobContext, OobContext,
OobLegacyContext, OobLegacyContext,
OobSharedData, OobSharedData,
) )
from bumble.core import AdvertisingData, LeRole
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -35,19 +49,19 @@ from bumble.smp import (
class OobData: class OobData:
"""OOB data that can be sent from one device to another.""" """OOB data that can be sent from one device to another."""
address: hci.Address | None = None address: Optional[Address] = None
role: LeRole | None = None role: Optional[LeRole] = None
shared_data: OobSharedData | None = None shared_data: Optional[OobSharedData] = None
legacy_context: OobLegacyContext | None = None legacy_context: Optional[OobLegacyContext] = None
@classmethod @classmethod
def from_ad(cls, ad: AdvertisingData) -> OobData: def from_ad(cls, ad: AdvertisingData) -> OobData:
instance = cls() instance = cls()
shared_data_c: bytes | None = None shared_data_c: Optional[bytes] = None
shared_data_r: bytes | None = None shared_data_r: Optional[bytes] = None
for ad_type, ad_data in ad.ad_structures: for ad_type, ad_data in ad.ad_structures:
if ad_type == AdvertisingData.LE_BLUETOOTH_DEVICE_ADDRESS: if ad_type == AdvertisingData.LE_BLUETOOTH_DEVICE_ADDRESS:
instance.address = hci.Address(ad_data) instance.address = Address(ad_data)
elif ad_type == AdvertisingData.LE_ROLE: elif ad_type == AdvertisingData.LE_ROLE:
instance.role = LeRole(ad_data[0]) instance.role = LeRole(ad_data[0])
elif ad_type == AdvertisingData.LE_SECURE_CONNECTIONS_CONFIRMATION_VALUE: elif ad_type == AdvertisingData.LE_SECURE_CONNECTIONS_CONFIRMATION_VALUE:
@@ -87,11 +101,11 @@ class PairingDelegate:
# These are defined abstractly, and can be mapped to specific Classic pairing # These are defined abstractly, and can be mapped to specific Classic pairing
# and/or SMP constants. # and/or SMP constants.
class IoCapability(enum.IntEnum): class IoCapability(enum.IntEnum):
NO_OUTPUT_NO_INPUT = smp.IoCapability.NO_INPUT_NO_OUTPUT NO_OUTPUT_NO_INPUT = SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY
KEYBOARD_INPUT_ONLY = smp.IoCapability.KEYBOARD_ONLY KEYBOARD_INPUT_ONLY = SMP_KEYBOARD_ONLY_IO_CAPABILITY
DISPLAY_OUTPUT_ONLY = smp.IoCapability.DISPLAY_ONLY DISPLAY_OUTPUT_ONLY = SMP_DISPLAY_ONLY_IO_CAPABILITY
DISPLAY_OUTPUT_AND_YES_NO_INPUT = smp.IoCapability.DISPLAY_YES_NO DISPLAY_OUTPUT_AND_YES_NO_INPUT = SMP_DISPLAY_YES_NO_IO_CAPABILITY
DISPLAY_OUTPUT_AND_KEYBOARD_INPUT = smp.IoCapability.KEYBOARD_DISPLAY DISPLAY_OUTPUT_AND_KEYBOARD_INPUT = SMP_KEYBOARD_DISPLAY_IO_CAPABILITY
# Direct names for backward compatibility. # Direct names for backward compatibility.
NO_OUTPUT_NO_INPUT = IoCapability.NO_OUTPUT_NO_INPUT NO_OUTPUT_NO_INPUT = IoCapability.NO_OUTPUT_NO_INPUT
@@ -102,10 +116,10 @@ class PairingDelegate:
# Key Distribution [LE only] # Key Distribution [LE only]
class KeyDistribution(enum.IntFlag): class KeyDistribution(enum.IntFlag):
DISTRIBUTE_ENCRYPTION_KEY = smp.KeyDistribution.ENC_KEY DISTRIBUTE_ENCRYPTION_KEY = SMP_ENC_KEY_DISTRIBUTION_FLAG
DISTRIBUTE_IDENTITY_KEY = smp.KeyDistribution.ID_KEY DISTRIBUTE_IDENTITY_KEY = SMP_ID_KEY_DISTRIBUTION_FLAG
DISTRIBUTE_SIGNING_KEY = smp.KeyDistribution.SIGN_KEY DISTRIBUTE_SIGNING_KEY = SMP_SIGN_KEY_DISTRIBUTION_FLAG
DISTRIBUTE_LINK_KEY = smp.KeyDistribution.LINK_KEY DISTRIBUTE_LINK_KEY = SMP_LINK_KEY_DISTRIBUTION_FLAG
DEFAULT_KEY_DISTRIBUTION: KeyDistribution = ( DEFAULT_KEY_DISTRIBUTION: KeyDistribution = (
KeyDistribution.DISTRIBUTE_ENCRYPTION_KEY KeyDistribution.DISTRIBUTE_ENCRYPTION_KEY
@@ -115,11 +129,11 @@ class PairingDelegate:
# Default mapping from abstract to Classic I/O capabilities. # Default mapping from abstract to Classic I/O capabilities.
# Subclasses may override this if they prefer a different mapping. # Subclasses may override this if they prefer a different mapping.
CLASSIC_IO_CAPABILITIES_MAP = { CLASSIC_IO_CAPABILITIES_MAP = {
NO_OUTPUT_NO_INPUT: hci.IoCapability.NO_INPUT_NO_OUTPUT, NO_OUTPUT_NO_INPUT: HCI_NO_INPUT_NO_OUTPUT_IO_CAPABILITY,
KEYBOARD_INPUT_ONLY: hci.IoCapability.KEYBOARD_ONLY, KEYBOARD_INPUT_ONLY: HCI_KEYBOARD_ONLY_IO_CAPABILITY,
DISPLAY_OUTPUT_ONLY: hci.IoCapability.DISPLAY_ONLY, DISPLAY_OUTPUT_ONLY: HCI_DISPLAY_ONLY_IO_CAPABILITY,
DISPLAY_OUTPUT_AND_YES_NO_INPUT: hci.IoCapability.DISPLAY_YES_NO, DISPLAY_OUTPUT_AND_YES_NO_INPUT: HCI_DISPLAY_YES_NO_IO_CAPABILITY,
DISPLAY_OUTPUT_AND_KEYBOARD_INPUT: hci.IoCapability.DISPLAY_YES_NO, DISPLAY_OUTPUT_AND_KEYBOARD_INPUT: HCI_DISPLAY_YES_NO_IO_CAPABILITY,
} }
io_capability: IoCapability io_capability: IoCapability
@@ -145,7 +159,7 @@ class PairingDelegate:
# pylint: disable=line-too-long # pylint: disable=line-too-long
return self.CLASSIC_IO_CAPABILITIES_MAP.get( return self.CLASSIC_IO_CAPABILITIES_MAP.get(
self.io_capability, hci.IoCapability.NO_INPUT_NO_OUTPUT self.io_capability, HCI_NO_INPUT_NO_OUTPUT_IO_CAPABILITY
) )
@property @property
@@ -171,14 +185,14 @@ class PairingDelegate:
"""Compare two numbers.""" """Compare two numbers."""
return True return True
async def get_number(self) -> int | None: async def get_number(self) -> Optional[int]:
""" """
Return an optional number as an answer to a passkey request. Return an optional number as an answer to a passkey request.
Returning `None` will result in a negative reply. Returning `None` will result in a negative reply.
""" """
return 0 return 0
async def get_string(self, max_length: int) -> str | None: async def get_string(self, max_length: int) -> Optional[str]:
""" """
Return a string whose utf-8 encoding is up to max_length bytes. Return a string whose utf-8 encoding is up to max_length bytes.
""" """
@@ -191,7 +205,7 @@ class PairingDelegate:
# [LE only] # [LE only]
async def key_distribution_response( async def key_distribution_response(
self, peer_initiator_key_distribution: int, peer_responder_key_distribution: int self, peer_initiator_key_distribution: int, peer_responder_key_distribution: int
) -> tuple[int, int]: ) -> Tuple[int, int]:
""" """
Return the key distribution response in an SMP protocol context. Return the key distribution response in an SMP protocol context.
@@ -208,39 +222,31 @@ class PairingDelegate:
), ),
) )
async def generate_passkey(self) -> int:
"""
Return a passkey value between 0 and 999999 (inclusive).
"""
# By default, generate a random passkey.
return secrets.randbelow(1000000)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class PairingConfig: class PairingConfig:
"""Configuration for the Pairing protocol.""" """Configuration for the Pairing protocol."""
class AddressType(enum.IntEnum): class AddressType(enum.IntEnum):
PUBLIC = hci.Address.PUBLIC_DEVICE_ADDRESS PUBLIC = Address.PUBLIC_DEVICE_ADDRESS
RANDOM = hci.Address.RANDOM_DEVICE_ADDRESS RANDOM = Address.RANDOM_DEVICE_ADDRESS
@dataclass @dataclass
class OobConfig: class OobConfig:
"""Config for OOB pairing.""" """Config for OOB pairing."""
our_context: OobContext | None our_context: Optional[OobContext]
peer_data: OobSharedData | None peer_data: Optional[OobSharedData]
legacy_context: OobLegacyContext | None legacy_context: Optional[OobLegacyContext]
def __init__( def __init__(
self, self,
sc: bool = True, sc: bool = True,
mitm: bool = True, mitm: bool = True,
bonding: bool = True, bonding: bool = True,
delegate: PairingDelegate | None = None, delegate: Optional[PairingDelegate] = None,
identity_address_type: AddressType | None = None, identity_address_type: Optional[AddressType] = None,
oob: OobConfig | None = None, oob: Optional[OobConfig] = None,
) -> None: ) -> None:
self.sc = sc self.sc = sc
self.mitm = mitm self.mitm = mitm
+10 -11
View File
@@ -19,22 +19,21 @@ This module implement the Pandora Bluetooth test APIs for the Bumble stack.
__version__ = "0.0.1" __version__ = "0.0.1"
from collections.abc import Callable
import grpc import grpc
import grpc.aio import grpc.aio
from pandora.host_grpc_aio import add_HostServicer_to_server
from pandora.l2cap_grpc_aio import add_L2CAPServicer_to_server
from pandora.security_grpc_aio import (
add_SecurityServicer_to_server,
add_SecurityStorageServicer_to_server,
)
from bumble.pandora.config import Config from bumble.pandora.config import Config
from bumble.pandora.device import PandoraDevice from bumble.pandora.device import PandoraDevice
from bumble.pandora.host import HostService from bumble.pandora.host import HostService
from bumble.pandora.l2cap import L2CAPService from bumble.pandora.l2cap import L2CAPService
from bumble.pandora.security import SecurityService, SecurityStorageService from bumble.pandora.security import SecurityService, SecurityStorageService
from pandora.host_grpc_aio import add_HostServicer_to_server
from pandora.l2cap_grpc_aio import add_L2CAPServicer_to_server
from pandora.security_grpc_aio import (
add_SecurityServicer_to_server,
add_SecurityStorageServicer_to_server,
)
from typing import Callable, List, Optional
# public symbols # public symbols
__all__ = [ __all__ = [
@@ -46,11 +45,11 @@ __all__ = [
# Add servicers hooks. # Add servicers hooks.
_SERVICERS_HOOKS: list[Callable[[PandoraDevice, Config, grpc.aio.Server], None]] = [] _SERVICERS_HOOKS: List[Callable[[PandoraDevice, Config, grpc.aio.Server], None]] = []
def register_servicer_hook( def register_servicer_hook(
hook: Callable[[PandoraDevice, Config, grpc.aio.Server], None], hook: Callable[[PandoraDevice, Config, grpc.aio.Server], None]
) -> None: ) -> None:
_SERVICERS_HOOKS.append(hook) _SERVICERS_HOOKS.append(hook)
@@ -58,7 +57,7 @@ def register_servicer_hook(
async def serve( async def serve(
bumble: PandoraDevice, bumble: PandoraDevice,
config: Config = Config(), config: Config = Config(),
grpc_server: grpc.aio.Server | None = None, grpc_server: Optional[grpc.aio.Server] = None,
port: int = 0, port: int = 0,
) -> None: ) -> None:
# initialize a gRPC server if not provided. # initialize a gRPC server if not provided.
+3 -5
View File
@@ -13,11 +13,9 @@
# limitations under the License. # limitations under the License.
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass
from typing import Any
from bumble.pairing import PairingConfig, PairingDelegate from bumble.pairing import PairingConfig, PairingDelegate
from dataclasses import dataclass
from typing import Any, Dict
@dataclass @dataclass
@@ -34,7 +32,7 @@ class Config:
PairingDelegate.DEFAULT_KEY_DISTRIBUTION PairingDelegate.DEFAULT_KEY_DISTRIBUTION
) )
def load_from_dict(self, config: dict[str, Any]) -> None: def load_from_dict(self, config: Dict[str, Any]) -> None:
io_capability_name: str = config.get( io_capability_name: str = config.get(
'io_capability', 'no_output_no_input' 'io_capability', 'no_output_no_input'
).upper() ).upper()
+9 -12
View File
@@ -15,9 +15,6 @@
"""Generic & dependency free Bumble (reference) device.""" """Generic & dependency free Bumble (reference) device."""
from __future__ import annotations from __future__ import annotations
from typing import Any
from bumble import transport from bumble import transport
from bumble.core import ( from bumble.core import (
BT_GENERIC_AUDIO_SERVICE, BT_GENERIC_AUDIO_SERVICE,
@@ -35,6 +32,8 @@ from bumble.sdp import (
DataElement, DataElement,
ServiceAttribute, ServiceAttribute,
) )
from typing import Any, Dict, List, Optional
# Default rootcanal HCI TCP address # Default rootcanal HCI TCP address
ROOTCANAL_HCI_ADDRESS = "localhost:6402" ROOTCANAL_HCI_ADDRESS = "localhost:6402"
@@ -50,13 +49,13 @@ class PandoraDevice:
# Bumble device instance & configuration. # Bumble device instance & configuration.
device: Device device: Device
config: dict[str, Any] config: Dict[str, Any]
# HCI transport name & instance. # HCI transport name & instance.
_hci_name: str _hci_name: str
_hci: transport.Transport | None # type: ignore[name-defined] _hci: Optional[transport.Transport] # type: ignore[name-defined]
def __init__(self, config: dict[str, Any]) -> None: def __init__(self, config: Dict[str, Any]) -> None:
self.config = config self.config = config
self.device = _make_device(config) self.device = _make_device(config)
self._hci_name = config.get( self._hci_name = config.get(
@@ -74,9 +73,7 @@ class PandoraDevice:
# open HCI transport & set device host. # open HCI transport & set device host.
self._hci = await transport.open_transport(self._hci_name) self._hci = await transport.open_transport(self._hci_name)
self.device.host = Host( self.device.host = Host(controller_source=self._hci.source, controller_sink=self._hci.sink) # type: ignore[no-untyped-call]
controller_source=self._hci.source, controller_sink=self._hci.sink
) # type: ignore[no-untyped-call]
# power-on. # power-on.
await self.device.power_on() await self.device.power_on()
@@ -98,14 +95,14 @@ class PandoraDevice:
await self.close() await self.close()
await self.open() await self.open()
def info(self) -> dict[str, str] | None: def info(self) -> Optional[Dict[str, str]]:
return { return {
'public_bd_address': str(self.device.public_address), 'public_bd_address': str(self.device.public_address),
'random_address': str(self.device.random_address), 'random_address': str(self.device.random_address),
} }
def _make_device(config: dict[str, Any]) -> Device: def _make_device(config: Dict[str, Any]) -> Device:
"""Initialize an idle Bumble device instance.""" """Initialize an idle Bumble device instance."""
# initialize bumble device. # initialize bumble device.
@@ -120,7 +117,7 @@ def _make_device(config: dict[str, Any]) -> Device:
# TODO(b/267540823): remove when Pandora A2dp is supported # TODO(b/267540823): remove when Pandora A2dp is supported
def _make_sdp_records(rfcomm_channel: int) -> dict[int, list[ServiceAttribute]]: def _make_sdp_records(rfcomm_channel: int) -> Dict[int, List[ServiceAttribute]]:
return { return {
0x00010001: [ 0x00010001: [
ServiceAttribute( ServiceAttribute(
+71 -81
View File
@@ -13,26 +13,51 @@
# limitations under the License. # limitations under the License.
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import logging import bumble.device
import struct
from collections.abc import AsyncGenerator
from typing import cast
import grpc import grpc
import grpc.aio import grpc.aio
from google.protobuf import ( import logging
any_pb2, # pytype: disable=pyi-error import struct
empty_pb2, # pytype: disable=pyi-error
import bumble.utils
from bumble.pandora import utils
from bumble.pandora.config import Config
from bumble.core import (
PhysicalTransport,
UUID,
AdvertisingData,
Appearance,
ConnectionError,
) )
from pandora import host_pb2 from bumble.device import (
DEVICE_DEFAULT_SCAN_INTERVAL,
DEVICE_DEFAULT_SCAN_WINDOW,
Advertisement,
AdvertisingParameters,
AdvertisingEventProperties,
AdvertisingType,
Device,
)
from bumble.gatt import Service
from bumble.hci import (
HCI_CONNECTION_ALREADY_EXISTS_ERROR,
HCI_PAGE_TIMEOUT_ERROR,
HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR,
Address,
Phy,
Role,
OwnAddressType,
)
from google.protobuf import any_pb2 # pytype: disable=pyi-error
from google.protobuf import empty_pb2 # pytype: disable=pyi-error
from pandora.host_grpc_aio import HostServicer from pandora.host_grpc_aio import HostServicer
from pandora import host_pb2
from pandora.host_pb2 import ( from pandora.host_pb2 import (
DISCOVERABLE_GENERAL,
DISCOVERABLE_LIMITED,
NOT_CONNECTABLE, NOT_CONNECTABLE,
NOT_DISCOVERABLE, NOT_DISCOVERABLE,
DISCOVERABLE_LIMITED,
DISCOVERABLE_GENERAL,
PRIMARY_1M, PRIMARY_1M,
PRIMARY_CODED, PRIMARY_CODED,
SECONDARY_1M, SECONDARY_1M,
@@ -48,6 +73,7 @@ from pandora.host_pb2 import (
ConnectResponse, ConnectResponse,
DataTypes, DataTypes,
DisconnectRequest, DisconnectRequest,
DiscoverabilityMode,
InquiryResponse, InquiryResponse,
PrimaryPhy, PrimaryPhy,
ReadLocalAddressResponse, ReadLocalAddressResponse,
@@ -60,39 +86,9 @@ from pandora.host_pb2 import (
WaitConnectionResponse, WaitConnectionResponse,
WaitDisconnectionRequest, WaitDisconnectionRequest,
) )
from typing import AsyncGenerator, Dict, List, Optional, Set, Tuple, cast
import bumble.device PRIMARY_PHY_MAP: Dict[int, PrimaryPhy] = {
import bumble.utils
from bumble.core import (
UUID,
AdvertisingData,
Appearance,
ConnectionError,
PhysicalTransport,
)
from bumble.device import (
DEVICE_DEFAULT_SCAN_INTERVAL,
DEVICE_DEFAULT_SCAN_WINDOW,
Advertisement,
AdvertisingEventProperties,
AdvertisingParameters,
AdvertisingType,
Device,
)
from bumble.gatt import Service
from bumble.hci import (
HCI_CONNECTION_ALREADY_EXISTS_ERROR,
HCI_PAGE_TIMEOUT_ERROR,
HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR,
Address,
OwnAddressType,
Phy,
Role,
)
from bumble.pandora import utils
from bumble.pandora.config import Config
PRIMARY_PHY_MAP: dict[int, PrimaryPhy] = {
# Default value reported by Bumble for legacy Advertising reports. # Default value reported by Bumble for legacy Advertising reports.
# FIXME(uael): `None` might be a better value, but Bumble need to change accordingly. # FIXME(uael): `None` might be a better value, but Bumble need to change accordingly.
0: PRIMARY_1M, 0: PRIMARY_1M,
@@ -100,26 +96,26 @@ PRIMARY_PHY_MAP: dict[int, PrimaryPhy] = {
3: PRIMARY_CODED, 3: PRIMARY_CODED,
} }
SECONDARY_PHY_MAP: dict[int, SecondaryPhy] = { SECONDARY_PHY_MAP: Dict[int, SecondaryPhy] = {
0: SECONDARY_NONE, 0: SECONDARY_NONE,
1: SECONDARY_1M, 1: SECONDARY_1M,
2: SECONDARY_2M, 2: SECONDARY_2M,
3: SECONDARY_CODED, 3: SECONDARY_CODED,
} }
PRIMARY_PHY_TO_BUMBLE_PHY_MAP: dict[PrimaryPhy, Phy] = { PRIMARY_PHY_TO_BUMBLE_PHY_MAP: Dict[PrimaryPhy, Phy] = {
PRIMARY_1M: Phy.LE_1M, PRIMARY_1M: Phy.LE_1M,
PRIMARY_CODED: Phy.LE_CODED, PRIMARY_CODED: Phy.LE_CODED,
} }
SECONDARY_PHY_TO_BUMBLE_PHY_MAP: dict[SecondaryPhy, Phy] = { SECONDARY_PHY_TO_BUMBLE_PHY_MAP: Dict[SecondaryPhy, Phy] = {
SECONDARY_NONE: Phy.LE_1M, SECONDARY_NONE: Phy.LE_1M,
SECONDARY_1M: Phy.LE_1M, SECONDARY_1M: Phy.LE_1M,
SECONDARY_2M: Phy.LE_2M, SECONDARY_2M: Phy.LE_2M,
SECONDARY_CODED: Phy.LE_CODED, SECONDARY_CODED: Phy.LE_CODED,
} }
OWN_ADDRESS_MAP: dict[host_pb2.OwnAddressType, OwnAddressType] = { OWN_ADDRESS_MAP: Dict[host_pb2.OwnAddressType, OwnAddressType] = {
host_pb2.PUBLIC: OwnAddressType.PUBLIC, host_pb2.PUBLIC: OwnAddressType.PUBLIC,
host_pb2.RANDOM: OwnAddressType.RANDOM, host_pb2.RANDOM: OwnAddressType.RANDOM,
host_pb2.RESOLVABLE_OR_PUBLIC: OwnAddressType.RESOLVABLE_OR_PUBLIC, host_pb2.RESOLVABLE_OR_PUBLIC: OwnAddressType.RESOLVABLE_OR_PUBLIC,
@@ -128,7 +124,7 @@ OWN_ADDRESS_MAP: dict[host_pb2.OwnAddressType, OwnAddressType] = {
class HostService(HostServicer): class HostService(HostServicer):
waited_connections: set[int] waited_connections: Set[int]
def __init__( def __init__(
self, grpc_server: grpc.aio.Server, device: Device, config: Config self, grpc_server: grpc.aio.Server, device: Device, config: Config
@@ -305,9 +301,7 @@ class HostService(HostServicer):
await disconnection_future await disconnection_future
self.log.debug("Disconnected") self.log.debug("Disconnected")
finally: finally:
connection.remove_listener( connection.remove_listener(connection.EVENT_DISCONNECTION, on_disconnection) # type: ignore
connection.EVENT_DISCONNECTION, on_disconnection
) # type: ignore
return empty_pb2.Empty() return empty_pb2.Empty()
@@ -544,7 +538,7 @@ class HostService(HostServicer):
await bumble.utils.cancel_on_event( await bumble.utils.cancel_on_event(
self.device, 'flush', self.device.stop_advertising() self.device, 'flush', self.device.stop_advertising()
) )
except Exception: except:
pass pass
@utils.rpc @utils.rpc
@@ -614,7 +608,7 @@ class HostService(HostServicer):
await bumble.utils.cancel_on_event( await bumble.utils.cancel_on_event(
self.device, 'flush', self.device.stop_scanning() self.device, 'flush', self.device.stop_scanning()
) )
except Exception: except:
pass pass
@utils.rpc @utils.rpc
@@ -624,7 +618,7 @@ class HostService(HostServicer):
self.log.debug('Inquiry') self.log.debug('Inquiry')
inquiry_queue: asyncio.Queue[ inquiry_queue: asyncio.Queue[
tuple[Address, int, AdvertisingData, int] | None Optional[Tuple[Address, int, AdvertisingData, int]]
] = asyncio.Queue() ] = asyncio.Queue()
complete_handler = self.device.on( complete_handler = self.device.on(
self.device.EVENT_INQUIRY_COMPLETE, lambda: inquiry_queue.put_nowait(None) self.device.EVENT_INQUIRY_COMPLETE, lambda: inquiry_queue.put_nowait(None)
@@ -649,18 +643,14 @@ class HostService(HostServicer):
) )
finally: finally:
self.device.remove_listener( self.device.remove_listener(self.device.EVENT_INQUIRY_COMPLETE, complete_handler) # type: ignore
self.device.EVENT_INQUIRY_COMPLETE, complete_handler self.device.remove_listener(self.device.EVENT_INQUIRY_RESULT, result_handler) # type: ignore
) # type: ignore
self.device.remove_listener(
self.device.EVENT_INQUIRY_RESULT, result_handler
) # type: ignore
try: try:
self.log.debug('Stop inquiry') self.log.debug('Stop inquiry')
await bumble.utils.cancel_on_event( await bumble.utils.cancel_on_event(
self.device, 'flush', self.device.stop_discovery() self.device, 'flush', self.device.stop_discovery()
) )
except Exception: except:
pass pass
@utils.rpc @utils.rpc
@@ -680,10 +670,10 @@ class HostService(HostServicer):
return empty_pb2.Empty() return empty_pb2.Empty()
def unpack_data_types(self, dt: DataTypes) -> AdvertisingData: def unpack_data_types(self, dt: DataTypes) -> AdvertisingData:
ad_structures: list[tuple[int, bytes]] = [] ad_structures: List[Tuple[int, bytes]] = []
uuids: list[str] uuids: List[str]
datas: dict[str, bytes] datas: Dict[str, bytes]
def uuid128_from_str(uuid: str) -> bytes: def uuid128_from_str(uuid: str) -> bytes:
"""Decode a 128-bit uuid encoded as XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX """Decode a 128-bit uuid encoded as XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX
@@ -897,50 +887,50 @@ class HostService(HostServicer):
def pack_data_types(self, ad: AdvertisingData) -> DataTypes: def pack_data_types(self, ad: AdvertisingData) -> DataTypes:
dt = DataTypes() dt = DataTypes()
uuids: list[UUID] uuids: List[UUID]
s: str s: str
i: int i: int
ij: tuple[int, int] ij: Tuple[int, int]
uuid_data: tuple[UUID, bytes] uuid_data: Tuple[UUID, bytes]
data: bytes data: bytes
if uuids := cast( if uuids := cast(
list[UUID], List[UUID],
ad.get(AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS), ad.get(AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS),
): ):
dt.incomplete_service_class_uuids16.extend( dt.incomplete_service_class_uuids16.extend(
list(map(lambda x: x.to_hex_str('-'), uuids)) list(map(lambda x: x.to_hex_str('-'), uuids))
) )
if uuids := cast( if uuids := cast(
list[UUID], List[UUID],
ad.get(AdvertisingData.COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS), ad.get(AdvertisingData.COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS),
): ):
dt.complete_service_class_uuids16.extend( dt.complete_service_class_uuids16.extend(
list(map(lambda x: x.to_hex_str('-'), uuids)) list(map(lambda x: x.to_hex_str('-'), uuids))
) )
if uuids := cast( if uuids := cast(
list[UUID], List[UUID],
ad.get(AdvertisingData.INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS), ad.get(AdvertisingData.INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS),
): ):
dt.incomplete_service_class_uuids32.extend( dt.incomplete_service_class_uuids32.extend(
list(map(lambda x: x.to_hex_str('-'), uuids)) list(map(lambda x: x.to_hex_str('-'), uuids))
) )
if uuids := cast( if uuids := cast(
list[UUID], List[UUID],
ad.get(AdvertisingData.COMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS), ad.get(AdvertisingData.COMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS),
): ):
dt.complete_service_class_uuids32.extend( dt.complete_service_class_uuids32.extend(
list(map(lambda x: x.to_hex_str('-'), uuids)) list(map(lambda x: x.to_hex_str('-'), uuids))
) )
if uuids := cast( if uuids := cast(
list[UUID], List[UUID],
ad.get(AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS), ad.get(AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS),
): ):
dt.incomplete_service_class_uuids128.extend( dt.incomplete_service_class_uuids128.extend(
list(map(lambda x: x.to_hex_str('-'), uuids)) list(map(lambda x: x.to_hex_str('-'), uuids))
) )
if uuids := cast( if uuids := cast(
list[UUID], List[UUID],
ad.get(AdvertisingData.COMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS), ad.get(AdvertisingData.COMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS),
): ):
dt.complete_service_class_uuids128.extend( dt.complete_service_class_uuids128.extend(
@@ -955,42 +945,42 @@ class HostService(HostServicer):
if i := cast(int, ad.get(AdvertisingData.CLASS_OF_DEVICE)): if i := cast(int, ad.get(AdvertisingData.CLASS_OF_DEVICE)):
dt.class_of_device = i dt.class_of_device = i
if ij := cast( if ij := cast(
tuple[int, int], Tuple[int, int],
ad.get(AdvertisingData.PERIPHERAL_CONNECTION_INTERVAL_RANGE), ad.get(AdvertisingData.PERIPHERAL_CONNECTION_INTERVAL_RANGE),
): ):
dt.peripheral_connection_interval_min = ij[0] dt.peripheral_connection_interval_min = ij[0]
dt.peripheral_connection_interval_max = ij[1] dt.peripheral_connection_interval_max = ij[1]
if uuids := cast( if uuids := cast(
list[UUID], List[UUID],
ad.get(AdvertisingData.LIST_OF_16_BIT_SERVICE_SOLICITATION_UUIDS), ad.get(AdvertisingData.LIST_OF_16_BIT_SERVICE_SOLICITATION_UUIDS),
): ):
dt.service_solicitation_uuids16.extend( dt.service_solicitation_uuids16.extend(
list(map(lambda x: x.to_hex_str('-'), uuids)) list(map(lambda x: x.to_hex_str('-'), uuids))
) )
if uuids := cast( if uuids := cast(
list[UUID], List[UUID],
ad.get(AdvertisingData.LIST_OF_32_BIT_SERVICE_SOLICITATION_UUIDS), ad.get(AdvertisingData.LIST_OF_32_BIT_SERVICE_SOLICITATION_UUIDS),
): ):
dt.service_solicitation_uuids32.extend( dt.service_solicitation_uuids32.extend(
list(map(lambda x: x.to_hex_str('-'), uuids)) list(map(lambda x: x.to_hex_str('-'), uuids))
) )
if uuids := cast( if uuids := cast(
list[UUID], List[UUID],
ad.get(AdvertisingData.LIST_OF_128_BIT_SERVICE_SOLICITATION_UUIDS), ad.get(AdvertisingData.LIST_OF_128_BIT_SERVICE_SOLICITATION_UUIDS),
): ):
dt.service_solicitation_uuids128.extend( dt.service_solicitation_uuids128.extend(
list(map(lambda x: x.to_hex_str('-'), uuids)) list(map(lambda x: x.to_hex_str('-'), uuids))
) )
if uuid_data := cast( if uuid_data := cast(
tuple[UUID, bytes], ad.get(AdvertisingData.SERVICE_DATA_16_BIT_UUID) Tuple[UUID, bytes], ad.get(AdvertisingData.SERVICE_DATA_16_BIT_UUID)
): ):
dt.service_data_uuid16[uuid_data[0].to_hex_str('-')] = uuid_data[1] dt.service_data_uuid16[uuid_data[0].to_hex_str('-')] = uuid_data[1]
if uuid_data := cast( if uuid_data := cast(
tuple[UUID, bytes], ad.get(AdvertisingData.SERVICE_DATA_32_BIT_UUID) Tuple[UUID, bytes], ad.get(AdvertisingData.SERVICE_DATA_32_BIT_UUID)
): ):
dt.service_data_uuid32[uuid_data[0].to_hex_str('-')] = uuid_data[1] dt.service_data_uuid32[uuid_data[0].to_hex_str('-')] = uuid_data[1]
if uuid_data := cast( if uuid_data := cast(
tuple[UUID, bytes], ad.get(AdvertisingData.SERVICE_DATA_128_BIT_UUID) Tuple[UUID, bytes], ad.get(AdvertisingData.SERVICE_DATA_128_BIT_UUID)
): ):
dt.service_data_uuid128[uuid_data[0].to_hex_str('-')] = uuid_data[1] dt.service_data_uuid128[uuid_data[0].to_hex_str('-')] = uuid_data[1]
if data := cast(bytes, ad.get(AdvertisingData.PUBLIC_TARGET_ADDRESS, raw=True)): if data := cast(bytes, ad.get(AdvertisingData.PUBLIC_TARGET_ADDRESS, raw=True)):
+27 -27
View File
@@ -12,21 +12,31 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import grpc
import json import json
import logging import logging
from asyncio import Future
from asyncio import Queue as AsyncQueue
from collections.abc import AsyncGenerator
from dataclasses import dataclass
import grpc from asyncio import Queue as AsyncQueue, Future
from bumble.pandora import utils
from bumble.pandora.config import Config
from bumble.core import OutOfResourcesError, InvalidArgumentError
from bumble.device import Device
from bumble.l2cap import (
ClassicChannel,
ClassicChannelServer,
ClassicChannelSpec,
LeCreditBasedChannel,
LeCreditBasedChannelServer,
LeCreditBasedChannelSpec,
)
from google.protobuf import any_pb2, empty_pb2 # pytype: disable=pyi-error from google.protobuf import any_pb2, empty_pb2 # pytype: disable=pyi-error
from pandora.l2cap_grpc_aio import L2CAPServicer # pytype: disable=pyi-error from pandora.l2cap_grpc_aio import L2CAPServicer # pytype: disable=pyi-error
from pandora.l2cap_pb2 import ( from pandora.l2cap_pb2 import ( # pytype: disable=pyi-error
COMMAND_NOT_UNDERSTOOD, COMMAND_NOT_UNDERSTOOD,
INVALID_CID_IN_REQUEST, INVALID_CID_IN_REQUEST,
Channel as PandoraChannel,
ConnectRequest, ConnectRequest,
ConnectResponse, ConnectResponse,
CreditBasedChannelRequest, CreditBasedChannelRequest,
@@ -41,22 +51,10 @@ from pandora.l2cap_pb2 import (
WaitDisconnectionRequest, WaitDisconnectionRequest,
WaitDisconnectionResponse, WaitDisconnectionResponse,
) )
from pandora.l2cap_pb2 import Channel as PandoraChannel # pytype: disable=pyi-error from typing import AsyncGenerator, Dict, Optional, Union
from dataclasses import dataclass
from bumble.core import InvalidArgumentError, OutOfResourcesError L2capChannel = Union[ClassicChannel, LeCreditBasedChannel]
from bumble.device import Device
from bumble.l2cap import (
ClassicChannel,
ClassicChannelServer,
ClassicChannelSpec,
LeCreditBasedChannel,
LeCreditBasedChannelServer,
LeCreditBasedChannelSpec,
)
from bumble.pandora import utils
from bumble.pandora.config import Config
L2capChannel = ClassicChannel | LeCreditBasedChannel
@dataclass @dataclass
@@ -72,7 +70,7 @@ class L2CAPService(L2CAPServicer):
) )
self.device = device self.device = device
self.config = config self.config = config
self.channels: dict[bytes, ChannelContext] = {} self.channels: Dict[bytes, ChannelContext] = {}
def register_event(self, l2cap_channel: L2capChannel) -> ChannelContext: def register_event(self, l2cap_channel: L2capChannel) -> ChannelContext:
close_future = asyncio.get_running_loop().create_future() close_future = asyncio.get_running_loop().create_future()
@@ -107,8 +105,10 @@ class L2CAPService(L2CAPServicer):
oneof = request.WhichOneof('type') oneof = request.WhichOneof('type')
self.log.debug(f'WaitConnection channel request type: {oneof}.') self.log.debug(f'WaitConnection channel request type: {oneof}.')
channel_type = getattr(request, oneof) channel_type = getattr(request, oneof)
spec: ClassicChannelSpec | LeCreditBasedChannelSpec | None = None spec: Optional[Union[ClassicChannelSpec, LeCreditBasedChannelSpec]] = None
l2cap_server: ClassicChannelServer | LeCreditBasedChannelServer | None = None l2cap_server: Optional[
Union[ClassicChannelServer, LeCreditBasedChannelServer]
] = None
if isinstance(channel_type, CreditBasedChannelRequest): if isinstance(channel_type, CreditBasedChannelRequest):
spec = LeCreditBasedChannelSpec( spec = LeCreditBasedChannelSpec(
psm=channel_type.spsm, psm=channel_type.spsm,
@@ -215,7 +215,7 @@ class L2CAPService(L2CAPServicer):
oneof = request.WhichOneof('type') oneof = request.WhichOneof('type')
self.log.debug(f'Channel request type: {oneof}.') self.log.debug(f'Channel request type: {oneof}.')
channel_type = getattr(request, oneof) channel_type = getattr(request, oneof)
spec: ClassicChannelSpec | LeCreditBasedChannelSpec | None = None spec: Optional[Union[ClassicChannelSpec, LeCreditBasedChannelSpec]] = None
if isinstance(channel_type, CreditBasedChannelRequest): if isinstance(channel_type, CreditBasedChannelRequest):
spec = LeCreditBasedChannelSpec( spec = LeCreditBasedChannelSpec(
psm=channel_type.spsm, psm=channel_type.spsm,
@@ -278,7 +278,7 @@ class L2CAPService(L2CAPServicer):
if not l2cap_channel: if not l2cap_channel:
return SendResponse(error=COMMAND_NOT_UNDERSTOOD) return SendResponse(error=COMMAND_NOT_UNDERSTOOD)
if isinstance(l2cap_channel, ClassicChannel): if isinstance(l2cap_channel, ClassicChannel):
l2cap_channel.write(request.data) l2cap_channel.send_pdu(request.data)
else: else:
l2cap_channel.write(request.data) l2cap_channel.write(request.data)
return SendResponse(success=empty_pb2.Empty()) return SendResponse(success=empty_pb2.Empty())
+33 -35
View File
@@ -13,19 +13,27 @@
# limitations under the License. # limitations under the License.
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import contextlib import contextlib
import logging from collections.abc import Awaitable
from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Callable
from typing import Any
import grpc import grpc
from google.protobuf import ( import logging
any_pb2, # pytype: disable=pyi-error
empty_pb2, # pytype: disable=pyi-error from bumble.pandora import utils
wrappers_pb2, # pytype: disable=pyi-error from bumble.pandora.config import Config
from bumble import hci
from bumble.core import (
PhysicalTransport,
ProtocolError,
InvalidArgumentError,
) )
import bumble.utils
from bumble.device import Connection as BumbleConnection, Device
from bumble.hci import HCI_Error, Role
from bumble.pairing import PairingConfig, PairingDelegate as BasePairingDelegate
from google.protobuf import any_pb2 # pytype: disable=pyi-error
from google.protobuf import empty_pb2 # pytype: disable=pyi-error
from google.protobuf import wrappers_pb2 # pytype: disable=pyi-error
from pandora.host_pb2 import Connection from pandora.host_pb2 import Connection
from pandora.security_grpc_aio import SecurityServicer, SecurityStorageServicer from pandora.security_grpc_aio import SecurityServicer, SecurityStorageServicer
from pandora.security_pb2 import ( from pandora.security_pb2 import (
@@ -49,24 +57,14 @@ from pandora.security_pb2 import (
WaitSecurityRequest, WaitSecurityRequest,
WaitSecurityResponse, WaitSecurityResponse,
) )
from typing import Any, AsyncGenerator, AsyncIterator, Callable, Dict, Optional, Union
import bumble.utils
from bumble import hci
from bumble.core import InvalidArgumentError, PhysicalTransport, ProtocolError
from bumble.device import Connection as BumbleConnection
from bumble.device import Device
from bumble.hci import HCI_Error, Role
from bumble.pairing import PairingConfig
from bumble.pairing import PairingDelegate as BasePairingDelegate
from bumble.pandora import utils
from bumble.pandora.config import Config
class PairingDelegate(BasePairingDelegate): class PairingDelegate(BasePairingDelegate):
def __init__( def __init__(
self, self,
connection: BumbleConnection, connection: BumbleConnection,
service: SecurityService, service: "SecurityService",
io_capability: BasePairingDelegate.IoCapability = BasePairingDelegate.NO_OUTPUT_NO_INPUT, io_capability: BasePairingDelegate.IoCapability = BasePairingDelegate.NO_OUTPUT_NO_INPUT,
local_initiator_key_distribution: BasePairingDelegate.KeyDistribution = BasePairingDelegate.DEFAULT_KEY_DISTRIBUTION, local_initiator_key_distribution: BasePairingDelegate.KeyDistribution = BasePairingDelegate.DEFAULT_KEY_DISTRIBUTION,
local_responder_key_distribution: BasePairingDelegate.KeyDistribution = BasePairingDelegate.DEFAULT_KEY_DISTRIBUTION, local_responder_key_distribution: BasePairingDelegate.KeyDistribution = BasePairingDelegate.DEFAULT_KEY_DISTRIBUTION,
@@ -132,7 +130,7 @@ class PairingDelegate(BasePairingDelegate):
assert answer.answer_variant() == 'confirm' and answer.confirm is not None assert answer.answer_variant() == 'confirm' and answer.confirm is not None
return answer.confirm return answer.confirm
async def get_number(self) -> int | None: async def get_number(self) -> Optional[int]:
self.log.debug( self.log.debug(
f"Pairing event: `passkey_entry_request` (io_capability: {self.io_capability})" f"Pairing event: `passkey_entry_request` (io_capability: {self.io_capability})"
) )
@@ -149,7 +147,7 @@ class PairingDelegate(BasePairingDelegate):
assert answer.answer_variant() == 'passkey' assert answer.answer_variant() == 'passkey'
return answer.passkey return answer.passkey
async def get_string(self, max_length: int) -> str | None: async def get_string(self, max_length: int) -> Optional[str]:
self.log.debug( self.log.debug(
f"Pairing event: `pin_code_request` (io_capability: {self.io_capability})" f"Pairing event: `pin_code_request` (io_capability: {self.io_capability})"
) )
@@ -197,8 +195,8 @@ class SecurityService(SecurityServicer):
self.log = utils.BumbleServerLoggerAdapter( self.log = utils.BumbleServerLoggerAdapter(
logging.getLogger(), {'service_name': 'Security', 'device': device} logging.getLogger(), {'service_name': 'Security', 'device': device}
) )
self.event_queue: asyncio.Queue[PairingEvent] | None = None self.event_queue: Optional[asyncio.Queue[PairingEvent]] = None
self.event_answer: AsyncIterator[PairingEventAnswer] | None = None self.event_answer: Optional[AsyncIterator[PairingEventAnswer]] = None
self.device = device self.device = device
self.config = config self.config = config
@@ -233,7 +231,7 @@ class SecurityService(SecurityServicer):
if level == LEVEL2: if level == LEVEL2:
return connection.encryption != 0 and connection.authenticated return connection.encryption != 0 and connection.authenticated
link_key_type: int | None = None link_key_type: Optional[int] = None
if (keystore := connection.device.keystore) and ( if (keystore := connection.device.keystore) and (
keys := await keystore.get(str(connection.peer_address)) keys := await keystore.get(str(connection.peer_address))
): ):
@@ -246,16 +244,16 @@ class SecurityService(SecurityServicer):
and connection.authenticated and connection.authenticated
and link_key_type and link_key_type
in ( in (
hci.LinkKeyType.AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_192, hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_192_TYPE,
hci.LinkKeyType.AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_256, hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_256_TYPE,
) )
) )
if level == LEVEL4: if level == LEVEL4:
return ( return (
connection.encryption == hci.HCI_Encryption_Change_Event.Enabled.AES_CCM connection.encryption == hci.HCI_Encryption_Change_Event.AES_CCM
and connection.authenticated and connection.authenticated
and link_key_type and link_key_type
== hci.LinkKeyType.AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_256 == hci.HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_256_TYPE
) )
raise InvalidArgumentError(f"Unexpected level {level}") raise InvalidArgumentError(f"Unexpected level {level}")
@@ -412,8 +410,8 @@ class SecurityService(SecurityServicer):
wait_for_security: asyncio.Future[str] = ( wait_for_security: asyncio.Future[str] = (
asyncio.get_running_loop().create_future() asyncio.get_running_loop().create_future()
) )
authenticate_task: asyncio.Future[None] | None = None authenticate_task: Optional[asyncio.Future[None]] = None
pair_task: asyncio.Future[None] | None = None pair_task: Optional[asyncio.Future[None]] = None
async def authenticate() -> None: async def authenticate() -> None:
if (encryption := connection.encryption) != 0: if (encryption := connection.encryption) != 0:
@@ -457,9 +455,9 @@ class SecurityService(SecurityServicer):
def pair(*_: Any) -> None: def pair(*_: Any) -> None:
if self.need_pairing(connection, level): if self.need_pairing(connection, level):
bumble.utils.AsyncRunner.spawn(connection.pair()) pair_task = asyncio.create_task(connection.pair())
listeners: dict[str, Callable[..., None | Awaitable[None]]] = { listeners: Dict[str, Callable[..., Union[None, Awaitable[None]]]] = {
'disconnection': set_failure('connection_died'), 'disconnection': set_failure('connection_died'),
'pairing_failure': set_failure('pairing_failure'), 'pairing_failure': set_failure('pairing_failure'),
'connection_authentication_failure': set_failure('authentication_failure'), 'connection_authentication_failure': set_failure('authentication_failure'),
@@ -502,7 +500,7 @@ class SecurityService(SecurityServicer):
return WaitSecurityResponse(**kwargs) return WaitSecurityResponse(**kwargs)
async def reached_security_level( async def reached_security_level(
self, connection: BumbleConnection, level: SecurityLevel | LESecurityLevel self, connection: BumbleConnection, level: Union[SecurityLevel, LESecurityLevel]
) -> bool: ) -> bool:
self.log.debug( self.log.debug(
str( str(
+8 -10
View File
@@ -13,21 +13,18 @@
# limitations under the License. # limitations under the License.
from __future__ import annotations from __future__ import annotations
import contextlib import contextlib
import functools import functools
import grpc
import inspect import inspect
import logging import logging
from collections.abc import Generator, MutableMapping
from typing import Any
import grpc
from google.protobuf.message import Message # pytype: disable=pyi-error
from bumble.device import Device from bumble.device import Device
from bumble.hci import Address, AddressType from bumble.hci import Address, AddressType
from google.protobuf.message import Message # pytype: disable=pyi-error
from typing import Any, Dict, Generator, MutableMapping, Optional, Tuple
ADDRESS_TYPES: dict[str, AddressType] = { ADDRESS_TYPES: Dict[str, AddressType] = {
"public": Address.PUBLIC_DEVICE_ADDRESS, "public": Address.PUBLIC_DEVICE_ADDRESS,
"random": Address.RANDOM_DEVICE_ADDRESS, "random": Address.RANDOM_DEVICE_ADDRESS,
"public_identity": Address.PUBLIC_IDENTITY_ADDRESS, "public_identity": Address.PUBLIC_IDENTITY_ADDRESS,
@@ -35,7 +32,7 @@ ADDRESS_TYPES: dict[str, AddressType] = {
} }
def address_from_request(request: Message, field: str | None) -> Address: def address_from_request(request: Message, field: Optional[str]) -> Address:
if field is None: if field is None:
return Address.ANY return Address.ANY
return Address(bytes(reversed(getattr(request, field))), ADDRESS_TYPES[field]) return Address(bytes(reversed(getattr(request, field))), ADDRESS_TYPES[field])
@@ -46,7 +43,7 @@ class BumbleServerLoggerAdapter(logging.LoggerAdapter): # type: ignore
def process( def process(
self, msg: str, kwargs: MutableMapping[str, Any] self, msg: str, kwargs: MutableMapping[str, Any]
) -> tuple[str, MutableMapping[str, Any]]: ) -> Tuple[str, MutableMapping[str, Any]]:
assert self.extra assert self.extra
service_name = self.extra['service_name'] service_name = self.extra['service_name']
assert isinstance(service_name, str) assert isinstance(service_name, str)
@@ -96,7 +93,8 @@ def rpc(func: Any) -> Any:
@functools.wraps(func) @functools.wraps(func)
def gen_wrapper(self: Any, request: Any, context: grpc.ServicerContext) -> Any: def gen_wrapper(self: Any, request: Any, context: grpc.ServicerContext) -> Any:
with exception_to_rpc_error(context): with exception_to_rpc_error(context):
yield from func(self, request, context) for v in func(self, request, context):
yield v
@functools.wraps(func) @functools.wraps(func)
def wrapper(self: Any, request: Any, context: grpc.ServicerContext) -> Any: def wrapper(self: Any, request: Any, context: grpc.ServicerContext) -> Any:
+25 -21
View File
@@ -18,26 +18,26 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations from __future__ import annotations
import logging import logging
import struct import struct
from dataclasses import dataclass
from bumble import utils from dataclasses import dataclass
from bumble.att import ATT_Error from typing import Optional
from bumble.device import Connection from bumble.device import Connection
from bumble.att import ATT_Error
from bumble.gatt import ( from bumble.gatt import (
GATT_AUDIO_INPUT_CONTROL_POINT_CHARACTERISTIC,
GATT_AUDIO_INPUT_CONTROL_SERVICE,
GATT_AUDIO_INPUT_DESCRIPTION_CHARACTERISTIC,
GATT_AUDIO_INPUT_STATE_CHARACTERISTIC,
GATT_AUDIO_INPUT_STATUS_CHARACTERISTIC,
GATT_AUDIO_INPUT_TYPE_CHARACTERISTIC,
GATT_GAIN_SETTINGS_ATTRIBUTE_CHARACTERISTIC,
Attribute, Attribute,
Characteristic, Characteristic,
CharacteristicValue,
TemplateService, TemplateService,
CharacteristicValue,
GATT_AUDIO_INPUT_CONTROL_SERVICE,
GATT_AUDIO_INPUT_STATE_CHARACTERISTIC,
GATT_GAIN_SETTINGS_ATTRIBUTE_CHARACTERISTIC,
GATT_AUDIO_INPUT_TYPE_CHARACTERISTIC,
GATT_AUDIO_INPUT_STATUS_CHARACTERISTIC,
GATT_AUDIO_INPUT_CONTROL_POINT_CHARACTERISTIC,
GATT_AUDIO_INPUT_DESCRIPTION_CHARACTERISTIC,
) )
from bumble.gatt_adapters import ( from bumble.gatt_adapters import (
CharacteristicProxy, CharacteristicProxy,
@@ -48,6 +48,7 @@ from bumble.gatt_adapters import (
UTF8CharacteristicProxyAdapter, UTF8CharacteristicProxyAdapter,
) )
from bumble.gatt_client import ProfileServiceProxy, ServiceProxy from bumble.gatt_client import ProfileServiceProxy, ServiceProxy
from bumble import utils
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -128,7 +129,7 @@ class AudioInputState:
mute: Mute = Mute.NOT_MUTED mute: Mute = Mute.NOT_MUTED
gain_mode: GainMode = GainMode.MANUAL gain_mode: GainMode = GainMode.MANUAL
change_counter: int = 0 change_counter: int = 0
attribute: Attribute | None = None attribute: Optional[Attribute] = None
def __bytes__(self) -> bytes: def __bytes__(self) -> bytes:
return bytes( return bytes(
@@ -197,7 +198,9 @@ class AudioInputControlPoint:
audio_input_state: AudioInputState audio_input_state: AudioInputState
gain_settings_properties: GainSettingsProperties gain_settings_properties: GainSettingsProperties
async def on_write(self, connection: Connection, value: bytes) -> None: async def on_write(self, connection: Optional[Connection], value: bytes) -> None:
assert connection
opcode = AudioInputControlPointOpCode(value[0]) opcode = AudioInputControlPointOpCode(value[0])
if opcode == AudioInputControlPointOpCode.SET_GAIN_SETTING: if opcode == AudioInputControlPointOpCode.SET_GAIN_SETTING:
@@ -315,12 +318,13 @@ class AudioInputDescription:
''' '''
audio_input_description: str = "Bluetooth" audio_input_description: str = "Bluetooth"
attribute: Attribute | None = None attribute: Optional[Attribute] = None
def on_read(self, _connection: Connection) -> str: def on_read(self, _connection: Optional[Connection]) -> str:
return self.audio_input_description return self.audio_input_description
async def on_write(self, connection: Connection, value: str) -> None: async def on_write(self, connection: Optional[Connection], value: str) -> None:
assert connection
assert self.attribute assert self.attribute
self.audio_input_description = value self.audio_input_description = value
@@ -338,11 +342,11 @@ class AICSService(TemplateService):
def __init__( def __init__(
self, self,
audio_input_state: AudioInputState | None = None, audio_input_state: Optional[AudioInputState] = None,
gain_settings_properties: GainSettingsProperties | None = None, gain_settings_properties: Optional[GainSettingsProperties] = None,
audio_input_type: str = "local", audio_input_type: str = "local",
audio_input_status: AudioInputStatus | None = None, audio_input_status: Optional[AudioInputStatus] = None,
audio_input_description: AudioInputDescription | None = None, audio_input_description: Optional[AudioInputDescription] = None,
): ):
self.audio_input_state = ( self.audio_input_state = (
AudioInputState() if audio_input_state is None else audio_input_state AudioInputState() if audio_input_state is None else audio_input_state
-401
View File
@@ -1,401 +0,0 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Apple Media Service (AMS).
"""
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import dataclasses
import enum
import logging
from collections.abc import Iterable
from bumble import utils
from bumble.device import Peer
from bumble.gatt import (
GATT_AMS_ENTITY_ATTRIBUTE_CHARACTERISTIC,
GATT_AMS_ENTITY_UPDATE_CHARACTERISTIC,
GATT_AMS_REMOTE_COMMAND_CHARACTERISTIC,
GATT_AMS_SERVICE,
Characteristic,
TemplateService,
)
from bumble.gatt_client import CharacteristicProxy, ProfileServiceProxy, ServiceProxy
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Protocol
# -----------------------------------------------------------------------------
class RemoteCommandId(utils.OpenIntEnum):
PLAY = 0
PAUSE = 1
TOGGLE_PLAY_PAUSE = 2
NEXT_TRACK = 3
PREVIOUS_TRACK = 4
VOLUME_UP = 5
VOLUME_DOWN = 6
ADVANCE_REPEAT_MODE = 7
ADVANCE_SHUFFLE_MODE = 8
SKIP_FORWARD = 9
SKIP_BACKWARD = 10
LIKE_TRACK = 11
DISLIKE_TRACK = 12
BOOKMARK_TRACK = 13
class EntityId(utils.OpenIntEnum):
PLAYER = 0
QUEUE = 1
TRACK = 2
class ActionId(utils.OpenIntEnum):
POSITIVE = 0
NEGATIVE = 1
class EntityUpdateFlags(enum.IntFlag):
TRUNCATED = 1
class PlayerAttributeId(utils.OpenIntEnum):
NAME = 0
PLAYBACK_INFO = 1
VOLUME = 2
class QueueAttributeId(utils.OpenIntEnum):
INDEX = 0
COUNT = 1
SHUFFLE_MODE = 2
REPEAT_MODE = 3
class ShuffleMode(utils.OpenIntEnum):
OFF = 0
ONE = 1
ALL = 2
class RepeatMode(utils.OpenIntEnum):
OFF = 0
ONE = 1
ALL = 2
class TrackAttributeId(utils.OpenIntEnum):
ARTIST = 0
ALBUM = 1
TITLE = 2
DURATION = 3
class PlaybackState(utils.OpenIntEnum):
PAUSED = 0
PLAYING = 1
REWINDING = 2
FAST_FORWARDING = 3
@dataclasses.dataclass
class PlaybackInfo:
playback_state: PlaybackState = PlaybackState.PAUSED
playback_rate: float = 1.0
elapsed_time: float = 0.0
# -----------------------------------------------------------------------------
# GATT Server-side
# -----------------------------------------------------------------------------
class Ams(TemplateService):
UUID = GATT_AMS_SERVICE
remote_command_characteristic: Characteristic
entity_update_characteristic: Characteristic
entity_attribute_characteristic: Characteristic
def __init__(self) -> None:
# TODO not the final implementation
self.remote_command_characteristic = Characteristic(
GATT_AMS_REMOTE_COMMAND_CHARACTERISTIC,
Characteristic.Properties.NOTIFY
| Characteristic.Properties.WRITE_WITHOUT_RESPONSE,
Characteristic.Permissions.WRITEABLE,
)
# TODO not the final implementation
self.entity_update_characteristic = Characteristic(
GATT_AMS_ENTITY_UPDATE_CHARACTERISTIC,
Characteristic.Properties.NOTIFY | Characteristic.Properties.WRITE,
Characteristic.Permissions.WRITEABLE,
)
# TODO not the final implementation
self.entity_attribute_characteristic = Characteristic(
GATT_AMS_ENTITY_ATTRIBUTE_CHARACTERISTIC,
Characteristic.Properties.READ
| Characteristic.Properties.WRITE_WITHOUT_RESPONSE,
Characteristic.Permissions.WRITEABLE | Characteristic.Permissions.READABLE,
)
super().__init__(
[
self.remote_command_characteristic,
self.entity_update_characteristic,
self.entity_attribute_characteristic,
]
)
# -----------------------------------------------------------------------------
# GATT Client-side
# -----------------------------------------------------------------------------
class AmsProxy(ProfileServiceProxy):
SERVICE_CLASS = Ams
# NOTE: these don't use adapters, because the format for write and notifications
# are different.
remote_command: CharacteristicProxy[bytes]
entity_update: CharacteristicProxy[bytes]
entity_attribute: CharacteristicProxy[bytes]
def __init__(self, service_proxy: ServiceProxy):
self.remote_command = service_proxy.get_required_characteristic_by_uuid(
GATT_AMS_REMOTE_COMMAND_CHARACTERISTIC
)
self.entity_update = service_proxy.get_required_characteristic_by_uuid(
GATT_AMS_ENTITY_UPDATE_CHARACTERISTIC
)
self.entity_attribute = service_proxy.get_required_characteristic_by_uuid(
GATT_AMS_ENTITY_ATTRIBUTE_CHARACTERISTIC
)
class AmsClient(utils.EventEmitter):
EVENT_SUPPORTED_COMMANDS = "supported_commands"
EVENT_PLAYER_NAME = "player_name"
EVENT_PLAYER_PLAYBACK_INFO = "player_playback_info"
EVENT_PLAYER_VOLUME = "player_volume"
EVENT_QUEUE_COUNT = "queue_count"
EVENT_QUEUE_INDEX = "queue_index"
EVENT_QUEUE_SHUFFLE_MODE = "queue_shuffle_mode"
EVENT_QUEUE_REPEAT_MODE = "queue_repeat_mode"
EVENT_TRACK_ARTIST = "track_artist"
EVENT_TRACK_ALBUM = "track_album"
EVENT_TRACK_TITLE = "track_title"
EVENT_TRACK_DURATION = "track_duration"
supported_commands: set[RemoteCommandId]
player_name: str = ""
player_playback_info: PlaybackInfo = PlaybackInfo(PlaybackState.PAUSED, 0.0, 0.0)
player_volume: float = 1.0
queue_count: int = 0
queue_index: int = 0
queue_shuffle_mode: ShuffleMode = ShuffleMode.OFF
queue_repeat_mode: RepeatMode = RepeatMode.OFF
track_artist: str = ""
track_album: str = ""
track_title: str = ""
track_duration: float = 0.0
def __init__(self, ams_proxy: AmsProxy) -> None:
super().__init__()
self._ams_proxy = ams_proxy
self._started = False
self._read_attribute_semaphore = asyncio.Semaphore()
self.supported_commands = set()
@classmethod
async def for_peer(cls, peer: Peer) -> AmsClient | None:
ams_proxy = await peer.discover_service_and_create_proxy(AmsProxy)
if ams_proxy is None:
return None
return cls(ams_proxy)
async def start(self) -> None:
logger.debug("subscribing to remote command characteristic")
await self._ams_proxy.remote_command.subscribe(
self._on_remote_command_notification
)
logger.debug("subscribing to entity update characteristic")
await self._ams_proxy.entity_update.subscribe(
lambda data: utils.AsyncRunner.spawn(
self._on_entity_update_notification(data)
)
)
self._started = True
async def stop(self) -> None:
await self._ams_proxy.remote_command.unsubscribe(
self._on_remote_command_notification
)
await self._ams_proxy.entity_update.unsubscribe(
self._on_entity_update_notification
)
self._started = False
async def observe(
self,
entity: EntityId,
attributes: Iterable[PlayerAttributeId | QueueAttributeId | TrackAttributeId],
) -> None:
await self._ams_proxy.entity_update.write_value(
bytes([entity] + list(attributes)), with_response=True
)
async def command(self, command: RemoteCommandId) -> None:
await self._ams_proxy.remote_command.write_value(
bytes([command]), with_response=True
)
async def play(self) -> None:
await self.command(RemoteCommandId.PLAY)
async def pause(self) -> None:
await self.command(RemoteCommandId.PAUSE)
async def toggle_play_pause(self) -> None:
await self.command(RemoteCommandId.TOGGLE_PLAY_PAUSE)
async def next_track(self) -> None:
await self.command(RemoteCommandId.NEXT_TRACK)
async def previous_track(self) -> None:
await self.command(RemoteCommandId.PREVIOUS_TRACK)
async def volume_up(self) -> None:
await self.command(RemoteCommandId.VOLUME_UP)
async def volume_down(self) -> None:
await self.command(RemoteCommandId.VOLUME_DOWN)
async def advance_repeat_mode(self) -> None:
await self.command(RemoteCommandId.ADVANCE_REPEAT_MODE)
async def advance_shuffle_mode(self) -> None:
await self.command(RemoteCommandId.ADVANCE_SHUFFLE_MODE)
async def skip_forward(self) -> None:
await self.command(RemoteCommandId.SKIP_FORWARD)
async def skip_backward(self) -> None:
await self.command(RemoteCommandId.SKIP_BACKWARD)
async def like_track(self) -> None:
await self.command(RemoteCommandId.LIKE_TRACK)
async def dislike_track(self) -> None:
await self.command(RemoteCommandId.DISLIKE_TRACK)
async def bookmark_track(self) -> None:
await self.command(RemoteCommandId.BOOKMARK_TRACK)
def _on_remote_command_notification(self, data: bytes) -> None:
supported_commands = [RemoteCommandId(command) for command in data]
logger.debug(
f"supported commands: {[command.name for command in supported_commands]}"
)
for command in supported_commands:
self.supported_commands.add(command)
self.emit(self.EVENT_SUPPORTED_COMMANDS)
async def _on_entity_update_notification(self, data: bytes) -> None:
entity = EntityId(data[0])
flags = EntityUpdateFlags(data[2])
value = data[3:]
if flags & EntityUpdateFlags.TRUNCATED:
logger.debug("truncated attribute, fetching full value")
# Write the entity and attribute we're interested in
# (protected by a semaphore, so that we only read one attribute at a time)
async with self._read_attribute_semaphore:
await self._ams_proxy.entity_attribute.write_value(
data[:2], with_response=True
)
value = await self._ams_proxy.entity_attribute.read_value()
if entity == EntityId.PLAYER:
player_attribute = PlayerAttributeId(data[1])
if player_attribute == PlayerAttributeId.NAME:
self.player_name = value.decode()
self.emit(self.EVENT_PLAYER_NAME)
elif player_attribute == PlayerAttributeId.PLAYBACK_INFO:
playback_state_str, playback_rate_str, elapsed_time_str = (
value.decode().split(",")
)
self.player_playback_info = PlaybackInfo(
PlaybackState(int(playback_state_str)),
float(playback_rate_str),
float(elapsed_time_str),
)
self.emit(self.EVENT_PLAYER_PLAYBACK_INFO)
elif player_attribute == PlayerAttributeId.VOLUME:
self.player_volume = float(value.decode())
self.emit(self.EVENT_PLAYER_VOLUME)
else:
logger.warning(f"received unknown player attribute {player_attribute}")
elif entity == EntityId.QUEUE:
queue_attribute = QueueAttributeId(data[1])
if queue_attribute == QueueAttributeId.COUNT:
self.queue_count = int(value)
self.emit(self.EVENT_QUEUE_COUNT)
elif queue_attribute == QueueAttributeId.INDEX:
self.queue_index = int(value)
self.emit(self.EVENT_QUEUE_INDEX)
elif queue_attribute == QueueAttributeId.REPEAT_MODE:
self.queue_repeat_mode = RepeatMode(int(value))
self.emit(self.EVENT_QUEUE_REPEAT_MODE)
elif queue_attribute == QueueAttributeId.SHUFFLE_MODE:
self.queue_shuffle_mode = ShuffleMode(int(value))
self.emit(self.EVENT_QUEUE_SHUFFLE_MODE)
else:
logger.warning(f"received unknown queue attribute {queue_attribute}")
elif entity == EntityId.TRACK:
track_attribute = TrackAttributeId(data[1])
if track_attribute == TrackAttributeId.ARTIST:
self.track_artist = value.decode()
self.emit(self.EVENT_TRACK_ARTIST)
elif track_attribute == TrackAttributeId.ALBUM:
self.track_album = value.decode()
self.emit(self.EVENT_TRACK_ALBUM)
elif track_attribute == TrackAttributeId.TITLE:
self.track_title = value.decode()
self.emit(self.EVENT_TRACK_TITLE)
elif track_attribute == TrackAttributeId.DURATION:
self.track_duration = float(value.decode())
self.emit(self.EVENT_TRACK_DURATION)
else:
logger.warning(f"received unknown track attribute {track_attribute}")
else:
logger.warning(f"received unknown attribute ID {data[1]}")
+18 -17
View File
@@ -20,28 +20,29 @@ Apple Notification Center Service (ANCS).
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import dataclasses import dataclasses
import datetime import datetime
import enum import enum
import logging import logging
import struct import struct
from collections.abc import Sequence from typing import Optional, Sequence, Union
from bumble import utils
from bumble.att import ATT_Error from bumble.att import ATT_Error
from bumble.device import Peer from bumble.device import Peer
from bumble.gatt import ( from bumble.gatt import (
Characteristic,
GATT_ANCS_SERVICE,
GATT_ANCS_NOTIFICATION_SOURCE_CHARACTERISTIC,
GATT_ANCS_CONTROL_POINT_CHARACTERISTIC, GATT_ANCS_CONTROL_POINT_CHARACTERISTIC,
GATT_ANCS_DATA_SOURCE_CHARACTERISTIC, GATT_ANCS_DATA_SOURCE_CHARACTERISTIC,
GATT_ANCS_NOTIFICATION_SOURCE_CHARACTERISTIC,
GATT_ANCS_SERVICE,
Characteristic,
TemplateService, TemplateService,
) )
from bumble.gatt_adapters import SerializableCharacteristicProxyAdapter
from bumble.gatt_client import CharacteristicProxy, ProfileServiceProxy, ServiceProxy from bumble.gatt_client import CharacteristicProxy, ProfileServiceProxy, ServiceProxy
from bumble.gatt_adapters import SerializableCharacteristicProxyAdapter
from bumble import utils
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Constants # Constants
@@ -116,7 +117,7 @@ class NotificationAttributeId(utils.OpenIntEnum):
@dataclasses.dataclass @dataclasses.dataclass
class NotificationAttribute: class NotificationAttribute:
attribute_id: NotificationAttributeId attribute_id: NotificationAttributeId
value: str | int | datetime.datetime value: Union[str, int, datetime.datetime]
@dataclasses.dataclass @dataclasses.dataclass
@@ -242,10 +243,10 @@ class AncsProxy(ProfileServiceProxy):
class AncsClient(utils.EventEmitter): class AncsClient(utils.EventEmitter):
_expected_response_command_id: CommandId | None _expected_response_command_id: Optional[CommandId]
_expected_response_notification_uid: int | None _expected_response_notification_uid: Optional[int]
_expected_response_app_identifier: str | None _expected_response_app_identifier: Optional[str]
_expected_app_identifier: str | None _expected_app_identifier: Optional[str]
_expected_response_tuples: int _expected_response_tuples: int
_response_accumulator: bytes _response_accumulator: bytes
@@ -255,12 +256,12 @@ class AncsClient(utils.EventEmitter):
super().__init__() super().__init__()
self._ancs_proxy = ancs_proxy self._ancs_proxy = ancs_proxy
self._command_semaphore = asyncio.Semaphore() self._command_semaphore = asyncio.Semaphore()
self._response: asyncio.Future | None = None self._response: Optional[asyncio.Future] = None
self._reset_response() self._reset_response()
self._started = False self._started = False
@classmethod @classmethod
async def for_peer(cls, peer: Peer) -> AncsClient | None: async def for_peer(cls, peer: Peer) -> Optional[AncsClient]:
ancs_proxy = await peer.discover_service_and_create_proxy(AncsProxy) ancs_proxy = await peer.discover_service_and_create_proxy(AncsProxy)
if ancs_proxy is None: if ancs_proxy is None:
return None return None
@@ -316,7 +317,7 @@ class AncsClient(utils.EventEmitter):
# Not enough data yet. # Not enough data yet.
return return
attributes: list[NotificationAttribute | AppAttribute] = [] attributes: list[Union[NotificationAttribute, AppAttribute]] = []
if command_id == CommandId.GET_NOTIFICATION_ATTRIBUTES: if command_id == CommandId.GET_NOTIFICATION_ATTRIBUTES:
(notification_uid,) = struct.unpack_from( (notification_uid,) = struct.unpack_from(
@@ -342,7 +343,7 @@ class AncsClient(utils.EventEmitter):
str_value = attribute_data[3 : 3 + attribute_data_length].decode( str_value = attribute_data[3 : 3 + attribute_data_length].decode(
"utf-8" "utf-8"
) )
value: str | int | datetime.datetime value: Union[str, int, datetime.datetime]
if attribute_id == NotificationAttributeId.MESSAGE_SIZE: if attribute_id == NotificationAttributeId.MESSAGE_SIZE:
value = int(str_value) value = int(str_value)
elif attribute_id == NotificationAttributeId.DATE: elif attribute_id == NotificationAttributeId.DATE:
@@ -415,7 +416,7 @@ class AncsClient(utils.EventEmitter):
self, self,
notification_uid: int, notification_uid: int,
attributes: Sequence[ attributes: Sequence[
NotificationAttributeId | tuple[NotificationAttributeId, int] Union[NotificationAttributeId, tuple[NotificationAttributeId, int]]
], ],
) -> list[NotificationAttribute]: ) -> list[NotificationAttribute]:
if not self._started: if not self._started:
+167 -171
View File
@@ -19,16 +19,18 @@
from __future__ import annotations from __future__ import annotations
import enum import enum
import functools
import logging import logging
import struct import struct
from collections.abc import Sequence from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union
from dataclasses import dataclass, field
from typing import Any, TypeVar
from bumble import colors, device, gatt, gatt_client, hci, utils from bumble import utils
from bumble.profiles import le_audio from bumble import colors
from bumble.profiles.bap import CodecSpecificConfiguration from bumble.profiles.bap import CodecSpecificConfiguration
from bumble.profiles import le_audio
from bumble import device
from bumble import gatt
from bumble import gatt_client
from bumble import hci
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -46,11 +48,11 @@ class ASE_Operation:
See Audio Stream Control Service - 5 ASE Control operations. See Audio Stream Control Service - 5 ASE Control operations.
''' '''
classes: dict[int, type[ASE_Operation]] = {} classes: Dict[int, Type[ASE_Operation]] = {}
op_code: Opcode op_code: int
name: str name: str
fields: Sequence[Any] | None = None fields: Optional[Sequence[Any]] = None
ase_id: Sequence[int] ase_id: List[int]
class Opcode(enum.IntEnum): class Opcode(enum.IntEnum):
# fmt: off # fmt: off
@@ -63,30 +65,51 @@ class ASE_Operation:
UPDATE_METADATA = 0x07 UPDATE_METADATA = 0x07
RELEASE = 0x08 RELEASE = 0x08
@classmethod @staticmethod
def from_bytes(cls, pdu: bytes) -> ASE_Operation: def from_bytes(pdu: bytes) -> ASE_Operation:
op_code = pdu[0] op_code = pdu[0]
clazz = ASE_Operation.classes[op_code] cls = ASE_Operation.classes.get(op_code)
return clazz( if cls is None:
**hci.HCI_Object.dict_from_bytes(pdu, offset=1, fields=clazz.fields) instance = ASE_Operation(pdu)
) instance.name = ASE_Operation.Opcode(op_code).name
instance.op_code = op_code
return instance
self = cls.__new__(cls)
ASE_Operation.__init__(self, pdu)
if self.fields is not None:
self.init_from_bytes(pdu, 1)
return self
_OP = TypeVar("_OP", bound="ASE_Operation") @staticmethod
def subclass(fields):
def inner(cls: Type[ASE_Operation]):
try:
operation = ASE_Operation.Opcode[cls.__name__[4:].upper()]
cls.name = operation.name
cls.op_code = operation
except:
raise KeyError(f'PDU name {cls.name} not found in Ase_Operation.Opcode')
cls.fields = fields
@classmethod # Register a factory for this class
def subclass(cls, clazz: type[_OP]) -> type[_OP]: ASE_Operation.classes[cls.op_code] = cls
clazz.name = f"ASE_{clazz.op_code.name.upper()}"
clazz.fields = hci.HCI_Object.fields_from_dataclass(clazz)
# Register a factory for this class
ASE_Operation.classes[clazz.op_code] = clazz
return clazz
@functools.cached_property return cls
def pdu(self) -> bytes:
return bytes([self.op_code]) + hci.HCI_Object.dict_to_bytes( return inner
self.__dict__, self.fields
) def __init__(self, pdu: Optional[bytes] = None, **kwargs) -> None:
if self.fields is not None and kwargs:
hci.HCI_Object.init_from_fields(self, self.fields, kwargs)
if pdu is None:
pdu = bytes([self.op_code]) + hci.HCI_Object.dict_to_bytes(
kwargs, self.fields
)
self.pdu = pdu
def init_from_bytes(self, pdu: bytes, offset: int):
return hci.HCI_Object.init_from_bytes(self, pdu, offset, self.fields)
def __bytes__(self) -> bytes: def __bytes__(self) -> bytes:
return self.pdu return self.pdu
@@ -101,128 +124,105 @@ class ASE_Operation:
return result return result
@ASE_Operation.subclass @ASE_Operation.subclass(
@dataclass [
[
('ase_id', 1),
('target_latency', 1),
('target_phy', 1),
('codec_id', hci.CodingFormat.parse_from_bytes),
('codec_specific_configuration', 'v'),
],
]
)
class ASE_Config_Codec(ASE_Operation): class ASE_Config_Codec(ASE_Operation):
''' '''
See Audio Stream Control Service 5.1 - Config Codec Operation See Audio Stream Control Service 5.1 - Config Codec Operation
''' '''
op_code = ASE_Operation.Opcode.CONFIG_CODEC target_latency: List[int]
target_phy: List[int]
ase_id: Sequence[int] = field(metadata=hci.metadata(1, list_begin=True)) codec_id: List[hci.CodingFormat]
target_latency: Sequence[int] = field(metadata=hci.metadata(1)) codec_specific_configuration: List[bytes]
target_phy: Sequence[int] = field(metadata=hci.metadata(1))
codec_id: Sequence[hci.CodingFormat] = field(
metadata=hci.metadata(hci.CodingFormat.parse_from_bytes)
)
codec_specific_configuration: Sequence[bytes] = field(
metadata=hci.metadata('v', list_end=True)
)
@ASE_Operation.subclass @ASE_Operation.subclass(
@dataclass [
[
('ase_id', 1),
('cig_id', 1),
('cis_id', 1),
('sdu_interval', 3),
('framing', 1),
('phy', 1),
('max_sdu', 2),
('retransmission_number', 1),
('max_transport_latency', 2),
('presentation_delay', 3),
],
]
)
class ASE_Config_QOS(ASE_Operation): class ASE_Config_QOS(ASE_Operation):
''' '''
See Audio Stream Control Service 5.2 - Config Qos Operation See Audio Stream Control Service 5.2 - Config Qos Operation
''' '''
op_code = ASE_Operation.Opcode.CONFIG_QOS cig_id: List[int]
cis_id: List[int]
ase_id: Sequence[int] = field(metadata=hci.metadata(1, list_begin=True)) sdu_interval: List[int]
cig_id: Sequence[int] = field(metadata=hci.metadata(1)) framing: List[int]
cis_id: Sequence[int] = field(metadata=hci.metadata(1)) phy: List[int]
sdu_interval: Sequence[int] = field(metadata=hci.metadata(3)) max_sdu: List[int]
framing: Sequence[int] = field(metadata=hci.metadata(1)) retransmission_number: List[int]
phy: Sequence[int] = field(metadata=hci.metadata(1)) max_transport_latency: List[int]
max_sdu: Sequence[int] = field(metadata=hci.metadata(2)) presentation_delay: List[int]
retransmission_number: Sequence[int] = field(metadata=hci.metadata(1))
max_transport_latency: Sequence[int] = field(metadata=hci.metadata(2))
presentation_delay: Sequence[int] = field(metadata=hci.metadata(3, list_end=True))
@ASE_Operation.subclass @ASE_Operation.subclass([[('ase_id', 1), ('metadata', 'v')]])
@dataclass
class ASE_Enable(ASE_Operation): class ASE_Enable(ASE_Operation):
''' '''
See Audio Stream Control Service 5.3 - Enable Operation See Audio Stream Control Service 5.3 - Enable Operation
''' '''
op_code = ASE_Operation.Opcode.ENABLE metadata: bytes
ase_id: Sequence[int] = field(metadata=hci.metadata(1, list_begin=True))
metadata: Sequence[bytes] = field(metadata=hci.metadata('v', list_end=True))
@ASE_Operation.subclass @ASE_Operation.subclass([[('ase_id', 1)]])
@dataclass
class ASE_Receiver_Start_Ready(ASE_Operation): class ASE_Receiver_Start_Ready(ASE_Operation):
''' '''
See Audio Stream Control Service 5.4 - Receiver Start Ready Operation See Audio Stream Control Service 5.4 - Receiver Start Ready Operation
''' '''
op_code = ASE_Operation.Opcode.RECEIVER_START_READY
ase_id: Sequence[int] = field( @ASE_Operation.subclass([[('ase_id', 1)]])
metadata=hci.metadata(1, list_begin=True, list_end=True)
)
@ASE_Operation.subclass
@dataclass
class ASE_Disable(ASE_Operation): class ASE_Disable(ASE_Operation):
''' '''
See Audio Stream Control Service 5.5 - Disable Operation See Audio Stream Control Service 5.5 - Disable Operation
''' '''
op_code = ASE_Operation.Opcode.DISABLE
ase_id: Sequence[int] = field( @ASE_Operation.subclass([[('ase_id', 1)]])
metadata=hci.metadata(1, list_begin=True, list_end=True)
)
@ASE_Operation.subclass
@dataclass
class ASE_Receiver_Stop_Ready(ASE_Operation): class ASE_Receiver_Stop_Ready(ASE_Operation):
''' '''
See Audio Stream Control Service 5.6 - Receiver Stop Ready Operation See Audio Stream Control Service 5.6 - Receiver Stop Ready Operation
''' '''
op_code = ASE_Operation.Opcode.RECEIVER_STOP_READY
ase_id: Sequence[int] = field( @ASE_Operation.subclass([[('ase_id', 1), ('metadata', 'v')]])
metadata=hci.metadata(1, list_begin=True, list_end=True)
)
@ASE_Operation.subclass
@dataclass
class ASE_Update_Metadata(ASE_Operation): class ASE_Update_Metadata(ASE_Operation):
''' '''
See Audio Stream Control Service 5.7 - Update Metadata Operation See Audio Stream Control Service 5.7 - Update Metadata Operation
''' '''
op_code = ASE_Operation.Opcode.UPDATE_METADATA metadata: List[bytes]
ase_id: Sequence[int] = field(metadata=hci.metadata(1, list_begin=True))
metadata: Sequence[bytes] = field(metadata=hci.metadata('v', list_end=True))
@ASE_Operation.subclass @ASE_Operation.subclass([[('ase_id', 1)]])
@dataclass
class ASE_Release(ASE_Operation): class ASE_Release(ASE_Operation):
''' '''
See Audio Stream Control Service 5.8 - Release Operation See Audio Stream Control Service 5.8 - Release Operation
''' '''
op_code = ASE_Operation.Opcode.RELEASE
ase_id: Sequence[int] = field(
metadata=hci.metadata(1, list_begin=True, list_end=True)
)
class AseResponseCode(enum.IntEnum): class AseResponseCode(enum.IntEnum):
# fmt: off # fmt: off
@@ -278,7 +278,7 @@ class AseStateMachine(gatt.Characteristic):
EVENT_STATE_CHANGE = "state_change" EVENT_STATE_CHANGE = "state_change"
cis_link: device.CisLink | None = None cis_link: Optional[device.CisLink] = None
# Additional parameters in CODEC_CONFIGURED State # Additional parameters in CODEC_CONFIGURED State
preferred_framing = 0 # Unframed PDU supported preferred_framing = 0 # Unframed PDU supported
@@ -290,7 +290,7 @@ class AseStateMachine(gatt.Characteristic):
preferred_presentation_delay_min = 0 preferred_presentation_delay_min = 0
preferred_presentation_delay_max = 0 preferred_presentation_delay_max = 0
codec_id = hci.CodingFormat(hci.CodecID.LC3) codec_id = hci.CodingFormat(hci.CodecID.LC3)
codec_specific_configuration: CodecSpecificConfiguration | bytes = b'' codec_specific_configuration: Union[CodecSpecificConfiguration, bytes] = b''
# Additional parameters in QOS_CONFIGURED State # Additional parameters in QOS_CONFIGURED State
cig_id = 0 cig_id = 0
@@ -338,16 +338,22 @@ class AseStateMachine(gatt.Characteristic):
self.service.device.EVENT_CIS_ESTABLISHMENT, self.on_cis_establishment self.service.device.EVENT_CIS_ESTABLISHMENT, self.on_cis_establishment
) )
def on_cis_request(self, cis_link: device.CisLink) -> None: def on_cis_request(
self,
acl_connection: device.Connection,
cis_handle: int,
cig_id: int,
cis_id: int,
) -> None:
if ( if (
cis_link.cig_id == self.cig_id cig_id == self.cig_id
and cis_link.cis_id == self.cis_id and cis_id == self.cis_id
and self.state == self.State.ENABLING and self.state == self.State.ENABLING
): ):
utils.cancel_on_event( utils.cancel_on_event(
cis_link.acl_connection, acl_connection,
'flush', 'flush',
self.service.device.accept_cis_request(cis_link), self.service.device.accept_cis_request(cis_handle),
) )
def on_cis_establishment(self, cis_link: device.CisLink) -> None: def on_cis_establishment(self, cis_link: device.CisLink) -> None:
@@ -378,7 +384,7 @@ class AseStateMachine(gatt.Characteristic):
target_phy: int, target_phy: int,
codec_id: hci.CodingFormat, codec_id: hci.CodingFormat,
codec_specific_configuration: bytes, codec_specific_configuration: bytes,
) -> tuple[AseResponseCode, AseReasonCode]: ) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state not in ( if self.state not in (
self.State.IDLE, self.State.IDLE,
self.State.CODEC_CONFIGURED, self.State.CODEC_CONFIGURED,
@@ -414,7 +420,7 @@ class AseStateMachine(gatt.Characteristic):
retransmission_number: int, retransmission_number: int,
max_transport_latency: int, max_transport_latency: int,
presentation_delay: int, presentation_delay: int,
) -> tuple[AseResponseCode, AseReasonCode]: ) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state not in ( if self.state not in (
AseStateMachine.State.CODEC_CONFIGURED, AseStateMachine.State.CODEC_CONFIGURED,
AseStateMachine.State.QOS_CONFIGURED, AseStateMachine.State.QOS_CONFIGURED,
@@ -438,7 +444,7 @@ class AseStateMachine(gatt.Characteristic):
return (AseResponseCode.SUCCESS, AseReasonCode.NONE) return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_enable(self, metadata: bytes) -> tuple[AseResponseCode, AseReasonCode]: def on_enable(self, metadata: bytes) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state != AseStateMachine.State.QOS_CONFIGURED: if self.state != AseStateMachine.State.QOS_CONFIGURED:
return ( return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION, AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
@@ -447,20 +453,10 @@ class AseStateMachine(gatt.Characteristic):
self.metadata = le_audio.Metadata.from_bytes(metadata) self.metadata = le_audio.Metadata.from_bytes(metadata)
self.state = self.State.ENABLING self.state = self.State.ENABLING
# CIS could be established before enable.
if cis_link := next(
(
cis_link
for cis_link in self.service.device.cis_links.values()
if cis_link.cig_id == self.cig_id and cis_link.cis_id == self.cis_id
),
None,
):
self.on_cis_establishment(cis_link)
return (AseResponseCode.SUCCESS, AseReasonCode.NONE) return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_receiver_start_ready(self) -> tuple[AseResponseCode, AseReasonCode]: def on_receiver_start_ready(self) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state != AseStateMachine.State.ENABLING: if self.state != AseStateMachine.State.ENABLING:
return ( return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION, AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
@@ -469,7 +465,7 @@ class AseStateMachine(gatt.Characteristic):
self.state = self.State.STREAMING self.state = self.State.STREAMING
return (AseResponseCode.SUCCESS, AseReasonCode.NONE) return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_disable(self) -> tuple[AseResponseCode, AseReasonCode]: def on_disable(self) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state not in ( if self.state not in (
AseStateMachine.State.ENABLING, AseStateMachine.State.ENABLING,
AseStateMachine.State.STREAMING, AseStateMachine.State.STREAMING,
@@ -484,7 +480,7 @@ class AseStateMachine(gatt.Characteristic):
self.state = self.State.DISABLING self.state = self.State.DISABLING
return (AseResponseCode.SUCCESS, AseReasonCode.NONE) return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_receiver_stop_ready(self) -> tuple[AseResponseCode, AseReasonCode]: def on_receiver_stop_ready(self) -> Tuple[AseResponseCode, AseReasonCode]:
if ( if (
self.role != AudioRole.SOURCE self.role != AudioRole.SOURCE
or self.state != AseStateMachine.State.DISABLING or self.state != AseStateMachine.State.DISABLING
@@ -498,7 +494,7 @@ class AseStateMachine(gatt.Characteristic):
def on_update_metadata( def on_update_metadata(
self, metadata: bytes self, metadata: bytes
) -> tuple[AseResponseCode, AseReasonCode]: ) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state not in ( if self.state not in (
AseStateMachine.State.ENABLING, AseStateMachine.State.ENABLING,
AseStateMachine.State.STREAMING, AseStateMachine.State.STREAMING,
@@ -510,7 +506,7 @@ class AseStateMachine(gatt.Characteristic):
self.metadata = le_audio.Metadata.from_bytes(metadata) self.metadata = le_audio.Metadata.from_bytes(metadata)
return (AseResponseCode.SUCCESS, AseReasonCode.NONE) return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_release(self) -> tuple[AseResponseCode, AseReasonCode]: def on_release(self) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state == AseStateMachine.State.IDLE: if self.state == AseStateMachine.State.IDLE:
return ( return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION, AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
@@ -520,7 +516,7 @@ class AseStateMachine(gatt.Characteristic):
async def remove_cis_async(): async def remove_cis_async():
if self.cis_link: if self.cis_link:
await self.cis_link.remove_data_path([self.role]) await self.cis_link.remove_data_path(self.role)
self.state = self.State.IDLE self.state = self.State.IDLE
await self.service.device.notify_subscribers(self, self.value) await self.service.device.notify_subscribers(self, self.value)
@@ -594,7 +590,7 @@ class AseStateMachine(gatt.Characteristic):
# Readonly. Do nothing in the setter. # Readonly. Do nothing in the setter.
pass pass
def on_read(self, _: device.Connection) -> bytes: def on_read(self, _: Optional[device.Connection]) -> bytes:
return self.value return self.value
def __str__(self) -> str: def __str__(self) -> str:
@@ -608,9 +604,9 @@ class AseStateMachine(gatt.Characteristic):
class AudioStreamControlService(gatt.TemplateService): class AudioStreamControlService(gatt.TemplateService):
UUID = gatt.GATT_AUDIO_STREAM_CONTROL_SERVICE UUID = gatt.GATT_AUDIO_STREAM_CONTROL_SERVICE
ase_state_machines: dict[int, AseStateMachine] ase_state_machines: Dict[int, AseStateMachine]
ase_control_point: gatt.Characteristic[bytes] ase_control_point: gatt.Characteristic[bytes]
_active_client: device.Connection | None = None _active_client: Optional[device.Connection] = None
def __init__( def __init__(
self, self,
@@ -653,9 +649,7 @@ class AudioStreamControlService(gatt.TemplateService):
ase.state = AseStateMachine.State.IDLE ase.state = AseStateMachine.State.IDLE
self._active_client = None self._active_client = None
def on_write_ase_control_point( def on_write_ase_control_point(self, connection, data):
self, connection: device.Connection, data: bytes
) -> None:
if not self._active_client and connection: if not self._active_client and connection:
self._active_client = connection self._active_client = connection
connection.once('disconnection', self._on_client_disconnected) connection.once('disconnection', self._on_client_disconnected)
@@ -664,44 +658,46 @@ class AudioStreamControlService(gatt.TemplateService):
responses = [] responses = []
logger.debug(f'*** ASCS Write {operation} ***') logger.debug(f'*** ASCS Write {operation} ***')
match operation: if operation.op_code == ASE_Operation.Opcode.CONFIG_CODEC:
case ASE_Config_Codec(): for ase_id, *args in zip(
for ase_id, *args in zip( operation.ase_id,
operation.ase_id, operation.target_latency,
operation.target_latency, operation.target_phy,
operation.target_phy, operation.codec_id,
operation.codec_id, operation.codec_specific_configuration,
operation.codec_specific_configuration,
):
responses.append(self.on_operation(operation.op_code, ase_id, args))
case ASE_Config_QOS():
for ase_id, *args in zip(
operation.ase_id,
operation.cig_id,
operation.cis_id,
operation.sdu_interval,
operation.framing,
operation.phy,
operation.max_sdu,
operation.retransmission_number,
operation.max_transport_latency,
operation.presentation_delay,
):
responses.append(self.on_operation(operation.op_code, ase_id, args))
case ASE_Enable() | ASE_Update_Metadata():
for ase_id, *args in zip(
operation.ase_id,
operation.metadata,
):
responses.append(self.on_operation(operation.op_code, ase_id, args))
case (
ASE_Receiver_Start_Ready()
| ASE_Disable()
| ASE_Receiver_Stop_Ready()
| ASE_Release()
): ):
for ase_id in operation.ase_id: responses.append(self.on_operation(operation.op_code, ase_id, args))
responses.append(self.on_operation(operation.op_code, ase_id, [])) elif operation.op_code == ASE_Operation.Opcode.CONFIG_QOS:
for ase_id, *args in zip(
operation.ase_id,
operation.cig_id,
operation.cis_id,
operation.sdu_interval,
operation.framing,
operation.phy,
operation.max_sdu,
operation.retransmission_number,
operation.max_transport_latency,
operation.presentation_delay,
):
responses.append(self.on_operation(operation.op_code, ase_id, args))
elif operation.op_code in (
ASE_Operation.Opcode.ENABLE,
ASE_Operation.Opcode.UPDATE_METADATA,
):
for ase_id, *args in zip(
operation.ase_id,
operation.metadata,
):
responses.append(self.on_operation(operation.op_code, ase_id, args))
elif operation.op_code in (
ASE_Operation.Opcode.RECEIVER_START_READY,
ASE_Operation.Opcode.DISABLE,
ASE_Operation.Opcode.RECEIVER_STOP_READY,
ASE_Operation.Opcode.RELEASE,
):
for ase_id in operation.ase_id:
responses.append(self.on_operation(operation.op_code, ase_id, []))
control_point_notification = bytes( control_point_notification = bytes(
[operation.op_code, len(responses)] [operation.op_code, len(responses)]
@@ -727,8 +723,8 @@ class AudioStreamControlService(gatt.TemplateService):
class AudioStreamControlServiceProxy(gatt_client.ProfileServiceProxy): class AudioStreamControlServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = AudioStreamControlService SERVICE_CLASS = AudioStreamControlService
sink_ase: list[gatt_client.CharacteristicProxy[bytes]] sink_ase: List[gatt_client.CharacteristicProxy[bytes]]
source_ase: list[gatt_client.CharacteristicProxy[bytes]] source_ase: List[gatt_client.CharacteristicProxy[bytes]]
ase_control_point: gatt_client.CharacteristicProxy[bytes] ase_control_point: gatt_client.CharacteristicProxy[bytes]
def __init__(self, service_proxy: gatt_client.ServiceProxy): def __init__(self, service_proxy: gatt_client.ServiceProxy):
+22 -19
View File
@@ -17,14 +17,16 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import enum import enum
import logging
import struct import struct
from collections.abc import Callable import logging
from typing import Any from typing import List, Optional, Callable, Union, Any
from bumble import data_types, gatt, gatt_client, l2cap, utils from bumble import l2cap
from bumble import utils
from bumble import gatt
from bumble import gatt_client
from bumble.core import AdvertisingData from bumble.core import AdvertisingData
from bumble.device import Connection, Device from bumble.device import Device, Connection
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -91,20 +93,20 @@ class AshaService(gatt.TemplateService):
EVENT_DISCONNECTED = "disconnected" EVENT_DISCONNECTED = "disconnected"
EVENT_VOLUME_CHANGED = "volume_changed" EVENT_VOLUME_CHANGED = "volume_changed"
audio_sink: Callable[[bytes], Any] | None audio_sink: Optional[Callable[[bytes], Any]]
active_codec: Codec | None = None active_codec: Optional[Codec] = None
audio_type: AudioType | None = None audio_type: Optional[AudioType] = None
volume: int | None = None volume: Optional[int] = None
other_state: int | None = None other_state: Optional[int] = None
connection: Connection | None = None connection: Optional[Connection] = None
def __init__( def __init__(
self, self,
capability: int, capability: int,
hisyncid: list[int] | bytes, hisyncid: Union[List[int], bytes],
device: Device, device: Device,
psm: int = 0, psm: int = 0,
audio_sink: Callable[[bytes], Any] | None = None, audio_sink: Optional[Callable[[bytes], Any]] = None,
feature_map: int = FeatureMap.LE_COC_AUDIO_OUTPUT_STREAMING_SUPPORTED, feature_map: int = FeatureMap.LE_COC_AUDIO_OUTPUT_STREAMING_SUPPORTED,
protocol_version: int = 0x01, protocol_version: int = 0x01,
render_delay_milliseconds: int = 0, render_delay_milliseconds: int = 0,
@@ -186,18 +188,19 @@ class AshaService(gatt.TemplateService):
return bytes( return bytes(
AdvertisingData( AdvertisingData(
[ [
data_types.ServiceData16BitUUID( (
gatt.GATT_ASHA_SERVICE, AdvertisingData.SERVICE_DATA_16_BIT_UUID,
bytes([self.protocol_version, self.capability]) bytes(gatt.GATT_ASHA_SERVICE)
+ bytes([self.protocol_version, self.capability])
+ self.hisyncid[:4], + self.hisyncid[:4],
) ),
] ]
) )
) )
# Handler for audio control commands # Handler for audio control commands
async def _on_audio_control_point_write( async def _on_audio_control_point_write(
self, connection: Connection, value: bytes self, connection: Optional[Connection], value: bytes
) -> None: ) -> None:
_logger.debug(f'--- AUDIO CONTROL POINT Write:{value.hex()}') _logger.debug(f'--- AUDIO CONTROL POINT Write:{value.hex()}')
opcode = value[0] opcode = value[0]
@@ -244,7 +247,7 @@ class AshaService(gatt.TemplateService):
) )
# Handler for volume control # Handler for volume control
def _on_volume_write(self, connection: Connection, value: bytes) -> None: def _on_volume_write(self, connection: Optional[Connection], value: bytes) -> None:
_logger.debug(f'--- VOLUME Write:{value[0]}') _logger.debug(f'--- VOLUME Write:{value[0]}')
self.volume = value[0] self.volume = value[0]
self.emit(self.EVENT_VOLUME_CHANGED) self.emit(self.EVENT_VOLUME_CHANGED)
+38 -26
View File
@@ -18,18 +18,22 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations from __future__ import annotations
from collections.abc import Sequence
import dataclasses import dataclasses
import enum import enum
import struct
import functools import functools
import logging import logging
import struct from typing import List
from collections.abc import Sequence
from typing_extensions import Self from typing_extensions import Self
from bumble import core, data_types, gatt, hci, utils from bumble import core
from bumble import hci
from bumble import gatt
from bumble import utils
from bumble.profiles import le_audio from bumble.profiles import le_audio
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -257,10 +261,11 @@ class UnicastServerAdvertisingData:
return bytes( return bytes(
core.AdvertisingData( core.AdvertisingData(
[ [
data_types.ServiceData16BitUUID( (
gatt.GATT_AUDIO_STREAM_CONTROL_SERVICE, core.AdvertisingData.SERVICE_DATA_16_BIT_UUID,
struct.pack( struct.pack(
'<BIB', '<2sBIB',
bytes(gatt.GATT_AUDIO_STREAM_CONTROL_SERVICE),
self.announcement_type, self.announcement_type,
self.available_audio_contexts, self.available_audio_contexts,
len(self.metadata), len(self.metadata),
@@ -277,7 +282,7 @@ class UnicastServerAdvertisingData:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def bits_to_channel_counts(data: int) -> list[int]: def bits_to_channel_counts(data: int) -> List[int]:
pos = 0 pos = 0
counts = [] counts = []
while data != 0: while data != 0:
@@ -333,18 +338,17 @@ class CodecSpecificCapabilities:
value = int.from_bytes(data[offset : offset + length - 1], 'little') value = int.from_bytes(data[offset : offset + length - 1], 'little')
offset += length - 1 offset += length - 1
match type: if type == CodecSpecificCapabilities.Type.SAMPLING_FREQUENCY:
case CodecSpecificCapabilities.Type.SAMPLING_FREQUENCY: supported_sampling_frequencies = SupportedSamplingFrequency(value)
supported_sampling_frequencies = SupportedSamplingFrequency(value) elif type == CodecSpecificCapabilities.Type.FRAME_DURATION:
case CodecSpecificCapabilities.Type.FRAME_DURATION: supported_frame_durations = SupportedFrameDuration(value)
supported_frame_durations = SupportedFrameDuration(value) elif type == CodecSpecificCapabilities.Type.AUDIO_CHANNEL_COUNT:
case CodecSpecificCapabilities.Type.AUDIO_CHANNEL_COUNT: supported_audio_channel_count = bits_to_channel_counts(value)
supported_audio_channel_count = bits_to_channel_counts(value) elif type == CodecSpecificCapabilities.Type.OCTETS_PER_FRAME:
case CodecSpecificCapabilities.Type.OCTETS_PER_FRAME: min_octets_per_sample = value & 0xFFFF
min_octets_per_sample = value & 0xFFFF max_octets_per_sample = value >> 16
max_octets_per_sample = value >> 16 elif type == CodecSpecificCapabilities.Type.CODEC_FRAMES_PER_SDU:
case CodecSpecificCapabilities.Type.CODEC_FRAMES_PER_SDU: supported_max_codec_frames_per_sdu = value
supported_max_codec_frames_per_sdu = value
# It is expected here that if some fields are missing, an error should be raised. # It is expected here that if some fields are missing, an error should be raised.
# pylint: disable=possibly-used-before-assignment,used-before-assignment # pylint: disable=possibly-used-before-assignment,used-before-assignment
@@ -490,8 +494,12 @@ class BroadcastAudioAnnouncement:
return bytes( return bytes(
core.AdvertisingData( core.AdvertisingData(
[ [
data_types.ServiceData16BitUUID( (
gatt.GATT_BROADCAST_AUDIO_ANNOUNCEMENT_SERVICE, bytes(self) core.AdvertisingData.SERVICE_DATA_16_BIT_UUID,
(
bytes(gatt.GATT_BROADCAST_AUDIO_ANNOUNCEMENT_SERVICE)
+ bytes(self)
),
) )
] ]
) )
@@ -519,7 +527,7 @@ class BasicAudioAnnouncement:
codec_id: hci.CodingFormat codec_id: hci.CodingFormat
codec_specific_configuration: CodecSpecificConfiguration codec_specific_configuration: CodecSpecificConfiguration
metadata: le_audio.Metadata metadata: le_audio.Metadata
bis: list[BasicAudioAnnouncement.BIS] bis: List[BasicAudioAnnouncement.BIS]
def __bytes__(self) -> bytes: def __bytes__(self) -> bytes:
metadata_bytes = bytes(self.metadata) metadata_bytes = bytes(self.metadata)
@@ -537,7 +545,7 @@ class BasicAudioAnnouncement:
) )
presentation_delay: int presentation_delay: int
subgroups: list[BasicAudioAnnouncement.Subgroup] subgroups: List[BasicAudioAnnouncement.Subgroup]
@classmethod @classmethod
def from_bytes(cls, data: bytes) -> Self: def from_bytes(cls, data: bytes) -> Self:
@@ -603,8 +611,12 @@ class BasicAudioAnnouncement:
return bytes( return bytes(
core.AdvertisingData( core.AdvertisingData(
[ [
data_types.ServiceData16BitUUID( (
gatt.GATT_BASIC_AUDIO_ANNOUNCEMENT_SERVICE, bytes(self) core.AdvertisingData.SERVICE_DATA_16_BIT_UUID,
(
bytes(gatt.GATT_BASIC_AUDIO_ANNOUNCEMENT_SERVICE)
+ bytes(self)
),
) )
] ]
) )
+10 -11
View File
@@ -17,14 +17,18 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations from __future__ import annotations
import dataclasses import dataclasses
import logging import logging
import struct import struct
from collections.abc import Sequence from typing import ClassVar, Optional, Sequence
from typing import ClassVar
from bumble import core, device, gatt, gatt_adapters, gatt_client, hci, utils from bumble import core
from bumble import device
from bumble import gatt
from bumble import gatt_adapters
from bumble import gatt_client
from bumble import hci
from bumble import utils
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -338,12 +342,7 @@ class BroadcastAudioScanService(gatt.TemplateService):
b"12", # TEST b"12", # TEST
) )
super().__init__( super().__init__([self.battery_level_characteristic])
[
self.broadcast_audio_scan_control_point_characteristic,
self.broadcast_receive_state_characteristic,
]
)
def on_broadcast_audio_scan_control_point_write( def on_broadcast_audio_scan_control_point_write(
self, connection: device.Connection, value: bytes self, connection: device.Connection, value: bytes
@@ -357,7 +356,7 @@ class BroadcastAudioScanServiceProxy(gatt_client.ProfileServiceProxy):
broadcast_audio_scan_control_point: gatt_client.CharacteristicProxy[bytes] broadcast_audio_scan_control_point: gatt_client.CharacteristicProxy[bytes]
broadcast_receive_states: list[ broadcast_receive_states: list[
gatt_client.CharacteristicProxy[BroadcastReceiveState | None] gatt_client.CharacteristicProxy[Optional[BroadcastReceiveState]]
] ]
def __init__(self, service_proxy: gatt_client.ServiceProxy): def __init__(self, service_proxy: gatt_client.ServiceProxy):
+35 -24
View File
@@ -16,28 +16,37 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from collections.abc import Callable from typing import Optional
from bumble import device, gatt, gatt_adapters, gatt_client from bumble.gatt_client import ProfileServiceProxy
from bumble.gatt import (
GATT_BATTERY_SERVICE,
GATT_BATTERY_LEVEL_CHARACTERISTIC,
TemplateService,
Characteristic,
CharacteristicValue,
)
from bumble.gatt_client import CharacteristicProxy
from bumble.gatt_adapters import (
PackedCharacteristicAdapter,
PackedCharacteristicProxyAdapter,
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class BatteryService(gatt.TemplateService): class BatteryService(TemplateService):
UUID = gatt.GATT_BATTERY_SERVICE UUID = GATT_BATTERY_SERVICE
BATTERY_LEVEL_FORMAT = 'B' BATTERY_LEVEL_FORMAT = 'B'
battery_level_characteristic: gatt.Characteristic[int] battery_level_characteristic: Characteristic[int]
def __init__(self, read_battery_level: Callable[[device.Connection], int]) -> None: def __init__(self, read_battery_level):
self.battery_level_characteristic = gatt_adapters.PackedCharacteristicAdapter( self.battery_level_characteristic = PackedCharacteristicAdapter(
gatt.Characteristic( Characteristic(
gatt.GATT_BATTERY_LEVEL_CHARACTERISTIC, GATT_BATTERY_LEVEL_CHARACTERISTIC,
properties=( Characteristic.Properties.READ | Characteristic.Properties.NOTIFY,
gatt.Characteristic.Properties.READ Characteristic.READABLE,
| gatt.Characteristic.Properties.NOTIFY CharacteristicValue(read=read_battery_level),
),
permissions=gatt.Characteristic.READABLE,
value=gatt.CharacteristicValue(read=read_battery_level),
), ),
pack_format=BatteryService.BATTERY_LEVEL_FORMAT, pack_format=BatteryService.BATTERY_LEVEL_FORMAT,
) )
@@ -45,17 +54,19 @@ class BatteryService(gatt.TemplateService):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class BatteryServiceProxy(gatt_client.ProfileServiceProxy): class BatteryServiceProxy(ProfileServiceProxy):
SERVICE_CLASS = BatteryService SERVICE_CLASS = BatteryService
battery_level: gatt_client.CharacteristicProxy[int] battery_level: Optional[CharacteristicProxy[int]]
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None: def __init__(self, service_proxy):
self.service_proxy = service_proxy self.service_proxy = service_proxy
self.battery_level = gatt_adapters.PackedCharacteristicProxyAdapter( if characteristics := service_proxy.get_characteristics_by_uuid(
service_proxy.get_required_characteristic_by_uuid( GATT_BATTERY_LEVEL_CHARACTERISTIC
gatt.GATT_BATTERY_LEVEL_CHARACTERISTIC ):
), self.battery_level = PackedCharacteristicProxyAdapter(
pack_format=BatteryService.BATTERY_LEVEL_FORMAT, characteristics[0], pack_format=BatteryService.BATTERY_LEVEL_FORMAT
) )
else:
self.battery_level = None
+2 -1
View File
@@ -18,7 +18,8 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations from __future__ import annotations
from bumble import gatt, gatt_client from bumble import gatt
from bumble import gatt_client
from bumble.profiles import csip from bumble.profiles import csip
+20 -13
View File
@@ -17,11 +17,16 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations from __future__ import annotations
import enum import enum
import struct import struct
from typing import Optional, Tuple
from bumble import core
from bumble import crypto
from bumble import device
from bumble import gatt
from bumble import gatt_client
from bumble import core, crypto, device, gatt, gatt_client
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Constants # Constants
@@ -95,17 +100,17 @@ class CoordinatedSetIdentificationService(gatt.TemplateService):
set_identity_resolving_key: bytes set_identity_resolving_key: bytes
set_identity_resolving_key_characteristic: gatt.Characteristic[bytes] set_identity_resolving_key_characteristic: gatt.Characteristic[bytes]
coordinated_set_size_characteristic: gatt.Characteristic[bytes] | None = None coordinated_set_size_characteristic: Optional[gatt.Characteristic[bytes]] = None
set_member_lock_characteristic: gatt.Characteristic[bytes] | None = None set_member_lock_characteristic: Optional[gatt.Characteristic[bytes]] = None
set_member_rank_characteristic: gatt.Characteristic[bytes] | None = None set_member_rank_characteristic: Optional[gatt.Characteristic[bytes]] = None
def __init__( def __init__(
self, self,
set_identity_resolving_key: bytes, set_identity_resolving_key: bytes,
set_identity_resolving_key_type: SirkType, set_identity_resolving_key_type: SirkType,
coordinated_set_size: int | None = None, coordinated_set_size: Optional[int] = None,
set_member_lock: MemberLock | None = None, set_member_lock: Optional[MemberLock] = None,
set_member_rank: int | None = None, set_member_rank: Optional[int] = None,
) -> None: ) -> None:
if len(set_identity_resolving_key) != SET_IDENTITY_RESOLVING_KEY_LENGTH: if len(set_identity_resolving_key) != SET_IDENTITY_RESOLVING_KEY_LENGTH:
raise core.InvalidArgumentError( raise core.InvalidArgumentError(
@@ -159,10 +164,12 @@ class CoordinatedSetIdentificationService(gatt.TemplateService):
super().__init__(characteristics) super().__init__(characteristics)
async def on_sirk_read(self, connection: device.Connection) -> bytes: async def on_sirk_read(self, connection: Optional[device.Connection]) -> bytes:
if self.set_identity_resolving_key_type == SirkType.PLAINTEXT: if self.set_identity_resolving_key_type == SirkType.PLAINTEXT:
sirk_bytes = self.set_identity_resolving_key sirk_bytes = self.set_identity_resolving_key
else: else:
assert connection
if connection.transport == core.PhysicalTransport.LE: if connection.transport == core.PhysicalTransport.LE:
key = await connection.device.get_long_term_key( key = await connection.device.get_long_term_key(
connection_handle=connection.handle, rand=b'', ediv=0 connection_handle=connection.handle, rand=b'', ediv=0
@@ -197,9 +204,9 @@ class CoordinatedSetIdentificationProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = CoordinatedSetIdentificationService SERVICE_CLASS = CoordinatedSetIdentificationService
set_identity_resolving_key: gatt_client.CharacteristicProxy[bytes] set_identity_resolving_key: gatt_client.CharacteristicProxy[bytes]
coordinated_set_size: gatt_client.CharacteristicProxy[bytes] | None = None coordinated_set_size: Optional[gatt_client.CharacteristicProxy[bytes]] = None
set_member_lock: gatt_client.CharacteristicProxy[bytes] | None = None set_member_lock: Optional[gatt_client.CharacteristicProxy[bytes]] = None
set_member_rank: gatt_client.CharacteristicProxy[bytes] | None = None set_member_rank: Optional[gatt_client.CharacteristicProxy[bytes]] = None
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None: def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
self.service_proxy = service_proxy self.service_proxy = service_proxy
@@ -223,7 +230,7 @@ class CoordinatedSetIdentificationProxy(gatt_client.ProfileServiceProxy):
): ):
self.set_member_rank = characteristics[0] self.set_member_rank = characteristics[0]
async def read_set_identity_resolving_key(self) -> tuple[SirkType, bytes]: async def read_set_identity_resolving_key(self) -> Tuple[SirkType, bytes]:
'''Reads SIRK and decrypts if encrypted.''' '''Reads SIRK and decrypts if encrypted.'''
response = await self.set_identity_resolving_key.read_value() response = await self.set_identity_resolving_key.read_value()
if len(response) != SET_IDENTITY_RESOLVING_KEY_LENGTH + 1: if len(response) != SET_IDENTITY_RESOLVING_KEY_LENGTH + 1:
+19 -18
View File
@@ -17,6 +17,7 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import struct import struct
from typing import Optional, Tuple
from bumble.gatt import ( from bumble.gatt import (
GATT_DEVICE_INFORMATION_SERVICE, GATT_DEVICE_INFORMATION_SERVICE,
@@ -24,12 +25,12 @@ from bumble.gatt import (
GATT_HARDWARE_REVISION_STRING_CHARACTERISTIC, GATT_HARDWARE_REVISION_STRING_CHARACTERISTIC,
GATT_MANUFACTURER_NAME_STRING_CHARACTERISTIC, GATT_MANUFACTURER_NAME_STRING_CHARACTERISTIC,
GATT_MODEL_NUMBER_STRING_CHARACTERISTIC, GATT_MODEL_NUMBER_STRING_CHARACTERISTIC,
GATT_REGULATORY_CERTIFICATION_DATA_LIST_CHARACTERISTIC,
GATT_SERIAL_NUMBER_STRING_CHARACTERISTIC, GATT_SERIAL_NUMBER_STRING_CHARACTERISTIC,
GATT_SOFTWARE_REVISION_STRING_CHARACTERISTIC, GATT_SOFTWARE_REVISION_STRING_CHARACTERISTIC,
GATT_SYSTEM_ID_CHARACTERISTIC, GATT_SYSTEM_ID_CHARACTERISTIC,
Characteristic, GATT_REGULATORY_CERTIFICATION_DATA_LIST_CHARACTERISTIC,
TemplateService, TemplateService,
Characteristic,
) )
from bumble.gatt_adapters import ( from bumble.gatt_adapters import (
DelegatedCharacteristicProxyAdapter, DelegatedCharacteristicProxyAdapter,
@@ -53,14 +54,14 @@ class DeviceInformationService(TemplateService):
def __init__( def __init__(
self, self,
manufacturer_name: str | None = None, manufacturer_name: Optional[str] = None,
model_number: str | None = None, model_number: Optional[str] = None,
serial_number: str | None = None, serial_number: Optional[str] = None,
hardware_revision: str | None = None, hardware_revision: Optional[str] = None,
firmware_revision: str | None = None, firmware_revision: Optional[str] = None,
software_revision: str | None = None, software_revision: Optional[str] = None,
system_id: tuple[int, int] | None = None, # (OUI, Manufacturer ID) system_id: Optional[Tuple[int, int]] = None, # (OUI, Manufacturer ID)
ieee_regulatory_certification_data_list: bytes | None = None, ieee_regulatory_certification_data_list: Optional[bytes] = None,
# TODO: pnp_id # TODO: pnp_id
): ):
characteristics: list[Characteristic[bytes]] = [ characteristics: list[Characteristic[bytes]] = [
@@ -108,14 +109,14 @@ class DeviceInformationService(TemplateService):
class DeviceInformationServiceProxy(ProfileServiceProxy): class DeviceInformationServiceProxy(ProfileServiceProxy):
SERVICE_CLASS = DeviceInformationService SERVICE_CLASS = DeviceInformationService
manufacturer_name: CharacteristicProxy[str] | None manufacturer_name: Optional[CharacteristicProxy[str]]
model_number: CharacteristicProxy[str] | None model_number: Optional[CharacteristicProxy[str]]
serial_number: CharacteristicProxy[str] | None serial_number: Optional[CharacteristicProxy[str]]
hardware_revision: CharacteristicProxy[str] | None hardware_revision: Optional[CharacteristicProxy[str]]
firmware_revision: CharacteristicProxy[str] | None firmware_revision: Optional[CharacteristicProxy[str]]
software_revision: CharacteristicProxy[str] | None software_revision: Optional[CharacteristicProxy[str]]
system_id: CharacteristicProxy[tuple[int, int]] | None system_id: Optional[CharacteristicProxy[tuple[int, int]]]
ieee_regulatory_certification_data_list: CharacteristicProxy[bytes] | None ieee_regulatory_certification_data_list: Optional[CharacteristicProxy[bytes]]
def __init__(self, service_proxy: ServiceProxy): def __init__(self, service_proxy: ServiceProxy):
self.service_proxy = service_proxy self.service_proxy = service_proxy
+16 -16
View File
@@ -19,14 +19,15 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import logging import logging
import struct import struct
from typing import Optional, Tuple, Union
from bumble.core import Appearance from bumble.core import Appearance
from bumble.gatt import ( from bumble.gatt import (
GATT_APPEARANCE_CHARACTERISTIC,
GATT_DEVICE_NAME_CHARACTERISTIC,
GATT_GENERIC_ACCESS_SERVICE,
Characteristic,
TemplateService, TemplateService,
Characteristic,
GATT_GENERIC_ACCESS_SERVICE,
GATT_DEVICE_NAME_CHARACTERISTIC,
GATT_APPEARANCE_CHARACTERISTIC,
) )
from bumble.gatt_adapters import ( from bumble.gatt_adapters import (
DelegatedCharacteristicProxyAdapter, DelegatedCharacteristicProxyAdapter,
@@ -53,17 +54,16 @@ class GenericAccessService(TemplateService):
appearance_characteristic: Characteristic[bytes] appearance_characteristic: Characteristic[bytes]
def __init__( def __init__(
self, device_name: str, appearance: Appearance | tuple[int, int] | int = 0 self, device_name: str, appearance: Union[Appearance, Tuple[int, int], int] = 0
): ):
match appearance: if isinstance(appearance, int):
case int(): appearance_int = appearance
appearance_int = appearance elif isinstance(appearance, tuple):
case tuple(): appearance_int = (appearance[0] << 6) | appearance[1]
appearance_int = (appearance[0] << 6) | appearance[1] elif isinstance(appearance, Appearance):
case Appearance(): appearance_int = int(appearance)
appearance_int = int(appearance) else:
case _: raise TypeError()
raise TypeError()
self.device_name_characteristic = Characteristic( self.device_name_characteristic = Characteristic(
GATT_DEVICE_NAME_CHARACTERISTIC, GATT_DEVICE_NAME_CHARACTERISTIC,
@@ -88,8 +88,8 @@ class GenericAccessService(TemplateService):
class GenericAccessServiceProxy(ProfileServiceProxy): class GenericAccessServiceProxy(ProfileServiceProxy):
SERVICE_CLASS = GenericAccessService SERVICE_CLASS = GenericAccessService
device_name: CharacteristicProxy[str] | None device_name: Optional[CharacteristicProxy[str]]
appearance: CharacteristicProxy[Appearance] | None appearance: Optional[CharacteristicProxy[Appearance]]
def __init__(self, service_proxy: ServiceProxy): def __init__(self, service_proxy: ServiceProxy):
self.service_proxy = service_proxy self.service_proxy = service_proxy
+8 -2
View File
@@ -17,7 +17,10 @@ from __future__ import annotations
import struct import struct
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from bumble import att, crypto, gatt, gatt_client from bumble import att
from bumble import gatt
from bumble import gatt_client
from bumble import crypto
if TYPE_CHECKING: if TYPE_CHECKING:
from bumble import device from bumble import device
@@ -40,6 +43,7 @@ class GenericAttributeProfileService(gatt.TemplateService):
database_hash_enabled: bool = True, database_hash_enabled: bool = True,
service_change_enabled: bool = True, service_change_enabled: bool = True,
) -> None: ) -> None:
if server_supported_features is not None: if server_supported_features is not None:
self.server_supported_features_characteristic = gatt.Characteristic( self.server_supported_features_characteristic = gatt.Characteristic(
uuid=gatt.GATT_SERVER_SUPPORTED_FEATURES_CHARACTERISTIC, uuid=gatt.GATT_SERVER_SUPPORTED_FEATURES_CHARACTERISTIC,
@@ -123,7 +127,9 @@ class GenericAttributeProfileService(gatt.TemplateService):
return b'' return b''
def get_database_hash(self, connection: device.Connection) -> bytes: def get_database_hash(self, connection: device.Connection | None) -> bytes:
assert connection
m = b''.join( m = b''.join(
[ [
self.get_attribute_data(attribute) self.get_attribute_data(attribute)
+18 -17
View File
@@ -18,20 +18,21 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import struct import struct
from enum import IntFlag from typing import Optional
from bumble.gatt import ( from bumble.gatt import (
GATT_BGR_FEATURES_CHARACTERISTIC, TemplateService,
GATT_BGS_FEATURES_CHARACTERISTIC, Characteristic,
GATT_GAMING_AUDIO_SERVICE, GATT_GAMING_AUDIO_SERVICE,
GATT_GMAP_ROLE_CHARACTERISTIC, GATT_GMAP_ROLE_CHARACTERISTIC,
GATT_UGG_FEATURES_CHARACTERISTIC, GATT_UGG_FEATURES_CHARACTERISTIC,
GATT_UGT_FEATURES_CHARACTERISTIC, GATT_UGT_FEATURES_CHARACTERISTIC,
Characteristic, GATT_BGS_FEATURES_CHARACTERISTIC,
TemplateService, GATT_BGR_FEATURES_CHARACTERISTIC,
) )
from bumble.gatt_adapters import DelegatedCharacteristicProxyAdapter from bumble.gatt_adapters import DelegatedCharacteristicProxyAdapter
from bumble.gatt_client import CharacteristicProxy, ProfileServiceProxy, ServiceProxy from bumble.gatt_client import CharacteristicProxy, ProfileServiceProxy, ServiceProxy
from enum import IntFlag
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -76,18 +77,18 @@ class GamingAudioService(TemplateService):
UUID = GATT_GAMING_AUDIO_SERVICE UUID = GATT_GAMING_AUDIO_SERVICE
gmap_role: Characteristic gmap_role: Characteristic
ugg_features: Characteristic | None = None ugg_features: Optional[Characteristic] = None
ugt_features: Characteristic | None = None ugt_features: Optional[Characteristic] = None
bgs_features: Characteristic | None = None bgs_features: Optional[Characteristic] = None
bgr_features: Characteristic | None = None bgr_features: Optional[Characteristic] = None
def __init__( def __init__(
self, self,
gmap_role: GmapRole, gmap_role: GmapRole,
ugg_features: UggFeatures | None = None, ugg_features: Optional[UggFeatures] = None,
ugt_features: UgtFeatures | None = None, ugt_features: Optional[UgtFeatures] = None,
bgs_features: BgsFeatures | None = None, bgs_features: Optional[BgsFeatures] = None,
bgr_features: BgrFeatures | None = None, bgr_features: Optional[BgrFeatures] = None,
) -> None: ) -> None:
characteristics = [] characteristics = []
@@ -149,10 +150,10 @@ class GamingAudioService(TemplateService):
class GamingAudioServiceProxy(ProfileServiceProxy): class GamingAudioServiceProxy(ProfileServiceProxy):
SERVICE_CLASS = GamingAudioService SERVICE_CLASS = GamingAudioService
ugg_features: CharacteristicProxy[UggFeatures] | None = None ugg_features: Optional[CharacteristicProxy[UggFeatures]] = None
ugt_features: CharacteristicProxy[UgtFeatures] | None = None ugt_features: Optional[CharacteristicProxy[UgtFeatures]] = None
bgs_features: CharacteristicProxy[BgsFeatures] | None = None bgs_features: Optional[CharacteristicProxy[BgsFeatures]] = None
bgr_features: CharacteristicProxy[BgrFeatures] | None = None bgr_features: Optional[CharacteristicProxy[BgrFeatures]] = None
def __init__(self, service_proxy: ServiceProxy) -> None: def __init__(self, service_proxy: ServiceProxy) -> None:
self.service_proxy = service_proxy self.service_proxy = service_proxy
+109 -125
View File
@@ -16,15 +16,16 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import logging import functools
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any import logging
from typing import Any, Dict, List, Optional, Set, Union
from bumble import att, gatt, gatt_adapters, gatt_client, utils from bumble import att, gatt, gatt_adapters, gatt_client
from bumble.core import InvalidArgumentError, InvalidStateError from bumble.core import InvalidArgumentError, InvalidStateError
from bumble.device import Connection, Device from bumble.device import Device, Connection
from bumble import utils
from bumble.hci import Address from bumble.hci import Address
@@ -145,7 +146,7 @@ class PresetChangedOperation:
return bytes([self.prev_index]) + bytes(self.preset_record) return bytes([self.prev_index]) + bytes(self.preset_record)
change_id: ChangeId change_id: ChangeId
additional_parameters: Generic | int additional_parameters: Union[Generic, int]
def to_bytes(self, is_last: bool) -> bytes: def to_bytes(self, is_last: bool) -> bytes:
if isinstance(self.additional_parameters, PresetChangedOperation.Generic): if isinstance(self.additional_parameters, PresetChangedOperation.Generic):
@@ -227,25 +228,23 @@ class HearingAccessService(gatt.TemplateService):
hearing_aid_preset_control_point: gatt.Characteristic[bytes] hearing_aid_preset_control_point: gatt.Characteristic[bytes]
active_preset_index_characteristic: gatt.Characteristic[bytes] active_preset_index_characteristic: gatt.Characteristic[bytes]
active_preset_index: int active_preset_index: int
active_preset_index_per_device: dict[Address, int] active_preset_index_per_device: Dict[Address, int]
device: Device device: Device
server_features: HearingAidFeatures server_features: HearingAidFeatures
preset_records: dict[int, PresetRecord] # key is the preset index preset_records: Dict[int, PresetRecord] # key is the preset index
read_presets_request_in_progress: bool read_presets_request_in_progress: bool
other_server_in_binaural_set: HearingAccessService | None = None preset_changed_operations_history_per_device: Dict[
Address, List[PresetChangedOperation]
preset_changed_operations_history_per_device: dict[
Address, list[PresetChangedOperation]
] ]
# Keep an updated list of connected client to send notification to # Keep an updated list of connected client to send notification to
currently_connected_clients: set[Connection] currently_connected_clients: Set[Connection]
def __init__( def __init__(
self, device: Device, features: HearingAidFeatures, presets: list[PresetRecord] self, device: Device, features: HearingAidFeatures, presets: List[PresetRecord]
) -> None: ) -> None:
self.active_preset_index_per_device = {} self.active_preset_index_per_device = {}
self.read_presets_request_in_progress = False self.read_presets_request_in_progress = False
@@ -271,21 +270,14 @@ class HearingAccessService(gatt.TemplateService):
def on_connection(connection: Connection) -> None: def on_connection(connection: Connection) -> None:
@connection.on(connection.EVENT_DISCONNECTION) @connection.on(connection.EVENT_DISCONNECTION)
def on_disconnection(_reason) -> None: def on_disconnection(_reason) -> None:
self.currently_connected_clients.discard(connection) self.currently_connected_clients.remove(connection)
@connection.on(connection.EVENT_CONNECTION_ATT_MTU_UPDATE)
def on_mtu_update(*_: Any) -> None:
self.on_incoming_connection(connection)
@connection.on(connection.EVENT_CONNECTION_ENCRYPTION_CHANGE)
def on_encryption_change(*_: Any) -> None:
self.on_incoming_connection(connection)
@connection.on(connection.EVENT_PAIRING) @connection.on(connection.EVENT_PAIRING)
def on_pairing(*_: Any) -> None: def on_pairing(*_: Any) -> None:
self.on_incoming_connection(connection) self.on_incoming_paired_connection(connection)
self.on_incoming_connection(connection) if connection.peer_resolvable_address:
self.on_incoming_paired_connection(connection)
self.hearing_aid_features_characteristic = gatt.Characteristic( self.hearing_aid_features_characteristic = gatt.Characteristic(
uuid=gatt.GATT_HEARING_AID_FEATURES_CHARACTERISTIC, uuid=gatt.GATT_HEARING_AID_FEATURES_CHARACTERISTIC,
@@ -322,30 +314,9 @@ class HearingAccessService(gatt.TemplateService):
] ]
) )
def on_incoming_connection(self, connection: Connection): def on_incoming_paired_connection(self, connection: Connection):
'''Setup initial operations to handle a remote bonded HAP device''' '''Setup initial operations to handle a remote bonded HAP device'''
# TODO Should we filter on HAP device only ? # TODO Should we filter on HAP device only ?
if not connection.is_encrypted:
logging.debug(f'HAS: {connection.peer_address} is not encrypted')
return
if not connection.peer_resolvable_address:
logging.debug(f'HAS: {connection.peer_address} is not paired')
return
if connection.att_mtu < 49:
logging.debug(
f'HAS: {connection.peer_address} invalid MTU={connection.att_mtu}'
)
return
if connection.peer_address in self.currently_connected_clients:
logging.debug(
f'HAS: Already connected to {connection.peer_address} nothing to do'
)
return
self.currently_connected_clients.add(connection) self.currently_connected_clients.add(connection)
if ( if (
connection.peer_address connection.peer_address
@@ -362,10 +333,11 @@ class HearingAccessService(gatt.TemplateService):
# Update the active preset index if needed # Update the active preset index if needed
await self.notify_active_preset_for_connection(connection) await self.notify_active_preset_for_connection(connection)
connection.cancel_on_disconnection(on_connection_async()) utils.cancel_on_event(connection, 'disconnection', on_connection_async())
def _on_read_active_preset_index(self, connection: Connection) -> bytes: def _on_read_active_preset_index(
del connection # Unused self, __connection__: Optional[Connection]
) -> bytes:
return bytes([self.active_preset_index]) return bytes([self.active_preset_index])
# TODO this need to be triggered when device is unbonded # TODO this need to be triggered when device is unbonded
@@ -373,13 +345,18 @@ class HearingAccessService(gatt.TemplateService):
self.preset_changed_operations_history_per_device.pop(addr) self.preset_changed_operations_history_per_device.pop(addr)
async def _on_write_hearing_aid_preset_control_point( async def _on_write_hearing_aid_preset_control_point(
self, connection: Connection, value: bytes self, connection: Optional[Connection], value: bytes
): ):
assert connection
opcode = HearingAidPresetControlPointOpcode(value[0]) opcode = HearingAidPresetControlPointOpcode(value[0])
handler = getattr(self, '_on_' + opcode.name.lower()) handler = getattr(self, '_on_' + opcode.name.lower())
await handler(connection, value) await handler(connection, value)
async def _on_read_presets_request(self, connection: Connection, value: bytes): async def _on_read_presets_request(
self, connection: Optional[Connection], value: bytes
):
assert connection
if connection.att_mtu < 49: # 2.5. GATT sub-procedure requirements if connection.att_mtu < 49: # 2.5. GATT sub-procedure requirements
logging.warning(f'HAS require MTU >= 49: {connection}') logging.warning(f'HAS require MTU >= 49: {connection}')
@@ -400,19 +377,17 @@ class HearingAccessService(gatt.TemplateService):
self.preset_records[key] self.preset_records[key]
for key in sorted(self.preset_records.keys()) for key in sorted(self.preset_records.keys())
if self.preset_records[key].index >= start_index if self.preset_records[key].index >= start_index
][:num_presets] ]
del presets[num_presets:]
if len(presets) == 0: if len(presets) == 0:
raise att.ATT_Error(att.ErrorCode.OUT_OF_RANGE) raise att.ATT_Error(att.ErrorCode.OUT_OF_RANGE)
utils.AsyncRunner.spawn(self._read_preset_response(connection, presets)) utils.AsyncRunner.spawn(self._read_preset_response(connection, presets))
async def _read_preset_response( async def _read_preset_response(
self, connection: Connection, presets: list[PresetRecord] self, connection: Connection, presets: List[PresetRecord]
): ):
# If the ATT bearer is terminated before all notifications or indications are # If the ATT bearer is terminated before all notifications or indications are sent, then the server shall consider the Read Presets Request operation aborted and shall not either continue or restart the operation when the client reconnects.
# sent, then the server shall consider the Read Presets Request operation
# aborted and shall not either continue or restart the operation when the client
# reconnects.
try: try:
for i, preset in enumerate(presets): for i, preset in enumerate(presets):
await connection.device.indicate_subscriber( await connection.device.indicate_subscriber(
@@ -433,7 +408,7 @@ class HearingAccessService(gatt.TemplateService):
async def generic_update(self, op: PresetChangedOperation) -> None: async def generic_update(self, op: PresetChangedOperation) -> None:
'''Server API to perform a generic update. It is the responsibility of the caller to modify the preset_records to match the PresetChangedOperation being sent''' '''Server API to perform a generic update. It is the responsibility of the caller to modify the preset_records to match the PresetChangedOperation being sent'''
await self._notify_preset_operations(op) await self._notifyPresetOperations(op)
async def delete_preset(self, index: int) -> None: async def delete_preset(self, index: int) -> None:
'''Server API to delete a preset. It should not be the current active preset''' '''Server API to delete a preset. It should not be the current active preset'''
@@ -442,14 +417,14 @@ class HearingAccessService(gatt.TemplateService):
raise InvalidStateError('Cannot delete active preset') raise InvalidStateError('Cannot delete active preset')
del self.preset_records[index] del self.preset_records[index]
await self._notify_preset_operations(PresetChangedOperationDeleted(index)) await self._notifyPresetOperations(PresetChangedOperationDeleted(index))
async def available_preset(self, index: int) -> None: async def available_preset(self, index: int) -> None:
'''Server API to make a preset available''' '''Server API to make a preset available'''
preset = self.preset_records[index] preset = self.preset_records[index]
preset.properties.is_available = PresetRecord.Property.IsAvailable.IS_AVAILABLE preset.properties.is_available = PresetRecord.Property.IsAvailable.IS_AVAILABLE
await self._notify_preset_operations(PresetChangedOperationAvailable(index)) await self._notifyPresetOperations(PresetChangedOperationAvailable(index))
async def unavailable_preset(self, index: int) -> None: async def unavailable_preset(self, index: int) -> None:
'''Server API to make a preset unavailable. It should not be the current active preset''' '''Server API to make a preset unavailable. It should not be the current active preset'''
@@ -461,7 +436,7 @@ class HearingAccessService(gatt.TemplateService):
preset.properties.is_available = ( preset.properties.is_available = (
PresetRecord.Property.IsAvailable.IS_UNAVAILABLE PresetRecord.Property.IsAvailable.IS_UNAVAILABLE
) )
await self._notify_preset_operations(PresetChangedOperationUnavailable(index)) await self._notifyPresetOperations(PresetChangedOperationUnavailable(index))
async def _preset_changed_operation(self, connection: Connection) -> None: async def _preset_changed_operation(self, connection: Connection) -> None:
'''Send all PresetChangedOperation saved for a given connection''' '''Send all PresetChangedOperation saved for a given connection'''
@@ -476,31 +451,30 @@ class HearingAccessService(gatt.TemplateService):
return op.additional_parameters return op.additional_parameters
op_list.sort(key=get_op_index) op_list.sort(key=get_op_index)
# If the ATT bearer is terminated before all notifications or indications are # If the ATT bearer is terminated before all notifications or indications are sent, then the server shall consider the Preset Changed operation aborted and shall continue the operation when the client reconnects.
# sent, then the server shall consider the Preset Changed operation aborted and while len(op_list) > 0:
# shall continue the operation when the client reconnects.
while op_list:
try: try:
await connection.device.indicate_subscriber( await connection.device.indicate_subscriber(
connection, connection,
self.hearing_aid_preset_control_point, self.hearing_aid_preset_control_point,
value=op_list[0].to_bytes(len(op_list) == 1), value=op_list[0].to_bytes(len(op_list) == 1),
force=True, # TODO GATT notification subscription should be persistent
) )
# Remove item once sent, and keep the non sent item in the list # Remove item once sent, and keep the non sent item in the list
op_list.pop(0) op_list.pop(0)
except TimeoutError: except TimeoutError:
break break
async def _notify_preset_operations(self, op: PresetChangedOperation) -> None: async def _notifyPresetOperations(self, op: PresetChangedOperation) -> None:
for history_list in self.preset_changed_operations_history_per_device.values(): for historyList in self.preset_changed_operations_history_per_device.values():
history_list.append(op) historyList.append(op)
for connection in self.currently_connected_clients: for connection in self.currently_connected_clients:
await self._preset_changed_operation(connection) await self._preset_changed_operation(connection)
async def _on_write_preset_name(self, connection: Connection, value: bytes): async def _on_write_preset_name(
del connection # Unused self, connection: Optional[Connection], value: bytes
):
assert connection
if self.read_presets_request_in_progress: if self.read_presets_request_in_progress:
raise att.ATT_Error(att.ErrorCode.PROCEDURE_ALREADY_IN_PROGRESS) raise att.ATT_Error(att.ErrorCode.PROCEDURE_ALREADY_IN_PROGRESS)
@@ -548,7 +522,10 @@ class HearingAccessService(gatt.TemplateService):
for connection in self.currently_connected_clients: for connection in self.currently_connected_clients:
await self.notify_active_preset_for_connection(connection) await self.notify_active_preset_for_connection(connection)
async def set_active_preset(self, value: bytes) -> None: async def set_active_preset(
self, connection: Optional[Connection], value: bytes
) -> None:
assert connection
index = value[1] index = value[1]
preset = self.preset_records.get(index, None) preset = self.preset_records.get(index, None)
if ( if (
@@ -565,85 +542,86 @@ class HearingAccessService(gatt.TemplateService):
self.active_preset_index = index self.active_preset_index = index
await self.notify_active_preset() await self.notify_active_preset()
async def _on_set_active_preset(self, connection: Connection, value: bytes): async def _on_set_active_preset(
del connection # Unused self, connection: Optional[Connection], value: bytes
await self.set_active_preset(value) ):
await self.set_active_preset(connection, value)
async def set_next_or_previous_preset(self, is_previous: bool) -> None: async def set_next_or_previous_preset(
self, connection: Optional[Connection], is_previous
):
'''Set the next or the previous preset as active''' '''Set the next or the previous preset as active'''
assert connection
if self.active_preset_index == 0x00: if self.active_preset_index == 0x00:
raise att.ATT_Error(ErrorCode.PRESET_OPERATION_NOT_POSSIBLE) raise att.ATT_Error(ErrorCode.PRESET_OPERATION_NOT_POSSIBLE)
presets = sorted( first_preset: Optional[PresetRecord] = None # To loop to first preset
[ next_preset: Optional[PresetRecord] = None
record for index, record in sorted(self.preset_records.items(), reverse=is_previous):
for record in self.preset_records.values() if not record.is_available():
if record.is_available() continue
], if first_preset == None:
key=lambda record: record.index, first_preset = record
) if is_previous:
current_preset = self.preset_records[self.active_preset_index] if index >= self.active_preset_index:
current_preset_pos = presets.index(current_preset) continue
if is_previous: elif index <= self.active_preset_index:
new_preset = presets[(current_preset_pos - 1) % len(presets)] continue
else: next_preset = record
new_preset = presets[(current_preset_pos + 1) % len(presets)] break
if current_preset == new_preset: # If no other preset are available if not first_preset: # If no other preset are available
raise att.ATT_Error(ErrorCode.PRESET_OPERATION_NOT_POSSIBLE) raise att.ATT_Error(ErrorCode.PRESET_OPERATION_NOT_POSSIBLE)
self.active_preset_index = new_preset.index if next_preset:
self.active_preset_index = next_preset.index
else:
self.active_preset_index = first_preset.index
await self.notify_active_preset() await self.notify_active_preset()
async def _on_set_next_preset(self, connection: Connection, value: bytes) -> None: async def _on_set_next_preset(
del connection, value # Unused. self, connection: Optional[Connection], __value__: bytes
await self.set_next_or_previous_preset(False) ) -> None:
await self.set_next_or_previous_preset(connection, False)
async def _on_set_previous_preset( async def _on_set_previous_preset(
self, connection: Connection, value: bytes self, connection: Optional[Connection], __value__: bytes
) -> None: ) -> None:
del connection, value # Unused. await self.set_next_or_previous_preset(connection, True)
await self.set_next_or_previous_preset(True)
async def _on_set_active_preset_synchronized_locally( async def _on_set_active_preset_synchronized_locally(
self, connection: Connection, value: bytes self, connection: Optional[Connection], value: bytes
): ):
del connection # Unused.
if ( if (
self.server_features.preset_synchronization_support self.server_features.preset_synchronization_support
== PresetSynchronizationSupport.PRESET_SYNCHRONIZATION_IS_NOT_SUPPORTED == PresetSynchronizationSupport.PRESET_SYNCHRONIZATION_IS_SUPPORTED
): ):
raise att.ATT_Error(ErrorCode.PRESET_SYNCHRONIZATION_NOT_SUPPORTED) raise att.ATT_Error(ErrorCode.PRESET_SYNCHRONIZATION_NOT_SUPPORTED)
await self.set_active_preset(value) await self.set_active_preset(connection, value)
if self.other_server_in_binaural_set: # TODO (low priority) inform other server of the change
await self.other_server_in_binaural_set.set_active_preset(value)
async def _on_set_next_preset_synchronized_locally( async def _on_set_next_preset_synchronized_locally(
self, connection: Connection, value: bytes self, connection: Optional[Connection], __value__: bytes
): ):
del connection, value # Unused.
if ( if (
self.server_features.preset_synchronization_support self.server_features.preset_synchronization_support
== PresetSynchronizationSupport.PRESET_SYNCHRONIZATION_IS_NOT_SUPPORTED == PresetSynchronizationSupport.PRESET_SYNCHRONIZATION_IS_SUPPORTED
): ):
raise att.ATT_Error(ErrorCode.PRESET_SYNCHRONIZATION_NOT_SUPPORTED) raise att.ATT_Error(ErrorCode.PRESET_SYNCHRONIZATION_NOT_SUPPORTED)
await self.set_next_or_previous_preset(False) await self.set_next_or_previous_preset(connection, False)
if self.other_server_in_binaural_set: # TODO (low priority) inform other server of the change
await self.other_server_in_binaural_set.set_next_or_previous_preset(False)
async def _on_set_previous_preset_synchronized_locally( async def _on_set_previous_preset_synchronized_locally(
self, connection: Connection, value: bytes self, connection: Optional[Connection], __value__: bytes
): ):
del connection, value # Unused.
if ( if (
self.server_features.preset_synchronization_support self.server_features.preset_synchronization_support
== PresetSynchronizationSupport.PRESET_SYNCHRONIZATION_IS_NOT_SUPPORTED == PresetSynchronizationSupport.PRESET_SYNCHRONIZATION_IS_SUPPORTED
): ):
raise att.ATT_Error(ErrorCode.PRESET_SYNCHRONIZATION_NOT_SUPPORTED) raise att.ATT_Error(ErrorCode.PRESET_SYNCHRONIZATION_NOT_SUPPORTED)
await self.set_next_or_previous_preset(True) await self.set_next_or_previous_preset(connection, True)
if self.other_server_in_binaural_set: # TODO (low priority) inform other server of the change
await self.other_server_in_binaural_set.set_next_or_previous_preset(True)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -653,13 +631,11 @@ class HearingAccessServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = HearingAccessService SERVICE_CLASS = HearingAccessService
hearing_aid_preset_control_point: gatt_client.CharacteristicProxy hearing_aid_preset_control_point: gatt_client.CharacteristicProxy
preset_control_point_indications: asyncio.Queue[bytes] preset_control_point_indications: asyncio.Queue
active_preset_index_notification: asyncio.Queue[bytes] active_preset_index_notification: asyncio.Queue
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None: def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
self.service_proxy = service_proxy self.service_proxy = service_proxy
self.preset_control_point_indications = asyncio.Queue()
self.active_preset_index_notification = asyncio.Queue()
self.server_features = gatt_adapters.PackedCharacteristicProxyAdapter( self.server_features = gatt_adapters.PackedCharacteristicProxyAdapter(
service_proxy.get_characteristics_by_uuid( service_proxy.get_characteristics_by_uuid(
@@ -681,12 +657,20 @@ class HearingAccessServiceProxy(gatt_client.ProfileServiceProxy):
'B', 'B',
) )
async def setup_subscription(self) -> None: async def setup_subscription(self):
self.preset_control_point_indications = asyncio.Queue()
self.active_preset_index_notification = asyncio.Queue()
def on_active_preset_index_notification(data: bytes):
self.active_preset_index_notification.put_nowait(data)
def on_preset_control_point_indication(data: bytes):
self.preset_control_point_indications.put_nowait(data)
await self.hearing_aid_preset_control_point.subscribe( await self.hearing_aid_preset_control_point.subscribe(
self.preset_control_point_indications.put_nowait, functools.partial(on_preset_control_point_indication), prefer_notify=False
prefer_notify=False,
) )
await self.active_preset_index.subscribe( await self.active_preset_index.subscribe(
self.active_preset_index_notification.put_nowait functools.partial(on_active_preset_index_notification)
) )
+121 -130
View File
@@ -17,31 +17,41 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations from __future__ import annotations
from enum import IntEnum
import dataclasses
import enum
import struct import struct
from collections.abc import Callable, Sequence from typing import Optional
from typing import Any
from typing_extensions import Self from bumble import core
from bumble.att import ATT_Error
from bumble import att, core, device, gatt, gatt_adapters, gatt_client, utils from bumble.gatt import (
GATT_HEART_RATE_SERVICE,
GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC,
GATT_BODY_SENSOR_LOCATION_CHARACTERISTIC,
GATT_HEART_RATE_CONTROL_POINT_CHARACTERISTIC,
TemplateService,
Characteristic,
CharacteristicValue,
)
from bumble.gatt_adapters import (
DelegatedCharacteristicAdapter,
PackedCharacteristicAdapter,
SerializableCharacteristicAdapter,
)
from bumble.gatt_client import CharacteristicProxy, ProfileServiceProxy
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class HeartRateService(gatt.TemplateService): class HeartRateService(TemplateService):
UUID = gatt.GATT_HEART_RATE_SERVICE UUID = GATT_HEART_RATE_SERVICE
HEART_RATE_CONTROL_POINT_FORMAT = 'B' HEART_RATE_CONTROL_POINT_FORMAT = 'B'
CONTROL_POINT_NOT_SUPPORTED = 0x80 CONTROL_POINT_NOT_SUPPORTED = 0x80
RESET_ENERGY_EXPENDED = 0x01 RESET_ENERGY_EXPENDED = 0x01
heart_rate_measurement_characteristic: gatt.Characteristic[HeartRateMeasurement] heart_rate_measurement_characteristic: Characteristic[HeartRateMeasurement]
body_sensor_location_characteristic: gatt.Characteristic[BodySensorLocation] body_sensor_location_characteristic: Characteristic[BodySensorLocation]
heart_rate_control_point_characteristic: gatt.Characteristic[int] heart_rate_control_point_characteristic: Characteristic[int]
class BodySensorLocation(utils.OpenIntEnum): class BodySensorLocation(IntEnum):
OTHER = 0 OTHER = 0
CHEST = 1 CHEST = 1
WRIST = 2 WRIST = 2
@@ -50,90 +60,82 @@ class HeartRateService(gatt.TemplateService):
EAR_LOBE = 5 EAR_LOBE = 5
FOOT = 6 FOOT = 6
@dataclasses.dataclass
class HeartRateMeasurement: class HeartRateMeasurement:
heart_rate: int def __init__(
sensor_contact_detected: bool | None = None self,
energy_expended: int | None = None heart_rate,
rr_intervals: Sequence[float] | None = None sensor_contact_detected=None,
energy_expended=None,
class Flag(enum.IntFlag): rr_intervals=None,
INT16_HEART_RATE = 1 << 0 ):
SENSOR_CONTACT_DETECTED = 1 << 1 if heart_rate < 0 or heart_rate > 0xFFFF:
SENSOR_CONTACT_SUPPORTED = 1 << 2
ENERGY_EXPENDED_STATUS = 1 << 3
RR_INTERVAL = 1 << 4
def __post_init__(self) -> None:
if self.heart_rate < 0 or self.heart_rate > 0xFFFF:
raise core.InvalidArgumentError('heart_rate out of range') raise core.InvalidArgumentError('heart_rate out of range')
if self.energy_expended is not None and ( if energy_expended is not None and (
self.energy_expended < 0 or self.energy_expended > 0xFFFF energy_expended < 0 or energy_expended > 0xFFFF
): ):
raise core.InvalidArgumentError('energy_expended out of range') raise core.InvalidArgumentError('energy_expended out of range')
if self.rr_intervals: if rr_intervals:
for rr_interval in self.rr_intervals: for rr_interval in rr_intervals:
if rr_interval < 0 or rr_interval * 1024 > 0xFFFF: if rr_interval < 0 or rr_interval * 1024 > 0xFFFF:
raise core.InvalidArgumentError('rr_intervals out of range') raise core.InvalidArgumentError('rr_intervals out of range')
self.heart_rate = heart_rate
self.sensor_contact_detected = sensor_contact_detected
self.energy_expended = energy_expended
self.rr_intervals = rr_intervals
@classmethod @classmethod
def from_bytes(cls, data: bytes) -> Self: def from_bytes(cls, data):
flags = data[0] flags = data[0]
offset = 1 offset = 1
if flags & cls.Flag.INT16_HEART_RATE: if flags & 1:
heart_rate = struct.unpack_from('<H', data, offset)[0] hr = struct.unpack_from('<H', data, offset)[0]
offset += 2 offset += 2
else: else:
heart_rate = struct.unpack_from('B', data, offset)[0] hr = struct.unpack_from('B', data, offset)[0]
offset += 1 offset += 1
if flags & cls.Flag.SENSOR_CONTACT_SUPPORTED: if flags & (1 << 2):
sensor_contact_detected = flags & cls.Flag.SENSOR_CONTACT_DETECTED != 0 sensor_contact_detected = flags & (1 << 1) != 0
else: else:
sensor_contact_detected = None sensor_contact_detected = None
if flags & cls.Flag.ENERGY_EXPENDED_STATUS: if flags & (1 << 3):
energy_expended = struct.unpack_from('<H', data, offset)[0] energy_expended = struct.unpack_from('<H', data, offset)[0]
offset += 2 offset += 2
else: else:
energy_expended = None energy_expended = None
rr_intervals: Sequence[float] | None = None if flags & (1 << 4):
if flags & cls.Flag.RR_INTERVAL:
rr_intervals = tuple( rr_intervals = tuple(
struct.unpack_from('<H', data, i)[0] / 1024 struct.unpack_from('<H', data, offset + i * 2)[0] / 1024
for i in range(offset, len(data), 2) for i in range((len(data) - offset) // 2)
) )
else:
rr_intervals = ()
return cls( return cls(hr, sensor_contact_detected, energy_expended, rr_intervals)
heart_rate=heart_rate,
sensor_contact_detected=sensor_contact_detected,
energy_expended=energy_expended,
rr_intervals=rr_intervals,
)
def __bytes__(self) -> bytes: def __bytes__(self):
flags = 0
if self.heart_rate < 256: if self.heart_rate < 256:
flags = 0
data = struct.pack('B', self.heart_rate) data = struct.pack('B', self.heart_rate)
else: else:
flags |= self.Flag.INT16_HEART_RATE flags = 1
data = struct.pack('<H', self.heart_rate) data = struct.pack('<H', self.heart_rate)
if self.sensor_contact_detected is not None: if self.sensor_contact_detected is not None:
flags |= self.Flag.SENSOR_CONTACT_SUPPORTED flags |= ((1 if self.sensor_contact_detected else 0) << 1) | (1 << 2)
if self.sensor_contact_detected:
flags |= self.Flag.SENSOR_CONTACT_DETECTED
if self.energy_expended is not None: if self.energy_expended is not None:
flags |= self.Flag.ENERGY_EXPENDED_STATUS flags |= 1 << 3
data += struct.pack('<H', self.energy_expended) data += struct.pack('<H', self.energy_expended)
if self.rr_intervals is not None: if self.rr_intervals:
flags |= self.Flag.RR_INTERVAL flags |= 1 << 4
data += b''.join( data += b''.join(
[ [
struct.pack('<H', int(rr_interval * 1024)) struct.pack('<H', int(rr_interval * 1024))
@@ -143,67 +145,57 @@ class HeartRateService(gatt.TemplateService):
return bytes([flags]) + data return bytes([flags]) + data
def __str__(self):
return (
f'HeartRateMeasurement(heart_rate={self.heart_rate},'
f' sensor_contact_detected={self.sensor_contact_detected},'
f' energy_expended={self.energy_expended},'
f' rr_intervals={self.rr_intervals})'
)
def __init__( def __init__(
self, self,
read_heart_rate_measurement: Callable[ read_heart_rate_measurement,
[device.Connection], HeartRateMeasurement body_sensor_location=None,
], reset_energy_expended=None,
body_sensor_location: HeartRateService.BodySensorLocation | None = None,
reset_energy_expended: Callable[[device.Connection], Any] | None = None,
): ):
self.heart_rate_measurement_characteristic = ( self.heart_rate_measurement_characteristic = SerializableCharacteristicAdapter(
gatt_adapters.SerializableCharacteristicAdapter( Characteristic(
gatt.Characteristic( GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC,
uuid=gatt.GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC, Characteristic.Properties.NOTIFY,
properties=gatt.Characteristic.Properties.NOTIFY, 0,
permissions=gatt.Characteristic.Permissions(0), CharacteristicValue(read=read_heart_rate_measurement),
value=gatt.CharacteristicValue(read=read_heart_rate_measurement), ),
), HeartRateService.HeartRateMeasurement,
HeartRateService.HeartRateMeasurement,
)
) )
characteristics: list[gatt.Characteristic] = [ characteristics = [self.heart_rate_measurement_characteristic]
self.heart_rate_measurement_characteristic
]
if body_sensor_location is not None: if body_sensor_location is not None:
self.body_sensor_location_characteristic = ( self.body_sensor_location_characteristic = Characteristic(
gatt_adapters.EnumCharacteristicAdapter( GATT_BODY_SENSOR_LOCATION_CHARACTERISTIC,
gatt.Characteristic( Characteristic.Properties.READ,
uuid=gatt.GATT_BODY_SENSOR_LOCATION_CHARACTERISTIC, Characteristic.READABLE,
properties=gatt.Characteristic.Properties.READ, bytes([int(body_sensor_location)]),
permissions=gatt.Characteristic.READABLE,
value=body_sensor_location,
),
cls=self.BodySensorLocation,
length=1,
)
) )
characteristics.append(self.body_sensor_location_characteristic) characteristics.append(self.body_sensor_location_characteristic)
if reset_energy_expended: if reset_energy_expended:
def write_heart_rate_control_point_value( def write_heart_rate_control_point_value(connection, value):
connection: device.Connection, value: bytes
) -> None:
if value == self.RESET_ENERGY_EXPENDED: if value == self.RESET_ENERGY_EXPENDED:
if reset_energy_expended is not None: if reset_energy_expended is not None:
reset_energy_expended(connection) reset_energy_expended(connection)
else: else:
raise att.ATT_Error(self.CONTROL_POINT_NOT_SUPPORTED) raise ATT_Error(self.CONTROL_POINT_NOT_SUPPORTED)
self.heart_rate_control_point_characteristic = ( self.heart_rate_control_point_characteristic = PackedCharacteristicAdapter(
gatt_adapters.PackedCharacteristicAdapter( Characteristic(
gatt.Characteristic( GATT_HEART_RATE_CONTROL_POINT_CHARACTERISTIC,
uuid=gatt.GATT_HEART_RATE_CONTROL_POINT_CHARACTERISTIC, Characteristic.Properties.WRITE,
properties=gatt.Characteristic.Properties.WRITE, Characteristic.WRITEABLE,
permissions=gatt.Characteristic.WRITEABLE, CharacteristicValue(write=write_heart_rate_control_point_value),
value=gatt.CharacteristicValue( ),
write=write_heart_rate_control_point_value pack_format=HeartRateService.HEART_RATE_CONTROL_POINT_FORMAT,
),
),
pack_format=HeartRateService.HEART_RATE_CONTROL_POINT_FORMAT,
)
) )
characteristics.append(self.heart_rate_control_point_characteristic) characteristics.append(self.heart_rate_control_point_characteristic)
@@ -211,51 +203,50 @@ class HeartRateService(gatt.TemplateService):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class HeartRateServiceProxy(gatt_client.ProfileServiceProxy): class HeartRateServiceProxy(ProfileServiceProxy):
SERVICE_CLASS = HeartRateService SERVICE_CLASS = HeartRateService
heart_rate_measurement: gatt_client.CharacteristicProxy[ heart_rate_measurement: Optional[
HeartRateService.HeartRateMeasurement CharacteristicProxy[HeartRateService.HeartRateMeasurement]
] ]
body_sensor_location: ( body_sensor_location: Optional[
gatt_client.CharacteristicProxy[HeartRateService.BodySensorLocation] | None CharacteristicProxy[HeartRateService.BodySensorLocation]
) ]
heart_rate_control_point: gatt_client.CharacteristicProxy[int] | None heart_rate_control_point: Optional[CharacteristicProxy[int]]
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None: def __init__(self, service_proxy):
self.service_proxy = service_proxy self.service_proxy = service_proxy
self.heart_rate_measurement = ( if characteristics := service_proxy.get_characteristics_by_uuid(
gatt_adapters.SerializableCharacteristicProxyAdapter( GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC
service_proxy.get_required_characteristic_by_uuid( ):
gatt.GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC self.heart_rate_measurement = SerializableCharacteristicAdapter(
), characteristics[0], HeartRateService.HeartRateMeasurement
HeartRateService.HeartRateMeasurement,
) )
) else:
self.heart_rate_measurement = None
if characteristics := service_proxy.get_characteristics_by_uuid( if characteristics := service_proxy.get_characteristics_by_uuid(
gatt.GATT_BODY_SENSOR_LOCATION_CHARACTERISTIC GATT_BODY_SENSOR_LOCATION_CHARACTERISTIC
): ):
self.body_sensor_location = gatt_adapters.EnumCharacteristicProxyAdapter( self.body_sensor_location = DelegatedCharacteristicAdapter(
characteristics[0], cls=HeartRateService.BodySensorLocation, length=1 characteristics[0],
decode=lambda value: HeartRateService.BodySensorLocation(value[0]),
) )
else: else:
self.body_sensor_location = None self.body_sensor_location = None
if characteristics := service_proxy.get_characteristics_by_uuid( if characteristics := service_proxy.get_characteristics_by_uuid(
gatt.GATT_HEART_RATE_CONTROL_POINT_CHARACTERISTIC GATT_HEART_RATE_CONTROL_POINT_CHARACTERISTIC
): ):
self.heart_rate_control_point = ( self.heart_rate_control_point = PackedCharacteristicAdapter(
gatt_adapters.PackedCharacteristicProxyAdapter( characteristics[0],
characteristics[0], pack_format=HeartRateService.HEART_RATE_CONTROL_POINT_FORMAT,
pack_format=HeartRateService.HEART_RATE_CONTROL_POINT_FORMAT,
)
) )
else: else:
self.heart_rate_control_point = None self.heart_rate_control_point = None
async def reset_energy_expended(self) -> None: async def reset_energy_expended(self):
if self.heart_rate_control_point is not None: if self.heart_rate_control_point is not None:
return await self.heart_rate_control_point.write_value( return await self.heart_rate_control_point.write_value(
HeartRateService.RESET_ENERGY_EXPENDED HeartRateService.RESET_ENERGY_EXPENDED

Some files were not shown because too many files have changed in this diff Show More