Compare commits

..

1 Commits

Author SHA1 Message Date
uael
a7c87e7ad2 asha: import ASHA Pandora service from AOSP 2023-10-03 19:42:18 -07:00
305 changed files with 3882 additions and 29473 deletions

View File

@@ -29,7 +29,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
python-version: '3.10'
- name: Install dependencies
run: |
python -m pip install --upgrade pip

View File

@@ -1,43 +0,0 @@
name: Python Avatar
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
permissions:
contents: read
jobs:
test:
name: Avatar [${{ matrix.shard }}]
runs-on: ubuntu-latest
strategy:
matrix:
shard: [
1/24, 2/24, 3/24, 4/24,
5/24, 6/24, 7/24, 8/24,
9/24, 10/24, 11/24, 12/24,
13/24, 14/24, 15/24, 16/24,
17/24, 18/24, 19/24, 20/24,
21/24, 22/24, 23/24, 24/24,
]
steps:
- uses: actions/checkout@v3
- name: Set Up Python 3.11
uses: actions/setup-python@v4
with:
python-version: 3.11
- name: Install
run: |
python -m pip install --upgrade pip
python -m pip install .[avatar]
- name: Rootcanal
run: nohup python -m rootcanal > rootcanal.log &
- name: Test
run: |
avatar --list | grep -Ev '^=' > test-names.txt
timeout 5m avatar --test-beds bumble.bumbles --tests $(split test-names.txt -n l/${{ matrix.shard }})
- name: Rootcanal Logs
run: cat rootcanal.log

View File

@@ -47,7 +47,7 @@ jobs:
strategy:
matrix:
python-version: [ "3.8", "3.9", "3.10", "3.11" ]
rust-version: [ "1.76.0", "stable" ]
rust-version: [ "1.70.0", "stable" ]
fail-fast: false
steps:
- name: Check out from Git
@@ -56,7 +56,7 @@ jobs:
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install Python dependencies
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install ".[build,test,development,documentation]"
@@ -65,17 +65,15 @@ jobs:
with:
components: clippy,rustfmt
toolchain: ${{ matrix.rust-version }}
- name: Install Rust dependencies
run: cargo install cargo-all-features # allows building/testing combinations of features
- name: Check License Headers
run: cd rust && cargo run --features dev-tools --bin file-header check-all
- name: Rust Build
run: cd rust && cargo build --all-targets && cargo build-all-features --all-targets
run: cd rust && cargo build --all-targets && cargo build --all-features --all-targets
# Lints after build so what clippy needs is already built
- name: Rust Lints
run: cd rust && cargo fmt --check && cargo clippy --all-targets -- --deny warnings && cargo clippy --all-features --all-targets -- --deny warnings
- name: Rust Tests
run: cd rust && cargo test-all-features
run: cd rust && cargo test
# At some point, hook up publishing the binary. For now, just make sure it builds.
# Once we're ready to publish binaries, this should be built with `--release`.
- name: Build Bumble CLI

5
.gitignore vendored
View File

@@ -9,9 +9,4 @@ __pycache__
# generated by setuptools_scm
bumble/_version.py
.vscode/launch.json
.vscode/settings.json
/.idea
venv/
.venv/
# snoop logs
out/

13
.vscode/settings.json vendored
View File

@@ -12,9 +12,7 @@
"ASHA",
"asyncio",
"ATRAC",
"avctp",
"avdtp",
"avrcp",
"bitpool",
"bitstruct",
"BSCP",
@@ -23,10 +21,7 @@
"cccds",
"cmac",
"CONNECTIONLESS",
"csip",
"csis",
"csrcs",
"CVSD",
"datagram",
"DATALINK",
"delayreport",
@@ -34,8 +29,6 @@
"deregistration",
"dhkey",
"diversifier",
"endianness",
"ESCO",
"Fitbit",
"GATTLINK",
"HANDSFREE",
@@ -45,18 +38,15 @@
"libc",
"libusb",
"MITM",
"MSBC",
"NDIS",
"netsim",
"NONBLOCK",
"NONCONN",
"OXIMETER",
"popleft",
"PRAND",
"protobuf",
"psms",
"pyee",
"Pyodide",
"pyusb",
"rfcomm",
"ROHC",
@@ -64,7 +54,6 @@
"SEID",
"seids",
"SERV",
"SIRK",
"ssrc",
"strerror",
"subband",
@@ -74,8 +63,6 @@
"substates",
"tobytes",
"tsep",
"UNMUTE",
"unmuted",
"usbmodem",
"vhci",
"websockets",

File diff suppressed because it is too large Load Diff

View File

@@ -1,63 +0,0 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import click
from bumble.colors import color
from bumble.hci import Address
from bumble.helpers import generate_irk, verify_rpa_with_irk
@click.group()
def cli():
'''
This is a tool for generating IRK, RPA,
and verifying IRK/RPA pairs
'''
@click.command()
def gen_irk() -> None:
print(generate_irk().hex())
@click.command()
@click.argument("irk", type=str)
def gen_rpa(irk: str) -> None:
irk_bytes = bytes.fromhex(irk)
rpa = Address.generate_private_address(irk_bytes)
print(rpa.to_string(with_type_qualifier=False))
@click.command()
@click.argument("irk", type=str)
@click.argument("rpa", type=str)
def verify_rpa(irk: str, rpa: str) -> None:
address = Address(rpa)
irk_bytes = bytes.fromhex(irk)
if verify_rpa_with_irk(address, irk_bytes):
print(color("Verified", "green"))
else:
print(color("Not Verified", "red"))
def main():
cli.add_command(gen_irk)
cli.add_command(gen_rpa)
cli.add_command(verify_rpa)
cli()
# -----------------------------------------------------------------------------
if __name__ == '__main__':
main()

View File

@@ -777,7 +777,7 @@ class ConsoleApp:
if not service:
continue
values = [
await attribute.read_value(connection)
attribute.read_value(connection)
for connection in self.device.connections.values()
]
if not values:
@@ -796,11 +796,11 @@ class ConsoleApp:
if not characteristic:
continue
values = [
await attribute.read_value(connection)
attribute.read_value(connection)
for connection in self.device.connections.values()
]
if not values:
values = [await attribute.read_value(None)]
values = [attribute.read_value(None)]
# TODO: future optimization: convert CCCD value to human readable string
@@ -944,7 +944,7 @@ class ConsoleApp:
# send data to any subscribers
if isinstance(attribute, Characteristic):
await attribute.write_value(None, value)
attribute.write_value(None, value)
if attribute.has_properties(Characteristic.NOTIFY):
await self.device.gatt_server.notify_subscribers(attribute)
if attribute.has_properties(Characteristic.INDICATE):

View File

@@ -18,39 +18,30 @@
import asyncio
import os
import logging
import time
import click
from bumble.company_ids import COMPANY_IDENTIFIERS
from bumble.colors import color
from bumble.core import name_or_number
from bumble.hci import (
map_null_terminated_utf8_string,
LeFeatureMask,
HCI_SUCCESS,
HCI_LE_SUPPORTED_FEATURES_NAMES,
HCI_VERSION_NAMES,
LMP_VERSION_NAMES,
HCI_Command,
HCI_Command_Complete_Event,
HCI_Command_Status_Event,
HCI_READ_BUFFER_SIZE_COMMAND,
HCI_Read_Buffer_Size_Command,
HCI_READ_BD_ADDR_COMMAND,
HCI_Read_BD_ADDR_Command,
HCI_READ_LOCAL_NAME_COMMAND,
HCI_Read_Local_Name_Command,
HCI_LE_READ_BUFFER_SIZE_COMMAND,
HCI_LE_Read_Buffer_Size_Command,
HCI_LE_READ_MAXIMUM_DATA_LENGTH_COMMAND,
HCI_LE_Read_Maximum_Data_Length_Command,
HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND,
HCI_LE_Read_Number_Of_Supported_Advertising_Sets_Command,
HCI_LE_READ_MAXIMUM_ADVERTISING_DATA_LENGTH_COMMAND,
HCI_LE_Read_Maximum_Advertising_Data_Length_Command,
HCI_LE_READ_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND,
HCI_LE_Read_Suggested_Default_Data_Length_Command,
HCI_Read_Local_Version_Information_Command,
)
from bumble.host import Host
from bumble.transport import open_transport_or_link
@@ -66,7 +57,7 @@ def command_succeeded(response):
# -----------------------------------------------------------------------------
async def get_classic_info(host: Host) -> None:
async def get_classic_info(host):
if host.supports_command(HCI_READ_BD_ADDR_COMMAND):
response = await host.send_command(HCI_Read_BD_ADDR_Command())
if command_succeeded(response):
@@ -87,7 +78,7 @@ async def get_classic_info(host: Host) -> None:
# -----------------------------------------------------------------------------
async def get_le_info(host: Host) -> None:
async def get_le_info(host):
print()
if host.supports_command(HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND):
@@ -126,50 +117,13 @@ async def get_le_info(host: Host) -> None:
'\n',
)
if host.supports_command(HCI_LE_READ_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND):
response = await host.send_command(
HCI_LE_Read_Suggested_Default_Data_Length_Command()
)
if command_succeeded(response):
print(
color('Suggested Default Data Length:', 'yellow'),
f'{response.return_parameters.suggested_max_tx_octets}/'
f'{response.return_parameters.suggested_max_tx_time}',
'\n',
)
print(color('LE Features:', 'yellow'))
for feature in host.supported_le_features:
print(LeFeatureMask(feature).name)
print(' ', name_or_number(HCI_LE_SUPPORTED_FEATURES_NAMES, feature))
# -----------------------------------------------------------------------------
async def get_acl_flow_control_info(host: Host) -> None:
print()
if host.supports_command(HCI_READ_BUFFER_SIZE_COMMAND):
response = await host.send_command(
HCI_Read_Buffer_Size_Command(), check_result=True
)
print(
color('ACL Flow Control:', 'yellow'),
f'{response.return_parameters.hc_total_num_acl_data_packets} '
f'packets of size {response.return_parameters.hc_acl_data_packet_length}',
)
if host.supports_command(HCI_LE_READ_BUFFER_SIZE_COMMAND):
response = await host.send_command(
HCI_LE_Read_Buffer_Size_Command(), check_result=True
)
print(
color('LE ACL Flow Control:', 'yellow'),
f'{response.return_parameters.hc_total_num_le_acl_data_packets} '
f'packets of size {response.return_parameters.hc_le_acl_data_packet_length}',
)
# -----------------------------------------------------------------------------
async def async_main(latency_probes, transport):
async def async_main(transport):
print('<<< connecting to HCI...')
async with await open_transport_or_link(transport) as (hci_source, hci_sink):
print('<<< connected')
@@ -177,23 +131,6 @@ async def async_main(latency_probes, transport):
host = Host(hci_source, hci_sink)
await host.reset()
# Measure the latency if requested
latencies = []
if latency_probes:
for _ in range(latency_probes):
start = time.time()
await host.send_command(HCI_Read_Local_Version_Information_Command())
latencies.append(1000 * (time.time() - start))
print(
color('HCI Command Latency:', 'yellow'),
(
f'min={min(latencies):.2f}, '
f'max={max(latencies):.2f}, '
f'average={sum(latencies)/len(latencies):.2f}'
),
'\n',
)
# Print version
print(color('Version:', 'yellow'))
print(
@@ -217,9 +154,6 @@ async def async_main(latency_probes, transport):
# Get the LE info
await get_le_info(host)
# Print the ACL flow control info
await get_acl_flow_control_info(host)
# Print the list of commands supported by the controller
print()
print(color('Supported Commands:', 'yellow'))
@@ -229,16 +163,10 @@ async def async_main(latency_probes, transport):
# -----------------------------------------------------------------------------
@click.command()
@click.option(
'--latency-probes',
metavar='N',
type=int,
help='Send N commands to measure HCI transport latency statistics',
)
@click.argument('transport')
def main(latency_probes, transport):
def main(transport):
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
asyncio.run(async_main(latency_probes, transport))
asyncio.run(async_main(transport))
# -----------------------------------------------------------------------------

View File

@@ -1,205 +0,0 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import logging
import os
import time
from typing import Optional
from bumble.colors import color
from bumble.hci import (
HCI_READ_LOOPBACK_MODE_COMMAND,
HCI_Read_Loopback_Mode_Command,
HCI_WRITE_LOOPBACK_MODE_COMMAND,
HCI_Write_Loopback_Mode_Command,
LoopbackMode,
)
from bumble.host import Host
from bumble.transport import open_transport_or_link
import click
class Loopback:
"""Send and receive ACL data packets in local loopback mode"""
def __init__(self, packet_size: int, packet_count: int, transport: str):
self.transport = transport
self.packet_size = packet_size
self.packet_count = packet_count
self.connection_handle: Optional[int] = None
self.connection_event = asyncio.Event()
self.done = asyncio.Event()
self.expected_cid = 0
self.bytes_received = 0
self.start_timestamp = 0.0
self.last_timestamp = 0.0
def on_connection(self, connection_handle: int, *args):
"""Retrieve connection handle from new connection event"""
if not self.connection_event.is_set():
# save first connection handle for ACL
# subsequent connections are SCO
self.connection_handle = connection_handle
self.connection_event.set()
def on_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes):
"""Calculate packet receive speed"""
now = time.time()
print(f'<<< Received packet {cid}: {len(pdu)} bytes')
assert connection_handle == self.connection_handle
assert cid == self.expected_cid
self.expected_cid += 1
if cid == 0:
self.start_timestamp = now
else:
elapsed_since_start = now - self.start_timestamp
elapsed_since_last = now - self.last_timestamp
self.bytes_received += len(pdu)
instant_rx_speed = len(pdu) / elapsed_since_last
average_rx_speed = self.bytes_received / elapsed_since_start
print(
color(
f'@@@ RX speed: instant={instant_rx_speed:.4f},'
f' average={average_rx_speed:.4f}',
'cyan',
)
)
self.last_timestamp = now
if self.expected_cid == self.packet_count:
print(color('@@@ Received last packet', 'green'))
self.done.set()
async def run(self):
"""Run a loopback throughput test"""
print(color('>>> Connecting to HCI...', 'green'))
async with await open_transport_or_link(self.transport) as (
hci_source,
hci_sink,
):
print(color('>>> Connected', 'green'))
host = Host(hci_source, hci_sink)
await host.reset()
# make sure data can fit in one l2cap pdu
l2cap_header_size = 4
max_packet_size = (
host.acl_packet_queue
if host.acl_packet_queue
else host.le_acl_packet_queue
).max_packet_size - l2cap_header_size
if self.packet_size > max_packet_size:
print(
color(
f'!!! Packet size ({self.packet_size}) larger than max supported'
f' size ({max_packet_size})',
'red',
)
)
return
if not host.supports_command(
HCI_WRITE_LOOPBACK_MODE_COMMAND
) or not host.supports_command(HCI_READ_LOOPBACK_MODE_COMMAND):
print(color('!!! Loopback mode not supported', 'red'))
return
# set event callbacks
host.on('connection', self.on_connection)
host.on('l2cap_pdu', self.on_l2cap_pdu)
loopback_mode = LoopbackMode.LOCAL
print(color('### Setting loopback mode', 'blue'))
await host.send_command(
HCI_Write_Loopback_Mode_Command(loopback_mode=LoopbackMode.LOCAL),
check_result=True,
)
print(color('### Checking loopback mode', 'blue'))
response = await host.send_command(
HCI_Read_Loopback_Mode_Command(), check_result=True
)
if response.return_parameters.loopback_mode != loopback_mode:
print(color('!!! Loopback mode mismatch', 'red'))
return
await self.connection_event.wait()
print(color('### Connected', 'cyan'))
print(color('=== Start sending', 'magenta'))
start_time = time.time()
bytes_sent = 0
for cid in range(0, self.packet_count):
# using the cid as an incremental index
host.send_l2cap_pdu(
self.connection_handle, cid, bytes(self.packet_size)
)
print(
color(
f'>>> Sending packet {cid}: {self.packet_size} bytes', 'yellow'
)
)
bytes_sent += self.packet_size # don't count L2CAP or HCI header sizes
await asyncio.sleep(0) # yield to allow packet receive
await self.done.wait()
print(color('=== Done!', 'magenta'))
elapsed = time.time() - start_time
average_tx_speed = bytes_sent / elapsed
print(
color(
f'@@@ TX speed: average={average_tx_speed:.4f} ({bytes_sent} bytes'
f' in {elapsed:.2f} seconds)',
'green',
)
)
# -----------------------------------------------------------------------------
@click.command()
@click.option(
'--packet-size',
'-s',
metavar='SIZE',
type=click.IntRange(8, 4096),
default=500,
help='Packet size',
)
@click.option(
'--packet-count',
'-c',
metavar='COUNT',
type=click.IntRange(1, 65535),
default=10,
help='Packet count',
)
@click.argument('transport')
def main(packet_size, packet_count, transport):
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
loopback = Loopback(packet_size, packet_count, transport)
asyncio.run(loopback.run())
# -----------------------------------------------------------------------------
if __name__ == '__main__':
main()

View File

@@ -21,7 +21,6 @@ import struct
import logging
import click
from bumble import l2cap
from bumble.colors import color
from bumble.device import Device, Peer
from bumble.core import AdvertisingData
@@ -205,7 +204,7 @@ class GattlinkHubBridge(GattlinkL2capEndpoint, Device.Listener):
# -----------------------------------------------------------------------------
class GattlinkNodeBridge(GattlinkL2capEndpoint, Device.Listener):
def __init__(self, device: Device):
def __init__(self, device):
super().__init__()
self.device = device
self.peer = None
@@ -219,12 +218,7 @@ class GattlinkNodeBridge(GattlinkL2capEndpoint, Device.Listener):
# Listen for incoming L2CAP CoC connections
psm = 0xFB
device.create_l2cap_server(
spec=l2cap.LeCreditBasedChannelSpec(
psm=0xFB,
),
handler=self.on_coc,
)
device.register_l2cap_channel_server(0xFB, self.on_coc)
print(f'### Listening for CoC connection on PSM {psm}')
# Setup the Gattlink service

View File

@@ -20,7 +20,6 @@ import logging
import os
import click
from bumble import l2cap
from bumble.colors import color
from bumble.transport import open_transport_or_link
from bumble.device import Device
@@ -48,17 +47,16 @@ class ServerBridge:
self.tcp_host = tcp_host
self.tcp_port = tcp_port
async def start(self, device: Device) -> None:
# Listen for incoming L2CAP channel connections
device.create_l2cap_server(
spec=l2cap.LeCreditBasedChannelSpec(
psm=self.psm, mtu=self.mtu, mps=self.mps, max_credits=self.max_credits
),
handler=self.on_channel,
)
print(
color(f'### Listening for channel connection on PSM {self.psm}', 'yellow')
async def start(self, device):
# Listen for incoming L2CAP CoC connections
device.register_l2cap_channel_server(
psm=self.psm,
server=self.on_coc,
max_credits=self.max_credits,
mtu=self.mtu,
mps=self.mps,
)
print(color(f'### Listening for CoC connection on PSM {self.psm}', 'yellow'))
def on_ble_connection(connection):
def on_ble_disconnection(reason):
@@ -75,7 +73,7 @@ class ServerBridge:
await device.start_advertising(auto_restart=True)
# Called when a new L2CAP connection is established
def on_channel(self, l2cap_channel):
def on_coc(self, l2cap_channel):
print(color('*** L2CAP channel:', 'cyan'), l2cap_channel)
class Pipe:
@@ -85,7 +83,7 @@ class ServerBridge:
self.l2cap_channel = l2cap_channel
l2cap_channel.on('close', self.on_l2cap_close)
l2cap_channel.sink = self.on_channel_sdu
l2cap_channel.sink = self.on_coc_sdu
async def connect_to_tcp(self):
# Connect to the TCP server
@@ -130,7 +128,7 @@ class ServerBridge:
if self.tcp_transport is not None:
self.tcp_transport.close()
def on_channel_sdu(self, sdu):
def on_coc_sdu(self, sdu):
print(color(f'<<< [L2CAP SDU]: {len(sdu)} bytes', 'cyan'))
if self.tcp_transport is None:
print(color('!!! TCP socket not open, dropping', 'red'))
@@ -185,7 +183,7 @@ class ClientBridge:
peer_name = writer.get_extra_info('peer_name')
print(color(f'<<< TCP connection from {peer_name}', 'magenta'))
def on_channel_sdu(sdu):
def on_coc_sdu(sdu):
print(color(f'<<< [L2CAP SDU]: {len(sdu)} bytes', 'cyan'))
l2cap_to_tcp_pipe.write(sdu)
@@ -197,13 +195,11 @@ class ClientBridge:
# Connect a new L2CAP channel
print(color(f'>>> Opening L2CAP channel on PSM = {self.psm}', 'yellow'))
try:
l2cap_channel = await connection.create_l2cap_channel(
spec=l2cap.LeCreditBasedChannelSpec(
psm=self.psm,
max_credits=self.max_credits,
mtu=self.mtu,
mps=self.mps,
)
l2cap_channel = await connection.open_l2cap_channel(
psm=self.psm,
max_credits=self.max_credits,
mtu=self.mtu,
mps=self.mps,
)
print(color('*** L2CAP channel:', 'cyan'), l2cap_channel)
except Exception as error:
@@ -211,7 +207,7 @@ class ClientBridge:
writer.close()
return
l2cap_channel.sink = on_channel_sdu
l2cap_channel.sink = on_coc_sdu
l2cap_channel.on('close', on_l2cap_close)
# Start a flow control pipe from L2CAP to TCP
@@ -276,29 +272,23 @@ async def run(device_config, hci_transport, bridge):
@click.pass_context
@click.option('--device-config', help='Device configuration file', required=True)
@click.option('--hci-transport', help='HCI transport', required=True)
@click.option('--psm', help='PSM for L2CAP', type=int, default=1234)
@click.option('--psm', help='PSM for L2CAP CoC', type=int, default=1234)
@click.option(
'--l2cap-max-credits',
help='Maximum L2CAP Credits',
'--l2cap-coc-max-credits',
help='Maximum L2CAP CoC Credits',
type=click.IntRange(1, 65535),
default=128,
)
@click.option(
'--l2cap-mtu',
help='L2CAP MTU',
type=click.IntRange(
l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU,
l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MTU,
),
default=1024,
'--l2cap-coc-mtu',
help='L2CAP CoC MTU',
type=click.IntRange(23, 65535),
default=1022,
)
@click.option(
'--l2cap-mps',
help='L2CAP MPS',
type=click.IntRange(
l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MPS,
l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MPS,
),
'--l2cap-coc-mps',
help='L2CAP CoC MPS',
type=click.IntRange(23, 65533),
default=1024,
)
def cli(
@@ -306,17 +296,17 @@ def cli(
device_config,
hci_transport,
psm,
l2cap_max_credits,
l2cap_mtu,
l2cap_mps,
l2cap_coc_max_credits,
l2cap_coc_mtu,
l2cap_coc_mps,
):
context.ensure_object(dict)
context.obj['device_config'] = device_config
context.obj['hci_transport'] = hci_transport
context.obj['psm'] = psm
context.obj['max_credits'] = l2cap_max_credits
context.obj['mtu'] = l2cap_mtu
context.obj['mps'] = l2cap_mps
context.obj['max_credits'] = l2cap_coc_max_credits
context.obj['mtu'] = l2cap_coc_mtu
context.obj['mps'] = l2cap_coc_mps
# -----------------------------------------------------------------------------

View File

@@ -24,16 +24,10 @@ from prompt_toolkit.shortcuts import PromptSession
from bumble.colors import color
from bumble.device import Device, Peer
from bumble.transport import open_transport_or_link
from bumble.pairing import OobData, PairingDelegate, PairingConfig
from bumble.smp import OobContext, OobLegacyContext
from bumble.pairing import PairingDelegate, PairingConfig
from bumble.smp import error_name as smp_error_name
from bumble.keys import JsonKeyStore
from bumble.core import (
AdvertisingData,
ProtocolError,
BT_LE_TRANSPORT,
BT_BR_EDR_TRANSPORT,
)
from bumble.core import ProtocolError
from bumble.gatt import (
GATT_DEVICE_NAME_CHARACTERISTIC,
GATT_GENERIC_ACCESS_SERVICE,
@@ -52,13 +46,11 @@ from bumble.att import (
class Waiter:
instance = None
def __init__(self, linger=False):
def __init__(self):
self.done = asyncio.get_running_loop().create_future()
self.linger = linger
def terminate(self):
if not self.linger:
self.done.set_result(None)
self.done.set_result(None)
async def wait_until_terminated(self):
return await self.done
@@ -68,7 +60,7 @@ class Waiter:
class Delegate(PairingDelegate):
def __init__(self, mode, connection, capability_string, do_prompt):
super().__init__(
io_capability={
{
'keyboard': PairingDelegate.KEYBOARD_INPUT_ONLY,
'display': PairingDelegate.DISPLAY_OUTPUT_ONLY,
'display+keyboard': PairingDelegate.DISPLAY_OUTPUT_AND_KEYBOARD_INPUT,
@@ -293,9 +285,7 @@ async def pair(
mitm,
bond,
ctkd,
linger,
io,
oob,
prompt,
request,
print_keys,
@@ -304,7 +294,7 @@ async def pair(
hci_transport,
address_or_name,
):
Waiter.instance = Waiter(linger=linger)
Waiter.instance = Waiter()
print('<<< connecting to HCI...')
async with await open_transport_or_link(hci_transport) as (hci_source, hci_sink):
@@ -316,7 +306,6 @@ async def pair(
# Expose a GATT characteristic that can be used to trigger pairing by
# responding with an authentication error when read
if mode == 'le':
device.le_enabled = True
device.add_service(
Service(
'50DB505C-8AC4-4738-8448-3B1D9CC09CC5',
@@ -337,6 +326,7 @@ async def pair(
# Select LE or Classic
if mode == 'classic':
device.classic_enabled = True
device.le_enabled = False
device.classic_smp_enabled = ctkd
# Get things going
@@ -353,51 +343,16 @@ async def pair(
await device.keystore.print(prefix=color('@@@ ', 'blue'))
print(color('@@@-----------------------------------', 'blue'))
# Create an OOB context if needed
if oob:
our_oob_context = OobContext()
shared_data = (
None
if oob == '-'
else OobData.from_ad(AdvertisingData.from_bytes(bytes.fromhex(oob)))
)
legacy_context = OobLegacyContext()
oob_contexts = PairingConfig.OobConfig(
our_context=our_oob_context,
peer_data=shared_data,
legacy_context=legacy_context,
)
oob_data = OobData(
address=device.random_address,
shared_data=shared_data,
legacy_context=legacy_context,
)
print(color('@@@-----------------------------------', 'yellow'))
print(color('@@@ OOB Data:', 'yellow'))
print(color(f'@@@ {our_oob_context.share()}', 'yellow'))
print(color(f'@@@ TK={legacy_context.tk.hex()}', 'yellow'))
print(color(f'@@@ HEX: ({bytes(oob_data.to_ad()).hex()})', 'yellow'))
print(color('@@@-----------------------------------', 'yellow'))
else:
oob_contexts = None
# Set up a pairing config factory
device.pairing_config_factory = lambda connection: PairingConfig(
sc=sc,
mitm=mitm,
bonding=bond,
oob=oob_contexts,
delegate=Delegate(mode, connection, io, prompt),
sc, mitm, bond, Delegate(mode, connection, io, prompt)
)
# Connect to a peer or wait for a connection
device.on('connection', lambda connection: on_connection(connection, request))
if address_or_name is not None:
print(color(f'=== Connecting to {address_or_name}...', 'green'))
connection = await device.connect(
address_or_name,
transport=BT_LE_TRANSPORT if mode == 'le' else BT_BR_EDR_TRANSPORT,
)
connection = await device.connect(address_or_name)
if not request:
try:
@@ -405,9 +360,10 @@ async def pair(
await connection.pair()
else:
await connection.authenticate()
return
except ProtocolError as error:
print(color(f'Pairing failed: {error}', 'red'))
return
else:
if mode == 'le':
# Advertise so that peers can find us and connect
@@ -457,7 +413,6 @@ class LogHandler(logging.Handler):
help='Enable CTKD',
show_default=True,
)
@click.option('--linger', default=False, is_flag=True, help='Linger after pairing')
@click.option(
'--io',
type=click.Choice(
@@ -466,14 +421,6 @@ class LogHandler(logging.Handler):
default='display+keyboard',
show_default=True,
)
@click.option(
'--oob',
metavar='<oob-data-hex>',
help=(
'Use OOB pairing with this data from the peer '
'(use "-" to enable OOB without peer data)'
),
)
@click.option('--prompt', is_flag=True, help='Prompt to accept/reject pairing request')
@click.option(
'--request', is_flag=True, help='Request that the connecting peer initiate pairing'
@@ -493,9 +440,7 @@ def main(
mitm,
bond,
ctkd,
linger,
io,
oob,
prompt,
request,
print_keys,
@@ -518,9 +463,7 @@ def main(
mitm,
bond,
ctkd,
linger,
io,
oob,
prompt,
request,
print_keys,

View File

@@ -26,7 +26,7 @@ from bumble.transport import open_transport_or_link
from bumble.keys import JsonKeyStore
from bumble.smp import AddressResolver
from bumble.device import Advertisement
from bumble.hci import Address, HCI_Constant, HCI_LE_1M_PHY, HCI_LE_CODED_PHY
from bumble.hci import HCI_Constant, HCI_LE_1M_PHY, HCI_LE_CODED_PHY
# -----------------------------------------------------------------------------
@@ -66,15 +66,10 @@ class AdvertisementPrinter:
address_type_string = ('PUBLIC', 'RANDOM', 'PUBLIC_ID', 'RANDOM_ID')[
address.address_type
]
if address.address_type in (
Address.RANDOM_IDENTITY_ADDRESS,
Address.PUBLIC_IDENTITY_ADDRESS,
):
type_color = 'yellow'
if address.is_public:
type_color = 'cyan'
else:
if address.is_public:
type_color = 'cyan'
elif address.is_static:
if address.is_static:
type_color = 'green'
address_qualifier = '(static)'
elif address.is_resolvable:
@@ -121,7 +116,6 @@ async def scan(
phy,
filter_duplicates,
raw,
irks,
keystore_file,
device_config,
transport,
@@ -146,21 +140,9 @@ async def scan(
if device.keystore:
resolving_keys = await device.keystore.get_resolving_keys()
resolver = AddressResolver(resolving_keys)
else:
resolving_keys = []
for irk_and_address in irks:
if ':' not in irk_and_address:
raise ValueError('invalid IRK:ADDRESS value')
irk_hex, address_str = irk_and_address.split(':', 1)
resolving_keys.append(
(
bytes.fromhex(irk_hex),
Address(address_str, Address.RANDOM_DEVICE_ADDRESS),
)
)
resolver = AddressResolver(resolving_keys) if resolving_keys else None
resolver = None
printer = AdvertisementPrinter(min_rssi, resolver)
if raw:
@@ -205,24 +187,8 @@ async def scan(
default=False,
help='Listen for raw advertising reports instead of processed ones',
)
@click.option(
'--irk',
metavar='<IRK_HEX>:<ADDRESS>',
help=(
'Use this IRK for resolving private addresses ' '(may be used more than once)'
),
multiple=True,
)
@click.option(
'--keystore-file',
metavar='FILE_PATH',
help='Keystore file to use when resolving addresses',
)
@click.option(
'--device-config',
metavar='FILE_PATH',
help='Device config file for the scanning device',
)
@click.option('--keystore-file', help='Keystore file to use when resolving addresses')
@click.option('--device-config', help='Device config file for the scanning device')
@click.argument('transport')
def main(
min_rssi,
@@ -232,7 +198,6 @@ def main(
phy,
filter_duplicates,
raw,
irk,
keystore_file,
device_config,
transport,
@@ -247,7 +212,6 @@ def main(
phy,
filter_duplicates,
raw,
irk,
keystore_file,
device_config,
transport,

View File

@@ -15,11 +15,7 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import datetime
import logging
import os
import struct
import click
from bumble.colors import color
@@ -28,14 +24,6 @@ from bumble.transport.common import PacketReader
from bumble.helpers import PacketTracer
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
class SnoopPacketReader:
'''
@@ -48,18 +36,12 @@ class SnoopPacketReader:
DATALINK_BSCP = 1003
DATALINK_H5 = 1004
IDENTIFICATION_PATTERN = b'btsnoop\0'
TIMESTAMP_ANCHOR = datetime.datetime(2000, 1, 1)
TIMESTAMP_DELTA = 0x00E03AB44A676000
ONE_MICROSECOND = datetime.timedelta(microseconds=1)
def __init__(self, source):
self.source = source
self.at_end = False
# Read the header
identification_pattern = source.read(8)
if identification_pattern != self.IDENTIFICATION_PATTERN:
if identification_pattern.hex().lower() != '6274736e6f6f7000':
raise ValueError(
'not a valid snoop file, unexpected identification pattern'
)
@@ -73,32 +55,19 @@ class SnoopPacketReader:
# Read the record header
header = self.source.read(24)
if len(header) < 24:
self.at_end = True
return (None, 0, None)
# Parse the header
return (0, None)
(
original_length,
included_length,
packet_flags,
_cumulative_drops,
timestamp,
) = struct.unpack('>IIIIQ', header)
_timestamp_seconds,
_timestamp_microsecond,
) = struct.unpack('>IIIIII', header)
# Skip truncated packets
# Abort on truncated packets
if original_length != included_length:
print(
color(
f"!!! truncated packet ({included_length}/{original_length})", "red"
)
)
self.source.read(included_length)
return (None, 0, None)
# Convert the timestamp to a datetime object.
ts_dt = self.TIMESTAMP_ANCHOR + datetime.timedelta(
microseconds=timestamp - self.TIMESTAMP_DELTA
)
return (0, None)
if self.data_link_type == self.DATALINK_H1:
# The packet is un-encapsulated, look at the flags to figure out its type
@@ -120,17 +89,7 @@ class SnoopPacketReader:
bytes([packet_type]) + self.source.read(included_length),
)
return (ts_dt, packet_flags & 1, self.source.read(included_length))
# -----------------------------------------------------------------------------
class Printer:
def __init__(self):
self.index = 0
def print(self, message: str) -> None:
self.index += 1
print(f"[{self.index:8}]{message}")
return (packet_flags & 1, self.source.read(included_length))
# -----------------------------------------------------------------------------
@@ -163,28 +122,24 @@ def main(format, vendors, filename):
packet_reader = PacketReader(input)
def read_next_packet():
return (None, 0, packet_reader.next_packet())
return (0, packet_reader.next_packet())
else:
packet_reader = SnoopPacketReader(input)
read_next_packet = packet_reader.next_packet
printer = Printer()
tracer = PacketTracer(emit_message=printer.print)
tracer = PacketTracer(emit_message=print)
while not packet_reader.at_end:
while True:
try:
(timestamp, direction, packet) = read_next_packet()
if packet:
tracer.trace(hci.HCI_Packet.from_bytes(packet), direction, timestamp)
else:
printer.print(color("[TRUNCATED]", "red"))
(direction, packet) = read_next_packet()
if packet is None:
break
tracer.trace(hci.HCI_Packet.from_bytes(packet), direction)
except Exception as error:
logger.exception()
print(color(f'!!! {error}', 'red'))
# -----------------------------------------------------------------------------
if __name__ == '__main__':
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
main() # pylint: disable=no-value-for-parameter

View File

@@ -641,7 +641,7 @@ class Speaker:
self.device.on('connection', self.on_bluetooth_connection)
# Create a listener to wait for AVDTP connections
self.listener = Listener.for_device(self.device)
self.listener = Listener(Listener.create_registrar(self.device))
self.listener.on('connection', self.on_avdtp_connection)
print(f'Speaker ready to play, codec={color(self.codec, "cyan")}')

View File

@@ -15,13 +15,9 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import dataclasses
import struct
import logging
from collections.abc import AsyncGenerator
from typing import List, Callable, Awaitable
from collections import namedtuple
from .company_ids import COMPANY_IDENTIFIERS
from .sdp import (
@@ -184,12 +180,8 @@ def make_audio_source_service_sdp_records(service_record_handle, version=(1, 3))
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.sequence(
[
DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE),
DataElement.unsigned_integer_16(version_int),
]
)
DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE),
DataElement.unsigned_integer_16(version_int),
]
),
),
@@ -238,12 +230,8 @@ def make_audio_sink_service_sdp_records(service_record_handle, version=(1, 3)):
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.sequence(
[
DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE),
DataElement.unsigned_integer_16(version_int),
]
)
DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE),
DataElement.unsigned_integer_16(version_int),
]
),
),
@@ -251,20 +239,24 @@ def make_audio_sink_service_sdp_records(service_record_handle, version=(1, 3)):
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class SbcMediaCodecInformation:
class SbcMediaCodecInformation(
namedtuple(
'SbcMediaCodecInformation',
[
'sampling_frequency',
'channel_mode',
'block_length',
'subbands',
'allocation_method',
'minimum_bitpool_value',
'maximum_bitpool_value',
],
)
):
'''
A2DP spec - 4.3.2 Codec Specific Information Elements
'''
sampling_frequency: int
channel_mode: int
block_length: int
subbands: int
allocation_method: int
minimum_bitpool_value: int
maximum_bitpool_value: int
SAMPLING_FREQUENCY_BITS = {16000: 1 << 3, 32000: 1 << 2, 44100: 1 << 1, 48000: 1}
CHANNEL_MODE_BITS = {
SBC_MONO_CHANNEL_MODE: 1 << 3,
@@ -280,7 +272,7 @@ class SbcMediaCodecInformation:
}
@staticmethod
def from_bytes(data: bytes) -> SbcMediaCodecInformation:
def from_bytes(data: bytes) -> 'SbcMediaCodecInformation':
sampling_frequency = (data[0] >> 4) & 0x0F
channel_mode = (data[0] >> 0) & 0x0F
block_length = (data[1] >> 4) & 0x0F
@@ -301,14 +293,14 @@ class SbcMediaCodecInformation:
@classmethod
def from_discrete_values(
cls,
sampling_frequency: int,
channel_mode: int,
block_length: int,
subbands: int,
allocation_method: int,
minimum_bitpool_value: int,
maximum_bitpool_value: int,
) -> SbcMediaCodecInformation:
sampling_frequency,
channel_mode,
block_length,
subbands,
allocation_method,
minimum_bitpool_value,
maximum_bitpool_value,
):
return SbcMediaCodecInformation(
sampling_frequency=cls.SAMPLING_FREQUENCY_BITS[sampling_frequency],
channel_mode=cls.CHANNEL_MODE_BITS[channel_mode],
@@ -322,14 +314,14 @@ class SbcMediaCodecInformation:
@classmethod
def from_lists(
cls,
sampling_frequencies: List[int],
channel_modes: List[int],
block_lengths: List[int],
subbands: List[int],
allocation_methods: List[int],
minimum_bitpool_value: int,
maximum_bitpool_value: int,
) -> SbcMediaCodecInformation:
sampling_frequencies,
channel_modes,
block_lengths,
subbands,
allocation_methods,
minimum_bitpool_value,
maximum_bitpool_value,
):
return SbcMediaCodecInformation(
sampling_frequency=sum(
cls.SAMPLING_FREQUENCY_BITS[x] for x in sampling_frequencies
@@ -356,7 +348,7 @@ class SbcMediaCodecInformation:
]
)
def __str__(self) -> str:
def __str__(self):
channel_modes = ['MONO', 'DUAL_CHANNEL', 'STEREO', 'JOINT_STEREO']
allocation_methods = ['SNR', 'Loudness']
return '\n'.join(
@@ -375,19 +367,16 @@ class SbcMediaCodecInformation:
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class AacMediaCodecInformation:
class AacMediaCodecInformation(
namedtuple(
'AacMediaCodecInformation',
['object_type', 'sampling_frequency', 'channels', 'rfa', 'vbr', 'bitrate'],
)
):
'''
A2DP spec - 4.5.2 Codec Specific Information Elements
'''
object_type: int
sampling_frequency: int
channels: int
rfa: int
vbr: int
bitrate: int
OBJECT_TYPE_BITS = {
MPEG_2_AAC_LC_OBJECT_TYPE: 1 << 7,
MPEG_4_AAC_LC_OBJECT_TYPE: 1 << 6,
@@ -411,7 +400,7 @@ class AacMediaCodecInformation:
CHANNELS_BITS = {1: 1 << 1, 2: 1}
@staticmethod
def from_bytes(data: bytes) -> AacMediaCodecInformation:
def from_bytes(data: bytes) -> 'AacMediaCodecInformation':
object_type = data[0]
sampling_frequency = (data[1] << 4) | ((data[2] >> 4) & 0x0F)
channels = (data[2] >> 2) & 0x03
@@ -424,13 +413,8 @@ class AacMediaCodecInformation:
@classmethod
def from_discrete_values(
cls,
object_type: int,
sampling_frequency: int,
channels: int,
vbr: int,
bitrate: int,
) -> AacMediaCodecInformation:
cls, object_type, sampling_frequency, channels, vbr, bitrate
):
return AacMediaCodecInformation(
object_type=cls.OBJECT_TYPE_BITS[object_type],
sampling_frequency=cls.SAMPLING_FREQUENCY_BITS[sampling_frequency],
@@ -441,14 +425,7 @@ class AacMediaCodecInformation:
)
@classmethod
def from_lists(
cls,
object_types: List[int],
sampling_frequencies: List[int],
channels: List[int],
vbr: int,
bitrate: int,
) -> AacMediaCodecInformation:
def from_lists(cls, object_types, sampling_frequencies, channels, vbr, bitrate):
return AacMediaCodecInformation(
object_type=sum(cls.OBJECT_TYPE_BITS[x] for x in object_types),
sampling_frequency=sum(
@@ -472,7 +449,7 @@ class AacMediaCodecInformation:
]
)
def __str__(self) -> str:
def __str__(self):
object_types = [
'MPEG_2_AAC_LC',
'MPEG_4_AAC_LC',
@@ -497,26 +474,26 @@ class AacMediaCodecInformation:
)
@dataclasses.dataclass
# -----------------------------------------------------------------------------
class VendorSpecificMediaCodecInformation:
'''
A2DP spec - 4.7.2 Codec Specific Information Elements
'''
vendor_id: int
codec_id: int
value: bytes
@staticmethod
def from_bytes(data: bytes) -> VendorSpecificMediaCodecInformation:
def from_bytes(data):
(vendor_id, codec_id) = struct.unpack_from('<IH', data, 0)
return VendorSpecificMediaCodecInformation(vendor_id, codec_id, data[6:])
def __bytes__(self) -> bytes:
def __init__(self, vendor_id, codec_id, value):
self.vendor_id = vendor_id
self.codec_id = codec_id
self.value = value
def __bytes__(self):
return struct.pack('<IH', self.vendor_id, self.codec_id, self.value)
def __str__(self) -> str:
def __str__(self):
# pylint: disable=line-too-long
return '\n'.join(
[
@@ -529,27 +506,29 @@ class VendorSpecificMediaCodecInformation:
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class SbcFrame:
sampling_frequency: int
block_count: int
channel_mode: int
subband_count: int
payload: bytes
def __init__(
self, sampling_frequency, block_count, channel_mode, subband_count, payload
):
self.sampling_frequency = sampling_frequency
self.block_count = block_count
self.channel_mode = channel_mode
self.subband_count = subband_count
self.payload = payload
@property
def sample_count(self) -> int:
def sample_count(self):
return self.subband_count * self.block_count
@property
def bitrate(self) -> int:
def bitrate(self):
return 8 * ((len(self.payload) * self.sampling_frequency) // self.sample_count)
@property
def duration(self) -> float:
def duration(self):
return self.sample_count / self.sampling_frequency
def __str__(self) -> str:
def __str__(self):
return (
f'SBC(sf={self.sampling_frequency},'
f'cm={self.channel_mode},'
@@ -561,12 +540,12 @@ class SbcFrame:
# -----------------------------------------------------------------------------
class SbcParser:
def __init__(self, read: Callable[[int], Awaitable[bytes]]) -> None:
def __init__(self, read):
self.read = read
@property
def frames(self) -> AsyncGenerator[SbcFrame, None]:
async def generate_frames() -> AsyncGenerator[SbcFrame, None]:
def frames(self):
async def generate_frames():
while True:
# Read 4 bytes of header
header = await self.read(4)
@@ -610,9 +589,7 @@ class SbcParser:
# -----------------------------------------------------------------------------
class SbcPacketSource:
def __init__(
self, read: Callable[[int], Awaitable[bytes]], mtu: int, codec_capabilities
) -> None:
def __init__(self, read, mtu, codec_capabilities):
self.read = read
self.mtu = mtu
self.codec_capabilities = codec_capabilities

View File

@@ -25,21 +25,9 @@
from __future__ import annotations
import enum
import functools
import inspect
import struct
from typing import (
Any,
Awaitable,
Callable,
Dict,
List,
Optional,
Type,
Union,
TYPE_CHECKING,
)
from pyee import EventEmitter
from typing import Dict, Type, List, Protocol, Union, Optional, Any, TYPE_CHECKING
from bumble.core import UUID, name_or_number, ProtocolError
from bumble.hci import HCI_Object, key_with_value
@@ -734,38 +722,12 @@ class ATT_Handle_Value_Confirmation(ATT_PDU):
# -----------------------------------------------------------------------------
class AttributeValue:
'''
Attribute value where reading and/or writing is delegated to functions
passed as arguments to the constructor.
'''
class ConnectionValue(Protocol):
def read(self, connection) -> bytes:
...
def __init__(
self,
read: Union[
Callable[[Optional[Connection]], bytes],
Callable[[Optional[Connection]], Awaitable[bytes]],
None,
] = None,
write: Union[
Callable[[Optional[Connection], bytes], None],
Callable[[Optional[Connection], bytes], Awaitable[None]],
None,
] = None,
):
self._read = read
self._write = write
def read(self, connection: Optional[Connection]) -> Union[bytes, Awaitable[bytes]]:
return self._read(connection) if self._read else b''
def write(
self, connection: Optional[Connection], value: bytes
) -> Union[Awaitable[None], None]:
if self._write:
return self._write(connection, value)
return None
def write(self, connection, value: bytes) -> None:
...
# -----------------------------------------------------------------------------
@@ -808,13 +770,13 @@ class Attribute(EventEmitter):
READ_REQUIRES_AUTHORIZATION = Permissions.READ_REQUIRES_AUTHORIZATION
WRITE_REQUIRES_AUTHORIZATION = Permissions.WRITE_REQUIRES_AUTHORIZATION
value: Union[bytes, AttributeValue]
value: Union[str, bytes, ConnectionValue]
def __init__(
self,
attribute_type: Union[str, bytes, UUID],
permissions: Union[str, Attribute.Permissions],
value: Union[str, bytes, AttributeValue] = b'',
value: Union[str, bytes, ConnectionValue] = b'',
) -> None:
EventEmitter.__init__(self)
self.handle = 0
@@ -844,7 +806,7 @@ class Attribute(EventEmitter):
def decode_value(self, value_bytes: bytes) -> Any:
return value_bytes
async def read_value(self, connection: Optional[Connection]) -> bytes:
def read_value(self, connection: Optional[Connection]) -> bytes:
if (
(self.permissions & self.READ_REQUIRES_ENCRYPTION)
and connection is not None
@@ -870,8 +832,6 @@ class Attribute(EventEmitter):
if hasattr(self.value, 'read'):
try:
value = self.value.read(connection)
if inspect.isawaitable(value):
value = await value
except ATT_Error as error:
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
@@ -881,7 +841,7 @@ class Attribute(EventEmitter):
return self.encode_value(value)
async def write_value(self, connection: Connection, value_bytes: bytes) -> None:
def write_value(self, connection: Connection, value_bytes: bytes) -> None:
if (
self.permissions & self.WRITE_REQUIRES_ENCRYPTION
) and not connection.encryption:
@@ -904,9 +864,7 @@ class Attribute(EventEmitter):
if hasattr(self.value, 'write'):
try:
result = self.value.write(connection, value)
if inspect.isawaitable(result):
await result
self.value.write(connection, value) # pylint: disable=not-callable
except ATT_Error as error:
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle

View File

@@ -1,520 +0,0 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import enum
import struct
from typing import Dict, Type, Union, Tuple
from bumble.utils import OpenIntEnum
# -----------------------------------------------------------------------------
class Frame:
class SubunitType(enum.IntEnum):
# AV/C Digital Interface Command Set General Specification Version 4.1
# Table 7.4
MONITOR = 0x00
AUDIO = 0x01
PRINTER = 0x02
DISC = 0x03
TAPE_RECORDER_OR_PLAYER = 0x04
TUNER = 0x05
CA = 0x06
CAMERA = 0x07
PANEL = 0x09
BULLETIN_BOARD = 0x0A
VENDOR_UNIQUE = 0x1C
EXTENDED = 0x1E
UNIT = 0x1F
class OperationCode(OpenIntEnum):
# 0x00 - 0x0F: Unit and subunit commands
VENDOR_DEPENDENT = 0x00
RESERVE = 0x01
PLUG_INFO = 0x02
# 0x10 - 0x3F: Unit commands
DIGITAL_OUTPUT = 0x10
DIGITAL_INPUT = 0x11
CHANNEL_USAGE = 0x12
OUTPUT_PLUG_SIGNAL_FORMAT = 0x18
INPUT_PLUG_SIGNAL_FORMAT = 0x19
GENERAL_BUS_SETUP = 0x1F
CONNECT_AV = 0x20
DISCONNECT_AV = 0x21
CONNECTIONS = 0x22
CONNECT = 0x24
DISCONNECT = 0x25
UNIT_INFO = 0x30
SUBUNIT_INFO = 0x31
# 0x40 - 0x7F: Subunit commands
PASS_THROUGH = 0x7C
GUI_UPDATE = 0x7D
PUSH_GUI_DATA = 0x7E
USER_ACTION = 0x7F
# 0xA0 - 0xBF: Unit and subunit commands
VERSION = 0xB0
POWER = 0xB2
subunit_type: SubunitType
subunit_id: int
opcode: OperationCode
operands: bytes
@staticmethod
def subclass(subclass):
# Infer the opcode from the class name
if subclass.__name__.endswith("CommandFrame"):
short_name = subclass.__name__.replace("CommandFrame", "")
category_class = CommandFrame
elif subclass.__name__.endswith("ResponseFrame"):
short_name = subclass.__name__.replace("ResponseFrame", "")
category_class = ResponseFrame
else:
raise ValueError(f"invalid subclass name {subclass.__name__}")
uppercase_indexes = [
i for i in range(len(short_name)) if short_name[i].isupper()
]
uppercase_indexes.append(len(short_name))
words = [
short_name[uppercase_indexes[i] : uppercase_indexes[i + 1]].upper()
for i in range(len(uppercase_indexes) - 1)
]
opcode_name = "_".join(words)
opcode = Frame.OperationCode[opcode_name]
category_class.subclasses[opcode] = subclass
return subclass
@staticmethod
def from_bytes(data: bytes) -> Frame:
if data[0] >> 4 != 0:
raise ValueError("first 4 bits must be 0s")
ctype_or_response = data[0] & 0xF
subunit_type = Frame.SubunitType(data[1] >> 3)
subunit_id = data[1] & 7
if subunit_type == Frame.SubunitType.EXTENDED:
# Not supported
raise NotImplementedError("extended subunit types not supported")
if subunit_id < 5:
opcode_offset = 2
elif subunit_id == 5:
# Extended to the next byte
extension = data[2]
if extension == 0:
raise ValueError("extended subunit ID value reserved")
if extension == 0xFF:
subunit_id = 5 + 254 + data[3]
opcode_offset = 4
else:
subunit_id = 5 + extension
opcode_offset = 3
elif subunit_id == 6:
raise ValueError("reserved subunit ID")
opcode = Frame.OperationCode(data[opcode_offset])
operands = data[opcode_offset + 1 :]
# Look for a registered subclass
if ctype_or_response < 8:
# Command
ctype = CommandFrame.CommandType(ctype_or_response)
if c_subclass := CommandFrame.subclasses.get(opcode):
return c_subclass(
ctype,
subunit_type,
subunit_id,
*c_subclass.parse_operands(operands),
)
return CommandFrame(ctype, subunit_type, subunit_id, opcode, operands)
else:
# Response
response = ResponseFrame.ResponseCode(ctype_or_response)
if r_subclass := ResponseFrame.subclasses.get(opcode):
return r_subclass(
response,
subunit_type,
subunit_id,
*r_subclass.parse_operands(operands),
)
return ResponseFrame(response, subunit_type, subunit_id, opcode, operands)
def to_bytes(
self,
ctype_or_response: Union[CommandFrame.CommandType, ResponseFrame.ResponseCode],
) -> bytes:
# TODO: support extended subunit types and ids.
return (
bytes(
[
ctype_or_response,
self.subunit_type << 3 | self.subunit_id,
self.opcode,
]
)
+ self.operands
)
def to_string(self, extra: str) -> str:
return (
f"{self.__class__.__name__}({extra}"
f"subunit_type={self.subunit_type.name}, "
f"subunit_id=0x{self.subunit_id:02X}, "
f"opcode={self.opcode.name}, "
f"operands={self.operands.hex()})"
)
def __init__(
self,
subunit_type: SubunitType,
subunit_id: int,
opcode: OperationCode,
operands: bytes,
) -> None:
self.subunit_type = subunit_type
self.subunit_id = subunit_id
self.opcode = opcode
self.operands = operands
# -----------------------------------------------------------------------------
class CommandFrame(Frame):
class CommandType(OpenIntEnum):
# AV/C Digital Interface Command Set General Specification Version 4.1
# Table 7.1
CONTROL = 0x00
STATUS = 0x01
SPECIFIC_INQUIRY = 0x02
NOTIFY = 0x03
GENERAL_INQUIRY = 0x04
subclasses: Dict[Frame.OperationCode, Type[CommandFrame]] = {}
ctype: CommandType
@staticmethod
def parse_operands(operands: bytes) -> Tuple:
raise NotImplementedError
def __init__(
self,
ctype: CommandType,
subunit_type: Frame.SubunitType,
subunit_id: int,
opcode: Frame.OperationCode,
operands: bytes,
) -> None:
super().__init__(subunit_type, subunit_id, opcode, operands)
self.ctype = ctype
def __bytes__(self):
return self.to_bytes(self.ctype)
def __str__(self):
return self.to_string(f"ctype={self.ctype.name}, ")
# -----------------------------------------------------------------------------
class ResponseFrame(Frame):
class ResponseCode(OpenIntEnum):
# AV/C Digital Interface Command Set General Specification Version 4.1
# Table 7.2
NOT_IMPLEMENTED = 0x08
ACCEPTED = 0x09
REJECTED = 0x0A
IN_TRANSITION = 0x0B
IMPLEMENTED_OR_STABLE = 0x0C
CHANGED = 0x0D
INTERIM = 0x0F
subclasses: Dict[Frame.OperationCode, Type[ResponseFrame]] = {}
response: ResponseCode
@staticmethod
def parse_operands(operands: bytes) -> Tuple:
raise NotImplementedError
def __init__(
self,
response: ResponseCode,
subunit_type: Frame.SubunitType,
subunit_id: int,
opcode: Frame.OperationCode,
operands: bytes,
) -> None:
super().__init__(subunit_type, subunit_id, opcode, operands)
self.response = response
def __bytes__(self):
return self.to_bytes(self.response)
def __str__(self):
return self.to_string(f"response={self.response.name}, ")
# -----------------------------------------------------------------------------
class VendorDependentFrame:
company_id: int
vendor_dependent_data: bytes
@staticmethod
def parse_operands(operands: bytes) -> Tuple:
return (
struct.unpack(">I", b"\x00" + operands[:3])[0],
operands[3:],
)
def make_operands(self) -> bytes:
return struct.pack(">I", self.company_id)[1:] + self.vendor_dependent_data
def __init__(self, company_id: int, vendor_dependent_data: bytes):
self.company_id = company_id
self.vendor_dependent_data = vendor_dependent_data
# -----------------------------------------------------------------------------
@Frame.subclass
class VendorDependentCommandFrame(VendorDependentFrame, CommandFrame):
def __init__(
self,
ctype: CommandFrame.CommandType,
subunit_type: Frame.SubunitType,
subunit_id: int,
company_id: int,
vendor_dependent_data: bytes,
) -> None:
VendorDependentFrame.__init__(self, company_id, vendor_dependent_data)
CommandFrame.__init__(
self,
ctype,
subunit_type,
subunit_id,
Frame.OperationCode.VENDOR_DEPENDENT,
self.make_operands(),
)
def __str__(self):
return (
f"VendorDependentCommandFrame(ctype={self.ctype.name}, "
f"subunit_type={self.subunit_type.name}, "
f"subunit_id=0x{self.subunit_id:02X}, "
f"company_id=0x{self.company_id:06X}, "
f"vendor_dependent_data={self.vendor_dependent_data.hex()})"
)
# -----------------------------------------------------------------------------
@Frame.subclass
class VendorDependentResponseFrame(VendorDependentFrame, ResponseFrame):
def __init__(
self,
response: ResponseFrame.ResponseCode,
subunit_type: Frame.SubunitType,
subunit_id: int,
company_id: int,
vendor_dependent_data: bytes,
) -> None:
VendorDependentFrame.__init__(self, company_id, vendor_dependent_data)
ResponseFrame.__init__(
self,
response,
subunit_type,
subunit_id,
Frame.OperationCode.VENDOR_DEPENDENT,
self.make_operands(),
)
def __str__(self):
return (
f"VendorDependentResponseFrame(response={self.response.name}, "
f"subunit_type={self.subunit_type.name}, "
f"subunit_id=0x{self.subunit_id:02X}, "
f"company_id=0x{self.company_id:06X}, "
f"vendor_dependent_data={self.vendor_dependent_data.hex()})"
)
# -----------------------------------------------------------------------------
class PassThroughFrame:
"""
See AV/C Panel Subunit Specification 1.1 - 9.4 PASS THROUGH control command
"""
class StateFlag(enum.IntEnum):
PRESSED = 0
RELEASED = 1
class OperationId(OpenIntEnum):
SELECT = 0x00
UP = 0x01
DOWN = 0x01
LEFT = 0x03
RIGHT = 0x04
RIGHT_UP = 0x05
RIGHT_DOWN = 0x06
LEFT_UP = 0x07
LEFT_DOWN = 0x08
ROOT_MENU = 0x09
SETUP_MENU = 0x0A
CONTENTS_MENU = 0x0B
FAVORITE_MENU = 0x0C
EXIT = 0x0D
NUMBER_0 = 0x20
NUMBER_1 = 0x21
NUMBER_2 = 0x22
NUMBER_3 = 0x23
NUMBER_4 = 0x24
NUMBER_5 = 0x25
NUMBER_6 = 0x26
NUMBER_7 = 0x27
NUMBER_8 = 0x28
NUMBER_9 = 0x29
DOT = 0x2A
ENTER = 0x2B
CLEAR = 0x2C
CHANNEL_UP = 0x30
CHANNEL_DOWN = 0x31
PREVIOUS_CHANNEL = 0x32
SOUND_SELECT = 0x33
INPUT_SELECT = 0x34
DISPLAY_INFORMATION = 0x35
HELP = 0x36
PAGE_UP = 0x37
PAGE_DOWN = 0x38
POWER = 0x40
VOLUME_UP = 0x41
VOLUME_DOWN = 0x42
MUTE = 0x43
PLAY = 0x44
STOP = 0x45
PAUSE = 0x46
RECORD = 0x47
REWIND = 0x48
FAST_FORWARD = 0x49
EJECT = 0x4A
FORWARD = 0x4B
BACKWARD = 0x4C
ANGLE = 0x50
SUBPICTURE = 0x51
F1 = 0x71
F2 = 0x72
F3 = 0x73
F4 = 0x74
F5 = 0x75
VENDOR_UNIQUE = 0x7E
state_flag: StateFlag
operation_id: OperationId
operation_data: bytes
@staticmethod
def parse_operands(operands: bytes) -> Tuple:
return (
PassThroughFrame.StateFlag(operands[0] >> 7),
PassThroughFrame.OperationId(operands[0] & 0x7F),
operands[1 : 1 + operands[1]],
)
def make_operands(self):
return (
bytes([self.state_flag << 7 | self.operation_id, len(self.operation_data)])
+ self.operation_data
)
def __init__(
self,
state_flag: StateFlag,
operation_id: OperationId,
operation_data: bytes,
) -> None:
if len(operation_data) > 255:
raise ValueError("operation data must be <= 255 bytes")
self.state_flag = state_flag
self.operation_id = operation_id
self.operation_data = operation_data
# -----------------------------------------------------------------------------
@Frame.subclass
class PassThroughCommandFrame(PassThroughFrame, CommandFrame):
def __init__(
self,
ctype: CommandFrame.CommandType,
subunit_type: Frame.SubunitType,
subunit_id: int,
state_flag: PassThroughFrame.StateFlag,
operation_id: PassThroughFrame.OperationId,
operation_data: bytes,
) -> None:
PassThroughFrame.__init__(self, state_flag, operation_id, operation_data)
CommandFrame.__init__(
self,
ctype,
subunit_type,
subunit_id,
Frame.OperationCode.PASS_THROUGH,
self.make_operands(),
)
def __str__(self):
return (
f"PassThroughCommandFrame(ctype={self.ctype.name}, "
f"subunit_type={self.subunit_type.name}, "
f"subunit_id=0x{self.subunit_id:02X}, "
f"state_flag={self.state_flag.name}, "
f"operation_id={self.operation_id.name}, "
f"operation_data={self.operation_data.hex()})"
)
# -----------------------------------------------------------------------------
@Frame.subclass
class PassThroughResponseFrame(PassThroughFrame, ResponseFrame):
def __init__(
self,
response: ResponseFrame.ResponseCode,
subunit_type: Frame.SubunitType,
subunit_id: int,
state_flag: PassThroughFrame.StateFlag,
operation_id: PassThroughFrame.OperationId,
operation_data: bytes,
) -> None:
PassThroughFrame.__init__(self, state_flag, operation_id, operation_data)
ResponseFrame.__init__(
self,
response,
subunit_type,
subunit_id,
Frame.OperationCode.PASS_THROUGH,
self.make_operands(),
)
def __str__(self):
return (
f"PassThroughResponseFrame(response={self.response.name}, "
f"subunit_type={self.subunit_type.name}, "
f"subunit_id=0x{self.subunit_id:02X}, "
f"state_flag={self.state_flag.name}, "
f"operation_id={self.operation_id.name}, "
f"operation_data={self.operation_data.hex()})"
)

View File

@@ -1,291 +0,0 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
from enum import IntEnum
import logging
import struct
from typing import Callable, cast, Dict, Optional
from bumble.colors import color
from bumble import avc
from bumble import l2cap
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
AVCTP_PSM = 0x0017
AVCTP_BROWSING_PSM = 0x001B
# -----------------------------------------------------------------------------
class MessageAssembler:
Callback = Callable[[int, bool, bool, int, bytes], None]
transaction_label: int
pid: int
c_r: int
ipid: int
payload: bytes
number_of_packets: int
packets_received: int
def __init__(self, callback: Callback) -> None:
self.callback = callback
self.reset()
def reset(self) -> None:
self.packets_received = 0
self.transaction_label = -1
self.pid = -1
self.c_r = -1
self.ipid = -1
self.payload = b''
self.number_of_packets = 0
self.packet_count = 0
def on_pdu(self, pdu: bytes) -> None:
self.packets_received += 1
transaction_label = pdu[0] >> 4
packet_type = Protocol.PacketType((pdu[0] >> 2) & 3)
c_r = (pdu[0] >> 1) & 1
ipid = pdu[0] & 1
if c_r == 0 and ipid != 0:
logger.warning("invalid IPID in command frame")
self.reset()
return
pid_offset = 1
if packet_type in (Protocol.PacketType.SINGLE, Protocol.PacketType.START):
if self.transaction_label >= 0:
# We are already in a transaction
logger.warning("received START or SINGLE fragment while in transaction")
self.reset()
self.packets_received = 1
if packet_type == Protocol.PacketType.START:
self.number_of_packets = pdu[1]
pid_offset = 2
pid = struct.unpack_from(">H", pdu, pid_offset)[0]
self.payload += pdu[pid_offset + 2 :]
if packet_type in (Protocol.PacketType.CONTINUE, Protocol.PacketType.END):
if transaction_label != self.transaction_label:
logger.warning("transaction label does not match")
self.reset()
return
if pid != self.pid:
logger.warning("PID does not match")
self.reset()
return
if c_r != self.c_r:
logger.warning("C/R does not match")
self.reset()
return
if self.packets_received > self.number_of_packets:
logger.warning("too many fragments in transaction")
self.reset()
return
if packet_type == Protocol.PacketType.END:
if self.packets_received != self.number_of_packets:
logger.warning("premature END")
self.reset()
return
else:
self.transaction_label = transaction_label
self.c_r = c_r
self.ipid = ipid
self.pid = pid
if packet_type in (Protocol.PacketType.SINGLE, Protocol.PacketType.END):
self.on_message_complete()
def on_message_complete(self):
try:
self.callback(
self.transaction_label,
self.c_r == 0,
self.ipid != 0,
self.pid,
self.payload,
)
except Exception as error:
logger.exception(color(f"!!! exception in callback: {error}", "red"))
self.reset()
# -----------------------------------------------------------------------------
class Protocol:
CommandHandler = Callable[[int, avc.CommandFrame], None]
command_handlers: Dict[int, CommandHandler] # Command handlers, by PID
ResponseHandler = Callable[[int, Optional[avc.ResponseFrame]], None]
response_handlers: Dict[int, ResponseHandler] # Response handlers, by PID
next_transaction_label: int
message_assembler: MessageAssembler
class PacketType(IntEnum):
SINGLE = 0b00
START = 0b01
CONTINUE = 0b10
END = 0b11
def __init__(self, l2cap_channel: l2cap.ClassicChannel) -> None:
self.command_handlers = {}
self.response_handlers = {}
self.l2cap_channel = l2cap_channel
self.message_assembler = MessageAssembler(self.on_message)
# Register to receive PDUs from the channel
l2cap_channel.sink = self.on_pdu
l2cap_channel.on("open", self.on_l2cap_channel_open)
l2cap_channel.on("close", self.on_l2cap_channel_close)
def on_l2cap_channel_open(self):
logger.debug(color("<<< AVCTP channel open", "magenta"))
def on_l2cap_channel_close(self):
logger.debug(color("<<< AVCTP channel closed", "magenta"))
def on_pdu(self, pdu: bytes) -> None:
self.message_assembler.on_pdu(pdu)
def on_message(
self,
transaction_label: int,
is_command: bool,
ipid: bool,
pid: int,
payload: bytes,
) -> None:
logger.debug(
f"<<< AVCTP Message: pid={pid}, "
f"transaction_label={transaction_label}, "
f"is_command={is_command}, "
f"ipid={ipid}, "
f"payload={payload.hex()}"
)
# Check for invalid PID responses.
if ipid:
logger.debug(f"received IPID for PID={pid}")
# Find the appropriate handler.
if is_command:
if pid not in self.command_handlers:
logger.warning(f"no command handler for PID {pid}")
self.send_ipid(transaction_label, pid)
return
command_frame = cast(avc.CommandFrame, avc.Frame.from_bytes(payload))
self.command_handlers[pid](transaction_label, command_frame)
else:
if pid not in self.response_handlers:
logger.warning(f"no response handler for PID {pid}")
return
# By convention, for an ipid, send a None payload to the response handler.
if ipid:
response_frame = None
else:
response_frame = cast(avc.ResponseFrame, avc.Frame.from_bytes(payload))
self.response_handlers[pid](transaction_label, response_frame)
def send_message(
self,
transaction_label: int,
is_command: bool,
ipid: bool,
pid: int,
payload: bytes,
):
# TODO: fragment large messages
packet_type = Protocol.PacketType.SINGLE
pdu = (
struct.pack(
">BH",
transaction_label << 4
| packet_type << 2
| (0 if is_command else 1) << 1
| (1 if ipid else 0),
pid,
)
+ payload
)
self.l2cap_channel.send_pdu(pdu)
def send_command(self, transaction_label: int, pid: int, payload: bytes) -> None:
logger.debug(
">>> AVCTP command: "
f"transaction_label={transaction_label}, "
f"pid={pid}, "
f"payload={payload.hex()}"
)
self.send_message(transaction_label, True, False, pid, payload)
def send_response(self, transaction_label: int, pid: int, payload: bytes):
logger.debug(
">>> AVCTP response: "
f"transaction_label={transaction_label}, "
f"pid={pid}, "
f"payload={payload.hex()}"
)
self.send_message(transaction_label, False, False, pid, payload)
def send_ipid(self, transaction_label: int, pid: int) -> None:
logger.debug(
">>> AVCTP ipid: " f"transaction_label={transaction_label}, " f"pid={pid}"
)
self.send_message(transaction_label, False, True, pid, b'')
def register_command_handler(
self, pid: int, handler: Protocol.CommandHandler
) -> None:
self.command_handlers[pid] = handler
def unregister_command_handler(
self, pid: int, handler: Protocol.CommandHandler
) -> None:
if pid not in self.command_handlers or self.command_handlers[pid] != handler:
raise ValueError("command handler not registered")
del self.command_handlers[pid]
def register_response_handler(
self, pid: int, handler: Protocol.ResponseHandler
) -> None:
self.response_handlers[pid] = handler
def unregister_response_handler(
self, pid: int, handler: Protocol.ResponseHandler
) -> None:
if pid not in self.response_handlers or self.response_handlers[pid] != handler:
raise ValueError("response handler not registered")
del self.response_handlers[pid]

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -19,7 +19,6 @@ from __future__ import annotations
import logging
import asyncio
import dataclasses
import itertools
import random
import struct
@@ -43,7 +42,6 @@ from bumble.hci import (
HCI_LE_1M_PHY,
HCI_SUCCESS,
HCI_UNKNOWN_HCI_COMMAND_ERROR,
HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR,
HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR,
HCI_VERSION_BLUETOOTH_CORE_5_0,
Address,
@@ -55,21 +53,17 @@ from bumble.hci import (
HCI_Connection_Request_Event,
HCI_Disconnection_Complete_Event,
HCI_Encryption_Change_Event,
HCI_Synchronous_Connection_Complete_Event,
HCI_LE_Advertising_Report_Event,
HCI_LE_CIS_Established_Event,
HCI_LE_CIS_Request_Event,
HCI_LE_Connection_Complete_Event,
HCI_LE_Read_Remote_Features_Complete_Event,
HCI_Number_Of_Completed_Packets_Event,
HCI_Packet,
HCI_Role_Change_Event,
)
from typing import Optional, Union, Dict, Any, TYPE_CHECKING
from typing import Optional, Union, Dict, TYPE_CHECKING
if TYPE_CHECKING:
from bumble.link import LocalLink
from bumble.transport.common import TransportSink
from bumble.transport.common import TransportSink, TransportSource
# -----------------------------------------------------------------------------
# Logging
@@ -85,27 +79,15 @@ class DataObject:
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class CisLink:
handle: int
cis_id: int
cig_id: int
acl_connection: Optional[Connection] = None
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class Connection:
controller: Controller
handle: int
role: int
peer_address: Address
link: Any
transport: int
link_type: int
def __post_init__(self):
def __init__(self, controller, handle, role, peer_address, link, transport):
self.controller = controller
self.handle = handle
self.role = role
self.peer_address = peer_address
self.link = link
self.assembler = HCI_AclDataPacketAssembler(self.on_acl_pdu)
self.transport = transport
def on_hci_acl_data_packet(self, packet):
self.assembler.feed_packet(packet)
@@ -124,10 +106,10 @@ class Connection:
class Controller:
def __init__(
self,
name: str,
name,
host_source=None,
host_sink: Optional[TransportSink] = None,
link: Optional[LocalLink] = None,
link=None,
public_address: Optional[Union[bytes, str, Address]] = None,
):
self.name = name
@@ -143,8 +125,6 @@ class Controller:
self.classic_connections: Dict[
Address, Connection
] = {} # Connections in BR/EDR
self.central_cis_links: Dict[int, CisLink] = {} # CIS links by handle
self.peripheral_cis_links: Dict[int, CisLink] = {} # CIS links by handle
self.hci_version = HCI_VERSION_BLUETOOTH_CORE_5_0
self.hci_revision = 0
@@ -154,14 +134,12 @@ class Controller:
'0000000060000000'
) # BR/EDR Not Supported, LE Supported (Controller)
self.manufacturer_name = 0xFFFF
self.hc_data_packet_length = 27
self.hc_total_num_data_packets = 64
self.hc_le_data_packet_length = 27
self.hc_total_num_le_data_packets = 64
self.event_mask = 0
self.event_mask_page_2 = 0
self.supported_commands = bytes.fromhex(
'2000800000c000000000e4000000a822000000000000040000f7ffff7f000000'
'2000800000c000000000e40000002822000000000000040000f7ffff7f000000'
'30f0f9ff01008004000000000000000000000000000000000000000000000000'
)
self.le_event_mask = 0
@@ -323,7 +301,7 @@ class Controller:
############################################################
# Link connections
############################################################
def allocate_connection_handle(self) -> int:
def allocate_connection_handle(self):
handle = 0
max_handle = 0
for connection in itertools.chain(
@@ -335,13 +313,6 @@ class Controller:
if connection.handle == handle:
# Already used, continue searching after the current max
handle = max_handle + 1
for cis_handle in itertools.chain(
self.central_cis_links.keys(), self.peripheral_cis_links.keys()
):
max_handle = max(max_handle, cis_handle)
if cis_handle == handle:
# Already used, continue searching after the current max
handle = max_handle + 1
return handle
def find_le_connection_by_address(self, address):
@@ -386,13 +357,12 @@ class Controller:
if connection is None:
connection_handle = self.allocate_connection_handle()
connection = Connection(
controller=self,
handle=connection_handle,
role=BT_PERIPHERAL_ROLE,
peer_address=peer_address,
link=self.link,
transport=BT_LE_TRANSPORT,
link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE,
self,
connection_handle,
BT_PERIPHERAL_ROLE,
peer_address,
self.link,
BT_LE_TRANSPORT,
)
self.peripheral_connections[peer_address] = connection
logger.debug(f'New PERIPHERAL connection handle: 0x{connection_handle:04X}')
@@ -446,13 +416,12 @@ class Controller:
if connection is None:
connection_handle = self.allocate_connection_handle()
connection = Connection(
controller=self,
handle=connection_handle,
role=BT_CENTRAL_ROLE,
peer_address=peer_address,
link=self.link,
transport=BT_LE_TRANSPORT,
link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE,
self,
connection_handle,
BT_CENTRAL_ROLE,
peer_address,
self.link,
BT_LE_TRANSPORT,
)
self.central_connections[peer_address] = connection
logger.debug(
@@ -569,104 +538,6 @@ class Controller:
)
self.send_hci_packet(HCI_LE_Advertising_Report_Event([report]))
def on_link_cis_request(
self, central_address: Address, cig_id: int, cis_id: int
) -> None:
'''
Called when an incoming CIS request occurs from a central on the link
'''
connection = self.peripheral_connections.get(central_address)
assert connection
pending_cis_link = CisLink(
handle=self.allocate_connection_handle(),
cis_id=cis_id,
cig_id=cig_id,
acl_connection=connection,
)
self.peripheral_cis_links[pending_cis_link.handle] = pending_cis_link
self.send_hci_packet(
HCI_LE_CIS_Request_Event(
acl_connection_handle=connection.handle,
cis_connection_handle=pending_cis_link.handle,
cig_id=cig_id,
cis_id=cis_id,
)
)
def on_link_cis_established(self, cig_id: int, cis_id: int) -> None:
'''
Called when an incoming CIS established.
'''
cis_link = next(
cis_link
for cis_link in itertools.chain(
self.central_cis_links.values(), self.peripheral_cis_links.values()
)
if cis_link.cis_id == cis_id and cis_link.cig_id == cig_id
)
self.send_hci_packet(
HCI_LE_CIS_Established_Event(
status=HCI_SUCCESS,
connection_handle=cis_link.handle,
# CIS parameters are ignored.
cig_sync_delay=0,
cis_sync_delay=0,
transport_latency_c_to_p=0,
transport_latency_p_to_c=0,
phy_c_to_p=0,
phy_p_to_c=0,
nse=0,
bn_c_to_p=0,
bn_p_to_c=0,
ft_c_to_p=0,
ft_p_to_c=0,
max_pdu_c_to_p=0,
max_pdu_p_to_c=0,
iso_interval=0,
)
)
def on_link_cis_disconnected(self, cig_id: int, cis_id: int) -> None:
'''
Called when a CIS disconnected.
'''
if cis_link := next(
(
cis_link
for cis_link in self.peripheral_cis_links.values()
if cis_link.cis_id == cis_id and cis_link.cig_id == cig_id
),
None,
):
# Remove peripheral CIS on disconnection.
self.peripheral_cis_links.pop(cis_link.handle)
elif cis_link := next(
(
cis_link
for cis_link in self.central_cis_links.values()
if cis_link.cis_id == cis_id and cis_link.cig_id == cig_id
),
None,
):
# Keep central CIS on disconnection. They should be removed by HCI_LE_Remove_CIG_Command.
cis_link.acl_connection = None
else:
return
self.send_hci_packet(
HCI_Disconnection_Complete_Event(
status=HCI_SUCCESS,
connection_handle=cis_link.handle,
reason=HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR,
)
)
############################################################
# Classic link connections
############################################################
@@ -695,7 +566,6 @@ class Controller:
peer_address=peer_address,
link=self.link,
transport=BT_BR_EDR_TRANSPORT,
link_type=HCI_Connection_Complete_Event.ACL_LINK_TYPE,
)
self.classic_connections[peer_address] = connection
logger.debug(
@@ -749,42 +619,6 @@ class Controller:
)
)
def on_classic_sco_connection_complete(
self, peer_address: Address, status: int, link_type: int
):
if status == HCI_SUCCESS:
# Allocate (or reuse) a connection handle
connection_handle = self.allocate_connection_handle()
connection = Connection(
controller=self,
handle=connection_handle,
# Role doesn't matter in SCO.
role=BT_CENTRAL_ROLE,
peer_address=peer_address,
link=self.link,
transport=BT_BR_EDR_TRANSPORT,
link_type=link_type,
)
self.classic_connections[peer_address] = connection
logger.debug(f'New SCO connection handle: 0x{connection_handle:04X}')
else:
connection_handle = 0
self.send_hci_packet(
HCI_Synchronous_Connection_Complete_Event(
status=status,
connection_handle=connection_handle,
bd_addr=peer_address,
link_type=link_type,
# TODO: Provide SCO connection parameters.
transmission_interval=0,
retransmission_window=0,
rx_packet_length=0,
tx_packet_length=0,
air_mode=0,
)
)
############################################################
# Advertising support
############################################################
@@ -887,17 +721,6 @@ class Controller:
else:
# Remove the connection
del self.classic_connections[connection.peer_address]
elif cis_link := (
self.central_cis_links.get(handle) or self.peripheral_cis_links.get(handle)
):
if self.link:
self.link.disconnect_cis(
initiator_controller=self,
peer_address=cis_link.acl_connection.peer_address,
cig_id=cis_link.cig_id,
cis_id=cis_link.cis_id,
)
# Spec requires handle to be kept after disconnection.
def on_hci_accept_connection_request_command(self, command):
'''
@@ -915,68 +738,6 @@ class Controller:
)
self.link.classic_accept_connection(self, command.bd_addr, command.role)
def on_hci_enhanced_setup_synchronous_connection_command(self, command):
'''
See Bluetooth spec Vol 4, Part E - 7.1.45 Enhanced Setup Synchronous Connection command
'''
if self.link is None:
return
if not (
connection := self.find_classic_connection_by_handle(
command.connection_handle
)
):
self.send_hci_packet(
HCI_Command_Status_Event(
status=HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR,
num_hci_command_packets=1,
command_opcode=command.op_code,
)
)
return
self.send_hci_packet(
HCI_Command_Status_Event(
status=HCI_SUCCESS,
num_hci_command_packets=1,
command_opcode=command.op_code,
)
)
self.link.classic_sco_connect(
self, connection.peer_address, HCI_Connection_Complete_Event.ESCO_LINK_TYPE
)
def on_hci_enhanced_accept_synchronous_connection_request_command(self, command):
'''
See Bluetooth spec Vol 4, Part E - 7.1.46 Enhanced Accept Synchronous Connection Request command
'''
if self.link is None:
return
if not (connection := self.find_classic_connection_by_address(command.bd_addr)):
self.send_hci_packet(
HCI_Command_Status_Event(
status=HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR,
num_hci_command_packets=1,
command_opcode=command.op_code,
)
)
return
self.send_hci_packet(
HCI_Command_Status_Event(
status=HCI_SUCCESS,
num_hci_command_packets=1,
command_opcode=command.op_code,
)
)
self.link.classic_accept_sco_connection(
self, connection.peer_address, HCI_Connection_Complete_Event.ESCO_LINK_TYPE
)
def on_hci_switch_role_command(self, command):
'''
See Bluetooth spec Vol 4, Part E - 7.2.8 Switch Role command
@@ -1151,41 +912,7 @@ class Controller:
'''
See Bluetooth spec Vol 4, Part E - 7.4.3 Read Local Supported Features Command
'''
return bytes([HCI_SUCCESS]) + self.lmp_features[:8]
def on_hci_read_local_extended_features_command(self, command):
'''
See Bluetooth spec Vol 4, Part E - 7.4.4 Read Local Extended Features Command
'''
if command.page_number * 8 > len(self.lmp_features):
return bytes([HCI_INVALID_HCI_COMMAND_PARAMETERS_ERROR])
return (
bytes(
[
# Status
HCI_SUCCESS,
# Page number
command.page_number,
# Max page number
len(self.lmp_features) // 8 - 1,
]
)
# Features of the current page
+ self.lmp_features[command.page_number * 8 : (command.page_number + 1) * 8]
)
def on_hci_read_buffer_size_command(self, _command):
'''
See Bluetooth spec Vol 4, Part E - 7.4.5 Read Buffer Size Command
'''
return struct.pack(
'<BHBHH',
HCI_SUCCESS,
self.hc_data_packet_length,
0,
self.hc_total_num_data_packets,
0,
)
return bytes([HCI_SUCCESS]) + self.lmp_features
def on_hci_read_bd_addr_command(self, _command):
'''
@@ -1273,9 +1000,6 @@ class Controller:
'''
See Bluetooth spec Vol 4, Part E - 7.8.10 LE Set Scan Parameters Command
'''
if self.le_scan_enable:
return bytes([HCI_COMMAND_DISALLOWED_ERROR])
self.le_scan_type = command.le_scan_type
self.le_scan_interval = command.le_scan_interval
self.le_scan_window = command.le_scan_window
@@ -1362,18 +1086,6 @@ class Controller:
See Bluetooth spec Vol 4, Part E - 7.8.21 LE Read Remote Features Command
'''
handle = command.connection_handle
if not self.find_connection_by_handle(handle):
self.send_hci_packet(
HCI_Command_Status_Event(
status=HCI_INVALID_HCI_COMMAND_PARAMETERS_ERROR,
num_hci_command_packets=1,
command_opcode=command.op_code,
)
)
return
# First, say that the command is pending
self.send_hci_packet(
HCI_Command_Status_Event(
@@ -1387,7 +1099,7 @@ class Controller:
self.send_hci_packet(
HCI_LE_Read_Remote_Features_Complete_Event(
status=HCI_SUCCESS,
connection_handle=handle,
connection_handle=0,
le_features=bytes.fromhex('dd40000000000000'),
)
)
@@ -1543,135 +1255,8 @@ class Controller:
}
return bytes([HCI_SUCCESS])
def on_hci_le_read_maximum_advertising_data_length_command(self, _command):
'''
See Bluetooth spec Vol 4, Part E - 7.8.57 LE Read Maximum Advertising Data
Length Command
'''
return struct.pack('<BH', HCI_SUCCESS, 0x0672)
def on_hci_le_read_number_of_supported_advertising_sets_command(self, _command):
'''
See Bluetooth spec Vol 4, Part E - 7.8.58 LE Read Number of Supported
Advertising Set Command
'''
return struct.pack('<BB', HCI_SUCCESS, 0xF0)
def on_hci_le_read_transmit_power_command(self, _command):
'''
See Bluetooth spec Vol 4, Part E - 7.8.74 LE Read Transmit Power Command
'''
return struct.pack('<BBB', HCI_SUCCESS, 0, 0)
def on_hci_le_set_cig_parameters_command(self, command):
'''
See Bluetooth spec Vol 4, Part E - 7.8.97 LE Set CIG Parameter Command
'''
# Remove old CIG implicitly.
for handle, cis_link in self.central_cis_links.items():
if cis_link.cig_id == command.cig_id:
self.central_cis_links.pop(handle)
handles = []
for cis_id in command.cis_id:
handle = self.allocate_connection_handle()
handles.append(handle)
self.central_cis_links[handle] = CisLink(
cis_id=cis_id,
cig_id=command.cig_id,
handle=handle,
)
return struct.pack(
'<BBB', HCI_SUCCESS, command.cig_id, len(handles)
) + b''.join([struct.pack('<H', handle) for handle in handles])
def on_hci_le_create_cis_command(self, command):
'''
See Bluetooth spec Vol 4, Part E - 7.8.99 LE Create CIS Command
'''
if not self.link:
return
for cis_handle, acl_handle in zip(
command.cis_connection_handle, command.acl_connection_handle
):
if not (connection := self.find_connection_by_handle(acl_handle)):
logger.error(f'Cannot find connection with handle={acl_handle}')
return bytes([HCI_INVALID_HCI_COMMAND_PARAMETERS_ERROR])
if not (cis_link := self.central_cis_links.get(cis_handle)):
logger.error(f'Cannot find CIS with handle={cis_handle}')
return bytes([HCI_INVALID_HCI_COMMAND_PARAMETERS_ERROR])
cis_link.acl_connection = connection
self.link.create_cis(
self,
peripheral_address=connection.peer_address,
cig_id=cis_link.cig_id,
cis_id=cis_link.cis_id,
)
self.send_hci_packet(
HCI_Command_Status_Event(
status=HCI_COMMAND_STATUS_PENDING,
num_hci_command_packets=1,
command_opcode=command.op_code,
)
)
def on_hci_le_remove_cig_command(self, command):
'''
See Bluetooth spec Vol 4, Part E - 7.8.100 LE Remove CIG Command
'''
status = HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR
for cis_handle, cis_link in self.central_cis_links.items():
if cis_link.cig_id == command.cig_id:
self.central_cis_links.pop(cis_handle)
status = HCI_SUCCESS
return struct.pack('<BH', status, command.cig_id)
def on_hci_le_accept_cis_request_command(self, command):
'''
See Bluetooth spec Vol 4, Part E - 7.8.101 LE Accept CIS Request Command
'''
if not self.link:
return
if not (
pending_cis_link := self.peripheral_cis_links.get(command.connection_handle)
):
logger.error(f'Cannot find CIS with handle={command.connection_handle}')
return bytes([HCI_INVALID_HCI_COMMAND_PARAMETERS_ERROR])
assert pending_cis_link.acl_connection
self.link.accept_cis(
peripheral_controller=self,
central_address=pending_cis_link.acl_connection.peer_address,
cig_id=pending_cis_link.cig_id,
cis_id=pending_cis_link.cis_id,
)
self.send_hci_packet(
HCI_Command_Status_Event(
status=HCI_COMMAND_STATUS_PENDING,
num_hci_command_packets=1,
command_opcode=command.op_code,
)
)
def on_hci_le_setup_iso_data_path_command(self, command):
'''
See Bluetooth spec Vol 4, Part E - 7.8.109 LE Setup ISO Data Path Command
'''
return struct.pack('<BH', HCI_SUCCESS, command.connection_handle)
def on_hci_le_remove_iso_data_path_command(self, command):
'''
See Bluetooth spec Vol 4, Part E - 7.8.110 LE Remove ISO Data Path Command
'''
return struct.pack('<BH', HCI_SUCCESS, command.connection_handle)

View File

@@ -16,7 +16,6 @@
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import enum
import struct
from typing import List, Optional, Tuple, Union, cast, Dict
@@ -97,16 +96,12 @@ class BaseError(Exception):
namespace = f'{self.error_namespace}/'
else:
namespace = ''
have_name = self.error_name != ''
have_code = self.error_code is not None
if have_name and have_code:
error_text = f'{self.error_name} [0x{self.error_code:X}]'
elif have_name and not have_code:
error_text = self.error_name
elif not have_name and have_code:
error_text = f'0x{self.error_code:X}'
else:
error_text = '<unspecified>'
error_text = {
(True, True): f'{self.error_name} [0x{self.error_code:X}]',
(True, False): self.error_name,
(False, True): f'0x{self.error_code:X}',
(False, False): '',
}[(self.error_name != '', self.error_code is not None)]
return f'{type(self).__name__}({namespace}{error_text})'
@@ -323,7 +318,7 @@ BT_HIDP_PROTOCOL_ID = UUID.from_16_bits(0x0011, 'HIDP')
BT_HARDCOPY_CONTROL_CHANNEL_PROTOCOL_ID = UUID.from_16_bits(0x0012, 'HardcopyControlChannel')
BT_HARDCOPY_DATA_CHANNEL_PROTOCOL_ID = UUID.from_16_bits(0x0014, 'HardcopyDataChannel')
BT_HARDCOPY_NOTIFICATION_PROTOCOL_ID = UUID.from_16_bits(0x0016, 'HardcopyNotification')
BT_AVCTP_PROTOCOL_ID = UUID.from_16_bits(0x0017, 'AVCTP')
BT_AVTCP_PROTOCOL_ID = UUID.from_16_bits(0x0017, 'AVCTP')
BT_AVDTP_PROTOCOL_ID = UUID.from_16_bits(0x0019, 'AVDTP')
BT_CMTP_PROTOCOL_ID = UUID.from_16_bits(0x001B, 'CMTP')
BT_MCAP_CONTROL_CHANNEL_PROTOCOL_ID = UUID.from_16_bits(0x001E, 'MCAPControlChannel')
@@ -825,8 +820,8 @@ class AdvertisingData:
ad_structures = []
self.ad_structures = ad_structures[:]
@classmethod
def from_bytes(cls, data: bytes) -> AdvertisingData:
@staticmethod
def from_bytes(data):
instance = AdvertisingData()
instance.append(data)
return instance
@@ -982,7 +977,7 @@ class AdvertisingData:
return ad_data
def append(self, data: bytes) -> None:
def append(self, data):
offset = 0
while offset + 1 < len(data):
length = data[offset]
@@ -1056,13 +1051,3 @@ class ConnectionPHY:
def __str__(self):
return f'ConnectionPHY(tx_phy={self.tx_phy}, rx_phy={self.rx_phy})'
# -----------------------------------------------------------------------------
# LE Role
# -----------------------------------------------------------------------------
class LeRole(enum.IntEnum):
PERIPHERAL_ONLY = 0x00
CENTRAL_ONLY = 0x01
BOTH_PERIPHERAL_PREFERRED = 0x02
BOTH_CENTRAL_PREFERRED = 0x03

View File

@@ -21,8 +21,6 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import logging
import operator
@@ -31,13 +29,11 @@ from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.primitives.asymmetric.ec import (
generate_private_key,
ECDH,
EllipticCurvePrivateKey,
EllipticCurvePublicNumbers,
EllipticCurvePrivateNumbers,
SECP256R1,
)
from cryptography.hazmat.primitives import cmac
from typing import Tuple
# -----------------------------------------------------------------------------
@@ -50,18 +46,16 @@ logger = logging.getLogger(__name__)
# Classes
# -----------------------------------------------------------------------------
class EccKey:
def __init__(self, private_key: EllipticCurvePrivateKey) -> None:
def __init__(self, private_key):
self.private_key = private_key
@classmethod
def generate(cls) -> EccKey:
def generate(cls):
private_key = generate_private_key(SECP256R1())
return cls(private_key)
@classmethod
def from_private_key_bytes(
cls, d_bytes: bytes, x_bytes: bytes, y_bytes: bytes
) -> EccKey:
def from_private_key_bytes(cls, d_bytes, x_bytes, y_bytes):
d = int.from_bytes(d_bytes, byteorder='big', signed=False)
x = int.from_bytes(x_bytes, byteorder='big', signed=False)
y = int.from_bytes(y_bytes, byteorder='big', signed=False)
@@ -71,7 +65,7 @@ class EccKey:
return cls(private_key)
@property
def x(self) -> bytes:
def x(self):
return (
self.private_key.public_key()
.public_numbers()
@@ -79,14 +73,14 @@ class EccKey:
)
@property
def y(self) -> bytes:
def y(self):
return (
self.private_key.public_key()
.public_numbers()
.y.to_bytes(32, byteorder='big')
)
def dh(self, public_key_x: bytes, public_key_y: bytes) -> bytes:
def dh(self, public_key_x, public_key_y):
x = int.from_bytes(public_key_x, byteorder='big', signed=False)
y = int.from_bytes(public_key_y, byteorder='big', signed=False)
public_key = EllipticCurvePublicNumbers(x, y, SECP256R1()).public_key()
@@ -99,33 +93,14 @@ class EccKey:
# Functions
# -----------------------------------------------------------------------------
# -----------------------------------------------------------------------------
def generate_prand() -> bytes:
'''Generates random 3 bytes, with the 2 most significant bits of 0b01.
See Bluetooth spec, Vol 6, Part E - Table 1.2.
'''
prand_bytes = secrets.token_bytes(6)
return prand_bytes[:2] + bytes([(prand_bytes[2] & 0b01111111) | 0b01000000])
# -----------------------------------------------------------------------------
def xor(x: bytes, y: bytes) -> bytes:
def xor(x, y):
assert len(x) == len(y)
return bytes(map(operator.xor, x, y))
# -----------------------------------------------------------------------------
def reverse(input: bytes) -> bytes:
'''
Returns bytes of input in reversed endianness.
'''
return input[::-1]
# -----------------------------------------------------------------------------
def r() -> bytes:
def r():
'''
Generate 16 bytes of random data
'''
@@ -133,20 +108,20 @@ def r() -> bytes:
# -----------------------------------------------------------------------------
def e(key: bytes, data: bytes) -> bytes:
def e(key, data):
'''
AES-128 ECB, expecting byte-swapped inputs and producing a byte-swapped output.
See Bluetooth spec Vol 3, Part H - 2.2.1 Security function e
'''
cipher = Cipher(algorithms.AES(reverse(key)), modes.ECB())
cipher = Cipher(algorithms.AES(bytes(reversed(key))), modes.ECB())
encryptor = cipher.encryptor()
return reverse(encryptor.update(reverse(data)))
return bytes(reversed(encryptor.update(bytes(reversed(data)))))
# -----------------------------------------------------------------------------
def ah(k: bytes, r: bytes) -> bytes: # pylint: disable=redefined-outer-name
def ah(k, r): # pylint: disable=redefined-outer-name
'''
See Bluetooth spec Vol 3, Part H - 2.2.2 Random Address Hash function ah
'''
@@ -157,16 +132,7 @@ def ah(k: bytes, r: bytes) -> bytes: # pylint: disable=redefined-outer-name
# -----------------------------------------------------------------------------
def c1(
k: bytes,
r: bytes,
preq: bytes,
pres: bytes,
iat: int,
rat: int,
ia: bytes,
ra: bytes,
) -> bytes: # pylint: disable=redefined-outer-name
def c1(k, r, preq, pres, iat, rat, ia, ra): # pylint: disable=redefined-outer-name
'''
See Bluetooth spec, Vol 3, Part H - 2.2.3 Confirm value generation function c1 for
LE Legacy Pairing
@@ -178,7 +144,7 @@ def c1(
# -----------------------------------------------------------------------------
def s1(k: bytes, r1: bytes, r2: bytes) -> bytes:
def s1(k, r1, r2):
'''
See Bluetooth spec, Vol 3, Part H - 2.2.4 Key generation function s1 for LE Legacy
Pairing
@@ -188,7 +154,7 @@ def s1(k: bytes, r1: bytes, r2: bytes) -> bytes:
# -----------------------------------------------------------------------------
def aes_cmac(m: bytes, k: bytes) -> bytes:
def aes_cmac(m, k):
'''
See Bluetooth spec, Vol 3, Part H - 2.2.5 FunctionAES-CMAC
@@ -200,16 +166,20 @@ def aes_cmac(m: bytes, k: bytes) -> bytes:
# -----------------------------------------------------------------------------
def f4(u: bytes, v: bytes, x: bytes, z: bytes) -> bytes:
def f4(u, v, x, z):
'''
See Bluetooth spec, Vol 3, Part H - 2.2.6 LE Secure Connections Confirm Value
Generation Function f4
'''
return reverse(aes_cmac(reverse(u) + reverse(v) + z, reverse(x)))
return bytes(
reversed(
aes_cmac(bytes(reversed(u)) + bytes(reversed(v)) + z, bytes(reversed(x)))
)
)
# -----------------------------------------------------------------------------
def f5(w: bytes, n1: bytes, n2: bytes, a1: bytes, a2: bytes) -> Tuple[bytes, bytes]:
def f5(w, n1, n2, a1, a2):
'''
See Bluetooth spec, Vol 3, Part H - 2.2.7 LE Secure Connections Key Generation
Function f5
@@ -217,83 +187,87 @@ def f5(w: bytes, n1: bytes, n2: bytes, a1: bytes, a2: bytes) -> Tuple[bytes, byt
NOTE: this returns a tuple: (MacKey, LTK) in little-endian byte order
'''
salt = bytes.fromhex('6C888391AAF5A53860370BDB5A6083BE')
t = aes_cmac(reverse(w), salt)
t = aes_cmac(bytes(reversed(w)), salt)
key_id = bytes([0x62, 0x74, 0x6C, 0x65])
return (
reverse(
aes_cmac(
bytes([0])
+ key_id
+ reverse(n1)
+ reverse(n2)
+ reverse(a1)
+ reverse(a2)
+ bytes([1, 0]),
t,
bytes(
reversed(
aes_cmac(
bytes([0])
+ key_id
+ bytes(reversed(n1))
+ bytes(reversed(n2))
+ bytes(reversed(a1))
+ bytes(reversed(a2))
+ bytes([1, 0]),
t,
)
)
),
reverse(
aes_cmac(
bytes([1])
+ key_id
+ reverse(n1)
+ reverse(n2)
+ reverse(a1)
+ reverse(a2)
+ bytes([1, 0]),
t,
bytes(
reversed(
aes_cmac(
bytes([1])
+ key_id
+ bytes(reversed(n1))
+ bytes(reversed(n2))
+ bytes(reversed(a1))
+ bytes(reversed(a2))
+ bytes([1, 0]),
t,
)
)
),
)
# -----------------------------------------------------------------------------
def f6(
w: bytes, n1: bytes, n2: bytes, r: bytes, io_cap: bytes, a1: bytes, a2: bytes
) -> bytes: # pylint: disable=redefined-outer-name
def f6(w, n1, n2, r, io_cap, a1, a2): # pylint: disable=redefined-outer-name
'''
See Bluetooth spec, Vol 3, Part H - 2.2.8 LE Secure Connections Check Value
Generation Function f6
'''
return reverse(
aes_cmac(
reverse(n1)
+ reverse(n2)
+ reverse(r)
+ reverse(io_cap)
+ reverse(a1)
+ reverse(a2),
reverse(w),
return bytes(
reversed(
aes_cmac(
bytes(reversed(n1))
+ bytes(reversed(n2))
+ bytes(reversed(r))
+ bytes(reversed(io_cap))
+ bytes(reversed(a1))
+ bytes(reversed(a2)),
bytes(reversed(w)),
)
)
)
# -----------------------------------------------------------------------------
def g2(u: bytes, v: bytes, x: bytes, y: bytes) -> int:
def g2(u, v, x, y):
'''
See Bluetooth spec, Vol 3, Part H - 2.2.9 LE Secure Connections Numeric Comparison
Value Generation Function g2
'''
return int.from_bytes(
aes_cmac(
reverse(u) + reverse(v) + reverse(y),
reverse(x),
bytes(reversed(u)) + bytes(reversed(v)) + bytes(reversed(y)),
bytes(reversed(x)),
)[-4:],
byteorder='big',
)
# -----------------------------------------------------------------------------
def h6(w: bytes, key_id: bytes) -> bytes:
def h6(w, key_id):
'''
See Bluetooth spec, Vol 3, Part H - 2.2.10 Link key conversion function h6
'''
return reverse(aes_cmac(key_id, reverse(w)))
return aes_cmac(key_id, w)
# -----------------------------------------------------------------------------
def h7(salt: bytes, w: bytes) -> bytes:
def h7(salt, w):
'''
See Bluetooth spec, Vol 3, Part H - 2.2.11 Link key conversion function h7
'''
return reverse(aes_cmac(reverse(w), salt))
return aes_cmac(w, salt)

File diff suppressed because it is too large Load Diff

View File

@@ -19,17 +19,12 @@ like loading firmware after a cold start.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import abc
import logging
import pathlib
import platform
from typing import Dict, Iterable, Optional, Type, TYPE_CHECKING
from . import rtk
from . import rtk, intel
from .common import Driver
if TYPE_CHECKING:
from bumble.host import Host
# -----------------------------------------------------------------------------
# Logging
@@ -37,31 +32,40 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
class Driver(abc.ABC):
"""Base class for drivers."""
@staticmethod
async def for_host(_host):
"""Return a driver instance for a host.
Args:
host: Host object for which a driver should be created.
Returns:
A Driver instance if a driver should be instantiated for this host, or
None if no driver instance of this class is needed.
"""
return None
@abc.abstractmethod
async def init_controller(self):
"""Initialize the controller."""
# -----------------------------------------------------------------------------
# Functions
# -----------------------------------------------------------------------------
async def get_driver_for_host(host: Host) -> Optional[Driver]:
"""Probe diver classes until one returns a valid instance for a host, or none is
found.
If a "driver" HCI metadata entry is present, only that driver class will be probed.
async def get_driver_for_host(host):
"""Probe all known diver classes until one returns a valid instance for a host,
or none is found.
"""
driver_classes: Dict[str, Type[Driver]] = {"rtk": rtk.Driver, "intel": intel.Driver}
probe_list: Iterable[str]
if driver_name := host.hci_metadata.get("driver"):
# Only probe a single driver
probe_list = [driver_name]
else:
# Probe all drivers
probe_list = driver_classes.keys()
for driver_name in probe_list:
if driver_class := driver_classes.get(driver_name):
logger.debug(f"Probing driver class: {driver_name}")
if driver := await driver_class.for_host(host):
logger.debug(f"Instantiated {driver_name} driver")
return driver
else:
logger.debug(f"Skipping unknown driver class: {driver_name}")
if driver := await rtk.Driver.for_host(host):
logger.debug("Instantiated RTK driver")
return driver
return None

View File

@@ -1,45 +0,0 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Common types for drivers.
"""
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import abc
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
class Driver(abc.ABC):
"""Base class for drivers."""
@staticmethod
async def for_host(_host):
"""Return a driver instance for a host.
Args:
host: Host object for which a driver should be created.
Returns:
A Driver instance if a driver should be instantiated for this host, or
None if no driver instance of this class is needed.
"""
return None
@abc.abstractmethod
async def init_controller(self):
"""Initialize the controller."""

View File

@@ -1,102 +0,0 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import logging
from bumble.drivers import common
from bumble.hci import (
hci_vendor_command_op_code, # type: ignore
HCI_Command,
HCI_Reset_Command,
)
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constant
# -----------------------------------------------------------------------------
INTEL_USB_PRODUCTS = {
# Intel AX210
(0x8087, 0x0032),
# Intel BE200
(0x8087, 0x0036),
}
# -----------------------------------------------------------------------------
# HCI Commands
# -----------------------------------------------------------------------------
HCI_INTEL_DDC_CONFIG_WRITE_COMMAND = hci_vendor_command_op_code(0xFC8B) # type: ignore
HCI_INTEL_DDC_CONFIG_WRITE_PAYLOAD = [0x03, 0xE4, 0x02, 0x00]
HCI_Command.register_commands(globals())
@HCI_Command.command( # type: ignore
fields=[("params", "*")],
return_parameters_fields=[
("params", "*"),
],
)
class Hci_Intel_DDC_Config_Write_Command(HCI_Command):
pass
class Driver(common.Driver):
def __init__(self, host):
self.host = host
@staticmethod
def check(host):
driver = host.hci_metadata.get("driver")
if driver == "intel":
return True
vendor_id = host.hci_metadata.get("vendor_id")
product_id = host.hci_metadata.get("product_id")
if vendor_id is None or product_id is None:
logger.debug("USB metadata not sufficient")
return False
if (vendor_id, product_id) not in INTEL_USB_PRODUCTS:
logger.debug(
f"USB device ({vendor_id:04X}, {product_id:04X}) " "not in known list"
)
return False
return True
@classmethod
async def for_host(cls, host, force=False): # type: ignore
# Only instantiate this driver if explicitly selected
if not force and not cls.check(host):
return None
return cls(host)
async def init_controller(self):
self.host.ready = True
await self.host.send_command(HCI_Reset_Command(), check_result=True)
await self.host.send_command(
Hci_Intel_DDC_Config_Write_Command(
params=HCI_INTEL_DDC_CONFIG_WRITE_PAYLOAD
)
)

View File

@@ -41,7 +41,7 @@ from bumble.hci import (
HCI_Reset_Command,
HCI_Read_Local_Version_Information_Command,
)
from bumble.drivers import common
# -----------------------------------------------------------------------------
# Logging
@@ -285,7 +285,7 @@ class Firmware:
)
class Driver(common.Driver):
class Driver:
@dataclass
class DriverInfo:
rom: int
@@ -470,12 +470,8 @@ class Driver(common.Driver):
logger.debug("USB metadata not found")
return False
if host.hci_metadata.get('driver') == 'rtk':
# Forced driver
return True
vendor_id = host.hci_metadata.get("vendor_id")
product_id = host.hci_metadata.get("product_id")
vendor_id = host.hci_metadata.get("vendor_id", None)
product_id = host.hci_metadata.get("product_id", None)
if vendor_id is None or product_id is None:
logger.debug("USB metadata not sufficient")
return False
@@ -490,9 +486,6 @@ class Driver(common.Driver):
@classmethod
async def driver_info_for_host(cls, host):
await host.send_command(HCI_Reset_Command(), check_result=True)
host.ready = True # Needed to let the host know the controller is ready.
response = await host.send_command(
HCI_Read_Local_Version_Information_Command(), check_result=True
)

View File

@@ -23,28 +23,16 @@
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import enum
import functools
import logging
import struct
from typing import (
Callable,
Dict,
Iterable,
List,
Optional,
Sequence,
Union,
TYPE_CHECKING,
)
from typing import Optional, Sequence, Iterable, List, Union
from bumble.colors import color
from bumble.core import UUID
from bumble.att import Attribute, AttributeValue
if TYPE_CHECKING:
from bumble.gatt_client import AttributeProxy
from bumble.device import Connection
from .colors import color
from .core import UUID, get_dict_key_by_value
from .att import Attribute
# -----------------------------------------------------------------------------
@@ -105,35 +93,20 @@ GATT_RECONNECTION_CONFIGURATION_SERVICE = UUID.from_16_bits(0x1829, 'Reconne
GATT_INSULIN_DELIVERY_SERVICE = UUID.from_16_bits(0x183A, 'Insulin Delivery')
GATT_BINARY_SENSOR_SERVICE = UUID.from_16_bits(0x183B, 'Binary Sensor')
GATT_EMERGENCY_CONFIGURATION_SERVICE = UUID.from_16_bits(0x183C, 'Emergency Configuration')
GATT_AUTHORIZATION_CONTROL_SERVICE = UUID.from_16_bits(0x183D, 'Authorization Control')
GATT_PHYSICAL_ACTIVITY_MONITOR_SERVICE = UUID.from_16_bits(0x183E, 'Physical Activity Monitor')
GATT_ELAPSED_TIME_SERVICE = UUID.from_16_bits(0x183F, 'Elapsed Time')
GATT_GENERIC_HEALTH_SENSOR_SERVICE = UUID.from_16_bits(0x1840, 'Generic Health Sensor')
GATT_AUDIO_INPUT_CONTROL_SERVICE = UUID.from_16_bits(0x1843, 'Audio Input Control')
GATT_VOLUME_CONTROL_SERVICE = UUID.from_16_bits(0x1844, 'Volume Control')
GATT_VOLUME_OFFSET_CONTROL_SERVICE = UUID.from_16_bits(0x1845, 'Volume Offset Control')
GATT_COORDINATED_SET_IDENTIFICATION_SERVICE = UUID.from_16_bits(0x1846, 'Coordinated Set Identification')
GATT_COORDINATED_SET_IDENTIFICATION_SERVICE = UUID.from_16_bits(0x1846, 'Coordinated Set Identification Service')
GATT_DEVICE_TIME_SERVICE = UUID.from_16_bits(0x1847, 'Device Time')
GATT_MEDIA_CONTROL_SERVICE = UUID.from_16_bits(0x1848, 'Media Control')
GATT_GENERIC_MEDIA_CONTROL_SERVICE = UUID.from_16_bits(0x1849, 'Generic Media Control')
GATT_MEDIA_CONTROL_SERVICE = UUID.from_16_bits(0x1848, 'Media Control Service')
GATT_GENERIC_MEDIA_CONTROL_SERVICE = UUID.from_16_bits(0x1849, 'Generic Media Control Service')
GATT_CONSTANT_TONE_EXTENSION_SERVICE = UUID.from_16_bits(0x184A, 'Constant Tone Extension')
GATT_TELEPHONE_BEARER_SERVICE = UUID.from_16_bits(0x184B, 'Telephone Bearer')
GATT_GENERIC_TELEPHONE_BEARER_SERVICE = UUID.from_16_bits(0x184C, 'Generic Telephone Bearer')
GATT_TELEPHONE_BEARER_SERVICE = UUID.from_16_bits(0x184B, 'Telephone Bearer Service')
GATT_GENERIC_TELEPHONE_BEARER_SERVICE = UUID.from_16_bits(0x184C, 'Generic Telephone Bearer Service')
GATT_MICROPHONE_CONTROL_SERVICE = UUID.from_16_bits(0x184D, 'Microphone Control')
GATT_AUDIO_STREAM_CONTROL_SERVICE = UUID.from_16_bits(0x184E, 'Audio Stream Control')
GATT_BROADCAST_AUDIO_SCAN_SERVICE = UUID.from_16_bits(0x184F, 'Broadcast Audio Scan')
GATT_PUBLISHED_AUDIO_CAPABILITIES_SERVICE = UUID.from_16_bits(0x1850, 'Published Audio Capabilities')
GATT_BASIC_AUDIO_ANNOUNCEMENT_SERVICE = UUID.from_16_bits(0x1851, 'Basic Audio Announcement')
GATT_BROADCAST_AUDIO_ANNOUNCEMENT_SERVICE = UUID.from_16_bits(0x1852, 'Broadcast Audio Announcement')
GATT_COMMON_AUDIO_SERVICE = UUID.from_16_bits(0x1853, 'Common Audio')
GATT_HEARING_ACCESS_SERVICE = UUID.from_16_bits(0x1854, 'Hearing Access')
GATT_TELEPHONY_AND_MEDIA_AUDIO_SERVICE = UUID.from_16_bits(0x1855, 'Telephony and Media Audio')
GATT_PUBLIC_BROADCAST_ANNOUNCEMENT_SERVICE = UUID.from_16_bits(0x1856, 'Public Broadcast Announcement')
GATT_ELECTRONIC_SHELF_LABEL_SERVICE = UUID.from_16_bits(0X1857, 'Electronic Shelf Label')
GATT_GAMING_AUDIO_SERVICE = UUID.from_16_bits(0x1858, 'Gaming Audio')
GATT_MESH_PROXY_SOLICITATION_SERVICE = UUID.from_16_bits(0x1859, 'Mesh Audio Solicitation')
# Attribute Types
# Types
GATT_PRIMARY_SERVICE_ATTRIBUTE_TYPE = UUID.from_16_bits(0x2800, 'Primary Service')
GATT_SECONDARY_SERVICE_ATTRIBUTE_TYPE = UUID.from_16_bits(0x2801, 'Secondary Service')
GATT_INCLUDE_ATTRIBUTE_TYPE = UUID.from_16_bits(0x2802, 'Include')
@@ -156,8 +129,6 @@ GATT_ENVIRONMENTAL_SENSING_MEASUREMENT_DESCRIPTOR = UUID.from_16_bits(0x290C,
GATT_ENVIRONMENTAL_SENSING_TRIGGER_DESCRIPTOR = UUID.from_16_bits(0x290D, 'Environmental Sensing Trigger Setting')
GATT_TIME_TRIGGER_DESCRIPTOR = UUID.from_16_bits(0x290E, 'Time Trigger Setting')
GATT_COMPLETE_BR_EDR_TRANSPORT_BLOCK_DATA_DESCRIPTOR = UUID.from_16_bits(0x290F, 'Complete BR-EDR Transport Block Data')
GATT_OBSERVATION_SCHEDULE_DESCRIPTOR = UUID.from_16_bits(0x290F, 'Observation Schedule')
GATT_VALID_RANGE_AND_ACCURACY_DESCRIPTOR = UUID.from_16_bits(0x290F, 'Valid Range And Accuracy')
# Device Information Service
GATT_SYSTEM_ID_CHARACTERISTIC = UUID.from_16_bits(0x2A23, 'System ID')
@@ -185,96 +156,6 @@ GATT_HEART_RATE_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2A39, 'Heart
# Battery Service
GATT_BATTERY_LEVEL_CHARACTERISTIC = UUID.from_16_bits(0x2A19, 'Battery Level')
# Telephony And Media Audio Service (TMAS)
GATT_TMAP_ROLE_CHARACTERISTIC = UUID.from_16_bits(0x2B51, 'TMAP Role')
# Audio Input Control Service (AICS)
GATT_AUDIO_INPUT_STATE_CHARACTERISTIC = UUID.from_16_bits(0x2B77, 'Audio Input State')
GATT_GAIN_SETTINGS_ATTRIBUTE_CHARACTERISTIC = UUID.from_16_bits(0x2B78, 'Gain Settings Attribute')
GATT_AUDIO_INPUT_TYPE_CHARACTERISTIC = UUID.from_16_bits(0x2B79, 'Audio Input Type')
GATT_AUDIO_INPUT_STATUS_CHARACTERISTIC = UUID.from_16_bits(0x2B7A, 'Audio Input Status')
GATT_AUDIO_INPUT_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2B7B, 'Audio Input Control Point')
GATT_AUDIO_INPUT_DESCRIPTION_CHARACTERISTIC = UUID.from_16_bits(0x2B7C, 'Audio Input Description')
# Volume Control Service (VCS)
GATT_VOLUME_STATE_CHARACTERISTIC = UUID.from_16_bits(0x2B7D, 'Volume State')
GATT_VOLUME_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2B7E, 'Volume Control Point')
GATT_VOLUME_FLAGS_CHARACTERISTIC = UUID.from_16_bits(0x2B7F, 'Volume Flags')
# Volume Offset Control Service (VOCS)
GATT_VOLUME_OFFSET_STATE_CHARACTERISTIC = UUID.from_16_bits(0x2B80, 'Volume Offset State')
GATT_AUDIO_LOCATION_CHARACTERISTIC = UUID.from_16_bits(0x2B81, 'Audio Location')
GATT_VOLUME_OFFSET_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2B82, 'Volume Offset Control Point')
GATT_AUDIO_OUTPUT_DESCRIPTION_CHARACTERISTIC = UUID.from_16_bits(0x2B83, 'Audio Output Description')
# Coordinated Set Identification Service (CSIS)
GATT_SET_IDENTITY_RESOLVING_KEY_CHARACTERISTIC = UUID.from_16_bits(0x2B84, 'Set Identity Resolving Key')
GATT_COORDINATED_SET_SIZE_CHARACTERISTIC = UUID.from_16_bits(0x2B85, 'Coordinated Set Size')
GATT_SET_MEMBER_LOCK_CHARACTERISTIC = UUID.from_16_bits(0x2B86, 'Set Member Lock')
GATT_SET_MEMBER_RANK_CHARACTERISTIC = UUID.from_16_bits(0x2B87, 'Set Member Rank')
# Media Control Service (MCS)
GATT_MEDIA_PLAYER_NAME_CHARACTERISTIC = UUID.from_16_bits(0x2B93, 'Media Player Name')
GATT_MEDIA_PLAYER_ICON_OBJECT_ID_CHARACTERISTIC = UUID.from_16_bits(0x2B94, 'Media Player Icon Object ID')
GATT_MEDIA_PLAYER_ICON_URL_CHARACTERISTIC = UUID.from_16_bits(0x2B95, 'Media Player Icon URL')
GATT_TRACK_CHANGED_CHARACTERISTIC = UUID.from_16_bits(0x2B96, 'Track Changed')
GATT_TRACK_TITLE_CHARACTERISTIC = UUID.from_16_bits(0x2B97, 'Track Title')
GATT_TRACK_DURATION_CHARACTERISTIC = UUID.from_16_bits(0x2B98, 'Track Duration')
GATT_TRACK_POSITION_CHARACTERISTIC = UUID.from_16_bits(0x2B99, 'Track Position')
GATT_PLAYBACK_SPEED_CHARACTERISTIC = UUID.from_16_bits(0x2B9A, 'Playback Speed')
GATT_SEEKING_SPEED_CHARACTERISTIC = UUID.from_16_bits(0x2B9B, 'Seeking Speed')
GATT_CURRENT_TRACK_SEGMENTS_OBJECT_ID_CHARACTERISTIC = UUID.from_16_bits(0x2B9C, 'Current Track Segments Object ID')
GATT_CURRENT_TRACK_OBJECT_ID_CHARACTERISTIC = UUID.from_16_bits(0x2B9D, 'Current Track Object ID')
GATT_NEXT_TRACK_OBJECT_ID_CHARACTERISTIC = UUID.from_16_bits(0x2B9E, 'Next Track Object ID')
GATT_PARENT_GROUP_OBJECT_ID_CHARACTERISTIC = UUID.from_16_bits(0x2B9F, 'Parent Group Object ID')
GATT_CURRENT_GROUP_OBJECT_ID_CHARACTERISTIC = UUID.from_16_bits(0x2BA0, 'Current Group Object ID')
GATT_PLAYING_ORDER_CHARACTERISTIC = UUID.from_16_bits(0x2BA1, 'Playing Order')
GATT_PLAYING_ORDERS_SUPPORTED_CHARACTERISTIC = UUID.from_16_bits(0x2BA2, 'Playing Orders Supported')
GATT_MEDIA_STATE_CHARACTERISTIC = UUID.from_16_bits(0x2BA3, 'Media State')
GATT_MEDIA_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2BA4, 'Media Control Point')
GATT_MEDIA_CONTROL_POINT_OPCODES_SUPPORTED_CHARACTERISTIC = UUID.from_16_bits(0x2BA5, 'Media Control Point Opcodes Supported')
GATT_SEARCH_RESULTS_OBJECT_ID_CHARACTERISTIC = UUID.from_16_bits(0x2BA6, 'Search Results Object ID')
GATT_SEARCH_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2BA7, 'Search Control Point')
GATT_CONTENT_CONTROL_ID_CHARACTERISTIC = UUID.from_16_bits(0x2BBA, 'Content Control Id')
# Telephone Bearer Service (TBS)
GATT_BEARER_PROVIDER_NAME_CHARACTERISTIC = UUID.from_16_bits(0x2BB4, 'Bearer Provider Name')
GATT_BEARER_UCI_CHARACTERISTIC = UUID.from_16_bits(0x2BB5, 'Bearer UCI')
GATT_BEARER_TECHNOLOGY_CHARACTERISTIC = UUID.from_16_bits(0x2BB6, 'Bearer Technology')
GATT_BEARER_URI_SCHEMES_SUPPORTED_LIST_CHARACTERISTIC = UUID.from_16_bits(0x2BB7, 'Bearer URI Schemes Supported List')
GATT_BEARER_SIGNAL_STRENGTH_CHARACTERISTIC = UUID.from_16_bits(0x2BB8, 'Bearer Signal Strength')
GATT_BEARER_SIGNAL_STRENGTH_REPORTING_INTERVAL_CHARACTERISTIC = UUID.from_16_bits(0x2BB9, 'Bearer Signal Strength Reporting Interval')
GATT_BEARER_LIST_CURRENT_CALLS_CHARACTERISTIC = UUID.from_16_bits(0x2BBA, 'Bearer List Current Calls')
GATT_CONTENT_CONTROL_ID_CHARACTERISTIC = UUID.from_16_bits(0x2BBB, 'Content Control ID')
GATT_STATUS_FLAGS_CHARACTERISTIC = UUID.from_16_bits(0x2BBC, 'Status Flags')
GATT_INCOMING_CALL_TARGET_BEARER_URI_CHARACTERISTIC = UUID.from_16_bits(0x2BBD, 'Incoming Call Target Bearer URI')
GATT_CALL_STATE_CHARACTERISTIC = UUID.from_16_bits(0x2BBE, 'Call State')
GATT_CALL_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2BBF, 'Call Control Point')
GATT_CALL_CONTROL_POINT_OPTIONAL_OPCODES_CHARACTERISTIC = UUID.from_16_bits(0x2BC0, 'Call Control Point Optional Opcodes')
GATT_TERMINATION_REASON_CHARACTERISTIC = UUID.from_16_bits(0x2BC1, 'Termination Reason')
GATT_INCOMING_CALL_CHARACTERISTIC = UUID.from_16_bits(0x2BC2, 'Incoming Call')
GATT_CALL_FRIENDLY_NAME_CHARACTERISTIC = UUID.from_16_bits(0x2BC3, 'Call Friendly Name')
# Microphone Control Service (MICS)
GATT_MUTE_CHARACTERISTIC = UUID.from_16_bits(0x2BC3, 'Mute')
# Audio Stream Control Service (ASCS)
GATT_SINK_ASE_CHARACTERISTIC = UUID.from_16_bits(0x2BC4, 'Sink ASE')
GATT_SOURCE_ASE_CHARACTERISTIC = UUID.from_16_bits(0x2BC5, 'Source ASE')
GATT_ASE_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2BC6, 'ASE Control Point')
# Broadcast Audio Scan Service (BASS)
GATT_BROADCAST_AUDIO_SCAN_CONTROL_POINT_CHARACTERISTIC = UUID.from_16_bits(0x2BC7, 'Broadcast Audio Scan Control Point')
GATT_BROADCAST_RECEIVE_STATE_CHARACTERISTIC = UUID.from_16_bits(0x2BC8, 'Broadcast Receive State')
# Published Audio Capabilities Service (PACS)
GATT_SINK_PAC_CHARACTERISTIC = UUID.from_16_bits(0x2BC9, 'Sink PAC')
GATT_SINK_AUDIO_LOCATION_CHARACTERISTIC = UUID.from_16_bits(0x2BCA, 'Sink Audio Location')
GATT_SOURCE_PAC_CHARACTERISTIC = UUID.from_16_bits(0x2BCB, 'Source PAC')
GATT_SOURCE_AUDIO_LOCATION_CHARACTERISTIC = UUID.from_16_bits(0x2BCC, 'Source Audio Location')
GATT_AVAILABLE_AUDIO_CONTEXTS_CHARACTERISTIC = UUID.from_16_bits(0x2BCD, 'Available Audio Contexts')
GATT_SUPPORTED_AUDIO_CONTEXTS_CHARACTERISTIC = UUID.from_16_bits(0x2BCE, 'Supported Audio Contexts')
# ASHA Service
GATT_ASHA_SERVICE = UUID.from_16_bits(0xFDF0, 'Audio Streaming for Hearing Aid')
GATT_ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC = UUID('6333651e-c481-4a3e-9169-7c902aad37bb', 'ReadOnlyProperties')
@@ -296,9 +177,6 @@ GATT_BOOT_KEYBOARD_INPUT_REPORT_CHARACTERISTIC = UUID.from_16_bi
GATT_CURRENT_TIME_CHARACTERISTIC = UUID.from_16_bits(0x2A2B, 'Current Time')
GATT_BOOT_KEYBOARD_OUTPUT_REPORT_CHARACTERISTIC = UUID.from_16_bits(0x2A32, 'Boot Keyboard Output Report')
GATT_CENTRAL_ADDRESS_RESOLUTION__CHARACTERISTIC = UUID.from_16_bits(0x2AA6, 'Central Address Resolution')
GATT_CLIENT_SUPPORTED_FEATURES_CHARACTERISTIC = UUID.from_16_bits(0x2B29, 'Client Supported Features')
GATT_DATABASE_HASH_CHARACTERISTIC = UUID.from_16_bits(0x2B2A, 'Database Hash')
GATT_SERVER_SUPPORTED_FEATURES_CHARACTERISTIC = UUID.from_16_bits(0x2B3A, 'Server Supported Features')
# fmt: on
# pylint: enable=line-too-long
@@ -380,12 +258,9 @@ class TemplateService(Service):
UUID: UUID
def __init__(
self,
characteristics: List[Characteristic],
primary: bool = True,
included_services: List[Service] = [],
self, characteristics: List[Characteristic], primary: bool = True
) -> None:
super().__init__(self.UUID, characteristics, primary, included_services)
super().__init__(self.UUID, characteristics, primary)
# -----------------------------------------------------------------------------
@@ -534,43 +409,56 @@ class CharacteristicDeclaration(Attribute):
# -----------------------------------------------------------------------------
class CharacteristicValue(AttributeValue):
"""Same as AttributeValue, for backward compatibility"""
class CharacteristicValue:
'''
Characteristic value where reading and/or writing is delegated to functions
passed as arguments to the constructor.
'''
def __init__(self, read=None, write=None):
self._read = read
self._write = write
def read(self, connection):
return self._read(connection) if self._read else b''
def write(self, connection, value):
if self._write:
self._write(connection, value)
# -----------------------------------------------------------------------------
class CharacteristicAdapter:
'''
An adapter that can adapt Characteristic and AttributeProxy objects
by wrapping their `read_value()` and `write_value()` methods with ones that
return/accept encoded/decoded values.
For proxies (i.e used by a GATT client), the adaptation is one where the return
value of `read_value()` is decoded and the value passed to `write_value()` is
encoded. The `subscribe()` method, is wrapped with one where the values are decoded
before being passed to the subscriber.
For local values (i.e hosted by a GATT server) the adaptation is one where the
return value of `read_value()` is encoded and the value passed to `write_value()`
is decoded.
An adapter that can adapt any object with `read_value` and `write_value`
methods (like Characteristic and CharacteristicProxy objects) by wrapping
those methods with ones that return/accept encoded/decoded values.
Objects with async methods are considered proxies, so the adaptation is one
where the return value of `read_value` is decoded and the value passed to
`write_value` is encoded. Other objects are considered local characteristics
so the adaptation is one where the return value of `read_value` is encoded
and the value passed to `write_value` is decoded.
If the characteristic has a `subscribe` method, it is wrapped with one where
the values are decoded before being passed to the subscriber.
'''
read_value: Callable
write_value: Callable
def __init__(self, characteristic: Union[Characteristic, AttributeProxy]):
def __init__(self, characteristic):
self.wrapped_characteristic = characteristic
self.subscribers: Dict[
Callable, Callable
] = {} # Map from subscriber to proxy subscriber
self.subscribers = {} # Map from subscriber to proxy subscriber
if isinstance(characteristic, Characteristic):
self.read_value = self.read_encoded_value
self.write_value = self.write_encoded_value
else:
if asyncio.iscoroutinefunction(
characteristic.read_value
) and asyncio.iscoroutinefunction(characteristic.write_value):
self.read_value = self.read_decoded_value
self.write_value = self.write_decoded_value
else:
self.read_value = self.read_encoded_value
self.write_value = self.write_encoded_value
if hasattr(self.wrapped_characteristic, 'subscribe'):
self.subscribe = self.wrapped_subscribe
if hasattr(self.wrapped_characteristic, 'unsubscribe'):
self.unsubscribe = self.wrapped_unsubscribe
def __getattr__(self, name):
@@ -589,13 +477,11 @@ class CharacteristicAdapter:
else:
setattr(self.wrapped_characteristic, name, value)
async def read_encoded_value(self, connection):
return self.encode_value(
await self.wrapped_characteristic.read_value(connection)
)
def read_encoded_value(self, connection):
return self.encode_value(self.wrapped_characteristic.read_value(connection))
async def write_encoded_value(self, connection, value):
return await self.wrapped_characteristic.write_value(
def write_encoded_value(self, connection, value):
return self.wrapped_characteristic.write_value(
connection, self.decode_value(value)
)
@@ -730,24 +616,13 @@ class Descriptor(Attribute):
'''
def __str__(self) -> str:
if isinstance(self.value, bytes):
value_str = self.value.hex()
elif isinstance(self.value, CharacteristicValue):
value = self.value.read(None)
if isinstance(value, bytes):
value_str = value.hex()
else:
value_str = '<async>'
else:
value_str = '<...>'
return (
f'Descriptor(handle=0x{self.handle:04X}, '
f'type={self.type}, '
f'value={value_str})'
f'value={self.read_value(None).hex()})'
)
# -----------------------------------------------------------------------------
class ClientCharacteristicConfigurationBits(enum.IntFlag):
'''
See Vol 3, Part G - 3.3.3.3 - Table 3.11 Client Characteristic Configuration bit

View File

@@ -38,7 +38,6 @@ from typing import (
Any,
Iterable,
Type,
Set,
TYPE_CHECKING,
)
@@ -129,7 +128,7 @@ class ServiceProxy(AttributeProxy):
included_services: List[ServiceProxy]
@staticmethod
def from_client(service_class, client: Client, service_uuid: UUID):
def from_client(service_class, client, service_uuid):
# The service and its characteristics are considered to have already been
# discovered
services = client.get_services_by_uuid(service_uuid)
@@ -207,11 +206,11 @@ class CharacteristicProxy(AttributeProxy):
return await self.client.subscribe(self, subscriber, prefer_notify)
async def unsubscribe(self, subscriber=None, force=False):
async def unsubscribe(self, subscriber=None):
if subscriber in self.subscribers:
subscriber = self.subscribers.pop(subscriber)
return await self.client.unsubscribe(self, subscriber, force)
return await self.client.unsubscribe(self, subscriber)
def __str__(self) -> str:
return (
@@ -247,12 +246,8 @@ class ProfileServiceProxy:
class Client:
services: List[ServiceProxy]
cached_values: Dict[int, Tuple[datetime, bytes]]
notification_subscribers: Dict[
int, Set[Union[CharacteristicProxy, Callable[[bytes], Any]]]
]
indication_subscribers: Dict[
int, Set[Union[CharacteristicProxy, Callable[[bytes], Any]]]
]
notification_subscribers: Dict[int, Callable[[bytes], Any]]
indication_subscribers: Dict[int, Callable[[bytes], Any]]
pending_response: Optional[asyncio.futures.Future[ATT_PDU]]
pending_request: Optional[ATT_PDU]
@@ -262,8 +257,10 @@ class Client:
self.request_semaphore = asyncio.Semaphore(1)
self.pending_request = None
self.pending_response = None
self.notification_subscribers = {} # Subscriber set, by attribute handle
self.indication_subscribers = {} # Subscriber set, by attribute handle
self.notification_subscribers = (
{}
) # Notification subscribers, by attribute handle
self.indication_subscribers = {} # Indication subscribers, by attribute handle
self.services = []
self.cached_values = {}
@@ -685,8 +682,8 @@ class Client:
async def discover_descriptors(
self,
characteristic: Optional[CharacteristicProxy] = None,
start_handle: Optional[int] = None,
end_handle: Optional[int] = None,
start_handle=None,
end_handle=None,
) -> List[DescriptorProxy]:
'''
See Vol 3, Part G - 4.7.1 Discover All Characteristic Descriptors
@@ -792,12 +789,7 @@ class Client:
return attributes
async def subscribe(
self,
characteristic: CharacteristicProxy,
subscriber: Optional[Callable[[bytes], Any]] = None,
prefer_notify: bool = True,
) -> None:
async def subscribe(self, characteristic, subscriber=None, prefer_notify=True):
# If we haven't already discovered the descriptors for this characteristic,
# do it now
if not characteristic.descriptors_discovered:
@@ -834,7 +826,6 @@ class Client:
subscriber_set = subscribers.setdefault(characteristic.handle, set())
if subscriber is not None:
subscriber_set.add(subscriber)
# Add the characteristic as a subscriber, which will result in the
# characteristic emitting an 'update' event when a notification or indication
# is received
@@ -842,18 +833,7 @@ class Client:
await self.write_value(cccd, struct.pack('<H', bits), with_response=True)
async def unsubscribe(
self,
characteristic: CharacteristicProxy,
subscriber: Optional[Callable[[bytes], Any]] = None,
force: bool = False,
) -> None:
'''
Unsubscribe from a characteristic.
If `force` is True, this will write zeros to the CCCD when there are no
subscribers left, even if there were already no registered subscribers.
'''
async def unsubscribe(self, characteristic, subscriber=None):
# If we haven't already discovered the descriptors for this characteristic,
# do it now
if not characteristic.descriptors_discovered:
@@ -867,45 +847,31 @@ class Client:
logger.warning('unsubscribing from characteristic with no CCCD descriptor')
return
# Check if the characteristic has subscribers
if not (
characteristic.handle in self.notification_subscribers
or characteristic.handle in self.indication_subscribers
):
if not force:
return
# Remove the subscriber(s)
if subscriber is not None:
# Remove matching subscriber from subscriber sets
for subscriber_set in (
self.notification_subscribers,
self.indication_subscribers,
):
if (
subscribers := subscriber_set.get(characteristic.handle)
) and subscriber in subscribers:
subscribers = subscriber_set.get(characteristic.handle, [])
if subscriber in subscribers:
subscribers.remove(subscriber)
# Cleanup if we removed the last one
if not subscribers:
del subscriber_set[characteristic.handle]
else:
# Remove all subscribers for this attribute from the sets
# Remove all subscribers for this attribute from the sets!
self.notification_subscribers.pop(characteristic.handle, None)
self.indication_subscribers.pop(characteristic.handle, None)
# Update the CCCD
if not (
characteristic.handle in self.notification_subscribers
or characteristic.handle in self.indication_subscribers
):
if not self.notification_subscribers and not self.indication_subscribers:
# No more subscribers left
await self.write_value(cccd, b'\x00\x00', with_response=True)
async def read_value(
self, attribute: Union[int, AttributeProxy], no_long_read: bool = False
) -> bytes:
) -> Any:
'''
See Vol 3, Part G - 4.8.1 Read Characteristic Value
@@ -1068,7 +1034,7 @@ class Client:
logger.warning('!!! unexpected response, there is no pending request')
return
# The response should match the pending request unless it is
# Sanity check: the response should match the pending request unless it is
# an error response
if att_pdu.op_code != ATT_ERROR_RESPONSE:
expected_response_name = self.pending_request.name.replace(
@@ -1101,7 +1067,7 @@ class Client:
def on_att_handle_value_notification(self, notification):
# Call all subscribers
subscribers = self.notification_subscribers.get(
notification.attribute_handle, set()
notification.attribute_handle, []
)
if not subscribers:
logger.warning('!!! received notification with no subscriber')
@@ -1115,9 +1081,7 @@ class Client:
def on_att_handle_value_indication(self, indication):
# Call all subscribers
subscribers = self.indication_subscribers.get(
indication.attribute_handle, set()
)
subscribers = self.indication_subscribers.get(indication.attribute_handle, [])
if not subscribers:
logger.warning('!!! received indication with no subscriber')

View File

@@ -31,9 +31,9 @@ import struct
from typing import List, Tuple, Optional, TypeVar, Type, Dict, Iterable, TYPE_CHECKING
from pyee import EventEmitter
from bumble.colors import color
from bumble.core import UUID
from bumble.att import (
from .colors import color
from .core import UUID
from .att import (
ATT_ATTRIBUTE_NOT_FOUND_ERROR,
ATT_ATTRIBUTE_NOT_LONG_ERROR,
ATT_CID,
@@ -60,7 +60,7 @@ from bumble.att import (
ATT_Write_Response,
Attribute,
)
from bumble.gatt import (
from .gatt import (
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
GATT_MAX_ATTRIBUTE_VALUE_SIZE,
@@ -74,7 +74,6 @@ from bumble.gatt import (
Descriptor,
Service,
)
from bumble.utils import AsyncRunner
if TYPE_CHECKING:
from bumble.device import Device, Connection
@@ -328,7 +327,7 @@ class Server(EventEmitter):
f'handle=0x{characteristic.handle:04X}: {value.hex()}'
)
# Check parameters
# Sanity check
if len(value) != 2:
logger.warning('CCCD value not 2 bytes long')
return
@@ -380,7 +379,7 @@ class Server(EventEmitter):
# Get or encode the value
value = (
await attribute.read_value(connection)
attribute.read_value(connection)
if value is None
else attribute.encode_value(value)
)
@@ -423,7 +422,7 @@ class Server(EventEmitter):
# Get or encode the value
value = (
await attribute.read_value(connection)
attribute.read_value(connection)
if value is None
else attribute.encode_value(value)
)
@@ -651,8 +650,7 @@ class Server(EventEmitter):
self.send_response(connection, response)
@AsyncRunner.run_in_task()
async def on_att_find_by_type_value_request(self, connection, request):
def on_att_find_by_type_value_request(self, connection, request):
'''
See Bluetooth spec Vol 3, Part F - 3.4.3.3 Find By Type Value Request
'''
@@ -660,13 +658,13 @@ class Server(EventEmitter):
# Build list of returned attributes
pdu_space_available = connection.att_mtu - 2
attributes = []
async for attribute in (
for attribute in (
attribute
for attribute in self.attributes
if attribute.handle >= request.starting_handle
and attribute.handle <= request.ending_handle
and attribute.type == request.attribute_type
and (await attribute.read_value(connection)) == request.attribute_value
and attribute.read_value(connection) == request.attribute_value
and pdu_space_available >= 4
):
# TODO: check permissions
@@ -704,8 +702,7 @@ class Server(EventEmitter):
self.send_response(connection, response)
@AsyncRunner.run_in_task()
async def on_att_read_by_type_request(self, connection, request):
def on_att_read_by_type_request(self, connection, request):
'''
See Bluetooth spec Vol 3, Part F - 3.4.4.1 Read By Type Request
'''
@@ -728,7 +725,7 @@ class Server(EventEmitter):
and pdu_space_available
):
try:
attribute_value = await attribute.read_value(connection)
attribute_value = attribute.read_value(connection)
except ATT_Error as error:
# If the first attribute is unreadable, return an error
# Otherwise return attributes up to this point
@@ -770,15 +767,14 @@ class Server(EventEmitter):
self.send_response(connection, response)
@AsyncRunner.run_in_task()
async def on_att_read_request(self, connection, request):
def on_att_read_request(self, connection, request):
'''
See Bluetooth spec Vol 3, Part F - 3.4.4.3 Read Request
'''
if attribute := self.get_attribute(request.attribute_handle):
try:
value = await attribute.read_value(connection)
value = attribute.read_value(connection)
except ATT_Error as error:
response = ATT_Error_Response(
request_opcode_in_error=request.op_code,
@@ -796,15 +792,14 @@ class Server(EventEmitter):
)
self.send_response(connection, response)
@AsyncRunner.run_in_task()
async def on_att_read_blob_request(self, connection, request):
def on_att_read_blob_request(self, connection, request):
'''
See Bluetooth spec Vol 3, Part F - 3.4.4.5 Read Blob Request
'''
if attribute := self.get_attribute(request.attribute_handle):
try:
value = await attribute.read_value(connection)
value = attribute.read_value(connection)
except ATT_Error as error:
response = ATT_Error_Response(
request_opcode_in_error=request.op_code,
@@ -841,8 +836,7 @@ class Server(EventEmitter):
)
self.send_response(connection, response)
@AsyncRunner.run_in_task()
async def on_att_read_by_group_type_request(self, connection, request):
def on_att_read_by_group_type_request(self, connection, request):
'''
See Bluetooth spec Vol 3, Part F - 3.4.4.9 Read by Group Type Request
'''
@@ -870,7 +864,7 @@ class Server(EventEmitter):
):
# No need to catch permission errors here, since these attributes
# must all be world-readable
attribute_value = await attribute.read_value(connection)
attribute_value = attribute.read_value(connection)
# Check the attribute value size
max_attribute_size = min(connection.att_mtu - 6, 251)
if len(attribute_value) > max_attribute_size:
@@ -909,8 +903,7 @@ class Server(EventEmitter):
self.send_response(connection, response)
@AsyncRunner.run_in_task()
async def on_att_write_request(self, connection, request):
def on_att_write_request(self, connection, request):
'''
See Bluetooth spec Vol 3, Part F - 3.4.5.1 Write Request
'''
@@ -943,13 +936,12 @@ class Server(EventEmitter):
return
# Accept the value
await attribute.write_value(connection, request.attribute_value)
attribute.write_value(connection, request.attribute_value)
# Done
self.send_response(connection, ATT_Write_Response())
@AsyncRunner.run_in_task()
async def on_att_write_command(self, connection, request):
def on_att_write_command(self, connection, request):
'''
See Bluetooth spec Vol 3, Part F - 3.4.5.3 Write Command
'''
@@ -967,9 +959,9 @@ class Server(EventEmitter):
# Accept the value
try:
await attribute.write_value(connection, request.attribute_value)
attribute.write_value(connection, request.attribute_value)
except Exception as error:
logger.exception(f'!!! ignoring exception: {error}')
logger.warning(f'!!! ignoring exception: {error}')
def on_att_handle_value_confirmation(self, connection, _confirmation):
'''

File diff suppressed because it is too large Load Diff

View File

@@ -15,46 +15,30 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
from collections.abc import Callable, MutableMapping
import datetime
from typing import cast, Any, Optional
import logging
from bumble import avc
from bumble import avctp
from bumble import avdtp
from bumble import avrcp
from bumble import crypto
from bumble import rfcomm
from bumble import sdp
from bumble.colors import color
from bumble.att import ATT_CID, ATT_PDU
from bumble.smp import SMP_CID, SMP_Command
from bumble.core import name_or_number
from bumble.l2cap import (
from .colors import color
from .att import ATT_CID, ATT_PDU
from .smp import SMP_CID, SMP_Command
from .core import name_or_number
from .l2cap import (
L2CAP_PDU,
L2CAP_CONNECTION_REQUEST,
L2CAP_CONNECTION_RESPONSE,
L2CAP_SIGNALING_CID,
L2CAP_LE_SIGNALING_CID,
L2CAP_Control_Frame,
L2CAP_Connection_Request,
L2CAP_Connection_Response,
)
from bumble.hci import (
Address,
from .hci import (
HCI_EVENT_PACKET,
HCI_ACL_DATA_PACKET,
HCI_DISCONNECTION_COMPLETE_EVENT,
HCI_AclDataPacketAssembler,
HCI_Packet,
HCI_Event,
HCI_AclDataPacket,
HCI_Disconnection_Complete_Event,
)
from .rfcomm import RFCOMM_Frame, RFCOMM_PSM
from .sdp import SDP_PDU, SDP_PSM
from .avdtp import MessageAssembler as AVDTP_MessageAssembler, AVDTP_PSM
# -----------------------------------------------------------------------------
# Logging
@@ -64,36 +48,26 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
PSM_NAMES = {
rfcomm.RFCOMM_PSM: 'RFCOMM',
sdp.SDP_PSM: 'SDP',
avdtp.AVDTP_PSM: 'AVDTP',
avctp.AVCTP_PSM: 'AVCTP',
RFCOMM_PSM: 'RFCOMM',
SDP_PSM: 'SDP',
AVDTP_PSM: 'AVDTP'
# TODO: add more PSM values
}
AVCTP_PID_NAMES = {avrcp.AVRCP_PID: 'AVRCP'}
# -----------------------------------------------------------------------------
class PacketTracer:
class AclStream:
psms: MutableMapping[int, int]
peer: Optional[PacketTracer.AclStream]
avdtp_assemblers: MutableMapping[int, avdtp.MessageAssembler]
avctp_assemblers: MutableMapping[int, avctp.MessageAssembler]
def __init__(self, analyzer: PacketTracer.Analyzer) -> None:
def __init__(self, analyzer):
self.analyzer = analyzer
self.packet_assembler = HCI_AclDataPacketAssembler(self.on_acl_pdu)
self.avdtp_assemblers = {} # AVDTP assemblers, by source_cid
self.avctp_assemblers = {} # AVCTP assemblers, by source_cid
self.psms = {} # PSM, by source_cid
self.peer = None
self.peer = None # ACL stream in the other direction
# pylint: disable=too-many-nested-blocks
def on_acl_pdu(self, pdu: bytes) -> None:
def on_acl_pdu(self, pdu):
l2cap_pdu = L2CAP_PDU.from_bytes(pdu)
self.analyzer.emit(l2cap_pdu)
if l2cap_pdu.cid == ATT_CID:
att_pdu = ATT_PDU.from_bytes(l2cap_pdu.payload)
@@ -107,59 +81,46 @@ class PacketTracer:
# Check if this signals a new channel
if control_frame.code == L2CAP_CONNECTION_REQUEST:
connection_request = cast(L2CAP_Connection_Request, control_frame)
self.psms[connection_request.source_cid] = connection_request.psm
self.psms[control_frame.source_cid] = control_frame.psm
elif control_frame.code == L2CAP_CONNECTION_RESPONSE:
connection_response = cast(L2CAP_Connection_Response, control_frame)
if (
connection_response.result
control_frame.result
== L2CAP_Connection_Response.CONNECTION_SUCCESSFUL
):
if self.peer and (
psm := self.peer.psms.get(connection_response.source_cid)
):
# Found a pending connection
self.psms[connection_response.destination_cid] = psm
if self.peer:
if psm := self.peer.psms.get(control_frame.source_cid):
# Found a pending connection
self.psms[control_frame.destination_cid] = psm
# For AVDTP connections, create a packet assembler for
# each direction
if psm == AVDTP_PSM:
self.avdtp_assemblers[
control_frame.source_cid
] = AVDTP_MessageAssembler(self.on_avdtp_message)
self.peer.avdtp_assemblers[
control_frame.destination_cid
] = AVDTP_MessageAssembler(
self.peer.on_avdtp_message
)
# For AVDTP connections, create a packet assembler for
# each direction
if psm == avdtp.AVDTP_PSM:
self.avdtp_assemblers[
connection_response.source_cid
] = avdtp.MessageAssembler(self.on_avdtp_message)
self.peer.avdtp_assemblers[
connection_response.destination_cid
] = avdtp.MessageAssembler(self.peer.on_avdtp_message)
elif psm == avctp.AVCTP_PSM:
self.avctp_assemblers[
connection_response.source_cid
] = avctp.MessageAssembler(self.on_avctp_message)
self.peer.avctp_assemblers[
connection_response.destination_cid
] = avctp.MessageAssembler(self.peer.on_avctp_message)
else:
# Try to find the PSM associated with this PDU
if self.peer and (psm := self.peer.psms.get(l2cap_pdu.cid)):
if psm == sdp.SDP_PSM:
sdp_pdu = sdp.SDP_PDU.from_bytes(l2cap_pdu.payload)
if psm == SDP_PSM:
sdp_pdu = SDP_PDU.from_bytes(l2cap_pdu.payload)
self.analyzer.emit(sdp_pdu)
elif psm == rfcomm.RFCOMM_PSM:
rfcomm_frame = rfcomm.RFCOMM_Frame.from_bytes(l2cap_pdu.payload)
elif psm == RFCOMM_PSM:
rfcomm_frame = RFCOMM_Frame.from_bytes(l2cap_pdu.payload)
self.analyzer.emit(rfcomm_frame)
elif psm == avdtp.AVDTP_PSM:
elif psm == AVDTP_PSM:
self.analyzer.emit(
f'{color("L2CAP", "green")} [CID={l2cap_pdu.cid}, '
f'PSM=AVDTP]: {l2cap_pdu.payload.hex()}'
)
if avdtp_assembler := self.avdtp_assemblers.get(l2cap_pdu.cid):
avdtp_assembler.on_pdu(l2cap_pdu.payload)
elif psm == avctp.AVCTP_PSM:
self.analyzer.emit(
f'{color("L2CAP", "green")} [CID={l2cap_pdu.cid}, '
f'PSM=AVCTP]: {l2cap_pdu.payload.hex()}'
)
if avctp_assembler := self.avctp_assemblers.get(l2cap_pdu.cid):
avctp_assembler.on_pdu(l2cap_pdu.payload)
assembler = self.avdtp_assemblers.get(l2cap_pdu.cid)
if assembler:
assembler.on_pdu(l2cap_pdu.payload)
else:
psm_string = name_or_number(PSM_NAMES, psm)
self.analyzer.emit(
@@ -169,49 +130,22 @@ class PacketTracer:
else:
self.analyzer.emit(l2cap_pdu)
def on_avdtp_message(
self, transaction_label: int, message: avdtp.Message
) -> None:
def on_avdtp_message(self, transaction_label, message):
self.analyzer.emit(
f'{color("AVDTP", "green")} [{transaction_label}] {message}'
)
def on_avctp_message(
self,
transaction_label: int,
is_command: bool,
ipid: bool,
pid: int,
payload: bytes,
):
if pid == avrcp.AVRCP_PID:
avc_frame = avc.Frame.from_bytes(payload)
details = str(avc_frame)
else:
details = payload.hex()
c_r = 'Command' if is_command else 'Response'
self.analyzer.emit(
f'{color("AVCTP", "green")} '
f'{c_r}[{transaction_label}][{name_or_number(AVCTP_PID_NAMES, pid)}] '
f'{"#" if ipid else ""}'
f'{details}'
)
def feed_packet(self, packet: HCI_AclDataPacket) -> None:
def feed_packet(self, packet):
self.packet_assembler.feed_packet(packet)
class Analyzer:
acl_streams: MutableMapping[int, PacketTracer.AclStream]
peer: PacketTracer.Analyzer
def __init__(self, label: str, emit_message: Callable[..., None]) -> None:
def __init__(self, label, emit_message):
self.label = label
self.emit_message = emit_message
self.acl_streams = {} # ACL streams, by connection handle
self.packet_timestamp: Optional[datetime.datetime] = None
self.peer = None # Analyzer in the other direction
def start_acl_stream(self, connection_handle: int) -> PacketTracer.AclStream:
def start_acl_stream(self, connection_handle):
logger.info(
f'[{self.label}] +++ Creating ACL stream for connection '
f'0x{connection_handle:04X}'
@@ -226,7 +160,7 @@ class PacketTracer:
return stream
def end_acl_stream(self, connection_handle: int) -> None:
def end_acl_stream(self, connection_handle):
if connection_handle in self.acl_streams:
logger.info(
f'[{self.label}] --- Removing ACL stream for connection '
@@ -237,52 +171,34 @@ class PacketTracer:
# Let the other forwarder know so it can cleanup its stream as well
self.peer.end_acl_stream(connection_handle)
def on_packet(
self, timestamp: Optional[datetime.datetime], packet: HCI_Packet
) -> None:
self.packet_timestamp = timestamp
def on_packet(self, packet):
self.emit(packet)
if packet.hci_packet_type == HCI_ACL_DATA_PACKET:
acl_packet = cast(HCI_AclDataPacket, packet)
# Look for an existing stream for this handle, create one if it is the
# first ACL packet for that connection handle
if (
stream := self.acl_streams.get(acl_packet.connection_handle)
) is None:
stream = self.start_acl_stream(acl_packet.connection_handle)
stream.feed_packet(acl_packet)
if (stream := self.acl_streams.get(packet.connection_handle)) is None:
stream = self.start_acl_stream(packet.connection_handle)
stream.feed_packet(packet)
elif packet.hci_packet_type == HCI_EVENT_PACKET:
event_packet = cast(HCI_Event, packet)
if event_packet.event_code == HCI_DISCONNECTION_COMPLETE_EVENT:
self.end_acl_stream(
cast(HCI_Disconnection_Complete_Event, packet).connection_handle
)
if packet.event_code == HCI_DISCONNECTION_COMPLETE_EVENT:
self.end_acl_stream(packet.connection_handle)
def emit(self, message: Any) -> None:
if self.packet_timestamp:
prefix = f"[{self.packet_timestamp.strftime('%Y-%m-%d %H:%M:%S.%f')}]"
else:
prefix = ""
self.emit_message(f'{prefix}[{self.label}] {message}')
def emit(self, message):
self.emit_message(f'[{self.label}] {message}')
def trace(
self,
packet: HCI_Packet,
direction: int = 0,
timestamp: Optional[datetime.datetime] = None,
) -> None:
def trace(self, packet, direction=0):
if direction == 0:
self.host_to_controller_analyzer.on_packet(timestamp, packet)
self.host_to_controller_analyzer.on_packet(packet)
else:
self.controller_to_host_analyzer.on_packet(timestamp, packet)
self.controller_to_host_analyzer.on_packet(packet)
def __init__(
self,
host_to_controller_label: str = color('HOST->CONTROLLER', 'blue'),
controller_to_host_label: str = color('CONTROLLER->HOST', 'cyan'),
emit_message: Callable[..., None] = logger.info,
) -> None:
host_to_controller_label=color('HOST->CONTROLLER', 'blue'),
controller_to_host_label=color('CONTROLLER->HOST', 'cyan'),
emit_message=logger.info,
):
self.host_to_controller_analyzer = PacketTracer.Analyzer(
host_to_controller_label, emit_message
)
@@ -291,15 +207,3 @@ class PacketTracer:
)
self.host_to_controller_analyzer.peer = self.controller_to_host_analyzer
self.controller_to_host_analyzer.peer = self.host_to_controller_analyzer
def generate_irk() -> bytes:
return crypto.r()
def verify_rpa_with_irk(rpa: Address, irk: bytes) -> bool:
rpa_bytes = bytes(rpa)
prand_given = rpa_bytes[3:]
hash_given = rpa_bytes[:3]
hash_local = crypto.ah(irk, prand_given)
return hash_local[:3] == hash_given

View File

@@ -21,11 +21,12 @@ import asyncio
import dataclasses
import enum
import traceback
import pyee
from typing import Dict, List, Union, Set, Any, Optional, TYPE_CHECKING
import warnings
from typing import Dict, List, Union, Set, TYPE_CHECKING
from . import at
from . import rfcomm
from bumble import at
from bumble import rfcomm
from bumble.colors import color
from bumble.core import (
ProtocolError,
@@ -34,11 +35,6 @@ from bumble.core import (
BT_L2CAP_PROTOCOL_ID,
BT_RFCOMM_PROTOCOL_ID,
)
from bumble.hci import (
HCI_Enhanced_Setup_Synchronous_Connection_Command,
CodingFormat,
CodecID,
)
from bumble.sdp import (
DataElement,
ServiceAttribute,
@@ -69,7 +65,6 @@ class HfpProtocolError(ProtocolError):
# Protocol Support
# -----------------------------------------------------------------------------
# -----------------------------------------------------------------------------
class HfpProtocol:
dlc: rfcomm.DLC
@@ -78,6 +73,7 @@ class HfpProtocol:
lines_available: asyncio.Event
def __init__(self, dlc: rfcomm.DLC) -> None:
warnings.warn("See HfProtocol", DeprecationWarning)
self.dlc = dlc
self.buffer = ''
self.lines = collections.deque()
@@ -126,13 +122,10 @@ class HfpProtocol:
# -----------------------------------------------------------------------------
# HF supported features (AT+BRSF=) (normative).
# Hands-Free Profile v1.8, 4.34.2, AT Capabilities Re-Used from GSM 07.07
# and 3GPP 27.007
class HfFeature(enum.IntFlag):
"""
HF supported features (AT+BRSF=) (normative).
Hands-Free Profile v1.8, 4.34.2, AT Capabilities Re-Used from GSM 07.07 and 3GPP 27.007.
"""
EC_NR = 0x001 # Echo Cancel & Noise reduction
THREE_WAY_CALLING = 0x002
CLI_PRESENTATION_CAPABILITY = 0x004
@@ -147,13 +140,10 @@ class HfFeature(enum.IntFlag):
VOICE_RECOGNITION_TEST = 0x800
# AG supported features (+BRSF:) (normative).
# Hands-Free Profile v1.8, 4.34.2, AT Capabilities Re-Used from GSM 07.07
# and 3GPP 27.007
class AgFeature(enum.IntFlag):
"""
AG supported features (+BRSF:) (normative).
Hands-Free Profile v1.8, 4.34.2, AT Capabilities Re-Used from GSM 07.07 and 3GPP 27.007.
"""
THREE_WAY_CALLING = 0x001
EC_NR = 0x002 # Echo Cancel & Noise reduction
VOICE_RECOGNITION_FUNCTION = 0x004
@@ -170,90 +160,52 @@ class AgFeature(enum.IntFlag):
VOICE_RECOGNITION_TEST = 0x2000
# Audio Codec IDs (normative).
# Hands-Free Profile v1.8, 10 Appendix B
class AudioCodec(enum.IntEnum):
"""
Audio Codec IDs (normative).
Hands-Free Profile v1.9, 11 Appendix B
"""
CVSD = 0x01 # Support for CVSD audio codec
MSBC = 0x02 # Support for mSBC audio codec
LC3_SWB = 0x03 # Support for LC3-SWB audio codec
# HF Indicators (normative).
# Bluetooth Assigned Numbers, 6.10.1 HF Indicators
class HfIndicator(enum.IntEnum):
"""
HF Indicators (normative).
Bluetooth Assigned Numbers, 6.10.1 HF Indicators.
"""
ENHANCED_SAFETY = 0x01 # Enhanced safety feature
BATTERY_LEVEL = 0x02 # Battery level feature
# Call Hold supported operations (normative).
# AT Commands Reference Guide, 3.5.2.3.12 +CHLD - Call Holding Services
class CallHoldOperation(enum.IntEnum):
"""
Call Hold supported operations (normative).
AT Commands Reference Guide, 3.5.2.3.12 +CHLD - Call Holding Services.
"""
RELEASE_ALL_HELD_CALLS = 0 # Release all held calls
RELEASE_ALL_ACTIVE_CALLS = 1 # Release all active calls, accept other
HOLD_ALL_ACTIVE_CALLS = 2 # Place all active calls on hold, accept other
ADD_HELD_CALL = 3 # Adds a held call to conversation
# Response Hold status (normative).
# Hands-Free Profile v1.8, 4.34.2, AT Capabilities Re-Used from GSM 07.07
# and 3GPP 27.007
class ResponseHoldStatus(enum.IntEnum):
"""
Response Hold status (normative).
Hands-Free Profile v1.8, 4.34.2, AT Capabilities Re-Used from GSM 07.07 and 3GPP 27.007.
"""
INC_CALL_HELD = 0 # Put incoming call on hold
HELD_CALL_ACC = 1 # Accept a held incoming call
HELD_CALL_REJ = 2 # Reject a held incoming call
class AgIndicator(enum.Enum):
"""
Values for the AG indicator (normative).
Hands-Free Profile v1.8, 4.34.2, AT Capabilities Re-Used from GSM 07.07 and 3GPP 27.007.
"""
SERVICE = 'service'
CALL = 'call'
CALL_SETUP = 'callsetup'
CALL_HELD = 'callheld'
SIGNAL = 'signal'
ROAM = 'roam'
BATTERY_CHARGE = 'battchg'
# Values for the Call Setup AG indicator (normative).
# Hands-Free Profile v1.8, 4.34.2, AT Capabilities Re-Used from GSM 07.07
# and 3GPP 27.007
class CallSetupAgIndicator(enum.IntEnum):
"""
Values for the Call Setup AG indicator (normative).
Hands-Free Profile v1.8, 4.34.2, AT Capabilities Re-Used from GSM 07.07 and 3GPP 27.007.
"""
NOT_IN_CALL_SETUP = 0
INCOMING_CALL_PROCESS = 1
OUTGOING_CALL_SETUP = 2
REMOTE_ALERTED = 3 # Remote party alerted in an outgoing call
# Values for the Call Held AG indicator (normative).
# Hands-Free Profile v1.8, 4.34.2, AT Capabilities Re-Used from GSM 07.07
# and 3GPP 27.007
class CallHeldAgIndicator(enum.IntEnum):
"""
Values for the Call Held AG indicator (normative).
Hands-Free Profile v1.8, 4.34.2, AT Capabilities Re-Used from GSM 07.07 and 3GPP 27.007.
"""
NO_CALLS_HELD = 0
# Call is placed on hold or active/held calls swapped
# (The AG has both an active AND a held call)
@@ -261,24 +213,16 @@ class CallHeldAgIndicator(enum.IntEnum):
CALL_ON_HOLD_NO_ACTIVE_CALL = 2 # Call on hold, no active call
# Call Info direction (normative).
# AT Commands Reference Guide, 3.5.2.3.15 +CLCC - List Current Calls
class CallInfoDirection(enum.IntEnum):
"""
Call Info direction (normative).
AT Commands Reference Guide, 3.5.2.3.15 +CLCC - List Current Calls.
"""
MOBILE_ORIGINATED_CALL = 0
MOBILE_TERMINATED_CALL = 1
# Call Info status (normative).
# AT Commands Reference Guide, 3.5.2.3.15 +CLCC - List Current Calls
class CallInfoStatus(enum.IntEnum):
"""
Call Info status (normative).
AT Commands Reference Guide, 3.5.2.3.15 +CLCC - List Current Calls.
"""
ACTIVE = 0
HELD = 1
DIALING = 2
@@ -287,47 +231,15 @@ class CallInfoStatus(enum.IntEnum):
WAITING = 5
# Call Info mode (normative).
# AT Commands Reference Guide, 3.5.2.3.15 +CLCC - List Current Calls
class CallInfoMode(enum.IntEnum):
"""
Call Info mode (normative).
AT Commands Reference Guide, 3.5.2.3.15 +CLCC - List Current Calls.
"""
VOICE = 0
DATA = 1
FAX = 2
UNKNOWN = 9
class CallInfoMultiParty(enum.IntEnum):
"""
Call Info Multi-Party state (normative).
AT Commands Reference Guide, 3.5.2.3.15 +CLCC - List Current Calls.
"""
NOT_IN_CONFERENCE = 0
IN_CONFERENCE = 1
@dataclasses.dataclass
class CallInfo:
"""
Enhanced call status.
AT Commands Reference Guide, 3.5.2.3.15 +CLCC - List Current Calls.
"""
index: int
direction: CallInfoDirection
status: CallInfoStatus
mode: CallInfoMode
multi_party: CallInfoMultiParty
number: Optional[int] = None
type: Optional[int] = None
# -----------------------------------------------------------------------------
# Hands-Free Control Interoperability Requirements
# -----------------------------------------------------------------------------
@@ -408,9 +320,8 @@ class Configuration:
class AtResponseType(enum.Enum):
"""
Indicates if a response is expected from an AT command, and if multiple responses are accepted.
"""
"""Indicate if a response is expected from an AT command, and if multiple
responses are accepted."""
NONE = 0
SINGLE = 1
@@ -444,20 +355,9 @@ class HfIndicatorState:
enabled: bool = False
class HfProtocol(pyee.EventEmitter):
"""
Implementation for the Hands-Free side of the Hands-Free profile.
Reference specification Hands-Free Profile v1.8.
Emitted events:
codec_negotiation: When codec is renegotiated, notify the new codec.
Args:
active_codec: AudioCodec
ag_indicator: When AG update their indicators, notify the new state.
Args:
ag_indicator: AgIndicator
"""
class HfProtocol:
"""Implementation for the Hands-Free side of the Hands-Free profile.
Reference specification Hands-Free Profile v1.8"""
supported_hf_features: int
supported_audio_codecs: List[AudioCodec]
@@ -477,18 +377,14 @@ class HfProtocol(pyee.EventEmitter):
response_queue: asyncio.Queue
unsolicited_queue: asyncio.Queue
read_buffer: bytearray
active_codec: AudioCodec
def __init__(self, dlc: rfcomm.DLC, configuration: Configuration) -> None:
super().__init__()
def __init__(self, dlc: rfcomm.DLC, configuration: Configuration):
# Configure internal state.
self.dlc = dlc
self.command_lock = asyncio.Lock()
self.response_queue = asyncio.Queue()
self.unsolicited_queue = asyncio.Queue()
self.read_buffer = bytearray()
self.active_codec = AudioCodec.CVSD
# Build local features.
self.supported_hf_features = sum(configuration.supported_hf_features)
@@ -513,12 +409,10 @@ class HfProtocol(pyee.EventEmitter):
def supports_ag_feature(self, feature: AgFeature) -> bool:
return (self.supported_ag_features & feature) != 0
# Read AT messages from the RFCOMM channel.
# Enqueue AT commands, responses, unsolicited responses to their
# respective queues, and set the corresponding event.
def _read_at(self, data: bytes):
"""
Reads AT messages from the RFCOMM channel.
Enqueues AT commands, responses, unsolicited responses to their respective queues, and set the corresponding event.
"""
# Append to the read buffer.
self.read_buffer.extend(data)
@@ -546,25 +440,17 @@ class HfProtocol(pyee.EventEmitter):
else:
logger.warning(f"dropping unexpected response with code '{response.code}'")
# Send an AT command and wait for the peer response.
# Wait for the AT responses sent by the peer, to the status code.
# Raises asyncio.TimeoutError if the status is not received
# after a timeout (default 1 second).
# Raises ProtocolError if the status is not OK.
async def execute_command(
self,
cmd: str,
timeout: float = 1.0,
response_type: AtResponseType = AtResponseType.NONE,
) -> Union[None, AtResponse, List[AtResponse]]:
"""
Sends an AT command and wait for the peer response.
Wait for the AT responses sent by the peer, to the status code.
Args:
cmd: the AT command in string to execute.
timeout: timeout in float seconds.
response_type: type of response.
Raises:
asyncio.TimeoutError: the status is not received after a timeout (default 1 second).
ProtocolError: the status is not OK.
"""
async with self.command_lock:
logger.debug(f">>> {cmd}")
self.dlc.write(cmd + '\r')
@@ -587,9 +473,8 @@ class HfProtocol(pyee.EventEmitter):
raise HfpProtocolError(result.code)
responses.append(result)
# 4.2.1 Service Level Connection Initialization.
async def initiate_slc(self):
"""4.2.1 Service Level Connection Initialization."""
# 4.2.1.1 Supported features exchange
# First, in the initialization procedure, the HF shall send the
# AT+BRSF=<HF supported features> command to the AG to both notify
@@ -729,17 +614,16 @@ class HfProtocol(pyee.EventEmitter):
logger.info("SLC setup completed")
# 4.11.2 Audio Connection Setup by HF
async def setup_audio_connection(self):
"""4.11.2 Audio Connection Setup by HF."""
# When the HF triggers the establishment of the Codec Connection it
# shall send the AT command AT+BCC to the AG. The AG shall respond with
# OK if it will start the Codec Connection procedure, and with ERROR
# if it cannot start the Codec Connection procedure.
await self.execute_command("AT+BCC")
# 4.11.3 Codec Connection Setup
async def setup_codec_connection(self, codec_id: int):
"""4.11.3 Codec Connection Setup."""
# The AG shall send a +BCS=<Codec ID> unsolicited response to the HF.
# The HF shall then respond to the incoming unsolicited response with
# the AT command AT+BCS=<Codec ID>. The ID shall be the same as in the
@@ -757,29 +641,27 @@ class HfProtocol(pyee.EventEmitter):
# Synchronous Connection with the settings that are determined by the
# ID. The HF shall be ready to accept the synchronous connection
# establishment as soon as it has sent the AT commands AT+BCS=<Codec ID>.
self.active_codec = AudioCodec(codec_id)
self.emit('codec_negotiation', self.active_codec)
logger.info("codec connection setup completed")
# 4.13.1 Answer Incoming Call from the HF In-Band Ringing
async def answer_incoming_call(self):
"""4.13.1 Answer Incoming Call from the HF - In-Band Ringing."""
# The user accepts the incoming voice call by using the proper means
# provided by the HF. The HF shall then send the ATA command
# (see Section 4.34) to the AG. The AG shall then begin the procedure for
# accepting the incoming call.
await self.execute_command("ATA")
# 4.14.1 Reject an Incoming Call from the HF
async def reject_incoming_call(self):
"""4.14.1 Reject an Incoming Call from the HF."""
# The user rejects the incoming call by using the User Interface on the
# Hands-Free unit. The HF shall then send the AT+CHUP command
# (see Section 4.34) to the AG. This may happen at any time during the
# procedures described in Sections 4.13.1 and 4.13.2.
await self.execute_command("AT+CHUP")
# 4.15.1 Terminate a Call Process from the HF
async def terminate_call(self):
"""4.15.1 Terminate a Call Process from the HF."""
# The user may abort the ongoing call process using whatever means
# provided by the Hands-Free unit. The HF shall send AT+CHUP command
# (see Section 4.34) to the AG, and the AG shall then start the
@@ -788,35 +670,8 @@ class HfProtocol(pyee.EventEmitter):
# code, with the value indicating (call=0).
await self.execute_command("AT+CHUP")
async def query_current_calls(self) -> List[CallInfo]:
"""4.32.1 Query List of Current Calls in AG.
Return:
List of current calls in AG.
"""
responses = await self.execute_command(
"AT+CLCC", response_type=AtResponseType.MULTIPLE
)
assert isinstance(responses, list)
calls = []
for response in responses:
call_info = CallInfo(
index=int(response.parameters[0]),
direction=CallInfoDirection(int(response.parameters[1])),
status=CallInfoStatus(int(response.parameters[2])),
mode=CallInfoMode(int(response.parameters[3])),
multi_party=CallInfoMultiParty(int(response.parameters[4])),
)
if len(response.parameters) >= 7:
call_info.number = int(response.parameters[5])
call_info.type = int(response.parameters[6])
calls.append(call_info)
return calls
async def update_ag_indicator(self, index: int, value: int):
self.ag_indicators[index].current_status = value
self.emit('ag_indicator', self.ag_indicators[index])
logger.info(
f"AG indicator updated: {self.ag_indicators[index].description}, {value}"
)
@@ -834,11 +689,9 @@ class HfProtocol(pyee.EventEmitter):
logging.info(f"unhandled unsolicited response {result.code}")
async def run(self):
"""
Main routine for the Hands-Free side of the HFP protocol.
Initiates the service level connection then loops handling unsolicited AG responses.
"""
"""Main rountine for the Hands-Free side of the HFP protocol.
Initiates the service level connection then loops handling
unsolicited AG responses."""
try:
await self.initiate_slc()
@@ -854,13 +707,9 @@ class HfProtocol(pyee.EventEmitter):
# -----------------------------------------------------------------------------
# Profile version (normative).
# Hands-Free Profile v1.8, 5.3 SDP Interoperability Requirements
class ProfileVersion(enum.IntEnum):
"""
Profile version (normative).
Hands-Free Profile v1.8, 5.3 SDP Interoperability Requirements.
"""
V1_5 = 0x0105
V1_6 = 0x0106
V1_7 = 0x0107
@@ -868,13 +717,9 @@ class ProfileVersion(enum.IntEnum):
V1_9 = 0x0109
# HF supported features (normative).
# Hands-Free Profile v1.8, 5.3 SDP Interoperability Requirements
class HfSdpFeature(enum.IntFlag):
"""
HF supported features (normative).
Hands-Free Profile v1.8, 5.3 SDP Interoperability Requirements.
"""
EC_NR = 0x01 # Echo Cancel & Noise reduction
THREE_WAY_CALLING = 0x02
CLI_PRESENTATION_CAPABILITY = 0x04
@@ -885,13 +730,9 @@ class HfSdpFeature(enum.IntFlag):
VOICE_RECOGNITION_TEST = 0x80
# AG supported features (normative).
# Hands-Free Profile v1.8, 5.3 SDP Interoperability Requirements
class AgSdpFeature(enum.IntFlag):
"""
AG supported features (normative).
Hands-Free Profile v1.8, 5.3 SDP Interoperability Requirements.
"""
THREE_WAY_CALLING = 0x01
EC_NR = 0x02 # Echo Cancel & Noise reduction
VOICE_RECOGNITION_FUNCTION = 0x04
@@ -905,12 +746,9 @@ class AgSdpFeature(enum.IntFlag):
def sdp_records(
service_record_handle: int, rfcomm_channel: int, configuration: Configuration
) -> List[ServiceAttribute]:
"""
Generates the SDP record for HFP Hands-Free support.
"""Generate the SDP record for HFP Hands-Free support.
The record exposes the features supported in the input configuration,
and the allocated RFCOMM channel.
"""
and the allocated RFCOMM channel."""
hf_supported_features = 0
@@ -981,175 +819,3 @@ def sdp_records(
DataElement.unsigned_integer_16(hf_supported_features),
),
]
# -----------------------------------------------------------------------------
# ESCO Codec Default Parameters
# -----------------------------------------------------------------------------
# Hands-Free Profile v1.8, 5.7 Codec Interoperability Requirements
class DefaultCodecParameters(enum.IntEnum):
SCO_CVSD_D0 = enum.auto()
SCO_CVSD_D1 = enum.auto()
ESCO_CVSD_S1 = enum.auto()
ESCO_CVSD_S2 = enum.auto()
ESCO_CVSD_S3 = enum.auto()
ESCO_CVSD_S4 = enum.auto()
ESCO_MSBC_T1 = enum.auto()
ESCO_MSBC_T2 = enum.auto()
@dataclasses.dataclass
class EscoParameters:
# Codec specific
transmit_coding_format: CodingFormat
receive_coding_format: CodingFormat
packet_type: HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType
retransmission_effort: HCI_Enhanced_Setup_Synchronous_Connection_Command.RetransmissionEffort
max_latency: int
# Common
input_coding_format: CodingFormat = CodingFormat(CodecID.LINEAR_PCM)
output_coding_format: CodingFormat = CodingFormat(CodecID.LINEAR_PCM)
input_coded_data_size: int = 16
output_coded_data_size: int = 16
input_pcm_data_format: HCI_Enhanced_Setup_Synchronous_Connection_Command.PcmDataFormat = (
HCI_Enhanced_Setup_Synchronous_Connection_Command.PcmDataFormat.TWOS_COMPLEMENT
)
output_pcm_data_format: HCI_Enhanced_Setup_Synchronous_Connection_Command.PcmDataFormat = (
HCI_Enhanced_Setup_Synchronous_Connection_Command.PcmDataFormat.TWOS_COMPLEMENT
)
input_pcm_sample_payload_msb_position: int = 0
output_pcm_sample_payload_msb_position: int = 0
input_data_path: HCI_Enhanced_Setup_Synchronous_Connection_Command.DataPath = (
HCI_Enhanced_Setup_Synchronous_Connection_Command.DataPath.HCI
)
output_data_path: HCI_Enhanced_Setup_Synchronous_Connection_Command.DataPath = (
HCI_Enhanced_Setup_Synchronous_Connection_Command.DataPath.HCI
)
input_transport_unit_size: int = 0
output_transport_unit_size: int = 0
input_bandwidth: int = 16000
output_bandwidth: int = 16000
transmit_bandwidth: int = 8000
receive_bandwidth: int = 8000
transmit_codec_frame_size: int = 60
receive_codec_frame_size: int = 60
def asdict(self) -> Dict[str, Any]:
# dataclasses.asdict() will recursively deep-copy the entire object,
# which is expensive and breaks CodingFormat object, so let it simply copy here.
return self.__dict__
_ESCO_PARAMETERS_CVSD_D0 = EscoParameters(
transmit_coding_format=CodingFormat(CodecID.CVSD),
receive_coding_format=CodingFormat(CodecID.CVSD),
max_latency=0xFFFF,
packet_type=HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.HV1,
retransmission_effort=HCI_Enhanced_Setup_Synchronous_Connection_Command.RetransmissionEffort.NO_RETRANSMISSION,
)
_ESCO_PARAMETERS_CVSD_D1 = EscoParameters(
transmit_coding_format=CodingFormat(CodecID.CVSD),
receive_coding_format=CodingFormat(CodecID.CVSD),
max_latency=0xFFFF,
packet_type=HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.HV3,
retransmission_effort=HCI_Enhanced_Setup_Synchronous_Connection_Command.RetransmissionEffort.NO_RETRANSMISSION,
)
_ESCO_PARAMETERS_CVSD_S1 = EscoParameters(
transmit_coding_format=CodingFormat(CodecID.CVSD),
receive_coding_format=CodingFormat(CodecID.CVSD),
max_latency=0x0007,
packet_type=(
HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.EV3
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_2_EV3
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_3_EV3
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_2_EV5
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_3_EV5
),
retransmission_effort=HCI_Enhanced_Setup_Synchronous_Connection_Command.RetransmissionEffort.OPTIMIZE_FOR_POWER,
)
_ESCO_PARAMETERS_CVSD_S2 = EscoParameters(
transmit_coding_format=CodingFormat(CodecID.CVSD),
receive_coding_format=CodingFormat(CodecID.CVSD),
max_latency=0x0007,
packet_type=(
HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.EV3
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_3_EV3
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_2_EV5
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_3_EV5
),
retransmission_effort=HCI_Enhanced_Setup_Synchronous_Connection_Command.RetransmissionEffort.OPTIMIZE_FOR_POWER,
)
_ESCO_PARAMETERS_CVSD_S3 = EscoParameters(
transmit_coding_format=CodingFormat(CodecID.CVSD),
receive_coding_format=CodingFormat(CodecID.CVSD),
max_latency=0x000A,
packet_type=(
HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.EV3
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_3_EV3
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_2_EV5
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_3_EV5
),
retransmission_effort=HCI_Enhanced_Setup_Synchronous_Connection_Command.RetransmissionEffort.OPTIMIZE_FOR_POWER,
)
_ESCO_PARAMETERS_CVSD_S4 = EscoParameters(
transmit_coding_format=CodingFormat(CodecID.CVSD),
receive_coding_format=CodingFormat(CodecID.CVSD),
max_latency=0x000C,
packet_type=(
HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.EV3
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_3_EV3
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_2_EV5
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_3_EV5
),
retransmission_effort=HCI_Enhanced_Setup_Synchronous_Connection_Command.RetransmissionEffort.OPTIMIZE_FOR_QUALITY,
)
_ESCO_PARAMETERS_MSBC_T1 = EscoParameters(
transmit_coding_format=CodingFormat(CodecID.MSBC),
receive_coding_format=CodingFormat(CodecID.MSBC),
max_latency=0x0008,
packet_type=(
HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.EV3
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_3_EV3
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_2_EV5
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_3_EV5
),
input_bandwidth=32000,
output_bandwidth=32000,
retransmission_effort=HCI_Enhanced_Setup_Synchronous_Connection_Command.RetransmissionEffort.OPTIMIZE_FOR_QUALITY,
)
_ESCO_PARAMETERS_MSBC_T2 = EscoParameters(
transmit_coding_format=CodingFormat(CodecID.MSBC),
receive_coding_format=CodingFormat(CodecID.MSBC),
max_latency=0x000D,
packet_type=(
HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.EV3
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_2_EV3
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_3_EV3
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_2_EV5
| HCI_Enhanced_Setup_Synchronous_Connection_Command.PacketType.NO_3_EV5
),
input_bandwidth=32000,
output_bandwidth=32000,
retransmission_effort=HCI_Enhanced_Setup_Synchronous_Connection_Command.RetransmissionEffort.OPTIMIZE_FOR_QUALITY,
)
ESCO_PARAMETERS = {
DefaultCodecParameters.SCO_CVSD_D0: _ESCO_PARAMETERS_CVSD_D0,
DefaultCodecParameters.SCO_CVSD_D1: _ESCO_PARAMETERS_CVSD_D1,
DefaultCodecParameters.ESCO_CVSD_S1: _ESCO_PARAMETERS_CVSD_S1,
DefaultCodecParameters.ESCO_CVSD_S2: _ESCO_PARAMETERS_CVSD_S2,
DefaultCodecParameters.ESCO_CVSD_S3: _ESCO_PARAMETERS_CVSD_S3,
DefaultCodecParameters.ESCO_CVSD_S4: _ESCO_PARAMETERS_CVSD_S4,
DefaultCodecParameters.ESCO_MSBC_T1: _ESCO_PARAMETERS_MSBC_T1,
DefaultCodecParameters.ESCO_MSBC_T2: _ESCO_PARAMETERS_MSBC_T2,
}

View File

@@ -1,554 +0,0 @@
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
from dataclasses import dataclass
import logging
import enum
import struct
from abc import ABC, abstractmethod
from pyee import EventEmitter
from typing import Optional, Callable, TYPE_CHECKING
from typing_extensions import override
from bumble import l2cap, device
from bumble.colors import color
from bumble.core import InvalidStateError, ProtocolError
from .hci import Address
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
# fmt: on
HID_CONTROL_PSM = 0x0011
HID_INTERRUPT_PSM = 0x0013
class Message:
message_type: MessageType
# Report types
class ReportType(enum.IntEnum):
OTHER_REPORT = 0x00
INPUT_REPORT = 0x01
OUTPUT_REPORT = 0x02
FEATURE_REPORT = 0x03
# Handshake parameters
class Handshake(enum.IntEnum):
SUCCESSFUL = 0x00
NOT_READY = 0x01
ERR_INVALID_REPORT_ID = 0x02
ERR_UNSUPPORTED_REQUEST = 0x03
ERR_INVALID_PARAMETER = 0x04
ERR_UNKNOWN = 0x0E
ERR_FATAL = 0x0F
# Message Type
class MessageType(enum.IntEnum):
HANDSHAKE = 0x00
CONTROL = 0x01
GET_REPORT = 0x04
SET_REPORT = 0x05
GET_PROTOCOL = 0x06
SET_PROTOCOL = 0x07
DATA = 0x0A
# Protocol modes
class ProtocolMode(enum.IntEnum):
BOOT_PROTOCOL = 0x00
REPORT_PROTOCOL = 0x01
# Control Operations
class ControlCommand(enum.IntEnum):
SUSPEND = 0x03
EXIT_SUSPEND = 0x04
VIRTUAL_CABLE_UNPLUG = 0x05
# Class Method to derive header
@classmethod
def header(cls, lower_bits: int = 0x00) -> bytes:
return bytes([(cls.message_type << 4) | lower_bits])
# HIDP messages
@dataclass
class GetReportMessage(Message):
report_type: int
report_id: int
buffer_size: int
message_type = Message.MessageType.GET_REPORT
def __bytes__(self) -> bytes:
packet_bytes = bytearray()
packet_bytes.append(self.report_id)
if self.buffer_size == 0:
return self.header(self.report_type) + packet_bytes
else:
return (
self.header(0x08 | self.report_type)
+ packet_bytes
+ struct.pack("<H", self.buffer_size)
)
@dataclass
class SetReportMessage(Message):
report_type: int
data: bytes
message_type = Message.MessageType.SET_REPORT
def __bytes__(self) -> bytes:
return self.header(self.report_type) + self.data
@dataclass
class SendControlData(Message):
report_type: int
data: bytes
message_type = Message.MessageType.DATA
def __bytes__(self) -> bytes:
return self.header(self.report_type) + self.data
@dataclass
class GetProtocolMessage(Message):
message_type = Message.MessageType.GET_PROTOCOL
def __bytes__(self) -> bytes:
return self.header()
@dataclass
class SetProtocolMessage(Message):
protocol_mode: int
message_type = Message.MessageType.SET_PROTOCOL
def __bytes__(self) -> bytes:
return self.header(self.protocol_mode)
@dataclass
class Suspend(Message):
message_type = Message.MessageType.CONTROL
def __bytes__(self) -> bytes:
return self.header(Message.ControlCommand.SUSPEND)
@dataclass
class ExitSuspend(Message):
message_type = Message.MessageType.CONTROL
def __bytes__(self) -> bytes:
return self.header(Message.ControlCommand.EXIT_SUSPEND)
@dataclass
class VirtualCableUnplug(Message):
message_type = Message.MessageType.CONTROL
def __bytes__(self) -> bytes:
return self.header(Message.ControlCommand.VIRTUAL_CABLE_UNPLUG)
# Device sends input report, host sends output report.
@dataclass
class SendData(Message):
data: bytes
report_type: int
message_type = Message.MessageType.DATA
def __bytes__(self) -> bytes:
return self.header(self.report_type) + self.data
@dataclass
class SendHandshakeMessage(Message):
result_code: int
message_type = Message.MessageType.HANDSHAKE
def __bytes__(self) -> bytes:
return self.header(self.result_code)
# -----------------------------------------------------------------------------
class HID(ABC, EventEmitter):
l2cap_ctrl_channel: Optional[l2cap.ClassicChannel] = None
l2cap_intr_channel: Optional[l2cap.ClassicChannel] = None
connection: Optional[device.Connection] = None
class Role(enum.IntEnum):
HOST = 0x00
DEVICE = 0x01
def __init__(self, device: device.Device, role: Role) -> None:
super().__init__()
self.remote_device_bd_address: Optional[Address] = None
self.device = device
self.role = role
# Register ourselves with the L2CAP channel manager
device.register_l2cap_server(HID_CONTROL_PSM, self.on_l2cap_connection)
device.register_l2cap_server(HID_INTERRUPT_PSM, self.on_l2cap_connection)
device.on('connection', self.on_device_connection)
async def connect_control_channel(self) -> None:
# Create a new L2CAP connection - control channel
try:
self.l2cap_ctrl_channel = await self.device.l2cap_channel_manager.connect(
self.connection, HID_CONTROL_PSM
)
except ProtocolError:
logging.exception(f'L2CAP connection failed.')
raise
assert self.l2cap_ctrl_channel is not None
# Become a sink for the L2CAP channel
self.l2cap_ctrl_channel.sink = self.on_ctrl_pdu
async def connect_interrupt_channel(self) -> None:
# Create a new L2CAP connection - interrupt channel
try:
self.l2cap_intr_channel = await self.device.l2cap_channel_manager.connect(
self.connection, HID_INTERRUPT_PSM
)
except ProtocolError:
logging.exception(f'L2CAP connection failed.')
raise
assert self.l2cap_intr_channel is not None
# Become a sink for the L2CAP channel
self.l2cap_intr_channel.sink = self.on_intr_pdu
async def disconnect_interrupt_channel(self) -> None:
if self.l2cap_intr_channel is None:
raise InvalidStateError('invalid state')
channel = self.l2cap_intr_channel
self.l2cap_intr_channel = None
await channel.disconnect()
async def disconnect_control_channel(self) -> None:
if self.l2cap_ctrl_channel is None:
raise InvalidStateError('invalid state')
channel = self.l2cap_ctrl_channel
self.l2cap_ctrl_channel = None
await channel.disconnect()
def on_device_connection(self, connection: device.Connection) -> None:
self.connection = connection
self.remote_device_bd_address = connection.peer_address
connection.on('disconnection', self.on_device_disconnection)
def on_device_disconnection(self, reason: int) -> None:
self.connection = None
def on_l2cap_connection(self, l2cap_channel: l2cap.ClassicChannel) -> None:
logger.debug(f'+++ New L2CAP connection: {l2cap_channel}')
l2cap_channel.on('open', lambda: self.on_l2cap_channel_open(l2cap_channel))
l2cap_channel.on('close', lambda: self.on_l2cap_channel_close(l2cap_channel))
def on_l2cap_channel_open(self, l2cap_channel: l2cap.ClassicChannel) -> None:
if l2cap_channel.psm == HID_CONTROL_PSM:
self.l2cap_ctrl_channel = l2cap_channel
self.l2cap_ctrl_channel.sink = self.on_ctrl_pdu
else:
self.l2cap_intr_channel = l2cap_channel
self.l2cap_intr_channel.sink = self.on_intr_pdu
logger.debug(f'$$$ L2CAP channel open: {l2cap_channel}')
def on_l2cap_channel_close(self, l2cap_channel: l2cap.ClassicChannel) -> None:
if l2cap_channel.psm == HID_CONTROL_PSM:
self.l2cap_ctrl_channel = None
else:
self.l2cap_intr_channel = None
logger.debug(f'$$$ L2CAP channel close: {l2cap_channel}')
@abstractmethod
def on_ctrl_pdu(self, pdu: bytes) -> None:
pass
def on_intr_pdu(self, pdu: bytes) -> None:
logger.debug(f'<<< HID INTERRUPT PDU: {pdu.hex()}')
self.emit("interrupt_data", pdu)
def send_pdu_on_ctrl(self, msg: bytes) -> None:
assert self.l2cap_ctrl_channel
self.l2cap_ctrl_channel.send_pdu(msg)
def send_pdu_on_intr(self, msg: bytes) -> None:
assert self.l2cap_intr_channel
self.l2cap_intr_channel.send_pdu(msg)
def send_data(self, data: bytes) -> None:
if self.role == HID.Role.HOST:
report_type = Message.ReportType.OUTPUT_REPORT
else:
report_type = Message.ReportType.INPUT_REPORT
msg = SendData(data, report_type)
hid_message = bytes(msg)
if self.l2cap_intr_channel is not None:
logger.debug(f'>>> HID INTERRUPT SEND DATA, PDU: {hid_message.hex()}')
self.send_pdu_on_intr(hid_message)
def virtual_cable_unplug(self) -> None:
msg = VirtualCableUnplug()
hid_message = bytes(msg)
logger.debug(f'>>> HID CONTROL VIRTUAL CABLE UNPLUG, PDU: {hid_message.hex()}')
self.send_pdu_on_ctrl(hid_message)
# -----------------------------------------------------------------------------
class Device(HID):
class GetSetReturn(enum.IntEnum):
FAILURE = 0x00
REPORT_ID_NOT_FOUND = 0x01
ERR_UNSUPPORTED_REQUEST = 0x02
ERR_UNKNOWN = 0x03
ERR_INVALID_PARAMETER = 0x04
SUCCESS = 0xFF
class GetSetStatus:
def __init__(self) -> None:
self.data = bytearray()
self.status = 0
def __init__(self, device: device.Device) -> None:
super().__init__(device, HID.Role.DEVICE)
get_report_cb: Optional[Callable[[int, int, int], None]] = None
set_report_cb: Optional[Callable[[int, int, int, bytes], None]] = None
get_protocol_cb: Optional[Callable[[], None]] = None
set_protocol_cb: Optional[Callable[[int], None]] = None
@override
def on_ctrl_pdu(self, pdu: bytes) -> None:
logger.debug(f'<<< HID CONTROL PDU: {pdu.hex()}')
param = pdu[0] & 0x0F
message_type = pdu[0] >> 4
if message_type == Message.MessageType.GET_REPORT:
logger.debug('<<< HID GET REPORT')
self.handle_get_report(pdu)
elif message_type == Message.MessageType.SET_REPORT:
logger.debug('<<< HID SET REPORT')
self.handle_set_report(pdu)
elif message_type == Message.MessageType.GET_PROTOCOL:
logger.debug('<<< HID GET PROTOCOL')
self.handle_get_protocol(pdu)
elif message_type == Message.MessageType.SET_PROTOCOL:
logger.debug('<<< HID SET PROTOCOL')
self.handle_set_protocol(pdu)
elif message_type == Message.MessageType.DATA:
logger.debug('<<< HID CONTROL DATA')
self.emit('control_data', pdu)
elif message_type == Message.MessageType.CONTROL:
if param == Message.ControlCommand.SUSPEND:
logger.debug('<<< HID SUSPEND')
self.emit('suspend')
elif param == Message.ControlCommand.EXIT_SUSPEND:
logger.debug('<<< HID EXIT SUSPEND')
self.emit('exit_suspend')
elif param == Message.ControlCommand.VIRTUAL_CABLE_UNPLUG:
logger.debug('<<< HID VIRTUAL CABLE UNPLUG')
self.emit('virtual_cable_unplug')
else:
logger.debug('<<< HID CONTROL OPERATION UNSUPPORTED')
else:
logger.debug('<<< HID MESSAGE TYPE UNSUPPORTED')
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
def send_handshake_message(self, result_code: int) -> None:
msg = SendHandshakeMessage(result_code)
hid_message = bytes(msg)
logger.debug(f'>>> HID HANDSHAKE MESSAGE, PDU: {hid_message.hex()}')
self.send_pdu_on_ctrl(hid_message)
def send_control_data(self, report_type: int, data: bytes):
msg = SendControlData(report_type=report_type, data=data)
hid_message = bytes(msg)
logger.debug(f'>>> HID CONTROL DATA: {hid_message.hex()}')
self.send_pdu_on_ctrl(hid_message)
def handle_get_report(self, pdu: bytes):
if self.get_report_cb is None:
logger.debug("GetReport callback not registered !!")
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
return
report_type = pdu[0] & 0x03
buffer_flag = (pdu[0] & 0x08) >> 3
report_id = pdu[1]
logger.debug(f"buffer_flag: {buffer_flag}")
if buffer_flag == 1:
buffer_size = (pdu[3] << 8) | pdu[2]
else:
buffer_size = 0
ret = self.get_report_cb(report_id, report_type, buffer_size)
assert ret is not None
if ret.status == self.GetSetReturn.FAILURE:
self.send_handshake_message(Message.Handshake.ERR_UNKNOWN)
elif ret.status == self.GetSetReturn.SUCCESS:
data = bytearray()
data.append(report_id)
data.extend(ret.data)
if len(data) < self.l2cap_ctrl_channel.peer_mtu: # type: ignore[union-attr]
self.send_control_data(report_type=report_type, data=data)
else:
self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER)
elif ret.status == self.GetSetReturn.REPORT_ID_NOT_FOUND:
self.send_handshake_message(Message.Handshake.ERR_INVALID_REPORT_ID)
elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER:
self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER)
elif ret.status == self.GetSetReturn.ERR_UNSUPPORTED_REQUEST:
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
def register_get_report_cb(self, cb: Callable[[int, int, int], None]) -> None:
self.get_report_cb = cb
logger.debug("GetReport callback registered successfully")
def handle_set_report(self, pdu: bytes):
if self.set_report_cb is None:
logger.debug("SetReport callback not registered !!")
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
return
report_type = pdu[0] & 0x03
report_id = pdu[1]
report_data = pdu[2:]
report_size = len(report_data) + 1
ret = self.set_report_cb(report_id, report_type, report_size, report_data)
assert ret is not None
if ret.status == self.GetSetReturn.SUCCESS:
self.send_handshake_message(Message.Handshake.SUCCESSFUL)
elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER:
self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER)
elif ret.status == self.GetSetReturn.REPORT_ID_NOT_FOUND:
self.send_handshake_message(Message.Handshake.ERR_INVALID_REPORT_ID)
else:
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
def register_set_report_cb(
self, cb: Callable[[int, int, int, bytes], None]
) -> None:
self.set_report_cb = cb
logger.debug("SetReport callback registered successfully")
def handle_get_protocol(self, pdu: bytes):
if self.get_protocol_cb is None:
logger.debug("GetProtocol callback not registered !!")
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
return
ret = self.get_protocol_cb()
assert ret is not None
if ret.status == self.GetSetReturn.SUCCESS:
self.send_control_data(Message.ReportType.OTHER_REPORT, ret.data)
else:
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
def register_get_protocol_cb(self, cb: Callable[[], None]) -> None:
self.get_protocol_cb = cb
logger.debug("GetProtocol callback registered successfully")
def handle_set_protocol(self, pdu: bytes):
if self.set_protocol_cb is None:
logger.debug("SetProtocol callback not registered !!")
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
return
ret = self.set_protocol_cb(pdu[0] & 0x01)
assert ret is not None
if ret.status == self.GetSetReturn.SUCCESS:
self.send_handshake_message(Message.Handshake.SUCCESSFUL)
else:
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
def register_set_protocol_cb(self, cb: Callable[[int], None]) -> None:
self.set_protocol_cb = cb
logger.debug("SetProtocol callback registered successfully")
# -----------------------------------------------------------------------------
class Host(HID):
def __init__(self, device: device.Device) -> None:
super().__init__(device, HID.Role.HOST)
def get_report(self, report_type: int, report_id: int, buffer_size: int) -> None:
msg = GetReportMessage(
report_type=report_type, report_id=report_id, buffer_size=buffer_size
)
hid_message = bytes(msg)
logger.debug(f'>>> HID CONTROL GET REPORT, PDU: {hid_message.hex()}')
self.send_pdu_on_ctrl(hid_message)
def set_report(self, report_type: int, data: bytes) -> None:
msg = SetReportMessage(report_type=report_type, data=data)
hid_message = bytes(msg)
logger.debug(f'>>> HID CONTROL SET REPORT, PDU:{hid_message.hex()}')
self.send_pdu_on_ctrl(hid_message)
def get_protocol(self) -> None:
msg = GetProtocolMessage()
hid_message = bytes(msg)
logger.debug(f'>>> HID CONTROL GET PROTOCOL, PDU: {hid_message.hex()}')
self.send_pdu_on_ctrl(hid_message)
def set_protocol(self, protocol_mode: int) -> None:
msg = SetProtocolMessage(protocol_mode=protocol_mode)
hid_message = bytes(msg)
logger.debug(f'>>> HID CONTROL SET PROTOCOL, PDU: {hid_message.hex()}')
self.send_pdu_on_ctrl(hid_message)
def suspend(self) -> None:
msg = Suspend()
hid_message = bytes(msg)
logger.debug(f'>>> HID CONTROL SUSPEND, PDU:{hid_message.hex()}')
self.send_pdu_on_ctrl(hid_message)
def exit_suspend(self) -> None:
msg = ExitSuspend()
hid_message = bytes(msg)
logger.debug(f'>>> HID CONTROL EXIT SUSPEND, PDU:{hid_message.hex()}')
self.send_pdu_on_ctrl(hid_message)
@override
def on_ctrl_pdu(self, pdu: bytes) -> None:
logger.debug(f'<<< HID CONTROL PDU: {pdu.hex()}')
param = pdu[0] & 0x0F
message_type = pdu[0] >> 4
if message_type == Message.MessageType.HANDSHAKE:
logger.debug(f'<<< HID HANDSHAKE: {Message.Handshake(param).name}')
self.emit('handshake', Message.Handshake(param))
elif message_type == Message.MessageType.DATA:
logger.debug('<<< HID CONTROL DATA')
self.emit('control_data', pdu)
elif message_type == Message.MessageType.CONTROL:
if param == Message.ControlCommand.VIRTUAL_CABLE_UNPLUG:
logger.debug('<<< HID VIRTUAL CABLE UNPLUG')
self.emit('virtual_cable_unplug')
else:
logger.debug('<<< HID CONTROL OPERATION UNSUPPORTED')
else:
logger.debug('<<< HID MESSAGE TYPE UNSUPPORTED')

File diff suppressed because it is too large Load Diff

View File

@@ -17,7 +17,6 @@
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import dataclasses
import enum
import logging
import struct
@@ -39,7 +38,6 @@ from typing import (
TYPE_CHECKING,
)
from .utils import deprecated
from .colors import color
from .core import BT_CENTRAL_ROLE, InvalidStateError, ProtocolError
from .hci import (
@@ -149,10 +147,9 @@ L2CAP_INVALID_CID_IN_REQUEST_REASON = 0x0002
L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_CREDITS = 65535
L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU = 23
L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MTU = 65535
L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MPS = 23
L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MPS = 65533
L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU = 2048
L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU = 2046
L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS = 2048
L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS = 256
@@ -170,37 +167,6 @@ L2CAP_MTU_CONFIGURATION_PARAMETER_TYPE = 0x01
# pylint: disable=invalid-name
@dataclasses.dataclass
class ClassicChannelSpec:
psm: Optional[int] = None
mtu: int = L2CAP_DEFAULT_MTU
@dataclasses.dataclass
class LeCreditBasedChannelSpec:
psm: Optional[int] = None
mtu: int = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU
mps: int = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS
max_credits: int = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS
def __post_init__(self):
if (
self.max_credits < 1
or self.max_credits > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_CREDITS
):
raise ValueError('max credits out of range')
if (
self.mtu < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU
or self.mtu > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MTU
):
raise ValueError('MTU out of range')
if (
self.mps < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MPS
or self.mps > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MPS
):
raise ValueError('MPS out of range')
class L2CAP_PDU:
'''
See Bluetooth spec @ Vol 3, Part A - 3 DATA PACKET FORMAT
@@ -208,7 +174,7 @@ class L2CAP_PDU:
@staticmethod
def from_bytes(data: bytes) -> L2CAP_PDU:
# Check parameters
# Sanity check
if len(data) < 4:
raise ValueError('not enough data for L2CAP header')
@@ -395,9 +361,6 @@ class L2CAP_Connection_Request(L2CAP_Control_Frame):
See Bluetooth spec @ Vol 3, Part A - 4.2 CONNECTION REQUEST
'''
psm: int
source_cid: int
@staticmethod
def parse_psm(data: bytes, offset: int = 0) -> Tuple[int, int]:
psm_length = 2
@@ -439,11 +402,6 @@ class L2CAP_Connection_Response(L2CAP_Control_Frame):
See Bluetooth spec @ Vol 3, Part A - 4.3 CONNECTION RESPONSE
'''
source_cid: int
destination_cid: int
status: int
result: int
CONNECTION_SUCCESSFUL = 0x0000
CONNECTION_PENDING = 0x0001
CONNECTION_REFUSED_PSM_NOT_SUPPORTED = 0x0002
@@ -718,7 +676,7 @@ class L2CAP_LE_Flow_Control_Credit(L2CAP_Control_Frame):
# -----------------------------------------------------------------------------
class ClassicChannel(EventEmitter):
class Channel(EventEmitter):
class State(enum.IntEnum):
# States
CLOSED = 0x00
@@ -749,8 +707,6 @@ class ClassicChannel(EventEmitter):
sink: Optional[Callable[[bytes], Any]]
state: State
connection: Connection
mtu: int
peer_mtu: int
def __init__(
self,
@@ -767,7 +723,6 @@ class ClassicChannel(EventEmitter):
self.signaling_cid = signaling_cid
self.state = self.State.CLOSED
self.mtu = mtu
self.peer_mtu = L2CAP_MIN_BR_EDR_MTU
self.psm = psm
self.source_cid = source_cid
self.destination_cid = 0
@@ -864,7 +819,7 @@ class ClassicChannel(EventEmitter):
[
(
L2CAP_MAXIMUM_TRANSMISSION_UNIT_CONFIGURATION_OPTION_TYPE,
struct.pack('<H', self.mtu),
struct.pack('<H', L2CAP_DEFAULT_MTU),
)
]
)
@@ -929,8 +884,8 @@ class ClassicChannel(EventEmitter):
options = L2CAP_Control_Frame.decode_configuration_options(request.options)
for option in options:
if option[0] == L2CAP_MTU_CONFIGURATION_PARAMETER_TYPE:
self.peer_mtu = struct.unpack('<H', option[1])[0]
logger.debug(f'peer MTU = {self.peer_mtu}')
self.mtu = struct.unpack('<H', option[1])[0]
logger.debug(f'MTU = {self.mtu}')
self.send_control_frame(
L2CAP_Configure_Response(
@@ -1029,13 +984,13 @@ class ClassicChannel(EventEmitter):
return (
f'Channel({self.source_cid}->{self.destination_cid}, '
f'PSM={self.psm}, '
f'MTU={self.mtu}/{self.peer_mtu}, '
f'MTU={self.mtu}, '
f'state={self.state.name})'
)
# -----------------------------------------------------------------------------
class LeCreditBasedChannel(EventEmitter):
class LeConnectionOrientedChannel(EventEmitter):
"""
LE Credit-based Connection Oriented Channel
"""
@@ -1049,13 +1004,11 @@ class LeCreditBasedChannel(EventEmitter):
CONNECTION_ERROR = 5
out_queue: Deque[bytes]
connection_result: Optional[asyncio.Future[LeCreditBasedChannel]]
connection_result: Optional[asyncio.Future[LeConnectionOrientedChannel]]
disconnection_result: Optional[asyncio.Future[None]]
in_sdu: Optional[bytes]
out_sdu: Optional[bytes]
state: State
connection: Connection
sink: Optional[Callable[[bytes], Any]]
def __init__(
self,
@@ -1118,7 +1071,7 @@ class LeCreditBasedChannel(EventEmitter):
def send_control_frame(self, frame: L2CAP_Control_Frame) -> None:
self.manager.send_control_frame(self.connection, L2CAP_LE_SIGNALING_CID, frame)
async def connect(self) -> LeCreditBasedChannel:
async def connect(self) -> LeConnectionOrientedChannel:
# Check that we're in the right state
if self.state != self.State.INIT:
raise InvalidStateError('not in a connectable state')
@@ -1389,67 +1342,15 @@ class LeCreditBasedChannel(EventEmitter):
)
# -----------------------------------------------------------------------------
class ClassicChannelServer(EventEmitter):
def __init__(
self,
manager: ChannelManager,
psm: int,
handler: Optional[Callable[[ClassicChannel], Any]],
mtu: int,
) -> None:
super().__init__()
self.manager = manager
self.handler = handler
self.psm = psm
self.mtu = mtu
def on_connection(self, channel: ClassicChannel) -> None:
self.emit('connection', channel)
if self.handler:
self.handler(channel)
def close(self) -> None:
if self.psm in self.manager.servers:
del self.manager.servers[self.psm]
# -----------------------------------------------------------------------------
class LeCreditBasedChannelServer(EventEmitter):
def __init__(
self,
manager: ChannelManager,
psm: int,
handler: Optional[Callable[[LeCreditBasedChannel], Any]],
max_credits: int,
mtu: int,
mps: int,
) -> None:
super().__init__()
self.manager = manager
self.handler = handler
self.psm = psm
self.max_credits = max_credits
self.mtu = mtu
self.mps = mps
def on_connection(self, channel: LeCreditBasedChannel) -> None:
self.emit('connection', channel)
if self.handler:
self.handler(channel)
def close(self) -> None:
if self.psm in self.manager.le_coc_servers:
del self.manager.le_coc_servers[self.psm]
# -----------------------------------------------------------------------------
class ChannelManager:
identifiers: Dict[int, int]
channels: Dict[int, Dict[int, Union[ClassicChannel, LeCreditBasedChannel]]]
servers: Dict[int, ClassicChannelServer]
le_coc_channels: Dict[int, Dict[int, LeCreditBasedChannel]]
le_coc_servers: Dict[int, LeCreditBasedChannelServer]
channels: Dict[int, Dict[int, Union[Channel, LeConnectionOrientedChannel]]]
servers: Dict[int, Callable[[Channel], Any]]
le_coc_channels: Dict[int, Dict[int, LeConnectionOrientedChannel]]
le_coc_servers: Dict[
int, Tuple[Callable[[LeConnectionOrientedChannel], Any], int, int, int]
]
le_coc_requests: Dict[int, L2CAP_LE_Credit_Based_Connection_Request]
fixed_channels: Dict[int, Optional[Callable[[int, bytes], Any]]]
_host: Optional[Host]
@@ -1528,6 +1429,21 @@ class ChannelManager:
raise RuntimeError('no free CID')
@staticmethod
def check_le_coc_parameters(max_credits: int, mtu: int, mps: int) -> None:
if (
max_credits < 1
or max_credits > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_CREDITS
):
raise ValueError('max credits out of range')
if mtu < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU:
raise ValueError('MTU too small')
if (
mps < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MPS
or mps > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MPS
):
raise ValueError('MPS out of range')
def next_identifier(self, connection: Connection) -> int:
identifier = (self.identifiers.setdefault(connection.handle, 0) + 1) % 256
self.identifiers[connection.handle] = identifier
@@ -1542,22 +1458,8 @@ class ChannelManager:
if cid in self.fixed_channels:
del self.fixed_channels[cid]
@deprecated("Please use create_classic_server")
def register_server(
self,
psm: int,
server: Callable[[ClassicChannel], Any],
) -> int:
return self.create_classic_server(
handler=server, spec=ClassicChannelSpec(psm=psm)
).psm
def create_classic_server(
self,
spec: ClassicChannelSpec,
handler: Optional[Callable[[ClassicChannel], Any]] = None,
) -> ClassicChannelServer:
if not spec.psm:
def register_server(self, psm: int, server: Callable[[Channel], Any]) -> int:
if psm == 0:
# Find a free PSM
for candidate in range(
L2CAP_PSM_DYNAMIC_RANGE_START, L2CAP_PSM_DYNAMIC_RANGE_END + 1, 2
@@ -1566,75 +1468,62 @@ class ChannelManager:
continue
if candidate in self.servers:
continue
spec.psm = candidate
psm = candidate
break
else:
raise InvalidStateError('no free PSM')
else:
# Check that the PSM isn't already in use
if spec.psm in self.servers:
if psm in self.servers:
raise ValueError('PSM already in use')
# Check that the PSM is valid
if spec.psm % 2 == 0:
if psm % 2 == 0:
raise ValueError('invalid PSM (not odd)')
check = spec.psm >> 8
check = psm >> 8
while check:
if check % 2 != 0:
raise ValueError('invalid PSM')
check >>= 8
self.servers[spec.psm] = ClassicChannelServer(self, spec.psm, handler, spec.mtu)
self.servers[psm] = server
return self.servers[spec.psm]
return psm
@deprecated("Please use create_le_credit_based_server()")
def register_le_coc_server(
self,
psm: int,
server: Callable[[LeCreditBasedChannel], Any],
max_credits: int,
mtu: int,
mps: int,
server: Callable[[LeConnectionOrientedChannel], Any],
max_credits: int = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS,
mtu: int = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU,
mps: int = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS,
) -> int:
return self.create_le_credit_based_server(
spec=LeCreditBasedChannelSpec(
psm=None if psm == 0 else psm, mtu=mtu, mps=mps, max_credits=max_credits
),
handler=server,
).psm
self.check_le_coc_parameters(max_credits, mtu, mps)
def create_le_credit_based_server(
self,
spec: LeCreditBasedChannelSpec,
handler: Optional[Callable[[LeCreditBasedChannel], Any]] = None,
) -> LeCreditBasedChannelServer:
if not spec.psm:
if psm == 0:
# Find a free PSM
for candidate in range(
L2CAP_LE_PSM_DYNAMIC_RANGE_START, L2CAP_LE_PSM_DYNAMIC_RANGE_END + 1
):
if candidate in self.le_coc_servers:
continue
spec.psm = candidate
psm = candidate
break
else:
raise InvalidStateError('no free PSM')
else:
# Check that the PSM isn't already in use
if spec.psm in self.le_coc_servers:
if psm in self.le_coc_servers:
raise ValueError('PSM already in use')
self.le_coc_servers[spec.psm] = LeCreditBasedChannelServer(
self,
spec.psm,
handler,
max_credits=spec.max_credits,
mtu=spec.mtu,
mps=spec.mps,
self.le_coc_servers[psm] = (
server,
max_credits or L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS,
mtu or L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU,
mps or L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS,
)
return self.le_coc_servers[spec.psm]
return psm
def on_disconnection(self, connection_handle: int, _reason: int) -> None:
logger.debug(f'disconnection from {connection_handle}, cleaning up channels')
@@ -1651,13 +1540,12 @@ class ChannelManager:
def send_pdu(self, connection, cid: int, pdu: Union[SupportsBytes, bytes]) -> None:
pdu_str = pdu.hex() if isinstance(pdu, bytes) else str(pdu)
pdu_bytes = bytes(pdu)
logger.debug(
f'{color(">>> Sending L2CAP PDU", "blue")} '
f'on connection [0x{connection.handle:04X}] (CID={cid}) '
f'{connection.peer_address}: {len(pdu_bytes)} bytes, {pdu_str}'
f'{connection.peer_address}: {pdu_str}'
)
self.host.send_l2cap_pdu(connection.handle, cid, pdu_bytes)
self.host.send_l2cap_pdu(connection.handle, cid, bytes(pdu))
def on_pdu(self, connection: Connection, cid: int, pdu: bytes) -> None:
if cid in (L2CAP_SIGNALING_CID, L2CAP_LE_SIGNALING_CID):
@@ -1762,13 +1650,13 @@ class ChannelManager:
logger.debug(
f'creating server channel with cid={source_cid} for psm {request.psm}'
)
channel = ClassicChannel(
self, connection, cid, request.psm, source_cid, server.mtu
channel = Channel(
self, connection, cid, request.psm, source_cid, L2CAP_MIN_BR_EDR_MTU
)
connection_channels[source_cid] = channel
# Notify
server.on_connection(channel)
server(channel)
channel.on_connection_request(request)
else:
logger.warning(
@@ -1934,7 +1822,7 @@ class ChannelManager:
supervision_timeout=request.timeout,
min_ce_length=0,
max_ce_length=0,
)
) # type: ignore[call-arg]
)
else:
self.send_control_frame(
@@ -1990,7 +1878,7 @@ class ChannelManager:
self, connection: Connection, cid: int, request
) -> None:
if request.le_psm in self.le_coc_servers:
server = self.le_coc_servers[request.le_psm]
(server, max_credits, mtu, mps) = self.le_coc_servers[request.le_psm]
# Check that the CID isn't already used
le_connection_channels = self.le_coc_channels.setdefault(
@@ -2004,8 +1892,8 @@ class ChannelManager:
L2CAP_LE_Credit_Based_Connection_Response(
identifier=request.identifier,
destination_cid=0,
mtu=server.mtu,
mps=server.mps,
mtu=mtu,
mps=mps,
initial_credits=0,
# pylint: disable=line-too-long
result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_REFUSED_SOURCE_CID_ALREADY_ALLOCATED,
@@ -2023,8 +1911,8 @@ class ChannelManager:
L2CAP_LE_Credit_Based_Connection_Response(
identifier=request.identifier,
destination_cid=0,
mtu=server.mtu,
mps=server.mps,
mtu=mtu,
mps=mps,
initial_credits=0,
# pylint: disable=line-too-long
result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_REFUSED_NO_RESOURCES_AVAILABLE,
@@ -2037,18 +1925,18 @@ class ChannelManager:
f'creating LE CoC server channel with cid={source_cid} for psm '
f'{request.le_psm}'
)
channel = LeCreditBasedChannel(
channel = LeConnectionOrientedChannel(
self,
connection,
request.le_psm,
source_cid,
request.source_cid,
server.mtu,
server.mps,
mtu,
mps,
request.initial_credits,
request.mtu,
request.mps,
server.max_credits,
max_credits,
True,
)
connection_channels[source_cid] = channel
@@ -2061,16 +1949,16 @@ class ChannelManager:
L2CAP_LE_Credit_Based_Connection_Response(
identifier=request.identifier,
destination_cid=source_cid,
mtu=server.mtu,
mps=server.mps,
initial_credits=server.max_credits,
mtu=mtu,
mps=mps,
initial_credits=max_credits,
# pylint: disable=line-too-long
result=L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_SUCCESSFUL,
),
)
# Notify
server.on_connection(channel)
server(channel)
else:
logger.info(
f'No LE server for connection 0x{connection.handle:04X} '
@@ -2125,51 +2013,37 @@ class ChannelManager:
channel.on_credits(credit.credits)
def on_channel_closed(self, channel: ClassicChannel) -> None:
def on_channel_closed(self, channel: Channel) -> None:
connection_channels = self.channels.get(channel.connection.handle)
if connection_channels:
if channel.source_cid in connection_channels:
del connection_channels[channel.source_cid]
@deprecated("Please use create_le_credit_based_channel()")
async def open_le_coc(
self, connection: Connection, psm: int, max_credits: int, mtu: int, mps: int
) -> LeCreditBasedChannel:
return await self.create_le_credit_based_channel(
connection=connection,
spec=LeCreditBasedChannelSpec(
psm=psm, max_credits=max_credits, mtu=mtu, mps=mps
),
)
) -> LeConnectionOrientedChannel:
self.check_le_coc_parameters(max_credits, mtu, mps)
async def create_le_credit_based_channel(
self,
connection: Connection,
spec: LeCreditBasedChannelSpec,
) -> LeCreditBasedChannel:
# Find a free CID for the new channel
connection_channels = self.channels.setdefault(connection.handle, {})
source_cid = self.find_free_le_cid(connection_channels)
if source_cid is None: # Should never happen!
raise RuntimeError('all CIDs already in use')
if spec.psm is None:
raise ValueError('PSM cannot be None')
# Create the channel
logger.debug(f'creating coc channel with cid={source_cid} for psm {spec.psm}')
channel = LeCreditBasedChannel(
logger.debug(f'creating coc channel with cid={source_cid} for psm {psm}')
channel = LeConnectionOrientedChannel(
manager=self,
connection=connection,
le_psm=spec.psm,
le_psm=psm,
source_cid=source_cid,
destination_cid=0,
mtu=spec.mtu,
mps=spec.mps,
mtu=mtu,
mps=mps,
credits=0,
peer_mtu=0,
peer_mps=0,
peer_credits=spec.max_credits,
peer_credits=max_credits,
connected=False,
)
connection_channels[source_cid] = channel
@@ -2188,15 +2062,7 @@ class ChannelManager:
return channel
@deprecated("Please use create_classic_channel()")
async def connect(self, connection: Connection, psm: int) -> ClassicChannel:
return await self.create_classic_channel(
connection=connection, spec=ClassicChannelSpec(psm=psm)
)
async def create_classic_channel(
self, connection: Connection, spec: ClassicChannelSpec
) -> ClassicChannel:
async def connect(self, connection: Connection, psm: int) -> Channel:
# NOTE: this implementation hard-codes BR/EDR
# Find a free CID for a new channel
@@ -2205,20 +2071,10 @@ class ChannelManager:
if source_cid is None: # Should never happen!
raise RuntimeError('all CIDs already in use')
if spec.psm is None:
raise ValueError('PSM cannot be None')
# Create the channel
logger.debug(
f'creating client channel with cid={source_cid} for psm {spec.psm}'
)
channel = ClassicChannel(
self,
connection,
L2CAP_SIGNALING_CID,
spec.psm,
source_cid,
spec.mtu,
logger.debug(f'creating client channel with cid={source_cid} for psm {psm}')
channel = Channel(
self, connection, L2CAP_SIGNALING_CID, psm, source_cid, L2CAP_MIN_BR_EDR_MTU
)
connection_channels[source_cid] = channel
@@ -2230,20 +2086,3 @@ class ChannelManager:
raise e
return channel
# -----------------------------------------------------------------------------
# Deprecated Classes
# -----------------------------------------------------------------------------
class Channel(ClassicChannel):
@deprecated("Please use ClassicChannel")
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
class LeConnectionOrientedChannel(LeCreditBasedChannel):
@deprecated("Please use LeCreditBasedChannel")
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

View File

@@ -26,13 +26,9 @@ from bumble.hci import (
HCI_SUCCESS,
HCI_CONNECTION_ACCEPT_TIMEOUT_ERROR,
HCI_CONNECTION_TIMEOUT_ERROR,
HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR,
HCI_PAGE_TIMEOUT_ERROR,
HCI_Connection_Complete_Event,
)
from bumble import controller
from typing import Optional, Set
# -----------------------------------------------------------------------------
# Logging
@@ -61,8 +57,6 @@ class LocalLink:
Link bus for controllers to communicate with each other
'''
controllers: Set[controller.Controller]
def __init__(self):
self.controllers = set()
self.pending_connection = None
@@ -85,9 +79,7 @@ class LocalLink:
return controller
return None
def find_classic_controller(
self, address: Address
) -> Optional[controller.Controller]:
def find_classic_controller(self, address):
for controller in self.controllers:
if controller.public_address == address:
return controller
@@ -196,60 +188,6 @@ class LocalLink:
if peripheral_controller := self.find_controller(peripheral_address):
peripheral_controller.on_link_encrypted(central_address, rand, ediv, ltk)
def create_cis(
self,
central_controller: controller.Controller,
peripheral_address: Address,
cig_id: int,
cis_id: int,
) -> None:
logger.debug(
f'$$$ CIS Request {central_controller.random_address} -> {peripheral_address}'
)
if peripheral_controller := self.find_controller(peripheral_address):
asyncio.get_running_loop().call_soon(
peripheral_controller.on_link_cis_request,
central_controller.random_address,
cig_id,
cis_id,
)
def accept_cis(
self,
peripheral_controller: controller.Controller,
central_address: Address,
cig_id: int,
cis_id: int,
) -> None:
logger.debug(
f'$$$ CIS Accept {peripheral_controller.random_address} -> {central_address}'
)
if central_controller := self.find_controller(central_address):
asyncio.get_running_loop().call_soon(
central_controller.on_link_cis_established, cig_id, cis_id
)
asyncio.get_running_loop().call_soon(
peripheral_controller.on_link_cis_established, cig_id, cis_id
)
def disconnect_cis(
self,
initiator_controller: controller.Controller,
peer_address: Address,
cig_id: int,
cis_id: int,
) -> None:
logger.debug(
f'$$$ CIS Disconnect {initiator_controller.random_address} -> {peer_address}'
)
if peer_controller := self.find_controller(peer_address):
asyncio.get_running_loop().call_soon(
initiator_controller.on_link_cis_disconnected, cig_id, cis_id
)
asyncio.get_running_loop().call_soon(
peer_controller.on_link_cis_disconnected, cig_id, cis_id
)
############################################################
# Classic handlers
############################################################
@@ -333,52 +271,6 @@ class LocalLink:
initiator_controller.public_address, int(not (initiator_new_role))
)
def classic_sco_connect(
self,
initiator_controller: controller.Controller,
responder_address: Address,
link_type: int,
):
logger.debug(
f'[Classic] {initiator_controller.public_address} connects SCO to {responder_address}'
)
responder_controller = self.find_classic_controller(responder_address)
# Initiator controller should handle it.
assert responder_controller
responder_controller.on_classic_connection_request(
initiator_controller.public_address,
link_type,
)
def classic_accept_sco_connection(
self,
responder_controller: controller.Controller,
initiator_address: Address,
link_type: int,
):
logger.debug(
f'[Classic] {responder_controller.public_address} accepts to connect SCO {initiator_address}'
)
initiator_controller = self.find_classic_controller(initiator_address)
if initiator_controller is None:
responder_controller.on_classic_sco_connection_complete(
responder_controller.public_address,
HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR,
link_type,
)
return
async def task():
initiator_controller.on_classic_sco_connection_complete(
responder_controller.public_address, HCI_SUCCESS, link_type
)
asyncio.create_task(task())
responder_controller.on_classic_sco_connection_complete(
initiator_controller.public_address, HCI_SUCCESS, link_type
)
# -----------------------------------------------------------------------------
class RemoteLink:

View File

@@ -15,9 +15,7 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import enum
from dataclasses import dataclass
from typing import Optional, Tuple
from .hci import (
@@ -37,60 +35,7 @@ from .smp import (
SMP_ID_KEY_DISTRIBUTION_FLAG,
SMP_SIGN_KEY_DISTRIBUTION_FLAG,
SMP_LINK_KEY_DISTRIBUTION_FLAG,
OobContext,
OobLegacyContext,
OobSharedData,
)
from .core import AdvertisingData, LeRole
# -----------------------------------------------------------------------------
@dataclass
class OobData:
"""OOB data that can be sent from one device to another."""
address: Optional[Address] = None
role: Optional[LeRole] = None
shared_data: Optional[OobSharedData] = None
legacy_context: Optional[OobLegacyContext] = None
@classmethod
def from_ad(cls, ad: AdvertisingData) -> OobData:
instance = cls()
shared_data_c: Optional[bytes] = None
shared_data_r: Optional[bytes] = None
for ad_type, ad_data in ad.ad_structures:
if ad_type == AdvertisingData.LE_BLUETOOTH_DEVICE_ADDRESS:
instance.address = Address(ad_data)
elif ad_type == AdvertisingData.LE_ROLE:
instance.role = LeRole(ad_data[0])
elif ad_type == AdvertisingData.LE_SECURE_CONNECTIONS_CONFIRMATION_VALUE:
shared_data_c = ad_data
elif ad_type == AdvertisingData.LE_SECURE_CONNECTIONS_RANDOM_VALUE:
shared_data_r = ad_data
elif ad_type == AdvertisingData.SECURITY_MANAGER_TK_VALUE:
instance.legacy_context = OobLegacyContext(tk=ad_data)
if shared_data_c and shared_data_r:
instance.shared_data = OobSharedData(c=shared_data_c, r=shared_data_r)
return instance
def to_ad(self) -> AdvertisingData:
ad_structures = []
if self.address is not None:
ad_structures.append(
(AdvertisingData.LE_BLUETOOTH_DEVICE_ADDRESS, bytes(self.address))
)
if self.role is not None:
ad_structures.append((AdvertisingData.LE_ROLE, bytes([self.role])))
if self.shared_data is not None:
ad_structures.extend(self.shared_data.to_ad().ad_structures)
if self.legacy_context is not None:
ad_structures.append(
(AdvertisingData.SECURITY_MANAGER_TK_VALUE, self.legacy_context.tk)
)
return AdvertisingData(ad_structures)
# -----------------------------------------------------------------------------
@@ -228,14 +173,6 @@ class PairingConfig:
PUBLIC = Address.PUBLIC_DEVICE_ADDRESS
RANDOM = Address.RANDOM_DEVICE_ADDRESS
@dataclass
class OobConfig:
"""Config for OOB pairing."""
our_context: Optional[OobContext]
peer_data: Optional[OobSharedData]
legacy_context: Optional[OobLegacyContext]
def __init__(
self,
sc: bool = True,
@@ -243,20 +180,17 @@ class PairingConfig:
bonding: bool = True,
delegate: Optional[PairingDelegate] = None,
identity_address_type: Optional[AddressType] = None,
oob: Optional[OobConfig] = None,
) -> None:
self.sc = sc
self.mitm = mitm
self.bonding = bonding
self.delegate = delegate or PairingDelegate()
self.identity_address_type = identity_address_type
self.oob = oob
def __str__(self) -> str:
return (
f'PairingConfig(sc={self.sc}, '
f'mitm={self.mitm}, bonding={self.bonding}, '
f'identity_address_type={self.identity_address_type}, '
f'delegate[{self.delegate.io_capability}]), '
f'oob[{self.oob}])'
f'delegate[{self.delegate.io_capability}])'
)

View File

@@ -24,8 +24,10 @@ import grpc.aio
from .config import Config
from .device import PandoraDevice
from .asha import AshaService
from .host import HostService
from .security import SecurityService, SecurityStorageService
from pandora.asha_grpc_aio import add_ASHAServicer_to_server
from pandora.host_grpc_aio import add_HostServicer_to_server
from pandora.security_grpc_aio import (
add_SecurityServicer_to_server,
@@ -68,6 +70,7 @@ async def serve(
config.load_from_dict(bumble.config.get('server', {}))
# add Pandora services to the gRPC server.
add_ASHAServicer_to_server(AshaService(bumble.device), server)
add_HostServicer_to_server(
HostService(server, bumble.device, config), server
)

96
bumble/pandora/asha.py Normal file
View File

@@ -0,0 +1,96 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import grpc
import logging
from bumble.decoder import G722Decoder
from bumble.device import Connection, Device
from bumble.pandora import utils
from bumble.profiles import asha_service
from google.protobuf.empty_pb2 import Empty # pytype: disable=pyi-error
from pandora.asha_grpc_aio import ASHAServicer
from pandora.asha_pb2 import CaptureAudioRequest, CaptureAudioResponse, RegisterRequest
from typing import AsyncGenerator, Optional
class AshaService(ASHAServicer):
DECODE_FRAME_LENGTH = 80
device: Device
asha_service: Optional[asha_service.AshaService]
def __init__(self, device: Device) -> None:
self.log = utils.BumbleServerLoggerAdapter(
logging.getLogger(), {"service_name": "Asha", "device": device}
)
self.device = device
self.asha_service = None
@utils.rpc
async def Register(
self, request: RegisterRequest, context: grpc.ServicerContext
) -> Empty:
logging.info("Register")
if self.asha_service:
self.asha_service.capability = request.capability
self.asha_service.hisyncid = request.hisyncid
else:
self.asha_service = asha_service.AshaService(
request.capability, request.hisyncid, self.device
)
self.device.add_service(self.asha_service) # type: ignore[no-untyped-call]
return Empty()
@utils.rpc
async def CaptureAudio(
self, request: CaptureAudioRequest, context: grpc.ServicerContext
) -> AsyncGenerator[CaptureAudioResponse, None]:
connection_handle = int.from_bytes(request.connection.cookie.value, "big")
logging.info(f"CaptureAudioData connection_handle:{connection_handle}")
if not (connection := self.device.lookup_connection(connection_handle)):
raise RuntimeError(
f"Unknown connection for connection_handle:{connection_handle}"
)
decoder = G722Decoder() # type: ignore
queue: asyncio.Queue[bytes] = asyncio.Queue()
def on_data(asha_connection: Connection, data: bytes) -> None:
if asha_connection == connection:
queue.put_nowait(data)
self.asha_service.on("data", on_data) # type: ignore
try:
while data := await queue.get():
output_bytes = bytearray()
# First byte is sequence number, last 160 bytes are audio payload.
audio_payload = data[1:]
data_length = int(len(audio_payload) / AshaService.DECODE_FRAME_LENGTH)
for i in range(0, data_length):
input_data = audio_payload[
i
* AshaService.DECODE_FRAME_LENGTH : i
* AshaService.DECODE_FRAME_LENGTH
+ AshaService.DECODE_FRAME_LENGTH
]
decoded_data = decoder.decode_frame(input_data)
output_bytes.extend(decoded_data)
yield CaptureAudioResponse(data=bytes(output_bytes))
finally:
self.asha_service.remove_listener("data", on_data) # type: ignore

View File

@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from bumble.pairing import PairingConfig, PairingDelegate
from dataclasses import dataclass
from typing import Any, Dict

View File

@@ -14,7 +14,6 @@
"""Generic & dependency free Bumble (reference) device."""
from __future__ import annotations
from bumble import transport
from bumble.core import (
BT_GENERIC_AUDIO_SERVICE,

View File

@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
import bumble.device
import grpc
@@ -34,11 +33,8 @@ from bumble.device import (
DEVICE_DEFAULT_SCAN_INTERVAL,
DEVICE_DEFAULT_SCAN_WINDOW,
Advertisement,
AdvertisingParameters,
AdvertisingEventProperties,
AdvertisingType,
Device,
Phy,
)
from bumble.gatt import Service
from bumble.hci import (
@@ -50,12 +46,9 @@ from bumble.hci import (
from google.protobuf import any_pb2 # pytype: disable=pyi-error
from google.protobuf import empty_pb2 # pytype: disable=pyi-error
from pandora.host_grpc_aio import HostServicer
from pandora import host_pb2
from pandora.host_pb2 import (
NOT_CONNECTABLE,
NOT_DISCOVERABLE,
DISCOVERABLE_LIMITED,
DISCOVERABLE_GENERAL,
PRIMARY_1M,
PRIMARY_CODED,
SECONDARY_1M,
@@ -71,7 +64,6 @@ from pandora.host_pb2 import (
ConnectResponse,
DataTypes,
DisconnectRequest,
DiscoverabilityMode,
InquiryResponse,
PrimaryPhy,
ReadLocalAddressResponse,
@@ -101,25 +93,6 @@ SECONDARY_PHY_MAP: Dict[int, SecondaryPhy] = {
3: SECONDARY_CODED,
}
PRIMARY_PHY_TO_BUMBLE_PHY_MAP: Dict[PrimaryPhy, Phy] = {
PRIMARY_1M: Phy.LE_1M,
PRIMARY_CODED: Phy.LE_CODED,
}
SECONDARY_PHY_TO_BUMBLE_PHY_MAP: Dict[SecondaryPhy, Phy] = {
SECONDARY_NONE: Phy.LE_1M,
SECONDARY_1M: Phy.LE_1M,
SECONDARY_2M: Phy.LE_2M,
SECONDARY_CODED: Phy.LE_CODED,
}
OWN_ADDRESS_MAP: Dict[host_pb2.OwnAddressType, bumble.hci.OwnAddressType] = {
host_pb2.PUBLIC: bumble.hci.OwnAddressType.PUBLIC,
host_pb2.RANDOM: bumble.hci.OwnAddressType.RANDOM,
host_pb2.RESOLVABLE_OR_PUBLIC: bumble.hci.OwnAddressType.RESOLVABLE_OR_PUBLIC,
host_pb2.RESOLVABLE_OR_RANDOM: bumble.hci.OwnAddressType.RESOLVABLE_OR_RANDOM,
}
class HostService(HostServicer):
waited_connections: Set[int]
@@ -307,118 +280,14 @@ class HostService(HostServicer):
async def Advertise(
self, request: AdvertiseRequest, context: grpc.ServicerContext
) -> AsyncGenerator[AdvertiseResponse, None]:
try:
if request.legacy:
async for rsp in self.legacy_advertise(request, context):
yield rsp
else:
async for rsp in self.extended_advertise(request, context):
yield rsp
finally:
pass
async def extended_advertise(
self, request: AdvertiseRequest, context: grpc.ServicerContext
) -> AsyncGenerator[AdvertiseResponse, None]:
advertising_data = bytes(self.unpack_data_types(request.data))
scan_response_data = bytes(self.unpack_data_types(request.scan_response_data))
scannable = len(scan_response_data) != 0
advertising_event_properties = AdvertisingEventProperties(
is_connectable=request.connectable,
is_scannable=scannable,
is_directed=request.target is not None,
is_high_duty_cycle_directed_connectable=False,
is_legacy=False,
is_anonymous=False,
include_tx_power=False,
)
peer_address = Address.ANY
if request.target:
# Need to reverse bytes order since Bumble Address is using MSB.
target_bytes = bytes(reversed(request.target))
if request.target_variant() == "public":
peer_address = Address(target_bytes, Address.PUBLIC_DEVICE_ADDRESS)
else:
peer_address = Address(target_bytes, Address.RANDOM_DEVICE_ADDRESS)
advertising_parameters = AdvertisingParameters(
advertising_event_properties=advertising_event_properties,
own_address_type=OWN_ADDRESS_MAP[request.own_address_type],
peer_address=peer_address,
primary_advertising_phy=PRIMARY_PHY_TO_BUMBLE_PHY_MAP[request.primary_phy],
secondary_advertising_phy=SECONDARY_PHY_TO_BUMBLE_PHY_MAP[
request.secondary_phy
],
)
if advertising_interval := request.interval:
advertising_parameters.primary_advertising_interval_min = int(
advertising_interval
if not request.legacy:
raise NotImplementedError(
"TODO: add support for extended advertising in Bumble"
)
advertising_parameters.primary_advertising_interval_max = int(
advertising_interval
)
if interval_range := request.interval_range:
advertising_parameters.primary_advertising_interval_max += int(
interval_range
)
advertising_set = await self.device.create_advertising_set(
advertising_parameters=advertising_parameters,
advertising_data=advertising_data,
scan_response_data=scan_response_data,
)
pending_connection: asyncio.Future[
bumble.device.Connection
] = asyncio.get_running_loop().create_future()
if request.connectable:
def on_connection(connection: bumble.device.Connection) -> None:
if (
connection.transport == BT_LE_TRANSPORT
and connection.role == BT_PERIPHERAL_ROLE
):
pending_connection.set_result(connection)
self.device.on('connection', on_connection)
try:
# Advertise until RPC is canceled
while True:
if not advertising_set.enabled:
self.log.debug('Advertise (extended)')
await advertising_set.start()
if not request.connectable:
await asyncio.sleep(1)
continue
connection = await pending_connection
pending_connection = asyncio.get_running_loop().create_future()
cookie = any_pb2.Any(value=connection.handle.to_bytes(4, 'big'))
yield AdvertiseResponse(connection=Connection(cookie=cookie))
await asyncio.sleep(1)
finally:
try:
self.log.debug('Stop Advertise (extended)')
await advertising_set.stop()
await advertising_set.remove()
except Exception:
pass
async def legacy_advertise(
self, request: AdvertiseRequest, context: grpc.ServicerContext
) -> AsyncGenerator[AdvertiseResponse, None]:
if advertising_interval := request.interval:
self.device.config.advertising_interval_min = int(advertising_interval)
self.device.config.advertising_interval_max = int(advertising_interval)
if interval_range := request.interval_range:
self.device.config.advertising_interval_max += int(interval_range)
if request.interval:
raise NotImplementedError("TODO: add support for `request.interval`")
if request.interval_range:
raise NotImplementedError("TODO: add support for `request.interval_range`")
if request.primary_phy:
raise NotImplementedError("TODO: add support for `request.primary_phy`")
if request.secondary_phy:
@@ -486,10 +355,14 @@ class HostService(HostServicer):
target_bytes = bytes(reversed(request.target))
if request.target_variant() == "public":
target = Address(target_bytes, Address.PUBLIC_DEVICE_ADDRESS)
advertising_type = AdvertisingType.DIRECTED_CONNECTABLE_LOW_DUTY
advertising_type = (
AdvertisingType.DIRECTED_CONNECTABLE_HIGH_DUTY
) # FIXME: HIGH_DUTY ?
else:
target = Address(target_bytes, Address.RANDOM_DEVICE_ADDRESS)
advertising_type = AdvertisingType.DIRECTED_CONNECTABLE_LOW_DUTY
advertising_type = (
AdvertisingType.DIRECTED_CONNECTABLE_HIGH_DUTY
) # FIXME: HIGH_DUTY ?
if request.connectable:
@@ -547,15 +420,10 @@ class HostService(HostServicer):
self, request: ScanRequest, context: grpc.ServicerContext
) -> AsyncGenerator[ScanningResponse, None]:
# TODO: modify `start_scanning` to accept floats instead of int for ms values
self.log.debug('Scan')
if request.phys:
raise NotImplementedError("TODO: add support for `request.phys`")
scanning_phys = []
if PRIMARY_1M in request.phys:
scanning_phys.append(int(Phy.LE_1M))
if PRIMARY_CODED in request.phys:
scanning_phys.append(int(Phy.LE_CODED))
if not scanning_phys:
scanning_phys = [int(Phy.LE_1M), int(Phy.LE_CODED)]
self.log.debug('Scan')
scan_queue: asyncio.Queue[Advertisement] = asyncio.Queue()
handler = self.device.on('advertisement', scan_queue.put_nowait)
@@ -569,7 +437,6 @@ class HostService(HostServicer):
scan_window=int(request.window)
if request.window
else DEVICE_DEFAULT_SCAN_WINDOW,
scanning_phys=scanning_phys,
)
try:
@@ -866,16 +733,6 @@ class HostService(HostServicer):
)
)
flag_map = {
NOT_DISCOVERABLE: 0x00,
DISCOVERABLE_LIMITED: AdvertisingData.LE_LIMITED_DISCOVERABLE_MODE_FLAG,
DISCOVERABLE_GENERAL: AdvertisingData.LE_GENERAL_DISCOVERABLE_MODE_FLAG,
}
if dt.le_discoverability_mode:
flags = flag_map[dt.le_discoverability_mode]
ad_structures.append((AdvertisingData.FLAGS, flags.to_bytes(1, 'big')))
return AdvertisingData(ad_structures)
def pack_data_types(self, ad: AdvertisingData) -> DataTypes:

View File

@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
import contextlib
import grpc
@@ -110,7 +109,7 @@ class PairingDelegate(BasePairingDelegate):
event = self.add_origin(PairingEvent(just_works=empty_pb2.Empty()))
self.service.event_queue.put_nowait(event)
answer = await anext(self.service.event_answer) # type: ignore
answer = await anext(self.service.event_answer) # pytype: disable=name-error
assert answer.event == event
assert answer.answer_variant() == 'confirm' and answer.confirm is not None
return answer.confirm
@@ -125,7 +124,7 @@ class PairingDelegate(BasePairingDelegate):
event = self.add_origin(PairingEvent(numeric_comparison=number))
self.service.event_queue.put_nowait(event)
answer = await anext(self.service.event_answer) # type: ignore
answer = await anext(self.service.event_answer) # pytype: disable=name-error
assert answer.event == event
assert answer.answer_variant() == 'confirm' and answer.confirm is not None
return answer.confirm
@@ -140,7 +139,7 @@ class PairingDelegate(BasePairingDelegate):
event = self.add_origin(PairingEvent(passkey_entry_request=empty_pb2.Empty()))
self.service.event_queue.put_nowait(event)
answer = await anext(self.service.event_answer) # type: ignore
answer = await anext(self.service.event_answer) # pytype: disable=name-error
assert answer.event == event
if answer.answer_variant() is None:
return None
@@ -157,7 +156,7 @@ class PairingDelegate(BasePairingDelegate):
event = self.add_origin(PairingEvent(pin_code_request=empty_pb2.Empty()))
self.service.event_queue.put_nowait(event)
answer = await anext(self.service.event_answer) # type: ignore
answer = await anext(self.service.event_answer) # pytype: disable=name-error
assert answer.event == event
if answer.answer_variant() is None:
return None
@@ -451,18 +450,21 @@ class SecurityService(SecurityServicer):
'security_request': pair,
}
with contextlib.closing(EventWatcher()) as watcher:
# register event handlers
for event, listener in listeners.items():
watcher.on(connection, event, listener)
# register event handlers
for event, listener in listeners.items():
connection.on(event, listener)
# security level already reached
if self.reached_security_level(connection, level):
return WaitSecurityResponse(success=empty_pb2.Empty())
# security level already reached
if self.reached_security_level(connection, level):
return WaitSecurityResponse(success=empty_pb2.Empty())
self.log.debug('Wait for security...')
kwargs = {}
kwargs[await wait_for_security] = empty_pb2.Empty()
self.log.debug('Wait for security...')
kwargs = {}
kwargs[await wait_for_security] = empty_pb2.Empty()
# remove event handlers
for event, listener in listeners.items():
connection.remove_listener(event, listener) # type: ignore
# wait for `authenticate` to finish if any
if authenticate_task is not None:

View File

@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import contextlib
import functools
import grpc

View File

@@ -18,9 +18,7 @@
# -----------------------------------------------------------------------------
import struct
import logging
from typing import List, Optional
from bumble import l2cap
from typing import List
from ..core import AdvertisingData
from ..device import Device, Connection
from ..gatt import (
@@ -34,6 +32,7 @@ from ..gatt import (
Characteristic,
CharacteristicValue,
)
from ..l2cap import Channel
from ..utils import AsyncRunner
# -----------------------------------------------------------------------------
@@ -54,46 +53,48 @@ class AshaService(TemplateService):
SUPPORTED_CODEC_ID = [0x02, 0x01] # Codec IDs [G.722 at 16 kHz]
RENDER_DELAY = [00, 00]
def __init__(self, capability: int, hisyncid: List[int], device: Device, psm=0):
def __init__(
self, capability: int, hisyncid: List[int], device: Device, psm: int = 0
) -> None:
self.hisyncid = hisyncid
self.capability = capability # Device Capabilities [Left, Monaural]
self.device = device
self.audio_out_data = b''
self.psm = psm # a non-zero psm is mainly for testing purpose
self.audio_out_data = b""
self.psm: int = psm # a non-zero psm is mainly for testing purpose
# Handler for volume control
def on_volume_write(connection, value):
logger.info(f'--- VOLUME Write:{value[0]}')
self.emit('volume', connection, value[0])
def on_volume_write(connection: Connection, value: bytes) -> None:
logger.info(f"--- VOLUME Write:{value[0]}")
self.emit("volume", connection, value[0])
# Handler for audio control commands
def on_audio_control_point_write(connection: Optional[Connection], value):
logger.info(f'--- AUDIO CONTROL POINT Write:{value.hex()}')
def on_audio_control_point_write(connection: Connection, value: bytes) -> None:
logger.info(f"--- AUDIO CONTROL POINT Write:{value.hex()}")
opcode = value[0]
if opcode == AshaService.OPCODE_START:
# Start
audio_type = ('Unknown', 'Ringtone', 'Phone Call', 'Media')[value[2]]
audio_type = ("Unknown", "Ringtone", "Phone Call", "Media")[value[2]]
logger.info(
f'### START: codec={value[1]}, '
f'audio_type={audio_type}, '
f'volume={value[3]}, '
f'otherstate={value[4]}'
f"### START: codec={value[1]}, "
f"audio_type={audio_type}, "
f"volume={value[3]}, "
f"otherstate={value[4]}"
)
self.emit(
'start',
"start",
connection,
{
'codec': value[1],
'audiotype': value[2],
'volume': value[3],
'otherstate': value[4],
"codec": value[1],
"audiotype": value[2],
"volume": value[3],
"otherstate": value[4],
},
)
elif opcode == AshaService.OPCODE_STOP:
logger.info('### STOP')
self.emit('stop', connection)
logger.info("### STOP")
self.emit("stop", connection)
elif opcode == AshaService.OPCODE_STATUS:
logger.info(f'### STATUS: connected={value[1]}')
logger.info(f"### STATUS: connected={value[1]}")
# OPCODE_STATUS does not need audio status point update
if opcode != AshaService.OPCODE_STATUS:
@@ -103,63 +104,70 @@ class AshaService(TemplateService):
)
)
def on_read_only_properties_read(connection: Connection) -> bytes:
value = (
bytes(
[
AshaService.PROTOCOL_VERSION, # Version
self.capability,
]
)
+ bytes(self.hisyncid)
+ bytes(AshaService.FEATURE_MAP)
+ bytes(AshaService.RENDER_DELAY)
+ bytes(AshaService.RESERVED_FOR_FUTURE_USE)
+ bytes(AshaService.SUPPORTED_CODEC_ID)
)
self.emit("read_only_properties", connection, value)
return value
def on_le_psm_out_read(connection: Connection) -> bytes:
self.emit("le_psm_out", connection, self.psm)
return struct.pack("<H", self.psm)
self.read_only_properties_characteristic = Characteristic(
GATT_ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC,
Characteristic.Properties.READ,
Characteristic.READ,
Characteristic.READABLE,
bytes(
[
AshaService.PROTOCOL_VERSION, # Version
self.capability,
]
)
+ bytes(self.hisyncid)
+ bytes(AshaService.FEATURE_MAP)
+ bytes(AshaService.RENDER_DELAY)
+ bytes(AshaService.RESERVED_FOR_FUTURE_USE)
+ bytes(AshaService.SUPPORTED_CODEC_ID),
CharacteristicValue(read=on_read_only_properties_read),
)
self.audio_control_point_characteristic = Characteristic(
GATT_ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC,
Characteristic.Properties.WRITE
| Characteristic.Properties.WRITE_WITHOUT_RESPONSE,
Characteristic.WRITE | Characteristic.WRITE_WITHOUT_RESPONSE,
Characteristic.WRITEABLE,
CharacteristicValue(write=on_audio_control_point_write),
)
self.audio_status_characteristic = Characteristic(
GATT_ASHA_AUDIO_STATUS_CHARACTERISTIC,
Characteristic.Properties.READ | Characteristic.Properties.NOTIFY,
Characteristic.READ | Characteristic.NOTIFY,
Characteristic.READABLE,
bytes([0]),
)
self.volume_characteristic = Characteristic(
GATT_ASHA_VOLUME_CHARACTERISTIC,
Characteristic.Properties.WRITE_WITHOUT_RESPONSE,
Characteristic.WRITE_WITHOUT_RESPONSE,
Characteristic.WRITEABLE,
CharacteristicValue(write=on_volume_write),
)
# Register an L2CAP CoC server
def on_coc(channel):
def on_data(data):
logging.debug(f'<<< data received:{data}')
def on_coc(channel: Channel) -> None:
def on_data(data: bytes) -> None:
logging.debug(f"data received:{data.hex()}")
self.emit('data', channel.connection, data)
self.emit("data", channel.connection, data)
self.audio_out_data += data
channel.sink = on_data
# let the server find a free PSM
self.psm = device.create_l2cap_server(
spec=l2cap.LeCreditBasedChannelSpec(psm=self.psm, max_credits=8),
handler=on_coc,
).psm
self.psm = self.device.register_l2cap_channel_server(self.psm, on_coc, 8)
self.le_psm_out_characteristic = Characteristic(
GATT_ASHA_LE_PSM_OUT_CHARACTERISTIC,
Characteristic.Properties.READ,
Characteristic.READ,
Characteristic.READABLE,
struct.pack('<H', self.psm),
CharacteristicValue(read=on_le_psm_out_read),
)
characteristics = [
@@ -172,7 +180,7 @@ class AshaService(TemplateService):
super().__init__(characteristics)
def get_advertising_data(self):
def get_advertising_data(self) -> bytes:
# Advertisement only uses 4 least significant bytes of the HiSyncId.
return bytes(
AdvertisingData(

File diff suppressed because it is too large Load Diff

View File

@@ -1,52 +0,0 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
from bumble import gatt
from bumble import gatt_client
from bumble.profiles import csip
# -----------------------------------------------------------------------------
# Server
# -----------------------------------------------------------------------------
class CommonAudioServiceService(gatt.TemplateService):
UUID = gatt.GATT_COMMON_AUDIO_SERVICE
def __init__(
self,
coordinated_set_identification_service: csip.CoordinatedSetIdentificationService,
) -> None:
self.coordinated_set_identification_service = (
coordinated_set_identification_service
)
super().__init__(
characteristics=[],
included_services=[coordinated_set_identification_service],
)
# -----------------------------------------------------------------------------
# Client
# -----------------------------------------------------------------------------
class CommonAudioServiceServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = CommonAudioServiceService
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
self.service_proxy = service_proxy

View File

@@ -1,257 +0,0 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import enum
import struct
from typing import Optional, Tuple
from bumble import core
from bumble import crypto
from bumble import device
from bumble import gatt
from bumble import gatt_client
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
SET_IDENTITY_RESOLVING_KEY_LENGTH = 16
class SirkType(enum.IntEnum):
'''Coordinated Set Identification Service - 5.1 Set Identity Resolving Key.'''
ENCRYPTED = 0x00
PLAINTEXT = 0x01
class MemberLock(enum.IntEnum):
'''Coordinated Set Identification Service - 5.3 Set Member Lock.'''
UNLOCKED = 0x01
LOCKED = 0x02
# -----------------------------------------------------------------------------
# Crypto Toolbox
# -----------------------------------------------------------------------------
def s1(m: bytes) -> bytes:
'''
Coordinated Set Identification Service - 4.3 s1 SALT generation function.
'''
return crypto.aes_cmac(m[::-1], bytes(16))[::-1]
def k1(n: bytes, salt: bytes, p: bytes) -> bytes:
'''
Coordinated Set Identification Service - 4.4 k1 derivation function.
'''
t = crypto.aes_cmac(n[::-1], salt[::-1])
return crypto.aes_cmac(p[::-1], t)[::-1]
def sef(k: bytes, r: bytes) -> bytes:
'''
Coordinated Set Identification Service - 4.5 SIRK encryption function sef.
SIRK decryption function sdf shares the same algorithm. The only difference is that argument r is:
* Plaintext in encryption
* Cipher in decryption
'''
return crypto.xor(k1(k, s1(b'SIRKenc'[::-1]), b'csis'[::-1]), r)
def sih(k: bytes, r: bytes) -> bytes:
'''
Coordinated Set Identification Service - 4.7 Resolvable Set Identifier hash function sih.
'''
return crypto.e(k, r + bytes(13))[:3]
def generate_rsi(sirk: bytes) -> bytes:
'''
Coordinated Set Identification Service - 4.8 Resolvable Set Identifier generation operation.
'''
prand = crypto.generate_prand()
return sih(sirk, prand) + prand
# -----------------------------------------------------------------------------
# Server
# -----------------------------------------------------------------------------
class CoordinatedSetIdentificationService(gatt.TemplateService):
UUID = gatt.GATT_COORDINATED_SET_IDENTIFICATION_SERVICE
set_identity_resolving_key: bytes
set_identity_resolving_key_characteristic: gatt.Characteristic
coordinated_set_size_characteristic: Optional[gatt.Characteristic] = None
set_member_lock_characteristic: Optional[gatt.Characteristic] = None
set_member_rank_characteristic: Optional[gatt.Characteristic] = None
def __init__(
self,
set_identity_resolving_key: bytes,
set_identity_resolving_key_type: SirkType,
coordinated_set_size: Optional[int] = None,
set_member_lock: Optional[MemberLock] = None,
set_member_rank: Optional[int] = None,
) -> None:
if len(set_identity_resolving_key) != SET_IDENTITY_RESOLVING_KEY_LENGTH:
raise ValueError(
f'Invalid SIRK length {len(set_identity_resolving_key)}, expected {SET_IDENTITY_RESOLVING_KEY_LENGTH}'
)
characteristics = []
self.set_identity_resolving_key = set_identity_resolving_key
self.set_identity_resolving_key_type = set_identity_resolving_key_type
self.set_identity_resolving_key_characteristic = gatt.Characteristic(
uuid=gatt.GATT_SET_IDENTITY_RESOLVING_KEY_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=gatt.CharacteristicValue(read=self.on_sirk_read),
)
characteristics.append(self.set_identity_resolving_key_characteristic)
if coordinated_set_size is not None:
self.coordinated_set_size_characteristic = gatt.Characteristic(
uuid=gatt.GATT_COORDINATED_SET_SIZE_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=struct.pack('B', coordinated_set_size),
)
characteristics.append(self.coordinated_set_size_characteristic)
if set_member_lock is not None:
self.set_member_lock_characteristic = gatt.Characteristic(
uuid=gatt.GATT_SET_MEMBER_LOCK_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY
| gatt.Characteristic.Properties.WRITE,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION
| gatt.Characteristic.Permissions.WRITEABLE,
value=struct.pack('B', set_member_lock),
)
characteristics.append(self.set_member_lock_characteristic)
if set_member_rank is not None:
self.set_member_rank_characteristic = gatt.Characteristic(
uuid=gatt.GATT_SET_MEMBER_RANK_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=struct.pack('B', set_member_rank),
)
characteristics.append(self.set_member_rank_characteristic)
super().__init__(characteristics)
async def on_sirk_read(self, connection: Optional[device.Connection]) -> bytes:
if self.set_identity_resolving_key_type == SirkType.PLAINTEXT:
sirk_bytes = self.set_identity_resolving_key
else:
assert connection
if connection.transport == core.BT_LE_TRANSPORT:
key = await connection.device.get_long_term_key(
connection_handle=connection.handle, rand=b'', ediv=0
)
else:
key = await connection.device.get_link_key(connection.peer_address)
if not key:
raise RuntimeError('LTK or LinkKey is not present')
sirk_bytes = sef(key, self.set_identity_resolving_key)
return bytes([self.set_identity_resolving_key_type]) + sirk_bytes
def get_advertising_data(self) -> bytes:
return bytes(
core.AdvertisingData(
[
(
core.AdvertisingData.RESOLVABLE_SET_IDENTIFIER,
generate_rsi(self.set_identity_resolving_key),
),
]
)
)
# -----------------------------------------------------------------------------
# Client
# -----------------------------------------------------------------------------
class CoordinatedSetIdentificationProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = CoordinatedSetIdentificationService
set_identity_resolving_key: gatt_client.CharacteristicProxy
coordinated_set_size: Optional[gatt_client.CharacteristicProxy] = None
set_member_lock: Optional[gatt_client.CharacteristicProxy] = None
set_member_rank: Optional[gatt_client.CharacteristicProxy] = None
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
self.service_proxy = service_proxy
self.set_identity_resolving_key = service_proxy.get_characteristics_by_uuid(
gatt.GATT_SET_IDENTITY_RESOLVING_KEY_CHARACTERISTIC
)[0]
if characteristics := service_proxy.get_characteristics_by_uuid(
gatt.GATT_COORDINATED_SET_SIZE_CHARACTERISTIC
):
self.coordinated_set_size = characteristics[0]
if characteristics := service_proxy.get_characteristics_by_uuid(
gatt.GATT_SET_MEMBER_LOCK_CHARACTERISTIC
):
self.set_member_lock = characteristics[0]
if characteristics := service_proxy.get_characteristics_by_uuid(
gatt.GATT_SET_MEMBER_RANK_CHARACTERISTIC
):
self.set_member_rank = characteristics[0]
async def read_set_identity_resolving_key(self) -> Tuple[SirkType, bytes]:
'''Reads SIRK and decrypts if encrypted.'''
response = await self.set_identity_resolving_key.read_value()
if len(response) != SET_IDENTITY_RESOLVING_KEY_LENGTH + 1:
raise RuntimeError('Invalid SIRK value')
sirk_type = SirkType(response[0])
if sirk_type == SirkType.PLAINTEXT:
sirk = response[1:]
else:
connection = self.service_proxy.client.connection
device = connection.device
if connection.transport == core.BT_LE_TRANSPORT:
key = await device.get_long_term_key(
connection_handle=connection.handle, rand=b'', ediv=0
)
else:
key = await device.get_link_key(connection.peer_address)
if not key:
raise RuntimeError('LTK or LinkKey is not present')
sirk = sef(key, response[1:])
return (sirk_type, sirk)

View File

@@ -42,12 +42,12 @@ class HeartRateService(TemplateService):
RESET_ENERGY_EXPENDED = 0x01
class BodySensorLocation(IntEnum):
OTHER = 0
CHEST = 1
WRIST = 2
FINGER = 3
HAND = 4
EAR_LOBE = 5
OTHER = (0,)
CHEST = (1,)
WRIST = (2,)
FINGER = (3,)
HAND = (4,)
EAR_LOBE = (5,)
FOOT = 6
class HeartRateMeasurement:

View File

@@ -1,228 +0,0 @@
# Copyright 2021-2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import enum
from bumble import att
from bumble import device
from bumble import gatt
from bumble import gatt_client
from typing import Optional
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
MIN_VOLUME = 0
MAX_VOLUME = 255
class ErrorCode(enum.IntEnum):
'''
See Volume Control Service 1.6. Application error codes.
'''
INVALID_CHANGE_COUNTER = 0x80
OPCODE_NOT_SUPPORTED = 0x81
class VolumeFlags(enum.IntFlag):
'''
See Volume Control Service 3.3. Volume Flags.
'''
VOLUME_SETTING_PERSISTED = 0x01
# RFU
class VolumeControlPointOpcode(enum.IntEnum):
'''
See Volume Control Service Table 3.3: Volume Control Point procedure requirements.
'''
# fmt: off
RELATIVE_VOLUME_DOWN = 0x00
RELATIVE_VOLUME_UP = 0x01
UNMUTE_RELATIVE_VOLUME_DOWN = 0x02
UNMUTE_RELATIVE_VOLUME_UP = 0x03
SET_ABSOLUTE_VOLUME = 0x04
UNMUTE = 0x05
MUTE = 0x06
# -----------------------------------------------------------------------------
# Server
# -----------------------------------------------------------------------------
class VolumeControlService(gatt.TemplateService):
UUID = gatt.GATT_VOLUME_CONTROL_SERVICE
volume_state: gatt.Characteristic
volume_control_point: gatt.Characteristic
volume_flags: gatt.Characteristic
volume_setting: int
muted: int
change_counter: int
def __init__(
self,
step_size: int = 16,
volume_setting: int = 0,
muted: int = 0,
change_counter: int = 0,
volume_flags: int = 0,
) -> None:
self.step_size = step_size
self.volume_setting = volume_setting
self.muted = muted
self.change_counter = change_counter
self.volume_state = gatt.Characteristic(
uuid=gatt.GATT_VOLUME_STATE_CHARACTERISTIC,
properties=(
gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY
),
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=gatt.CharacteristicValue(read=self._on_read_volume_state),
)
self.volume_control_point = gatt.Characteristic(
uuid=gatt.GATT_VOLUME_CONTROL_POINT_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.WRITE,
permissions=gatt.Characteristic.Permissions.WRITE_REQUIRES_ENCRYPTION,
value=gatt.CharacteristicValue(write=self._on_write_volume_control_point),
)
self.volume_flags = gatt.Characteristic(
uuid=gatt.GATT_VOLUME_FLAGS_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=bytes([volume_flags]),
)
super().__init__(
[
self.volume_state,
self.volume_control_point,
self.volume_flags,
]
)
@property
def volume_state_bytes(self) -> bytes:
return bytes([self.volume_setting, self.muted, self.change_counter])
@volume_state_bytes.setter
def volume_state_bytes(self, new_value: bytes) -> None:
self.volume_setting, self.muted, self.change_counter = new_value
def _on_read_volume_state(self, _connection: Optional[device.Connection]) -> bytes:
return self.volume_state_bytes
def _on_write_volume_control_point(
self, connection: Optional[device.Connection], value: bytes
) -> None:
assert connection
opcode = VolumeControlPointOpcode(value[0])
change_counter = value[1]
if change_counter != self.change_counter:
raise att.ATT_Error(ErrorCode.INVALID_CHANGE_COUNTER)
handler = getattr(self, '_on_' + opcode.name.lower())
if handler(*value[2:]):
self.change_counter = (self.change_counter + 1) % 256
connection.abort_on(
'disconnection',
connection.device.notify_subscribers(
attribute=self.volume_state,
value=self.volume_state_bytes,
),
)
self.emit(
'volume_state', self.volume_setting, self.muted, self.change_counter
)
def _on_relative_volume_down(self) -> bool:
old_volume = self.volume_setting
self.volume_setting = max(self.volume_setting - self.step_size, MIN_VOLUME)
return self.volume_setting != old_volume
def _on_relative_volume_up(self) -> bool:
old_volume = self.volume_setting
self.volume_setting = min(self.volume_setting + self.step_size, MAX_VOLUME)
return self.volume_setting != old_volume
def _on_unmute_relative_volume_down(self) -> bool:
old_volume, old_muted_state = self.volume_setting, self.muted
self.volume_setting = max(self.volume_setting - self.step_size, MIN_VOLUME)
self.muted = 0
return (self.volume_setting, self.muted) != (old_volume, old_muted_state)
def _on_unmute_relative_volume_up(self) -> bool:
old_volume, old_muted_state = self.volume_setting, self.muted
self.volume_setting = min(self.volume_setting + self.step_size, MAX_VOLUME)
self.muted = 0
return (self.volume_setting, self.muted) != (old_volume, old_muted_state)
def _on_set_absolute_volume(self, volume_setting: int) -> bool:
old_volume_setting = self.volume_setting
self.volume_setting = volume_setting
return old_volume_setting != self.volume_setting
def _on_unmute(self) -> bool:
old_muted_state = self.muted
self.muted = 0
return self.muted != old_muted_state
def _on_mute(self) -> bool:
old_muted_state = self.muted
self.muted = 1
return self.muted != old_muted_state
# -----------------------------------------------------------------------------
# Client
# -----------------------------------------------------------------------------
class VolumeControlServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = VolumeControlService
volume_control_point: gatt_client.CharacteristicProxy
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
self.service_proxy = service_proxy
self.volume_state = gatt.PackedCharacteristicAdapter(
service_proxy.get_characteristics_by_uuid(
gatt.GATT_VOLUME_STATE_CHARACTERISTIC
)[0],
'BBB',
)
self.volume_control_point = service_proxy.get_characteristics_by_uuid(
gatt.GATT_VOLUME_CONTROL_POINT_CHARACTERISTIC
)[0]
self.volume_flags = gatt.PackedCharacteristicAdapter(
service_proxy.get_characteristics_by_uuid(
gatt.GATT_VOLUME_FLAGS_CHARACTERISTIC
)[0],
'B',
)

View File

@@ -19,16 +19,12 @@ from __future__ import annotations
import logging
import asyncio
import dataclasses
import enum
from typing import Callable, Dict, List, Optional, Tuple, Union, TYPE_CHECKING
from typing_extensions import Self
from pyee import EventEmitter
from bumble import core
from bumble import l2cap
from bumble import sdp
from . import core, l2cap
from .colors import color
from .core import (
UUID,
@@ -38,6 +34,15 @@ from .core import (
InvalidStateError,
ProtocolError,
)
from .sdp import (
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_PUBLIC_BROWSE_ROOT,
DataElement,
ServiceAttribute,
)
if TYPE_CHECKING:
from bumble.device import Device, Connection
@@ -55,18 +60,27 @@ logger = logging.getLogger(__name__)
RFCOMM_PSM = 0x0003
class FrameType(enum.IntEnum):
SABM = 0x2F # Control field [1,1,1,1,_,1,0,0] LSB-first
UA = 0x63 # Control field [0,1,1,0,_,0,1,1] LSB-first
DM = 0x0F # Control field [1,1,1,1,_,0,0,0] LSB-first
DISC = 0x43 # Control field [0,1,0,_,0,0,1,1] LSB-first
UIH = 0xEF # Control field [1,1,1,_,1,1,1,1] LSB-first
UI = 0x03 # Control field [0,0,0,_,0,0,1,1] LSB-first
class MccType(enum.IntEnum):
PN = 0x20
MSC = 0x38
# Frame types
RFCOMM_SABM_FRAME = 0x2F # Control field [1,1,1,1,_,1,0,0] LSB-first
RFCOMM_UA_FRAME = 0x63 # Control field [0,1,1,0,_,0,1,1] LSB-first
RFCOMM_DM_FRAME = 0x0F # Control field [1,1,1,1,_,0,0,0] LSB-first
RFCOMM_DISC_FRAME = 0x43 # Control field [0,1,0,_,0,0,1,1] LSB-first
RFCOMM_UIH_FRAME = 0xEF # Control field [1,1,1,_,1,1,1,1] LSB-first
RFCOMM_UI_FRAME = 0x03 # Control field [0,0,0,_,0,0,1,1] LSB-first
RFCOMM_FRAME_TYPE_NAMES = {
RFCOMM_SABM_FRAME: 'SABM',
RFCOMM_UA_FRAME: 'UA',
RFCOMM_DM_FRAME: 'DM',
RFCOMM_DISC_FRAME: 'DISC',
RFCOMM_UIH_FRAME: 'UIH',
RFCOMM_UI_FRAME: 'UI'
}
# MCC Types
RFCOMM_MCC_PN_TYPE = 0x20
RFCOMM_MCC_MSC_TYPE = 0x38
# FCS CRC
CRC_TABLE = bytes([
@@ -104,9 +118,8 @@ CRC_TABLE = bytes([
0XBA, 0X2B, 0X59, 0XC8, 0XBD, 0X2C, 0X5E, 0XCF
])
RFCOMM_DEFAULT_L2CAP_MTU = 2048
RFCOMM_DEFAULT_WINDOW_SIZE = 7
RFCOMM_DEFAULT_MAX_FRAME_SIZE = 2000
RFCOMM_DEFAULT_INITIAL_RX_CREDITS = 7
RFCOMM_DEFAULT_PREFERRED_MTU = 1280
RFCOMM_DYNAMIC_CHANNEL_NUMBER_START = 1
RFCOMM_DYNAMIC_CHANNEL_NUMBER_END = 30
@@ -117,33 +130,29 @@ RFCOMM_DYNAMIC_CHANNEL_NUMBER_END = 30
# -----------------------------------------------------------------------------
def make_service_sdp_records(
service_record_handle: int, channel: int, uuid: Optional[UUID] = None
) -> List[sdp.ServiceAttribute]:
) -> List[ServiceAttribute]:
"""
Create SDP records for an RFComm service given a channel number and an
optional UUID. A Service Class Attribute is included only if the UUID is not None.
"""
records = [
sdp.ServiceAttribute(
sdp.SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
sdp.DataElement.unsigned_integer_32(service_record_handle),
ServiceAttribute(
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
DataElement.unsigned_integer_32(service_record_handle),
),
sdp.ServiceAttribute(
sdp.SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
sdp.DataElement.sequence(
[sdp.DataElement.uuid(sdp.SDP_PUBLIC_BROWSE_ROOT)]
),
ServiceAttribute(
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
DataElement.sequence([DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)]),
),
sdp.ServiceAttribute(
sdp.SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
sdp.DataElement.sequence(
ServiceAttribute(
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
sdp.DataElement.sequence(
[sdp.DataElement.uuid(BT_L2CAP_PROTOCOL_ID)]
),
sdp.DataElement.sequence(
DataElement.sequence([DataElement.uuid(BT_L2CAP_PROTOCOL_ID)]),
DataElement.sequence(
[
sdp.DataElement.uuid(BT_RFCOMM_PROTOCOL_ID),
sdp.DataElement.unsigned_integer_8(channel),
DataElement.uuid(BT_RFCOMM_PROTOCOL_ID),
DataElement.unsigned_integer_8(channel),
]
),
]
@@ -153,81 +162,15 @@ def make_service_sdp_records(
if uuid:
records.append(
sdp.ServiceAttribute(
sdp.SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
sdp.DataElement.sequence([sdp.DataElement.uuid(uuid)]),
ServiceAttribute(
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
DataElement.sequence([DataElement.uuid(uuid)]),
)
)
return records
# -----------------------------------------------------------------------------
async def find_rfcomm_channels(connection: Connection) -> Dict[int, List[UUID]]:
"""Searches all RFCOMM channels and their associated UUID from SDP service records.
Args:
connection: ACL connection to make SDP search.
Returns:
Dictionary mapping from channel number to service class UUID list.
"""
results = {}
async with sdp.Client(connection) as sdp_client:
search_result = await sdp_client.search_attributes(
uuids=[core.BT_RFCOMM_PROTOCOL_ID],
attribute_ids=[
sdp.SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
sdp.SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
],
)
for attribute_lists in search_result:
service_classes: List[UUID] = []
channel: Optional[int] = None
for attribute in attribute_lists:
# The layout is [[L2CAP_PROTOCOL], [RFCOMM_PROTOCOL, RFCOMM_CHANNEL]].
if attribute.id == sdp.SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID:
protocol_descriptor_list = attribute.value.value
channel = protocol_descriptor_list[1].value[1].value
elif attribute.id == sdp.SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID:
service_class_id_list = attribute.value.value
service_classes = [
service_class.value for service_class in service_class_id_list
]
if not service_classes or not channel:
logger.warning(f"Bad result {attribute_lists}.")
else:
results[channel] = service_classes
return results
# -----------------------------------------------------------------------------
async def find_rfcomm_channel_with_uuid(
connection: Connection, uuid: str | UUID
) -> Optional[int]:
"""Searches an RFCOMM channel associated with given UUID from service records.
Args:
connection: ACL connection to make SDP search.
uuid: UUID of service record to search for.
Returns:
RFCOMM channel number if found, otherwise None.
"""
if isinstance(uuid, str):
uuid = UUID(uuid)
return next(
(
channel
for channel, class_id_list in (
await find_rfcomm_channels(connection)
).items()
if uuid in class_id_list
),
None,
)
# -----------------------------------------------------------------------------
def compute_fcs(buffer: bytes) -> int:
result = 0xFF
@@ -240,7 +183,7 @@ def compute_fcs(buffer: bytes) -> int:
class RFCOMM_Frame:
def __init__(
self,
frame_type: FrameType,
frame_type: int,
c_r: int,
dlci: int,
p_f: int,
@@ -263,11 +206,14 @@ class RFCOMM_Frame:
self.length = bytes([(length << 1) | 1])
self.address = (dlci << 2) | (c_r << 1) | 1
self.control = frame_type | (p_f << 4)
if frame_type == FrameType.UIH:
if frame_type == RFCOMM_UIH_FRAME:
self.fcs = compute_fcs(bytes([self.address, self.control]))
else:
self.fcs = compute_fcs(bytes([self.address, self.control]) + self.length)
def type_name(self) -> str:
return RFCOMM_FRAME_TYPE_NAMES[self.type]
@staticmethod
def parse_mcc(data) -> Tuple[int, bool, bytes]:
mcc_type = data[0] >> 2
@@ -291,24 +237,24 @@ class RFCOMM_Frame:
@staticmethod
def sabm(c_r: int, dlci: int):
return RFCOMM_Frame(FrameType.SABM, c_r, dlci, 1)
return RFCOMM_Frame(RFCOMM_SABM_FRAME, c_r, dlci, 1)
@staticmethod
def ua(c_r: int, dlci: int):
return RFCOMM_Frame(FrameType.UA, c_r, dlci, 1)
return RFCOMM_Frame(RFCOMM_UA_FRAME, c_r, dlci, 1)
@staticmethod
def dm(c_r: int, dlci: int):
return RFCOMM_Frame(FrameType.DM, c_r, dlci, 1)
return RFCOMM_Frame(RFCOMM_DM_FRAME, c_r, dlci, 1)
@staticmethod
def disc(c_r: int, dlci: int):
return RFCOMM_Frame(FrameType.DISC, c_r, dlci, 1)
return RFCOMM_Frame(RFCOMM_DISC_FRAME, c_r, dlci, 1)
@staticmethod
def uih(c_r: int, dlci: int, information: bytes, p_f: int = 0):
return RFCOMM_Frame(
FrameType.UIH, c_r, dlci, p_f, information, with_credits=(p_f == 1)
RFCOMM_UIH_FRAME, c_r, dlci, p_f, information, with_credits=(p_f == 1)
)
@staticmethod
@@ -316,7 +262,7 @@ class RFCOMM_Frame:
# Extract fields
dlci = (data[0] >> 2) & 0x3F
c_r = (data[0] >> 1) & 0x01
frame_type = FrameType(data[1] & 0xEF)
frame_type = data[1] & 0xEF
p_f = (data[1] >> 4) & 0x01
length = data[2]
if length & 0x01:
@@ -345,7 +291,7 @@ class RFCOMM_Frame:
def __str__(self) -> str:
return (
f'{color(self.type.name, "yellow")}'
f'{color(self.type_name(), "yellow")}'
f'(c/r={self.c_r},'
f'dlci={self.dlci},'
f'p/f={self.p_f},'
@@ -355,7 +301,6 @@ class RFCOMM_Frame:
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class RFCOMM_MCC_PN:
dlci: int
cl: int
@@ -365,11 +310,23 @@ class RFCOMM_MCC_PN:
max_retransmissions: int
window_size: int
def __post_init__(self) -> None:
if self.window_size < 1 or self.window_size > 7:
logger.warning(
f'Error Recovery Window size {self.window_size} is out of range [1, 7].'
)
def __init__(
self,
dlci: int,
cl: int,
priority: int,
ack_timer: int,
max_frame_size: int,
max_retransmissions: int,
window_size: int,
) -> None:
self.dlci = dlci
self.cl = cl
self.priority = priority
self.ack_timer = ack_timer
self.max_frame_size = max_frame_size
self.max_retransmissions = max_retransmissions
self.window_size = window_size
@staticmethod
def from_bytes(data: bytes) -> RFCOMM_MCC_PN:
@@ -380,7 +337,7 @@ class RFCOMM_MCC_PN:
ack_timer=data[3],
max_frame_size=data[4] | data[5] << 8,
max_retransmissions=data[6],
window_size=data[7] & 0x07,
window_size=data[7],
)
def __bytes__(self) -> bytes:
@@ -393,14 +350,23 @@ class RFCOMM_MCC_PN:
self.max_frame_size & 0xFF,
(self.max_frame_size >> 8) & 0xFF,
self.max_retransmissions & 0xFF,
# Only 3 bits are meaningful.
self.window_size & 0x07,
self.window_size & 0xFF,
]
)
def __str__(self) -> str:
return (
f'PN(dlci={self.dlci},'
f'cl={self.cl},'
f'priority={self.priority},'
f'ack_timer={self.ack_timer},'
f'max_frame_size={self.max_frame_size},'
f'max_retransmissions={self.max_retransmissions},'
f'window_size={self.window_size})'
)
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class RFCOMM_MCC_MSC:
dlci: int
fc: int
@@ -409,6 +375,16 @@ class RFCOMM_MCC_MSC:
ic: int
dv: int
def __init__(
self, dlci: int, fc: int, rtc: int, rtr: int, ic: int, dv: int
) -> None:
self.dlci = dlci
self.fc = fc
self.rtc = rtc
self.rtr = rtr
self.ic = ic
self.dv = dv
@staticmethod
def from_bytes(data: bytes) -> RFCOMM_MCC_MSC:
return RFCOMM_MCC_MSC(
@@ -433,6 +409,16 @@ class RFCOMM_MCC_MSC:
]
)
def __str__(self) -> str:
return (
f'MSC(dlci={self.dlci},'
f'fc={self.fc},'
f'rtc={self.rtc},'
f'rtr={self.rtr},'
f'ic={self.ic},'
f'dv={self.dv})'
)
# -----------------------------------------------------------------------------
class DLC(EventEmitter):
@@ -452,29 +438,25 @@ class DLC(EventEmitter):
multiplexer: Multiplexer,
dlci: int,
max_frame_size: int,
window_size: int,
initial_tx_credits: int,
) -> None:
super().__init__()
self.multiplexer = multiplexer
self.dlci = dlci
self.max_frame_size = max_frame_size
self.window_size = window_size
self.rx_credits = window_size
self.rx_threshold = window_size // 2
self.tx_credits = window_size
self.rx_credits = RFCOMM_DEFAULT_INITIAL_RX_CREDITS
self.rx_threshold = self.rx_credits // 2
self.tx_credits = initial_tx_credits
self.tx_buffer = b''
self.state = DLC.State.INIT
self.role = multiplexer.role
self.c_r = 1 if self.role == Multiplexer.Role.INITIATOR else 0
self.sink = None
self.connection_result = None
self.drained = asyncio.Event()
self.drained.set()
# Compute the MTU
max_overhead = 4 + 1 # header with 2-byte length + fcs
self.mtu = min(
max_frame_size, self.multiplexer.l2cap_channel.peer_mtu - max_overhead
max_frame_size, self.multiplexer.l2cap_channel.mtu - max_overhead
)
def change_state(self, new_state: State) -> None:
@@ -485,7 +467,7 @@ class DLC(EventEmitter):
self.multiplexer.send_frame(frame)
def on_frame(self, frame: RFCOMM_Frame) -> None:
handler = getattr(self, f'on_{frame.type.name}_frame'.lower())
handler = getattr(self, f'on_{frame.type_name()}_frame'.lower())
handler(frame)
def on_sabm_frame(self, _frame: RFCOMM_Frame) -> None:
@@ -499,7 +481,9 @@ class DLC(EventEmitter):
# Exchange the modem status with the peer
msc = RFCOMM_MCC_MSC(dlci=self.dlci, fc=0, rtc=1, rtr=1, ic=0, dv=1)
mcc = RFCOMM_Frame.make_mcc(mcc_type=MccType.MSC, c_r=1, data=bytes(msc))
mcc = RFCOMM_Frame.make_mcc(
mcc_type=RFCOMM_MCC_MSC_TYPE, c_r=1, data=bytes(msc)
)
logger.debug(f'>>> MCC MSC Command: {msc}')
self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
@@ -515,7 +499,9 @@ class DLC(EventEmitter):
# Exchange the modem status with the peer
msc = RFCOMM_MCC_MSC(dlci=self.dlci, fc=0, rtc=1, rtr=1, ic=0, dv=1)
mcc = RFCOMM_Frame.make_mcc(mcc_type=MccType.MSC, c_r=1, data=bytes(msc))
mcc = RFCOMM_Frame.make_mcc(
mcc_type=RFCOMM_MCC_MSC_TYPE, c_r=1, data=bytes(msc)
)
logger.debug(f'>>> MCC MSC Command: {msc}')
self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
@@ -548,15 +534,14 @@ class DLC(EventEmitter):
f'[{self.dlci}] {len(data)} bytes, '
f'rx_credits={self.rx_credits}: {data.hex()}'
)
if data:
if self.sink:
self.sink(data) # pylint: disable=not-callable
if len(data) and self.sink:
self.sink(data) # pylint: disable=not-callable
# Update the credits
if self.rx_credits > 0:
self.rx_credits -= 1
else:
logger.warning(color('!!! received frame with no rx credits', 'red'))
# Update the credits
if self.rx_credits > 0:
self.rx_credits -= 1
else:
logger.warning(color('!!! received frame with no rx credits', 'red'))
# Check if there's anything to send (including credits)
self.process_tx()
@@ -569,7 +554,9 @@ class DLC(EventEmitter):
# Command
logger.debug(f'<<< MCC MSC Command: {msc}')
msc = RFCOMM_MCC_MSC(dlci=self.dlci, fc=0, rtc=1, rtr=1, ic=0, dv=1)
mcc = RFCOMM_Frame.make_mcc(mcc_type=MccType.MSC, c_r=0, data=bytes(msc))
mcc = RFCOMM_Frame.make_mcc(
mcc_type=RFCOMM_MCC_MSC_TYPE, c_r=0, data=bytes(msc)
)
logger.debug(f'>>> MCC MSC Response: {msc}')
self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
else:
@@ -593,18 +580,18 @@ class DLC(EventEmitter):
cl=0xE0,
priority=7,
ack_timer=0,
max_frame_size=self.max_frame_size,
max_frame_size=RFCOMM_DEFAULT_PREFERRED_MTU,
max_retransmissions=0,
window_size=self.window_size,
window_size=RFCOMM_DEFAULT_INITIAL_RX_CREDITS,
)
mcc = RFCOMM_Frame.make_mcc(mcc_type=MccType.PN, c_r=0, data=bytes(pn))
mcc = RFCOMM_Frame.make_mcc(mcc_type=RFCOMM_MCC_PN_TYPE, c_r=0, data=bytes(pn))
logger.debug(f'>>> PN Response: {pn}')
self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
self.change_state(DLC.State.CONNECTING)
def rx_credits_needed(self) -> int:
if self.rx_credits <= self.rx_threshold:
return self.window_size - self.rx_credits
return RFCOMM_DEFAULT_INITIAL_RX_CREDITS - self.rx_credits
return 0
@@ -644,8 +631,6 @@ class DLC(EventEmitter):
)
rx_credits_needed = 0
if not self.tx_buffer:
self.drained.set()
# Stream protocol
def write(self, data: Union[bytes, str]) -> None:
@@ -658,11 +643,11 @@ class DLC(EventEmitter):
raise ValueError('write only accept bytes or strings')
self.tx_buffer += data
self.drained.clear()
self.process_tx()
async def drain(self) -> None:
await self.drained.wait()
def drain(self) -> None:
# TODO
pass
def __str__(self) -> str:
return f'DLC(dlci={self.dlci},state={self.state.name})'
@@ -689,7 +674,7 @@ class Multiplexer(EventEmitter):
acceptor: Optional[Callable[[int], bool]]
dlcs: Dict[int, DLC]
def __init__(self, l2cap_channel: l2cap.ClassicChannel, role: Role) -> None:
def __init__(self, l2cap_channel: l2cap.Channel, role: Role) -> None:
super().__init__()
self.role = role
self.l2cap_channel = l2cap_channel
@@ -719,7 +704,7 @@ class Multiplexer(EventEmitter):
if frame.dlci == 0:
self.on_frame(frame)
else:
if frame.type == FrameType.DM:
if frame.type == RFCOMM_DM_FRAME:
# DM responses are for a DLCI, but since we only create the dlc when we
# receive a PN response (because we need the parameters), we handle DM
# frames at the Multiplexer level
@@ -732,7 +717,7 @@ class Multiplexer(EventEmitter):
dlc.on_frame(frame)
def on_frame(self, frame: RFCOMM_Frame) -> None:
handler = getattr(self, f'on_{frame.type.name}_frame'.lower())
handler = getattr(self, f'on_{frame.type_name()}_frame'.lower())
handler(frame)
def on_sabm_frame(self, _frame: RFCOMM_Frame) -> None:
@@ -780,10 +765,10 @@ class Multiplexer(EventEmitter):
def on_uih_frame(self, frame: RFCOMM_Frame) -> None:
(mcc_type, c_r, value) = RFCOMM_Frame.parse_mcc(frame.information)
if mcc_type == MccType.PN:
if mcc_type == RFCOMM_MCC_PN_TYPE:
pn = RFCOMM_MCC_PN.from_bytes(value)
self.on_mcc_pn(c_r, pn)
elif mcc_type == MccType.MSC:
elif mcc_type == RFCOMM_MCC_MSC_TYPE:
mcs = RFCOMM_MCC_MSC.from_bytes(value)
self.on_mcc_msc(c_r, mcs)
@@ -858,12 +843,7 @@ class Multiplexer(EventEmitter):
)
await self.disconnection_result
async def open_dlc(
self,
channel: int,
max_frame_size: int = RFCOMM_DEFAULT_MAX_FRAME_SIZE,
window_size: int = RFCOMM_DEFAULT_WINDOW_SIZE,
) -> DLC:
async def open_dlc(self, channel: int) -> DLC:
if self.state != Multiplexer.State.CONNECTED:
if self.state == Multiplexer.State.OPENING:
raise InvalidStateError('open already in progress')
@@ -875,11 +855,11 @@ class Multiplexer(EventEmitter):
cl=0xF0,
priority=7,
ack_timer=0,
max_frame_size=max_frame_size,
max_frame_size=RFCOMM_DEFAULT_PREFERRED_MTU,
max_retransmissions=0,
window_size=window_size,
window_size=RFCOMM_DEFAULT_INITIAL_RX_CREDITS,
)
mcc = RFCOMM_Frame.make_mcc(mcc_type=MccType.PN, c_r=1, data=bytes(pn))
mcc = RFCOMM_Frame.make_mcc(mcc_type=RFCOMM_MCC_PN_TYPE, c_r=1, data=bytes(pn))
logger.debug(f'>>> Sending MCC: {pn}')
self.open_result = asyncio.get_running_loop().create_future()
self.change_state(Multiplexer.State.OPENING)
@@ -907,28 +887,26 @@ class Multiplexer(EventEmitter):
# -----------------------------------------------------------------------------
class Client:
multiplexer: Optional[Multiplexer]
l2cap_channel: Optional[l2cap.ClassicChannel]
l2cap_channel: Optional[l2cap.Channel]
def __init__(
self, connection: Connection, l2cap_mtu: int = RFCOMM_DEFAULT_L2CAP_MTU
) -> None:
def __init__(self, device: Device, connection: Connection) -> None:
self.device = device
self.connection = connection
self.l2cap_mtu = l2cap_mtu
self.l2cap_channel = None
self.multiplexer = None
async def start(self) -> Multiplexer:
# Create a new L2CAP connection
try:
self.l2cap_channel = await self.connection.create_l2cap_channel(
spec=l2cap.ClassicChannelSpec(psm=RFCOMM_PSM, mtu=self.l2cap_mtu)
self.l2cap_channel = await self.device.l2cap_channel_manager.connect(
self.connection, RFCOMM_PSM
)
except ProtocolError as error:
logger.warning(f'L2CAP connection failed: {error}')
raise
assert self.l2cap_channel is not None
# Create a multiplexer to manage DLCs with the server
# Create a mutliplexer to manage DLCs with the server
self.multiplexer = Multiplexer(self.l2cap_channel, Multiplexer.Role.INITIATOR)
# Connect the multiplexer
@@ -944,34 +922,21 @@ class Client:
self.multiplexer = None
# Close the L2CAP channel
if self.l2cap_channel:
await self.l2cap_channel.disconnect()
self.l2cap_channel = None
async def __aenter__(self) -> Multiplexer:
return await self.start()
async def __aexit__(self, *args) -> None:
await self.shutdown()
# TODO
# -----------------------------------------------------------------------------
class Server(EventEmitter):
acceptors: Dict[int, Callable[[DLC], None]]
def __init__(
self, device: Device, l2cap_mtu: int = RFCOMM_DEFAULT_L2CAP_MTU
) -> None:
def __init__(self, device: Device) -> None:
super().__init__()
self.device = device
self.multiplexer = None
self.acceptors = {}
# Register ourselves with the L2CAP channel manager
self.l2cap_server = device.create_l2cap_server(
spec=l2cap.ClassicChannelSpec(psm=RFCOMM_PSM, mtu=l2cap_mtu),
handler=self.on_connection,
)
device.register_l2cap_server(RFCOMM_PSM, self.on_connection)
def listen(self, acceptor: Callable[[DLC], None], channel: int = 0) -> int:
if channel:
@@ -995,11 +960,11 @@ class Server(EventEmitter):
self.acceptors[channel] = acceptor
return channel
def on_connection(self, l2cap_channel: l2cap.ClassicChannel) -> None:
def on_connection(self, l2cap_channel: l2cap.Channel) -> None:
logger.debug(f'+++ new L2CAP connection: {l2cap_channel}')
l2cap_channel.on('open', lambda: self.on_l2cap_channel_open(l2cap_channel))
def on_l2cap_channel_open(self, l2cap_channel: l2cap.ClassicChannel) -> None:
def on_l2cap_channel_open(self, l2cap_channel: l2cap.Channel) -> None:
logger.debug(f'$$$ L2CAP channel open: {l2cap_channel}')
# Create a new multiplexer for the channel
@@ -1020,9 +985,3 @@ class Server(EventEmitter):
acceptor = self.acceptors.get(dlc.dlci >> 1)
if acceptor:
acceptor(dlc)
def __enter__(self) -> Self:
return self
def __exit__(self, *args) -> None:
self.l2cap_server.close()

View File

@@ -19,7 +19,6 @@ from __future__ import annotations
import logging
import struct
from typing import Dict, List, Type, Optional, Tuple, Union, NewType, TYPE_CHECKING
from typing_extensions import Self
from . import core, l2cap
from .colors import color
@@ -98,8 +97,7 @@ SDP_CLIENT_EXECUTABLE_URL_ATTRIBUTE_ID = 0X000B
SDP_ICON_URL_ATTRIBUTE_ID = 0X000C
SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID = 0X000D
# Profile-specific Attribute Identifiers (cf. Assigned Numbers for Service Discovery)
# Attribute Identifier (cf. Assigned Numbers for Service Discovery)
# used by AVRCP, HFP and A2DP
SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID = 0x0311
@@ -117,8 +115,7 @@ SDP_ATTRIBUTE_ID_NAMES = {
SDP_DOCUMENTATION_URL_ATTRIBUTE_ID: 'SDP_DOCUMENTATION_URL_ATTRIBUTE_ID',
SDP_CLIENT_EXECUTABLE_URL_ATTRIBUTE_ID: 'SDP_CLIENT_EXECUTABLE_URL_ATTRIBUTE_ID',
SDP_ICON_URL_ATTRIBUTE_ID: 'SDP_ICON_URL_ATTRIBUTE_ID',
SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID: 'SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID',
SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID: 'SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID',
SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID: 'SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID'
}
SDP_PUBLIC_BROWSE_ROOT = core.UUID.from_16_bits(0x1002, 'PublicBrowseRoot')
@@ -170,7 +167,7 @@ class DataElement:
UUID: lambda x: DataElement(
DataElement.UUID, core.UUID.from_bytes(bytes(reversed(x)))
),
TEXT_STRING: lambda x: DataElement(DataElement.TEXT_STRING, x),
TEXT_STRING: lambda x: DataElement(DataElement.TEXT_STRING, x.decode('utf8')),
BOOLEAN: lambda x: DataElement(DataElement.BOOLEAN, x[0] == 1),
SEQUENCE: lambda x: DataElement(
DataElement.SEQUENCE, DataElement.list_from_bytes(x)
@@ -232,7 +229,7 @@ class DataElement:
return DataElement(DataElement.UUID, value)
@staticmethod
def text_string(value: bytes) -> DataElement:
def text_string(value: str) -> DataElement:
return DataElement(DataElement.TEXT_STRING, value)
@staticmethod
@@ -379,7 +376,7 @@ class DataElement:
raise ValueError('invalid value_size')
elif self.type == DataElement.UUID:
data = bytes(reversed(bytes(self.value)))
elif self.type == DataElement.URL:
elif self.type in (DataElement.TEXT_STRING, DataElement.URL):
data = self.value.encode('utf8')
elif self.type == DataElement.BOOLEAN:
data = bytes([1 if self.value else 0])
@@ -761,17 +758,16 @@ class SDP_ServiceSearchAttributeResponse(SDP_PDU):
# -----------------------------------------------------------------------------
class Client:
channel: Optional[l2cap.ClassicChannel]
channel: Optional[l2cap.Channel]
def __init__(self, connection: Connection) -> None:
self.connection = connection
def __init__(self, device: Device) -> None:
self.device = device
self.pending_request = None
self.channel = None
async def connect(self) -> None:
self.channel = await self.connection.create_l2cap_channel(
spec=l2cap.ClassicChannelSpec(SDP_PSM)
)
async def connect(self, connection: Connection) -> None:
result = await self.device.l2cap_channel_manager.connect(connection, SDP_PSM)
self.channel = result
async def disconnect(self) -> None:
if self.channel:
@@ -921,18 +917,11 @@ class Client:
return ServiceAttribute.list_from_data_elements(attribute_list_sequence.value)
async def __aenter__(self) -> Self:
await self.connect()
return self
async def __aexit__(self, *args) -> None:
await self.disconnect()
# -----------------------------------------------------------------------------
class Server:
CONTINUATION_STATE = bytes([0x01, 0x43])
channel: Optional[l2cap.ClassicChannel]
channel: Optional[l2cap.Channel]
Service = NewType('Service', List[ServiceAttribute])
service_records: Dict[int, Service]
current_response: Union[None, bytes, Tuple[int, List[int]]]
@@ -944,9 +933,7 @@ class Server:
self.current_response = None
def register(self, l2cap_channel_manager: l2cap.ChannelManager) -> None:
l2cap_channel_manager.create_classic_server(
spec=l2cap.ClassicChannelSpec(psm=SDP_PSM), handler=self.on_connection
)
l2cap_channel_manager.register_server(SDP_PSM, self.on_connection)
def send_response(self, response):
logger.debug(f'{color(">>> Sending SDP Response", "blue")}: {response}')

View File

@@ -27,7 +27,6 @@ import logging
import asyncio
import enum
import secrets
from dataclasses import dataclass
from typing import (
TYPE_CHECKING,
Any,
@@ -54,7 +53,6 @@ from .core import (
BT_BR_EDR_TRANSPORT,
BT_CENTRAL_ROLE,
BT_LE_TRANSPORT,
AdvertisingData,
ProtocolError,
name_or_number,
)
@@ -187,8 +185,8 @@ SMP_KEYPRESS_AUTHREQ = 0b00010000
SMP_CT2_AUTHREQ = 0b00100000
# Crypto salt
SMP_CTKD_H7_LEBR_SALT = bytes.fromhex('000000000000000000000000746D7031')
SMP_CTKD_H7_BRLE_SALT = bytes.fromhex('000000000000000000000000746D7032')
SMP_CTKD_H7_LEBR_SALT = bytes.fromhex('00000000000000000000000000000000746D7031')
SMP_CTKD_H7_BRLE_SALT = bytes.fromhex('00000000000000000000000000000000746D7032')
# fmt: on
# pylint: enable=line-too-long
@@ -565,54 +563,6 @@ class PairingMethod(enum.IntEnum):
CTKD_OVER_CLASSIC = 4
# -----------------------------------------------------------------------------
class OobContext:
"""Cryptographic context for LE SC OOB pairing."""
ecc_key: crypto.EccKey
r: bytes
def __init__(
self, ecc_key: Optional[crypto.EccKey] = None, r: Optional[bytes] = None
) -> None:
self.ecc_key = crypto.EccKey.generate() if ecc_key is None else ecc_key
self.r = crypto.r() if r is None else r
def share(self) -> OobSharedData:
pkx = self.ecc_key.x[::-1]
return OobSharedData(c=crypto.f4(pkx, pkx, self.r, bytes(1)), r=self.r)
# -----------------------------------------------------------------------------
class OobLegacyContext:
"""Cryptographic context for LE Legacy OOB pairing."""
tk: bytes
def __init__(self, tk: Optional[bytes] = None) -> None:
self.tk = crypto.r() if tk is None else tk
# -----------------------------------------------------------------------------
@dataclass
class OobSharedData:
"""Shareable data for LE SC OOB pairing."""
c: bytes
r: bytes
def to_ad(self) -> AdvertisingData:
return AdvertisingData(
[
(AdvertisingData.LE_SECURE_CONNECTIONS_CONFIRMATION_VALUE, self.c),
(AdvertisingData.LE_SECURE_CONNECTIONS_RANDOM_VALUE, self.r),
]
)
def __str__(self) -> str:
return f'OOB(C={self.c.hex()}, R={self.r.hex()})'
# -----------------------------------------------------------------------------
class Session:
# I/O Capability to pairing method decision matrix
@@ -677,13 +627,6 @@ class Session:
},
}
ea: bytes
eb: bytes
ltk: bytes
preq: bytes
pres: bytes
tk: bytes
def __init__(
self,
manager: Manager,
@@ -693,10 +636,17 @@ class Session:
) -> None:
self.manager = manager
self.connection = connection
self.preq: Optional[bytes] = None
self.pres: Optional[bytes] = None
self.ea = None
self.eb = None
self.tk = bytes(16)
self.r = bytes(16)
self.stk = None
self.ltk = None
self.ltk_ediv = 0
self.ltk_rand = bytes(8)
self.link_key: Optional[bytes] = None
self.link_key = None
self.initiator_key_distribution: int = 0
self.responder_key_distribution: int = 0
self.peer_random_value: Optional[bytes] = None
@@ -709,7 +659,7 @@ class Session:
self.peer_bd_addr: Optional[Address] = None
self.peer_signature_key = None
self.peer_expected_distributions: List[Type[SMP_Command]] = []
self.dh_key = b''
self.dh_key = None
self.confirm_value = None
self.passkey: Optional[int] = None
self.passkey_ready = asyncio.Event()
@@ -762,8 +712,8 @@ class Session:
self.io_capability = pairing_config.delegate.io_capability
self.peer_io_capability = SMP_NO_INPUT_NO_OUTPUT_IO_CAPABILITY
# OOB
self.oob_data_flag = 0 if pairing_config.oob is None else 1
# OOB (not supported yet)
self.oob = False
# Set up addresses
self_address = connection.self_address
@@ -779,35 +729,9 @@ class Session:
self.ia = bytes(peer_address)
self.iat = 1 if peer_address.is_random else 0
# Select the ECC key, TK and r initial value
if pairing_config.oob:
self.peer_oob_data = pairing_config.oob.peer_data
if pairing_config.sc:
if pairing_config.oob.our_context is None:
raise ValueError(
"oob pairing config requires a context when sc is True"
)
self.r = pairing_config.oob.our_context.r
self.ecc_key = pairing_config.oob.our_context.ecc_key
if pairing_config.oob.legacy_context is not None:
self.tk = pairing_config.oob.legacy_context.tk
else:
if pairing_config.oob.legacy_context is None:
raise ValueError(
"oob pairing config requires a legacy context when sc is False"
)
self.r = bytes(16)
self.ecc_key = manager.ecc_key
self.tk = pairing_config.oob.legacy_context.tk
else:
self.peer_oob_data = None
self.r = bytes(16)
self.ecc_key = manager.ecc_key
self.tk = bytes(16)
@property
def pkx(self) -> Tuple[bytes, bytes]:
return (self.ecc_key.x[::-1], self.peer_public_key_x)
return (bytes(reversed(self.manager.ecc_key.x)), self.peer_public_key_x)
@property
def pka(self) -> bytes:
@@ -844,10 +768,7 @@ class Session:
return None
def decide_pairing_method(
self,
auth_req: int,
initiator_io_capability: int,
responder_io_capability: int,
self, auth_req: int, initiator_io_capability: int, responder_io_capability: int
) -> None:
if self.connection.transport == BT_BR_EDR_TRANSPORT:
self.pairing_method = PairingMethod.CTKD_OVER_CLASSIC
@@ -988,7 +909,7 @@ class Session:
command = SMP_Pairing_Request_Command(
io_capability=self.io_capability,
oob_data_flag=self.oob_data_flag,
oob_data_flag=0,
auth_req=self.auth_req,
maximum_encryption_key_size=16,
initiator_key_distribution=self.initiator_key_distribution,
@@ -1000,7 +921,7 @@ class Session:
def send_pairing_response_command(self) -> None:
response = SMP_Pairing_Response_Command(
io_capability=self.io_capability,
oob_data_flag=self.oob_data_flag,
oob_data_flag=0,
auth_req=self.auth_req,
maximum_encryption_key_size=16,
initiator_key_distribution=self.initiator_key_distribution,
@@ -1061,8 +982,8 @@ class Session:
def send_public_key_command(self) -> None:
self.send_command(
SMP_Pairing_Public_Key_Command(
public_key_x=self.ecc_key.x[::-1],
public_key_y=self.ecc_key.y[::-1],
public_key_x=bytes(reversed(self.manager.ecc_key.x)),
public_key_y=bytes(reversed(self.manager.ecc_key.y)),
)
)
@@ -1090,7 +1011,7 @@ class Session:
# We can now encrypt the connection with the short term key, so that we can
# distribute the long term and/or other keys over an encrypted connection
self.manager.device.host.send_command_sync(
HCI_LE_Enable_Encryption_Command(
HCI_LE_Enable_Encryption_Command( # type: ignore[call-arg]
connection_handle=self.connection.handle,
random_number=bytes(8),
encrypted_diversifier=0,
@@ -1098,56 +1019,18 @@ class Session:
)
)
@classmethod
def derive_ltk(cls, link_key: bytes, ct2: bool) -> bytes:
'''Derives Long Term Key from Link Key.
Args:
link_key: BR/EDR Link Key bytes in little-endian.
ct2: whether ct2 is supported on both devices.
Returns:
LE Long Tern Key bytes in little-endian.
'''
async def derive_ltk(self) -> None:
link_key = await self.manager.device.get_link_key(self.connection.peer_address)
assert link_key is not None
ilk = (
crypto.h7(salt=SMP_CTKD_H7_BRLE_SALT, w=link_key)
if ct2
if self.ct2
else crypto.h6(link_key, b'tmp2')
)
return crypto.h6(ilk, b'brle')
@classmethod
def derive_link_key(cls, ltk: bytes, ct2: bool) -> bytes:
'''Derives Link Key from Long Term Key.
Args:
ltk: LE Long Term Key bytes in little-endian.
ct2: whether ct2 is supported on both devices.
Returns:
BR/EDR Link Key bytes in little-endian.
'''
ilk = (
crypto.h7(salt=SMP_CTKD_H7_LEBR_SALT, w=ltk)
if ct2
else crypto.h6(ltk, b'tmp1')
)
return crypto.h6(ilk, b'lebr')
async def get_link_key_and_derive_ltk(self) -> None:
'''Retrieves BR/EDR Link Key from storage and derive it to LE LTK.'''
self.link_key = await self.manager.device.get_link_key(
self.connection.peer_address
)
if self.link_key is None:
logging.warning(
'Try to derive LTK but host does not have the LK. Send a SMP_PAIRING_FAILED but the procedure will not be paused!'
)
self.send_pairing_failed(
SMP_CROSS_TRANSPORT_KEY_DERIVATION_NOT_ALLOWED_ERROR
)
else:
self.ltk = self.derive_ltk(self.link_key, self.ct2)
self.ltk = crypto.h6(ilk, b'brle')
def distribute_keys(self) -> None:
# Distribute the keys as required
if self.is_initiator:
# CTKD: Derive LTK from LinkKey
@@ -1156,7 +1039,7 @@ class Session:
and self.initiator_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG
):
self.ctkd_task = self.connection.abort_on(
'disconnection', self.get_link_key_and_derive_ltk()
'disconnection', self.derive_ltk()
)
elif not self.sc:
# Distribute the LTK, EDIV and RAND
@@ -1186,7 +1069,12 @@ class Session:
# CTKD, calculate BR/EDR link key
if self.initiator_key_distribution & SMP_LINK_KEY_DISTRIBUTION_FLAG:
self.link_key = self.derive_link_key(self.ltk, self.ct2)
ilk = (
crypto.h7(salt=SMP_CTKD_H7_LEBR_SALT, w=self.ltk)
if self.ct2
else crypto.h6(self.ltk, b'tmp1')
)
self.link_key = crypto.h6(ilk, b'lebr')
else:
# CTKD: Derive LTK from LinkKey
@@ -1195,7 +1083,7 @@ class Session:
and self.responder_key_distribution & SMP_ENC_KEY_DISTRIBUTION_FLAG
):
self.ctkd_task = self.connection.abort_on(
'disconnection', self.get_link_key_and_derive_ltk()
'disconnection', self.derive_ltk()
)
# Distribute the LTK, EDIV and RAND
elif not self.sc:
@@ -1225,7 +1113,12 @@ class Session:
# CTKD, calculate BR/EDR link key
if self.responder_key_distribution & SMP_LINK_KEY_DISTRIBUTION_FLAG:
self.link_key = self.derive_link_key(self.ltk, self.ct2)
ilk = (
crypto.h7(salt=SMP_CTKD_H7_LEBR_SALT, w=self.ltk)
if self.ct2
else crypto.h6(self.ltk, b'tmp1')
)
self.link_key = crypto.h6(ilk, b'lebr')
def compute_peer_expected_distributions(self, key_distribution_flags: int) -> None:
# Set our expectations for what to wait for in the key distribution phase
@@ -1403,7 +1296,7 @@ class Session:
try:
handler(command)
except Exception as error:
logger.exception(f'{color("!!! Exception in handler:", "red")} {error}')
logger.warning(f'{color("!!! Exception in handler:", "red")} {error}')
response = SMP_Pairing_Failed_Command(
reason=SMP_UNSPECIFIED_REASON_ERROR
)
@@ -1440,28 +1333,15 @@ class Session:
self.sc = self.sc and (command.auth_req & SMP_SC_AUTHREQ != 0)
self.ct2 = self.ct2 and (command.auth_req & SMP_CT2_AUTHREQ != 0)
# Infer the pairing method
if (self.sc and (self.oob_data_flag != 0 or command.oob_data_flag != 0)) or (
not self.sc and (self.oob_data_flag != 0 and command.oob_data_flag != 0)
):
# Use OOB
self.pairing_method = PairingMethod.OOB
if not self.sc and self.tk is None:
# For legacy OOB, TK is required.
logger.warning("legacy OOB without TK")
self.send_pairing_failed(SMP_OOB_NOT_AVAILABLE_ERROR)
return
if command.oob_data_flag == 0:
# The peer doesn't have OOB data, use r=0
self.r = bytes(16)
else:
# Decide which pairing method to use from the IO capability
self.decide_pairing_method(
command.auth_req,
command.io_capability,
self.io_capability,
)
# Check for OOB
if command.oob_data_flag != 0:
self.send_pairing_failed(SMP_OOB_NOT_AVAILABLE_ERROR)
return
# Decide which pairing method to use
self.decide_pairing_method(
command.auth_req, command.io_capability, self.io_capability
)
logger.debug(f'pairing method: {self.pairing_method.name}')
# Key distribution
@@ -1510,26 +1390,15 @@ class Session:
self.bonding = self.bonding and (command.auth_req & SMP_BONDING_AUTHREQ != 0)
self.sc = self.sc and (command.auth_req & SMP_SC_AUTHREQ != 0)
# Infer the pairing method
if (self.sc and (self.oob_data_flag != 0 or command.oob_data_flag != 0)) or (
not self.sc and (self.oob_data_flag != 0 and command.oob_data_flag != 0)
):
# Use OOB
self.pairing_method = PairingMethod.OOB
if not self.sc and self.tk is None:
# For legacy OOB, TK is required.
logger.warning("legacy OOB without TK")
self.send_pairing_failed(SMP_OOB_NOT_AVAILABLE_ERROR)
return
if command.oob_data_flag == 0:
# The peer doesn't have OOB data, use r=0
self.r = bytes(16)
else:
# Decide which pairing method to use from the IO capability
self.decide_pairing_method(
command.auth_req, self.io_capability, command.io_capability
)
# Check for OOB
if self.sc and command.oob_data_flag:
self.send_pairing_failed(SMP_OOB_NOT_AVAILABLE_ERROR)
return
# Decide which pairing method to use
self.decide_pairing_method(
command.auth_req, self.io_capability, command.io_capability
)
logger.debug(f'pairing method: {self.pairing_method.name}')
# Key distribution
@@ -1680,13 +1549,12 @@ class Session:
if self.passkey_step < 20:
self.send_pairing_confirm_command()
return
elif self.pairing_method != PairingMethod.OOB:
else:
return
else:
if self.pairing_method in (
PairingMethod.JUST_WORKS,
PairingMethod.NUMERIC_COMPARISON,
PairingMethod.OOB,
):
self.send_pairing_random_command()
elif self.pairing_method == PairingMethod.PASSKEY:
@@ -1723,7 +1591,6 @@ class Session:
if self.pairing_method in (
PairingMethod.JUST_WORKS,
PairingMethod.NUMERIC_COMPARISON,
PairingMethod.OOB,
):
ra = bytes(16)
rb = ra
@@ -1732,6 +1599,7 @@ class Session:
ra = self.passkey.to_bytes(16, byteorder='little')
rb = ra
else:
# OOB not implemented yet
return
assert self.preq and self.pres
@@ -1783,33 +1651,18 @@ class Session:
self.peer_public_key_y = command.public_key_y
# Compute the DH key
self.dh_key = self.ecc_key.dh(
command.public_key_x[::-1],
command.public_key_y[::-1],
)[::-1]
self.dh_key = bytes(
reversed(
self.manager.ecc_key.dh(
bytes(reversed(command.public_key_x)),
bytes(reversed(command.public_key_y)),
)
)
)
logger.debug(f'DH key: {self.dh_key.hex()}')
if self.pairing_method == PairingMethod.OOB:
# Check against shared OOB data
if self.peer_oob_data:
confirm_verifier = crypto.f4(
self.peer_public_key_x,
self.peer_public_key_x,
self.peer_oob_data.r,
bytes(1),
)
if not self.check_expected_value(
self.peer_oob_data.c,
confirm_verifier,
SMP_CONFIRM_VALUE_FAILED_ERROR,
):
return
if self.is_initiator:
if self.pairing_method == PairingMethod.OOB:
self.send_pairing_random_command()
else:
self.send_pairing_confirm_command()
self.send_pairing_confirm_command()
else:
if self.pairing_method == PairingMethod.PASSKEY:
self.display_or_input_passkey()
@@ -1820,7 +1673,6 @@ class Session:
if self.pairing_method in (
PairingMethod.JUST_WORKS,
PairingMethod.NUMERIC_COMPARISON,
PairingMethod.OOB,
):
# We can now send the confirmation value
self.send_pairing_confirm_command()
@@ -1849,6 +1701,7 @@ class Session:
else:
self.send_pairing_dhkey_check_command()
else:
assert self.ltk
self.start_encryption(self.ltk)
def on_smp_pairing_failed_command(
@@ -1898,7 +1751,6 @@ class Manager(EventEmitter):
sessions: Dict[int, Session]
pairing_config_factory: Callable[[Connection], PairingConfig]
session_proxy: Type[Session]
_ecc_key: Optional[crypto.EccKey]
def __init__(
self,
@@ -1993,8 +1845,10 @@ class Manager(EventEmitter):
) -> None:
# Store the keys in the key store
if self.device.keystore and identity_address is not None:
# Make sure on_pairing emits after key update.
await self.device.update_keys(str(identity_address), keys)
self.device.abort_on(
'flush', self.device.update_keys(str(identity_address), keys)
)
# Notify the device
self.device.on_pairing(session.connection, identity_address, keys, session.sc)

View File

@@ -18,7 +18,6 @@
from contextlib import asynccontextmanager
import logging
import os
from typing import Optional
from .common import Transport, AsyncPipeSink, SnoopingTransport
from ..snoop import create_snooper
@@ -53,16 +52,8 @@ def _wrap_transport(transport: Transport) -> Transport:
async def open_transport(name: str) -> Transport:
"""
Open a transport by name.
The name must be <type>:<metadata><parameters>
Where <parameters> depend on the type (and may be empty for some types), and
<metadata> is either omitted, or a ,-separated list of <key>=<value> pairs,
enclosed in [].
If there are not metadata or parameter, the : after the <type> may be omitted.
Examples:
* usb:0
* usb:[driver=rtk]0
* android-netsim
The name must be <type>:<parameters>
Where <parameters> depend on the type (and may be empty for some types).
The supported types are:
* serial
* udp
@@ -80,105 +71,87 @@ async def open_transport(name: str) -> Transport:
* android-netsim
"""
scheme, *tail = name.split(':', 1)
spec = tail[0] if tail else None
metadata = None
if spec:
# Metadata may precede the spec
if spec.startswith('['):
metadata_str, *tail = spec[1:].split(']')
spec = tail[0] if tail else None
metadata = dict([entry.split('=') for entry in metadata_str.split(',')])
transport = await _open_transport(scheme, spec)
if metadata:
transport.source.metadata = { # type: ignore[attr-defined]
**metadata,
**getattr(transport.source, 'metadata', {}),
}
# pylint: disable=line-too-long
logger.debug(f'HCI metadata: {transport.source.metadata}') # type: ignore[attr-defined]
return _wrap_transport(transport)
return _wrap_transport(await _open_transport(name))
# -----------------------------------------------------------------------------
async def _open_transport(scheme: str, spec: Optional[str]) -> Transport:
async def _open_transport(name: str) -> Transport:
# pylint: disable=import-outside-toplevel
# pylint: disable=too-many-return-statements
scheme, *spec = name.split(':', 1)
if scheme == 'serial' and spec:
from .serial import open_serial_transport
return await open_serial_transport(spec)
return await open_serial_transport(spec[0])
if scheme == 'udp' and spec:
from .udp import open_udp_transport
return await open_udp_transport(spec)
return await open_udp_transport(spec[0])
if scheme == 'tcp-client' and spec:
from .tcp_client import open_tcp_client_transport
return await open_tcp_client_transport(spec)
return await open_tcp_client_transport(spec[0])
if scheme == 'tcp-server' and spec:
from .tcp_server import open_tcp_server_transport
return await open_tcp_server_transport(spec)
return await open_tcp_server_transport(spec[0])
if scheme == 'ws-client' and spec:
from .ws_client import open_ws_client_transport
return await open_ws_client_transport(spec)
return await open_ws_client_transport(spec[0])
if scheme == 'ws-server' and spec:
from .ws_server import open_ws_server_transport
return await open_ws_server_transport(spec)
return await open_ws_server_transport(spec[0])
if scheme == 'pty':
from .pty import open_pty_transport
return await open_pty_transport(spec)
return await open_pty_transport(spec[0] if spec else None)
if scheme == 'file':
from .file import open_file_transport
assert spec is not None
return await open_file_transport(spec)
return await open_file_transport(spec[0])
if scheme == 'vhci':
from .vhci import open_vhci_transport
return await open_vhci_transport(spec)
return await open_vhci_transport(spec[0] if spec else None)
if scheme == 'hci-socket':
from .hci_socket import open_hci_socket_transport
return await open_hci_socket_transport(spec)
return await open_hci_socket_transport(spec[0] if spec else None)
if scheme == 'usb':
from .usb import open_usb_transport
assert spec
return await open_usb_transport(spec)
assert spec is not None
return await open_usb_transport(spec[0])
if scheme == 'pyusb':
from .pyusb import open_pyusb_transport
assert spec
return await open_pyusb_transport(spec)
assert spec is not None
return await open_pyusb_transport(spec[0])
if scheme == 'android-emulator':
from .android_emulator import open_android_emulator_transport
return await open_android_emulator_transport(spec)
return await open_android_emulator_transport(spec[0] if spec else None)
if scheme == 'android-netsim':
from .android_netsim import open_android_netsim_transport
return await open_android_netsim_transport(spec)
return await open_android_netsim_transport(spec[0] if spec else None)
raise ValueError('unknown transport scheme')
@@ -197,13 +170,12 @@ async def open_transport_or_link(name: str) -> Transport:
"""
if name.startswith('link-relay:'):
logger.warning('Link Relay has been deprecated.')
from ..controller import Controller
from ..link import RemoteLink # lazy import
link = RemoteLink(name[11:])
await link.wait_until_connected()
controller = Controller('remote', link=link) # type:ignore[arg-type]
controller = Controller('remote', link=link)
class LinkTransport(Transport):
async def close(self):

View File

@@ -69,7 +69,7 @@ async def open_android_emulator_transport(spec: Optional[str]) -> Transport:
mode = 'host'
server_host = 'localhost'
server_port = '8554'
if spec:
if spec is not None:
params = spec.split(',')
for param in params:
if param.startswith('mode='):

View File

@@ -21,7 +21,7 @@ import struct
import asyncio
import logging
import io
from typing import Any, ContextManager, Tuple, Optional, Protocol, Dict
from typing import ContextManager, Tuple, Optional, Protocol, Dict
from bumble import hci
from bumble.colors import color
@@ -42,7 +42,6 @@ HCI_PACKET_INFO: Dict[int, Tuple[int, int, str]] = {
hci.HCI_ACL_DATA_PACKET: (2, 2, 'H'),
hci.HCI_SYNCHRONOUS_DATA_PACKET: (1, 2, 'B'),
hci.HCI_EVENT_PACKET: (1, 1, 'B'),
hci.HCI_ISO_DATA_PACKET: (2, 2, 'H'),
}
@@ -151,7 +150,7 @@ class PacketParser:
try:
self.sink.on_packet(bytes(self.packet))
except Exception as error:
logger.exception(
logger.warning(
color(f'!!! Exception in on_packet: {error}', 'red')
)
self.reset()
@@ -168,13 +167,11 @@ class PacketReader:
def __init__(self, source: io.BufferedReader) -> None:
self.source = source
self.at_end = False
def next_packet(self) -> Optional[bytes]:
# Get the packet type
packet_type = self.source.read(1)
if len(packet_type) != 1:
self.at_end = True
return None
# Get the packet info based on its type

View File

@@ -59,7 +59,10 @@ async def open_hci_socket_transport(spec: Optional[str]) -> Transport:
) from error
# Compute the adapter index
adapter_index = int(spec) if spec else 0
if spec is None:
adapter_index = 0
else:
adapter_index = int(spec)
# Bind the socket
# NOTE: since Python doesn't support binding with the required address format (yet),

View File

@@ -113,10 +113,9 @@ async def open_pyusb_transport(spec: str) -> Transport:
self.loop.call_soon_threadsafe(self.stop_event.set)
class UsbPacketSource(asyncio.Protocol, ParserSource):
def __init__(self, device, metadata, sco_enabled):
def __init__(self, device, sco_enabled):
super().__init__()
self.device = device
self.metadata = metadata
self.loop = asyncio.get_running_loop()
self.queue = asyncio.Queue()
self.dequeue_task = None
@@ -217,15 +216,6 @@ async def open_pyusb_transport(spec: str) -> Transport:
if ':' in spec:
vendor_id, product_id = spec.split(':')
device = usb_find(idVendor=int(vendor_id, 16), idProduct=int(product_id, 16))
elif '-' in spec:
def device_path(device):
if device.port_numbers:
return f'{device.bus}-{".".join(map(str, device.port_numbers))}'
else:
return str(device.bus)
device = usb_find(custom_match=lambda device: device_path(device) == spec)
else:
device_index = int(spec)
devices = list(
@@ -245,9 +235,6 @@ async def open_pyusb_transport(spec: str) -> Transport:
raise ValueError('device not found')
logger.debug(f'USB Device: {device}')
# Collect the metadata
device_metadata = {'vendor_id': device.idVendor, 'product_id': device.idProduct}
# Detach the kernel driver if needed
if device.is_kernel_driver_active(0):
logger.debug("detaching kernel driver")
@@ -302,7 +289,7 @@ async def open_pyusb_transport(spec: str) -> Transport:
# except usb.USBError:
# logger.warning('failed to set alternate setting')
packet_source = UsbPacketSource(device, device_metadata, sco_enabled)
packet_source = UsbPacketSource(device, sco_enabled)
packet_sink = UsbPacketSink(device)
packet_source.start()
packet_sink.start()

View File

@@ -24,10 +24,9 @@ import platform
import usb1
from bumble.transport.common import Transport, ParserSource
from bumble import hci
from bumble.colors import color
from bumble.utils import AsyncRunner
from .common import Transport, ParserSource
from .. import hci
from ..colors import color
# -----------------------------------------------------------------------------
@@ -108,13 +107,13 @@ async def open_usb_transport(spec: str) -> Transport:
USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER,
)
READ_SIZE = 4096
READ_SIZE = 1024
class UsbPacketSink:
def __init__(self, device, acl_out):
self.device = device
self.acl_out = acl_out
self.acl_out_transfer = device.getTransfer()
self.transfer = device.getTransfer()
self.packets = collections.deque() # Queue of packets waiting to be sent
self.loop = asyncio.get_running_loop()
self.cancel_done = self.loop.create_future()
@@ -138,20 +137,21 @@ async def open_usb_transport(spec: str) -> Transport:
# The queue was previously empty, re-prime the pump
self.process_queue()
def transfer_callback(self, transfer):
def on_packet_sent(self, transfer):
status = transfer.getStatus()
# logger.debug(f'<<< USB out transfer callback: status={status}')
# pylint: disable=no-member
if status == usb1.TRANSFER_COMPLETED:
self.loop.call_soon_threadsafe(self.on_packet_sent)
self.loop.call_soon_threadsafe(self.on_packet_sent_)
elif status == usb1.TRANSFER_CANCELLED:
self.loop.call_soon_threadsafe(self.cancel_done.set_result, None)
else:
logger.warning(
color(f'!!! OUT transfer not completed: status={status}', 'red')
color(f'!!! out transfer not completed: status={status}', 'red')
)
def on_packet_sent(self):
def on_packet_sent_(self):
if self.packets:
self.packets.popleft()
self.process_queue()
@@ -163,20 +163,22 @@ async def open_usb_transport(spec: str) -> Transport:
packet = self.packets[0]
packet_type = packet[0]
if packet_type == hci.HCI_ACL_DATA_PACKET:
self.acl_out_transfer.setBulk(
self.acl_out, packet[1:], callback=self.transfer_callback
self.transfer.setBulk(
self.acl_out, packet[1:], callback=self.on_packet_sent
)
self.acl_out_transfer.submit()
logger.debug('submit ACL')
self.transfer.submit()
elif packet_type == hci.HCI_COMMAND_PACKET:
self.acl_out_transfer.setControl(
self.transfer.setControl(
USB_RECIPIENT_DEVICE | USB_REQUEST_TYPE_CLASS,
0,
0,
0,
packet[1:],
callback=self.transfer_callback,
callback=self.on_packet_sent,
)
self.acl_out_transfer.submit()
logger.debug('submit COMMAND')
self.transfer.submit()
else:
logger.warning(color(f'unsupported packet type {packet_type}', 'red'))
@@ -191,11 +193,11 @@ async def open_usb_transport(spec: str) -> Transport:
self.packets.clear()
# If we have a transfer in flight, cancel it
if self.acl_out_transfer.isSubmitted():
if self.transfer.isSubmitted():
# Try to cancel the transfer, but that may fail because it may have
# already completed
try:
self.acl_out_transfer.cancel()
self.transfer.cancel()
logger.debug('waiting for OUT transfer cancellation to be done...')
await self.cancel_done
@@ -204,22 +206,27 @@ async def open_usb_transport(spec: str) -> Transport:
logger.debug('OUT transfer likely already completed')
class UsbPacketSource(asyncio.Protocol, ParserSource):
def __init__(self, device, metadata, acl_in, events_in):
def __init__(self, context, device, metadata, acl_in, events_in):
super().__init__()
self.context = context
self.device = device
self.metadata = metadata
self.acl_in = acl_in
self.acl_in_transfer = None
self.events_in = events_in
self.events_in_transfer = None
self.loop = asyncio.get_running_loop()
self.queue = asyncio.Queue()
self.dequeue_task = None
self.closed = False
self.event_loop_done = self.loop.create_future()
self.cancel_done = {
hci.HCI_EVENT_PACKET: self.loop.create_future(),
hci.HCI_ACL_DATA_PACKET: self.loop.create_future(),
}
self.closed = False
self.events_in_transfer = None
self.acl_in_transfer = None
# Create a thread to process events
self.event_thread = threading.Thread(target=self.run)
def start(self):
# Set up transfer objects for input
@@ -227,7 +234,7 @@ async def open_usb_transport(spec: str) -> Transport:
self.events_in_transfer.setInterrupt(
self.events_in,
READ_SIZE,
callback=self.transfer_callback,
callback=self.on_packet_received,
user_data=hci.HCI_EVENT_PACKET,
)
self.events_in_transfer.submit()
@@ -236,23 +243,22 @@ async def open_usb_transport(spec: str) -> Transport:
self.acl_in_transfer.setBulk(
self.acl_in,
READ_SIZE,
callback=self.transfer_callback,
callback=self.on_packet_received,
user_data=hci.HCI_ACL_DATA_PACKET,
)
self.acl_in_transfer.submit()
self.dequeue_task = self.loop.create_task(self.dequeue())
self.event_thread.start()
@property
def usb_transfer_submitted(self):
return (
self.events_in_transfer.isSubmitted()
or self.acl_in_transfer.isSubmitted()
)
def transfer_callback(self, transfer):
def on_packet_received(self, transfer):
packet_type = transfer.getUserData()
status = transfer.getStatus()
# logger.debug(
# f'<<< USB IN transfer callback: status={status} '
# f'packet_type={packet_type} '
# f'length={transfer.getActualLength()}'
# )
# pylint: disable=no-member
if status == usb1.TRANSFER_COMPLETED:
@@ -261,18 +267,18 @@ async def open_usb_transport(spec: str) -> Transport:
+ transfer.getBuffer()[: transfer.getActualLength()]
)
self.loop.call_soon_threadsafe(self.queue.put_nowait, packet)
# Re-submit the transfer so we can receive more data
transfer.submit()
elif status == usb1.TRANSFER_CANCELLED:
self.loop.call_soon_threadsafe(
self.cancel_done[packet_type].set_result, None
)
return
else:
logger.warning(
color(f'!!! IN transfer not completed: status={status}', 'red')
color(f'!!! transfer not completed: status={status}', 'red')
)
self.loop.call_soon_threadsafe(self.on_transport_lost)
# Re-submit the transfer so we can receive more data
transfer.submit()
async def dequeue(self):
while not self.closed:
@@ -282,6 +288,21 @@ async def open_usb_transport(spec: str) -> Transport:
return
self.parser.feed_data(packet)
def run(self):
logger.debug('starting USB event loop')
while (
self.events_in_transfer.isSubmitted()
or self.acl_in_transfer.isSubmitted()
):
# pylint: disable=no-member
try:
self.context.handleEvents()
except usb1.USBErrorInterrupted:
pass
logger.debug('USB event loop done')
self.loop.call_soon_threadsafe(self.event_loop_done.set_result, None)
def close(self):
self.closed = True
@@ -310,14 +331,15 @@ async def open_usb_transport(spec: str) -> Transport:
f'IN[{packet_type}] transfer likely already completed'
)
# Wait for the thread to terminate
await self.event_loop_done
class UsbTransport(Transport):
def __init__(self, context, device, interface, setting, source, sink):
super().__init__(source, sink)
self.context = context
self.device = device
self.interface = interface
self.loop = asyncio.get_running_loop()
self.event_loop_done = self.loop.create_future()
# Get exclusive access
device.claimInterface(interface)
@@ -330,22 +352,6 @@ async def open_usb_transport(spec: str) -> Transport:
source.start()
sink.start()
# Create a thread to process events
self.event_thread = threading.Thread(target=self.run)
self.event_thread.start()
def run(self):
logger.debug('starting USB event loop')
while self.source.usb_transfer_submitted:
# pylint: disable=no-member
try:
self.context.handleEvents()
except usb1.USBErrorInterrupted:
pass
logger.debug('USB event loop done')
self.loop.call_soon_threadsafe(self.event_loop_done.set_result, None)
async def close(self):
self.source.close()
self.sink.close()
@@ -355,9 +361,6 @@ async def open_usb_transport(spec: str) -> Transport:
self.device.close()
self.context.close()
# Wait for the thread to terminate
await self.event_loop_done
# Find the device according to the spec moniker
load_libusb()
context = usb1.USBContext()
@@ -396,16 +399,6 @@ async def open_usb_transport(spec: str) -> Transport:
break
device_index -= 1
device.close()
elif '-' in spec:
def device_path(device):
return f'{device.getBusNumber()}-{".".join(map(str, device.getPortNumberList()))}'
for device in context.getDeviceIterator(skip_on_error=True):
if device_path(device) == spec:
found = device
break
device.close()
else:
# Look for a compatible device by index
def device_is_bluetooth_hci(device):
@@ -547,7 +540,7 @@ async def open_usb_transport(spec: str) -> Transport:
except usb1.USBError:
logger.warning('failed to set configuration')
source = UsbPacketSource(device, device_metadata, acl_in, events_in)
source = UsbPacketSource(context, device, device_metadata, acl_in, events_in)
sink = UsbPacketSink(device, acl_out)
return UsbTransport(context, device, interface, setting, source, sink)
except usb1.USBError as error:

View File

@@ -17,12 +17,10 @@
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import collections
import enum
import functools
import logging
import traceback
import collections
import sys
import warnings
from typing import (
Awaitable,
Set,
@@ -35,7 +33,7 @@ from typing import (
Union,
overload,
)
from functools import wraps, partial
from pyee import EventEmitter
from .colors import color
@@ -132,14 +130,13 @@ class EventWatcher:
Args:
emitter: EventEmitter to watch
event: Event name
handler: (Optional) Event handler. When nothing is passed, this method
works as a decorator.
handler: (Optional) Event handler. When nothing is passed, this method works as a decorator.
'''
def wrapper(wrapped: _Handler) -> _Handler:
self.handlers.append((emitter, event, wrapped))
emitter.on(event, wrapped)
return wrapped
def wrapper(f: _Handler) -> _Handler:
self.handlers.append((emitter, event, f))
emitter.on(event, f)
return f
return wrapper if handler is None else wrapper(handler)
@@ -159,14 +156,13 @@ class EventWatcher:
Args:
emitter: EventEmitter to watch
event: Event name
handler: (Optional) Event handler. When nothing passed, this method works
as a decorator.
handler: (Optional) Event handler. When nothing passed, this method works as a decorator.
'''
def wrapper(wrapped: _Handler) -> _Handler:
self.handlers.append((emitter, event, wrapped))
emitter.once(event, wrapped)
return wrapped
def wrapper(f: _Handler) -> _Handler:
self.handlers.append((emitter, event, f))
emitter.once(event, f)
return f
return wrapper if handler is None else wrapper(handler)
@@ -226,13 +222,13 @@ class CompositeEventEmitter(AbortableEventEmitter):
if self._listener:
# Call the deregistration methods for each base class that has them
for cls in self._listener.__class__.mro():
if '_bumble_register_composite' in cls.__dict__:
cls._bumble_deregister_composite(self._listener, self)
if hasattr(cls, '_bumble_register_composite'):
cls._bumble_deregister_composite(listener, self)
self._listener = listener
if listener:
# Call the registration methods for each base class that has them
for cls in listener.__class__.mro():
if '_bumble_deregister_composite' in cls.__dict__:
if hasattr(cls, '_bumble_deregister_composite'):
cls._bumble_register_composite(listener, self)
@@ -279,18 +275,21 @@ class AsyncRunner:
"""
def decorator(func):
@functools.wraps(func)
@wraps(func)
def wrapper(*args, **kwargs):
coroutine = func(*args, **kwargs)
if queue is None:
# Spawn the coroutine as a task
# Create a task to run the coroutine
async def run():
try:
await coroutine
except Exception:
logger.exception(color("!!! Exception in wrapper:", "red"))
logger.warning(
f'{color("!!! Exception in wrapper:", "red")} '
f'{traceback.format_exc()}'
)
AsyncRunner.spawn(run())
asyncio.create_task(run())
else:
# Queue the coroutine to be awaited by the work queue
queue.enqueue(coroutine)
@@ -413,75 +412,18 @@ class FlowControlAsyncPipe:
self.check_pump()
# -----------------------------------------------------------------------------
async def async_call(function, *args, **kwargs):
"""
Immediately calls the function with provided args and kwargs, wrapping it in an
async function.
Rust's `pyo3_asyncio` library needs functions to be marked async to properly inject
a running loop.
Immediately calls the function with provided args and kwargs, wrapping it in an async function.
Rust's `pyo3_asyncio` library needs functions to be marked async to properly inject a running loop.
result = await async_call(some_function, ...)
"""
return function(*args, **kwargs)
# -----------------------------------------------------------------------------
def wrap_async(function):
"""
Wraps the provided function in an async function.
"""
return functools.partial(async_call, function)
# -----------------------------------------------------------------------------
def deprecated(msg: str):
"""
Throw deprecation warning before execution.
"""
def wrapper(function):
@functools.wraps(function)
def inner(*args, **kwargs):
warnings.warn(msg, DeprecationWarning)
return function(*args, **kwargs)
return inner
return wrapper
# -----------------------------------------------------------------------------
def experimental(msg: str):
"""
Throws a future warning before execution.
"""
def wrapper(function):
@functools.wraps(function)
def inner(*args, **kwargs):
warnings.warn(msg, FutureWarning)
return function(*args, **kwargs)
return inner
return wrapper
# -----------------------------------------------------------------------------
class OpenIntEnum(enum.IntEnum):
"""
Subclass of enum.IntEnum that can hold integer values outside the set of
predefined values. This is convenient for implementing protocols where some
integer constants may be added over time.
"""
@classmethod
def _missing_(cls, value):
if not isinstance(value, int):
return None
obj = int.__new__(cls, value)
obj._value_ = value
obj._name_ = f"{cls.__name__}[{value}]"
return obj
return partial(async_call, function)

View File

@@ -10,7 +10,7 @@ nav:
- Contributing: development/contributing.md
- Code Style: development/code_style.md
- Use Cases:
- use_cases/index.md
- Overview: use_cases/index.md
- Use Case 1: use_cases/use_case_1.md
- Use Case 2: use_cases/use_case_2.md
- Use Case 3: use_cases/use_case_3.md
@@ -23,7 +23,7 @@ nav:
- GATT: components/gatt.md
- Security Manager: components/security_manager.md
- Transports:
- transports/index.md
- Overview: transports/index.md
- Serial: transports/serial.md
- USB: transports/usb.md
- PTY: transports/pty.md
@@ -37,14 +37,14 @@ nav:
- Android Emulator: transports/android_emulator.md
- File: transports/file.md
- Drivers:
- drivers/index.md
- Overview: drivers/index.md
- Realtek: drivers/realtek.md
- API:
- Guide: api/guide.md
- Examples: api/examples.md
- Reference: api/reference.md
- Apps & Tools:
- apps_and_tools/index.md
- Overview: apps_and_tools/index.md
- Console: apps_and_tools/console.md
- Bench: apps_and_tools/bench.md
- Speaker: apps_and_tools/speaker.md
@@ -57,25 +57,16 @@ nav:
- USB Probe: apps_and_tools/usb_probe.md
- Link Relay: apps_and_tools/link_relay.md
- Hardware:
- hardware/index.md
- Overview: hardware/index.md
- Platforms:
- platforms/index.md
- Overview: platforms/index.md
- macOS: platforms/macos.md
- Linux: platforms/linux.md
- Windows: platforms/windows.md
- Android: platforms/android.md
- Zephyr: platforms/zephyr.md
- Examples:
- examples/index.md
- Extras:
- extras/index.md
- Android Remote HCI: extras/android_remote_hci.md
- Android BT Bench: extras/android_bt_bench.md
- Hive:
- hive/index.md
- Speaker: hive/web/speaker/speaker.html
- Scanner: hive/web/scanner/scanner.html
- Heart Rate Monitor: hive/web/heart_rate_monitor/heart_rate_monitor.html
- Overview: examples/index.md
copyright: Copyright 2021-2023 Google LLC
@@ -84,8 +75,6 @@ theme:
logo: 'images/logo.png'
favicon: 'images/favicon.ico'
custom_dir: 'theme'
features:
- navigation.indexes
plugins:
- mkdocstrings:
@@ -110,8 +99,6 @@ markdown_extensions:
- pymdownx.emoji:
emoji_index: !!python/name:materialx.emoji.twemoji
emoji_generator: !!python/name:materialx.emoji.to_svg
- pymdownx.tabbed:
alternate_style: true
- codehilite:
guess_lang: false
- toc:

View File

@@ -7,36 +7,16 @@ throughput and/or latency between two devices.
# General Usage
```
Usage: bumble-bench [OPTIONS] COMMAND [ARGS]...
Usage: bench.py [OPTIONS] COMMAND [ARGS]...
Options:
--device-config FILENAME Device configuration file
--role [sender|receiver|ping|pong]
--mode [gatt-client|gatt-server|l2cap-client|l2cap-server|rfcomm-client|rfcomm-server]
--att-mtu MTU GATT MTU (gatt-client mode) [23<=x<=517]
--extended-data-length TEXT Request a data length upon connection,
specified as tx_octets/tx_time
--rfcomm-channel INTEGER RFComm channel to use
--rfcomm-uuid TEXT RFComm service UUID to use (ignored if
--rfcomm-channel is not 0)
--l2cap-psm INTEGER L2CAP PSM to use
--l2cap-mtu INTEGER L2CAP MTU to use
--l2cap-mps INTEGER L2CAP MPS to use
--l2cap-max-credits INTEGER L2CAP maximum number of credits allowed for
the peer
-s, --packet-size SIZE Packet size (client or ping role)
[8<=x<=4096]
-c, --packet-count COUNT Packet count (client or ping role)
-sd, --start-delay SECONDS Start delay (client or ping role)
--repeat N Repeat the run N times (client and ping
roles)(0, which is the fault, to run just
once)
--repeat-delay SECONDS Delay, in seconds, between repeats
--pace MILLISECONDS Wait N milliseconds between packets (0,
which is the fault, to send as fast as
possible)
--linger Don't exit at the end of a run (server and
pong roles)
-s, --packet-size SIZE Packet size (server role) [8<=x<=4096]
-c, --packet-count COUNT Packet count (server role)
-sd, --start-delay SECONDS Start delay (server role)
--help Show this message and exit.
Commands:
@@ -55,18 +35,17 @@ Options:
--connection-interval, --ci CONNECTION_INTERVAL
Connection interval (in ms)
--phy [1m|2m|coded] PHY to use
--authenticate Authenticate (RFComm only)
--encrypt Encrypt the connection (RFComm only)
--help Show this message and exit.
```
To test once device against another, one of the two devices must be running
To test once device against another, one of the two devices must be running
the ``peripheral`` command and the other the ``central`` command. The device
running the ``peripheral`` command will accept connections from the device
running the ``central`` command.
When using Bluetooth LE (all modes except for ``rfcomm-server`` and ``rfcomm-client``utils),
the default addresses configured in the tool should be sufficient. But when using
Bluetooth Classic, the address of the Peripheral must be specified on the Central
the default addresses configured in the tool should be sufficient. But when using
Bluetooth Classic, the address of the Peripheral must be specified on the Central
using the ``--peripheral`` option. The address will be printed by the Peripheral when
it starts.
@@ -104,7 +83,7 @@ the other on `usb:1`, and two consoles/terminals. We will run a command in each.
$ bumble-bench central usb:1
```
In this default configuration, the Central runs a Sender, as a GATT client,
In this default configuration, the Central runs a Sender, as a GATT client,
connecting to the Peripheral running a Receiver, as a GATT server.
!!! example "L2CAP Throughput"

View File

@@ -12,25 +12,12 @@ a host that send custom HCI commands that the controller may not understand.
```
python hci_bridge.py <host-transport-spec> <controller-transport-spec> [command-short-circuit-list]
```
The command-short-circuit-list field is specified by a series of comma separated Opcode Group
Field (OGF) : OpCode Command Field (OCF) pairs. The OGF/OCF values are specified in the Blutooth
core specification.
For the commands that are listed in the short-circuit-list, the HCI bridge will always generate
a Command Complete Event for the specified op code. The return parameter will be HCI_SUCCESS.
This feature can only be used for commands that return Command Complete. Other events will not be
generated by the HCI bridge tool.
!!! example "UDP to Serial"
```
python hci_bridge.py udp:0.0.0.0:9000,127.0.0.1:9001 serial:/dev/tty.usbmodem0006839912171,1000000 0x3f:0x0070,0x3f:0x0074,0x3f:0x0077,0x3f:0x0078
```
In this example, the short circuit list is specified to respond to the Vendor-specific Opcode Group
Field (0x3f) commands 0x70, 0x74, 0x77, 0x78 with Command Complete. The short circuit list can be
used where the Host uses some HCI commands that are not supported/implemented by the Controller.
!!! example "PTY to Link Relay"
```
python hci_bridge.py serial:emulated_uart_pty,1000000 link-relay:ws://127.0.0.1:10723/test
@@ -41,4 +28,3 @@ a host that send custom HCI commands that the controller may not understand.
(through which the communication with other virtual controllers will be mediated).
NOTE: this assumes you're running a Link Relay on port `10723`.

View File

@@ -5,15 +5,6 @@ Some Bluetooth controllers require a driver to function properly.
This may include, for instance, loading a Firmware image or patch,
loading a configuration.
By default, drivers will be automatically probed to determine if they should be
used with particular HCI controller.
When the transport for an HCI controller is instantiated from a transport name,
a driver may also be forced by specifying ``driver=<driver-name>`` in the optional
metadata portion of the transport name. For example,
``usb:[driver=rtk]0`` indicates that the ``rtk`` driver should be used with the
first USB device, even if a normal probe would not have selected it based on the
USB vendor ID and product ID.
Drivers included in the module are:
* [Realtek](realtek.md): Loading of Firmware and Config for Realtek USB dongles.

View File

@@ -1,16 +1,13 @@
REALTEK DRIVER
==============
This driver supports loading firmware images and optional config data to
This driver supports loading firmware images and optional config data to
USB dongles with a Realtek chipset.
A number of USB dongles are supported, but likely not all.
When using a USB dongle, the USB product ID and vendor ID are used
When using a USB dongle, the USB product ID and manufacturer ID are used
to find whether a matching set of firmware image and config data
is needed for that specific model. If a match exists, the driver will try
load the firmware image and, if needed, config data.
Alternatively, the metadata property ``driver=rtk`` may be specified in a transport
name to force that driver to be used (ex: ``usb:[driver=rtk]0`` instead of just
``usb:0`` for the first USB device).
The driver will look for those files by name, in order, in:
* The directory specified by the environment variable `BUMBLE_RTK_FIRMWARE_DIR`

View File

@@ -1,64 +0,0 @@
ANDROID BENCH APP
=================
This Android app that is compatible with the Bumble `bench` command line app.
This app can be used to test the throughput and latency between two Android
devices, or between an Android device and another device running the Bumble
`bench` app.
Only the RFComm Client, RFComm Server, L2CAP Client and L2CAP Server modes are
supported.
Building
--------
You can build the app by running `./gradlew build` (use `gradlew.bat` on Windows) from the `BtBench` top level directory.
You can also build with Android Studio: open the `BtBench` project. You can build and/or debug from there.
If the build succeeds, you can find the app APKs (debug and release) at:
* [Release] ``app/build/outputs/apk/release/app-release-unsigned.apk``
* [Debug] ``app/build/outputs/apk/debug/app-debug.apk``
Running
-------
### Starting the app
You can start the app from the Android launcher, from Android Studio, or with `adb`
#### Launching from the launcher
Just tap the app icon on the launcher, check the parameters, and tap
one of the benchmark action buttons.
#### Launching with `adb`
Using the `am` command, you can start the activity, and pass it arguments so that you can
automatically start the benchmark test, and/or set the parameters.
| Parameter Name | Parameter Type | Description
|------------------------|----------------|------------
| autostart | String | Benchmark to start. (rfcomm-client, rfcomm-server, l2cap-client or l2cap-server)
| packet-count | Integer | Number of packets to send (rfcomm-client and l2cap-client only)
| packet-size | Integer | Number of bytes per packet (rfcomm-client and l2cap-client only)
| peer-bluetooth-address | Integer | Peer Bluetooth address to connect to (rfcomm-client and l2cap-client | only)
!!! tip "Launching from adb with auto-start"
In this example, we auto-start the Rfcomm Server bench action.
```bash
$ adb shell am start -n com.github.google.bumble.btbench/.MainActivity --es autostart rfcomm-server
```
!!! tip "Launching from adb with auto-start and some parameters"
In this example, we auto-start the Rfcomm Client bench action, set the packet count to 100,
and the packet size to 1024, and connect to DA:4C:10:DE:17:02
```bash
$ adb shell am start -n com.github.google.bumble.btbench/.MainActivity --es autostart rfcomm-client --ei packet-count 100 --ei packet-size 1024 --es peer-bluetooth-address DA:4C:10:DE:17:02
```
#### Selecting a Peer Bluetooth Address
The app's main activity has a "Peer Bluetooth Address" setting where you can change the address.
!!! note "Bluetooth Address for L2CAP vs RFComm"
For BLE (L2CAP mode), the address of a device typically changes regularly (it is randomized for privacy), whereas the Bluetooth Classic addresses will remain the same (RFComm mode).
If two devices are paired and bonded, then they will each "see" a non-changing address for each other even with BLE (Resolvable Private Address)

View File

@@ -1,181 +0,0 @@
ANDROID REMOTE HCI APP
======================
This application allows using an android phone's built-in Bluetooth controller with
a Bumble host stack running outside the phone (typically a development laptop or desktop).
The app runs an HCI proxy between a TCP socket on the "outside" and the Bluetooth HCI HAL
on the "inside". (See [this page](https://source.android.com/docs/core/connect/bluetooth) for a high level
description of the Android Bluetooth HCI HAL).
The HCI packets received on the TCP socket are forwarded to the phone's controller, and the
packets coming from the controller are forwarded to the TCP socket.
Building
--------
You can build the app by running `./gradlew build` (use `gradlew.bat` on Windows) from the `extras/android/RemoteHCI` top level directory.
You can also build with Android Studio: open the `RemoteHCI` project. You can build and/or debug from there.
If the build succeeds, you can find the app APKs (debug and release) at:
* [Release] ``app/build/outputs/apk/release/app-release-unsigned.apk``
* [Debug] ``app/build/outputs/apk/debug/app-debug.apk``
Running
-------
!!! note
In the following examples, it is assumed that shell commands are executed while in the
app's root directory, `extras/android/RemoteHCI`. If you are in a different directory,
adjust the relative paths accordingly.
### Preconditions
When the proxy starts (tapping the "Start" button in the app's main activity, or running the proxy
from an `adb shell` command line), it will try to bind to the Bluetooth HAL.
This requires that there is no other HAL client, and requires certain privileges.
For running as a regular app, this requires disabling SELinux temporarily.
For running as a command-line executable, this just requires a root shell.
#### Root Shell
!!! tip "Restart `adb` as root"
```bash
$ adb root
```
#### Disabling SELinux
Binding to the Bluetooth HCI HAL requires certain SELinux permissions that can't simply be changed
on a device without rebuilding its system image. To bypass these restrictions, you will need
to disable SELinux on your phone (please be aware that this is global, not just for the proxy app,
so proceed with caution).
In order to disable SELinux, you need to root the phone (it may be advisable to do this on a
development phone).
!!! tip "Disabling SELinux Temporarily"
Restart `adb` as root:
```bash
$ adb root
```
Then disable SELinux
```bash
$ adb shell setenforce 0
```
Once you're done using the proxy, you can restore SELinux, if you need to, with
```bash
$ adb shell setenforce 1
```
This state will also reset to the normal SELinux enforcement when you reboot.
#### Stopping the bluetooth process
Since the Bluetooth HAL service can only accept one client, and that in normal conditions
that client is the Android's bluetooth stack, it is required to first shut down the
Android bluetooth stack process.
!!! tip "Checking if the Bluetooth process is running"
```bash
$ adb shell "ps -A | grep com.google.android.bluetooth"
```
If the process is running, you will get a line like:
```
bluetooth 10759 876 17455796 136620 do_epoll_wait 0 S com.google.android.bluetooth
```
If you don't, it means that the process is not running and you are clear to proceed.
Simply turning Bluetooth off from the phone's settings does not ensure that the bluetooth process will exit.
If the bluetooth process is still running after toggling Bluetooth off from the settings, you may try enabling
Airplane Mode, then rebooting. The bluetooth process should, in theory, not restart after the reboot.
!!! tip "Stopping the bluetooth process with adb"
```bash
$ adb shell cmd bluetooth_manager disable
```
### Running as a command line app
You push the built APK to a temporary location on the phone's filesystem, then launch the command
line executable with an `adb shell` command.
!!! tip "Pushing the executable"
```bash
$ adb push app/build/outputs/apk/release/app-release-unsigned.apk /data/local/tmp/remotehci.apk
```
Do this every time you rebuild. Alternatively, you can push the `debug` APK instead:
```bash
$ adb push app/build/outputs/apk/debug/app-debug.apk /data/local/tmp/remotehci.apk
```
!!! tip "Start the proxy from the command line"
```bash
adb shell "CLASSPATH=/data/local/tmp/remotehci.apk app_process /system/bin com.github.google.bumble.remotehci.CommandLineInterface"
```
This will run the proxy, listening on the default TCP port.
If you want a different port, pass it as a command line parameter
!!! tip "Start the proxy from the command line with a specific TCP port"
```bash
adb shell "CLASSPATH=/data/local/tmp/remotehci.apk app_process /system/bin com.github.google.bumble.remotehci.CommandLineInterface 12345"
```
### Running as a normal app
You can start the app from the Android launcher, from Android Studio, or with `adb`
#### Launching from the launcher
Just tap the app icon on the launcher, check the TCP port that is configured, and tap
the "Start" button.
#### Launching with `adb`
Using the `am` command, you can start the activity, and pass it arguments so that you can
automatically start the proxy, and/or set the port number.
!!! tip "Launching from adb with auto-start"
```bash
$ adb shell am start -n com.github.google.bumble.remotehci/.MainActivity --ez autostart true
```
!!! tip "Launching from adb with auto-start and a port"
In this example, we auto-start the proxy upon launch, with the port set to 9995
```bash
$ adb shell am start -n com.github.google.bumble.remotehci/.MainActivity --ez autostart true --ei port 9995
```
#### Selecting a TCP port
The RemoteHCI app's main activity has a "TCP Port" setting where you can change the port on
which the proxy is accepting connections. If the default value isn't suitable, you can
change it there (you can also use the special value 0 to let the OS assign a port number for you).
### Connecting to the proxy
To connect the Bumble stack to the proxy, you need to be able to reach the phone's network
stack. This can be done over the phone's WiFi connection, or, alternatively, using an `adb`
TCP forward (which should be faster than over WiFi).
!!! tip "Forwarding TCP with `adb`"
To connect to the proxy via an `adb` TCP forward, use:
```bash
$ adb forward tcp:<outside-port> tcp:<inside-port>
```
Where ``<outside-port>`` is the port number for a listening socket on your laptop or
desktop machine, and <inside-port> is the TCP port selected in the app's user interface.
Those two ports may be the same, of course.
For example, with the default TCP port 9993:
```bash
$ adb forward tcp:9993 tcp:9993
```
Once you've ensured that you can reach the proxy's TCP port on the phone, either directly or
via an `adb` forward, you can then use it as a Bumble transport, using the transport name:
``tcp-client:<host>:<port>`` syntax.
!!! example "Connecting a Bumble client"
Connecting the `bumble-controller-info` app to the phone's controller.
Assuming you have set up an `adb` forward on port 9993:
```bash
$ bumble-controller-info tcp-client:localhost:9993
```
Or over WiFi with, in this example, the IP address of the phone being ```192.168.86.27```
```bash
$ bumble-controller-info tcp-client:192.168.86.27:9993
```

View File

@@ -1,19 +0,0 @@
EXTRAS
======
A collection of add-ons, apps and tools, to the Bumble project.
Android Remote HCI
------------------
Allows using an Android phone's built-in Bluetooth controller with a Bumble
stack running on a development machine.
See [Android Remote HCI](android_remote_hci.md) for details.
Android BT Bench
----------------
An Android app that is compatible with the Bumble `bench` command line app.
This app can be used to test the throughput and latency between two Android
devices, or between an Android device and another device running the Bumble
`bench` app.

View File

@@ -3,7 +3,7 @@ HARDWARE
The Bumble Host connects to a controller over an [HCI Transport](../transports/index.md).
To use a hardware controller attached to the host on which the host application is running, the transport is typically either [HCI over UART](../transports/serial.md) or [HCI over USB](../transports/usb.md).
On Linux, the [VHCI Transport](../transports/vhci.md) can be used to communicate with any controller hardware managed by the operating system. Alternatively, a remote controller (a phyiscal controller attached to a remote host) can be used by connecting one of the networked transports (such as the [TCP Client transport](../transports/tcp_client.md), the [TCP Server transport](../transports/tcp_server.md) or the [UDP Transport](../transports/udp.md)) to an [HCI Bridge](../apps_and_tools/hci_bridge.md) bridging the network transport to a physical controller on a remote host.
On Linux, the [VHCI Transport](../transports/vhci.md) can be used to communicate with any controller hardware managed by the operating system. Alternatively, a remote controller (a phyiscal controller attached to a remote host) can be used by connecting one of the networked transports (such as the [TCP Client transport](../transports/tcp_client.md), the [TCP Server transport](../transports/tcp_server.md) or the [UDP Transport](../transports/udp.md)) to an [HCI Bridge](../apps_and_tools/hci_bridge) bridging the network transport to a physical controller on a remote host.
In theory, any controller that is compliant with the HCI over UART or HCI over USB protocols can be used.

View File

@@ -1,59 +0,0 @@
HIVE
====
Welcome to the Bumble Hive.
This is a collection of apps and virtual devices that can run entirely in a browser page.
The code for the apps and devices, as well as the Bumble runtime code, runs via [Pyodide](https://pyodide.org/).
Pyodide is a Python distribution for the browser and Node.js based on WebAssembly.
The Bumble stack uses a WebSocket to exchange HCI packets with a virtual or physical
Bluetooth controller.
The apps and devices in the hive can be accessed by following the links below. Each
page has a settings button that may be used to configure the WebSocket URL to use for
the virtual HCI connection. This will typically be the WebSocket URL for a `netsim`
daemon.
There is also a [TOML index](index.toml) that can be used by tools to know at which URL to access
each of the apps and devices, as well as their names and short descriptions.
!!! tip "Using `netsim`"
When the `netsimd` daemon is running (for example when using the Android Emulator that
is included in Android Studio), the daemon listens for connections on a TCP port.
To find out what this TCP port is, you can read the `netsim.ini` file that `netsimd`
creates, it includes a line with `web.port=<tcp-port>` (for example `web.port=7681`).
The location of the `netsim.ini` file is platform-specific.
=== "macOS"
On macOS, the directory where `netsim.ini` is stored is $TMPDIR
```bash
$ cat $TMPDIR/netsim.ini
```
=== "Linux"
On Linux, the directory where `netsim.ini` is stored is $XDG_RUNTIME_DIR
```bash
$ cat $XDG_RUNTIME_DIR/netsim.ini
```
!!! tip "Using a local radio"
You can connect the hive virtual apps and devices to a local Bluetooth radio, like,
for example, a USB dongle.
For that, you need to run a local HCI bridge to bridge a local HCI device to a WebSocket
that a web page can connect to.
Use the `bumble-hci-bridge` app, with the host transport set to a WebSocket server on an
available port (ex: `ws-server:_:7682`) and the controller transport set to the transport
name for the radio you want to use (ex: `usb:0` for the first USB dongle)
Applications
------------
* [Scanner](web/scanner/scanner.html) - Scans for BLE devices.
Virtual Devices
---------------
* [Speaker](web/speaker/speaker.html) - Virtual speaker that plays audio in a browser page.
* [Heart Rate Monitor](web/heart_rate_monitor/heart_rate_monitor.html) - Virtual heart rate monitor.

View File

@@ -1,21 +0,0 @@
version = "1.0.0"
base_url = "https://google.github.io/bumble/hive/web"
default_hci_query_param = "hci"
[[index]]
name = "speaker"
description = "Bumble Virtual Speaker"
type = "Device"
url = "speaker/speaker.html"
[[index]]
name = "scanner"
description = "Simple Scanner Application"
type = "Application"
url = "scanner/scanner.html"
[[index]]
name = "heart-rate-monitor"
description = "Virtual Heart Rate Monitor"
type = "Device"
url = "heart_rate_monitor/heart_rate_monitor.html"

View File

@@ -1 +0,0 @@
../../../../../web/bumble.js

View File

@@ -1 +0,0 @@
../../../../../../web/heart_rate_monitor/heart_rate_monitor.html

View File

@@ -1 +0,0 @@
../../../../../../web/heart_rate_monitor/heart_rate_monitor.js

View File

@@ -1 +0,0 @@
../../../../../../web/heart_rate_monitor/heart_rate_monitor.py

View File

@@ -1 +0,0 @@
../../../../../../web/scanner/scanner.css

View File

@@ -1 +0,0 @@
../../../../../../web/scanner/scanner.html

View File

@@ -1 +0,0 @@
../../../../../../web/scanner/scanner.js

View File

@@ -1 +0,0 @@
../../../../../../web/scanner/scanner.py

View File

@@ -1 +0,0 @@
../../../../../../web/speaker/logo.svg

View File

@@ -1 +0,0 @@
../../../../../../web/speaker/speaker.css

View File

@@ -1 +0,0 @@
../../../../../../web/speaker/speaker.html

View File

@@ -1 +0,0 @@
../../../../../../web/speaker/speaker.js

View File

@@ -1 +0,0 @@
../../../../../../web/speaker/speaker.py

View File

@@ -1 +0,0 @@
../../../../../web/ui.js

View File

@@ -152,23 +152,11 @@ Some platforms support features that not all platforms support
See the [Platforms page](platforms/index.md) for details.
Hive
----
The Hive is a collection of example apps and virtual devices that are implemented using the
Python Bumble API, running entirely in a web page. This is a convenient way to try out some
of the examples without any Python installation, when you have some other virtual Bluetooth
device that you can connect to or from, such as the Android Emulator.
See the [Bumble Hive](hive/index.md) for details.
Roadmap
-------
Future features to be considered include:
* More profiles
* More device examples
* Add a new type of virtual link (beyond the two existing ones) to allow for link-level simulation (timing, loss, etc)
* Bindings for languages other than Python

View File

@@ -14,7 +14,7 @@ connections.
## Moniker
The moniker syntax for an Android Emulator "netsim" transport is: `android-netsim:[<host>:<port>][<options>]`,
where `<options>` is a comma-separated list of `<name>=<value>` pairs.
where `<options>` is a ','-separated list of `<name>=<value>` pairs`.
The `mode` parameter name can specify running as a host or a controller, and `<hostname>:<port>` can specify a host name (or IP address) and TCP port number on which to reach the gRPC server for the emulator (in "host" mode), or to accept gRPC connections (in "controller" mode).
Both the `mode=<host|controller>` and `<hostname>:<port>` parameters are optional (so the moniker `android-netsim` by itself is a valid moniker, which will create a transport in `host` mode, connected to `localhost` on the default gRPC port for the Netsim background process).

View File

@@ -10,7 +10,6 @@ The moniker for a USB transport is either:
* `usb:<vendor>:<product>`
* `usb:<vendor>:<product>/<serial-number>`
* `usb:<vendor>:<product>#<index>`
* `usb:<bus>-<port_numbers>`
with `<index>` as a 0-based index (0 being the first one) to select amongst all the matching devices when there are more than one.
In the `usb:<index>` form, matching devices are the ones supporting Bluetooth HCI, as declared by their Class, Subclass and Protocol.
@@ -18,8 +17,6 @@ In the `usb:<vendor>:<product>#<index>` form, matching devices are the ones with
`<vendor>` and `<product>` are a vendor ID and product ID in hexadecimal.
with `<port_numbers>` as a list of all port numbers from root separated with dots `.`
In addition, if the moniker ends with the symbol "!", the device will be used in "forced" mode:
the first USB interface of the device will be used, regardless of the interface class/subclass.
This may be useful for some devices that use a custom class/subclass but may nonetheless work as-is.
@@ -40,9 +37,6 @@ This may be useful for some devices that use a custom class/subclass but may non
`usb:0B05:17CB!`
The BT USB dongle vendor=0B05 and product=17CB, in "forced" mode.
`usb:3-3.4.1`
The BT USB dongle on bus 3 on port path 3, 4, 1.
## Alternative
The library includes two different implementations of the USB transport, implemented using different python bindings for `libusb`.

View File

@@ -1,274 +0,0 @@
<html>
<head>
<style>
* {
font-family: sans-serif;
}
</style>
</head>
<body>
Server Port <input id="port" type="text" value="8989"></input> <button id="connectButton" onclick="connect()">Connect</button><br>
<div id="socketState"></div>
<br>
<div id="buttons"></div><br>
<hr>
<button onclick="onGetPlayStatusButtonClicked()">Get Play Status</button><br>
<div id="getPlayStatusResponseTable"></div>
<hr>
<button onclick="onGetElementAttributesButtonClicked()">Get Element Attributes</button><br>
<div id="getElementAttributesResponseTable"></div>
<hr>
<table>
<tr>
<b>VOLUME</b>:
<button onclick="onVolumeDownButtonClicked()">-</button>
<button onclick="onVolumeUpButtonClicked()">+</button>&nbsp;
<span id="volumeText"></span><br>
</tr>
<tr>
<td><b>PLAYBACK STATUS</b></td><td><span id="playbackStatusText"></span></td>
</tr>
<tr>
<td><b>POSITION</b></td><td><span id="positionText"></span></td>
</tr>
<tr>
<td><b>TRACK</b></td><td><span id="trackText"></span></td>
</tr>
<tr>
<td><b>ADDRESSED PLAYER</b></td><td><span id="addressedPlayerText"></span></td>
</tr>
<tr>
<td><b>UID COUNTER</b></td><td><span id="uidCounterText"></span></td>
</tr>
<tr>
<td><b>SUPPORTED EVENTS</b></td><td><span id="supportedEventsText"></span></td>
</tr>
<tr>
<td><b>PLAYER SETTINGS</b></td><td><div id="playerSettingsTable"></div></td>
</tr>
</table>
<script>
const portInput = document.getElementById("port")
const connectButton = document.getElementById("connectButton")
const socketState = document.getElementById("socketState")
const volumeText = document.getElementById("volumeText")
const positionText = document.getElementById("positionText")
const trackText = document.getElementById("trackText")
const playbackStatusText = document.getElementById("playbackStatusText")
const addressedPlayerText = document.getElementById("addressedPlayerText")
const uidCounterText = document.getElementById("uidCounterText")
const supportedEventsText = document.getElementById("supportedEventsText")
const playerSettingsTable = document.getElementById("playerSettingsTable")
const getPlayStatusResponseTable = document.getElementById("getPlayStatusResponseTable")
const getElementAttributesResponseTable = document.getElementById("getElementAttributesResponseTable")
let socket
let volume = 0
const keyNames = [
"SELECT",
"UP",
"DOWN",
"LEFT",
"RIGHT",
"RIGHT_UP",
"RIGHT_DOWN",
"LEFT_UP",
"LEFT_DOWN",
"ROOT_MENU",
"SETUP_MENU",
"CONTENTS_MENU",
"FAVORITE_MENU",
"EXIT",
"NUMBER_0",
"NUMBER_1",
"NUMBER_2",
"NUMBER_3",
"NUMBER_4",
"NUMBER_5",
"NUMBER_6",
"NUMBER_7",
"NUMBER_8",
"NUMBER_9",
"DOT",
"ENTER",
"CLEAR",
"CHANNEL_UP",
"CHANNEL_DOWN",
"PREVIOUS_CHANNEL",
"SOUND_SELECT",
"INPUT_SELECT",
"DISPLAY_INFORMATION",
"HELP",
"PAGE_UP",
"PAGE_DOWN",
"POWER",
"VOLUME_UP",
"VOLUME_DOWN",
"MUTE",
"PLAY",
"STOP",
"PAUSE",
"RECORD",
"REWIND",
"FAST_FORWARD",
"EJECT",
"FORWARD",
"BACKWARD",
"ANGLE",
"SUBPICTURE",
"F1",
"F2",
"F3",
"F4",
"F5",
]
document.addEventListener('keydown', onKeyDown)
document.addEventListener('keyup', onKeyUp)
const buttons = document.getElementById("buttons")
keyNames.forEach(name => {
const button = document.createElement("BUTTON")
button.appendChild(document.createTextNode(name))
button.addEventListener("mousedown", event => {
send({type: 'send-key-down', key: name})
})
button.addEventListener("mouseup", event => {
send({type: 'send-key-up', key: name})
})
buttons.appendChild(button)
})
updateVolume(0)
function connect() {
socket = new WebSocket(`ws://localhost:${portInput.value}`);
socket.onopen = _ => {
socketState.innerText = 'OPEN'
connectButton.disabled = true
}
socket.onclose = _ => {
socketState.innerText = 'CLOSED'
connectButton.disabled = false
}
socket.onerror = (error) => {
socketState.innerText = 'ERROR'
console.log(`ERROR: ${error}`)
connectButton.disabled = false
}
socket.onmessage = (message) => {
onMessage(JSON.parse(message.data))
}
}
function send(message) {
if (socket && socket.readyState == WebSocket.OPEN) {
socket.send(JSON.stringify(message))
}
}
function hmsText(position) {
const h_1 = 1000 * 60 * 60
const h = Math.floor(position / h_1)
position -= h * h_1
const m_1 = 1000 * 60
const m = Math.floor(position / m_1)
position -= m * m_1
const s_1 = 1000
const s = Math.floor(position / s_1)
position -= s * s_1
return `${h}:${m.toString().padStart(2, "0")}:${s.toString().padStart(2, "0")}:${position}`
}
function setTableHead(table, columns) {
let thead = table.createTHead()
let row = thead.insertRow()
for (let column of columns) {
let th = document.createElement("th")
let text = document.createTextNode(column)
th.appendChild(text)
row.appendChild(th)
}
}
function createTable(rows) {
const table = document.createElement("table")
if (rows.length != 0) {
columns = Object.keys(rows[0])
setTableHead(table, columns)
}
for (let element of rows) {
let row = table.insertRow()
for (key in element) {
let cell = row.insertCell()
let text = document.createTextNode(element[key])
cell.appendChild(text)
}
}
return table
}
function onMessage(message) {
console.log(message)
if (message.type == "set-volume") {
updateVolume(message.params.volume)
} else if (message.type == "supported-events") {
supportedEventsText.innerText = JSON.stringify(message.params.events)
} else if (message.type == "playback-position-changed") {
positionText.innerText = hmsText(message.params.position)
} else if (message.type == "playback-status-changed") {
playbackStatusText.innerText = message.params.status
} else if (message.type == "player-settings-changed") {
playerSettingsTable.replaceChildren(message.params.settings)
} else if (message.type == "track-changed") {
trackText.innerText = message.params.identifier
} else if (message.type == "addressed-player-changed") {
addressedPlayerText.innerText = JSON.stringify(message.params.player)
} else if (message.type == "uids-changed") {
uidCounterText.innerText = message.params.uid_counter
} else if (message.type == "get-play-status-response") {
getPlayStatusResponseTable.replaceChildren(message.params)
} else if (message.type == "get-element-attributes-response") {
getElementAttributesResponseTable.replaceChildren(createTable(message.params))
}
}
function updateVolume(newVolume) {
volume = newVolume
volumeText.innerText = `${volume} (${Math.round(100*volume/0x7F)}%)`
}
function onKeyDown(event) {
console.log(event)
send({ type: 'send-key-down', key: event.key })
}
function onKeyUp(event) {
console.log(event)
send({ type: 'send-key-up', key: event.key })
}
function onVolumeUpButtonClicked() {
updateVolume(Math.min(volume + 5, 0x7F))
send({ type: 'set-volume', volume })
}
function onVolumeDownButtonClicked() {
updateVolume(Math.max(volume - 5, 0))
send({ type: 'set-volume', volume })
}
function onGetPlayStatusButtonClicked() {
send({ type: 'get-play-status', volume })
}
function onGetElementAttributesButtonClicked() {
send({ type: 'get-element-attributes' })
}
</script>
</body>
</html>

View File

@@ -29,7 +29,6 @@ from bumble.device import Device
from bumble.transport import open_transport_or_link
from bumble.profiles.device_information_service import DeviceInformationService
from bumble.profiles.heart_rate_service import HeartRateService
from bumble.utils import AsyncRunner
# -----------------------------------------------------------------------------
@@ -99,17 +98,6 @@ async def main():
)
)
# Notify subscribers of the current value as soon as they subscribe
@heart_rate_service.heart_rate_measurement_characteristic.on('subscription')
def on_subscription(connection, notify_enabled, indicate_enabled):
if notify_enabled or indicate_enabled:
AsyncRunner.spawn(
device.notify_subscriber(
connection,
heart_rate_service.heart_rate_measurement_characteristic,
)
)
# Go!
await device.power_on()
await device.start_advertising(auto_restart=True)

View File

@@ -1,248 +0,0 @@
# shift map
# letters
shift_map = {
'a': 'A',
'b': 'B',
'c': 'C',
'd': 'D',
'e': 'E',
'f': 'F',
'g': 'G',
'h': 'H',
'i': 'I',
'j': 'J',
'k': 'K',
'l': 'L',
'm': 'M',
'n': 'N',
'o': 'O',
'p': 'P',
'q': 'Q',
'r': 'R',
's': 'S',
't': 'T',
'u': 'U',
'v': 'V',
'w': 'W',
'x': 'X',
'y': 'Y',
'z': 'Z',
# numbers
'1': '!',
'2': '@',
'3': '#',
'4': '$',
'5': '%',
'6': '^',
'7': '&',
'8': '*',
'9': '(',
'0': ')',
# symbols
'-': '_',
'=': '+',
'[': '{',
']': '}',
'\\': '|',
';': ':',
'\'': '"',
',': '<',
'.': '>',
'/': '?',
'`': '~',
}
# hex map
# modifier keys
mod_keys = {
'00': '',
'01': 'left_ctrl',
'02': 'left_shift',
'04': 'left_alt',
'08': 'left_meta',
'10': 'right_ctrl',
'20': 'right_shift',
'40': 'right_alt',
'80': 'right_meta',
}
# base keys
base_keys = {
# meta
'00': '', # none
'01': 'error_ovf',
# letters
'04': 'a',
'05': 'b',
'06': 'c',
'07': 'd',
'08': 'e',
'09': 'f',
'0a': 'g',
'0b': 'h',
'0c': 'i',
'0d': 'j',
'0e': 'k',
'0f': 'l',
'10': 'm',
'11': 'n',
'12': 'o',
'13': 'p',
'14': 'q',
'15': 'r',
'16': 's',
'17': 't',
'18': 'u',
'19': 'v',
'1a': 'w',
'1b': 'x',
'1c': 'y',
'1d': 'z',
# numbers
'1e': '1',
'1f': '2',
'20': '3',
'21': '4',
'22': '5',
'23': '6',
'24': '7',
'25': '8',
'26': '9',
'27': '0',
# misc
'28': 'enter', # enter \n
'29': 'esc',
'2a': 'backspace',
'2b': 'tab',
'2c': 'spacebar', # space
'2d': '-',
'2e': '=',
'2f': '[',
'30': ']',
'31': '\\',
'32': '=',
'33': '_SEMICOLON',
'34': 'KEY_APOSTROPHE',
'35': 'KEY_GRAVE',
'36': 'KEY_COMMA',
'37': 'KEY_DOT',
'38': 'KEY_SLASH',
'39': 'KEY_CAPSLOCK',
'3a': 'KEY_F1',
'3b': 'KEY_F2',
'3c': 'KEY_F3',
'3d': 'KEY_F4',
'3e': 'KEY_F5',
'3f': 'KEY_F6',
'40': 'KEY_F7',
'41': 'KEY_F8',
'42': 'KEY_F9',
'43': 'KEY_F10',
'44': 'KEY_F11',
'45': 'KEY_F12',
'46': 'KEY_SYSRQ',
'47': 'KEY_SCROLLLOCK',
'48': 'KEY_PAUSE',
'49': 'KEY_INSERT',
'4a': 'KEY_HOME',
'4b': 'KEY_PAGEUP',
'4c': 'KEY_DELETE',
'4d': 'KEY_END',
'4e': 'KEY_PAGEDOWN',
'4f': 'KEY_RIGHT',
'50': 'KEY_LEFT',
'51': 'KEY_DOWN',
'52': 'KEY_UP',
'53': 'KEY_NUMLOCK',
'54': 'KEY_KPSLASH',
'55': 'KEY_KPASTERISK',
'56': 'KEY_KPMINUS',
'57': 'KEY_KPPLUS',
'58': 'KEY_KPENTER',
'59': 'KEY_KP1',
'5a': 'KEY_KP2',
'5b': 'KEY_KP3',
'5c': 'KEY_KP4',
'5d': 'KEY_KP5',
'5e': 'KEY_KP6',
'5f': 'KEY_KP7',
'60': 'KEY_KP8',
'61': 'KEY_KP9',
'62': 'KEY_KP0',
'63': 'KEY_KPDOT',
'64': 'KEY_102ND',
'65': 'KEY_COMPOSE',
'66': 'KEY_POWER',
'67': 'KEY_KPEQUAL',
'68': 'KEY_F13',
'69': 'KEY_F14',
'6a': 'KEY_F15',
'6b': 'KEY_F16',
'6c': 'KEY_F17',
'6d': 'KEY_F18',
'6e': 'KEY_F19',
'6f': 'KEY_F20',
'70': 'KEY_F21',
'71': 'KEY_F22',
'72': 'KEY_F23',
'73': 'KEY_F24',
'74': 'KEY_OPEN',
'75': 'KEY_HELP',
'76': 'KEY_PROPS',
'77': 'KEY_FRONT',
'78': 'KEY_STOP',
'79': 'KEY_AGAIN',
'7a': 'KEY_UNDO',
'7b': 'KEY_CUT',
'7c': 'KEY_COPY',
'7d': 'KEY_PASTE',
'7e': 'KEY_FIND',
'7f': 'KEY_MUTE',
'80': 'KEY_VOLUMEUP',
'81': 'KEY_VOLUMEDOWN',
'85': 'KEY_KPCOMMA',
'87': 'KEY_RO',
'88': 'KEY_KATAKANAHIRAGANA',
'89': 'KEY_YEN',
'8a': 'KEY_HENKAN',
'8b': 'KEY_MUHENKAN',
'8c': 'KEY_KPJPCOMMA',
'90': 'KEY_HANGEUL',
'91': 'KEY_HANJA',
'92': 'KEY_KATAKANA',
'93': 'KEY_HIRAGANA',
'94': 'KEY_ZENKAKUHANKAKU',
'b6': 'KEY_KPLEFTPAREN',
'b7': 'KEY_KPRIGHTPAREN',
'e0': 'KEY_LEFTCTRL',
'e1': 'KEY_LEFTSHIFT',
'e2': 'KEY_LEFTALT',
'e3': 'KEY_LEFTMETA',
'e4': 'KEY_RIGHTCTRL',
'e5': 'KEY_RIGHTSHIFT',
'e6': 'KEY_RIGHTALT',
'e7': 'KEY_RIGHTMETA',
'e8': 'KEY_MEDIA_PLAYPAUSE',
'e9': 'KEY_MEDIA_STOPCD',
'ea': 'KEY_MEDIA_PREVIOUSSONG',
'eb': 'KEY_MEDIA_NEXTSONG',
'ec': 'KEY_MEDIA_EJECTCD',
'ed': 'KEY_MEDIA_VOLUMEUP',
'ee': 'KEY_MEDIA_VOLUMEDOWN',
'ef': 'KEY_MEDIA_MUTE',
'f0': 'KEY_MEDIA_WWW',
'f1': 'KEY_MEDIA_BACK',
'f2': 'KEY_MEDIA_FORWARD',
'f3': 'KEY_MEDIA_STOP',
'f4': 'KEY_MEDIA_FIND',
'f5': 'KEY_MEDIA_SCROLLUP',
'f6': 'KEY_MEDIA_SCROLLDOWN',
'f7': 'KEY_MEDIA_EDIT',
'f8': 'KEY_MEDIA_SLEEP',
'f9': 'KEY_MEDIA_COFFEE',
'fa': 'KEY_MEDIA_REFRESH',
'fb': 'KEY_MEDIA_CALC',
}

View File

@@ -1,5 +0,0 @@
{
"name": "Bumble HID Keyboard",
"class_of_device": 9664,
"keystore": "JsonKeyStore"
}

View File

@@ -1,159 +0,0 @@
from bumble.colors import color
from hid_key_map import base_keys, mod_keys, shift_map
# ------------------------------------------------------------------------------
def get_key(modifier: str, key: str) -> str:
if modifier == '22':
modifier = '02'
if modifier in mod_keys:
modifier = mod_keys[modifier]
else:
return ''
if key in base_keys:
key = base_keys[key]
else:
return ''
if (modifier == 'left_shift' or modifier == 'right_shift') and key in shift_map:
key = shift_map[key]
return key
class Keyboard:
def __init__(self): # type: ignore
self.report = [
[ # Bit array for Modifier keys
0, # Right GUI - (usually the Windows key)
0, # Right ALT
0, # Right Shift
0, # Right Control
0, # Left GUI - (usually the Windows key)
0, # Left ALT
0, # Left Shift
0, # Left Control
],
0x00, # Vendor reserved
'', # Rest is space for 6 keys
'',
'',
'',
'',
'',
]
def decode_keyboard_report(self, input_report: bytes, report_length: int) -> None:
if report_length >= 8:
modifier = input_report[1]
self.report[0] = [int(x) for x in '{0:08b}'.format(modifier)]
self.report[0].reverse() # type: ignore
modifier_key = str((modifier & 0x22).to_bytes(1, "big").hex())
keycodes = []
for k in range(3, report_length):
keycodes.append(str(input_report[k].to_bytes(1, "big").hex()))
self.report[k - 1] = get_key(modifier_key, keycodes[k - 3])
else:
print(color('Warning: Not able to parse report', 'yellow'))
def print_keyboard_report(self) -> None:
print(color('\tKeyboard Input Received', 'green', None, 'bold'))
print(color(f'Keys:', 'white', None, 'bold'))
for i in range(1, 7):
print(
color(f' Key{i}{" ":>8s}= ', 'cyan', None, 'bold'), self.report[i + 1]
)
print(color(f'\nModifier Keys:', 'white', None, 'bold'))
print(
color(f' Left Ctrl : ', 'cyan'),
f'{self.report[0][0] == 1!s:<5}', # type: ignore
color(f' Left Shift : ', 'cyan'),
f'{self.report[0][1] == 1!s:<5}', # type: ignore
color(f' Left ALT : ', 'cyan'),
f'{self.report[0][2] == 1!s:<5}', # type: ignore
color(f' Left GUI : ', 'cyan'),
f'{self.report[0][3] == 1!s:<5}\n', # type: ignore
color(f' Right Ctrl : ', 'cyan'),
f'{self.report[0][4] == 1!s:<5}', # type: ignore
color(f' Right Shift : ', 'cyan'),
f'{self.report[0][5] == 1!s:<5}', # type: ignore
color(f' Right ALT : ', 'cyan'),
f'{self.report[0][6] == 1!s:<5}', # type: ignore
color(f' Right GUI : ', 'cyan'),
f'{self.report[0][7] == 1!s:<5}', # type: ignore
)
# ------------------------------------------------------------------------------
class Mouse:
def __init__(self): # type: ignore
self.report = [
[ # Bit array for Buttons
0, # Button 1 (primary/trigger
0, # Button 2 (secondary)
0, # Button 3 (tertiary)
0, # Button 4
0, # Button 5
0, # unused padding bits
0, # unused padding bits
0, # unused padding bits
],
0, # X
0, # Y
0, # Wheel
0, # AC Pan
]
def decode_mouse_report(self, input_report: bytes, report_length: int) -> None:
self.report[0] = [int(x) for x in '{0:08b}'.format(input_report[1])]
self.report[0].reverse() # type: ignore
self.report[1] = input_report[2]
self.report[2] = input_report[3]
if report_length in [5, 6]:
self.report[3] = input_report[4]
self.report[4] = input_report[5] if report_length == 6 else 0
def print_mouse_report(self) -> None:
print(color('\tMouse Input Received', 'green', None, 'bold'))
print(
color(f' Button 1 (primary/trigger) = ', 'cyan'),
self.report[0][0] == 1, # type: ignore
color(f'\n Button 2 (secondary) = ', 'cyan'),
self.report[0][1] == 1, # type: ignore
color(f'\n Button 3 (tertiary) = ', 'cyan'),
self.report[0][2] == 1, # type: ignore
color(f'\n Button4 = ', 'cyan'),
self.report[0][3] == 1, # type: ignore
color(f'\n Button5 = ', 'cyan'),
self.report[0][4] == 1, # type: ignore
color(f'\n X (X-axis displacement) = ', 'cyan'),
self.report[1],
color(f'\n Y (Y-axis displacement) = ', 'cyan'),
self.report[2],
color(f'\n Wheel = ', 'cyan'),
self.report[3],
color(f'\n AC PAN = ', 'cyan'),
self.report[4],
)
# ------------------------------------------------------------------------------
class ReportParser:
@staticmethod
def parse_input_report(input_report: bytes) -> None:
report_id = input_report[0] # pylint: disable=unsubscriptable-object
report_length = len(input_report)
# Keyboard input report (report id = 1)
if report_id == 1 and report_length >= 8:
keyboard = Keyboard() # type: ignore
keyboard.decode_keyboard_report(input_report, report_length)
keyboard.print_keyboard_report()
# Mouse input report (report id = 2)
elif report_id == 2 and report_length in [4, 5, 6]:
mouse = Mouse() # type: ignore
mouse.decode_mouse_report(input_report, report_length)
mouse.print_mouse_report()
else:
print(color(f'Warning: Parse Error Report ID {report_id}', 'yellow'))

View File

@@ -40,9 +40,9 @@
}
}
function onMouseMove(event) {
//console.log(event.movementX, event.movementY)
mouseInfo.innerText = `MOUSE: x=${event.movementX}, y=${event.movementY}`
send({ type:'mousemove', x: event.movementX, y: event.movementY })
//console.log(event.clientX, event.clientY)
mouseInfo.innerText = `MOUSE: x=${event.clientX}, y=${event.clientY}`
send({ type:'mousemove', x: event.clientX, y: event.clientY })
}
function onKeyDown(event) {

View File

@@ -1,7 +0,0 @@
{
"name": "Bumble-LEA",
"keystore": "JsonKeyStore",
"address": "F0:F1:F2:F3:F4:FA",
"class_of_device": 2376708,
"advertising_interval": 100
}

Some files were not shown because too many files have changed in this diff Show More