Add L2CAP CoC support (squashed)

[85542e0] fix test
[3748781] add ASAH sink example
[e782e29] add app
[83daa30] wip
[7f138a0] add test
[f732108] allow different address syntax
[9d0bbf8] rename deprecated methods
[eb303d5] add LE CoC support
This commit is contained in:
Gilles Boccon-Gibod
2022-06-02 11:44:01 -07:00
committed by Gilles Boccon-Gibod
parent be8f8ac68f
commit ce9004f0ac
19 changed files with 1882 additions and 187 deletions

View File

@@ -17,13 +17,14 @@
# -----------------------------------------------------------------------------
import asyncio
import os
import struct
import logging
import click
from colors import color
from bumble.device import Device, Peer
from bumble.core import AdvertisingData
from bumble.gatt import Service, Characteristic
from bumble.gatt import Service, Characteristic, CharacteristicValue
from bumble.utils import AsyncRunner
from bumble.transport import open_transport_or_link
from bumble.hci import HCI_Constant
@@ -41,13 +42,59 @@ GG_PREFERRED_MTU = 256
# -----------------------------------------------------------------------------
class GattlinkHubBridge(Device.Listener):
class GattlinkL2capEndpoint:
def __init__(self):
self.peer = None
self.rx_socket = None
self.tx_socket = None
self.rx_characteristic = None
self.tx_characteristic = None
self.l2cap_channel = None
self.l2cap_packet = b''
self.l2cap_packet_size = 0
# Called when an L2CAP SDU has been received
def on_coc_sdu(self, sdu):
print(color(f'<<< [L2CAP SDU]: {len(sdu)} bytes', 'cyan'))
while len(sdu):
if self.l2cap_packet_size == 0:
# Expect a new packet
self.l2cap_packet_size = sdu[0] + 1
sdu = sdu[1:]
else:
bytes_needed = self.l2cap_packet_size - len(self.l2cap_packet)
chunk = min(bytes_needed, len(sdu))
self.l2cap_packet += sdu[:chunk]
sdu = sdu[chunk:]
if len(self.l2cap_packet) == self.l2cap_packet_size:
self.on_l2cap_packet(self.l2cap_packet)
self.l2cap_packet = b''
self.l2cap_packet_size = 0
# -----------------------------------------------------------------------------
class GattlinkHubBridge(GattlinkL2capEndpoint, Device.Listener):
def __init__(self, device, peer_address):
super().__init__()
self.device = device
self.peer_address = peer_address
self.peer = None
self.tx_socket = None
self.rx_characteristic = None
self.tx_characteristic = None
self.l2cap_psm_characteristic = None
device.listener = self
async def start(self):
# Connect to the peer
print(f'=== Connecting to {self.peer_address}...')
await self.device.connect(self.peer_address)
async def connect_l2cap(self, psm):
print(color(f'### Connecting with L2CAP on PSM = {psm}', 'yellow'))
try:
self.l2cap_channel = await self.peer.connection.open_l2cap_channel(psm)
print(color('*** Connected', 'yellow'), self.l2cap_channel)
self.l2cap_channel.sink = self.on_coc_sdu
except Exception as error:
print(color(f'!!! Connection failed: {error}', 'red'))
@AsyncRunner.run_in_task()
async def on_connection(self, connection):
@@ -80,15 +127,24 @@ class GattlinkHubBridge(Device.Listener):
self.rx_characteristic = characteristic
elif characteristic.uuid == GG_GATTLINK_TX_CHARACTERISTIC_UUID:
self.tx_characteristic = characteristic
elif characteristic.uuid == GG_GATTLINK_L2CAP_CHANNEL_PSM_CHARACTERISTIC_UUID:
self.l2cap_psm_characteristic = characteristic
print('RX:', self.rx_characteristic)
print('TX:', self.tx_characteristic)
print('PSM:', self.l2cap_psm_characteristic)
# Subscribe to TX
if self.tx_characteristic:
if self.l2cap_psm_characteristic:
# Subscribe to and then read the PSM value
await self.peer.subscribe(self.l2cap_psm_characteristic, self.on_l2cap_psm_received)
psm_bytes = await self.peer.read_value(self.l2cap_psm_characteristic)
psm = struct.unpack('<H', psm_bytes)[0]
await self.connect_l2cap(psm)
elif self.tx_characteristic:
# Subscribe to TX
await self.peer.subscribe(self.tx_characteristic, self.on_tx_received)
print(color('=== Subscribed to Gattlink TX', 'yellow'))
else:
print(color('!!! Gattlink TX not found', 'red'))
print(color('!!! No Gattlink TX or PSM found', 'red'))
def on_connection_failure(self, error):
print(color(f'!!! Connection failed: {error}'))
@@ -99,31 +155,23 @@ class GattlinkHubBridge(Device.Listener):
self.rx_characteristic = None
self.peer = None
# Called when an L2CAP packet has been received
def on_l2cap_packet(self, packet):
print(color(f'<<< [L2CAP PACKET]: {len(packet)} bytes', 'cyan'))
print(color('>>> [UDP]', 'magenta'))
self.tx_socket.sendto(packet)
# Called by the GATT client when a notification is received
def on_tx_received(self, value):
print(color('>>> TX:', 'magenta'), value.hex())
print(color(f'<<< [GATT TX]: {len(value)} bytes', 'cyan'))
if self.tx_socket:
print(color('>>> [UDP]', 'magenta'))
self.tx_socket.sendto(value)
# Called by asyncio when the UDP socket is created
def connection_made(self, transport):
pass
# Called by asyncio when a UDP datagram is received
def datagram_received(self, data, address):
print(color('<<< RX:', 'magenta'), data.hex())
# TODO: use a queue instead of creating a task everytime
if self.peer and self.rx_characteristic:
asyncio.create_task(self.peer.write_value(self.rx_characteristic, data))
# -----------------------------------------------------------------------------
class GattlinkNodeBridge(Device.Listener):
def __init__(self):
self.peer = None
self.rx_socket = None
self.tx_socket = None
def on_l2cap_psm_received(self, value):
psm = struct.unpack('<H', value)[0]
asyncio.create_task(self.connect_l2cap(psm))
# Called by asyncio when the UDP socket is created
def connection_made(self, transport):
@@ -131,21 +179,130 @@ class GattlinkNodeBridge(Device.Listener):
# Called by asyncio when a UDP datagram is received
def datagram_received(self, data, address):
print(color('<<< RX:', 'magenta'), data.hex())
print(color(f'<<< [UDP]: {len(data)} bytes', 'green'))
# TODO: use a queue instead of creating a task everytime
if self.peer and self.rx_characteristic:
if self.l2cap_channel:
print(color('>>> [L2CAP]', 'yellow'))
self.l2cap_channel.write(bytes([len(data) - 1]) + data)
elif self.peer and self.rx_characteristic:
print(color('>>> [GATT RX]', 'yellow'))
asyncio.create_task(self.peer.write_value(self.rx_characteristic, data))
# -----------------------------------------------------------------------------
async def run(hci_transport, device_address, send_host, send_port, receive_host, receive_port):
class GattlinkNodeBridge(GattlinkL2capEndpoint, Device.Listener):
def __init__(self, device):
super().__init__()
self.device = device
self.peer = None
self.tx_socket = None
self.tx_subscriber = None
self.rx_characteristic = None
# Register as a listener
device.listener = self
# Listen for incoming L2CAP CoC connections
psm = 0xFB
device.register_l2cap_channel_server(0xFB, self.on_coc)
print(f'### Listening for CoC connection on PSM {psm}')
# Setup the Gattlink service
self.rx_characteristic = Characteristic(
GG_GATTLINK_RX_CHARACTERISTIC_UUID,
Characteristic.WRITE_WITHOUT_RESPONSE,
Characteristic.WRITEABLE,
CharacteristicValue(write=self.on_rx_write)
)
self.tx_characteristic = Characteristic(
GG_GATTLINK_TX_CHARACTERISTIC_UUID,
Characteristic.NOTIFY,
Characteristic.READABLE
)
self.tx_characteristic.on('subscription', self.on_tx_subscription)
self.psm_characteristic = Characteristic(
GG_GATTLINK_L2CAP_CHANNEL_PSM_CHARACTERISTIC_UUID,
Characteristic.READ | Characteristic.NOTIFY,
Characteristic.READABLE,
bytes([psm, 0])
)
gattlink_service = Service(
GG_GATTLINK_SERVICE_UUID,
[
self.rx_characteristic,
self.tx_characteristic,
self.psm_characteristic
]
)
device.add_services([gattlink_service])
device.advertising_data = bytes(
AdvertisingData([
(AdvertisingData.COMPLETE_LOCAL_NAME, bytes('Bumble GG', 'utf-8')),
(AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS,
bytes(reversed(bytes.fromhex('ABBAFF00E56A484CB8328B17CF6CBFE8'))))
])
)
async def start(self):
await self.device.start_advertising()
# Called by asyncio when the UDP socket is created
def connection_made(self, transport):
self.transport = transport
# Called by asyncio when a UDP datagram is received
def datagram_received(self, data, address):
print(color(f'<<< [UDP]: {len(data)} bytes', 'green'))
if self.l2cap_channel:
print(color('>>> [L2CAP]', 'yellow'))
self.l2cap_channel.write(bytes([len(data) - 1]) + data)
elif self.tx_subscriber:
print(color('>>> [GATT TX]', 'yellow'))
self.tx_characteristic.value = data
asyncio.create_task(self.device.notify_subscribers(self.tx_characteristic))
# Called when a write to the RX characteristic has been received
def on_rx_write(self, connection, data):
print(color(f'<<< [GATT RX]: {len(data)} bytes', 'cyan'))
print(color('>>> [UDP]', 'magenta'))
self.tx_socket.sendto(data)
# Called when the subscription to the TX characteristic has changed
def on_tx_subscription(self, peer, enabled):
print(f'### [GATT TX] subscription from {peer}: {"enabled" if enabled else "disabled"}')
if enabled:
self.tx_subscriber = peer
else:
self.tx_subscriber = None
# Called when an L2CAP packet is received
def on_l2cap_packet(self, packet):
print(color(f'<<< [L2CAP PACKET]: {len(packet)} bytes', 'cyan'))
print(color('>>> [UDP]', 'magenta'))
self.tx_socket.sendto(packet)
# Called when a new connection is established
def on_coc(self, channel):
print('*** CoC Connection', channel)
self.l2cap_channel = channel
channel.sink = self.on_coc_sdu
# -----------------------------------------------------------------------------
async def run(hci_transport, device_address, role_or_peer_address, send_host, send_port, receive_host, receive_port):
print('<<< connecting to HCI...')
async with await open_transport_or_link(hci_transport) as (hci_source, hci_sink):
print('<<< connected')
# Instantiate a bridge object
bridge = GattlinkNodeBridge()
device = Device.with_hci('Bumble GG', device_address, hci_source, hci_sink)
# Instantiate a bridge object
if role_or_peer_address == 'node':
bridge = GattlinkNodeBridge(device)
else:
bridge = GattlinkHubBridge(device, role_or_peer_address)
# Create a UDP to RX bridge (receive from UDP, send to RX)
loop = asyncio.get_running_loop()
@@ -160,35 +317,8 @@ async def run(hci_transport, device_address, send_host, send_port, receive_host,
remote_addr=(send_host, send_port)
)
# Create a device to manage the host, with a custom listener
device = Device.with_hci('Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink)
device.listener = bridge
await device.power_on()
# Connect to the peer
# print(f'=== Connecting to {device_address}...')
# await device.connect(device_address)
# TODO move to class
gattlink_service = Service(
GG_GATTLINK_SERVICE_UUID,
[
Characteristic(
GG_GATTLINK_L2CAP_CHANNEL_PSM_CHARACTERISTIC_UUID,
Characteristic.READ,
Characteristic.READABLE,
bytes([193, 0])
)
]
)
device.add_services([gattlink_service])
device.advertising_data = bytes(
AdvertisingData([
(AdvertisingData.COMPLETE_LOCAL_NAME, bytes('Bumble GG', 'utf-8')),
(AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS, bytes(reversed(bytes.fromhex('ABBAFF00E56A484CB8328B17CF6CBFE8'))))
])
)
await device.start_advertising()
await bridge.start()
# Wait until the source terminates
await hci_source.wait_for_termination()
@@ -197,15 +327,16 @@ async def run(hci_transport, device_address, send_host, send_port, receive_host,
@click.command()
@click.argument('hci_transport')
@click.argument('device_address')
@click.argument('role_or_peer_address')
@click.option('-sh', '--send-host', type=str, default='127.0.0.1', help='UDP host to send to')
@click.option('-sp', '--send-port', type=int, default=9001, help='UDP port to send to')
@click.option('-rh', '--receive-host', type=str, default='127.0.0.1', help='UDP host to receive on')
@click.option('-rp', '--receive-port', type=int, default=9000, help='UDP port to receive on')
def main(hci_transport, device_address, send_host, send_port, receive_host, receive_port):
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(run(hci_transport, device_address, send_host, send_port, receive_host, receive_port))
def main(hci_transport, device_address, role_or_peer_address, send_host, send_port, receive_host, receive_port):
asyncio.run(run(hci_transport, device_address, role_or_peer_address, send_host, send_port, receive_host, receive_port))
# -----------------------------------------------------------------------------
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
if __name__ == '__main__':
main()

331
apps/l2cap_bridge.py Normal file
View File

@@ -0,0 +1,331 @@
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import click
import logging
import os
from colors import color
from bumble.transport import open_transport_or_link
from bumble.device import Device
from bumble.utils import FlowControlAsyncPipe
from bumble.hci import HCI_Constant
# -----------------------------------------------------------------------------
class ServerBridge:
"""
L2CAP CoC server bridge: waits for a peer to connect an L2CAP CoC channel
on a specified PSM. When the connection is made, the bridge connects a TCP
socket to a remote host and bridges the data in both directions, with flow
control.
When the L2CAP CoC channel is closed, the bridge disconnects the TCP socket
and waits for a new L2CAP CoC channel to be connected.
When the TCP connection is closed by the TCP server, XXXX
"""
def __init__(
self,
psm,
max_credits,
mtu,
mps,
tcp_host,
tcp_port
):
self.psm = psm
self.max_credits = max_credits
self.mtu = mtu
self.mps = mps
self.tcp_host = tcp_host
self.tcp_port = tcp_port
async def start(self, device):
# Listen for incoming L2CAP CoC connections
device.register_l2cap_channel_server(
psm = self.psm,
server = self.on_coc,
max_credits = self.max_credits,
mtu = self.mtu,
mps = self.mps
)
print(color(f'### Listening for CoC connection on PSM {self.psm}', 'yellow'))
def on_ble_connection(connection):
def on_ble_disconnection(reason):
print(color('@@@ Bluetooth disconnection:', 'red'), HCI_Constant.error_name(reason))
print(color('@@@ Bluetooth connection:', 'green'), connection)
connection.on('disconnection', on_ble_disconnection)
device.on('connection', on_ble_connection)
await device.start_advertising(auto_restart=True)
# Called when a new L2CAP connection is established
def on_coc(self, l2cap_channel):
print(color('*** L2CAP channel:', 'cyan'), l2cap_channel)
class Pipe:
def __init__(self, bridge, l2cap_channel):
self.bridge = bridge
self.tcp_transport = None
self.l2cap_channel = l2cap_channel
l2cap_channel.on('close', self.on_l2cap_close)
l2cap_channel.sink = self.on_coc_sdu
async def connect_to_tcp(self):
# Connect to the TCP server
print(color(f'### Connecting to TCP {self.bridge.tcp_host}:{self.bridge.tcp_port}...', 'yellow'))
class TcpClientProtocol(asyncio.Protocol):
def __init__(self, pipe):
self.pipe = pipe
def connection_lost(self, error):
print(color(f'!!! TCP connection lost: {error}', 'red'))
if self.pipe.l2cap_channel is not None:
asyncio.create_task(self.pipe.l2cap_channel.disconnect())
def data_received(self, data):
print(f'<<< Received on TCP: {len(data)}')
self.pipe.l2cap_channel.write(data)
try:
self.tcp_transport, _ = await asyncio.get_running_loop().create_connection(
lambda: TcpClientProtocol(self),
host=self.bridge.tcp_host,
port=self.bridge.tcp_port,
)
print(color('### Connected', 'green'))
except Exception as error:
print(color(f'!!! Connection failed: {error}', 'red'))
await self.l2cap_channel.disconnect()
def on_l2cap_close(self):
self.l2cap_channel = None
if self.tcp_transport is not None:
self.tcp_transport.close()
def on_coc_sdu(self, sdu):
print(color(f'<<< [L2CAP SDU]: {len(sdu)} bytes', 'cyan'))
if self.tcp_transport is None:
print(color('!!! TCP socket not open, dropping', 'red'))
return
self.tcp_transport.write(sdu)
pipe = Pipe(self, l2cap_channel)
asyncio.create_task(pipe.connect_to_tcp())
# -----------------------------------------------------------------------------
class ClientBridge:
"""
L2CAP CoC client bridge: connects to a BLE device, then waits for an inbound
TCP connection on a specified port number. When a TCP client connects, an
L2CAP CoC channel connection to the BLE device is established, and the data
is bridged in both directions, with flow control.
When the TCP connection is closed by the client, the L2CAP CoC channel is
disconnected, but the connection to the BLE device remains, ready for a new
TCP client to connect.
When the L2CAP CoC channel is closed, XXXX
"""
READ_CHUNK_SIZE = 4096
def __init__(
self,
psm,
max_credits,
mtu,
mps,
address,
tcp_host,
tcp_port
):
self.psm = psm
self.max_credits = max_credits
self.mtu = mtu
self.mps = mps
self.address = address
self.tcp_host = tcp_host
self.tcp_port = tcp_port
async def start(self, device):
print(color(f'### Connecting to {self.address}...', 'yellow'))
connection = await device.connect(self.address)
print(color('### Connected', 'green'))
# Called when the BLE connection is disconnected
def on_ble_disconnection(reason):
print(color('@@@ Bluetooth disconnection:', 'red'), HCI_Constant.error_name(reason))
connection.on('disconnection', on_ble_disconnection)
# Called when a TCP connection is established
async def on_tcp_connection(reader, writer):
peername = writer.get_extra_info('peername')
print(color(f'<<< TCP connection from {peername}', 'magenta'))
def on_coc_sdu(sdu):
print(color(f'<<< [L2CAP SDU]: {len(sdu)} bytes', 'cyan'))
l2cap_to_tcp_pipe.write(sdu)
def on_l2cap_close():
print(color('*** L2CAP channel closed', 'red'))
l2cap_to_tcp_pipe.stop()
writer.close()
# Connect a new L2CAP channel
print(color(f'>>> Opening L2CAP channel on PSM = {self.psm}', 'yellow'))
try:
l2cap_channel = await connection.open_l2cap_channel(
psm = self.psm,
max_credits = self.max_credits,
mtu = self.mtu,
mps = self.mps
)
print(color('*** L2CAP channel:', 'cyan'), l2cap_channel)
except Exception as error:
print(color(f'!!! Connection failed: {error}', 'red'))
writer.close()
return
l2cap_channel.sink = on_coc_sdu
l2cap_channel.on('close', on_l2cap_close)
# Start a flow control pipe from L2CAP to TCP
l2cap_to_tcp_pipe = FlowControlAsyncPipe(
l2cap_channel.pause_reading,
l2cap_channel.resume_reading,
writer.write,
writer.drain
)
l2cap_to_tcp_pipe.start()
# Pipe data from TCP to L2CAP
while True:
try:
data = await reader.read(self.READ_CHUNK_SIZE)
if len(data) == 0:
print(color('!!! End of stream', 'red'))
await l2cap_channel.disconnect()
return
print(color(f'<<< [TCP DATA]: {len(data)} bytes', 'blue'))
l2cap_channel.write(data)
await l2cap_channel.drain()
except Exception as error:
print(f'!!! Exception: {error}')
break
writer.close()
print(color('~~~ Bye bye', 'magenta'))
await asyncio.start_server(
on_tcp_connection,
host=self.tcp_host if self.tcp_host != '_' else None,
port=self.tcp_port
)
print(color(f'### Listening for TCP connections on port {self.tcp_port}', 'magenta'))
# -----------------------------------------------------------------------------
async def run(device_config, hci_transport, bridge):
print('<<< connecting to HCI...')
async with await open_transport_or_link(hci_transport) as (hci_source, hci_sink):
print('<<< connected')
device = Device.from_config_file_with_hci(device_config, hci_source, hci_sink)
# Let's go
await device.power_on()
await bridge.start(device)
# Wait until the transport terminates
await hci_source.wait_for_termination()
# -----------------------------------------------------------------------------
@click.group()
@click.pass_context
@click.option('--device-config', help='Device configuration file', required=True)
@click.option('--hci-transport', help='HCI transport', required=True)
@click.option('--psm', help='PSM for L2CAP CoC', type=int, default=1234)
@click.option('--l2cap-coc-max-credits', help='Maximum L2CAP CoC Credits', type=click.IntRange(1, 65535), default=128)
@click.option('--l2cap-coc-mtu', help='L2CAP CoC MTU', type=click.IntRange(23, 65535), default=1022)
@click.option('--l2cap-coc-mps', help='L2CAP CoC MPS', type=click.IntRange(23, 65533), default=1024)
def cli(context, device_config, hci_transport, psm, l2cap_coc_max_credits, l2cap_coc_mtu, l2cap_coc_mps):
context.ensure_object(dict)
context.obj['device_config'] = device_config
context.obj['hci_transport'] = hci_transport
context.obj['psm'] = psm
context.obj['max_credits'] = l2cap_coc_max_credits
context.obj['mtu'] = l2cap_coc_mtu
context.obj['mps'] = l2cap_coc_mps
# -----------------------------------------------------------------------------
@cli.command()
@click.pass_context
@click.option('--tcp-host', help='TCP host', default='localhost')
@click.option('--tcp-port', help='TCP port', default=9544)
def server(context, tcp_host, tcp_port):
bridge = ServerBridge(
context.obj['psm'],
context.obj['max_credits'],
context.obj['mtu'],
context.obj['mps'],
tcp_host,
tcp_port)
asyncio.run(run(
context.obj['device_config'],
context.obj['hci_transport'],
bridge
))
# -----------------------------------------------------------------------------
@cli.command()
@click.pass_context
@click.argument('bluetooth-address')
@click.option('--tcp-host', help='TCP host', default='_')
@click.option('--tcp-port', help='TCP port', default=9543)
def client(context, bluetooth_address, tcp_host, tcp_port):
bridge = ClientBridge(
context.obj['psm'],
context.obj['max_credits'],
context.obj['mtu'],
context.obj['mps'],
bluetooth_address,
tcp_host,
tcp_port
)
asyncio.run(run(
context.obj['device_config'],
context.obj['hci_transport'],
bridge
))
# -----------------------------------------------------------------------------
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
if __name__ == '__main__':
cli(obj={})

View File

@@ -351,7 +351,7 @@ class MediaPacketPump:
logger.debug('pump canceled')
# Pump packets
self.pump_task = asyncio.get_running_loop().create_task(pump_packets())
self.pump_task = asyncio.create_task(pump_packets())
async def stop(self):
# Stop the pump
@@ -1890,10 +1890,10 @@ class LocalSource(LocalStreamEndPoint, EventEmitter):
self.configuration = configuration
def on_start_command(self):
asyncio.get_running_loop().create_task(self.start())
asyncio.create_task(self.start())
def on_suspend_command(self):
asyncio.get_running_loop().create_task(self.stop())
asyncio.create_task(self.stop())
# -----------------------------------------------------------------------------

View File

@@ -43,6 +43,13 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
DEVICE_MIN_SCAN_INTERVAL = 25
DEVICE_MAX_SCAN_INTERVAL = 10240
DEVICE_MIN_SCAN_WINDOW = 25
DEVICE_MAX_SCAN_WINDOW = 10240
DEVICE_MIN_LE_RSSI = -127
DEVICE_MAX_LE_RSSI = 20
DEVICE_DEFAULT_ADDRESS = '00:00:00:00:00:00'
DEVICE_DEFAULT_ADVERTISING_INTERVAL = 1000 # ms
DEVICE_DEFAULT_ADVERTISING_DATA = ''
@@ -62,20 +69,15 @@ DEVICE_DEFAULT_CONNECTION_MAX_LATENCY = 0
DEVICE_DEFAULT_CONNECTION_SUPERVISION_TIMEOUT = 720 # ms
DEVICE_DEFAULT_CONNECTION_MIN_CE_LENGTH = 0 # ms
DEVICE_DEFAULT_CONNECTION_MAX_CE_LENGTH = 0 # ms
DEVICE_MIN_SCAN_INTERVAL = 25
DEVICE_MAX_SCAN_INTERVAL = 10240
DEVICE_MIN_SCAN_WINDOW = 25
DEVICE_MAX_SCAN_WINDOW = 10240
DEVICE_MIN_LE_RSSI = -127
DEVICE_MAX_LE_RSSI = 20
DEVICE_DEFAULT_L2CAP_COC_MTU = l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU
DEVICE_DEFAULT_L2CAP_COC_MPS = l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS
DEVICE_DEFAULT_L2CAP_COC_MAX_CREDITS = l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
# -----------------------------------------------------------------------------
class Advertisement:
TX_POWER_NOT_AVAILABLE = HCI_LE_Extended_Advertising_Report_Event.TX_POWER_INFORMATION_NOT_AVAILABLE
@@ -429,7 +431,16 @@ class Connection(CompositeEventEmitter):
def create_l2cap_connector(self, psm):
return self.device.create_l2cap_connector(self, psm)
async def disconnect(self, reason = HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR):
async def open_l2cap_channel(
self,
psm,
max_credits=DEVICE_DEFAULT_L2CAP_COC_MAX_CREDITS,
mtu=DEVICE_DEFAULT_L2CAP_COC_MTU,
mps=DEVICE_DEFAULT_L2CAP_COC_MPS
):
return await self.device.open_l2cap_channel(self, psm, max_credits, mtu, mps)
async def disconnect(self, reason=HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR):
return await self.device.disconnect(self, reason)
async def pair(self):
@@ -563,6 +574,7 @@ class DeviceConfiguration:
with open(filename, 'r') as file:
self.load_from_dict(json.load(file))
# -----------------------------------------------------------------------------
# Decorators used with the following Device class
# (we define them outside of the Device class, because defining decorators
@@ -685,7 +697,7 @@ class Device(CompositeEventEmitter):
self.classic_enabled = False
self.inquiry_response = None
self.address_resolver = None
self.classic_pending_accepts = { Address.ANY: [] } # Futures, by BD address OR [Futures] for Address.ANY
self.classic_pending_accepts = {Address.ANY: []} # Futures, by BD address OR [Futures] for Address.ANY
# Use the initial config or a default
self.public_address = Address('00:00:00:00:00:00')
@@ -785,15 +797,35 @@ class Device(CompositeEventEmitter):
if transport is None or connection.transport == transport:
return connection
def register_l2cap_server(self, psm, server):
self.l2cap_channel_manager.register_server(psm, server)
def create_l2cap_connector(self, connection, psm):
return lambda: self.l2cap_channel_manager.connect(connection, psm)
def create_l2cap_registrar(self, psm):
return lambda handler: self.register_l2cap_server(psm, handler)
def register_l2cap_server(self, psm, server):
self.l2cap_channel_manager.register_server(psm, server)
def register_l2cap_channel_server(
self,
psm,
server,
max_credits=DEVICE_DEFAULT_L2CAP_COC_MAX_CREDITS,
mtu=DEVICE_DEFAULT_L2CAP_COC_MTU,
mps=DEVICE_DEFAULT_L2CAP_COC_MPS
):
return self.l2cap_channel_manager.register_le_coc_server(psm, server, max_credits, mtu, mps)
async def open_l2cap_channel(
self,
connection,
psm,
max_credits=DEVICE_DEFAULT_L2CAP_COC_MAX_CREDITS,
mtu=DEVICE_DEFAULT_L2CAP_COC_MTU,
mps=DEVICE_DEFAULT_L2CAP_COC_MPS
):
return await self.l2cap_channel_manager.open_le_coc(connection, psm, max_credits, mtu, mps)
def send_l2cap_pdu(self, connection_handle, cid, pdu):
self.host.send_l2cap_pdu(connection_handle, cid, pdu)
@@ -1185,13 +1217,15 @@ class Device(CompositeEventEmitter):
def on_connection(connection):
if transport == BT_LE_TRANSPORT or (
# match BR/EDR connection event against peer address
connection.transport == transport and connection.peer_address == peer_address):
connection.transport == transport and connection.peer_address == peer_address
):
pending_connection.set_result(connection)
def on_connection_failure(error):
if transport == BT_LE_TRANSPORT or (
# match BR/EDR connection failure event against peer address
error.transport == transport and error.peer_address == peer_address):
error.transport == transport and error.peer_address == peer_address
):
pending_connection.set_exception(error)
# Create a future so that we can wait for the connection's result
@@ -1336,7 +1370,7 @@ class Device(CompositeEventEmitter):
if peer_address == Address.NIL:
raise ValueError('accept on nil address')
# Create a future so that we can wait for the request
# Create a future so that we can wait for the request
pending_request = asyncio.get_running_loop().create_future()
if peer_address == Address.ANY:
@@ -1349,8 +1383,7 @@ class Device(CompositeEventEmitter):
try:
# Wait for a request or a completed connection
result = await (asyncio.wait_for(pending_request, timeout) if timeout else pending_request)
except:
except Exception:
# Remove future from device context
if peer_address == Address.ANY:
self.classic_pending_accepts[Address.ANY].remove(pending_request)
@@ -1710,26 +1743,32 @@ class Device(CompositeEventEmitter):
connection.remove_listener('connection_encryption_failure', on_encryption_failure)
# [Classic only]
async def request_remote_name(self, remote: Connection | Address):
async def request_remote_name(self, remote): # remote: Connection | Address
# Set up event handlers
pending_name = asyncio.get_running_loop().create_future()
if type(remote) == Address:
peer_address = remote
handler = self.on('remote_name',
handler = self.on(
'remote_name',
lambda address, remote_name:
pending_name.set_result(remote_name) if address == remote else None)
failure_handler = self.on('remote_name_failure',
pending_name.set_result(remote_name) if address == remote else None
)
failure_handler = self.on(
'remote_name_failure',
lambda address, error_code:
pending_name.set_exception(HCI_Error(error_code)) if address == remote else None)
pending_name.set_exception(HCI_Error(error_code)) if address == remote else None
)
else:
peer_address = remote.peer_address
handler = remote.on('remote_name',
lambda:
pending_name.set_result(remote.peer_name))
failure_handler = remote.on('remote_name_failure',
lambda error_code:
pending_name.set_exception(HCI_Error(error_code)))
handler = remote.on(
'remote_name',
lambda: pending_name.set_result(remote.peer_name)
)
failure_handler = remote.on(
'remote_name_failure',
lambda error_code: pending_name.set_exception(HCI_Error(error_code))
)
try:
result = await self.send_command(
@@ -2097,7 +2136,6 @@ class Device(CompositeEventEmitter):
else:
self.emit('remote_name_failure', address, error)
# [Classic only]
@host_event_handler
@try_with_connection_from_address

View File

@@ -25,6 +25,7 @@
import asyncio
import types
import logging
from pyee import EventEmitter
from colors import color
from .core import *

View File

@@ -273,7 +273,7 @@ class Client:
if response.op_code == ATT_ERROR_RESPONSE:
if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR:
# Unexpected end
logger.waning(f'!!! unexpected error while discovering services: {HCI_Constant.error_name(response.error_code)}')
logger.warning(f'!!! unexpected error while discovering services: {HCI_Constant.error_name(response.error_code)}')
# TODO raise appropriate exception
return
break
@@ -337,7 +337,7 @@ class Client:
if response.op_code == ATT_ERROR_RESPONSE:
if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR:
# Unexpected end
logger.waning(f'!!! unexpected error while discovering services: {HCI_Constant.error_name(response.error_code)}')
logger.warning(f'!!! unexpected error while discovering services: {HCI_Constant.error_name(response.error_code)}')
# TODO raise appropriate exception
return
break

View File

@@ -155,7 +155,7 @@ class Server(EventEmitter):
return cccd or bytes([0, 0])
def write_cccd(self, connection, characteristic, value):
logger.debug(f'Subscription update for connection={connection.handle:04X}, handle={characteristic.handle:04X}: {value.hex()}')
logger.debug(f'Subscription update for connection=0x{connection.handle:04X}, handle=0x{characteristic.handle:04X}: {value.hex()}')
# Sanity check
if len(value) != 2:
@@ -204,7 +204,7 @@ class Server(EventEmitter):
logger.debug(f'GATT Notify from server: [0x{connection.handle:04X}] {notification}')
self.send_gatt_pdu(connection.handle, bytes(notification))
async def indicate_subscriber(self, connection, attribute, value=None, force=False):
async def indicate_subscriber(self, connection, attribute, force=False):
# Check if there's a subscriber
if not force:
subscribers = self.subscribers.get(connection.handle)

View File

@@ -2466,9 +2466,10 @@ class HCI_Write_Voice_Setting_Command(HCI_Command):
# -----------------------------------------------------------------------------
@HCI_Command.command()
class HCI_Read_Synchronous_Flow_Control_Enable_Command(HCI_Command):
'''
See Bluetooth spec @ 7.3.36 Write Synchronous Flow Control Enable Command
See Bluetooth spec @ 7.3.36 Read Synchronous Flow Control Enable Command
'''

View File

@@ -79,6 +79,8 @@ class Host(EventEmitter):
self.local_version = None
self.local_supported_commands = bytes(64)
self.local_le_features = 0
self.suggested_max_tx_octets = 251 # Max allowed
self.suggested_max_tx_time = 2120 # Max allowed
self.command_semaphore = asyncio.Semaphore(1)
self.long_term_key_provider = None
self.link_key_provider = None
@@ -138,6 +140,22 @@ class Host(EventEmitter):
f'hc_total_num_le_acl_data_packets={self.hc_total_num_le_acl_data_packets}'
)
if (
self.supports_command(HCI_LE_READ_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND) and
self.supports_command(HCI_LE_WRITE_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND)
):
response = await self.send_command(HCI_LE_Read_Suggested_Default_Data_Length_Command())
suggested_max_tx_octets = response.return_parameters.suggested_max_tx_octets
suggested_max_tx_time = response.return_parameters.suggested_max_tx_time
if (
suggested_max_tx_octets != self.suggested_max_tx_octets or
suggested_max_tx_time != self.suggested_max_tx_time
):
await self.send_command(HCI_LE_Write_Suggested_Default_Data_Length_Command(
suggested_max_tx_octets = self.suggested_max_tx_octets,
suggested_max_tx_time = self.suggested_max_tx_time
))
self.reset_done = True
@property

