Add L2CAP CoC support (squashed)

[85542e0] fix test
[3748781] add ASAH sink example
[e782e29] add app
[83daa30] wip
[7f138a0] add test
[f732108] allow different address syntax
[9d0bbf8] rename deprecated methods
[eb303d5] add LE CoC support
This commit is contained in:
Gilles Boccon-Gibod
2022-06-02 11:44:01 -07:00
committed by Gilles Boccon-Gibod
parent be8f8ac68f
commit ce9004f0ac
19 changed files with 1882 additions and 187 deletions

View File

@@ -17,13 +17,14 @@
# -----------------------------------------------------------------------------
import asyncio
import os
import struct
import logging
import click
from colors import color
from bumble.device import Device, Peer
from bumble.core import AdvertisingData
from bumble.gatt import Service, Characteristic
from bumble.gatt import Service, Characteristic, CharacteristicValue
from bumble.utils import AsyncRunner
from bumble.transport import open_transport_or_link
from bumble.hci import HCI_Constant
@@ -41,13 +42,59 @@ GG_PREFERRED_MTU = 256
# -----------------------------------------------------------------------------
class GattlinkHubBridge(Device.Listener):
class GattlinkL2capEndpoint:
def __init__(self):
self.l2cap_channel = None
self.l2cap_packet = b''
self.l2cap_packet_size = 0
# Called when an L2CAP SDU has been received
def on_coc_sdu(self, sdu):
print(color(f'<<< [L2CAP SDU]: {len(sdu)} bytes', 'cyan'))
while len(sdu):
if self.l2cap_packet_size == 0:
# Expect a new packet
self.l2cap_packet_size = sdu[0] + 1
sdu = sdu[1:]
else:
bytes_needed = self.l2cap_packet_size - len(self.l2cap_packet)
chunk = min(bytes_needed, len(sdu))
self.l2cap_packet += sdu[:chunk]
sdu = sdu[chunk:]
if len(self.l2cap_packet) == self.l2cap_packet_size:
self.on_l2cap_packet(self.l2cap_packet)
self.l2cap_packet = b''
self.l2cap_packet_size = 0
# -----------------------------------------------------------------------------
class GattlinkHubBridge(GattlinkL2capEndpoint, Device.Listener):
def __init__(self, device, peer_address):
super().__init__()
self.device = device
self.peer_address = peer_address
self.peer = None
self.rx_socket = None
self.tx_socket = None
self.rx_characteristic = None
self.tx_characteristic = None
self.l2cap_psm_characteristic = None
device.listener = self
async def start(self):
# Connect to the peer
print(f'=== Connecting to {self.peer_address}...')
await self.device.connect(self.peer_address)
async def connect_l2cap(self, psm):
print(color(f'### Connecting with L2CAP on PSM = {psm}', 'yellow'))
try:
self.l2cap_channel = await self.peer.connection.open_l2cap_channel(psm)
print(color('*** Connected', 'yellow'), self.l2cap_channel)
self.l2cap_channel.sink = self.on_coc_sdu
except Exception as error:
print(color(f'!!! Connection failed: {error}', 'red'))
@AsyncRunner.run_in_task()
async def on_connection(self, connection):
@@ -80,15 +127,24 @@ class GattlinkHubBridge(Device.Listener):
self.rx_characteristic = characteristic
elif characteristic.uuid == GG_GATTLINK_TX_CHARACTERISTIC_UUID:
self.tx_characteristic = characteristic
elif characteristic.uuid == GG_GATTLINK_L2CAP_CHANNEL_PSM_CHARACTERISTIC_UUID:
self.l2cap_psm_characteristic = characteristic
print('RX:', self.rx_characteristic)
print('TX:', self.tx_characteristic)
print('PSM:', self.l2cap_psm_characteristic)
if self.l2cap_psm_characteristic:
# Subscribe to and then read the PSM value
await self.peer.subscribe(self.l2cap_psm_characteristic, self.on_l2cap_psm_received)
psm_bytes = await self.peer.read_value(self.l2cap_psm_characteristic)
psm = struct.unpack('<H', psm_bytes)[0]
await self.connect_l2cap(psm)
elif self.tx_characteristic:
# Subscribe to TX
if self.tx_characteristic:
await self.peer.subscribe(self.tx_characteristic, self.on_tx_received)
print(color('=== Subscribed to Gattlink TX', 'yellow'))
else:
print(color('!!! Gattlink TX not found', 'red'))
print(color('!!! No Gattlink TX or PSM found', 'red'))
def on_connection_failure(self, error):
print(color(f'!!! Connection failed: {error}'))
@@ -99,31 +155,23 @@ class GattlinkHubBridge(Device.Listener):
self.rx_characteristic = None
self.peer = None
# Called when an L2CAP packet has been received
def on_l2cap_packet(self, packet):
print(color(f'<<< [L2CAP PACKET]: {len(packet)} bytes', 'cyan'))
print(color('>>> [UDP]', 'magenta'))
self.tx_socket.sendto(packet)
# Called by the GATT client when a notification is received
def on_tx_received(self, value):
print(color('>>> TX:', 'magenta'), value.hex())
print(color(f'<<< [GATT TX]: {len(value)} bytes', 'cyan'))
if self.tx_socket:
print(color('>>> [UDP]', 'magenta'))
self.tx_socket.sendto(value)
# Called by asyncio when the UDP socket is created
def connection_made(self, transport):
pass
# Called by asyncio when a UDP datagram is received
def datagram_received(self, data, address):
print(color('<<< RX:', 'magenta'), data.hex())
# TODO: use a queue instead of creating a task everytime
if self.peer and self.rx_characteristic:
asyncio.create_task(self.peer.write_value(self.rx_characteristic, data))
# -----------------------------------------------------------------------------
class GattlinkNodeBridge(Device.Listener):
def __init__(self):
self.peer = None
self.rx_socket = None
self.tx_socket = None
def on_l2cap_psm_received(self, value):
psm = struct.unpack('<H', value)[0]
asyncio.create_task(self.connect_l2cap(psm))
# Called by asyncio when the UDP socket is created
def connection_made(self, transport):
@@ -131,21 +179,130 @@ class GattlinkNodeBridge(Device.Listener):
# Called by asyncio when a UDP datagram is received
def datagram_received(self, data, address):
print(color('<<< RX:', 'magenta'), data.hex())
print(color(f'<<< [UDP]: {len(data)} bytes', 'green'))
# TODO: use a queue instead of creating a task everytime
if self.peer and self.rx_characteristic:
if self.l2cap_channel:
print(color('>>> [L2CAP]', 'yellow'))
self.l2cap_channel.write(bytes([len(data) - 1]) + data)
elif self.peer and self.rx_characteristic:
print(color('>>> [GATT RX]', 'yellow'))
asyncio.create_task(self.peer.write_value(self.rx_characteristic, data))
# -----------------------------------------------------------------------------
async def run(hci_transport, device_address, send_host, send_port, receive_host, receive_port):
class GattlinkNodeBridge(GattlinkL2capEndpoint, Device.Listener):
def __init__(self, device):
super().__init__()
self.device = device
self.peer = None
self.tx_socket = None
self.tx_subscriber = None
self.rx_characteristic = None
# Register as a listener
device.listener = self
# Listen for incoming L2CAP CoC connections
psm = 0xFB
device.register_l2cap_channel_server(0xFB, self.on_coc)
print(f'### Listening for CoC connection on PSM {psm}')
# Setup the Gattlink service
self.rx_characteristic = Characteristic(
GG_GATTLINK_RX_CHARACTERISTIC_UUID,
Characteristic.WRITE_WITHOUT_RESPONSE,
Characteristic.WRITEABLE,
CharacteristicValue(write=self.on_rx_write)
)
self.tx_characteristic = Characteristic(
GG_GATTLINK_TX_CHARACTERISTIC_UUID,
Characteristic.NOTIFY,
Characteristic.READABLE
)
self.tx_characteristic.on('subscription', self.on_tx_subscription)
self.psm_characteristic = Characteristic(
GG_GATTLINK_L2CAP_CHANNEL_PSM_CHARACTERISTIC_UUID,
Characteristic.READ | Characteristic.NOTIFY,
Characteristic.READABLE,
bytes([psm, 0])
)
gattlink_service = Service(
GG_GATTLINK_SERVICE_UUID,
[
self.rx_characteristic,
self.tx_characteristic,
self.psm_characteristic
]
)
device.add_services([gattlink_service])
device.advertising_data = bytes(
AdvertisingData([
(AdvertisingData.COMPLETE_LOCAL_NAME, bytes('Bumble GG', 'utf-8')),
(AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS,
bytes(reversed(bytes.fromhex('ABBAFF00E56A484CB8328B17CF6CBFE8'))))
])
)
async def start(self):
await self.device.start_advertising()
# Called by asyncio when the UDP socket is created
def connection_made(self, transport):
self.transport = transport
# Called by asyncio when a UDP datagram is received
def datagram_received(self, data, address):
print(color(f'<<< [UDP]: {len(data)} bytes', 'green'))
if self.l2cap_channel:
print(color('>>> [L2CAP]', 'yellow'))
self.l2cap_channel.write(bytes([len(data) - 1]) + data)
elif self.tx_subscriber:
print(color('>>> [GATT TX]', 'yellow'))
self.tx_characteristic.value = data
asyncio.create_task(self.device.notify_subscribers(self.tx_characteristic))
# Called when a write to the RX characteristic has been received
def on_rx_write(self, connection, data):
print(color(f'<<< [GATT RX]: {len(data)} bytes', 'cyan'))
print(color('>>> [UDP]', 'magenta'))
self.tx_socket.sendto(data)
# Called when the subscription to the TX characteristic has changed
def on_tx_subscription(self, peer, enabled):
print(f'### [GATT TX] subscription from {peer}: {"enabled" if enabled else "disabled"}')
if enabled:
self.tx_subscriber = peer
else:
self.tx_subscriber = None
# Called when an L2CAP packet is received
def on_l2cap_packet(self, packet):
print(color(f'<<< [L2CAP PACKET]: {len(packet)} bytes', 'cyan'))
print(color('>>> [UDP]', 'magenta'))
self.tx_socket.sendto(packet)
# Called when a new connection is established
def on_coc(self, channel):
print('*** CoC Connection', channel)
self.l2cap_channel = channel
channel.sink = self.on_coc_sdu
# -----------------------------------------------------------------------------
async def run(hci_transport, device_address, role_or_peer_address, send_host, send_port, receive_host, receive_port):
print('<<< connecting to HCI...')
async with await open_transport_or_link(hci_transport) as (hci_source, hci_sink):
print('<<< connected')
# Instantiate a bridge object
bridge = GattlinkNodeBridge()
device = Device.with_hci('Bumble GG', device_address, hci_source, hci_sink)
# Instantiate a bridge object
if role_or_peer_address == 'node':
bridge = GattlinkNodeBridge(device)
else:
bridge = GattlinkHubBridge(device, role_or_peer_address)
# Create a UDP to RX bridge (receive from UDP, send to RX)
loop = asyncio.get_running_loop()
@@ -160,35 +317,8 @@ async def run(hci_transport, device_address, send_host, send_port, receive_host,
remote_addr=(send_host, send_port)
)
# Create a device to manage the host, with a custom listener
device = Device.with_hci('Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink)
device.listener = bridge
await device.power_on()
# Connect to the peer
# print(f'=== Connecting to {device_address}...')
# await device.connect(device_address)
# TODO move to class
gattlink_service = Service(
GG_GATTLINK_SERVICE_UUID,
[
Characteristic(
GG_GATTLINK_L2CAP_CHANNEL_PSM_CHARACTERISTIC_UUID,
Characteristic.READ,
Characteristic.READABLE,
bytes([193, 0])
)
]
)
device.add_services([gattlink_service])
device.advertising_data = bytes(
AdvertisingData([
(AdvertisingData.COMPLETE_LOCAL_NAME, bytes('Bumble GG', 'utf-8')),
(AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS, bytes(reversed(bytes.fromhex('ABBAFF00E56A484CB8328B17CF6CBFE8'))))
])
)
await device.start_advertising()
await bridge.start()
# Wait until the source terminates
await hci_source.wait_for_termination()
@@ -197,15 +327,16 @@ async def run(hci_transport, device_address, send_host, send_port, receive_host,
@click.command()
@click.argument('hci_transport')
@click.argument('device_address')
@click.argument('role_or_peer_address')
@click.option('-sh', '--send-host', type=str, default='127.0.0.1', help='UDP host to send to')
@click.option('-sp', '--send-port', type=int, default=9001, help='UDP port to send to')
@click.option('-rh', '--receive-host', type=str, default='127.0.0.1', help='UDP host to receive on')
@click.option('-rp', '--receive-port', type=int, default=9000, help='UDP port to receive on')
def main(hci_transport, device_address, send_host, send_port, receive_host, receive_port):
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(run(hci_transport, device_address, send_host, send_port, receive_host, receive_port))
def main(hci_transport, device_address, role_or_peer_address, send_host, send_port, receive_host, receive_port):
asyncio.run(run(hci_transport, device_address, role_or_peer_address, send_host, send_port, receive_host, receive_port))
# -----------------------------------------------------------------------------
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
if __name__ == '__main__':
main()

331
apps/l2cap_bridge.py Normal file
View File

@@ -0,0 +1,331 @@
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import click
import logging
import os
from colors import color
from bumble.transport import open_transport_or_link
from bumble.device import Device
from bumble.utils import FlowControlAsyncPipe
from bumble.hci import HCI_Constant
# -----------------------------------------------------------------------------
class ServerBridge:
"""
L2CAP CoC server bridge: waits for a peer to connect an L2CAP CoC channel
on a specified PSM. When the connection is made, the bridge connects a TCP
socket to a remote host and bridges the data in both directions, with flow
control.
When the L2CAP CoC channel is closed, the bridge disconnects the TCP socket
and waits for a new L2CAP CoC channel to be connected.
When the TCP connection is closed by the TCP server, XXXX
"""
def __init__(
self,
psm,
max_credits,
mtu,
mps,
tcp_host,
tcp_port
):
self.psm = psm
self.max_credits = max_credits
self.mtu = mtu
self.mps = mps
self.tcp_host = tcp_host
self.tcp_port = tcp_port
async def start(self, device):
# Listen for incoming L2CAP CoC connections
device.register_l2cap_channel_server(
psm = self.psm,
server = self.on_coc,
max_credits = self.max_credits,
mtu = self.mtu,
mps = self.mps
)
print(color(f'### Listening for CoC connection on PSM {self.psm}', 'yellow'))
def on_ble_connection(connection):
def on_ble_disconnection(reason):
print(color('@@@ Bluetooth disconnection:', 'red'), HCI_Constant.error_name(reason))
print(color('@@@ Bluetooth connection:', 'green'), connection)
connection.on('disconnection', on_ble_disconnection)
device.on('connection', on_ble_connection)
await device.start_advertising(auto_restart=True)
# Called when a new L2CAP connection is established
def on_coc(self, l2cap_channel):
print(color('*** L2CAP channel:', 'cyan'), l2cap_channel)
class Pipe:
def __init__(self, bridge, l2cap_channel):
self.bridge = bridge
self.tcp_transport = None
self.l2cap_channel = l2cap_channel
l2cap_channel.on('close', self.on_l2cap_close)
l2cap_channel.sink = self.on_coc_sdu
async def connect_to_tcp(self):
# Connect to the TCP server
print(color(f'### Connecting to TCP {self.bridge.tcp_host}:{self.bridge.tcp_port}...', 'yellow'))
class TcpClientProtocol(asyncio.Protocol):
def __init__(self, pipe):
self.pipe = pipe
def connection_lost(self, error):
print(color(f'!!! TCP connection lost: {error}', 'red'))
if self.pipe.l2cap_channel is not None:
asyncio.create_task(self.pipe.l2cap_channel.disconnect())
def data_received(self, data):
print(f'<<< Received on TCP: {len(data)}')
self.pipe.l2cap_channel.write(data)
try:
self.tcp_transport, _ = await asyncio.get_running_loop().create_connection(
lambda: TcpClientProtocol(self),
host=self.bridge.tcp_host,
port=self.bridge.tcp_port,
)
print(color('### Connected', 'green'))
except Exception as error:
print(color(f'!!! Connection failed: {error}', 'red'))
await self.l2cap_channel.disconnect()
def on_l2cap_close(self):
self.l2cap_channel = None
if self.tcp_transport is not None:
self.tcp_transport.close()
def on_coc_sdu(self, sdu):
print(color(f'<<< [L2CAP SDU]: {len(sdu)} bytes', 'cyan'))
if self.tcp_transport is None:
print(color('!!! TCP socket not open, dropping', 'red'))
return
self.tcp_transport.write(sdu)
pipe = Pipe(self, l2cap_channel)
asyncio.create_task(pipe.connect_to_tcp())
# -----------------------------------------------------------------------------
class ClientBridge:
"""
L2CAP CoC client bridge: connects to a BLE device, then waits for an inbound
TCP connection on a specified port number. When a TCP client connects, an
L2CAP CoC channel connection to the BLE device is established, and the data
is bridged in both directions, with flow control.
When the TCP connection is closed by the client, the L2CAP CoC channel is
disconnected, but the connection to the BLE device remains, ready for a new
TCP client to connect.
When the L2CAP CoC channel is closed, XXXX
"""
READ_CHUNK_SIZE = 4096
def __init__(
self,
psm,
max_credits,
mtu,
mps,
address,
tcp_host,
tcp_port
):
self.psm = psm
self.max_credits = max_credits
self.mtu = mtu
self.mps = mps
self.address = address
self.tcp_host = tcp_host
self.tcp_port = tcp_port
async def start(self, device):
print(color(f'### Connecting to {self.address}...', 'yellow'))
connection = await device.connect(self.address)
print(color('### Connected', 'green'))
# Called when the BLE connection is disconnected
def on_ble_disconnection(reason):
print(color('@@@ Bluetooth disconnection:', 'red'), HCI_Constant.error_name(reason))
connection.on('disconnection', on_ble_disconnection)
# Called when a TCP connection is established
async def on_tcp_connection(reader, writer):
peername = writer.get_extra_info('peername')
print(color(f'<<< TCP connection from {peername}', 'magenta'))
def on_coc_sdu(sdu):
print(color(f'<<< [L2CAP SDU]: {len(sdu)} bytes', 'cyan'))
l2cap_to_tcp_pipe.write(sdu)
def on_l2cap_close():
print(color('*** L2CAP channel closed', 'red'))
l2cap_to_tcp_pipe.stop()
writer.close()
# Connect a new L2CAP channel
print(color(f'>>> Opening L2CAP channel on PSM = {self.psm}', 'yellow'))
try:
l2cap_channel = await connection.open_l2cap_channel(
psm = self.psm,
max_credits = self.max_credits,
mtu = self.mtu,
mps = self.mps
)
print(color('*** L2CAP channel:', 'cyan'), l2cap_channel)
except Exception as error:
print(color(f'!!! Connection failed: {error}', 'red'))
writer.close()
return
l2cap_channel.sink = on_coc_sdu
l2cap_channel.on('close', on_l2cap_close)
# Start a flow control pipe from L2CAP to TCP
l2cap_to_tcp_pipe = FlowControlAsyncPipe(
l2cap_channel.pause_reading,
l2cap_channel.resume_reading,
writer.write,
writer.drain
)
l2cap_to_tcp_pipe.start()
# Pipe data from TCP to L2CAP
while True:
try:
data = await reader.read(self.READ_CHUNK_SIZE)
if len(data) == 0:
print(color('!!! End of stream', 'red'))
await l2cap_channel.disconnect()
return
print(color(f'<<< [TCP DATA]: {len(data)} bytes', 'blue'))
l2cap_channel.write(data)
await l2cap_channel.drain()
except Exception as error:
print(f'!!! Exception: {error}')
break
writer.close()
print(color('~~~ Bye bye', 'magenta'))
await asyncio.start_server(
on_tcp_connection,
host=self.tcp_host if self.tcp_host != '_' else None,
port=self.tcp_port
)
print(color(f'### Listening for TCP connections on port {self.tcp_port}', 'magenta'))
# -----------------------------------------------------------------------------
async def run(device_config, hci_transport, bridge):
print('<<< connecting to HCI...')
async with await open_transport_or_link(hci_transport) as (hci_source, hci_sink):
print('<<< connected')
device = Device.from_config_file_with_hci(device_config, hci_source, hci_sink)
# Let's go
await device.power_on()
await bridge.start(device)
# Wait until the transport terminates
await hci_source.wait_for_termination()
# -----------------------------------------------------------------------------
@click.group()
@click.pass_context
@click.option('--device-config', help='Device configuration file', required=True)
@click.option('--hci-transport', help='HCI transport', required=True)
@click.option('--psm', help='PSM for L2CAP CoC', type=int, default=1234)
@click.option('--l2cap-coc-max-credits', help='Maximum L2CAP CoC Credits', type=click.IntRange(1, 65535), default=128)
@click.option('--l2cap-coc-mtu', help='L2CAP CoC MTU', type=click.IntRange(23, 65535), default=1022)
@click.option('--l2cap-coc-mps', help='L2CAP CoC MPS', type=click.IntRange(23, 65533), default=1024)
def cli(context, device_config, hci_transport, psm, l2cap_coc_max_credits, l2cap_coc_mtu, l2cap_coc_mps):
context.ensure_object(dict)
context.obj['device_config'] = device_config
context.obj['hci_transport'] = hci_transport
context.obj['psm'] = psm
context.obj['max_credits'] = l2cap_coc_max_credits
context.obj['mtu'] = l2cap_coc_mtu
context.obj['mps'] = l2cap_coc_mps
# -----------------------------------------------------------------------------
@cli.command()
@click.pass_context
@click.option('--tcp-host', help='TCP host', default='localhost')
@click.option('--tcp-port', help='TCP port', default=9544)
def server(context, tcp_host, tcp_port):
bridge = ServerBridge(
context.obj['psm'],
context.obj['max_credits'],
context.obj['mtu'],
context.obj['mps'],
tcp_host,
tcp_port)
asyncio.run(run(
context.obj['device_config'],
context.obj['hci_transport'],
bridge
))
# -----------------------------------------------------------------------------
@cli.command()
@click.pass_context
@click.argument('bluetooth-address')
@click.option('--tcp-host', help='TCP host', default='_')
@click.option('--tcp-port', help='TCP port', default=9543)
def client(context, bluetooth_address, tcp_host, tcp_port):
bridge = ClientBridge(
context.obj['psm'],
context.obj['max_credits'],
context.obj['mtu'],
context.obj['mps'],
bluetooth_address,
tcp_host,
tcp_port
)
asyncio.run(run(
context.obj['device_config'],
context.obj['hci_transport'],
bridge
))
# -----------------------------------------------------------------------------
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
if __name__ == '__main__':
cli(obj={})

View File

@@ -351,7 +351,7 @@ class MediaPacketPump:
logger.debug('pump canceled')
# Pump packets
self.pump_task = asyncio.get_running_loop().create_task(pump_packets())
self.pump_task = asyncio.create_task(pump_packets())
async def stop(self):
# Stop the pump
@@ -1890,10 +1890,10 @@ class LocalSource(LocalStreamEndPoint, EventEmitter):
self.configuration = configuration
def on_start_command(self):
asyncio.get_running_loop().create_task(self.start())
asyncio.create_task(self.start())
def on_suspend_command(self):
asyncio.get_running_loop().create_task(self.stop())
asyncio.create_task(self.stop())
# -----------------------------------------------------------------------------

View File

@@ -43,6 +43,13 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
DEVICE_MIN_SCAN_INTERVAL = 25
DEVICE_MAX_SCAN_INTERVAL = 10240
DEVICE_MIN_SCAN_WINDOW = 25
DEVICE_MAX_SCAN_WINDOW = 10240
DEVICE_MIN_LE_RSSI = -127
DEVICE_MAX_LE_RSSI = 20
DEVICE_DEFAULT_ADDRESS = '00:00:00:00:00:00'
DEVICE_DEFAULT_ADVERTISING_INTERVAL = 1000 # ms
DEVICE_DEFAULT_ADVERTISING_DATA = ''
@@ -62,20 +69,15 @@ DEVICE_DEFAULT_CONNECTION_MAX_LATENCY = 0
DEVICE_DEFAULT_CONNECTION_SUPERVISION_TIMEOUT = 720 # ms
DEVICE_DEFAULT_CONNECTION_MIN_CE_LENGTH = 0 # ms
DEVICE_DEFAULT_CONNECTION_MAX_CE_LENGTH = 0 # ms
DEVICE_MIN_SCAN_INTERVAL = 25
DEVICE_MAX_SCAN_INTERVAL = 10240
DEVICE_MIN_SCAN_WINDOW = 25
DEVICE_MAX_SCAN_WINDOW = 10240
DEVICE_MIN_LE_RSSI = -127
DEVICE_MAX_LE_RSSI = 20
DEVICE_DEFAULT_L2CAP_COC_MTU = l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU
DEVICE_DEFAULT_L2CAP_COC_MPS = l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS
DEVICE_DEFAULT_L2CAP_COC_MAX_CREDITS = l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
# -----------------------------------------------------------------------------
class Advertisement:
TX_POWER_NOT_AVAILABLE = HCI_LE_Extended_Advertising_Report_Event.TX_POWER_INFORMATION_NOT_AVAILABLE
@@ -429,6 +431,15 @@ class Connection(CompositeEventEmitter):
def create_l2cap_connector(self, psm):
return self.device.create_l2cap_connector(self, psm)
async def open_l2cap_channel(
self,
psm,
max_credits=DEVICE_DEFAULT_L2CAP_COC_MAX_CREDITS,
mtu=DEVICE_DEFAULT_L2CAP_COC_MTU,
mps=DEVICE_DEFAULT_L2CAP_COC_MPS
):
return await self.device.open_l2cap_channel(self, psm, max_credits, mtu, mps)
async def disconnect(self, reason=HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR):
return await self.device.disconnect(self, reason)
@@ -563,6 +574,7 @@ class DeviceConfiguration:
with open(filename, 'r') as file:
self.load_from_dict(json.load(file))
# -----------------------------------------------------------------------------
# Decorators used with the following Device class
# (we define them outside of the Device class, because defining decorators
@@ -785,15 +797,35 @@ class Device(CompositeEventEmitter):
if transport is None or connection.transport == transport:
return connection
def register_l2cap_server(self, psm, server):
self.l2cap_channel_manager.register_server(psm, server)
def create_l2cap_connector(self, connection, psm):
return lambda: self.l2cap_channel_manager.connect(connection, psm)
def create_l2cap_registrar(self, psm):
return lambda handler: self.register_l2cap_server(psm, handler)
def register_l2cap_server(self, psm, server):
self.l2cap_channel_manager.register_server(psm, server)
def register_l2cap_channel_server(
self,
psm,
server,
max_credits=DEVICE_DEFAULT_L2CAP_COC_MAX_CREDITS,
mtu=DEVICE_DEFAULT_L2CAP_COC_MTU,
mps=DEVICE_DEFAULT_L2CAP_COC_MPS
):
return self.l2cap_channel_manager.register_le_coc_server(psm, server, max_credits, mtu, mps)
async def open_l2cap_channel(
self,
connection,
psm,
max_credits=DEVICE_DEFAULT_L2CAP_COC_MAX_CREDITS,
mtu=DEVICE_DEFAULT_L2CAP_COC_MTU,
mps=DEVICE_DEFAULT_L2CAP_COC_MPS
):
return await self.l2cap_channel_manager.open_le_coc(connection, psm, max_credits, mtu, mps)
def send_l2cap_pdu(self, connection_handle, cid, pdu):
self.host.send_l2cap_pdu(connection_handle, cid, pdu)
@@ -1185,13 +1217,15 @@ class Device(CompositeEventEmitter):
def on_connection(connection):
if transport == BT_LE_TRANSPORT or (
# match BR/EDR connection event against peer address
connection.transport == transport and connection.peer_address == peer_address):
connection.transport == transport and connection.peer_address == peer_address
):
pending_connection.set_result(connection)
def on_connection_failure(error):
if transport == BT_LE_TRANSPORT or (
# match BR/EDR connection failure event against peer address
error.transport == transport and error.peer_address == peer_address):
error.transport == transport and error.peer_address == peer_address
):
pending_connection.set_exception(error)
# Create a future so that we can wait for the connection's result
@@ -1349,8 +1383,7 @@ class Device(CompositeEventEmitter):
try:
# Wait for a request or a completed connection
result = await (asyncio.wait_for(pending_request, timeout) if timeout else pending_request)
except:
except Exception:
# Remove future from device context
if peer_address == Address.ANY:
self.classic_pending_accepts[Address.ANY].remove(pending_request)
@@ -1710,26 +1743,32 @@ class Device(CompositeEventEmitter):
connection.remove_listener('connection_encryption_failure', on_encryption_failure)
# [Classic only]
async def request_remote_name(self, remote: Connection | Address):
async def request_remote_name(self, remote): # remote: Connection | Address
# Set up event handlers
pending_name = asyncio.get_running_loop().create_future()
if type(remote) == Address:
peer_address = remote
handler = self.on('remote_name',
handler = self.on(
'remote_name',
lambda address, remote_name:
pending_name.set_result(remote_name) if address == remote else None)
failure_handler = self.on('remote_name_failure',
pending_name.set_result(remote_name) if address == remote else None
)
failure_handler = self.on(
'remote_name_failure',
lambda address, error_code:
pending_name.set_exception(HCI_Error(error_code)) if address == remote else None)
pending_name.set_exception(HCI_Error(error_code)) if address == remote else None
)
else:
peer_address = remote.peer_address
handler = remote.on('remote_name',
lambda:
pending_name.set_result(remote.peer_name))
failure_handler = remote.on('remote_name_failure',
lambda error_code:
pending_name.set_exception(HCI_Error(error_code)))
handler = remote.on(
'remote_name',
lambda: pending_name.set_result(remote.peer_name)
)
failure_handler = remote.on(
'remote_name_failure',
lambda error_code: pending_name.set_exception(HCI_Error(error_code))
)
try:
result = await self.send_command(
@@ -2097,7 +2136,6 @@ class Device(CompositeEventEmitter):
else:
self.emit('remote_name_failure', address, error)
# [Classic only]
@host_event_handler
@try_with_connection_from_address

View File

@@ -25,6 +25,7 @@
import asyncio
import types
import logging
from pyee import EventEmitter
from colors import color
from .core import *

View File

@@ -273,7 +273,7 @@ class Client:
if response.op_code == ATT_ERROR_RESPONSE:
if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR:
# Unexpected end
logger.waning(f'!!! unexpected error while discovering services: {HCI_Constant.error_name(response.error_code)}')
logger.warning(f'!!! unexpected error while discovering services: {HCI_Constant.error_name(response.error_code)}')
# TODO raise appropriate exception
return
break
@@ -337,7 +337,7 @@ class Client:
if response.op_code == ATT_ERROR_RESPONSE:
if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR:
# Unexpected end
logger.waning(f'!!! unexpected error while discovering services: {HCI_Constant.error_name(response.error_code)}')
logger.warning(f'!!! unexpected error while discovering services: {HCI_Constant.error_name(response.error_code)}')
# TODO raise appropriate exception
return
break

View File

@@ -155,7 +155,7 @@ class Server(EventEmitter):
return cccd or bytes([0, 0])
def write_cccd(self, connection, characteristic, value):
logger.debug(f'Subscription update for connection={connection.handle:04X}, handle={characteristic.handle:04X}: {value.hex()}')
logger.debug(f'Subscription update for connection=0x{connection.handle:04X}, handle=0x{characteristic.handle:04X}: {value.hex()}')
# Sanity check
if len(value) != 2:
@@ -204,7 +204,7 @@ class Server(EventEmitter):
logger.debug(f'GATT Notify from server: [0x{connection.handle:04X}] {notification}')
self.send_gatt_pdu(connection.handle, bytes(notification))
async def indicate_subscriber(self, connection, attribute, value=None, force=False):
async def indicate_subscriber(self, connection, attribute, force=False):
# Check if there's a subscriber
if not force:
subscribers = self.subscribers.get(connection.handle)

View File

@@ -2466,9 +2466,10 @@ class HCI_Write_Voice_Setting_Command(HCI_Command):
# -----------------------------------------------------------------------------
@HCI_Command.command()
class HCI_Read_Synchronous_Flow_Control_Enable_Command(HCI_Command):
'''
See Bluetooth spec @ 7.3.36 Write Synchronous Flow Control Enable Command
See Bluetooth spec @ 7.3.36 Read Synchronous Flow Control Enable Command
'''

View File

@@ -79,6 +79,8 @@ class Host(EventEmitter):
self.local_version = None
self.local_supported_commands = bytes(64)
self.local_le_features = 0
self.suggested_max_tx_octets = 251 # Max allowed
self.suggested_max_tx_time = 2120 # Max allowed
self.command_semaphore = asyncio.Semaphore(1)
self.long_term_key_provider = None
self.link_key_provider = None
@@ -138,6 +140,22 @@ class Host(EventEmitter):
f'hc_total_num_le_acl_data_packets={self.hc_total_num_le_acl_data_packets}'
)
if (
self.supports_command(HCI_LE_READ_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND) and
self.supports_command(HCI_LE_WRITE_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND)
):
response = await self.send_command(HCI_LE_Read_Suggested_Default_Data_Length_Command())
suggested_max_tx_octets = response.return_parameters.suggested_max_tx_octets
suggested_max_tx_time = response.return_parameters.suggested_max_tx_time
if (
suggested_max_tx_octets != self.suggested_max_tx_octets or
suggested_max_tx_time != self.suggested_max_tx_time
):
await self.send_command(HCI_LE_Write_Suggested_Default_Data_Length_Command(
suggested_max_tx_octets = self.suggested_max_tx_octets,
suggested_max_tx_time = self.suggested_max_tx_time
))
self.reset_done = True
@property

