forked from auracaster/bumble_mirror
Compare commits
97 Commits
gbg/androi
...
gbg/hci-la
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5f377c024b | ||
|
|
aeeff18428 | ||
|
|
e2fec67bd9 | ||
|
|
88cb3b2a4d | ||
|
|
9ebb03be46 | ||
|
|
80d84af76c | ||
|
|
8f4721758f | ||
|
|
8864af4acd | ||
|
|
8980fb8cc7 | ||
|
|
2c5f3472a9 | ||
|
|
f18277ac78 | ||
|
|
09e5ea5dec | ||
|
|
6810865670 | ||
|
|
3e9e06a02c | ||
|
|
ccd12f6591 | ||
|
|
f9a7843f7e | ||
|
|
210c334db7 | ||
|
|
f297cdfcce | ||
|
|
5b536d00ab | ||
|
|
b4af46ebd5 | ||
|
|
c08da3193e | ||
|
|
f2925ca647 | ||
|
|
fd4d68e5c0 | ||
|
|
5d83deffa4 | ||
|
|
2878cca478 | ||
|
|
53934716db | ||
|
|
d885d45824 | ||
|
|
b90d0f8710 | ||
|
|
8ccfc90fe6 | ||
|
|
92aa7e9e2a | ||
|
|
afc6d19e04 | ||
|
|
c05f073b33 | ||
|
|
2b4c2a22f4 | ||
|
|
47fe93a148 | ||
|
|
6139ca8045 | ||
|
|
87c76a4a0e | ||
|
|
f7b66db873 | ||
|
|
0b314bd7f7 | ||
|
|
9da2e32ad7 | ||
|
|
93c0875740 | ||
|
|
a286700239 | ||
|
|
98ed772e8a | ||
|
|
f0b55a4f97 | ||
|
|
b74503d345 | ||
|
|
f911163e49 | ||
|
|
b083cc99ad | ||
|
|
d35643524e | ||
|
|
62a8ced447 | ||
|
|
085f163c92 | ||
|
|
81a6b1e097 | ||
|
|
dd090c9e6b | ||
|
|
11faa48422 | ||
|
|
55596176c2 | ||
|
|
4d6822d312 | ||
|
|
985c365e6d | ||
|
|
af57762227 | ||
|
|
3575f9030e | ||
|
|
698d947d85 | ||
|
|
ff6528d2bf | ||
|
|
72ac75a98d | ||
|
|
5e3ecb74e4 | ||
|
|
c59be293c8 | ||
|
|
88b4cbdf1a | ||
|
|
d6afbc6f4e | ||
|
|
fc90de3e7b | ||
|
|
847c2ef114 | ||
|
|
a0bf0c1f4d | ||
|
|
6d22ed80ec | ||
|
|
843466c822 | ||
|
|
3adcc8be09 | ||
|
|
c853d56302 | ||
|
|
dc97be5b35 | ||
|
|
73dbdfff9f | ||
|
|
dff14e1258 | ||
|
|
10a3833893 | ||
|
|
ffb3eca68b | ||
|
|
7eb493990f | ||
|
|
403a13e4c6 | ||
|
|
ad0f035df5 | ||
|
|
07f71fc895 | ||
|
|
f47b9178ad | ||
|
|
4f399249bd | ||
|
|
9324237828 | ||
|
|
d1033c018a | ||
|
|
0f29052ade | ||
|
|
0578e84586 | ||
|
|
6ab41c466f | ||
|
|
98a1093ebf | ||
|
|
caf04373f3 | ||
|
|
d4e8526766 | ||
|
|
515b83a8c7 | ||
|
|
dc18595c8a | ||
|
|
488bcfe9c6 | ||
|
|
d6cefdff8e | ||
|
|
dc410b14c4 | ||
|
|
4c49ef9403 | ||
|
|
ba85dcbda5 |
2
.gitignore
vendored
2
.gitignore
vendored
@@ -10,3 +10,5 @@ __pycache__
|
||||
bumble/_version.py
|
||||
.vscode/launch.json
|
||||
/.idea
|
||||
venv/
|
||||
.venv/
|
||||
|
||||
1
.vscode/settings.json
vendored
1
.vscode/settings.json
vendored
@@ -22,6 +22,7 @@
|
||||
"cmac",
|
||||
"CONNECTIONLESS",
|
||||
"csip",
|
||||
"csis",
|
||||
"csrcs",
|
||||
"CVSD",
|
||||
"datagram",
|
||||
|
||||
526
apps/bench.py
526
apps/bench.py
File diff suppressed because it is too large
Load Diff
63
apps/ble_rpa_tool.py
Normal file
63
apps/ble_rpa_tool.py
Normal file
@@ -0,0 +1,63 @@
|
||||
# Copyright 2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import click
|
||||
from bumble.colors import color
|
||||
from bumble.hci import Address
|
||||
from bumble.helpers import generate_irk, verify_rpa_with_irk
|
||||
|
||||
|
||||
@click.group()
|
||||
def cli():
|
||||
'''
|
||||
This is a tool for generating IRK, RPA,
|
||||
and verifying IRK/RPA pairs
|
||||
'''
|
||||
|
||||
|
||||
@click.command()
|
||||
def gen_irk() -> None:
|
||||
print(generate_irk().hex())
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.argument("irk", type=str)
|
||||
def gen_rpa(irk: str) -> None:
|
||||
irk_bytes = bytes.fromhex(irk)
|
||||
rpa = Address.generate_private_address(irk_bytes)
|
||||
print(rpa.to_string(with_type_qualifier=False))
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.argument("irk", type=str)
|
||||
@click.argument("rpa", type=str)
|
||||
def verify_rpa(irk: str, rpa: str) -> None:
|
||||
address = Address(rpa)
|
||||
irk_bytes = bytes.fromhex(irk)
|
||||
if verify_rpa_with_irk(address, irk_bytes):
|
||||
print(color("Verified", "green"))
|
||||
else:
|
||||
print(color("Not Verified", "red"))
|
||||
|
||||
|
||||
def main():
|
||||
cli.add_command(gen_irk)
|
||||
cli.add_command(gen_rpa)
|
||||
cli.add_command(verify_rpa)
|
||||
cli()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -777,7 +777,7 @@ class ConsoleApp:
|
||||
if not service:
|
||||
continue
|
||||
values = [
|
||||
attribute.read_value(connection)
|
||||
await attribute.read_value(connection)
|
||||
for connection in self.device.connections.values()
|
||||
]
|
||||
if not values:
|
||||
@@ -796,11 +796,11 @@ class ConsoleApp:
|
||||
if not characteristic:
|
||||
continue
|
||||
values = [
|
||||
attribute.read_value(connection)
|
||||
await attribute.read_value(connection)
|
||||
for connection in self.device.connections.values()
|
||||
]
|
||||
if not values:
|
||||
values = [attribute.read_value(None)]
|
||||
values = [await attribute.read_value(None)]
|
||||
|
||||
# TODO: future optimization: convert CCCD value to human readable string
|
||||
|
||||
@@ -944,7 +944,7 @@ class ConsoleApp:
|
||||
|
||||
# send data to any subscribers
|
||||
if isinstance(attribute, Characteristic):
|
||||
attribute.write_value(None, value)
|
||||
await attribute.write_value(None, value)
|
||||
if attribute.has_properties(Characteristic.NOTIFY):
|
||||
await self.device.gatt_server.notify_subscribers(attribute)
|
||||
if attribute.has_properties(Characteristic.INDICATE):
|
||||
|
||||
@@ -18,9 +18,11 @@
|
||||
import asyncio
|
||||
import os
|
||||
import logging
|
||||
import click
|
||||
from bumble.company_ids import COMPANY_IDENTIFIERS
|
||||
import time
|
||||
|
||||
import click
|
||||
|
||||
from bumble.company_ids import COMPANY_IDENTIFIERS
|
||||
from bumble.colors import color
|
||||
from bumble.core import name_or_number
|
||||
from bumble.hci import (
|
||||
@@ -32,10 +34,14 @@ from bumble.hci import (
|
||||
HCI_Command,
|
||||
HCI_Command_Complete_Event,
|
||||
HCI_Command_Status_Event,
|
||||
HCI_READ_BUFFER_SIZE_COMMAND,
|
||||
HCI_Read_Buffer_Size_Command,
|
||||
HCI_READ_BD_ADDR_COMMAND,
|
||||
HCI_Read_BD_ADDR_Command,
|
||||
HCI_READ_LOCAL_NAME_COMMAND,
|
||||
HCI_Read_Local_Name_Command,
|
||||
HCI_LE_READ_BUFFER_SIZE_COMMAND,
|
||||
HCI_LE_Read_Buffer_Size_Command,
|
||||
HCI_LE_READ_MAXIMUM_DATA_LENGTH_COMMAND,
|
||||
HCI_LE_Read_Maximum_Data_Length_Command,
|
||||
HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND,
|
||||
@@ -44,6 +50,7 @@ from bumble.hci import (
|
||||
HCI_LE_Read_Maximum_Advertising_Data_Length_Command,
|
||||
HCI_LE_READ_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND,
|
||||
HCI_LE_Read_Suggested_Default_Data_Length_Command,
|
||||
HCI_Read_Local_Version_Information_Command,
|
||||
)
|
||||
from bumble.host import Host
|
||||
from bumble.transport import open_transport_or_link
|
||||
@@ -59,7 +66,7 @@ def command_succeeded(response):
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def get_classic_info(host):
|
||||
async def get_classic_info(host: Host) -> None:
|
||||
if host.supports_command(HCI_READ_BD_ADDR_COMMAND):
|
||||
response = await host.send_command(HCI_Read_BD_ADDR_Command())
|
||||
if command_succeeded(response):
|
||||
@@ -80,7 +87,7 @@ async def get_classic_info(host):
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def get_le_info(host):
|
||||
async def get_le_info(host: Host) -> None:
|
||||
print()
|
||||
|
||||
if host.supports_command(HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND):
|
||||
@@ -137,7 +144,32 @@ async def get_le_info(host):
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def async_main(transport):
|
||||
async def get_acl_flow_control_info(host: Host) -> None:
|
||||
print()
|
||||
|
||||
if host.supports_command(HCI_READ_BUFFER_SIZE_COMMAND):
|
||||
response = await host.send_command(
|
||||
HCI_Read_Buffer_Size_Command(), check_result=True
|
||||
)
|
||||
print(
|
||||
color('ACL Flow Control:', 'yellow'),
|
||||
f'{response.return_parameters.hc_total_num_acl_data_packets} '
|
||||
f'packets of size {response.return_parameters.hc_acl_data_packet_length}',
|
||||
)
|
||||
|
||||
if host.supports_command(HCI_LE_READ_BUFFER_SIZE_COMMAND):
|
||||
response = await host.send_command(
|
||||
HCI_LE_Read_Buffer_Size_Command(), check_result=True
|
||||
)
|
||||
print(
|
||||
color('LE ACL Flow Control:', 'yellow'),
|
||||
f'{response.return_parameters.hc_total_num_le_acl_data_packets} '
|
||||
f'packets of size {response.return_parameters.hc_le_acl_data_packet_length}',
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def async_main(latency_probes, transport):
|
||||
print('<<< connecting to HCI...')
|
||||
async with await open_transport_or_link(transport) as (hci_source, hci_sink):
|
||||
print('<<< connected')
|
||||
@@ -145,6 +177,23 @@ async def async_main(transport):
|
||||
host = Host(hci_source, hci_sink)
|
||||
await host.reset()
|
||||
|
||||
# Measure the latency if requested
|
||||
latencies = []
|
||||
if latency_probes:
|
||||
for _ in range(latency_probes):
|
||||
start = time.time()
|
||||
await host.send_command(HCI_Read_Local_Version_Information_Command())
|
||||
latencies.append(1000 * (time.time() - start))
|
||||
print(
|
||||
color('HCI Command Latency:', 'yellow'),
|
||||
(
|
||||
f'min={min(latencies):.2f}, '
|
||||
f'max={max(latencies):.2f}, '
|
||||
f'average={sum(latencies)/len(latencies):.2f}'
|
||||
),
|
||||
'\n',
|
||||
)
|
||||
|
||||
# Print version
|
||||
print(color('Version:', 'yellow'))
|
||||
print(
|
||||
@@ -168,6 +217,9 @@ async def async_main(transport):
|
||||
# Get the LE info
|
||||
await get_le_info(host)
|
||||
|
||||
# Print the ACL flow control info
|
||||
await get_acl_flow_control_info(host)
|
||||
|
||||
# Print the list of commands supported by the controller
|
||||
print()
|
||||
print(color('Supported Commands:', 'yellow'))
|
||||
@@ -177,10 +229,16 @@ async def async_main(transport):
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@click.command()
|
||||
@click.option(
|
||||
'--latency-probes',
|
||||
metavar='N',
|
||||
type=int,
|
||||
help='Send N commands to measure HCI transport latency statistics',
|
||||
)
|
||||
@click.argument('transport')
|
||||
def main(transport):
|
||||
def main(latency_probes, transport):
|
||||
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
|
||||
asyncio.run(async_main(transport))
|
||||
asyncio.run(async_main(latency_probes, transport))
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
@@ -49,14 +49,16 @@ class ServerBridge:
|
||||
self.tcp_port = tcp_port
|
||||
|
||||
async def start(self, device: Device) -> None:
|
||||
# Listen for incoming L2CAP CoC connections
|
||||
# Listen for incoming L2CAP channel connections
|
||||
device.create_l2cap_server(
|
||||
spec=l2cap.LeCreditBasedChannelSpec(
|
||||
psm=self.psm, mtu=self.mtu, mps=self.mps, max_credits=self.max_credits
|
||||
),
|
||||
handler=self.on_coc,
|
||||
handler=self.on_channel,
|
||||
)
|
||||
print(
|
||||
color(f'### Listening for channel connection on PSM {self.psm}', 'yellow')
|
||||
)
|
||||
print(color(f'### Listening for CoC connection on PSM {self.psm}', 'yellow'))
|
||||
|
||||
def on_ble_connection(connection):
|
||||
def on_ble_disconnection(reason):
|
||||
@@ -73,7 +75,7 @@ class ServerBridge:
|
||||
await device.start_advertising(auto_restart=True)
|
||||
|
||||
# Called when a new L2CAP connection is established
|
||||
def on_coc(self, l2cap_channel):
|
||||
def on_channel(self, l2cap_channel):
|
||||
print(color('*** L2CAP channel:', 'cyan'), l2cap_channel)
|
||||
|
||||
class Pipe:
|
||||
@@ -83,7 +85,7 @@ class ServerBridge:
|
||||
self.l2cap_channel = l2cap_channel
|
||||
|
||||
l2cap_channel.on('close', self.on_l2cap_close)
|
||||
l2cap_channel.sink = self.on_coc_sdu
|
||||
l2cap_channel.sink = self.on_channel_sdu
|
||||
|
||||
async def connect_to_tcp(self):
|
||||
# Connect to the TCP server
|
||||
@@ -128,7 +130,7 @@ class ServerBridge:
|
||||
if self.tcp_transport is not None:
|
||||
self.tcp_transport.close()
|
||||
|
||||
def on_coc_sdu(self, sdu):
|
||||
def on_channel_sdu(self, sdu):
|
||||
print(color(f'<<< [L2CAP SDU]: {len(sdu)} bytes', 'cyan'))
|
||||
if self.tcp_transport is None:
|
||||
print(color('!!! TCP socket not open, dropping', 'red'))
|
||||
@@ -183,7 +185,7 @@ class ClientBridge:
|
||||
peer_name = writer.get_extra_info('peer_name')
|
||||
print(color(f'<<< TCP connection from {peer_name}', 'magenta'))
|
||||
|
||||
def on_coc_sdu(sdu):
|
||||
def on_channel_sdu(sdu):
|
||||
print(color(f'<<< [L2CAP SDU]: {len(sdu)} bytes', 'cyan'))
|
||||
l2cap_to_tcp_pipe.write(sdu)
|
||||
|
||||
@@ -209,7 +211,7 @@ class ClientBridge:
|
||||
writer.close()
|
||||
return
|
||||
|
||||
l2cap_channel.sink = on_coc_sdu
|
||||
l2cap_channel.sink = on_channel_sdu
|
||||
l2cap_channel.on('close', on_l2cap_close)
|
||||
|
||||
# Start a flow control pipe from L2CAP to TCP
|
||||
@@ -274,23 +276,29 @@ async def run(device_config, hci_transport, bridge):
|
||||
@click.pass_context
|
||||
@click.option('--device-config', help='Device configuration file', required=True)
|
||||
@click.option('--hci-transport', help='HCI transport', required=True)
|
||||
@click.option('--psm', help='PSM for L2CAP CoC', type=int, default=1234)
|
||||
@click.option('--psm', help='PSM for L2CAP', type=int, default=1234)
|
||||
@click.option(
|
||||
'--l2cap-coc-max-credits',
|
||||
help='Maximum L2CAP CoC Credits',
|
||||
'--l2cap-max-credits',
|
||||
help='Maximum L2CAP Credits',
|
||||
type=click.IntRange(1, 65535),
|
||||
default=128,
|
||||
)
|
||||
@click.option(
|
||||
'--l2cap-coc-mtu',
|
||||
help='L2CAP CoC MTU',
|
||||
type=click.IntRange(23, 65535),
|
||||
default=1022,
|
||||
'--l2cap-mtu',
|
||||
help='L2CAP MTU',
|
||||
type=click.IntRange(
|
||||
l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU,
|
||||
l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MTU,
|
||||
),
|
||||
default=1024,
|
||||
)
|
||||
@click.option(
|
||||
'--l2cap-coc-mps',
|
||||
help='L2CAP CoC MPS',
|
||||
type=click.IntRange(23, 65533),
|
||||
'--l2cap-mps',
|
||||
help='L2CAP MPS',
|
||||
type=click.IntRange(
|
||||
l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MPS,
|
||||
l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MPS,
|
||||
),
|
||||
default=1024,
|
||||
)
|
||||
def cli(
|
||||
@@ -298,17 +306,17 @@ def cli(
|
||||
device_config,
|
||||
hci_transport,
|
||||
psm,
|
||||
l2cap_coc_max_credits,
|
||||
l2cap_coc_mtu,
|
||||
l2cap_coc_mps,
|
||||
l2cap_max_credits,
|
||||
l2cap_mtu,
|
||||
l2cap_mps,
|
||||
):
|
||||
context.ensure_object(dict)
|
||||
context.obj['device_config'] = device_config
|
||||
context.obj['hci_transport'] = hci_transport
|
||||
context.obj['psm'] = psm
|
||||
context.obj['max_credits'] = l2cap_coc_max_credits
|
||||
context.obj['mtu'] = l2cap_coc_mtu
|
||||
context.obj['mps'] = l2cap_coc_mps
|
||||
context.obj['max_credits'] = l2cap_max_credits
|
||||
context.obj['mtu'] = l2cap_mtu
|
||||
context.obj['mps'] = l2cap_mps
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
14
apps/pair.py
14
apps/pair.py
@@ -52,11 +52,13 @@ from bumble.att import (
|
||||
class Waiter:
|
||||
instance = None
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, linger=False):
|
||||
self.done = asyncio.get_running_loop().create_future()
|
||||
self.linger = linger
|
||||
|
||||
def terminate(self):
|
||||
self.done.set_result(None)
|
||||
if not self.linger:
|
||||
self.done.set_result(None)
|
||||
|
||||
async def wait_until_terminated(self):
|
||||
return await self.done
|
||||
@@ -302,7 +304,7 @@ async def pair(
|
||||
hci_transport,
|
||||
address_or_name,
|
||||
):
|
||||
Waiter.instance = Waiter()
|
||||
Waiter.instance = Waiter(linger=linger)
|
||||
|
||||
print('<<< connecting to HCI...')
|
||||
async with await open_transport_or_link(hci_transport) as (hci_source, hci_sink):
|
||||
@@ -396,7 +398,6 @@ async def pair(
|
||||
address_or_name,
|
||||
transport=BT_LE_TRANSPORT if mode == 'le' else BT_BR_EDR_TRANSPORT,
|
||||
)
|
||||
pairing_failure = False
|
||||
|
||||
if not request:
|
||||
try:
|
||||
@@ -405,11 +406,8 @@ async def pair(
|
||||
else:
|
||||
await connection.authenticate()
|
||||
except ProtocolError as error:
|
||||
pairing_failure = True
|
||||
print(color(f'Pairing failed: {error}', 'red'))
|
||||
|
||||
if not linger or pairing_failure:
|
||||
return
|
||||
else:
|
||||
if mode == 'le':
|
||||
# Advertise so that peers can find us and connect
|
||||
@@ -459,7 +457,7 @@ class LogHandler(logging.Handler):
|
||||
help='Enable CTKD',
|
||||
show_default=True,
|
||||
)
|
||||
@click.option('--linger', default=True, is_flag=True, help='Linger after pairing')
|
||||
@click.option('--linger', default=False, is_flag=True, help='Linger after pairing')
|
||||
@click.option(
|
||||
'--io',
|
||||
type=click.Choice(
|
||||
|
||||
@@ -25,9 +25,21 @@
|
||||
from __future__ import annotations
|
||||
import enum
|
||||
import functools
|
||||
import inspect
|
||||
import struct
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Type,
|
||||
Union,
|
||||
TYPE_CHECKING,
|
||||
)
|
||||
|
||||
from pyee import EventEmitter
|
||||
from typing import Dict, Type, List, Protocol, Union, Optional, Any, TYPE_CHECKING
|
||||
|
||||
from bumble.core import UUID, name_or_number, ProtocolError
|
||||
from bumble.hci import HCI_Object, key_with_value
|
||||
@@ -722,12 +734,38 @@ class ATT_Handle_Value_Confirmation(ATT_PDU):
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class ConnectionValue(Protocol):
|
||||
def read(self, connection) -> bytes:
|
||||
...
|
||||
class AttributeValue:
|
||||
'''
|
||||
Attribute value where reading and/or writing is delegated to functions
|
||||
passed as arguments to the constructor.
|
||||
'''
|
||||
|
||||
def write(self, connection, value: bytes) -> None:
|
||||
...
|
||||
def __init__(
|
||||
self,
|
||||
read: Union[
|
||||
Callable[[Optional[Connection]], bytes],
|
||||
Callable[[Optional[Connection]], Awaitable[bytes]],
|
||||
None,
|
||||
] = None,
|
||||
write: Union[
|
||||
Callable[[Optional[Connection], bytes], None],
|
||||
Callable[[Optional[Connection], bytes], Awaitable[None]],
|
||||
None,
|
||||
] = None,
|
||||
):
|
||||
self._read = read
|
||||
self._write = write
|
||||
|
||||
def read(self, connection: Optional[Connection]) -> Union[bytes, Awaitable[bytes]]:
|
||||
return self._read(connection) if self._read else b''
|
||||
|
||||
def write(
|
||||
self, connection: Optional[Connection], value: bytes
|
||||
) -> Union[Awaitable[None], None]:
|
||||
if self._write:
|
||||
return self._write(connection, value)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -770,13 +808,13 @@ class Attribute(EventEmitter):
|
||||
READ_REQUIRES_AUTHORIZATION = Permissions.READ_REQUIRES_AUTHORIZATION
|
||||
WRITE_REQUIRES_AUTHORIZATION = Permissions.WRITE_REQUIRES_AUTHORIZATION
|
||||
|
||||
value: Union[str, bytes, ConnectionValue]
|
||||
value: Union[bytes, AttributeValue]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
attribute_type: Union[str, bytes, UUID],
|
||||
permissions: Union[str, Attribute.Permissions],
|
||||
value: Union[str, bytes, ConnectionValue] = b'',
|
||||
value: Union[str, bytes, AttributeValue] = b'',
|
||||
) -> None:
|
||||
EventEmitter.__init__(self)
|
||||
self.handle = 0
|
||||
@@ -806,7 +844,7 @@ class Attribute(EventEmitter):
|
||||
def decode_value(self, value_bytes: bytes) -> Any:
|
||||
return value_bytes
|
||||
|
||||
def read_value(self, connection: Optional[Connection]) -> bytes:
|
||||
async def read_value(self, connection: Optional[Connection]) -> bytes:
|
||||
if (
|
||||
(self.permissions & self.READ_REQUIRES_ENCRYPTION)
|
||||
and connection is not None
|
||||
@@ -832,6 +870,8 @@ class Attribute(EventEmitter):
|
||||
if hasattr(self.value, 'read'):
|
||||
try:
|
||||
value = self.value.read(connection)
|
||||
if inspect.isawaitable(value):
|
||||
value = await value
|
||||
except ATT_Error as error:
|
||||
raise ATT_Error(
|
||||
error_code=error.error_code, att_handle=self.handle
|
||||
@@ -841,7 +881,7 @@ class Attribute(EventEmitter):
|
||||
|
||||
return self.encode_value(value)
|
||||
|
||||
def write_value(self, connection: Connection, value_bytes: bytes) -> None:
|
||||
async def write_value(self, connection: Connection, value_bytes: bytes) -> None:
|
||||
if (
|
||||
self.permissions & self.WRITE_REQUIRES_ENCRYPTION
|
||||
) and not connection.encryption:
|
||||
@@ -864,7 +904,9 @@ class Attribute(EventEmitter):
|
||||
|
||||
if hasattr(self.value, 'write'):
|
||||
try:
|
||||
self.value.write(connection, value) # pylint: disable=not-callable
|
||||
result = self.value.write(connection, value)
|
||||
if inspect.isawaitable(result):
|
||||
await result
|
||||
except ATT_Error as error:
|
||||
raise ATT_Error(
|
||||
error_code=error.error_code, att_handle=self.handle
|
||||
|
||||
@@ -134,12 +134,14 @@ class Controller:
|
||||
'0000000060000000'
|
||||
) # BR/EDR Not Supported, LE Supported (Controller)
|
||||
self.manufacturer_name = 0xFFFF
|
||||
self.hc_data_packet_length = 27
|
||||
self.hc_total_num_data_packets = 64
|
||||
self.hc_le_data_packet_length = 27
|
||||
self.hc_total_num_le_data_packets = 64
|
||||
self.event_mask = 0
|
||||
self.event_mask_page_2 = 0
|
||||
self.supported_commands = bytes.fromhex(
|
||||
'2000800000c000000000e40000002822000000000000040000f7ffff7f000000'
|
||||
'2000800000c000000000e4000000a822000000000000040000f7ffff7f000000'
|
||||
'30f0f9ff01008004000000000000000000000000000000000000000000000000'
|
||||
)
|
||||
self.le_event_mask = 0
|
||||
@@ -914,6 +916,19 @@ class Controller:
|
||||
'''
|
||||
return bytes([HCI_SUCCESS]) + self.lmp_features
|
||||
|
||||
def on_hci_read_buffer_size_command(self, _command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.4.5 Read Buffer Size Command
|
||||
'''
|
||||
return struct.pack(
|
||||
'<BHBHH',
|
||||
HCI_SUCCESS,
|
||||
self.hc_data_packet_length,
|
||||
0,
|
||||
self.hc_total_num_data_packets,
|
||||
0,
|
||||
)
|
||||
|
||||
def on_hci_read_bd_addr_command(self, _command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.4.6 Read BD_ADDR Command
|
||||
@@ -1263,3 +1278,15 @@ class Controller:
|
||||
See Bluetooth spec Vol 4, Part E - 7.8.74 LE Read Transmit Power Command
|
||||
'''
|
||||
return struct.pack('<BBB', HCI_SUCCESS, 0, 0)
|
||||
|
||||
def on_hci_le_setup_iso_data_path_command(self, command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.8.109 LE Setup ISO Data Path Command
|
||||
'''
|
||||
return struct.pack('<BH', HCI_SUCCESS, command.connection_handle)
|
||||
|
||||
def on_hci_le_remove_iso_data_path_command(self, command):
|
||||
'''
|
||||
See Bluetooth spec Vol 4, Part E - 7.8.110 LE Remove ISO Data Path Command
|
||||
'''
|
||||
return struct.pack('<BH', HCI_SUCCESS, command.connection_handle)
|
||||
|
||||
@@ -100,6 +100,16 @@ class EccKey:
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def generate_prand() -> bytes:
|
||||
'''Generates random 3 bytes, with the 2 most significant bits of 0b01.
|
||||
|
||||
See Bluetooth spec, Vol 6, Part E - Table 1.2.
|
||||
'''
|
||||
prand_bytes = secrets.token_bytes(6)
|
||||
return prand_bytes[:2] + bytes([(prand_bytes[2] & 0b01111111) | 0b01000000])
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def xor(x: bytes, y: bytes) -> bytes:
|
||||
assert len(x) == len(y)
|
||||
|
||||
338
bumble/device.py
338
bumble/device.py
@@ -437,6 +437,38 @@ class AdvertisingType(IntEnum):
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclass
|
||||
class LegacyAdvertiser:
|
||||
device: Device
|
||||
advertising_type: AdvertisingType
|
||||
own_address_type: OwnAddressType
|
||||
auto_restart: bool
|
||||
advertising_data: Optional[bytes]
|
||||
scan_response_data: Optional[bytes]
|
||||
|
||||
async def stop(self) -> None:
|
||||
await self.device.stop_legacy_advertising()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@dataclass
|
||||
class ExtendedAdvertiser(CompositeEventEmitter):
|
||||
device: Device
|
||||
handle: int
|
||||
advertising_properties: HCI_LE_Set_Extended_Advertising_Parameters_Command.AdvertisingProperties
|
||||
own_address_type: OwnAddressType
|
||||
auto_restart: bool
|
||||
advertising_data: Optional[bytes]
|
||||
scan_response_data: Optional[bytes]
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
async def stop(self) -> None:
|
||||
await self.device.stop_extended_advertising(self.handle)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class LePhyOptions:
|
||||
# Coded PHY preference
|
||||
@@ -658,6 +690,9 @@ class Connection(CompositeEventEmitter):
|
||||
gatt_client: gatt_client.Client
|
||||
pairing_peer_io_capability: Optional[int]
|
||||
pairing_peer_authentication_requirements: Optional[int]
|
||||
advertiser_after_disconnection: Union[
|
||||
LegacyAdvertiser, ExtendedAdvertiser, None
|
||||
] = None
|
||||
|
||||
@composite_listener
|
||||
class Listener:
|
||||
@@ -1063,7 +1098,8 @@ class Device(CompositeEventEmitter):
|
||||
]
|
||||
advertisement_accumulators: Dict[Address, AdvertisementDataAccumulator]
|
||||
config: DeviceConfiguration
|
||||
extended_advertising_handles: Set[int]
|
||||
legacy_advertiser: Optional[LegacyAdvertiser]
|
||||
extended_advertisers: Dict[int, ExtendedAdvertiser]
|
||||
sco_links: Dict[int, ScoLink]
|
||||
cis_links: Dict[int, CisLink]
|
||||
_pending_cis: Dict[int, Tuple[int, int]]
|
||||
@@ -1141,10 +1177,7 @@ class Device(CompositeEventEmitter):
|
||||
|
||||
self._host = None
|
||||
self.powered_on = False
|
||||
self.advertising = False
|
||||
self.advertising_type = None
|
||||
self.auto_restart_inquiry = True
|
||||
self.auto_restart_advertising = False
|
||||
self.command_timeout = 10 # seconds
|
||||
self.gatt_server = gatt_server.Server(self)
|
||||
self.sdp_server = sdp.Server(self)
|
||||
@@ -1168,10 +1201,10 @@ class Device(CompositeEventEmitter):
|
||||
self.classic_pending_accepts = {
|
||||
Address.ANY: []
|
||||
} # Futures, by BD address OR [Futures] for Address.ANY
|
||||
self.extended_advertising_handles = set()
|
||||
self.legacy_advertiser = None
|
||||
self.extended_advertisers = {}
|
||||
|
||||
# Own address type cache
|
||||
self.advertising_own_address_type = None
|
||||
self.connect_own_address_type = None
|
||||
|
||||
# Use the initial config or a default
|
||||
@@ -1432,7 +1465,7 @@ class Device(CompositeEventEmitter):
|
||||
await self.host.reset()
|
||||
|
||||
# Try to get the public address from the controller
|
||||
response = await self.send_command(HCI_Read_BD_ADDR_Command()) # type: ignore[call-arg]
|
||||
response = await self.send_command(HCI_Read_BD_ADDR_Command())
|
||||
if response.return_parameters.status == HCI_SUCCESS:
|
||||
logger.debug(
|
||||
color(f'BD_ADDR: {response.return_parameters.bd_addr}', 'yellow')
|
||||
@@ -1455,7 +1488,7 @@ class Device(CompositeEventEmitter):
|
||||
HCI_Write_LE_Host_Support_Command(
|
||||
le_supported_host=int(self.le_enabled),
|
||||
simultaneous_le_host=int(self.le_simultaneous_enabled),
|
||||
) # type: ignore[call-arg]
|
||||
)
|
||||
)
|
||||
|
||||
if self.le_enabled:
|
||||
@@ -1465,7 +1498,7 @@ class Device(CompositeEventEmitter):
|
||||
if self.host.supports_command(HCI_LE_RAND_COMMAND):
|
||||
# Get 8 random bytes
|
||||
response = await self.send_command(
|
||||
HCI_LE_Rand_Command(), check_result=True # type: ignore[call-arg]
|
||||
HCI_LE_Rand_Command(), check_result=True
|
||||
)
|
||||
|
||||
# Ensure the address bytes can be a static random address
|
||||
@@ -1486,7 +1519,7 @@ class Device(CompositeEventEmitter):
|
||||
await self.send_command(
|
||||
HCI_LE_Set_Random_Address_Command(
|
||||
random_address=self.random_address
|
||||
), # type: ignore[call-arg]
|
||||
),
|
||||
check_result=True,
|
||||
)
|
||||
|
||||
@@ -1499,12 +1532,12 @@ class Device(CompositeEventEmitter):
|
||||
await self.send_command(
|
||||
HCI_LE_Set_Address_Resolution_Enable_Command(
|
||||
address_resolution_enable=1
|
||||
) # type: ignore[call-arg]
|
||||
)
|
||||
)
|
||||
|
||||
if self.cis_enabled:
|
||||
await self.send_command(
|
||||
HCI_LE_Set_Host_Feature_Command( # type: ignore[call-arg]
|
||||
HCI_LE_Set_Host_Feature_Command(
|
||||
bit_number=(
|
||||
HCI_CONNECTED_ISOCHRONOUS_STREAM_LE_SUPPORTED_FEATURE
|
||||
),
|
||||
@@ -1514,20 +1547,20 @@ class Device(CompositeEventEmitter):
|
||||
|
||||
if self.classic_enabled:
|
||||
await self.send_command(
|
||||
HCI_Write_Local_Name_Command(local_name=self.name.encode('utf8')) # type: ignore[call-arg]
|
||||
HCI_Write_Local_Name_Command(local_name=self.name.encode('utf8'))
|
||||
)
|
||||
await self.send_command(
|
||||
HCI_Write_Class_Of_Device_Command(class_of_device=self.class_of_device) # type: ignore[call-arg]
|
||||
HCI_Write_Class_Of_Device_Command(class_of_device=self.class_of_device)
|
||||
)
|
||||
await self.send_command(
|
||||
HCI_Write_Simple_Pairing_Mode_Command(
|
||||
simple_pairing_mode=int(self.classic_ssp_enabled)
|
||||
) # type: ignore[call-arg]
|
||||
)
|
||||
)
|
||||
await self.send_command(
|
||||
HCI_Write_Secure_Connections_Host_Support_Command(
|
||||
secure_connections_host_support=int(self.classic_sc_enabled)
|
||||
) # type: ignore[call-arg]
|
||||
)
|
||||
)
|
||||
await self.set_connectable(self.connectable)
|
||||
await self.set_discoverable(self.discoverable)
|
||||
@@ -1551,7 +1584,7 @@ class Device(CompositeEventEmitter):
|
||||
self.address_resolver = smp.AddressResolver(resolving_keys)
|
||||
|
||||
if self.address_resolution_offload:
|
||||
await self.send_command(HCI_LE_Clear_Resolving_List_Command()) # type: ignore[call-arg]
|
||||
await self.send_command(HCI_LE_Clear_Resolving_List_Command())
|
||||
|
||||
for irk, address in resolving_keys:
|
||||
await self.send_command(
|
||||
@@ -1560,7 +1593,7 @@ class Device(CompositeEventEmitter):
|
||||
peer_identity_address=address,
|
||||
peer_irk=irk,
|
||||
local_irk=self.irk,
|
||||
) # type: ignore[call-arg]
|
||||
)
|
||||
)
|
||||
|
||||
def supports_le_feature(self, feature):
|
||||
@@ -1579,6 +1612,7 @@ class Device(CompositeEventEmitter):
|
||||
|
||||
return self.host.supports_le_feature(feature_map[phy])
|
||||
|
||||
@deprecated("Please use start_legacy_advertising.")
|
||||
async def start_advertising(
|
||||
self,
|
||||
advertising_type: AdvertisingType = AdvertisingType.UNDIRECTED_CONNECTABLE_SCANNABLE,
|
||||
@@ -1586,16 +1620,50 @@ class Device(CompositeEventEmitter):
|
||||
own_address_type: int = OwnAddressType.RANDOM,
|
||||
auto_restart: bool = False,
|
||||
) -> None:
|
||||
await self.start_legacy_advertising(
|
||||
advertising_type=advertising_type,
|
||||
target=target,
|
||||
own_address_type=OwnAddressType(own_address_type),
|
||||
auto_restart=auto_restart,
|
||||
)
|
||||
|
||||
async def start_legacy_advertising(
|
||||
self,
|
||||
advertising_type: AdvertisingType = AdvertisingType.UNDIRECTED_CONNECTABLE_SCANNABLE,
|
||||
target: Optional[Address] = None,
|
||||
own_address_type: OwnAddressType = OwnAddressType.RANDOM,
|
||||
auto_restart: bool = False,
|
||||
advertising_data: Optional[bytes] = None,
|
||||
scan_response_data: Optional[bytes] = None,
|
||||
) -> LegacyAdvertiser:
|
||||
"""Starts an legacy advertisement.
|
||||
|
||||
Args:
|
||||
advertising_type: Advertising type passed to HCI_LE_Set_Advertising_Parameters_Command.
|
||||
target: Directed advertising target. Directed type should be set in advertising_type arg.
|
||||
own_address_type: own address type to use in the advertising.
|
||||
auto_restart: whether the advertisement will be restarted after disconnection.
|
||||
scan_response_data: raw scan response.
|
||||
advertising_data: raw advertising data.
|
||||
|
||||
Returns:
|
||||
LegacyAdvertiser object containing the metadata of advertisement.
|
||||
"""
|
||||
if self.extended_advertisers:
|
||||
logger.warning(
|
||||
'Trying to start Legacy and Extended Advertising at the same time!'
|
||||
)
|
||||
|
||||
# If we're advertising, stop first
|
||||
if self.advertising:
|
||||
if self.legacy_advertiser:
|
||||
await self.stop_advertising()
|
||||
|
||||
# Set/update the advertising data if the advertising type allows it
|
||||
if advertising_type.has_data:
|
||||
await self.send_command(
|
||||
HCI_LE_Set_Advertising_Data_Command(
|
||||
advertising_data=self.advertising_data
|
||||
), # type: ignore[call-arg]
|
||||
advertising_data=advertising_data or self.advertising_data or b''
|
||||
),
|
||||
check_result=True,
|
||||
)
|
||||
|
||||
@@ -1603,8 +1671,10 @@ class Device(CompositeEventEmitter):
|
||||
if advertising_type.is_scannable:
|
||||
await self.send_command(
|
||||
HCI_LE_Set_Scan_Response_Data_Command(
|
||||
scan_response_data=self.scan_response_data
|
||||
), # type: ignore[call-arg]
|
||||
scan_response_data=scan_response_data
|
||||
or self.scan_response_data
|
||||
or b''
|
||||
),
|
||||
check_result=True,
|
||||
)
|
||||
|
||||
@@ -1630,55 +1700,67 @@ class Device(CompositeEventEmitter):
|
||||
peer_address=peer_address,
|
||||
advertising_channel_map=7,
|
||||
advertising_filter_policy=0,
|
||||
), # type: ignore[call-arg]
|
||||
),
|
||||
check_result=True,
|
||||
)
|
||||
|
||||
# Enable advertising
|
||||
await self.send_command(
|
||||
HCI_LE_Set_Advertising_Enable_Command(advertising_enable=1), # type: ignore[call-arg]
|
||||
HCI_LE_Set_Advertising_Enable_Command(advertising_enable=1),
|
||||
check_result=True,
|
||||
)
|
||||
|
||||
self.advertising_type = advertising_type
|
||||
self.advertising_own_address_type = own_address_type
|
||||
self.advertising = True
|
||||
self.auto_restart_advertising = auto_restart
|
||||
self.legacy_advertiser = LegacyAdvertiser(
|
||||
device=self,
|
||||
advertising_type=advertising_type,
|
||||
own_address_type=own_address_type,
|
||||
auto_restart=auto_restart,
|
||||
advertising_data=advertising_data,
|
||||
scan_response_data=scan_response_data,
|
||||
)
|
||||
return self.legacy_advertiser
|
||||
|
||||
@deprecated("Please use stop_legacy_advertising.")
|
||||
async def stop_advertising(self) -> None:
|
||||
await self.stop_legacy_advertising()
|
||||
|
||||
async def stop_legacy_advertising(self) -> None:
|
||||
# Disable advertising
|
||||
if self.advertising:
|
||||
if self.legacy_advertiser:
|
||||
await self.send_command(
|
||||
HCI_LE_Set_Advertising_Enable_Command(advertising_enable=0), # type: ignore[call-arg]
|
||||
HCI_LE_Set_Advertising_Enable_Command(advertising_enable=0),
|
||||
check_result=True,
|
||||
)
|
||||
|
||||
self.advertising_type = None
|
||||
self.advertising_own_address_type = None
|
||||
self.advertising = False
|
||||
self.auto_restart_advertising = False
|
||||
self.legacy_advertiser = None
|
||||
|
||||
@experimental('Extended Advertising is still experimental - Might be changed soon.')
|
||||
async def start_extended_advertising(
|
||||
self,
|
||||
advertising_properties: HCI_LE_Set_Extended_Advertising_Parameters_Command.AdvertisingProperties = HCI_LE_Set_Extended_Advertising_Parameters_Command.AdvertisingProperties.CONNECTABLE_ADVERTISING,
|
||||
target: Address = Address.ANY,
|
||||
own_address_type: int = OwnAddressType.RANDOM,
|
||||
scan_response: Optional[bytes] = None,
|
||||
own_address_type: OwnAddressType = OwnAddressType.RANDOM,
|
||||
auto_restart: bool = True,
|
||||
advertising_data: Optional[bytes] = None,
|
||||
) -> int:
|
||||
scan_response_data: Optional[bytes] = None,
|
||||
) -> ExtendedAdvertiser:
|
||||
"""Starts an extended advertising set.
|
||||
|
||||
Args:
|
||||
advertising_properties: Properties to pass in HCI_LE_Set_Extended_Advertising_Parameters_Command
|
||||
target: Directed advertising target. Directed property should be set in advertising_properties arg.
|
||||
own_address_type: own address type to use in the advertising.
|
||||
scan_response: raw scan response. When a non-none value is set, HCI_LE_Set_Extended_Scan_Response_Data_Command will be sent.
|
||||
auto_restart: whether the advertisement will be restarted after disconnection.
|
||||
advertising_data: raw advertising data. When a non-none value is set, HCI_LE_Set_Advertising_Set_Random_Address_Command will be sent.
|
||||
scan_response_data: raw scan response. When a non-none value is set, HCI_LE_Set_Extended_Scan_Response_Data_Command will be sent.
|
||||
|
||||
Returns:
|
||||
Handle of the new advertising set.
|
||||
ExtendedAdvertiser object containing the metadata of advertisement.
|
||||
"""
|
||||
if self.legacy_advertiser:
|
||||
logger.warning(
|
||||
'Trying to start Legacy and Extended Advertising at the same time!'
|
||||
)
|
||||
|
||||
adv_handle = -1
|
||||
# Find a free handle
|
||||
@@ -1686,7 +1768,7 @@ class Device(CompositeEventEmitter):
|
||||
DEVICE_MIN_EXTENDED_ADVERTISING_SET_HANDLE,
|
||||
DEVICE_MAX_EXTENDED_ADVERTISING_SET_HANDLE + 1,
|
||||
):
|
||||
if i not in self.extended_advertising_handles:
|
||||
if i not in self.extended_advertisers:
|
||||
adv_handle = i
|
||||
break
|
||||
|
||||
@@ -1716,7 +1798,7 @@ class Device(CompositeEventEmitter):
|
||||
secondary_advertising_phy=1, # LE 1M
|
||||
advertising_sid=0,
|
||||
scan_request_notification_enable=0,
|
||||
), # type: ignore[call-arg]
|
||||
),
|
||||
check_result=True,
|
||||
)
|
||||
|
||||
@@ -1728,19 +1810,19 @@ class Device(CompositeEventEmitter):
|
||||
operation=HCI_LE_Set_Extended_Advertising_Data_Command.Operation.COMPLETE_DATA,
|
||||
fragment_preference=0x01, # Should not fragment
|
||||
advertising_data=advertising_data,
|
||||
), # type: ignore[call-arg]
|
||||
),
|
||||
check_result=True,
|
||||
)
|
||||
|
||||
# Set the scan response if present
|
||||
if scan_response is not None:
|
||||
if scan_response_data is not None:
|
||||
await self.send_command(
|
||||
HCI_LE_Set_Extended_Scan_Response_Data_Command(
|
||||
advertising_handle=adv_handle,
|
||||
operation=HCI_LE_Set_Extended_Advertising_Data_Command.Operation.COMPLETE_DATA,
|
||||
fragment_preference=0x01, # Should not fragment
|
||||
scan_response_data=scan_response,
|
||||
), # type: ignore[call-arg]
|
||||
scan_response_data=scan_response_data,
|
||||
),
|
||||
check_result=True,
|
||||
)
|
||||
|
||||
@@ -1752,7 +1834,7 @@ class Device(CompositeEventEmitter):
|
||||
HCI_LE_Set_Advertising_Set_Random_Address_Command(
|
||||
advertising_handle=adv_handle,
|
||||
random_address=self.random_address,
|
||||
), # type: ignore[call-arg]
|
||||
),
|
||||
check_result=True,
|
||||
)
|
||||
|
||||
@@ -1763,19 +1845,27 @@ class Device(CompositeEventEmitter):
|
||||
advertising_handles=[adv_handle],
|
||||
durations=[0], # Forever
|
||||
max_extended_advertising_events=[0], # Infinite
|
||||
), # type: ignore[call-arg]
|
||||
),
|
||||
check_result=True,
|
||||
)
|
||||
except HCI_Error as error:
|
||||
# When any step fails, cleanup the advertising handle.
|
||||
await self.send_command(
|
||||
HCI_LE_Remove_Advertising_Set_Command(advertising_handle=adv_handle), # type: ignore[call-arg]
|
||||
HCI_LE_Remove_Advertising_Set_Command(advertising_handle=adv_handle),
|
||||
check_result=False,
|
||||
)
|
||||
raise error
|
||||
|
||||
self.extended_advertising_handles.add(adv_handle)
|
||||
return adv_handle
|
||||
advertiser = self.extended_advertisers[adv_handle] = ExtendedAdvertiser(
|
||||
device=self,
|
||||
handle=adv_handle,
|
||||
advertising_properties=advertising_properties,
|
||||
own_address_type=own_address_type,
|
||||
auto_restart=auto_restart,
|
||||
advertising_data=advertising_data,
|
||||
scan_response_data=scan_response_data,
|
||||
)
|
||||
return advertiser
|
||||
|
||||
@experimental('Extended Advertising is still experimental - Might be changed soon.')
|
||||
async def stop_extended_advertising(self, adv_handle: int) -> None:
|
||||
@@ -1791,19 +1881,19 @@ class Device(CompositeEventEmitter):
|
||||
advertising_handles=[adv_handle],
|
||||
durations=[0],
|
||||
max_extended_advertising_events=[0],
|
||||
), # type: ignore[call-arg]
|
||||
),
|
||||
check_result=True,
|
||||
)
|
||||
# Remove advertising set
|
||||
await self.send_command(
|
||||
HCI_LE_Remove_Advertising_Set_Command(advertising_handle=adv_handle), # type: ignore[call-arg]
|
||||
HCI_LE_Remove_Advertising_Set_Command(advertising_handle=adv_handle),
|
||||
check_result=True,
|
||||
)
|
||||
self.extended_advertising_handles.remove(adv_handle)
|
||||
del self.extended_advertisers[adv_handle]
|
||||
|
||||
@property
|
||||
def is_advertising(self):
|
||||
return self.advertising
|
||||
return self.legacy_advertiser or self.extended_advertisers
|
||||
|
||||
async def start_scanning(
|
||||
self,
|
||||
@@ -1864,7 +1954,7 @@ class Device(CompositeEventEmitter):
|
||||
scan_types=[scan_type] * scanning_phy_count,
|
||||
scan_intervals=[int(scan_window / 0.625)] * scanning_phy_count,
|
||||
scan_windows=[int(scan_window / 0.625)] * scanning_phy_count,
|
||||
), # type: ignore[call-arg]
|
||||
),
|
||||
check_result=True,
|
||||
)
|
||||
|
||||
@@ -1875,7 +1965,7 @@ class Device(CompositeEventEmitter):
|
||||
filter_duplicates=1 if filter_duplicates else 0,
|
||||
duration=0, # TODO allow other values
|
||||
period=0, # TODO allow other values
|
||||
), # type: ignore[call-arg]
|
||||
),
|
||||
check_result=True,
|
||||
)
|
||||
else:
|
||||
@@ -1893,7 +1983,7 @@ class Device(CompositeEventEmitter):
|
||||
le_scan_window=int(scan_window / 0.625),
|
||||
own_address_type=own_address_type,
|
||||
scanning_filter_policy=HCI_LE_Set_Scan_Parameters_Command.BASIC_UNFILTERED_POLICY,
|
||||
), # type: ignore[call-arg]
|
||||
),
|
||||
check_result=True,
|
||||
)
|
||||
|
||||
@@ -1901,7 +1991,7 @@ class Device(CompositeEventEmitter):
|
||||
await self.send_command(
|
||||
HCI_LE_Set_Scan_Enable_Command(
|
||||
le_scan_enable=1, filter_duplicates=1 if filter_duplicates else 0
|
||||
), # type: ignore[call-arg]
|
||||
),
|
||||
check_result=True,
|
||||
)
|
||||
|
||||
@@ -1914,12 +2004,12 @@ class Device(CompositeEventEmitter):
|
||||
await self.send_command(
|
||||
HCI_LE_Set_Extended_Scan_Enable_Command(
|
||||
enable=0, filter_duplicates=0, duration=0, period=0
|
||||
), # type: ignore[call-arg]
|
||||
),
|
||||
check_result=True,
|
||||
)
|
||||
else:
|
||||
await self.send_command(
|
||||
HCI_LE_Set_Scan_Enable_Command(le_scan_enable=0, filter_duplicates=0), # type: ignore[call-arg]
|
||||
HCI_LE_Set_Scan_Enable_Command(le_scan_enable=0, filter_duplicates=0),
|
||||
check_result=True,
|
||||
)
|
||||
|
||||
@@ -1939,7 +2029,7 @@ class Device(CompositeEventEmitter):
|
||||
|
||||
async def start_discovery(self, auto_restart: bool = True) -> None:
|
||||
await self.send_command(
|
||||
HCI_Write_Inquiry_Mode_Command(inquiry_mode=HCI_EXTENDED_INQUIRY_MODE), # type: ignore[call-arg]
|
||||
HCI_Write_Inquiry_Mode_Command(inquiry_mode=HCI_EXTENDED_INQUIRY_MODE),
|
||||
check_result=True,
|
||||
)
|
||||
|
||||
@@ -1948,7 +2038,7 @@ class Device(CompositeEventEmitter):
|
||||
lap=HCI_GENERAL_INQUIRY_LAP,
|
||||
inquiry_length=DEVICE_DEFAULT_INQUIRY_LENGTH,
|
||||
num_responses=0, # Unlimited number of responses.
|
||||
) # type: ignore[call-arg]
|
||||
)
|
||||
)
|
||||
if response.status != HCI_Command_Status_Event.PENDING:
|
||||
self.discovering = False
|
||||
@@ -1959,7 +2049,7 @@ class Device(CompositeEventEmitter):
|
||||
|
||||
async def stop_discovery(self) -> None:
|
||||
if self.discovering:
|
||||
await self.send_command(HCI_Inquiry_Cancel_Command(), check_result=True) # type: ignore[call-arg]
|
||||
await self.send_command(HCI_Inquiry_Cancel_Command(), check_result=True)
|
||||
self.auto_restart_inquiry = True
|
||||
self.discovering = False
|
||||
|
||||
@@ -2007,7 +2097,7 @@ class Device(CompositeEventEmitter):
|
||||
await self.send_command(
|
||||
HCI_Write_Extended_Inquiry_Response_Command(
|
||||
fec_required=0, extended_inquiry_response=self.inquiry_response
|
||||
), # type: ignore[call-arg]
|
||||
),
|
||||
check_result=True,
|
||||
)
|
||||
await self.set_scan_enable(
|
||||
@@ -2196,7 +2286,7 @@ class Device(CompositeEventEmitter):
|
||||
supervision_timeouts=supervision_timeouts,
|
||||
min_ce_lengths=min_ce_lengths,
|
||||
max_ce_lengths=max_ce_lengths,
|
||||
) # type: ignore[call-arg]
|
||||
)
|
||||
)
|
||||
else:
|
||||
if HCI_LE_1M_PHY not in connection_parameters_preferences:
|
||||
@@ -2225,7 +2315,7 @@ class Device(CompositeEventEmitter):
|
||||
supervision_timeout=int(prefs.supervision_timeout / 10),
|
||||
min_ce_length=int(prefs.min_ce_length / 0.625),
|
||||
max_ce_length=int(prefs.max_ce_length / 0.625),
|
||||
) # type: ignore[call-arg]
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Save pending connection
|
||||
@@ -2242,7 +2332,7 @@ class Device(CompositeEventEmitter):
|
||||
clock_offset=0x0000,
|
||||
allow_role_switch=0x01,
|
||||
reserved=0,
|
||||
) # type: ignore[call-arg]
|
||||
)
|
||||
)
|
||||
|
||||
if result.status != HCI_Command_Status_Event.PENDING:
|
||||
@@ -2261,10 +2351,10 @@ class Device(CompositeEventEmitter):
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
if transport == BT_LE_TRANSPORT:
|
||||
await self.send_command(HCI_LE_Create_Connection_Cancel_Command()) # type: ignore[call-arg]
|
||||
await self.send_command(HCI_LE_Create_Connection_Cancel_Command())
|
||||
else:
|
||||
await self.send_command(
|
||||
HCI_Create_Connection_Cancel_Command(bd_addr=peer_address) # type: ignore[call-arg]
|
||||
HCI_Create_Connection_Cancel_Command(bd_addr=peer_address)
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -2378,7 +2468,7 @@ class Device(CompositeEventEmitter):
|
||||
try:
|
||||
# Accept connection request
|
||||
await self.send_command(
|
||||
HCI_Accept_Connection_Request_Command(bd_addr=peer_address, role=role) # type: ignore[call-arg]
|
||||
HCI_Accept_Connection_Request_Command(bd_addr=peer_address, role=role)
|
||||
)
|
||||
|
||||
# Wait for connection complete
|
||||
@@ -2445,7 +2535,7 @@ class Device(CompositeEventEmitter):
|
||||
|
||||
# Request a disconnection
|
||||
result = await self.send_command(
|
||||
HCI_Disconnect_Command(connection_handle=connection.handle, reason=reason) # type: ignore[call-arg]
|
||||
HCI_Disconnect_Command(connection_handle=connection.handle, reason=reason)
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -2476,7 +2566,7 @@ class Device(CompositeEventEmitter):
|
||||
connection_handle=connection.handle,
|
||||
tx_octets=tx_octets,
|
||||
tx_time=tx_time,
|
||||
), # type: ignore[call-arg]
|
||||
),
|
||||
check_result=True,
|
||||
)
|
||||
|
||||
@@ -2522,7 +2612,7 @@ class Device(CompositeEventEmitter):
|
||||
supervision_timeout=supervision_timeout,
|
||||
min_ce_length=min_ce_length,
|
||||
max_ce_length=max_ce_length,
|
||||
) # type: ignore[call-arg]
|
||||
)
|
||||
)
|
||||
if result.status != HCI_Command_Status_Event.PENDING:
|
||||
raise HCI_StatusError(result)
|
||||
@@ -2850,7 +2940,7 @@ class Device(CompositeEventEmitter):
|
||||
|
||||
try:
|
||||
result = await self.send_command(
|
||||
HCI_Switch_Role_Command(bd_addr=connection.peer_address, role=role) # type: ignore[call-arg]
|
||||
HCI_Switch_Role_Command(bd_addr=connection.peer_address, role=role)
|
||||
)
|
||||
if result.status != HCI_COMMAND_STATUS_PENDING:
|
||||
logger.warning(
|
||||
@@ -2892,7 +2982,7 @@ class Device(CompositeEventEmitter):
|
||||
page_scan_repetition_mode=HCI_Remote_Name_Request_Command.R2,
|
||||
reserved=0,
|
||||
clock_offset=0, # TODO investigate non-0 values
|
||||
) # type: ignore[call-arg]
|
||||
)
|
||||
)
|
||||
|
||||
if result.status != HCI_COMMAND_STATUS_PENDING:
|
||||
@@ -2938,7 +3028,7 @@ class Device(CompositeEventEmitter):
|
||||
num_cis = len(cis_id)
|
||||
|
||||
response = await self.send_command(
|
||||
HCI_LE_Set_CIG_Parameters_Command( # type: ignore[call-arg]
|
||||
HCI_LE_Set_CIG_Parameters_Command(
|
||||
cig_id=cig_id,
|
||||
sdu_interval_c_to_p=sdu_interval[0],
|
||||
sdu_interval_p_to_c=sdu_interval[1],
|
||||
@@ -2982,7 +3072,7 @@ class Device(CompositeEventEmitter):
|
||||
)
|
||||
|
||||
result = await self.send_command(
|
||||
HCI_LE_Create_CIS_Command( # type: ignore[call-arg]
|
||||
HCI_LE_Create_CIS_Command(
|
||||
cis_connection_handle=[p[0] for p in cis_acl_pairs],
|
||||
acl_connection_handle=[p[1] for p in cis_acl_pairs],
|
||||
),
|
||||
@@ -3015,9 +3105,7 @@ class Device(CompositeEventEmitter):
|
||||
@experimental('Only for testing.')
|
||||
async def accept_cis_request(self, handle: int) -> CisLink:
|
||||
result = await self.send_command(
|
||||
HCI_LE_Accept_CIS_Request_Command( # type: ignore[call-arg]
|
||||
connection_handle=handle
|
||||
),
|
||||
HCI_LE_Accept_CIS_Request_Command(connection_handle=handle),
|
||||
)
|
||||
if result.status != HCI_COMMAND_STATUS_PENDING:
|
||||
logger.warning(
|
||||
@@ -3045,9 +3133,7 @@ class Device(CompositeEventEmitter):
|
||||
reason: int = HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR,
|
||||
) -> None:
|
||||
result = await self.send_command(
|
||||
HCI_LE_Reject_CIS_Request_Command( # type: ignore[call-arg]
|
||||
connection_handle=handle, reason=reason
|
||||
),
|
||||
HCI_LE_Reject_CIS_Request_Command(connection_handle=handle, reason=reason),
|
||||
)
|
||||
if result.status != HCI_COMMAND_STATUS_PENDING:
|
||||
logger.warning(
|
||||
@@ -3148,13 +3234,18 @@ class Device(CompositeEventEmitter):
|
||||
# Guess which own address type is used for this connection.
|
||||
# This logic is somewhat correct but may need to be improved
|
||||
# when multiple advertising are run simultaneously.
|
||||
advertiser = None
|
||||
if self.connect_own_address_type is not None:
|
||||
own_address_type = self.connect_own_address_type
|
||||
elif self.legacy_advertiser:
|
||||
own_address_type = self.legacy_advertiser.own_address_type
|
||||
# Store advertiser for restarting - it's only required for legacy, since
|
||||
# extended advertisement produces HCI_Advertising_Set_Terminated.
|
||||
if self.legacy_advertiser.auto_restart:
|
||||
advertiser = self.legacy_advertiser
|
||||
else:
|
||||
own_address_type = self.advertising_own_address_type
|
||||
|
||||
# We are no longer advertising
|
||||
self.advertising = False
|
||||
# For extended advertisement, determining own address type later.
|
||||
own_address_type = OwnAddressType.RANDOM
|
||||
|
||||
if own_address_type in (
|
||||
OwnAddressType.PUBLIC,
|
||||
@@ -3176,6 +3267,7 @@ class Device(CompositeEventEmitter):
|
||||
connection_parameters,
|
||||
ConnectionPHY(HCI_LE_1M_PHY, HCI_LE_1M_PHY),
|
||||
)
|
||||
connection.advertiser_after_disconnection = advertiser
|
||||
self.connections[connection_handle] = connection
|
||||
|
||||
# If supported, read which PHY we're connected with before
|
||||
@@ -3207,10 +3299,10 @@ class Device(CompositeEventEmitter):
|
||||
# For directed advertising, this means a timeout
|
||||
if (
|
||||
transport == BT_LE_TRANSPORT
|
||||
and self.advertising
|
||||
and self.advertising_type.is_directed
|
||||
and self.legacy_advertiser
|
||||
and self.legacy_advertiser.advertising_type.is_directed
|
||||
):
|
||||
self.advertising = False
|
||||
self.legacy_advertiser = None
|
||||
|
||||
# Notify listeners
|
||||
error = core.ConnectionError(
|
||||
@@ -3272,16 +3364,30 @@ class Device(CompositeEventEmitter):
|
||||
self.gatt_server.on_disconnection(connection)
|
||||
|
||||
# Restart advertising if auto-restart is enabled
|
||||
if self.auto_restart_advertising:
|
||||
if advertiser := connection.advertiser_after_disconnection:
|
||||
logger.debug('restarting advertising')
|
||||
self.abort_on(
|
||||
'flush',
|
||||
self.start_advertising(
|
||||
advertising_type=self.advertising_type, # type: ignore[arg-type]
|
||||
own_address_type=self.advertising_own_address_type, # type: ignore[arg-type]
|
||||
auto_restart=True,
|
||||
),
|
||||
)
|
||||
if isinstance(advertiser, LegacyAdvertiser):
|
||||
self.abort_on(
|
||||
'flush',
|
||||
self.start_legacy_advertising(
|
||||
advertising_type=advertiser.advertising_type,
|
||||
own_address_type=advertiser.own_address_type,
|
||||
advertising_data=advertiser.advertising_data,
|
||||
scan_response_data=advertiser.scan_response_data,
|
||||
auto_restart=True,
|
||||
),
|
||||
)
|
||||
elif isinstance(advertiser, ExtendedAdvertiser):
|
||||
self.abort_on(
|
||||
'flush',
|
||||
self.start_extended_advertising(
|
||||
advertising_properties=advertiser.advertising_properties,
|
||||
own_address_type=advertiser.own_address_type,
|
||||
advertising_data=advertiser.advertising_data,
|
||||
scan_response_data=advertiser.scan_response_data,
|
||||
auto_restart=True,
|
||||
),
|
||||
)
|
||||
elif sco_link := self.sco_links.pop(connection_handle, None):
|
||||
sco_link.emit('disconnection', reason)
|
||||
elif cis_link := self.cis_links.pop(connection_handle, None):
|
||||
@@ -3439,7 +3545,7 @@ class Device(CompositeEventEmitter):
|
||||
try:
|
||||
if await connection.abort_on('disconnection', method()):
|
||||
await self.host.send_command(
|
||||
HCI_User_Confirmation_Request_Reply_Command( # type: ignore[call-arg]
|
||||
HCI_User_Confirmation_Request_Reply_Command(
|
||||
bd_addr=connection.peer_address
|
||||
)
|
||||
)
|
||||
@@ -3448,7 +3554,7 @@ class Device(CompositeEventEmitter):
|
||||
logger.warning(f'exception while confirming: {error}')
|
||||
|
||||
await self.host.send_command(
|
||||
HCI_User_Confirmation_Request_Negative_Reply_Command( # type: ignore[call-arg]
|
||||
HCI_User_Confirmation_Request_Negative_Reply_Command(
|
||||
bd_addr=connection.peer_address
|
||||
)
|
||||
)
|
||||
@@ -3469,7 +3575,7 @@ class Device(CompositeEventEmitter):
|
||||
)
|
||||
if number is not None:
|
||||
await self.host.send_command(
|
||||
HCI_User_Passkey_Request_Reply_Command( # type: ignore[call-arg]
|
||||
HCI_User_Passkey_Request_Reply_Command(
|
||||
bd_addr=connection.peer_address, numeric_value=number
|
||||
)
|
||||
)
|
||||
@@ -3478,7 +3584,7 @@ class Device(CompositeEventEmitter):
|
||||
logger.warning(f'exception while asking for pass-key: {error}')
|
||||
|
||||
await self.host.send_command(
|
||||
HCI_User_Passkey_Request_Negative_Reply_Command( # type: ignore[call-arg]
|
||||
HCI_User_Passkey_Request_Negative_Reply_Command(
|
||||
bd_addr=connection.peer_address
|
||||
)
|
||||
)
|
||||
@@ -3604,6 +3710,30 @@ class Device(CompositeEventEmitter):
|
||||
if sco_link := self.sco_links.get(sco_handle, None):
|
||||
sco_link.emit('pdu', packet)
|
||||
|
||||
# [LE only]
|
||||
@host_event_handler
|
||||
@experimental('Only for testing')
|
||||
def on_advertising_set_termination(
|
||||
self,
|
||||
status: int,
|
||||
advertising_handle: int,
|
||||
connection_handle: int,
|
||||
) -> None:
|
||||
if status == HCI_SUCCESS:
|
||||
connection = self.lookup_connection(connection_handle)
|
||||
if advertiser := self.extended_advertisers.pop(advertising_handle, None):
|
||||
if connection:
|
||||
if advertiser.auto_restart:
|
||||
connection.advertiser_after_disconnection = advertiser
|
||||
if advertiser.own_address_type in (
|
||||
OwnAddressType.PUBLIC,
|
||||
OwnAddressType.RESOLVABLE_OR_PUBLIC,
|
||||
):
|
||||
connection.self_address = self.public_address
|
||||
else:
|
||||
connection.self_address = self.random_address
|
||||
advertiser.emit('termination', status)
|
||||
|
||||
# [LE only]
|
||||
@host_event_handler
|
||||
@with_connection_from_handle
|
||||
|
||||
@@ -19,12 +19,17 @@ like loading firmware after a cold start.
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import abc
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import pathlib
|
||||
import platform
|
||||
from . import rtk
|
||||
from typing import Dict, Iterable, Optional, Type, TYPE_CHECKING
|
||||
|
||||
from . import rtk
|
||||
from .common import Driver
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bumble.host import Host
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
@@ -32,40 +37,31 @@ from . import rtk
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Classes
|
||||
# -----------------------------------------------------------------------------
|
||||
class Driver(abc.ABC):
|
||||
"""Base class for drivers."""
|
||||
|
||||
@staticmethod
|
||||
async def for_host(_host):
|
||||
"""Return a driver instance for a host.
|
||||
|
||||
Args:
|
||||
host: Host object for which a driver should be created.
|
||||
|
||||
Returns:
|
||||
A Driver instance if a driver should be instantiated for this host, or
|
||||
None if no driver instance of this class is needed.
|
||||
"""
|
||||
return None
|
||||
|
||||
@abc.abstractmethod
|
||||
async def init_controller(self):
|
||||
"""Initialize the controller."""
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Functions
|
||||
# -----------------------------------------------------------------------------
|
||||
async def get_driver_for_host(host):
|
||||
"""Probe all known diver classes until one returns a valid instance for a host,
|
||||
or none is found.
|
||||
async def get_driver_for_host(host: Host) -> Optional[Driver]:
|
||||
"""Probe diver classes until one returns a valid instance for a host, or none is
|
||||
found.
|
||||
If a "driver" HCI metadata entry is present, only that driver class will be probed.
|
||||
"""
|
||||
if driver := await rtk.Driver.for_host(host):
|
||||
logger.debug("Instantiated RTK driver")
|
||||
return driver
|
||||
driver_classes: Dict[str, Type[Driver]] = {"rtk": rtk.Driver}
|
||||
probe_list: Iterable[str]
|
||||
if driver_name := host.hci_metadata.get("driver"):
|
||||
# Only probe a single driver
|
||||
probe_list = [driver_name]
|
||||
else:
|
||||
# Probe all drivers
|
||||
probe_list = driver_classes.keys()
|
||||
|
||||
for driver_name in probe_list:
|
||||
if driver_class := driver_classes.get(driver_name):
|
||||
logger.debug(f"Probing driver class: {driver_name}")
|
||||
if driver := await driver_class.for_host(host):
|
||||
logger.debug(f"Instantiated {driver_name} driver")
|
||||
return driver
|
||||
else:
|
||||
logger.debug(f"Skipping unknown driver class: {driver_name}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
45
bumble/drivers/common.py
Normal file
45
bumble/drivers/common.py
Normal file
@@ -0,0 +1,45 @@
|
||||
# Copyright 2021-2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Common types for drivers.
|
||||
"""
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import abc
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Classes
|
||||
# -----------------------------------------------------------------------------
|
||||
class Driver(abc.ABC):
|
||||
"""Base class for drivers."""
|
||||
|
||||
@staticmethod
|
||||
async def for_host(_host):
|
||||
"""Return a driver instance for a host.
|
||||
|
||||
Args:
|
||||
host: Host object for which a driver should be created.
|
||||
|
||||
Returns:
|
||||
A Driver instance if a driver should be instantiated for this host, or
|
||||
None if no driver instance of this class is needed.
|
||||
"""
|
||||
return None
|
||||
|
||||
@abc.abstractmethod
|
||||
async def init_controller(self):
|
||||
"""Initialize the controller."""
|
||||
@@ -41,7 +41,7 @@ from bumble.hci import (
|
||||
HCI_Reset_Command,
|
||||
HCI_Read_Local_Version_Information_Command,
|
||||
)
|
||||
|
||||
from bumble.drivers import common
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
@@ -285,7 +285,7 @@ class Firmware:
|
||||
)
|
||||
|
||||
|
||||
class Driver:
|
||||
class Driver(common.Driver):
|
||||
@dataclass
|
||||
class DriverInfo:
|
||||
rom: int
|
||||
@@ -470,8 +470,12 @@ class Driver:
|
||||
logger.debug("USB metadata not found")
|
||||
return False
|
||||
|
||||
vendor_id = host.hci_metadata.get("vendor_id", None)
|
||||
product_id = host.hci_metadata.get("product_id", None)
|
||||
if host.hci_metadata.get('driver') == 'rtk':
|
||||
# Forced driver
|
||||
return True
|
||||
|
||||
vendor_id = host.hci_metadata.get("vendor_id")
|
||||
product_id = host.hci_metadata.get("product_id")
|
||||
if vendor_id is None or product_id is None:
|
||||
logger.debug("USB metadata not sufficient")
|
||||
return False
|
||||
@@ -486,6 +490,9 @@ class Driver:
|
||||
|
||||
@classmethod
|
||||
async def driver_info_for_host(cls, host):
|
||||
await host.send_command(HCI_Reset_Command(), check_result=True)
|
||||
host.ready = True # Needed to let the host know the controller is ready.
|
||||
|
||||
response = await host.send_command(
|
||||
HCI_Read_Local_Version_Information_Command(), check_result=True
|
||||
)
|
||||
|
||||
117
bumble/gatt.py
117
bumble/gatt.py
@@ -23,16 +23,28 @@
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import enum
|
||||
import functools
|
||||
import logging
|
||||
import struct
|
||||
from typing import Optional, Sequence, Iterable, List, Union
|
||||
from typing import (
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
TYPE_CHECKING,
|
||||
)
|
||||
|
||||
from .colors import color
|
||||
from .core import UUID, get_dict_key_by_value
|
||||
from .att import Attribute
|
||||
from bumble.colors import color
|
||||
from bumble.core import UUID
|
||||
from bumble.att import Attribute, AttributeValue
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bumble.gatt_client import AttributeProxy
|
||||
from bumble.device import Connection
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -368,9 +380,12 @@ class TemplateService(Service):
|
||||
UUID: UUID
|
||||
|
||||
def __init__(
|
||||
self, characteristics: List[Characteristic], primary: bool = True
|
||||
self,
|
||||
characteristics: List[Characteristic],
|
||||
primary: bool = True,
|
||||
included_services: List[Service] = [],
|
||||
) -> None:
|
||||
super().__init__(self.UUID, characteristics, primary)
|
||||
super().__init__(self.UUID, characteristics, primary, included_services)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -519,56 +534,43 @@ class CharacteristicDeclaration(Attribute):
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class CharacteristicValue:
|
||||
'''
|
||||
Characteristic value where reading and/or writing is delegated to functions
|
||||
passed as arguments to the constructor.
|
||||
'''
|
||||
|
||||
def __init__(self, read=None, write=None):
|
||||
self._read = read
|
||||
self._write = write
|
||||
|
||||
def read(self, connection):
|
||||
return self._read(connection) if self._read else b''
|
||||
|
||||
def write(self, connection, value):
|
||||
if self._write:
|
||||
self._write(connection, value)
|
||||
class CharacteristicValue(AttributeValue):
|
||||
"""Same as AttributeValue, for backward compatibility"""
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class CharacteristicAdapter:
|
||||
'''
|
||||
An adapter that can adapt any object with `read_value` and `write_value`
|
||||
methods (like Characteristic and CharacteristicProxy objects) by wrapping
|
||||
those methods with ones that return/accept encoded/decoded values.
|
||||
Objects with async methods are considered proxies, so the adaptation is one
|
||||
where the return value of `read_value` is decoded and the value passed to
|
||||
`write_value` is encoded. Other objects are considered local characteristics
|
||||
so the adaptation is one where the return value of `read_value` is encoded
|
||||
and the value passed to `write_value` is decoded.
|
||||
If the characteristic has a `subscribe` method, it is wrapped with one where
|
||||
the values are decoded before being passed to the subscriber.
|
||||
An adapter that can adapt Characteristic and AttributeProxy objects
|
||||
by wrapping their `read_value()` and `write_value()` methods with ones that
|
||||
return/accept encoded/decoded values.
|
||||
|
||||
For proxies (i.e used by a GATT client), the adaptation is one where the return
|
||||
value of `read_value()` is decoded and the value passed to `write_value()` is
|
||||
encoded. The `subscribe()` method, is wrapped with one where the values are decoded
|
||||
before being passed to the subscriber.
|
||||
|
||||
For local values (i.e hosted by a GATT server) the adaptation is one where the
|
||||
return value of `read_value()` is encoded and the value passed to `write_value()`
|
||||
is decoded.
|
||||
'''
|
||||
|
||||
def __init__(self, characteristic):
|
||||
self.wrapped_characteristic = characteristic
|
||||
self.subscribers = {} # Map from subscriber to proxy subscriber
|
||||
read_value: Callable
|
||||
write_value: Callable
|
||||
|
||||
if asyncio.iscoroutinefunction(
|
||||
characteristic.read_value
|
||||
) and asyncio.iscoroutinefunction(characteristic.write_value):
|
||||
self.read_value = self.read_decoded_value
|
||||
self.write_value = self.write_decoded_value
|
||||
else:
|
||||
def __init__(self, characteristic: Union[Characteristic, AttributeProxy]):
|
||||
self.wrapped_characteristic = characteristic
|
||||
self.subscribers: Dict[
|
||||
Callable, Callable
|
||||
] = {} # Map from subscriber to proxy subscriber
|
||||
|
||||
if isinstance(characteristic, Characteristic):
|
||||
self.read_value = self.read_encoded_value
|
||||
self.write_value = self.write_encoded_value
|
||||
|
||||
if hasattr(self.wrapped_characteristic, 'subscribe'):
|
||||
else:
|
||||
self.read_value = self.read_decoded_value
|
||||
self.write_value = self.write_decoded_value
|
||||
self.subscribe = self.wrapped_subscribe
|
||||
|
||||
if hasattr(self.wrapped_characteristic, 'unsubscribe'):
|
||||
self.unsubscribe = self.wrapped_unsubscribe
|
||||
|
||||
def __getattr__(self, name):
|
||||
@@ -587,11 +589,13 @@ class CharacteristicAdapter:
|
||||
else:
|
||||
setattr(self.wrapped_characteristic, name, value)
|
||||
|
||||
def read_encoded_value(self, connection):
|
||||
return self.encode_value(self.wrapped_characteristic.read_value(connection))
|
||||
async def read_encoded_value(self, connection):
|
||||
return self.encode_value(
|
||||
await self.wrapped_characteristic.read_value(connection)
|
||||
)
|
||||
|
||||
def write_encoded_value(self, connection, value):
|
||||
return self.wrapped_characteristic.write_value(
|
||||
async def write_encoded_value(self, connection, value):
|
||||
return await self.wrapped_characteristic.write_value(
|
||||
connection, self.decode_value(value)
|
||||
)
|
||||
|
||||
@@ -726,13 +730,24 @@ class Descriptor(Attribute):
|
||||
'''
|
||||
|
||||
def __str__(self) -> str:
|
||||
if isinstance(self.value, bytes):
|
||||
value_str = self.value.hex()
|
||||
elif isinstance(self.value, CharacteristicValue):
|
||||
value = self.value.read(None)
|
||||
if isinstance(value, bytes):
|
||||
value_str = value.hex()
|
||||
else:
|
||||
value_str = '<async>'
|
||||
else:
|
||||
value_str = '<...>'
|
||||
return (
|
||||
f'Descriptor(handle=0x{self.handle:04X}, '
|
||||
f'type={self.type}, '
|
||||
f'value={self.read_value(None).hex()})'
|
||||
f'value={value_str})'
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class ClientCharacteristicConfigurationBits(enum.IntFlag):
|
||||
'''
|
||||
See Vol 3, Part G - 3.3.3.3 - Table 3.11 Client Characteristic Configuration bit
|
||||
|
||||
@@ -31,9 +31,9 @@ import struct
|
||||
from typing import List, Tuple, Optional, TypeVar, Type, Dict, Iterable, TYPE_CHECKING
|
||||
from pyee import EventEmitter
|
||||
|
||||
from .colors import color
|
||||
from .core import UUID
|
||||
from .att import (
|
||||
from bumble.colors import color
|
||||
from bumble.core import UUID
|
||||
from bumble.att import (
|
||||
ATT_ATTRIBUTE_NOT_FOUND_ERROR,
|
||||
ATT_ATTRIBUTE_NOT_LONG_ERROR,
|
||||
ATT_CID,
|
||||
@@ -60,7 +60,7 @@ from .att import (
|
||||
ATT_Write_Response,
|
||||
Attribute,
|
||||
)
|
||||
from .gatt import (
|
||||
from bumble.gatt import (
|
||||
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
|
||||
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
|
||||
GATT_MAX_ATTRIBUTE_VALUE_SIZE,
|
||||
@@ -74,6 +74,7 @@ from .gatt import (
|
||||
Descriptor,
|
||||
Service,
|
||||
)
|
||||
from bumble.utils import AsyncRunner
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bumble.device import Device, Connection
|
||||
@@ -379,7 +380,7 @@ class Server(EventEmitter):
|
||||
|
||||
# Get or encode the value
|
||||
value = (
|
||||
attribute.read_value(connection)
|
||||
await attribute.read_value(connection)
|
||||
if value is None
|
||||
else attribute.encode_value(value)
|
||||
)
|
||||
@@ -422,7 +423,7 @@ class Server(EventEmitter):
|
||||
|
||||
# Get or encode the value
|
||||
value = (
|
||||
attribute.read_value(connection)
|
||||
await attribute.read_value(connection)
|
||||
if value is None
|
||||
else attribute.encode_value(value)
|
||||
)
|
||||
@@ -650,7 +651,8 @@ class Server(EventEmitter):
|
||||
|
||||
self.send_response(connection, response)
|
||||
|
||||
def on_att_find_by_type_value_request(self, connection, request):
|
||||
@AsyncRunner.run_in_task()
|
||||
async def on_att_find_by_type_value_request(self, connection, request):
|
||||
'''
|
||||
See Bluetooth spec Vol 3, Part F - 3.4.3.3 Find By Type Value Request
|
||||
'''
|
||||
@@ -658,13 +660,13 @@ class Server(EventEmitter):
|
||||
# Build list of returned attributes
|
||||
pdu_space_available = connection.att_mtu - 2
|
||||
attributes = []
|
||||
for attribute in (
|
||||
async for attribute in (
|
||||
attribute
|
||||
for attribute in self.attributes
|
||||
if attribute.handle >= request.starting_handle
|
||||
and attribute.handle <= request.ending_handle
|
||||
and attribute.type == request.attribute_type
|
||||
and attribute.read_value(connection) == request.attribute_value
|
||||
and (await attribute.read_value(connection)) == request.attribute_value
|
||||
and pdu_space_available >= 4
|
||||
):
|
||||
# TODO: check permissions
|
||||
@@ -702,7 +704,8 @@ class Server(EventEmitter):
|
||||
|
||||
self.send_response(connection, response)
|
||||
|
||||
def on_att_read_by_type_request(self, connection, request):
|
||||
@AsyncRunner.run_in_task()
|
||||
async def on_att_read_by_type_request(self, connection, request):
|
||||
'''
|
||||
See Bluetooth spec Vol 3, Part F - 3.4.4.1 Read By Type Request
|
||||
'''
|
||||
@@ -725,7 +728,7 @@ class Server(EventEmitter):
|
||||
and pdu_space_available
|
||||
):
|
||||
try:
|
||||
attribute_value = attribute.read_value(connection)
|
||||
attribute_value = await attribute.read_value(connection)
|
||||
except ATT_Error as error:
|
||||
# If the first attribute is unreadable, return an error
|
||||
# Otherwise return attributes up to this point
|
||||
@@ -767,14 +770,15 @@ class Server(EventEmitter):
|
||||
|
||||
self.send_response(connection, response)
|
||||
|
||||
def on_att_read_request(self, connection, request):
|
||||
@AsyncRunner.run_in_task()
|
||||
async def on_att_read_request(self, connection, request):
|
||||
'''
|
||||
See Bluetooth spec Vol 3, Part F - 3.4.4.3 Read Request
|
||||
'''
|
||||
|
||||
if attribute := self.get_attribute(request.attribute_handle):
|
||||
try:
|
||||
value = attribute.read_value(connection)
|
||||
value = await attribute.read_value(connection)
|
||||
except ATT_Error as error:
|
||||
response = ATT_Error_Response(
|
||||
request_opcode_in_error=request.op_code,
|
||||
@@ -792,14 +796,15 @@ class Server(EventEmitter):
|
||||
)
|
||||
self.send_response(connection, response)
|
||||
|
||||
def on_att_read_blob_request(self, connection, request):
|
||||
@AsyncRunner.run_in_task()
|
||||
async def on_att_read_blob_request(self, connection, request):
|
||||
'''
|
||||
See Bluetooth spec Vol 3, Part F - 3.4.4.5 Read Blob Request
|
||||
'''
|
||||
|
||||
if attribute := self.get_attribute(request.attribute_handle):
|
||||
try:
|
||||
value = attribute.read_value(connection)
|
||||
value = await attribute.read_value(connection)
|
||||
except ATT_Error as error:
|
||||
response = ATT_Error_Response(
|
||||
request_opcode_in_error=request.op_code,
|
||||
@@ -836,7 +841,8 @@ class Server(EventEmitter):
|
||||
)
|
||||
self.send_response(connection, response)
|
||||
|
||||
def on_att_read_by_group_type_request(self, connection, request):
|
||||
@AsyncRunner.run_in_task()
|
||||
async def on_att_read_by_group_type_request(self, connection, request):
|
||||
'''
|
||||
See Bluetooth spec Vol 3, Part F - 3.4.4.9 Read by Group Type Request
|
||||
'''
|
||||
@@ -864,7 +870,7 @@ class Server(EventEmitter):
|
||||
):
|
||||
# No need to catch permission errors here, since these attributes
|
||||
# must all be world-readable
|
||||
attribute_value = attribute.read_value(connection)
|
||||
attribute_value = await attribute.read_value(connection)
|
||||
# Check the attribute value size
|
||||
max_attribute_size = min(connection.att_mtu - 6, 251)
|
||||
if len(attribute_value) > max_attribute_size:
|
||||
@@ -903,7 +909,8 @@ class Server(EventEmitter):
|
||||
|
||||
self.send_response(connection, response)
|
||||
|
||||
def on_att_write_request(self, connection, request):
|
||||
@AsyncRunner.run_in_task()
|
||||
async def on_att_write_request(self, connection, request):
|
||||
'''
|
||||
See Bluetooth spec Vol 3, Part F - 3.4.5.1 Write Request
|
||||
'''
|
||||
@@ -936,12 +943,13 @@ class Server(EventEmitter):
|
||||
return
|
||||
|
||||
# Accept the value
|
||||
attribute.write_value(connection, request.attribute_value)
|
||||
await attribute.write_value(connection, request.attribute_value)
|
||||
|
||||
# Done
|
||||
self.send_response(connection, ATT_Write_Response())
|
||||
|
||||
def on_att_write_command(self, connection, request):
|
||||
@AsyncRunner.run_in_task()
|
||||
async def on_att_write_command(self, connection, request):
|
||||
'''
|
||||
See Bluetooth spec Vol 3, Part F - 3.4.5.3 Write Command
|
||||
'''
|
||||
@@ -959,9 +967,9 @@ class Server(EventEmitter):
|
||||
|
||||
# Accept the value
|
||||
try:
|
||||
attribute.write_value(connection, request.attribute_value)
|
||||
await attribute.write_value(connection, request.attribute_value)
|
||||
except Exception as error:
|
||||
logger.warning(f'!!! ignoring exception: {error}')
|
||||
logger.exception(f'!!! ignoring exception: {error}')
|
||||
|
||||
def on_att_handle_value_confirmation(self, connection, _confirmation):
|
||||
'''
|
||||
|
||||
219
bumble/hci.py
219
bumble/hci.py
@@ -21,9 +21,11 @@ import dataclasses
|
||||
import enum
|
||||
import functools
|
||||
import logging
|
||||
import secrets
|
||||
import struct
|
||||
from typing import Any, Dict, Callable, Optional, Type, Union, List
|
||||
|
||||
from bumble import crypto
|
||||
from .colors import color
|
||||
from .core import (
|
||||
BT_BR_EDR_TRANSPORT,
|
||||
@@ -561,6 +563,12 @@ HCI_LE_TRANSMITTER_TEST_V4_COMMAND = hci_c
|
||||
HCI_LE_SET_DATA_RELATED_ADDRESS_CHANGES_COMMAND = hci_command_op_code(0x08, 0x007C)
|
||||
HCI_LE_SET_DEFAULT_SUBRATE_COMMAND = hci_command_op_code(0x08, 0x007D)
|
||||
HCI_LE_SUBRATE_REQUEST_COMMAND = hci_command_op_code(0x08, 0x007E)
|
||||
HCI_LE_SET_EXTENDED_ADVERTISING_PARAMETERS_V2_COMMAND = hci_command_op_code(0x08, 0x007F)
|
||||
HCI_LE_SET_PERIODIC_ADVERTISING_SUBEVENT_DATA_COMMAND = hci_command_op_code(0x08, 0x0082)
|
||||
HCI_LE_SET_PERIODIC_ADVERTISING_RESPONSE_DATA_COMMAND = hci_command_op_code(0x08, 0x0083)
|
||||
HCI_LE_SET_PERIODIC_SYNC_SUBEVENT_COMMAND = hci_command_op_code(0x08, 0x0084)
|
||||
HCI_LE_EXTENDED_CREATE_CONNECTION_V2_COMMAND = hci_command_op_code(0x08, 0x0085)
|
||||
HCI_LE_SET_PERIODIC_ADVERTISING_PARAMETERS_V2_COMMAND = hci_command_op_code(0x08, 0x0086)
|
||||
|
||||
|
||||
# HCI Error Codes
|
||||
@@ -722,6 +730,19 @@ HCI_LE_PHY_TYPE_TO_BIT = {
|
||||
HCI_LE_CODED_PHY: HCI_LE_CODED_PHY_BIT
|
||||
}
|
||||
|
||||
|
||||
class Phy(enum.IntEnum):
|
||||
LE_1M = 0x01
|
||||
LE_2M = 0x02
|
||||
LE_CODED = 0x03
|
||||
|
||||
|
||||
class PhyBit(enum.IntFlag):
|
||||
LE_1M = 0b00000001
|
||||
LE_2M = 0b00000010
|
||||
LE_CODED = 0b00000100
|
||||
|
||||
|
||||
# Connection Parameters
|
||||
HCI_CONNECTION_INTERVAL_MS_PER_UNIT = 1.25
|
||||
HCI_CONNECTION_LATENCY_MS_PER_UNIT = 1.25
|
||||
@@ -1317,56 +1338,72 @@ HCI_SUPPORTED_COMMANDS_FLAGS = (
|
||||
(
|
||||
HCI_LE_SET_DEFAULT_SUBRATE_COMMAND,
|
||||
HCI_LE_SUBRATE_REQUEST_COMMAND,
|
||||
HCI_LE_SET_EXTENDED_ADVERTISING_PARAMETERS_V2_COMMAND,
|
||||
None,
|
||||
None,
|
||||
HCI_LE_SET_PERIODIC_ADVERTISING_SUBEVENT_DATA_COMMAND,
|
||||
HCI_LE_SET_PERIODIC_ADVERTISING_RESPONSE_DATA_COMMAND,
|
||||
HCI_LE_SET_PERIODIC_SYNC_SUBEVENT_COMMAND
|
||||
),
|
||||
# Octet 47
|
||||
(
|
||||
HCI_LE_EXTENDED_CREATE_CONNECTION_V2_COMMAND,
|
||||
HCI_LE_SET_PERIODIC_ADVERTISING_PARAMETERS_V2_COMMAND,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None
|
||||
)
|
||||
)
|
||||
|
||||
# LE Supported Features
|
||||
HCI_LE_ENCRYPTION_LE_SUPPORTED_FEATURE = 0
|
||||
HCI_CONNECTION_PARAMETERS_REQUEST_PROCEDURE_LE_SUPPORTED_FEATURE = 1
|
||||
HCI_EXTENDED_REJECT_INDICATION_LE_SUPPORTED_FEATURE = 2
|
||||
HCI_PERIPHERAL_INITIATED_FEATURE_EXCHANGE_LE_SUPPORTED_FEATURE = 3
|
||||
HCI_LE_PING_LE_SUPPORTED_FEATURE = 4
|
||||
HCI_LE_DATA_PACKET_LENGTH_EXTENSION_LE_SUPPORTED_FEATURE = 5
|
||||
HCI_LL_PRIVACY_LE_SUPPORTED_FEATURE = 6
|
||||
HCI_EXTENDED_SCANNER_FILTER_POLICIES_LE_SUPPORTED_FEATURE = 7
|
||||
HCI_LE_2M_PHY_LE_SUPPORTED_FEATURE = 8
|
||||
HCI_STABLE_MODULATION_INDEX_TRANSMITTER_LE_SUPPORTED_FEATURE = 9
|
||||
HCI_STABLE_MODULATION_INDEX_RECEIVER_LE_SUPPORTED_FEATURE = 10
|
||||
HCI_LE_CODED_PHY_LE_SUPPORTED_FEATURE = 11
|
||||
HCI_LE_EXTENDED_ADVERTISING_LE_SUPPORTED_FEATURE = 12
|
||||
HCI_LE_PERIODIC_ADVERTISING_LE_SUPPORTED_FEATURE = 13
|
||||
HCI_CHANNEL_SELECTION_ALGORITHM_2_LE_SUPPORTED_FEATURE = 14
|
||||
HCI_LE_POWER_CLASS_1_LE_SUPPORTED_FEATURE = 15
|
||||
HCI_MINIMUM_NUMBER_OF_USED_CHANNELS_PROCEDURE_LE_SUPPORTED_FEATURE = 16
|
||||
HCI_CONNECTION_CTE_REQUEST_LE_SUPPORTED_FEATURE = 17
|
||||
HCI_CONNECTION_CTE_RESPONSE_LE_SUPPORTED_FEATURE = 18
|
||||
HCI_CONNECTIONLESS_CTE_TRANSMITTER_LE_SUPPORTED_FEATURE = 19
|
||||
HCI_CONNECTIONLESS_CTR_RECEIVER_LE_SUPPORTED_FEATURE = 20
|
||||
HCI_ANTENNA_SWITCHING_DURING_CTE_TRANSMISSION_LE_SUPPORTED_FEATURE = 21
|
||||
HCI_ANTENNA_SWITCHING_DURING_CTE_RECEPTION_LE_SUPPORTED_FEATURE = 22
|
||||
HCI_RECEIVING_CONSTANT_TONE_EXTENSIONS_LE_SUPPORTED_FEATURE = 23
|
||||
HCI_PERIODIC_ADVERTISING_SYNC_TRANSFER_SENDER_LE_SUPPORTED_FEATURE = 24
|
||||
HCI_PERIODIC_ADVERTISING_SYNC_TRANSFER_RECIPIENT_LE_SUPPORTED_FEATURE = 25
|
||||
HCI_SLEEP_CLOCK_ACCURACY_UPDATES_LE_SUPPORTED_FEATURE = 26
|
||||
HCI_REMOTE_PUBLIC_KEY_VALIDATION_LE_SUPPORTED_FEATURE = 27
|
||||
HCI_CONNECTED_ISOCHRONOUS_STREAM_CENTRAL_LE_SUPPORTED_FEATURE = 28
|
||||
HCI_CONNECTED_ISOCHRONOUS_STREAM_PERIPHERAL_LE_SUPPORTED_FEATURE = 29
|
||||
HCI_ISOCHRONOUS_BROADCASTER_LE_SUPPORTED_FEATURE = 30
|
||||
HCI_SYNCHRONIZED_RECEIVER_LE_SUPPORTED_FEATURE = 31
|
||||
HCI_CONNECTED_ISOCHRONOUS_STREAM_LE_SUPPORTED_FEATURE = 32
|
||||
HCI_LE_POWER_CONTROL_REQUEST_LE_SUPPORTED_FEATURE = 33
|
||||
HCI_LE_POWER_CONTROL_REQUEST_DUP_LE_SUPPORTED_FEATURE = 34
|
||||
HCI_LE_PATH_LOSS_MONITORING_LE_SUPPORTED_FEATURE = 35
|
||||
HCI_PERIODIC_ADVERTISING_ADI_SUPPORT_LE_SUPPORTED_FEATURE = 36
|
||||
HCI_CONNECTION_SUBRATING_LE_SUPPORTED_FEATURE = 37
|
||||
HCI_CONNECTION_SUBRATING_HOST_SUPPORT_LE_SUPPORTED_FEATURE = 38
|
||||
HCI_CHANNEL_CLASSIFICATION_LE_SUPPORTED_FEATURE = 39
|
||||
# See Bluetooth spec @ Vol 6, Part B, 4.6 FEATURE SUPPORT
|
||||
HCI_LE_ENCRYPTION_LE_SUPPORTED_FEATURE = 0
|
||||
HCI_CONNECTION_PARAMETERS_REQUEST_PROCEDURE_LE_SUPPORTED_FEATURE = 1
|
||||
HCI_EXTENDED_REJECT_INDICATION_LE_SUPPORTED_FEATURE = 2
|
||||
HCI_PERIPHERAL_INITIATED_FEATURE_EXCHANGE_LE_SUPPORTED_FEATURE = 3
|
||||
HCI_LE_PING_LE_SUPPORTED_FEATURE = 4
|
||||
HCI_LE_DATA_PACKET_LENGTH_EXTENSION_LE_SUPPORTED_FEATURE = 5
|
||||
HCI_LL_PRIVACY_LE_SUPPORTED_FEATURE = 6
|
||||
HCI_EXTENDED_SCANNER_FILTER_POLICIES_LE_SUPPORTED_FEATURE = 7
|
||||
HCI_LE_2M_PHY_LE_SUPPORTED_FEATURE = 8
|
||||
HCI_STABLE_MODULATION_INDEX_TRANSMITTER_LE_SUPPORTED_FEATURE = 9
|
||||
HCI_STABLE_MODULATION_INDEX_RECEIVER_LE_SUPPORTED_FEATURE = 10
|
||||
HCI_LE_CODED_PHY_LE_SUPPORTED_FEATURE = 11
|
||||
HCI_LE_EXTENDED_ADVERTISING_LE_SUPPORTED_FEATURE = 12
|
||||
HCI_LE_PERIODIC_ADVERTISING_LE_SUPPORTED_FEATURE = 13
|
||||
HCI_CHANNEL_SELECTION_ALGORITHM_2_LE_SUPPORTED_FEATURE = 14
|
||||
HCI_LE_POWER_CLASS_1_LE_SUPPORTED_FEATURE = 15
|
||||
HCI_MINIMUM_NUMBER_OF_USED_CHANNELS_PROCEDURE_LE_SUPPORTED_FEATURE = 16
|
||||
HCI_CONNECTION_CTE_REQUEST_LE_SUPPORTED_FEATURE = 17
|
||||
HCI_CONNECTION_CTE_RESPONSE_LE_SUPPORTED_FEATURE = 18
|
||||
HCI_CONNECTIONLESS_CTE_TRANSMITTER_LE_SUPPORTED_FEATURE = 19
|
||||
HCI_CONNECTIONLESS_CTR_RECEIVER_LE_SUPPORTED_FEATURE = 20
|
||||
HCI_ANTENNA_SWITCHING_DURING_CTE_TRANSMISSION_LE_SUPPORTED_FEATURE = 21
|
||||
HCI_ANTENNA_SWITCHING_DURING_CTE_RECEPTION_LE_SUPPORTED_FEATURE = 22
|
||||
HCI_RECEIVING_CONSTANT_TONE_EXTENSIONS_LE_SUPPORTED_FEATURE = 23
|
||||
HCI_PERIODIC_ADVERTISING_SYNC_TRANSFER_SENDER_LE_SUPPORTED_FEATURE = 24
|
||||
HCI_PERIODIC_ADVERTISING_SYNC_TRANSFER_RECIPIENT_LE_SUPPORTED_FEATURE = 25
|
||||
HCI_SLEEP_CLOCK_ACCURACY_UPDATES_LE_SUPPORTED_FEATURE = 26
|
||||
HCI_REMOTE_PUBLIC_KEY_VALIDATION_LE_SUPPORTED_FEATURE = 27
|
||||
HCI_CONNECTED_ISOCHRONOUS_STREAM_CENTRAL_LE_SUPPORTED_FEATURE = 28
|
||||
HCI_CONNECTED_ISOCHRONOUS_STREAM_PERIPHERAL_LE_SUPPORTED_FEATURE = 29
|
||||
HCI_ISOCHRONOUS_BROADCASTER_LE_SUPPORTED_FEATURE = 30
|
||||
HCI_SYNCHRONIZED_RECEIVER_LE_SUPPORTED_FEATURE = 31
|
||||
HCI_CONNECTED_ISOCHRONOUS_STREAM_LE_SUPPORTED_FEATURE = 32
|
||||
HCI_LE_POWER_CONTROL_REQUEST_LE_SUPPORTED_FEATURE = 33
|
||||
HCI_LE_POWER_CONTROL_REQUEST_DUP_LE_SUPPORTED_FEATURE = 34
|
||||
HCI_LE_PATH_LOSS_MONITORING_LE_SUPPORTED_FEATURE = 35
|
||||
HCI_PERIODIC_ADVERTISING_ADI_SUPPORT_LE_SUPPORTED_FEATURE = 36
|
||||
HCI_CONNECTION_SUBRATING_LE_SUPPORTED_FEATURE = 37
|
||||
HCI_CONNECTION_SUBRATING_HOST_SUPPORT_LE_SUPPORTED_FEATURE = 38
|
||||
HCI_CHANNEL_CLASSIFICATION_LE_SUPPORTED_FEATURE = 39
|
||||
HCI_ADVERTISING_CODING_SELECTION_LE_SUPPORTED_FEATURE = 40
|
||||
HCI_ADVERTISING_CODING_SELECTION_HOST_SUPPORT_LE_SUPPORTED_FEATURE = 41
|
||||
HCI_PERIODIC_ADVERTISING_WITH_RESPONSES_ADVERTISER_LE_SUPPORTED_FEATURE = 43
|
||||
HCI_PERIODIC_ADVERTISING_WITH_RESPONSES_SCANNER_LE_SUPPORTED_FEATURE = 44
|
||||
|
||||
HCI_LE_SUPPORTED_FEATURES_NAMES = {
|
||||
flag: feature_name for (feature_name, flag) in globals().items()
|
||||
@@ -1629,7 +1666,7 @@ class HCI_Object:
|
||||
field_bytes = bytes(field_value)
|
||||
elif field_type == 'v':
|
||||
# Variable-length bytes field, with 1-byte length at the beginning
|
||||
field_bytes = bytes(field_bytes)
|
||||
field_bytes = bytes(field_value)
|
||||
field_length = len(field_bytes)
|
||||
field_bytes = bytes([field_length]) + field_bytes
|
||||
elif isinstance(field_value, (bytes, bytearray)) or hasattr(
|
||||
@@ -1846,6 +1883,43 @@ class Address:
|
||||
address_type = data[offset - 1]
|
||||
return Address.parse_address_with_type(data, offset, address_type)
|
||||
|
||||
@classmethod
|
||||
def generate_static_address(cls) -> Address:
|
||||
'''Generates Random Static Address, with the 2 most significant bits of 0b11.
|
||||
|
||||
See Bluetooth spec, Vol 6, Part B - Table 1.2.
|
||||
'''
|
||||
address_bytes = secrets.token_bytes(6)
|
||||
address_bytes = address_bytes[:5] + bytes([address_bytes[5] | 0b11000000])
|
||||
return Address(
|
||||
address=address_bytes, address_type=Address.RANDOM_DEVICE_ADDRESS
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def generate_private_address(cls, irk: bytes = b'') -> Address:
|
||||
'''Generates Random Private MAC Address.
|
||||
|
||||
If IRK is present, a Resolvable Private Address, with the 2 most significant
|
||||
bits of 0b01 will be generated. Otherwise, a Non-resolvable Private Address,
|
||||
with the 2 most significant bits of 0b00 will be generated.
|
||||
|
||||
See Bluetooth spec, Vol 6, Part B - Table 1.2.
|
||||
|
||||
Args:
|
||||
irk: Local Identity Resolving Key(IRK), in little-endian. If not set, a
|
||||
non-resolvable address will be generated.
|
||||
'''
|
||||
if irk:
|
||||
prand = crypto.generate_prand()
|
||||
address_bytes = crypto.ah(irk, prand) + prand
|
||||
else:
|
||||
address_bytes = secrets.token_bytes(6)
|
||||
address_bytes = address_bytes[:5] + bytes([address_bytes[5] & 0b00111111])
|
||||
|
||||
return Address(
|
||||
address=address_bytes, address_type=Address.RANDOM_DEVICE_ADDRESS
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self, address: Union[bytes, str], address_type: int = RANDOM_DEVICE_ADDRESS
|
||||
):
|
||||
@@ -1941,25 +2015,15 @@ Address.ANY_RANDOM = Address(b"\x00\x00\x00\x00\x00\x00", Address.RANDOM_DEVICE_
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class OwnAddressType:
|
||||
class OwnAddressType(enum.IntEnum):
|
||||
PUBLIC = 0
|
||||
RANDOM = 1
|
||||
RESOLVABLE_OR_PUBLIC = 2
|
||||
RESOLVABLE_OR_RANDOM = 3
|
||||
|
||||
TYPE_NAMES = {
|
||||
PUBLIC: 'PUBLIC',
|
||||
RANDOM: 'RANDOM',
|
||||
RESOLVABLE_OR_PUBLIC: 'RESOLVABLE_OR_PUBLIC',
|
||||
RESOLVABLE_OR_RANDOM: 'RESOLVABLE_OR_RANDOM',
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def type_name(type_id):
|
||||
return name_or_number(OwnAddressType.TYPE_NAMES, type_id)
|
||||
|
||||
# pylint: disable-next=unnecessary-lambda
|
||||
TYPE_SPEC = {'size': 1, 'mapper': lambda x: OwnAddressType.type_name(x)}
|
||||
@classmethod
|
||||
def type_spec(cls):
|
||||
return {'size': 1, 'mapper': lambda x: OwnAddressType(x).name}
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -1986,6 +2050,9 @@ class HCI_Packet:
|
||||
if packet_type == HCI_EVENT_PACKET:
|
||||
return HCI_Event.from_bytes(packet)
|
||||
|
||||
if packet_type == HCI_ISO_DATA_PACKET:
|
||||
return HCI_IsoDataPacket.from_bytes(packet)
|
||||
|
||||
return HCI_CustomPacket(packet)
|
||||
|
||||
def __init__(self, name):
|
||||
@@ -2018,6 +2085,7 @@ class HCI_Command(HCI_Packet):
|
||||
hci_packet_type = HCI_COMMAND_PACKET
|
||||
command_names: Dict[int, str] = {}
|
||||
command_classes: Dict[int, Type[HCI_Command]] = {}
|
||||
op_code: int
|
||||
|
||||
@staticmethod
|
||||
def command(fields=(), return_parameters_fields=()):
|
||||
@@ -2103,7 +2171,11 @@ class HCI_Command(HCI_Packet):
|
||||
return_parameters.fields = cls.return_parameters_fields
|
||||
return return_parameters
|
||||
|
||||
def __init__(self, op_code, parameters=None, **kwargs):
|
||||
def __init__(self, op_code=-1, parameters=None, **kwargs):
|
||||
# Since the legacy implementation relies on an __init__ injector, typing always
|
||||
# complains that positional argument op_code is not passed, so here sets a
|
||||
# default value to allow building derived HCI_Command without op_code.
|
||||
assert op_code != -1
|
||||
super().__init__(HCI_Command.command_name(op_code))
|
||||
if (fields := getattr(self, 'fields', None)) and kwargs:
|
||||
HCI_Object.init_from_fields(self, fields, kwargs)
|
||||
@@ -3344,7 +3416,7 @@ class HCI_LE_Set_Random_Address_Command(HCI_Command):
|
||||
),
|
||||
},
|
||||
),
|
||||
('own_address_type', OwnAddressType.TYPE_SPEC),
|
||||
('own_address_type', OwnAddressType.type_spec()),
|
||||
('peer_address_type', Address.ADDRESS_TYPE_SPEC),
|
||||
('peer_address', Address.parse_address_preceded_by_type),
|
||||
('advertising_channel_map', 1),
|
||||
@@ -3437,7 +3509,7 @@ class HCI_LE_Set_Advertising_Enable_Command(HCI_Command):
|
||||
('le_scan_type', 1),
|
||||
('le_scan_interval', 2),
|
||||
('le_scan_window', 2),
|
||||
('own_address_type', OwnAddressType.TYPE_SPEC),
|
||||
('own_address_type', OwnAddressType.type_spec()),
|
||||
('scanning_filter_policy', 1),
|
||||
]
|
||||
)
|
||||
@@ -3476,7 +3548,7 @@ class HCI_LE_Set_Scan_Enable_Command(HCI_Command):
|
||||
('initiator_filter_policy', 1),
|
||||
('peer_address_type', Address.ADDRESS_TYPE_SPEC),
|
||||
('peer_address', Address.parse_address_preceded_by_type),
|
||||
('own_address_type', OwnAddressType.TYPE_SPEC),
|
||||
('own_address_type', OwnAddressType.type_spec()),
|
||||
('connection_interval_min', 2),
|
||||
('connection_interval_max', 2),
|
||||
('max_latency', 2),
|
||||
@@ -3883,7 +3955,7 @@ class HCI_LE_Set_Advertising_Set_Random_Address_Command(HCI_Command):
|
||||
),
|
||||
},
|
||||
),
|
||||
('own_address_type', OwnAddressType.TYPE_SPEC),
|
||||
('own_address_type', OwnAddressType.type_spec()),
|
||||
('peer_address_type', Address.ADDRESS_TYPE_SPEC),
|
||||
('peer_address', Address.parse_address_preceded_by_type),
|
||||
('advertising_filter_policy', 1),
|
||||
@@ -4279,7 +4351,7 @@ class HCI_LE_Extended_Create_Connection_Command(HCI_Command):
|
||||
('initiator_filter_policy:', self.initiator_filter_policy),
|
||||
(
|
||||
'own_address_type: ',
|
||||
OwnAddressType.type_name(self.own_address_type),
|
||||
OwnAddressType(self.own_address_type).name,
|
||||
),
|
||||
(
|
||||
'peer_address_type: ',
|
||||
@@ -4521,6 +4593,10 @@ class HCI_LE_Setup_ISO_Data_Path_Command(HCI_Command):
|
||||
See Bluetooth spec @ 7.8.109 LE Setup ISO Data Path command
|
||||
'''
|
||||
|
||||
class Direction(enum.IntEnum):
|
||||
HOST_TO_CONTROLLER = 0x00
|
||||
CONTROLLER_TO_HOST = 0x01
|
||||
|
||||
connection_handle: int
|
||||
data_path_direction: int
|
||||
data_path_id: int
|
||||
@@ -5160,6 +5236,21 @@ HCI_LE_Meta_Event.subevent_classes[
|
||||
] = HCI_LE_Extended_Advertising_Report_Event
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@HCI_LE_Meta_Event.event(
|
||||
[
|
||||
('status', 1),
|
||||
('advertising_handle', 1),
|
||||
('connection_handle', 2),
|
||||
('number_completed_extended_advertising_events', 1),
|
||||
]
|
||||
)
|
||||
class HCI_LE_Advertising_Set_Terminated_Event(HCI_LE_Meta_Event):
|
||||
'''
|
||||
See Bluetooth spec @ 7.7.65.18 LE Advertising Set Terminated Event
|
||||
'''
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@HCI_LE_Meta_Event.event([('connection_handle', 2), ('channel_selection_algorithm', 1)])
|
||||
class HCI_LE_Channel_Selection_Algorithm_Event(HCI_LE_Meta_Event):
|
||||
@@ -6093,7 +6184,7 @@ class HCI_IsoDataPacket(HCI_Packet):
|
||||
if ts_flag:
|
||||
if not should_include_sdu_info:
|
||||
logger.warn(f'Timestamp included when pb_flag={bin(pb_flag)}')
|
||||
time_stamp, _ = struct.unpack_from('<I', packet, pos)
|
||||
time_stamp, *_ = struct.unpack_from('<I', packet, pos)
|
||||
pos += 4
|
||||
|
||||
if should_include_sdu_info:
|
||||
@@ -6160,7 +6251,7 @@ class HCI_IsoDataPacket(HCI_Packet):
|
||||
self.packet_sequence_number,
|
||||
self.iso_sdu_length | self.packet_status_flag << 14,
|
||||
]
|
||||
return struct.pack(fmt, args) + self.iso_sdu_fragment
|
||||
return struct.pack(fmt, *args) + self.iso_sdu_fragment
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
|
||||
@@ -37,6 +37,7 @@ from bumble.l2cap import (
|
||||
L2CAP_Connection_Response,
|
||||
)
|
||||
from bumble.hci import (
|
||||
Address,
|
||||
HCI_EVENT_PACKET,
|
||||
HCI_ACL_DATA_PACKET,
|
||||
HCI_DISCONNECTION_COMPLETE_EVENT,
|
||||
@@ -48,6 +49,7 @@ from bumble.hci import (
|
||||
)
|
||||
from bumble.rfcomm import RFCOMM_Frame, RFCOMM_PSM
|
||||
from bumble.sdp import SDP_PDU, SDP_PSM
|
||||
from bumble import crypto
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
@@ -232,3 +234,15 @@ class PacketTracer:
|
||||
)
|
||||
self.host_to_controller_analyzer.peer = self.controller_to_host_analyzer
|
||||
self.controller_to_host_analyzer.peer = self.host_to_controller_analyzer
|
||||
|
||||
|
||||
def generate_irk() -> bytes:
|
||||
return crypto.r()
|
||||
|
||||
|
||||
def verify_rpa_with_irk(rpa: Address, irk: bytes) -> bool:
|
||||
rpa_bytes = bytes(rpa)
|
||||
prand_given = rpa_bytes[3:]
|
||||
hash_given = rpa_bytes[:3]
|
||||
hash_local = crypto.ah(irk, prand_given)
|
||||
return hash_local[:3] == hash_given
|
||||
|
||||
407
bumble/hid.py
407
bumble/hid.py
@@ -19,16 +19,17 @@ from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
import enum
|
||||
import struct
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from pyee import EventEmitter
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from typing import Optional, Callable, TYPE_CHECKING
|
||||
from typing_extensions import override
|
||||
|
||||
from bumble import l2cap
|
||||
from bumble import l2cap, device
|
||||
from bumble.colors import color
|
||||
from bumble.core import InvalidStateError, ProtocolError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from bumble.device import Device, Connection
|
||||
from .hci import Address
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -60,6 +61,7 @@ class Message:
|
||||
NOT_READY = 0x01
|
||||
ERR_INVALID_REPORT_ID = 0x02
|
||||
ERR_UNSUPPORTED_REQUEST = 0x03
|
||||
ERR_INVALID_PARAMETER = 0x04
|
||||
ERR_UNKNOWN = 0x0E
|
||||
ERR_FATAL = 0x0F
|
||||
|
||||
@@ -101,13 +103,14 @@ class GetReportMessage(Message):
|
||||
def __bytes__(self) -> bytes:
|
||||
packet_bytes = bytearray()
|
||||
packet_bytes.append(self.report_id)
|
||||
packet_bytes.extend(
|
||||
[(self.buffer_size & 0xFF), ((self.buffer_size >> 8) & 0xFF)]
|
||||
)
|
||||
if self.report_type == Message.ReportType.OTHER_REPORT:
|
||||
if self.buffer_size == 0:
|
||||
return self.header(self.report_type) + packet_bytes
|
||||
else:
|
||||
return self.header(0x08 | self.report_type) + packet_bytes
|
||||
return (
|
||||
self.header(0x08 | self.report_type)
|
||||
+ packet_bytes
|
||||
+ struct.pack("<H", self.buffer_size)
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -120,6 +123,16 @@ class SetReportMessage(Message):
|
||||
return self.header(self.report_type) + self.data
|
||||
|
||||
|
||||
@dataclass
|
||||
class SendControlData(Message):
|
||||
report_type: int
|
||||
data: bytes
|
||||
message_type = Message.MessageType.DATA
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return self.header(self.report_type) + self.data
|
||||
|
||||
|
||||
@dataclass
|
||||
class GetProtocolMessage(Message):
|
||||
message_type = Message.MessageType.GET_PROTOCOL
|
||||
@@ -161,31 +174,47 @@ class VirtualCableUnplug(Message):
|
||||
return self.header(Message.ControlCommand.VIRTUAL_CABLE_UNPLUG)
|
||||
|
||||
|
||||
# Device sends input report, host sends output report.
|
||||
@dataclass
|
||||
class SendData(Message):
|
||||
data: bytes
|
||||
report_type: int
|
||||
message_type = Message.MessageType.DATA
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return self.header(Message.ReportType.OUTPUT_REPORT) + self.data
|
||||
return self.header(self.report_type) + self.data
|
||||
|
||||
|
||||
@dataclass
|
||||
class SendHandshakeMessage(Message):
|
||||
result_code: int
|
||||
message_type = Message.MessageType.HANDSHAKE
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
return self.header(self.result_code)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Host(EventEmitter):
|
||||
l2cap_ctrl_channel: Optional[l2cap.ClassicChannel]
|
||||
l2cap_intr_channel: Optional[l2cap.ClassicChannel]
|
||||
class HID(ABC, EventEmitter):
|
||||
l2cap_ctrl_channel: Optional[l2cap.ClassicChannel] = None
|
||||
l2cap_intr_channel: Optional[l2cap.ClassicChannel] = None
|
||||
connection: Optional[device.Connection] = None
|
||||
|
||||
def __init__(self, device: Device, connection: Connection) -> None:
|
||||
class Role(enum.IntEnum):
|
||||
HOST = 0x00
|
||||
DEVICE = 0x01
|
||||
|
||||
def __init__(self, device: device.Device, role: Role) -> None:
|
||||
super().__init__()
|
||||
self.remote_device_bd_address: Optional[Address] = None
|
||||
self.device = device
|
||||
self.connection = connection
|
||||
|
||||
self.l2cap_ctrl_channel = None
|
||||
self.l2cap_intr_channel = None
|
||||
self.role = role
|
||||
|
||||
# Register ourselves with the L2CAP channel manager
|
||||
device.register_l2cap_server(HID_CONTROL_PSM, self.on_connection)
|
||||
device.register_l2cap_server(HID_INTERRUPT_PSM, self.on_connection)
|
||||
device.register_l2cap_server(HID_CONTROL_PSM, self.on_l2cap_connection)
|
||||
device.register_l2cap_server(HID_INTERRUPT_PSM, self.on_l2cap_connection)
|
||||
|
||||
device.on('connection', self.on_device_connection)
|
||||
|
||||
async def connect_control_channel(self) -> None:
|
||||
# Create a new L2CAP connection - control channel
|
||||
@@ -229,9 +258,18 @@ class Host(EventEmitter):
|
||||
self.l2cap_ctrl_channel = None
|
||||
await channel.disconnect()
|
||||
|
||||
def on_connection(self, l2cap_channel: l2cap.ClassicChannel) -> None:
|
||||
def on_device_connection(self, connection: device.Connection) -> None:
|
||||
self.connection = connection
|
||||
self.remote_device_bd_address = connection.peer_address
|
||||
connection.on('disconnection', self.on_device_disconnection)
|
||||
|
||||
def on_device_disconnection(self, reason: int) -> None:
|
||||
self.connection = None
|
||||
|
||||
def on_l2cap_connection(self, l2cap_channel: l2cap.ClassicChannel) -> None:
|
||||
logger.debug(f'+++ New L2CAP connection: {l2cap_channel}')
|
||||
l2cap_channel.on('open', lambda: self.on_l2cap_channel_open(l2cap_channel))
|
||||
l2cap_channel.on('close', lambda: self.on_l2cap_channel_close(l2cap_channel))
|
||||
|
||||
def on_l2cap_channel_open(self, l2cap_channel: l2cap.ClassicChannel) -> None:
|
||||
if l2cap_channel.psm == HID_CONTROL_PSM:
|
||||
@@ -242,63 +280,20 @@ class Host(EventEmitter):
|
||||
self.l2cap_intr_channel.sink = self.on_intr_pdu
|
||||
logger.debug(f'$$$ L2CAP channel open: {l2cap_channel}')
|
||||
|
||||
def on_ctrl_pdu(self, pdu: bytes) -> None:
|
||||
logger.debug(f'<<< HID CONTROL PDU: {pdu.hex()}')
|
||||
# Here we will receive all kinds of packets, parse and then call respective callbacks
|
||||
message_type = pdu[0] >> 4
|
||||
param = pdu[0] & 0x0F
|
||||
|
||||
if message_type == Message.MessageType.HANDSHAKE:
|
||||
logger.debug(f'<<< HID HANDSHAKE: {Message.Handshake(param).name}')
|
||||
self.emit('handshake', Message.Handshake(param))
|
||||
elif message_type == Message.MessageType.DATA:
|
||||
logger.debug('<<< HID CONTROL DATA')
|
||||
self.emit('data', pdu)
|
||||
elif message_type == Message.MessageType.CONTROL:
|
||||
if param == Message.ControlCommand.SUSPEND:
|
||||
logger.debug('<<< HID SUSPEND')
|
||||
self.emit('suspend', pdu)
|
||||
elif param == Message.ControlCommand.EXIT_SUSPEND:
|
||||
logger.debug('<<< HID EXIT SUSPEND')
|
||||
self.emit('exit_suspend', pdu)
|
||||
elif param == Message.ControlCommand.VIRTUAL_CABLE_UNPLUG:
|
||||
logger.debug('<<< HID VIRTUAL CABLE UNPLUG')
|
||||
self.emit('virtual_cable_unplug')
|
||||
else:
|
||||
logger.debug('<<< HID CONTROL OPERATION UNSUPPORTED')
|
||||
def on_l2cap_channel_close(self, l2cap_channel: l2cap.ClassicChannel) -> None:
|
||||
if l2cap_channel.psm == HID_CONTROL_PSM:
|
||||
self.l2cap_ctrl_channel = None
|
||||
else:
|
||||
logger.debug('<<< HID CONTROL DATA')
|
||||
self.emit('data', pdu)
|
||||
self.l2cap_intr_channel = None
|
||||
logger.debug(f'$$$ L2CAP channel close: {l2cap_channel}')
|
||||
|
||||
@abstractmethod
|
||||
def on_ctrl_pdu(self, pdu: bytes) -> None:
|
||||
pass
|
||||
|
||||
def on_intr_pdu(self, pdu: bytes) -> None:
|
||||
logger.debug(f'<<< HID INTERRUPT PDU: {pdu.hex()}')
|
||||
self.emit("data", pdu)
|
||||
|
||||
def get_report(self, report_type: int, report_id: int, buffer_size: int) -> None:
|
||||
msg = GetReportMessage(
|
||||
report_type=report_type, report_id=report_id, buffer_size=buffer_size
|
||||
)
|
||||
hid_message = bytes(msg)
|
||||
logger.debug(f'>>> HID CONTROL GET REPORT, PDU: {hid_message.hex()}')
|
||||
self.send_pdu_on_ctrl(hid_message)
|
||||
|
||||
def set_report(self, report_type: int, data: bytes):
|
||||
msg = SetReportMessage(report_type=report_type, data=data)
|
||||
hid_message = bytes(msg)
|
||||
logger.debug(f'>>> HID CONTROL SET REPORT, PDU:{hid_message.hex()}')
|
||||
self.send_pdu_on_ctrl(hid_message)
|
||||
|
||||
def get_protocol(self):
|
||||
msg = GetProtocolMessage()
|
||||
hid_message = bytes(msg)
|
||||
logger.debug(f'>>> HID CONTROL GET PROTOCOL, PDU: {hid_message.hex()}')
|
||||
self.send_pdu_on_ctrl(hid_message)
|
||||
|
||||
def set_protocol(self, protocol_mode: int):
|
||||
msg = SetProtocolMessage(protocol_mode=protocol_mode)
|
||||
hid_message = bytes(msg)
|
||||
logger.debug(f'>>> HID CONTROL SET PROTOCOL, PDU: {hid_message.hex()}')
|
||||
self.send_pdu_on_ctrl(hid_message)
|
||||
self.emit("interrupt_data", pdu)
|
||||
|
||||
def send_pdu_on_ctrl(self, msg: bytes) -> None:
|
||||
assert self.l2cap_ctrl_channel
|
||||
@@ -308,26 +303,252 @@ class Host(EventEmitter):
|
||||
assert self.l2cap_intr_channel
|
||||
self.l2cap_intr_channel.send_pdu(msg)
|
||||
|
||||
def send_data(self, data):
|
||||
msg = SendData(data)
|
||||
def send_data(self, data: bytes) -> None:
|
||||
if self.role == HID.Role.HOST:
|
||||
report_type = Message.ReportType.OUTPUT_REPORT
|
||||
else:
|
||||
report_type = Message.ReportType.INPUT_REPORT
|
||||
msg = SendData(data, report_type)
|
||||
hid_message = bytes(msg)
|
||||
logger.debug(f'>>> HID INTERRUPT SEND DATA, PDU: {hid_message.hex()}')
|
||||
self.send_pdu_on_intr(hid_message)
|
||||
if self.l2cap_intr_channel is not None:
|
||||
logger.debug(f'>>> HID INTERRUPT SEND DATA, PDU: {hid_message.hex()}')
|
||||
self.send_pdu_on_intr(hid_message)
|
||||
|
||||
def suspend(self):
|
||||
msg = Suspend()
|
||||
hid_message = bytes(msg)
|
||||
logger.debug(f'>>> HID CONTROL SUSPEND, PDU:{hid_message.hex()}')
|
||||
self.send_pdu_on_ctrl(msg)
|
||||
|
||||
def exit_suspend(self):
|
||||
msg = ExitSuspend()
|
||||
hid_message = bytes(msg)
|
||||
logger.debug(f'>>> HID CONTROL EXIT SUSPEND, PDU:{hid_message.hex()}')
|
||||
self.send_pdu_on_ctrl(msg)
|
||||
|
||||
def virtual_cable_unplug(self):
|
||||
def virtual_cable_unplug(self) -> None:
|
||||
msg = VirtualCableUnplug()
|
||||
hid_message = bytes(msg)
|
||||
logger.debug(f'>>> HID CONTROL VIRTUAL CABLE UNPLUG, PDU: {hid_message.hex()}')
|
||||
self.send_pdu_on_ctrl(msg)
|
||||
self.send_pdu_on_ctrl(hid_message)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class Device(HID):
|
||||
class GetSetReturn(enum.IntEnum):
|
||||
FAILURE = 0x00
|
||||
REPORT_ID_NOT_FOUND = 0x01
|
||||
ERR_UNSUPPORTED_REQUEST = 0x02
|
||||
ERR_UNKNOWN = 0x03
|
||||
ERR_INVALID_PARAMETER = 0x04
|
||||
SUCCESS = 0xFF
|
||||
|
||||
class GetSetStatus:
|
||||
def __init__(self) -> None:
|
||||
self.data = bytearray()
|
||||
self.status = 0
|
||||
|
||||
def __init__(self, device: device.Device) -> None:
|
||||
super().__init__(device, HID.Role.DEVICE)
|
||||
get_report_cb: Optional[Callable[[int, int, int], None]] = None
|
||||
set_report_cb: Optional[Callable[[int, int, int, bytes], None]] = None
|
||||
get_protocol_cb: Optional[Callable[[], None]] = None
|
||||
set_protocol_cb: Optional[Callable[[int], None]] = None
|
||||
|
||||
@override
|
||||
def on_ctrl_pdu(self, pdu: bytes) -> None:
|
||||
logger.debug(f'<<< HID CONTROL PDU: {pdu.hex()}')
|
||||
param = pdu[0] & 0x0F
|
||||
message_type = pdu[0] >> 4
|
||||
|
||||
if message_type == Message.MessageType.GET_REPORT:
|
||||
logger.debug('<<< HID GET REPORT')
|
||||
self.handle_get_report(pdu)
|
||||
elif message_type == Message.MessageType.SET_REPORT:
|
||||
logger.debug('<<< HID SET REPORT')
|
||||
self.handle_set_report(pdu)
|
||||
elif message_type == Message.MessageType.GET_PROTOCOL:
|
||||
logger.debug('<<< HID GET PROTOCOL')
|
||||
self.handle_get_protocol(pdu)
|
||||
elif message_type == Message.MessageType.SET_PROTOCOL:
|
||||
logger.debug('<<< HID SET PROTOCOL')
|
||||
self.handle_set_protocol(pdu)
|
||||
elif message_type == Message.MessageType.DATA:
|
||||
logger.debug('<<< HID CONTROL DATA')
|
||||
self.emit('control_data', pdu)
|
||||
elif message_type == Message.MessageType.CONTROL:
|
||||
if param == Message.ControlCommand.SUSPEND:
|
||||
logger.debug('<<< HID SUSPEND')
|
||||
self.emit('suspend')
|
||||
elif param == Message.ControlCommand.EXIT_SUSPEND:
|
||||
logger.debug('<<< HID EXIT SUSPEND')
|
||||
self.emit('exit_suspend')
|
||||
elif param == Message.ControlCommand.VIRTUAL_CABLE_UNPLUG:
|
||||
logger.debug('<<< HID VIRTUAL CABLE UNPLUG')
|
||||
self.emit('virtual_cable_unplug')
|
||||
else:
|
||||
logger.debug('<<< HID CONTROL OPERATION UNSUPPORTED')
|
||||
else:
|
||||
logger.debug('<<< HID MESSAGE TYPE UNSUPPORTED')
|
||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||
|
||||
def send_handshake_message(self, result_code: int) -> None:
|
||||
msg = SendHandshakeMessage(result_code)
|
||||
hid_message = bytes(msg)
|
||||
logger.debug(f'>>> HID HANDSHAKE MESSAGE, PDU: {hid_message.hex()}')
|
||||
self.send_pdu_on_ctrl(hid_message)
|
||||
|
||||
def send_control_data(self, report_type: int, data: bytes):
|
||||
msg = SendControlData(report_type=report_type, data=data)
|
||||
hid_message = bytes(msg)
|
||||
logger.debug(f'>>> HID CONTROL DATA: {hid_message.hex()}')
|
||||
self.send_pdu_on_ctrl(hid_message)
|
||||
|
||||
def handle_get_report(self, pdu: bytes):
|
||||
if self.get_report_cb is None:
|
||||
logger.debug("GetReport callback not registered !!")
|
||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||
return
|
||||
report_type = pdu[0] & 0x03
|
||||
buffer_flag = (pdu[0] & 0x08) >> 3
|
||||
report_id = pdu[1]
|
||||
logger.debug(f"buffer_flag: {buffer_flag}")
|
||||
if buffer_flag == 1:
|
||||
buffer_size = (pdu[3] << 8) | pdu[2]
|
||||
else:
|
||||
buffer_size = 0
|
||||
|
||||
ret = self.get_report_cb(report_id, report_type, buffer_size)
|
||||
assert ret is not None
|
||||
if ret.status == self.GetSetReturn.FAILURE:
|
||||
self.send_handshake_message(Message.Handshake.ERR_UNKNOWN)
|
||||
elif ret.status == self.GetSetReturn.SUCCESS:
|
||||
data = bytearray()
|
||||
data.append(report_id)
|
||||
data.extend(ret.data)
|
||||
if len(data) < self.l2cap_ctrl_channel.mtu: # type: ignore[union-attr]
|
||||
self.send_control_data(report_type=report_type, data=data)
|
||||
else:
|
||||
self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER)
|
||||
elif ret.status == self.GetSetReturn.REPORT_ID_NOT_FOUND:
|
||||
self.send_handshake_message(Message.Handshake.ERR_INVALID_REPORT_ID)
|
||||
elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER:
|
||||
self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER)
|
||||
elif ret.status == self.GetSetReturn.ERR_UNSUPPORTED_REQUEST:
|
||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||
|
||||
def register_get_report_cb(self, cb: Callable[[int, int, int], None]) -> None:
|
||||
self.get_report_cb = cb
|
||||
logger.debug("GetReport callback registered successfully")
|
||||
|
||||
def handle_set_report(self, pdu: bytes):
|
||||
if self.set_report_cb is None:
|
||||
logger.debug("SetReport callback not registered !!")
|
||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||
return
|
||||
report_type = pdu[0] & 0x03
|
||||
report_id = pdu[1]
|
||||
report_data = pdu[2:]
|
||||
report_size = len(report_data) + 1
|
||||
ret = self.set_report_cb(report_id, report_type, report_size, report_data)
|
||||
assert ret is not None
|
||||
if ret.status == self.GetSetReturn.SUCCESS:
|
||||
self.send_handshake_message(Message.Handshake.SUCCESSFUL)
|
||||
elif ret.status == self.GetSetReturn.ERR_INVALID_PARAMETER:
|
||||
self.send_handshake_message(Message.Handshake.ERR_INVALID_PARAMETER)
|
||||
elif ret.status == self.GetSetReturn.REPORT_ID_NOT_FOUND:
|
||||
self.send_handshake_message(Message.Handshake.ERR_INVALID_REPORT_ID)
|
||||
else:
|
||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||
|
||||
def register_set_report_cb(
|
||||
self, cb: Callable[[int, int, int, bytes], None]
|
||||
) -> None:
|
||||
self.set_report_cb = cb
|
||||
logger.debug("SetReport callback registered successfully")
|
||||
|
||||
def handle_get_protocol(self, pdu: bytes):
|
||||
if self.get_protocol_cb is None:
|
||||
logger.debug("GetProtocol callback not registered !!")
|
||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||
return
|
||||
ret = self.get_protocol_cb()
|
||||
assert ret is not None
|
||||
if ret.status == self.GetSetReturn.SUCCESS:
|
||||
self.send_control_data(Message.ReportType.OTHER_REPORT, ret.data)
|
||||
else:
|
||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||
|
||||
def register_get_protocol_cb(self, cb: Callable[[], None]) -> None:
|
||||
self.get_protocol_cb = cb
|
||||
logger.debug("GetProtocol callback registered successfully")
|
||||
|
||||
def handle_set_protocol(self, pdu: bytes):
|
||||
if self.set_protocol_cb is None:
|
||||
logger.debug("SetProtocol callback not registered !!")
|
||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||
return
|
||||
ret = self.set_protocol_cb(pdu[0] & 0x01)
|
||||
assert ret is not None
|
||||
if ret.status == self.GetSetReturn.SUCCESS:
|
||||
self.send_handshake_message(Message.Handshake.SUCCESSFUL)
|
||||
else:
|
||||
self.send_handshake_message(Message.Handshake.ERR_UNSUPPORTED_REQUEST)
|
||||
|
||||
def register_set_protocol_cb(self, cb: Callable[[int], None]) -> None:
|
||||
self.set_protocol_cb = cb
|
||||
logger.debug("SetProtocol callback registered successfully")
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Host(HID):
|
||||
def __init__(self, device: device.Device) -> None:
|
||||
super().__init__(device, HID.Role.HOST)
|
||||
|
||||
def get_report(self, report_type: int, report_id: int, buffer_size: int) -> None:
|
||||
msg = GetReportMessage(
|
||||
report_type=report_type, report_id=report_id, buffer_size=buffer_size
|
||||
)
|
||||
hid_message = bytes(msg)
|
||||
logger.debug(f'>>> HID CONTROL GET REPORT, PDU: {hid_message.hex()}')
|
||||
self.send_pdu_on_ctrl(hid_message)
|
||||
|
||||
def set_report(self, report_type: int, data: bytes) -> None:
|
||||
msg = SetReportMessage(report_type=report_type, data=data)
|
||||
hid_message = bytes(msg)
|
||||
logger.debug(f'>>> HID CONTROL SET REPORT, PDU:{hid_message.hex()}')
|
||||
self.send_pdu_on_ctrl(hid_message)
|
||||
|
||||
def get_protocol(self) -> None:
|
||||
msg = GetProtocolMessage()
|
||||
hid_message = bytes(msg)
|
||||
logger.debug(f'>>> HID CONTROL GET PROTOCOL, PDU: {hid_message.hex()}')
|
||||
self.send_pdu_on_ctrl(hid_message)
|
||||
|
||||
def set_protocol(self, protocol_mode: int) -> None:
|
||||
msg = SetProtocolMessage(protocol_mode=protocol_mode)
|
||||
hid_message = bytes(msg)
|
||||
logger.debug(f'>>> HID CONTROL SET PROTOCOL, PDU: {hid_message.hex()}')
|
||||
self.send_pdu_on_ctrl(hid_message)
|
||||
|
||||
def suspend(self) -> None:
|
||||
msg = Suspend()
|
||||
hid_message = bytes(msg)
|
||||
logger.debug(f'>>> HID CONTROL SUSPEND, PDU:{hid_message.hex()}')
|
||||
self.send_pdu_on_ctrl(hid_message)
|
||||
|
||||
def exit_suspend(self) -> None:
|
||||
msg = ExitSuspend()
|
||||
hid_message = bytes(msg)
|
||||
logger.debug(f'>>> HID CONTROL EXIT SUSPEND, PDU:{hid_message.hex()}')
|
||||
self.send_pdu_on_ctrl(hid_message)
|
||||
|
||||
@override
|
||||
def on_ctrl_pdu(self, pdu: bytes) -> None:
|
||||
logger.debug(f'<<< HID CONTROL PDU: {pdu.hex()}')
|
||||
param = pdu[0] & 0x0F
|
||||
message_type = pdu[0] >> 4
|
||||
if message_type == Message.MessageType.HANDSHAKE:
|
||||
logger.debug(f'<<< HID HANDSHAKE: {Message.Handshake(param).name}')
|
||||
self.emit('handshake', Message.Handshake(param))
|
||||
elif message_type == Message.MessageType.DATA:
|
||||
logger.debug('<<< HID CONTROL DATA')
|
||||
self.emit('control_data', pdu)
|
||||
elif message_type == Message.MessageType.CONTROL:
|
||||
if param == Message.ControlCommand.VIRTUAL_CABLE_UNPLUG:
|
||||
logger.debug('<<< HID VIRTUAL CABLE UNPLUG')
|
||||
self.emit('virtual_cable_unplug')
|
||||
else:
|
||||
logger.debug('<<< HID CONTROL OPERATION UNSUPPORTED')
|
||||
else:
|
||||
logger.debug('<<< HID MESSAGE TYPE UNSUPPORTED')
|
||||
|
||||
221
bumble/host.py
221
bumble/host.py
@@ -21,7 +21,7 @@ import collections
|
||||
import logging
|
||||
import struct
|
||||
|
||||
from typing import Optional, TYPE_CHECKING, Dict, Callable, Awaitable, cast
|
||||
from typing import Any, Awaitable, Callable, Deque, Dict, Optional, cast, TYPE_CHECKING
|
||||
|
||||
from bumble.colors import color
|
||||
from bumble.l2cap import L2CAP_PDU
|
||||
@@ -91,16 +91,49 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Constants
|
||||
# -----------------------------------------------------------------------------
|
||||
# fmt: off
|
||||
class AclPacketQueue:
|
||||
max_packet_size: int
|
||||
|
||||
HOST_DEFAULT_HC_LE_ACL_DATA_PACKET_LENGTH = 27
|
||||
HOST_HC_TOTAL_NUM_LE_ACL_DATA_PACKETS = 1
|
||||
HOST_DEFAULT_HC_ACL_DATA_PACKET_LENGTH = 27
|
||||
HOST_HC_TOTAL_NUM_ACL_DATA_PACKETS = 1
|
||||
def __init__(
|
||||
self,
|
||||
max_packet_size: int,
|
||||
max_in_flight: int,
|
||||
send: Callable[[HCI_Packet], None],
|
||||
) -> None:
|
||||
self.max_packet_size = max_packet_size
|
||||
self.max_in_flight = max_in_flight
|
||||
self.in_flight = 0
|
||||
self.send = send
|
||||
self.packets: Deque[HCI_AclDataPacket] = collections.deque()
|
||||
|
||||
# fmt: on
|
||||
def enqueue(self, packet: HCI_AclDataPacket) -> None:
|
||||
self.packets.appendleft(packet)
|
||||
self.check_queue()
|
||||
|
||||
if self.packets:
|
||||
logger.debug(
|
||||
f'{self.in_flight} ACL packets in flight, '
|
||||
f'{len(self.packets)} in queue'
|
||||
)
|
||||
|
||||
def check_queue(self) -> None:
|
||||
while self.packets and self.in_flight < self.max_in_flight:
|
||||
packet = self.packets.pop()
|
||||
self.send(packet)
|
||||
self.in_flight += 1
|
||||
|
||||
def on_packets_completed(self, packet_count: int) -> None:
|
||||
if packet_count > self.in_flight:
|
||||
logger.warning(
|
||||
color(
|
||||
'!!! {packet_count} completed but only '
|
||||
f'{self.in_flight} in flight'
|
||||
)
|
||||
)
|
||||
packet_count = self.in_flight
|
||||
|
||||
self.in_flight -= packet_count
|
||||
self.check_queue()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -111,6 +144,13 @@ class Connection:
|
||||
self.peer_address = peer_address
|
||||
self.assembler = HCI_AclDataPacketAssembler(self.on_acl_pdu)
|
||||
self.transport = transport
|
||||
acl_packet_queue: Optional[AclPacketQueue] = (
|
||||
host.le_acl_packet_queue
|
||||
if transport == BT_LE_TRANSPORT
|
||||
else host.acl_packet_queue
|
||||
)
|
||||
assert acl_packet_queue
|
||||
self.acl_packet_queue = acl_packet_queue
|
||||
|
||||
def on_hci_acl_data_packet(self, packet: HCI_AclDataPacket) -> None:
|
||||
self.assembler.feed_packet(packet)
|
||||
@@ -123,8 +163,10 @@ class Connection:
|
||||
# -----------------------------------------------------------------------------
|
||||
class Host(AbortableEventEmitter):
|
||||
connections: Dict[int, Connection]
|
||||
acl_packet_queue: collections.deque[HCI_AclDataPacket]
|
||||
hci_sink: TransportSink
|
||||
acl_packet_queue: Optional[AclPacketQueue] = None
|
||||
le_acl_packet_queue: Optional[AclPacketQueue] = None
|
||||
hci_sink: Optional[TransportSink] = None
|
||||
hci_metadata: Dict[str, Any]
|
||||
long_term_key_provider: Optional[
|
||||
Callable[[int, bytes, int], Awaitable[Optional[bytes]]]
|
||||
]
|
||||
@@ -137,18 +179,11 @@ class Host(AbortableEventEmitter):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.hci_metadata = None
|
||||
self.hci_metadata = {}
|
||||
self.ready = False # True when we can accept incoming packets
|
||||
self.reset_done = False
|
||||
self.connections = {} # Connections, by connection handle
|
||||
self.pending_command = None
|
||||
self.pending_response = None
|
||||
self.hc_le_acl_data_packet_length = HOST_DEFAULT_HC_LE_ACL_DATA_PACKET_LENGTH
|
||||
self.hc_total_num_le_acl_data_packets = HOST_HC_TOTAL_NUM_LE_ACL_DATA_PACKETS
|
||||
self.hc_acl_data_packet_length = HOST_DEFAULT_HC_ACL_DATA_PACKET_LENGTH
|
||||
self.hc_total_num_acl_data_packets = HOST_HC_TOTAL_NUM_ACL_DATA_PACKETS
|
||||
self.acl_packet_queue = collections.deque()
|
||||
self.acl_packets_in_flight = 0
|
||||
self.local_version = None
|
||||
self.local_supported_commands = bytes(64)
|
||||
self.local_le_features = 0
|
||||
@@ -162,10 +197,7 @@ class Host(AbortableEventEmitter):
|
||||
|
||||
# Connect to the source and sink if specified
|
||||
if controller_source:
|
||||
controller_source.set_packet_sink(self)
|
||||
self.hci_metadata = getattr(
|
||||
controller_source, 'metadata', self.hci_metadata
|
||||
)
|
||||
self.set_packet_source(controller_source)
|
||||
if controller_sink:
|
||||
self.set_packet_sink(controller_sink)
|
||||
|
||||
@@ -200,17 +232,21 @@ class Host(AbortableEventEmitter):
|
||||
self.ready = False
|
||||
await self.flush()
|
||||
|
||||
await self.send_command(HCI_Reset_Command(), check_result=True)
|
||||
self.ready = True
|
||||
|
||||
# Instantiate and init a driver for the host if needed.
|
||||
# NOTE: we don't keep a reference to the driver here, because we don't
|
||||
# currently have a need for the driver later on. But if the driver interface
|
||||
# evolves, it may be required, then, to store a reference to the driver in
|
||||
# an object property.
|
||||
reset_needed = True
|
||||
if driver_factory is not None:
|
||||
if driver := await driver_factory(self):
|
||||
await driver.init_controller()
|
||||
reset_needed = False
|
||||
|
||||
# Send a reset command unless a driver has already done so.
|
||||
if reset_needed:
|
||||
await self.send_command(HCI_Reset_Command(), check_result=True)
|
||||
self.ready = True
|
||||
|
||||
response = await self.send_command(
|
||||
HCI_Read_Local_Supported_Commands_Command(), check_result=True
|
||||
@@ -253,46 +289,54 @@ class Host(AbortableEventEmitter):
|
||||
response = await self.send_command(
|
||||
HCI_Read_Buffer_Size_Command(), check_result=True
|
||||
)
|
||||
self.hc_acl_data_packet_length = (
|
||||
hc_acl_data_packet_length = (
|
||||
response.return_parameters.hc_acl_data_packet_length
|
||||
)
|
||||
self.hc_total_num_acl_data_packets = (
|
||||
hc_total_num_acl_data_packets = (
|
||||
response.return_parameters.hc_total_num_acl_data_packets
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
'HCI ACL flow control: '
|
||||
f'hc_acl_data_packet_length={self.hc_acl_data_packet_length},'
|
||||
f'hc_total_num_acl_data_packets={self.hc_total_num_acl_data_packets}'
|
||||
f'hc_acl_data_packet_length={hc_acl_data_packet_length},'
|
||||
f'hc_total_num_acl_data_packets={hc_total_num_acl_data_packets}'
|
||||
)
|
||||
|
||||
self.acl_packet_queue = AclPacketQueue(
|
||||
max_packet_size=hc_acl_data_packet_length,
|
||||
max_in_flight=hc_total_num_acl_data_packets,
|
||||
send=self.send_hci_packet,
|
||||
)
|
||||
|
||||
hc_le_acl_data_packet_length = 0
|
||||
hc_total_num_le_acl_data_packets = 0
|
||||
if self.supports_command(HCI_LE_READ_BUFFER_SIZE_COMMAND):
|
||||
response = await self.send_command(
|
||||
HCI_LE_Read_Buffer_Size_Command(), check_result=True
|
||||
)
|
||||
self.hc_le_acl_data_packet_length = (
|
||||
hc_le_acl_data_packet_length = (
|
||||
response.return_parameters.hc_le_acl_data_packet_length
|
||||
)
|
||||
self.hc_total_num_le_acl_data_packets = (
|
||||
hc_total_num_le_acl_data_packets = (
|
||||
response.return_parameters.hc_total_num_le_acl_data_packets
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
'HCI LE ACL flow control: '
|
||||
f'hc_le_acl_data_packet_length={self.hc_le_acl_data_packet_length},'
|
||||
'hc_total_num_le_acl_data_packets='
|
||||
f'{self.hc_total_num_le_acl_data_packets}'
|
||||
f'hc_le_acl_data_packet_length={hc_le_acl_data_packet_length},'
|
||||
f'hc_total_num_le_acl_data_packets={hc_total_num_le_acl_data_packets}'
|
||||
)
|
||||
|
||||
if (
|
||||
response.return_parameters.hc_le_acl_data_packet_length == 0
|
||||
or response.return_parameters.hc_total_num_le_acl_data_packets == 0
|
||||
):
|
||||
# LE and Classic share the same values
|
||||
self.hc_le_acl_data_packet_length = self.hc_acl_data_packet_length
|
||||
self.hc_total_num_le_acl_data_packets = (
|
||||
self.hc_total_num_acl_data_packets
|
||||
)
|
||||
if hc_le_acl_data_packet_length == 0 or hc_total_num_le_acl_data_packets == 0:
|
||||
# LE and Classic share the same queue
|
||||
self.le_acl_packet_queue = self.acl_packet_queue
|
||||
else:
|
||||
# Create a separate queue for LE
|
||||
self.le_acl_packet_queue = AclPacketQueue(
|
||||
max_packet_size=hc_le_acl_data_packet_length,
|
||||
max_in_flight=hc_total_num_le_acl_data_packets,
|
||||
send=self.send_hci_packet,
|
||||
)
|
||||
|
||||
if self.supports_command(
|
||||
HCI_LE_READ_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND
|
||||
@@ -313,29 +357,31 @@ class Host(AbortableEventEmitter):
|
||||
)
|
||||
)
|
||||
|
||||
self.reset_done = True
|
||||
|
||||
@property
|
||||
def controller(self) -> TransportSink:
|
||||
def controller(self) -> Optional[TransportSink]:
|
||||
return self.hci_sink
|
||||
|
||||
@controller.setter
|
||||
def controller(self, controller):
|
||||
def controller(self, controller) -> None:
|
||||
self.set_packet_sink(controller)
|
||||
if controller:
|
||||
controller.set_packet_sink(self)
|
||||
|
||||
def set_packet_sink(self, sink: TransportSink) -> None:
|
||||
def set_packet_sink(self, sink: Optional[TransportSink]) -> None:
|
||||
self.hci_sink = sink
|
||||
|
||||
def set_packet_source(self, source: TransportSource) -> None:
|
||||
source.set_packet_sink(self)
|
||||
self.hci_metadata = getattr(source, 'metadata', self.hci_metadata)
|
||||
|
||||
def send_hci_packet(self, packet: HCI_Packet) -> None:
|
||||
logger.debug(f'{color("### HOST -> CONTROLLER", "blue")}: {packet}')
|
||||
if self.snooper:
|
||||
self.snooper.snoop(bytes(packet), Snooper.Direction.HOST_TO_CONTROLLER)
|
||||
self.hci_sink.on_packet(bytes(packet))
|
||||
if self.hci_sink:
|
||||
self.hci_sink.on_packet(bytes(packet))
|
||||
|
||||
async def send_command(self, command, check_result=False):
|
||||
logger.debug(f'{color("### HOST -> CONTROLLER", "blue")}: {command}')
|
||||
|
||||
# Wait until we can send (only one pending command at a time)
|
||||
async with self.command_semaphore:
|
||||
assert self.pending_command is None
|
||||
@@ -383,6 +429,17 @@ class Host(AbortableEventEmitter):
|
||||
asyncio.create_task(send_command(command))
|
||||
|
||||
def send_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes) -> None:
|
||||
if not (connection := self.connections.get(connection_handle)):
|
||||
logger.warning(f'connection 0x{connection_handle:04X} not found')
|
||||
return
|
||||
packet_queue = connection.acl_packet_queue
|
||||
if packet_queue is None:
|
||||
logger.warning(
|
||||
f'no ACL packet queue for connection 0x{connection_handle:04X}'
|
||||
)
|
||||
return
|
||||
|
||||
# Create a PDU
|
||||
l2cap_pdu = bytes(L2CAP_PDU(cid, pdu))
|
||||
|
||||
# Send the data to the controller via ACL packets
|
||||
@@ -390,8 +447,7 @@ class Host(AbortableEventEmitter):
|
||||
offset = 0
|
||||
pb_flag = 0
|
||||
while bytes_remaining:
|
||||
# TODO: support different LE/Classic lengths
|
||||
data_total_length = min(bytes_remaining, self.hc_le_acl_data_packet_length)
|
||||
data_total_length = min(bytes_remaining, packet_queue.max_packet_size)
|
||||
acl_packet = HCI_AclDataPacket(
|
||||
connection_handle=connection_handle,
|
||||
pb_flag=pb_flag,
|
||||
@@ -399,34 +455,12 @@ class Host(AbortableEventEmitter):
|
||||
data_total_length=data_total_length,
|
||||
data=l2cap_pdu[offset : offset + data_total_length],
|
||||
)
|
||||
logger.debug(
|
||||
f'{color("### HOST -> CONTROLLER", "blue")}: (CID={cid}) {acl_packet}'
|
||||
)
|
||||
self.queue_acl_packet(acl_packet)
|
||||
logger.debug(f'>>> ACL packet enqueue: (CID={cid}) {acl_packet}')
|
||||
packet_queue.enqueue(acl_packet)
|
||||
pb_flag = 1
|
||||
offset += data_total_length
|
||||
bytes_remaining -= data_total_length
|
||||
|
||||
def queue_acl_packet(self, acl_packet: HCI_AclDataPacket) -> None:
|
||||
self.acl_packet_queue.appendleft(acl_packet)
|
||||
self.check_acl_packet_queue()
|
||||
|
||||
if len(self.acl_packet_queue):
|
||||
logger.debug(
|
||||
f'{self.acl_packets_in_flight} ACL packets in flight, '
|
||||
f'{len(self.acl_packet_queue)} in queue'
|
||||
)
|
||||
|
||||
def check_acl_packet_queue(self) -> None:
|
||||
# Send all we can (TODO: support different LE/Classic limits)
|
||||
while (
|
||||
len(self.acl_packet_queue) > 0
|
||||
and self.acl_packets_in_flight < self.hc_total_num_le_acl_data_packets
|
||||
):
|
||||
packet = self.acl_packet_queue.pop()
|
||||
self.send_hci_packet(packet)
|
||||
self.acl_packets_in_flight += 1
|
||||
|
||||
def supports_command(self, command):
|
||||
# Find the support flag position for this command
|
||||
for octet, flags in enumerate(HCI_SUPPORTED_COMMANDS_FLAGS):
|
||||
@@ -549,7 +583,7 @@ class Host(AbortableEventEmitter):
|
||||
# This is used just for the Num_HCI_Command_Packets field, not related to
|
||||
# an actual command
|
||||
logger.debug('no-command event')
|
||||
return None
|
||||
return
|
||||
|
||||
return self.on_command_processed(event)
|
||||
|
||||
@@ -557,18 +591,17 @@ class Host(AbortableEventEmitter):
|
||||
return self.on_command_processed(event)
|
||||
|
||||
def on_hci_number_of_completed_packets_event(self, event):
|
||||
total_packets = sum(event.num_completed_packets)
|
||||
if total_packets <= self.acl_packets_in_flight:
|
||||
self.acl_packets_in_flight -= total_packets
|
||||
self.check_acl_packet_queue()
|
||||
else:
|
||||
logger.warning(
|
||||
color(
|
||||
'!!! {total_packets} completed but only '
|
||||
f'{self.acl_packets_in_flight} in flight'
|
||||
for connection_handle, num_completed_packets in zip(
|
||||
event.connection_handles, event.num_completed_packets
|
||||
):
|
||||
if not (connection := self.connections.get(connection_handle)):
|
||||
logger.warning(
|
||||
'received packet completion event for unknown handle '
|
||||
f'0x{connection_handle:04X}'
|
||||
)
|
||||
)
|
||||
self.acl_packets_in_flight = 0
|
||||
continue
|
||||
|
||||
connection.acl_packet_queue.on_packets_completed(num_completed_packets)
|
||||
|
||||
# Classic only
|
||||
def on_hci_connection_request_event(self, event):
|
||||
@@ -721,6 +754,14 @@ class Host(AbortableEventEmitter):
|
||||
def on_hci_le_extended_advertising_report_event(self, event):
|
||||
self.on_hci_le_advertising_report_event(event)
|
||||
|
||||
def on_hci_le_advertising_set_terminated_event(self, event):
|
||||
self.emit(
|
||||
'advertising_set_termination',
|
||||
event.status,
|
||||
event.advertising_handle,
|
||||
event.connection_handle,
|
||||
)
|
||||
|
||||
def on_hci_le_cis_request_event(self, event):
|
||||
self.emit(
|
||||
'cis_request',
|
||||
|
||||
@@ -149,9 +149,10 @@ L2CAP_INVALID_CID_IN_REQUEST_REASON = 0x0002
|
||||
|
||||
L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_CREDITS = 65535
|
||||
L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU = 23
|
||||
L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MTU = 65535
|
||||
L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MPS = 23
|
||||
L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MPS = 65533
|
||||
L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU = 2046
|
||||
L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU = 2048
|
||||
L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS = 2048
|
||||
L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS = 256
|
||||
|
||||
@@ -188,8 +189,11 @@ class LeCreditBasedChannelSpec:
|
||||
or self.max_credits > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_CREDITS
|
||||
):
|
||||
raise ValueError('max credits out of range')
|
||||
if self.mtu < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU:
|
||||
raise ValueError('MTU too small')
|
||||
if (
|
||||
self.mtu < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU
|
||||
or self.mtu > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MTU
|
||||
):
|
||||
raise ValueError('MTU out of range')
|
||||
if (
|
||||
self.mps < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MPS
|
||||
or self.mps > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MPS
|
||||
@@ -1644,12 +1648,13 @@ class ChannelManager:
|
||||
|
||||
def send_pdu(self, connection, cid: int, pdu: Union[SupportsBytes, bytes]) -> None:
|
||||
pdu_str = pdu.hex() if isinstance(pdu, bytes) else str(pdu)
|
||||
pdu_bytes = bytes(pdu)
|
||||
logger.debug(
|
||||
f'{color(">>> Sending L2CAP PDU", "blue")} '
|
||||
f'on connection [0x{connection.handle:04X}] (CID={cid}) '
|
||||
f'{connection.peer_address}: {pdu_str}'
|
||||
f'{connection.peer_address}: {len(pdu_bytes)} bytes, {pdu_str}'
|
||||
)
|
||||
self.host.send_l2cap_pdu(connection.handle, cid, bytes(pdu))
|
||||
self.host.send_l2cap_pdu(connection.handle, cid, pdu_bytes)
|
||||
|
||||
def on_pdu(self, connection: Connection, cid: int, pdu: bytes) -> None:
|
||||
if cid in (L2CAP_SIGNALING_CID, L2CAP_LE_SIGNALING_CID):
|
||||
@@ -1926,7 +1931,7 @@ class ChannelManager:
|
||||
supervision_timeout=request.timeout,
|
||||
min_ce_length=0,
|
||||
max_ce_length=0,
|
||||
) # type: ignore[call-arg]
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.send_control_frame(
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
import struct
|
||||
import logging
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
from bumble import l2cap
|
||||
from ..core import AdvertisingData
|
||||
@@ -67,7 +67,7 @@ class AshaService(TemplateService):
|
||||
self.emit('volume', connection, value[0])
|
||||
|
||||
# Handler for audio control commands
|
||||
def on_audio_control_point_write(connection: Connection, value):
|
||||
def on_audio_control_point_write(connection: Optional[Connection], value):
|
||||
logger.info(f'--- AUDIO CONTROL POINT Write:{value.hex()}')
|
||||
opcode = value[0]
|
||||
if opcode == AshaService.OPCODE_START:
|
||||
|
||||
1247
bumble/profiles/bap.py
Normal file
1247
bumble/profiles/bap.py
Normal file
File diff suppressed because it is too large
Load Diff
52
bumble/profiles/cap.py
Normal file
52
bumble/profiles/cap.py
Normal file
@@ -0,0 +1,52 @@
|
||||
# Copyright 2021-2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
from __future__ import annotations
|
||||
|
||||
from bumble import gatt
|
||||
from bumble import gatt_client
|
||||
from bumble.profiles import csip
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Server
|
||||
# -----------------------------------------------------------------------------
|
||||
class CommonAudioServiceService(gatt.TemplateService):
|
||||
UUID = gatt.GATT_COMMON_AUDIO_SERVICE
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
coordinated_set_identification_service: csip.CoordinatedSetIdentificationService,
|
||||
) -> None:
|
||||
self.coordinated_set_identification_service = (
|
||||
coordinated_set_identification_service
|
||||
)
|
||||
super().__init__(
|
||||
characteristics=[],
|
||||
included_services=[coordinated_set_identification_service],
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Client
|
||||
# -----------------------------------------------------------------------------
|
||||
class CommonAudioServiceServiceProxy(gatt_client.ProfileServiceProxy):
|
||||
SERVICE_CLASS = CommonAudioServiceService
|
||||
|
||||
def __init__(self, service_proxy: gatt_client.ServiceProxy) -> None:
|
||||
self.service_proxy = service_proxy
|
||||
@@ -19,8 +19,11 @@
|
||||
from __future__ import annotations
|
||||
import enum
|
||||
import struct
|
||||
from typing import Optional
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from bumble import core
|
||||
from bumble import crypto
|
||||
from bumble import device
|
||||
from bumble import gatt
|
||||
from bumble import gatt_client
|
||||
|
||||
@@ -28,6 +31,9 @@ from bumble import gatt_client
|
||||
# -----------------------------------------------------------------------------
|
||||
# Constants
|
||||
# -----------------------------------------------------------------------------
|
||||
SET_IDENTITY_RESOLVING_KEY_LENGTH = 16
|
||||
|
||||
|
||||
class SirkType(enum.IntEnum):
|
||||
'''Coordinated Set Identification Service - 5.1 Set Identity Resolving Key.'''
|
||||
|
||||
@@ -43,9 +49,47 @@ class MemberLock(enum.IntEnum):
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Utils
|
||||
# Crypto Toolbox
|
||||
# -----------------------------------------------------------------------------
|
||||
# TODO: Implement RSI Generator
|
||||
def s1(m: bytes) -> bytes:
|
||||
'''
|
||||
Coordinated Set Identification Service - 4.3 s1 SALT generation function.
|
||||
'''
|
||||
return crypto.aes_cmac(m[::-1], bytes(16))[::-1]
|
||||
|
||||
|
||||
def k1(n: bytes, salt: bytes, p: bytes) -> bytes:
|
||||
'''
|
||||
Coordinated Set Identification Service - 4.4 k1 derivation function.
|
||||
'''
|
||||
t = crypto.aes_cmac(n[::-1], salt[::-1])
|
||||
return crypto.aes_cmac(p[::-1], t)[::-1]
|
||||
|
||||
|
||||
def sef(k: bytes, r: bytes) -> bytes:
|
||||
'''
|
||||
Coordinated Set Identification Service - 4.5 SIRK encryption function sef.
|
||||
|
||||
SIRK decryption function sdf shares the same algorithm. The only difference is that argument r is:
|
||||
* Plaintext in encryption
|
||||
* Cipher in decryption
|
||||
'''
|
||||
return crypto.xor(k1(k, s1(b'SIRKenc'[::-1]), b'csis'[::-1]), r)
|
||||
|
||||
|
||||
def sih(k: bytes, r: bytes) -> bytes:
|
||||
'''
|
||||
Coordinated Set Identification Service - 4.7 Resolvable Set Identifier hash function sih.
|
||||
'''
|
||||
return crypto.e(k, r + bytes(13))[:3]
|
||||
|
||||
|
||||
def generate_rsi(sirk: bytes) -> bytes:
|
||||
'''
|
||||
Coordinated Set Identification Service - 4.8 Resolvable Set Identifier generation operation.
|
||||
'''
|
||||
prand = crypto.generate_prand()
|
||||
return sih(sirk, prand) + prand
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -54,6 +98,7 @@ class MemberLock(enum.IntEnum):
|
||||
class CoordinatedSetIdentificationService(gatt.TemplateService):
|
||||
UUID = gatt.GATT_COORDINATED_SET_IDENTIFICATION_SERVICE
|
||||
|
||||
set_identity_resolving_key: bytes
|
||||
set_identity_resolving_key_characteristic: gatt.Characteristic
|
||||
coordinated_set_size_characteristic: Optional[gatt.Characteristic] = None
|
||||
set_member_lock_characteristic: Optional[gatt.Characteristic] = None
|
||||
@@ -62,19 +107,26 @@ class CoordinatedSetIdentificationService(gatt.TemplateService):
|
||||
def __init__(
|
||||
self,
|
||||
set_identity_resolving_key: bytes,
|
||||
set_identity_resolving_key_type: SirkType,
|
||||
coordinated_set_size: Optional[int] = None,
|
||||
set_member_lock: Optional[MemberLock] = None,
|
||||
set_member_rank: Optional[int] = None,
|
||||
) -> None:
|
||||
if len(set_identity_resolving_key) != SET_IDENTITY_RESOLVING_KEY_LENGTH:
|
||||
raise ValueError(
|
||||
f'Invalid SIRK length {len(set_identity_resolving_key)}, expected {SET_IDENTITY_RESOLVING_KEY_LENGTH}'
|
||||
)
|
||||
|
||||
characteristics = []
|
||||
|
||||
self.set_identity_resolving_key = set_identity_resolving_key
|
||||
self.set_identity_resolving_key_type = set_identity_resolving_key_type
|
||||
self.set_identity_resolving_key_characteristic = gatt.Characteristic(
|
||||
uuid=gatt.GATT_SET_IDENTITY_RESOLVING_KEY_CHARACTERISTIC,
|
||||
properties=gatt.Characteristic.Properties.READ
|
||||
| gatt.Characteristic.Properties.NOTIFY,
|
||||
permissions=gatt.Characteristic.Permissions.READABLE,
|
||||
# TODO: Implement encrypted SIRK reader.
|
||||
value=struct.pack('B', SirkType.PLAINTEXT) + set_identity_resolving_key,
|
||||
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
|
||||
value=gatt.CharacteristicValue(read=self.on_sirk_read),
|
||||
)
|
||||
characteristics.append(self.set_identity_resolving_key_characteristic)
|
||||
|
||||
@@ -83,7 +135,7 @@ class CoordinatedSetIdentificationService(gatt.TemplateService):
|
||||
uuid=gatt.GATT_COORDINATED_SET_SIZE_CHARACTERISTIC,
|
||||
properties=gatt.Characteristic.Properties.READ
|
||||
| gatt.Characteristic.Properties.NOTIFY,
|
||||
permissions=gatt.Characteristic.Permissions.READABLE,
|
||||
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
|
||||
value=struct.pack('B', coordinated_set_size),
|
||||
)
|
||||
characteristics.append(self.coordinated_set_size_characteristic)
|
||||
@@ -94,7 +146,7 @@ class CoordinatedSetIdentificationService(gatt.TemplateService):
|
||||
properties=gatt.Characteristic.Properties.READ
|
||||
| gatt.Characteristic.Properties.NOTIFY
|
||||
| gatt.Characteristic.Properties.WRITE,
|
||||
permissions=gatt.Characteristic.Permissions.READABLE
|
||||
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION
|
||||
| gatt.Characteristic.Permissions.WRITEABLE,
|
||||
value=struct.pack('B', set_member_lock),
|
||||
)
|
||||
@@ -105,13 +157,45 @@ class CoordinatedSetIdentificationService(gatt.TemplateService):
|
||||
uuid=gatt.GATT_SET_MEMBER_RANK_CHARACTERISTIC,
|
||||
properties=gatt.Characteristic.Properties.READ
|
||||
| gatt.Characteristic.Properties.NOTIFY,
|
||||
permissions=gatt.Characteristic.Permissions.READABLE,
|
||||
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
|
||||
value=struct.pack('B', set_member_rank),
|
||||
)
|
||||
characteristics.append(self.set_member_rank_characteristic)
|
||||
|
||||
super().__init__(characteristics)
|
||||
|
||||
async def on_sirk_read(self, connection: Optional[device.Connection]) -> bytes:
|
||||
if self.set_identity_resolving_key_type == SirkType.PLAINTEXT:
|
||||
sirk_bytes = self.set_identity_resolving_key
|
||||
else:
|
||||
assert connection
|
||||
|
||||
if connection.transport == core.BT_LE_TRANSPORT:
|
||||
key = await connection.device.get_long_term_key(
|
||||
connection_handle=connection.handle, rand=b'', ediv=0
|
||||
)
|
||||
else:
|
||||
key = await connection.device.get_link_key(connection.peer_address)
|
||||
|
||||
if not key:
|
||||
raise RuntimeError('LTK or LinkKey is not present')
|
||||
|
||||
sirk_bytes = sef(key, self.set_identity_resolving_key)
|
||||
|
||||
return bytes([self.set_identity_resolving_key_type]) + sirk_bytes
|
||||
|
||||
def get_advertising_data(self) -> bytes:
|
||||
return bytes(
|
||||
core.AdvertisingData(
|
||||
[
|
||||
(
|
||||
core.AdvertisingData.RESOLVABLE_SET_IDENTIFIER,
|
||||
generate_rsi(self.set_identity_resolving_key),
|
||||
),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Client
|
||||
@@ -145,3 +229,29 @@ class CoordinatedSetIdentificationProxy(gatt_client.ProfileServiceProxy):
|
||||
gatt.GATT_SET_MEMBER_RANK_CHARACTERISTIC
|
||||
):
|
||||
self.set_member_rank = characteristics[0]
|
||||
|
||||
async def read_set_identity_resolving_key(self) -> Tuple[SirkType, bytes]:
|
||||
'''Reads SIRK and decrypts if encrypted.'''
|
||||
response = await self.set_identity_resolving_key.read_value()
|
||||
if len(response) != SET_IDENTITY_RESOLVING_KEY_LENGTH + 1:
|
||||
raise RuntimeError('Invalid SIRK value')
|
||||
|
||||
sirk_type = SirkType(response[0])
|
||||
if sirk_type == SirkType.PLAINTEXT:
|
||||
sirk = response[1:]
|
||||
else:
|
||||
connection = self.service_proxy.client.connection
|
||||
device = connection.device
|
||||
if connection.transport == core.BT_LE_TRANSPORT:
|
||||
key = await device.get_long_term_key(
|
||||
connection_handle=connection.handle, rand=b'', ediv=0
|
||||
)
|
||||
else:
|
||||
key = await device.get_link_key(connection.peer_address)
|
||||
|
||||
if not key:
|
||||
raise RuntimeError('LTK or LinkKey is not present')
|
||||
|
||||
sirk = sef(key, response[1:])
|
||||
|
||||
return (sirk_type, sirk)
|
||||
|
||||
@@ -118,8 +118,8 @@ CRC_TABLE = bytes([
|
||||
0XBA, 0X2B, 0X59, 0XC8, 0XBD, 0X2C, 0X5E, 0XCF
|
||||
])
|
||||
|
||||
RFCOMM_DEFAULT_INITIAL_RX_CREDITS = 7
|
||||
RFCOMM_DEFAULT_PREFERRED_MTU = 1280
|
||||
RFCOMM_DEFAULT_WINDOW_SIZE = 16
|
||||
RFCOMM_DEFAULT_MAX_FRAME_SIZE = 2000
|
||||
|
||||
RFCOMM_DYNAMIC_CHANNEL_NUMBER_START = 1
|
||||
RFCOMM_DYNAMIC_CHANNEL_NUMBER_END = 30
|
||||
@@ -438,20 +438,24 @@ class DLC(EventEmitter):
|
||||
multiplexer: Multiplexer,
|
||||
dlci: int,
|
||||
max_frame_size: int,
|
||||
initial_tx_credits: int,
|
||||
window_size: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.multiplexer = multiplexer
|
||||
self.dlci = dlci
|
||||
self.rx_credits = RFCOMM_DEFAULT_INITIAL_RX_CREDITS
|
||||
self.rx_threshold = self.rx_credits // 2
|
||||
self.tx_credits = initial_tx_credits
|
||||
self.max_frame_size = max_frame_size
|
||||
self.window_size = window_size
|
||||
self.rx_credits = window_size
|
||||
self.rx_threshold = window_size // 2
|
||||
self.tx_credits = window_size
|
||||
self.tx_buffer = b''
|
||||
self.state = DLC.State.INIT
|
||||
self.role = multiplexer.role
|
||||
self.c_r = 1 if self.role == Multiplexer.Role.INITIATOR else 0
|
||||
self.sink = None
|
||||
self.connection_result = None
|
||||
self.drained = asyncio.Event()
|
||||
self.drained.set()
|
||||
|
||||
# Compute the MTU
|
||||
max_overhead = 4 + 1 # header with 2-byte length + fcs
|
||||
@@ -537,11 +541,11 @@ class DLC(EventEmitter):
|
||||
if len(data) and self.sink:
|
||||
self.sink(data) # pylint: disable=not-callable
|
||||
|
||||
# Update the credits
|
||||
if self.rx_credits > 0:
|
||||
self.rx_credits -= 1
|
||||
else:
|
||||
logger.warning(color('!!! received frame with no rx credits', 'red'))
|
||||
# Update the credits
|
||||
if self.rx_credits > 0:
|
||||
self.rx_credits -= 1
|
||||
else:
|
||||
logger.warning(color('!!! received frame with no rx credits', 'red'))
|
||||
|
||||
# Check if there's anything to send (including credits)
|
||||
self.process_tx()
|
||||
@@ -580,9 +584,9 @@ class DLC(EventEmitter):
|
||||
cl=0xE0,
|
||||
priority=7,
|
||||
ack_timer=0,
|
||||
max_frame_size=RFCOMM_DEFAULT_PREFERRED_MTU,
|
||||
max_frame_size=self.max_frame_size,
|
||||
max_retransmissions=0,
|
||||
window_size=RFCOMM_DEFAULT_INITIAL_RX_CREDITS,
|
||||
window_size=self.window_size,
|
||||
)
|
||||
mcc = RFCOMM_Frame.make_mcc(mcc_type=RFCOMM_MCC_PN_TYPE, c_r=0, data=bytes(pn))
|
||||
logger.debug(f'>>> PN Response: {pn}')
|
||||
@@ -591,7 +595,7 @@ class DLC(EventEmitter):
|
||||
|
||||
def rx_credits_needed(self) -> int:
|
||||
if self.rx_credits <= self.rx_threshold:
|
||||
return RFCOMM_DEFAULT_INITIAL_RX_CREDITS - self.rx_credits
|
||||
return self.window_size - self.rx_credits
|
||||
|
||||
return 0
|
||||
|
||||
@@ -631,6 +635,8 @@ class DLC(EventEmitter):
|
||||
)
|
||||
|
||||
rx_credits_needed = 0
|
||||
if not self.tx_buffer:
|
||||
self.drained.set()
|
||||
|
||||
# Stream protocol
|
||||
def write(self, data: Union[bytes, str]) -> None:
|
||||
@@ -643,11 +649,11 @@ class DLC(EventEmitter):
|
||||
raise ValueError('write only accept bytes or strings')
|
||||
|
||||
self.tx_buffer += data
|
||||
self.drained.clear()
|
||||
self.process_tx()
|
||||
|
||||
def drain(self) -> None:
|
||||
# TODO
|
||||
pass
|
||||
async def drain(self) -> None:
|
||||
await self.drained.wait()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f'DLC(dlci={self.dlci},state={self.state.name})'
|
||||
@@ -843,7 +849,12 @@ class Multiplexer(EventEmitter):
|
||||
)
|
||||
await self.disconnection_result
|
||||
|
||||
async def open_dlc(self, channel: int) -> DLC:
|
||||
async def open_dlc(
|
||||
self,
|
||||
channel: int,
|
||||
max_frame_size: int = RFCOMM_DEFAULT_MAX_FRAME_SIZE,
|
||||
window_size: int = RFCOMM_DEFAULT_WINDOW_SIZE,
|
||||
) -> DLC:
|
||||
if self.state != Multiplexer.State.CONNECTED:
|
||||
if self.state == Multiplexer.State.OPENING:
|
||||
raise InvalidStateError('open already in progress')
|
||||
@@ -855,9 +866,9 @@ class Multiplexer(EventEmitter):
|
||||
cl=0xF0,
|
||||
priority=7,
|
||||
ack_timer=0,
|
||||
max_frame_size=RFCOMM_DEFAULT_PREFERRED_MTU,
|
||||
max_frame_size=max_frame_size,
|
||||
max_retransmissions=0,
|
||||
window_size=RFCOMM_DEFAULT_INITIAL_RX_CREDITS,
|
||||
window_size=window_size,
|
||||
)
|
||||
mcc = RFCOMM_Frame.make_mcc(mcc_type=RFCOMM_MCC_PN_TYPE, c_r=1, data=bytes(pn))
|
||||
logger.debug(f'>>> Sending MCC: {pn}')
|
||||
|
||||
@@ -1090,7 +1090,7 @@ class Session:
|
||||
# We can now encrypt the connection with the short term key, so that we can
|
||||
# distribute the long term and/or other keys over an encrypted connection
|
||||
self.manager.device.host.send_command_sync(
|
||||
HCI_LE_Enable_Encryption_Command( # type: ignore[call-arg]
|
||||
HCI_LE_Enable_Encryption_Command(
|
||||
connection_handle=self.connection.handle,
|
||||
random_number=bytes(8),
|
||||
encrypted_diversifier=0,
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
from contextlib import asynccontextmanager
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from .common import Transport, AsyncPipeSink, SnoopingTransport
|
||||
from ..snoop import create_snooper
|
||||
@@ -52,8 +53,16 @@ def _wrap_transport(transport: Transport) -> Transport:
|
||||
async def open_transport(name: str) -> Transport:
|
||||
"""
|
||||
Open a transport by name.
|
||||
The name must be <type>:<parameters>
|
||||
Where <parameters> depend on the type (and may be empty for some types).
|
||||
The name must be <type>:<metadata><parameters>
|
||||
Where <parameters> depend on the type (and may be empty for some types), and
|
||||
<metadata> is either omitted, or a ,-separated list of <key>=<value> pairs,
|
||||
enclosed in [].
|
||||
If there are not metadata or parameter, the : after the <type> may be omitted.
|
||||
Examples:
|
||||
* usb:0
|
||||
* usb:[driver=rtk]0
|
||||
* android-netsim
|
||||
|
||||
The supported types are:
|
||||
* serial
|
||||
* udp
|
||||
@@ -71,87 +80,106 @@ async def open_transport(name: str) -> Transport:
|
||||
* android-netsim
|
||||
"""
|
||||
|
||||
return _wrap_transport(await _open_transport(name))
|
||||
scheme, *tail = name.split(':', 1)
|
||||
spec = tail[0] if tail else None
|
||||
if spec:
|
||||
# Metadata may precede the spec
|
||||
if spec.startswith('['):
|
||||
metadata_str, *tail = spec[1:].split(']')
|
||||
spec = tail[0] if tail else None
|
||||
metadata = dict([entry.split('=') for entry in metadata_str.split(',')])
|
||||
else:
|
||||
metadata = None
|
||||
|
||||
transport = await _open_transport(scheme, spec)
|
||||
if metadata:
|
||||
transport.source.metadata = { # type: ignore[attr-defined]
|
||||
**metadata,
|
||||
**getattr(transport.source, 'metadata', {}),
|
||||
}
|
||||
# pylint: disable=line-too-long
|
||||
logger.debug(f'HCI metadata: {transport.source.metadata}') # type: ignore[attr-defined]
|
||||
|
||||
return _wrap_transport(transport)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def _open_transport(name: str) -> Transport:
|
||||
async def _open_transport(scheme: str, spec: Optional[str]) -> Transport:
|
||||
# pylint: disable=import-outside-toplevel
|
||||
# pylint: disable=too-many-return-statements
|
||||
|
||||
scheme, *spec = name.split(':', 1)
|
||||
if scheme == 'serial' and spec:
|
||||
from .serial import open_serial_transport
|
||||
|
||||
return await open_serial_transport(spec[0])
|
||||
return await open_serial_transport(spec)
|
||||
|
||||
if scheme == 'udp' and spec:
|
||||
from .udp import open_udp_transport
|
||||
|
||||
return await open_udp_transport(spec[0])
|
||||
return await open_udp_transport(spec)
|
||||
|
||||
if scheme == 'tcp-client' and spec:
|
||||
from .tcp_client import open_tcp_client_transport
|
||||
|
||||
return await open_tcp_client_transport(spec[0])
|
||||
return await open_tcp_client_transport(spec)
|
||||
|
||||
if scheme == 'tcp-server' and spec:
|
||||
from .tcp_server import open_tcp_server_transport
|
||||
|
||||
return await open_tcp_server_transport(spec[0])
|
||||
return await open_tcp_server_transport(spec)
|
||||
|
||||
if scheme == 'ws-client' and spec:
|
||||
from .ws_client import open_ws_client_transport
|
||||
|
||||
return await open_ws_client_transport(spec[0])
|
||||
return await open_ws_client_transport(spec)
|
||||
|
||||
if scheme == 'ws-server' and spec:
|
||||
from .ws_server import open_ws_server_transport
|
||||
|
||||
return await open_ws_server_transport(spec[0])
|
||||
return await open_ws_server_transport(spec)
|
||||
|
||||
if scheme == 'pty':
|
||||
from .pty import open_pty_transport
|
||||
|
||||
return await open_pty_transport(spec[0] if spec else None)
|
||||
return await open_pty_transport(spec)
|
||||
|
||||
if scheme == 'file':
|
||||
from .file import open_file_transport
|
||||
|
||||
assert spec is not None
|
||||
return await open_file_transport(spec[0])
|
||||
return await open_file_transport(spec)
|
||||
|
||||
if scheme == 'vhci':
|
||||
from .vhci import open_vhci_transport
|
||||
|
||||
return await open_vhci_transport(spec[0] if spec else None)
|
||||
return await open_vhci_transport(spec)
|
||||
|
||||
if scheme == 'hci-socket':
|
||||
from .hci_socket import open_hci_socket_transport
|
||||
|
||||
return await open_hci_socket_transport(spec[0] if spec else None)
|
||||
return await open_hci_socket_transport(spec)
|
||||
|
||||
if scheme == 'usb':
|
||||
from .usb import open_usb_transport
|
||||
|
||||
assert spec is not None
|
||||
return await open_usb_transport(spec[0])
|
||||
assert spec
|
||||
return await open_usb_transport(spec)
|
||||
|
||||
if scheme == 'pyusb':
|
||||
from .pyusb import open_pyusb_transport
|
||||
|
||||
assert spec is not None
|
||||
return await open_pyusb_transport(spec[0])
|
||||
assert spec
|
||||
return await open_pyusb_transport(spec)
|
||||
|
||||
if scheme == 'android-emulator':
|
||||
from .android_emulator import open_android_emulator_transport
|
||||
|
||||
return await open_android_emulator_transport(spec[0] if spec else None)
|
||||
return await open_android_emulator_transport(spec)
|
||||
|
||||
if scheme == 'android-netsim':
|
||||
from .android_netsim import open_android_netsim_transport
|
||||
|
||||
return await open_android_netsim_transport(spec[0] if spec else None)
|
||||
return await open_android_netsim_transport(spec)
|
||||
|
||||
raise ValueError('unknown transport scheme')
|
||||
|
||||
|
||||
@@ -69,7 +69,7 @@ async def open_android_emulator_transport(spec: Optional[str]) -> Transport:
|
||||
mode = 'host'
|
||||
server_host = 'localhost'
|
||||
server_port = '8554'
|
||||
if spec is not None:
|
||||
if spec:
|
||||
params = spec.split(',')
|
||||
for param in params:
|
||||
if param.startswith('mode='):
|
||||
|
||||
@@ -21,7 +21,7 @@ import struct
|
||||
import asyncio
|
||||
import logging
|
||||
import io
|
||||
from typing import ContextManager, Tuple, Optional, Protocol, Dict
|
||||
from typing import Any, ContextManager, Tuple, Optional, Protocol, Dict
|
||||
|
||||
from bumble import hci
|
||||
from bumble.colors import color
|
||||
@@ -42,6 +42,7 @@ HCI_PACKET_INFO: Dict[int, Tuple[int, int, str]] = {
|
||||
hci.HCI_ACL_DATA_PACKET: (2, 2, 'H'),
|
||||
hci.HCI_SYNCHRONOUS_DATA_PACKET: (1, 2, 'B'),
|
||||
hci.HCI_EVENT_PACKET: (1, 1, 'B'),
|
||||
hci.HCI_ISO_DATA_PACKET: (2, 2, 'H'),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -59,10 +59,7 @@ async def open_hci_socket_transport(spec: Optional[str]) -> Transport:
|
||||
) from error
|
||||
|
||||
# Compute the adapter index
|
||||
if spec is None:
|
||||
adapter_index = 0
|
||||
else:
|
||||
adapter_index = int(spec)
|
||||
adapter_index = int(spec) if spec else 0
|
||||
|
||||
# Bind the socket
|
||||
# NOTE: since Python doesn't support binding with the required address format (yet),
|
||||
|
||||
@@ -108,7 +108,7 @@ async def open_usb_transport(spec: str) -> Transport:
|
||||
USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER,
|
||||
)
|
||||
|
||||
READ_SIZE = 1024
|
||||
READ_SIZE = 4096
|
||||
|
||||
class UsbPacketSink:
|
||||
def __init__(self, device, acl_out):
|
||||
|
||||
@@ -280,17 +280,14 @@ class AsyncRunner:
|
||||
def wrapper(*args, **kwargs):
|
||||
coroutine = func(*args, **kwargs)
|
||||
if queue is None:
|
||||
# Create a task to run the coroutine
|
||||
# Spawn the coroutine as a task
|
||||
async def run():
|
||||
try:
|
||||
await coroutine
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f'{color("!!! Exception in wrapper:", "red")} '
|
||||
f'{traceback.format_exc()}'
|
||||
)
|
||||
logger.exception(color("!!! Exception in wrapper:", "red"))
|
||||
|
||||
asyncio.create_task(run())
|
||||
AsyncRunner.spawn(run())
|
||||
else:
|
||||
# Queue the coroutine to be awaited by the work queue
|
||||
queue.enqueue(coroutine)
|
||||
|
||||
@@ -7,16 +7,36 @@ throughput and/or latency between two devices.
|
||||
# General Usage
|
||||
|
||||
```
|
||||
Usage: bench.py [OPTIONS] COMMAND [ARGS]...
|
||||
Usage: bumble-bench [OPTIONS] COMMAND [ARGS]...
|
||||
|
||||
Options:
|
||||
--device-config FILENAME Device configuration file
|
||||
--role [sender|receiver|ping|pong]
|
||||
--mode [gatt-client|gatt-server|l2cap-client|l2cap-server|rfcomm-client|rfcomm-server]
|
||||
--att-mtu MTU GATT MTU (gatt-client mode) [23<=x<=517]
|
||||
-s, --packet-size SIZE Packet size (server role) [8<=x<=4096]
|
||||
-c, --packet-count COUNT Packet count (server role)
|
||||
-sd, --start-delay SECONDS Start delay (server role)
|
||||
--extended-data-length TEXT Request a data length upon connection,
|
||||
specified as tx_octets/tx_time
|
||||
--rfcomm-channel INTEGER RFComm channel to use
|
||||
--rfcomm-uuid TEXT RFComm service UUID to use (ignored if
|
||||
--rfcomm-channel is not 0)
|
||||
--l2cap-psm INTEGER L2CAP PSM to use
|
||||
--l2cap-mtu INTEGER L2CAP MTU to use
|
||||
--l2cap-mps INTEGER L2CAP MPS to use
|
||||
--l2cap-max-credits INTEGER L2CAP maximum number of credits allowed for
|
||||
the peer
|
||||
-s, --packet-size SIZE Packet size (client or ping role)
|
||||
[8<=x<=4096]
|
||||
-c, --packet-count COUNT Packet count (client or ping role)
|
||||
-sd, --start-delay SECONDS Start delay (client or ping role)
|
||||
--repeat N Repeat the run N times (client and ping
|
||||
roles)(0, which is the fault, to run just
|
||||
once)
|
||||
--repeat-delay SECONDS Delay, in seconds, between repeats
|
||||
--pace MILLISECONDS Wait N milliseconds between packets (0,
|
||||
which is the fault, to send as fast as
|
||||
possible)
|
||||
--linger Don't exit at the end of a run (server and
|
||||
pong roles)
|
||||
--help Show this message and exit.
|
||||
|
||||
Commands:
|
||||
@@ -35,17 +55,18 @@ Options:
|
||||
--connection-interval, --ci CONNECTION_INTERVAL
|
||||
Connection interval (in ms)
|
||||
--phy [1m|2m|coded] PHY to use
|
||||
--authenticate Authenticate (RFComm only)
|
||||
--encrypt Encrypt the connection (RFComm only)
|
||||
--help Show this message and exit.
|
||||
```
|
||||
|
||||
|
||||
To test once device against another, one of the two devices must be running
|
||||
To test once device against another, one of the two devices must be running
|
||||
the ``peripheral`` command and the other the ``central`` command. The device
|
||||
running the ``peripheral`` command will accept connections from the device
|
||||
running the ``central`` command.
|
||||
When using Bluetooth LE (all modes except for ``rfcomm-server`` and ``rfcomm-client``utils),
|
||||
the default addresses configured in the tool should be sufficient. But when using
|
||||
Bluetooth Classic, the address of the Peripheral must be specified on the Central
|
||||
the default addresses configured in the tool should be sufficient. But when using
|
||||
Bluetooth Classic, the address of the Peripheral must be specified on the Central
|
||||
using the ``--peripheral`` option. The address will be printed by the Peripheral when
|
||||
it starts.
|
||||
|
||||
@@ -83,7 +104,7 @@ the other on `usb:1`, and two consoles/terminals. We will run a command in each.
|
||||
$ bumble-bench central usb:1
|
||||
```
|
||||
|
||||
In this default configuration, the Central runs a Sender, as a GATT client,
|
||||
In this default configuration, the Central runs a Sender, as a GATT client,
|
||||
connecting to the Peripheral running a Receiver, as a GATT server.
|
||||
|
||||
!!! example "L2CAP Throughput"
|
||||
|
||||
@@ -5,6 +5,15 @@ Some Bluetooth controllers require a driver to function properly.
|
||||
This may include, for instance, loading a Firmware image or patch,
|
||||
loading a configuration.
|
||||
|
||||
By default, drivers will be automatically probed to determine if they should be
|
||||
used with particular HCI controller.
|
||||
When the transport for an HCI controller is instantiated from a transport name,
|
||||
a driver may also be forced by specifying ``driver=<driver-name>`` in the optional
|
||||
metadata portion of the transport name. For example,
|
||||
``usb:[driver=-rtk]0`` indicates that the ``rtk`` driver should be used with the
|
||||
first USB device, even if a normal probe would not have selected it based on the
|
||||
USB vendor ID and product ID.
|
||||
|
||||
Drivers included in the module are:
|
||||
|
||||
* [Realtek](realtek.md): Loading of Firmware and Config for Realtek USB dongles.
|
||||
@@ -1,13 +1,16 @@
|
||||
REALTEK DRIVER
|
||||
==============
|
||||
|
||||
This driver supports loading firmware images and optional config data to
|
||||
This driver supports loading firmware images and optional config data to
|
||||
USB dongles with a Realtek chipset.
|
||||
A number of USB dongles are supported, but likely not all.
|
||||
When using a USB dongle, the USB product ID and manufacturer ID are used
|
||||
When using a USB dongle, the USB product ID and vendor ID are used
|
||||
to find whether a matching set of firmware image and config data
|
||||
is needed for that specific model. If a match exists, the driver will try
|
||||
load the firmware image and, if needed, config data.
|
||||
Alternatively, the metadata property ``driver=rtk`` may be specified in a transport
|
||||
name to force that driver to be used (ex: ``usb:[driver=rtk]0`` instead of just
|
||||
``usb:0`` for the first USB device).
|
||||
The driver will look for those files by name, in order, in:
|
||||
|
||||
* The directory specified by the environment variable `BUMBLE_RTK_FIRMWARE_DIR`
|
||||
|
||||
5
examples/hid_keyboard.json
Normal file
5
examples/hid_keyboard.json
Normal file
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"name": "Bumble HID Keyboard",
|
||||
"class_of_device": 9664,
|
||||
"keystore": "JsonKeyStore"
|
||||
}
|
||||
@@ -40,9 +40,9 @@
|
||||
}
|
||||
}
|
||||
function onMouseMove(event) {
|
||||
//console.log(event.clientX, event.clientY)
|
||||
mouseInfo.innerText = `MOUSE: x=${event.clientX}, y=${event.clientY}`
|
||||
send({ type:'mousemove', x: event.clientX, y: event.clientY })
|
||||
//console.log(event.movementX, event.movementY)
|
||||
mouseInfo.innerText = `MOUSE: x=${event.movementX}, y=${event.movementY}`
|
||||
send({ type:'mousemove', x: event.movementX, y: event.movementY })
|
||||
}
|
||||
|
||||
function onKeyDown(event) {
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
{
|
||||
"name": "Bumble-LEA",
|
||||
"keystore": "JsonKeyStore",
|
||||
"address": "F0:F1:F2:F3:F4:FA",
|
||||
"advertising_interval": 100
|
||||
}
|
||||
|
||||
116
examples/run_csis_servers.py
Normal file
116
examples/run_csis_servers.py
Normal file
@@ -0,0 +1,116 @@
|
||||
# Copyright 2021-2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
import os
|
||||
import secrets
|
||||
|
||||
from bumble.core import AdvertisingData
|
||||
from bumble.device import Device
|
||||
from bumble.hci import (
|
||||
Address,
|
||||
OwnAddressType,
|
||||
HCI_LE_Set_Extended_Advertising_Parameters_Command,
|
||||
)
|
||||
from bumble.profiles.cap import CommonAudioServiceService
|
||||
from bumble.profiles.csip import CoordinatedSetIdentificationService, SirkType
|
||||
|
||||
from bumble.transport import open_transport_or_link
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def main() -> None:
|
||||
if len(sys.argv) < 3:
|
||||
print(
|
||||
'Usage: run_cig_setup.py <config-file>'
|
||||
'<transport-spec-for-device-1> <transport-spec-for-device-2>'
|
||||
)
|
||||
print(
|
||||
'example: run_cig_setup.py device1.json'
|
||||
'tcp-client:127.0.0.1:6402 tcp-client:127.0.0.1:6402'
|
||||
)
|
||||
return
|
||||
|
||||
print('<<< connecting to HCI...')
|
||||
hci_transports = await asyncio.gather(
|
||||
open_transport_or_link(sys.argv[2]), open_transport_or_link(sys.argv[3])
|
||||
)
|
||||
print('<<< connected')
|
||||
|
||||
devices = [
|
||||
Device.from_config_file_with_hci(
|
||||
sys.argv[1], hci_transport.source, hci_transport.sink
|
||||
)
|
||||
for hci_transport in hci_transports
|
||||
]
|
||||
|
||||
sirk = secrets.token_bytes(16)
|
||||
|
||||
for i, device in enumerate(devices):
|
||||
device.random_address = Address(secrets.token_bytes(6))
|
||||
await device.power_on()
|
||||
csis = CoordinatedSetIdentificationService(
|
||||
set_identity_resolving_key=sirk,
|
||||
set_identity_resolving_key_type=SirkType.PLAINTEXT,
|
||||
coordinated_set_size=2,
|
||||
)
|
||||
device.add_service(CommonAudioServiceService(csis))
|
||||
advertising_data = (
|
||||
bytes(
|
||||
AdvertisingData(
|
||||
[
|
||||
(
|
||||
AdvertisingData.COMPLETE_LOCAL_NAME,
|
||||
bytes(f'Bumble LE Audio-{i}', 'utf-8'),
|
||||
),
|
||||
(
|
||||
AdvertisingData.FLAGS,
|
||||
bytes(
|
||||
[
|
||||
AdvertisingData.LE_GENERAL_DISCOVERABLE_MODE_FLAG
|
||||
| AdvertisingData.BR_EDR_HOST_FLAG
|
||||
| AdvertisingData.BR_EDR_CONTROLLER_FLAG
|
||||
]
|
||||
),
|
||||
),
|
||||
(
|
||||
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
|
||||
bytes(CoordinatedSetIdentificationService.UUID),
|
||||
),
|
||||
]
|
||||
)
|
||||
)
|
||||
+ csis.get_advertising_data()
|
||||
)
|
||||
await device.start_extended_advertising(
|
||||
advertising_properties=(
|
||||
HCI_LE_Set_Extended_Advertising_Parameters_Command.AdvertisingProperties.CONNECTABLE_ADVERTISING
|
||||
),
|
||||
own_address_type=OwnAddressType.RANDOM,
|
||||
advertising_data=advertising_data,
|
||||
)
|
||||
|
||||
await asyncio.gather(
|
||||
*[hci_transport.source.terminated for hci_transport in hci_transports]
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
|
||||
asyncio.run(main())
|
||||
@@ -73,7 +73,6 @@ async def main() -> None:
|
||||
HCI_Enhanced_Setup_Synchronous_Connection_Command(
|
||||
connection_handle=connections[0].handle,
|
||||
**ESCO_PARAMETERS[DefaultCodecParameters.ESCO_CVSD_S3].asdict(),
|
||||
# type: ignore[call-args]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
748
examples/run_hid_device.py
Normal file
748
examples/run_hid_device.py
Normal file
@@ -0,0 +1,748 @@
|
||||
# Copyright 2021-2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import asyncio
|
||||
import sys
|
||||
import os
|
||||
import logging
|
||||
import json
|
||||
import websockets
|
||||
from bumble.colors import color
|
||||
|
||||
from bumble.device import Device
|
||||
from bumble.transport import open_transport_or_link
|
||||
from bumble.core import (
|
||||
BT_BR_EDR_TRANSPORT,
|
||||
BT_L2CAP_PROTOCOL_ID,
|
||||
BT_HUMAN_INTERFACE_DEVICE_SERVICE,
|
||||
BT_HIDP_PROTOCOL_ID,
|
||||
UUID,
|
||||
)
|
||||
from bumble.hci import Address
|
||||
from bumble.hid import (
|
||||
Device as HID_Device,
|
||||
HID_CONTROL_PSM,
|
||||
HID_INTERRUPT_PSM,
|
||||
Message,
|
||||
)
|
||||
from bumble.sdp import (
|
||||
Client as SDP_Client,
|
||||
DataElement,
|
||||
ServiceAttribute,
|
||||
SDP_PUBLIC_BROWSE_ROOT,
|
||||
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
|
||||
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
|
||||
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
|
||||
SDP_ALL_ATTRIBUTES_RANGE,
|
||||
SDP_LANGUAGE_BASE_ATTRIBUTE_ID_LIST_ATTRIBUTE_ID,
|
||||
SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
|
||||
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
|
||||
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
|
||||
)
|
||||
from bumble.utils import AsyncRunner
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# SDP attributes for Bluetooth HID devices
|
||||
SDP_HID_SERVICE_NAME_ATTRIBUTE_ID = 0x0100
|
||||
SDP_HID_SERVICE_DESCRIPTION_ATTRIBUTE_ID = 0x0101
|
||||
SDP_HID_PROVIDER_NAME_ATTRIBUTE_ID = 0x0102
|
||||
SDP_HID_DEVICE_RELEASE_NUMBER_ATTRIBUTE_ID = 0x0200 # [DEPRECATED]
|
||||
SDP_HID_PARSER_VERSION_ATTRIBUTE_ID = 0x0201
|
||||
SDP_HID_DEVICE_SUBCLASS_ATTRIBUTE_ID = 0x0202
|
||||
SDP_HID_COUNTRY_CODE_ATTRIBUTE_ID = 0x0203
|
||||
SDP_HID_VIRTUAL_CABLE_ATTRIBUTE_ID = 0x0204
|
||||
SDP_HID_RECONNECT_INITIATE_ATTRIBUTE_ID = 0x0205
|
||||
SDP_HID_DESCRIPTOR_LIST_ATTRIBUTE_ID = 0x0206
|
||||
SDP_HID_LANGID_BASE_LIST_ATTRIBUTE_ID = 0x0207
|
||||
SDP_HID_SDP_DISABLE_ATTRIBUTE_ID = 0x0208 # [DEPRECATED]
|
||||
SDP_HID_BATTERY_POWER_ATTRIBUTE_ID = 0x0209
|
||||
SDP_HID_REMOTE_WAKE_ATTRIBUTE_ID = 0x020A
|
||||
SDP_HID_PROFILE_VERSION_ATTRIBUTE_ID = 0x020B # DEPRECATED]
|
||||
SDP_HID_SUPERVISION_TIMEOUT_ATTRIBUTE_ID = 0x020C
|
||||
SDP_HID_NORMALLY_CONNECTABLE_ATTRIBUTE_ID = 0x020D
|
||||
SDP_HID_BOOT_DEVICE_ATTRIBUTE_ID = 0x020E
|
||||
SDP_HID_SSR_HOST_MAX_LATENCY_ATTRIBUTE_ID = 0x020F
|
||||
SDP_HID_SSR_HOST_MIN_TIMEOUT_ATTRIBUTE_ID = 0x0210
|
||||
|
||||
# Refer to HID profile specification v1.1.1, "5.3 Service Discovery Protocol (SDP)" for details
|
||||
# HID SDP attribute values
|
||||
LANGUAGE = 0x656E # 0x656E uint16 “en” (English)
|
||||
ENCODING = 0x6A # 0x006A uint16 UTF-8 encoding
|
||||
PRIMARY_LANGUAGE_BASE_ID = 0x100 # 0x0100 uint16 PrimaryLanguageBaseID
|
||||
VERSION_NUMBER = 0x0101 # 0x0101 uint16 version number (v1.1)
|
||||
SERVICE_NAME = b'Bumble HID'
|
||||
SERVICE_DESCRIPTION = b'Bumble'
|
||||
PROVIDER_NAME = b'Bumble'
|
||||
HID_PARSER_VERSION = 0x0111 # uint16 0x0111 (v1.1.1)
|
||||
HID_DEVICE_SUBCLASS = 0xC0 # Combo keyboard/pointing device
|
||||
HID_COUNTRY_CODE = 0x21 # 0x21 Uint8, USA
|
||||
HID_VIRTUAL_CABLE = True # Virtual cable enabled
|
||||
HID_RECONNECT_INITIATE = True # Reconnect initiate enabled
|
||||
REPORT_DESCRIPTOR_TYPE = 0x22 # 0x22 Type = Report Descriptor
|
||||
HID_LANGID_BASE_LANGUAGE = 0x0409 # 0x0409 Language = English (United States)
|
||||
HID_LANGID_BASE_BLUETOOTH_STRING_OFFSET = 0x100 # 0x0100 Default
|
||||
HID_BATTERY_POWER = True # Battery power enabled
|
||||
HID_REMOTE_WAKE = True # Remote wake enabled
|
||||
HID_SUPERVISION_TIMEOUT = 0xC80 # uint16 0xC80 (2s)
|
||||
HID_NORMALLY_CONNECTABLE = True # Normally connectable enabled
|
||||
HID_BOOT_DEVICE = True # Boot device support enabled
|
||||
HID_SSR_HOST_MAX_LATENCY = 0x640 # uint16 0x640 (1s)
|
||||
HID_SSR_HOST_MIN_TIMEOUT = 0xC80 # uint16 0xC80 (2s)
|
||||
HID_REPORT_MAP = bytes( # Text String, 50 Octet Report Descriptor
|
||||
# pylint: disable=line-too-long
|
||||
[
|
||||
0x05,
|
||||
0x01, # Usage Page (Generic Desktop Ctrls)
|
||||
0x09,
|
||||
0x06, # Usage (Keyboard)
|
||||
0xA1,
|
||||
0x01, # Collection (Application)
|
||||
0x85,
|
||||
0x01, # . Report ID (1)
|
||||
0x05,
|
||||
0x07, # . Usage Page (Kbrd/Keypad)
|
||||
0x19,
|
||||
0xE0, # . Usage Minimum (0xE0)
|
||||
0x29,
|
||||
0xE7, # . Usage Maximum (0xE7)
|
||||
0x15,
|
||||
0x00, # . Logical Minimum (0)
|
||||
0x25,
|
||||
0x01, # . Logical Maximum (1)
|
||||
0x75,
|
||||
0x01, # . Report Size (1)
|
||||
0x95,
|
||||
0x08, # . Report Count (8)
|
||||
0x81,
|
||||
0x02, # . Input (Data,Var,Abs,No Wrap,Linear,Preferred State,No Null Position)
|
||||
0x95,
|
||||
0x01, # . Report Count (1)
|
||||
0x75,
|
||||
0x08, # . Report Size (8)
|
||||
0x81,
|
||||
0x03, # . Input (Const,Var,Abs,No Wrap,Linear,Preferred State,No Null Position)
|
||||
0x95,
|
||||
0x05, # . Report Count (5)
|
||||
0x75,
|
||||
0x01, # . Report Size (1)
|
||||
0x05,
|
||||
0x08, # . Usage Page (LEDs)
|
||||
0x19,
|
||||
0x01, # . Usage Minimum (Num Lock)
|
||||
0x29,
|
||||
0x05, # . Usage Maximum (Kana)
|
||||
0x91,
|
||||
0x02, # . Output (Data,Var,Abs,No Wrap,Linear,Preferred State,No Null Position,Non-volatile)
|
||||
0x95,
|
||||
0x01, # . Report Count (1)
|
||||
0x75,
|
||||
0x03, # . Report Size (3)
|
||||
0x91,
|
||||
0x03, # . Output (Const,Var,Abs,No Wrap,Linear,Preferred State,No Null Position,Non-volatile)
|
||||
0x95,
|
||||
0x06, # . Report Count (6)
|
||||
0x75,
|
||||
0x08, # . Report Size (8)
|
||||
0x15,
|
||||
0x00, # . Logical Minimum (0)
|
||||
0x25,
|
||||
0x65, # . Logical Maximum (101)
|
||||
0x05,
|
||||
0x07, # . Usage Page (Kbrd/Keypad)
|
||||
0x19,
|
||||
0x00, # . Usage Minimum (0x00)
|
||||
0x29,
|
||||
0x65, # . Usage Maximum (0x65)
|
||||
0x81,
|
||||
0x00, # . Input (Data,Array,Abs,No Wrap,Linear,Preferred State,No Null Position)
|
||||
0xC0, # End Collection
|
||||
0x05,
|
||||
0x01, # Usage Page (Generic Desktop Ctrls)
|
||||
0x09,
|
||||
0x02, # Usage (Mouse)
|
||||
0xA1,
|
||||
0x01, # Collection (Application)
|
||||
0x85,
|
||||
0x02, # . Report ID (2)
|
||||
0x09,
|
||||
0x01, # . Usage (Pointer)
|
||||
0xA1,
|
||||
0x00, # . Collection (Physical)
|
||||
0x05,
|
||||
0x09, # . Usage Page (Button)
|
||||
0x19,
|
||||
0x01, # . Usage Minimum (0x01)
|
||||
0x29,
|
||||
0x03, # . Usage Maximum (0x03)
|
||||
0x15,
|
||||
0x00, # . Logical Minimum (0)
|
||||
0x25,
|
||||
0x01, # . Logical Maximum (1)
|
||||
0x95,
|
||||
0x03, # . Report Count (3)
|
||||
0x75,
|
||||
0x01, # . Report Size (1)
|
||||
0x81,
|
||||
0x02, # . Input (Data,Var,Abs,No Wrap,Linear,Preferred State,No Null Position)
|
||||
0x95,
|
||||
0x01, # . Report Count (1)
|
||||
0x75,
|
||||
0x05, # . Report Size (5)
|
||||
0x81,
|
||||
0x03, # . Input (Const,Var,Abs,No Wrap,Linear,Preferred State,No Null Position)
|
||||
0x05,
|
||||
0x01, # . Usage Page (Generic Desktop Ctrls)
|
||||
0x09,
|
||||
0x30, # . Usage (X)
|
||||
0x09,
|
||||
0x31, # . Usage (Y)
|
||||
0x15,
|
||||
0x81, # . Logical Minimum (-127)
|
||||
0x25,
|
||||
0x7F, # . Logical Maximum (127)
|
||||
0x75,
|
||||
0x08, # . Report Size (8)
|
||||
0x95,
|
||||
0x02, # . Report Count (2)
|
||||
0x81,
|
||||
0x06, # . Input (Data,Var,Rel,No Wrap,Linear,Preferred State,No Null Position)
|
||||
0xC0, # . End Collection
|
||||
0xC0, # End Collection
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# Default protocol mode set to report protocol
|
||||
protocol_mode = Message.ProtocolMode.REPORT_PROTOCOL
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def sdp_records():
|
||||
service_record_handle = 0x00010002
|
||||
return {
|
||||
service_record_handle: [
|
||||
ServiceAttribute(
|
||||
SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID,
|
||||
DataElement.unsigned_integer_32(service_record_handle),
|
||||
),
|
||||
ServiceAttribute(
|
||||
SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID,
|
||||
DataElement.sequence([DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)]),
|
||||
),
|
||||
ServiceAttribute(
|
||||
SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID,
|
||||
DataElement.sequence(
|
||||
[DataElement.uuid(BT_HUMAN_INTERFACE_DEVICE_SERVICE)]
|
||||
),
|
||||
),
|
||||
ServiceAttribute(
|
||||
SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
|
||||
DataElement.sequence(
|
||||
[
|
||||
DataElement.sequence(
|
||||
[
|
||||
DataElement.uuid(BT_L2CAP_PROTOCOL_ID),
|
||||
DataElement.unsigned_integer_16(HID_CONTROL_PSM),
|
||||
]
|
||||
),
|
||||
DataElement.sequence(
|
||||
[
|
||||
DataElement.uuid(BT_HIDP_PROTOCOL_ID),
|
||||
]
|
||||
),
|
||||
]
|
||||
),
|
||||
),
|
||||
ServiceAttribute(
|
||||
SDP_LANGUAGE_BASE_ATTRIBUTE_ID_LIST_ATTRIBUTE_ID,
|
||||
DataElement.sequence(
|
||||
[
|
||||
DataElement.unsigned_integer_16(LANGUAGE),
|
||||
DataElement.unsigned_integer_16(ENCODING),
|
||||
DataElement.unsigned_integer_16(PRIMARY_LANGUAGE_BASE_ID),
|
||||
]
|
||||
),
|
||||
),
|
||||
ServiceAttribute(
|
||||
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
|
||||
DataElement.sequence(
|
||||
[
|
||||
DataElement.sequence(
|
||||
[
|
||||
DataElement.uuid(BT_HUMAN_INTERFACE_DEVICE_SERVICE),
|
||||
DataElement.unsigned_integer_16(VERSION_NUMBER),
|
||||
]
|
||||
),
|
||||
]
|
||||
),
|
||||
),
|
||||
ServiceAttribute(
|
||||
SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID,
|
||||
DataElement.sequence(
|
||||
[
|
||||
DataElement.sequence(
|
||||
[
|
||||
DataElement.sequence(
|
||||
[
|
||||
DataElement.uuid(BT_L2CAP_PROTOCOL_ID),
|
||||
DataElement.unsigned_integer_16(
|
||||
HID_INTERRUPT_PSM
|
||||
),
|
||||
]
|
||||
),
|
||||
DataElement.sequence(
|
||||
[
|
||||
DataElement.uuid(BT_HIDP_PROTOCOL_ID),
|
||||
]
|
||||
),
|
||||
]
|
||||
),
|
||||
]
|
||||
),
|
||||
),
|
||||
ServiceAttribute(
|
||||
SDP_HID_SERVICE_NAME_ATTRIBUTE_ID,
|
||||
DataElement(DataElement.TEXT_STRING, SERVICE_NAME),
|
||||
),
|
||||
ServiceAttribute(
|
||||
SDP_HID_SERVICE_DESCRIPTION_ATTRIBUTE_ID,
|
||||
DataElement(DataElement.TEXT_STRING, SERVICE_DESCRIPTION),
|
||||
),
|
||||
ServiceAttribute(
|
||||
SDP_HID_PROVIDER_NAME_ATTRIBUTE_ID,
|
||||
DataElement(DataElement.TEXT_STRING, PROVIDER_NAME),
|
||||
),
|
||||
ServiceAttribute(
|
||||
SDP_HID_PARSER_VERSION_ATTRIBUTE_ID,
|
||||
DataElement.unsigned_integer_32(HID_PARSER_VERSION),
|
||||
),
|
||||
ServiceAttribute(
|
||||
SDP_HID_DEVICE_SUBCLASS_ATTRIBUTE_ID,
|
||||
DataElement.unsigned_integer_32(HID_DEVICE_SUBCLASS),
|
||||
),
|
||||
ServiceAttribute(
|
||||
SDP_HID_COUNTRY_CODE_ATTRIBUTE_ID,
|
||||
DataElement.unsigned_integer_32(HID_COUNTRY_CODE),
|
||||
),
|
||||
ServiceAttribute(
|
||||
SDP_HID_VIRTUAL_CABLE_ATTRIBUTE_ID,
|
||||
DataElement.boolean(HID_VIRTUAL_CABLE),
|
||||
),
|
||||
ServiceAttribute(
|
||||
SDP_HID_RECONNECT_INITIATE_ATTRIBUTE_ID,
|
||||
DataElement.boolean(HID_RECONNECT_INITIATE),
|
||||
),
|
||||
ServiceAttribute(
|
||||
SDP_HID_DESCRIPTOR_LIST_ATTRIBUTE_ID,
|
||||
DataElement.sequence(
|
||||
[
|
||||
DataElement.sequence(
|
||||
[
|
||||
DataElement.unsigned_integer_16(REPORT_DESCRIPTOR_TYPE),
|
||||
DataElement(DataElement.TEXT_STRING, HID_REPORT_MAP),
|
||||
]
|
||||
),
|
||||
]
|
||||
),
|
||||
),
|
||||
ServiceAttribute(
|
||||
SDP_HID_LANGID_BASE_LIST_ATTRIBUTE_ID,
|
||||
DataElement.sequence(
|
||||
[
|
||||
DataElement.sequence(
|
||||
[
|
||||
DataElement.unsigned_integer_16(
|
||||
HID_LANGID_BASE_LANGUAGE
|
||||
),
|
||||
DataElement.unsigned_integer_16(
|
||||
HID_LANGID_BASE_BLUETOOTH_STRING_OFFSET
|
||||
),
|
||||
]
|
||||
),
|
||||
]
|
||||
),
|
||||
),
|
||||
ServiceAttribute(
|
||||
SDP_HID_BATTERY_POWER_ATTRIBUTE_ID,
|
||||
DataElement.boolean(HID_BATTERY_POWER),
|
||||
),
|
||||
ServiceAttribute(
|
||||
SDP_HID_REMOTE_WAKE_ATTRIBUTE_ID,
|
||||
DataElement.boolean(HID_REMOTE_WAKE),
|
||||
),
|
||||
ServiceAttribute(
|
||||
SDP_HID_SUPERVISION_TIMEOUT_ATTRIBUTE_ID,
|
||||
DataElement.unsigned_integer_16(HID_SUPERVISION_TIMEOUT),
|
||||
),
|
||||
ServiceAttribute(
|
||||
SDP_HID_NORMALLY_CONNECTABLE_ATTRIBUTE_ID,
|
||||
DataElement.boolean(HID_NORMALLY_CONNECTABLE),
|
||||
),
|
||||
ServiceAttribute(
|
||||
SDP_HID_BOOT_DEVICE_ATTRIBUTE_ID,
|
||||
DataElement.boolean(HID_BOOT_DEVICE),
|
||||
),
|
||||
ServiceAttribute(
|
||||
SDP_HID_SSR_HOST_MAX_LATENCY_ATTRIBUTE_ID,
|
||||
DataElement.unsigned_integer_16(HID_SSR_HOST_MAX_LATENCY),
|
||||
),
|
||||
ServiceAttribute(
|
||||
SDP_HID_SSR_HOST_MIN_TIMEOUT_ATTRIBUTE_ID,
|
||||
DataElement.unsigned_integer_16(HID_SSR_HOST_MIN_TIMEOUT),
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def get_stream_reader(pipe) -> asyncio.StreamReader:
|
||||
loop = asyncio.get_event_loop()
|
||||
reader = asyncio.StreamReader(loop=loop)
|
||||
protocol = asyncio.StreamReaderProtocol(reader)
|
||||
await loop.connect_read_pipe(lambda: protocol, pipe)
|
||||
return reader
|
||||
|
||||
|
||||
class DeviceData:
|
||||
def __init__(self) -> None:
|
||||
self.keyboardData = bytearray(
|
||||
[0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]
|
||||
)
|
||||
self.mouseData = bytearray([0x02, 0x00, 0x00, 0x00])
|
||||
|
||||
|
||||
# Device's live data - Mouse and Keyboard will be stored in this
|
||||
deviceData = DeviceData()
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def keyboard_device(hid_device):
|
||||
|
||||
# Start a Websocket server to receive events from a web page
|
||||
async def serve(websocket, _path):
|
||||
global deviceData
|
||||
while True:
|
||||
try:
|
||||
message = await websocket.recv()
|
||||
print('Received: ', str(message))
|
||||
parsed = json.loads(message)
|
||||
message_type = parsed['type']
|
||||
if message_type == 'keydown':
|
||||
# Only deal with keys a to z for now
|
||||
key = parsed['key']
|
||||
if len(key) == 1:
|
||||
code = ord(key)
|
||||
if ord('a') <= code <= ord('z'):
|
||||
hid_code = 0x04 + code - ord('a')
|
||||
deviceData.keyboardData = bytearray(
|
||||
[
|
||||
0x01,
|
||||
0x00,
|
||||
0x00,
|
||||
hid_code,
|
||||
0x00,
|
||||
0x00,
|
||||
0x00,
|
||||
0x00,
|
||||
0x00,
|
||||
]
|
||||
)
|
||||
hid_device.send_data(deviceData.keyboardData)
|
||||
elif message_type == 'keyup':
|
||||
deviceData.keyboardData = bytearray(
|
||||
[0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]
|
||||
)
|
||||
hid_device.send_data(deviceData.keyboardData)
|
||||
elif message_type == "mousemove":
|
||||
# logical min and max values
|
||||
log_min = -127
|
||||
log_max = 127
|
||||
x = parsed['x']
|
||||
y = parsed['y']
|
||||
# limiting x and y values within logical max and min range
|
||||
x = max(log_min, min(log_max, x))
|
||||
y = max(log_min, min(log_max, y))
|
||||
x_cord = x.to_bytes(signed=True)
|
||||
y_cord = y.to_bytes(signed=True)
|
||||
deviceData.mouseData = bytearray([0x02, 0x00]) + x_cord + y_cord
|
||||
hid_device.send_data(deviceData.mouseData)
|
||||
except websockets.exceptions.ConnectionClosedOK:
|
||||
pass
|
||||
|
||||
# pylint: disable-next=no-member
|
||||
await websockets.serve(serve, 'localhost', 8989)
|
||||
await asyncio.get_event_loop().create_future()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def main():
|
||||
if len(sys.argv) < 3:
|
||||
print(
|
||||
'Usage: python run_hid_device.py <device-config> <transport-spec> <command>'
|
||||
' where <command> is one of:\n'
|
||||
' test-mode (run with menu enabled for testing)\n'
|
||||
' web (run a keyboard with keypress input from a web page, '
|
||||
'see keyboard.html'
|
||||
)
|
||||
print('example: python run_hid_device.py hid_keyboard.json usb:0 web')
|
||||
print('example: python run_hid_device.py hid_keyboard.json usb:0 test-mode')
|
||||
|
||||
return
|
||||
|
||||
async def handle_virtual_cable_unplug():
|
||||
hid_host_bd_addr = str(hid_device.remote_device_bd_address)
|
||||
await hid_device.disconnect_interrupt_channel()
|
||||
await hid_device.disconnect_control_channel()
|
||||
await device.keystore.delete(hid_host_bd_addr) # type: ignore
|
||||
connection = hid_device.connection
|
||||
if connection is not None:
|
||||
await connection.disconnect()
|
||||
|
||||
def on_hid_data_cb(pdu: bytes):
|
||||
print(f'Received Data, PDU: {pdu.hex()}')
|
||||
|
||||
def on_get_report_cb(report_id: int, report_type: int, buffer_size: int):
|
||||
retValue = hid_device.GetSetStatus()
|
||||
print(
|
||||
"GET_REPORT report_id: "
|
||||
+ str(report_id)
|
||||
+ "report_type: "
|
||||
+ str(report_type)
|
||||
+ "buffer_size:"
|
||||
+ str(buffer_size)
|
||||
)
|
||||
if report_type == Message.ReportType.INPUT_REPORT:
|
||||
if report_id == 1:
|
||||
retValue.data = deviceData.keyboardData[1:]
|
||||
retValue.status = hid_device.GetSetReturn.SUCCESS
|
||||
elif report_id == 2:
|
||||
retValue.data = deviceData.mouseData[1:]
|
||||
retValue.status = hid_device.GetSetReturn.SUCCESS
|
||||
else:
|
||||
retValue.status = hid_device.GetSetReturn.REPORT_ID_NOT_FOUND
|
||||
|
||||
if buffer_size:
|
||||
data_len = buffer_size - 1
|
||||
retValue.data = retValue.data[:data_len]
|
||||
elif report_type == Message.ReportType.OUTPUT_REPORT:
|
||||
# This sample app has nothing to do with the report received, to enable PTS
|
||||
# testing, we will return single byte random data.
|
||||
retValue.data = bytearray([0x11])
|
||||
retValue.status = hid_device.GetSetReturn.SUCCESS
|
||||
elif report_type == Message.ReportType.FEATURE_REPORT:
|
||||
retValue.status = hid_device.GetSetReturn.ERR_INVALID_PARAMETER
|
||||
elif report_type == Message.ReportType.OTHER_REPORT:
|
||||
if report_id == 3:
|
||||
retValue.status = hid_device.GetSetReturn.REPORT_ID_NOT_FOUND
|
||||
else:
|
||||
retValue.status = hid_device.GetSetReturn.FAILURE
|
||||
|
||||
return retValue
|
||||
|
||||
def on_set_report_cb(
|
||||
report_id: int, report_type: int, report_size: int, data: bytes
|
||||
):
|
||||
retValue = hid_device.GetSetStatus()
|
||||
print(
|
||||
"SET_REPORT report_id: "
|
||||
+ str(report_id)
|
||||
+ "report_type: "
|
||||
+ str(report_type)
|
||||
+ "report_size "
|
||||
+ str(report_size)
|
||||
+ "data:"
|
||||
+ str(data)
|
||||
)
|
||||
if report_type == Message.ReportType.FEATURE_REPORT:
|
||||
retValue.status = hid_device.GetSetReturn.ERR_INVALID_PARAMETER
|
||||
elif report_type == Message.ReportType.INPUT_REPORT:
|
||||
if report_id == 1 and report_size != len(deviceData.keyboardData):
|
||||
retValue.status = hid_device.GetSetReturn.ERR_INVALID_PARAMETER
|
||||
elif report_id == 2 and report_size != len(deviceData.mouseData):
|
||||
retValue.status = hid_device.GetSetReturn.ERR_INVALID_PARAMETER
|
||||
elif report_id == 3:
|
||||
retValue.status = hid_device.GetSetReturn.REPORT_ID_NOT_FOUND
|
||||
else:
|
||||
retValue.status = hid_device.GetSetReturn.SUCCESS
|
||||
else:
|
||||
retValue.status = hid_device.GetSetReturn.SUCCESS
|
||||
|
||||
return retValue
|
||||
|
||||
def on_get_protocol_cb():
|
||||
retValue = hid_device.GetSetStatus()
|
||||
retValue.data = protocol_mode.to_bytes()
|
||||
retValue.status = hid_device.GetSetReturn.SUCCESS
|
||||
return retValue
|
||||
|
||||
def on_set_protocol_cb(protocol: int):
|
||||
retValue = hid_device.GetSetStatus()
|
||||
# We do not support SET_PROTOCOL.
|
||||
print(f"SET_PROTOCOL report_id: {protocol}")
|
||||
retValue.status = hid_device.GetSetReturn.ERR_UNSUPPORTED_REQUEST
|
||||
return retValue
|
||||
|
||||
def on_virtual_cable_unplug_cb():
|
||||
print('Received Virtual Cable Unplug')
|
||||
asyncio.create_task(handle_virtual_cable_unplug())
|
||||
|
||||
print('<<< connecting to HCI...')
|
||||
async with await open_transport_or_link(sys.argv[2]) as (hci_source, hci_sink):
|
||||
print('<<< connected')
|
||||
|
||||
# Create a device
|
||||
device = Device.from_config_file_with_hci(sys.argv[1], hci_source, hci_sink)
|
||||
device.classic_enabled = True
|
||||
|
||||
# Create and register HID device
|
||||
hid_device = HID_Device(device)
|
||||
|
||||
# Register for call backs
|
||||
hid_device.on('interrupt_data', on_hid_data_cb)
|
||||
|
||||
hid_device.register_get_report_cb(on_get_report_cb)
|
||||
hid_device.register_set_report_cb(on_set_report_cb)
|
||||
hid_device.register_get_protocol_cb(on_get_protocol_cb)
|
||||
hid_device.register_set_protocol_cb(on_set_protocol_cb)
|
||||
|
||||
# Register for virtual cable unplug call back
|
||||
hid_device.on('virtual_cable_unplug', on_virtual_cable_unplug_cb)
|
||||
|
||||
# Setup the SDP to advertise HID Device service
|
||||
device.sdp_service_records = sdp_records()
|
||||
|
||||
# Start the controller
|
||||
await device.power_on()
|
||||
|
||||
# Start being discoverable and connectable
|
||||
await device.set_discoverable(True)
|
||||
await device.set_connectable(True)
|
||||
|
||||
async def menu():
|
||||
reader = await get_stream_reader(sys.stdin)
|
||||
while True:
|
||||
print(
|
||||
"\n************************ HID Device Menu *****************************\n"
|
||||
)
|
||||
print(" 1. Connect Control Channel")
|
||||
print(" 2. Connect Interrupt Channel")
|
||||
print(" 3. Disconnect Control Channel")
|
||||
print(" 4. Disconnect Interrupt Channel")
|
||||
print(" 5. Send Report on Interrupt Channel")
|
||||
print(" 6. Virtual Cable Unplug")
|
||||
print(" 7. Disconnect device")
|
||||
print(" 8. Delete Bonding")
|
||||
print(" 9. Re-connect to device")
|
||||
print("10. Exit ")
|
||||
print("\nEnter your choice : \n")
|
||||
|
||||
choice = await reader.readline()
|
||||
choice = choice.decode('utf-8').strip()
|
||||
|
||||
if choice == '1':
|
||||
await hid_device.connect_control_channel()
|
||||
|
||||
elif choice == '2':
|
||||
await hid_device.connect_interrupt_channel()
|
||||
|
||||
elif choice == '3':
|
||||
await hid_device.disconnect_control_channel()
|
||||
|
||||
elif choice == '4':
|
||||
await hid_device.disconnect_interrupt_channel()
|
||||
|
||||
elif choice == '5':
|
||||
print(" 1. Report ID 0x01")
|
||||
print(" 2. Report ID 0x02")
|
||||
print(" 3. Invalid Report ID")
|
||||
|
||||
choice1 = await reader.readline()
|
||||
choice1 = choice1.decode('utf-8').strip()
|
||||
|
||||
if choice1 == '1':
|
||||
data = bytearray(
|
||||
[0x01, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00]
|
||||
)
|
||||
hid_device.send_data(data)
|
||||
data = bytearray(
|
||||
[0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]
|
||||
)
|
||||
hid_device.send_data(data)
|
||||
|
||||
elif choice1 == '2':
|
||||
data = bytearray([0x02, 0x00, 0x00, 0xF6])
|
||||
hid_device.send_data(data)
|
||||
data = bytearray([0x02, 0x00, 0x00, 0x00])
|
||||
hid_device.send_data(data)
|
||||
|
||||
elif choice1 == '3':
|
||||
data = bytearray([0x00, 0x00, 0x00, 0x00])
|
||||
hid_device.send_data(data)
|
||||
data = bytearray([0x00, 0x00, 0x00, 0x00])
|
||||
hid_device.send_data(data)
|
||||
|
||||
else:
|
||||
print('Incorrect option selected')
|
||||
|
||||
elif choice == '6':
|
||||
hid_device.virtual_cable_unplug()
|
||||
try:
|
||||
hid_host_bd_addr = str(hid_device.remote_device_bd_address)
|
||||
await device.keystore.delete(hid_host_bd_addr)
|
||||
except KeyError:
|
||||
print('Device not found or Device already unpaired.')
|
||||
|
||||
elif choice == '7':
|
||||
connection = hid_device.connection
|
||||
if connection is not None:
|
||||
await connection.disconnect()
|
||||
else:
|
||||
print("Already disconnected from device")
|
||||
|
||||
elif choice == '8':
|
||||
try:
|
||||
hid_host_bd_addr = str(hid_device.remote_device_bd_address)
|
||||
await device.keystore.delete(hid_host_bd_addr)
|
||||
except KeyError:
|
||||
print('Device NOT found or Device already unpaired.')
|
||||
|
||||
elif choice == '9':
|
||||
hid_host_bd_addr = str(hid_device.remote_device_bd_address)
|
||||
connection = await device.connect(
|
||||
hid_host_bd_addr, transport=BT_BR_EDR_TRANSPORT
|
||||
)
|
||||
await connection.authenticate()
|
||||
await connection.encrypt()
|
||||
|
||||
elif choice == '10':
|
||||
sys.exit("Exit successful")
|
||||
|
||||
else:
|
||||
print("Invalid option selected.")
|
||||
|
||||
if (len(sys.argv) > 3) and (sys.argv[3] == 'test-mode'):
|
||||
# Test mode for PTS/Unit testing
|
||||
await menu()
|
||||
else:
|
||||
# default option is using keyboard.html (web)
|
||||
print("Executing in Web mode")
|
||||
await keyboard_device(hid_device)
|
||||
|
||||
await hci_source.wait_for_termination()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
|
||||
asyncio.run(main())
|
||||
@@ -285,7 +285,10 @@ async def main():
|
||||
print('example: run_hid_host.py classic1.json usb:0 E1:CA:72:48:C4:E8/P')
|
||||
return
|
||||
|
||||
def on_hid_data_cb(pdu):
|
||||
def on_hid_control_data_cb(pdu: bytes):
|
||||
print(f'Received Control Data, PDU: {pdu.hex()}')
|
||||
|
||||
def on_hid_interrupt_data_cb(pdu: bytes):
|
||||
report_type = pdu[0] & 0x0F
|
||||
if len(pdu) == 1:
|
||||
print(color(f'Warning: No report received', 'yellow'))
|
||||
@@ -305,7 +308,7 @@ async def main():
|
||||
|
||||
if (report_length <= 1) or (report_id == 0):
|
||||
return
|
||||
|
||||
# Parse report over interrupt channel
|
||||
if report_type == Message.ReportType.INPUT_REPORT:
|
||||
ReportParser.parse_input_report(pdu[1:]) # type: ignore
|
||||
|
||||
@@ -313,7 +316,9 @@ async def main():
|
||||
await hid_host.disconnect_interrupt_channel()
|
||||
await hid_host.disconnect_control_channel()
|
||||
await device.keystore.delete(target_address) # type: ignore
|
||||
await connection.disconnect()
|
||||
connection = hid_host.connection
|
||||
if connection is not None:
|
||||
await connection.disconnect()
|
||||
|
||||
def on_hid_virtual_cable_unplug_cb():
|
||||
asyncio.create_task(handle_virtual_cable_unplug())
|
||||
@@ -325,6 +330,18 @@ async def main():
|
||||
# Create a device
|
||||
device = Device.from_config_file_with_hci(sys.argv[1], hci_source, hci_sink)
|
||||
device.classic_enabled = True
|
||||
|
||||
# Create HID host and start it
|
||||
print('@@@ Starting HID Host...')
|
||||
hid_host = Host(device)
|
||||
|
||||
# Register for HID data call back
|
||||
hid_host.on('interrupt_data', on_hid_interrupt_data_cb)
|
||||
hid_host.on('control_data', on_hid_control_data_cb)
|
||||
|
||||
# Register for virtual cable unplug call back
|
||||
hid_host.on('virtual_cable_unplug', on_hid_virtual_cable_unplug_cb)
|
||||
|
||||
await device.power_on()
|
||||
|
||||
# Connect to a peer
|
||||
@@ -345,16 +362,6 @@ async def main():
|
||||
|
||||
await get_hid_device_sdp_record(connection)
|
||||
|
||||
# Create HID host and start it
|
||||
print('@@@ Starting HID Host...')
|
||||
hid_host = Host(device, connection)
|
||||
|
||||
# Register for HID data call back
|
||||
hid_host.on('data', on_hid_data_cb)
|
||||
|
||||
# Register for virtual cable unplug call back
|
||||
hid_host.on('virtual_cable_unplug', on_hid_virtual_cable_unplug_cb)
|
||||
|
||||
async def menu():
|
||||
reader = await get_stream_reader(sys.stdin)
|
||||
while True:
|
||||
@@ -369,13 +376,14 @@ async def main():
|
||||
print(" 6. Set Report")
|
||||
print(" 7. Set Protocol Mode")
|
||||
print(" 8. Get Protocol Mode")
|
||||
print(" 9. Send Report")
|
||||
print(" 9. Send Report on Interrupt Channel")
|
||||
print("10. Suspend")
|
||||
print("11. Exit Suspend")
|
||||
print("12. Virtual Cable Unplug")
|
||||
print("13. Disconnect device")
|
||||
print("14. Delete Bonding")
|
||||
print("15. Re-connect to device")
|
||||
print("16. Exit")
|
||||
print("\nEnter your choice : \n")
|
||||
|
||||
choice = await reader.readline()
|
||||
@@ -394,21 +402,40 @@ async def main():
|
||||
await hid_host.disconnect_interrupt_channel()
|
||||
|
||||
elif choice == '5':
|
||||
print(" 1. Report ID 0x02")
|
||||
print(" 2. Report ID 0x03")
|
||||
print(" 3. Report ID 0x05")
|
||||
print(" 1. Input Report with ID 0x01")
|
||||
print(" 2. Input Report with ID 0x02")
|
||||
print(" 3. Input Report with ID 0x0F - Invalid ReportId")
|
||||
print(" 4. Output Report with ID 0x02")
|
||||
print(" 5. Feature Report with ID 0x05 - Unsupported Request")
|
||||
print(" 6. Input Report with ID 0x02, BufferSize 3")
|
||||
print(" 7. Output Report with ID 0x03, BufferSize 2")
|
||||
print(" 8. Feature Report with ID 0x05, BufferSize 3")
|
||||
choice1 = await reader.readline()
|
||||
choice1 = choice1.decode('utf-8').strip()
|
||||
|
||||
if choice1 == '1':
|
||||
hid_host.get_report(1, 2, 3)
|
||||
hid_host.get_report(1, 1, 0)
|
||||
|
||||
elif choice1 == '2':
|
||||
hid_host.get_report(2, 3, 2)
|
||||
hid_host.get_report(1, 2, 0)
|
||||
|
||||
elif choice1 == '3':
|
||||
hid_host.get_report(3, 5, 3)
|
||||
hid_host.get_report(1, 5, 0)
|
||||
|
||||
elif choice1 == '4':
|
||||
hid_host.get_report(2, 2, 0)
|
||||
|
||||
elif choice1 == '5':
|
||||
hid_host.get_report(3, 15, 0)
|
||||
|
||||
elif choice1 == '6':
|
||||
hid_host.get_report(1, 2, 3)
|
||||
|
||||
elif choice1 == '7':
|
||||
hid_host.get_report(2, 3, 2)
|
||||
|
||||
elif choice1 == '8':
|
||||
hid_host.get_report(3, 5, 3)
|
||||
else:
|
||||
print('Incorrect option selected')
|
||||
|
||||
@@ -484,6 +511,7 @@ async def main():
|
||||
hid_host.virtual_cable_unplug()
|
||||
try:
|
||||
await device.keystore.delete(target_address)
|
||||
print("Unpair successful")
|
||||
except KeyError:
|
||||
print('Device not found or Device already unpaired.')
|
||||
|
||||
@@ -513,6 +541,9 @@ async def main():
|
||||
await connection.authenticate()
|
||||
await connection.encrypt()
|
||||
|
||||
elif choice == '16':
|
||||
sys.exit("Exit successful")
|
||||
|
||||
else:
|
||||
print("Invalid option selected.")
|
||||
|
||||
|
||||
195
examples/run_unicast_server.py
Normal file
195
examples/run_unicast_server.py
Normal file
@@ -0,0 +1,195 @@
|
||||
# Copyright 2021-2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
import os
|
||||
import struct
|
||||
import secrets
|
||||
from bumble.core import AdvertisingData
|
||||
from bumble.device import Device, CisLink
|
||||
from bumble.hci import (
|
||||
CodecID,
|
||||
CodingFormat,
|
||||
OwnAddressType,
|
||||
HCI_IsoDataPacket,
|
||||
HCI_LE_Set_Extended_Advertising_Parameters_Command,
|
||||
)
|
||||
from bumble.profiles.bap import (
|
||||
CodecSpecificCapabilities,
|
||||
ContextType,
|
||||
AudioLocation,
|
||||
SupportedSamplingFrequency,
|
||||
SupportedFrameDuration,
|
||||
PacRecord,
|
||||
PublishedAudioCapabilitiesService,
|
||||
AudioStreamControlService,
|
||||
)
|
||||
from bumble.profiles.cap import CommonAudioServiceService
|
||||
from bumble.profiles.csip import CoordinatedSetIdentificationService, SirkType
|
||||
|
||||
from bumble.transport import open_transport_or_link
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def main() -> None:
|
||||
if len(sys.argv) < 3:
|
||||
print('Usage: run_cig_setup.py <config-file>' '<transport-spec-for-device>')
|
||||
return
|
||||
|
||||
print('<<< connecting to HCI...')
|
||||
async with await open_transport_or_link(sys.argv[2]) as hci_transport:
|
||||
print('<<< connected')
|
||||
|
||||
device = Device.from_config_file_with_hci(
|
||||
sys.argv[1], hci_transport.source, hci_transport.sink
|
||||
)
|
||||
device.cis_enabled = True
|
||||
|
||||
await device.power_on()
|
||||
|
||||
csis = CoordinatedSetIdentificationService(
|
||||
set_identity_resolving_key=secrets.token_bytes(16),
|
||||
set_identity_resolving_key_type=SirkType.PLAINTEXT,
|
||||
)
|
||||
device.add_service(CommonAudioServiceService(csis))
|
||||
device.add_service(
|
||||
PublishedAudioCapabilitiesService(
|
||||
supported_source_context=ContextType.PROHIBITED,
|
||||
available_source_context=ContextType.PROHIBITED,
|
||||
supported_sink_context=ContextType.MEDIA,
|
||||
available_sink_context=ContextType.MEDIA,
|
||||
sink_audio_locations=(
|
||||
AudioLocation.FRONT_LEFT | AudioLocation.FRONT_RIGHT
|
||||
),
|
||||
sink_pac=[
|
||||
# Codec Capability Setting 16_2
|
||||
PacRecord(
|
||||
coding_format=CodingFormat(CodecID.LC3),
|
||||
codec_specific_capabilities=CodecSpecificCapabilities(
|
||||
supported_sampling_frequencies=(
|
||||
SupportedSamplingFrequency.FREQ_16000
|
||||
),
|
||||
supported_frame_durations=(
|
||||
SupportedFrameDuration.DURATION_10000_US_SUPPORTED
|
||||
),
|
||||
supported_audio_channel_counts=[1],
|
||||
min_octets_per_codec_frame=40,
|
||||
max_octets_per_codec_frame=40,
|
||||
supported_max_codec_frames_per_sdu=1,
|
||||
),
|
||||
),
|
||||
# Codec Capability Setting 24_2
|
||||
PacRecord(
|
||||
coding_format=CodingFormat(CodecID.LC3),
|
||||
codec_specific_capabilities=CodecSpecificCapabilities(
|
||||
supported_sampling_frequencies=(
|
||||
SupportedSamplingFrequency.FREQ_24000
|
||||
),
|
||||
supported_frame_durations=(
|
||||
SupportedFrameDuration.DURATION_10000_US_SUPPORTED
|
||||
),
|
||||
supported_audio_channel_counts=[1],
|
||||
min_octets_per_codec_frame=60,
|
||||
max_octets_per_codec_frame=60,
|
||||
supported_max_codec_frames_per_sdu=1,
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
device.add_service(AudioStreamControlService(device, sink_ase_id=[1, 2]))
|
||||
|
||||
advertising_data = (
|
||||
bytes(
|
||||
AdvertisingData(
|
||||
[
|
||||
(
|
||||
AdvertisingData.COMPLETE_LOCAL_NAME,
|
||||
bytes('Bumble LE Audio', 'utf-8'),
|
||||
),
|
||||
(
|
||||
AdvertisingData.FLAGS,
|
||||
bytes(
|
||||
[
|
||||
AdvertisingData.LE_GENERAL_DISCOVERABLE_MODE_FLAG
|
||||
| AdvertisingData.BR_EDR_HOST_FLAG
|
||||
| AdvertisingData.BR_EDR_CONTROLLER_FLAG
|
||||
]
|
||||
),
|
||||
),
|
||||
(
|
||||
AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS,
|
||||
bytes(PublishedAudioCapabilitiesService.UUID),
|
||||
),
|
||||
]
|
||||
)
|
||||
)
|
||||
+ csis.get_advertising_data()
|
||||
)
|
||||
subprocess = await asyncio.create_subprocess_shell(
|
||||
f'dlc3 | ffplay pipe:0',
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
|
||||
stdin = subprocess.stdin
|
||||
assert stdin
|
||||
|
||||
# Write a fake LC3 header to dlc3.
|
||||
stdin.write(
|
||||
bytes([0x1C, 0xCC]) # Header.
|
||||
+ struct.pack(
|
||||
'<HHHHHHI',
|
||||
18, # Header length.
|
||||
24000 // 100, # Sampling Rate(/100Hz).
|
||||
0, # Bitrate(unused).
|
||||
1, # Channels.
|
||||
10000 // 10, # Frame duration(/10us).
|
||||
0, # RFU.
|
||||
0x0FFFFFFF, # Frame counts.
|
||||
)
|
||||
)
|
||||
|
||||
def on_pdu(pdu: HCI_IsoDataPacket):
|
||||
# LC3 format: |frame_length(2)| + |frame(length)|.
|
||||
if pdu.iso_sdu_length:
|
||||
stdin.write(struct.pack('<H', pdu.iso_sdu_length))
|
||||
stdin.write(pdu.iso_sdu_fragment)
|
||||
|
||||
def on_cis(cis_link: CisLink):
|
||||
cis_link.on('pdu', on_pdu)
|
||||
|
||||
device.once('cis_establishment', on_cis)
|
||||
|
||||
await device.start_extended_advertising(
|
||||
advertising_properties=(
|
||||
HCI_LE_Set_Extended_Advertising_Parameters_Command.AdvertisingProperties.CONNECTABLE_ADVERTISING
|
||||
),
|
||||
own_address_type=OwnAddressType.RANDOM,
|
||||
advertising_data=advertising_data,
|
||||
)
|
||||
|
||||
await hci_transport.source.terminated
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
|
||||
asyncio.run(main())
|
||||
@@ -16,17 +16,83 @@ package com.github.google.bumble.btbench
|
||||
|
||||
import android.annotation.SuppressLint
|
||||
import android.bluetooth.BluetoothAdapter
|
||||
import java.io.IOException
|
||||
import android.bluetooth.BluetoothDevice
|
||||
import android.bluetooth.BluetoothGatt
|
||||
import android.bluetooth.BluetoothGattCallback
|
||||
import android.bluetooth.BluetoothProfile
|
||||
import android.content.Context
|
||||
import android.os.Build
|
||||
import java.util.logging.Logger
|
||||
import kotlin.concurrent.thread
|
||||
|
||||
private val Log = Logger.getLogger("btbench.l2cap-client")
|
||||
|
||||
class L2capClient(private val viewModel: AppViewModel, val bluetoothAdapter: BluetoothAdapter) {
|
||||
class L2capClient(
|
||||
private val viewModel: AppViewModel,
|
||||
private val bluetoothAdapter: BluetoothAdapter,
|
||||
private val context: Context
|
||||
) {
|
||||
@SuppressLint("MissingPermission")
|
||||
fun run() {
|
||||
viewModel.running = true
|
||||
val remoteDevice = bluetoothAdapter.getRemoteDevice(viewModel.peerBluetoothAddress)
|
||||
val addressIsPublic = viewModel.peerBluetoothAddress.endsWith("/P")
|
||||
val address = viewModel.peerBluetoothAddress.take(17)
|
||||
val remoteDevice = if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.TIRAMISU) {
|
||||
bluetoothAdapter.getRemoteLeDevice(
|
||||
address,
|
||||
if (addressIsPublic) {
|
||||
BluetoothDevice.ADDRESS_TYPE_PUBLIC
|
||||
} else {
|
||||
BluetoothDevice.ADDRESS_TYPE_RANDOM
|
||||
}
|
||||
)
|
||||
} else {
|
||||
bluetoothAdapter.getRemoteDevice(address)
|
||||
}
|
||||
|
||||
val gatt = remoteDevice.connectGatt(
|
||||
context,
|
||||
false,
|
||||
object : BluetoothGattCallback() {
|
||||
override fun onMtuChanged(gatt: BluetoothGatt, mtu: Int, status: Int) {
|
||||
Log.info("MTU update: mtu=$mtu status=$status")
|
||||
viewModel.mtu = mtu
|
||||
}
|
||||
|
||||
override fun onPhyUpdate(gatt: BluetoothGatt, txPhy: Int, rxPhy: Int, status: Int) {
|
||||
Log.info("PHY update: tx=$txPhy, rx=$rxPhy, status=$status")
|
||||
viewModel.txPhy = txPhy
|
||||
viewModel.rxPhy = rxPhy
|
||||
}
|
||||
|
||||
override fun onPhyRead(gatt: BluetoothGatt, txPhy: Int, rxPhy: Int, status: Int) {
|
||||
Log.info("PHY: tx=$txPhy, rx=$rxPhy, status=$status")
|
||||
viewModel.txPhy = txPhy
|
||||
viewModel.rxPhy = rxPhy
|
||||
}
|
||||
|
||||
override fun onConnectionStateChange(
|
||||
gatt: BluetoothGatt?, status: Int, newState: Int
|
||||
) {
|
||||
if (gatt != null && newState == BluetoothProfile.STATE_CONNECTED) {
|
||||
if (viewModel.use2mPhy) {
|
||||
gatt.setPreferredPhy(
|
||||
BluetoothDevice.PHY_LE_2M_MASK,
|
||||
BluetoothDevice.PHY_LE_2M_MASK,
|
||||
BluetoothDevice.PHY_OPTION_NO_PREFERRED
|
||||
)
|
||||
}
|
||||
gatt.readPhy()
|
||||
|
||||
// Request an MTU update, even though we don't use GATT, because Android
|
||||
// won't request a larger link layer maximum data length otherwise.
|
||||
gatt.requestMtu(517)
|
||||
}
|
||||
}
|
||||
},
|
||||
BluetoothDevice.TRANSPORT_LE,
|
||||
if (viewModel.use2mPhy) BluetoothDevice.PHY_LE_2M_MASK else BluetoothDevice.PHY_LE_1M_MASK
|
||||
)
|
||||
|
||||
val socket = remoteDevice.createInsecureL2capChannel(viewModel.l2capPsm)
|
||||
|
||||
val client = SocketClient(viewModel, socket)
|
||||
|
||||
@@ -30,7 +30,7 @@ private val Log = Logger.getLogger("btbench.l2cap-server")
|
||||
class L2capServer(private val viewModel: AppViewModel, private val bluetoothAdapter: BluetoothAdapter) {
|
||||
@SuppressLint("MissingPermission")
|
||||
fun run() {
|
||||
// Advertise to that the peer can find us and connect.
|
||||
// Advertise so that the peer can find us and connect.
|
||||
val callback = object: AdvertiseCallback() {
|
||||
override fun onStartFailure(errorCode: Int) {
|
||||
Log.warning("failed to start advertising: $errorCode")
|
||||
@@ -50,13 +50,12 @@ class L2capServer(private val viewModel: AppViewModel, private val bluetoothAdap
|
||||
val advertiseData = AdvertiseData.Builder().build()
|
||||
val scanData = AdvertiseData.Builder().setIncludeDeviceName(true).build()
|
||||
val advertiser = bluetoothAdapter.bluetoothLeAdvertiser
|
||||
advertiser.startAdvertising(advertiseSettings, advertiseData, scanData, callback)
|
||||
|
||||
val serverSocket = bluetoothAdapter.listenUsingInsecureL2capChannel()
|
||||
viewModel.l2capPsm = serverSocket.psm
|
||||
Log.info("psm = $serverSocket.psm")
|
||||
|
||||
val server = SocketServer(viewModel, serverSocket)
|
||||
server.run({ advertiser.stopAdvertising(callback) })
|
||||
server.run({ advertiser.stopAdvertising(callback) }, { advertiser.startAdvertising(advertiseSettings, advertiseData, scanData, callback) })
|
||||
}
|
||||
}
|
||||
@@ -26,23 +26,33 @@ import android.os.Bundle
|
||||
import androidx.activity.ComponentActivity
|
||||
import androidx.activity.compose.setContent
|
||||
import androidx.activity.result.contract.ActivityResultContracts
|
||||
import androidx.compose.foundation.layout.Arrangement
|
||||
import androidx.compose.foundation.layout.Column
|
||||
import androidx.compose.foundation.layout.Row
|
||||
import androidx.compose.foundation.layout.Spacer
|
||||
import androidx.compose.foundation.layout.fillMaxSize
|
||||
import androidx.compose.foundation.layout.fillMaxWidth
|
||||
import androidx.compose.foundation.layout.padding
|
||||
import androidx.compose.foundation.rememberScrollState
|
||||
import androidx.compose.foundation.text.KeyboardActions
|
||||
import androidx.compose.foundation.text.KeyboardOptions
|
||||
import androidx.compose.foundation.verticalScroll
|
||||
import androidx.compose.material3.Button
|
||||
import androidx.compose.material3.Divider
|
||||
import androidx.compose.material3.MaterialTheme
|
||||
import androidx.compose.material3.Slider
|
||||
import androidx.compose.material3.Surface
|
||||
import androidx.compose.material3.Switch
|
||||
import androidx.compose.material3.Text
|
||||
import androidx.compose.material3.TextField
|
||||
import androidx.compose.runtime.Composable
|
||||
import androidx.compose.runtime.remember
|
||||
import androidx.compose.ui.Alignment
|
||||
import androidx.compose.ui.ExperimentalComposeUiApi
|
||||
import androidx.compose.ui.Modifier
|
||||
import androidx.compose.ui.focus.FocusRequester
|
||||
import androidx.compose.ui.focus.focusRequester
|
||||
import androidx.compose.ui.platform.LocalFocusManager
|
||||
import androidx.compose.ui.platform.LocalSoftwareKeyboardController
|
||||
import androidx.compose.ui.text.font.FontWeight
|
||||
import androidx.compose.ui.text.input.ImeAction
|
||||
@@ -171,7 +181,7 @@ class MainActivity : ComponentActivity() {
|
||||
}
|
||||
|
||||
private fun runL2capClient() {
|
||||
val l2capClient = bluetoothAdapter?.let { L2capClient(appViewModel, it) }
|
||||
val l2capClient = bluetoothAdapter?.let { L2capClient(appViewModel, it, baseContext) }
|
||||
l2capClient?.run()
|
||||
}
|
||||
|
||||
@@ -199,9 +209,12 @@ fun MainView(
|
||||
runL2capServer: () -> Unit
|
||||
) {
|
||||
BTBenchTheme {
|
||||
// A surface container using the 'background' color from the theme
|
||||
val scrollState = rememberScrollState()
|
||||
Surface(
|
||||
modifier = Modifier.fillMaxSize(), color = MaterialTheme.colorScheme.background
|
||||
modifier = Modifier
|
||||
.fillMaxSize()
|
||||
.verticalScroll(scrollState),
|
||||
color = MaterialTheme.colorScheme.background
|
||||
) {
|
||||
Column(modifier = Modifier.padding(horizontal = 16.dp)) {
|
||||
Text(
|
||||
@@ -212,28 +225,33 @@ fun MainView(
|
||||
)
|
||||
Divider()
|
||||
val keyboardController = LocalSoftwareKeyboardController.current
|
||||
TextField(label = {
|
||||
Text(text = "Peer Bluetooth Address")
|
||||
},
|
||||
val focusRequester = remember { FocusRequester() }
|
||||
val focusManager = LocalFocusManager.current
|
||||
TextField(
|
||||
label = {
|
||||
Text(text = "Peer Bluetooth Address")
|
||||
},
|
||||
value = appViewModel.peerBluetoothAddress,
|
||||
modifier = Modifier.fillMaxWidth(),
|
||||
modifier = Modifier.fillMaxWidth().focusRequester(focusRequester),
|
||||
keyboardOptions = KeyboardOptions.Default.copy(
|
||||
keyboardType = KeyboardType.Ascii, imeAction = ImeAction.Done
|
||||
),
|
||||
onValueChange = {
|
||||
appViewModel.updatePeerBluetoothAddress(it)
|
||||
},
|
||||
keyboardActions = KeyboardActions(onDone = { keyboardController?.hide() })
|
||||
keyboardActions = KeyboardActions(onDone = {
|
||||
keyboardController?.hide()
|
||||
focusManager.clearFocus()
|
||||
})
|
||||
)
|
||||
Divider()
|
||||
TextField(label = {
|
||||
Text(text = "L2CAP PSM")
|
||||
},
|
||||
value = appViewModel.l2capPsm.toString(),
|
||||
modifier = Modifier.fillMaxWidth(),
|
||||
modifier = Modifier.fillMaxWidth().focusRequester(focusRequester),
|
||||
keyboardOptions = KeyboardOptions.Default.copy(
|
||||
keyboardType = KeyboardType.Number,
|
||||
imeAction = ImeAction.Done
|
||||
keyboardType = KeyboardType.Number, imeAction = ImeAction.Done
|
||||
),
|
||||
onValueChange = {
|
||||
if (it.isNotEmpty()) {
|
||||
@@ -243,7 +261,11 @@ fun MainView(
|
||||
}
|
||||
}
|
||||
},
|
||||
keyboardActions = KeyboardActions(onDone = { keyboardController?.hide() }))
|
||||
keyboardActions = KeyboardActions(onDone = {
|
||||
keyboardController?.hide()
|
||||
focusManager.clearFocus()
|
||||
})
|
||||
)
|
||||
Divider()
|
||||
Slider(
|
||||
value = appViewModel.senderPacketCountSlider, onValueChange = {
|
||||
@@ -264,7 +286,19 @@ fun MainView(
|
||||
ActionButton(
|
||||
text = "Become Discoverable", onClick = becomeDiscoverable, true
|
||||
)
|
||||
Row() {
|
||||
Row(
|
||||
horizontalArrangement = Arrangement.SpaceBetween,
|
||||
verticalAlignment = Alignment.CenterVertically
|
||||
) {
|
||||
Text(text = "2M PHY")
|
||||
Spacer(modifier = Modifier.padding(start = 8.dp))
|
||||
Switch(
|
||||
checked = appViewModel.use2mPhy,
|
||||
onCheckedChange = { appViewModel.use2mPhy = it }
|
||||
)
|
||||
|
||||
}
|
||||
Row {
|
||||
ActionButton(
|
||||
text = "RFCOMM Client", onClick = runRfcommClient, !appViewModel.running
|
||||
)
|
||||
@@ -272,7 +306,7 @@ fun MainView(
|
||||
text = "RFCOMM Server", onClick = runRfcommServer, !appViewModel.running
|
||||
)
|
||||
}
|
||||
Row() {
|
||||
Row {
|
||||
ActionButton(
|
||||
text = "L2CAP Client", onClick = runL2capClient, !appViewModel.running
|
||||
)
|
||||
@@ -281,6 +315,12 @@ fun MainView(
|
||||
)
|
||||
}
|
||||
Divider()
|
||||
Text(
|
||||
text = if (appViewModel.mtu != 0) "MTU: ${appViewModel.mtu}" else ""
|
||||
)
|
||||
Text(
|
||||
text = if (appViewModel.rxPhy != 0 || appViewModel.txPhy != 0) "PHY: tx=${appViewModel.txPhy}, rx=${appViewModel.rxPhy}" else ""
|
||||
)
|
||||
Text(
|
||||
text = "Packets Sent: ${appViewModel.packetsSent}"
|
||||
)
|
||||
|
||||
@@ -23,15 +23,20 @@ import androidx.compose.runtime.setValue
|
||||
import androidx.lifecycle.ViewModel
|
||||
import java.util.UUID
|
||||
|
||||
val DEFAULT_RFCOMM_UUID = UUID.fromString("E6D55659-C8B4-4B85-96BB-B1143AF6D3AE")
|
||||
val DEFAULT_RFCOMM_UUID: UUID = UUID.fromString("E6D55659-C8B4-4B85-96BB-B1143AF6D3AE")
|
||||
const val DEFAULT_PEER_BLUETOOTH_ADDRESS = "AA:BB:CC:DD:EE:FF"
|
||||
const val DEFAULT_SENDER_PACKET_COUNT = 100
|
||||
const val DEFAULT_SENDER_PACKET_SIZE = 1024
|
||||
const val DEFAULT_PSM = 128
|
||||
|
||||
class AppViewModel : ViewModel() {
|
||||
private var preferences: SharedPreferences? = null
|
||||
var peerBluetoothAddress by mutableStateOf(DEFAULT_PEER_BLUETOOTH_ADDRESS)
|
||||
var l2capPsm by mutableStateOf(0)
|
||||
var l2capPsm by mutableIntStateOf(DEFAULT_PSM)
|
||||
var use2mPhy by mutableStateOf(true)
|
||||
var mtu by mutableIntStateOf(0)
|
||||
var rxPhy by mutableIntStateOf(0)
|
||||
var txPhy by mutableIntStateOf(0)
|
||||
var senderPacketCountSlider by mutableFloatStateOf(0.0F)
|
||||
var senderPacketSizeSlider by mutableFloatStateOf(0.0F)
|
||||
var senderPacketCount by mutableIntStateOf(DEFAULT_SENDER_PACKET_COUNT)
|
||||
@@ -64,28 +69,29 @@ class AppViewModel : ViewModel() {
|
||||
}
|
||||
|
||||
fun updatePeerBluetoothAddress(peerBluetoothAddress: String) {
|
||||
this.peerBluetoothAddress = peerBluetoothAddress
|
||||
val address = peerBluetoothAddress.uppercase()
|
||||
this.peerBluetoothAddress = address
|
||||
|
||||
// Save the address to the preferences
|
||||
with(preferences!!.edit()) {
|
||||
putString(PEER_BLUETOOTH_ADDRESS_PREF_KEY, peerBluetoothAddress)
|
||||
putString(PEER_BLUETOOTH_ADDRESS_PREF_KEY, address)
|
||||
apply()
|
||||
}
|
||||
}
|
||||
|
||||
fun updateSenderPacketCountSlider() {
|
||||
if (senderPacketCount <= 10) {
|
||||
senderPacketCountSlider = 0.0F
|
||||
senderPacketCountSlider = if (senderPacketCount <= 10) {
|
||||
0.0F
|
||||
} else if (senderPacketCount <= 50) {
|
||||
senderPacketCountSlider = 0.2F
|
||||
0.2F
|
||||
} else if (senderPacketCount <= 100) {
|
||||
senderPacketCountSlider = 0.4F
|
||||
0.4F
|
||||
} else if (senderPacketCount <= 500) {
|
||||
senderPacketCountSlider = 0.6F
|
||||
0.6F
|
||||
} else if (senderPacketCount <= 1000) {
|
||||
senderPacketCountSlider = 0.8F
|
||||
0.8F
|
||||
} else {
|
||||
senderPacketCountSlider = 1.0F
|
||||
1.0F
|
||||
}
|
||||
|
||||
with(preferences!!.edit()) {
|
||||
@@ -95,18 +101,18 @@ class AppViewModel : ViewModel() {
|
||||
}
|
||||
|
||||
fun updateSenderPacketCount() {
|
||||
if (senderPacketCountSlider < 0.1F) {
|
||||
senderPacketCount = 10
|
||||
senderPacketCount = if (senderPacketCountSlider < 0.1F) {
|
||||
10
|
||||
} else if (senderPacketCountSlider < 0.3F) {
|
||||
senderPacketCount = 50
|
||||
50
|
||||
} else if (senderPacketCountSlider < 0.5F) {
|
||||
senderPacketCount = 100
|
||||
100
|
||||
} else if (senderPacketCountSlider < 0.7F) {
|
||||
senderPacketCount = 500
|
||||
500
|
||||
} else if (senderPacketCountSlider < 0.9F) {
|
||||
senderPacketCount = 1000
|
||||
1000
|
||||
} else {
|
||||
senderPacketCount = 10000
|
||||
10000
|
||||
}
|
||||
|
||||
with(preferences!!.edit()) {
|
||||
@@ -116,18 +122,18 @@ class AppViewModel : ViewModel() {
|
||||
}
|
||||
|
||||
fun updateSenderPacketSizeSlider() {
|
||||
if (senderPacketSize <= 1) {
|
||||
senderPacketSizeSlider = 0.0F
|
||||
senderPacketSizeSlider = if (senderPacketSize <= 16) {
|
||||
0.0F
|
||||
} else if (senderPacketSize <= 256) {
|
||||
senderPacketSizeSlider = 0.02F
|
||||
0.02F
|
||||
} else if (senderPacketSize <= 512) {
|
||||
senderPacketSizeSlider = 0.4F
|
||||
0.4F
|
||||
} else if (senderPacketSize <= 1024) {
|
||||
senderPacketSizeSlider = 0.6F
|
||||
0.6F
|
||||
} else if (senderPacketSize <= 2048) {
|
||||
senderPacketSizeSlider = 0.8F
|
||||
0.8F
|
||||
} else {
|
||||
senderPacketSizeSlider = 1.0F
|
||||
1.0F
|
||||
}
|
||||
|
||||
with(preferences!!.edit()) {
|
||||
@@ -137,18 +143,18 @@ class AppViewModel : ViewModel() {
|
||||
}
|
||||
|
||||
fun updateSenderPacketSize() {
|
||||
if (senderPacketSizeSlider < 0.1F) {
|
||||
senderPacketSize = 1
|
||||
senderPacketSize = if (senderPacketSizeSlider < 0.1F) {
|
||||
16
|
||||
} else if (senderPacketSizeSlider < 0.3F) {
|
||||
senderPacketSize = 256
|
||||
256
|
||||
} else if (senderPacketSizeSlider < 0.5F) {
|
||||
senderPacketSize = 512
|
||||
512
|
||||
} else if (senderPacketSizeSlider < 0.7F) {
|
||||
senderPacketSize = 1024
|
||||
1024
|
||||
} else if (senderPacketSizeSlider < 0.9F) {
|
||||
senderPacketSize = 2048
|
||||
2048
|
||||
} else {
|
||||
senderPacketSize = 4096
|
||||
4096
|
||||
}
|
||||
|
||||
with(preferences!!.edit()) {
|
||||
|
||||
@@ -25,7 +25,8 @@ private val Log = Logger.getLogger("btbench.rfcomm-client")
|
||||
class RfcommClient(private val viewModel: AppViewModel, val bluetoothAdapter: BluetoothAdapter) {
|
||||
@SuppressLint("MissingPermission")
|
||||
fun run() {
|
||||
val remoteDevice = bluetoothAdapter.getRemoteDevice(viewModel.peerBluetoothAddress)
|
||||
val address = viewModel.peerBluetoothAddress.take(17)
|
||||
val remoteDevice = bluetoothAdapter.getRemoteDevice(address)
|
||||
val socket = remoteDevice.createInsecureRfcommSocketToServiceRecord(
|
||||
DEFAULT_RFCOMM_UUID
|
||||
)
|
||||
|
||||
@@ -30,6 +30,6 @@ class RfcommServer(private val viewModel: AppViewModel, val bluetoothAdapter: Bl
|
||||
)
|
||||
|
||||
val server = SocketServer(viewModel, serverSocket)
|
||||
server.run({})
|
||||
server.run({}, {})
|
||||
}
|
||||
}
|
||||
@@ -22,6 +22,8 @@ import kotlin.concurrent.thread
|
||||
|
||||
private val Log = Logger.getLogger("btbench.socket-client")
|
||||
|
||||
private const val DEFAULT_STARTUP_DELAY = 3000
|
||||
|
||||
class SocketClient(private val viewModel: AppViewModel, private val socket: BluetoothSocket) {
|
||||
@SuppressLint("MissingPermission")
|
||||
fun run() {
|
||||
@@ -56,6 +58,10 @@ class SocketClient(private val viewModel: AppViewModel, private val socket: Blue
|
||||
socketDataSource.receive()
|
||||
}
|
||||
|
||||
Log.info("Startup delay: $DEFAULT_STARTUP_DELAY")
|
||||
Thread.sleep(DEFAULT_STARTUP_DELAY.toLong());
|
||||
Log.info("Starting to send")
|
||||
|
||||
sender.run()
|
||||
cleanup()
|
||||
}
|
||||
|
||||
@@ -22,14 +22,13 @@ import kotlin.concurrent.thread
|
||||
private val Log = Logger.getLogger("btbench.socket-server")
|
||||
|
||||
class SocketServer(private val viewModel: AppViewModel, private val serverSocket: BluetoothServerSocket) {
|
||||
fun run(onTerminate: () -> Unit) {
|
||||
fun run(onConnected: () -> Unit, onDisconnected: () -> Unit) {
|
||||
var aborted = false
|
||||
viewModel.running = true
|
||||
|
||||
fun cleanup() {
|
||||
serverSocket.close()
|
||||
viewModel.running = false
|
||||
onTerminate()
|
||||
}
|
||||
|
||||
thread(name = "SocketServer") {
|
||||
@@ -38,6 +37,7 @@ class SocketServer(private val viewModel: AppViewModel, private val serverSocket
|
||||
serverSocket.close()
|
||||
}
|
||||
Log.info("waiting for connection...")
|
||||
onDisconnected()
|
||||
val socket = try {
|
||||
serverSocket.accept()
|
||||
} catch (error: IOException) {
|
||||
@@ -45,7 +45,8 @@ class SocketServer(private val viewModel: AppViewModel, private val serverSocket
|
||||
cleanup()
|
||||
return@thread
|
||||
}
|
||||
Log.info("got connection")
|
||||
Log.info("got connection from ${socket.remoteDevice.address}")
|
||||
onConnected()
|
||||
|
||||
viewModel.aborter = {
|
||||
aborted = true
|
||||
|
||||
@@ -42,6 +42,7 @@ public class HciServer {
|
||||
try (ServerSocket serverSocket = new ServerSocket(mPort)) {
|
||||
mListener.onMessage("Waiting for connection on port " + serverSocket.getLocalPort());
|
||||
try (Socket clientSocket = serverSocket.accept()) {
|
||||
clientSocket.setTcpNoDelay(true);
|
||||
mListener.onHostConnectionState(true);
|
||||
mListener.onMessage("Connected");
|
||||
HciParser parser = new HciParser(mListener);
|
||||
|
||||
@@ -10,8 +10,10 @@ import androidx.compose.foundation.layout.Column
|
||||
import androidx.compose.foundation.layout.fillMaxSize
|
||||
import androidx.compose.foundation.layout.fillMaxWidth
|
||||
import androidx.compose.foundation.layout.padding
|
||||
import androidx.compose.foundation.rememberScrollState
|
||||
import androidx.compose.foundation.text.KeyboardActions
|
||||
import androidx.compose.foundation.text.KeyboardOptions
|
||||
import androidx.compose.foundation.verticalScroll
|
||||
import androidx.compose.material3.Button
|
||||
import androidx.compose.material3.Divider
|
||||
import androidx.compose.material3.ExperimentalMaterial3Api
|
||||
@@ -71,7 +73,7 @@ class AppViewModel : ViewModel(), HciProxy.Listener {
|
||||
this.tcpPort = tcpPort
|
||||
|
||||
// Save the port to the preferences
|
||||
with (preferences!!.edit()) {
|
||||
with(preferences!!.edit()) {
|
||||
putString(TCP_PORT_PREF_KEY, tcpPort.toString())
|
||||
apply()
|
||||
}
|
||||
@@ -138,7 +140,8 @@ class MainActivity : ComponentActivity() {
|
||||
log.warning("Exception while running HCI Server: $error")
|
||||
} catch (error: HalException) {
|
||||
log.warning("HAL exception: ${error.message}")
|
||||
appViewModel.message = "Cannot bind to HAL (${error.message}). You may need to use the command 'setenforce 0' in a root adb shell."
|
||||
appViewModel.message =
|
||||
"Cannot bind to HAL (${error.message}). You may need to use the command 'setenforce 0' in a root adb shell."
|
||||
}
|
||||
log.info("HCI Proxy thread ended")
|
||||
appViewModel.canStart = true
|
||||
@@ -157,9 +160,12 @@ fun ActionButton(text: String, onClick: () -> Unit, enabled: Boolean) {
|
||||
@Composable
|
||||
fun MainView(appViewModel: AppViewModel, startProxy: () -> Unit) {
|
||||
RemoteHCITheme {
|
||||
// A surface container using the 'background' color from the theme
|
||||
val scrollState = rememberScrollState()
|
||||
Surface(
|
||||
modifier = Modifier.fillMaxSize(), color = MaterialTheme.colorScheme.background
|
||||
modifier = Modifier
|
||||
.fillMaxSize()
|
||||
.verticalScroll(scrollState),
|
||||
color = MaterialTheme.colorScheme.background
|
||||
) {
|
||||
Column(modifier = Modifier.padding(horizontal = 16.dp)) {
|
||||
Text(
|
||||
@@ -174,13 +180,15 @@ fun MainView(appViewModel: AppViewModel, startProxy: () -> Unit) {
|
||||
)
|
||||
Divider()
|
||||
val keyboardController = LocalSoftwareKeyboardController.current
|
||||
TextField(
|
||||
label = {
|
||||
Text(text = "TCP Port")
|
||||
},
|
||||
TextField(label = {
|
||||
Text(text = "TCP Port")
|
||||
},
|
||||
value = appViewModel.tcpPort.toString(),
|
||||
modifier = Modifier.fillMaxWidth(),
|
||||
keyboardOptions = KeyboardOptions.Default.copy(keyboardType = KeyboardType.Number, imeAction = ImeAction.Done),
|
||||
keyboardOptions = KeyboardOptions.Default.copy(
|
||||
keyboardType = KeyboardType.Number,
|
||||
imeAction = ImeAction.Done
|
||||
),
|
||||
onValueChange = {
|
||||
if (it.isNotEmpty()) {
|
||||
val tcpPort = it.toIntOrNull()
|
||||
@@ -189,10 +197,7 @@ fun MainView(appViewModel: AppViewModel, startProxy: () -> Unit) {
|
||||
}
|
||||
}
|
||||
},
|
||||
keyboardActions = KeyboardActions(
|
||||
onDone = {keyboardController?.hide()}
|
||||
)
|
||||
)
|
||||
keyboardActions = KeyboardActions(onDone = { keyboardController?.hide() }))
|
||||
Divider()
|
||||
val connectState = if (appViewModel.hostConnected) "CONNECTED" else "DISCONNECTED"
|
||||
Text(
|
||||
|
||||
8
rust/Cargo.lock
generated
8
rust/Cargo.lock
generated
@@ -1073,9 +1073,9 @@ checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d"
|
||||
|
||||
[[package]]
|
||||
name = "openssl"
|
||||
version = "0.10.57"
|
||||
version = "0.10.60"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bac25ee399abb46215765b1cb35bc0212377e58a061560d8b29b024fd0430e7c"
|
||||
checksum = "79a4c6c3a2b158f7f8f2a2fc5a969fa3a068df6fc9dbb4a43845436e3af7c800"
|
||||
dependencies = [
|
||||
"bitflags 2.4.0",
|
||||
"cfg-if",
|
||||
@@ -1105,9 +1105,9 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf"
|
||||
|
||||
[[package]]
|
||||
name = "openssl-sys"
|
||||
version = "0.9.92"
|
||||
version = "0.9.96"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "db7e971c2c2bba161b2d2fdf37080177eff520b3bc044787c7f1f5f9e78d869b"
|
||||
checksum = "3812c071ba60da8b5677cc12bcb1d42989a65553772897a7e0355545a819838f"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"libc",
|
||||
|
||||
@@ -56,6 +56,7 @@ install_requires =
|
||||
|
||||
[options.entry_points]
|
||||
console_scripts =
|
||||
bumble-ble-rpa-tool = bumble.apps.ble_rpa_tool:main
|
||||
bumble-console = bumble.apps.console:main
|
||||
bumble-controller-info = bumble.apps.controller_info:main
|
||||
bumble-gatt-dump = bumble.apps.gatt_dump:main
|
||||
|
||||
403
tests/bap_test.py
Normal file
403
tests/bap_test.py
Normal file
@@ -0,0 +1,403 @@
|
||||
# Copyright 2021-2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import asyncio
|
||||
import os
|
||||
import functools
|
||||
import pytest
|
||||
import logging
|
||||
|
||||
from bumble import device
|
||||
from bumble.hci import CodecID, CodingFormat
|
||||
from bumble.profiles.bap import (
|
||||
AudioLocation,
|
||||
AseStateMachine,
|
||||
ASE_Operation,
|
||||
ASE_Config_Codec,
|
||||
ASE_Config_QOS,
|
||||
ASE_Disable,
|
||||
ASE_Enable,
|
||||
ASE_Receiver_Start_Ready,
|
||||
ASE_Receiver_Stop_Ready,
|
||||
ASE_Release,
|
||||
ASE_Update_Metadata,
|
||||
SupportedFrameDuration,
|
||||
SupportedSamplingFrequency,
|
||||
SamplingFrequency,
|
||||
FrameDuration,
|
||||
CodecSpecificCapabilities,
|
||||
CodecSpecificConfiguration,
|
||||
ContextType,
|
||||
PacRecord,
|
||||
AudioStreamControlService,
|
||||
AudioStreamControlServiceProxy,
|
||||
PublishedAudioCapabilitiesService,
|
||||
PublishedAudioCapabilitiesServiceProxy,
|
||||
)
|
||||
from tests.test_utils import TwoDevices
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
# -----------------------------------------------------------------------------
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def basic_check(operation: ASE_Operation):
|
||||
serialized = bytes(operation)
|
||||
parsed = ASE_Operation.from_bytes(serialized)
|
||||
assert bytes(parsed) == serialized
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def test_codec_specific_capabilities() -> None:
|
||||
SAMPLE_FREQUENCY = SupportedSamplingFrequency.FREQ_16000
|
||||
FRAME_SURATION = SupportedFrameDuration.DURATION_10000_US_SUPPORTED
|
||||
AUDIO_CHANNEL_COUNTS = [1]
|
||||
cap = CodecSpecificCapabilities(
|
||||
supported_sampling_frequencies=SAMPLE_FREQUENCY,
|
||||
supported_frame_durations=FRAME_SURATION,
|
||||
supported_audio_channel_counts=AUDIO_CHANNEL_COUNTS,
|
||||
min_octets_per_codec_frame=40,
|
||||
max_octets_per_codec_frame=40,
|
||||
supported_max_codec_frames_per_sdu=1,
|
||||
)
|
||||
assert CodecSpecificCapabilities.from_bytes(bytes(cap)) == cap
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def test_pac_record() -> None:
|
||||
SAMPLE_FREQUENCY = SupportedSamplingFrequency.FREQ_16000
|
||||
FRAME_SURATION = SupportedFrameDuration.DURATION_10000_US_SUPPORTED
|
||||
AUDIO_CHANNEL_COUNTS = [1]
|
||||
cap = CodecSpecificCapabilities(
|
||||
supported_sampling_frequencies=SAMPLE_FREQUENCY,
|
||||
supported_frame_durations=FRAME_SURATION,
|
||||
supported_audio_channel_counts=AUDIO_CHANNEL_COUNTS,
|
||||
min_octets_per_codec_frame=40,
|
||||
max_octets_per_codec_frame=40,
|
||||
supported_max_codec_frames_per_sdu=1,
|
||||
)
|
||||
|
||||
pac_record = PacRecord(
|
||||
coding_format=CodingFormat(CodecID.LC3),
|
||||
codec_specific_capabilities=cap,
|
||||
metadata=b'',
|
||||
)
|
||||
assert PacRecord.from_bytes(bytes(pac_record)) == pac_record
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def test_vendor_specific_pac_record() -> None:
|
||||
# Vendor-Specific codec, Google, ID=0xFFFF. No capabilities and metadata.
|
||||
RAW_DATA = bytes.fromhex('ffe000ffff0000')
|
||||
assert bytes(PacRecord.from_bytes(RAW_DATA)) == RAW_DATA
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def test_ASE_Config_Codec() -> None:
|
||||
operation = ASE_Config_Codec(
|
||||
ase_id=[1, 2],
|
||||
target_latency=[3, 4],
|
||||
target_phy=[5, 6],
|
||||
codec_id=[CodingFormat(CodecID.LC3), CodingFormat(CodecID.LC3)],
|
||||
codec_specific_configuration=[b'foo', b'bar'],
|
||||
)
|
||||
basic_check(operation)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def test_ASE_Config_QOS() -> None:
|
||||
operation = ASE_Config_QOS(
|
||||
ase_id=[1, 2],
|
||||
cig_id=[1, 2],
|
||||
cis_id=[3, 4],
|
||||
sdu_interval=[5, 6],
|
||||
framing=[0, 1],
|
||||
phy=[2, 3],
|
||||
max_sdu=[4, 5],
|
||||
retransmission_number=[6, 7],
|
||||
max_transport_latency=[8, 9],
|
||||
presentation_delay=[10, 11],
|
||||
)
|
||||
basic_check(operation)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def test_ASE_Enable() -> None:
|
||||
operation = ASE_Enable(
|
||||
ase_id=[1, 2],
|
||||
metadata=[b'foo', b'bar'],
|
||||
)
|
||||
basic_check(operation)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def test_ASE_Update_Metadata() -> None:
|
||||
operation = ASE_Update_Metadata(
|
||||
ase_id=[1, 2],
|
||||
metadata=[b'foo', b'bar'],
|
||||
)
|
||||
basic_check(operation)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def test_ASE_Disable() -> None:
|
||||
operation = ASE_Disable(ase_id=[1, 2])
|
||||
basic_check(operation)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def test_ASE_Release() -> None:
|
||||
operation = ASE_Release(ase_id=[1, 2])
|
||||
basic_check(operation)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def test_ASE_Receiver_Start_Ready() -> None:
|
||||
operation = ASE_Receiver_Start_Ready(ase_id=[1, 2])
|
||||
basic_check(operation)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def test_ASE_Receiver_Stop_Ready() -> None:
|
||||
operation = ASE_Receiver_Stop_Ready(ase_id=[1, 2])
|
||||
basic_check(operation)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def test_codec_specific_configuration() -> None:
|
||||
SAMPLE_FREQUENCY = SamplingFrequency.FREQ_16000
|
||||
FRAME_SURATION = FrameDuration.DURATION_10000_US
|
||||
AUDIO_LOCATION = AudioLocation.FRONT_LEFT
|
||||
config = CodecSpecificConfiguration(
|
||||
sampling_frequency=SAMPLE_FREQUENCY,
|
||||
frame_duration=FRAME_SURATION,
|
||||
audio_channel_allocation=AUDIO_LOCATION,
|
||||
octets_per_codec_frame=60,
|
||||
codec_frames_per_sdu=1,
|
||||
)
|
||||
assert CodecSpecificConfiguration.from_bytes(bytes(config)) == config
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_pacs():
|
||||
devices = TwoDevices()
|
||||
devices[0].add_service(
|
||||
PublishedAudioCapabilitiesService(
|
||||
supported_sink_context=ContextType.MEDIA,
|
||||
available_sink_context=ContextType.MEDIA,
|
||||
supported_source_context=0,
|
||||
available_source_context=0,
|
||||
sink_pac=[
|
||||
# Codec Capability Setting 16_2
|
||||
PacRecord(
|
||||
coding_format=CodingFormat(CodecID.LC3),
|
||||
codec_specific_capabilities=CodecSpecificCapabilities(
|
||||
supported_sampling_frequencies=(
|
||||
SupportedSamplingFrequency.FREQ_16000
|
||||
),
|
||||
supported_frame_durations=(
|
||||
SupportedFrameDuration.DURATION_10000_US_SUPPORTED
|
||||
),
|
||||
supported_audio_channel_counts=[1],
|
||||
min_octets_per_codec_frame=40,
|
||||
max_octets_per_codec_frame=40,
|
||||
supported_max_codec_frames_per_sdu=1,
|
||||
),
|
||||
),
|
||||
# Codec Capability Setting 24_2
|
||||
PacRecord(
|
||||
coding_format=CodingFormat(CodecID.LC3),
|
||||
codec_specific_capabilities=CodecSpecificCapabilities(
|
||||
supported_sampling_frequencies=(
|
||||
SupportedSamplingFrequency.FREQ_24000
|
||||
),
|
||||
supported_frame_durations=(
|
||||
SupportedFrameDuration.DURATION_10000_US_SUPPORTED
|
||||
),
|
||||
supported_audio_channel_counts=[1],
|
||||
min_octets_per_codec_frame=60,
|
||||
max_octets_per_codec_frame=60,
|
||||
supported_max_codec_frames_per_sdu=1,
|
||||
),
|
||||
),
|
||||
],
|
||||
sink_audio_locations=AudioLocation.FRONT_LEFT | AudioLocation.FRONT_RIGHT,
|
||||
)
|
||||
)
|
||||
|
||||
await devices.setup_connection()
|
||||
peer = device.Peer(devices.connections[1])
|
||||
pacs_client = await peer.discover_service_and_create_proxy(
|
||||
PublishedAudioCapabilitiesServiceProxy
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_ascs():
|
||||
devices = TwoDevices()
|
||||
devices[0].add_service(
|
||||
AudioStreamControlService(device=devices[0], sink_ase_id=[1, 2])
|
||||
)
|
||||
|
||||
await devices.setup_connection()
|
||||
peer = device.Peer(devices.connections[1])
|
||||
ascs_client = await peer.discover_service_and_create_proxy(
|
||||
AudioStreamControlServiceProxy
|
||||
)
|
||||
|
||||
notifications = {1: asyncio.Queue(), 2: asyncio.Queue()}
|
||||
|
||||
def on_notification(data: bytes, ase_id: int):
|
||||
notifications[ase_id].put_nowait(data)
|
||||
|
||||
# Should be idle
|
||||
assert await ascs_client.sink_ase[0].read_value() == bytes(
|
||||
[1, AseStateMachine.State.IDLE]
|
||||
)
|
||||
assert await ascs_client.sink_ase[1].read_value() == bytes(
|
||||
[2, AseStateMachine.State.IDLE]
|
||||
)
|
||||
|
||||
# Subscribe
|
||||
await ascs_client.sink_ase[0].subscribe(
|
||||
functools.partial(on_notification, ase_id=1)
|
||||
)
|
||||
await ascs_client.sink_ase[1].subscribe(
|
||||
functools.partial(on_notification, ase_id=2)
|
||||
)
|
||||
|
||||
# Config Codec
|
||||
config = CodecSpecificConfiguration(
|
||||
sampling_frequency=SamplingFrequency.FREQ_48000,
|
||||
frame_duration=FrameDuration.DURATION_10000_US,
|
||||
audio_channel_allocation=AudioLocation.FRONT_LEFT,
|
||||
octets_per_codec_frame=120,
|
||||
codec_frames_per_sdu=1,
|
||||
)
|
||||
await ascs_client.ase_control_point.write_value(
|
||||
ASE_Config_Codec(
|
||||
ase_id=[1, 2],
|
||||
target_latency=[3, 4],
|
||||
target_phy=[5, 6],
|
||||
codec_id=[CodingFormat(CodecID.LC3), CodingFormat(CodecID.LC3)],
|
||||
codec_specific_configuration=[config, config],
|
||||
)
|
||||
)
|
||||
assert (await notifications[1].get())[:2] == bytes(
|
||||
[1, AseStateMachine.State.CODEC_CONFIGURED]
|
||||
)
|
||||
assert (await notifications[2].get())[:2] == bytes(
|
||||
[2, AseStateMachine.State.CODEC_CONFIGURED]
|
||||
)
|
||||
|
||||
# Config QOS
|
||||
await ascs_client.ase_control_point.write_value(
|
||||
ASE_Config_QOS(
|
||||
ase_id=[1, 2],
|
||||
cig_id=[1, 2],
|
||||
cis_id=[3, 4],
|
||||
sdu_interval=[5, 6],
|
||||
framing=[0, 1],
|
||||
phy=[2, 3],
|
||||
max_sdu=[4, 5],
|
||||
retransmission_number=[6, 7],
|
||||
max_transport_latency=[8, 9],
|
||||
presentation_delay=[10, 11],
|
||||
)
|
||||
)
|
||||
assert (await notifications[1].get())[:2] == bytes(
|
||||
[1, AseStateMachine.State.QOS_CONFIGURED]
|
||||
)
|
||||
assert (await notifications[2].get())[:2] == bytes(
|
||||
[2, AseStateMachine.State.QOS_CONFIGURED]
|
||||
)
|
||||
|
||||
# Enable
|
||||
await ascs_client.ase_control_point.write_value(
|
||||
ASE_Enable(
|
||||
ase_id=[1, 2],
|
||||
metadata=[b'foo', b'bar'],
|
||||
)
|
||||
)
|
||||
assert (await notifications[1].get())[:2] == bytes(
|
||||
[1, AseStateMachine.State.ENABLING]
|
||||
)
|
||||
assert (await notifications[2].get())[:2] == bytes(
|
||||
[2, AseStateMachine.State.ENABLING]
|
||||
)
|
||||
|
||||
# CIS establishment
|
||||
devices[0].emit(
|
||||
'cis_establishment',
|
||||
device.CisLink(
|
||||
device=devices[0],
|
||||
acl_connection=devices.connections[0],
|
||||
handle=5,
|
||||
cis_id=3,
|
||||
cig_id=1,
|
||||
),
|
||||
)
|
||||
devices[0].emit(
|
||||
'cis_establishment',
|
||||
device.CisLink(
|
||||
device=devices[0],
|
||||
acl_connection=devices.connections[0],
|
||||
handle=6,
|
||||
cis_id=4,
|
||||
cig_id=2,
|
||||
),
|
||||
)
|
||||
assert (await notifications[1].get())[:2] == bytes(
|
||||
[1, AseStateMachine.State.STREAMING]
|
||||
)
|
||||
assert (await notifications[2].get())[:2] == bytes(
|
||||
[2, AseStateMachine.State.STREAMING]
|
||||
)
|
||||
|
||||
# Release
|
||||
await ascs_client.ase_control_point.write_value(
|
||||
ASE_Release(
|
||||
ase_id=[1, 2],
|
||||
metadata=[b'foo', b'bar'],
|
||||
)
|
||||
)
|
||||
assert (await notifications[1].get())[:2] == bytes(
|
||||
[1, AseStateMachine.State.RELEASING]
|
||||
)
|
||||
assert (await notifications[2].get())[:2] == bytes(
|
||||
[2, AseStateMachine.State.RELEASING]
|
||||
)
|
||||
assert (await notifications[1].get())[:2] == bytes([1, AseStateMachine.State.IDLE])
|
||||
assert (await notifications[2].get())[:2] == bytes([2, AseStateMachine.State.IDLE])
|
||||
|
||||
await asyncio.sleep(0.001)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def run():
|
||||
await test_pacs()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
if __name__ == '__main__':
|
||||
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
|
||||
asyncio.run(run())
|
||||
71
tests/cap_test.py
Normal file
71
tests/cap_test.py
Normal file
@@ -0,0 +1,71 @@
|
||||
# Copyright 2021-2023 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import asyncio
|
||||
import os
|
||||
import pytest
|
||||
import logging
|
||||
|
||||
from bumble import device
|
||||
from bumble import gatt
|
||||
from bumble.profiles import cap
|
||||
from bumble.profiles import csip
|
||||
from .test_utils import TwoDevices
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
# -----------------------------------------------------------------------------
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_cas():
|
||||
SIRK = bytes.fromhex('2f62c8ae41867d1bb619e788a2605faa')
|
||||
|
||||
devices = TwoDevices()
|
||||
devices[0].add_service(
|
||||
cap.CommonAudioServiceService(
|
||||
csip.CoordinatedSetIdentificationService(
|
||||
set_identity_resolving_key=SIRK,
|
||||
set_identity_resolving_key_type=csip.SirkType.PLAINTEXT,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
await devices.setup_connection()
|
||||
peer = device.Peer(devices.connections[1])
|
||||
cas_client = await peer.discover_service_and_create_proxy(
|
||||
cap.CommonAudioServiceServiceProxy
|
||||
)
|
||||
|
||||
included_services = await peer.discover_included_services(cas_client.service_proxy)
|
||||
assert any(
|
||||
service.uuid == gatt.GATT_COORDINATED_SET_IDENTIFICATION_SERVICE
|
||||
for service in included_services
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def run():
|
||||
await test_cas()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
if __name__ == '__main__':
|
||||
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
|
||||
asyncio.run(run())
|
||||
@@ -20,6 +20,7 @@ import os
|
||||
import pytest
|
||||
import struct
|
||||
import logging
|
||||
from unittest import mock
|
||||
|
||||
from bumble import device
|
||||
from bumble.profiles import csip
|
||||
@@ -31,15 +32,55 @@ from .test_utils import TwoDevices
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def test_s1():
|
||||
assert (
|
||||
csip.s1(b'SIRKenc'[::-1])
|
||||
== bytes.fromhex('6901983f 18149e82 3c7d133a 7d774572')[::-1]
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def test_k1():
|
||||
K = bytes.fromhex('676e1b9b d448696f 061ec622 3ce5ced9')[::-1]
|
||||
SALT = csip.s1(b'SIRKenc'[::-1])
|
||||
P = b'csis'[::-1]
|
||||
assert (
|
||||
csip.k1(K, SALT, P)
|
||||
== bytes.fromhex('5277453c c094d982 b0e8ee53 2f2d1f8b')[::-1]
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def test_sih():
|
||||
SIRK = bytes.fromhex('457d7d09 21a1fd22 cecd8c86 dd72cccd')[::-1]
|
||||
PRAND = bytes.fromhex('69f563')[::-1]
|
||||
assert csip.sih(SIRK, PRAND) == bytes.fromhex('1948da')[::-1]
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def test_sef():
|
||||
SIRK = bytes.fromhex('457d7d09 21a1fd22 cecd8c86 dd72cccd')[::-1]
|
||||
K = bytes.fromhex('676e1b9b d448696f 061ec622 3ce5ced9')[::-1]
|
||||
assert (
|
||||
csip.sef(K, SIRK) == bytes.fromhex('170a3835 e13524a0 7e2562d5 f25fd346')[::-1]
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_csis():
|
||||
@pytest.mark.parametrize(
|
||||
'sirk_type,', [(csip.SirkType.ENCRYPTED), (csip.SirkType.PLAINTEXT)]
|
||||
)
|
||||
async def test_csis(sirk_type):
|
||||
SIRK = bytes.fromhex('2f62c8ae41867d1bb619e788a2605faa')
|
||||
LTK = bytes.fromhex('2f62c8ae41867d1bb619e788a2605faa')
|
||||
|
||||
devices = TwoDevices()
|
||||
devices[0].add_service(
|
||||
csip.CoordinatedSetIdentificationService(
|
||||
set_identity_resolving_key=SIRK,
|
||||
set_identity_resolving_key_type=sirk_type,
|
||||
coordinated_set_size=2,
|
||||
set_member_lock=csip.MemberLock.UNLOCKED,
|
||||
set_member_rank=0,
|
||||
@@ -47,15 +88,19 @@ async def test_csis():
|
||||
)
|
||||
|
||||
await devices.setup_connection()
|
||||
|
||||
# Mock encryption.
|
||||
devices.connections[0].encryption = 1
|
||||
devices.connections[1].encryption = 1
|
||||
devices[0].get_long_term_key = mock.AsyncMock(return_value=LTK)
|
||||
devices[1].get_long_term_key = mock.AsyncMock(return_value=LTK)
|
||||
|
||||
peer = device.Peer(devices.connections[1])
|
||||
csis_client = await peer.discover_service_and_create_proxy(
|
||||
csip.CoordinatedSetIdentificationProxy
|
||||
)
|
||||
|
||||
assert (
|
||||
await csis_client.set_identity_resolving_key.read_value()
|
||||
== bytes([csip.SirkType.PLAINTEXT]) + SIRK
|
||||
)
|
||||
assert await csis_client.read_set_identity_resolving_key() == (sirk_type, SIRK)
|
||||
assert await csis_client.coordinated_set_size.read_value() == struct.pack('B', 2)
|
||||
assert await csis_client.set_member_lock.read_value() == struct.pack(
|
||||
'B', csip.MemberLock.UNLOCKED
|
||||
@@ -65,6 +110,7 @@ async def test_csis():
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def run():
|
||||
test_sih()
|
||||
await test_csis()
|
||||
|
||||
|
||||
|
||||
@@ -20,16 +20,23 @@ import logging
|
||||
import os
|
||||
from types import LambdaType
|
||||
import pytest
|
||||
from unittest import mock
|
||||
|
||||
from bumble.core import BT_BR_EDR_TRANSPORT
|
||||
from bumble.core import (
|
||||
BT_BR_EDR_TRANSPORT,
|
||||
BT_LE_TRANSPORT,
|
||||
BT_PERIPHERAL_ROLE,
|
||||
ConnectionParameters,
|
||||
)
|
||||
from bumble.device import Connection, Device
|
||||
from bumble.host import Host
|
||||
from bumble.host import AclPacketQueue, Host
|
||||
from bumble.hci import (
|
||||
HCI_ACCEPT_CONNECTION_REQUEST_COMMAND,
|
||||
HCI_COMMAND_STATUS_PENDING,
|
||||
HCI_CREATE_CONNECTION_COMMAND,
|
||||
HCI_SUCCESS,
|
||||
Address,
|
||||
OwnAddressType,
|
||||
HCI_Command_Complete_Event,
|
||||
HCI_Command_Status_Event,
|
||||
HCI_Connection_Complete_Event,
|
||||
@@ -66,6 +73,13 @@ async def test_device_connect_parallel():
|
||||
d1 = Device(host=Host(None, None))
|
||||
d2 = Device(host=Host(None, None))
|
||||
|
||||
def _send(packet):
|
||||
pass
|
||||
|
||||
d0.host.acl_packet_queue = AclPacketQueue(0, 0, _send)
|
||||
d1.host.acl_packet_queue = AclPacketQueue(0, 0, _send)
|
||||
d2.host.acl_packet_queue = AclPacketQueue(0, 0, _send)
|
||||
|
||||
# enable classic
|
||||
d0.classic_enabled = True
|
||||
d1.classic_enabled = True
|
||||
@@ -232,6 +246,172 @@ async def test_flush():
|
||||
pass
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_legacy_advertising():
|
||||
device = Device(host=mock.AsyncMock(Host))
|
||||
|
||||
# Start advertising
|
||||
advertiser = await device.start_legacy_advertising()
|
||||
assert device.legacy_advertiser
|
||||
|
||||
# Stop advertising
|
||||
await advertiser.stop()
|
||||
assert not device.legacy_advertiser
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.parametrize(
|
||||
'own_address_type,',
|
||||
(OwnAddressType.PUBLIC, OwnAddressType.RANDOM),
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_legacy_advertising_connection(own_address_type):
|
||||
device = Device(host=mock.AsyncMock(Host))
|
||||
peer_address = Address('F0:F1:F2:F3:F4:F5')
|
||||
|
||||
# Start advertising
|
||||
advertiser = await device.start_legacy_advertising()
|
||||
device.on_connection(
|
||||
0x0001,
|
||||
BT_LE_TRANSPORT,
|
||||
peer_address,
|
||||
BT_PERIPHERAL_ROLE,
|
||||
ConnectionParameters(0, 0, 0),
|
||||
)
|
||||
|
||||
if own_address_type == OwnAddressType.PUBLIC:
|
||||
assert device.lookup_connection(0x0001).self_address == device.public_address
|
||||
else:
|
||||
assert device.lookup_connection(0x0001).self_address == device.random_address
|
||||
|
||||
# For unknown reason, read_phy() in on_connection() would be killed at the end of
|
||||
# test, so we force scheduling here to avoid an warning.
|
||||
await asyncio.sleep(0.0001)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.parametrize(
|
||||
'auto_restart,',
|
||||
(True, False),
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_legacy_advertising_disconnection(auto_restart):
|
||||
device = Device(host=mock.AsyncMock(spec=Host))
|
||||
peer_address = Address('F0:F1:F2:F3:F4:F5')
|
||||
advertiser = await device.start_legacy_advertising(auto_restart=auto_restart)
|
||||
device.on_connection(
|
||||
0x0001,
|
||||
BT_LE_TRANSPORT,
|
||||
peer_address,
|
||||
BT_PERIPHERAL_ROLE,
|
||||
ConnectionParameters(0, 0, 0),
|
||||
)
|
||||
|
||||
device.start_legacy_advertising = mock.AsyncMock()
|
||||
|
||||
device.on_disconnection(0x0001, 0)
|
||||
|
||||
if auto_restart:
|
||||
device.start_legacy_advertising.assert_called_with(
|
||||
advertising_type=advertiser.advertising_type,
|
||||
own_address_type=advertiser.own_address_type,
|
||||
auto_restart=advertiser.auto_restart,
|
||||
advertising_data=advertiser.advertising_data,
|
||||
scan_response_data=advertiser.scan_response_data,
|
||||
)
|
||||
else:
|
||||
device.start_legacy_advertising.assert_not_called()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_extended_advertising():
|
||||
device = Device(host=mock.AsyncMock(Host))
|
||||
|
||||
# Start advertising
|
||||
advertiser = await device.start_extended_advertising()
|
||||
assert device.extended_advertisers
|
||||
|
||||
# Stop advertising
|
||||
await advertiser.stop()
|
||||
assert not device.extended_advertisers
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.parametrize(
|
||||
'own_address_type,',
|
||||
(OwnAddressType.PUBLIC, OwnAddressType.RANDOM),
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_extended_advertising_connection(own_address_type):
|
||||
device = Device(host=mock.AsyncMock(spec=Host))
|
||||
peer_address = Address('F0:F1:F2:F3:F4:F5')
|
||||
advertiser = await device.start_extended_advertising(
|
||||
own_address_type=own_address_type
|
||||
)
|
||||
device.on_connection(
|
||||
0x0001,
|
||||
BT_LE_TRANSPORT,
|
||||
peer_address,
|
||||
BT_PERIPHERAL_ROLE,
|
||||
ConnectionParameters(0, 0, 0),
|
||||
)
|
||||
device.on_advertising_set_termination(
|
||||
HCI_SUCCESS,
|
||||
advertiser.handle,
|
||||
0x0001,
|
||||
)
|
||||
|
||||
if own_address_type == OwnAddressType.PUBLIC:
|
||||
assert device.lookup_connection(0x0001).self_address == device.public_address
|
||||
else:
|
||||
assert device.lookup_connection(0x0001).self_address == device.random_address
|
||||
|
||||
# For unknown reason, read_phy() in on_connection() would be killed at the end of
|
||||
# test, so we force scheduling here to avoid an warning.
|
||||
await asyncio.sleep(0.0001)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.parametrize(
|
||||
'auto_restart,',
|
||||
(True, False),
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_extended_advertising_disconnection(auto_restart):
|
||||
device = Device(host=mock.AsyncMock(spec=Host))
|
||||
peer_address = Address('F0:F1:F2:F3:F4:F5')
|
||||
advertiser = await device.start_extended_advertising(auto_restart=auto_restart)
|
||||
device.on_connection(
|
||||
0x0001,
|
||||
BT_LE_TRANSPORT,
|
||||
peer_address,
|
||||
BT_PERIPHERAL_ROLE,
|
||||
ConnectionParameters(0, 0, 0),
|
||||
)
|
||||
device.on_advertising_set_termination(
|
||||
HCI_SUCCESS,
|
||||
advertiser.handle,
|
||||
0x0001,
|
||||
)
|
||||
|
||||
device.start_extended_advertising = mock.AsyncMock()
|
||||
|
||||
device.on_disconnection(0x0001, 0)
|
||||
|
||||
if auto_restart:
|
||||
device.start_extended_advertising.assert_called_with(
|
||||
advertising_properties=advertiser.advertising_properties,
|
||||
own_address_type=advertiser.own_address_type,
|
||||
auto_restart=advertiser.auto_restart,
|
||||
advertising_data=advertiser.advertising_data,
|
||||
scan_response_data=advertiser.scan_response_data,
|
||||
)
|
||||
else:
|
||||
device.start_extended_advertising.assert_not_called()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def test_gatt_services_with_gas():
|
||||
device = Device(host=Host(None, None))
|
||||
|
||||
@@ -20,11 +20,10 @@ import logging
|
||||
import os
|
||||
import struct
|
||||
import pytest
|
||||
from unittest.mock import Mock, ANY
|
||||
from unittest.mock import AsyncMock, Mock, ANY
|
||||
|
||||
from bumble.controller import Controller
|
||||
from bumble.gatt_client import CharacteristicProxy
|
||||
from bumble.gatt_server import Server
|
||||
from bumble.link import LocalLink
|
||||
from bumble.device import Device, Peer
|
||||
from bumble.host import Host
|
||||
@@ -120,9 +119,9 @@ async def test_characteristic_encoding():
|
||||
Characteristic.READABLE,
|
||||
123,
|
||||
)
|
||||
x = c.read_value(None)
|
||||
x = await c.read_value(None)
|
||||
assert x == bytes([123])
|
||||
c.write_value(None, bytes([122]))
|
||||
await c.write_value(None, bytes([122]))
|
||||
assert c.value == 122
|
||||
|
||||
class FooProxy(CharacteristicProxy):
|
||||
@@ -152,7 +151,22 @@ async def test_characteristic_encoding():
|
||||
bytes([123]),
|
||||
)
|
||||
|
||||
service = Service('3A657F47-D34F-46B3-B1EC-698E29B6B829', [characteristic])
|
||||
async def async_read(connection):
|
||||
return 0x05060708
|
||||
|
||||
async_characteristic = PackedCharacteristicAdapter(
|
||||
Characteristic(
|
||||
'2AB7E91B-43E8-4F73-AC3B-80C1683B47F9',
|
||||
Characteristic.Properties.READ,
|
||||
Characteristic.READABLE,
|
||||
CharacteristicValue(read=async_read),
|
||||
),
|
||||
'>I',
|
||||
)
|
||||
|
||||
service = Service(
|
||||
'3A657F47-D34F-46B3-B1EC-698E29B6B829', [characteristic, async_characteristic]
|
||||
)
|
||||
server.add_service(service)
|
||||
|
||||
await client.power_on()
|
||||
@@ -184,6 +198,13 @@ async def test_characteristic_encoding():
|
||||
await async_barrier()
|
||||
assert characteristic.value == bytes([50])
|
||||
|
||||
c2 = peer.get_characteristics_by_uuid(async_characteristic.uuid)
|
||||
assert len(c2) == 1
|
||||
c2 = c2[0]
|
||||
cd2 = PackedCharacteristicAdapter(c2, ">I")
|
||||
cd2v = await cd2.read_value()
|
||||
assert cd2v == 0x05060708
|
||||
|
||||
last_change = None
|
||||
|
||||
def on_change(value):
|
||||
@@ -285,7 +306,8 @@ async def test_attribute_getters():
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def test_CharacteristicAdapter():
|
||||
@pytest.mark.asyncio
|
||||
async def test_CharacteristicAdapter():
|
||||
# Check that the CharacteristicAdapter base class is transparent
|
||||
v = bytes([1, 2, 3])
|
||||
c = Characteristic(
|
||||
@@ -296,11 +318,11 @@ def test_CharacteristicAdapter():
|
||||
)
|
||||
a = CharacteristicAdapter(c)
|
||||
|
||||
value = a.read_value(None)
|
||||
value = await a.read_value(None)
|
||||
assert value == v
|
||||
|
||||
v = bytes([3, 4, 5])
|
||||
a.write_value(None, v)
|
||||
await a.write_value(None, v)
|
||||
assert c.value == v
|
||||
|
||||
# Simple delegated adapter
|
||||
@@ -308,11 +330,11 @@ def test_CharacteristicAdapter():
|
||||
c, lambda x: bytes(reversed(x)), lambda x: bytes(reversed(x))
|
||||
)
|
||||
|
||||
value = a.read_value(None)
|
||||
value = await a.read_value(None)
|
||||
assert value == bytes(reversed(v))
|
||||
|
||||
v = bytes([3, 4, 5])
|
||||
a.write_value(None, v)
|
||||
await a.write_value(None, v)
|
||||
assert a.value == bytes(reversed(v))
|
||||
|
||||
# Packed adapter with single element format
|
||||
@@ -321,10 +343,10 @@ def test_CharacteristicAdapter():
|
||||
c.value = v
|
||||
a = PackedCharacteristicAdapter(c, '>H')
|
||||
|
||||
value = a.read_value(None)
|
||||
value = await a.read_value(None)
|
||||
assert value == pv
|
||||
c.value = None
|
||||
a.write_value(None, pv)
|
||||
await a.write_value(None, pv)
|
||||
assert a.value == v
|
||||
|
||||
# Packed adapter with multi-element format
|
||||
@@ -334,10 +356,10 @@ def test_CharacteristicAdapter():
|
||||
c.value = (v1, v2)
|
||||
a = PackedCharacteristicAdapter(c, '>HH')
|
||||
|
||||
value = a.read_value(None)
|
||||
value = await a.read_value(None)
|
||||
assert value == pv
|
||||
c.value = None
|
||||
a.write_value(None, pv)
|
||||
await a.write_value(None, pv)
|
||||
assert a.value == (v1, v2)
|
||||
|
||||
# Mapped adapter
|
||||
@@ -348,10 +370,10 @@ def test_CharacteristicAdapter():
|
||||
c.value = mapped
|
||||
a = MappedCharacteristicAdapter(c, '>HH', ('v1', 'v2'))
|
||||
|
||||
value = a.read_value(None)
|
||||
value = await a.read_value(None)
|
||||
assert value == pv
|
||||
c.value = None
|
||||
a.write_value(None, pv)
|
||||
await a.write_value(None, pv)
|
||||
assert a.value == mapped
|
||||
|
||||
# UTF-8 adapter
|
||||
@@ -360,27 +382,49 @@ def test_CharacteristicAdapter():
|
||||
c.value = v
|
||||
a = UTF8CharacteristicAdapter(c)
|
||||
|
||||
value = a.read_value(None)
|
||||
value = await a.read_value(None)
|
||||
assert value == ev
|
||||
c.value = None
|
||||
a.write_value(None, ev)
|
||||
await a.write_value(None, ev)
|
||||
assert a.value == v
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def test_CharacteristicValue():
|
||||
@pytest.mark.asyncio
|
||||
async def test_CharacteristicValue():
|
||||
b = bytes([1, 2, 3])
|
||||
c = CharacteristicValue(read=lambda _: b)
|
||||
x = c.read(None)
|
||||
|
||||
async def read_value(connection):
|
||||
return b
|
||||
|
||||
c = CharacteristicValue(read=read_value)
|
||||
x = await c.read(None)
|
||||
assert x == b
|
||||
|
||||
result = []
|
||||
c = CharacteristicValue(
|
||||
write=lambda connection, value: result.append((connection, value))
|
||||
)
|
||||
m = Mock()
|
||||
c = CharacteristicValue(write=m)
|
||||
z = object()
|
||||
c.write(z, b)
|
||||
assert result == [(z, b)]
|
||||
m.assert_called_once_with(z, b)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_CharacteristicValue_async():
|
||||
b = bytes([1, 2, 3])
|
||||
|
||||
async def read_value(connection):
|
||||
return b
|
||||
|
||||
c = CharacteristicValue(read=read_value)
|
||||
x = await c.read(None)
|
||||
assert x == b
|
||||
|
||||
m = AsyncMock()
|
||||
c = CharacteristicValue(write=m)
|
||||
z = object()
|
||||
await c.write(z, b)
|
||||
m.assert_called_once_with(z, b)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -961,12 +1005,18 @@ Descriptor(handle=0x0009, type=UUID-16:2902 (Client Characteristic Configuration
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def async_main():
|
||||
test_UUID()
|
||||
test_ATT_Error_Response()
|
||||
test_ATT_Read_By_Group_Type_Request()
|
||||
await test_read_write()
|
||||
await test_read_write2()
|
||||
await test_subscribe_notify()
|
||||
await test_unsubscribe()
|
||||
await test_characteristic_encoding()
|
||||
await test_mtu_exchange()
|
||||
await test_CharacteristicValue()
|
||||
await test_CharacteristicValue_async()
|
||||
await test_CharacteristicAdapter()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -1105,9 +1155,4 @@ def test_get_attribute_group():
|
||||
# -----------------------------------------------------------------------------
|
||||
if __name__ == '__main__':
|
||||
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
|
||||
test_UUID()
|
||||
test_ATT_Error_Response()
|
||||
test_ATT_Read_By_Group_Type_Request()
|
||||
test_CharacteristicValue()
|
||||
test_CharacteristicAdapter()
|
||||
asyncio.run(async_main())
|
||||
|
||||
@@ -32,6 +32,7 @@ from bumble.hci import (
|
||||
HCI_CustomPacket,
|
||||
HCI_Disconnect_Command,
|
||||
HCI_Event,
|
||||
HCI_IsoDataPacket,
|
||||
HCI_LE_Add_Device_To_Filter_Accept_List_Command,
|
||||
HCI_LE_Advertising_Report_Event,
|
||||
HCI_LE_Channel_Selection_Algorithm_Event,
|
||||
@@ -53,6 +54,7 @@ from bumble.hci import (
|
||||
HCI_LE_Set_Random_Address_Command,
|
||||
HCI_LE_Set_Scan_Enable_Command,
|
||||
HCI_LE_Set_Scan_Parameters_Command,
|
||||
HCI_LE_Setup_ISO_Data_Path_Command,
|
||||
HCI_Number_Of_Completed_Packets_Event,
|
||||
HCI_Packet,
|
||||
HCI_PIN_Code_Request_Reply_Command,
|
||||
@@ -455,6 +457,14 @@ def test_HCI_LE_Setup_ISO_Data_Path_Command():
|
||||
assert command.controller_delay == 0
|
||||
assert command.codec_configuration == b''
|
||||
|
||||
command = HCI_LE_Setup_ISO_Data_Path_Command(
|
||||
connection_handle=0x0060,
|
||||
data_path_direction=0x00,
|
||||
data_path_id=0x01,
|
||||
codec_id=CodingFormat(CodecID.TRANSPARENT),
|
||||
controller_delay=0x00,
|
||||
codec_configuration=b'',
|
||||
)
|
||||
basic_check(command)
|
||||
|
||||
|
||||
@@ -477,6 +487,29 @@ def test_custom():
|
||||
assert packet.payload == data
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def test_iso_data_packet():
|
||||
data = bytes.fromhex(
|
||||
'05616044002ac9f0a193003c00e83b477b00eba8d41dc018bf1a980f0290afe1e7c37652096697'
|
||||
'52b6a535a8df61e22931ef5a36281bc77ed6a3206d984bcdabee6be831c699cb50e2'
|
||||
)
|
||||
packet = HCI_IsoDataPacket.from_bytes(data)
|
||||
assert packet.connection_handle == 0x0061
|
||||
assert packet.packet_status_flag == 0
|
||||
assert packet.pb_flag == 0x02
|
||||
assert packet.ts_flag == 0x01
|
||||
assert packet.data_total_length == 68
|
||||
assert packet.time_stamp == 2716911914
|
||||
assert packet.packet_sequence_number == 147
|
||||
assert packet.iso_sdu_length == 60
|
||||
assert packet.iso_sdu_fragment == bytes.fromhex(
|
||||
'e83b477b00eba8d41dc018bf1a980f0290afe1e7c3765209669752b6a535a8df61e22931ef5a3'
|
||||
'6281bc77ed6a3206d984bcdabee6be831c699cb50e2'
|
||||
)
|
||||
|
||||
assert packet.to_bytes() == data
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def run_test_events():
|
||||
test_HCI_Event()
|
||||
@@ -515,6 +548,7 @@ def run_test_commands():
|
||||
test_HCI_LE_Set_Default_PHY_Command()
|
||||
test_HCI_LE_Set_Extended_Scan_Parameters_Command()
|
||||
test_HCI_LE_Set_Extended_Advertising_Enable_Command()
|
||||
test_HCI_LE_Setup_ISO_Data_Path_Command()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -523,3 +557,4 @@ if __name__ == '__main__':
|
||||
run_test_commands()
|
||||
test_address()
|
||||
test_custom()
|
||||
test_iso_data_packet()
|
||||
|
||||
Reference in New Issue
Block a user