File diff suppressed because it is too large Load Diff

View File

@@ -274,7 +274,7 @@ class PumpedPacketSource(ParserSource):
self.terminated.set_result(error)
break
self.pump_task = asyncio.get_running_loop().create_task(pump_packets())
self.pump_task = asyncio.create_task(pump_packets())
def close(self):
if self.pump_task:
@@ -304,7 +304,7 @@ class PumpedPacketSink:
logger.warn(f'exception while sending packet: {error}')
break
self.pump_task = asyncio.get_running_loop().create_task(pump_packets())
self.pump_task = asyncio.create_task(pump_packets())
def close(self):
if self.pump_task:

View File

@@ -18,6 +18,7 @@
import asyncio
import logging
import traceback
import collections
from functools import wraps
from colors import color
from pyee import EventEmitter
@@ -140,3 +141,95 @@ class AsyncRunner:
return wrapper
return decorator
# -----------------------------------------------------------------------------
class FlowControlAsyncPipe:
"""
Asyncio pipe with flow control. When writing to the pipe, the source is
paused (by calling a function passed in when the pipe is created) if the
amount of queued data exceeds a specified threshold.
"""
def __init__(self, pause_source, resume_source, write_to_sink=None, drain_sink=None, threshold=0):
self.pause_source = pause_source
self.resume_source = resume_source
self.write_to_sink = write_to_sink
self.drain_sink = drain_sink
self.threshold = threshold
self.queue = collections.deque() # Queue of packets
self.queued_bytes = 0 # Number of bytes in the queue
self.ready_to_pump = asyncio.Event()
self.paused = False
self.source_paused = False
self.pump_task = None
def start(self):
if self.pump_task is None:
self.pump_task = asyncio.create_task(self.pump())
self.check_pump()
def stop(self):
if self.pump_task is not None:
self.pump_task.cancel()
self.pump_task = None
def write(self, packet):
self.queued_bytes += len(packet)
self.queue.append(packet)
# Pause the source if we're over the threshold
if self.queued_bytes > self.threshold and not self.source_paused:
logger.debug(f'pausing source (queued={self.queued_bytes})')
self.pause_source()
self.source_paused = True
self.check_pump()
def pause(self):
if not self.paused:
self.paused = True
if not self.source_paused:
self.pause_source()
self.source_paused = True
self.check_pump()
def resume(self):
if self.paused:
self.paused = False
if self.source_paused:
self.resume_source()
self.source_paused = False
self.check_pump()
def can_pump(self):
return self.queue and not self.paused and self.write_to_sink is not None
def check_pump(self):
if self.can_pump():
self.ready_to_pump.set()
else:
self.ready_to_pump.clear()
async def pump(self):
while True:
# Wait until we can try to pump packets
await self.ready_to_pump.wait()
# Try to pump a packet
if self.can_pump():
packet = self.queue.pop()
self.write_to_sink(packet)
self.queued_bytes -= len(packet)
# Drain the sink if we can
if self.drain_sink:
await self.drain_sink()
# Check if we can accept more
if self.queued_bytes <= self.threshold and self.source_paused:
logger.debug(f'resuming source (queued={self.queued_bytes})')
self.source_paused = False
self.resume_source()
self.check_pump()