View File

@@ -19,6 +19,7 @@ import asyncio
import logging
import struct
from collections import deque
from colors import color
from pyee import EventEmitter
@@ -43,13 +44,23 @@ L2CAP_MIN_BR_EDR_MTU = 48
L2CAP_DEFAULT_MTU = 2048 # Default value for the MTU we are willing to accept
L2CAP_DEFAULT_CONNECTIONLESS_MTU = 1024
# See Bluetooth spec @ Vol 3, Part A - Table 2.1: CID name space on ACL-U, ASB-U, and AMP-U logical links
L2CAP_ACL_U_DYNAMIC_CID_RANGE_START = 0x0040
L2CAP_ACL_U_DYNAMIC_CID_RANGE_END = 0xFFFF
# See Bluetooth spec @ Vol 3, Part A - Table 2.2: CID name space on LE-U logical link
L2CAP_LE_U_DYNAMIC_CID_RANGE_START = 0x0040
L2CAP_LE_U_DYNAMIC_CID_RANGE_START = 0x007F
L2CAP_LE_U_DYNAMIC_CID_RANGE_END = 0x007F
# PSM Range - See Bluetooth spec @ Vol 3, Part A / Table 4.5: PSM ranges and usage
L2CAP_PSM_DYNAMIC_RANGE_START = 0x1001
L2CAP_PSM_DYNAMIC_RANGE_END = 0xFFFF
# LE PSM Ranges - See Bluetooth spec @ Vol 3, Part A / Table 4.19: LE Credit Based Connection Request LE_PSM ranges
L2CAP_LE_PSM_DYNAMIC_RANGE_START = 0x0080
L2CAP_LE_PSM_DYNAMIC_RANGE_END = 0x00FF
# Frame types
L2CAP_COMMAND_REJECT = 0x01
@@ -107,8 +118,13 @@ L2CAP_COMMAND_NOT_UNDERSTOOD_REASON = 0x0000
L2CAP_SIGNALING_MTU_EXCEEDED_REASON = 0x0001
L2CAP_INVALID_CID_IN_REQUEST_REASON = 0x0002
L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU = 2048
L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_CREDITS = 65535
L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU = 23
L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MPS = 23
L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MPS = 65533
L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU = 2046
L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS = 2048
L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS = 256
L2CAP_MAXIMUM_TRANSMISSION_UNIT_CONFIGURATION_OPTION_TYPE = 0x01
@@ -172,7 +188,7 @@ class L2CAP_Control_Frame:
self.identifier = pdu[1]
length = struct.unpack_from('<H', pdu, 2)[0]
if length + 4 != len(pdu):
logger.warn(color(f'!!! length mismatch: expected {len(pdu) - 4} but got {length}', 'red'))
logger.warning(color(f'!!! length mismatch: expected {len(pdu) - 4} but got {length}', 'red'))
if hasattr(self, 'fields'):
self.init_from_bytes(pdu, 4)
return self
@@ -268,7 +284,10 @@ class L2CAP_Command_Reject(L2CAP_Control_Frame):
# -----------------------------------------------------------------------------
@L2CAP_Control_Frame.subclass([
('psm', 2),
('psm', {
'parser': lambda data, offset: L2CAP_Connection_Request.parse_psm(data, offset),
'serializer': lambda value: L2CAP_Connection_Request.serialize_psm(value)
}),
('source_cid', 2)
])
class L2CAP_Connection_Request(L2CAP_Control_Frame):
@@ -276,6 +295,28 @@ class L2CAP_Connection_Request(L2CAP_Control_Frame):
See Bluetooth spec @ Vol 3, Part A - 4.2 CONNECTION REQUEST
'''
@staticmethod
def parse_psm(data, offset=0):
psm_length = 2
psm = data[offset] | data[offset + 1] << 8
# The PSM field extends until the first even octet (inclusive)
while data[offset + psm_length - 1] % 2 == 1:
psm |= data[offset + psm_length] << (8 * psm_length)
psm_length += 1
return offset + psm_length, psm
@staticmethod
def serialize_psm(psm):
serialized = struct.pack('<H', psm & 0xFFFF)
psm >>= 16
while psm:
serialized += bytes([psm & 0xFF])
psm >>= 8
return serialized
# -----------------------------------------------------------------------------
@L2CAP_Control_Frame.subclass([
@@ -298,7 +339,7 @@ class L2CAP_Connection_Response(L2CAP_Control_Frame):
CONNECTION_REFUSED_SOURCE_CID_ALREADY_ALLOCATED = 0x0007
CONNECTION_REFUSED_UNACCEPTABLE_PARAMETERS = 0x000B
CONNECTION_RESULT_NAMES = {
RESULT_NAMES = {
CONNECTION_SUCCESSFUL: 'CONNECTION_SUCCESSFUL',
CONNECTION_PENDING: 'CONNECTION_PENDING',
CONNECTION_REFUSED_LE_PSM_NOT_SUPPORTED: 'CONNECTION_REFUSED_LE_PSM_NOT_SUPPORTED',
@@ -311,7 +352,7 @@ class L2CAP_Connection_Response(L2CAP_Control_Frame):
@staticmethod
def result_name(result):
return name_or_number(L2CAP_Connection_Response.CONNECTION_RESULT_NAMES, result)
return name_or_number(L2CAP_Connection_Response.RESULT_NAMES, result)
# -----------------------------------------------------------------------------
@@ -521,7 +562,7 @@ class L2CAP_LE_Credit_Based_Connection_Response(L2CAP_Control_Frame):
CONNECTION_REFUSED_SOURCE_CID_ALREADY_ALLOCATED = 0x000A
CONNECTION_REFUSED_UNACCEPTABLE_PARAMETERS = 0x000B
CONNECTION_RESULT_NAMES = {
RESULT_NAMES = {
CONNECTION_SUCCESSFUL: 'CONNECTION_SUCCESSFUL',
CONNECTION_REFUSED_LE_PSM_NOT_SUPPORTED: 'CONNECTION_REFUSED_LE_PSM_NOT_SUPPORTED',
CONNECTION_REFUSED_NO_RESOURCES_AVAILABLE: 'CONNECTION_REFUSED_NO_RESOURCES_AVAILABLE',
@@ -536,7 +577,7 @@ class L2CAP_LE_Credit_Based_Connection_Response(L2CAP_Control_Frame):
@staticmethod
def result_name(result):
return name_or_number(L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_RESULT_NAMES, result)
return name_or_number(L2CAP_LE_Credit_Based_Connection_Response.RESULT_NAMES, result)
# -----------------------------------------------------------------------------
@@ -619,6 +660,9 @@ class Channel(EventEmitter):
def send_pdu(self, pdu):
self.manager.send_pdu(self.connection, self.destination_cid, pdu)
def send_control_frame(self, frame):
self.manager.send_control_frame(self.connection, self.signaling_cid, frame)
async def send_request(self, request):
# Check that there isn't already a request pending
if self.response:
@@ -637,15 +681,16 @@ class Channel(EventEmitter):
elif self.sink:
self.sink(pdu)
else:
logger.warn(color('received pdu without a pending request or sink', 'red'))
def send_control_frame(self, frame):
self.manager.send_control_frame(self.connection, self.signaling_cid, frame)
logger.warning(color('received pdu without a pending request or sink', 'red'))
async def connect(self):
if self.state != Channel.CLOSED:
raise InvalidStateError('invalid state')
# Check that we can start a new connection
if self.connection_result:
raise RuntimeError('connection already pending')
self.change_state(Channel.WAIT_CONNECT_RSP)
self.send_control_frame(
L2CAP_Connection_Request(
@@ -657,7 +702,12 @@ class Channel(EventEmitter):
# Create a future to wait for the state machine to get to a success or error state
self.connection_result = asyncio.get_running_loop().create_future()
# Wait for the connection to succeed or fail
try:
return await self.connection_result
finally:
self.connection_result = None
async def disconnect(self):
if self.state != Channel.OPEN:
@@ -708,7 +758,7 @@ class Channel(EventEmitter):
def on_connection_response(self, response):
if self.state != Channel.WAIT_CONNECT_RSP:
logger.warn(color('invalid state', 'red'))
logger.warning(color('invalid state', 'red'))
return
if response.result == L2CAP_Connection_Response.CONNECTION_SUCCESSFUL:
@@ -734,7 +784,7 @@ class Channel(EventEmitter):
self.state != Channel.WAIT_CONFIG_REQ and
self.state != Channel.WAIT_CONFIG_REQ_RSP
):
logger.warn(color('invalid state', 'red'))
logger.warning(color('invalid state', 'red'))
return
# Decode the options
@@ -750,7 +800,7 @@ class Channel(EventEmitter):
source_cid = self.destination_cid,
flags = 0x0000,
result = L2CAP_Configure_Response.SUCCESS,
options = request.options # TODO: don't accept everthing blindly
options = request.options # TODO: don't accept everything blindly
)
)
if self.state == Channel.WAIT_CONFIG:
@@ -777,7 +827,7 @@ class Channel(EventEmitter):
self.connection_result = None
self.emit('open')
else:
logger.warn(color('invalid state', 'red'))
logger.warning(color('invalid state', 'red'))
elif response.result == L2CAP_Configure_Response.FAILURE_UNACCEPTABLE_PARAMETERS:
# Re-configure with what's suggested in the response
self.send_control_frame(
@@ -789,7 +839,7 @@ class Channel(EventEmitter):
)
)
else:
logger.warn(color(f'!!! configuration rejected: {L2CAP_Configure_Response.result_name(response.result)}', 'red'))
logger.warning(color(f'!!! configuration rejected: {L2CAP_Configure_Response.result_name(response.result)}', 'red'))
# TODO: decide how to fail gracefully
def on_disconnection_request(self, request):
@@ -805,15 +855,15 @@ class Channel(EventEmitter):
self.emit('close')
self.manager.on_channel_closed(self)
else:
logger.warn(color('invalid state', 'red'))
logger.warning(color('invalid state', 'red'))
def on_disconnection_response(self, response):
if self.state != Channel.WAIT_DISCONNECT:
logger.warn(color('invalid state', 'red'))
logger.warning(color('invalid state', 'red'))
return
if response.destination_cid != self.destination_cid or response.source_cid != self.source_cid:
logger.warn('unexpected source or destination CID')
logger.warning('unexpected source or destination CID')
return
self.change_state(Channel.CLOSED)
@@ -827,23 +877,363 @@ class Channel(EventEmitter):
return f'Channel({self.source_cid}->{self.destination_cid}, PSM={self.psm}, MTU={self.mtu}, state={Channel.STATE_NAMES[self.state]})'
# -----------------------------------------------------------------------------
class LeConnectionOrientedChannel(EventEmitter):
"""
LE Credit-based Connection Oriented Channel
"""
INIT = 0
CONNECTED = 1
CONNECTING = 2
DISCONNECTING = 3
DISCONNECTED = 4
CONNECTION_ERROR = 5
STATE_NAMES = {
INIT: 'INIT',
CONNECTED: 'CONNECTED',
CONNECTING: 'CONNECTING',
DISCONNECTING: 'DISCONNECTING',
DISCONNECTED: 'DISCONNECTED',
CONNECTION_ERROR: 'CONNECTION_ERROR'
}
@staticmethod
def state_name(state):
return name_or_number(LeConnectionOrientedChannel.STATE_NAMES, state)
def __init__(
self,
manager,
connection,
le_psm,
source_cid,
destination_cid,
mtu,
mps,
credits,
peer_mtu,
peer_mps,
peer_credits,
connected
):
super().__init__()
self.manager = manager
self.connection = connection
self.le_psm = le_psm
self.source_cid = source_cid
self.destination_cid = destination_cid
self.mtu = mtu
self.mps = mps
self.credits = credits
self.peer_mtu = peer_mtu
self.peer_mps = peer_mps
self.peer_credits = peer_credits
self.peer_max_credits = self.peer_credits
self.peer_credits_threshold = self.peer_max_credits // 2
self.in_sdu = None
self.in_sdu_length = 0
self.out_queue = deque()
self.out_sdu = None
self.sink = None
self.connection_result = None
self.disconnection_result = None
self.drained = asyncio.Event()
self.drained.set()
if connected:
self.state = LeConnectionOrientedChannel.CONNECTED
else:
self.state = LeConnectionOrientedChannel.INIT
def change_state(self, new_state):
logger.debug(f'{self} state change -> {color(self.state_name(new_state), "cyan")}')
self.state = new_state
if new_state == self.CONNECTED:
self.emit('open')
elif new_state == self.DISCONNECTED:
self.emit('close')
def send_pdu(self, pdu):
self.manager.send_pdu(self.connection, self.destination_cid, pdu)
def send_control_frame(self, frame):
self.manager.send_control_frame(self.connection, L2CAP_LE_SIGNALING_CID, frame)
async def connect(self):
# Check that we're in the right state
if self.state != self.INIT:
raise InvalidStateError('not in a connectable state')
# Check that we can start a new connection
identifier = self.manager.next_identifier(self.connection)
if identifier in self.manager.le_coc_requests:
raise RuntimeError('too many concurrent connection requests')
self.change_state(self.CONNECTING)
request = L2CAP_LE_Credit_Based_Connection_Request(
identifier = identifier,
le_psm = self.le_psm,
source_cid = self.source_cid,
mtu = self.mtu,
mps = self.mps,
initial_credits = self.peer_credits
)
self.manager.le_coc_requests[identifier] = request
self.send_control_frame(request)
# Create a future to wait for the response
self.connection_result = asyncio.get_running_loop().create_future()
# Wait for the connection to succeed or fail
return await self.connection_result
async def disconnect(self):
# Check that we're connected
if self.state != self.CONNECTED:
raise InvalidStateError('not connected')
self.change_state(self.DISCONNECTING)
self.flush_output()
self.send_control_frame(
L2CAP_Disconnection_Request(
identifier = self.manager.next_identifier(self.connection),
destination_cid = self.destination_cid,
source_cid = self.source_cid
)
)
# Create a future to wait for the state machine to get to a success or error state
self.disconnection_result = asyncio.get_running_loop().create_future()
return await self.disconnection_result
def on_pdu(self, pdu):
if self.sink is None:
logger.warning('received pdu without a sink')
return
if self.state != self.CONNECTED:
logger.warning('received PDU while not connected, dropping')
# Manage the peer credits
if self.peer_credits == 0:
logger.warning('received LE frame when peer out of credits')
else:
self.peer_credits -= 1
if self.peer_credits <= self.peer_credits_threshold:
# The credits fell below the threshold, replenish them to the max
self.send_control_frame(
L2CAP_LE_Flow_Control_Credit(
identifier = self.manager.next_identifier(self.connection),
cid = self.source_cid,
credits = self.peer_max_credits - self.peer_credits
)
)
self.peer_credits = self.peer_max_credits
# Check if this starts a new SDU
if self.in_sdu is None:
# Start a new SDU
self.in_sdu = pdu
else:
# Continue an SDU
self.in_sdu += pdu
# Check if the SDU is complete
if self.in_sdu_length == 0:
# We don't know the size yet, check if we have received the header to compute it
if len(self.in_sdu) >= 2:
self.in_sdu_length = struct.unpack_from('<H', self.in_sdu, 0)[0]
if self.in_sdu_length == 0:
# We'll compute it later
return
if len(self.in_sdu) < 2 + self.in_sdu_length:
# Not complete yet
logger.debug(f'SDU: {len(self.in_sdu) - 2} of {self.in_sdu_length} bytes received')
return
if len(self.in_sdu) != 2 + self.in_sdu_length:
# Overflow
logger.warning(f'SDU overflow: sdu_length={self.in_sdu_length}, received {len(self.in_sdu) - 2}')
# TODO: we should disconnect
self.in_sdu = None
self.in_sdu_length = 0
return
# Send the SDU to the sink
logger.debug(f'SDU complete: 2+{len(self.in_sdu) - 2} bytes')
self.sink(self.in_sdu[2:])
# Prepare for a new SDU
self.in_sdu = None
self.in_sdu_length = 0
def on_connection_response(self, response):
# Look for a matching pending response result
if self.connection_result is None:
logger.warning(f'received unexpected connection response (id={response.identifier})')
return
if response.result == L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_SUCCESSFUL:
self.destination_cid = response.destination_cid
self.peer_mtu = response.mtu
self.peer_mps = response.mps
self.credits = response.initial_credits
self.connected = True
self.connection_result.set_result(self)
self.change_state(self.CONNECTED)
else:
self.connection_result.set_exception(
ProtocolError(
response.result,
'l2cap',
L2CAP_LE_Credit_Based_Connection_Response.result_name(response.result))
)
self.change_state(self.CONNECTION_ERROR)
# Cleanup
self.connection_result = None
def on_credits(self, credits):
self.credits += credits
logger.debug(f'received {credits} credits, total = {self.credits}')
# Try to send more data if we have any queued up
self.process_output()
def on_disconnection_request(self, request):
self.send_control_frame(
L2CAP_Disconnection_Response(
identifier = request.identifier,
destination_cid = request.destination_cid,
source_cid = request.source_cid
)
)
self.change_state(self.DISCONNECTED)
self.flush_output()
def on_disconnection_response(self, response):
if self.state != self.DISCONNECTING:
logger.warning(color('invalid state', 'red'))
return
if response.destination_cid != self.destination_cid or response.source_cid != self.source_cid:
logger.warning('unexpected source or destination CID')
return
self.change_state(self.DISCONNECTED)
if self.disconnection_result:
self.disconnection_result.set_result(None)
self.disconnection_result = None
def flush_output(self):
self.out_queue.clear()
self.out_sdu = None
def process_output(self):
while self.credits > 0:
if self.out_sdu is not None:
# Finish the current SDU
packet = self.out_sdu[:self.peer_mps]
self.send_pdu(packet)
self.credits -= 1
logger.debug(f'sent {len(packet)} bytes, {self.credits} credits left')
if len(packet) == len(self.out_sdu):
# We sent everything
self.out_sdu = None
else:
# Keep what's still left to send
self.out_sdu = self.out_sdu[len(packet):]
continue
elif self.out_queue:
# Create the next SDU (2 bytes header plus up to MTU bytes payload)
logger.debug(f'assembling SDU from {len(self.out_queue)} packets in output queue')
payload = b''
while self.out_queue and len(payload) < self.peer_mtu:
# We can add more data to the payload
chunk = self.out_queue[0][:self.peer_mtu - len(payload)]
payload += chunk
self.out_queue[0] = self.out_queue[0][len(chunk):]
if len(self.out_queue[0]) == 0:
# We consumed the entire buffer, remove it
self.out_queue.popleft()
logger.debug(f'packet completed, {len(self.out_queue)} left in queue')
# Construct the SDU with its header
assert len(payload) != 0
logger.debug(f'SDU complete: {len(payload)} payload bytes')
self.out_sdu = struct.pack('<H', len(payload)) + payload
else:
# Nothing left to send for now
self.drained.set()
return
def write(self, data):
if self.state != self.CONNECTED:
logger.warning('not connected, dropping data')
return
# Queue the data
self.out_queue.append(data)
self.drained.clear()
logger.debug(f'{len(data)} bytes packet queued, {len(self.out_queue)} packets in queue')
# Send what we can
self.process_output()
async def drain(self):
await self.drained.wait()
def pause_reading(self):
# TODO: not implemented yet
pass
def resume_reading(self):
# TODO: not implemented yet
pass
def __str__(self):
return f'CoC({self.source_cid}->{self.destination_cid}, State={self.state_name(self.state)}, PSM={self.le_psm}, MTU={self.mtu}/{self.peer_mtu}, MPS={self.mps}/{self.peer_mps}, credits={self.credits}/{self.peer_credits})'
# -----------------------------------------------------------------------------
class ChannelManager:
def __init__(self, extended_features=None, connectionless_mtu=1024):
self.host = None
self.channels = {} # Channels, mapped by connection and cid
# Fixed channel handlers, mapped by cid
self.fixed_channels = {
L2CAP_SIGNALING_CID: None, L2CAP_LE_SIGNALING_CID: None}
def __init__(self, extended_features=[], connectionless_mtu=L2CAP_DEFAULT_CONNECTIONLESS_MTU):
self._host = None
self.identifiers = {} # Incrementing identifier values by connection
self.channels = {} # All channels, mapped by connection and source cid
self.fixed_channels = { # Fixed channel handlers, mapped by cid
L2CAP_SIGNALING_CID: None, L2CAP_LE_SIGNALING_CID: None
}
self.servers = {} # Servers accepting connections, by PSM
self.extended_features = [] if extended_features is None else extended_features
self.le_coc_channels = {} # LE CoC channels, mapped by connection and destination cid
self.le_coc_servers = {} # LE CoC - Servers accepting connections, by PSM
self.le_coc_requests = {} # LE CoC connection requests, by identifier
self.extended_features = extended_features
self.connectionless_mtu = connectionless_mtu
@property
def host(self):
return self._host
@host.setter
def host(self, host):
if self._host is not None:
self._host.remove_listener('disconnection', self.on_disconnection)
self._host = host
if host is not None:
host.add_listener('disconnection', self.on_disconnection)
def find_channel(self, connection_handle, cid):
if connection_channels := self.channels.get(connection_handle):
return connection_channels.get(cid)
def find_le_coc_channel(self, connection_handle, cid):
if connection_channels := self.le_coc_channels.get(connection_handle):
return connection_channels.get(cid)
@staticmethod
def find_free_br_edr_cid(channels):
# Pick the smallest valid CID that's not already in the list
@@ -853,6 +1243,24 @@ class ChannelManager:
if cid not in channels:
return cid
@staticmethod
def find_free_le_cid(channels):
# Pick the smallest valid CID that's not already in the list
# (not necessarily the most efficient algorithm, but the list of CID is
# very small in practice)
for cid in range(L2CAP_LE_U_DYNAMIC_CID_RANGE_START, L2CAP_LE_U_DYNAMIC_CID_RANGE_END + 1):
if cid not in channels:
return cid
@staticmethod
def check_le_coc_parameters(max_credits, mtu, mps):
if max_credits < 1 or max_credits > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_CREDITS:
raise ValueError('max credits out of range')
if mtu < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU:
raise ValueError('MTU too small')
if mps < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MPS or mps > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MPS:
raise ValueError('MPS out of range')
def next_identifier(self, connection):
identifier = (self.identifiers.setdefault(connection.handle, 0) + 1) % 256
self.identifiers[connection.handle] = identifier
@@ -866,8 +1274,77 @@ class ChannelManager:
del self.fixed_channels[cid]
def register_server(self, psm, server):
if psm == 0:
# Find a free PSM
for candidate in range(L2CAP_PSM_DYNAMIC_RANGE_START, L2CAP_PSM_DYNAMIC_RANGE_END + 1, 2):
if (candidate >> 8) % 2 == 1:
continue
if candidate in self.servers:
continue
psm = candidate
break
else:
raise InvalidStateError('no free PSM')
else:
# Check that the PSM isn't already in use
if psm in self.servers:
raise ValueError('PSM already in use')
# Check that the PSM is valid
if psm % 2 == 0:
raise ValueError('invalid PSM (not odd)')
check = psm >> 8
while check:
if check % 2 != 0:
raise ValueError('invalid PSM')
check >>= 8
self.servers[psm] = server
return psm
def register_le_coc_server(
self,
psm,
server,
max_credits=L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS,
mtu=L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU,
mps=L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS
):
self.check_le_coc_parameters(max_credits, mtu, mps)
if psm == 0:
# Find a free PSM
for candidate in range(L2CAP_LE_PSM_DYNAMIC_RANGE_START, L2CAP_LE_PSM_DYNAMIC_RANGE_END + 1):
if candidate in self.le_coc_servers:
continue
psm = candidate
break
else:
raise InvalidStateError('no free PSM')
else:
# Check that the PSM isn't already in use
if psm in self.le_coc_servers:
raise ValueError('PSM already in use')
self.le_coc_servers[psm] = (
server,
max_credits or L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS,
mtu or L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU,
mps or L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS
)
return psm
def on_disconnection(self, connection_handle, reason):
logger.debug(f'disconnection from {connection_handle}, cleaning up channels')
if connection_handle in self.channels:
del self.channels[connection_handle]
if connection_handle in self.le_coc_channels:
del self.le_coc_channels[connection_handle]
if connection_handle in self.identifiers:
del self.identifiers[connection_handle]
def send_pdu(self, connection, cid, pdu):
pdu_str = pdu.hex() if type(pdu) is bytes else str(pdu)
logger.debug(f'{color(">>> Sending L2CAP PDU", "blue")} on connection [0x{connection.handle:04X}] (CID={cid}) {connection.peer_address}: {pdu_str}')
@@ -883,7 +1360,7 @@ class ChannelManager:
self.fixed_channels[cid](connection.handle, pdu)
else:
if (channel := self.find_channel(connection.handle, cid)) is None:
logger.warn(color(f'channel not found for 0x{connection.handle:04X}:{cid}', 'red'))
logger.warning(color(f'channel not found for 0x{connection.handle:04X}:{cid}', 'red'))
return
channel.on_pdu(pdu)
@@ -927,7 +1404,6 @@ class ChannelManager:
def on_l2cap_command_reject(self, connection, cid, packet):
logger.warning(f'{color("!!! Command rejected:", "red")} {packet.reason}')
pass
def on_l2cap_connection_request(self, connection, cid, request):
# Check if there's a server for this PSM
@@ -959,7 +1435,7 @@ class ChannelManager:
server(channel)
channel.on_connection_request(request)
else:
logger.warn(f'No server for connection 0x{connection.handle:04X} on PSM {request.psm}')
logger.warning(f'No server for connection 0x{connection.handle:04X} on PSM {request.psm}')
self.send_control_frame(
connection,
cid,
@@ -974,35 +1450,35 @@ class ChannelManager:
def on_l2cap_connection_response(self, connection, cid, response):
if (channel := self.find_channel(connection.handle, response.source_cid)) is None:
logger.warn(color(f'channel {response.source_cid} not found for 0x{connection.handle:04X}:{cid}', 'red'))
logger.warning(color(f'channel {response.source_cid} not found for 0x{connection.handle:04X}:{cid}', 'red'))
return
channel.on_connection_response(response)
def on_l2cap_configure_request(self, connection, cid, request):
if (channel := self.find_channel(connection.handle, request.destination_cid)) is None:
logger.warn(color(f'channel {request.destination_cid} not found for 0x{connection.handle:04X}:{cid}', 'red'))
logger.warning(color(f'channel {request.destination_cid} not found for 0x{connection.handle:04X}:{cid}', 'red'))
return
channel.on_configure_request(request)
def on_l2cap_configure_response(self, connection, cid, response):
if (channel := self.find_channel(connection.handle, response.source_cid)) is None:
logger.warn(color(f'channel {response.source_cid} not found for 0x{connection.handle:04X}:{cid}', 'red'))
logger.warning(color(f'channel {response.source_cid} not found for 0x{connection.handle:04X}:{cid}', 'red'))
return
channel.on_configure_response(response)
def on_l2cap_disconnection_request(self, connection, cid, request):
if (channel := self.find_channel(connection.handle, request.destination_cid)) is None:
logger.warn(color(f'channel {request.destination_cid} not found for 0x{connection.handle:04X}:{cid}', 'red'))
logger.warning(color(f'channel {request.destination_cid} not found for 0x{connection.handle:04X}:{cid}', 'red'))
return
channel.on_disconnection_request(request)
def on_l2cap_disconnection_response(self, connection, cid, response):
if (channel := self.find_channel(connection.handle, response.source_cid)) is None:
logger.warn(color(f'channel {response.source_cid} not found for 0x{connection.handle:04X}:{cid}', 'red'))
logger.warning(color(f'channel {response.source_cid} not found for 0x{connection.handle:04X}:{cid}', 'red'))
return
channel.on_disconnection_response(response)
@@ -1076,25 +1552,123 @@ class ChannelManager:
)
def on_l2cap_connection_parameter_update_response(self, connection, cid, response):
# TODO: check response
pass
def on_l2cap_le_credit_based_connection_request(self, connection, cid, request):
# FIXME: temp fixed values
if request.le_psm in self.le_coc_servers:
(server, max_credits, mtu, mps) = self.le_coc_servers[request.le_psm]
# Check that the CID isn't already used
le_connection_channels = self.le_coc_channels.setdefault(connection.handle, {})
if request.source_cid in le_connection_channels:
logger.warning(f'source CID {request.source_cid} already in use')
self.send_control_frame(
connection,
cid,
L2CAP_LE_Credit_Based_Connection_Response(
identifier = request.identifier,
destination_cid = 194, # FIXME: for testing only
mtu = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU,
mps = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS,
initial_credits = 3, # FIXME: for testing only
destination_cid = 0,
mtu = mtu,
mps = mps,
initial_credits = 0,
result = L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_REFUSED_SOURCE_CID_ALREADY_ALLOCATED
)
)
return
# Find a free CID for this new channel
connection_channels = self.channels.setdefault(connection.handle, {})
source_cid = self.find_free_le_cid(connection_channels)
if source_cid is None: # Should never happen!
self.send_control_frame(
connection,
cid,
L2CAP_LE_Credit_Based_Connection_Response(
identifier = request.identifier,
destination_cid = 0,
mtu = mtu,
mps = mps,
initial_credits = 0,
result = L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_REFUSED_NO_RESOURCES_AVAILABLE,
)
)
return
# Create a new channel
logger.debug(f'creating LE CoC server channel with cid={source_cid} for psm {request.le_psm}')
channel = LeConnectionOrientedChannel(
self,
connection,
request.le_psm,
source_cid,
request.source_cid,
mtu,
mps,
request.initial_credits,
request.mtu,
request.mps,
max_credits,
True
)
connection_channels[source_cid] = channel
le_connection_channels[request.source_cid] = channel
# Respond
self.send_control_frame(
connection,
cid,
L2CAP_LE_Credit_Based_Connection_Response(
identifier = request.identifier,
destination_cid = source_cid,
mtu = mtu,
mps = mps,
initial_credits = max_credits,
result = L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_SUCCESSFUL
)
)
def on_l2cap_le_flow_control_credit(self, connection, cid, packet):
pass
# Notify
server(channel)
else:
logger.info(f'No LE server for connection 0x{connection.handle:04X} on PSM {request.le_psm}')
self.send_control_frame(
connection,
cid,
L2CAP_LE_Credit_Based_Connection_Response(
identifier = request.identifier,
destination_cid = 0,
mtu = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU,
mps = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS,
initial_credits = 0,
result = L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_REFUSED_LE_PSM_NOT_SUPPORTED,
)
)
def on_l2cap_le_credit_based_connection_response(self, connection, cid, response):
# Find the pending request by identifier
request = self.le_coc_requests.get(response.identifier)
if request is None:
logger.warning(color('!!! received response for unknown request', 'red'))
return
del self.le_coc_requests[response.identifier]
# Find the channel for this request
channel = self.find_channel(connection.handle, request.source_cid)
if channel is None:
logger.warning(color(f'received connection response for an unknown channel (cid={request.source_cid})', 'red'))
return
# Process the response
channel.on_connection_response(response)
def on_l2cap_le_flow_control_credit(self, connection, cid, credit):
channel = self.find_le_coc_channel(connection.handle, credit.cid)
if channel is None:
logger.warning(f'received credits for an unknown channel (cid={credit.cid}')
return
channel.on_credits(credit.credits)
def on_channel_closed(self, channel):
connection_channels = self.channels.get(channel.connection.handle)
@@ -1102,22 +1676,65 @@ class ChannelManager:
if channel.source_cid in connection_channels:
del connection_channels[channel.source_cid]
async def connect(self, connection, psm):
# NOTE: this implementation hard-codes BR/EDR more
# TODO: LE mode (maybe?)
async def open_le_coc(self, connection, psm, max_credits, mtu, mps):
self.check_le_coc_parameters(max_credits, mtu, mps)
# Find a free CID for a new channel
# Find a free CID for the new channel
connection_channels = self.channels.setdefault(connection.handle, {})
cid = self.find_free_br_edr_cid(connection_channels)
if cid is None: # Should never happen!
source_cid = self.find_free_le_cid(connection_channels)
if source_cid is None: # Should never happen!
raise RuntimeError('all CIDs already in use')
# Create the channel
logger.debug(f'creating client channel with cid={cid} for psm {psm}')
channel = Channel(self, connection, L2CAP_SIGNALING_CID, psm, cid, L2CAP_MIN_BR_EDR_MTU)
connection_channels[cid] = channel
logger.debug(f'creating coc channel with cid={source_cid} for psm {psm}')
channel = LeConnectionOrientedChannel(
manager = self,
connection = connection,
le_psm = psm,
source_cid = source_cid,
destination_cid = 0,
mtu = mtu,
mps = mps,
credits = 0,
peer_mtu = 0,
peer_mps = 0,
peer_credits = max_credits,
connected = False
)
connection_channels[source_cid] = channel
# Connect
try:
await channel.connect()
except Exception as error:
logger.warning(f'connection failed: {error}')
del connection_channels[source_cid]
raise
# Remember the channel by source CID and destination CID
le_connection_channels = self.le_coc_channels.setdefault(connection.handle, {})
le_connection_channels[channel.destination_cid] = channel
return channel
async def connect(self, connection, psm):
# NOTE: this implementation hard-codes BR/EDR
# Find a free CID for a new channel
connection_channels = self.channels.setdefault(connection.handle, {})
source_cid = self.find_free_br_edr_cid(connection_channels)
if source_cid is None: # Should never happen!
raise RuntimeError('all CIDs already in use')
# Create the channel
logger.debug(f'creating client channel with cid={source_cid} for psm {psm}')
channel = Channel(self, connection, L2CAP_SIGNALING_CID, psm, source_cid, L2CAP_MIN_BR_EDR_MTU)
connection_channels[source_cid] = channel
# Connect
try:
await channel.connect()
except Exception:
del connection_channels[source_cid]
return channel

View File

@@ -274,7 +274,7 @@ class PumpedPacketSource(ParserSource):
self.terminated.set_result(error)
break
self.pump_task = asyncio.get_running_loop().create_task(pump_packets())
self.pump_task = asyncio.create_task(pump_packets())
def close(self):
if self.pump_task:
@@ -304,7 +304,7 @@ class PumpedPacketSink:
logger.warn(f'exception while sending packet: {error}')
break
self.pump_task = asyncio.get_running_loop().create_task(pump_packets())
self.pump_task = asyncio.create_task(pump_packets())
def close(self):
if self.pump_task:

View File

@@ -18,6 +18,7 @@
import asyncio
import logging
import traceback
import collections
from functools import wraps
from colors import color
from pyee import EventEmitter
@@ -140,3 +141,95 @@ class AsyncRunner:
return wrapper
return decorator
# -----------------------------------------------------------------------------
class FlowControlAsyncPipe:
"""
Asyncio pipe with flow control. When writing to the pipe, the source is
paused (by calling a function passed in when the pipe is created) if the
amount of queued data exceeds a specified threshold.
"""
def __init__(self, pause_source, resume_source, write_to_sink=None, drain_sink=None, threshold=0):
self.pause_source = pause_source
self.resume_source = resume_source
self.write_to_sink = write_to_sink
self.drain_sink = drain_sink
self.threshold = threshold
self.queue = collections.deque() # Queue of packets
self.queued_bytes = 0 # Number of bytes in the queue
self.ready_to_pump = asyncio.Event()
self.paused = False
self.source_paused = False
self.pump_task = None
def start(self):
if self.pump_task is None:
self.pump_task = asyncio.create_task(self.pump())
self.check_pump()
def stop(self):
if self.pump_task is not None:
self.pump_task.cancel()
self.pump_task = None
def write(self, packet):
self.queued_bytes += len(packet)
self.queue.append(packet)
# Pause the source if we're over the threshold
if self.queued_bytes > self.threshold and not self.source_paused:
logger.debug(f'pausing source (queued={self.queued_bytes})')
self.pause_source()
self.source_paused = True
self.check_pump()
def pause(self):
if not self.paused:
self.paused = True
if not self.source_paused:
self.pause_source()
self.source_paused = True
self.check_pump()
def resume(self):
if self.paused:
self.paused = False
if self.source_paused:
self.resume_source()
self.source_paused = False
self.check_pump()
def can_pump(self):
return self.queue and not self.paused and self.write_to_sink is not None
def check_pump(self):
if self.can_pump():
self.ready_to_pump.set()
else:
self.ready_to_pump.clear()
async def pump(self):
while True:
# Wait until we can try to pump packets
await self.ready_to_pump.wait()
# Try to pump a packet
if self.can_pump():
packet = self.queue.pop()
self.write_to_sink(packet)
self.queued_bytes -= len(packet)
# Drain the sink if we can
if self.drain_sink:
await self.drain_sink()
# Check if we can accept more
if self.queued_bytes <= self.threshold and self.source_paused:
logger.debug(f'resuming source (queued={self.queued_bytes})')
self.source_paused = False
self.resume_source()
self.check_pump()

5
examples/asha_sink1.json Normal file
View File

@@ -0,0 +1,5 @@
{
"name": "Bumble Aid Left",
"address": "F1:F2:F3:F4:F5:F6",
"keystore": "JsonKeyStore"
}

5
examples/asha_sink2.json Normal file
View File

@@ -0,0 +1,5 @@
{
"name": "Bumble Aid Right",
"address": "F7:F8:F9:FA:FB:FC",
"keystore": "JsonKeyStore"
}

161
examples/run_asha_sink.py Normal file
View File

@@ -0,0 +1,161 @@
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import struct
import sys
import os
import logging
from bumble.core import AdvertisingData
from bumble.device import Device
from bumble.transport import open_transport_or_link
from bumble.hci import UUID
from bumble.gatt import (
Service,
Characteristic,
CharacteristicValue
)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
ASHA_SERVICE = UUID.from_16_bits(0xFDF0, 'Audio Streaming for Hearing Aid')
ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC = UUID('6333651e-c481-4a3e-9169-7c902aad37bb', 'ReadOnlyProperties')
ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC = UUID('f0d4de7e-4a88-476c-9d9f-1937b0996cc0', 'AudioControlPoint')
ASHA_AUDIO_STATUS_CHARACTERISTIC = UUID('38663f1a-e711-4cac-b641-326b56404837', 'AudioStatus')
ASHA_VOLUME_CHARACTERISTIC = UUID('00e4ca9e-ab14-41e4-8823-f9e70c7e91df', 'Volume')
ASHA_LE_PSM_OUT_CHARACTERISTIC = UUID('2d410339-82b6-42aa-b34e-e2e01df8cc1a', 'LE_PSM_OUT')
# -----------------------------------------------------------------------------
async def main():
if len(sys.argv) != 4:
print('Usage: python run_asha_sink.py <device-config> <transport-spec> <audio-file>')
print('example: python run_asha_sink.py device1.json usb:0 audio_out.g722')
return
audio_out = open(sys.argv[3], 'wb')
async with await open_transport_or_link(sys.argv[2]) as (hci_source, hci_sink):
device = Device.from_config_file_with_hci(sys.argv[1], hci_source, hci_sink)
# Handler for audio control commands
def on_audio_control_point_write(connection, value):
print('--- AUDIO CONTROL POINT Write:', value.hex())
opcode = value[0]
if opcode == 1:
# Start
audio_type = ('Unknown', 'Ringtone', 'Phone Call', 'Media')[value[2]]
print(f'### START: codec={value[1]}, audio_type={audio_type}, volume={value[3]}, otherstate={value[4]}')
elif opcode == 2:
print('### STOP')
elif opcode == 3:
print(f'### STATUS: connected={value[1]}')
# Respond with a status
asyncio.create_task(device.notify_subscribers(audio_status_characteristic, force=True))
# Handler for volume control
def on_volume_write(connection, value):
print('--- VOLUME Write:', value[0])
# Register an L2CAP CoC server
def on_coc(channel):
def on_data(data):
print('<<< Voice data received:', data.hex())
audio_out.write(data)
channel.sink = on_data
psm = device.register_l2cap_channel_server(0, on_coc, 8)
print(f'### LE_PSM_OUT = {psm}')
# Add the ASHA service to the GATT server
read_only_properties_characteristic = Characteristic(
ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC,
Characteristic.READ,
Characteristic.READABLE,
bytes([
0x01, # Version
0x00, # Device Capabilities [Left, Monaural]
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, # HiSyncId
0x01, # Feature Map [LE CoC audio output streaming supported]
0x00, 0x00, # Render Delay
0x00, 0x00, # RFU
0x02, 0x00 # Codec IDs [G.722 at 16 kHz]
])
)
audio_control_point_characteristic = Characteristic(
ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC,
Characteristic.WRITE | Characteristic.WRITE_WITHOUT_RESPONSE,
Characteristic.WRITEABLE,
CharacteristicValue(write=on_audio_control_point_write)
)
audio_status_characteristic = Characteristic(
ASHA_AUDIO_STATUS_CHARACTERISTIC,
Characteristic.READ | Characteristic.NOTIFY,
Characteristic.READABLE,
bytes([0])
)
volume_characteristic = Characteristic(
ASHA_VOLUME_CHARACTERISTIC,
Characteristic.WRITE_WITHOUT_RESPONSE,
Characteristic.WRITEABLE,
CharacteristicValue(write=on_volume_write)
)
le_psm_out_characteristic = Characteristic(
ASHA_LE_PSM_OUT_CHARACTERISTIC,
Characteristic.READ,
Characteristic.READABLE,
struct.pack('<H', psm)
)
device.add_service(Service(
ASHA_SERVICE,
[
read_only_properties_characteristic,
audio_control_point_characteristic,
audio_status_characteristic,
volume_characteristic,
le_psm_out_characteristic
]
))
# Set the advertising data
device.advertising_data = bytes(
AdvertisingData([
(AdvertisingData.COMPLETE_LOCAL_NAME, bytes(device.name, 'utf-8')),
(AdvertisingData.FLAGS, bytes([0x06])),
(AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS, bytes(ASHA_SERVICE)),
(AdvertisingData.SERVICE_DATA_16_BIT_UUID, bytes(ASHA_SERVICE) + bytes([
0x01, # Protocol Version
0x00, # Capability
0x01, 0x02, 0x03, 0x04 # Truncated HiSyncID
]))
])
)
# Go!
await device.power_on()
await device.start_advertising(auto_restart=True)
await hci_source.wait_for_termination()
# -----------------------------------------------------------------------------
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
asyncio.run(main())

View File

@@ -49,11 +49,17 @@ class Listener(Device.Listener, Connection.Listener):
)
# -----------------------------------------------------------------------------
# Alternative way to listen for subscriptions
# -----------------------------------------------------------------------------
def on_my_characteristic_subscription(peer, enabled):
print(f'### My characteristic from {peer}: {"enabled" if enabled else "disabled"}')
# -----------------------------------------------------------------------------
async def main():
if len(sys.argv) < 3:
print('Usage: run_gatt_server.py <device-config> <transport-spec>')
print('example: run_gatt_server.py device1.json usb:0')
print('Usage: run_notifier.py <device-config> <transport-spec>')
print('example: run_notifier.py device1.json usb:0')
return
print('<<< connecting to HCI...')
@@ -83,6 +89,7 @@ async def main():
Characteristic.READABLE,
bytes([0x42])
)
characteristic3.on('subscription', on_my_characteristic_subscription)
custom_service = Service(
'50DB505C-8AC4-4738-8448-3B1D9CC09CC5',
[characteristic1, characteristic2, characteristic3]

View File

@@ -98,6 +98,7 @@ async def list_rfcomm_channels(device, connection):
await sdp_client.disconnect()
# -----------------------------------------------------------------------------
class TcpServerProtocol(asyncio.Protocol):
def __init__(self, rfcomm_session):
@@ -173,7 +174,7 @@ async def main():
print('*** Encryption on')
# Create a client and start it
print('@@@ Starting to RFCOMM client...')
print('@@@ Starting RFCOMM client...')
rfcomm_client = Client(device, connection)
rfcomm_mux = await rfcomm_client.start()
print('@@@ Started')
@@ -192,7 +193,7 @@ async def main():
if len(sys.argv) == 6:
# A TCP port was specified, start listening
tcp_port = int(sys.argv[5])
asyncio.get_running_loop().create_task(tcp_server(tcp_port, session))
asyncio.create_task(tcp_server(tcp_port, session))
await hci_source.wait_for_termination()

View File

@@ -51,6 +51,7 @@ console_scripts =
bumble-controller-info = bumble.apps.controller_info:main
bumble-gatt-dump = bumble.apps.gatt_dump:main
bumble-hci-bridge = bumble.apps.hci_bridge:main
bumble-l2cap-bridge = bumble.apps.l2cap_bridge:main
bumble-pair = bumble.apps.pair:main
bumble-scan = bumble.apps.scan:main
bumble-show = bumble.apps.show:main
@@ -64,6 +65,7 @@ build =
test =
pytest >= 6.2
pytest-asyncio >= 0.17
coverage >= 6.4
development =
invoke >= 1.4
nox >= 2022

284
tests/l2cap_test.py Normal file
View File

@@ -0,0 +1,284 @@
# Copyright 2021-2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
import asyncio
import logging
import os
import random
import pytest
from bumble.controller import Controller
from bumble.link import LocalLink
from bumble.device import Device
from bumble.host import Host
from bumble.transport import AsyncPipeSink
from bumble.core import ProtocolError
from bumble.l2cap import (
L2CAP_Connection_Request
)
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
class TwoDevices:
def __init__(self):
self.connections = [None, None]
self.link = LocalLink()
self.controllers = [
Controller('C1', link = self.link),
Controller('C2', link = self.link)
]
self.devices = [
Device(
address = 'F0:F1:F2:F3:F4:F5',
host = Host(self.controllers[0], AsyncPipeSink(self.controllers[0]))
),
Device(
address = 'F5:F4:F3:F2:F1:F0',
host = Host(self.controllers[1], AsyncPipeSink(self.controllers[1]))
)
]
self.paired = [None, None]
def on_connection(self, which, connection):
self.connections[which] = connection
def on_paired(self, which, keys):
self.paired[which] = keys
# -----------------------------------------------------------------------------
async def setup_connection():
# Create two devices, each with a controller, attached to the same link
two_devices = TwoDevices()
# Attach listeners
two_devices.devices[0].on('connection', lambda connection: two_devices.on_connection(0, connection))
two_devices.devices[1].on('connection', lambda connection: two_devices.on_connection(1, connection))
# Start
await two_devices.devices[0].power_on()
await two_devices.devices[1].power_on()
# Connect the two devices
await two_devices.devices[0].connect(two_devices.devices[1].random_address)
# Check the post conditions
assert(two_devices.connections[0] is not None)
assert(two_devices.connections[1] is not None)
return two_devices
# -----------------------------------------------------------------------------
def test_helpers():
psm = L2CAP_Connection_Request.serialize_psm(0x01)
assert(psm == bytes([0x01, 0x00]))
psm = L2CAP_Connection_Request.serialize_psm(0x1023)
assert(psm == bytes([0x23, 0x10]))
psm = L2CAP_Connection_Request.serialize_psm(0x242311)
assert(psm == bytes([0x11, 0x23, 0x24]))
(offset, psm) = L2CAP_Connection_Request.parse_psm(bytes([0x00, 0x01, 0x00, 0x44]), 1)
assert(offset == 3)
assert(psm == 0x01)
(offset, psm) = L2CAP_Connection_Request.parse_psm(bytes([0x00, 0x23, 0x10, 0x44]), 1)
assert(offset == 3)
assert(psm == 0x1023)
(offset, psm) = L2CAP_Connection_Request.parse_psm(bytes([0x00, 0x11, 0x23, 0x24, 0x44]), 1)
assert(offset == 4)
assert(psm == 0x242311)
rq = L2CAP_Connection_Request(psm = 0x01, source_cid = 0x44)
brq = bytes(rq)
srq = L2CAP_Connection_Request.from_bytes(brq)
assert(srq.psm == rq.psm)
assert(srq.source_cid == rq.source_cid)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_basic_connection():
devices = await setup_connection()
psm = 1234
# Check that if there's no one listening, we can't connect
with pytest.raises(ProtocolError):
l2cap_channel = await devices.connections[0].open_l2cap_channel(psm)
# Now add a listener
incoming_channel = None
received = []
def on_coc(channel):
nonlocal incoming_channel
incoming_channel = channel
def on_data(data):
received.append(data)
channel.sink = on_data
devices.devices[1].register_l2cap_channel_server(psm, on_coc)
l2cap_channel = await devices.connections[0].open_l2cap_channel(psm)
messages = (
bytes([1, 2, 3]),
bytes([4, 5, 6]),
bytes(10000)
)
for message in messages:
l2cap_channel.write(message)
await asyncio.sleep(0)
await l2cap_channel.drain()
# Test closing
closed = [False, False]
closed_event = asyncio.Event()
def on_close(which, event):
closed[which] = True
if event:
event.set()
l2cap_channel.on('close', lambda: on_close(0, None))
incoming_channel.on('close', lambda: on_close(1, closed_event))
await l2cap_channel.disconnect()
assert(closed == [True, True])
await closed_event.wait()
sent_bytes = b''.join(messages)
received_bytes = b''.join(received)
assert(sent_bytes == received_bytes)
# -----------------------------------------------------------------------------
async def transfer_payload(max_credits, mtu, mps):
devices = await setup_connection()
received = []
def on_coc(channel):
def on_data(data):
received.append(data)
channel.sink = on_data
psm = devices.devices[1].register_l2cap_channel_server(
psm = 0,
server = on_coc,
max_credits = max_credits,
mtu = mtu,
mps = mps
)
l2cap_channel = await devices.connections[0].open_l2cap_channel(psm)
messages = [
bytes([1, 2, 3, 4, 5, 6, 7]) * x
for x in (3, 10, 100, 500, 789)
]
for message in messages:
l2cap_channel.write(message)
await asyncio.sleep(0)
if random.randint(0, 5) == 1:
await l2cap_channel.drain()
await l2cap_channel.drain()
await l2cap_channel.disconnect()
sent_bytes = b''.join(messages)
received_bytes = b''.join(received)
assert(sent_bytes == received_bytes)
@pytest.mark.asyncio
async def test_transfer():
for max_credits in (1, 10, 100, 10000):
for mtu in (23, 24, 25, 26, 50, 200, 255, 256, 1000):
for mps in (23, 24, 25, 26, 50, 200, 255, 256, 1000):
# print(max_credits, mtu, mps)
await transfer_payload(max_credits, mtu, mps)
# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_bidirectional_transfer():
devices = await setup_connection()
client_received = []
server_received = []
server_channel = None
def on_server_coc(channel):
nonlocal server_channel
server_channel = channel
def on_server_data(data):
server_received.append(data)
channel.sink = on_server_data
def on_client_data(data):
client_received.append(data)
psm = devices.devices[1].register_l2cap_channel_server(psm=0, server=on_server_coc)
client_channel = await devices.connections[0].open_l2cap_channel(psm)
client_channel.sink = on_client_data
messages = [
bytes([1, 2, 3, 4, 5, 6, 7]) * x
for x in (3, 10, 100)
]
for message in messages:
client_channel.write(message)
await client_channel.drain()
await asyncio.sleep(0)
server_channel.write(message)
await server_channel.drain()
await client_channel.disconnect()
message_bytes = b''.join(messages)
client_received_bytes = b''.join(client_received)
server_received_bytes = b''.join(server_received)
assert(client_received_bytes == message_bytes)
assert(server_received_bytes == message_bytes)
# -----------------------------------------------------------------------------
async def run():
test_helpers()
await test_basic_connection()
await test_transfer()
await test_bidirectional_transfer()
# -----------------------------------------------------------------------------
if __name__ == '__main__':
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
asyncio.run(run())