Compare commits

..

17 Commits

Author SHA1 Message Date
Gilles Boccon-Gibod 09e5ea5dec Merge pull request #387 from google/gbg/async-gatt-server
support async read/write for characteristic values
2023-12-29 11:28:22 -08:00
Gilles Boccon-Gibod 6810865670 Merge pull request #385 from google/gbg/android-enable-dle
request MTU change after connection
2023-12-28 13:46:25 -08:00
Gilles Boccon-Gibod 3e9e06a02c Merge pull request #386 from AlanRosenthal/main
app/bench.py: use logging rather than print()
2023-12-28 13:42:17 -08:00
Alan Rosenthal ccd12f6591 app/bench.py: use logging rather than print() 2023-12-28 16:06:50 -05:00
Gilles Boccon-Gibod f9a7843f7e request MTU change after connection 2023-12-28 11:17:18 -08:00
Gilles Boccon-Gibod 210c334db7 Merge pull request #380 from google/gbg/classic-buffer-size
support per-transport ACL queues
2023-12-28 09:24:52 -08:00
Gilles Boccon-Gibod f297cdfcce Merge pull request #384 from eukub/string-concatination-to-fstring
сhanged concatenation of strings to f-strings to improve readability
2023-12-28 09:24:25 -08:00
eukub 5b536d00ab сhanged concatenation of strings to f-strings to improve readability and unify with the rest of code 2023-12-28 16:27:36 +03:00
Gilles Boccon-Gibod b4af46ebd5 use TCP_NODELAY on socket 2023-12-27 12:11:20 -08:00
Gilles Boccon-Gibod c08da3193e format 2023-12-27 11:56:06 -08:00
Gilles Boccon-Gibod f2925ca647 support async read/write for characteristic values 2023-12-27 11:52:22 -08:00
Gilles Boccon-Gibod fd4d68e5c0 print controller flow control info 2023-12-26 13:24:24 -08:00
Gilles Boccon-Gibod b90d0f8710 fix tests 2023-12-26 09:09:20 -08:00
Gilles Boccon-Gibod afc6d19e04 address PR comments 2023-12-23 14:21:44 -08:00
Gilles Boccon-Gibod c05f073b33 Update bumble/host.py
Co-authored-by: zxzxwu <92432172+zxzxwu@users.noreply.github.com>
2023-12-23 14:15:53 -08:00
Gilles Boccon-Gibod 2b4c2a22f4 format 2023-12-22 14:22:08 -08:00
Gilles Boccon-Gibod 47fe93a148 support per-transport ACL queues 2023-12-22 13:52:33 -08:00
28 changed files with 724 additions and 1248 deletions
+110 -81
View File
@@ -82,10 +82,11 @@ SPEED_RX_UUID = '016A2CC7-E14B-4819-935F-1F56EAE4098D'
DEFAULT_RFCOMM_UUID = 'E6D55659-C8B4-4B85-96BB-B1143AF6D3AE'
DEFAULT_L2CAP_PSM = 1234
DEFAULT_L2CAP_MAX_CREDITS = 128
DEFAULT_L2CAP_MTU = 1022
DEFAULT_L2CAP_MPS = 1024
DEFAULT_L2CAP_MTU = 1024
DEFAULT_L2CAP_MPS = 1022
DEFAULT_LINGER_TIME = 1.0
DEFAULT_POST_CONNECTION_WAIT_TIME = 1.0
DEFAULT_RFCOMM_CHANNEL = 8
@@ -95,7 +96,7 @@ DEFAULT_RFCOMM_CHANNEL = 8
# -----------------------------------------------------------------------------
def parse_packet(packet):
if len(packet) < 1:
print(
logging.info(
color(f'!!! Packet too short (got {len(packet)} bytes, need >= 1)', 'red')
)
raise ValueError('packet too short')
@@ -103,7 +104,7 @@ def parse_packet(packet):
try:
packet_type = PacketType(packet[0])
except ValueError:
print(color(f'!!! Invalid packet type 0x{packet[0]:02X}', 'red'))
logging.info(color(f'!!! Invalid packet type 0x{packet[0]:02X}', 'red'))
raise
return (packet_type, packet[1:])
@@ -111,7 +112,7 @@ def parse_packet(packet):
def parse_packet_sequence(packet_data):
if len(packet_data) < 5:
print(
logging.info(
color(
f'!!!Packet too short (got {len(packet_data)} bytes, need >= 5)',
'red',
@@ -155,7 +156,7 @@ def print_connection(connection):
mtu = connection.att_mtu
print(
logging.info(
f'{color("@@@ Connection:", "yellow")} '
f'{connection_parameters} '
f'{data_length} '
@@ -266,15 +267,15 @@ class Sender:
pass
async def run(self):
print(color('--- Waiting for I/O to be ready...', 'blue'))
logging.info(color('--- Waiting for I/O to be ready...', 'blue'))
await self.packet_io.ready.wait()
print(color('--- Go!', 'blue'))
logging.info(color('--- Go!', 'blue'))
if self.tx_start_delay:
print(color(f'*** Startup delay: {self.tx_start_delay}', 'blue'))
logging.info(color(f'*** Startup delay: {self.tx_start_delay}', 'blue'))
await asyncio.sleep(self.tx_start_delay)
print(color('=== Sending RESET', 'magenta'))
logging.info(color('=== Sending RESET', 'magenta'))
await self.packet_io.send_packet(bytes([PacketType.RESET]))
self.start_time = time.time()
for tx_i in range(self.tx_packet_count):
@@ -285,12 +286,12 @@ class Sender:
packet_flags,
tx_i,
) + bytes(self.tx_packet_size - 6)
print(color(f'Sending packet {tx_i}: {len(packet)} bytes', 'yellow'))
logging.info(color(f'Sending packet {tx_i}: {len(packet)} bytes', 'yellow'))
self.bytes_sent += len(packet)
await self.packet_io.send_packet(packet)
await self.done.wait()
print(color('=== Done!', 'magenta'))
logging.info(color('=== Done!', 'magenta'))
def on_packet_received(self, packet):
try:
@@ -301,7 +302,7 @@ class Sender:
if packet_type == PacketType.ACK:
elapsed = time.time() - self.start_time
average_tx_speed = self.bytes_sent / elapsed
print(
logging.info(
color(
f'@@@ Received ACK. Speed: average={average_tx_speed:.4f}'
f' ({self.bytes_sent} bytes in {elapsed:.2f} seconds)',
@@ -315,6 +316,10 @@ class Sender:
# Receiver
# -----------------------------------------------------------------------------
class Receiver:
expected_packet_index: int
start_timestamp: float
last_timestamp: float
def __init__(self, packet_io):
self.reset()
self.packet_io = packet_io
@@ -336,7 +341,7 @@ class Receiver:
now = time.time()
if packet_type == PacketType.RESET:
print(color('=== Received RESET', 'magenta'))
logging.info(color('=== Received RESET', 'magenta'))
self.reset()
self.start_timestamp = now
return
@@ -345,13 +350,13 @@ class Receiver:
packet_flags, packet_index = parse_packet_sequence(packet_data)
except ValueError:
return
print(
logging.info(
f'<<< Received packet {packet_index}: '
f'flags=0x{packet_flags:02X}, {len(packet)} bytes'
)
if packet_index != self.expected_packet_index:
print(
logging.info(
color(
f'!!! Unexpected packet, expected {self.expected_packet_index} '
f'but received {packet_index}'
@@ -363,7 +368,7 @@ class Receiver:
self.bytes_received += len(packet)
instant_rx_speed = len(packet) / elapsed_since_last
average_rx_speed = self.bytes_received / elapsed_since_start
print(
logging.info(
color(
f'Speed: instant={instant_rx_speed:.4f}, average={average_rx_speed:.4f}',
'yellow',
@@ -379,12 +384,12 @@ class Receiver:
struct.pack('>bbI', PacketType.ACK, packet_flags, packet_index)
)
)
print(color('@@@ Received last packet', 'green'))
logging.info(color('@@@ Received last packet', 'green'))
self.done.set()
async def run(self):
await self.done.wait()
print(color('=== Done!', 'magenta'))
logging.info(color('=== Done!', 'magenta'))
# -----------------------------------------------------------------------------
@@ -406,23 +411,23 @@ class Ping:
pass
async def run(self):
print(color('--- Waiting for I/O to be ready...', 'blue'))
logging.info(color('--- Waiting for I/O to be ready...', 'blue'))
await self.packet_io.ready.wait()
print(color('--- Go!', 'blue'))
logging.info(color('--- Go!', 'blue'))
if self.tx_start_delay:
print(color(f'*** Startup delay: {self.tx_start_delay}', 'blue'))
logging.info(color(f'*** Startup delay: {self.tx_start_delay}', 'blue'))
await asyncio.sleep(self.tx_start_delay)
print(color('=== Sending RESET', 'magenta'))
logging.info(color('=== Sending RESET', 'magenta'))
await self.packet_io.send_packet(bytes([PacketType.RESET]))
await self.send_next_ping()
await self.done.wait()
average_latency = sum(self.latencies) / len(self.latencies)
print(color(f'@@@ Average latency: {average_latency:.2f}'))
print(color('=== Done!', 'magenta'))
logging.info(color(f'@@@ Average latency: {average_latency:.2f}'))
logging.info(color('=== Done!', 'magenta'))
async def send_next_ping(self):
packet = struct.pack(
@@ -433,7 +438,7 @@ class Ping:
else 0,
self.current_packet_index,
) + bytes(self.tx_packet_size - 6)
print(color(f'Sending packet {self.current_packet_index}', 'yellow'))
logging.info(color(f'Sending packet {self.current_packet_index}', 'yellow'))
self.ping_sent_time = time.time()
await self.packet_io.send_packet(packet)
@@ -453,7 +458,7 @@ class Ping:
if packet_type == PacketType.ACK:
latency = elapsed * 1000
self.latencies.append(latency)
print(
logging.info(
color(
f'<<< Received ACK [{packet_index}], latency={latency:.2f}ms',
'green',
@@ -463,7 +468,7 @@ class Ping:
if packet_index == self.current_packet_index:
self.current_packet_index += 1
else:
print(
logging.info(
color(
f'!!! Unexpected packet, expected {self.current_packet_index} '
f'but received {packet_index}'
@@ -481,6 +486,8 @@ class Ping:
# Pong
# -----------------------------------------------------------------------------
class Pong:
expected_packet_index: int
def __init__(self, packet_io):
self.reset()
self.packet_io = packet_io
@@ -497,7 +504,7 @@ class Pong:
return
if packet_type == PacketType.RESET:
print(color('=== Received RESET', 'magenta'))
logging.info(color('=== Received RESET', 'magenta'))
self.reset()
return
@@ -505,7 +512,7 @@ class Pong:
packet_flags, packet_index = parse_packet_sequence(packet_data)
except ValueError:
return
print(
logging.info(
color(
f'<<< Received packet {packet_index}: '
f'flags=0x{packet_flags:02X}, {len(packet)} bytes',
@@ -514,7 +521,7 @@ class Pong:
)
if packet_index != self.expected_packet_index:
print(
logging.info(
color(
f'!!! Unexpected packet, expected {self.expected_packet_index} '
f'but received {packet_index}'
@@ -534,7 +541,7 @@ class Pong:
async def run(self):
await self.done.wait()
print(color('=== Done!', 'magenta'))
logging.info(color('=== Done!', 'magenta'))
# -----------------------------------------------------------------------------
@@ -552,36 +559,36 @@ class GattClient:
peer = Peer(connection)
if self.att_mtu:
print(color(f'*** Requesting MTU update: {self.att_mtu}', 'blue'))
logging.info(color(f'*** Requesting MTU update: {self.att_mtu}', 'blue'))
await peer.request_mtu(self.att_mtu)
print(color('*** Discovering services...', 'blue'))
logging.info(color('*** Discovering services...', 'blue'))
await peer.discover_services()
speed_services = peer.get_services_by_uuid(SPEED_SERVICE_UUID)
if not speed_services:
print(color('!!! Speed Service not found', 'red'))
logging.info(color('!!! Speed Service not found', 'red'))
return
speed_service = speed_services[0]
print(color('*** Discovering characteristics...', 'blue'))
logging.info(color('*** Discovering characteristics...', 'blue'))
await speed_service.discover_characteristics()
speed_txs = speed_service.get_characteristics_by_uuid(SPEED_TX_UUID)
if not speed_txs:
print(color('!!! Speed TX not found', 'red'))
logging.info(color('!!! Speed TX not found', 'red'))
return
self.speed_tx = speed_txs[0]
speed_rxs = speed_service.get_characteristics_by_uuid(SPEED_RX_UUID)
if not speed_rxs:
print(color('!!! Speed RX not found', 'red'))
logging.info(color('!!! Speed RX not found', 'red'))
return
self.speed_rx = speed_rxs[0]
print(color('*** Subscribing to RX', 'blue'))
logging.info(color('*** Subscribing to RX', 'blue'))
await self.speed_rx.subscribe(self.on_packet_received)
print(color('*** Discovery complete', 'blue'))
logging.info(color('*** Discovery complete', 'blue'))
connection.on('disconnection', self.on_disconnection)
self.ready.set()
@@ -633,10 +640,10 @@ class GattServer:
def on_rx_subscription(self, _connection, notify_enabled, _indicate_enabled):
if notify_enabled:
print(color('*** RX subscription', 'blue'))
logging.info(color('*** RX subscription', 'blue'))
self.ready.set()
else:
print(color('*** RX un-subscription', 'blue'))
logging.info(color('*** RX un-subscription', 'blue'))
self.ready.clear()
def on_tx_write(self, _, value):
@@ -684,7 +691,7 @@ class StreamedPacketIO:
async def send_packet(self, packet):
if not self.io_sink:
print(color('!!! No sink, dropping packet', 'red'))
logging.info(color('!!! No sink, dropping packet', 'red'))
return
# pylint: disable-next=not-callable
@@ -714,7 +721,7 @@ class L2capClient(StreamedPacketIO):
connection.on('disconnection', self.on_disconnection)
# Connect a new L2CAP channel
print(color(f'>>> Opening L2CAP channel on PSM = {self.psm}', 'yellow'))
logging.info(color(f'>>> Opening L2CAP channel on PSM = {self.psm}', 'yellow'))
try:
l2cap_channel = await connection.create_l2cap_channel(
spec=l2cap.LeCreditBasedChannelSpec(
@@ -724,9 +731,9 @@ class L2capClient(StreamedPacketIO):
mps=self.mps,
)
)
print(color('*** L2CAP channel:', 'cyan'), l2cap_channel)
logging.info(color(f'*** L2CAP channel: {l2cap_channel}', 'cyan'))
except Exception as error:
print(color(f'!!! Connection failed: {error}', 'red'))
logging.info(color(f'!!! Connection failed: {error}', 'red'))
return
l2cap_channel.sink = self.on_packet
@@ -739,7 +746,7 @@ class L2capClient(StreamedPacketIO):
pass
def on_l2cap_close(self):
print(color('*** L2CAP channel closed', 'red'))
logging.info(color('*** L2CAP channel closed', 'red'))
# -----------------------------------------------------------------------------
@@ -765,7 +772,9 @@ class L2capServer(StreamedPacketIO):
),
handler=self.on_l2cap_channel,
)
print(color(f'### Listening for L2CAP connection on PSM {psm}', 'yellow'))
logging.info(
color(f'### Listening for L2CAP connection on PSM {psm}', 'yellow')
)
async def on_connection(self, connection):
connection.on('disconnection', self.on_disconnection)
@@ -774,7 +783,7 @@ class L2capServer(StreamedPacketIO):
pass
def on_l2cap_channel(self, l2cap_channel):
print(color('*** L2CAP channel:', 'cyan'), l2cap_channel)
logging.info(color(f'*** L2CAP channel: {l2cap_channel}', 'cyan'))
self.io_sink = l2cap_channel.write
l2cap_channel.on('close', self.on_l2cap_close)
@@ -783,7 +792,7 @@ class L2capServer(StreamedPacketIO):
self.ready.set()
def on_l2cap_close(self):
print(color('*** L2CAP channel closed', 'red'))
logging.info(color('*** L2CAP channel closed', 'red'))
self.l2cap_channel = None
@@ -804,28 +813,28 @@ class RfcommClient(StreamedPacketIO):
# Find the channel number if not specified
channel = self.channel
if channel == 0:
print(
logging.info(
color(f'@@@ Discovering channel number from UUID {self.uuid}', 'cyan')
)
channel = await find_rfcomm_channel_with_uuid(connection, self.uuid)
print(color(f'@@@ Channel number = {channel}', 'cyan'))
logging.info(color(f'@@@ Channel number = {channel}', 'cyan'))
if channel == 0:
print(color('!!! No RFComm service with this UUID found', 'red'))
logging.info(color('!!! No RFComm service with this UUID found', 'red'))
await connection.disconnect()
return
# Create a client and start it
print(color('*** Starting RFCOMM client...', 'blue'))
logging.info(color('*** Starting RFCOMM client...', 'blue'))
rfcomm_client = bumble.rfcomm.Client(connection)
rfcomm_mux = await rfcomm_client.start()
print(color('*** Started', 'blue'))
logging.info(color('*** Started', 'blue'))
print(color(f'### Opening session for channel {channel}...', 'yellow'))
logging.info(color(f'### Opening session for channel {channel}...', 'yellow'))
try:
rfcomm_session = await rfcomm_mux.open_dlc(channel)
print(color('### Session open', 'yellow'), rfcomm_session)
logging.info(color(f'### Session open: {rfcomm_session}', 'yellow'))
except bumble.core.ConnectionError as error:
print(color(f'!!! Session open failed: {error}', 'red'))
logging.info(color(f'!!! Session open failed: {error}', 'red'))
await rfcomm_mux.disconnect()
return
@@ -855,7 +864,7 @@ class RfcommServer(StreamedPacketIO):
# Setup the SDP to advertise this channel
device.sdp_service_records = make_sdp_records(channel_number)
print(
logging.info(
color(
f'### Listening for RFComm connection on channel {channel_number}',
'yellow',
@@ -869,7 +878,7 @@ class RfcommServer(StreamedPacketIO):
pass
def on_dlc(self, dlc):
print(color('*** DLC connected:', 'blue'), dlc)
logging.info(color(f'*** DLC connected: {dlc}', 'blue'))
dlc.sink = self.on_packet
self.io_sink = dlc.write
@@ -935,12 +944,12 @@ class Central(Connection.Listener):
self.connection_parameter_preferences = None
async def run(self):
print(color('>>> Connecting to HCI...', 'green'))
logging.info(color('>>> Connecting to HCI...', 'green'))
async with await open_transport_or_link(self.transport) as (
hci_source,
hci_sink,
):
print(color('>>> Connected', 'green'))
logging.info(color('>>> Connected', 'green'))
central_address = DEFAULT_CENTRAL_ADDRESS
self.device = Device.with_hci(
@@ -952,7 +961,13 @@ class Central(Connection.Listener):
await self.device.power_on()
print(color(f'### Connecting to {self.peripheral_address}...', 'cyan'))
if self.classic:
await self.device.set_discoverable(False)
await self.device.set_connectable(False)
logging.info(
color(f'### Connecting to {self.peripheral_address}...', 'cyan')
)
try:
self.connection = await self.device.connect(
self.peripheral_address,
@@ -960,21 +975,26 @@ class Central(Connection.Listener):
transport=BT_BR_EDR_TRANSPORT if self.classic else BT_LE_TRANSPORT,
)
except CommandTimeoutError:
print(color('!!! Connection timed out', 'red'))
logging.info(color('!!! Connection timed out', 'red'))
return
except bumble.core.ConnectionError as error:
print(color(f'!!! Connection error: {error}', 'red'))
logging.info(color(f'!!! Connection error: {error}', 'red'))
return
except HCI_StatusError as error:
print(color(f'!!! Connection failed: {error.error_name}'))
logging.info(color(f'!!! Connection failed: {error.error_name}'))
return
print(color('### Connected', 'cyan'))
logging.info(color('### Connected', 'cyan'))
self.connection.listener = self
print_connection(self.connection)
# Wait a bit after the connection, some controllers aren't very good when
# we start sending data right away while some connection parameters are
# updated post connection
await asyncio.sleep(DEFAULT_POST_CONNECTION_WAIT_TIME)
# Request a new data length if requested
if self.extended_data_length:
print(color('+++ Requesting extended data length', 'cyan'))
logging.info(color('+++ Requesting extended data length', 'cyan'))
await self.connection.set_data_length(
self.extended_data_length[0], self.extended_data_length[1]
)
@@ -982,16 +1002,16 @@ class Central(Connection.Listener):
# Authenticate if requested
if self.authenticate:
# Request authentication
print(color('*** Authenticating...', 'cyan'))
logging.info(color('*** Authenticating...', 'cyan'))
await self.connection.authenticate()
print(color('*** Authenticated', 'cyan'))
logging.info(color('*** Authenticated', 'cyan'))
# Encrypt if requested
if self.encrypt:
# Enable encryption
print(color('*** Enabling encryption...', 'cyan'))
logging.info(color('*** Enabling encryption...', 'cyan'))
await self.connection.encrypt()
print(color('*** Encryption on', 'cyan'))
logging.info(color('*** Encryption on', 'cyan'))
# Set the PHY if requested
if self.phy is not None:
@@ -1000,7 +1020,7 @@ class Central(Connection.Listener):
tx_phys=[self.phy], rx_phys=[self.phy]
)
except HCI_Error as error:
print(
logging.info(
color(
f'!!! Unable to set the PHY: {error.error_name}', 'yellow'
)
@@ -1012,7 +1032,7 @@ class Central(Connection.Listener):
await asyncio.sleep(DEFAULT_LINGER_TIME)
def on_disconnection(self, reason):
print(color(f'!!! Disconnection: reason={reason}', 'red'))
logging.info(color(f'!!! Disconnection: reason={reason}', 'red'))
self.connection = None
def on_connection_parameters_update(self):
@@ -1047,12 +1067,12 @@ class Peripheral(Device.Listener, Connection.Listener):
self.connected = asyncio.Event()
async def run(self):
print(color('>>> Connecting to HCI...', 'green'))
logging.info(color('>>> Connecting to HCI...', 'green'))
async with await open_transport_or_link(self.transport) as (
hci_source,
hci_sink,
):
print(color('>>> Connected', 'green'))
logging.info(color('>>> Connected', 'green'))
peripheral_address = DEFAULT_PERIPHERAL_ADDRESS
self.device = Device.with_hci(
@@ -1072,7 +1092,7 @@ class Peripheral(Device.Listener, Connection.Listener):
await self.device.start_advertising(auto_restart=True)
if self.classic:
print(
logging.info(
color(
'### Waiting for connection on'
f' {self.device.public_address}...',
@@ -1080,14 +1100,14 @@ class Peripheral(Device.Listener, Connection.Listener):
)
)
else:
print(
logging.info(
color(
f'### Waiting for connection on {peripheral_address}...',
'cyan',
)
)
await self.connected.wait()
print(color('### Connected', 'cyan'))
logging.info(color('### Connected', 'cyan'))
await self.mode.on_connection(self.connection)
await self.role.run()
@@ -1098,9 +1118,18 @@ class Peripheral(Device.Listener, Connection.Listener):
self.connection = connection
self.connected.set()
# Stop being discoverable and connectable
if self.classic:
async def stop_being_discoverable_connectable():
await self.device.set_discoverable(False)
await self.device.set_connectable(False)
AsyncRunner.spawn(stop_being_discoverable_connectable())
# Request a new data length if needed
if self.extended_data_length:
print("+++ Requesting extended data length")
logging.info("+++ Requesting extended data length")
AsyncRunner.spawn(
connection.set_data_length(
self.extended_data_length[0], self.extended_data_length[1]
@@ -1108,7 +1137,7 @@ class Peripheral(Device.Listener, Connection.Listener):
)
def on_disconnection(self, reason):
print(color(f'!!! Disconnection: reason={reason}', 'red'))
logging.info(color(f'!!! Disconnection: reason={reason}', 'red'))
self.connection = None
self.role.reset()
+4 -4
View File
@@ -777,7 +777,7 @@ class ConsoleApp:
if not service:
continue
values = [
attribute.read_value(connection)
await attribute.read_value(connection)
for connection in self.device.connections.values()
]
if not values:
@@ -796,11 +796,11 @@ class ConsoleApp:
if not characteristic:
continue
values = [
attribute.read_value(connection)
await attribute.read_value(connection)
for connection in self.device.connections.values()
]
if not values:
values = [attribute.read_value(None)]
values = [await attribute.read_value(None)]
# TODO: future optimization: convert CCCD value to human readable string
@@ -944,7 +944,7 @@ class ConsoleApp:
# send data to any subscribers
if isinstance(attribute, Characteristic):
attribute.write_value(None, value)
await attribute.write_value(None, value)
if attribute.has_properties(Characteristic.NOTIFY):
await self.device.gatt_server.notify_subscribers(attribute)
if attribute.has_properties(Characteristic.INDICATE):
+34 -39
View File
@@ -24,10 +24,6 @@ from bumble.company_ids import COMPANY_IDENTIFIERS
from bumble.colors import color
from bumble.core import name_or_number
from bumble.hci import (
HCI_READ_LOCAL_EXTENDED_FEATURES_COMMAND,
HCI_READ_LOCAL_SUPPORTED_FEATURES_COMMAND,
HCI_Read_Local_Extended_Features_Command,
HCI_Read_Local_Supported_Features_Command,
map_null_terminated_utf8_string,
HCI_SUCCESS,
HCI_LE_SUPPORTED_FEATURES_NAMES,
@@ -36,10 +32,14 @@ from bumble.hci import (
HCI_Command,
HCI_Command_Complete_Event,
HCI_Command_Status_Event,
HCI_READ_BUFFER_SIZE_COMMAND,
HCI_Read_Buffer_Size_Command,
HCI_READ_BD_ADDR_COMMAND,
HCI_Read_BD_ADDR_Command,
HCI_READ_LOCAL_NAME_COMMAND,
HCI_Read_Local_Name_Command,
HCI_LE_READ_BUFFER_SIZE_COMMAND,
HCI_LE_Read_Buffer_Size_Command,
HCI_LE_READ_MAXIMUM_DATA_LENGTH_COMMAND,
HCI_LE_Read_Maximum_Data_Length_Command,
HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND,
@@ -63,37 +63,7 @@ def command_succeeded(response):
# -----------------------------------------------------------------------------
async def get_common_info(host):
if host.supports_command(HCI_READ_LOCAL_SUPPORTED_FEATURES_COMMAND):
response = await host.send_command(HCI_Read_Local_Supported_Features_Command())
if response.return_parameters.status == HCI_SUCCESS:
print()
print(color('LMP Features:', 'yellow'))
# TODO: support printing discrete enum values
print(' ', response.return_parameters.lmp_features.hex())
if host.supports_command(HCI_READ_LOCAL_EXTENDED_FEATURES_COMMAND):
response = await host.send_command(
HCI_Read_Local_Extended_Features_Command(page_number=0)
)
if response.return_parameters.status == HCI_SUCCESS:
if response.return_parameters.max_page_number > 0:
print()
print(color('Extended LMP Features:', 'yellow'))
for page in range(1, response.return_parameters.max_page_number + 1):
response = await host.send_command(
HCI_Read_Local_Extended_Features_Command(page_number=page)
)
if response.return_parameters.status == HCI_SUCCESS:
# TODO: support printing discrete enum values
print(f' Page {page}:')
print(' ', response.return_parameters.extended_lmp_features.hex())
# -----------------------------------------------------------------------------
async def get_classic_info(host):
async def get_classic_info(host: Host) -> None:
if host.supports_command(HCI_READ_BD_ADDR_COMMAND):
response = await host.send_command(HCI_Read_BD_ADDR_Command())
if command_succeeded(response):
@@ -114,7 +84,7 @@ async def get_classic_info(host):
# -----------------------------------------------------------------------------
async def get_le_info(host):
async def get_le_info(host: Host) -> None:
print()
if host.supports_command(HCI_LE_READ_NUMBER_OF_SUPPORTED_ADVERTISING_SETS_COMMAND):
@@ -170,6 +140,31 @@ async def get_le_info(host):
print(' ', name_or_number(HCI_LE_SUPPORTED_FEATURES_NAMES, feature))
# -----------------------------------------------------------------------------
async def get_acl_flow_control_info(host: Host) -> None:
print()
if host.supports_command(HCI_READ_BUFFER_SIZE_COMMAND):
response = await host.send_command(
HCI_Read_Buffer_Size_Command(), check_result=True
)
print(
color('ACL Flow Control:', 'yellow'),
f'{response.return_parameters.hc_total_num_acl_data_packets} '
f'packets of size {response.return_parameters.hc_acl_data_packet_length}',
)
if host.supports_command(HCI_LE_READ_BUFFER_SIZE_COMMAND):
response = await host.send_command(
HCI_LE_Read_Buffer_Size_Command(), check_result=True
)
print(
color('LE ACL Flow Control:', 'yellow'),
f'{response.return_parameters.hc_total_num_le_acl_data_packets} '
f'packets of size {response.return_parameters.hc_le_acl_data_packet_length}',
)
# -----------------------------------------------------------------------------
async def async_main(transport):
print('<<< connecting to HCI...')
@@ -196,15 +191,15 @@ async def async_main(transport):
)
print(color(' LMP Subversion:', 'green'), host.local_version.lmp_subversion)
# Get the common info
await get_common_info(host)
# Get the Classic info
await get_classic_info(host)
# Get the LE info
await get_le_info(host)
# Print the ACL flow control info
await get_acl_flow_control_info(host)
# Print the list of commands supported by the controller
print()
print(color('Supported Commands:', 'yellow'))
+15 -15
View File
@@ -17,25 +17,31 @@
# -----------------------------------------------------------------------------
import logging
import asyncio
import sys
import os
import click
from bumble.controller import Controller, Options
from bumble.controller import Controller
from bumble.link import LocalLink
from bumble.transport import open_transport_or_link
# -----------------------------------------------------------------------------
async def async_main(extended_advertising, transport_names):
async def async_main():
if len(sys.argv) != 3:
print(
'Usage: controllers.py <hci-transport-1> <hci-transport-2> '
'[<hci-transport-3> ...]'
)
print('example: python controllers.py pty:ble1 pty:ble2')
return
# Create a local link to attach the controllers to
link = LocalLink()
# Create a transport and controller for all requested names
transports = []
controllers = []
options = Options(extended_advertising=extended_advertising)
for index, transport_name in enumerate(transport_names):
for index, transport_name in enumerate(sys.argv[1:]):
transport = await open_transport_or_link(transport_name)
transports.append(transport)
controller = Controller(
@@ -43,7 +49,6 @@ async def async_main(extended_advertising, transport_names):
host_source=transport.source,
host_sink=transport.sink,
link=link,
options=options,
)
controllers.append(controller)
@@ -56,14 +61,9 @@ async def async_main(extended_advertising, transport_names):
# -----------------------------------------------------------------------------
@click.command()
@click.option(
'--extended-advertising', is_flag=True, help="Enable extended advertising"
)
@click.argument('transports', nargs=-1, required=True)
def main(extended_advertising, transports):
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
asyncio.run(async_main(extended_advertising, transports))
def main():
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(async_main())
# -----------------------------------------------------------------------------
+3 -8
View File
@@ -253,7 +253,7 @@ class Relay:
# ----------------------------------------------------------------------------
async def async_main():
def main():
# Check the Python version
if sys.version_info < (3, 6, 1):
print('ERROR: Python 3.6.1 or higher is required')
@@ -280,13 +280,8 @@ async def async_main():
# Start a relay
relay = Relay(args.port)
async with relay.start():
await asyncio.Future()
# ----------------------------------------------------------------------------
def main():
asyncio.run(async_main())
asyncio.get_event_loop().run_until_complete(relay.start())
asyncio.get_event_loop().run_forever()
# ----------------------------------------------------------------------------
+53 -11
View File
@@ -25,9 +25,21 @@
from __future__ import annotations
import enum
import functools
import inspect
import struct
from typing import (
Any,
Awaitable,
Callable,
Dict,
List,
Optional,
Type,
Union,
TYPE_CHECKING,
)
from pyee import EventEmitter
from typing import Dict, Type, List, Protocol, Union, Optional, Any, TYPE_CHECKING
from bumble.core import UUID, name_or_number, ProtocolError
from bumble.hci import HCI_Object, key_with_value
@@ -722,12 +734,38 @@ class ATT_Handle_Value_Confirmation(ATT_PDU):
# -----------------------------------------------------------------------------
class ConnectionValue(Protocol):
def read(self, connection) -> bytes:
...
class AttributeValue:
'''
Attribute value where reading and/or writing is delegated to functions
passed as arguments to the constructor.
'''
def write(self, connection, value: bytes) -> None:
...
def __init__(
self,
read: Union[
Callable[[Optional[Connection]], bytes],
Callable[[Optional[Connection]], Awaitable[bytes]],
None,
] = None,
write: Union[
Callable[[Optional[Connection], bytes], None],
Callable[[Optional[Connection], bytes], Awaitable[None]],
None,
] = None,
):
self._read = read
self._write = write
def read(self, connection: Optional[Connection]) -> Union[bytes, Awaitable[bytes]]:
return self._read(connection) if self._read else b''
def write(
self, connection: Optional[Connection], value: bytes
) -> Union[Awaitable[None], None]:
if self._write:
return self._write(connection, value)
return None
# -----------------------------------------------------------------------------
@@ -770,13 +808,13 @@ class Attribute(EventEmitter):
READ_REQUIRES_AUTHORIZATION = Permissions.READ_REQUIRES_AUTHORIZATION
WRITE_REQUIRES_AUTHORIZATION = Permissions.WRITE_REQUIRES_AUTHORIZATION
value: Union[str, bytes, ConnectionValue]
value: Union[bytes, AttributeValue]
def __init__(
self,
attribute_type: Union[str, bytes, UUID],
permissions: Union[str, Attribute.Permissions],
value: Union[str, bytes, ConnectionValue] = b'',
value: Union[str, bytes, AttributeValue] = b'',
) -> None:
EventEmitter.__init__(self)
self.handle = 0
@@ -806,7 +844,7 @@ class Attribute(EventEmitter):
def decode_value(self, value_bytes: bytes) -> Any:
return value_bytes
def read_value(self, connection: Optional[Connection]) -> bytes:
async def read_value(self, connection: Optional[Connection]) -> bytes:
if (
(self.permissions & self.READ_REQUIRES_ENCRYPTION)
and connection is not None
@@ -832,6 +870,8 @@ class Attribute(EventEmitter):
if hasattr(self.value, 'read'):
try:
value = self.value.read(connection)
if inspect.isawaitable(value):
value = await value
except ATT_Error as error:
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
@@ -841,7 +881,7 @@ class Attribute(EventEmitter):
return self.encode_value(value)
def write_value(self, connection: Connection, value_bytes: bytes) -> None:
async def write_value(self, connection: Connection, value_bytes: bytes) -> None:
if (
self.permissions & self.WRITE_REQUIRES_ENCRYPTION
) and not connection.encryption:
@@ -864,7 +904,9 @@ class Attribute(EventEmitter):
if hasattr(self.value, 'write'):
try:
self.value.write(connection, value) # pylint: disable=not-callable
result = self.value.write(connection, value)
if inspect.isawaitable(result):
await result
except ATT_Error as error:
raise ATT_Error(
error_code=error.error_code, att_handle=self.handle
+125 -665
View File
File diff suppressed because it is too large Load Diff
+61 -49
View File
@@ -23,16 +23,28 @@
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations
import asyncio
import enum
import functools
import logging
import struct
from typing import Optional, Sequence, Iterable, List, Union
from typing import (
Callable,
Dict,
Iterable,
List,
Optional,
Sequence,
Union,
TYPE_CHECKING,
)
from .colors import color
from .core import UUID, get_dict_key_by_value
from .att import Attribute
from bumble.colors import color
from bumble.core import UUID
from bumble.att import Attribute, AttributeValue
if TYPE_CHECKING:
from bumble.gatt_client import AttributeProxy
from bumble.device import Connection
# -----------------------------------------------------------------------------
@@ -522,56 +534,43 @@ class CharacteristicDeclaration(Attribute):
# -----------------------------------------------------------------------------
class CharacteristicValue:
'''
Characteristic value where reading and/or writing is delegated to functions
passed as arguments to the constructor.
'''
def __init__(self, read=None, write=None):
self._read = read
self._write = write
def read(self, connection):
return self._read(connection) if self._read else b''
def write(self, connection, value):
if self._write:
self._write(connection, value)
class CharacteristicValue(AttributeValue):
"""Same as AttributeValue, for backward compatibility"""
# -----------------------------------------------------------------------------
class CharacteristicAdapter:
'''
An adapter that can adapt any object with `read_value` and `write_value`
methods (like Characteristic and CharacteristicProxy objects) by wrapping
those methods with ones that return/accept encoded/decoded values.
Objects with async methods are considered proxies, so the adaptation is one
where the return value of `read_value` is decoded and the value passed to
`write_value` is encoded. Other objects are considered local characteristics
so the adaptation is one where the return value of `read_value` is encoded
and the value passed to `write_value` is decoded.
If the characteristic has a `subscribe` method, it is wrapped with one where
the values are decoded before being passed to the subscriber.
An adapter that can adapt Characteristic and AttributeProxy objects
by wrapping their `read_value()` and `write_value()` methods with ones that
return/accept encoded/decoded values.
For proxies (i.e used by a GATT client), the adaptation is one where the return
value of `read_value()` is decoded and the value passed to `write_value()` is
encoded. The `subscribe()` method, is wrapped with one where the values are decoded
before being passed to the subscriber.
For local values (i.e hosted by a GATT server) the adaptation is one where the
return value of `read_value()` is encoded and the value passed to `write_value()`
is decoded.
'''
def __init__(self, characteristic):
self.wrapped_characteristic = characteristic
self.subscribers = {} # Map from subscriber to proxy subscriber
read_value: Callable
write_value: Callable
if asyncio.iscoroutinefunction(
characteristic.read_value
) and asyncio.iscoroutinefunction(characteristic.write_value):
self.read_value = self.read_decoded_value
self.write_value = self.write_decoded_value
else:
def __init__(self, characteristic: Union[Characteristic, AttributeProxy]):
self.wrapped_characteristic = characteristic
self.subscribers: Dict[
Callable, Callable
] = {} # Map from subscriber to proxy subscriber
if isinstance(characteristic, Characteristic):
self.read_value = self.read_encoded_value
self.write_value = self.write_encoded_value
if hasattr(self.wrapped_characteristic, 'subscribe'):
else:
self.read_value = self.read_decoded_value
self.write_value = self.write_decoded_value
self.subscribe = self.wrapped_subscribe
if hasattr(self.wrapped_characteristic, 'unsubscribe'):
self.unsubscribe = self.wrapped_unsubscribe
def __getattr__(self, name):
@@ -590,11 +589,13 @@ class CharacteristicAdapter:
else:
setattr(self.wrapped_characteristic, name, value)
def read_encoded_value(self, connection):
return self.encode_value(self.wrapped_characteristic.read_value(connection))
async def read_encoded_value(self, connection):
return self.encode_value(
await self.wrapped_characteristic.read_value(connection)
)
def write_encoded_value(self, connection, value):
return self.wrapped_characteristic.write_value(
async def write_encoded_value(self, connection, value):
return await self.wrapped_characteristic.write_value(
connection, self.decode_value(value)
)
@@ -729,13 +730,24 @@ class Descriptor(Attribute):
'''
def __str__(self) -> str:
if isinstance(self.value, bytes):
value_str = self.value.hex()
elif isinstance(self.value, CharacteristicValue):
value = self.value.read(None)
if isinstance(value, bytes):
value_str = value.hex()
else:
value_str = '<async>'
else:
value_str = '<...>'
return (
f'Descriptor(handle=0x{self.handle:04X}, '
f'type={self.type}, '
f'value={self.read_value(None).hex()})'
f'value={value_str})'
)
# -----------------------------------------------------------------------------
class ClientCharacteristicConfigurationBits(enum.IntFlag):
'''
See Vol 3, Part G - 3.3.3.3 - Table 3.11 Client Characteristic Configuration bit
+29 -21
View File
@@ -31,9 +31,9 @@ import struct
from typing import List, Tuple, Optional, TypeVar, Type, Dict, Iterable, TYPE_CHECKING
from pyee import EventEmitter
from .colors import color
from .core import UUID
from .att import (
from bumble.colors import color
from bumble.core import UUID
from bumble.att import (
ATT_ATTRIBUTE_NOT_FOUND_ERROR,
ATT_ATTRIBUTE_NOT_LONG_ERROR,
ATT_CID,
@@ -60,7 +60,7 @@ from .att import (
ATT_Write_Response,
Attribute,
)
from .gatt import (
from bumble.gatt import (
GATT_CHARACTERISTIC_ATTRIBUTE_TYPE,
GATT_CLIENT_CHARACTERISTIC_CONFIGURATION_DESCRIPTOR,
GATT_MAX_ATTRIBUTE_VALUE_SIZE,
@@ -74,6 +74,7 @@ from .gatt import (
Descriptor,
Service,
)
from bumble.utils import AsyncRunner
if TYPE_CHECKING:
from bumble.device import Device, Connection
@@ -379,7 +380,7 @@ class Server(EventEmitter):
# Get or encode the value
value = (
attribute.read_value(connection)
await attribute.read_value(connection)
if value is None
else attribute.encode_value(value)
)
@@ -422,7 +423,7 @@ class Server(EventEmitter):
# Get or encode the value
value = (
attribute.read_value(connection)
await attribute.read_value(connection)
if value is None
else attribute.encode_value(value)
)
@@ -650,7 +651,8 @@ class Server(EventEmitter):
self.send_response(connection, response)
def on_att_find_by_type_value_request(self, connection, request):
@AsyncRunner.run_in_task()
async def on_att_find_by_type_value_request(self, connection, request):
'''
See Bluetooth spec Vol 3, Part F - 3.4.3.3 Find By Type Value Request
'''
@@ -658,13 +660,13 @@ class Server(EventEmitter):
# Build list of returned attributes
pdu_space_available = connection.att_mtu - 2
attributes = []
for attribute in (
async for attribute in (
attribute
for attribute in self.attributes
if attribute.handle >= request.starting_handle
and attribute.handle <= request.ending_handle
and attribute.type == request.attribute_type
and attribute.read_value(connection) == request.attribute_value
and (await attribute.read_value(connection)) == request.attribute_value
and pdu_space_available >= 4
):
# TODO: check permissions
@@ -702,7 +704,8 @@ class Server(EventEmitter):
self.send_response(connection, response)
def on_att_read_by_type_request(self, connection, request):
@AsyncRunner.run_in_task()
async def on_att_read_by_type_request(self, connection, request):
'''
See Bluetooth spec Vol 3, Part F - 3.4.4.1 Read By Type Request
'''
@@ -725,7 +728,7 @@ class Server(EventEmitter):
and pdu_space_available
):
try:
attribute_value = attribute.read_value(connection)
attribute_value = await attribute.read_value(connection)
except ATT_Error as error:
# If the first attribute is unreadable, return an error
# Otherwise return attributes up to this point
@@ -767,14 +770,15 @@ class Server(EventEmitter):
self.send_response(connection, response)
def on_att_read_request(self, connection, request):
@AsyncRunner.run_in_task()
async def on_att_read_request(self, connection, request):
'''
See Bluetooth spec Vol 3, Part F - 3.4.4.3 Read Request
'''
if attribute := self.get_attribute(request.attribute_handle):
try:
value = attribute.read_value(connection)
value = await attribute.read_value(connection)
except ATT_Error as error:
response = ATT_Error_Response(
request_opcode_in_error=request.op_code,
@@ -792,14 +796,15 @@ class Server(EventEmitter):
)
self.send_response(connection, response)
def on_att_read_blob_request(self, connection, request):
@AsyncRunner.run_in_task()
async def on_att_read_blob_request(self, connection, request):
'''
See Bluetooth spec Vol 3, Part F - 3.4.4.5 Read Blob Request
'''
if attribute := self.get_attribute(request.attribute_handle):
try:
value = attribute.read_value(connection)
value = await attribute.read_value(connection)
except ATT_Error as error:
response = ATT_Error_Response(
request_opcode_in_error=request.op_code,
@@ -836,7 +841,8 @@ class Server(EventEmitter):
)
self.send_response(connection, response)
def on_att_read_by_group_type_request(self, connection, request):
@AsyncRunner.run_in_task()
async def on_att_read_by_group_type_request(self, connection, request):
'''
See Bluetooth spec Vol 3, Part F - 3.4.4.9 Read by Group Type Request
'''
@@ -864,7 +870,7 @@ class Server(EventEmitter):
):
# No need to catch permission errors here, since these attributes
# must all be world-readable
attribute_value = attribute.read_value(connection)
attribute_value = await attribute.read_value(connection)
# Check the attribute value size
max_attribute_size = min(connection.att_mtu - 6, 251)
if len(attribute_value) > max_attribute_size:
@@ -903,7 +909,8 @@ class Server(EventEmitter):
self.send_response(connection, response)
def on_att_write_request(self, connection, request):
@AsyncRunner.run_in_task()
async def on_att_write_request(self, connection, request):
'''
See Bluetooth spec Vol 3, Part F - 3.4.5.1 Write Request
'''
@@ -936,12 +943,13 @@ class Server(EventEmitter):
return
# Accept the value
attribute.write_value(connection, request.attribute_value)
await attribute.write_value(connection, request.attribute_value)
# Done
self.send_response(connection, ATT_Write_Response())
def on_att_write_command(self, connection, request):
@AsyncRunner.run_in_task()
async def on_att_write_command(self, connection, request):
'''
See Bluetooth spec Vol 3, Part F - 3.4.5.3 Write Command
'''
@@ -959,7 +967,7 @@ class Server(EventEmitter):
# Accept the value
try:
attribute.write_value(connection, request.attribute_value)
await attribute.write_value(connection, request.attribute_value)
except Exception as error:
logger.exception(f'!!! ignoring exception: {error}')
+4 -8
View File
@@ -3266,9 +3266,7 @@ class HCI_Read_Local_Supported_Commands_Command(HCI_Command):
# -----------------------------------------------------------------------------
@HCI_Command.command(
return_parameters_fields=[('status', STATUS_SPEC), ('lmp_features', 8)]
)
@HCI_Command.command()
class HCI_Read_Local_Supported_Features_Command(HCI_Command):
'''
See Bluetooth spec @ 7.4.3 Read Local Supported Features Command
@@ -3281,7 +3279,7 @@ class HCI_Read_Local_Supported_Features_Command(HCI_Command):
return_parameters_fields=[
('status', STATUS_SPEC),
('page_number', 1),
('max_page_number', 1),
('maximum_page_number', 1),
('extended_lmp_features', 8),
],
)
@@ -3450,9 +3448,7 @@ class HCI_LE_Set_Advertising_Parameters_Command(HCI_Command):
# -----------------------------------------------------------------------------
@HCI_Command.command(
return_parameters_fields=[('status', STATUS_SPEC), ('tx_power_level', 1)]
)
@HCI_Command.command()
class HCI_LE_Read_Advertising_Physical_Channel_Tx_Power_Command(HCI_Command):
'''
See Bluetooth spec @ 7.8.6 LE Read Advertising Physical Channel Tx Power Command
@@ -5833,7 +5829,7 @@ class HCI_Inquiry_Result_With_RSSI_Event(HCI_Event):
('status', STATUS_SPEC),
('connection_handle', 2),
('page_number', 1),
('max_page_number', 1),
('maximum_page_number', 1),
('extended_lmp_features', 8),
]
)
+1 -1
View File
@@ -402,7 +402,7 @@ class Device(HID):
report_type = pdu[0] & 0x03
buffer_flag = (pdu[0] & 0x08) >> 3
report_id = pdu[1]
logger.debug("buffer_flag: " + str(buffer_flag))
logger.debug(f"buffer_flag: {buffer_flag}")
if buffer_flag == 1:
buffer_size = (pdu[3] << 8) | pdu[2]
else:
+103 -74
View File
@@ -21,7 +21,7 @@ import collections
import logging
import struct
from typing import Any, Awaitable, Callable, Dict, Optional, Union, cast, TYPE_CHECKING
from typing import Any, Awaitable, Callable, Deque, Dict, Optional, cast, TYPE_CHECKING
from bumble.colors import color
from bumble.l2cap import L2CAP_PDU
@@ -91,16 +91,49 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
# fmt: off
class AclPacketQueue:
max_packet_size: int
HOST_DEFAULT_HC_LE_ACL_DATA_PACKET_LENGTH = 27
HOST_HC_TOTAL_NUM_LE_ACL_DATA_PACKETS = 1
HOST_DEFAULT_HC_ACL_DATA_PACKET_LENGTH = 27
HOST_HC_TOTAL_NUM_ACL_DATA_PACKETS = 1
def __init__(
self,
max_packet_size: int,
max_in_flight: int,
send: Callable[[HCI_Packet], None],
) -> None:
self.max_packet_size = max_packet_size
self.max_in_flight = max_in_flight
self.in_flight = 0
self.send = send
self.packets: Deque[HCI_AclDataPacket] = collections.deque()
# fmt: on
def enqueue(self, packet: HCI_AclDataPacket) -> None:
self.packets.appendleft(packet)
self.check_queue()
if self.packets:
logger.debug(
f'{self.in_flight} ACL packets in flight, '
f'{len(self.packets)} in queue'
)
def check_queue(self) -> None:
while self.packets and self.in_flight < self.max_in_flight:
packet = self.packets.pop()
self.send(packet)
self.in_flight += 1
def on_packets_completed(self, packet_count: int) -> None:
if packet_count > self.in_flight:
logger.warning(
color(
'!!! {packet_count} completed but only '
f'{self.in_flight} in flight'
)
)
packet_count = self.in_flight
self.in_flight -= packet_count
self.check_queue()
# -----------------------------------------------------------------------------
@@ -111,6 +144,13 @@ class Connection:
self.peer_address = peer_address
self.assembler = HCI_AclDataPacketAssembler(self.on_acl_pdu)
self.transport = transport
acl_packet_queue: Optional[AclPacketQueue] = (
host.le_acl_packet_queue
if transport == BT_LE_TRANSPORT
else host.acl_packet_queue
)
assert acl_packet_queue
self.acl_packet_queue = acl_packet_queue
def on_hci_acl_data_packet(self, packet: HCI_AclDataPacket) -> None:
self.assembler.feed_packet(packet)
@@ -123,7 +163,8 @@ class Connection:
# -----------------------------------------------------------------------------
class Host(AbortableEventEmitter):
connections: Dict[int, Connection]
acl_packet_queue: collections.deque[HCI_AclDataPacket]
acl_packet_queue: Optional[AclPacketQueue] = None
le_acl_packet_queue: Optional[AclPacketQueue] = None
hci_sink: Optional[TransportSink] = None
hci_metadata: Dict[str, Any]
long_term_key_provider: Optional[
@@ -143,12 +184,6 @@ class Host(AbortableEventEmitter):
self.connections = {} # Connections, by connection handle
self.pending_command = None
self.pending_response = None
self.hc_le_acl_data_packet_length = HOST_DEFAULT_HC_LE_ACL_DATA_PACKET_LENGTH
self.hc_total_num_le_acl_data_packets = HOST_HC_TOTAL_NUM_LE_ACL_DATA_PACKETS
self.hc_acl_data_packet_length = HOST_DEFAULT_HC_ACL_DATA_PACKET_LENGTH
self.hc_total_num_acl_data_packets = HOST_HC_TOTAL_NUM_ACL_DATA_PACKETS
self.acl_packet_queue = collections.deque()
self.acl_packets_in_flight = 0
self.local_version = None
self.local_supported_commands = bytes(64)
self.local_le_features = 0
@@ -254,46 +289,54 @@ class Host(AbortableEventEmitter):
response = await self.send_command(
HCI_Read_Buffer_Size_Command(), check_result=True
)
self.hc_acl_data_packet_length = (
hc_acl_data_packet_length = (
response.return_parameters.hc_acl_data_packet_length
)
self.hc_total_num_acl_data_packets = (
hc_total_num_acl_data_packets = (
response.return_parameters.hc_total_num_acl_data_packets
)
logger.debug(
'HCI ACL flow control: '
f'hc_acl_data_packet_length={self.hc_acl_data_packet_length},'
f'hc_total_num_acl_data_packets={self.hc_total_num_acl_data_packets}'
f'hc_acl_data_packet_length={hc_acl_data_packet_length},'
f'hc_total_num_acl_data_packets={hc_total_num_acl_data_packets}'
)
self.acl_packet_queue = AclPacketQueue(
max_packet_size=hc_acl_data_packet_length,
max_in_flight=hc_total_num_acl_data_packets,
send=self.send_hci_packet,
)
hc_le_acl_data_packet_length = 0
hc_total_num_le_acl_data_packets = 0
if self.supports_command(HCI_LE_READ_BUFFER_SIZE_COMMAND):
response = await self.send_command(
HCI_LE_Read_Buffer_Size_Command(), check_result=True
)
self.hc_le_acl_data_packet_length = (
hc_le_acl_data_packet_length = (
response.return_parameters.hc_le_acl_data_packet_length
)
self.hc_total_num_le_acl_data_packets = (
hc_total_num_le_acl_data_packets = (
response.return_parameters.hc_total_num_le_acl_data_packets
)
logger.debug(
'HCI LE ACL flow control: '
f'hc_le_acl_data_packet_length={self.hc_le_acl_data_packet_length},'
'hc_total_num_le_acl_data_packets='
f'{self.hc_total_num_le_acl_data_packets}'
f'hc_le_acl_data_packet_length={hc_le_acl_data_packet_length},'
f'hc_total_num_le_acl_data_packets={hc_total_num_le_acl_data_packets}'
)
if (
response.return_parameters.hc_le_acl_data_packet_length == 0
or response.return_parameters.hc_total_num_le_acl_data_packets == 0
):
# LE and Classic share the same values
self.hc_le_acl_data_packet_length = self.hc_acl_data_packet_length
self.hc_total_num_le_acl_data_packets = (
self.hc_total_num_acl_data_packets
)
if hc_le_acl_data_packet_length == 0 or hc_total_num_le_acl_data_packets == 0:
# LE and Classic share the same queue
self.le_acl_packet_queue = self.acl_packet_queue
else:
# Create a separate queue for LE
self.le_acl_packet_queue = AclPacketQueue(
max_packet_size=hc_le_acl_data_packet_length,
max_in_flight=hc_total_num_le_acl_data_packets,
send=self.send_hci_packet,
)
if self.supports_command(
HCI_LE_READ_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND
@@ -332,14 +375,13 @@ class Host(AbortableEventEmitter):
self.hci_metadata = getattr(source, 'metadata', self.hci_metadata)
def send_hci_packet(self, packet: HCI_Packet) -> None:
logger.debug(f'{color("### HOST -> CONTROLLER", "blue")}: {packet}')
if self.snooper:
self.snooper.snoop(bytes(packet), Snooper.Direction.HOST_TO_CONTROLLER)
if self.hci_sink:
self.hci_sink.on_packet(bytes(packet))
async def send_command(self, command, check_result=False):
logger.debug(f'{color("### HOST -> CONTROLLER", "blue")}: {command}')
# Wait until we can send (only one pending command at a time)
async with self.command_semaphore:
assert self.pending_command is None
@@ -387,6 +429,17 @@ class Host(AbortableEventEmitter):
asyncio.create_task(send_command(command))
def send_l2cap_pdu(self, connection_handle: int, cid: int, pdu: bytes) -> None:
if not (connection := self.connections.get(connection_handle)):
logger.warning(f'connection 0x{connection_handle:04X} not found')
return
packet_queue = connection.acl_packet_queue
if packet_queue is None:
logger.warning(
f'no ACL packet queue for connection 0x{connection_handle:04X}'
)
return
# Create a PDU
l2cap_pdu = bytes(L2CAP_PDU(cid, pdu))
# Send the data to the controller via ACL packets
@@ -394,8 +447,7 @@ class Host(AbortableEventEmitter):
offset = 0
pb_flag = 0
while bytes_remaining:
# TODO: support different LE/Classic lengths
data_total_length = min(bytes_remaining, self.hc_le_acl_data_packet_length)
data_total_length = min(bytes_remaining, packet_queue.max_packet_size)
acl_packet = HCI_AclDataPacket(
connection_handle=connection_handle,
pb_flag=pb_flag,
@@ -403,34 +455,12 @@ class Host(AbortableEventEmitter):
data_total_length=data_total_length,
data=l2cap_pdu[offset : offset + data_total_length],
)
logger.debug(
f'{color("### HOST -> CONTROLLER", "blue")}: (CID={cid}) {acl_packet}'
)
self.queue_acl_packet(acl_packet)
logger.debug(f'>>> ACL packet enqueue: (CID={cid}) {acl_packet}')
packet_queue.enqueue(acl_packet)
pb_flag = 1
offset += data_total_length
bytes_remaining -= data_total_length
def queue_acl_packet(self, acl_packet: HCI_AclDataPacket) -> None:
self.acl_packet_queue.appendleft(acl_packet)
self.check_acl_packet_queue()
if len(self.acl_packet_queue):
logger.debug(
f'{self.acl_packets_in_flight} ACL packets in flight, '
f'{len(self.acl_packet_queue)} in queue'
)
def check_acl_packet_queue(self) -> None:
# Send all we can (TODO: support different LE/Classic limits)
while (
len(self.acl_packet_queue) > 0
and self.acl_packets_in_flight < self.hc_total_num_le_acl_data_packets
):
packet = self.acl_packet_queue.pop()
self.acl_packets_in_flight += 1
self.send_hci_packet(packet)
def supports_command(self, command):
# Find the support flag position for this command
for octet, flags in enumerate(HCI_SUPPORTED_COMMANDS_FLAGS):
@@ -553,7 +583,7 @@ class Host(AbortableEventEmitter):
# This is used just for the Num_HCI_Command_Packets field, not related to
# an actual command
logger.debug('no-command event')
return None
return
return self.on_command_processed(event)
@@ -561,18 +591,17 @@ class Host(AbortableEventEmitter):
return self.on_command_processed(event)
def on_hci_number_of_completed_packets_event(self, event):
total_packets = sum(event.num_completed_packets)
if total_packets <= self.acl_packets_in_flight:
self.acl_packets_in_flight -= total_packets
self.check_acl_packet_queue()
else:
logger.warning(
color(
f'!!! {total_packets} completed but only '
f'{self.acl_packets_in_flight} in flight'
for connection_handle, num_completed_packets in zip(
event.connection_handles, event.num_completed_packets
):
if not (connection := self.connections.get(connection_handle)):
logger.warning(
'received packet completion event for unknown handle '
f'0x{connection_handle:04X}'
)
)
self.acl_packets_in_flight = 0
continue
connection.acl_packet_queue.on_packets_completed(num_completed_packets)
# Classic only
def on_hci_connection_request_event(self, event):
+2 -2
View File
@@ -151,8 +151,8 @@ L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_CREDITS = 65535
L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU = 23
L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MPS = 23
L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MPS = 65533
L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU = 2046
L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS = 2048
L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU = 2048
L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS = 2046
L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS = 256
L2CAP_MAXIMUM_TRANSMISSION_UNIT_CONFIGURATION_OPTION_TYPE = 0x01
+18 -32
View File
@@ -95,21 +95,11 @@ class LocalLink:
def on_address_changed(self, controller):
pass
def send_advertising_data(self, sender_address, data, scan_response):
def send_advertising_data(self, sender_address, data):
# Send the advertising data to all controllers, except the sender
for controller in self.controllers:
if controller.random_address != sender_address:
controller.on_link_advertising_data(sender_address, data, scan_response)
def send_extended_advertising_data(
self, sender_address, event_type, data, scan_response
):
# Send the advertising data to all controllers, except the sender
for controller in self.controllers:
if controller.random_address != sender_address:
controller.on_link_extended_advertising_data(
sender_address, event_type, data, scan_response
)
controller.on_link_advertising_data(sender_address, data)
def send_acl_data(self, sender_controller, destination_address, transport, data):
# Send the data to the first controller with a matching address
@@ -161,34 +151,30 @@ class LocalLink:
asyncio.get_running_loop().call_soon(self.on_connection_complete)
def on_disconnection_complete(
self, initiator_address, peer_address, disconnect_command
self, central_address, peripheral_address, disconnect_command
):
# Find the controller that initiated the disconnection
if not (initiator_controller := self.find_controller(initiator_address)):
if not (central_controller := self.find_controller(central_address)):
logger.warning('!!! Initiating controller not found')
return
# Disconnect from the first controller with a matching address
if peer_controller := self.find_controller(peer_address):
peer_controller.on_link_peer_disconnected(
initiator_address, disconnect_command.reason
if peripheral_controller := self.find_controller(peripheral_address):
peripheral_controller.on_link_central_disconnected(
central_address, disconnect_command.reason
)
initiator_controller.on_link_initiated_disconnection_complete(
central_controller.on_link_peripheral_disconnection_complete(
disconnect_command, HCI_SUCCESS
)
def disconnect(self, initiator_address, peer_address, disconnect_command):
def disconnect(self, central_address, peripheral_address, disconnect_command):
logger.debug(
f'$$$ DISCONNECTION {initiator_address} -> '
f'{peer_address}: reason = {disconnect_command.reason}'
)
asyncio.get_running_loop().call_soon(
self.on_disconnection_complete,
initiator_address,
peer_address,
disconnect_command,
f'$$$ DISCONNECTION {central_address} -> '
f'{peripheral_address}: reason = {disconnect_command.reason}'
)
args = [central_address, peripheral_address, disconnect_command]
asyncio.get_running_loop().call_soon(self.on_disconnection_complete, *args)
# pylint: disable=too-many-arguments
def on_connection_encrypted(
@@ -374,11 +360,11 @@ class RemoteLink:
async def on_left_received(self, address):
if address in self.central_connections:
self.controller.on_link_connection_lost(Address(address))
self.controller.on_link_peripheral_disconnected(Address(address))
self.central_connections.remove(address)
if address in self.peripheral_connections:
self.controller.on_link_peer_disconnected(
self.controller.on_link_central_disconnected(
address, HCI_CONNECTION_TIMEOUT_ERROR
)
self.peripheral_connections.remove(address)
@@ -398,7 +384,7 @@ class RemoteLink:
async def on_advertisement_message_received(self, sender, advertisement):
try:
self.controller.on_link_advertising_data(
Address(sender), bytes.fromhex(advertisement), b''
Address(sender), bytes.fromhex(advertisement)
)
except Exception:
logger.exception('exception')
@@ -438,7 +424,7 @@ class RemoteLink:
# Notify the controller
params = parse_parameters(message)
reason = int(params.get('reason', str(HCI_CONNECTION_TIMEOUT_ERROR)))
self.controller.on_link_peer_disconnected(Address(sender), reason)
self.controller.on_link_central_disconnected(Address(sender), reason)
# Forget the connection
if sender in self.peripheral_connections:
@@ -485,7 +471,7 @@ class RemoteLink:
async def send_advertising_data_to_relay(self, data):
await self.send_targeted_message('*', f'advertisement:{data.hex()}')
def send_advertising_data(self, _, data, scan_response):
def send_advertising_data(self, _, data):
self.execute(partial(self.send_advertising_data_to_relay, data))
async def send_acl_data_to_relay(self, peer_address, data):
+2 -2
View File
@@ -18,7 +18,7 @@
# -----------------------------------------------------------------------------
import struct
import logging
from typing import List
from typing import List, Optional
from bumble import l2cap
from ..core import AdvertisingData
@@ -67,7 +67,7 @@ class AshaService(TemplateService):
self.emit('volume', connection, value[0])
# Handler for audio control commands
def on_audio_control_point_write(connection: Connection, value):
def on_audio_control_point_write(connection: Optional[Connection], value):
logger.info(f'--- AUDIO CONTROL POINT Write:{value.hex()}')
opcode = value[0]
if opcode == AshaService.OPCODE_START:
+3 -3
View File
@@ -114,7 +114,7 @@ class SamplingFrequency(enum.IntEnum):
'''Bluetooth Assigned Numbers, Section 6.12.5.1 - Sampling Frequency'''
# fmt: off
FREQ_8000 = 0x01
FREQ_8000 = 0x01
FREQ_11025 = 0x02
FREQ_16000 = 0x03
FREQ_22050 = 0x04
@@ -430,7 +430,7 @@ class AseResponseCode(enum.IntEnum):
REJECTED_METADATA = 0x0B
INVALID_METADATA = 0x0C
INSUFFICIENT_RESOURCES = 0x0D
UNSPECIFIED_ERROR = 0x0E
UNSPECIFIED_ERROR = 0x0E
class AseReasonCode(enum.IntEnum):
@@ -1066,7 +1066,7 @@ class AseStateMachine(gatt.Characteristic):
# Readonly. Do nothing in the setter.
pass
def on_read(self, _: device.Connection) -> bytes:
def on_read(self, _: Optional[device.Connection]) -> bytes:
return self.value
def __str__(self) -> str:
+1 -1
View File
@@ -152,7 +152,7 @@ class CoordinatedSetIdentificationService(gatt.TemplateService):
super().__init__(characteristics)
def on_sirk_read(self, _connection: device.Connection) -> bytes:
def on_sirk_read(self, _connection: Optional[device.Connection]) -> bytes:
if self.set_identity_resolving_key_type == SirkType.PLAINTEXT:
return bytes([SirkType.PLAINTEXT]) + self.set_identity_resolving_key
else:
+24 -17
View File
@@ -118,8 +118,8 @@ CRC_TABLE = bytes([
0XBA, 0X2B, 0X59, 0XC8, 0XBD, 0X2C, 0X5E, 0XCF
])
RFCOMM_DEFAULT_INITIAL_RX_CREDITS = 7
RFCOMM_DEFAULT_PREFERRED_MTU = 1280
RFCOMM_DEFAULT_WINDOW_SIZE = 16
RFCOMM_DEFAULT_MAX_FRAME_SIZE = 2000
RFCOMM_DYNAMIC_CHANNEL_NUMBER_START = 1
RFCOMM_DYNAMIC_CHANNEL_NUMBER_END = 30
@@ -438,14 +438,16 @@ class DLC(EventEmitter):
multiplexer: Multiplexer,
dlci: int,
max_frame_size: int,
initial_tx_credits: int,
window_size: int,
) -> None:
super().__init__()
self.multiplexer = multiplexer
self.dlci = dlci
self.rx_credits = RFCOMM_DEFAULT_INITIAL_RX_CREDITS
self.rx_threshold = self.rx_credits // 2
self.tx_credits = initial_tx_credits
self.max_frame_size = max_frame_size
self.window_size = window_size
self.rx_credits = window_size
self.rx_threshold = window_size // 2
self.tx_credits = window_size
self.tx_buffer = b''
self.state = DLC.State.INIT
self.role = multiplexer.role
@@ -537,11 +539,11 @@ class DLC(EventEmitter):
if len(data) and self.sink:
self.sink(data) # pylint: disable=not-callable
# Update the credits
if self.rx_credits > 0:
self.rx_credits -= 1
else:
logger.warning(color('!!! received frame with no rx credits', 'red'))
# Update the credits
if self.rx_credits > 0:
self.rx_credits -= 1
else:
logger.warning(color('!!! received frame with no rx credits', 'red'))
# Check if there's anything to send (including credits)
self.process_tx()
@@ -580,9 +582,9 @@ class DLC(EventEmitter):
cl=0xE0,
priority=7,
ack_timer=0,
max_frame_size=RFCOMM_DEFAULT_PREFERRED_MTU,
max_frame_size=self.max_frame_size,
max_retransmissions=0,
window_size=RFCOMM_DEFAULT_INITIAL_RX_CREDITS,
window_size=self.window_size,
)
mcc = RFCOMM_Frame.make_mcc(mcc_type=RFCOMM_MCC_PN_TYPE, c_r=0, data=bytes(pn))
logger.debug(f'>>> PN Response: {pn}')
@@ -591,7 +593,7 @@ class DLC(EventEmitter):
def rx_credits_needed(self) -> int:
if self.rx_credits <= self.rx_threshold:
return RFCOMM_DEFAULT_INITIAL_RX_CREDITS - self.rx_credits
return self.window_size - self.rx_credits
return 0
@@ -843,7 +845,12 @@ class Multiplexer(EventEmitter):
)
await self.disconnection_result
async def open_dlc(self, channel: int) -> DLC:
async def open_dlc(
self,
channel: int,
max_frame_size: int = RFCOMM_DEFAULT_MAX_FRAME_SIZE,
window_size: int = RFCOMM_DEFAULT_WINDOW_SIZE,
) -> DLC:
if self.state != Multiplexer.State.CONNECTED:
if self.state == Multiplexer.State.OPENING:
raise InvalidStateError('open already in progress')
@@ -855,9 +862,9 @@ class Multiplexer(EventEmitter):
cl=0xF0,
priority=7,
ack_timer=0,
max_frame_size=RFCOMM_DEFAULT_PREFERRED_MTU,
max_frame_size=max_frame_size,
max_retransmissions=0,
window_size=RFCOMM_DEFAULT_INITIAL_RX_CREDITS,
window_size=window_size,
)
mcc = RFCOMM_Frame.make_mcc(mcc_type=RFCOMM_MCC_PN_TYPE, c_r=1, data=bytes(pn))
logger.debug(f'>>> Sending MCC: {pn}')
+1 -1
View File
@@ -108,7 +108,7 @@ async def open_usb_transport(spec: str) -> Transport:
USB_DEVICE_PROTOCOL_BLUETOOTH_PRIMARY_CONTROLLER,
)
READ_SIZE = 1024
READ_SIZE = 4096
class UsbPacketSink:
def __init__(self, device, acl_out):
+3 -6
View File
@@ -280,17 +280,14 @@ class AsyncRunner:
def wrapper(*args, **kwargs):
coroutine = func(*args, **kwargs)
if queue is None:
# Create a task to run the coroutine
# Spawn the coroutine as a task
async def run():
try:
await coroutine
except Exception:
logger.warning(
f'{color("!!! Exception in wrapper:", "red")} '
f'{traceback.format_exc()}'
)
logger.exception(color("!!! Exception in wrapper:", "red"))
asyncio.create_task(run())
AsyncRunner.spawn(run())
else:
# Queue the coroutine to be awaited by the work queue
queue.enqueue(coroutine)
+2 -2
View File
@@ -590,12 +590,12 @@ async def main():
def on_set_protocol_cb(protocol: int):
retValue = hid_device.GetSetStatus()
# We do not support SET_PROTOCOL.
print("SET_PROTOCOL report_id: " + str(protocol))
print(f"SET_PROTOCOL report_id: {protocol}")
retValue.status = hid_device.GetSetReturn.ERR_UNSUPPORTED_REQUEST
return retValue
def on_virtual_cable_unplug_cb():
print(f'Received Virtual Cable Unplug')
print('Received Virtual Cable Unplug')
asyncio.create_task(handle_virtual_cable_unplug())
print('<<< connecting to HCI...')
@@ -28,8 +28,8 @@ private val Log = Logger.getLogger("btbench.l2cap-client")
class L2capClient(
private val viewModel: AppViewModel,
val bluetoothAdapter: BluetoothAdapter,
val context: Context
private val bluetoothAdapter: BluetoothAdapter,
private val context: Context
) {
@SuppressLint("MissingPermission")
fun run() {
@@ -80,6 +80,10 @@ class L2capClient(
BluetoothDevice.PHY_OPTION_NO_PREFERRED
)
gatt.readPhy()
// Request an MTU update, even though we don't use GATT, because Android
// won't request a larger link layer maximum data length otherwise.
gatt.requestMtu(517)
}
}
},
@@ -23,7 +23,7 @@ import androidx.compose.runtime.setValue
import androidx.lifecycle.ViewModel
import java.util.UUID
val DEFAULT_RFCOMM_UUID = UUID.fromString("E6D55659-C8B4-4B85-96BB-B1143AF6D3AE")
val DEFAULT_RFCOMM_UUID: UUID = UUID.fromString("E6D55659-C8B4-4B85-96BB-B1143AF6D3AE")
const val DEFAULT_PEER_BLUETOOTH_ADDRESS = "AA:BB:CC:DD:EE:FF"
const val DEFAULT_SENDER_PACKET_COUNT = 100
const val DEFAULT_SENDER_PACKET_SIZE = 1024
@@ -31,11 +31,11 @@ const val DEFAULT_SENDER_PACKET_SIZE = 1024
class AppViewModel : ViewModel() {
private var preferences: SharedPreferences? = null
var peerBluetoothAddress by mutableStateOf(DEFAULT_PEER_BLUETOOTH_ADDRESS)
var l2capPsm by mutableStateOf(0)
var l2capPsm by mutableIntStateOf(0)
var use2mPhy by mutableStateOf(true)
var mtu by mutableStateOf(0)
var rxPhy by mutableStateOf(0)
var txPhy by mutableStateOf(0)
var mtu by mutableIntStateOf(0)
var rxPhy by mutableIntStateOf(0)
var txPhy by mutableIntStateOf(0)
var senderPacketCountSlider by mutableFloatStateOf(0.0F)
var senderPacketSizeSlider by mutableFloatStateOf(0.0F)
var senderPacketCount by mutableIntStateOf(DEFAULT_SENDER_PACKET_COUNT)
@@ -79,18 +79,18 @@ class AppViewModel : ViewModel() {
}
fun updateSenderPacketCountSlider() {
if (senderPacketCount <= 10) {
senderPacketCountSlider = 0.0F
senderPacketCountSlider = if (senderPacketCount <= 10) {
0.0F
} else if (senderPacketCount <= 50) {
senderPacketCountSlider = 0.2F
0.2F
} else if (senderPacketCount <= 100) {
senderPacketCountSlider = 0.4F
0.4F
} else if (senderPacketCount <= 500) {
senderPacketCountSlider = 0.6F
0.6F
} else if (senderPacketCount <= 1000) {
senderPacketCountSlider = 0.8F
0.8F
} else {
senderPacketCountSlider = 1.0F
1.0F
}
with(preferences!!.edit()) {
@@ -100,18 +100,18 @@ class AppViewModel : ViewModel() {
}
fun updateSenderPacketCount() {
if (senderPacketCountSlider < 0.1F) {
senderPacketCount = 10
senderPacketCount = if (senderPacketCountSlider < 0.1F) {
10
} else if (senderPacketCountSlider < 0.3F) {
senderPacketCount = 50
50
} else if (senderPacketCountSlider < 0.5F) {
senderPacketCount = 100
100
} else if (senderPacketCountSlider < 0.7F) {
senderPacketCount = 500
500
} else if (senderPacketCountSlider < 0.9F) {
senderPacketCount = 1000
1000
} else {
senderPacketCount = 10000
10000
}
with(preferences!!.edit()) {
@@ -121,18 +121,18 @@ class AppViewModel : ViewModel() {
}
fun updateSenderPacketSizeSlider() {
if (senderPacketSize <= 16) {
senderPacketSizeSlider = 0.0F
senderPacketSizeSlider = if (senderPacketSize <= 16) {
0.0F
} else if (senderPacketSize <= 256) {
senderPacketSizeSlider = 0.02F
0.02F
} else if (senderPacketSize <= 512) {
senderPacketSizeSlider = 0.4F
0.4F
} else if (senderPacketSize <= 1024) {
senderPacketSizeSlider = 0.6F
0.6F
} else if (senderPacketSize <= 2048) {
senderPacketSizeSlider = 0.8F
0.8F
} else {
senderPacketSizeSlider = 1.0F
1.0F
}
with(preferences!!.edit()) {
@@ -142,18 +142,18 @@ class AppViewModel : ViewModel() {
}
fun updateSenderPacketSize() {
if (senderPacketSizeSlider < 0.1F) {
senderPacketSize = 16
senderPacketSize = if (senderPacketSizeSlider < 0.1F) {
16
} else if (senderPacketSizeSlider < 0.3F) {
senderPacketSize = 256
256
} else if (senderPacketSizeSlider < 0.5F) {
senderPacketSize = 512
512
} else if (senderPacketSizeSlider < 0.7F) {
senderPacketSize = 1024
1024
} else if (senderPacketSizeSlider < 0.9F) {
senderPacketSize = 2048
2048
} else {
senderPacketSize = 4096
4096
}
with(preferences!!.edit()) {
@@ -42,6 +42,7 @@ public class HciServer {
try (ServerSocket serverSocket = new ServerSocket(mPort)) {
mListener.onMessage("Waiting for connection on port " + serverSocket.getLocalPort());
try (Socket clientSocket = serverSocket.accept()) {
clientSocket.setTcpNoDelay(true);
mListener.onHostConnectionState(true);
mListener.onMessage("Connected");
HciParser parser = new HciParser(mListener);
+2 -1
View File
@@ -48,7 +48,8 @@ from bumble.profiles.bap import (
PublishedAudioCapabilitiesService,
PublishedAudioCapabilitiesServiceProxy,
)
from .test_utils import TwoDevices
from tests.test_utils import TwoDevices
# -----------------------------------------------------------------------------
# Logging
-138
View File
@@ -1,138 +0,0 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import logging
import os
import pytest
from typing import List, Optional
from unittest.mock import MagicMock
from bumble.device import Connection, Device
from bumble.host import Host
from bumble.link import LocalLink
from bumble.controller import Controller
from bumble.hci import (
Address,
HCI_CONNECTION_TERMINATED_BY_LOCAL_HOST_ERROR,
HCI_REMOTE_DEVICE_TERMINATED_CONNECTION_DUE_TO_LOW_RESOURCES_ERROR,
)
from bumble.transport import AsyncPipeSink
# -----------------------------------------------------------------------------
class TwoDevices:
connections: List[Optional[Connection]]
def __init__(self) -> None:
self.connections = [None, None]
self.link = LocalLink()
self.controllers = [
Controller('C1', link=self.link),
Controller('C2', link=self.link),
]
self.devices = [
Device(
address=Address('F0:F1:F2:F3:F4:F5'),
host=Host(self.controllers[0], AsyncPipeSink(self.controllers[0])),
),
Device(
address=Address('F5:F4:F3:F2:F1:F0'),
host=Host(self.controllers[1], AsyncPipeSink(self.controllers[1])),
),
]
self.paired = [None, None]
def on_connection(self, which, connection):
self.connections[which] = connection
connection.on(
'disconnection', lambda reason: self.on_disconnection(which, reason)
)
def on_disconnection(self, which, _):
self.connections[which] = None
async def setup(self):
self.devices[0].on(
'connection', lambda connection: self.on_connection(0, connection)
)
self.devices[1].on(
'connection', lambda connection: self.on_connection(1, connection)
)
await self.devices[0].power_on()
await self.devices[1].power_on()
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_self_connection():
two_devices = TwoDevices()
await two_devices.setup()
await two_devices.devices[0].connect(two_devices.devices[1].random_address)
assert two_devices.connections[0] is not None
assert two_devices.connections[1] is not None
mock0 = MagicMock()
mock1 = MagicMock()
two_devices.connections[0].once('disconnection', mock0)
two_devices.connections[1].once('disconnection', mock1)
await two_devices.connections[0].disconnect(
HCI_REMOTE_DEVICE_TERMINATED_CONNECTION_DUE_TO_LOW_RESOURCES_ERROR
)
mock0.assert_called_once_with(HCI_CONNECTION_TERMINATED_BY_LOCAL_HOST_ERROR)
mock1.assert_called_once_with(
HCI_REMOTE_DEVICE_TERMINATED_CONNECTION_DUE_TO_LOW_RESOURCES_ERROR
)
assert two_devices.connections[0] is None
assert two_devices.connections[1] is None
await two_devices.devices[0].connect(two_devices.devices[1].random_address)
assert two_devices.connections[0] is not None
assert two_devices.connections[1] is not None
mock0 = MagicMock()
mock1 = MagicMock()
two_devices.connections[0].once('disconnection', mock0)
two_devices.connections[1].once('disconnection', mock1)
await two_devices.connections[1].disconnect(
HCI_REMOTE_DEVICE_TERMINATED_CONNECTION_DUE_TO_LOW_RESOURCES_ERROR
)
mock1.assert_called_once_with(HCI_CONNECTION_TERMINATED_BY_LOCAL_HOST_ERROR)
mock0.assert_called_once_with(
HCI_REMOTE_DEVICE_TERMINATED_CONNECTION_DUE_TO_LOW_RESOURCES_ERROR
)
assert two_devices.connections[0] is None
assert two_devices.connections[1] is None
# -----------------------------------------------------------------------------
async def run_test_controller():
await test_self_connection()
# -----------------------------------------------------------------------------
if __name__ == '__main__':
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(run_test_controller())
+8 -1
View File
@@ -29,7 +29,7 @@ from bumble.core import (
ConnectionParameters,
)
from bumble.device import Connection, Device
from bumble.host import Host
from bumble.host import AclPacketQueue, Host
from bumble.hci import (
HCI_ACCEPT_CONNECTION_REQUEST_COMMAND,
HCI_COMMAND_STATUS_PENDING,
@@ -73,6 +73,13 @@ async def test_device_connect_parallel():
d1 = Device(host=Host(None, None))
d2 = Device(host=Host(None, None))
def _send(packet):
pass
d0.host.acl_packet_queue = AclPacketQueue(0, 0, _send)
d1.host.acl_packet_queue = AclPacketQueue(0, 0, _send)
d2.host.acl_packet_queue = AclPacketQueue(0, 0, _send)
# enable classic
d0.classic_enabled = True
d1.classic_enabled = True
+76 -31
View File
@@ -20,11 +20,10 @@ import logging
import os
import struct
import pytest
from unittest.mock import Mock, ANY
from unittest.mock import AsyncMock, Mock, ANY
from bumble.controller import Controller
from bumble.gatt_client import CharacteristicProxy
from bumble.gatt_server import Server
from bumble.link import LocalLink
from bumble.device import Device, Peer
from bumble.host import Host
@@ -120,9 +119,9 @@ async def test_characteristic_encoding():
Characteristic.READABLE,
123,
)
x = c.read_value(None)
x = await c.read_value(None)
assert x == bytes([123])
c.write_value(None, bytes([122]))
await c.write_value(None, bytes([122]))
assert c.value == 122
class FooProxy(CharacteristicProxy):
@@ -152,7 +151,22 @@ async def test_characteristic_encoding():
bytes([123]),
)
service = Service('3A657F47-D34F-46B3-B1EC-698E29B6B829', [characteristic])
async def async_read(connection):
return 0x05060708
async_characteristic = PackedCharacteristicAdapter(
Characteristic(
'2AB7E91B-43E8-4F73-AC3B-80C1683B47F9',
Characteristic.Properties.READ,
Characteristic.READABLE,
CharacteristicValue(read=async_read),
),
'>I',
)
service = Service(
'3A657F47-D34F-46B3-B1EC-698E29B6B829', [characteristic, async_characteristic]
)
server.add_service(service)
await client.power_on()
@@ -184,6 +198,13 @@ async def test_characteristic_encoding():
await async_barrier()
assert characteristic.value == bytes([50])
c2 = peer.get_characteristics_by_uuid(async_characteristic.uuid)
assert len(c2) == 1
c2 = c2[0]
cd2 = PackedCharacteristicAdapter(c2, ">I")
cd2v = await cd2.read_value()
assert cd2v == 0x05060708
last_change = None
def on_change(value):
@@ -285,7 +306,8 @@ async def test_attribute_getters():
# -----------------------------------------------------------------------------
def test_CharacteristicAdapter():
@pytest.mark.asyncio
async def test_CharacteristicAdapter():
# Check that the CharacteristicAdapter base class is transparent
v = bytes([1, 2, 3])
c = Characteristic(
@@ -296,11 +318,11 @@ def test_CharacteristicAdapter():
)
a = CharacteristicAdapter(c)
value = a.read_value(None)
value = await a.read_value(None)
assert value == v
v = bytes([3, 4, 5])
a.write_value(None, v)
await a.write_value(None, v)
assert c.value == v
# Simple delegated adapter
@@ -308,11 +330,11 @@ def test_CharacteristicAdapter():
c, lambda x: bytes(reversed(x)), lambda x: bytes(reversed(x))
)
value = a.read_value(None)
value = await a.read_value(None)
assert value == bytes(reversed(v))
v = bytes([3, 4, 5])
a.write_value(None, v)
await a.write_value(None, v)
assert a.value == bytes(reversed(v))
# Packed adapter with single element format
@@ -321,10 +343,10 @@ def test_CharacteristicAdapter():
c.value = v
a = PackedCharacteristicAdapter(c, '>H')
value = a.read_value(None)
value = await a.read_value(None)
assert value == pv
c.value = None
a.write_value(None, pv)
await a.write_value(None, pv)
assert a.value == v
# Packed adapter with multi-element format
@@ -334,10 +356,10 @@ def test_CharacteristicAdapter():
c.value = (v1, v2)
a = PackedCharacteristicAdapter(c, '>HH')
value = a.read_value(None)
value = await a.read_value(None)
assert value == pv
c.value = None
a.write_value(None, pv)
await a.write_value(None, pv)
assert a.value == (v1, v2)
# Mapped adapter
@@ -348,10 +370,10 @@ def test_CharacteristicAdapter():
c.value = mapped
a = MappedCharacteristicAdapter(c, '>HH', ('v1', 'v2'))
value = a.read_value(None)
value = await a.read_value(None)
assert value == pv
c.value = None
a.write_value(None, pv)
await a.write_value(None, pv)
assert a.value == mapped
# UTF-8 adapter
@@ -360,27 +382,49 @@ def test_CharacteristicAdapter():
c.value = v
a = UTF8CharacteristicAdapter(c)
value = a.read_value(None)
value = await a.read_value(None)
assert value == ev
c.value = None
a.write_value(None, ev)
await a.write_value(None, ev)
assert a.value == v
# -----------------------------------------------------------------------------
def test_CharacteristicValue():
@pytest.mark.asyncio
async def test_CharacteristicValue():
b = bytes([1, 2, 3])
c = CharacteristicValue(read=lambda _: b)
x = c.read(None)
async def read_value(connection):
return b
c = CharacteristicValue(read=read_value)
x = await c.read(None)
assert x == b
result = []
c = CharacteristicValue(
write=lambda connection, value: result.append((connection, value))
)
m = Mock()
c = CharacteristicValue(write=m)
z = object()
c.write(z, b)
assert result == [(z, b)]
m.assert_called_once_with(z, b)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_CharacteristicValue_async():
b = bytes([1, 2, 3])
async def read_value(connection):
return b
c = CharacteristicValue(read=read_value)
x = await c.read(None)
assert x == b
m = AsyncMock()
c = CharacteristicValue(write=m)
z = object()
await c.write(z, b)
m.assert_called_once_with(z, b)
# -----------------------------------------------------------------------------
@@ -961,12 +1005,18 @@ Descriptor(handle=0x0009, type=UUID-16:2902 (Client Characteristic Configuration
# -----------------------------------------------------------------------------
async def async_main():
test_UUID()
test_ATT_Error_Response()
test_ATT_Read_By_Group_Type_Request()
await test_read_write()
await test_read_write2()
await test_subscribe_notify()
await test_unsubscribe()
await test_characteristic_encoding()
await test_mtu_exchange()
await test_CharacteristicValue()
await test_CharacteristicValue_async()
await test_CharacteristicAdapter()
# -----------------------------------------------------------------------------
@@ -1105,9 +1155,4 @@ def test_get_attribute_group():
# -----------------------------------------------------------------------------
if __name__ == '__main__':
logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
test_UUID()
test_ATT_Error_Response()
test_ATT_Read_By_Group_Type_Request()
test_CharacteristicValue()
test_CharacteristicAdapter()
asyncio.run(async_main())