Compare commits

..

1 Commits

Author SHA1 Message Date
Lucas Abel ec35f5b118 pandora: add annotations import 2023-11-06 08:30:12 +00:00
226 changed files with 3674 additions and 25703 deletions
+3 -3
View File
@@ -16,7 +16,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
matrix: matrix:
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] python-version: ["3.8", "3.9", "3.10", "3.11"]
fail-fast: false fail-fast: false
steps: steps:
@@ -29,11 +29,11 @@ jobs:
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v3 uses: actions/setup-python@v3
with: with:
python-version: ${{ matrix.python-version }} python-version: '3.10'
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
python -m pip install ".[build,test,development,pandora]" python -m pip install ".[build,test,development]"
- name: Check - name: Check
run: | run: |
invoke project.pre-commit invoke project.pre-commit
+1 -1
View File
@@ -32,7 +32,7 @@ jobs:
- name: Install - name: Install
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
python -m pip install .[avatar,pandora] python -m pip install .[avatar]
- name: Rootcanal - name: Rootcanal
run: nohup python -m rootcanal > rootcanal.log & run: nohup python -m rootcanal > rootcanal.log &
- name: Test - name: Test
+6 -8
View File
@@ -16,7 +16,7 @@ jobs:
strategy: strategy:
matrix: matrix:
os: ['ubuntu-latest', 'macos-latest', 'windows-latest'] os: ['ubuntu-latest', 'macos-latest', 'windows-latest']
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] python-version: ["3.8", "3.9", "3.10", "3.11"]
fail-fast: false fail-fast: false
steps: steps:
@@ -46,8 +46,8 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
matrix: matrix:
python-version: [ "3.8", "3.9", "3.10", "3.11", "3.12" ] python-version: [ "3.8", "3.9", "3.10", "3.11" ]
rust-version: [ "1.76.0", "stable" ] rust-version: [ "1.70.0", "stable" ]
fail-fast: false fail-fast: false
steps: steps:
- name: Check out from Git - name: Check out from Git
@@ -56,7 +56,7 @@ jobs:
uses: actions/setup-python@v4 uses: actions/setup-python@v4
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
- name: Install Python dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
python -m pip install ".[build,test,development,documentation]" python -m pip install ".[build,test,development,documentation]"
@@ -65,17 +65,15 @@ jobs:
with: with:
components: clippy,rustfmt components: clippy,rustfmt
toolchain: ${{ matrix.rust-version }} toolchain: ${{ matrix.rust-version }}
- name: Install Rust dependencies
run: cargo install cargo-all-features # allows building/testing combinations of features
- name: Check License Headers - name: Check License Headers
run: cd rust && cargo run --features dev-tools --bin file-header check-all run: cd rust && cargo run --features dev-tools --bin file-header check-all
- name: Rust Build - name: Rust Build
run: cd rust && cargo build --all-targets && cargo build-all-features --all-targets run: cd rust && cargo build --all-targets && cargo build --all-features --all-targets
# Lints after build so what clippy needs is already built # Lints after build so what clippy needs is already built
- name: Rust Lints - name: Rust Lints
run: cd rust && cargo fmt --check && cargo clippy --all-targets -- --deny warnings && cargo clippy --all-features --all-targets -- --deny warnings run: cd rust && cargo fmt --check && cargo clippy --all-targets -- --deny warnings && cargo clippy --all-features --all-targets -- --deny warnings
- name: Rust Tests - name: Rust Tests
run: cd rust && cargo test-all-features run: cd rust && cargo test
# At some point, hook up publishing the binary. For now, just make sure it builds. # At some point, hook up publishing the binary. For now, just make sure it builds.
# Once we're ready to publish binaries, this should be built with `--release`. # Once we're ready to publish binaries, this should be built with `--release`.
- name: Build Bumble CLI - name: Build Bumble CLI
-7
View File
@@ -6,14 +6,7 @@ dist/
docs/mkdocs/site docs/mkdocs/site
test-results.xml test-results.xml
__pycache__ __pycache__
# Vim
.*.sw*
# generated by setuptools_scm # generated by setuptools_scm
bumble/_version.py bumble/_version.py
.vscode/launch.json .vscode/launch.json
.vscode/settings.json
/.idea /.idea
venv/
.venv/
# snoop logs
out/
-16
View File
@@ -1,7 +1,6 @@
{ {
"cSpell.words": [ "cSpell.words": [
"Abortable", "Abortable",
"aiohttp",
"altsetting", "altsetting",
"ansiblue", "ansiblue",
"ansicyan", "ansicyan",
@@ -10,13 +9,10 @@
"ansired", "ansired",
"ansiyellow", "ansiyellow",
"appendleft", "appendleft",
"ascs",
"ASHA", "ASHA",
"asyncio", "asyncio",
"ATRAC", "ATRAC",
"avctp",
"avdtp", "avdtp",
"avrcp",
"bitpool", "bitpool",
"bitstruct", "bitstruct",
"BSCP", "BSCP",
@@ -25,10 +21,7 @@
"cccds", "cccds",
"cmac", "cmac",
"CONNECTIONLESS", "CONNECTIONLESS",
"csip",
"csis",
"csrcs", "csrcs",
"CVSD",
"datagram", "datagram",
"DATALINK", "DATALINK",
"delayreport", "delayreport",
@@ -36,8 +29,6 @@
"deregistration", "deregistration",
"dhkey", "dhkey",
"diversifier", "diversifier",
"endianness",
"ESCO",
"Fitbit", "Fitbit",
"GATTLINK", "GATTLINK",
"HANDSFREE", "HANDSFREE",
@@ -45,17 +36,14 @@
"keyup", "keyup",
"levelname", "levelname",
"libc", "libc",
"liblc",
"libusb", "libusb",
"MITM", "MITM",
"MSBC",
"NDIS", "NDIS",
"netsim", "netsim",
"NONBLOCK", "NONBLOCK",
"NONCONN", "NONCONN",
"OXIMETER", "OXIMETER",
"popleft", "popleft",
"PRAND",
"protobuf", "protobuf",
"psms", "psms",
"pyee", "pyee",
@@ -67,7 +55,6 @@
"SEID", "SEID",
"seids", "seids",
"SERV", "SERV",
"SIRK",
"ssrc", "ssrc",
"strerror", "strerror",
"subband", "subband",
@@ -77,11 +64,8 @@
"substates", "substates",
"tobytes", "tobytes",
"tsep", "tsep",
"UNMUTE",
"unmuted",
"usbmodem", "usbmodem",
"vhci", "vhci",
"wasmtime",
"websockets", "websockets",
"xcursor", "xcursor",
"ycursor" "ycursor"
-407
View File
@@ -1,407 +0,0 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import dataclasses
import logging
import os
from typing import cast, Dict, Optional, Tuple
import click
import pyee
from bumble.colors import color
import bumble.company_ids
import bumble.core
import bumble.device
import bumble.gatt
import bumble.hci
import bumble.profiles.bap
import bumble.profiles.pbp
import bumble.transport
import bumble.utils
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
AURACAST_DEFAULT_DEVICE_NAME = "Bumble Auracast"
AURACAST_DEFAULT_DEVICE_ADDRESS = bumble.hci.Address("F0:F1:F2:F3:F4:F5")
# -----------------------------------------------------------------------------
# Discover Broadcasts
# -----------------------------------------------------------------------------
class BroadcastDiscoverer:
@dataclasses.dataclass
class Broadcast(pyee.EventEmitter):
name: str
sync: bumble.device.PeriodicAdvertisingSync
rssi: int = 0
public_broadcast_announcement: Optional[
bumble.profiles.pbp.PublicBroadcastAnnouncement
] = None
broadcast_audio_announcement: Optional[
bumble.profiles.bap.BroadcastAudioAnnouncement
] = None
basic_audio_announcement: Optional[
bumble.profiles.bap.BasicAudioAnnouncement
] = None
appearance: Optional[bumble.core.Appearance] = None
biginfo: Optional[bumble.device.BIGInfoAdvertisement] = None
manufacturer_data: Optional[Tuple[str, bytes]] = None
def __post_init__(self) -> None:
super().__init__()
self.sync.on('establishment', self.on_sync_establishment)
self.sync.on('loss', self.on_sync_loss)
self.sync.on('periodic_advertisement', self.on_periodic_advertisement)
self.sync.on('biginfo_advertisement', self.on_biginfo_advertisement)
self.establishment_timeout_task = asyncio.create_task(
self.wait_for_establishment()
)
async def wait_for_establishment(self) -> None:
await asyncio.sleep(5.0)
if self.sync.state == bumble.device.PeriodicAdvertisingSync.State.PENDING:
print(
color(
'!!! Periodic advertisement sync not established in time, '
'canceling',
'red',
)
)
await self.sync.terminate()
def update(self, advertisement: bumble.device.Advertisement) -> None:
self.rssi = advertisement.rssi
for service_data in advertisement.data.get_all(
bumble.core.AdvertisingData.SERVICE_DATA
):
assert isinstance(service_data, tuple)
service_uuid, data = service_data
assert isinstance(data, bytes)
if (
service_uuid
== bumble.gatt.GATT_PUBLIC_BROADCAST_ANNOUNCEMENT_SERVICE
):
self.public_broadcast_announcement = (
bumble.profiles.pbp.PublicBroadcastAnnouncement.from_bytes(data)
)
continue
if (
service_uuid
== bumble.gatt.GATT_BROADCAST_AUDIO_ANNOUNCEMENT_SERVICE
):
self.broadcast_audio_announcement = (
bumble.profiles.bap.BroadcastAudioAnnouncement.from_bytes(data)
)
continue
self.appearance = advertisement.data.get( # type: ignore[assignment]
bumble.core.AdvertisingData.APPEARANCE
)
if manufacturer_data := advertisement.data.get(
bumble.core.AdvertisingData.MANUFACTURER_SPECIFIC_DATA
):
assert isinstance(manufacturer_data, tuple)
company_id = cast(int, manufacturer_data[0])
data = cast(bytes, manufacturer_data[1])
self.manufacturer_data = (
bumble.company_ids.COMPANY_IDENTIFIERS.get(
company_id, f'0x{company_id:04X}'
),
data,
)
def print(self) -> None:
print(
color('Broadcast:', 'yellow'),
self.sync.advertiser_address,
color(self.sync.state.name, 'green'),
)
print(f' {color("Name", "cyan")}: {self.name}')
if self.appearance:
print(f' {color("Appearance", "cyan")}: {str(self.appearance)}')
print(f' {color("RSSI", "cyan")}: {self.rssi}')
print(f' {color("SID", "cyan")}: {self.sync.sid}')
if self.manufacturer_data:
print(
f' {color("Manufacturer Data", "cyan")}: '
f'{self.manufacturer_data[0]} -> {self.manufacturer_data[1].hex()}'
)
if self.broadcast_audio_announcement:
print(
f' {color("Broadcast ID", "cyan")}: '
f'{self.broadcast_audio_announcement.broadcast_id}'
)
if self.public_broadcast_announcement:
print(
f' {color("Features", "cyan")}: '
f'{self.public_broadcast_announcement.features}'
)
print(
f' {color("Metadata", "cyan")}: '
f'{self.public_broadcast_announcement.metadata}'
)
if self.basic_audio_announcement:
print(color(' Audio:', 'cyan'))
print(
color(' Presentation Delay:', 'magenta'),
self.basic_audio_announcement.presentation_delay,
)
for subgroup in self.basic_audio_announcement.subgroups:
print(color(' Subgroup:', 'magenta'))
print(color(' Codec ID:', 'yellow'))
print(
color(' Coding Format: ', 'green'),
subgroup.codec_id.coding_format.name,
)
print(
color(' Company ID: ', 'green'),
subgroup.codec_id.company_id,
)
print(
color(' Vendor Specific Codec ID:', 'green'),
subgroup.codec_id.vendor_specific_codec_id,
)
print(
color(' Codec Config:', 'yellow'),
subgroup.codec_specific_configuration,
)
print(color(' Metadata: ', 'yellow'), subgroup.metadata)
for bis in subgroup.bis:
print(color(f' BIS [{bis.index}]:', 'yellow'))
print(
color(' Codec Config:', 'green'),
bis.codec_specific_configuration,
)
if self.biginfo:
print(color(' BIG:', 'cyan'))
print(
color(' Number of BIS:', 'magenta'),
self.biginfo.num_bis,
)
print(
color(' PHY: ', 'magenta'),
self.biginfo.phy.name,
)
print(
color(' Framed: ', 'magenta'),
self.biginfo.framed,
)
print(
color(' Encrypted: ', 'magenta'),
self.biginfo.encrypted,
)
def on_sync_establishment(self) -> None:
self.establishment_timeout_task.cancel()
self.emit('change')
def on_sync_loss(self) -> None:
self.basic_audio_announcement = None
self.biginfo = None
self.emit('change')
def on_periodic_advertisement(
self, advertisement: bumble.device.PeriodicAdvertisement
) -> None:
if advertisement.data is None:
return
for service_data in advertisement.data.get_all(
bumble.core.AdvertisingData.SERVICE_DATA
):
assert isinstance(service_data, tuple)
service_uuid, data = service_data
assert isinstance(data, bytes)
if service_uuid == bumble.gatt.GATT_BASIC_AUDIO_ANNOUNCEMENT_SERVICE:
self.basic_audio_announcement = (
bumble.profiles.bap.BasicAudioAnnouncement.from_bytes(data)
)
break
self.emit('change')
def on_biginfo_advertisement(
self, advertisement: bumble.device.BIGInfoAdvertisement
) -> None:
self.biginfo = advertisement
self.emit('change')
def __init__(
self,
device: bumble.device.Device,
filter_duplicates: bool,
sync_timeout: float,
):
self.device = device
self.filter_duplicates = filter_duplicates
self.sync_timeout = sync_timeout
self.broadcasts: Dict[bumble.hci.Address, BroadcastDiscoverer.Broadcast] = {}
self.status_message = ''
device.on('advertisement', self.on_advertisement)
async def run(self) -> None:
self.status_message = color('Scanning...', 'green')
await self.device.start_scanning(
active=False,
filter_duplicates=False,
)
def refresh(self) -> None:
# Clear the screen from the top
print('\033[H')
print('\033[0J')
print('\033[H')
# Print the status message
print(self.status_message)
print("==========================================")
# Print all broadcasts
for broadcast in self.broadcasts.values():
broadcast.print()
print('------------------------------------------')
# Clear the screen to the bottom
print('\033[0J')
def on_advertisement(self, advertisement: bumble.device.Advertisement) -> None:
if (
broadcast_name := advertisement.data.get(
bumble.core.AdvertisingData.BROADCAST_NAME
)
) is None:
return
assert isinstance(broadcast_name, str)
if broadcast := self.broadcasts.get(advertisement.address):
broadcast.update(advertisement)
self.refresh()
return
bumble.utils.AsyncRunner.spawn(
self.on_new_broadcast(broadcast_name, advertisement)
)
async def on_new_broadcast(
self, name: str, advertisement: bumble.device.Advertisement
) -> None:
periodic_advertising_sync = await self.device.create_periodic_advertising_sync(
advertiser_address=advertisement.address,
sid=advertisement.sid,
sync_timeout=self.sync_timeout,
filter_duplicates=self.filter_duplicates,
)
broadcast = self.Broadcast(
name,
periodic_advertising_sync,
)
broadcast.on('change', self.refresh)
broadcast.update(advertisement)
self.broadcasts[advertisement.address] = broadcast
periodic_advertising_sync.on('loss', lambda: self.on_broadcast_loss(broadcast))
self.status_message = color(
f'+Found {len(self.broadcasts)} broadcasts', 'green'
)
self.refresh()
def on_broadcast_loss(self, broadcast: Broadcast) -> None:
del self.broadcasts[broadcast.sync.advertiser_address]
bumble.utils.AsyncRunner.spawn(broadcast.sync.terminate())
self.status_message = color(
f'-Found {len(self.broadcasts)} broadcasts', 'green'
)
self.refresh()
async def run_discover_broadcasts(
filter_duplicates: bool, sync_timeout: float, transport: str
) -> None:
async with await bumble.transport.open_transport(transport) as (
hci_source,
hci_sink,
):
device = bumble.device.Device.with_hci(
AURACAST_DEFAULT_DEVICE_NAME,
AURACAST_DEFAULT_DEVICE_ADDRESS,
hci_source,
hci_sink,
)
await device.power_on()
discoverer = BroadcastDiscoverer(device, filter_duplicates, sync_timeout)
await discoverer.run()
await hci_source.terminated
# -----------------------------------------------------------------------------
# Main
# -----------------------------------------------------------------------------
@click.group()
@click.pass_context
def auracast(
ctx,
):
ctx.ensure_object(dict)
@auracast.command('discover-broadcasts')
@click.option(
'--filter-duplicates', is_flag=True, default=False, help='Filter duplicates'
)
@click.option(
'--sync-timeout',
metavar='SYNC_TIMEOUT',
type=float,
default=5.0,
help='Sync timeout (in seconds)',
)
@click.argument('transport')
@click.pass_context
def discover_broadcasts(ctx, filter_duplicates, sync_timeout, transport):
"""Discover public broadcasts"""
asyncio.run(run_discover_broadcasts(filter_duplicates, sync_timeout, transport))
def main():
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
auracast()
# -----------------------------------------------------------------------------
if __name__ == "__main__":
main() # pylint: disable=no-value-for-parameter
+169 -708
View File
File diff suppressed because it is too large Load Diff
-63
View File
@@ -1,63 +0,0 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import click
from bumble.colors import color
from bumble.hci import Address
from bumble.helpers import generate_irk, verify_rpa_with_irk
@click.group()
def cli():
'''
This is a tool for generating IRK, RPA,
and verifying IRK/RPA pairs
'''
@click.command()
def gen_irk() -> None:
print(generate_irk().hex())
@click.command()
@click.argument("irk", type=str)
def gen_rpa(irk: str) -> None:
irk_bytes = bytes.fromhex(irk)
rpa = Address.generate_private_address(irk_bytes)
print(rpa.to_string(with_type_qualifier=False))
@click.command()
@click.argument("irk", type=str)
@click.argument("rpa", type=str)
def verify_rpa(irk: str, rpa: str) -> None:
address = Address(rpa)
irk_bytes = bytes.fromhex(irk)
if verify_rpa_with_irk(address, irk_bytes):
print(color("Verified", "green"))
else:
print(color("Not Verified", "red"))
def main():
cli.add_command(gen_irk)
cli.add_command(gen_rpa)
cli.add_command(verify_rpa)
cli()
# -----------------------------------------------------------------------------
if __name__ == '__main__':
main()
+4 -4
View File
@@ -777,7 +777,7 @@ class ConsoleApp:
if not service: if not service:
continue continue
values = [ values = [
await attribute.read_value(connection) attribute.read_value(connection)
for connection in self.device.connections.values() for connection in self.device.connections.values()
] ]
if not values: if not values:
@@ -796,11 +796,11 @@ class ConsoleApp:
if not characteristic: if not characteristic:
continue continue
values = [ values = [
await attribute.read_value(connection) attribute.read_value(connection)
for connection in self.device.connections.values() for connection in self.device.connections.values()
] ]
if not values: if not values:
values = [await attribute.read_value(None)] values = [attribute.read_value(None)]
# TODO: future optimization: convert CCCD value to human readable string # TODO: future optimization: convert CCCD value to human readable string
@@ -944,7 +944,7 @@ class ConsoleApp:
# send data to any subscribers # send data to any subscribers
if isinstance(attribute, Characteristic): if isinstance(attribute, Characteristic):
await attribute.write_value(None, value) attribute.write_value(None, value)
if attribute.has_properties(Characteristic.NOTIFY): if attribute.has_properties(Characteristic.NOTIFY):
await self.device.gatt_server.notify_subscribers(attribute) await self.device.gatt_server.notify_subscribers(attribute)
if attribute.has_properties(Characteristic.INDICATE): if attribute.has_properties(Characteristic.INDICATE):
+9 -81
View File
@@ -18,39 +18,30 @@
import asyncio import asyncio
import os import os
import logging import logging
import time
import click import click
from bumble.company_ids import COMPANY_IDENTIFIERS from bumble.company_ids import COMPANY_IDENTIFIERS
from bumble.colors import color from bumble.colors import color
from bumble.core import name_or_number from bumble.core import name_or_number
from bumble.hci import ( from bumble.hci import (
map_null_terminated_utf8_string, map_null_terminated_utf8_string,
LeFeature,
HCI_SUCCESS, HCI_SUCCESS,
HCI_LE_SUPPORTED_FEATURES_NAMES,
HCI_VERSION_NAMES, HCI_VERSION_NAMES,
LMP_VERSION_NAMES, LMP_VERSION_NAMES,
HCI_Command, HCI_Command,
HCI_Command_Complete_Event, HCI_Command_Complete_Event,
HCI_Command_Status_Event, HCI_Command_Status_Event,
HCI_READ_BUFFER_SIZE_COMMAND,
HCI_Read_Buffer_Size_Command,
HCI_READ_BD_ADDR_COMMAND, HCI_READ_BD_ADDR_COMMAND,
HCI_Read_BD_ADDR_Command, HCI_Read_BD_ADDR_Command,
HCI_READ_LOCAL_NAME_COMMAND, HCI_READ_LOCAL_NAME_COMMAND,
HCI_Read_Local_Name_Command, HCI_Read_Local_Name_Command,
HCI_LE_READ_BUFFER_SIZE_COMMAND,
HCI_LE_Read_Buffer_Size_Command,
HCI_LE_READ_MAXIMUM_DATA_LENGTH_COMMAND, HCI_LE_READ_MAXIMUM_DATA_LENGTH_COMMAND,
HCI_LE_Read_Maximum_Data_Length_Command, HCI_LE_Read_Maximum_Data_Length_Command,
HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND, HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND,
HCI_LE_Read_Number_Of_Supported_Advertising_Sets_Command, HCI_LE_Read_Number_Of_Supported_Advertising_Sets_Command,
HCI_LE_READ_MAXIMUM_ADVERTISING_DATA_LENGTH_COMMAND, HCI_LE_READ_MAXIMUM_ADVERTISING_DATA_LENGTH_COMMAND,
HCI_LE_Read_Maximum_Advertising_Data_Length_Command, HCI_LE_Read_Maximum_Advertising_Data_Length_Command,
HCI_LE_READ_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND,
HCI_LE_Read_Suggested_Default_Data_Length_Command,
HCI_Read_Local_Version_Information_Command,
) )
from bumble.host import Host from bumble.host import Host
from bumble.transport import open_transport_or_link from bumble.transport import open_transport_or_link
@@ -66,7 +57,7 @@ def command_succeeded(response):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def get_classic_info(host: Host) -> None: async def get_classic_info(host):
if host.supports_command(HCI_READ_BD_ADDR_COMMAND): if host.supports_command(HCI_READ_BD_ADDR_COMMAND):
response = await host.send_command(HCI_Read_BD_ADDR_Command()) response = await host.send_command(HCI_Read_BD_ADDR_Command())
if command_succeeded(response): if command_succeeded(response):
@@ -87,7 +78,7 @@ async def get_classic_info(host: Host) -> None:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def get_le_info(host: Host) -> None: async def get_le_info(host):
print() print()
if host.supports_command(HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND): if host.supports_command(HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND):
@@ -126,50 +117,13 @@ async def get_le_info(host: Host) -> None:
'\n', '\n',
) )
if host.supports_command(HCI_LE_READ_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND):
response = await host.send_command(
HCI_LE_Read_Suggested_Default_Data_Length_Command()
)
if command_succeeded(response):
print(
color('Suggested Default Data Length:', 'yellow'),
f'{response.return_parameters.suggested_max_tx_octets}/'
f'{response.return_parameters.suggested_max_tx_time}',
'\n',
)
print(color('LE Features:', 'yellow')) print(color('LE Features:', 'yellow'))
for feature in host.supported_le_features: for feature in host.supported_le_features:
print(f' {LeFeature(feature).name}') print(' ', name_or_number(HCI_LE_SUPPORTED_FEATURES_NAMES, feature))
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def get_acl_flow_control_info(host: Host) -> None: async def async_main(transport):
print()
if host.supports_command(HCI_READ_BUFFER_SIZE_COMMAND):
response = await host.send_command(
HCI_Read_Buffer_Size_Command(), check_result=True
)
print(
color('ACL Flow Control:', 'yellow'),
f'{response.return_parameters.hc_total_num_acl_data_packets} '
f'packets of size {response.return_parameters.hc_acl_data_packet_length}',
)
if host.supports_command(HCI_LE_READ_BUFFER_SIZE_COMMAND):
response = await host.send_command(
HCI_LE_Read_Buffer_Size_Command(), check_result=True
)
print(
color('LE ACL Flow Control:', 'yellow'),
f'{response.return_parameters.hc_total_num_le_acl_data_packets} '
f'packets of size {response.return_parameters.hc_le_acl_data_packet_length}',
)
# -----------------------------------------------------------------------------
async def async_main(latency_probes, transport):
print('<<< connecting to HCI...') print('<<< connecting to HCI...')
async with await open_transport_or_link(transport) as (hci_source, hci_sink): async with await open_transport_or_link(transport) as (hci_source, hci_sink):
print('<<< connected') print('<<< connected')
@@ -177,23 +131,6 @@ async def async_main(latency_probes, transport):
host = Host(hci_source, hci_sink) host = Host(hci_source, hci_sink)
await host.reset() await host.reset()
# Measure the latency if requested
latencies = []
if latency_probes:
for _ in range(latency_probes):
start = time.time()
await host.send_command(HCI_Read_Local_Version_Information_Command())
latencies.append(1000 * (time.time() - start))
print(
color('HCI Command Latency:', 'yellow'),
(
f'min={min(latencies):.2f}, '
f'max={max(latencies):.2f}, '
f'average={sum(latencies)/len(latencies):.2f}'
),
'\n',
)
# Print version # Print version
print(color('Version:', 'yellow')) print(color('Version:', 'yellow'))
print( print(
@@ -217,28 +154,19 @@ async def async_main(latency_probes, transport):
# Get the LE info # Get the LE info
await get_le_info(host) await get_le_info(host)
# Print the ACL flow control info
await get_acl_flow_control_info(host)
# Print the list of commands supported by the controller # Print the list of commands supported by the controller
print() print()
print(color('Supported Commands:', 'yellow')) print(color('Supported Commands:', 'yellow'))
for command in host.supported_commands: for command in host.supported_commands:
print(f' {HCI_Command.command_name(command)}') print(' ', HCI_Command.command_name(command))
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@click.command() @click.command()
@click.option(
'--latency-probes',
metavar='N',
type=int,
help='Send N commands to measure HCI transport latency statistics',
)
@click.argument('transport') @click.argument('transport')
def main(latency_probes, transport): def main(transport):
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper()) logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
asyncio.run(async_main(latency_probes, transport)) asyncio.run(async_main(transport))
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
-205
View File
@@ -1,205 +0,0 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import logging
import os
import time
from typing import Optional
from bumble.colors import color
from bumble.hci import (
HCI_READ_LOOPBACK_MODE_COMMAND,
HCI_Read_Loopback_Mode_Command,
HCI_WRITE_LOOPBACK_MODE_COMMAND,
HCI_Write_Loopback_Mode_Command,
LoopbackMode,
)
from bumble.host import Host
from bumble.transport import open_transport_or_link
import click
class Loopback:
"""Send and receive ACL data packets in local loopback mode"""
def __init__(self, packet_size: int, packet_count: int, transport: str):
self.transport = transport
self.packet_size = packet_size
self.packet_count = packet_count
self.connection_handle: Optional[int] = None
self.connection_event = asyncio.Event()
self.done = asyncio.Event()
self.expected_cid = 0
self.bytes_received = 0
self.start_timestamp = 0.0
self.last_timestamp = 0.0
def on_connection(self, connection_handle: int, *args):
"""Retrieve connection handle from new connection event"""
if not self.connection_event.is_set():
# save first connection handle for ACL
# subsequent connections are SCO
self.connection_handle = connection_handle
self.connection_event.set()
def on_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes):
"""Calculate packet receive speed"""
now = time.time()
print(f'<<< Received packet {cid}: {len(pdu)} bytes')
assert connection_handle == self.connection_handle
assert cid == self.expected_cid
self.expected_cid += 1
if cid == 0:
self.start_timestamp = now
else:
elapsed_since_start = now - self.start_timestamp
elapsed_since_last = now - self.last_timestamp
self.bytes_received += len(pdu)
instant_rx_speed = len(pdu) / elapsed_since_last
average_rx_speed = self.bytes_received / elapsed_since_start
print(
color(
f'@@@ RX speed: instant={instant_rx_speed:.4f},'
f' average={average_rx_speed:.4f}',
'cyan',
)
)
self.last_timestamp = now
if self.expected_cid == self.packet_count:
print(color('@@@ Received last packet', 'green'))
self.done.set()
async def run(self):
"""Run a loopback throughput test"""
print(color('>>> Connecting to HCI...', 'green'))
async with await open_transport_or_link(self.transport) as (
hci_source,
hci_sink,
):
print(color('>>> Connected', 'green'))
host = Host(hci_source, hci_sink)
await host.reset()
# make sure data can fit in one l2cap pdu
l2cap_header_size = 4
max_packet_size = (
host.acl_packet_queue
if host.acl_packet_queue
else host.le_acl_packet_queue
).max_packet_size - l2cap_header_size
if self.packet_size > max_packet_size:
print(
color(
f'!!! Packet size ({self.packet_size}) larger than max supported'
f' size ({max_packet_size})',
'red',
)
)
return
if not host.supports_command(
HCI_WRITE_LOOPBACK_MODE_COMMAND
) or not host.supports_command(HCI_READ_LOOPBACK_MODE_COMMAND):
print(color('!!! Loopback mode not supported', 'red'))
return
# set event callbacks
host.on('connection', self.on_connection)
host.on('l2cap_pdu', self.on_l2cap_pdu)
loopback_mode = LoopbackMode.LOCAL
print(color('### Setting loopback mode', 'blue'))
await host.send_command(
HCI_Write_Loopback_Mode_Command(loopback_mode=LoopbackMode.LOCAL),
check_result=True,
)
print(color('### Checking loopback mode', 'blue'))
response = await host.send_command(
HCI_Read_Loopback_Mode_Command(), check_result=True
)
if response.return_parameters.loopback_mode != loopback_mode:
print(color('!!! Loopback mode mismatch', 'red'))
return
await self.connection_event.wait()
print(color('### Connected', 'cyan'))
print(color('=== Start sending', 'magenta'))
start_time = time.time()
bytes_sent = 0
for cid in range(0, self.packet_count):
# using the cid as an incremental index
host.send_l2cap_pdu(
self.connection_handle, cid, bytes(self.packet_size)
)
print(
color(
f'>>> Sending packet {cid}: {self.packet_size} bytes', 'yellow'
)
)
bytes_sent += self.packet_size # don't count L2CAP or HCI header sizes
await asyncio.sleep(0) # yield to allow packet receive
await self.done.wait()
print(color('=== Done!', 'magenta'))
elapsed = time.time() - start_time
average_tx_speed = bytes_sent / elapsed
print(
color(
f'@@@ TX speed: average={average_tx_speed:.4f} ({bytes_sent} bytes'
f' in {elapsed:.2f} seconds)',
'green',
)
)
# -----------------------------------------------------------------------------
@click.command()
@click.option(
'--packet-size',
'-s',
metavar='SIZE',
type=click.IntRange(8, 4096),
default=500,
help='Packet size',
)
@click.option(
'--packet-count',
'-c',
metavar='COUNT',
type=click.IntRange(1, 65535),
default=10,
help='Packet count',
)
@click.argument('transport')
def main(packet_size, packet_count, transport):
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
loopback = Loopback(packet_size, packet_count, transport)
asyncio.run(loopback.run())
# -----------------------------------------------------------------------------
if __name__ == '__main__':
main()
+24 -32
View File
@@ -49,16 +49,14 @@ class ServerBridge:
self.tcp_port = tcp_port self.tcp_port = tcp_port
async def start(self, device: Device) -> None: async def start(self, device: Device) -> None:
# Listen for incoming L2CAP channel connections # Listen for incoming L2CAP CoC connections
device.create_l2cap_server( device.create_l2cap_server(
spec=l2cap.LeCreditBasedChannelSpec( spec=l2cap.LeCreditBasedChannelSpec(
psm=self.psm, mtu=self.mtu, mps=self.mps, max_credits=self.max_credits psm=self.psm, mtu=self.mtu, mps=self.mps, max_credits=self.max_credits
), ),
handler=self.on_channel, handler=self.on_coc,
)
print(
color(f'### Listening for channel connection on PSM {self.psm}', 'yellow')
) )
print(color(f'### Listening for CoC connection on PSM {self.psm}', 'yellow'))
def on_ble_connection(connection): def on_ble_connection(connection):
def on_ble_disconnection(reason): def on_ble_disconnection(reason):
@@ -75,7 +73,7 @@ class ServerBridge:
await device.start_advertising(auto_restart=True) await device.start_advertising(auto_restart=True)
# Called when a new L2CAP connection is established # Called when a new L2CAP connection is established
def on_channel(self, l2cap_channel): def on_coc(self, l2cap_channel):
print(color('*** L2CAP channel:', 'cyan'), l2cap_channel) print(color('*** L2CAP channel:', 'cyan'), l2cap_channel)
class Pipe: class Pipe:
@@ -85,7 +83,7 @@ class ServerBridge:
self.l2cap_channel = l2cap_channel self.l2cap_channel = l2cap_channel
l2cap_channel.on('close', self.on_l2cap_close) l2cap_channel.on('close', self.on_l2cap_close)
l2cap_channel.sink = self.on_channel_sdu l2cap_channel.sink = self.on_coc_sdu
async def connect_to_tcp(self): async def connect_to_tcp(self):
# Connect to the TCP server # Connect to the TCP server
@@ -130,7 +128,7 @@ class ServerBridge:
if self.tcp_transport is not None: if self.tcp_transport is not None:
self.tcp_transport.close() self.tcp_transport.close()
def on_channel_sdu(self, sdu): def on_coc_sdu(self, sdu):
print(color(f'<<< [L2CAP SDU]: {len(sdu)} bytes', 'cyan')) print(color(f'<<< [L2CAP SDU]: {len(sdu)} bytes', 'cyan'))
if self.tcp_transport is None: if self.tcp_transport is None:
print(color('!!! TCP socket not open, dropping', 'red')) print(color('!!! TCP socket not open, dropping', 'red'))
@@ -185,7 +183,7 @@ class ClientBridge:
peer_name = writer.get_extra_info('peer_name') peer_name = writer.get_extra_info('peer_name')
print(color(f'<<< TCP connection from {peer_name}', 'magenta')) print(color(f'<<< TCP connection from {peer_name}', 'magenta'))
def on_channel_sdu(sdu): def on_coc_sdu(sdu):
print(color(f'<<< [L2CAP SDU]: {len(sdu)} bytes', 'cyan')) print(color(f'<<< [L2CAP SDU]: {len(sdu)} bytes', 'cyan'))
l2cap_to_tcp_pipe.write(sdu) l2cap_to_tcp_pipe.write(sdu)
@@ -211,7 +209,7 @@ class ClientBridge:
writer.close() writer.close()
return return
l2cap_channel.sink = on_channel_sdu l2cap_channel.sink = on_coc_sdu
l2cap_channel.on('close', on_l2cap_close) l2cap_channel.on('close', on_l2cap_close)
# Start a flow control pipe from L2CAP to TCP # Start a flow control pipe from L2CAP to TCP
@@ -276,29 +274,23 @@ async def run(device_config, hci_transport, bridge):
@click.pass_context @click.pass_context
@click.option('--device-config', help='Device configuration file', required=True) @click.option('--device-config', help='Device configuration file', required=True)
@click.option('--hci-transport', help='HCI transport', required=True) @click.option('--hci-transport', help='HCI transport', required=True)
@click.option('--psm', help='PSM for L2CAP', type=int, default=1234) @click.option('--psm', help='PSM for L2CAP CoC', type=int, default=1234)
@click.option( @click.option(
'--l2cap-max-credits', '--l2cap-coc-max-credits',
help='Maximum L2CAP Credits', help='Maximum L2CAP CoC Credits',
type=click.IntRange(1, 65535), type=click.IntRange(1, 65535),
default=128, default=128,
) )
@click.option( @click.option(
'--l2cap-mtu', '--l2cap-coc-mtu',
help='L2CAP MTU', help='L2CAP CoC MTU',
type=click.IntRange( type=click.IntRange(23, 65535),
l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU, default=1022,
l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MTU,
),
default=1024,
) )
@click.option( @click.option(
'--l2cap-mps', '--l2cap-coc-mps',
help='L2CAP MPS', help='L2CAP CoC MPS',
type=click.IntRange( type=click.IntRange(23, 65533),
l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MPS,
l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MPS,
),
default=1024, default=1024,
) )
def cli( def cli(
@@ -306,17 +298,17 @@ def cli(
device_config, device_config,
hci_transport, hci_transport,
psm, psm,
l2cap_max_credits, l2cap_coc_max_credits,
l2cap_mtu, l2cap_coc_mtu,
l2cap_mps, l2cap_coc_mps,
): ):
context.ensure_object(dict) context.ensure_object(dict)
context.obj['device_config'] = device_config context.obj['device_config'] = device_config
context.obj['hci_transport'] = hci_transport context.obj['hci_transport'] = hci_transport
context.obj['psm'] = psm context.obj['psm'] = psm
context.obj['max_credits'] = l2cap_max_credits context.obj['max_credits'] = l2cap_coc_max_credits
context.obj['mtu'] = l2cap_mtu context.obj['mtu'] = l2cap_coc_mtu
context.obj['mps'] = l2cap_mps context.obj['mps'] = l2cap_coc_mps
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
-577
View File
@@ -1,577 +0,0 @@
# Copyright 2021-2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import datetime
import enum
import functools
from importlib import resources
import json
import os
import logging
import pathlib
from typing import Optional, List, cast
import weakref
import struct
import ctypes
import wasmtime
import wasmtime.loader
import liblc3 # type: ignore
import logging
import click
import aiohttp.web
import bumble
from bumble.core import AdvertisingData
from bumble.colors import color
from bumble.device import Device, DeviceConfiguration, AdvertisingParameters
from bumble.transport import open_transport
from bumble.profiles import bap
from bumble.hci import Address, CodecID, CodingFormat, HCI_IsoDataPacket
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
DEFAULT_UI_PORT = 7654
def _sink_pac_record() -> bap.PacRecord:
return bap.PacRecord(
coding_format=CodingFormat(CodecID.LC3),
codec_specific_capabilities=bap.CodecSpecificCapabilities(
supported_sampling_frequencies=(
bap.SupportedSamplingFrequency.FREQ_8000
| bap.SupportedSamplingFrequency.FREQ_16000
| bap.SupportedSamplingFrequency.FREQ_24000
| bap.SupportedSamplingFrequency.FREQ_32000
| bap.SupportedSamplingFrequency.FREQ_48000
),
supported_frame_durations=(
bap.SupportedFrameDuration.DURATION_10000_US_SUPPORTED
),
supported_audio_channel_count=[1, 2],
min_octets_per_codec_frame=26,
max_octets_per_codec_frame=240,
supported_max_codec_frames_per_sdu=2,
),
)
def _source_pac_record() -> bap.PacRecord:
return bap.PacRecord(
coding_format=CodingFormat(CodecID.LC3),
codec_specific_capabilities=bap.CodecSpecificCapabilities(
supported_sampling_frequencies=(
bap.SupportedSamplingFrequency.FREQ_8000
| bap.SupportedSamplingFrequency.FREQ_16000
| bap.SupportedSamplingFrequency.FREQ_24000
| bap.SupportedSamplingFrequency.FREQ_32000
| bap.SupportedSamplingFrequency.FREQ_48000
),
supported_frame_durations=(
bap.SupportedFrameDuration.DURATION_10000_US_SUPPORTED
),
supported_audio_channel_count=[1],
min_octets_per_codec_frame=30,
max_octets_per_codec_frame=100,
supported_max_codec_frames_per_sdu=1,
),
)
# -----------------------------------------------------------------------------
# WASM - liblc3
# -----------------------------------------------------------------------------
store = wasmtime.loader.store
_memory = cast(wasmtime.Memory, liblc3.memory)
STACK_POINTER = _memory.data_len(store)
_memory.grow(store, 1)
# Mapping wasmtime memory to linear address
memory = (ctypes.c_ubyte * _memory.data_len(store)).from_address(
ctypes.addressof(_memory.data_ptr(store).contents) # type: ignore
)
class Liblc3PcmFormat(enum.IntEnum):
S16 = 0
S24 = 1
S24_3LE = 2
FLOAT = 3
MAX_DECODER_SIZE = liblc3.lc3_decoder_size(10000, 48000)
MAX_ENCODER_SIZE = liblc3.lc3_encoder_size(10000, 48000)
DECODER_STACK_POINTER = STACK_POINTER
ENCODER_STACK_POINTER = DECODER_STACK_POINTER + MAX_DECODER_SIZE * 2
DECODE_BUFFER_STACK_POINTER = ENCODER_STACK_POINTER + MAX_ENCODER_SIZE * 2
ENCODE_BUFFER_STACK_POINTER = DECODE_BUFFER_STACK_POINTER + 8192
DEFAULT_PCM_SAMPLE_RATE = 48000
DEFAULT_PCM_FORMAT = Liblc3PcmFormat.S16
DEFAULT_PCM_BYTES_PER_SAMPLE = 2
encoders: List[int] = []
decoders: List[int] = []
def setup_encoders(
sample_rate_hz: int, frame_duration_us: int, num_channels: int
) -> None:
logger.info(
f"setup_encoders {sample_rate_hz}Hz {frame_duration_us}us {num_channels}channels"
)
encoders[:num_channels] = [
liblc3.lc3_setup_encoder(
frame_duration_us,
sample_rate_hz,
DEFAULT_PCM_SAMPLE_RATE, # Input sample rate
ENCODER_STACK_POINTER + MAX_ENCODER_SIZE * i,
)
for i in range(num_channels)
]
def setup_decoders(
sample_rate_hz: int, frame_duration_us: int, num_channels: int
) -> None:
logger.info(
f"setup_decoders {sample_rate_hz}Hz {frame_duration_us}us {num_channels}channels"
)
decoders[:num_channels] = [
liblc3.lc3_setup_decoder(
frame_duration_us,
sample_rate_hz,
DEFAULT_PCM_SAMPLE_RATE, # Output sample rate
DECODER_STACK_POINTER + MAX_DECODER_SIZE * i,
)
for i in range(num_channels)
]
def decode(
frame_duration_us: int,
num_channels: int,
input_bytes: bytes,
) -> bytes:
if not input_bytes:
return b''
input_buffer_offset = DECODE_BUFFER_STACK_POINTER
input_buffer_size = len(input_bytes)
input_bytes_per_frame = input_buffer_size // num_channels
# Copy into wasm
memory[input_buffer_offset : input_buffer_offset + input_buffer_size] = input_bytes # type: ignore
output_buffer_offset = input_buffer_offset + input_buffer_size
output_buffer_size = (
liblc3.lc3_frame_samples(frame_duration_us, DEFAULT_PCM_SAMPLE_RATE)
* DEFAULT_PCM_BYTES_PER_SAMPLE
* num_channels
)
for i in range(num_channels):
res = liblc3.lc3_decode(
decoders[i],
input_buffer_offset + input_bytes_per_frame * i,
input_bytes_per_frame,
DEFAULT_PCM_FORMAT,
output_buffer_offset + i * DEFAULT_PCM_BYTES_PER_SAMPLE,
num_channels, # Stride
)
if res != 0:
logging.error(f"Parsing failed, res={res}")
# Extract decoded data from the output buffer
return bytes(
memory[output_buffer_offset : output_buffer_offset + output_buffer_size]
)
def encode(
sdu_length: int,
num_channels: int,
stride: int,
input_bytes: bytes,
) -> bytes:
if not input_bytes:
return b''
input_buffer_offset = ENCODE_BUFFER_STACK_POINTER
input_buffer_size = len(input_bytes)
# Copy into wasm
memory[input_buffer_offset : input_buffer_offset + input_buffer_size] = input_bytes # type: ignore
output_buffer_offset = input_buffer_offset + input_buffer_size
output_buffer_size = sdu_length
output_frame_size = output_buffer_size // num_channels
for i in range(num_channels):
res = liblc3.lc3_encode(
encoders[i],
DEFAULT_PCM_FORMAT,
input_buffer_offset + DEFAULT_PCM_BYTES_PER_SAMPLE * i,
stride,
output_frame_size,
output_buffer_offset + output_frame_size * i,
)
if res != 0:
logging.error(f"Parsing failed, res={res}")
# Extract decoded data from the output buffer
return bytes(
memory[output_buffer_offset : output_buffer_offset + output_buffer_size]
)
async def lc3_source_task(
filename: str,
sdu_length: int,
frame_duration_us: int,
device: Device,
cis_handle: int,
) -> None:
with open(filename, 'rb') as f:
header = f.read(44)
assert header[8:12] == b'WAVE'
pcm_num_channel, pcm_sample_rate, _byte_rate, _block_align, bits_per_sample = (
struct.unpack("<HIIHH", header[22:36])
)
assert pcm_sample_rate == DEFAULT_PCM_SAMPLE_RATE
assert bits_per_sample == DEFAULT_PCM_BYTES_PER_SAMPLE * 8
frame_bytes = (
liblc3.lc3_frame_samples(frame_duration_us, DEFAULT_PCM_SAMPLE_RATE)
* DEFAULT_PCM_BYTES_PER_SAMPLE
)
packet_sequence_number = 0
while True:
next_round = datetime.datetime.now() + datetime.timedelta(
microseconds=frame_duration_us
)
pcm_data = f.read(frame_bytes)
sdu = encode(sdu_length, pcm_num_channel, pcm_num_channel, pcm_data)
iso_packet = HCI_IsoDataPacket(
connection_handle=cis_handle,
data_total_length=sdu_length + 4,
packet_sequence_number=packet_sequence_number,
pb_flag=0b10,
packet_status_flag=0,
iso_sdu_length=sdu_length,
iso_sdu_fragment=sdu,
)
device.host.send_hci_packet(iso_packet)
packet_sequence_number += 1
sleep_time = next_round - datetime.datetime.now()
await asyncio.sleep(sleep_time.total_seconds())
# -----------------------------------------------------------------------------
class UiServer:
speaker: weakref.ReferenceType[Speaker]
port: int
def __init__(self, speaker: Speaker, port: int) -> None:
self.speaker = weakref.ref(speaker)
self.port = port
self.channel_socket = None
async def start_http(self) -> None:
"""Start the UI HTTP server."""
app = aiohttp.web.Application()
app.add_routes(
[
aiohttp.web.get('/', self.get_static),
aiohttp.web.get('/index.html', self.get_static),
aiohttp.web.get('/channel', self.get_channel),
]
)
runner = aiohttp.web.AppRunner(app)
await runner.setup()
site = aiohttp.web.TCPSite(runner, 'localhost', self.port)
print('UI HTTP server at ' + color(f'http://127.0.0.1:{self.port}', 'green'))
await site.start()
async def get_static(self, request):
path = request.path
if path == '/':
path = '/index.html'
if path.endswith('.html'):
content_type = 'text/html'
elif path.endswith('.js'):
content_type = 'text/javascript'
elif path.endswith('.css'):
content_type = 'text/css'
elif path.endswith('.svg'):
content_type = 'image/svg+xml'
else:
content_type = 'text/plain'
text = (
resources.files("bumble.apps.lea_unicast")
.joinpath(pathlib.Path(path).relative_to('/'))
.read_text(encoding="utf-8")
)
return aiohttp.web.Response(text=text, content_type=content_type)
async def get_channel(self, request):
ws = aiohttp.web.WebSocketResponse()
await ws.prepare(request)
# Process messages until the socket is closed.
self.channel_socket = ws
async for message in ws:
if message.type == aiohttp.WSMsgType.TEXT:
logger.debug(f'<<< received message: {message.data}')
await self.on_message(message.data)
elif message.type == aiohttp.WSMsgType.ERROR:
logger.debug(
f'channel connection closed with exception {ws.exception()}'
)
self.channel_socket = None
logger.debug('--- channel connection closed')
return ws
async def on_message(self, message_str: str):
# Parse the message as JSON
message = json.loads(message_str)
# Dispatch the message
message_type = message['type']
message_params = message.get('params', {})
handler = getattr(self, f'on_{message_type}_message')
if handler:
await handler(**message_params)
async def on_hello_message(self):
await self.send_message(
'hello',
bumble_version=bumble.__version__,
codec=self.speaker().codec,
streamState=self.speaker().stream_state.name,
)
if connection := self.speaker().connection:
await self.send_message(
'connection',
peer_address=connection.peer_address.to_string(False),
peer_name=connection.peer_name,
)
async def send_message(self, message_type: str, **kwargs) -> None:
if self.channel_socket is None:
return
message = {'type': message_type, 'params': kwargs}
await self.channel_socket.send_json(message)
async def send_audio(self, data: bytes) -> None:
if self.channel_socket is None:
return
try:
await self.channel_socket.send_bytes(data)
except Exception as error:
logger.warning(f'exception while sending audio packet: {error}')
# -----------------------------------------------------------------------------
class Speaker:
def __init__(
self,
device_config_path: Optional[str],
ui_port: int,
transport: str,
lc3_input_file_path: str,
):
self.device_config_path = device_config_path
self.transport = transport
self.lc3_input_file_path = lc3_input_file_path
# Create an HTTP server for the UI
self.ui_server = UiServer(speaker=self, port=ui_port)
async def run(self) -> None:
await self.ui_server.start_http()
async with await open_transport(self.transport) as hci_transport:
# Create a device
if self.device_config_path:
device_config = DeviceConfiguration.from_file(self.device_config_path)
else:
device_config = DeviceConfiguration(
name="Bumble LE Headphone",
class_of_device=0x244418,
keystore="JsonKeyStore",
advertising_interval_min=25,
advertising_interval_max=25,
address=Address('F1:F2:F3:F4:F5:F6'),
)
device_config.le_enabled = True
device_config.cis_enabled = True
self.device = Device.from_config_with_hci(
device_config, hci_transport.source, hci_transport.sink
)
self.device.add_service(
bap.PublishedAudioCapabilitiesService(
supported_source_context=bap.ContextType(0xFFFF),
available_source_context=bap.ContextType(0xFFFF),
supported_sink_context=bap.ContextType(0xFFFF), # All context types
available_sink_context=bap.ContextType(0xFFFF), # All context types
sink_audio_locations=(
bap.AudioLocation.FRONT_LEFT | bap.AudioLocation.FRONT_RIGHT
),
sink_pac=[_sink_pac_record()],
source_audio_locations=bap.AudioLocation.FRONT_LEFT,
source_pac=[_source_pac_record()],
)
)
ascs = bap.AudioStreamControlService(
self.device, sink_ase_id=[1], source_ase_id=[2]
)
self.device.add_service(ascs)
advertising_data = bytes(
AdvertisingData(
[
(
AdvertisingData.COMPLETE_LOCAL_NAME,
bytes(device_config.name, 'utf-8'),
),
(
AdvertisingData.FLAGS,
bytes([AdvertisingData.LE_GENERAL_DISCOVERABLE_MODE_FLAG]),
),
(
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
bytes(bap.PublishedAudioCapabilitiesService.UUID),
),
]
)
) + bytes(bap.UnicastServerAdvertisingData())
def on_pdu(pdu: HCI_IsoDataPacket, ase: bap.AseStateMachine):
codec_config = ase.codec_specific_configuration
assert isinstance(codec_config, bap.CodecSpecificConfiguration)
pcm = decode(
codec_config.frame_duration.us,
codec_config.audio_channel_allocation.channel_count,
pdu.iso_sdu_fragment,
)
self.device.abort_on('disconnection', self.ui_server.send_audio(pcm))
def on_ase_state_change(ase: bap.AseStateMachine) -> None:
if ase.state == bap.AseStateMachine.State.STREAMING:
codec_config = ase.codec_specific_configuration
assert isinstance(codec_config, bap.CodecSpecificConfiguration)
assert ase.cis_link
if ase.role == bap.AudioRole.SOURCE:
ase.cis_link.abort_on(
'disconnection',
lc3_source_task(
filename=self.lc3_input_file_path,
sdu_length=(
codec_config.codec_frames_per_sdu
* codec_config.octets_per_codec_frame
),
frame_duration_us=codec_config.frame_duration.us,
device=self.device,
cis_handle=ase.cis_link.handle,
),
)
else:
ase.cis_link.sink = functools.partial(on_pdu, ase=ase)
elif ase.state == bap.AseStateMachine.State.CODEC_CONFIGURED:
codec_config = ase.codec_specific_configuration
assert isinstance(codec_config, bap.CodecSpecificConfiguration)
if ase.role == bap.AudioRole.SOURCE:
setup_encoders(
codec_config.sampling_frequency.hz,
codec_config.frame_duration.us,
codec_config.audio_channel_allocation.channel_count,
)
else:
setup_decoders(
codec_config.sampling_frequency.hz,
codec_config.frame_duration.us,
codec_config.audio_channel_allocation.channel_count,
)
for ase in ascs.ase_state_machines.values():
ase.on('state_change', functools.partial(on_ase_state_change, ase=ase))
await self.device.power_on()
await self.device.create_advertising_set(
advertising_data=advertising_data,
auto_restart=True,
advertising_parameters=AdvertisingParameters(
primary_advertising_interval_min=100,
primary_advertising_interval_max=100,
),
)
await hci_transport.source.terminated
@click.command()
@click.option(
'--ui-port',
'ui_port',
metavar='HTTP_PORT',
default=DEFAULT_UI_PORT,
show_default=True,
help='HTTP port for the UI server',
)
@click.option('--device-config', metavar='FILENAME', help='Device configuration file')
@click.argument('transport')
@click.argument('lc3_file')
def speaker(ui_port: int, device_config: str, transport: str, lc3_file: str) -> None:
"""Run the speaker."""
asyncio.run(Speaker(device_config, ui_port, transport, lc3_file).run())
# -----------------------------------------------------------------------------
def main():
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
speaker()
# -----------------------------------------------------------------------------
if __name__ == "__main__":
main() # pylint: disable=no-value-for-parameter
-68
View File
@@ -1,68 +0,0 @@
<html data-bs-theme="dark">
<head>
<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.2/dist/css/bootstrap.min.css" rel="stylesheet"
integrity="sha384-T3c6CoIi6uLrA9TneNEoa7RxnatzjcDSCmG1MXxSR1GAsXEV/Dwwykc2MPK8M2HN" crossorigin="anonymous">
<script src="https://unpkg.com/pcm-player"></script>
</head>
<body>
<nav class="navbar navbar-dark bg-primary">
<div class="container">
<span class="navbar-brand mb-0 h1">Bumble Unicast Server</span>
</div>
</nav>
<br>
<div class="container">
<button type="button" class="btn btn-danger" id="connect-audio" onclick="connectAudio()">Connect Audio</button>
<button class="btn btn-primary" type="button" disabled>
<span class="spinner-border spinner-border-sm" id="ws-status-spinner" aria-hidden="true"></span>
<span role="status" id="ws-status">WebSocket Connecting...</span>
</button>
</div>
<script>
let player = null;
const wsStatus = document.getElementById("ws-status");
const wsStatusSpinner = document.getElementById("ws-status-spinner");
const socket = new WebSocket('ws://127.0.0.1:7654/channel');
socket.binaryType = "arraybuffer";
socket.onmessage = function (message) {
if (typeof message.data === 'string' || message.data instanceof String) {
console.log(`channel MESSAGE: ${message.data}`);
} else {
console.log(typeof (message.data))
// BINARY audio data.
if (player == null) return;
player.feed(message.data);
}
};
socket.onopen = (message) => {
wsStatusSpinner.remove();
wsStatus.textContent = "WebSocket Connected";
}
socket.onclose = (message) => {
wsStatus.textContent = "WebSocket Disconnected";
}
function connectAudio() {
player = new PCMPlayer({
inputCodec: 'Int16',
channels: 2,
sampleRate: 48000,
flushTime: 10,
});
const button = document.getElementById("connect-audio")
button.disabled = true;
button.textContent = "Audio Connected";
}
</script>
</div>
</body>
</html>
Binary file not shown.
+10 -67
View File
@@ -24,16 +24,10 @@ from prompt_toolkit.shortcuts import PromptSession
from bumble.colors import color from bumble.colors import color
from bumble.device import Device, Peer from bumble.device import Device, Peer
from bumble.transport import open_transport_or_link from bumble.transport import open_transport_or_link
from bumble.pairing import OobData, PairingDelegate, PairingConfig from bumble.pairing import PairingDelegate, PairingConfig
from bumble.smp import OobContext, OobLegacyContext
from bumble.smp import error_name as smp_error_name from bumble.smp import error_name as smp_error_name
from bumble.keys import JsonKeyStore from bumble.keys import JsonKeyStore
from bumble.core import ( from bumble.core import ProtocolError
AdvertisingData,
ProtocolError,
BT_LE_TRANSPORT,
BT_BR_EDR_TRANSPORT,
)
from bumble.gatt import ( from bumble.gatt import (
GATT_DEVICE_NAME_CHARACTERISTIC, GATT_DEVICE_NAME_CHARACTERISTIC,
GATT_GENERIC_ACCESS_SERVICE, GATT_GENERIC_ACCESS_SERVICE,
@@ -52,13 +46,11 @@ from bumble.att import (
class Waiter: class Waiter:
instance = None instance = None
def __init__(self, linger=False): def __init__(self):
self.done = asyncio.get_running_loop().create_future() self.done = asyncio.get_running_loop().create_future()
self.linger = linger
def terminate(self): def terminate(self):
if not self.linger: self.done.set_result(None)
self.done.set_result(None)
async def wait_until_terminated(self): async def wait_until_terminated(self):
return await self.done return await self.done
@@ -68,7 +60,7 @@ class Waiter:
class Delegate(PairingDelegate): class Delegate(PairingDelegate):
def __init__(self, mode, connection, capability_string, do_prompt): def __init__(self, mode, connection, capability_string, do_prompt):
super().__init__( super().__init__(
io_capability={ {
'keyboard': PairingDelegate.KEYBOARD_INPUT_ONLY, 'keyboard': PairingDelegate.KEYBOARD_INPUT_ONLY,
'display': PairingDelegate.DISPLAY_OUTPUT_ONLY, 'display': PairingDelegate.DISPLAY_OUTPUT_ONLY,
'display+keyboard': PairingDelegate.DISPLAY_OUTPUT_AND_KEYBOARD_INPUT, 'display+keyboard': PairingDelegate.DISPLAY_OUTPUT_AND_KEYBOARD_INPUT,
@@ -293,9 +285,7 @@ async def pair(
mitm, mitm,
bond, bond,
ctkd, ctkd,
linger,
io, io,
oob,
prompt, prompt,
request, request,
print_keys, print_keys,
@@ -304,7 +294,7 @@ async def pair(
hci_transport, hci_transport,
address_or_name, address_or_name,
): ):
Waiter.instance = Waiter(linger=linger) Waiter.instance = Waiter()
print('<<< connecting to HCI...') print('<<< connecting to HCI...')
async with await open_transport_or_link(hci_transport) as (hci_source, hci_sink): async with await open_transport_or_link(hci_transport) as (hci_source, hci_sink):
@@ -353,51 +343,16 @@ async def pair(
await device.keystore.print(prefix=color('@@@ ', 'blue')) await device.keystore.print(prefix=color('@@@ ', 'blue'))
print(color('@@@-----------------------------------', 'blue')) print(color('@@@-----------------------------------', 'blue'))
# Create an OOB context if needed
if oob:
our_oob_context = OobContext()
shared_data = (
None
if oob == '-'
else OobData.from_ad(AdvertisingData.from_bytes(bytes.fromhex(oob)))
)
legacy_context = OobLegacyContext()
oob_contexts = PairingConfig.OobConfig(
our_context=our_oob_context,
peer_data=shared_data,
legacy_context=legacy_context,
)
oob_data = OobData(
address=device.random_address,
shared_data=shared_data,
legacy_context=legacy_context,
)
print(color('@@@-----------------------------------', 'yellow'))
print(color('@@@ OOB Data:', 'yellow'))
print(color(f'@@@ {our_oob_context.share()}', 'yellow'))
print(color(f'@@@ TK={legacy_context.tk.hex()}', 'yellow'))
print(color(f'@@@ HEX: ({bytes(oob_data.to_ad()).hex()})', 'yellow'))
print(color('@@@-----------------------------------', 'yellow'))
else:
oob_contexts = None
# Set up a pairing config factory # Set up a pairing config factory
device.pairing_config_factory = lambda connection: PairingConfig( device.pairing_config_factory = lambda connection: PairingConfig(
sc=sc, sc, mitm, bond, Delegate(mode, connection, io, prompt)
mitm=mitm,
bonding=bond,
oob=oob_contexts,
delegate=Delegate(mode, connection, io, prompt),
) )
# Connect to a peer or wait for a connection # Connect to a peer or wait for a connection
device.on('connection', lambda connection: on_connection(connection, request)) device.on('connection', lambda connection: on_connection(connection, request))
if address_or_name is not None: if address_or_name is not None:
print(color(f'=== Connecting to {address_or_name}...', 'green')) print(color(f'=== Connecting to {address_or_name}...', 'green'))
connection = await device.connect( connection = await device.connect(address_or_name)
address_or_name,
transport=BT_LE_TRANSPORT if mode == 'le' else BT_BR_EDR_TRANSPORT,
)
if not request: if not request:
try: try:
@@ -405,9 +360,10 @@ async def pair(
await connection.pair() await connection.pair()
else: else:
await connection.authenticate() await connection.authenticate()
return
except ProtocolError as error: except ProtocolError as error:
print(color(f'Pairing failed: {error}', 'red')) print(color(f'Pairing failed: {error}', 'red'))
return
else: else:
if mode == 'le': if mode == 'le':
# Advertise so that peers can find us and connect # Advertise so that peers can find us and connect
@@ -457,7 +413,6 @@ class LogHandler(logging.Handler):
help='Enable CTKD', help='Enable CTKD',
show_default=True, show_default=True,
) )
@click.option('--linger', default=False, is_flag=True, help='Linger after pairing')
@click.option( @click.option(
'--io', '--io',
type=click.Choice( type=click.Choice(
@@ -466,14 +421,6 @@ class LogHandler(logging.Handler):
default='display+keyboard', default='display+keyboard',
show_default=True, show_default=True,
) )
@click.option(
'--oob',
metavar='<oob-data-hex>',
help=(
'Use OOB pairing with this data from the peer '
'(use "-" to enable OOB without peer data)'
),
)
@click.option('--prompt', is_flag=True, help='Prompt to accept/reject pairing request') @click.option('--prompt', is_flag=True, help='Prompt to accept/reject pairing request')
@click.option( @click.option(
'--request', is_flag=True, help='Request that the connecting peer initiate pairing' '--request', is_flag=True, help='Request that the connecting peer initiate pairing'
@@ -493,9 +440,7 @@ def main(
mitm, mitm,
bond, bond,
ctkd, ctkd,
linger,
io, io,
oob,
prompt, prompt,
request, request,
print_keys, print_keys,
@@ -518,9 +463,7 @@ def main(
mitm, mitm,
bond, bond,
ctkd, ctkd,
linger,
io, io,
oob,
prompt, prompt,
request, request,
print_keys, print_keys,
-511
View File
@@ -1,511 +0,0 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import logging
import os
import time
from typing import Optional
import click
from bumble.colors import color
from bumble.device import Device, DeviceConfiguration, Connection
from bumble import core
from bumble import hci
from bumble import rfcomm
from bumble import transport
from bumble import utils
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
DEFAULT_RFCOMM_UUID = "E6D55659-C8B4-4B85-96BB-B1143AF6D3AE"
DEFAULT_MTU = 4096
DEFAULT_CLIENT_TCP_PORT = 9544
DEFAULT_SERVER_TCP_PORT = 9545
TRACE_MAX_SIZE = 48
# -----------------------------------------------------------------------------
class Tracer:
"""
Trace data buffers transmitted from one endpoint to another, with stats.
"""
def __init__(self, channel_name: str) -> None:
self.channel_name = channel_name
self.last_ts: float = 0.0
def trace_data(self, data: bytes) -> None:
now = time.time()
elapsed_s = now - self.last_ts if self.last_ts else 0
elapsed_ms = int(elapsed_s * 1000)
instant_throughput_kbps = ((len(data) / elapsed_s) / 1000) if elapsed_s else 0.0
hex_str = data[:TRACE_MAX_SIZE].hex() + (
"..." if len(data) > TRACE_MAX_SIZE else ""
)
print(
f"[{self.channel_name}] {len(data):4} bytes "
f"(+{elapsed_ms:4}ms, {instant_throughput_kbps: 7.2f}kB/s) "
f" {hex_str}"
)
self.last_ts = now
# -----------------------------------------------------------------------------
class ServerBridge:
"""
RFCOMM server bridge: waits for a peer to connect an RFCOMM channel.
The RFCOMM channel may be associated with a UUID published in an SDP service
description, or simply be on a system-assigned channel number.
When the connection is made, the bridge connects a TCP socket to a remote host and
bridges the data in both directions, with flow control.
When the RFCOMM channel is closed, the bridge disconnects the TCP socket
and waits for a new channel to be connected.
"""
READ_CHUNK_SIZE = 4096
def __init__(
self, channel: int, uuid: str, trace: bool, tcp_host: str, tcp_port: int
) -> None:
self.device: Optional[Device] = None
self.channel = channel
self.uuid = uuid
self.tcp_host = tcp_host
self.tcp_port = tcp_port
self.rfcomm_channel: Optional[rfcomm.DLC] = None
self.tcp_tracer: Optional[Tracer]
self.rfcomm_tracer: Optional[Tracer]
if trace:
self.tcp_tracer = Tracer(color("RFCOMM->TCP", "cyan"))
self.rfcomm_tracer = Tracer(color("TCP->RFCOMM", "magenta"))
else:
self.rfcomm_tracer = None
self.tcp_tracer = None
async def start(self, device: Device) -> None:
self.device = device
# Create and register a server
rfcomm_server = rfcomm.Server(self.device)
# Listen for incoming DLC connections
self.channel = rfcomm_server.listen(self.on_rfcomm_channel, self.channel)
# Setup the SDP to advertise this channel
service_record_handle = 0x00010001
self.device.sdp_service_records = {
service_record_handle: rfcomm.make_service_sdp_records(
service_record_handle, self.channel, core.UUID(self.uuid)
)
}
# We're ready for a connection
self.device.on("connection", self.on_connection)
await self.set_available(True)
print(
color(
(
f"### Listening for RFCOMM connection on {device.public_address}, "
f"channel {self.channel}"
),
"yellow",
)
)
async def set_available(self, available: bool):
# Become discoverable and connectable
assert self.device
await self.device.set_connectable(available)
await self.device.set_discoverable(available)
def on_connection(self, connection):
print(color(f"@@@ Bluetooth connection: {connection}", "blue"))
connection.on("disconnection", self.on_disconnection)
# Don't accept new connections until we're disconnected
utils.AsyncRunner.spawn(self.set_available(False))
def on_disconnection(self, reason: int):
print(
color("@@@ Bluetooth disconnection:", "red"),
hci.HCI_Constant.error_name(reason),
)
# We're ready for a new connection
utils.AsyncRunner.spawn(self.set_available(True))
# Called when an RFCOMM channel is established
@utils.AsyncRunner.run_in_task()
async def on_rfcomm_channel(self, rfcomm_channel):
print(color("*** RFCOMM channel:", "cyan"), rfcomm_channel)
# Connect to the TCP server
print(
color(
f"### Connecting to TCP {self.tcp_host}:{self.tcp_port}",
"yellow",
)
)
try:
reader, writer = await asyncio.open_connection(self.tcp_host, self.tcp_port)
except OSError:
print(color("!!! Connection failed", "red"))
await rfcomm_channel.disconnect()
return
# Pipe data from RFCOMM to TCP
def on_rfcomm_channel_closed():
print(color("*** RFCOMM channel closed", "cyan"))
writer.close()
def write_rfcomm_data(data):
if self.rfcomm_tracer:
self.rfcomm_tracer.trace_data(data)
writer.write(data)
rfcomm_channel.sink = write_rfcomm_data
rfcomm_channel.on("close", on_rfcomm_channel_closed)
# Pipe data from TCP to RFCOMM
while True:
try:
data = await reader.read(self.READ_CHUNK_SIZE)
if len(data) == 0:
print(color("### TCP end of stream", "yellow"))
if rfcomm_channel.state == rfcomm.DLC.State.CONNECTED:
await rfcomm_channel.disconnect()
return
if self.tcp_tracer:
self.tcp_tracer.trace_data(data)
rfcomm_channel.write(data)
await rfcomm_channel.drain()
except Exception as error:
print(f"!!! Exception: {error}")
break
writer.close()
await writer.wait_closed()
print(color("~~~ Bye bye", "magenta"))
# -----------------------------------------------------------------------------
class ClientBridge:
"""
RFCOMM client bridge: connects to a BR/EDR device, then waits for an inbound
TCP connection on a specified port number. When a TCP client connects, an
RFCOMM connection to the device is established, and the data is bridged in both
directions, with flow control.
When the TCP connection is closed by the client, the RFCOMM channel is
disconnected, but the connection to the device remains, ready for a new TCP client
to connect.
"""
READ_CHUNK_SIZE = 4096
def __init__(
self,
channel: int,
uuid: str,
trace: bool,
address: str,
tcp_host: str,
tcp_port: int,
encrypt: bool,
):
self.channel = channel
self.uuid = uuid
self.trace = trace
self.address = address
self.tcp_host = tcp_host
self.tcp_port = tcp_port
self.encrypt = encrypt
self.device: Optional[Device] = None
self.connection: Optional[Connection] = None
self.rfcomm_client: Optional[rfcomm.Client]
self.rfcomm_mux: Optional[rfcomm.Multiplexer]
self.tcp_connected: bool = False
self.tcp_tracer: Optional[Tracer]
self.rfcomm_tracer: Optional[Tracer]
if trace:
self.tcp_tracer = Tracer(color("RFCOMM->TCP", "cyan"))
self.rfcomm_tracer = Tracer(color("TCP->RFCOMM", "magenta"))
else:
self.rfcomm_tracer = None
self.tcp_tracer = None
async def connect(self) -> None:
if self.connection:
return
print(color(f"@@@ Connecting to Bluetooth {self.address}", "blue"))
assert self.device
self.connection = await self.device.connect(
self.address, transport=core.BT_BR_EDR_TRANSPORT
)
print(color(f"@@@ Bluetooth connection: {self.connection}", "blue"))
self.connection.on("disconnection", self.on_disconnection)
if self.encrypt:
print(color("@@@ Encrypting Bluetooth connection", "blue"))
await self.connection.encrypt()
print(color("@@@ Bluetooth connection encrypted", "blue"))
self.rfcomm_client = rfcomm.Client(self.connection)
try:
self.rfcomm_mux = await self.rfcomm_client.start()
except BaseException as e:
print(color("!!! Failed to setup RFCOMM connection", "red"), e)
raise
async def start(self, device: Device) -> None:
self.device = device
await device.set_connectable(False)
await device.set_discoverable(False)
# Called when a TCP connection is established
async def on_tcp_connection(reader, writer):
print(color("<<< TCP connection", "magenta"))
if self.tcp_connected:
print(
color("!!! TCP connection already active, rejecting new one", "red")
)
writer.close()
return
self.tcp_connected = True
try:
await self.pipe(reader, writer)
except BaseException as error:
print(color("!!! Exception while piping data:", "red"), error)
return
finally:
writer.close()
await writer.wait_closed()
self.tcp_connected = False
await asyncio.start_server(
on_tcp_connection,
host=self.tcp_host if self.tcp_host != "_" else None,
port=self.tcp_port,
)
print(
color(
f"### Listening for TCP connections on port {self.tcp_port}", "magenta"
)
)
async def pipe(
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
) -> None:
# Resolve the channel number from the UUID if needed
if self.channel == 0:
await self.connect()
assert self.connection
channel = await rfcomm.find_rfcomm_channel_with_uuid(
self.connection, self.uuid
)
if channel:
print(color(f"### Found RFCOMM channel {channel}", "yellow"))
else:
print(color(f"!!! RFCOMM channel with UUID {self.uuid} not found"))
return
else:
channel = self.channel
# Connect a new RFCOMM channel
await self.connect()
assert self.rfcomm_mux
print(color(f"*** Opening RFCOMM channel {channel}", "green"))
try:
rfcomm_channel = await self.rfcomm_mux.open_dlc(channel)
print(color(f"*** RFCOMM channel open: {rfcomm_channel}", "green"))
except Exception as error:
print(color(f"!!! RFCOMM open failed: {error}", "red"))
return
# Pipe data from RFCOMM to TCP
def on_rfcomm_channel_closed():
print(color("*** RFCOMM channel closed", "green"))
def write_rfcomm_data(data):
if self.trace:
self.rfcomm_tracer.trace_data(data)
writer.write(data)
rfcomm_channel.on("close", on_rfcomm_channel_closed)
rfcomm_channel.sink = write_rfcomm_data
# Pipe data from TCP to RFCOMM
while True:
try:
data = await reader.read(self.READ_CHUNK_SIZE)
if len(data) == 0:
print(color("### TCP end of stream", "yellow"))
if rfcomm_channel.state == rfcomm.DLC.State.CONNECTED:
await rfcomm_channel.disconnect()
self.tcp_connected = False
return
if self.tcp_tracer:
self.tcp_tracer.trace_data(data)
rfcomm_channel.write(data)
await rfcomm_channel.drain()
except Exception as error:
print(f"!!! Exception: {error}")
break
print(color("~~~ Bye bye", "magenta"))
def on_disconnection(self, reason: int) -> None:
print(
color("@@@ Bluetooth disconnection:", "red"),
hci.HCI_Constant.error_name(reason),
)
self.connection = None
# -----------------------------------------------------------------------------
async def run(device_config, hci_transport, bridge):
print("<<< connecting to HCI...")
async with await transport.open_transport_or_link(hci_transport) as (
hci_source,
hci_sink,
):
print("<<< connected")
if device_config:
device = Device.from_config_file_with_hci(
device_config, hci_source, hci_sink
)
else:
device = Device.from_config_with_hci(
DeviceConfiguration(), hci_source, hci_sink
)
device.classic_enabled = True
# Let's go
await device.power_on()
try:
await bridge.start(device)
# Wait until the transport terminates
await hci_source.wait_for_termination()
except core.ConnectionError as error:
print(color(f"!!! Bluetooth connection failed: {error}", "red"))
except Exception as error:
print(f"Exception while running bridge: {error}")
# -----------------------------------------------------------------------------
@click.group()
@click.pass_context
@click.option(
"--device-config",
metavar="CONFIG_FILE",
help="Device configuration file",
)
@click.option(
"--hci-transport", metavar="TRANSPORT_NAME", help="HCI transport", required=True
)
@click.option("--trace", is_flag=True, help="Trace bridged data to stdout")
@click.option(
"--channel",
metavar="CHANNEL_NUMER",
help="RFCOMM channel number",
type=int,
default=0,
)
@click.option(
"--uuid",
metavar="UUID",
help="UUID for the RFCOMM channel",
default=DEFAULT_RFCOMM_UUID,
)
def cli(
context,
device_config,
hci_transport,
trace,
channel,
uuid,
):
context.ensure_object(dict)
context.obj["device_config"] = device_config
context.obj["hci_transport"] = hci_transport
context.obj["trace"] = trace
context.obj["channel"] = channel
context.obj["uuid"] = uuid
# -----------------------------------------------------------------------------
@cli.command()
@click.pass_context
@click.option("--tcp-host", help="TCP host", default="localhost")
@click.option("--tcp-port", help="TCP port", default=DEFAULT_SERVER_TCP_PORT)
def server(context, tcp_host, tcp_port):
bridge = ServerBridge(
context.obj["channel"],
context.obj["uuid"],
context.obj["trace"],
tcp_host,
tcp_port,
)
asyncio.run(run(context.obj["device_config"], context.obj["hci_transport"], bridge))
# -----------------------------------------------------------------------------
@cli.command()
@click.pass_context
@click.argument("bluetooth-address")
@click.option("--tcp-host", help="TCP host", default="_")
@click.option("--tcp-port", help="TCP port", default=DEFAULT_CLIENT_TCP_PORT)
@click.option("--encrypt", is_flag=True, help="Encrypt the connection")
def client(context, bluetooth_address, tcp_host, tcp_port, encrypt):
bridge = ClientBridge(
context.obj["channel"],
context.obj["uuid"],
context.obj["trace"],
bluetooth_address,
tcp_host,
tcp_port,
encrypt,
)
asyncio.run(run(context.obj["device_config"], context.obj["hci_transport"], bridge))
# -----------------------------------------------------------------------------
logging.basicConfig(level=os.environ.get("BUMBLE_LOGLEVEL", "WARNING").upper())
if __name__ == "__main__":
cli(obj={}) # pylint: disable=no-value-for-parameter
+8 -44
View File
@@ -26,7 +26,7 @@ from bumble.transport import open_transport_or_link
from bumble.keys import JsonKeyStore from bumble.keys import JsonKeyStore
from bumble.smp import AddressResolver from bumble.smp import AddressResolver
from bumble.device import Advertisement from bumble.device import Advertisement
from bumble.hci import Address, HCI_Constant, HCI_LE_1M_PHY, HCI_LE_CODED_PHY from bumble.hci import HCI_Constant, HCI_LE_1M_PHY, HCI_LE_CODED_PHY
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -66,15 +66,10 @@ class AdvertisementPrinter:
address_type_string = ('PUBLIC', 'RANDOM', 'PUBLIC_ID', 'RANDOM_ID')[ address_type_string = ('PUBLIC', 'RANDOM', 'PUBLIC_ID', 'RANDOM_ID')[
address.address_type address.address_type
] ]
if address.address_type in ( if address.is_public:
Address.RANDOM_IDENTITY_ADDRESS, type_color = 'cyan'
Address.PUBLIC_IDENTITY_ADDRESS,
):
type_color = 'yellow'
else: else:
if address.is_public: if address.is_static:
type_color = 'cyan'
elif address.is_static:
type_color = 'green' type_color = 'green'
address_qualifier = '(static)' address_qualifier = '(static)'
elif address.is_resolvable: elif address.is_resolvable:
@@ -121,7 +116,6 @@ async def scan(
phy, phy,
filter_duplicates, filter_duplicates,
raw, raw,
irks,
keystore_file, keystore_file,
device_config, device_config,
transport, transport,
@@ -146,21 +140,9 @@ async def scan(
if device.keystore: if device.keystore:
resolving_keys = await device.keystore.get_resolving_keys() resolving_keys = await device.keystore.get_resolving_keys()
resolver = AddressResolver(resolving_keys)
else: else:
resolving_keys = [] resolver = None
for irk_and_address in irks:
if ':' not in irk_and_address:
raise ValueError('invalid IRK:ADDRESS value')
irk_hex, address_str = irk_and_address.split(':', 1)
resolving_keys.append(
(
bytes.fromhex(irk_hex),
Address(address_str, Address.RANDOM_DEVICE_ADDRESS),
)
)
resolver = AddressResolver(resolving_keys) if resolving_keys else None
printer = AdvertisementPrinter(min_rssi, resolver) printer = AdvertisementPrinter(min_rssi, resolver)
if raw: if raw:
@@ -205,24 +187,8 @@ async def scan(
default=False, default=False,
help='Listen for raw advertising reports instead of processed ones', help='Listen for raw advertising reports instead of processed ones',
) )
@click.option( @click.option('--keystore-file', help='Keystore file to use when resolving addresses')
'--irk', @click.option('--device-config', help='Device config file for the scanning device')
metavar='<IRK_HEX>:<ADDRESS>',
help=(
'Use this IRK for resolving private addresses ' '(may be used more than once)'
),
multiple=True,
)
@click.option(
'--keystore-file',
metavar='FILE_PATH',
help='Keystore file to use when resolving addresses',
)
@click.option(
'--device-config',
metavar='FILE_PATH',
help='Device config file for the scanning device',
)
@click.argument('transport') @click.argument('transport')
def main( def main(
min_rssi, min_rssi,
@@ -232,7 +198,6 @@ def main(
phy, phy,
filter_duplicates, filter_duplicates,
raw, raw,
irk,
keystore_file, keystore_file,
device_config, device_config,
transport, transport,
@@ -247,7 +212,6 @@ def main(
phy, phy,
filter_duplicates, filter_duplicates,
raw, raw,
irk,
keystore_file, keystore_file,
device_config, device_config,
transport, transport,
+15 -60
View File
@@ -15,11 +15,7 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import datetime
import logging
import os
import struct import struct
import click import click
from bumble.colors import color from bumble.colors import color
@@ -28,14 +24,6 @@ from bumble.transport.common import PacketReader
from bumble.helpers import PacketTracer from bumble.helpers import PacketTracer
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Classes
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class SnoopPacketReader: class SnoopPacketReader:
''' '''
@@ -48,18 +36,12 @@ class SnoopPacketReader:
DATALINK_BSCP = 1003 DATALINK_BSCP = 1003
DATALINK_H5 = 1004 DATALINK_H5 = 1004
IDENTIFICATION_PATTERN = b'btsnoop\0'
TIMESTAMP_ANCHOR = datetime.datetime(2000, 1, 1)
TIMESTAMP_DELTA = 0x00E03AB44A676000
ONE_MICROSECOND = datetime.timedelta(microseconds=1)
def __init__(self, source): def __init__(self, source):
self.source = source self.source = source
self.at_end = False
# Read the header # Read the header
identification_pattern = source.read(8) identification_pattern = source.read(8)
if identification_pattern != self.IDENTIFICATION_PATTERN: if identification_pattern.hex().lower() != '6274736e6f6f7000':
raise ValueError( raise ValueError(
'not a valid snoop file, unexpected identification pattern' 'not a valid snoop file, unexpected identification pattern'
) )
@@ -73,32 +55,19 @@ class SnoopPacketReader:
# Read the record header # Read the record header
header = self.source.read(24) header = self.source.read(24)
if len(header) < 24: if len(header) < 24:
self.at_end = True return (0, None)
return (None, 0, None)
# Parse the header
( (
original_length, original_length,
included_length, included_length,
packet_flags, packet_flags,
_cumulative_drops, _cumulative_drops,
timestamp, _timestamp_seconds,
) = struct.unpack('>IIIIQ', header) _timestamp_microsecond,
) = struct.unpack('>IIIIII', header)
# Skip truncated packets # Abort on truncated packets
if original_length != included_length: if original_length != included_length:
print( return (0, None)
color(
f"!!! truncated packet ({included_length}/{original_length})", "red"
)
)
self.source.read(included_length)
return (None, 0, None)
# Convert the timestamp to a datetime object.
ts_dt = self.TIMESTAMP_ANCHOR + datetime.timedelta(
microseconds=timestamp - self.TIMESTAMP_DELTA
)
if self.data_link_type == self.DATALINK_H1: if self.data_link_type == self.DATALINK_H1:
# The packet is un-encapsulated, look at the flags to figure out its type # The packet is un-encapsulated, look at the flags to figure out its type
@@ -120,17 +89,7 @@ class SnoopPacketReader:
bytes([packet_type]) + self.source.read(included_length), bytes([packet_type]) + self.source.read(included_length),
) )
return (ts_dt, packet_flags & 1, self.source.read(included_length)) return (packet_flags & 1, self.source.read(included_length))
# -----------------------------------------------------------------------------
class Printer:
def __init__(self):
self.index = 0
def print(self, message: str) -> None:
self.index += 1
print(f"[{self.index:8}]{message}")
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -163,28 +122,24 @@ def main(format, vendors, filename):
packet_reader = PacketReader(input) packet_reader = PacketReader(input)
def read_next_packet(): def read_next_packet():
return (None, 0, packet_reader.next_packet()) return (0, packet_reader.next_packet())
else: else:
packet_reader = SnoopPacketReader(input) packet_reader = SnoopPacketReader(input)
read_next_packet = packet_reader.next_packet read_next_packet = packet_reader.next_packet
printer = Printer() tracer = PacketTracer(emit_message=print)
tracer = PacketTracer(emit_message=printer.print)
while not packet_reader.at_end: while True:
try: try:
(timestamp, direction, packet) = read_next_packet() (direction, packet) = read_next_packet()
if packet: if packet is None:
tracer.trace(hci.HCI_Packet.from_bytes(packet), direction, timestamp) break
else: tracer.trace(hci.HCI_Packet.from_bytes(packet), direction)
printer.print(color("[TRUNCATED]", "red"))
except Exception as error: except Exception as error:
logger.exception()
print(color(f'!!! {error}', 'red')) print(color(f'!!! {error}', 'red'))
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
if __name__ == '__main__': if __name__ == '__main__':
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
main() # pylint: disable=no-value-for-parameter main() # pylint: disable=no-value-for-parameter
-1
View File
@@ -76,7 +76,6 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
DEFAULT_UI_PORT = 7654 DEFAULT_UI_PORT = 7654
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class AudioExtractor: class AudioExtractor:
@staticmethod @staticmethod
-1
View File
@@ -24,7 +24,6 @@ from bumble.device import Device
from bumble.keys import JsonKeyStore from bumble.keys import JsonKeyStore
from bumble.transport import open_transport from bumble.transport import open_transport
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def unbond_with_keystore(keystore, address): async def unbond_with_keystore(keystore, address):
if address is None: if address is None:
+72 -97
View File
@@ -15,13 +15,9 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations
import dataclasses
import struct import struct
import logging import logging
from collections.abc import AsyncGenerator from collections import namedtuple
from typing import List, Callable, Awaitable
from .company_ids import COMPANY_IDENTIFIERS from .company_ids import COMPANY_IDENTIFIERS
from .sdp import ( from .sdp import (
@@ -184,12 +180,8 @@ def make_audio_source_service_sdp_records(service_record_handle, version=(1, 3))
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID, SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence( DataElement.sequence(
[ [
DataElement.sequence( DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE),
[ DataElement.unsigned_integer_16(version_int),
DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE),
DataElement.unsigned_integer_16(version_int),
]
)
] ]
), ),
), ),
@@ -238,12 +230,8 @@ def make_audio_sink_service_sdp_records(service_record_handle, version=(1, 3)):
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID, SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence( DataElement.sequence(
[ [
DataElement.sequence( DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE),
[ DataElement.unsigned_integer_16(version_int),
DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE),
DataElement.unsigned_integer_16(version_int),
]
)
] ]
), ),
), ),
@@ -251,20 +239,24 @@ def make_audio_sink_service_sdp_records(service_record_handle, version=(1, 3)):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@dataclasses.dataclass class SbcMediaCodecInformation(
class SbcMediaCodecInformation: namedtuple(
'SbcMediaCodecInformation',
[
'sampling_frequency',
'channel_mode',
'block_length',
'subbands',
'allocation_method',
'minimum_bitpool_value',
'maximum_bitpool_value',
],
)
):
''' '''
A2DP spec - 4.3.2 Codec Specific Information Elements A2DP spec - 4.3.2 Codec Specific Information Elements
''' '''
sampling_frequency: int
channel_mode: int
block_length: int
subbands: int
allocation_method: int
minimum_bitpool_value: int
maximum_bitpool_value: int
SAMPLING_FREQUENCY_BITS = {16000: 1 << 3, 32000: 1 << 2, 44100: 1 << 1, 48000: 1} SAMPLING_FREQUENCY_BITS = {16000: 1 << 3, 32000: 1 << 2, 44100: 1 << 1, 48000: 1}
CHANNEL_MODE_BITS = { CHANNEL_MODE_BITS = {
SBC_MONO_CHANNEL_MODE: 1 << 3, SBC_MONO_CHANNEL_MODE: 1 << 3,
@@ -280,7 +272,7 @@ class SbcMediaCodecInformation:
} }
@staticmethod @staticmethod
def from_bytes(data: bytes) -> SbcMediaCodecInformation: def from_bytes(data: bytes) -> 'SbcMediaCodecInformation':
sampling_frequency = (data[0] >> 4) & 0x0F sampling_frequency = (data[0] >> 4) & 0x0F
channel_mode = (data[0] >> 0) & 0x0F channel_mode = (data[0] >> 0) & 0x0F
block_length = (data[1] >> 4) & 0x0F block_length = (data[1] >> 4) & 0x0F
@@ -301,14 +293,14 @@ class SbcMediaCodecInformation:
@classmethod @classmethod
def from_discrete_values( def from_discrete_values(
cls, cls,
sampling_frequency: int, sampling_frequency,
channel_mode: int, channel_mode,
block_length: int, block_length,
subbands: int, subbands,
allocation_method: int, allocation_method,
minimum_bitpool_value: int, minimum_bitpool_value,
maximum_bitpool_value: int, maximum_bitpool_value,
) -> SbcMediaCodecInformation: ):
return SbcMediaCodecInformation( return SbcMediaCodecInformation(
sampling_frequency=cls.SAMPLING_FREQUENCY_BITS[sampling_frequency], sampling_frequency=cls.SAMPLING_FREQUENCY_BITS[sampling_frequency],
channel_mode=cls.CHANNEL_MODE_BITS[channel_mode], channel_mode=cls.CHANNEL_MODE_BITS[channel_mode],
@@ -322,14 +314,14 @@ class SbcMediaCodecInformation:
@classmethod @classmethod
def from_lists( def from_lists(
cls, cls,
sampling_frequencies: List[int], sampling_frequencies,
channel_modes: List[int], channel_modes,
block_lengths: List[int], block_lengths,
subbands: List[int], subbands,
allocation_methods: List[int], allocation_methods,
minimum_bitpool_value: int, minimum_bitpool_value,
maximum_bitpool_value: int, maximum_bitpool_value,
) -> SbcMediaCodecInformation: ):
return SbcMediaCodecInformation( return SbcMediaCodecInformation(
sampling_frequency=sum( sampling_frequency=sum(
cls.SAMPLING_FREQUENCY_BITS[x] for x in sampling_frequencies cls.SAMPLING_FREQUENCY_BITS[x] for x in sampling_frequencies
@@ -356,7 +348,7 @@ class SbcMediaCodecInformation:
] ]
) )
def __str__(self) -> str: def __str__(self):
channel_modes = ['MONO', 'DUAL_CHANNEL', 'STEREO', 'JOINT_STEREO'] channel_modes = ['MONO', 'DUAL_CHANNEL', 'STEREO', 'JOINT_STEREO']
allocation_methods = ['SNR', 'Loudness'] allocation_methods = ['SNR', 'Loudness']
return '\n'.join( return '\n'.join(
@@ -375,19 +367,16 @@ class SbcMediaCodecInformation:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@dataclasses.dataclass class AacMediaCodecInformation(
class AacMediaCodecInformation: namedtuple(
'AacMediaCodecInformation',
['object_type', 'sampling_frequency', 'channels', 'rfa', 'vbr', 'bitrate'],
)
):
''' '''
A2DP spec - 4.5.2 Codec Specific Information Elements A2DP spec - 4.5.2 Codec Specific Information Elements
''' '''
object_type: int
sampling_frequency: int
channels: int
rfa: int
vbr: int
bitrate: int
OBJECT_TYPE_BITS = { OBJECT_TYPE_BITS = {
MPEG_2_AAC_LC_OBJECT_TYPE: 1 << 7, MPEG_2_AAC_LC_OBJECT_TYPE: 1 << 7,
MPEG_4_AAC_LC_OBJECT_TYPE: 1 << 6, MPEG_4_AAC_LC_OBJECT_TYPE: 1 << 6,
@@ -411,7 +400,7 @@ class AacMediaCodecInformation:
CHANNELS_BITS = {1: 1 << 1, 2: 1} CHANNELS_BITS = {1: 1 << 1, 2: 1}
@staticmethod @staticmethod
def from_bytes(data: bytes) -> AacMediaCodecInformation: def from_bytes(data: bytes) -> 'AacMediaCodecInformation':
object_type = data[0] object_type = data[0]
sampling_frequency = (data[1] << 4) | ((data[2] >> 4) & 0x0F) sampling_frequency = (data[1] << 4) | ((data[2] >> 4) & 0x0F)
channels = (data[2] >> 2) & 0x03 channels = (data[2] >> 2) & 0x03
@@ -424,13 +413,8 @@ class AacMediaCodecInformation:
@classmethod @classmethod
def from_discrete_values( def from_discrete_values(
cls, cls, object_type, sampling_frequency, channels, vbr, bitrate
object_type: int, ):
sampling_frequency: int,
channels: int,
vbr: int,
bitrate: int,
) -> AacMediaCodecInformation:
return AacMediaCodecInformation( return AacMediaCodecInformation(
object_type=cls.OBJECT_TYPE_BITS[object_type], object_type=cls.OBJECT_TYPE_BITS[object_type],
sampling_frequency=cls.SAMPLING_FREQUENCY_BITS[sampling_frequency], sampling_frequency=cls.SAMPLING_FREQUENCY_BITS[sampling_frequency],
@@ -441,14 +425,7 @@ class AacMediaCodecInformation:
) )
@classmethod @classmethod
def from_lists( def from_lists(cls, object_types, sampling_frequencies, channels, vbr, bitrate):
cls,
object_types: List[int],
sampling_frequencies: List[int],
channels: List[int],
vbr: int,
bitrate: int,
) -> AacMediaCodecInformation:
return AacMediaCodecInformation( return AacMediaCodecInformation(
object_type=sum(cls.OBJECT_TYPE_BITS[x] for x in object_types), object_type=sum(cls.OBJECT_TYPE_BITS[x] for x in object_types),
sampling_frequency=sum( sampling_frequency=sum(
@@ -472,7 +449,7 @@ class AacMediaCodecInformation:
] ]
) )
def __str__(self) -> str: def __str__(self):
object_types = [ object_types = [
'MPEG_2_AAC_LC', 'MPEG_2_AAC_LC',
'MPEG_4_AAC_LC', 'MPEG_4_AAC_LC',
@@ -497,26 +474,26 @@ class AacMediaCodecInformation:
) )
@dataclasses.dataclass
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class VendorSpecificMediaCodecInformation: class VendorSpecificMediaCodecInformation:
''' '''
A2DP spec - 4.7.2 Codec Specific Information Elements A2DP spec - 4.7.2 Codec Specific Information Elements
''' '''
vendor_id: int
codec_id: int
value: bytes
@staticmethod @staticmethod
def from_bytes(data: bytes) -> VendorSpecificMediaCodecInformation: def from_bytes(data):
(vendor_id, codec_id) = struct.unpack_from('<IH', data, 0) (vendor_id, codec_id) = struct.unpack_from('<IH', data, 0)
return VendorSpecificMediaCodecInformation(vendor_id, codec_id, data[6:]) return VendorSpecificMediaCodecInformation(vendor_id, codec_id, data[6:])
def __bytes__(self) -> bytes: def __init__(self, vendor_id, codec_id, value):
self.vendor_id = vendor_id
self.codec_id = codec_id
self.value = value
def __bytes__(self):
return struct.pack('<IH', self.vendor_id, self.codec_id, self.value) return struct.pack('<IH', self.vendor_id, self.codec_id, self.value)
def __str__(self) -> str: def __str__(self):
# pylint: disable=line-too-long # pylint: disable=line-too-long
return '\n'.join( return '\n'.join(
[ [
@@ -529,27 +506,29 @@ class VendorSpecificMediaCodecInformation:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@dataclasses.dataclass
class SbcFrame: class SbcFrame:
sampling_frequency: int def __init__(
block_count: int self, sampling_frequency, block_count, channel_mode, subband_count, payload
channel_mode: int ):
subband_count: int self.sampling_frequency = sampling_frequency
payload: bytes self.block_count = block_count
self.channel_mode = channel_mode
self.subband_count = subband_count
self.payload = payload
@property @property
def sample_count(self) -> int: def sample_count(self):
return self.subband_count * self.block_count return self.subband_count * self.block_count
@property @property
def bitrate(self) -> int: def bitrate(self):
return 8 * ((len(self.payload) * self.sampling_frequency) // self.sample_count) return 8 * ((len(self.payload) * self.sampling_frequency) // self.sample_count)
@property @property
def duration(self) -> float: def duration(self):
return self.sample_count / self.sampling_frequency return self.sample_count / self.sampling_frequency
def __str__(self) -> str: def __str__(self):
return ( return (
f'SBC(sf={self.sampling_frequency},' f'SBC(sf={self.sampling_frequency},'
f'cm={self.channel_mode},' f'cm={self.channel_mode},'
@@ -561,12 +540,12 @@ class SbcFrame:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class SbcParser: class SbcParser:
def __init__(self, read: Callable[[int], Awaitable[bytes]]) -> None: def __init__(self, read):
self.read = read self.read = read
@property @property
def frames(self) -> AsyncGenerator[SbcFrame, None]: def frames(self):
async def generate_frames() -> AsyncGenerator[SbcFrame, None]: async def generate_frames():
while True: while True:
# Read 4 bytes of header # Read 4 bytes of header
header = await self.read(4) header = await self.read(4)
@@ -610,9 +589,7 @@ class SbcParser:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class SbcPacketSource: class SbcPacketSource:
def __init__( def __init__(self, read, mtu, codec_capabilities):
self, read: Callable[[int], Awaitable[bytes]], mtu: int, codec_capabilities
) -> None:
self.read = read self.read = read
self.mtu = mtu self.mtu = mtu
self.codec_capabilities = codec_capabilities self.codec_capabilities = codec_capabilities
@@ -652,9 +629,7 @@ class SbcPacketSource:
# Prepare for next packets # Prepare for next packets
sequence_number += 1 sequence_number += 1
sequence_number &= 0xFFFF
timestamp += sum((frame.sample_count for frame in frames)) timestamp += sum((frame.sample_count for frame in frames))
timestamp &= 0xFFFFFFFF
frames = [frame] frames = [frame]
frames_size = len(frame.payload) frames_size = len(frame.payload)
else: else:
+12 -54
View File
@@ -25,21 +25,9 @@
from __future__ import annotations from __future__ import annotations
import enum import enum
import functools import functools
import inspect
import struct import struct
from typing import (
Any,
Awaitable,
Callable,
Dict,
List,
Optional,
Type,
Union,
TYPE_CHECKING,
)
from pyee import EventEmitter from pyee import EventEmitter
from typing import Dict, Type, List, Protocol, Union, Optional, Any, TYPE_CHECKING
from bumble.core import UUID, name_or_number, ProtocolError from bumble.core import UUID, name_or_number, ProtocolError
from bumble.hci import HCI_Object, key_with_value from bumble.hci import HCI_Object, key_with_value
@@ -655,7 +643,7 @@ class ATT_Write_Command(ATT_PDU):
@ATT_PDU.subclass( @ATT_PDU.subclass(
[ [
('attribute_handle', HANDLE_FIELD_SPEC), ('attribute_handle', HANDLE_FIELD_SPEC),
('attribute_value', '*'), ('attribute_value', '*')
# ('authentication_signature', 'TODO') # ('authentication_signature', 'TODO')
] ]
) )
@@ -734,38 +722,12 @@ class ATT_Handle_Value_Confirmation(ATT_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class AttributeValue: class ConnectionValue(Protocol):
''' def read(self, connection) -> bytes:
Attribute value where reading and/or writing is delegated to functions ...
passed as arguments to the constructor.
'''
def __init__( def write(self, connection, value: bytes) -> None:
self, ...
read: Union[
Callable[[Optional[Connection]], bytes],
Callable[[Optional[Connection]], Awaitable[bytes]],
None,
] = None,
write: Union[
Callable[[Optional[Connection], bytes], None],
Callable[[Optional[Connection], bytes], Awaitable[None]],
None,
] = None,
):
self._read = read
self._write = write
def read(self, connection: Optional[Connection]) -> Union[bytes, Awaitable[bytes]]:
return self._read(connection) if self._read else b''
def write(
self, connection: Optional[Connection], value: bytes
) -> Union[Awaitable[None], None]:
if self._write:
return self._write(connection, value)
return None
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -808,13 +770,13 @@ class Attribute(EventEmitter):
READ_REQUIRES_AUTHORIZATION = Permissions.READ_REQUIRES_AUTHORIZATION READ_REQUIRES_AUTHORIZATION = Permissions.READ_REQUIRES_AUTHORIZATION
WRITE_REQUIRES_AUTHORIZATION = Permissions.WRITE_REQUIRES_AUTHORIZATION WRITE_REQUIRES_AUTHORIZATION = Permissions.WRITE_REQUIRES_AUTHORIZATION
value: Union[bytes, AttributeValue] value: Union[str, bytes, ConnectionValue]
def __init__( def __init__(
self, self,
attribute_type: Union[str, bytes, UUID], attribute_type: Union[str, bytes, UUID],
permissions: Union[str, Attribute.Permissions], permissions: Union[str, Attribute.Permissions],
value: Union[str, bytes, AttributeValue] = b'', value: Union[str, bytes, ConnectionValue] = b'',
) -> None: ) -> None:
EventEmitter.__init__(self) EventEmitter.__init__(self)
self.handle = 0 self.handle = 0
@@ -844,7 +806,7 @@ class Attribute(EventEmitter):
def decode_value(self, value_bytes: bytes) -> Any: def decode_value(self, value_bytes: bytes) -> Any:
return value_bytes return value_bytes
async def read_value(self, connection: Optional[Connection]) -> bytes: def read_value(self, connection: Optional[Connection]) -> bytes:
if ( if (
(self.permissions & self.READ_REQUIRES_ENCRYPTION) (self.permissions & self.READ_REQUIRES_ENCRYPTION)
and connection is not None and connection is not None
@@ -870,8 +832,6 @@ class Attribute(EventEmitter):
if hasattr(self.value, 'read'): if hasattr(self.value, 'read'):
try: try:
value = self.value.read(connection) value = self.value.read(connection)
if inspect.isawaitable(value):
value = await value
except ATT_Error as error: except ATT_Error as error:
raise ATT_Error( raise ATT_Error(
error_code=error.error_code, att_handle=self.handle error_code=error.error_code, att_handle=self.handle
@@ -881,7 +841,7 @@ class Attribute(EventEmitter):
return self.encode_value(value) return self.encode_value(value)
async def write_value(self, connection: Connection, value_bytes: bytes) -> None: def write_value(self, connection: Connection, value_bytes: bytes) -> None:
if ( if (
self.permissions & self.WRITE_REQUIRES_ENCRYPTION self.permissions & self.WRITE_REQUIRES_ENCRYPTION
) and not connection.encryption: ) and not connection.encryption:
@@ -904,9 +864,7 @@ class Attribute(EventEmitter):
if hasattr(self.value, 'write'): if hasattr(self.value, 'write'):
try: try:
result = self.value.write(connection, value) self.value.write(connection, value) # pylint: disable=not-callable
if inspect.isawaitable(result):
await result
except ATT_Error as error: except ATT_Error as error:
raise ATT_Error( raise ATT_Error(
error_code=error.error_code, att_handle=self.handle error_code=error.error_code, att_handle=self.handle
-520
View File
@@ -1,520 +0,0 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import enum
import struct
from typing import Dict, Type, Union, Tuple
from bumble.utils import OpenIntEnum
# -----------------------------------------------------------------------------
class Frame:
class SubunitType(enum.IntEnum):
# AV/C Digital Interface Command Set General Specification Version 4.1
# Table 7.4
MONITOR = 0x00
AUDIO = 0x01
PRINTER = 0x02
DISC = 0x03
TAPE_RECORDER_OR_PLAYER = 0x04
TUNER = 0x05
CA = 0x06
CAMERA = 0x07
PANEL = 0x09
BULLETIN_BOARD = 0x0A
VENDOR_UNIQUE = 0x1C
EXTENDED = 0x1E
UNIT = 0x1F
class OperationCode(OpenIntEnum):
# 0x00 - 0x0F: Unit and subunit commands
VENDOR_DEPENDENT = 0x00
RESERVE = 0x01
PLUG_INFO = 0x02
# 0x10 - 0x3F: Unit commands
DIGITAL_OUTPUT = 0x10
DIGITAL_INPUT = 0x11
CHANNEL_USAGE = 0x12
OUTPUT_PLUG_SIGNAL_FORMAT = 0x18
INPUT_PLUG_SIGNAL_FORMAT = 0x19
GENERAL_BUS_SETUP = 0x1F
CONNECT_AV = 0x20
DISCONNECT_AV = 0x21
CONNECTIONS = 0x22
CONNECT = 0x24
DISCONNECT = 0x25
UNIT_INFO = 0x30
SUBUNIT_INFO = 0x31
# 0x40 - 0x7F: Subunit commands
PASS_THROUGH = 0x7C
GUI_UPDATE = 0x7D
PUSH_GUI_DATA = 0x7E
USER_ACTION = 0x7F
# 0xA0 - 0xBF: Unit and subunit commands
VERSION = 0xB0
POWER = 0xB2
subunit_type: SubunitType
subunit_id: int
opcode: OperationCode
operands: bytes
@staticmethod
def subclass(subclass):
# Infer the opcode from the class name
if subclass.__name__.endswith("CommandFrame"):
short_name = subclass.__name__.replace("CommandFrame", "")
category_class = CommandFrame
elif subclass.__name__.endswith("ResponseFrame"):
short_name = subclass.__name__.replace("ResponseFrame", "")
category_class = ResponseFrame
else:
raise ValueError(f"invalid subclass name {subclass.__name__}")
uppercase_indexes = [
i for i in range(len(short_name)) if short_name[i].isupper()
]
uppercase_indexes.append(len(short_name))
words = [
short_name[uppercase_indexes[i] : uppercase_indexes[i + 1]].upper()
for i in range(len(uppercase_indexes) - 1)
]
opcode_name = "_".join(words)
opcode = Frame.OperationCode[opcode_name]
category_class.subclasses[opcode] = subclass
return subclass
@staticmethod
def from_bytes(data: bytes) -> Frame:
if data[0] >> 4 != 0:
raise ValueError("first 4 bits must be 0s")
ctype_or_response = data[0] & 0xF
subunit_type = Frame.SubunitType(data[1] >> 3)
subunit_id = data[1] & 7
if subunit_type == Frame.SubunitType.EXTENDED:
# Not supported
raise NotImplementedError("extended subunit types not supported")
if subunit_id < 5:
opcode_offset = 2
elif subunit_id == 5:
# Extended to the next byte
extension = data[2]
if extension == 0:
raise ValueError("extended subunit ID value reserved")
if extension == 0xFF:
subunit_id = 5 + 254 + data[3]
opcode_offset = 4
else:
subunit_id = 5 + extension
opcode_offset = 3
elif subunit_id == 6:
raise ValueError("reserved subunit ID")
opcode = Frame.OperationCode(data[opcode_offset])
operands = data[opcode_offset + 1 :]
# Look for a registered subclass
if ctype_or_response < 8:
# Command
ctype = CommandFrame.CommandType(ctype_or_response)
if c_subclass := CommandFrame.subclasses.get(opcode):
return c_subclass(
ctype,
subunit_type,
subunit_id,
*c_subclass.parse_operands(operands),
)
return CommandFrame(ctype, subunit_type, subunit_id, opcode, operands)
else:
# Response
response = ResponseFrame.ResponseCode(ctype_or_response)
if r_subclass := ResponseFrame.subclasses.get(opcode):
return r_subclass(
response,
subunit_type,
subunit_id,
*r_subclass.parse_operands(operands),
)
return ResponseFrame(response, subunit_type, subunit_id, opcode, operands)
def to_bytes(
self,
ctype_or_response: Union[CommandFrame.CommandType, ResponseFrame.ResponseCode],
) -> bytes:
# TODO: support extended subunit types and ids.
return (
bytes(
[
ctype_or_response,
self.subunit_type << 3 | self.subunit_id,
self.opcode,
]
)
+ self.operands
)
def to_string(self, extra: str) -> str:
return (
f"{self.__class__.__name__}({extra}"
f"subunit_type={self.subunit_type.name}, "
f"subunit_id=0x{self.subunit_id:02X}, "
f"opcode={self.opcode.name}, "
f"operands={self.operands.hex()})"
)
def __init__(
self,
subunit_type: SubunitType,
subunit_id: int,
opcode: OperationCode,
operands: bytes,
) -> None:
self.subunit_type = subunit_type
self.subunit_id = subunit_id
self.opcode = opcode
self.operands = operands
# -----------------------------------------------------------------------------
class CommandFrame(Frame):
class CommandType(OpenIntEnum):
# AV/C Digital Interface Command Set General Specification Version 4.1
# Table 7.1
CONTROL = 0x00
STATUS = 0x01
SPECIFIC_INQUIRY = 0x02
NOTIFY = 0x03
GENERAL_INQUIRY = 0x04
subclasses: Dict[Frame.OperationCode, Type[CommandFrame]] = {}
ctype: CommandType
@staticmethod
def parse_operands(operands: bytes) -> Tuple:
raise NotImplementedError
def __init__(
self,
ctype: CommandType,
subunit_type: Frame.SubunitType,
subunit_id: int,
opcode: Frame.OperationCode,
operands: bytes,
) -> None:
super().__init__(subunit_type, subunit_id, opcode, operands)
self.ctype = ctype
def __bytes__(self):
return self.to_bytes(self.ctype)
def __str__(self):
return self.to_string(f"ctype={self.ctype.name}, ")
# -----------------------------------------------------------------------------
class ResponseFrame(Frame):
class ResponseCode(OpenIntEnum):
# AV/C Digital Interface Command Set General Specification Version 4.1
# Table 7.2
NOT_IMPLEMENTED = 0x08
ACCEPTED = 0x09
REJECTED = 0x0A
IN_TRANSITION = 0x0B
IMPLEMENTED_OR_STABLE = 0x0C
CHANGED = 0x0D
INTERIM = 0x0F
subclasses: Dict[Frame.OperationCode, Type[ResponseFrame]] = {}
response: ResponseCode
@staticmethod
def parse_operands(operands: bytes) -> Tuple:
raise NotImplementedError
def __init__(
self,
response: ResponseCode,
subunit_type: Frame.SubunitType,
subunit_id: int,
opcode: Frame.OperationCode,
operands: bytes,
) -> None:
super().__init__(subunit_type, subunit_id, opcode, operands)
self.response = response
def __bytes__(self):
return self.to_bytes(self.response)
def __str__(self):
return self.to_string(f"response={self.response.name}, ")
# -----------------------------------------------------------------------------
class VendorDependentFrame:
company_id: int
vendor_dependent_data: bytes
@staticmethod
def parse_operands(operands: bytes) -> Tuple:
return (
struct.unpack(">I", b"\x00" + operands[:3])[0],
operands[3:],
)
def make_operands(self) -> bytes:
return struct.pack(">I", self.company_id)[1:] + self.vendor_dependent_data
def __init__(self, company_id: int, vendor_dependent_data: bytes):
self.company_id = company_id
self.vendor_dependent_data = vendor_dependent_data
# -----------------------------------------------------------------------------
@Frame.subclass
class VendorDependentCommandFrame(VendorDependentFrame, CommandFrame):
def __init__(
self,
ctype: CommandFrame.CommandType,
subunit_type: Frame.SubunitType,
subunit_id: int,
company_id: int,
vendor_dependent_data: bytes,
) -> None:
VendorDependentFrame.__init__(self, company_id, vendor_dependent_data)
CommandFrame.__init__(
self,
ctype,
subunit_type,
subunit_id,
Frame.OperationCode.VENDOR_DEPENDENT,
self.make_operands(),
)
def __str__(self):
return (
f"VendorDependentCommandFrame(ctype={self.ctype.name}, "
f"subunit_type={self.subunit_type.name}, "
f"subunit_id=0x{self.subunit_id:02X}, "
f"company_id=0x{self.company_id:06X}, "
f"vendor_dependent_data={self.vendor_dependent_data.hex()})"
)
# -----------------------------------------------------------------------------
@Frame.subclass
class VendorDependentResponseFrame(VendorDependentFrame, ResponseFrame):
def __init__(
self,
response: ResponseFrame.ResponseCode,
subunit_type: Frame.SubunitType,
subunit_id: int,
company_id: int,
vendor_dependent_data: bytes,
) -> None:
VendorDependentFrame.__init__(self, company_id, vendor_dependent_data)
ResponseFrame.__init__(
self,
response,
subunit_type,
subunit_id,
Frame.OperationCode.VENDOR_DEPENDENT,
self.make_operands(),
)
def __str__(self):
return (
f"VendorDependentResponseFrame(response={self.response.name}, "
f"subunit_type={self.subunit_type.name}, "
f"subunit_id=0x{self.subunit_id:02X}, "
f"company_id=0x{self.company_id:06X}, "
f"vendor_dependent_data={self.vendor_dependent_data.hex()})"
)
# -----------------------------------------------------------------------------
class PassThroughFrame:
"""
See AV/C Panel Subunit Specification 1.1 - 9.4 PASS THROUGH control command
"""
class StateFlag(enum.IntEnum):
PRESSED = 0
RELEASED = 1
class OperationId(OpenIntEnum):
SELECT = 0x00
UP = 0x01
DOWN = 0x01
LEFT = 0x03
RIGHT = 0x04
RIGHT_UP = 0x05
RIGHT_DOWN = 0x06
LEFT_UP = 0x07
LEFT_DOWN = 0x08
ROOT_MENU = 0x09
SETUP_MENU = 0x0A
CONTENTS_MENU = 0x0B
FAVORITE_MENU = 0x0C
EXIT = 0x0D
NUMBER_0 = 0x20
NUMBER_1 = 0x21
NUMBER_2 = 0x22
NUMBER_3 = 0x23
NUMBER_4 = 0x24
NUMBER_5 = 0x25
NUMBER_6 = 0x26
NUMBER_7 = 0x27
NUMBER_8 = 0x28
NUMBER_9 = 0x29
DOT = 0x2A
ENTER = 0x2B
CLEAR = 0x2C
CHANNEL_UP = 0x30
CHANNEL_DOWN = 0x31
PREVIOUS_CHANNEL = 0x32
SOUND_SELECT = 0x33
INPUT_SELECT = 0x34
DISPLAY_INFORMATION = 0x35
HELP = 0x36
PAGE_UP = 0x37
PAGE_DOWN = 0x38
POWER = 0x40
VOLUME_UP = 0x41
VOLUME_DOWN = 0x42
MUTE = 0x43
PLAY = 0x44
STOP = 0x45
PAUSE = 0x46
RECORD = 0x47
REWIND = 0x48
FAST_FORWARD = 0x49
EJECT = 0x4A
FORWARD = 0x4B
BACKWARD = 0x4C
ANGLE = 0x50
SUBPICTURE = 0x51
F1 = 0x71
F2 = 0x72
F3 = 0x73
F4 = 0x74
F5 = 0x75
VENDOR_UNIQUE = 0x7E
state_flag: StateFlag
operation_id: OperationId
operation_data: bytes
@staticmethod
def parse_operands(operands: bytes) -> Tuple:
return (
PassThroughFrame.StateFlag(operands[0] >> 7),
PassThroughFrame.OperationId(operands[0] & 0x7F),
operands[1 : 1 + operands[1]],
)
def make_operands(self):
return (
bytes([self.state_flag << 7 | self.operation_id, len(self.operation_data)])
+ self.operation_data
)
def __init__(
self,
state_flag: StateFlag,
operation_id: OperationId,
operation_data: bytes,
) -> None:
if len(operation_data) > 255:
raise ValueError("operation data must be <= 255 bytes")
self.state_flag = state_flag
self.operation_id = operation_id
self.operation_data = operation_data
# -----------------------------------------------------------------------------
@Frame.subclass
class PassThroughCommandFrame(PassThroughFrame, CommandFrame):
def __init__(
self,
ctype: CommandFrame.CommandType,
subunit_type: Frame.SubunitType,
subunit_id: int,
state_flag: PassThroughFrame.StateFlag,
operation_id: PassThroughFrame.OperationId,
operation_data: bytes,
) -> None:
PassThroughFrame.__init__(self, state_flag, operation_id, operation_data)
CommandFrame.__init__(
self,
ctype,
subunit_type,
subunit_id,
Frame.OperationCode.PASS_THROUGH,
self.make_operands(),
)
def __str__(self):
return (
f"PassThroughCommandFrame(ctype={self.ctype.name}, "
f"subunit_type={self.subunit_type.name}, "
f"subunit_id=0x{self.subunit_id:02X}, "
f"state_flag={self.state_flag.name}, "
f"operation_id={self.operation_id.name}, "
f"operation_data={self.operation_data.hex()})"
)
# -----------------------------------------------------------------------------
@Frame.subclass
class PassThroughResponseFrame(PassThroughFrame, ResponseFrame):
def __init__(
self,
response: ResponseFrame.ResponseCode,
subunit_type: Frame.SubunitType,
subunit_id: int,
state_flag: PassThroughFrame.StateFlag,
operation_id: PassThroughFrame.OperationId,
operation_data: bytes,
) -> None:
PassThroughFrame.__init__(self, state_flag, operation_id, operation_data)
ResponseFrame.__init__(
self,
response,
subunit_type,
subunit_id,
Frame.OperationCode.PASS_THROUGH,
self.make_operands(),
)
def __str__(self):
return (
f"PassThroughResponseFrame(response={self.response.name}, "
f"subunit_type={self.subunit_type.name}, "
f"subunit_id=0x{self.subunit_id:02X}, "
f"state_flag={self.state_flag.name}, "
f"operation_id={self.operation_id.name}, "
f"operation_data={self.operation_data.hex()})"
)
-291
View File
@@ -1,291 +0,0 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
from enum import IntEnum
import logging
import struct
from typing import Callable, cast, Dict, Optional
from bumble.colors import color
from bumble import avc
from bumble import l2cap
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
AVCTP_PSM = 0x0017
AVCTP_BROWSING_PSM = 0x001B
# -----------------------------------------------------------------------------
class MessageAssembler:
Callback = Callable[[int, bool, bool, int, bytes], None]
transaction_label: int
pid: int
c_r: int
ipid: int
payload: bytes
number_of_packets: int
packets_received: int
def __init__(self, callback: Callback) -> None:
self.callback = callback
self.reset()
def reset(self) -> None:
self.packets_received = 0
self.transaction_label = -1
self.pid = -1
self.c_r = -1
self.ipid = -1
self.payload = b''
self.number_of_packets = 0
self.packet_count = 0
def on_pdu(self, pdu: bytes) -> None:
self.packets_received += 1
transaction_label = pdu[0] >> 4
packet_type = Protocol.PacketType((pdu[0] >> 2) & 3)
c_r = (pdu[0] >> 1) & 1
ipid = pdu[0] & 1
if c_r == 0 and ipid != 0:
logger.warning("invalid IPID in command frame")
self.reset()
return
pid_offset = 1
if packet_type in (Protocol.PacketType.SINGLE, Protocol.PacketType.START):
if self.transaction_label >= 0:
# We are already in a transaction
logger.warning("received START or SINGLE fragment while in transaction")
self.reset()
self.packets_received = 1
if packet_type == Protocol.PacketType.START:
self.number_of_packets = pdu[1]
pid_offset = 2
pid = struct.unpack_from(">H", pdu, pid_offset)[0]
self.payload += pdu[pid_offset + 2 :]
if packet_type in (Protocol.PacketType.CONTINUE, Protocol.PacketType.END):
if transaction_label != self.transaction_label:
logger.warning("transaction label does not match")
self.reset()
return
if pid != self.pid:
logger.warning("PID does not match")
self.reset()
return
if c_r != self.c_r:
logger.warning("C/R does not match")
self.reset()
return
if self.packets_received > self.number_of_packets:
logger.warning("too many fragments in transaction")
self.reset()
return
if packet_type == Protocol.PacketType.END:
if self.packets_received != self.number_of_packets:
logger.warning("premature END")
self.reset()
return
else:
self.transaction_label = transaction_label
self.c_r = c_r
self.ipid = ipid
self.pid = pid
if packet_type in (Protocol.PacketType.SINGLE, Protocol.PacketType.END):
self.on_message_complete()
def on_message_complete(self):
try:
self.callback(
self.transaction_label,
self.c_r == 0,
self.ipid != 0,
self.pid,
self.payload,
)
except Exception as error:
logger.exception(color(f"!!! exception in callback: {error}", "red"))
self.reset()
# -----------------------------------------------------------------------------
class Protocol:
CommandHandler = Callable[[int, avc.CommandFrame], None]
command_handlers: Dict[int, CommandHandler] # Command handlers, by PID
ResponseHandler = Callable[[int, Optional[avc.ResponseFrame]], None]
response_handlers: Dict[int, ResponseHandler] # Response handlers, by PID
next_transaction_label: int
message_assembler: MessageAssembler
class PacketType(IntEnum):
SINGLE = 0b00
START = 0b01
CONTINUE = 0b10
END = 0b11
def __init__(self, l2cap_channel: l2cap.ClassicChannel) -> None:
self.command_handlers = {}
self.response_handlers = {}
self.l2cap_channel = l2cap_channel
self.message_assembler = MessageAssembler(self.on_message)
# Register to receive PDUs from the channel
l2cap_channel.sink = self.on_pdu
l2cap_channel.on("open", self.on_l2cap_channel_open)
l2cap_channel.on("close", self.on_l2cap_channel_close)
def on_l2cap_channel_open(self):
logger.debug(color("<<< AVCTP channel open", "magenta"))
def on_l2cap_channel_close(self):
logger.debug(color("<<< AVCTP channel closed", "magenta"))
def on_pdu(self, pdu: bytes) -> None:
self.message_assembler.on_pdu(pdu)
def on_message(
self,
transaction_label: int,
is_command: bool,
ipid: bool,
pid: int,
payload: bytes,
) -> None:
logger.debug(
f"<<< AVCTP Message: pid={pid}, "
f"transaction_label={transaction_label}, "
f"is_command={is_command}, "
f"ipid={ipid}, "
f"payload={payload.hex()}"
)
# Check for invalid PID responses.
if ipid:
logger.debug(f"received IPID for PID={pid}")
# Find the appropriate handler.
if is_command:
if pid not in self.command_handlers:
logger.warning(f"no command handler for PID {pid}")
self.send_ipid(transaction_label, pid)
return
command_frame = cast(avc.CommandFrame, avc.Frame.from_bytes(payload))
self.command_handlers[pid](transaction_label, command_frame)
else:
if pid not in self.response_handlers:
logger.warning(f"no response handler for PID {pid}")
return
# By convention, for an ipid, send a None payload to the response handler.
if ipid:
response_frame = None
else:
response_frame = cast(avc.ResponseFrame, avc.Frame.from_bytes(payload))
self.response_handlers[pid](transaction_label, response_frame)
def send_message(
self,
transaction_label: int,
is_command: bool,
ipid: bool,
pid: int,
payload: bytes,
):
# TODO: fragment large messages
packet_type = Protocol.PacketType.SINGLE
pdu = (
struct.pack(
">BH",
transaction_label << 4
| packet_type << 2
| (0 if is_command else 1) << 1
| (1 if ipid else 0),
pid,
)
+ payload
)
self.l2cap_channel.send_pdu(pdu)
def send_command(self, transaction_label: int, pid: int, payload: bytes) -> None:
logger.debug(
">>> AVCTP command: "
f"transaction_label={transaction_label}, "
f"pid={pid}, "
f"payload={payload.hex()}"
)
self.send_message(transaction_label, True, False, pid, payload)
def send_response(self, transaction_label: int, pid: int, payload: bytes):
logger.debug(
">>> AVCTP response: "
f"transaction_label={transaction_label}, "
f"pid={pid}, "
f"payload={payload.hex()}"
)
self.send_message(transaction_label, False, False, pid, payload)
def send_ipid(self, transaction_label: int, pid: int) -> None:
logger.debug(
">>> AVCTP ipid: " f"transaction_label={transaction_label}, " f"pid={pid}"
)
self.send_message(transaction_label, False, True, pid, b'')
def register_command_handler(
self, pid: int, handler: Protocol.CommandHandler
) -> None:
self.command_handlers[pid] = handler
def unregister_command_handler(
self, pid: int, handler: Protocol.CommandHandler
) -> None:
if pid not in self.command_handlers or self.command_handlers[pid] != handler:
raise ValueError("command handler not registered")
del self.command_handlers[pid]
def register_response_handler(
self, pid: int, handler: Protocol.ResponseHandler
) -> None:
self.response_handlers[pid] = handler
def unregister_response_handler(
self, pid: int, handler: Protocol.ResponseHandler
) -> None:
if pid not in self.response_handlers or self.response_handlers[pid] != handler:
raise ValueError("response handler not registered")
del self.response_handlers[pid]
+13 -23
View File
@@ -241,10 +241,7 @@ async def find_avdtp_service_with_sdp_client(
) )
if profile_descriptor_list: if profile_descriptor_list:
for profile_descriptor in profile_descriptor_list.value: for profile_descriptor in profile_descriptor_list.value:
if ( if len(profile_descriptor.value) >= 2:
profile_descriptor.type == sdp.DataElement.SEQUENCE
and len(profile_descriptor.value) >= 2
):
avdtp_version_major = profile_descriptor.value[1].value >> 8 avdtp_version_major = profile_descriptor.value[1].value >> 8
avdtp_version_minor = profile_descriptor.value[1].value & 0xFF avdtp_version_minor = profile_descriptor.value[1].value & 0xFF
return (avdtp_version_major, avdtp_version_minor) return (avdtp_version_major, avdtp_version_minor)
@@ -253,15 +250,15 @@ async def find_avdtp_service_with_sdp_client(
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def find_avdtp_service_with_connection( async def find_avdtp_service_with_connection(
connection: device.Connection, device: device.Device, connection: device.Connection
) -> Optional[Tuple[int, int]]: ) -> Optional[Tuple[int, int]]:
''' '''
Find an AVDTP service, for a connection, and return its version, Find an AVDTP service, for a connection, and return its version,
or None if none is found or None if none is found
''' '''
sdp_client = sdp.Client(connection) sdp_client = sdp.Client(device)
await sdp_client.connect() await sdp_client.connect(connection)
service_version = await find_avdtp_service_with_sdp_client(sdp_client) service_version = await find_avdtp_service_with_sdp_client(sdp_client)
await sdp_client.disconnect() await sdp_client.disconnect()
@@ -325,8 +322,8 @@ class MediaPacket:
self.padding = padding self.padding = padding
self.extension = extension self.extension = extension
self.marker = marker self.marker = marker
self.sequence_number = sequence_number & 0xFFFF self.sequence_number = sequence_number
self.timestamp = timestamp & 0xFFFFFFFF self.timestamp = timestamp
self.ssrc = ssrc self.ssrc = ssrc
self.csrc_list = csrc_list self.csrc_list = csrc_list
self.payload_type = payload_type self.payload_type = payload_type
@@ -341,12 +338,7 @@ class MediaPacket:
| len(self.csrc_list), | len(self.csrc_list),
self.marker << 7 | self.payload_type, self.marker << 7 | self.payload_type,
] ]
) + struct.pack( ) + struct.pack('>HII', self.sequence_number, self.timestamp, self.ssrc)
'>HII',
self.sequence_number,
self.timestamp,
self.ssrc,
)
for csrc in self.csrc_list: for csrc in self.csrc_list:
header += struct.pack('>I', csrc) header += struct.pack('>I', csrc)
return header + self.payload return header + self.payload
@@ -519,8 +511,7 @@ class MessageAssembler:
try: try:
self.callback(self.transaction_label, message) self.callback(self.transaction_label, message)
except Exception as error: except Exception as error:
logger.exception(color(f'!!! exception in callback: {error}', 'red')) logger.warning(color(f'!!! exception in callback: {error}'))
self.reset() self.reset()
@@ -1475,10 +1466,10 @@ class Protocol(EventEmitter):
f'[{transaction_label}] {message}' f'[{transaction_label}] {message}'
) )
max_fragment_size = ( max_fragment_size = (
self.l2cap_channel.peer_mtu - 3 self.l2cap_channel.mtu - 3
) # Enough space for a 3-byte start packet header ) # Enough space for a 3-byte start packet header
payload = message.payload payload = message.payload
if len(payload) + 2 <= self.l2cap_channel.peer_mtu: if len(payload) + 2 <= self.l2cap_channel.mtu:
# Fits in a single packet # Fits in a single packet
packet_type = self.PacketType.SINGLE_PACKET packet_type = self.PacketType.SINGLE_PACKET
else: else:
@@ -1550,10 +1541,9 @@ class Protocol(EventEmitter):
assert False # Should never reach this assert False # Should never reach this
async def get_capabilities(self, seid: int) -> Union[ async def get_capabilities(
Get_Capabilities_Response, self, seid: int
Get_All_Capabilities_Response, ) -> Union[Get_Capabilities_Response, Get_All_Capabilities_Response,]:
]:
if self.version > (1, 2): if self.version > (1, 2):
return await self.send_command(Get_All_Capabilities_Command(seid)) return await self.send_command(Get_All_Capabilities_Command(seid))
-1918
View File
File diff suppressed because it is too large Load Diff
+36 -451
View File
@@ -19,7 +19,6 @@ from __future__ import annotations
import logging import logging
import asyncio import asyncio
import dataclasses
import itertools import itertools
import random import random
import struct import struct
@@ -43,7 +42,6 @@ from bumble.hci import (
HCI_LE_1M_PHY, HCI_LE_1M_PHY,
HCI_SUCCESS, HCI_SUCCESS,
HCI_UNKNOWN_HCI_COMMAND_ERROR, HCI_UNKNOWN_HCI_COMMAND_ERROR,
HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR,
HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR, HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR,
HCI_VERSION_BLUETOOTH_CORE_5_0, HCI_VERSION_BLUETOOTH_CORE_5_0,
Address, Address,
@@ -55,21 +53,17 @@ from bumble.hci import (
HCI_Connection_Request_Event, HCI_Connection_Request_Event,
HCI_Disconnection_Complete_Event, HCI_Disconnection_Complete_Event,
HCI_Encryption_Change_Event, HCI_Encryption_Change_Event,
HCI_Synchronous_Connection_Complete_Event,
HCI_LE_Advertising_Report_Event, HCI_LE_Advertising_Report_Event,
HCI_LE_CIS_Established_Event,
HCI_LE_CIS_Request_Event,
HCI_LE_Connection_Complete_Event, HCI_LE_Connection_Complete_Event,
HCI_LE_Read_Remote_Features_Complete_Event, HCI_LE_Read_Remote_Features_Complete_Event,
HCI_Number_Of_Completed_Packets_Event, HCI_Number_Of_Completed_Packets_Event,
HCI_Packet, HCI_Packet,
HCI_Role_Change_Event, HCI_Role_Change_Event,
) )
from typing import Optional, Union, Dict, Any, TYPE_CHECKING from typing import Optional, Union, Dict, TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from bumble.link import LocalLink from bumble.transport.common import TransportSink, TransportSource
from bumble.transport.common import TransportSink
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -85,27 +79,15 @@ class DataObject:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@dataclasses.dataclass
class CisLink:
handle: int
cis_id: int
cig_id: int
acl_connection: Optional[Connection] = None
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class Connection: class Connection:
controller: Controller def __init__(self, controller, handle, role, peer_address, link, transport):
handle: int self.controller = controller
role: int self.handle = handle
peer_address: Address self.role = role
link: Any self.peer_address = peer_address
transport: int self.link = link
link_type: int
def __post_init__(self):
self.assembler = HCI_AclDataPacketAssembler(self.on_acl_pdu) self.assembler = HCI_AclDataPacketAssembler(self.on_acl_pdu)
self.transport = transport
def on_hci_acl_data_packet(self, packet): def on_hci_acl_data_packet(self, packet):
self.assembler.feed_packet(packet) self.assembler.feed_packet(packet)
@@ -124,27 +106,25 @@ class Connection:
class Controller: class Controller:
def __init__( def __init__(
self, self,
name: str, name,
host_source=None, host_source=None,
host_sink: Optional[TransportSink] = None, host_sink: Optional[TransportSink] = None,
link: Optional[LocalLink] = None, link=None,
public_address: Optional[Union[bytes, str, Address]] = None, public_address: Optional[Union[bytes, str, Address]] = None,
): ):
self.name = name self.name = name
self.hci_sink = None self.hci_sink = None
self.link = link self.link = link
self.central_connections: Dict[Address, Connection] = ( self.central_connections: Dict[
{} Address, Connection
) # Connections where this controller is the central ] = {} # Connections where this controller is the central
self.peripheral_connections: Dict[Address, Connection] = ( self.peripheral_connections: Dict[
{} Address, Connection
) # Connections where this controller is the peripheral ] = {} # Connections where this controller is the peripheral
self.classic_connections: Dict[Address, Connection] = ( self.classic_connections: Dict[
{} Address, Connection
) # Connections in BR/EDR ] = {} # Connections in BR/EDR
self.central_cis_links: Dict[int, CisLink] = {} # CIS links by handle
self.peripheral_cis_links: Dict[int, CisLink] = {} # CIS links by handle
self.hci_version = HCI_VERSION_BLUETOOTH_CORE_5_0 self.hci_version = HCI_VERSION_BLUETOOTH_CORE_5_0
self.hci_revision = 0 self.hci_revision = 0
@@ -154,14 +134,12 @@ class Controller:
'0000000060000000' '0000000060000000'
) # BR/EDR Not Supported, LE Supported (Controller) ) # BR/EDR Not Supported, LE Supported (Controller)
self.manufacturer_name = 0xFFFF self.manufacturer_name = 0xFFFF
self.hc_data_packet_length = 27
self.hc_total_num_data_packets = 64
self.hc_le_data_packet_length = 27 self.hc_le_data_packet_length = 27
self.hc_total_num_le_data_packets = 64 self.hc_total_num_le_data_packets = 64
self.event_mask = 0 self.event_mask = 0
self.event_mask_page_2 = 0 self.event_mask_page_2 = 0
self.supported_commands = bytes.fromhex( self.supported_commands = bytes.fromhex(
'2000800000c000000000e4000000a822000000000000040000f7ffff7f000000' '2000800000c000000000e40000002822000000000000040000f7ffff7f000000'
'30f0f9ff01008004000000000000000000000000000000000000000000000000' '30f0f9ff01008004000000000000000000000000000000000000000000000000'
) )
self.le_event_mask = 0 self.le_event_mask = 0
@@ -323,7 +301,7 @@ class Controller:
############################################################ ############################################################
# Link connections # Link connections
############################################################ ############################################################
def allocate_connection_handle(self) -> int: def allocate_connection_handle(self):
handle = 0 handle = 0
max_handle = 0 max_handle = 0
for connection in itertools.chain( for connection in itertools.chain(
@@ -335,13 +313,6 @@ class Controller:
if connection.handle == handle: if connection.handle == handle:
# Already used, continue searching after the current max # Already used, continue searching after the current max
handle = max_handle + 1 handle = max_handle + 1
for cis_handle in itertools.chain(
self.central_cis_links.keys(), self.peripheral_cis_links.keys()
):
max_handle = max(max_handle, cis_handle)
if cis_handle == handle:
# Already used, continue searching after the current max
handle = max_handle + 1
return handle return handle
def find_le_connection_by_address(self, address): def find_le_connection_by_address(self, address):
@@ -386,13 +357,12 @@ class Controller:
if connection is None: if connection is None:
connection_handle = self.allocate_connection_handle() connection_handle = self.allocate_connection_handle()
connection = Connection( connection = Connection(
controller=self, self,
handle=connection_handle, connection_handle,
role=BT_PERIPHERAL_ROLE, BT_PERIPHERAL_ROLE,
peer_address=peer_address, peer_address,
link=self.link, self.link,
transport=BT_LE_TRANSPORT, BT_LE_TRANSPORT,
link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE,
) )
self.peripheral_connections[peer_address] = connection self.peripheral_connections[peer_address] = connection
logger.debug(f'New PERIPHERAL connection handle: 0x{connection_handle:04X}') logger.debug(f'New PERIPHERAL connection handle: 0x{connection_handle:04X}')
@@ -446,13 +416,12 @@ class Controller:
if connection is None: if connection is None:
connection_handle = self.allocate_connection_handle() connection_handle = self.allocate_connection_handle()
connection = Connection( connection = Connection(
controller=self, self,
handle=connection_handle, connection_handle,
role=BT_CENTRAL_ROLE, BT_CENTRAL_ROLE,
peer_address=peer_address, peer_address,
link=self.link, self.link,
transport=BT_LE_TRANSPORT, BT_LE_TRANSPORT,
link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE,
) )
self.central_connections[peer_address] = connection self.central_connections[peer_address] = connection
logger.debug( logger.debug(
@@ -569,104 +538,6 @@ class Controller:
) )
self.send_hci_packet(HCI_LE_Advertising_Report_Event([report])) self.send_hci_packet(HCI_LE_Advertising_Report_Event([report]))
def on_link_cis_request(
self, central_address: Address, cig_id: int, cis_id: int
) -> None:
'''
Called when an incoming CIS request occurs from a central on the link
'''
connection = self.peripheral_connections.get(central_address)
assert connection
pending_cis_link = CisLink(
handle=self.allocate_connection_handle(),
cis_id=cis_id,
cig_id=cig_id,
acl_connection=connection,
)
self.peripheral_cis_links[pending_cis_link.handle] = pending_cis_link
self.send_hci_packet(
HCI_LE_CIS_Request_Event(
acl_connection_handle=connection.handle,
cis_connection_handle=pending_cis_link.handle,
cig_id=cig_id,
cis_id=cis_id,
)
)
def on_link_cis_established(self, cig_id: int, cis_id: int) -> None:
'''
Called when an incoming CIS established.
'''
cis_link = next(
cis_link
for cis_link in itertools.chain(
self.central_cis_links.values(), self.peripheral_cis_links.values()
)
if cis_link.cis_id == cis_id and cis_link.cig_id == cig_id
)
self.send_hci_packet(
HCI_LE_CIS_Established_Event(
status=HCI_SUCCESS,
connection_handle=cis_link.handle,
# CIS parameters are ignored.
cig_sync_delay=0,
cis_sync_delay=0,
transport_latency_c_to_p=0,
transport_latency_p_to_c=0,
phy_c_to_p=0,
phy_p_to_c=0,
nse=0,
bn_c_to_p=0,
bn_p_to_c=0,
ft_c_to_p=0,
ft_p_to_c=0,
max_pdu_c_to_p=0,
max_pdu_p_to_c=0,
iso_interval=0,
)
)
def on_link_cis_disconnected(self, cig_id: int, cis_id: int) -> None:
'''
Called when a CIS disconnected.
'''
if cis_link := next(
(
cis_link
for cis_link in self.peripheral_cis_links.values()
if cis_link.cis_id == cis_id and cis_link.cig_id == cig_id
),
None,
):
# Remove peripheral CIS on disconnection.
self.peripheral_cis_links.pop(cis_link.handle)
elif cis_link := next(
(
cis_link
for cis_link in self.central_cis_links.values()
if cis_link.cis_id == cis_id and cis_link.cig_id == cig_id
),
None,
):
# Keep central CIS on disconnection. They should be removed by HCI_LE_Remove_CIG_Command.
cis_link.acl_connection = None
else:
return
self.send_hci_packet(
HCI_Disconnection_Complete_Event(
status=HCI_SUCCESS,
connection_handle=cis_link.handle,
reason=HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR,
)
)
############################################################ ############################################################
# Classic link connections # Classic link connections
############################################################ ############################################################
@@ -695,7 +566,6 @@ class Controller:
peer_address=peer_address, peer_address=peer_address,
link=self.link, link=self.link,
transport=BT_BR_EDR_TRANSPORT, transport=BT_BR_EDR_TRANSPORT,
link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE,
) )
self.classic_connections[peer_address] = connection self.classic_connections[peer_address] = connection
logger.debug( logger.debug(
@@ -749,42 +619,6 @@ class Controller:
) )
) )
def on_classic_sco_connection_complete(
self, peer_address: Address, status: int, link_type: int
):
if status == HCI_SUCCESS:
# Allocate (or reuse) a connection handle
connection_handle = self.allocate_connection_handle()
connection = Connection(
controller=self,
handle=connection_handle,
# Role doesn't matter in SCO.
role=BT_CENTRAL_ROLE,
peer_address=peer_address,
link=self.link,
transport=BT_BR_EDR_TRANSPORT,
link_type=link_type,
)
self.classic_connections[peer_address] = connection
logger.debug(f'New SCO connection handle: 0x{connection_handle:04X}')
else:
connection_handle = 0
self.send_hci_packet(
HCI_Synchronous_Connection_Complete_Event(
status=status,
connection_handle=connection_handle,
bd_addr=peer_address,
link_type=link_type,
# TODO: Provide SCO connection parameters.
transmission_interval=0,
retransmission_window=0,
rx_packet_length=0,
tx_packet_length=0,
air_mode=0,
)
)
############################################################ ############################################################
# Advertising support # Advertising support
############################################################ ############################################################
@@ -887,17 +721,6 @@ class Controller:
else: else:
# Remove the connection # Remove the connection
del self.classic_connections[connection.peer_address] del self.classic_connections[connection.peer_address]
elif cis_link := (
self.central_cis_links.get(handle) or self.peripheral_cis_links.get(handle)
):
if self.link:
self.link.disconnect_cis(
initiator_controller=self,
peer_address=cis_link.acl_connection.peer_address,
cig_id=cis_link.cig_id,
cis_id=cis_link.cis_id,
)
# Spec requires handle to be kept after disconnection.
def on_hci_accept_connection_request_command(self, command): def on_hci_accept_connection_request_command(self, command):
''' '''
@@ -915,68 +738,6 @@ class Controller:
) )
self.link.classic_accept_connection(self, command.bd_addr, command.role) self.link.classic_accept_connection(self, command.bd_addr, command.role)
def on_hci_enhanced_setup_synchronous_connection_command(self, command):
'''
See Bluetooth spec Vol 4, Part E - 7.1.45 Enhanced Setup Synchronous Connection command
'''
if self.link is None:
return
if not (
connection := self.find_classic_connection_by_handle(
command.connection_handle
)
):
self.send_hci_packet(
HCI_Command_Status_Event(
status=HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR,
num_hci_command_packets=1,
command_opcode=command.op_code,
)
)
return
self.send_hci_packet(
HCI_Command_Status_Event(
status=HCI_SUCCESS,
num_hci_command_packets=1,
command_opcode=command.op_code,
)
)
self.link.classic_sco_connect(
self, connection.peer_address, HCI_Connection_Complete_Event.ESCO_LINK_TYPE
)
def on_hci_enhanced_accept_synchronous_connection_request_command(self, command):
'''
See Bluetooth spec Vol 4, Part E - 7.1.46 Enhanced Accept Synchronous Connection Request command
'''
if self.link is None:
return
if not (connection := self.find_classic_connection_by_address(command.bd_addr)):
self.send_hci_packet(
HCI_Command_Status_Event(
status=HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR,
num_hci_command_packets=1,
command_opcode=command.op_code,
)
)
return
self.send_hci_packet(
HCI_Command_Status_Event(
status=HCI_SUCCESS,
num_hci_command_packets=1,
command_opcode=command.op_code,
)
)
self.link.classic_accept_sco_connection(
self, connection.peer_address, HCI_Connection_Complete_Event.ESCO_LINK_TYPE
)
def on_hci_switch_role_command(self, command): def on_hci_switch_role_command(self, command):
''' '''
See Bluetooth spec Vol 4, Part E - 7.2.8 Switch Role command See Bluetooth spec Vol 4, Part E - 7.2.8 Switch Role command
@@ -1151,41 +912,7 @@ class Controller:
''' '''
See Bluetooth spec Vol 4, Part E - 7.4.3 Read Local Supported Features Command See Bluetooth spec Vol 4, Part E - 7.4.3 Read Local Supported Features Command
''' '''
return bytes([HCI_SUCCESS]) + self.lmp_features[:8] return bytes([HCI_SUCCESS]) + self.lmp_features
def on_hci_read_local_extended_features_command(self, command):
'''
See Bluetooth spec Vol 4, Part E - 7.4.4 Read Local Extended Features Command
'''
if command.page_number * 8 > len(self.lmp_features):
return bytes([HCI_INVALID_HCI_COMMAND_PARAMETERS_ERROR])
return (
bytes(
[
# Status
HCI_SUCCESS,
# Page number
command.page_number,
# Max page number
len(self.lmp_features) // 8 - 1,
]
)
# Features of the current page
+ self.lmp_features[command.page_number * 8 : (command.page_number + 1) * 8]
)
def on_hci_read_buffer_size_command(self, _command):
'''
See Bluetooth spec Vol 4, Part E - 7.4.5 Read Buffer Size Command
'''
return struct.pack(
'<BHBHH',
HCI_SUCCESS,
self.hc_data_packet_length,
0,
self.hc_total_num_data_packets,
0,
)
def on_hci_read_bd_addr_command(self, _command): def on_hci_read_bd_addr_command(self, _command):
''' '''
@@ -1273,9 +1000,6 @@ class Controller:
''' '''
See Bluetooth spec Vol 4, Part E - 7.8.10 LE Set Scan Parameters Command See Bluetooth spec Vol 4, Part E - 7.8.10 LE Set Scan Parameters Command
''' '''
if self.le_scan_enable:
return bytes([HCI_COMMAND_DISALLOWED_ERROR])
self.le_scan_type = command.le_scan_type self.le_scan_type = command.le_scan_type
self.le_scan_interval = command.le_scan_interval self.le_scan_interval = command.le_scan_interval
self.le_scan_window = command.le_scan_window self.le_scan_window = command.le_scan_window
@@ -1362,18 +1086,6 @@ class Controller:
See Bluetooth spec Vol 4, Part E - 7.8.21 LE Read Remote Features Command See Bluetooth spec Vol 4, Part E - 7.8.21 LE Read Remote Features Command
''' '''
handle = command.connection_handle
if not self.find_connection_by_handle(handle):
self.send_hci_packet(
HCI_Command_Status_Event(
status=HCI_INVALID_HCI_COMMAND_PARAMETERS_ERROR,
num_hci_command_packets=1,
command_opcode=command.op_code,
)
)
return
# First, say that the command is pending # First, say that the command is pending
self.send_hci_packet( self.send_hci_packet(
HCI_Command_Status_Event( HCI_Command_Status_Event(
@@ -1387,7 +1099,7 @@ class Controller:
self.send_hci_packet( self.send_hci_packet(
HCI_LE_Read_Remote_Features_Complete_Event( HCI_LE_Read_Remote_Features_Complete_Event(
status=HCI_SUCCESS, status=HCI_SUCCESS,
connection_handle=handle, connection_handle=0,
le_features=bytes.fromhex('dd40000000000000'), le_features=bytes.fromhex('dd40000000000000'),
) )
) )
@@ -1543,135 +1255,8 @@ class Controller:
} }
return bytes([HCI_SUCCESS]) return bytes([HCI_SUCCESS])
def on_hci_le_read_maximum_advertising_data_length_command(self, _command):
'''
See Bluetooth spec Vol 4, Part E - 7.8.57 LE Read Maximum Advertising Data
Length Command
'''
return struct.pack('<BH', HCI_SUCCESS, 0x0672)
def on_hci_le_read_number_of_supported_advertising_sets_command(self, _command):
'''
See Bluetooth spec Vol 4, Part E - 7.8.58 LE Read Number of Supported
Advertising Set Command
'''
return struct.pack('<BB', HCI_SUCCESS, 0xF0)
def on_hci_le_read_transmit_power_command(self, _command): def on_hci_le_read_transmit_power_command(self, _command):
''' '''
See Bluetooth spec Vol 4, Part E - 7.8.74 LE Read Transmit Power Command See Bluetooth spec Vol 4, Part E - 7.8.74 LE Read Transmit Power Command
''' '''
return struct.pack('<BBB', HCI_SUCCESS, 0, 0) return struct.pack('<BBB', HCI_SUCCESS, 0, 0)
def on_hci_le_set_cig_parameters_command(self, command):
'''
See Bluetooth spec Vol 4, Part E - 7.8.97 LE Set CIG Parameter Command
'''
# Remove old CIG implicitly.
for handle, cis_link in self.central_cis_links.items():
if cis_link.cig_id == command.cig_id:
self.central_cis_links.pop(handle)
handles = []
for cis_id in command.cis_id:
handle = self.allocate_connection_handle()
handles.append(handle)
self.central_cis_links[handle] = CisLink(
cis_id=cis_id,
cig_id=command.cig_id,
handle=handle,
)
return struct.pack(
'<BBB', HCI_SUCCESS, command.cig_id, len(handles)
) + b''.join([struct.pack('<H', handle) for handle in handles])
def on_hci_le_create_cis_command(self, command):
'''
See Bluetooth spec Vol 4, Part E - 7.8.99 LE Create CIS Command
'''
if not self.link:
return
for cis_handle, acl_handle in zip(
command.cis_connection_handle, command.acl_connection_handle
):
if not (connection := self.find_connection_by_handle(acl_handle)):
logger.error(f'Cannot find connection with handle={acl_handle}')
return bytes([HCI_INVALID_HCI_COMMAND_PARAMETERS_ERROR])
if not (cis_link := self.central_cis_links.get(cis_handle)):
logger.error(f'Cannot find CIS with handle={cis_handle}')
return bytes([HCI_INVALID_HCI_COMMAND_PARAMETERS_ERROR])
cis_link.acl_connection = connection
self.link.create_cis(
self,
peripheral_address=connection.peer_address,
cig_id=cis_link.cig_id,
cis_id=cis_link.cis_id,
)
self.send_hci_packet(
HCI_Command_Status_Event(
status=HCI_COMMAND_STATUS_PENDING,
num_hci_command_packets=1,
command_opcode=command.op_code,
)
)
def on_hci_le_remove_cig_command(self, command):
'''
See Bluetooth spec Vol 4, Part E - 7.8.100 LE Remove CIG Command
'''
status = HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR
for cis_handle, cis_link in self.central_cis_links.items():
if cis_link.cig_id == command.cig_id:
self.central_cis_links.pop(cis_handle)
status = HCI_SUCCESS
return struct.pack('<BH', status, command.cig_id)
def on_hci_le_accept_cis_request_command(self, command):
'''
See Bluetooth spec Vol 4, Part E - 7.8.101 LE Accept CIS Request Command
'''
if not self.link:
return
if not (
pending_cis_link := self.peripheral_cis_links.get(command.connection_handle)
):
logger.error(f'Cannot find CIS with handle={command.connection_handle}')
return bytes([HCI_INVALID_HCI_COMMAND_PARAMETERS_ERROR])
assert pending_cis_link.acl_connection
self.link.accept_cis(
peripheral_controller=self,
central_address=pending_cis_link.acl_connection.peer_address,
cig_id=pending_cis_link.cig_id,
cis_id=pending_cis_link.cis_id,
)
self.send_hci_packet(
HCI_Command_Status_Event(
status=HCI_COMMAND_STATUS_PENDING,
num_hci_command_packets=1,
command_opcode=command.op_code,
)
)
def on_hci_le_setup_iso_data_path_command(self, command):
'''
See Bluetooth spec Vol 4, Part E - 7.8.109 LE Setup ISO Data Path Command
'''
return struct.pack('<BH', HCI_SUCCESS, command.connection_handle)
def on_hci_le_remove_iso_data_path_command(self, command):
'''
See Bluetooth spec Vol 4, Part E - 7.8.110 LE Remove ISO Data Path Command
'''
return struct.pack('<BH', HCI_SUCCESS, command.connection_handle)
+125 -714
View File
@@ -16,14 +16,10 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations from __future__ import annotations
import dataclasses
import enum
import struct import struct
from typing import List, Optional, Tuple, Union, cast, Dict from typing import List, Optional, Tuple, Union, cast, Dict
from typing_extensions import Self
from bumble.company_ids import COMPANY_IDENTIFIERS from .company_ids import COMPANY_IDENTIFIERS
from bumble.utils import OpenIntEnum
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -100,16 +96,12 @@ class BaseError(Exception):
namespace = f'{self.error_namespace}/' namespace = f'{self.error_namespace}/'
else: else:
namespace = '' namespace = ''
have_name = self.error_name != '' error_text = {
have_code = self.error_code is not None (True, True): f'{self.error_name} [0x{self.error_code:X}]',
if have_name and have_code: (True, False): self.error_name,
error_text = f'{self.error_name} [0x{self.error_code:X}]' (False, True): f'0x{self.error_code:X}',
elif have_name and not have_code: (False, False): '',
error_text = self.error_name }[(self.error_name != '', self.error_code is not None)]
elif not have_name and have_code:
error_text = f'0x{self.error_code:X}'
else:
error_text = '<unspecified>'
return f'{type(self).__name__}({namespace}{error_text})' return f'{type(self).__name__}({namespace}{error_text})'
@@ -326,7 +318,7 @@ BT_HIDP_PROTOCOL_ID = UUID.from_16_bits(0x0011, 'HIDP')
BT_HARDCOPY_CONTROL_CHANNEL_PROTOCOL_ID = UUID.from_16_bits(0x0012, 'HardcopyControlChannel') BT_HARDCOPY_CONTROL_CHANNEL_PROTOCOL_ID = UUID.from_16_bits(0x0012, 'HardcopyControlChannel')
BT_HARDCOPY_DATA_CHANNEL_PROTOCOL_ID = UUID.from_16_bits(0x0014, 'HardcopyDataChannel') BT_HARDCOPY_DATA_CHANNEL_PROTOCOL_ID = UUID.from_16_bits(0x0014, 'HardcopyDataChannel')
BT_HARDCOPY_NOTIFICATION_PROTOCOL_ID = UUID.from_16_bits(0x0016, 'HardcopyNotification') BT_HARDCOPY_NOTIFICATION_PROTOCOL_ID = UUID.from_16_bits(0x0016, 'HardcopyNotification')
BT_AVCTP_PROTOCOL_ID = UUID.from_16_bits(0x0017, 'AVCTP') BT_AVTCP_PROTOCOL_ID = UUID.from_16_bits(0x0017, 'AVCTP')
BT_AVDTP_PROTOCOL_ID = UUID.from_16_bits(0x0019, 'AVDTP') BT_AVDTP_PROTOCOL_ID = UUID.from_16_bits(0x0019, 'AVDTP')
BT_CMTP_PROTOCOL_ID = UUID.from_16_bits(0x001B, 'CMTP') BT_CMTP_PROTOCOL_ID = UUID.from_16_bits(0x001B, 'CMTP')
BT_MCAP_CONTROL_CHANNEL_PROTOCOL_ID = UUID.from_16_bits(0x001E, 'MCAPControlChannel') BT_MCAP_CONTROL_CHANNEL_PROTOCOL_ID = UUID.from_16_bits(0x001E, 'MCAPControlChannel')
@@ -695,569 +687,11 @@ class DeviceClass:
return name_or_number(class_names, minor_device_class) return name_or_number(class_names, minor_device_class)
# -----------------------------------------------------------------------------
# Appearance
# -----------------------------------------------------------------------------
class Appearance:
class Category(OpenIntEnum):
UNKNOWN = 0x0000
PHONE = 0x0001
COMPUTER = 0x0002
WATCH = 0x0003
CLOCK = 0x0004
DISPLAY = 0x0005
REMOTE_CONTROL = 0x0006
EYE_GLASSES = 0x0007
TAG = 0x0008
KEYRING = 0x0009
MEDIA_PLAYER = 0x000A
BARCODE_SCANNER = 0x000B
THERMOMETER = 0x000C
HEART_RATE_SENSOR = 0x000D
BLOOD_PRESSURE = 0x000E
HUMAN_INTERFACE_DEVICE = 0x000F
GLUCOSE_METER = 0x0010
RUNNING_WALKING_SENSOR = 0x0011
CYCLING = 0x0012
CONTROL_DEVICE = 0x0013
NETWORK_DEVICE = 0x0014
SENSOR = 0x0015
LIGHT_FIXTURES = 0x0016
FAN = 0x0017
HVAC = 0x0018
AIR_CONDITIONING = 0x0019
HUMIDIFIER = 0x001A
HEATING = 0x001B
ACCESS_CONTROL = 0x001C
MOTORIZED_DEVICE = 0x001D
POWER_DEVICE = 0x001E
LIGHT_SOURCE = 0x001F
WINDOW_COVERING = 0x0020
AUDIO_SINK = 0x0021
AUDIO_SOURCE = 0x0022
MOTORIZED_VEHICLE = 0x0023
DOMESTIC_APPLIANCE = 0x0024
WEARABLE_AUDIO_DEVICE = 0x0025
AIRCRAFT = 0x0026
AV_EQUIPMENT = 0x0027
DISPLAY_EQUIPMENT = 0x0028
HEARING_AID = 0x0029
GAMING = 0x002A
SIGNAGE = 0x002B
PULSE_OXIMETER = 0x0031
WEIGHT_SCALE = 0x0032
PERSONAL_MOBILITY_DEVICE = 0x0033
CONTINUOUS_GLUCOSE_MONITOR = 0x0034
INSULIN_PUMP = 0x0035
MEDICATION_DELIVERY = 0x0036
SPIROMETER = 0x0037
OUTDOOR_SPORTS_ACTIVITY = 0x0051
class UnknownSubcategory(OpenIntEnum):
GENERIC_UNKNOWN = 0x00
class PhoneSubcategory(OpenIntEnum):
GENERIC_PHONE = 0x00
class ComputerSubcategory(OpenIntEnum):
GENERIC_COMPUTER = 0x00
DESKTOP_WORKSTATION = 0x01
SERVER_CLASS_COMPUTER = 0x02
LAPTOP = 0x03
HANDHELD_PC_PDA = 0x04
PALM_SIZE_PC_PDA = 0x05
WEARABLE_COMPUTER = 0x06
TABLET = 0x07
DOCKING_STATION = 0x08
ALL_IN_ONE = 0x09
BLADE_SERVER = 0x0A
CONVERTIBLE = 0x0B
DETACHABLE = 0x0C
IOT_GATEWAY = 0x0D
MINI_PC = 0x0E
STICK_PC = 0x0F
class WatchSubcategory(OpenIntEnum):
GENENERIC_WATCH = 0x00
SPORTS_WATCH = 0x01
SMARTWATCH = 0x02
class ClockSubcategory(OpenIntEnum):
GENERIC_CLOCK = 0x00
class DisplaySubcategory(OpenIntEnum):
GENERIC_DISPLAY = 0x00
class RemoteControlSubcategory(OpenIntEnum):
GENERIC_REMOTE_CONTROL = 0x00
class EyeglassesSubcategory(OpenIntEnum):
GENERIC_EYEGLASSES = 0x00
class TagSubcategory(OpenIntEnum):
GENERIC_TAG = 0x00
class KeyringSubcategory(OpenIntEnum):
GENERIC_KEYRING = 0x00
class MediaPlayerSubcategory(OpenIntEnum):
GENERIC_MEDIA_PLAYER = 0x00
class BarcodeScannerSubcategory(OpenIntEnum):
GENERIC_BARCODE_SCANNER = 0x00
class ThermometerSubcategory(OpenIntEnum):
GENERIC_THERMOMETER = 0x00
EAR_THERMOMETER = 0x01
class HeartRateSensorSubcategory(OpenIntEnum):
GENERIC_HEART_RATE_SENSOR = 0x00
HEART_RATE_BELT = 0x01
class BloodPressureSubcategory(OpenIntEnum):
GENERIC_BLOOD_PRESSURE = 0x00
ARM_BLOOD_PRESSURE = 0x01
WRIST_BLOOD_PRESSURE = 0x02
class HumanInterfaceDeviceSubcategory(OpenIntEnum):
GENERIC_HUMAN_INTERFACE_DEVICE = 0x00
KEYBOARD = 0x01
MOUSE = 0x02
JOYSTICK = 0x03
GAMEPAD = 0x04
DIGITIZER_TABLET = 0x05
CARD_READER = 0x06
DIGITAL_PEN = 0x07
BARCODE_SCANNER = 0x08
TOUCHPAD = 0x09
PRESENTATION_REMOTE = 0x0A
class GlucoseMeterSubcategory(OpenIntEnum):
GENERIC_GLUCOSE_METER = 0x00
class RunningWalkingSensorSubcategory(OpenIntEnum):
GENERIC_RUNNING_WALKING_SENSOR = 0x00
IN_SHOE_RUNNING_WALKING_SENSOR = 0x01
ON_SHOW_RUNNING_WALKING_SENSOR = 0x02
ON_HIP_RUNNING_WALKING_SENSOR = 0x03
class CyclingSubcategory(OpenIntEnum):
GENERIC_CYCLING = 0x00
CYCLING_COMPUTER = 0x01
SPEED_SENSOR = 0x02
CADENCE_SENSOR = 0x03
POWER_SENSOR = 0x04
SPEED_AND_CADENCE_SENSOR = 0x05
class ControlDeviceSubcategory(OpenIntEnum):
GENERIC_CONTROL_DEVICE = 0x00
SWITCH = 0x01
MULTI_SWITCH = 0x02
BUTTON = 0x03
SLIDER = 0x04
ROTARY_SWITCH = 0x05
TOUCH_PANEL = 0x06
SINGLE_SWITCH = 0x07
DOUBLE_SWITCH = 0x08
TRIPLE_SWITCH = 0x09
BATTERY_SWITCH = 0x0A
ENERGY_HARVESTING_SWITCH = 0x0B
PUSH_BUTTON = 0x0C
class NetworkDeviceSubcategory(OpenIntEnum):
GENERIC_NETWORK_DEVICE = 0x00
ACCESS_POINT = 0x01
MESH_DEVICE = 0x02
MESH_NETWORK_PROXY = 0x03
class SensorSubcategory(OpenIntEnum):
GENERIC_SENSOR = 0x00
MOTION_SENSOR = 0x01
AIR_QUALITY_SENSOR = 0x02
TEMPERATURE_SENSOR = 0x03
HUMIDITY_SENSOR = 0x04
LEAK_SENSOR = 0x05
SMOKE_SENSOR = 0x06
OCCUPANCY_SENSOR = 0x07
CONTACT_SENSOR = 0x08
CARBON_MONOXIDE_SENSOR = 0x09
CARBON_DIOXIDE_SENSOR = 0x0A
AMBIENT_LIGHT_SENSOR = 0x0B
ENERGY_SENSOR = 0x0C
COLOR_LIGHT_SENSOR = 0x0D
RAIN_SENSOR = 0x0E
FIRE_SENSOR = 0x0F
WIND_SENSOR = 0x10
PROXIMITY_SENSOR = 0x11
MULTI_SENSOR = 0x12
FLUSH_MOUNTED_SENSOR = 0x13
CEILING_MOUNTED_SENSOR = 0x14
WALL_MOUNTED_SENSOR = 0x15
MULTISENSOR = 0x16
ENERGY_METER = 0x17
FLAME_DETECTOR = 0x18
VEHICLE_TIRE_PRESSURE_SENSOR = 0x19
class LightFixturesSubcategory(OpenIntEnum):
GENERIC_LIGHT_FIXTURES = 0x00
WALL_LIGHT = 0x01
CEILING_LIGHT = 0x02
FLOOR_LIGHT = 0x03
CABINET_LIGHT = 0x04
DESK_LIGHT = 0x05
TROFFER_LIGHT = 0x06
PENDANT_LIGHT = 0x07
IN_GROUND_LIGHT = 0x08
FLOOD_LIGHT = 0x09
UNDERWATER_LIGHT = 0x0A
BOLLARD_WITH_LIGHT = 0x0B
PATHWAY_LIGHT = 0x0C
GARDEN_LIGHT = 0x0D
POLE_TOP_LIGHT = 0x0E
SPOTLIGHT = 0x0F
LINEAR_LIGHT = 0x10
STREET_LIGHT = 0x11
SHELVES_LIGHT = 0x12
BAY_LIGHT = 0x013
EMERGENCY_EXIT_LIGHT = 0x14
LIGHT_CONTROLLER = 0x15
LIGHT_DRIVER = 0x16
BULB = 0x17
LOW_BAY_LIGHT = 0x18
HIGH_BAY_LIGHT = 0x19
class FanSubcategory(OpenIntEnum):
GENERIC_FAN = 0x00
CEILING_FAN = 0x01
AXIAL_FAN = 0x02
EXHAUST_FAN = 0x03
PEDESTAL_FAN = 0x04
DESK_FAN = 0x05
WALL_FAN = 0x06
class HvacSubcategory(OpenIntEnum):
GENERIC_HVAC = 0x00
THERMOSTAT = 0x01
HUMIDIFIER = 0x02
DEHUMIDIFIER = 0x03
HEATER = 0x04
RADIATOR = 0x05
BOILER = 0x06
HEAT_PUMP = 0x07
INFRARED_HEATER = 0x08
RADIANT_PANEL_HEATER = 0x09
FAN_HEATER = 0x0A
AIR_CURTAIN = 0x0B
class AirConditioningSubcategory(OpenIntEnum):
GENERIC_AIR_CONDITIONING = 0x00
class HumidifierSubcategory(OpenIntEnum):
GENERIC_HUMIDIFIER = 0x00
class HeatingSubcategory(OpenIntEnum):
GENERIC_HEATING = 0x00
RADIATOR = 0x01
BOILER = 0x02
HEAT_PUMP = 0x03
INFRARED_HEATER = 0x04
RADIANT_PANEL_HEATER = 0x05
FAN_HEATER = 0x06
AIR_CURTAIN = 0x07
class AccessControlSubcategory(OpenIntEnum):
GENERIC_ACCESS_CONTROL = 0x00
ACCESS_DOOR = 0x01
GARAGE_DOOR = 0x02
EMERGENCY_EXIT_DOOR = 0x03
ACCESS_LOCK = 0x04
ELEVATOR = 0x05
WINDOW = 0x06
ENTRANCE_GATE = 0x07
DOOR_LOCK = 0x08
LOCKER = 0x09
class MotorizedDeviceSubcategory(OpenIntEnum):
GENERIC_MOTORIZED_DEVICE = 0x00
MOTORIZED_GATE = 0x01
AWNING = 0x02
BLINDS_OR_SHADES = 0x03
CURTAINS = 0x04
SCREEN = 0x05
class PowerDeviceSubcategory(OpenIntEnum):
GENERIC_POWER_DEVICE = 0x00
POWER_OUTLET = 0x01
POWER_STRIP = 0x02
PLUG = 0x03
POWER_SUPPLY = 0x04
LED_DRIVER = 0x05
FLUORESCENT_LAMP_GEAR = 0x06
HID_LAMP_GEAR = 0x07
CHARGE_CASE = 0x08
POWER_BANK = 0x09
class LightSourceSubcategory(OpenIntEnum):
GENERIC_LIGHT_SOURCE = 0x00
INCANDESCENT_LIGHT_BULB = 0x01
LED_LAMP = 0x02
HID_LAMP = 0x03
FLUORESCENT_LAMP = 0x04
LED_ARRAY = 0x05
MULTI_COLOR_LED_ARRAY = 0x06
LOW_VOLTAGE_HALOGEN = 0x07
ORGANIC_LIGHT_EMITTING_DIODE = 0x08
class WindowCoveringSubcategory(OpenIntEnum):
GENERIC_WINDOW_COVERING = 0x00
WINDOW_SHADES = 0x01
WINDOW_BLINDS = 0x02
WINDOW_AWNING = 0x03
WINDOW_CURTAIN = 0x04
EXTERIOR_SHUTTER = 0x05
EXTERIOR_SCREEN = 0x06
class AudioSinkSubcategory(OpenIntEnum):
GENERIC_AUDIO_SINK = 0x00
STANDALONE_SPEAKER = 0x01
SOUNDBAR = 0x02
BOOKSHELF_SPEAKER = 0x03
STANDMOUNTED_SPEAKER = 0x04
SPEAKERPHONE = 0x05
class AudioSourceSubcategory(OpenIntEnum):
GENERIC_AUDIO_SOURCE = 0x00
MICROPHONE = 0x01
ALARM = 0x02
BELL = 0x03
HORN = 0x04
BROADCASTING_DEVICE = 0x05
SERVICE_DESK = 0x06
KIOSK = 0x07
BROADCASTING_ROOM = 0x08
AUDITORIUM = 0x09
class MotorizedVehicleSubcategory(OpenIntEnum):
GENERIC_MOTORIZED_VEHICLE = 0x00
CAR = 0x01
LARGE_GOODS_VEHICLE = 0x02
TWO_WHEELED_VEHICLE = 0x03
MOTORBIKE = 0x04
SCOOTER = 0x05
MOPED = 0x06
THREE_WHEELED_VEHICLE = 0x07
LIGHT_VEHICLE = 0x08
QUAD_BIKE = 0x09
MINIBUS = 0x0A
BUS = 0x0B
TROLLEY = 0x0C
AGRICULTURAL_VEHICLE = 0x0D
CAMPER_CARAVAN = 0x0E
RECREATIONAL_VEHICLE_MOTOR_HOME = 0x0F
class DomesticApplianceSubcategory(OpenIntEnum):
GENERIC_DOMESTIC_APPLIANCE = 0x00
REFRIGERATOR = 0x01
FREEZER = 0x02
OVEN = 0x03
MICROWAVE = 0x04
TOASTER = 0x05
WASHING_MACHINE = 0x06
DRYER = 0x07
COFFEE_MAKER = 0x08
CLOTHES_IRON = 0x09
CURLING_IRON = 0x0A
HAIR_DRYER = 0x0B
VACUUM_CLEANER = 0x0C
ROBOTIC_VACUUM_CLEANER = 0x0D
RICE_COOKER = 0x0E
CLOTHES_STEAMER = 0x0F
class WearableAudioDeviceSubcategory(OpenIntEnum):
GENERIC_WEARABLE_AUDIO_DEVICE = 0x00
EARBUD = 0x01
HEADSET = 0x02
HEADPHONES = 0x03
NECK_BAND = 0x04
class AircraftSubcategory(OpenIntEnum):
GENERIC_AIRCRAFT = 0x00
LIGHT_AIRCRAFT = 0x01
MICROLIGHT = 0x02
PARAGLIDER = 0x03
LARGE_PASSENGER_AIRCRAFT = 0x04
class AvEquipmentSubcategory(OpenIntEnum):
GENERIC_AV_EQUIPMENT = 0x00
AMPLIFIER = 0x01
RECEIVER = 0x02
RADIO = 0x03
TUNER = 0x04
TURNTABLE = 0x05
CD_PLAYER = 0x06
DVD_PLAYER = 0x07
BLUERAY_PLAYER = 0x08
OPTICAL_DISC_PLAYER = 0x09
SET_TOP_BOX = 0x0A
class DisplayEquipmentSubcategory(OpenIntEnum):
GENERIC_DISPLAY_EQUIPMENT = 0x00
TELEVISION = 0x01
MONITOR = 0x02
PROJECTOR = 0x03
class HearingAidSubcategory(OpenIntEnum):
GENERIC_HEARING_AID = 0x00
IN_EAR_HEARING_AID = 0x01
BEHIND_EAR_HEARING_AID = 0x02
COCHLEAR_IMPLANT = 0x03
class GamingSubcategory(OpenIntEnum):
GENERIC_GAMING = 0x00
HOME_VIDEO_GAME_CONSOLE = 0x01
PORTABLE_HANDHELD_CONSOLE = 0x02
class SignageSubcategory(OpenIntEnum):
GENERIC_SIGNAGE = 0x00
DIGITAL_SIGNAGE = 0x01
ELECTRONIC_LABEL = 0x02
class PulseOximeterSubcategory(OpenIntEnum):
GENERIC_PULSE_OXIMETER = 0x00
FINGERTIP_PULSE_OXIMETER = 0x01
WRIST_WORN_PULSE_OXIMETER = 0x02
class WeightScaleSubcategory(OpenIntEnum):
GENERIC_WEIGHT_SCALE = 0x00
class PersonalMobilityDeviceSubcategory(OpenIntEnum):
GENERIC_PERSONAL_MOBILITY_DEVICE = 0x00
POWERED_WHEELCHAIR = 0x01
MOBILITY_SCOOTER = 0x02
class ContinuousGlucoseMonitorSubcategory(OpenIntEnum):
GENERIC_CONTINUOUS_GLUCOSE_MONITOR = 0x00
class InsulinPumpSubcategory(OpenIntEnum):
GENERIC_INSULIN_PUMP = 0x00
INSULIN_PUMP_DURABLE_PUMP = 0x01
INSULIN_PUMP_PATCH_PUMP = 0x02
INSULIN_PEN = 0x03
class MedicationDeliverySubcategory(OpenIntEnum):
GENERIC_MEDICATION_DELIVERY = 0x00
class SpirometerSubcategory(OpenIntEnum):
GENERIC_SPIROMETER = 0x00
HANDHELD_SPIROMETER = 0x01
class OutdoorSportsActivitySubcategory(OpenIntEnum):
GENERIC_OUTDOOR_SPORTS_ACTIVITY = 0x00
LOCATION_DISPLAY = 0x01
LOCATION_AND_NAVIGATION_DISPLAY = 0x02
LOCATION_POD = 0x03
LOCATION_AND_NAVIGATION_POD = 0x04
class _OpenSubcategory(OpenIntEnum):
GENERIC = 0x00
SUBCATEGORY_CLASSES = {
Category.UNKNOWN: UnknownSubcategory,
Category.PHONE: PhoneSubcategory,
Category.COMPUTER: ComputerSubcategory,
Category.WATCH: WatchSubcategory,
Category.CLOCK: ClockSubcategory,
Category.DISPLAY: DisplaySubcategory,
Category.REMOTE_CONTROL: RemoteControlSubcategory,
Category.EYE_GLASSES: EyeglassesSubcategory,
Category.TAG: TagSubcategory,
Category.KEYRING: KeyringSubcategory,
Category.MEDIA_PLAYER: MediaPlayerSubcategory,
Category.BARCODE_SCANNER: BarcodeScannerSubcategory,
Category.THERMOMETER: ThermometerSubcategory,
Category.HEART_RATE_SENSOR: HeartRateSensorSubcategory,
Category.BLOOD_PRESSURE: BloodPressureSubcategory,
Category.HUMAN_INTERFACE_DEVICE: HumanInterfaceDeviceSubcategory,
Category.GLUCOSE_METER: GlucoseMeterSubcategory,
Category.RUNNING_WALKING_SENSOR: RunningWalkingSensorSubcategory,
Category.CYCLING: CyclingSubcategory,
Category.CONTROL_DEVICE: ControlDeviceSubcategory,
Category.NETWORK_DEVICE: NetworkDeviceSubcategory,
Category.SENSOR: SensorSubcategory,
Category.LIGHT_FIXTURES: LightFixturesSubcategory,
Category.FAN: FanSubcategory,
Category.HVAC: HvacSubcategory,
Category.AIR_CONDITIONING: AirConditioningSubcategory,
Category.HUMIDIFIER: HumidifierSubcategory,
Category.HEATING: HeatingSubcategory,
Category.ACCESS_CONTROL: AccessControlSubcategory,
Category.MOTORIZED_DEVICE: MotorizedDeviceSubcategory,
Category.POWER_DEVICE: PowerDeviceSubcategory,
Category.LIGHT_SOURCE: LightSourceSubcategory,
Category.WINDOW_COVERING: WindowCoveringSubcategory,
Category.AUDIO_SINK: AudioSinkSubcategory,
Category.AUDIO_SOURCE: AudioSourceSubcategory,
Category.MOTORIZED_VEHICLE: MotorizedVehicleSubcategory,
Category.DOMESTIC_APPLIANCE: DomesticApplianceSubcategory,
Category.WEARABLE_AUDIO_DEVICE: WearableAudioDeviceSubcategory,
Category.AIRCRAFT: AircraftSubcategory,
Category.AV_EQUIPMENT: AvEquipmentSubcategory,
Category.DISPLAY_EQUIPMENT: DisplayEquipmentSubcategory,
Category.HEARING_AID: HearingAidSubcategory,
Category.GAMING: GamingSubcategory,
Category.SIGNAGE: SignageSubcategory,
Category.PULSE_OXIMETER: PulseOximeterSubcategory,
Category.WEIGHT_SCALE: WeightScaleSubcategory,
Category.PERSONAL_MOBILITY_DEVICE: PersonalMobilityDeviceSubcategory,
Category.CONTINUOUS_GLUCOSE_MONITOR: ContinuousGlucoseMonitorSubcategory,
Category.INSULIN_PUMP: InsulinPumpSubcategory,
Category.MEDICATION_DELIVERY: MedicationDeliverySubcategory,
Category.SPIROMETER: SpirometerSubcategory,
Category.OUTDOOR_SPORTS_ACTIVITY: OutdoorSportsActivitySubcategory,
}
category: Category
subcategory: enum.IntEnum
@classmethod
def from_int(cls, appearance: int) -> Self:
category = cls.Category(appearance >> 6)
return cls(category, appearance & 0x3F)
def __init__(self, category: Category, subcategory: int) -> None:
self.category = category
if subcategory_class := self.SUBCATEGORY_CLASSES.get(category):
self.subcategory = subcategory_class(subcategory)
else:
self.subcategory = self._OpenSubcategory(subcategory)
def __int__(self) -> int:
return self.category << 6 | self.subcategory
def __repr__(self) -> str:
return (
'Appearance('
f'category={self.category.name}, '
f'subcategory={self.subcategory.name}'
')'
)
def __str__(self) -> str:
return f'{self.category.name}/{self.subcategory.name}'
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Advertising Data # Advertising Data
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
AdvertisingDataObject = Union[ AdvertisingObject = Union[
List[UUID], List[UUID], Tuple[UUID, bytes], bytes, str, int, Tuple[int, int], Tuple[int, bytes]
Tuple[UUID, bytes],
bytes,
str,
int,
Tuple[int, int],
Tuple[int, bytes],
Appearance,
] ]
@@ -1265,115 +699,109 @@ class AdvertisingData:
# fmt: off # fmt: off
# pylint: disable=line-too-long # pylint: disable=line-too-long
FLAGS = 0x01 # This list is only partial, it still needs to be filled in from the spec
INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS = 0x02 FLAGS = 0x01
COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS = 0x03 INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS = 0x02
INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS = 0x04 COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS = 0x03
COMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS = 0x05 INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS = 0x04
INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS = 0x06 COMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS = 0x05
COMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS = 0x07 INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS = 0x06
SHORTENED_LOCAL_NAME = 0x08 COMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS = 0x07
COMPLETE_LOCAL_NAME = 0x09 SHORTENED_LOCAL_NAME = 0x08
TX_POWER_LEVEL = 0x0A COMPLETE_LOCAL_NAME = 0x09
CLASS_OF_DEVICE = 0x0D TX_POWER_LEVEL = 0x0A
SIMPLE_PAIRING_HASH_C = 0x0E CLASS_OF_DEVICE = 0x0D
SIMPLE_PAIRING_HASH_C_192 = 0x0E SIMPLE_PAIRING_HASH_C = 0x0E
SIMPLE_PAIRING_RANDOMIZER_R = 0x0F SIMPLE_PAIRING_HASH_C_192 = 0x0E
SIMPLE_PAIRING_RANDOMIZER_R_192 = 0x0F SIMPLE_PAIRING_RANDOMIZER_R = 0x0F
DEVICE_ID = 0x10 SIMPLE_PAIRING_RANDOMIZER_R_192 = 0x0F
SECURITY_MANAGER_TK_VALUE = 0x10 DEVICE_ID = 0x10
SECURITY_MANAGER_OUT_OF_BAND_FLAGS = 0x11 SECURITY_MANAGER_TK_VALUE = 0x10
PERIPHERAL_CONNECTION_INTERVAL_RANGE = 0x12 SECURITY_MANAGER_OUT_OF_BAND_FLAGS = 0x11
LIST_OF_16_BIT_SERVICE_SOLICITATION_UUIDS = 0x14 PERIPHERAL_CONNECTION_INTERVAL_RANGE = 0x12
LIST_OF_128_BIT_SERVICE_SOLICITATION_UUIDS = 0x15 LIST_OF_16_BIT_SERVICE_SOLICITATION_UUIDS = 0x14
SERVICE_DATA = 0x16 LIST_OF_128_BIT_SERVICE_SOLICITATION_UUIDS = 0x15
SERVICE_DATA_16_BIT_UUID = 0x16 SERVICE_DATA = 0x16
PUBLIC_TARGET_ADDRESS = 0x17 SERVICE_DATA_16_BIT_UUID = 0x16
RANDOM_TARGET_ADDRESS = 0x18 PUBLIC_TARGET_ADDRESS = 0x17
APPEARANCE = 0x19 RANDOM_TARGET_ADDRESS = 0x18
ADVERTISING_INTERVAL = 0x1A APPEARANCE = 0x19
LE_BLUETOOTH_DEVICE_ADDRESS = 0x1B ADVERTISING_INTERVAL = 0x1A
LE_ROLE = 0x1C LE_BLUETOOTH_DEVICE_ADDRESS = 0x1B
SIMPLE_PAIRING_HASH_C_256 = 0x1D LE_ROLE = 0x1C
SIMPLE_PAIRING_RANDOMIZER_R_256 = 0x1E SIMPLE_PAIRING_HASH_C_256 = 0x1D
LIST_OF_32_BIT_SERVICE_SOLICITATION_UUIDS = 0x1F SIMPLE_PAIRING_RANDOMIZER_R_256 = 0x1E
SERVICE_DATA_32_BIT_UUID = 0x20 LIST_OF_32_BIT_SERVICE_SOLICITATION_UUIDS = 0x1F
SERVICE_DATA_128_BIT_UUID = 0x21 SERVICE_DATA_32_BIT_UUID = 0x20
LE_SECURE_CONNECTIONS_CONFIRMATION_VALUE = 0x22 SERVICE_DATA_128_BIT_UUID = 0x21
LE_SECURE_CONNECTIONS_RANDOM_VALUE = 0x23 LE_SECURE_CONNECTIONS_CONFIRMATION_VALUE = 0x22
URI = 0x24 LE_SECURE_CONNECTIONS_RANDOM_VALUE = 0x23
INDOOR_POSITIONING = 0x25 URI = 0x24
TRANSPORT_DISCOVERY_DATA = 0x26 INDOOR_POSITIONING = 0x25
LE_SUPPORTED_FEATURES = 0x27 TRANSPORT_DISCOVERY_DATA = 0x26
CHANNEL_MAP_UPDATE_INDICATION = 0x28 LE_SUPPORTED_FEATURES = 0x27
PB_ADV = 0x29 CHANNEL_MAP_UPDATE_INDICATION = 0x28
MESH_MESSAGE = 0x2A PB_ADV = 0x29
MESH_BEACON = 0x2B MESH_MESSAGE = 0x2A
BIGINFO = 0x2C MESH_BEACON = 0x2B
BROADCAST_CODE = 0x2D BIGINFO = 0x2C
RESOLVABLE_SET_IDENTIFIER = 0x2E BROADCAST_CODE = 0x2D
ADVERTISING_INTERVAL_LONG = 0x2F RESOLVABLE_SET_IDENTIFIER = 0x2E
BROADCAST_NAME = 0x30 ADVERTISING_INTERVAL_LONG = 0x2F
ENCRYPTED_ADVERTISING_DATA = 0X31 THREE_D_INFORMATION_DATA = 0x3D
PERIODIC_ADVERTISING_RESPONSE_TIMING_INFORMATION = 0X32 MANUFACTURER_SPECIFIC_DATA = 0xFF
ELECTRONIC_SHELF_LABEL = 0X34
THREE_D_INFORMATION_DATA = 0x3D
MANUFACTURER_SPECIFIC_DATA = 0xFF
AD_TYPE_NAMES = { AD_TYPE_NAMES = {
FLAGS: 'FLAGS', FLAGS: 'FLAGS',
INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS: 'INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS', INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS: 'INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS',
COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS: 'COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS', COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS: 'COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS',
INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS: 'INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS', INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS: 'INCOMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS',
COMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS: 'COMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS', COMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS: 'COMPLETE_LIST_OF_32_BIT_SERVICE_CLASS_UUIDS',
INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS: 'INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS', INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS: 'INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS',
COMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS: 'COMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS', COMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS: 'COMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS',
SHORTENED_LOCAL_NAME: 'SHORTENED_LOCAL_NAME', SHORTENED_LOCAL_NAME: 'SHORTENED_LOCAL_NAME',
COMPLETE_LOCAL_NAME: 'COMPLETE_LOCAL_NAME', COMPLETE_LOCAL_NAME: 'COMPLETE_LOCAL_NAME',
TX_POWER_LEVEL: 'TX_POWER_LEVEL', TX_POWER_LEVEL: 'TX_POWER_LEVEL',
CLASS_OF_DEVICE: 'CLASS_OF_DEVICE', CLASS_OF_DEVICE: 'CLASS_OF_DEVICE',
SIMPLE_PAIRING_HASH_C: 'SIMPLE_PAIRING_HASH_C', SIMPLE_PAIRING_HASH_C: 'SIMPLE_PAIRING_HASH_C',
SIMPLE_PAIRING_HASH_C_192: 'SIMPLE_PAIRING_HASH_C_192', SIMPLE_PAIRING_HASH_C_192: 'SIMPLE_PAIRING_HASH_C_192',
SIMPLE_PAIRING_RANDOMIZER_R: 'SIMPLE_PAIRING_RANDOMIZER_R', SIMPLE_PAIRING_RANDOMIZER_R: 'SIMPLE_PAIRING_RANDOMIZER_R',
SIMPLE_PAIRING_RANDOMIZER_R_192: 'SIMPLE_PAIRING_RANDOMIZER_R_192', SIMPLE_PAIRING_RANDOMIZER_R_192: 'SIMPLE_PAIRING_RANDOMIZER_R_192',
DEVICE_ID: 'DEVICE_ID', DEVICE_ID: 'DEVICE_ID',
SECURITY_MANAGER_TK_VALUE: 'SECURITY_MANAGER_TK_VALUE', SECURITY_MANAGER_TK_VALUE: 'SECURITY_MANAGER_TK_VALUE',
SECURITY_MANAGER_OUT_OF_BAND_FLAGS: 'SECURITY_MANAGER_OUT_OF_BAND_FLAGS', SECURITY_MANAGER_OUT_OF_BAND_FLAGS: 'SECURITY_MANAGER_OUT_OF_BAND_FLAGS',
PERIPHERAL_CONNECTION_INTERVAL_RANGE: 'PERIPHERAL_CONNECTION_INTERVAL_RANGE', PERIPHERAL_CONNECTION_INTERVAL_RANGE: 'PERIPHERAL_CONNECTION_INTERVAL_RANGE',
LIST_OF_16_BIT_SERVICE_SOLICITATION_UUIDS: 'LIST_OF_16_BIT_SERVICE_SOLICITATION_UUIDS', LIST_OF_16_BIT_SERVICE_SOLICITATION_UUIDS: 'LIST_OF_16_BIT_SERVICE_SOLICITATION_UUIDS',
LIST_OF_128_BIT_SERVICE_SOLICITATION_UUIDS: 'LIST_OF_128_BIT_SERVICE_SOLICITATION_UUIDS', LIST_OF_128_BIT_SERVICE_SOLICITATION_UUIDS: 'LIST_OF_128_BIT_SERVICE_SOLICITATION_UUIDS',
SERVICE_DATA_16_BIT_UUID: 'SERVICE_DATA_16_BIT_UUID', SERVICE_DATA: 'SERVICE_DATA',
PUBLIC_TARGET_ADDRESS: 'PUBLIC_TARGET_ADDRESS', SERVICE_DATA_16_BIT_UUID: 'SERVICE_DATA_16_BIT_UUID',
RANDOM_TARGET_ADDRESS: 'RANDOM_TARGET_ADDRESS', PUBLIC_TARGET_ADDRESS: 'PUBLIC_TARGET_ADDRESS',
APPEARANCE: 'APPEARANCE', RANDOM_TARGET_ADDRESS: 'RANDOM_TARGET_ADDRESS',
ADVERTISING_INTERVAL: 'ADVERTISING_INTERVAL', APPEARANCE: 'APPEARANCE',
LE_BLUETOOTH_DEVICE_ADDRESS: 'LE_BLUETOOTH_DEVICE_ADDRESS', ADVERTISING_INTERVAL: 'ADVERTISING_INTERVAL',
LE_ROLE: 'LE_ROLE', LE_BLUETOOTH_DEVICE_ADDRESS: 'LE_BLUETOOTH_DEVICE_ADDRESS',
SIMPLE_PAIRING_HASH_C_256: 'SIMPLE_PAIRING_HASH_C_256', LE_ROLE: 'LE_ROLE',
SIMPLE_PAIRING_RANDOMIZER_R_256: 'SIMPLE_PAIRING_RANDOMIZER_R_256', SIMPLE_PAIRING_HASH_C_256: 'SIMPLE_PAIRING_HASH_C_256',
LIST_OF_32_BIT_SERVICE_SOLICITATION_UUIDS: 'LIST_OF_32_BIT_SERVICE_SOLICITATION_UUIDS', SIMPLE_PAIRING_RANDOMIZER_R_256: 'SIMPLE_PAIRING_RANDOMIZER_R_256',
SERVICE_DATA_32_BIT_UUID: 'SERVICE_DATA_32_BIT_UUID', LIST_OF_32_BIT_SERVICE_SOLICITATION_UUIDS: 'LIST_OF_32_BIT_SERVICE_SOLICITATION_UUIDS',
SERVICE_DATA_128_BIT_UUID: 'SERVICE_DATA_128_BIT_UUID', SERVICE_DATA_32_BIT_UUID: 'SERVICE_DATA_32_BIT_UUID',
LE_SECURE_CONNECTIONS_CONFIRMATION_VALUE: 'LE_SECURE_CONNECTIONS_CONFIRMATION_VALUE', SERVICE_DATA_128_BIT_UUID: 'SERVICE_DATA_128_BIT_UUID',
LE_SECURE_CONNECTIONS_RANDOM_VALUE: 'LE_SECURE_CONNECTIONS_RANDOM_VALUE', LE_SECURE_CONNECTIONS_CONFIRMATION_VALUE: 'LE_SECURE_CONNECTIONS_CONFIRMATION_VALUE',
URI: 'URI', LE_SECURE_CONNECTIONS_RANDOM_VALUE: 'LE_SECURE_CONNECTIONS_RANDOM_VALUE',
INDOOR_POSITIONING: 'INDOOR_POSITIONING', URI: 'URI',
TRANSPORT_DISCOVERY_DATA: 'TRANSPORT_DISCOVERY_DATA', INDOOR_POSITIONING: 'INDOOR_POSITIONING',
LE_SUPPORTED_FEATURES: 'LE_SUPPORTED_FEATURES', TRANSPORT_DISCOVERY_DATA: 'TRANSPORT_DISCOVERY_DATA',
CHANNEL_MAP_UPDATE_INDICATION: 'CHANNEL_MAP_UPDATE_INDICATION', LE_SUPPORTED_FEATURES: 'LE_SUPPORTED_FEATURES',
PB_ADV: 'PB_ADV', CHANNEL_MAP_UPDATE_INDICATION: 'CHANNEL_MAP_UPDATE_INDICATION',
MESH_MESSAGE: 'MESH_MESSAGE', PB_ADV: 'PB_ADV',
MESH_BEACON: 'MESH_BEACON', MESH_MESSAGE: 'MESH_MESSAGE',
BIGINFO: 'BIGINFO', MESH_BEACON: 'MESH_BEACON',
BROADCAST_CODE: 'BROADCAST_CODE', BIGINFO: 'BIGINFO',
RESOLVABLE_SET_IDENTIFIER: 'RESOLVABLE_SET_IDENTIFIER', BROADCAST_CODE: 'BROADCAST_CODE',
ADVERTISING_INTERVAL_LONG: 'ADVERTISING_INTERVAL_LONG', RESOLVABLE_SET_IDENTIFIER: 'RESOLVABLE_SET_IDENTIFIER',
BROADCAST_NAME: 'BROADCAST_NAME', ADVERTISING_INTERVAL_LONG: 'ADVERTISING_INTERVAL_LONG',
ENCRYPTED_ADVERTISING_DATA: 'ENCRYPTED_ADVERTISING_DATA', THREE_D_INFORMATION_DATA: 'THREE_D_INFORMATION_DATA',
PERIODIC_ADVERTISING_RESPONSE_TIMING_INFORMATION: 'PERIODIC_ADVERTISING_RESPONSE_TIMING_INFORMATION', MANUFACTURER_SPECIFIC_DATA: 'MANUFACTURER_SPECIFIC_DATA'
ELECTRONIC_SHELF_LABEL: 'ELECTRONIC_SHELF_LABEL',
THREE_D_INFORMATION_DATA: 'THREE_D_INFORMATION_DATA',
MANUFACTURER_SPECIFIC_DATA: 'MANUFACTURER_SPECIFIC_DATA'
} }
LE_LIMITED_DISCOVERABLE_MODE_FLAG = 0x01 LE_LIMITED_DISCOVERABLE_MODE_FLAG = 0x01
@@ -1392,8 +820,8 @@ class AdvertisingData:
ad_structures = [] ad_structures = []
self.ad_structures = ad_structures[:] self.ad_structures = ad_structures[:]
@classmethod @staticmethod
def from_bytes(cls, data: bytes) -> AdvertisingData: def from_bytes(data):
instance = AdvertisingData() instance = AdvertisingData()
instance.append(data) instance.append(data)
return instance return instance
@@ -1482,11 +910,7 @@ class AdvertisingData:
ad_data_str = f'company={company_name}, data={ad_data[2:].hex()}' ad_data_str = f'company={company_name}, data={ad_data[2:].hex()}'
elif ad_type == AdvertisingData.APPEARANCE: elif ad_type == AdvertisingData.APPEARANCE:
ad_type_str = 'Appearance' ad_type_str = 'Appearance'
appearance = Appearance.from_int(struct.unpack_from('<H', ad_data, 0)[0]) ad_data_str = ad_data.hex()
ad_data_str = str(appearance)
elif ad_type == AdvertisingData.BROADCAST_NAME:
ad_type_str = 'Broadcast Name'
ad_data_str = ad_data.decode('utf-8')
else: else:
ad_type_str = AdvertisingData.AD_TYPE_NAMES.get(ad_type, f'0x{ad_type:02X}') ad_type_str = AdvertisingData.AD_TYPE_NAMES.get(ad_type, f'0x{ad_type:02X}')
ad_data_str = ad_data.hex() ad_data_str = ad_data.hex()
@@ -1495,7 +919,7 @@ class AdvertisingData:
# pylint: disable=too-many-return-statements # pylint: disable=too-many-return-statements
@staticmethod @staticmethod
def ad_data_to_object(ad_type: int, ad_data: bytes) -> AdvertisingDataObject: def ad_data_to_object(ad_type: int, ad_data: bytes) -> AdvertisingObject:
if ad_type in ( if ad_type in (
AdvertisingData.COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS, AdvertisingData.COMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS, AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
@@ -1530,14 +954,16 @@ class AdvertisingData:
AdvertisingData.SHORTENED_LOCAL_NAME, AdvertisingData.SHORTENED_LOCAL_NAME,
AdvertisingData.COMPLETE_LOCAL_NAME, AdvertisingData.COMPLETE_LOCAL_NAME,
AdvertisingData.URI, AdvertisingData.URI,
AdvertisingData.BROADCAST_NAME,
): ):
return ad_data.decode("utf-8") return ad_data.decode("utf-8")
if ad_type in (AdvertisingData.TX_POWER_LEVEL, AdvertisingData.FLAGS): if ad_type in (AdvertisingData.TX_POWER_LEVEL, AdvertisingData.FLAGS):
return cast(int, struct.unpack('B', ad_data)[0]) return cast(int, struct.unpack('B', ad_data)[0])
if ad_type in (AdvertisingData.ADVERTISING_INTERVAL,): if ad_type in (
AdvertisingData.APPEARANCE,
AdvertisingData.ADVERTISING_INTERVAL,
):
return cast(int, struct.unpack('<H', ad_data)[0]) return cast(int, struct.unpack('<H', ad_data)[0])
if ad_type == AdvertisingData.CLASS_OF_DEVICE: if ad_type == AdvertisingData.CLASS_OF_DEVICE:
@@ -1549,14 +975,9 @@ class AdvertisingData:
if ad_type == AdvertisingData.MANUFACTURER_SPECIFIC_DATA: if ad_type == AdvertisingData.MANUFACTURER_SPECIFIC_DATA:
return (cast(int, struct.unpack_from('<H', ad_data, 0)[0]), ad_data[2:]) return (cast(int, struct.unpack_from('<H', ad_data, 0)[0]), ad_data[2:])
if ad_type == AdvertisingData.APPEARANCE:
return Appearance.from_int(
cast(int, struct.unpack_from('<H', ad_data, 0)[0])
)
return ad_data return ad_data
def append(self, data: bytes) -> None: def append(self, data):
offset = 0 offset = 0
while offset + 1 < len(data): while offset + 1 < len(data):
length = data[offset] length = data[offset]
@@ -1567,27 +988,27 @@ class AdvertisingData:
self.ad_structures.append((ad_type, ad_data)) self.ad_structures.append((ad_type, ad_data))
offset += length offset += length
def get_all(self, type_id: int, raw: bool = False) -> List[AdvertisingDataObject]: def get_all(self, type_id: int, raw: bool = False) -> List[AdvertisingObject]:
''' '''
Get Advertising Data Structure(s) with a given type Get Advertising Data Structure(s) with a given type
Returns a (possibly empty) list of matches. Returns a (possibly empty) list of matches.
''' '''
def process_ad_data(ad_data: bytes) -> AdvertisingDataObject: def process_ad_data(ad_data: bytes) -> AdvertisingObject:
return ad_data if raw else self.ad_data_to_object(type_id, ad_data) return ad_data if raw else self.ad_data_to_object(type_id, ad_data)
return [process_ad_data(ad[1]) for ad in self.ad_structures if ad[0] == type_id] return [process_ad_data(ad[1]) for ad in self.ad_structures if ad[0] == type_id]
def get(self, type_id: int, raw: bool = False) -> Optional[AdvertisingDataObject]: def get(self, type_id: int, raw: bool = False) -> Optional[AdvertisingObject]:
''' '''
Get Advertising Data Structure(s) with a given type Get Advertising Data Structure(s) with a given type
Returns the first entry, or None if no structure matches. Returns the first entry, or None if no structure matches.
''' '''
all_objects = self.get_all(type_id, raw=raw) all = self.get_all(type_id, raw=raw)
return all_objects[0] if all_objects else None return all[0] if all else None
def __bytes__(self): def __bytes__(self):
return b''.join( return b''.join(
@@ -1630,13 +1051,3 @@ class ConnectionPHY:
def __str__(self): def __str__(self):
return f'ConnectionPHY(tx_phy={self.tx_phy}, rx_phy={self.rx_phy})' return f'ConnectionPHY(tx_phy={self.tx_phy}, rx_phy={self.rx_phy})'
# -----------------------------------------------------------------------------
# LE Role
# -----------------------------------------------------------------------------
class LeRole(enum.IntEnum):
PERIPHERAL_ONLY = 0x00
CENTRAL_ONLY = 0x01
BOTH_PERIPHERAL_PREFERRED = 0x02
BOTH_CENTRAL_PREFERRED = 0x03
+66 -92
View File
@@ -21,8 +21,6 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations
import logging import logging
import operator import operator
@@ -31,13 +29,11 @@ from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.primitives.asymmetric.ec import ( from cryptography.hazmat.primitives.asymmetric.ec import (
generate_private_key, generate_private_key,
ECDH, ECDH,
EllipticCurvePrivateKey,
EllipticCurvePublicNumbers, EllipticCurvePublicNumbers,
EllipticCurvePrivateNumbers, EllipticCurvePrivateNumbers,
SECP256R1, SECP256R1,
) )
from cryptography.hazmat.primitives import cmac from cryptography.hazmat.primitives import cmac
from typing import Tuple
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -50,18 +46,16 @@ logger = logging.getLogger(__name__)
# Classes # Classes
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class EccKey: class EccKey:
def __init__(self, private_key: EllipticCurvePrivateKey) -> None: def __init__(self, private_key):
self.private_key = private_key self.private_key = private_key
@classmethod @classmethod
def generate(cls) -> EccKey: def generate(cls):
private_key = generate_private_key(SECP256R1()) private_key = generate_private_key(SECP256R1())
return cls(private_key) return cls(private_key)
@classmethod @classmethod
def from_private_key_bytes( def from_private_key_bytes(cls, d_bytes, x_bytes, y_bytes):
cls, d_bytes: bytes, x_bytes: bytes, y_bytes: bytes
) -> EccKey:
d = int.from_bytes(d_bytes, byteorder='big', signed=False) d = int.from_bytes(d_bytes, byteorder='big', signed=False)
x = int.from_bytes(x_bytes, byteorder='big', signed=False) x = int.from_bytes(x_bytes, byteorder='big', signed=False)
y = int.from_bytes(y_bytes, byteorder='big', signed=False) y = int.from_bytes(y_bytes, byteorder='big', signed=False)
@@ -71,7 +65,7 @@ class EccKey:
return cls(private_key) return cls(private_key)
@property @property
def x(self) -> bytes: def x(self):
return ( return (
self.private_key.public_key() self.private_key.public_key()
.public_numbers() .public_numbers()
@@ -79,14 +73,14 @@ class EccKey:
) )
@property @property
def y(self) -> bytes: def y(self):
return ( return (
self.private_key.public_key() self.private_key.public_key()
.public_numbers() .public_numbers()
.y.to_bytes(32, byteorder='big') .y.to_bytes(32, byteorder='big')
) )
def dh(self, public_key_x: bytes, public_key_y: bytes) -> bytes: def dh(self, public_key_x, public_key_y):
x = int.from_bytes(public_key_x, byteorder='big', signed=False) x = int.from_bytes(public_key_x, byteorder='big', signed=False)
y = int.from_bytes(public_key_y, byteorder='big', signed=False) y = int.from_bytes(public_key_y, byteorder='big', signed=False)
public_key = EllipticCurvePublicNumbers(x, y, SECP256R1()).public_key() public_key = EllipticCurvePublicNumbers(x, y, SECP256R1()).public_key()
@@ -99,33 +93,14 @@ class EccKey:
# Functions # Functions
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def generate_prand() -> bytes: def xor(x, y):
'''Generates random 3 bytes, with the 2 most significant bits of 0b01.
See Bluetooth spec, Vol 6, Part E - Table 1.2.
'''
prand_bytes = secrets.token_bytes(6)
return prand_bytes[:2] + bytes([(prand_bytes[2] & 0b01111111) | 0b01000000])
# -----------------------------------------------------------------------------
def xor(x: bytes, y: bytes) -> bytes:
assert len(x) == len(y) assert len(x) == len(y)
return bytes(map(operator.xor, x, y)) return bytes(map(operator.xor, x, y))
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def reverse(input: bytes) -> bytes: def r():
'''
Returns bytes of input in reversed endianness.
'''
return input[::-1]
# -----------------------------------------------------------------------------
def r() -> bytes:
''' '''
Generate 16 bytes of random data Generate 16 bytes of random data
''' '''
@@ -133,20 +108,20 @@ def r() -> bytes:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def e(key: bytes, data: bytes) -> bytes: def e(key, data):
''' '''
AES-128 ECB, expecting byte-swapped inputs and producing a byte-swapped output. AES-128 ECB, expecting byte-swapped inputs and producing a byte-swapped output.
See Bluetooth spec Vol 3, Part H - 2.2.1 Security function e See Bluetooth spec Vol 3, Part H - 2.2.1 Security function e
''' '''
cipher = Cipher(algorithms.AES(reverse(key)), modes.ECB()) cipher = Cipher(algorithms.AES(bytes(reversed(key))), modes.ECB())
encryptor = cipher.encryptor() encryptor = cipher.encryptor()
return reverse(encryptor.update(reverse(data))) return bytes(reversed(encryptor.update(bytes(reversed(data)))))
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def ah(k: bytes, r: bytes) -> bytes: # pylint: disable=redefined-outer-name def ah(k, r): # pylint: disable=redefined-outer-name
''' '''
See Bluetooth spec Vol 3, Part H - 2.2.2 Random Address Hash function ah See Bluetooth spec Vol 3, Part H - 2.2.2 Random Address Hash function ah
''' '''
@@ -157,16 +132,7 @@ def ah(k: bytes, r: bytes) -> bytes: # pylint: disable=redefined-outer-name
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def c1( def c1(k, r, preq, pres, iat, rat, ia, ra): # pylint: disable=redefined-outer-name
k: bytes,
r: bytes,
preq: bytes,
pres: bytes,
iat: int,
rat: int,
ia: bytes,
ra: bytes,
) -> bytes: # pylint: disable=redefined-outer-name
''' '''
See Bluetooth spec, Vol 3, Part H - 2.2.3 Confirm value generation function c1 for See Bluetooth spec, Vol 3, Part H - 2.2.3 Confirm value generation function c1 for
LE Legacy Pairing LE Legacy Pairing
@@ -178,7 +144,7 @@ def c1(
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def s1(k: bytes, r1: bytes, r2: bytes) -> bytes: def s1(k, r1, r2):
''' '''
See Bluetooth spec, Vol 3, Part H - 2.2.4 Key generation function s1 for LE Legacy See Bluetooth spec, Vol 3, Part H - 2.2.4 Key generation function s1 for LE Legacy
Pairing Pairing
@@ -188,7 +154,7 @@ def s1(k: bytes, r1: bytes, r2: bytes) -> bytes:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def aes_cmac(m: bytes, k: bytes) -> bytes: def aes_cmac(m, k):
''' '''
See Bluetooth spec, Vol 3, Part H - 2.2.5 FunctionAES-CMAC See Bluetooth spec, Vol 3, Part H - 2.2.5 FunctionAES-CMAC
@@ -200,16 +166,20 @@ def aes_cmac(m: bytes, k: bytes) -> bytes:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def f4(u: bytes, v: bytes, x: bytes, z: bytes) -> bytes: def f4(u, v, x, z):
''' '''
See Bluetooth spec, Vol 3, Part H - 2.2.6 LE Secure Connections Confirm Value See Bluetooth spec, Vol 3, Part H - 2.2.6 LE Secure Connections Confirm Value
Generation Function f4 Generation Function f4
''' '''
return reverse(aes_cmac(reverse(u) + reverse(v) + z, reverse(x))) return bytes(
reversed(
aes_cmac(bytes(reversed(u)) + bytes(reversed(v)) + z, bytes(reversed(x)))
)
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def f5(w: bytes, n1: bytes, n2: bytes, a1: bytes, a2: bytes) -> Tuple[bytes, bytes]: def f5(w, n1, n2, a1, a2):
''' '''
See Bluetooth spec, Vol 3, Part H - 2.2.7 LE Secure Connections Key Generation See Bluetooth spec, Vol 3, Part H - 2.2.7 LE Secure Connections Key Generation
Function f5 Function f5
@@ -217,83 +187,87 @@ def f5(w: bytes, n1: bytes, n2: bytes, a1: bytes, a2: bytes) -> Tuple[bytes, byt
NOTE: this returns a tuple: (MacKey, LTK) in little-endian byte order NOTE: this returns a tuple: (MacKey, LTK) in little-endian byte order
''' '''
salt = bytes.fromhex('6C888391AAF5A53860370BDB5A6083BE') salt = bytes.fromhex('6C888391AAF5A53860370BDB5A6083BE')
t = aes_cmac(reverse(w), salt) t = aes_cmac(bytes(reversed(w)), salt)
key_id = bytes([0x62, 0x74, 0x6C, 0x65]) key_id = bytes([0x62, 0x74, 0x6C, 0x65])
return ( return (
reverse( bytes(
aes_cmac( reversed(
bytes([0]) aes_cmac(
+ key_id bytes([0])
+ reverse(n1) + key_id
+ reverse(n2) + bytes(reversed(n1))
+ reverse(a1) + bytes(reversed(n2))
+ reverse(a2) + bytes(reversed(a1))
+ bytes([1, 0]), + bytes(reversed(a2))
t, + bytes([1, 0]),
t,
)
) )
), ),
reverse( bytes(
aes_cmac( reversed(
bytes([1]) aes_cmac(
+ key_id bytes([1])
+ reverse(n1) + key_id
+ reverse(n2) + bytes(reversed(n1))
+ reverse(a1) + bytes(reversed(n2))
+ reverse(a2) + bytes(reversed(a1))
+ bytes([1, 0]), + bytes(reversed(a2))
t, + bytes([1, 0]),
t,
)
) )
), ),
) )
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def f6( def f6(w, n1, n2, r, io_cap, a1, a2): # pylint: disable=redefined-outer-name
w: bytes, n1: bytes, n2: bytes, r: bytes, io_cap: bytes, a1: bytes, a2: bytes
) -> bytes: # pylint: disable=redefined-outer-name
''' '''
See Bluetooth spec, Vol 3, Part H - 2.2.8 LE Secure Connections Check Value See Bluetooth spec, Vol 3, Part H - 2.2.8 LE Secure Connections Check Value
Generation Function f6 Generation Function f6
''' '''
return reverse( return bytes(
aes_cmac( reversed(
reverse(n1) aes_cmac(
+ reverse(n2) bytes(reversed(n1))
+ reverse(r) + bytes(reversed(n2))
+ reverse(io_cap) + bytes(reversed(r))
+ reverse(a1) + bytes(reversed(io_cap))
+ reverse(a2), + bytes(reversed(a1))
reverse(w), + bytes(reversed(a2)),
bytes(reversed(w)),
)
) )
) )
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def g2(u: bytes, v: bytes, x: bytes, y: bytes) -> int: def g2(u, v, x, y):
''' '''
See Bluetooth spec, Vol 3, Part H - 2.2.9 LE Secure Connections Numeric Comparison See Bluetooth spec, Vol 3, Part H - 2.2.9 LE Secure Connections Numeric Comparison
Value Generation Function g2 Value Generation Function g2
''' '''
return int.from_bytes( return int.from_bytes(
aes_cmac( aes_cmac(
reverse(u) + reverse(v) + reverse(y), bytes(reversed(u)) + bytes(reversed(v)) + bytes(reversed(y)),
reverse(x), bytes(reversed(x)),
)[-4:], )[-4:],
byteorder='big', byteorder='big',
) )
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def h6(w: bytes, key_id: bytes) -> bytes: def h6(w, key_id):
''' '''
See Bluetooth spec, Vol 3, Part H - 2.2.10 Link key conversion function h6 See Bluetooth spec, Vol 3, Part H - 2.2.10 Link key conversion function h6
''' '''
return reverse(aes_cmac(key_id, reverse(w))) return aes_cmac(key_id, w)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def h7(salt: bytes, w: bytes) -> bytes: def h7(salt, w):
''' '''
See Bluetooth spec, Vol 3, Part H - 2.2.11 Link key conversion function h7 See Bluetooth spec, Vol 3, Part H - 2.2.11 Link key conversion function h7
''' '''
return reverse(aes_cmac(reverse(w), salt)) return aes_cmac(w, salt)
+365 -1888
View File
File diff suppressed because it is too large Load Diff
+32 -28
View File
@@ -19,17 +19,12 @@ like loading firmware after a cold start.
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations import abc
import logging import logging
import pathlib import pathlib
import platform import platform
from typing import Dict, Iterable, Optional, Type, TYPE_CHECKING from . import rtk
from . import rtk, intel
from .common import Driver
if TYPE_CHECKING:
from bumble.host import Host
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -37,31 +32,40 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
class Driver(abc.ABC):
"""Base class for drivers."""
@staticmethod
async def for_host(_host):
"""Return a driver instance for a host.
Args:
host: Host object for which a driver should be created.
Returns:
A Driver instance if a driver should be instantiated for this host, or
None if no driver instance of this class is needed.
"""
return None
@abc.abstractmethod
async def init_controller(self):
"""Initialize the controller."""
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Functions # Functions
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def get_driver_for_host(host: Host) -> Optional[Driver]: async def get_driver_for_host(host):
"""Probe diver classes until one returns a valid instance for a host, or none is """Probe all known diver classes until one returns a valid instance for a host,
found. or none is found.
If a "driver" HCI metadata entry is present, only that driver class will be probed.
""" """
driver_classes: Dict[str, Type[Driver]] = {"rtk": rtk.Driver, "intel": intel.Driver} if driver := await rtk.Driver.for_host(host):
probe_list: Iterable[str] logger.debug("Instantiated RTK driver")
if driver_name := host.hci_metadata.get("driver"): return driver
# Only probe a single driver
probe_list = [driver_name]
else:
# Probe all drivers
probe_list = driver_classes.keys()
for driver_name in probe_list:
if driver_class := driver_classes.get(driver_name):
logger.debug(f"Probing driver class: {driver_name}")
if driver := await driver_class.for_host(host):
logger.debug(f"Instantiated {driver_name} driver")
return driver
else:
logger.debug(f"Skipping unknown driver class: {driver_name}")
return None return None
-45
View File
@@ -1,45 +0,0 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Common types for drivers.
"""
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import abc
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
class Driver(abc.ABC):
"""Base class for drivers."""
@staticmethod
async def for_host(_host):
"""Return a driver instance for a host.
Args:
host: Host object for which a driver should be created.
Returns:
A Driver instance if a driver should be instantiated for this host, or
None if no driver instance of this class is needed.
"""
return None
@abc.abstractmethod
async def init_controller(self):
"""Initialize the controller."""
-102
View File
@@ -1,102 +0,0 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import logging
from bumble.drivers import common
from bumble.hci import (
hci_vendor_command_op_code, # type: ignore
HCI_Command,
HCI_Reset_Command,
)
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constant
# -----------------------------------------------------------------------------
INTEL_USB_PRODUCTS = {
# Intel AX210
(0x8087, 0x0032),
# Intel BE200
(0x8087, 0x0036),
}
# -----------------------------------------------------------------------------
# HCI Commands
# -----------------------------------------------------------------------------
HCI_INTEL_DDC_CONFIG_WRITE_COMMAND = hci_vendor_command_op_code(0xFC8B) # type: ignore
HCI_INTEL_DDC_CONFIG_WRITE_PAYLOAD = [0x03, 0xE4, 0x02, 0x00]
HCI_Command.register_commands(globals())
@HCI_Command.command( # type: ignore
fields=[("params", "*")],
return_parameters_fields=[
("params", "*"),
],
)
class Hci_Intel_DDC_Config_Write_Command(HCI_Command):
pass
class Driver(common.Driver):
def __init__(self, host):
self.host = host
@staticmethod
def check(host):
driver = host.hci_metadata.get("driver")
if driver == "intel":
return True
vendor_id = host.hci_metadata.get("vendor_id")
product_id = host.hci_metadata.get("product_id")
if vendor_id is None or product_id is None:
logger.debug("USB metadata not sufficient")
return False
if (vendor_id, product_id) not in INTEL_USB_PRODUCTS:
logger.debug(
f"USB device ({vendor_id:04X}, {product_id:04X}) " "not in known list"
)
return False
return True
@classmethod
async def for_host(cls, host, force=False): # type: ignore
# Only instantiate this driver if explicitly selected
if not force and not cls.check(host):
return None
return cls(host)
async def init_controller(self):
self.host.ready = True
await self.host.send_command(HCI_Reset_Command(), check_result=True)
await self.host.send_command(
Hci_Intel_DDC_Config_Write_Command(
params=HCI_INTEL_DDC_CONFIG_WRITE_PAYLOAD
)
)
+4 -11
View File
@@ -41,7 +41,7 @@ from bumble.hci import (
HCI_Reset_Command, HCI_Reset_Command,
HCI_Read_Local_Version_Information_Command, HCI_Read_Local_Version_Information_Command,
) )
from bumble.drivers import common
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -285,7 +285,7 @@ class Firmware:
) )
class Driver(common.Driver): class Driver:
@dataclass @dataclass
class DriverInfo: class DriverInfo:
rom: int rom: int
@@ -470,12 +470,8 @@ class Driver(common.Driver):
logger.debug("USB metadata not found") logger.debug("USB metadata not found")
return False return False
if host.hci_metadata.get('driver') == 'rtk': vendor_id = host.hci_metadata.get("vendor_id", None)
# Forced driver product_id = host.hci_metadata.get("product_id", None)
return True
vendor_id = host.hci_metadata.get("vendor_id")
product_id = host.hci_metadata.get("product_id")
if vendor_id is None or product_id is None: if vendor_id is None or product_id is None:
logger.debug("USB metadata not sufficient") logger.debug("USB metadata not sufficient")
return False return False
@@ -490,9 +486,6 @@ class Driver(common.Driver):
@classmethod @classmethod
async def driver_info_for_host(cls, host): async def driver_info_for_host(cls, host):
await host.send_command(HCI_Reset_Command(), check_result=True)
host.ready = True # Needed to let the host know the controller is ready.
response = await host.send_command( response = await host.send_command(
HCI_Read_Local_Version_Information_Command(), check_result=True HCI_Read_Local_Version_Information_Command(), check_result=True
) )
-1
View File
@@ -36,7 +36,6 @@ logger = logging.getLogger(__name__)
# Classes # Classes
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class GenericAccessService(Service): class GenericAccessService(Service):
def __init__(self, device_name, appearance=(0, 0)): def __init__(self, device_name, appearance=(0, 0)):
+59 -186
View File
@@ -23,28 +23,16 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations from __future__ import annotations
import asyncio
import enum import enum
import functools import functools
import logging import logging
import struct import struct
from typing import ( from typing import Optional, Sequence, Iterable, List, Union
Callable,
Dict,
Iterable,
List,
Optional,
Sequence,
Union,
TYPE_CHECKING,
)
from bumble.colors import color from .colors import color
from bumble.core import UUID from .core import UUID, get_dict_key_by_value
from bumble.att import Attribute, AttributeValue from .att import Attribute
if TYPE_CHECKING:
from bumble.gatt_client import AttributeProxy
from bumble.device import Connection
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -105,35 +93,20 @@ GATT_RECONNECTION_CONFIGURATION_SERVICE = UUID.from_16_bits(0x1829, 'Reconne
GATT_INSULIN_DELIVERY_SERVICE = UUID.from_16_bits(0x183A, 'Insulin Delivery') GATT_INSULIN_DELIVERY_SERVICE = UUID.from_16_bits(0x183A, 'Insulin Delivery')
GATT_BINARY_SENSOR_SERVICE = UUID.from_16_bits(0x183B, 'Binary Sensor') GATT_BINARY_SENSOR_SERVICE = UUID.from_16_bits(0x183B, 'Binary Sensor')
GATT_EMERGENCY_CONFIGURATION_SERVICE = UUID.from_16_bits(0x183C, 'Emergency Configuration') GATT_EMERGENCY_CONFIGURATION_SERVICE = UUID.from_16_bits(0x183C, 'Emergency Configuration')
GATT_AUTHORIZATION_CONTROL_SERVICE = UUID.from_16_bits(0x183D, 'Authorization Control')
GATT_PHYSICAL_ACTIVITY_MONITOR_SERVICE = UUID.from_16_bits(0x183E, 'Physical Activity Monitor') GATT_PHYSICAL_ACTIVITY_MONITOR_SERVICE = UUID.from_16_bits(0x183E, 'Physical Activity Monitor')
GATT_ELAPSED_TIME_SERVICE = UUID.from_16_bits(0x183F, 'Elapsed Time')
GATT_GENERIC_HEALTH_SENSOR_SERVICE = UUID.from_16_bits(0x1840, 'Generic Health Sensor')
GATT_AUDIO_INPUT_CONTROL_SERVICE = UUID.from_16_bits(0x1843, 'Audio Input Control') GATT_AUDIO_INPUT_CONTROL_SERVICE = UUID.from_16_bits(0x1843, 'Audio Input Control')
GATT_VOLUME_CONTROL_SERVICE = UUID.from_16_bits(0x1844, 'Volume Control') GATT_VOLUME_CONTROL_SERVICE = UUID.from_16_bits(0x1844, 'Volume Control')
GATT_VOLUME_OFFSET_CONTROL_SERVICE = UUID.from_16_bits(0x1845, 'Volume Offset Control') GATT_VOLUME_OFFSET_CONTROL_SERVICE = UUID.from_16_bits(0x1845, 'Volume Offset Control')
GATT_COORDINATED_SET_IDENTIFICATION_SERVICE = UUID.from_16_bits(0x1846, 'Coordinated Set Identification') GATT_COORDINATED_SET_IDENTIFICATION_SERVICE = UUID.from_16_bits(0x1846, 'Coordinated Set Identification Service')
GATT_DEVICE_TIME_SERVICE = UUID.from_16_bits(0x1847, 'Device Time') GATT_DEVICE_TIME_SERVICE = UUID.from_16_bits(0x1847, 'Device Time')
GATT_MEDIA_CONTROL_SERVICE = UUID.from_16_bits(0x1848, 'Media Control') GATT_MEDIA_CONTROL_SERVICE = UUID.from_16_bits(0x1848, 'Media Control Service')
GATT_GENERIC_MEDIA_CONTROL_SERVICE = UUID.from_16_bits(0x1849, 'Generic Media Control') GATT_GENERIC_MEDIA_CONTROL_SERVICE = UUID.from_16_bits(0x1849, 'Generic Media Control Service')
GATT_CONSTANT_TONE_EXTENSION_SERVICE = UUID.from_16_bits(0x184A, 'Constant Tone Extension') GATT_CONSTANT_TONE_EXTENSION_SERVICE = UUID.from_16_bits(0x184A, 'Constant Tone Extension')
GATT_TELEPHONE_BEARER_SERVICE = UUID.from_16_bits(0x184B, 'Telephone Bearer') GATT_TELEPHONE_BEARER_SERVICE = UUID.from_16_bits(0x184B, 'Telephone Bearer Service')
GATT_GENERIC_TELEPHONE_BEARER_SERVICE = UUID.from_16_bits(0x184C, 'Generic Telephone Bearer') GATT_GENERIC_TELEPHONE_BEARER_SERVICE = UUID.from_16_bits(0x184C, 'Generic Telephone Bearer Service')
GATT_MICROPHONE_CONTROL_SERVICE = UUID.from_16_bits(0x184D, 'Microphone Control') GATT_MICROPHONE_CONTROL_SERVICE = UUID.from_16_bits(0x184D, 'Microphone Control')
GATT_AUDIO_STREAM_CONTROL_SERVICE = UUID.from_16_bits(0x184E, 'Audio Stream Control')
GATT_BROADCAST_AUDIO_SCAN_SERVICE = UUID.from_16_bits(0x184F, 'Broadcast Audio Scan')
GATT_PUBLISHED_AUDIO_CAPABILITIES_SERVICE = UUID.from_16_bits(0x1850, 'Published Audio Capabilities')
GATT_BASIC_AUDIO_ANNOUNCEMENT_SERVICE = UUID.from_16_bits(0x1851, 'Basic Audio Announcement')
GATT_BROADCAST_AUDIO_ANNOUNCEMENT_SERVICE = UUID.from_16_bits(0x1852, 'Broadcast Audio Announcement')
GATT_COMMON_AUDIO_SERVICE = UUID.from_16_bits(0x1853, 'Common Audio')
GATT_HEARING_ACCESS_SERVICE = UUID.from_16_bits(0x1854, 'Hearing Access')
GATT_TELEPHONY_AND_MEDIA_AUDIO_SERVICE = UUID.from_16_bits(0x1855, 'Telephony and Media Audio')
GATT_PUBLIC_BROADCAST_ANNOUNCEMENT_SERVICE = UUID.from_16_bits(0x1856, 'Public Broadcast Announcement')
GATT_ELECTRONIC_SHELF_LABEL_SERVICE = UUID.from_16_bits(0X1857, 'Electronic Shelf Label')
GATT_GAMING_AUDIO_SERVICE = UUID.from_16_bits(0x1858, 'Gaming Audio')
GATT_MESH_PROXY_SOLICITATION_SERVICE = UUID.from_16_bits(0x1859, 'Mesh Audio Solicitation')
# Attribute Types # Types
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE = UUID.from_16_bits(0x2800, 'Primary Service') GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE = UUID.from_16_bits(0x2800, 'Primary Service')
GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE = UUID.from_16_bits(0x2801, 'Secondary Service') GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE = UUID.from_16_bits(0x2801, 'Secondary Service')
GATT_INCLUDE_ATTRIBUTE_TYPE = UUID.from_16_bits(0x2802, 'Include') GATT_INCLUDE_ATTRIBUTE_TYPE = UUID.from_16_bits(0x2802, 'Include')
@@ -156,8 +129,6 @@ GATT_ENVIRONMENTAL_SENSING_MEASUREMENT_DESCRIPTOR = UUID.from_16_bits(0x290C,
GATT_ENVIRONMENTAL_SENSING_TRIGGER_DESCRIPTOR = UUID.from_16_bits(0x290D, 'Environmental Sensing Trigger Setting') GATT_ENVIRONMENTAL_SENSING_TRIGGER_DESCRIPTOR = UUID.from_16_bits(0x290D, 'Environmental Sensing Trigger Setting')
GATT_TIME_TRIGGER_DESCRIPTOR = UUID.from_16_bits(0x290E, 'Time Trigger Setting') GATT_TIME_TRIGGER_DESCRIPTOR = UUID.from_16_bits(0x290E, 'Time Trigger Setting')
GATT_COMPLETE_BR_EDR_TRANSPORT_BLOCK_DATA_DESCRIPTOR = UUID.from_16_bits(0x290F, 'Complete BR-EDR Transport Block Data') GATT_COMPLETE_BR_EDR_TRANSPORT_BLOCK_DATA_DESCRIPTOR = UUID.from_16_bits(0x290F, 'Complete BR-EDR Transport Block Data')
GATT_OBSERVATION_SCHEDULE_DESCRIPTOR = UUID.from_16_bits(0x290F, 'Observation Schedule')
GATT_VALID_RANGE_AND_ACCURACY_DESCRIPTOR = UUID.from_16_bits(0x290F, 'Valid Range And Accuracy')
# Device Information Service # Device Information Service
GATT_SYSTEM_ID_CHARACTERISTIC = UUID.from_16_bits(0x2A23, 'System ID') GATT_SYSTEM_ID_CHARACTERISTIC = UUID.from_16_bits(0x2A23, 'System ID')
@@ -185,96 +156,6 @@ GATT_HEART_RATE_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2A39, 'Heart
# Battery Service # Battery Service
GATT_BATTERY_LEVEL_CHARACTERISTIC = UUID.from_16_bits(0x2A19, 'Battery Level') GATT_BATTERY_LEVEL_CHARACTERISTIC = UUID.from_16_bits(0x2A19, 'Battery Level')
# Telephony And Media Audio Service (TMAS)
GATT_TMAP_ROLE_CHARACTERISTIC = UUID.from_16_bits(0x2B51, 'TMAP Role')
# Audio Input Control Service (AICS)
GATT_AUDIO_INPUT_STATE_CHARACTERISTIC = UUID.from_16_bits(0x2B77, 'Audio Input State')
GATT_GAIN_SETTINGS_ATTRIBUTE_CHARACTERISTIC = UUID.from_16_bits(0x2B78, 'Gain Settings Attribute')
GATT_AUDIO_INPUT_TYPE_CHARACTERISTIC = UUID.from_16_bits(0x2B79, 'Audio Input Type')
GATT_AUDIO_INPUT_STATUS_CHARACTERISTIC = UUID.from_16_bits(0x2B7A, 'Audio Input Status')
GATT_AUDIO_INPUT_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2B7B, 'Audio Input Control Point')
GATT_AUDIO_INPUT_DESCRIPTION_CHARACTERISTIC = UUID.from_16_bits(0x2B7C, 'Audio Input Description')
# Volume Control Service (VCS)
GATT_VOLUME_STATE_CHARACTERISTIC = UUID.from_16_bits(0x2B7D, 'Volume State')
GATT_VOLUME_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2B7E, 'Volume Control Point')
GATT_VOLUME_FLAGS_CHARACTERISTIC = UUID.from_16_bits(0x2B7F, 'Volume Flags')
# Volume Offset Control Service (VOCS)
GATT_VOLUME_OFFSET_STATE_CHARACTERISTIC = UUID.from_16_bits(0x2B80, 'Volume Offset State')
GATT_AUDIO_LOCATION_CHARACTERISTIC = UUID.from_16_bits(0x2B81, 'Audio Location')
GATT_VOLUME_OFFSET_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2B82, 'Volume Offset Control Point')
GATT_AUDIO_OUTPUT_DESCRIPTION_CHARACTERISTIC = UUID.from_16_bits(0x2B83, 'Audio Output Description')
# Coordinated Set Identification Service (CSIS)
GATT_SET_IDENTITY_RESOLVING_KEY_CHARACTERISTIC = UUID.from_16_bits(0x2B84, 'Set Identity Resolving Key')
GATT_COORDINATED_SET_SIZE_CHARACTERISTIC = UUID.from_16_bits(0x2B85, 'Coordinated Set Size')
GATT_SET_MEMBER_LOCK_CHARACTERISTIC = UUID.from_16_bits(0x2B86, 'Set Member Lock')
GATT_SET_MEMBER_RANK_CHARACTERISTIC = UUID.from_16_bits(0x2B87, 'Set Member Rank')
# Media Control Service (MCS)
GATT_MEDIA_PLAYER_NAME_CHARACTERISTIC = UUID.from_16_bits(0x2B93, 'Media Player Name')
GATT_MEDIA_PLAYER_ICON_OBJECT_ID_CHARACTERISTIC = UUID.from_16_bits(0x2B94, 'Media Player Icon Object ID')
GATT_MEDIA_PLAYER_ICON_URL_CHARACTERISTIC = UUID.from_16_bits(0x2B95, 'Media Player Icon URL')
GATT_TRACK_CHANGED_CHARACTERISTIC = UUID.from_16_bits(0x2B96, 'Track Changed')
GATT_TRACK_TITLE_CHARACTERISTIC = UUID.from_16_bits(0x2B97, 'Track Title')
GATT_TRACK_DURATION_CHARACTERISTIC = UUID.from_16_bits(0x2B98, 'Track Duration')
GATT_TRACK_POSITION_CHARACTERISTIC = UUID.from_16_bits(0x2B99, 'Track Position')
GATT_PLAYBACK_SPEED_CHARACTERISTIC = UUID.from_16_bits(0x2B9A, 'Playback Speed')
GATT_SEEKING_SPEED_CHARACTERISTIC = UUID.from_16_bits(0x2B9B, 'Seeking Speed')
GATT_CURRENT_TRACK_SEGMENTS_OBJECT_ID_CHARACTERISTIC = UUID.from_16_bits(0x2B9C, 'Current Track Segments Object ID')
GATT_CURRENT_TRACK_OBJECT_ID_CHARACTERISTIC = UUID.from_16_bits(0x2B9D, 'Current Track Object ID')
GATT_NEXT_TRACK_OBJECT_ID_CHARACTERISTIC = UUID.from_16_bits(0x2B9E, 'Next Track Object ID')
GATT_PARENT_GROUP_OBJECT_ID_CHARACTERISTIC = UUID.from_16_bits(0x2B9F, 'Parent Group Object ID')
GATT_CURRENT_GROUP_OBJECT_ID_CHARACTERISTIC = UUID.from_16_bits(0x2BA0, 'Current Group Object ID')
GATT_PLAYING_ORDER_CHARACTERISTIC = UUID.from_16_bits(0x2BA1, 'Playing Order')
GATT_PLAYING_ORDERS_SUPPORTED_CHARACTERISTIC = UUID.from_16_bits(0x2BA2, 'Playing Orders Supported')
GATT_MEDIA_STATE_CHARACTERISTIC = UUID.from_16_bits(0x2BA3, 'Media State')
GATT_MEDIA_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2BA4, 'Media Control Point')
GATT_MEDIA_CONTROL_POINT_OPCODES_SUPPORTED_CHARACTERISTIC = UUID.from_16_bits(0x2BA5, 'Media Control Point Opcodes Supported')
GATT_SEARCH_RESULTS_OBJECT_ID_CHARACTERISTIC = UUID.from_16_bits(0x2BA6, 'Search Results Object ID')
GATT_SEARCH_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2BA7, 'Search Control Point')
GATT_CONTENT_CONTROL_ID_CHARACTERISTIC = UUID.from_16_bits(0x2BBA, 'Content Control Id')
# Telephone Bearer Service (TBS)
GATT_BEARER_PROVIDER_NAME_CHARACTERISTIC = UUID.from_16_bits(0x2BB4, 'Bearer Provider Name')
GATT_BEARER_UCI_CHARACTERISTIC = UUID.from_16_bits(0x2BB5, 'Bearer UCI')
GATT_BEARER_TECHNOLOGY_CHARACTERISTIC = UUID.from_16_bits(0x2BB6, 'Bearer Technology')
GATT_BEARER_URI_SCHEMES_SUPPORTED_LIST_CHARACTERISTIC = UUID.from_16_bits(0x2BB7, 'Bearer URI Schemes Supported List')
GATT_BEARER_SIGNAL_STRENGTH_CHARACTERISTIC = UUID.from_16_bits(0x2BB8, 'Bearer Signal Strength')
GATT_BEARER_SIGNAL_STRENGTH_REPORTING_INTERVAL_CHARACTERISTIC = UUID.from_16_bits(0x2BB9, 'Bearer Signal Strength Reporting Interval')
GATT_BEARER_LIST_CURRENT_CALLS_CHARACTERISTIC = UUID.from_16_bits(0x2BBA, 'Bearer List Current Calls')
GATT_CONTENT_CONTROL_ID_CHARACTERISTIC = UUID.from_16_bits(0x2BBB, 'Content Control ID')
GATT_STATUS_FLAGS_CHARACTERISTIC = UUID.from_16_bits(0x2BBC, 'Status Flags')
GATT_INCOMING_CALL_TARGET_BEARER_URI_CHARACTERISTIC = UUID.from_16_bits(0x2BBD, 'Incoming Call Target Bearer URI')
GATT_CALL_STATE_CHARACTERISTIC = UUID.from_16_bits(0x2BBE, 'Call State')
GATT_CALL_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2BBF, 'Call Control Point')
GATT_CALL_CONTROL_POINT_OPTIONAL_OPCODES_CHARACTERISTIC = UUID.from_16_bits(0x2BC0, 'Call Control Point Optional Opcodes')
GATT_TERMINATION_REASON_CHARACTERISTIC = UUID.from_16_bits(0x2BC1, 'Termination Reason')
GATT_INCOMING_CALL_CHARACTERISTIC = UUID.from_16_bits(0x2BC2, 'Incoming Call')
GATT_CALL_FRIENDLY_NAME_CHARACTERISTIC = UUID.from_16_bits(0x2BC3, 'Call Friendly Name')
# Microphone Control Service (MICS)
GATT_MUTE_CHARACTERISTIC = UUID.from_16_bits(0x2BC3, 'Mute')
# Audio Stream Control Service (ASCS)
GATT_SINK_ASE_CHARACTERISTIC = UUID.from_16_bits(0x2BC4, 'Sink ASE')
GATT_SOURCE_ASE_CHARACTERISTIC = UUID.from_16_bits(0x2BC5, 'Source ASE')
GATT_ASE_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2BC6, 'ASE Control Point')
# Broadcast Audio Scan Service (BASS)
GATT_BROADCAST_AUDIO_SCAN_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2BC7, 'Broadcast Audio Scan Control Point')
GATT_BROADCAST_RECEIVE_STATE_CHARACTERISTIC = UUID.from_16_bits(0x2BC8, 'Broadcast Receive State')
# Published Audio Capabilities Service (PACS)
GATT_SINK_PAC_CHARACTERISTIC = UUID.from_16_bits(0x2BC9, 'Sink PAC')
GATT_SINK_AUDIO_LOCATION_CHARACTERISTIC = UUID.from_16_bits(0x2BCA, 'Sink Audio Location')
GATT_SOURCE_PAC_CHARACTERISTIC = UUID.from_16_bits(0x2BCB, 'Source PAC')
GATT_SOURCE_AUDIO_LOCATION_CHARACTERISTIC = UUID.from_16_bits(0x2BCC, 'Source Audio Location')
GATT_AVAILABLE_AUDIO_CONTEXTS_CHARACTERISTIC = UUID.from_16_bits(0x2BCD, 'Available Audio Contexts')
GATT_SUPPORTED_AUDIO_CONTEXTS_CHARACTERISTIC = UUID.from_16_bits(0x2BCE, 'Supported Audio Contexts')
# ASHA Service # ASHA Service
GATT_ASHA_SERVICE = UUID.from_16_bits(0xFDF0, 'Audio Streaming for Hearing Aid') GATT_ASHA_SERVICE = UUID.from_16_bits(0xFDF0, 'Audio Streaming for Hearing Aid')
GATT_ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC = UUID('6333651e-c481-4a3e-9169-7c902aad37bb', 'ReadOnlyProperties') GATT_ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC = UUID('6333651e-c481-4a3e-9169-7c902aad37bb', 'ReadOnlyProperties')
@@ -296,9 +177,6 @@ GATT_BOOT_KEYBOARD_INPUT_REPORT_CHARACTERISTIC = UUID.from_16_bi
GATT_CURRENT_TIME_CHARACTERISTIC = UUID.from_16_bits(0x2A2B, 'Current Time') GATT_CURRENT_TIME_CHARACTERISTIC = UUID.from_16_bits(0x2A2B, 'Current Time')
GATT_BOOT_KEYBOARD_OUTPUT_REPORT_CHARACTERISTIC = UUID.from_16_bits(0x2A32, 'Boot Keyboard Output Report') GATT_BOOT_KEYBOARD_OUTPUT_REPORT_CHARACTERISTIC = UUID.from_16_bits(0x2A32, 'Boot Keyboard Output Report')
GATT_CENTRAL_ADDRESS_RESOLUTION__CHARACTERISTIC = UUID.from_16_bits(0x2AA6, 'Central Address Resolution') GATT_CENTRAL_ADDRESS_RESOLUTION__CHARACTERISTIC = UUID.from_16_bits(0x2AA6, 'Central Address Resolution')
GATT_CLIENT_SUPPORTED_FEATURES_CHARACTERISTIC = UUID.from_16_bits(0x2B29, 'Client Supported Features')
GATT_DATABASE_HASH_CHARACTERISTIC = UUID.from_16_bits(0x2B2A, 'Database Hash')
GATT_SERVER_SUPPORTED_FEATURES_CHARACTERISTIC = UUID.from_16_bits(0x2B3A, 'Server Supported Features')
# fmt: on # fmt: on
# pylint: enable=line-too-long # pylint: enable=line-too-long
@@ -342,11 +220,9 @@ class Service(Attribute):
uuid = UUID(uuid) uuid = UUID(uuid)
super().__init__( super().__init__(
( GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE if primary
if primary else GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE,
else GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE
),
Attribute.READABLE, Attribute.READABLE,
uuid.to_pdu_bytes(), uuid.to_pdu_bytes(),
) )
@@ -382,12 +258,9 @@ class TemplateService(Service):
UUID: UUID UUID: UUID
def __init__( def __init__(
self, self, characteristics: List[Characteristic], primary: bool = True
characteristics: List[Characteristic],
primary: bool = True,
included_services: List[Service] = [],
) -> None: ) -> None:
super().__init__(self.UUID, characteristics, primary, included_services) super().__init__(self.UUID, characteristics, primary)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -536,43 +409,56 @@ class CharacteristicDeclaration(Attribute):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class CharacteristicValue(AttributeValue): class CharacteristicValue:
"""Same as AttributeValue, for backward compatibility""" '''
Characteristic value where reading and/or writing is delegated to functions
passed as arguments to the constructor.
'''
def __init__(self, read=None, write=None):
self._read = read
self._write = write
def read(self, connection):
return self._read(connection) if self._read else b''
def write(self, connection, value):
if self._write:
self._write(connection, value)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class CharacteristicAdapter: class CharacteristicAdapter:
''' '''
An adapter that can adapt Characteristic and AttributeProxy objects An adapter that can adapt any object with `read_value` and `write_value`
by wrapping their `read_value()` and `write_value()` methods with ones that methods (like Characteristic and CharacteristicProxy objects) by wrapping
return/accept encoded/decoded values. those methods with ones that return/accept encoded/decoded values.
Objects with async methods are considered proxies, so the adaptation is one
For proxies (i.e used by a GATT client), the adaptation is one where the return where the return value of `read_value` is decoded and the value passed to
value of `read_value()` is decoded and the value passed to `write_value()` is `write_value` is encoded. Other objects are considered local characteristics
encoded. The `subscribe()` method, is wrapped with one where the values are decoded so the adaptation is one where the return value of `read_value` is encoded
before being passed to the subscriber. and the value passed to `write_value` is decoded.
If the characteristic has a `subscribe` method, it is wrapped with one where
For local values (i.e hosted by a GATT server) the adaptation is one where the the values are decoded before being passed to the subscriber.
return value of `read_value()` is encoded and the value passed to `write_value()`
is decoded.
''' '''
read_value: Callable def __init__(self, characteristic):
write_value: Callable
def __init__(self, characteristic: Union[Characteristic, AttributeProxy]):
self.wrapped_characteristic = characteristic self.wrapped_characteristic = characteristic
self.subscribers: Dict[Callable, Callable] = ( self.subscribers = {} # Map from subscriber to proxy subscriber
{}
) # Map from subscriber to proxy subscriber
if isinstance(characteristic, Characteristic): if asyncio.iscoroutinefunction(
self.read_value = self.read_encoded_value characteristic.read_value
self.write_value = self.write_encoded_value ) and asyncio.iscoroutinefunction(characteristic.write_value):
else:
self.read_value = self.read_decoded_value self.read_value = self.read_decoded_value
self.write_value = self.write_decoded_value self.write_value = self.write_decoded_value
else:
self.read_value = self.read_encoded_value
self.write_value = self.write_encoded_value
if hasattr(self.wrapped_characteristic, 'subscribe'):
self.subscribe = self.wrapped_subscribe self.subscribe = self.wrapped_subscribe
if hasattr(self.wrapped_characteristic, 'unsubscribe'):
self.unsubscribe = self.wrapped_unsubscribe self.unsubscribe = self.wrapped_unsubscribe
def __getattr__(self, name): def __getattr__(self, name):
@@ -591,13 +477,11 @@ class CharacteristicAdapter:
else: else:
setattr(self.wrapped_characteristic, name, value) setattr(self.wrapped_characteristic, name, value)
async def read_encoded_value(self, connection): def read_encoded_value(self, connection):
return self.encode_value( return self.encode_value(self.wrapped_characteristic.read_value(connection))
await self.wrapped_characteristic.read_value(connection)
)
async def write_encoded_value(self, connection, value): def write_encoded_value(self, connection, value):
return await self.wrapped_characteristic.write_value( return self.wrapped_characteristic.write_value(
connection, self.decode_value(value) connection, self.decode_value(value)
) )
@@ -732,24 +616,13 @@ class Descriptor(Attribute):
''' '''
def __str__(self) -> str: def __str__(self) -> str:
if isinstance(self.value, bytes):
value_str = self.value.hex()
elif isinstance(self.value, CharacteristicValue):
value = self.value.read(None)
if isinstance(value, bytes):
value_str = value.hex()
else:
value_str = '<async>'
else:
value_str = '<...>'
return ( return (
f'Descriptor(handle=0x{self.handle:04X}, ' f'Descriptor(handle=0x{self.handle:04X}, '
f'type={self.type}, ' f'type={self.type}, '
f'value={value_str})' f'value={self.read_value(None).hex()})'
) )
# -----------------------------------------------------------------------------
class ClientCharacteristicConfigurationBits(enum.IntFlag): class ClientCharacteristicConfigurationBits(enum.IntFlag):
''' '''
See Vol 3, Part G - 3.3.3.3 - Table 3.11 Client Characteristic Configuration bit See Vol 3, Part G - 3.3.3.3 - Table 3.11 Client Characteristic Configuration bit
+24 -74
View File
@@ -38,7 +38,6 @@ from typing import (
Any, Any,
Iterable, Iterable,
Type, Type,
Set,
TYPE_CHECKING, TYPE_CHECKING,
) )
@@ -90,22 +89,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Utils
# -----------------------------------------------------------------------------
def show_services(services: Iterable[ServiceProxy]) -> None:
for service in services:
print(color(str(service), 'cyan'))
for characteristic in service.characteristics:
print(color(' ' + str(characteristic), 'magenta'))
for descriptor in characteristic.descriptors:
print(color(' ' + str(descriptor), 'green'))
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Proxies # Proxies
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -145,7 +128,7 @@ class ServiceProxy(AttributeProxy):
included_services: List[ServiceProxy] included_services: List[ServiceProxy]
@staticmethod @staticmethod
def from_client(service_class, client: Client, service_uuid: UUID): def from_client(service_class, client, service_uuid):
# The service and its characteristics are considered to have already been # The service and its characteristics are considered to have already been
# discovered # discovered
services = client.get_services_by_uuid(service_uuid) services = client.get_services_by_uuid(service_uuid)
@@ -223,11 +206,11 @@ class CharacteristicProxy(AttributeProxy):
return await self.client.subscribe(self, subscriber, prefer_notify) return await self.client.subscribe(self, subscriber, prefer_notify)
async def unsubscribe(self, subscriber=None, force=False): async def unsubscribe(self, subscriber=None):
if subscriber in self.subscribers: if subscriber in self.subscribers:
subscriber = self.subscribers.pop(subscriber) subscriber = self.subscribers.pop(subscriber)
return await self.client.unsubscribe(self, subscriber, force) return await self.client.unsubscribe(self, subscriber)
def __str__(self) -> str: def __str__(self) -> str:
return ( return (
@@ -263,12 +246,8 @@ class ProfileServiceProxy:
class Client: class Client:
services: List[ServiceProxy] services: List[ServiceProxy]
cached_values: Dict[int, Tuple[datetime, bytes]] cached_values: Dict[int, Tuple[datetime, bytes]]
notification_subscribers: Dict[ notification_subscribers: Dict[int, Callable[[bytes], Any]]
int, Set[Union[CharacteristicProxy, Callable[[bytes], Any]]] indication_subscribers: Dict[int, Callable[[bytes], Any]]
]
indication_subscribers: Dict[
int, Set[Union[CharacteristicProxy, Callable[[bytes], Any]]]
]
pending_response: Optional[asyncio.futures.Future[ATT_PDU]] pending_response: Optional[asyncio.futures.Future[ATT_PDU]]
pending_request: Optional[ATT_PDU] pending_request: Optional[ATT_PDU]
@@ -278,8 +257,10 @@ class Client:
self.request_semaphore = asyncio.Semaphore(1) self.request_semaphore = asyncio.Semaphore(1)
self.pending_request = None self.pending_request = None
self.pending_response = None self.pending_response = None
self.notification_subscribers = {} # Subscriber set, by attribute handle self.notification_subscribers = (
self.indication_subscribers = {} # Subscriber set, by attribute handle {}
) # Notification subscribers, by attribute handle
self.indication_subscribers = {} # Indication subscribers, by attribute handle
self.services = [] self.services = []
self.cached_values = {} self.cached_values = {}
@@ -368,7 +349,9 @@ class Client:
if c.uuid == uuid if c.uuid == uuid
] ]
def get_attribute_grouping(self, attribute_handle: int) -> Optional[ def get_attribute_grouping(
self, attribute_handle: int
) -> Optional[
Union[ Union[
ServiceProxy, ServiceProxy,
Tuple[ServiceProxy, CharacteristicProxy], Tuple[ServiceProxy, CharacteristicProxy],
@@ -699,8 +682,8 @@ class Client:
async def discover_descriptors( async def discover_descriptors(
self, self,
characteristic: Optional[CharacteristicProxy] = None, characteristic: Optional[CharacteristicProxy] = None,
start_handle: Optional[int] = None, start_handle=None,
end_handle: Optional[int] = None, end_handle=None,
) -> List[DescriptorProxy]: ) -> List[DescriptorProxy]:
''' '''
See Vol 3, Part G - 4.7.1 Discover All Characteristic Descriptors See Vol 3, Part G - 4.7.1 Discover All Characteristic Descriptors
@@ -806,12 +789,7 @@ class Client:
return attributes return attributes
async def subscribe( async def subscribe(self, characteristic, subscriber=None, prefer_notify=True):
self,
characteristic: CharacteristicProxy,
subscriber: Optional[Callable[[bytes], Any]] = None,
prefer_notify: bool = True,
) -> None:
# If we haven't already discovered the descriptors for this characteristic, # If we haven't already discovered the descriptors for this characteristic,
# do it now # do it now
if not characteristic.descriptors_discovered: if not characteristic.descriptors_discovered:
@@ -848,7 +826,6 @@ class Client:
subscriber_set = subscribers.setdefault(characteristic.handle, set()) subscriber_set = subscribers.setdefault(characteristic.handle, set())
if subscriber is not None: if subscriber is not None:
subscriber_set.add(subscriber) subscriber_set.add(subscriber)
# Add the characteristic as a subscriber, which will result in the # Add the characteristic as a subscriber, which will result in the
# characteristic emitting an 'update' event when a notification or indication # characteristic emitting an 'update' event when a notification or indication
# is received # is received
@@ -856,18 +833,7 @@ class Client:
await self.write_value(cccd, struct.pack('<H', bits), with_response=True) await self.write_value(cccd, struct.pack('<H', bits), with_response=True)
async def unsubscribe( async def unsubscribe(self, characteristic, subscriber=None):
self,
characteristic: CharacteristicProxy,
subscriber: Optional[Callable[[bytes], Any]] = None,
force: bool = False,
) -> None:
'''
Unsubscribe from a characteristic.
If `force` is True, this will write zeros to the CCCD when there are no
subscribers left, even if there were already no registered subscribers.
'''
# If we haven't already discovered the descriptors for this characteristic, # If we haven't already discovered the descriptors for this characteristic,
# do it now # do it now
if not characteristic.descriptors_discovered: if not characteristic.descriptors_discovered:
@@ -881,45 +847,31 @@ class Client:
logger.warning('unsubscribing from characteristic with no CCCD descriptor') logger.warning('unsubscribing from characteristic with no CCCD descriptor')
return return
# Check if the characteristic has subscribers
if not (
characteristic.handle in self.notification_subscribers
or characteristic.handle in self.indication_subscribers
):
if not force:
return
# Remove the subscriber(s)
if subscriber is not None: if subscriber is not None:
# Remove matching subscriber from subscriber sets # Remove matching subscriber from subscriber sets
for subscriber_set in ( for subscriber_set in (
self.notification_subscribers, self.notification_subscribers,
self.indication_subscribers, self.indication_subscribers,
): ):
if ( subscribers = subscriber_set.get(characteristic.handle, [])
subscribers := subscriber_set.get(characteristic.handle) if subscriber in subscribers:
) and subscriber in subscribers:
subscribers.remove(subscriber) subscribers.remove(subscriber)
# Cleanup if we removed the last one # Cleanup if we removed the last one
if not subscribers: if not subscribers:
del subscriber_set[characteristic.handle] del subscriber_set[characteristic.handle]
else: else:
# Remove all subscribers for this attribute from the sets # Remove all subscribers for this attribute from the sets!
self.notification_subscribers.pop(characteristic.handle, None) self.notification_subscribers.pop(characteristic.handle, None)
self.indication_subscribers.pop(characteristic.handle, None) self.indication_subscribers.pop(characteristic.handle, None)
# Update the CCCD if not self.notification_subscribers and not self.indication_subscribers:
if not (
characteristic.handle in self.notification_subscribers
or characteristic.handle in self.indication_subscribers
):
# No more subscribers left # No more subscribers left
await self.write_value(cccd, b'\x00\x00', with_response=True) await self.write_value(cccd, b'\x00\x00', with_response=True)
async def read_value( async def read_value(
self, attribute: Union[int, AttributeProxy], no_long_read: bool = False self, attribute: Union[int, AttributeProxy], no_long_read: bool = False
) -> bytes: ) -> Any:
''' '''
See Vol 3, Part G - 4.8.1 Read Characteristic Value See Vol 3, Part G - 4.8.1 Read Characteristic Value
@@ -1082,7 +1034,7 @@ class Client:
logger.warning('!!! unexpected response, there is no pending request') logger.warning('!!! unexpected response, there is no pending request')
return return
# The response should match the pending request unless it is # Sanity check: the response should match the pending request unless it is
# an error response # an error response
if att_pdu.op_code != ATT_ERROR_RESPONSE: if att_pdu.op_code != ATT_ERROR_RESPONSE:
expected_response_name = self.pending_request.name.replace( expected_response_name = self.pending_request.name.replace(
@@ -1115,7 +1067,7 @@ class Client:
def on_att_handle_value_notification(self, notification): def on_att_handle_value_notification(self, notification):
# Call all subscribers # Call all subscribers
subscribers = self.notification_subscribers.get( subscribers = self.notification_subscribers.get(
notification.attribute_handle, set() notification.attribute_handle, []
) )
if not subscribers: if not subscribers:
logger.warning('!!! received notification with no subscriber') logger.warning('!!! received notification with no subscriber')
@@ -1129,9 +1081,7 @@ class Client:
def on_att_handle_value_indication(self, indication): def on_att_handle_value_indication(self, indication):
# Call all subscribers # Call all subscribers
subscribers = self.indication_subscribers.get( subscribers = self.indication_subscribers.get(indication.attribute_handle, [])
indication.attribute_handle, set()
)
if not subscribers: if not subscribers:
logger.warning('!!! received indication with no subscriber') logger.warning('!!! received indication with no subscriber')
+26 -34
View File
@@ -31,9 +31,9 @@ import struct
from typing import List, Tuple, Optional, TypeVar, Type, Dict, Iterable, TYPE_CHECKING from typing import List, Tuple, Optional, TypeVar, Type, Dict, Iterable, TYPE_CHECKING
from pyee import EventEmitter from pyee import EventEmitter
from bumble.colors import color from .colors import color
from bumble.core import UUID from .core import UUID
from bumble.att import ( from .att import (
ATT_ATTRIBUTE_NOT_FOUND_ERROR, ATT_ATTRIBUTE_NOT_FOUND_ERROR,
ATT_ATTRIBUTE_NOT_LONG_ERROR, ATT_ATTRIBUTE_NOT_LONG_ERROR,
ATT_CID, ATT_CID,
@@ -60,7 +60,7 @@ from bumble.att import (
ATT_Write_Response, ATT_Write_Response,
Attribute, Attribute,
) )
from bumble.gatt import ( from .gatt import (
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE, GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR, GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
GATT_MAX_ATTRIBUTE_VALUE_SIZE, GATT_MAX_ATTRIBUTE_VALUE_SIZE,
@@ -74,7 +74,6 @@ from bumble.gatt import (
Descriptor, Descriptor,
Service, Service,
) )
from bumble.utils import AsyncRunner
if TYPE_CHECKING: if TYPE_CHECKING:
from bumble.device import Device, Connection from bumble.device import Device, Connection
@@ -328,7 +327,7 @@ class Server(EventEmitter):
f'handle=0x{characteristic.handle:04X}: {value.hex()}' f'handle=0x{characteristic.handle:04X}: {value.hex()}'
) )
# Check parameters # Sanity check
if len(value) != 2: if len(value) != 2:
logger.warning('CCCD value not 2 bytes long') logger.warning('CCCD value not 2 bytes long')
return return
@@ -380,7 +379,7 @@ class Server(EventEmitter):
# Get or encode the value # Get or encode the value
value = ( value = (
await attribute.read_value(connection) attribute.read_value(connection)
if value is None if value is None
else attribute.encode_value(value) else attribute.encode_value(value)
) )
@@ -423,7 +422,7 @@ class Server(EventEmitter):
# Get or encode the value # Get or encode the value
value = ( value = (
await attribute.read_value(connection) attribute.read_value(connection)
if value is None if value is None
else attribute.encode_value(value) else attribute.encode_value(value)
) )
@@ -445,9 +444,9 @@ class Server(EventEmitter):
assert self.pending_confirmations[connection.handle] is None assert self.pending_confirmations[connection.handle] is None
# Create a future value to hold the eventual response # Create a future value to hold the eventual response
pending_confirmation = self.pending_confirmations[connection.handle] = ( pending_confirmation = self.pending_confirmations[
asyncio.get_running_loop().create_future() connection.handle
) ] = asyncio.get_running_loop().create_future()
try: try:
self.send_gatt_pdu(connection.handle, indication.to_bytes()) self.send_gatt_pdu(connection.handle, indication.to_bytes())
@@ -651,8 +650,7 @@ class Server(EventEmitter):
self.send_response(connection, response) self.send_response(connection, response)
@AsyncRunner.run_in_task() def on_att_find_by_type_value_request(self, connection, request):
async def on_att_find_by_type_value_request(self, connection, request):
''' '''
See Bluetooth spec Vol 3, Part F - 3.4.3.3 Find By Type Value Request See Bluetooth spec Vol 3, Part F - 3.4.3.3 Find By Type Value Request
''' '''
@@ -660,13 +658,13 @@ class Server(EventEmitter):
# Build list of returned attributes # Build list of returned attributes
pdu_space_available = connection.att_mtu - 2 pdu_space_available = connection.att_mtu - 2
attributes = [] attributes = []
async for attribute in ( for attribute in (
attribute attribute
for attribute in self.attributes for attribute in self.attributes
if attribute.handle >= request.starting_handle if attribute.handle >= request.starting_handle
and attribute.handle <= request.ending_handle and attribute.handle <= request.ending_handle
and attribute.type == request.attribute_type and attribute.type == request.attribute_type
and (await attribute.read_value(connection)) == request.attribute_value and attribute.read_value(connection) == request.attribute_value
and pdu_space_available >= 4 and pdu_space_available >= 4
): ):
# TODO: check permissions # TODO: check permissions
@@ -704,8 +702,7 @@ class Server(EventEmitter):
self.send_response(connection, response) self.send_response(connection, response)
@AsyncRunner.run_in_task() def on_att_read_by_type_request(self, connection, request):
async def on_att_read_by_type_request(self, connection, request):
''' '''
See Bluetooth spec Vol 3, Part F - 3.4.4.1 Read By Type Request See Bluetooth spec Vol 3, Part F - 3.4.4.1 Read By Type Request
''' '''
@@ -728,7 +725,7 @@ class Server(EventEmitter):
and pdu_space_available and pdu_space_available
): ):
try: try:
attribute_value = await attribute.read_value(connection) attribute_value = attribute.read_value(connection)
except ATT_Error as error: except ATT_Error as error:
# If the first attribute is unreadable, return an error # If the first attribute is unreadable, return an error
# Otherwise return attributes up to this point # Otherwise return attributes up to this point
@@ -770,15 +767,14 @@ class Server(EventEmitter):
self.send_response(connection, response) self.send_response(connection, response)
@AsyncRunner.run_in_task() def on_att_read_request(self, connection, request):
async def on_att_read_request(self, connection, request):
''' '''
See Bluetooth spec Vol 3, Part F - 3.4.4.3 Read Request See Bluetooth spec Vol 3, Part F - 3.4.4.3 Read Request
''' '''
if attribute := self.get_attribute(request.attribute_handle): if attribute := self.get_attribute(request.attribute_handle):
try: try:
value = await attribute.read_value(connection) value = attribute.read_value(connection)
except ATT_Error as error: except ATT_Error as error:
response = ATT_Error_Response( response = ATT_Error_Response(
request_opcode_in_error=request.op_code, request_opcode_in_error=request.op_code,
@@ -796,15 +792,14 @@ class Server(EventEmitter):
) )
self.send_response(connection, response) self.send_response(connection, response)
@AsyncRunner.run_in_task() def on_att_read_blob_request(self, connection, request):
async def on_att_read_blob_request(self, connection, request):
''' '''
See Bluetooth spec Vol 3, Part F - 3.4.4.5 Read Blob Request See Bluetooth spec Vol 3, Part F - 3.4.4.5 Read Blob Request
''' '''
if attribute := self.get_attribute(request.attribute_handle): if attribute := self.get_attribute(request.attribute_handle):
try: try:
value = await attribute.read_value(connection) value = attribute.read_value(connection)
except ATT_Error as error: except ATT_Error as error:
response = ATT_Error_Response( response = ATT_Error_Response(
request_opcode_in_error=request.op_code, request_opcode_in_error=request.op_code,
@@ -841,8 +836,7 @@ class Server(EventEmitter):
) )
self.send_response(connection, response) self.send_response(connection, response)
@AsyncRunner.run_in_task() def on_att_read_by_group_type_request(self, connection, request):
async def on_att_read_by_group_type_request(self, connection, request):
''' '''
See Bluetooth spec Vol 3, Part F - 3.4.4.9 Read by Group Type Request See Bluetooth spec Vol 3, Part F - 3.4.4.9 Read by Group Type Request
''' '''
@@ -870,7 +864,7 @@ class Server(EventEmitter):
): ):
# No need to catch permission errors here, since these attributes # No need to catch permission errors here, since these attributes
# must all be world-readable # must all be world-readable
attribute_value = await attribute.read_value(connection) attribute_value = attribute.read_value(connection)
# Check the attribute value size # Check the attribute value size
max_attribute_size = min(connection.att_mtu - 6, 251) max_attribute_size = min(connection.att_mtu - 6, 251)
if len(attribute_value) > max_attribute_size: if len(attribute_value) > max_attribute_size:
@@ -909,8 +903,7 @@ class Server(EventEmitter):
self.send_response(connection, response) self.send_response(connection, response)
@AsyncRunner.run_in_task() def on_att_write_request(self, connection, request):
async def on_att_write_request(self, connection, request):
''' '''
See Bluetooth spec Vol 3, Part F - 3.4.5.1 Write Request See Bluetooth spec Vol 3, Part F - 3.4.5.1 Write Request
''' '''
@@ -943,13 +936,12 @@ class Server(EventEmitter):
return return
# Accept the value # Accept the value
await attribute.write_value(connection, request.attribute_value) attribute.write_value(connection, request.attribute_value)
# Done # Done
self.send_response(connection, ATT_Write_Response()) self.send_response(connection, ATT_Write_Response())
@AsyncRunner.run_in_task() def on_att_write_command(self, connection, request):
async def on_att_write_command(self, connection, request):
''' '''
See Bluetooth spec Vol 3, Part F - 3.4.5.3 Write Command See Bluetooth spec Vol 3, Part F - 3.4.5.3 Write Command
''' '''
@@ -967,9 +959,9 @@ class Server(EventEmitter):
# Accept the value # Accept the value
try: try:
await attribute.write_value(connection, request.attribute_value) attribute.write_value(connection, request.attribute_value)
except Exception as error: except Exception as error:
logger.exception(f'!!! ignoring exception: {error}') logger.warning(f'!!! ignoring exception: {error}')
def on_att_handle_value_confirmation(self, connection, _confirmation): def on_att_handle_value_confirmation(self, connection, _confirmation):
''' '''
+755 -1542
View File
File diff suppressed because it is too large Load Diff
+62 -158
View File
@@ -15,46 +15,30 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations
from collections.abc import Callable, MutableMapping
import datetime
from typing import cast, Any, Optional
import logging import logging
from bumble import avc from .colors import color
from bumble import avctp from .att import ATT_CID, ATT_PDU
from bumble import avdtp from .smp import SMP_CID, SMP_Command
from bumble import avrcp from .core import name_or_number
from bumble import crypto from .l2cap import (
from bumble import rfcomm
from bumble import sdp
from bumble.colors import color
from bumble.att import ATT_CID, ATT_PDU
from bumble.smp import SMP_CID, SMP_Command
from bumble.core import name_or_number
from bumble.l2cap import (
L2CAP_PDU, L2CAP_PDU,
L2CAP_CONNECTION_REQUEST, L2CAP_CONNECTION_REQUEST,
L2CAP_CONNECTION_RESPONSE, L2CAP_CONNECTION_RESPONSE,
L2CAP_SIGNALING_CID, L2CAP_SIGNALING_CID,
L2CAP_LE_SIGNALING_CID, L2CAP_LE_SIGNALING_CID,
L2CAP_Control_Frame, L2CAP_Control_Frame,
L2CAP_Connection_Request,
L2CAP_Connection_Response, L2CAP_Connection_Response,
) )
from bumble.hci import ( from .hci import (
Address,
HCI_EVENT_PACKET, HCI_EVENT_PACKET,
HCI_ACL_DATA_PACKET, HCI_ACL_DATA_PACKET,
HCI_DISCONNECTION_COMPLETE_EVENT, HCI_DISCONNECTION_COMPLETE_EVENT,
HCI_AclDataPacketAssembler, HCI_AclDataPacketAssembler,
HCI_Packet,
HCI_Event,
HCI_AclDataPacket,
HCI_Disconnection_Complete_Event,
) )
from .rfcomm import RFCOMM_Frame, RFCOMM_PSM
from .sdp import SDP_PDU, SDP_PSM
from .avdtp import MessageAssembler as AVDTP_MessageAssembler, AVDTP_PSM
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -64,36 +48,26 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
PSM_NAMES = { PSM_NAMES = {
rfcomm.RFCOMM_PSM: 'RFCOMM', RFCOMM_PSM: 'RFCOMM',
sdp.SDP_PSM: 'SDP', SDP_PSM: 'SDP',
avdtp.AVDTP_PSM: 'AVDTP', AVDTP_PSM: 'AVDTP'
avctp.AVCTP_PSM: 'AVCTP',
# TODO: add more PSM values # TODO: add more PSM values
} }
AVCTP_PID_NAMES = {avrcp.AVRCP_PID: 'AVRCP'}
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class PacketTracer: class PacketTracer:
class AclStream: class AclStream:
psms: MutableMapping[int, int] def __init__(self, analyzer):
peer: Optional[PacketTracer.AclStream]
avdtp_assemblers: MutableMapping[int, avdtp.MessageAssembler]
avctp_assemblers: MutableMapping[int, avctp.MessageAssembler]
def __init__(self, analyzer: PacketTracer.Analyzer) -> None:
self.analyzer = analyzer self.analyzer = analyzer
self.packet_assembler = HCI_AclDataPacketAssembler(self.on_acl_pdu) self.packet_assembler = HCI_AclDataPacketAssembler(self.on_acl_pdu)
self.avdtp_assemblers = {} # AVDTP assemblers, by source_cid self.avdtp_assemblers = {} # AVDTP assemblers, by source_cid
self.avctp_assemblers = {} # AVCTP assemblers, by source_cid
self.psms = {} # PSM, by source_cid self.psms = {} # PSM, by source_cid
self.peer = None self.peer = None # ACL stream in the other direction
# pylint: disable=too-many-nested-blocks # pylint: disable=too-many-nested-blocks
def on_acl_pdu(self, pdu: bytes) -> None: def on_acl_pdu(self, pdu):
l2cap_pdu = L2CAP_PDU.from_bytes(pdu) l2cap_pdu = L2CAP_PDU.from_bytes(pdu)
self.analyzer.emit(l2cap_pdu)
if l2cap_pdu.cid == ATT_CID: if l2cap_pdu.cid == ATT_CID:
att_pdu = ATT_PDU.from_bytes(l2cap_pdu.payload) att_pdu = ATT_PDU.from_bytes(l2cap_pdu.payload)
@@ -107,59 +81,46 @@ class PacketTracer:
# Check if this signals a new channel # Check if this signals a new channel
if control_frame.code == L2CAP_CONNECTION_REQUEST: if control_frame.code == L2CAP_CONNECTION_REQUEST:
connection_request = cast(L2CAP_Connection_Request, control_frame) self.psms[control_frame.source_cid] = control_frame.psm
self.psms[connection_request.source_cid] = connection_request.psm
elif control_frame.code == L2CAP_CONNECTION_RESPONSE: elif control_frame.code == L2CAP_CONNECTION_RESPONSE:
connection_response = cast(L2CAP_Connection_Response, control_frame)
if ( if (
connection_response.result control_frame.result
== L2CAP_Connection_Response.CONNECTION_SUCCESSFUL == L2CAP_Connection_Response.CONNECTION_SUCCESSFUL
): ):
if self.peer and ( if self.peer:
psm := self.peer.psms.get(connection_response.source_cid) if psm := self.peer.psms.get(control_frame.source_cid):
): # Found a pending connection
# Found a pending connection self.psms[control_frame.destination_cid] = psm
self.psms[connection_response.destination_cid] = psm
# For AVDTP connections, create a packet assembler for
# each direction
if psm == AVDTP_PSM:
self.avdtp_assemblers[
control_frame.source_cid
] = AVDTP_MessageAssembler(self.on_avdtp_message)
self.peer.avdtp_assemblers[
control_frame.destination_cid
] = AVDTP_MessageAssembler(
self.peer.on_avdtp_message
)
# For AVDTP connections, create a packet assembler for
# each direction
if psm == avdtp.AVDTP_PSM:
self.avdtp_assemblers[
connection_response.source_cid
] = avdtp.MessageAssembler(self.on_avdtp_message)
self.peer.avdtp_assemblers[
connection_response.destination_cid
] = avdtp.MessageAssembler(self.peer.on_avdtp_message)
elif psm == avctp.AVCTP_PSM:
self.avctp_assemblers[
connection_response.source_cid
] = avctp.MessageAssembler(self.on_avctp_message)
self.peer.avctp_assemblers[
connection_response.destination_cid
] = avctp.MessageAssembler(self.peer.on_avctp_message)
else: else:
# Try to find the PSM associated with this PDU # Try to find the PSM associated with this PDU
if self.peer and (psm := self.peer.psms.get(l2cap_pdu.cid)): if self.peer and (psm := self.peer.psms.get(l2cap_pdu.cid)):
if psm == sdp.SDP_PSM: if psm == SDP_PSM:
sdp_pdu = sdp.SDP_PDU.from_bytes(l2cap_pdu.payload) sdp_pdu = SDP_PDU.from_bytes(l2cap_pdu.payload)
self.analyzer.emit(sdp_pdu) self.analyzer.emit(sdp_pdu)
elif psm == rfcomm.RFCOMM_PSM: elif psm == RFCOMM_PSM:
rfcomm_frame = rfcomm.RFCOMM_Frame.from_bytes(l2cap_pdu.payload) rfcomm_frame = RFCOMM_Frame.from_bytes(l2cap_pdu.payload)
self.analyzer.emit(rfcomm_frame) self.analyzer.emit(rfcomm_frame)
elif psm == avdtp.AVDTP_PSM: elif psm == AVDTP_PSM:
self.analyzer.emit( self.analyzer.emit(
f'{color("L2CAP", "green")} [CID={l2cap_pdu.cid}, ' f'{color("L2CAP", "green")} [CID={l2cap_pdu.cid}, '
f'PSM=AVDTP]: {l2cap_pdu.payload.hex()}' f'PSM=AVDTP]: {l2cap_pdu.payload.hex()}'
) )
if avdtp_assembler := self.avdtp_assemblers.get(l2cap_pdu.cid): assembler = self.avdtp_assemblers.get(l2cap_pdu.cid)
avdtp_assembler.on_pdu(l2cap_pdu.payload) if assembler:
elif psm == avctp.AVCTP_PSM: assembler.on_pdu(l2cap_pdu.payload)
self.analyzer.emit(
f'{color("L2CAP", "green")} [CID={l2cap_pdu.cid}, '
f'PSM=AVCTP]: {l2cap_pdu.payload.hex()}'
)
if avctp_assembler := self.avctp_assemblers.get(l2cap_pdu.cid):
avctp_assembler.on_pdu(l2cap_pdu.payload)
else: else:
psm_string = name_or_number(PSM_NAMES, psm) psm_string = name_or_number(PSM_NAMES, psm)
self.analyzer.emit( self.analyzer.emit(
@@ -169,49 +130,22 @@ class PacketTracer:
else: else:
self.analyzer.emit(l2cap_pdu) self.analyzer.emit(l2cap_pdu)
def on_avdtp_message( def on_avdtp_message(self, transaction_label, message):
self, transaction_label: int, message: avdtp.Message
) -> None:
self.analyzer.emit( self.analyzer.emit(
f'{color("AVDTP", "green")} [{transaction_label}] {message}' f'{color("AVDTP", "green")} [{transaction_label}] {message}'
) )
def on_avctp_message( def feed_packet(self, packet):
self,
transaction_label: int,
is_command: bool,
ipid: bool,
pid: int,
payload: bytes,
):
if pid == avrcp.AVRCP_PID:
avc_frame = avc.Frame.from_bytes(payload)
details = str(avc_frame)
else:
details = payload.hex()
c_r = 'Command' if is_command else 'Response'
self.analyzer.emit(
f'{color("AVCTP", "green")} '
f'{c_r}[{transaction_label}][{name_or_number(AVCTP_PID_NAMES, pid)}] '
f'{"#" if ipid else ""}'
f'{details}'
)
def feed_packet(self, packet: HCI_AclDataPacket) -> None:
self.packet_assembler.feed_packet(packet) self.packet_assembler.feed_packet(packet)
class Analyzer: class Analyzer:
acl_streams: MutableMapping[int, PacketTracer.AclStream] def __init__(self, label, emit_message):
peer: PacketTracer.Analyzer
def __init__(self, label: str, emit_message: Callable[..., None]) -> None:
self.label = label self.label = label
self.emit_message = emit_message self.emit_message = emit_message
self.acl_streams = {} # ACL streams, by connection handle self.acl_streams = {} # ACL streams, by connection handle
self.packet_timestamp: Optional[datetime.datetime] = None self.peer = None # Analyzer in the other direction
def start_acl_stream(self, connection_handle: int) -> PacketTracer.AclStream: def start_acl_stream(self, connection_handle):
logger.info( logger.info(
f'[{self.label}] +++ Creating ACL stream for connection ' f'[{self.label}] +++ Creating ACL stream for connection '
f'0x{connection_handle:04X}' f'0x{connection_handle:04X}'
@@ -226,7 +160,7 @@ class PacketTracer:
return stream return stream
def end_acl_stream(self, connection_handle: int) -> None: def end_acl_stream(self, connection_handle):
if connection_handle in self.acl_streams: if connection_handle in self.acl_streams:
logger.info( logger.info(
f'[{self.label}] --- Removing ACL stream for connection ' f'[{self.label}] --- Removing ACL stream for connection '
@@ -237,52 +171,34 @@ class PacketTracer:
# Let the other forwarder know so it can cleanup its stream as well # Let the other forwarder know so it can cleanup its stream as well
self.peer.end_acl_stream(connection_handle) self.peer.end_acl_stream(connection_handle)
def on_packet( def on_packet(self, packet):
self, timestamp: Optional[datetime.datetime], packet: HCI_Packet
) -> None:
self.packet_timestamp = timestamp
self.emit(packet) self.emit(packet)
if packet.hci_packet_type == HCI_ACL_DATA_PACKET: if packet.hci_packet_type == HCI_ACL_DATA_PACKET:
acl_packet = cast(HCI_AclDataPacket, packet)
# Look for an existing stream for this handle, create one if it is the # Look for an existing stream for this handle, create one if it is the
# first ACL packet for that connection handle # first ACL packet for that connection handle
if ( if (stream := self.acl_streams.get(packet.connection_handle)) is None:
stream := self.acl_streams.get(acl_packet.connection_handle) stream = self.start_acl_stream(packet.connection_handle)
) is None: stream.feed_packet(packet)
stream = self.start_acl_stream(acl_packet.connection_handle)
stream.feed_packet(acl_packet)
elif packet.hci_packet_type == HCI_EVENT_PACKET: elif packet.hci_packet_type == HCI_EVENT_PACKET:
event_packet = cast(HCI_Event, packet) if packet.event_code == HCI_DISCONNECTION_COMPLETE_EVENT:
if event_packet.event_code == HCI_DISCONNECTION_COMPLETE_EVENT: self.end_acl_stream(packet.connection_handle)
self.end_acl_stream(
cast(HCI_Disconnection_Complete_Event, packet).connection_handle
)
def emit(self, message: Any) -> None: def emit(self, message):
if self.packet_timestamp: self.emit_message(f'[{self.label}] {message}')
prefix = f"[{self.packet_timestamp.strftime('%Y-%m-%d %H:%M:%S.%f')}]"
else:
prefix = ""
self.emit_message(f'{prefix}[{self.label}] {message}')
def trace( def trace(self, packet, direction=0):
self,
packet: HCI_Packet,
direction: int = 0,
timestamp: Optional[datetime.datetime] = None,
) -> None:
if direction == 0: if direction == 0:
self.host_to_controller_analyzer.on_packet(timestamp, packet) self.host_to_controller_analyzer.on_packet(packet)
else: else:
self.controller_to_host_analyzer.on_packet(timestamp, packet) self.controller_to_host_analyzer.on_packet(packet)
def __init__( def __init__(
self, self,
host_to_controller_label: str = color('HOST->CONTROLLER', 'blue'), host_to_controller_label=color('HOST->CONTROLLER', 'blue'),
controller_to_host_label: str = color('CONTROLLER->HOST', 'cyan'), controller_to_host_label=color('CONTROLLER->HOST', 'cyan'),
emit_message: Callable[..., None] = logger.info, emit_message=logger.info,
) -> None: ):
self.host_to_controller_analyzer = PacketTracer.Analyzer( self.host_to_controller_analyzer = PacketTracer.Analyzer(
host_to_controller_label, emit_message host_to_controller_label, emit_message
) )
@@ -291,15 +207,3 @@ class PacketTracer:
) )
self.host_to_controller_analyzer.peer = self.controller_to_host_analyzer self.host_to_controller_analyzer.peer = self.controller_to_host_analyzer
self.controller_to_host_analyzer.peer = self.host_to_controller_analyzer self.controller_to_host_analyzer.peer = self.host_to_controller_analyzer
def generate_irk() -> bytes:
return crypto.r()
def verify_rpa_with_irk(rpa: Address, irk: bytes) -> bool:
rpa_bytes = bytes(rpa)
prand_given = rpa_bytes[3:]
hash_given = rpa_bytes[:3]
hash_local = crypto.ah(irk, prand_given)
return hash_local[:3] == hash_given
+131 -1407
View File
File diff suppressed because it is too large Load Diff
+62 -285
View File
@@ -18,18 +18,18 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
import logging import logging
import asyncio
import enum import enum
import struct
from abc import ABC, abstractmethod
from pyee import EventEmitter from pyee import EventEmitter
from typing import Optional, Callable, TYPE_CHECKING from typing import Optional, Tuple, Callable, Dict, Union, TYPE_CHECKING
from typing_extensions import override
from bumble import l2cap, device from . import core, l2cap # type: ignore
from bumble.colors import color from .colors import color # type: ignore
from bumble.core import InvalidStateError, ProtocolError from .core import BT_BR_EDR_TRANSPORT, InvalidStateError, ProtocolError # type: ignore
from .hci import Address
if TYPE_CHECKING:
from bumble.device import Device, Connection
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -48,7 +48,6 @@ HID_INTERRUPT_PSM = 0x0013
class Message: class Message:
message_type: MessageType message_type: MessageType
# Report types # Report types
class ReportType(enum.IntEnum): class ReportType(enum.IntEnum):
OTHER_REPORT = 0x00 OTHER_REPORT = 0x00
@@ -62,7 +61,6 @@ class Message:
NOT_READY = 0x01 NOT_READY = 0x01
ERR_INVALID_REPORT_ID = 0x02 ERR_INVALID_REPORT_ID = 0x02
ERR_UNSUPPORTED_REQUEST = 0x03 ERR_UNSUPPORTED_REQUEST = 0x03
ERR_INVALID_PARAMETER = 0x04
ERR_UNKNOWN = 0x0E ERR_UNKNOWN = 0x0E
ERR_FATAL = 0x0F ERR_FATAL = 0x0F
@@ -104,14 +102,13 @@ class GetReportMessage(Message):
def __bytes__(self) -> bytes: def __bytes__(self) -> bytes:
packet_bytes = bytearray() packet_bytes = bytearray()
packet_bytes.append(self.report_id) packet_bytes.append(self.report_id)
if self.buffer_size == 0: packet_bytes.extend(
[(self.buffer_size & 0xFF), ((self.buffer_size >> 8) & 0xFF)]
)
if self.report_type == Message.ReportType.OTHER_REPORT:
return self.header(self.report_type) + packet_bytes return self.header(self.report_type) + packet_bytes
else: else:
return ( return self.header(0x08 | self.report_type) + packet_bytes
self.header(0x08 | self.report_type)
+ packet_bytes
+ struct.pack("<H", self.buffer_size)
)
@dataclass @dataclass
@@ -124,16 +121,6 @@ class SetReportMessage(Message):
return self.header(self.report_type) + self.data return self.header(self.report_type) + self.data
@dataclass
class SendControlData(Message):
report_type: int
data: bytes
message_type = Message.MessageType.DATA
def __bytes__(self) -> bytes:
return self.header(self.report_type) + self.data
@dataclass @dataclass
class GetProtocolMessage(Message): class GetProtocolMessage(Message):
message_type = Message.MessageType.GET_PROTOCOL message_type = Message.MessageType.GET_PROTOCOL
@@ -175,47 +162,31 @@ class VirtualCableUnplug(Message):
return self.header(Message.ControlCommand.VIRTUAL_CABLE_UNPLUG) return self.header(Message.ControlCommand.VIRTUAL_CABLE_UNPLUG)
# Device sends input report, host sends output report.
@dataclass @dataclass
class SendData(Message): class SendData(Message):
data: bytes data: bytes
report_type: int
message_type = Message.MessageType.DATA message_type = Message.MessageType.DATA
def __bytes__(self) -> bytes: def __bytes__(self) -> bytes:
return self.header(self.report_type) + self.data return self.header(Message.ReportType.OUTPUT_REPORT) + self.data
@dataclass
class SendHandshakeMessage(Message):
result_code: int
message_type = Message.MessageType.HANDSHAKE
def __bytes__(self) -> bytes:
return self.header(self.result_code)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class HID(ABC, EventEmitter): class Host(EventEmitter):
l2cap_ctrl_channel: Optional[l2cap.ClassicChannel] = None l2cap_ctrl_channel: Optional[l2cap.ClassicChannel]
l2cap_intr_channel: Optional[l2cap.ClassicChannel] = None l2cap_intr_channel: Optional[l2cap.ClassicChannel]
connection: Optional[device.Connection] = None
class Role(enum.IntEnum): def __init__(self, device: Device, connection: Connection) -> None:
HOST = 0x00
DEVICE = 0x01
def __init__(self, device: device.Device, role: Role) -> None:
super().__init__() super().__init__()
self.remote_device_bd_address: Optional[Address] = None
self.device = device self.device = device
self.role = role self.connection = connection
self.l2cap_ctrl_channel = None
self.l2cap_intr_channel = None
# Register ourselves with the L2CAP channel manager # Register ourselves with the L2CAP channel manager
device.register_l2cap_server(HID_CONTROL_PSM, self.on_l2cap_connection) device.register_l2cap_server(HID_CONTROL_PSM, self.on_connection)
device.register_l2cap_server(HID_INTERRUPT_PSM, self.on_l2cap_connection) device.register_l2cap_server(HID_INTERRUPT_PSM, self.on_connection)
device.on('connection', self.on_device_connection)
async def connect_control_channel(self) -> None: async def connect_control_channel(self) -> None:
# Create a new L2CAP connection - control channel # Create a new L2CAP connection - control channel
@@ -259,18 +230,9 @@ class HID(ABC, EventEmitter):
self.l2cap_ctrl_channel = None self.l2cap_ctrl_channel = None
await channel.disconnect() await channel.disconnect()
def on_device_connection(self, connection: device.Connection) -> None: def on_connection(self, l2cap_channel: l2cap.ClassicChannel) -> None:
self.connection = connection
self.remote_device_bd_address = connection.peer_address
connection.on('disconnection', self.on_device_disconnection)
def on_device_disconnection(self, reason: int) -> None:
self.connection = None
def on_l2cap_connection(self, l2cap_channel: l2cap.ClassicChannel) -> None:
logger.debug(f'+++ New L2CAP connection: {l2cap_channel}') logger.debug(f'+++ New L2CAP connection: {l2cap_channel}')
l2cap_channel.on('open', lambda: self.on_l2cap_channel_open(l2cap_channel)) l2cap_channel.on('open', lambda: self.on_l2cap_channel_open(l2cap_channel))
l2cap_channel.on('close', lambda: self.on_l2cap_channel_close(l2cap_channel))
def on_l2cap_channel_open(self, l2cap_channel: l2cap.ClassicChannel) -> None: def on_l2cap_channel_open(self, l2cap_channel: l2cap.ClassicChannel) -> None:
if l2cap_channel.psm == HID_CONTROL_PSM: if l2cap_channel.psm == HID_CONTROL_PSM:
@@ -281,220 +243,37 @@ class HID(ABC, EventEmitter):
self.l2cap_intr_channel.sink = self.on_intr_pdu self.l2cap_intr_channel.sink = self.on_intr_pdu
logger.debug(f'$$$ L2CAP channel open: {l2cap_channel}') logger.debug(f'$$$ L2CAP channel open: {l2cap_channel}')
def on_l2cap_channel_close(self, l2cap_channel: l2cap.ClassicChannel) -> None:
if l2cap_channel.psm == HID_CONTROL_PSM:
self.l2cap_ctrl_channel = None
else:
self.l2cap_intr_channel = None
logger.debug(f'$$$ L2CAP channel close: {l2cap_channel}')
@abstractmethod
def on_ctrl_pdu(self, pdu: bytes) -> None:
pass
def on_intr_pdu(self, pdu: bytes) -> None:
logger.debug(f'<<< HID INTERRUPT PDU: {pdu.hex()}')
self.emit("interrupt_data", pdu)
def send_pdu_on_ctrl(self, msg: bytes) -> None:
assert self.l2cap_ctrl_channel
self.l2cap_ctrl_channel.send_pdu(msg)
def send_pdu_on_intr(self, msg: bytes) -> None:
assert self.l2cap_intr_channel
self.l2cap_intr_channel.send_pdu(msg)
def send_data(self, data: bytes) -> None:
if self.role == HID.Role.HOST:
report_type = Message.ReportType.OUTPUT_REPORT
else:
report_type = Message.ReportType.INPUT_REPORT
msg = SendData(data, report_type)
hid_message = bytes(msg)
if self.l2cap_intr_channel is not None:
logger.debug(f'>>> HID INTERRUPT SEND DATA, PDU: {hid_message.hex()}')
self.send_pdu_on_intr(hid_message)
def virtual_cable_unplug(self) -> None:
msg = VirtualCableUnplug()
hid_message = bytes(msg)
logger.debug(f'>>> HID CONTROL VIRTUAL CABLE UNPLUG, PDU: {hid_message.hex()}')
self.send_pdu_on_ctrl(hid_message)
# -----------------------------------------------------------------------------
class Device(HID):
class GetSetReturn(enum.IntEnum):
FAILURE = 0x00
REPORT_ID_NOT_FOUND = 0x01
ERR_UNSUPPORTED_REQUEST = 0x02
ERR_UNKNOWN = 0x03
ERR_INVALID_PARAMETER = 0x04
SUCCESS = 0xFF
class GetSetStatus:
def __init__(self) -> None:
self.data = bytearray()
self.status = 0
def __init__(self, device: device.Device) -> None:
super().__init__(device, HID.Role.DEVICE)
get_report_cb: Optional[Callable[[int, int, int], None]] = None
set_report_cb: Optional[Callable[[int, int, int, bytes], None]] = None
get_protocol_cb: Optional[Callable[[], None]] = None
set_protocol_cb: Optional[Callable[[int], None]] = None
@override
def on_ctrl_pdu(self, pdu: bytes) -> None: def on_ctrl_pdu(self, pdu: bytes) -> None:
logger.debug(f'<<< HID CONTROL PDU: {pdu.hex()}') logger.debug(f'<<< HID CONTROL PDU: {pdu.hex()}')
param = pdu[0] & 0x0F # Here we will receive all kinds of packets, parse and then call respective callbacks
message_type = pdu[0] >> 4 message_type = pdu[0] >> 4
param = pdu[0] & 0x0F
if message_type == Message.MessageType.GET_REPORT: if message_type == Message.MessageType.HANDSHAKE:
logger.debug('<<< HID GET REPORT') logger.debug(f'<<< HID HANDSHAKE: {Message.Handshake(param).name}')
self.handle_get_report(pdu) self.emit('handshake', Message.Handshake(param))
elif message_type == Message.MessageType.SET_REPORT:
logger.debug('<<< HID SET REPORT')
self.handle_set_report(pdu)
elif message_type == Message.MessageType.GET_PROTOCOL:
logger.debug('<<< HID GET PROTOCOL')
self.handle_get_protocol(pdu)
elif message_type == Message.MessageType.SET_PROTOCOL:
logger.debug('<<< HID SET PROTOCOL')
self.handle_set_protocol(pdu)
elif message_type == Message.MessageType.DATA: elif message_type == Message.MessageType.DATA:
logger.debug('<<< HID CONTROL DATA') logger.debug('<<< HID CONTROL DATA')
self.emit('control_data', pdu) self.emit('data', pdu)
elif message_type == Message.MessageType.CONTROL: elif message_type == Message.MessageType.CONTROL:
if param == Message.ControlCommand.SUSPEND: if param == Message.ControlCommand.SUSPEND:
logger.debug('<<< HID SUSPEND') logger.debug('<<< HID SUSPEND')
self.emit('suspend') self.emit('suspend', pdu)
elif param == Message.ControlCommand.EXIT_SUSPEND: elif param == Message.ControlCommand.EXIT_SUSPEND:
logger.debug('<<< HID EXIT SUSPEND') logger.debug('<<< HID EXIT SUSPEND')
self.emit('exit_suspend') self.emit('exit_suspend', pdu)
elif param == Message.ControlCommand.VIRTUAL_CABLE_UNPLUG: elif param == Message.ControlCommand.VIRTUAL_CABLE_UNPLUG:
logger.debug('<<< HID VIRTUAL CABLE UNPLUG') logger.debug('<<< HID VIRTUAL CABLE UNPLUG')
self.emit('virtual_cable_unplug') self.emit('virtual_cable_unplug')
else: else:
logger.debug('<<< HID CONTROL OPERATION UNSUPPORTED') logger.debug('<<< HID CONTROL OPERATION UNSUPPORTED')
else: else:
logger.debug('<<< HID MESSAGE TYPE UNSUPPORTED') logger.debug('<<< HID CONTROL DATA')
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST) self.emit('data', pdu)
def send_handshake_message(self, result_code: int) -> None: def on_intr_pdu(self, pdu: bytes) -> None:
msg = SendHandshakeMessage(result_code) logger.debug(f'<<< HID INTERRUPT PDU: {pdu.hex()}')
hid_message = bytes(msg) self.emit("data", pdu)
logger.debug(f'>>> HID HANDSHAKE MESSAGE, PDU: {hid_message.hex()}')
self.send_pdu_on_ctrl(hid_message)
def send_control_data(self, report_type: int, data: bytes):
msg = SendControlData(report_type=report_type, data=data)
hid_message = bytes(msg)
logger.debug(f'>>> HID CONTROL DATA: {hid_message.hex()}')
self.send_pdu_on_ctrl(hid_message)
def handle_get_report(self, pdu: bytes):
if self.get_report_cb is None:
logger.debug("GetReport callback not registered !!")
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
return
report_type = pdu[0] & 0x03
buffer_flag = (pdu[0] & 0x08) >> 3
report_id = pdu[1]
logger.debug(f"buffer_flag: {buffer_flag}")
if buffer_flag == 1:
buffer_size = (pdu[3] << 8) | pdu[2]
else:
buffer_size = 0
ret = self.get_report_cb(report_id, report_type, buffer_size)
assert ret is not None
if ret.status == self.GetSetReturn.FAILURE:
self.send_handshake_message(Message.Handshake.ERR_UNKNOWN)
elif ret.status == self.GetSetReturn.SUCCESS:
data = bytearray()
data.append(report_id)
data.extend(ret.data)
if len(data) < self.l2cap_ctrl_channel.peer_mtu: # type: ignore[union-attr]
self.send_control_data(report_type=report_type, data=data)
else:
self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER)
elif ret.status == self.GetSetReturn.REPORT_ID_NOT_FOUND:
self.send_handshake_message(Message.Handshake.ERR_INVALID_REPORT_ID)
elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER:
self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER)
elif ret.status == self.GetSetReturn.ERR_UNSUPPORTED_REQUEST:
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
def register_get_report_cb(self, cb: Callable[[int, int, int], None]) -> None:
self.get_report_cb = cb
logger.debug("GetReport callback registered successfully")
def handle_set_report(self, pdu: bytes):
if self.set_report_cb is None:
logger.debug("SetReport callback not registered !!")
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
return
report_type = pdu[0] & 0x03
report_id = pdu[1]
report_data = pdu[2:]
report_size = len(report_data) + 1
ret = self.set_report_cb(report_id, report_type, report_size, report_data)
assert ret is not None
if ret.status == self.GetSetReturn.SUCCESS:
self.send_handshake_message(Message.Handshake.SUCCESSFUL)
elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER:
self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER)
elif ret.status == self.GetSetReturn.REPORT_ID_NOT_FOUND:
self.send_handshake_message(Message.Handshake.ERR_INVALID_REPORT_ID)
else:
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
def register_set_report_cb(
self, cb: Callable[[int, int, int, bytes], None]
) -> None:
self.set_report_cb = cb
logger.debug("SetReport callback registered successfully")
def handle_get_protocol(self, pdu: bytes):
if self.get_protocol_cb is None:
logger.debug("GetProtocol callback not registered !!")
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
return
ret = self.get_protocol_cb()
assert ret is not None
if ret.status == self.GetSetReturn.SUCCESS:
self.send_control_data(Message.ReportType.OTHER_REPORT, ret.data)
else:
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
def register_get_protocol_cb(self, cb: Callable[[], None]) -> None:
self.get_protocol_cb = cb
logger.debug("GetProtocol callback registered successfully")
def handle_set_protocol(self, pdu: bytes):
if self.set_protocol_cb is None:
logger.debug("SetProtocol callback not registered !!")
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
return
ret = self.set_protocol_cb(pdu[0] & 0x01)
assert ret is not None
if ret.status == self.GetSetReturn.SUCCESS:
self.send_handshake_message(Message.Handshake.SUCCESSFUL)
else:
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
def register_set_protocol_cb(self, cb: Callable[[int], None]) -> None:
self.set_protocol_cb = cb
logger.debug("SetProtocol callback registered successfully")
# -----------------------------------------------------------------------------
class Host(HID):
def __init__(self, device: device.Device) -> None:
super().__init__(device, HID.Role.HOST)
def get_report(self, report_type: int, report_id: int, buffer_size: int) -> None: def get_report(self, report_type: int, report_id: int, buffer_size: int) -> None:
msg = GetReportMessage( msg = GetReportMessage(
@@ -504,52 +283,50 @@ class Host(HID):
logger.debug(f'>>> HID CONTROL GET REPORT, PDU: {hid_message.hex()}') logger.debug(f'>>> HID CONTROL GET REPORT, PDU: {hid_message.hex()}')
self.send_pdu_on_ctrl(hid_message) self.send_pdu_on_ctrl(hid_message)
def set_report(self, report_type: int, data: bytes) -> None: def set_report(self, report_type: int, data: bytes):
msg = SetReportMessage(report_type=report_type, data=data) msg = SetReportMessage(report_type=report_type, data=data)
hid_message = bytes(msg) hid_message = bytes(msg)
logger.debug(f'>>> HID CONTROL SET REPORT, PDU:{hid_message.hex()}') logger.debug(f'>>> HID CONTROL SET REPORT, PDU:{hid_message.hex()}')
self.send_pdu_on_ctrl(hid_message) self.send_pdu_on_ctrl(hid_message)
def get_protocol(self) -> None: def get_protocol(self):
msg = GetProtocolMessage() msg = GetProtocolMessage()
hid_message = bytes(msg) hid_message = bytes(msg)
logger.debug(f'>>> HID CONTROL GET PROTOCOL, PDU: {hid_message.hex()}') logger.debug(f'>>> HID CONTROL GET PROTOCOL, PDU: {hid_message.hex()}')
self.send_pdu_on_ctrl(hid_message) self.send_pdu_on_ctrl(hid_message)
def set_protocol(self, protocol_mode: int) -> None: def set_protocol(self, protocol_mode: int):
msg = SetProtocolMessage(protocol_mode=protocol_mode) msg = SetProtocolMessage(protocol_mode=protocol_mode)
hid_message = bytes(msg) hid_message = bytes(msg)
logger.debug(f'>>> HID CONTROL SET PROTOCOL, PDU: {hid_message.hex()}') logger.debug(f'>>> HID CONTROL SET PROTOCOL, PDU: {hid_message.hex()}')
self.send_pdu_on_ctrl(hid_message) self.send_pdu_on_ctrl(hid_message)
def suspend(self) -> None: def send_pdu_on_ctrl(self, msg: bytes) -> None:
self.l2cap_ctrl_channel.send_pdu(msg) # type: ignore
def send_pdu_on_intr(self, msg: bytes) -> None:
self.l2cap_intr_channel.send_pdu(msg) # type: ignore
def send_data(self, data):
msg = SendData(data)
hid_message = bytes(msg)
logger.debug(f'>>> HID INTERRUPT SEND DATA, PDU: {hid_message.hex()}')
self.send_pdu_on_intr(hid_message)
def suspend(self):
msg = Suspend() msg = Suspend()
hid_message = bytes(msg) hid_message = bytes(msg)
logger.debug(f'>>> HID CONTROL SUSPEND, PDU:{hid_message.hex()}') logger.debug(f'>>> HID CONTROL SUSPEND, PDU:{hid_message.hex()}')
self.send_pdu_on_ctrl(hid_message) self.send_pdu_on_ctrl(msg)
def exit_suspend(self) -> None: def exit_suspend(self):
msg = ExitSuspend() msg = ExitSuspend()
hid_message = bytes(msg) hid_message = bytes(msg)
logger.debug(f'>>> HID CONTROL EXIT SUSPEND, PDU:{hid_message.hex()}') logger.debug(f'>>> HID CONTROL EXIT SUSPEND, PDU:{hid_message.hex()}')
self.send_pdu_on_ctrl(hid_message) self.send_pdu_on_ctrl(msg)
@override def virtual_cable_unplug(self):
def on_ctrl_pdu(self, pdu: bytes) -> None: msg = VirtualCableUnplug()
logger.debug(f'<<< HID CONTROL PDU: {pdu.hex()}') hid_message = bytes(msg)
param = pdu[0] & 0x0F logger.debug(f'>>> HID CONTROL VIRTUAL CABLE UNPLUG, PDU: {hid_message.hex()}')
message_type = pdu[0] >> 4 self.send_pdu_on_ctrl(msg)
if message_type == Message.MessageType.HANDSHAKE:
logger.debug(f'<<< HID HANDSHAKE: {Message.Handshake(param).name}')
self.emit('handshake', Message.Handshake(param))
elif message_type == Message.MessageType.DATA:
logger.debug('<<< HID CONTROL DATA')
self.emit('control_data', pdu)
elif message_type == Message.MessageType.CONTROL:
if param == Message.ControlCommand.VIRTUAL_CABLE_UNPLUG:
logger.debug('<<< HID VIRTUAL CABLE UNPLUG')
self.emit('virtual_cable_unplug')
else:
logger.debug('<<< HID CONTROL OPERATION UNSUPPORTED')
else:
logger.debug('<<< HID MESSAGE TYPE UNSUPPORTED')
+232 -512
View File
File diff suppressed because it is too large Load Diff
+8 -11
View File
@@ -25,8 +25,7 @@ import asyncio
import logging import logging
import os import os
import json import json
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from typing_extensions import Self
from .colors import color from .colors import color
from .hci import Address from .hci import Address
@@ -129,10 +128,10 @@ class PairingKeys:
def print(self, prefix=''): def print(self, prefix=''):
keys_dict = self.to_dict() keys_dict = self.to_dict()
for container_property, value in keys_dict.items(): for (container_property, value) in keys_dict.items():
if isinstance(value, dict): if isinstance(value, dict):
print(f'{prefix}{color(container_property, "cyan")}:') print(f'{prefix}{color(container_property, "cyan")}:')
for key_property, key_value in value.items(): for (key_property, key_value) in value.items():
print(f'{prefix} {color(key_property, "green")}: {key_value}') print(f'{prefix} {color(key_property, "green")}: {key_value}')
else: else:
print(f'{prefix}{color(container_property, "cyan")}: {value}') print(f'{prefix}{color(container_property, "cyan")}: {value}')
@@ -159,7 +158,7 @@ class KeyStore:
async def get_resolving_keys(self): async def get_resolving_keys(self):
all_keys = await self.get_all() all_keys = await self.get_all()
resolving_keys = [] resolving_keys = []
for name, keys in all_keys: for (name, keys) in all_keys:
if keys.irk is not None: if keys.irk is not None:
if keys.address_type is None: if keys.address_type is None:
address_type = Address.RANDOM_DEVICE_ADDRESS address_type = Address.RANDOM_DEVICE_ADDRESS
@@ -172,7 +171,7 @@ class KeyStore:
async def print(self, prefix=''): async def print(self, prefix=''):
entries = await self.get_all() entries = await self.get_all()
separator = '' separator = ''
for name, keys in entries: for (name, keys) in entries:
print(separator + prefix + color(name, 'yellow')) print(separator + prefix + color(name, 'yellow'))
keys.print(prefix=prefix + ' ') keys.print(prefix=prefix + ' ')
separator = '\n' separator = '\n'
@@ -254,10 +253,8 @@ class JsonKeyStore(KeyStore):
logger.debug(f'JSON keystore: {self.filename}') logger.debug(f'JSON keystore: {self.filename}')
@classmethod @staticmethod
def from_device( def from_device(device: Device, filename=None) -> Optional[JsonKeyStore]:
cls: Type[Self], device: Device, filename: Optional[str] = None
) -> Self:
if not filename: if not filename:
# Extract the filename from the config if there is one # Extract the filename from the config if there is one
if device.config.keystore is not None: if device.config.keystore is not None:
@@ -273,7 +270,7 @@ class JsonKeyStore(KeyStore):
else: else:
namespace = JsonKeyStore.DEFAULT_NAMESPACE namespace = JsonKeyStore.DEFAULT_NAMESPACE
return cls(namespace, filename) return JsonKeyStore(namespace, filename)
async def load(self): async def load(self):
# Try to open the file, without failing. If the file does not exist, it # Try to open the file, without failing. If the file does not exist, it
+14 -33
View File
@@ -70,7 +70,6 @@ L2CAP_LE_SIGNALING_CID = 0x05
L2CAP_MIN_LE_MTU = 23 L2CAP_MIN_LE_MTU = 23
L2CAP_MIN_BR_EDR_MTU = 48 L2CAP_MIN_BR_EDR_MTU = 48
L2CAP_MAX_BR_EDR_MTU = 65535
L2CAP_DEFAULT_MTU = 2048 # Default value for the MTU we are willing to accept L2CAP_DEFAULT_MTU = 2048 # Default value for the MTU we are willing to accept
@@ -150,10 +149,9 @@ L2CAP_INVALID_CID_IN_REQUEST_REASON = 0x0002
L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_CREDITS = 65535 L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_CREDITS = 65535
L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU = 23 L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU = 23
L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MTU = 65535
L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MPS = 23 L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MPS = 23
L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MPS = 65533 L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MPS = 65533
L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU = 2048 L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU = 2046
L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS = 2048 L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS = 2048
L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS = 256 L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS = 256
@@ -174,7 +172,7 @@ L2CAP_MTU_CONFIGURATION_PARAMETER_TYPE = 0x01
@dataclasses.dataclass @dataclasses.dataclass
class ClassicChannelSpec: class ClassicChannelSpec:
psm: Optional[int] = None psm: Optional[int] = None
mtu: int = L2CAP_DEFAULT_MTU mtu: int = L2CAP_MIN_BR_EDR_MTU
@dataclasses.dataclass @dataclasses.dataclass
@@ -190,11 +188,8 @@ class LeCreditBasedChannelSpec:
or self.max_credits > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_CREDITS or self.max_credits > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_CREDITS
): ):
raise ValueError('max credits out of range') raise ValueError('max credits out of range')
if ( if self.mtu < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU:
self.mtu < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU raise ValueError('MTU too small')
or self.mtu > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MTU
):
raise ValueError('MTU out of range')
if ( if (
self.mps < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MPS self.mps < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MPS
or self.mps > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MPS or self.mps > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MPS
@@ -209,7 +204,7 @@ class L2CAP_PDU:
@staticmethod @staticmethod
def from_bytes(data: bytes) -> L2CAP_PDU: def from_bytes(data: bytes) -> L2CAP_PDU:
# Check parameters # Sanity check
if len(data) < 4: if len(data) < 4:
raise ValueError('not enough data for L2CAP header') raise ValueError('not enough data for L2CAP header')
@@ -396,9 +391,6 @@ class L2CAP_Connection_Request(L2CAP_Control_Frame):
See Bluetooth spec @ Vol 3, Part A - 4.2 CONNECTION REQUEST See Bluetooth spec @ Vol 3, Part A - 4.2 CONNECTION REQUEST
''' '''
psm: int
source_cid: int
@staticmethod @staticmethod
def parse_psm(data: bytes, offset: int = 0) -> Tuple[int, int]: def parse_psm(data: bytes, offset: int = 0) -> Tuple[int, int]:
psm_length = 2 psm_length = 2
@@ -440,11 +432,6 @@ class L2CAP_Connection_Response(L2CAP_Control_Frame):
See Bluetooth spec @ Vol 3, Part A - 4.3 CONNECTION RESPONSE See Bluetooth spec @ Vol 3, Part A - 4.3 CONNECTION RESPONSE
''' '''
source_cid: int
destination_cid: int
status: int
result: int
CONNECTION_SUCCESSFUL = 0x0000 CONNECTION_SUCCESSFUL = 0x0000
CONNECTION_PENDING = 0x0001 CONNECTION_PENDING = 0x0001
CONNECTION_REFUSED_PSM_NOT_SUPPORTED = 0x0002 CONNECTION_REFUSED_PSM_NOT_SUPPORTED = 0x0002
@@ -750,8 +737,6 @@ class ClassicChannel(EventEmitter):
sink: Optional[Callable[[bytes], Any]] sink: Optional[Callable[[bytes], Any]]
state: State state: State
connection: Connection connection: Connection
mtu: int
peer_mtu: int
def __init__( def __init__(
self, self,
@@ -768,7 +753,6 @@ class ClassicChannel(EventEmitter):
self.signaling_cid = signaling_cid self.signaling_cid = signaling_cid
self.state = self.State.CLOSED self.state = self.State.CLOSED
self.mtu = mtu self.mtu = mtu
self.peer_mtu = L2CAP_MIN_BR_EDR_MTU
self.psm = psm self.psm = psm
self.source_cid = source_cid self.source_cid = source_cid
self.destination_cid = 0 self.destination_cid = 0
@@ -833,9 +817,7 @@ class ClassicChannel(EventEmitter):
# Wait for the connection to succeed or fail # Wait for the connection to succeed or fail
try: try:
return await self.connection.abort_on( return await self.connection_result
'disconnection', self.connection_result
)
finally: finally:
self.connection_result = None self.connection_result = None
@@ -867,7 +849,7 @@ class ClassicChannel(EventEmitter):
[ [
( (
L2CAP_MAXIMUM_TRANSMISSION_UNIT_CONFIGURATION_OPTION_TYPE, L2CAP_MAXIMUM_TRANSMISSION_UNIT_CONFIGURATION_OPTION_TYPE,
struct.pack('<H', self.mtu), struct.pack('<H', L2CAP_DEFAULT_MTU),
) )
] ]
) )
@@ -932,8 +914,8 @@ class ClassicChannel(EventEmitter):
options = L2CAP_Control_Frame.decode_configuration_options(request.options) options = L2CAP_Control_Frame.decode_configuration_options(request.options)
for option in options: for option in options:
if option[0] == L2CAP_MTU_CONFIGURATION_PARAMETER_TYPE: if option[0] == L2CAP_MTU_CONFIGURATION_PARAMETER_TYPE:
self.peer_mtu = struct.unpack('<H', option[1])[0] self.mtu = struct.unpack('<H', option[1])[0]
logger.debug(f'peer MTU = {self.peer_mtu}') logger.debug(f'MTU = {self.mtu}')
self.send_control_frame( self.send_control_frame(
L2CAP_Configure_Response( L2CAP_Configure_Response(
@@ -1032,7 +1014,7 @@ class ClassicChannel(EventEmitter):
return ( return (
f'Channel({self.source_cid}->{self.destination_cid}, ' f'Channel({self.source_cid}->{self.destination_cid}, '
f'PSM={self.psm}, ' f'PSM={self.psm}, '
f'MTU={self.mtu}/{self.peer_mtu}, ' f'MTU={self.mtu}, '
f'state={self.state.name})' f'state={self.state.name})'
) )
@@ -1654,13 +1636,12 @@ class ChannelManager:
def send_pdu(self, connection, cid: int, pdu: Union[SupportsBytes, bytes]) -> None: def send_pdu(self, connection, cid: int, pdu: Union[SupportsBytes, bytes]) -> None:
pdu_str = pdu.hex() if isinstance(pdu, bytes) else str(pdu) pdu_str = pdu.hex() if isinstance(pdu, bytes) else str(pdu)
pdu_bytes = bytes(pdu)
logger.debug( logger.debug(
f'{color(">>> Sending L2CAP PDU", "blue")} ' f'{color(">>> Sending L2CAP PDU", "blue")} '
f'on connection [0x{connection.handle:04X}] (CID={cid}) ' f'on connection [0x{connection.handle:04X}] (CID={cid}) '
f'{connection.peer_address}: {len(pdu_bytes)} bytes, {pdu_str}' f'{connection.peer_address}: {pdu_str}'
) )
self.host.send_l2cap_pdu(connection.handle, cid, pdu_bytes) self.host.send_l2cap_pdu(connection.handle, cid, bytes(pdu))
def on_pdu(self, connection: Connection, cid: int, pdu: bytes) -> None: def on_pdu(self, connection: Connection, cid: int, pdu: bytes) -> None:
if cid in (L2CAP_SIGNALING_CID, L2CAP_LE_SIGNALING_CID): if cid in (L2CAP_SIGNALING_CID, L2CAP_LE_SIGNALING_CID):
@@ -1937,7 +1918,7 @@ class ChannelManager:
supervision_timeout=request.timeout, supervision_timeout=request.timeout,
min_ce_length=0, min_ce_length=0,
max_ce_length=0, max_ce_length=0,
) ) # type: ignore[call-arg]
) )
else: else:
self.send_control_frame( self.send_control_frame(
@@ -2228,7 +2209,7 @@ class ChannelManager:
# Connect # Connect
try: try:
await channel.connect() await channel.connect()
except BaseException as e: except Exception as e:
del connection_channels[source_cid] del connection_channels[source_cid]
raise e raise e
+1 -109
View File
@@ -26,13 +26,9 @@ from bumble.hci import (
HCI_SUCCESS, HCI_SUCCESS,
HCI_CONNECTION_ACCEPT_TIMEOUT_ERROR, HCI_CONNECTION_ACCEPT_TIMEOUT_ERROR,
HCI_CONNECTION_TIMEOUT_ERROR, HCI_CONNECTION_TIMEOUT_ERROR,
HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR,
HCI_PAGE_TIMEOUT_ERROR, HCI_PAGE_TIMEOUT_ERROR,
HCI_Connection_Complete_Event, HCI_Connection_Complete_Event,
) )
from bumble import controller
from typing import Optional, Set
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -61,8 +57,6 @@ class LocalLink:
Link bus for controllers to communicate with each other Link bus for controllers to communicate with each other
''' '''
controllers: Set[controller.Controller]
def __init__(self): def __init__(self):
self.controllers = set() self.controllers = set()
self.pending_connection = None self.pending_connection = None
@@ -85,9 +79,7 @@ class LocalLink:
return controller return controller
return None return None
def find_classic_controller( def find_classic_controller(self, address):
self, address: Address
) -> Optional[controller.Controller]:
for controller in self.controllers: for controller in self.controllers:
if controller.public_address == address: if controller.public_address == address:
return controller return controller
@@ -196,60 +188,6 @@ class LocalLink:
if peripheral_controller := self.find_controller(peripheral_address): if peripheral_controller := self.find_controller(peripheral_address):
peripheral_controller.on_link_encrypted(central_address, rand, ediv, ltk) peripheral_controller.on_link_encrypted(central_address, rand, ediv, ltk)
def create_cis(
self,
central_controller: controller.Controller,
peripheral_address: Address,
cig_id: int,
cis_id: int,
) -> None:
logger.debug(
f'$$$ CIS Request {central_controller.random_address} -> {peripheral_address}'
)
if peripheral_controller := self.find_controller(peripheral_address):
asyncio.get_running_loop().call_soon(
peripheral_controller.on_link_cis_request,
central_controller.random_address,
cig_id,
cis_id,
)
def accept_cis(
self,
peripheral_controller: controller.Controller,
central_address: Address,
cig_id: int,
cis_id: int,
) -> None:
logger.debug(
f'$$$ CIS Accept {peripheral_controller.random_address} -> {central_address}'
)
if central_controller := self.find_controller(central_address):
asyncio.get_running_loop().call_soon(
central_controller.on_link_cis_established, cig_id, cis_id
)
asyncio.get_running_loop().call_soon(
peripheral_controller.on_link_cis_established, cig_id, cis_id
)
def disconnect_cis(
self,
initiator_controller: controller.Controller,
peer_address: Address,
cig_id: int,
cis_id: int,
) -> None:
logger.debug(
f'$$$ CIS Disconnect {initiator_controller.random_address} -> {peer_address}'
)
if peer_controller := self.find_controller(peer_address):
asyncio.get_running_loop().call_soon(
initiator_controller.on_link_cis_disconnected, cig_id, cis_id
)
asyncio.get_running_loop().call_soon(
peer_controller.on_link_cis_disconnected, cig_id, cis_id
)
############################################################ ############################################################
# Classic handlers # Classic handlers
############################################################ ############################################################
@@ -333,52 +271,6 @@ class LocalLink:
initiator_controller.public_address, int(not (initiator_new_role)) initiator_controller.public_address, int(not (initiator_new_role))
) )
def classic_sco_connect(
self,
initiator_controller: controller.Controller,
responder_address: Address,
link_type: int,
):
logger.debug(
f'[Classic] {initiator_controller.public_address} connects SCO to {responder_address}'
)
responder_controller = self.find_classic_controller(responder_address)
# Initiator controller should handle it.
assert responder_controller
responder_controller.on_classic_connection_request(
initiator_controller.public_address,
link_type,
)
def classic_accept_sco_connection(
self,
responder_controller: controller.Controller,
initiator_address: Address,
link_type: int,
):
logger.debug(
f'[Classic] {responder_controller.public_address} accepts to connect SCO {initiator_address}'
)
initiator_controller = self.find_classic_controller(initiator_address)
if initiator_controller is None:
responder_controller.on_classic_sco_connection_complete(
responder_controller.public_address,
HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR,
link_type,
)
return
async def task():
initiator_controller.on_classic_sco_connection_complete(
responder_controller.public_address, HCI_SUCCESS, link_type
)
asyncio.create_task(task())
responder_controller.on_classic_sco_connection_complete(
initiator_controller.public_address, HCI_SUCCESS, link_type
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class RemoteLink: class RemoteLink:
+1 -67
View File
@@ -15,9 +15,7 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations
import enum import enum
from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Tuple
from .hci import ( from .hci import (
@@ -37,60 +35,7 @@ from .smp import (
SMP_ID_KEY_DISTRIBUTION_FLAG, SMP_ID_KEY_DISTRIBUTION_FLAG,
SMP_SIGN_KEY_DISTRIBUTION_FLAG, SMP_SIGN_KEY_DISTRIBUTION_FLAG,
SMP_LINK_KEY_DISTRIBUTION_FLAG, SMP_LINK_KEY_DISTRIBUTION_FLAG,
OobContext,
OobLegacyContext,
OobSharedData,
) )
from .core import AdvertisingData, LeRole
# -----------------------------------------------------------------------------
@dataclass
class OobData:
"""OOB data that can be sent from one device to another."""
address: Optional[Address] = None
role: Optional[LeRole] = None
shared_data: Optional[OobSharedData] = None
legacy_context: Optional[OobLegacyContext] = None
@classmethod
def from_ad(cls, ad: AdvertisingData) -> OobData:
instance = cls()
shared_data_c: Optional[bytes] = None
shared_data_r: Optional[bytes] = None
for ad_type, ad_data in ad.ad_structures:
if ad_type == AdvertisingData.LE_BLUETOOTH_DEVICE_ADDRESS:
instance.address = Address(ad_data)
elif ad_type == AdvertisingData.LE_ROLE:
instance.role = LeRole(ad_data[0])
elif ad_type == AdvertisingData.LE_SECURE_CONNECTIONS_CONFIRMATION_VALUE:
shared_data_c = ad_data
elif ad_type == AdvertisingData.LE_SECURE_CONNECTIONS_RANDOM_VALUE:
shared_data_r = ad_data
elif ad_type == AdvertisingData.SECURITY_MANAGER_TK_VALUE:
instance.legacy_context = OobLegacyContext(tk=ad_data)
if shared_data_c and shared_data_r:
instance.shared_data = OobSharedData(c=shared_data_c, r=shared_data_r)
return instance
def to_ad(self) -> AdvertisingData:
ad_structures = []
if self.address is not None:
ad_structures.append(
(AdvertisingData.LE_BLUETOOTH_DEVICE_ADDRESS, bytes(self.address))
)
if self.role is not None:
ad_structures.append((AdvertisingData.LE_ROLE, bytes([self.role])))
if self.shared_data is not None:
ad_structures.extend(self.shared_data.to_ad().ad_structures)
if self.legacy_context is not None:
ad_structures.append(
(AdvertisingData.SECURITY_MANAGER_TK_VALUE, self.legacy_context.tk)
)
return AdvertisingData(ad_structures)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -228,14 +173,6 @@ class PairingConfig:
PUBLIC = Address.PUBLIC_DEVICE_ADDRESS PUBLIC = Address.PUBLIC_DEVICE_ADDRESS
RANDOM = Address.RANDOM_DEVICE_ADDRESS RANDOM = Address.RANDOM_DEVICE_ADDRESS
@dataclass
class OobConfig:
"""Config for OOB pairing."""
our_context: Optional[OobContext]
peer_data: Optional[OobSharedData]
legacy_context: Optional[OobLegacyContext]
def __init__( def __init__(
self, self,
sc: bool = True, sc: bool = True,
@@ -243,20 +180,17 @@ class PairingConfig:
bonding: bool = True, bonding: bool = True,
delegate: Optional[PairingDelegate] = None, delegate: Optional[PairingDelegate] = None,
identity_address_type: Optional[AddressType] = None, identity_address_type: Optional[AddressType] = None,
oob: Optional[OobConfig] = None,
) -> None: ) -> None:
self.sc = sc self.sc = sc
self.mitm = mitm self.mitm = mitm
self.bonding = bonding self.bonding = bonding
self.delegate = delegate or PairingDelegate() self.delegate = delegate or PairingDelegate()
self.identity_address_type = identity_address_type self.identity_address_type = identity_address_type
self.oob = oob
def __str__(self) -> str: def __str__(self) -> str:
return ( return (
f'PairingConfig(sc={self.sc}, ' f'PairingConfig(sc={self.sc}, '
f'mitm={self.mitm}, bonding={self.bonding}, ' f'mitm={self.mitm}, bonding={self.bonding}, '
f'identity_address_type={self.identity_address_type}, ' f'identity_address_type={self.identity_address_type}, '
f'delegate[{self.delegate.io_capability}]), ' f'delegate[{self.delegate.io_capability}])'
f'oob[{self.oob}])'
) )
+33 -180
View File
@@ -28,18 +28,14 @@ from bumble.core import (
BT_PERIPHERAL_ROLE, BT_PERIPHERAL_ROLE,
UUID, UUID,
AdvertisingData, AdvertisingData,
Appearance,
ConnectionError, ConnectionError,
) )
from bumble.device import ( from bumble.device import (
DEVICE_DEFAULT_SCAN_INTERVAL, DEVICE_DEFAULT_SCAN_INTERVAL,
DEVICE_DEFAULT_SCAN_WINDOW, DEVICE_DEFAULT_SCAN_WINDOW,
Advertisement, Advertisement,
AdvertisingParameters,
AdvertisingEventProperties,
AdvertisingType, AdvertisingType,
Device, Device,
Phy,
) )
from bumble.gatt import Service from bumble.gatt import Service
from bumble.hci import ( from bumble.hci import (
@@ -51,12 +47,9 @@ from bumble.hci import (
from google.protobuf import any_pb2 # pytype: disable=pyi-error from google.protobuf import any_pb2 # pytype: disable=pyi-error
from google.protobuf import empty_pb2 # pytype: disable=pyi-error from google.protobuf import empty_pb2 # pytype: disable=pyi-error
from pandora.host_grpc_aio import HostServicer from pandora.host_grpc_aio import HostServicer
from pandora import host_pb2
from pandora.host_pb2 import ( from pandora.host_pb2 import (
NOT_CONNECTABLE, NOT_CONNECTABLE,
NOT_DISCOVERABLE, NOT_DISCOVERABLE,
DISCOVERABLE_LIMITED,
DISCOVERABLE_GENERAL,
PRIMARY_1M, PRIMARY_1M,
PRIMARY_CODED, PRIMARY_CODED,
SECONDARY_1M, SECONDARY_1M,
@@ -72,7 +65,6 @@ from pandora.host_pb2 import (
ConnectResponse, ConnectResponse,
DataTypes, DataTypes,
DisconnectRequest, DisconnectRequest,
DiscoverabilityMode,
InquiryResponse, InquiryResponse,
PrimaryPhy, PrimaryPhy,
ReadLocalAddressResponse, ReadLocalAddressResponse,
@@ -102,25 +94,6 @@ SECONDARY_PHY_MAP: Dict[int, SecondaryPhy] = {
3: SECONDARY_CODED, 3: SECONDARY_CODED,
} }
PRIMARY_PHY_TO_BUMBLE_PHY_MAP: Dict[PrimaryPhy, Phy] = {
PRIMARY_1M: Phy.LE_1M,
PRIMARY_CODED: Phy.LE_CODED,
}
SECONDARY_PHY_TO_BUMBLE_PHY_MAP: Dict[SecondaryPhy, Phy] = {
SECONDARY_NONE: Phy.LE_1M,
SECONDARY_1M: Phy.LE_1M,
SECONDARY_2M: Phy.LE_2M,
SECONDARY_CODED: Phy.LE_CODED,
}
OWN_ADDRESS_MAP: Dict[host_pb2.OwnAddressType, bumble.hci.OwnAddressType] = {
host_pb2.PUBLIC: bumble.hci.OwnAddressType.PUBLIC,
host_pb2.RANDOM: bumble.hci.OwnAddressType.RANDOM,
host_pb2.RESOLVABLE_OR_PUBLIC: bumble.hci.OwnAddressType.RESOLVABLE_OR_PUBLIC,
host_pb2.RESOLVABLE_OR_RANDOM: bumble.hci.OwnAddressType.RESOLVABLE_OR_RANDOM,
}
class HostService(HostServicer): class HostService(HostServicer):
waited_connections: Set[int] waited_connections: Set[int]
@@ -288,9 +261,9 @@ class HostService(HostServicer):
self.log.debug(f"WaitDisconnection: {connection_handle}") self.log.debug(f"WaitDisconnection: {connection_handle}")
if connection := self.device.lookup_connection(connection_handle): if connection := self.device.lookup_connection(connection_handle):
disconnection_future: asyncio.Future[None] = ( disconnection_future: asyncio.Future[
asyncio.get_running_loop().create_future() None
) ] = asyncio.get_running_loop().create_future()
def on_disconnection(_: None) -> None: def on_disconnection(_: None) -> None:
disconnection_future.set_result(None) disconnection_future.set_result(None)
@@ -308,118 +281,14 @@ class HostService(HostServicer):
async def Advertise( async def Advertise(
self, request: AdvertiseRequest, context: grpc.ServicerContext self, request: AdvertiseRequest, context: grpc.ServicerContext
) -> AsyncGenerator[AdvertiseResponse, None]: ) -> AsyncGenerator[AdvertiseResponse, None]:
try: if not request.legacy:
if request.legacy: raise NotImplementedError(
async for rsp in self.legacy_advertise(request, context): "TODO: add support for extended advertising in Bumble"
yield rsp
else:
async for rsp in self.extended_advertise(request, context):
yield rsp
finally:
pass
async def extended_advertise(
self, request: AdvertiseRequest, context: grpc.ServicerContext
) -> AsyncGenerator[AdvertiseResponse, None]:
advertising_data = bytes(self.unpack_data_types(request.data))
scan_response_data = bytes(self.unpack_data_types(request.scan_response_data))
scannable = len(scan_response_data) != 0
advertising_event_properties = AdvertisingEventProperties(
is_connectable=request.connectable,
is_scannable=scannable,
is_directed=request.target is not None,
is_high_duty_cycle_directed_connectable=False,
is_legacy=False,
is_anonymous=False,
include_tx_power=False,
)
peer_address = Address.ANY
if request.target:
# Need to reverse bytes order since Bumble Address is using MSB.
target_bytes = bytes(reversed(request.target))
if request.target_variant() == "public":
peer_address = Address(target_bytes, Address.PUBLIC_DEVICE_ADDRESS)
else:
peer_address = Address(target_bytes, Address.RANDOM_DEVICE_ADDRESS)
advertising_parameters = AdvertisingParameters(
advertising_event_properties=advertising_event_properties,
own_address_type=OWN_ADDRESS_MAP[request.own_address_type],
peer_address=peer_address,
primary_advertising_phy=PRIMARY_PHY_TO_BUMBLE_PHY_MAP[request.primary_phy],
secondary_advertising_phy=SECONDARY_PHY_TO_BUMBLE_PHY_MAP[
request.secondary_phy
],
)
if advertising_interval := request.interval:
advertising_parameters.primary_advertising_interval_min = int(
advertising_interval
) )
advertising_parameters.primary_advertising_interval_max = int( if request.interval:
advertising_interval raise NotImplementedError("TODO: add support for `request.interval`")
) if request.interval_range:
if interval_range := request.interval_range: raise NotImplementedError("TODO: add support for `request.interval_range`")
advertising_parameters.primary_advertising_interval_max += int(
interval_range
)
advertising_set = await self.device.create_advertising_set(
advertising_parameters=advertising_parameters,
advertising_data=advertising_data,
scan_response_data=scan_response_data,
)
pending_connection: asyncio.Future[bumble.device.Connection] = (
asyncio.get_running_loop().create_future()
)
if request.connectable:
def on_connection(connection: bumble.device.Connection) -> None:
if (
connection.transport == BT_LE_TRANSPORT
and connection.role == BT_PERIPHERAL_ROLE
):
pending_connection.set_result(connection)
self.device.on('connection', on_connection)
try:
# Advertise until RPC is canceled
while True:
if not advertising_set.enabled:
self.log.debug('Advertise (extended)')
await advertising_set.start()
if not request.connectable:
await asyncio.sleep(1)
continue
connection = await pending_connection
pending_connection = asyncio.get_running_loop().create_future()
cookie = any_pb2.Any(value=connection.handle.to_bytes(4, 'big'))
yield AdvertiseResponse(connection=Connection(cookie=cookie))
await asyncio.sleep(1)
finally:
try:
self.log.debug('Stop Advertise (extended)')
await advertising_set.stop()
await advertising_set.remove()
except Exception:
pass
async def legacy_advertise(
self, request: AdvertiseRequest, context: grpc.ServicerContext
) -> AsyncGenerator[AdvertiseResponse, None]:
if advertising_interval := request.interval:
self.device.config.advertising_interval_min = int(advertising_interval)
self.device.config.advertising_interval_max = int(advertising_interval)
if interval_range := request.interval_range:
self.device.config.advertising_interval_max += int(interval_range)
if request.primary_phy: if request.primary_phy:
raise NotImplementedError("TODO: add support for `request.primary_phy`") raise NotImplementedError("TODO: add support for `request.primary_phy`")
if request.secondary_phy: if request.secondary_phy:
@@ -487,10 +356,14 @@ class HostService(HostServicer):
target_bytes = bytes(reversed(request.target)) target_bytes = bytes(reversed(request.target))
if request.target_variant() == "public": if request.target_variant() == "public":
target = Address(target_bytes, Address.PUBLIC_DEVICE_ADDRESS) target = Address(target_bytes, Address.PUBLIC_DEVICE_ADDRESS)
advertising_type = AdvertisingType.DIRECTED_CONNECTABLE_LOW_DUTY advertising_type = (
AdvertisingType.DIRECTED_CONNECTABLE_HIGH_DUTY
) # FIXME: HIGH_DUTY ?
else: else:
target = Address(target_bytes, Address.RANDOM_DEVICE_ADDRESS) target = Address(target_bytes, Address.RANDOM_DEVICE_ADDRESS)
advertising_type = AdvertisingType.DIRECTED_CONNECTABLE_LOW_DUTY advertising_type = (
AdvertisingType.DIRECTED_CONNECTABLE_HIGH_DUTY
) # FIXME: HIGH_DUTY ?
if request.connectable: if request.connectable:
@@ -517,9 +390,9 @@ class HostService(HostServicer):
await asyncio.sleep(1) await asyncio.sleep(1)
continue continue
pending_connection: asyncio.Future[bumble.device.Connection] = ( pending_connection: asyncio.Future[
asyncio.get_running_loop().create_future() bumble.device.Connection
) ] = asyncio.get_running_loop().create_future()
self.log.debug('Wait for LE connection...') self.log.debug('Wait for LE connection...')
connection = await pending_connection connection = await pending_connection
@@ -548,15 +421,10 @@ class HostService(HostServicer):
self, request: ScanRequest, context: grpc.ServicerContext self, request: ScanRequest, context: grpc.ServicerContext
) -> AsyncGenerator[ScanningResponse, None]: ) -> AsyncGenerator[ScanningResponse, None]:
# TODO: modify `start_scanning` to accept floats instead of int for ms values # TODO: modify `start_scanning` to accept floats instead of int for ms values
self.log.debug('Scan') if request.phys:
raise NotImplementedError("TODO: add support for `request.phys`")
scanning_phys = [] self.log.debug('Scan')
if PRIMARY_1M in request.phys:
scanning_phys.append(int(Phy.LE_1M))
if PRIMARY_CODED in request.phys:
scanning_phys.append(int(Phy.LE_CODED))
if not scanning_phys:
scanning_phys = [int(Phy.LE_1M), int(Phy.LE_CODED)]
scan_queue: asyncio.Queue[Advertisement] = asyncio.Queue() scan_queue: asyncio.Queue[Advertisement] = asyncio.Queue()
handler = self.device.on('advertisement', scan_queue.put_nowait) handler = self.device.on('advertisement', scan_queue.put_nowait)
@@ -564,15 +432,12 @@ class HostService(HostServicer):
legacy=request.legacy, legacy=request.legacy,
active=not request.passive, active=not request.passive,
own_address_type=request.own_address_type, own_address_type=request.own_address_type,
scan_interval=( scan_interval=int(request.interval)
int(request.interval) if request.interval
if request.interval else DEVICE_DEFAULT_SCAN_INTERVAL,
else DEVICE_DEFAULT_SCAN_INTERVAL scan_window=int(request.window)
), if request.window
scan_window=( else DEVICE_DEFAULT_SCAN_WINDOW,
int(request.window) if request.window else DEVICE_DEFAULT_SCAN_WINDOW
),
scanning_phys=scanning_phys,
) )
try: try:
@@ -785,11 +650,9 @@ class HostService(HostServicer):
*struct.pack('<H', dt.peripheral_connection_interval_min), *struct.pack('<H', dt.peripheral_connection_interval_min),
*struct.pack( *struct.pack(
'<H', '<H',
( dt.peripheral_connection_interval_max
dt.peripheral_connection_interval_max if dt.peripheral_connection_interval_max
if dt.peripheral_connection_interval_max else dt.peripheral_connection_interval_min,
else dt.peripheral_connection_interval_min
),
), ),
] ]
), ),
@@ -871,16 +734,6 @@ class HostService(HostServicer):
) )
) )
flag_map = {
NOT_DISCOVERABLE: 0x00,
DISCOVERABLE_LIMITED: AdvertisingData.LE_LIMITED_DISCOVERABLE_MODE_FLAG,
DISCOVERABLE_GENERAL: AdvertisingData.LE_GENERAL_DISCOVERABLE_MODE_FLAG,
}
if dt.le_discoverability_mode:
flags = flag_map[dt.le_discoverability_mode]
ad_structures.append((AdvertisingData.FLAGS, flags.to_bytes(1, 'big')))
return AdvertisingData(ad_structures) return AdvertisingData(ad_structures)
def pack_data_types(self, ad: AdvertisingData) -> DataTypes: def pack_data_types(self, ad: AdvertisingData) -> DataTypes:
@@ -989,8 +842,8 @@ class HostService(HostServicer):
dt.random_target_addresses.extend( dt.random_target_addresses.extend(
[data[i * 6 :: i * 6 + 6] for i in range(int(len(data) / 6))] [data[i * 6 :: i * 6 + 6] for i in range(int(len(data) / 6))]
) )
if appearance := cast(Appearance, ad.get(AdvertisingData.APPEARANCE)): if i := cast(int, ad.get(AdvertisingData.APPEARANCE)):
dt.appearance = int(appearance) dt.appearance = i
if i := cast(int, ad.get(AdvertisingData.ADVERTISING_INTERVAL)): if i := cast(int, ad.get(AdvertisingData.ADVERTISING_INTERVAL)):
dt.advertising_interval = i dt.advertising_interval = i
if s := cast(str, ad.get(AdvertisingData.URI)): if s := cast(str, ad.get(AdvertisingData.URI)):
+7 -7
View File
@@ -110,7 +110,7 @@ class PairingDelegate(BasePairingDelegate):
event = self.add_origin(PairingEvent(just_works=empty_pb2.Empty())) event = self.add_origin(PairingEvent(just_works=empty_pb2.Empty()))
self.service.event_queue.put_nowait(event) self.service.event_queue.put_nowait(event)
answer = await anext(self.service.event_answer) # type: ignore answer = await anext(self.service.event_answer) # pytype: disable=name-error
assert answer.event == event assert answer.event == event
assert answer.answer_variant() == 'confirm' and answer.confirm is not None assert answer.answer_variant() == 'confirm' and answer.confirm is not None
return answer.confirm return answer.confirm
@@ -125,7 +125,7 @@ class PairingDelegate(BasePairingDelegate):
event = self.add_origin(PairingEvent(numeric_comparison=number)) event = self.add_origin(PairingEvent(numeric_comparison=number))
self.service.event_queue.put_nowait(event) self.service.event_queue.put_nowait(event)
answer = await anext(self.service.event_answer) # type: ignore answer = await anext(self.service.event_answer) # pytype: disable=name-error
assert answer.event == event assert answer.event == event
assert answer.answer_variant() == 'confirm' and answer.confirm is not None assert answer.answer_variant() == 'confirm' and answer.confirm is not None
return answer.confirm return answer.confirm
@@ -140,7 +140,7 @@ class PairingDelegate(BasePairingDelegate):
event = self.add_origin(PairingEvent(passkey_entry_request=empty_pb2.Empty())) event = self.add_origin(PairingEvent(passkey_entry_request=empty_pb2.Empty()))
self.service.event_queue.put_nowait(event) self.service.event_queue.put_nowait(event)
answer = await anext(self.service.event_answer) # type: ignore answer = await anext(self.service.event_answer) # pytype: disable=name-error
assert answer.event == event assert answer.event == event
if answer.answer_variant() is None: if answer.answer_variant() is None:
return None return None
@@ -157,7 +157,7 @@ class PairingDelegate(BasePairingDelegate):
event = self.add_origin(PairingEvent(pin_code_request=empty_pb2.Empty())) event = self.add_origin(PairingEvent(pin_code_request=empty_pb2.Empty()))
self.service.event_queue.put_nowait(event) self.service.event_queue.put_nowait(event)
answer = await anext(self.service.event_answer) # type: ignore answer = await anext(self.service.event_answer) # pytype: disable=name-error
assert answer.event == event assert answer.event == event
if answer.answer_variant() is None: if answer.answer_variant() is None:
return None return None
@@ -383,9 +383,9 @@ class SecurityService(SecurityServicer):
connection.transport connection.transport
] == request.level_variant() ] == request.level_variant()
wait_for_security: asyncio.Future[str] = ( wait_for_security: asyncio.Future[
asyncio.get_running_loop().create_future() str
) ] = asyncio.get_running_loop().create_future()
authenticate_task: Optional[asyncio.Future[None]] = None authenticate_task: Optional[asyncio.Future[None]] = None
pair_task: Optional[asyncio.Future[None]] = None pair_task: Optional[asyncio.Future[None]] = None
+2 -2
View File
@@ -18,7 +18,7 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import struct import struct
import logging import logging
from typing import List, Optional from typing import List
from bumble import l2cap from bumble import l2cap
from ..core import AdvertisingData from ..core import AdvertisingData
@@ -67,7 +67,7 @@ class AshaService(TemplateService):
self.emit('volume', connection, value[0]) self.emit('volume', connection, value[0])
# Handler for audio control commands # Handler for audio control commands
def on_audio_control_point_write(connection: Optional[Connection], value): def on_audio_control_point_write(connection: Connection, value):
logger.info(f'--- AUDIO CONTROL POINT Write:{value.hex()}') logger.info(f'--- AUDIO CONTROL POINT Write:{value.hex()}')
opcode = value[0] opcode = value[0]
if opcode == AshaService.OPCODE_START: if opcode == AshaService.OPCODE_START:
File diff suppressed because it is too large Load Diff
-52
View File
@@ -1,52 +0,0 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
from bumble import gatt
from bumble import gatt_client
from bumble.profiles import csip
# -----------------------------------------------------------------------------
# Server
# -----------------------------------------------------------------------------
class CommonAudioServiceService(gatt.TemplateService):
UUID = gatt.GATT_COMMON_AUDIO_SERVICE
def __init__(
self,
coordinated_set_identification_service: csip.CoordinatedSetIdentificationService,
) -> None:
self.coordinated_set_identification_service = (
coordinated_set_identification_service
)
super().__init__(
characteristics=[],
included_services=[coordinated_set_identification_service],
)
# -----------------------------------------------------------------------------
# Client
# -----------------------------------------------------------------------------
class CommonAudioServiceServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = CommonAudioServiceService
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
self.service_proxy = service_proxy
-257
View File
@@ -1,257 +0,0 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import enum
import struct
from typing import Optional, Tuple
from bumble import core
from bumble import crypto
from bumble import device
from bumble import gatt
from bumble import gatt_client
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
SET_IDENTITY_RESOLVING_KEY_LENGTH = 16
class SirkType(enum.IntEnum):
'''Coordinated Set Identification Service - 5.1 Set Identity Resolving Key.'''
ENCRYPTED = 0x00
PLAINTEXT = 0x01
class MemberLock(enum.IntEnum):
'''Coordinated Set Identification Service - 5.3 Set Member Lock.'''
UNLOCKED = 0x01
LOCKED = 0x02
# -----------------------------------------------------------------------------
# Crypto Toolbox
# -----------------------------------------------------------------------------
def s1(m: bytes) -> bytes:
'''
Coordinated Set Identification Service - 4.3 s1 SALT generation function.
'''
return crypto.aes_cmac(m[::-1], bytes(16))[::-1]
def k1(n: bytes, salt: bytes, p: bytes) -> bytes:
'''
Coordinated Set Identification Service - 4.4 k1 derivation function.
'''
t = crypto.aes_cmac(n[::-1], salt[::-1])
return crypto.aes_cmac(p[::-1], t)[::-1]
def sef(k: bytes, r: bytes) -> bytes:
'''
Coordinated Set Identification Service - 4.5 SIRK encryption function sef.
SIRK decryption function sdf shares the same algorithm. The only difference is that argument r is:
* Plaintext in encryption
* Cipher in decryption
'''
return crypto.xor(k1(k, s1(b'SIRKenc'[::-1]), b'csis'[::-1]), r)
def sih(k: bytes, r: bytes) -> bytes:
'''
Coordinated Set Identification Service - 4.7 Resolvable Set Identifier hash function sih.
'''
return crypto.e(k, r + bytes(13))[:3]
def generate_rsi(sirk: bytes) -> bytes:
'''
Coordinated Set Identification Service - 4.8 Resolvable Set Identifier generation operation.
'''
prand = crypto.generate_prand()
return sih(sirk, prand) + prand
# -----------------------------------------------------------------------------
# Server
# -----------------------------------------------------------------------------
class CoordinatedSetIdentificationService(gatt.TemplateService):
UUID = gatt.GATT_COORDINATED_SET_IDENTIFICATION_SERVICE
set_identity_resolving_key: bytes
set_identity_resolving_key_characteristic: gatt.Characteristic
coordinated_set_size_characteristic: Optional[gatt.Characteristic] = None
set_member_lock_characteristic: Optional[gatt.Characteristic] = None
set_member_rank_characteristic: Optional[gatt.Characteristic] = None
def __init__(
self,
set_identity_resolving_key: bytes,
set_identity_resolving_key_type: SirkType,
coordinated_set_size: Optional[int] = None,
set_member_lock: Optional[MemberLock] = None,
set_member_rank: Optional[int] = None,
) -> None:
if len(set_identity_resolving_key) != SET_IDENTITY_RESOLVING_KEY_LENGTH:
raise ValueError(
f'Invalid SIRK length {len(set_identity_resolving_key)}, expected {SET_IDENTITY_RESOLVING_KEY_LENGTH}'
)
characteristics = []
self.set_identity_resolving_key = set_identity_resolving_key
self.set_identity_resolving_key_type = set_identity_resolving_key_type
self.set_identity_resolving_key_characteristic = gatt.Characteristic(
uuid=gatt.GATT_SET_IDENTITY_RESOLVING_KEY_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=gatt.CharacteristicValue(read=self.on_sirk_read),
)
characteristics.append(self.set_identity_resolving_key_characteristic)
if coordinated_set_size is not None:
self.coordinated_set_size_characteristic = gatt.Characteristic(
uuid=gatt.GATT_COORDINATED_SET_SIZE_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=struct.pack('B', coordinated_set_size),
)
characteristics.append(self.coordinated_set_size_characteristic)
if set_member_lock is not None:
self.set_member_lock_characteristic = gatt.Characteristic(
uuid=gatt.GATT_SET_MEMBER_LOCK_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY
| gatt.Characteristic.Properties.WRITE,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION
| gatt.Characteristic.Permissions.WRITEABLE,
value=struct.pack('B', set_member_lock),
)
characteristics.append(self.set_member_lock_characteristic)
if set_member_rank is not None:
self.set_member_rank_characteristic = gatt.Characteristic(
uuid=gatt.GATT_SET_MEMBER_RANK_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=struct.pack('B', set_member_rank),
)
characteristics.append(self.set_member_rank_characteristic)
super().__init__(characteristics)
async def on_sirk_read(self, connection: Optional[device.Connection]) -> bytes:
if self.set_identity_resolving_key_type == SirkType.PLAINTEXT:
sirk_bytes = self.set_identity_resolving_key
else:
assert connection
if connection.transport == core.BT_LE_TRANSPORT:
key = await connection.device.get_long_term_key(
connection_handle=connection.handle, rand=b'', ediv=0
)
else:
key = await connection.device.get_link_key(connection.peer_address)
if not key:
raise RuntimeError('LTK or LinkKey is not present')
sirk_bytes = sef(key, self.set_identity_resolving_key)
return bytes([self.set_identity_resolving_key_type]) + sirk_bytes
def get_advertising_data(self) -> bytes:
return bytes(
core.AdvertisingData(
[
(
core.AdvertisingData.RESOLVABLE_SET_IDENTIFIER,
generate_rsi(self.set_identity_resolving_key),
),
]
)
)
# -----------------------------------------------------------------------------
# Client
# -----------------------------------------------------------------------------
class CoordinatedSetIdentificationProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = CoordinatedSetIdentificationService
set_identity_resolving_key: gatt_client.CharacteristicProxy
coordinated_set_size: Optional[gatt_client.CharacteristicProxy] = None
set_member_lock: Optional[gatt_client.CharacteristicProxy] = None
set_member_rank: Optional[gatt_client.CharacteristicProxy] = None
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
self.service_proxy = service_proxy
self.set_identity_resolving_key = service_proxy.get_characteristics_by_uuid(
gatt.GATT_SET_IDENTITY_RESOLVING_KEY_CHARACTERISTIC
)[0]
if characteristics := service_proxy.get_characteristics_by_uuid(
gatt.GATT_COORDINATED_SET_SIZE_CHARACTERISTIC
):
self.coordinated_set_size = characteristics[0]
if characteristics := service_proxy.get_characteristics_by_uuid(
gatt.GATT_SET_MEMBER_LOCK_CHARACTERISTIC
):
self.set_member_lock = characteristics[0]
if characteristics := service_proxy.get_characteristics_by_uuid(
gatt.GATT_SET_MEMBER_RANK_CHARACTERISTIC
):
self.set_member_rank = characteristics[0]
async def read_set_identity_resolving_key(self) -> Tuple[SirkType, bytes]:
'''Reads SIRK and decrypts if encrypted.'''
response = await self.set_identity_resolving_key.read_value()
if len(response) != SET_IDENTITY_RESOLVING_KEY_LENGTH + 1:
raise RuntimeError('Invalid SIRK value')
sirk_type = SirkType(response[0])
if sirk_type == SirkType.PLAINTEXT:
sirk = response[1:]
else:
connection = self.service_proxy.client.connection
device = connection.device
if connection.transport == core.BT_LE_TRANSPORT:
key = await device.get_long_term_key(
connection_handle=connection.handle, rand=b'', ediv=0
)
else:
key = await device.get_link_key(connection.peer_address)
if not key:
raise RuntimeError('LTK or LinkKey is not present')
sirk = sef(key, response[1:])
return (sirk_type, sirk)
+5 -14
View File
@@ -19,8 +19,8 @@
import struct import struct
from typing import Optional, Tuple from typing import Optional, Tuple
from bumble.gatt_client import ServiceProxy, ProfileServiceProxy, CharacteristicProxy from ..gatt_client import ProfileServiceProxy
from bumble.gatt import ( from ..gatt import (
GATT_DEVICE_INFORMATION_SERVICE, GATT_DEVICE_INFORMATION_SERVICE,
GATT_FIRMWARE_REVISION_STRING_CHARACTERISTIC, GATT_FIRMWARE_REVISION_STRING_CHARACTERISTIC,
GATT_HARDWARE_REVISION_STRING_CHARACTERISTIC, GATT_HARDWARE_REVISION_STRING_CHARACTERISTIC,
@@ -59,7 +59,7 @@ class DeviceInformationService(TemplateService):
firmware_revision: Optional[str] = None, firmware_revision: Optional[str] = None,
software_revision: Optional[str] = None, software_revision: Optional[str] = None,
system_id: Optional[Tuple[int, int]] = None, # (OUI, Manufacturer ID) system_id: Optional[Tuple[int, int]] = None, # (OUI, Manufacturer ID)
ieee_regulatory_certification_data_list: Optional[bytes] = None, ieee_regulatory_certification_data_list: Optional[bytes] = None
# TODO: pnp_id # TODO: pnp_id
): ):
characteristics = [ characteristics = [
@@ -104,19 +104,10 @@ class DeviceInformationService(TemplateService):
class DeviceInformationServiceProxy(ProfileServiceProxy): class DeviceInformationServiceProxy(ProfileServiceProxy):
SERVICE_CLASS = DeviceInformationService SERVICE_CLASS = DeviceInformationService
manufacturer_name: Optional[UTF8CharacteristicAdapter] def __init__(self, service_proxy):
model_number: Optional[UTF8CharacteristicAdapter]
serial_number: Optional[UTF8CharacteristicAdapter]
hardware_revision: Optional[UTF8CharacteristicAdapter]
firmware_revision: Optional[UTF8CharacteristicAdapter]
software_revision: Optional[UTF8CharacteristicAdapter]
system_id: Optional[DelegatedCharacteristicAdapter]
ieee_regulatory_certification_data_list: Optional[CharacteristicProxy]
def __init__(self, service_proxy: ServiceProxy):
self.service_proxy = service_proxy self.service_proxy = service_proxy
for field, uuid in ( for (field, uuid) in (
('manufacturer_name', GATT_MANUFACTURER_NAME_STRING_CHARACTERISTIC), ('manufacturer_name', GATT_MANUFACTURER_NAME_STRING_CHARACTERISTIC),
('model_number', GATT_MODEL_NUMBER_STRING_CHARACTERISTIC), ('model_number', GATT_MODEL_NUMBER_STRING_CHARACTERISTIC),
('serial_number', GATT_SERIAL_NUMBER_STRING_CHARACTERISTIC), ('serial_number', GATT_SERIAL_NUMBER_STRING_CHARACTERISTIC),
-49
View File
@@ -1,49 +0,0 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import dataclasses
from typing import List
from typing_extensions import Self
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class Metadata:
@dataclasses.dataclass
class Entry:
tag: int
data: bytes
entries: List[Entry]
@classmethod
def from_bytes(cls, data: bytes) -> Self:
entries = []
offset = 0
length = len(data)
while length >= 2:
entry_length = data[offset]
entry_tag = data[offset + 1]
entry_data = data[offset + 2 : offset + 2 + entry_length - 1]
entries.append(cls.Entry(entry_tag, entry_data))
length -= entry_length
offset += entry_length
return cls(entries)
-46
View File
@@ -1,46 +0,0 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import dataclasses
import enum
from typing_extensions import Self
from bumble.profiles import le_audio
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class PublicBroadcastAnnouncement:
class Features(enum.IntFlag):
ENCRYPTED = 1 << 0
STANDARD_QUALITY_CONFIGURATION = 1 << 1
HIGH_QUALITY_CONFIGURATION = 1 << 2
features: Features
metadata: le_audio.Metadata
@classmethod
def from_bytes(cls, data: bytes) -> Self:
features = cls.Features(data[0])
metadata_length = data[1]
metadata_ltv = data[1 : 1 + metadata_length]
return cls(
features=features, metadata=le_audio.Metadata.from_bytes(metadata_ltv)
)
-228
View File
@@ -1,228 +0,0 @@
# Copyright 2021-2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import enum
from bumble import att
from bumble import device
from bumble import gatt
from bumble import gatt_client
from typing import Optional
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
MIN_VOLUME = 0
MAX_VOLUME = 255
class ErrorCode(enum.IntEnum):
'''
See Volume Control Service 1.6. Application error codes.
'''
INVALID_CHANGE_COUNTER = 0x80
OPCODE_NOT_SUPPORTED = 0x81
class VolumeFlags(enum.IntFlag):
'''
See Volume Control Service 3.3. Volume Flags.
'''
VOLUME_SETTING_PERSISTED = 0x01
# RFU
class VolumeControlPointOpcode(enum.IntEnum):
'''
See Volume Control Service Table 3.3: Volume Control Point procedure requirements.
'''
# fmt: off
RELATIVE_VOLUME_DOWN = 0x00
RELATIVE_VOLUME_UP = 0x01
UNMUTE_RELATIVE_VOLUME_DOWN = 0x02
UNMUTE_RELATIVE_VOLUME_UP = 0x03
SET_ABSOLUTE_VOLUME = 0x04
UNMUTE = 0x05
MUTE = 0x06
# -----------------------------------------------------------------------------
# Server
# -----------------------------------------------------------------------------
class VolumeControlService(gatt.TemplateService):
UUID = gatt.GATT_VOLUME_CONTROL_SERVICE
volume_state: gatt.Characteristic
volume_control_point: gatt.Characteristic
volume_flags: gatt.Characteristic
volume_setting: int
muted: int
change_counter: int
def __init__(
self,
step_size: int = 16,
volume_setting: int = 0,
muted: int = 0,
change_counter: int = 0,
volume_flags: int = 0,
) -> None:
self.step_size = step_size
self.volume_setting = volume_setting
self.muted = muted
self.change_counter = change_counter
self.volume_state = gatt.Characteristic(
uuid=gatt.GATT_VOLUME_STATE_CHARACTERISTIC,
properties=(
gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY
),
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=gatt.CharacteristicValue(read=self._on_read_volume_state),
)
self.volume_control_point = gatt.Characteristic(
uuid=gatt.GATT_VOLUME_CONTROL_POINT_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.WRITE,
permissions=gatt.Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION,
value=gatt.CharacteristicValue(write=self._on_write_volume_control_point),
)
self.volume_flags = gatt.Characteristic(
uuid=gatt.GATT_VOLUME_FLAGS_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=bytes([volume_flags]),
)
super().__init__(
[
self.volume_state,
self.volume_control_point,
self.volume_flags,
]
)
@property
def volume_state_bytes(self) -> bytes:
return bytes([self.volume_setting, self.muted, self.change_counter])
@volume_state_bytes.setter
def volume_state_bytes(self, new_value: bytes) -> None:
self.volume_setting, self.muted, self.change_counter = new_value
def _on_read_volume_state(self, _connection: Optional[device.Connection]) -> bytes:
return self.volume_state_bytes
def _on_write_volume_control_point(
self, connection: Optional[device.Connection], value: bytes
) -> None:
assert connection
opcode = VolumeControlPointOpcode(value[0])
change_counter = value[1]
if change_counter != self.change_counter:
raise att.ATT_Error(ErrorCode.INVALID_CHANGE_COUNTER)
handler = getattr(self, '_on_' + opcode.name.lower())
if handler(*value[2:]):
self.change_counter = (self.change_counter + 1) % 256
connection.abort_on(
'disconnection',
connection.device.notify_subscribers(
attribute=self.volume_state,
value=self.volume_state_bytes,
),
)
self.emit(
'volume_state', self.volume_setting, self.muted, self.change_counter
)
def _on_relative_volume_down(self) -> bool:
old_volume = self.volume_setting
self.volume_setting = max(self.volume_setting - self.step_size, MIN_VOLUME)
return self.volume_setting != old_volume
def _on_relative_volume_up(self) -> bool:
old_volume = self.volume_setting
self.volume_setting = min(self.volume_setting + self.step_size, MAX_VOLUME)
return self.volume_setting != old_volume
def _on_unmute_relative_volume_down(self) -> bool:
old_volume, old_muted_state = self.volume_setting, self.muted
self.volume_setting = max(self.volume_setting - self.step_size, MIN_VOLUME)
self.muted = 0
return (self.volume_setting, self.muted) != (old_volume, old_muted_state)
def _on_unmute_relative_volume_up(self) -> bool:
old_volume, old_muted_state = self.volume_setting, self.muted
self.volume_setting = min(self.volume_setting + self.step_size, MAX_VOLUME)
self.muted = 0
return (self.volume_setting, self.muted) != (old_volume, old_muted_state)
def _on_set_absolute_volume(self, volume_setting: int) -> bool:
old_volume_setting = self.volume_setting
self.volume_setting = volume_setting
return old_volume_setting != self.volume_setting
def _on_unmute(self) -> bool:
old_muted_state = self.muted
self.muted = 0
return self.muted != old_muted_state
def _on_mute(self) -> bool:
old_muted_state = self.muted
self.muted = 1
return self.muted != old_muted_state
# -----------------------------------------------------------------------------
# Client
# -----------------------------------------------------------------------------
class VolumeControlServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = VolumeControlService
volume_control_point: gatt_client.CharacteristicProxy
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
self.service_proxy = service_proxy
self.volume_state = gatt.PackedCharacteristicAdapter(
service_proxy.get_characteristics_by_uuid(
gatt.GATT_VOLUME_STATE_CHARACTERISTIC
)[0],
'BBB',
)
self.volume_control_point = service_proxy.get_characteristics_by_uuid(
gatt.GATT_VOLUME_CONTROL_POINT_CHARACTERISTIC
)[0]
self.volume_flags = gatt.PackedCharacteristicAdapter(
service_proxy.get_characteristics_by_uuid(
gatt.GATT_VOLUME_FLAGS_CHARACTERISTIC
)[0],
'B',
)
+193 -354
View File
@@ -19,17 +19,12 @@ from __future__ import annotations
import logging import logging
import asyncio import asyncio
import collections
import dataclasses
import enum import enum
from typing import Callable, Dict, List, Optional, Tuple, Union, TYPE_CHECKING from typing import Callable, Dict, List, Optional, Tuple, Union, TYPE_CHECKING
from typing_extensions import Self
from pyee import EventEmitter from pyee import EventEmitter
from bumble import core from . import core, l2cap
from bumble import l2cap
from bumble import sdp
from .colors import color from .colors import color
from .core import ( from .core import (
UUID, UUID,
@@ -39,6 +34,15 @@ from .core import (
InvalidStateError, InvalidStateError,
ProtocolError, ProtocolError,
) )
from .sdp import (
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_PUBLIC_BROWSE_ROOT,
DataElement,
ServiceAttribute,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from bumble.device import Device, Connection from bumble.device import Device, Connection
@@ -55,20 +59,28 @@ logger = logging.getLogger(__name__)
# fmt: off # fmt: off
RFCOMM_PSM = 0x0003 RFCOMM_PSM = 0x0003
DEFAULT_RX_QUEUE_SIZE = 32
class FrameType(enum.IntEnum):
SABM = 0x2F # Control field [1,1,1,1,_,1,0,0] LSB-first
UA = 0x63 # Control field [0,1,1,0,_,0,1,1] LSB-first
DM = 0x0F # Control field [1,1,1,1,_,0,0,0] LSB-first
DISC = 0x43 # Control field [0,1,0,_,0,0,1,1] LSB-first
UIH = 0xEF # Control field [1,1,1,_,1,1,1,1] LSB-first
UI = 0x03 # Control field [0,0,0,_,0,0,1,1] LSB-first
class MccType(enum.IntEnum): # Frame types
PN = 0x20 RFCOMM_SABM_FRAME = 0x2F # Control field [1,1,1,1,_,1,0,0] LSB-first
MSC = 0x38 RFCOMM_UA_FRAME = 0x63 # Control field [0,1,1,0,_,0,1,1] LSB-first
RFCOMM_DM_FRAME = 0x0F # Control field [1,1,1,1,_,0,0,0] LSB-first
RFCOMM_DISC_FRAME = 0x43 # Control field [0,1,0,_,0,0,1,1] LSB-first
RFCOMM_UIH_FRAME = 0xEF # Control field [1,1,1,_,1,1,1,1] LSB-first
RFCOMM_UI_FRAME = 0x03 # Control field [0,0,0,_,0,0,1,1] LSB-first
RFCOMM_FRAME_TYPE_NAMES = {
RFCOMM_SABM_FRAME: 'SABM',
RFCOMM_UA_FRAME: 'UA',
RFCOMM_DM_FRAME: 'DM',
RFCOMM_DISC_FRAME: 'DISC',
RFCOMM_UIH_FRAME: 'UIH',
RFCOMM_UI_FRAME: 'UI'
}
# MCC Types
RFCOMM_MCC_PN_TYPE = 0x20
RFCOMM_MCC_MSC_TYPE = 0x38
# FCS CRC # FCS CRC
CRC_TABLE = bytes([ CRC_TABLE = bytes([
@@ -106,11 +118,8 @@ CRC_TABLE = bytes([
0XBA, 0X2B, 0X59, 0XC8, 0XBD, 0X2C, 0X5E, 0XCF 0XBA, 0X2B, 0X59, 0XC8, 0XBD, 0X2C, 0X5E, 0XCF
]) ])
RFCOMM_DEFAULT_L2CAP_MTU = 2048 RFCOMM_DEFAULT_INITIAL_RX_CREDITS = 7
RFCOMM_DEFAULT_INITIAL_CREDITS = 7 RFCOMM_DEFAULT_PREFERRED_MTU = 1280
RFCOMM_DEFAULT_MAX_CREDITS = 32
RFCOMM_DEFAULT_CREDIT_THRESHOLD = RFCOMM_DEFAULT_MAX_CREDITS // 2
RFCOMM_DEFAULT_MAX_FRAME_SIZE = 2000
RFCOMM_DYNAMIC_CHANNEL_NUMBER_START = 1 RFCOMM_DYNAMIC_CHANNEL_NUMBER_START = 1
RFCOMM_DYNAMIC_CHANNEL_NUMBER_END = 30 RFCOMM_DYNAMIC_CHANNEL_NUMBER_END = 30
@@ -121,33 +130,29 @@ RFCOMM_DYNAMIC_CHANNEL_NUMBER_END = 30
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def make_service_sdp_records( def make_service_sdp_records(
service_record_handle: int, channel: int, uuid: Optional[UUID] = None service_record_handle: int, channel: int, uuid: Optional[UUID] = None
) -> List[sdp.ServiceAttribute]: ) -> List[ServiceAttribute]:
""" """
Create SDP records for an RFComm service given a channel number and an Create SDP records for an RFComm service given a channel number and an
optional UUID. A Service Class Attribute is included only if the UUID is not None. optional UUID. A Service Class Attribute is included only if the UUID is not None.
""" """
records = [ records = [
sdp.ServiceAttribute( ServiceAttribute(
sdp.SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
sdp.DataElement.unsigned_integer_32(service_record_handle), DataElement.unsigned_integer_32(service_record_handle),
), ),
sdp.ServiceAttribute( ServiceAttribute(
sdp.SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID, SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
sdp.DataElement.sequence( DataElement.sequence([DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)]),
[sdp.DataElement.uuid(sdp.SDP_PUBLIC_BROWSE_ROOT)]
),
), ),
sdp.ServiceAttribute( ServiceAttribute(
sdp.SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
sdp.DataElement.sequence( DataElement.sequence(
[ [
sdp.DataElement.sequence( DataElement.sequence([DataElement.uuid(BT_L2CAP_PROTOCOL_ID)]),
[sdp.DataElement.uuid(BT_L2CAP_PROTOCOL_ID)] DataElement.sequence(
),
sdp.DataElement.sequence(
[ [
sdp.DataElement.uuid(BT_RFCOMM_PROTOCOL_ID), DataElement.uuid(BT_RFCOMM_PROTOCOL_ID),
sdp.DataElement.unsigned_integer_8(channel), DataElement.unsigned_integer_8(channel),
] ]
), ),
] ]
@@ -157,81 +162,15 @@ def make_service_sdp_records(
if uuid: if uuid:
records.append( records.append(
sdp.ServiceAttribute( ServiceAttribute(
sdp.SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
sdp.DataElement.sequence([sdp.DataElement.uuid(uuid)]), DataElement.sequence([DataElement.uuid(uuid)]),
) )
) )
return records return records
# -----------------------------------------------------------------------------
async def find_rfcomm_channels(connection: Connection) -> Dict[int, List[UUID]]:
"""Searches all RFCOMM channels and their associated UUID from SDP service records.
Args:
connection: ACL connection to make SDP search.
Returns:
Dictionary mapping from channel number to service class UUID list.
"""
results = {}
async with sdp.Client(connection) as sdp_client:
search_result = await sdp_client.search_attributes(
uuids=[core.BT_RFCOMM_PROTOCOL_ID],
attribute_ids=[
sdp.SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
sdp.SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
],
)
for attribute_lists in search_result:
service_classes: List[UUID] = []
channel: Optional[int] = None
for attribute in attribute_lists:
# The layout is [[L2CAP_PROTOCOL], [RFCOMM_PROTOCOL, RFCOMM_CHANNEL]].
if attribute.id == sdp.SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID:
protocol_descriptor_list = attribute.value.value
channel = protocol_descriptor_list[1].value[1].value
elif attribute.id == sdp.SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID:
service_class_id_list = attribute.value.value
service_classes = [
service_class.value for service_class in service_class_id_list
]
if not service_classes or not channel:
logger.warning(f"Bad result {attribute_lists}.")
else:
results[channel] = service_classes
return results
# -----------------------------------------------------------------------------
async def find_rfcomm_channel_with_uuid(
connection: Connection, uuid: str | UUID
) -> Optional[int]:
"""Searches an RFCOMM channel associated with given UUID from service records.
Args:
connection: ACL connection to make SDP search.
uuid: UUID of service record to search for.
Returns:
RFCOMM channel number if found, otherwise None.
"""
if isinstance(uuid, str):
uuid = UUID(uuid)
return next(
(
channel
for channel, class_id_list in (
await find_rfcomm_channels(connection)
).items()
if uuid in class_id_list
),
None,
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def compute_fcs(buffer: bytes) -> int: def compute_fcs(buffer: bytes) -> int:
result = 0xFF result = 0xFF
@@ -244,7 +183,7 @@ def compute_fcs(buffer: bytes) -> int:
class RFCOMM_Frame: class RFCOMM_Frame:
def __init__( def __init__(
self, self,
frame_type: FrameType, frame_type: int,
c_r: int, c_r: int,
dlci: int, dlci: int,
p_f: int, p_f: int,
@@ -267,11 +206,14 @@ class RFCOMM_Frame:
self.length = bytes([(length << 1) | 1]) self.length = bytes([(length << 1) | 1])
self.address = (dlci << 2) | (c_r << 1) | 1 self.address = (dlci << 2) | (c_r << 1) | 1
self.control = frame_type | (p_f << 4) self.control = frame_type | (p_f << 4)
if frame_type == FrameType.UIH: if frame_type == RFCOMM_UIH_FRAME:
self.fcs = compute_fcs(bytes([self.address, self.control])) self.fcs = compute_fcs(bytes([self.address, self.control]))
else: else:
self.fcs = compute_fcs(bytes([self.address, self.control]) + self.length) self.fcs = compute_fcs(bytes([self.address, self.control]) + self.length)
def type_name(self) -> str:
return RFCOMM_FRAME_TYPE_NAMES[self.type]
@staticmethod @staticmethod
def parse_mcc(data) -> Tuple[int, bool, bytes]: def parse_mcc(data) -> Tuple[int, bool, bytes]:
mcc_type = data[0] >> 2 mcc_type = data[0] >> 2
@@ -295,24 +237,24 @@ class RFCOMM_Frame:
@staticmethod @staticmethod
def sabm(c_r: int, dlci: int): def sabm(c_r: int, dlci: int):
return RFCOMM_Frame(FrameType.SABM, c_r, dlci, 1) return RFCOMM_Frame(RFCOMM_SABM_FRAME, c_r, dlci, 1)
@staticmethod @staticmethod
def ua(c_r: int, dlci: int): def ua(c_r: int, dlci: int):
return RFCOMM_Frame(FrameType.UA, c_r, dlci, 1) return RFCOMM_Frame(RFCOMM_UA_FRAME, c_r, dlci, 1)
@staticmethod @staticmethod
def dm(c_r: int, dlci: int): def dm(c_r: int, dlci: int):
return RFCOMM_Frame(FrameType.DM, c_r, dlci, 1) return RFCOMM_Frame(RFCOMM_DM_FRAME, c_r, dlci, 1)
@staticmethod @staticmethod
def disc(c_r: int, dlci: int): def disc(c_r: int, dlci: int):
return RFCOMM_Frame(FrameType.DISC, c_r, dlci, 1) return RFCOMM_Frame(RFCOMM_DISC_FRAME, c_r, dlci, 1)
@staticmethod @staticmethod
def uih(c_r: int, dlci: int, information: bytes, p_f: int = 0): def uih(c_r: int, dlci: int, information: bytes, p_f: int = 0):
return RFCOMM_Frame( return RFCOMM_Frame(
FrameType.UIH, c_r, dlci, p_f, information, with_credits=(p_f == 1) RFCOMM_UIH_FRAME, c_r, dlci, p_f, information, with_credits=(p_f == 1)
) )
@staticmethod @staticmethod
@@ -320,7 +262,7 @@ class RFCOMM_Frame:
# Extract fields # Extract fields
dlci = (data[0] >> 2) & 0x3F dlci = (data[0] >> 2) & 0x3F
c_r = (data[0] >> 1) & 0x01 c_r = (data[0] >> 1) & 0x01
frame_type = FrameType(data[1] & 0xEF) frame_type = data[1] & 0xEF
p_f = (data[1] >> 4) & 0x01 p_f = (data[1] >> 4) & 0x01
length = data[2] length = data[2]
if length & 0x01: if length & 0x01:
@@ -349,7 +291,7 @@ class RFCOMM_Frame:
def __str__(self) -> str: def __str__(self) -> str:
return ( return (
f'{color(self.type.name, "yellow")}' f'{color(self.type_name(), "yellow")}'
f'(c/r={self.c_r},' f'(c/r={self.c_r},'
f'dlci={self.dlci},' f'dlci={self.dlci},'
f'p/f={self.p_f},' f'p/f={self.p_f},'
@@ -359,7 +301,6 @@ class RFCOMM_Frame:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@dataclasses.dataclass
class RFCOMM_MCC_PN: class RFCOMM_MCC_PN:
dlci: int dlci: int
cl: int cl: int
@@ -367,13 +308,25 @@ class RFCOMM_MCC_PN:
ack_timer: int ack_timer: int
max_frame_size: int max_frame_size: int
max_retransmissions: int max_retransmissions: int
initial_credits: int window_size: int
def __post_init__(self) -> None: def __init__(
if self.initial_credits < 1 or self.initial_credits > 7: self,
logger.warning( dlci: int,
f'Initial credits {self.initial_credits} is out of range [1, 7].' cl: int,
) priority: int,
ack_timer: int,
max_frame_size: int,
max_retransmissions: int,
window_size: int,
) -> None:
self.dlci = dlci
self.cl = cl
self.priority = priority
self.ack_timer = ack_timer
self.max_frame_size = max_frame_size
self.max_retransmissions = max_retransmissions
self.window_size = window_size
@staticmethod @staticmethod
def from_bytes(data: bytes) -> RFCOMM_MCC_PN: def from_bytes(data: bytes) -> RFCOMM_MCC_PN:
@@ -384,7 +337,7 @@ class RFCOMM_MCC_PN:
ack_timer=data[3], ack_timer=data[3],
max_frame_size=data[4] | data[5] << 8, max_frame_size=data[4] | data[5] << 8,
max_retransmissions=data[6], max_retransmissions=data[6],
initial_credits=data[7] & 0x07, window_size=data[7],
) )
def __bytes__(self) -> bytes: def __bytes__(self) -> bytes:
@@ -397,14 +350,23 @@ class RFCOMM_MCC_PN:
self.max_frame_size & 0xFF, self.max_frame_size & 0xFF,
(self.max_frame_size >> 8) & 0xFF, (self.max_frame_size >> 8) & 0xFF,
self.max_retransmissions & 0xFF, self.max_retransmissions & 0xFF,
# Only 3 bits are meaningful. self.window_size & 0xFF,
self.initial_credits & 0x07,
] ]
) )
def __str__(self) -> str:
return (
f'PN(dlci={self.dlci},'
f'cl={self.cl},'
f'priority={self.priority},'
f'ack_timer={self.ack_timer},'
f'max_frame_size={self.max_frame_size},'
f'max_retransmissions={self.max_retransmissions},'
f'window_size={self.window_size})'
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@dataclasses.dataclass
class RFCOMM_MCC_MSC: class RFCOMM_MCC_MSC:
dlci: int dlci: int
fc: int fc: int
@@ -413,6 +375,16 @@ class RFCOMM_MCC_MSC:
ic: int ic: int
dv: int dv: int
def __init__(
self, dlci: int, fc: int, rtc: int, rtr: int, ic: int, dv: int
) -> None:
self.dlci = dlci
self.fc = fc
self.rtc = rtc
self.rtr = rtr
self.ic = ic
self.dv = dv
@staticmethod @staticmethod
def from_bytes(data: bytes) -> RFCOMM_MCC_MSC: def from_bytes(data: bytes) -> RFCOMM_MCC_MSC:
return RFCOMM_MCC_MSC( return RFCOMM_MCC_MSC(
@@ -437,6 +409,16 @@ class RFCOMM_MCC_MSC:
] ]
) )
def __str__(self) -> str:
return (
f'MSC(dlci={self.dlci},'
f'fc={self.fc},'
f'rtc={self.rtc},'
f'rtr={self.rtr},'
f'ic={self.ic},'
f'dv={self.dv})'
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class DLC(EventEmitter): class DLC(EventEmitter):
@@ -448,58 +430,35 @@ class DLC(EventEmitter):
DISCONNECTED = 0x04 DISCONNECTED = 0x04
RESET = 0x05 RESET = 0x05
connection_result: Optional[asyncio.Future]
sink: Optional[Callable[[bytes], None]]
def __init__( def __init__(
self, self,
multiplexer: Multiplexer, multiplexer: Multiplexer,
dlci: int, dlci: int,
tx_max_frame_size: int, max_frame_size: int,
tx_initial_credits: int, initial_tx_credits: int,
rx_max_frame_size: int,
rx_initial_credits: int,
) -> None: ) -> None:
super().__init__() super().__init__()
self.multiplexer = multiplexer self.multiplexer = multiplexer
self.dlci = dlci self.dlci = dlci
self.rx_max_frame_size = rx_max_frame_size self.rx_credits = RFCOMM_DEFAULT_INITIAL_RX_CREDITS
self.rx_initial_credits = rx_initial_credits self.rx_threshold = self.rx_credits // 2
self.rx_max_credits = RFCOMM_DEFAULT_MAX_CREDITS self.tx_credits = initial_tx_credits
self.rx_credits = rx_initial_credits
self.rx_credits_threshold = RFCOMM_DEFAULT_CREDIT_THRESHOLD
self.tx_max_frame_size = tx_max_frame_size
self.tx_credits = tx_initial_credits
self.tx_buffer = b'' self.tx_buffer = b''
self.state = DLC.State.INIT self.state = DLC.State.INIT
self.role = multiplexer.role self.role = multiplexer.role
self.c_r = 1 if self.role == Multiplexer.Role.INITIATOR else 0 self.c_r = 1 if self.role == Multiplexer.Role.INITIATOR else 0
self.connection_result: Optional[asyncio.Future] = None self.sink = None
self.disconnection_result: Optional[asyncio.Future] = None self.connection_result = None
self.drained = asyncio.Event()
self.drained.set()
# Queued packets when sink is not set.
self._enqueued_rx_packets: collections.deque[bytes] = collections.deque(
maxlen=DEFAULT_RX_QUEUE_SIZE
)
self._sink: Optional[Callable[[bytes], None]] = None
# Compute the MTU # Compute the MTU
max_overhead = 4 + 1 # header with 2-byte length + fcs max_overhead = 4 + 1 # header with 2-byte length + fcs
self.mtu = min( self.mtu = min(
tx_max_frame_size, self.multiplexer.l2cap_channel.peer_mtu - max_overhead max_frame_size, self.multiplexer.l2cap_channel.mtu - max_overhead
) )
@property
def sink(self) -> Optional[Callable[[bytes], None]]:
return self._sink
@sink.setter
def sink(self, sink: Optional[Callable[[bytes], None]]) -> None:
self._sink = sink
# Dump queued packets to sink
if sink:
for packet in self._enqueued_rx_packets:
sink(packet) # pylint: disable=not-callable
self._enqueued_rx_packets.clear()
def change_state(self, new_state: State) -> None: def change_state(self, new_state: State) -> None:
logger.debug(f'{self} state change -> {color(new_state.name, "magenta")}') logger.debug(f'{self} state change -> {color(new_state.name, "magenta")}')
self.state = new_state self.state = new_state
@@ -508,7 +467,7 @@ class DLC(EventEmitter):
self.multiplexer.send_frame(frame) self.multiplexer.send_frame(frame)
def on_frame(self, frame: RFCOMM_Frame) -> None: def on_frame(self, frame: RFCOMM_Frame) -> None:
handler = getattr(self, f'on_{frame.type.name}_frame'.lower()) handler = getattr(self, f'on_{frame.type_name()}_frame'.lower())
handler(frame) handler(frame)
def on_sabm_frame(self, _frame: RFCOMM_Frame) -> None: def on_sabm_frame(self, _frame: RFCOMM_Frame) -> None:
@@ -522,7 +481,9 @@ class DLC(EventEmitter):
# Exchange the modem status with the peer # Exchange the modem status with the peer
msc = RFCOMM_MCC_MSC(dlci=self.dlci, fc=0, rtc=1, rtr=1, ic=0, dv=1) msc = RFCOMM_MCC_MSC(dlci=self.dlci, fc=0, rtc=1, rtr=1, ic=0, dv=1)
mcc = RFCOMM_Frame.make_mcc(mcc_type=MccType.MSC, c_r=1, data=bytes(msc)) mcc = RFCOMM_Frame.make_mcc(
mcc_type=RFCOMM_MCC_MSC_TYPE, c_r=1, data=bytes(msc)
)
logger.debug(f'>>> MCC MSC Command: {msc}') logger.debug(f'>>> MCC MSC Command: {msc}')
self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc)) self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
@@ -530,35 +491,22 @@ class DLC(EventEmitter):
self.emit('open') self.emit('open')
def on_ua_frame(self, _frame: RFCOMM_Frame) -> None: def on_ua_frame(self, _frame: RFCOMM_Frame) -> None:
if self.state == DLC.State.CONNECTING: if self.state != DLC.State.CONNECTING:
# Exchange the modem status with the peer
msc = RFCOMM_MCC_MSC(dlci=self.dlci, fc=0, rtc=1, rtr=1, ic=0, dv=1)
mcc = RFCOMM_Frame.make_mcc(mcc_type=MccType.MSC, c_r=1, data=bytes(msc))
logger.debug(f'>>> MCC MSC Command: {msc}')
self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
self.change_state(DLC.State.CONNECTED)
if self.connection_result:
self.connection_result.set_result(None)
self.connection_result = None
self.multiplexer.on_dlc_open_complete(self)
elif self.state == DLC.State.DISCONNECTING:
self.change_state(DLC.State.DISCONNECTED)
if self.disconnection_result:
self.disconnection_result.set_result(None)
self.disconnection_result = None
self.multiplexer.on_dlc_disconnection(self)
self.emit('close')
else:
logger.warning( logger.warning(
color( color('!!! received SABM when not in CONNECTING state', 'red')
(
'!!! received UA frame when not in '
'CONNECTING or DISCONNECTING state'
),
'red',
)
) )
return
# Exchange the modem status with the peer
msc = RFCOMM_MCC_MSC(dlci=self.dlci, fc=0, rtc=1, rtr=1, ic=0, dv=1)
mcc = RFCOMM_Frame.make_mcc(
mcc_type=RFCOMM_MCC_MSC_TYPE, c_r=1, data=bytes(msc)
)
logger.debug(f'>>> MCC MSC Command: {msc}')
self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
self.change_state(DLC.State.CONNECTED)
self.multiplexer.on_dlc_open_complete(self)
def on_dm_frame(self, frame: RFCOMM_Frame) -> None: def on_dm_frame(self, frame: RFCOMM_Frame) -> None:
# TODO: handle all states # TODO: handle all states
@@ -586,22 +534,14 @@ class DLC(EventEmitter):
f'[{self.dlci}] {len(data)} bytes, ' f'[{self.dlci}] {len(data)} bytes, '
f'rx_credits={self.rx_credits}: {data.hex()}' f'rx_credits={self.rx_credits}: {data.hex()}'
) )
if data: if len(data) and self.sink:
if self._sink: self.sink(data) # pylint: disable=not-callable
self._sink(data) # pylint: disable=not-callable
else:
self._enqueued_rx_packets.append(data)
if (
self._enqueued_rx_packets.maxlen
and len(self._enqueued_rx_packets) >= self._enqueued_rx_packets.maxlen
):
logger.warning(f'DLC [{self.dlci}] received packet queue is full')
# Update the credits # Update the credits
if self.rx_credits > 0: if self.rx_credits > 0:
self.rx_credits -= 1 self.rx_credits -= 1
else: else:
logger.warning(color('!!! received frame with no rx credits', 'red')) logger.warning(color('!!! received frame with no rx credits', 'red'))
# Check if there's anything to send (including credits) # Check if there's anything to send (including credits)
self.process_tx() self.process_tx()
@@ -614,7 +554,9 @@ class DLC(EventEmitter):
# Command # Command
logger.debug(f'<<< MCC MSC Command: {msc}') logger.debug(f'<<< MCC MSC Command: {msc}')
msc = RFCOMM_MCC_MSC(dlci=self.dlci, fc=0, rtc=1, rtr=1, ic=0, dv=1) msc = RFCOMM_MCC_MSC(dlci=self.dlci, fc=0, rtc=1, rtr=1, ic=0, dv=1)
mcc = RFCOMM_Frame.make_mcc(mcc_type=MccType.MSC, c_r=0, data=bytes(msc)) mcc = RFCOMM_Frame.make_mcc(
mcc_type=RFCOMM_MCC_MSC_TYPE, c_r=0, data=bytes(msc)
)
logger.debug(f'>>> MCC MSC Response: {msc}') logger.debug(f'>>> MCC MSC Response: {msc}')
self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc)) self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
else: else:
@@ -629,19 +571,6 @@ class DLC(EventEmitter):
self.connection_result = asyncio.get_running_loop().create_future() self.connection_result = asyncio.get_running_loop().create_future()
self.send_frame(RFCOMM_Frame.sabm(c_r=self.c_r, dlci=self.dlci)) self.send_frame(RFCOMM_Frame.sabm(c_r=self.c_r, dlci=self.dlci))
async def disconnect(self) -> None:
if self.state != DLC.State.CONNECTED:
raise InvalidStateError('invalid state')
self.disconnection_result = asyncio.get_running_loop().create_future()
self.change_state(DLC.State.DISCONNECTING)
self.send_frame(
RFCOMM_Frame.disc(
c_r=1 if self.role == Multiplexer.Role.INITIATOR else 0, dlci=self.dlci
)
)
await self.disconnection_result
def accept(self) -> None: def accept(self) -> None:
if self.state != DLC.State.INIT: if self.state != DLC.State.INIT:
raise InvalidStateError('invalid state') raise InvalidStateError('invalid state')
@@ -651,18 +580,18 @@ class DLC(EventEmitter):
cl=0xE0, cl=0xE0,
priority=7, priority=7,
ack_timer=0, ack_timer=0,
max_frame_size=self.rx_max_frame_size, max_frame_size=RFCOMM_DEFAULT_PREFERRED_MTU,
max_retransmissions=0, max_retransmissions=0,
initial_credits=self.rx_initial_credits, window_size=RFCOMM_DEFAULT_INITIAL_RX_CREDITS,
) )
mcc = RFCOMM_Frame.make_mcc(mcc_type=MccType.PN, c_r=0, data=bytes(pn)) mcc = RFCOMM_Frame.make_mcc(mcc_type=RFCOMM_MCC_PN_TYPE, c_r=0, data=bytes(pn))
logger.debug(f'>>> PN Response: {pn}') logger.debug(f'>>> PN Response: {pn}')
self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc)) self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
self.change_state(DLC.State.CONNECTING) self.change_state(DLC.State.CONNECTING)
def rx_credits_needed(self) -> int: def rx_credits_needed(self) -> int:
if self.rx_credits <= self.rx_credits_threshold: if self.rx_credits <= self.rx_threshold:
return self.rx_max_credits - self.rx_credits return RFCOMM_DEFAULT_INITIAL_RX_CREDITS - self.rx_credits
return 0 return 0
@@ -702,8 +631,6 @@ class DLC(EventEmitter):
) )
rx_credits_needed = 0 rx_credits_needed = 0
if not self.tx_buffer:
self.drained.set()
# Stream protocol # Stream protocol
def write(self, data: Union[bytes, str]) -> None: def write(self, data: Union[bytes, str]) -> None:
@@ -716,34 +643,14 @@ class DLC(EventEmitter):
raise ValueError('write only accept bytes or strings') raise ValueError('write only accept bytes or strings')
self.tx_buffer += data self.tx_buffer += data
self.drained.clear()
self.process_tx() self.process_tx()
async def drain(self) -> None: def drain(self) -> None:
await self.drained.wait() # TODO
pass
def abort(self) -> None:
logger.debug(f'aborting DLC: {self}')
if self.connection_result:
self.connection_result.cancel()
self.connection_result = None
if self.disconnection_result:
self.disconnection_result.cancel()
self.disconnection_result = None
self.change_state(DLC.State.RESET)
self.emit('close')
def __str__(self) -> str: def __str__(self) -> str:
return ( return f'DLC(dlci={self.dlci},state={self.state.name})'
f'DLC(dlci={self.dlci}, '
f'state={self.state.name}, '
f'rx_max_frame_size={self.rx_max_frame_size}, '
f'rx_credits={self.rx_credits}, '
f'rx_max_credits={self.rx_max_credits}, '
f'tx_max_frame_size={self.tx_max_frame_size}, '
f'tx_credits={self.tx_credits}'
')'
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -764,7 +671,7 @@ class Multiplexer(EventEmitter):
connection_result: Optional[asyncio.Future] connection_result: Optional[asyncio.Future]
disconnection_result: Optional[asyncio.Future] disconnection_result: Optional[asyncio.Future]
open_result: Optional[asyncio.Future] open_result: Optional[asyncio.Future]
acceptor: Optional[Callable[[int], Optional[Tuple[int, int]]]] acceptor: Optional[Callable[[int], bool]]
dlcs: Dict[int, DLC] dlcs: Dict[int, DLC]
def __init__(self, l2cap_channel: l2cap.ClassicChannel, role: Role) -> None: def __init__(self, l2cap_channel: l2cap.ClassicChannel, role: Role) -> None:
@@ -776,15 +683,11 @@ class Multiplexer(EventEmitter):
self.connection_result = None self.connection_result = None
self.disconnection_result = None self.disconnection_result = None
self.open_result = None self.open_result = None
self.open_pn: Optional[RFCOMM_MCC_PN] = None
self.open_rx_max_credits = 0
self.acceptor = None self.acceptor = None
# Become a sink for the L2CAP channel # Become a sink for the L2CAP channel
l2cap_channel.sink = self.on_pdu l2cap_channel.sink = self.on_pdu
l2cap_channel.on('close', self.on_l2cap_channel_close)
def change_state(self, new_state: State) -> None: def change_state(self, new_state: State) -> None:
logger.debug(f'{self} state change -> {color(new_state.name, "cyan")}') logger.debug(f'{self} state change -> {color(new_state.name, "cyan")}')
self.state = new_state self.state = new_state
@@ -801,7 +704,7 @@ class Multiplexer(EventEmitter):
if frame.dlci == 0: if frame.dlci == 0:
self.on_frame(frame) self.on_frame(frame)
else: else:
if frame.type == FrameType.DM: if frame.type == RFCOMM_DM_FRAME:
# DM responses are for a DLCI, but since we only create the dlc when we # DM responses are for a DLCI, but since we only create the dlc when we
# receive a PN response (because we need the parameters), we handle DM # receive a PN response (because we need the parameters), we handle DM
# frames at the Multiplexer level # frames at the Multiplexer level
@@ -814,7 +717,7 @@ class Multiplexer(EventEmitter):
dlc.on_frame(frame) dlc.on_frame(frame)
def on_frame(self, frame: RFCOMM_Frame) -> None: def on_frame(self, frame: RFCOMM_Frame) -> None:
handler = getattr(self, f'on_{frame.type.name}_frame'.lower()) handler = getattr(self, f'on_{frame.type_name()}_frame'.lower())
handler(frame) handler(frame)
def on_sabm_frame(self, _frame: RFCOMM_Frame) -> None: def on_sabm_frame(self, _frame: RFCOMM_Frame) -> None:
@@ -848,7 +751,6 @@ class Multiplexer(EventEmitter):
'rfcomm', 'rfcomm',
) )
) )
self.open_result = None
else: else:
logger.warning(f'unexpected state for DM: {self}') logger.warning(f'unexpected state for DM: {self}')
@@ -863,10 +765,10 @@ class Multiplexer(EventEmitter):
def on_uih_frame(self, frame: RFCOMM_Frame) -> None: def on_uih_frame(self, frame: RFCOMM_Frame) -> None:
(mcc_type, c_r, value) = RFCOMM_Frame.parse_mcc(frame.information) (mcc_type, c_r, value) = RFCOMM_Frame.parse_mcc(frame.information)
if mcc_type == MccType.PN: if mcc_type == RFCOMM_MCC_PN_TYPE:
pn = RFCOMM_MCC_PN.from_bytes(value) pn = RFCOMM_MCC_PN.from_bytes(value)
self.on_mcc_pn(c_r, pn) self.on_mcc_pn(c_r, pn)
elif mcc_type == MccType.MSC: elif mcc_type == RFCOMM_MCC_MSC_TYPE:
mcs = RFCOMM_MCC_MSC.from_bytes(value) mcs = RFCOMM_MCC_MSC.from_bytes(value)
self.on_mcc_msc(c_r, mcs) self.on_mcc_msc(c_r, mcs)
@@ -886,16 +788,9 @@ class Multiplexer(EventEmitter):
else: else:
if self.acceptor: if self.acceptor:
channel_number = pn.dlci >> 1 channel_number = pn.dlci >> 1
if dlc_params := self.acceptor(channel_number): if self.acceptor(channel_number):
# Create a new DLC # Create a new DLC
dlc = DLC( dlc = DLC(self, pn.dlci, pn.max_frame_size, pn.window_size)
self,
dlci=pn.dlci,
tx_max_frame_size=pn.max_frame_size,
tx_initial_credits=pn.initial_credits,
rx_max_frame_size=dlc_params[0],
rx_initial_credits=dlc_params[1],
)
self.dlcs[pn.dlci] = dlc self.dlcs[pn.dlci] = dlc
# Re-emit the handshake completion event # Re-emit the handshake completion event
@@ -913,17 +808,8 @@ class Multiplexer(EventEmitter):
# Response # Response
logger.debug(f'>>> PN Response: {pn}') logger.debug(f'>>> PN Response: {pn}')
if self.state == Multiplexer.State.OPENING: if self.state == Multiplexer.State.OPENING:
assert self.open_pn dlc = DLC(self, pn.dlci, pn.max_frame_size, pn.window_size)
dlc = DLC(
self,
dlci=pn.dlci,
tx_max_frame_size=pn.max_frame_size,
tx_initial_credits=pn.initial_credits,
rx_max_frame_size=self.open_pn.max_frame_size,
rx_initial_credits=self.open_pn.initial_credits,
)
self.dlcs[pn.dlci] = dlc self.dlcs[pn.dlci] = dlc
self.open_pn = None
dlc.connect() dlc.connect()
else: else:
logger.warning('ignoring PN response') logger.warning('ignoring PN response')
@@ -957,31 +843,24 @@ class Multiplexer(EventEmitter):
) )
await self.disconnection_result await self.disconnection_result
async def open_dlc( async def open_dlc(self, channel: int) -> DLC:
self,
channel: int,
max_frame_size: int = RFCOMM_DEFAULT_MAX_FRAME_SIZE,
initial_credits: int = RFCOMM_DEFAULT_INITIAL_CREDITS,
) -> DLC:
if self.state != Multiplexer.State.CONNECTED: if self.state != Multiplexer.State.CONNECTED:
if self.state == Multiplexer.State.OPENING: if self.state == Multiplexer.State.OPENING:
raise InvalidStateError('open already in progress') raise InvalidStateError('open already in progress')
raise InvalidStateError('not connected') raise InvalidStateError('not connected')
self.open_pn = RFCOMM_MCC_PN( pn = RFCOMM_MCC_PN(
dlci=channel << 1, dlci=channel << 1,
cl=0xF0, cl=0xF0,
priority=7, priority=7,
ack_timer=0, ack_timer=0,
max_frame_size=max_frame_size, max_frame_size=RFCOMM_DEFAULT_PREFERRED_MTU,
max_retransmissions=0, max_retransmissions=0,
initial_credits=initial_credits, window_size=RFCOMM_DEFAULT_INITIAL_RX_CREDITS,
) )
mcc = RFCOMM_Frame.make_mcc( mcc = RFCOMM_Frame.make_mcc(mcc_type=RFCOMM_MCC_PN_TYPE, c_r=1, data=bytes(pn))
mcc_type=MccType.PN, c_r=1, data=bytes(self.open_pn) logger.debug(f'>>> Sending MCC: {pn}')
)
logger.debug(f'>>> Sending MCC: {self.open_pn}')
self.open_result = asyncio.get_running_loop().create_future() self.open_result = asyncio.get_running_loop().create_future()
self.change_state(Multiplexer.State.OPENING) self.change_state(Multiplexer.State.OPENING)
self.send_frame( self.send_frame(
@@ -991,31 +870,15 @@ class Multiplexer(EventEmitter):
information=mcc, information=mcc,
) )
) )
return await self.open_result result = await self.open_result
self.open_result = None
return result
def on_dlc_open_complete(self, dlc: DLC) -> None: def on_dlc_open_complete(self, dlc: DLC) -> None:
logger.debug(f'DLC [{dlc.dlci}] open complete') logger.debug(f'DLC [{dlc.dlci}] open complete')
self.change_state(Multiplexer.State.CONNECTED) self.change_state(Multiplexer.State.CONNECTED)
if self.open_result: if self.open_result:
self.open_result.set_result(dlc) self.open_result.set_result(dlc)
self.open_result = None
def on_dlc_disconnection(self, dlc: DLC) -> None:
logger.debug(f'DLC [{dlc.dlci}] disconnection')
self.dlcs.pop(dlc.dlci, None)
def on_l2cap_channel_close(self) -> None:
logger.debug('L2CAP channel closed, cleaning up')
if self.open_result:
self.open_result.cancel()
self.open_result = None
if self.disconnection_result:
self.disconnection_result.cancel()
self.disconnection_result = None
for dlc in self.dlcs.values():
dlc.abort()
def __str__(self) -> str: def __str__(self) -> str:
return f'Multiplexer(state={self.state.name})' return f'Multiplexer(state={self.state.name})'
@@ -1026,11 +889,9 @@ class Client:
multiplexer: Optional[Multiplexer] multiplexer: Optional[Multiplexer]
l2cap_channel: Optional[l2cap.ClassicChannel] l2cap_channel: Optional[l2cap.ClassicChannel]
def __init__( def __init__(self, device: Device, connection: Connection) -> None:
self, connection: Connection, l2cap_mtu: int = RFCOMM_DEFAULT_L2CAP_MTU self.device = device
) -> None:
self.connection = connection self.connection = connection
self.l2cap_mtu = l2cap_mtu
self.l2cap_channel = None self.l2cap_channel = None
self.multiplexer = None self.multiplexer = None
@@ -1038,14 +899,14 @@ class Client:
# Create a new L2CAP connection # Create a new L2CAP connection
try: try:
self.l2cap_channel = await self.connection.create_l2cap_channel( self.l2cap_channel = await self.connection.create_l2cap_channel(
spec=l2cap.ClassicChannelSpec(psm=RFCOMM_PSM, mtu=self.l2cap_mtu) spec=l2cap.ClassicChannelSpec(RFCOMM_PSM)
) )
except ProtocolError as error: except ProtocolError as error:
logger.warning(f'L2CAP connection failed: {error}') logger.warning(f'L2CAP connection failed: {error}')
raise raise
assert self.l2cap_channel is not None assert self.l2cap_channel is not None
# Create a multiplexer to manage DLCs with the server # Create a mutliplexer to manage DLCs with the server
self.multiplexer = Multiplexer(self.l2cap_channel, Multiplexer.Role.INITIATOR) self.multiplexer = Multiplexer(self.l2cap_channel, Multiplexer.Role.INITIATOR)
# Connect the multiplexer # Connect the multiplexer
@@ -1061,40 +922,25 @@ class Client:
self.multiplexer = None self.multiplexer = None
# Close the L2CAP channel # Close the L2CAP channel
if self.l2cap_channel: # TODO
await self.l2cap_channel.disconnect()
self.l2cap_channel = None
async def __aenter__(self) -> Multiplexer:
return await self.start()
async def __aexit__(self, *args) -> None:
await self.shutdown()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Server(EventEmitter): class Server(EventEmitter):
def __init__( acceptors: Dict[int, Callable[[DLC], None]]
self, device: Device, l2cap_mtu: int = RFCOMM_DEFAULT_L2CAP_MTU
) -> None: def __init__(self, device: Device) -> None:
super().__init__() super().__init__()
self.device = device self.device = device
self.acceptors: Dict[int, Callable[[DLC], None]] = {} self.multiplexer = None
self.dlc_configs: Dict[int, Tuple[int, int]] = {} self.acceptors = {}
# Register ourselves with the L2CAP channel manager # Register ourselves with the L2CAP channel manager
self.l2cap_server = device.create_l2cap_server( device.create_l2cap_server(
spec=l2cap.ClassicChannelSpec(psm=RFCOMM_PSM, mtu=l2cap_mtu), spec=l2cap.ClassicChannelSpec(psm=RFCOMM_PSM), handler=self.on_connection
handler=self.on_connection,
) )
def listen( def listen(self, acceptor: Callable[[DLC], None], channel: int = 0) -> int:
self,
acceptor: Callable[[DLC], None],
channel: int = 0,
max_frame_size: int = RFCOMM_DEFAULT_MAX_FRAME_SIZE,
initial_credits: int = RFCOMM_DEFAULT_INITIAL_CREDITS,
) -> int:
if channel: if channel:
if channel in self.acceptors: if channel in self.acceptors:
# Busy # Busy
@@ -1114,8 +960,6 @@ class Server(EventEmitter):
return 0 return 0
self.acceptors[channel] = acceptor self.acceptors[channel] = acceptor
self.dlc_configs[channel] = (max_frame_size, initial_credits)
return channel return channel
def on_connection(self, l2cap_channel: l2cap.ClassicChannel) -> None: def on_connection(self, l2cap_channel: l2cap.ClassicChannel) -> None:
@@ -1133,18 +977,13 @@ class Server(EventEmitter):
# Notify # Notify
self.emit('start', multiplexer) self.emit('start', multiplexer)
def accept_dlc(self, channel_number: int) -> Optional[Tuple[int, int]]: def accept_dlc(self, channel_number: int) -> bool:
return self.dlc_configs.get(channel_number) return channel_number in self.acceptors
def on_dlc(self, dlc: DLC) -> None: def on_dlc(self, dlc: DLC) -> None:
logger.debug(f'@@@ new DLC connected: {dlc}') logger.debug(f'@@@ new DLC connected: {dlc}')
# Let the acceptor know # Let the acceptor know
if acceptor := self.acceptors.get(dlc.dlci >> 1): acceptor = self.acceptors.get(dlc.dlci >> 1)
if acceptor:
acceptor(dlc) acceptor(dlc)
def __enter__(self) -> Self:
return self
def __exit__(self, *args) -> None:
self.l2cap_server.close()
+15 -29
View File
@@ -19,7 +19,6 @@ from __future__ import annotations
import logging import logging
import struct import struct
from typing import Dict, List, Type, Optional, Tuple, Union, NewType, TYPE_CHECKING from typing import Dict, List, Type, Optional, Tuple, Union, NewType, TYPE_CHECKING
from typing_extensions import Self
from . import core, l2cap from . import core, l2cap
from .colors import color from .colors import color
@@ -98,8 +97,7 @@ SDP_CLIENT_EXECUTABLE_URL_ATTRIBUTE_ID = 0X000B
SDP_ICON_URL_ATTRIBUTE_ID = 0X000C SDP_ICON_URL_ATTRIBUTE_ID = 0X000C
SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID = 0X000D SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID = 0X000D
# Attribute Identifier (cf. Assigned Numbers for Service Discovery)
# Profile-specific Attribute Identifiers (cf. Assigned Numbers for Service Discovery)
# used by AVRCP, HFP and A2DP # used by AVRCP, HFP and A2DP
SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID = 0x0311 SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID = 0x0311
@@ -117,8 +115,7 @@ SDP_ATTRIBUTE_ID_NAMES = {
SDP_DOCUMENTATION_URL_ATTRIBUTE_ID: 'SDP_DOCUMENTATION_URL_ATTRIBUTE_ID', SDP_DOCUMENTATION_URL_ATTRIBUTE_ID: 'SDP_DOCUMENTATION_URL_ATTRIBUTE_ID',
SDP_CLIENT_EXECUTABLE_URL_ATTRIBUTE_ID: 'SDP_CLIENT_EXECUTABLE_URL_ATTRIBUTE_ID', SDP_CLIENT_EXECUTABLE_URL_ATTRIBUTE_ID: 'SDP_CLIENT_EXECUTABLE_URL_ATTRIBUTE_ID',
SDP_ICON_URL_ATTRIBUTE_ID: 'SDP_ICON_URL_ATTRIBUTE_ID', SDP_ICON_URL_ATTRIBUTE_ID: 'SDP_ICON_URL_ATTRIBUTE_ID',
SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID: 'SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID', SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID: 'SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID'
SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID: 'SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID',
} }
SDP_PUBLIC_BROWSE_ROOT = core.UUID.from_16_bits(0x1002, 'PublicBrowseRoot') SDP_PUBLIC_BROWSE_ROOT = core.UUID.from_16_bits(0x1002, 'PublicBrowseRoot')
@@ -763,13 +760,13 @@ class SDP_ServiceSearchAttributeResponse(SDP_PDU):
class Client: class Client:
channel: Optional[l2cap.ClassicChannel] channel: Optional[l2cap.ClassicChannel]
def __init__(self, connection: Connection) -> None: def __init__(self, device: Device) -> None:
self.connection = connection self.device = device
self.pending_request = None self.pending_request = None
self.channel = None self.channel = None
async def connect(self) -> None: async def connect(self, connection: Connection) -> None:
self.channel = await self.connection.create_l2cap_channel( self.channel = await connection.create_l2cap_channel(
spec=l2cap.ClassicChannelSpec(SDP_PSM) spec=l2cap.ClassicChannelSpec(SDP_PSM)
) )
@@ -825,13 +822,11 @@ class Client:
) )
attribute_id_list = DataElement.sequence( attribute_id_list = DataElement.sequence(
[ [
( DataElement.unsigned_integer(
DataElement.unsigned_integer( attribute_id[0], value_size=attribute_id[1]
attribute_id[0], value_size=attribute_id[1]
)
if isinstance(attribute_id, tuple)
else DataElement.unsigned_integer_16(attribute_id)
) )
if isinstance(attribute_id, tuple)
else DataElement.unsigned_integer_16(attribute_id)
for attribute_id in attribute_ids for attribute_id in attribute_ids
] ]
) )
@@ -883,13 +878,11 @@ class Client:
attribute_id_list = DataElement.sequence( attribute_id_list = DataElement.sequence(
[ [
( DataElement.unsigned_integer(
DataElement.unsigned_integer( attribute_id[0], value_size=attribute_id[1]
attribute_id[0], value_size=attribute_id[1]
)
if isinstance(attribute_id, tuple)
else DataElement.unsigned_integer_16(attribute_id)
) )
if isinstance(attribute_id, tuple)
else DataElement.unsigned_integer_16(attribute_id)
for attribute_id in attribute_ids for attribute_id in attribute_ids
] ]
) )
@@ -925,13 +918,6 @@ class Client:
return ServiceAttribute.list_from_data_elements(attribute_list_sequence.value) return ServiceAttribute.list_from_data_elements(attribute_list_sequence.value)
async def __aenter__(self) -> Self:
await self.connect()
return self
async def __aexit__(self, *args) -> None:
await self.disconnect()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Server: class Server:
@@ -997,7 +983,7 @@ class Server:
try: try:
handler(sdp_pdu) handler(sdp_pdu)
except Exception as error: except Exception as error:
logger.exception(f'{color("!!! Exception in handler:", "red")} {error}') logger.warning(f'{color("!!! Exception in handler:", "red")} {error}')
self.send_response( self.send_response(
SDP_ErrorResponse( SDP_ErrorResponse(
transaction_id=sdp_pdu.transaction_id, transaction_id=sdp_pdu.transaction_id,
+76 -222
View File
@@ -27,7 +27,6 @@ import logging
import asyncio import asyncio
import enum import enum
import secrets import secrets
from dataclasses import dataclass
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
@@ -54,7 +53,6 @@ from .core import (
BT_BR_EDR_TRANSPORT, BT_BR_EDR_TRANSPORT,
BT_CENTRAL_ROLE, BT_CENTRAL_ROLE,
BT_LE_TRANSPORT, BT_LE_TRANSPORT,
AdvertisingData,
ProtocolError, ProtocolError,
name_or_number, name_or_number,
) )
@@ -187,8 +185,8 @@ SMP_KEYPRESS_AUTHREQ = 0b00010000
SMP_CT2_AUTHREQ = 0b00100000 SMP_CT2_AUTHREQ = 0b00100000
# Crypto salt # Crypto salt
SMP_CTKD_H7_LEBR_SALT = bytes.fromhex('000000000000000000000000746D7031') SMP_CTKD_H7_LEBR_SALT = bytes.fromhex('00000000000000000000000000000000746D7031')
SMP_CTKD_H7_BRLE_SALT = bytes.fromhex('000000000000000000000000746D7032') SMP_CTKD_H7_BRLE_SALT = bytes.fromhex('00000000000000000000000000000000746D7032')
# fmt: on # fmt: on
# pylint: enable=line-too-long # pylint: enable=line-too-long
@@ -565,54 +563,6 @@ class PairingMethod(enum.IntEnum):
CTKD_OVER_CLASSIC = 4 CTKD_OVER_CLASSIC = 4
# -----------------------------------------------------------------------------
class OobContext:
"""Cryptographic context for LE SC OOB pairing."""
ecc_key: crypto.EccKey
r: bytes
def __init__(
self, ecc_key: Optional[crypto.EccKey] = None, r: Optional[bytes] = None
) -> None:
self.ecc_key = crypto.EccKey.generate() if ecc_key is None else ecc_key
self.r = crypto.r() if r is None else r
def share(self) -> OobSharedData:
pkx = self.ecc_key.x[::-1]
return OobSharedData(c=crypto.f4(pkx, pkx, self.r, bytes(1)), r=self.r)
# -----------------------------------------------------------------------------
class OobLegacyContext:
"""Cryptographic context for LE Legacy OOB pairing."""
tk: bytes
def __init__(self, tk: Optional[bytes] = None) -> None:
self.tk = crypto.r() if tk is None else tk
# -----------------------------------------------------------------------------
@dataclass
class OobSharedData:
"""Shareable data for LE SC OOB pairing."""
c: bytes
r: bytes
def to_ad(self) -> AdvertisingData:
return AdvertisingData(
[
(AdvertisingData.LE_SECURE_CONNECTIONS_CONFIRMATION_VALUE, self.c),
(AdvertisingData.LE_SECURE_CONNECTIONS_RANDOM_VALUE, self.r),
]
)
def __str__(self) -> str:
return f'OOB(C={self.c.hex()}, R={self.r.hex()})'
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Session: class Session:
# I/O Capability to pairing method decision matrix # I/O Capability to pairing method decision matrix
@@ -677,13 +627,6 @@ class Session:
}, },
} }
ea: bytes
eb: bytes
ltk: bytes
preq: bytes
pres: bytes
tk: bytes
def __init__( def __init__(
self, self,
manager: Manager, manager: Manager,
@@ -693,10 +636,17 @@ class Session:
) -> None: ) -> None:
self.manager = manager self.manager = manager
self.connection = connection self.connection = connection
self.preq: Optional[bytes] = None
self.pres: Optional[bytes] = None
self.ea = None
self.eb = None
self.tk = bytes(16)
self.r = bytes(16)
self.stk = None self.stk = None
self.ltk = None
self.ltk_ediv = 0 self.ltk_ediv = 0
self.ltk_rand = bytes(8) self.ltk_rand = bytes(8)
self.link_key: Optional[bytes] = None self.link_key = None
self.initiator_key_distribution: int = 0 self.initiator_key_distribution: int = 0
self.responder_key_distribution: int = 0 self.responder_key_distribution: int = 0
self.peer_random_value: Optional[bytes] = None self.peer_random_value: Optional[bytes] = None
@@ -709,7 +659,7 @@ class Session:
self.peer_bd_addr: Optional[Address] = None self.peer_bd_addr: Optional[Address] = None
self.peer_signature_key = None self.peer_signature_key = None
self.peer_expected_distributions: List[Type[SMP_Command]] = [] self.peer_expected_distributions: List[Type[SMP_Command]] = []
self.dh_key = b'' self.dh_key = None
self.confirm_value = None self.confirm_value = None
self.passkey: Optional[int] = None self.passkey: Optional[int] = None
self.passkey_ready = asyncio.Event() self.passkey_ready = asyncio.Event()
@@ -737,9 +687,9 @@ class Session:
# Create a future that can be used to wait for the session to complete # Create a future that can be used to wait for the session to complete
if self.is_initiator: if self.is_initiator:
self.pairing_result: Optional[asyncio.Future[None]] = ( self.pairing_result: Optional[
asyncio.get_running_loop().create_future() asyncio.Future[None]
) ] = asyncio.get_running_loop().create_future()
else: else:
self.pairing_result = None self.pairing_result = None
@@ -762,8 +712,8 @@ class Session:
self.io_capability = pairing_config.delegate.io_capability self.io_capability = pairing_config.delegate.io_capability
self.peer_io_capability = SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY self.peer_io_capability = SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY
# OOB # OOB (not supported yet)
self.oob_data_flag = 0 if pairing_config.oob is None else 1 self.oob = False
# Set up addresses # Set up addresses
self_address = connection.self_address self_address = connection.self_address
@@ -779,35 +729,9 @@ class Session:
self.ia = bytes(peer_address) self.ia = bytes(peer_address)
self.iat = 1 if peer_address.is_random else 0 self.iat = 1 if peer_address.is_random else 0
# Select the ECC key, TK and r initial value
if pairing_config.oob:
self.peer_oob_data = pairing_config.oob.peer_data
if pairing_config.sc:
if pairing_config.oob.our_context is None:
raise ValueError(
"oob pairing config requires a context when sc is True"
)
self.r = pairing_config.oob.our_context.r
self.ecc_key = pairing_config.oob.our_context.ecc_key
if pairing_config.oob.legacy_context is not None:
self.tk = pairing_config.oob.legacy_context.tk
else:
if pairing_config.oob.legacy_context is None:
raise ValueError(
"oob pairing config requires a legacy context when sc is False"
)
self.r = bytes(16)
self.ecc_key = manager.ecc_key
self.tk = pairing_config.oob.legacy_context.tk
else:
self.peer_oob_data = None
self.r = bytes(16)
self.ecc_key = manager.ecc_key
self.tk = bytes(16)
@property @property
def pkx(self) -> Tuple[bytes, bytes]: def pkx(self) -> Tuple[bytes, bytes]:
return (self.ecc_key.x[::-1], self.peer_public_key_x) return (bytes(reversed(self.manager.ecc_key.x)), self.peer_public_key_x)
@property @property
def pka(self) -> bytes: def pka(self) -> bytes:
@@ -844,10 +768,7 @@ class Session:
return None return None
def decide_pairing_method( def decide_pairing_method(
self, self, auth_req: int, initiator_io_capability: int, responder_io_capability: int
auth_req: int,
initiator_io_capability: int,
responder_io_capability: int,
) -> None: ) -> None:
if self.connection.transport == BT_BR_EDR_TRANSPORT: if self.connection.transport == BT_BR_EDR_TRANSPORT:
self.pairing_method = PairingMethod.CTKD_OVER_CLASSIC self.pairing_method = PairingMethod.CTKD_OVER_CLASSIC
@@ -988,7 +909,7 @@ class Session:
command = SMP_Pairing_Request_Command( command = SMP_Pairing_Request_Command(
io_capability=self.io_capability, io_capability=self.io_capability,
oob_data_flag=self.oob_data_flag, oob_data_flag=0,
auth_req=self.auth_req, auth_req=self.auth_req,
maximum_encryption_key_size=16, maximum_encryption_key_size=16,
initiator_key_distribution=self.initiator_key_distribution, initiator_key_distribution=self.initiator_key_distribution,
@@ -1000,7 +921,7 @@ class Session:
def send_pairing_response_command(self) -> None: def send_pairing_response_command(self) -> None:
response = SMP_Pairing_Response_Command( response = SMP_Pairing_Response_Command(
io_capability=self.io_capability, io_capability=self.io_capability,
oob_data_flag=self.oob_data_flag, oob_data_flag=0,
auth_req=self.auth_req, auth_req=self.auth_req,
maximum_encryption_key_size=16, maximum_encryption_key_size=16,
initiator_key_distribution=self.initiator_key_distribution, initiator_key_distribution=self.initiator_key_distribution,
@@ -1061,8 +982,8 @@ class Session:
def send_public_key_command(self) -> None: def send_public_key_command(self) -> None:
self.send_command( self.send_command(
SMP_Pairing_Public_Key_Command( SMP_Pairing_Public_Key_Command(
public_key_x=self.ecc_key.x[::-1], public_key_x=bytes(reversed(self.manager.ecc_key.x)),
public_key_y=self.ecc_key.y[::-1], public_key_y=bytes(reversed(self.manager.ecc_key.y)),
) )
) )
@@ -1090,7 +1011,7 @@ class Session:
# We can now encrypt the connection with the short term key, so that we can # We can now encrypt the connection with the short term key, so that we can
# distribute the long term and/or other keys over an encrypted connection # distribute the long term and/or other keys over an encrypted connection
self.manager.device.host.send_command_sync( self.manager.device.host.send_command_sync(
HCI_LE_Enable_Encryption_Command( HCI_LE_Enable_Encryption_Command( # type: ignore[call-arg]
connection_handle=self.connection.handle, connection_handle=self.connection.handle,
random_number=bytes(8), random_number=bytes(8),
encrypted_diversifier=0, encrypted_diversifier=0,
@@ -1098,56 +1019,18 @@ class Session:
) )
) )
@classmethod async def derive_ltk(self) -> None:
def derive_ltk(cls, link_key: bytes, ct2: bool) -> bytes: link_key = await self.manager.device.get_link_key(self.connection.peer_address)
'''Derives Long Term Key from Link Key. assert link_key is not None
Args:
link_key: BR/EDR Link Key bytes in little-endian.
ct2: whether ct2 is supported on both devices.
Returns:
LE Long Tern Key bytes in little-endian.
'''
ilk = ( ilk = (
crypto.h7(salt=SMP_CTKD_H7_BRLE_SALT, w=link_key) crypto.h7(salt=SMP_CTKD_H7_BRLE_SALT, w=link_key)
if ct2 if self.ct2
else crypto.h6(link_key, b'tmp2') else crypto.h6(link_key, b'tmp2')
) )
return crypto.h6(ilk, b'brle') self.ltk = crypto.h6(ilk, b'brle')
@classmethod
def derive_link_key(cls, ltk: bytes, ct2: bool) -> bytes:
'''Derives Link Key from Long Term Key.
Args:
ltk: LE Long Term Key bytes in little-endian.
ct2: whether ct2 is supported on both devices.
Returns:
BR/EDR Link Key bytes in little-endian.
'''
ilk = (
crypto.h7(salt=SMP_CTKD_H7_LEBR_SALT, w=ltk)
if ct2
else crypto.h6(ltk, b'tmp1')
)
return crypto.h6(ilk, b'lebr')
async def get_link_key_and_derive_ltk(self) -> None:
'''Retrieves BR/EDR Link Key from storage and derive it to LE LTK.'''
self.link_key = await self.manager.device.get_link_key(
self.connection.peer_address
)
if self.link_key is None:
logging.warning(
'Try to derive LTK but host does not have the LK. Send a SMP_PAIRING_FAILED but the procedure will not be paused!'
)
self.send_pairing_failed(
SMP_CROSS_TRANSPORT_KEY_DERIVATION_NOT_ALLOWED_ERROR
)
else:
self.ltk = self.derive_ltk(self.link_key, self.ct2)
def distribute_keys(self) -> None: def distribute_keys(self) -> None:
# Distribute the keys as required # Distribute the keys as required
if self.is_initiator: if self.is_initiator:
# CTKD: Derive LTK from LinkKey # CTKD: Derive LTK from LinkKey
@@ -1156,7 +1039,7 @@ class Session:
and self.initiator_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG and self.initiator_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG
): ):
self.ctkd_task = self.connection.abort_on( self.ctkd_task = self.connection.abort_on(
'disconnection', self.get_link_key_and_derive_ltk() 'disconnection', self.derive_ltk()
) )
elif not self.sc: elif not self.sc:
# Distribute the LTK, EDIV and RAND # Distribute the LTK, EDIV and RAND
@@ -1186,7 +1069,12 @@ class Session:
# CTKD, calculate BR/EDR link key # CTKD, calculate BR/EDR link key
if self.initiator_key_distribution & SMP_LINK_KEY_DISTRIBUTION_FLAG: if self.initiator_key_distribution & SMP_LINK_KEY_DISTRIBUTION_FLAG:
self.link_key = self.derive_link_key(self.ltk, self.ct2) ilk = (
crypto.h7(salt=SMP_CTKD_H7_LEBR_SALT, w=self.ltk)
if self.ct2
else crypto.h6(self.ltk, b'tmp1')
)
self.link_key = crypto.h6(ilk, b'lebr')
else: else:
# CTKD: Derive LTK from LinkKey # CTKD: Derive LTK from LinkKey
@@ -1195,7 +1083,7 @@ class Session:
and self.responder_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG and self.responder_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG
): ):
self.ctkd_task = self.connection.abort_on( self.ctkd_task = self.connection.abort_on(
'disconnection', self.get_link_key_and_derive_ltk() 'disconnection', self.derive_ltk()
) )
# Distribute the LTK, EDIV and RAND # Distribute the LTK, EDIV and RAND
elif not self.sc: elif not self.sc:
@@ -1225,7 +1113,12 @@ class Session:
# CTKD, calculate BR/EDR link key # CTKD, calculate BR/EDR link key
if self.responder_key_distribution & SMP_LINK_KEY_DISTRIBUTION_FLAG: if self.responder_key_distribution & SMP_LINK_KEY_DISTRIBUTION_FLAG:
self.link_key = self.derive_link_key(self.ltk, self.ct2) ilk = (
crypto.h7(salt=SMP_CTKD_H7_LEBR_SALT, w=self.ltk)
if self.ct2
else crypto.h6(self.ltk, b'tmp1')
)
self.link_key = crypto.h6(ilk, b'lebr')
def compute_peer_expected_distributions(self, key_distribution_flags: int) -> None: def compute_peer_expected_distributions(self, key_distribution_flags: int) -> None:
# Set our expectations for what to wait for in the key distribution phase # Set our expectations for what to wait for in the key distribution phase
@@ -1403,7 +1296,7 @@ class Session:
try: try:
handler(command) handler(command)
except Exception as error: except Exception as error:
logger.exception(f'{color("!!! Exception in handler:", "red")} {error}') logger.warning(f'{color("!!! Exception in handler:", "red")} {error}')
response = SMP_Pairing_Failed_Command( response = SMP_Pairing_Failed_Command(
reason=SMP_UNSPECIFIED_REASON_ERROR reason=SMP_UNSPECIFIED_REASON_ERROR
) )
@@ -1440,28 +1333,15 @@ class Session:
self.sc = self.sc and (command.auth_req & SMP_SC_AUTHREQ != 0) self.sc = self.sc and (command.auth_req & SMP_SC_AUTHREQ != 0)
self.ct2 = self.ct2 and (command.auth_req & SMP_CT2_AUTHREQ != 0) self.ct2 = self.ct2 and (command.auth_req & SMP_CT2_AUTHREQ != 0)
# Infer the pairing method # Check for OOB
if (self.sc and (self.oob_data_flag != 0 or command.oob_data_flag != 0)) or ( if command.oob_data_flag != 0:
not self.sc and (self.oob_data_flag != 0 and command.oob_data_flag != 0) self.send_pairing_failed(SMP_OOB_NOT_AVAILABLE_ERROR)
): return
# Use OOB
self.pairing_method = PairingMethod.OOB
if not self.sc and self.tk is None:
# For legacy OOB, TK is required.
logger.warning("legacy OOB without TK")
self.send_pairing_failed(SMP_OOB_NOT_AVAILABLE_ERROR)
return
if command.oob_data_flag == 0:
# The peer doesn't have OOB data, use r=0
self.r = bytes(16)
else:
# Decide which pairing method to use from the IO capability
self.decide_pairing_method(
command.auth_req,
command.io_capability,
self.io_capability,
)
# Decide which pairing method to use
self.decide_pairing_method(
command.auth_req, command.io_capability, self.io_capability
)
logger.debug(f'pairing method: {self.pairing_method.name}') logger.debug(f'pairing method: {self.pairing_method.name}')
# Key distribution # Key distribution
@@ -1510,26 +1390,15 @@ class Session:
self.bonding = self.bonding and (command.auth_req & SMP_BONDING_AUTHREQ != 0) self.bonding = self.bonding and (command.auth_req & SMP_BONDING_AUTHREQ != 0)
self.sc = self.sc and (command.auth_req & SMP_SC_AUTHREQ != 0) self.sc = self.sc and (command.auth_req & SMP_SC_AUTHREQ != 0)
# Infer the pairing method # Check for OOB
if (self.sc and (self.oob_data_flag != 0 or command.oob_data_flag != 0)) or ( if self.sc and command.oob_data_flag:
not self.sc and (self.oob_data_flag != 0 and command.oob_data_flag != 0) self.send_pairing_failed(SMP_OOB_NOT_AVAILABLE_ERROR)
): return
# Use OOB
self.pairing_method = PairingMethod.OOB
if not self.sc and self.tk is None:
# For legacy OOB, TK is required.
logger.warning("legacy OOB without TK")
self.send_pairing_failed(SMP_OOB_NOT_AVAILABLE_ERROR)
return
if command.oob_data_flag == 0:
# The peer doesn't have OOB data, use r=0
self.r = bytes(16)
else:
# Decide which pairing method to use from the IO capability
self.decide_pairing_method(
command.auth_req, self.io_capability, command.io_capability
)
# Decide which pairing method to use
self.decide_pairing_method(
command.auth_req, self.io_capability, command.io_capability
)
logger.debug(f'pairing method: {self.pairing_method.name}') logger.debug(f'pairing method: {self.pairing_method.name}')
# Key distribution # Key distribution
@@ -1680,13 +1549,12 @@ class Session:
if self.passkey_step < 20: if self.passkey_step < 20:
self.send_pairing_confirm_command() self.send_pairing_confirm_command()
return return
elif self.pairing_method != PairingMethod.OOB: else:
return return
else: else:
if self.pairing_method in ( if self.pairing_method in (
PairingMethod.JUST_WORKS, PairingMethod.JUST_WORKS,
PairingMethod.NUMERIC_COMPARISON, PairingMethod.NUMERIC_COMPARISON,
PairingMethod.OOB,
): ):
self.send_pairing_random_command() self.send_pairing_random_command()
elif self.pairing_method == PairingMethod.PASSKEY: elif self.pairing_method == PairingMethod.PASSKEY:
@@ -1723,7 +1591,6 @@ class Session:
if self.pairing_method in ( if self.pairing_method in (
PairingMethod.JUST_WORKS, PairingMethod.JUST_WORKS,
PairingMethod.NUMERIC_COMPARISON, PairingMethod.NUMERIC_COMPARISON,
PairingMethod.OOB,
): ):
ra = bytes(16) ra = bytes(16)
rb = ra rb = ra
@@ -1732,6 +1599,7 @@ class Session:
ra = self.passkey.to_bytes(16, byteorder='little') ra = self.passkey.to_bytes(16, byteorder='little')
rb = ra rb = ra
else: else:
# OOB not implemented yet
return return
assert self.preq and self.pres assert self.preq and self.pres
@@ -1783,33 +1651,18 @@ class Session:
self.peer_public_key_y = command.public_key_y self.peer_public_key_y = command.public_key_y
# Compute the DH key # Compute the DH key
self.dh_key = self.ecc_key.dh( self.dh_key = bytes(
command.public_key_x[::-1], reversed(
command.public_key_y[::-1], self.manager.ecc_key.dh(
)[::-1] bytes(reversed(command.public_key_x)),
bytes(reversed(command.public_key_y)),
)
)
)
logger.debug(f'DH key: {self.dh_key.hex()}') logger.debug(f'DH key: {self.dh_key.hex()}')
if self.pairing_method == PairingMethod.OOB:
# Check against shared OOB data
if self.peer_oob_data:
confirm_verifier = crypto.f4(
self.peer_public_key_x,
self.peer_public_key_x,
self.peer_oob_data.r,
bytes(1),
)
if not self.check_expected_value(
self.peer_oob_data.c,
confirm_verifier,
SMP_CONFIRM_VALUE_FAILED_ERROR,
):
return
if self.is_initiator: if self.is_initiator:
if self.pairing_method == PairingMethod.OOB: self.send_pairing_confirm_command()
self.send_pairing_random_command()
else:
self.send_pairing_confirm_command()
else: else:
if self.pairing_method == PairingMethod.PASSKEY: if self.pairing_method == PairingMethod.PASSKEY:
self.display_or_input_passkey() self.display_or_input_passkey()
@@ -1820,7 +1673,6 @@ class Session:
if self.pairing_method in ( if self.pairing_method in (
PairingMethod.JUST_WORKS, PairingMethod.JUST_WORKS,
PairingMethod.NUMERIC_COMPARISON, PairingMethod.NUMERIC_COMPARISON,
PairingMethod.OOB,
): ):
# We can now send the confirmation value # We can now send the confirmation value
self.send_pairing_confirm_command() self.send_pairing_confirm_command()
@@ -1849,6 +1701,7 @@ class Session:
else: else:
self.send_pairing_dhkey_check_command() self.send_pairing_dhkey_check_command()
else: else:
assert self.ltk
self.start_encryption(self.ltk) self.start_encryption(self.ltk)
def on_smp_pairing_failed_command( def on_smp_pairing_failed_command(
@@ -1898,7 +1751,6 @@ class Manager(EventEmitter):
sessions: Dict[int, Session] sessions: Dict[int, Session]
pairing_config_factory: Callable[[Connection], PairingConfig] pairing_config_factory: Callable[[Connection], PairingConfig]
session_proxy: Type[Session] session_proxy: Type[Session]
_ecc_key: Optional[crypto.EccKey]
def __init__( def __init__(
self, self,
@@ -1993,8 +1845,10 @@ class Manager(EventEmitter):
) -> None: ) -> None:
# Store the keys in the key store # Store the keys in the key store
if self.device.keystore and identity_address is not None: if self.device.keystore and identity_address is not None:
# Make sure on_pairing emits after key update. self.device.abort_on(
await self.device.update_keys(str(identity_address), keys) 'flush', self.device.update_keys(str(identity_address), keys)
)
# Notify the device # Notify the device
self.device.on_pairing(session.connection, identity_address, keys, session.sc) self.device.on_pairing(session.connection, identity_address, keys, session.sc)
+22 -50
View File
@@ -18,7 +18,6 @@
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
import logging import logging
import os import os
from typing import Optional
from .common import Transport, AsyncPipeSink, SnoopingTransport from .common import Transport, AsyncPipeSink, SnoopingTransport
from ..snoop import create_snooper from ..snoop import create_snooper
@@ -53,16 +52,8 @@ def _wrap_transport(transport: Transport) -> Transport:
async def open_transport(name: str) -> Transport: async def open_transport(name: str) -> Transport:
""" """
Open a transport by name. Open a transport by name.
The name must be <type>:<metadata><parameters> The name must be <type>:<parameters>
Where <parameters> depend on the type (and may be empty for some types), and Where <parameters> depend on the type (and may be empty for some types).
<metadata> is either omitted, or a ,-separated list of <key>=<value> pairs,
enclosed in [].
If there are not metadata or parameter, the : after the <type> may be omitted.
Examples:
* usb:0
* usb:[driver=rtk]0
* android-netsim
The supported types are: The supported types are:
* serial * serial
* udp * udp
@@ -80,105 +71,87 @@ async def open_transport(name: str) -> Transport:
* android-netsim * android-netsim
""" """
scheme, *tail = name.split(':', 1) return _wrap_transport(await _open_transport(name))
spec = tail[0] if tail else None
metadata = None
if spec:
# Metadata may precede the spec
if spec.startswith('['):
metadata_str, *tail = spec[1:].split(']')
spec = tail[0] if tail else None
metadata = dict([entry.split('=') for entry in metadata_str.split(',')])
transport = await _open_transport(scheme, spec)
if metadata:
transport.source.metadata = { # type: ignore[attr-defined]
**metadata,
**getattr(transport.source, 'metadata', {}),
}
# pylint: disable=line-too-long
logger.debug(f'HCI metadata: {transport.source.metadata}') # type: ignore[attr-defined]
return _wrap_transport(transport)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def _open_transport(scheme: str, spec: Optional[str]) -> Transport: async def _open_transport(name: str) -> Transport:
# pylint: disable=import-outside-toplevel # pylint: disable=import-outside-toplevel
# pylint: disable=too-many-return-statements # pylint: disable=too-many-return-statements
scheme, *spec = name.split(':', 1)
if scheme == 'serial' and spec: if scheme == 'serial' and spec:
from .serial import open_serial_transport from .serial import open_serial_transport
return await open_serial_transport(spec) return await open_serial_transport(spec[0])
if scheme == 'udp' and spec: if scheme == 'udp' and spec:
from .udp import open_udp_transport from .udp import open_udp_transport
return await open_udp_transport(spec) return await open_udp_transport(spec[0])
if scheme == 'tcp-client' and spec: if scheme == 'tcp-client' and spec:
from .tcp_client import open_tcp_client_transport from .tcp_client import open_tcp_client_transport
return await open_tcp_client_transport(spec) return await open_tcp_client_transport(spec[0])
if scheme == 'tcp-server' and spec: if scheme == 'tcp-server' and spec:
from .tcp_server import open_tcp_server_transport from .tcp_server import open_tcp_server_transport
return await open_tcp_server_transport(spec) return await open_tcp_server_transport(spec[0])
if scheme == 'ws-client' and spec: if scheme == 'ws-client' and spec:
from .ws_client import open_ws_client_transport from .ws_client import open_ws_client_transport
return await open_ws_client_transport(spec) return await open_ws_client_transport(spec[0])
if scheme == 'ws-server' and spec: if scheme == 'ws-server' and spec:
from .ws_server import open_ws_server_transport from .ws_server import open_ws_server_transport
return await open_ws_server_transport(spec) return await open_ws_server_transport(spec[0])
if scheme == 'pty': if scheme == 'pty':
from .pty import open_pty_transport from .pty import open_pty_transport
return await open_pty_transport(spec) return await open_pty_transport(spec[0] if spec else None)
if scheme == 'file': if scheme == 'file':
from .file import open_file_transport from .file import open_file_transport
assert spec is not None assert spec is not None
return await open_file_transport(spec) return await open_file_transport(spec[0])
if scheme == 'vhci': if scheme == 'vhci':
from .vhci import open_vhci_transport from .vhci import open_vhci_transport
return await open_vhci_transport(spec) return await open_vhci_transport(spec[0] if spec else None)
if scheme == 'hci-socket': if scheme == 'hci-socket':
from .hci_socket import open_hci_socket_transport from .hci_socket import open_hci_socket_transport
return await open_hci_socket_transport(spec) return await open_hci_socket_transport(spec[0] if spec else None)
if scheme == 'usb': if scheme == 'usb':
from .usb import open_usb_transport from .usb import open_usb_transport
assert spec assert spec is not None
return await open_usb_transport(spec) return await open_usb_transport(spec[0])
if scheme == 'pyusb': if scheme == 'pyusb':
from .pyusb import open_pyusb_transport from .pyusb import open_pyusb_transport
assert spec assert spec is not None
return await open_pyusb_transport(spec) return await open_pyusb_transport(spec[0])
if scheme == 'android-emulator': if scheme == 'android-emulator':
from .android_emulator import open_android_emulator_transport from .android_emulator import open_android_emulator_transport
return await open_android_emulator_transport(spec) return await open_android_emulator_transport(spec[0] if spec else None)
if scheme == 'android-netsim': if scheme == 'android-netsim':
from .android_netsim import open_android_netsim_transport from .android_netsim import open_android_netsim_transport
return await open_android_netsim_transport(spec) return await open_android_netsim_transport(spec[0] if spec else None)
raise ValueError('unknown transport scheme') raise ValueError('unknown transport scheme')
@@ -197,13 +170,12 @@ async def open_transport_or_link(name: str) -> Transport:
""" """
if name.startswith('link-relay:'): if name.startswith('link-relay:'):
logger.warning('Link Relay has been deprecated.')
from ..controller import Controller from ..controller import Controller
from ..link import RemoteLink # lazy import from ..link import RemoteLink # lazy import
link = RemoteLink(name[11:]) link = RemoteLink(name[11:])
await link.wait_until_connected() await link.wait_until_connected()
controller = Controller('remote', link=link) # type:ignore[arg-type] controller = Controller('remote', link=link)
class LinkTransport(Transport): class LinkTransport(Transport):
async def close(self): async def close(self):
+1 -1
View File
@@ -69,7 +69,7 @@ async def open_android_emulator_transport(spec: Optional[str]) -> Transport:
mode = 'host' mode = 'host'
server_host = 'localhost' server_host = 'localhost'
server_port = '8554' server_port = '8554'
if spec: if spec is not None:
params = spec.split(',') params = spec.split(',')
for param in params: for param in params:
if param.startswith('mode='): if param.startswith('mode='):
+6 -11
View File
@@ -21,7 +21,7 @@ import struct
import asyncio import asyncio
import logging import logging
import io import io
from typing import Any, ContextManager, Tuple, Optional, Protocol, Dict from typing import ContextManager, Tuple, Optional, Protocol, Dict
from bumble import hci from bumble import hci
from bumble.colors import color from bumble.colors import color
@@ -42,7 +42,6 @@ HCI_PACKET_INFO: Dict[int, Tuple[int, int, str]] = {
hci.HCI_ACL_DATA_PACKET: (2, 2, 'H'), hci.HCI_ACL_DATA_PACKET: (2, 2, 'H'),
hci.HCI_SYNCHRONOUS_DATA_PACKET: (1, 2, 'B'), hci.HCI_SYNCHRONOUS_DATA_PACKET: (1, 2, 'B'),
hci.HCI_EVENT_PACKET: (1, 1, 'B'), hci.HCI_EVENT_PACKET: (1, 1, 'B'),
hci.HCI_ISO_DATA_PACKET: (2, 2, 'H'),
} }
@@ -59,13 +58,15 @@ class TransportLostError(Exception):
# Typing Protocols # Typing Protocols
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class TransportSink(Protocol): class TransportSink(Protocol):
def on_packet(self, packet: bytes) -> None: ... def on_packet(self, packet: bytes) -> None:
...
class TransportSource(Protocol): class TransportSource(Protocol):
terminated: asyncio.Future[None] terminated: asyncio.Future[None]
def set_packet_sink(self, sink: TransportSink) -> None: ... def set_packet_sink(self, sink: TransportSink) -> None:
...
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -149,7 +150,7 @@ class PacketParser:
try: try:
self.sink.on_packet(bytes(self.packet)) self.sink.on_packet(bytes(self.packet))
except Exception as error: except Exception as error:
logger.exception( logger.warning(
color(f'!!! Exception in on_packet: {error}', 'red') color(f'!!! Exception in on_packet: {error}', 'red')
) )
self.reset() self.reset()
@@ -166,13 +167,11 @@ class PacketReader:
def __init__(self, source: io.BufferedReader) -> None: def __init__(self, source: io.BufferedReader) -> None:
self.source = source self.source = source
self.at_end = False
def next_packet(self) -> Optional[bytes]: def next_packet(self) -> Optional[bytes]:
# Get the packet type # Get the packet type
packet_type = self.source.read(1) packet_type = self.source.read(1)
if len(packet_type) != 1: if len(packet_type) != 1:
self.at_end = True
return None return None
# Get the packet info based on its type # Get the packet info based on its type
@@ -425,10 +424,6 @@ class SnoopingTransport(Transport):
class Source: class Source:
sink: TransportSink sink: TransportSink
@property
def metadata(self) -> dict[str, Any]:
return getattr(self.source, 'metadata', {})
def __init__(self, source: TransportSource, snooper: Snooper): def __init__(self, source: TransportSource, snooper: Snooper):
self.source = source self.source = source
self.snooper = snooper self.snooper = snooper
+4 -1
View File
@@ -59,7 +59,10 @@ async def open_hci_socket_transport(spec: Optional[str]) -> Transport:
) from error ) from error
# Compute the adapter index # Compute the adapter index
adapter_index = int(spec) if spec else 0 if spec is None:
adapter_index = 0
else:
adapter_index = int(spec)
# Bind the socket # Bind the socket
# NOTE: since Python doesn't support binding with the required address format (yet), # NOTE: since Python doesn't support binding with the required address format (yet),
+2 -110
View File
@@ -23,24 +23,11 @@ import time
import usb.core import usb.core
import usb.util import usb.util
from typing import Optional
from usb.core import Device as UsbDevice
from usb.core import USBError
from usb.util import CTRL_TYPE_CLASS, CTRL_RECIPIENT_OTHER
from usb.legacy import REQ_SET_FEATURE, REQ_CLEAR_FEATURE, CLASS_HUB
from .common import Transport, ParserSource from .common import Transport, ParserSource
from .. import hci from .. import hci
from ..colors import color from ..colors import color
# -----------------------------------------------------------------------------
# Constant
# -----------------------------------------------------------------------------
USB_PORT_FEATURE_POWER = 8
POWER_CYCLE_DELAY = 1
RESET_DELAY = 3
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -126,10 +113,9 @@ async def open_pyusb_transport(spec: str) -> Transport:
self.loop.call_soon_threadsafe(self.stop_event.set) self.loop.call_soon_threadsafe(self.stop_event.set)
class UsbPacketSource(asyncio.Protocol, ParserSource): class UsbPacketSource(asyncio.Protocol, ParserSource):
def __init__(self, device, metadata, sco_enabled): def __init__(self, device, sco_enabled):
super().__init__() super().__init__()
self.device = device self.device = device
self.metadata = metadata
self.loop = asyncio.get_running_loop() self.loop = asyncio.get_running_loop()
self.queue = asyncio.Queue() self.queue = asyncio.Queue()
self.dequeue_task = None self.dequeue_task = None
@@ -227,22 +213,9 @@ async def open_pyusb_transport(spec: str) -> Transport:
usb_find = libusb_package.find usb_find = libusb_package.find
# Find the device according to the spec moniker # Find the device according to the spec moniker
power_cycle = False
if spec.startswith('!'):
power_cycle = True
spec = spec[1:]
if ':' in spec: if ':' in spec:
vendor_id, product_id = spec.split(':') vendor_id, product_id = spec.split(':')
device = usb_find(idVendor=int(vendor_id, 16), idProduct=int(product_id, 16)) device = usb_find(idVendor=int(vendor_id, 16), idProduct=int(product_id, 16))
elif '-' in spec:
def device_path(device):
if device.port_numbers:
return f'{device.bus}-{".".join(map(str, device.port_numbers))}'
else:
return str(device.bus)
device = usb_find(custom_match=lambda device: device_path(device) == spec)
else: else:
device_index = int(spec) device_index = int(spec)
devices = list( devices = list(
@@ -262,17 +235,6 @@ async def open_pyusb_transport(spec: str) -> Transport:
raise ValueError('device not found') raise ValueError('device not found')
logger.debug(f'USB Device: {device}') logger.debug(f'USB Device: {device}')
# Power Cycle the device
if power_cycle:
try:
device = await _power_cycle(device) # type: ignore
except Exception as e:
logging.debug(e)
logging.info(f"Unable to power cycle {hex(device.idVendor)} {hex(device.idProduct)}") # type: ignore
# Collect the metadata
device_metadata = {'vendor_id': device.idVendor, 'product_id': device.idProduct}
# Detach the kernel driver if needed # Detach the kernel driver if needed
if device.is_kernel_driver_active(0): if device.is_kernel_driver_active(0):
logger.debug("detaching kernel driver") logger.debug("detaching kernel driver")
@@ -327,79 +289,9 @@ async def open_pyusb_transport(spec: str) -> Transport:
# except usb.USBError: # except usb.USBError:
# logger.warning('failed to set alternate setting') # logger.warning('failed to set alternate setting')
packet_source = UsbPacketSource(device, device_metadata, sco_enabled) packet_source = UsbPacketSource(device, sco_enabled)
packet_sink = UsbPacketSink(device) packet_sink = UsbPacketSink(device)
packet_source.start() packet_source.start()
packet_sink.start() packet_sink.start()
return UsbTransport(device, packet_source, packet_sink) return UsbTransport(device, packet_source, packet_sink)
async def _power_cycle(device: UsbDevice) -> UsbDevice:
"""
For devices connected to compatible USB hubs: Performs a power cycle on a given USB device.
This involves temporarily disabling its port on the hub and then re-enabling it.
"""
device_path = f'{device.bus}-{".".join(map(str, device.port_numbers))}' # type: ignore
hub = _find_hub_by_device_path(device_path)
if hub:
try:
device_port = device.port_numbers[-1] # type: ignore
_set_port_status(hub, device_port, False)
await asyncio.sleep(POWER_CYCLE_DELAY)
_set_port_status(hub, device_port, True)
await asyncio.sleep(RESET_DELAY)
# Device needs to be find again otherwise it will appear as disconnected
return usb.core.find(idVendor=device.idVendor, idProduct=device.idProduct) # type: ignore
except USBError as e:
logger.error(f"Adjustment needed: Please revise the udev rule for device {hex(device.idVendor)}:{hex(device.idProduct)} for proper recognition.") # type: ignore
logger.error(e)
return device
def _set_port_status(device: UsbDevice, port: int, on: bool):
"""Sets the power status of a specific port on a USB hub."""
device.ctrl_transfer(
bmRequestType=CTRL_TYPE_CLASS | CTRL_RECIPIENT_OTHER,
bRequest=REQ_SET_FEATURE if on else REQ_CLEAR_FEATURE,
wIndex=port,
wValue=USB_PORT_FEATURE_POWER,
)
def _find_device_by_path(sys_path: str) -> Optional[UsbDevice]:
"""Finds a USB device based on its system path."""
bus_num, *port_parts = sys_path.split('-')
ports = [int(port) for port in port_parts[0].split('.')]
devices = usb.core.find(find_all=True, bus=int(bus_num))
if devices:
for device in devices:
if device.bus == int(bus_num) and list(device.port_numbers) == ports: # type: ignore
return device
return None
def _find_hub_by_device_path(sys_path: str) -> Optional[UsbDevice]:
"""Finds the USB hub associated with a specific device path."""
hub_sys_path = sys_path.rsplit('.', 1)[0]
hub_device = _find_device_by_path(hub_sys_path)
if hub_device is None:
return None
else:
return hub_device if _is_hub(hub_device) else None
def _is_hub(device: UsbDevice) -> bool:
"""Checks if a USB device is a hub"""
if device.bDeviceClass == CLASS_HUB: # type: ignore
return True
for config in device:
for interface in config:
if interface.bInterfaceClass == CLASS_HUB: # type: ignore
return True
return False
+5 -25
View File
@@ -18,7 +18,6 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import logging import logging
import socket
from .common import Transport, StreamPacketSource from .common import Transport, StreamPacketSource
@@ -29,13 +28,6 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# A pass-through function to ease mock testing.
async def _create_server(*args, **kw_args):
await asyncio.get_running_loop().create_server(*args, **kw_args)
async def open_tcp_server_transport(spec: str) -> Transport: async def open_tcp_server_transport(spec: str) -> Transport:
''' '''
Open a TCP server transport. Open a TCP server transport.
@@ -46,22 +38,7 @@ async def open_tcp_server_transport(spec: str) -> Transport:
Example: _:9001 Example: _:9001
''' '''
local_host, local_port = spec.split(':')
return await _open_tcp_server_transport_impl(
host=local_host if local_host != '_' else None, port=int(local_port)
)
async def open_tcp_server_transport_with_socket(sock: socket.socket) -> Transport:
'''
Open a TCP server transport with an existing socket.
One reason to use this variant is to let python pick an unused port.
'''
return await _open_tcp_server_transport_impl(sock=sock)
async def _open_tcp_server_transport_impl(**kwargs) -> Transport:
class TcpServerTransport(Transport): class TcpServerTransport(Transport):
async def close(self): async def close(self):
await super().close() await super().close()
@@ -100,10 +77,13 @@ async def _open_tcp_server_transport_impl(**kwargs) -> Transport:
else: else:
logger.debug('no client, dropping packet') logger.debug('no client, dropping packet')
local_host, local_port = spec.split(':')
packet_source = StreamPacketSource() packet_source = StreamPacketSource()
packet_sink = TcpServerPacketSink() packet_sink = TcpServerPacketSink()
await _create_server( await asyncio.get_running_loop().create_server(
lambda: TcpServerProtocol(packet_source, packet_sink), **kwargs lambda: TcpServerProtocol(packet_source, packet_sink),
host=local_host if local_host != '_' else None,
port=int(local_port),
) )
return TcpServerTransport(packet_source, packet_sink) return TcpServerTransport(packet_source, packet_sink)
+63 -70
View File
@@ -24,10 +24,9 @@ import platform
import usb1 import usb1
from bumble.transport.common import Transport, ParserSource from .common import Transport, ParserSource
from bumble import hci from .. import hci
from bumble.colors import color from ..colors import color
from bumble.utils import AsyncRunner
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -108,13 +107,13 @@ async def open_usb_transport(spec: str) -> Transport:
USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER, USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER,
) )
READ_SIZE = 4096 READ_SIZE = 1024
class UsbPacketSink: class UsbPacketSink:
def __init__(self, device, acl_out): def __init__(self, device, acl_out):
self.device = device self.device = device
self.acl_out = acl_out self.acl_out = acl_out
self.acl_out_transfer = device.getTransfer() self.transfer = device.getTransfer()
self.packets = collections.deque() # Queue of packets waiting to be sent self.packets = collections.deque() # Queue of packets waiting to be sent
self.loop = asyncio.get_running_loop() self.loop = asyncio.get_running_loop()
self.cancel_done = self.loop.create_future() self.cancel_done = self.loop.create_future()
@@ -138,20 +137,21 @@ async def open_usb_transport(spec: str) -> Transport:
# The queue was previously empty, re-prime the pump # The queue was previously empty, re-prime the pump
self.process_queue() self.process_queue()
def transfer_callback(self, transfer): def on_packet_sent(self, transfer):
status = transfer.getStatus() status = transfer.getStatus()
# logger.debug(f'<<< USB out transfer callback: status={status}')
# pylint: disable=no-member # pylint: disable=no-member
if status == usb1.TRANSFER_COMPLETED: if status == usb1.TRANSFER_COMPLETED:
self.loop.call_soon_threadsafe(self.on_packet_sent) self.loop.call_soon_threadsafe(self.on_packet_sent_)
elif status == usb1.TRANSFER_CANCELLED: elif status == usb1.TRANSFER_CANCELLED:
self.loop.call_soon_threadsafe(self.cancel_done.set_result, None) self.loop.call_soon_threadsafe(self.cancel_done.set_result, None)
else: else:
logger.warning( logger.warning(
color(f'!!! OUT transfer not completed: status={status}', 'red') color(f'!!! out transfer not completed: status={status}', 'red')
) )
def on_packet_sent(self): def on_packet_sent_(self):
if self.packets: if self.packets:
self.packets.popleft() self.packets.popleft()
self.process_queue() self.process_queue()
@@ -163,20 +163,22 @@ async def open_usb_transport(spec: str) -> Transport:
packet = self.packets[0] packet = self.packets[0]
packet_type = packet[0] packet_type = packet[0]
if packet_type == hci.HCI_ACL_DATA_PACKET: if packet_type == hci.HCI_ACL_DATA_PACKET:
self.acl_out_transfer.setBulk( self.transfer.setBulk(
self.acl_out, packet[1:], callback=self.transfer_callback self.acl_out, packet[1:], callback=self.on_packet_sent
) )
self.acl_out_transfer.submit() logger.debug('submit ACL')
self.transfer.submit()
elif packet_type == hci.HCI_COMMAND_PACKET: elif packet_type == hci.HCI_COMMAND_PACKET:
self.acl_out_transfer.setControl( self.transfer.setControl(
USB_RECIPIENT_DEVICE | USB_REQUEST_TYPE_CLASS, USB_RECIPIENT_DEVICE | USB_REQUEST_TYPE_CLASS,
0, 0,
0, 0,
0, 0,
packet[1:], packet[1:],
callback=self.transfer_callback, callback=self.on_packet_sent,
) )
self.acl_out_transfer.submit() logger.debug('submit COMMAND')
self.transfer.submit()
else: else:
logger.warning(color(f'unsupported packet type {packet_type}', 'red')) logger.warning(color(f'unsupported packet type {packet_type}', 'red'))
@@ -191,11 +193,11 @@ async def open_usb_transport(spec: str) -> Transport:
self.packets.clear() self.packets.clear()
# If we have a transfer in flight, cancel it # If we have a transfer in flight, cancel it
if self.acl_out_transfer.isSubmitted(): if self.transfer.isSubmitted():
# Try to cancel the transfer, but that may fail because it may have # Try to cancel the transfer, but that may fail because it may have
# already completed # already completed
try: try:
self.acl_out_transfer.cancel() self.transfer.cancel()
logger.debug('waiting for OUT transfer cancellation to be done...') logger.debug('waiting for OUT transfer cancellation to be done...')
await self.cancel_done await self.cancel_done
@@ -204,22 +206,27 @@ async def open_usb_transport(spec: str) -> Transport:
logger.debug('OUT transfer likely already completed') logger.debug('OUT transfer likely already completed')
class UsbPacketSource(asyncio.Protocol, ParserSource): class UsbPacketSource(asyncio.Protocol, ParserSource):
def __init__(self, device, metadata, acl_in, events_in): def __init__(self, context, device, metadata, acl_in, events_in):
super().__init__() super().__init__()
self.context = context
self.device = device self.device = device
self.metadata = metadata self.metadata = metadata
self.acl_in = acl_in self.acl_in = acl_in
self.acl_in_transfer = None
self.events_in = events_in self.events_in = events_in
self.events_in_transfer = None
self.loop = asyncio.get_running_loop() self.loop = asyncio.get_running_loop()
self.queue = asyncio.Queue() self.queue = asyncio.Queue()
self.dequeue_task = None self.dequeue_task = None
self.closed = False
self.event_loop_done = self.loop.create_future()
self.cancel_done = { self.cancel_done = {
hci.HCI_EVENT_PACKET: self.loop.create_future(), hci.HCI_EVENT_PACKET: self.loop.create_future(),
hci.HCI_ACL_DATA_PACKET: self.loop.create_future(), hci.HCI_ACL_DATA_PACKET: self.loop.create_future(),
} }
self.closed = False self.events_in_transfer = None
self.acl_in_transfer = None
# Create a thread to process events
self.event_thread = threading.Thread(target=self.run)
def start(self): def start(self):
# Set up transfer objects for input # Set up transfer objects for input
@@ -227,7 +234,7 @@ async def open_usb_transport(spec: str) -> Transport:
self.events_in_transfer.setInterrupt( self.events_in_transfer.setInterrupt(
self.events_in, self.events_in,
READ_SIZE, READ_SIZE,
callback=self.transfer_callback, callback=self.on_packet_received,
user_data=hci.HCI_EVENT_PACKET, user_data=hci.HCI_EVENT_PACKET,
) )
self.events_in_transfer.submit() self.events_in_transfer.submit()
@@ -236,23 +243,22 @@ async def open_usb_transport(spec: str) -> Transport:
self.acl_in_transfer.setBulk( self.acl_in_transfer.setBulk(
self.acl_in, self.acl_in,
READ_SIZE, READ_SIZE,
callback=self.transfer_callback, callback=self.on_packet_received,
user_data=hci.HCI_ACL_DATA_PACKET, user_data=hci.HCI_ACL_DATA_PACKET,
) )
self.acl_in_transfer.submit() self.acl_in_transfer.submit()
self.dequeue_task = self.loop.create_task(self.dequeue()) self.dequeue_task = self.loop.create_task(self.dequeue())
self.event_thread.start()
@property def on_packet_received(self, transfer):
def usb_transfer_submitted(self):
return (
self.events_in_transfer.isSubmitted()
or self.acl_in_transfer.isSubmitted()
)
def transfer_callback(self, transfer):
packet_type = transfer.getUserData() packet_type = transfer.getUserData()
status = transfer.getStatus() status = transfer.getStatus()
# logger.debug(
# f'<<< USB IN transfer callback: status={status} '
# f'packet_type={packet_type} '
# f'length={transfer.getActualLength()}'
# )
# pylint: disable=no-member # pylint: disable=no-member
if status == usb1.TRANSFER_COMPLETED: if status == usb1.TRANSFER_COMPLETED:
@@ -261,18 +267,18 @@ async def open_usb_transport(spec: str) -> Transport:
+ transfer.getBuffer()[: transfer.getActualLength()] + transfer.getBuffer()[: transfer.getActualLength()]
) )
self.loop.call_soon_threadsafe(self.queue.put_nowait, packet) self.loop.call_soon_threadsafe(self.queue.put_nowait, packet)
# Re-submit the transfer so we can receive more data
transfer.submit()
elif status == usb1.TRANSFER_CANCELLED: elif status == usb1.TRANSFER_CANCELLED:
self.loop.call_soon_threadsafe( self.loop.call_soon_threadsafe(
self.cancel_done[packet_type].set_result, None self.cancel_done[packet_type].set_result, None
) )
return
else: else:
logger.warning( logger.warning(
color(f'!!! IN transfer not completed: status={status}', 'red') color(f'!!! transfer not completed: status={status}', 'red')
) )
self.loop.call_soon_threadsafe(self.on_transport_lost)
# Re-submit the transfer so we can receive more data
transfer.submit()
async def dequeue(self): async def dequeue(self):
while not self.closed: while not self.closed:
@@ -282,6 +288,21 @@ async def open_usb_transport(spec: str) -> Transport:
return return
self.parser.feed_data(packet) self.parser.feed_data(packet)
def run(self):
logger.debug('starting USB event loop')
while (
self.events_in_transfer.isSubmitted()
or self.acl_in_transfer.isSubmitted()
):
# pylint: disable=no-member
try:
self.context.handleEvents()
except usb1.USBErrorInterrupted:
pass
logger.debug('USB event loop done')
self.loop.call_soon_threadsafe(self.event_loop_done.set_result, None)
def close(self): def close(self):
self.closed = True self.closed = True
@@ -310,14 +331,15 @@ async def open_usb_transport(spec: str) -> Transport:
f'IN[{packet_type}] transfer likely already completed' f'IN[{packet_type}] transfer likely already completed'
) )
# Wait for the thread to terminate
await self.event_loop_done
class UsbTransport(Transport): class UsbTransport(Transport):
def __init__(self, context, device, interface, setting, source, sink): def __init__(self, context, device, interface, setting, source, sink):
super().__init__(source, sink) super().__init__(source, sink)
self.context = context self.context = context
self.device = device self.device = device
self.interface = interface self.interface = interface
self.loop = asyncio.get_running_loop()
self.event_loop_done = self.loop.create_future()
# Get exclusive access # Get exclusive access
device.claimInterface(interface) device.claimInterface(interface)
@@ -330,22 +352,6 @@ async def open_usb_transport(spec: str) -> Transport:
source.start() source.start()
sink.start() sink.start()
# Create a thread to process events
self.event_thread = threading.Thread(target=self.run)
self.event_thread.start()
def run(self):
logger.debug('starting USB event loop')
while self.source.usb_transfer_submitted:
# pylint: disable=no-member
try:
self.context.handleEvents()
except usb1.USBErrorInterrupted:
pass
logger.debug('USB event loop done')
self.loop.call_soon_threadsafe(self.event_loop_done.set_result, None)
async def close(self): async def close(self):
self.source.close() self.source.close()
self.sink.close() self.sink.close()
@@ -355,9 +361,6 @@ async def open_usb_transport(spec: str) -> Transport:
self.device.close() self.device.close()
self.context.close() self.context.close()
# Wait for the thread to terminate
await self.event_loop_done
# Find the device according to the spec moniker # Find the device according to the spec moniker
load_libusb() load_libusb()
context = usb1.USBContext() context = usb1.USBContext()
@@ -396,16 +399,6 @@ async def open_usb_transport(spec: str) -> Transport:
break break
device_index -= 1 device_index -= 1
device.close() device.close()
elif '-' in spec:
def device_path(device):
return f'{device.getBusNumber()}-{".".join(map(str, device.getPortNumberList()))}'
for device in context.getDeviceIterator(skip_on_error=True):
if device_path(device) == spec:
found = device
break
device.close()
else: else:
# Look for a compatible device by index # Look for a compatible device by index
def device_is_bluetooth_hci(device): def device_is_bluetooth_hci(device):
@@ -449,7 +442,7 @@ async def open_usb_transport(spec: str) -> Transport:
# Look for the first interface with the right class and endpoints # Look for the first interface with the right class and endpoints
def find_endpoints(device): def find_endpoints(device):
# pylint: disable-next=too-many-nested-blocks # pylint: disable-next=too-many-nested-blocks
for configuration_index, configuration in enumerate(device): for (configuration_index, configuration) in enumerate(device):
interface = None interface = None
for interface in configuration: for interface in configuration:
setting = None setting = None
@@ -547,7 +540,7 @@ async def open_usb_transport(spec: str) -> Transport:
except usb1.USBError: except usb1.USBError:
logger.warning('failed to set configuration') logger.warning('failed to set configuration')
source = UsbPacketSource(device, device_metadata, acl_in, events_in) source = UsbPacketSource(context, device, device_metadata, acl_in, events_in)
sink = UsbPacketSink(device, acl_out) sink = UsbPacketSink(device, acl_out)
return UsbTransport(context, device, interface, setting, source, sink) return UsbTransport(context, device, interface, setting, source, sink)
except usb1.USBError as error: except usb1.USBError as error:
+36 -79
View File
@@ -17,10 +17,9 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import collections
import enum
import functools
import logging import logging
import traceback
import collections
import sys import sys
import warnings import warnings
from typing import ( from typing import (
@@ -35,7 +34,7 @@ from typing import (
Union, Union,
overload, overload,
) )
from functools import wraps, partial
from pyee import EventEmitter from pyee import EventEmitter
from .colors import color from .colors import color
@@ -117,12 +116,12 @@ class EventWatcher:
self.handlers = [] self.handlers = []
@overload @overload
def on( def on(self, emitter: EventEmitter, event: str) -> Callable[[_Handler], _Handler]:
self, emitter: EventEmitter, event: str ...
) -> Callable[[_Handler], _Handler]: ...
@overload @overload
def on(self, emitter: EventEmitter, event: str, handler: _Handler) -> _Handler: ... def on(self, emitter: EventEmitter, event: str, handler: _Handler) -> _Handler:
...
def on( def on(
self, emitter: EventEmitter, event: str, handler: Optional[_Handler] = None self, emitter: EventEmitter, event: str, handler: Optional[_Handler] = None
@@ -132,26 +131,23 @@ class EventWatcher:
Args: Args:
emitter: EventEmitter to watch emitter: EventEmitter to watch
event: Event name event: Event name
handler: (Optional) Event handler. When nothing is passed, this method handler: (Optional) Event handler. When nothing is passed, this method works as a decorator.
works as a decorator.
''' '''
def wrapper(wrapped: _Handler) -> _Handler: def wrapper(f: _Handler) -> _Handler:
self.handlers.append((emitter, event, wrapped)) self.handlers.append((emitter, event, f))
emitter.on(event, wrapped) emitter.on(event, f)
return wrapped return f
return wrapper if handler is None else wrapper(handler) return wrapper if handler is None else wrapper(handler)
@overload @overload
def once( def once(self, emitter: EventEmitter, event: str) -> Callable[[_Handler], _Handler]:
self, emitter: EventEmitter, event: str ...
) -> Callable[[_Handler], _Handler]: ...
@overload @overload
def once( def once(self, emitter: EventEmitter, event: str, handler: _Handler) -> _Handler:
self, emitter: EventEmitter, event: str, handler: _Handler ...
) -> _Handler: ...
def once( def once(
self, emitter: EventEmitter, event: str, handler: Optional[_Handler] = None self, emitter: EventEmitter, event: str, handler: Optional[_Handler] = None
@@ -161,14 +157,13 @@ class EventWatcher:
Args: Args:
emitter: EventEmitter to watch emitter: EventEmitter to watch
event: Event name event: Event name
handler: (Optional) Event handler. When nothing passed, this method works handler: (Optional) Event handler. When nothing passed, this method works as a decorator.
as a decorator.
''' '''
def wrapper(wrapped: _Handler) -> _Handler: def wrapper(f: _Handler) -> _Handler:
self.handlers.append((emitter, event, wrapped)) self.handlers.append((emitter, event, f))
emitter.once(event, wrapped) emitter.once(event, f)
return wrapped return f
return wrapper if handler is None else wrapper(handler) return wrapper if handler is None else wrapper(handler)
@@ -228,13 +223,13 @@ class CompositeEventEmitter(AbortableEventEmitter):
if self._listener: if self._listener:
# Call the deregistration methods for each base class that has them # Call the deregistration methods for each base class that has them
for cls in self._listener.__class__.mro(): for cls in self._listener.__class__.mro():
if '_bumble_register_composite' in cls.__dict__: if hasattr(cls, '_bumble_register_composite'):
cls._bumble_deregister_composite(self._listener, self) cls._bumble_deregister_composite(listener, self)
self._listener = listener self._listener = listener
if listener: if listener:
# Call the registration methods for each base class that has them # Call the registration methods for each base class that has them
for cls in listener.__class__.mro(): for cls in listener.__class__.mro():
if '_bumble_deregister_composite' in cls.__dict__: if hasattr(cls, '_bumble_deregister_composite'):
cls._bumble_register_composite(listener, self) cls._bumble_register_composite(listener, self)
@@ -281,18 +276,21 @@ class AsyncRunner:
""" """
def decorator(func): def decorator(func):
@functools.wraps(func) @wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
coroutine = func(*args, **kwargs) coroutine = func(*args, **kwargs)
if queue is None: if queue is None:
# Spawn the coroutine as a task # Create a task to run the coroutine
async def run(): async def run():
try: try:
await coroutine await coroutine
except Exception: except Exception:
logger.exception(color("!!! Exception in wrapper:", "red")) logger.warning(
f'{color("!!! Exception in wrapper:", "red")} '
f'{traceback.format_exc()}'
)
AsyncRunner.spawn(run()) asyncio.create_task(run())
else: else:
# Queue the coroutine to be awaited by the work queue # Queue the coroutine to be awaited by the work queue
queue.enqueue(coroutine) queue.enqueue(coroutine)
@@ -415,35 +413,30 @@ class FlowControlAsyncPipe:
self.check_pump() self.check_pump()
# -----------------------------------------------------------------------------
async def async_call(function, *args, **kwargs): async def async_call(function, *args, **kwargs):
""" """
Immediately calls the function with provided args and kwargs, wrapping it in an Immediately calls the function with provided args and kwargs, wrapping it in an async function.
async function. Rust's `pyo3_asyncio` library needs functions to be marked async to properly inject a running loop.
Rust's `pyo3_asyncio` library needs functions to be marked async to properly inject
a running loop.
result = await async_call(some_function, ...) result = await async_call(some_function, ...)
""" """
return function(*args, **kwargs) return function(*args, **kwargs)
# -----------------------------------------------------------------------------
def wrap_async(function): def wrap_async(function):
""" """
Wraps the provided function in an async function. Wraps the provided function in an async function.
""" """
return functools.partial(async_call, function) return partial(async_call, function)
# -----------------------------------------------------------------------------
def deprecated(msg: str): def deprecated(msg: str):
""" """
Throw deprecation warning before execution. Throw deprecation warning before execution
""" """
def wrapper(function): def wrapper(function):
@functools.wraps(function) @wraps(function)
def inner(*args, **kwargs): def inner(*args, **kwargs):
warnings.warn(msg, DeprecationWarning) warnings.warn(msg, DeprecationWarning)
return function(*args, **kwargs) return function(*args, **kwargs)
@@ -451,39 +444,3 @@ def deprecated(msg: str):
return inner return inner
return wrapper return wrapper
# -----------------------------------------------------------------------------
def experimental(msg: str):
"""
Throws a future warning before execution.
"""
def wrapper(function):
@functools.wraps(function)
def inner(*args, **kwargs):
warnings.warn(msg, FutureWarning)
return function(*args, **kwargs)
return inner
return wrapper
# -----------------------------------------------------------------------------
class OpenIntEnum(enum.IntEnum):
"""
Subclass of enum.IntEnum that can hold integer values outside the set of
predefined values. This is convenient for implementing protocols where some
integer constants may be added over time.
"""
@classmethod
def _missing_(cls, value):
if not isinstance(value, int):
return None
obj = int.__new__(cls, value)
obj._value_ = value
obj._name_ = f"{cls.__name__}[{value}]"
return obj
-1
View File
@@ -70,7 +70,6 @@ nav:
- Extras: - Extras:
- extras/index.md - extras/index.md
- Android Remote HCI: extras/android_remote_hci.md - Android Remote HCI: extras/android_remote_hci.md
- Android BT Bench: extras/android_bt_bench.md
- Hive: - Hive:
- hive/index.md - hive/index.md
- Speaker: hive/web/speaker/speaker.html - Speaker: hive/web/speaker/speaker.html
+9 -30
View File
@@ -7,36 +7,16 @@ throughput and/or latency between two devices.
# General Usage # General Usage
``` ```
Usage: bumble-bench [OPTIONS] COMMAND [ARGS]... Usage: bench.py [OPTIONS] COMMAND [ARGS]...
Options: Options:
--device-config FILENAME Device configuration file --device-config FILENAME Device configuration file
--role [sender|receiver|ping|pong] --role [sender|receiver|ping|pong]
--mode [gatt-client|gatt-server|l2cap-client|l2cap-server|rfcomm-client|rfcomm-server] --mode [gatt-client|gatt-server|l2cap-client|l2cap-server|rfcomm-client|rfcomm-server]
--att-mtu MTU GATT MTU (gatt-client mode) [23<=x<=517] --att-mtu MTU GATT MTU (gatt-client mode) [23<=x<=517]
--extended-data-length TEXT Request a data length upon connection, -s, --packet-size SIZE Packet size (server role) [8<=x<=4096]
specified as tx_octets/tx_time -c, --packet-count COUNT Packet count (server role)
--rfcomm-channel INTEGER RFComm channel to use -sd, --start-delay SECONDS Start delay (server role)
--rfcomm-uuid TEXT RFComm service UUID to use (ignored if
--rfcomm-channel is not 0)
--l2cap-psm INTEGER L2CAP PSM to use
--l2cap-mtu INTEGER L2CAP MTU to use
--l2cap-mps INTEGER L2CAP MPS to use
--l2cap-max-credits INTEGER L2CAP maximum number of credits allowed for
the peer
-s, --packet-size SIZE Packet size (client or ping role)
[8<=x<=4096]
-c, --packet-count COUNT Packet count (client or ping role)
-sd, --start-delay SECONDS Start delay (client or ping role)
--repeat N Repeat the run N times (client and ping
roles)(0, which is the fault, to run just
once)
--repeat-delay SECONDS Delay, in seconds, between repeats
--pace MILLISECONDS Wait N milliseconds between packets (0,
which is the fault, to send as fast as
possible)
--linger Don't exit at the end of a run (server and
pong roles)
--help Show this message and exit. --help Show this message and exit.
Commands: Commands:
@@ -55,18 +35,17 @@ Options:
--connection-interval, --ci CONNECTION_INTERVAL --connection-interval, --ci CONNECTION_INTERVAL
Connection interval (in ms) Connection interval (in ms)
--phy [1m|2m|coded] PHY to use --phy [1m|2m|coded] PHY to use
--authenticate Authenticate (RFComm only)
--encrypt Encrypt the connection (RFComm only)
--help Show this message and exit. --help Show this message and exit.
``` ```
To test once device against another, one of the two devices must be running
To test once device against another, one of the two devices must be running
the ``peripheral`` command and the other the ``central`` command. The device the ``peripheral`` command and the other the ``central`` command. The device
running the ``peripheral`` command will accept connections from the device running the ``peripheral`` command will accept connections from the device
running the ``central`` command. running the ``central`` command.
When using Bluetooth LE (all modes except for ``rfcomm-server`` and ``rfcomm-client``utils), When using Bluetooth LE (all modes except for ``rfcomm-server`` and ``rfcomm-client``utils),
the default addresses configured in the tool should be sufficient. But when using the default addresses configured in the tool should be sufficient. But when using
Bluetooth Classic, the address of the Peripheral must be specified on the Central Bluetooth Classic, the address of the Peripheral must be specified on the Central
using the ``--peripheral`` option. The address will be printed by the Peripheral when using the ``--peripheral`` option. The address will be printed by the Peripheral when
it starts. it starts.
@@ -104,7 +83,7 @@ the other on `usb:1`, and two consoles/terminals. We will run a command in each.
$ bumble-bench central usb:1 $ bumble-bench central usb:1
``` ```
In this default configuration, the Central runs a Sender, as a GATT client, In this default configuration, the Central runs a Sender, as a GATT client,
connecting to the Peripheral running a Receiver, as a GATT server. connecting to the Peripheral running a Receiver, as a GATT server.
!!! example "L2CAP Throughput" !!! example "L2CAP Throughput"
@@ -12,25 +12,12 @@ a host that send custom HCI commands that the controller may not understand.
``` ```
python hci_bridge.py <host-transport-spec> <controller-transport-spec> [command-short-circuit-list] python hci_bridge.py <host-transport-spec> <controller-transport-spec> [command-short-circuit-list]
``` ```
The command-short-circuit-list field is specified by a series of comma separated Opcode Group
Field (OGF) : OpCode Command Field (OCF) pairs. The OGF/OCF values are specified in the Blutooth
core specification.
For the commands that are listed in the short-circuit-list, the HCI bridge will always generate
a Command Complete Event for the specified op code. The return parameter will be HCI_SUCCESS.
This feature can only be used for commands that return Command Complete. Other events will not be
generated by the HCI bridge tool.
!!! example "UDP to Serial" !!! example "UDP to Serial"
``` ```
python hci_bridge.py udp:0.0.0.0:9000,127.0.0.1:9001 serial:/dev/tty.usbmodem0006839912171,1000000 0x3f:0x0070,0x3f:0x0074,0x3f:0x0077,0x3f:0x0078 python hci_bridge.py udp:0.0.0.0:9000,127.0.0.1:9001 serial:/dev/tty.usbmodem0006839912171,1000000 0x3f:0x0070,0x3f:0x0074,0x3f:0x0077,0x3f:0x0078
``` ```
In this example, the short circuit list is specified to respond to the Vendor-specific Opcode Group
Field (0x3f) commands 0x70, 0x74, 0x77, 0x78 with Command Complete. The short circuit list can be
used where the Host uses some HCI commands that are not supported/implemented by the Controller.
!!! example "PTY to Link Relay" !!! example "PTY to Link Relay"
``` ```
python hci_bridge.py serial:emulated_uart_pty,1000000 link-relay:ws://127.0.0.1:10723/test python hci_bridge.py serial:emulated_uart_pty,1000000 link-relay:ws://127.0.0.1:10723/test
@@ -41,4 +28,3 @@ a host that send custom HCI commands that the controller may not understand.
(through which the communication with other virtual controllers will be mediated). (through which the communication with other virtual controllers will be mediated).
NOTE: this assumes you're running a Link Relay on port `10723`. NOTE: this assumes you're running a Link Relay on port `10723`.
-9
View File
@@ -5,15 +5,6 @@ Some Bluetooth controllers require a driver to function properly.
This may include, for instance, loading a Firmware image or patch, This may include, for instance, loading a Firmware image or patch,
loading a configuration. loading a configuration.
By default, drivers will be automatically probed to determine if they should be
used with particular HCI controller.
When the transport for an HCI controller is instantiated from a transport name,
a driver may also be forced by specifying ``driver=<driver-name>`` in the optional
metadata portion of the transport name. For example,
``usb:[driver=rtk]0`` indicates that the ``rtk`` driver should be used with the
first USB device, even if a normal probe would not have selected it based on the
USB vendor ID and product ID.
Drivers included in the module are: Drivers included in the module are:
* [Realtek](realtek.md): Loading of Firmware and Config for Realtek USB dongles. * [Realtek](realtek.md): Loading of Firmware and Config for Realtek USB dongles.
+2 -5
View File
@@ -1,16 +1,13 @@
REALTEK DRIVER REALTEK DRIVER
============== ==============
This driver supports loading firmware images and optional config data to This driver supports loading firmware images and optional config data to
USB dongles with a Realtek chipset. USB dongles with a Realtek chipset.
A number of USB dongles are supported, but likely not all. A number of USB dongles are supported, but likely not all.
When using a USB dongle, the USB product ID and vendor ID are used When using a USB dongle, the USB product ID and manufacturer ID are used
to find whether a matching set of firmware image and config data to find whether a matching set of firmware image and config data
is needed for that specific model. If a match exists, the driver will try is needed for that specific model. If a match exists, the driver will try
load the firmware image and, if needed, config data. load the firmware image and, if needed, config data.
Alternatively, the metadata property ``driver=rtk`` may be specified in a transport
name to force that driver to be used (ex: ``usb:[driver=rtk]0`` instead of just
``usb:0`` for the first USB device).
The driver will look for those files by name, in order, in: The driver will look for those files by name, in order, in:
* The directory specified by the environment variable `BUMBLE_RTK_FIRMWARE_DIR` * The directory specified by the environment variable `BUMBLE_RTK_FIRMWARE_DIR`
@@ -1,64 +0,0 @@
ANDROID BENCH APP
=================
This Android app that is compatible with the Bumble `bench` command line app.
This app can be used to test the throughput and latency between two Android
devices, or between an Android device and another device running the Bumble
`bench` app.
Only the RFComm Client, RFComm Server, L2CAP Client and L2CAP Server modes are
supported.
Building
--------
You can build the app by running `./gradlew build` (use `gradlew.bat` on Windows) from the `BtBench` top level directory.
You can also build with Android Studio: open the `BtBench` project. You can build and/or debug from there.
If the build succeeds, you can find the app APKs (debug and release) at:
* [Release] ``app/build/outputs/apk/release/app-release-unsigned.apk``
* [Debug] ``app/build/outputs/apk/debug/app-debug.apk``
Running
-------
### Starting the app
You can start the app from the Android launcher, from Android Studio, or with `adb`
#### Launching from the launcher
Just tap the app icon on the launcher, check the parameters, and tap
one of the benchmark action buttons.
#### Launching with `adb`
Using the `am` command, you can start the activity, and pass it arguments so that you can
automatically start the benchmark test, and/or set the parameters.
| Parameter Name | Parameter Type | Description
|------------------------|----------------|------------
| autostart | String | Benchmark to start. (rfcomm-client, rfcomm-server, l2cap-client or l2cap-server)
| packet-count | Integer | Number of packets to send (rfcomm-client and l2cap-client only)
| packet-size | Integer | Number of bytes per packet (rfcomm-client and l2cap-client only)
| peer-bluetooth-address | Integer | Peer Bluetooth address to connect to (rfcomm-client and l2cap-client | only)
!!! tip "Launching from adb with auto-start"
In this example, we auto-start the Rfcomm Server bench action.
```bash
$ adb shell am start -n com.github.google.bumble.btbench/.MainActivity --es autostart rfcomm-server
```
!!! tip "Launching from adb with auto-start and some parameters"
In this example, we auto-start the Rfcomm Client bench action, set the packet count to 100,
and the packet size to 1024, and connect to DA:4C:10:DE:17:02
```bash
$ adb shell am start -n com.github.google.bumble.btbench/.MainActivity --es autostart rfcomm-client --ei packet-count 100 --ei packet-size 1024 --es peer-bluetooth-address DA:4C:10:DE:17:02
```
#### Selecting a Peer Bluetooth Address
The app's main activity has a "Peer Bluetooth Address" setting where you can change the address.
!!! note "Bluetooth Address for L2CAP vs RFComm"
For BLE (L2CAP mode), the address of a device typically changes regularly (it is randomized for privacy), whereas the Bluetooth Classic addresses will remain the same (RFComm mode).
If two devices are paired and bonded, then they will each "see" a non-changing address for each other even with BLE (Resolvable Private Address)
+13 -53
View File
@@ -1,19 +1,19 @@
ANDROID REMOTE HCI APP ANDROID REMOTE HCI APP
====================== ======================
This application allows using an android phone's built-in Bluetooth controller with This application allows using an android phone's built-in Bluetooth controller with
a Bumble host stack running outside the phone (typically a development laptop or desktop). a Bumble host stack running outside the phone (typically a development laptop or desktop).
The app runs an HCI proxy between a TCP socket on the "outside" and the Bluetooth HCI HAL The app runs an HCI proxy between a TCP socket on the "outside" and the Bluetooth HCI HAL
on the "inside". (See [this page](https://source.android.com/docs/core/connect/bluetooth) for a high level on the "inside". (See [this page](https://source.android.com/docs/core/connect/bluetooth) for a high level
description of the Android Bluetooth HCI HAL). description of the Android Bluetooth HCI HAL).
The HCI packets received on the TCP socket are forwarded to the phone's controller, and the The HCI packets received on the TCP socket are forwarded to the phone's controller, and the
packets coming from the controller are forwarded to the TCP socket. packets coming from the controller are forwarded to the TCP socket.
Building Building
-------- --------
You can build the app by running `./gradlew build` (use `gradlew.bat` on Windows) from the `extras/android/RemoteHCI` top level directory. You can build the app by running `./gradlew build` (use `gradlew.bat` on Windows) from the `RemoteHCI` top level directory.
You can also build with Android Studio: open the `RemoteHCI` project. You can build and/or debug from there. You can also build with Android Studio: open the `RemoteHCI` project. You can build and/or debug from there.
If the build succeeds, you can find the app APKs (debug and release) at: If the build succeeds, you can find the app APKs (debug and release) at:
@@ -25,23 +25,9 @@ If the build succeeds, you can find the app APKs (debug and release) at:
Running Running
------- -------
!!! note
In the following examples, it is assumed that shell commands are executed while in the
app's root directory, `extras/android/RemoteHCI`. If you are in a different directory,
adjust the relative paths accordingly.
### Preconditions ### Preconditions
When the proxy starts (tapping the "Start" button in the app's main activity, or running the proxy When the proxy starts (tapping the "Start" button in the app's main activity), it will try to
from an `adb shell` command line), it will try to bind to the Bluetooth HAL. bind to the Bluetooth HAL. This requires disabling SELinux temporarily, and being the only HAL client.
This requires that there is no other HAL client, and requires certain privileges.
For running as a regular app, this requires disabling SELinux temporarily.
For running as a command-line executable, this just requires a root shell.
#### Root Shell
!!! tip "Restart `adb` as root"
```bash
$ adb root
```
#### Disabling SELinux #### Disabling SELinux
Binding to the Bluetooth HCI HAL requires certain SELinux permissions that can't simply be changed Binding to the Bluetooth HCI HAL requires certain SELinux permissions that can't simply be changed
@@ -70,8 +56,8 @@ development phone).
This state will also reset to the normal SELinux enforcement when you reboot. This state will also reset to the normal SELinux enforcement when you reboot.
#### Stopping the bluetooth process #### Stopping the bluetooth process
Since the Bluetooth HAL service can only accept one client, and that in normal conditions Since the Bluetooth HAL service can only accept one client, and that in normal conditions
that client is the Android's bluetooth stack, it is required to first shut down the that client is the Android's bluetooth stack, it is required to first shut down the
Android bluetooth stack process. Android bluetooth stack process.
!!! tip "Checking if the Bluetooth process is running" !!! tip "Checking if the Bluetooth process is running"
@@ -93,33 +79,7 @@ Airplane Mode, then rebooting. The bluetooth process should, in theory, not rest
$ adb shell cmd bluetooth_manager disable $ adb shell cmd bluetooth_manager disable
``` ```
### Running as a command line app ### Starting the app
You push the built APK to a temporary location on the phone's filesystem, then launch the command
line executable with an `adb shell` command.
!!! tip "Pushing the executable"
```bash
$ adb push app/build/outputs/apk/release/app-release-unsigned.apk /data/local/tmp/remotehci.apk
```
Do this every time you rebuild. Alternatively, you can push the `debug` APK instead:
```bash
$ adb push app/build/outputs/apk/debug/app-debug.apk /data/local/tmp/remotehci.apk
```
!!! tip "Start the proxy from the command line"
```bash
adb shell "CLASSPATH=/data/local/tmp/remotehci.apk app_process /system/bin com.github.google.bumble.remotehci.CommandLineInterface"
```
This will run the proxy, listening on the default TCP port.
If you want a different port, pass it as a command line parameter
!!! tip "Start the proxy from the command line with a specific TCP port"
```bash
adb shell "CLASSPATH=/data/local/tmp/remotehci.apk app_process /system/bin com.github.google.bumble.remotehci.CommandLineInterface 12345"
```
### Running as a normal app
You can start the app from the Android launcher, from Android Studio, or with `adb` You can start the app from the Android launcher, from Android Studio, or with `adb`
#### Launching from the launcher #### Launching from the launcher
@@ -143,11 +103,11 @@ automatically start the proxy, and/or set the port number.
#### Selecting a TCP port #### Selecting a TCP port
The RemoteHCI app's main activity has a "TCP Port" setting where you can change the port on The RemoteHCI app's main activity has a "TCP Port" setting where you can change the port on
which the proxy is accepting connections. If the default value isn't suitable, you can which the proxy is accepting connections. If the default value isn't suitable, you can
change it there (you can also use the special value 0 to let the OS assign a port number for you). change it there (you can also use the special value 0 to let the OS assign a port number for you).
### Connecting to the proxy ### Connecting to the proxy
To connect the Bumble stack to the proxy, you need to be able to reach the phone's network To connect the Bumble stack to the proxy, you need to be able to reach the phone's network
stack. This can be done over the phone's WiFi connection, or, alternatively, using an `adb` stack. This can be done over the phone's WiFi connection, or, alternatively, using an `adb`
TCP forward (which should be faster than over WiFi). TCP forward (which should be faster than over WiFi).
@@ -156,7 +116,7 @@ TCP forward (which should be faster than over WiFi).
```bash ```bash
$ adb forward tcp:<outside-port> tcp:<inside-port> $ adb forward tcp:<outside-port> tcp:<inside-port>
``` ```
Where ``<outside-port>`` is the port number for a listening socket on your laptop or Where ``<outside-port>`` is the port number for a listening socket on your laptop or
desktop machine, and <inside-port> is the TCP port selected in the app's user interface. desktop machine, and <inside-port> is the TCP port selected in the app's user interface.
Those two ports may be the same, of course. Those two ports may be the same, of course.
For example, with the default TCP port 9993: For example, with the default TCP port 9993:
@@ -165,7 +125,7 @@ TCP forward (which should be faster than over WiFi).
``` ```
Once you've ensured that you can reach the proxy's TCP port on the phone, either directly or Once you've ensured that you can reach the proxy's TCP port on the phone, either directly or
via an `adb` forward, you can then use it as a Bumble transport, using the transport name: via an `adb` forward, you can then use it as a Bumble transport, using the transport name:
``tcp-client:<host>:<port>`` syntax. ``tcp-client:<host>:<port>`` syntax.
!!! example "Connecting a Bumble client" !!! example "Connecting a Bumble client"
+1 -9
View File
@@ -8,12 +8,4 @@ Android Remote HCI
Allows using an Android phone's built-in Bluetooth controller with a Bumble Allows using an Android phone's built-in Bluetooth controller with a Bumble
stack running on a development machine. stack running on a development machine.
See [Android Remote HCI](android_remote_hci.md) for details. See [Android Remote HCI](android_remote_hci.md) for details.
Android BT Bench
----------------
An Android app that is compatible with the Bumble `bench` command line app.
This app can be used to test the throughput and latency between two Android
devices, or between an Android device and another device running the Bumble
`bench` app.
-6
View File
@@ -10,7 +10,6 @@ The moniker for a USB transport is either:
* `usb:<vendor>:<product>` * `usb:<vendor>:<product>`
* `usb:<vendor>:<product>/<serial-number>` * `usb:<vendor>:<product>/<serial-number>`
* `usb:<vendor>:<product>#<index>` * `usb:<vendor>:<product>#<index>`
* `usb:<bus>-<port_numbers>`
with `<index>` as a 0-based index (0 being the first one) to select amongst all the matching devices when there are more than one. with `<index>` as a 0-based index (0 being the first one) to select amongst all the matching devices when there are more than one.
In the `usb:<index>` form, matching devices are the ones supporting Bluetooth HCI, as declared by their Class, Subclass and Protocol. In the `usb:<index>` form, matching devices are the ones supporting Bluetooth HCI, as declared by their Class, Subclass and Protocol.
@@ -18,8 +17,6 @@ In the `usb:<vendor>:<product>#<index>` form, matching devices are the ones with
`<vendor>` and `<product>` are a vendor ID and product ID in hexadecimal. `<vendor>` and `<product>` are a vendor ID and product ID in hexadecimal.
with `<port_numbers>` as a list of all port numbers from root separated with dots `.`
In addition, if the moniker ends with the symbol "!", the device will be used in "forced" mode: In addition, if the moniker ends with the symbol "!", the device will be used in "forced" mode:
the first USB interface of the device will be used, regardless of the interface class/subclass. the first USB interface of the device will be used, regardless of the interface class/subclass.
This may be useful for some devices that use a custom class/subclass but may nonetheless work as-is. This may be useful for some devices that use a custom class/subclass but may nonetheless work as-is.
@@ -40,9 +37,6 @@ This may be useful for some devices that use a custom class/subclass but may non
`usb:0B05:17CB!` `usb:0B05:17CB!`
The BT USB dongle vendor=0B05 and product=17CB, in "forced" mode. The BT USB dongle vendor=0B05 and product=17CB, in "forced" mode.
`usb:3-3.4.1`
The BT USB dongle on bus 3 on port path 3, 4, 1.
## Alternative ## Alternative
The library includes two different implementations of the USB transport, implemented using different python bindings for `libusb`. The library includes two different implementations of the USB transport, implemented using different python bindings for `libusb`.
+1 -2
View File
@@ -25,7 +25,6 @@ from bumble.utils import AsyncRunner
my_work_queue1 = AsyncRunner.WorkQueue() my_work_queue1 = AsyncRunner.WorkQueue()
my_work_queue2 = AsyncRunner.WorkQueue(create_task=False) my_work_queue2 = AsyncRunner.WorkQueue(create_task=False)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@AsyncRunner.run_in_task() @AsyncRunner.run_in_task()
async def func1(x, y): async def func1(x, y):
@@ -61,7 +60,7 @@ async def func4(x, y):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def main() -> None: async def main():
print("MAIN: start, loop=", asyncio.get_running_loop()) print("MAIN: start, loop=", asyncio.get_running_loop())
print("MAIN: invoke func1") print("MAIN: invoke func1")
func1(1, 2) func1(1, 2)
-274
View File
@@ -1,274 +0,0 @@
<html>
<head>
<style>
* {
font-family: sans-serif;
}
</style>
</head>
<body>
Server Port <input id="port" type="text" value="8989"></input> <button id="connectButton" onclick="connect()">Connect</button><br>
<div id="socketState"></div>
<br>
<div id="buttons"></div><br>
<hr>
<button onclick="onGetPlayStatusButtonClicked()">Get Play Status</button><br>
<div id="getPlayStatusResponseTable"></div>
<hr>
<button onclick="onGetElementAttributesButtonClicked()">Get Element Attributes</button><br>
<div id="getElementAttributesResponseTable"></div>
<hr>
<table>
<tr>
<b>VOLUME</b>:
<button onclick="onVolumeDownButtonClicked()">-</button>
<button onclick="onVolumeUpButtonClicked()">+</button>&nbsp;
<span id="volumeText"></span><br>
</tr>
<tr>
<td><b>PLAYBACK STATUS</b></td><td><span id="playbackStatusText"></span></td>
</tr>
<tr>
<td><b>POSITION</b></td><td><span id="positionText"></span></td>
</tr>
<tr>
<td><b>TRACK</b></td><td><span id="trackText"></span></td>
</tr>
<tr>
<td><b>ADDRESSED PLAYER</b></td><td><span id="addressedPlayerText"></span></td>
</tr>
<tr>
<td><b>UID COUNTER</b></td><td><span id="uidCounterText"></span></td>
</tr>
<tr>
<td><b>SUPPORTED EVENTS</b></td><td><span id="supportedEventsText"></span></td>
</tr>
<tr>
<td><b>PLAYER SETTINGS</b></td><td><div id="playerSettingsTable"></div></td>
</tr>
</table>
<script>
const portInput = document.getElementById("port")
const connectButton = document.getElementById("connectButton")
const socketState = document.getElementById("socketState")
const volumeText = document.getElementById("volumeText")
const positionText = document.getElementById("positionText")
const trackText = document.getElementById("trackText")
const playbackStatusText = document.getElementById("playbackStatusText")
const addressedPlayerText = document.getElementById("addressedPlayerText")
const uidCounterText = document.getElementById("uidCounterText")
const supportedEventsText = document.getElementById("supportedEventsText")
const playerSettingsTable = document.getElementById("playerSettingsTable")
const getPlayStatusResponseTable = document.getElementById("getPlayStatusResponseTable")
const getElementAttributesResponseTable = document.getElementById("getElementAttributesResponseTable")
let socket
let volume = 0
const keyNames = [
"SELECT",
"UP",
"DOWN",
"LEFT",
"RIGHT",
"RIGHT_UP",
"RIGHT_DOWN",
"LEFT_UP",
"LEFT_DOWN",
"ROOT_MENU",
"SETUP_MENU",
"CONTENTS_MENU",
"FAVORITE_MENU",
"EXIT",
"NUMBER_0",
"NUMBER_1",
"NUMBER_2",
"NUMBER_3",
"NUMBER_4",
"NUMBER_5",
"NUMBER_6",
"NUMBER_7",
"NUMBER_8",
"NUMBER_9",
"DOT",
"ENTER",
"CLEAR",
"CHANNEL_UP",
"CHANNEL_DOWN",
"PREVIOUS_CHANNEL",
"SOUND_SELECT",
"INPUT_SELECT",
"DISPLAY_INFORMATION",
"HELP",
"PAGE_UP",
"PAGE_DOWN",
"POWER",
"VOLUME_UP",
"VOLUME_DOWN",
"MUTE",
"PLAY",
"STOP",
"PAUSE",
"RECORD",
"REWIND",
"FAST_FORWARD",
"EJECT",
"FORWARD",
"BACKWARD",
"ANGLE",
"SUBPICTURE",
"F1",
"F2",
"F3",
"F4",
"F5",
]
document.addEventListener('keydown', onKeyDown)
document.addEventListener('keyup', onKeyUp)
const buttons = document.getElementById("buttons")
keyNames.forEach(name => {
const button = document.createElement("BUTTON")
button.appendChild(document.createTextNode(name))
button.addEventListener("mousedown", event => {
send({type: 'send-key-down', key: name})
})
button.addEventListener("mouseup", event => {
send({type: 'send-key-up', key: name})
})
buttons.appendChild(button)
})
updateVolume(0)
function connect() {
socket = new WebSocket(`ws://localhost:${portInput.value}`);
socket.onopen = _ => {
socketState.innerText = 'OPEN'
connectButton.disabled = true
}
socket.onclose = _ => {
socketState.innerText = 'CLOSED'
connectButton.disabled = false
}
socket.onerror = (error) => {
socketState.innerText = 'ERROR'
console.log(`ERROR: ${error}`)
connectButton.disabled = false
}
socket.onmessage = (message) => {
onMessage(JSON.parse(message.data))
}
}
function send(message) {
if (socket && socket.readyState == WebSocket.OPEN) {
socket.send(JSON.stringify(message))
}
}
function hmsText(position) {
const h_1 = 1000 * 60 * 60
const h = Math.floor(position / h_1)
position -= h * h_1
const m_1 = 1000 * 60
const m = Math.floor(position / m_1)
position -= m * m_1
const s_1 = 1000
const s = Math.floor(position / s_1)
position -= s * s_1
return `${h}:${m.toString().padStart(2, "0")}:${s.toString().padStart(2, "0")}:${position}`
}
function setTableHead(table, columns) {
let thead = table.createTHead()
let row = thead.insertRow()
for (let column of columns) {
let th = document.createElement("th")
let text = document.createTextNode(column)
th.appendChild(text)
row.appendChild(th)
}
}
function createTable(rows) {
const table = document.createElement("table")
if (rows.length != 0) {
columns = Object.keys(rows[0])
setTableHead(table, columns)
}
for (let element of rows) {
let row = table.insertRow()
for (key in element) {
let cell = row.insertCell()
let text = document.createTextNode(element[key])
cell.appendChild(text)
}
}
return table
}
function onMessage(message) {
console.log(message)
if (message.type == "set-volume") {
updateVolume(message.params.volume)
} else if (message.type == "supported-events") {
supportedEventsText.innerText = JSON.stringify(message.params.events)
} else if (message.type == "playback-position-changed") {
positionText.innerText = hmsText(message.params.position)
} else if (message.type == "playback-status-changed") {
playbackStatusText.innerText = message.params.status
} else if (message.type == "player-settings-changed") {
playerSettingsTable.replaceChildren(message.params.settings)
} else if (message.type == "track-changed") {
trackText.innerText = message.params.identifier
} else if (message.type == "addressed-player-changed") {
addressedPlayerText.innerText = JSON.stringify(message.params.player)
} else if (message.type == "uids-changed") {
uidCounterText.innerText = message.params.uid_counter
} else if (message.type == "get-play-status-response") {
getPlayStatusResponseTable.replaceChildren(message.params)
} else if (message.type == "get-element-attributes-response") {
getElementAttributesResponseTable.replaceChildren(createTable(message.params))
}
}
function updateVolume(newVolume) {
volume = newVolume
volumeText.innerText = `${volume} (${Math.round(100*volume/0x7F)}%)`
}
function onKeyDown(event) {
console.log(event)
send({ type: 'send-key-down', key: event.key })
}
function onKeyUp(event) {
console.log(event)
send({ type: 'send-key-up', key: event.key })
}
function onVolumeUpButtonClicked() {
updateVolume(Math.min(volume + 5, 0x7F))
send({ type: 'set-volume', volume })
}
function onVolumeDownButtonClicked() {
updateVolume(Math.max(volume - 5, 0))
send({ type: 'set-volume', volume })
}
function onGetPlayStatusButtonClicked() {
send({ type: 'get-play-status', volume })
}
function onGetElementAttributesButtonClicked() {
send({ type: 'get-element-attributes' })
}
</script>
</body>
</html>
+3 -9
View File
@@ -21,29 +21,23 @@ import os
import logging import logging
from bumble.colors import color from bumble.colors import color
from bumble.device import Device from bumble.device import Device
from bumble.hci import Address
from bumble.transport import open_transport from bumble.transport import open_transport
from bumble.profiles.battery_service import BatteryServiceProxy from bumble.profiles.battery_service import BatteryServiceProxy
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def main() -> None: async def main():
if len(sys.argv) != 3: if len(sys.argv) != 3:
print('Usage: battery_client.py <transport-spec> <bluetooth-address>') print('Usage: battery_client.py <transport-spec> <bluetooth-address>')
print('example: battery_client.py usb:0 E1:CA:72:48:C4:E8') print('example: battery_client.py usb:0 E1:CA:72:48:C4:E8')
return return
print('<<< connecting to HCI...') print('<<< connecting to HCI...')
async with await open_transport(sys.argv[1]) as hci_transport: async with await open_transport(sys.argv[1]) as (hci_source, hci_sink):
print('<<< connected') print('<<< connected')
# Create and start a device # Create and start a device
device = Device.with_hci( device = Device.with_hci('Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink)
'Bumble',
Address('F0:F1:F2:F3:F4:F5'),
hci_transport.source,
hci_transport.sink,
)
await device.power_on() await device.power_on()
# Connect to the peer # Connect to the peer
+3 -5
View File
@@ -29,16 +29,14 @@ from bumble.profiles.battery_service import BatteryService
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def main() -> None: async def main():
if len(sys.argv) != 3: if len(sys.argv) != 3:
print('Usage: python battery_server.py <device-config> <transport-spec>') print('Usage: python battery_server.py <device-config> <transport-spec>')
print('example: python battery_server.py device1.json usb:0') print('example: python battery_server.py device1.json usb:0')
return return
async with await open_transport_or_link(sys.argv[2]) as hci_transport: async with await open_transport_or_link(sys.argv[2]) as (hci_source, hci_sink):
device = Device.from_config_file_with_hci( device = Device.from_config_file_with_hci(sys.argv[1], hci_source, hci_sink)
sys.argv[1], hci_transport.source, hci_transport.sink
)
# Add a Battery Service to the GATT sever # Add a Battery Service to the GATT sever
battery_service = BatteryService(lambda _: random.randint(0, 100)) battery_service = BatteryService(lambda _: random.randint(0, 100))
+3 -9
View File
@@ -21,13 +21,12 @@ import os
import logging import logging
from bumble.colors import color from bumble.colors import color
from bumble.device import Device, Peer from bumble.device import Device, Peer
from bumble.hci import Address
from bumble.profiles.device_information_service import DeviceInformationServiceProxy from bumble.profiles.device_information_service import DeviceInformationServiceProxy
from bumble.transport import open_transport from bumble.transport import open_transport
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def main() -> None: async def main():
if len(sys.argv) != 3: if len(sys.argv) != 3:
print( print(
'Usage: device_information_client.py <transport-spec> <bluetooth-address>' 'Usage: device_information_client.py <transport-spec> <bluetooth-address>'
@@ -36,16 +35,11 @@ async def main() -> None:
return return
print('<<< connecting to HCI...') print('<<< connecting to HCI...')
async with await open_transport(sys.argv[1]) as hci_transport: async with await open_transport(sys.argv[1]) as (hci_source, hci_sink):
print('<<< connected') print('<<< connected')
# Create and start a device # Create and start a device
device = Device.with_hci( device = Device.with_hci('Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink)
'Bumble',
Address('F0:F1:F2:F3:F4:F5'),
hci_transport.source,
hci_transport.sink,
)
await device.power_on() await device.power_on()
# Connect to the peer # Connect to the peer
+4 -6
View File
@@ -28,16 +28,14 @@ from bumble.profiles.device_information_service import DeviceInformationService
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def main() -> None: async def main():
if len(sys.argv) != 3: if len(sys.argv) != 3:
print('Usage: python device_info_server.py <device-config> <transport-spec>') print('Usage: python device_info_server.py <device-config> <transport-spec>')
print('example: python device_info_server.py device1.json usb:0') print('example: python device_info_server.py device1.json usb:0')
return return
async with await open_transport_or_link(sys.argv[2]) as hci_transport: async with await open_transport_or_link(sys.argv[2]) as (hci_source, hci_sink):
device = Device.from_config_file_with_hci( device = Device.from_config_file_with_hci(sys.argv[1], hci_source, hci_sink)
sys.argv[1], hci_transport.source, hci_transport.sink
)
# Add a Device Information Service to the GATT sever # Add a Device Information Service to the GATT sever
device_information_service = DeviceInformationService( device_information_service = DeviceInformationService(
@@ -66,7 +64,7 @@ async def main() -> None:
# Go! # Go!
await device.power_on() await device.power_on()
await device.start_advertising(auto_restart=True) await device.start_advertising(auto_restart=True)
await hci_transport.source.wait_for_termination() await hci_source.wait_for_termination()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
+3 -9
View File
@@ -21,29 +21,23 @@ import os
import logging import logging
from bumble.colors import color from bumble.colors import color
from bumble.device import Device from bumble.device import Device
from bumble.hci import Address
from bumble.transport import open_transport from bumble.transport import open_transport
from bumble.profiles.heart_rate_service import HeartRateServiceProxy from bumble.profiles.heart_rate_service import HeartRateServiceProxy
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def main() -> None: async def main():
if len(sys.argv) != 3: if len(sys.argv) != 3:
print('Usage: heart_rate_client.py <transport-spec> <bluetooth-address>') print('Usage: heart_rate_client.py <transport-spec> <bluetooth-address>')
print('example: heart_rate_client.py usb:0 E1:CA:72:48:C4:E8') print('example: heart_rate_client.py usb:0 E1:CA:72:48:C4:E8')
return return
print('<<< connecting to HCI...') print('<<< connecting to HCI...')
async with await open_transport(sys.argv[1]) as hci_transport: async with await open_transport(sys.argv[1]) as (hci_source, hci_sink):
print('<<< connected') print('<<< connected')
# Create and start a device # Create and start a device
device = Device.with_hci( device = Device.with_hci('Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink)
'Bumble',
Address('F0:F1:F2:F3:F4:F5'),
hci_transport.source,
hci_transport.sink,
)
await device.power_on() await device.power_on()
# Connect to the peer # Connect to the peer
+3 -5
View File
@@ -33,16 +33,14 @@ from bumble.utils import AsyncRunner
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def main() -> None: async def main():
if len(sys.argv) != 3: if len(sys.argv) != 3:
print('Usage: python heart_rate_server.py <device-config> <transport-spec>') print('Usage: python heart_rate_server.py <device-config> <transport-spec>')
print('example: python heart_rate_server.py device1.json usb:0') print('example: python heart_rate_server.py device1.json usb:0')
return return
async with await open_transport_or_link(sys.argv[2]) as hci_transport: async with await open_transport_or_link(sys.argv[2]) as (hci_source, hci_sink):
device = Device.from_config_file_with_hci( device = Device.from_config_file_with_hci(sys.argv[1], hci_source, hci_sink)
sys.argv[1], hci_transport.source, hci_transport.sink
)
# Keep track of accumulated expended energy # Keep track of accumulated expended energy
energy_start_time = time.time() energy_start_time = time.time()
-350
View File
@@ -1,350 +0,0 @@
<html data-bs-theme="dark">
<head>
<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.2/dist/css/bootstrap.min.css" rel="stylesheet"
integrity="sha384-T3c6CoIi6uLrA9TneNEoa7RxnatzjcDSCmG1MXxSR1GAsXEV/Dwwykc2MPK8M2HN" crossorigin="anonymous">
<script src="https://unpkg.com/pcm-player"></script>
</head>
<body>
<nav class="navbar navbar-dark bg-primary">
<div class="container">
<span class="navbar-brand mb-0 h1">Bumble HFP Audio Gateway</span>
</div>
</nav>
<br>
<div class="container">
<label class="form-label">Send AT Response</label>
<div class="input-group mb-3">
<input type="text" class="form-control" placeholder="AT Response" aria-label="AT response" id="at_response">
<button class="btn btn-primary" type="button"
onclick="send_at_response(document.getElementById('at_response').value)">Send</button>
</div>
<div class="row">
<div class="col-3">
<label class="form-label">Speaker Volume</label>
<div class="input-group mb-3 col-auto">
<input type="text" class="form-control" placeholder="0 - 15" aria-label="Speaker Volume"
id="speaker_volume">
<button class="btn btn-primary" type="button"
onclick="send_at_response(`+VGS: ${document.getElementById('speaker_volume').value}`)">Set</button>
</div>
</div>
<div class="col-3">
<label class="form-label">Mic Volume</label>
<div class="input-group mb-3 col-auto">
<input type="text" class="form-control" placeholder="0 - 15" aria-label="Mic Volume"
id="mic_volume">
<button class="btn btn-primary" type="button"
onclick="send_at_response(`+VGM: ${document.getElementById('mic_volume').value}`)">Set</button>
</div>
</div>
<div class="col-3">
<label class="form-label">Browser Gain</label>
<input type="range" class="form-range" id="browser-gain" min="0" max="2" value="1" step="0.1" onchange="setGain()">
</div>
</div>
<div class="row">
<div class="col-auto">
<div class="input-group mb-3">
<span class="input-group-text">Codec</span>
<select class="form-select" id="codec">
<option selected value="1">CVSD</option>
<option value="2">MSBC</option>
</select>
</div>
</div>
<div class="col-auto">
<button class="btn btn-primary" onclick="negotiate_codec()">Negotiate Codec</button>
</div>
<div class="col-auto">
<button class="btn btn-primary" onclick="connect_sco()">Connect SCO</button>
</div>
<div class="col-auto">
<button class="btn btn-primary" onclick="disconnect_sco()">Disconnect SCO</button>
</div>
<div class="col-auto">
<button class="btn btn-danger" onclick="connectAudio()">Connect Audio</button>
</div>
</div>
<hr>
<div class="row">
<h4>AG Indicators</h2>
<div class="col-3">
<label class="form-label">call</label>
<div class="input-group mb-3 col-auto">
<select class="form-select" id="call">
<option selected value="0">Inactive</option>
<option value="1">Active</option>
</select>
<button class="btn btn-primary" type="button" onclick="update_ag_indicator('call')">Set</button>
</div>
</div>
<div class="col-3">
<label class="form-label">callsetup</label>
<div class="input-group mb-3 col-auto">
<select class="form-select" id="callsetup">
<option selected value="0">Idle</option>
<option value="1">Incoming</option>
<option value="2">Outgoing</option>
<option value="3">Remote Alerted</option>
</select>
<button class="btn btn-primary" type="button"
onclick="update_ag_indicator('callsetup')">Set</button>
</div>
</div>
<div class="col-3">
<label class="form-label">callheld</label>
<div class="input-group mb-3 col-auto">
<select class="form-select" id="callsetup">
<option selected value="0">0</option>
<option value="1">1</option>
<option value="2">2</option>
</select>
<button class="btn btn-primary" type="button"
onclick="update_ag_indicator('callheld')">Set</button>
</div>
</div>
<div class="col-3">
<label class="form-label">signal</label>
<div class="input-group mb-3 col-auto">
<select class="form-select" id="signal">
<option selected value="0">0</option>
<option value="1">1</option>
<option value="2">2</option>
<option value="3">3</option>
<option value="4">4</option>
<option value="5">5</option>
</select>
<button class="btn btn-primary" type="button"
onclick="update_ag_indicator('signal')">Set</button>
</div>
</div>
<div class="col-3">
<label class="form-label">roam</label>
<div class="input-group mb-3 col-auto">
<select class="form-select" id="roam">
<option selected value="0">0</option>
<option value="1">1</option>
</select>
<button class="btn btn-primary" type="button" onclick="update_ag_indicator('roam')">Set</button>
</div>
</div>
<div class="col-3">
<label class="form-label">battchg</label>
<div class="input-group mb-3 col-auto">
<select class="form-select" id="battchg">
<option selected value="0">0</option>
<option value="1">1</option>
<option value="2">2</option>
<option value="3">3</option>
<option value="4">4</option>
<option value="5">5</option>
</select>
<button class="btn btn-primary" type="button"
onclick="update_ag_indicator('battchg')">Set</button>
</div>
</div>
<div class="col-3">
<label class="form-label">service</label>
<div class="input-group mb-3 col-auto">
<select class="form-select" id="service">
<option selected value="0">0</option>
<option value="1">1</option>
</select>
<button class="btn btn-primary" type="button"
onclick="update_ag_indicator('service')">Set</button>
</div>
</div>
</div>
<hr>
<button class="btn btn-primary" onclick="send_at_response('+BVRA: 1')">Start Voice Assistant</button>
<button class="btn btn-primary" onclick="send_at_response('+BVRA: 0')">Stop Voice Assistant</button>
<hr>
<h4>Calls</h4>
<div id="call-lists">
<template id="call-template">
<div class="row call-row">
<div class="input-group mb-3">
<label class="input-group-text">Index</label>
<input class="form-control call-index" value="1">
<label class="input-group-text">Number</label>
<input class="form-control call-number">
<label class="input-group-text">Direction</label>
<select class="form-select call-direction">
<option selected value="0">Originated</option>
<option value="1">Terminated</option>
</select>
<label class="input-group-text">Status</label>
<select class="form-select call-status">
<option value="0">ACTIVE</option>
<option value="1">HELD</option>
<option value="2">DIALING</option>
<option value="3">ALERTING</option>
<option value="4">INCOMING</option>
<option value="5">WAITING</option>
</select>
<button class="btn btn-primary call-remover"></button>
</div>
</div>
</template>
</div>
<button class="btn btn-primary" onclick="add_call()"> Add Call</button>
<button class="btn btn-primary" onclick="update_calls()">🗘 Update Calls</button>
<hr>
<div id="socketStateContainer" class="bg-body-tertiary p-3 rounded-2">
<h3>Log</h3>
<code id="log" style="white-space: pre-line;"></code>
</div>
</div>
<script>
let atResponseInput = document.getElementById("at_response")
let gainInput = document.getElementById('browser-gain')
let log = document.getElementById("log")
let socket = new WebSocket('ws://localhost:8888');
let sampleRate = 0;
let player;
socket.binaryType = "arraybuffer";
socket.onopen = _ => {
log.textContent += 'SOCKET OPEN\n'
}
socket.onclose = _ => {
log.textContent += 'SOCKET CLOSED\n'
}
socket.onerror = (error) => {
log.textContent += 'SOCKET ERROR\n'
console.log(`ERROR: ${error}`)
}
socket.onmessage = function (message) {
if (typeof message.data === 'string' || message.data instanceof String) {
log.textContent += `<-- ${event.data}\n`
const jsonMessage = JSON.parse(event.data)
if (jsonMessage.type == 'speaker_volume') {
document.getElementById('speaker_volume').value = jsonMessage.level;
} else if (jsonMessage.type == 'microphone_volume') {
document.getElementById('microphone_volume').value = jsonMessage.level;
} else if (jsonMessage.type == 'sco_state_change') {
sampleRate = jsonMessage.sample_rate;
console.log(sampleRate);
if (player != null) {
player = new PCMPlayer({
inputCodec: 'Int16',
channels: 1,
sampleRate: sampleRate,
flushTime: 7.5,
});
player.volume(gainInput.value);
}
}
} else {
// BINARY audio data.
if (player == null) return;
player.feed(message.data);
}
};
function send(message) {
if (socket && socket.readyState == WebSocket.OPEN) {
let jsonMessage = JSON.stringify(message)
log.textContent += `--> ${jsonMessage}\n`
socket.send(jsonMessage)
} else {
log.textContent += 'NOT CONNECTED\n'
}
}
function send_at_response(response) {
send({ type: 'at_response', response: response })
}
function update_ag_indicator(indicator) {
const value = document.getElementById(indicator).value
send({ type: 'ag_indicator', indicator: indicator, value: value })
}
function connect_sco() {
send({ type: 'connect_sco' })
}
function negotiate_codec() {
const codec = document.getElementById('codec').value
send({ type: 'negotiate_codec', codec: codec })
}
function disconnect_sco() {
send({ type: 'disconnect_sco' })
}
function add_call() {
let callLists = document.getElementById('call-lists');
let template = document.getElementById('call-template');
let newNode = document.importNode(template.content, true);
newNode.querySelector('.call-remover').onclick = function (event) {
event.target.closest('.call-row').remove();
}
callLists.appendChild(newNode);
}
function update_calls() {
let callLists = document.getElementById('call-lists');
send({
type: 'update_calls',
calls: Array.from(
callLists.querySelectorAll('.call-row')).map(
function (element) {
return {
index: element.querySelector('.call-index').value,
number: element.querySelector('.call-number').value,
direction: element.querySelector('.call-direction').value,
status: element.querySelector('.call-status').value,
}
}
),
}
)
}
function connectAudio() {
player = new PCMPlayer({
inputCodec: 'Int16',
channels: 1,
sampleRate: sampleRate,
flushTime: 7.5,
});
player.volume(gainInput.value);
}
function setGain() {
if (player != null) {
player.volume(gainInput.value);
}
}
</script>
</div>
</body>
</html>
+1 -2
View File
@@ -1,5 +1,4 @@
{ {
"name": "Bumble Phone", "name": "Bumble Phone",
"class_of_device": 6291980, "class_of_device": 6291980
"keystore": "JsonKeyStore"
} }
+54 -107
View File
@@ -1,132 +1,79 @@
<html data-bs-theme="dark"> <html>
<head>
<style>
* {
font-family: sans-serif;
}
<head> label {
<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.2/dist/css/bootstrap.min.css" rel="stylesheet" display: block;
integrity="sha384-T3c6CoIi6uLrA9TneNEoa7RxnatzjcDSCmG1MXxSR1GAsXEV/Dwwykc2MPK8M2HN" crossorigin="anonymous"> }
</head>
<body>
<nav class="navbar navbar-dark bg-primary">
<div class="container">
<span class="navbar-brand mb-0 h1">Bumble Handsfree</span>
</div>
</nav>
<br>
<div class="container">
<label class="form-label">Server Port</label>
<div class="input-group mb-3">
<input type="text" class="form-control" aria-label="Port Number" value="8989" id="port">
<button class="btn btn-primary" type="button" onclick="connect()">Connect</button>
</div>
<label class="form-label">Dial Phone Number</label>
<div class="input-group mb-3">
<input type="text" class="form-control" placeholder="Phone Number" aria-label="Phone Number"
id="dial_number">
<button class="btn btn-primary" type="button"
onclick="send_at_command(`ATD${dialNumberInput.value}`)">Dial</button>
</div>
<label class="form-label">Send AT Command</label>
<div class="input-group mb-3">
<input type="text" class="form-control" placeholder="AT Command" aria-label="AT command" id="at_command">
<button class="btn btn-primary" type="button"
onclick="send_at_command(document.getElementById('at_command').value)">Send</button>
</div>
<div class="row">
<div class="col-auto">
<label class="form-label">Battery Level</label>
<div class="input-group mb-3">
<input type="text" class="form-control" placeholder="0 - 100" aria-label="Battery Level"
id="battery_level">
<button class="btn btn-primary" type="button"
onclick="send_at_command(`AT+BIEV=2,${document.getElementById('battery_level').value}`)">Set</button>
</div>
</div>
<div class="col-auto">
<label class="form-label">Speaker Volume</label>
<div class="input-group mb-3 col-auto">
<input type="text" class="form-control" placeholder="0 - 15" aria-label="Speaker Volume"
id="speaker_volume">
<button class="btn btn-primary" type="button"
onclick="send_at_command(`AT+VGS=${document.getElementById('speaker_volume').value}`)">Set</button>
</div>
</div>
<div class="col-auto">
<label class="form-label">Mic Volume</label>
<div class="input-group mb-3 col-auto">
<input type="text" class="form-control" placeholder="0 - 15" aria-label="Mic Volume"
id="mic_volume">
<button class="btn btn-primary" type="button"
onclick="send_at_command(`AT+VGM=${document.getElementById('mic_volume').value}`)">Set</button>
</div>
</div>
</div>
<button class="btn btn-primary" onclick="send_at_command('ATA')">Answer</button>
<button class="btn btn-primary" onclick="send_at_command('AT+CHUP')">Hang Up</button>
<button class="btn btn-primary" onclick="send_at_command('AT+BLDN')">Redial</button>
<button class="btn btn-primary" onclick="send({ type: 'query_call'})">Get Call Status</button>
<br><br>
<button class="btn btn-primary" onclick="send_at_command('AT+BVRA=1')">Start Voice Assistant</button>
<button class="btn btn-primary" onclick="send_at_command('AT+BVRA=0')">Stop Voice Assistant</button>
input, label {
margin: .4rem 0;
}
</style>
</head>
<body>
Server Port <input id="port" type="text" value="8989"></input> <button onclick="connect()">Connect</button><br>
AT Command <input type="text" id="at_command" required size="10"> <button onclick="send_at_command()">Send</button><br>
Dial Phone Number <input type="text" id="dial_number" required size="10"> <button onclick="dial()">Dial</button><br>
<button onclick="answer()">Answer</button>
<button onclick="hangup()">Hang Up</button>
<button onclick="start_voice_assistant()">Start Voice Assistant</button>
<button onclick="stop_voice_assistant()">Stop Voice Assistant</button>
<hr> <hr>
<div id="socketState"></div>
<div id="socketStateContainer" class="bg-body-tertiary p-3 rounded-2"> <script>
<h3>Log</h3>
<code id="log" style="white-space: pre-line;"></code>
</div>
</div>
<script>
let portInput = document.getElementById("port") let portInput = document.getElementById("port")
let atCommandInput = document.getElementById("at_command") let atCommandInput = document.getElementById("at_command")
let log = document.getElementById("log") let dialNumberInput = document.getElementById("dial_number")
let socketState = document.getElementById("socketState")
let socket let socket
function connect() { function connect() {
socket = new WebSocket(`ws://localhost:${portInput.value}`); socket = new WebSocket(`ws://localhost:${portInput.value}`);
socket.onopen = _ => { socket.onopen = _ => {
log.textContent += 'OPEN\n' socketState.innerText = 'OPEN'
} }
socket.onclose = _ => { socket.onclose = _ => {
log.textContent += 'CLOSED\n' socketState.innerText = 'CLOSED'
} }
socket.onerror = (error) => { socket.onerror = (error) => {
log.textContent += 'ERROR\n' socketState.innerText = 'ERROR'
console.log(`ERROR: ${error}`) console.log(`ERROR: ${error}`)
} }
socket.onmessage = (event) => {
log.textContent += `<-- ${event.data}\n`
let volume_state = JSON.parse(event.data)
volumeSetting.value = volume_state.volume_setting
changeCounter.value = volume_state.change_counter
muted.checked = volume_state.muted ? true : false
}
} }
function send(message) { function send(message) {
if (socket && socket.readyState == WebSocket.OPEN) { if (socket && socket.readyState == WebSocket.OPEN) {
let jsonMessage = JSON.stringify(message) socket.send(JSON.stringify(message))
log.textContent += `--> ${jsonMessage}\n`
socket.send(jsonMessage)
} else {
log.textContent += 'NOT CONNECTED\n'
} }
} }
function send_at_command(command) { function send_at_command() {
send({ type: 'at_command', 'command': command }) send({ type:'at_command', command: atCommandInput.value })
} }
</script>
</div>
</body>
</html> function answer() {
send({ type:'at_command', command: 'ATA' })
}
function hangup() {
send({ type:'at_command', command: 'AT+CHUP' })
}
function dial() {
send({ type:'at_command', command: `ATD${dialNumberInput.value}` })
}
function start_voice_assistant() {
send(({ type:'at_command', command: 'AT+BVRA=1' }))
}
function stop_voice_assistant() {
send(({ type:'at_command', command: 'AT+BVRA=0' }))
}
</script>
</body>
</html>
-5
View File
@@ -1,5 +0,0 @@
{
"name": "Bumble HID Keyboard",
"class_of_device": 9664,
"keystore": "JsonKeyStore"
}
+3 -3
View File
@@ -40,9 +40,9 @@
} }
} }
function onMouseMove(event) { function onMouseMove(event) {
//console.log(event.movementX, event.movementY) //console.log(event.clientX, event.clientY)
mouseInfo.innerText = `MOUSE: x=${event.movementX}, y=${event.movementY}` mouseInfo.innerText = `MOUSE: x=${event.clientX}, y=${event.clientY}`
send({ type:'mousemove', x: event.movementX, y: event.movementY }) send({ type:'mousemove', x: event.clientX, y: event.clientY })
} }
function onKeyDown(event) { function onKeyDown(event) {
+3 -5
View File
@@ -416,7 +416,7 @@ async def keyboard_device(device, command):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def main() -> None: async def main():
if len(sys.argv) < 4: if len(sys.argv) < 4:
print( print(
'Usage: python keyboard.py <device-config> <transport-spec> <command>' 'Usage: python keyboard.py <device-config> <transport-spec> <command>'
@@ -434,11 +434,9 @@ async def main() -> None:
) )
return return
async with await open_transport_or_link(sys.argv[2]) as hci_transport: async with await open_transport_or_link(sys.argv[2]) as (hci_source, hci_sink):
# Create a device to manage the host # Create a device to manage the host
device = Device.from_config_file_with_hci( device = Device.from_config_file_with_hci(sys.argv[1], hci_source, hci_sink)
sys.argv[1], hci_transport.source, hci_transport.sink
)
command = sys.argv[3] command = sys.argv[3]
if command == 'connect': if command == 'connect':
-7
View File
@@ -1,7 +0,0 @@
{
"name": "Bumble-LEA",
"keystore": "JsonKeyStore",
"address": "F0:F1:F2:F3:F4:FA",
"class_of_device": 2376708,
"advertising_interval": 100
}
-9
View File
@@ -1,9 +0,0 @@
{
"name": "Bumble-LEA",
"keystore": "JsonKeyStore",
"address": "F0:F1:F2:F3:F4:FA",
"classic_enabled": true,
"cis_enabled": true,
"class_of_device": 2376708,
"advertising_interval": 100
}
+8 -10
View File
@@ -53,10 +53,10 @@ def sdp_records():
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# pylint: disable-next=too-many-nested-blocks # pylint: disable-next=too-many-nested-blocks
async def find_a2dp_service(connection): async def find_a2dp_service(device, connection):
# Connect to the SDP Server # Connect to the SDP Server
sdp_client = SDP_Client(connection) sdp_client = SDP_Client(device)
await sdp_client.connect() await sdp_client.connect(connection)
# Search for services with an Audio Sink service class # Search for services with an Audio Sink service class
search_result = await sdp_client.search_attributes( search_result = await sdp_client.search_attributes(
@@ -139,20 +139,18 @@ async def find_a2dp_service(connection):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def main() -> None: async def main():
if len(sys.argv) < 4: if len(sys.argv) < 4:
print('Usage: run_a2dp_info.py <device-config> <transport-spec> <bt-addr>') print('Usage: run_a2dp_info.py <device-config> <transport-spec> <bt-addr>')
print('example: run_a2dp_info.py classic1.json usb:0 14:7D:DA:4E:53:A8') print('example: run_a2dp_info.py classic1.json usb:0 14:7D:DA:4E:53:A8')
return return
print('<<< connecting to HCI...') print('<<< connecting to HCI...')
async with await open_transport_or_link(sys.argv[2]) as hci_transport: async with await open_transport_or_link(sys.argv[2]) as (hci_source, hci_sink):
print('<<< connected') print('<<< connected')
# Create a device # Create a device
device = Device.from_config_file_with_hci( device = Device.from_config_file_with_hci(sys.argv[1], hci_source, hci_sink)
sys.argv[1], hci_transport.source, hci_transport.sink
)
device.classic_enabled = True device.classic_enabled = True
# Start the controller # Start the controller
@@ -179,7 +177,7 @@ async def main() -> None:
print('*** Encryption on') print('*** Encryption on')
# Look for an A2DP service # Look for an A2DP service
avdtp_version = await find_a2dp_service(connection) avdtp_version = await find_a2dp_service(device, connection)
if not avdtp_version: if not avdtp_version:
print(color('!!! no AVDTP service found')) print(color('!!! no AVDTP service found'))
return return
@@ -189,7 +187,7 @@ async def main() -> None:
client = await AVDTP_Protocol.connect(connection, avdtp_version) client = await AVDTP_Protocol.connect(connection, avdtp_version)
# Discover all endpoints on the remote device # Discover all endpoints on the remote device
endpoints = list(await client.discover_remote_endpoints()) endpoints = await client.discover_remote_endpoints()
print(f'@@@ Found {len(endpoints)} endpoints') print(f'@@@ Found {len(endpoints)} endpoints')
for endpoint in endpoints: for endpoint in endpoints:
print('@@@', endpoint) print('@@@', endpoint)
+5 -8
View File
@@ -19,7 +19,6 @@ import asyncio
import sys import sys
import os import os
import logging import logging
from typing import Any, Dict
from bumble.device import Device from bumble.device import Device
from bumble.transport import open_transport_or_link from bumble.transport import open_transport_or_link
@@ -42,7 +41,7 @@ from bumble.a2dp import (
SbcMediaCodecInformation, SbcMediaCodecInformation,
) )
Context: Dict[Any, Any] = {'output': None} Context = {'output': None}
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -105,7 +104,7 @@ def on_rtp_packet(packet):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def main() -> None: async def main():
if len(sys.argv) < 4: if len(sys.argv) < 4:
print( print(
'Usage: run_a2dp_sink.py <device-config> <transport-spec> <sbc-file> ' 'Usage: run_a2dp_sink.py <device-config> <transport-spec> <sbc-file> '
@@ -115,16 +114,14 @@ async def main() -> None:
return return
print('<<< connecting to HCI...') print('<<< connecting to HCI...')
async with await open_transport_or_link(sys.argv[2]) as hci_transport: async with await open_transport_or_link(sys.argv[2]) as (hci_source, hci_sink):
print('<<< connected') print('<<< connected')
with open(sys.argv[3], 'wb') as sbc_file: with open(sys.argv[3], 'wb') as sbc_file:
Context['output'] = sbc_file Context['output'] = sbc_file
# Create a device # Create a device
device = Device.from_config_file_with_hci( device = Device.from_config_file_with_hci(sys.argv[1], hci_source, hci_sink)
sys.argv[1], hci_transport.source, hci_transport.sink
)
device.classic_enabled = True device.classic_enabled = True
# Setup the SDP to expose the sink service # Setup the SDP to expose the sink service
@@ -165,7 +162,7 @@ async def main() -> None:
await device.set_discoverable(True) await device.set_discoverable(True)
await device.set_connectable(True) await device.set_connectable(True)
await hci_transport.source.wait_for_termination() await hci_source.wait_for_termination()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
+9 -9
View File
@@ -74,7 +74,7 @@ def codec_capabilities():
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def on_avdtp_connection(read_function, protocol): def on_avdtp_connection(read_function, protocol):
packet_source = SbcPacketSource( packet_source = SbcPacketSource(
read_function, protocol.l2cap_channel.peer_mtu, codec_capabilities() read_function, protocol.l2cap_channel.mtu, codec_capabilities()
) )
packet_pump = MediaPacketPump(packet_source.packets) packet_pump = MediaPacketPump(packet_source.packets)
protocol.add_source(packet_source.codec_capabilities, packet_pump) protocol.add_source(packet_source.codec_capabilities, packet_pump)
@@ -98,7 +98,7 @@ async def stream_packets(read_function, protocol):
# Stream the packets # Stream the packets
packet_source = SbcPacketSource( packet_source = SbcPacketSource(
read_function, protocol.l2cap_channel.peer_mtu, codec_capabilities() read_function, protocol.l2cap_channel.mtu, codec_capabilities()
) )
packet_pump = MediaPacketPump(packet_source.packets) packet_pump = MediaPacketPump(packet_source.packets)
source = protocol.add_source(packet_source.codec_capabilities, packet_pump) source = protocol.add_source(packet_source.codec_capabilities, packet_pump)
@@ -114,7 +114,7 @@ async def stream_packets(read_function, protocol):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def main() -> None: async def main():
if len(sys.argv) < 4: if len(sys.argv) < 4:
print( print(
'Usage: run_a2dp_source.py <device-config> <transport-spec> <sbc-file> ' 'Usage: run_a2dp_source.py <device-config> <transport-spec> <sbc-file> '
@@ -126,13 +126,11 @@ async def main() -> None:
return return
print('<<< connecting to HCI...') print('<<< connecting to HCI...')
async with await open_transport_or_link(sys.argv[2]) as hci_transport: async with await open_transport_or_link(sys.argv[2]) as (hci_source, hci_sink):
print('<<< connected') print('<<< connected')
# Create a device # Create a device
device = Device.from_config_file_with_hci( device = Device.from_config_file_with_hci(sys.argv[1], hci_source, hci_sink)
sys.argv[1], hci_transport.source, hci_transport.sink
)
device.classic_enabled = True device.classic_enabled = True
# Setup the SDP to expose the SRC service # Setup the SDP to expose the SRC service
@@ -167,7 +165,9 @@ async def main() -> None:
print('*** Encryption on') print('*** Encryption on')
# Look for an A2DP service # Look for an A2DP service
avdtp_version = await find_avdtp_service_with_connection(connection) avdtp_version = await find_avdtp_service_with_connection(
device, connection
)
if not avdtp_version: if not avdtp_version:
print(color('!!! no A2DP service found')) print(color('!!! no A2DP service found'))
return return
@@ -188,7 +188,7 @@ async def main() -> None:
await device.set_discoverable(True) await device.set_discoverable(True)
await device.set_connectable(True) await device.set_connectable(True)
await hci_transport.source.wait_for_termination() await hci_source.wait_for_termination()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
+5 -19
View File
@@ -19,16 +19,14 @@ import asyncio
import logging import logging
import sys import sys
import os import os
import struct
from bumble.core import AdvertisingData
from bumble.device import AdvertisingType, Device from bumble.device import AdvertisingType, Device
from bumble.hci import Address from bumble.hci import Address
from bumble.transport import open_transport_or_link from bumble.transport import open_transport_or_link
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def main() -> None: async def main():
if len(sys.argv) < 3: if len(sys.argv) < 3:
print( print(
'Usage: run_advertiser.py <config-file> <transport-spec> [type] [address]' 'Usage: run_advertiser.py <config-file> <transport-spec> [type] [address]'
@@ -50,25 +48,13 @@ async def main() -> None:
target = None target = None
print('<<< connecting to HCI...') print('<<< connecting to HCI...')
async with await open_transport_or_link(sys.argv[2]) as hci_transport: async with await open_transport_or_link(sys.argv[2]) as (hci_source, hci_sink):
print('<<< connected') print('<<< connected')
device = Device.from_config_file_with_hci( device = Device.from_config_file_with_hci(sys.argv[1], hci_source, hci_sink)
sys.argv[1], hci_transport.source, hci_transport.sink
)
if advertising_type.is_scannable:
device.scan_response_data = bytes(
AdvertisingData(
[
(AdvertisingData.APPEARANCE, struct.pack('<H', 0x0340)),
]
)
)
await device.power_on() await device.power_on()
await device.start_advertising(advertising_type=advertising_type, target=target) await device.start_advertising(advertising_type=advertising_type, target=target)
await hci_transport.source.wait_for_termination() await hci_source.wait_for_termination()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
+4 -6
View File
@@ -49,7 +49,7 @@ ASHA_LE_PSM_OUT_CHARACTERISTIC = UUID(
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def main() -> None: async def main():
if len(sys.argv) != 4: if len(sys.argv) != 4:
print( print(
'Usage: python run_asha_sink.py <device-config> <transport-spec> ' 'Usage: python run_asha_sink.py <device-config> <transport-spec> '
@@ -60,10 +60,8 @@ async def main() -> None:
audio_out = open(sys.argv[3], 'wb') audio_out = open(sys.argv[3], 'wb')
async with await open_transport_or_link(sys.argv[2]) as hci_transport: async with await open_transport_or_link(sys.argv[2]) as (hci_source, hci_sink):
device = Device.from_config_file_with_hci( device = Device.from_config_file_with_hci(sys.argv[1], hci_source, hci_sink)
sys.argv[1], hci_transport.source, hci_transport.sink
)
# Handler for audio control commands # Handler for audio control commands
def on_audio_control_point_write(_connection, value): def on_audio_control_point_write(_connection, value):
@@ -199,7 +197,7 @@ async def main() -> None:
await device.power_on() await device.power_on()
await device.start_advertising(auto_restart=True) await device.start_advertising(auto_restart=True)
await hci_transport.source.wait_for_termination() await hci_source.wait_for_termination()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
-410
View File
@@ -1,410 +0,0 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import json
import sys
import os
import logging
import websockets
from bumble.device import Device
from bumble.transport import open_transport_or_link
from bumble.core import BT_BR_EDR_TRANSPORT
from bumble import avc
from bumble import avrcp
from bumble import avdtp
from bumble import a2dp
from bumble import utils
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
def sdp_records():
a2dp_sink_service_record_handle = 0x00010001
avrcp_controller_service_record_handle = 0x00010002
avrcp_target_service_record_handle = 0x00010003
# pylint: disable=line-too-long
return {
a2dp_sink_service_record_handle: a2dp.make_audio_sink_service_sdp_records(
a2dp_sink_service_record_handle
),
avrcp_controller_service_record_handle: avrcp.make_controller_service_sdp_records(
avrcp_controller_service_record_handle
),
avrcp_target_service_record_handle: avrcp.make_target_service_sdp_records(
avrcp_controller_service_record_handle
),
}
# -----------------------------------------------------------------------------
def codec_capabilities():
return avdtp.MediaCodecCapabilities(
media_type=avdtp.AVDTP_AUDIO_MEDIA_TYPE,
media_codec_type=a2dp.A2DP_SBC_CODEC_TYPE,
media_codec_information=a2dp.SbcMediaCodecInformation.from_lists(
sampling_frequencies=[48000, 44100, 32000, 16000],
channel_modes=[
a2dp.SBC_MONO_CHANNEL_MODE,
a2dp.SBC_DUAL_CHANNEL_MODE,
a2dp.SBC_STEREO_CHANNEL_MODE,
a2dp.SBC_JOINT_STEREO_CHANNEL_MODE,
],
block_lengths=[4, 8, 12, 16],
subbands=[4, 8],
allocation_methods=[
a2dp.SBC_LOUDNESS_ALLOCATION_METHOD,
a2dp.SBC_SNR_ALLOCATION_METHOD,
],
minimum_bitpool_value=2,
maximum_bitpool_value=53,
),
)
# -----------------------------------------------------------------------------
def on_avdtp_connection(server):
# Add a sink endpoint to the server
sink = server.add_sink(codec_capabilities())
sink.on('rtp_packet', on_rtp_packet)
# -----------------------------------------------------------------------------
def on_rtp_packet(packet):
print(f'RTP: {packet}')
# -----------------------------------------------------------------------------
def on_avrcp_start(avrcp_protocol: avrcp.Protocol, websocket_server: WebSocketServer):
async def get_supported_events():
events = await avrcp_protocol.get_supported_events()
print("SUPPORTED EVENTS:", events)
websocket_server.send_message(
{
"type": "supported-events",
"params": {"events": [event.name for event in events]},
}
)
if avrcp.EventId.TRACK_CHANGED in events:
utils.AsyncRunner.spawn(monitor_track_changed())
if avrcp.EventId.PLAYBACK_STATUS_CHANGED in events:
utils.AsyncRunner.spawn(monitor_playback_status())
if avrcp.EventId.PLAYBACK_POS_CHANGED in events:
utils.AsyncRunner.spawn(monitor_playback_position())
if avrcp.EventId.PLAYER_APPLICATION_SETTING_CHANGED in events:
utils.AsyncRunner.spawn(monitor_player_application_settings())
if avrcp.EventId.AVAILABLE_PLAYERS_CHANGED in events:
utils.AsyncRunner.spawn(monitor_available_players())
if avrcp.EventId.ADDRESSED_PLAYER_CHANGED in events:
utils.AsyncRunner.spawn(monitor_addressed_player())
if avrcp.EventId.UIDS_CHANGED in events:
utils.AsyncRunner.spawn(monitor_uids())
if avrcp.EventId.VOLUME_CHANGED in events:
utils.AsyncRunner.spawn(monitor_volume())
utils.AsyncRunner.spawn(get_supported_events())
async def monitor_track_changed():
async for identifier in avrcp_protocol.monitor_track_changed():
print("TRACK CHANGED:", identifier.hex())
websocket_server.send_message(
{"type": "track-changed", "params": {"identifier": identifier.hex()}}
)
async def monitor_playback_status():
async for playback_status in avrcp_protocol.monitor_playback_status():
print("PLAYBACK STATUS CHANGED:", playback_status.name)
websocket_server.send_message(
{
"type": "playback-status-changed",
"params": {"status": playback_status.name},
}
)
async def monitor_playback_position():
async for playback_position in avrcp_protocol.monitor_playback_position(
playback_interval=1
):
print("PLAYBACK POSITION CHANGED:", playback_position)
websocket_server.send_message(
{
"type": "playback-position-changed",
"params": {"position": playback_position},
}
)
async def monitor_player_application_settings():
async for settings in avrcp_protocol.monitor_player_application_settings():
print("PLAYER APPLICATION SETTINGS:", settings)
settings_as_dict = [
{"attribute": setting.attribute_id.name, "value": setting.value_id.name}
for setting in settings
]
websocket_server.send_message(
{
"type": "player-settings-changed",
"params": {"settings": settings_as_dict},
}
)
async def monitor_available_players():
async for _ in avrcp_protocol.monitor_available_players():
print("AVAILABLE PLAYERS CHANGED")
websocket_server.send_message(
{"type": "available-players-changed", "params": {}}
)
async def monitor_addressed_player():
async for player in avrcp_protocol.monitor_addressed_player():
print("ADDRESSED PLAYER CHANGED")
websocket_server.send_message(
{
"type": "addressed-player-changed",
"params": {
"player": {
"player_id": player.player_id,
"uid_counter": player.uid_counter,
}
},
}
)
async def monitor_uids():
async for uid_counter in avrcp_protocol.monitor_uids():
print("UIDS CHANGED")
websocket_server.send_message(
{
"type": "uids-changed",
"params": {
"uid_counter": uid_counter,
},
}
)
async def monitor_volume():
async for volume in avrcp_protocol.monitor_volume():
print("VOLUME CHANGED:", volume)
websocket_server.send_message(
{"type": "volume-changed", "params": {"volume": volume}}
)
# -----------------------------------------------------------------------------
class WebSocketServer:
def __init__(
self, avrcp_protocol: avrcp.Protocol, avrcp_delegate: Delegate
) -> None:
self.socket = None
self.delegate = None
self.avrcp_protocol = avrcp_protocol
self.avrcp_delegate = avrcp_delegate
async def start(self) -> None:
# pylint: disable-next=no-member
await websockets.serve(self.serve, 'localhost', 8989) # type: ignore
async def serve(self, socket, _path) -> None:
print('### WebSocket connected')
self.socket = socket
while True:
try:
message = await socket.recv()
print('Received: ', str(message))
parsed = json.loads(message)
message_type = parsed['type']
if message_type == 'send-key-down':
await self.on_send_key_down(parsed)
elif message_type == 'send-key-up':
await self.on_send_key_up(parsed)
elif message_type == 'set-volume':
await self.on_set_volume(parsed)
elif message_type == 'get-play-status':
await self.on_get_play_status()
elif message_type == 'get-element-attributes':
await self.on_get_element_attributes()
except websockets.exceptions.ConnectionClosedOK:
self.socket = None
break
async def on_send_key_down(self, message: dict) -> None:
key = avc.PassThroughFrame.OperationId[message["key"]]
await self.avrcp_protocol.send_key_event(key, True)
async def on_send_key_up(self, message: dict) -> None:
key = avc.PassThroughFrame.OperationId[message["key"]]
await self.avrcp_protocol.send_key_event(key, False)
async def on_set_volume(self, message: dict) -> None:
volume = message["volume"]
self.avrcp_delegate.volume = volume
self.avrcp_protocol.notify_volume_changed(volume)
async def on_get_play_status(self) -> None:
play_status = await self.avrcp_protocol.get_play_status()
self.send_message(
{
"type": "get-play-status-response",
"params": {
"song_length": play_status.song_length,
"song_position": play_status.song_position,
"play_status": play_status.play_status.name,
},
}
)
async def on_get_element_attributes(self) -> None:
attributes = await self.avrcp_protocol.get_element_attributes(
0,
[
avrcp.MediaAttributeId.TITLE,
avrcp.MediaAttributeId.ARTIST_NAME,
avrcp.MediaAttributeId.ALBUM_NAME,
avrcp.MediaAttributeId.TRACK_NUMBER,
avrcp.MediaAttributeId.TOTAL_NUMBER_OF_TRACKS,
avrcp.MediaAttributeId.GENRE,
avrcp.MediaAttributeId.PLAYING_TIME,
avrcp.MediaAttributeId.DEFAULT_COVER_ART,
],
)
self.send_message(
{
"type": "get-element-attributes-response",
"params": [
{
"attribute_id": attribute.attribute_id.name,
"attribute_value": attribute.attribute_value,
}
for attribute in attributes
],
}
)
def send_message(self, message: dict) -> None:
if self.socket is None:
print("no socket, dropping message")
return
serialized = json.dumps(message)
utils.AsyncRunner.spawn(self.socket.send(serialized))
# -----------------------------------------------------------------------------
class Delegate(avrcp.Delegate):
def __init__(self):
super().__init__(
[avrcp.EventId.VOLUME_CHANGED, avrcp.EventId.PLAYBACK_STATUS_CHANGED]
)
self.websocket_server = None
async def set_absolute_volume(self, volume: int) -> None:
await super().set_absolute_volume(volume)
if self.websocket_server is not None:
self.websocket_server.send_message(
{"type": "set-volume", "params": {"volume": volume}}
)
# -----------------------------------------------------------------------------
async def main() -> None:
if len(sys.argv) < 3:
print(
'Usage: run_avrcp_controller.py <device-config> <transport-spec> '
'<sbc-file> [<bt-addr>]'
)
print('example: run_avrcp_controller.py classic1.json usb:0')
return
print('<<< connecting to HCI...')
async with await open_transport_or_link(sys.argv[2]) as hci_transport:
print('<<< connected')
# Create a device
device = Device.from_config_file_with_hci(
sys.argv[1], hci_transport.source, hci_transport.sink
)
device.classic_enabled = True
# Setup the SDP to expose the sink service
device.sdp_service_records = sdp_records()
# Start the controller
await device.power_on()
# Create a listener to wait for AVDTP connections
listener = avdtp.Listener(avdtp.Listener.create_registrar(device))
listener.on('connection', on_avdtp_connection)
avrcp_delegate = Delegate()
avrcp_protocol = avrcp.Protocol(avrcp_delegate)
avrcp_protocol.listen(device)
websocket_server = WebSocketServer(avrcp_protocol, avrcp_delegate)
avrcp_delegate.websocket_server = websocket_server
avrcp_protocol.on(
"start", lambda: on_avrcp_start(avrcp_protocol, websocket_server)
)
await websocket_server.start()
if len(sys.argv) >= 5:
# Connect to the peer
target_address = sys.argv[4]
print(f'=== Connecting to {target_address}...')
connection = await device.connect(
target_address, transport=BT_BR_EDR_TRANSPORT
)
print(f'=== Connected to {connection.peer_address}!')
# Request authentication
print('*** Authenticating...')
await connection.authenticate()
print('*** Authenticated')
# Enable encryption
print('*** Enabling encryption...')
await connection.encrypt()
print('*** Encryption on')
server = await avdtp.Protocol.connect(connection)
listener.set_server(connection, server)
sink = server.add_sink(codec_capabilities())
sink.on('rtp_packet', on_rtp_packet)
await avrcp_protocol.connect(connection)
else:
# Start being discoverable and connectable
await device.set_discoverable(True)
await device.set_connectable(True)
await asyncio.get_event_loop().create_future()
# -----------------------------------------------------------------------------
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main())

Some files were not shown because too many files have changed in this diff Show More