Compare commits

..

83 Commits

Author SHA1 Message Date
Gilles Boccon-Gibod 00cd8fbdd0 compatibility with recent host ACL property changes 2024-01-05 12:17:09 -08:00
Michael Mogenson c48e3f5e9c Merge pull request #393 from mogenson/controller-loopback
apps: Add a controller loopback throughput test app
2024-01-05 13:13:30 -05:00
Michael Mogenson d6bbc1145a apps: Add a controller loopback throughput test app
Add a command line utility to open a transport to a BT controller, put
the controller into local loopback mode, and send and receive ACL data
packets. Record the time it takes to send and receive all packets and
calculate a throughput measurement in kB/s.

This utility is usefull for characterizing the speed of a transport to a
BT controller (such as a TCP socket or serial port) without having to
deal with a connected peer or the variability of over the air
transmissions.

The transport CLI argument is required. The packet size and packet
count arguments are optional. They default to the same values as the
bumble-bench app.
2024-01-05 10:01:24 -05:00
zxzxwu e2fec67bd9 Merge pull request #390 from zxzxwu/csip
CSIP: Encrypted SIRK implementation
2024-01-04 13:28:23 +08:00
Josh Wu 88cb3b2a4d IWYU in CSIP 2024-01-04 13:22:09 +08:00
zxzxwu 9ebb03be46 Merge pull request #389 from zxzxwu/gitignore
.gitignore: Add venv directories
2024-01-04 12:54:30 +08:00
Gilles Boccon-Gibod 80d84af76c Merge pull request #392 from google/gbg/l2cap-drain
l2cap & rfcomm drain support
2024-01-03 09:59:36 -08:00
Gilles Boccon-Gibod 8f4721758f fix typo 2024-01-03 09:53:17 -08:00
Gilles Boccon-Gibod 8864af4acd format 2024-01-02 11:35:11 -08:00
Gilles Boccon-Gibod 8980fb8cc7 add drain support and a few tool options 2024-01-02 11:07:52 -08:00
Josh Wu 2c5f3472a9 CSIP: Encrypted SIRK implementation 2023-12-30 16:06:42 +08:00
Josh Wu f18277ac78 Ignore venv directories 2023-12-30 14:23:35 +08:00
Gilles Boccon-Gibod 09e5ea5dec Merge pull request #387 from google/gbg/async-gatt-server
support async read/write for characteristic values
2023-12-29 11:28:22 -08:00
Gilles Boccon-Gibod 6810865670 Merge pull request #385 from google/gbg/android-enable-dle
request MTU change after connection
2023-12-28 13:46:25 -08:00
Gilles Boccon-Gibod 3e9e06a02c Merge pull request #386 from AlanRosenthal/main
app/bench.py: use logging rather than print()
2023-12-28 13:42:17 -08:00
Alan Rosenthal ccd12f6591 app/bench.py: use logging rather than print() 2023-12-28 16:06:50 -05:00
Gilles Boccon-Gibod f9a7843f7e request MTU change after connection 2023-12-28 11:17:18 -08:00
Gilles Boccon-Gibod 210c334db7 Merge pull request #380 from google/gbg/classic-buffer-size
support per-transport ACL queues
2023-12-28 09:24:52 -08:00
Gilles Boccon-Gibod f297cdfcce Merge pull request #384 from eukub/string-concatination-to-fstring
сhanged concatenation of strings to f-strings to improve readability
2023-12-28 09:24:25 -08:00
eukub 5b536d00ab сhanged concatenation of strings to f-strings to improve readability and unify with the rest of code 2023-12-28 16:27:36 +03:00
Gilles Boccon-Gibod b4af46ebd5 use TCP_NODELAY on socket 2023-12-27 12:11:20 -08:00
Gilles Boccon-Gibod c08da3193e format 2023-12-27 11:56:06 -08:00
Gilles Boccon-Gibod f2925ca647 support async read/write for characteristic values 2023-12-27 11:52:22 -08:00
Gilles Boccon-Gibod fd4d68e5c0 print controller flow control info 2023-12-26 13:24:24 -08:00
Gilles Boccon-Gibod 5d83deffa4 Merge pull request #345 from rdhavan/bumble_hid_device
Bumble hid device implementation - Application and hid profile
2023-12-26 11:10:34 -08:00
Gilles Boccon-Gibod 2878cca478 Merge pull request #378 from benquike/pair_linger
Improve the linger option of bumble-pair
2023-12-26 10:55:28 -08:00
Gilles Boccon-Gibod 53934716db Merge pull request #377 from benquike/irk
Add functions/tool for gen/verifying BLE IRK/RPA
2023-12-26 10:54:18 -08:00
Hui Peng d885d45824 Add functions/tool for gen/verifying BLE IRK/RPA 2023-12-26 09:34:19 -08:00
Gilles Boccon-Gibod b90d0f8710 fix tests 2023-12-26 09:09:20 -08:00
zxzxwu 8ccfc90fe6 Merge pull request #379 from zxzxwu/addr
Add random address generation methods
2023-12-25 17:28:49 +08:00
Josh Wu 92aa7e9e2a Add random address generation methods 2023-12-24 18:07:40 +08:00
Gilles Boccon-Gibod afc6d19e04 address PR comments 2023-12-23 14:21:44 -08:00
Gilles Boccon-Gibod c05f073b33 Update bumble/host.py
Co-authored-by: zxzxwu <92432172+zxzxwu@users.noreply.github.com>
2023-12-23 14:15:53 -08:00
Gilles Boccon-Gibod 2b4c2a22f4 format 2023-12-22 14:22:08 -08:00
Gilles Boccon-Gibod 47fe93a148 support per-transport ACL queues 2023-12-22 13:52:33 -08:00
zxzxwu 6139ca8045 Merge pull request #374 from zxzxwu/csip
Complete CSIP and CAP
2023-12-23 02:49:35 +08:00
Josh Wu 87c76a4a0e Complete CSIP and CAP
Also add random address generation functions.
2023-12-23 02:14:32 +08:00
Hui Peng f7b66db873 Improve the linger option in pair tool
No matter pairing fails or not, make linger effective
2023-12-21 17:25:42 -08:00
skarnataki 0b314bd7f7 Updated absctract class and method for on_ctrl_pdu in hid.py 2023-12-18 13:36:25 +00:00
skarnataki 9da2e32ad7 Review comment Fix 3 - rename json file and usage of Optional in parameters 2023-12-15 09:42:57 +00:00
Snehal Karnataki 93c0875740 Merge branch 'google:main' into bumble_hid_device 2023-12-13 09:51:27 +00:00
Gilles Boccon-Gibod a286700239 Merge pull request #368 from google/gbg/driver-load-before-reset
support drivers that can't use reset directly.
2023-12-11 18:06:23 -08:00
Gilles Boccon-Gibod 98ed772e8a address PR comments and add some typing 2023-12-11 17:52:04 -08:00
Gilles Boccon-Gibod f0b55a4f97 Merge pull request #367 from google/gbg/android-bench-update
Android bench app: add support for 2M phy
2023-12-11 10:20:56 -08:00
zxzxwu b74503d345 Merge pull request #359 from zxzxwu/ascs
Audio Stream Control Service
2023-12-12 00:47:03 +08:00
Josh Wu f911163e49 Improve ASCS logging 2023-12-12 00:36:24 +08:00
Gilles Boccon-Gibod b083cc99ad fix spec parsing 2023-12-08 18:57:02 -08:00
Gilles Boccon-Gibod 62a8ced447 support drivers that can't use reset directly. 2023-12-08 17:28:57 -08:00
Josh Wu 81a6b1e097 Replace 3.9 dict merger 2023-12-08 11:10:17 +08:00
Josh Wu dd090c9e6b Add ASCS tests 2023-12-08 11:00:44 +08:00
Josh Wu 11faa48422 Fix ASE state change 2023-12-08 09:53:14 +08:00
Josh Wu 55596176c2 ffplay routing 2023-12-08 09:53:14 +08:00
Josh Wu 4d6822d312 Remove ISO data path on release 2023-12-08 09:53:14 +08:00
Josh Wu 985c365e6d Setup data path after CIS established 2023-12-08 09:53:14 +08:00
Josh Wu af57762227 Parse CodecSpecificConfiguration 2023-12-08 09:53:14 +08:00
Josh Wu 3575f9030e Add Audio Stream Control Service 2023-12-08 09:53:14 +08:00
zxzxwu 698d947d85 Merge pull request #366 from zxzxwu/extadv
Add advertiser classes and handle adv set terminated events
2023-12-08 09:52:42 +08:00
Josh Wu ff6528d2bf Add Advertising unit tests 2023-12-08 01:38:01 +08:00
Josh Wu 72ac75a98d Add advertiser classes and handle adv set terminated events
* Convert hci.OwnAddressType to enum
* Add LegacyAdvertiser and ExtendedAdvertiser classes
* Rename start/stop_advertising() => start/stop_legacy_advertising()
* Handle HCI_Advertising_Set_Terminated
* Properly restart advertisement on disconnection
2023-12-07 15:51:51 +08:00
skarnataki 5e3ecb74e4 Review comment fix -2 2023-12-05 13:41:30 +00:00
Snehal Karnataki c59be293c8 Merge branch 'google:main' into bumble_hid_device 2023-12-05 13:07:36 +00:00
Snehal Karnataki 6d22ed80ec Merge branch 'google:main' into bumble_hid_device 2023-12-04 07:29:04 +00:00
Snehal Karnataki ffb3eca68b Merge branch 'google:main' into bumble_hid_device 2023-11-30 04:50:05 +00:00
skarnataki 403a13e4c6 Review comment fix HID device 2023-11-28 13:42:25 +00:00
Snehal Karnataki ad0f035df5 Merge branch 'google:main' into bumble_hid_device 2023-11-28 13:06:32 +00:00
skarnataki 07f71fc895 Project format and lint error fix. Redefination if Device class needs to be discussed 2023-11-27 13:04:54 +00:00
Fahad Afroze f47b9178ad Added GET_REPORT and SET_REPORT changes
Added changes to handle invalid cases
2023-11-27 11:55:35 +00:00
SneKarnataki 4f399249bd Merge branch 'google:main' into bumble_hid_device 2023-11-27 09:00:44 +00:00
skarnataki 9324237828 send_data comment fix and lint error fix 2023-11-24 11:13:20 +00:00
Fahad Afroze d1033c018a Modified DeviceData class 2023-11-24 05:42:31 +00:00
Fahad Afroze 0f29052ade Added mousemove changes
Also modified keyboard data on keyup
2023-11-23 17:46:55 +00:00
skarnataki 0578e84586 Menu and name change review comments fix 2023-11-23 15:43:22 +00:00
Fahad Afroze 6ab41c466f Add review comment changes 3 2023-11-23 12:27:56 +00:00
Fahad Afroze 98a1093ebf Add review comment changes 2
Also corrected sending mouseData
2023-11-23 09:53:16 +00:00
dhavan caf04373f3 keyboard data moved to DeviceData class 2023-11-23 08:01:07 +00:00
SneKarnataki d4e8526766 Merge branch 'google:main' into bumble_hid_device 2023-11-23 07:59:43 +00:00
dhavan 515b83a8c7 deleted: bumble/classic3.json
modified:   examples/keyboard.html
2023-11-23 06:10:52 +00:00
dhavan dc18595c8a MTU size check added 2023-11-23 05:17:44 +00:00
SneKarnataki 488bcfe9c6 Merge branch 'google:main' into bumble_hid_device 2023-11-23 04:03:53 +00:00
dhavan d6cefdff8e Renamed the status message class 2023-11-22 17:14:24 +00:00
dhavan dc410b14c4 SET_REPORT and GET_REPORT implemented 2023-11-22 16:05:33 +00:00
dhavan 4c49ef9403 SET_REPORT implemented 2023-11-22 12:31:34 +00:00
dhavan ba85dcbda5 Get the changes from hid_device to bumble_hid_device
Modified the get_report_cb
2023-11-22 11:06:27 +00:00
53 changed files with 4503 additions and 734 deletions
+2
View File
@@ -10,3 +10,5 @@ __pycache__
bumble/_version.py bumble/_version.py
.vscode/launch.json .vscode/launch.json
/.idea /.idea
venv/
.venv/
+1
View File
@@ -22,6 +22,7 @@
"cmac", "cmac",
"CONNECTIONLESS", "CONNECTIONLESS",
"csip", "csip",
"csis",
"csrcs", "csrcs",
"CVSD", "CVSD",
"datagram", "datagram",
+394 -131
View File
File diff suppressed because it is too large Load Diff
+63
View File
@@ -0,0 +1,63 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import click
from bumble.colors import color
from bumble.hci import Address
from bumble.helpers import generate_irk, verify_rpa_with_irk
@click.group()
def cli():
'''
This is a tool for generating IRK, RPA,
and verifying IRK/RPA pairs
'''
@click.command()
def gen_irk() -> None:
print(generate_irk().hex())
@click.command()
@click.argument("irk", type=str)
def gen_rpa(irk: str) -> None:
irk_bytes = bytes.fromhex(irk)
rpa = Address.generate_private_address(irk_bytes)
print(rpa.to_string(with_type_qualifier=False))
@click.command()
@click.argument("irk", type=str)
@click.argument("rpa", type=str)
def verify_rpa(irk: str, rpa: str) -> None:
address = Address(rpa)
irk_bytes = bytes.fromhex(irk)
if verify_rpa_with_irk(address, irk_bytes):
print(color("Verified", "green"))
else:
print(color("Not Verified", "red"))
def main():
cli.add_command(gen_irk)
cli.add_command(gen_rpa)
cli.add_command(verify_rpa)
cli()
# -----------------------------------------------------------------------------
if __name__ == '__main__':
main()
+4 -4
View File
@@ -777,7 +777,7 @@ class ConsoleApp:
if not service: if not service:
continue continue
values = [ values = [
attribute.read_value(connection) await attribute.read_value(connection)
for connection in self.device.connections.values() for connection in self.device.connections.values()
] ]
if not values: if not values:
@@ -796,11 +796,11 @@ class ConsoleApp:
if not characteristic: if not characteristic:
continue continue
values = [ values = [
attribute.read_value(connection) await attribute.read_value(connection)
for connection in self.device.connections.values() for connection in self.device.connections.values()
] ]
if not values: if not values:
values = [attribute.read_value(None)] values = [await attribute.read_value(None)]
# TODO: future optimization: convert CCCD value to human readable string # TODO: future optimization: convert CCCD value to human readable string
@@ -944,7 +944,7 @@ class ConsoleApp:
# send data to any subscribers # send data to any subscribers
if isinstance(attribute, Characteristic): if isinstance(attribute, Characteristic):
attribute.write_value(None, value) await attribute.write_value(None, value)
if attribute.has_properties(Characteristic.NOTIFY): if attribute.has_properties(Characteristic.NOTIFY):
await self.device.gatt_server.notify_subscribers(attribute) await self.device.gatt_server.notify_subscribers(attribute)
if attribute.has_properties(Characteristic.INDICATE): if attribute.has_properties(Characteristic.INDICATE):
+34 -2
View File
@@ -32,10 +32,14 @@ from bumble.hci import (
HCI_Command, HCI_Command,
HCI_Command_Complete_Event, HCI_Command_Complete_Event,
HCI_Command_Status_Event, HCI_Command_Status_Event,
HCI_READ_BUFFER_SIZE_COMMAND,
HCI_Read_Buffer_Size_Command,
HCI_READ_BD_ADDR_COMMAND, HCI_READ_BD_ADDR_COMMAND,
HCI_Read_BD_ADDR_Command, HCI_Read_BD_ADDR_Command,
HCI_READ_LOCAL_NAME_COMMAND, HCI_READ_LOCAL_NAME_COMMAND,
HCI_Read_Local_Name_Command, HCI_Read_Local_Name_Command,
HCI_LE_READ_BUFFER_SIZE_COMMAND,
HCI_LE_Read_Buffer_Size_Command,
HCI_LE_READ_MAXIMUM_DATA_LENGTH_COMMAND, HCI_LE_READ_MAXIMUM_DATA_LENGTH_COMMAND,
HCI_LE_Read_Maximum_Data_Length_Command, HCI_LE_Read_Maximum_Data_Length_Command,
HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND, HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND,
@@ -59,7 +63,7 @@ def command_succeeded(response):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def get_classic_info(host): async def get_classic_info(host: Host) -> None:
if host.supports_command(HCI_READ_BD_ADDR_COMMAND): if host.supports_command(HCI_READ_BD_ADDR_COMMAND):
response = await host.send_command(HCI_Read_BD_ADDR_Command()) response = await host.send_command(HCI_Read_BD_ADDR_Command())
if command_succeeded(response): if command_succeeded(response):
@@ -80,7 +84,7 @@ async def get_classic_info(host):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def get_le_info(host): async def get_le_info(host: Host) -> None:
print() print()
if host.supports_command(HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND): if host.supports_command(HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND):
@@ -136,6 +140,31 @@ async def get_le_info(host):
print(' ', name_or_number(HCI_LE_SUPPORTED_FEATURES_NAMES, feature)) print(' ', name_or_number(HCI_LE_SUPPORTED_FEATURES_NAMES, feature))
# -----------------------------------------------------------------------------
async def get_acl_flow_control_info(host: Host) -> None:
print()
if host.supports_command(HCI_READ_BUFFER_SIZE_COMMAND):
response = await host.send_command(
HCI_Read_Buffer_Size_Command(), check_result=True
)
print(
color('ACL Flow Control:', 'yellow'),
f'{response.return_parameters.hc_total_num_acl_data_packets} '
f'packets of size {response.return_parameters.hc_acl_data_packet_length}',
)
if host.supports_command(HCI_LE_READ_BUFFER_SIZE_COMMAND):
response = await host.send_command(
HCI_LE_Read_Buffer_Size_Command(), check_result=True
)
print(
color('LE ACL Flow Control:', 'yellow'),
f'{response.return_parameters.hc_total_num_le_acl_data_packets} '
f'packets of size {response.return_parameters.hc_le_acl_data_packet_length}',
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def async_main(transport): async def async_main(transport):
print('<<< connecting to HCI...') print('<<< connecting to HCI...')
@@ -168,6 +197,9 @@ async def async_main(transport):
# Get the LE info # Get the LE info
await get_le_info(host) await get_le_info(host)
# Print the ACL flow control info
await get_acl_flow_control_info(host)
# Print the list of commands supported by the controller # Print the list of commands supported by the controller
print() print()
print(color('Supported Commands:', 'yellow')) print(color('Supported Commands:', 'yellow'))
+200
View File
@@ -0,0 +1,200 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import logging
import os
import time
from typing import Optional
from bumble.colors import color
from bumble.hci import (
HCI_READ_LOOPBACK_MODE_COMMAND,
HCI_Read_Loopback_Mode_Command,
HCI_WRITE_LOOPBACK_MODE_COMMAND,
HCI_Write_Loopback_Mode_Command,
LoopbackMode,
)
from bumble.host import Host
from bumble.transport import open_transport_or_link
import click
class Loopback:
"""Send and receive ACL data packets in local loopback mode"""
def __init__(self, packet_size: int, packet_count: int, transport: str):
self.transport = transport
self.packet_size = packet_size
self.packet_count = packet_count
self.connection_handle: Optional[int] = None
self.connection_event = asyncio.Event()
self.done = asyncio.Event()
self.expected_cid = 0
self.bytes_received = 0
self.start_timestamp = 0.0
self.last_timestamp = 0.0
def on_connection(self, connection_handle: int, *args):
"""Retrieve connection handle from new connection event"""
if not self.connection_event.is_set():
# save first connection handle for ACL
# subsequent connections are SCO
self.connection_handle = connection_handle
self.connection_event.set()
def on_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes):
"""Calculate packet receive speed"""
now = time.time()
print(f'<<< Received packet {cid}: {len(pdu)} bytes')
assert connection_handle == self.connection_handle
assert cid == self.expected_cid
self.expected_cid += 1
if cid == 0:
self.start_timestamp = now
else:
elapsed_since_start = now - self.start_timestamp
elapsed_since_last = now - self.last_timestamp
self.bytes_received += len(pdu)
instant_rx_speed = len(pdu) / elapsed_since_last
average_rx_speed = self.bytes_received / elapsed_since_start
print(
color(
f'@@@ RX speed: instant={instant_rx_speed:.4f},'
f' average={average_rx_speed:.4f}',
'cyan',
)
)
self.last_timestamp = now
if self.expected_cid == self.packet_count:
print(color('@@@ Received last packet', 'green'))
self.done.set()
async def run(self):
"""Run a loopback throughput test"""
print(color('>>> Connecting to HCI...', 'green'))
async with await open_transport_or_link(self.transport) as (
hci_source,
hci_sink,
):
print(color('>>> Connected', 'green'))
host = Host(hci_source, hci_sink)
await host.reset()
# make sure data can fit in one l2cap pdu
l2cap_header_size = 4
max_packet_size = host.acl_packet_queue.max_packet_size - l2cap_header_size
if self.packet_size > max_packet_size:
print(
color(
f'!!! Packet size ({self.packet_size}) larger than max supported'
f' size ({max_packet_size})',
'red',
)
)
return
if not host.supports_command(
HCI_WRITE_LOOPBACK_MODE_COMMAND
) or not host.supports_command(HCI_READ_LOOPBACK_MODE_COMMAND):
print(color('!!! Loopback mode not supported', 'red'))
return
# set event callbacks
host.on('connection', self.on_connection)
host.on('l2cap_pdu', self.on_l2cap_pdu)
loopback_mode = LoopbackMode.LOCAL
print(color('### Setting loopback mode', 'blue'))
await host.send_command(
HCI_Write_Loopback_Mode_Command(loopback_mode=LoopbackMode.LOCAL),
check_result=True,
)
print(color('### Checking loopback mode', 'blue'))
response = await host.send_command(
HCI_Read_Loopback_Mode_Command(), check_result=True
)
if response.return_parameters.loopback_mode != loopback_mode:
print(color('!!! Loopback mode mismatch', 'red'))
return
await self.connection_event.wait()
print(color('### Connected', 'cyan'))
print(color('=== Start sending', 'magenta'))
start_time = time.time()
bytes_sent = 0
for cid in range(0, self.packet_count):
# using the cid as an incremental index
host.send_l2cap_pdu(
self.connection_handle, cid, bytes(self.packet_size)
)
print(
color(
f'>>> Sending packet {cid}: {self.packet_size} bytes', 'yellow'
)
)
bytes_sent += self.packet_size # don't count L2CAP or HCI header sizes
await asyncio.sleep(0) # yield to allow packet receive
await self.done.wait()
print(color('=== Done!', 'magenta'))
elapsed = time.time() - start_time
average_tx_speed = bytes_sent / elapsed
print(
color(
f'@@@ TX speed: average={average_tx_speed:.4f} ({bytes_sent} bytes'
f' in {elapsed:.2f} seconds)',
'green',
)
)
# -----------------------------------------------------------------------------
@click.command()
@click.option(
'--packet-size',
'-s',
metavar='SIZE',
type=click.IntRange(8, 4096),
default=500,
help='Packet size',
)
@click.option(
'--packet-count',
'-c',
metavar='COUNT',
type=int,
default=10,
help='Packet count',
)
@click.argument('transport')
def main(packet_size, packet_count, transport):
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
loopback = Loopback(packet_size, packet_count, transport)
asyncio.run(loopback.run())
# -----------------------------------------------------------------------------
if __name__ == '__main__':
main()
+32 -24
View File
@@ -49,14 +49,16 @@ class ServerBridge:
self.tcp_port = tcp_port self.tcp_port = tcp_port
async def start(self, device: Device) -> None: async def start(self, device: Device) -> None:
# Listen for incoming L2CAP CoC connections # Listen for incoming L2CAP channel connections
device.create_l2cap_server( device.create_l2cap_server(
spec=l2cap.LeCreditBasedChannelSpec( spec=l2cap.LeCreditBasedChannelSpec(
psm=self.psm, mtu=self.mtu, mps=self.mps, max_credits=self.max_credits psm=self.psm, mtu=self.mtu, mps=self.mps, max_credits=self.max_credits
), ),
handler=self.on_coc, handler=self.on_channel,
)
print(
color(f'### Listening for channel connection on PSM {self.psm}', 'yellow')
) )
print(color(f'### Listening for CoC connection on PSM {self.psm}', 'yellow'))
def on_ble_connection(connection): def on_ble_connection(connection):
def on_ble_disconnection(reason): def on_ble_disconnection(reason):
@@ -73,7 +75,7 @@ class ServerBridge:
await device.start_advertising(auto_restart=True) await device.start_advertising(auto_restart=True)
# Called when a new L2CAP connection is established # Called when a new L2CAP connection is established
def on_coc(self, l2cap_channel): def on_channel(self, l2cap_channel):
print(color('*** L2CAP channel:', 'cyan'), l2cap_channel) print(color('*** L2CAP channel:', 'cyan'), l2cap_channel)
class Pipe: class Pipe:
@@ -83,7 +85,7 @@ class ServerBridge:
self.l2cap_channel = l2cap_channel self.l2cap_channel = l2cap_channel
l2cap_channel.on('close', self.on_l2cap_close) l2cap_channel.on('close', self.on_l2cap_close)
l2cap_channel.sink = self.on_coc_sdu l2cap_channel.sink = self.on_channel_sdu
async def connect_to_tcp(self): async def connect_to_tcp(self):
# Connect to the TCP server # Connect to the TCP server
@@ -128,7 +130,7 @@ class ServerBridge:
if self.tcp_transport is not None: if self.tcp_transport is not None:
self.tcp_transport.close() self.tcp_transport.close()
def on_coc_sdu(self, sdu): def on_channel_sdu(self, sdu):
print(color(f'<<< [L2CAP SDU]: {len(sdu)} bytes', 'cyan')) print(color(f'<<< [L2CAP SDU]: {len(sdu)} bytes', 'cyan'))
if self.tcp_transport is None: if self.tcp_transport is None:
print(color('!!! TCP socket not open, dropping', 'red')) print(color('!!! TCP socket not open, dropping', 'red'))
@@ -183,7 +185,7 @@ class ClientBridge:
peer_name = writer.get_extra_info('peer_name') peer_name = writer.get_extra_info('peer_name')
print(color(f'<<< TCP connection from {peer_name}', 'magenta')) print(color(f'<<< TCP connection from {peer_name}', 'magenta'))
def on_coc_sdu(sdu): def on_channel_sdu(sdu):
print(color(f'<<< [L2CAP SDU]: {len(sdu)} bytes', 'cyan')) print(color(f'<<< [L2CAP SDU]: {len(sdu)} bytes', 'cyan'))
l2cap_to_tcp_pipe.write(sdu) l2cap_to_tcp_pipe.write(sdu)
@@ -209,7 +211,7 @@ class ClientBridge:
writer.close() writer.close()
return return
l2cap_channel.sink = on_coc_sdu l2cap_channel.sink = on_channel_sdu
l2cap_channel.on('close', on_l2cap_close) l2cap_channel.on('close', on_l2cap_close)
# Start a flow control pipe from L2CAP to TCP # Start a flow control pipe from L2CAP to TCP
@@ -274,23 +276,29 @@ async def run(device_config, hci_transport, bridge):
@click.pass_context @click.pass_context
@click.option('--device-config', help='Device configuration file', required=True) @click.option('--device-config', help='Device configuration file', required=True)
@click.option('--hci-transport', help='HCI transport', required=True) @click.option('--hci-transport', help='HCI transport', required=True)
@click.option('--psm', help='PSM for L2CAP CoC', type=int, default=1234) @click.option('--psm', help='PSM for L2CAP', type=int, default=1234)
@click.option( @click.option(
'--l2cap-coc-max-credits', '--l2cap-max-credits',
help='Maximum L2CAP CoC Credits', help='Maximum L2CAP Credits',
type=click.IntRange(1, 65535), type=click.IntRange(1, 65535),
default=128, default=128,
) )
@click.option( @click.option(
'--l2cap-coc-mtu', '--l2cap-mtu',
help='L2CAP CoC MTU', help='L2CAP MTU',
type=click.IntRange(23, 65535), type=click.IntRange(
default=1022, l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU,
l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MTU,
),
default=1024,
) )
@click.option( @click.option(
'--l2cap-coc-mps', '--l2cap-mps',
help='L2CAP CoC MPS', help='L2CAP MPS',
type=click.IntRange(23, 65533), type=click.IntRange(
l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MPS,
l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MPS,
),
default=1024, default=1024,
) )
def cli( def cli(
@@ -298,17 +306,17 @@ def cli(
device_config, device_config,
hci_transport, hci_transport,
psm, psm,
l2cap_coc_max_credits, l2cap_max_credits,
l2cap_coc_mtu, l2cap_mtu,
l2cap_coc_mps, l2cap_mps,
): ):
context.ensure_object(dict) context.ensure_object(dict)
context.obj['device_config'] = device_config context.obj['device_config'] = device_config
context.obj['hci_transport'] = hci_transport context.obj['hci_transport'] = hci_transport
context.obj['psm'] = psm context.obj['psm'] = psm
context.obj['max_credits'] = l2cap_coc_max_credits context.obj['max_credits'] = l2cap_max_credits
context.obj['mtu'] = l2cap_coc_mtu context.obj['mtu'] = l2cap_mtu
context.obj['mps'] = l2cap_coc_mps context.obj['mps'] = l2cap_mps
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
+6 -8
View File
@@ -52,11 +52,13 @@ from bumble.att import (
class Waiter: class Waiter:
instance = None instance = None
def __init__(self): def __init__(self, linger=False):
self.done = asyncio.get_running_loop().create_future() self.done = asyncio.get_running_loop().create_future()
self.linger = linger
def terminate(self): def terminate(self):
self.done.set_result(None) if not self.linger:
self.done.set_result(None)
async def wait_until_terminated(self): async def wait_until_terminated(self):
return await self.done return await self.done
@@ -302,7 +304,7 @@ async def pair(
hci_transport, hci_transport,
address_or_name, address_or_name,
): ):
Waiter.instance = Waiter() Waiter.instance = Waiter(linger=linger)
print('<<< connecting to HCI...') print('<<< connecting to HCI...')
async with await open_transport_or_link(hci_transport) as (hci_source, hci_sink): async with await open_transport_or_link(hci_transport) as (hci_source, hci_sink):
@@ -396,7 +398,6 @@ async def pair(
address_or_name, address_or_name,
transport=BT_LE_TRANSPORT if mode == 'le' else BT_BR_EDR_TRANSPORT, transport=BT_LE_TRANSPORT if mode == 'le' else BT_BR_EDR_TRANSPORT,
) )
pairing_failure = False
if not request: if not request:
try: try:
@@ -405,11 +406,8 @@ async def pair(
else: else:
await connection.authenticate() await connection.authenticate()
except ProtocolError as error: except ProtocolError as error:
pairing_failure = True
print(color(f'Pairing failed: {error}', 'red')) print(color(f'Pairing failed: {error}', 'red'))
if not linger or pairing_failure:
return
else: else:
if mode == 'le': if mode == 'le':
# Advertise so that peers can find us and connect # Advertise so that peers can find us and connect
@@ -459,7 +457,7 @@ class LogHandler(logging.Handler):
help='Enable CTKD', help='Enable CTKD',
show_default=True, show_default=True,
) )
@click.option('--linger', default=True, is_flag=True, help='Linger after pairing') @click.option('--linger', default=False, is_flag=True, help='Linger after pairing')
@click.option( @click.option(
'--io', '--io',
type=click.Choice( type=click.Choice(
+53 -11
View File
@@ -25,9 +25,21 @@
from __future__ import annotations from __future__ import annotations
import enum import enum
import functools import functools
import inspect
import struct import struct
from typing import (
Any,
Awaitable,
Callable,
Dict,
List,
Optional,
Type,
Union,
TYPE_CHECKING,
)
from pyee import EventEmitter from pyee import EventEmitter
from typing import Dict, Type, List, Protocol, Union, Optional, Any, TYPE_CHECKING
from bumble.core import UUID, name_or_number, ProtocolError from bumble.core import UUID, name_or_number, ProtocolError
from bumble.hci import HCI_Object, key_with_value from bumble.hci import HCI_Object, key_with_value
@@ -722,12 +734,38 @@ class ATT_Handle_Value_Confirmation(ATT_PDU):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class ConnectionValue(Protocol): class AttributeValue:
def read(self, connection) -> bytes: '''
... Attribute value where reading and/or writing is delegated to functions
passed as arguments to the constructor.
'''
def write(self, connection, value: bytes) -> None: def __init__(
... self,
read: Union[
Callable[[Optional[Connection]], bytes],
Callable[[Optional[Connection]], Awaitable[bytes]],
None,
] = None,
write: Union[
Callable[[Optional[Connection], bytes], None],
Callable[[Optional[Connection], bytes], Awaitable[None]],
None,
] = None,
):
self._read = read
self._write = write
def read(self, connection: Optional[Connection]) -> Union[bytes, Awaitable[bytes]]:
return self._read(connection) if self._read else b''
def write(
self, connection: Optional[Connection], value: bytes
) -> Union[Awaitable[None], None]:
if self._write:
return self._write(connection, value)
return None
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -770,13 +808,13 @@ class Attribute(EventEmitter):
READ_REQUIRES_AUTHORIZATION = Permissions.READ_REQUIRES_AUTHORIZATION READ_REQUIRES_AUTHORIZATION = Permissions.READ_REQUIRES_AUTHORIZATION
WRITE_REQUIRES_AUTHORIZATION = Permissions.WRITE_REQUIRES_AUTHORIZATION WRITE_REQUIRES_AUTHORIZATION = Permissions.WRITE_REQUIRES_AUTHORIZATION
value: Union[str, bytes, ConnectionValue] value: Union[bytes, AttributeValue]
def __init__( def __init__(
self, self,
attribute_type: Union[str, bytes, UUID], attribute_type: Union[str, bytes, UUID],
permissions: Union[str, Attribute.Permissions], permissions: Union[str, Attribute.Permissions],
value: Union[str, bytes, ConnectionValue] = b'', value: Union[str, bytes, AttributeValue] = b'',
) -> None: ) -> None:
EventEmitter.__init__(self) EventEmitter.__init__(self)
self.handle = 0 self.handle = 0
@@ -806,7 +844,7 @@ class Attribute(EventEmitter):
def decode_value(self, value_bytes: bytes) -> Any: def decode_value(self, value_bytes: bytes) -> Any:
return value_bytes return value_bytes
def read_value(self, connection: Optional[Connection]) -> bytes: async def read_value(self, connection: Optional[Connection]) -> bytes:
if ( if (
(self.permissions & self.READ_REQUIRES_ENCRYPTION) (self.permissions & self.READ_REQUIRES_ENCRYPTION)
and connection is not None and connection is not None
@@ -832,6 +870,8 @@ class Attribute(EventEmitter):
if hasattr(self.value, 'read'): if hasattr(self.value, 'read'):
try: try:
value = self.value.read(connection) value = self.value.read(connection)
if inspect.isawaitable(value):
value = await value
except ATT_Error as error: except ATT_Error as error:
raise ATT_Error( raise ATT_Error(
error_code=error.error_code, att_handle=self.handle error_code=error.error_code, att_handle=self.handle
@@ -841,7 +881,7 @@ class Attribute(EventEmitter):
return self.encode_value(value) return self.encode_value(value)
def write_value(self, connection: Connection, value_bytes: bytes) -> None: async def write_value(self, connection: Connection, value_bytes: bytes) -> None:
if ( if (
self.permissions & self.WRITE_REQUIRES_ENCRYPTION self.permissions & self.WRITE_REQUIRES_ENCRYPTION
) and not connection.encryption: ) and not connection.encryption:
@@ -864,7 +904,9 @@ class Attribute(EventEmitter):
if hasattr(self.value, 'write'): if hasattr(self.value, 'write'):
try: try:
self.value.write(connection, value) # pylint: disable=not-callable result = self.value.write(connection, value)
if inspect.isawaitable(result):
await result
except ATT_Error as error: except ATT_Error as error:
raise ATT_Error( raise ATT_Error(
error_code=error.error_code, att_handle=self.handle error_code=error.error_code, att_handle=self.handle
+28 -1
View File
@@ -134,12 +134,14 @@ class Controller:
'0000000060000000' '0000000060000000'
) # BR/EDR Not Supported, LE Supported (Controller) ) # BR/EDR Not Supported, LE Supported (Controller)
self.manufacturer_name = 0xFFFF self.manufacturer_name = 0xFFFF
self.hc_data_packet_length = 27
self.hc_total_num_data_packets = 64
self.hc_le_data_packet_length = 27 self.hc_le_data_packet_length = 27
self.hc_total_num_le_data_packets = 64 self.hc_total_num_le_data_packets = 64
self.event_mask = 0 self.event_mask = 0
self.event_mask_page_2 = 0 self.event_mask_page_2 = 0
self.supported_commands = bytes.fromhex( self.supported_commands = bytes.fromhex(
'2000800000c000000000e40000002822000000000000040000f7ffff7f000000' '2000800000c000000000e4000000a822000000000000040000f7ffff7f000000'
'30f0f9ff01008004000000000000000000000000000000000000000000000000' '30f0f9ff01008004000000000000000000000000000000000000000000000000'
) )
self.le_event_mask = 0 self.le_event_mask = 0
@@ -914,6 +916,19 @@ class Controller:
''' '''
return bytes([HCI_SUCCESS]) + self.lmp_features return bytes([HCI_SUCCESS]) + self.lmp_features
def on_hci_read_buffer_size_command(self, _command):
'''
See Bluetooth spec Vol 4, Part E - 7.4.5 Read Buffer Size Command
'''
return struct.pack(
'<BHBHH',
HCI_SUCCESS,
self.hc_data_packet_length,
0,
self.hc_total_num_data_packets,
0,
)
def on_hci_read_bd_addr_command(self, _command): def on_hci_read_bd_addr_command(self, _command):
''' '''
See Bluetooth spec Vol 4, Part E - 7.4.6 Read BD_ADDR Command See Bluetooth spec Vol 4, Part E - 7.4.6 Read BD_ADDR Command
@@ -1263,3 +1278,15 @@ class Controller:
See Bluetooth spec Vol 4, Part E - 7.8.74 LE Read Transmit Power Command See Bluetooth spec Vol 4, Part E - 7.8.74 LE Read Transmit Power Command
''' '''
return struct.pack('<BBB', HCI_SUCCESS, 0, 0) return struct.pack('<BBB', HCI_SUCCESS, 0, 0)
def on_hci_le_setup_iso_data_path_command(self, command):
'''
See Bluetooth spec Vol 4, Part E - 7.8.109 LE Setup ISO Data Path Command
'''
return struct.pack('<BH', HCI_SUCCESS, command.connection_handle)
def on_hci_le_remove_iso_data_path_command(self, command):
'''
See Bluetooth spec Vol 4, Part E - 7.8.110 LE Remove ISO Data Path Command
'''
return struct.pack('<BH', HCI_SUCCESS, command.connection_handle)
+10
View File
@@ -100,6 +100,16 @@ class EccKey:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# -----------------------------------------------------------------------------
def generate_prand() -> bytes:
'''Generates random 3 bytes, with the 2 most significant bits of 0b01.
See Bluetooth spec, Vol 6, Part E - Table 1.2.
'''
prand_bytes = secrets.token_bytes(6)
return prand_bytes[:2] + bytes([(prand_bytes[2] & 0b01111111) | 0b01000000])
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def xor(x: bytes, y: bytes) -> bytes: def xor(x: bytes, y: bytes) -> bytes:
assert len(x) == len(y) assert len(x) == len(y)
+180 -46
View File
@@ -437,6 +437,38 @@ class AdvertisingType(IntEnum):
) )
# -----------------------------------------------------------------------------
@dataclass
class LegacyAdvertiser:
device: Device
advertising_type: AdvertisingType
own_address_type: OwnAddressType
auto_restart: bool
advertising_data: Optional[bytes]
scan_response_data: Optional[bytes]
async def stop(self) -> None:
await self.device.stop_legacy_advertising()
# -----------------------------------------------------------------------------
@dataclass
class ExtendedAdvertiser(CompositeEventEmitter):
device: Device
handle: int
advertising_properties: HCI_LE_Set_Extended_Advertising_Parameters_Command.AdvertisingProperties
own_address_type: OwnAddressType
auto_restart: bool
advertising_data: Optional[bytes]
scan_response_data: Optional[bytes]
def __post_init__(self) -> None:
super().__init__()
async def stop(self) -> None:
await self.device.stop_extended_advertising(self.handle)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class LePhyOptions: class LePhyOptions:
# Coded PHY preference # Coded PHY preference
@@ -658,6 +690,9 @@ class Connection(CompositeEventEmitter):
gatt_client: gatt_client.Client gatt_client: gatt_client.Client
pairing_peer_io_capability: Optional[int] pairing_peer_io_capability: Optional[int]
pairing_peer_authentication_requirements: Optional[int] pairing_peer_authentication_requirements: Optional[int]
advertiser_after_disconnection: Union[
LegacyAdvertiser, ExtendedAdvertiser, None
] = None
@composite_listener @composite_listener
class Listener: class Listener:
@@ -1063,7 +1098,8 @@ class Device(CompositeEventEmitter):
] ]
advertisement_accumulators: Dict[Address, AdvertisementDataAccumulator] advertisement_accumulators: Dict[Address, AdvertisementDataAccumulator]
config: DeviceConfiguration config: DeviceConfiguration
extended_advertising_handles: Set[int] legacy_advertiser: Optional[LegacyAdvertiser]
extended_advertisers: Dict[int, ExtendedAdvertiser]
sco_links: Dict[int, ScoLink] sco_links: Dict[int, ScoLink]
cis_links: Dict[int, CisLink] cis_links: Dict[int, CisLink]
_pending_cis: Dict[int, Tuple[int, int]] _pending_cis: Dict[int, Tuple[int, int]]
@@ -1141,10 +1177,7 @@ class Device(CompositeEventEmitter):
self._host = None self._host = None
self.powered_on = False self.powered_on = False
self.advertising = False
self.advertising_type = None
self.auto_restart_inquiry = True self.auto_restart_inquiry = True
self.auto_restart_advertising = False
self.command_timeout = 10 # seconds self.command_timeout = 10 # seconds
self.gatt_server = gatt_server.Server(self) self.gatt_server = gatt_server.Server(self)
self.sdp_server = sdp.Server(self) self.sdp_server = sdp.Server(self)
@@ -1168,10 +1201,10 @@ class Device(CompositeEventEmitter):
self.classic_pending_accepts = { self.classic_pending_accepts = {
Address.ANY: [] Address.ANY: []
} # Futures, by BD address OR [Futures] for Address.ANY } # Futures, by BD address OR [Futures] for Address.ANY
self.extended_advertising_handles = set() self.legacy_advertiser = None
self.extended_advertisers = {}
# Own address type cache # Own address type cache
self.advertising_own_address_type = None
self.connect_own_address_type = None self.connect_own_address_type = None
# Use the initial config or a default # Use the initial config or a default
@@ -1579,6 +1612,7 @@ class Device(CompositeEventEmitter):
return self.host.supports_le_feature(feature_map[phy]) return self.host.supports_le_feature(feature_map[phy])
@deprecated("Please use start_legacy_advertising.")
async def start_advertising( async def start_advertising(
self, self,
advertising_type: AdvertisingType = AdvertisingType.UNDIRECTED_CONNECTABLE_SCANNABLE, advertising_type: AdvertisingType = AdvertisingType.UNDIRECTED_CONNECTABLE_SCANNABLE,
@@ -1586,15 +1620,49 @@ class Device(CompositeEventEmitter):
own_address_type: int = OwnAddressType.RANDOM, own_address_type: int = OwnAddressType.RANDOM,
auto_restart: bool = False, auto_restart: bool = False,
) -> None: ) -> None:
await self.start_legacy_advertising(
advertising_type=advertising_type,
target=target,
own_address_type=OwnAddressType(own_address_type),
auto_restart=auto_restart,
)
async def start_legacy_advertising(
self,
advertising_type: AdvertisingType = AdvertisingType.UNDIRECTED_CONNECTABLE_SCANNABLE,
target: Optional[Address] = None,
own_address_type: OwnAddressType = OwnAddressType.RANDOM,
auto_restart: bool = False,
advertising_data: Optional[bytes] = None,
scan_response_data: Optional[bytes] = None,
) -> LegacyAdvertiser:
"""Starts an legacy advertisement.
Args:
advertising_type: Advertising type passed to HCI_LE_Set_Advertising_Parameters_Command.
target: Directed advertising target. Directed type should be set in advertising_type arg.
own_address_type: own address type to use in the advertising.
auto_restart: whether the advertisement will be restarted after disconnection.
scan_response_data: raw scan response.
advertising_data: raw advertising data.
Returns:
LegacyAdvertiser object containing the metadata of advertisement.
"""
if self.extended_advertisers:
logger.warning(
'Trying to start Legacy and Extended Advertising at the same time!'
)
# If we're advertising, stop first # If we're advertising, stop first
if self.advertising: if self.legacy_advertiser:
await self.stop_advertising() await self.stop_advertising()
# Set/update the advertising data if the advertising type allows it # Set/update the advertising data if the advertising type allows it
if advertising_type.has_data: if advertising_type.has_data:
await self.send_command( await self.send_command(
HCI_LE_Set_Advertising_Data_Command( HCI_LE_Set_Advertising_Data_Command(
advertising_data=self.advertising_data advertising_data=advertising_data or self.advertising_data or b''
), ),
check_result=True, check_result=True,
) )
@@ -1603,7 +1671,9 @@ class Device(CompositeEventEmitter):
if advertising_type.is_scannable: if advertising_type.is_scannable:
await self.send_command( await self.send_command(
HCI_LE_Set_Scan_Response_Data_Command( HCI_LE_Set_Scan_Response_Data_Command(
scan_response_data=self.scan_response_data scan_response_data=scan_response_data
or self.scan_response_data
or b''
), ),
check_result=True, check_result=True,
) )
@@ -1640,45 +1710,57 @@ class Device(CompositeEventEmitter):
check_result=True, check_result=True,
) )
self.advertising_type = advertising_type self.legacy_advertiser = LegacyAdvertiser(
self.advertising_own_address_type = own_address_type device=self,
self.advertising = True advertising_type=advertising_type,
self.auto_restart_advertising = auto_restart own_address_type=own_address_type,
auto_restart=auto_restart,
advertising_data=advertising_data,
scan_response_data=scan_response_data,
)
return self.legacy_advertiser
@deprecated("Please use stop_legacy_advertising.")
async def stop_advertising(self) -> None: async def stop_advertising(self) -> None:
await self.stop_legacy_advertising()
async def stop_legacy_advertising(self) -> None:
# Disable advertising # Disable advertising
if self.advertising: if self.legacy_advertiser:
await self.send_command( await self.send_command(
HCI_LE_Set_Advertising_Enable_Command(advertising_enable=0), HCI_LE_Set_Advertising_Enable_Command(advertising_enable=0),
check_result=True, check_result=True,
) )
self.advertising_type = None self.legacy_advertiser = None
self.advertising_own_address_type = None
self.advertising = False
self.auto_restart_advertising = False
@experimental('Extended Advertising is still experimental - Might be changed soon.') @experimental('Extended Advertising is still experimental - Might be changed soon.')
async def start_extended_advertising( async def start_extended_advertising(
self, self,
advertising_properties: HCI_LE_Set_Extended_Advertising_Parameters_Command.AdvertisingProperties = HCI_LE_Set_Extended_Advertising_Parameters_Command.AdvertisingProperties.CONNECTABLE_ADVERTISING, advertising_properties: HCI_LE_Set_Extended_Advertising_Parameters_Command.AdvertisingProperties = HCI_LE_Set_Extended_Advertising_Parameters_Command.AdvertisingProperties.CONNECTABLE_ADVERTISING,
target: Address = Address.ANY, target: Address = Address.ANY,
own_address_type: int = OwnAddressType.RANDOM, own_address_type: OwnAddressType = OwnAddressType.RANDOM,
scan_response: Optional[bytes] = None, auto_restart: bool = True,
advertising_data: Optional[bytes] = None, advertising_data: Optional[bytes] = None,
) -> int: scan_response_data: Optional[bytes] = None,
) -> ExtendedAdvertiser:
"""Starts an extended advertising set. """Starts an extended advertising set.
Args: Args:
advertising_properties: Properties to pass in HCI_LE_Set_Extended_Advertising_Parameters_Command advertising_properties: Properties to pass in HCI_LE_Set_Extended_Advertising_Parameters_Command
target: Directed advertising target. Directed property should be set in advertising_properties arg. target: Directed advertising target. Directed property should be set in advertising_properties arg.
own_address_type: own address type to use in the advertising. own_address_type: own address type to use in the advertising.
scan_response: raw scan response. When a non-none value is set, HCI_LE_Set_Extended_Scan_Response_Data_Command will be sent. auto_restart: whether the advertisement will be restarted after disconnection.
advertising_data: raw advertising data. When a non-none value is set, HCI_LE_Set_Advertising_Set_Random_Address_Command will be sent. advertising_data: raw advertising data. When a non-none value is set, HCI_LE_Set_Advertising_Set_Random_Address_Command will be sent.
scan_response_data: raw scan response. When a non-none value is set, HCI_LE_Set_Extended_Scan_Response_Data_Command will be sent.
Returns: Returns:
Handle of the new advertising set. ExtendedAdvertiser object containing the metadata of advertisement.
""" """
if self.legacy_advertiser:
logger.warning(
'Trying to start Legacy and Extended Advertising at the same time!'
)
adv_handle = -1 adv_handle = -1
# Find a free handle # Find a free handle
@@ -1686,7 +1768,7 @@ class Device(CompositeEventEmitter):
DEVICE_MIN_EXTENDED_ADVERTISING_SET_HANDLE, DEVICE_MIN_EXTENDED_ADVERTISING_SET_HANDLE,
DEVICE_MAX_EXTENDED_ADVERTISING_SET_HANDLE + 1, DEVICE_MAX_EXTENDED_ADVERTISING_SET_HANDLE + 1,
): ):
if i not in self.extended_advertising_handles: if i not in self.extended_advertisers:
adv_handle = i adv_handle = i
break break
@@ -1733,13 +1815,13 @@ class Device(CompositeEventEmitter):
) )
# Set the scan response if present # Set the scan response if present
if scan_response is not None: if scan_response_data is not None:
await self.send_command( await self.send_command(
HCI_LE_Set_Extended_Scan_Response_Data_Command( HCI_LE_Set_Extended_Scan_Response_Data_Command(
advertising_handle=adv_handle, advertising_handle=adv_handle,
operation=HCI_LE_Set_Extended_Advertising_Data_Command.Operation.COMPLETE_DATA, operation=HCI_LE_Set_Extended_Advertising_Data_Command.Operation.COMPLETE_DATA,
fragment_preference=0x01, # Should not fragment fragment_preference=0x01, # Should not fragment
scan_response_data=scan_response, scan_response_data=scan_response_data,
), ),
check_result=True, check_result=True,
) )
@@ -1774,8 +1856,16 @@ class Device(CompositeEventEmitter):
) )
raise error raise error
self.extended_advertising_handles.add(adv_handle) advertiser = self.extended_advertisers[adv_handle] = ExtendedAdvertiser(
return adv_handle device=self,
handle=adv_handle,
advertising_properties=advertising_properties,
own_address_type=own_address_type,
auto_restart=auto_restart,
advertising_data=advertising_data,
scan_response_data=scan_response_data,
)
return advertiser
@experimental('Extended Advertising is still experimental - Might be changed soon.') @experimental('Extended Advertising is still experimental - Might be changed soon.')
async def stop_extended_advertising(self, adv_handle: int) -> None: async def stop_extended_advertising(self, adv_handle: int) -> None:
@@ -1799,11 +1889,11 @@ class Device(CompositeEventEmitter):
HCI_LE_Remove_Advertising_Set_Command(advertising_handle=adv_handle), HCI_LE_Remove_Advertising_Set_Command(advertising_handle=adv_handle),
check_result=True, check_result=True,
) )
self.extended_advertising_handles.remove(adv_handle) del self.extended_advertisers[adv_handle]
@property @property
def is_advertising(self): def is_advertising(self):
return self.advertising return self.legacy_advertiser or self.extended_advertisers
async def start_scanning( async def start_scanning(
self, self,
@@ -3144,13 +3234,18 @@ class Device(CompositeEventEmitter):
# Guess which own address type is used for this connection. # Guess which own address type is used for this connection.
# This logic is somewhat correct but may need to be improved # This logic is somewhat correct but may need to be improved
# when multiple advertising are run simultaneously. # when multiple advertising are run simultaneously.
advertiser = None
if self.connect_own_address_type is not None: if self.connect_own_address_type is not None:
own_address_type = self.connect_own_address_type own_address_type = self.connect_own_address_type
elif self.legacy_advertiser:
own_address_type = self.legacy_advertiser.own_address_type
# Store advertiser for restarting - it's only required for legacy, since
# extended advertisement produces HCI_Advertising_Set_Terminated.
if self.legacy_advertiser.auto_restart:
advertiser = self.legacy_advertiser
else: else:
own_address_type = self.advertising_own_address_type # For extended advertisement, determining own address type later.
own_address_type = OwnAddressType.RANDOM
# We are no longer advertising
self.advertising = False
if own_address_type in ( if own_address_type in (
OwnAddressType.PUBLIC, OwnAddressType.PUBLIC,
@@ -3172,6 +3267,7 @@ class Device(CompositeEventEmitter):
connection_parameters, connection_parameters,
ConnectionPHY(HCI_LE_1M_PHY, HCI_LE_1M_PHY), ConnectionPHY(HCI_LE_1M_PHY, HCI_LE_1M_PHY),
) )
connection.advertiser_after_disconnection = advertiser
self.connections[connection_handle] = connection self.connections[connection_handle] = connection
# If supported, read which PHY we're connected with before # If supported, read which PHY we're connected with before
@@ -3203,10 +3299,10 @@ class Device(CompositeEventEmitter):
# For directed advertising, this means a timeout # For directed advertising, this means a timeout
if ( if (
transport == BT_LE_TRANSPORT transport == BT_LE_TRANSPORT
and self.advertising and self.legacy_advertiser
and self.advertising_type.is_directed and self.legacy_advertiser.advertising_type.is_directed
): ):
self.advertising = False self.legacy_advertiser = None
# Notify listeners # Notify listeners
error = core.ConnectionError( error = core.ConnectionError(
@@ -3268,16 +3364,30 @@ class Device(CompositeEventEmitter):
self.gatt_server.on_disconnection(connection) self.gatt_server.on_disconnection(connection)
# Restart advertising if auto-restart is enabled # Restart advertising if auto-restart is enabled
if self.auto_restart_advertising: if advertiser := connection.advertiser_after_disconnection:
logger.debug('restarting advertising') logger.debug('restarting advertising')
self.abort_on( if isinstance(advertiser, LegacyAdvertiser):
'flush', self.abort_on(
self.start_advertising( 'flush',
advertising_type=self.advertising_type, # type: ignore[arg-type] self.start_legacy_advertising(
own_address_type=self.advertising_own_address_type, # type: ignore[arg-type] advertising_type=advertiser.advertising_type,
auto_restart=True, own_address_type=advertiser.own_address_type,
), advertising_data=advertiser.advertising_data,
) scan_response_data=advertiser.scan_response_data,
auto_restart=True,
),
)
elif isinstance(advertiser, ExtendedAdvertiser):
self.abort_on(
'flush',
self.start_extended_advertising(
advertising_properties=advertiser.advertising_properties,
own_address_type=advertiser.own_address_type,
advertising_data=advertiser.advertising_data,
scan_response_data=advertiser.scan_response_data,
auto_restart=True,
),
)
elif sco_link := self.sco_links.pop(connection_handle, None): elif sco_link := self.sco_links.pop(connection_handle, None):
sco_link.emit('disconnection', reason) sco_link.emit('disconnection', reason)
elif cis_link := self.cis_links.pop(connection_handle, None): elif cis_link := self.cis_links.pop(connection_handle, None):
@@ -3600,6 +3710,30 @@ class Device(CompositeEventEmitter):
if sco_link := self.sco_links.get(sco_handle, None): if sco_link := self.sco_links.get(sco_handle, None):
sco_link.emit('pdu', packet) sco_link.emit('pdu', packet)
# [LE only]
@host_event_handler
@experimental('Only for testing')
def on_advertising_set_termination(
self,
status: int,
advertising_handle: int,
connection_handle: int,
) -> None:
if status == HCI_SUCCESS:
connection = self.lookup_connection(connection_handle)
if advertiser := self.extended_advertisers.pop(advertising_handle, None):
if connection:
if advertiser.auto_restart:
connection.advertiser_after_disconnection = advertiser
if advertiser.own_address_type in (
OwnAddressType.PUBLIC,
OwnAddressType.RESOLVABLE_OR_PUBLIC,
):
connection.self_address = self.public_address
else:
connection.self_address = self.random_address
advertiser.emit('termination', status)
# [LE only] # [LE only]
@host_event_handler @host_event_handler
@with_connection_from_handle @with_connection_from_handle
+28 -32
View File
@@ -19,12 +19,17 @@ like loading firmware after a cold start.
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import abc from __future__ import annotations
import logging import logging
import pathlib import pathlib
import platform import platform
from . import rtk from typing import Dict, Iterable, Optional, Type, TYPE_CHECKING
from . import rtk
from .common import Driver
if TYPE_CHECKING:
from bumble.host import Host
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -32,40 +37,31 @@ from . import rtk
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
class Driver(abc.ABC):
"""Base class for drivers."""
@staticmethod
async def for_host(_host):
"""Return a driver instance for a host.
Args:
host: Host object for which a driver should be created.
Returns:
A Driver instance if a driver should be instantiated for this host, or
None if no driver instance of this class is needed.
"""
return None
@abc.abstractmethod
async def init_controller(self):
"""Initialize the controller."""
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Functions # Functions
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def get_driver_for_host(host): async def get_driver_for_host(host: Host) -> Optional[Driver]:
"""Probe all known diver classes until one returns a valid instance for a host, """Probe diver classes until one returns a valid instance for a host, or none is
or none is found. found.
If a "driver" HCI metadata entry is present, only that driver class will be probed.
""" """
if driver := await rtk.Driver.for_host(host): driver_classes: Dict[str, Type[Driver]] = {"rtk": rtk.Driver}
logger.debug("Instantiated RTK driver") probe_list: Iterable[str]
return driver if driver_name := host.hci_metadata.get("driver"):
# Only probe a single driver
probe_list = [driver_name]
else:
# Probe all drivers
probe_list = driver_classes.keys()
for driver_name in probe_list:
if driver_class := driver_classes.get(driver_name):
logger.debug(f"Probing driver class: {driver_name}")
if driver := await driver_class.for_host(host):
logger.debug(f"Instantiated {driver_name} driver")
return driver
else:
logger.debug(f"Skipping unknown driver class: {driver_name}")
return None return None
+45
View File
@@ -0,0 +1,45 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Common types for drivers.
"""
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import abc
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
class Driver(abc.ABC):
"""Base class for drivers."""
@staticmethod
async def for_host(_host):
"""Return a driver instance for a host.
Args:
host: Host object for which a driver should be created.
Returns:
A Driver instance if a driver should be instantiated for this host, or
None if no driver instance of this class is needed.
"""
return None
@abc.abstractmethod
async def init_controller(self):
"""Initialize the controller."""
+11 -4
View File
@@ -41,7 +41,7 @@ from bumble.hci import (
HCI_Reset_Command, HCI_Reset_Command,
HCI_Read_Local_Version_Information_Command, HCI_Read_Local_Version_Information_Command,
) )
from bumble.drivers import common
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -285,7 +285,7 @@ class Firmware:
) )
class Driver: class Driver(common.Driver):
@dataclass @dataclass
class DriverInfo: class DriverInfo:
rom: int rom: int
@@ -470,8 +470,12 @@ class Driver:
logger.debug("USB metadata not found") logger.debug("USB metadata not found")
return False return False
vendor_id = host.hci_metadata.get("vendor_id", None) if host.hci_metadata.get('driver') == 'rtk':
product_id = host.hci_metadata.get("product_id", None) # Forced driver
return True
vendor_id = host.hci_metadata.get("vendor_id")
product_id = host.hci_metadata.get("product_id")
if vendor_id is None or product_id is None: if vendor_id is None or product_id is None:
logger.debug("USB metadata not sufficient") logger.debug("USB metadata not sufficient")
return False return False
@@ -486,6 +490,9 @@ class Driver:
@classmethod @classmethod
async def driver_info_for_host(cls, host): async def driver_info_for_host(cls, host):
await host.send_command(HCI_Reset_Command(), check_result=True)
host.ready = True # Needed to let the host know the controller is ready.
response = await host.send_command( response = await host.send_command(
HCI_Read_Local_Version_Information_Command(), check_result=True HCI_Read_Local_Version_Information_Command(), check_result=True
) )
+66 -51
View File
@@ -23,16 +23,28 @@
# Imports # Imports
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
from __future__ import annotations from __future__ import annotations
import asyncio
import enum import enum
import functools import functools
import logging import logging
import struct import struct
from typing import Optional, Sequence, Iterable, List, Union from typing import (
Callable,
Dict,
Iterable,
List,
Optional,
Sequence,
Union,
TYPE_CHECKING,
)
from .colors import color from bumble.colors import color
from .core import UUID, get_dict_key_by_value from bumble.core import UUID
from .att import Attribute from bumble.att import Attribute, AttributeValue
if TYPE_CHECKING:
from bumble.gatt_client import AttributeProxy
from bumble.device import Connection
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -368,9 +380,12 @@ class TemplateService(Service):
UUID: UUID UUID: UUID
def __init__( def __init__(
self, characteristics: List[Characteristic], primary: bool = True self,
characteristics: List[Characteristic],
primary: bool = True,
included_services: List[Service] = [],
) -> None: ) -> None:
super().__init__(self.UUID, characteristics, primary) super().__init__(self.UUID, characteristics, primary, included_services)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -519,56 +534,43 @@ class CharacteristicDeclaration(Attribute):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class CharacteristicValue: class CharacteristicValue(AttributeValue):
''' """Same as AttributeValue, for backward compatibility"""
Characteristic value where reading and/or writing is delegated to functions
passed as arguments to the constructor.
'''
def __init__(self, read=None, write=None):
self._read = read
self._write = write
def read(self, connection):
return self._read(connection) if self._read else b''
def write(self, connection, value):
if self._write:
self._write(connection, value)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class CharacteristicAdapter: class CharacteristicAdapter:
''' '''
An adapter that can adapt any object with `read_value` and `write_value` An adapter that can adapt Characteristic and AttributeProxy objects
methods (like Characteristic and CharacteristicProxy objects) by wrapping by wrapping their `read_value()` and `write_value()` methods with ones that
those methods with ones that return/accept encoded/decoded values. return/accept encoded/decoded values.
Objects with async methods are considered proxies, so the adaptation is one
where the return value of `read_value` is decoded and the value passed to For proxies (i.e used by a GATT client), the adaptation is one where the return
`write_value` is encoded. Other objects are considered local characteristics value of `read_value()` is decoded and the value passed to `write_value()` is
so the adaptation is one where the return value of `read_value` is encoded encoded. The `subscribe()` method, is wrapped with one where the values are decoded
and the value passed to `write_value` is decoded. before being passed to the subscriber.
If the characteristic has a `subscribe` method, it is wrapped with one where
the values are decoded before being passed to the subscriber. For local values (i.e hosted by a GATT server) the adaptation is one where the
return value of `read_value()` is encoded and the value passed to `write_value()`
is decoded.
''' '''
def __init__(self, characteristic): read_value: Callable
self.wrapped_characteristic = characteristic write_value: Callable
self.subscribers = {} # Map from subscriber to proxy subscriber
if asyncio.iscoroutinefunction( def __init__(self, characteristic: Union[Characteristic, AttributeProxy]):
characteristic.read_value self.wrapped_characteristic = characteristic
) and asyncio.iscoroutinefunction(characteristic.write_value): self.subscribers: Dict[
self.read_value = self.read_decoded_value Callable, Callable
self.write_value = self.write_decoded_value ] = {} # Map from subscriber to proxy subscriber
else:
if isinstance(characteristic, Characteristic):
self.read_value = self.read_encoded_value self.read_value = self.read_encoded_value
self.write_value = self.write_encoded_value self.write_value = self.write_encoded_value
else:
if hasattr(self.wrapped_characteristic, 'subscribe'): self.read_value = self.read_decoded_value
self.write_value = self.write_decoded_value
self.subscribe = self.wrapped_subscribe self.subscribe = self.wrapped_subscribe
if hasattr(self.wrapped_characteristic, 'unsubscribe'):
self.unsubscribe = self.wrapped_unsubscribe self.unsubscribe = self.wrapped_unsubscribe
def __getattr__(self, name): def __getattr__(self, name):
@@ -587,11 +589,13 @@ class CharacteristicAdapter:
else: else:
setattr(self.wrapped_characteristic, name, value) setattr(self.wrapped_characteristic, name, value)
def read_encoded_value(self, connection): async def read_encoded_value(self, connection):
return self.encode_value(self.wrapped_characteristic.read_value(connection)) return self.encode_value(
await self.wrapped_characteristic.read_value(connection)
)
def write_encoded_value(self, connection, value): async def write_encoded_value(self, connection, value):
return self.wrapped_characteristic.write_value( return await self.wrapped_characteristic.write_value(
connection, self.decode_value(value) connection, self.decode_value(value)
) )
@@ -726,13 +730,24 @@ class Descriptor(Attribute):
''' '''
def __str__(self) -> str: def __str__(self) -> str:
if isinstance(self.value, bytes):
value_str = self.value.hex()
elif isinstance(self.value, CharacteristicValue):
value = self.value.read(None)
if isinstance(value, bytes):
value_str = value.hex()
else:
value_str = '<async>'
else:
value_str = '<...>'
return ( return (
f'Descriptor(handle=0x{self.handle:04X}, ' f'Descriptor(handle=0x{self.handle:04X}, '
f'type={self.type}, ' f'type={self.type}, '
f'value={self.read_value(None).hex()})' f'value={value_str})'
) )
# -----------------------------------------------------------------------------
class ClientCharacteristicConfigurationBits(enum.IntFlag): class ClientCharacteristicConfigurationBits(enum.IntFlag):
''' '''
See Vol 3, Part G - 3.3.3.3 - Table 3.11 Client Characteristic Configuration bit See Vol 3, Part G - 3.3.3.3 - Table 3.11 Client Characteristic Configuration bit
+30 -22
View File
@@ -31,9 +31,9 @@ import struct
from typing import List, Tuple, Optional, TypeVar, Type, Dict, Iterable, TYPE_CHECKING from typing import List, Tuple, Optional, TypeVar, Type, Dict, Iterable, TYPE_CHECKING
from pyee import EventEmitter from pyee import EventEmitter
from .colors import color from bumble.colors import color
from .core import UUID from bumble.core import UUID
from .att import ( from bumble.att import (
ATT_ATTRIBUTE_NOT_FOUND_ERROR, ATT_ATTRIBUTE_NOT_FOUND_ERROR,
ATT_ATTRIBUTE_NOT_LONG_ERROR, ATT_ATTRIBUTE_NOT_LONG_ERROR,
ATT_CID, ATT_CID,
@@ -60,7 +60,7 @@ from .att import (
ATT_Write_Response, ATT_Write_Response,
Attribute, Attribute,
) )
from .gatt import ( from bumble.gatt import (
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE, GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR, GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
GATT_MAX_ATTRIBUTE_VALUE_SIZE, GATT_MAX_ATTRIBUTE_VALUE_SIZE,
@@ -74,6 +74,7 @@ from .gatt import (
Descriptor, Descriptor,
Service, Service,
) )
from bumble.utils import AsyncRunner
if TYPE_CHECKING: if TYPE_CHECKING:
from bumble.device import Device, Connection from bumble.device import Device, Connection
@@ -379,7 +380,7 @@ class Server(EventEmitter):
# Get or encode the value # Get or encode the value
value = ( value = (
attribute.read_value(connection) await attribute.read_value(connection)
if value is None if value is None
else attribute.encode_value(value) else attribute.encode_value(value)
) )
@@ -422,7 +423,7 @@ class Server(EventEmitter):
# Get or encode the value # Get or encode the value
value = ( value = (
attribute.read_value(connection) await attribute.read_value(connection)
if value is None if value is None
else attribute.encode_value(value) else attribute.encode_value(value)
) )
@@ -650,7 +651,8 @@ class Server(EventEmitter):
self.send_response(connection, response) self.send_response(connection, response)
def on_att_find_by_type_value_request(self, connection, request): @AsyncRunner.run_in_task()
async def on_att_find_by_type_value_request(self, connection, request):
''' '''
See Bluetooth spec Vol 3, Part F - 3.4.3.3 Find By Type Value Request See Bluetooth spec Vol 3, Part F - 3.4.3.3 Find By Type Value Request
''' '''
@@ -658,13 +660,13 @@ class Server(EventEmitter):
# Build list of returned attributes # Build list of returned attributes
pdu_space_available = connection.att_mtu - 2 pdu_space_available = connection.att_mtu - 2
attributes = [] attributes = []
for attribute in ( async for attribute in (
attribute attribute
for attribute in self.attributes for attribute in self.attributes
if attribute.handle >= request.starting_handle if attribute.handle >= request.starting_handle
and attribute.handle <= request.ending_handle and attribute.handle <= request.ending_handle
and attribute.type == request.attribute_type and attribute.type == request.attribute_type
and attribute.read_value(connection) == request.attribute_value and (await attribute.read_value(connection)) == request.attribute_value
and pdu_space_available >= 4 and pdu_space_available >= 4
): ):
# TODO: check permissions # TODO: check permissions
@@ -702,7 +704,8 @@ class Server(EventEmitter):
self.send_response(connection, response) self.send_response(connection, response)
def on_att_read_by_type_request(self, connection, request): @AsyncRunner.run_in_task()
async def on_att_read_by_type_request(self, connection, request):
''' '''
See Bluetooth spec Vol 3, Part F - 3.4.4.1 Read By Type Request See Bluetooth spec Vol 3, Part F - 3.4.4.1 Read By Type Request
''' '''
@@ -725,7 +728,7 @@ class Server(EventEmitter):
and pdu_space_available and pdu_space_available
): ):
try: try:
attribute_value = attribute.read_value(connection) attribute_value = await attribute.read_value(connection)
except ATT_Error as error: except ATT_Error as error:
# If the first attribute is unreadable, return an error # If the first attribute is unreadable, return an error
# Otherwise return attributes up to this point # Otherwise return attributes up to this point
@@ -767,14 +770,15 @@ class Server(EventEmitter):
self.send_response(connection, response) self.send_response(connection, response)
def on_att_read_request(self, connection, request): @AsyncRunner.run_in_task()
async def on_att_read_request(self, connection, request):
''' '''
See Bluetooth spec Vol 3, Part F - 3.4.4.3 Read Request See Bluetooth spec Vol 3, Part F - 3.4.4.3 Read Request
''' '''
if attribute := self.get_attribute(request.attribute_handle): if attribute := self.get_attribute(request.attribute_handle):
try: try:
value = attribute.read_value(connection) value = await attribute.read_value(connection)
except ATT_Error as error: except ATT_Error as error:
response = ATT_Error_Response( response = ATT_Error_Response(
request_opcode_in_error=request.op_code, request_opcode_in_error=request.op_code,
@@ -792,14 +796,15 @@ class Server(EventEmitter):
) )
self.send_response(connection, response) self.send_response(connection, response)
def on_att_read_blob_request(self, connection, request): @AsyncRunner.run_in_task()
async def on_att_read_blob_request(self, connection, request):
''' '''
See Bluetooth spec Vol 3, Part F - 3.4.4.5 Read Blob Request See Bluetooth spec Vol 3, Part F - 3.4.4.5 Read Blob Request
''' '''
if attribute := self.get_attribute(request.attribute_handle): if attribute := self.get_attribute(request.attribute_handle):
try: try:
value = attribute.read_value(connection) value = await attribute.read_value(connection)
except ATT_Error as error: except ATT_Error as error:
response = ATT_Error_Response( response = ATT_Error_Response(
request_opcode_in_error=request.op_code, request_opcode_in_error=request.op_code,
@@ -836,7 +841,8 @@ class Server(EventEmitter):
) )
self.send_response(connection, response) self.send_response(connection, response)
def on_att_read_by_group_type_request(self, connection, request): @AsyncRunner.run_in_task()
async def on_att_read_by_group_type_request(self, connection, request):
''' '''
See Bluetooth spec Vol 3, Part F - 3.4.4.9 Read by Group Type Request See Bluetooth spec Vol 3, Part F - 3.4.4.9 Read by Group Type Request
''' '''
@@ -864,7 +870,7 @@ class Server(EventEmitter):
): ):
# No need to catch permission errors here, since these attributes # No need to catch permission errors here, since these attributes
# must all be world-readable # must all be world-readable
attribute_value = attribute.read_value(connection) attribute_value = await attribute.read_value(connection)
# Check the attribute value size # Check the attribute value size
max_attribute_size = min(connection.att_mtu - 6, 251) max_attribute_size = min(connection.att_mtu - 6, 251)
if len(attribute_value) > max_attribute_size: if len(attribute_value) > max_attribute_size:
@@ -903,7 +909,8 @@ class Server(EventEmitter):
self.send_response(connection, response) self.send_response(connection, response)
def on_att_write_request(self, connection, request): @AsyncRunner.run_in_task()
async def on_att_write_request(self, connection, request):
''' '''
See Bluetooth spec Vol 3, Part F - 3.4.5.1 Write Request See Bluetooth spec Vol 3, Part F - 3.4.5.1 Write Request
''' '''
@@ -936,12 +943,13 @@ class Server(EventEmitter):
return return
# Accept the value # Accept the value
attribute.write_value(connection, request.attribute_value) await attribute.write_value(connection, request.attribute_value)
# Done # Done
self.send_response(connection, ATT_Write_Response()) self.send_response(connection, ATT_Write_Response())
def on_att_write_command(self, connection, request): @AsyncRunner.run_in_task()
async def on_att_write_command(self, connection, request):
''' '''
See Bluetooth spec Vol 3, Part F - 3.4.5.3 Write Command See Bluetooth spec Vol 3, Part F - 3.4.5.3 Write Command
''' '''
@@ -959,9 +967,9 @@ class Server(EventEmitter):
# Accept the value # Accept the value
try: try:
attribute.write_value(connection, request.attribute_value) await attribute.write_value(connection, request.attribute_value)
except Exception as error: except Exception as error:
logger.warning(f'!!! ignoring exception: {error}') logger.exception(f'!!! ignoring exception: {error}')
def on_att_handle_value_confirmation(self, connection, _confirmation): def on_att_handle_value_confirmation(self, connection, _confirmation):
''' '''
+110 -17
View File
@@ -21,9 +21,11 @@ import dataclasses
import enum import enum
import functools import functools
import logging import logging
import secrets
import struct import struct
from typing import Any, Dict, Callable, Optional, Type, Union, List from typing import Any, Dict, Callable, Optional, Type, Union, List
from bumble import crypto
from .colors import color from .colors import color
from .core import ( from .core import (
BT_BR_EDR_TRANSPORT, BT_BR_EDR_TRANSPORT,
@@ -728,6 +730,19 @@ HCI_LE_PHY_TYPE_TO_BIT = {
HCI_LE_CODED_PHY: HCI_LE_CODED_PHY_BIT HCI_LE_CODED_PHY: HCI_LE_CODED_PHY_BIT
} }
class Phy(enum.IntEnum):
LE_1M = 0x01
LE_2M = 0x02
LE_CODED = 0x03
class PhyBit(enum.IntFlag):
LE_1M = 0b00000001
LE_2M = 0b00000010
LE_CODED = 0b00000100
# Connection Parameters # Connection Parameters
HCI_CONNECTION_INTERVAL_MS_PER_UNIT = 1.25 HCI_CONNECTION_INTERVAL_MS_PER_UNIT = 1.25
HCI_CONNECTION_LATENCY_MS_PER_UNIT = 1.25 HCI_CONNECTION_LATENCY_MS_PER_UNIT = 1.25
@@ -1868,6 +1883,43 @@ class Address:
address_type = data[offset - 1] address_type = data[offset - 1]
return Address.parse_address_with_type(data, offset, address_type) return Address.parse_address_with_type(data, offset, address_type)
@classmethod
def generate_static_address(cls) -> Address:
'''Generates Random Static Address, with the 2 most significant bits of 0b11.
See Bluetooth spec, Vol 6, Part B - Table 1.2.
'''
address_bytes = secrets.token_bytes(6)
address_bytes = address_bytes[:5] + bytes([address_bytes[5] | 0b11000000])
return Address(
address=address_bytes, address_type=Address.RANDOM_DEVICE_ADDRESS
)
@classmethod
def generate_private_address(cls, irk: bytes = b'') -> Address:
'''Generates Random Private MAC Address.
If IRK is present, a Resolvable Private Address, with the 2 most significant
bits of 0b01 will be generated. Otherwise, a Non-resolvable Private Address,
with the 2 most significant bits of 0b00 will be generated.
See Bluetooth spec, Vol 6, Part B - Table 1.2.
Args:
irk: Local Identity Resolving Key(IRK), in little-endian. If not set, a
non-resolvable address will be generated.
'''
if irk:
prand = crypto.generate_prand()
address_bytes = crypto.ah(irk, prand) + prand
else:
address_bytes = secrets.token_bytes(6)
address_bytes = address_bytes[:5] + bytes([address_bytes[5] & 0b00111111])
return Address(
address=address_bytes, address_type=Address.RANDOM_DEVICE_ADDRESS
)
def __init__( def __init__(
self, address: Union[bytes, str], address_type: int = RANDOM_DEVICE_ADDRESS self, address: Union[bytes, str], address_type: int = RANDOM_DEVICE_ADDRESS
): ):
@@ -1963,25 +2015,26 @@ Address.ANY_RANDOM = Address(b"\x00\x00\x00\x00\x00\x00", Address.RANDOM_DEVICE_
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class OwnAddressType: class OwnAddressType(enum.IntEnum):
PUBLIC = 0 PUBLIC = 0
RANDOM = 1 RANDOM = 1
RESOLVABLE_OR_PUBLIC = 2 RESOLVABLE_OR_PUBLIC = 2
RESOLVABLE_OR_RANDOM = 3 RESOLVABLE_OR_RANDOM = 3
TYPE_NAMES = { @classmethod
PUBLIC: 'PUBLIC', def type_spec(cls):
RANDOM: 'RANDOM', return {'size': 1, 'mapper': lambda x: OwnAddressType(x).name}
RESOLVABLE_OR_PUBLIC: 'RESOLVABLE_OR_PUBLIC',
RESOLVABLE_OR_RANDOM: 'RESOLVABLE_OR_RANDOM',
}
@staticmethod
def type_name(type_id):
return name_or_number(OwnAddressType.TYPE_NAMES, type_id)
# pylint: disable-next=unnecessary-lambda # -----------------------------------------------------------------------------
TYPE_SPEC = {'size': 1, 'mapper': lambda x: OwnAddressType.type_name(x)} class LoopbackMode(enum.IntEnum):
DISABLED = 0
LOCAL = 1
REMOTE = 2
@classmethod
def type_spec(cls):
return {'size': 1, 'mapper': lambda x: LoopbackMode(x).name}
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -3310,6 +3363,27 @@ class HCI_Read_Encryption_Key_Size_Command(HCI_Command):
''' '''
# -----------------------------------------------------------------------------
@HCI_Command.command(
return_parameters_fields=[
('status', STATUS_SPEC),
('loopback_mode', LoopbackMode.type_spec()),
],
)
class HCI_Read_Loopback_Mode_Command(HCI_Command):
'''
See Bluetooth spec @ 7.6.1 Read Loopback Mode Command
'''
# -----------------------------------------------------------------------------
@HCI_Command.command([('loopback_mode', 1)])
class HCI_Write_Loopback_Mode_Command(HCI_Command):
'''
See Bluetooth spec @ 7.6.2 Write Loopback Mode Command
'''
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@HCI_Command.command([('le_event_mask', 8)]) @HCI_Command.command([('le_event_mask', 8)])
class HCI_LE_Set_Event_Mask_Command(HCI_Command): class HCI_LE_Set_Event_Mask_Command(HCI_Command):
@@ -3374,7 +3448,7 @@ class HCI_LE_Set_Random_Address_Command(HCI_Command):
), ),
}, },
), ),
('own_address_type', OwnAddressType.TYPE_SPEC), ('own_address_type', OwnAddressType.type_spec()),
('peer_address_type', Address.ADDRESS_TYPE_SPEC), ('peer_address_type', Address.ADDRESS_TYPE_SPEC),
('peer_address', Address.parse_address_preceded_by_type), ('peer_address', Address.parse_address_preceded_by_type),
('advertising_channel_map', 1), ('advertising_channel_map', 1),
@@ -3467,7 +3541,7 @@ class HCI_LE_Set_Advertising_Enable_Command(HCI_Command):
('le_scan_type', 1), ('le_scan_type', 1),
('le_scan_interval', 2), ('le_scan_interval', 2),
('le_scan_window', 2), ('le_scan_window', 2),
('own_address_type', OwnAddressType.TYPE_SPEC), ('own_address_type', OwnAddressType.type_spec()),
('scanning_filter_policy', 1), ('scanning_filter_policy', 1),
] ]
) )
@@ -3506,7 +3580,7 @@ class HCI_LE_Set_Scan_Enable_Command(HCI_Command):
('initiator_filter_policy', 1), ('initiator_filter_policy', 1),
('peer_address_type', Address.ADDRESS_TYPE_SPEC), ('peer_address_type', Address.ADDRESS_TYPE_SPEC),
('peer_address', Address.parse_address_preceded_by_type), ('peer_address', Address.parse_address_preceded_by_type),
('own_address_type', OwnAddressType.TYPE_SPEC), ('own_address_type', OwnAddressType.type_spec()),
('connection_interval_min', 2), ('connection_interval_min', 2),
('connection_interval_max', 2), ('connection_interval_max', 2),
('max_latency', 2), ('max_latency', 2),
@@ -3913,7 +3987,7 @@ class HCI_LE_Set_Advertising_Set_Random_Address_Command(HCI_Command):
), ),
}, },
), ),
('own_address_type', OwnAddressType.TYPE_SPEC), ('own_address_type', OwnAddressType.type_spec()),
('peer_address_type', Address.ADDRESS_TYPE_SPEC), ('peer_address_type', Address.ADDRESS_TYPE_SPEC),
('peer_address', Address.parse_address_preceded_by_type), ('peer_address', Address.parse_address_preceded_by_type),
('advertising_filter_policy', 1), ('advertising_filter_policy', 1),
@@ -4309,7 +4383,7 @@ class HCI_LE_Extended_Create_Connection_Command(HCI_Command):
('initiator_filter_policy:', self.initiator_filter_policy), ('initiator_filter_policy:', self.initiator_filter_policy),
( (
'own_address_type: ', 'own_address_type: ',
OwnAddressType.type_name(self.own_address_type), OwnAddressType(self.own_address_type).name,
), ),
( (
'peer_address_type: ', 'peer_address_type: ',
@@ -4551,6 +4625,10 @@ class HCI_LE_Setup_ISO_Data_Path_Command(HCI_Command):
See Bluetooth spec @ 7.8.109 LE Setup ISO Data Path command See Bluetooth spec @ 7.8.109 LE Setup ISO Data Path command
''' '''
class Direction(enum.IntEnum):
HOST_TO_CONTROLLER = 0x00
CONTROLLER_TO_HOST = 0x01
connection_handle: int connection_handle: int
data_path_direction: int data_path_direction: int
data_path_id: int data_path_id: int
@@ -5190,6 +5268,21 @@ HCI_LE_Meta_Event.subevent_classes[
] = HCI_LE_Extended_Advertising_Report_Event ] = HCI_LE_Extended_Advertising_Report_Event
# -----------------------------------------------------------------------------
@HCI_LE_Meta_Event.event(
[
('status', 1),
('advertising_handle', 1),
('connection_handle', 2),
('number_completed_extended_advertising_events', 1),
]
)
class HCI_LE_Advertising_Set_Terminated_Event(HCI_LE_Meta_Event):
'''
See Bluetooth spec @ 7.7.65.18 LE Advertising Set Terminated Event
'''
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@HCI_LE_Meta_Event.event([('connection_handle', 2), ('channel_selection_algorithm', 1)]) @HCI_LE_Meta_Event.event([('connection_handle', 2), ('channel_selection_algorithm', 1)])
class HCI_LE_Channel_Selection_Algorithm_Event(HCI_LE_Meta_Event): class HCI_LE_Channel_Selection_Algorithm_Event(HCI_LE_Meta_Event):
+14
View File
@@ -37,6 +37,7 @@ from bumble.l2cap import (
L2CAP_Connection_Response, L2CAP_Connection_Response,
) )
from bumble.hci import ( from bumble.hci import (
Address,
HCI_EVENT_PACKET, HCI_EVENT_PACKET,
HCI_ACL_DATA_PACKET, HCI_ACL_DATA_PACKET,
HCI_DISCONNECTION_COMPLETE_EVENT, HCI_DISCONNECTION_COMPLETE_EVENT,
@@ -48,6 +49,7 @@ from bumble.hci import (
) )
from bumble.rfcomm import RFCOMM_Frame, RFCOMM_PSM from bumble.rfcomm import RFCOMM_Frame, RFCOMM_PSM
from bumble.sdp import SDP_PDU, SDP_PSM from bumble.sdp import SDP_PDU, SDP_PSM
from bumble import crypto
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -232,3 +234,15 @@ class PacketTracer:
) )
self.host_to_controller_analyzer.peer = self.controller_to_host_analyzer self.host_to_controller_analyzer.peer = self.controller_to_host_analyzer
self.controller_to_host_analyzer.peer = self.host_to_controller_analyzer self.controller_to_host_analyzer.peer = self.host_to_controller_analyzer
def generate_irk() -> bytes:
return crypto.r()
def verify_rpa_with_irk(rpa: Address, irk: bytes) -> bool:
rpa_bytes = bytes(rpa)
prand_given = rpa_bytes[3:]
hash_given = rpa_bytes[:3]
hash_local = crypto.ah(irk, prand_given)
return hash_local[:3] == hash_given
+314 -93
View File
@@ -19,16 +19,17 @@ from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
import logging import logging
import enum import enum
import struct
from abc import ABC, abstractmethod
from pyee import EventEmitter from pyee import EventEmitter
from typing import Optional, TYPE_CHECKING from typing import Optional, Callable, TYPE_CHECKING
from typing_extensions import override
from bumble import l2cap from bumble import l2cap, device
from bumble.colors import color from bumble.colors import color
from bumble.core import InvalidStateError, ProtocolError from bumble.core import InvalidStateError, ProtocolError
from .hci import Address
if TYPE_CHECKING:
from bumble.device import Device, Connection
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -60,6 +61,7 @@ class Message:
NOT_READY = 0x01 NOT_READY = 0x01
ERR_INVALID_REPORT_ID = 0x02 ERR_INVALID_REPORT_ID = 0x02
ERR_UNSUPPORTED_REQUEST = 0x03 ERR_UNSUPPORTED_REQUEST = 0x03
ERR_INVALID_PARAMETER = 0x04
ERR_UNKNOWN = 0x0E ERR_UNKNOWN = 0x0E
ERR_FATAL = 0x0F ERR_FATAL = 0x0F
@@ -101,13 +103,14 @@ class GetReportMessage(Message):
def __bytes__(self) -> bytes: def __bytes__(self) -> bytes:
packet_bytes = bytearray() packet_bytes = bytearray()
packet_bytes.append(self.report_id) packet_bytes.append(self.report_id)
packet_bytes.extend( if self.buffer_size == 0:
[(self.buffer_size & 0xFF), ((self.buffer_size >> 8) & 0xFF)]
)
if self.report_type == Message.ReportType.OTHER_REPORT:
return self.header(self.report_type) + packet_bytes return self.header(self.report_type) + packet_bytes
else: else:
return self.header(0x08 | self.report_type) + packet_bytes return (
self.header(0x08 | self.report_type)
+ packet_bytes
+ struct.pack("<H", self.buffer_size)
)
@dataclass @dataclass
@@ -120,6 +123,16 @@ class SetReportMessage(Message):
return self.header(self.report_type) + self.data return self.header(self.report_type) + self.data
@dataclass
class SendControlData(Message):
report_type: int
data: bytes
message_type = Message.MessageType.DATA
def __bytes__(self) -> bytes:
return self.header(self.report_type) + self.data
@dataclass @dataclass
class GetProtocolMessage(Message): class GetProtocolMessage(Message):
message_type = Message.MessageType.GET_PROTOCOL message_type = Message.MessageType.GET_PROTOCOL
@@ -161,31 +174,47 @@ class VirtualCableUnplug(Message):
return self.header(Message.ControlCommand.VIRTUAL_CABLE_UNPLUG) return self.header(Message.ControlCommand.VIRTUAL_CABLE_UNPLUG)
# Device sends input report, host sends output report.
@dataclass @dataclass
class SendData(Message): class SendData(Message):
data: bytes data: bytes
report_type: int
message_type = Message.MessageType.DATA message_type = Message.MessageType.DATA
def __bytes__(self) -> bytes: def __bytes__(self) -> bytes:
return self.header(Message.ReportType.OUTPUT_REPORT) + self.data return self.header(self.report_type) + self.data
@dataclass
class SendHandshakeMessage(Message):
result_code: int
message_type = Message.MessageType.HANDSHAKE
def __bytes__(self) -> bytes:
return self.header(self.result_code)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Host(EventEmitter): class HID(ABC, EventEmitter):
l2cap_ctrl_channel: Optional[l2cap.ClassicChannel] l2cap_ctrl_channel: Optional[l2cap.ClassicChannel] = None
l2cap_intr_channel: Optional[l2cap.ClassicChannel] l2cap_intr_channel: Optional[l2cap.ClassicChannel] = None
connection: Optional[device.Connection] = None
def __init__(self, device: Device, connection: Connection) -> None: class Role(enum.IntEnum):
HOST = 0x00
DEVICE = 0x01
def __init__(self, device: device.Device, role: Role) -> None:
super().__init__() super().__init__()
self.remote_device_bd_address: Optional[Address] = None
self.device = device self.device = device
self.connection = connection self.role = role
self.l2cap_ctrl_channel = None
self.l2cap_intr_channel = None
# Register ourselves with the L2CAP channel manager # Register ourselves with the L2CAP channel manager
device.register_l2cap_server(HID_CONTROL_PSM, self.on_connection) device.register_l2cap_server(HID_CONTROL_PSM, self.on_l2cap_connection)
device.register_l2cap_server(HID_INTERRUPT_PSM, self.on_connection) device.register_l2cap_server(HID_INTERRUPT_PSM, self.on_l2cap_connection)
device.on('connection', self.on_device_connection)
async def connect_control_channel(self) -> None: async def connect_control_channel(self) -> None:
# Create a new L2CAP connection - control channel # Create a new L2CAP connection - control channel
@@ -229,9 +258,18 @@ class Host(EventEmitter):
self.l2cap_ctrl_channel = None self.l2cap_ctrl_channel = None
await channel.disconnect() await channel.disconnect()
def on_connection(self, l2cap_channel: l2cap.ClassicChannel) -> None: def on_device_connection(self, connection: device.Connection) -> None:
self.connection = connection
self.remote_device_bd_address = connection.peer_address
connection.on('disconnection', self.on_device_disconnection)
def on_device_disconnection(self, reason: int) -> None:
self.connection = None
def on_l2cap_connection(self, l2cap_channel: l2cap.ClassicChannel) -> None:
logger.debug(f'+++ New L2CAP connection: {l2cap_channel}') logger.debug(f'+++ New L2CAP connection: {l2cap_channel}')
l2cap_channel.on('open', lambda: self.on_l2cap_channel_open(l2cap_channel)) l2cap_channel.on('open', lambda: self.on_l2cap_channel_open(l2cap_channel))
l2cap_channel.on('close', lambda: self.on_l2cap_channel_close(l2cap_channel))
def on_l2cap_channel_open(self, l2cap_channel: l2cap.ClassicChannel) -> None: def on_l2cap_channel_open(self, l2cap_channel: l2cap.ClassicChannel) -> None:
if l2cap_channel.psm == HID_CONTROL_PSM: if l2cap_channel.psm == HID_CONTROL_PSM:
@@ -242,63 +280,20 @@ class Host(EventEmitter):
self.l2cap_intr_channel.sink = self.on_intr_pdu self.l2cap_intr_channel.sink = self.on_intr_pdu
logger.debug(f'$$$ L2CAP channel open: {l2cap_channel}') logger.debug(f'$$$ L2CAP channel open: {l2cap_channel}')
def on_ctrl_pdu(self, pdu: bytes) -> None: def on_l2cap_channel_close(self, l2cap_channel: l2cap.ClassicChannel) -> None:
logger.debug(f'<<< HID CONTROL PDU: {pdu.hex()}') if l2cap_channel.psm == HID_CONTROL_PSM:
# Here we will receive all kinds of packets, parse and then call respective callbacks self.l2cap_ctrl_channel = None
message_type = pdu[0] >> 4
param = pdu[0] & 0x0F
if message_type == Message.MessageType.HANDSHAKE:
logger.debug(f'<<< HID HANDSHAKE: {Message.Handshake(param).name}')
self.emit('handshake', Message.Handshake(param))
elif message_type == Message.MessageType.DATA:
logger.debug('<<< HID CONTROL DATA')
self.emit('data', pdu)
elif message_type == Message.MessageType.CONTROL:
if param == Message.ControlCommand.SUSPEND:
logger.debug('<<< HID SUSPEND')
self.emit('suspend', pdu)
elif param == Message.ControlCommand.EXIT_SUSPEND:
logger.debug('<<< HID EXIT SUSPEND')
self.emit('exit_suspend', pdu)
elif param == Message.ControlCommand.VIRTUAL_CABLE_UNPLUG:
logger.debug('<<< HID VIRTUAL CABLE UNPLUG')
self.emit('virtual_cable_unplug')
else:
logger.debug('<<< HID CONTROL OPERATION UNSUPPORTED')
else: else:
logger.debug('<<< HID CONTROL DATA') self.l2cap_intr_channel = None
self.emit('data', pdu) logger.debug(f'$$$ L2CAP channel close: {l2cap_channel}')
@abstractmethod
def on_ctrl_pdu(self, pdu: bytes) -> None:
pass
def on_intr_pdu(self, pdu: bytes) -> None: def on_intr_pdu(self, pdu: bytes) -> None:
logger.debug(f'<<< HID INTERRUPT PDU: {pdu.hex()}') logger.debug(f'<<< HID INTERRUPT PDU: {pdu.hex()}')
self.emit("data", pdu) self.emit("interrupt_data", pdu)
def get_report(self, report_type: int, report_id: int, buffer_size: int) -> None:
msg = GetReportMessage(
report_type=report_type, report_id=report_id, buffer_size=buffer_size
)
hid_message = bytes(msg)
logger.debug(f'>>> HID CONTROL GET REPORT, PDU: {hid_message.hex()}')
self.send_pdu_on_ctrl(hid_message)
def set_report(self, report_type: int, data: bytes):
msg = SetReportMessage(report_type=report_type, data=data)
hid_message = bytes(msg)
logger.debug(f'>>> HID CONTROL SET REPORT, PDU:{hid_message.hex()}')
self.send_pdu_on_ctrl(hid_message)
def get_protocol(self):
msg = GetProtocolMessage()
hid_message = bytes(msg)
logger.debug(f'>>> HID CONTROL GET PROTOCOL, PDU: {hid_message.hex()}')
self.send_pdu_on_ctrl(hid_message)
def set_protocol(self, protocol_mode: int):
msg = SetProtocolMessage(protocol_mode=protocol_mode)
hid_message = bytes(msg)
logger.debug(f'>>> HID CONTROL SET PROTOCOL, PDU: {hid_message.hex()}')
self.send_pdu_on_ctrl(hid_message)
def send_pdu_on_ctrl(self, msg: bytes) -> None: def send_pdu_on_ctrl(self, msg: bytes) -> None:
assert self.l2cap_ctrl_channel assert self.l2cap_ctrl_channel
@@ -308,26 +303,252 @@ class Host(EventEmitter):
assert self.l2cap_intr_channel assert self.l2cap_intr_channel
self.l2cap_intr_channel.send_pdu(msg) self.l2cap_intr_channel.send_pdu(msg)
def send_data(self, data): def send_data(self, data: bytes) -> None:
msg = SendData(data) if self.role == HID.Role.HOST:
report_type = Message.ReportType.OUTPUT_REPORT
else:
report_type = Message.ReportType.INPUT_REPORT
msg = SendData(data, report_type)
hid_message = bytes(msg) hid_message = bytes(msg)
logger.debug(f'>>> HID INTERRUPT SEND DATA, PDU: {hid_message.hex()}') if self.l2cap_intr_channel is not None:
self.send_pdu_on_intr(hid_message) logger.debug(f'>>> HID INTERRUPT SEND DATA, PDU: {hid_message.hex()}')
self.send_pdu_on_intr(hid_message)
def suspend(self): def virtual_cable_unplug(self) -> None:
msg = Suspend()
hid_message = bytes(msg)
logger.debug(f'>>> HID CONTROL SUSPEND, PDU:{hid_message.hex()}')
self.send_pdu_on_ctrl(msg)
def exit_suspend(self):
msg = ExitSuspend()
hid_message = bytes(msg)
logger.debug(f'>>> HID CONTROL EXIT SUSPEND, PDU:{hid_message.hex()}')
self.send_pdu_on_ctrl(msg)
def virtual_cable_unplug(self):
msg = VirtualCableUnplug() msg = VirtualCableUnplug()
hid_message = bytes(msg) hid_message = bytes(msg)
logger.debug(f'>>> HID CONTROL VIRTUAL CABLE UNPLUG, PDU: {hid_message.hex()}') logger.debug(f'>>> HID CONTROL VIRTUAL CABLE UNPLUG, PDU: {hid_message.hex()}')
self.send_pdu_on_ctrl(msg) self.send_pdu_on_ctrl(hid_message)
# -----------------------------------------------------------------------------
class Device(HID):
class GetSetReturn(enum.IntEnum):
FAILURE = 0x00
REPORT_ID_NOT_FOUND = 0x01
ERR_UNSUPPORTED_REQUEST = 0x02
ERR_UNKNOWN = 0x03
ERR_INVALID_PARAMETER = 0x04
SUCCESS = 0xFF
class GetSetStatus:
def __init__(self) -> None:
self.data = bytearray()
self.status = 0
def __init__(self, device: device.Device) -> None:
super().__init__(device, HID.Role.DEVICE)
get_report_cb: Optional[Callable[[int, int, int], None]] = None
set_report_cb: Optional[Callable[[int, int, int, bytes], None]] = None
get_protocol_cb: Optional[Callable[[], None]] = None
set_protocol_cb: Optional[Callable[[int], None]] = None
@override
def on_ctrl_pdu(self, pdu: bytes) -> None:
logger.debug(f'<<< HID CONTROL PDU: {pdu.hex()}')
param = pdu[0] & 0x0F
message_type = pdu[0] >> 4
if message_type == Message.MessageType.GET_REPORT:
logger.debug('<<< HID GET REPORT')
self.handle_get_report(pdu)
elif message_type == Message.MessageType.SET_REPORT:
logger.debug('<<< HID SET REPORT')
self.handle_set_report(pdu)
elif message_type == Message.MessageType.GET_PROTOCOL:
logger.debug('<<< HID GET PROTOCOL')
self.handle_get_protocol(pdu)
elif message_type == Message.MessageType.SET_PROTOCOL:
logger.debug('<<< HID SET PROTOCOL')
self.handle_set_protocol(pdu)
elif message_type == Message.MessageType.DATA:
logger.debug('<<< HID CONTROL DATA')
self.emit('control_data', pdu)
elif message_type == Message.MessageType.CONTROL:
if param == Message.ControlCommand.SUSPEND:
logger.debug('<<< HID SUSPEND')
self.emit('suspend')
elif param == Message.ControlCommand.EXIT_SUSPEND:
logger.debug('<<< HID EXIT SUSPEND')
self.emit('exit_suspend')
elif param == Message.ControlCommand.VIRTUAL_CABLE_UNPLUG:
logger.debug('<<< HID VIRTUAL CABLE UNPLUG')
self.emit('virtual_cable_unplug')
else:
logger.debug('<<< HID CONTROL OPERATION UNSUPPORTED')
else:
logger.debug('<<< HID MESSAGE TYPE UNSUPPORTED')
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
def send_handshake_message(self, result_code: int) -> None:
msg = SendHandshakeMessage(result_code)
hid_message = bytes(msg)
logger.debug(f'>>> HID HANDSHAKE MESSAGE, PDU: {hid_message.hex()}')
self.send_pdu_on_ctrl(hid_message)
def send_control_data(self, report_type: int, data: bytes):
msg = SendControlData(report_type=report_type, data=data)
hid_message = bytes(msg)
logger.debug(f'>>> HID CONTROL DATA: {hid_message.hex()}')
self.send_pdu_on_ctrl(hid_message)
def handle_get_report(self, pdu: bytes):
if self.get_report_cb is None:
logger.debug("GetReport callback not registered !!")
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
return
report_type = pdu[0] & 0x03
buffer_flag = (pdu[0] & 0x08) >> 3
report_id = pdu[1]
logger.debug(f"buffer_flag: {buffer_flag}")
if buffer_flag == 1:
buffer_size = (pdu[3] << 8) | pdu[2]
else:
buffer_size = 0
ret = self.get_report_cb(report_id, report_type, buffer_size)
assert ret is not None
if ret.status == self.GetSetReturn.FAILURE:
self.send_handshake_message(Message.Handshake.ERR_UNKNOWN)
elif ret.status == self.GetSetReturn.SUCCESS:
data = bytearray()
data.append(report_id)
data.extend(ret.data)
if len(data) < self.l2cap_ctrl_channel.mtu: # type: ignore[union-attr]
self.send_control_data(report_type=report_type, data=data)
else:
self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER)
elif ret.status == self.GetSetReturn.REPORT_ID_NOT_FOUND:
self.send_handshake_message(Message.Handshake.ERR_INVALID_REPORT_ID)
elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER:
self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER)
elif ret.status == self.GetSetReturn.ERR_UNSUPPORTED_REQUEST:
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
def register_get_report_cb(self, cb: Callable[[int, int, int], None]) -> None:
self.get_report_cb = cb
logger.debug("GetReport callback registered successfully")
def handle_set_report(self, pdu: bytes):
if self.set_report_cb is None:
logger.debug("SetReport callback not registered !!")
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
return
report_type = pdu[0] & 0x03
report_id = pdu[1]
report_data = pdu[2:]
report_size = len(report_data) + 1
ret = self.set_report_cb(report_id, report_type, report_size, report_data)
assert ret is not None
if ret.status == self.GetSetReturn.SUCCESS:
self.send_handshake_message(Message.Handshake.SUCCESSFUL)
elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER:
self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER)
elif ret.status == self.GetSetReturn.REPORT_ID_NOT_FOUND:
self.send_handshake_message(Message.Handshake.ERR_INVALID_REPORT_ID)
else:
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
def register_set_report_cb(
self, cb: Callable[[int, int, int, bytes], None]
) -> None:
self.set_report_cb = cb
logger.debug("SetReport callback registered successfully")
def handle_get_protocol(self, pdu: bytes):
if self.get_protocol_cb is None:
logger.debug("GetProtocol callback not registered !!")
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
return
ret = self.get_protocol_cb()
assert ret is not None
if ret.status == self.GetSetReturn.SUCCESS:
self.send_control_data(Message.ReportType.OTHER_REPORT, ret.data)
else:
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
def register_get_protocol_cb(self, cb: Callable[[], None]) -> None:
self.get_protocol_cb = cb
logger.debug("GetProtocol callback registered successfully")
def handle_set_protocol(self, pdu: bytes):
if self.set_protocol_cb is None:
logger.debug("SetProtocol callback not registered !!")
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
return
ret = self.set_protocol_cb(pdu[0] & 0x01)
assert ret is not None
if ret.status == self.GetSetReturn.SUCCESS:
self.send_handshake_message(Message.Handshake.SUCCESSFUL)
else:
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
def register_set_protocol_cb(self, cb: Callable[[int], None]) -> None:
self.set_protocol_cb = cb
logger.debug("SetProtocol callback registered successfully")
# -----------------------------------------------------------------------------
class Host(HID):
def __init__(self, device: device.Device) -> None:
super().__init__(device, HID.Role.HOST)
def get_report(self, report_type: int, report_id: int, buffer_size: int) -> None:
msg = GetReportMessage(
report_type=report_type, report_id=report_id, buffer_size=buffer_size
)
hid_message = bytes(msg)
logger.debug(f'>>> HID CONTROL GET REPORT, PDU: {hid_message.hex()}')
self.send_pdu_on_ctrl(hid_message)
def set_report(self, report_type: int, data: bytes) -> None:
msg = SetReportMessage(report_type=report_type, data=data)
hid_message = bytes(msg)
logger.debug(f'>>> HID CONTROL SET REPORT, PDU:{hid_message.hex()}')
self.send_pdu_on_ctrl(hid_message)
def get_protocol(self) -> None:
msg = GetProtocolMessage()
hid_message = bytes(msg)
logger.debug(f'>>> HID CONTROL GET PROTOCOL, PDU: {hid_message.hex()}')
self.send_pdu_on_ctrl(hid_message)
def set_protocol(self, protocol_mode: int) -> None:
msg = SetProtocolMessage(protocol_mode=protocol_mode)
hid_message = bytes(msg)
logger.debug(f'>>> HID CONTROL SET PROTOCOL, PDU: {hid_message.hex()}')
self.send_pdu_on_ctrl(hid_message)
def suspend(self) -> None:
msg = Suspend()
hid_message = bytes(msg)
logger.debug(f'>>> HID CONTROL SUSPEND, PDU:{hid_message.hex()}')
self.send_pdu_on_ctrl(hid_message)
def exit_suspend(self) -> None:
msg = ExitSuspend()
hid_message = bytes(msg)
logger.debug(f'>>> HID CONTROL EXIT SUSPEND, PDU:{hid_message.hex()}')
self.send_pdu_on_ctrl(hid_message)
@override
def on_ctrl_pdu(self, pdu: bytes) -> None:
logger.debug(f'<<< HID CONTROL PDU: {pdu.hex()}')
param = pdu[0] & 0x0F
message_type = pdu[0] >> 4
if message_type == Message.MessageType.HANDSHAKE:
logger.debug(f'<<< HID HANDSHAKE: {Message.Handshake(param).name}')
self.emit('handshake', Message.Handshake(param))
elif message_type == Message.MessageType.DATA:
logger.debug('<<< HID CONTROL DATA')
self.emit('control_data', pdu)
elif message_type == Message.MessageType.CONTROL:
if param == Message.ControlCommand.VIRTUAL_CABLE_UNPLUG:
logger.debug('<<< HID VIRTUAL CABLE UNPLUG')
self.emit('virtual_cable_unplug')
else:
logger.debug('<<< HID CONTROL OPERATION UNSUPPORTED')
else:
logger.debug('<<< HID MESSAGE TYPE UNSUPPORTED')
+131 -90
View File
@@ -21,7 +21,7 @@ import collections
import logging import logging
import struct import struct
from typing import Optional, TYPE_CHECKING, Dict, Callable, Awaitable, cast from typing import Any, Awaitable, Callable, Deque, Dict, Optional, cast, TYPE_CHECKING
from bumble.colors import color from bumble.colors import color
from bumble.l2cap import L2CAP_PDU from bumble.l2cap import L2CAP_PDU
@@ -91,16 +91,49 @@ logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Constants class AclPacketQueue:
# ----------------------------------------------------------------------------- max_packet_size: int
# fmt: off
HOST_DEFAULT_HC_LE_ACL_DATA_PACKET_LENGTH = 27 def __init__(
HOST_HC_TOTAL_NUM_LE_ACL_DATA_PACKETS = 1 self,
HOST_DEFAULT_HC_ACL_DATA_PACKET_LENGTH = 27 max_packet_size: int,
HOST_HC_TOTAL_NUM_ACL_DATA_PACKETS = 1 max_in_flight: int,
send: Callable[[HCI_Packet], None],
) -> None:
self.max_packet_size = max_packet_size
self.max_in_flight = max_in_flight
self.in_flight = 0
self.send = send
self.packets: Deque[HCI_AclDataPacket] = collections.deque()
# fmt: on def enqueue(self, packet: HCI_AclDataPacket) -> None:
self.packets.appendleft(packet)
self.check_queue()
if self.packets:
logger.debug(
f'{self.in_flight} ACL packets in flight, '
f'{len(self.packets)} in queue'
)
def check_queue(self) -> None:
while self.packets and self.in_flight < self.max_in_flight:
packet = self.packets.pop()
self.send(packet)
self.in_flight += 1
def on_packets_completed(self, packet_count: int) -> None:
if packet_count > self.in_flight:
logger.warning(
color(
'!!! {packet_count} completed but only '
f'{self.in_flight} in flight'
)
)
packet_count = self.in_flight
self.in_flight -= packet_count
self.check_queue()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -111,6 +144,13 @@ class Connection:
self.peer_address = peer_address self.peer_address = peer_address
self.assembler = HCI_AclDataPacketAssembler(self.on_acl_pdu) self.assembler = HCI_AclDataPacketAssembler(self.on_acl_pdu)
self.transport = transport self.transport = transport
acl_packet_queue: Optional[AclPacketQueue] = (
host.le_acl_packet_queue
if transport == BT_LE_TRANSPORT
else host.acl_packet_queue
)
assert acl_packet_queue
self.acl_packet_queue = acl_packet_queue
def on_hci_acl_data_packet(self, packet: HCI_AclDataPacket) -> None: def on_hci_acl_data_packet(self, packet: HCI_AclDataPacket) -> None:
self.assembler.feed_packet(packet) self.assembler.feed_packet(packet)
@@ -123,8 +163,10 @@ class Connection:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Host(AbortableEventEmitter): class Host(AbortableEventEmitter):
connections: Dict[int, Connection] connections: Dict[int, Connection]
acl_packet_queue: collections.deque[HCI_AclDataPacket] acl_packet_queue: Optional[AclPacketQueue] = None
hci_sink: TransportSink le_acl_packet_queue: Optional[AclPacketQueue] = None
hci_sink: Optional[TransportSink] = None
hci_metadata: Dict[str, Any]
long_term_key_provider: Optional[ long_term_key_provider: Optional[
Callable[[int, bytes, int], Awaitable[Optional[bytes]]] Callable[[int, bytes, int], Awaitable[Optional[bytes]]]
] ]
@@ -137,18 +179,11 @@ class Host(AbortableEventEmitter):
) -> None: ) -> None:
super().__init__() super().__init__()
self.hci_metadata = None self.hci_metadata = {}
self.ready = False # True when we can accept incoming packets self.ready = False # True when we can accept incoming packets
self.reset_done = False
self.connections = {} # Connections, by connection handle self.connections = {} # Connections, by connection handle
self.pending_command = None self.pending_command = None
self.pending_response = None self.pending_response = None
self.hc_le_acl_data_packet_length = HOST_DEFAULT_HC_LE_ACL_DATA_PACKET_LENGTH
self.hc_total_num_le_acl_data_packets = HOST_HC_TOTAL_NUM_LE_ACL_DATA_PACKETS
self.hc_acl_data_packet_length = HOST_DEFAULT_HC_ACL_DATA_PACKET_LENGTH
self.hc_total_num_acl_data_packets = HOST_HC_TOTAL_NUM_ACL_DATA_PACKETS
self.acl_packet_queue = collections.deque()
self.acl_packets_in_flight = 0
self.local_version = None self.local_version = None
self.local_supported_commands = bytes(64) self.local_supported_commands = bytes(64)
self.local_le_features = 0 self.local_le_features = 0
@@ -162,10 +197,7 @@ class Host(AbortableEventEmitter):
# Connect to the source and sink if specified # Connect to the source and sink if specified
if controller_source: if controller_source:
controller_source.set_packet_sink(self) self.set_packet_source(controller_source)
self.hci_metadata = getattr(
controller_source, 'metadata', self.hci_metadata
)
if controller_sink: if controller_sink:
self.set_packet_sink(controller_sink) self.set_packet_sink(controller_sink)
@@ -200,17 +232,21 @@ class Host(AbortableEventEmitter):
self.ready = False self.ready = False
await self.flush() await self.flush()
await self.send_command(HCI_Reset_Command(), check_result=True)
self.ready = True
# Instantiate and init a driver for the host if needed. # Instantiate and init a driver for the host if needed.
# NOTE: we don't keep a reference to the driver here, because we don't # NOTE: we don't keep a reference to the driver here, because we don't
# currently have a need for the driver later on. But if the driver interface # currently have a need for the driver later on. But if the driver interface
# evolves, it may be required, then, to store a reference to the driver in # evolves, it may be required, then, to store a reference to the driver in
# an object property. # an object property.
reset_needed = True
if driver_factory is not None: if driver_factory is not None:
if driver := await driver_factory(self): if driver := await driver_factory(self):
await driver.init_controller() await driver.init_controller()
reset_needed = False
# Send a reset command unless a driver has already done so.
if reset_needed:
await self.send_command(HCI_Reset_Command(), check_result=True)
self.ready = True
response = await self.send_command( response = await self.send_command(
HCI_Read_Local_Supported_Commands_Command(), check_result=True HCI_Read_Local_Supported_Commands_Command(), check_result=True
@@ -253,46 +289,54 @@ class Host(AbortableEventEmitter):
response = await self.send_command( response = await self.send_command(
HCI_Read_Buffer_Size_Command(), check_result=True HCI_Read_Buffer_Size_Command(), check_result=True
) )
self.hc_acl_data_packet_length = ( hc_acl_data_packet_length = (
response.return_parameters.hc_acl_data_packet_length response.return_parameters.hc_acl_data_packet_length
) )
self.hc_total_num_acl_data_packets = ( hc_total_num_acl_data_packets = (
response.return_parameters.hc_total_num_acl_data_packets response.return_parameters.hc_total_num_acl_data_packets
) )
logger.debug( logger.debug(
'HCI ACL flow control: ' 'HCI ACL flow control: '
f'hc_acl_data_packet_length={self.hc_acl_data_packet_length},' f'hc_acl_data_packet_length={hc_acl_data_packet_length},'
f'hc_total_num_acl_data_packets={self.hc_total_num_acl_data_packets}' f'hc_total_num_acl_data_packets={hc_total_num_acl_data_packets}'
) )
self.acl_packet_queue = AclPacketQueue(
max_packet_size=hc_acl_data_packet_length,
max_in_flight=hc_total_num_acl_data_packets,
send=self.send_hci_packet,
)
hc_le_acl_data_packet_length = 0
hc_total_num_le_acl_data_packets = 0
if self.supports_command(HCI_LE_READ_BUFFER_SIZE_COMMAND): if self.supports_command(HCI_LE_READ_BUFFER_SIZE_COMMAND):
response = await self.send_command( response = await self.send_command(
HCI_LE_Read_Buffer_Size_Command(), check_result=True HCI_LE_Read_Buffer_Size_Command(), check_result=True
) )
self.hc_le_acl_data_packet_length = ( hc_le_acl_data_packet_length = (
response.return_parameters.hc_le_acl_data_packet_length response.return_parameters.hc_le_acl_data_packet_length
) )
self.hc_total_num_le_acl_data_packets = ( hc_total_num_le_acl_data_packets = (
response.return_parameters.hc_total_num_le_acl_data_packets response.return_parameters.hc_total_num_le_acl_data_packets
) )
logger.debug( logger.debug(
'HCI LE ACL flow control: ' 'HCI LE ACL flow control: '
f'hc_le_acl_data_packet_length={self.hc_le_acl_data_packet_length},' f'hc_le_acl_data_packet_length={hc_le_acl_data_packet_length},'
'hc_total_num_le_acl_data_packets=' f'hc_total_num_le_acl_data_packets={hc_total_num_le_acl_data_packets}'
f'{self.hc_total_num_le_acl_data_packets}'
) )
if ( if hc_le_acl_data_packet_length == 0 or hc_total_num_le_acl_data_packets == 0:
response.return_parameters.hc_le_acl_data_packet_length == 0 # LE and Classic share the same queue
or response.return_parameters.hc_total_num_le_acl_data_packets == 0 self.le_acl_packet_queue = self.acl_packet_queue
): else:
# LE and Classic share the same values # Create a separate queue for LE
self.hc_le_acl_data_packet_length = self.hc_acl_data_packet_length self.le_acl_packet_queue = AclPacketQueue(
self.hc_total_num_le_acl_data_packets = ( max_packet_size=hc_le_acl_data_packet_length,
self.hc_total_num_acl_data_packets max_in_flight=hc_total_num_le_acl_data_packets,
) send=self.send_hci_packet,
)
if self.supports_command( if self.supports_command(
HCI_LE_READ_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND HCI_LE_READ_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND
@@ -313,29 +357,31 @@ class Host(AbortableEventEmitter):
) )
) )
self.reset_done = True
@property @property
def controller(self) -> TransportSink: def controller(self) -> Optional[TransportSink]:
return self.hci_sink return self.hci_sink
@controller.setter @controller.setter
def controller(self, controller): def controller(self, controller) -> None:
self.set_packet_sink(controller) self.set_packet_sink(controller)
if controller: if controller:
controller.set_packet_sink(self) controller.set_packet_sink(self)
def set_packet_sink(self, sink: TransportSink) -> None: def set_packet_sink(self, sink: Optional[TransportSink]) -> None:
self.hci_sink = sink self.hci_sink = sink
def set_packet_source(self, source: TransportSource) -> None:
source.set_packet_sink(self)
self.hci_metadata = getattr(source, 'metadata', self.hci_metadata)
def send_hci_packet(self, packet: HCI_Packet) -> None: def send_hci_packet(self, packet: HCI_Packet) -> None:
logger.debug(f'{color("### HOST -> CONTROLLER", "blue")}: {packet}')
if self.snooper: if self.snooper:
self.snooper.snoop(bytes(packet), Snooper.Direction.HOST_TO_CONTROLLER) self.snooper.snoop(bytes(packet), Snooper.Direction.HOST_TO_CONTROLLER)
self.hci_sink.on_packet(bytes(packet)) if self.hci_sink:
self.hci_sink.on_packet(bytes(packet))
async def send_command(self, command, check_result=False): async def send_command(self, command, check_result=False):
logger.debug(f'{color("### HOST -> CONTROLLER", "blue")}: {command}')
# Wait until we can send (only one pending command at a time) # Wait until we can send (only one pending command at a time)
async with self.command_semaphore: async with self.command_semaphore:
assert self.pending_command is None assert self.pending_command is None
@@ -383,6 +429,17 @@ class Host(AbortableEventEmitter):
asyncio.create_task(send_command(command)) asyncio.create_task(send_command(command))
def send_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes) -> None: def send_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes) -> None:
if not (connection := self.connections.get(connection_handle)):
logger.warning(f'connection 0x{connection_handle:04X} not found')
return
packet_queue = connection.acl_packet_queue
if packet_queue is None:
logger.warning(
f'no ACL packet queue for connection 0x{connection_handle:04X}'
)
return
# Create a PDU
l2cap_pdu = bytes(L2CAP_PDU(cid, pdu)) l2cap_pdu = bytes(L2CAP_PDU(cid, pdu))
# Send the data to the controller via ACL packets # Send the data to the controller via ACL packets
@@ -390,8 +447,7 @@ class Host(AbortableEventEmitter):
offset = 0 offset = 0
pb_flag = 0 pb_flag = 0
while bytes_remaining: while bytes_remaining:
# TODO: support different LE/Classic lengths data_total_length = min(bytes_remaining, packet_queue.max_packet_size)
data_total_length = min(bytes_remaining, self.hc_le_acl_data_packet_length)
acl_packet = HCI_AclDataPacket( acl_packet = HCI_AclDataPacket(
connection_handle=connection_handle, connection_handle=connection_handle,
pb_flag=pb_flag, pb_flag=pb_flag,
@@ -399,34 +455,12 @@ class Host(AbortableEventEmitter):
data_total_length=data_total_length, data_total_length=data_total_length,
data=l2cap_pdu[offset : offset + data_total_length], data=l2cap_pdu[offset : offset + data_total_length],
) )
logger.debug( logger.debug(f'>>> ACL packet enqueue: (CID={cid}) {acl_packet}')
f'{color("### HOST -> CONTROLLER", "blue")}: (CID={cid}) {acl_packet}' packet_queue.enqueue(acl_packet)
)
self.queue_acl_packet(acl_packet)
pb_flag = 1 pb_flag = 1
offset += data_total_length offset += data_total_length
bytes_remaining -= data_total_length bytes_remaining -= data_total_length
def queue_acl_packet(self, acl_packet: HCI_AclDataPacket) -> None:
self.acl_packet_queue.appendleft(acl_packet)
self.check_acl_packet_queue()
if len(self.acl_packet_queue):
logger.debug(
f'{self.acl_packets_in_flight} ACL packets in flight, '
f'{len(self.acl_packet_queue)} in queue'
)
def check_acl_packet_queue(self) -> None:
# Send all we can (TODO: support different LE/Classic limits)
while (
len(self.acl_packet_queue) > 0
and self.acl_packets_in_flight < self.hc_total_num_le_acl_data_packets
):
packet = self.acl_packet_queue.pop()
self.send_hci_packet(packet)
self.acl_packets_in_flight += 1
def supports_command(self, command): def supports_command(self, command):
# Find the support flag position for this command # Find the support flag position for this command
for octet, flags in enumerate(HCI_SUPPORTED_COMMANDS_FLAGS): for octet, flags in enumerate(HCI_SUPPORTED_COMMANDS_FLAGS):
@@ -549,7 +583,7 @@ class Host(AbortableEventEmitter):
# This is used just for the Num_HCI_Command_Packets field, not related to # This is used just for the Num_HCI_Command_Packets field, not related to
# an actual command # an actual command
logger.debug('no-command event') logger.debug('no-command event')
return None return
return self.on_command_processed(event) return self.on_command_processed(event)
@@ -557,18 +591,17 @@ class Host(AbortableEventEmitter):
return self.on_command_processed(event) return self.on_command_processed(event)
def on_hci_number_of_completed_packets_event(self, event): def on_hci_number_of_completed_packets_event(self, event):
total_packets = sum(event.num_completed_packets) for connection_handle, num_completed_packets in zip(
if total_packets <= self.acl_packets_in_flight: event.connection_handles, event.num_completed_packets
self.acl_packets_in_flight -= total_packets ):
self.check_acl_packet_queue() if not (connection := self.connections.get(connection_handle)):
else: logger.warning(
logger.warning( 'received packet completion event for unknown handle '
color( f'0x{connection_handle:04X}'
'!!! {total_packets} completed but only '
f'{self.acl_packets_in_flight} in flight'
) )
) continue
self.acl_packets_in_flight = 0
connection.acl_packet_queue.on_packets_completed(num_completed_packets)
# Classic only # Classic only
def on_hci_connection_request_event(self, event): def on_hci_connection_request_event(self, event):
@@ -721,6 +754,14 @@ class Host(AbortableEventEmitter):
def on_hci_le_extended_advertising_report_event(self, event): def on_hci_le_extended_advertising_report_event(self, event):
self.on_hci_le_advertising_report_event(event) self.on_hci_le_advertising_report_event(event)
def on_hci_le_advertising_set_terminated_event(self, event):
self.emit(
'advertising_set_termination',
event.status,
event.advertising_handle,
event.connection_handle,
)
def on_hci_le_cis_request_event(self, event): def on_hci_le_cis_request_event(self, event):
self.emit( self.emit(
'cis_request', 'cis_request',
+10 -5
View File
@@ -149,9 +149,10 @@ L2CAP_INVALID_CID_IN_REQUEST_REASON = 0x0002
L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_CREDITS = 65535 L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_CREDITS = 65535
L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU = 23 L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU = 23
L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MTU = 65535
L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MPS = 23 L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MPS = 23
L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MPS = 65533 L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MPS = 65533
L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU = 2046 L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU = 2048
L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS = 2048 L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS = 2048
L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS = 256 L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS = 256
@@ -188,8 +189,11 @@ class LeCreditBasedChannelSpec:
or self.max_credits > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_CREDITS or self.max_credits > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_CREDITS
): ):
raise ValueError('max credits out of range') raise ValueError('max credits out of range')
if self.mtu < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU: if (
raise ValueError('MTU too small') self.mtu < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU
or self.mtu > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MTU
):
raise ValueError('MTU out of range')
if ( if (
self.mps < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MPS self.mps < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MPS
or self.mps > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MPS or self.mps > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MPS
@@ -1644,12 +1648,13 @@ class ChannelManager:
def send_pdu(self, connection, cid: int, pdu: Union[SupportsBytes, bytes]) -> None: def send_pdu(self, connection, cid: int, pdu: Union[SupportsBytes, bytes]) -> None:
pdu_str = pdu.hex() if isinstance(pdu, bytes) else str(pdu) pdu_str = pdu.hex() if isinstance(pdu, bytes) else str(pdu)
pdu_bytes = bytes(pdu)
logger.debug( logger.debug(
f'{color(">>> Sending L2CAP PDU", "blue")} ' f'{color(">>> Sending L2CAP PDU", "blue")} '
f'on connection [0x{connection.handle:04X}] (CID={cid}) ' f'on connection [0x{connection.handle:04X}] (CID={cid}) '
f'{connection.peer_address}: {pdu_str}' f'{connection.peer_address}: {len(pdu_bytes)} bytes, {pdu_str}'
) )
self.host.send_l2cap_pdu(connection.handle, cid, bytes(pdu)) self.host.send_l2cap_pdu(connection.handle, cid, pdu_bytes)
def on_pdu(self, connection: Connection, cid: int, pdu: bytes) -> None: def on_pdu(self, connection: Connection, cid: int, pdu: bytes) -> None:
if cid in (L2CAP_SIGNALING_CID, L2CAP_LE_SIGNALING_CID): if cid in (L2CAP_SIGNALING_CID, L2CAP_LE_SIGNALING_CID):
+2 -2
View File
@@ -18,7 +18,7 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import struct import struct
import logging import logging
from typing import List from typing import List, Optional
from bumble import l2cap from bumble import l2cap
from ..core import AdvertisingData from ..core import AdvertisingData
@@ -67,7 +67,7 @@ class AshaService(TemplateService):
self.emit('volume', connection, value[0]) self.emit('volume', connection, value[0])
# Handler for audio control commands # Handler for audio control commands
def on_audio_control_point_write(connection: Connection, value): def on_audio_control_point_write(connection: Optional[Connection], value):
logger.info(f'--- AUDIO CONTROL POINT Write:{value.hex()}') logger.info(f'--- AUDIO CONTROL POINT Write:{value.hex()}')
opcode = value[0] opcode = value[0]
if opcode == AshaService.OPCODE_START: if opcode == AshaService.OPCODE_START:
+753 -2
View File
@@ -23,13 +23,21 @@ import dataclasses
import enum import enum
import struct import struct
import functools import functools
from typing import Optional, List, Union import logging
from typing import Optional, List, Union, Type, Dict, Any, Tuple, cast
from bumble import colors
from bumble import device
from bumble import hci from bumble import hci
from bumble import gatt from bumble import gatt
from bumble import gatt_client from bumble import gatt_client
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Constants # Constants
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -106,7 +114,7 @@ class SamplingFrequency(enum.IntEnum):
'''Bluetooth Assigned Numbers, Section 6.12.5.1 - Sampling Frequency''' '''Bluetooth Assigned Numbers, Section 6.12.5.1 - Sampling Frequency'''
# fmt: off # fmt: off
FREQ_8000 = 0x01 FREQ_8000 = 0x01
FREQ_11025 = 0x02 FREQ_11025 = 0x02
FREQ_16000 = 0x03 FREQ_16000 = 0x03
FREQ_22050 = 0x04 FREQ_22050 = 0x04
@@ -220,6 +228,231 @@ class SupportedFrameDuration(enum.IntFlag):
DURATION_10000_US_PREFERRED = 0b0010 DURATION_10000_US_PREFERRED = 0b0010
# -----------------------------------------------------------------------------
# ASE Operations
# -----------------------------------------------------------------------------
class ASE_Operation:
'''
See Audio Stream Control Service - 5 ASE Control operations.
'''
classes: Dict[int, Type[ASE_Operation]] = {}
op_code: int
name: str
fields: Optional[Sequence[Any]] = None
ase_id: List[int]
class Opcode(enum.IntEnum):
# fmt: off
CONFIG_CODEC = 0x01
CONFIG_QOS = 0x02
ENABLE = 0x03
RECEIVER_START_READY = 0x04
DISABLE = 0x05
RECEIVER_STOP_READY = 0x06
UPDATE_METADATA = 0x07
RELEASE = 0x08
@staticmethod
def from_bytes(pdu: bytes) -> ASE_Operation:
op_code = pdu[0]
cls = ASE_Operation.classes.get(op_code)
if cls is None:
instance = ASE_Operation(pdu)
instance.name = ASE_Operation.Opcode(op_code).name
instance.op_code = op_code
return instance
self = cls.__new__(cls)
ASE_Operation.__init__(self, pdu)
if self.fields is not None:
self.init_from_bytes(pdu, 1)
return self
@staticmethod
def subclass(fields):
def inner(cls: Type[ASE_Operation]):
try:
operation = ASE_Operation.Opcode[cls.__name__[4:].upper()]
cls.name = operation.name
cls.op_code = operation
except:
raise KeyError(f'PDU name {cls.name} not found in Ase_Operation.Opcode')
cls.fields = fields
# Register a factory for this class
ASE_Operation.classes[cls.op_code] = cls
return cls
return inner
def __init__(self, pdu: Optional[bytes] = None, **kwargs) -> None:
if self.fields is not None and kwargs:
hci.HCI_Object.init_from_fields(self, self.fields, kwargs)
if pdu is None:
pdu = bytes([self.op_code]) + hci.HCI_Object.dict_to_bytes(
kwargs, self.fields
)
self.pdu = pdu
def init_from_bytes(self, pdu: bytes, offset: int):
return hci.HCI_Object.init_from_bytes(self, pdu, offset, self.fields)
def __bytes__(self) -> bytes:
return self.pdu
def __str__(self) -> str:
result = f'{colors.color(self.name, "yellow")} '
if fields := getattr(self, 'fields', None):
result += ':\n' + hci.HCI_Object.format_fields(self.__dict__, fields, ' ')
else:
if len(self.pdu) > 1:
result += f': {self.pdu.hex()}'
return result
@ASE_Operation.subclass(
[
[
('ase_id', 1),
('target_latency', 1),
('target_phy', 1),
('codec_id', hci.CodingFormat.parse_from_bytes),
('codec_specific_configuration', 'v'),
],
]
)
class ASE_Config_Codec(ASE_Operation):
'''
See Audio Stream Control Service 5.1 - Config Codec Operation
'''
target_latency: List[int]
target_phy: List[int]
codec_id: List[hci.CodingFormat]
codec_specific_configuration: List[bytes]
@ASE_Operation.subclass(
[
[
('ase_id', 1),
('cig_id', 1),
('cis_id', 1),
('sdu_interval', 3),
('framing', 1),
('phy', 1),
('max_sdu', 2),
('retransmission_number', 1),
('max_transport_latency', 2),
('presentation_delay', 3),
],
]
)
class ASE_Config_QOS(ASE_Operation):
'''
See Audio Stream Control Service 5.2 - Config Qos Operation
'''
cig_id: List[int]
cis_id: List[int]
sdu_interval: List[int]
framing: List[int]
phy: List[int]
max_sdu: List[int]
retransmission_number: List[int]
max_transport_latency: List[int]
presentation_delay: List[int]
@ASE_Operation.subclass([[('ase_id', 1), ('metadata', 'v')]])
class ASE_Enable(ASE_Operation):
'''
See Audio Stream Control Service 5.3 - Enable Operation
'''
metadata: bytes
@ASE_Operation.subclass([[('ase_id', 1)]])
class ASE_Receiver_Start_Ready(ASE_Operation):
'''
See Audio Stream Control Service 5.4 - Receiver Start Ready Operation
'''
@ASE_Operation.subclass([[('ase_id', 1)]])
class ASE_Disable(ASE_Operation):
'''
See Audio Stream Control Service 5.5 - Disable Operation
'''
@ASE_Operation.subclass([[('ase_id', 1)]])
class ASE_Receiver_Stop_Ready(ASE_Operation):
'''
See Audio Stream Control Service 5.6 - Receiver Stop Ready Operation
'''
@ASE_Operation.subclass([[('ase_id', 1), ('metadata', 'v')]])
class ASE_Update_Metadata(ASE_Operation):
'''
See Audio Stream Control Service 5.7 - Update Metadata Operation
'''
metadata: List[bytes]
@ASE_Operation.subclass([[('ase_id', 1)]])
class ASE_Release(ASE_Operation):
'''
See Audio Stream Control Service 5.8 - Release Operation
'''
class AseResponseCode(enum.IntEnum):
# fmt: off
SUCCESS = 0x00
UNSUPPORTED_OPCODE = 0x01
INVALID_LENGTH = 0x02
INVALID_ASE_ID = 0x03
INVALID_ASE_STATE_MACHINE_TRANSITION = 0x04
INVALID_ASE_DIRECTION = 0x05
UNSUPPORTED_AUDIO_CAPABILITIES = 0x06
UNSUPPORTED_CONFIGURATION_PARAMETER_VALUE = 0x07
REJECTED_CONFIGURATION_PARAMETER_VALUE = 0x08
INVALID_CONFIGURATION_PARAMETER_VALUE = 0x09
UNSUPPORTED_METADATA = 0x0A
REJECTED_METADATA = 0x0B
INVALID_METADATA = 0x0C
INSUFFICIENT_RESOURCES = 0x0D
UNSPECIFIED_ERROR = 0x0E
class AseReasonCode(enum.IntEnum):
# fmt: off
NONE = 0x00
CODEC_ID = 0x01
CODEC_SPECIFIC_CONFIGURATION = 0x02
SDU_INTERVAL = 0x03
FRAMING = 0x04
PHY = 0x05
MAXIMUM_SDU_SIZE = 0x06
RETRANSMISSION_NUMBER = 0x07
MAX_TRANSPORT_LATENCY = 0x08
PRESENTATION_DELAY = 0x09
INVALID_ASE_CIS_MAPPING = 0x0A
class AudioRole(enum.IntEnum):
SINK = hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction.CONTROLLER_TO_HOST
SOURCE = hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction.HOST_TO_CONTROLLER
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Utils # Utils
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -325,6 +558,80 @@ class CodecSpecificCapabilities:
) )
@dataclasses.dataclass
class CodecSpecificConfiguration:
'''See:
* Bluetooth Assigned Numbers, 6.12.5 - Codec Specific Configuration LTV Structures
* Basic Audio Profile, 4.3.2 - Codec_Specific_Capabilities LTV requirements
'''
class Type(enum.IntEnum):
# fmt: off
SAMPLING_FREQUENCY = 0x01
FRAME_DURATION = 0x02
AUDIO_CHANNEL_ALLOCATION = 0x03
OCTETS_PER_FRAME = 0x04
CODEC_FRAMES_PER_SDU = 0x05
sampling_frequency: SamplingFrequency
frame_duration: FrameDuration
audio_channel_allocation: AudioLocation
octets_per_codec_frame: int
codec_frames_per_sdu: int
@classmethod
def from_bytes(cls, data: bytes) -> CodecSpecificConfiguration:
offset = 0
# Allowed default values.
audio_channel_allocation = AudioLocation.NOT_ALLOWED
codec_frames_per_sdu = 1
while offset < len(data):
length, type = struct.unpack_from('BB', data, offset)
offset += 2
value = int.from_bytes(data[offset : offset + length - 1], 'little')
offset += length - 1
if type == CodecSpecificConfiguration.Type.SAMPLING_FREQUENCY:
sampling_frequency = SamplingFrequency(value)
elif type == CodecSpecificConfiguration.Type.FRAME_DURATION:
frame_duration = FrameDuration(value)
elif type == CodecSpecificConfiguration.Type.AUDIO_CHANNEL_ALLOCATION:
audio_channel_allocation = AudioLocation(value)
elif type == CodecSpecificConfiguration.Type.OCTETS_PER_FRAME:
octets_per_codec_frame = value
elif type == CodecSpecificConfiguration.Type.CODEC_FRAMES_PER_SDU:
codec_frames_per_sdu = value
# It is expected here that if some fields are missing, an error should be raised.
return CodecSpecificConfiguration(
sampling_frequency=sampling_frequency,
frame_duration=frame_duration,
audio_channel_allocation=audio_channel_allocation,
octets_per_codec_frame=octets_per_codec_frame,
codec_frames_per_sdu=codec_frames_per_sdu,
)
def __bytes__(self) -> bytes:
return struct.pack(
'<BBBBBBBBIBBHBBB',
2,
CodecSpecificConfiguration.Type.SAMPLING_FREQUENCY,
self.sampling_frequency,
2,
CodecSpecificConfiguration.Type.FRAME_DURATION,
self.frame_duration,
5,
CodecSpecificConfiguration.Type.AUDIO_CHANNEL_ALLOCATION,
self.audio_channel_allocation,
3,
CodecSpecificConfiguration.Type.OCTETS_PER_FRAME,
self.octets_per_codec_frame,
2,
CodecSpecificConfiguration.Type.CODEC_FRAMES_PER_SDU,
self.codec_frames_per_sdu,
)
@dataclasses.dataclass @dataclasses.dataclass
class PacRecord: class PacRecord:
coding_format: hci.CodingFormat coding_format: hci.CodingFormat
@@ -452,6 +759,429 @@ class PublishedAudioCapabilitiesService(gatt.TemplateService):
super().__init__(characteristics) super().__init__(characteristics)
class AseStateMachine(gatt.Characteristic):
class State(enum.IntEnum):
# fmt: off
IDLE = 0x00
CODEC_CONFIGURED = 0x01
QOS_CONFIGURED = 0x02
ENABLING = 0x03
STREAMING = 0x04
DISABLING = 0x05
RELEASING = 0x06
cis_link: Optional[device.CisLink] = None
# Additional parameters in CODEC_CONFIGURED State
preferred_framing = 0 # Unframed PDU supported
preferred_phy = 0
preferred_retransmission_number = 13
preferred_max_transport_latency = 100
supported_presentation_delay_min = 0
supported_presentation_delay_max = 0
preferred_presentation_delay_min = 0
preferred_presentation_delay_max = 0
codec_id = hci.CodingFormat(hci.CodecID.LC3)
codec_specific_configuration: Union[CodecSpecificConfiguration, bytes] = b''
# Additional parameters in QOS_CONFIGURED State
cig_id = 0
cis_id = 0
sdu_interval = 0
framing = 0
phy = 0
max_sdu = 0
retransmission_number = 0
max_transport_latency = 0
presentation_delay = 0
# Additional parameters in ENABLING, STREAMING, DISABLING State
# TODO: Parse this
metadata = b''
def __init__(
self,
role: AudioRole,
ase_id: int,
service: AudioStreamControlService,
) -> None:
self.service = service
self.ase_id = ase_id
self._state = AseStateMachine.State.IDLE
self.role = role
uuid = (
gatt.GATT_SINK_ASE_CHARACTERISTIC
if role == AudioRole.SINK
else gatt.GATT_SOURCE_ASE_CHARACTERISTIC
)
super().__init__(
uuid=uuid,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READABLE,
value=gatt.CharacteristicValue(read=self.on_read),
)
self.service.device.on('cis_request', self.on_cis_request)
self.service.device.on('cis_establishment', self.on_cis_establishment)
def on_cis_request(
self,
acl_connection: device.Connection,
cis_handle: int,
cig_id: int,
cis_id: int,
) -> None:
if cis_id == self.cis_id and self.state == self.State.ENABLING:
acl_connection.abort_on(
'flush', self.service.device.accept_cis_request(cis_handle)
)
def on_cis_establishment(self, cis_link: device.CisLink) -> None:
if cis_link.cis_id == self.cis_id and self.state == self.State.ENABLING:
self.state = self.State.STREAMING
self.cis_link = cis_link
async def post_cis_established():
await self.service.device.send_command(
hci.HCI_LE_Setup_ISO_Data_Path_Command(
connection_handle=cis_link.handle,
data_path_direction=self.role,
data_path_id=0x00, # Fixed HCI
codec_id=hci.CodingFormat(hci.CodecID.TRANSPARENT),
controller_delay=0,
codec_configuration=b'',
)
)
await self.service.device.notify_subscribers(self, self.value)
cis_link.acl_connection.abort_on('flush', post_cis_established())
def on_config_codec(
self,
target_latency: int,
target_phy: int,
codec_id: hci.CodingFormat,
codec_specific_configuration: bytes,
) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state not in (
self.State.IDLE,
self.State.CODEC_CONFIGURED,
self.State.QOS_CONFIGURED,
):
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.max_transport_latency = target_latency
self.phy = target_phy
self.codec_id = codec_id
if codec_id.codec_id == hci.CodecID.VENDOR_SPECIFIC:
self.codec_specific_configuration = codec_specific_configuration
else:
self.codec_specific_configuration = CodecSpecificConfiguration.from_bytes(
codec_specific_configuration
)
self.state = self.State.CODEC_CONFIGURED
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_config_qos(
self,
cig_id: int,
cis_id: int,
sdu_interval: int,
framing: int,
phy: int,
max_sdu: int,
retransmission_number: int,
max_transport_latency: int,
presentation_delay: int,
) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state not in (
AseStateMachine.State.CODEC_CONFIGURED,
AseStateMachine.State.QOS_CONFIGURED,
):
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.cig_id = cig_id
self.cis_id = cis_id
self.sdu_interval = sdu_interval
self.framing = framing
self.phy = phy
self.max_sdu = max_sdu
self.retransmission_number = retransmission_number
self.max_transport_latency = max_transport_latency
self.presentation_delay = presentation_delay
self.state = self.State.QOS_CONFIGURED
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_enable(self, metadata: bytes) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state != AseStateMachine.State.QOS_CONFIGURED:
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.metadata = metadata
self.state = self.State.ENABLING
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_receiver_start_ready(self) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state != AseStateMachine.State.ENABLING:
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.state = self.State.STREAMING
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_disable(self) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state not in (
AseStateMachine.State.ENABLING,
AseStateMachine.State.STREAMING,
):
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.state = self.State.DISABLING
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_receiver_stop_ready(self) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state != AseStateMachine.State.DISABLING:
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.state = self.State.QOS_CONFIGURED
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_update_metadata(
self, metadata: bytes
) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state not in (
AseStateMachine.State.ENABLING,
AseStateMachine.State.STREAMING,
):
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.metadata = metadata
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
def on_release(self) -> Tuple[AseResponseCode, AseReasonCode]:
if self.state == AseStateMachine.State.IDLE:
return (
AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION,
AseReasonCode.NONE,
)
self.state = self.State.RELEASING
async def remove_cis_async():
await self.service.device.send_command(
hci.HCI_LE_Remove_ISO_Data_Path_Command(
connection_handle=self.cis_link.handle,
data_path_direction=self.role,
)
)
self.state = self.State.IDLE
await self.service.device.notify_subscribers(self, self.value)
self.service.device.abort_on('flush', remove_cis_async())
return (AseResponseCode.SUCCESS, AseReasonCode.NONE)
@property
def state(self) -> State:
return self._state
@state.setter
def state(self, new_state: State) -> None:
logger.debug(f'{self} state change -> {colors.color(new_state.name, "cyan")}')
self._state = new_state
@property
def value(self):
'''Returns ASE_ID, ASE_STATE, and ASE Additional Parameters.'''
if self.state == self.State.CODEC_CONFIGURED:
codec_specific_configuration_bytes = bytes(
self.codec_specific_configuration
)
additional_parameters = (
struct.pack(
'<BBBH',
self.preferred_framing,
self.preferred_phy,
self.preferred_retransmission_number,
self.preferred_max_transport_latency,
)
+ self.supported_presentation_delay_min.to_bytes(3, 'little')
+ self.supported_presentation_delay_max.to_bytes(3, 'little')
+ self.preferred_presentation_delay_min.to_bytes(3, 'little')
+ self.preferred_presentation_delay_max.to_bytes(3, 'little')
+ bytes(self.codec_id)
+ bytes([len(codec_specific_configuration_bytes)])
+ codec_specific_configuration_bytes
)
elif self.state == self.State.QOS_CONFIGURED:
additional_parameters = (
bytes([self.cig_id, self.cis_id])
+ self.sdu_interval.to_bytes(3, 'little')
+ struct.pack(
'<BBHBH',
self.framing,
self.phy,
self.max_sdu,
self.retransmission_number,
self.max_transport_latency,
)
+ self.presentation_delay.to_bytes(3, 'little')
)
elif self.state in (
self.State.ENABLING,
self.State.STREAMING,
self.State.DISABLING,
):
additional_parameters = (
bytes([self.cig_id, self.cis_id, len(self.metadata)]) + self.metadata
)
else:
additional_parameters = b''
return bytes([self.ase_id, self.state]) + additional_parameters
@value.setter
def value(self, _new_value):
# Readonly. Do nothing in the setter.
pass
def on_read(self, _: Optional[device.Connection]) -> bytes:
return self.value
def __str__(self) -> str:
return (
f'AseStateMachine(id={self.ase_id}, role={self.role.name} '
f'state={self._state.name})'
)
class AudioStreamControlService(gatt.TemplateService):
UUID = gatt.GATT_AUDIO_STREAM_CONTROL_SERVICE
ase_state_machines: Dict[int, AseStateMachine]
ase_control_point: gatt.Characteristic
def __init__(
self,
device: device.Device,
source_ase_id: Sequence[int] = [],
sink_ase_id: Sequence[int] = [],
) -> None:
self.device = device
self.ase_state_machines = {
**{
id: AseStateMachine(role=AudioRole.SINK, ase_id=id, service=self)
for id in sink_ase_id
},
**{
id: AseStateMachine(role=AudioRole.SOURCE, ase_id=id, service=self)
for id in source_ase_id
},
} # ASE state machines, by ASE ID
self.ase_control_point = gatt.Characteristic(
uuid=gatt.GATT_ASE_CONTROL_POINT_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.WRITE
| gatt.Characteristic.Properties.WRITE_WITHOUT_RESPONSE
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.WRITEABLE,
value=gatt.CharacteristicValue(write=self.on_write_ase_control_point),
)
super().__init__([self.ase_control_point, *self.ase_state_machines.values()])
def on_operation(self, opcode: ASE_Operation.Opcode, ase_id: int, args):
if ase := self.ase_state_machines.get(ase_id):
handler = getattr(ase, 'on_' + opcode.name.lower())
return (ase_id, *handler(*args))
else:
return (ase_id, AseResponseCode.INVALID_ASE_ID, AseReasonCode.NONE)
def on_write_ase_control_point(self, connection, data):
operation = ASE_Operation.from_bytes(data)
responses = []
logger.debug(f'*** ASCS Write {operation} ***')
if operation.op_code == ASE_Operation.Opcode.CONFIG_CODEC:
for ase_id, *args in zip(
operation.ase_id,
operation.target_latency,
operation.target_phy,
operation.codec_id,
operation.codec_specific_configuration,
):
responses.append(self.on_operation(operation.op_code, ase_id, args))
elif operation.op_code == ASE_Operation.Opcode.CONFIG_QOS:
for ase_id, *args in zip(
operation.ase_id,
operation.cig_id,
operation.cis_id,
operation.sdu_interval,
operation.framing,
operation.phy,
operation.max_sdu,
operation.retransmission_number,
operation.max_transport_latency,
operation.presentation_delay,
):
responses.append(self.on_operation(operation.op_code, ase_id, args))
elif operation.op_code in (
ASE_Operation.Opcode.ENABLE,
ASE_Operation.Opcode.UPDATE_METADATA,
):
for ase_id, *args in zip(
operation.ase_id,
operation.metadata,
):
responses.append(self.on_operation(operation.op_code, ase_id, args))
elif operation.op_code in (
ASE_Operation.Opcode.RECEIVER_START_READY,
ASE_Operation.Opcode.DISABLE,
ASE_Operation.Opcode.RECEIVER_STOP_READY,
ASE_Operation.Opcode.RELEASE,
):
for ase_id in operation.ase_id:
responses.append(self.on_operation(operation.op_code, ase_id, []))
control_point_notification = bytes(
[operation.op_code, len(responses)]
) + b''.join(map(bytes, responses))
self.device.abort_on(
'flush',
self.device.notify_subscribers(
self.ase_control_point, control_point_notification
),
)
for ase_id, *_ in responses:
if ase := self.ase_state_machines.get(ase_id):
self.device.abort_on(
'flush',
self.device.notify_subscribers(ase, ase.value),
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Client # Client
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -494,3 +1224,24 @@ class PublishedAudioCapabilitiesServiceProxy(gatt_client.ProfileServiceProxy):
gatt.GATT_SOURCE_AUDIO_LOCATION_CHARACTERISTIC gatt.GATT_SOURCE_AUDIO_LOCATION_CHARACTERISTIC
): ):
self.source_audio_locations = characteristics[0] self.source_audio_locations = characteristics[0]
class AudioStreamControlServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = AudioStreamControlService
sink_ase: List[gatt_client.CharacteristicProxy]
source_ase: List[gatt_client.CharacteristicProxy]
ase_control_point: gatt_client.CharacteristicProxy
def __init__(self, service_proxy: gatt_client.ServiceProxy):
self.service_proxy = service_proxy
self.sink_ase = service_proxy.get_characteristics_by_uuid(
gatt.GATT_SINK_ASE_CHARACTERISTIC
)
self.source_ase = service_proxy.get_characteristics_by_uuid(
gatt.GATT_SOURCE_ASE_CHARACTERISTIC
)
self.ase_control_point = service_proxy.get_characteristics_by_uuid(
gatt.GATT_ASE_CONTROL_POINT_CHARACTERISTIC
)[0]
+52
View File
@@ -0,0 +1,52 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
from bumble import gatt
from bumble import gatt_client
from bumble.profiles import csip
# -----------------------------------------------------------------------------
# Server
# -----------------------------------------------------------------------------
class CommonAudioServiceService(gatt.TemplateService):
UUID = gatt.GATT_COMMON_AUDIO_SERVICE
def __init__(
self,
coordinated_set_identification_service: csip.CoordinatedSetIdentificationService,
) -> None:
self.coordinated_set_identification_service = (
coordinated_set_identification_service
)
super().__init__(
characteristics=[],
included_services=[coordinated_set_identification_service],
)
# -----------------------------------------------------------------------------
# Client
# -----------------------------------------------------------------------------
class CommonAudioServiceServiceProxy(gatt_client.ProfileServiceProxy):
SERVICE_CLASS = CommonAudioServiceService
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
self.service_proxy = service_proxy
+119 -9
View File
@@ -19,8 +19,11 @@
from __future__ import annotations from __future__ import annotations
import enum import enum
import struct import struct
from typing import Optional from typing import Optional, Tuple
from bumble import core
from bumble import crypto
from bumble import device
from bumble import gatt from bumble import gatt
from bumble import gatt_client from bumble import gatt_client
@@ -28,6 +31,9 @@ from bumble import gatt_client
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Constants # Constants
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
SET_IDENTITY_RESOLVING_KEY_LENGTH = 16
class SirkType(enum.IntEnum): class SirkType(enum.IntEnum):
'''Coordinated Set Identification Service - 5.1 Set Identity Resolving Key.''' '''Coordinated Set Identification Service - 5.1 Set Identity Resolving Key.'''
@@ -43,9 +49,47 @@ class MemberLock(enum.IntEnum):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Utils # Crypto Toolbox
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# TODO: Implement RSI Generator def s1(m: bytes) -> bytes:
'''
Coordinated Set Identification Service - 4.3 s1 SALT generation function.
'''
return crypto.aes_cmac(m[::-1], bytes(16))[::-1]
def k1(n: bytes, salt: bytes, p: bytes) -> bytes:
'''
Coordinated Set Identification Service - 4.4 k1 derivation function.
'''
t = crypto.aes_cmac(n[::-1], salt[::-1])
return crypto.aes_cmac(p[::-1], t)[::-1]
def sef(k: bytes, r: bytes) -> bytes:
'''
Coordinated Set Identification Service - 4.5 SIRK encryption function sef.
SIRK decryption function sdf shares the same algorithm. The only difference is that argument r is:
* Plaintext in encryption
* Cipher in decryption
'''
return crypto.xor(k1(k, s1(b'SIRKenc'[::-1]), b'csis'[::-1]), r)
def sih(k: bytes, r: bytes) -> bytes:
'''
Coordinated Set Identification Service - 4.7 Resolvable Set Identifier hash function sih.
'''
return crypto.e(k, r + bytes(13))[:3]
def generate_rsi(sirk: bytes) -> bytes:
'''
Coordinated Set Identification Service - 4.8 Resolvable Set Identifier generation operation.
'''
prand = crypto.generate_prand()
return sih(sirk, prand) + prand
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -54,6 +98,7 @@ class MemberLock(enum.IntEnum):
class CoordinatedSetIdentificationService(gatt.TemplateService): class CoordinatedSetIdentificationService(gatt.TemplateService):
UUID = gatt.GATT_COORDINATED_SET_IDENTIFICATION_SERVICE UUID = gatt.GATT_COORDINATED_SET_IDENTIFICATION_SERVICE
set_identity_resolving_key: bytes
set_identity_resolving_key_characteristic: gatt.Characteristic set_identity_resolving_key_characteristic: gatt.Characteristic
coordinated_set_size_characteristic: Optional[gatt.Characteristic] = None coordinated_set_size_characteristic: Optional[gatt.Characteristic] = None
set_member_lock_characteristic: Optional[gatt.Characteristic] = None set_member_lock_characteristic: Optional[gatt.Characteristic] = None
@@ -62,19 +107,26 @@ class CoordinatedSetIdentificationService(gatt.TemplateService):
def __init__( def __init__(
self, self,
set_identity_resolving_key: bytes, set_identity_resolving_key: bytes,
set_identity_resolving_key_type: SirkType,
coordinated_set_size: Optional[int] = None, coordinated_set_size: Optional[int] = None,
set_member_lock: Optional[MemberLock] = None, set_member_lock: Optional[MemberLock] = None,
set_member_rank: Optional[int] = None, set_member_rank: Optional[int] = None,
) -> None: ) -> None:
if len(set_identity_resolving_key) != SET_IDENTITY_RESOLVING_KEY_LENGTH:
raise ValueError(
f'Invalid SIRK length {len(set_identity_resolving_key)}, expected {SET_IDENTITY_RESOLVING_KEY_LENGTH}'
)
characteristics = [] characteristics = []
self.set_identity_resolving_key = set_identity_resolving_key
self.set_identity_resolving_key_type = set_identity_resolving_key_type
self.set_identity_resolving_key_characteristic = gatt.Characteristic( self.set_identity_resolving_key_characteristic = gatt.Characteristic(
uuid=gatt.GATT_SET_IDENTITY_RESOLVING_KEY_CHARACTERISTIC, uuid=gatt.GATT_SET_IDENTITY_RESOLVING_KEY_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY, | gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READABLE, permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
# TODO: Implement encrypted SIRK reader. value=gatt.CharacteristicValue(read=self.on_sirk_read),
value=struct.pack('B', SirkType.PLAINTEXT) + set_identity_resolving_key,
) )
characteristics.append(self.set_identity_resolving_key_characteristic) characteristics.append(self.set_identity_resolving_key_characteristic)
@@ -83,7 +135,7 @@ class CoordinatedSetIdentificationService(gatt.TemplateService):
uuid=gatt.GATT_COORDINATED_SET_SIZE_CHARACTERISTIC, uuid=gatt.GATT_COORDINATED_SET_SIZE_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY, | gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READABLE, permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=struct.pack('B', coordinated_set_size), value=struct.pack('B', coordinated_set_size),
) )
characteristics.append(self.coordinated_set_size_characteristic) characteristics.append(self.coordinated_set_size_characteristic)
@@ -94,7 +146,7 @@ class CoordinatedSetIdentificationService(gatt.TemplateService):
properties=gatt.Characteristic.Properties.READ properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY | gatt.Characteristic.Properties.NOTIFY
| gatt.Characteristic.Properties.WRITE, | gatt.Characteristic.Properties.WRITE,
permissions=gatt.Characteristic.Permissions.READABLE permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION
| gatt.Characteristic.Permissions.WRITEABLE, | gatt.Characteristic.Permissions.WRITEABLE,
value=struct.pack('B', set_member_lock), value=struct.pack('B', set_member_lock),
) )
@@ -105,13 +157,45 @@ class CoordinatedSetIdentificationService(gatt.TemplateService):
uuid=gatt.GATT_SET_MEMBER_RANK_CHARACTERISTIC, uuid=gatt.GATT_SET_MEMBER_RANK_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY, | gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READABLE, permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
value=struct.pack('B', set_member_rank), value=struct.pack('B', set_member_rank),
) )
characteristics.append(self.set_member_rank_characteristic) characteristics.append(self.set_member_rank_characteristic)
super().__init__(characteristics) super().__init__(characteristics)
async def on_sirk_read(self, connection: Optional[device.Connection]) -> bytes:
if self.set_identity_resolving_key_type == SirkType.PLAINTEXT:
sirk_bytes = self.set_identity_resolving_key
else:
assert connection
if connection.transport == core.BT_LE_TRANSPORT:
key = await connection.device.get_long_term_key(
connection_handle=connection.handle, rand=b'', ediv=0
)
else:
key = await connection.device.get_link_key(connection.peer_address)
if not key:
raise RuntimeError('LTK or LinkKey is not present')
sirk_bytes = sef(key, self.set_identity_resolving_key)
return bytes([self.set_identity_resolving_key_type]) + sirk_bytes
def get_advertising_data(self) -> bytes:
return bytes(
core.AdvertisingData(
[
(
core.AdvertisingData.RESOLVABLE_SET_IDENTIFIER,
generate_rsi(self.set_identity_resolving_key),
),
]
)
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Client # Client
@@ -145,3 +229,29 @@ class CoordinatedSetIdentificationProxy(gatt_client.ProfileServiceProxy):
gatt.GATT_SET_MEMBER_RANK_CHARACTERISTIC gatt.GATT_SET_MEMBER_RANK_CHARACTERISTIC
): ):
self.set_member_rank = characteristics[0] self.set_member_rank = characteristics[0]
async def read_set_identity_resolving_key(self) -> Tuple[SirkType, bytes]:
'''Reads SIRK and decrypts if encrypted.'''
response = await self.set_identity_resolving_key.read_value()
if len(response) != SET_IDENTITY_RESOLVING_KEY_LENGTH + 1:
raise RuntimeError('Invalid SIRK value')
sirk_type = SirkType(response[0])
if sirk_type == SirkType.PLAINTEXT:
sirk = response[1:]
else:
connection = self.service_proxy.client.connection
device = connection.device
if connection.transport == core.BT_LE_TRANSPORT:
key = await device.get_long_term_key(
connection_handle=connection.handle, rand=b'', ediv=0
)
else:
key = await device.get_link_key(connection.peer_address)
if not key:
raise RuntimeError('LTK or LinkKey is not present')
sirk = sef(key, response[1:])
return (sirk_type, sirk)
+31 -20
View File
@@ -118,8 +118,8 @@ CRC_TABLE = bytes([
0XBA, 0X2B, 0X59, 0XC8, 0XBD, 0X2C, 0X5E, 0XCF 0XBA, 0X2B, 0X59, 0XC8, 0XBD, 0X2C, 0X5E, 0XCF
]) ])
RFCOMM_DEFAULT_INITIAL_RX_CREDITS = 7 RFCOMM_DEFAULT_WINDOW_SIZE = 16
RFCOMM_DEFAULT_PREFERRED_MTU = 1280 RFCOMM_DEFAULT_MAX_FRAME_SIZE = 2000
RFCOMM_DYNAMIC_CHANNEL_NUMBER_START = 1 RFCOMM_DYNAMIC_CHANNEL_NUMBER_START = 1
RFCOMM_DYNAMIC_CHANNEL_NUMBER_END = 30 RFCOMM_DYNAMIC_CHANNEL_NUMBER_END = 30
@@ -438,20 +438,24 @@ class DLC(EventEmitter):
multiplexer: Multiplexer, multiplexer: Multiplexer,
dlci: int, dlci: int,
max_frame_size: int, max_frame_size: int,
initial_tx_credits: int, window_size: int,
) -> None: ) -> None:
super().__init__() super().__init__()
self.multiplexer = multiplexer self.multiplexer = multiplexer
self.dlci = dlci self.dlci = dlci
self.rx_credits = RFCOMM_DEFAULT_INITIAL_RX_CREDITS self.max_frame_size = max_frame_size
self.rx_threshold = self.rx_credits // 2 self.window_size = window_size
self.tx_credits = initial_tx_credits self.rx_credits = window_size
self.rx_threshold = window_size // 2
self.tx_credits = window_size
self.tx_buffer = b'' self.tx_buffer = b''
self.state = DLC.State.INIT self.state = DLC.State.INIT
self.role = multiplexer.role self.role = multiplexer.role
self.c_r = 1 if self.role == Multiplexer.Role.INITIATOR else 0 self.c_r = 1 if self.role == Multiplexer.Role.INITIATOR else 0
self.sink = None self.sink = None
self.connection_result = None self.connection_result = None
self.drained = asyncio.Event()
self.drained.set()
# Compute the MTU # Compute the MTU
max_overhead = 4 + 1 # header with 2-byte length + fcs max_overhead = 4 + 1 # header with 2-byte length + fcs
@@ -537,11 +541,11 @@ class DLC(EventEmitter):
if len(data) and self.sink: if len(data) and self.sink:
self.sink(data) # pylint: disable=not-callable self.sink(data) # pylint: disable=not-callable
# Update the credits # Update the credits
if self.rx_credits > 0: if self.rx_credits > 0:
self.rx_credits -= 1 self.rx_credits -= 1
else: else:
logger.warning(color('!!! received frame with no rx credits', 'red')) logger.warning(color('!!! received frame with no rx credits', 'red'))
# Check if there's anything to send (including credits) # Check if there's anything to send (including credits)
self.process_tx() self.process_tx()
@@ -580,9 +584,9 @@ class DLC(EventEmitter):
cl=0xE0, cl=0xE0,
priority=7, priority=7,
ack_timer=0, ack_timer=0,
max_frame_size=RFCOMM_DEFAULT_PREFERRED_MTU, max_frame_size=self.max_frame_size,
max_retransmissions=0, max_retransmissions=0,
window_size=RFCOMM_DEFAULT_INITIAL_RX_CREDITS, window_size=self.window_size,
) )
mcc = RFCOMM_Frame.make_mcc(mcc_type=RFCOMM_MCC_PN_TYPE, c_r=0, data=bytes(pn)) mcc = RFCOMM_Frame.make_mcc(mcc_type=RFCOMM_MCC_PN_TYPE, c_r=0, data=bytes(pn))
logger.debug(f'>>> PN Response: {pn}') logger.debug(f'>>> PN Response: {pn}')
@@ -591,7 +595,7 @@ class DLC(EventEmitter):
def rx_credits_needed(self) -> int: def rx_credits_needed(self) -> int:
if self.rx_credits <= self.rx_threshold: if self.rx_credits <= self.rx_threshold:
return RFCOMM_DEFAULT_INITIAL_RX_CREDITS - self.rx_credits return self.window_size - self.rx_credits
return 0 return 0
@@ -631,6 +635,8 @@ class DLC(EventEmitter):
) )
rx_credits_needed = 0 rx_credits_needed = 0
if not self.tx_buffer:
self.drained.set()
# Stream protocol # Stream protocol
def write(self, data: Union[bytes, str]) -> None: def write(self, data: Union[bytes, str]) -> None:
@@ -643,11 +649,11 @@ class DLC(EventEmitter):
raise ValueError('write only accept bytes or strings') raise ValueError('write only accept bytes or strings')
self.tx_buffer += data self.tx_buffer += data
self.drained.clear()
self.process_tx() self.process_tx()
def drain(self) -> None: async def drain(self) -> None:
# TODO await self.drained.wait()
pass
def __str__(self) -> str: def __str__(self) -> str:
return f'DLC(dlci={self.dlci},state={self.state.name})' return f'DLC(dlci={self.dlci},state={self.state.name})'
@@ -843,7 +849,12 @@ class Multiplexer(EventEmitter):
) )
await self.disconnection_result await self.disconnection_result
async def open_dlc(self, channel: int) -> DLC: async def open_dlc(
self,
channel: int,
max_frame_size: int = RFCOMM_DEFAULT_MAX_FRAME_SIZE,
window_size: int = RFCOMM_DEFAULT_WINDOW_SIZE,
) -> DLC:
if self.state != Multiplexer.State.CONNECTED: if self.state != Multiplexer.State.CONNECTED:
if self.state == Multiplexer.State.OPENING: if self.state == Multiplexer.State.OPENING:
raise InvalidStateError('open already in progress') raise InvalidStateError('open already in progress')
@@ -855,9 +866,9 @@ class Multiplexer(EventEmitter):
cl=0xF0, cl=0xF0,
priority=7, priority=7,
ack_timer=0, ack_timer=0,
max_frame_size=RFCOMM_DEFAULT_PREFERRED_MTU, max_frame_size=max_frame_size,
max_retransmissions=0, max_retransmissions=0,
window_size=RFCOMM_DEFAULT_INITIAL_RX_CREDITS, window_size=window_size,
) )
mcc = RFCOMM_Frame.make_mcc(mcc_type=RFCOMM_MCC_PN_TYPE, c_r=1, data=bytes(pn)) mcc = RFCOMM_Frame.make_mcc(mcc_type=RFCOMM_MCC_PN_TYPE, c_r=1, data=bytes(pn))
logger.debug(f'>>> Sending MCC: {pn}') logger.debug(f'>>> Sending MCC: {pn}')
+49 -21
View File
@@ -18,6 +18,7 @@
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
import logging import logging
import os import os
from typing import Optional
from .common import Transport, AsyncPipeSink, SnoopingTransport from .common import Transport, AsyncPipeSink, SnoopingTransport
from ..snoop import create_snooper from ..snoop import create_snooper
@@ -52,8 +53,16 @@ def _wrap_transport(transport: Transport) -> Transport:
async def open_transport(name: str) -> Transport: async def open_transport(name: str) -> Transport:
""" """
Open a transport by name. Open a transport by name.
The name must be <type>:<parameters> The name must be <type>:<metadata><parameters>
Where <parameters> depend on the type (and may be empty for some types). Where <parameters> depend on the type (and may be empty for some types), and
<metadata> is either omitted, or a ,-separated list of <key>=<value> pairs,
enclosed in [].
If there are not metadata or parameter, the : after the <type> may be omitted.
Examples:
* usb:0
* usb:[driver=rtk]0
* android-netsim
The supported types are: The supported types are:
* serial * serial
* udp * udp
@@ -71,87 +80,106 @@ async def open_transport(name: str) -> Transport:
* android-netsim * android-netsim
""" """
return _wrap_transport(await _open_transport(name)) scheme, *tail = name.split(':', 1)
spec = tail[0] if tail else None
if spec:
# Metadata may precede the spec
if spec.startswith('['):
metadata_str, *tail = spec[1:].split(']')
spec = tail[0] if tail else None
metadata = dict([entry.split('=') for entry in metadata_str.split(',')])
else:
metadata = None
transport = await _open_transport(scheme, spec)
if metadata:
transport.source.metadata = { # type: ignore[attr-defined]
**metadata,
**getattr(transport.source, 'metadata', {}),
}
# pylint: disable=line-too-long
logger.debug(f'HCI metadata: {transport.source.metadata}') # type: ignore[attr-defined]
return _wrap_transport(transport)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def _open_transport(name: str) -> Transport: async def _open_transport(scheme: str, spec: Optional[str]) -> Transport:
# pylint: disable=import-outside-toplevel # pylint: disable=import-outside-toplevel
# pylint: disable=too-many-return-statements # pylint: disable=too-many-return-statements
scheme, *spec = name.split(':', 1)
if scheme == 'serial' and spec: if scheme == 'serial' and spec:
from .serial import open_serial_transport from .serial import open_serial_transport
return await open_serial_transport(spec[0]) return await open_serial_transport(spec)
if scheme == 'udp' and spec: if scheme == 'udp' and spec:
from .udp import open_udp_transport from .udp import open_udp_transport
return await open_udp_transport(spec[0]) return await open_udp_transport(spec)
if scheme == 'tcp-client' and spec: if scheme == 'tcp-client' and spec:
from .tcp_client import open_tcp_client_transport from .tcp_client import open_tcp_client_transport
return await open_tcp_client_transport(spec[0]) return await open_tcp_client_transport(spec)
if scheme == 'tcp-server' and spec: if scheme == 'tcp-server' and spec:
from .tcp_server import open_tcp_server_transport from .tcp_server import open_tcp_server_transport
return await open_tcp_server_transport(spec[0]) return await open_tcp_server_transport(spec)
if scheme == 'ws-client' and spec: if scheme == 'ws-client' and spec:
from .ws_client import open_ws_client_transport from .ws_client import open_ws_client_transport
return await open_ws_client_transport(spec[0]) return await open_ws_client_transport(spec)
if scheme == 'ws-server' and spec: if scheme == 'ws-server' and spec:
from .ws_server import open_ws_server_transport from .ws_server import open_ws_server_transport
return await open_ws_server_transport(spec[0]) return await open_ws_server_transport(spec)
if scheme == 'pty': if scheme == 'pty':
from .pty import open_pty_transport from .pty import open_pty_transport
return await open_pty_transport(spec[0] if spec else None) return await open_pty_transport(spec)
if scheme == 'file': if scheme == 'file':
from .file import open_file_transport from .file import open_file_transport
assert spec is not None assert spec is not None
return await open_file_transport(spec[0]) return await open_file_transport(spec)
if scheme == 'vhci': if scheme == 'vhci':
from .vhci import open_vhci_transport from .vhci import open_vhci_transport
return await open_vhci_transport(spec[0] if spec else None) return await open_vhci_transport(spec)
if scheme == 'hci-socket': if scheme == 'hci-socket':
from .hci_socket import open_hci_socket_transport from .hci_socket import open_hci_socket_transport
return await open_hci_socket_transport(spec[0] if spec else None) return await open_hci_socket_transport(spec)
if scheme == 'usb': if scheme == 'usb':
from .usb import open_usb_transport from .usb import open_usb_transport
assert spec is not None assert spec
return await open_usb_transport(spec[0]) return await open_usb_transport(spec)
if scheme == 'pyusb': if scheme == 'pyusb':
from .pyusb import open_pyusb_transport from .pyusb import open_pyusb_transport
assert spec is not None assert spec
return await open_pyusb_transport(spec[0]) return await open_pyusb_transport(spec)
if scheme == 'android-emulator': if scheme == 'android-emulator':
from .android_emulator import open_android_emulator_transport from .android_emulator import open_android_emulator_transport
return await open_android_emulator_transport(spec[0] if spec else None) return await open_android_emulator_transport(spec)
if scheme == 'android-netsim': if scheme == 'android-netsim':
from .android_netsim import open_android_netsim_transport from .android_netsim import open_android_netsim_transport
return await open_android_netsim_transport(spec[0] if spec else None) return await open_android_netsim_transport(spec)
raise ValueError('unknown transport scheme') raise ValueError('unknown transport scheme')
+1 -1
View File
@@ -69,7 +69,7 @@ async def open_android_emulator_transport(spec: Optional[str]) -> Transport:
mode = 'host' mode = 'host'
server_host = 'localhost' server_host = 'localhost'
server_port = '8554' server_port = '8554'
if spec is not None: if spec:
params = spec.split(',') params = spec.split(',')
for param in params: for param in params:
if param.startswith('mode='): if param.startswith('mode='):
+1 -1
View File
@@ -21,7 +21,7 @@ import struct
import asyncio import asyncio
import logging import logging
import io import io
from typing import ContextManager, Tuple, Optional, Protocol, Dict from typing import Any, ContextManager, Tuple, Optional, Protocol, Dict
from bumble import hci from bumble import hci
from bumble.colors import color from bumble.colors import color
+1 -4
View File
@@ -59,10 +59,7 @@ async def open_hci_socket_transport(spec: Optional[str]) -> Transport:
) from error ) from error
# Compute the adapter index # Compute the adapter index
if spec is None: adapter_index = int(spec) if spec else 0
adapter_index = 0
else:
adapter_index = int(spec)
# Bind the socket # Bind the socket
# NOTE: since Python doesn't support binding with the required address format (yet), # NOTE: since Python doesn't support binding with the required address format (yet),
+1 -1
View File
@@ -108,7 +108,7 @@ async def open_usb_transport(spec: str) -> Transport:
USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER, USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER,
) )
READ_SIZE = 1024 READ_SIZE = 4096
class UsbPacketSink: class UsbPacketSink:
def __init__(self, device, acl_out): def __init__(self, device, acl_out):
+3 -6
View File
@@ -280,17 +280,14 @@ class AsyncRunner:
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
coroutine = func(*args, **kwargs) coroutine = func(*args, **kwargs)
if queue is None: if queue is None:
# Create a task to run the coroutine # Spawn the coroutine as a task
async def run(): async def run():
try: try:
await coroutine await coroutine
except Exception: except Exception:
logger.warning( logger.exception(color("!!! Exception in wrapper:", "red"))
f'{color("!!! Exception in wrapper:", "red")} '
f'{traceback.format_exc()}'
)
asyncio.create_task(run()) AsyncRunner.spawn(run())
else: else:
# Queue the coroutine to be awaited by the work queue # Queue the coroutine to be awaited by the work queue
queue.enqueue(coroutine) queue.enqueue(coroutine)
+30 -9
View File
@@ -7,16 +7,36 @@ throughput and/or latency between two devices.
# General Usage # General Usage
``` ```
Usage: bench.py [OPTIONS] COMMAND [ARGS]... Usage: bumble-bench [OPTIONS] COMMAND [ARGS]...
Options: Options:
--device-config FILENAME Device configuration file --device-config FILENAME Device configuration file
--role [sender|receiver|ping|pong] --role [sender|receiver|ping|pong]
--mode [gatt-client|gatt-server|l2cap-client|l2cap-server|rfcomm-client|rfcomm-server] --mode [gatt-client|gatt-server|l2cap-client|l2cap-server|rfcomm-client|rfcomm-server]
--att-mtu MTU GATT MTU (gatt-client mode) [23<=x<=517] --att-mtu MTU GATT MTU (gatt-client mode) [23<=x<=517]
-s, --packet-size SIZE Packet size (server role) [8<=x<=4096] --extended-data-length TEXT Request a data length upon connection,
-c, --packet-count COUNT Packet count (server role) specified as tx_octets/tx_time
-sd, --start-delay SECONDS Start delay (server role) --rfcomm-channel INTEGER RFComm channel to use
--rfcomm-uuid TEXT RFComm service UUID to use (ignored if
--rfcomm-channel is not 0)
--l2cap-psm INTEGER L2CAP PSM to use
--l2cap-mtu INTEGER L2CAP MTU to use
--l2cap-mps INTEGER L2CAP MPS to use
--l2cap-max-credits INTEGER L2CAP maximum number of credits allowed for
the peer
-s, --packet-size SIZE Packet size (client or ping role)
[8<=x<=4096]
-c, --packet-count COUNT Packet count (client or ping role)
-sd, --start-delay SECONDS Start delay (client or ping role)
--repeat N Repeat the run N times (client and ping
roles)(0, which is the fault, to run just
once)
--repeat-delay SECONDS Delay, in seconds, between repeats
--pace MILLISECONDS Wait N milliseconds between packets (0,
which is the fault, to send as fast as
possible)
--linger Don't exit at the end of a run (server and
pong roles)
--help Show this message and exit. --help Show this message and exit.
Commands: Commands:
@@ -35,17 +55,18 @@ Options:
--connection-interval, --ci CONNECTION_INTERVAL --connection-interval, --ci CONNECTION_INTERVAL
Connection interval (in ms) Connection interval (in ms)
--phy [1m|2m|coded] PHY to use --phy [1m|2m|coded] PHY to use
--authenticate Authenticate (RFComm only)
--encrypt Encrypt the connection (RFComm only)
--help Show this message and exit. --help Show this message and exit.
``` ```
To test once device against another, one of the two devices must be running
To test once device against another, one of the two devices must be running
the ``peripheral`` command and the other the ``central`` command. The device the ``peripheral`` command and the other the ``central`` command. The device
running the ``peripheral`` command will accept connections from the device running the ``peripheral`` command will accept connections from the device
running the ``central`` command. running the ``central`` command.
When using Bluetooth LE (all modes except for ``rfcomm-server`` and ``rfcomm-client``utils), When using Bluetooth LE (all modes except for ``rfcomm-server`` and ``rfcomm-client``utils),
the default addresses configured in the tool should be sufficient. But when using the default addresses configured in the tool should be sufficient. But when using
Bluetooth Classic, the address of the Peripheral must be specified on the Central Bluetooth Classic, the address of the Peripheral must be specified on the Central
using the ``--peripheral`` option. The address will be printed by the Peripheral when using the ``--peripheral`` option. The address will be printed by the Peripheral when
it starts. it starts.
@@ -83,7 +104,7 @@ the other on `usb:1`, and two consoles/terminals. We will run a command in each.
$ bumble-bench central usb:1 $ bumble-bench central usb:1
``` ```
In this default configuration, the Central runs a Sender, as a GATT client, In this default configuration, the Central runs a Sender, as a GATT client,
connecting to the Peripheral running a Receiver, as a GATT server. connecting to the Peripheral running a Receiver, as a GATT server.
!!! example "L2CAP Throughput" !!! example "L2CAP Throughput"
+9
View File
@@ -5,6 +5,15 @@ Some Bluetooth controllers require a driver to function properly.
This may include, for instance, loading a Firmware image or patch, This may include, for instance, loading a Firmware image or patch,
loading a configuration. loading a configuration.
By default, drivers will be automatically probed to determine if they should be
used with particular HCI controller.
When the transport for an HCI controller is instantiated from a transport name,
a driver may also be forced by specifying ``driver=<driver-name>`` in the optional
metadata portion of the transport name. For example,
``usb:[driver=-rtk]0`` indicates that the ``rtk`` driver should be used with the
first USB device, even if a normal probe would not have selected it based on the
USB vendor ID and product ID.
Drivers included in the module are: Drivers included in the module are:
* [Realtek](realtek.md): Loading of Firmware and Config for Realtek USB dongles. * [Realtek](realtek.md): Loading of Firmware and Config for Realtek USB dongles.
+5 -2
View File
@@ -1,13 +1,16 @@
REALTEK DRIVER REALTEK DRIVER
============== ==============
This driver supports loading firmware images and optional config data to This driver supports loading firmware images and optional config data to
USB dongles with a Realtek chipset. USB dongles with a Realtek chipset.
A number of USB dongles are supported, but likely not all. A number of USB dongles are supported, but likely not all.
When using a USB dongle, the USB product ID and manufacturer ID are used When using a USB dongle, the USB product ID and vendor ID are used
to find whether a matching set of firmware image and config data to find whether a matching set of firmware image and config data
is needed for that specific model. If a match exists, the driver will try is needed for that specific model. If a match exists, the driver will try
load the firmware image and, if needed, config data. load the firmware image and, if needed, config data.
Alternatively, the metadata property ``driver=rtk`` may be specified in a transport
name to force that driver to be used (ex: ``usb:[driver=rtk]0`` instead of just
``usb:0`` for the first USB device).
The driver will look for those files by name, in order, in: The driver will look for those files by name, in order, in:
* The directory specified by the environment variable `BUMBLE_RTK_FIRMWARE_DIR` * The directory specified by the environment variable `BUMBLE_RTK_FIRMWARE_DIR`
+5
View File
@@ -0,0 +1,5 @@
{
"name": "Bumble HID Keyboard",
"class_of_device": 9664,
"keystore": "JsonKeyStore"
}
+3 -3
View File
@@ -40,9 +40,9 @@
} }
} }
function onMouseMove(event) { function onMouseMove(event) {
//console.log(event.clientX, event.clientY) //console.log(event.movementX, event.movementY)
mouseInfo.innerText = `MOUSE: x=${event.clientX}, y=${event.clientY}` mouseInfo.innerText = `MOUSE: x=${event.movementX}, y=${event.movementY}`
send({ type:'mousemove', x: event.clientX, y: event.clientY }) send({ type:'mousemove', x: event.movementX, y: event.movementY })
} }
function onKeyDown(event) { function onKeyDown(event) {
+1
View File
@@ -1,5 +1,6 @@
{ {
"name": "Bumble-LEA", "name": "Bumble-LEA",
"keystore": "JsonKeyStore", "keystore": "JsonKeyStore",
"address": "F0:F1:F2:F3:F4:FA",
"advertising_interval": 100 "advertising_interval": 100
} }
+116
View File
@@ -0,0 +1,116 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import logging
import sys
import os
import secrets
from bumble.core import AdvertisingData
from bumble.device import Device
from bumble.hci import (
Address,
OwnAddressType,
HCI_LE_Set_Extended_Advertising_Parameters_Command,
)
from bumble.profiles.cap import CommonAudioServiceService
from bumble.profiles.csip import CoordinatedSetIdentificationService, SirkType
from bumble.transport import open_transport_or_link
# -----------------------------------------------------------------------------
async def main() -> None:
if len(sys.argv) < 3:
print(
'Usage: run_cig_setup.py <config-file>'
'<transport-spec-for-device-1> <transport-spec-for-device-2>'
)
print(
'example: run_cig_setup.py device1.json'
'tcp-client:127.0.0.1:6402 tcp-client:127.0.0.1:6402'
)
return
print('<<< connecting to HCI...')
hci_transports = await asyncio.gather(
open_transport_or_link(sys.argv[2]), open_transport_or_link(sys.argv[3])
)
print('<<< connected')
devices = [
Device.from_config_file_with_hci(
sys.argv[1], hci_transport.source, hci_transport.sink
)
for hci_transport in hci_transports
]
sirk = secrets.token_bytes(16)
for i, device in enumerate(devices):
device.random_address = Address(secrets.token_bytes(6))
await device.power_on()
csis = CoordinatedSetIdentificationService(
set_identity_resolving_key=sirk,
set_identity_resolving_key_type=SirkType.PLAINTEXT,
coordinated_set_size=2,
)
device.add_service(CommonAudioServiceService(csis))
advertising_data = (
bytes(
AdvertisingData(
[
(
AdvertisingData.COMPLETE_LOCAL_NAME,
bytes(f'Bumble LE Audio-{i}', 'utf-8'),
),
(
AdvertisingData.FLAGS,
bytes(
[
AdvertisingData.LE_GENERAL_DISCOVERABLE_MODE_FLAG
| AdvertisingData.BR_EDR_HOST_FLAG
| AdvertisingData.BR_EDR_CONTROLLER_FLAG
]
),
),
(
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
bytes(CoordinatedSetIdentificationService.UUID),
),
]
)
)
+ csis.get_advertising_data()
)
await device.start_extended_advertising(
advertising_properties=(
HCI_LE_Set_Extended_Advertising_Parameters_Command.AdvertisingProperties.CONNECTABLE_ADVERTISING
),
own_address_type=OwnAddressType.RANDOM,
advertising_data=advertising_data,
)
await asyncio.gather(
*[hci_transport.source.terminated for hci_transport in hci_transports]
)
# -----------------------------------------------------------------------------
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main())
+748
View File
@@ -0,0 +1,748 @@
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import sys
import os
import logging
import json
import websockets
from bumble.colors import color
from bumble.device import Device
from bumble.transport import open_transport_or_link
from bumble.core import (
BT_BR_EDR_TRANSPORT,
BT_L2CAP_PROTOCOL_ID,
BT_HUMAN_INTERFACE_DEVICE_SERVICE,
BT_HIDP_PROTOCOL_ID,
UUID,
)
from bumble.hci import Address
from bumble.hid import (
Device as HID_Device,
HID_CONTROL_PSM,
HID_INTERRUPT_PSM,
Message,
)
from bumble.sdp import (
Client as SDP_Client,
DataElement,
ServiceAttribute,
SDP_PUBLIC_BROWSE_ROOT,
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_ALL_ATTRIBUTES_RANGE,
SDP_LANGUAGE_BASE_ATTRIBUTE_ID_LIST_ATTRIBUTE_ID,
SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
)
from bumble.utils import AsyncRunner
# -----------------------------------------------------------------------------
# SDP attributes for Bluetooth HID devices
SDP_HID_SERVICE_NAME_ATTRIBUTE_ID = 0x0100
SDP_HID_SERVICE_DESCRIPTION_ATTRIBUTE_ID = 0x0101
SDP_HID_PROVIDER_NAME_ATTRIBUTE_ID = 0x0102
SDP_HID_DEVICE_RELEASE_NUMBER_ATTRIBUTE_ID = 0x0200 # [DEPRECATED]
SDP_HID_PARSER_VERSION_ATTRIBUTE_ID = 0x0201
SDP_HID_DEVICE_SUBCLASS_ATTRIBUTE_ID = 0x0202
SDP_HID_COUNTRY_CODE_ATTRIBUTE_ID = 0x0203
SDP_HID_VIRTUAL_CABLE_ATTRIBUTE_ID = 0x0204
SDP_HID_RECONNECT_INITIATE_ATTRIBUTE_ID = 0x0205
SDP_HID_DESCRIPTOR_LIST_ATTRIBUTE_ID = 0x0206
SDP_HID_LANGID_BASE_LIST_ATTRIBUTE_ID = 0x0207
SDP_HID_SDP_DISABLE_ATTRIBUTE_ID = 0x0208 # [DEPRECATED]
SDP_HID_BATTERY_POWER_ATTRIBUTE_ID = 0x0209
SDP_HID_REMOTE_WAKE_ATTRIBUTE_ID = 0x020A
SDP_HID_PROFILE_VERSION_ATTRIBUTE_ID = 0x020B # DEPRECATED]
SDP_HID_SUPERVISION_TIMEOUT_ATTRIBUTE_ID = 0x020C
SDP_HID_NORMALLY_CONNECTABLE_ATTRIBUTE_ID = 0x020D
SDP_HID_BOOT_DEVICE_ATTRIBUTE_ID = 0x020E
SDP_HID_SSR_HOST_MAX_LATENCY_ATTRIBUTE_ID = 0x020F
SDP_HID_SSR_HOST_MIN_TIMEOUT_ATTRIBUTE_ID = 0x0210
# Refer to HID profile specification v1.1.1, "5.3 Service Discovery Protocol (SDP)" for details
# HID SDP attribute values
LANGUAGE = 0x656E # 0x656E uint16 “en” (English)
ENCODING = 0x6A # 0x006A uint16 UTF-8 encoding
PRIMARY_LANGUAGE_BASE_ID = 0x100 # 0x0100 uint16 PrimaryLanguageBaseID
VERSION_NUMBER = 0x0101 # 0x0101 uint16 version number (v1.1)
SERVICE_NAME = b'Bumble HID'
SERVICE_DESCRIPTION = b'Bumble'
PROVIDER_NAME = b'Bumble'
HID_PARSER_VERSION = 0x0111 # uint16 0x0111 (v1.1.1)
HID_DEVICE_SUBCLASS = 0xC0 # Combo keyboard/pointing device
HID_COUNTRY_CODE = 0x21 # 0x21 Uint8, USA
HID_VIRTUAL_CABLE = True # Virtual cable enabled
HID_RECONNECT_INITIATE = True # Reconnect initiate enabled
REPORT_DESCRIPTOR_TYPE = 0x22 # 0x22 Type = Report Descriptor
HID_LANGID_BASE_LANGUAGE = 0x0409 # 0x0409 Language = English (United States)
HID_LANGID_BASE_BLUETOOTH_STRING_OFFSET = 0x100 # 0x0100 Default
HID_BATTERY_POWER = True # Battery power enabled
HID_REMOTE_WAKE = True # Remote wake enabled
HID_SUPERVISION_TIMEOUT = 0xC80 # uint16 0xC80 (2s)
HID_NORMALLY_CONNECTABLE = True # Normally connectable enabled
HID_BOOT_DEVICE = True # Boot device support enabled
HID_SSR_HOST_MAX_LATENCY = 0x640 # uint16 0x640 (1s)
HID_SSR_HOST_MIN_TIMEOUT = 0xC80 # uint16 0xC80 (2s)
HID_REPORT_MAP = bytes( # Text String, 50 Octet Report Descriptor
# pylint: disable=line-too-long
[
0x05,
0x01, # Usage Page (Generic Desktop Ctrls)
0x09,
0x06, # Usage (Keyboard)
0xA1,
0x01, # Collection (Application)
0x85,
0x01, # . Report ID (1)
0x05,
0x07, # . Usage Page (Kbrd/Keypad)
0x19,
0xE0, # . Usage Minimum (0xE0)
0x29,
0xE7, # . Usage Maximum (0xE7)
0x15,
0x00, # . Logical Minimum (0)
0x25,
0x01, # . Logical Maximum (1)
0x75,
0x01, # . Report Size (1)
0x95,
0x08, # . Report Count (8)
0x81,
0x02, # . Input (Data,Var,Abs,No Wrap,Linear,Preferred State,No Null Position)
0x95,
0x01, # . Report Count (1)
0x75,
0x08, # . Report Size (8)
0x81,
0x03, # . Input (Const,Var,Abs,No Wrap,Linear,Preferred State,No Null Position)
0x95,
0x05, # . Report Count (5)
0x75,
0x01, # . Report Size (1)
0x05,
0x08, # . Usage Page (LEDs)
0x19,
0x01, # . Usage Minimum (Num Lock)
0x29,
0x05, # . Usage Maximum (Kana)
0x91,
0x02, # . Output (Data,Var,Abs,No Wrap,Linear,Preferred State,No Null Position,Non-volatile)
0x95,
0x01, # . Report Count (1)
0x75,
0x03, # . Report Size (3)
0x91,
0x03, # . Output (Const,Var,Abs,No Wrap,Linear,Preferred State,No Null Position,Non-volatile)
0x95,
0x06, # . Report Count (6)
0x75,
0x08, # . Report Size (8)
0x15,
0x00, # . Logical Minimum (0)
0x25,
0x65, # . Logical Maximum (101)
0x05,
0x07, # . Usage Page (Kbrd/Keypad)
0x19,
0x00, # . Usage Minimum (0x00)
0x29,
0x65, # . Usage Maximum (0x65)
0x81,
0x00, # . Input (Data,Array,Abs,No Wrap,Linear,Preferred State,No Null Position)
0xC0, # End Collection
0x05,
0x01, # Usage Page (Generic Desktop Ctrls)
0x09,
0x02, # Usage (Mouse)
0xA1,
0x01, # Collection (Application)
0x85,
0x02, # . Report ID (2)
0x09,
0x01, # . Usage (Pointer)
0xA1,
0x00, # . Collection (Physical)
0x05,
0x09, # . Usage Page (Button)
0x19,
0x01, # . Usage Minimum (0x01)
0x29,
0x03, # . Usage Maximum (0x03)
0x15,
0x00, # . Logical Minimum (0)
0x25,
0x01, # . Logical Maximum (1)
0x95,
0x03, # . Report Count (3)
0x75,
0x01, # . Report Size (1)
0x81,
0x02, # . Input (Data,Var,Abs,No Wrap,Linear,Preferred State,No Null Position)
0x95,
0x01, # . Report Count (1)
0x75,
0x05, # . Report Size (5)
0x81,
0x03, # . Input (Const,Var,Abs,No Wrap,Linear,Preferred State,No Null Position)
0x05,
0x01, # . Usage Page (Generic Desktop Ctrls)
0x09,
0x30, # . Usage (X)
0x09,
0x31, # . Usage (Y)
0x15,
0x81, # . Logical Minimum (-127)
0x25,
0x7F, # . Logical Maximum (127)
0x75,
0x08, # . Report Size (8)
0x95,
0x02, # . Report Count (2)
0x81,
0x06, # . Input (Data,Var,Rel,No Wrap,Linear,Preferred State,No Null Position)
0xC0, # . End Collection
0xC0, # End Collection
]
)
# Default protocol mode set to report protocol
protocol_mode = Message.ProtocolMode.REPORT_PROTOCOL
# -----------------------------------------------------------------------------
def sdp_records():
service_record_handle = 0x00010002
return {
service_record_handle: [
ServiceAttribute(
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
DataElement.unsigned_integer_32(service_record_handle),
),
ServiceAttribute(
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
DataElement.sequence([DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)]),
),
ServiceAttribute(
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[DataElement.uuid(BT_HUMAN_INTERFACE_DEVICE_SERVICE)]
),
),
ServiceAttribute(
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.sequence(
[
DataElement.uuid(BT_L2CAP_PROTOCOL_ID),
DataElement.unsigned_integer_16(HID_CONTROL_PSM),
]
),
DataElement.sequence(
[
DataElement.uuid(BT_HIDP_PROTOCOL_ID),
]
),
]
),
),
ServiceAttribute(
SDP_LANGUAGE_BASE_ATTRIBUTE_ID_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.unsigned_integer_16(LANGUAGE),
DataElement.unsigned_integer_16(ENCODING),
DataElement.unsigned_integer_16(PRIMARY_LANGUAGE_BASE_ID),
]
),
),
ServiceAttribute(
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.sequence(
[
DataElement.uuid(BT_HUMAN_INTERFACE_DEVICE_SERVICE),
DataElement.unsigned_integer_16(VERSION_NUMBER),
]
),
]
),
),
ServiceAttribute(
SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.sequence(
[
DataElement.sequence(
[
DataElement.uuid(BT_L2CAP_PROTOCOL_ID),
DataElement.unsigned_integer_16(
HID_INTERRUPT_PSM
),
]
),
DataElement.sequence(
[
DataElement.uuid(BT_HIDP_PROTOCOL_ID),
]
),
]
),
]
),
),
ServiceAttribute(
SDP_HID_SERVICE_NAME_ATTRIBUTE_ID,
DataElement(DataElement.TEXT_STRING, SERVICE_NAME),
),
ServiceAttribute(
SDP_HID_SERVICE_DESCRIPTION_ATTRIBUTE_ID,
DataElement(DataElement.TEXT_STRING, SERVICE_DESCRIPTION),
),
ServiceAttribute(
SDP_HID_PROVIDER_NAME_ATTRIBUTE_ID,
DataElement(DataElement.TEXT_STRING, PROVIDER_NAME),
),
ServiceAttribute(
SDP_HID_PARSER_VERSION_ATTRIBUTE_ID,
DataElement.unsigned_integer_32(HID_PARSER_VERSION),
),
ServiceAttribute(
SDP_HID_DEVICE_SUBCLASS_ATTRIBUTE_ID,
DataElement.unsigned_integer_32(HID_DEVICE_SUBCLASS),
),
ServiceAttribute(
SDP_HID_COUNTRY_CODE_ATTRIBUTE_ID,
DataElement.unsigned_integer_32(HID_COUNTRY_CODE),
),
ServiceAttribute(
SDP_HID_VIRTUAL_CABLE_ATTRIBUTE_ID,
DataElement.boolean(HID_VIRTUAL_CABLE),
),
ServiceAttribute(
SDP_HID_RECONNECT_INITIATE_ATTRIBUTE_ID,
DataElement.boolean(HID_RECONNECT_INITIATE),
),
ServiceAttribute(
SDP_HID_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.sequence(
[
DataElement.unsigned_integer_16(REPORT_DESCRIPTOR_TYPE),
DataElement(DataElement.TEXT_STRING, HID_REPORT_MAP),
]
),
]
),
),
ServiceAttribute(
SDP_HID_LANGID_BASE_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.sequence(
[
DataElement.unsigned_integer_16(
HID_LANGID_BASE_LANGUAGE
),
DataElement.unsigned_integer_16(
HID_LANGID_BASE_BLUETOOTH_STRING_OFFSET
),
]
),
]
),
),
ServiceAttribute(
SDP_HID_BATTERY_POWER_ATTRIBUTE_ID,
DataElement.boolean(HID_BATTERY_POWER),
),
ServiceAttribute(
SDP_HID_REMOTE_WAKE_ATTRIBUTE_ID,
DataElement.boolean(HID_REMOTE_WAKE),
),
ServiceAttribute(
SDP_HID_SUPERVISION_TIMEOUT_ATTRIBUTE_ID,
DataElement.unsigned_integer_16(HID_SUPERVISION_TIMEOUT),
),
ServiceAttribute(
SDP_HID_NORMALLY_CONNECTABLE_ATTRIBUTE_ID,
DataElement.boolean(HID_NORMALLY_CONNECTABLE),
),
ServiceAttribute(
SDP_HID_BOOT_DEVICE_ATTRIBUTE_ID,
DataElement.boolean(HID_BOOT_DEVICE),
),
ServiceAttribute(
SDP_HID_SSR_HOST_MAX_LATENCY_ATTRIBUTE_ID,
DataElement.unsigned_integer_16(HID_SSR_HOST_MAX_LATENCY),
),
ServiceAttribute(
SDP_HID_SSR_HOST_MIN_TIMEOUT_ATTRIBUTE_ID,
DataElement.unsigned_integer_16(HID_SSR_HOST_MIN_TIMEOUT),
),
]
}
# -----------------------------------------------------------------------------
async def get_stream_reader(pipe) -> asyncio.StreamReader:
loop = asyncio.get_event_loop()
reader = asyncio.StreamReader(loop=loop)
protocol = asyncio.StreamReaderProtocol(reader)
await loop.connect_read_pipe(lambda: protocol, pipe)
return reader
class DeviceData:
def __init__(self) -> None:
self.keyboardData = bytearray(
[0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]
)
self.mouseData = bytearray([0x02, 0x00, 0x00, 0x00])
# Device's live data - Mouse and Keyboard will be stored in this
deviceData = DeviceData()
# -----------------------------------------------------------------------------
async def keyboard_device(hid_device):
# Start a Websocket server to receive events from a web page
async def serve(websocket, _path):
global deviceData
while True:
try:
message = await websocket.recv()
print('Received: ', str(message))
parsed = json.loads(message)
message_type = parsed['type']
if message_type == 'keydown':
# Only deal with keys a to z for now
key = parsed['key']
if len(key) == 1:
code = ord(key)
if ord('a') <= code <= ord('z'):
hid_code = 0x04 + code - ord('a')
deviceData.keyboardData = bytearray(
[
0x01,
0x00,
0x00,
hid_code,
0x00,
0x00,
0x00,
0x00,
0x00,
]
)
hid_device.send_data(deviceData.keyboardData)
elif message_type == 'keyup':
deviceData.keyboardData = bytearray(
[0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]
)
hid_device.send_data(deviceData.keyboardData)
elif message_type == "mousemove":
# logical min and max values
log_min = -127
log_max = 127
x = parsed['x']
y = parsed['y']
# limiting x and y values within logical max and min range
x = max(log_min, min(log_max, x))
y = max(log_min, min(log_max, y))
x_cord = x.to_bytes(signed=True)
y_cord = y.to_bytes(signed=True)
deviceData.mouseData = bytearray([0x02, 0x00]) + x_cord + y_cord
hid_device.send_data(deviceData.mouseData)
except websockets.exceptions.ConnectionClosedOK:
pass
# pylint: disable-next=no-member
await websockets.serve(serve, 'localhost', 8989)
await asyncio.get_event_loop().create_future()
# -----------------------------------------------------------------------------
async def main():
if len(sys.argv) < 3:
print(
'Usage: python run_hid_device.py <device-config> <transport-spec> <command>'
' where <command> is one of:\n'
' test-mode (run with menu enabled for testing)\n'
' web (run a keyboard with keypress input from a web page, '
'see keyboard.html'
)
print('example: python run_hid_device.py hid_keyboard.json usb:0 web')
print('example: python run_hid_device.py hid_keyboard.json usb:0 test-mode')
return
async def handle_virtual_cable_unplug():
hid_host_bd_addr = str(hid_device.remote_device_bd_address)
await hid_device.disconnect_interrupt_channel()
await hid_device.disconnect_control_channel()
await device.keystore.delete(hid_host_bd_addr) # type: ignore
connection = hid_device.connection
if connection is not None:
await connection.disconnect()
def on_hid_data_cb(pdu: bytes):
print(f'Received Data, PDU: {pdu.hex()}')
def on_get_report_cb(report_id: int, report_type: int, buffer_size: int):
retValue = hid_device.GetSetStatus()
print(
"GET_REPORT report_id: "
+ str(report_id)
+ "report_type: "
+ str(report_type)
+ "buffer_size:"
+ str(buffer_size)
)
if report_type == Message.ReportType.INPUT_REPORT:
if report_id == 1:
retValue.data = deviceData.keyboardData[1:]
retValue.status = hid_device.GetSetReturn.SUCCESS
elif report_id == 2:
retValue.data = deviceData.mouseData[1:]
retValue.status = hid_device.GetSetReturn.SUCCESS
else:
retValue.status = hid_device.GetSetReturn.REPORT_ID_NOT_FOUND
if buffer_size:
data_len = buffer_size - 1
retValue.data = retValue.data[:data_len]
elif report_type == Message.ReportType.OUTPUT_REPORT:
# This sample app has nothing to do with the report received, to enable PTS
# testing, we will return single byte random data.
retValue.data = bytearray([0x11])
retValue.status = hid_device.GetSetReturn.SUCCESS
elif report_type == Message.ReportType.FEATURE_REPORT:
retValue.status = hid_device.GetSetReturn.ERR_INVALID_PARAMETER
elif report_type == Message.ReportType.OTHER_REPORT:
if report_id == 3:
retValue.status = hid_device.GetSetReturn.REPORT_ID_NOT_FOUND
else:
retValue.status = hid_device.GetSetReturn.FAILURE
return retValue
def on_set_report_cb(
report_id: int, report_type: int, report_size: int, data: bytes
):
retValue = hid_device.GetSetStatus()
print(
"SET_REPORT report_id: "
+ str(report_id)
+ "report_type: "
+ str(report_type)
+ "report_size "
+ str(report_size)
+ "data:"
+ str(data)
)
if report_type == Message.ReportType.FEATURE_REPORT:
retValue.status = hid_device.GetSetReturn.ERR_INVALID_PARAMETER
elif report_type == Message.ReportType.INPUT_REPORT:
if report_id == 1 and report_size != len(deviceData.keyboardData):
retValue.status = hid_device.GetSetReturn.ERR_INVALID_PARAMETER
elif report_id == 2 and report_size != len(deviceData.mouseData):
retValue.status = hid_device.GetSetReturn.ERR_INVALID_PARAMETER
elif report_id == 3:
retValue.status = hid_device.GetSetReturn.REPORT_ID_NOT_FOUND
else:
retValue.status = hid_device.GetSetReturn.SUCCESS
else:
retValue.status = hid_device.GetSetReturn.SUCCESS
return retValue
def on_get_protocol_cb():
retValue = hid_device.GetSetStatus()
retValue.data = protocol_mode.to_bytes()
retValue.status = hid_device.GetSetReturn.SUCCESS
return retValue
def on_set_protocol_cb(protocol: int):
retValue = hid_device.GetSetStatus()
# We do not support SET_PROTOCOL.
print(f"SET_PROTOCOL report_id: {protocol}")
retValue.status = hid_device.GetSetReturn.ERR_UNSUPPORTED_REQUEST
return retValue
def on_virtual_cable_unplug_cb():
print('Received Virtual Cable Unplug')
asyncio.create_task(handle_virtual_cable_unplug())
print('<<< connecting to HCI...')
async with await open_transport_or_link(sys.argv[2]) as (hci_source, hci_sink):
print('<<< connected')
# Create a device
device = Device.from_config_file_with_hci(sys.argv[1], hci_source, hci_sink)
device.classic_enabled = True
# Create and register HID device
hid_device = HID_Device(device)
# Register for call backs
hid_device.on('interrupt_data', on_hid_data_cb)
hid_device.register_get_report_cb(on_get_report_cb)
hid_device.register_set_report_cb(on_set_report_cb)
hid_device.register_get_protocol_cb(on_get_protocol_cb)
hid_device.register_set_protocol_cb(on_set_protocol_cb)
# Register for virtual cable unplug call back
hid_device.on('virtual_cable_unplug', on_virtual_cable_unplug_cb)
# Setup the SDP to advertise HID Device service
device.sdp_service_records = sdp_records()
# Start the controller
await device.power_on()
# Start being discoverable and connectable
await device.set_discoverable(True)
await device.set_connectable(True)
async def menu():
reader = await get_stream_reader(sys.stdin)
while True:
print(
"\n************************ HID Device Menu *****************************\n"
)
print(" 1. Connect Control Channel")
print(" 2. Connect Interrupt Channel")
print(" 3. Disconnect Control Channel")
print(" 4. Disconnect Interrupt Channel")
print(" 5. Send Report on Interrupt Channel")
print(" 6. Virtual Cable Unplug")
print(" 7. Disconnect device")
print(" 8. Delete Bonding")
print(" 9. Re-connect to device")
print("10. Exit ")
print("\nEnter your choice : \n")
choice = await reader.readline()
choice = choice.decode('utf-8').strip()
if choice == '1':
await hid_device.connect_control_channel()
elif choice == '2':
await hid_device.connect_interrupt_channel()
elif choice == '3':
await hid_device.disconnect_control_channel()
elif choice == '4':
await hid_device.disconnect_interrupt_channel()
elif choice == '5':
print(" 1. Report ID 0x01")
print(" 2. Report ID 0x02")
print(" 3. Invalid Report ID")
choice1 = await reader.readline()
choice1 = choice1.decode('utf-8').strip()
if choice1 == '1':
data = bytearray(
[0x01, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00]
)
hid_device.send_data(data)
data = bytearray(
[0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]
)
hid_device.send_data(data)
elif choice1 == '2':
data = bytearray([0x02, 0x00, 0x00, 0xF6])
hid_device.send_data(data)
data = bytearray([0x02, 0x00, 0x00, 0x00])
hid_device.send_data(data)
elif choice1 == '3':
data = bytearray([0x00, 0x00, 0x00, 0x00])
hid_device.send_data(data)
data = bytearray([0x00, 0x00, 0x00, 0x00])
hid_device.send_data(data)
else:
print('Incorrect option selected')
elif choice == '6':
hid_device.virtual_cable_unplug()
try:
hid_host_bd_addr = str(hid_device.remote_device_bd_address)
await device.keystore.delete(hid_host_bd_addr)
except KeyError:
print('Device not found or Device already unpaired.')
elif choice == '7':
connection = hid_device.connection
if connection is not None:
await connection.disconnect()
else:
print("Already disconnected from device")
elif choice == '8':
try:
hid_host_bd_addr = str(hid_device.remote_device_bd_address)
await device.keystore.delete(hid_host_bd_addr)
except KeyError:
print('Device NOT found or Device already unpaired.')
elif choice == '9':
hid_host_bd_addr = str(hid_device.remote_device_bd_address)
connection = await device.connect(
hid_host_bd_addr, transport=BT_BR_EDR_TRANSPORT
)
await connection.authenticate()
await connection.encrypt()
elif choice == '10':
sys.exit("Exit successful")
else:
print("Invalid option selected.")
if (len(sys.argv) > 3) and (sys.argv[3] == 'test-mode'):
# Test mode for PTS/Unit testing
await menu()
else:
# default option is using keyboard.html (web)
print("Executing in Web mode")
await keyboard_device(hid_device)
await hci_source.wait_for_termination()
# -----------------------------------------------------------------------------
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main())
+51 -20
View File
@@ -285,7 +285,10 @@ async def main():
print('example: run_hid_host.py classic1.json usb:0 E1:CA:72:48:C4:E8/P') print('example: run_hid_host.py classic1.json usb:0 E1:CA:72:48:C4:E8/P')
return return
def on_hid_data_cb(pdu): def on_hid_control_data_cb(pdu: bytes):
print(f'Received Control Data, PDU: {pdu.hex()}')
def on_hid_interrupt_data_cb(pdu: bytes):
report_type = pdu[0] & 0x0F report_type = pdu[0] & 0x0F
if len(pdu) == 1: if len(pdu) == 1:
print(color(f'Warning: No report received', 'yellow')) print(color(f'Warning: No report received', 'yellow'))
@@ -305,7 +308,7 @@ async def main():
if (report_length <= 1) or (report_id == 0): if (report_length <= 1) or (report_id == 0):
return return
# Parse report over interrupt channel
if report_type == Message.ReportType.INPUT_REPORT: if report_type == Message.ReportType.INPUT_REPORT:
ReportParser.parse_input_report(pdu[1:]) # type: ignore ReportParser.parse_input_report(pdu[1:]) # type: ignore
@@ -313,7 +316,9 @@ async def main():
await hid_host.disconnect_interrupt_channel() await hid_host.disconnect_interrupt_channel()
await hid_host.disconnect_control_channel() await hid_host.disconnect_control_channel()
await device.keystore.delete(target_address) # type: ignore await device.keystore.delete(target_address) # type: ignore
await connection.disconnect() connection = hid_host.connection
if connection is not None:
await connection.disconnect()
def on_hid_virtual_cable_unplug_cb(): def on_hid_virtual_cable_unplug_cb():
asyncio.create_task(handle_virtual_cable_unplug()) asyncio.create_task(handle_virtual_cable_unplug())
@@ -325,6 +330,18 @@ async def main():
# Create a device # Create a device
device = Device.from_config_file_with_hci(sys.argv[1], hci_source, hci_sink) device = Device.from_config_file_with_hci(sys.argv[1], hci_source, hci_sink)
device.classic_enabled = True device.classic_enabled = True
# Create HID host and start it
print('@@@ Starting HID Host...')
hid_host = Host(device)
# Register for HID data call back
hid_host.on('interrupt_data', on_hid_interrupt_data_cb)
hid_host.on('control_data', on_hid_control_data_cb)
# Register for virtual cable unplug call back
hid_host.on('virtual_cable_unplug', on_hid_virtual_cable_unplug_cb)
await device.power_on() await device.power_on()
# Connect to a peer # Connect to a peer
@@ -345,16 +362,6 @@ async def main():
await get_hid_device_sdp_record(connection) await get_hid_device_sdp_record(connection)
# Create HID host and start it
print('@@@ Starting HID Host...')
hid_host = Host(device, connection)
# Register for HID data call back
hid_host.on('data', on_hid_data_cb)
# Register for virtual cable unplug call back
hid_host.on('virtual_cable_unplug', on_hid_virtual_cable_unplug_cb)
async def menu(): async def menu():
reader = await get_stream_reader(sys.stdin) reader = await get_stream_reader(sys.stdin)
while True: while True:
@@ -369,13 +376,14 @@ async def main():
print(" 6. Set Report") print(" 6. Set Report")
print(" 7. Set Protocol Mode") print(" 7. Set Protocol Mode")
print(" 8. Get Protocol Mode") print(" 8. Get Protocol Mode")
print(" 9. Send Report") print(" 9. Send Report on Interrupt Channel")
print("10. Suspend") print("10. Suspend")
print("11. Exit Suspend") print("11. Exit Suspend")
print("12. Virtual Cable Unplug") print("12. Virtual Cable Unplug")
print("13. Disconnect device") print("13. Disconnect device")
print("14. Delete Bonding") print("14. Delete Bonding")
print("15. Re-connect to device") print("15. Re-connect to device")
print("16. Exit")
print("\nEnter your choice : \n") print("\nEnter your choice : \n")
choice = await reader.readline() choice = await reader.readline()
@@ -394,21 +402,40 @@ async def main():
await hid_host.disconnect_interrupt_channel() await hid_host.disconnect_interrupt_channel()
elif choice == '5': elif choice == '5':
print(" 1. Report ID 0x02") print(" 1. Input Report with ID 0x01")
print(" 2. Report ID 0x03") print(" 2. Input Report with ID 0x02")
print(" 3. Report ID 0x05") print(" 3. Input Report with ID 0x0F - Invalid ReportId")
print(" 4. Output Report with ID 0x02")
print(" 5. Feature Report with ID 0x05 - Unsupported Request")
print(" 6. Input Report with ID 0x02, BufferSize 3")
print(" 7. Output Report with ID 0x03, BufferSize 2")
print(" 8. Feature Report with ID 0x05, BufferSize 3")
choice1 = await reader.readline() choice1 = await reader.readline()
choice1 = choice1.decode('utf-8').strip() choice1 = choice1.decode('utf-8').strip()
if choice1 == '1': if choice1 == '1':
hid_host.get_report(1, 2, 3) hid_host.get_report(1, 1, 0)
elif choice1 == '2': elif choice1 == '2':
hid_host.get_report(2, 3, 2) hid_host.get_report(1, 2, 0)
elif choice1 == '3': elif choice1 == '3':
hid_host.get_report(3, 5, 3) hid_host.get_report(1, 5, 0)
elif choice1 == '4':
hid_host.get_report(2, 2, 0)
elif choice1 == '5':
hid_host.get_report(3, 15, 0)
elif choice1 == '6':
hid_host.get_report(1, 2, 3)
elif choice1 == '7':
hid_host.get_report(2, 3, 2)
elif choice1 == '8':
hid_host.get_report(3, 5, 3)
else: else:
print('Incorrect option selected') print('Incorrect option selected')
@@ -484,6 +511,7 @@ async def main():
hid_host.virtual_cable_unplug() hid_host.virtual_cable_unplug()
try: try:
await device.keystore.delete(target_address) await device.keystore.delete(target_address)
print("Unpair successful")
except KeyError: except KeyError:
print('Device not found or Device already unpaired.') print('Device not found or Device already unpaired.')
@@ -513,6 +541,9 @@ async def main():
await connection.authenticate() await connection.authenticate()
await connection.encrypt() await connection.encrypt()
elif choice == '16':
sys.exit("Exit successful")
else: else:
print("Invalid option selected.") print("Invalid option selected.")
+74 -13
View File
@@ -19,12 +19,15 @@ import asyncio
import logging import logging
import sys import sys
import os import os
import struct
import secrets
from bumble.core import AdvertisingData from bumble.core import AdvertisingData
from bumble.device import Device from bumble.device import Device, CisLink
from bumble.hci import ( from bumble.hci import (
CodecID, CodecID,
CodingFormat, CodingFormat,
OwnAddressType, OwnAddressType,
HCI_IsoDataPacket,
HCI_LE_Set_Extended_Advertising_Parameters_Command, HCI_LE_Set_Extended_Advertising_Parameters_Command,
) )
from bumble.profiles.bap import ( from bumble.profiles.bap import (
@@ -35,7 +38,10 @@ from bumble.profiles.bap import (
SupportedFrameDuration, SupportedFrameDuration,
PacRecord, PacRecord,
PublishedAudioCapabilitiesService, PublishedAudioCapabilitiesService,
AudioStreamControlService,
) )
from bumble.profiles.cap import CommonAudioServiceService
from bumble.profiles.csip import CoordinatedSetIdentificationService, SirkType
from bumble.transport import open_transport_or_link from bumble.transport import open_transport_or_link
@@ -57,6 +63,11 @@ async def main() -> None:
await device.power_on() await device.power_on()
csis = CoordinatedSetIdentificationService(
set_identity_resolving_key=secrets.token_bytes(16),
set_identity_resolving_key_type=SirkType.PLAINTEXT,
)
device.add_service(CommonAudioServiceService(csis))
device.add_service( device.add_service(
PublishedAudioCapabilitiesService( PublishedAudioCapabilitiesService(
supported_source_context=ContextType.PROHIBITED, supported_source_context=ContextType.PROHIBITED,
@@ -103,21 +114,71 @@ async def main() -> None:
) )
) )
advertising_data = bytes( device.add_service(AudioStreamControlService(device, sink_ase_id=[1, 2]))
AdvertisingData(
[ advertising_data = (
( bytes(
AdvertisingData.COMPLETE_LOCAL_NAME, AdvertisingData(
bytes('Bumble LE Audio', 'utf-8'), [
), (
( AdvertisingData.COMPLETE_LOCAL_NAME,
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS, bytes('Bumble LE Audio', 'utf-8'),
bytes(PublishedAudioCapabilitiesService.UUID), ),
), (
] AdvertisingData.FLAGS,
bytes(
[
AdvertisingData.LE_GENERAL_DISCOVERABLE_MODE_FLAG
| AdvertisingData.BR_EDR_HOST_FLAG
| AdvertisingData.BR_EDR_CONTROLLER_FLAG
]
),
),
(
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
bytes(PublishedAudioCapabilitiesService.UUID),
),
]
)
)
+ csis.get_advertising_data()
)
subprocess = await asyncio.create_subprocess_shell(
f'dlc3 | ffplay pipe:0',
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdin = subprocess.stdin
assert stdin
# Write a fake LC3 header to dlc3.
stdin.write(
bytes([0x1C, 0xCC]) # Header.
+ struct.pack(
'<HHHHHHI',
18, # Header length.
24000 // 100, # Sampling Rate(/100Hz).
0, # Bitrate(unused).
1, # Channels.
10000 // 10, # Frame duration(/10us).
0, # RFU.
0x0FFFFFFF, # Frame counts.
) )
) )
def on_pdu(pdu: HCI_IsoDataPacket):
# LC3 format: |frame_length(2)| + |frame(length)|.
if pdu.iso_sdu_length:
stdin.write(struct.pack('<H', pdu.iso_sdu_length))
stdin.write(pdu.iso_sdu_fragment)
def on_cis(cis_link: CisLink):
cis_link.on('pdu', on_pdu)
device.once('cis_establishment', on_cis)
await device.start_extended_advertising( await device.start_extended_advertising(
advertising_properties=( advertising_properties=(
HCI_LE_Set_Extended_Advertising_Parameters_Command.AdvertisingProperties.CONNECTABLE_ADVERTISING HCI_LE_Set_Extended_Advertising_Parameters_Command.AdvertisingProperties.CONNECTABLE_ADVERTISING
@@ -28,8 +28,8 @@ private val Log = Logger.getLogger("btbench.l2cap-client")
class L2capClient( class L2capClient(
private val viewModel: AppViewModel, private val viewModel: AppViewModel,
val bluetoothAdapter: BluetoothAdapter, private val bluetoothAdapter: BluetoothAdapter,
val context: Context private val context: Context
) { ) {
@SuppressLint("MissingPermission") @SuppressLint("MissingPermission")
fun run() { fun run() {
@@ -74,12 +74,18 @@ class L2capClient(
gatt: BluetoothGatt?, status: Int, newState: Int gatt: BluetoothGatt?, status: Int, newState: Int
) { ) {
if (gatt != null && newState == BluetoothProfile.STATE_CONNECTED) { if (gatt != null && newState == BluetoothProfile.STATE_CONNECTED) {
gatt.setPreferredPhy( if (viewModel.use2mPhy) {
BluetoothDevice.PHY_LE_2M_MASK, gatt.setPreferredPhy(
BluetoothDevice.PHY_LE_2M_MASK, BluetoothDevice.PHY_LE_2M_MASK,
BluetoothDevice.PHY_OPTION_NO_PREFERRED BluetoothDevice.PHY_LE_2M_MASK,
) BluetoothDevice.PHY_OPTION_NO_PREFERRED
)
}
gatt.readPhy() gatt.readPhy()
// Request an MTU update, even though we don't use GATT, because Android
// won't request a larger link layer maximum data length otherwise.
gatt.requestMtu(517)
} }
} }
}, },
@@ -23,19 +23,20 @@ import androidx.compose.runtime.setValue
import androidx.lifecycle.ViewModel import androidx.lifecycle.ViewModel
import java.util.UUID import java.util.UUID
val DEFAULT_RFCOMM_UUID = UUID.fromString("E6D55659-C8B4-4B85-96BB-B1143AF6D3AE") val DEFAULT_RFCOMM_UUID: UUID = UUID.fromString("E6D55659-C8B4-4B85-96BB-B1143AF6D3AE")
const val DEFAULT_PEER_BLUETOOTH_ADDRESS = "AA:BB:CC:DD:EE:FF" const val DEFAULT_PEER_BLUETOOTH_ADDRESS = "AA:BB:CC:DD:EE:FF"
const val DEFAULT_SENDER_PACKET_COUNT = 100 const val DEFAULT_SENDER_PACKET_COUNT = 100
const val DEFAULT_SENDER_PACKET_SIZE = 1024 const val DEFAULT_SENDER_PACKET_SIZE = 1024
const val DEFAULT_PSM = 128
class AppViewModel : ViewModel() { class AppViewModel : ViewModel() {
private var preferences: SharedPreferences? = null private var preferences: SharedPreferences? = null
var peerBluetoothAddress by mutableStateOf(DEFAULT_PEER_BLUETOOTH_ADDRESS) var peerBluetoothAddress by mutableStateOf(DEFAULT_PEER_BLUETOOTH_ADDRESS)
var l2capPsm by mutableStateOf(0) var l2capPsm by mutableIntStateOf(DEFAULT_PSM)
var use2mPhy by mutableStateOf(true) var use2mPhy by mutableStateOf(true)
var mtu by mutableStateOf(0) var mtu by mutableIntStateOf(0)
var rxPhy by mutableStateOf(0) var rxPhy by mutableIntStateOf(0)
var txPhy by mutableStateOf(0) var txPhy by mutableIntStateOf(0)
var senderPacketCountSlider by mutableFloatStateOf(0.0F) var senderPacketCountSlider by mutableFloatStateOf(0.0F)
var senderPacketSizeSlider by mutableFloatStateOf(0.0F) var senderPacketSizeSlider by mutableFloatStateOf(0.0F)
var senderPacketCount by mutableIntStateOf(DEFAULT_SENDER_PACKET_COUNT) var senderPacketCount by mutableIntStateOf(DEFAULT_SENDER_PACKET_COUNT)
@@ -79,18 +80,18 @@ class AppViewModel : ViewModel() {
} }
fun updateSenderPacketCountSlider() { fun updateSenderPacketCountSlider() {
if (senderPacketCount <= 10) { senderPacketCountSlider = if (senderPacketCount <= 10) {
senderPacketCountSlider = 0.0F 0.0F
} else if (senderPacketCount <= 50) { } else if (senderPacketCount <= 50) {
senderPacketCountSlider = 0.2F 0.2F
} else if (senderPacketCount <= 100) { } else if (senderPacketCount <= 100) {
senderPacketCountSlider = 0.4F 0.4F
} else if (senderPacketCount <= 500) { } else if (senderPacketCount <= 500) {
senderPacketCountSlider = 0.6F 0.6F
} else if (senderPacketCount <= 1000) { } else if (senderPacketCount <= 1000) {
senderPacketCountSlider = 0.8F 0.8F
} else { } else {
senderPacketCountSlider = 1.0F 1.0F
} }
with(preferences!!.edit()) { with(preferences!!.edit()) {
@@ -100,18 +101,18 @@ class AppViewModel : ViewModel() {
} }
fun updateSenderPacketCount() { fun updateSenderPacketCount() {
if (senderPacketCountSlider < 0.1F) { senderPacketCount = if (senderPacketCountSlider < 0.1F) {
senderPacketCount = 10 10
} else if (senderPacketCountSlider < 0.3F) { } else if (senderPacketCountSlider < 0.3F) {
senderPacketCount = 50 50
} else if (senderPacketCountSlider < 0.5F) { } else if (senderPacketCountSlider < 0.5F) {
senderPacketCount = 100 100
} else if (senderPacketCountSlider < 0.7F) { } else if (senderPacketCountSlider < 0.7F) {
senderPacketCount = 500 500
} else if (senderPacketCountSlider < 0.9F) { } else if (senderPacketCountSlider < 0.9F) {
senderPacketCount = 1000 1000
} else { } else {
senderPacketCount = 10000 10000
} }
with(preferences!!.edit()) { with(preferences!!.edit()) {
@@ -121,18 +122,18 @@ class AppViewModel : ViewModel() {
} }
fun updateSenderPacketSizeSlider() { fun updateSenderPacketSizeSlider() {
if (senderPacketSize <= 16) { senderPacketSizeSlider = if (senderPacketSize <= 16) {
senderPacketSizeSlider = 0.0F 0.0F
} else if (senderPacketSize <= 256) { } else if (senderPacketSize <= 256) {
senderPacketSizeSlider = 0.02F 0.02F
} else if (senderPacketSize <= 512) { } else if (senderPacketSize <= 512) {
senderPacketSizeSlider = 0.4F 0.4F
} else if (senderPacketSize <= 1024) { } else if (senderPacketSize <= 1024) {
senderPacketSizeSlider = 0.6F 0.6F
} else if (senderPacketSize <= 2048) { } else if (senderPacketSize <= 2048) {
senderPacketSizeSlider = 0.8F 0.8F
} else { } else {
senderPacketSizeSlider = 1.0F 1.0F
} }
with(preferences!!.edit()) { with(preferences!!.edit()) {
@@ -142,18 +143,18 @@ class AppViewModel : ViewModel() {
} }
fun updateSenderPacketSize() { fun updateSenderPacketSize() {
if (senderPacketSizeSlider < 0.1F) { senderPacketSize = if (senderPacketSizeSlider < 0.1F) {
senderPacketSize = 16 16
} else if (senderPacketSizeSlider < 0.3F) { } else if (senderPacketSizeSlider < 0.3F) {
senderPacketSize = 256 256
} else if (senderPacketSizeSlider < 0.5F) { } else if (senderPacketSizeSlider < 0.5F) {
senderPacketSize = 512 512
} else if (senderPacketSizeSlider < 0.7F) { } else if (senderPacketSizeSlider < 0.7F) {
senderPacketSize = 1024 1024
} else if (senderPacketSizeSlider < 0.9F) { } else if (senderPacketSizeSlider < 0.9F) {
senderPacketSize = 2048 2048
} else { } else {
senderPacketSize = 4096 4096
} }
with(preferences!!.edit()) { with(preferences!!.edit()) {
@@ -42,6 +42,7 @@ public class HciServer {
try (ServerSocket serverSocket = new ServerSocket(mPort)) { try (ServerSocket serverSocket = new ServerSocket(mPort)) {
mListener.onMessage("Waiting for connection on port " + serverSocket.getLocalPort()); mListener.onMessage("Waiting for connection on port " + serverSocket.getLocalPort());
try (Socket clientSocket = serverSocket.accept()) { try (Socket clientSocket = serverSocket.accept()) {
clientSocket.setTcpNoDelay(true);
mListener.onHostConnectionState(true); mListener.onHostConnectionState(true);
mListener.onMessage("Connected"); mListener.onMessage("Connected");
HciParser parser = new HciParser(mListener); HciParser parser = new HciParser(mListener);
+1
View File
@@ -56,6 +56,7 @@ install_requires =
[options.entry_points] [options.entry_points]
console_scripts = console_scripts =
bumble-ble-rpa-tool = bumble.apps.ble_rpa_tool:main
bumble-console = bumble.apps.console:main bumble-console = bumble.apps.console:main
bumble-controller-info = bumble.apps.controller_info:main bumble-controller-info = bumble.apps.controller_info:main
bumble-gatt-dump = bumble.apps.gatt_dump:main bumble-gatt-dump = bumble.apps.gatt_dump:main
+253 -1
View File
@@ -17,6 +17,7 @@
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
import asyncio import asyncio
import os import os
import functools
import pytest import pytest
import logging import logging
@@ -24,15 +25,31 @@ from bumble import device
from bumble.hci import CodecID, CodingFormat from bumble.hci import CodecID, CodingFormat
from bumble.profiles.bap import ( from bumble.profiles.bap import (
AudioLocation, AudioLocation,
AseStateMachine,
ASE_Operation,
ASE_Config_Codec,
ASE_Config_QOS,
ASE_Disable,
ASE_Enable,
ASE_Receiver_Start_Ready,
ASE_Receiver_Stop_Ready,
ASE_Release,
ASE_Update_Metadata,
SupportedFrameDuration, SupportedFrameDuration,
SupportedSamplingFrequency, SupportedSamplingFrequency,
SamplingFrequency,
FrameDuration,
CodecSpecificCapabilities, CodecSpecificCapabilities,
CodecSpecificConfiguration,
ContextType, ContextType,
PacRecord, PacRecord,
AudioStreamControlService,
AudioStreamControlServiceProxy,
PublishedAudioCapabilitiesService, PublishedAudioCapabilitiesService,
PublishedAudioCapabilitiesServiceProxy, PublishedAudioCapabilitiesServiceProxy,
) )
from .test_utils import TwoDevices from tests.test_utils import TwoDevices
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
@@ -40,6 +57,13 @@ from .test_utils import TwoDevices
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
def basic_check(operation: ASE_Operation):
serialized = bytes(operation)
parsed = ASE_Operation.from_bytes(serialized)
assert bytes(parsed) == serialized
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def test_codec_specific_capabilities() -> None: def test_codec_specific_capabilities() -> None:
SAMPLE_FREQUENCY = SupportedSamplingFrequency.FREQ_16000 SAMPLE_FREQUENCY = SupportedSamplingFrequency.FREQ_16000
@@ -85,6 +109,92 @@ def test_vendor_specific_pac_record() -> None:
assert bytes(PacRecord.from_bytes(RAW_DATA)) == RAW_DATA assert bytes(PacRecord.from_bytes(RAW_DATA)) == RAW_DATA
# -----------------------------------------------------------------------------
def test_ASE_Config_Codec() -> None:
operation = ASE_Config_Codec(
ase_id=[1, 2],
target_latency=[3, 4],
target_phy=[5, 6],
codec_id=[CodingFormat(CodecID.LC3), CodingFormat(CodecID.LC3)],
codec_specific_configuration=[b'foo', b'bar'],
)
basic_check(operation)
# -----------------------------------------------------------------------------
def test_ASE_Config_QOS() -> None:
operation = ASE_Config_QOS(
ase_id=[1, 2],
cig_id=[1, 2],
cis_id=[3, 4],
sdu_interval=[5, 6],
framing=[0, 1],
phy=[2, 3],
max_sdu=[4, 5],
retransmission_number=[6, 7],
max_transport_latency=[8, 9],
presentation_delay=[10, 11],
)
basic_check(operation)
# -----------------------------------------------------------------------------
def test_ASE_Enable() -> None:
operation = ASE_Enable(
ase_id=[1, 2],
metadata=[b'foo', b'bar'],
)
basic_check(operation)
# -----------------------------------------------------------------------------
def test_ASE_Update_Metadata() -> None:
operation = ASE_Update_Metadata(
ase_id=[1, 2],
metadata=[b'foo', b'bar'],
)
basic_check(operation)
# -----------------------------------------------------------------------------
def test_ASE_Disable() -> None:
operation = ASE_Disable(ase_id=[1, 2])
basic_check(operation)
# -----------------------------------------------------------------------------
def test_ASE_Release() -> None:
operation = ASE_Release(ase_id=[1, 2])
basic_check(operation)
# -----------------------------------------------------------------------------
def test_ASE_Receiver_Start_Ready() -> None:
operation = ASE_Receiver_Start_Ready(ase_id=[1, 2])
basic_check(operation)
# -----------------------------------------------------------------------------
def test_ASE_Receiver_Stop_Ready() -> None:
operation = ASE_Receiver_Stop_Ready(ase_id=[1, 2])
basic_check(operation)
# -----------------------------------------------------------------------------
def test_codec_specific_configuration() -> None:
SAMPLE_FREQUENCY = SamplingFrequency.FREQ_16000
FRAME_SURATION = FrameDuration.DURATION_10000_US
AUDIO_LOCATION = AudioLocation.FRONT_LEFT
config = CodecSpecificConfiguration(
sampling_frequency=SAMPLE_FREQUENCY,
frame_duration=FRAME_SURATION,
audio_channel_allocation=AUDIO_LOCATION,
octets_per_codec_frame=60,
codec_frames_per_sdu=1,
)
assert CodecSpecificConfiguration.from_bytes(bytes(config)) == config
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_pacs(): async def test_pacs():
@@ -140,6 +250,148 @@ async def test_pacs():
) )
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_ascs():
devices = TwoDevices()
devices[0].add_service(
AudioStreamControlService(device=devices[0], sink_ase_id=[1, 2])
)
await devices.setup_connection()
peer = device.Peer(devices.connections[1])
ascs_client = await peer.discover_service_and_create_proxy(
AudioStreamControlServiceProxy
)
notifications = {1: asyncio.Queue(), 2: asyncio.Queue()}
def on_notification(data: bytes, ase_id: int):
notifications[ase_id].put_nowait(data)
# Should be idle
assert await ascs_client.sink_ase[0].read_value() == bytes(
[1, AseStateMachine.State.IDLE]
)
assert await ascs_client.sink_ase[1].read_value() == bytes(
[2, AseStateMachine.State.IDLE]
)
# Subscribe
await ascs_client.sink_ase[0].subscribe(
functools.partial(on_notification, ase_id=1)
)
await ascs_client.sink_ase[1].subscribe(
functools.partial(on_notification, ase_id=2)
)
# Config Codec
config = CodecSpecificConfiguration(
sampling_frequency=SamplingFrequency.FREQ_48000,
frame_duration=FrameDuration.DURATION_10000_US,
audio_channel_allocation=AudioLocation.FRONT_LEFT,
octets_per_codec_frame=120,
codec_frames_per_sdu=1,
)
await ascs_client.ase_control_point.write_value(
ASE_Config_Codec(
ase_id=[1, 2],
target_latency=[3, 4],
target_phy=[5, 6],
codec_id=[CodingFormat(CodecID.LC3), CodingFormat(CodecID.LC3)],
codec_specific_configuration=[config, config],
)
)
assert (await notifications[1].get())[:2] == bytes(
[1, AseStateMachine.State.CODEC_CONFIGURED]
)
assert (await notifications[2].get())[:2] == bytes(
[2, AseStateMachine.State.CODEC_CONFIGURED]
)
# Config QOS
await ascs_client.ase_control_point.write_value(
ASE_Config_QOS(
ase_id=[1, 2],
cig_id=[1, 2],
cis_id=[3, 4],
sdu_interval=[5, 6],
framing=[0, 1],
phy=[2, 3],
max_sdu=[4, 5],
retransmission_number=[6, 7],
max_transport_latency=[8, 9],
presentation_delay=[10, 11],
)
)
assert (await notifications[1].get())[:2] == bytes(
[1, AseStateMachine.State.QOS_CONFIGURED]
)
assert (await notifications[2].get())[:2] == bytes(
[2, AseStateMachine.State.QOS_CONFIGURED]
)
# Enable
await ascs_client.ase_control_point.write_value(
ASE_Enable(
ase_id=[1, 2],
metadata=[b'foo', b'bar'],
)
)
assert (await notifications[1].get())[:2] == bytes(
[1, AseStateMachine.State.ENABLING]
)
assert (await notifications[2].get())[:2] == bytes(
[2, AseStateMachine.State.ENABLING]
)
# CIS establishment
devices[0].emit(
'cis_establishment',
device.CisLink(
device=devices[0],
acl_connection=devices.connections[0],
handle=5,
cis_id=3,
cig_id=1,
),
)
devices[0].emit(
'cis_establishment',
device.CisLink(
device=devices[0],
acl_connection=devices.connections[0],
handle=6,
cis_id=4,
cig_id=2,
),
)
assert (await notifications[1].get())[:2] == bytes(
[1, AseStateMachine.State.STREAMING]
)
assert (await notifications[2].get())[:2] == bytes(
[2, AseStateMachine.State.STREAMING]
)
# Release
await ascs_client.ase_control_point.write_value(
ASE_Release(
ase_id=[1, 2],
metadata=[b'foo', b'bar'],
)
)
assert (await notifications[1].get())[:2] == bytes(
[1, AseStateMachine.State.RELEASING]
)
assert (await notifications[2].get())[:2] == bytes(
[2, AseStateMachine.State.RELEASING]
)
assert (await notifications[1].get())[:2] == bytes([1, AseStateMachine.State.IDLE])
assert (await notifications[2].get())[:2] == bytes([2, AseStateMachine.State.IDLE])
await asyncio.sleep(0.001)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def run(): async def run():
await test_pacs() await test_pacs()
+71
View File
@@ -0,0 +1,71 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import os
import pytest
import logging
from bumble import device
from bumble import gatt
from bumble.profiles import cap
from bumble.profiles import csip
from .test_utils import TwoDevices
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_cas():
SIRK = bytes.fromhex('2f62c8ae41867d1bb619e788a2605faa')
devices = TwoDevices()
devices[0].add_service(
cap.CommonAudioServiceService(
csip.CoordinatedSetIdentificationService(
set_identity_resolving_key=SIRK,
set_identity_resolving_key_type=csip.SirkType.PLAINTEXT,
)
)
)
await devices.setup_connection()
peer = device.Peer(devices.connections[1])
cas_client = await peer.discover_service_and_create_proxy(
cap.CommonAudioServiceServiceProxy
)
included_services = await peer.discover_included_services(cas_client.service_proxy)
assert any(
service.uuid == gatt.GATT_COORDINATED_SET_IDENTIFICATION_SERVICE
for service in included_services
)
# -----------------------------------------------------------------------------
async def run():
await test_cas()
# -----------------------------------------------------------------------------
if __name__ == '__main__':
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(run())
+51 -5
View File
@@ -20,6 +20,7 @@ import os
import pytest import pytest
import struct import struct
import logging import logging
from unittest import mock
from bumble import device from bumble import device
from bumble.profiles import csip from bumble.profiles import csip
@@ -31,15 +32,55 @@ from .test_utils import TwoDevices
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
def test_s1():
assert (
csip.s1(b'SIRKenc'[::-1])
== bytes.fromhex('6901983f 18149e82 3c7d133a 7d774572')[::-1]
)
# -----------------------------------------------------------------------------
def test_k1():
K = bytes.fromhex('676e1b9b d448696f 061ec622 3ce5ced9')[::-1]
SALT = csip.s1(b'SIRKenc'[::-1])
P = b'csis'[::-1]
assert (
csip.k1(K, SALT, P)
== bytes.fromhex('5277453c c094d982 b0e8ee53 2f2d1f8b')[::-1]
)
# -----------------------------------------------------------------------------
def test_sih():
SIRK = bytes.fromhex('457d7d09 21a1fd22 cecd8c86 dd72cccd')[::-1]
PRAND = bytes.fromhex('69f563')[::-1]
assert csip.sih(SIRK, PRAND) == bytes.fromhex('1948da')[::-1]
# -----------------------------------------------------------------------------
def test_sef():
SIRK = bytes.fromhex('457d7d09 21a1fd22 cecd8c86 dd72cccd')[::-1]
K = bytes.fromhex('676e1b9b d448696f 061ec622 3ce5ced9')[::-1]
assert (
csip.sef(K, SIRK) == bytes.fromhex('170a3835 e13524a0 7e2562d5 f25fd346')[::-1]
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_csis(): @pytest.mark.parametrize(
'sirk_type,', [(csip.SirkType.ENCRYPTED), (csip.SirkType.PLAINTEXT)]
)
async def test_csis(sirk_type):
SIRK = bytes.fromhex('2f62c8ae41867d1bb619e788a2605faa') SIRK = bytes.fromhex('2f62c8ae41867d1bb619e788a2605faa')
LTK = bytes.fromhex('2f62c8ae41867d1bb619e788a2605faa')
devices = TwoDevices() devices = TwoDevices()
devices[0].add_service( devices[0].add_service(
csip.CoordinatedSetIdentificationService( csip.CoordinatedSetIdentificationService(
set_identity_resolving_key=SIRK, set_identity_resolving_key=SIRK,
set_identity_resolving_key_type=sirk_type,
coordinated_set_size=2, coordinated_set_size=2,
set_member_lock=csip.MemberLock.UNLOCKED, set_member_lock=csip.MemberLock.UNLOCKED,
set_member_rank=0, set_member_rank=0,
@@ -47,15 +88,19 @@ async def test_csis():
) )
await devices.setup_connection() await devices.setup_connection()
# Mock encryption.
devices.connections[0].encryption = 1
devices.connections[1].encryption = 1
devices[0].get_long_term_key = mock.AsyncMock(return_value=LTK)
devices[1].get_long_term_key = mock.AsyncMock(return_value=LTK)
peer = device.Peer(devices.connections[1]) peer = device.Peer(devices.connections[1])
csis_client = await peer.discover_service_and_create_proxy( csis_client = await peer.discover_service_and_create_proxy(
csip.CoordinatedSetIdentificationProxy csip.CoordinatedSetIdentificationProxy
) )
assert ( assert await csis_client.read_set_identity_resolving_key() == (sirk_type, SIRK)
await csis_client.set_identity_resolving_key.read_value()
== bytes([csip.SirkType.PLAINTEXT]) + SIRK
)
assert await csis_client.coordinated_set_size.read_value() == struct.pack('B', 2) assert await csis_client.coordinated_set_size.read_value() == struct.pack('B', 2)
assert await csis_client.set_member_lock.read_value() == struct.pack( assert await csis_client.set_member_lock.read_value() == struct.pack(
'B', csip.MemberLock.UNLOCKED 'B', csip.MemberLock.UNLOCKED
@@ -65,6 +110,7 @@ async def test_csis():
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def run(): async def run():
test_sih()
await test_csis() await test_csis()
+182 -2
View File
@@ -20,16 +20,23 @@ import logging
import os import os
from types import LambdaType from types import LambdaType
import pytest import pytest
from unittest import mock
from bumble.core import BT_BR_EDR_TRANSPORT from bumble.core import (
BT_BR_EDR_TRANSPORT,
BT_LE_TRANSPORT,
BT_PERIPHERAL_ROLE,
ConnectionParameters,
)
from bumble.device import Connection, Device from bumble.device import Connection, Device
from bumble.host import Host from bumble.host import AclPacketQueue, Host
from bumble.hci import ( from bumble.hci import (
HCI_ACCEPT_CONNECTION_REQUEST_COMMAND, HCI_ACCEPT_CONNECTION_REQUEST_COMMAND,
HCI_COMMAND_STATUS_PENDING, HCI_COMMAND_STATUS_PENDING,
HCI_CREATE_CONNECTION_COMMAND, HCI_CREATE_CONNECTION_COMMAND,
HCI_SUCCESS, HCI_SUCCESS,
Address, Address,
OwnAddressType,
HCI_Command_Complete_Event, HCI_Command_Complete_Event,
HCI_Command_Status_Event, HCI_Command_Status_Event,
HCI_Connection_Complete_Event, HCI_Connection_Complete_Event,
@@ -66,6 +73,13 @@ async def test_device_connect_parallel():
d1 = Device(host=Host(None, None)) d1 = Device(host=Host(None, None))
d2 = Device(host=Host(None, None)) d2 = Device(host=Host(None, None))
def _send(packet):
pass
d0.host.acl_packet_queue = AclPacketQueue(0, 0, _send)
d1.host.acl_packet_queue = AclPacketQueue(0, 0, _send)
d2.host.acl_packet_queue = AclPacketQueue(0, 0, _send)
# enable classic # enable classic
d0.classic_enabled = True d0.classic_enabled = True
d1.classic_enabled = True d1.classic_enabled = True
@@ -232,6 +246,172 @@ async def test_flush():
pass pass
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_legacy_advertising():
device = Device(host=mock.AsyncMock(Host))
# Start advertising
advertiser = await device.start_legacy_advertising()
assert device.legacy_advertiser
# Stop advertising
await advertiser.stop()
assert not device.legacy_advertiser
# -----------------------------------------------------------------------------
@pytest.mark.parametrize(
'own_address_type,',
(OwnAddressType.PUBLIC, OwnAddressType.RANDOM),
)
@pytest.mark.asyncio
async def test_legacy_advertising_connection(own_address_type):
device = Device(host=mock.AsyncMock(Host))
peer_address = Address('F0:F1:F2:F3:F4:F5')
# Start advertising
advertiser = await device.start_legacy_advertising()
device.on_connection(
0x0001,
BT_LE_TRANSPORT,
peer_address,
BT_PERIPHERAL_ROLE,
ConnectionParameters(0, 0, 0),
)
if own_address_type == OwnAddressType.PUBLIC:
assert device.lookup_connection(0x0001).self_address == device.public_address
else:
assert device.lookup_connection(0x0001).self_address == device.random_address
# For unknown reason, read_phy() in on_connection() would be killed at the end of
# test, so we force scheduling here to avoid an warning.
await asyncio.sleep(0.0001)
# -----------------------------------------------------------------------------
@pytest.mark.parametrize(
'auto_restart,',
(True, False),
)
@pytest.mark.asyncio
async def test_legacy_advertising_disconnection(auto_restart):
device = Device(host=mock.AsyncMock(spec=Host))
peer_address = Address('F0:F1:F2:F3:F4:F5')
advertiser = await device.start_legacy_advertising(auto_restart=auto_restart)
device.on_connection(
0x0001,
BT_LE_TRANSPORT,
peer_address,
BT_PERIPHERAL_ROLE,
ConnectionParameters(0, 0, 0),
)
device.start_legacy_advertising = mock.AsyncMock()
device.on_disconnection(0x0001, 0)
if auto_restart:
device.start_legacy_advertising.assert_called_with(
advertising_type=advertiser.advertising_type,
own_address_type=advertiser.own_address_type,
auto_restart=advertiser.auto_restart,
advertising_data=advertiser.advertising_data,
scan_response_data=advertiser.scan_response_data,
)
else:
device.start_legacy_advertising.assert_not_called()
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_extended_advertising():
device = Device(host=mock.AsyncMock(Host))
# Start advertising
advertiser = await device.start_extended_advertising()
assert device.extended_advertisers
# Stop advertising
await advertiser.stop()
assert not device.extended_advertisers
# -----------------------------------------------------------------------------
@pytest.mark.parametrize(
'own_address_type,',
(OwnAddressType.PUBLIC, OwnAddressType.RANDOM),
)
@pytest.mark.asyncio
async def test_extended_advertising_connection(own_address_type):
device = Device(host=mock.AsyncMock(spec=Host))
peer_address = Address('F0:F1:F2:F3:F4:F5')
advertiser = await device.start_extended_advertising(
own_address_type=own_address_type
)
device.on_connection(
0x0001,
BT_LE_TRANSPORT,
peer_address,
BT_PERIPHERAL_ROLE,
ConnectionParameters(0, 0, 0),
)
device.on_advertising_set_termination(
HCI_SUCCESS,
advertiser.handle,
0x0001,
)
if own_address_type == OwnAddressType.PUBLIC:
assert device.lookup_connection(0x0001).self_address == device.public_address
else:
assert device.lookup_connection(0x0001).self_address == device.random_address
# For unknown reason, read_phy() in on_connection() would be killed at the end of
# test, so we force scheduling here to avoid an warning.
await asyncio.sleep(0.0001)
# -----------------------------------------------------------------------------
@pytest.mark.parametrize(
'auto_restart,',
(True, False),
)
@pytest.mark.asyncio
async def test_extended_advertising_disconnection(auto_restart):
device = Device(host=mock.AsyncMock(spec=Host))
peer_address = Address('F0:F1:F2:F3:F4:F5')
advertiser = await device.start_extended_advertising(auto_restart=auto_restart)
device.on_connection(
0x0001,
BT_LE_TRANSPORT,
peer_address,
BT_PERIPHERAL_ROLE,
ConnectionParameters(0, 0, 0),
)
device.on_advertising_set_termination(
HCI_SUCCESS,
advertiser.handle,
0x0001,
)
device.start_extended_advertising = mock.AsyncMock()
device.on_disconnection(0x0001, 0)
if auto_restart:
device.start_extended_advertising.assert_called_with(
advertising_properties=advertiser.advertising_properties,
own_address_type=advertiser.own_address_type,
auto_restart=advertiser.auto_restart,
advertising_data=advertiser.advertising_data,
scan_response_data=advertiser.scan_response_data,
)
else:
device.start_extended_advertising.assert_not_called()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def test_gatt_services_with_gas(): def test_gatt_services_with_gas():
device = Device(host=Host(None, None)) device = Device(host=Host(None, None))
+76 -31
View File
@@ -20,11 +20,10 @@ import logging
import os import os
import struct import struct
import pytest import pytest
from unittest.mock import Mock, ANY from unittest.mock import AsyncMock, Mock, ANY
from bumble.controller import Controller from bumble.controller import Controller
from bumble.gatt_client import CharacteristicProxy from bumble.gatt_client import CharacteristicProxy
from bumble.gatt_server import Server
from bumble.link import LocalLink from bumble.link import LocalLink
from bumble.device import Device, Peer from bumble.device import Device, Peer
from bumble.host import Host from bumble.host import Host
@@ -120,9 +119,9 @@ async def test_characteristic_encoding():
Characteristic.READABLE, Characteristic.READABLE,
123, 123,
) )
x = c.read_value(None) x = await c.read_value(None)
assert x == bytes([123]) assert x == bytes([123])
c.write_value(None, bytes([122])) await c.write_value(None, bytes([122]))
assert c.value == 122 assert c.value == 122
class FooProxy(CharacteristicProxy): class FooProxy(CharacteristicProxy):
@@ -152,7 +151,22 @@ async def test_characteristic_encoding():
bytes([123]), bytes([123]),
) )
service = Service('3A657F47-D34F-46B3-B1EC-698E29B6B829', [characteristic]) async def async_read(connection):
return 0x05060708
async_characteristic = PackedCharacteristicAdapter(
Characteristic(
'2AB7E91B-43E8-4F73-AC3B-80C1683B47F9',
Characteristic.Properties.READ,
Characteristic.READABLE,
CharacteristicValue(read=async_read),
),
'>I',
)
service = Service(
'3A657F47-D34F-46B3-B1EC-698E29B6B829', [characteristic, async_characteristic]
)
server.add_service(service) server.add_service(service)
await client.power_on() await client.power_on()
@@ -184,6 +198,13 @@ async def test_characteristic_encoding():
await async_barrier() await async_barrier()
assert characteristic.value == bytes([50]) assert characteristic.value == bytes([50])
c2 = peer.get_characteristics_by_uuid(async_characteristic.uuid)
assert len(c2) == 1
c2 = c2[0]
cd2 = PackedCharacteristicAdapter(c2, ">I")
cd2v = await cd2.read_value()
assert cd2v == 0x05060708
last_change = None last_change = None
def on_change(value): def on_change(value):
@@ -285,7 +306,8 @@ async def test_attribute_getters():
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def test_CharacteristicAdapter(): @pytest.mark.asyncio
async def test_CharacteristicAdapter():
# Check that the CharacteristicAdapter base class is transparent # Check that the CharacteristicAdapter base class is transparent
v = bytes([1, 2, 3]) v = bytes([1, 2, 3])
c = Characteristic( c = Characteristic(
@@ -296,11 +318,11 @@ def test_CharacteristicAdapter():
) )
a = CharacteristicAdapter(c) a = CharacteristicAdapter(c)
value = a.read_value(None) value = await a.read_value(None)
assert value == v assert value == v
v = bytes([3, 4, 5]) v = bytes([3, 4, 5])
a.write_value(None, v) await a.write_value(None, v)
assert c.value == v assert c.value == v
# Simple delegated adapter # Simple delegated adapter
@@ -308,11 +330,11 @@ def test_CharacteristicAdapter():
c, lambda x: bytes(reversed(x)), lambda x: bytes(reversed(x)) c, lambda x: bytes(reversed(x)), lambda x: bytes(reversed(x))
) )
value = a.read_value(None) value = await a.read_value(None)
assert value == bytes(reversed(v)) assert value == bytes(reversed(v))
v = bytes([3, 4, 5]) v = bytes([3, 4, 5])
a.write_value(None, v) await a.write_value(None, v)
assert a.value == bytes(reversed(v)) assert a.value == bytes(reversed(v))
# Packed adapter with single element format # Packed adapter with single element format
@@ -321,10 +343,10 @@ def test_CharacteristicAdapter():
c.value = v c.value = v
a = PackedCharacteristicAdapter(c, '>H') a = PackedCharacteristicAdapter(c, '>H')
value = a.read_value(None) value = await a.read_value(None)
assert value == pv assert value == pv
c.value = None c.value = None
a.write_value(None, pv) await a.write_value(None, pv)
assert a.value == v assert a.value == v
# Packed adapter with multi-element format # Packed adapter with multi-element format
@@ -334,10 +356,10 @@ def test_CharacteristicAdapter():
c.value = (v1, v2) c.value = (v1, v2)
a = PackedCharacteristicAdapter(c, '>HH') a = PackedCharacteristicAdapter(c, '>HH')
value = a.read_value(None) value = await a.read_value(None)
assert value == pv assert value == pv
c.value = None c.value = None
a.write_value(None, pv) await a.write_value(None, pv)
assert a.value == (v1, v2) assert a.value == (v1, v2)
# Mapped adapter # Mapped adapter
@@ -348,10 +370,10 @@ def test_CharacteristicAdapter():
c.value = mapped c.value = mapped
a = MappedCharacteristicAdapter(c, '>HH', ('v1', 'v2')) a = MappedCharacteristicAdapter(c, '>HH', ('v1', 'v2'))
value = a.read_value(None) value = await a.read_value(None)
assert value == pv assert value == pv
c.value = None c.value = None
a.write_value(None, pv) await a.write_value(None, pv)
assert a.value == mapped assert a.value == mapped
# UTF-8 adapter # UTF-8 adapter
@@ -360,27 +382,49 @@ def test_CharacteristicAdapter():
c.value = v c.value = v
a = UTF8CharacteristicAdapter(c) a = UTF8CharacteristicAdapter(c)
value = a.read_value(None) value = await a.read_value(None)
assert value == ev assert value == ev
c.value = None c.value = None
a.write_value(None, ev) await a.write_value(None, ev)
assert a.value == v assert a.value == v
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def test_CharacteristicValue(): @pytest.mark.asyncio
async def test_CharacteristicValue():
b = bytes([1, 2, 3]) b = bytes([1, 2, 3])
c = CharacteristicValue(read=lambda _: b)
x = c.read(None) async def read_value(connection):
return b
c = CharacteristicValue(read=read_value)
x = await c.read(None)
assert x == b assert x == b
result = [] m = Mock()
c = CharacteristicValue( c = CharacteristicValue(write=m)
write=lambda connection, value: result.append((connection, value))
)
z = object() z = object()
c.write(z, b) c.write(z, b)
assert result == [(z, b)] m.assert_called_once_with(z, b)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_CharacteristicValue_async():
b = bytes([1, 2, 3])
async def read_value(connection):
return b
c = CharacteristicValue(read=read_value)
x = await c.read(None)
assert x == b
m = AsyncMock()
c = CharacteristicValue(write=m)
z = object()
await c.write(z, b)
m.assert_called_once_with(z, b)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -961,12 +1005,18 @@ Descriptor(handle=0x0009, type=UUID-16:2902 (Client Characteristic Configuration
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
async def async_main(): async def async_main():
test_UUID()
test_ATT_Error_Response()
test_ATT_Read_By_Group_Type_Request()
await test_read_write() await test_read_write()
await test_read_write2() await test_read_write2()
await test_subscribe_notify() await test_subscribe_notify()
await test_unsubscribe() await test_unsubscribe()
await test_characteristic_encoding() await test_characteristic_encoding()
await test_mtu_exchange() await test_mtu_exchange()
await test_CharacteristicValue()
await test_CharacteristicValue_async()
await test_CharacteristicAdapter()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -1105,9 +1155,4 @@ def test_get_attribute_group():
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
if __name__ == '__main__': if __name__ == '__main__':
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
test_UUID()
test_ATT_Error_Response()
test_ATT_Read_By_Group_Type_Request()
test_CharacteristicValue()
test_CharacteristicAdapter()
asyncio.run(async_main()) asyncio.run(async_main())