Compare commits

..

2 Commits

Author SHA1 Message Date
Gilles Boccon-Gibod 07270240e3 fix merge issues 2023-12-26 11:33:15 -08:00
Gilles Boccon-Gibod d12b15b5d4 Merge/rebase 2023-12-26 11:25:54 -08:00
61 changed files with 1739 additions and 6364 deletions
+1 -1
View File
@@ -29,7 +29,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
python-version: '3.10'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
-2
View File
@@ -10,5 +10,3 @@ __pycache__
bumble/_version.py
.vscode/launch.json
/.idea
venv/
.venv/
-4
View File
@@ -12,9 +12,7 @@
"ASHA",
"asyncio",
"ATRAC",
"avctp",
"avdtp",
"avrcp",
"bitpool",
"bitstruct",
"BSCP",
@@ -24,7 +22,6 @@
"cmac",
"CONNECTIONLESS",
"csip",
"csis",
"csrcs",
"CVSD",
"datagram",
@@ -35,7 +32,6 @@
"dhkey",
"diversifier",
"endianness",
"ESCO",
"Fitbit",
"GATTLINK",
"HANDSFREE",
+131 -395
View File
File diff suppressed because it is too large Load Diff
+4 -4
View File
@@ -777,7 +777,7 @@ class ConsoleApp:
if not service:
continue
values = [
await attribute.read_value(connection)
attribute.read_value(connection)
for connection in self.device.connections.values()
]
if not values:
@@ -796,11 +796,11 @@ class ConsoleApp:
if not characteristic:
continue
values = [
await attribute.read_value(connection)
attribute.read_value(connection)
for connection in self.device.connections.values()
]
if not values:
values = [await attribute.read_value(None)]
values = [attribute.read_value(None)]
# TODO: future optimization: convert CCCD value to human readable string
@@ -944,7 +944,7 @@ class ConsoleApp:
# send data to any subscribers
if isinstance(attribute, Characteristic):
await attribute.write_value(None, value)
attribute.write_value(None, value)
if attribute.has_properties(Characteristic.NOTIFY):
await self.device.gatt_server.notify_subscribers(attribute)
if attribute.has_properties(Characteristic.INDICATE):
+45 -66
View File
@@ -18,30 +18,28 @@
import asyncio
import os
import logging
import time
import click
from bumble.company_ids import COMPANY_IDENTIFIERS
from bumble.colors import color
from bumble.core import name_or_number
from bumble.hci import (
HCI_READ_LOCAL_EXTENDED_FEATURES_COMMAND,
HCI_READ_LOCAL_SUPPORTED_FEATURES_COMMAND,
HCI_Read_Local_Extended_Features_Command,
HCI_Read_Local_Supported_Features_Command,
map_null_terminated_utf8_string,
LeFeatureMask,
HCI_SUCCESS,
HCI_LE_SUPPORTED_FEATURES_NAMES,
HCI_VERSION_NAMES,
LMP_VERSION_NAMES,
HCI_Command,
HCI_Command_Complete_Event,
HCI_Command_Status_Event,
HCI_READ_BUFFER_SIZE_COMMAND,
HCI_Read_Buffer_Size_Command,
HCI_READ_BD_ADDR_COMMAND,
HCI_Read_BD_ADDR_Command,
HCI_READ_LOCAL_NAME_COMMAND,
HCI_Read_Local_Name_Command,
HCI_LE_READ_BUFFER_SIZE_COMMAND,
HCI_LE_Read_Buffer_Size_Command,
HCI_LE_READ_MAXIMUM_DATA_LENGTH_COMMAND,
HCI_LE_Read_Maximum_Data_Length_Command,
HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND,
@@ -50,7 +48,6 @@ from bumble.hci import (
HCI_LE_Read_Maximum_Advertising_Data_Length_Command,
HCI_LE_READ_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND,
HCI_LE_Read_Suggested_Default_Data_Length_Command,
HCI_Read_Local_Version_Information_Command,
)
from bumble.host import Host
from bumble.transport import open_transport_or_link
@@ -66,7 +63,37 @@ def command_succeeded(response):
# -----------------------------------------------------------------------------
async def get_classic_info(host: Host) -> None:
async def get_common_info(host):
if host.supports_command(HCI_READ_LOCAL_SUPPORTED_FEATURES_COMMAND):
response = await host.send_command(HCI_Read_Local_Supported_Features_Command())
if response.return_parameters.status == HCI_SUCCESS:
print()
print(color('LMP Features:', 'yellow'))
# TODO: support printing discrete enum values
print(' ', response.return_parameters.lmp_features.hex())
if host.supports_command(HCI_READ_LOCAL_EXTENDED_FEATURES_COMMAND):
response = await host.send_command(
HCI_Read_Local_Extended_Features_Command(page_number=0)
)
if response.return_parameters.status == HCI_SUCCESS:
if response.return_parameters.max_page_number > 0:
print()
print(color('Extended LMP Features:', 'yellow'))
for page in range(1, response.return_parameters.max_page_number + 1):
response = await host.send_command(
HCI_Read_Local_Extended_Features_Command(page_number=page)
)
if response.return_parameters.status == HCI_SUCCESS:
# TODO: support printing discrete enum values
print(f' Page {page}:')
print(' ', response.return_parameters.extended_lmp_features.hex())
# -----------------------------------------------------------------------------
async def get_classic_info(host):
if host.supports_command(HCI_READ_BD_ADDR_COMMAND):
response = await host.send_command(HCI_Read_BD_ADDR_Command())
if command_succeeded(response):
@@ -87,7 +114,7 @@ async def get_classic_info(host: Host) -> None:
# -----------------------------------------------------------------------------
async def get_le_info(host: Host) -> None:
async def get_le_info(host):
print()
if host.supports_command(HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND):
@@ -140,36 +167,11 @@ async def get_le_info(host: Host) -> None:
print(color('LE Features:', 'yellow'))
for feature in host.supported_le_features:
print(LeFeatureMask(feature).name)
print(' ', name_or_number(HCI_LE_SUPPORTED_FEATURES_NAMES, feature))
# -----------------------------------------------------------------------------
async def get_acl_flow_control_info(host: Host) -> None:
print()
if host.supports_command(HCI_READ_BUFFER_SIZE_COMMAND):
response = await host.send_command(
HCI_Read_Buffer_Size_Command(), check_result=True
)
print(
color('ACL Flow Control:', 'yellow'),
f'{response.return_parameters.hc_total_num_acl_data_packets} '
f'packets of size {response.return_parameters.hc_acl_data_packet_length}',
)
if host.supports_command(HCI_LE_READ_BUFFER_SIZE_COMMAND):
response = await host.send_command(
HCI_LE_Read_Buffer_Size_Command(), check_result=True
)
print(
color('LE ACL Flow Control:', 'yellow'),
f'{response.return_parameters.hc_total_num_le_acl_data_packets} '
f'packets of size {response.return_parameters.hc_le_acl_data_packet_length}',
)
# -----------------------------------------------------------------------------
async def async_main(latency_probes, transport):
async def async_main(transport):
print('<<< connecting to HCI...')
async with await open_transport_or_link(transport) as (hci_source, hci_sink):
print('<<< connected')
@@ -177,23 +179,6 @@ async def async_main(latency_probes, transport):
host = Host(hci_source, hci_sink)
await host.reset()
# Measure the latency if requested
latencies = []
if latency_probes:
for _ in range(latency_probes):
start = time.time()
await host.send_command(HCI_Read_Local_Version_Information_Command())
latencies.append(1000 * (time.time() - start))
print(
color('HCI Command Latency:', 'yellow'),
(
f'min={min(latencies):.2f}, '
f'max={max(latencies):.2f}, '
f'average={sum(latencies)/len(latencies):.2f}'
),
'\n',
)
# Print version
print(color('Version:', 'yellow'))
print(
@@ -211,15 +196,15 @@ async def async_main(latency_probes, transport):
)
print(color(' LMP Subversion:', 'green'), host.local_version.lmp_subversion)
# Get the common info
await get_common_info(host)
# Get the Classic info
await get_classic_info(host)
# Get the LE info
await get_le_info(host)
# Print the ACL flow control info
await get_acl_flow_control_info(host)
# Print the list of commands supported by the controller
print()
print(color('Supported Commands:', 'yellow'))
@@ -229,16 +214,10 @@ async def async_main(latency_probes, transport):
# -----------------------------------------------------------------------------
@click.command()
@click.option(
'--latency-probes',
metavar='N',
type=int,
help='Send N commands to measure HCI transport latency statistics',
)
@click.argument('transport')
def main(latency_probes, transport):
def main(transport):
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
asyncio.run(async_main(latency_probes, transport))
asyncio.run(async_main(transport))
# -----------------------------------------------------------------------------
-200
View File
@@ -1,200 +0,0 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import logging
import os
import time
from typing import Optional
from bumble.colors import color
from bumble.hci import (
HCI_READ_LOOPBACK_MODE_COMMAND,
HCI_Read_Loopback_Mode_Command,
HCI_WRITE_LOOPBACK_MODE_COMMAND,
HCI_Write_Loopback_Mode_Command,
LoopbackMode,
)
from bumble.host import Host
from bumble.transport import open_transport_or_link
import click
class Loopback:
"""Send and receive ACL data packets in local loopback mode"""
def __init__(self, packet_size: int, packet_count: int, transport: str):
self.transport = transport
self.packet_size = packet_size
self.packet_count = packet_count
self.connection_handle: Optional[int] = None
self.connection_event = asyncio.Event()
self.done = asyncio.Event()
self.expected_cid = 0
self.bytes_received = 0
self.start_timestamp = 0.0
self.last_timestamp = 0.0
def on_connection(self, connection_handle: int, *args):
"""Retrieve connection handle from new connection event"""
if not self.connection_event.is_set():
# save first connection handle for ACL
# subsequent connections are SCO
self.connection_handle = connection_handle
self.connection_event.set()
def on_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes):
"""Calculate packet receive speed"""
now = time.time()
print(f'<<< Received packet {cid}: {len(pdu)} bytes')
assert connection_handle == self.connection_handle
assert cid == self.expected_cid
self.expected_cid += 1
if cid == 0:
self.start_timestamp = now
else:
elapsed_since_start = now - self.start_timestamp
elapsed_since_last = now - self.last_timestamp
self.bytes_received += len(pdu)
instant_rx_speed = len(pdu) / elapsed_since_last
average_rx_speed = self.bytes_received / elapsed_since_start
print(
color(
f'@@@ RX speed: instant={instant_rx_speed:.4f},'
f' average={average_rx_speed:.4f}',
'cyan',
)
)
self.last_timestamp = now
if self.expected_cid == self.packet_count:
print(color('@@@ Received last packet', 'green'))
self.done.set()
async def run(self):
"""Run a loopback throughput test"""
print(color('>>> Connecting to HCI...', 'green'))
async with await open_transport_or_link(self.transport) as (
hci_source,
hci_sink,
):
print(color('>>> Connected', 'green'))
host = Host(hci_source, hci_sink)
await host.reset()
# make sure data can fit in one l2cap pdu
l2cap_header_size = 4
max_packet_size = host.acl_packet_queue.max_packet_size - l2cap_header_size
if self.packet_size > max_packet_size:
print(
color(
f'!!! Packet size ({self.packet_size}) larger than max supported'
f' size ({max_packet_size})',
'red',
)
)
return
if not host.supports_command(
HCI_WRITE_LOOPBACK_MODE_COMMAND
) or not host.supports_command(HCI_READ_LOOPBACK_MODE_COMMAND):
print(color('!!! Loopback mode not supported', 'red'))
return
# set event callbacks
host.on('connection', self.on_connection)
host.on('l2cap_pdu', self.on_l2cap_pdu)
loopback_mode = LoopbackMode.LOCAL
print(color('### Setting loopback mode', 'blue'))
await host.send_command(
HCI_Write_Loopback_Mode_Command(loopback_mode=LoopbackMode.LOCAL),
check_result=True,
)
print(color('### Checking loopback mode', 'blue'))
response = await host.send_command(
HCI_Read_Loopback_Mode_Command(), check_result=True
)
if response.return_parameters.loopback_mode != loopback_mode:
print(color('!!! Loopback mode mismatch', 'red'))
return
await self.connection_event.wait()
print(color('### Connected', 'cyan'))
print(color('=== Start sending', 'magenta'))
start_time = time.time()
bytes_sent = 0
for cid in range(0, self.packet_count):
# using the cid as an incremental index
host.send_l2cap_pdu(
self.connection_handle, cid, bytes(self.packet_size)
)
print(
color(
f'>>> Sending packet {cid}: {self.packet_size} bytes', 'yellow'
)
)
bytes_sent += self.packet_size # don't count L2CAP or HCI header sizes
await asyncio.sleep(0) # yield to allow packet receive
await self.done.wait()
print(color('=== Done!', 'magenta'))
elapsed = time.time() - start_time
average_tx_speed = bytes_sent / elapsed
print(
color(
f'@@@ TX speed: average={average_tx_speed:.4f} ({bytes_sent} bytes'
f' in {elapsed:.2f} seconds)',
'green',
)
)
# -----------------------------------------------------------------------------
@click.command()
@click.option(
'--packet-size',
'-s',
metavar='SIZE',
type=click.IntRange(8, 4096),
default=500,
help='Packet size',
)
@click.option(
'--packet-count',
'-c',
metavar='COUNT',
type=int,
default=10,
help='Packet count',
)
@click.argument('transport')
def main(packet_size, packet_count, transport):
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
loopback = Loopback(packet_size, packet_count, transport)
asyncio.run(loopback.run())
# -----------------------------------------------------------------------------
if __name__ == '__main__':
main()
+15 -15
View File
@@ -17,31 +17,25 @@
# -----------------------------------------------------------------------------
import logging
import asyncio
import sys
import os
from bumble.controller import Controller
import click
from bumble.controller import Controller, Options
from bumble.link import LocalLink
from bumble.transport import open_transport_or_link
# -----------------------------------------------------------------------------
async def async_main():
if len(sys.argv) != 3:
print(
'Usage: controllers.py <hci-transport-1> <hci-transport-2> '
'[<hci-transport-3> ...]'
)
print('example: python controllers.py pty:ble1 pty:ble2')
return
async def async_main(extended_advertising, transport_names):
# Create a local link to attach the controllers to
link = LocalLink()
# Create a transport and controller for all requested names
transports = []
controllers = []
for index, transport_name in enumerate(sys.argv[1:]):
options = Options(extended_advertising=extended_advertising)
for index, transport_name in enumerate(transport_names):
transport = await open_transport_or_link(transport_name)
transports.append(transport)
controller = Controller(
@@ -49,6 +43,7 @@ async def async_main():
host_source=transport.source,
host_sink=transport.sink,
link=link,
options=options,
)
controllers.append(controller)
@@ -61,9 +56,14 @@ async def async_main():
# -----------------------------------------------------------------------------
def main():
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(async_main())
@click.command()
@click.option(
'--extended-advertising', is_flag=True, help="Enable extended advertising"
)
@click.argument('transports', nargs=-1, required=True)
def main(extended_advertising, transports):
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
asyncio.run(async_main(extended_advertising, transports))
# -----------------------------------------------------------------------------
+24 -32
View File
@@ -49,16 +49,14 @@ class ServerBridge:
self.tcp_port = tcp_port
async def start(self, device: Device) -> None:
# Listen for incoming L2CAP channel connections
# Listen for incoming L2CAP CoC connections
device.create_l2cap_server(
spec=l2cap.LeCreditBasedChannelSpec(
psm=self.psm, mtu=self.mtu, mps=self.mps, max_credits=self.max_credits
),
handler=self.on_channel,
)
print(
color(f'### Listening for channel connection on PSM {self.psm}', 'yellow')
handler=self.on_coc,
)
print(color(f'### Listening for CoC connection on PSM {self.psm}', 'yellow'))
def on_ble_connection(connection):
def on_ble_disconnection(reason):
@@ -75,7 +73,7 @@ class ServerBridge:
await device.start_advertising(auto_restart=True)
# Called when a new L2CAP connection is established
def on_channel(self, l2cap_channel):
def on_coc(self, l2cap_channel):
print(color('*** L2CAP channel:', 'cyan'), l2cap_channel)
class Pipe:
@@ -85,7 +83,7 @@ class ServerBridge:
self.l2cap_channel = l2cap_channel
l2cap_channel.on('close', self.on_l2cap_close)
l2cap_channel.sink = self.on_channel_sdu
l2cap_channel.sink = self.on_coc_sdu
async def connect_to_tcp(self):
# Connect to the TCP server
@@ -130,7 +128,7 @@ class ServerBridge:
if self.tcp_transport is not None:
self.tcp_transport.close()
def on_channel_sdu(self, sdu):
def on_coc_sdu(self, sdu):
print(color(f'<<< [L2CAP SDU]: {len(sdu)} bytes', 'cyan'))
if self.tcp_transport is None:
print(color('!!! TCP socket not open, dropping', 'red'))
@@ -185,7 +183,7 @@ class ClientBridge:
peer_name = writer.get_extra_info('peer_name')
print(color(f'<<< TCP connection from {peer_name}', 'magenta'))
def on_channel_sdu(sdu):
def on_coc_sdu(sdu):
print(color(f'<<< [L2CAP SDU]: {len(sdu)} bytes', 'cyan'))
l2cap_to_tcp_pipe.write(sdu)
@@ -211,7 +209,7 @@ class ClientBridge:
writer.close()
return
l2cap_channel.sink = on_channel_sdu
l2cap_channel.sink = on_coc_sdu
l2cap_channel.on('close', on_l2cap_close)
# Start a flow control pipe from L2CAP to TCP
@@ -276,29 +274,23 @@ async def run(device_config, hci_transport, bridge):
@click.pass_context
@click.option('--device-config', help='Device configuration file', required=True)
@click.option('--hci-transport', help='HCI transport', required=True)
@click.option('--psm', help='PSM for L2CAP', type=int, default=1234)
@click.option('--psm', help='PSM for L2CAP CoC', type=int, default=1234)
@click.option(
'--l2cap-max-credits',
help='Maximum L2CAP Credits',
'--l2cap-coc-max-credits',
help='Maximum L2CAP CoC Credits',
type=click.IntRange(1, 65535),
default=128,
)
@click.option(
'--l2cap-mtu',
help='L2CAP MTU',
type=click.IntRange(
l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU,
l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MTU,
),
default=1024,
'--l2cap-coc-mtu',
help='L2CAP CoC MTU',
type=click.IntRange(23, 65535),
default=1022,
)
@click.option(
'--l2cap-mps',
help='L2CAP MPS',
type=click.IntRange(
l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MPS,
l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MPS,
),
'--l2cap-coc-mps',
help='L2CAP CoC MPS',
type=click.IntRange(23, 65533),
default=1024,
)
def cli(
@@ -306,17 +298,17 @@ def cli(
device_config,
hci_transport,
psm,
l2cap_max_credits,
l2cap_mtu,
l2cap_mps,
l2cap_coc_max_credits,
l2cap_coc_mtu,
l2cap_coc_mps,
):
context.ensure_object(dict)
context.obj['device_config'] = device_config
context.obj['hci_transport'] = hci_transport
context.obj['psm'] = psm
context.obj['max_credits'] = l2cap_max_credits
context.obj['mtu'] = l2cap_mtu
context.obj['mps'] = l2cap_mps
context.obj['max_credits'] = l2cap_coc_max_credits
context.obj['mtu'] = l2cap_coc_mtu
context.obj['mps'] = l2cap_coc_mps
# -----------------------------------------------------------------------------
+8 -3
View File
@@ -253,7 +253,7 @@ class Relay:
# ----------------------------------------------------------------------------
def main():
async def async_main():
# Check the Python version
if sys.version_info < (3, 6, 1):
print('ERROR: Python 3.6.1 or higher is required')
@@ -280,8 +280,13 @@ def main():
# Start a relay
relay = Relay(args.port)
asyncio.get_event_loop().run_until_complete(relay.start())
asyncio.get_event_loop().run_forever()
async with relay.start():
await asyncio.Future()
# ----------------------------------------------------------------------------
def main():
asyncio.run(async_main())
# ----------------------------------------------------------------------------
+8 -44
View File
@@ -26,7 +26,7 @@ from bumble.transport import open_transport_or_link
from bumble.keys import JsonKeyStore
from bumble.smp import AddressResolver
from bumble.device import Advertisement
from bumble.hci import Address, HCI_Constant, HCI_LE_1M_PHY, HCI_LE_CODED_PHY
from bumble.hci import HCI_Constant, HCI_LE_1M_PHY, HCI_LE_CODED_PHY
# -----------------------------------------------------------------------------
@@ -66,15 +66,10 @@ class AdvertisementPrinter:
address_type_string = ('PUBLIC', 'RANDOM', 'PUBLIC_ID', 'RANDOM_ID')[
address.address_type
]
if address.address_type in (
Address.RANDOM_IDENTITY_ADDRESS,
Address.PUBLIC_IDENTITY_ADDRESS,
):
type_color = 'yellow'
if address.is_public:
type_color = 'cyan'
else:
if address.is_public:
type_color = 'cyan'
elif address.is_static:
if address.is_static:
type_color = 'green'
address_qualifier = '(static)'
elif address.is_resolvable:
@@ -121,7 +116,6 @@ async def scan(
phy,
filter_duplicates,
raw,
irks,
keystore_file,
device_config,
transport,
@@ -146,21 +140,9 @@ async def scan(
if device.keystore:
resolving_keys = await device.keystore.get_resolving_keys()
resolver = AddressResolver(resolving_keys)
else:
resolving_keys = []
for irk_and_address in irks:
if ':' not in irk_and_address:
raise ValueError('invalid IRK:ADDRESS value')
irk_hex, address_str = irk_and_address.split(':', 1)
resolving_keys.append(
(
bytes.fromhex(irk_hex),
Address(address_str, Address.RANDOM_DEVICE_ADDRESS),
)
)
resolver = AddressResolver(resolving_keys) if resolving_keys else None
resolver = None
printer = AdvertisementPrinter(min_rssi, resolver)
if raw:
@@ -205,24 +187,8 @@ async def scan(
default=False,
help='Listen for raw advertising reports instead of processed ones',
)
@click.option(
'--irk',
metavar='<IRK_HEX>:<ADDRESS>',
help=(
'Use this IRK for resolving private addresses ' '(may be used more than once)'
),
multiple=True,
)
@click.option(
'--keystore-file',
metavar='FILE_PATH',
help='Keystore file to use when resolving addresses',
)
@click.option(
'--device-config',
metavar='FILE_PATH',
help='Device config file for the scanning device',
)
@click.option('--keystore-file', help='Keystore file to use when resolving addresses')
@click.option('--device-config', help='Device config file for the scanning device')
@click.argument('transport')
def main(
min_rssi,
@@ -232,7 +198,6 @@ def main(
phy,
filter_duplicates,
raw,
irk,
keystore_file,
device_config,
transport,
@@ -247,7 +212,6 @@ def main(
phy,
filter_duplicates,
raw,
irk,
keystore_file,
device_config,
transport,
+4 -12
View File
@@ -184,12 +184,8 @@ def make_audio_source_service_sdp_records(service_record_handle, version=(1, 3))
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.sequence(
[
DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE),
DataElement.unsigned_integer_16(version_int),
]
)
DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE),
DataElement.unsigned_integer_16(version_int),
]
),
),
@@ -238,12 +234,8 @@ def make_audio_sink_service_sdp_records(service_record_handle, version=(1, 3)):
SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID,
DataElement.sequence(
[
DataElement.sequence(
[
DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE),
DataElement.unsigned_integer_16(version_int),
]
)
DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE),
DataElement.unsigned_integer_16(version_int),
]
),
),
+11 -53
View File
@@ -25,21 +25,9 @@
from __future__ import annotations
import enum
import functools
import inspect
import struct
from typing import (
Any,
Awaitable,
Callable,
Dict,
List,
Optional,
Type,
Union,
TYPE_CHECKING,
)
from pyee import EventEmitter
from typing import Dict, Type, List, Protocol, Union, Optional, Any, TYPE_CHECKING
from bumble.core import UUID, name_or_number, ProtocolError
from bumble.hci import HCI_Object, key_with_value
@@ -734,38 +722,12 @@ class ATT_Handle_Value_Confirmation(ATT_PDU):
# -----------------------------------------------------------------------------
class AttributeValue:
'''
Attribute value where reading and/or writing is delegated to functions
passed as arguments to the constructor.
'''
class ConnectionValue(Protocol):
def read(self, connection) -> bytes:
...
def __init__(
self,
read: Union[
Callable[[Optional[Connection]], bytes],
Callable[[Optional[Connection]], Awaitable[bytes]],
None,
] = None,
write: Union[
Callable[[Optional[Connection], bytes], None],
Callable[[Optional[Connection], bytes], Awaitable[None]],
None,
] = None,
):
self._read = read
self._write = write
def read(self, connection: Optional[Connection]) -> Union[bytes, Awaitable[bytes]]:
return self._read(connection) if self._read else b''
def write(
self, connection: Optional[Connection], value: bytes
) -> Union[Awaitable[None], None]:
if self._write:
return self._write(connection, value)
return None
def write(self, connection, value: bytes) -> None:
...
# -----------------------------------------------------------------------------
@@ -808,13 +770,13 @@ class Attribute(EventEmitter):
READ_REQUIRES_AUTHORIZATION = Permissions.READ_REQUIRES_AUTHORIZATION
WRITE_REQUIRES_AUTHORIZATION = Permissions.WRITE_REQUIRES_AUTHORIZATION
value: Union[bytes, AttributeValue]
value: Union[str, bytes, ConnectionValue]
def __init__(
self,
attribute_type: Union[str, bytes, UUID],
permissions: Union[str, Attribute.Permissions],
value: Union[str, bytes, AttributeValue] = b'',
value: Union[str, bytes, ConnectionValue] = b'',
) -> None:
EventEmitter.__init__(self)
self.handle = 0
@@ -844,7 +806,7 @@ class Attribute(EventEmitter):
def decode_value(self, value_bytes: bytes) -> Any:
return value_bytes
async def read_value(self, connection: Optional[Connection]) -> bytes:
def read_value(self, connection: Optional[Connection]) -> bytes:
if (
(self.permissions & self.READ_REQUIRES_ENCRYPTION)
and connection is not None
@@ -870,8 +832,6 @@ class Attribute(EventEmitter):
if hasattr(self.value, 'read'):
try:
value = self.value.read(connection)
if inspect.isawaitable(value):
value = await value
except ATT_Error as error:
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
@@ -881,7 +841,7 @@ class Attribute(EventEmitter):
return self.encode_value(value)
async def write_value(self, connection: Connection, value_bytes: bytes) -> None:
def write_value(self, connection: Connection, value_bytes: bytes) -> None:
if (
self.permissions & self.WRITE_REQUIRES_ENCRYPTION
) and not connection.encryption:
@@ -904,9 +864,7 @@ class Attribute(EventEmitter):
if hasattr(self.value, 'write'):
try:
result = self.value.write(connection, value)
if inspect.isawaitable(result):
await result
self.value.write(connection, value) # pylint: disable=not-callable
except ATT_Error as error:
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
-520
View File
@@ -1,520 +0,0 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import enum
import struct
from typing import Dict, Type, Union, Tuple
from bumble.utils import OpenIntEnum
# -----------------------------------------------------------------------------
class Frame:
class SubunitType(enum.IntEnum):
# AV/C Digital Interface Command Set General Specification Version 4.1
# Table 7.4
MONITOR = 0x00
AUDIO = 0x01
PRINTER = 0x02
DISC = 0x03
TAPE_RECORDER_OR_PLAYER = 0x04
TUNER = 0x05
CA = 0x06
CAMERA = 0x07
PANEL = 0x09
BULLETIN_BOARD = 0x0A
VENDOR_UNIQUE = 0x1C
EXTENDED = 0x1E
UNIT = 0x1F
class OperationCode(OpenIntEnum):
# 0x00 - 0x0F: Unit and subunit commands
VENDOR_DEPENDENT = 0x00
RESERVE = 0x01
PLUG_INFO = 0x02
# 0x10 - 0x3F: Unit commands
DIGITAL_OUTPUT = 0x10
DIGITAL_INPUT = 0x11
CHANNEL_USAGE = 0x12
OUTPUT_PLUG_SIGNAL_FORMAT = 0x18
INPUT_PLUG_SIGNAL_FORMAT = 0x19
GENERAL_BUS_SETUP = 0x1F
CONNECT_AV = 0x20
DISCONNECT_AV = 0x21
CONNECTIONS = 0x22
CONNECT = 0x24
DISCONNECT = 0x25
UNIT_INFO = 0x30
SUBUNIT_INFO = 0x31
# 0x40 - 0x7F: Subunit commands
PASS_THROUGH = 0x7C
GUI_UPDATE = 0x7D
PUSH_GUI_DATA = 0x7E
USER_ACTION = 0x7F
# 0xA0 - 0xBF: Unit and subunit commands
VERSION = 0xB0
POWER = 0xB2
subunit_type: SubunitType
subunit_id: int
opcode: OperationCode
operands: bytes
@staticmethod
def subclass(subclass):
# Infer the opcode from the class name
if subclass.__name__.endswith("CommandFrame"):
short_name = subclass.__name__.replace("CommandFrame", "")
category_class = CommandFrame
elif subclass.__name__.endswith("ResponseFrame"):
short_name = subclass.__name__.replace("ResponseFrame", "")
category_class = ResponseFrame
else:
raise ValueError(f"invalid subclass name {subclass.__name__}")
uppercase_indexes = [
i for i in range(len(short_name)) if short_name[i].isupper()
]
uppercase_indexes.append(len(short_name))
words = [
short_name[uppercase_indexes[i] : uppercase_indexes[i + 1]].upper()
for i in range(len(uppercase_indexes) - 1)
]
opcode_name = "_".join(words)
opcode = Frame.OperationCode[opcode_name]
category_class.subclasses[opcode] = subclass
return subclass
@staticmethod
def from_bytes(data: bytes) -> Frame:
if data[0] >> 4 != 0:
raise ValueError("first 4 bits must be 0s")
ctype_or_response = data[0] & 0xF
subunit_type = Frame.SubunitType(data[1] >> 3)
subunit_id = data[1] & 7
if subunit_type == Frame.SubunitType.EXTENDED:
# Not supported
raise NotImplementedError("extended subunit types not supported")
if subunit_id < 5:
opcode_offset = 2
elif subunit_id == 5:
# Extended to the next byte
extension = data[2]
if extension == 0:
raise ValueError("extended subunit ID value reserved")
if extension == 0xFF:
subunit_id = 5 + 254 + data[3]
opcode_offset = 4
else:
subunit_id = 5 + extension
opcode_offset = 3
elif subunit_id == 6:
raise ValueError("reserved subunit ID")
opcode = Frame.OperationCode(data[opcode_offset])
operands = data[opcode_offset + 1 :]
# Look for a registered subclass
if ctype_or_response < 8:
# Command
ctype = CommandFrame.CommandType(ctype_or_response)
if c_subclass := CommandFrame.subclasses.get(opcode):
return c_subclass(
ctype,
subunit_type,
subunit_id,
*c_subclass.parse_operands(operands),
)
return CommandFrame(ctype, subunit_type, subunit_id, opcode, operands)
else:
# Response
response = ResponseFrame.ResponseCode(ctype_or_response)
if r_subclass := ResponseFrame.subclasses.get(opcode):
return r_subclass(
response,
subunit_type,
subunit_id,
*r_subclass.parse_operands(operands),
)
return ResponseFrame(response, subunit_type, subunit_id, opcode, operands)
def to_bytes(
self,
ctype_or_response: Union[CommandFrame.CommandType, ResponseFrame.ResponseCode],
) -> bytes:
# TODO: support extended subunit types and ids.
return (
bytes(
[
ctype_or_response,
self.subunit_type << 3 | self.subunit_id,
self.opcode,
]
)
+ self.operands
)
def to_string(self, extra: str) -> str:
return (
f"{self.__class__.__name__}({extra}"
f"subunit_type={self.subunit_type.name}, "
f"subunit_id=0x{self.subunit_id:02X}, "
f"opcode={self.opcode.name}, "
f"operands={self.operands.hex()})"
)
def __init__(
self,
subunit_type: SubunitType,
subunit_id: int,
opcode: OperationCode,
operands: bytes,
) -> None:
self.subunit_type = subunit_type
self.subunit_id = subunit_id
self.opcode = opcode
self.operands = operands
# -----------------------------------------------------------------------------
class CommandFrame(Frame):
class CommandType(OpenIntEnum):
# AV/C Digital Interface Command Set General Specification Version 4.1
# Table 7.1
CONTROL = 0x00
STATUS = 0x01
SPECIFIC_INQUIRY = 0x02
NOTIFY = 0x03
GENERAL_INQUIRY = 0x04
subclasses: Dict[Frame.OperationCode, Type[CommandFrame]] = {}
ctype: CommandType
@staticmethod
def parse_operands(operands: bytes) -> Tuple:
raise NotImplementedError
def __init__(
self,
ctype: CommandType,
subunit_type: Frame.SubunitType,
subunit_id: int,
opcode: Frame.OperationCode,
operands: bytes,
) -> None:
super().__init__(subunit_type, subunit_id, opcode, operands)
self.ctype = ctype
def __bytes__(self):
return self.to_bytes(self.ctype)
def __str__(self):
return self.to_string(f"ctype={self.ctype.name}, ")
# -----------------------------------------------------------------------------
class ResponseFrame(Frame):
class ResponseCode(OpenIntEnum):
# AV/C Digital Interface Command Set General Specification Version 4.1
# Table 7.2
NOT_IMPLEMENTED = 0x08
ACCEPTED = 0x09
REJECTED = 0x0A
IN_TRANSITION = 0x0B
IMPLEMENTED_OR_STABLE = 0x0C
CHANGED = 0x0D
INTERIM = 0x0F
subclasses: Dict[Frame.OperationCode, Type[ResponseFrame]] = {}
response: ResponseCode
@staticmethod
def parse_operands(operands: bytes) -> Tuple:
raise NotImplementedError
def __init__(
self,
response: ResponseCode,
subunit_type: Frame.SubunitType,
subunit_id: int,
opcode: Frame.OperationCode,
operands: bytes,
) -> None:
super().__init__(subunit_type, subunit_id, opcode, operands)
self.response = response
def __bytes__(self):
return self.to_bytes(self.response)
def __str__(self):
return self.to_string(f"response={self.response.name}, ")
# -----------------------------------------------------------------------------
class VendorDependentFrame:
company_id: int
vendor_dependent_data: bytes
@staticmethod
def parse_operands(operands: bytes) -> Tuple:
return (
struct.unpack(">I", b"\x00" + operands[:3])[0],
operands[3:],
)
def make_operands(self) -> bytes:
return struct.pack(">I", self.company_id)[1:] + self.vendor_dependent_data
def __init__(self, company_id: int, vendor_dependent_data: bytes):
self.company_id = company_id
self.vendor_dependent_data = vendor_dependent_data
# -----------------------------------------------------------------------------
@Frame.subclass
class VendorDependentCommandFrame(VendorDependentFrame, CommandFrame):
def __init__(
self,
ctype: CommandFrame.CommandType,
subunit_type: Frame.SubunitType,
subunit_id: int,
company_id: int,
vendor_dependent_data: bytes,
) -> None:
VendorDependentFrame.__init__(self, company_id, vendor_dependent_data)
CommandFrame.__init__(
self,
ctype,
subunit_type,
subunit_id,
Frame.OperationCode.VENDOR_DEPENDENT,
self.make_operands(),
)
def __str__(self):
return (
f"VendorDependentCommandFrame(ctype={self.ctype.name}, "
f"subunit_type={self.subunit_type.name}, "
f"subunit_id=0x{self.subunit_id:02X}, "
f"company_id=0x{self.company_id:06X}, "
f"vendor_dependent_data={self.vendor_dependent_data.hex()})"
)
# -----------------------------------------------------------------------------
@Frame.subclass
class VendorDependentResponseFrame(VendorDependentFrame, ResponseFrame):
def __init__(
self,
response: ResponseFrame.ResponseCode,
subunit_type: Frame.SubunitType,
subunit_id: int,
company_id: int,
vendor_dependent_data: bytes,
) -> None:
VendorDependentFrame.__init__(self, company_id, vendor_dependent_data)
ResponseFrame.__init__(
self,
response,
subunit_type,
subunit_id,
Frame.OperationCode.VENDOR_DEPENDENT,
self.make_operands(),
)
def __str__(self):
return (
f"VendorDependentResponseFrame(response={self.response.name}, "
f"subunit_type={self.subunit_type.name}, "
f"subunit_id=0x{self.subunit_id:02X}, "
f"company_id=0x{self.company_id:06X}, "
f"vendor_dependent_data={self.vendor_dependent_data.hex()})"
)
# -----------------------------------------------------------------------------
class PassThroughFrame:
"""
See AV/C Panel Subunit Specification 1.1 - 9.4 PASS THROUGH control command
"""
class StateFlag(enum.IntEnum):
PRESSED = 0
RELEASED = 1
class OperationId(OpenIntEnum):
SELECT = 0x00
UP = 0x01
DOWN = 0x01
LEFT = 0x03
RIGHT = 0x04
RIGHT_UP = 0x05
RIGHT_DOWN = 0x06
LEFT_UP = 0x07
LEFT_DOWN = 0x08
ROOT_MENU = 0x09
SETUP_MENU = 0x0A
CONTENTS_MENU = 0x0B
FAVORITE_MENU = 0x0C
EXIT = 0x0D
NUMBER_0 = 0x20
NUMBER_1 = 0x21
NUMBER_2 = 0x22
NUMBER_3 = 0x23
NUMBER_4 = 0x24
NUMBER_5 = 0x25
NUMBER_6 = 0x26
NUMBER_7 = 0x27
NUMBER_8 = 0x28
NUMBER_9 = 0x29
DOT = 0x2A
ENTER = 0x2B
CLEAR = 0x2C
CHANNEL_UP = 0x30
CHANNEL_DOWN = 0x31
PREVIOUS_CHANNEL = 0x32
SOUND_SELECT = 0x33
INPUT_SELECT = 0x34
DISPLAY_INFORMATION = 0x35
HELP = 0x36
PAGE_UP = 0x37
PAGE_DOWN = 0x38
POWER = 0x40
VOLUME_UP = 0x41
VOLUME_DOWN = 0x42
MUTE = 0x43
PLAY = 0x44
STOP = 0x45
PAUSE = 0x46
RECORD = 0x47
REWIND = 0x48
FAST_FORWARD = 0x49
EJECT = 0x4A
FORWARD = 0x4B
BACKWARD = 0x4C
ANGLE = 0x50
SUBPICTURE = 0x51
F1 = 0x71
F2 = 0x72
F3 = 0x73
F4 = 0x74
F5 = 0x75
VENDOR_UNIQUE = 0x7E
state_flag: StateFlag
operation_id: OperationId
operation_data: bytes
@staticmethod
def parse_operands(operands: bytes) -> Tuple:
return (
PassThroughFrame.StateFlag(operands[0] >> 7),
PassThroughFrame.OperationId(operands[0] & 0x7F),
operands[1 : 1 + operands[1]],
)
def make_operands(self):
return (
bytes([self.state_flag << 7 | self.operation_id, len(self.operation_data)])
+ self.operation_data
)
def __init__(
self,
state_flag: StateFlag,
operation_id: OperationId,
operation_data: bytes,
) -> None:
if len(operation_data) > 255:
raise ValueError("operation data must be <= 255 bytes")
self.state_flag = state_flag
self.operation_id = operation_id
self.operation_data = operation_data
# -----------------------------------------------------------------------------
@Frame.subclass
class PassThroughCommandFrame(PassThroughFrame, CommandFrame):
def __init__(
self,
ctype: CommandFrame.CommandType,
subunit_type: Frame.SubunitType,
subunit_id: int,
state_flag: PassThroughFrame.StateFlag,
operation_id: PassThroughFrame.OperationId,
operation_data: bytes,
) -> None:
PassThroughFrame.__init__(self, state_flag, operation_id, operation_data)
CommandFrame.__init__(
self,
ctype,
subunit_type,
subunit_id,
Frame.OperationCode.PASS_THROUGH,
self.make_operands(),
)
def __str__(self):
return (
f"PassThroughCommandFrame(ctype={self.ctype.name}, "
f"subunit_type={self.subunit_type.name}, "
f"subunit_id=0x{self.subunit_id:02X}, "
f"state_flag={self.state_flag.name}, "
f"operation_id={self.operation_id.name}, "
f"operation_data={self.operation_data.hex()})"
)
# -----------------------------------------------------------------------------
@Frame.subclass
class PassThroughResponseFrame(PassThroughFrame, ResponseFrame):
def __init__(
self,
response: ResponseFrame.ResponseCode,
subunit_type: Frame.SubunitType,
subunit_id: int,
state_flag: PassThroughFrame.StateFlag,
operation_id: PassThroughFrame.OperationId,
operation_data: bytes,
) -> None:
PassThroughFrame.__init__(self, state_flag, operation_id, operation_data)
ResponseFrame.__init__(
self,
response,
subunit_type,
subunit_id,
Frame.OperationCode.PASS_THROUGH,
self.make_operands(),
)
def __str__(self):
return (
f"PassThroughResponseFrame(response={self.response.name}, "
f"subunit_type={self.subunit_type.name}, "
f"subunit_id=0x{self.subunit_id:02X}, "
f"state_flag={self.state_flag.name}, "
f"operation_id={self.operation_id.name}, "
f"operation_data={self.operation_data.hex()})"
)
-291
View File
@@ -1,291 +0,0 @@
# Copyright 2021-2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
from enum import IntEnum
import logging
import struct
from typing import Callable, cast, Dict, Optional
from bumble.colors import color
from bumble import avc
from bumble import l2cap
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
AVCTP_PSM = 0x0017
AVCTP_BROWSING_PSM = 0x001B
# -----------------------------------------------------------------------------
class MessageAssembler:
Callback = Callable[[int, bool, bool, int, bytes], None]
transaction_label: int
pid: int
c_r: int
ipid: int
payload: bytes
number_of_packets: int
packets_received: int
def __init__(self, callback: Callback) -> None:
self.callback = callback
self.reset()
def reset(self) -> None:
self.packets_received = 0
self.transaction_label = -1
self.pid = -1
self.c_r = -1
self.ipid = -1
self.payload = b''
self.number_of_packets = 0
self.packet_count = 0
def on_pdu(self, pdu: bytes) -> None:
self.packets_received += 1
transaction_label = pdu[0] >> 4
packet_type = Protocol.PacketType((pdu[0] >> 2) & 3)
c_r = (pdu[0] >> 1) & 1
ipid = pdu[0] & 1
if c_r == 0 and ipid != 0:
logger.warning("invalid IPID in command frame")
self.reset()
return
pid_offset = 1
if packet_type in (Protocol.PacketType.SINGLE, Protocol.PacketType.START):
if self.transaction_label >= 0:
# We are already in a transaction
logger.warning("received START or SINGLE fragment while in transaction")
self.reset()
self.packets_received = 1
if packet_type == Protocol.PacketType.START:
self.number_of_packets = pdu[1]
pid_offset = 2
pid = struct.unpack_from(">H", pdu, pid_offset)[0]
self.payload += pdu[pid_offset + 2 :]
if packet_type in (Protocol.PacketType.CONTINUE, Protocol.PacketType.END):
if transaction_label != self.transaction_label:
logger.warning("transaction label does not match")
self.reset()
return
if pid != self.pid:
logger.warning("PID does not match")
self.reset()
return
if c_r != self.c_r:
logger.warning("C/R does not match")
self.reset()
return
if self.packets_received > self.number_of_packets:
logger.warning("too many fragments in transaction")
self.reset()
return
if packet_type == Protocol.PacketType.END:
if self.packets_received != self.number_of_packets:
logger.warning("premature END")
self.reset()
return
else:
self.transaction_label = transaction_label
self.c_r = c_r
self.ipid = ipid
self.pid = pid
if packet_type in (Protocol.PacketType.SINGLE, Protocol.PacketType.END):
self.on_message_complete()
def on_message_complete(self):
try:
self.callback(
self.transaction_label,
self.c_r == 0,
self.ipid != 0,
self.pid,
self.payload,
)
except Exception as error:
logger.exception(color(f"!!! exception in callback: {error}", "red"))
self.reset()
# -----------------------------------------------------------------------------
class Protocol:
CommandHandler = Callable[[int, avc.CommandFrame], None]
command_handlers: Dict[int, CommandHandler] # Command handlers, by PID
ResponseHandler = Callable[[int, Optional[avc.ResponseFrame]], None]
response_handlers: Dict[int, ResponseHandler] # Response handlers, by PID
next_transaction_label: int
message_assembler: MessageAssembler
class PacketType(IntEnum):
SINGLE = 0b00
START = 0b01
CONTINUE = 0b10
END = 0b11
def __init__(self, l2cap_channel: l2cap.ClassicChannel) -> None:
self.command_handlers = {}
self.response_handlers = {}
self.l2cap_channel = l2cap_channel
self.message_assembler = MessageAssembler(self.on_message)
# Register to receive PDUs from the channel
l2cap_channel.sink = self.on_pdu
l2cap_channel.on("open", self.on_l2cap_channel_open)
l2cap_channel.on("close", self.on_l2cap_channel_close)
def on_l2cap_channel_open(self):
logger.debug(color("<<< AVCTP channel open", "magenta"))
def on_l2cap_channel_close(self):
logger.debug(color("<<< AVCTP channel closed", "magenta"))
def on_pdu(self, pdu: bytes) -> None:
self.message_assembler.on_pdu(pdu)
def on_message(
self,
transaction_label: int,
is_command: bool,
ipid: bool,
pid: int,
payload: bytes,
) -> None:
logger.debug(
f"<<< AVCTP Message: pid={pid}, "
f"transaction_label={transaction_label}, "
f"is_command={is_command}, "
f"ipid={ipid}, "
f"payload={payload.hex()}"
)
# Check for invalid PID responses.
if ipid:
logger.debug(f"received IPID for PID={pid}")
# Find the appropriate handler.
if is_command:
if pid not in self.command_handlers:
logger.warning(f"no command handler for PID {pid}")
self.send_ipid(transaction_label, pid)
return
command_frame = cast(avc.CommandFrame, avc.Frame.from_bytes(payload))
self.command_handlers[pid](transaction_label, command_frame)
else:
if pid not in self.response_handlers:
logger.warning(f"no response handler for PID {pid}")
return
# By convention, for an ipid, send a None payload to the response handler.
if ipid:
response_frame = None
else:
response_frame = cast(avc.ResponseFrame, avc.Frame.from_bytes(payload))
self.response_handlers[pid](transaction_label, response_frame)
def send_message(
self,
transaction_label: int,
is_command: bool,
ipid: bool,
pid: int,
payload: bytes,
):
# TODO: fragment large messages
packet_type = Protocol.PacketType.SINGLE
pdu = (
struct.pack(
">BH",
transaction_label << 4
| packet_type << 2
| (0 if is_command else 1) << 1
| (1 if ipid else 0),
pid,
)
+ payload
)
self.l2cap_channel.send_pdu(pdu)
def send_command(self, transaction_label: int, pid: int, payload: bytes) -> None:
logger.debug(
">>> AVCTP command: "
f"transaction_label={transaction_label}, "
f"pid={pid}, "
f"payload={payload.hex()}"
)
self.send_message(transaction_label, True, False, pid, payload)
def send_response(self, transaction_label: int, pid: int, payload: bytes):
logger.debug(
">>> AVCTP response: "
f"transaction_label={transaction_label}, "
f"pid={pid}, "
f"payload={payload.hex()}"
)
self.send_message(transaction_label, False, False, pid, payload)
def send_ipid(self, transaction_label: int, pid: int) -> None:
logger.debug(
">>> AVCTP ipid: " f"transaction_label={transaction_label}, " f"pid={pid}"
)
self.send_message(transaction_label, False, True, pid, b'')
def register_command_handler(
self, pid: int, handler: Protocol.CommandHandler
) -> None:
self.command_handlers[pid] = handler
def unregister_command_handler(
self, pid: int, handler: Protocol.CommandHandler
) -> None:
if pid not in self.command_handlers or self.command_handlers[pid] != handler:
raise ValueError("command handler not registered")
del self.command_handlers[pid]
def register_response_handler(
self, pid: int, handler: Protocol.ResponseHandler
) -> None:
self.response_handlers[pid] = handler
def unregister_response_handler(
self, pid: int, handler: Protocol.ResponseHandler
) -> None:
if pid not in self.response_handlers or self.response_handlers[pid] != handler:
raise ValueError("response handler not registered")
del self.response_handlers[pid]
+2 -6
View File
@@ -241,10 +241,7 @@ async def find_avdtp_service_with_sdp_client(
)
if profile_descriptor_list:
for profile_descriptor in profile_descriptor_list.value:
if (
profile_descriptor.type == sdp.DataElement.SEQUENCE
and len(profile_descriptor.value) >= 2
):
if len(profile_descriptor.value) >= 2:
avdtp_version_major = profile_descriptor.value[1].value >> 8
avdtp_version_minor = profile_descriptor.value[1].value & 0xFF
return (avdtp_version_major, avdtp_version_minor)
@@ -514,8 +511,7 @@ class MessageAssembler:
try:
self.callback(self.transaction_label, message)
except Exception as error:
logger.exception(color(f'!!! exception in callback: {error}', 'red'))
logger.warning(color(f'!!! exception in callback: {error}'))
self.reset()
-1916
View File
File diff suppressed because it is too large Load Diff
+682 -492
View File
File diff suppressed because it is too large Load Diff
+10 -14
View File
@@ -97,16 +97,12 @@ class BaseError(Exception):
namespace = f'{self.error_namespace}/'
else:
namespace = ''
have_name = self.error_name != ''
have_code = self.error_code is not None
if have_name and have_code:
error_text = f'{self.error_name} [0x{self.error_code:X}]'
elif have_name and not have_code:
error_text = self.error_name
elif not have_name and have_code:
error_text = f'0x{self.error_code:X}'
else:
error_text = '<unspecified>'
error_text = {
(True, True): f'{self.error_name} [0x{self.error_code:X}]',
(True, False): self.error_name,
(False, True): f'0x{self.error_code:X}',
(False, False): '',
}[(self.error_name != '', self.error_code is not None)]
return f'{type(self).__name__}({namespace}{error_text})'
@@ -323,7 +319,7 @@ BT_HIDP_PROTOCOL_ID = UUID.from_16_bits(0x0011, 'HIDP')
BT_HARDCOPY_CONTROL_CHANNEL_PROTOCOL_ID = UUID.from_16_bits(0x0012, 'HardcopyControlChannel')
BT_HARDCOPY_DATA_CHANNEL_PROTOCOL_ID = UUID.from_16_bits(0x0014, 'HardcopyDataChannel')
BT_HARDCOPY_NOTIFICATION_PROTOCOL_ID = UUID.from_16_bits(0x0016, 'HardcopyNotification')
BT_AVCTP_PROTOCOL_ID = UUID.from_16_bits(0x0017, 'AVCTP')
BT_AVTCP_PROTOCOL_ID = UUID.from_16_bits(0x0017, 'AVCTP')
BT_AVDTP_PROTOCOL_ID = UUID.from_16_bits(0x0019, 'AVDTP')
BT_CMTP_PROTOCOL_ID = UUID.from_16_bits(0x001B, 'CMTP')
BT_MCAP_CONTROL_CHANNEL_PROTOCOL_ID = UUID.from_16_bits(0x001E, 'MCAPControlChannel')
@@ -825,8 +821,8 @@ class AdvertisingData:
ad_structures = []
self.ad_structures = ad_structures[:]
@classmethod
def from_bytes(cls, data: bytes) -> AdvertisingData:
@staticmethod
def from_bytes(data):
instance = AdvertisingData()
instance.append(data)
return instance
@@ -982,7 +978,7 @@ class AdvertisingData:
return ad_data
def append(self, data: bytes) -> None:
def append(self, data):
offset = 0
while offset + 1 < len(data):
length = data[offset]
+93 -137
View File
@@ -21,7 +21,6 @@ import functools
import json
import asyncio
import logging
import secrets
from contextlib import asynccontextmanager, AsyncExitStack, closing
from dataclasses import dataclass
from collections.abc import Iterable
@@ -35,6 +34,7 @@ from typing import (
Tuple,
Type,
TypeVar,
Set,
Union,
cast,
overload,
@@ -49,6 +49,7 @@ from .hci import (
HCI_AUTHENTICATED_COMBINATION_KEY_GENERATED_FROM_P_256_TYPE,
HCI_CENTRAL_ROLE,
HCI_COMMAND_STATUS_PENDING,
HCI_CONNECTED_ISOCHRONOUS_STREAM_LE_SUPPORTED_FEATURE,
HCI_CONNECTION_REJECTED_DUE_TO_LIMITED_RESOURCES_ERROR,
HCI_DISPLAY_YES_NO_IO_CAPABILITY,
HCI_DISPLAY_ONLY_IO_CAPABILITY,
@@ -59,8 +60,12 @@ from .hci import (
HCI_LE_1M_PHY,
HCI_LE_1M_PHY_BIT,
HCI_LE_2M_PHY,
HCI_LE_2M_PHY_LE_SUPPORTED_FEATURE,
HCI_LE_CLEAR_RESOLVING_LIST_COMMAND,
HCI_LE_CODED_PHY,
HCI_LE_CODED_PHY_BIT,
HCI_LE_CODED_PHY_LE_SUPPORTED_FEATURE,
HCI_LE_EXTENDED_ADVERTISING_LE_SUPPORTED_FEATURE,
HCI_LE_EXTENDED_CREATE_CONNECTION_COMMAND,
HCI_LE_RAND_COMMAND,
HCI_LE_READ_PHY_COMMAND,
@@ -81,7 +86,7 @@ from .hci import (
HCI_Constant,
HCI_Create_Connection_Cancel_Command,
HCI_Create_Connection_Command,
HCI_Connection_Complete_Event,
HCI_Create_Connection_Command,
HCI_Disconnect_Command,
HCI_Encryption_Change_Event,
HCI_Error,
@@ -102,7 +107,6 @@ from .hci import (
HCI_LE_Extended_Create_Connection_Command,
HCI_LE_Rand_Command,
HCI_LE_Read_PHY_Command,
HCI_LE_Read_Remote_Features_Command,
HCI_LE_Reject_CIS_Request_Command,
HCI_LE_Remove_Advertising_Set_Command,
HCI_LE_Set_Address_Resolution_Enable_Command,
@@ -148,8 +152,6 @@ from .hci import (
HCI_Write_Secure_Connections_Host_Support_Command,
HCI_Write_Simple_Pairing_Mode_Command,
OwnAddressType,
LeFeature,
LeFeatureMask,
phy_list_to_bits,
)
from .host import Host
@@ -243,40 +245,16 @@ DEVICE_DEFAULT_L2CAP_COC_MAX_CREDITS = l2cap.L2CAP_LE_CREDIT_BASED_CONN
# -----------------------------------------------------------------------------
@dataclass
class Advertisement:
# Attributes
address: Address
rssi: int = HCI_LE_Extended_Advertising_Report_Event.RSSI_NOT_AVAILABLE
is_legacy: bool = False
is_anonymous: bool = False
is_connectable: bool = False
is_directed: bool = False
is_scannable: bool = False
is_scan_response: bool = False
is_complete: bool = True
is_truncated: bool = False
primary_phy: int = 0
secondary_phy: int = 0
tx_power: int = (
TX_POWER_NOT_AVAILABLE = (
HCI_LE_Extended_Advertising_Report_Event.TX_POWER_INFORMATION_NOT_AVAILABLE
)
sid: int = 0
data_bytes: bytes = b''
# Constants
TX_POWER_NOT_AVAILABLE: ClassVar[
int
] = HCI_LE_Extended_Advertising_Report_Event.TX_POWER_INFORMATION_NOT_AVAILABLE
RSSI_NOT_AVAILABLE: ClassVar[
int
] = HCI_LE_Extended_Advertising_Report_Event.RSSI_NOT_AVAILABLE
def __post_init__(self) -> None:
self.data = AdvertisingData.from_bytes(self.data_bytes)
RSSI_NOT_AVAILABLE = HCI_LE_Extended_Advertising_Report_Event.RSSI_NOT_AVAILABLE
@classmethod
def from_advertising_report(cls, report) -> Optional[Advertisement]:
def from_advertising_report(cls, report):
if isinstance(report, HCI_LE_Advertising_Report_Event.Report):
return LegacyAdvertisement.from_advertising_report(report)
@@ -285,6 +263,41 @@ class Advertisement:
return None
# pylint: disable=line-too-long
def __init__(
self,
address,
rssi=HCI_LE_Extended_Advertising_Report_Event.RSSI_NOT_AVAILABLE,
is_legacy=False,
is_anonymous=False,
is_connectable=False,
is_directed=False,
is_scannable=False,
is_scan_response=False,
is_complete=True,
is_truncated=False,
primary_phy=0,
secondary_phy=0,
tx_power=HCI_LE_Extended_Advertising_Report_Event.TX_POWER_INFORMATION_NOT_AVAILABLE,
sid=0,
data=b'',
):
self.address = address
self.rssi = rssi
self.is_legacy = is_legacy
self.is_anonymous = is_anonymous
self.is_connectable = is_connectable
self.is_directed = is_directed
self.is_scannable = is_scannable
self.is_scan_response = is_scan_response
self.is_complete = is_complete
self.is_truncated = is_truncated
self.primary_phy = primary_phy
self.secondary_phy = secondary_phy
self.tx_power = tx_power
self.sid = sid
self.data = AdvertisingData.from_bytes(data)
# -----------------------------------------------------------------------------
class LegacyAdvertisement(Advertisement):
@@ -308,7 +321,7 @@ class LegacyAdvertisement(Advertisement):
),
is_scan_response=report.event_type
== HCI_LE_Advertising_Report_Event.SCAN_RSP,
data_bytes=report.data,
data=report.data,
)
@@ -333,7 +346,7 @@ class ExtendedAdvertisement(Advertisement):
secondary_phy = report.secondary_phy,
tx_power = report.tx_power,
sid = report.advertising_sid,
data_bytes = report.data
data = report.data
)
# fmt: on
@@ -394,7 +407,7 @@ class AdvertisingType(IntEnum):
# fmt: on
@property
def has_data(self) -> bool:
def has_data(self):
return self in (
AdvertisingType.UNDIRECTED_CONNECTABLE_SCANNABLE,
AdvertisingType.UNDIRECTED_SCANNABLE,
@@ -402,7 +415,7 @@ class AdvertisingType(IntEnum):
)
@property
def is_connectable(self) -> bool:
def is_connectable(self):
return self in (
AdvertisingType.UNDIRECTED_CONNECTABLE_SCANNABLE,
AdvertisingType.DIRECTED_CONNECTABLE_HIGH_DUTY,
@@ -410,14 +423,14 @@ class AdvertisingType(IntEnum):
)
@property
def is_scannable(self) -> bool:
def is_scannable(self):
return self in (
AdvertisingType.UNDIRECTED_CONNECTABLE_SCANNABLE,
AdvertisingType.UNDIRECTED_SCANNABLE,
)
@property
def is_directed(self) -> bool:
def is_directed(self):
return self in (
AdvertisingType.DIRECTED_CONNECTABLE_HIGH_DUTY,
AdvertisingType.DIRECTED_CONNECTABLE_LOW_DUTY,
@@ -669,7 +682,6 @@ class Connection(CompositeEventEmitter):
self_address: Address
peer_address: Address
peer_resolvable_address: Optional[Address]
peer_le_features: Optional[LeFeatureMask]
role: int
encryption: int
authenticated: bool
@@ -746,7 +758,6 @@ class Connection(CompositeEventEmitter):
) # By default, use the device's shared server
self.pairing_peer_io_capability = None
self.pairing_peer_authentication_requirements = None
self.peer_le_features = None
# [Classic only]
@classmethod
@@ -895,15 +906,6 @@ class Connection(CompositeEventEmitter):
async def request_remote_name(self):
return await self.device.request_remote_name(self)
async def get_remote_le_features(self) -> LeFeatureMask:
"""[LE Only] Reads remote LE supported features.
Returns:
LE features supported by the remote device.
"""
self.peer_le_features = await self.device.get_remote_le_features(self)
return self.peer_le_features
async def __aenter__(self):
return self
@@ -995,15 +997,12 @@ class DeviceConfiguration:
irk = config.get('irk')
if irk:
self.irk = bytes.fromhex(irk)
elif self.address != Address(DEVICE_DEFAULT_ADDRESS):
else:
# Construct an IRK from the address bytes
# NOTE: this is not secure, but will always give the same IRK for the same
# address
address_bytes = bytes(self.address)
self.irk = (address_bytes * 3)[:16]
else:
# Fallback - when both IRK and address are not set, randomly generate an IRK.
self.irk = secrets.token_bytes(16)
# Load advertising data
advertising_data = config.get('advertising_data')
@@ -1539,7 +1538,9 @@ class Device(CompositeEventEmitter):
if self.cis_enabled:
await self.send_command(
HCI_LE_Set_Host_Feature_Command(
bit_number=LeFeature.CONNECTED_ISOCHRONOUS_STREAM,
bit_number=(
HCI_CONNECTED_ISOCHRONOUS_STREAM_LE_SUPPORTED_FEATURE
),
bit_value=1,
)
)
@@ -1585,16 +1586,6 @@ class Device(CompositeEventEmitter):
if self.address_resolution_offload:
await self.send_command(HCI_LE_Clear_Resolving_List_Command())
# Add an empty entry for non-directed address generation.
await self.send_command(
HCI_LE_Add_Device_To_Resolving_List_Command(
peer_identity_address_type=Address.ANY.address_type,
peer_identity_address=Address.ANY,
peer_irk=bytes(16),
local_irk=self.irk,
)
)
for irk, address in resolving_keys:
await self.send_command(
HCI_LE_Add_Device_To_Resolving_List_Command(
@@ -1605,21 +1596,21 @@ class Device(CompositeEventEmitter):
)
)
def supports_le_features(self, feature: LeFeatureMask) -> bool:
return self.host.supports_le_features(feature)
def supports_le_feature(self, feature):
return self.host.supports_le_feature(feature)
def supports_le_phy(self, phy):
if phy == HCI_LE_1M_PHY:
return True
feature_map = {
HCI_LE_2M_PHY: LeFeatureMask.LE_2M_PHY,
HCI_LE_CODED_PHY: LeFeatureMask.LE_CODED_PHY,
HCI_LE_2M_PHY: HCI_LE_2M_PHY_LE_SUPPORTED_FEATURE,
HCI_LE_CODED_PHY: HCI_LE_CODED_PHY_LE_SUPPORTED_FEATURE,
}
if phy not in feature_map:
raise ValueError('invalid PHY')
return self.host.supports_le_features(feature_map[phy])
return self.host.supports_le_feature(feature_map[phy])
@deprecated("Please use start_legacy_advertising.")
async def start_advertising(
@@ -1695,8 +1686,8 @@ class Device(CompositeEventEmitter):
peer_address = target
peer_address_type = target.address_type
else:
peer_address = Address.ANY
peer_address_type = Address.ANY.address_type
peer_address = Address('00:00:00:00:00:00')
peer_address_type = Address.PUBLIC_DEVICE_ADDRESS
# Set the advertising parameters
await self.send_command(
@@ -1929,8 +1920,8 @@ class Device(CompositeEventEmitter):
self.advertisement_accumulators = {}
# Enable scanning
if not legacy and self.supports_le_features(
LeFeatureMask.LE_EXTENDED_ADVERTISING
if not legacy and self.supports_le_feature(
HCI_LE_EXTENDED_ADVERTISING_LE_SUPPORTED_FEATURE
):
# Set the scanning parameters
scan_type = (
@@ -1948,7 +1939,7 @@ class Device(CompositeEventEmitter):
scanning_phys_bits |= 1 << HCI_LE_1M_PHY_BIT
scanning_phy_count += 1
if HCI_LE_CODED_PHY in scanning_phys:
if self.supports_le_features(LeFeatureMask.LE_CODED_PHY):
if self.supports_le_feature(HCI_LE_CODED_PHY_LE_SUPPORTED_FEATURE):
scanning_phys_bits |= 1 << HCI_LE_CODED_PHY_BIT
scanning_phy_count += 1
@@ -2009,7 +2000,7 @@ class Device(CompositeEventEmitter):
async def stop_scanning(self) -> None:
# Disable scanning
if self.supports_le_features(LeFeatureMask.LE_EXTENDED_ADVERTISING):
if self.supports_le_feature(HCI_LE_EXTENDED_ADVERTISING_LE_SUPPORTED_FEATURE):
await self.send_command(
HCI_LE_Set_Extended_Scan_Enable_Command(
enable=0, filter_duplicates=0, duration=0, period=0
@@ -3080,30 +3071,34 @@ class Device(CompositeEventEmitter):
cig_id=cig_id,
)
result = await self.send_command(
HCI_LE_Create_CIS_Command(
cis_connection_handle=[p[0] for p in cis_acl_pairs],
acl_connection_handle=[p[1] for p in cis_acl_pairs],
),
)
if result.status != HCI_COMMAND_STATUS_PENDING:
logger.warning(
'HCI_LE_Create_CIS_Command failed: '
f'{HCI_Constant.error_name(result.status)}'
)
raise HCI_StatusError(result)
pending_cis_establishments: Dict[int, asyncio.Future[CisLink]] = {}
for cis_handle, _ in cis_acl_pairs:
pending_cis_establishments[
cis_handle
] = asyncio.get_running_loop().create_future()
with closing(EventWatcher()) as watcher:
pending_cis_establishments = {
cis_handle: asyncio.get_running_loop().create_future()
for cis_handle, _ in cis_acl_pairs
}
@watcher.on(self, 'cis_establishment')
def on_cis_establishment(cis_link: CisLink) -> None:
if pending_future := pending_cis_establishments.get(cis_link.handle):
if pending_future := pending_cis_establishments.get(
cis_link.handle, None
):
pending_future.set_result(cis_link)
result = await self.send_command(
HCI_LE_Create_CIS_Command(
cis_connection_handle=[p[0] for p in cis_acl_pairs],
acl_connection_handle=[p[1] for p in cis_acl_pairs],
),
)
if result.status != HCI_COMMAND_STATUS_PENDING:
logger.warning(
'HCI_LE_Create_CIS_Command failed: '
f'{HCI_Constant.error_name(result.status)}'
)
raise HCI_StatusError(result)
return await asyncio.gather(*pending_cis_establishments.values())
# [LE only]
@@ -3147,32 +3142,6 @@ class Device(CompositeEventEmitter):
)
raise HCI_StatusError(result)
async def get_remote_le_features(self, connection: Connection) -> LeFeatureMask:
"""[LE Only] Reads remote LE supported features.
Args:
handle: connection handle to read LE features.
Returns:
LE features supported by the remote device.
"""
with closing(EventWatcher()) as watcher:
read_feature_future: asyncio.Future[
LeFeatureMask
] = asyncio.get_running_loop().create_future()
def on_le_remote_features(handle: int, features: int):
if handle == connection.handle:
read_feature_future.set_result(LeFeatureMask(features))
watcher.on(self.host, 'le_remote_features', on_le_remote_features)
await self.send_command(
HCI_LE_Read_Remote_Features_Command(
connection_handle=connection.handle
),
)
return await read_feature_future
@host_event_handler
def on_flush(self):
self.emit('flush')
@@ -3350,21 +3319,8 @@ class Device(CompositeEventEmitter):
def on_connection_request(self, bd_addr, class_of_device, link_type):
logger.debug(f'*** Connection request: {bd_addr}')
# Handle SCO request.
if link_type in (
HCI_Connection_Complete_Event.SCO_LINK_TYPE,
HCI_Connection_Complete_Event.ESCO_LINK_TYPE,
):
if connection := self.find_connection_by_bd_addr(
bd_addr, transport=BT_BR_EDR_TRANSPORT
):
self.emit('sco_request', connection, link_type)
else:
logger.error(f'SCO request from a non-connected device {bd_addr}')
return
# match a pending future using `bd_addr`
elif bd_addr in self.classic_pending_accepts:
if bd_addr in self.classic_pending_accepts:
future, *_ = self.classic_pending_accepts.pop(bd_addr)
future.set_result((bd_addr, class_of_device, link_type))
@@ -3751,7 +3707,7 @@ class Device(CompositeEventEmitter):
@host_event_handler
@experimental('Only for testing')
def on_sco_packet(self, sco_handle: int, packet: HCI_SynchronousDataPacket) -> None:
if sco_link := self.sco_links.get(sco_handle):
if sco_link := self.sco_links.get(sco_handle, None):
sco_link.emit('pdu', packet)
# [LE only]
@@ -3831,7 +3787,7 @@ class Device(CompositeEventEmitter):
@experimental('Only for testing')
def on_cis_establishment_failure(self, cis_handle: int, status: int) -> None:
logger.debug(f'*** CIS Establishment Failure: cis=[0x{cis_handle:04X}] ***')
if cis_link := self.cis_links.pop(cis_handle):
if cis_link := self.cis_links.pop(cis_handle, None):
cis_link.emit('establishment_failure')
self.emit('cis_establishment_failure', cis_handle, status)
@@ -3839,7 +3795,7 @@ class Device(CompositeEventEmitter):
@host_event_handler
@experimental('Only for testing')
def on_iso_packet(self, handle: int, packet: HCI_IsoDataPacket) -> None:
if cis_link := self.cis_links.get(handle):
if cis_link := self.cis_links.get(handle, None):
cis_link.emit('pdu', packet)
@host_event_handler
+48 -60
View File
@@ -23,28 +23,16 @@
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import enum
import functools
import logging
import struct
from typing import (
Callable,
Dict,
Iterable,
List,
Optional,
Sequence,
Union,
TYPE_CHECKING,
)
from typing import Optional, Sequence, Iterable, List, Union
from bumble.colors import color
from bumble.core import UUID
from bumble.att import Attribute, AttributeValue
if TYPE_CHECKING:
from bumble.gatt_client import AttributeProxy
from bumble.device import Connection
from .colors import color
from .core import UUID, get_dict_key_by_value
from .att import Attribute
# -----------------------------------------------------------------------------
@@ -534,43 +522,56 @@ class CharacteristicDeclaration(Attribute):
# -----------------------------------------------------------------------------
class CharacteristicValue(AttributeValue):
"""Same as AttributeValue, for backward compatibility"""
class CharacteristicValue:
'''
Characteristic value where reading and/or writing is delegated to functions
passed as arguments to the constructor.
'''
def __init__(self, read=None, write=None):
self._read = read
self._write = write
def read(self, connection):
return self._read(connection) if self._read else b''
def write(self, connection, value):
if self._write:
self._write(connection, value)
# -----------------------------------------------------------------------------
class CharacteristicAdapter:
'''
An adapter that can adapt Characteristic and AttributeProxy objects
by wrapping their `read_value()` and `write_value()` methods with ones that
return/accept encoded/decoded values.
For proxies (i.e used by a GATT client), the adaptation is one where the return
value of `read_value()` is decoded and the value passed to `write_value()` is
encoded. The `subscribe()` method, is wrapped with one where the values are decoded
before being passed to the subscriber.
For local values (i.e hosted by a GATT server) the adaptation is one where the
return value of `read_value()` is encoded and the value passed to `write_value()`
is decoded.
An adapter that can adapt any object with `read_value` and `write_value`
methods (like Characteristic and CharacteristicProxy objects) by wrapping
those methods with ones that return/accept encoded/decoded values.
Objects with async methods are considered proxies, so the adaptation is one
where the return value of `read_value` is decoded and the value passed to
`write_value` is encoded. Other objects are considered local characteristics
so the adaptation is one where the return value of `read_value` is encoded
and the value passed to `write_value` is decoded.
If the characteristic has a `subscribe` method, it is wrapped with one where
the values are decoded before being passed to the subscriber.
'''
read_value: Callable
write_value: Callable
def __init__(self, characteristic: Union[Characteristic, AttributeProxy]):
def __init__(self, characteristic):
self.wrapped_characteristic = characteristic
self.subscribers: Dict[
Callable, Callable
] = {} # Map from subscriber to proxy subscriber
self.subscribers = {} # Map from subscriber to proxy subscriber
if isinstance(characteristic, Characteristic):
self.read_value = self.read_encoded_value
self.write_value = self.write_encoded_value
else:
if asyncio.iscoroutinefunction(
characteristic.read_value
) and asyncio.iscoroutinefunction(characteristic.write_value):
self.read_value = self.read_decoded_value
self.write_value = self.write_decoded_value
else:
self.read_value = self.read_encoded_value
self.write_value = self.write_encoded_value
if hasattr(self.wrapped_characteristic, 'subscribe'):
self.subscribe = self.wrapped_subscribe
if hasattr(self.wrapped_characteristic, 'unsubscribe'):
self.unsubscribe = self.wrapped_unsubscribe
def __getattr__(self, name):
@@ -589,13 +590,11 @@ class CharacteristicAdapter:
else:
setattr(self.wrapped_characteristic, name, value)
async def read_encoded_value(self, connection):
return self.encode_value(
await self.wrapped_characteristic.read_value(connection)
)
def read_encoded_value(self, connection):
return self.encode_value(self.wrapped_characteristic.read_value(connection))
async def write_encoded_value(self, connection, value):
return await self.wrapped_characteristic.write_value(
def write_encoded_value(self, connection, value):
return self.wrapped_characteristic.write_value(
connection, self.decode_value(value)
)
@@ -730,24 +729,13 @@ class Descriptor(Attribute):
'''
def __str__(self) -> str:
if isinstance(self.value, bytes):
value_str = self.value.hex()
elif isinstance(self.value, CharacteristicValue):
value = self.value.read(None)
if isinstance(value, bytes):
value_str = value.hex()
else:
value_str = '<async>'
else:
value_str = '<...>'
return (
f'Descriptor(handle=0x{self.handle:04X}, '
f'type={self.type}, '
f'value={value_str})'
f'value={self.read_value(None).hex()})'
)
# -----------------------------------------------------------------------------
class ClientCharacteristicConfigurationBits(enum.IntFlag):
'''
See Vol 3, Part G - 3.3.3.3 - Table 3.11 Client Characteristic Configuration bit
+21 -29
View File
@@ -31,9 +31,9 @@ import struct
from typing import List, Tuple, Optional, TypeVar, Type, Dict, Iterable, TYPE_CHECKING
from pyee import EventEmitter
from bumble.colors import color
from bumble.core import UUID
from bumble.att import (
from .colors import color
from .core import UUID
from .att import (
ATT_ATTRIBUTE_NOT_FOUND_ERROR,
ATT_ATTRIBUTE_NOT_LONG_ERROR,
ATT_CID,
@@ -60,7 +60,7 @@ from bumble.att import (
ATT_Write_Response,
Attribute,
)
from bumble.gatt import (
from .gatt import (
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
GATT_MAX_ATTRIBUTE_VALUE_SIZE,
@@ -74,7 +74,6 @@ from bumble.gatt import (
Descriptor,
Service,
)
from bumble.utils import AsyncRunner
if TYPE_CHECKING:
from bumble.device import Device, Connection
@@ -380,7 +379,7 @@ class Server(EventEmitter):
# Get or encode the value
value = (
await attribute.read_value(connection)
attribute.read_value(connection)
if value is None
else attribute.encode_value(value)
)
@@ -423,7 +422,7 @@ class Server(EventEmitter):
# Get or encode the value
value = (
await attribute.read_value(connection)
attribute.read_value(connection)
if value is None
else attribute.encode_value(value)
)
@@ -651,8 +650,7 @@ class Server(EventEmitter):
self.send_response(connection, response)
@AsyncRunner.run_in_task()
async def on_att_find_by_type_value_request(self, connection, request):
def on_att_find_by_type_value_request(self, connection, request):
'''
See Bluetooth spec Vol 3, Part F - 3.4.3.3 Find By Type Value Request
'''
@@ -660,13 +658,13 @@ class Server(EventEmitter):
# Build list of returned attributes
pdu_space_available = connection.att_mtu - 2
attributes = []
async for attribute in (
for attribute in (
attribute
for attribute in self.attributes
if attribute.handle >= request.starting_handle
and attribute.handle <= request.ending_handle
and attribute.type == request.attribute_type
and (await attribute.read_value(connection)) == request.attribute_value
and attribute.read_value(connection) == request.attribute_value
and pdu_space_available >= 4
):
# TODO: check permissions
@@ -704,8 +702,7 @@ class Server(EventEmitter):
self.send_response(connection, response)
@AsyncRunner.run_in_task()
async def on_att_read_by_type_request(self, connection, request):
def on_att_read_by_type_request(self, connection, request):
'''
See Bluetooth spec Vol 3, Part F - 3.4.4.1 Read By Type Request
'''
@@ -728,7 +725,7 @@ class Server(EventEmitter):
and pdu_space_available
):
try:
attribute_value = await attribute.read_value(connection)
attribute_value = attribute.read_value(connection)
except ATT_Error as error:
# If the first attribute is unreadable, return an error
# Otherwise return attributes up to this point
@@ -770,15 +767,14 @@ class Server(EventEmitter):
self.send_response(connection, response)
@AsyncRunner.run_in_task()
async def on_att_read_request(self, connection, request):
def on_att_read_request(self, connection, request):
'''
See Bluetooth spec Vol 3, Part F - 3.4.4.3 Read Request
'''
if attribute := self.get_attribute(request.attribute_handle):
try:
value = await attribute.read_value(connection)
value = attribute.read_value(connection)
except ATT_Error as error:
response = ATT_Error_Response(
request_opcode_in_error=request.op_code,
@@ -796,15 +792,14 @@ class Server(EventEmitter):
)
self.send_response(connection, response)
@AsyncRunner.run_in_task()
async def on_att_read_blob_request(self, connection, request):
def on_att_read_blob_request(self, connection, request):
'''
See Bluetooth spec Vol 3, Part F - 3.4.4.5 Read Blob Request
'''
if attribute := self.get_attribute(request.attribute_handle):
try:
value = await attribute.read_value(connection)
value = attribute.read_value(connection)
except ATT_Error as error:
response = ATT_Error_Response(
request_opcode_in_error=request.op_code,
@@ -841,8 +836,7 @@ class Server(EventEmitter):
)
self.send_response(connection, response)
@AsyncRunner.run_in_task()
async def on_att_read_by_group_type_request(self, connection, request):
def on_att_read_by_group_type_request(self, connection, request):
'''
See Bluetooth spec Vol 3, Part F - 3.4.4.9 Read by Group Type Request
'''
@@ -870,7 +864,7 @@ class Server(EventEmitter):
):
# No need to catch permission errors here, since these attributes
# must all be world-readable
attribute_value = await attribute.read_value(connection)
attribute_value = attribute.read_value(connection)
# Check the attribute value size
max_attribute_size = min(connection.att_mtu - 6, 251)
if len(attribute_value) > max_attribute_size:
@@ -909,8 +903,7 @@ class Server(EventEmitter):
self.send_response(connection, response)
@AsyncRunner.run_in_task()
async def on_att_write_request(self, connection, request):
def on_att_write_request(self, connection, request):
'''
See Bluetooth spec Vol 3, Part F - 3.4.5.1 Write Request
'''
@@ -943,13 +936,12 @@ class Server(EventEmitter):
return
# Accept the value
await attribute.write_value(connection, request.attribute_value)
attribute.write_value(connection, request.attribute_value)
# Done
self.send_response(connection, ATT_Write_Response())
@AsyncRunner.run_in_task()
async def on_att_write_command(self, connection, request):
def on_att_write_command(self, connection, request):
'''
See Bluetooth spec Vol 3, Part F - 3.4.5.3 Write Command
'''
@@ -967,7 +959,7 @@ class Server(EventEmitter):
# Accept the value
try:
await attribute.write_value(connection, request.attribute_value)
attribute.write_value(connection, request.attribute_value)
except Exception as error:
logger.exception(f'!!! ignoring exception: {error}')
+58 -133
View File
@@ -1360,97 +1360,55 @@ HCI_SUPPORTED_COMMANDS_FLAGS = (
# LE Supported Features
# See Bluetooth spec @ Vol 6, Part B, 4.6 FEATURE SUPPORT
class LeFeature(enum.IntEnum):
LE_ENCRYPTION = 0
CONNECTION_PARAMETERS_REQUEST_PROCEDURE = 1
EXTENDED_REJECT_INDICATION = 2
PERIPHERAL_INITIATED_FEATURE_EXCHANGE = 3
LE_PING = 4
LE_DATA_PACKET_LENGTH_EXTENSION = 5
LL_PRIVACY = 6
EXTENDED_SCANNER_FILTER_POLICIES = 7
LE_2M_PHY = 8
STABLE_MODULATION_INDEX_TRANSMITTER = 9
STABLE_MODULATION_INDEX_RECEIVER = 10
LE_CODED_PHY = 11
LE_EXTENDED_ADVERTISING = 12
LE_PERIODIC_ADVERTISING = 13
CHANNEL_SELECTION_ALGORITHM_2 = 14
LE_POWER_CLASS_1 = 15
MINIMUM_NUMBER_OF_USED_CHANNELS_PROCEDURE = 16
CONNECTION_CTE_REQUEST = 17
CONNECTION_CTE_RESPONSE = 18
CONNECTIONLESS_CTE_TRANSMITTER = 19
CONNECTIONLESS_CTR_RECEIVER = 20
ANTENNA_SWITCHING_DURING_CTE_TRANSMISSION = 21
ANTENNA_SWITCHING_DURING_CTE_RECEPTION = 22
RECEIVING_CONSTANT_TONE_EXTENSIONS = 23
PERIODIC_ADVERTISING_SYNC_TRANSFER_SENDER = 24
PERIODIC_ADVERTISING_SYNC_TRANSFER_RECIPIENT = 25
SLEEP_CLOCK_ACCURACY_UPDATES = 26
REMOTE_PUBLIC_KEY_VALIDATION = 27
CONNECTED_ISOCHRONOUS_STREAM_CENTRAL = 28
CONNECTED_ISOCHRONOUS_STREAM_PERIPHERAL = 29
ISOCHRONOUS_BROADCASTER = 30
SYNCHRONIZED_RECEIVER = 31
CONNECTED_ISOCHRONOUS_STREAM = 32
LE_POWER_CONTROL_REQUEST = 33
LE_POWER_CONTROL_REQUEST_DUP = 34
LE_PATH_LOSS_MONITORING = 35
PERIODIC_ADVERTISING_ADI_SUPPORT = 36
CONNECTION_SUBRATING = 37
CONNECTION_SUBRATING_HOST_SUPPORT = 38
CHANNEL_CLASSIFICATION = 39
ADVERTISING_CODING_SELECTION = 40
ADVERTISING_CODING_SELECTION_HOST_SUPPORT = 41
PERIODIC_ADVERTISING_WITH_RESPONSES_ADVERTISER = 43
PERIODIC_ADVERTISING_WITH_RESPONSES_SCANNER = 44
HCI_LE_ENCRYPTION_LE_SUPPORTED_FEATURE = 0
HCI_CONNECTION_PARAMETERS_REQUEST_PROCEDURE_LE_SUPPORTED_FEATURE = 1
HCI_EXTENDED_REJECT_INDICATION_LE_SUPPORTED_FEATURE = 2
HCI_PERIPHERAL_INITIATED_FEATURE_EXCHANGE_LE_SUPPORTED_FEATURE = 3
HCI_LE_PING_LE_SUPPORTED_FEATURE = 4
HCI_LE_DATA_PACKET_LENGTH_EXTENSION_LE_SUPPORTED_FEATURE = 5
HCI_LL_PRIVACY_LE_SUPPORTED_FEATURE = 6
HCI_EXTENDED_SCANNER_FILTER_POLICIES_LE_SUPPORTED_FEATURE = 7
HCI_LE_2M_PHY_LE_SUPPORTED_FEATURE = 8
HCI_STABLE_MODULATION_INDEX_TRANSMITTER_LE_SUPPORTED_FEATURE = 9
HCI_STABLE_MODULATION_INDEX_RECEIVER_LE_SUPPORTED_FEATURE = 10
HCI_LE_CODED_PHY_LE_SUPPORTED_FEATURE = 11
HCI_LE_EXTENDED_ADVERTISING_LE_SUPPORTED_FEATURE = 12
HCI_LE_PERIODIC_ADVERTISING_LE_SUPPORTED_FEATURE = 13
HCI_CHANNEL_SELECTION_ALGORITHM_2_LE_SUPPORTED_FEATURE = 14
HCI_LE_POWER_CLASS_1_LE_SUPPORTED_FEATURE = 15
HCI_MINIMUM_NUMBER_OF_USED_CHANNELS_PROCEDURE_LE_SUPPORTED_FEATURE = 16
HCI_CONNECTION_CTE_REQUEST_LE_SUPPORTED_FEATURE = 17
HCI_CONNECTION_CTE_RESPONSE_LE_SUPPORTED_FEATURE = 18
HCI_CONNECTIONLESS_CTE_TRANSMITTER_LE_SUPPORTED_FEATURE = 19
HCI_CONNECTIONLESS_CTR_RECEIVER_LE_SUPPORTED_FEATURE = 20
HCI_ANTENNA_SWITCHING_DURING_CTE_TRANSMISSION_LE_SUPPORTED_FEATURE = 21
HCI_ANTENNA_SWITCHING_DURING_CTE_RECEPTION_LE_SUPPORTED_FEATURE = 22
HCI_RECEIVING_CONSTANT_TONE_EXTENSIONS_LE_SUPPORTED_FEATURE = 23
HCI_PERIODIC_ADVERTISING_SYNC_TRANSFER_SENDER_LE_SUPPORTED_FEATURE = 24
HCI_PERIODIC_ADVERTISING_SYNC_TRANSFER_RECIPIENT_LE_SUPPORTED_FEATURE = 25
HCI_SLEEP_CLOCK_ACCURACY_UPDATES_LE_SUPPORTED_FEATURE = 26
HCI_REMOTE_PUBLIC_KEY_VALIDATION_LE_SUPPORTED_FEATURE = 27
HCI_CONNECTED_ISOCHRONOUS_STREAM_CENTRAL_LE_SUPPORTED_FEATURE = 28
HCI_CONNECTED_ISOCHRONOUS_STREAM_PERIPHERAL_LE_SUPPORTED_FEATURE = 29
HCI_ISOCHRONOUS_BROADCASTER_LE_SUPPORTED_FEATURE = 30
HCI_SYNCHRONIZED_RECEIVER_LE_SUPPORTED_FEATURE = 31
HCI_CONNECTED_ISOCHRONOUS_STREAM_LE_SUPPORTED_FEATURE = 32
HCI_LE_POWER_CONTROL_REQUEST_LE_SUPPORTED_FEATURE = 33
HCI_LE_POWER_CONTROL_REQUEST_DUP_LE_SUPPORTED_FEATURE = 34
HCI_LE_PATH_LOSS_MONITORING_LE_SUPPORTED_FEATURE = 35
HCI_PERIODIC_ADVERTISING_ADI_SUPPORT_LE_SUPPORTED_FEATURE = 36
HCI_CONNECTION_SUBRATING_LE_SUPPORTED_FEATURE = 37
HCI_CONNECTION_SUBRATING_HOST_SUPPORT_LE_SUPPORTED_FEATURE = 38
HCI_CHANNEL_CLASSIFICATION_LE_SUPPORTED_FEATURE = 39
HCI_ADVERTISING_CODING_SELECTION_LE_SUPPORTED_FEATURE = 40
HCI_ADVERTISING_CODING_SELECTION_HOST_SUPPORT_LE_SUPPORTED_FEATURE = 41
HCI_PERIODIC_ADVERTISING_WITH_RESPONSES_ADVERTISER_LE_SUPPORTED_FEATURE = 43
HCI_PERIODIC_ADVERTISING_WITH_RESPONSES_SCANNER_LE_SUPPORTED_FEATURE = 44
class LeFeatureMask(enum.IntFlag):
LE_ENCRYPTION = 1 << LeFeature.LE_ENCRYPTION
CONNECTION_PARAMETERS_REQUEST_PROCEDURE = 1 << LeFeature.CONNECTION_PARAMETERS_REQUEST_PROCEDURE
EXTENDED_REJECT_INDICATION = 1 << LeFeature.EXTENDED_REJECT_INDICATION
PERIPHERAL_INITIATED_FEATURE_EXCHANGE = 1 << LeFeature.PERIPHERAL_INITIATED_FEATURE_EXCHANGE
LE_PING = 1 << LeFeature.LE_PING
LE_DATA_PACKET_LENGTH_EXTENSION = 1 << LeFeature.LE_DATA_PACKET_LENGTH_EXTENSION
LL_PRIVACY = 1 << LeFeature.LL_PRIVACY
EXTENDED_SCANNER_FILTER_POLICIES = 1 << LeFeature.EXTENDED_SCANNER_FILTER_POLICIES
LE_2M_PHY = 1 << LeFeature.LE_2M_PHY
STABLE_MODULATION_INDEX_TRANSMITTER = 1 << LeFeature.STABLE_MODULATION_INDEX_TRANSMITTER
STABLE_MODULATION_INDEX_RECEIVER = 1 << LeFeature.STABLE_MODULATION_INDEX_RECEIVER
LE_CODED_PHY = 1 << LeFeature.LE_CODED_PHY
LE_EXTENDED_ADVERTISING = 1 << LeFeature.LE_EXTENDED_ADVERTISING
LE_PERIODIC_ADVERTISING = 1 << LeFeature.LE_PERIODIC_ADVERTISING
CHANNEL_SELECTION_ALGORITHM_2 = 1 << LeFeature.CHANNEL_SELECTION_ALGORITHM_2
LE_POWER_CLASS_1 = 1 << LeFeature.LE_POWER_CLASS_1
MINIMUM_NUMBER_OF_USED_CHANNELS_PROCEDURE = 1 << LeFeature.MINIMUM_NUMBER_OF_USED_CHANNELS_PROCEDURE
CONNECTION_CTE_REQUEST = 1 << LeFeature.CONNECTION_CTE_REQUEST
CONNECTION_CTE_RESPONSE = 1 << LeFeature.CONNECTION_CTE_RESPONSE
CONNECTIONLESS_CTE_TRANSMITTER = 1 << LeFeature.CONNECTIONLESS_CTE_TRANSMITTER
CONNECTIONLESS_CTR_RECEIVER = 1 << LeFeature.CONNECTIONLESS_CTR_RECEIVER
ANTENNA_SWITCHING_DURING_CTE_TRANSMISSION = 1 << LeFeature.ANTENNA_SWITCHING_DURING_CTE_TRANSMISSION
ANTENNA_SWITCHING_DURING_CTE_RECEPTION = 1 << LeFeature.ANTENNA_SWITCHING_DURING_CTE_RECEPTION
RECEIVING_CONSTANT_TONE_EXTENSIONS = 1 << LeFeature.RECEIVING_CONSTANT_TONE_EXTENSIONS
PERIODIC_ADVERTISING_SYNC_TRANSFER_SENDER = 1 << LeFeature.PERIODIC_ADVERTISING_SYNC_TRANSFER_SENDER
PERIODIC_ADVERTISING_SYNC_TRANSFER_RECIPIENT = 1 << LeFeature.PERIODIC_ADVERTISING_SYNC_TRANSFER_RECIPIENT
SLEEP_CLOCK_ACCURACY_UPDATES = 1 << LeFeature.SLEEP_CLOCK_ACCURACY_UPDATES
REMOTE_PUBLIC_KEY_VALIDATION = 1 << LeFeature.REMOTE_PUBLIC_KEY_VALIDATION
CONNECTED_ISOCHRONOUS_STREAM_CENTRAL = 1 << LeFeature.CONNECTED_ISOCHRONOUS_STREAM_CENTRAL
CONNECTED_ISOCHRONOUS_STREAM_PERIPHERAL = 1 << LeFeature.CONNECTED_ISOCHRONOUS_STREAM_PERIPHERAL
ISOCHRONOUS_BROADCASTER = 1 << LeFeature.ISOCHRONOUS_BROADCASTER
SYNCHRONIZED_RECEIVER = 1 << LeFeature.SYNCHRONIZED_RECEIVER
CONNECTED_ISOCHRONOUS_STREAM = 1 << LeFeature.CONNECTED_ISOCHRONOUS_STREAM
LE_POWER_CONTROL_REQUEST = 1 << LeFeature.LE_POWER_CONTROL_REQUEST
LE_POWER_CONTROL_REQUEST_DUP = 1 << LeFeature.LE_POWER_CONTROL_REQUEST_DUP
LE_PATH_LOSS_MONITORING = 1 << LeFeature.LE_PATH_LOSS_MONITORING
PERIODIC_ADVERTISING_ADI_SUPPORT = 1 << LeFeature.PERIODIC_ADVERTISING_ADI_SUPPORT
CONNECTION_SUBRATING = 1 << LeFeature.CONNECTION_SUBRATING
CONNECTION_SUBRATING_HOST_SUPPORT = 1 << LeFeature.CONNECTION_SUBRATING_HOST_SUPPORT
CHANNEL_CLASSIFICATION = 1 << LeFeature.CHANNEL_CLASSIFICATION
ADVERTISING_CODING_SELECTION = 1 << LeFeature.ADVERTISING_CODING_SELECTION
ADVERTISING_CODING_SELECTION_HOST_SUPPORT = 1 << LeFeature.ADVERTISING_CODING_SELECTION_HOST_SUPPORT
PERIODIC_ADVERTISING_WITH_RESPONSES_ADVERTISER = 1 << LeFeature.PERIODIC_ADVERTISING_WITH_RESPONSES_ADVERTISER
PERIODIC_ADVERTISING_WITH_RESPONSES_SCANNER = 1 << LeFeature.PERIODIC_ADVERTISING_WITH_RESPONSES_SCANNER
HCI_LE_SUPPORTED_FEATURES_NAMES = {
flag: feature_name for (feature_name, flag) in globals().items()
if feature_name.startswith('HCI_') and feature_name.endswith('_LE_SUPPORTED_FEATURE')
}
# fmt: on
@@ -2068,17 +2026,6 @@ class OwnAddressType(enum.IntEnum):
return {'size': 1, 'mapper': lambda x: OwnAddressType(x).name}
# -----------------------------------------------------------------------------
class LoopbackMode(enum.IntEnum):
DISABLED = 0
LOCAL = 1
REMOTE = 2
@classmethod
def type_spec(cls):
return {'size': 1, 'mapper': lambda x: LoopbackMode(x).name}
# -----------------------------------------------------------------------------
class HCI_Packet:
'''
@@ -3319,7 +3266,9 @@ class HCI_Read_Local_Supported_Commands_Command(HCI_Command):
# -----------------------------------------------------------------------------
@HCI_Command.command()
@HCI_Command.command(
return_parameters_fields=[('status', STATUS_SPEC), ('lmp_features', 8)]
)
class HCI_Read_Local_Supported_Features_Command(HCI_Command):
'''
See Bluetooth spec @ 7.4.3 Read Local Supported Features Command
@@ -3332,7 +3281,7 @@ class HCI_Read_Local_Supported_Features_Command(HCI_Command):
return_parameters_fields=[
('status', STATUS_SPEC),
('page_number', 1),
('maximum_page_number', 1),
('max_page_number', 1),
('extended_lmp_features', 8),
],
)
@@ -3405,27 +3354,6 @@ class HCI_Read_Encryption_Key_Size_Command(HCI_Command):
'''
# -----------------------------------------------------------------------------
@HCI_Command.command(
return_parameters_fields=[
('status', STATUS_SPEC),
('loopback_mode', LoopbackMode.type_spec()),
],
)
class HCI_Read_Loopback_Mode_Command(HCI_Command):
'''
See Bluetooth spec @ 7.6.1 Read Loopback Mode Command
'''
# -----------------------------------------------------------------------------
@HCI_Command.command([('loopback_mode', 1)])
class HCI_Write_Loopback_Mode_Command(HCI_Command):
'''
See Bluetooth spec @ 7.6.2 Write Loopback Mode Command
'''
# -----------------------------------------------------------------------------
@HCI_Command.command([('le_event_mask', 8)])
class HCI_LE_Set_Event_Mask_Command(HCI_Command):
@@ -3522,7 +3450,9 @@ class HCI_LE_Set_Advertising_Parameters_Command(HCI_Command):
# -----------------------------------------------------------------------------
@HCI_Command.command()
@HCI_Command.command(
return_parameters_fields=[('status', STATUS_SPEC), ('tx_power_level', 1)]
)
class HCI_LE_Read_Advertising_Physical_Channel_Tx_Power_Command(HCI_Command):
'''
See Bluetooth spec @ 7.8.6 LE Read Advertising Physical Channel Tx Power Command
@@ -4807,11 +4737,7 @@ class HCI_Event(HCI_Packet):
HCI_Object.init_from_bytes(self, parameters, 0, fields)
return self
def __init__(self, event_code=-1, parameters=None, **kwargs):
# Since the legacy implementation relies on an __init__ injector, typing always
# complains that positional argument event_code is not passed, so here sets a
# default value to allow building derived HCI_Event without event_code.
assert event_code != -1
def __init__(self, event_code, parameters=None, **kwargs):
super().__init__(HCI_Event.event_name(event_code))
if (fields := getattr(self, 'fields', None)) and kwargs:
HCI_Object.init_from_fields(self, fields, kwargs)
@@ -4905,8 +4831,7 @@ class HCI_Extended_Event(HCI_Event):
HCI_Object.init_from_bytes(self, parameters, 1, fields)
return self
def __init__(self, subevent_code=None, parameters=None, **kwargs):
assert subevent_code is not None
def __init__(self, subevent_code, parameters, **kwargs):
self.subevent_code = subevent_code
if parameters is None and (fields := getattr(self, 'fields', None)) and kwargs:
parameters = bytes([subevent_code]) + HCI_Object.dict_to_bytes(
@@ -5908,7 +5833,7 @@ class HCI_Inquiry_Result_With_RSSI_Event(HCI_Event):
('status', STATUS_SPEC),
('connection_handle', 2),
('page_number', 1),
('maximum_page_number', 1),
('max_page_number', 1),
('extended_lmp_features', 8),
]
)
+32 -74
View File
@@ -18,16 +18,10 @@
from __future__ import annotations
from collections.abc import Callable, MutableMapping
from typing import cast, Any, Optional
from typing import cast, Any
import logging
from bumble import avc
from bumble import avctp
from bumble import avdtp
from bumble import avrcp
from bumble import crypto
from bumble import rfcomm
from bumble import sdp
from bumble.colors import color
from bumble.att import ATT_CID, ATT_PDU
from bumble.smp import SMP_CID, SMP_Command
@@ -53,7 +47,9 @@ from bumble.hci import (
HCI_AclDataPacket,
HCI_Disconnection_Complete_Event,
)
from bumble.rfcomm import RFCOMM_Frame, RFCOMM_PSM
from bumble.sdp import SDP_PDU, SDP_PSM
from bumble import crypto
# -----------------------------------------------------------------------------
# Logging
@@ -63,35 +59,28 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
PSM_NAMES = {
rfcomm.RFCOMM_PSM: 'RFCOMM',
sdp.SDP_PSM: 'SDP',
RFCOMM_PSM: 'RFCOMM',
SDP_PSM: 'SDP',
avdtp.AVDTP_PSM: 'AVDTP',
avctp.AVCTP_PSM: 'AVCTP'
# TODO: add more PSM values
}
AVCTP_PID_NAMES = {avrcp.AVRCP_PID: 'AVRCP'}
# -----------------------------------------------------------------------------
class PacketTracer:
class AclStream:
psms: MutableMapping[int, int]
peer: Optional[PacketTracer.AclStream]
peer: PacketTracer.AclStream
avdtp_assemblers: MutableMapping[int, avdtp.MessageAssembler]
avctp_assemblers: MutableMapping[int, avctp.MessageAssembler]
def __init__(self, analyzer: PacketTracer.Analyzer) -> None:
self.analyzer = analyzer
self.packet_assembler = HCI_AclDataPacketAssembler(self.on_acl_pdu)
self.avdtp_assemblers = {} # AVDTP assemblers, by source_cid
self.avctp_assemblers = {} # AVCTP assemblers, by source_cid
self.psms = {} # PSM, by source_cid
self.peer = None
# pylint: disable=too-many-nested-blocks
def on_acl_pdu(self, pdu: bytes) -> None:
l2cap_pdu = L2CAP_PDU.from_bytes(pdu)
self.analyzer.emit(l2cap_pdu)
if l2cap_pdu.cid == ATT_CID:
att_pdu = ATT_PDU.from_bytes(l2cap_pdu.payload)
@@ -113,51 +102,42 @@ class PacketTracer:
connection_response.result
== L2CAP_Connection_Response.CONNECTION_SUCCESSFUL
):
if self.peer and (
psm := self.peer.psms.get(connection_response.source_cid)
):
# Found a pending connection
self.psms[connection_response.destination_cid] = psm
if self.peer:
if psm := self.peer.psms.get(
connection_response.source_cid
):
# Found a pending connection
self.psms[connection_response.destination_cid] = psm
# For AVDTP connections, create a packet assembler for
# each direction
if psm == avdtp.AVDTP_PSM:
self.avdtp_assemblers[
connection_response.source_cid
] = avdtp.MessageAssembler(self.on_avdtp_message)
self.peer.avdtp_assemblers[
connection_response.destination_cid
] = avdtp.MessageAssembler(
self.peer.on_avdtp_message
)
# For AVDTP connections, create a packet assembler for
# each direction
if psm == avdtp.AVDTP_PSM:
self.avdtp_assemblers[
connection_response.source_cid
] = avdtp.MessageAssembler(self.on_avdtp_message)
self.peer.avdtp_assemblers[
connection_response.destination_cid
] = avdtp.MessageAssembler(self.peer.on_avdtp_message)
elif psm == avctp.AVCTP_PSM:
self.avctp_assemblers[
connection_response.source_cid
] = avctp.MessageAssembler(self.on_avctp_message)
self.peer.avctp_assemblers[
connection_response.destination_cid
] = avctp.MessageAssembler(self.peer.on_avctp_message)
else:
# Try to find the PSM associated with this PDU
if self.peer and (psm := self.peer.psms.get(l2cap_pdu.cid)):
if psm == sdp.SDP_PSM:
sdp_pdu = sdp.SDP_PDU.from_bytes(l2cap_pdu.payload)
if psm == SDP_PSM:
sdp_pdu = SDP_PDU.from_bytes(l2cap_pdu.payload)
self.analyzer.emit(sdp_pdu)
elif psm == rfcomm.RFCOMM_PSM:
rfcomm_frame = rfcomm.RFCOMM_Frame.from_bytes(l2cap_pdu.payload)
elif psm == RFCOMM_PSM:
rfcomm_frame = RFCOMM_Frame.from_bytes(l2cap_pdu.payload)
self.analyzer.emit(rfcomm_frame)
elif psm == avdtp.AVDTP_PSM:
self.analyzer.emit(
f'{color("L2CAP", "green")} [CID={l2cap_pdu.cid}, '
f'PSM=AVDTP]: {l2cap_pdu.payload.hex()}'
)
if avdtp_assembler := self.avdtp_assemblers.get(l2cap_pdu.cid):
avdtp_assembler.on_pdu(l2cap_pdu.payload)
elif psm == avctp.AVCTP_PSM:
self.analyzer.emit(
f'{color("L2CAP", "green")} [CID={l2cap_pdu.cid}, '
f'PSM=AVCTP]: {l2cap_pdu.payload.hex()}'
)
if avctp_assembler := self.avctp_assemblers.get(l2cap_pdu.cid):
avctp_assembler.on_pdu(l2cap_pdu.payload)
assembler = self.avdtp_assemblers.get(l2cap_pdu.cid)
if assembler:
assembler.on_pdu(l2cap_pdu.payload)
else:
psm_string = name_or_number(PSM_NAMES, psm)
self.analyzer.emit(
@@ -174,28 +154,6 @@ class PacketTracer:
f'{color("AVDTP", "green")} [{transaction_label}] {message}'
)
def on_avctp_message(
self,
transaction_label: int,
is_command: bool,
ipid: bool,
pid: int,
payload: bytes,
):
if pid == avrcp.AVRCP_PID:
avc_frame = avc.Frame.from_bytes(payload)
details = str(avc_frame)
else:
details = payload.hex()
c_r = 'Command' if is_command else 'Response'
self.analyzer.emit(
f'{color("AVCTP", "green")} '
f'{c_r}[{transaction_label}][{name_or_number(AVCTP_PID_NAMES, pid)}] '
f'{"#" if ipid else ""}'
f'{details}'
)
def feed_packet(self, packet: HCI_AclDataPacket) -> None:
self.packet_assembler.feed_packet(packet)
+1 -1
View File
@@ -402,7 +402,7 @@ class Device(HID):
report_type = pdu[0] & 0x03
buffer_flag = (pdu[0] & 0x08) >> 3
report_id = pdu[1]
logger.debug(f"buffer_flag: {buffer_flag}")
logger.debug("buffer_flag: " + str(buffer_flag))
if buffer_flag == 1:
buffer_size = (pdu[3] << 8) | pdu[2]
else:
+81 -162
View File
@@ -18,11 +18,10 @@
from __future__ import annotations
import asyncio
import collections
import dataclasses
import logging
import struct
from typing import Any, Awaitable, Callable, Deque, Dict, Optional, cast, TYPE_CHECKING
from typing import Any, Awaitable, Callable, Dict, Optional, Union, cast, TYPE_CHECKING
from bumble.colors import color
from bumble.l2cap import L2CAP_PDU
@@ -71,7 +70,6 @@ from .hci import (
HCI_Reset_Command,
HCI_Set_Event_Mask_Command,
HCI_SynchronousDataPacket,
LeFeatureMask,
)
from .core import (
BT_BR_EDR_TRANSPORT,
@@ -93,49 +91,16 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
class AclPacketQueue:
max_packet_size: int
# Constants
# -----------------------------------------------------------------------------
# fmt: off
def __init__(
self,
max_packet_size: int,
max_in_flight: int,
send: Callable[[HCI_Packet], None],
) -> None:
self.max_packet_size = max_packet_size
self.max_in_flight = max_in_flight
self.in_flight = 0
self.send = send
self.packets: Deque[HCI_AclDataPacket] = collections.deque()
HOST_DEFAULT_HC_LE_ACL_DATA_PACKET_LENGTH = 27
HOST_HC_TOTAL_NUM_LE_ACL_DATA_PACKETS = 1
HOST_DEFAULT_HC_ACL_DATA_PACKET_LENGTH = 27
HOST_HC_TOTAL_NUM_ACL_DATA_PACKETS = 1
def enqueue(self, packet: HCI_AclDataPacket) -> None:
self.packets.appendleft(packet)
self.check_queue()
if self.packets:
logger.debug(
f'{self.in_flight} ACL packets in flight, '
f'{len(self.packets)} in queue'
)
def check_queue(self) -> None:
while self.packets and self.in_flight < self.max_in_flight:
packet = self.packets.pop()
self.send(packet)
self.in_flight += 1
def on_packets_completed(self, packet_count: int) -> None:
if packet_count > self.in_flight:
logger.warning(
color(
'!!! {packet_count} completed but only '
f'{self.in_flight} in flight'
)
)
packet_count = self.in_flight
self.in_flight -= packet_count
self.check_queue()
# fmt: on
# -----------------------------------------------------------------------------
@@ -146,13 +111,6 @@ class Connection:
self.peer_address = peer_address
self.assembler = HCI_AclDataPacketAssembler(self.on_acl_pdu)
self.transport = transport
acl_packet_queue: Optional[AclPacketQueue] = (
host.le_acl_packet_queue
if transport == BT_LE_TRANSPORT
else host.acl_packet_queue
)
assert acl_packet_queue
self.acl_packet_queue = acl_packet_queue
def on_hci_acl_data_packet(self, packet: HCI_AclDataPacket) -> None:
self.assembler.feed_packet(packet)
@@ -162,27 +120,10 @@ class Connection:
self.host.on_l2cap_pdu(self, l2cap_pdu.cid, l2cap_pdu.payload)
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class ScoLink:
peer_address: Address
handle: int
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class CisLink:
peer_address: Address
handle: int
# -----------------------------------------------------------------------------
class Host(AbortableEventEmitter):
connections: Dict[int, Connection]
cis_links: Dict[int, CisLink]
sco_links: Dict[int, ScoLink]
acl_packet_queue: Optional[AclPacketQueue] = None
le_acl_packet_queue: Optional[AclPacketQueue] = None
acl_packet_queue: collections.deque[HCI_AclDataPacket]
hci_sink: Optional[TransportSink] = None
hci_metadata: Dict[str, Any]
long_term_key_provider: Optional[
@@ -200,10 +141,14 @@ class Host(AbortableEventEmitter):
self.hci_metadata = {}
self.ready = False # True when we can accept incoming packets
self.connections = {} # Connections, by connection handle
self.cis_links = {} # CIS links, by connection handle
self.sco_links = {} # SCO links, by connection handle
self.pending_command = None
self.pending_response = None
self.hc_le_acl_data_packet_length = HOST_DEFAULT_HC_LE_ACL_DATA_PACKET_LENGTH
self.hc_total_num_le_acl_data_packets = HOST_HC_TOTAL_NUM_LE_ACL_DATA_PACKETS
self.hc_acl_data_packet_length = HOST_DEFAULT_HC_ACL_DATA_PACKET_LENGTH
self.hc_total_num_acl_data_packets = HOST_HC_TOTAL_NUM_ACL_DATA_PACKETS
self.acl_packet_queue = collections.deque()
self.acl_packets_in_flight = 0
self.local_version = None
self.local_supported_commands = bytes(64)
self.local_le_features = 0
@@ -309,54 +254,46 @@ class Host(AbortableEventEmitter):
response = await self.send_command(
HCI_Read_Buffer_Size_Command(), check_result=True
)
hc_acl_data_packet_length = (
self.hc_acl_data_packet_length = (
response.return_parameters.hc_acl_data_packet_length
)
hc_total_num_acl_data_packets = (
self.hc_total_num_acl_data_packets = (
response.return_parameters.hc_total_num_acl_data_packets
)
logger.debug(
'HCI ACL flow control: '
f'hc_acl_data_packet_length={hc_acl_data_packet_length},'
f'hc_total_num_acl_data_packets={hc_total_num_acl_data_packets}'
f'hc_acl_data_packet_length={self.hc_acl_data_packet_length},'
f'hc_total_num_acl_data_packets={self.hc_total_num_acl_data_packets}'
)
self.acl_packet_queue = AclPacketQueue(
max_packet_size=hc_acl_data_packet_length,
max_in_flight=hc_total_num_acl_data_packets,
send=self.send_hci_packet,
)
hc_le_acl_data_packet_length = 0
hc_total_num_le_acl_data_packets = 0
if self.supports_command(HCI_LE_READ_BUFFER_SIZE_COMMAND):
response = await self.send_command(
HCI_LE_Read_Buffer_Size_Command(), check_result=True
)
hc_le_acl_data_packet_length = (
self.hc_le_acl_data_packet_length = (
response.return_parameters.hc_le_acl_data_packet_length
)
hc_total_num_le_acl_data_packets = (
self.hc_total_num_le_acl_data_packets = (
response.return_parameters.hc_total_num_le_acl_data_packets
)
logger.debug(
'HCI LE ACL flow control: '
f'hc_le_acl_data_packet_length={hc_le_acl_data_packet_length},'
f'hc_total_num_le_acl_data_packets={hc_total_num_le_acl_data_packets}'
f'hc_le_acl_data_packet_length={self.hc_le_acl_data_packet_length},'
'hc_total_num_le_acl_data_packets='
f'{self.hc_total_num_le_acl_data_packets}'
)
if hc_le_acl_data_packet_length == 0 or hc_total_num_le_acl_data_packets == 0:
# LE and Classic share the same queue
self.le_acl_packet_queue = self.acl_packet_queue
else:
# Create a separate queue for LE
self.le_acl_packet_queue = AclPacketQueue(
max_packet_size=hc_le_acl_data_packet_length,
max_in_flight=hc_total_num_le_acl_data_packets,
send=self.send_hci_packet,
)
if (
response.return_parameters.hc_le_acl_data_packet_length == 0
or response.return_parameters.hc_total_num_le_acl_data_packets == 0
):
# LE and Classic share the same values
self.hc_le_acl_data_packet_length = self.hc_acl_data_packet_length
self.hc_total_num_le_acl_data_packets = (
self.hc_total_num_acl_data_packets
)
if self.supports_command(
HCI_LE_READ_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND
@@ -395,13 +332,14 @@ class Host(AbortableEventEmitter):
self.hci_metadata = getattr(source, 'metadata', self.hci_metadata)
def send_hci_packet(self, packet: HCI_Packet) -> None:
logger.debug(f'{color("### HOST -> CONTROLLER", "blue")}: {packet}')
if self.snooper:
self.snooper.snoop(bytes(packet), Snooper.Direction.HOST_TO_CONTROLLER)
if self.hci_sink:
self.hci_sink.on_packet(bytes(packet))
async def send_command(self, command, check_result=False):
logger.debug(f'{color("### HOST -> CONTROLLER", "blue")}: {command}')
# Wait until we can send (only one pending command at a time)
async with self.command_semaphore:
assert self.pending_command is None
@@ -449,17 +387,6 @@ class Host(AbortableEventEmitter):
asyncio.create_task(send_command(command))
def send_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes) -> None:
if not (connection := self.connections.get(connection_handle)):
logger.warning(f'connection 0x{connection_handle:04X} not found')
return
packet_queue = connection.acl_packet_queue
if packet_queue is None:
logger.warning(
f'no ACL packet queue for connection 0x{connection_handle:04X}'
)
return
# Create a PDU
l2cap_pdu = bytes(L2CAP_PDU(cid, pdu))
# Send the data to the controller via ACL packets
@@ -467,7 +394,8 @@ class Host(AbortableEventEmitter):
offset = 0
pb_flag = 0
while bytes_remaining:
data_total_length = min(bytes_remaining, packet_queue.max_packet_size)
# TODO: support different LE/Classic lengths
data_total_length = min(bytes_remaining, self.hc_le_acl_data_packet_length)
acl_packet = HCI_AclDataPacket(
connection_handle=connection_handle,
pb_flag=pb_flag,
@@ -475,12 +403,34 @@ class Host(AbortableEventEmitter):
data_total_length=data_total_length,
data=l2cap_pdu[offset : offset + data_total_length],
)
logger.debug(f'>>> ACL packet enqueue: (CID={cid}) {acl_packet}')
packet_queue.enqueue(acl_packet)
logger.debug(
f'{color("### HOST -> CONTROLLER", "blue")}: (CID={cid}) {acl_packet}'
)
self.queue_acl_packet(acl_packet)
pb_flag = 1
offset += data_total_length
bytes_remaining -= data_total_length
def queue_acl_packet(self, acl_packet: HCI_AclDataPacket) -> None:
self.acl_packet_queue.appendleft(acl_packet)
self.check_acl_packet_queue()
if len(self.acl_packet_queue):
logger.debug(
f'{self.acl_packets_in_flight} ACL packets in flight, '
f'{len(self.acl_packet_queue)} in queue'
)
def check_acl_packet_queue(self) -> None:
# Send all we can (TODO: support different LE/Classic limits)
while (
len(self.acl_packet_queue) > 0
and self.acl_packets_in_flight < self.hc_total_num_le_acl_data_packets
):
packet = self.acl_packet_queue.pop()
self.acl_packets_in_flight += 1
self.send_hci_packet(packet)
def supports_command(self, command):
# Find the support flag position for this command
for octet, flags in enumerate(HCI_SUPPORTED_COMMANDS_FLAGS):
@@ -507,8 +457,8 @@ class Host(AbortableEventEmitter):
return commands
def supports_le_features(self, feature: LeFeatureMask) -> bool:
return (self.local_le_features & feature) == feature
def supports_le_feature(self, feature):
return (self.local_le_features & (1 << feature)) != 0
@property
def supported_le_features(self):
@@ -603,7 +553,7 @@ class Host(AbortableEventEmitter):
# This is used just for the Num_HCI_Command_Packets field, not related to
# an actual command
logger.debug('no-command event')
return
return None
return self.on_command_processed(event)
@@ -611,17 +561,18 @@ class Host(AbortableEventEmitter):
return self.on_command_processed(event)
def on_hci_number_of_completed_packets_event(self, event):
for connection_handle, num_completed_packets in zip(
event.connection_handles, event.num_completed_packets
):
if not (connection := self.connections.get(connection_handle)):
logger.warning(
'received packet completion event for unknown handle '
f'0x{connection_handle:04X}'
total_packets = sum(event.num_completed_packets)
if total_packets <= self.acl_packets_in_flight:
self.acl_packets_in_flight -= total_packets
self.check_acl_packet_queue()
else:
logger.warning(
color(
f'!!! {total_packets} completed but only '
f'{self.acl_packets_in_flight} in flight'
)
continue
connection.acl_packet_queue.on_packets_completed(num_completed_packets)
)
self.acl_packets_in_flight = 0
# Classic only
def on_hci_connection_request_event(self, event):
@@ -715,36 +666,25 @@ class Host(AbortableEventEmitter):
def on_hci_disconnection_complete_event(self, event):
# Find the connection
handle = event.connection_handle
if (
connection := (
self.connections.get(handle)
or self.cis_links.get(handle)
or self.sco_links.get(handle)
)
) is None:
if (connection := self.connections.get(event.connection_handle)) is None:
logger.warning('!!! DISCONNECTION COMPLETE: unknown handle')
return
if event.status == HCI_SUCCESS:
logger.debug(
f'### DISCONNECTION: [0x{handle:04X}] '
f'### DISCONNECTION: [0x{event.connection_handle:04X}] '
f'{connection.peer_address} '
f'reason={event.reason}'
)
del self.connections[event.connection_handle]
# Notify the listeners
self.emit('disconnection', handle, event.reason)
(
self.connections.pop(handle, 0)
or self.cis_links.pop(handle, 0)
or self.sco_links.pop(handle, 0)
)
self.emit('disconnection', event.connection_handle, event.reason)
else:
logger.debug(f'### DISCONNECTION FAILED: {event.status}')
# Notify the listeners
self.emit('disconnection_failure', handle, event.status)
self.emit('disconnection_failure', event.connection_handle, event.status)
def on_hci_le_connection_update_complete_event(self, event):
if (connection := self.connections.get(event.connection_handle)) is None:
@@ -805,10 +745,6 @@ class Host(AbortableEventEmitter):
def on_hci_le_cis_established_event(self, event):
# The remaining parameters are unused for now.
if event.status == HCI_SUCCESS:
self.cis_links[event.connection_handle] = CisLink(
handle=event.connection_handle,
peer_address=Address.ANY,
)
self.emit('cis_establishment', event.connection_handle)
else:
self.emit(
@@ -875,11 +811,6 @@ class Host(AbortableEventEmitter):
f'{event.bd_addr}'
)
self.sco_links[event.connection_handle] = ScoLink(
peer_address=event.bd_addr,
handle=event.connection_handle,
)
# Notify the client
self.emit(
'sco_connection',
@@ -1073,15 +1004,3 @@ class Host(AbortableEventEmitter):
event.bd_addr,
event.host_supported_features,
)
def on_hci_le_read_remote_features_complete_event(self, event):
if event.status != HCI_SUCCESS:
self.emit(
'le_remote_features_failure', event.connection_handle, event.status
)
else:
self.emit(
'le_remote_features',
event.connection_handle,
int.from_bytes(event.le_features, 'little'),
)
+5 -10
View File
@@ -149,10 +149,9 @@ L2CAP_INVALID_CID_IN_REQUEST_REASON = 0x0002
L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_CREDITS = 65535
L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU = 23
L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MTU = 65535
L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MPS = 23
L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MPS = 65533
L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU = 2048
L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU = 2046
L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS = 2048
L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS = 256
@@ -189,11 +188,8 @@ class LeCreditBasedChannelSpec:
or self.max_credits > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_CREDITS
):
raise ValueError('max credits out of range')
if (
self.mtu < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU
or self.mtu > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MTU
):
raise ValueError('MTU out of range')
if self.mtu < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU:
raise ValueError('MTU too small')
if (
self.mps < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MPS
or self.mps > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MPS
@@ -1648,13 +1644,12 @@ class ChannelManager:
def send_pdu(self, connection, cid: int, pdu: Union[SupportsBytes, bytes]) -> None:
pdu_str = pdu.hex() if isinstance(pdu, bytes) else str(pdu)
pdu_bytes = bytes(pdu)
logger.debug(
f'{color(">>> Sending L2CAP PDU", "blue")} '
f'on connection [0x{connection.handle:04X}] (CID={cid}) '
f'{connection.peer_address}: {len(pdu_bytes)} bytes, {pdu_str}'
f'{connection.peer_address}: {pdu_str}'
)
self.host.send_l2cap_pdu(connection.handle, cid, pdu_bytes)
self.host.send_l2cap_pdu(connection.handle, cid, bytes(pdu))
def on_pdu(self, connection: Connection, cid: int, pdu: bytes) -> None:
if cid in (L2CAP_SIGNALING_CID, L2CAP_LE_SIGNALING_CID):
+33 -127
View File
@@ -26,13 +26,9 @@ from bumble.hci import (
HCI_SUCCESS,
HCI_CONNECTION_ACCEPT_TIMEOUT_ERROR,
HCI_CONNECTION_TIMEOUT_ERROR,
HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR,
HCI_PAGE_TIMEOUT_ERROR,
HCI_Connection_Complete_Event,
)
from bumble import controller
from typing import Optional, Set
# -----------------------------------------------------------------------------
# Logging
@@ -61,8 +57,6 @@ class LocalLink:
Link bus for controllers to communicate with each other
'''
controllers: Set[controller.Controller]
def __init__(self):
self.controllers = set()
self.pending_connection = None
@@ -85,9 +79,7 @@ class LocalLink:
return controller
return None
def find_classic_controller(
self, address: Address
) -> Optional[controller.Controller]:
def find_classic_controller(self, address):
for controller in self.controllers:
if controller.public_address == address:
return controller
@@ -103,11 +95,21 @@ class LocalLink:
def on_address_changed(self, controller):
pass
def send_advertising_data(self, sender_address, data):
def send_advertising_data(self, sender_address, data, scan_response):
# Send the advertising data to all controllers, except the sender
for controller in self.controllers:
if controller.random_address != sender_address:
controller.on_link_advertising_data(sender_address, data)
controller.on_link_advertising_data(sender_address, data, scan_response)
def send_extended_advertising_data(
self, sender_address, event_type, data, scan_response
):
# Send the advertising data to all controllers, except the sender
for controller in self.controllers:
if controller.random_address != sender_address:
controller.on_link_extended_advertising_data(
sender_address, event_type, data, scan_response
)
def send_acl_data(self, sender_controller, destination_address, transport, data):
# Send the data to the first controller with a matching address
@@ -159,30 +161,34 @@ class LocalLink:
asyncio.get_running_loop().call_soon(self.on_connection_complete)
def on_disconnection_complete(
self, central_address, peripheral_address, disconnect_command
self, initiator_address, peer_address, disconnect_command
):
# Find the controller that initiated the disconnection
if not (central_controller := self.find_controller(central_address)):
if not (initiator_controller := self.find_controller(initiator_address)):
logger.warning('!!! Initiating controller not found')
return
# Disconnect from the first controller with a matching address
if peripheral_controller := self.find_controller(peripheral_address):
peripheral_controller.on_link_central_disconnected(
central_address, disconnect_command.reason
if peer_controller := self.find_controller(peer_address):
peer_controller.on_link_peer_disconnected(
initiator_address, disconnect_command.reason
)
central_controller.on_link_peripheral_disconnection_complete(
initiator_controller.on_link_initiated_disconnection_complete(
disconnect_command, HCI_SUCCESS
)
def disconnect(self, central_address, peripheral_address, disconnect_command):
def disconnect(self, initiator_address, peer_address, disconnect_command):
logger.debug(
f'$$$ DISCONNECTION {central_address} -> '
f'{peripheral_address}: reason = {disconnect_command.reason}'
f'$$$ DISCONNECTION {initiator_address} -> '
f'{peer_address}: reason = {disconnect_command.reason}'
)
asyncio.get_running_loop().call_soon(
self.on_disconnection_complete,
initiator_address,
peer_address,
disconnect_command,
)
args = [central_address, peripheral_address, disconnect_command]
asyncio.get_running_loop().call_soon(self.on_disconnection_complete, *args)
# pylint: disable=too-many-arguments
def on_connection_encrypted(
@@ -196,60 +202,6 @@ class LocalLink:
if peripheral_controller := self.find_controller(peripheral_address):
peripheral_controller.on_link_encrypted(central_address, rand, ediv, ltk)
def create_cis(
self,
central_controller: controller.Controller,
peripheral_address: Address,
cig_id: int,
cis_id: int,
) -> None:
logger.debug(
f'$$$ CIS Request {central_controller.random_address} -> {peripheral_address}'
)
if peripheral_controller := self.find_controller(peripheral_address):
asyncio.get_running_loop().call_soon(
peripheral_controller.on_link_cis_request,
central_controller.random_address,
cig_id,
cis_id,
)
def accept_cis(
self,
peripheral_controller: controller.Controller,
central_address: Address,
cig_id: int,
cis_id: int,
) -> None:
logger.debug(
f'$$$ CIS Accept {peripheral_controller.random_address} -> {central_address}'
)
if central_controller := self.find_controller(central_address):
asyncio.get_running_loop().call_soon(
central_controller.on_link_cis_established, cig_id, cis_id
)
asyncio.get_running_loop().call_soon(
peripheral_controller.on_link_cis_established, cig_id, cis_id
)
def disconnect_cis(
self,
initiator_controller: controller.Controller,
peer_address: Address,
cig_id: int,
cis_id: int,
) -> None:
logger.debug(
f'$$$ CIS Disconnect {initiator_controller.random_address} -> {peer_address}'
)
if peer_controller := self.find_controller(peer_address):
asyncio.get_running_loop().call_soon(
initiator_controller.on_link_cis_disconnected, cig_id, cis_id
)
asyncio.get_running_loop().call_soon(
peer_controller.on_link_cis_disconnected, cig_id, cis_id
)
############################################################
# Classic handlers
############################################################
@@ -333,52 +285,6 @@ class LocalLink:
initiator_controller.public_address, int(not (initiator_new_role))
)
def classic_sco_connect(
self,
initiator_controller: controller.Controller,
responder_address: Address,
link_type: int,
):
logger.debug(
f'[Classic] {initiator_controller.public_address} connects SCO to {responder_address}'
)
responder_controller = self.find_classic_controller(responder_address)
# Initiator controller should handle it.
assert responder_controller
responder_controller.on_classic_connection_request(
initiator_controller.public_address,
link_type,
)
def classic_accept_sco_connection(
self,
responder_controller: controller.Controller,
initiator_address: Address,
link_type: int,
):
logger.debug(
f'[Classic] {responder_controller.public_address} accepts to connect SCO {initiator_address}'
)
initiator_controller = self.find_classic_controller(initiator_address)
if initiator_controller is None:
responder_controller.on_classic_sco_connection_complete(
responder_controller.public_address,
HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR,
link_type,
)
return
async def task():
initiator_controller.on_classic_sco_connection_complete(
responder_controller.public_address, HCI_SUCCESS, link_type
)
asyncio.create_task(task())
responder_controller.on_classic_sco_connection_complete(
initiator_controller.public_address, HCI_SUCCESS, link_type
)
# -----------------------------------------------------------------------------
class RemoteLink:
@@ -468,11 +374,11 @@ class RemoteLink:
async def on_left_received(self, address):
if address in self.central_connections:
self.controller.on_link_peripheral_disconnected(Address(address))
self.controller.on_link_connection_lost(Address(address))
self.central_connections.remove(address)
if address in self.peripheral_connections:
self.controller.on_link_central_disconnected(
self.controller.on_link_peer_disconnected(
address, HCI_CONNECTION_TIMEOUT_ERROR
)
self.peripheral_connections.remove(address)
@@ -492,7 +398,7 @@ class RemoteLink:
async def on_advertisement_message_received(self, sender, advertisement):
try:
self.controller.on_link_advertising_data(
Address(sender), bytes.fromhex(advertisement)
Address(sender), bytes.fromhex(advertisement), b''
)
except Exception:
logger.exception('exception')
@@ -532,7 +438,7 @@ class RemoteLink:
# Notify the controller
params = parse_parameters(message)
reason = int(params.get('reason', str(HCI_CONNECTION_TIMEOUT_ERROR)))
self.controller.on_link_central_disconnected(Address(sender), reason)
self.controller.on_link_peer_disconnected(Address(sender), reason)
# Forget the connection
if sender in self.peripheral_connections:
@@ -579,7 +485,7 @@ class RemoteLink:
async def send_advertising_data_to_relay(self, data):
await self.send_targeted_message('*', f'advertisement:{data.hex()}')
def send_advertising_data(self, _, data):
def send_advertising_data(self, _, data, scan_response):
self.execute(partial(self.send_advertising_data_to_relay, data))
async def send_acl_data_to_relay(self, peer_address, data):
+4 -5
View File
@@ -285,11 +285,10 @@ class HostService(HostServicer):
raise NotImplementedError(
"TODO: add support for extended advertising in Bumble"
)
if advertising_interval := request.interval:
self.device.config.advertising_interval_min = int(advertising_interval)
self.device.config.advertising_interval_max = int(advertising_interval)
if interval_range := request.interval_range:
self.device.config.advertising_interval_max += int(interval_range)
if request.interval:
raise NotImplementedError("TODO: add support for `request.interval`")
if request.interval_range:
raise NotImplementedError("TODO: add support for `request.interval_range`")
if request.primary_phy:
raise NotImplementedError("TODO: add support for `request.primary_phy`")
if request.secondary_phy:
+4 -4
View File
@@ -110,7 +110,7 @@ class PairingDelegate(BasePairingDelegate):
event = self.add_origin(PairingEvent(just_works=empty_pb2.Empty()))
self.service.event_queue.put_nowait(event)
answer = await anext(self.service.event_answer) # type: ignore
answer = await anext(self.service.event_answer) # pytype: disable=name-error
assert answer.event == event
assert answer.answer_variant() == 'confirm' and answer.confirm is not None
return answer.confirm
@@ -125,7 +125,7 @@ class PairingDelegate(BasePairingDelegate):
event = self.add_origin(PairingEvent(numeric_comparison=number))
self.service.event_queue.put_nowait(event)
answer = await anext(self.service.event_answer) # type: ignore
answer = await anext(self.service.event_answer) # pytype: disable=name-error
assert answer.event == event
assert answer.answer_variant() == 'confirm' and answer.confirm is not None
return answer.confirm
@@ -140,7 +140,7 @@ class PairingDelegate(BasePairingDelegate):
event = self.add_origin(PairingEvent(passkey_entry_request=empty_pb2.Empty()))
self.service.event_queue.put_nowait(event)
answer = await anext(self.service.event_answer) # type: ignore
answer = await anext(self.service.event_answer) # pytype: disable=name-error
assert answer.event == event
if answer.answer_variant() is None:
return None
@@ -157,7 +157,7 @@ class PairingDelegate(BasePairingDelegate):
event = self.add_origin(PairingEvent(pin_code_request=empty_pb2.Empty()))
self.service.event_queue.put_nowait(event)
answer = await anext(self.service.event_answer) # type: ignore
answer = await anext(self.service.event_answer) # pytype: disable=name-error
assert answer.event == event
if answer.answer_variant() is None:
return None
+2 -2
View File
@@ -18,7 +18,7 @@
# -----------------------------------------------------------------------------
import struct
import logging
from typing import List, Optional
from typing import List
from bumble import l2cap
from ..core import AdvertisingData
@@ -67,7 +67,7 @@ class AshaService(TemplateService):
self.emit('volume', connection, value[0])
# Handler for audio control commands
def on_audio_control_point_write(connection: Optional[Connection], value):
def on_audio_control_point_write(connection: Connection, value):
logger.info(f'--- AUDIO CONTROL POINT Write:{value.hex()}')
opcode = value[0]
if opcode == AshaService.OPCODE_START:
+3 -3
View File
@@ -114,7 +114,7 @@ class SamplingFrequency(enum.IntEnum):
'''Bluetooth Assigned Numbers, Section 6.12.5.1 - Sampling Frequency'''
# fmt: off
FREQ_8000 = 0x01
FREQ_8000 = 0x01
FREQ_11025 = 0x02
FREQ_16000 = 0x03
FREQ_22050 = 0x04
@@ -430,7 +430,7 @@ class AseResponseCode(enum.IntEnum):
REJECTED_METADATA = 0x0B
INVALID_METADATA = 0x0C
INSUFFICIENT_RESOURCES = 0x0D
UNSPECIFIED_ERROR = 0x0E
UNSPECIFIED_ERROR = 0x0E
class AseReasonCode(enum.IntEnum):
@@ -1066,7 +1066,7 @@ class AseStateMachine(gatt.Characteristic):
# Readonly. Do nothing in the setter.
pass
def on_read(self, _: Optional[device.Connection]) -> bytes:
def on_read(self, _: device.Connection) -> bytes:
return self.value
def __str__(self) -> str:
+8 -60
View File
@@ -19,7 +19,7 @@
from __future__ import annotations
import enum
import struct
from typing import Optional, Tuple
from typing import Optional
from bumble import core
from bumble import crypto
@@ -31,9 +31,6 @@ from bumble import gatt_client
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
SET_IDENTITY_RESOLVING_KEY_LENGTH = 16
class SirkType(enum.IntEnum):
'''Coordinated Set Identification Service - 5.1 Set Identity Resolving Key.'''
@@ -69,10 +66,6 @@ def k1(n: bytes, salt: bytes, p: bytes) -> bytes:
def sef(k: bytes, r: bytes) -> bytes:
'''
Coordinated Set Identification Service - 4.5 SIRK encryption function sef.
SIRK decryption function sdf shares the same algorithm. The only difference is that argument r is:
* Plaintext in encryption
* Cipher in decryption
'''
return crypto.xor(k1(k, s1(b'SIRKenc'[::-1]), b'csis'[::-1]), r)
@@ -112,11 +105,6 @@ class CoordinatedSetIdentificationService(gatt.TemplateService):
set_member_lock: Optional[MemberLock] = None,
set_member_rank: Optional[int] = None,
) -> None:
if len(set_identity_resolving_key) != SET_IDENTITY_RESOLVING_KEY_LENGTH:
raise ValueError(
f'Invalid SIRK length {len(set_identity_resolving_key)}, expected {SET_IDENTITY_RESOLVING_KEY_LENGTH}'
)
characteristics = []
self.set_identity_resolving_key = set_identity_resolving_key
@@ -125,7 +113,7 @@ class CoordinatedSetIdentificationService(gatt.TemplateService):
uuid=gatt.GATT_SET_IDENTITY_RESOLVING_KEY_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
permissions=gatt.Characteristic.Permissions.READABLE,
value=gatt.CharacteristicValue(read=self.on_sirk_read),
)
characteristics.append(self.set_identity_resolving_key_characteristic)
@@ -135,7 +123,7 @@ class CoordinatedSetIdentificationService(gatt.TemplateService):
uuid=gatt.GATT_COORDINATED_SET_SIZE_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
permissions=gatt.Characteristic.Permissions.READABLE,
value=struct.pack('B', coordinated_set_size),
)
characteristics.append(self.coordinated_set_size_characteristic)
@@ -146,7 +134,7 @@ class CoordinatedSetIdentificationService(gatt.TemplateService):
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY
| gatt.Characteristic.Properties.WRITE,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION
permissions=gatt.Characteristic.Permissions.READABLE
| gatt.Characteristic.Permissions.WRITEABLE,
value=struct.pack('B', set_member_lock),
)
@@ -157,32 +145,18 @@ class CoordinatedSetIdentificationService(gatt.TemplateService):
uuid=gatt.GATT_SET_MEMBER_RANK_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
permissions=gatt.Characteristic.Permissions.READABLE,
value=struct.pack('B', set_member_rank),
)
characteristics.append(self.set_member_rank_characteristic)
super().__init__(characteristics)
async def on_sirk_read(self, connection: Optional[device.Connection]) -> bytes:
def on_sirk_read(self, _connection: device.Connection) -> bytes:
if self.set_identity_resolving_key_type == SirkType.PLAINTEXT:
sirk_bytes = self.set_identity_resolving_key
return bytes([SirkType.PLAINTEXT]) + self.set_identity_resolving_key
else:
assert connection
if connection.transport == core.BT_LE_TRANSPORT:
key = await connection.device.get_long_term_key(
connection_handle=connection.handle, rand=b'', ediv=0
)
else:
key = await connection.device.get_link_key(connection.peer_address)
if not key:
raise RuntimeError('LTK or LinkKey is not present')
sirk_bytes = sef(key, self.set_identity_resolving_key)
return bytes([self.set_identity_resolving_key_type]) + sirk_bytes
raise NotImplementedError('TODO: Pending async Characteristic read.')
def get_advertising_data(self) -> bytes:
return bytes(
@@ -229,29 +203,3 @@ class CoordinatedSetIdentificationProxy(gatt_client.ProfileServiceProxy):
gatt.GATT_SET_MEMBER_RANK_CHARACTERISTIC
):
self.set_member_rank = characteristics[0]
async def read_set_identity_resolving_key(self) -> Tuple[SirkType, bytes]:
'''Reads SIRK and decrypts if encrypted.'''
response = await self.set_identity_resolving_key.read_value()
if len(response) != SET_IDENTITY_RESOLVING_KEY_LENGTH + 1:
raise RuntimeError('Invalid SIRK value')
sirk_type = SirkType(response[0])
if sirk_type == SirkType.PLAINTEXT:
sirk = response[1:]
else:
connection = self.service_proxy.client.connection
device = connection.device
if connection.transport == core.BT_LE_TRANSPORT:
key = await device.get_long_term_key(
connection_handle=connection.handle, rand=b'', ediv=0
)
else:
key = await device.get_link_key(connection.peer_address)
if not key:
raise RuntimeError('LTK or LinkKey is not present')
sirk = sef(key, response[1:])
return (sirk_type, sirk)
+119 -74
View File
@@ -19,7 +19,6 @@ from __future__ import annotations
import logging
import asyncio
import dataclasses
import enum
from typing import Callable, Dict, List, Optional, Tuple, Union, TYPE_CHECKING
@@ -61,18 +60,27 @@ logger = logging.getLogger(__name__)
RFCOMM_PSM = 0x0003
class FrameType(enum.IntEnum):
SABM = 0x2F # Control field [1,1,1,1,_,1,0,0] LSB-first
UA = 0x63 # Control field [0,1,1,0,_,0,1,1] LSB-first
DM = 0x0F # Control field [1,1,1,1,_,0,0,0] LSB-first
DISC = 0x43 # Control field [0,1,0,_,0,0,1,1] LSB-first
UIH = 0xEF # Control field [1,1,1,_,1,1,1,1] LSB-first
UI = 0x03 # Control field [0,0,0,_,0,0,1,1] LSB-first
class MccType(enum.IntEnum):
PN = 0x20
MSC = 0x38
# Frame types
RFCOMM_SABM_FRAME = 0x2F # Control field [1,1,1,1,_,1,0,0] LSB-first
RFCOMM_UA_FRAME = 0x63 # Control field [0,1,1,0,_,0,1,1] LSB-first
RFCOMM_DM_FRAME = 0x0F # Control field [1,1,1,1,_,0,0,0] LSB-first
RFCOMM_DISC_FRAME = 0x43 # Control field [0,1,0,_,0,0,1,1] LSB-first
RFCOMM_UIH_FRAME = 0xEF # Control field [1,1,1,_,1,1,1,1] LSB-first
RFCOMM_UI_FRAME = 0x03 # Control field [0,0,0,_,0,0,1,1] LSB-first
RFCOMM_FRAME_TYPE_NAMES = {
RFCOMM_SABM_FRAME: 'SABM',
RFCOMM_UA_FRAME: 'UA',
RFCOMM_DM_FRAME: 'DM',
RFCOMM_DISC_FRAME: 'DISC',
RFCOMM_UIH_FRAME: 'UIH',
RFCOMM_UI_FRAME: 'UI'
}
# MCC Types
RFCOMM_MCC_PN_TYPE = 0x20
RFCOMM_MCC_MSC_TYPE = 0x38
# FCS CRC
CRC_TABLE = bytes([
@@ -110,8 +118,8 @@ CRC_TABLE = bytes([
0XBA, 0X2B, 0X59, 0XC8, 0XBD, 0X2C, 0X5E, 0XCF
])
RFCOMM_DEFAULT_WINDOW_SIZE = 7
RFCOMM_DEFAULT_MAX_FRAME_SIZE = 2000
RFCOMM_DEFAULT_INITIAL_RX_CREDITS = 7
RFCOMM_DEFAULT_PREFERRED_MTU = 1280
RFCOMM_DYNAMIC_CHANNEL_NUMBER_START = 1
RFCOMM_DYNAMIC_CHANNEL_NUMBER_END = 30
@@ -175,7 +183,7 @@ def compute_fcs(buffer: bytes) -> int:
class RFCOMM_Frame:
def __init__(
self,
frame_type: FrameType,
frame_type: int,
c_r: int,
dlci: int,
p_f: int,
@@ -198,11 +206,14 @@ class RFCOMM_Frame:
self.length = bytes([(length << 1) | 1])
self.address = (dlci << 2) | (c_r << 1) | 1
self.control = frame_type | (p_f << 4)
if frame_type == FrameType.UIH:
if frame_type == RFCOMM_UIH_FRAME:
self.fcs = compute_fcs(bytes([self.address, self.control]))
else:
self.fcs = compute_fcs(bytes([self.address, self.control]) + self.length)
def type_name(self) -> str:
return RFCOMM_FRAME_TYPE_NAMES[self.type]
@staticmethod
def parse_mcc(data) -> Tuple[int, bool, bytes]:
mcc_type = data[0] >> 2
@@ -226,24 +237,24 @@ class RFCOMM_Frame:
@staticmethod
def sabm(c_r: int, dlci: int):
return RFCOMM_Frame(FrameType.SABM, c_r, dlci, 1)
return RFCOMM_Frame(RFCOMM_SABM_FRAME, c_r, dlci, 1)
@staticmethod
def ua(c_r: int, dlci: int):
return RFCOMM_Frame(FrameType.UA, c_r, dlci, 1)
return RFCOMM_Frame(RFCOMM_UA_FRAME, c_r, dlci, 1)
@staticmethod
def dm(c_r: int, dlci: int):
return RFCOMM_Frame(FrameType.DM, c_r, dlci, 1)
return RFCOMM_Frame(RFCOMM_DM_FRAME, c_r, dlci, 1)
@staticmethod
def disc(c_r: int, dlci: int):
return RFCOMM_Frame(FrameType.DISC, c_r, dlci, 1)
return RFCOMM_Frame(RFCOMM_DISC_FRAME, c_r, dlci, 1)
@staticmethod
def uih(c_r: int, dlci: int, information: bytes, p_f: int = 0):
return RFCOMM_Frame(
FrameType.UIH, c_r, dlci, p_f, information, with_credits=(p_f == 1)
RFCOMM_UIH_FRAME, c_r, dlci, p_f, information, with_credits=(p_f == 1)
)
@staticmethod
@@ -251,7 +262,7 @@ class RFCOMM_Frame:
# Extract fields
dlci = (data[0] >> 2) & 0x3F
c_r = (data[0] >> 1) & 0x01
frame_type = FrameType(data[1] & 0xEF)
frame_type = data[1] & 0xEF
p_f = (data[1] >> 4) & 0x01
length = data[2]
if length & 0x01:
@@ -280,7 +291,7 @@ class RFCOMM_Frame:
def __str__(self) -> str:
return (
f'{color(self.type.name, "yellow")}'
f'{color(self.type_name(), "yellow")}'
f'(c/r={self.c_r},'
f'dlci={self.dlci},'
f'p/f={self.p_f},'
@@ -290,7 +301,6 @@ class RFCOMM_Frame:
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class RFCOMM_MCC_PN:
dlci: int
cl: int
@@ -300,11 +310,23 @@ class RFCOMM_MCC_PN:
max_retransmissions: int
window_size: int
def __post_init__(self) -> None:
if self.window_size < 1 or self.window_size > 7:
logger.warning(
f'Error Recovery Window size {self.window_size} is out of range [1, 7].'
)
def __init__(
self,
dlci: int,
cl: int,
priority: int,
ack_timer: int,
max_frame_size: int,
max_retransmissions: int,
window_size: int,
) -> None:
self.dlci = dlci
self.cl = cl
self.priority = priority
self.ack_timer = ack_timer
self.max_frame_size = max_frame_size
self.max_retransmissions = max_retransmissions
self.window_size = window_size
@staticmethod
def from_bytes(data: bytes) -> RFCOMM_MCC_PN:
@@ -315,7 +337,7 @@ class RFCOMM_MCC_PN:
ack_timer=data[3],
max_frame_size=data[4] | data[5] << 8,
max_retransmissions=data[6],
window_size=data[7] & 0x07,
window_size=data[7],
)
def __bytes__(self) -> bytes:
@@ -328,14 +350,23 @@ class RFCOMM_MCC_PN:
self.max_frame_size & 0xFF,
(self.max_frame_size >> 8) & 0xFF,
self.max_retransmissions & 0xFF,
# Only 3 bits are meaningful.
self.window_size & 0x07,
self.window_size & 0xFF,
]
)
def __str__(self) -> str:
return (
f'PN(dlci={self.dlci},'
f'cl={self.cl},'
f'priority={self.priority},'
f'ack_timer={self.ack_timer},'
f'max_frame_size={self.max_frame_size},'
f'max_retransmissions={self.max_retransmissions},'
f'window_size={self.window_size})'
)
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class RFCOMM_MCC_MSC:
dlci: int
fc: int
@@ -344,6 +375,16 @@ class RFCOMM_MCC_MSC:
ic: int
dv: int
def __init__(
self, dlci: int, fc: int, rtc: int, rtr: int, ic: int, dv: int
) -> None:
self.dlci = dlci
self.fc = fc
self.rtc = rtc
self.rtr = rtr
self.ic = ic
self.dv = dv
@staticmethod
def from_bytes(data: bytes) -> RFCOMM_MCC_MSC:
return RFCOMM_MCC_MSC(
@@ -368,6 +409,16 @@ class RFCOMM_MCC_MSC:
]
)
def __str__(self) -> str:
return (
f'MSC(dlci={self.dlci},'
f'fc={self.fc},'
f'rtc={self.rtc},'
f'rtr={self.rtr},'
f'ic={self.ic},'
f'dv={self.dv})'
)
# -----------------------------------------------------------------------------
class DLC(EventEmitter):
@@ -387,24 +438,20 @@ class DLC(EventEmitter):
multiplexer: Multiplexer,
dlci: int,
max_frame_size: int,
window_size: int,
initial_tx_credits: int,
) -> None:
super().__init__()
self.multiplexer = multiplexer
self.dlci = dlci
self.max_frame_size = max_frame_size
self.window_size = window_size
self.rx_credits = window_size
self.rx_threshold = window_size // 2
self.tx_credits = window_size
self.rx_credits = RFCOMM_DEFAULT_INITIAL_RX_CREDITS
self.rx_threshold = self.rx_credits // 2
self.tx_credits = initial_tx_credits
self.tx_buffer = b''
self.state = DLC.State.INIT
self.role = multiplexer.role
self.c_r = 1 if self.role == Multiplexer.Role.INITIATOR else 0
self.sink = None
self.connection_result = None
self.drained = asyncio.Event()
self.drained.set()
# Compute the MTU
max_overhead = 4 + 1 # header with 2-byte length + fcs
@@ -420,7 +467,7 @@ class DLC(EventEmitter):
self.multiplexer.send_frame(frame)
def on_frame(self, frame: RFCOMM_Frame) -> None:
handler = getattr(self, f'on_{frame.type.name}_frame'.lower())
handler = getattr(self, f'on_{frame.type_name()}_frame'.lower())
handler(frame)
def on_sabm_frame(self, _frame: RFCOMM_Frame) -> None:
@@ -434,7 +481,9 @@ class DLC(EventEmitter):
# Exchange the modem status with the peer
msc = RFCOMM_MCC_MSC(dlci=self.dlci, fc=0, rtc=1, rtr=1, ic=0, dv=1)
mcc = RFCOMM_Frame.make_mcc(mcc_type=MccType.MSC, c_r=1, data=bytes(msc))
mcc = RFCOMM_Frame.make_mcc(
mcc_type=RFCOMM_MCC_MSC_TYPE, c_r=1, data=bytes(msc)
)
logger.debug(f'>>> MCC MSC Command: {msc}')
self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
@@ -450,7 +499,9 @@ class DLC(EventEmitter):
# Exchange the modem status with the peer
msc = RFCOMM_MCC_MSC(dlci=self.dlci, fc=0, rtc=1, rtr=1, ic=0, dv=1)
mcc = RFCOMM_Frame.make_mcc(mcc_type=MccType.MSC, c_r=1, data=bytes(msc))
mcc = RFCOMM_Frame.make_mcc(
mcc_type=RFCOMM_MCC_MSC_TYPE, c_r=1, data=bytes(msc)
)
logger.debug(f'>>> MCC MSC Command: {msc}')
self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
@@ -483,15 +534,14 @@ class DLC(EventEmitter):
f'[{self.dlci}] {len(data)} bytes, '
f'rx_credits={self.rx_credits}: {data.hex()}'
)
if data:
if self.sink:
self.sink(data) # pylint: disable=not-callable
if len(data) and self.sink:
self.sink(data) # pylint: disable=not-callable
# Update the credits
if self.rx_credits > 0:
self.rx_credits -= 1
else:
logger.warning(color('!!! received frame with no rx credits', 'red'))
# Update the credits
if self.rx_credits > 0:
self.rx_credits -= 1
else:
logger.warning(color('!!! received frame with no rx credits', 'red'))
# Check if there's anything to send (including credits)
self.process_tx()
@@ -504,7 +554,9 @@ class DLC(EventEmitter):
# Command
logger.debug(f'<<< MCC MSC Command: {msc}')
msc = RFCOMM_MCC_MSC(dlci=self.dlci, fc=0, rtc=1, rtr=1, ic=0, dv=1)
mcc = RFCOMM_Frame.make_mcc(mcc_type=MccType.MSC, c_r=0, data=bytes(msc))
mcc = RFCOMM_Frame.make_mcc(
mcc_type=RFCOMM_MCC_MSC_TYPE, c_r=0, data=bytes(msc)
)
logger.debug(f'>>> MCC MSC Response: {msc}')
self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
else:
@@ -528,18 +580,18 @@ class DLC(EventEmitter):
cl=0xE0,
priority=7,
ack_timer=0,
max_frame_size=self.max_frame_size,
max_frame_size=RFCOMM_DEFAULT_PREFERRED_MTU,
max_retransmissions=0,
window_size=self.window_size,
window_size=RFCOMM_DEFAULT_INITIAL_RX_CREDITS,
)
mcc = RFCOMM_Frame.make_mcc(mcc_type=MccType.PN, c_r=0, data=bytes(pn))
mcc = RFCOMM_Frame.make_mcc(mcc_type=RFCOMM_MCC_PN_TYPE, c_r=0, data=bytes(pn))
logger.debug(f'>>> PN Response: {pn}')
self.send_frame(RFCOMM_Frame.uih(c_r=self.c_r, dlci=0, information=mcc))
self.change_state(DLC.State.CONNECTING)
def rx_credits_needed(self) -> int:
if self.rx_credits <= self.rx_threshold:
return self.window_size - self.rx_credits
return RFCOMM_DEFAULT_INITIAL_RX_CREDITS - self.rx_credits
return 0
@@ -579,8 +631,6 @@ class DLC(EventEmitter):
)
rx_credits_needed = 0
if not self.tx_buffer:
self.drained.set()
# Stream protocol
def write(self, data: Union[bytes, str]) -> None:
@@ -593,11 +643,11 @@ class DLC(EventEmitter):
raise ValueError('write only accept bytes or strings')
self.tx_buffer += data
self.drained.clear()
self.process_tx()
async def drain(self) -> None:
await self.drained.wait()
def drain(self) -> None:
# TODO
pass
def __str__(self) -> str:
return f'DLC(dlci={self.dlci},state={self.state.name})'
@@ -654,7 +704,7 @@ class Multiplexer(EventEmitter):
if frame.dlci == 0:
self.on_frame(frame)
else:
if frame.type == FrameType.DM:
if frame.type == RFCOMM_DM_FRAME:
# DM responses are for a DLCI, but since we only create the dlc when we
# receive a PN response (because we need the parameters), we handle DM
# frames at the Multiplexer level
@@ -667,7 +717,7 @@ class Multiplexer(EventEmitter):
dlc.on_frame(frame)
def on_frame(self, frame: RFCOMM_Frame) -> None:
handler = getattr(self, f'on_{frame.type.name}_frame'.lower())
handler = getattr(self, f'on_{frame.type_name()}_frame'.lower())
handler(frame)
def on_sabm_frame(self, _frame: RFCOMM_Frame) -> None:
@@ -715,10 +765,10 @@ class Multiplexer(EventEmitter):
def on_uih_frame(self, frame: RFCOMM_Frame) -> None:
(mcc_type, c_r, value) = RFCOMM_Frame.parse_mcc(frame.information)
if mcc_type == MccType.PN:
if mcc_type == RFCOMM_MCC_PN_TYPE:
pn = RFCOMM_MCC_PN.from_bytes(value)
self.on_mcc_pn(c_r, pn)
elif mcc_type == MccType.MSC:
elif mcc_type == RFCOMM_MCC_MSC_TYPE:
mcs = RFCOMM_MCC_MSC.from_bytes(value)
self.on_mcc_msc(c_r, mcs)
@@ -793,12 +843,7 @@ class Multiplexer(EventEmitter):
)
await self.disconnection_result
async def open_dlc(
self,
channel: int,
max_frame_size: int = RFCOMM_DEFAULT_MAX_FRAME_SIZE,
window_size: int = RFCOMM_DEFAULT_WINDOW_SIZE,
) -> DLC:
async def open_dlc(self, channel: int) -> DLC:
if self.state != Multiplexer.State.CONNECTED:
if self.state == Multiplexer.State.OPENING:
raise InvalidStateError('open already in progress')
@@ -810,11 +855,11 @@ class Multiplexer(EventEmitter):
cl=0xF0,
priority=7,
ack_timer=0,
max_frame_size=max_frame_size,
max_frame_size=RFCOMM_DEFAULT_PREFERRED_MTU,
max_retransmissions=0,
window_size=window_size,
window_size=RFCOMM_DEFAULT_INITIAL_RX_CREDITS,
)
mcc = RFCOMM_Frame.make_mcc(mcc_type=MccType.PN, c_r=1, data=bytes(pn))
mcc = RFCOMM_Frame.make_mcc(mcc_type=RFCOMM_MCC_PN_TYPE, c_r=1, data=bytes(pn))
logger.debug(f'>>> Sending MCC: {pn}')
self.open_result = asyncio.get_running_loop().create_future()
self.change_state(Multiplexer.State.OPENING)
+2 -4
View File
@@ -97,8 +97,7 @@ SDP_CLIENT_EXECUTABLE_URL_ATTRIBUTE_ID = 0X000B
SDP_ICON_URL_ATTRIBUTE_ID = 0X000C
SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID = 0X000D
# Profile-specific Attribute Identifiers (cf. Assigned Numbers for Service Discovery)
# Attribute Identifier (cf. Assigned Numbers for Service Discovery)
# used by AVRCP, HFP and A2DP
SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID = 0x0311
@@ -116,8 +115,7 @@ SDP_ATTRIBUTE_ID_NAMES = {
SDP_DOCUMENTATION_URL_ATTRIBUTE_ID: 'SDP_DOCUMENTATION_URL_ATTRIBUTE_ID',
SDP_CLIENT_EXECUTABLE_URL_ATTRIBUTE_ID: 'SDP_CLIENT_EXECUTABLE_URL_ATTRIBUTE_ID',
SDP_ICON_URL_ATTRIBUTE_ID: 'SDP_ICON_URL_ATTRIBUTE_ID',
SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID: 'SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID',
SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID: 'SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID',
SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID: 'SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID'
}
SDP_PUBLIC_BROWSE_ROOT = core.UUID.from_16_bits(0x1002, 'PublicBrowseRoot')
+7 -7
View File
@@ -1134,10 +1134,8 @@ class Session:
async def get_link_key_and_derive_ltk(self) -> None:
'''Retrieves BR/EDR Link Key from storage and derive it to LE LTK.'''
self.link_key = await self.manager.device.get_link_key(
self.connection.peer_address
)
if self.link_key is None:
link_key = await self.manager.device.get_link_key(self.connection.peer_address)
if link_key is None:
logging.warning(
'Try to derive LTK but host does not have the LK. Send a SMP_PAIRING_FAILED but the procedure will not be paused!'
)
@@ -1145,7 +1143,7 @@ class Session:
SMP_CROSS_TRANSPORT_KEY_DERIVATION_NOT_ALLOWED_ERROR
)
else:
self.ltk = self.derive_ltk(self.link_key, self.ct2)
self.ltk = self.derive_ltk(link_key, self.ct2)
def distribute_keys(self) -> None:
# Distribute the keys as required
@@ -1993,8 +1991,10 @@ class Manager(EventEmitter):
) -> None:
# Store the keys in the key store
if self.device.keystore and identity_address is not None:
# Make sure on_pairing emits after key update.
await self.device.update_keys(str(identity_address), keys)
self.device.abort_on(
'flush', self.device.update_keys(str(identity_address), keys)
)
# Notify the device
self.device.on_pairing(session.connection, identity_address, keys, session.sc)
+3 -3
View File
@@ -82,13 +82,14 @@ async def open_transport(name: str) -> Transport:
scheme, *tail = name.split(':', 1)
spec = tail[0] if tail else None
metadata = None
if spec:
# Metadata may precede the spec
if spec.startswith('['):
metadata_str, *tail = spec[1:].split(']')
spec = tail[0] if tail else None
metadata = dict([entry.split('=') for entry in metadata_str.split(',')])
else:
metadata = None
transport = await _open_transport(scheme, spec)
if metadata:
@@ -197,13 +198,12 @@ async def open_transport_or_link(name: str) -> Transport:
"""
if name.startswith('link-relay:'):
logger.warning('Link Relay has been deprecated.')
from ..controller import Controller
from ..link import RemoteLink # lazy import
link = RemoteLink(name[11:])
await link.wait_until_connected()
controller = Controller('remote', link=link) # type:ignore[arg-type]
controller = Controller('remote', link=link)
class LinkTransport(Transport):
async def close(self):
+1 -1
View File
@@ -108,7 +108,7 @@ async def open_usb_transport(spec: str) -> Transport:
USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER,
)
READ_SIZE = 4096
READ_SIZE = 1024
class UsbPacketSink:
def __init__(self, device, acl_out):
+25 -50
View File
@@ -17,10 +17,9 @@
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import collections
import enum
import functools
import logging
import traceback
import collections
import sys
import warnings
from typing import (
@@ -35,7 +34,7 @@ from typing import (
Union,
overload,
)
from functools import wraps, partial
from pyee import EventEmitter
from .colors import color
@@ -132,14 +131,13 @@ class EventWatcher:
Args:
emitter: EventEmitter to watch
event: Event name
handler: (Optional) Event handler. When nothing is passed, this method
works as a decorator.
handler: (Optional) Event handler. When nothing is passed, this method works as a decorator.
'''
def wrapper(wrapped: _Handler) -> _Handler:
self.handlers.append((emitter, event, wrapped))
emitter.on(event, wrapped)
return wrapped
def wrapper(f: _Handler) -> _Handler:
self.handlers.append((emitter, event, f))
emitter.on(event, f)
return f
return wrapper if handler is None else wrapper(handler)
@@ -159,14 +157,13 @@ class EventWatcher:
Args:
emitter: EventEmitter to watch
event: Event name
handler: (Optional) Event handler. When nothing passed, this method works
as a decorator.
handler: (Optional) Event handler. When nothing passed, this method works as a decorator.
'''
def wrapper(wrapped: _Handler) -> _Handler:
self.handlers.append((emitter, event, wrapped))
emitter.once(event, wrapped)
return wrapped
def wrapper(f: _Handler) -> _Handler:
self.handlers.append((emitter, event, f))
emitter.once(event, f)
return f
return wrapper if handler is None else wrapper(handler)
@@ -279,18 +276,21 @@ class AsyncRunner:
"""
def decorator(func):
@functools.wraps(func)
@wraps(func)
def wrapper(*args, **kwargs):
coroutine = func(*args, **kwargs)
if queue is None:
# Spawn the coroutine as a task
# Create a task to run the coroutine
async def run():
try:
await coroutine
except Exception:
logger.exception(color("!!! Exception in wrapper:", "red"))
logger.warning(
f'{color("!!! Exception in wrapper:", "red")} '
f'{traceback.format_exc()}'
)
AsyncRunner.spawn(run())
asyncio.create_task(run())
else:
# Queue the coroutine to be awaited by the work queue
queue.enqueue(coroutine)
@@ -413,35 +413,30 @@ class FlowControlAsyncPipe:
self.check_pump()
# -----------------------------------------------------------------------------
async def async_call(function, *args, **kwargs):
"""
Immediately calls the function with provided args and kwargs, wrapping it in an
async function.
Rust's `pyo3_asyncio` library needs functions to be marked async to properly inject
a running loop.
Immediately calls the function with provided args and kwargs, wrapping it in an async function.
Rust's `pyo3_asyncio` library needs functions to be marked async to properly inject a running loop.
result = await async_call(some_function, ...)
"""
return function(*args, **kwargs)
# -----------------------------------------------------------------------------
def wrap_async(function):
"""
Wraps the provided function in an async function.
"""
return functools.partial(async_call, function)
return partial(async_call, function)
# -----------------------------------------------------------------------------
def deprecated(msg: str):
"""
Throw deprecation warning before execution.
"""
def wrapper(function):
@functools.wraps(function)
@wraps(function)
def inner(*args, **kwargs):
warnings.warn(msg, DeprecationWarning)
return function(*args, **kwargs)
@@ -451,14 +446,13 @@ def deprecated(msg: str):
return wrapper
# -----------------------------------------------------------------------------
def experimental(msg: str):
"""
Throws a future warning before execution.
"""
def wrapper(function):
@functools.wraps(function)
@wraps(function)
def inner(*args, **kwargs):
warnings.warn(msg, FutureWarning)
return function(*args, **kwargs)
@@ -466,22 +460,3 @@ def experimental(msg: str):
return inner
return wrapper
# -----------------------------------------------------------------------------
class OpenIntEnum(enum.IntEnum):
"""
Subclass of enum.IntEnum that can hold integer values outside the set of
predefined values. This is convenient for implementing protocols where some
integer constants may be added over time.
"""
@classmethod
def _missing_(cls, value):
if not isinstance(value, int):
return None
obj = int.__new__(cls, value)
obj._value_ = value
obj._name_ = f"{cls.__name__}[{value}]"
return obj
+9 -30
View File
@@ -7,36 +7,16 @@ throughput and/or latency between two devices.
# General Usage
```
Usage: bumble-bench [OPTIONS] COMMAND [ARGS]...
Usage: bench.py [OPTIONS] COMMAND [ARGS]...
Options:
--device-config FILENAME Device configuration file
--role [sender|receiver|ping|pong]
--mode [gatt-client|gatt-server|l2cap-client|l2cap-server|rfcomm-client|rfcomm-server]
--att-mtu MTU GATT MTU (gatt-client mode) [23<=x<=517]
--extended-data-length TEXT Request a data length upon connection,
specified as tx_octets/tx_time
--rfcomm-channel INTEGER RFComm channel to use
--rfcomm-uuid TEXT RFComm service UUID to use (ignored if
--rfcomm-channel is not 0)
--l2cap-psm INTEGER L2CAP PSM to use
--l2cap-mtu INTEGER L2CAP MTU to use
--l2cap-mps INTEGER L2CAP MPS to use
--l2cap-max-credits INTEGER L2CAP maximum number of credits allowed for
the peer
-s, --packet-size SIZE Packet size (client or ping role)
[8<=x<=4096]
-c, --packet-count COUNT Packet count (client or ping role)
-sd, --start-delay SECONDS Start delay (client or ping role)
--repeat N Repeat the run N times (client and ping
roles)(0, which is the fault, to run just
once)
--repeat-delay SECONDS Delay, in seconds, between repeats
--pace MILLISECONDS Wait N milliseconds between packets (0,
which is the fault, to send as fast as
possible)
--linger Don't exit at the end of a run (server and
pong roles)
-s, --packet-size SIZE Packet size (server role) [8<=x<=4096]
-c, --packet-count COUNT Packet count (server role)
-sd, --start-delay SECONDS Start delay (server role)
--help Show this message and exit.
Commands:
@@ -55,18 +35,17 @@ Options:
--connection-interval, --ci CONNECTION_INTERVAL
Connection interval (in ms)
--phy [1m|2m|coded] PHY to use
--authenticate Authenticate (RFComm only)
--encrypt Encrypt the connection (RFComm only)
--help Show this message and exit.
```
To test once device against another, one of the two devices must be running
To test once device against another, one of the two devices must be running
the ``peripheral`` command and the other the ``central`` command. The device
running the ``peripheral`` command will accept connections from the device
running the ``central`` command.
When using Bluetooth LE (all modes except for ``rfcomm-server`` and ``rfcomm-client``utils),
the default addresses configured in the tool should be sufficient. But when using
Bluetooth Classic, the address of the Peripheral must be specified on the Central
the default addresses configured in the tool should be sufficient. But when using
Bluetooth Classic, the address of the Peripheral must be specified on the Central
using the ``--peripheral`` option. The address will be printed by the Peripheral when
it starts.
@@ -104,7 +83,7 @@ the other on `usb:1`, and two consoles/terminals. We will run a command in each.
$ bumble-bench central usb:1
```
In this default configuration, the Central runs a Sender, as a GATT client,
In this default configuration, the Central runs a Sender, as a GATT client,
connecting to the Peripheral running a Receiver, as a GATT server.
!!! example "L2CAP Throughput"
-274
View File
@@ -1,274 +0,0 @@
<html>
<head>
<style>
* {
font-family: sans-serif;
}
</style>
</head>
<body>
Server Port <input id="port" type="text" value="8989"></input> <button id="connectButton" onclick="connect()">Connect</button><br>
<div id="socketState"></div>
<br>
<div id="buttons"></div><br>
<hr>
<button onclick="onGetPlayStatusButtonClicked()">Get Play Status</button><br>
<div id="getPlayStatusResponseTable"></div>
<hr>
<button onclick="onGetElementAttributesButtonClicked()">Get Element Attributes</button><br>
<div id="getElementAttributesResponseTable"></div>
<hr>
<table>
<tr>
<b>VOLUME</b>:
<button onclick="onVolumeDownButtonClicked()">-</button>
<button onclick="onVolumeUpButtonClicked()">+</button>&nbsp;
<span id="volumeText"></span><br>
</tr>
<tr>
<td><b>PLAYBACK STATUS</b></td><td><span id="playbackStatusText"></span></td>
</tr>
<tr>
<td><b>POSITION</b></td><td><span id="positionText"></span></td>
</tr>
<tr>
<td><b>TRACK</b></td><td><span id="trackText"></span></td>
</tr>
<tr>
<td><b>ADDRESSED PLAYER</b></td><td><span id="addressedPlayerText"></span></td>
</tr>
<tr>
<td><b>UID COUNTER</b></td><td><span id="uidCounterText"></span></td>
</tr>
<tr>
<td><b>SUPPORTED EVENTS</b></td><td><span id="supportedEventsText"></span></td>
</tr>
<tr>
<td><b>PLAYER SETTINGS</b></td><td><div id="playerSettingsTable"></div></td>
</tr>
</table>
<script>
const portInput = document.getElementById("port")
const connectButton = document.getElementById("connectButton")
const socketState = document.getElementById("socketState")
const volumeText = document.getElementById("volumeText")
const positionText = document.getElementById("positionText")
const trackText = document.getElementById("trackText")
const playbackStatusText = document.getElementById("playbackStatusText")
const addressedPlayerText = document.getElementById("addressedPlayerText")
const uidCounterText = document.getElementById("uidCounterText")
const supportedEventsText = document.getElementById("supportedEventsText")
const playerSettingsTable = document.getElementById("playerSettingsTable")
const getPlayStatusResponseTable = document.getElementById("getPlayStatusResponseTable")
const getElementAttributesResponseTable = document.getElementById("getElementAttributesResponseTable")
let socket
let volume = 0
const keyNames = [
"SELECT",
"UP",
"DOWN",
"LEFT",
"RIGHT",
"RIGHT_UP",
"RIGHT_DOWN",
"LEFT_UP",
"LEFT_DOWN",
"ROOT_MENU",
"SETUP_MENU",
"CONTENTS_MENU",
"FAVORITE_MENU",
"EXIT",
"NUMBER_0",
"NUMBER_1",
"NUMBER_2",
"NUMBER_3",
"NUMBER_4",
"NUMBER_5",
"NUMBER_6",
"NUMBER_7",
"NUMBER_8",
"NUMBER_9",
"DOT",
"ENTER",
"CLEAR",
"CHANNEL_UP",
"CHANNEL_DOWN",
"PREVIOUS_CHANNEL",
"SOUND_SELECT",
"INPUT_SELECT",
"DISPLAY_INFORMATION",
"HELP",
"PAGE_UP",
"PAGE_DOWN",
"POWER",
"VOLUME_UP",
"VOLUME_DOWN",
"MUTE",
"PLAY",
"STOP",
"PAUSE",
"RECORD",
"REWIND",
"FAST_FORWARD",
"EJECT",
"FORWARD",
"BACKWARD",
"ANGLE",
"SUBPICTURE",
"F1",
"F2",
"F3",
"F4",
"F5",
]
document.addEventListener('keydown', onKeyDown)
document.addEventListener('keyup', onKeyUp)
const buttons = document.getElementById("buttons")
keyNames.forEach(name => {
const button = document.createElement("BUTTON")
button.appendChild(document.createTextNode(name))
button.addEventListener("mousedown", event => {
send({type: 'send-key-down', key: name})
})
button.addEventListener("mouseup", event => {
send({type: 'send-key-up', key: name})
})
buttons.appendChild(button)
})
updateVolume(0)
function connect() {
socket = new WebSocket(`ws://localhost:${portInput.value}`);
socket.onopen = _ => {
socketState.innerText = 'OPEN'
connectButton.disabled = true
}
socket.onclose = _ => {
socketState.innerText = 'CLOSED'
connectButton.disabled = false
}
socket.onerror = (error) => {
socketState.innerText = 'ERROR'
console.log(`ERROR: ${error}`)
connectButton.disabled = false
}
socket.onmessage = (message) => {
onMessage(JSON.parse(message.data))
}
}
function send(message) {
if (socket && socket.readyState == WebSocket.OPEN) {
socket.send(JSON.stringify(message))
}
}
function hmsText(position) {
const h_1 = 1000 * 60 * 60
const h = Math.floor(position / h_1)
position -= h * h_1
const m_1 = 1000 * 60
const m = Math.floor(position / m_1)
position -= m * m_1
const s_1 = 1000
const s = Math.floor(position / s_1)
position -= s * s_1
return `${h}:${m.toString().padStart(2, "0")}:${s.toString().padStart(2, "0")}:${position}`
}
function setTableHead(table, columns) {
let thead = table.createTHead()
let row = thead.insertRow()
for (let column of columns) {
let th = document.createElement("th")
let text = document.createTextNode(column)
th.appendChild(text)
row.appendChild(th)
}
}
function createTable(rows) {
const table = document.createElement("table")
if (rows.length != 0) {
columns = Object.keys(rows[0])
setTableHead(table, columns)
}
for (let element of rows) {
let row = table.insertRow()
for (key in element) {
let cell = row.insertCell()
let text = document.createTextNode(element[key])
cell.appendChild(text)
}
}
return table
}
function onMessage(message) {
console.log(message)
if (message.type == "set-volume") {
updateVolume(message.params.volume)
} else if (message.type == "supported-events") {
supportedEventsText.innerText = JSON.stringify(message.params.events)
} else if (message.type == "playback-position-changed") {
positionText.innerText = hmsText(message.params.position)
} else if (message.type == "playback-status-changed") {
playbackStatusText.innerText = message.params.status
} else if (message.type == "player-settings-changed") {
playerSettingsTable.replaceChildren(message.params.settings)
} else if (message.type == "track-changed") {
trackText.innerText = message.params.identifier
} else if (message.type == "addressed-player-changed") {
addressedPlayerText.innerText = JSON.stringify(message.params.player)
} else if (message.type == "uids-changed") {
uidCounterText.innerText = message.params.uid_counter
} else if (message.type == "get-play-status-response") {
getPlayStatusResponseTable.replaceChildren(message.params)
} else if (message.type == "get-element-attributes-response") {
getElementAttributesResponseTable.replaceChildren(createTable(message.params))
}
}
function updateVolume(newVolume) {
volume = newVolume
volumeText.innerText = `${volume} (${Math.round(100*volume/0x7F)}%)`
}
function onKeyDown(event) {
console.log(event)
send({ type: 'send-key-down', key: event.key })
}
function onKeyUp(event) {
console.log(event)
send({ type: 'send-key-up', key: event.key })
}
function onVolumeUpButtonClicked() {
updateVolume(Math.min(volume + 5, 0x7F))
send({ type: 'set-volume', volume })
}
function onVolumeDownButtonClicked() {
updateVolume(Math.max(volume - 5, 0))
send({ type: 'set-volume', volume })
}
function onGetPlayStatusButtonClicked() {
send({ type: 'get-play-status', volume })
}
function onGetElementAttributesButtonClicked() {
send({ type: 'get-element-attributes' })
}
</script>
</body>
</html>
-408
View File
@@ -1,408 +0,0 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import json
import sys
import os
import logging
import websockets
from bumble.device import Device
from bumble.transport import open_transport_or_link
from bumble.core import BT_BR_EDR_TRANSPORT
from bumble import avc
from bumble import avrcp
from bumble import avdtp
from bumble import a2dp
from bumble import utils
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
def sdp_records():
a2dp_sink_service_record_handle = 0x00010001
avrcp_controller_service_record_handle = 0x00010002
avrcp_target_service_record_handle = 0x00010003
# pylint: disable=line-too-long
return {
a2dp_sink_service_record_handle: a2dp.make_audio_sink_service_sdp_records(
a2dp_sink_service_record_handle
),
avrcp_controller_service_record_handle: avrcp.make_controller_service_sdp_records(
avrcp_controller_service_record_handle
),
avrcp_target_service_record_handle: avrcp.make_target_service_sdp_records(
avrcp_controller_service_record_handle
),
}
# -----------------------------------------------------------------------------
def codec_capabilities():
return avdtp.MediaCodecCapabilities(
media_type=avdtp.AVDTP_AUDIO_MEDIA_TYPE,
media_codec_type=a2dp.A2DP_SBC_CODEC_TYPE,
media_codec_information=a2dp.SbcMediaCodecInformation.from_lists(
sampling_frequencies=[48000, 44100, 32000, 16000],
channel_modes=[
a2dp.SBC_MONO_CHANNEL_MODE,
a2dp.SBC_DUAL_CHANNEL_MODE,
a2dp.SBC_STEREO_CHANNEL_MODE,
a2dp.SBC_JOINT_STEREO_CHANNEL_MODE,
],
block_lengths=[4, 8, 12, 16],
subbands=[4, 8],
allocation_methods=[
a2dp.SBC_LOUDNESS_ALLOCATION_METHOD,
a2dp.SBC_SNR_ALLOCATION_METHOD,
],
minimum_bitpool_value=2,
maximum_bitpool_value=53,
),
)
# -----------------------------------------------------------------------------
def on_avdtp_connection(server):
# Add a sink endpoint to the server
sink = server.add_sink(codec_capabilities())
sink.on('rtp_packet', on_rtp_packet)
# -----------------------------------------------------------------------------
def on_rtp_packet(packet):
print(f'RTP: {packet}')
# -----------------------------------------------------------------------------
def on_avrcp_start(avrcp_protocol: avrcp.Protocol, websocket_server: WebSocketServer):
async def get_supported_events():
events = await avrcp_protocol.get_supported_events()
print("SUPPORTED EVENTS:", events)
websocket_server.send_message(
{
"type": "supported-events",
"params": {"events": [event.name for event in events]},
}
)
if avrcp.EventId.TRACK_CHANGED in events:
utils.AsyncRunner.spawn(monitor_track_changed())
if avrcp.EventId.PLAYBACK_STATUS_CHANGED in events:
utils.AsyncRunner.spawn(monitor_playback_status())
if avrcp.EventId.PLAYBACK_POS_CHANGED in events:
utils.AsyncRunner.spawn(monitor_playback_position())
if avrcp.EventId.PLAYER_APPLICATION_SETTING_CHANGED in events:
utils.AsyncRunner.spawn(monitor_player_application_settings())
if avrcp.EventId.AVAILABLE_PLAYERS_CHANGED in events:
utils.AsyncRunner.spawn(monitor_available_players())
if avrcp.EventId.ADDRESSED_PLAYER_CHANGED in events:
utils.AsyncRunner.spawn(monitor_addressed_player())
if avrcp.EventId.UIDS_CHANGED in events:
utils.AsyncRunner.spawn(monitor_uids())
if avrcp.EventId.VOLUME_CHANGED in events:
utils.AsyncRunner.spawn(monitor_volume())
utils.AsyncRunner.spawn(get_supported_events())
async def monitor_track_changed():
async for identifier in avrcp_protocol.monitor_track_changed():
print("TRACK CHANGED:", identifier.hex())
websocket_server.send_message(
{"type": "track-changed", "params": {"identifier": identifier.hex()}}
)
async def monitor_playback_status():
async for playback_status in avrcp_protocol.monitor_playback_status():
print("PLAYBACK STATUS CHANGED:", playback_status.name)
websocket_server.send_message(
{
"type": "playback-status-changed",
"params": {"status": playback_status.name},
}
)
async def monitor_playback_position():
async for playback_position in avrcp_protocol.monitor_playback_position(
playback_interval=1
):
print("PLAYBACK POSITION CHANGED:", playback_position)
websocket_server.send_message(
{
"type": "playback-position-changed",
"params": {"position": playback_position},
}
)
async def monitor_player_application_settings():
async for settings in avrcp_protocol.monitor_player_application_settings():
print("PLAYER APPLICATION SETTINGS:", settings)
settings_as_dict = [
{"attribute": setting.attribute_id.name, "value": setting.value_id.name}
for setting in settings
]
websocket_server.send_message(
{
"type": "player-settings-changed",
"params": {"settings": settings_as_dict},
}
)
async def monitor_available_players():
async for _ in avrcp_protocol.monitor_available_players():
print("AVAILABLE PLAYERS CHANGED")
websocket_server.send_message(
{"type": "available-players-changed", "params": {}}
)
async def monitor_addressed_player():
async for player in avrcp_protocol.monitor_addressed_player():
print("ADDRESSED PLAYER CHANGED")
websocket_server.send_message(
{
"type": "addressed-player-changed",
"params": {
"player": {
"player_id": player.player_id,
"uid_counter": player.uid_counter,
}
},
}
)
async def monitor_uids():
async for uid_counter in avrcp_protocol.monitor_uids():
print("UIDS CHANGED")
websocket_server.send_message(
{
"type": "uids-changed",
"params": {
"uid_counter": uid_counter,
},
}
)
async def monitor_volume():
async for volume in avrcp_protocol.monitor_volume():
print("VOLUME CHANGED:", volume)
websocket_server.send_message(
{"type": "volume-changed", "params": {"volume": volume}}
)
# -----------------------------------------------------------------------------
class WebSocketServer:
def __init__(
self, avrcp_protocol: avrcp.Protocol, avrcp_delegate: Delegate
) -> None:
self.socket = None
self.delegate = None
self.avrcp_protocol = avrcp_protocol
self.avrcp_delegate = avrcp_delegate
async def start(self) -> None:
# pylint: disable-next=no-member
await websockets.serve(self.serve, 'localhost', 8989) # type: ignore
async def serve(self, socket, _path) -> None:
print('### WebSocket connected')
self.socket = socket
while True:
try:
message = await socket.recv()
print('Received: ', str(message))
parsed = json.loads(message)
message_type = parsed['type']
if message_type == 'send-key-down':
await self.on_send_key_down(parsed)
elif message_type == 'send-key-up':
await self.on_send_key_up(parsed)
elif message_type == 'set-volume':
await self.on_set_volume(parsed)
elif message_type == 'get-play-status':
await self.on_get_play_status()
elif message_type == 'get-element-attributes':
await self.on_get_element_attributes()
except websockets.exceptions.ConnectionClosedOK:
self.socket = None
break
async def on_send_key_down(self, message: dict) -> None:
key = avc.PassThroughFrame.OperationId[message["key"]]
await self.avrcp_protocol.send_key_event(key, True)
async def on_send_key_up(self, message: dict) -> None:
key = avc.PassThroughFrame.OperationId[message["key"]]
await self.avrcp_protocol.send_key_event(key, False)
async def on_set_volume(self, message: dict) -> None:
volume = message["volume"]
self.avrcp_delegate.volume = volume
self.avrcp_protocol.notify_volume_changed(volume)
async def on_get_play_status(self) -> None:
play_status = await self.avrcp_protocol.get_play_status()
self.send_message(
{
"type": "get-play-status-response",
"params": {
"song_length": play_status.song_length,
"song_position": play_status.song_position,
"play_status": play_status.play_status.name,
},
}
)
async def on_get_element_attributes(self) -> None:
attributes = await self.avrcp_protocol.get_element_attributes(
0,
[
avrcp.MediaAttributeId.TITLE,
avrcp.MediaAttributeId.ARTIST_NAME,
avrcp.MediaAttributeId.ALBUM_NAME,
avrcp.MediaAttributeId.TRACK_NUMBER,
avrcp.MediaAttributeId.TOTAL_NUMBER_OF_TRACKS,
avrcp.MediaAttributeId.GENRE,
avrcp.MediaAttributeId.PLAYING_TIME,
avrcp.MediaAttributeId.DEFAULT_COVER_ART,
],
)
self.send_message(
{
"type": "get-element-attributes-response",
"params": [
{
"attribute_id": attribute.attribute_id.name,
"attribute_value": attribute.attribute_value,
}
for attribute in attributes
],
}
)
def send_message(self, message: dict) -> None:
if self.socket is None:
print("no socket, dropping message")
return
serialized = json.dumps(message)
utils.AsyncRunner.spawn(self.socket.send(serialized))
# -----------------------------------------------------------------------------
class Delegate(avrcp.Delegate):
def __init__(self):
super().__init__(
[avrcp.EventId.VOLUME_CHANGED, avrcp.EventId.PLAYBACK_STATUS_CHANGED]
)
self.websocket_server = None
async def set_absolute_volume(self, volume: int) -> None:
await super().set_absolute_volume(volume)
if self.websocket_server is not None:
self.websocket_server.send_message(
{"type": "set-volume", "params": {"volume": volume}}
)
# -----------------------------------------------------------------------------
async def main():
if len(sys.argv) < 3:
print(
'Usage: run_avrcp_controller.py <device-config> <transport-spec> '
'<sbc-file> [<bt-addr>]'
)
print('example: run_avrcp_controller.py classic1.json usb:0')
return
print('<<< connecting to HCI...')
async with await open_transport_or_link(sys.argv[2]) as (hci_source, hci_sink):
print('<<< connected')
# Create a device
device = Device.from_config_file_with_hci(sys.argv[1], hci_source, hci_sink)
device.classic_enabled = True
# Setup the SDP to expose the sink service
device.sdp_service_records = sdp_records()
# Start the controller
await device.power_on()
# Create a listener to wait for AVDTP connections
listener = avdtp.Listener(avdtp.Listener.create_registrar(device))
listener.on('connection', on_avdtp_connection)
avrcp_delegate = Delegate()
avrcp_protocol = avrcp.Protocol(avrcp_delegate)
avrcp_protocol.listen(device)
websocket_server = WebSocketServer(avrcp_protocol, avrcp_delegate)
avrcp_delegate.websocket_server = websocket_server
avrcp_protocol.on(
"start", lambda: on_avrcp_start(avrcp_protocol, websocket_server)
)
await websocket_server.start()
if len(sys.argv) >= 5:
# Connect to the peer
target_address = sys.argv[4]
print(f'=== Connecting to {target_address}...')
connection = await device.connect(
target_address, transport=BT_BR_EDR_TRANSPORT
)
print(f'=== Connected to {connection.peer_address}!')
# Request authentication
print('*** Authenticating...')
await connection.authenticate()
print('*** Authenticated')
# Enable encryption
print('*** Enabling encryption...')
await connection.encrypt()
print('*** Encryption on')
server = await avdtp.Protocol.connect(connection)
listener.set_server(connection, server)
sink = server.add_sink(codec_capabilities())
sink.on('rtp_packet', on_rtp_packet)
await avrcp_protocol.connect(connection)
else:
# Start being discoverable and connectable
await device.set_discoverable(True)
await device.set_connectable(True)
await asyncio.get_event_loop().create_future()
# -----------------------------------------------------------------------------
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main())
+2 -2
View File
@@ -590,12 +590,12 @@ async def main():
def on_set_protocol_cb(protocol: int):
retValue = hid_device.GetSetStatus()
# We do not support SET_PROTOCOL.
print(f"SET_PROTOCOL report_id: {protocol}")
print("SET_PROTOCOL report_id: " + str(protocol))
retValue.status = hid_device.GetSetReturn.ERR_UNSUPPORTED_REQUEST
return retValue
def on_virtual_cable_unplug_cb():
print('Received Virtual Cable Unplug')
print(f'Received Virtual Cable Unplug')
asyncio.create_task(handle_virtual_cable_unplug())
print('<<< connecting to HCI...')
@@ -1,7 +1,7 @@
<?xml version="1.0" encoding="utf-8"?>
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="com.github.google.bumble.btbench">
<uses-sdk android:minSdkVersion="30" android:targetSdkVersion="34" />
xmlns:tools="http://schemas.android.com/tools">
<!-- Request legacy Bluetooth permissions on older devices. -->
<uses-permission android:name="android.permission.BLUETOOTH" android:maxSdkVersion="30" />
<uses-permission android:name="android.permission.BLUETOOTH_ADMIN" android:maxSdkVersion="30" />
@@ -22,10 +22,11 @@
android:roundIcon="@mipmap/ic_launcher_round"
android:supportsRtl="true"
android:theme="@style/Theme.BTBench"
>
tools:targetApi="31">
<activity
android:name=".MainActivity"
android:exported="true"
android:label="@string/app_name"
android:theme="@style/Theme.BTBench">
<intent-filter>
<action android:name="android.intent.action.MAIN" />
@@ -28,8 +28,8 @@ private val Log = Logger.getLogger("btbench.l2cap-client")
class L2capClient(
private val viewModel: AppViewModel,
private val bluetoothAdapter: BluetoothAdapter,
private val context: Context
val bluetoothAdapter: BluetoothAdapter,
val context: Context
) {
@SuppressLint("MissingPermission")
fun run() {
@@ -74,18 +74,12 @@ class L2capClient(
gatt: BluetoothGatt?, status: Int, newState: Int
) {
if (gatt != null && newState == BluetoothProfile.STATE_CONNECTED) {
if (viewModel.use2mPhy) {
gatt.setPreferredPhy(
BluetoothDevice.PHY_LE_2M_MASK,
BluetoothDevice.PHY_LE_2M_MASK,
BluetoothDevice.PHY_OPTION_NO_PREFERRED
)
}
gatt.setPreferredPhy(
BluetoothDevice.PHY_LE_2M_MASK,
BluetoothDevice.PHY_LE_2M_MASK,
BluetoothDevice.PHY_OPTION_NO_PREFERRED
)
gatt.readPhy()
// Request an MTU update, even though we don't use GATT, because Android
// won't request a larger link layer maximum data length otherwise.
gatt.requestMtu(517)
}
}
},
@@ -23,20 +23,19 @@ import androidx.compose.runtime.setValue
import androidx.lifecycle.ViewModel
import java.util.UUID
val DEFAULT_RFCOMM_UUID: UUID = UUID.fromString("E6D55659-C8B4-4B85-96BB-B1143AF6D3AE")
val DEFAULT_RFCOMM_UUID = UUID.fromString("E6D55659-C8B4-4B85-96BB-B1143AF6D3AE")
const val DEFAULT_PEER_BLUETOOTH_ADDRESS = "AA:BB:CC:DD:EE:FF"
const val DEFAULT_SENDER_PACKET_COUNT = 100
const val DEFAULT_SENDER_PACKET_SIZE = 1024
const val DEFAULT_PSM = 128
class AppViewModel : ViewModel() {
private var preferences: SharedPreferences? = null
var peerBluetoothAddress by mutableStateOf(DEFAULT_PEER_BLUETOOTH_ADDRESS)
var l2capPsm by mutableIntStateOf(DEFAULT_PSM)
var l2capPsm by mutableStateOf(0)
var use2mPhy by mutableStateOf(true)
var mtu by mutableIntStateOf(0)
var rxPhy by mutableIntStateOf(0)
var txPhy by mutableIntStateOf(0)
var mtu by mutableStateOf(0)
var rxPhy by mutableStateOf(0)
var txPhy by mutableStateOf(0)
var senderPacketCountSlider by mutableFloatStateOf(0.0F)
var senderPacketSizeSlider by mutableFloatStateOf(0.0F)
var senderPacketCount by mutableIntStateOf(DEFAULT_SENDER_PACKET_COUNT)
@@ -80,18 +79,18 @@ class AppViewModel : ViewModel() {
}
fun updateSenderPacketCountSlider() {
senderPacketCountSlider = if (senderPacketCount <= 10) {
0.0F
if (senderPacketCount <= 10) {
senderPacketCountSlider = 0.0F
} else if (senderPacketCount <= 50) {
0.2F
senderPacketCountSlider = 0.2F
} else if (senderPacketCount <= 100) {
0.4F
senderPacketCountSlider = 0.4F
} else if (senderPacketCount <= 500) {
0.6F
senderPacketCountSlider = 0.6F
} else if (senderPacketCount <= 1000) {
0.8F
senderPacketCountSlider = 0.8F
} else {
1.0F
senderPacketCountSlider = 1.0F
}
with(preferences!!.edit()) {
@@ -101,18 +100,18 @@ class AppViewModel : ViewModel() {
}
fun updateSenderPacketCount() {
senderPacketCount = if (senderPacketCountSlider < 0.1F) {
10
if (senderPacketCountSlider < 0.1F) {
senderPacketCount = 10
} else if (senderPacketCountSlider < 0.3F) {
50
senderPacketCount = 50
} else if (senderPacketCountSlider < 0.5F) {
100
senderPacketCount = 100
} else if (senderPacketCountSlider < 0.7F) {
500
senderPacketCount = 500
} else if (senderPacketCountSlider < 0.9F) {
1000
senderPacketCount = 1000
} else {
10000
senderPacketCount = 10000
}
with(preferences!!.edit()) {
@@ -122,18 +121,18 @@ class AppViewModel : ViewModel() {
}
fun updateSenderPacketSizeSlider() {
senderPacketSizeSlider = if (senderPacketSize <= 16) {
0.0F
if (senderPacketSize <= 16) {
senderPacketSizeSlider = 0.0F
} else if (senderPacketSize <= 256) {
0.02F
senderPacketSizeSlider = 0.02F
} else if (senderPacketSize <= 512) {
0.4F
senderPacketSizeSlider = 0.4F
} else if (senderPacketSize <= 1024) {
0.6F
senderPacketSizeSlider = 0.6F
} else if (senderPacketSize <= 2048) {
0.8F
senderPacketSizeSlider = 0.8F
} else {
1.0F
senderPacketSizeSlider = 1.0F
}
with(preferences!!.edit()) {
@@ -143,18 +142,18 @@ class AppViewModel : ViewModel() {
}
fun updateSenderPacketSize() {
senderPacketSize = if (senderPacketSizeSlider < 0.1F) {
16
if (senderPacketSizeSlider < 0.1F) {
senderPacketSize = 16
} else if (senderPacketSizeSlider < 0.3F) {
256
senderPacketSize = 256
} else if (senderPacketSizeSlider < 0.5F) {
512
senderPacketSize = 512
} else if (senderPacketSizeSlider < 0.7F) {
1024
senderPacketSize = 1024
} else if (senderPacketSizeSlider < 0.9F) {
2048
senderPacketSize = 2048
} else {
4096
senderPacketSize = 4096
}
with(preferences!!.edit()) {
@@ -10,7 +10,7 @@ android {
defaultConfig {
applicationId = "com.github.google.bumble.remotehci"
minSdk = 29
minSdk = 26
targetSdk = 33
versionCode = 1
versionName = "1.0"
@@ -4,7 +4,6 @@ import android.hardware.bluetooth.V1_0.Status;
import android.os.IBinder;
import android.os.RemoteException;
import android.os.ServiceManager;
import android.os.Trace;
import android.util.Log;
import java.util.ArrayList;
@@ -54,7 +53,6 @@ class HciHidlHal extends android.hardware.bluetooth.V1_0.IBluetoothHciCallbacks.
private final android.hardware.bluetooth.V1_0.IBluetoothHci mHciService;
private final HciHalCallback mHciCallbacks;
private int mInitializationStatus = -1;
private final boolean mTracingEnabled = Trace.isEnabled();
public static HciHidlHal create(HciHalCallback hciCallbacks) {
@@ -91,7 +89,6 @@ class HciHidlHal extends android.hardware.bluetooth.V1_0.IBluetoothHciCallbacks.
}
// Map the status code.
Log.d(TAG, "Initialization status = " + mInitializationStatus);
switch (mInitializationStatus) {
case android.hardware.bluetooth.V1_0.Status.SUCCESS:
return Status.SUCCESS;
@@ -111,10 +108,6 @@ class HciHidlHal extends android.hardware.bluetooth.V1_0.IBluetoothHciCallbacks.
public void sendPacket(HciPacket.Type type, byte[] packet) {
ArrayList<Byte> data = HciPacket.byteArrayToList(packet);
if (mTracingEnabled) {
Trace.beginAsyncSection("SEND_PACKET_TO_HAL", 1);
}
try {
switch (type) {
case COMMAND:
@@ -132,10 +125,6 @@ class HciHidlHal extends android.hardware.bluetooth.V1_0.IBluetoothHciCallbacks.
} catch (RemoteException error) {
Log.w(TAG, "failed to forward packet: " + error);
}
if (mTracingEnabled) {
Trace.endAsyncSection("SEND_PACKET_TO_HAL", 1);
}
}
@Override
@@ -168,7 +157,6 @@ class HciAidlHal extends android.hardware.bluetooth.IBluetoothHciCallbacks.Stub
private final android.hardware.bluetooth.IBluetoothHci mHciService;
private final HciHalCallback mHciCallbacks;
private int mInitializationStatus = android.hardware.bluetooth.Status.SUCCESS;
private final boolean mTracingEnabled = Trace.isEnabled();
public static HciAidlHal create(HciHalCallback hciCallbacks) {
IBinder binder = ServiceManager.getService("android.hardware.bluetooth.IBluetoothHci/default");
@@ -199,7 +187,6 @@ class HciAidlHal extends android.hardware.bluetooth.IBluetoothHciCallbacks.Stub
}
// Map the status code.
Log.d(TAG, "Initialization status = " + mInitializationStatus);
switch (mInitializationStatus) {
case android.hardware.bluetooth.Status.SUCCESS:
return Status.SUCCESS;
@@ -221,10 +208,6 @@ class HciAidlHal extends android.hardware.bluetooth.IBluetoothHciCallbacks.Stub
// HciHal methods.
@Override
public void sendPacket(HciPacket.Type type, byte[] packet) {
if (mTracingEnabled) {
Trace.beginAsyncSection("SEND_PACKET_TO_HAL", 1);
}
try {
switch (type) {
case COMMAND:
@@ -246,10 +229,6 @@ class HciAidlHal extends android.hardware.bluetooth.IBluetoothHciCallbacks.Stub
} catch (RemoteException error) {
Log.w(TAG, "failed to forward packet: " + error);
}
if (mTracingEnabled) {
Trace.endAsyncSection("SEND_PACKET_TO_HAL", 1);
}
}
// IBluetoothHciCallbacks methods.
@@ -1,6 +1,5 @@
package com.github.google.bumble.remotehci;
import android.os.Trace;
import android.util.Log;
import java.io.IOException;
@@ -16,7 +15,6 @@ public class HciServer {
private final int mPort;
private final Listener mListener;
private OutputStream mOutputStream;
private final boolean mTracingEnabled = Trace.isEnabled();
public interface Listener extends HciParser.Sink {
void onHostConnectionState(boolean connected);
@@ -29,8 +27,6 @@ public class HciServer {
}
public void run() throws IOException {
Log.i(TAG, "Tracing enabled: " + mTracingEnabled);
for (;;) {
try {
loop();
@@ -46,7 +42,6 @@ public class HciServer {
try (ServerSocket serverSocket = new ServerSocket(mPort)) {
mListener.onMessage("Waiting for connection on port " + serverSocket.getLocalPort());
try (Socket clientSocket = serverSocket.accept()) {
clientSocket.setTcpNoDelay(true);
mListener.onHostConnectionState(true);
mListener.onMessage("Connected");
HciParser parser = new HciParser(mListener);
@@ -77,10 +72,6 @@ public class HciServer {
}
public void sendPacket(HciPacket.Type type, byte[] packet) {
if (mTracingEnabled) {
Trace.beginAsyncSection("SEND_PACKET_FROM_HAL", 2);
}
// Create a combined data buffer so we can write it out in a single call.
byte[] data = new byte[packet.length + 1];
data[0] = type.value;
@@ -97,9 +88,5 @@ public class HciServer {
Log.d(TAG, "no client, dropping packet");
}
}
if (mTracingEnabled) {
Trace.endAsyncSection("SEND_PACKET_FROM_HAL", 2);
}
}
}
+4 -5
View File
@@ -59,7 +59,6 @@ console_scripts =
bumble-ble-rpa-tool = bumble.apps.ble_rpa_tool:main
bumble-console = bumble.apps.console:main
bumble-controller-info = bumble.apps.controller_info:main
bumble-controller-loopback = bumble.apps.controller_loopback:main
bumble-gatt-dump = bumble.apps.gatt_dump:main
bumble-hci-bridge = bumble.apps.hci_bridge:main
bumble-l2cap-bridge = bumble.apps.l2cap_bridge:main
@@ -82,15 +81,15 @@ console_scripts =
build =
build >= 0.7
test =
pytest >= 8.0
pytest-asyncio == 0.21.1
pytest >= 6.2
pytest-asyncio >= 0.17
pytest-html >= 3.2.0
coverage >= 6.4
development =
black == 22.10
grpcio-tools >= 1.57.0
invoke >= 1.7.3
mypy == 1.8.0
mypy == 1.5.0
nox >= 2022
pylint == 2.15.8
pyyaml >= 6.0
@@ -99,7 +98,7 @@ development =
types-protobuf >= 4.21.0
avatar =
pandora-avatar == 0.0.5
rootcanal == 1.6.0 ; python_version>='3.10'
rootcanal == 1.3.0 ; python_version>='3.10'
documentation =
mkdocs >= 1.4.0
mkdocs-material >= 8.5.6
-246
View File
@@ -1,246 +0,0 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import struct
import pytest
from bumble import core
from bumble import device
from bumble import host
from bumble import controller
from bumble import link
from bumble import avc
from bumble import avrcp
from bumble import avctp
from bumble.transport import common
# -----------------------------------------------------------------------------
class TwoDevices:
def __init__(self):
self.connections = [None, None]
addresses = ['F0:F1:F2:F3:F4:F5', 'F5:F4:F3:F2:F1:F0']
self.link = link.LocalLink()
self.controllers = [
controller.Controller('C1', link=self.link, public_address=addresses[0]),
controller.Controller('C2', link=self.link, public_address=addresses[1]),
]
self.devices = [
device.Device(
address=addresses[0],
host=host.Host(
self.controllers[0], common.AsyncPipeSink(self.controllers[0])
),
),
device.Device(
address=addresses[1],
host=host.Host(
self.controllers[1], common.AsyncPipeSink(self.controllers[1])
),
),
]
self.devices[0].classic_enabled = True
self.devices[1].classic_enabled = True
self.connections = [None, None]
self.protocols = [None, None]
def on_connection(self, which, connection):
self.connections[which] = connection
async def setup_connections(self):
await self.devices[0].power_on()
await self.devices[1].power_on()
self.connections = await asyncio.gather(
self.devices[0].connect(
self.devices[1].public_address, core.BT_BR_EDR_TRANSPORT
),
self.devices[1].accept(self.devices[0].public_address),
)
self.protocols = [avrcp.Protocol(), avrcp.Protocol()]
self.protocols[0].listen(self.devices[1])
await self.protocols[1].connect(self.connections[0])
# -----------------------------------------------------------------------------
def test_frame_parser():
with pytest.raises(ValueError) as error:
avc.Frame.from_bytes(bytes.fromhex("11480000"))
x = bytes.fromhex("014D0208")
frame = avc.Frame.from_bytes(x)
assert frame.subunit_type == avc.Frame.SubunitType.PANEL
assert frame.subunit_id == 7
assert frame.opcode == 8
x = bytes.fromhex("014DFF0108")
frame = avc.Frame.from_bytes(x)
assert frame.subunit_type == avc.Frame.SubunitType.PANEL
assert frame.subunit_id == 260
assert frame.opcode == 8
x = bytes.fromhex("0148000019581000000103")
frame = avc.Frame.from_bytes(x)
assert isinstance(frame, avc.CommandFrame)
assert frame.ctype == avc.CommandFrame.CommandType.STATUS
assert frame.subunit_type == avc.Frame.SubunitType.PANEL
assert frame.subunit_id == 0
assert frame.opcode == 0
# -----------------------------------------------------------------------------
def test_vendor_dependent_command():
x = bytes.fromhex("0148000019581000000103")
frame = avc.Frame.from_bytes(x)
assert isinstance(frame, avc.VendorDependentCommandFrame)
assert frame.company_id == 0x1958
assert frame.vendor_dependent_data == bytes.fromhex("1000000103")
frame = avc.VendorDependentCommandFrame(
avc.CommandFrame.CommandType.STATUS,
avc.Frame.SubunitType.PANEL,
0,
0x1958,
bytes.fromhex("1000000103"),
)
assert bytes(frame) == x
# -----------------------------------------------------------------------------
def test_avctp_message_assembler():
received_message = []
def on_message(transaction_label, is_response, ipid, pid, payload):
received_message.append((transaction_label, is_response, ipid, pid, payload))
assembler = avctp.MessageAssembler(on_message)
payload = bytes.fromhex("01")
assembler.on_pdu(bytes([1 << 4 | 0b00 << 2 | 1 << 1 | 0, 0x11, 0x22]) + payload)
assert received_message
assert received_message[0] == (1, False, False, 0x1122, payload)
received_message = []
payload = bytes.fromhex("010203")
assembler.on_pdu(bytes([1 << 4 | 0b01 << 2 | 1 << 1 | 0, 0x11, 0x22]) + payload)
assert len(received_message) == 0
assembler.on_pdu(bytes([1 << 4 | 0b00 << 2 | 1 << 1 | 0, 0x11, 0x22]) + payload)
assert received_message
assert received_message[0] == (1, False, False, 0x1122, payload)
received_message = []
payload = bytes.fromhex("010203")
assembler.on_pdu(
bytes([1 << 4 | 0b01 << 2 | 1 << 1 | 0, 3, 0x11, 0x22]) + payload[0:1]
)
assembler.on_pdu(
bytes([1 << 4 | 0b10 << 2 | 1 << 1 | 0, 0x11, 0x22]) + payload[1:2]
)
assembler.on_pdu(
bytes([1 << 4 | 0b11 << 2 | 1 << 1 | 0, 0x11, 0x22]) + payload[2:3]
)
assert received_message
assert received_message[0] == (1, False, False, 0x1122, payload)
# received_message = []
# parameter = bytes.fromhex("010203")
# assembler.on_pdu(struct.pack(">BBH", 0x10, 0b11, len(parameter)) + parameter)
# assert len(received_message) == 0
# -----------------------------------------------------------------------------
def test_avrcp_pdu_assembler():
received_pdus = []
def on_pdu(pdu_id, parameter):
received_pdus.append((pdu_id, parameter))
assembler = avrcp.PduAssembler(on_pdu)
parameter = bytes.fromhex("01")
assembler.on_pdu(struct.pack(">BBH", 0x10, 0b00, len(parameter)) + parameter)
assert received_pdus
assert received_pdus[0] == (0x10, parameter)
received_pdus = []
parameter = bytes.fromhex("010203")
assembler.on_pdu(struct.pack(">BBH", 0x10, 0b01, len(parameter)) + parameter)
assert len(received_pdus) == 0
assembler.on_pdu(struct.pack(">BBH", 0x10, 0b00, len(parameter)) + parameter)
assert received_pdus
assert received_pdus[0] == (0x10, parameter)
received_pdus = []
parameter = bytes.fromhex("010203")
assembler.on_pdu(struct.pack(">BBH", 0x10, 0b01, 1) + parameter[0:1])
assembler.on_pdu(struct.pack(">BBH", 0x10, 0b10, 1) + parameter[1:2])
assembler.on_pdu(struct.pack(">BBH", 0x10, 0b11, 1) + parameter[2:3])
assert received_pdus
assert received_pdus[0] == (0x10, parameter)
received_pdus = []
parameter = bytes.fromhex("010203")
assembler.on_pdu(struct.pack(">BBH", 0x10, 0b11, len(parameter)) + parameter)
assert len(received_pdus) == 0
def test_passthrough_commands():
play_pressed = avc.PassThroughCommandFrame(
avc.CommandFrame.CommandType.CONTROL,
avc.CommandFrame.SubunitType.PANEL,
0,
avc.PassThroughCommandFrame.StateFlag.PRESSED,
avc.PassThroughCommandFrame.OperationId.PLAY,
b'',
)
play_pressed_bytes = bytes(play_pressed)
parsed = avc.Frame.from_bytes(play_pressed_bytes)
assert isinstance(parsed, avc.PassThroughCommandFrame)
assert parsed.operation_id == avc.PassThroughCommandFrame.OperationId.PLAY
assert bytes(parsed) == play_pressed_bytes
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_get_supported_events():
two_devices = TwoDevices()
await two_devices.setup_connections()
supported_events = await two_devices.protocols[0].get_supported_events()
assert supported_events == []
delegate1 = avrcp.Delegate([avrcp.EventId.VOLUME_CHANGED])
two_devices.protocols[0].delegate = delegate1
supported_events = await two_devices.protocols[1].get_supported_events()
assert supported_events == [avrcp.EventId.VOLUME_CHANGED]
# -----------------------------------------------------------------------------
if __name__ == '__main__':
test_frame_parser()
test_vendor_dependent_command()
test_avctp_message_assembler()
test_avrcp_pdu_assembler()
test_passthrough_commands()
test_get_supported_events()
+1 -2
View File
@@ -48,8 +48,7 @@ from bumble.profiles.bap import (
PublishedAudioCapabilitiesService,
PublishedAudioCapabilitiesServiceProxy,
)
from tests.test_utils import TwoDevices
from .test_utils import TwoDevices
# -----------------------------------------------------------------------------
# Logging
+138
View File
@@ -0,0 +1,138 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import logging
import os
import pytest
from typing import List, Optional
from unittest.mock import MagicMock
from bumble.device import Connection, Device
from bumble.host import Host
from bumble.link import LocalLink
from bumble.controller import Controller
from bumble.hci import (
Address,
HCI_CONNECTION_TERMINATED_BY_LOCAL_HOST_ERROR,
HCI_REMOTE_DEVICE_TERMINATED_CONNECTION_DUE_TO_LOW_RESOURCES_ERROR,
)
from bumble.transport import AsyncPipeSink
# -----------------------------------------------------------------------------
class TwoDevices:
connections: List[Optional[Connection]]
def __init__(self) -> None:
self.connections = [None, None]
self.link = LocalLink()
self.controllers = [
Controller('C1', link=self.link),
Controller('C2', link=self.link),
]
self.devices = [
Device(
address=Address('F0:F1:F2:F3:F4:F5'),
host=Host(self.controllers[0], AsyncPipeSink(self.controllers[0])),
),
Device(
address=Address('F5:F4:F3:F2:F1:F0'),
host=Host(self.controllers[1], AsyncPipeSink(self.controllers[1])),
),
]
self.paired = [None, None]
def on_connection(self, which, connection):
self.connections[which] = connection
connection.on(
'disconnection', lambda reason: self.on_disconnection(which, reason)
)
def on_disconnection(self, which, _):
self.connections[which] = None
async def setup(self):
self.devices[0].on(
'connection', lambda connection: self.on_connection(0, connection)
)
self.devices[1].on(
'connection', lambda connection: self.on_connection(1, connection)
)
await self.devices[0].power_on()
await self.devices[1].power_on()
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_self_connection():
two_devices = TwoDevices()
await two_devices.setup()
await two_devices.devices[0].connect(two_devices.devices[1].random_address)
assert two_devices.connections[0] is not None
assert two_devices.connections[1] is not None
mock0 = MagicMock()
mock1 = MagicMock()
two_devices.connections[0].once('disconnection', mock0)
two_devices.connections[1].once('disconnection', mock1)
await two_devices.connections[0].disconnect(
HCI_REMOTE_DEVICE_TERMINATED_CONNECTION_DUE_TO_LOW_RESOURCES_ERROR
)
mock0.assert_called_once_with(HCI_CONNECTION_TERMINATED_BY_LOCAL_HOST_ERROR)
mock1.assert_called_once_with(
HCI_REMOTE_DEVICE_TERMINATED_CONNECTION_DUE_TO_LOW_RESOURCES_ERROR
)
assert two_devices.connections[0] is None
assert two_devices.connections[1] is None
await two_devices.devices[0].connect(two_devices.devices[1].random_address)
assert two_devices.connections[0] is not None
assert two_devices.connections[1] is not None
mock0 = MagicMock()
mock1 = MagicMock()
two_devices.connections[0].once('disconnection', mock0)
two_devices.connections[1].once('disconnection', mock1)
await two_devices.connections[1].disconnect(
HCI_REMOTE_DEVICE_TERMINATED_CONNECTION_DUE_TO_LOW_RESOURCES_ERROR
)
mock1.assert_called_once_with(HCI_CONNECTION_TERMINATED_BY_LOCAL_HOST_ERROR)
mock0.assert_called_once_with(
HCI_REMOTE_DEVICE_TERMINATED_CONNECTION_DUE_TO_LOW_RESOURCES_ERROR
)
assert two_devices.connections[0] is None
assert two_devices.connections[1] is None
# -----------------------------------------------------------------------------
async def run_test_controller():
await test_self_connection()
# -----------------------------------------------------------------------------
if __name__ == '__main__':
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(run_test_controller())
+6 -15
View File
@@ -20,7 +20,6 @@ import os
import pytest
import struct
import logging
from unittest import mock
from bumble import device
from bumble.profiles import csip
@@ -69,18 +68,14 @@ def test_sef():
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
@pytest.mark.parametrize(
'sirk_type,', [(csip.SirkType.ENCRYPTED), (csip.SirkType.PLAINTEXT)]
)
async def test_csis(sirk_type):
async def test_csis():
SIRK = bytes.fromhex('2f62c8ae41867d1bb619e788a2605faa')
LTK = bytes.fromhex('2f62c8ae41867d1bb619e788a2605faa')
devices = TwoDevices()
devices[0].add_service(
csip.CoordinatedSetIdentificationService(
set_identity_resolving_key=SIRK,
set_identity_resolving_key_type=sirk_type,
set_identity_resolving_key_type=csip.SirkType.PLAINTEXT,
coordinated_set_size=2,
set_member_lock=csip.MemberLock.UNLOCKED,
set_member_rank=0,
@@ -88,19 +83,15 @@ async def test_csis(sirk_type):
)
await devices.setup_connection()
# Mock encryption.
devices.connections[0].encryption = 1
devices.connections[1].encryption = 1
devices[0].get_long_term_key = mock.AsyncMock(return_value=LTK)
devices[1].get_long_term_key = mock.AsyncMock(return_value=LTK)
peer = device.Peer(devices.connections[1])
csis_client = await peer.discover_service_and_create_proxy(
csip.CoordinatedSetIdentificationProxy
)
assert await csis_client.read_set_identity_resolving_key() == (sirk_type, SIRK)
assert (
await csis_client.set_identity_resolving_key.read_value()
== bytes([csip.SirkType.PLAINTEXT]) + SIRK
)
assert await csis_client.coordinated_set_size.read_value() == struct.pack('B', 2)
assert await csis_client.set_member_lock.read_value() == struct.pack(
'B', csip.MemberLock.UNLOCKED
+1 -67
View File
@@ -29,7 +29,7 @@ from bumble.core import (
ConnectionParameters,
)
from bumble.device import Connection, Device
from bumble.host import AclPacketQueue, Host
from bumble.host import Host
from bumble.hci import (
HCI_ACCEPT_CONNECTION_REQUEST_COMMAND,
HCI_COMMAND_STATUS_PENDING,
@@ -50,8 +50,6 @@ from bumble.gatt import (
GATT_APPEARANCE_CHARACTERISTIC,
)
from .test_utils import TwoDevices
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
@@ -75,13 +73,6 @@ async def test_device_connect_parallel():
d1 = Device(host=Host(None, None))
d2 = Device(host=Host(None, None))
def _send(packet):
pass
d0.host.acl_packet_queue = AclPacketQueue(0, 0, _send)
d1.host.acl_packet_queue = AclPacketQueue(0, 0, _send)
d2.host.acl_packet_queue = AclPacketQueue(0, 0, _send)
# enable classic
d0.classic_enabled = True
d1.classic_enabled = True
@@ -414,63 +405,6 @@ async def test_extended_advertising_disconnection(auto_restart):
device.start_extended_advertising.assert_not_called()
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_get_remote_le_features():
devices = TwoDevices()
await devices.setup_connection()
assert (await devices.connections[0].get_remote_le_features()) is not None
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_cis():
devices = TwoDevices()
await devices.setup_connection()
peripheral_cis_futures = {}
def on_cis_request(
acl_connection: Connection,
cis_handle: int,
_cig_id: int,
_cis_id: int,
):
acl_connection.abort_on(
'disconnection', devices[1].accept_cis_request(cis_handle)
)
peripheral_cis_futures[cis_handle] = asyncio.get_running_loop().create_future()
devices[1].on('cis_request', on_cis_request)
devices[1].on(
'cis_establishment',
lambda cis_link: peripheral_cis_futures[cis_link.handle].set_result(None),
)
cis_handles = await devices[0].setup_cig(
cig_id=1,
cis_id=[2, 3],
sdu_interval=(0, 0),
framing=0,
max_sdu=(0, 0),
retransmission_number=0,
max_transport_latency=(0, 0),
)
assert len(cis_handles) == 2
cis_links = await devices[0].create_cis(
[
(cis_handles[0], devices.connections[0].handle),
(cis_handles[1], devices.connections[0].handle),
]
)
await asyncio.gather(*peripheral_cis_futures.values())
assert len(cis_links) == 2
await cis_links[0].disconnect()
await cis_links[1].disconnect()
# -----------------------------------------------------------------------------
def test_gatt_services_with_gas():
device = Device(host=Host(None, None))
+31 -76
View File
@@ -20,10 +20,11 @@ import logging
import os
import struct
import pytest
from unittest.mock import AsyncMock, Mock, ANY
from unittest.mock import Mock, ANY
from bumble.controller import Controller
from bumble.gatt_client import CharacteristicProxy
from bumble.gatt_server import Server
from bumble.link import LocalLink
from bumble.device import Device, Peer
from bumble.host import Host
@@ -119,9 +120,9 @@ async def test_characteristic_encoding():
Characteristic.READABLE,
123,
)
x = await c.read_value(None)
x = c.read_value(None)
assert x == bytes([123])
await c.write_value(None, bytes([122]))
c.write_value(None, bytes([122]))
assert c.value == 122
class FooProxy(CharacteristicProxy):
@@ -151,22 +152,7 @@ async def test_characteristic_encoding():
bytes([123]),
)
async def async_read(connection):
return 0x05060708
async_characteristic = PackedCharacteristicAdapter(
Characteristic(
'2AB7E91B-43E8-4F73-AC3B-80C1683B47F9',
Characteristic.Properties.READ,
Characteristic.READABLE,
CharacteristicValue(read=async_read),
),
'>I',
)
service = Service(
'3A657F47-D34F-46B3-B1EC-698E29B6B829', [characteristic, async_characteristic]
)
service = Service('3A657F47-D34F-46B3-B1EC-698E29B6B829', [characteristic])
server.add_service(service)
await client.power_on()
@@ -198,13 +184,6 @@ async def test_characteristic_encoding():
await async_barrier()
assert characteristic.value == bytes([50])
c2 = peer.get_characteristics_by_uuid(async_characteristic.uuid)
assert len(c2) == 1
c2 = c2[0]
cd2 = PackedCharacteristicAdapter(c2, ">I")
cd2v = await cd2.read_value()
assert cd2v == 0x05060708
last_change = None
def on_change(value):
@@ -306,8 +285,7 @@ async def test_attribute_getters():
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_CharacteristicAdapter():
def test_CharacteristicAdapter():
# Check that the CharacteristicAdapter base class is transparent
v = bytes([1, 2, 3])
c = Characteristic(
@@ -318,11 +296,11 @@ async def test_CharacteristicAdapter():
)
a = CharacteristicAdapter(c)
value = await a.read_value(None)
value = a.read_value(None)
assert value == v
v = bytes([3, 4, 5])
await a.write_value(None, v)
a.write_value(None, v)
assert c.value == v
# Simple delegated adapter
@@ -330,11 +308,11 @@ async def test_CharacteristicAdapter():
c, lambda x: bytes(reversed(x)), lambda x: bytes(reversed(x))
)
value = await a.read_value(None)
value = a.read_value(None)
assert value == bytes(reversed(v))
v = bytes([3, 4, 5])
await a.write_value(None, v)
a.write_value(None, v)
assert a.value == bytes(reversed(v))
# Packed adapter with single element format
@@ -343,10 +321,10 @@ async def test_CharacteristicAdapter():
c.value = v
a = PackedCharacteristicAdapter(c, '>H')
value = await a.read_value(None)
value = a.read_value(None)
assert value == pv
c.value = None
await a.write_value(None, pv)
a.write_value(None, pv)
assert a.value == v
# Packed adapter with multi-element format
@@ -356,10 +334,10 @@ async def test_CharacteristicAdapter():
c.value = (v1, v2)
a = PackedCharacteristicAdapter(c, '>HH')
value = await a.read_value(None)
value = a.read_value(None)
assert value == pv
c.value = None
await a.write_value(None, pv)
a.write_value(None, pv)
assert a.value == (v1, v2)
# Mapped adapter
@@ -370,10 +348,10 @@ async def test_CharacteristicAdapter():
c.value = mapped
a = MappedCharacteristicAdapter(c, '>HH', ('v1', 'v2'))
value = await a.read_value(None)
value = a.read_value(None)
assert value == pv
c.value = None
await a.write_value(None, pv)
a.write_value(None, pv)
assert a.value == mapped
# UTF-8 adapter
@@ -382,49 +360,27 @@ async def test_CharacteristicAdapter():
c.value = v
a = UTF8CharacteristicAdapter(c)
value = await a.read_value(None)
value = a.read_value(None)
assert value == ev
c.value = None
await a.write_value(None, ev)
a.write_value(None, ev)
assert a.value == v
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_CharacteristicValue():
def test_CharacteristicValue():
b = bytes([1, 2, 3])
async def read_value(connection):
return b
c = CharacteristicValue(read=read_value)
x = await c.read(None)
c = CharacteristicValue(read=lambda _: b)
x = c.read(None)
assert x == b
m = Mock()
c = CharacteristicValue(write=m)
result = []
c = CharacteristicValue(
write=lambda connection, value: result.append((connection, value))
)
z = object()
c.write(z, b)
m.assert_called_once_with(z, b)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_CharacteristicValue_async():
b = bytes([1, 2, 3])
async def read_value(connection):
return b
c = CharacteristicValue(read=read_value)
x = await c.read(None)
assert x == b
m = AsyncMock()
c = CharacteristicValue(write=m)
z = object()
await c.write(z, b)
m.assert_called_once_with(z, b)
assert result == [(z, b)]
# -----------------------------------------------------------------------------
@@ -1005,18 +961,12 @@ Descriptor(handle=0x0009, type=UUID-16:2902 (Client Characteristic Configuration
# -----------------------------------------------------------------------------
async def async_main():
test_UUID()
test_ATT_Error_Response()
test_ATT_Read_By_Group_Type_Request()
await test_read_write()
await test_read_write2()
await test_subscribe_notify()
await test_unsubscribe()
await test_characteristic_encoding()
await test_mtu_exchange()
await test_CharacteristicValue()
await test_CharacteristicValue_async()
await test_CharacteristicAdapter()
# -----------------------------------------------------------------------------
@@ -1155,4 +1105,9 @@ def test_get_attribute_group():
# -----------------------------------------------------------------------------
if __name__ == '__main__':
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
test_UUID()
test_ATT_Error_Response()
test_ATT_Read_By_Group_Type_Request()
test_CharacteristicValue()
test_CharacteristicAdapter()
asyncio.run(async_main())
-64
View File
@@ -23,10 +23,8 @@ import pytest
from typing import Tuple
from .test_utils import TwoDevices
from bumble import core
from bumble import hfp
from bumble import rfcomm
from bumble import hci
# -----------------------------------------------------------------------------
@@ -89,68 +87,6 @@ async def test_slc():
ag_task.cancel()
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_sco_setup():
devices = TwoDevices()
# Enable Classic connections
devices[0].classic_enabled = True
devices[1].classic_enabled = True
# Start
await devices[0].power_on()
await devices[1].power_on()
connections = await asyncio.gather(
devices[0].connect(
devices[1].public_address, transport=core.BT_BR_EDR_TRANSPORT
),
devices[1].accept(devices[0].public_address),
)
def on_sco_request(_connection, _link_type: int):
connections[1].abort_on(
'disconnection',
devices[1].send_command(
hci.HCI_Enhanced_Accept_Synchronous_Connection_Request_Command(
bd_addr=connections[1].peer_address,
**hfp.ESCO_PARAMETERS[
hfp.DefaultCodecParameters.ESCO_CVSD_S1
].asdict(),
)
),
)
devices[1].on('sco_request', on_sco_request)
sco_connection_futures = [
asyncio.get_running_loop().create_future(),
asyncio.get_running_loop().create_future(),
]
for device, future in zip(devices, sco_connection_futures):
device.on('sco_connection', future.set_result)
await devices[0].send_command(
hci.HCI_Enhanced_Setup_Synchronous_Connection_Command(
connection_handle=connections[0].handle,
**hfp.ESCO_PARAMETERS[hfp.DefaultCodecParameters.ESCO_CVSD_S1].asdict(),
)
)
sco_connections = await asyncio.gather(*sco_connection_futures)
sco_disconnection_futures = [
asyncio.get_running_loop().create_future(),
asyncio.get_running_loop().create_future(),
]
for future, sco_connection in zip(sco_disconnection_futures, sco_connections):
sco_connection.on('disconnection', future.set_result)
await sco_connections[0].disconnect()
await asyncio.gather(*sco_disconnection_futures)
# -----------------------------------------------------------------------------
async def run():
await test_slc()
+1 -28
View File
@@ -15,11 +15,7 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import pytest
from . import test_utils
from bumble.rfcomm import RFCOMM_Frame, Server, Client, DLC
from bumble.rfcomm import RFCOMM_Frame
# -----------------------------------------------------------------------------
@@ -47,29 +43,6 @@ def test_frames():
basic_frame_check(frame)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_basic_connection():
devices = test_utils.TwoDevices()
await devices.setup_connection()
accept_future: asyncio.Future[DLC] = asyncio.get_running_loop().create_future()
channel = Server(devices[0]).listen(acceptor=accept_future.set_result)
multiplexer = await Client(devices.connections[1]).start()
dlcs = await asyncio.gather(accept_future, multiplexer.open_dlc(channel))
queues = [asyncio.Queue(), asyncio.Queue()]
for dlc, queue in zip(dlcs, queues):
dlc.sink = queue.put_nowait
dlcs[0].write(b'The quick brown fox jumps over the lazy dog')
assert await queues[1].get() == b'The quick brown fox jumps over the lazy dog'
dlcs[1].write(b'Lorem ipsum dolor sit amet')
assert await queues[0].get() == b'Lorem ipsum dolor sit amet'
# -----------------------------------------------------------------------------
if __name__ == '__main__':
test_frames()
-7
View File
@@ -547,13 +547,6 @@ async def test_self_smp_over_classic():
MockSmpSession.send_public_key_command.assert_not_called()
MockSmpSession.send_pairing_random_command.assert_not_called()
for i in range(2):
assert (
await two_devices.devices[i].keystore.get(
str(two_devices.connections[i].peer_address)
)
).link_key
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
+4 -5
View File
@@ -29,18 +29,17 @@ class TwoDevices:
self.connections = [None, None]
self.link = LocalLink()
addresses = ['F0:F1:F2:F3:F4:F5', 'F5:F4:F3:F2:F1:F0']
self.controllers = [
Controller('C1', link=self.link, public_address=addresses[0]),
Controller('C2', link=self.link, public_address=addresses[1]),
Controller('C1', link=self.link),
Controller('C2', link=self.link),
]
self.devices = [
Device(
address=Address(addresses[0]),
address=Address('F0:F1:F2:F3:F4:F5'),
host=Host(self.controllers[0], AsyncPipeSink(self.controllers[0])),
),
Device(
address=Address(addresses[1]),
address=Address('F5:F4:F3:F2:F1:F0'),
host=Host(self.controllers[1], AsyncPipeSink(self.controllers[1])),
),
]
+2 -34
View File
@@ -12,20 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import contextlib
import logging
import os
from unittest.mock import MagicMock
from pyee import EventEmitter
from bumble import utils
from pyee import EventEmitter
from unittest.mock import MagicMock
# -----------------------------------------------------------------------------
def test_on() -> None:
emitter = EventEmitter()
with contextlib.closing(utils.EventWatcher()) as context:
@@ -38,7 +33,6 @@ def test_on() -> None:
assert mock.call_count == 1
# -----------------------------------------------------------------------------
def test_on_decorator() -> None:
emitter = EventEmitter()
with contextlib.closing(utils.EventWatcher()) as context:
@@ -54,7 +48,6 @@ def test_on_decorator() -> None:
assert mock.call_count == 1
# -----------------------------------------------------------------------------
def test_multiple_handlers() -> None:
emitter = EventEmitter()
with contextlib.closing(utils.EventWatcher()) as context:
@@ -71,30 +64,6 @@ def test_multiple_handlers() -> None:
mock.assert_called_once_with('b')
# -----------------------------------------------------------------------------
def test_open_int_enums():
class Foo(utils.OpenIntEnum):
FOO = 1
BAR = 2
BLA = 3
x = Foo(1)
assert x.name == "FOO"
assert x.value == 1
assert int(x) == 1
assert x == 1
assert x + 1 == 2
x = Foo(4)
assert x.name == "Foo[4]"
assert x.value == 4
assert int(x) == 4
assert x == 4
assert x + 1 == 5
print(list(Foo))
# -----------------------------------------------------------------------------
def run_tests():
test_on()
@@ -106,4 +75,3 @@ def run_tests():
if __name__ == '__main__':
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
run_tests()
test_open_int_enums()