mirror of
https://github.com/google/bumble.git
synced 2026-05-07 03:48:01 +00:00
Compare commits
129 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3f6f036270 | ||
|
|
859bb0609f | ||
|
|
5f2d24570e | ||
|
|
dbf94c8f3e | ||
|
|
b6adc29365 | ||
|
|
5caa7bfa90 | ||
|
|
f39d706fa0 | ||
|
|
c02c1f33d2 | ||
|
|
33435c2980 | ||
|
|
c08449d9db | ||
|
|
3c8718bb5b | ||
|
|
26e87f09fe | ||
|
|
7f5e0d190e | ||
|
|
efae307b3d | ||
|
|
26d38a855c | ||
|
|
7360a887d9 | ||
|
|
9756572c93 | ||
|
|
d6100755b1 | ||
|
|
a66eef6630 | ||
|
|
ae23ef7b9b | ||
|
|
f368b5e518 | ||
|
|
5293d32dc6 | ||
|
|
6d9a0bf4e1 | ||
|
|
3c7b5df7c5 | ||
|
|
70141c0439 | ||
|
|
dedc0aca54 | ||
|
|
7c019b574f | ||
|
|
9b485fd943 | ||
|
|
fdee8269ec | ||
|
|
0767f2d4ae | ||
|
|
c4a0846727 | ||
|
|
83ac70e426 | ||
|
|
01cce3525f | ||
|
|
b9d35aea47 | ||
|
|
079cf6b896 | ||
|
|
180655088c | ||
|
|
a1bade6f20 | ||
|
|
5d80e7fd80 | ||
|
|
2198692961 | ||
|
|
55d3fd90f5 | ||
|
|
afee659ca6 | ||
|
|
6fe7931d7d | ||
|
|
9023407ee4 | ||
|
|
54d961bbe5 | ||
|
|
cbd46adbcf | ||
|
|
745e107849 | ||
|
|
af466c2970 | ||
|
|
931e2de854 | ||
|
|
55eb7eb237 | ||
|
|
bade4502f9 | ||
|
|
9f952f202f | ||
|
|
1eb9d8d055 | ||
|
|
5a477eb391 | ||
|
|
86cda8771d | ||
|
|
c1ea0ddd35 | ||
|
|
f567711a6c | ||
|
|
509df4c676 | ||
|
|
b375ed07b4 | ||
|
|
69d62d3dd1 | ||
|
|
fe3fa3d505 | ||
|
|
27fcd43224 | ||
|
|
c3b2bb19d5 | ||
|
|
34287177b9 | ||
|
|
d238dd4059 | ||
|
|
865f3a249f | ||
|
|
7324d322fe | ||
|
|
af148b476d | ||
|
|
80d60aaf15 | ||
|
|
c80f89d20f | ||
|
|
a27f55a588 | ||
|
|
62e4670a39 | ||
|
|
99695bb264 | ||
|
|
eb54898106 | ||
|
|
4f5ee204d2 | ||
|
|
2552e21db1 | ||
|
|
6168f87e2f | ||
|
|
ca7d2ca4df | ||
|
|
60723323e9 | ||
|
|
3ce7b9255b | ||
|
|
97fcfc2fa0 | ||
|
|
19674e3758 | ||
|
|
1130e1db8f | ||
|
|
37c7f3a58a | ||
|
|
0a12b2bf2e | ||
|
|
d014acbe63 | ||
|
|
07f9997a49 | ||
|
|
b9f91f695a | ||
|
|
082d55af10 | ||
|
|
4c3fd5688d | ||
|
|
9d3d5495ce | ||
|
|
b3869f267c | ||
|
|
8715333706 | ||
|
|
b57096abe2 | ||
|
|
48685c8587 | ||
|
|
100bea6b41 | ||
|
|
63819bf9dd | ||
|
|
6e55390930 | ||
|
|
e3fdab4175 | ||
|
|
bbcd14dbf0 | ||
|
|
01dc0d574b | ||
|
|
5e959d638e | ||
|
|
8d908288c8 | ||
|
|
c88b32a406 | ||
|
|
5a72eefb89 | ||
|
|
430046944b | ||
|
|
21d23320eb | ||
|
|
d0990ee04d | ||
|
|
2d88e853e8 | ||
|
|
a060a70fba | ||
|
|
a06394ad4a | ||
|
|
a1414c2b5b | ||
|
|
b2864dac2d | ||
|
|
b78f895143 | ||
|
|
c4e9726828 | ||
|
|
d4b8e8348a | ||
|
|
19debaa52e | ||
|
|
73fe564321 | ||
|
|
a00abd65b3 | ||
|
|
f169ceaebb | ||
|
|
528af0d338 | ||
|
|
4b25eed869 | ||
|
|
fcd6bd7136 | ||
|
|
2bed50b353 | ||
|
|
1fe3778a74 | ||
|
|
5e31bcf23d | ||
|
|
fe429cb2eb | ||
|
|
c91695c23a | ||
|
|
55f99e6887 | ||
|
|
b190069f48 |
4
.github/workflows/code-check.yml
vendored
4
.github/workflows/code-check.yml
vendored
@@ -16,7 +16,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13.0"]
|
||||
fail-fast: false
|
||||
|
||||
steps:
|
||||
@@ -33,7 +33,7 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install ".[build,test,development,pandora]"
|
||||
python -m pip install ".[build,test,development]"
|
||||
- name: Check
|
||||
run: |
|
||||
invoke project.pre-commit
|
||||
|
||||
6
.github/workflows/python-avatar.yml
vendored
6
.github/workflows/python-avatar.yml
vendored
@@ -32,7 +32,7 @@ jobs:
|
||||
- name: Install
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install .[avatar,pandora]
|
||||
python -m pip install .[avatar]
|
||||
- name: Rootcanal
|
||||
run: nohup python -m rootcanal > rootcanal.log &
|
||||
- name: Test
|
||||
@@ -44,7 +44,7 @@ jobs:
|
||||
run: cat rootcanal.log
|
||||
- name: Upload Mobly logs
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v3
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: mobly-logs
|
||||
name: mobly-logs-${{ strategy.job-index }}
|
||||
path: /tmp/logs/mobly/bumble.bumbles/
|
||||
|
||||
4
.github/workflows/python-build-test.yml
vendored
4
.github/workflows/python-build-test.yml
vendored
@@ -16,7 +16,7 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
os: ['ubuntu-latest', 'macos-latest', 'windows-latest']
|
||||
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
fail-fast: false
|
||||
|
||||
steps:
|
||||
@@ -46,7 +46,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: [ "3.8", "3.9", "3.10", "3.11", "3.12" ]
|
||||
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
||||
rust-version: [ "1.76.0", "stable" ]
|
||||
fail-fast: false
|
||||
steps:
|
||||
|
||||
7
.vscode/settings.json
vendored
7
.vscode/settings.json
vendored
@@ -14,9 +14,12 @@
|
||||
"ASHA",
|
||||
"asyncio",
|
||||
"ATRAC",
|
||||
"auracast",
|
||||
"avctp",
|
||||
"avdtp",
|
||||
"avrcp",
|
||||
"biginfo",
|
||||
"bigs",
|
||||
"bitpool",
|
||||
"bitstruct",
|
||||
"BSCP",
|
||||
@@ -36,6 +39,7 @@
|
||||
"deregistration",
|
||||
"dhkey",
|
||||
"diversifier",
|
||||
"ediv",
|
||||
"endianness",
|
||||
"ESCO",
|
||||
"Fitbit",
|
||||
@@ -47,6 +51,7 @@
|
||||
"libc",
|
||||
"liblc",
|
||||
"libusb",
|
||||
"maxs",
|
||||
"MITM",
|
||||
"MSBC",
|
||||
"NDIS",
|
||||
@@ -54,8 +59,10 @@
|
||||
"NONBLOCK",
|
||||
"NONCONN",
|
||||
"OXIMETER",
|
||||
"PDUS",
|
||||
"popleft",
|
||||
"PRAND",
|
||||
"prefs",
|
||||
"protobuf",
|
||||
"psms",
|
||||
"pyee",
|
||||
|
||||
726
apps/auracast.py
726
apps/auracast.py
@@ -1,4 +1,4 @@
|
||||
# Copyright 2024 Google LLC
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -16,29 +16,50 @@
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import asyncio.subprocess
|
||||
import collections
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
from typing import cast, Any, AsyncGenerator, Coroutine, Dict, Optional, Tuple
|
||||
import struct
|
||||
from typing import (
|
||||
cast,
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Coroutine,
|
||||
Deque,
|
||||
Optional,
|
||||
Tuple,
|
||||
)
|
||||
|
||||
import click
|
||||
import pyee
|
||||
|
||||
try:
|
||||
import lc3 # type: ignore # pylint: disable=E0401
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Try `python -m pip install \"git+https://github.com/google/liblc3.git\"`."
|
||||
) from e
|
||||
|
||||
from bumble.audio import io as audio_io
|
||||
from bumble.colors import color
|
||||
import bumble.company_ids
|
||||
import bumble.core
|
||||
from bumble import company_ids
|
||||
from bumble import core
|
||||
from bumble import gatt
|
||||
from bumble import hci
|
||||
from bumble.profiles import bap
|
||||
from bumble.profiles import le_audio
|
||||
from bumble.profiles import pbp
|
||||
from bumble.profiles import bass
|
||||
import bumble.device
|
||||
import bumble.gatt
|
||||
import bumble.hci
|
||||
import bumble.profiles.bap
|
||||
import bumble.profiles.bass
|
||||
import bumble.profiles.pbp
|
||||
import bumble.transport
|
||||
import bumble.utils
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -49,9 +70,34 @@ logger = logging.getLogger(__name__)
|
||||
# Constants
|
||||
# -----------------------------------------------------------------------------
|
||||
AURACAST_DEFAULT_DEVICE_NAME = 'Bumble Auracast'
|
||||
AURACAST_DEFAULT_DEVICE_ADDRESS = bumble.hci.Address('F0:F1:F2:F3:F4:F5')
|
||||
AURACAST_DEFAULT_DEVICE_ADDRESS = hci.Address('F0:F1:F2:F3:F4:F5')
|
||||
AURACAST_DEFAULT_SYNC_TIMEOUT = 5.0
|
||||
AURACAST_DEFAULT_ATT_MTU = 256
|
||||
AURACAST_DEFAULT_FRAME_DURATION = 10000
|
||||
AURACAST_DEFAULT_SAMPLE_RATE = 48000
|
||||
AURACAST_DEFAULT_TRANSMIT_BITRATE = 80000
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Utils
|
||||
# -----------------------------------------------------------------------------
|
||||
def codec_config_string(
|
||||
codec_config: bap.CodecSpecificConfiguration, indent: str
|
||||
) -> str:
|
||||
lines = []
|
||||
if codec_config.sampling_frequency is not None:
|
||||
lines.append(f'Sampling Frequency: {codec_config.sampling_frequency.hz} hz')
|
||||
if codec_config.frame_duration is not None:
|
||||
lines.append(f'Frame Duration: {codec_config.frame_duration.us} µs')
|
||||
if codec_config.octets_per_codec_frame is not None:
|
||||
lines.append(f'Frame Size: {codec_config.octets_per_codec_frame} bytes')
|
||||
if codec_config.codec_frames_per_sdu is not None:
|
||||
lines.append(f'Frames Per SDU: {codec_config.codec_frames_per_sdu}')
|
||||
if codec_config.audio_channel_allocation is not None:
|
||||
lines.append(
|
||||
f'Audio Location: {codec_config.audio_channel_allocation.name}'
|
||||
)
|
||||
return '\n'.join(indent + line for line in lines)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -60,19 +106,14 @@ AURACAST_DEFAULT_ATT_MTU = 256
|
||||
class BroadcastScanner(pyee.EventEmitter):
|
||||
@dataclasses.dataclass
|
||||
class Broadcast(pyee.EventEmitter):
|
||||
name: str
|
||||
name: str | None
|
||||
sync: bumble.device.PeriodicAdvertisingSync
|
||||
broadcast_id: int
|
||||
rssi: int = 0
|
||||
public_broadcast_announcement: Optional[
|
||||
bumble.profiles.pbp.PublicBroadcastAnnouncement
|
||||
] = None
|
||||
broadcast_audio_announcement: Optional[
|
||||
bumble.profiles.bap.BroadcastAudioAnnouncement
|
||||
] = None
|
||||
basic_audio_announcement: Optional[
|
||||
bumble.profiles.bap.BasicAudioAnnouncement
|
||||
] = None
|
||||
appearance: Optional[bumble.core.Appearance] = None
|
||||
public_broadcast_announcement: Optional[pbp.PublicBroadcastAnnouncement] = None
|
||||
broadcast_audio_announcement: Optional[bap.BroadcastAudioAnnouncement] = None
|
||||
basic_audio_announcement: Optional[bap.BasicAudioAnnouncement] = None
|
||||
appearance: Optional[core.Appearance] = None
|
||||
biginfo: Optional[bumble.device.BIGInfoAdvertisement] = None
|
||||
manufacturer_data: Optional[Tuple[str, bytes]] = None
|
||||
|
||||
@@ -86,42 +127,36 @@ class BroadcastScanner(pyee.EventEmitter):
|
||||
def update(self, advertisement: bumble.device.Advertisement) -> None:
|
||||
self.rssi = advertisement.rssi
|
||||
for service_data in advertisement.data.get_all(
|
||||
bumble.core.AdvertisingData.SERVICE_DATA
|
||||
core.AdvertisingData.SERVICE_DATA
|
||||
):
|
||||
assert isinstance(service_data, tuple)
|
||||
service_uuid, data = service_data
|
||||
assert isinstance(data, bytes)
|
||||
|
||||
if (
|
||||
service_uuid
|
||||
== bumble.gatt.GATT_PUBLIC_BROADCAST_ANNOUNCEMENT_SERVICE
|
||||
):
|
||||
if service_uuid == gatt.GATT_PUBLIC_BROADCAST_ANNOUNCEMENT_SERVICE:
|
||||
self.public_broadcast_announcement = (
|
||||
bumble.profiles.pbp.PublicBroadcastAnnouncement.from_bytes(data)
|
||||
pbp.PublicBroadcastAnnouncement.from_bytes(data)
|
||||
)
|
||||
continue
|
||||
|
||||
if (
|
||||
service_uuid
|
||||
== bumble.gatt.GATT_BROADCAST_AUDIO_ANNOUNCEMENT_SERVICE
|
||||
):
|
||||
if service_uuid == gatt.GATT_BROADCAST_AUDIO_ANNOUNCEMENT_SERVICE:
|
||||
self.broadcast_audio_announcement = (
|
||||
bumble.profiles.bap.BroadcastAudioAnnouncement.from_bytes(data)
|
||||
bap.BroadcastAudioAnnouncement.from_bytes(data)
|
||||
)
|
||||
continue
|
||||
|
||||
self.appearance = advertisement.data.get( # type: ignore[assignment]
|
||||
bumble.core.AdvertisingData.APPEARANCE
|
||||
core.AdvertisingData.APPEARANCE
|
||||
)
|
||||
|
||||
if manufacturer_data := advertisement.data.get(
|
||||
bumble.core.AdvertisingData.MANUFACTURER_SPECIFIC_DATA
|
||||
core.AdvertisingData.MANUFACTURER_SPECIFIC_DATA
|
||||
):
|
||||
assert isinstance(manufacturer_data, tuple)
|
||||
company_id = cast(int, manufacturer_data[0])
|
||||
data = cast(bytes, manufacturer_data[1])
|
||||
self.manufacturer_data = (
|
||||
bumble.company_ids.COMPANY_IDENTIFIERS.get(
|
||||
company_ids.COMPANY_IDENTIFIERS.get(
|
||||
company_id, f'0x{company_id:04X}'
|
||||
),
|
||||
data,
|
||||
@@ -135,7 +170,8 @@ class BroadcastScanner(pyee.EventEmitter):
|
||||
self.sync.advertiser_address,
|
||||
color(self.sync.state.name, 'green'),
|
||||
)
|
||||
print(f' {color("Name", "cyan")}: {self.name}')
|
||||
if self.name is not None:
|
||||
print(f' {color("Name", "cyan")}: {self.name}')
|
||||
if self.appearance:
|
||||
print(f' {color("Appearance", "cyan")}: {str(self.appearance)}')
|
||||
print(f' {color("RSSI", "cyan")}: {self.rssi}')
|
||||
@@ -156,25 +192,24 @@ class BroadcastScanner(pyee.EventEmitter):
|
||||
if self.public_broadcast_announcement:
|
||||
print(
|
||||
f' {color("Features", "cyan")}: '
|
||||
f'{self.public_broadcast_announcement.features}'
|
||||
)
|
||||
print(
|
||||
f' {color("Metadata", "cyan")}: '
|
||||
f'{self.public_broadcast_announcement.metadata}'
|
||||
f'{self.public_broadcast_announcement.features.name}'
|
||||
)
|
||||
print(f' {color("Metadata", "cyan")}:')
|
||||
print(self.public_broadcast_announcement.metadata.pretty_print(' '))
|
||||
|
||||
if self.basic_audio_announcement:
|
||||
print(color(' Audio:', 'cyan'))
|
||||
print(
|
||||
color(' Presentation Delay:', 'magenta'),
|
||||
self.basic_audio_announcement.presentation_delay,
|
||||
"µs",
|
||||
)
|
||||
for subgroup in self.basic_audio_announcement.subgroups:
|
||||
print(color(' Subgroup:', 'magenta'))
|
||||
print(color(' Codec ID:', 'yellow'))
|
||||
print(
|
||||
color(' Coding Format: ', 'green'),
|
||||
subgroup.codec_id.coding_format.name,
|
||||
subgroup.codec_id.codec_id.name,
|
||||
)
|
||||
print(
|
||||
color(' Company ID: ', 'green'),
|
||||
@@ -184,17 +219,22 @@ class BroadcastScanner(pyee.EventEmitter):
|
||||
color(' Vendor Specific Codec ID:', 'green'),
|
||||
subgroup.codec_id.vendor_specific_codec_id,
|
||||
)
|
||||
print(color(' Codec Config:', 'yellow'))
|
||||
print(
|
||||
color(' Codec Config:', 'yellow'),
|
||||
subgroup.codec_specific_configuration,
|
||||
codec_config_string(
|
||||
subgroup.codec_specific_configuration, ' '
|
||||
),
|
||||
)
|
||||
print(color(' Metadata: ', 'yellow'), subgroup.metadata)
|
||||
print(color(' Metadata: ', 'yellow'))
|
||||
print(subgroup.metadata.pretty_print(' '))
|
||||
|
||||
for bis in subgroup.bis:
|
||||
print(color(f' BIS [{bis.index}]:', 'yellow'))
|
||||
print(color(' Codec Config:', 'green'))
|
||||
print(
|
||||
color(' Codec Config:', 'green'),
|
||||
bis.codec_specific_configuration,
|
||||
codec_config_string(
|
||||
bis.codec_specific_configuration, ' '
|
||||
),
|
||||
)
|
||||
|
||||
if self.biginfo:
|
||||
@@ -231,15 +271,15 @@ class BroadcastScanner(pyee.EventEmitter):
|
||||
return
|
||||
|
||||
for service_data in advertisement.data.get_all(
|
||||
bumble.core.AdvertisingData.SERVICE_DATA
|
||||
core.AdvertisingData.SERVICE_DATA
|
||||
):
|
||||
assert isinstance(service_data, tuple)
|
||||
service_uuid, data = service_data
|
||||
assert isinstance(data, bytes)
|
||||
|
||||
if service_uuid == bumble.gatt.GATT_BASIC_AUDIO_ANNOUNCEMENT_SERVICE:
|
||||
if service_uuid == gatt.GATT_BASIC_AUDIO_ANNOUNCEMENT_SERVICE:
|
||||
self.basic_audio_announcement = (
|
||||
bumble.profiles.bap.BasicAudioAnnouncement.from_bytes(data)
|
||||
bap.BasicAudioAnnouncement.from_bytes(data)
|
||||
)
|
||||
break
|
||||
|
||||
@@ -261,7 +301,7 @@ class BroadcastScanner(pyee.EventEmitter):
|
||||
self.device = device
|
||||
self.filter_duplicates = filter_duplicates
|
||||
self.sync_timeout = sync_timeout
|
||||
self.broadcasts: Dict[bumble.hci.Address, BroadcastScanner.Broadcast] = {}
|
||||
self.broadcasts = dict[hci.Address, BroadcastScanner.Broadcast]()
|
||||
device.on('advertisement', self.on_advertisement)
|
||||
|
||||
async def start(self) -> None:
|
||||
@@ -274,24 +314,46 @@ class BroadcastScanner(pyee.EventEmitter):
|
||||
await self.device.stop_scanning()
|
||||
|
||||
def on_advertisement(self, advertisement: bumble.device.Advertisement) -> None:
|
||||
if (
|
||||
broadcast_name := advertisement.data.get(
|
||||
bumble.core.AdvertisingData.BROADCAST_NAME
|
||||
if not (
|
||||
ads := advertisement.data.get_all(
|
||||
core.AdvertisingData.SERVICE_DATA_16_BIT_UUID
|
||||
)
|
||||
) is None:
|
||||
) or not (
|
||||
broadcast_audio_announcement := next(
|
||||
(
|
||||
ad
|
||||
for ad in ads
|
||||
if isinstance(ad, tuple)
|
||||
and ad[0] == gatt.GATT_BROADCAST_AUDIO_ANNOUNCEMENT_SERVICE
|
||||
),
|
||||
None,
|
||||
)
|
||||
):
|
||||
return
|
||||
assert isinstance(broadcast_name, str)
|
||||
|
||||
broadcast_name = advertisement.data.get(core.AdvertisingData.BROADCAST_NAME)
|
||||
assert isinstance(broadcast_name, str) or broadcast_name is None
|
||||
assert isinstance(broadcast_audio_announcement[1], bytes)
|
||||
|
||||
if broadcast := self.broadcasts.get(advertisement.address):
|
||||
broadcast.update(advertisement)
|
||||
return
|
||||
|
||||
bumble.utils.AsyncRunner.spawn(
|
||||
self.on_new_broadcast(broadcast_name, advertisement)
|
||||
self.on_new_broadcast(
|
||||
broadcast_name,
|
||||
advertisement,
|
||||
bap.BroadcastAudioAnnouncement.from_bytes(
|
||||
broadcast_audio_announcement[1]
|
||||
).broadcast_id,
|
||||
)
|
||||
)
|
||||
|
||||
async def on_new_broadcast(
|
||||
self, name: str, advertisement: bumble.device.Advertisement
|
||||
self,
|
||||
name: str | None,
|
||||
advertisement: bumble.device.Advertisement,
|
||||
broadcast_id: int,
|
||||
) -> None:
|
||||
periodic_advertising_sync = await self.device.create_periodic_advertising_sync(
|
||||
advertiser_address=advertisement.address,
|
||||
@@ -299,10 +361,7 @@ class BroadcastScanner(pyee.EventEmitter):
|
||||
sync_timeout=self.sync_timeout,
|
||||
filter_duplicates=self.filter_duplicates,
|
||||
)
|
||||
broadcast = self.Broadcast(
|
||||
name,
|
||||
periodic_advertising_sync,
|
||||
)
|
||||
broadcast = self.Broadcast(name, periodic_advertising_sync, broadcast_id)
|
||||
broadcast.update(advertisement)
|
||||
self.broadcasts[advertisement.address] = broadcast
|
||||
periodic_advertising_sync.on('loss', lambda: self.on_broadcast_loss(broadcast))
|
||||
@@ -314,10 +373,11 @@ class BroadcastScanner(pyee.EventEmitter):
|
||||
self.emit('broadcast_loss', broadcast)
|
||||
|
||||
|
||||
class PrintingBroadcastScanner:
|
||||
class PrintingBroadcastScanner(pyee.EventEmitter):
|
||||
def __init__(
|
||||
self, device: bumble.device.Device, filter_duplicates: bool, sync_timeout: float
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.scanner = BroadcastScanner(device, filter_duplicates, sync_timeout)
|
||||
self.scanner.on('new_broadcast', self.on_new_broadcast)
|
||||
self.scanner.on('broadcast_loss', self.on_broadcast_loss)
|
||||
@@ -452,27 +512,29 @@ async def run_assist(
|
||||
await peer.request_mtu(mtu)
|
||||
|
||||
# Get the BASS service
|
||||
bass = await peer.discover_service_and_create_proxy(
|
||||
bumble.profiles.bass.BroadcastAudioScanServiceProxy
|
||||
bass_client = await peer.discover_service_and_create_proxy(
|
||||
bass.BroadcastAudioScanServiceProxy
|
||||
)
|
||||
|
||||
# Check that the service was found
|
||||
if not bass:
|
||||
if not bass_client:
|
||||
print(color('!!! Broadcast Audio Scan Service not found', 'red'))
|
||||
return
|
||||
|
||||
# Subscribe to and read the broadcast receive state characteristics
|
||||
for i, broadcast_receive_state in enumerate(bass.broadcast_receive_states):
|
||||
for i, broadcast_receive_state in enumerate(
|
||||
bass_client.broadcast_receive_states
|
||||
):
|
||||
try:
|
||||
await broadcast_receive_state.subscribe(
|
||||
lambda value, i=i: print(
|
||||
f"{color(f'Broadcast Receive State Update [{i}]:', 'green')} {value}"
|
||||
)
|
||||
)
|
||||
except bumble.core.ProtocolError as error:
|
||||
except core.ProtocolError as error:
|
||||
print(
|
||||
color(
|
||||
f'!!! Failed to subscribe to Broadcast Receive State characteristic:',
|
||||
'!!! Failed to subscribe to Broadcast Receive State characteristic',
|
||||
'red',
|
||||
),
|
||||
error,
|
||||
@@ -488,7 +550,7 @@ async def run_assist(
|
||||
|
||||
if command == 'add-source':
|
||||
# Find the requested broadcast
|
||||
await bass.remote_scan_started()
|
||||
await bass_client.remote_scan_started()
|
||||
if broadcast_name:
|
||||
print(color('Scanning for broadcast:', 'cyan'), broadcast_name)
|
||||
else:
|
||||
@@ -508,15 +570,15 @@ async def run_assist(
|
||||
|
||||
# Add the source
|
||||
print(color('Adding source:', 'blue'), broadcast.sync.advertiser_address)
|
||||
await bass.add_source(
|
||||
await bass_client.add_source(
|
||||
broadcast.sync.advertiser_address,
|
||||
broadcast.sync.sid,
|
||||
broadcast.broadcast_audio_announcement.broadcast_id,
|
||||
bumble.profiles.bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_AVAILABLE,
|
||||
bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_AVAILABLE,
|
||||
0xFFFF,
|
||||
[
|
||||
bumble.profiles.bass.SubgroupInfo(
|
||||
bumble.profiles.bass.SubgroupInfo.ANY_BIS,
|
||||
bass.SubgroupInfo(
|
||||
bass.SubgroupInfo.ANY_BIS,
|
||||
bytes(broadcast.basic_audio_announcement.subgroups[0].metadata),
|
||||
)
|
||||
],
|
||||
@@ -526,7 +588,7 @@ async def run_assist(
|
||||
await broadcast.sync.transfer(peer.connection)
|
||||
|
||||
# Notify the sink that we're done scanning.
|
||||
await bass.remote_scan_stopped()
|
||||
await bass_client.remote_scan_stopped()
|
||||
|
||||
await peer.sustain()
|
||||
return
|
||||
@@ -537,7 +599,7 @@ async def run_assist(
|
||||
return
|
||||
|
||||
# Find the requested broadcast
|
||||
await bass.remote_scan_started()
|
||||
await bass_client.remote_scan_started()
|
||||
if broadcast_name:
|
||||
print(color('Scanning for broadcast:', 'cyan'), broadcast_name)
|
||||
else:
|
||||
@@ -560,13 +622,13 @@ async def run_assist(
|
||||
color('Modifying source:', 'blue'),
|
||||
source_id,
|
||||
)
|
||||
await bass.modify_source(
|
||||
await bass_client.modify_source(
|
||||
source_id,
|
||||
bumble.profiles.bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_NOT_AVAILABLE,
|
||||
bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_NOT_AVAILABLE,
|
||||
0xFFFF,
|
||||
[
|
||||
bumble.profiles.bass.SubgroupInfo(
|
||||
bumble.profiles.bass.SubgroupInfo.ANY_BIS,
|
||||
bass.SubgroupInfo(
|
||||
bass.SubgroupInfo.ANY_BIS,
|
||||
bytes(broadcast.basic_audio_announcement.subgroups[0].metadata),
|
||||
)
|
||||
],
|
||||
@@ -581,7 +643,7 @@ async def run_assist(
|
||||
|
||||
# Remove the source
|
||||
print(color('Removing source:', 'blue'), source_id)
|
||||
await bass.remove_source(source_id)
|
||||
await bass_client.remove_source(source_id)
|
||||
await peer.sustain()
|
||||
return
|
||||
|
||||
@@ -601,14 +663,342 @@ async def run_pair(transport: str, address: str) -> None:
|
||||
print("+++ Paired")
|
||||
|
||||
|
||||
async def run_receive(
|
||||
transport: str,
|
||||
broadcast_id: Optional[int],
|
||||
output: str,
|
||||
broadcast_code: str | None,
|
||||
sync_timeout: float,
|
||||
subgroup_index: int,
|
||||
) -> None:
|
||||
# Run a pre-flight check for the output.
|
||||
try:
|
||||
if not audio_io.check_audio_output(output):
|
||||
return
|
||||
except ValueError as error:
|
||||
print(error)
|
||||
return
|
||||
|
||||
async with create_device(transport) as device:
|
||||
if not device.supports_le_periodic_advertising:
|
||||
print(color('Periodic advertising not supported', 'red'))
|
||||
return
|
||||
|
||||
scanner = BroadcastScanner(device, False, sync_timeout)
|
||||
scan_result: asyncio.Future[BroadcastScanner.Broadcast] = (
|
||||
asyncio.get_running_loop().create_future()
|
||||
)
|
||||
|
||||
def on_new_broadcast(broadcast: BroadcastScanner.Broadcast) -> None:
|
||||
if scan_result.done():
|
||||
return
|
||||
if broadcast_id is None or broadcast.broadcast_id == broadcast_id:
|
||||
scan_result.set_result(broadcast)
|
||||
|
||||
scanner.on('new_broadcast', on_new_broadcast)
|
||||
await scanner.start()
|
||||
print('Start scanning...')
|
||||
broadcast = await scan_result
|
||||
print('Advertisement found:')
|
||||
broadcast.print()
|
||||
basic_audio_announcement_scanned = asyncio.Event()
|
||||
|
||||
def on_change() -> None:
|
||||
if (
|
||||
broadcast.basic_audio_announcement
|
||||
and not basic_audio_announcement_scanned.is_set()
|
||||
):
|
||||
basic_audio_announcement_scanned.set()
|
||||
|
||||
broadcast.on('change', on_change)
|
||||
if not broadcast.basic_audio_announcement:
|
||||
print('Wait for Basic Audio Announcement...')
|
||||
await basic_audio_announcement_scanned.wait()
|
||||
print('Basic Audio Announcement found')
|
||||
broadcast.print()
|
||||
print('Stop scanning')
|
||||
await scanner.stop()
|
||||
print('Start sync to BIG')
|
||||
|
||||
assert broadcast.basic_audio_announcement
|
||||
subgroup = broadcast.basic_audio_announcement.subgroups[subgroup_index]
|
||||
configuration = subgroup.codec_specific_configuration
|
||||
assert configuration
|
||||
assert (sampling_frequency := configuration.sampling_frequency)
|
||||
assert (frame_duration := configuration.frame_duration)
|
||||
|
||||
big_sync = await device.create_big_sync(
|
||||
broadcast.sync,
|
||||
bumble.device.BigSyncParameters(
|
||||
big_sync_timeout=0x4000,
|
||||
bis=[bis.index for bis in subgroup.bis],
|
||||
broadcast_code=(
|
||||
bytes.fromhex(broadcast_code) if broadcast_code else None
|
||||
),
|
||||
),
|
||||
)
|
||||
num_bis = len(big_sync.bis_links)
|
||||
decoder = lc3.Decoder(
|
||||
frame_duration_us=frame_duration.us,
|
||||
sample_rate_hz=sampling_frequency.hz,
|
||||
num_channels=num_bis,
|
||||
)
|
||||
lc3_queues: list[Deque[bytes]] = [collections.deque() for i in range(num_bis)]
|
||||
packet_stats = [0, 0]
|
||||
|
||||
audio_output = await audio_io.create_audio_output(output)
|
||||
# This try should be replaced with contextlib.aclosing() when python 3.9 is no
|
||||
# longer needed.
|
||||
try:
|
||||
await audio_output.open(
|
||||
audio_io.PcmFormat(
|
||||
audio_io.PcmFormat.Endianness.LITTLE,
|
||||
audio_io.PcmFormat.SampleType.FLOAT32,
|
||||
sampling_frequency.hz,
|
||||
num_bis,
|
||||
)
|
||||
)
|
||||
|
||||
def sink(queue: Deque[bytes], packet: hci.HCI_IsoDataPacket):
|
||||
# TODO: re-assemble fragments and detect errors
|
||||
queue.append(packet.iso_sdu_fragment)
|
||||
|
||||
while all(lc3_queues):
|
||||
# This assumes SDUs contain one LC3 frame each, which may not
|
||||
# be correct for all cases. TODO: revisit this assumption.
|
||||
frame = b''.join([lc3_queue.popleft() for lc3_queue in lc3_queues])
|
||||
if not frame:
|
||||
print(color('!!! received empty frame', 'red'))
|
||||
continue
|
||||
|
||||
packet_stats[0] += len(frame)
|
||||
packet_stats[1] += 1
|
||||
print(
|
||||
f'\rRECEIVED: {packet_stats[0]} bytes in '
|
||||
f'{packet_stats[1]} packets',
|
||||
end='',
|
||||
)
|
||||
|
||||
try:
|
||||
pcm = decoder.decode(frame).tobytes()
|
||||
except lc3.BaseError as error:
|
||||
print(color(f'!!! LC3 decoding error: {error}'))
|
||||
continue
|
||||
|
||||
audio_output.write(pcm)
|
||||
|
||||
for i, bis_link in enumerate(big_sync.bis_links):
|
||||
print(f'Setup ISO for BIS {bis_link.handle}')
|
||||
bis_link.sink = functools.partial(sink, lc3_queues[i])
|
||||
await device.send_command(
|
||||
hci.HCI_LE_Setup_ISO_Data_Path_Command(
|
||||
connection_handle=bis_link.handle,
|
||||
data_path_direction=hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction.CONTROLLER_TO_HOST,
|
||||
data_path_id=0,
|
||||
codec_id=hci.CodingFormat(codec_id=hci.CodecID.TRANSPARENT),
|
||||
controller_delay=0,
|
||||
codec_configuration=b'',
|
||||
),
|
||||
check_result=True,
|
||||
)
|
||||
|
||||
terminated = asyncio.Event()
|
||||
big_sync.on(big_sync.Event.TERMINATION, lambda _: terminated.set())
|
||||
await terminated.wait()
|
||||
finally:
|
||||
await audio_output.aclose()
|
||||
|
||||
|
||||
async def run_transmit(
|
||||
transport: str,
|
||||
broadcast_id: int,
|
||||
broadcast_code: str | None,
|
||||
broadcast_name: str,
|
||||
bitrate: int,
|
||||
manufacturer_data: tuple[int, bytes] | None,
|
||||
input: str,
|
||||
input_format: str,
|
||||
) -> None:
|
||||
# Run a pre-flight check for the input.
|
||||
try:
|
||||
if not audio_io.check_audio_input(input):
|
||||
return
|
||||
except ValueError as error:
|
||||
print(error)
|
||||
return
|
||||
|
||||
async with create_device(transport) as device:
|
||||
if not device.supports_le_periodic_advertising:
|
||||
print(color('Periodic advertising not supported', 'red'))
|
||||
return
|
||||
|
||||
basic_audio_announcement = bap.BasicAudioAnnouncement(
|
||||
presentation_delay=40000,
|
||||
subgroups=[
|
||||
bap.BasicAudioAnnouncement.Subgroup(
|
||||
codec_id=hci.CodingFormat(codec_id=hci.CodecID.LC3),
|
||||
codec_specific_configuration=bap.CodecSpecificConfiguration(
|
||||
sampling_frequency=bap.SamplingFrequency.FREQ_48000,
|
||||
frame_duration=bap.FrameDuration.DURATION_10000_US,
|
||||
octets_per_codec_frame=100,
|
||||
),
|
||||
metadata=le_audio.Metadata(
|
||||
[
|
||||
le_audio.Metadata.Entry(
|
||||
tag=le_audio.Metadata.Tag.LANGUAGE, data=b'eng'
|
||||
),
|
||||
le_audio.Metadata.Entry(
|
||||
tag=le_audio.Metadata.Tag.PROGRAM_INFO, data=b'Disco'
|
||||
),
|
||||
]
|
||||
),
|
||||
bis=[
|
||||
bap.BasicAudioAnnouncement.BIS(
|
||||
index=1,
|
||||
codec_specific_configuration=bap.CodecSpecificConfiguration(
|
||||
audio_channel_allocation=bap.AudioLocation.FRONT_LEFT
|
||||
),
|
||||
),
|
||||
bap.BasicAudioAnnouncement.BIS(
|
||||
index=2,
|
||||
codec_specific_configuration=bap.CodecSpecificConfiguration(
|
||||
audio_channel_allocation=bap.AudioLocation.FRONT_RIGHT
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
],
|
||||
)
|
||||
broadcast_audio_announcement = bap.BroadcastAudioAnnouncement(broadcast_id)
|
||||
|
||||
advertising_manufacturer_data = (
|
||||
b''
|
||||
if manufacturer_data is None
|
||||
else bytes(
|
||||
core.AdvertisingData(
|
||||
[
|
||||
(
|
||||
core.AdvertisingData.MANUFACTURER_SPECIFIC_DATA,
|
||||
struct.pack('<H', manufacturer_data[0])
|
||||
+ manufacturer_data[1],
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
advertising_set = await device.create_advertising_set(
|
||||
advertising_parameters=bumble.device.AdvertisingParameters(
|
||||
advertising_event_properties=bumble.device.AdvertisingEventProperties(
|
||||
is_connectable=False
|
||||
),
|
||||
primary_advertising_interval_min=100,
|
||||
primary_advertising_interval_max=200,
|
||||
),
|
||||
advertising_data=(
|
||||
broadcast_audio_announcement.get_advertising_data()
|
||||
+ bytes(
|
||||
core.AdvertisingData(
|
||||
[(core.AdvertisingData.BROADCAST_NAME, broadcast_name.encode())]
|
||||
)
|
||||
)
|
||||
+ advertising_manufacturer_data
|
||||
),
|
||||
periodic_advertising_parameters=bumble.device.PeriodicAdvertisingParameters(
|
||||
periodic_advertising_interval_min=80,
|
||||
periodic_advertising_interval_max=160,
|
||||
),
|
||||
periodic_advertising_data=basic_audio_announcement.get_advertising_data(),
|
||||
auto_restart=True,
|
||||
auto_start=True,
|
||||
)
|
||||
|
||||
print('Start Periodic Advertising')
|
||||
await advertising_set.start_periodic()
|
||||
|
||||
audio_input = await audio_io.create_audio_input(input, input_format)
|
||||
pcm_format = await audio_input.open()
|
||||
# This try should be replaced with contextlib.aclosing() when python 3.9 is no
|
||||
# longer needed.
|
||||
try:
|
||||
if pcm_format.channels != 2:
|
||||
print("Only 2 channels PCM configurations are supported")
|
||||
return
|
||||
if pcm_format.sample_type == audio_io.PcmFormat.SampleType.INT16:
|
||||
pcm_bit_depth = 16
|
||||
elif pcm_format.sample_type == audio_io.PcmFormat.SampleType.FLOAT32:
|
||||
pcm_bit_depth = None
|
||||
else:
|
||||
print("Only INT16 and FLOAT32 sample types are supported")
|
||||
return
|
||||
|
||||
encoder = lc3.Encoder(
|
||||
frame_duration_us=AURACAST_DEFAULT_FRAME_DURATION,
|
||||
sample_rate_hz=AURACAST_DEFAULT_SAMPLE_RATE,
|
||||
num_channels=pcm_format.channels,
|
||||
input_sample_rate_hz=pcm_format.sample_rate,
|
||||
)
|
||||
lc3_frame_samples = encoder.get_frame_samples()
|
||||
lc3_frame_size = encoder.get_frame_bytes(bitrate)
|
||||
print(
|
||||
f'Encoding with {lc3_frame_samples} '
|
||||
f'PCM samples per {lc3_frame_size} byte frame'
|
||||
)
|
||||
|
||||
print('Setup BIG')
|
||||
big = await device.create_big(
|
||||
advertising_set,
|
||||
parameters=bumble.device.BigParameters(
|
||||
num_bis=pcm_format.channels,
|
||||
sdu_interval=AURACAST_DEFAULT_FRAME_DURATION,
|
||||
max_sdu=lc3_frame_size,
|
||||
max_transport_latency=65,
|
||||
rtn=4,
|
||||
broadcast_code=(
|
||||
bytes.fromhex(broadcast_code) if broadcast_code else None
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
iso_queues = [
|
||||
bumble.device.IsoPacketStream(big.bis_links[0], 64),
|
||||
bumble.device.IsoPacketStream(big.bis_links[1], 64),
|
||||
]
|
||||
|
||||
def on_flow():
|
||||
data_packet_queue = iso_queues[0].data_packet_queue
|
||||
print(
|
||||
f'\rPACKETS: pending={data_packet_queue.pending}, '
|
||||
f'queued={data_packet_queue.queued}, '
|
||||
f'completed={data_packet_queue.completed}',
|
||||
end='',
|
||||
)
|
||||
|
||||
iso_queues[0].data_packet_queue.on('flow', on_flow)
|
||||
|
||||
frame_count = 0
|
||||
async for pcm_frame in audio_input.frames(lc3_frame_samples):
|
||||
lc3_frame = encoder.encode(
|
||||
pcm_frame, num_bytes=2 * lc3_frame_size, bit_depth=pcm_bit_depth
|
||||
)
|
||||
|
||||
mid = len(lc3_frame) // 2
|
||||
await iso_queues[0].write(lc3_frame[:mid])
|
||||
await iso_queues[1].write(lc3_frame[mid:])
|
||||
|
||||
frame_count += 1
|
||||
finally:
|
||||
await audio_input.aclose()
|
||||
|
||||
|
||||
def run_async(async_command: Coroutine) -> None:
|
||||
try:
|
||||
asyncio.run(async_command)
|
||||
except bumble.core.ProtocolError as error:
|
||||
except core.ProtocolError as error:
|
||||
if error.error_namespace == 'att' and error.error_code in list(
|
||||
bumble.profiles.bass.ApplicationError
|
||||
bass.ApplicationError
|
||||
):
|
||||
message = bumble.profiles.bass.ApplicationError(error.error_code).name
|
||||
message = bass.ApplicationError(error.error_code).name
|
||||
else:
|
||||
message = str(error)
|
||||
|
||||
@@ -622,9 +1012,7 @@ def run_async(async_command: Coroutine) -> None:
|
||||
# -----------------------------------------------------------------------------
|
||||
@click.group()
|
||||
@click.pass_context
|
||||
def auracast(
|
||||
ctx,
|
||||
):
|
||||
def auracast(ctx):
|
||||
ctx.ensure_object(dict)
|
||||
|
||||
|
||||
@@ -669,7 +1057,7 @@ def scan(ctx, filter_duplicates, sync_timeout, transport):
|
||||
@click.argument('address')
|
||||
@click.pass_context
|
||||
def assist(ctx, broadcast_name, source_id, command, transport, address):
|
||||
"""Scan for broadcasts on behalf of a audio server"""
|
||||
"""Scan for broadcasts on behalf of an audio server"""
|
||||
run_async(run_assist(broadcast_name, source_id, command, transport, address))
|
||||
|
||||
|
||||
@@ -682,6 +1070,166 @@ def pair(ctx, transport, address):
|
||||
run_async(run_pair(transport, address))
|
||||
|
||||
|
||||
@auracast.command('receive')
|
||||
@click.argument('transport')
|
||||
@click.argument(
|
||||
'broadcast_id',
|
||||
type=int,
|
||||
required=False,
|
||||
)
|
||||
@click.option(
|
||||
'--output',
|
||||
default='device',
|
||||
help=(
|
||||
"Audio output. "
|
||||
"'device' -> use the host's default sound output device, "
|
||||
"'device:<DEVICE_ID>' -> use one of the host's sound output device "
|
||||
"(specify 'device:?' to get a list of available sound output devices), "
|
||||
"'stdout' -> send audio to stdout, "
|
||||
"'file:<filename> -> write audio to a raw float32 PCM file, "
|
||||
"'ffplay' -> pipe the audio to ffplay"
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
'--broadcast-code',
|
||||
metavar='BROADCAST_CODE',
|
||||
type=str,
|
||||
help='Broadcast encryption code in hex format',
|
||||
)
|
||||
@click.option(
|
||||
'--sync-timeout',
|
||||
metavar='SYNC_TIMEOUT',
|
||||
type=float,
|
||||
default=AURACAST_DEFAULT_SYNC_TIMEOUT,
|
||||
help='Sync timeout (in seconds)',
|
||||
)
|
||||
@click.option(
|
||||
'--subgroup',
|
||||
metavar='SUBGROUP',
|
||||
type=int,
|
||||
default=0,
|
||||
help='Index of Subgroup',
|
||||
)
|
||||
@click.pass_context
|
||||
def receive(
|
||||
ctx,
|
||||
transport,
|
||||
broadcast_id,
|
||||
output,
|
||||
broadcast_code,
|
||||
sync_timeout,
|
||||
subgroup,
|
||||
):
|
||||
"""Receive a broadcast source"""
|
||||
run_async(
|
||||
run_receive(
|
||||
transport,
|
||||
broadcast_id,
|
||||
output,
|
||||
broadcast_code,
|
||||
sync_timeout,
|
||||
subgroup,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@auracast.command('transmit')
|
||||
@click.argument('transport')
|
||||
@click.option(
|
||||
'--input',
|
||||
required=True,
|
||||
help=(
|
||||
"Audio input. "
|
||||
"'device' -> use the host's default sound input device, "
|
||||
"'device:<DEVICE_ID>' -> use one of the host's sound input devices "
|
||||
"(specify 'device:?' to get a list of available sound input devices), "
|
||||
"'stdin' -> receive audio from stdin as int16 PCM, "
|
||||
"'file:<filename> -> read audio from a .wav or raw int16 PCM file. "
|
||||
"(The file: prefix may be omitted if the file path does not start with "
|
||||
"the substring 'device:' or 'file:' and is not 'stdin')"
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
'--input-format',
|
||||
metavar="FORMAT",
|
||||
default='auto',
|
||||
help=(
|
||||
"Audio input format. "
|
||||
"Use 'auto' for .wav files, or for the default setting with the devices. "
|
||||
"For other inputs, the format is specified as "
|
||||
"<sample-type>,<sample-rate>,<channels> (supported <sample-type>: 'int16le' "
|
||||
"for 16-bit signed integers with little-endian byte order or 'float32le' for "
|
||||
"32-bit floating point with little-endian byte order)"
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
'--broadcast-id',
|
||||
metavar='BROADCAST_ID',
|
||||
type=int,
|
||||
default=123456,
|
||||
help='Broadcast ID',
|
||||
)
|
||||
@click.option(
|
||||
'--broadcast-code',
|
||||
metavar='BROADCAST_CODE',
|
||||
help='Broadcast encryption code in hex format',
|
||||
)
|
||||
@click.option(
|
||||
'--broadcast-name',
|
||||
metavar='BROADCAST_NAME',
|
||||
default='Bumble Auracast',
|
||||
help='Broadcast name',
|
||||
)
|
||||
@click.option(
|
||||
'--bitrate',
|
||||
type=int,
|
||||
default=AURACAST_DEFAULT_TRANSMIT_BITRATE,
|
||||
help='Bitrate, per channel, in bps',
|
||||
)
|
||||
@click.option(
|
||||
'--manufacturer-data',
|
||||
metavar='VENDOR-ID:DATA-HEX',
|
||||
help='Manufacturer data (specify as <vendor-id>:<data-hex>)',
|
||||
)
|
||||
@click.pass_context
|
||||
def transmit(
|
||||
ctx,
|
||||
transport,
|
||||
broadcast_id,
|
||||
broadcast_code,
|
||||
manufacturer_data,
|
||||
broadcast_name,
|
||||
bitrate,
|
||||
input,
|
||||
input_format,
|
||||
):
|
||||
"""Transmit a broadcast source"""
|
||||
if manufacturer_data:
|
||||
vendor_id_str, data_hex = manufacturer_data.split(':')
|
||||
vendor_id = int(vendor_id_str)
|
||||
data = bytes.fromhex(data_hex)
|
||||
manufacturer_data_tuple = (vendor_id, data)
|
||||
else:
|
||||
manufacturer_data_tuple = None
|
||||
|
||||
if (input == 'device' or input.startswith('device:')) and input_format == 'auto':
|
||||
# Use a default format for device inputs
|
||||
input_format = 'int16le,48000,1'
|
||||
|
||||
run_async(
|
||||
run_transmit(
|
||||
transport=transport,
|
||||
broadcast_id=broadcast_id,
|
||||
broadcast_code=broadcast_code,
|
||||
broadcast_name=broadcast_name,
|
||||
bitrate=bitrate,
|
||||
manufacturer_data=manufacturer_data_tuple,
|
||||
input=input,
|
||||
input_format=input_format,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
|
||||
auracast()
|
||||
|
||||
504
apps/bench.py
504
apps/bench.py
@@ -16,9 +16,11 @@
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import enum
|
||||
import logging
|
||||
import os
|
||||
import statistics
|
||||
import struct
|
||||
import time
|
||||
|
||||
@@ -96,34 +98,6 @@ DEFAULT_RFCOMM_MTU = 2048
|
||||
# -----------------------------------------------------------------------------
|
||||
# Utils
|
||||
# -----------------------------------------------------------------------------
|
||||
def parse_packet(packet):
|
||||
if len(packet) < 1:
|
||||
logging.info(
|
||||
color(f'!!! Packet too short (got {len(packet)} bytes, need >= 1)', 'red')
|
||||
)
|
||||
raise ValueError('packet too short')
|
||||
|
||||
try:
|
||||
packet_type = PacketType(packet[0])
|
||||
except ValueError:
|
||||
logging.info(color(f'!!! Invalid packet type 0x{packet[0]:02X}', 'red'))
|
||||
raise
|
||||
|
||||
return (packet_type, packet[1:])
|
||||
|
||||
|
||||
def parse_packet_sequence(packet_data):
|
||||
if len(packet_data) < 5:
|
||||
logging.info(
|
||||
color(
|
||||
f'!!!Packet too short (got {len(packet_data)} bytes, need >= 5)',
|
||||
'red',
|
||||
)
|
||||
)
|
||||
raise ValueError('packet too short')
|
||||
return struct.unpack_from('>bI', packet_data, 0)
|
||||
|
||||
|
||||
def le_phy_name(phy_id):
|
||||
return {HCI_LE_1M_PHY: '1M', HCI_LE_2M_PHY: '2M', HCI_LE_CODED_PHY: 'CODED'}.get(
|
||||
phy_id, HCI_Constant.le_phy_name(phy_id)
|
||||
@@ -194,17 +168,19 @@ def make_sdp_records(channel):
|
||||
}
|
||||
|
||||
|
||||
def log_stats(title, stats):
|
||||
def log_stats(title, stats, precision=2):
|
||||
stats_min = min(stats)
|
||||
stats_max = max(stats)
|
||||
stats_avg = sum(stats) / len(stats)
|
||||
stats_avg = statistics.mean(stats)
|
||||
stats_stdev = statistics.stdev(stats) if len(stats) >= 2 else 0
|
||||
logging.info(
|
||||
color(
|
||||
(
|
||||
f'### {title} stats: '
|
||||
f'min={stats_min:.2f}, '
|
||||
f'max={stats_max:.2f}, '
|
||||
f'average={stats_avg:.2f}'
|
||||
f'min={stats_min:.{precision}f}, '
|
||||
f'max={stats_max:.{precision}f}, '
|
||||
f'average={stats_avg:.{precision}f}, '
|
||||
f'stdev={stats_stdev:.{precision}f}'
|
||||
),
|
||||
'cyan',
|
||||
)
|
||||
@@ -222,13 +198,135 @@ async def switch_roles(connection, role):
|
||||
logging.info(f'{color("### Role switch failed:", "red")} {error}')
|
||||
|
||||
|
||||
class PacketType(enum.IntEnum):
|
||||
RESET = 0
|
||||
SEQUENCE = 1
|
||||
ACK = 2
|
||||
# -----------------------------------------------------------------------------
|
||||
# Packet
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclasses.dataclass
|
||||
class Packet:
|
||||
class PacketType(enum.IntEnum):
|
||||
RESET = 0
|
||||
SEQUENCE = 1
|
||||
ACK = 2
|
||||
|
||||
class PacketFlags(enum.IntFlag):
|
||||
LAST = 1
|
||||
|
||||
packet_type: PacketType
|
||||
flags: PacketFlags = PacketFlags(0)
|
||||
sequence: int = 0
|
||||
timestamp: int = 0
|
||||
payload: bytes = b""
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes):
|
||||
if len(data) < 1:
|
||||
logging.warning(
|
||||
color(f'!!! Packet too short (got {len(data)} bytes, need >= 1)', 'red')
|
||||
)
|
||||
raise ValueError('packet too short')
|
||||
|
||||
try:
|
||||
packet_type = cls.PacketType(data[0])
|
||||
except ValueError:
|
||||
logging.warning(color(f'!!! Invalid packet type 0x{data[0]:02X}', 'red'))
|
||||
raise
|
||||
|
||||
if packet_type == cls.PacketType.RESET:
|
||||
return cls(packet_type)
|
||||
|
||||
flags = cls.PacketFlags(data[1])
|
||||
(sequence,) = struct.unpack_from("<I", data, 2)
|
||||
|
||||
if packet_type == cls.PacketType.ACK:
|
||||
if len(data) < 6:
|
||||
logging.warning(
|
||||
color(
|
||||
f'!!! Packet too short (got {len(data)} bytes, need >= 6)',
|
||||
'red',
|
||||
)
|
||||
)
|
||||
return cls(packet_type, flags, sequence)
|
||||
|
||||
if len(data) < 10:
|
||||
logging.warning(
|
||||
color(
|
||||
f'!!! Packet too short (got {len(data)} bytes, need >= 10)', 'red'
|
||||
)
|
||||
)
|
||||
raise ValueError('packet too short')
|
||||
|
||||
(timestamp,) = struct.unpack_from("<I", data, 6)
|
||||
return cls(packet_type, flags, sequence, timestamp, data[10:])
|
||||
|
||||
def __bytes__(self):
|
||||
if self.packet_type == self.PacketType.RESET:
|
||||
return bytes([self.packet_type])
|
||||
|
||||
if self.packet_type == self.PacketType.ACK:
|
||||
return struct.pack("<BBI", self.packet_type, self.flags, self.sequence)
|
||||
|
||||
return (
|
||||
struct.pack(
|
||||
"<BBII", self.packet_type, self.flags, self.sequence, self.timestamp
|
||||
)
|
||||
+ self.payload
|
||||
)
|
||||
|
||||
|
||||
PACKET_FLAG_LAST = 1
|
||||
# -----------------------------------------------------------------------------
|
||||
# Jitter Stats
|
||||
# -----------------------------------------------------------------------------
|
||||
class JitterStats:
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.packets = []
|
||||
self.receive_times = []
|
||||
self.jitter = []
|
||||
|
||||
def on_packet_received(self, packet):
|
||||
now = time.time()
|
||||
self.packets.append(packet)
|
||||
self.receive_times.append(now)
|
||||
|
||||
if packet.timestamp and len(self.packets) > 1:
|
||||
expected_time = (
|
||||
self.receive_times[0]
|
||||
+ (packet.timestamp - self.packets[0].timestamp) / 1000000
|
||||
)
|
||||
jitter = now - expected_time
|
||||
else:
|
||||
jitter = 0.0
|
||||
|
||||
self.jitter.append(jitter)
|
||||
return jitter
|
||||
|
||||
def show_stats(self):
|
||||
if len(self.jitter) < 3:
|
||||
return
|
||||
average = sum(self.jitter) / len(self.jitter)
|
||||
adjusted = [jitter - average for jitter in self.jitter]
|
||||
|
||||
log_stats('Jitter (signed)', adjusted, 3)
|
||||
log_stats('Jitter (absolute)', [abs(jitter) for jitter in adjusted], 3)
|
||||
|
||||
# Show a histogram
|
||||
bin_count = 20
|
||||
bins = [0] * bin_count
|
||||
interval_min = min(adjusted)
|
||||
interval_max = max(adjusted)
|
||||
interval_range = interval_max - interval_min
|
||||
bin_thresholds = [
|
||||
interval_min + i * (interval_range / bin_count) for i in range(bin_count)
|
||||
]
|
||||
for jitter in adjusted:
|
||||
for i in reversed(range(bin_count)):
|
||||
if jitter >= bin_thresholds[i]:
|
||||
bins[i] += 1
|
||||
break
|
||||
for i in range(bin_count):
|
||||
logging.info(f'@@@ >= {bin_thresholds[i]:.4f}: {bins[i]}')
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -278,19 +376,37 @@ class Sender:
|
||||
await asyncio.sleep(self.tx_start_delay)
|
||||
|
||||
logging.info(color('=== Sending RESET', 'magenta'))
|
||||
await self.packet_io.send_packet(bytes([PacketType.RESET]))
|
||||
await self.packet_io.send_packet(
|
||||
bytes(Packet(packet_type=Packet.PacketType.RESET))
|
||||
)
|
||||
|
||||
self.start_time = time.time()
|
||||
self.bytes_sent = 0
|
||||
for tx_i in range(self.tx_packet_count):
|
||||
packet_flags = (
|
||||
PACKET_FLAG_LAST if tx_i == self.tx_packet_count - 1 else 0
|
||||
if self.pace > 0:
|
||||
# Wait until it is time to send the next packet
|
||||
target_time = self.start_time + (tx_i * self.pace / 1000)
|
||||
now = time.time()
|
||||
if now < target_time:
|
||||
await asyncio.sleep(target_time - now)
|
||||
else:
|
||||
await self.packet_io.drain()
|
||||
|
||||
packet = bytes(
|
||||
Packet(
|
||||
packet_type=Packet.PacketType.SEQUENCE,
|
||||
flags=(
|
||||
Packet.PacketFlags.LAST
|
||||
if tx_i == self.tx_packet_count - 1
|
||||
else 0
|
||||
),
|
||||
sequence=tx_i,
|
||||
timestamp=int((time.time() - self.start_time) * 1000000),
|
||||
payload=bytes(
|
||||
self.tx_packet_size - 10 - self.packet_io.overhead_size
|
||||
),
|
||||
)
|
||||
)
|
||||
packet = struct.pack(
|
||||
'>bbI',
|
||||
PacketType.SEQUENCE,
|
||||
packet_flags,
|
||||
tx_i,
|
||||
) + bytes(self.tx_packet_size - 6 - self.packet_io.overhead_size)
|
||||
logging.info(
|
||||
color(
|
||||
f'Sending packet {tx_i}: {self.tx_packet_size} bytes', 'yellow'
|
||||
@@ -299,14 +415,6 @@ class Sender:
|
||||
self.bytes_sent += len(packet)
|
||||
await self.packet_io.send_packet(packet)
|
||||
|
||||
if self.pace is None:
|
||||
continue
|
||||
|
||||
if self.pace > 0:
|
||||
await asyncio.sleep(self.pace / 1000)
|
||||
else:
|
||||
await self.packet_io.drain()
|
||||
|
||||
await self.done.wait()
|
||||
|
||||
run_counter = f'[{run + 1} of {self.repeat + 1}]' if self.repeat else ''
|
||||
@@ -318,13 +426,13 @@ class Sender:
|
||||
if self.repeat:
|
||||
logging.info(color('--- End of runs', 'blue'))
|
||||
|
||||
def on_packet_received(self, packet):
|
||||
def on_packet_received(self, data):
|
||||
try:
|
||||
packet_type, _ = parse_packet(packet)
|
||||
packet = Packet.from_bytes(data)
|
||||
except ValueError:
|
||||
return
|
||||
|
||||
if packet_type == PacketType.ACK:
|
||||
if packet.packet_type == Packet.PacketType.ACK:
|
||||
elapsed = time.time() - self.start_time
|
||||
average_tx_speed = self.bytes_sent / elapsed
|
||||
self.stats.append(average_tx_speed)
|
||||
@@ -347,52 +455,53 @@ class Receiver:
|
||||
last_timestamp: float
|
||||
|
||||
def __init__(self, packet_io, linger):
|
||||
self.reset()
|
||||
self.jitter_stats = JitterStats()
|
||||
self.packet_io = packet_io
|
||||
self.packet_io.packet_listener = self
|
||||
self.linger = linger
|
||||
self.done = asyncio.Event()
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.expected_packet_index = 0
|
||||
self.measurements = [(time.time(), 0)]
|
||||
self.total_bytes_received = 0
|
||||
self.jitter_stats.reset()
|
||||
|
||||
def on_packet_received(self, packet):
|
||||
def on_packet_received(self, data):
|
||||
try:
|
||||
packet_type, packet_data = parse_packet(packet)
|
||||
packet = Packet.from_bytes(data)
|
||||
except ValueError:
|
||||
logging.exception("invalid packet")
|
||||
return
|
||||
|
||||
if packet_type == PacketType.RESET:
|
||||
if packet.packet_type == Packet.PacketType.RESET:
|
||||
logging.info(color('=== Received RESET', 'magenta'))
|
||||
self.reset()
|
||||
return
|
||||
|
||||
try:
|
||||
packet_flags, packet_index = parse_packet_sequence(packet_data)
|
||||
except ValueError:
|
||||
return
|
||||
jitter = self.jitter_stats.on_packet_received(packet)
|
||||
logging.info(
|
||||
f'<<< Received packet {packet_index}: '
|
||||
f'flags=0x{packet_flags:02X}, '
|
||||
f'{len(packet) + self.packet_io.overhead_size} bytes'
|
||||
f'<<< Received packet {packet.sequence}: '
|
||||
f'flags={packet.flags}, '
|
||||
f'jitter={jitter:.4f}, '
|
||||
f'{len(data) + self.packet_io.overhead_size} bytes',
|
||||
)
|
||||
|
||||
if packet_index != self.expected_packet_index:
|
||||
if packet.sequence != self.expected_packet_index:
|
||||
logging.info(
|
||||
color(
|
||||
f'!!! Unexpected packet, expected {self.expected_packet_index} '
|
||||
f'but received {packet_index}'
|
||||
f'but received {packet.sequence}'
|
||||
)
|
||||
)
|
||||
|
||||
now = time.time()
|
||||
elapsed_since_start = now - self.measurements[0][0]
|
||||
elapsed_since_last = now - self.measurements[-1][0]
|
||||
self.measurements.append((now, len(packet)))
|
||||
self.total_bytes_received += len(packet)
|
||||
instant_rx_speed = len(packet) / elapsed_since_last
|
||||
self.measurements.append((now, len(data)))
|
||||
self.total_bytes_received += len(data)
|
||||
instant_rx_speed = len(data) / elapsed_since_last
|
||||
average_rx_speed = self.total_bytes_received / elapsed_since_start
|
||||
window = self.measurements[-64:]
|
||||
windowed_rx_speed = sum(measurement[1] for measurement in window[1:]) / (
|
||||
@@ -408,15 +517,17 @@ class Receiver:
|
||||
)
|
||||
)
|
||||
|
||||
self.expected_packet_index = packet_index + 1
|
||||
self.expected_packet_index = packet.sequence + 1
|
||||
|
||||
if packet_flags & PACKET_FLAG_LAST:
|
||||
if packet.flags & Packet.PacketFlags.LAST:
|
||||
AsyncRunner.spawn(
|
||||
self.packet_io.send_packet(
|
||||
struct.pack('>bbI', PacketType.ACK, packet_flags, packet_index)
|
||||
bytes(Packet(Packet.PacketType.ACK, packet.flags, packet.sequence))
|
||||
)
|
||||
)
|
||||
logging.info(color('@@@ Received last packet', 'green'))
|
||||
self.jitter_stats.show_stats()
|
||||
|
||||
if not self.linger:
|
||||
self.done.set()
|
||||
|
||||
@@ -448,9 +559,9 @@ class Ping:
|
||||
self.repeat_delay = repeat_delay
|
||||
self.pace = pace
|
||||
self.done = asyncio.Event()
|
||||
self.current_packet_index = 0
|
||||
self.ping_sent_time = 0.0
|
||||
self.latencies = []
|
||||
self.ping_times = []
|
||||
self.rtts = []
|
||||
self.next_expected_packet_index = 0
|
||||
self.min_stats = []
|
||||
self.max_stats = []
|
||||
self.avg_stats = []
|
||||
@@ -465,6 +576,7 @@ class Ping:
|
||||
|
||||
for run in range(self.repeat + 1):
|
||||
self.done.clear()
|
||||
self.ping_times = []
|
||||
|
||||
if run > 0 and self.repeat and self.repeat_delay:
|
||||
logging.info(color(f'*** Repeat delay: {self.repeat_delay}', 'green'))
|
||||
@@ -475,98 +587,97 @@ class Ping:
|
||||
await asyncio.sleep(self.tx_start_delay)
|
||||
|
||||
logging.info(color('=== Sending RESET', 'magenta'))
|
||||
await self.packet_io.send_packet(bytes([PacketType.RESET]))
|
||||
await self.packet_io.send_packet(bytes(Packet(Packet.PacketType.RESET)))
|
||||
|
||||
self.current_packet_index = 0
|
||||
self.latencies = []
|
||||
await self.send_next_ping()
|
||||
start_time = time.time()
|
||||
self.next_expected_packet_index = 0
|
||||
for i in range(self.tx_packet_count):
|
||||
target_time = start_time + (i * self.pace / 1000)
|
||||
now = time.time()
|
||||
if now < target_time:
|
||||
await asyncio.sleep(target_time - now)
|
||||
now = time.time()
|
||||
|
||||
packet = bytes(
|
||||
Packet(
|
||||
packet_type=Packet.PacketType.SEQUENCE,
|
||||
flags=(
|
||||
Packet.PacketFlags.LAST
|
||||
if i == self.tx_packet_count - 1
|
||||
else 0
|
||||
),
|
||||
sequence=i,
|
||||
timestamp=int((now - start_time) * 1000000),
|
||||
payload=bytes(self.tx_packet_size - 10),
|
||||
)
|
||||
)
|
||||
logging.info(color(f'Sending packet {i}', 'yellow'))
|
||||
self.ping_times.append(now)
|
||||
await self.packet_io.send_packet(packet)
|
||||
|
||||
await self.done.wait()
|
||||
|
||||
min_latency = min(self.latencies)
|
||||
max_latency = max(self.latencies)
|
||||
avg_latency = sum(self.latencies) / len(self.latencies)
|
||||
min_rtt = min(self.rtts)
|
||||
max_rtt = max(self.rtts)
|
||||
avg_rtt = statistics.mean(self.rtts)
|
||||
stdev_rtt = statistics.stdev(self.rtts)
|
||||
logging.info(
|
||||
color(
|
||||
'@@@ Latencies: '
|
||||
f'min={min_latency:.2f}, '
|
||||
f'max={max_latency:.2f}, '
|
||||
f'average={avg_latency:.2f}'
|
||||
'@@@ RTTs: '
|
||||
f'min={min_rtt:.2f}, '
|
||||
f'max={max_rtt:.2f}, '
|
||||
f'average={avg_rtt:.2f}, '
|
||||
f'stdev={stdev_rtt:.2f}'
|
||||
)
|
||||
)
|
||||
|
||||
self.min_stats.append(min_latency)
|
||||
self.max_stats.append(max_latency)
|
||||
self.avg_stats.append(avg_latency)
|
||||
self.min_stats.append(min_rtt)
|
||||
self.max_stats.append(max_rtt)
|
||||
self.avg_stats.append(avg_rtt)
|
||||
|
||||
run_counter = f'[{run + 1} of {self.repeat + 1}]' if self.repeat else ''
|
||||
logging.info(color(f'=== {run_counter} Done!', 'magenta'))
|
||||
|
||||
if self.repeat:
|
||||
log_stats('Min Latency', self.min_stats)
|
||||
log_stats('Max Latency', self.max_stats)
|
||||
log_stats('Average Latency', self.avg_stats)
|
||||
log_stats('Min RTT', self.min_stats)
|
||||
log_stats('Max RTT', self.max_stats)
|
||||
log_stats('Average RTT', self.avg_stats)
|
||||
|
||||
if self.repeat:
|
||||
logging.info(color('--- End of runs', 'blue'))
|
||||
|
||||
async def send_next_ping(self):
|
||||
if self.pace:
|
||||
await asyncio.sleep(self.pace / 1000)
|
||||
|
||||
packet = struct.pack(
|
||||
'>bbI',
|
||||
PacketType.SEQUENCE,
|
||||
(
|
||||
PACKET_FLAG_LAST
|
||||
if self.current_packet_index == self.tx_packet_count - 1
|
||||
else 0
|
||||
),
|
||||
self.current_packet_index,
|
||||
) + bytes(self.tx_packet_size - 6)
|
||||
logging.info(color(f'Sending packet {self.current_packet_index}', 'yellow'))
|
||||
self.ping_sent_time = time.time()
|
||||
await self.packet_io.send_packet(packet)
|
||||
|
||||
def on_packet_received(self, packet):
|
||||
elapsed = time.time() - self.ping_sent_time
|
||||
|
||||
def on_packet_received(self, data):
|
||||
try:
|
||||
packet_type, packet_data = parse_packet(packet)
|
||||
packet = Packet.from_bytes(data)
|
||||
except ValueError:
|
||||
return
|
||||
|
||||
try:
|
||||
packet_flags, packet_index = parse_packet_sequence(packet_data)
|
||||
except ValueError:
|
||||
return
|
||||
|
||||
if packet_type == PacketType.ACK:
|
||||
latency = elapsed * 1000
|
||||
self.latencies.append(latency)
|
||||
if packet.packet_type == Packet.PacketType.ACK:
|
||||
elapsed = time.time() - self.ping_times[packet.sequence]
|
||||
rtt = elapsed * 1000
|
||||
self.rtts.append(rtt)
|
||||
logging.info(
|
||||
color(
|
||||
f'<<< Received ACK [{packet_index}], latency={latency:.2f}ms',
|
||||
f'<<< Received ACK [{packet.sequence}], RTT={rtt:.2f}ms',
|
||||
'green',
|
||||
)
|
||||
)
|
||||
|
||||
if packet_index == self.current_packet_index:
|
||||
self.current_packet_index += 1
|
||||
if packet.sequence == self.next_expected_packet_index:
|
||||
self.next_expected_packet_index += 1
|
||||
else:
|
||||
logging.info(
|
||||
color(
|
||||
f'!!! Unexpected packet, expected {self.current_packet_index} '
|
||||
f'but received {packet_index}'
|
||||
f'!!! Unexpected packet, '
|
||||
f'expected {self.next_expected_packet_index} '
|
||||
f'but received {packet.sequence}'
|
||||
)
|
||||
)
|
||||
|
||||
if packet_flags & PACKET_FLAG_LAST:
|
||||
if packet.flags & Packet.PacketFlags.LAST:
|
||||
self.done.set()
|
||||
return
|
||||
|
||||
AsyncRunner.spawn(self.send_next_ping())
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Pong
|
||||
@@ -575,56 +686,59 @@ class Pong:
|
||||
expected_packet_index: int
|
||||
|
||||
def __init__(self, packet_io, linger):
|
||||
self.reset()
|
||||
self.jitter_stats = JitterStats()
|
||||
self.packet_io = packet_io
|
||||
self.packet_io.packet_listener = self
|
||||
self.linger = linger
|
||||
self.done = asyncio.Event()
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.expected_packet_index = 0
|
||||
self.jitter_stats.reset()
|
||||
|
||||
def on_packet_received(self, packet):
|
||||
def on_packet_received(self, data):
|
||||
try:
|
||||
packet_type, packet_data = parse_packet(packet)
|
||||
packet = Packet.from_bytes(data)
|
||||
except ValueError:
|
||||
return
|
||||
|
||||
if packet_type == PacketType.RESET:
|
||||
if packet.packet_type == Packet.PacketType.RESET:
|
||||
logging.info(color('=== Received RESET', 'magenta'))
|
||||
self.reset()
|
||||
return
|
||||
|
||||
try:
|
||||
packet_flags, packet_index = parse_packet_sequence(packet_data)
|
||||
except ValueError:
|
||||
return
|
||||
jitter = self.jitter_stats.on_packet_received(packet)
|
||||
logging.info(
|
||||
color(
|
||||
f'<<< Received packet {packet_index}: '
|
||||
f'flags=0x{packet_flags:02X}, {len(packet)} bytes',
|
||||
f'<<< Received packet {packet.sequence}: '
|
||||
f'flags={packet.flags}, {len(data)} bytes, '
|
||||
f'jitter={jitter:.4f}',
|
||||
'green',
|
||||
)
|
||||
)
|
||||
|
||||
if packet_index != self.expected_packet_index:
|
||||
if packet.sequence != self.expected_packet_index:
|
||||
logging.info(
|
||||
color(
|
||||
f'!!! Unexpected packet, expected {self.expected_packet_index} '
|
||||
f'but received {packet_index}'
|
||||
f'but received {packet.sequence}'
|
||||
)
|
||||
)
|
||||
|
||||
self.expected_packet_index = packet_index + 1
|
||||
self.expected_packet_index = packet.sequence + 1
|
||||
|
||||
AsyncRunner.spawn(
|
||||
self.packet_io.send_packet(
|
||||
struct.pack('>bbI', PacketType.ACK, packet_flags, packet_index)
|
||||
bytes(Packet(Packet.PacketType.ACK, packet.flags, packet.sequence))
|
||||
)
|
||||
)
|
||||
|
||||
if packet_flags & PACKET_FLAG_LAST and not self.linger:
|
||||
self.done.set()
|
||||
if packet.flags & Packet.PacketFlags.LAST:
|
||||
self.jitter_stats.show_stats()
|
||||
|
||||
if not self.linger:
|
||||
self.done.set()
|
||||
|
||||
async def run(self):
|
||||
await self.done.wait()
|
||||
@@ -942,9 +1056,12 @@ class RfcommClient(StreamedPacketIO):
|
||||
channel = await bumble.rfcomm.find_rfcomm_channel_with_uuid(
|
||||
connection, self.uuid
|
||||
)
|
||||
logging.info(color(f'@@@ Channel number = {channel}', 'cyan'))
|
||||
if channel == 0:
|
||||
logging.info(color('!!! No RFComm service with this UUID found', 'red'))
|
||||
if channel:
|
||||
logging.info(color(f'@@@ Channel number = {channel}', 'cyan'))
|
||||
else:
|
||||
logging.warning(
|
||||
color('!!! No RFComm service with this UUID found', 'red')
|
||||
)
|
||||
await connection.disconnect()
|
||||
return
|
||||
|
||||
@@ -1054,6 +1171,8 @@ class RfcommServer(StreamedPacketIO):
|
||||
if self.credits_threshold is not None:
|
||||
dlc.rx_credits_threshold = self.credits_threshold
|
||||
|
||||
self.ready.set()
|
||||
|
||||
async def drain(self):
|
||||
assert self.dlc
|
||||
await self.dlc.drain()
|
||||
@@ -1068,7 +1187,7 @@ class Central(Connection.Listener):
|
||||
transport,
|
||||
peripheral_address,
|
||||
classic,
|
||||
role_factory,
|
||||
scenario_factory,
|
||||
mode_factory,
|
||||
connection_interval,
|
||||
phy,
|
||||
@@ -1081,7 +1200,7 @@ class Central(Connection.Listener):
|
||||
self.transport = transport
|
||||
self.peripheral_address = peripheral_address
|
||||
self.classic = classic
|
||||
self.role_factory = role_factory
|
||||
self.scenario_factory = scenario_factory
|
||||
self.mode_factory = mode_factory
|
||||
self.authenticate = authenticate
|
||||
self.encrypt = encrypt or authenticate
|
||||
@@ -1134,7 +1253,7 @@ class Central(Connection.Listener):
|
||||
DEFAULT_CENTRAL_NAME, central_address, hci_source, hci_sink
|
||||
)
|
||||
mode = self.mode_factory(self.device)
|
||||
role = self.role_factory(mode)
|
||||
scenario = self.scenario_factory(mode)
|
||||
self.device.classic_enabled = self.classic
|
||||
|
||||
# Set up a pairing config factory with minimal requirements.
|
||||
@@ -1215,7 +1334,7 @@ class Central(Connection.Listener):
|
||||
|
||||
await mode.on_connection(self.connection)
|
||||
|
||||
await role.run()
|
||||
await scenario.run()
|
||||
await asyncio.sleep(DEFAULT_LINGER_TIME)
|
||||
await self.connection.disconnect()
|
||||
|
||||
@@ -1246,7 +1365,7 @@ class Peripheral(Device.Listener, Connection.Listener):
|
||||
def __init__(
|
||||
self,
|
||||
transport,
|
||||
role_factory,
|
||||
scenario_factory,
|
||||
mode_factory,
|
||||
classic,
|
||||
extended_data_length,
|
||||
@@ -1254,11 +1373,11 @@ class Peripheral(Device.Listener, Connection.Listener):
|
||||
):
|
||||
self.transport = transport
|
||||
self.classic = classic
|
||||
self.role_factory = role_factory
|
||||
self.scenario_factory = scenario_factory
|
||||
self.mode_factory = mode_factory
|
||||
self.extended_data_length = extended_data_length
|
||||
self.role_switch = role_switch
|
||||
self.role = None
|
||||
self.scenario = None
|
||||
self.mode = None
|
||||
self.device = None
|
||||
self.connection = None
|
||||
@@ -1278,7 +1397,7 @@ class Peripheral(Device.Listener, Connection.Listener):
|
||||
)
|
||||
self.device.listener = self
|
||||
self.mode = self.mode_factory(self.device)
|
||||
self.role = self.role_factory(self.mode)
|
||||
self.scenario = self.scenario_factory(self.mode)
|
||||
self.device.classic_enabled = self.classic
|
||||
|
||||
# Set up a pairing config factory with minimal requirements.
|
||||
@@ -1315,7 +1434,7 @@ class Peripheral(Device.Listener, Connection.Listener):
|
||||
print_connection(self.connection)
|
||||
|
||||
await self.mode.on_connection(self.connection)
|
||||
await self.role.run()
|
||||
await self.scenario.run()
|
||||
await asyncio.sleep(DEFAULT_LINGER_TIME)
|
||||
|
||||
def on_connection(self, connection):
|
||||
@@ -1344,7 +1463,7 @@ class Peripheral(Device.Listener, Connection.Listener):
|
||||
def on_disconnection(self, reason):
|
||||
logging.info(color(f'!!! Disconnection: reason={reason}', 'red'))
|
||||
self.connection = None
|
||||
self.role.reset()
|
||||
self.scenario.reset()
|
||||
|
||||
if self.classic:
|
||||
AsyncRunner.spawn(self.device.set_discoverable(True))
|
||||
@@ -1426,13 +1545,13 @@ def create_mode_factory(ctx, default_mode):
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def create_role_factory(ctx, default_role):
|
||||
role = ctx.obj['role']
|
||||
if role is None:
|
||||
role = default_role
|
||||
def create_scenario_factory(ctx, default_scenario):
|
||||
scenario = ctx.obj['scenario']
|
||||
if scenario is None:
|
||||
scenario = default_scenario
|
||||
|
||||
def create_role(packet_io):
|
||||
if role == 'sender':
|
||||
def create_scenario(packet_io):
|
||||
if scenario == 'send':
|
||||
return Sender(
|
||||
packet_io,
|
||||
start_delay=ctx.obj['start_delay'],
|
||||
@@ -1443,10 +1562,10 @@ def create_role_factory(ctx, default_role):
|
||||
packet_count=ctx.obj['packet_count'],
|
||||
)
|
||||
|
||||
if role == 'receiver':
|
||||
if scenario == 'receive':
|
||||
return Receiver(packet_io, ctx.obj['linger'])
|
||||
|
||||
if role == 'ping':
|
||||
if scenario == 'ping':
|
||||
return Ping(
|
||||
packet_io,
|
||||
start_delay=ctx.obj['start_delay'],
|
||||
@@ -1457,12 +1576,12 @@ def create_role_factory(ctx, default_role):
|
||||
packet_count=ctx.obj['packet_count'],
|
||||
)
|
||||
|
||||
if role == 'pong':
|
||||
if scenario == 'pong':
|
||||
return Pong(packet_io, ctx.obj['linger'])
|
||||
|
||||
raise ValueError('invalid role')
|
||||
raise ValueError('invalid scenario')
|
||||
|
||||
return create_role
|
||||
return create_scenario
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -1470,7 +1589,7 @@ def create_role_factory(ctx, default_role):
|
||||
# -----------------------------------------------------------------------------
|
||||
@click.group()
|
||||
@click.option('--device-config', metavar='FILENAME', help='Device configuration file')
|
||||
@click.option('--role', type=click.Choice(['sender', 'receiver', 'ping', 'pong']))
|
||||
@click.option('--scenario', type=click.Choice(['send', 'receive', 'ping', 'pong']))
|
||||
@click.option(
|
||||
'--mode',
|
||||
type=click.Choice(
|
||||
@@ -1488,6 +1607,7 @@ def create_role_factory(ctx, default_role):
|
||||
'--att-mtu',
|
||||
metavar='MTU',
|
||||
type=click.IntRange(23, 517),
|
||||
default=517,
|
||||
help='GATT MTU (gatt-client mode)',
|
||||
)
|
||||
@click.option(
|
||||
@@ -1503,7 +1623,7 @@ def create_role_factory(ctx, default_role):
|
||||
'--rfcomm-channel',
|
||||
type=int,
|
||||
default=DEFAULT_RFCOMM_CHANNEL,
|
||||
help='RFComm channel to use',
|
||||
help='RFComm channel to use (specify 0 for channel discovery via SDP)',
|
||||
)
|
||||
@click.option(
|
||||
'--rfcomm-uuid',
|
||||
@@ -1563,9 +1683,9 @@ def create_role_factory(ctx, default_role):
|
||||
'--packet-size',
|
||||
'-s',
|
||||
metavar='SIZE',
|
||||
type=click.IntRange(8, 8192),
|
||||
type=click.IntRange(10, 8192),
|
||||
default=500,
|
||||
help='Packet size (client or ping role)',
|
||||
help='Packet size (send or ping scenario)',
|
||||
)
|
||||
@click.option(
|
||||
'--packet-count',
|
||||
@@ -1573,7 +1693,7 @@ def create_role_factory(ctx, default_role):
|
||||
metavar='COUNT',
|
||||
type=int,
|
||||
default=10,
|
||||
help='Packet count (client or ping role)',
|
||||
help='Packet count (send or ping scenario)',
|
||||
)
|
||||
@click.option(
|
||||
'--start-delay',
|
||||
@@ -1581,7 +1701,7 @@ def create_role_factory(ctx, default_role):
|
||||
metavar='SECONDS',
|
||||
type=int,
|
||||
default=1,
|
||||
help='Start delay (client or ping role)',
|
||||
help='Start delay (send or ping scenario)',
|
||||
)
|
||||
@click.option(
|
||||
'--repeat',
|
||||
@@ -1589,7 +1709,7 @@ def create_role_factory(ctx, default_role):
|
||||
type=int,
|
||||
default=0,
|
||||
help=(
|
||||
'Repeat the run N times (client and ping roles)'
|
||||
'Repeat the run N times (send and ping scenario)'
|
||||
'(0, which is the fault, to run just once) '
|
||||
),
|
||||
)
|
||||
@@ -1613,13 +1733,13 @@ def create_role_factory(ctx, default_role):
|
||||
@click.option(
|
||||
'--linger',
|
||||
is_flag=True,
|
||||
help="Don't exit at the end of a run (server and pong roles)",
|
||||
help="Don't exit at the end of a run (receive and pong scenarios)",
|
||||
)
|
||||
@click.pass_context
|
||||
def bench(
|
||||
ctx,
|
||||
device_config,
|
||||
role,
|
||||
scenario,
|
||||
mode,
|
||||
att_mtu,
|
||||
extended_data_length,
|
||||
@@ -1645,7 +1765,7 @@ def bench(
|
||||
):
|
||||
ctx.ensure_object(dict)
|
||||
ctx.obj['device_config'] = device_config
|
||||
ctx.obj['role'] = role
|
||||
ctx.obj['scenario'] = scenario
|
||||
ctx.obj['mode'] = mode
|
||||
ctx.obj['att_mtu'] = att_mtu
|
||||
ctx.obj['rfcomm_channel'] = rfcomm_channel
|
||||
@@ -1699,7 +1819,7 @@ def central(
|
||||
ctx, transport, peripheral_address, connection_interval, phy, authenticate, encrypt
|
||||
):
|
||||
"""Run as a central (initiates the connection)"""
|
||||
role_factory = create_role_factory(ctx, 'sender')
|
||||
scenario_factory = create_scenario_factory(ctx, 'send')
|
||||
mode_factory = create_mode_factory(ctx, 'gatt-client')
|
||||
classic = ctx.obj['classic']
|
||||
|
||||
@@ -1708,7 +1828,7 @@ def central(
|
||||
transport,
|
||||
peripheral_address,
|
||||
classic,
|
||||
role_factory,
|
||||
scenario_factory,
|
||||
mode_factory,
|
||||
connection_interval,
|
||||
phy,
|
||||
@@ -1726,13 +1846,13 @@ def central(
|
||||
@click.pass_context
|
||||
def peripheral(ctx, transport):
|
||||
"""Run as a peripheral (waits for a connection)"""
|
||||
role_factory = create_role_factory(ctx, 'receiver')
|
||||
scenario_factory = create_scenario_factory(ctx, 'receive')
|
||||
mode_factory = create_mode_factory(ctx, 'gatt-server')
|
||||
|
||||
async def run_peripheral():
|
||||
await Peripheral(
|
||||
transport,
|
||||
role_factory,
|
||||
scenario_factory,
|
||||
mode_factory,
|
||||
ctx.obj['classic'],
|
||||
ctx.obj['extended_data_length'],
|
||||
@@ -1743,7 +1863,11 @@ def peripheral(ctx, transport):
|
||||
|
||||
|
||||
def main():
|
||||
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
|
||||
logging.basicConfig(
|
||||
level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper(),
|
||||
format="[%(asctime)s.%(msecs)03d] %(levelname)s:%(name)s:%(message)s",
|
||||
datefmt="%H:%M:%S",
|
||||
)
|
||||
bench()
|
||||
|
||||
|
||||
|
||||
@@ -37,6 +37,8 @@ from bumble.hci import (
|
||||
HCI_Command_Status_Event,
|
||||
HCI_READ_BUFFER_SIZE_COMMAND,
|
||||
HCI_Read_Buffer_Size_Command,
|
||||
HCI_LE_READ_BUFFER_SIZE_V2_COMMAND,
|
||||
HCI_LE_Read_Buffer_Size_V2_Command,
|
||||
HCI_READ_BD_ADDR_COMMAND,
|
||||
HCI_Read_BD_ADDR_Command,
|
||||
HCI_READ_LOCAL_NAME_COMMAND,
|
||||
@@ -75,7 +77,7 @@ async def get_classic_info(host: Host) -> None:
|
||||
if command_succeeded(response):
|
||||
print()
|
||||
print(
|
||||
color('Classic Address:', 'yellow'),
|
||||
color('Public Address:', 'yellow'),
|
||||
response.return_parameters.bd_addr.to_string(False),
|
||||
)
|
||||
|
||||
@@ -147,7 +149,7 @@ async def get_le_info(host: Host) -> None:
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def get_acl_flow_control_info(host: Host) -> None:
|
||||
async def get_flow_control_info(host: Host) -> None:
|
||||
print()
|
||||
|
||||
if host.supports_command(HCI_READ_BUFFER_SIZE_COMMAND):
|
||||
@@ -160,14 +162,28 @@ async def get_acl_flow_control_info(host: Host) -> None:
|
||||
f'packets of size {response.return_parameters.hc_acl_data_packet_length}',
|
||||
)
|
||||
|
||||
if host.supports_command(HCI_LE_READ_BUFFER_SIZE_COMMAND):
|
||||
if host.supports_command(HCI_LE_READ_BUFFER_SIZE_V2_COMMAND):
|
||||
response = await host.send_command(
|
||||
HCI_LE_Read_Buffer_Size_V2_Command(), check_result=True
|
||||
)
|
||||
print(
|
||||
color('LE ACL Flow Control:', 'yellow'),
|
||||
f'{response.return_parameters.total_num_le_acl_data_packets} '
|
||||
f'packets of size {response.return_parameters.le_acl_data_packet_length}',
|
||||
)
|
||||
print(
|
||||
color('LE ISO Flow Control:', 'yellow'),
|
||||
f'{response.return_parameters.total_num_iso_data_packets} '
|
||||
f'packets of size {response.return_parameters.iso_data_packet_length}',
|
||||
)
|
||||
elif host.supports_command(HCI_LE_READ_BUFFER_SIZE_COMMAND):
|
||||
response = await host.send_command(
|
||||
HCI_LE_Read_Buffer_Size_Command(), check_result=True
|
||||
)
|
||||
print(
|
||||
color('LE ACL Flow Control:', 'yellow'),
|
||||
f'{response.return_parameters.hc_total_num_le_acl_data_packets} '
|
||||
f'packets of size {response.return_parameters.hc_le_acl_data_packet_length}',
|
||||
f'{response.return_parameters.total_num_le_acl_data_packets} '
|
||||
f'packets of size {response.return_parameters.le_acl_data_packet_length}',
|
||||
)
|
||||
|
||||
|
||||
@@ -274,8 +290,8 @@ async def async_main(latency_probes, transport):
|
||||
# Get the LE info
|
||||
await get_le_info(host)
|
||||
|
||||
# Print the ACL flow control info
|
||||
await get_acl_flow_control_info(host)
|
||||
# Print the flow control info
|
||||
await get_flow_control_info(host)
|
||||
|
||||
# Get codec info
|
||||
await get_codecs_info(host)
|
||||
|
||||
@@ -29,7 +29,9 @@ from bumble.gatt import Service
|
||||
from bumble.profiles.device_information_service import DeviceInformationServiceProxy
|
||||
from bumble.profiles.battery_service import BatteryServiceProxy
|
||||
from bumble.profiles.gap import GenericAccessServiceProxy
|
||||
from bumble.profiles.pacs import PublishedAudioCapabilitiesServiceProxy
|
||||
from bumble.profiles.tmap import TelephonyAndMediaAudioServiceProxy
|
||||
from bumble.profiles.vcs import VolumeControlServiceProxy
|
||||
from bumble.transport import open_transport_or_link
|
||||
|
||||
|
||||
@@ -126,14 +128,52 @@ async def show_tmas(
|
||||
print(color('### Telephony And Media Audio Service', 'yellow'))
|
||||
|
||||
if tmas.role:
|
||||
print(
|
||||
color(' Role:', 'green'),
|
||||
await tmas.role.read_value(),
|
||||
)
|
||||
role = await tmas.role.read_value()
|
||||
print(color(' Role:', 'green'), role)
|
||||
|
||||
print()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def show_pacs(pacs: PublishedAudioCapabilitiesServiceProxy) -> None:
|
||||
print(color('### Published Audio Capabilities Service', 'yellow'))
|
||||
|
||||
contexts = await pacs.available_audio_contexts.read_value()
|
||||
print(color(' Available Audio Contexts:', 'green'), contexts)
|
||||
|
||||
contexts = await pacs.supported_audio_contexts.read_value()
|
||||
print(color(' Supported Audio Contexts:', 'green'), contexts)
|
||||
|
||||
if pacs.sink_pac:
|
||||
pac = await pacs.sink_pac.read_value()
|
||||
print(color(' Sink PAC: ', 'green'), pac)
|
||||
|
||||
if pacs.sink_audio_locations:
|
||||
audio_locations = await pacs.sink_audio_locations.read_value()
|
||||
print(color(' Sink Audio Locations: ', 'green'), audio_locations)
|
||||
|
||||
if pacs.source_pac:
|
||||
pac = await pacs.source_pac.read_value()
|
||||
print(color(' Source PAC: ', 'green'), pac)
|
||||
|
||||
if pacs.source_audio_locations:
|
||||
audio_locations = await pacs.source_audio_locations.read_value()
|
||||
print(color(' Source Audio Locations: ', 'green'), audio_locations)
|
||||
|
||||
print()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def show_vcs(vcs: VolumeControlServiceProxy) -> None:
|
||||
print(color('### Volume Control Service', 'yellow'))
|
||||
|
||||
volume_state = await vcs.volume_state.read_value()
|
||||
print(color(' Volume State:', 'green'), volume_state)
|
||||
|
||||
volume_flags = await vcs.volume_flags.read_value()
|
||||
print(color(' Volume Flags:', 'green'), volume_flags)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def show_device_info(peer, done: Optional[asyncio.Future]) -> None:
|
||||
try:
|
||||
@@ -161,6 +201,12 @@ async def show_device_info(peer, done: Optional[asyncio.Future]) -> None:
|
||||
if tmas := peer.create_service_proxy(TelephonyAndMediaAudioServiceProxy):
|
||||
await try_show(show_tmas, tmas)
|
||||
|
||||
if pacs := peer.create_service_proxy(PublishedAudioCapabilitiesServiceProxy):
|
||||
await try_show(show_pacs, pacs)
|
||||
|
||||
if vcs := peer.create_service_proxy(VolumeControlServiceProxy):
|
||||
await try_show(show_vcs, vcs)
|
||||
|
||||
if done is not None:
|
||||
done.set_result(None)
|
||||
except asyncio.CancelledError:
|
||||
|
||||
@@ -83,7 +83,7 @@ async def async_main():
|
||||
return_parameters=bytes([hci.HCI_SUCCESS]),
|
||||
)
|
||||
# Return a packet with 'respond to sender' set to True
|
||||
return (response.to_bytes(), True)
|
||||
return (bytes(response), True)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@@ -16,23 +16,22 @@
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import enum
|
||||
import functools
|
||||
from importlib import resources
|
||||
import json
|
||||
import os
|
||||
import logging
|
||||
import pathlib
|
||||
from typing import Optional, List, cast
|
||||
import weakref
|
||||
import struct
|
||||
import wave
|
||||
|
||||
import ctypes
|
||||
import wasmtime
|
||||
import wasmtime.loader
|
||||
import liblc3 # type: ignore
|
||||
try:
|
||||
import lc3 # type: ignore # pylint: disable=E0401
|
||||
except ImportError as e:
|
||||
raise ImportError("Try `python -m pip install \".[lc3]\"`.") from e
|
||||
|
||||
import click
|
||||
import aiohttp.web
|
||||
@@ -40,11 +39,12 @@ import aiohttp.web
|
||||
import bumble
|
||||
from bumble.core import AdvertisingData
|
||||
from bumble.colors import color
|
||||
from bumble.device import Device, DeviceConfiguration, AdvertisingParameters
|
||||
from bumble.device import Device, DeviceConfiguration, AdvertisingParameters, CisLink
|
||||
from bumble.transport import open_transport
|
||||
from bumble.profiles import ascs, bap, pacs
|
||||
from bumble.hci import Address, CodecID, CodingFormat, HCI_IsoDataPacket
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -54,6 +54,7 @@ logger = logging.getLogger(__name__)
|
||||
# Constants
|
||||
# -----------------------------------------------------------------------------
|
||||
DEFAULT_UI_PORT = 7654
|
||||
DEFAULT_PCM_BYTES_PER_SAMPLE = 2
|
||||
|
||||
|
||||
def _sink_pac_record() -> pacs.PacRecord:
|
||||
@@ -100,153 +101,8 @@ def _source_pac_record() -> pacs.PacRecord:
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# WASM - liblc3
|
||||
# -----------------------------------------------------------------------------
|
||||
store = wasmtime.loader.store
|
||||
_memory = cast(wasmtime.Memory, liblc3.memory)
|
||||
STACK_POINTER = _memory.data_len(store)
|
||||
_memory.grow(store, 1)
|
||||
# Mapping wasmtime memory to linear address
|
||||
memory = (ctypes.c_ubyte * _memory.data_len(store)).from_address(
|
||||
ctypes.addressof(_memory.data_ptr(store).contents) # type: ignore
|
||||
)
|
||||
|
||||
|
||||
class Liblc3PcmFormat(enum.IntEnum):
|
||||
S16 = 0
|
||||
S24 = 1
|
||||
S24_3LE = 2
|
||||
FLOAT = 3
|
||||
|
||||
|
||||
MAX_DECODER_SIZE = liblc3.lc3_decoder_size(10000, 48000)
|
||||
MAX_ENCODER_SIZE = liblc3.lc3_encoder_size(10000, 48000)
|
||||
|
||||
DECODER_STACK_POINTER = STACK_POINTER
|
||||
ENCODER_STACK_POINTER = DECODER_STACK_POINTER + MAX_DECODER_SIZE * 2
|
||||
DECODE_BUFFER_STACK_POINTER = ENCODER_STACK_POINTER + MAX_ENCODER_SIZE * 2
|
||||
ENCODE_BUFFER_STACK_POINTER = DECODE_BUFFER_STACK_POINTER + 8192
|
||||
DEFAULT_PCM_SAMPLE_RATE = 48000
|
||||
DEFAULT_PCM_FORMAT = Liblc3PcmFormat.S16
|
||||
DEFAULT_PCM_BYTES_PER_SAMPLE = 2
|
||||
|
||||
|
||||
encoders: List[int] = []
|
||||
decoders: List[int] = []
|
||||
|
||||
|
||||
def setup_encoders(
|
||||
sample_rate_hz: int, frame_duration_us: int, num_channels: int
|
||||
) -> None:
|
||||
logger.info(
|
||||
f"setup_encoders {sample_rate_hz}Hz {frame_duration_us}us {num_channels}channels"
|
||||
)
|
||||
encoders[:num_channels] = [
|
||||
liblc3.lc3_setup_encoder(
|
||||
frame_duration_us,
|
||||
sample_rate_hz,
|
||||
DEFAULT_PCM_SAMPLE_RATE, # Input sample rate
|
||||
ENCODER_STACK_POINTER + MAX_ENCODER_SIZE * i,
|
||||
)
|
||||
for i in range(num_channels)
|
||||
]
|
||||
|
||||
|
||||
def setup_decoders(
|
||||
sample_rate_hz: int, frame_duration_us: int, num_channels: int
|
||||
) -> None:
|
||||
logger.info(
|
||||
f"setup_decoders {sample_rate_hz}Hz {frame_duration_us}us {num_channels}channels"
|
||||
)
|
||||
decoders[:num_channels] = [
|
||||
liblc3.lc3_setup_decoder(
|
||||
frame_duration_us,
|
||||
sample_rate_hz,
|
||||
DEFAULT_PCM_SAMPLE_RATE, # Output sample rate
|
||||
DECODER_STACK_POINTER + MAX_DECODER_SIZE * i,
|
||||
)
|
||||
for i in range(num_channels)
|
||||
]
|
||||
|
||||
|
||||
def decode(
|
||||
frame_duration_us: int,
|
||||
num_channels: int,
|
||||
input_bytes: bytes,
|
||||
) -> bytes:
|
||||
if not input_bytes:
|
||||
return b''
|
||||
|
||||
input_buffer_offset = DECODE_BUFFER_STACK_POINTER
|
||||
input_buffer_size = len(input_bytes)
|
||||
input_bytes_per_frame = input_buffer_size // num_channels
|
||||
|
||||
# Copy into wasm
|
||||
memory[input_buffer_offset : input_buffer_offset + input_buffer_size] = input_bytes # type: ignore
|
||||
|
||||
output_buffer_offset = input_buffer_offset + input_buffer_size
|
||||
output_buffer_size = (
|
||||
liblc3.lc3_frame_samples(frame_duration_us, DEFAULT_PCM_SAMPLE_RATE)
|
||||
* DEFAULT_PCM_BYTES_PER_SAMPLE
|
||||
* num_channels
|
||||
)
|
||||
|
||||
for i in range(num_channels):
|
||||
res = liblc3.lc3_decode(
|
||||
decoders[i],
|
||||
input_buffer_offset + input_bytes_per_frame * i,
|
||||
input_bytes_per_frame,
|
||||
DEFAULT_PCM_FORMAT,
|
||||
output_buffer_offset + i * DEFAULT_PCM_BYTES_PER_SAMPLE,
|
||||
num_channels, # Stride
|
||||
)
|
||||
|
||||
if res != 0:
|
||||
logging.error(f"Parsing failed, res={res}")
|
||||
|
||||
# Extract decoded data from the output buffer
|
||||
return bytes(
|
||||
memory[output_buffer_offset : output_buffer_offset + output_buffer_size]
|
||||
)
|
||||
|
||||
|
||||
def encode(
|
||||
sdu_length: int,
|
||||
num_channels: int,
|
||||
stride: int,
|
||||
input_bytes: bytes,
|
||||
) -> bytes:
|
||||
if not input_bytes:
|
||||
return b''
|
||||
|
||||
input_buffer_offset = ENCODE_BUFFER_STACK_POINTER
|
||||
input_buffer_size = len(input_bytes)
|
||||
|
||||
# Copy into wasm
|
||||
memory[input_buffer_offset : input_buffer_offset + input_buffer_size] = input_bytes # type: ignore
|
||||
|
||||
output_buffer_offset = input_buffer_offset + input_buffer_size
|
||||
output_buffer_size = sdu_length
|
||||
output_frame_size = output_buffer_size // num_channels
|
||||
|
||||
for i in range(num_channels):
|
||||
res = liblc3.lc3_encode(
|
||||
encoders[i],
|
||||
DEFAULT_PCM_FORMAT,
|
||||
input_buffer_offset + DEFAULT_PCM_BYTES_PER_SAMPLE * i,
|
||||
stride,
|
||||
output_frame_size,
|
||||
output_buffer_offset + output_frame_size * i,
|
||||
)
|
||||
|
||||
if res != 0:
|
||||
logging.error(f"Parsing failed, res={res}")
|
||||
|
||||
# Extract decoded data from the output buffer
|
||||
return bytes(
|
||||
memory[output_buffer_offset : output_buffer_offset + output_buffer_size]
|
||||
)
|
||||
decoder: lc3.Decoder | None = None
|
||||
encoding_config: bap.CodecSpecificConfiguration | None = None
|
||||
|
||||
|
||||
async def lc3_source_task(
|
||||
@@ -254,44 +110,49 @@ async def lc3_source_task(
|
||||
sdu_length: int,
|
||||
frame_duration_us: int,
|
||||
device: Device,
|
||||
cis_handle: int,
|
||||
cis_link: CisLink,
|
||||
) -> None:
|
||||
with open(filename, 'rb') as f:
|
||||
header = f.read(44)
|
||||
assert header[8:12] == b'WAVE'
|
||||
logger.info(
|
||||
"lc3_source_task filename=%s, sdu_length=%d, frame_duration=%.1f",
|
||||
filename,
|
||||
sdu_length,
|
||||
frame_duration_us / 1000,
|
||||
)
|
||||
with wave.open(filename, 'rb') as wav:
|
||||
bits_per_sample = wav.getsampwidth() * 8
|
||||
|
||||
pcm_num_channel, pcm_sample_rate, _byte_rate, _block_align, bits_per_sample = (
|
||||
struct.unpack("<HIIHH", header[22:36])
|
||||
)
|
||||
assert pcm_sample_rate == DEFAULT_PCM_SAMPLE_RATE
|
||||
assert bits_per_sample == DEFAULT_PCM_BYTES_PER_SAMPLE * 8
|
||||
|
||||
frame_bytes = (
|
||||
liblc3.lc3_frame_samples(frame_duration_us, DEFAULT_PCM_SAMPLE_RATE)
|
||||
* DEFAULT_PCM_BYTES_PER_SAMPLE
|
||||
)
|
||||
packet_sequence_number = 0
|
||||
encoder: lc3.Encoder | None = None
|
||||
|
||||
while True:
|
||||
next_round = datetime.datetime.now() + datetime.timedelta(
|
||||
microseconds=frame_duration_us
|
||||
)
|
||||
pcm_data = f.read(frame_bytes)
|
||||
sdu = encode(sdu_length, pcm_num_channel, pcm_num_channel, pcm_data)
|
||||
if not encoder:
|
||||
if (
|
||||
encoding_config
|
||||
and (frame_duration := encoding_config.frame_duration)
|
||||
and (sampling_frequency := encoding_config.sampling_frequency)
|
||||
and (
|
||||
audio_channel_allocation := encoding_config.audio_channel_allocation
|
||||
)
|
||||
):
|
||||
logger.info("Use %s", encoding_config)
|
||||
encoder = lc3.Encoder(
|
||||
frame_duration_us=frame_duration.us,
|
||||
sample_rate_hz=sampling_frequency.hz,
|
||||
num_channels=audio_channel_allocation.channel_count,
|
||||
input_sample_rate_hz=wav.getframerate(),
|
||||
)
|
||||
else:
|
||||
sdu = encoder.encode(
|
||||
pcm=wav.readframes(encoder.get_frame_samples()),
|
||||
num_bytes=sdu_length,
|
||||
bit_depth=bits_per_sample,
|
||||
)
|
||||
cis_link.write(sdu)
|
||||
|
||||
iso_packet = HCI_IsoDataPacket(
|
||||
connection_handle=cis_handle,
|
||||
data_total_length=sdu_length + 4,
|
||||
packet_sequence_number=packet_sequence_number,
|
||||
pb_flag=0b10,
|
||||
packet_status_flag=0,
|
||||
iso_sdu_length=sdu_length,
|
||||
iso_sdu_fragment=sdu,
|
||||
)
|
||||
device.host.send_hci_packet(iso_packet)
|
||||
packet_sequence_number += 1
|
||||
sleep_time = next_round - datetime.datetime.now()
|
||||
await asyncio.sleep(sleep_time.total_seconds())
|
||||
await asyncio.sleep(sleep_time.total_seconds() * 0.9)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -410,7 +271,7 @@ class Speaker:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device_config_path: Optional[str],
|
||||
device_config_path: str | None,
|
||||
ui_port: int,
|
||||
transport: str,
|
||||
lc3_input_file_path: str,
|
||||
@@ -437,6 +298,7 @@ class Speaker:
|
||||
advertising_interval_min=25,
|
||||
advertising_interval_max=25,
|
||||
address=Address('F1:F2:F3:F4:F5:F6'),
|
||||
identity_address_type=Address.RANDOM_DEVICE_ADDRESS,
|
||||
)
|
||||
|
||||
device_config.le_enabled = True
|
||||
@@ -486,20 +348,31 @@ class Speaker:
|
||||
|
||||
def on_pdu(pdu: HCI_IsoDataPacket, ase: ascs.AseStateMachine):
|
||||
codec_config = ase.codec_specific_configuration
|
||||
assert isinstance(codec_config, bap.CodecSpecificConfiguration)
|
||||
pcm = decode(
|
||||
codec_config.frame_duration.us,
|
||||
codec_config.audio_channel_allocation.channel_count,
|
||||
pdu.iso_sdu_fragment,
|
||||
if (
|
||||
not isinstance(codec_config, bap.CodecSpecificConfiguration)
|
||||
or codec_config.frame_duration is None
|
||||
or codec_config.audio_channel_allocation is None
|
||||
or decoder is None
|
||||
or not pdu.iso_sdu_fragment
|
||||
):
|
||||
return
|
||||
pcm = decoder.decode(
|
||||
pdu.iso_sdu_fragment, bit_depth=DEFAULT_PCM_BYTES_PER_SAMPLE * 8
|
||||
)
|
||||
self.device.abort_on('disconnection', self.ui_server.send_audio(pcm))
|
||||
|
||||
def on_ase_state_change(ase: ascs.AseStateMachine) -> None:
|
||||
codec_config = ase.codec_specific_configuration
|
||||
if ase.state == ascs.AseStateMachine.State.STREAMING:
|
||||
codec_config = ase.codec_specific_configuration
|
||||
assert isinstance(codec_config, bap.CodecSpecificConfiguration)
|
||||
assert ase.cis_link
|
||||
if ase.role == ascs.AudioRole.SOURCE:
|
||||
if (
|
||||
not isinstance(codec_config, bap.CodecSpecificConfiguration)
|
||||
or ase.cis_link is None
|
||||
or codec_config.octets_per_codec_frame is None
|
||||
or codec_config.frame_duration is None
|
||||
or codec_config.codec_frames_per_sdu is None
|
||||
):
|
||||
return
|
||||
ase.cis_link.abort_on(
|
||||
'disconnection',
|
||||
lc3_source_task(
|
||||
@@ -510,25 +383,30 @@ class Speaker:
|
||||
),
|
||||
frame_duration_us=codec_config.frame_duration.us,
|
||||
device=self.device,
|
||||
cis_handle=ase.cis_link.handle,
|
||||
cis_link=ase.cis_link,
|
||||
),
|
||||
)
|
||||
else:
|
||||
if not ase.cis_link:
|
||||
return
|
||||
ase.cis_link.sink = functools.partial(on_pdu, ase=ase)
|
||||
elif ase.state == ascs.AseStateMachine.State.CODEC_CONFIGURED:
|
||||
codec_config = ase.codec_specific_configuration
|
||||
assert isinstance(codec_config, bap.CodecSpecificConfiguration)
|
||||
if (
|
||||
not isinstance(codec_config, bap.CodecSpecificConfiguration)
|
||||
or codec_config.sampling_frequency is None
|
||||
or codec_config.frame_duration is None
|
||||
or codec_config.audio_channel_allocation is None
|
||||
):
|
||||
return
|
||||
if ase.role == ascs.AudioRole.SOURCE:
|
||||
setup_encoders(
|
||||
codec_config.sampling_frequency.hz,
|
||||
codec_config.frame_duration.us,
|
||||
codec_config.audio_channel_allocation.channel_count,
|
||||
)
|
||||
global encoding_config
|
||||
encoding_config = codec_config
|
||||
else:
|
||||
setup_decoders(
|
||||
codec_config.sampling_frequency.hz,
|
||||
codec_config.frame_duration.us,
|
||||
codec_config.audio_channel_allocation.channel_count,
|
||||
global decoder
|
||||
decoder = lc3.Decoder(
|
||||
frame_duration_us=codec_config.frame_duration.us,
|
||||
sample_rate_hz=codec_config.sampling_frequency.hz,
|
||||
num_channels=codec_config.audio_channel_allocation.channel_count,
|
||||
)
|
||||
|
||||
for ase in ascs_service.ase_state_machines.values():
|
||||
@@ -567,7 +445,7 @@ def speaker(ui_port: int, device_config: str, transport: str, lc3_file: str) ->
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def main():
|
||||
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
|
||||
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
|
||||
speaker()
|
||||
|
||||
|
||||
|
||||
Binary file not shown.
21
apps/pair.py
21
apps/pair.py
@@ -373,7 +373,9 @@ async def pair(
|
||||
shared_data = (
|
||||
None
|
||||
if oob == '-'
|
||||
else OobData.from_ad(AdvertisingData.from_bytes(bytes.fromhex(oob)))
|
||||
else OobData.from_ad(
|
||||
AdvertisingData.from_bytes(bytes.fromhex(oob))
|
||||
).shared_data
|
||||
)
|
||||
legacy_context = OobLegacyContext()
|
||||
oob_contexts = PairingConfig.OobConfig(
|
||||
@@ -381,16 +383,19 @@ async def pair(
|
||||
peer_data=shared_data,
|
||||
legacy_context=legacy_context,
|
||||
)
|
||||
oob_data = OobData(
|
||||
address=device.random_address,
|
||||
shared_data=shared_data,
|
||||
legacy_context=legacy_context,
|
||||
)
|
||||
print(color('@@@-----------------------------------', 'yellow'))
|
||||
print(color('@@@ OOB Data:', 'yellow'))
|
||||
print(color(f'@@@ {our_oob_context.share()}', 'yellow'))
|
||||
if shared_data is None:
|
||||
oob_data = OobData(
|
||||
address=device.random_address, shared_data=our_oob_context.share()
|
||||
)
|
||||
print(
|
||||
color(
|
||||
f'@@@ SHARE: {bytes(oob_data.to_ad()).hex()}',
|
||||
'yellow',
|
||||
)
|
||||
)
|
||||
print(color(f'@@@ TK={legacy_context.tk.hex()}', 'yellow'))
|
||||
print(color(f'@@@ HEX: ({bytes(oob_data.to_ad()).hex()})', 'yellow'))
|
||||
print(color('@@@-----------------------------------', 'yellow'))
|
||||
else:
|
||||
oob_contexts = None
|
||||
|
||||
@@ -237,6 +237,7 @@ class ClientBridge:
|
||||
address: str,
|
||||
tcp_host: str,
|
||||
tcp_port: int,
|
||||
authenticate: bool,
|
||||
encrypt: bool,
|
||||
):
|
||||
self.channel = channel
|
||||
@@ -245,6 +246,7 @@ class ClientBridge:
|
||||
self.address = address
|
||||
self.tcp_host = tcp_host
|
||||
self.tcp_port = tcp_port
|
||||
self.authenticate = authenticate
|
||||
self.encrypt = encrypt
|
||||
self.device: Optional[Device] = None
|
||||
self.connection: Optional[Connection] = None
|
||||
@@ -274,6 +276,11 @@ class ClientBridge:
|
||||
print(color(f"@@@ Bluetooth connection: {self.connection}", "blue"))
|
||||
self.connection.on("disconnection", self.on_disconnection)
|
||||
|
||||
if self.authenticate:
|
||||
print(color("@@@ Authenticating Bluetooth connection", "blue"))
|
||||
await self.connection.authenticate()
|
||||
print(color("@@@ Bluetooth connection authenticated", "blue"))
|
||||
|
||||
if self.encrypt:
|
||||
print(color("@@@ Encrypting Bluetooth connection", "blue"))
|
||||
await self.connection.encrypt()
|
||||
@@ -491,8 +498,9 @@ def server(context, tcp_host, tcp_port):
|
||||
@click.argument("bluetooth-address")
|
||||
@click.option("--tcp-host", help="TCP host", default="_")
|
||||
@click.option("--tcp-port", help="TCP port", default=DEFAULT_CLIENT_TCP_PORT)
|
||||
@click.option("--authenticate", is_flag=True, help="Authenticate the connection")
|
||||
@click.option("--encrypt", is_flag=True, help="Encrypt the connection")
|
||||
def client(context, bluetooth_address, tcp_host, tcp_port, encrypt):
|
||||
def client(context, bluetooth_address, tcp_host, tcp_port, authenticate, encrypt):
|
||||
bridge = ClientBridge(
|
||||
context.obj["channel"],
|
||||
context.obj["uuid"],
|
||||
@@ -500,6 +508,7 @@ def client(context, bluetooth_address, tcp_host, tcp_port, encrypt):
|
||||
bluetooth_address,
|
||||
tcp_host,
|
||||
tcp_port,
|
||||
authenticate,
|
||||
encrypt,
|
||||
)
|
||||
asyncio.run(run(context.obj["device_config"], context.obj["hci_transport"], bridge))
|
||||
|
||||
12
apps/show.py
12
apps/show.py
@@ -144,18 +144,18 @@ class Printer:
|
||||
help='Format of the input file',
|
||||
)
|
||||
@click.option(
|
||||
'--vendors',
|
||||
'--vendor',
|
||||
type=click.Choice(['android', 'zephyr']),
|
||||
multiple=True,
|
||||
help='Support vendor-specific commands (list one or more)',
|
||||
)
|
||||
@click.argument('filename')
|
||||
# pylint: disable=redefined-builtin
|
||||
def main(format, vendors, filename):
|
||||
for vendor in vendors:
|
||||
if vendor == 'android':
|
||||
def main(format, vendor, filename):
|
||||
for vendor_name in vendor:
|
||||
if vendor_name == 'android':
|
||||
import bumble.vendor.android.hci
|
||||
elif vendor == 'zephyr':
|
||||
elif vendor_name == 'zephyr':
|
||||
import bumble.vendor.zephyr.hci
|
||||
|
||||
input = open(filename, 'rb')
|
||||
@@ -180,7 +180,7 @@ def main(format, vendors, filename):
|
||||
else:
|
||||
printer.print(color("[TRUNCATED]", "red"))
|
||||
except Exception as error:
|
||||
logger.exception()
|
||||
logger.exception('')
|
||||
print(color(f'!!! {error}', 'red'))
|
||||
|
||||
|
||||
|
||||
@@ -57,6 +57,7 @@ if TYPE_CHECKING:
|
||||
# pylint: disable=line-too-long
|
||||
|
||||
ATT_CID = 0x04
|
||||
ATT_PSM = 0x001F
|
||||
|
||||
ATT_ERROR_RESPONSE = 0x01
|
||||
ATT_EXCHANGE_MTU_REQUEST = 0x02
|
||||
@@ -291,9 +292,6 @@ class ATT_PDU:
|
||||
def init_from_bytes(self, pdu, offset):
|
||||
return HCI_Object.init_from_bytes(self, pdu, offset, self.fields)
|
||||
|
||||
def to_bytes(self):
|
||||
return self.pdu
|
||||
|
||||
@property
|
||||
def is_command(self):
|
||||
return ((self.op_code >> 6) & 1) == 1
|
||||
@@ -303,7 +301,7 @@ class ATT_PDU:
|
||||
return ((self.op_code >> 7) & 1) == 1
|
||||
|
||||
def __bytes__(self):
|
||||
return self.to_bytes()
|
||||
return self.pdu
|
||||
|
||||
def __str__(self):
|
||||
result = color(self.name, 'yellow')
|
||||
@@ -759,13 +757,13 @@ class AttributeValue:
|
||||
def __init__(
|
||||
self,
|
||||
read: Union[
|
||||
Callable[[Optional[Connection]], bytes],
|
||||
Callable[[Optional[Connection]], Awaitable[bytes]],
|
||||
Callable[[Optional[Connection]], Any],
|
||||
Callable[[Optional[Connection]], Awaitable[Any]],
|
||||
None,
|
||||
] = None,
|
||||
write: Union[
|
||||
Callable[[Optional[Connection], bytes], None],
|
||||
Callable[[Optional[Connection], bytes], Awaitable[None]],
|
||||
Callable[[Optional[Connection], Any], None],
|
||||
Callable[[Optional[Connection], Any], Awaitable[None]],
|
||||
None,
|
||||
] = None,
|
||||
):
|
||||
@@ -824,13 +822,13 @@ class Attribute(EventEmitter):
|
||||
READ_REQUIRES_AUTHORIZATION = Permissions.READ_REQUIRES_AUTHORIZATION
|
||||
WRITE_REQUIRES_AUTHORIZATION = Permissions.WRITE_REQUIRES_AUTHORIZATION
|
||||
|
||||
value: Union[bytes, AttributeValue]
|
||||
value: Any
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
attribute_type: Union[str, bytes, UUID],
|
||||
permissions: Union[str, Attribute.Permissions],
|
||||
value: Union[str, bytes, AttributeValue] = b'',
|
||||
value: Any = b'',
|
||||
) -> None:
|
||||
EventEmitter.__init__(self)
|
||||
self.handle = 0
|
||||
@@ -848,11 +846,7 @@ class Attribute(EventEmitter):
|
||||
else:
|
||||
self.type = attribute_type
|
||||
|
||||
# Convert the value to a byte array
|
||||
if isinstance(value, str):
|
||||
self.value = bytes(value, 'utf-8')
|
||||
else:
|
||||
self.value = value
|
||||
self.value = value
|
||||
|
||||
def encode_value(self, value: Any) -> bytes:
|
||||
return value
|
||||
@@ -895,6 +889,8 @@ class Attribute(EventEmitter):
|
||||
else:
|
||||
value = self.value
|
||||
|
||||
self.emit('read', connection, value)
|
||||
|
||||
return self.encode_value(value)
|
||||
|
||||
async def write_value(self, connection: Connection, value_bytes: bytes) -> None:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Copyright 2021-2022 Google LLC
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -12,6 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from setuptools import setup
|
||||
|
||||
setup()
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
553
bumble/audio/io.py
Normal file
553
bumble/audio/io.py
Normal file
@@ -0,0 +1,553 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import abc
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import dataclasses
|
||||
import enum
|
||||
import logging
|
||||
import pathlib
|
||||
from typing import (
|
||||
AsyncGenerator,
|
||||
BinaryIO,
|
||||
TYPE_CHECKING,
|
||||
)
|
||||
import sys
|
||||
import wave
|
||||
|
||||
from bumble.colors import color
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import sounddevice # type: ignore[import-untyped]
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
# -----------------------------------------------------------------------------
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Classes
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclasses.dataclass
|
||||
class PcmFormat:
|
||||
class Endianness(enum.Enum):
|
||||
LITTLE = 0
|
||||
BIG = 1
|
||||
|
||||
class SampleType(enum.Enum):
|
||||
FLOAT32 = 0
|
||||
INT16 = 1
|
||||
|
||||
endianness: Endianness
|
||||
sample_type: SampleType
|
||||
sample_rate: int
|
||||
channels: int
|
||||
|
||||
@classmethod
|
||||
def from_str(cls, format_str: str) -> PcmFormat:
|
||||
endianness = cls.Endianness.LITTLE # Others not yet supported.
|
||||
sample_type_str, sample_rate_str, channels_str = format_str.split(',')
|
||||
if sample_type_str == 'int16le':
|
||||
sample_type = cls.SampleType.INT16
|
||||
elif sample_type_str == 'float32le':
|
||||
sample_type = cls.SampleType.FLOAT32
|
||||
else:
|
||||
raise ValueError(f'sample type {sample_type_str} not supported')
|
||||
sample_rate = int(sample_rate_str)
|
||||
channels = int(channels_str)
|
||||
|
||||
return cls(endianness, sample_type, sample_rate, channels)
|
||||
|
||||
@property
|
||||
def bytes_per_sample(self) -> int:
|
||||
return 2 if self.sample_type == self.SampleType.INT16 else 4
|
||||
|
||||
|
||||
def check_audio_output(output: str) -> bool:
|
||||
if output == 'device' or output.startswith('device:'):
|
||||
try:
|
||||
import sounddevice
|
||||
except ImportError as exc:
|
||||
raise ValueError(
|
||||
'audio output not available (sounddevice python module not installed)'
|
||||
) from exc
|
||||
except OSError as exc:
|
||||
raise ValueError(
|
||||
'audio output not available '
|
||||
'(sounddevice python module failed to load: '
|
||||
f'{exc})'
|
||||
) from exc
|
||||
|
||||
if output == 'device':
|
||||
# Default device
|
||||
return True
|
||||
|
||||
# Specific device
|
||||
device = output[7:]
|
||||
if device == '?':
|
||||
print(color('Audio Devices:', 'yellow'))
|
||||
for device_info in [
|
||||
device_info
|
||||
for device_info in sounddevice.query_devices()
|
||||
if device_info['max_output_channels'] > 0
|
||||
]:
|
||||
device_index = device_info['index']
|
||||
is_default = (
|
||||
color(' [default]', 'green')
|
||||
if sounddevice.default.device[1] == device_index
|
||||
else ''
|
||||
)
|
||||
print(
|
||||
f'{color(device_index, "cyan")}: {device_info["name"]}{is_default}'
|
||||
)
|
||||
return False
|
||||
|
||||
try:
|
||||
device_info = sounddevice.query_devices(int(device))
|
||||
except sounddevice.PortAudioError as exc:
|
||||
raise ValueError('No such audio device') from exc
|
||||
|
||||
if device_info['max_output_channels'] < 1:
|
||||
raise ValueError(
|
||||
f'Device {device} ({device_info["name"]}) does not have an output'
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def create_audio_output(output: str) -> AudioOutput:
|
||||
if output == 'stdout':
|
||||
return StreamAudioOutput(sys.stdout.buffer)
|
||||
|
||||
if output == 'device' or output.startswith('device:'):
|
||||
device_name = '' if output == 'device' else output[7:]
|
||||
return SoundDeviceAudioOutput(device_name)
|
||||
|
||||
if output == 'ffplay':
|
||||
return SubprocessAudioOutput(
|
||||
command=(
|
||||
'ffplay -probesize 32 -fflags nobuffer -analyzeduration 0 '
|
||||
'-ar {sample_rate} '
|
||||
'-ch_layout {channel_layout} '
|
||||
'-f f32le pipe:0'
|
||||
)
|
||||
)
|
||||
|
||||
if output.startswith('file:'):
|
||||
return FileAudioOutput(output[5:])
|
||||
|
||||
raise ValueError('unsupported audio output')
|
||||
|
||||
|
||||
class AudioOutput(abc.ABC):
|
||||
"""Audio output to which PCM samples can be written."""
|
||||
|
||||
async def open(self, pcm_format: PcmFormat) -> None:
|
||||
"""Start the output."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def write(self, pcm_samples: bytes) -> None:
|
||||
"""Write PCM samples. Must not block."""
|
||||
|
||||
async def aclose(self) -> None:
|
||||
"""Close the output."""
|
||||
|
||||
|
||||
class ThreadedAudioOutput(AudioOutput):
|
||||
"""Base class for AudioOutput classes that may need to call blocking functions.
|
||||
|
||||
The actual writing is performed in a thread, so as to ensure that calling write()
|
||||
does not block the caller.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._thread_pool = ThreadPoolExecutor(1)
|
||||
self._pcm_samples: asyncio.Queue[bytes] = asyncio.Queue()
|
||||
self._write_task = asyncio.create_task(self._write_loop())
|
||||
|
||||
async def _write_loop(self) -> None:
|
||||
while True:
|
||||
pcm_samples = await self._pcm_samples.get()
|
||||
await asyncio.get_running_loop().run_in_executor(
|
||||
self._thread_pool, self._write, pcm_samples
|
||||
)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _write(self, pcm_samples: bytes) -> None:
|
||||
"""This method does the actual writing and can block."""
|
||||
|
||||
def write(self, pcm_samples: bytes) -> None:
|
||||
self._pcm_samples.put_nowait(pcm_samples)
|
||||
|
||||
def _close(self) -> None:
|
||||
"""This method does the actual closing and can block."""
|
||||
|
||||
async def aclose(self) -> None:
|
||||
await asyncio.get_running_loop().run_in_executor(self._thread_pool, self._close)
|
||||
self._write_task.cancel()
|
||||
self._thread_pool.shutdown()
|
||||
|
||||
|
||||
class SoundDeviceAudioOutput(ThreadedAudioOutput):
|
||||
def __init__(self, device_name: str) -> None:
|
||||
super().__init__()
|
||||
self._device = int(device_name) if device_name else None
|
||||
self._stream: sounddevice.RawOutputStream | None = None
|
||||
|
||||
async def open(self, pcm_format: PcmFormat) -> None:
|
||||
import sounddevice # pylint: disable=import-error
|
||||
|
||||
self._stream = sounddevice.RawOutputStream(
|
||||
samplerate=pcm_format.sample_rate,
|
||||
device=self._device,
|
||||
channels=pcm_format.channels,
|
||||
dtype='float32',
|
||||
)
|
||||
self._stream.start()
|
||||
|
||||
def _write(self, pcm_samples: bytes) -> None:
|
||||
if self._stream is None:
|
||||
return
|
||||
|
||||
try:
|
||||
self._stream.write(pcm_samples)
|
||||
except Exception as error:
|
||||
print(f'Sound device error: {error}')
|
||||
raise
|
||||
|
||||
def _close(self):
|
||||
self._stream.stop()
|
||||
self._stream = None
|
||||
|
||||
|
||||
class StreamAudioOutput(ThreadedAudioOutput):
|
||||
"""AudioOutput where PCM samples are written to a stream that may block."""
|
||||
|
||||
def __init__(self, stream: BinaryIO) -> None:
|
||||
super().__init__()
|
||||
self._stream = stream
|
||||
|
||||
def _write(self, pcm_samples: bytes) -> None:
|
||||
self._stream.write(pcm_samples)
|
||||
self._stream.flush()
|
||||
|
||||
|
||||
class FileAudioOutput(StreamAudioOutput):
|
||||
"""AudioOutput where PCM samples are written to a file."""
|
||||
|
||||
def __init__(self, filename: str) -> None:
|
||||
self._file = open(filename, "wb")
|
||||
super().__init__(self._file)
|
||||
|
||||
async def shutdown(self):
|
||||
self._file.close()
|
||||
return await super().shutdown()
|
||||
|
||||
|
||||
class SubprocessAudioOutput(AudioOutput):
|
||||
"""AudioOutput where audio samples are written to a subprocess via stdin."""
|
||||
|
||||
def __init__(self, command: str) -> None:
|
||||
self._command = command
|
||||
self._subprocess: asyncio.subprocess.Process | None
|
||||
|
||||
async def open(self, pcm_format: PcmFormat) -> None:
|
||||
if pcm_format.channels == 1:
|
||||
channel_layout = 'mono'
|
||||
elif pcm_format.channels == 2:
|
||||
channel_layout = 'stereo'
|
||||
else:
|
||||
raise ValueError(f'{pcm_format.channels} channels not supported')
|
||||
|
||||
command = self._command.format(
|
||||
sample_rate=pcm_format.sample_rate, channel_layout=channel_layout
|
||||
)
|
||||
self._subprocess = await asyncio.create_subprocess_shell(
|
||||
command,
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
|
||||
def write(self, pcm_samples: bytes) -> None:
|
||||
if self._subprocess is None or self._subprocess.stdin is None:
|
||||
return
|
||||
|
||||
self._subprocess.stdin.write(pcm_samples)
|
||||
|
||||
async def aclose(self):
|
||||
if self._subprocess:
|
||||
self._subprocess.terminate()
|
||||
|
||||
|
||||
def check_audio_input(input: str) -> bool:
|
||||
if input == 'device' or input.startswith('device:'):
|
||||
try:
|
||||
import sounddevice # pylint: disable=import-error
|
||||
except ImportError as exc:
|
||||
raise ValueError(
|
||||
'audio input not available (sounddevice python module not installed)'
|
||||
) from exc
|
||||
except OSError as exc:
|
||||
raise ValueError(
|
||||
'audio input not available '
|
||||
'(sounddevice python module failed to load: '
|
||||
f'{exc})'
|
||||
) from exc
|
||||
|
||||
if input == 'device':
|
||||
# Default device
|
||||
return True
|
||||
|
||||
# Specific device
|
||||
device = input[7:]
|
||||
if device == '?':
|
||||
print(color('Audio Devices:', 'yellow'))
|
||||
for device_info in [
|
||||
device_info
|
||||
for device_info in sounddevice.query_devices()
|
||||
if device_info['max_input_channels'] > 0
|
||||
]:
|
||||
device_index = device_info["index"]
|
||||
is_mono = device_info['max_input_channels'] == 1
|
||||
max_channels = color(f'[{"mono" if is_mono else "stereo"}]', 'cyan')
|
||||
is_default = (
|
||||
color(' [default]', 'green')
|
||||
if sounddevice.default.device[0] == device_index
|
||||
else ''
|
||||
)
|
||||
print(
|
||||
f'{color(device_index, "cyan")}: {device_info["name"]}'
|
||||
f' {max_channels}{is_default}'
|
||||
)
|
||||
return False
|
||||
|
||||
try:
|
||||
device_info = sounddevice.query_devices(int(device))
|
||||
except sounddevice.PortAudioError as exc:
|
||||
raise ValueError('No such audio device') from exc
|
||||
|
||||
if device_info['max_input_channels'] < 1:
|
||||
raise ValueError(
|
||||
f'Device {device} ({device_info["name"]}) does not have an input'
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def create_audio_input(input: str, input_format: str) -> AudioInput:
|
||||
pcm_format: PcmFormat | None
|
||||
if input_format == 'auto':
|
||||
pcm_format = None
|
||||
else:
|
||||
pcm_format = PcmFormat.from_str(input_format)
|
||||
|
||||
if input == 'stdin':
|
||||
if not pcm_format:
|
||||
raise ValueError('input format details required for stdin')
|
||||
return StreamAudioInput(sys.stdin.buffer, pcm_format)
|
||||
|
||||
if input == 'device' or input.startswith('device:'):
|
||||
if not pcm_format:
|
||||
raise ValueError('input format details required for device')
|
||||
device_name = '' if input == 'device' else input[7:]
|
||||
return SoundDeviceAudioInput(device_name, pcm_format)
|
||||
|
||||
# If there's no file: prefix, check if we can assume it is a file.
|
||||
if pathlib.Path(input).is_file():
|
||||
input = 'file:' + input
|
||||
|
||||
if input.startswith('file:'):
|
||||
filename = input[5:]
|
||||
if filename.endswith('.wav'):
|
||||
if input_format != 'auto':
|
||||
raise ValueError(".wav file only supported with 'auto' format")
|
||||
return WaveAudioInput(filename)
|
||||
|
||||
if pcm_format is None:
|
||||
raise ValueError('input format details required for raw PCM files')
|
||||
return FileAudioInput(filename, pcm_format)
|
||||
|
||||
raise ValueError('input not supported')
|
||||
|
||||
|
||||
class AudioInput(abc.ABC):
|
||||
"""Audio input that produces PCM samples."""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def open(self) -> PcmFormat:
|
||||
"""Open the input."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def frames(self, frame_size: int) -> AsyncGenerator[bytes]:
|
||||
"""Generate one frame of PCM samples. Must not block."""
|
||||
|
||||
async def aclose(self) -> None:
|
||||
"""Close the input."""
|
||||
|
||||
|
||||
class ThreadedAudioInput(AudioInput):
|
||||
"""Base class for AudioInput implementation where reading samples may block."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._thread_pool = ThreadPoolExecutor(1)
|
||||
self._pcm_samples: asyncio.Queue[bytes] = asyncio.Queue()
|
||||
|
||||
@abc.abstractmethod
|
||||
def _read(self, frame_size: int) -> bytes:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _open(self) -> PcmFormat:
|
||||
pass
|
||||
|
||||
def _close(self) -> None:
|
||||
pass
|
||||
|
||||
async def open(self) -> PcmFormat:
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
self._thread_pool, self._open
|
||||
)
|
||||
|
||||
async def frames(self, frame_size: int) -> AsyncGenerator[bytes]:
|
||||
while pcm_sample := await asyncio.get_running_loop().run_in_executor(
|
||||
self._thread_pool, self._read, frame_size
|
||||
):
|
||||
yield pcm_sample
|
||||
|
||||
async def aclose(self) -> None:
|
||||
await asyncio.get_running_loop().run_in_executor(self._thread_pool, self._close)
|
||||
self._thread_pool.shutdown()
|
||||
|
||||
|
||||
class WaveAudioInput(ThreadedAudioInput):
|
||||
"""Audio input that reads PCM samples from a .wav file."""
|
||||
|
||||
def __init__(self, filename: str) -> None:
|
||||
super().__init__()
|
||||
self._filename = filename
|
||||
self._wav: wave.Wave_read | None = None
|
||||
self._bytes_read = 0
|
||||
|
||||
def _open(self) -> PcmFormat:
|
||||
self._wav = wave.open(self._filename, 'rb')
|
||||
if self._wav.getsampwidth() != 2:
|
||||
raise ValueError('sample width not supported')
|
||||
return PcmFormat(
|
||||
PcmFormat.Endianness.LITTLE,
|
||||
PcmFormat.SampleType.INT16,
|
||||
self._wav.getframerate(),
|
||||
self._wav.getnchannels(),
|
||||
)
|
||||
|
||||
def _read(self, frame_size: int) -> bytes:
|
||||
if not self._wav:
|
||||
return b''
|
||||
|
||||
pcm_samples = self._wav.readframes(frame_size)
|
||||
if not pcm_samples and self._bytes_read:
|
||||
# Loop around.
|
||||
self._wav.rewind()
|
||||
self._bytes_read = 0
|
||||
pcm_samples = self._wav.readframes(frame_size)
|
||||
|
||||
self._bytes_read += len(pcm_samples)
|
||||
return pcm_samples
|
||||
|
||||
def _close(self) -> None:
|
||||
if self._wav:
|
||||
self._wav.close()
|
||||
|
||||
|
||||
class StreamAudioInput(ThreadedAudioInput):
|
||||
"""AudioInput where samples are read from a raw PCM stream that may block."""
|
||||
|
||||
def __init__(self, stream: BinaryIO, pcm_format: PcmFormat) -> None:
|
||||
super().__init__()
|
||||
self._stream = stream
|
||||
self._pcm_format = pcm_format
|
||||
|
||||
def _open(self) -> PcmFormat:
|
||||
return self._pcm_format
|
||||
|
||||
def _read(self, frame_size: int) -> bytes:
|
||||
return self._stream.read(
|
||||
frame_size * self._pcm_format.channels * self._pcm_format.bytes_per_sample
|
||||
)
|
||||
|
||||
|
||||
class FileAudioInput(StreamAudioInput):
|
||||
"""AudioInput where PCM samples are read from a raw PCM file."""
|
||||
|
||||
def __init__(self, filename: str, pcm_format: PcmFormat) -> None:
|
||||
self._stream = open(filename, "rb")
|
||||
super().__init__(self._stream, pcm_format)
|
||||
|
||||
def _close(self) -> None:
|
||||
self._stream.close()
|
||||
|
||||
|
||||
class SoundDeviceAudioInput(ThreadedAudioInput):
|
||||
def __init__(self, device_name: str, pcm_format: PcmFormat) -> None:
|
||||
super().__init__()
|
||||
self._device = int(device_name) if device_name else None
|
||||
self._pcm_format = pcm_format
|
||||
self._stream: sounddevice.RawInputStream | None = None
|
||||
|
||||
def _open(self) -> PcmFormat:
|
||||
import sounddevice # pylint: disable=import-error
|
||||
|
||||
self._stream = sounddevice.RawInputStream(
|
||||
samplerate=self._pcm_format.sample_rate,
|
||||
device=self._device,
|
||||
channels=self._pcm_format.channels,
|
||||
dtype='int16',
|
||||
)
|
||||
self._stream.start()
|
||||
|
||||
return PcmFormat(
|
||||
PcmFormat.Endianness.LITTLE,
|
||||
PcmFormat.SampleType.INT16,
|
||||
self._pcm_format.sample_rate,
|
||||
2,
|
||||
)
|
||||
|
||||
def _read(self, frame_size: int) -> bytes:
|
||||
if not self._stream:
|
||||
return b''
|
||||
pcm_buffer, overflowed = self._stream.read(frame_size)
|
||||
if overflowed:
|
||||
logger.warning("input overflow")
|
||||
|
||||
# Convert the buffer to stereo if needed
|
||||
if self._pcm_format.channels == 1:
|
||||
stereo_buffer = bytearray()
|
||||
for i in range(frame_size):
|
||||
sample = pcm_buffer[i * 2 : i * 2 + 2]
|
||||
stereo_buffer += sample + sample
|
||||
return stereo_buffer
|
||||
|
||||
return bytes(pcm_buffer)
|
||||
|
||||
def _close(self):
|
||||
self._stream.stop()
|
||||
self._stream = None
|
||||
@@ -134,6 +134,8 @@ class Frame:
|
||||
opcode_offset = 3
|
||||
elif subunit_id == 6:
|
||||
raise core.InvalidPacketError("reserved subunit ID")
|
||||
else:
|
||||
raise core.InvalidPacketError("invalid subunit ID")
|
||||
|
||||
opcode = Frame.OperationCode(data[opcode_offset])
|
||||
operands = data[opcode_offset + 1 :]
|
||||
|
||||
@@ -154,15 +154,17 @@ class Controller:
|
||||
'0000000060000000'
|
||||
) # BR/EDR Not Supported, LE Supported (Controller)
|
||||
self.manufacturer_name = 0xFFFF
|
||||
self.hc_data_packet_length = 27
|
||||
self.hc_total_num_data_packets = 64
|
||||
self.hc_le_data_packet_length = 27
|
||||
self.hc_total_num_le_data_packets = 64
|
||||
self.acl_data_packet_length = 27
|
||||
self.total_num_acl_data_packets = 64
|
||||
self.le_acl_data_packet_length = 27
|
||||
self.total_num_le_acl_data_packets = 64
|
||||
self.iso_data_packet_length = 960
|
||||
self.total_num_iso_data_packets = 64
|
||||
self.event_mask = 0
|
||||
self.event_mask_page_2 = 0
|
||||
self.supported_commands = bytes.fromhex(
|
||||
'2000800000c000000000e4000000a822000000000000040000f7ffff7f000000'
|
||||
'30f0f9ff01008004000000000000000000000000000000000000000000000000'
|
||||
'30f0f9ff01008004002000000000000000000000000000000000000000000000'
|
||||
)
|
||||
self.le_event_mask = 0
|
||||
self.advertising_parameters = None
|
||||
@@ -314,7 +316,7 @@ class Controller:
|
||||
f'{color("CONTROLLER -> HOST", "green")}: {packet}'
|
||||
)
|
||||
if self.host:
|
||||
self.host.on_packet(packet.to_bytes())
|
||||
self.host.on_packet(bytes(packet))
|
||||
|
||||
# This method allows the controller to emulate the same API as a transport source
|
||||
async def wait_for_termination(self):
|
||||
@@ -1181,9 +1183,9 @@ class Controller:
|
||||
return struct.pack(
|
||||
'<BHBHH',
|
||||
HCI_SUCCESS,
|
||||
self.hc_data_packet_length,
|
||||
self.acl_data_packet_length,
|
||||
0,
|
||||
self.hc_total_num_data_packets,
|
||||
self.total_num_acl_data_packets,
|
||||
0,
|
||||
)
|
||||
|
||||
@@ -1192,7 +1194,7 @@ class Controller:
|
||||
See Bluetooth spec Vol 4, Part E - 7.4.6 Read BD_ADDR Command
|
||||
'''
|
||||
bd_addr = (
|
||||
self._public_address.to_bytes()
|
||||
bytes(self._public_address)
|
||||
if self._public_address is not None
|
||||
else bytes(6)
|
||||
)
|
||||
@@ -1212,8 +1214,21 @@ class Controller:
|
||||
return struct.pack(
|
||||
'<BHB',
|
||||
HCI_SUCCESS,
|
||||
self.hc_le_data_packet_length,
|
||||
self.hc_total_num_le_data_packets,
|
||||
self.le_acl_data_packet_length,
|
||||
self.total_num_le_acl_data_packets,
|
||||
)
|
||||
|
||||
def on_hci_le_read_buffer_size_v2_command(self, _command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.8.2 LE Read Buffer Size Command
|
||||
'''
|
||||
return struct.pack(
|
||||
'<BHBHB',
|
||||
HCI_SUCCESS,
|
||||
self.le_acl_data_packet_length,
|
||||
self.total_num_le_acl_data_packets,
|
||||
self.iso_data_packet_length,
|
||||
self.total_num_iso_data_packets,
|
||||
)
|
||||
|
||||
def on_hci_le_read_local_supported_features_command(self, _command):
|
||||
@@ -1543,6 +1558,41 @@ class Controller:
|
||||
}
|
||||
return bytes([HCI_SUCCESS])
|
||||
|
||||
def on_hci_le_set_advertising_set_random_address_command(self, _command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.8.52 LE Set Advertising Set Random Address
|
||||
Command
|
||||
'''
|
||||
return bytes([HCI_SUCCESS])
|
||||
|
||||
def on_hci_le_set_extended_advertising_parameters_command(self, _command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.8.53 LE Set Extended Advertising Parameters
|
||||
Command
|
||||
'''
|
||||
return bytes([HCI_SUCCESS, 0])
|
||||
|
||||
def on_hci_le_set_extended_advertising_data_command(self, _command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.8.54 LE Set Extended Advertising Data
|
||||
Command
|
||||
'''
|
||||
return bytes([HCI_SUCCESS])
|
||||
|
||||
def on_hci_le_set_extended_scan_response_data_command(self, _command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.8.55 LE Set Extended Scan Response Data
|
||||
Command
|
||||
'''
|
||||
return bytes([HCI_SUCCESS])
|
||||
|
||||
def on_hci_le_set_extended_advertising_enable_command(self, _command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.8.56 LE Set Extended Advertising Enable
|
||||
Command
|
||||
'''
|
||||
return bytes([HCI_SUCCESS])
|
||||
|
||||
def on_hci_le_read_maximum_advertising_data_length_command(self, _command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.8.57 LE Read Maximum Advertising Data
|
||||
@@ -1557,6 +1607,27 @@ class Controller:
|
||||
'''
|
||||
return struct.pack('<BB', HCI_SUCCESS, 0xF0)
|
||||
|
||||
def on_hci_le_set_periodic_advertising_parameters_command(self, _command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.8.61 LE Set Periodic Advertising Parameters
|
||||
Command
|
||||
'''
|
||||
return bytes([HCI_SUCCESS])
|
||||
|
||||
def on_hci_le_set_periodic_advertising_data_command(self, _command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.8.62 LE Set Periodic Advertising Data
|
||||
Command
|
||||
'''
|
||||
return bytes([HCI_SUCCESS])
|
||||
|
||||
def on_hci_le_set_periodic_advertising_enable_command(self, _command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.8.63 LE Set Periodic Advertising Enable
|
||||
Command
|
||||
'''
|
||||
return bytes([HCI_SUCCESS])
|
||||
|
||||
def on_hci_le_read_transmit_power_command(self, _command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.8.74 LE Read Transmit Power Command
|
||||
|
||||
@@ -1501,7 +1501,10 @@ class AdvertisingData:
|
||||
ad_data_str = f'"{ad_data.decode("utf-8")}"'
|
||||
elif ad_type == AdvertisingData.COMPLETE_LOCAL_NAME:
|
||||
ad_type_str = 'Complete Local Name'
|
||||
ad_data_str = f'"{ad_data.decode("utf-8")}"'
|
||||
try:
|
||||
ad_data_str = f'"{ad_data.decode("utf-8")}"'
|
||||
except UnicodeDecodeError:
|
||||
ad_data_str = ad_data.hex()
|
||||
elif ad_type == AdvertisingData.TX_POWER_LEVEL:
|
||||
ad_type_str = 'TX Power Level'
|
||||
ad_data_str = str(ad_data[0])
|
||||
|
||||
1983
bumble/device.py
1983
bumble/device.py
File diff suppressed because it is too large
Load Diff
@@ -20,6 +20,8 @@ Common types for drivers.
|
||||
# -----------------------------------------------------------------------------
|
||||
import abc
|
||||
|
||||
from bumble import core
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Classes
|
||||
|
||||
@@ -11,18 +11,33 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Support for Intel USB controllers.
|
||||
Loosely based on the Fuchsia OS implementation.
|
||||
"""
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import collections
|
||||
import dataclasses
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
import platform
|
||||
import struct
|
||||
from typing import Any, Deque, Optional, TYPE_CHECKING
|
||||
|
||||
from bumble import core
|
||||
from bumble.drivers import common
|
||||
from bumble.hci import (
|
||||
hci_vendor_command_op_code, # type: ignore
|
||||
HCI_Command,
|
||||
HCI_Reset_Command,
|
||||
)
|
||||
from bumble import hci
|
||||
from bumble import utils
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bumble.host import Host
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
@@ -34,39 +49,328 @@ logger = logging.getLogger(__name__)
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
INTEL_USB_PRODUCTS = {
|
||||
# Intel AX210
|
||||
(0x8087, 0x0032),
|
||||
# Intel BE200
|
||||
(0x8087, 0x0036),
|
||||
(0x8087, 0x0032), # AX210
|
||||
(0x8087, 0x0036), # BE200
|
||||
}
|
||||
|
||||
INTEL_FW_IMAGE_NAMES = [
|
||||
"ibt-0040-0041",
|
||||
"ibt-0040-1020",
|
||||
"ibt-0040-1050",
|
||||
"ibt-0040-2120",
|
||||
"ibt-0040-4150",
|
||||
"ibt-0041-0041",
|
||||
"ibt-0180-0041",
|
||||
"ibt-0180-1050",
|
||||
"ibt-0180-4150",
|
||||
"ibt-0291-0291",
|
||||
"ibt-1040-0041",
|
||||
"ibt-1040-1020",
|
||||
"ibt-1040-1050",
|
||||
"ibt-1040-2120",
|
||||
"ibt-1040-4150",
|
||||
]
|
||||
|
||||
INTEL_FIRMWARE_DIR_ENV = "BUMBLE_INTEL_FIRMWARE_DIR"
|
||||
INTEL_LINUX_FIRMWARE_DIR = "/lib/firmware/intel"
|
||||
|
||||
_MAX_FRAGMENT_SIZE = 252
|
||||
_POST_RESET_DELAY = 0.2
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# HCI Commands
|
||||
# -----------------------------------------------------------------------------
|
||||
HCI_INTEL_DDC_CONFIG_WRITE_COMMAND = hci_vendor_command_op_code(0xFC8B) # type: ignore
|
||||
HCI_INTEL_DDC_CONFIG_WRITE_PAYLOAD = [0x03, 0xE4, 0x02, 0x00]
|
||||
HCI_INTEL_WRITE_DEVICE_CONFIG_COMMAND = hci.hci_vendor_command_op_code(0x008B)
|
||||
HCI_INTEL_READ_VERSION_COMMAND = hci.hci_vendor_command_op_code(0x0005)
|
||||
HCI_INTEL_RESET_COMMAND = hci.hci_vendor_command_op_code(0x0001)
|
||||
HCI_INTEL_SECURE_SEND_COMMAND = hci.hci_vendor_command_op_code(0x0009)
|
||||
HCI_INTEL_WRITE_BOOT_PARAMS_COMMAND = hci.hci_vendor_command_op_code(0x000E)
|
||||
|
||||
HCI_Command.register_commands(globals())
|
||||
hci.HCI_Command.register_commands(globals())
|
||||
|
||||
|
||||
@HCI_Command.command( # type: ignore
|
||||
fields=[("params", "*")],
|
||||
@hci.HCI_Command.command(
|
||||
fields=[
|
||||
("param0", 1),
|
||||
],
|
||||
return_parameters_fields=[
|
||||
("params", "*"),
|
||||
("status", hci.STATUS_SPEC),
|
||||
("tlv", "*"),
|
||||
],
|
||||
)
|
||||
class Hci_Intel_DDC_Config_Write_Command(HCI_Command):
|
||||
class HCI_Intel_Read_Version_Command(hci.HCI_Command):
|
||||
pass
|
||||
|
||||
|
||||
@hci.HCI_Command.command(
|
||||
fields=[("data_type", 1), ("data", "*")],
|
||||
return_parameters_fields=[
|
||||
("status", 1),
|
||||
],
|
||||
)
|
||||
class Hci_Intel_Secure_Send_Command(hci.HCI_Command):
|
||||
pass
|
||||
|
||||
|
||||
@hci.HCI_Command.command(
|
||||
fields=[
|
||||
("reset_type", 1),
|
||||
("patch_enable", 1),
|
||||
("ddc_reload", 1),
|
||||
("boot_option", 1),
|
||||
("boot_address", 4),
|
||||
],
|
||||
return_parameters_fields=[
|
||||
("data", "*"),
|
||||
],
|
||||
)
|
||||
class HCI_Intel_Reset_Command(hci.HCI_Command):
|
||||
pass
|
||||
|
||||
|
||||
@hci.HCI_Command.command(
|
||||
fields=[("data", "*")],
|
||||
return_parameters_fields=[
|
||||
("status", hci.STATUS_SPEC),
|
||||
("params", "*"),
|
||||
],
|
||||
)
|
||||
class Hci_Intel_Write_Device_Config_Command(hci.HCI_Command):
|
||||
pass
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Functions
|
||||
# -----------------------------------------------------------------------------
|
||||
def intel_firmware_dir() -> pathlib.Path:
|
||||
"""
|
||||
Returns:
|
||||
A path to a subdir of the project data dir for Intel firmware.
|
||||
The directory is created if it doesn't exist.
|
||||
"""
|
||||
from bumble.drivers import project_data_dir
|
||||
|
||||
p = project_data_dir() / "firmware" / "intel"
|
||||
p.mkdir(parents=True, exist_ok=True)
|
||||
return p
|
||||
|
||||
|
||||
def _find_binary_path(file_name: str) -> pathlib.Path | None:
|
||||
# First check if an environment variable is set
|
||||
if INTEL_FIRMWARE_DIR_ENV in os.environ:
|
||||
if (
|
||||
path := pathlib.Path(os.environ[INTEL_FIRMWARE_DIR_ENV]) / file_name
|
||||
).is_file():
|
||||
logger.debug(f"{file_name} found in env dir")
|
||||
return path
|
||||
|
||||
# When the environment variable is set, don't look elsewhere
|
||||
return None
|
||||
|
||||
# Then, look where the firmware download tool writes by default
|
||||
if (path := intel_firmware_dir() / file_name).is_file():
|
||||
logger.debug(f"{file_name} found in project data dir")
|
||||
return path
|
||||
|
||||
# Then, look in the package's driver directory
|
||||
if (path := pathlib.Path(__file__).parent / "intel_fw" / file_name).is_file():
|
||||
logger.debug(f"{file_name} found in package dir")
|
||||
return path
|
||||
|
||||
# On Linux, check the system's FW directory
|
||||
if (
|
||||
platform.system() == "Linux"
|
||||
and (path := pathlib.Path(INTEL_LINUX_FIRMWARE_DIR) / file_name).is_file()
|
||||
):
|
||||
logger.debug(f"{file_name} found in Linux system FW dir")
|
||||
return path
|
||||
|
||||
# Finally look in the current directory
|
||||
if (path := pathlib.Path.cwd() / file_name).is_file():
|
||||
logger.debug(f"{file_name} found in CWD")
|
||||
return path
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _parse_tlv(data: bytes) -> list[tuple[ValueType, Any]]:
|
||||
result: list[tuple[ValueType, Any]] = []
|
||||
while len(data) >= 2:
|
||||
value_type = ValueType(data[0])
|
||||
value_length = data[1]
|
||||
value = data[2 : 2 + value_length]
|
||||
typed_value: Any
|
||||
|
||||
if value_type == ValueType.END:
|
||||
break
|
||||
|
||||
if value_type in (ValueType.CNVI, ValueType.CNVR):
|
||||
(v,) = struct.unpack("<I", value)
|
||||
typed_value = (
|
||||
(((v >> 0) & 0xF) << 12)
|
||||
| (((v >> 4) & 0xF) << 0)
|
||||
| (((v >> 8) & 0xF) << 4)
|
||||
| (((v >> 24) & 0xF) << 8)
|
||||
)
|
||||
elif value_type == ValueType.HARDWARE_INFO:
|
||||
(v,) = struct.unpack("<I", value)
|
||||
typed_value = HardwareInfo(
|
||||
HardwarePlatform((v >> 8) & 0xFF), HardwareVariant((v >> 16) & 0x3F)
|
||||
)
|
||||
elif value_type in (
|
||||
ValueType.USB_VENDOR_ID,
|
||||
ValueType.USB_PRODUCT_ID,
|
||||
ValueType.DEVICE_REVISION,
|
||||
):
|
||||
(typed_value,) = struct.unpack("<H", value)
|
||||
elif value_type == ValueType.CURRENT_MODE_OF_OPERATION:
|
||||
typed_value = ModeOfOperation(value[0])
|
||||
elif value_type in (
|
||||
ValueType.BUILD_TYPE,
|
||||
ValueType.BUILD_NUMBER,
|
||||
ValueType.SECURE_BOOT,
|
||||
ValueType.OTP_LOCK,
|
||||
ValueType.API_LOCK,
|
||||
ValueType.DEBUG_LOCK,
|
||||
ValueType.SECURE_BOOT_ENGINE_TYPE,
|
||||
):
|
||||
typed_value = value[0]
|
||||
elif value_type == ValueType.TIMESTAMP:
|
||||
typed_value = Timestamp(value[0], value[1])
|
||||
elif value_type == ValueType.FIRMWARE_BUILD:
|
||||
typed_value = FirmwareBuild(value[0], Timestamp(value[1], value[2]))
|
||||
elif value_type == ValueType.BLUETOOTH_ADDRESS:
|
||||
typed_value = hci.Address(
|
||||
value, address_type=hci.Address.PUBLIC_DEVICE_ADDRESS
|
||||
)
|
||||
else:
|
||||
typed_value = value
|
||||
|
||||
result.append((value_type, typed_value))
|
||||
data = data[2 + value_length :]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Classes
|
||||
# -----------------------------------------------------------------------------
|
||||
class DriverError(core.BaseBumbleError):
|
||||
def __init__(self, message: str) -> None:
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"IntelDriverError({self.message})"
|
||||
|
||||
|
||||
class ValueType(utils.OpenIntEnum):
|
||||
END = 0x00
|
||||
CNVI = 0x10
|
||||
CNVR = 0x11
|
||||
HARDWARE_INFO = 0x12
|
||||
DEVICE_REVISION = 0x16
|
||||
CURRENT_MODE_OF_OPERATION = 0x1C
|
||||
USB_VENDOR_ID = 0x17
|
||||
USB_PRODUCT_ID = 0x18
|
||||
TIMESTAMP = 0x1D
|
||||
BUILD_TYPE = 0x1E
|
||||
BUILD_NUMBER = 0x1F
|
||||
SECURE_BOOT = 0x28
|
||||
OTP_LOCK = 0x2A
|
||||
API_LOCK = 0x2B
|
||||
DEBUG_LOCK = 0x2C
|
||||
FIRMWARE_BUILD = 0x2D
|
||||
SECURE_BOOT_ENGINE_TYPE = 0x2F
|
||||
BLUETOOTH_ADDRESS = 0x30
|
||||
|
||||
|
||||
class HardwarePlatform(utils.OpenIntEnum):
|
||||
INTEL_37 = 0x37
|
||||
|
||||
|
||||
class HardwareVariant(utils.OpenIntEnum):
|
||||
# This is a just a partial list.
|
||||
# Add other constants here as new hardware is encountered and tested.
|
||||
TYPHOON_PEAK = 0x17
|
||||
GALE_PEAK = 0x1C
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class HardwareInfo:
|
||||
platform: HardwarePlatform
|
||||
variant: HardwareVariant
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Timestamp:
|
||||
week: int
|
||||
year: int
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class FirmwareBuild:
|
||||
build_number: int
|
||||
timestamp: Timestamp
|
||||
|
||||
|
||||
class ModeOfOperation(utils.OpenIntEnum):
|
||||
BOOTLOADER = 0x01
|
||||
INTERMEDIATE = 0x02
|
||||
OPERATIONAL = 0x03
|
||||
|
||||
|
||||
class SecureBootEngineType(utils.OpenIntEnum):
|
||||
RSA = 0x00
|
||||
ECDSA = 0x01
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BootParams:
|
||||
css_header_offset: int
|
||||
css_header_size: int
|
||||
pki_offset: int
|
||||
pki_size: int
|
||||
sig_offset: int
|
||||
sig_size: int
|
||||
write_offset: int
|
||||
|
||||
|
||||
_BOOT_PARAMS = {
|
||||
SecureBootEngineType.RSA: BootParams(0, 128, 128, 256, 388, 256, 964),
|
||||
SecureBootEngineType.ECDSA: BootParams(644, 128, 772, 96, 868, 96, 964),
|
||||
}
|
||||
|
||||
|
||||
class Driver(common.Driver):
|
||||
def __init__(self, host):
|
||||
def __init__(self, host: Host) -> None:
|
||||
self.host = host
|
||||
self.max_in_flight_firmware_load_commands = 1
|
||||
self.pending_firmware_load_commands: Deque[hci.HCI_Command] = (
|
||||
collections.deque()
|
||||
)
|
||||
self.can_send_firmware_load_command = asyncio.Event()
|
||||
self.can_send_firmware_load_command.set()
|
||||
self.firmware_load_complete = asyncio.Event()
|
||||
self.reset_complete = asyncio.Event()
|
||||
|
||||
# Parse configuration options from the driver name.
|
||||
self.ddc_addon: Optional[bytes] = None
|
||||
self.ddc_override: Optional[bytes] = None
|
||||
driver = host.hci_metadata.get("driver")
|
||||
if driver is not None and driver.startswith("intel/"):
|
||||
for key, value in [
|
||||
key_eq_value.split(":") for key_eq_value in driver[6:].split("+")
|
||||
]:
|
||||
if key == "ddc_addon":
|
||||
self.ddc_addon = bytes.fromhex(value)
|
||||
elif key == "ddc_override":
|
||||
self.ddc_override = bytes.fromhex(value)
|
||||
|
||||
@staticmethod
|
||||
def check(host):
|
||||
def check(host: Host) -> bool:
|
||||
driver = host.hci_metadata.get("driver")
|
||||
if driver == "intel":
|
||||
if driver == "intel" or driver is not None and driver.startswith("intel/"):
|
||||
return True
|
||||
|
||||
vendor_id = host.hci_metadata.get("vendor_id")
|
||||
@@ -85,18 +389,283 @@ class Driver(common.Driver):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
async def for_host(cls, host, force=False): # type: ignore
|
||||
async def for_host(cls, host: Host, force: bool = False):
|
||||
# Only instantiate this driver if explicitly selected
|
||||
if not force and not cls.check(host):
|
||||
return None
|
||||
|
||||
return cls(host)
|
||||
|
||||
async def init_controller(self):
|
||||
def on_packet(self, packet: bytes) -> None:
|
||||
"""Handler for event packets that are received from an ACL channel"""
|
||||
event = hci.HCI_Event.from_bytes(packet)
|
||||
|
||||
if not isinstance(event, hci.HCI_Command_Complete_Event):
|
||||
self.host.on_hci_event_packet(event)
|
||||
return
|
||||
|
||||
if not event.return_parameters == hci.HCI_SUCCESS:
|
||||
raise DriverError("HCI_Command_Complete_Event error")
|
||||
|
||||
if self.max_in_flight_firmware_load_commands != event.num_hci_command_packets:
|
||||
logger.debug(
|
||||
"max_in_flight_firmware_load_commands update: "
|
||||
f"{event.num_hci_command_packets}"
|
||||
)
|
||||
self.max_in_flight_firmware_load_commands = event.num_hci_command_packets
|
||||
logger.debug(f"event: {event}")
|
||||
self.pending_firmware_load_commands.popleft()
|
||||
in_flight = len(self.pending_firmware_load_commands)
|
||||
logger.debug(f"event received, {in_flight} still in flight")
|
||||
if in_flight < self.max_in_flight_firmware_load_commands:
|
||||
self.can_send_firmware_load_command.set()
|
||||
|
||||
async def send_firmware_load_command(self, command: hci.HCI_Command) -> None:
|
||||
# Wait until we can send.
|
||||
await self.can_send_firmware_load_command.wait()
|
||||
|
||||
# Send the command and adjust counters.
|
||||
self.host.send_hci_packet(command)
|
||||
self.pending_firmware_load_commands.append(command)
|
||||
in_flight = len(self.pending_firmware_load_commands)
|
||||
if in_flight >= self.max_in_flight_firmware_load_commands:
|
||||
logger.debug(f"max commands in flight reached [{in_flight}]")
|
||||
self.can_send_firmware_load_command.clear()
|
||||
|
||||
async def send_firmware_data(self, data_type: int, data: bytes) -> None:
|
||||
while data:
|
||||
fragment_size = min(len(data), _MAX_FRAGMENT_SIZE)
|
||||
fragment = data[:fragment_size]
|
||||
data = data[fragment_size:]
|
||||
|
||||
await self.send_firmware_load_command(
|
||||
Hci_Intel_Secure_Send_Command(data_type=data_type, data=fragment)
|
||||
)
|
||||
|
||||
async def load_firmware(self) -> None:
|
||||
self.host.ready = True
|
||||
await self.host.send_command(HCI_Reset_Command(), check_result=True)
|
||||
await self.host.send_command(
|
||||
Hci_Intel_DDC_Config_Write_Command(
|
||||
params=HCI_INTEL_DDC_CONFIG_WRITE_PAYLOAD
|
||||
device_info = await self.read_device_info()
|
||||
logger.debug(
|
||||
"device info: \n%s",
|
||||
"\n".join(
|
||||
[
|
||||
f" {value_type.name}: {value}"
|
||||
for value_type, value in device_info.items()
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
# Check if the firmware is already loaded.
|
||||
if (
|
||||
device_info.get(ValueType.CURRENT_MODE_OF_OPERATION)
|
||||
== ModeOfOperation.OPERATIONAL
|
||||
):
|
||||
logger.debug("firmware already loaded")
|
||||
return
|
||||
|
||||
# We only support some platforms and variants.
|
||||
hardware_info = device_info.get(ValueType.HARDWARE_INFO)
|
||||
if hardware_info is None:
|
||||
raise DriverError("hardware info missing")
|
||||
if hardware_info.platform != HardwarePlatform.INTEL_37:
|
||||
raise DriverError("hardware platform not supported")
|
||||
if hardware_info.variant not in (
|
||||
HardwareVariant.TYPHOON_PEAK,
|
||||
HardwareVariant.GALE_PEAK,
|
||||
):
|
||||
raise DriverError("hardware variant not supported")
|
||||
|
||||
# Compute the firmware name.
|
||||
if ValueType.CNVI not in device_info or ValueType.CNVR not in device_info:
|
||||
raise DriverError("insufficient device info, missing CNVI or CNVR")
|
||||
|
||||
firmware_base_name = (
|
||||
"ibt-"
|
||||
f"{device_info[ValueType.CNVI]:04X}-"
|
||||
f"{device_info[ValueType.CNVR]:04X}"
|
||||
)
|
||||
logger.debug(f"FW base name: {firmware_base_name}")
|
||||
|
||||
firmware_name = f"{firmware_base_name}.sfi"
|
||||
firmware_path = _find_binary_path(firmware_name)
|
||||
if not firmware_path:
|
||||
logger.warning(f"Firmware file {firmware_name} not found")
|
||||
logger.warning("See https://google.github.io/bumble/drivers/intel.html")
|
||||
return None
|
||||
logger.debug(f"loading firmware from {firmware_path}")
|
||||
firmware_image = firmware_path.read_bytes()
|
||||
|
||||
engine_type = device_info.get(ValueType.SECURE_BOOT_ENGINE_TYPE)
|
||||
if engine_type is None:
|
||||
raise DriverError("secure boot engine type missing")
|
||||
if engine_type not in _BOOT_PARAMS:
|
||||
raise DriverError("secure boot engine type not supported")
|
||||
|
||||
boot_params = _BOOT_PARAMS[engine_type]
|
||||
if len(firmware_image) < boot_params.write_offset:
|
||||
raise DriverError("firmware image too small")
|
||||
|
||||
# Register to receive vendor events.
|
||||
def on_vendor_event(event: hci.HCI_Vendor_Event):
|
||||
logger.debug(f"vendor event: {event}")
|
||||
event_type = event.parameters[0]
|
||||
if event_type == 0x02:
|
||||
# Boot event
|
||||
logger.debug("boot complete")
|
||||
self.reset_complete.set()
|
||||
elif event_type == 0x06:
|
||||
# Firmware load event
|
||||
logger.debug("download complete")
|
||||
self.firmware_load_complete.set()
|
||||
else:
|
||||
logger.debug(f"ignoring vendor event type {event_type}")
|
||||
|
||||
self.host.on("vendor_event", on_vendor_event)
|
||||
|
||||
# We need to temporarily intercept packets from the controller,
|
||||
# because they are formatted as HCI event packets but are received
|
||||
# on the ACL channel, so the host parser would get confused.
|
||||
saved_on_packet = self.host.on_packet
|
||||
self.host.on_packet = self.on_packet # type: ignore
|
||||
self.firmware_load_complete.clear()
|
||||
|
||||
# Send the CSS header
|
||||
data = firmware_image[
|
||||
boot_params.css_header_offset : boot_params.css_header_offset
|
||||
+ boot_params.css_header_size
|
||||
]
|
||||
await self.send_firmware_data(0x00, data)
|
||||
|
||||
# Send the PKI header
|
||||
data = firmware_image[
|
||||
boot_params.pki_offset : boot_params.pki_offset + boot_params.pki_size
|
||||
]
|
||||
await self.send_firmware_data(0x03, data)
|
||||
|
||||
# Send the Signature header
|
||||
data = firmware_image[
|
||||
boot_params.sig_offset : boot_params.sig_offset + boot_params.sig_size
|
||||
]
|
||||
await self.send_firmware_data(0x02, data)
|
||||
|
||||
# Send the rest of the image.
|
||||
# The payload consists of command objects, which are sent when they add up
|
||||
# to a multiple of 4 bytes.
|
||||
boot_address = 0
|
||||
offset = boot_params.write_offset
|
||||
fragment_size = 0
|
||||
while offset + 3 < len(firmware_image):
|
||||
(command_opcode,) = struct.unpack_from(
|
||||
"<H", firmware_image, offset + fragment_size
|
||||
)
|
||||
command_size = firmware_image[offset + fragment_size + 2]
|
||||
if command_opcode == HCI_INTEL_WRITE_BOOT_PARAMS_COMMAND:
|
||||
(boot_address,) = struct.unpack_from(
|
||||
"<I", firmware_image, offset + fragment_size + 3
|
||||
)
|
||||
logger.debug(
|
||||
"found HCI_INTEL_WRITE_BOOT_PARAMS_COMMAND, "
|
||||
f"boot_address={boot_address}"
|
||||
)
|
||||
fragment_size += 3 + command_size
|
||||
if fragment_size % 4 == 0:
|
||||
await self.send_firmware_data(
|
||||
0x01, firmware_image[offset : offset + fragment_size]
|
||||
)
|
||||
logger.debug(f"sent {fragment_size} bytes")
|
||||
offset += fragment_size
|
||||
fragment_size = 0
|
||||
|
||||
# Wait for the firmware loading to be complete.
|
||||
logger.debug("waiting for firmware to be loaded")
|
||||
await self.firmware_load_complete.wait()
|
||||
logger.debug("firmware loaded")
|
||||
|
||||
# Restore the original packet handler.
|
||||
self.host.on_packet = saved_on_packet # type: ignore
|
||||
|
||||
# Reset
|
||||
self.reset_complete.clear()
|
||||
self.host.send_hci_packet(
|
||||
HCI_Intel_Reset_Command(
|
||||
reset_type=0x00,
|
||||
patch_enable=0x01,
|
||||
ddc_reload=0x00,
|
||||
boot_option=0x01,
|
||||
boot_address=boot_address,
|
||||
)
|
||||
)
|
||||
logger.debug("waiting for reset completion")
|
||||
await self.reset_complete.wait()
|
||||
logger.debug("reset complete")
|
||||
|
||||
# Load the device config if there is one.
|
||||
if self.ddc_override:
|
||||
logger.debug("loading overridden DDC")
|
||||
await self.load_device_config(self.ddc_override)
|
||||
else:
|
||||
ddc_name = f"{firmware_base_name}.ddc"
|
||||
ddc_path = _find_binary_path(ddc_name)
|
||||
if ddc_path:
|
||||
logger.debug(f"loading DDC from {ddc_path}")
|
||||
ddc_data = ddc_path.read_bytes()
|
||||
await self.load_device_config(ddc_data)
|
||||
if self.ddc_addon:
|
||||
logger.debug("loading DDC addon")
|
||||
await self.load_device_config(self.ddc_addon)
|
||||
|
||||
async def load_device_config(self, ddc_data: bytes) -> None:
|
||||
while ddc_data:
|
||||
ddc_len = 1 + ddc_data[0]
|
||||
ddc_payload = ddc_data[:ddc_len]
|
||||
await self.host.send_command(
|
||||
Hci_Intel_Write_Device_Config_Command(data=ddc_payload)
|
||||
)
|
||||
ddc_data = ddc_data[ddc_len:]
|
||||
|
||||
async def reboot_bootloader(self) -> None:
|
||||
self.host.send_hci_packet(
|
||||
HCI_Intel_Reset_Command(
|
||||
reset_type=0x01,
|
||||
patch_enable=0x01,
|
||||
ddc_reload=0x01,
|
||||
boot_option=0x00,
|
||||
boot_address=0,
|
||||
)
|
||||
)
|
||||
await asyncio.sleep(_POST_RESET_DELAY)
|
||||
|
||||
async def read_device_info(self) -> dict[ValueType, Any]:
|
||||
self.host.ready = True
|
||||
response = await self.host.send_command(hci.HCI_Reset_Command())
|
||||
if not (
|
||||
isinstance(response, hci.HCI_Command_Complete_Event)
|
||||
and response.return_parameters
|
||||
in (hci.HCI_UNKNOWN_HCI_COMMAND_ERROR, hci.HCI_SUCCESS)
|
||||
):
|
||||
# When the controller is in operational mode, the response is a
|
||||
# successful response.
|
||||
# When the controller is in bootloader mode,
|
||||
# HCI_UNKNOWN_HCI_COMMAND_ERROR is the expected response. Anything
|
||||
# else is a failure.
|
||||
logger.warning(f"unexpected response: {response}")
|
||||
raise DriverError("unexpected HCI response")
|
||||
|
||||
# Read the firmware version.
|
||||
response = await self.host.send_command(
|
||||
HCI_Intel_Read_Version_Command(param0=0xFF)
|
||||
)
|
||||
if not isinstance(response, hci.HCI_Command_Complete_Event):
|
||||
raise DriverError("unexpected HCI response")
|
||||
|
||||
if response.return_parameters.status != 0: # type: ignore
|
||||
raise DriverError("HCI_Intel_Read_Version_Command error")
|
||||
|
||||
tlvs = _parse_tlv(response.return_parameters.tlv) # type: ignore
|
||||
|
||||
# Convert the list to a dict. That's Ok here because we only expect each type
|
||||
# to appear just once.
|
||||
return dict(tlvs)
|
||||
|
||||
async def init_controller(self):
|
||||
await self.load_firmware()
|
||||
|
||||
@@ -28,23 +28,26 @@ import functools
|
||||
import logging
|
||||
import struct
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
SupportsBytes,
|
||||
Type,
|
||||
Union,
|
||||
TYPE_CHECKING,
|
||||
)
|
||||
|
||||
from bumble.colors import color
|
||||
from bumble.core import BaseBumbleError, UUID
|
||||
from bumble.core import BaseBumbleError, InvalidOperationError, UUID
|
||||
from bumble.att import Attribute, AttributeValue
|
||||
from bumble.utils import ByteSerializable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bumble.gatt_client import AttributeProxy
|
||||
from bumble.device import Connection
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -275,6 +278,13 @@ GATT_SOURCE_AUDIO_LOCATION_CHARACTERISTIC = UUID.from_16_bits(0x2BCC, 'Sou
|
||||
GATT_AVAILABLE_AUDIO_CONTEXTS_CHARACTERISTIC = UUID.from_16_bits(0x2BCD, 'Available Audio Contexts')
|
||||
GATT_SUPPORTED_AUDIO_CONTEXTS_CHARACTERISTIC = UUID.from_16_bits(0x2BCE, 'Supported Audio Contexts')
|
||||
|
||||
# Gaming Audio Service (GMAS)
|
||||
GATT_GMAP_ROLE_CHARACTERISTIC = UUID.from_16_bits(0x2C00, 'GMAP Role')
|
||||
GATT_UGG_FEATURES_CHARACTERISTIC = UUID.from_16_bits(0x2C01, 'UGG Features')
|
||||
GATT_UGT_FEATURES_CHARACTERISTIC = UUID.from_16_bits(0x2C02, 'UGT Features')
|
||||
GATT_BGS_FEATURES_CHARACTERISTIC = UUID.from_16_bits(0x2C03, 'BGS Features')
|
||||
GATT_BGR_FEATURES_CHARACTERISTIC = UUID.from_16_bits(0x2C04, 'BGR Features')
|
||||
|
||||
# Hearing Access Service
|
||||
GATT_HEARING_AID_FEATURES_CHARACTERISTIC = UUID.from_16_bits(0x2BDA, 'Hearing Aid Features')
|
||||
GATT_HEARING_AID_PRESET_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2BDB, 'Hearing Aid Preset Control Point')
|
||||
@@ -304,6 +314,7 @@ GATT_CENTRAL_ADDRESS_RESOLUTION__CHARACTERISTIC = UUID.from_16_bi
|
||||
GATT_CLIENT_SUPPORTED_FEATURES_CHARACTERISTIC = UUID.from_16_bits(0x2B29, 'Client Supported Features')
|
||||
GATT_DATABASE_HASH_CHARACTERISTIC = UUID.from_16_bits(0x2B2A, 'Database Hash')
|
||||
GATT_SERVER_SUPPORTED_FEATURES_CHARACTERISTIC = UUID.from_16_bits(0x2B3A, 'Server Supported Features')
|
||||
GATT_LE_GATT_SECURITY_LEVELS_CHARACTERISTIC = UUID.from_16_bits(0x2BF5, 'E GATT Security Levels')
|
||||
|
||||
# fmt: on
|
||||
# pylint: enable=line-too-long
|
||||
@@ -312,8 +323,6 @@ GATT_SERVER_SUPPORTED_FEATURES_CHARACTERISTIC = UUID.from_16_bi
|
||||
# -----------------------------------------------------------------------------
|
||||
# Utils
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def show_services(services: Iterable[Service]) -> None:
|
||||
for service in services:
|
||||
print(color(str(service), 'cyan'))
|
||||
@@ -343,7 +352,7 @@ class Service(Attribute):
|
||||
def __init__(
|
||||
self,
|
||||
uuid: Union[str, UUID],
|
||||
characteristics: List[Characteristic],
|
||||
characteristics: Iterable[Characteristic],
|
||||
primary=True,
|
||||
included_services: Iterable[Service] = (),
|
||||
) -> None:
|
||||
@@ -362,7 +371,7 @@ class Service(Attribute):
|
||||
)
|
||||
self.uuid = uuid
|
||||
self.included_services = list(included_services)
|
||||
self.characteristics = characteristics[:]
|
||||
self.characteristics = list(characteristics)
|
||||
self.primary = primary
|
||||
|
||||
def get_advertising_data(self) -> Optional[bytes]:
|
||||
@@ -393,7 +402,7 @@ class TemplateService(Service):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
characteristics: List[Characteristic],
|
||||
characteristics: Iterable[Characteristic],
|
||||
primary: bool = True,
|
||||
included_services: Iterable[Service] = (),
|
||||
) -> None:
|
||||
@@ -410,7 +419,7 @@ class IncludedServiceDeclaration(Attribute):
|
||||
|
||||
def __init__(self, service: Service) -> None:
|
||||
declaration_bytes = struct.pack(
|
||||
'<HH2s', service.handle, service.end_group_handle, service.uuid.to_bytes()
|
||||
'<HH2s', service.handle, service.end_group_handle, bytes(service.uuid)
|
||||
)
|
||||
super().__init__(
|
||||
GATT_INCLUDE_ATTRIBUTE_TYPE, Attribute.READABLE, declaration_bytes
|
||||
@@ -490,7 +499,7 @@ class Characteristic(Attribute):
|
||||
uuid: Union[str, bytes, UUID],
|
||||
properties: Characteristic.Properties,
|
||||
permissions: Union[str, Attribute.Permissions],
|
||||
value: Union[str, bytes, CharacteristicValue] = b'',
|
||||
value: Any = b'',
|
||||
descriptors: Sequence[Descriptor] = (),
|
||||
):
|
||||
super().__init__(uuid, permissions, value)
|
||||
@@ -525,7 +534,11 @@ class CharacteristicDeclaration(Attribute):
|
||||
|
||||
characteristic: Characteristic
|
||||
|
||||
def __init__(self, characteristic: Characteristic, value_handle: int) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
characteristic: Characteristic,
|
||||
value_handle: int,
|
||||
) -> None:
|
||||
declaration_bytes = (
|
||||
struct.pack('<BH', characteristic.properties, value_handle)
|
||||
+ characteristic.uuid.to_pdu_bytes()
|
||||
@@ -665,10 +678,14 @@ class DelegatedCharacteristicAdapter(CharacteristicAdapter):
|
||||
self.decode = decode
|
||||
|
||||
def encode_value(self, value):
|
||||
return self.encode(value) if self.encode else value
|
||||
if self.encode is None:
|
||||
raise InvalidOperationError('delegated adapter does not have an encoder')
|
||||
return self.encode(value)
|
||||
|
||||
def decode_value(self, value):
|
||||
return self.decode(value) if self.decode else value
|
||||
if self.decode is None:
|
||||
raise InvalidOperationError('delegate adapter does not have a decoder')
|
||||
return self.decode(value)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -705,7 +722,7 @@ class MappedCharacteristicAdapter(PackedCharacteristicAdapter):
|
||||
'''
|
||||
Adapter that packs/unpacks characteristic values according to a standard
|
||||
Python `struct` format.
|
||||
The adapted `read_value` and `write_value` methods return/accept aa dictionary which
|
||||
The adapted `read_value` and `write_value` methods return/accept a dictionary which
|
||||
is packed/unpacked according to format, with the arguments extracted from the
|
||||
dictionary by key, in the same order as they occur in the `keys` parameter.
|
||||
'''
|
||||
@@ -735,6 +752,24 @@ class UTF8CharacteristicAdapter(CharacteristicAdapter):
|
||||
return value.decode('utf-8')
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class SerializableCharacteristicAdapter(CharacteristicAdapter):
|
||||
'''
|
||||
Adapter that converts any class to/from bytes using the class'
|
||||
`to_bytes` and `__bytes__` methods, respectively.
|
||||
'''
|
||||
|
||||
def __init__(self, characteristic, cls: Type[ByteSerializable]):
|
||||
super().__init__(characteristic)
|
||||
self.cls = cls
|
||||
|
||||
def encode_value(self, value: SupportsBytes) -> bytes:
|
||||
return bytes(value)
|
||||
|
||||
def decode_value(self, value: bytes) -> Any:
|
||||
return self.cls.from_bytes(value)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Descriptor(Attribute):
|
||||
'''
|
||||
@@ -769,3 +804,23 @@ class ClientCharacteristicConfigurationBits(enum.IntFlag):
|
||||
DEFAULT = 0x0000
|
||||
NOTIFICATION = 0x0001
|
||||
INDICATION = 0x0002
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class ClientSupportedFeatures(enum.IntFlag):
|
||||
'''
|
||||
See Vol 3, Part G - 7.2 - Table 7.6: Client Supported Features bit assignments.
|
||||
'''
|
||||
|
||||
ROBUST_CACHING = 0x01
|
||||
ENHANCED_ATT_BEARER = 0x02
|
||||
MULTIPLE_HANDLE_VALUE_NOTIFICATIONS = 0x04
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class ServerSupportedFeatures(enum.IntFlag):
|
||||
'''
|
||||
See Vol 3, Part G - 7.4 - Table 7.11: Server Supported Features bit assignments.
|
||||
'''
|
||||
|
||||
EATT_SUPPORTED = 0x01
|
||||
|
||||
@@ -78,6 +78,7 @@ from .gatt import (
|
||||
GATT_INCLUDE_ATTRIBUTE_TYPE,
|
||||
Characteristic,
|
||||
ClientCharacteristicConfigurationBits,
|
||||
InvalidServiceError,
|
||||
TemplateService,
|
||||
)
|
||||
|
||||
@@ -162,12 +163,23 @@ class ServiceProxy(AttributeProxy):
|
||||
self.uuid = uuid
|
||||
self.characteristics = []
|
||||
|
||||
async def discover_characteristics(self, uuids=()):
|
||||
async def discover_characteristics(self, uuids=()) -> list[CharacteristicProxy]:
|
||||
return await self.client.discover_characteristics(uuids, self)
|
||||
|
||||
def get_characteristics_by_uuid(self, uuid):
|
||||
def get_characteristics_by_uuid(self, uuid: UUID) -> list[CharacteristicProxy]:
|
||||
"""Get all the characteristics with a specified UUID."""
|
||||
return self.client.get_characteristics_by_uuid(uuid, self)
|
||||
|
||||
def get_required_characteristic_by_uuid(self, uuid: UUID) -> CharacteristicProxy:
|
||||
"""
|
||||
Get the first characteristic with a specified UUID.
|
||||
|
||||
If no characteristic with that UUID is found, an InvalidServiceError is raised.
|
||||
"""
|
||||
if not (characteristics := self.get_characteristics_by_uuid(uuid)):
|
||||
raise InvalidServiceError(f'{uuid} characteristic not found')
|
||||
return characteristics[0]
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f'Service(handle=0x{self.handle:04X}, uuid={self.uuid})'
|
||||
|
||||
@@ -292,7 +304,7 @@ class Client:
|
||||
logger.debug(
|
||||
f'GATT Command from client: [0x{self.connection.handle:04X}] {command}'
|
||||
)
|
||||
self.send_gatt_pdu(command.to_bytes())
|
||||
self.send_gatt_pdu(bytes(command))
|
||||
|
||||
async def send_request(self, request: ATT_PDU):
|
||||
logger.debug(
|
||||
@@ -310,7 +322,7 @@ class Client:
|
||||
self.pending_request = request
|
||||
|
||||
try:
|
||||
self.send_gatt_pdu(request.to_bytes())
|
||||
self.send_gatt_pdu(bytes(request))
|
||||
response = await asyncio.wait_for(
|
||||
self.pending_response, GATT_REQUEST_TIMEOUT
|
||||
)
|
||||
@@ -328,7 +340,7 @@ class Client:
|
||||
f'GATT Confirmation from client: [0x{self.connection.handle:04X}] '
|
||||
f'{confirmation}'
|
||||
)
|
||||
self.send_gatt_pdu(confirmation.to_bytes())
|
||||
self.send_gatt_pdu(bytes(confirmation))
|
||||
|
||||
async def request_mtu(self, mtu: int) -> int:
|
||||
# Check the range
|
||||
@@ -898,6 +910,12 @@ class Client:
|
||||
) and subscriber in subscribers:
|
||||
subscribers.remove(subscriber)
|
||||
|
||||
# The characteristic itself is added as subscriber. If it is the
|
||||
# last remaining subscriber, we remove it, such that the clean up
|
||||
# works correctly. Otherwise the CCCD never is set back to 0.
|
||||
if len(subscribers) == 1 and characteristic in subscribers:
|
||||
subscribers.remove(characteristic)
|
||||
|
||||
# Cleanup if we removed the last one
|
||||
if not subscribers:
|
||||
del subscriber_set[characteristic.handle]
|
||||
|
||||
@@ -28,7 +28,17 @@ import asyncio
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
import struct
|
||||
from typing import List, Tuple, Optional, TypeVar, Type, Dict, Iterable, TYPE_CHECKING
|
||||
from typing import (
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Type,
|
||||
Union,
|
||||
TYPE_CHECKING,
|
||||
)
|
||||
from pyee import EventEmitter
|
||||
|
||||
from bumble.colors import color
|
||||
@@ -68,6 +78,7 @@ from bumble.gatt import (
|
||||
GATT_REQUEST_TIMEOUT,
|
||||
GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE,
|
||||
Characteristic,
|
||||
CharacteristicAdapter,
|
||||
CharacteristicDeclaration,
|
||||
CharacteristicValue,
|
||||
IncludedServiceDeclaration,
|
||||
@@ -353,7 +364,7 @@ class Server(EventEmitter):
|
||||
logger.debug(
|
||||
f'GATT Response from server: [0x{connection.handle:04X}] {response}'
|
||||
)
|
||||
self.send_gatt_pdu(connection.handle, response.to_bytes())
|
||||
self.send_gatt_pdu(connection.handle, bytes(response))
|
||||
|
||||
async def notify_subscriber(
|
||||
self,
|
||||
@@ -450,7 +461,7 @@ class Server(EventEmitter):
|
||||
)
|
||||
|
||||
try:
|
||||
self.send_gatt_pdu(connection.handle, indication.to_bytes())
|
||||
self.send_gatt_pdu(connection.handle, bytes(indication))
|
||||
await asyncio.wait_for(pending_confirmation, GATT_REQUEST_TIMEOUT)
|
||||
except asyncio.TimeoutError as error:
|
||||
logger.warning(color('!!! GATT Indicate timeout', 'red'))
|
||||
|
||||
1011
bumble/hci.py
1011
bumble/hci.py
File diff suppressed because it is too large
Load Diff
@@ -141,7 +141,7 @@ class HfFeature(enum.IntFlag):
|
||||
"""
|
||||
HF supported features (AT+BRSF=) (normative).
|
||||
|
||||
Hands-Free Profile v1.8, 4.34.2, AT Capabilities Re-Used from GSM 07.07 and 3GPP 27.007.
|
||||
Hands-Free Profile v1.9, 4.34.2, AT Capabilities Re-Used from GSM 07.07 and 3GPP 27.007.
|
||||
"""
|
||||
|
||||
EC_NR = 0x001 # Echo Cancel & Noise reduction
|
||||
@@ -155,14 +155,14 @@ class HfFeature(enum.IntFlag):
|
||||
HF_INDICATORS = 0x100
|
||||
ESCO_S4_SETTINGS_SUPPORTED = 0x200
|
||||
ENHANCED_VOICE_RECOGNITION_STATUS = 0x400
|
||||
VOICE_RECOGNITION_TEST = 0x800
|
||||
VOICE_RECOGNITION_TEXT = 0x800
|
||||
|
||||
|
||||
class AgFeature(enum.IntFlag):
|
||||
"""
|
||||
AG supported features (+BRSF:) (normative).
|
||||
|
||||
Hands-Free Profile v1.8, 4.34.2, AT Capabilities Re-Used from GSM 07.07 and 3GPP 27.007.
|
||||
Hands-Free Profile v1.9, 4.34.2, AT Capabilities Re-Used from GSM 07.07 and 3GPP 27.007.
|
||||
"""
|
||||
|
||||
THREE_WAY_CALLING = 0x001
|
||||
@@ -178,7 +178,7 @@ class AgFeature(enum.IntFlag):
|
||||
HF_INDICATORS = 0x400
|
||||
ESCO_S4_SETTINGS_SUPPORTED = 0x800
|
||||
ENHANCED_VOICE_RECOGNITION_STATUS = 0x1000
|
||||
VOICE_RECOGNITION_TEST = 0x2000
|
||||
VOICE_RECOGNITION_TEXT = 0x2000
|
||||
|
||||
|
||||
class AudioCodec(enum.IntEnum):
|
||||
@@ -1390,6 +1390,7 @@ class AgProtocol(pyee.EventEmitter):
|
||||
|
||||
def _on_bac(self, *args) -> None:
|
||||
self.supported_audio_codecs = [AudioCodec(int(value)) for value in args]
|
||||
self.emit('supported_audio_codecs', self.supported_audio_codecs)
|
||||
self.send_ok()
|
||||
|
||||
def _on_bcs(self, codec: bytes) -> None:
|
||||
@@ -1618,7 +1619,7 @@ class ProfileVersion(enum.IntEnum):
|
||||
"""
|
||||
Profile version (normative).
|
||||
|
||||
Hands-Free Profile v1.8, 5.3 SDP Interoperability Requirements.
|
||||
Hands-Free Profile v1.8, 6.3 SDP Interoperability Requirements.
|
||||
"""
|
||||
|
||||
V1_5 = 0x0105
|
||||
@@ -1632,7 +1633,7 @@ class HfSdpFeature(enum.IntFlag):
|
||||
"""
|
||||
HF supported features (normative).
|
||||
|
||||
Hands-Free Profile v1.8, 5.3 SDP Interoperability Requirements.
|
||||
Hands-Free Profile v1.9, 6.3 SDP Interoperability Requirements.
|
||||
"""
|
||||
|
||||
EC_NR = 0x01 # Echo Cancel & Noise reduction
|
||||
@@ -1640,16 +1641,17 @@ class HfSdpFeature(enum.IntFlag):
|
||||
CLI_PRESENTATION_CAPABILITY = 0x04
|
||||
VOICE_RECOGNITION_ACTIVATION = 0x08
|
||||
REMOTE_VOLUME_CONTROL = 0x10
|
||||
WIDE_BAND = 0x20 # Wide band speech
|
||||
WIDE_BAND_SPEECH = 0x20
|
||||
ENHANCED_VOICE_RECOGNITION_STATUS = 0x40
|
||||
VOICE_RECOGNITION_TEST = 0x80
|
||||
VOICE_RECOGNITION_TEXT = 0x80
|
||||
SUPER_WIDE_BAND = 0x100
|
||||
|
||||
|
||||
class AgSdpFeature(enum.IntFlag):
|
||||
"""
|
||||
AG supported features (normative).
|
||||
|
||||
Hands-Free Profile v1.8, 5.3 SDP Interoperability Requirements.
|
||||
Hands-Free Profile v1.9, 6.3 SDP Interoperability Requirements.
|
||||
"""
|
||||
|
||||
THREE_WAY_CALLING = 0x01
|
||||
@@ -1657,9 +1659,10 @@ class AgSdpFeature(enum.IntFlag):
|
||||
VOICE_RECOGNITION_FUNCTION = 0x04
|
||||
IN_BAND_RING_TONE_CAPABILITY = 0x08
|
||||
VOICE_TAG = 0x10 # Attach a number to voice tag
|
||||
WIDE_BAND = 0x20 # Wide band speech
|
||||
WIDE_BAND_SPEECH = 0x20
|
||||
ENHANCED_VOICE_RECOGNITION_STATUS = 0x40
|
||||
VOICE_RECOGNITION_TEST = 0x80
|
||||
VOICE_RECOGNITION_TEXT = 0x80
|
||||
SUPER_WIDE_BAND_SPEED_SPEECH = 0x100
|
||||
|
||||
|
||||
def make_hf_sdp_records(
|
||||
@@ -1692,11 +1695,11 @@ def make_hf_sdp_records(
|
||||
in configuration.supported_hf_features
|
||||
):
|
||||
hf_supported_features |= HfSdpFeature.ENHANCED_VOICE_RECOGNITION_STATUS
|
||||
if HfFeature.VOICE_RECOGNITION_TEST in configuration.supported_hf_features:
|
||||
hf_supported_features |= HfSdpFeature.VOICE_RECOGNITION_TEST
|
||||
if HfFeature.VOICE_RECOGNITION_TEXT in configuration.supported_hf_features:
|
||||
hf_supported_features |= HfSdpFeature.VOICE_RECOGNITION_TEXT
|
||||
|
||||
if AudioCodec.MSBC in configuration.supported_audio_codecs:
|
||||
hf_supported_features |= HfSdpFeature.WIDE_BAND
|
||||
hf_supported_features |= HfSdpFeature.WIDE_BAND_SPEECH
|
||||
|
||||
return [
|
||||
sdp.ServiceAttribute(
|
||||
@@ -1772,14 +1775,14 @@ def make_ag_sdp_records(
|
||||
in configuration.supported_ag_features
|
||||
):
|
||||
ag_supported_features |= AgSdpFeature.ENHANCED_VOICE_RECOGNITION_STATUS
|
||||
if AgFeature.VOICE_RECOGNITION_TEST in configuration.supported_ag_features:
|
||||
ag_supported_features |= AgSdpFeature.VOICE_RECOGNITION_TEST
|
||||
if AgFeature.VOICE_RECOGNITION_TEXT in configuration.supported_ag_features:
|
||||
ag_supported_features |= AgSdpFeature.VOICE_RECOGNITION_TEXT
|
||||
if AgFeature.IN_BAND_RING_TONE_CAPABILITY in configuration.supported_ag_features:
|
||||
ag_supported_features |= AgSdpFeature.IN_BAND_RING_TONE_CAPABILITY
|
||||
if AgFeature.VOICE_RECOGNITION_FUNCTION in configuration.supported_ag_features:
|
||||
ag_supported_features |= AgSdpFeature.VOICE_RECOGNITION_FUNCTION
|
||||
if AudioCodec.MSBC in configuration.supported_audio_codecs:
|
||||
ag_supported_features |= AgSdpFeature.WIDE_BAND
|
||||
ag_supported_features |= AgSdpFeature.WIDE_BAND_SPEECH
|
||||
|
||||
return [
|
||||
sdp.ServiceAttribute(
|
||||
|
||||
429
bumble/host.py
429
bumble/host.py
@@ -1,4 +1,4 @@
|
||||
# Copyright 2021-2022 Google LLC
|
||||
# Copyright 2021-2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -34,6 +34,8 @@ from typing import (
|
||||
TYPE_CHECKING,
|
||||
)
|
||||
|
||||
import pyee
|
||||
|
||||
from bumble.colors import color
|
||||
from bumble.l2cap import L2CAP_PDU
|
||||
from bumble.snoop import Snooper
|
||||
@@ -59,7 +61,19 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class AclPacketQueue:
|
||||
class DataPacketQueue(pyee.EventEmitter):
|
||||
"""
|
||||
Flow-control queue for host->controller data packets (ACL, ISO).
|
||||
|
||||
The queue holds packets associated with a connection handle. The packets
|
||||
are sent to the controller, up to a maximum total number of packets in flight.
|
||||
A packet is considered to be "in flight" when it has been sent to the controller
|
||||
but not completed yet. Packets are no longer "in flight" when the controller
|
||||
declares them as completed.
|
||||
|
||||
The queue emits a 'flow' event whenever one or more packets are completed.
|
||||
"""
|
||||
|
||||
max_packet_size: int
|
||||
|
||||
def __init__(
|
||||
@@ -68,40 +82,105 @@ class AclPacketQueue:
|
||||
max_in_flight: int,
|
||||
send: Callable[[hci.HCI_Packet], None],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.max_packet_size = max_packet_size
|
||||
self.max_in_flight = max_in_flight
|
||||
self.in_flight = 0
|
||||
self.send = send
|
||||
self.packets: Deque[hci.HCI_AclDataPacket] = collections.deque()
|
||||
self._in_flight = 0 # Total number of packets in flight across all connections
|
||||
self._in_flight_per_connection: dict[int, int] = collections.defaultdict(
|
||||
int
|
||||
) # Number of packets in flight per connection
|
||||
self._send = send
|
||||
self._packets: Deque[tuple[hci.HCI_Packet, int]] = collections.deque()
|
||||
self._queued = 0
|
||||
self._completed = 0
|
||||
|
||||
def enqueue(self, packet: hci.HCI_AclDataPacket) -> None:
|
||||
self.packets.appendleft(packet)
|
||||
self.check_queue()
|
||||
@property
|
||||
def queued(self) -> int:
|
||||
"""Total number of packets queued since creation."""
|
||||
return self._queued
|
||||
|
||||
if self.packets:
|
||||
@property
|
||||
def completed(self) -> int:
|
||||
"""Total number of packets completed since creation."""
|
||||
return self._completed
|
||||
|
||||
@property
|
||||
def pending(self) -> int:
|
||||
"""Number of packets that have been queued but not completed."""
|
||||
return self._queued - self._completed
|
||||
|
||||
def enqueue(self, packet: hci.HCI_Packet, connection_handle: int) -> None:
|
||||
"""Enqueue a packet associated with a connection"""
|
||||
self._packets.appendleft((packet, connection_handle))
|
||||
self._queued += 1
|
||||
self._check_queue()
|
||||
|
||||
if self._packets:
|
||||
logger.debug(
|
||||
f'{self.in_flight} ACL packets in flight, '
|
||||
f'{len(self.packets)} in queue'
|
||||
f'{self._in_flight} packets in flight, '
|
||||
f'{len(self._packets)} in queue'
|
||||
)
|
||||
|
||||
def check_queue(self) -> None:
|
||||
while self.packets and self.in_flight < self.max_in_flight:
|
||||
packet = self.packets.pop()
|
||||
self.send(packet)
|
||||
self.in_flight += 1
|
||||
def flush(self, connection_handle: int) -> None:
|
||||
"""
|
||||
Remove all packets associated with a connection.
|
||||
|
||||
def on_packets_completed(self, packet_count: int) -> None:
|
||||
if packet_count > self.in_flight:
|
||||
All packets associated with the connection that are in flight are implicitly
|
||||
marked as completed, but no 'flow' event is emitted.
|
||||
"""
|
||||
|
||||
packets_to_keep = [
|
||||
(packet, handle)
|
||||
for (packet, handle) in self._packets
|
||||
if handle != connection_handle
|
||||
]
|
||||
if flushed_count := len(self._packets) - len(packets_to_keep):
|
||||
self._completed += flushed_count
|
||||
self._packets = collections.deque(packets_to_keep)
|
||||
|
||||
if connection_handle in self._in_flight_per_connection:
|
||||
in_flight = self._in_flight_per_connection[connection_handle]
|
||||
self._completed += in_flight
|
||||
self._in_flight -= in_flight
|
||||
del self._in_flight_per_connection[connection_handle]
|
||||
|
||||
def _check_queue(self) -> None:
|
||||
while self._packets and self._in_flight < self.max_in_flight:
|
||||
packet, connection_handle = self._packets.pop()
|
||||
self._send(packet)
|
||||
self._in_flight += 1
|
||||
self._in_flight_per_connection[connection_handle] += 1
|
||||
|
||||
def on_packets_completed(self, packet_count: int, connection_handle: int) -> None:
|
||||
"""Mark one or more packets associated with a connection as completed."""
|
||||
if connection_handle not in self._in_flight_per_connection:
|
||||
logger.warning(
|
||||
color(
|
||||
'!!! {packet_count} completed but only '
|
||||
f'{self.in_flight} in flight'
|
||||
)
|
||||
f'received completion for unknown connection {connection_handle}'
|
||||
)
|
||||
packet_count = self.in_flight
|
||||
return
|
||||
|
||||
self.in_flight -= packet_count
|
||||
self.check_queue()
|
||||
in_flight_for_connection = self._in_flight_per_connection[connection_handle]
|
||||
if packet_count <= in_flight_for_connection:
|
||||
self._in_flight_per_connection[connection_handle] -= packet_count
|
||||
else:
|
||||
logger.warning(
|
||||
f'{packet_count} completed for {connection_handle} '
|
||||
f'but only {in_flight_for_connection} in flight'
|
||||
)
|
||||
self._in_flight_per_connection[connection_handle] = 0
|
||||
|
||||
if packet_count <= self._in_flight:
|
||||
self._in_flight -= packet_count
|
||||
self._completed += packet_count
|
||||
else:
|
||||
logger.warning(
|
||||
f'{packet_count} completed but only {self._in_flight} in flight'
|
||||
)
|
||||
self._in_flight = 0
|
||||
self._completed = self._queued
|
||||
|
||||
self._check_queue()
|
||||
self.emit('flow')
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -114,7 +193,7 @@ class Connection:
|
||||
self.peer_address = peer_address
|
||||
self.assembler = hci.HCI_AclDataPacketAssembler(self.on_acl_pdu)
|
||||
self.transport = transport
|
||||
acl_packet_queue: Optional[AclPacketQueue] = (
|
||||
acl_packet_queue: Optional[DataPacketQueue] = (
|
||||
host.le_acl_packet_queue
|
||||
if transport == BT_LE_TRANSPORT
|
||||
else host.acl_packet_queue
|
||||
@@ -129,28 +208,37 @@ class Connection:
|
||||
l2cap_pdu = L2CAP_PDU.from_bytes(pdu)
|
||||
self.host.on_l2cap_pdu(self, l2cap_pdu.cid, l2cap_pdu.payload)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f'Connection(transport={self.transport}, peer_address={self.peer_address})'
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclasses.dataclass
|
||||
class ScoLink:
|
||||
peer_address: hci.Address
|
||||
handle: int
|
||||
connection_handle: int
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclasses.dataclass
|
||||
class CisLink:
|
||||
peer_address: hci.Address
|
||||
class IsoLink:
|
||||
handle: int
|
||||
packet_queue: DataPacketQueue = dataclasses.field(repr=False)
|
||||
packet_sequence_number: int = 0
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Host(AbortableEventEmitter):
|
||||
connections: Dict[int, Connection]
|
||||
cis_links: Dict[int, CisLink]
|
||||
cis_links: Dict[int, IsoLink]
|
||||
bis_links: Dict[int, IsoLink]
|
||||
sco_links: Dict[int, ScoLink]
|
||||
acl_packet_queue: Optional[AclPacketQueue] = None
|
||||
le_acl_packet_queue: Optional[AclPacketQueue] = None
|
||||
bigs: dict[int, set[int]] = {} # BIG Handle to BIS Handles
|
||||
acl_packet_queue: Optional[DataPacketQueue] = None
|
||||
le_acl_packet_queue: Optional[DataPacketQueue] = None
|
||||
iso_packet_queue: Optional[DataPacketQueue] = None
|
||||
hci_sink: Optional[TransportSink] = None
|
||||
hci_metadata: Dict[str, Any]
|
||||
long_term_key_provider: Optional[
|
||||
@@ -169,6 +257,7 @@ class Host(AbortableEventEmitter):
|
||||
self.ready = False # True when we can accept incoming packets
|
||||
self.connections = {} # Connections, by connection handle
|
||||
self.cis_links = {} # CIS links, by connection handle
|
||||
self.bis_links = {} # BIS links, by connection handle
|
||||
self.sco_links = {} # SCO links, by connection handle
|
||||
self.pending_command = None
|
||||
self.pending_response: Optional[asyncio.Future[Any]] = None
|
||||
@@ -199,7 +288,7 @@ class Host(AbortableEventEmitter):
|
||||
check_address_type: bool = False,
|
||||
) -> Optional[Connection]:
|
||||
for connection in self.connections.values():
|
||||
if connection.peer_address.to_bytes() == bd_addr.to_bytes():
|
||||
if bytes(connection.peer_address) == bytes(bd_addr):
|
||||
if (
|
||||
check_address_type
|
||||
and connection.peer_address.address_type != bd_addr.address_type
|
||||
@@ -387,6 +476,12 @@ class Host(AbortableEventEmitter):
|
||||
hci.HCI_LE_TRANSMIT_POWER_REPORTING_EVENT,
|
||||
hci.HCI_LE_BIGINFO_ADVERTISING_REPORT_EVENT,
|
||||
hci.HCI_LE_SUBRATE_CHANGE_EVENT,
|
||||
hci.HCI_LE_CS_READ_REMOTE_SUPPORTED_CAPABILITIES_COMPLETE_EVENT,
|
||||
hci.HCI_LE_CS_PROCEDURE_ENABLE_COMPLETE_EVENT,
|
||||
hci.HCI_LE_CS_SECURITY_ENABLE_COMPLETE_EVENT,
|
||||
hci.HCI_LE_CS_CONFIG_COMPLETE_EVENT,
|
||||
hci.HCI_LE_CS_SUBEVENT_RESULT_EVENT,
|
||||
hci.HCI_LE_CS_SUBEVENT_RESULT_CONTINUE_EVENT,
|
||||
]
|
||||
)
|
||||
|
||||
@@ -411,39 +506,70 @@ class Host(AbortableEventEmitter):
|
||||
f'hc_total_num_acl_data_packets={hc_total_num_acl_data_packets}'
|
||||
)
|
||||
|
||||
self.acl_packet_queue = AclPacketQueue(
|
||||
self.acl_packet_queue = DataPacketQueue(
|
||||
max_packet_size=hc_acl_data_packet_length,
|
||||
max_in_flight=hc_total_num_acl_data_packets,
|
||||
send=self.send_hci_packet,
|
||||
)
|
||||
|
||||
hc_le_acl_data_packet_length = 0
|
||||
hc_total_num_le_acl_data_packets = 0
|
||||
if self.supports_command(hci.HCI_LE_READ_BUFFER_SIZE_COMMAND):
|
||||
le_acl_data_packet_length = 0
|
||||
total_num_le_acl_data_packets = 0
|
||||
iso_data_packet_length = 0
|
||||
total_num_iso_data_packets = 0
|
||||
if self.supports_command(hci.HCI_LE_READ_BUFFER_SIZE_V2_COMMAND):
|
||||
response = await self.send_command(
|
||||
hci.HCI_LE_Read_Buffer_Size_V2_Command(), check_result=True
|
||||
)
|
||||
le_acl_data_packet_length = (
|
||||
response.return_parameters.le_acl_data_packet_length
|
||||
)
|
||||
total_num_le_acl_data_packets = (
|
||||
response.return_parameters.total_num_le_acl_data_packets
|
||||
)
|
||||
iso_data_packet_length = response.return_parameters.iso_data_packet_length
|
||||
total_num_iso_data_packets = (
|
||||
response.return_parameters.total_num_iso_data_packets
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
'HCI LE flow control: '
|
||||
f'le_acl_data_packet_length={le_acl_data_packet_length},'
|
||||
f'total_num_le_acl_data_packets={total_num_le_acl_data_packets}'
|
||||
f'iso_data_packet_length={iso_data_packet_length},'
|
||||
f'total_num_iso_data_packets={total_num_iso_data_packets}'
|
||||
)
|
||||
elif self.supports_command(hci.HCI_LE_READ_BUFFER_SIZE_COMMAND):
|
||||
response = await self.send_command(
|
||||
hci.HCI_LE_Read_Buffer_Size_Command(), check_result=True
|
||||
)
|
||||
hc_le_acl_data_packet_length = (
|
||||
response.return_parameters.hc_le_acl_data_packet_length
|
||||
le_acl_data_packet_length = (
|
||||
response.return_parameters.le_acl_data_packet_length
|
||||
)
|
||||
hc_total_num_le_acl_data_packets = (
|
||||
response.return_parameters.hc_total_num_le_acl_data_packets
|
||||
total_num_le_acl_data_packets = (
|
||||
response.return_parameters.total_num_le_acl_data_packets
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
'HCI LE ACL flow control: '
|
||||
f'hc_le_acl_data_packet_length={hc_le_acl_data_packet_length},'
|
||||
f'hc_total_num_le_acl_data_packets={hc_total_num_le_acl_data_packets}'
|
||||
f'le_acl_data_packet_length={le_acl_data_packet_length},'
|
||||
f'total_num_le_acl_data_packets={total_num_le_acl_data_packets}'
|
||||
)
|
||||
|
||||
if hc_le_acl_data_packet_length == 0 or hc_total_num_le_acl_data_packets == 0:
|
||||
if le_acl_data_packet_length == 0 or total_num_le_acl_data_packets == 0:
|
||||
# LE and Classic share the same queue
|
||||
self.le_acl_packet_queue = self.acl_packet_queue
|
||||
else:
|
||||
# Create a separate queue for LE
|
||||
self.le_acl_packet_queue = AclPacketQueue(
|
||||
max_packet_size=hc_le_acl_data_packet_length,
|
||||
max_in_flight=hc_total_num_le_acl_data_packets,
|
||||
self.le_acl_packet_queue = DataPacketQueue(
|
||||
max_packet_size=le_acl_data_packet_length,
|
||||
max_in_flight=total_num_le_acl_data_packets,
|
||||
send=self.send_hci_packet,
|
||||
)
|
||||
|
||||
if iso_data_packet_length and total_num_iso_data_packets:
|
||||
self.iso_packet_queue = DataPacketQueue(
|
||||
max_packet_size=iso_data_packet_length,
|
||||
max_in_flight=total_num_iso_data_packets,
|
||||
send=self.send_hci_packet,
|
||||
)
|
||||
|
||||
@@ -552,7 +678,7 @@ class Host(AbortableEventEmitter):
|
||||
|
||||
return response
|
||||
except Exception as error:
|
||||
logger.warning(
|
||||
logger.exception(
|
||||
f'{color("!!! Exception while sending command:", "red")} {error}'
|
||||
)
|
||||
raise error
|
||||
@@ -595,11 +721,78 @@ class Host(AbortableEventEmitter):
|
||||
data=l2cap_pdu[offset : offset + data_total_length],
|
||||
)
|
||||
logger.debug(f'>>> ACL packet enqueue: (CID={cid}) {acl_packet}')
|
||||
packet_queue.enqueue(acl_packet)
|
||||
packet_queue.enqueue(acl_packet, connection_handle)
|
||||
pb_flag = 1
|
||||
offset += data_total_length
|
||||
bytes_remaining -= data_total_length
|
||||
|
||||
def get_data_packet_queue(self, connection_handle: int) -> DataPacketQueue | None:
|
||||
if connection := self.connections.get(connection_handle):
|
||||
return connection.acl_packet_queue
|
||||
|
||||
if iso_link := self.cis_links.get(connection_handle) or self.bis_links.get(
|
||||
connection_handle
|
||||
):
|
||||
return iso_link.packet_queue
|
||||
|
||||
return None
|
||||
|
||||
def send_iso_sdu(self, connection_handle: int, sdu: bytes) -> None:
|
||||
if not (
|
||||
iso_link := self.cis_links.get(connection_handle)
|
||||
or self.bis_links.get(connection_handle)
|
||||
):
|
||||
logger.warning(f"no ISO link for connection handle {connection_handle}")
|
||||
return
|
||||
|
||||
if iso_link.packet_queue is None:
|
||||
logger.warning("ISO link has no data packet queue")
|
||||
return
|
||||
|
||||
bytes_remaining = len(sdu)
|
||||
offset = 0
|
||||
while bytes_remaining:
|
||||
is_first_fragment = offset == 0
|
||||
header_length = 4 if is_first_fragment else 0
|
||||
assert iso_link.packet_queue.max_packet_size > header_length
|
||||
fragment_length = min(
|
||||
bytes_remaining, iso_link.packet_queue.max_packet_size - header_length
|
||||
)
|
||||
is_last_fragment = bytes_remaining == fragment_length
|
||||
iso_sdu_fragment = sdu[offset : offset + fragment_length]
|
||||
iso_link.packet_queue.enqueue(
|
||||
(
|
||||
hci.HCI_IsoDataPacket(
|
||||
connection_handle=connection_handle,
|
||||
data_total_length=header_length + fragment_length,
|
||||
packet_sequence_number=iso_link.packet_sequence_number,
|
||||
pb_flag=0b10 if is_last_fragment else 0b00,
|
||||
packet_status_flag=0,
|
||||
iso_sdu_length=len(sdu),
|
||||
iso_sdu_fragment=iso_sdu_fragment,
|
||||
)
|
||||
if is_first_fragment
|
||||
else hci.HCI_IsoDataPacket(
|
||||
connection_handle=connection_handle,
|
||||
data_total_length=fragment_length,
|
||||
pb_flag=0b11 if is_last_fragment else 0b01,
|
||||
iso_sdu_fragment=iso_sdu_fragment,
|
||||
)
|
||||
),
|
||||
connection_handle,
|
||||
)
|
||||
|
||||
offset += fragment_length
|
||||
bytes_remaining -= fragment_length
|
||||
|
||||
iso_link.packet_sequence_number = (iso_link.packet_sequence_number + 1) & 0xFFFF
|
||||
|
||||
def remove_big(self, big_handle: int) -> None:
|
||||
if big := self.bigs.pop(big_handle, None):
|
||||
for connection_handle in big:
|
||||
if bis_link := self.bis_links.pop(connection_handle, None):
|
||||
bis_link.packet_queue.flush(bis_link.handle)
|
||||
|
||||
def supports_command(self, op_code: int) -> bool:
|
||||
return (
|
||||
self.local_supported_commands
|
||||
@@ -727,16 +920,17 @@ class Host(AbortableEventEmitter):
|
||||
def on_hci_command_status_event(self, event):
|
||||
return self.on_command_processed(event)
|
||||
|
||||
def on_hci_number_of_completed_packets_event(self, event):
|
||||
def on_hci_number_of_completed_packets_event(
|
||||
self, event: hci.HCI_Number_Of_Completed_Packets_Event
|
||||
) -> None:
|
||||
for connection_handle, num_completed_packets in zip(
|
||||
event.connection_handles, event.num_completed_packets
|
||||
):
|
||||
if connection := self.connections.get(connection_handle):
|
||||
connection.acl_packet_queue.on_packets_completed(num_completed_packets)
|
||||
elif not (
|
||||
self.cis_links.get(connection_handle)
|
||||
or self.sco_links.get(connection_handle)
|
||||
):
|
||||
if queue := self.get_data_packet_queue(connection_handle):
|
||||
queue.on_packets_completed(num_completed_packets, connection_handle)
|
||||
continue
|
||||
|
||||
if connection_handle not in self.sco_links:
|
||||
logger.warning(
|
||||
'received packet completion event for unknown handle '
|
||||
f'0x{connection_handle:04X}'
|
||||
@@ -854,11 +1048,7 @@ class Host(AbortableEventEmitter):
|
||||
return
|
||||
|
||||
if event.status == hci.HCI_SUCCESS:
|
||||
logger.debug(
|
||||
f'### DISCONNECTION: [0x{handle:04X}] '
|
||||
f'{connection.peer_address} '
|
||||
f'reason={event.reason}'
|
||||
)
|
||||
logger.debug(f'### DISCONNECTION: {connection}, reason={event.reason}')
|
||||
|
||||
# Notify the listeners
|
||||
self.emit('disconnection', handle, event.reason)
|
||||
@@ -869,6 +1059,12 @@ class Host(AbortableEventEmitter):
|
||||
or self.cis_links.pop(handle, 0)
|
||||
or self.sco_links.pop(handle, 0)
|
||||
)
|
||||
|
||||
# Flush the data queues
|
||||
self.acl_packet_queue.flush(handle)
|
||||
self.le_acl_packet_queue.flush(handle)
|
||||
if self.iso_packet_queue:
|
||||
self.iso_packet_queue.flush(handle)
|
||||
else:
|
||||
logger.debug(f'### DISCONNECTION FAILED: {event.status}')
|
||||
|
||||
@@ -953,12 +1149,94 @@ class Host(AbortableEventEmitter):
|
||||
event.cis_id,
|
||||
)
|
||||
|
||||
def on_hci_le_create_big_complete_event(self, event):
|
||||
self.bigs[event.big_handle] = set(event.connection_handle)
|
||||
if self.iso_packet_queue is None:
|
||||
logger.warning("BIS established but ISO packets not supported")
|
||||
|
||||
for connection_handle in event.connection_handle:
|
||||
self.bis_links[connection_handle] = IsoLink(
|
||||
connection_handle, self.iso_packet_queue
|
||||
)
|
||||
|
||||
self.emit(
|
||||
'big_establishment',
|
||||
event.status,
|
||||
event.big_handle,
|
||||
event.connection_handle,
|
||||
event.big_sync_delay,
|
||||
event.transport_latency_big,
|
||||
event.phy,
|
||||
event.nse,
|
||||
event.bn,
|
||||
event.pto,
|
||||
event.irc,
|
||||
event.max_pdu,
|
||||
event.iso_interval,
|
||||
)
|
||||
|
||||
def on_hci_le_big_sync_established_event(self, event):
|
||||
self.bigs[event.big_handle] = set(event.connection_handle)
|
||||
for connection_handle in event.connection_handle:
|
||||
self.bis_links[connection_handle] = IsoLink(
|
||||
connection_handle, self.iso_packet_queue
|
||||
)
|
||||
|
||||
self.emit(
|
||||
'big_sync_establishment',
|
||||
event.status,
|
||||
event.big_handle,
|
||||
event.transport_latency_big,
|
||||
event.nse,
|
||||
event.bn,
|
||||
event.pto,
|
||||
event.irc,
|
||||
event.max_pdu,
|
||||
event.iso_interval,
|
||||
event.connection_handle,
|
||||
)
|
||||
|
||||
def on_hci_le_big_sync_lost_event(self, event):
|
||||
self.remove_big(event.big_handle)
|
||||
self.emit('big_sync_lost', event.big_handle, event.reason)
|
||||
|
||||
def on_hci_le_terminate_big_complete_event(self, event):
|
||||
self.remove_big(event.big_handle)
|
||||
self.emit('big_termination', event.reason, event.big_handle)
|
||||
|
||||
def on_hci_le_periodic_advertising_sync_transfer_received_event(self, event):
|
||||
self.emit(
|
||||
'periodic_advertising_sync_transfer',
|
||||
event.status,
|
||||
event.connection_handle,
|
||||
event.sync_handle,
|
||||
event.advertising_sid,
|
||||
event.advertiser_address,
|
||||
event.advertiser_phy,
|
||||
event.periodic_advertising_interval,
|
||||
event.advertiser_clock_accuracy,
|
||||
)
|
||||
|
||||
def on_hci_le_periodic_advertising_sync_transfer_received_v2_event(self, event):
|
||||
self.emit(
|
||||
'periodic_advertising_sync_transfer',
|
||||
event.status,
|
||||
event.connection_handle,
|
||||
event.sync_handle,
|
||||
event.advertising_sid,
|
||||
event.advertiser_address,
|
||||
event.advertiser_phy,
|
||||
event.periodic_advertising_interval,
|
||||
event.advertiser_clock_accuracy,
|
||||
)
|
||||
|
||||
def on_hci_le_cis_established_event(self, event):
|
||||
# The remaining parameters are unused for now.
|
||||
if event.status == hci.HCI_SUCCESS:
|
||||
self.cis_links[event.connection_handle] = CisLink(
|
||||
handle=event.connection_handle,
|
||||
peer_address=hci.Address.ANY,
|
||||
if self.iso_packet_queue is None:
|
||||
logger.warning("CIS established but ISO packets not supported")
|
||||
self.cis_links[event.connection_handle] = IsoLink(
|
||||
handle=event.connection_handle, packet_queue=self.iso_packet_queue
|
||||
)
|
||||
self.emit('cis_establishment', event.connection_handle)
|
||||
else:
|
||||
@@ -1028,7 +1306,7 @@ class Host(AbortableEventEmitter):
|
||||
|
||||
self.sco_links[event.connection_handle] = ScoLink(
|
||||
peer_address=event.bd_addr,
|
||||
handle=event.connection_handle,
|
||||
connection_handle=event.connection_handle,
|
||||
)
|
||||
|
||||
# Notify the client
|
||||
@@ -1248,3 +1526,24 @@ class Host(AbortableEventEmitter):
|
||||
event.connection_handle,
|
||||
int.from_bytes(event.le_features, 'little'),
|
||||
)
|
||||
|
||||
def on_hci_le_cs_read_remote_supported_capabilities_complete_event(self, event):
|
||||
self.emit('cs_remote_supported_capabilities', event)
|
||||
|
||||
def on_hci_le_cs_security_enable_complete_event(self, event):
|
||||
self.emit('cs_security', event)
|
||||
|
||||
def on_hci_le_cs_config_complete_event(self, event):
|
||||
self.emit('cs_config', event)
|
||||
|
||||
def on_hci_le_cs_procedure_enable_complete_event(self, event):
|
||||
self.emit('cs_procedure', event)
|
||||
|
||||
def on_hci_le_cs_subevent_result_event(self, event):
|
||||
self.emit('cs_subevent_result', event)
|
||||
|
||||
def on_hci_le_cs_subevent_result_continue_event(self, event):
|
||||
self.emit('cs_subevent_result_continue', event)
|
||||
|
||||
def on_hci_vendor_event(self, event):
|
||||
self.emit('vendor_event', event)
|
||||
|
||||
@@ -225,7 +225,7 @@ class L2CAP_PDU:
|
||||
|
||||
return L2CAP_PDU(l2cap_pdu_cid, l2cap_pdu_payload)
|
||||
|
||||
def to_bytes(self) -> bytes:
|
||||
def __bytes__(self) -> bytes:
|
||||
header = struct.pack('<HH', len(self.payload), self.cid)
|
||||
return header + self.payload
|
||||
|
||||
@@ -233,9 +233,6 @@ class L2CAP_PDU:
|
||||
self.cid = cid
|
||||
self.payload = payload
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return self.to_bytes()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f'{color("L2CAP", "green")} [CID={self.cid}]: {self.payload.hex()}'
|
||||
|
||||
@@ -333,11 +330,8 @@ class L2CAP_Control_Frame:
|
||||
def init_from_bytes(self, pdu, offset):
|
||||
return HCI_Object.init_from_bytes(self, pdu, offset, self.fields)
|
||||
|
||||
def to_bytes(self) -> bytes:
|
||||
return self.pdu
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return self.to_bytes()
|
||||
return self.pdu
|
||||
|
||||
def __str__(self) -> str:
|
||||
result = f'{color(self.name, "yellow")} [ID={self.identifier}]'
|
||||
@@ -779,7 +773,6 @@ class ClassicChannel(EventEmitter):
|
||||
self.psm = psm
|
||||
self.source_cid = source_cid
|
||||
self.destination_cid = 0
|
||||
self.response = None
|
||||
self.connection_result = None
|
||||
self.disconnection_result = None
|
||||
self.sink = None
|
||||
@@ -789,27 +782,15 @@ class ClassicChannel(EventEmitter):
|
||||
self.state = new_state
|
||||
|
||||
def send_pdu(self, pdu: Union[SupportsBytes, bytes]) -> None:
|
||||
if self.state != self.State.OPEN:
|
||||
raise InvalidStateError('channel not open')
|
||||
self.manager.send_pdu(self.connection, self.destination_cid, pdu)
|
||||
|
||||
def send_control_frame(self, frame: L2CAP_Control_Frame) -> None:
|
||||
self.manager.send_control_frame(self.connection, self.signaling_cid, frame)
|
||||
|
||||
async def send_request(self, request: SupportsBytes) -> bytes:
|
||||
# Check that there isn't already a request pending
|
||||
if self.response:
|
||||
raise InvalidStateError('request already pending')
|
||||
if self.state != self.State.OPEN:
|
||||
raise InvalidStateError('channel not open')
|
||||
|
||||
self.response = asyncio.get_running_loop().create_future()
|
||||
self.send_pdu(request)
|
||||
return await self.response
|
||||
|
||||
def on_pdu(self, pdu: bytes) -> None:
|
||||
if self.response:
|
||||
self.response.set_result(pdu)
|
||||
self.response = None
|
||||
elif self.sink:
|
||||
if self.sink:
|
||||
# pylint: disable=not-callable
|
||||
self.sink(pdu)
|
||||
else:
|
||||
@@ -1911,6 +1892,7 @@ class ChannelManager:
|
||||
data = sum(1 << cid for cid in self.fixed_channels).to_bytes(8, 'little')
|
||||
else:
|
||||
result = L2CAP_Information_Response.NOT_SUPPORTED
|
||||
data = b''
|
||||
|
||||
self.send_control_frame(
|
||||
connection,
|
||||
|
||||
@@ -122,6 +122,8 @@ class LocalLink:
|
||||
elif transport == BT_BR_EDR_TRANSPORT:
|
||||
destination_controller = self.find_classic_controller(destination_address)
|
||||
source_address = sender_controller.public_address
|
||||
else:
|
||||
raise ValueError("unsupported transport type")
|
||||
|
||||
if destination_controller is not None:
|
||||
destination_controller.on_link_acl_data(source_address, transport, data)
|
||||
|
||||
@@ -139,16 +139,19 @@ class PairingDelegate:
|
||||
io_capability: IoCapability
|
||||
local_initiator_key_distribution: KeyDistribution
|
||||
local_responder_key_distribution: KeyDistribution
|
||||
maximum_encryption_key_size: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
io_capability: IoCapability = NO_OUTPUT_NO_INPUT,
|
||||
local_initiator_key_distribution: KeyDistribution = DEFAULT_KEY_DISTRIBUTION,
|
||||
local_responder_key_distribution: KeyDistribution = DEFAULT_KEY_DISTRIBUTION,
|
||||
maximum_encryption_key_size: int = 16,
|
||||
) -> None:
|
||||
self.io_capability = io_capability
|
||||
self.local_initiator_key_distribution = local_initiator_key_distribution
|
||||
self.local_responder_key_distribution = local_responder_key_distribution
|
||||
self.maximum_encryption_key_size = maximum_encryption_key_size
|
||||
|
||||
@property
|
||||
def classic_io_capability(self) -> int:
|
||||
|
||||
@@ -39,7 +39,6 @@ from bumble.device import (
|
||||
AdvertisingEventProperties,
|
||||
AdvertisingType,
|
||||
Device,
|
||||
Phy,
|
||||
)
|
||||
from bumble.gatt import Service
|
||||
from bumble.hci import (
|
||||
@@ -47,6 +46,7 @@ from bumble.hci import (
|
||||
HCI_PAGE_TIMEOUT_ERROR,
|
||||
HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR,
|
||||
Address,
|
||||
Phy,
|
||||
)
|
||||
from google.protobuf import any_pb2 # pytype: disable=pyi-error
|
||||
from google.protobuf import empty_pb2 # pytype: disable=pyi-error
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import struct
|
||||
|
||||
@@ -28,10 +29,11 @@ from bumble.device import Connection
|
||||
from bumble.att import ATT_Error
|
||||
from bumble.gatt import (
|
||||
Characteristic,
|
||||
DelegatedCharacteristicAdapter,
|
||||
SerializableCharacteristicAdapter,
|
||||
PackedCharacteristicAdapter,
|
||||
TemplateService,
|
||||
CharacteristicValue,
|
||||
PackedCharacteristicAdapter,
|
||||
UTF8CharacteristicAdapter,
|
||||
GATT_AUDIO_INPUT_CONTROL_SERVICE,
|
||||
GATT_AUDIO_INPUT_STATE_CHARACTERISTIC,
|
||||
GATT_GAIN_SETTINGS_ATTRIBUTE_CHARACTERISTIC,
|
||||
@@ -95,7 +97,7 @@ class AudioInputStatus(OpenIntEnum):
|
||||
Cf. 3.4 Audio Input Status
|
||||
'''
|
||||
|
||||
INATIVE = 0x00
|
||||
INACTIVE = 0x00
|
||||
ACTIVE = 0x01
|
||||
|
||||
|
||||
@@ -104,7 +106,7 @@ class AudioInputControlPointOpCode(OpenIntEnum):
|
||||
Cf. 3.5.1 Audio Input Control Point procedure requirements
|
||||
'''
|
||||
|
||||
SET_GAIN_SETTING = 0x00
|
||||
SET_GAIN_SETTING = 0x01
|
||||
UNMUTE = 0x02
|
||||
MUTE = 0x03
|
||||
SET_MANUAL_GAIN_MODE = 0x04
|
||||
@@ -154,9 +156,6 @@ class AudioInputState:
|
||||
attribute=self.attribute_value, value=bytes(self)
|
||||
)
|
||||
|
||||
def on_read(self, _connection: Optional[Connection]) -> bytes:
|
||||
return bytes(self)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GainSettingsProperties:
|
||||
@@ -173,7 +172,7 @@ class GainSettingsProperties:
|
||||
(gain_settings_unit, gain_settings_minimum, gain_settings_maximum) = (
|
||||
struct.unpack('BBB', data)
|
||||
)
|
||||
GainSettingsProperties(
|
||||
return GainSettingsProperties(
|
||||
gain_settings_unit, gain_settings_minimum, gain_settings_maximum
|
||||
)
|
||||
|
||||
@@ -186,9 +185,6 @@ class GainSettingsProperties:
|
||||
]
|
||||
)
|
||||
|
||||
def on_read(self, _connection: Optional[Connection]) -> bytes:
|
||||
return bytes(self)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AudioInputControlPoint:
|
||||
@@ -239,7 +235,7 @@ class AudioInputControlPoint:
|
||||
or gain_settings_operand
|
||||
> self.gain_settings_properties.gain_settings_maximum
|
||||
):
|
||||
logger.error("gain_seetings value out of range")
|
||||
logger.error("gain_settings value out of range")
|
||||
raise ATT_Error(ErrorCode.VALUE_OUT_OF_RANGE)
|
||||
|
||||
if self.audio_input_state.gain_settings != gain_settings_operand:
|
||||
@@ -321,21 +317,14 @@ class AudioInputDescription:
|
||||
audio_input_description: str = "Bluetooth"
|
||||
attribute_value: Optional[CharacteristicValue] = None
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes):
|
||||
return cls(audio_input_description=data.decode('utf-8'))
|
||||
def on_read(self, _connection: Optional[Connection]) -> str:
|
||||
return self.audio_input_description
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return self.audio_input_description.encode('utf-8')
|
||||
|
||||
def on_read(self, _connection: Optional[Connection]) -> bytes:
|
||||
return self.audio_input_description.encode('utf-8')
|
||||
|
||||
async def on_write(self, connection: Optional[Connection], value: bytes) -> None:
|
||||
async def on_write(self, connection: Optional[Connection], value: str) -> None:
|
||||
assert connection
|
||||
assert self.attribute_value
|
||||
|
||||
self.audio_input_description = value.decode('utf-8')
|
||||
self.audio_input_description = value
|
||||
await connection.device.notify_subscribers(
|
||||
attribute=self.attribute_value, value=value
|
||||
)
|
||||
@@ -375,26 +364,29 @@ class AICSService(TemplateService):
|
||||
self.audio_input_state, self.gain_settings_properties
|
||||
)
|
||||
|
||||
self.audio_input_state_characteristic = DelegatedCharacteristicAdapter(
|
||||
self.audio_input_state_characteristic = SerializableCharacteristicAdapter(
|
||||
Characteristic(
|
||||
uuid=GATT_AUDIO_INPUT_STATE_CHARACTERISTIC,
|
||||
properties=Characteristic.Properties.READ
|
||||
| Characteristic.Properties.NOTIFY,
|
||||
permissions=Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
|
||||
value=CharacteristicValue(read=self.audio_input_state.on_read),
|
||||
value=self.audio_input_state,
|
||||
),
|
||||
encode=lambda value: bytes(value),
|
||||
AudioInputState,
|
||||
)
|
||||
self.audio_input_state.attribute_value = (
|
||||
self.audio_input_state_characteristic.value
|
||||
)
|
||||
|
||||
self.gain_settings_properties_characteristic = DelegatedCharacteristicAdapter(
|
||||
Characteristic(
|
||||
uuid=GATT_GAIN_SETTINGS_ATTRIBUTE_CHARACTERISTIC,
|
||||
properties=Characteristic.Properties.READ,
|
||||
permissions=Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
|
||||
value=CharacteristicValue(read=self.gain_settings_properties.on_read),
|
||||
self.gain_settings_properties_characteristic = (
|
||||
SerializableCharacteristicAdapter(
|
||||
Characteristic(
|
||||
uuid=GATT_GAIN_SETTINGS_ATTRIBUTE_CHARACTERISTIC,
|
||||
properties=Characteristic.Properties.READ,
|
||||
permissions=Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
|
||||
value=self.gain_settings_properties,
|
||||
),
|
||||
GainSettingsProperties,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -402,7 +394,7 @@ class AICSService(TemplateService):
|
||||
uuid=GATT_AUDIO_INPUT_TYPE_CHARACTERISTIC,
|
||||
properties=Characteristic.Properties.READ,
|
||||
permissions=Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
|
||||
value=audio_input_type,
|
||||
value=bytes(audio_input_type, 'utf-8'),
|
||||
)
|
||||
|
||||
self.audio_input_status_characteristic = Characteristic(
|
||||
@@ -412,18 +404,14 @@ class AICSService(TemplateService):
|
||||
value=bytes([self.audio_input_status]),
|
||||
)
|
||||
|
||||
self.audio_input_control_point_characteristic = DelegatedCharacteristicAdapter(
|
||||
Characteristic(
|
||||
uuid=GATT_AUDIO_INPUT_CONTROL_POINT_CHARACTERISTIC,
|
||||
properties=Characteristic.Properties.WRITE,
|
||||
permissions=Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION,
|
||||
value=CharacteristicValue(
|
||||
write=self.audio_input_control_point.on_write
|
||||
),
|
||||
)
|
||||
self.audio_input_control_point_characteristic = Characteristic(
|
||||
uuid=GATT_AUDIO_INPUT_CONTROL_POINT_CHARACTERISTIC,
|
||||
properties=Characteristic.Properties.WRITE,
|
||||
permissions=Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION,
|
||||
value=CharacteristicValue(write=self.audio_input_control_point.on_write),
|
||||
)
|
||||
|
||||
self.audio_input_description_characteristic = DelegatedCharacteristicAdapter(
|
||||
self.audio_input_description_characteristic = UTF8CharacteristicAdapter(
|
||||
Characteristic(
|
||||
uuid=GATT_AUDIO_INPUT_DESCRIPTION_CHARACTERISTIC,
|
||||
properties=Characteristic.Properties.READ
|
||||
@@ -463,58 +451,35 @@ class AICSServiceProxy(ProfileServiceProxy):
|
||||
def __init__(self, service_proxy: ServiceProxy) -> None:
|
||||
self.service_proxy = service_proxy
|
||||
|
||||
if not (
|
||||
characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
self.audio_input_state = SerializableCharacteristicAdapter(
|
||||
service_proxy.get_required_characteristic_by_uuid(
|
||||
GATT_AUDIO_INPUT_STATE_CHARACTERISTIC
|
||||
)
|
||||
):
|
||||
raise gatt.InvalidServiceError("Audio Input State Characteristic not found")
|
||||
self.audio_input_state = DelegatedCharacteristicAdapter(
|
||||
characteristic=characteristics[0], decode=AudioInputState.from_bytes
|
||||
),
|
||||
AudioInputState,
|
||||
)
|
||||
|
||||
if not (
|
||||
characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
self.gain_settings_properties = SerializableCharacteristicAdapter(
|
||||
service_proxy.get_required_characteristic_by_uuid(
|
||||
GATT_GAIN_SETTINGS_ATTRIBUTE_CHARACTERISTIC
|
||||
)
|
||||
):
|
||||
raise gatt.InvalidServiceError(
|
||||
"Gain Settings Attribute Characteristic not found"
|
||||
)
|
||||
self.gain_settings_properties = PackedCharacteristicAdapter(
|
||||
characteristics[0],
|
||||
'BBB',
|
||||
),
|
||||
GainSettingsProperties,
|
||||
)
|
||||
|
||||
if not (
|
||||
characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
GATT_AUDIO_INPUT_STATUS_CHARACTERISTIC
|
||||
)
|
||||
):
|
||||
raise gatt.InvalidServiceError(
|
||||
"Audio Input Status Characteristic not found"
|
||||
)
|
||||
self.audio_input_status = PackedCharacteristicAdapter(
|
||||
characteristics[0],
|
||||
service_proxy.get_required_characteristic_by_uuid(
|
||||
GATT_AUDIO_INPUT_STATUS_CHARACTERISTIC
|
||||
),
|
||||
'B',
|
||||
)
|
||||
|
||||
if not (
|
||||
characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
self.audio_input_control_point = (
|
||||
service_proxy.get_required_characteristic_by_uuid(
|
||||
GATT_AUDIO_INPUT_CONTROL_POINT_CHARACTERISTIC
|
||||
)
|
||||
):
|
||||
raise gatt.InvalidServiceError(
|
||||
"Audio Input Control Point Characteristic not found"
|
||||
)
|
||||
self.audio_input_control_point = characteristics[0]
|
||||
)
|
||||
|
||||
if not (
|
||||
characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
self.audio_input_description = UTF8CharacteristicAdapter(
|
||||
service_proxy.get_required_characteristic_by_uuid(
|
||||
GATT_AUDIO_INPUT_DESCRIPTION_CHARACTERISTIC
|
||||
)
|
||||
):
|
||||
raise gatt.InvalidServiceError(
|
||||
"Audio Input Description Characteristic not found"
|
||||
)
|
||||
self.audio_input_description = characteristics[0]
|
||||
)
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import logging
|
||||
import struct
|
||||
@@ -258,8 +259,8 @@ class AseReasonCode(enum.IntEnum):
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class AudioRole(enum.IntEnum):
|
||||
SINK = hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction.CONTROLLER_TO_HOST
|
||||
SOURCE = hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction.HOST_TO_CONTROLLER
|
||||
SINK = device.CisLink.Direction.CONTROLLER_TO_HOST
|
||||
SOURCE = device.CisLink.Direction.HOST_TO_CONTROLLER
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -354,16 +355,7 @@ class AseStateMachine(gatt.Characteristic):
|
||||
cis_link.on('disconnection', self.on_cis_disconnection)
|
||||
|
||||
async def post_cis_established():
|
||||
await self.service.device.send_command(
|
||||
hci.HCI_LE_Setup_ISO_Data_Path_Command(
|
||||
connection_handle=cis_link.handle,
|
||||
data_path_direction=self.role,
|
||||
data_path_id=0x00, # Fixed HCI
|
||||
codec_id=hci.CodingFormat(hci.CodecID.TRANSPARENT),
|
||||
controller_delay=0,
|
||||
codec_configuration=b'',
|
||||
)
|
||||
)
|
||||
await cis_link.setup_data_path(direction=self.role)
|
||||
if self.role == AudioRole.SINK:
|
||||
self.state = self.State.STREAMING
|
||||
await self.service.device.notify_subscribers(self, self.value)
|
||||
@@ -511,12 +503,8 @@ class AseStateMachine(gatt.Characteristic):
|
||||
self.state = self.State.RELEASING
|
||||
|
||||
async def remove_cis_async():
|
||||
await self.service.device.send_command(
|
||||
hci.HCI_LE_Remove_ISO_Data_Path_Command(
|
||||
connection_handle=self.cis_link.handle,
|
||||
data_path_direction=self.role,
|
||||
)
|
||||
)
|
||||
if self.cis_link:
|
||||
await self.cis_link.remove_data_path(self.role)
|
||||
self.state = self.State.IDLE
|
||||
await self.service.device.notify_subscribers(self, self.value)
|
||||
|
||||
|
||||
@@ -288,8 +288,8 @@ class AshaServiceProxy(gatt_client.ProfileServiceProxy):
|
||||
'psm_characteristic',
|
||||
),
|
||||
):
|
||||
if not (
|
||||
characteristics := self.service_proxy.get_characteristics_by_uuid(uuid)
|
||||
):
|
||||
raise gatt.InvalidServiceError(f"Missing {uuid} Characteristic")
|
||||
setattr(self, attribute_name, characteristics[0])
|
||||
setattr(
|
||||
self,
|
||||
attribute_name,
|
||||
self.service_proxy.get_required_characteristic_by_uuid(uuid),
|
||||
)
|
||||
|
||||
@@ -102,6 +102,7 @@ class ContextType(enum.IntFlag):
|
||||
|
||||
# fmt: off
|
||||
PROHIBITED = 0x0000
|
||||
UNSPECIFIED = 0x0001
|
||||
CONVERSATIONAL = 0x0002
|
||||
MEDIA = 0x0004
|
||||
GAME = 0x0008
|
||||
@@ -264,7 +265,7 @@ class UnicastServerAdvertisingData:
|
||||
core.AdvertisingData.SERVICE_DATA_16_BIT_UUID,
|
||||
struct.pack(
|
||||
'<2sBIB',
|
||||
gatt.GATT_AUDIO_STREAM_CONTROL_SERVICE.to_bytes(),
|
||||
bytes(gatt.GATT_AUDIO_STREAM_CONTROL_SERVICE),
|
||||
self.announcement_type,
|
||||
self.available_audio_contexts,
|
||||
len(self.metadata),
|
||||
@@ -350,6 +351,7 @@ class CodecSpecificCapabilities:
|
||||
supported_max_codec_frames_per_sdu = value
|
||||
|
||||
# It is expected here that if some fields are missing, an error should be raised.
|
||||
# pylint: disable=possibly-used-before-assignment,used-before-assignment
|
||||
return CodecSpecificCapabilities(
|
||||
supported_sampling_frequencies=supported_sampling_frequencies,
|
||||
supported_frame_durations=supported_frame_durations,
|
||||
@@ -396,18 +398,21 @@ class CodecSpecificConfiguration:
|
||||
OCTETS_PER_FRAME = 0x04
|
||||
CODEC_FRAMES_PER_SDU = 0x05
|
||||
|
||||
sampling_frequency: SamplingFrequency
|
||||
frame_duration: FrameDuration
|
||||
audio_channel_allocation: AudioLocation
|
||||
octets_per_codec_frame: int
|
||||
codec_frames_per_sdu: int
|
||||
sampling_frequency: SamplingFrequency | None = None
|
||||
frame_duration: FrameDuration | None = None
|
||||
audio_channel_allocation: AudioLocation | None = None
|
||||
octets_per_codec_frame: int | None = None
|
||||
codec_frames_per_sdu: int | None = None
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes) -> CodecSpecificConfiguration:
|
||||
offset = 0
|
||||
# Allowed default values.
|
||||
audio_channel_allocation = AudioLocation.NOT_ALLOWED
|
||||
codec_frames_per_sdu = 1
|
||||
sampling_frequency: SamplingFrequency | None = None
|
||||
frame_duration: FrameDuration | None = None
|
||||
audio_channel_allocation: AudioLocation | None = None
|
||||
octets_per_codec_frame: int | None = None
|
||||
codec_frames_per_sdu: int | None = None
|
||||
|
||||
while offset < len(data):
|
||||
length, type = struct.unpack_from('BB', data, offset)
|
||||
offset += 2
|
||||
@@ -425,7 +430,6 @@ class CodecSpecificConfiguration:
|
||||
elif type == CodecSpecificConfiguration.Type.CODEC_FRAMES_PER_SDU:
|
||||
codec_frames_per_sdu = value
|
||||
|
||||
# It is expected here that if some fields are missing, an error should be raised.
|
||||
return CodecSpecificConfiguration(
|
||||
sampling_frequency=sampling_frequency,
|
||||
frame_duration=frame_duration,
|
||||
@@ -435,23 +439,43 @@ class CodecSpecificConfiguration:
|
||||
)
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return struct.pack(
|
||||
'<BBBBBBBBIBBHBBB',
|
||||
2,
|
||||
CodecSpecificConfiguration.Type.SAMPLING_FREQUENCY,
|
||||
self.sampling_frequency,
|
||||
2,
|
||||
CodecSpecificConfiguration.Type.FRAME_DURATION,
|
||||
self.frame_duration,
|
||||
5,
|
||||
CodecSpecificConfiguration.Type.AUDIO_CHANNEL_ALLOCATION,
|
||||
self.audio_channel_allocation,
|
||||
3,
|
||||
CodecSpecificConfiguration.Type.OCTETS_PER_FRAME,
|
||||
self.octets_per_codec_frame,
|
||||
2,
|
||||
CodecSpecificConfiguration.Type.CODEC_FRAMES_PER_SDU,
|
||||
self.codec_frames_per_sdu,
|
||||
return b''.join(
|
||||
[
|
||||
struct.pack(fmt, length, tag, value)
|
||||
for fmt, length, tag, value in [
|
||||
(
|
||||
'<BBB',
|
||||
2,
|
||||
CodecSpecificConfiguration.Type.SAMPLING_FREQUENCY,
|
||||
self.sampling_frequency,
|
||||
),
|
||||
(
|
||||
'<BBB',
|
||||
2,
|
||||
CodecSpecificConfiguration.Type.FRAME_DURATION,
|
||||
self.frame_duration,
|
||||
),
|
||||
(
|
||||
'<BBI',
|
||||
5,
|
||||
CodecSpecificConfiguration.Type.AUDIO_CHANNEL_ALLOCATION,
|
||||
self.audio_channel_allocation,
|
||||
),
|
||||
(
|
||||
'<BBH',
|
||||
3,
|
||||
CodecSpecificConfiguration.Type.OCTETS_PER_FRAME,
|
||||
self.octets_per_codec_frame,
|
||||
),
|
||||
(
|
||||
'<BBB',
|
||||
2,
|
||||
CodecSpecificConfiguration.Type.CODEC_FRAMES_PER_SDU,
|
||||
self.codec_frames_per_sdu,
|
||||
),
|
||||
]
|
||||
if value is not None
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -463,6 +487,24 @@ class BroadcastAudioAnnouncement:
|
||||
def from_bytes(cls, data: bytes) -> Self:
|
||||
return cls(int.from_bytes(data[:3], 'little'))
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return self.broadcast_id.to_bytes(3, 'little')
|
||||
|
||||
def get_advertising_data(self) -> bytes:
|
||||
return bytes(
|
||||
core.AdvertisingData(
|
||||
[
|
||||
(
|
||||
core.AdvertisingData.SERVICE_DATA_16_BIT_UUID,
|
||||
(
|
||||
bytes(gatt.GATT_BROADCAST_AUDIO_ANNOUNCEMENT_SERVICE)
|
||||
+ bytes(self)
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BasicAudioAnnouncement:
|
||||
@@ -471,26 +513,37 @@ class BasicAudioAnnouncement:
|
||||
index: int
|
||||
codec_specific_configuration: CodecSpecificConfiguration
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CodecInfo:
|
||||
coding_format: hci.CodecID
|
||||
company_id: int
|
||||
vendor_specific_codec_id: int
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes) -> Self:
|
||||
coding_format = hci.CodecID(data[0])
|
||||
company_id = int.from_bytes(data[1:3], 'little')
|
||||
vendor_specific_codec_id = int.from_bytes(data[3:5], 'little')
|
||||
return cls(coding_format, company_id, vendor_specific_codec_id)
|
||||
def __bytes__(self) -> bytes:
|
||||
codec_specific_configuration_bytes = bytes(
|
||||
self.codec_specific_configuration
|
||||
)
|
||||
return (
|
||||
bytes([self.index, len(codec_specific_configuration_bytes)])
|
||||
+ codec_specific_configuration_bytes
|
||||
)
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Subgroup:
|
||||
codec_id: BasicAudioAnnouncement.CodecInfo
|
||||
codec_id: hci.CodingFormat
|
||||
codec_specific_configuration: CodecSpecificConfiguration
|
||||
metadata: le_audio.Metadata
|
||||
bis: List[BasicAudioAnnouncement.BIS]
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
metadata_bytes = bytes(self.metadata)
|
||||
codec_specific_configuration_bytes = bytes(
|
||||
self.codec_specific_configuration
|
||||
)
|
||||
return (
|
||||
bytes([len(self.bis)])
|
||||
+ bytes(self.codec_id)
|
||||
+ bytes([len(codec_specific_configuration_bytes)])
|
||||
+ codec_specific_configuration_bytes
|
||||
+ bytes([len(metadata_bytes)])
|
||||
+ metadata_bytes
|
||||
+ b''.join(map(bytes, self.bis))
|
||||
)
|
||||
|
||||
presentation_delay: int
|
||||
subgroups: List[BasicAudioAnnouncement.Subgroup]
|
||||
|
||||
@@ -502,7 +555,7 @@ class BasicAudioAnnouncement:
|
||||
for _ in range(data[3]):
|
||||
num_bis = data[offset]
|
||||
offset += 1
|
||||
codec_id = cls.CodecInfo.from_bytes(data[offset : offset + 5])
|
||||
codec_id = hci.CodingFormat.from_bytes(data[offset : offset + 5])
|
||||
offset += 5
|
||||
codec_specific_configuration_length = data[offset]
|
||||
offset += 1
|
||||
@@ -546,3 +599,25 @@ class BasicAudioAnnouncement:
|
||||
)
|
||||
|
||||
return cls(presentation_delay, subgroups)
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return (
|
||||
self.presentation_delay.to_bytes(3, 'little')
|
||||
+ bytes([len(self.subgroups)])
|
||||
+ b''.join(map(bytes, self.subgroups))
|
||||
)
|
||||
|
||||
def get_advertising_data(self) -> bytes:
|
||||
return bytes(
|
||||
core.AdvertisingData(
|
||||
[
|
||||
(
|
||||
core.AdvertisingData.SERVICE_DATA_16_BIT_UUID,
|
||||
(
|
||||
bytes(gatt.GATT_BASIC_AUDIO_ANNOUNCEMENT_SERVICE)
|
||||
+ bytes(self)
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
@@ -276,10 +276,7 @@ class BroadcastReceiveState:
|
||||
subgroups: List[SubgroupInfo]
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes) -> Optional[BroadcastReceiveState]:
|
||||
if not data:
|
||||
return None
|
||||
|
||||
def from_bytes(cls, data: bytes) -> BroadcastReceiveState:
|
||||
source_id = data[0]
|
||||
_, source_address = hci.Address.parse_address_preceded_by_type(data, 2)
|
||||
source_adv_sid = data[8]
|
||||
@@ -362,29 +359,20 @@ class BroadcastAudioScanServiceProxy(gatt_client.ProfileServiceProxy):
|
||||
def __init__(self, service_proxy: gatt_client.ServiceProxy):
|
||||
self.service_proxy = service_proxy
|
||||
|
||||
if not (
|
||||
characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
self.broadcast_audio_scan_control_point = (
|
||||
service_proxy.get_required_characteristic_by_uuid(
|
||||
gatt.GATT_BROADCAST_AUDIO_SCAN_CONTROL_POINT_CHARACTERISTIC
|
||||
)
|
||||
):
|
||||
raise gatt.InvalidServiceError(
|
||||
"Broadcast Audio Scan Control Point characteristic not found"
|
||||
)
|
||||
self.broadcast_audio_scan_control_point = characteristics[0]
|
||||
)
|
||||
|
||||
if not (
|
||||
characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
gatt.GATT_BROADCAST_RECEIVE_STATE_CHARACTERISTIC
|
||||
)
|
||||
):
|
||||
raise gatt.InvalidServiceError(
|
||||
"Broadcast Receive State characteristic not found"
|
||||
)
|
||||
self.broadcast_receive_states = [
|
||||
gatt.DelegatedCharacteristicAdapter(
|
||||
characteristic, decode=BroadcastReceiveState.from_bytes
|
||||
characteristic,
|
||||
decode=lambda x: BroadcastReceiveState.from_bytes(x) if x else None,
|
||||
)
|
||||
for characteristic in service_proxy.get_characteristics_by_uuid(
|
||||
gatt.GATT_BROADCAST_RECEIVE_STATE_CHARACTERISTIC
|
||||
)
|
||||
for characteristic in characteristics
|
||||
]
|
||||
|
||||
async def send_control_point_operation(
|
||||
|
||||
@@ -64,7 +64,10 @@ class DeviceInformationService(TemplateService):
|
||||
):
|
||||
characteristics = [
|
||||
Characteristic(
|
||||
uuid, Characteristic.Properties.READ, Characteristic.READABLE, field
|
||||
uuid,
|
||||
Characteristic.Properties.READ,
|
||||
Characteristic.READABLE,
|
||||
bytes(field, 'utf-8'),
|
||||
)
|
||||
for (field, uuid) in (
|
||||
(manufacturer_name, GATT_MANUFACTURER_NAME_STRING_CHARACTERISTIC),
|
||||
|
||||
166
bumble/profiles/gatt_service.py
Normal file
166
bumble/profiles/gatt_service.py
Normal file
@@ -0,0 +1,166 @@
|
||||
# Copyright 2021-2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import struct
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from bumble import att
|
||||
from bumble import gatt
|
||||
from bumble import gatt_client
|
||||
from bumble import crypto
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bumble import device
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class GenericAttributeProfileService(gatt.TemplateService):
|
||||
'''See Vol 3, Part G - 7 - DEFINED GENERIC ATTRIBUTE PROFILE SERVICE.'''
|
||||
|
||||
UUID = gatt.GATT_GENERIC_ATTRIBUTE_SERVICE
|
||||
|
||||
client_supported_features_characteristic: gatt.Characteristic | None = None
|
||||
server_supported_features_characteristic: gatt.Characteristic | None = None
|
||||
database_hash_characteristic: gatt.Characteristic | None = None
|
||||
service_changed_characteristic: gatt.Characteristic | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_supported_features: gatt.ServerSupportedFeatures | None = None,
|
||||
database_hash_enabled: bool = True,
|
||||
service_change_enabled: bool = True,
|
||||
) -> None:
|
||||
|
||||
if server_supported_features is not None:
|
||||
self.server_supported_features_characteristic = gatt.Characteristic(
|
||||
uuid=gatt.GATT_SERVER_SUPPORTED_FEATURES_CHARACTERISTIC,
|
||||
properties=gatt.Characteristic.Properties.READ,
|
||||
permissions=gatt.Characteristic.Permissions.READABLE,
|
||||
value=bytes([server_supported_features]),
|
||||
)
|
||||
|
||||
if database_hash_enabled:
|
||||
self.database_hash_characteristic = gatt.Characteristic(
|
||||
uuid=gatt.GATT_DATABASE_HASH_CHARACTERISTIC,
|
||||
properties=gatt.Characteristic.Properties.READ,
|
||||
permissions=gatt.Characteristic.Permissions.READABLE,
|
||||
value=gatt.CharacteristicValue(read=self.get_database_hash),
|
||||
)
|
||||
|
||||
if service_change_enabled:
|
||||
self.service_changed_characteristic = gatt.Characteristic(
|
||||
uuid=gatt.GATT_SERVICE_CHANGED_CHARACTERISTIC,
|
||||
properties=gatt.Characteristic.Properties.INDICATE,
|
||||
permissions=gatt.Characteristic.Permissions(0),
|
||||
value=b'',
|
||||
)
|
||||
|
||||
if (database_hash_enabled and service_change_enabled) or (
|
||||
server_supported_features
|
||||
and (
|
||||
server_supported_features & gatt.ServerSupportedFeatures.EATT_SUPPORTED
|
||||
)
|
||||
): # TODO: Support Multiple Handle Value Notifications
|
||||
self.client_supported_features_characteristic = gatt.Characteristic(
|
||||
uuid=gatt.GATT_CLIENT_SUPPORTED_FEATURES_CHARACTERISTIC,
|
||||
properties=(
|
||||
gatt.Characteristic.Properties.READ
|
||||
| gatt.Characteristic.Properties.WRITE
|
||||
),
|
||||
permissions=(
|
||||
gatt.Characteristic.Permissions.READABLE
|
||||
| gatt.Characteristic.Permissions.WRITEABLE
|
||||
),
|
||||
value=bytes(1),
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
characteristics=[
|
||||
c
|
||||
for c in (
|
||||
self.service_changed_characteristic,
|
||||
self.client_supported_features_characteristic,
|
||||
self.database_hash_characteristic,
|
||||
self.server_supported_features_characteristic,
|
||||
)
|
||||
if c is not None
|
||||
],
|
||||
primary=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_attribute_data(cls, attribute: att.Attribute) -> bytes:
|
||||
if attribute.type in (
|
||||
gatt.GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE,
|
||||
gatt.GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE,
|
||||
gatt.GATT_INCLUDE_ATTRIBUTE_TYPE,
|
||||
gatt.GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
|
||||
gatt.GATT_CHARACTERISTIC_EXTENDED_PROPERTIES_DESCRIPTOR,
|
||||
):
|
||||
return (
|
||||
struct.pack("<H", attribute.handle)
|
||||
+ attribute.type.to_bytes()
|
||||
+ attribute.value
|
||||
)
|
||||
elif attribute.type in (
|
||||
gatt.GATT_CHARACTERISTIC_USER_DESCRIPTION_DESCRIPTOR,
|
||||
gatt.GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
|
||||
gatt.GATT_SERVER_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
|
||||
gatt.GATT_CHARACTERISTIC_PRESENTATION_FORMAT_DESCRIPTOR,
|
||||
gatt.GATT_CHARACTERISTIC_AGGREGATE_FORMAT_DESCRIPTOR,
|
||||
):
|
||||
return struct.pack("<H", attribute.handle) + attribute.type.to_bytes()
|
||||
|
||||
return b''
|
||||
|
||||
def get_database_hash(self, connection: device.Connection | None) -> bytes:
|
||||
assert connection
|
||||
|
||||
m = b''.join(
|
||||
[
|
||||
self.get_attribute_data(attribute)
|
||||
for attribute in connection.device.gatt_server.attributes
|
||||
]
|
||||
)
|
||||
|
||||
return crypto.aes_cmac(m=m, k=bytes(16))
|
||||
|
||||
|
||||
class GenericAttributeProfileServiceProxy(gatt_client.ProfileServiceProxy):
|
||||
SERVICE_CLASS = GenericAttributeProfileService
|
||||
|
||||
client_supported_features_characteristic: gatt_client.CharacteristicProxy | None = (
|
||||
None
|
||||
)
|
||||
server_supported_features_characteristic: gatt_client.CharacteristicProxy | None = (
|
||||
None
|
||||
)
|
||||
database_hash_characteristic: gatt_client.CharacteristicProxy | None = None
|
||||
service_changed_characteristic: gatt_client.CharacteristicProxy | None = None
|
||||
|
||||
_CHARACTERISTICS = {
|
||||
gatt.GATT_CLIENT_SUPPORTED_FEATURES_CHARACTERISTIC: 'client_supported_features_characteristic',
|
||||
gatt.GATT_SERVER_SUPPORTED_FEATURES_CHARACTERISTIC: 'server_supported_features_characteristic',
|
||||
gatt.GATT_DATABASE_HASH_CHARACTERISTIC: 'database_hash_characteristic',
|
||||
gatt.GATT_SERVICE_CHANGED_CHARACTERISTIC: 'service_changed_characteristic',
|
||||
}
|
||||
|
||||
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
|
||||
self.service_proxy = service_proxy
|
||||
|
||||
for uuid, attribute_name in self._CHARACTERISTICS.items():
|
||||
if characteristics := self.service_proxy.get_characteristics_by_uuid(uuid):
|
||||
setattr(self, attribute_name, characteristics[0])
|
||||
193
bumble/profiles/gmap.py
Normal file
193
bumble/profiles/gmap.py
Normal file
@@ -0,0 +1,193 @@
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""LE Audio - Gaming Audio Profile"""
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import struct
|
||||
from typing import Optional
|
||||
|
||||
from bumble.gatt import (
|
||||
TemplateService,
|
||||
DelegatedCharacteristicAdapter,
|
||||
Characteristic,
|
||||
GATT_GAMING_AUDIO_SERVICE,
|
||||
GATT_GMAP_ROLE_CHARACTERISTIC,
|
||||
GATT_UGG_FEATURES_CHARACTERISTIC,
|
||||
GATT_UGT_FEATURES_CHARACTERISTIC,
|
||||
GATT_BGS_FEATURES_CHARACTERISTIC,
|
||||
GATT_BGR_FEATURES_CHARACTERISTIC,
|
||||
)
|
||||
from bumble.gatt_client import ProfileServiceProxy, ServiceProxy
|
||||
from enum import IntFlag
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Classes
|
||||
# -----------------------------------------------------------------------------
|
||||
class GmapRole(IntFlag):
|
||||
UNICAST_GAME_GATEWAY = 1 << 0
|
||||
UNICAST_GAME_TERMINAL = 1 << 1
|
||||
BROADCAST_GAME_SENDER = 1 << 2
|
||||
BROADCAST_GAME_RECEIVER = 1 << 3
|
||||
|
||||
|
||||
class UggFeatures(IntFlag):
|
||||
UGG_MULTIPLEX = 1 << 0
|
||||
UGG_96_KBPS_SOURCE = 1 << 1
|
||||
UGG_MULTISINK = 1 << 2
|
||||
|
||||
|
||||
class UgtFeatures(IntFlag):
|
||||
UGT_SOURCE = 1 << 0
|
||||
UGT_80_KBPS_SOURCE = 1 << 1
|
||||
UGT_SINK = 1 << 2
|
||||
UGT_64_KBPS_SINK = 1 << 3
|
||||
UGT_MULTIPLEX = 1 << 4
|
||||
UGT_MULTISINK = 1 << 5
|
||||
UGT_MULTISOURCE = 1 << 6
|
||||
|
||||
|
||||
class BgsFeatures(IntFlag):
|
||||
BGS_96_KBPS = 1 << 0
|
||||
|
||||
|
||||
class BgrFeatures(IntFlag):
|
||||
BGR_MULTISINK = 1 << 0
|
||||
BGR_MULTIPLEX = 1 << 1
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Server
|
||||
# -----------------------------------------------------------------------------
|
||||
class GamingAudioService(TemplateService):
|
||||
UUID = GATT_GAMING_AUDIO_SERVICE
|
||||
|
||||
gmap_role: Characteristic
|
||||
ugg_features: Optional[Characteristic] = None
|
||||
ugt_features: Optional[Characteristic] = None
|
||||
bgs_features: Optional[Characteristic] = None
|
||||
bgr_features: Optional[Characteristic] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
gmap_role: GmapRole,
|
||||
ugg_features: Optional[UggFeatures] = None,
|
||||
ugt_features: Optional[UgtFeatures] = None,
|
||||
bgs_features: Optional[BgsFeatures] = None,
|
||||
bgr_features: Optional[BgrFeatures] = None,
|
||||
) -> None:
|
||||
characteristics = []
|
||||
|
||||
ugg_features = UggFeatures(0) if ugg_features is None else ugg_features
|
||||
ugt_features = UgtFeatures(0) if ugt_features is None else ugt_features
|
||||
bgs_features = BgsFeatures(0) if bgs_features is None else bgs_features
|
||||
bgr_features = BgrFeatures(0) if bgr_features is None else bgr_features
|
||||
|
||||
self.gmap_role = Characteristic(
|
||||
uuid=GATT_GMAP_ROLE_CHARACTERISTIC,
|
||||
properties=Characteristic.Properties.READ,
|
||||
permissions=Characteristic.Permissions.READABLE,
|
||||
value=struct.pack('B', gmap_role),
|
||||
)
|
||||
characteristics.append(self.gmap_role)
|
||||
|
||||
if gmap_role & GmapRole.UNICAST_GAME_GATEWAY:
|
||||
self.ugg_features = Characteristic(
|
||||
uuid=GATT_UGG_FEATURES_CHARACTERISTIC,
|
||||
properties=Characteristic.Properties.READ,
|
||||
permissions=Characteristic.Permissions.READABLE,
|
||||
value=struct.pack('B', ugg_features),
|
||||
)
|
||||
characteristics.append(self.ugg_features)
|
||||
|
||||
if gmap_role & GmapRole.UNICAST_GAME_TERMINAL:
|
||||
self.ugt_features = Characteristic(
|
||||
uuid=GATT_UGT_FEATURES_CHARACTERISTIC,
|
||||
properties=Characteristic.Properties.READ,
|
||||
permissions=Characteristic.Permissions.READABLE,
|
||||
value=struct.pack('B', ugt_features),
|
||||
)
|
||||
characteristics.append(self.ugt_features)
|
||||
|
||||
if gmap_role & GmapRole.BROADCAST_GAME_SENDER:
|
||||
self.bgs_features = Characteristic(
|
||||
uuid=GATT_BGS_FEATURES_CHARACTERISTIC,
|
||||
properties=Characteristic.Properties.READ,
|
||||
permissions=Characteristic.Permissions.READABLE,
|
||||
value=struct.pack('B', bgs_features),
|
||||
)
|
||||
characteristics.append(self.bgs_features)
|
||||
|
||||
if gmap_role & GmapRole.BROADCAST_GAME_RECEIVER:
|
||||
self.bgr_features = Characteristic(
|
||||
uuid=GATT_BGR_FEATURES_CHARACTERISTIC,
|
||||
properties=Characteristic.Properties.READ,
|
||||
permissions=Characteristic.Permissions.READABLE,
|
||||
value=struct.pack('B', bgr_features),
|
||||
)
|
||||
characteristics.append(self.bgr_features)
|
||||
|
||||
super().__init__(characteristics)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Client
|
||||
# -----------------------------------------------------------------------------
|
||||
class GamingAudioServiceProxy(ProfileServiceProxy):
|
||||
SERVICE_CLASS = GamingAudioService
|
||||
|
||||
def __init__(self, service_proxy: ServiceProxy) -> None:
|
||||
self.service_proxy = service_proxy
|
||||
|
||||
self.gmap_role = DelegatedCharacteristicAdapter(
|
||||
service_proxy.get_required_characteristic_by_uuid(
|
||||
GATT_GMAP_ROLE_CHARACTERISTIC
|
||||
),
|
||||
decode=lambda value: GmapRole(value[0]),
|
||||
)
|
||||
|
||||
if characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
GATT_UGG_FEATURES_CHARACTERISTIC
|
||||
):
|
||||
self.ugg_features = DelegatedCharacteristicAdapter(
|
||||
characteristic=characteristics[0],
|
||||
decode=lambda value: UggFeatures(value[0]),
|
||||
)
|
||||
|
||||
if characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
GATT_UGT_FEATURES_CHARACTERISTIC
|
||||
):
|
||||
self.ugt_features = DelegatedCharacteristicAdapter(
|
||||
characteristic=characteristics[0],
|
||||
decode=lambda value: UgtFeatures(value[0]),
|
||||
)
|
||||
|
||||
if characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
GATT_BGS_FEATURES_CHARACTERISTIC
|
||||
):
|
||||
self.bgs_features = DelegatedCharacteristicAdapter(
|
||||
characteristic=characteristics[0],
|
||||
decode=lambda value: BgsFeatures(value[0]),
|
||||
)
|
||||
|
||||
if characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
GATT_BGR_FEATURES_CHARACTERISTIC
|
||||
):
|
||||
self.bgr_features = DelegatedCharacteristicAdapter(
|
||||
characteristic=characteristics[0],
|
||||
decode=lambda value: BgrFeatures(value[0]),
|
||||
)
|
||||
@@ -30,6 +30,7 @@ from ..gatt import (
|
||||
TemplateService,
|
||||
Characteristic,
|
||||
CharacteristicValue,
|
||||
SerializableCharacteristicAdapter,
|
||||
DelegatedCharacteristicAdapter,
|
||||
PackedCharacteristicAdapter,
|
||||
)
|
||||
@@ -150,15 +151,14 @@ class HeartRateService(TemplateService):
|
||||
body_sensor_location=None,
|
||||
reset_energy_expended=None,
|
||||
):
|
||||
self.heart_rate_measurement_characteristic = DelegatedCharacteristicAdapter(
|
||||
self.heart_rate_measurement_characteristic = SerializableCharacteristicAdapter(
|
||||
Characteristic(
|
||||
GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC,
|
||||
Characteristic.Properties.NOTIFY,
|
||||
0,
|
||||
CharacteristicValue(read=read_heart_rate_measurement),
|
||||
),
|
||||
# pylint: disable=unnecessary-lambda
|
||||
encode=lambda value: bytes(value),
|
||||
HeartRateService.HeartRateMeasurement,
|
||||
)
|
||||
characteristics = [self.heart_rate_measurement_characteristic]
|
||||
|
||||
@@ -204,9 +204,8 @@ class HeartRateServiceProxy(ProfileServiceProxy):
|
||||
if characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
GATT_HEART_RATE_MEASUREMENT_CHARACTERISTIC
|
||||
):
|
||||
self.heart_rate_measurement = DelegatedCharacteristicAdapter(
|
||||
characteristics[0],
|
||||
decode=HeartRateService.HeartRateMeasurement.from_bytes,
|
||||
self.heart_rate_measurement = SerializableCharacteristicAdapter(
|
||||
characteristics[0], HeartRateService.HeartRateMeasurement
|
||||
)
|
||||
else:
|
||||
self.heart_rate_measurement = None
|
||||
|
||||
@@ -17,23 +17,35 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
import dataclasses
|
||||
import enum
|
||||
import struct
|
||||
from typing import List, Type
|
||||
from typing import Any, List, Type
|
||||
from typing_extensions import Self
|
||||
|
||||
from bumble.profiles import bap
|
||||
from bumble import utils
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Classes
|
||||
# -----------------------------------------------------------------------------
|
||||
class AudioActiveState(utils.OpenIntEnum):
|
||||
NO_AUDIO_DATA_TRANSMITTED = 0x00
|
||||
AUDIO_DATA_TRANSMITTED = 0x01
|
||||
|
||||
|
||||
class AssistedListeningStream(utils.OpenIntEnum):
|
||||
UNSPECIFIED_AUDIO_ENHANCEMENT = 0x00
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Metadata:
|
||||
'''Bluetooth Assigned Numbers, Section 6.12.6 - Metadata LTV structures.
|
||||
|
||||
As Metadata fields may extend, and Spec doesn't forbid duplication, we don't parse
|
||||
Metadata into a key-value style dataclass here. Rather, we encourage users to parse
|
||||
again outside the lib.
|
||||
As Metadata fields may extend, and the spec may not guarantee the uniqueness of
|
||||
tags, we don't automatically parse the Metadata data into specific classes.
|
||||
Users of this class may decode the data by themselves, or use the Entry.decode
|
||||
method.
|
||||
'''
|
||||
|
||||
class Tag(utils.OpenIntEnum):
|
||||
@@ -57,6 +69,44 @@ class Metadata:
|
||||
tag: Metadata.Tag
|
||||
data: bytes
|
||||
|
||||
def decode(self) -> Any:
|
||||
"""
|
||||
Decode the data into an object, if possible.
|
||||
|
||||
If no specific object class exists to represent the data, the raw data
|
||||
bytes are returned.
|
||||
"""
|
||||
|
||||
if self.tag in (
|
||||
Metadata.Tag.PREFERRED_AUDIO_CONTEXTS,
|
||||
Metadata.Tag.STREAMING_AUDIO_CONTEXTS,
|
||||
):
|
||||
return bap.ContextType(struct.unpack("<H", self.data)[0])
|
||||
|
||||
if self.tag in (
|
||||
Metadata.Tag.PROGRAM_INFO,
|
||||
Metadata.Tag.PROGRAM_INFO_URI,
|
||||
Metadata.Tag.BROADCAST_NAME,
|
||||
):
|
||||
return self.data.decode("utf-8")
|
||||
|
||||
if self.tag == Metadata.Tag.LANGUAGE:
|
||||
return self.data.decode("ascii")
|
||||
|
||||
if self.tag == Metadata.Tag.CCID_LIST:
|
||||
return list(self.data)
|
||||
|
||||
if self.tag == Metadata.Tag.PARENTAL_RATING:
|
||||
return self.data[0]
|
||||
|
||||
if self.tag == Metadata.Tag.AUDIO_ACTIVE_STATE:
|
||||
return AudioActiveState(self.data[0])
|
||||
|
||||
if self.tag == Metadata.Tag.ASSISTED_LISTENING_STREAM:
|
||||
return AssistedListeningStream(self.data[0])
|
||||
|
||||
return self.data
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls: Type[Self], data: bytes) -> Self:
|
||||
return cls(tag=Metadata.Tag(data[0]), data=data[1:])
|
||||
@@ -66,6 +116,29 @@ class Metadata:
|
||||
|
||||
entries: List[Entry] = dataclasses.field(default_factory=list)
|
||||
|
||||
def pretty_print(self, indent: str) -> str:
|
||||
"""Convenience method to generate a string with one key-value pair per line."""
|
||||
|
||||
max_key_length = 0
|
||||
keys = []
|
||||
values = []
|
||||
for entry in self.entries:
|
||||
key = entry.tag.name
|
||||
max_key_length = max(max_key_length, len(key))
|
||||
keys.append(key)
|
||||
decoded = entry.decode()
|
||||
if isinstance(decoded, enum.Enum):
|
||||
values.append(decoded.name)
|
||||
elif isinstance(decoded, bytes):
|
||||
values.append(decoded.hex())
|
||||
else:
|
||||
values.append(str(decoded))
|
||||
|
||||
return '\n'.join(
|
||||
f'{indent}{key}: {" " * (max_key_length-len(key))}{value}'
|
||||
for key, value in zip(keys, values)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls: Type[Self], data: bytes) -> Self:
|
||||
entries = []
|
||||
@@ -81,3 +154,13 @@ class Metadata:
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return b''.join([bytes(entry) for entry in self.entries])
|
||||
|
||||
def __str__(self) -> str:
|
||||
entries_str = []
|
||||
for entry in self.entries:
|
||||
decoded = entry.decode()
|
||||
entries_str.append(
|
||||
f'{entry.tag.name}: '
|
||||
f'{decoded.hex() if isinstance(decoded, bytes) else decoded!r}'
|
||||
)
|
||||
return f'Metadata(entries={", ".join(entry_str for entry_str in entries_str)})'
|
||||
|
||||
@@ -72,6 +72,19 @@ class PacRecord:
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def list_from_bytes(cls, data: bytes) -> list[PacRecord]:
|
||||
"""Parse a serialized list of records preceded by a one byte list length."""
|
||||
record_count = data[0]
|
||||
records = []
|
||||
offset = 1
|
||||
for _ in range(record_count):
|
||||
record = PacRecord.from_bytes(data[offset:])
|
||||
offset += len(bytes(record))
|
||||
records.append(record)
|
||||
|
||||
return records
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
capabilities_bytes = bytes(self.codec_specific_capabilities)
|
||||
metadata_bytes = bytes(self.metadata)
|
||||
@@ -172,39 +185,58 @@ class PublishedAudioCapabilitiesService(gatt.TemplateService):
|
||||
class PublishedAudioCapabilitiesServiceProxy(gatt_client.ProfileServiceProxy):
|
||||
SERVICE_CLASS = PublishedAudioCapabilitiesService
|
||||
|
||||
sink_pac: Optional[gatt_client.CharacteristicProxy] = None
|
||||
sink_audio_locations: Optional[gatt_client.CharacteristicProxy] = None
|
||||
source_pac: Optional[gatt_client.CharacteristicProxy] = None
|
||||
source_audio_locations: Optional[gatt_client.CharacteristicProxy] = None
|
||||
available_audio_contexts: gatt_client.CharacteristicProxy
|
||||
supported_audio_contexts: gatt_client.CharacteristicProxy
|
||||
sink_pac: Optional[gatt.DelegatedCharacteristicAdapter] = None
|
||||
sink_audio_locations: Optional[gatt.DelegatedCharacteristicAdapter] = None
|
||||
source_pac: Optional[gatt.DelegatedCharacteristicAdapter] = None
|
||||
source_audio_locations: Optional[gatt.DelegatedCharacteristicAdapter] = None
|
||||
available_audio_contexts: gatt.DelegatedCharacteristicAdapter
|
||||
supported_audio_contexts: gatt.DelegatedCharacteristicAdapter
|
||||
|
||||
def __init__(self, service_proxy: gatt_client.ServiceProxy):
|
||||
self.service_proxy = service_proxy
|
||||
|
||||
self.available_audio_contexts = service_proxy.get_characteristics_by_uuid(
|
||||
gatt.GATT_AVAILABLE_AUDIO_CONTEXTS_CHARACTERISTIC
|
||||
)[0]
|
||||
self.supported_audio_contexts = service_proxy.get_characteristics_by_uuid(
|
||||
gatt.GATT_SUPPORTED_AUDIO_CONTEXTS_CHARACTERISTIC
|
||||
)[0]
|
||||
self.available_audio_contexts = gatt.DelegatedCharacteristicAdapter(
|
||||
service_proxy.get_required_characteristic_by_uuid(
|
||||
gatt.GATT_AVAILABLE_AUDIO_CONTEXTS_CHARACTERISTIC
|
||||
),
|
||||
decode=lambda x: tuple(map(ContextType, struct.unpack('<HH', x))),
|
||||
)
|
||||
|
||||
self.supported_audio_contexts = gatt.DelegatedCharacteristicAdapter(
|
||||
service_proxy.get_required_characteristic_by_uuid(
|
||||
gatt.GATT_SUPPORTED_AUDIO_CONTEXTS_CHARACTERISTIC
|
||||
),
|
||||
decode=lambda x: tuple(map(ContextType, struct.unpack('<HH', x))),
|
||||
)
|
||||
|
||||
if characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
gatt.GATT_SINK_PAC_CHARACTERISTIC
|
||||
):
|
||||
self.sink_pac = characteristics[0]
|
||||
self.sink_pac = gatt.DelegatedCharacteristicAdapter(
|
||||
characteristics[0],
|
||||
decode=PacRecord.list_from_bytes,
|
||||
)
|
||||
|
||||
if characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
gatt.GATT_SOURCE_PAC_CHARACTERISTIC
|
||||
):
|
||||
self.source_pac = characteristics[0]
|
||||
self.source_pac = gatt.DelegatedCharacteristicAdapter(
|
||||
characteristics[0],
|
||||
decode=PacRecord.list_from_bytes,
|
||||
)
|
||||
|
||||
if characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
gatt.GATT_SINK_AUDIO_LOCATION_CHARACTERISTIC
|
||||
):
|
||||
self.sink_audio_locations = characteristics[0]
|
||||
self.sink_audio_locations = gatt.DelegatedCharacteristicAdapter(
|
||||
characteristics[0],
|
||||
decode=lambda x: AudioLocation(struct.unpack('<I', x)[0]),
|
||||
)
|
||||
|
||||
if characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
gatt.GATT_SOURCE_AUDIO_LOCATION_CHARACTERISTIC
|
||||
):
|
||||
self.source_audio_locations = characteristics[0]
|
||||
self.source_audio_locations = gatt.DelegatedCharacteristicAdapter(
|
||||
characteristics[0],
|
||||
decode=lambda x: AudioLocation(struct.unpack('<I', x)[0]),
|
||||
)
|
||||
|
||||
@@ -25,7 +25,6 @@ from bumble.gatt import (
|
||||
TemplateService,
|
||||
Characteristic,
|
||||
DelegatedCharacteristicAdapter,
|
||||
InvalidServiceError,
|
||||
GATT_TELEPHONY_AND_MEDIA_AUDIO_SERVICE,
|
||||
GATT_TMAP_ROLE_CHARACTERISTIC,
|
||||
)
|
||||
@@ -74,15 +73,10 @@ class TelephonyAndMediaAudioServiceProxy(ProfileServiceProxy):
|
||||
def __init__(self, service_proxy: ServiceProxy):
|
||||
self.service_proxy = service_proxy
|
||||
|
||||
if not (
|
||||
characteristics := service_proxy.get_characteristics_by_uuid(
|
||||
GATT_TMAP_ROLE_CHARACTERISTIC
|
||||
)
|
||||
):
|
||||
raise InvalidServiceError('TMAP Role characteristic not found')
|
||||
|
||||
self.role = DelegatedCharacteristicAdapter(
|
||||
characteristics[0],
|
||||
service_proxy.get_required_characteristic_by_uuid(
|
||||
GATT_TMAP_ROLE_CHARACTERISTIC
|
||||
),
|
||||
decode=lambda value: Role(
|
||||
struct.unpack_from('<H', value, 0)[0],
|
||||
),
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Copyright 2021-2024 Google LLC
|
||||
# Copyright 2021-2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -17,14 +17,16 @@
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
import dataclasses
|
||||
import enum
|
||||
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from bumble import att
|
||||
from bumble import device
|
||||
from bumble import gatt
|
||||
from bumble import gatt_client
|
||||
|
||||
from typing import Optional, Sequence
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Constants
|
||||
@@ -67,6 +69,20 @@ class VolumeControlPointOpcode(enum.IntEnum):
|
||||
MUTE = 0x06
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class VolumeState:
|
||||
volume_setting: int
|
||||
mute: int
|
||||
change_counter: int
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes) -> VolumeState:
|
||||
return cls(data[0], data[1], data[2])
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return bytes([self.volume_setting, self.mute, self.change_counter])
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Server
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -126,16 +142,8 @@ class VolumeControlService(gatt.TemplateService):
|
||||
included_services=list(included_services),
|
||||
)
|
||||
|
||||
@property
|
||||
def volume_state_bytes(self) -> bytes:
|
||||
return bytes([self.volume_setting, self.muted, self.change_counter])
|
||||
|
||||
@volume_state_bytes.setter
|
||||
def volume_state_bytes(self, new_value: bytes) -> None:
|
||||
self.volume_setting, self.muted, self.change_counter = new_value
|
||||
|
||||
def _on_read_volume_state(self, _connection: Optional[device.Connection]) -> bytes:
|
||||
return self.volume_state_bytes
|
||||
return bytes(VolumeState(self.volume_setting, self.muted, self.change_counter))
|
||||
|
||||
def _on_write_volume_control_point(
|
||||
self, connection: Optional[device.Connection], value: bytes
|
||||
@@ -153,14 +161,9 @@ class VolumeControlService(gatt.TemplateService):
|
||||
self.change_counter = (self.change_counter + 1) % 256
|
||||
connection.abort_on(
|
||||
'disconnection',
|
||||
connection.device.notify_subscribers(
|
||||
attribute=self.volume_state,
|
||||
value=self.volume_state_bytes,
|
||||
),
|
||||
)
|
||||
self.emit(
|
||||
'volume_state', self.volume_setting, self.muted, self.change_counter
|
||||
connection.device.notify_subscribers(attribute=self.volume_state),
|
||||
)
|
||||
self.emit('volume_state_change')
|
||||
|
||||
def _on_relative_volume_down(self) -> bool:
|
||||
old_volume = self.volume_setting
|
||||
@@ -207,24 +210,26 @@ class VolumeControlServiceProxy(gatt_client.ProfileServiceProxy):
|
||||
SERVICE_CLASS = VolumeControlService
|
||||
|
||||
volume_control_point: gatt_client.CharacteristicProxy
|
||||
volume_state: gatt.SerializableCharacteristicAdapter
|
||||
volume_flags: gatt.DelegatedCharacteristicAdapter
|
||||
|
||||
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
|
||||
self.service_proxy = service_proxy
|
||||
|
||||
self.volume_state = gatt.PackedCharacteristicAdapter(
|
||||
service_proxy.get_characteristics_by_uuid(
|
||||
self.volume_state = gatt.SerializableCharacteristicAdapter(
|
||||
service_proxy.get_required_characteristic_by_uuid(
|
||||
gatt.GATT_VOLUME_STATE_CHARACTERISTIC
|
||||
)[0],
|
||||
'BBB',
|
||||
),
|
||||
VolumeState,
|
||||
)
|
||||
|
||||
self.volume_control_point = service_proxy.get_characteristics_by_uuid(
|
||||
self.volume_control_point = service_proxy.get_required_characteristic_by_uuid(
|
||||
gatt.GATT_VOLUME_CONTROL_POINT_CHARACTERISTIC
|
||||
)[0]
|
||||
|
||||
self.volume_flags = gatt.PackedCharacteristicAdapter(
|
||||
service_proxy.get_characteristics_by_uuid(
|
||||
gatt.GATT_VOLUME_FLAGS_CHARACTERISTIC
|
||||
)[0],
|
||||
'B',
|
||||
)
|
||||
|
||||
self.volume_flags = gatt.DelegatedCharacteristicAdapter(
|
||||
service_proxy.get_required_characteristic_by_uuid(
|
||||
gatt.GATT_VOLUME_FLAGS_CHARACTERISTIC
|
||||
),
|
||||
decode=lambda data: VolumeFlags(data[0]),
|
||||
)
|
||||
299
bumble/profiles/vocs.py
Normal file
299
bumble/profiles/vocs.py
Normal file
@@ -0,0 +1,299 @@
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import struct
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from bumble.device import Connection
|
||||
from bumble.att import ATT_Error
|
||||
from bumble.gatt import (
|
||||
Characteristic,
|
||||
DelegatedCharacteristicAdapter,
|
||||
TemplateService,
|
||||
CharacteristicValue,
|
||||
SerializableCharacteristicAdapter,
|
||||
UTF8CharacteristicAdapter,
|
||||
GATT_VOLUME_OFFSET_CONTROL_SERVICE,
|
||||
GATT_VOLUME_OFFSET_STATE_CHARACTERISTIC,
|
||||
GATT_AUDIO_LOCATION_CHARACTERISTIC,
|
||||
GATT_VOLUME_OFFSET_CONTROL_POINT_CHARACTERISTIC,
|
||||
GATT_AUDIO_OUTPUT_DESCRIPTION_CHARACTERISTIC,
|
||||
)
|
||||
from bumble.gatt_client import ProfileServiceProxy, ServiceProxy
|
||||
from bumble.utils import OpenIntEnum
|
||||
from bumble.profiles.bap import AudioLocation
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Constants
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
MIN_VOLUME_OFFSET = -255
|
||||
MAX_VOLUME_OFFSET = 255
|
||||
CHANGE_COUNTER_MAX_VALUE = 0xFF
|
||||
|
||||
|
||||
class SetVolumeOffsetOpCode(OpenIntEnum):
|
||||
SET_VOLUME_OFFSET = 0x01
|
||||
|
||||
|
||||
class ErrorCode(OpenIntEnum):
|
||||
"""
|
||||
See Volume Offset Control Service 1.6. Application error codes.
|
||||
"""
|
||||
|
||||
INVALID_CHANGE_COUNTER = 0x80
|
||||
OPCODE_NOT_SUPPORTED = 0x81
|
||||
VALUE_OUT_OF_RANGE = 0x82
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclass
|
||||
class VolumeOffsetState:
|
||||
volume_offset: int = 0
|
||||
change_counter: int = 0
|
||||
attribute_value: Optional[CharacteristicValue] = None
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return struct.pack('<hB', self.volume_offset, self.change_counter)
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes):
|
||||
volume_offset, change_counter = struct.unpack('<hB', data)
|
||||
return cls(volume_offset, change_counter)
|
||||
|
||||
def increment_change_counter(self) -> None:
|
||||
self.change_counter = (self.change_counter + 1) % (CHANGE_COUNTER_MAX_VALUE + 1)
|
||||
|
||||
async def notify_subscribers_via_connection(self, connection: Connection) -> None:
|
||||
assert self.attribute_value is not None
|
||||
await connection.device.notify_subscribers(attribute=self.attribute_value)
|
||||
|
||||
def on_read(self, _connection: Optional[Connection]) -> bytes:
|
||||
return bytes(self)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VocsAudioLocation:
|
||||
audio_location: AudioLocation = AudioLocation.NOT_ALLOWED
|
||||
attribute_value: Optional[CharacteristicValue] = None
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return struct.pack('<I', self.audio_location)
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes):
|
||||
audio_location = AudioLocation(struct.unpack('<I', data)[0])
|
||||
return cls(audio_location)
|
||||
|
||||
def on_read(self, _connection: Optional[Connection]) -> bytes:
|
||||
return bytes(self)
|
||||
|
||||
async def on_write(self, connection: Optional[Connection], value: bytes) -> None:
|
||||
assert connection
|
||||
assert self.attribute_value
|
||||
|
||||
self.audio_location = AudioLocation(int.from_bytes(value, 'little'))
|
||||
await connection.device.notify_subscribers(attribute=self.attribute_value)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VolumeOffsetControlPoint:
|
||||
volume_offset_state: VolumeOffsetState
|
||||
|
||||
async def on_write(self, connection: Optional[Connection], value: bytes) -> None:
|
||||
assert connection
|
||||
|
||||
opcode = value[0]
|
||||
if opcode != SetVolumeOffsetOpCode.SET_VOLUME_OFFSET:
|
||||
raise ATT_Error(ErrorCode.OPCODE_NOT_SUPPORTED)
|
||||
|
||||
change_counter, volume_offset = struct.unpack('<Bh', value[1:])
|
||||
await self._set_volume_offset(connection, change_counter, volume_offset)
|
||||
|
||||
async def _set_volume_offset(
|
||||
self,
|
||||
connection: Connection,
|
||||
change_counter_operand: int,
|
||||
volume_offset_operand: int,
|
||||
) -> None:
|
||||
change_counter = self.volume_offset_state.change_counter
|
||||
|
||||
if change_counter != change_counter_operand:
|
||||
raise ATT_Error(ErrorCode.INVALID_CHANGE_COUNTER)
|
||||
|
||||
if not MIN_VOLUME_OFFSET <= volume_offset_operand <= MAX_VOLUME_OFFSET:
|
||||
raise ATT_Error(ErrorCode.VALUE_OUT_OF_RANGE)
|
||||
|
||||
self.volume_offset_state.volume_offset = volume_offset_operand
|
||||
self.volume_offset_state.increment_change_counter()
|
||||
await self.volume_offset_state.notify_subscribers_via_connection(connection)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AudioOutputDescription:
|
||||
audio_output_description: str = ''
|
||||
attribute_value: Optional[CharacteristicValue] = None
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes):
|
||||
return cls(audio_output_description=data.decode('utf-8'))
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return self.audio_output_description.encode('utf-8')
|
||||
|
||||
def on_read(self, _connection: Optional[Connection]) -> bytes:
|
||||
return bytes(self)
|
||||
|
||||
async def on_write(self, connection: Optional[Connection], value: bytes) -> None:
|
||||
assert connection
|
||||
assert self.attribute_value
|
||||
|
||||
self.audio_output_description = value.decode('utf-8')
|
||||
await connection.device.notify_subscribers(attribute=self.attribute_value)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class VolumeOffsetControlService(TemplateService):
|
||||
UUID = GATT_VOLUME_OFFSET_CONTROL_SERVICE
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
volume_offset_state: Optional[VolumeOffsetState] = None,
|
||||
audio_location: Optional[VocsAudioLocation] = None,
|
||||
audio_output_description: Optional[AudioOutputDescription] = None,
|
||||
) -> None:
|
||||
|
||||
self.volume_offset_state = (
|
||||
VolumeOffsetState() if volume_offset_state is None else volume_offset_state
|
||||
)
|
||||
|
||||
self.audio_location = (
|
||||
VocsAudioLocation() if audio_location is None else audio_location
|
||||
)
|
||||
|
||||
self.audio_output_description = (
|
||||
AudioOutputDescription()
|
||||
if audio_output_description is None
|
||||
else audio_output_description
|
||||
)
|
||||
|
||||
self.volume_offset_control_point: VolumeOffsetControlPoint = (
|
||||
VolumeOffsetControlPoint(self.volume_offset_state)
|
||||
)
|
||||
|
||||
self.volume_offset_state_characteristic = Characteristic(
|
||||
uuid=GATT_VOLUME_OFFSET_STATE_CHARACTERISTIC,
|
||||
properties=(
|
||||
Characteristic.Properties.READ | Characteristic.Properties.NOTIFY
|
||||
),
|
||||
permissions=Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
|
||||
value=CharacteristicValue(read=self.volume_offset_state.on_read),
|
||||
)
|
||||
|
||||
self.audio_location_characteristic = Characteristic(
|
||||
uuid=GATT_AUDIO_LOCATION_CHARACTERISTIC,
|
||||
properties=(
|
||||
Characteristic.Properties.READ
|
||||
| Characteristic.Properties.NOTIFY
|
||||
| Characteristic.Properties.WRITE_WITHOUT_RESPONSE
|
||||
),
|
||||
permissions=(
|
||||
Characteristic.Permissions.READ_REQUIRES_ENCRYPTION
|
||||
| Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION
|
||||
),
|
||||
value=CharacteristicValue(
|
||||
read=self.audio_location.on_read,
|
||||
write=self.audio_location.on_write,
|
||||
),
|
||||
)
|
||||
self.audio_location.attribute_value = self.audio_location_characteristic.value
|
||||
|
||||
self.volume_offset_control_point_characteristic = Characteristic(
|
||||
uuid=GATT_VOLUME_OFFSET_CONTROL_POINT_CHARACTERISTIC,
|
||||
properties=Characteristic.Properties.WRITE,
|
||||
permissions=Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION,
|
||||
value=CharacteristicValue(write=self.volume_offset_control_point.on_write),
|
||||
)
|
||||
|
||||
self.audio_output_description_characteristic = Characteristic(
|
||||
uuid=GATT_AUDIO_OUTPUT_DESCRIPTION_CHARACTERISTIC,
|
||||
properties=(
|
||||
Characteristic.Properties.READ
|
||||
| Characteristic.Properties.NOTIFY
|
||||
| Characteristic.Properties.WRITE_WITHOUT_RESPONSE
|
||||
),
|
||||
permissions=(
|
||||
Characteristic.Permissions.READ_REQUIRES_ENCRYPTION
|
||||
| Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION
|
||||
),
|
||||
value=CharacteristicValue(
|
||||
read=self.audio_output_description.on_read,
|
||||
write=self.audio_output_description.on_write,
|
||||
),
|
||||
)
|
||||
self.audio_output_description.attribute_value = (
|
||||
self.audio_output_description_characteristic.value
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
characteristics=[
|
||||
self.volume_offset_state_characteristic, # type: ignore
|
||||
self.audio_location_characteristic, # type: ignore
|
||||
self.volume_offset_control_point_characteristic, # type: ignore
|
||||
self.audio_output_description_characteristic, # type: ignore
|
||||
],
|
||||
primary=False,
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Client
|
||||
# -----------------------------------------------------------------------------
|
||||
class VolumeOffsetControlServiceProxy(ProfileServiceProxy):
|
||||
SERVICE_CLASS = VolumeOffsetControlService
|
||||
|
||||
def __init__(self, service_proxy: ServiceProxy) -> None:
|
||||
self.service_proxy = service_proxy
|
||||
|
||||
self.volume_offset_state = SerializableCharacteristicAdapter(
|
||||
service_proxy.get_required_characteristic_by_uuid(
|
||||
GATT_VOLUME_OFFSET_STATE_CHARACTERISTIC
|
||||
),
|
||||
VolumeOffsetState,
|
||||
)
|
||||
|
||||
self.audio_location = DelegatedCharacteristicAdapter(
|
||||
service_proxy.get_required_characteristic_by_uuid(
|
||||
GATT_AUDIO_LOCATION_CHARACTERISTIC
|
||||
),
|
||||
encode=lambda value: bytes([int(value)]),
|
||||
decode=lambda data: AudioLocation(data[0]),
|
||||
)
|
||||
|
||||
self.volume_offset_control_point = (
|
||||
service_proxy.get_required_characteristic_by_uuid(
|
||||
GATT_VOLUME_OFFSET_CONTROL_POINT_CHARACTERISTIC
|
||||
)
|
||||
)
|
||||
|
||||
self.audio_output_description = UTF8CharacteristicAdapter(
|
||||
service_proxy.get_required_characteristic_by_uuid(
|
||||
GATT_AUDIO_OUTPUT_DESCRIPTION_CHARACTERISTIC
|
||||
)
|
||||
)
|
||||
326
bumble/sdp.py
326
bumble/sdp.py
@@ -16,15 +16,21 @@
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import logging
|
||||
import struct
|
||||
from typing import Dict, List, Type, Optional, Tuple, Union, NewType, TYPE_CHECKING
|
||||
from typing import Iterable, NewType, Optional, Union, Sequence, Type, TYPE_CHECKING
|
||||
from typing_extensions import Self
|
||||
|
||||
from . import core, l2cap
|
||||
from .colors import color
|
||||
from .core import InvalidStateError, InvalidArgumentError, InvalidPacketError
|
||||
from .hci import HCI_Object, name_or_number, key_with_value
|
||||
from bumble import core, l2cap
|
||||
from bumble.colors import color
|
||||
from bumble.core import (
|
||||
InvalidStateError,
|
||||
InvalidArgumentError,
|
||||
InvalidPacketError,
|
||||
ProtocolError,
|
||||
)
|
||||
from bumble.hci import HCI_Object, name_or_number, key_with_value
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .device import Device, Connection
|
||||
@@ -124,7 +130,7 @@ SDP_ATTRIBUTE_ID_NAMES = {
|
||||
SDP_PUBLIC_BROWSE_ROOT = core.UUID.from_16_bits(0x1002, 'PublicBrowseRoot')
|
||||
|
||||
# To be used in searches where an attribute ID list allows a range to be specified
|
||||
SDP_ALL_ATTRIBUTES_RANGE = (0x0000FFFF, 4) # Express this as tuple so we can convey the desired encoding size
|
||||
SDP_ALL_ATTRIBUTES_RANGE = (0x0000, 0xFFFF)
|
||||
|
||||
# fmt: on
|
||||
# pylint: enable=line-too-long
|
||||
@@ -242,11 +248,11 @@ class DataElement:
|
||||
return DataElement(DataElement.BOOLEAN, value)
|
||||
|
||||
@staticmethod
|
||||
def sequence(value: List[DataElement]) -> DataElement:
|
||||
def sequence(value: Iterable[DataElement]) -> DataElement:
|
||||
return DataElement(DataElement.SEQUENCE, value)
|
||||
|
||||
@staticmethod
|
||||
def alternative(value: List[DataElement]) -> DataElement:
|
||||
def alternative(value: Iterable[DataElement]) -> DataElement:
|
||||
return DataElement(DataElement.ALTERNATIVE, value)
|
||||
|
||||
@staticmethod
|
||||
@@ -344,9 +350,6 @@ class DataElement:
|
||||
] # Keep a copy so we can re-serialize to an exact replica
|
||||
return result
|
||||
|
||||
def to_bytes(self):
|
||||
return bytes(self)
|
||||
|
||||
def __bytes__(self):
|
||||
# Return early if we have a cache
|
||||
if self.bytes:
|
||||
@@ -434,6 +437,8 @@ class DataElement:
|
||||
if size != 1:
|
||||
raise InvalidArgumentError('boolean must be 1 byte')
|
||||
size_index = 0
|
||||
else:
|
||||
raise RuntimeError("internal error - self.type not supported")
|
||||
|
||||
self.bytes = bytes([self.type << 3 | size_index]) + size_bytes + data
|
||||
return self.bytes
|
||||
@@ -474,7 +479,9 @@ class ServiceAttribute:
|
||||
self.value = value
|
||||
|
||||
@staticmethod
|
||||
def list_from_data_elements(elements: List[DataElement]) -> List[ServiceAttribute]:
|
||||
def list_from_data_elements(
|
||||
elements: Sequence[DataElement],
|
||||
) -> list[ServiceAttribute]:
|
||||
attribute_list = []
|
||||
for i in range(0, len(elements) // 2):
|
||||
attribute_id, attribute_value = elements[2 * i : 2 * (i + 1)]
|
||||
@@ -487,7 +494,7 @@ class ServiceAttribute:
|
||||
|
||||
@staticmethod
|
||||
def find_attribute_in_list(
|
||||
attribute_list: List[ServiceAttribute], attribute_id: int
|
||||
attribute_list: Iterable[ServiceAttribute], attribute_id: int
|
||||
) -> Optional[DataElement]:
|
||||
return next(
|
||||
(
|
||||
@@ -535,7 +542,12 @@ class SDP_PDU:
|
||||
See Bluetooth spec @ Vol 3, Part B - 4.2 PROTOCOL DATA UNIT FORMAT
|
||||
'''
|
||||
|
||||
sdp_pdu_classes: Dict[int, Type[SDP_PDU]] = {}
|
||||
RESPONSE_PDU_IDS = {
|
||||
SDP_SERVICE_SEARCH_REQUEST: SDP_SERVICE_SEARCH_RESPONSE,
|
||||
SDP_SERVICE_ATTRIBUTE_REQUEST: SDP_SERVICE_ATTRIBUTE_RESPONSE,
|
||||
SDP_SERVICE_SEARCH_ATTRIBUTE_REQUEST: SDP_SERVICE_SEARCH_ATTRIBUTE_RESPONSE,
|
||||
}
|
||||
sdp_pdu_classes: dict[int, Type[SDP_PDU]] = {}
|
||||
name = None
|
||||
pdu_id = 0
|
||||
|
||||
@@ -559,7 +571,7 @@ class SDP_PDU:
|
||||
@staticmethod
|
||||
def parse_service_record_handle_list_preceded_by_count(
|
||||
data: bytes, offset: int
|
||||
) -> Tuple[int, List[int]]:
|
||||
) -> tuple[int, list[int]]:
|
||||
count = struct.unpack_from('>H', data, offset - 2)[0]
|
||||
handle_list = [
|
||||
struct.unpack_from('>I', data, offset + x * 4)[0] for x in range(count)
|
||||
@@ -621,11 +633,8 @@ class SDP_PDU:
|
||||
def init_from_bytes(self, pdu, offset):
|
||||
return HCI_Object.init_from_bytes(self, pdu, offset, self.fields)
|
||||
|
||||
def to_bytes(self):
|
||||
return self.pdu
|
||||
|
||||
def __bytes__(self):
|
||||
return self.to_bytes()
|
||||
return self.pdu
|
||||
|
||||
def __str__(self):
|
||||
result = f'{color(self.name, "blue")} [TID={self.transaction_id}]'
|
||||
@@ -643,6 +652,8 @@ class SDP_ErrorResponse(SDP_PDU):
|
||||
See Bluetooth spec @ Vol 3, Part B - 4.4.1 SDP_ErrorResponse PDU
|
||||
'''
|
||||
|
||||
error_code: int
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@SDP_PDU.subclass(
|
||||
@@ -679,7 +690,7 @@ class SDP_ServiceSearchResponse(SDP_PDU):
|
||||
See Bluetooth spec @ Vol 3, Part B - 4.5.2 SDP_ServiceSearchResponse PDU
|
||||
'''
|
||||
|
||||
service_record_handle_list: List[int]
|
||||
service_record_handle_list: list[int]
|
||||
total_service_record_count: int
|
||||
current_service_record_count: int
|
||||
continuation_state: bytes
|
||||
@@ -756,31 +767,99 @@ class SDP_ServiceSearchAttributeResponse(SDP_PDU):
|
||||
See Bluetooth spec @ Vol 3, Part B - 4.7.2 SDP_ServiceSearchAttributeResponse PDU
|
||||
'''
|
||||
|
||||
attribute_list_byte_count: int
|
||||
attribute_list: bytes
|
||||
attribute_lists_byte_count: int
|
||||
attribute_lists: bytes
|
||||
continuation_state: bytes
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Client:
|
||||
channel: Optional[l2cap.ClassicChannel]
|
||||
|
||||
def __init__(self, connection: Connection) -> None:
|
||||
def __init__(self, connection: Connection, mtu: int = 0) -> None:
|
||||
self.connection = connection
|
||||
self.pending_request = None
|
||||
self.channel = None
|
||||
self.channel: Optional[l2cap.ClassicChannel] = None
|
||||
self.mtu = mtu
|
||||
self.request_semaphore = asyncio.Semaphore(1)
|
||||
self.pending_request: Optional[SDP_PDU] = None
|
||||
self.pending_response: Optional[asyncio.futures.Future[SDP_PDU]] = None
|
||||
self.next_transaction_id = 0
|
||||
|
||||
async def connect(self) -> None:
|
||||
self.channel = await self.connection.create_l2cap_channel(
|
||||
spec=l2cap.ClassicChannelSpec(SDP_PSM)
|
||||
spec=(
|
||||
l2cap.ClassicChannelSpec(SDP_PSM, self.mtu)
|
||||
if self.mtu
|
||||
else l2cap.ClassicChannelSpec(SDP_PSM)
|
||||
)
|
||||
)
|
||||
self.channel.sink = self.on_pdu
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
if self.channel:
|
||||
await self.channel.disconnect()
|
||||
self.channel = None
|
||||
|
||||
async def search_services(self, uuids: List[core.UUID]) -> List[int]:
|
||||
def make_transaction_id(self) -> int:
|
||||
transaction_id = self.next_transaction_id
|
||||
self.next_transaction_id = (self.next_transaction_id + 1) & 0xFFFF
|
||||
return transaction_id
|
||||
|
||||
def on_pdu(self, pdu: bytes) -> None:
|
||||
if not self.pending_request:
|
||||
logger.warning('received response with no pending request')
|
||||
return
|
||||
assert self.pending_response is not None
|
||||
|
||||
response = SDP_PDU.from_bytes(pdu)
|
||||
|
||||
# Check that the transaction ID is what we expect
|
||||
if self.pending_request.transaction_id != response.transaction_id:
|
||||
logger.warning(
|
||||
f"received response with transaction ID {response.transaction_id} "
|
||||
f"but expected {self.pending_request.transaction_id}"
|
||||
)
|
||||
return
|
||||
|
||||
# Check if the response is an error
|
||||
if isinstance(response, SDP_ErrorResponse):
|
||||
self.pending_response.set_exception(
|
||||
ProtocolError(error_code=response.error_code)
|
||||
)
|
||||
return
|
||||
|
||||
# Check that the type of the response matches the request
|
||||
if response.pdu_id != SDP_PDU.RESPONSE_PDU_IDS.get(self.pending_request.pdu_id):
|
||||
logger.warning("response type mismatch")
|
||||
return
|
||||
|
||||
self.pending_response.set_result(response)
|
||||
|
||||
async def send_request(self, request: SDP_PDU) -> SDP_PDU:
|
||||
assert self.channel is not None
|
||||
async with self.request_semaphore:
|
||||
assert self.pending_request is None
|
||||
assert self.pending_response is None
|
||||
|
||||
# Create a future value to hold the eventual response
|
||||
self.pending_response = asyncio.get_running_loop().create_future()
|
||||
self.pending_request = request
|
||||
|
||||
try:
|
||||
self.channel.send_pdu(bytes(request))
|
||||
return await self.pending_response
|
||||
finally:
|
||||
self.pending_request = None
|
||||
self.pending_response = None
|
||||
|
||||
async def search_services(self, uuids: Iterable[core.UUID]) -> list[int]:
|
||||
"""
|
||||
Search for services by UUID.
|
||||
|
||||
Args:
|
||||
uuids: service the UUIDs to search for.
|
||||
|
||||
Returns:
|
||||
A list of matching service record handles.
|
||||
"""
|
||||
if self.pending_request is not None:
|
||||
raise InvalidStateError('request already pending')
|
||||
if self.channel is None:
|
||||
@@ -795,16 +874,16 @@ class Client:
|
||||
continuation_state = bytes([0])
|
||||
watchdog = SDP_CONTINUATION_WATCHDOG
|
||||
while watchdog > 0:
|
||||
response_pdu = await self.channel.send_request(
|
||||
response = await self.send_request(
|
||||
SDP_ServiceSearchRequest(
|
||||
transaction_id=0, # Transaction ID TODO: pick a real value
|
||||
transaction_id=self.make_transaction_id(),
|
||||
service_search_pattern=service_search_pattern,
|
||||
maximum_service_record_count=0xFFFF,
|
||||
continuation_state=continuation_state,
|
||||
)
|
||||
)
|
||||
response = SDP_PDU.from_bytes(response_pdu)
|
||||
logger.debug(f'<<< Response: {response}')
|
||||
assert isinstance(response, SDP_ServiceSearchResponse)
|
||||
service_record_handle_list += response.service_record_handle_list
|
||||
continuation_state = response.continuation_state
|
||||
if len(continuation_state) == 1 and continuation_state[0] == 0:
|
||||
@@ -815,8 +894,21 @@ class Client:
|
||||
return service_record_handle_list
|
||||
|
||||
async def search_attributes(
|
||||
self, uuids: List[core.UUID], attribute_ids: List[Union[int, Tuple[int, int]]]
|
||||
) -> List[List[ServiceAttribute]]:
|
||||
self,
|
||||
uuids: Iterable[core.UUID],
|
||||
attribute_ids: Iterable[Union[int, tuple[int, int]]],
|
||||
) -> list[list[ServiceAttribute]]:
|
||||
"""
|
||||
Search for attributes by UUID and attribute IDs.
|
||||
|
||||
Args:
|
||||
uuids: the service UUIDs to search for.
|
||||
attribute_ids: list of attribute IDs or (start, end) attribute ID ranges.
|
||||
(use (0, 0xFFFF) to include all attributes)
|
||||
|
||||
Returns:
|
||||
A list of list of attributes, one list per matching service.
|
||||
"""
|
||||
if self.pending_request is not None:
|
||||
raise InvalidStateError('request already pending')
|
||||
if self.channel is None:
|
||||
@@ -828,8 +920,8 @@ class Client:
|
||||
attribute_id_list = DataElement.sequence(
|
||||
[
|
||||
(
|
||||
DataElement.unsigned_integer(
|
||||
attribute_id[0], value_size=attribute_id[1]
|
||||
DataElement.unsigned_integer_32(
|
||||
attribute_id[0] << 16 | attribute_id[1]
|
||||
)
|
||||
if isinstance(attribute_id, tuple)
|
||||
else DataElement.unsigned_integer_16(attribute_id)
|
||||
@@ -843,17 +935,17 @@ class Client:
|
||||
continuation_state = bytes([0])
|
||||
watchdog = SDP_CONTINUATION_WATCHDOG
|
||||
while watchdog > 0:
|
||||
response_pdu = await self.channel.send_request(
|
||||
response = await self.send_request(
|
||||
SDP_ServiceSearchAttributeRequest(
|
||||
transaction_id=0, # Transaction ID TODO: pick a real value
|
||||
transaction_id=self.make_transaction_id(),
|
||||
service_search_pattern=service_search_pattern,
|
||||
maximum_attribute_byte_count=0xFFFF,
|
||||
attribute_id_list=attribute_id_list,
|
||||
continuation_state=continuation_state,
|
||||
)
|
||||
)
|
||||
response = SDP_PDU.from_bytes(response_pdu)
|
||||
logger.debug(f'<<< Response: {response}')
|
||||
assert isinstance(response, SDP_ServiceSearchAttributeResponse)
|
||||
accumulator += response.attribute_lists
|
||||
continuation_state = response.continuation_state
|
||||
if len(continuation_state) == 1 and continuation_state[0] == 0:
|
||||
@@ -876,8 +968,18 @@ class Client:
|
||||
async def get_attributes(
|
||||
self,
|
||||
service_record_handle: int,
|
||||
attribute_ids: List[Union[int, Tuple[int, int]]],
|
||||
) -> List[ServiceAttribute]:
|
||||
attribute_ids: Iterable[Union[int, tuple[int, int]]],
|
||||
) -> list[ServiceAttribute]:
|
||||
"""
|
||||
Get attributes for a service.
|
||||
|
||||
Args:
|
||||
service_record_handle: the handle for a service
|
||||
attribute_ids: list or attribute IDs or (start, end) attribute ID handles.
|
||||
|
||||
Returns:
|
||||
A list of attributes.
|
||||
"""
|
||||
if self.pending_request is not None:
|
||||
raise InvalidStateError('request already pending')
|
||||
if self.channel is None:
|
||||
@@ -886,8 +988,8 @@ class Client:
|
||||
attribute_id_list = DataElement.sequence(
|
||||
[
|
||||
(
|
||||
DataElement.unsigned_integer(
|
||||
attribute_id[0], value_size=attribute_id[1]
|
||||
DataElement.unsigned_integer_32(
|
||||
attribute_id[0] << 16 | attribute_id[1]
|
||||
)
|
||||
if isinstance(attribute_id, tuple)
|
||||
else DataElement.unsigned_integer_16(attribute_id)
|
||||
@@ -901,17 +1003,17 @@ class Client:
|
||||
continuation_state = bytes([0])
|
||||
watchdog = SDP_CONTINUATION_WATCHDOG
|
||||
while watchdog > 0:
|
||||
response_pdu = await self.channel.send_request(
|
||||
response = await self.send_request(
|
||||
SDP_ServiceAttributeRequest(
|
||||
transaction_id=0, # Transaction ID TODO: pick a real value
|
||||
transaction_id=self.make_transaction_id(),
|
||||
service_record_handle=service_record_handle,
|
||||
maximum_attribute_byte_count=0xFFFF,
|
||||
attribute_id_list=attribute_id_list,
|
||||
continuation_state=continuation_state,
|
||||
)
|
||||
)
|
||||
response = SDP_PDU.from_bytes(response_pdu)
|
||||
logger.debug(f'<<< Response: {response}')
|
||||
assert isinstance(response, SDP_ServiceAttributeResponse)
|
||||
accumulator += response.attribute_list
|
||||
continuation_state = response.continuation_state
|
||||
if len(continuation_state) == 1 and continuation_state[0] == 0:
|
||||
@@ -937,17 +1039,17 @@ class Client:
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Server:
|
||||
CONTINUATION_STATE = bytes([0x01, 0x43])
|
||||
CONTINUATION_STATE = bytes([0x01, 0x00])
|
||||
channel: Optional[l2cap.ClassicChannel]
|
||||
Service = NewType('Service', List[ServiceAttribute])
|
||||
service_records: Dict[int, Service]
|
||||
current_response: Union[None, bytes, Tuple[int, List[int]]]
|
||||
Service = NewType('Service', list[ServiceAttribute])
|
||||
service_records: dict[int, Service]
|
||||
current_response: Union[None, bytes, tuple[int, list[int]]]
|
||||
|
||||
def __init__(self, device: Device) -> None:
|
||||
self.device = device
|
||||
self.service_records = {} # Service records maps, by record handle
|
||||
self.channel = None
|
||||
self.current_response = None
|
||||
self.current_response = None # Current response data, used for continuations
|
||||
|
||||
def register(self, l2cap_channel_manager: l2cap.ChannelManager) -> None:
|
||||
l2cap_channel_manager.create_classic_server(
|
||||
@@ -958,7 +1060,7 @@ class Server:
|
||||
logger.debug(f'{color(">>> Sending SDP Response", "blue")}: {response}')
|
||||
self.channel.send_pdu(response)
|
||||
|
||||
def match_services(self, search_pattern: DataElement) -> Dict[int, Service]:
|
||||
def match_services(self, search_pattern: DataElement) -> dict[int, Service]:
|
||||
# Find the services for which the attributes in the pattern is a subset of the
|
||||
# service's attribute values (NOTE: the value search recurses into sequences)
|
||||
matching_services = {}
|
||||
@@ -1015,6 +1117,31 @@ class Server:
|
||||
)
|
||||
)
|
||||
|
||||
def check_continuation(
|
||||
self,
|
||||
continuation_state: bytes,
|
||||
transaction_id: int,
|
||||
) -> Optional[bool]:
|
||||
# Check if this is a valid continuation
|
||||
if len(continuation_state) > 1:
|
||||
if (
|
||||
self.current_response is None
|
||||
or continuation_state != self.CONTINUATION_STATE
|
||||
):
|
||||
self.send_response(
|
||||
SDP_ErrorResponse(
|
||||
transaction_id=transaction_id,
|
||||
error_code=SDP_INVALID_CONTINUATION_STATE_ERROR,
|
||||
)
|
||||
)
|
||||
return None
|
||||
return True
|
||||
|
||||
# Cleanup any partial response leftover
|
||||
self.current_response = None
|
||||
|
||||
return False
|
||||
|
||||
def get_next_response_payload(self, maximum_size):
|
||||
if len(self.current_response) > maximum_size:
|
||||
payload = self.current_response[:maximum_size]
|
||||
@@ -1029,7 +1156,7 @@ class Server:
|
||||
|
||||
@staticmethod
|
||||
def get_service_attributes(
|
||||
service: Service, attribute_ids: List[DataElement]
|
||||
service: Service, attribute_ids: Iterable[DataElement]
|
||||
) -> DataElement:
|
||||
attributes = []
|
||||
for attribute_id in attribute_ids:
|
||||
@@ -1057,30 +1184,24 @@ class Server:
|
||||
|
||||
def on_sdp_service_search_request(self, request: SDP_ServiceSearchRequest) -> None:
|
||||
# Check if this is a continuation
|
||||
if len(request.continuation_state) > 1:
|
||||
if self.current_response is None:
|
||||
self.send_response(
|
||||
SDP_ErrorResponse(
|
||||
transaction_id=request.transaction_id,
|
||||
error_code=SDP_INVALID_CONTINUATION_STATE_ERROR,
|
||||
)
|
||||
)
|
||||
return
|
||||
else:
|
||||
# Cleanup any partial response leftover
|
||||
self.current_response = None
|
||||
if (
|
||||
continuation := self.check_continuation(
|
||||
request.continuation_state, request.transaction_id
|
||||
)
|
||||
) is None:
|
||||
return
|
||||
|
||||
if not continuation:
|
||||
# Find the matching services
|
||||
matching_services = self.match_services(request.service_search_pattern)
|
||||
service_record_handles = list(matching_services.keys())
|
||||
logger.debug(f'Service Record Handles: {service_record_handles}')
|
||||
|
||||
# Only return up to the maximum requested
|
||||
service_record_handles_subset = service_record_handles[
|
||||
: request.maximum_service_record_count
|
||||
]
|
||||
|
||||
# Serialize to a byte array, and remember the total count
|
||||
logger.debug(f'Service Record Handles: {service_record_handles}')
|
||||
self.current_response = (
|
||||
len(service_record_handles),
|
||||
service_record_handles_subset,
|
||||
@@ -1088,15 +1209,21 @@ class Server:
|
||||
|
||||
# Respond, keeping any unsent handles for later
|
||||
assert isinstance(self.current_response, tuple)
|
||||
service_record_handles = self.current_response[1][
|
||||
: request.maximum_service_record_count
|
||||
assert self.channel is not None
|
||||
total_service_record_count, service_record_handles = self.current_response
|
||||
maximum_service_record_count = (self.channel.peer_mtu - 11) // 4
|
||||
service_record_handles_remaining = service_record_handles[
|
||||
maximum_service_record_count:
|
||||
]
|
||||
service_record_handles = service_record_handles[:maximum_service_record_count]
|
||||
self.current_response = (
|
||||
self.current_response[0],
|
||||
self.current_response[1][request.maximum_service_record_count :],
|
||||
total_service_record_count,
|
||||
service_record_handles_remaining,
|
||||
)
|
||||
continuation_state = (
|
||||
Server.CONTINUATION_STATE if self.current_response[1] else bytes([0])
|
||||
Server.CONTINUATION_STATE
|
||||
if service_record_handles_remaining
|
||||
else bytes([0])
|
||||
)
|
||||
service_record_handle_list = b''.join(
|
||||
[struct.pack('>I', handle) for handle in service_record_handles]
|
||||
@@ -1104,7 +1231,7 @@ class Server:
|
||||
self.send_response(
|
||||
SDP_ServiceSearchResponse(
|
||||
transaction_id=request.transaction_id,
|
||||
total_service_record_count=self.current_response[0],
|
||||
total_service_record_count=total_service_record_count,
|
||||
current_service_record_count=len(service_record_handles),
|
||||
service_record_handle_list=service_record_handle_list,
|
||||
continuation_state=continuation_state,
|
||||
@@ -1115,19 +1242,14 @@ class Server:
|
||||
self, request: SDP_ServiceAttributeRequest
|
||||
) -> None:
|
||||
# Check if this is a continuation
|
||||
if len(request.continuation_state) > 1:
|
||||
if self.current_response is None:
|
||||
self.send_response(
|
||||
SDP_ErrorResponse(
|
||||
transaction_id=request.transaction_id,
|
||||
error_code=SDP_INVALID_CONTINUATION_STATE_ERROR,
|
||||
)
|
||||
)
|
||||
return
|
||||
else:
|
||||
# Cleanup any partial response leftover
|
||||
self.current_response = None
|
||||
if (
|
||||
continuation := self.check_continuation(
|
||||
request.continuation_state, request.transaction_id
|
||||
)
|
||||
) is None:
|
||||
return
|
||||
|
||||
if not continuation:
|
||||
# Check that the service exists
|
||||
service = self.service_records.get(request.service_record_handle)
|
||||
if service is None:
|
||||
@@ -1149,14 +1271,18 @@ class Server:
|
||||
self.current_response = bytes(attribute_list)
|
||||
|
||||
# Respond, keeping any pending chunks for later
|
||||
assert self.channel is not None
|
||||
maximum_attribute_byte_count = min(
|
||||
request.maximum_attribute_byte_count, self.channel.peer_mtu - 9
|
||||
)
|
||||
attribute_list_response, continuation_state = self.get_next_response_payload(
|
||||
request.maximum_attribute_byte_count
|
||||
maximum_attribute_byte_count
|
||||
)
|
||||
self.send_response(
|
||||
SDP_ServiceAttributeResponse(
|
||||
transaction_id=request.transaction_id,
|
||||
attribute_list_byte_count=len(attribute_list_response),
|
||||
attribute_list=attribute_list,
|
||||
attribute_list=attribute_list_response,
|
||||
continuation_state=continuation_state,
|
||||
)
|
||||
)
|
||||
@@ -1165,18 +1291,14 @@ class Server:
|
||||
self, request: SDP_ServiceSearchAttributeRequest
|
||||
) -> None:
|
||||
# Check if this is a continuation
|
||||
if len(request.continuation_state) > 1:
|
||||
if self.current_response is None:
|
||||
self.send_response(
|
||||
SDP_ErrorResponse(
|
||||
transaction_id=request.transaction_id,
|
||||
error_code=SDP_INVALID_CONTINUATION_STATE_ERROR,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Cleanup any partial response leftover
|
||||
self.current_response = None
|
||||
if (
|
||||
continuation := self.check_continuation(
|
||||
request.continuation_state, request.transaction_id
|
||||
)
|
||||
) is None:
|
||||
return
|
||||
|
||||
if not continuation:
|
||||
# Find the matching services
|
||||
matching_services = self.match_services(
|
||||
request.service_search_pattern
|
||||
@@ -1196,14 +1318,18 @@ class Server:
|
||||
self.current_response = bytes(attribute_lists)
|
||||
|
||||
# Respond, keeping any pending chunks for later
|
||||
assert self.channel is not None
|
||||
maximum_attribute_byte_count = min(
|
||||
request.maximum_attribute_byte_count, self.channel.peer_mtu - 9
|
||||
)
|
||||
attribute_lists_response, continuation_state = self.get_next_response_payload(
|
||||
request.maximum_attribute_byte_count
|
||||
maximum_attribute_byte_count
|
||||
)
|
||||
self.send_response(
|
||||
SDP_ServiceSearchAttributeResponse(
|
||||
transaction_id=request.transaction_id,
|
||||
attribute_lists_byte_count=len(attribute_lists_response),
|
||||
attribute_lists=attribute_lists,
|
||||
attribute_lists=attribute_lists_response,
|
||||
continuation_state=continuation_state,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -298,11 +298,8 @@ class SMP_Command:
|
||||
def init_from_bytes(self, pdu: bytes, offset: int) -> None:
|
||||
return HCI_Object.init_from_bytes(self, pdu, offset, self.fields)
|
||||
|
||||
def to_bytes(self):
|
||||
return self.pdu
|
||||
|
||||
def __bytes__(self):
|
||||
return self.to_bytes()
|
||||
return self.pdu
|
||||
|
||||
def __str__(self):
|
||||
result = color(self.name, 'yellow')
|
||||
@@ -698,6 +695,7 @@ class Session:
|
||||
self.ltk_ediv = 0
|
||||
self.ltk_rand = bytes(8)
|
||||
self.link_key: Optional[bytes] = None
|
||||
self.maximum_encryption_key_size: int = 0
|
||||
self.initiator_key_distribution: int = 0
|
||||
self.responder_key_distribution: int = 0
|
||||
self.peer_random_value: Optional[bytes] = None
|
||||
@@ -744,6 +742,10 @@ class Session:
|
||||
else:
|
||||
self.pairing_result = None
|
||||
|
||||
self.maximum_encryption_key_size = (
|
||||
pairing_config.delegate.maximum_encryption_key_size
|
||||
)
|
||||
|
||||
# Key Distribution (default values before negotiation)
|
||||
self.initiator_key_distribution = (
|
||||
pairing_config.delegate.local_initiator_key_distribution
|
||||
@@ -996,7 +998,7 @@ class Session:
|
||||
io_capability=self.io_capability,
|
||||
oob_data_flag=self.oob_data_flag,
|
||||
auth_req=self.auth_req,
|
||||
maximum_encryption_key_size=16,
|
||||
maximum_encryption_key_size=self.maximum_encryption_key_size,
|
||||
initiator_key_distribution=self.initiator_key_distribution,
|
||||
responder_key_distribution=self.responder_key_distribution,
|
||||
)
|
||||
@@ -1008,7 +1010,7 @@ class Session:
|
||||
io_capability=self.io_capability,
|
||||
oob_data_flag=self.oob_data_flag,
|
||||
auth_req=self.auth_req,
|
||||
maximum_encryption_key_size=16,
|
||||
maximum_encryption_key_size=self.maximum_encryption_key_size,
|
||||
initiator_key_distribution=self.initiator_key_distribution,
|
||||
responder_key_distribution=self.responder_key_distribution,
|
||||
)
|
||||
@@ -1324,7 +1326,7 @@ class Session:
|
||||
self.connection.abort_on('disconnection', self.on_pairing())
|
||||
|
||||
def on_connection_encryption_change(self) -> None:
|
||||
if self.connection.is_encrypted:
|
||||
if self.connection.is_encrypted and not self.completed:
|
||||
if self.is_responder:
|
||||
# The responder distributes its keys first, the initiator later
|
||||
self.distribute_keys()
|
||||
@@ -1839,7 +1841,7 @@ class Session:
|
||||
if self.is_initiator:
|
||||
if self.pairing_method == PairingMethod.OOB:
|
||||
self.send_pairing_random_command()
|
||||
else:
|
||||
elif self.pairing_method == PairingMethod.PASSKEY:
|
||||
self.send_pairing_confirm_command()
|
||||
else:
|
||||
if self.pairing_method == PairingMethod.PASSKEY:
|
||||
@@ -1949,7 +1951,7 @@ class Manager(EventEmitter):
|
||||
f'{connection.peer_address}: {command}'
|
||||
)
|
||||
cid = SMP_BR_CID if connection.transport == BT_BR_EDR_TRANSPORT else SMP_CID
|
||||
connection.send_l2cap_pdu(cid, command.to_bytes())
|
||||
connection.send_l2cap_pdu(cid, bytes(command))
|
||||
|
||||
def on_smp_security_request_command(
|
||||
self, connection: Connection, request: SMP_Security_Request_Command
|
||||
|
||||
@@ -370,11 +370,13 @@ class PumpedPacketSource(ParserSource):
|
||||
self.parser.feed_data(packet)
|
||||
except asyncio.CancelledError:
|
||||
logger.debug('source pump task done')
|
||||
self.terminated.set_result(None)
|
||||
if not self.terminated.done():
|
||||
self.terminated.set_result(None)
|
||||
break
|
||||
except Exception as error:
|
||||
logger.warning(f'exception while waiting for packet: {error}')
|
||||
self.terminated.set_exception(error)
|
||||
if not self.terminated.done():
|
||||
self.terminated.set_exception(error)
|
||||
break
|
||||
|
||||
self.pump_task = asyncio.create_task(pump_packets())
|
||||
|
||||
@@ -149,7 +149,10 @@ async def open_usb_transport(spec: str) -> Transport:
|
||||
|
||||
if status != usb1.TRANSFER_COMPLETED:
|
||||
logger.warning(
|
||||
color(f'!!! OUT transfer not completed: status={status}', 'red')
|
||||
color(
|
||||
f'!!! OUT transfer not completed: status={status}',
|
||||
'red',
|
||||
)
|
||||
)
|
||||
|
||||
async def process_queue(self):
|
||||
@@ -275,7 +278,10 @@ async def open_usb_transport(spec: str) -> Transport:
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
color(f'!!! IN transfer not completed: status={status}', 'red')
|
||||
color(
|
||||
f'!!! IN[{packet_type}] transfer not completed: status={status}',
|
||||
'red',
|
||||
)
|
||||
)
|
||||
self.loop.call_soon_threadsafe(self.on_transport_lost)
|
||||
|
||||
|
||||
@@ -24,17 +24,19 @@ import logging
|
||||
import sys
|
||||
import warnings
|
||||
from typing import (
|
||||
Awaitable,
|
||||
Set,
|
||||
TypeVar,
|
||||
List,
|
||||
Tuple,
|
||||
Callable,
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
List,
|
||||
Optional,
|
||||
Protocol,
|
||||
Set,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
overload,
|
||||
)
|
||||
from typing_extensions import Self
|
||||
|
||||
from pyee import EventEmitter
|
||||
|
||||
@@ -445,7 +447,7 @@ def deprecated(msg: str):
|
||||
def wrapper(function):
|
||||
@functools.wraps(function)
|
||||
def inner(*args, **kwargs):
|
||||
warnings.warn(msg, DeprecationWarning)
|
||||
warnings.warn(msg, DeprecationWarning, stacklevel=2)
|
||||
return function(*args, **kwargs)
|
||||
|
||||
return inner
|
||||
@@ -462,7 +464,7 @@ def experimental(msg: str):
|
||||
def wrapper(function):
|
||||
@functools.wraps(function)
|
||||
def inner(*args, **kwargs):
|
||||
warnings.warn(msg, FutureWarning)
|
||||
warnings.warn(msg, FutureWarning, stacklevel=2)
|
||||
return function(*args, **kwargs)
|
||||
|
||||
return inner
|
||||
@@ -487,3 +489,16 @@ class OpenIntEnum(enum.IntEnum):
|
||||
obj._value_ = value
|
||||
obj._name_ = f"{cls.__name__}[{value}]"
|
||||
return obj
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class ByteSerializable(Protocol):
|
||||
"""
|
||||
Type protocol for classes that can be instantiated from bytes and serialized
|
||||
to bytes.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes) -> Self: ...
|
||||
|
||||
def __bytes__(self) -> bytes: ...
|
||||
|
||||
33
bumble/vendor/android/hci.py
vendored
33
bumble/vendor/android/hci.py
vendored
@@ -16,6 +16,7 @@
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import struct
|
||||
from typing import Dict, Optional, Type
|
||||
|
||||
from bumble.hci import (
|
||||
name_or_number,
|
||||
@@ -24,7 +25,9 @@ from bumble.hci import (
|
||||
HCI_Constant,
|
||||
HCI_Object,
|
||||
HCI_Command,
|
||||
HCI_Vendor_Event,
|
||||
HCI_Event,
|
||||
HCI_Extended_Event,
|
||||
HCI_VENDOR_EVENT,
|
||||
STATUS_SPEC,
|
||||
)
|
||||
|
||||
@@ -48,7 +51,6 @@ HCI_DYNAMIC_AUDIO_BUFFER_COMMAND = hci_vendor_command_op_code(0x15F)
|
||||
HCI_BLUETOOTH_QUALITY_REPORT_EVENT = 0x58
|
||||
|
||||
HCI_Command.register_commands(globals())
|
||||
HCI_Vendor_Event.register_subevents(globals())
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -279,7 +281,29 @@ class HCI_Dynamic_Audio_Buffer_Command(HCI_Command):
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@HCI_Vendor_Event.event(
|
||||
class HCI_Android_Vendor_Event(HCI_Extended_Event):
|
||||
event_code: int = HCI_VENDOR_EVENT
|
||||
subevent_classes: Dict[int, Type[HCI_Extended_Event]] = {}
|
||||
|
||||
@classmethod
|
||||
def subclass_from_parameters(
|
||||
cls, parameters: bytes
|
||||
) -> Optional[HCI_Extended_Event]:
|
||||
subevent_code = parameters[0]
|
||||
if subevent_code == HCI_BLUETOOTH_QUALITY_REPORT_EVENT:
|
||||
quality_report_id = parameters[1]
|
||||
if quality_report_id in (0x01, 0x02, 0x03, 0x04, 0x07, 0x08, 0x09):
|
||||
return HCI_Bluetooth_Quality_Report_Event.from_parameters(parameters)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
HCI_Android_Vendor_Event.register_subevents(globals())
|
||||
HCI_Event.add_vendor_factory(HCI_Android_Vendor_Event.subclass_from_parameters)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@HCI_Extended_Event.event(
|
||||
fields=[
|
||||
('quality_report_id', 1),
|
||||
('packet_types', 1),
|
||||
@@ -308,10 +332,11 @@ class HCI_Dynamic_Audio_Buffer_Command(HCI_Command):
|
||||
('tx_last_subevent_packets', 4),
|
||||
('crc_error_packets', 4),
|
||||
('rx_duplicate_packets', 4),
|
||||
('rx_unreceived_packets', 4),
|
||||
('vendor_specific_parameters', '*'),
|
||||
]
|
||||
)
|
||||
class HCI_Bluetooth_Quality_Report_Event(HCI_Vendor_Event):
|
||||
class HCI_Bluetooth_Quality_Report_Event(HCI_Android_Vendor_Event):
|
||||
# pylint: disable=line-too-long
|
||||
'''
|
||||
See https://source.android.com/docs/core/connect/bluetooth/hci_requirements#bluetooth-quality-report-sub-event
|
||||
|
||||
@@ -39,12 +39,14 @@ nav:
|
||||
- Drivers:
|
||||
- drivers/index.md
|
||||
- Realtek: drivers/realtek.md
|
||||
- Intel: drivers/intel.md
|
||||
- API:
|
||||
- Guide: api/guide.md
|
||||
- Examples: api/examples.md
|
||||
- Reference: api/reference.md
|
||||
- Apps & Tools:
|
||||
- apps_and_tools/index.md
|
||||
- Auracast: apps_and_tools/auracast.md
|
||||
- Console: apps_and_tools/console.md
|
||||
- Bench: apps_and_tools/bench.md
|
||||
- Speaker: apps_and_tools/speaker.md
|
||||
@@ -108,8 +110,8 @@ markdown_extensions:
|
||||
- pymdownx.details
|
||||
- pymdownx.superfences
|
||||
- pymdownx.emoji:
|
||||
emoji_index: !!python/name:materialx.emoji.twemoji
|
||||
emoji_generator: !!python/name:materialx.emoji.to_svg
|
||||
emoji_index: !!python/name:material.extensions.emoji.twemoji
|
||||
emoji_generator: !!python/name:material.extensions.emoji.to_svg
|
||||
- pymdownx.tabbed:
|
||||
alternate_style: true
|
||||
- codehilite:
|
||||
|
||||
@@ -11,32 +11,44 @@ Usage: bumble-bench [OPTIONS] COMMAND [ARGS]...
|
||||
|
||||
Options:
|
||||
--device-config FILENAME Device configuration file
|
||||
--role [sender|receiver|ping|pong]
|
||||
--scenario [send|receive|ping|pong]
|
||||
--mode [gatt-client|gatt-server|l2cap-client|l2cap-server|rfcomm-client|rfcomm-server]
|
||||
--att-mtu MTU GATT MTU (gatt-client mode) [23<=x<=517]
|
||||
--extended-data-length TEXT Request a data length upon connection,
|
||||
specified as tx_octets/tx_time
|
||||
--rfcomm-channel INTEGER RFComm channel to use
|
||||
--role-switch [central|peripheral]
|
||||
Request role switch upon connection (central
|
||||
or peripheral)
|
||||
--rfcomm-channel INTEGER RFComm channel to use (specify 0 for channel
|
||||
discovery via SDP)
|
||||
--rfcomm-uuid TEXT RFComm service UUID to use (ignored if
|
||||
--rfcomm-channel is not 0)
|
||||
--rfcomm-l2cap-mtu INTEGER RFComm L2CAP MTU
|
||||
--rfcomm-max-frame-size INTEGER
|
||||
RFComm maximum frame size
|
||||
--rfcomm-initial-credits INTEGER
|
||||
RFComm initial credits
|
||||
--rfcomm-max-credits INTEGER RFComm max credits
|
||||
--rfcomm-credits-threshold INTEGER
|
||||
RFComm credits threshold
|
||||
--l2cap-psm INTEGER L2CAP PSM to use
|
||||
--l2cap-mtu INTEGER L2CAP MTU to use
|
||||
--l2cap-mps INTEGER L2CAP MPS to use
|
||||
--l2cap-max-credits INTEGER L2CAP maximum number of credits allowed for
|
||||
the peer
|
||||
-s, --packet-size SIZE Packet size (client or ping role)
|
||||
[8<=x<=4096]
|
||||
-c, --packet-count COUNT Packet count (client or ping role)
|
||||
-sd, --start-delay SECONDS Start delay (client or ping role)
|
||||
--repeat N Repeat the run N times (client and ping
|
||||
roles)(0, which is the fault, to run just
|
||||
-s, --packet-size SIZE Packet size (send or ping scenario)
|
||||
[8<=x<=8192]
|
||||
-c, --packet-count COUNT Packet count (send or ping scenario)
|
||||
-sd, --start-delay SECONDS Start delay (send or ping scenario)
|
||||
--repeat N Repeat the run N times (send and ping
|
||||
scenario)(0, which is the fault, to run just
|
||||
once)
|
||||
--repeat-delay SECONDS Delay, in seconds, between repeats
|
||||
--pace MILLISECONDS Wait N milliseconds between packets (0,
|
||||
which is the fault, to send as fast as
|
||||
possible)
|
||||
--linger Don't exit at the end of a run (server and
|
||||
pong roles)
|
||||
--linger Don't exit at the end of a run (receive and
|
||||
pong scenarios)
|
||||
--help Show this message and exit.
|
||||
|
||||
Commands:
|
||||
@@ -71,19 +83,19 @@ using the ``--peripheral`` option. The address will be printed by the Peripheral
|
||||
it starts.
|
||||
|
||||
Independently of whether the device is the Central or Peripheral, each device selects a
|
||||
``mode`` and and ``role`` to run as. The ``mode`` and ``role`` of the Central and Peripheral
|
||||
``mode`` and and ``scenario`` to run as. The ``mode`` and ``scenario`` of the Central and Peripheral
|
||||
must be compatible.
|
||||
|
||||
Device 1 mode | Device 2 mode
|
||||
Device 1 scenario | Device 2 scenario
|
||||
------------------|------------------
|
||||
``gatt-client`` | ``gatt-server``
|
||||
``l2cap-client`` | ``l2cap-server``
|
||||
``rfcomm-client`` | ``rfcomm-server``
|
||||
|
||||
Device 1 role | Device 2 role
|
||||
--------------|--------------
|
||||
``sender`` | ``receiver``
|
||||
``ping`` | ``pong``
|
||||
Device 1 scenario | Device 2 scenario
|
||||
------------------|--------------
|
||||
``send`` | ``receive``
|
||||
``ping`` | ``pong``
|
||||
|
||||
|
||||
# Examples
|
||||
@@ -92,7 +104,7 @@ In the following examples, we have two USB Bluetooth controllers, one on `usb:0`
|
||||
the other on `usb:1`, and two consoles/terminals. We will run a command in each.
|
||||
|
||||
!!! example "GATT Throughput"
|
||||
Using the default mode and role for the Central and Peripheral.
|
||||
Using the default mode and scenario for the Central and Peripheral.
|
||||
|
||||
In the first console/terminal:
|
||||
```
|
||||
@@ -137,12 +149,12 @@ the other on `usb:1`, and two consoles/terminals. We will run a command in each.
|
||||
!!! example "Ping/Pong Latency"
|
||||
In the first console/terminal:
|
||||
```
|
||||
$ bumble-bench --role pong peripheral usb:0
|
||||
$ bumble-bench --scenario pong peripheral usb:0
|
||||
```
|
||||
|
||||
In the second console/terminal:
|
||||
```
|
||||
$ bumble-bench --role ping central usb:1
|
||||
$ bumble-bench --scenario ping central usb:1
|
||||
```
|
||||
|
||||
!!! example "Reversed modes with GATT and custom connection interval"
|
||||
@@ -167,13 +179,13 @@ the other on `usb:1`, and two consoles/terminals. We will run a command in each.
|
||||
$ bumble-bench --mode l2cap-server central --phy 2m usb:1
|
||||
```
|
||||
|
||||
!!! example "Reversed roles with L2CAP"
|
||||
!!! example "Reversed scenarios with L2CAP"
|
||||
In the first console/terminal:
|
||||
```
|
||||
$ bumble-bench --mode l2cap-client --role sender peripheral usb:0
|
||||
$ bumble-bench --mode l2cap-client --scenario send peripheral usb:0
|
||||
```
|
||||
|
||||
In the second console/terminal:
|
||||
```
|
||||
$ bumble-bench --mode l2cap-server --role receiver central usb:1
|
||||
$ bumble-bench --mode l2cap-server --scenario receive central usb:1
|
||||
```
|
||||
|
||||
@@ -4,12 +4,13 @@ APPS & TOOLS
|
||||
Included in the project are a few apps and tools, built on top of the core libraries.
|
||||
These include:
|
||||
|
||||
* [Console](console.md) - an interactive text-based console
|
||||
* [Bench](bench.md) - Speed and Latency benchmarking between two devices (LE and Classic)
|
||||
* [Pair](pair.md) - Pair/bond two devices (LE and Classic)
|
||||
* [Unbond](unbond.md) - Remove a previously established bond
|
||||
* [HCI Bridge](hci_bridge.md) - a HCI transport bridge to connect two HCI transports and filter/snoop the HCI packets
|
||||
* [Golden Gate Bridge](gg_bridge.md) - a bridge between GATT and UDP to use with the Golden Gate "stack tool"
|
||||
* [Show](show.md) - Parse a file with HCI packets and print the details of each packet in a human readable form
|
||||
* [Auracast](auracast.md) - Commands to broadcast, receive and/or control LE Audio.
|
||||
* [Console](console.md) - An interactive text-based console.
|
||||
* [Bench](bench.md) - Speed and Latency benchmarking between two devices (LE and Classic).
|
||||
* [Pair](pair.md) - Pair/bond two devices (LE and Classic).
|
||||
* [Unbond](unbond.md) - Remove a previously established bond.
|
||||
* [HCI Bridge](hci_bridge.md) - An HCI transport bridge to connect two HCI transports and filter/snoop the HCI packets.
|
||||
* [Golden Gate Bridge](gg_bridge.md) - Bridge between GATT and UDP to use with the Golden Gate "stack tool".
|
||||
* [Show](show.md) - Parse a file with HCI packets and print the details of each packet in a human readable form.
|
||||
* [Speaker](speaker.md) - Virtual Bluetooth speaker, with a command line and browser-based UI.
|
||||
* [Link Relay](link_relay.md) - WebSocket relay for virtual RemoteLink instances to communicate with each other.
|
||||
|
||||
@@ -16,4 +16,5 @@ USB vendor ID and product ID.
|
||||
|
||||
Drivers included in the module are:
|
||||
|
||||
* [Realtek](realtek.md): Loading of Firmware and Config for Realtek USB dongles.
|
||||
* [Realtek](realtek.md): Loading of Firmware and Config for Realtek USB dongles.
|
||||
* [Intel](intel.md): Loading of Firmware and Config for Intel USB controllers.
|
||||
73
docs/mkdocs/src/drivers/intel.md
Normal file
73
docs/mkdocs/src/drivers/intel.md
Normal file
@@ -0,0 +1,73 @@
|
||||
INTEL DRIVER
|
||||
==============
|
||||
|
||||
This driver supports loading firmware images and optional config data to
|
||||
Intel USB controllers.
|
||||
A number of USB dongles are supported, but likely not all.
|
||||
The initial implementation has been tested on BE200 and AX210 controllers.
|
||||
When using a USB controller, the USB product ID and vendor ID are used
|
||||
to find whether a matching set of firmware image and config data
|
||||
is needed for that specific model. If a match exists, the driver will try
|
||||
load the firmware image and, if needed, config data.
|
||||
Alternatively, the metadata property ``driver=intel`` may be specified in a transport
|
||||
name to force that driver to be used (ex: ``usb:[driver=intel]0`` instead of just
|
||||
``usb:0`` for the first USB device).
|
||||
The driver will look for the firmware and config files by name, in order, in:
|
||||
|
||||
* The directory specified by the environment variable `BUMBLE_INTEL_FIRMWARE_DIR`
|
||||
if set.
|
||||
* The directory `<package-dir>/drivers/intel_fw` where `<package-dir>` is the directory
|
||||
where the `bumble` package is installed.
|
||||
* The current directory.
|
||||
|
||||
It is also possible to override or extend the config data with parameters passed via the
|
||||
transport name. The driver name `intel` may be suffixed with `/<param:value>[+<param:value>]...`
|
||||
The supported params are:
|
||||
* `ddc_addon`: configuration data to add to the data loaded from the config data file
|
||||
* `ddc_override`: configuration data to use instead of the data loaded from the config data file.
|
||||
|
||||
With both `dcc_addon` and `dcc_override`, the param value is a hex-encoded byte array containing
|
||||
the config data (same format as the config file).
|
||||
Example transport name:
|
||||
`usb:[driver=intel/dcc_addon:03E40200]0`
|
||||
|
||||
|
||||
Obtaining Firmware Images and Config Data
|
||||
-----------------------------------------
|
||||
|
||||
Firmware images and config data may be obtained from a variety of online
|
||||
sources.
|
||||
To facilitate finding a downloading the, the utility program `bumble-intel-fw-download`
|
||||
may be used.
|
||||
|
||||
```
|
||||
Usage: bumble-intel-fw-download [OPTIONS]
|
||||
|
||||
Download Intel firmware images and configs.
|
||||
|
||||
Options:
|
||||
--output-dir TEXT Output directory where the files will be saved.
|
||||
Defaults to the OS-specific app data dir, which the
|
||||
driver will check when trying to find firmware
|
||||
--source [linux-kernel] [default: linux-kernel]
|
||||
--single TEXT Only download a single image set, by its base name
|
||||
--force Overwrite files if they already exist
|
||||
--help Show this message and exit.
|
||||
```
|
||||
|
||||
Utility
|
||||
-------
|
||||
|
||||
The `bumble-intel-util` utility may be used to interact with an Intel USB controller.
|
||||
|
||||
```
|
||||
Usage: bumble-intel-util [OPTIONS] COMMAND [ARGS]...
|
||||
|
||||
Options:
|
||||
--help Show this message and exit.
|
||||
|
||||
Commands:
|
||||
bootloader Reboot in bootloader mode.
|
||||
info Get the firmware info.
|
||||
load Load a firmware image.
|
||||
```
|
||||
@@ -3,17 +3,15 @@ GETTING STARTED WITH BUMBLE
|
||||
|
||||
# Prerequisites
|
||||
|
||||
You need Python 3.8 or above. Python >= 3.9 is recommended, but 3.8 should be sufficient if
|
||||
necessary (there may be some optional functionality that will not work on some platforms with
|
||||
python 3.8).
|
||||
You need Python 3.9 or above.
|
||||
Visit the [Python site](https://www.python.org/) for instructions on how to install Python
|
||||
for your platform.
|
||||
Throughout the documentation, when shell commands are shown, it is assumed that you can
|
||||
invoke Python as
|
||||
```
|
||||
$ python
|
||||
$ python3
|
||||
```
|
||||
If invoking python is different on your platform (it may be `python3` for example, or just `py` or `py.exe`),
|
||||
If invoking python is different on your platform (it may be `python` for example, or just `py` or `py.exe`),
|
||||
adjust accordingly.
|
||||
|
||||
You may be simply using Bumble as a module for your own application or as a dependency to your own
|
||||
@@ -32,12 +30,18 @@ manager, or from source.
|
||||
python environment, or in a virtual environment, such as a `venv`, `pyenv` or `conda` environment.
|
||||
See the [Python Environments page](development/python_environments.md) page for details.
|
||||
|
||||
### Install from PyPI
|
||||
|
||||
```
|
||||
$ python3 -m pip install bumble
|
||||
```
|
||||
|
||||
### Install From Source
|
||||
|
||||
Install with `pip`. Run in a command shell in the directory where you downloaded the source
|
||||
distribution
|
||||
```
|
||||
$ python -m pip install -e .
|
||||
$ python3 -m pip install -e .
|
||||
```
|
||||
|
||||
### Install from GitHub
|
||||
@@ -46,21 +50,21 @@ You can install directly from GitHub without first downloading the repo.
|
||||
|
||||
Install the latest commit from the main branch with `pip`:
|
||||
```
|
||||
$ python -m pip install git+https://github.com/google/bumble.git
|
||||
$ python3 -m pip install git+https://github.com/google/bumble.git
|
||||
```
|
||||
|
||||
You can specify a specific tag.
|
||||
|
||||
Install tag `v0.0.1` with `pip`:
|
||||
```
|
||||
$ python -m pip install git+https://github.com/google/bumble.git@v0.0.1
|
||||
$ python3 -m pip install git+https://github.com/google/bumble.git@v0.0.1
|
||||
```
|
||||
|
||||
You can also specify a specific commit.
|
||||
|
||||
Install commit `27c0551` with `pip`:
|
||||
```
|
||||
$ python -m pip install git+https://github.com/google/bumble.git@27c0551
|
||||
$ python3 -m pip install git+https://github.com/google/bumble.git@27c0551
|
||||
```
|
||||
|
||||
# Working On The Bumble Code
|
||||
@@ -80,21 +84,21 @@ directory of the project.
|
||||
|
||||
```bash
|
||||
$ export PYTHONPATH=.
|
||||
$ python apps/console.py serial:/dev/tty.usbmodem0006839912171
|
||||
$ python3 apps/console.py serial:/dev/tty.usbmodem0006839912171
|
||||
```
|
||||
|
||||
or running an example, with the working directory set to the `examples` subdirectory
|
||||
```bash
|
||||
$ cd examples
|
||||
$ export PYTHONPATH=..
|
||||
$ python run_scanner.py usb:0
|
||||
$ python3 run_scanner.py usb:0
|
||||
```
|
||||
|
||||
Or course, `export PYTHONPATH` only needs to be invoked once, not before each app/script execution.
|
||||
|
||||
Setting `PYTHONPATH` locally with each command would look something like:
|
||||
```
|
||||
$ PYTHONPATH=. python examples/run_advertiser.py examples/device1.json serial:/dev/tty.usbmodem0006839912171
|
||||
$ PYTHONPATH=. python3 examples/run_advertiser.py examples/device1.json serial:/dev/tty.usbmodem0006839912171
|
||||
```
|
||||
|
||||
# Where To Go Next
|
||||
|
||||
@@ -31,7 +31,7 @@ Some of the configurations that may be useful:
|
||||
|
||||
See the [use cases page](use_cases/index.md) for more use cases.
|
||||
|
||||
The project is implemented in Python (Python >= 3.8 is required). A number of APIs for functionality that is inherently I/O bound is implemented in terms of python coroutines with async IO. This means that all of the concurrent tasks run in the same thread, which makes everything much simpler and more predictable.
|
||||
The project is implemented in Python (Python >= 3.9 is required). A number of APIs for functionality that is inherently I/O bound is implemented in terms of python coroutines with async IO. This means that all of the concurrent tasks run in the same thread, which makes everything much simpler and more predictable.
|
||||
|
||||

|
||||
|
||||
|
||||
@@ -35,11 +35,11 @@ the command line.
|
||||
visit [this Android Studio user guide page](https://developer.android.com/studio/run/emulator-commandline)
|
||||
|
||||
The `-packet-streamer-endpoint <endpoint>` command line option may be used to enable
|
||||
Bluetooth emulation and tell the emulator which virtual controller to connect to.
|
||||
Bluetooth emulation and tell the emulator which virtual controller to connect to.
|
||||
|
||||
## Connecting to Netsim
|
||||
|
||||
If the emulator doesn't have Bluetooth emulation enabled by default, use the
|
||||
If the emulator doesn't have Bluetooth emulation enabled by default, use the
|
||||
`-packet-streamer-endpoint default` option to tell it to connect to Netsim.
|
||||
If Netsim is not running, the emulator will start it automatically.
|
||||
|
||||
@@ -60,17 +60,17 @@ the Bumble `android-netsim` transport in `host` mode (the default).
|
||||
|
||||
!!! example "Run the example GATT server connected to the emulator via Netsim"
|
||||
``` shell
|
||||
$ python run_gatt_server.py device1.json android-netsim
|
||||
$ python3 run_gatt_server.py device1.json android-netsim
|
||||
```
|
||||
|
||||
By default, the Bumble `android-netsim` transport will try to automatically discover
|
||||
the port number on which the netsim process is exposing its gRPC server interface. If
|
||||
that discovery process fails, or if you want to specify the interface manually, you
|
||||
that discovery process fails, or if you want to specify the interface manually, you
|
||||
can pass a `hostname` and `port` as parameters to the transport, as: `android-netsim:<host>:<port>`.
|
||||
|
||||
!!! example "Run the example GATT server connected to the emulator via Netsim on a localhost, port 8877"
|
||||
``` shell
|
||||
$ python run_gatt_server.py device1.json android-netsim:localhost:8877
|
||||
$ python3 run_gatt_server.py device1.json android-netsim:localhost:8877
|
||||
```
|
||||
|
||||
### Multiple Instances
|
||||
@@ -84,7 +84,7 @@ For example: `android-netsim:localhost:8877,name=bumble1`
|
||||
This is an advanced use case, which may not be officially supported, but should work in recent
|
||||
versions of the emulator.
|
||||
|
||||
The first step is to run the Bumble HCI bridge, specifying netsim as the "host" end of the
|
||||
The first step is to run the Bumble HCI bridge, specifying netsim as the "host" end of the
|
||||
bridge, and another controller (typically a USB Bluetooth dongle, but any other supported
|
||||
transport can work as well) as the "controller" end of the bridge.
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
PLATFORMS
|
||||
=========
|
||||
|
||||
Most of the code included in the project should run on any platform that supports Python >= 3.8. Not all features are supported on all platforms (for example, USB dongle support is only available on platforms where the python USB library is functional).
|
||||
Most of the code included in the project should run on any platform that supports Python >= 3.9. Not all features are supported on all platforms (for example, USB dongle support is only available on platforms where the python USB library is functional).
|
||||
|
||||
For platform-specific information, see the following pages:
|
||||
|
||||
|
||||
@@ -4,6 +4,6 @@ channels:
|
||||
- conda-forge
|
||||
dependencies:
|
||||
- pip=23
|
||||
- python=3.8
|
||||
- python=3.9
|
||||
- pip:
|
||||
- --editable .[development,documentation,test]
|
||||
|
||||
9
examples/cs_initiator.json
Normal file
9
examples/cs_initiator.json
Normal file
@@ -0,0 +1,9 @@
|
||||
{
|
||||
"name": "Bumble CS Initiator",
|
||||
"address": "F0:F1:F2:F3:F4:F5",
|
||||
"advertising_interval": 100,
|
||||
"keystore": "JsonKeyStore",
|
||||
"irk": "865F81FF5A8B486EAAE29A27AD9F77DC",
|
||||
"identity_address_type": 1,
|
||||
"channel_sounding_enabled": true
|
||||
}
|
||||
9
examples/cs_reflector.json
Normal file
9
examples/cs_reflector.json
Normal file
@@ -0,0 +1,9 @@
|
||||
{
|
||||
"name": "Bumble CS Reflector",
|
||||
"address": "F0:F1:F2:F3:F4:F6",
|
||||
"advertising_interval": 100,
|
||||
"keystore": "JsonKeyStore",
|
||||
"irk": "0c7d74db03a1c98e7be691f76141d53d",
|
||||
"identity_address_type": 1,
|
||||
"channel_sounding_enabled": true
|
||||
}
|
||||
@@ -282,7 +282,7 @@ async def keyboard_device(device, command):
|
||||
GATT_MANUFACTURER_NAME_STRING_CHARACTERISTIC,
|
||||
Characteristic.Properties.READ,
|
||||
Characteristic.READABLE,
|
||||
'Bumble',
|
||||
bytes('Bumble', 'utf-8'),
|
||||
)
|
||||
],
|
||||
),
|
||||
|
||||
69
examples/mobly/bench/one_device_bench_test.py
Normal file
69
examples/mobly/bench/one_device_bench_test.py
Normal file
@@ -0,0 +1,69 @@
|
||||
from mobly import base_test
|
||||
from mobly import test_runner
|
||||
from mobly.controllers import android_device
|
||||
|
||||
|
||||
class OneDeviceBenchTest(base_test.BaseTestClass):
|
||||
|
||||
def setup_class(self):
|
||||
self.ads = self.register_controller(android_device)
|
||||
self.dut = self.ads[0]
|
||||
self.dut.load_snippet("bench", "com.github.google.bumble.btbench")
|
||||
|
||||
def test_rfcomm_client_ping(self):
|
||||
runner = self.dut.bench.runRfcommClient(
|
||||
"ping", "DC:E5:5B:E5:51:2C", 100, 970, 100
|
||||
)
|
||||
print("### Initial status:", runner)
|
||||
final_status = self.dut.bench.waitForRunnerCompletion(runner["id"])
|
||||
print("### Final status:", final_status)
|
||||
|
||||
def test_rfcomm_client_send(self):
|
||||
runner = self.dut.bench.runRfcommClient(
|
||||
"send", "DC:E5:5B:E5:51:2C", 100, 970, 0
|
||||
)
|
||||
print("### Initial status:", runner)
|
||||
final_status = self.dut.bench.waitForRunnerCompletion(runner["id"])
|
||||
print("### Final status:", final_status)
|
||||
|
||||
def test_l2cap_client_ping(self):
|
||||
runner = self.dut.bench.runL2capClient(
|
||||
"ping", "4B:2A:67:76:2B:E3", 128, True, 100, 970, 100, "HIGH"
|
||||
)
|
||||
print("### Initial status:", runner)
|
||||
final_status = self.dut.bench.waitForRunnerCompletion(runner["id"])
|
||||
print("### Final status:", final_status)
|
||||
|
||||
def test_l2cap_client_send(self):
|
||||
runner = self.dut.bench.runL2capClient(
|
||||
"send",
|
||||
"F1:F1:F1:F1:F1:F1",
|
||||
128,
|
||||
True,
|
||||
100,
|
||||
970,
|
||||
0,
|
||||
"HIGH",
|
||||
10000,
|
||||
)
|
||||
print("### Initial status:", runner)
|
||||
final_status = self.dut.bench.waitForRunnerCompletion(runner["id"])
|
||||
print("### Final status:", final_status)
|
||||
|
||||
def test_gatt_client_send(self):
|
||||
runner = self.dut.bench.runGattClient(
|
||||
"send", "F1:F1:F1:F1:F1:F1", 128, True, 100, 970, 100, "HIGH"
|
||||
)
|
||||
print("### Initial status:", runner)
|
||||
final_status = self.dut.bench.waitForRunnerCompletion(runner["id"])
|
||||
print("### Final status:", final_status)
|
||||
|
||||
def test_gatt_server_receive(self):
|
||||
runner = self.dut.bench.runGattServer("receive")
|
||||
print("### Initial status:", runner)
|
||||
final_status = self.dut.bench.waitForRunnerCompletion(runner["id"])
|
||||
print("### Final status:", final_status)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_runner.main()
|
||||
9
examples/mobly/bench/sample_config.yml
Normal file
9
examples/mobly/bench/sample_config.yml
Normal file
@@ -0,0 +1,9 @@
|
||||
TestBeds:
|
||||
- Name: BenchTestBed
|
||||
Controllers:
|
||||
AndroidDevice:
|
||||
- serial: emulator-5554
|
||||
local_bt_address: 94:45:60:5E:03:B0
|
||||
|
||||
#- serial: 23071FDEE001F7
|
||||
# local_bt_address: DC:E5:5B:E5:51:2C
|
||||
38
examples/mobly/bench/two_devices_bench_test.py
Normal file
38
examples/mobly/bench/two_devices_bench_test.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import time
|
||||
|
||||
from mobly import base_test
|
||||
from mobly import test_runner
|
||||
from mobly.controllers import android_device
|
||||
|
||||
|
||||
class TwoDevicesBenchTest(base_test.BaseTestClass):
|
||||
def setup_class(self):
|
||||
self.ads = self.register_controller(android_device)
|
||||
self.dut1 = self.ads[0]
|
||||
self.dut1.load_snippet("bench", "com.github.google.bumble.btbench")
|
||||
self.dut2 = self.ads[1]
|
||||
self.dut2.load_snippet("bench", "com.github.google.bumble.btbench")
|
||||
|
||||
def test_rfcomm_client_send_receive(self):
|
||||
print("### Starting Receiver")
|
||||
receiver = self.dut2.bench.runRfcommServer("receive")
|
||||
receiver_id = receiver["id"]
|
||||
print("--- Receiver status:", receiver)
|
||||
while not receiver["model"]["running"]:
|
||||
print("--- Waiting for Receiver to be running...")
|
||||
time.sleep(1)
|
||||
receiver = self.dut2.bench.getRunner(receiver_id)
|
||||
|
||||
print("### Starting Sender")
|
||||
sender = self.dut1.bench.runRfcommClient(
|
||||
"send", "DC:E5:5B:E5:51:2C", 100, 970, 100
|
||||
)
|
||||
print("--- Sender status:", sender)
|
||||
|
||||
print("--- Waiting for Sender to complete...")
|
||||
sender_result = self.dut1.bench.waitForRunnerCompletion(sender["id"])
|
||||
print("--- Sender result:", sender_result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_runner.main()
|
||||
154
examples/run_channel_sounding.py
Normal file
154
examples/run_channel_sounding.py
Normal file
@@ -0,0 +1,154 @@
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
import os
|
||||
import functools
|
||||
|
||||
from bumble import core
|
||||
from bumble import hci
|
||||
from bumble.device import Connection, Device, ChannelSoundingCapabilities
|
||||
from bumble.transport import open_transport_or_link
|
||||
|
||||
# From https://cs.android.com/android/platform/superproject/main/+/main:packages/modules/Bluetooth/system/gd/hci/distance_measurement_manager.cc.
|
||||
CS_TONE_ANTENNA_CONFIG_MAPPING_TABLE = [
|
||||
[0, 4, 5, 6],
|
||||
[1, 7, 7, 7],
|
||||
[2, 7, 7, 7],
|
||||
[3, 7, 7, 7],
|
||||
]
|
||||
CS_PREFERRED_PEER_ANTENNA_MAPPING_TABLE = [1, 1, 1, 1, 3, 7, 15, 3]
|
||||
CS_ANTENNA_PERMUTATION_ARRAY = [
|
||||
[1, 2, 3, 4],
|
||||
[2, 1, 3, 4],
|
||||
[1, 3, 2, 4],
|
||||
[3, 1, 2, 4],
|
||||
[3, 2, 1, 4],
|
||||
[2, 3, 1, 4],
|
||||
[1, 2, 4, 3],
|
||||
[2, 1, 4, 3],
|
||||
[1, 4, 2, 3],
|
||||
[4, 1, 2, 3],
|
||||
[4, 2, 1, 3],
|
||||
[2, 4, 1, 3],
|
||||
[1, 4, 3, 2],
|
||||
[4, 1, 3, 2],
|
||||
[1, 3, 4, 2],
|
||||
[3, 1, 4, 2],
|
||||
[3, 4, 1, 2],
|
||||
[4, 3, 1, 2],
|
||||
[4, 2, 3, 1],
|
||||
[2, 4, 3, 1],
|
||||
[4, 3, 2, 1],
|
||||
[3, 4, 2, 1],
|
||||
[3, 2, 4, 1],
|
||||
[2, 3, 4, 1],
|
||||
]
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def main() -> None:
|
||||
if len(sys.argv) < 3:
|
||||
print(
|
||||
'Usage: run_channel_sounding.py <config-file> <transport-spec-for-device>'
|
||||
'[target_address](If missing, run as reflector)'
|
||||
)
|
||||
print('example: run_channel_sounding.py cs_reflector.json usb:0')
|
||||
print(
|
||||
'example: run_channel_sounding.py cs_initiator.json usb:0 F0:F1:F2:F3:F4:F5'
|
||||
)
|
||||
return
|
||||
|
||||
print('<<< connecting to HCI...')
|
||||
async with await open_transport_or_link(sys.argv[2]) as hci_transport:
|
||||
print('<<< connected')
|
||||
|
||||
device = Device.from_config_file_with_hci(
|
||||
sys.argv[1], hci_transport.source, hci_transport.sink
|
||||
)
|
||||
await device.power_on()
|
||||
assert (local_cs_capabilities := device.cs_capabilities)
|
||||
|
||||
if len(sys.argv) == 3:
|
||||
print('<<< Start Advertising')
|
||||
await device.start_advertising(
|
||||
own_address_type=hci.OwnAddressType.RANDOM, auto_restart=True
|
||||
)
|
||||
|
||||
def on_cs_capabilities(
|
||||
connection: Connection, capabilities: ChannelSoundingCapabilities
|
||||
):
|
||||
del capabilities
|
||||
print('<<< Set CS Settings')
|
||||
asyncio.create_task(device.set_default_cs_settings(connection))
|
||||
|
||||
device.on(
|
||||
'connection',
|
||||
lambda connection: connection.on(
|
||||
'channel_sounding_capabilities',
|
||||
functools.partial(on_cs_capabilities, connection),
|
||||
),
|
||||
)
|
||||
else:
|
||||
target_address = hci.Address(sys.argv[3])
|
||||
|
||||
print(f'<<< Connecting to {target_address}')
|
||||
connection = await device.connect(
|
||||
target_address, transport=core.BT_LE_TRANSPORT
|
||||
)
|
||||
print('<<< ACL Connected')
|
||||
if not (await device.get_long_term_key(connection.handle, b'', 0)):
|
||||
print('<<< No bond, start pairing')
|
||||
await connection.pair()
|
||||
print('<<< Pairing complete')
|
||||
|
||||
print('<<< Encrypting Connection')
|
||||
await connection.encrypt()
|
||||
|
||||
print('<<< Getting remote CS Capabilities...')
|
||||
remote_capabilities = await device.get_remote_cs_capabilities(connection)
|
||||
print('<<< Set CS Settings...')
|
||||
await device.set_default_cs_settings(connection)
|
||||
print('<<< Set CS Config...')
|
||||
config = await device.create_cs_config(connection)
|
||||
print('<<< Enable CS Security...')
|
||||
await device.enable_cs_security(connection)
|
||||
tone_antenna_config_selection = CS_TONE_ANTENNA_CONFIG_MAPPING_TABLE[
|
||||
local_cs_capabilities.num_antennas_supported - 1
|
||||
][remote_capabilities.num_antennas_supported - 1]
|
||||
print('<<< Set CS Procedure Parameters...')
|
||||
await device.set_cs_procedure_parameters(
|
||||
connection=connection,
|
||||
config=config,
|
||||
tone_antenna_config_selection=tone_antenna_config_selection,
|
||||
preferred_peer_antenna=CS_PREFERRED_PEER_ANTENNA_MAPPING_TABLE[
|
||||
tone_antenna_config_selection
|
||||
],
|
||||
)
|
||||
print('<<< Enable CS Procedure...')
|
||||
await device.enable_cs_procedure(connection=connection, config=config)
|
||||
|
||||
await hci_transport.source.terminated
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
|
||||
asyncio.run(main())
|
||||
@@ -64,6 +64,7 @@ async def main() -> None:
|
||||
[(AdvertisingData.COMPLETE_LOCAL_NAME, "Bumble 2".encode("utf-8"))]
|
||||
)
|
||||
|
||||
# pylint: disable=possibly-used-before-assignment
|
||||
if device.host.number_of_supported_advertising_sets >= 2:
|
||||
set2 = await device.create_advertising_set(
|
||||
random_address=Address("F0:F0:F0:F0:F0:F1"),
|
||||
|
||||
@@ -127,7 +127,7 @@ async def main() -> None:
|
||||
'486F64C6-4B5F-4B3B-8AFF-EDE134A8446A',
|
||||
Characteristic.Properties.READ | Characteristic.Properties.NOTIFY,
|
||||
Characteristic.READABLE,
|
||||
'hello',
|
||||
bytes('hello', 'utf-8'),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
319
examples/run_gatt_with_adapters.py
Normal file
319
examples/run_gatt_with_adapters.py
Normal file
@@ -0,0 +1,319 @@
|
||||
# Copyright 2024 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import struct
|
||||
import sys
|
||||
from typing import Any, List, Union
|
||||
|
||||
from bumble.device import Device, Peer
|
||||
from bumble import transport
|
||||
from bumble import gatt
|
||||
from bumble import hci
|
||||
from bumble import core
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
SERVICE_UUID = core.UUID("50DB505C-8AC4-4738-8448-3B1D9CC09CC5")
|
||||
CHARACTERISTIC_UUID_BASE = "D901B45B-4916-412E-ACCA-0000000000"
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclasses.dataclass
|
||||
class CustomSerializableClass:
|
||||
x: int
|
||||
y: int
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes) -> CustomSerializableClass:
|
||||
return cls(*struct.unpack(">II", data))
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return struct.pack(">II", self.x, self.y)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclasses.dataclass
|
||||
class CustomClass:
|
||||
a: int
|
||||
b: int
|
||||
|
||||
@classmethod
|
||||
def decode(cls, data: bytes) -> CustomClass:
|
||||
return cls(*struct.unpack(">II", data))
|
||||
|
||||
def encode(self) -> bytes:
|
||||
return struct.pack(">II", self.a, self.b)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def client(device: Device, address: hci.Address) -> None:
|
||||
print(f'=== Connecting to {address}...')
|
||||
connection = await device.connect(address)
|
||||
print('=== Connected')
|
||||
|
||||
# Discover all characteristics.
|
||||
peer = Peer(connection)
|
||||
print("*** Discovering services and characteristics...")
|
||||
await peer.discover_all()
|
||||
print("*** Discovery complete")
|
||||
|
||||
service = peer.get_services_by_uuid(SERVICE_UUID)[0]
|
||||
characteristics = []
|
||||
for index in range(1, 9):
|
||||
characteristics.append(
|
||||
service.get_characteristics_by_uuid(
|
||||
core.UUID(CHARACTERISTIC_UUID_BASE + f"{index:02X}")
|
||||
)[0]
|
||||
)
|
||||
|
||||
# Read all characteristics as raw bytes.
|
||||
for characteristic in characteristics:
|
||||
value = await characteristic.read_value()
|
||||
print(f"### {characteristic} = {value!r} ({value.hex()})")
|
||||
|
||||
# Static characteristic with a bytes value.
|
||||
c1 = characteristics[0]
|
||||
c1_value = await c1.read_value()
|
||||
print(f"@@@ C1 {c1} value = {c1_value!r} (type={type(c1_value)})")
|
||||
await c1.write_value("happy π day".encode("utf-8"))
|
||||
|
||||
# Static characteristic with a string value.
|
||||
c2 = gatt.UTF8CharacteristicAdapter(characteristics[1])
|
||||
c2_value = await c2.read_value()
|
||||
print(f"@@@ C2 {c2} value = {c2_value} (type={type(c2_value)})")
|
||||
await c2.write_value("happy π day")
|
||||
|
||||
# Static characteristic with a tuple value.
|
||||
c3 = gatt.PackedCharacteristicAdapter(characteristics[2], ">III")
|
||||
c3_value = await c3.read_value()
|
||||
print(f"@@@ C3 {c3} value = {c3_value} (type={type(c3_value)})")
|
||||
await c3.write_value((2001, 2002, 2003))
|
||||
|
||||
# Static characteristic with a named tuple value.
|
||||
c4 = gatt.MappedCharacteristicAdapter(
|
||||
characteristics[3], ">III", ["f1", "f2", "f3"]
|
||||
)
|
||||
c4_value = await c4.read_value()
|
||||
print(f"@@@ C4 {c4} value = {c4_value} (type={type(c4_value)})")
|
||||
await c4.write_value({"f1": 4001, "f2": 4002, "f3": 4003})
|
||||
|
||||
# Static characteristic with a serializable value.
|
||||
c5 = gatt.SerializableCharacteristicAdapter(
|
||||
characteristics[4], CustomSerializableClass
|
||||
)
|
||||
c5_value = await c5.read_value()
|
||||
print(f"@@@ C5 {c5} value = {c5_value} (type={type(c5_value)})")
|
||||
await c5.write_value(CustomSerializableClass(56, 57))
|
||||
|
||||
# Static characteristic with a delegated value.
|
||||
c6 = gatt.DelegatedCharacteristicAdapter(
|
||||
characteristics[5], encode=CustomClass.encode, decode=CustomClass.decode
|
||||
)
|
||||
c6_value = await c6.read_value()
|
||||
print(f"@@@ C6 {c6} value = {c6_value} (type={type(c6_value)})")
|
||||
await c6.write_value(CustomClass(6, 7))
|
||||
|
||||
# Dynamic characteristic with a bytes value.
|
||||
c7 = characteristics[6]
|
||||
c7_value = await c7.read_value()
|
||||
print(f"@@@ C7 {c7} value = {c7_value!r} (type={type(c7_value)})")
|
||||
await c7.write_value(bytes.fromhex("01020304"))
|
||||
|
||||
# Dynamic characteristic with a string value.
|
||||
c8 = gatt.UTF8CharacteristicAdapter(characteristics[7])
|
||||
c8_value = await c8.read_value()
|
||||
print(f"@@@ C8 {c8} value = {c8_value} (type={type(c8_value)})")
|
||||
await c8.write_value("howdy")
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def dynamic_read(selector: str) -> Union[bytes, str]:
|
||||
if selector == "bytes":
|
||||
print("$$$ Returning random bytes")
|
||||
return random.randbytes(7)
|
||||
elif selector == "string":
|
||||
print("$$$ Returning random string")
|
||||
return random.randbytes(7).hex()
|
||||
|
||||
raise ValueError("invalid selector")
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def dynamic_write(selector: str, value: Any) -> None:
|
||||
print(f"$$$ Received[{selector}]: {value} (type={type(value)})")
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def on_characteristic_read(characteristic: gatt.Characteristic, value: Any) -> None:
|
||||
"""Event listener invoked when a characteristic is read."""
|
||||
print(f"<<< READ: {characteristic} -> {value} ({type(value)})")
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def on_characteristic_write(characteristic: gatt.Characteristic, value: Any) -> None:
|
||||
"""Event listener invoked when a characteristic is written."""
|
||||
print(f"<<< WRITE: {characteristic} <- {value} ({type(value)})")
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def main() -> None:
|
||||
if len(sys.argv) < 2:
|
||||
print("Usage: run_gatt_with_adapters.py <transport-spec> [<bluetooth-address>]")
|
||||
print("example: run_gatt_with_adapters.py usb:0 E1:CA:72:48:C4:E8")
|
||||
return
|
||||
|
||||
async with await transport.open_transport(sys.argv[1]) as hci_transport:
|
||||
# Create a device to manage the host
|
||||
device = Device.with_hci(
|
||||
"Bumble",
|
||||
hci.Address("F0:F1:F2:F3:F4:F5"),
|
||||
hci_transport.source,
|
||||
hci_transport.sink,
|
||||
)
|
||||
|
||||
# Static characteristic with a bytes value.
|
||||
c1 = gatt.Characteristic(
|
||||
CHARACTERISTIC_UUID_BASE + "01",
|
||||
gatt.Characteristic.Properties.READ | gatt.Characteristic.Properties.WRITE,
|
||||
gatt.Characteristic.READABLE | gatt.Characteristic.WRITEABLE,
|
||||
b'hello',
|
||||
)
|
||||
|
||||
# Static characteristic with a string value.
|
||||
c2 = gatt.UTF8CharacteristicAdapter(
|
||||
gatt.Characteristic(
|
||||
CHARACTERISTIC_UUID_BASE + "02",
|
||||
gatt.Characteristic.Properties.READ
|
||||
| gatt.Characteristic.Properties.WRITE,
|
||||
gatt.Characteristic.READABLE | gatt.Characteristic.WRITEABLE,
|
||||
'hello',
|
||||
)
|
||||
)
|
||||
|
||||
# Static characteristic with a tuple value.
|
||||
c3 = gatt.PackedCharacteristicAdapter(
|
||||
gatt.Characteristic(
|
||||
CHARACTERISTIC_UUID_BASE + "03",
|
||||
gatt.Characteristic.Properties.READ
|
||||
| gatt.Characteristic.Properties.WRITE,
|
||||
gatt.Characteristic.READABLE | gatt.Characteristic.WRITEABLE,
|
||||
(1007, 1008, 1009),
|
||||
),
|
||||
">III",
|
||||
)
|
||||
|
||||
# Static characteristic with a named tuple value.
|
||||
c4 = gatt.MappedCharacteristicAdapter(
|
||||
gatt.Characteristic(
|
||||
CHARACTERISTIC_UUID_BASE + "04",
|
||||
gatt.Characteristic.Properties.READ
|
||||
| gatt.Characteristic.Properties.WRITE,
|
||||
gatt.Characteristic.READABLE | gatt.Characteristic.WRITEABLE,
|
||||
{"f1": 3007, "f2": 3008, "f3": 3009},
|
||||
),
|
||||
">III",
|
||||
["f1", "f2", "f3"],
|
||||
)
|
||||
|
||||
# Static characteristic with a serializable value.
|
||||
c5 = gatt.SerializableCharacteristicAdapter(
|
||||
gatt.Characteristic(
|
||||
CHARACTERISTIC_UUID_BASE + "05",
|
||||
gatt.Characteristic.Properties.READ
|
||||
| gatt.Characteristic.Properties.WRITE,
|
||||
gatt.Characteristic.READABLE | gatt.Characteristic.WRITEABLE,
|
||||
CustomSerializableClass(11, 12),
|
||||
),
|
||||
CustomSerializableClass,
|
||||
)
|
||||
|
||||
# Static characteristic with a delegated value.
|
||||
c6 = gatt.DelegatedCharacteristicAdapter(
|
||||
gatt.Characteristic(
|
||||
CHARACTERISTIC_UUID_BASE + "06",
|
||||
gatt.Characteristic.Properties.READ
|
||||
| gatt.Characteristic.Properties.WRITE,
|
||||
gatt.Characteristic.READABLE | gatt.Characteristic.WRITEABLE,
|
||||
CustomClass(1, 2),
|
||||
),
|
||||
encode=CustomClass.encode,
|
||||
decode=CustomClass.decode,
|
||||
)
|
||||
|
||||
# Dynamic characteristic with a bytes value.
|
||||
c7 = gatt.Characteristic(
|
||||
CHARACTERISTIC_UUID_BASE + "07",
|
||||
gatt.Characteristic.Properties.READ | gatt.Characteristic.Properties.WRITE,
|
||||
gatt.Characteristic.READABLE | gatt.Characteristic.WRITEABLE,
|
||||
gatt.CharacteristicValue(
|
||||
read=lambda connection: dynamic_read("bytes"),
|
||||
write=lambda connection, value: dynamic_write("bytes", value),
|
||||
),
|
||||
)
|
||||
|
||||
# Dynamic characteristic with a string value.
|
||||
c8 = gatt.UTF8CharacteristicAdapter(
|
||||
gatt.Characteristic(
|
||||
CHARACTERISTIC_UUID_BASE + "08",
|
||||
gatt.Characteristic.Properties.READ
|
||||
| gatt.Characteristic.Properties.WRITE,
|
||||
gatt.Characteristic.READABLE | gatt.Characteristic.WRITEABLE,
|
||||
gatt.CharacteristicValue(
|
||||
read=lambda connection: dynamic_read("string"),
|
||||
write=lambda connection, value: dynamic_write("string", value),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
characteristics: List[
|
||||
Union[gatt.Characteristic, gatt.CharacteristicAdapter]
|
||||
] = [c1, c2, c3, c4, c5, c6, c7, c8]
|
||||
|
||||
# Listen for read and write events.
|
||||
for characteristic in characteristics:
|
||||
characteristic.on(
|
||||
"read",
|
||||
lambda _, value, c=characteristic: on_characteristic_read(c, value),
|
||||
)
|
||||
characteristic.on(
|
||||
"write",
|
||||
lambda _, value, c=characteristic: on_characteristic_write(c, value),
|
||||
)
|
||||
|
||||
device.add_service(gatt.Service(SERVICE_UUID, characteristics)) # type: ignore
|
||||
|
||||
# Get things going
|
||||
await device.power_on()
|
||||
|
||||
# Connect to a peer
|
||||
if len(sys.argv) > 2:
|
||||
await client(device, hci.Address(sys.argv[2]))
|
||||
else:
|
||||
await device.start_advertising(auto_restart=True)
|
||||
|
||||
await hci_transport.source.wait_for_termination()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
|
||||
asyncio.run(main())
|
||||
@@ -21,9 +21,9 @@ import sys
|
||||
import os
|
||||
import io
|
||||
import logging
|
||||
import websockets
|
||||
from typing import Iterable, Optional
|
||||
|
||||
from typing import Optional
|
||||
import websockets
|
||||
|
||||
import bumble.core
|
||||
from bumble.device import Device, ScoLink
|
||||
@@ -82,6 +82,10 @@ def on_microphone_volume(level: int):
|
||||
send_message(type='microphone_volume', level=level)
|
||||
|
||||
|
||||
def on_supported_audio_codecs(codecs: Iterable[hfp.AudioCodec]):
|
||||
send_message(type='supported_audio_codecs', codecs=[codec.name for codec in codecs])
|
||||
|
||||
|
||||
def on_sco_state_change(codec: int):
|
||||
if codec == hfp.AudioCodec.CVSD:
|
||||
sample_rate = 8000
|
||||
@@ -207,6 +211,7 @@ async def main() -> None:
|
||||
ag_protocol = hfp.AgProtocol(dlc, configuration)
|
||||
ag_protocol.on('speaker_volume', on_speaker_volume)
|
||||
ag_protocol.on('microphone_volume', on_microphone_volume)
|
||||
ag_protocol.on('supported_audio_codecs', on_supported_audio_codecs)
|
||||
on_hfp_state_change(True)
|
||||
dlc.multiplexer.l2cap_channel.on(
|
||||
'close', lambda: on_hfp_state_change(False)
|
||||
@@ -241,7 +246,7 @@ async def main() -> None:
|
||||
# Pick the first one
|
||||
channel, version, hf_sdp_features = hfp_record
|
||||
print(f'HF version: {version}')
|
||||
print(f'HF features: {hf_sdp_features}')
|
||||
print(f'HF features: {hf_sdp_features.name}')
|
||||
|
||||
# Request authentication
|
||||
print('*** Authenticating...')
|
||||
|
||||
@@ -57,6 +57,9 @@ def on_dlc(dlc: rfcomm.DLC, configuration: hfp.HfConfiguration):
|
||||
esco_parameters = hfp.ESCO_PARAMETERS[
|
||||
hfp.DefaultCodecParameters.ESCO_CVSD_S4
|
||||
]
|
||||
else:
|
||||
raise RuntimeError("unknown active codec")
|
||||
|
||||
connection.abort_on(
|
||||
'disconnection',
|
||||
connection.device.send_command(
|
||||
|
||||
@@ -161,7 +161,13 @@ async def main() -> None:
|
||||
else:
|
||||
file_output = open(f'{datetime.datetime.now().isoformat()}.lc3', 'wb')
|
||||
codec_configuration = ase.codec_specific_configuration
|
||||
assert isinstance(codec_configuration, CodecSpecificConfiguration)
|
||||
if (
|
||||
not isinstance(codec_configuration, CodecSpecificConfiguration)
|
||||
or codec_configuration.sampling_frequency is None
|
||||
or codec_configuration.audio_channel_allocation is None
|
||||
or codec_configuration.frame_duration is None
|
||||
):
|
||||
return
|
||||
# Write a LC3 header.
|
||||
file_output.write(
|
||||
bytes([0x1C, 0xCC]) # Header.
|
||||
|
||||
@@ -42,7 +42,7 @@ from bumble.profiles.bap import (
|
||||
from bumble.profiles.pacs import PacRecord, PublishedAudioCapabilitiesService
|
||||
from bumble.profiles.cap import CommonAudioServiceService
|
||||
from bumble.profiles.csip import CoordinatedSetIdentificationService, SirkType
|
||||
from bumble.profiles.vcp import VolumeControlService
|
||||
from bumble.profiles.vcs import VolumeControlService
|
||||
|
||||
from bumble.transport import open_transport_or_link
|
||||
|
||||
@@ -117,13 +117,17 @@ async def main() -> None:
|
||||
|
||||
ws: Optional[websockets.WebSocketServerProtocol] = None
|
||||
|
||||
def on_volume_state(volume_setting: int, muted: int, change_counter: int):
|
||||
def on_volume_state_change():
|
||||
if ws:
|
||||
asyncio.create_task(
|
||||
ws.send(dumps_volume_state(volume_setting, muted, change_counter))
|
||||
ws.send(
|
||||
dumps_volume_state(
|
||||
vcs.volume_setting, vcs.muted, vcs.change_counter
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
vcs.on('volume_state', on_volume_state)
|
||||
vcs.on('volume_state_change', on_volume_state_change)
|
||||
|
||||
advertising_data = (
|
||||
bytes(
|
||||
@@ -170,16 +174,10 @@ async def main() -> None:
|
||||
ws = websocket
|
||||
async for message in websocket:
|
||||
volume_state = json.loads(message)
|
||||
vcs.volume_state_bytes = bytes(
|
||||
[
|
||||
volume_state['volume_setting'],
|
||||
volume_state['muted'],
|
||||
volume_state['change_counter'],
|
||||
]
|
||||
)
|
||||
await device.notify_subscribers(
|
||||
vcs.volume_state, vcs.volume_state_bytes
|
||||
)
|
||||
vcs.volume_setting = volume_state['volume_setting']
|
||||
vcs.muted = volume_state['muted']
|
||||
vcs.change_counter = volume_state['change_counter']
|
||||
await device.notify_subscribers(vcs.volume_state)
|
||||
ws = None
|
||||
|
||||
await websockets.serve(serve, 'localhost', 8989)
|
||||
|
||||
@@ -10,7 +10,7 @@ android {
|
||||
|
||||
defaultConfig {
|
||||
applicationId = "com.github.google.bumble.btbench"
|
||||
minSdk = 30
|
||||
minSdk = 33
|
||||
targetSdk = 34
|
||||
versionCode = 1
|
||||
versionName = "1.0"
|
||||
@@ -60,6 +60,8 @@ dependencies {
|
||||
implementation(libs.ui.graphics)
|
||||
implementation(libs.ui.tooling.preview)
|
||||
implementation(libs.material3)
|
||||
implementation(libs.mobly.snippet)
|
||||
implementation(libs.androidx.core)
|
||||
testImplementation(libs.junit)
|
||||
androidTestImplementation(libs.androidx.test.ext.junit)
|
||||
androidTestImplementation(libs.espresso.core)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
package="com.github.google.bumble.btbench">
|
||||
<uses-sdk android:minSdkVersion="30" android:targetSdkVersion="34" />
|
||||
<uses-sdk android:minSdkVersion="33" android:targetSdkVersion="34" />
|
||||
<!-- Request legacy Bluetooth permissions on older devices. -->
|
||||
<uses-permission android:name="android.permission.BLUETOOTH" android:maxSdkVersion="30" />
|
||||
<uses-permission android:name="android.permission.BLUETOOTH_ADMIN" android:maxSdkVersion="30" />
|
||||
@@ -9,6 +9,8 @@
|
||||
<uses-permission android:name="android.permission.BLUETOOTH_SCAN" android:usesPermissionFlags="neverForLocation"/>
|
||||
<uses-permission android:name="android.permission.BLUETOOTH_ADVERTISE" />
|
||||
<uses-permission android:name="android.permission.BLUETOOTH_CONNECT" />
|
||||
<uses-permission android:name="android.permission.ACCESS_NETWORK_STATE" />
|
||||
<uses-permission android:name="android.permission.INTERNET" />
|
||||
|
||||
<uses-feature android:name="android.hardware.bluetooth" android:required="true"/>
|
||||
<uses-feature android:name="android.hardware.bluetooth_le" android:required="true"/>
|
||||
@@ -23,6 +25,9 @@
|
||||
android:supportsRtl="true"
|
||||
android:theme="@style/Theme.BTBench"
|
||||
>
|
||||
<meta-data
|
||||
android:name="mobly-snippets"
|
||||
android:value="com.github.google.bumble.btbench.AutomationSnippet"/>
|
||||
<activity
|
||||
android:name=".MainActivity"
|
||||
android:exported="true"
|
||||
@@ -35,5 +40,7 @@
|
||||
</activity>
|
||||
<!-- <profileable android:shell="true"/>-->
|
||||
</application>
|
||||
|
||||
</manifest>
|
||||
<instrumentation
|
||||
android:name="com.google.android.mobly.snippet.SnippetRunner"
|
||||
android:targetPackage="com.github.google.bumble.btbench" />
|
||||
</manifest>
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
package com.github.google.bumble.btbench
|
||||
|
||||
import android.annotation.SuppressLint
|
||||
import android.bluetooth.BluetoothAdapter
|
||||
import android.bluetooth.le.AdvertiseCallback
|
||||
import android.bluetooth.le.AdvertiseData
|
||||
import android.bluetooth.le.AdvertiseSettings
|
||||
import android.bluetooth.le.AdvertiseSettings.ADVERTISE_MODE_LOW_LATENCY
|
||||
import android.os.Build
|
||||
import java.util.logging.Logger
|
||||
|
||||
private val Log = Logger.getLogger("btbench.advertiser")
|
||||
|
||||
class Advertiser(private val bluetoothAdapter: BluetoothAdapter) : AdvertiseCallback() {
|
||||
@SuppressLint("MissingPermission")
|
||||
fun start() {
|
||||
val advertiseSettingsBuilder = AdvertiseSettings.Builder()
|
||||
.setAdvertiseMode(ADVERTISE_MODE_LOW_LATENCY)
|
||||
.setConnectable(true)
|
||||
advertiseSettingsBuilder.setDiscoverable(true)
|
||||
val advertiseSettings = advertiseSettingsBuilder.build()
|
||||
val advertiseData = AdvertiseData.Builder().build()
|
||||
val scanData = AdvertiseData.Builder().setIncludeDeviceName(true).build()
|
||||
bluetoothAdapter.bluetoothLeAdvertiser.startAdvertising(advertiseSettings, advertiseData, scanData, this)
|
||||
}
|
||||
|
||||
@SuppressLint("MissingPermission")
|
||||
fun stop() {
|
||||
bluetoothAdapter.bluetoothLeAdvertiser.stopAdvertising(this)
|
||||
}
|
||||
|
||||
override fun onStartFailure(errorCode: Int) {
|
||||
Log.warning("failed to start advertising: $errorCode")
|
||||
}
|
||||
|
||||
override fun onStartSuccess(settingsInEffect: AdvertiseSettings) {
|
||||
Log.info("advertising started: $settingsInEffect")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,383 @@
|
||||
// Copyright 2024 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.github.google.bumble.btbench;
|
||||
|
||||
import android.bluetooth.BluetoothAdapter;
|
||||
import android.bluetooth.BluetoothManager;
|
||||
import android.content.Context;
|
||||
|
||||
import androidx.test.core.app.ApplicationProvider;
|
||||
|
||||
import com.google.android.mobly.snippet.Snippet;
|
||||
import com.google.android.mobly.snippet.rpc.Rpc;
|
||||
import com.google.android.mobly.snippet.rpc.RpcOptional;
|
||||
|
||||
import org.json.JSONArray;
|
||||
import org.json.JSONException;
|
||||
import org.json.JSONObject;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.security.InvalidParameterException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.UUID;
|
||||
|
||||
class Runner {
|
||||
public UUID mId;
|
||||
private final Mode mMode;
|
||||
private final String mModeName;
|
||||
private final String mScenario;
|
||||
private final AppViewModel mModel;
|
||||
|
||||
Runner(Mode mode, String modeName, String scenario, AppViewModel model) {
|
||||
this.mId = UUID.randomUUID();
|
||||
this.mMode = mode;
|
||||
this.mModeName = modeName;
|
||||
this.mScenario = scenario;
|
||||
this.mModel = model;
|
||||
}
|
||||
|
||||
public JSONObject toJson() throws JSONException {
|
||||
JSONObject result = new JSONObject();
|
||||
result.put("id", mId.toString());
|
||||
result.put("mode", mModeName);
|
||||
result.put("scenario", mScenario);
|
||||
result.put("model", AutomationSnippet.modelToJson(mModel));
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
public void stop() {
|
||||
mModel.abort();
|
||||
}
|
||||
|
||||
public void waitForCompletion() {
|
||||
mMode.waitForCompletion();
|
||||
}
|
||||
}
|
||||
|
||||
public class AutomationSnippet implements Snippet {
|
||||
private static final String TAG = "btbench.snippet";
|
||||
private final BluetoothAdapter mBluetoothAdapter;
|
||||
private final Context mContext;
|
||||
private final ArrayList<Runner> mRunners = new ArrayList<>();
|
||||
|
||||
public AutomationSnippet() throws IOException {
|
||||
mContext = ApplicationProvider.getApplicationContext();
|
||||
BluetoothManager bluetoothManager = mContext.getSystemService(BluetoothManager.class);
|
||||
mBluetoothAdapter = bluetoothManager.getAdapter();
|
||||
if (mBluetoothAdapter == null) {
|
||||
throw new IOException("bluetooth not supported");
|
||||
}
|
||||
if (!mBluetoothAdapter.isEnabled()) {
|
||||
throw new IOException("bluetooth not enabled");
|
||||
}
|
||||
}
|
||||
|
||||
private Runner runScenario(AppViewModel model, String mode, String scenario) {
|
||||
Mode runnable;
|
||||
switch (mode) {
|
||||
case "rfcomm-client":
|
||||
runnable = new RfcommClient(model, mBluetoothAdapter,
|
||||
(PacketIO packetIO) -> createIoClient(model, scenario,
|
||||
packetIO));
|
||||
break;
|
||||
|
||||
case "rfcomm-server":
|
||||
runnable = new RfcommServer(model, mBluetoothAdapter,
|
||||
(PacketIO packetIO) -> createIoClient(model, scenario,
|
||||
packetIO));
|
||||
break;
|
||||
|
||||
case "l2cap-client":
|
||||
runnable = new L2capClient(model, mBluetoothAdapter, mContext,
|
||||
(PacketIO packetIO) -> createIoClient(model, scenario,
|
||||
packetIO));
|
||||
break;
|
||||
|
||||
case "l2cap-server":
|
||||
runnable = new L2capServer(model, mBluetoothAdapter,
|
||||
(PacketIO packetIO) -> createIoClient(model, scenario,
|
||||
packetIO));
|
||||
break;
|
||||
|
||||
case "gatt-client":
|
||||
runnable = new GattClient(model, mBluetoothAdapter, mContext,
|
||||
(PacketIO packetIO) -> createIoClient(model, scenario,
|
||||
packetIO));
|
||||
break;
|
||||
|
||||
case "gatt-server":
|
||||
runnable = new GattServer(model, mBluetoothAdapter, mContext,
|
||||
(PacketIO packetIO) -> createIoClient(model, scenario,
|
||||
packetIO));
|
||||
break;
|
||||
|
||||
default:
|
||||
return null;
|
||||
}
|
||||
|
||||
model.setMode(mode);
|
||||
model.setScenario(scenario);
|
||||
runnable.run();
|
||||
Runner runner = new Runner(runnable, mode, scenario, model);
|
||||
mRunners.add(runner);
|
||||
return runner;
|
||||
}
|
||||
|
||||
private IoClient createIoClient(AppViewModel model, String scenario, PacketIO packetIO) {
|
||||
switch (scenario) {
|
||||
case "send":
|
||||
return new Sender(model, packetIO);
|
||||
|
||||
case "receive":
|
||||
return new Receiver(model, packetIO);
|
||||
|
||||
case "ping":
|
||||
return new Pinger(model, packetIO);
|
||||
|
||||
case "pong":
|
||||
return new Ponger(model, packetIO);
|
||||
|
||||
default:
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
public static JSONObject modelToJson(AppViewModel model) throws JSONException {
|
||||
JSONObject result = new JSONObject();
|
||||
result.put("status", model.getStatus());
|
||||
result.put("running", model.getRunning());
|
||||
result.put("peer_bluetooth_address", model.getPeerBluetoothAddress());
|
||||
result.put("mode", model.getMode());
|
||||
result.put("scenario", model.getScenario());
|
||||
result.put("sender_packet_size", model.getSenderPacketSize());
|
||||
result.put("sender_packet_count", model.getSenderPacketCount());
|
||||
result.put("sender_packet_interval", model.getSenderPacketInterval());
|
||||
result.put("packets_sent", model.getPacketsSent());
|
||||
result.put("packets_received", model.getPacketsReceived());
|
||||
result.put("l2cap_psm", model.getL2capPsm());
|
||||
result.put("use_2m_phy", model.getUse2mPhy());
|
||||
result.put("connection_priority", model.getConnectionPriority());
|
||||
result.put("mtu", model.getMtu());
|
||||
result.put("rx_phy", model.getRxPhy());
|
||||
result.put("tx_phy", model.getTxPhy());
|
||||
result.put("startup_delay", model.getStartupDelay());
|
||||
if (model.getStatus().equals("OK")) {
|
||||
JSONObject stats = new JSONObject();
|
||||
result.put("stats", stats);
|
||||
stats.put("throughput", model.getThroughput());
|
||||
JSONObject rttStats = new JSONObject();
|
||||
stats.put("rtt", rttStats);
|
||||
rttStats.put("compound", model.getStats());
|
||||
} else {
|
||||
result.put("last_error", model.getLastError());
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
private Runner findRunner(String runnerId) {
|
||||
for (Runner runner : mRunners) {
|
||||
if (runner.mId.toString().equals(runnerId)) {
|
||||
return runner;
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
@Rpc(description = "Run a scenario in RFComm Client mode")
|
||||
public JSONObject runRfcommClient(String scenario, String peerBluetoothAddress, int packetCount,
|
||||
int packetSize, int packetInterval,
|
||||
@RpcOptional Integer startupDelay) throws JSONException {
|
||||
// We only support "send" and "ping" for this mode for now
|
||||
if (!(scenario.equals("send") || scenario.equals("ping"))) {
|
||||
throw new InvalidParameterException(
|
||||
"only 'send' and 'ping' are supported for this mode");
|
||||
}
|
||||
|
||||
AppViewModel model = new AppViewModel();
|
||||
model.setPeerBluetoothAddress(peerBluetoothAddress);
|
||||
model.setSenderPacketCount(packetCount);
|
||||
model.setSenderPacketSize(packetSize);
|
||||
model.setSenderPacketInterval(packetInterval);
|
||||
if (startupDelay != null) {
|
||||
model.setStartupDelay(startupDelay);
|
||||
}
|
||||
|
||||
Runner runner = runScenario(model, "rfcomm-client", scenario);
|
||||
assert runner != null;
|
||||
return runner.toJson();
|
||||
}
|
||||
|
||||
@Rpc(description = "Run a scenario in RFComm Server mode")
|
||||
public JSONObject runRfcommServer(String scenario,
|
||||
@RpcOptional Integer startupDelay) throws JSONException {
|
||||
// We only support "receive" and "pong" for this mode for now
|
||||
if (!(scenario.equals("receive") || scenario.equals("pong"))) {
|
||||
throw new InvalidParameterException(
|
||||
"only 'receive' and 'pong' are supported for this mode");
|
||||
}
|
||||
|
||||
AppViewModel model = new AppViewModel();
|
||||
if (startupDelay != null) {
|
||||
model.setStartupDelay(startupDelay);
|
||||
}
|
||||
|
||||
Runner runner = runScenario(model, "rfcomm-server", scenario);
|
||||
assert runner != null;
|
||||
return runner.toJson();
|
||||
}
|
||||
|
||||
@Rpc(description = "Run a scenario in L2CAP Client mode")
|
||||
public JSONObject runL2capClient(String scenario, String peerBluetoothAddress, int psm,
|
||||
boolean use_2m_phy, int packetCount, int packetSize,
|
||||
int packetInterval, @RpcOptional String connectionPriority,
|
||||
@RpcOptional Integer startupDelay) throws JSONException {
|
||||
// We only support "send" and "ping" for this mode for now
|
||||
if (!(scenario.equals("send") || scenario.equals("ping"))) {
|
||||
throw new InvalidParameterException(
|
||||
"only 'send' and 'ping' are supported for this mode");
|
||||
}
|
||||
|
||||
AppViewModel model = new AppViewModel();
|
||||
model.setPeerBluetoothAddress(peerBluetoothAddress);
|
||||
model.setL2capPsm(psm);
|
||||
model.setUse2mPhy(use_2m_phy);
|
||||
model.setSenderPacketCount(packetCount);
|
||||
model.setSenderPacketSize(packetSize);
|
||||
model.setSenderPacketInterval(packetInterval);
|
||||
if (connectionPriority != null) {
|
||||
model.setConnectionPriority(connectionPriority);
|
||||
}
|
||||
if (startupDelay != null) {
|
||||
model.setStartupDelay(startupDelay);
|
||||
}
|
||||
Runner runner = runScenario(model, "l2cap-client", scenario);
|
||||
assert runner != null;
|
||||
return runner.toJson();
|
||||
}
|
||||
|
||||
@Rpc(description = "Run a scenario in L2CAP Server mode")
|
||||
public JSONObject runL2capServer(String scenario,
|
||||
@RpcOptional Integer startupDelay) throws JSONException {
|
||||
// We only support "receive" and "pong" for this mode for now
|
||||
if (!(scenario.equals("receive") || scenario.equals("pong"))) {
|
||||
throw new InvalidParameterException(
|
||||
"only 'receive' and 'pong' are supported for this mode");
|
||||
}
|
||||
|
||||
AppViewModel model = new AppViewModel();
|
||||
if (startupDelay != null) {
|
||||
model.setStartupDelay(startupDelay);
|
||||
}
|
||||
|
||||
Runner runner = runScenario(model, "l2cap-server", scenario);
|
||||
assert runner != null;
|
||||
return runner.toJson();
|
||||
}
|
||||
|
||||
@Rpc(description = "Run a scenario in GATT Client mode")
|
||||
public JSONObject runGattClient(String scenario, String peerBluetoothAddress,
|
||||
boolean use_2m_phy, int packetCount, int packetSize,
|
||||
int packetInterval, @RpcOptional String connectionPriority,
|
||||
@RpcOptional Integer startupDelay) throws JSONException {
|
||||
// We only support "send" and "ping" for this mode for now
|
||||
if (!(scenario.equals("send") || scenario.equals("ping"))) {
|
||||
throw new InvalidParameterException(
|
||||
"only 'send' and 'ping' are supported for this mode");
|
||||
}
|
||||
|
||||
AppViewModel model = new AppViewModel();
|
||||
model.setPeerBluetoothAddress(peerBluetoothAddress);
|
||||
model.setUse2mPhy(use_2m_phy);
|
||||
model.setSenderPacketCount(packetCount);
|
||||
model.setSenderPacketSize(packetSize);
|
||||
model.setSenderPacketInterval(packetInterval);
|
||||
if (connectionPriority != null) {
|
||||
model.setConnectionPriority(connectionPriority);
|
||||
}
|
||||
if (startupDelay != null) {
|
||||
model.setStartupDelay(startupDelay);
|
||||
}
|
||||
Runner runner = runScenario(model, "gatt-client", scenario);
|
||||
assert runner != null;
|
||||
return runner.toJson();
|
||||
}
|
||||
|
||||
@Rpc(description = "Run a scenario in GATT Server mode")
|
||||
public JSONObject runGattServer(String scenario,
|
||||
@RpcOptional Integer startupDelay) throws JSONException {
|
||||
// We only support "receive" and "pong" for this mode for now
|
||||
if (!(scenario.equals("receive") || scenario.equals("pong"))) {
|
||||
throw new InvalidParameterException(
|
||||
"only 'receive' and 'pong' are supported for this mode");
|
||||
}
|
||||
|
||||
AppViewModel model = new AppViewModel();
|
||||
if (startupDelay != null) {
|
||||
model.setStartupDelay(startupDelay);
|
||||
}
|
||||
|
||||
Runner runner = runScenario(model, "gatt-server", scenario);
|
||||
assert runner != null;
|
||||
return runner.toJson();
|
||||
}
|
||||
|
||||
@Rpc(description = "Stop a Runner")
|
||||
public JSONObject stopRunner(String runnerId) throws JSONException {
|
||||
Runner runner = findRunner(runnerId);
|
||||
if (runner == null) {
|
||||
return new JSONObject();
|
||||
}
|
||||
runner.stop();
|
||||
return runner.toJson();
|
||||
}
|
||||
|
||||
@Rpc(description = "Wait for a Runner to complete")
|
||||
public JSONObject waitForRunnerCompletion(String runnerId) throws JSONException {
|
||||
Runner runner = findRunner(runnerId);
|
||||
if (runner == null) {
|
||||
return new JSONObject();
|
||||
}
|
||||
runner.waitForCompletion();
|
||||
return runner.toJson();
|
||||
}
|
||||
|
||||
@Rpc(description = "Get a Runner by ID")
|
||||
public JSONObject getRunner(String runnerId) throws JSONException {
|
||||
Runner runner = findRunner(runnerId);
|
||||
if (runner == null) {
|
||||
return new JSONObject();
|
||||
}
|
||||
return runner.toJson();
|
||||
}
|
||||
|
||||
@Rpc(description = "Get all Runners")
|
||||
public JSONObject getRunners() throws JSONException {
|
||||
JSONObject result = new JSONObject();
|
||||
JSONArray runners = new JSONArray();
|
||||
result.put("runners", runners);
|
||||
for (Runner runner : mRunners) {
|
||||
runners.put(runner.toJson());
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void shutdown() {
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,109 @@
|
||||
package com.github.google.bumble.btbench
|
||||
|
||||
import android.annotation.SuppressLint
|
||||
import android.bluetooth.BluetoothAdapter
|
||||
import android.bluetooth.BluetoothDevice
|
||||
import android.bluetooth.BluetoothGatt
|
||||
import android.bluetooth.BluetoothGattCallback
|
||||
import android.bluetooth.BluetoothManager
|
||||
import android.bluetooth.BluetoothProfile
|
||||
import android.content.Context
|
||||
import android.os.Build
|
||||
import androidx.core.content.ContextCompat
|
||||
import java.util.logging.Logger
|
||||
|
||||
private val Log = Logger.getLogger("btbench.connection")
|
||||
|
||||
open class Connection(
|
||||
private val viewModel: AppViewModel,
|
||||
private val bluetoothAdapter: BluetoothAdapter,
|
||||
private val context: Context
|
||||
) : BluetoothGattCallback() {
|
||||
var remoteDevice: BluetoothDevice? = null
|
||||
var gatt: BluetoothGatt? = null
|
||||
|
||||
@SuppressLint("MissingPermission")
|
||||
open fun connect() {
|
||||
val addressIsPublic = viewModel.peerBluetoothAddress.endsWith("/P")
|
||||
val address = viewModel.peerBluetoothAddress.take(17)
|
||||
remoteDevice = if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.TIRAMISU) {
|
||||
bluetoothAdapter.getRemoteLeDevice(
|
||||
address,
|
||||
if (addressIsPublic) {
|
||||
BluetoothDevice.ADDRESS_TYPE_PUBLIC
|
||||
} else {
|
||||
BluetoothDevice.ADDRESS_TYPE_RANDOM
|
||||
}
|
||||
)
|
||||
} else {
|
||||
bluetoothAdapter.getRemoteDevice(address)
|
||||
}
|
||||
|
||||
gatt = remoteDevice?.connectGatt(
|
||||
context,
|
||||
false,
|
||||
this,
|
||||
BluetoothDevice.TRANSPORT_LE,
|
||||
if (viewModel.use2mPhy) BluetoothDevice.PHY_LE_2M_MASK else BluetoothDevice.PHY_LE_1M_MASK
|
||||
)
|
||||
}
|
||||
|
||||
@SuppressLint("MissingPermission")
|
||||
open fun disconnect() {
|
||||
gatt?.disconnect()
|
||||
}
|
||||
|
||||
override fun onMtuChanged(gatt: BluetoothGatt, mtu: Int, status: Int) {
|
||||
Log.info("MTU update: mtu=$mtu status=$status")
|
||||
viewModel.mtu = mtu
|
||||
}
|
||||
|
||||
override fun onPhyUpdate(gatt: BluetoothGatt, txPhy: Int, rxPhy: Int, status: Int) {
|
||||
Log.info("PHY update: tx=$txPhy, rx=$rxPhy, status=$status")
|
||||
viewModel.txPhy = txPhy
|
||||
viewModel.rxPhy = rxPhy
|
||||
}
|
||||
|
||||
override fun onPhyRead(gatt: BluetoothGatt, txPhy: Int, rxPhy: Int, status: Int) {
|
||||
Log.info("PHY: tx=$txPhy, rx=$rxPhy, status=$status")
|
||||
viewModel.txPhy = txPhy
|
||||
viewModel.rxPhy = rxPhy
|
||||
}
|
||||
|
||||
@SuppressLint("MissingPermission")
|
||||
override fun onConnectionStateChange(
|
||||
gatt: BluetoothGatt?, status: Int, newState: Int
|
||||
) {
|
||||
if (status != BluetoothGatt.GATT_SUCCESS) {
|
||||
Log.warning("onConnectionStateChange status=$status")
|
||||
}
|
||||
|
||||
if (gatt != null && newState == BluetoothProfile.STATE_CONNECTED) {
|
||||
if (viewModel.use2mPhy) {
|
||||
Log.info("requesting 2M PHY")
|
||||
gatt.setPreferredPhy(
|
||||
BluetoothDevice.PHY_LE_2M_MASK,
|
||||
BluetoothDevice.PHY_LE_2M_MASK,
|
||||
BluetoothDevice.PHY_OPTION_NO_PREFERRED
|
||||
)
|
||||
}
|
||||
gatt.readPhy()
|
||||
|
||||
// Request an MTU update, even though we don't use GATT, because Android
|
||||
// won't request a larger link layer maximum data length otherwise.
|
||||
gatt.requestMtu(517)
|
||||
|
||||
// Request a specific connection priority
|
||||
val connectionPriority = when (viewModel.connectionPriority) {
|
||||
"BALANCED" -> BluetoothGatt.CONNECTION_PRIORITY_BALANCED
|
||||
"LOW_POWER" -> BluetoothGatt.CONNECTION_PRIORITY_LOW_POWER
|
||||
"HIGH" -> BluetoothGatt.CONNECTION_PRIORITY_HIGH
|
||||
"DCK" -> BluetoothGatt.CONNECTION_PRIORITY_DCK
|
||||
else -> 0
|
||||
}
|
||||
if (!gatt.requestConnectionPriority(connectionPriority)) {
|
||||
Log.warning("requestConnectionPriority failed")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,23 @@
|
||||
// Copyright 2024 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.github.google.bumble.btbench
|
||||
|
||||
import java.util.UUID
|
||||
|
||||
var CCCD_UUID = UUID.fromString("00002902-0000-1000-8000-00805F9B34FB")
|
||||
|
||||
val BENCH_SERVICE_UUID = UUID.fromString("50DB505C-8AC4-4738-8448-3B1D9CC09CC5")
|
||||
val BENCH_TX_UUID = UUID.fromString("E789C754-41A1-45F4-A948-A0A1A90DBA53")
|
||||
val BENCH_RX_UUID = UUID.fromString("016A2CC7-E14B-4819-935F-1F56EAE4098D")
|
||||
@@ -0,0 +1,224 @@
|
||||
// Copyright 2024 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.github.google.bumble.btbench
|
||||
|
||||
import android.annotation.SuppressLint
|
||||
import android.bluetooth.BluetoothAdapter
|
||||
import android.bluetooth.BluetoothGatt
|
||||
import android.bluetooth.BluetoothGattCharacteristic
|
||||
import android.bluetooth.BluetoothGattDescriptor
|
||||
import android.bluetooth.BluetoothProfile
|
||||
import android.content.Context
|
||||
import java.io.IOException
|
||||
import java.util.UUID
|
||||
import java.util.concurrent.CountDownLatch
|
||||
import java.util.concurrent.Semaphore
|
||||
import java.util.logging.Logger
|
||||
import kotlin.concurrent.thread
|
||||
|
||||
private val Log = Logger.getLogger("btbench.gatt-client")
|
||||
|
||||
|
||||
class GattClientConnection(
|
||||
viewModel: AppViewModel,
|
||||
bluetoothAdapter: BluetoothAdapter,
|
||||
context: Context
|
||||
) : Connection(viewModel, bluetoothAdapter, context), PacketIO {
|
||||
override var packetSink: PacketSink? = null
|
||||
private val discoveryDone: CountDownLatch = CountDownLatch(1)
|
||||
private val writeSemaphore: Semaphore = Semaphore(1)
|
||||
var rxCharacteristic: BluetoothGattCharacteristic? = null
|
||||
var txCharacteristic: BluetoothGattCharacteristic? = null
|
||||
|
||||
override fun connect() {
|
||||
super.connect()
|
||||
|
||||
// Check if we're already connected and have discovered the services
|
||||
if (gatt?.getService(BENCH_SERVICE_UUID) != null) {
|
||||
Log.fine("already connected")
|
||||
onServicesDiscovered(gatt, BluetoothGatt.GATT_SUCCESS)
|
||||
}
|
||||
}
|
||||
|
||||
@SuppressLint("MissingPermission")
|
||||
override fun onConnectionStateChange(
|
||||
gatt: BluetoothGatt?, status: Int, newState: Int
|
||||
) {
|
||||
super.onConnectionStateChange(gatt, status, newState)
|
||||
if (status != BluetoothGatt.GATT_SUCCESS) {
|
||||
Log.warning("onConnectionStateChange status=$status")
|
||||
discoveryDone.countDown()
|
||||
return
|
||||
}
|
||||
if (gatt != null && newState == BluetoothProfile.STATE_CONNECTED) {
|
||||
if (!gatt.discoverServices()) {
|
||||
Log.warning("discoverServices could not start")
|
||||
discoveryDone.countDown()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@SuppressLint("MissingPermission")
|
||||
override fun onServicesDiscovered(gatt: BluetoothGatt?, status: Int) {
|
||||
Log.fine("onServicesDiscovered")
|
||||
|
||||
if (status != BluetoothGatt.GATT_SUCCESS) {
|
||||
Log.warning("failed to discover services: ${status}")
|
||||
discoveryDone.countDown()
|
||||
return
|
||||
}
|
||||
|
||||
// Find the service
|
||||
val service = gatt!!.getService(BENCH_SERVICE_UUID)
|
||||
if (service == null) {
|
||||
Log.warning("GATT Service not found")
|
||||
discoveryDone.countDown()
|
||||
return
|
||||
}
|
||||
|
||||
// Find the RX and TX characteristics
|
||||
rxCharacteristic = service.getCharacteristic(BENCH_RX_UUID)
|
||||
if (rxCharacteristic == null) {
|
||||
Log.warning("GATT RX Characteristics not found")
|
||||
discoveryDone.countDown()
|
||||
return
|
||||
}
|
||||
txCharacteristic = service.getCharacteristic(BENCH_TX_UUID)
|
||||
if (txCharacteristic == null) {
|
||||
Log.warning("GATT TX Characteristics not found")
|
||||
discoveryDone.countDown()
|
||||
return
|
||||
}
|
||||
|
||||
// Subscribe to the RX characteristic
|
||||
Log.fine("subscribing to RX")
|
||||
gatt.setCharacteristicNotification(rxCharacteristic, true)
|
||||
val cccdDescriptor = rxCharacteristic!!.getDescriptor(CCCD_UUID)
|
||||
gatt.writeDescriptor(cccdDescriptor, BluetoothGattDescriptor.ENABLE_NOTIFICATION_VALUE);
|
||||
|
||||
Log.info("GATT discovery complete")
|
||||
discoveryDone.countDown()
|
||||
}
|
||||
|
||||
override fun onCharacteristicWrite(
|
||||
gatt: BluetoothGatt?,
|
||||
characteristic: BluetoothGattCharacteristic?,
|
||||
status: Int
|
||||
) {
|
||||
// Now we can write again
|
||||
writeSemaphore.release()
|
||||
|
||||
if (status != BluetoothGatt.GATT_SUCCESS) {
|
||||
Log.warning("onCharacteristicWrite failed: $status")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
override fun onCharacteristicChanged(
|
||||
gatt: BluetoothGatt,
|
||||
characteristic: BluetoothGattCharacteristic,
|
||||
value: ByteArray
|
||||
) {
|
||||
if (characteristic.uuid == BENCH_RX_UUID && packetSink != null) {
|
||||
val packet = Packet.from(value)
|
||||
packetSink!!.onPacket(packet)
|
||||
}
|
||||
}
|
||||
|
||||
@SuppressLint("MissingPermission")
|
||||
override fun sendPacket(packet: Packet) {
|
||||
if (txCharacteristic == null) {
|
||||
Log.warning("No TX characteristic, dropping")
|
||||
return
|
||||
}
|
||||
|
||||
// Wait until we can write
|
||||
writeSemaphore.acquire()
|
||||
|
||||
// Write the data
|
||||
val data = packet.toBytes()
|
||||
val clampedData = if (data.size > 512) {
|
||||
// Clamp the data to the maximum allowed characteristic data size
|
||||
data.copyOf(512)
|
||||
} else {
|
||||
data
|
||||
}
|
||||
gatt?.writeCharacteristic(
|
||||
txCharacteristic!!,
|
||||
clampedData,
|
||||
BluetoothGattCharacteristic.WRITE_TYPE_NO_RESPONSE
|
||||
)
|
||||
}
|
||||
|
||||
override
|
||||
fun disconnect() {
|
||||
super.disconnect()
|
||||
discoveryDone.countDown()
|
||||
}
|
||||
|
||||
fun waitForDiscoveryCompletion() {
|
||||
discoveryDone.await()
|
||||
}
|
||||
}
|
||||
|
||||
class GattClient(
|
||||
private val viewModel: AppViewModel,
|
||||
bluetoothAdapter: BluetoothAdapter,
|
||||
context: Context,
|
||||
private val createIoClient: (packetIo: PacketIO) -> IoClient
|
||||
) : Mode {
|
||||
private var connection: GattClientConnection =
|
||||
GattClientConnection(viewModel, bluetoothAdapter, context)
|
||||
private var clientThread: Thread? = null
|
||||
|
||||
@SuppressLint("MissingPermission")
|
||||
override fun run() {
|
||||
viewModel.running = true
|
||||
|
||||
clientThread = thread(name = "GattClient") {
|
||||
connection.connect()
|
||||
|
||||
viewModel.aborter = {
|
||||
connection.disconnect()
|
||||
}
|
||||
|
||||
// Discover the rx and tx characteristics
|
||||
connection.waitForDiscoveryCompletion()
|
||||
if (connection.rxCharacteristic == null || connection.txCharacteristic == null) {
|
||||
connection.disconnect()
|
||||
viewModel.running = false
|
||||
return@thread
|
||||
}
|
||||
|
||||
val ioClient = createIoClient(connection)
|
||||
|
||||
try {
|
||||
ioClient.run()
|
||||
viewModel.status = "OK"
|
||||
} catch (error: IOException) {
|
||||
Log.info("run ended abruptly")
|
||||
viewModel.status = "ABORTED"
|
||||
viewModel.lastError = "IO_ERROR"
|
||||
} finally {
|
||||
connection.disconnect()
|
||||
viewModel.running = false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override fun waitForCompletion() {
|
||||
clientThread?.join()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,243 @@
|
||||
// Copyright 2024 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.github.google.bumble.btbench
|
||||
|
||||
import android.annotation.SuppressLint
|
||||
import android.bluetooth.BluetoothAdapter
|
||||
import android.bluetooth.BluetoothDevice
|
||||
import android.bluetooth.BluetoothGatt
|
||||
import android.bluetooth.BluetoothGattCharacteristic
|
||||
import android.bluetooth.BluetoothGattDescriptor
|
||||
import android.bluetooth.BluetoothGattServer
|
||||
import android.bluetooth.BluetoothGattServerCallback
|
||||
import android.bluetooth.BluetoothGattService
|
||||
import android.bluetooth.BluetoothManager
|
||||
import android.bluetooth.BluetoothStatusCodes
|
||||
import android.content.Context
|
||||
import androidx.core.content.ContextCompat
|
||||
import java.io.IOException
|
||||
import java.util.concurrent.CountDownLatch
|
||||
import java.util.concurrent.LinkedBlockingQueue
|
||||
import java.util.concurrent.Semaphore
|
||||
import java.util.logging.Logger
|
||||
import kotlin.concurrent.thread
|
||||
import kotlin.experimental.and
|
||||
|
||||
private val Log = Logger.getLogger("btbench.gatt-server")
|
||||
|
||||
@SuppressLint("MissingPermission")
|
||||
class GattServer(
|
||||
private val viewModel: AppViewModel,
|
||||
private val bluetoothAdapter: BluetoothAdapter,
|
||||
context: Context,
|
||||
private val createIoClient: (packetIo: PacketIO) -> IoClient
|
||||
) : Mode, PacketIO, BluetoothGattServerCallback() {
|
||||
override var packetSink: PacketSink? = null
|
||||
private val gattServer: BluetoothGattServer
|
||||
private val rxCharacteristic: BluetoothGattCharacteristic?
|
||||
private val txCharacteristic: BluetoothGattCharacteristic?
|
||||
private val notifySemaphore: Semaphore = Semaphore(1)
|
||||
private val ready: CountDownLatch = CountDownLatch(1)
|
||||
private var peerDevice: BluetoothDevice? = null
|
||||
private var clientThread: Thread? = null
|
||||
private var sinkQueue: LinkedBlockingQueue<Packet>? = null
|
||||
|
||||
init {
|
||||
val bluetoothManager = ContextCompat.getSystemService(context, BluetoothManager::class.java)
|
||||
gattServer = bluetoothManager!!.openGattServer(context, this)
|
||||
val benchService = gattServer.getService(BENCH_SERVICE_UUID)
|
||||
if (benchService == null) {
|
||||
rxCharacteristic = BluetoothGattCharacteristic(
|
||||
BENCH_RX_UUID,
|
||||
BluetoothGattCharacteristic.PROPERTY_NOTIFY,
|
||||
0
|
||||
)
|
||||
txCharacteristic = BluetoothGattCharacteristic(
|
||||
BENCH_TX_UUID,
|
||||
BluetoothGattCharacteristic.PROPERTY_WRITE_NO_RESPONSE,
|
||||
BluetoothGattCharacteristic.PERMISSION_WRITE
|
||||
)
|
||||
val rxCCCD = BluetoothGattDescriptor(
|
||||
CCCD_UUID,
|
||||
BluetoothGattDescriptor.PERMISSION_READ or BluetoothGattDescriptor.PERMISSION_WRITE
|
||||
)
|
||||
rxCharacteristic.addDescriptor(rxCCCD)
|
||||
|
||||
val service =
|
||||
BluetoothGattService(BENCH_SERVICE_UUID, BluetoothGattService.SERVICE_TYPE_PRIMARY)
|
||||
service.addCharacteristic(rxCharacteristic)
|
||||
service.addCharacteristic(txCharacteristic)
|
||||
|
||||
gattServer.addService(service)
|
||||
} else {
|
||||
rxCharacteristic = benchService.getCharacteristic(BENCH_RX_UUID)
|
||||
txCharacteristic = benchService.getCharacteristic(BENCH_TX_UUID)
|
||||
}
|
||||
}
|
||||
|
||||
override fun onCharacteristicWriteRequest(
|
||||
device: BluetoothDevice?,
|
||||
requestId: Int,
|
||||
characteristic: BluetoothGattCharacteristic?,
|
||||
preparedWrite: Boolean,
|
||||
responseNeeded: Boolean,
|
||||
offset: Int,
|
||||
value: ByteArray?
|
||||
) {
|
||||
Log.info("onCharacteristicWriteRequest")
|
||||
if (characteristic != null && characteristic.uuid == BENCH_TX_UUID) {
|
||||
if (packetSink == null) {
|
||||
Log.warning("no sink, dropping")
|
||||
} else if (offset != 0) {
|
||||
Log.warning("offset != 0")
|
||||
} else if (value == null) {
|
||||
Log.warning("no value")
|
||||
} else {
|
||||
// Deliver the packet in a separate thread so that we don't block this
|
||||
// callback.
|
||||
sinkQueue?.put(Packet.from(value))
|
||||
}
|
||||
}
|
||||
|
||||
if (responseNeeded) {
|
||||
gattServer.sendResponse(device, requestId, BluetoothGatt.GATT_SUCCESS, offset, value)
|
||||
}
|
||||
}
|
||||
|
||||
override fun onNotificationSent(device: BluetoothDevice?, status: Int) {
|
||||
if (status == BluetoothGatt.GATT_SUCCESS) {
|
||||
notifySemaphore.release()
|
||||
}
|
||||
}
|
||||
|
||||
override fun onDescriptorWriteRequest(
|
||||
device: BluetoothDevice?,
|
||||
requestId: Int,
|
||||
descriptor: BluetoothGattDescriptor?,
|
||||
preparedWrite: Boolean,
|
||||
responseNeeded: Boolean,
|
||||
offset: Int,
|
||||
value: ByteArray?
|
||||
) {
|
||||
if (descriptor?.uuid == CCCD_UUID && descriptor?.characteristic?.uuid == BENCH_RX_UUID) {
|
||||
if (offset == 0 && value?.size == 2) {
|
||||
if (value[0].and(1).toInt() != 0) {
|
||||
// Subscription
|
||||
Log.fine("peer subscribed to RX")
|
||||
peerDevice = device
|
||||
ready.countDown()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (responseNeeded) {
|
||||
gattServer.sendResponse(device, requestId, BluetoothGatt.GATT_SUCCESS, offset, value)
|
||||
}
|
||||
}
|
||||
|
||||
@SuppressLint("MissingPermission")
|
||||
override fun sendPacket(packet: Packet) {
|
||||
if (peerDevice == null) {
|
||||
Log.warning("no peer device, cannot send")
|
||||
return
|
||||
}
|
||||
if (rxCharacteristic == null) {
|
||||
Log.warning("no RX characteristic, cannot send")
|
||||
return
|
||||
}
|
||||
|
||||
// Wait until we can notify
|
||||
notifySemaphore.acquire()
|
||||
|
||||
// Send the packet via a notification
|
||||
val result = gattServer.notifyCharacteristicChanged(
|
||||
peerDevice!!,
|
||||
rxCharacteristic,
|
||||
false,
|
||||
packet.toBytes()
|
||||
)
|
||||
if (result != BluetoothStatusCodes.SUCCESS) {
|
||||
Log.warning("notifyCharacteristicChanged failed: $result")
|
||||
notifySemaphore.release()
|
||||
}
|
||||
}
|
||||
|
||||
override fun run() {
|
||||
viewModel.running = true
|
||||
|
||||
// Start advertising
|
||||
Log.fine("starting advertiser")
|
||||
val advertiser = Advertiser(bluetoothAdapter)
|
||||
advertiser.start()
|
||||
|
||||
clientThread = thread(name = "GattServer") {
|
||||
// Wait for a subscriber
|
||||
Log.info("waiting for RX subscriber")
|
||||
viewModel.aborter = {
|
||||
ready.countDown()
|
||||
}
|
||||
ready.await()
|
||||
if (peerDevice == null) {
|
||||
Log.warning("server interrupted")
|
||||
viewModel.running = false
|
||||
gattServer.close()
|
||||
return@thread
|
||||
}
|
||||
Log.info("RX subscriber accepted")
|
||||
|
||||
// Stop advertising
|
||||
Log.info("stopping advertiser")
|
||||
advertiser.stop()
|
||||
|
||||
sinkQueue = LinkedBlockingQueue()
|
||||
val sinkWriterThread = thread(name = "SinkWriter") {
|
||||
while (true) {
|
||||
try {
|
||||
val packet = sinkQueue!!.take()
|
||||
if (packetSink == null) {
|
||||
Log.warning("no sink, dropping packet")
|
||||
continue
|
||||
}
|
||||
packetSink!!.onPacket(packet)
|
||||
} catch (error: InterruptedException) {
|
||||
Log.warning("sink writer interrupted")
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
val ioClient = createIoClient(this)
|
||||
|
||||
try {
|
||||
ioClient.run()
|
||||
viewModel.status = "OK"
|
||||
} catch (error: IOException) {
|
||||
Log.info("run ended abruptly")
|
||||
viewModel.status = "ABORTED"
|
||||
viewModel.lastError = "IO_ERROR"
|
||||
} finally {
|
||||
sinkWriterThread.interrupt()
|
||||
sinkWriterThread.join()
|
||||
gattServer.close()
|
||||
viewModel.running = false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override fun waitForCompletion() {
|
||||
clientThread?.join()
|
||||
Log.info("server thread completed")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
// Copyright 2024 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.github.google.bumble.btbench
|
||||
|
||||
interface IoClient {
|
||||
fun run()
|
||||
fun abort()
|
||||
}
|
||||
@@ -16,86 +16,30 @@ package com.github.google.bumble.btbench
|
||||
|
||||
import android.annotation.SuppressLint
|
||||
import android.bluetooth.BluetoothAdapter
|
||||
import android.bluetooth.BluetoothDevice
|
||||
import android.bluetooth.BluetoothGatt
|
||||
import android.bluetooth.BluetoothGattCallback
|
||||
import android.bluetooth.BluetoothProfile
|
||||
import android.content.Context
|
||||
import android.os.Build
|
||||
import java.util.logging.Logger
|
||||
|
||||
private val Log = Logger.getLogger("btbench.l2cap-client")
|
||||
|
||||
class L2capClient(
|
||||
private val viewModel: AppViewModel,
|
||||
private val bluetoothAdapter: BluetoothAdapter,
|
||||
private val context: Context
|
||||
) {
|
||||
bluetoothAdapter: BluetoothAdapter,
|
||||
context: Context,
|
||||
private val createIoClient: (packetIo: PacketIO) -> IoClient
|
||||
) : Mode {
|
||||
private var connection: Connection = Connection(viewModel, bluetoothAdapter, context)
|
||||
private var socketClient: SocketClient? = null
|
||||
|
||||
@SuppressLint("MissingPermission")
|
||||
fun run() {
|
||||
override fun run() {
|
||||
viewModel.running = true
|
||||
val addressIsPublic = viewModel.peerBluetoothAddress.endsWith("/P")
|
||||
val address = viewModel.peerBluetoothAddress.take(17)
|
||||
val remoteDevice = if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.TIRAMISU) {
|
||||
bluetoothAdapter.getRemoteLeDevice(
|
||||
address,
|
||||
if (addressIsPublic) {
|
||||
BluetoothDevice.ADDRESS_TYPE_PUBLIC
|
||||
} else {
|
||||
BluetoothDevice.ADDRESS_TYPE_RANDOM
|
||||
}
|
||||
)
|
||||
} else {
|
||||
bluetoothAdapter.getRemoteDevice(address)
|
||||
}
|
||||
|
||||
val gatt = remoteDevice.connectGatt(
|
||||
context,
|
||||
false,
|
||||
object : BluetoothGattCallback() {
|
||||
override fun onMtuChanged(gatt: BluetoothGatt, mtu: Int, status: Int) {
|
||||
Log.info("MTU update: mtu=$mtu status=$status")
|
||||
viewModel.mtu = mtu
|
||||
}
|
||||
|
||||
override fun onPhyUpdate(gatt: BluetoothGatt, txPhy: Int, rxPhy: Int, status: Int) {
|
||||
Log.info("PHY update: tx=$txPhy, rx=$rxPhy, status=$status")
|
||||
viewModel.txPhy = txPhy
|
||||
viewModel.rxPhy = rxPhy
|
||||
}
|
||||
|
||||
override fun onPhyRead(gatt: BluetoothGatt, txPhy: Int, rxPhy: Int, status: Int) {
|
||||
Log.info("PHY: tx=$txPhy, rx=$rxPhy, status=$status")
|
||||
viewModel.txPhy = txPhy
|
||||
viewModel.rxPhy = rxPhy
|
||||
}
|
||||
|
||||
override fun onConnectionStateChange(
|
||||
gatt: BluetoothGatt?, status: Int, newState: Int
|
||||
) {
|
||||
if (gatt != null && newState == BluetoothProfile.STATE_CONNECTED) {
|
||||
if (viewModel.use2mPhy) {
|
||||
gatt.setPreferredPhy(
|
||||
BluetoothDevice.PHY_LE_2M_MASK,
|
||||
BluetoothDevice.PHY_LE_2M_MASK,
|
||||
BluetoothDevice.PHY_OPTION_NO_PREFERRED
|
||||
)
|
||||
}
|
||||
gatt.readPhy()
|
||||
|
||||
// Request an MTU update, even though we don't use GATT, because Android
|
||||
// won't request a larger link layer maximum data length otherwise.
|
||||
gatt.requestMtu(517)
|
||||
}
|
||||
}
|
||||
},
|
||||
BluetoothDevice.TRANSPORT_LE,
|
||||
if (viewModel.use2mPhy) BluetoothDevice.PHY_LE_2M_MASK else BluetoothDevice.PHY_LE_1M_MASK
|
||||
)
|
||||
|
||||
val socket = remoteDevice.createInsecureL2capChannel(viewModel.l2capPsm)
|
||||
|
||||
val client = SocketClient(viewModel, socket)
|
||||
client.run()
|
||||
connection.connect()
|
||||
val socket = connection.remoteDevice!!.createInsecureL2capChannel(viewModel.l2capPsm)
|
||||
socketClient = SocketClient(viewModel, socket, createIoClient)
|
||||
socketClient!!.run()
|
||||
}
|
||||
}
|
||||
|
||||
override fun waitForCompletion() {
|
||||
socketClient?.waitForCompletion()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -27,35 +27,29 @@ import kotlin.concurrent.thread
|
||||
|
||||
private val Log = Logger.getLogger("btbench.l2cap-server")
|
||||
|
||||
class L2capServer(private val viewModel: AppViewModel, private val bluetoothAdapter: BluetoothAdapter) {
|
||||
class L2capServer(
|
||||
private val viewModel: AppViewModel,
|
||||
private val bluetoothAdapter: BluetoothAdapter,
|
||||
private val createIoClient: (packetIo: PacketIO) -> IoClient
|
||||
) : Mode {
|
||||
private var socketServer: SocketServer? = null
|
||||
|
||||
@SuppressLint("MissingPermission")
|
||||
fun run() {
|
||||
override fun run() {
|
||||
// Advertise so that the peer can find us and connect.
|
||||
val callback = object: AdvertiseCallback() {
|
||||
override fun onStartFailure(errorCode: Int) {
|
||||
Log.warning("failed to start advertising: $errorCode")
|
||||
}
|
||||
|
||||
override fun onStartSuccess(settingsInEffect: AdvertiseSettings) {
|
||||
Log.info("advertising started: $settingsInEffect")
|
||||
}
|
||||
}
|
||||
val advertiseSettingsBuilder = AdvertiseSettings.Builder()
|
||||
.setAdvertiseMode(ADVERTISE_MODE_LOW_LATENCY)
|
||||
.setConnectable(true)
|
||||
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.UPSIDE_DOWN_CAKE) {
|
||||
advertiseSettingsBuilder.setDiscoverable(true)
|
||||
}
|
||||
val advertiseSettings = advertiseSettingsBuilder.build()
|
||||
val advertiseData = AdvertiseData.Builder().build()
|
||||
val scanData = AdvertiseData.Builder().setIncludeDeviceName(true).build()
|
||||
val advertiser = bluetoothAdapter.bluetoothLeAdvertiser
|
||||
|
||||
val advertiser = Advertiser(bluetoothAdapter)
|
||||
val serverSocket = bluetoothAdapter.listenUsingInsecureL2capChannel()
|
||||
viewModel.l2capPsm = serverSocket.psm
|
||||
Log.info("psm = $serverSocket.psm")
|
||||
|
||||
val server = SocketServer(viewModel, serverSocket)
|
||||
server.run({ advertiser.stopAdvertising(callback) }, { advertiser.startAdvertising(advertiseSettings, advertiseData, scanData, callback) })
|
||||
socketServer = SocketServer(viewModel, serverSocket, createIoClient)
|
||||
socketServer!!.run(
|
||||
{ advertiser.stop() },
|
||||
{ advertiser.start() }
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
override fun waitForCompletion() {
|
||||
socketServer?.waitForCompletion()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,9 +17,12 @@ package com.github.google.bumble.btbench
|
||||
import android.Manifest
|
||||
import android.annotation.SuppressLint
|
||||
import android.bluetooth.BluetoothAdapter
|
||||
import android.bluetooth.BluetoothDevice
|
||||
import android.bluetooth.BluetoothManager
|
||||
import android.content.BroadcastReceiver
|
||||
import android.content.Context
|
||||
import android.content.Intent
|
||||
import android.content.IntentFilter
|
||||
import android.content.pm.PackageManager
|
||||
import android.os.Build
|
||||
import android.os.Bundle
|
||||
@@ -34,12 +37,15 @@ import androidx.compose.foundation.layout.fillMaxSize
|
||||
import androidx.compose.foundation.layout.fillMaxWidth
|
||||
import androidx.compose.foundation.layout.padding
|
||||
import androidx.compose.foundation.rememberScrollState
|
||||
import androidx.compose.foundation.selection.selectable
|
||||
import androidx.compose.foundation.selection.selectableGroup
|
||||
import androidx.compose.foundation.text.KeyboardActions
|
||||
import androidx.compose.foundation.text.KeyboardOptions
|
||||
import androidx.compose.foundation.verticalScroll
|
||||
import androidx.compose.material3.Button
|
||||
import androidx.compose.material3.Divider
|
||||
import androidx.compose.material3.MaterialTheme
|
||||
import androidx.compose.material3.RadioButton
|
||||
import androidx.compose.material3.Slider
|
||||
import androidx.compose.material3.Surface
|
||||
import androidx.compose.material3.Switch
|
||||
@@ -54,6 +60,7 @@ import androidx.compose.ui.focus.FocusRequester
|
||||
import androidx.compose.ui.focus.focusRequester
|
||||
import androidx.compose.ui.platform.LocalFocusManager
|
||||
import androidx.compose.ui.platform.LocalSoftwareKeyboardController
|
||||
import androidx.compose.ui.semantics.Role
|
||||
import androidx.compose.ui.text.font.FontWeight
|
||||
import androidx.compose.ui.text.input.ImeAction
|
||||
import androidx.compose.ui.text.input.KeyboardType
|
||||
@@ -62,6 +69,7 @@ import androidx.compose.ui.unit.dp
|
||||
import androidx.compose.ui.unit.sp
|
||||
import androidx.core.content.ContextCompat
|
||||
import com.github.google.bumble.btbench.ui.theme.BTBenchTheme
|
||||
import java.io.IOException
|
||||
import java.util.logging.Logger
|
||||
|
||||
private val Log = Logger.getLogger("bumble.main-activity")
|
||||
@@ -69,6 +77,10 @@ private val Log = Logger.getLogger("bumble.main-activity")
|
||||
const val PEER_BLUETOOTH_ADDRESS_PREF_KEY = "peer_bluetooth_address"
|
||||
const val SENDER_PACKET_COUNT_PREF_KEY = "sender_packet_count"
|
||||
const val SENDER_PACKET_SIZE_PREF_KEY = "sender_packet_size"
|
||||
const val SENDER_PACKET_INTERVAL_PREF_KEY = "sender_packet_interval"
|
||||
const val SCENARIO_PREF_KEY = "scenario"
|
||||
const val MODE_PREF_KEY = "mode"
|
||||
const val CONNECTION_PRIORITY_PREF_KEY = "connection_priority"
|
||||
|
||||
class MainActivity : ComponentActivity() {
|
||||
private val appViewModel = AppViewModel()
|
||||
@@ -77,6 +89,47 @@ class MainActivity : ComponentActivity() {
|
||||
super.onCreate(savedInstanceState)
|
||||
appViewModel.loadPreferences(getPreferences(Context.MODE_PRIVATE))
|
||||
checkPermissions()
|
||||
registerReceivers()
|
||||
}
|
||||
|
||||
private fun registerReceivers() {
|
||||
val pairingRequestIntentFilter = IntentFilter(BluetoothDevice.ACTION_PAIRING_REQUEST)
|
||||
registerReceiver(object: BroadcastReceiver() {
|
||||
@SuppressLint("MissingPermission")
|
||||
override fun onReceive(context: Context, intent: Intent) {
|
||||
Log.info("ACTION_PAIRING_REQUEST")
|
||||
val extras = intent.extras
|
||||
if (extras != null) {
|
||||
for (key in extras.keySet()) {
|
||||
Log.info("$key: ${extras.get(key)}")
|
||||
}
|
||||
}
|
||||
val device: BluetoothDevice? = intent.getParcelableExtra(BluetoothDevice.EXTRA_DEVICE)
|
||||
if (device != null) {
|
||||
if (checkSelfPermission(Manifest.permission.BLUETOOTH_PRIVILEGED) == PackageManager.PERMISSION_GRANTED) {
|
||||
Log.info("confirming pairing")
|
||||
device.setPairingConfirmation(true)
|
||||
} else {
|
||||
Log.info("we don't have BLUETOOTH_PRIVILEGED, not confirming")
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}, pairingRequestIntentFilter)
|
||||
|
||||
val bondStateChangedIntentFilter = IntentFilter(BluetoothDevice.ACTION_BOND_STATE_CHANGED)
|
||||
registerReceiver(object: BroadcastReceiver() {
|
||||
@SuppressLint("MissingPermission")
|
||||
override fun onReceive(context: Context, intent: Intent) {
|
||||
Log.info("ACTION_BOND_STATE_CHANGED")
|
||||
val extras = intent.extras
|
||||
if (extras != null) {
|
||||
for (key in extras.keySet()) {
|
||||
Log.info("$key: ${extras.get(key)}")
|
||||
}
|
||||
}
|
||||
}
|
||||
}, bondStateChangedIntentFilter)
|
||||
}
|
||||
|
||||
private fun checkPermissions() {
|
||||
@@ -137,12 +190,7 @@ class MainActivity : ComponentActivity() {
|
||||
initBluetooth()
|
||||
setContent {
|
||||
MainView(
|
||||
appViewModel,
|
||||
::becomeDiscoverable,
|
||||
::runRfcommClient,
|
||||
::runRfcommServer,
|
||||
::runL2capClient,
|
||||
::runL2capServer,
|
||||
appViewModel, ::becomeDiscoverable, ::runScenario
|
||||
)
|
||||
}
|
||||
|
||||
@@ -159,37 +207,61 @@ class MainActivity : ComponentActivity() {
|
||||
if (packetSize > 0) {
|
||||
appViewModel.senderPacketSize = packetSize
|
||||
}
|
||||
val packetInterval = intent.getIntExtra("packet-interval", 0)
|
||||
if (packetInterval > 0) {
|
||||
appViewModel.senderPacketInterval = packetInterval
|
||||
}
|
||||
appViewModel.updateSenderPacketSizeSlider()
|
||||
intent.getStringExtra("scenario")?.let {
|
||||
when (it) {
|
||||
"send" -> appViewModel.scenario = SEND_SCENARIO
|
||||
"receive" -> appViewModel.scenario = RECEIVE_SCENARIO
|
||||
"ping" -> appViewModel.scenario = PING_SCENARIO
|
||||
"pong" -> appViewModel.scenario = PONG_SCENARIO
|
||||
}
|
||||
}
|
||||
intent.getStringExtra("mode")?.let {
|
||||
when (it) {
|
||||
"rfcomm-client" -> appViewModel.mode = RFCOMM_CLIENT_MODE
|
||||
"rfcomm-server" -> appViewModel.mode = RFCOMM_SERVER_MODE
|
||||
"l2cap-client" -> appViewModel.mode = L2CAP_CLIENT_MODE
|
||||
"l2cap-server" -> appViewModel.mode = L2CAP_SERVER_MODE
|
||||
"gatt-client" -> appViewModel.mode = GATT_CLIENT_MODE
|
||||
"gatt-server" -> appViewModel.mode = GATT_SERVER_MODE
|
||||
}
|
||||
}
|
||||
intent.getStringExtra("autostart")?.let {
|
||||
when (it) {
|
||||
"rfcomm-client" -> runRfcommClient()
|
||||
"rfcomm-server" -> runRfcommServer()
|
||||
"l2cap-client" -> runL2capClient()
|
||||
"l2cap-server" -> runL2capServer()
|
||||
"run-scenario" -> runScenario()
|
||||
"scan-start" -> runScan(true)
|
||||
"stop-start" -> runScan(false)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun runRfcommClient() {
|
||||
val rfcommClient = bluetoothAdapter?.let { RfcommClient(appViewModel, it) }
|
||||
rfcommClient?.run()
|
||||
}
|
||||
private fun runScenario() {
|
||||
if (bluetoothAdapter == null) {
|
||||
throw IOException("bluetooth not enabled")
|
||||
}
|
||||
|
||||
private fun runRfcommServer() {
|
||||
val rfcommServer = bluetoothAdapter?.let { RfcommServer(appViewModel, it) }
|
||||
rfcommServer?.run()
|
||||
}
|
||||
val runner = when (appViewModel.mode) {
|
||||
RFCOMM_CLIENT_MODE -> RfcommClient(appViewModel, bluetoothAdapter!!, ::createIoClient)
|
||||
RFCOMM_SERVER_MODE -> RfcommServer(appViewModel, bluetoothAdapter!!, ::createIoClient)
|
||||
L2CAP_CLIENT_MODE -> L2capClient(
|
||||
appViewModel, bluetoothAdapter!!, baseContext, ::createIoClient
|
||||
)
|
||||
|
||||
private fun runL2capClient() {
|
||||
val l2capClient = bluetoothAdapter?.let { L2capClient(appViewModel, it, baseContext) }
|
||||
l2capClient?.run()
|
||||
}
|
||||
L2CAP_SERVER_MODE -> L2capServer(appViewModel, bluetoothAdapter!!, ::createIoClient)
|
||||
GATT_CLIENT_MODE -> GattClient(
|
||||
appViewModel, bluetoothAdapter!!, baseContext, ::createIoClient
|
||||
)
|
||||
GATT_SERVER_MODE -> GattServer(
|
||||
appViewModel, bluetoothAdapter!!, baseContext, ::createIoClient
|
||||
)
|
||||
|
||||
private fun runL2capServer() {
|
||||
val l2capServer = bluetoothAdapter?.let { L2capServer(appViewModel, it) }
|
||||
l2capServer?.run()
|
||||
else -> throw IllegalStateException()
|
||||
}
|
||||
runner.run()
|
||||
}
|
||||
|
||||
private fun runScan(startScan: Boolean) {
|
||||
@@ -197,6 +269,17 @@ class MainActivity : ComponentActivity() {
|
||||
scan?.run(startScan)
|
||||
}
|
||||
|
||||
private fun createIoClient(packetIo: PacketIO): IoClient {
|
||||
return when (appViewModel.scenario) {
|
||||
SEND_SCENARIO -> Sender(appViewModel, packetIo)
|
||||
RECEIVE_SCENARIO -> Receiver(appViewModel, packetIo)
|
||||
PING_SCENARIO -> Pinger(appViewModel, packetIo)
|
||||
PONG_SCENARIO -> Ponger(appViewModel, packetIo)
|
||||
else -> throw IllegalStateException()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@SuppressLint("MissingPermission")
|
||||
fun becomeDiscoverable() {
|
||||
val discoverableIntent = Intent(BluetoothAdapter.ACTION_REQUEST_DISCOVERABLE)
|
||||
@@ -210,10 +293,7 @@ class MainActivity : ComponentActivity() {
|
||||
fun MainView(
|
||||
appViewModel: AppViewModel,
|
||||
becomeDiscoverable: () -> Unit,
|
||||
runRfcommClient: () -> Unit,
|
||||
runRfcommServer: () -> Unit,
|
||||
runL2capClient: () -> Unit,
|
||||
runL2capServer: () -> Unit,
|
||||
runScenario: () -> Unit,
|
||||
) {
|
||||
BTBenchTheme {
|
||||
val scrollState = rememberScrollState()
|
||||
@@ -239,7 +319,9 @@ fun MainView(
|
||||
Text(text = "Peer Bluetooth Address")
|
||||
},
|
||||
value = appViewModel.peerBluetoothAddress,
|
||||
modifier = Modifier.fillMaxWidth().focusRequester(focusRequester),
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.focusRequester(focusRequester),
|
||||
keyboardOptions = KeyboardOptions.Default.copy(
|
||||
keyboardType = KeyboardType.Ascii, imeAction = ImeAction.Done
|
||||
),
|
||||
@@ -249,14 +331,18 @@ fun MainView(
|
||||
keyboardActions = KeyboardActions(onDone = {
|
||||
keyboardController?.hide()
|
||||
focusManager.clearFocus()
|
||||
})
|
||||
}),
|
||||
enabled = (appViewModel.mode == RFCOMM_CLIENT_MODE || appViewModel.mode == L2CAP_CLIENT_MODE || appViewModel.mode == GATT_CLIENT_MODE)
|
||||
)
|
||||
Divider()
|
||||
TextField(label = {
|
||||
Text(text = "L2CAP PSM")
|
||||
},
|
||||
TextField(
|
||||
label = {
|
||||
Text(text = "L2CAP PSM")
|
||||
},
|
||||
value = appViewModel.l2capPsm.toString(),
|
||||
modifier = Modifier.fillMaxWidth().focusRequester(focusRequester),
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.focusRequester(focusRequester),
|
||||
keyboardOptions = KeyboardOptions.Default.copy(
|
||||
keyboardType = KeyboardType.Number, imeAction = ImeAction.Done
|
||||
),
|
||||
@@ -271,7 +357,8 @@ fun MainView(
|
||||
keyboardActions = KeyboardActions(onDone = {
|
||||
keyboardController?.hide()
|
||||
focusManager.clearFocus()
|
||||
})
|
||||
}),
|
||||
enabled = (appViewModel.mode == L2CAP_CLIENT_MODE)
|
||||
)
|
||||
Divider()
|
||||
Slider(
|
||||
@@ -290,44 +377,156 @@ fun MainView(
|
||||
)
|
||||
Text(text = "Packet Size: " + appViewModel.senderPacketSize.toString())
|
||||
Divider()
|
||||
ActionButton(
|
||||
text = "Become Discoverable", onClick = becomeDiscoverable, true
|
||||
TextField(
|
||||
label = {
|
||||
Text(text = "Packet Interval (ms)")
|
||||
},
|
||||
value = appViewModel.senderPacketInterval.toString(),
|
||||
modifier = Modifier
|
||||
.fillMaxWidth()
|
||||
.focusRequester(focusRequester),
|
||||
keyboardOptions = KeyboardOptions.Default.copy(
|
||||
keyboardType = KeyboardType.Number, imeAction = ImeAction.Done
|
||||
),
|
||||
onValueChange = {
|
||||
if (it.isNotEmpty()) {
|
||||
val interval = it.toIntOrNull()
|
||||
if (interval != null) {
|
||||
appViewModel.updateSenderPacketInterval(interval)
|
||||
}
|
||||
}
|
||||
},
|
||||
keyboardActions = KeyboardActions(onDone = {
|
||||
keyboardController?.hide()
|
||||
focusManager.clearFocus()
|
||||
}),
|
||||
enabled = (appViewModel.scenario == PING_SCENARIO || appViewModel.scenario == SEND_SCENARIO)
|
||||
)
|
||||
Divider()
|
||||
Row(
|
||||
horizontalArrangement = Arrangement.SpaceBetween,
|
||||
verticalAlignment = Alignment.CenterVertically
|
||||
) {
|
||||
Text(text = "2M PHY")
|
||||
Spacer(modifier = Modifier.padding(start = 8.dp))
|
||||
Switch(
|
||||
Switch(enabled = (appViewModel.mode == L2CAP_CLIENT_MODE || appViewModel.mode == L2CAP_SERVER_MODE || appViewModel.mode == GATT_CLIENT_MODE || appViewModel.mode == GATT_SERVER_MODE),
|
||||
checked = appViewModel.use2mPhy,
|
||||
onCheckedChange = { appViewModel.use2mPhy = it }
|
||||
)
|
||||
|
||||
onCheckedChange = { appViewModel.use2mPhy = it })
|
||||
Column(Modifier.selectableGroup()) {
|
||||
listOf(
|
||||
"BALANCED", "LOW", "HIGH", "DCK"
|
||||
).forEach { text ->
|
||||
Row(
|
||||
Modifier
|
||||
.selectable(
|
||||
selected = (text == appViewModel.connectionPriority),
|
||||
onClick = { appViewModel.updateConnectionPriority(text) },
|
||||
role = Role.RadioButton,
|
||||
)
|
||||
.padding(horizontal = 16.dp),
|
||||
verticalAlignment = Alignment.CenterVertically
|
||||
) {
|
||||
RadioButton(
|
||||
selected = (text == appViewModel.connectionPriority),
|
||||
onClick = null,
|
||||
enabled = (appViewModel.mode == L2CAP_CLIENT_MODE || appViewModel.mode == L2CAP_SERVER_MODE || appViewModel.mode == GATT_CLIENT_MODE || appViewModel.mode == GATT_SERVER_MODE)
|
||||
)
|
||||
Text(
|
||||
text = text,
|
||||
style = MaterialTheme.typography.bodyLarge,
|
||||
modifier = Modifier.padding(start = 16.dp)
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Row {
|
||||
Column(Modifier.selectableGroup()) {
|
||||
listOf(
|
||||
RFCOMM_CLIENT_MODE,
|
||||
RFCOMM_SERVER_MODE,
|
||||
L2CAP_CLIENT_MODE,
|
||||
L2CAP_SERVER_MODE,
|
||||
GATT_CLIENT_MODE,
|
||||
GATT_SERVER_MODE
|
||||
).forEach { text ->
|
||||
Row(
|
||||
Modifier
|
||||
.selectable(
|
||||
selected = (text == appViewModel.mode),
|
||||
onClick = { appViewModel.updateMode(text) },
|
||||
role = Role.RadioButton
|
||||
)
|
||||
.padding(horizontal = 16.dp),
|
||||
verticalAlignment = Alignment.CenterVertically
|
||||
) {
|
||||
RadioButton(
|
||||
selected = (text == appViewModel.mode), onClick = null
|
||||
)
|
||||
Text(
|
||||
text = text,
|
||||
style = MaterialTheme.typography.bodyLarge,
|
||||
modifier = Modifier.padding(start = 16.dp)
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
Column(Modifier.selectableGroup()) {
|
||||
listOf(
|
||||
SEND_SCENARIO, RECEIVE_SCENARIO, PING_SCENARIO, PONG_SCENARIO
|
||||
).forEach { text ->
|
||||
Row(
|
||||
Modifier
|
||||
.selectable(
|
||||
selected = (text == appViewModel.scenario),
|
||||
onClick = { appViewModel.updateScenario(text) },
|
||||
role = Role.RadioButton
|
||||
)
|
||||
.padding(horizontal = 16.dp),
|
||||
verticalAlignment = Alignment.CenterVertically
|
||||
) {
|
||||
RadioButton(
|
||||
selected = (text == appViewModel.scenario), onClick = null
|
||||
)
|
||||
Text(
|
||||
text = text,
|
||||
style = MaterialTheme.typography.bodyLarge,
|
||||
modifier = Modifier.padding(start = 16.dp)
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Row {
|
||||
ActionButton(
|
||||
text = "RFCOMM Client", onClick = runRfcommClient, !appViewModel.running
|
||||
text = "Start", onClick = runScenario, enabled = !appViewModel.running
|
||||
)
|
||||
ActionButton(
|
||||
text = "RFCOMM Server", onClick = runRfcommServer, !appViewModel.running
|
||||
)
|
||||
}
|
||||
Row {
|
||||
ActionButton(
|
||||
text = "L2CAP Client", onClick = runL2capClient, !appViewModel.running
|
||||
text = "Stop", onClick = appViewModel::abort, enabled = appViewModel.running
|
||||
)
|
||||
ActionButton(
|
||||
text = "L2CAP Server", onClick = runL2capServer, !appViewModel.running
|
||||
text = "Become Discoverable", onClick = becomeDiscoverable, true
|
||||
)
|
||||
}
|
||||
Divider()
|
||||
if (appViewModel.mtu != 0) {
|
||||
Text(
|
||||
text = "MTU: ${appViewModel.mtu}"
|
||||
)
|
||||
}
|
||||
if (appViewModel.rxPhy != 0) {
|
||||
Text(
|
||||
text = "PHY: tx=${appViewModel.txPhy}, rx=${appViewModel.rxPhy}"
|
||||
)
|
||||
}
|
||||
Text(
|
||||
text = if (appViewModel.mtu != 0) "MTU: ${appViewModel.mtu}" else ""
|
||||
)
|
||||
Text(
|
||||
text = if (appViewModel.rxPhy != 0 || appViewModel.txPhy != 0) "PHY: tx=${appViewModel.txPhy}, rx=${appViewModel.rxPhy}" else ""
|
||||
text = "Status: ${appViewModel.status}"
|
||||
)
|
||||
if (appViewModel.lastError.isNotEmpty()) {
|
||||
Text(
|
||||
text = "Last Error: ${appViewModel.lastError}"
|
||||
)
|
||||
}
|
||||
Text(
|
||||
text = "Packets Sent: ${appViewModel.packetsSent}"
|
||||
)
|
||||
@@ -337,9 +536,8 @@ fun MainView(
|
||||
Text(
|
||||
text = "Throughput: ${appViewModel.throughput}"
|
||||
)
|
||||
Divider()
|
||||
ActionButton(
|
||||
text = "Abort", onClick = appViewModel::abort, appViewModel.running
|
||||
Text(
|
||||
text = "Stats: ${appViewModel.stats}"
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -351,4 +549,4 @@ fun ActionButton(text: String, onClick: () -> Unit, enabled: Boolean) {
|
||||
Button(onClick = onClick, enabled = enabled) {
|
||||
Text(text = text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
// Copyright 2024 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.github.google.bumble.btbench
|
||||
|
||||
interface Mode {
|
||||
fun run()
|
||||
fun waitForCompletion()
|
||||
}
|
||||
@@ -25,15 +25,35 @@ import java.util.UUID
|
||||
|
||||
val DEFAULT_RFCOMM_UUID: UUID = UUID.fromString("E6D55659-C8B4-4B85-96BB-B1143AF6D3AE")
|
||||
const val DEFAULT_PEER_BLUETOOTH_ADDRESS = "AA:BB:CC:DD:EE:FF"
|
||||
const val DEFAULT_STARTUP_DELAY = 3000
|
||||
const val DEFAULT_SENDER_PACKET_COUNT = 100
|
||||
const val DEFAULT_SENDER_PACKET_SIZE = 1024
|
||||
const val DEFAULT_SENDER_PACKET_INTERVAL = 100
|
||||
const val DEFAULT_PSM = 128
|
||||
|
||||
const val L2CAP_CLIENT_MODE = "L2CAP Client"
|
||||
const val L2CAP_SERVER_MODE = "L2CAP Server"
|
||||
const val RFCOMM_CLIENT_MODE = "RFCOMM Client"
|
||||
const val RFCOMM_SERVER_MODE = "RFCOMM Server"
|
||||
const val GATT_CLIENT_MODE = "GATT Client"
|
||||
const val GATT_SERVER_MODE = "GATT Server"
|
||||
|
||||
const val SEND_SCENARIO = "Send"
|
||||
const val RECEIVE_SCENARIO = "Receive"
|
||||
const val PING_SCENARIO = "Ping"
|
||||
const val PONG_SCENARIO = "Pong"
|
||||
|
||||
class AppViewModel : ViewModel() {
|
||||
private var preferences: SharedPreferences? = null
|
||||
var status by mutableStateOf("")
|
||||
var lastError by mutableStateOf("")
|
||||
var mode by mutableStateOf(RFCOMM_SERVER_MODE)
|
||||
var scenario by mutableStateOf(RECEIVE_SCENARIO)
|
||||
var peerBluetoothAddress by mutableStateOf(DEFAULT_PEER_BLUETOOTH_ADDRESS)
|
||||
var startupDelay by mutableIntStateOf(DEFAULT_STARTUP_DELAY)
|
||||
var l2capPsm by mutableIntStateOf(DEFAULT_PSM)
|
||||
var use2mPhy by mutableStateOf(true)
|
||||
var connectionPriority by mutableStateOf("BALANCED")
|
||||
var mtu by mutableIntStateOf(0)
|
||||
var rxPhy by mutableIntStateOf(0)
|
||||
var txPhy by mutableIntStateOf(0)
|
||||
@@ -41,9 +61,11 @@ class AppViewModel : ViewModel() {
|
||||
var senderPacketSizeSlider by mutableFloatStateOf(0.0F)
|
||||
var senderPacketCount by mutableIntStateOf(DEFAULT_SENDER_PACKET_COUNT)
|
||||
var senderPacketSize by mutableIntStateOf(DEFAULT_SENDER_PACKET_SIZE)
|
||||
var senderPacketInterval by mutableIntStateOf(DEFAULT_SENDER_PACKET_INTERVAL)
|
||||
var packetsSent by mutableIntStateOf(0)
|
||||
var packetsReceived by mutableIntStateOf(0)
|
||||
var throughput by mutableIntStateOf(0)
|
||||
var stats by mutableStateOf("")
|
||||
var running by mutableStateOf(false)
|
||||
var aborter: (() -> Unit)? = null
|
||||
|
||||
@@ -66,6 +88,26 @@ class AppViewModel : ViewModel() {
|
||||
senderPacketSize = savedSenderPacketSize
|
||||
}
|
||||
updateSenderPacketSizeSlider()
|
||||
|
||||
val savedSenderPacketInterval = preferences.getInt(SENDER_PACKET_INTERVAL_PREF_KEY, -1)
|
||||
if (savedSenderPacketInterval != -1) {
|
||||
senderPacketInterval = savedSenderPacketInterval
|
||||
}
|
||||
|
||||
val savedMode = preferences.getString(MODE_PREF_KEY, null)
|
||||
if (savedMode != null) {
|
||||
mode = savedMode
|
||||
}
|
||||
|
||||
val savedScenario = preferences.getString(SCENARIO_PREF_KEY, null)
|
||||
if (savedScenario != null) {
|
||||
scenario = savedScenario
|
||||
}
|
||||
|
||||
val savedConnectionPriority = preferences.getString(CONNECTION_PRIORITY_PREF_KEY, null)
|
||||
if (savedConnectionPriority != null) {
|
||||
connectionPriority = savedConnectionPriority
|
||||
}
|
||||
}
|
||||
|
||||
fun updatePeerBluetoothAddress(peerBluetoothAddress: String) {
|
||||
@@ -164,6 +206,50 @@ class AppViewModel : ViewModel() {
|
||||
}
|
||||
}
|
||||
|
||||
fun updateSenderPacketInterval(senderPacketInterval: Int) {
|
||||
this.senderPacketInterval = senderPacketInterval
|
||||
with(preferences!!.edit()) {
|
||||
putInt(SENDER_PACKET_INTERVAL_PREF_KEY, senderPacketInterval)
|
||||
apply()
|
||||
}
|
||||
}
|
||||
|
||||
fun updateScenario(scenario: String) {
|
||||
this.scenario = scenario
|
||||
with(preferences!!.edit()) {
|
||||
putString(SCENARIO_PREF_KEY, scenario)
|
||||
apply()
|
||||
}
|
||||
}
|
||||
|
||||
fun updateMode(mode: String) {
|
||||
this.mode = mode
|
||||
with(preferences!!.edit()) {
|
||||
putString(MODE_PREF_KEY, mode)
|
||||
apply()
|
||||
}
|
||||
}
|
||||
|
||||
fun updateConnectionPriority(connectionPriority: String) {
|
||||
this.connectionPriority = connectionPriority
|
||||
with(preferences!!.edit()) {
|
||||
putString(CONNECTION_PRIORITY_PREF_KEY, connectionPriority)
|
||||
apply()
|
||||
}
|
||||
}
|
||||
|
||||
fun clear() {
|
||||
status = ""
|
||||
lastError = ""
|
||||
mtu = 0
|
||||
rxPhy = 0
|
||||
txPhy = 0
|
||||
packetsSent = 0
|
||||
packetsReceived = 0
|
||||
throughput = 0
|
||||
stats = ""
|
||||
}
|
||||
|
||||
fun abort() {
|
||||
aborter?.let { it() }
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ package com.github.google.bumble.btbench
|
||||
import android.bluetooth.BluetoothSocket
|
||||
import java.io.IOException
|
||||
import java.nio.ByteBuffer
|
||||
import java.nio.ByteOrder
|
||||
import java.util.logging.Logger
|
||||
import kotlin.math.min
|
||||
|
||||
@@ -37,11 +38,16 @@ abstract class Packet(val type: Int, val payload: ByteArray = ByteArray(0)) {
|
||||
RESET -> ResetPacket()
|
||||
SEQUENCE -> SequencePacket(
|
||||
data[1].toInt(),
|
||||
ByteBuffer.wrap(data, 2, 4).getInt(),
|
||||
data.sliceArray(6..<data.size)
|
||||
ByteBuffer.wrap(data, 2, 4).order(ByteOrder.LITTLE_ENDIAN).getInt(),
|
||||
ByteBuffer.wrap(data, 6, 4).order(ByteOrder.LITTLE_ENDIAN).getInt(),
|
||||
data.sliceArray(10..<data.size)
|
||||
)
|
||||
|
||||
ACK -> AckPacket(
|
||||
data[1].toInt(),
|
||||
ByteBuffer.wrap(data, 2, 4).order(ByteOrder.LITTLE_ENDIAN).getInt()
|
||||
)
|
||||
|
||||
ACK -> AckPacket(data[1].toInt(), ByteBuffer.wrap(data, 2, 4).getInt())
|
||||
else -> GenericPacket(data[0].toInt(), data.sliceArray(1..<data.size))
|
||||
}
|
||||
}
|
||||
@@ -57,16 +63,24 @@ class ResetPacket : Packet(RESET)
|
||||
|
||||
class AckPacket(val flags: Int, val sequenceNumber: Int) : Packet(ACK) {
|
||||
override fun toBytes(): ByteArray {
|
||||
return ByteBuffer.allocate(1 + 1 + 4).put(type.toByte()).put(flags.toByte())
|
||||
return ByteBuffer.allocate(6).order(
|
||||
ByteOrder.LITTLE_ENDIAN
|
||||
).put(type.toByte()).put(flags.toByte())
|
||||
.putInt(sequenceNumber).array()
|
||||
}
|
||||
}
|
||||
|
||||
class SequencePacket(val flags: Int, val sequenceNumber: Int, payload: ByteArray) :
|
||||
class SequencePacket(
|
||||
val flags: Int,
|
||||
val sequenceNumber: Int,
|
||||
val timestamp: Int,
|
||||
payload: ByteArray
|
||||
) :
|
||||
Packet(SEQUENCE, payload) {
|
||||
override fun toBytes(): ByteArray {
|
||||
return ByteBuffer.allocate(1 + 1 + 4 + payload.size).put(type.toByte()).put(flags.toByte())
|
||||
.putInt(sequenceNumber).put(payload).array()
|
||||
return ByteBuffer.allocate(10 + payload.size).order(ByteOrder.LITTLE_ENDIAN)
|
||||
.put(type.toByte()).put(flags.toByte())
|
||||
.putInt(sequenceNumber).putInt(timestamp).put(payload).array()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -74,13 +88,13 @@ abstract class PacketSink {
|
||||
fun onPacket(packet: Packet) {
|
||||
when (packet) {
|
||||
is ResetPacket -> onResetPacket()
|
||||
is AckPacket -> onAckPacket()
|
||||
is AckPacket -> onAckPacket(packet)
|
||||
is SequencePacket -> onSequencePacket(packet)
|
||||
}
|
||||
}
|
||||
|
||||
abstract fun onResetPacket()
|
||||
abstract fun onAckPacket()
|
||||
abstract fun onAckPacket(packet: AckPacket)
|
||||
abstract fun onSequencePacket(packet: SequencePacket)
|
||||
}
|
||||
|
||||
@@ -175,4 +189,4 @@ class SocketDataSource(
|
||||
} while (true)
|
||||
Log.info("end of stream")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,106 @@
|
||||
// Copyright 2024 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.github.google.bumble.btbench
|
||||
|
||||
import java.util.concurrent.Semaphore
|
||||
import java.util.logging.Logger
|
||||
import kotlin.time.Duration.Companion.milliseconds
|
||||
import kotlin.time.TimeSource
|
||||
|
||||
private val Log = Logger.getLogger("btbench.pinger")
|
||||
|
||||
class Pinger(private val viewModel: AppViewModel, private val packetIO: PacketIO) : IoClient,
|
||||
PacketSink() {
|
||||
private val pingTimes: ArrayList<TimeSource.Monotonic.ValueTimeMark> = ArrayList()
|
||||
private val rtts: ArrayList<Long> = ArrayList()
|
||||
private val done = Semaphore(0)
|
||||
|
||||
init {
|
||||
packetIO.packetSink = this
|
||||
}
|
||||
|
||||
override fun run() {
|
||||
viewModel.clear()
|
||||
|
||||
Log.info("startup delay: ${viewModel.startupDelay}")
|
||||
Thread.sleep(viewModel.startupDelay.toLong());
|
||||
Log.info("running")
|
||||
|
||||
Log.info("sending reset")
|
||||
packetIO.sendPacket(ResetPacket())
|
||||
|
||||
val packetCount = viewModel.senderPacketCount
|
||||
val packetSize = viewModel.senderPacketSize
|
||||
|
||||
val startTime = TimeSource.Monotonic.markNow()
|
||||
for (i in 0..<packetCount) {
|
||||
var now = TimeSource.Monotonic.markNow()
|
||||
if (viewModel.senderPacketInterval > 0) {
|
||||
val targetTime = startTime + (i * viewModel.senderPacketInterval).milliseconds
|
||||
val delay = targetTime - now
|
||||
if (delay.isPositive()) {
|
||||
Log.info("sleeping ${delay.inWholeMilliseconds} ms")
|
||||
Thread.sleep(delay.inWholeMilliseconds)
|
||||
now = TimeSource.Monotonic.markNow()
|
||||
}
|
||||
}
|
||||
pingTimes.add(TimeSource.Monotonic.markNow())
|
||||
packetIO.sendPacket(
|
||||
SequencePacket(
|
||||
if (i < packetCount - 1) 0 else Packet.LAST_FLAG,
|
||||
i,
|
||||
(now - startTime).inWholeMicroseconds.toInt(),
|
||||
ByteArray(packetSize - 10)
|
||||
)
|
||||
)
|
||||
viewModel.packetsSent = i + 1
|
||||
}
|
||||
|
||||
// Wait for the last ACK
|
||||
Log.info("waiting for last ACK")
|
||||
done.acquire()
|
||||
Log.info("got last ACK")
|
||||
}
|
||||
|
||||
override fun abort() {
|
||||
done.release()
|
||||
}
|
||||
|
||||
override fun onResetPacket() {
|
||||
}
|
||||
|
||||
override fun onAckPacket(packet: AckPacket) {
|
||||
val now = TimeSource.Monotonic.markNow()
|
||||
viewModel.packetsReceived += 1
|
||||
if (packet.sequenceNumber < pingTimes.size) {
|
||||
val rtt = (now - pingTimes[packet.sequenceNumber]).inWholeMilliseconds
|
||||
rtts.add(rtt)
|
||||
Log.info("received ACK ${packet.sequenceNumber}, RTT=$rtt")
|
||||
} else {
|
||||
Log.warning("received ACK with unexpected sequence ${packet.sequenceNumber}")
|
||||
}
|
||||
|
||||
if (packet.flags and Packet.LAST_FLAG != 0) {
|
||||
Log.info("last packet received")
|
||||
val stats = "RTTs: min=${rtts.min()}, max=${rtts.max()}, avg=${rtts.sum() / rtts.size}"
|
||||
Log.info(stats)
|
||||
viewModel.stats = stats
|
||||
done.release()
|
||||
}
|
||||
}
|
||||
|
||||
override fun onSequencePacket(packet: SequencePacket) {
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,70 @@
|
||||
// Copyright 2024 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package com.github.google.bumble.btbench
|
||||
|
||||
import java.util.concurrent.CountDownLatch
|
||||
import java.util.logging.Logger
|
||||
import kotlin.time.TimeSource
|
||||
|
||||
private val Log = Logger.getLogger("btbench.receiver")
|
||||
|
||||
class Ponger(private val viewModel: AppViewModel, private val packetIO: PacketIO) : IoClient, PacketSink() {
|
||||
private var startTime: TimeSource.Monotonic.ValueTimeMark = TimeSource.Monotonic.markNow()
|
||||
private var lastPacketTime: TimeSource.Monotonic.ValueTimeMark = TimeSource.Monotonic.markNow()
|
||||
private var expectedSequenceNumber: Int = 0
|
||||
private val done = CountDownLatch(1)
|
||||
|
||||
init {
|
||||
packetIO.packetSink = this
|
||||
}
|
||||
|
||||
override fun run() {
|
||||
viewModel.clear()
|
||||
done.await()
|
||||
}
|
||||
|
||||
override fun abort() {}
|
||||
|
||||
override fun onResetPacket() {
|
||||
startTime = TimeSource.Monotonic.markNow()
|
||||
lastPacketTime = startTime
|
||||
expectedSequenceNumber = 0
|
||||
viewModel.packetsSent = 0
|
||||
viewModel.packetsReceived = 0
|
||||
viewModel.stats = ""
|
||||
}
|
||||
|
||||
override fun onAckPacket(packet: AckPacket) {
|
||||
}
|
||||
|
||||
override fun onSequencePacket(packet: SequencePacket) {
|
||||
val now = TimeSource.Monotonic.markNow()
|
||||
lastPacketTime = now
|
||||
viewModel.packetsReceived += 1
|
||||
|
||||
if (packet.sequenceNumber != expectedSequenceNumber) {
|
||||
Log.warning("unexpected packet sequence number (expected ${expectedSequenceNumber}, got ${packet.sequenceNumber})")
|
||||
}
|
||||
expectedSequenceNumber += 1
|
||||
|
||||
packetIO.sendPacket(AckPacket(packet.flags, packet.sequenceNumber))
|
||||
viewModel.packetsSent += 1
|
||||
|
||||
if (packet.flags and Packet.LAST_FLAG != 0) {
|
||||
Log.info("received last packet")
|
||||
done.countDown()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -14,21 +14,30 @@
|
||||
|
||||
package com.github.google.bumble.btbench
|
||||
|
||||
import java.util.concurrent.CountDownLatch
|
||||
import java.util.logging.Logger
|
||||
import kotlin.time.DurationUnit
|
||||
import kotlin.time.TimeSource
|
||||
|
||||
private val Log = Logger.getLogger("btbench.receiver")
|
||||
|
||||
class Receiver(private val viewModel: AppViewModel, private val packetIO: PacketIO) : PacketSink() {
|
||||
class Receiver(private val viewModel: AppViewModel, private val packetIO: PacketIO) : IoClient, PacketSink() {
|
||||
private var startTime: TimeSource.Monotonic.ValueTimeMark = TimeSource.Monotonic.markNow()
|
||||
private var lastPacketTime: TimeSource.Monotonic.ValueTimeMark = TimeSource.Monotonic.markNow()
|
||||
private var bytesReceived = 0
|
||||
private val done = CountDownLatch(1)
|
||||
|
||||
init {
|
||||
packetIO.packetSink = this
|
||||
}
|
||||
|
||||
override fun run() {
|
||||
viewModel.clear()
|
||||
done.await()
|
||||
}
|
||||
|
||||
override fun abort() {}
|
||||
|
||||
override fun onResetPacket() {
|
||||
startTime = TimeSource.Monotonic.markNow()
|
||||
lastPacketTime = startTime
|
||||
@@ -36,9 +45,10 @@ class Receiver(private val viewModel: AppViewModel, private val packetIO: Packet
|
||||
viewModel.throughput = 0
|
||||
viewModel.packetsSent = 0
|
||||
viewModel.packetsReceived = 0
|
||||
viewModel.stats = ""
|
||||
}
|
||||
|
||||
override fun onAckPacket() {
|
||||
override fun onAckPacket(packet: AckPacket) {
|
||||
|
||||
}
|
||||
|
||||
@@ -55,6 +65,7 @@ class Receiver(private val viewModel: AppViewModel, private val packetIO: Packet
|
||||
Log.info("throughput: $throughput")
|
||||
viewModel.throughput = throughput
|
||||
packetIO.sendPacket(AckPacket(packet.flags, packet.sequenceNumber))
|
||||
done.countDown()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,22 +16,30 @@ package com.github.google.bumble.btbench
|
||||
|
||||
import android.annotation.SuppressLint
|
||||
import android.bluetooth.BluetoothAdapter
|
||||
import java.io.IOException
|
||||
import java.util.logging.Logger
|
||||
import kotlin.concurrent.thread
|
||||
|
||||
private val Log = Logger.getLogger("btbench.rfcomm-client")
|
||||
|
||||
class RfcommClient(private val viewModel: AppViewModel, val bluetoothAdapter: BluetoothAdapter) {
|
||||
class RfcommClient(
|
||||
private val viewModel: AppViewModel,
|
||||
private val bluetoothAdapter: BluetoothAdapter,
|
||||
private val createIoClient: (packetIo: PacketIO) -> IoClient
|
||||
) : Mode {
|
||||
private var socketClient: SocketClient? = null
|
||||
|
||||
@SuppressLint("MissingPermission")
|
||||
fun run() {
|
||||
override fun run() {
|
||||
val address = viewModel.peerBluetoothAddress.take(17)
|
||||
val remoteDevice = bluetoothAdapter.getRemoteDevice(address)
|
||||
val socket = remoteDevice.createInsecureRfcommSocketToServiceRecord(
|
||||
DEFAULT_RFCOMM_UUID
|
||||
)
|
||||
|
||||
val client = SocketClient(viewModel, socket)
|
||||
client.run()
|
||||
socketClient = SocketClient(viewModel, socket, createIoClient)
|
||||
socketClient!!.run()
|
||||
}
|
||||
|
||||
override fun waitForCompletion() {
|
||||
socketClient?.waitForCompletion()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,20 +16,27 @@ package com.github.google.bumble.btbench
|
||||
|
||||
import android.annotation.SuppressLint
|
||||
import android.bluetooth.BluetoothAdapter
|
||||
import java.io.IOException
|
||||
import java.util.logging.Logger
|
||||
import kotlin.concurrent.thread
|
||||
|
||||
private val Log = Logger.getLogger("btbench.rfcomm-server")
|
||||
|
||||
class RfcommServer(private val viewModel: AppViewModel, val bluetoothAdapter: BluetoothAdapter) {
|
||||
class RfcommServer(
|
||||
private val viewModel: AppViewModel,
|
||||
private val bluetoothAdapter: BluetoothAdapter,
|
||||
private val createIoClient: (packetIo: PacketIO) -> IoClient
|
||||
) : Mode {
|
||||
private var socketServer: SocketServer? = null
|
||||
|
||||
@SuppressLint("MissingPermission")
|
||||
fun run() {
|
||||
override fun run() {
|
||||
val serverSocket = bluetoothAdapter.listenUsingInsecureRfcommWithServiceRecord(
|
||||
"BumbleBench", DEFAULT_RFCOMM_UUID
|
||||
)
|
||||
|
||||
val server = SocketServer(viewModel, serverSocket)
|
||||
server.run({}, {})
|
||||
socketServer = SocketServer(viewModel, serverSocket, createIoClient)
|
||||
socketServer!!.run({}, {})
|
||||
}
|
||||
}
|
||||
|
||||
override fun waitForCompletion() {
|
||||
socketServer?.waitForCompletion()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -35,4 +35,4 @@ class Scan(val bluetoothAdapter: BluetoothAdapter) {
|
||||
bluetoothLeScanner?.stopScan(scanCallback)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,12 +16,14 @@ package com.github.google.bumble.btbench
|
||||
|
||||
import java.util.concurrent.Semaphore
|
||||
import java.util.logging.Logger
|
||||
import kotlin.time.Duration.Companion.milliseconds
|
||||
import kotlin.time.DurationUnit
|
||||
import kotlin.time.TimeSource
|
||||
|
||||
private val Log = Logger.getLogger("btbench.sender")
|
||||
|
||||
class Sender(private val viewModel: AppViewModel, private val packetIO: PacketIO) : PacketSink() {
|
||||
class Sender(private val viewModel: AppViewModel, private val packetIO: PacketIO) : IoClient,
|
||||
PacketSink() {
|
||||
private var startTime: TimeSource.Monotonic.ValueTimeMark = TimeSource.Monotonic.markNow()
|
||||
private var bytesSent = 0
|
||||
private val done = Semaphore(0)
|
||||
@@ -30,10 +32,12 @@ class Sender(private val viewModel: AppViewModel, private val packetIO: PacketIO
|
||||
packetIO.packetSink = this
|
||||
}
|
||||
|
||||
fun run() {
|
||||
viewModel.packetsSent = 0
|
||||
viewModel.packetsReceived = 0
|
||||
viewModel.throughput = 0
|
||||
override fun run() {
|
||||
viewModel.clear()
|
||||
|
||||
Log.info("startup delay: ${viewModel.startupDelay}")
|
||||
Thread.sleep(viewModel.startupDelay.toLong());
|
||||
Log.info("running")
|
||||
|
||||
Log.info("sending reset")
|
||||
packetIO.sendPacket(ResetPacket())
|
||||
@@ -42,20 +46,32 @@ class Sender(private val viewModel: AppViewModel, private val packetIO: PacketIO
|
||||
|
||||
val packetCount = viewModel.senderPacketCount
|
||||
val packetSize = viewModel.senderPacketSize
|
||||
for (i in 0..<packetCount - 1) {
|
||||
packetIO.sendPacket(SequencePacket(0, i, ByteArray(packetSize - 6)))
|
||||
for (i in 0..<packetCount) {
|
||||
var now = TimeSource.Monotonic.markNow()
|
||||
if (viewModel.senderPacketInterval > 0) {
|
||||
val targetTime = startTime + (i * viewModel.senderPacketInterval).milliseconds
|
||||
val delay = targetTime - now
|
||||
if (delay.isPositive()) {
|
||||
Log.info("sleeping ${delay.inWholeMilliseconds} ms")
|
||||
Thread.sleep(delay.inWholeMilliseconds)
|
||||
}
|
||||
now = TimeSource.Monotonic.markNow()
|
||||
}
|
||||
val flags = when (i) {
|
||||
packetCount - 1 -> Packet.LAST_FLAG
|
||||
else -> 0
|
||||
}
|
||||
packetIO.sendPacket(
|
||||
SequencePacket(
|
||||
flags,
|
||||
i,
|
||||
(now - startTime).inWholeMicroseconds.toInt(),
|
||||
ByteArray(packetSize - 10)
|
||||
)
|
||||
)
|
||||
bytesSent += packetSize
|
||||
viewModel.packetsSent = i + 1
|
||||
}
|
||||
packetIO.sendPacket(
|
||||
SequencePacket(
|
||||
Packet.LAST_FLAG,
|
||||
packetCount - 1,
|
||||
ByteArray(packetSize - 6)
|
||||
)
|
||||
)
|
||||
bytesSent += packetSize
|
||||
viewModel.packetsSent = packetCount
|
||||
|
||||
// Wait for the ACK
|
||||
Log.info("waiting for ACK")
|
||||
@@ -63,14 +79,14 @@ class Sender(private val viewModel: AppViewModel, private val packetIO: PacketIO
|
||||
Log.info("got ACK")
|
||||
}
|
||||
|
||||
fun abort() {
|
||||
override fun abort() {
|
||||
done.release()
|
||||
}
|
||||
|
||||
override fun onResetPacket() {
|
||||
}
|
||||
|
||||
override fun onAckPacket() {
|
||||
override fun onAckPacket(packet: AckPacket) {
|
||||
Log.info("received ACK")
|
||||
val elapsed = TimeSource.Monotonic.markNow() - startTime
|
||||
val throughput = (bytesSent / elapsed.toDouble(DurationUnit.SECONDS)).toInt()
|
||||
@@ -81,4 +97,4 @@ class Sender(private val viewModel: AppViewModel, private val packetIO: PacketIO
|
||||
|
||||
override fun onSequencePacket(packet: SequencePacket) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,16 +22,20 @@ import kotlin.concurrent.thread
|
||||
|
||||
private val Log = Logger.getLogger("btbench.socket-client")
|
||||
|
||||
private const val DEFAULT_STARTUP_DELAY = 3000
|
||||
class SocketClient(
|
||||
private val viewModel: AppViewModel,
|
||||
private val socket: BluetoothSocket,
|
||||
private val createIoClient: (packetIo: PacketIO) -> IoClient
|
||||
) {
|
||||
private var clientThread: Thread? = null
|
||||
|
||||
class SocketClient(private val viewModel: AppViewModel, private val socket: BluetoothSocket) {
|
||||
@SuppressLint("MissingPermission")
|
||||
fun run() {
|
||||
viewModel.running = true
|
||||
val socketDataSink = SocketDataSink(socket)
|
||||
val streamIO = StreamedPacketIO(socketDataSink)
|
||||
val socketDataSource = SocketDataSource(socket, streamIO::onData)
|
||||
val sender = Sender(viewModel, streamIO)
|
||||
val ioClient = createIoClient(streamIO)
|
||||
|
||||
fun cleanup() {
|
||||
socket.close()
|
||||
@@ -39,9 +43,9 @@ class SocketClient(private val viewModel: AppViewModel, private val socket: Blue
|
||||
viewModel.running = false
|
||||
}
|
||||
|
||||
thread(name = "SocketClient") {
|
||||
clientThread = thread(name = "SocketClient") {
|
||||
viewModel.aborter = {
|
||||
sender.abort()
|
||||
ioClient.abort()
|
||||
socket.close()
|
||||
}
|
||||
Log.info("connecting to remote")
|
||||
@@ -49,27 +53,37 @@ class SocketClient(private val viewModel: AppViewModel, private val socket: Blue
|
||||
socket.connect()
|
||||
} catch (error: IOException) {
|
||||
Log.warning("connection failed")
|
||||
viewModel.status = "ABORTED"
|
||||
viewModel.lastError = "CONNECTION_FAILED"
|
||||
cleanup()
|
||||
return@thread
|
||||
}
|
||||
Log.info("connected")
|
||||
|
||||
thread {
|
||||
val sourceThread = thread {
|
||||
socketDataSource.receive()
|
||||
socket.close()
|
||||
sender.abort()
|
||||
ioClient.abort()
|
||||
}
|
||||
|
||||
Log.info("Startup delay: $DEFAULT_STARTUP_DELAY")
|
||||
Thread.sleep(DEFAULT_STARTUP_DELAY.toLong());
|
||||
Log.info("Starting to send")
|
||||
|
||||
try {
|
||||
sender.run()
|
||||
ioClient.run()
|
||||
socket.close()
|
||||
viewModel.status = "OK"
|
||||
} catch (error: IOException) {
|
||||
Log.info("run ended abruptly")
|
||||
viewModel.status = "ABORTED"
|
||||
viewModel.lastError = "IO_ERROR"
|
||||
}
|
||||
|
||||
Log.info("waiting for source thread to finish")
|
||||
sourceThread.join()
|
||||
|
||||
cleanup()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fun waitForCompletion() {
|
||||
clientThread?.join()
|
||||
}
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user