fix linter config

This commit is contained in:
Gilles Boccon-Gibod
2025-02-03 18:02:14 -05:00
parent 6d9a0bf4e1
commit 5293d32dc6
10 changed files with 656 additions and 293 deletions

View File

@@ -16,6 +16,7 @@
# Imports
# -----------------------------------------------------------------------------
import asyncio
import dataclasses
import enum
import logging
import os
@@ -97,34 +98,6 @@ DEFAULT_RFCOMM_MTU = 2048
# -----------------------------------------------------------------------------
# Utils
# -----------------------------------------------------------------------------
def parse_packet(packet):
if len(packet) < 1:
logging.info(
color(f'!!! Packet too short (got {len(packet)} bytes, need >= 1)', 'red')
)
raise ValueError('packet too short')
try:
packet_type = PacketType(packet[0])
except ValueError:
logging.info(color(f'!!! Invalid packet type 0x{packet[0]:02X}', 'red'))
raise
return (packet_type, packet[1:])
def parse_packet_sequence(packet_data):
if len(packet_data) < 5:
logging.info(
color(
f'!!!Packet too short (got {len(packet_data)} bytes, need >= 5)',
'red',
)
)
raise ValueError('packet too short')
return struct.unpack_from('>bI', packet_data, 0)
def le_phy_name(phy_id):
return {HCI_LE_1M_PHY: '1M', HCI_LE_2M_PHY: '2M', HCI_LE_CODED_PHY: 'CODED'}.get(
phy_id, HCI_Constant.le_phy_name(phy_id)
@@ -225,13 +198,135 @@ async def switch_roles(connection, role):
logging.info(f'{color("### Role switch failed:", "red")} {error}')
class PacketType(enum.IntEnum):
RESET = 0
SEQUENCE = 1
ACK = 2
# -----------------------------------------------------------------------------
# Packet
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class Packet:
class PacketType(enum.IntEnum):
RESET = 0
SEQUENCE = 1
ACK = 2
class PacketFlags(enum.IntFlag):
LAST = 1
packet_type: PacketType
flags: PacketFlags = PacketFlags(0)
sequence: int = 0
timestamp: int = 0
payload: bytes = b""
@classmethod
def from_bytes(cls, data: bytes):
if len(data) < 1:
logging.warning(
color(f'!!! Packet too short (got {len(data)} bytes, need >= 1)', 'red')
)
raise ValueError('packet too short')
try:
packet_type = cls.PacketType(data[0])
except ValueError:
logging.warning(color(f'!!! Invalid packet type 0x{data[0]:02X}', 'red'))
raise
if packet_type == cls.PacketType.RESET:
return cls(packet_type)
flags = cls.PacketFlags(data[1])
(sequence,) = struct.unpack_from("<I", data, 2)
if packet_type == cls.PacketType.ACK:
if len(data) < 6:
logging.warning(
color(
f'!!! Packet too short (got {len(data)} bytes, need >= 6)',
'red',
)
)
return cls(packet_type, flags, sequence)
if len(data) < 10:
logging.warning(
color(
f'!!! Packet too short (got {len(data)} bytes, need >= 10)', 'red'
)
)
raise ValueError('packet too short')
(timestamp,) = struct.unpack_from("<I", data, 6)
return cls(packet_type, flags, sequence, timestamp, data[10:])
def __bytes__(self):
if self.packet_type == self.PacketType.RESET:
return bytes([self.packet_type])
if self.packet_type == self.PacketType.ACK:
return struct.pack("<BBI", self.packet_type, self.flags, self.sequence)
return (
struct.pack(
"<BBII", self.packet_type, self.flags, self.sequence, self.timestamp
)
+ self.payload
)
PACKET_FLAG_LAST = 1
# -----------------------------------------------------------------------------
# Jitter Stats
# -----------------------------------------------------------------------------
class JitterStats:
def __init__(self):
self.reset()
def reset(self):
self.packets = []
self.receive_times = []
self.jitter = []
def on_packet_received(self, packet):
now = time.time()
self.packets.append(packet)
self.receive_times.append(now)
if packet.timestamp and len(self.packets) > 1:
expected_time = (
self.receive_times[0]
+ (packet.timestamp - self.packets[0].timestamp) / 1000000
)
jitter = now - expected_time
else:
jitter = 0.0
self.jitter.append(jitter)
return jitter
def show_stats(self):
if len(self.jitter) < 3:
return
average = sum(self.jitter) / len(self.jitter)
adjusted = [jitter - average for jitter in self.jitter]
log_stats('Jitter (signed)', adjusted, 3)
log_stats('Jitter (absolute)', [abs(jitter) for jitter in adjusted], 3)
# Show a histogram
bin_count = 20
bins = [0] * bin_count
interval_min = min(adjusted)
interval_max = max(adjusted)
interval_range = interval_max - interval_min
bin_thresholds = [
interval_min + i * (interval_range / bin_count) for i in range(bin_count)
]
for jitter in adjusted:
for i in reversed(range(bin_count)):
if jitter >= bin_thresholds[i]:
bins[i] += 1
break
for i in range(bin_count):
logging.info(f'@@@ >= {bin_thresholds[i]:.4f}: {bins[i]}')
# -----------------------------------------------------------------------------
@@ -281,19 +376,37 @@ class Sender:
await asyncio.sleep(self.tx_start_delay)
logging.info(color('=== Sending RESET', 'magenta'))
await self.packet_io.send_packet(bytes([PacketType.RESET]))
await self.packet_io.send_packet(
bytes(Packet(packet_type=Packet.PacketType.RESET))
)
self.start_time = time.time()
self.bytes_sent = 0
for tx_i in range(self.tx_packet_count):
packet_flags = (
PACKET_FLAG_LAST if tx_i == self.tx_packet_count - 1 else 0
if self.pace > 0:
# Wait until it is time to send the next packet
target_time = self.start_time + (tx_i * self.pace / 1000)
now = time.time()
if now < target_time:
await asyncio.sleep(target_time - now)
else:
await self.packet_io.drain()
packet = bytes(
Packet(
packet_type=Packet.PacketType.SEQUENCE,
flags=(
Packet.PacketFlags.LAST
if tx_i == self.tx_packet_count - 1
else 0
),
sequence=tx_i,
timestamp=int((time.time() - self.start_time) * 1000000),
payload=bytes(
self.tx_packet_size - 10 - self.packet_io.overhead_size
),
)
)
packet = struct.pack(
'>bbI',
PacketType.SEQUENCE,
packet_flags,
tx_i,
) + bytes(self.tx_packet_size - 6 - self.packet_io.overhead_size)
logging.info(
color(
f'Sending packet {tx_i}: {self.tx_packet_size} bytes', 'yellow'
@@ -302,14 +415,6 @@ class Sender:
self.bytes_sent += len(packet)
await self.packet_io.send_packet(packet)
if self.pace is None:
continue
if self.pace > 0:
await asyncio.sleep(self.pace / 1000)
else:
await self.packet_io.drain()
await self.done.wait()
run_counter = f'[{run + 1} of {self.repeat + 1}]' if self.repeat else ''
@@ -321,13 +426,13 @@ class Sender:
if self.repeat:
logging.info(color('--- End of runs', 'blue'))
def on_packet_received(self, packet):
def on_packet_received(self, data):
try:
packet_type, _ = parse_packet(packet)
packet = Packet.from_bytes(data)
except ValueError:
return
if packet_type == PacketType.ACK:
if packet.packet_type == Packet.PacketType.ACK:
elapsed = time.time() - self.start_time
average_tx_speed = self.bytes_sent / elapsed
self.stats.append(average_tx_speed)
@@ -350,52 +455,53 @@ class Receiver:
last_timestamp: float
def __init__(self, packet_io, linger):
self.reset()
self.jitter_stats = JitterStats()
self.packet_io = packet_io
self.packet_io.packet_listener = self
self.linger = linger
self.done = asyncio.Event()
self.reset()
def reset(self):
self.expected_packet_index = 0
self.measurements = [(time.time(), 0)]
self.total_bytes_received = 0
self.jitter_stats.reset()
def on_packet_received(self, packet):
def on_packet_received(self, data):
try:
packet_type, packet_data = parse_packet(packet)
packet = Packet.from_bytes(data)
except ValueError:
logging.exception("invalid packet")
return
if packet_type == PacketType.RESET:
if packet.packet_type == Packet.PacketType.RESET:
logging.info(color('=== Received RESET', 'magenta'))
self.reset()
return
try:
packet_flags, packet_index = parse_packet_sequence(packet_data)
except ValueError:
return
jitter = self.jitter_stats.on_packet_received(packet)
logging.info(
f'<<< Received packet {packet_index}: '
f'flags=0x{packet_flags:02X}, '
f'{len(packet) + self.packet_io.overhead_size} bytes'
f'<<< Received packet {packet.sequence}: '
f'flags={packet.flags}, '
f'jitter={jitter:.4f}, '
f'{len(data) + self.packet_io.overhead_size} bytes',
)
if packet_index != self.expected_packet_index:
if packet.sequence != self.expected_packet_index:
logging.info(
color(
f'!!! Unexpected packet, expected {self.expected_packet_index} '
f'but received {packet_index}'
f'but received {packet.sequence}'
)
)
now = time.time()
elapsed_since_start = now - self.measurements[0][0]
elapsed_since_last = now - self.measurements[-1][0]
self.measurements.append((now, len(packet)))
self.total_bytes_received += len(packet)
instant_rx_speed = len(packet) / elapsed_since_last
self.measurements.append((now, len(data)))
self.total_bytes_received += len(data)
instant_rx_speed = len(data) / elapsed_since_last
average_rx_speed = self.total_bytes_received / elapsed_since_start
window = self.measurements[-64:]
windowed_rx_speed = sum(measurement[1] for measurement in window[1:]) / (
@@ -411,15 +517,17 @@ class Receiver:
)
)
self.expected_packet_index = packet_index + 1
self.expected_packet_index = packet.sequence + 1
if packet_flags & PACKET_FLAG_LAST:
if packet.flags & Packet.PacketFlags.LAST:
AsyncRunner.spawn(
self.packet_io.send_packet(
struct.pack('>bbI', PacketType.ACK, packet_flags, packet_index)
bytes(Packet(Packet.PacketType.ACK, packet.flags, packet.sequence))
)
)
logging.info(color('@@@ Received last packet', 'green'))
self.jitter_stats.show_stats()
if not self.linger:
self.done.set()
@@ -479,25 +587,32 @@ class Ping:
await asyncio.sleep(self.tx_start_delay)
logging.info(color('=== Sending RESET', 'magenta'))
await self.packet_io.send_packet(bytes([PacketType.RESET]))
await self.packet_io.send_packet(bytes(Packet(Packet.PacketType.RESET)))
packet_interval = self.pace / 1000
start_time = time.time()
self.next_expected_packet_index = 0
for i in range(self.tx_packet_count):
target_time = start_time + (i * packet_interval)
target_time = start_time + (i * self.pace / 1000)
now = time.time()
if now < target_time:
await asyncio.sleep(target_time - now)
now = time.time()
packet = struct.pack(
'>bbI',
PacketType.SEQUENCE,
(PACKET_FLAG_LAST if i == self.tx_packet_count - 1 else 0),
i,
) + bytes(self.tx_packet_size - 6)
packet = bytes(
Packet(
packet_type=Packet.PacketType.SEQUENCE,
flags=(
Packet.PacketFlags.LAST
if i == self.tx_packet_count - 1
else 0
),
sequence=i,
timestamp=int((now - start_time) * 1000000),
payload=bytes(self.tx_packet_size - 10),
)
)
logging.info(color(f'Sending packet {i}', 'yellow'))
self.ping_times.append(time.time())
self.ping_times.append(now)
await self.packet_io.send_packet(packet)
await self.done.wait()
@@ -531,40 +646,35 @@ class Ping:
if self.repeat:
logging.info(color('--- End of runs', 'blue'))
def on_packet_received(self, packet):
def on_packet_received(self, data):
try:
packet_type, packet_data = parse_packet(packet)
packet = Packet.from_bytes(data)
except ValueError:
return
try:
packet_flags, packet_index = parse_packet_sequence(packet_data)
except ValueError:
return
if packet_type == PacketType.ACK:
elapsed = time.time() - self.ping_times[packet_index]
if packet.packet_type == Packet.PacketType.ACK:
elapsed = time.time() - self.ping_times[packet.sequence]
rtt = elapsed * 1000
self.rtts.append(rtt)
logging.info(
color(
f'<<< Received ACK [{packet_index}], RTT={rtt:.2f}ms',
f'<<< Received ACK [{packet.sequence}], RTT={rtt:.2f}ms',
'green',
)
)
if packet_index == self.next_expected_packet_index:
if packet.sequence == self.next_expected_packet_index:
self.next_expected_packet_index += 1
else:
logging.info(
color(
f'!!! Unexpected packet, '
f'expected {self.next_expected_packet_index} '
f'but received {packet_index}'
f'but received {packet.sequence}'
)
)
if packet_flags & PACKET_FLAG_LAST:
if packet.flags & Packet.PacketFlags.LAST:
self.done.set()
return
@@ -576,89 +686,56 @@ class Pong:
expected_packet_index: int
def __init__(self, packet_io, linger):
self.reset()
self.jitter_stats = JitterStats()
self.packet_io = packet_io
self.packet_io.packet_listener = self
self.linger = linger
self.done = asyncio.Event()
self.reset()
def reset(self):
self.expected_packet_index = 0
self.receive_times = []
def on_packet_received(self, packet):
self.receive_times.append(time.time())
self.jitter_stats.reset()
def on_packet_received(self, data):
try:
packet_type, packet_data = parse_packet(packet)
packet = Packet.from_bytes(data)
except ValueError:
return
if packet_type == PacketType.RESET:
if packet.packet_type == Packet.PacketType.RESET:
logging.info(color('=== Received RESET', 'magenta'))
self.reset()
return
try:
packet_flags, packet_index = parse_packet_sequence(packet_data)
except ValueError:
return
interval = (
self.receive_times[-1] - self.receive_times[-2]
if len(self.receive_times) >= 2
else 0
)
jitter = self.jitter_stats.on_packet_received(packet)
logging.info(
color(
f'<<< Received packet {packet_index}: '
f'flags=0x{packet_flags:02X}, {len(packet)} bytes, '
f'interval={interval:.4f}',
f'<<< Received packet {packet.sequence}: '
f'flags={packet.flags}, {len(data)} bytes, '
f'jitter={jitter:.4f}',
'green',
)
)
if packet_index != self.expected_packet_index:
if packet.sequence != self.expected_packet_index:
logging.info(
color(
f'!!! Unexpected packet, expected {self.expected_packet_index} '
f'but received {packet_index}'
f'but received {packet.sequence}'
)
)
self.expected_packet_index = packet_index + 1
self.expected_packet_index = packet.sequence + 1
AsyncRunner.spawn(
self.packet_io.send_packet(
struct.pack('>bbI', PacketType.ACK, packet_flags, packet_index)
bytes(Packet(Packet.PacketType.ACK, packet.flags, packet.sequence))
)
)
if packet_flags & PACKET_FLAG_LAST:
if len(self.receive_times) >= 3:
# Show basic stats
intervals = [
self.receive_times[i + 1] - self.receive_times[i]
for i in range(len(self.receive_times) - 1)
]
log_stats('Packet intervals', intervals, 3)
# Show a histogram
bin_count = 20
bins = [0] * bin_count
interval_min = min(intervals)
interval_max = max(intervals)
interval_range = interval_max - interval_min
bin_thresholds = [
interval_min + i * (interval_range / bin_count)
for i in range(bin_count)
]
for interval in intervals:
for i in reversed(range(bin_count)):
if interval >= bin_thresholds[i]:
bins[i] += 1
break
for i in range(bin_count):
logging.info(f'@@@ >= {bin_thresholds[i]:.4f}: {bins[i]}')
if packet.flags & Packet.PacketFlags.LAST:
self.jitter_stats.show_stats()
if not self.linger:
self.done.set()
@@ -1471,7 +1548,7 @@ def create_mode_factory(ctx, default_mode):
def create_scenario_factory(ctx, default_scenario):
scenario = ctx.obj['scenario']
if scenario is None:
scenarion = default_scenario
scenario = default_scenario
def create_scenario(packet_io):
if scenario == 'send':
@@ -1605,7 +1682,7 @@ def create_scenario_factory(ctx, default_scenario):
'--packet-size',
'-s',
metavar='SIZE',
type=click.IntRange(8, 8192),
type=click.IntRange(10, 8192),
default=500,
help='Packet size (send or ping scenario)',
)

View File

@@ -10,7 +10,7 @@ android {
defaultConfig {
applicationId = "com.github.google.bumble.btbench"
minSdk = 30
minSdk = 33
targetSdk = 34
versionCode = 1
versionName = "1.0"

View File

@@ -0,0 +1,109 @@
package com.github.google.bumble.btbench
import android.annotation.SuppressLint
import android.bluetooth.BluetoothAdapter
import android.bluetooth.BluetoothDevice
import android.bluetooth.BluetoothGatt
import android.bluetooth.BluetoothGattCallback
import android.bluetooth.BluetoothManager
import android.bluetooth.BluetoothProfile
import android.content.Context
import android.os.Build
import androidx.core.content.ContextCompat
import java.util.logging.Logger
private val Log = Logger.getLogger("btbench.connection")
open class Connection(
private val viewModel: AppViewModel,
private val bluetoothAdapter: BluetoothAdapter,
private val context: Context
) : BluetoothGattCallback() {
var remoteDevice: BluetoothDevice? = null
var gatt: BluetoothGatt? = null
@SuppressLint("MissingPermission")
open fun connect() {
val addressIsPublic = viewModel.peerBluetoothAddress.endsWith("/P")
val address = viewModel.peerBluetoothAddress.take(17)
remoteDevice = if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.TIRAMISU) {
bluetoothAdapter.getRemoteLeDevice(
address,
if (addressIsPublic) {
BluetoothDevice.ADDRESS_TYPE_PUBLIC
} else {
BluetoothDevice.ADDRESS_TYPE_RANDOM
}
)
} else {
bluetoothAdapter.getRemoteDevice(address)
}
gatt = remoteDevice?.connectGatt(
context,
false,
this,
BluetoothDevice.TRANSPORT_LE,
if (viewModel.use2mPhy) BluetoothDevice.PHY_LE_2M_MASK else BluetoothDevice.PHY_LE_1M_MASK
)
}
@SuppressLint("MissingPermission")
fun disconnect() {
gatt?.disconnect()
}
override fun onMtuChanged(gatt: BluetoothGatt, mtu: Int, status: Int) {
Log.info("MTU update: mtu=$mtu status=$status")
viewModel.mtu = mtu
}
override fun onPhyUpdate(gatt: BluetoothGatt, txPhy: Int, rxPhy: Int, status: Int) {
Log.info("PHY update: tx=$txPhy, rx=$rxPhy, status=$status")
viewModel.txPhy = txPhy
viewModel.rxPhy = rxPhy
}
override fun onPhyRead(gatt: BluetoothGatt, txPhy: Int, rxPhy: Int, status: Int) {
Log.info("PHY: tx=$txPhy, rx=$rxPhy, status=$status")
viewModel.txPhy = txPhy
viewModel.rxPhy = rxPhy
}
@SuppressLint("MissingPermission")
override fun onConnectionStateChange(
gatt: BluetoothGatt?, status: Int, newState: Int
) {
if (status != BluetoothGatt.GATT_SUCCESS) {
Log.warning("onConnectionStateChange status=$status")
}
if (gatt != null && newState == BluetoothProfile.STATE_CONNECTED) {
if (viewModel.use2mPhy) {
Log.info("requesting 2M PHY")
gatt.setPreferredPhy(
BluetoothDevice.PHY_LE_2M_MASK,
BluetoothDevice.PHY_LE_2M_MASK,
BluetoothDevice.PHY_OPTION_NO_PREFERRED
)
}
gatt.readPhy()
// Request an MTU update, even though we don't use GATT, because Android
// won't request a larger link layer maximum data length otherwise.
gatt.requestMtu(517)
// Request a specific connection priority
val connectionPriority = when (viewModel.connectionPriority) {
"BALANCED" -> BluetoothGatt.CONNECTION_PRIORITY_BALANCED
"LOW_POWER" -> BluetoothGatt.CONNECTION_PRIORITY_LOW_POWER
"HIGH" -> BluetoothGatt.CONNECTION_PRIORITY_HIGH
"DCK" -> BluetoothGatt.CONNECTION_PRIORITY_DCK
else -> 0
}
if (!gatt.requestConnectionPriority(connectionPriority)) {
Log.warning("requestConnectionPriority failed")
}
}
}
}

View File

@@ -0,0 +1,219 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.github.google.bumble.btbench
import android.annotation.SuppressLint
import android.bluetooth.BluetoothAdapter
import android.bluetooth.BluetoothGatt
import android.bluetooth.BluetoothGattCharacteristic
import android.bluetooth.BluetoothGattDescriptor
import android.bluetooth.BluetoothProfile
import android.content.Context
import java.io.IOException
import java.util.UUID
import java.util.concurrent.CountDownLatch
import java.util.concurrent.Semaphore
import java.util.logging.Logger
import kotlin.concurrent.thread
private val Log = Logger.getLogger("btbench.gatt-client")
private var CCCD_UUID = UUID.fromString("00002902-0000-1000-8000-00805F9B34FB")
private val SPEED_SERVICE_UUID = UUID.fromString("50DB505C-8AC4-4738-8448-3B1D9CC09CC5")
private val SPEED_TX_UUID = UUID.fromString("E789C754-41A1-45F4-A948-A0A1A90DBA53")
private val SPEED_RX_UUID = UUID.fromString("016A2CC7-E14B-4819-935F-1F56EAE4098D")
class GattClientConnection(
viewModel: AppViewModel,
bluetoothAdapter: BluetoothAdapter,
context: Context
) : Connection(viewModel, bluetoothAdapter, context), PacketIO {
override var packetSink: PacketSink? = null
private val discoveryDone: CountDownLatch = CountDownLatch(1)
private val writeSemaphore: Semaphore = Semaphore(1)
var rxCharacteristic: BluetoothGattCharacteristic? = null
var txCharacteristic: BluetoothGattCharacteristic? = null
override fun connect() {
super.connect()
// Check if we're already connected and have discovered the services
if (gatt?.getService(SPEED_SERVICE_UUID) != null) {
onServicesDiscovered(gatt, BluetoothGatt.GATT_SUCCESS)
}
}
@SuppressLint("MissingPermission")
override fun onConnectionStateChange(
gatt: BluetoothGatt?, status: Int, newState: Int
) {
super.onConnectionStateChange(gatt, status, newState)
if (status != BluetoothGatt.GATT_SUCCESS) {
discoveryDone.countDown()
return
}
if (gatt != null && newState == BluetoothProfile.STATE_CONNECTED) {
if (!gatt.discoverServices()) {
Log.warning("discoverServices could not start")
discoveryDone.countDown()
}
}
}
@SuppressLint("MissingPermission")
override fun onServicesDiscovered(gatt: BluetoothGatt?, status: Int) {
if (status != BluetoothGatt.GATT_SUCCESS) {
Log.warning("failed to discover services: ${status}")
discoveryDone.countDown()
return
}
// Find the service
val service = gatt!!.getService(SPEED_SERVICE_UUID)
if (service == null) {
Log.warning("GATT Service not found")
discoveryDone.countDown()
return
}
// Find the RX and TX characteristics
rxCharacteristic = service.getCharacteristic(SPEED_RX_UUID)
if (rxCharacteristic == null) {
Log.warning("GATT RX Characteristics not found")
discoveryDone.countDown()
return
}
txCharacteristic = service.getCharacteristic(SPEED_TX_UUID)
if (txCharacteristic == null) {
Log.warning("GATT TX Characteristics not found")
discoveryDone.countDown()
return
}
// Subscribe to the RX characteristic
gatt.setCharacteristicNotification(rxCharacteristic, true)
val cccdDescriptor = rxCharacteristic!!.getDescriptor(CCCD_UUID)
gatt.writeDescriptor(cccdDescriptor, BluetoothGattDescriptor.ENABLE_NOTIFICATION_VALUE);
Log.info("GATT discovery complete")
discoveryDone.countDown()
}
override fun onCharacteristicWrite(
gatt: BluetoothGatt?,
characteristic: BluetoothGattCharacteristic?,
status: Int
) {
// Now we can write again
writeSemaphore.release()
if (status != BluetoothGatt.GATT_SUCCESS) {
Log.warning("onCharacteristicWrite failed: $status")
return
}
}
override fun onCharacteristicChanged(
gatt: BluetoothGatt,
characteristic: BluetoothGattCharacteristic,
value: ByteArray
) {
if (characteristic.uuid == SPEED_RX_UUID && packetSink != null) {
val packet = Packet.from(value)
packetSink!!.onPacket(packet)
}
}
@SuppressLint("MissingPermission")
override fun sendPacket(packet: Packet) {
if (txCharacteristic == null) {
Log.warning("No TX characteristic, dropping")
return
}
// Wait until we can write
writeSemaphore.acquire()
// Write the data
val data = packet.toBytes()
val clampedData = if (data.size > 512) {
// Clamp the data to the maximum allowed characteristic data size
data.copyOf(512)
} else {
data
}
gatt?.writeCharacteristic(
txCharacteristic!!,
clampedData,
BluetoothGattCharacteristic.WRITE_TYPE_NO_RESPONSE
)
}
fun waitForDiscoveryCompletion() {
discoveryDone.await()
}
}
class GattClient(
private val viewModel: AppViewModel,
bluetoothAdapter: BluetoothAdapter,
context: Context,
private val createIoClient: (packetIo: PacketIO) -> IoClient
) : Mode {
private var connection: GattClientConnection =
GattClientConnection(viewModel, bluetoothAdapter, context)
private var clientThread: Thread? = null
@SuppressLint("MissingPermission")
override fun run() {
viewModel.running = true
clientThread = thread(name = "GattClient") {
connection.connect()
viewModel.aborter = {
connection.disconnect()
}
// Discover the rx and tx characteristics
connection.waitForDiscoveryCompletion()
if (connection.rxCharacteristic == null || connection.txCharacteristic == null) {
connection.disconnect()
viewModel.running = false
return@thread
}
val ioClient = createIoClient(connection)
try {
ioClient.run()
viewModel.status = "OK"
} catch (error: IOException) {
Log.info("run ended abruptly")
viewModel.status = "ABORTED"
viewModel.lastError = "IO_ERROR"
} finally {
connection.disconnect()
viewModel.running = false
}
}
}
override fun waitForCompletion() {
clientThread?.join()
}
}

View File

@@ -16,101 +16,25 @@ package com.github.google.bumble.btbench
import android.annotation.SuppressLint
import android.bluetooth.BluetoothAdapter
import android.bluetooth.BluetoothDevice
import android.bluetooth.BluetoothGatt
import android.bluetooth.BluetoothGattCallback
import android.bluetooth.BluetoothProfile
import android.content.Context
import android.os.Build
import java.util.logging.Logger
private val Log = Logger.getLogger("btbench.l2cap-client")
class L2capClient(
private val viewModel: AppViewModel,
private val bluetoothAdapter: BluetoothAdapter,
private val context: Context,
bluetoothAdapter: BluetoothAdapter,
context: Context,
private val createIoClient: (packetIo: PacketIO) -> IoClient
) : Mode {
private var connection: Connection = Connection(viewModel, bluetoothAdapter, context)
private var socketClient: SocketClient? = null
@SuppressLint("MissingPermission")
override fun run() {
viewModel.running = true
val addressIsPublic = viewModel.peerBluetoothAddress.endsWith("/P")
val address = viewModel.peerBluetoothAddress.take(17)
val remoteDevice = if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.TIRAMISU) {
bluetoothAdapter.getRemoteLeDevice(
address,
if (addressIsPublic) {
BluetoothDevice.ADDRESS_TYPE_PUBLIC
} else {
BluetoothDevice.ADDRESS_TYPE_RANDOM
}
)
} else {
bluetoothAdapter.getRemoteDevice(address)
}
val gatt = remoteDevice.connectGatt(
context,
false,
object : BluetoothGattCallback() {
override fun onMtuChanged(gatt: BluetoothGatt, mtu: Int, status: Int) {
Log.info("MTU update: mtu=$mtu status=$status")
viewModel.mtu = mtu
}
override fun onPhyUpdate(gatt: BluetoothGatt, txPhy: Int, rxPhy: Int, status: Int) {
Log.info("PHY update: tx=$txPhy, rx=$rxPhy, status=$status")
viewModel.txPhy = txPhy
viewModel.rxPhy = rxPhy
}
override fun onPhyRead(gatt: BluetoothGatt, txPhy: Int, rxPhy: Int, status: Int) {
Log.info("PHY: tx=$txPhy, rx=$rxPhy, status=$status")
viewModel.txPhy = txPhy
viewModel.rxPhy = rxPhy
}
override fun onConnectionStateChange(
gatt: BluetoothGatt?, status: Int, newState: Int
) {
if (gatt != null && newState == BluetoothProfile.STATE_CONNECTED) {
if (viewModel.use2mPhy) {
Log.info("requesting 2M PHY")
gatt.setPreferredPhy(
BluetoothDevice.PHY_LE_2M_MASK,
BluetoothDevice.PHY_LE_2M_MASK,
BluetoothDevice.PHY_OPTION_NO_PREFERRED
)
}
gatt.readPhy()
// Request an MTU update, even though we don't use GATT, because Android
// won't request a larger link layer maximum data length otherwise.
gatt.requestMtu(517)
// Request a specific connection priority
val connectionPriority = when (viewModel.connectionPriority) {
"BALANCED" -> BluetoothGatt.CONNECTION_PRIORITY_BALANCED
"LOW_POWER" -> BluetoothGatt.CONNECTION_PRIORITY_LOW_POWER
"HIGH" -> BluetoothGatt.CONNECTION_PRIORITY_HIGH
"DCK" -> BluetoothGatt.CONNECTION_PRIORITY_DCK
else -> 0
}
if (!gatt.requestConnectionPriority(connectionPriority)) {
Log.warning("requestConnectionPriority failed")
}
}
}
},
BluetoothDevice.TRANSPORT_LE,
if (viewModel.use2mPhy) BluetoothDevice.PHY_LE_2M_MASK else BluetoothDevice.PHY_LE_1M_MASK
)
val socket = remoteDevice.createInsecureL2capChannel(viewModel.l2capPsm)
connection.connect()
val socket = connection.remoteDevice!!.createInsecureL2capChannel(viewModel.l2capPsm)
socketClient = SocketClient(viewModel, socket, createIoClient)
socketClient!!.run()
}

View File

@@ -146,9 +146,7 @@ class MainActivity : ComponentActivity() {
initBluetooth()
setContent {
MainView(
appViewModel,
::becomeDiscoverable,
::runScenario
appViewModel, ::becomeDiscoverable, ::runScenario
)
}
@@ -184,6 +182,8 @@ class MainActivity : ComponentActivity() {
"rfcomm-server" -> appViewModel.mode = RFCOMM_SERVER_MODE
"l2cap-client" -> appViewModel.mode = L2CAP_CLIENT_MODE
"l2cap-server" -> appViewModel.mode = L2CAP_SERVER_MODE
"gatt-client" -> appViewModel.mode = GATT_CLIENT_MODE
"gatt-server" -> appViewModel.mode = GATT_SERVER_MODE
}
}
intent.getStringExtra("autostart")?.let {
@@ -204,12 +204,14 @@ class MainActivity : ComponentActivity() {
RFCOMM_CLIENT_MODE -> RfcommClient(appViewModel, bluetoothAdapter!!, ::createIoClient)
RFCOMM_SERVER_MODE -> RfcommServer(appViewModel, bluetoothAdapter!!, ::createIoClient)
L2CAP_CLIENT_MODE -> L2capClient(
appViewModel,
bluetoothAdapter!!,
baseContext,
::createIoClient
appViewModel, bluetoothAdapter!!, baseContext, ::createIoClient
)
L2CAP_SERVER_MODE -> L2capServer(appViewModel, bluetoothAdapter!!, ::createIoClient)
GATT_CLIENT_MODE -> GattClient(
appViewModel, bluetoothAdapter!!, baseContext, ::createIoClient
)
else -> throw IllegalStateException()
}
runner.run()
@@ -283,7 +285,7 @@ fun MainView(
keyboardController?.hide()
focusManager.clearFocus()
}),
enabled = (appViewModel.mode == RFCOMM_CLIENT_MODE) or (appViewModel.mode == L2CAP_CLIENT_MODE)
enabled = (appViewModel.mode == RFCOMM_CLIENT_MODE || appViewModel.mode == L2CAP_CLIENT_MODE || appViewModel.mode == GATT_CLIENT_MODE)
)
Divider()
TextField(
@@ -351,43 +353,36 @@ fun MainView(
keyboardController?.hide()
focusManager.clearFocus()
}),
enabled = (appViewModel.scenario == PING_SCENARIO)
enabled = (appViewModel.scenario == PING_SCENARIO || appViewModel.scenario == SEND_SCENARIO)
)
Divider()
ActionButton(
text = "Become Discoverable", onClick = becomeDiscoverable, true
)
Row(
horizontalArrangement = Arrangement.SpaceBetween,
verticalAlignment = Alignment.CenterVertically
) {
Text(text = "2M PHY")
Spacer(modifier = Modifier.padding(start = 8.dp))
Switch(
enabled = (appViewModel.mode == L2CAP_CLIENT_MODE || appViewModel.mode == L2CAP_SERVER_MODE),
Switch(enabled = (appViewModel.mode == L2CAP_CLIENT_MODE || appViewModel.mode == L2CAP_SERVER_MODE || appViewModel.mode == GATT_CLIENT_MODE || appViewModel.mode == GATT_SERVER_MODE),
checked = appViewModel.use2mPhy,
onCheckedChange = { appViewModel.use2mPhy = it }
)
onCheckedChange = { appViewModel.use2mPhy = it })
Column(Modifier.selectableGroup()) {
listOf(
"BALANCED",
"LOW",
"HIGH",
"DCK"
"BALANCED", "LOW", "HIGH", "DCK"
).forEach { text ->
Row(
Modifier
.selectable(
selected = (text == appViewModel.connectionPriority),
onClick = { appViewModel.updateConnectionPriority(text) },
role = Role.RadioButton
role = Role.RadioButton,
)
.padding(horizontal = 16.dp),
verticalAlignment = Alignment.CenterVertically
) {
RadioButton(
selected = (text == appViewModel.connectionPriority),
onClick = null
onClick = null,
enabled = (appViewModel.mode == L2CAP_CLIENT_MODE || appViewModel.mode == L2CAP_SERVER_MODE || appViewModel.mode == GATT_CLIENT_MODE || appViewModel.mode == GATT_SERVER_MODE)
)
Text(
text = text,
@@ -404,7 +399,9 @@ fun MainView(
RFCOMM_CLIENT_MODE,
RFCOMM_SERVER_MODE,
L2CAP_CLIENT_MODE,
L2CAP_SERVER_MODE
L2CAP_SERVER_MODE,
GATT_CLIENT_MODE,
GATT_SERVER_MODE
).forEach { text ->
Row(
Modifier
@@ -417,8 +414,7 @@ fun MainView(
verticalAlignment = Alignment.CenterVertically
) {
RadioButton(
selected = (text == appViewModel.mode),
onClick = null
selected = (text == appViewModel.mode), onClick = null
)
Text(
text = text,
@@ -430,10 +426,7 @@ fun MainView(
}
Column(Modifier.selectableGroup()) {
listOf(
SEND_SCENARIO,
RECEIVE_SCENARIO,
PING_SCENARIO,
PONG_SCENARIO
SEND_SCENARIO, RECEIVE_SCENARIO, PING_SCENARIO, PONG_SCENARIO
).forEach { text ->
Row(
Modifier
@@ -446,8 +439,7 @@ fun MainView(
verticalAlignment = Alignment.CenterVertically
) {
RadioButton(
selected = (text == appViewModel.scenario),
onClick = null
selected = (text == appViewModel.scenario), onClick = null
)
Text(
text = text,
@@ -465,20 +457,29 @@ fun MainView(
ActionButton(
text = "Stop", onClick = appViewModel::abort, enabled = appViewModel.running
)
ActionButton(
text = "Become Discoverable", onClick = becomeDiscoverable, true
)
}
Divider()
Text(
text = if (appViewModel.mtu != 0) "MTU: ${appViewModel.mtu}" else ""
)
Text(
text = if (appViewModel.rxPhy != 0 || appViewModel.txPhy != 0) "PHY: tx=${appViewModel.txPhy}, rx=${appViewModel.rxPhy}" else ""
)
if (appViewModel.mtu != 0) {
Text(
text = "MTU: ${appViewModel.mtu}"
)
}
if (appViewModel.rxPhy != 0) {
Text(
text = "PHY: tx=${appViewModel.txPhy}, rx=${appViewModel.rxPhy}"
)
}
Text(
text = "Status: ${appViewModel.status}"
)
Text(
text = "Last Error: ${appViewModel.lastError}"
)
if (appViewModel.lastError.isNotEmpty()) {
Text(
text = "Last Error: ${appViewModel.lastError}"
)
}
Text(
text = "Packets Sent: ${appViewModel.packetsSent}"
)

View File

@@ -35,6 +35,8 @@ const val L2CAP_CLIENT_MODE = "L2CAP Client"
const val L2CAP_SERVER_MODE = "L2CAP Server"
const val RFCOMM_CLIENT_MODE = "RFCOMM Client"
const val RFCOMM_SERVER_MODE = "RFCOMM Server"
const val GATT_CLIENT_MODE = "GATT Client"
const val GATT_SERVER_MODE = "GATT Server"
const val SEND_SCENARIO = "Send"
const val RECEIVE_SCENARIO = "Receive"

View File

@@ -17,6 +17,7 @@ package com.github.google.bumble.btbench
import android.bluetooth.BluetoothSocket
import java.io.IOException
import java.nio.ByteBuffer
import java.nio.ByteOrder
import java.util.logging.Logger
import kotlin.math.min
@@ -37,11 +38,16 @@ abstract class Packet(val type: Int, val payload: ByteArray = ByteArray(0)) {
RESET -> ResetPacket()
SEQUENCE -> SequencePacket(
data[1].toInt(),
ByteBuffer.wrap(data, 2, 4).getInt(),
data.sliceArray(6..<data.size)
ByteBuffer.wrap(data, 2, 4).order(ByteOrder.LITTLE_ENDIAN).getInt(),
ByteBuffer.wrap(data, 6, 4).order(ByteOrder.LITTLE_ENDIAN).getInt(),
data.sliceArray(10..<data.size)
)
ACK -> AckPacket(
data[1].toInt(),
ByteBuffer.wrap(data, 2, 4).order(ByteOrder.LITTLE_ENDIAN).getInt()
)
ACK -> AckPacket(data[1].toInt(), ByteBuffer.wrap(data, 2, 4).getInt())
else -> GenericPacket(data[0].toInt(), data.sliceArray(1..<data.size))
}
}
@@ -57,16 +63,24 @@ class ResetPacket : Packet(RESET)
class AckPacket(val flags: Int, val sequenceNumber: Int) : Packet(ACK) {
override fun toBytes(): ByteArray {
return ByteBuffer.allocate(1 + 1 + 4).put(type.toByte()).put(flags.toByte())
return ByteBuffer.allocate(6).order(
ByteOrder.LITTLE_ENDIAN
).put(type.toByte()).put(flags.toByte())
.putInt(sequenceNumber).array()
}
}
class SequencePacket(val flags: Int, val sequenceNumber: Int, payload: ByteArray) :
class SequencePacket(
val flags: Int,
val sequenceNumber: Int,
val timestamp: Int,
payload: ByteArray
) :
Packet(SEQUENCE, payload) {
override fun toBytes(): ByteArray {
return ByteBuffer.allocate(1 + 1 + 4 + payload.size).put(type.toByte()).put(flags.toByte())
.putInt(sequenceNumber).put(payload).array()
return ByteBuffer.allocate(10 + payload.size).order(ByteOrder.LITTLE_ENDIAN)
.put(type.toByte()).put(flags.toByte())
.putInt(sequenceNumber).putInt(timestamp).put(payload).array()
}
}

View File

@@ -46,19 +46,23 @@ class Pinger(private val viewModel: AppViewModel, private val packetIO: PacketIO
val startTime = TimeSource.Monotonic.markNow()
for (i in 0..<packetCount) {
val now = TimeSource.Monotonic.markNow()
val targetTime = startTime + (i * viewModel.senderPacketInterval).milliseconds
val delay = targetTime - now
if (delay.isPositive()) {
Log.info("sleeping ${delay.inWholeMilliseconds} ms")
Thread.sleep(delay.inWholeMilliseconds)
var now = TimeSource.Monotonic.markNow()
if (viewModel.senderPacketInterval > 0) {
val targetTime = startTime + (i * viewModel.senderPacketInterval).milliseconds
val delay = targetTime - now
if (delay.isPositive()) {
Log.info("sleeping ${delay.inWholeMilliseconds} ms")
Thread.sleep(delay.inWholeMilliseconds)
now = TimeSource.Monotonic.markNow()
}
}
pingTimes.add(TimeSource.Monotonic.markNow())
packetIO.sendPacket(
SequencePacket(
if (i < packetCount - 1) 0 else Packet.LAST_FLAG,
i,
ByteArray(packetSize - 6)
(now - startTime).inWholeMicroseconds.toInt(),
ByteArray(packetSize - 10)
)
)
viewModel.packetsSent = i + 1

View File

@@ -16,6 +16,7 @@ package com.github.google.bumble.btbench
import java.util.concurrent.Semaphore
import java.util.logging.Logger
import kotlin.time.Duration.Companion.milliseconds
import kotlin.time.DurationUnit
import kotlin.time.TimeSource
@@ -45,20 +46,32 @@ class Sender(private val viewModel: AppViewModel, private val packetIO: PacketIO
val packetCount = viewModel.senderPacketCount
val packetSize = viewModel.senderPacketSize
for (i in 0..<packetCount - 1) {
packetIO.sendPacket(SequencePacket(0, i, ByteArray(packetSize - 6)))
for (i in 0..<packetCount) {
var now = TimeSource.Monotonic.markNow()
if (viewModel.senderPacketInterval > 0) {
val targetTime = startTime + (i * viewModel.senderPacketInterval).milliseconds
val delay = targetTime - now
if (delay.isPositive()) {
Log.info("sleeping ${delay.inWholeMilliseconds} ms")
Thread.sleep(delay.inWholeMilliseconds)
}
now = TimeSource.Monotonic.markNow()
}
val flags = when (i) {
packetCount - 1 -> Packet.LAST_FLAG
else -> 0
}
packetIO.sendPacket(
SequencePacket(
flags,
i,
(now - startTime).inWholeMicroseconds.toInt(),
ByteArray(packetSize - 10)
)
)
bytesSent += packetSize
viewModel.packetsSent = i + 1
}
packetIO.sendPacket(
SequencePacket(
Packet.LAST_FLAG,
packetCount - 1,
ByteArray(packetSize - 6)
)
)
bytesSent += packetSize
viewModel.packetsSent = packetCount
// Wait for the ACK
Log.info("waiting for ACK")