Compare commits

..

1 Commits

Author SHA1 Message Date
Gilles Boccon-Gibod d43281c57e allow passing IRKs as arguments 2023-12-28 14:35:23 -08:00
23 changed files with 298 additions and 938 deletions
-2
View File
@@ -10,5 +10,3 @@ __pycache__
bumble/_version.py
.vscode/launch.json
/.idea
venv/
.venv/
-1
View File
@@ -22,7 +22,6 @@
"cmac",
"CONNECTIONLESS",
"csip",
"csis",
"csrcs",
"CVSD",
"datagram",
+66 -300
View File
@@ -80,10 +80,10 @@ SPEED_TX_UUID = 'E789C754-41A1-45F4-A948-A0A1A90DBA53'
SPEED_RX_UUID = '016A2CC7-E14B-4819-935F-1F56EAE4098D'
DEFAULT_RFCOMM_UUID = 'E6D55659-C8B4-4B85-96BB-B1143AF6D3AE'
DEFAULT_L2CAP_PSM = 128
DEFAULT_L2CAP_PSM = 1234
DEFAULT_L2CAP_MAX_CREDITS = 128
DEFAULT_L2CAP_MTU = 1024
DEFAULT_L2CAP_MPS = 1024
DEFAULT_L2CAP_MPS = 1022
DEFAULT_LINGER_TIME = 1.0
DEFAULT_POST_CONNECTION_WAIT_TIME = 1.0
@@ -240,23 +240,6 @@ async def find_rfcomm_channel_with_uuid(connection: Connection, uuid: str) -> in
return 0
def log_stats(title, stats):
stats_min = min(stats)
stats_max = max(stats)
stats_avg = sum(stats) / len(stats)
logging.info(
color(
(
f'### {title} stats: '
f'min={stats_min:.2f}, '
f'max={stats_max:.2f}, '
f'average={stats_avg:.2f}'
),
'cyan',
)
)
class PacketType(enum.IntEnum):
RESET = 0
SEQUENCE = 1
@@ -270,27 +253,14 @@ PACKET_FLAG_LAST = 1
# Sender
# -----------------------------------------------------------------------------
class Sender:
def __init__(
self,
packet_io,
start_delay,
repeat,
repeat_delay,
pace,
packet_size,
packet_count,
):
def __init__(self, packet_io, start_delay, packet_size, packet_count):
self.tx_start_delay = start_delay
self.tx_packet_size = packet_size
self.tx_packet_count = packet_count
self.packet_io = packet_io
self.packet_io.packet_listener = self
self.repeat = repeat
self.repeat_delay = repeat_delay
self.pace = pace
self.start_time = 0
self.bytes_sent = 0
self.stats = []
self.done = asyncio.Event()
def reset(self):
@@ -301,57 +271,27 @@ class Sender:
await self.packet_io.ready.wait()
logging.info(color('--- Go!', 'blue'))
for run in range(self.repeat + 1):
self.done.clear()
if self.tx_start_delay:
logging.info(color(f'*** Startup delay: {self.tx_start_delay}', 'blue'))
await asyncio.sleep(self.tx_start_delay)
if run > 0 and self.repeat and self.repeat_delay:
logging.info(color(f'*** Repeat delay: {self.repeat_delay}', 'green'))
await asyncio.sleep(self.repeat_delay)
logging.info(color('=== Sending RESET', 'magenta'))
await self.packet_io.send_packet(bytes([PacketType.RESET]))
self.start_time = time.time()
for tx_i in range(self.tx_packet_count):
packet_flags = PACKET_FLAG_LAST if tx_i == self.tx_packet_count - 1 else 0
packet = struct.pack(
'>bbI',
PacketType.SEQUENCE,
packet_flags,
tx_i,
) + bytes(self.tx_packet_size - 6)
logging.info(color(f'Sending packet {tx_i}: {len(packet)} bytes', 'yellow'))
self.bytes_sent += len(packet)
await self.packet_io.send_packet(packet)
if self.tx_start_delay:
logging.info(color(f'*** Startup delay: {self.tx_start_delay}', 'blue'))
await asyncio.sleep(self.tx_start_delay)
logging.info(color('=== Sending RESET', 'magenta'))
await self.packet_io.send_packet(bytes([PacketType.RESET]))
self.start_time = time.time()
self.bytes_sent = 0
for tx_i in range(self.tx_packet_count):
packet_flags = (
PACKET_FLAG_LAST if tx_i == self.tx_packet_count - 1 else 0
)
packet = struct.pack(
'>bbI',
PacketType.SEQUENCE,
packet_flags,
tx_i,
) + bytes(self.tx_packet_size - 6 - self.packet_io.overhead_size)
logging.info(
color(
f'Sending packet {tx_i}: {self.tx_packet_size} bytes', 'yellow'
)
)
self.bytes_sent += len(packet)
await self.packet_io.send_packet(packet)
if self.pace is None:
continue
if self.pace > 0:
await asyncio.sleep(self.pace / 1000)
else:
await self.packet_io.drain()
await self.done.wait()
run_counter = f'[{run + 1} of {self.repeat + 1}]' if self.repeat else ''
logging.info(color(f'=== {run_counter} Done!', 'magenta'))
if self.repeat:
log_stats('Run', self.stats)
if self.repeat:
logging.info(color('--- End of runs', 'blue'))
await self.done.wait()
logging.info(color('=== Done!', 'magenta'))
def on_packet_received(self, packet):
try:
@@ -362,7 +302,6 @@ class Sender:
if packet_type == PacketType.ACK:
elapsed = time.time() - self.start_time
average_tx_speed = self.bytes_sent / elapsed
self.stats.append(average_tx_speed)
logging.info(
color(
f'@@@ Received ACK. Speed: average={average_tx_speed:.4f}'
@@ -381,17 +320,17 @@ class Receiver:
start_timestamp: float
last_timestamp: float
def __init__(self, packet_io, linger):
def __init__(self, packet_io):
self.reset()
self.packet_io = packet_io
self.packet_io.packet_listener = self
self.linger = linger
self.done = asyncio.Event()
def reset(self):
self.expected_packet_index = 0
self.measurements = [(time.time(), 0)]
self.total_bytes_received = 0
self.start_timestamp = 0.0
self.last_timestamp = 0.0
self.bytes_received = 0
def on_packet_received(self, packet):
try:
@@ -399,9 +338,12 @@ class Receiver:
except ValueError:
return
now = time.time()
if packet_type == PacketType.RESET:
logging.info(color('=== Received RESET', 'magenta'))
self.reset()
self.start_timestamp = now
return
try:
@@ -410,8 +352,7 @@ class Receiver:
return
logging.info(
f'<<< Received packet {packet_index}: '
f'flags=0x{packet_flags:02X}, '
f'{len(packet) + self.packet_io.overhead_size} bytes'
f'flags=0x{packet_flags:02X}, {len(packet)} bytes'
)
if packet_index != self.expected_packet_index:
@@ -422,27 +363,19 @@ class Receiver:
)
)
now = time.time()
elapsed_since_start = now - self.measurements[0][0]
elapsed_since_last = now - self.measurements[-1][0]
self.measurements.append((now, len(packet)))
self.total_bytes_received += len(packet)
elapsed_since_start = now - self.start_timestamp
elapsed_since_last = now - self.last_timestamp
self.bytes_received += len(packet)
instant_rx_speed = len(packet) / elapsed_since_last
average_rx_speed = self.total_bytes_received / elapsed_since_start
window = self.measurements[-64:]
windowed_rx_speed = sum(measurement[1] for measurement in window[1:]) / (
window[-1][0] - window[0][0]
)
average_rx_speed = self.bytes_received / elapsed_since_start
logging.info(
color(
'Speed: '
f'instant={instant_rx_speed:.4f}, '
f'windowed={windowed_rx_speed:.4f}, '
f'average={average_rx_speed:.4f}',
f'Speed: instant={instant_rx_speed:.4f}, average={average_rx_speed:.4f}',
'yellow',
)
)
self.last_timestamp = now
self.expected_packet_index = packet_index + 1
if packet_flags & PACKET_FLAG_LAST:
@@ -452,8 +385,7 @@ class Receiver:
)
)
logging.info(color('@@@ Received last packet', 'green'))
if not self.linger:
self.done.set()
self.done.set()
async def run(self):
await self.done.wait()
@@ -464,31 +396,16 @@ class Receiver:
# Ping
# -----------------------------------------------------------------------------
class Ping:
def __init__(
self,
packet_io,
start_delay,
repeat,
repeat_delay,
pace,
packet_size,
packet_count,
):
def __init__(self, packet_io, start_delay, packet_size, packet_count):
self.tx_start_delay = start_delay
self.tx_packet_size = packet_size
self.tx_packet_count = packet_count
self.packet_io = packet_io
self.packet_io.packet_listener = self
self.repeat = repeat
self.repeat_delay = repeat_delay
self.pace = pace
self.done = asyncio.Event()
self.current_packet_index = 0
self.ping_sent_time = 0.0
self.latencies = []
self.min_stats = []
self.max_stats = []
self.avg_stats = []
def reset(self):
pass
@@ -498,56 +415,21 @@ class Ping:
await self.packet_io.ready.wait()
logging.info(color('--- Go!', 'blue'))
for run in range(self.repeat + 1):
self.done.clear()
if self.tx_start_delay:
logging.info(color(f'*** Startup delay: {self.tx_start_delay}', 'blue'))
await asyncio.sleep(self.tx_start_delay)
if run > 0 and self.repeat and self.repeat_delay:
logging.info(color(f'*** Repeat delay: {self.repeat_delay}', 'green'))
await asyncio.sleep(self.repeat_delay)
logging.info(color('=== Sending RESET', 'magenta'))
await self.packet_io.send_packet(bytes([PacketType.RESET]))
if self.tx_start_delay:
logging.info(color(f'*** Startup delay: {self.tx_start_delay}', 'blue'))
await asyncio.sleep(self.tx_start_delay)
await self.send_next_ping()
logging.info(color('=== Sending RESET', 'magenta'))
await self.packet_io.send_packet(bytes([PacketType.RESET]))
self.current_packet_index = 0
await self.send_next_ping()
await self.done.wait()
min_latency = min(self.latencies)
max_latency = max(self.latencies)
avg_latency = sum(self.latencies) / len(self.latencies)
logging.info(
color(
'@@@ Latencies: '
f'min={min_latency:.2f}, '
f'max={max_latency:.2f}, '
f'average={avg_latency:.2f}'
)
)
self.min_stats.append(min_latency)
self.max_stats.append(max_latency)
self.avg_stats.append(avg_latency)
run_counter = f'[{run + 1} of {self.repeat + 1}]' if self.repeat else ''
logging.info(color(f'=== {run_counter} Done!', 'magenta'))
if self.repeat:
log_stats('Min Latency', self.min_stats)
log_stats('Max Latency', self.max_stats)
log_stats('Average Latency', self.avg_stats)
if self.repeat:
logging.info(color('--- End of runs', 'blue'))
await self.done.wait()
average_latency = sum(self.latencies) / len(self.latencies)
logging.info(color(f'@@@ Average latency: {average_latency:.2f}'))
logging.info(color('=== Done!', 'magenta'))
async def send_next_ping(self):
if self.pace:
await asyncio.sleep(self.pace / 1000)
packet = struct.pack(
'>bbI',
PacketType.SEQUENCE,
@@ -606,11 +488,10 @@ class Ping:
class Pong:
expected_packet_index: int
def __init__(self, packet_io, linger):
def __init__(self, packet_io):
self.reset()
self.packet_io = packet_io
self.packet_io.packet_listener = self
self.linger = linger
self.done = asyncio.Event()
def reset(self):
@@ -655,7 +536,7 @@ class Pong:
)
)
if packet_flags & PACKET_FLAG_LAST and not self.linger:
if packet_flags & PACKET_FLAG_LAST:
self.done.set()
async def run(self):
@@ -673,7 +554,6 @@ class GattClient:
self.speed_tx = None
self.packet_listener = None
self.ready = asyncio.Event()
self.overhead_size = 0
async def on_connection(self, connection):
peer = Peer(connection)
@@ -723,9 +603,6 @@ class GattClient:
async def send_packet(self, packet):
await self.speed_tx.write_value(packet)
async def drain(self):
pass
# -----------------------------------------------------------------------------
# GattServer
@@ -735,7 +612,6 @@ class GattServer:
self.device = device
self.packet_listener = None
self.ready = asyncio.Event()
self.overhead_size = 0
# Setup the GATT service
self.speed_tx = Characteristic(
@@ -777,9 +653,6 @@ class GattServer:
async def send_packet(self, packet):
await self.device.notify_subscribers(self.speed_rx, packet)
async def drain(self):
pass
# -----------------------------------------------------------------------------
# StreamedPacketIO
@@ -791,7 +664,6 @@ class StreamedPacketIO:
self.rx_packet = b''
self.rx_packet_header = b''
self.rx_packet_need = 0
self.overhead_size = 2
def on_packet(self, packet):
while packet:
@@ -843,7 +715,6 @@ class L2capClient(StreamedPacketIO):
self.max_credits = max_credits
self.mtu = mtu
self.mps = mps
self.l2cap_channel = None
self.ready = asyncio.Event()
async def on_connection(self, connection: Connection) -> None:
@@ -865,10 +736,9 @@ class L2capClient(StreamedPacketIO):
logging.info(color(f'!!! Connection failed: {error}', 'red'))
return
self.io_sink = l2cap_channel.write
self.l2cap_channel = l2cap_channel
l2cap_channel.on('close', self.on_l2cap_close)
l2cap_channel.sink = self.on_packet
l2cap_channel.on('close', self.on_l2cap_close)
self.io_sink = l2cap_channel.write
self.ready.set()
@@ -878,10 +748,6 @@ class L2capClient(StreamedPacketIO):
def on_l2cap_close(self):
logging.info(color('*** L2CAP channel closed', 'red'))
async def drain(self):
assert self.l2cap_channel
await self.l2cap_channel.drain()
# -----------------------------------------------------------------------------
# L2capServer
@@ -920,7 +786,6 @@ class L2capServer(StreamedPacketIO):
logging.info(color(f'*** L2CAP channel: {l2cap_channel}', 'cyan'))
self.io_sink = l2cap_channel.write
self.l2cap_channel = l2cap_channel
l2cap_channel.on('close', self.on_l2cap_close)
l2cap_channel.sink = self.on_packet
@@ -930,10 +795,6 @@ class L2capServer(StreamedPacketIO):
logging.info(color('*** L2CAP channel closed', 'red'))
self.l2cap_channel = None
async def drain(self):
assert self.l2cap_channel
await self.l2cap_channel.drain()
# -----------------------------------------------------------------------------
# RfcommClient
@@ -944,7 +805,6 @@ class RfcommClient(StreamedPacketIO):
self.device = device
self.channel = channel
self.uuid = uuid
self.rfcomm_session = None
self.ready = asyncio.Event()
async def on_connection(self, connection):
@@ -980,17 +840,12 @@ class RfcommClient(StreamedPacketIO):
rfcomm_session.sink = self.on_packet
self.io_sink = rfcomm_session.write
self.rfcomm_session = rfcomm_session
self.ready.set()
def on_disconnection(self, _):
pass
async def drain(self):
assert self.rfcomm_session
await self.rfcomm_session.drain()
# -----------------------------------------------------------------------------
# RfcommServer
@@ -998,7 +853,6 @@ class RfcommClient(StreamedPacketIO):
class RfcommServer(StreamedPacketIO):
def __init__(self, device, channel):
super().__init__()
self.dlc = None
self.ready = asyncio.Event()
# Create and register a server
@@ -1027,11 +881,6 @@ class RfcommServer(StreamedPacketIO):
logging.info(color(f'*** DLC connected: {dlc}', 'blue'))
dlc.sink = self.on_packet
self.io_sink = dlc.write
self.dlc = dlc
async def drain(self):
assert self.dlc
await self.dlc.drain()
# -----------------------------------------------------------------------------
@@ -1181,7 +1030,6 @@ class Central(Connection.Listener):
await role.run()
await asyncio.sleep(DEFAULT_LINGER_TIME)
await self.connection.disconnect()
def on_disconnection(self, reason):
logging.info(color(f'!!! Disconnection: reason={reason}', 'red'))
@@ -1272,8 +1120,12 @@ class Peripheral(Device.Listener, Connection.Listener):
# Stop being discoverable and connectable
if self.classic:
AsyncRunner.spawn(self.device.set_discoverable(False))
AsyncRunner.spawn(self.device.set_connectable(False))
async def stop_being_discoverable_connectable():
await self.device.set_discoverable(False)
await self.device.set_connectable(False)
AsyncRunner.spawn(stop_being_discoverable_connectable())
# Request a new data length if needed
if self.extended_data_length:
@@ -1289,10 +1141,6 @@ class Peripheral(Device.Listener, Connection.Listener):
self.connection = None
self.role.reset()
if self.classic:
AsyncRunner.spawn(self.device.set_discoverable(True))
AsyncRunner.spawn(self.device.set_connectable(True))
def on_connection_parameters_update(self):
print_connection(self.connection)
@@ -1320,22 +1168,10 @@ def create_mode_factory(ctx, default_mode):
return GattServer(device)
if mode == 'l2cap-client':
return L2capClient(
device,
psm=ctx.obj['l2cap_psm'],
mtu=ctx.obj['l2cap_mtu'],
mps=ctx.obj['l2cap_mps'],
max_credits=ctx.obj['l2cap_max_credits'],
)
return L2capClient(device, psm=ctx.obj['l2cap_psm'])
if mode == 'l2cap-server':
return L2capServer(
device,
psm=ctx.obj['l2cap_psm'],
mtu=ctx.obj['l2cap_mtu'],
mps=ctx.obj['l2cap_mps'],
max_credits=ctx.obj['l2cap_max_credits'],
)
return L2capServer(device, psm=ctx.obj['l2cap_psm'])
if mode == 'rfcomm-client':
return RfcommClient(
@@ -1361,29 +1197,23 @@ def create_role_factory(ctx, default_role):
return Sender(
packet_io,
start_delay=ctx.obj['start_delay'],
repeat=ctx.obj['repeat'],
repeat_delay=ctx.obj['repeat_delay'],
pace=ctx.obj['pace'],
packet_size=ctx.obj['packet_size'],
packet_count=ctx.obj['packet_count'],
)
if role == 'receiver':
return Receiver(packet_io, ctx.obj['linger'])
return Receiver(packet_io)
if role == 'ping':
return Ping(
packet_io,
start_delay=ctx.obj['start_delay'],
repeat=ctx.obj['repeat'],
repeat_delay=ctx.obj['repeat_delay'],
pace=ctx.obj['pace'],
packet_size=ctx.obj['packet_size'],
packet_count=ctx.obj['packet_count'],
)
if role == 'pong':
return Pong(packet_io, ctx.obj['linger'])
return Pong(packet_io)
raise ValueError('invalid role')
@@ -1428,7 +1258,7 @@ def create_role_factory(ctx, default_role):
@click.option(
'--rfcomm-uuid',
default=DEFAULT_RFCOMM_UUID,
help='RFComm service UUID to use (ignored if --rfcomm-channel is not 0)',
help='RFComm service UUID to use (ignored is --rfcomm-channel is not 0)',
)
@click.option(
'--l2cap-psm',
@@ -1436,31 +1266,13 @@ def create_role_factory(ctx, default_role):
default=DEFAULT_L2CAP_PSM,
help='L2CAP PSM to use',
)
@click.option(
'--l2cap-mtu',
type=int,
default=DEFAULT_L2CAP_MTU,
help='L2CAP MTU to use',
)
@click.option(
'--l2cap-mps',
type=int,
default=DEFAULT_L2CAP_MPS,
help='L2CAP MPS to use',
)
@click.option(
'--l2cap-max-credits',
type=int,
default=DEFAULT_L2CAP_MAX_CREDITS,
help='L2CAP maximum number of credits allowed for the peer',
)
@click.option(
'--packet-size',
'-s',
metavar='SIZE',
type=click.IntRange(8, 4096),
default=500,
help='Packet size (client or ping role)',
help='Packet size (server role)',
)
@click.option(
'--packet-count',
@@ -1468,7 +1280,7 @@ def create_role_factory(ctx, default_role):
metavar='COUNT',
type=int,
default=10,
help='Packet count (client or ping role)',
help='Packet count (server role)',
)
@click.option(
'--start-delay',
@@ -1476,39 +1288,7 @@ def create_role_factory(ctx, default_role):
metavar='SECONDS',
type=int,
default=1,
help='Start delay (client or ping role)',
)
@click.option(
'--repeat',
metavar='N',
type=int,
default=0,
help=(
'Repeat the run N times (client and ping roles)'
'(0, which is the fault, to run just once) '
),
)
@click.option(
'--repeat-delay',
metavar='SECONDS',
type=int,
default=1,
help=('Delay, in seconds, between repeats'),
)
@click.option(
'--pace',
metavar='MILLISECONDS',
type=int,
default=0,
help=(
'Wait N milliseconds between packets '
'(0, which is the fault, to send as fast as possible) '
),
)
@click.option(
'--linger',
is_flag=True,
help="Don't exit at the end of a run (server and pong roles)",
help='Start delay (server role)',
)
@click.pass_context
def bench(
@@ -1521,16 +1301,9 @@ def bench(
packet_size,
packet_count,
start_delay,
repeat,
repeat_delay,
pace,
linger,
rfcomm_channel,
rfcomm_uuid,
l2cap_psm,
l2cap_mtu,
l2cap_mps,
l2cap_max_credits,
):
ctx.ensure_object(dict)
ctx.obj['device_config'] = device_config
@@ -1540,16 +1313,9 @@ def bench(
ctx.obj['rfcomm_channel'] = rfcomm_channel
ctx.obj['rfcomm_uuid'] = rfcomm_uuid
ctx.obj['l2cap_psm'] = l2cap_psm
ctx.obj['l2cap_mtu'] = l2cap_mtu
ctx.obj['l2cap_mps'] = l2cap_mps
ctx.obj['l2cap_max_credits'] = l2cap_max_credits
ctx.obj['packet_size'] = packet_size
ctx.obj['packet_count'] = packet_count
ctx.obj['start_delay'] = start_delay
ctx.obj['repeat'] = repeat
ctx.obj['repeat_delay'] = repeat_delay
ctx.obj['pace'] = pace
ctx.obj['linger'] = linger
ctx.obj['extended_data_length'] = (
[int(x) for x in extended_data_length.split('/')]
+4 -4
View File
@@ -777,7 +777,7 @@ class ConsoleApp:
if not service:
continue
values = [
await attribute.read_value(connection)
attribute.read_value(connection)
for connection in self.device.connections.values()
]
if not values:
@@ -796,11 +796,11 @@ class ConsoleApp:
if not characteristic:
continue
values = [
await attribute.read_value(connection)
attribute.read_value(connection)
for connection in self.device.connections.values()
]
if not values:
values = [await attribute.read_value(None)]
values = [attribute.read_value(None)]
# TODO: future optimization: convert CCCD value to human readable string
@@ -944,7 +944,7 @@ class ConsoleApp:
# send data to any subscribers
if isinstance(attribute, Characteristic):
await attribute.write_value(None, value)
attribute.write_value(None, value)
if attribute.has_properties(Characteristic.NOTIFY):
await self.device.gatt_server.notify_subscribers(attribute)
if attribute.has_properties(Characteristic.INDICATE):
-200
View File
@@ -1,200 +0,0 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import logging
import os
import time
from typing import Optional
from bumble.colors import color
from bumble.hci import (
HCI_READ_LOOPBACK_MODE_COMMAND,
HCI_Read_Loopback_Mode_Command,
HCI_WRITE_LOOPBACK_MODE_COMMAND,
HCI_Write_Loopback_Mode_Command,
LoopbackMode,
)
from bumble.host import Host
from bumble.transport import open_transport_or_link
import click
class Loopback:
"""Send and receive ACL data packets in local loopback mode"""
def __init__(self, packet_size: int, packet_count: int, transport: str):
self.transport = transport
self.packet_size = packet_size
self.packet_count = packet_count
self.connection_handle: Optional[int] = None
self.connection_event = asyncio.Event()
self.done = asyncio.Event()
self.expected_cid = 0
self.bytes_received = 0
self.start_timestamp = 0.0
self.last_timestamp = 0.0
def on_connection(self, connection_handle: int, *args):
"""Retrieve connection handle from new connection event"""
if not self.connection_event.is_set():
# save first connection handle for ACL
# subsequent connections are SCO
self.connection_handle = connection_handle
self.connection_event.set()
def on_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes):
"""Calculate packet receive speed"""
now = time.time()
print(f'<<< Received packet {cid}: {len(pdu)} bytes')
assert connection_handle == self.connection_handle
assert cid == self.expected_cid
self.expected_cid += 1
if cid == 0:
self.start_timestamp = now
else:
elapsed_since_start = now - self.start_timestamp
elapsed_since_last = now - self.last_timestamp
self.bytes_received += len(pdu)
instant_rx_speed = len(pdu) / elapsed_since_last
average_rx_speed = self.bytes_received / elapsed_since_start
print(
color(
f'@@@ RX speed: instant={instant_rx_speed:.4f},'
f' average={average_rx_speed:.4f}',
'cyan',
)
)
self.last_timestamp = now
if self.expected_cid == self.packet_count:
print(color('@@@ Received last packet', 'green'))
self.done.set()
async def run(self):
"""Run a loopback throughput test"""
print(color('>>> Connecting to HCI...', 'green'))
async with await open_transport_or_link(self.transport) as (
hci_source,
hci_sink,
):
print(color('>>> Connected', 'green'))
host = Host(hci_source, hci_sink)
await host.reset()
# make sure data can fit in one l2cap pdu
l2cap_header_size = 4
max_packet_size = host.acl_packet_queue.max_packet_size - l2cap_header_size
if self.packet_size > max_packet_size:
print(
color(
f'!!! Packet size ({self.packet_size}) larger than max supported'
f' size ({max_packet_size})',
'red',
)
)
return
if not host.supports_command(
HCI_WRITE_LOOPBACK_MODE_COMMAND
) or not host.supports_command(HCI_READ_LOOPBACK_MODE_COMMAND):
print(color('!!! Loopback mode not supported', 'red'))
return
# set event callbacks
host.on('connection', self.on_connection)
host.on('l2cap_pdu', self.on_l2cap_pdu)
loopback_mode = LoopbackMode.LOCAL
print(color('### Setting loopback mode', 'blue'))
await host.send_command(
HCI_Write_Loopback_Mode_Command(loopback_mode=LoopbackMode.LOCAL),
check_result=True,
)
print(color('### Checking loopback mode', 'blue'))
response = await host.send_command(
HCI_Read_Loopback_Mode_Command(), check_result=True
)
if response.return_parameters.loopback_mode != loopback_mode:
print(color('!!! Loopback mode mismatch', 'red'))
return
await self.connection_event.wait()
print(color('### Connected', 'cyan'))
print(color('=== Start sending', 'magenta'))
start_time = time.time()
bytes_sent = 0
for cid in range(0, self.packet_count):
# using the cid as an incremental index
host.send_l2cap_pdu(
self.connection_handle, cid, bytes(self.packet_size)
)
print(
color(
f'>>> Sending packet {cid}: {self.packet_size} bytes', 'yellow'
)
)
bytes_sent += self.packet_size # don't count L2CAP or HCI header sizes
await asyncio.sleep(0) # yield to allow packet receive
await self.done.wait()
print(color('=== Done!', 'magenta'))
elapsed = time.time() - start_time
average_tx_speed = bytes_sent / elapsed
print(
color(
f'@@@ TX speed: average={average_tx_speed:.4f} ({bytes_sent} bytes'
f' in {elapsed:.2f} seconds)',
'green',
)
)
# -----------------------------------------------------------------------------
@click.command()
@click.option(
'--packet-size',
'-s',
metavar='SIZE',
type=click.IntRange(8, 4096),
default=500,
help='Packet size',
)
@click.option(
'--packet-count',
'-c',
metavar='COUNT',
type=int,
default=10,
help='Packet count',
)
@click.argument('transport')
def main(packet_size, packet_count, transport):
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
loopback = Loopback(packet_size, packet_count, transport)
asyncio.run(loopback.run())
# -----------------------------------------------------------------------------
if __name__ == '__main__':
main()
+24 -32
View File
@@ -49,16 +49,14 @@ class ServerBridge:
self.tcp_port = tcp_port
async def start(self, device: Device) -> None:
# Listen for incoming L2CAP channel connections
# Listen for incoming L2CAP CoC connections
device.create_l2cap_server(
spec=l2cap.LeCreditBasedChannelSpec(
psm=self.psm, mtu=self.mtu, mps=self.mps, max_credits=self.max_credits
),
handler=self.on_channel,
)
print(
color(f'### Listening for channel connection on PSM {self.psm}', 'yellow')
handler=self.on_coc,
)
print(color(f'### Listening for CoC connection on PSM {self.psm}', 'yellow'))
def on_ble_connection(connection):
def on_ble_disconnection(reason):
@@ -75,7 +73,7 @@ class ServerBridge:
await device.start_advertising(auto_restart=True)
# Called when a new L2CAP connection is established
def on_channel(self, l2cap_channel):
def on_coc(self, l2cap_channel):
print(color('*** L2CAP channel:', 'cyan'), l2cap_channel)
class Pipe:
@@ -85,7 +83,7 @@ class ServerBridge:
self.l2cap_channel = l2cap_channel
l2cap_channel.on('close', self.on_l2cap_close)
l2cap_channel.sink = self.on_channel_sdu
l2cap_channel.sink = self.on_coc_sdu
async def connect_to_tcp(self):
# Connect to the TCP server
@@ -130,7 +128,7 @@ class ServerBridge:
if self.tcp_transport is not None:
self.tcp_transport.close()
def on_channel_sdu(self, sdu):
def on_coc_sdu(self, sdu):
print(color(f'<<< [L2CAP SDU]: {len(sdu)} bytes', 'cyan'))
if self.tcp_transport is None:
print(color('!!! TCP socket not open, dropping', 'red'))
@@ -185,7 +183,7 @@ class ClientBridge:
peer_name = writer.get_extra_info('peer_name')
print(color(f'<<< TCP connection from {peer_name}', 'magenta'))
def on_channel_sdu(sdu):
def on_coc_sdu(sdu):
print(color(f'<<< [L2CAP SDU]: {len(sdu)} bytes', 'cyan'))
l2cap_to_tcp_pipe.write(sdu)
@@ -211,7 +209,7 @@ class ClientBridge:
writer.close()
return
l2cap_channel.sink = on_channel_sdu
l2cap_channel.sink = on_coc_sdu
l2cap_channel.on('close', on_l2cap_close)
# Start a flow control pipe from L2CAP to TCP
@@ -276,29 +274,23 @@ async def run(device_config, hci_transport, bridge):
@click.pass_context
@click.option('--device-config', help='Device configuration file', required=True)
@click.option('--hci-transport', help='HCI transport', required=True)
@click.option('--psm', help='PSM for L2CAP', type=int, default=1234)
@click.option('--psm', help='PSM for L2CAP CoC', type=int, default=1234)
@click.option(
'--l2cap-max-credits',
help='Maximum L2CAP Credits',
'--l2cap-coc-max-credits',
help='Maximum L2CAP CoC Credits',
type=click.IntRange(1, 65535),
default=128,
)
@click.option(
'--l2cap-mtu',
help='L2CAP MTU',
type=click.IntRange(
l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU,
l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MTU,
),
default=1024,
'--l2cap-coc-mtu',
help='L2CAP CoC MTU',
type=click.IntRange(23, 65535),
default=1022,
)
@click.option(
'--l2cap-mps',
help='L2CAP MPS',
type=click.IntRange(
l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MPS,
l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MPS,
),
'--l2cap-coc-mps',
help='L2CAP CoC MPS',
type=click.IntRange(23, 65533),
default=1024,
)
def cli(
@@ -306,17 +298,17 @@ def cli(
device_config,
hci_transport,
psm,
l2cap_max_credits,
l2cap_mtu,
l2cap_mps,
l2cap_coc_max_credits,
l2cap_coc_mtu,
l2cap_coc_mps,
):
context.ensure_object(dict)
context.obj['device_config'] = device_config
context.obj['hci_transport'] = hci_transport
context.obj['psm'] = psm
context.obj['max_credits'] = l2cap_max_credits
context.obj['mtu'] = l2cap_mtu
context.obj['mps'] = l2cap_mps
context.obj['max_credits'] = l2cap_coc_max_credits
context.obj['mtu'] = l2cap_coc_mtu
context.obj['mps'] = l2cap_coc_mps
# -----------------------------------------------------------------------------
+44 -8
View File
@@ -26,7 +26,7 @@ from bumble.transport import open_transport_or_link
from bumble.keys import JsonKeyStore
from bumble.smp import AddressResolver
from bumble.device import Advertisement
from bumble.hci import HCI_Constant, HCI_LE_1M_PHY, HCI_LE_CODED_PHY
from bumble.hci import Address, HCI_Constant, HCI_LE_1M_PHY, HCI_LE_CODED_PHY
# -----------------------------------------------------------------------------
@@ -66,10 +66,15 @@ class AdvertisementPrinter:
address_type_string = ('PUBLIC', 'RANDOM', 'PUBLIC_ID', 'RANDOM_ID')[
address.address_type
]
if address.is_public:
type_color = 'cyan'
if address.address_type in (
Address.RANDOM_IDENTITY_ADDRESS,
Address.PUBLIC_IDENTITY_ADDRESS,
):
type_color = 'yellow'
else:
if address.is_static:
if address.is_public:
type_color = 'cyan'
elif address.is_static:
type_color = 'green'
address_qualifier = '(static)'
elif address.is_resolvable:
@@ -116,6 +121,7 @@ async def scan(
phy,
filter_duplicates,
raw,
irks,
keystore_file,
device_config,
transport,
@@ -140,9 +146,21 @@ async def scan(
if device.keystore:
resolving_keys = await device.keystore.get_resolving_keys()
resolver = AddressResolver(resolving_keys)
else:
resolver = None
resolving_keys = []
for irk_and_address in irks:
if ':' not in irk_and_address:
raise ValueError('invalid IRK:ADDRESS value')
irk_hex, address_str = irk_and_address.split(':', 1)
resolving_keys.append(
(
bytes.fromhex(irk_hex),
Address(address_str, Address.RANDOM_DEVICE_ADDRESS),
)
)
resolver = AddressResolver(resolving_keys) if resolving_keys else None
printer = AdvertisementPrinter(min_rssi, resolver)
if raw:
@@ -187,8 +205,24 @@ async def scan(
default=False,
help='Listen for raw advertising reports instead of processed ones',
)
@click.option('--keystore-file', help='Keystore file to use when resolving addresses')
@click.option('--device-config', help='Device config file for the scanning device')
@click.option(
'--irk',
metavar='<IRK_HEX>:<ADDRESS>',
help=(
'Use this IRK for resolving private addresses ' '(may be used more than once)'
),
multiple=True,
)
@click.option(
'--keystore-file',
metavar='FILE_PATH',
help='Keystore file to use when resolving addresses',
)
@click.option(
'--device-config',
metavar='FILE_PATH',
help='Device config file for the scanning device',
)
@click.argument('transport')
def main(
min_rssi,
@@ -198,6 +232,7 @@ def main(
phy,
filter_duplicates,
raw,
irk,
keystore_file,
device_config,
transport,
@@ -212,6 +247,7 @@ def main(
phy,
filter_duplicates,
raw,
irk,
keystore_file,
device_config,
transport,
+11 -53
View File
@@ -25,21 +25,9 @@
from __future__ import annotations
import enum
import functools
import inspect
import struct
from typing import (
Any,
Awaitable,
Callable,
Dict,
List,
Optional,
Type,
Union,
TYPE_CHECKING,
)
from pyee import EventEmitter
from typing import Dict, Type, List, Protocol, Union, Optional, Any, TYPE_CHECKING
from bumble.core import UUID, name_or_number, ProtocolError
from bumble.hci import HCI_Object, key_with_value
@@ -734,38 +722,12 @@ class ATT_Handle_Value_Confirmation(ATT_PDU):
# -----------------------------------------------------------------------------
class AttributeValue:
'''
Attribute value where reading and/or writing is delegated to functions
passed as arguments to the constructor.
'''
class ConnectionValue(Protocol):
def read(self, connection) -> bytes:
...
def __init__(
self,
read: Union[
Callable[[Optional[Connection]], bytes],
Callable[[Optional[Connection]], Awaitable[bytes]],
None,
] = None,
write: Union[
Callable[[Optional[Connection], bytes], None],
Callable[[Optional[Connection], bytes], Awaitable[None]],
None,
] = None,
):
self._read = read
self._write = write
def read(self, connection: Optional[Connection]) -> Union[bytes, Awaitable[bytes]]:
return self._read(connection) if self._read else b''
def write(
self, connection: Optional[Connection], value: bytes
) -> Union[Awaitable[None], None]:
if self._write:
return self._write(connection, value)
return None
def write(self, connection, value: bytes) -> None:
...
# -----------------------------------------------------------------------------
@@ -808,13 +770,13 @@ class Attribute(EventEmitter):
READ_REQUIRES_AUTHORIZATION = Permissions.READ_REQUIRES_AUTHORIZATION
WRITE_REQUIRES_AUTHORIZATION = Permissions.WRITE_REQUIRES_AUTHORIZATION
value: Union[bytes, AttributeValue]
value: Union[str, bytes, ConnectionValue]
def __init__(
self,
attribute_type: Union[str, bytes, UUID],
permissions: Union[str, Attribute.Permissions],
value: Union[str, bytes, AttributeValue] = b'',
value: Union[str, bytes, ConnectionValue] = b'',
) -> None:
EventEmitter.__init__(self)
self.handle = 0
@@ -844,7 +806,7 @@ class Attribute(EventEmitter):
def decode_value(self, value_bytes: bytes) -> Any:
return value_bytes
async def read_value(self, connection: Optional[Connection]) -> bytes:
def read_value(self, connection: Optional[Connection]) -> bytes:
if (
(self.permissions & self.READ_REQUIRES_ENCRYPTION)
and connection is not None
@@ -870,8 +832,6 @@ class Attribute(EventEmitter):
if hasattr(self.value, 'read'):
try:
value = self.value.read(connection)
if inspect.isawaitable(value):
value = await value
except ATT_Error as error:
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
@@ -881,7 +841,7 @@ class Attribute(EventEmitter):
return self.encode_value(value)
async def write_value(self, connection: Connection, value_bytes: bytes) -> None:
def write_value(self, connection: Connection, value_bytes: bytes) -> None:
if (
self.permissions & self.WRITE_REQUIRES_ENCRYPTION
) and not connection.encryption:
@@ -904,9 +864,7 @@ class Attribute(EventEmitter):
if hasattr(self.value, 'write'):
try:
result = self.value.write(connection, value)
if inspect.isawaitable(result):
await result
self.value.write(connection, value) # pylint: disable=not-callable
except ATT_Error as error:
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
+48 -60
View File
@@ -23,28 +23,16 @@
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import enum
import functools
import logging
import struct
from typing import (
Callable,
Dict,
Iterable,
List,
Optional,
Sequence,
Union,
TYPE_CHECKING,
)
from typing import Optional, Sequence, Iterable, List, Union
from bumble.colors import color
from bumble.core import UUID
from bumble.att import Attribute, AttributeValue
if TYPE_CHECKING:
from bumble.gatt_client import AttributeProxy
from bumble.device import Connection
from .colors import color
from .core import UUID, get_dict_key_by_value
from .att import Attribute
# -----------------------------------------------------------------------------
@@ -534,43 +522,56 @@ class CharacteristicDeclaration(Attribute):
# -----------------------------------------------------------------------------
class CharacteristicValue(AttributeValue):
"""Same as AttributeValue, for backward compatibility"""
class CharacteristicValue:
'''
Characteristic value where reading and/or writing is delegated to functions
passed as arguments to the constructor.
'''
def __init__(self, read=None, write=None):
self._read = read
self._write = write
def read(self, connection):
return self._read(connection) if self._read else b''
def write(self, connection, value):
if self._write:
self._write(connection, value)
# -----------------------------------------------------------------------------
class CharacteristicAdapter:
'''
An adapter that can adapt Characteristic and AttributeProxy objects
by wrapping their `read_value()` and `write_value()` methods with ones that
return/accept encoded/decoded values.
For proxies (i.e used by a GATT client), the adaptation is one where the return
value of `read_value()` is decoded and the value passed to `write_value()` is
encoded. The `subscribe()` method, is wrapped with one where the values are decoded
before being passed to the subscriber.
For local values (i.e hosted by a GATT server) the adaptation is one where the
return value of `read_value()` is encoded and the value passed to `write_value()`
is decoded.
An adapter that can adapt any object with `read_value` and `write_value`
methods (like Characteristic and CharacteristicProxy objects) by wrapping
those methods with ones that return/accept encoded/decoded values.
Objects with async methods are considered proxies, so the adaptation is one
where the return value of `read_value` is decoded and the value passed to
`write_value` is encoded. Other objects are considered local characteristics
so the adaptation is one where the return value of `read_value` is encoded
and the value passed to `write_value` is decoded.
If the characteristic has a `subscribe` method, it is wrapped with one where
the values are decoded before being passed to the subscriber.
'''
read_value: Callable
write_value: Callable
def __init__(self, characteristic: Union[Characteristic, AttributeProxy]):
def __init__(self, characteristic):
self.wrapped_characteristic = characteristic
self.subscribers: Dict[
Callable, Callable
] = {} # Map from subscriber to proxy subscriber
self.subscribers = {} # Map from subscriber to proxy subscriber
if isinstance(characteristic, Characteristic):
self.read_value = self.read_encoded_value
self.write_value = self.write_encoded_value
else:
if asyncio.iscoroutinefunction(
characteristic.read_value
) and asyncio.iscoroutinefunction(characteristic.write_value):
self.read_value = self.read_decoded_value
self.write_value = self.write_decoded_value
else:
self.read_value = self.read_encoded_value
self.write_value = self.write_encoded_value
if hasattr(self.wrapped_characteristic, 'subscribe'):
self.subscribe = self.wrapped_subscribe
if hasattr(self.wrapped_characteristic, 'unsubscribe'):
self.unsubscribe = self.wrapped_unsubscribe
def __getattr__(self, name):
@@ -589,13 +590,11 @@ class CharacteristicAdapter:
else:
setattr(self.wrapped_characteristic, name, value)
async def read_encoded_value(self, connection):
return self.encode_value(
await self.wrapped_characteristic.read_value(connection)
)
def read_encoded_value(self, connection):
return self.encode_value(self.wrapped_characteristic.read_value(connection))
async def write_encoded_value(self, connection, value):
return await self.wrapped_characteristic.write_value(
def write_encoded_value(self, connection, value):
return self.wrapped_characteristic.write_value(
connection, self.decode_value(value)
)
@@ -730,24 +729,13 @@ class Descriptor(Attribute):
'''
def __str__(self) -> str:
if isinstance(self.value, bytes):
value_str = self.value.hex()
elif isinstance(self.value, CharacteristicValue):
value = self.value.read(None)
if isinstance(value, bytes):
value_str = value.hex()
else:
value_str = '<async>'
else:
value_str = '<...>'
return (
f'Descriptor(handle=0x{self.handle:04X}, '
f'type={self.type}, '
f'value={value_str})'
f'value={self.read_value(None).hex()})'
)
# -----------------------------------------------------------------------------
class ClientCharacteristicConfigurationBits(enum.IntFlag):
'''
See Vol 3, Part G - 3.3.3.3 - Table 3.11 Client Characteristic Configuration bit
+21 -29
View File
@@ -31,9 +31,9 @@ import struct
from typing import List, Tuple, Optional, TypeVar, Type, Dict, Iterable, TYPE_CHECKING
from pyee import EventEmitter
from bumble.colors import color
from bumble.core import UUID
from bumble.att import (
from .colors import color
from .core import UUID
from .att import (
ATT_ATTRIBUTE_NOT_FOUND_ERROR,
ATT_ATTRIBUTE_NOT_LONG_ERROR,
ATT_CID,
@@ -60,7 +60,7 @@ from bumble.att import (
ATT_Write_Response,
Attribute,
)
from bumble.gatt import (
from .gatt import (
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
GATT_MAX_ATTRIBUTE_VALUE_SIZE,
@@ -74,7 +74,6 @@ from bumble.gatt import (
Descriptor,
Service,
)
from bumble.utils import AsyncRunner
if TYPE_CHECKING:
from bumble.device import Device, Connection
@@ -380,7 +379,7 @@ class Server(EventEmitter):
# Get or encode the value
value = (
await attribute.read_value(connection)
attribute.read_value(connection)
if value is None
else attribute.encode_value(value)
)
@@ -423,7 +422,7 @@ class Server(EventEmitter):
# Get or encode the value
value = (
await attribute.read_value(connection)
attribute.read_value(connection)
if value is None
else attribute.encode_value(value)
)
@@ -651,8 +650,7 @@ class Server(EventEmitter):
self.send_response(connection, response)
@AsyncRunner.run_in_task()
async def on_att_find_by_type_value_request(self, connection, request):
def on_att_find_by_type_value_request(self, connection, request):
'''
See Bluetooth spec Vol 3, Part F - 3.4.3.3 Find By Type Value Request
'''
@@ -660,13 +658,13 @@ class Server(EventEmitter):
# Build list of returned attributes
pdu_space_available = connection.att_mtu - 2
attributes = []
async for attribute in (
for attribute in (
attribute
for attribute in self.attributes
if attribute.handle >= request.starting_handle
and attribute.handle <= request.ending_handle
and attribute.type == request.attribute_type
and (await attribute.read_value(connection)) == request.attribute_value
and attribute.read_value(connection) == request.attribute_value
and pdu_space_available >= 4
):
# TODO: check permissions
@@ -704,8 +702,7 @@ class Server(EventEmitter):
self.send_response(connection, response)
@AsyncRunner.run_in_task()
async def on_att_read_by_type_request(self, connection, request):
def on_att_read_by_type_request(self, connection, request):
'''
See Bluetooth spec Vol 3, Part F - 3.4.4.1 Read By Type Request
'''
@@ -728,7 +725,7 @@ class Server(EventEmitter):
and pdu_space_available
):
try:
attribute_value = await attribute.read_value(connection)
attribute_value = attribute.read_value(connection)
except ATT_Error as error:
# If the first attribute is unreadable, return an error
# Otherwise return attributes up to this point
@@ -770,15 +767,14 @@ class Server(EventEmitter):
self.send_response(connection, response)
@AsyncRunner.run_in_task()
async def on_att_read_request(self, connection, request):
def on_att_read_request(self, connection, request):
'''
See Bluetooth spec Vol 3, Part F - 3.4.4.3 Read Request
'''
if attribute := self.get_attribute(request.attribute_handle):
try:
value = await attribute.read_value(connection)
value = attribute.read_value(connection)
except ATT_Error as error:
response = ATT_Error_Response(
request_opcode_in_error=request.op_code,
@@ -796,15 +792,14 @@ class Server(EventEmitter):
)
self.send_response(connection, response)
@AsyncRunner.run_in_task()
async def on_att_read_blob_request(self, connection, request):
def on_att_read_blob_request(self, connection, request):
'''
See Bluetooth spec Vol 3, Part F - 3.4.4.5 Read Blob Request
'''
if attribute := self.get_attribute(request.attribute_handle):
try:
value = await attribute.read_value(connection)
value = attribute.read_value(connection)
except ATT_Error as error:
response = ATT_Error_Response(
request_opcode_in_error=request.op_code,
@@ -841,8 +836,7 @@ class Server(EventEmitter):
)
self.send_response(connection, response)
@AsyncRunner.run_in_task()
async def on_att_read_by_group_type_request(self, connection, request):
def on_att_read_by_group_type_request(self, connection, request):
'''
See Bluetooth spec Vol 3, Part F - 3.4.4.9 Read by Group Type Request
'''
@@ -870,7 +864,7 @@ class Server(EventEmitter):
):
# No need to catch permission errors here, since these attributes
# must all be world-readable
attribute_value = await attribute.read_value(connection)
attribute_value = attribute.read_value(connection)
# Check the attribute value size
max_attribute_size = min(connection.att_mtu - 6, 251)
if len(attribute_value) > max_attribute_size:
@@ -909,8 +903,7 @@ class Server(EventEmitter):
self.send_response(connection, response)
@AsyncRunner.run_in_task()
async def on_att_write_request(self, connection, request):
def on_att_write_request(self, connection, request):
'''
See Bluetooth spec Vol 3, Part F - 3.4.5.1 Write Request
'''
@@ -943,13 +936,12 @@ class Server(EventEmitter):
return
# Accept the value
await attribute.write_value(connection, request.attribute_value)
attribute.write_value(connection, request.attribute_value)
# Done
self.send_response(connection, ATT_Write_Response())
@AsyncRunner.run_in_task()
async def on_att_write_command(self, connection, request):
def on_att_write_command(self, connection, request):
'''
See Bluetooth spec Vol 3, Part F - 3.4.5.3 Write Command
'''
@@ -967,7 +959,7 @@ class Server(EventEmitter):
# Accept the value
try:
await attribute.write_value(connection, request.attribute_value)
attribute.write_value(connection, request.attribute_value)
except Exception as error:
logger.exception(f'!!! ignoring exception: {error}')
-32
View File
@@ -2026,17 +2026,6 @@ class OwnAddressType(enum.IntEnum):
return {'size': 1, 'mapper': lambda x: OwnAddressType(x).name}
# -----------------------------------------------------------------------------
class LoopbackMode(enum.IntEnum):
DISABLED = 0
LOCAL = 1
REMOTE = 2
@classmethod
def type_spec(cls):
return {'size': 1, 'mapper': lambda x: LoopbackMode(x).name}
# -----------------------------------------------------------------------------
class HCI_Packet:
'''
@@ -3363,27 +3352,6 @@ class HCI_Read_Encryption_Key_Size_Command(HCI_Command):
'''
# -----------------------------------------------------------------------------
@HCI_Command.command(
return_parameters_fields=[
('status', STATUS_SPEC),
('loopback_mode', LoopbackMode.type_spec()),
],
)
class HCI_Read_Loopback_Mode_Command(HCI_Command):
'''
See Bluetooth spec @ 7.6.1 Read Loopback Mode Command
'''
# -----------------------------------------------------------------------------
@HCI_Command.command([('loopback_mode', 1)])
class HCI_Write_Loopback_Mode_Command(HCI_Command):
'''
See Bluetooth spec @ 7.6.2 Write Loopback Mode Command
'''
# -----------------------------------------------------------------------------
@HCI_Command.command([('le_event_mask', 8)])
class HCI_LE_Set_Event_Mask_Command(HCI_Command):
+5 -10
View File
@@ -149,11 +149,10 @@ L2CAP_INVALID_CID_IN_REQUEST_REASON = 0x0002
L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_CREDITS = 65535
L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU = 23
L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MTU = 65535
L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MPS = 23
L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MPS = 65533
L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU = 2048
L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS = 2048
L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS = 2046
L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS = 256
L2CAP_MAXIMUM_TRANSMISSION_UNIT_CONFIGURATION_OPTION_TYPE = 0x01
@@ -189,11 +188,8 @@ class LeCreditBasedChannelSpec:
or self.max_credits > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_CREDITS
):
raise ValueError('max credits out of range')
if (
self.mtu < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU
or self.mtu > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MTU
):
raise ValueError('MTU out of range')
if self.mtu < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU:
raise ValueError('MTU too small')
if (
self.mps < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MPS
or self.mps > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MPS
@@ -1648,13 +1644,12 @@ class ChannelManager:
def send_pdu(self, connection, cid: int, pdu: Union[SupportsBytes, bytes]) -> None:
pdu_str = pdu.hex() if isinstance(pdu, bytes) else str(pdu)
pdu_bytes = bytes(pdu)
logger.debug(
f'{color(">>> Sending L2CAP PDU", "blue")} '
f'on connection [0x{connection.handle:04X}] (CID={cid}) '
f'{connection.peer_address}: {len(pdu_bytes)} bytes, {pdu_str}'
f'{connection.peer_address}: {pdu_str}'
)
self.host.send_l2cap_pdu(connection.handle, cid, pdu_bytes)
self.host.send_l2cap_pdu(connection.handle, cid, bytes(pdu))
def on_pdu(self, connection: Connection, cid: int, pdu: bytes) -> None:
if cid in (L2CAP_SIGNALING_CID, L2CAP_LE_SIGNALING_CID):
+2 -2
View File
@@ -18,7 +18,7 @@
# -----------------------------------------------------------------------------
import struct
import logging
from typing import List, Optional
from typing import List
from bumble import l2cap
from ..core import AdvertisingData
@@ -67,7 +67,7 @@ class AshaService(TemplateService):
self.emit('volume', connection, value[0])
# Handler for audio control commands
def on_audio_control_point_write(connection: Optional[Connection], value):
def on_audio_control_point_write(connection: Connection, value):
logger.info(f'--- AUDIO CONTROL POINT Write:{value.hex()}')
opcode = value[0]
if opcode == AshaService.OPCODE_START:
+3 -3
View File
@@ -114,7 +114,7 @@ class SamplingFrequency(enum.IntEnum):
'''Bluetooth Assigned Numbers, Section 6.12.5.1 - Sampling Frequency'''
# fmt: off
FREQ_8000 = 0x01
FREQ_8000 = 0x01
FREQ_11025 = 0x02
FREQ_16000 = 0x03
FREQ_22050 = 0x04
@@ -430,7 +430,7 @@ class AseResponseCode(enum.IntEnum):
REJECTED_METADATA = 0x0B
INVALID_METADATA = 0x0C
INSUFFICIENT_RESOURCES = 0x0D
UNSPECIFIED_ERROR = 0x0E
UNSPECIFIED_ERROR = 0x0E
class AseReasonCode(enum.IntEnum):
@@ -1066,7 +1066,7 @@ class AseStateMachine(gatt.Characteristic):
# Readonly. Do nothing in the setter.
pass
def on_read(self, _: Optional[device.Connection]) -> bytes:
def on_read(self, _: device.Connection) -> bytes:
return self.value
def __str__(self) -> str:
+8 -60
View File
@@ -19,7 +19,7 @@
from __future__ import annotations
import enum
import struct
from typing import Optional, Tuple
from typing import Optional
from bumble import core
from bumble import crypto
@@ -31,9 +31,6 @@ from bumble import gatt_client
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
SET_IDENTITY_RESOLVING_KEY_LENGTH = 16
class SirkType(enum.IntEnum):
'''Coordinated Set Identification Service - 5.1 Set Identity Resolving Key.'''
@@ -69,10 +66,6 @@ def k1(n: bytes, salt: bytes, p: bytes) -> bytes:
def sef(k: bytes, r: bytes) -> bytes:
'''
Coordinated Set Identification Service - 4.5 SIRK encryption function sef.
SIRK decryption function sdf shares the same algorithm. The only difference is that argument r is:
* Plaintext in encryption
* Cipher in decryption
'''
return crypto.xor(k1(k, s1(b'SIRKenc'[::-1]), b'csis'[::-1]), r)
@@ -112,11 +105,6 @@ class CoordinatedSetIdentificationService(gatt.TemplateService):
set_member_lock: Optional[MemberLock] = None,
set_member_rank: Optional[int] = None,
) -> None:
if len(set_identity_resolving_key) != SET_IDENTITY_RESOLVING_KEY_LENGTH:
raise ValueError(
f'Invalid SIRK length {len(set_identity_resolving_key)}, expected {SET_IDENTITY_RESOLVING_KEY_LENGTH}'
)
characteristics = []
self.set_identity_resolving_key = set_identity_resolving_key
@@ -125,7 +113,7 @@ class CoordinatedSetIdentificationService(gatt.TemplateService):
uuid=gatt.GATT_SET_IDENTITY_RESOLVING_KEY_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
permissions=gatt.Characteristic.Permissions.READABLE,
value=gatt.CharacteristicValue(read=self.on_sirk_read),
)
characteristics.append(self.set_identity_resolving_key_characteristic)
@@ -135,7 +123,7 @@ class CoordinatedSetIdentificationService(gatt.TemplateService):
uuid=gatt.GATT_COORDINATED_SET_SIZE_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
permissions=gatt.Characteristic.Permissions.READABLE,
value=struct.pack('B', coordinated_set_size),
)
characteristics.append(self.coordinated_set_size_characteristic)
@@ -146,7 +134,7 @@ class CoordinatedSetIdentificationService(gatt.TemplateService):
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY
| gatt.Characteristic.Properties.WRITE,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION
permissions=gatt.Characteristic.Permissions.READABLE
| gatt.Characteristic.Permissions.WRITEABLE,
value=struct.pack('B', set_member_lock),
)
@@ -157,32 +145,18 @@ class CoordinatedSetIdentificationService(gatt.TemplateService):
uuid=gatt.GATT_SET_MEMBER_RANK_CHARACTERISTIC,
properties=gatt.Characteristic.Properties.READ
| gatt.Characteristic.Properties.NOTIFY,
permissions=gatt.Characteristic.Permissions.READ_REQUIRES_ENCRYPTION,
permissions=gatt.Characteristic.Permissions.READABLE,
value=struct.pack('B', set_member_rank),
)
characteristics.append(self.set_member_rank_characteristic)
super().__init__(characteristics)
async def on_sirk_read(self, connection: Optional[device.Connection]) -> bytes:
def on_sirk_read(self, _connection: device.Connection) -> bytes:
if self.set_identity_resolving_key_type == SirkType.PLAINTEXT:
sirk_bytes = self.set_identity_resolving_key
return bytes([SirkType.PLAINTEXT]) + self.set_identity_resolving_key
else:
assert connection
if connection.transport == core.BT_LE_TRANSPORT:
key = await connection.device.get_long_term_key(
connection_handle=connection.handle, rand=b'', ediv=0
)
else:
key = await connection.device.get_link_key(connection.peer_address)
if not key:
raise RuntimeError('LTK or LinkKey is not present')
sirk_bytes = sef(key, self.set_identity_resolving_key)
return bytes([self.set_identity_resolving_key_type]) + sirk_bytes
raise NotImplementedError('TODO: Pending async Characteristic read.')
def get_advertising_data(self) -> bytes:
return bytes(
@@ -229,29 +203,3 @@ class CoordinatedSetIdentificationProxy(gatt_client.ProfileServiceProxy):
gatt.GATT_SET_MEMBER_RANK_CHARACTERISTIC
):
self.set_member_rank = characteristics[0]
async def read_set_identity_resolving_key(self) -> Tuple[SirkType, bytes]:
'''Reads SIRK and decrypts if encrypted.'''
response = await self.set_identity_resolving_key.read_value()
if len(response) != SET_IDENTITY_RESOLVING_KEY_LENGTH + 1:
raise RuntimeError('Invalid SIRK value')
sirk_type = SirkType(response[0])
if sirk_type == SirkType.PLAINTEXT:
sirk = response[1:]
else:
connection = self.service_proxy.client.connection
device = connection.device
if connection.transport == core.BT_LE_TRANSPORT:
key = await device.get_long_term_key(
connection_handle=connection.handle, rand=b'', ediv=0
)
else:
key = await device.get_link_key(connection.peer_address)
if not key:
raise RuntimeError('LTK or LinkKey is not present')
sirk = sef(key, response[1:])
return (sirk_type, sirk)
+3 -7
View File
@@ -454,8 +454,6 @@ class DLC(EventEmitter):
self.c_r = 1 if self.role == Multiplexer.Role.INITIATOR else 0
self.sink = None
self.connection_result = None
self.drained = asyncio.Event()
self.drained.set()
# Compute the MTU
max_overhead = 4 + 1 # header with 2-byte length + fcs
@@ -635,8 +633,6 @@ class DLC(EventEmitter):
)
rx_credits_needed = 0
if not self.tx_buffer:
self.drained.set()
# Stream protocol
def write(self, data: Union[bytes, str]) -> None:
@@ -649,11 +645,11 @@ class DLC(EventEmitter):
raise ValueError('write only accept bytes or strings')
self.tx_buffer += data
self.drained.clear()
self.process_tx()
async def drain(self) -> None:
await self.drained.wait()
def drain(self) -> None:
# TODO
pass
def __str__(self) -> str:
return f'DLC(dlci={self.dlci},state={self.state.name})'
+6 -3
View File
@@ -280,14 +280,17 @@ class AsyncRunner:
def wrapper(*args, **kwargs):
coroutine = func(*args, **kwargs)
if queue is None:
# Spawn the coroutine as a task
# Create a task to run the coroutine
async def run():
try:
await coroutine
except Exception:
logger.exception(color("!!! Exception in wrapper:", "red"))
logger.warning(
f'{color("!!! Exception in wrapper:", "red")} '
f'{traceback.format_exc()}'
)
AsyncRunner.spawn(run())
asyncio.create_task(run())
else:
# Queue the coroutine to be awaited by the work queue
queue.enqueue(coroutine)
+9 -30
View File
@@ -7,36 +7,16 @@ throughput and/or latency between two devices.
# General Usage
```
Usage: bumble-bench [OPTIONS] COMMAND [ARGS]...
Usage: bench.py [OPTIONS] COMMAND [ARGS]...
Options:
--device-config FILENAME Device configuration file
--role [sender|receiver|ping|pong]
--mode [gatt-client|gatt-server|l2cap-client|l2cap-server|rfcomm-client|rfcomm-server]
--att-mtu MTU GATT MTU (gatt-client mode) [23<=x<=517]
--extended-data-length TEXT Request a data length upon connection,
specified as tx_octets/tx_time
--rfcomm-channel INTEGER RFComm channel to use
--rfcomm-uuid TEXT RFComm service UUID to use (ignored if
--rfcomm-channel is not 0)
--l2cap-psm INTEGER L2CAP PSM to use
--l2cap-mtu INTEGER L2CAP MTU to use
--l2cap-mps INTEGER L2CAP MPS to use
--l2cap-max-credits INTEGER L2CAP maximum number of credits allowed for
the peer
-s, --packet-size SIZE Packet size (client or ping role)
[8<=x<=4096]
-c, --packet-count COUNT Packet count (client or ping role)
-sd, --start-delay SECONDS Start delay (client or ping role)
--repeat N Repeat the run N times (client and ping
roles)(0, which is the fault, to run just
once)
--repeat-delay SECONDS Delay, in seconds, between repeats
--pace MILLISECONDS Wait N milliseconds between packets (0,
which is the fault, to send as fast as
possible)
--linger Don't exit at the end of a run (server and
pong roles)
-s, --packet-size SIZE Packet size (server role) [8<=x<=4096]
-c, --packet-count COUNT Packet count (server role)
-sd, --start-delay SECONDS Start delay (server role)
--help Show this message and exit.
Commands:
@@ -55,18 +35,17 @@ Options:
--connection-interval, --ci CONNECTION_INTERVAL
Connection interval (in ms)
--phy [1m|2m|coded] PHY to use
--authenticate Authenticate (RFComm only)
--encrypt Encrypt the connection (RFComm only)
--help Show this message and exit.
```
To test once device against another, one of the two devices must be running
To test once device against another, one of the two devices must be running
the ``peripheral`` command and the other the ``central`` command. The device
running the ``peripheral`` command will accept connections from the device
running the ``central`` command.
When using Bluetooth LE (all modes except for ``rfcomm-server`` and ``rfcomm-client``utils),
the default addresses configured in the tool should be sufficient. But when using
Bluetooth Classic, the address of the Peripheral must be specified on the Central
the default addresses configured in the tool should be sufficient. But when using
Bluetooth Classic, the address of the Peripheral must be specified on the Central
using the ``--peripheral`` option. The address will be printed by the Peripheral when
it starts.
@@ -104,7 +83,7 @@ the other on `usb:1`, and two consoles/terminals. We will run a command in each.
$ bumble-bench central usb:1
```
In this default configuration, the Central runs a Sender, as a GATT client,
In this default configuration, the Central runs a Sender, as a GATT client,
connecting to the Peripheral running a Receiver, as a GATT server.
!!! example "L2CAP Throughput"
@@ -74,13 +74,11 @@ class L2capClient(
gatt: BluetoothGatt?, status: Int, newState: Int
) {
if (gatt != null && newState == BluetoothProfile.STATE_CONNECTED) {
if (viewModel.use2mPhy) {
gatt.setPreferredPhy(
BluetoothDevice.PHY_LE_2M_MASK,
BluetoothDevice.PHY_LE_2M_MASK,
BluetoothDevice.PHY_OPTION_NO_PREFERRED
)
}
gatt.setPreferredPhy(
BluetoothDevice.PHY_LE_2M_MASK,
BluetoothDevice.PHY_LE_2M_MASK,
BluetoothDevice.PHY_OPTION_NO_PREFERRED
)
gatt.readPhy()
// Request an MTU update, even though we don't use GATT, because Android
@@ -27,12 +27,11 @@ val DEFAULT_RFCOMM_UUID: UUID = UUID.fromString("E6D55659-C8B4-4B85-96BB-B1143AF
const val DEFAULT_PEER_BLUETOOTH_ADDRESS = "AA:BB:CC:DD:EE:FF"
const val DEFAULT_SENDER_PACKET_COUNT = 100
const val DEFAULT_SENDER_PACKET_SIZE = 1024
const val DEFAULT_PSM = 128
class AppViewModel : ViewModel() {
private var preferences: SharedPreferences? = null
var peerBluetoothAddress by mutableStateOf(DEFAULT_PEER_BLUETOOTH_ADDRESS)
var l2capPsm by mutableIntStateOf(DEFAULT_PSM)
var l2capPsm by mutableIntStateOf(0)
var use2mPhy by mutableStateOf(true)
var mtu by mutableIntStateOf(0)
var rxPhy by mutableIntStateOf(0)
+1 -2
View File
@@ -48,8 +48,7 @@ from bumble.profiles.bap import (
PublishedAudioCapabilitiesService,
PublishedAudioCapabilitiesServiceProxy,
)
from tests.test_utils import TwoDevices
from .test_utils import TwoDevices
# -----------------------------------------------------------------------------
# Logging
+6 -15
View File
@@ -20,7 +20,6 @@ import os
import pytest
import struct
import logging
from unittest import mock
from bumble import device
from bumble.profiles import csip
@@ -69,18 +68,14 @@ def test_sef():
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
@pytest.mark.parametrize(
'sirk_type,', [(csip.SirkType.ENCRYPTED), (csip.SirkType.PLAINTEXT)]
)
async def test_csis(sirk_type):
async def test_csis():
SIRK = bytes.fromhex('2f62c8ae41867d1bb619e788a2605faa')
LTK = bytes.fromhex('2f62c8ae41867d1bb619e788a2605faa')
devices = TwoDevices()
devices[0].add_service(
csip.CoordinatedSetIdentificationService(
set_identity_resolving_key=SIRK,
set_identity_resolving_key_type=sirk_type,
set_identity_resolving_key_type=csip.SirkType.PLAINTEXT,
coordinated_set_size=2,
set_member_lock=csip.MemberLock.UNLOCKED,
set_member_rank=0,
@@ -88,19 +83,15 @@ async def test_csis(sirk_type):
)
await devices.setup_connection()
# Mock encryption.
devices.connections[0].encryption = 1
devices.connections[1].encryption = 1
devices[0].get_long_term_key = mock.AsyncMock(return_value=LTK)
devices[1].get_long_term_key = mock.AsyncMock(return_value=LTK)
peer = device.Peer(devices.connections[1])
csis_client = await peer.discover_service_and_create_proxy(
csip.CoordinatedSetIdentificationProxy
)
assert await csis_client.read_set_identity_resolving_key() == (sirk_type, SIRK)
assert (
await csis_client.set_identity_resolving_key.read_value()
== bytes([csip.SirkType.PLAINTEXT]) + SIRK
)
assert await csis_client.coordinated_set_size.read_value() == struct.pack('B', 2)
assert await csis_client.set_member_lock.read_value() == struct.pack(
'B', csip.MemberLock.UNLOCKED
+31 -76
View File
@@ -20,10 +20,11 @@ import logging
import os
import struct
import pytest
from unittest.mock import AsyncMock, Mock, ANY
from unittest.mock import Mock, ANY
from bumble.controller import Controller
from bumble.gatt_client import CharacteristicProxy
from bumble.gatt_server import Server
from bumble.link import LocalLink
from bumble.device import Device, Peer
from bumble.host import Host
@@ -119,9 +120,9 @@ async def test_characteristic_encoding():
Characteristic.READABLE,
123,
)
x = await c.read_value(None)
x = c.read_value(None)
assert x == bytes([123])
await c.write_value(None, bytes([122]))
c.write_value(None, bytes([122]))
assert c.value == 122
class FooProxy(CharacteristicProxy):
@@ -151,22 +152,7 @@ async def test_characteristic_encoding():
bytes([123]),
)
async def async_read(connection):
return 0x05060708
async_characteristic = PackedCharacteristicAdapter(
Characteristic(
'2AB7E91B-43E8-4F73-AC3B-80C1683B47F9',
Characteristic.Properties.READ,
Characteristic.READABLE,
CharacteristicValue(read=async_read),
),
'>I',
)
service = Service(
'3A657F47-D34F-46B3-B1EC-698E29B6B829', [characteristic, async_characteristic]
)
service = Service('3A657F47-D34F-46B3-B1EC-698E29B6B829', [characteristic])
server.add_service(service)
await client.power_on()
@@ -198,13 +184,6 @@ async def test_characteristic_encoding():
await async_barrier()
assert characteristic.value == bytes([50])
c2 = peer.get_characteristics_by_uuid(async_characteristic.uuid)
assert len(c2) == 1
c2 = c2[0]
cd2 = PackedCharacteristicAdapter(c2, ">I")
cd2v = await cd2.read_value()
assert cd2v == 0x05060708
last_change = None
def on_change(value):
@@ -306,8 +285,7 @@ async def test_attribute_getters():
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_CharacteristicAdapter():
def test_CharacteristicAdapter():
# Check that the CharacteristicAdapter base class is transparent
v = bytes([1, 2, 3])
c = Characteristic(
@@ -318,11 +296,11 @@ async def test_CharacteristicAdapter():
)
a = CharacteristicAdapter(c)
value = await a.read_value(None)
value = a.read_value(None)
assert value == v
v = bytes([3, 4, 5])
await a.write_value(None, v)
a.write_value(None, v)
assert c.value == v
# Simple delegated adapter
@@ -330,11 +308,11 @@ async def test_CharacteristicAdapter():
c, lambda x: bytes(reversed(x)), lambda x: bytes(reversed(x))
)
value = await a.read_value(None)
value = a.read_value(None)
assert value == bytes(reversed(v))
v = bytes([3, 4, 5])
await a.write_value(None, v)
a.write_value(None, v)
assert a.value == bytes(reversed(v))
# Packed adapter with single element format
@@ -343,10 +321,10 @@ async def test_CharacteristicAdapter():
c.value = v
a = PackedCharacteristicAdapter(c, '>H')
value = await a.read_value(None)
value = a.read_value(None)
assert value == pv
c.value = None
await a.write_value(None, pv)
a.write_value(None, pv)
assert a.value == v
# Packed adapter with multi-element format
@@ -356,10 +334,10 @@ async def test_CharacteristicAdapter():
c.value = (v1, v2)
a = PackedCharacteristicAdapter(c, '>HH')
value = await a.read_value(None)
value = a.read_value(None)
assert value == pv
c.value = None
await a.write_value(None, pv)
a.write_value(None, pv)
assert a.value == (v1, v2)
# Mapped adapter
@@ -370,10 +348,10 @@ async def test_CharacteristicAdapter():
c.value = mapped
a = MappedCharacteristicAdapter(c, '>HH', ('v1', 'v2'))
value = await a.read_value(None)
value = a.read_value(None)
assert value == pv
c.value = None
await a.write_value(None, pv)
a.write_value(None, pv)
assert a.value == mapped
# UTF-8 adapter
@@ -382,49 +360,27 @@ async def test_CharacteristicAdapter():
c.value = v
a = UTF8CharacteristicAdapter(c)
value = await a.read_value(None)
value = a.read_value(None)
assert value == ev
c.value = None
await a.write_value(None, ev)
a.write_value(None, ev)
assert a.value == v
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_CharacteristicValue():
def test_CharacteristicValue():
b = bytes([1, 2, 3])
async def read_value(connection):
return b
c = CharacteristicValue(read=read_value)
x = await c.read(None)
c = CharacteristicValue(read=lambda _: b)
x = c.read(None)
assert x == b
m = Mock()
c = CharacteristicValue(write=m)
result = []
c = CharacteristicValue(
write=lambda connection, value: result.append((connection, value))
)
z = object()
c.write(z, b)
m.assert_called_once_with(z, b)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_CharacteristicValue_async():
b = bytes([1, 2, 3])
async def read_value(connection):
return b
c = CharacteristicValue(read=read_value)
x = await c.read(None)
assert x == b
m = AsyncMock()
c = CharacteristicValue(write=m)
z = object()
await c.write(z, b)
m.assert_called_once_with(z, b)
assert result == [(z, b)]
# -----------------------------------------------------------------------------
@@ -1005,18 +961,12 @@ Descriptor(handle=0x0009, type=UUID-16:2902 (Client Characteristic Configuration
# -----------------------------------------------------------------------------
async def async_main():
test_UUID()
test_ATT_Error_Response()
test_ATT_Read_By_Group_Type_Request()
await test_read_write()
await test_read_write2()
await test_subscribe_notify()
await test_unsubscribe()
await test_characteristic_encoding()
await test_mtu_exchange()
await test_CharacteristicValue()
await test_CharacteristicValue_async()
await test_CharacteristicAdapter()
# -----------------------------------------------------------------------------
@@ -1155,4 +1105,9 @@ def test_get_attribute_group():
# -----------------------------------------------------------------------------
if __name__ == '__main__':
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
test_UUID()
test_ATT_Error_Response()
test_ATT_Read_By_Group_Type_Request()
test_CharacteristicValue()
test_CharacteristicAdapter()
asyncio.run(async_main())