5
examples/asha_sink1.json Normal file
View File

@@ -0,0 +1,5 @@
{
"name": "Bumble Aid Left",
"address": "F1:F2:F3:F4:F5:F6",
"keystore": "JsonKeyStore"
}

5
examples/asha_sink2.json Normal file
View File

@@ -0,0 +1,5 @@
{
"name": "Bumble Aid Right",
"address": "F7:F8:F9:FA:FB:FC",
"keystore": "JsonKeyStore"
}

161
examples/run_asha_sink.py Normal file
View File

@@ -0,0 +1,161 @@
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import struct
import sys
import os
import logging
from bumble.core import AdvertisingData
from bumble.device import Device
from bumble.transport import open_transport_or_link
from bumble.hci import UUID
from bumble.gatt import (
Service,
Characteristic,
CharacteristicValue
)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
ASHA_SERVICE = UUID.from_16_bits(0xFDF0, 'Audio Streaming for Hearing Aid')
ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC = UUID('6333651e-c481-4a3e-9169-7c902aad37bb', 'ReadOnlyProperties')
ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC = UUID('f0d4de7e-4a88-476c-9d9f-1937b0996cc0', 'AudioControlPoint')
ASHA_AUDIO_STATUS_CHARACTERISTIC = UUID('38663f1a-e711-4cac-b641-326b56404837', 'AudioStatus')
ASHA_VOLUME_CHARACTERISTIC = UUID('00e4ca9e-ab14-41e4-8823-f9e70c7e91df', 'Volume')
ASHA_LE_PSM_OUT_CHARACTERISTIC = UUID('2d410339-82b6-42aa-b34e-e2e01df8cc1a', 'LE_PSM_OUT')
# -----------------------------------------------------------------------------
async def main():
if len(sys.argv) != 4:
print('Usage: python run_asha_sink.py <device-config> <transport-spec> <audio-file>')
print('example: python run_asha_sink.py device1.json usb:0 audio_out.g722')
return
audio_out = open(sys.argv[3], 'wb')
async with await open_transport_or_link(sys.argv[2]) as (hci_source, hci_sink):
device = Device.from_config_file_with_hci(sys.argv[1], hci_source, hci_sink)
# Handler for audio control commands
def on_audio_control_point_write(connection, value):
print('--- AUDIO CONTROL POINT Write:', value.hex())
opcode = value[0]
if opcode == 1:
# Start
audio_type = ('Unknown', 'Ringtone', 'Phone Call', 'Media')[value[2]]
print(f'### START: codec={value[1]}, audio_type={audio_type}, volume={value[3]}, otherstate={value[4]}')
elif opcode == 2:
print('### STOP')
elif opcode == 3:
print(f'### STATUS: connected={value[1]}')
# Respond with a status
asyncio.create_task(device.notify_subscribers(audio_status_characteristic, force=True))
# Handler for volume control
def on_volume_write(connection, value):
print('--- VOLUME Write:', value[0])
# Register an L2CAP CoC server
def on_coc(channel):
def on_data(data):
print('<<< Voice data received:', data.hex())
audio_out.write(data)
channel.sink = on_data
psm = device.register_l2cap_channel_server(0, on_coc, 8)
print(f'### LE_PSM_OUT = {psm}')
# Add the ASHA service to the GATT server
read_only_properties_characteristic = Characteristic(
ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC,
Characteristic.READ,
Characteristic.READABLE,
bytes([
0x01, # Version
0x00, # Device Capabilities [Left, Monaural]
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, # HiSyncId
0x01, # Feature Map [LE CoC audio output streaming supported]
0x00, 0x00, # Render Delay
0x00, 0x00, # RFU
0x02, 0x00 # Codec IDs [G.722 at 16 kHz]
])
)
audio_control_point_characteristic = Characteristic(
ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC,
Characteristic.WRITE | Characteristic.WRITE_WITHOUT_RESPONSE,
Characteristic.WRITEABLE,
CharacteristicValue(write=on_audio_control_point_write)
)
audio_status_characteristic = Characteristic(
ASHA_AUDIO_STATUS_CHARACTERISTIC,
Characteristic.READ | Characteristic.NOTIFY,
Characteristic.READABLE,
bytes([0])
)
volume_characteristic = Characteristic(
ASHA_VOLUME_CHARACTERISTIC,
Characteristic.WRITE_WITHOUT_RESPONSE,
Characteristic.WRITEABLE,
CharacteristicValue(write=on_volume_write)
)
le_psm_out_characteristic = Characteristic(
ASHA_LE_PSM_OUT_CHARACTERISTIC,
Characteristic.READ,
Characteristic.READABLE,
struct.pack('<H', psm)
)
device.add_service(Service(
ASHA_SERVICE,
[
read_only_properties_characteristic,
audio_control_point_characteristic,
audio_status_characteristic,
volume_characteristic,
le_psm_out_characteristic
]
))
# Set the advertising data
device.advertising_data = bytes(
AdvertisingData([
(AdvertisingData.COMPLETE_LOCAL_NAME, bytes(device.name, 'utf-8')),
(AdvertisingData.FLAGS, bytes([0x06])),
(AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS, bytes(ASHA_SERVICE)),
(AdvertisingData.SERVICE_DATA_16_BIT_UUID, bytes(ASHA_SERVICE) + bytes([
0x01, # Protocol Version
0x00, # Capability
0x01, 0x02, 0x03, 0x04 # Truncated HiSyncID
]))
])
)
# Go!
await device.power_on()
await device.start_advertising(auto_restart=True)
await hci_source.wait_for_termination()
# -----------------------------------------------------------------------------
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main())

View File

@@ -49,11 +49,17 @@ class Listener(Device.Listener, Connection.Listener):
)
# -----------------------------------------------------------------------------
# Alternative way to listen for subscriptions
# -----------------------------------------------------------------------------
def on_my_characteristic_subscription(peer, enabled):
print(f'### My characteristic from {peer}: {"enabled" if enabled else "disabled"}')
# -----------------------------------------------------------------------------
async def main():
if len(sys.argv) < 3:
print('Usage: run_gatt_server.py <device-config> <transport-spec>')
print('example: run_gatt_server.py device1.json usb:0')
print('Usage: run_notifier.py <device-config> <transport-spec>')
print('example: run_notifier.py device1.json usb:0')
return
print('<<< connecting to HCI...')
@@ -83,6 +89,7 @@ async def main():
Characteristic.READABLE,
bytes([0x42])
)
characteristic3.on('subscription', on_my_characteristic_subscription)
custom_service = Service(
'50DB505C-8AC4-4738-8448-3B1D9CC09CC5',
[characteristic1, characteristic2, characteristic3]

View File

@@ -98,6 +98,7 @@ async def list_rfcomm_channels(device, connection):
await sdp_client.disconnect()
# -----------------------------------------------------------------------------
class TcpServerProtocol(asyncio.Protocol):
def __init__(self, rfcomm_session):
@@ -173,7 +174,7 @@ async def main():
print('*** Encryption on')
# Create a client and start it
print('@@@ Starting to RFCOMM client...')
print('@@@ Starting RFCOMM client...')
rfcomm_client = Client(device, connection)
rfcomm_mux = await rfcomm_client.start()
print('@@@ Started')
@@ -192,7 +193,7 @@ async def main():
if len(sys.argv) == 6:
# A TCP port was specified, start listening
tcp_port = int(sys.argv[5])
asyncio.get_running_loop().create_task(tcp_server(tcp_port, session))
asyncio.create_task(tcp_server(tcp_port, session))
await hci_source.wait_for_termination()

View File

@@ -51,6 +51,7 @@ console_scripts =
bumble-controller-info = bumble.apps.controller_info:main
bumble-gatt-dump = bumble.apps.gatt_dump:main
bumble-hci-bridge = bumble.apps.hci_bridge:main
bumble-l2cap-bridge = bumble.apps.l2cap_bridge:main
bumble-pair = bumble.apps.pair:main
bumble-scan = bumble.apps.scan:main
bumble-show = bumble.apps.show:main
@@ -64,6 +65,7 @@ build =
test =
pytest >= 6.2
pytest-asyncio >= 0.17
coverage >= 6.4
development =
invoke >= 1.4
nox >= 2022

284
tests/l2cap_test.py Normal file
View File

@@ -0,0 +1,284 @@
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import logging
import os
import random
import pytest
from bumble.controller import Controller
from bumble.link import LocalLink
from bumble.device import Device
from bumble.host import Host
from bumble.transport import AsyncPipeSink
from bumble.core import ProtocolError
from bumble.l2cap import (
L2CAP_Connection_Request
)
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
class TwoDevices:
def __init__(self):
self.connections = [None, None]
self.link = LocalLink()
self.controllers = [
Controller('C1', link = self.link),
Controller('C2', link = self.link)
]
self.devices = [
Device(
address = 'F0:F1:F2:F3:F4:F5',
host = Host(self.controllers[0], AsyncPipeSink(self.controllers[0]))
),
Device(
address = 'F5:F4:F3:F2:F1:F0',
host = Host(self.controllers[1], AsyncPipeSink(self.controllers[1]))
)
]
self.paired = [None, None]
def on_connection(self, which, connection):
self.connections[which] = connection
def on_paired(self, which, keys):
self.paired[which] = keys
# -----------------------------------------------------------------------------
async def setup_connection():
# Create two devices, each with a controller, attached to the same link
two_devices = TwoDevices()
# Attach listeners
two_devices.devices[0].on('connection', lambda connection: two_devices.on_connection(0, connection))
two_devices.devices[1].on('connection', lambda connection: two_devices.on_connection(1, connection))
# Start
await two_devices.devices[0].power_on()
await two_devices.devices[1].power_on()
# Connect the two devices
await two_devices.devices[0].connect(two_devices.devices[1].random_address)
# Check the post conditions
assert(two_devices.connections[0] is not None)
assert(two_devices.connections[1] is not None)
return two_devices
# -----------------------------------------------------------------------------
def test_helpers():
psm = L2CAP_Connection_Request.serialize_psm(0x01)
assert(psm == bytes([0x01, 0x00]))
psm = L2CAP_Connection_Request.serialize_psm(0x1023)
assert(psm == bytes([0x23, 0x10]))
psm = L2CAP_Connection_Request.serialize_psm(0x242311)
assert(psm == bytes([0x11, 0x23, 0x24]))
(offset, psm) = L2CAP_Connection_Request.parse_psm(bytes([0x00, 0x01, 0x00, 0x44]), 1)
assert(offset == 3)
assert(psm == 0x01)
(offset, psm) = L2CAP_Connection_Request.parse_psm(bytes([0x00, 0x23, 0x10, 0x44]), 1)
assert(offset == 3)
assert(psm == 0x1023)
(offset, psm) = L2CAP_Connection_Request.parse_psm(bytes([0x00, 0x11, 0x23, 0x24, 0x44]), 1)
assert(offset == 4)
assert(psm == 0x242311)
rq = L2CAP_Connection_Request(psm = 0x01, source_cid = 0x44)
brq = bytes(rq)
srq = L2CAP_Connection_Request.from_bytes(brq)
assert(srq.psm == rq.psm)
assert(srq.source_cid == rq.source_cid)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_basic_connection():
devices = await setup_connection()
psm = 1234
# Check that if there's no one listening, we can't connect
with pytest.raises(ProtocolError):
l2cap_channel = await devices.connections[0].open_l2cap_channel(psm)
# Now add a listener
incoming_channel = None
received = []
def on_coc(channel):
nonlocal incoming_channel
incoming_channel = channel
def on_data(data):
received.append(data)
channel.sink = on_data
devices.devices[1].register_l2cap_channel_server(psm, on_coc)
l2cap_channel = await devices.connections[0].open_l2cap_channel(psm)
messages = (
bytes([1, 2, 3]),
bytes([4, 5, 6]),
bytes(10000)
)
for message in messages:
l2cap_channel.write(message)
await asyncio.sleep(0)
await l2cap_channel.drain()
# Test closing
closed = [False, False]
closed_event = asyncio.Event()
def on_close(which, event):
closed[which] = True
if event:
event.set()
l2cap_channel.on('close', lambda: on_close(0, None))
incoming_channel.on('close', lambda: on_close(1, closed_event))
await l2cap_channel.disconnect()
assert(closed == [True, True])
await closed_event.wait()
sent_bytes = b''.join(messages)
received_bytes = b''.join(received)
assert(sent_bytes == received_bytes)
# -----------------------------------------------------------------------------
async def transfer_payload(max_credits, mtu, mps):
devices = await setup_connection()
received = []
def on_coc(channel):
def on_data(data):
received.append(data)
channel.sink = on_data
psm = devices.devices[1].register_l2cap_channel_server(
psm = 0,
server = on_coc,
max_credits = max_credits,
mtu = mtu,
mps = mps
)
l2cap_channel = await devices.connections[0].open_l2cap_channel(psm)
messages = [
bytes([1, 2, 3, 4, 5, 6, 7]) * x
for x in (3, 10, 100, 500, 789)
]
for message in messages:
l2cap_channel.write(message)
await asyncio.sleep(0)
if random.randint(0, 5) == 1:
await l2cap_channel.drain()
await l2cap_channel.drain()
await l2cap_channel.disconnect()
sent_bytes = b''.join(messages)
received_bytes = b''.join(received)
assert(sent_bytes == received_bytes)
@pytest.mark.asyncio
async def test_transfer():
for max_credits in (1, 10, 100, 10000):
for mtu in (23, 24, 25, 26, 50, 200, 255, 256, 1000):
for mps in (23, 24, 25, 26, 50, 200, 255, 256, 1000):
# print(max_credits, mtu, mps)
await transfer_payload(max_credits, mtu, mps)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_bidirectional_transfer():
devices = await setup_connection()
client_received = []
server_received = []
server_channel = None
def on_server_coc(channel):
nonlocal server_channel
server_channel = channel
def on_server_data(data):
server_received.append(data)
channel.sink = on_server_data
def on_client_data(data):
client_received.append(data)
psm = devices.devices[1].register_l2cap_channel_server(psm=0, server=on_server_coc)
client_channel = await devices.connections[0].open_l2cap_channel(psm)
client_channel.sink = on_client_data
messages = [
bytes([1, 2, 3, 4, 5, 6, 7]) * x
for x in (3, 10, 100)
]
for message in messages:
client_channel.write(message)
await client_channel.drain()
await asyncio.sleep(0)
server_channel.write(message)
await server_channel.drain()
await client_channel.disconnect()
message_bytes = b''.join(messages)
client_received_bytes = b''.join(client_received)
server_received_bytes = b''.join(server_received)
assert(client_received_bytes == message_bytes)
assert(server_received_bytes == message_bytes)
# -----------------------------------------------------------------------------
async def run():
test_helpers()
await test_basic_connection()
await test_transfer()
await test_bidirectional_transfer()
# -----------------------------------------------------------------------------
if __name__ == '__main__':
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(run())