forked from auracaster/bumble_mirror
Add L2CAP CoC support (squashed)
[85542e0] fix test [3748781] add ASAH sink example [e782e29] add app [83daa30] wip [7f138a0] add test [f732108] allow different address syntax [9d0bbf8] rename deprecated methods [eb303d5] add LE CoC support
This commit is contained in:
committed by
Gilles Boccon-Gibod
parent
be8f8ac68f
commit
ce9004f0ac
@@ -17,13 +17,14 @@
|
||||
# -----------------------------------------------------------------------------
|
||||
import asyncio
|
||||
import os
|
||||
import struct
|
||||
import logging
|
||||
import click
|
||||
from colors import color
|
||||
|
||||
from bumble.device import Device, Peer
|
||||
from bumble.core import AdvertisingData
|
||||
from bumble.gatt import Service, Characteristic
|
||||
from bumble.gatt import Service, Characteristic, CharacteristicValue
|
||||
from bumble.utils import AsyncRunner
|
||||
from bumble.transport import open_transport_or_link
|
||||
from bumble.hci import HCI_Constant
|
||||
@@ -41,13 +42,59 @@ GG_PREFERRED_MTU = 256
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class GattlinkHubBridge(Device.Listener):
|
||||
class GattlinkL2capEndpoint:
|
||||
def __init__(self):
|
||||
self.l2cap_channel = None
|
||||
self.l2cap_packet = b''
|
||||
self.l2cap_packet_size = 0
|
||||
|
||||
# Called when an L2CAP SDU has been received
|
||||
def on_coc_sdu(self, sdu):
|
||||
print(color(f'<<< [L2CAP SDU]: {len(sdu)} bytes', 'cyan'))
|
||||
while len(sdu):
|
||||
if self.l2cap_packet_size == 0:
|
||||
# Expect a new packet
|
||||
self.l2cap_packet_size = sdu[0] + 1
|
||||
sdu = sdu[1:]
|
||||
else:
|
||||
bytes_needed = self.l2cap_packet_size - len(self.l2cap_packet)
|
||||
chunk = min(bytes_needed, len(sdu))
|
||||
self.l2cap_packet += sdu[:chunk]
|
||||
sdu = sdu[chunk:]
|
||||
if len(self.l2cap_packet) == self.l2cap_packet_size:
|
||||
self.on_l2cap_packet(self.l2cap_packet)
|
||||
self.l2cap_packet = b''
|
||||
self.l2cap_packet_size = 0
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class GattlinkHubBridge(GattlinkL2capEndpoint, Device.Listener):
|
||||
def __init__(self, device, peer_address):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.peer_address = peer_address
|
||||
self.peer = None
|
||||
self.rx_socket = None
|
||||
self.tx_socket = None
|
||||
self.rx_characteristic = None
|
||||
self.tx_characteristic = None
|
||||
self.l2cap_psm_characteristic = None
|
||||
|
||||
device.listener = self
|
||||
|
||||
async def start(self):
|
||||
# Connect to the peer
|
||||
print(f'=== Connecting to {self.peer_address}...')
|
||||
await self.device.connect(self.peer_address)
|
||||
|
||||
async def connect_l2cap(self, psm):
|
||||
print(color(f'### Connecting with L2CAP on PSM = {psm}', 'yellow'))
|
||||
try:
|
||||
self.l2cap_channel = await self.peer.connection.open_l2cap_channel(psm)
|
||||
print(color('*** Connected', 'yellow'), self.l2cap_channel)
|
||||
self.l2cap_channel.sink = self.on_coc_sdu
|
||||
|
||||
except Exception as error:
|
||||
print(color(f'!!! Connection failed: {error}', 'red'))
|
||||
|
||||
@AsyncRunner.run_in_task()
|
||||
async def on_connection(self, connection):
|
||||
@@ -80,15 +127,24 @@ class GattlinkHubBridge(Device.Listener):
|
||||
self.rx_characteristic = characteristic
|
||||
elif characteristic.uuid == GG_GATTLINK_TX_CHARACTERISTIC_UUID:
|
||||
self.tx_characteristic = characteristic
|
||||
elif characteristic.uuid == GG_GATTLINK_L2CAP_CHANNEL_PSM_CHARACTERISTIC_UUID:
|
||||
self.l2cap_psm_characteristic = characteristic
|
||||
print('RX:', self.rx_characteristic)
|
||||
print('TX:', self.tx_characteristic)
|
||||
print('PSM:', self.l2cap_psm_characteristic)
|
||||
|
||||
if self.l2cap_psm_characteristic:
|
||||
# Subscribe to and then read the PSM value
|
||||
await self.peer.subscribe(self.l2cap_psm_characteristic, self.on_l2cap_psm_received)
|
||||
psm_bytes = await self.peer.read_value(self.l2cap_psm_characteristic)
|
||||
psm = struct.unpack('<H', psm_bytes)[0]
|
||||
await self.connect_l2cap(psm)
|
||||
elif self.tx_characteristic:
|
||||
# Subscribe to TX
|
||||
if self.tx_characteristic:
|
||||
await self.peer.subscribe(self.tx_characteristic, self.on_tx_received)
|
||||
print(color('=== Subscribed to Gattlink TX', 'yellow'))
|
||||
else:
|
||||
print(color('!!! Gattlink TX not found', 'red'))
|
||||
print(color('!!! No Gattlink TX or PSM found', 'red'))
|
||||
|
||||
def on_connection_failure(self, error):
|
||||
print(color(f'!!! Connection failed: {error}'))
|
||||
@@ -99,31 +155,23 @@ class GattlinkHubBridge(Device.Listener):
|
||||
self.rx_characteristic = None
|
||||
self.peer = None
|
||||
|
||||
# Called when an L2CAP packet has been received
|
||||
def on_l2cap_packet(self, packet):
|
||||
print(color(f'<<< [L2CAP PACKET]: {len(packet)} bytes', 'cyan'))
|
||||
print(color('>>> [UDP]', 'magenta'))
|
||||
self.tx_socket.sendto(packet)
|
||||
|
||||
# Called by the GATT client when a notification is received
|
||||
def on_tx_received(self, value):
|
||||
print(color('>>> TX:', 'magenta'), value.hex())
|
||||
print(color(f'<<< [GATT TX]: {len(value)} bytes', 'cyan'))
|
||||
if self.tx_socket:
|
||||
print(color('>>> [UDP]', 'magenta'))
|
||||
self.tx_socket.sendto(value)
|
||||
|
||||
# Called by asyncio when the UDP socket is created
|
||||
def connection_made(self, transport):
|
||||
pass
|
||||
|
||||
# Called by asyncio when a UDP datagram is received
|
||||
def datagram_received(self, data, address):
|
||||
print(color('<<< RX:', 'magenta'), data.hex())
|
||||
|
||||
# TODO: use a queue instead of creating a task everytime
|
||||
if self.peer and self.rx_characteristic:
|
||||
asyncio.create_task(self.peer.write_value(self.rx_characteristic, data))
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class GattlinkNodeBridge(Device.Listener):
|
||||
def __init__(self):
|
||||
self.peer = None
|
||||
self.rx_socket = None
|
||||
self.tx_socket = None
|
||||
def on_l2cap_psm_received(self, value):
|
||||
psm = struct.unpack('<H', value)[0]
|
||||
asyncio.create_task(self.connect_l2cap(psm))
|
||||
|
||||
# Called by asyncio when the UDP socket is created
|
||||
def connection_made(self, transport):
|
||||
@@ -131,21 +179,130 @@ class GattlinkNodeBridge(Device.Listener):
|
||||
|
||||
# Called by asyncio when a UDP datagram is received
|
||||
def datagram_received(self, data, address):
|
||||
print(color('<<< RX:', 'magenta'), data.hex())
|
||||
print(color(f'<<< [UDP]: {len(data)} bytes', 'green'))
|
||||
|
||||
# TODO: use a queue instead of creating a task everytime
|
||||
if self.peer and self.rx_characteristic:
|
||||
if self.l2cap_channel:
|
||||
print(color('>>> [L2CAP]', 'yellow'))
|
||||
self.l2cap_channel.write(bytes([len(data) - 1]) + data)
|
||||
elif self.peer and self.rx_characteristic:
|
||||
print(color('>>> [GATT RX]', 'yellow'))
|
||||
asyncio.create_task(self.peer.write_value(self.rx_characteristic, data))
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def run(hci_transport, device_address, send_host, send_port, receive_host, receive_port):
|
||||
class GattlinkNodeBridge(GattlinkL2capEndpoint, Device.Listener):
|
||||
def __init__(self, device):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.peer = None
|
||||
self.tx_socket = None
|
||||
self.tx_subscriber = None
|
||||
self.rx_characteristic = None
|
||||
|
||||
# Register as a listener
|
||||
device.listener = self
|
||||
|
||||
# Listen for incoming L2CAP CoC connections
|
||||
psm = 0xFB
|
||||
device.register_l2cap_channel_server(0xFB, self.on_coc)
|
||||
print(f'### Listening for CoC connection on PSM {psm}')
|
||||
|
||||
# Setup the Gattlink service
|
||||
self.rx_characteristic = Characteristic(
|
||||
GG_GATTLINK_RX_CHARACTERISTIC_UUID,
|
||||
Characteristic.WRITE_WITHOUT_RESPONSE,
|
||||
Characteristic.WRITEABLE,
|
||||
CharacteristicValue(write=self.on_rx_write)
|
||||
)
|
||||
self.tx_characteristic = Characteristic(
|
||||
GG_GATTLINK_TX_CHARACTERISTIC_UUID,
|
||||
Characteristic.NOTIFY,
|
||||
Characteristic.READABLE
|
||||
)
|
||||
self.tx_characteristic.on('subscription', self.on_tx_subscription)
|
||||
self.psm_characteristic = Characteristic(
|
||||
GG_GATTLINK_L2CAP_CHANNEL_PSM_CHARACTERISTIC_UUID,
|
||||
Characteristic.READ | Characteristic.NOTIFY,
|
||||
Characteristic.READABLE,
|
||||
bytes([psm, 0])
|
||||
)
|
||||
gattlink_service = Service(
|
||||
GG_GATTLINK_SERVICE_UUID,
|
||||
[
|
||||
self.rx_characteristic,
|
||||
self.tx_characteristic,
|
||||
self.psm_characteristic
|
||||
]
|
||||
)
|
||||
device.add_services([gattlink_service])
|
||||
device.advertising_data = bytes(
|
||||
AdvertisingData([
|
||||
(AdvertisingData.COMPLETE_LOCAL_NAME, bytes('Bumble GG', 'utf-8')),
|
||||
(AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS,
|
||||
bytes(reversed(bytes.fromhex('ABBAFF00E56A484CB8328B17CF6CBFE8'))))
|
||||
])
|
||||
)
|
||||
|
||||
async def start(self):
|
||||
await self.device.start_advertising()
|
||||
|
||||
# Called by asyncio when the UDP socket is created
|
||||
def connection_made(self, transport):
|
||||
self.transport = transport
|
||||
|
||||
# Called by asyncio when a UDP datagram is received
|
||||
def datagram_received(self, data, address):
|
||||
print(color(f'<<< [UDP]: {len(data)} bytes', 'green'))
|
||||
|
||||
if self.l2cap_channel:
|
||||
print(color('>>> [L2CAP]', 'yellow'))
|
||||
self.l2cap_channel.write(bytes([len(data) - 1]) + data)
|
||||
elif self.tx_subscriber:
|
||||
print(color('>>> [GATT TX]', 'yellow'))
|
||||
self.tx_characteristic.value = data
|
||||
asyncio.create_task(self.device.notify_subscribers(self.tx_characteristic))
|
||||
|
||||
# Called when a write to the RX characteristic has been received
|
||||
def on_rx_write(self, connection, data):
|
||||
print(color(f'<<< [GATT RX]: {len(data)} bytes', 'cyan'))
|
||||
print(color('>>> [UDP]', 'magenta'))
|
||||
self.tx_socket.sendto(data)
|
||||
|
||||
# Called when the subscription to the TX characteristic has changed
|
||||
def on_tx_subscription(self, peer, enabled):
|
||||
print(f'### [GATT TX] subscription from {peer}: {"enabled" if enabled else "disabled"}')
|
||||
if enabled:
|
||||
self.tx_subscriber = peer
|
||||
else:
|
||||
self.tx_subscriber = None
|
||||
|
||||
# Called when an L2CAP packet is received
|
||||
def on_l2cap_packet(self, packet):
|
||||
print(color(f'<<< [L2CAP PACKET]: {len(packet)} bytes', 'cyan'))
|
||||
print(color('>>> [UDP]', 'magenta'))
|
||||
self.tx_socket.sendto(packet)
|
||||
|
||||
# Called when a new connection is established
|
||||
def on_coc(self, channel):
|
||||
print('*** CoC Connection', channel)
|
||||
self.l2cap_channel = channel
|
||||
channel.sink = self.on_coc_sdu
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def run(hci_transport, device_address, role_or_peer_address, send_host, send_port, receive_host, receive_port):
|
||||
print('<<< connecting to HCI...')
|
||||
async with await open_transport_or_link(hci_transport) as (hci_source, hci_sink):
|
||||
print('<<< connected')
|
||||
|
||||
# Instantiate a bridge object
|
||||
bridge = GattlinkNodeBridge()
|
||||
device = Device.with_hci('Bumble GG', device_address, hci_source, hci_sink)
|
||||
|
||||
# Instantiate a bridge object
|
||||
if role_or_peer_address == 'node':
|
||||
bridge = GattlinkNodeBridge(device)
|
||||
else:
|
||||
bridge = GattlinkHubBridge(device, role_or_peer_address)
|
||||
|
||||
# Create a UDP to RX bridge (receive from UDP, send to RX)
|
||||
loop = asyncio.get_running_loop()
|
||||
@@ -160,35 +317,8 @@ async def run(hci_transport, device_address, send_host, send_port, receive_host,
|
||||
remote_addr=(send_host, send_port)
|
||||
)
|
||||
|
||||
# Create a device to manage the host, with a custom listener
|
||||
device = Device.with_hci('Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink)
|
||||
device.listener = bridge
|
||||
await device.power_on()
|
||||
|
||||
# Connect to the peer
|
||||
# print(f'=== Connecting to {device_address}...')
|
||||
# await device.connect(device_address)
|
||||
|
||||
# TODO move to class
|
||||
gattlink_service = Service(
|
||||
GG_GATTLINK_SERVICE_UUID,
|
||||
[
|
||||
Characteristic(
|
||||
GG_GATTLINK_L2CAP_CHANNEL_PSM_CHARACTERISTIC_UUID,
|
||||
Characteristic.READ,
|
||||
Characteristic.READABLE,
|
||||
bytes([193, 0])
|
||||
)
|
||||
]
|
||||
)
|
||||
device.add_services([gattlink_service])
|
||||
device.advertising_data = bytes(
|
||||
AdvertisingData([
|
||||
(AdvertisingData.COMPLETE_LOCAL_NAME, bytes('Bumble GG', 'utf-8')),
|
||||
(AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS, bytes(reversed(bytes.fromhex('ABBAFF00E56A484CB8328B17CF6CBFE8'))))
|
||||
])
|
||||
)
|
||||
await device.start_advertising()
|
||||
await bridge.start()
|
||||
|
||||
# Wait until the source terminates
|
||||
await hci_source.wait_for_termination()
|
||||
@@ -197,15 +327,16 @@ async def run(hci_transport, device_address, send_host, send_port, receive_host,
|
||||
@click.command()
|
||||
@click.argument('hci_transport')
|
||||
@click.argument('device_address')
|
||||
@click.argument('role_or_peer_address')
|
||||
@click.option('-sh', '--send-host', type=str, default='127.0.0.1', help='UDP host to send to')
|
||||
@click.option('-sp', '--send-port', type=int, default=9001, help='UDP port to send to')
|
||||
@click.option('-rh', '--receive-host', type=str, default='127.0.0.1', help='UDP host to receive on')
|
||||
@click.option('-rp', '--receive-port', type=int, default=9000, help='UDP port to receive on')
|
||||
def main(hci_transport, device_address, send_host, send_port, receive_host, receive_port):
|
||||
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
|
||||
asyncio.run(run(hci_transport, device_address, send_host, send_port, receive_host, receive_port))
|
||||
def main(hci_transport, device_address, role_or_peer_address, send_host, send_port, receive_host, receive_port):
|
||||
asyncio.run(run(hci_transport, device_address, role_or_peer_address, send_host, send_port, receive_host, receive_port))
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
331
apps/l2cap_bridge.py
Normal file
331
apps/l2cap_bridge.py
Normal file
@@ -0,0 +1,331 @@
|
||||
# Copyright 2021-2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import asyncio
|
||||
import click
|
||||
import logging
|
||||
import os
|
||||
from colors import color
|
||||
|
||||
from bumble.transport import open_transport_or_link
|
||||
from bumble.device import Device
|
||||
from bumble.utils import FlowControlAsyncPipe
|
||||
from bumble.hci import HCI_Constant
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class ServerBridge:
|
||||
"""
|
||||
L2CAP CoC server bridge: waits for a peer to connect an L2CAP CoC channel
|
||||
on a specified PSM. When the connection is made, the bridge connects a TCP
|
||||
socket to a remote host and bridges the data in both directions, with flow
|
||||
control.
|
||||
When the L2CAP CoC channel is closed, the bridge disconnects the TCP socket
|
||||
and waits for a new L2CAP CoC channel to be connected.
|
||||
When the TCP connection is closed by the TCP server, XXXX
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
psm,
|
||||
max_credits,
|
||||
mtu,
|
||||
mps,
|
||||
tcp_host,
|
||||
tcp_port
|
||||
):
|
||||
self.psm = psm
|
||||
self.max_credits = max_credits
|
||||
self.mtu = mtu
|
||||
self.mps = mps
|
||||
self.tcp_host = tcp_host
|
||||
self.tcp_port = tcp_port
|
||||
|
||||
async def start(self, device):
|
||||
# Listen for incoming L2CAP CoC connections
|
||||
device.register_l2cap_channel_server(
|
||||
psm = self.psm,
|
||||
server = self.on_coc,
|
||||
max_credits = self.max_credits,
|
||||
mtu = self.mtu,
|
||||
mps = self.mps
|
||||
)
|
||||
print(color(f'### Listening for CoC connection on PSM {self.psm}', 'yellow'))
|
||||
|
||||
def on_ble_connection(connection):
|
||||
def on_ble_disconnection(reason):
|
||||
print(color('@@@ Bluetooth disconnection:', 'red'), HCI_Constant.error_name(reason))
|
||||
|
||||
print(color('@@@ Bluetooth connection:', 'green'), connection)
|
||||
connection.on('disconnection', on_ble_disconnection)
|
||||
|
||||
device.on('connection', on_ble_connection)
|
||||
|
||||
await device.start_advertising(auto_restart=True)
|
||||
|
||||
# Called when a new L2CAP connection is established
|
||||
def on_coc(self, l2cap_channel):
|
||||
print(color('*** L2CAP channel:', 'cyan'), l2cap_channel)
|
||||
|
||||
class Pipe:
|
||||
def __init__(self, bridge, l2cap_channel):
|
||||
self.bridge = bridge
|
||||
self.tcp_transport = None
|
||||
self.l2cap_channel = l2cap_channel
|
||||
|
||||
l2cap_channel.on('close', self.on_l2cap_close)
|
||||
l2cap_channel.sink = self.on_coc_sdu
|
||||
|
||||
async def connect_to_tcp(self):
|
||||
# Connect to the TCP server
|
||||
print(color(f'### Connecting to TCP {self.bridge.tcp_host}:{self.bridge.tcp_port}...', 'yellow'))
|
||||
|
||||
class TcpClientProtocol(asyncio.Protocol):
|
||||
def __init__(self, pipe):
|
||||
self.pipe = pipe
|
||||
|
||||
def connection_lost(self, error):
|
||||
print(color(f'!!! TCP connection lost: {error}', 'red'))
|
||||
if self.pipe.l2cap_channel is not None:
|
||||
asyncio.create_task(self.pipe.l2cap_channel.disconnect())
|
||||
|
||||
def data_received(self, data):
|
||||
print(f'<<< Received on TCP: {len(data)}')
|
||||
self.pipe.l2cap_channel.write(data)
|
||||
|
||||
try:
|
||||
self.tcp_transport, _ = await asyncio.get_running_loop().create_connection(
|
||||
lambda: TcpClientProtocol(self),
|
||||
host=self.bridge.tcp_host,
|
||||
port=self.bridge.tcp_port,
|
||||
)
|
||||
print(color('### Connected', 'green'))
|
||||
except Exception as error:
|
||||
print(color(f'!!! Connection failed: {error}', 'red'))
|
||||
await self.l2cap_channel.disconnect()
|
||||
|
||||
def on_l2cap_close(self):
|
||||
self.l2cap_channel = None
|
||||
if self.tcp_transport is not None:
|
||||
self.tcp_transport.close()
|
||||
|
||||
def on_coc_sdu(self, sdu):
|
||||
print(color(f'<<< [L2CAP SDU]: {len(sdu)} bytes', 'cyan'))
|
||||
if self.tcp_transport is None:
|
||||
print(color('!!! TCP socket not open, dropping', 'red'))
|
||||
return
|
||||
self.tcp_transport.write(sdu)
|
||||
|
||||
pipe = Pipe(self, l2cap_channel)
|
||||
|
||||
asyncio.create_task(pipe.connect_to_tcp())
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class ClientBridge:
|
||||
"""
|
||||
L2CAP CoC client bridge: connects to a BLE device, then waits for an inbound
|
||||
TCP connection on a specified port number. When a TCP client connects, an
|
||||
L2CAP CoC channel connection to the BLE device is established, and the data
|
||||
is bridged in both directions, with flow control.
|
||||
When the TCP connection is closed by the client, the L2CAP CoC channel is
|
||||
disconnected, but the connection to the BLE device remains, ready for a new
|
||||
TCP client to connect.
|
||||
When the L2CAP CoC channel is closed, XXXX
|
||||
"""
|
||||
|
||||
READ_CHUNK_SIZE = 4096
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
psm,
|
||||
max_credits,
|
||||
mtu,
|
||||
mps,
|
||||
address,
|
||||
tcp_host,
|
||||
tcp_port
|
||||
):
|
||||
self.psm = psm
|
||||
self.max_credits = max_credits
|
||||
self.mtu = mtu
|
||||
self.mps = mps
|
||||
self.address = address
|
||||
self.tcp_host = tcp_host
|
||||
self.tcp_port = tcp_port
|
||||
|
||||
async def start(self, device):
|
||||
print(color(f'### Connecting to {self.address}...', 'yellow'))
|
||||
connection = await device.connect(self.address)
|
||||
print(color('### Connected', 'green'))
|
||||
|
||||
# Called when the BLE connection is disconnected
|
||||
def on_ble_disconnection(reason):
|
||||
print(color('@@@ Bluetooth disconnection:', 'red'), HCI_Constant.error_name(reason))
|
||||
|
||||
connection.on('disconnection', on_ble_disconnection)
|
||||
|
||||
# Called when a TCP connection is established
|
||||
async def on_tcp_connection(reader, writer):
|
||||
peername = writer.get_extra_info('peername')
|
||||
print(color(f'<<< TCP connection from {peername}', 'magenta'))
|
||||
|
||||
def on_coc_sdu(sdu):
|
||||
print(color(f'<<< [L2CAP SDU]: {len(sdu)} bytes', 'cyan'))
|
||||
l2cap_to_tcp_pipe.write(sdu)
|
||||
|
||||
def on_l2cap_close():
|
||||
print(color('*** L2CAP channel closed', 'red'))
|
||||
l2cap_to_tcp_pipe.stop()
|
||||
writer.close()
|
||||
|
||||
# Connect a new L2CAP channel
|
||||
print(color(f'>>> Opening L2CAP channel on PSM = {self.psm}', 'yellow'))
|
||||
try:
|
||||
l2cap_channel = await connection.open_l2cap_channel(
|
||||
psm = self.psm,
|
||||
max_credits = self.max_credits,
|
||||
mtu = self.mtu,
|
||||
mps = self.mps
|
||||
)
|
||||
print(color('*** L2CAP channel:', 'cyan'), l2cap_channel)
|
||||
except Exception as error:
|
||||
print(color(f'!!! Connection failed: {error}', 'red'))
|
||||
writer.close()
|
||||
return
|
||||
|
||||
l2cap_channel.sink = on_coc_sdu
|
||||
l2cap_channel.on('close', on_l2cap_close)
|
||||
|
||||
# Start a flow control pipe from L2CAP to TCP
|
||||
l2cap_to_tcp_pipe = FlowControlAsyncPipe(
|
||||
l2cap_channel.pause_reading,
|
||||
l2cap_channel.resume_reading,
|
||||
writer.write,
|
||||
writer.drain
|
||||
)
|
||||
l2cap_to_tcp_pipe.start()
|
||||
|
||||
# Pipe data from TCP to L2CAP
|
||||
while True:
|
||||
try:
|
||||
data = await reader.read(self.READ_CHUNK_SIZE)
|
||||
|
||||
if len(data) == 0:
|
||||
print(color('!!! End of stream', 'red'))
|
||||
await l2cap_channel.disconnect()
|
||||
return
|
||||
|
||||
print(color(f'<<< [TCP DATA]: {len(data)} bytes', 'blue'))
|
||||
l2cap_channel.write(data)
|
||||
await l2cap_channel.drain()
|
||||
except Exception as error:
|
||||
print(f'!!! Exception: {error}')
|
||||
break
|
||||
|
||||
writer.close()
|
||||
print(color('~~~ Bye bye', 'magenta'))
|
||||
|
||||
await asyncio.start_server(
|
||||
on_tcp_connection,
|
||||
host=self.tcp_host if self.tcp_host != '_' else None,
|
||||
port=self.tcp_port
|
||||
)
|
||||
print(color(f'### Listening for TCP connections on port {self.tcp_port}', 'magenta'))
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def run(device_config, hci_transport, bridge):
|
||||
print('<<< connecting to HCI...')
|
||||
async with await open_transport_or_link(hci_transport) as (hci_source, hci_sink):
|
||||
print('<<< connected')
|
||||
|
||||
device = Device.from_config_file_with_hci(device_config, hci_source, hci_sink)
|
||||
|
||||
# Let's go
|
||||
await device.power_on()
|
||||
await bridge.start(device)
|
||||
|
||||
# Wait until the transport terminates
|
||||
await hci_source.wait_for_termination()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@click.group()
|
||||
@click.pass_context
|
||||
@click.option('--device-config', help='Device configuration file', required=True)
|
||||
@click.option('--hci-transport', help='HCI transport', required=True)
|
||||
@click.option('--psm', help='PSM for L2CAP CoC', type=int, default=1234)
|
||||
@click.option('--l2cap-coc-max-credits', help='Maximum L2CAP CoC Credits', type=click.IntRange(1, 65535), default=128)
|
||||
@click.option('--l2cap-coc-mtu', help='L2CAP CoC MTU', type=click.IntRange(23, 65535), default=1022)
|
||||
@click.option('--l2cap-coc-mps', help='L2CAP CoC MPS', type=click.IntRange(23, 65533), default=1024)
|
||||
def cli(context, device_config, hci_transport, psm, l2cap_coc_max_credits, l2cap_coc_mtu, l2cap_coc_mps):
|
||||
context.ensure_object(dict)
|
||||
context.obj['device_config'] = device_config
|
||||
context.obj['hci_transport'] = hci_transport
|
||||
context.obj['psm'] = psm
|
||||
context.obj['max_credits'] = l2cap_coc_max_credits
|
||||
context.obj['mtu'] = l2cap_coc_mtu
|
||||
context.obj['mps'] = l2cap_coc_mps
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@cli.command()
|
||||
@click.pass_context
|
||||
@click.option('--tcp-host', help='TCP host', default='localhost')
|
||||
@click.option('--tcp-port', help='TCP port', default=9544)
|
||||
def server(context, tcp_host, tcp_port):
|
||||
bridge = ServerBridge(
|
||||
context.obj['psm'],
|
||||
context.obj['max_credits'],
|
||||
context.obj['mtu'],
|
||||
context.obj['mps'],
|
||||
tcp_host,
|
||||
tcp_port)
|
||||
asyncio.run(run(
|
||||
context.obj['device_config'],
|
||||
context.obj['hci_transport'],
|
||||
bridge
|
||||
))
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@cli.command()
|
||||
@click.pass_context
|
||||
@click.argument('bluetooth-address')
|
||||
@click.option('--tcp-host', help='TCP host', default='_')
|
||||
@click.option('--tcp-port', help='TCP port', default=9543)
|
||||
def client(context, bluetooth_address, tcp_host, tcp_port):
|
||||
bridge = ClientBridge(
|
||||
context.obj['psm'],
|
||||
context.obj['max_credits'],
|
||||
context.obj['mtu'],
|
||||
context.obj['mps'],
|
||||
bluetooth_address,
|
||||
tcp_host,
|
||||
tcp_port
|
||||
)
|
||||
asyncio.run(run(
|
||||
context.obj['device_config'],
|
||||
context.obj['hci_transport'],
|
||||
bridge
|
||||
))
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper())
|
||||
if __name__ == '__main__':
|
||||
cli(obj={})
|
||||
@@ -351,7 +351,7 @@ class MediaPacketPump:
|
||||
logger.debug('pump canceled')
|
||||
|
||||
# Pump packets
|
||||
self.pump_task = asyncio.get_running_loop().create_task(pump_packets())
|
||||
self.pump_task = asyncio.create_task(pump_packets())
|
||||
|
||||
async def stop(self):
|
||||
# Stop the pump
|
||||
@@ -1890,10 +1890,10 @@ class LocalSource(LocalStreamEndPoint, EventEmitter):
|
||||
self.configuration = configuration
|
||||
|
||||
def on_start_command(self):
|
||||
asyncio.get_running_loop().create_task(self.start())
|
||||
asyncio.create_task(self.start())
|
||||
|
||||
def on_suspend_command(self):
|
||||
asyncio.get_running_loop().create_task(self.stop())
|
||||
asyncio.create_task(self.stop())
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
@@ -43,6 +43,13 @@ logger = logging.getLogger(__name__)
|
||||
# -----------------------------------------------------------------------------
|
||||
# Constants
|
||||
# -----------------------------------------------------------------------------
|
||||
DEVICE_MIN_SCAN_INTERVAL = 25
|
||||
DEVICE_MAX_SCAN_INTERVAL = 10240
|
||||
DEVICE_MIN_SCAN_WINDOW = 25
|
||||
DEVICE_MAX_SCAN_WINDOW = 10240
|
||||
DEVICE_MIN_LE_RSSI = -127
|
||||
DEVICE_MAX_LE_RSSI = 20
|
||||
|
||||
DEVICE_DEFAULT_ADDRESS = '00:00:00:00:00:00'
|
||||
DEVICE_DEFAULT_ADVERTISING_INTERVAL = 1000 # ms
|
||||
DEVICE_DEFAULT_ADVERTISING_DATA = ''
|
||||
@@ -62,20 +69,15 @@ DEVICE_DEFAULT_CONNECTION_MAX_LATENCY = 0
|
||||
DEVICE_DEFAULT_CONNECTION_SUPERVISION_TIMEOUT = 720 # ms
|
||||
DEVICE_DEFAULT_CONNECTION_MIN_CE_LENGTH = 0 # ms
|
||||
DEVICE_DEFAULT_CONNECTION_MAX_CE_LENGTH = 0 # ms
|
||||
|
||||
DEVICE_MIN_SCAN_INTERVAL = 25
|
||||
DEVICE_MAX_SCAN_INTERVAL = 10240
|
||||
DEVICE_MIN_SCAN_WINDOW = 25
|
||||
DEVICE_MAX_SCAN_WINDOW = 10240
|
||||
DEVICE_MIN_LE_RSSI = -127
|
||||
DEVICE_MAX_LE_RSSI = 20
|
||||
DEVICE_DEFAULT_L2CAP_COC_MTU = l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU
|
||||
DEVICE_DEFAULT_L2CAP_COC_MPS = l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS
|
||||
DEVICE_DEFAULT_L2CAP_COC_MAX_CREDITS = l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Classes
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class Advertisement:
|
||||
TX_POWER_NOT_AVAILABLE = HCI_LE_Extended_Advertising_Report_Event.TX_POWER_INFORMATION_NOT_AVAILABLE
|
||||
@@ -429,6 +431,15 @@ class Connection(CompositeEventEmitter):
|
||||
def create_l2cap_connector(self, psm):
|
||||
return self.device.create_l2cap_connector(self, psm)
|
||||
|
||||
async def open_l2cap_channel(
|
||||
self,
|
||||
psm,
|
||||
max_credits=DEVICE_DEFAULT_L2CAP_COC_MAX_CREDITS,
|
||||
mtu=DEVICE_DEFAULT_L2CAP_COC_MTU,
|
||||
mps=DEVICE_DEFAULT_L2CAP_COC_MPS
|
||||
):
|
||||
return await self.device.open_l2cap_channel(self, psm, max_credits, mtu, mps)
|
||||
|
||||
async def disconnect(self, reason=HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR):
|
||||
return await self.device.disconnect(self, reason)
|
||||
|
||||
@@ -563,6 +574,7 @@ class DeviceConfiguration:
|
||||
with open(filename, 'r') as file:
|
||||
self.load_from_dict(json.load(file))
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Decorators used with the following Device class
|
||||
# (we define them outside of the Device class, because defining decorators
|
||||
@@ -785,15 +797,35 @@ class Device(CompositeEventEmitter):
|
||||
if transport is None or connection.transport == transport:
|
||||
return connection
|
||||
|
||||
def register_l2cap_server(self, psm, server):
|
||||
self.l2cap_channel_manager.register_server(psm, server)
|
||||
|
||||
def create_l2cap_connector(self, connection, psm):
|
||||
return lambda: self.l2cap_channel_manager.connect(connection, psm)
|
||||
|
||||
def create_l2cap_registrar(self, psm):
|
||||
return lambda handler: self.register_l2cap_server(psm, handler)
|
||||
|
||||
def register_l2cap_server(self, psm, server):
|
||||
self.l2cap_channel_manager.register_server(psm, server)
|
||||
|
||||
def register_l2cap_channel_server(
|
||||
self,
|
||||
psm,
|
||||
server,
|
||||
max_credits=DEVICE_DEFAULT_L2CAP_COC_MAX_CREDITS,
|
||||
mtu=DEVICE_DEFAULT_L2CAP_COC_MTU,
|
||||
mps=DEVICE_DEFAULT_L2CAP_COC_MPS
|
||||
):
|
||||
return self.l2cap_channel_manager.register_le_coc_server(psm, server, max_credits, mtu, mps)
|
||||
|
||||
async def open_l2cap_channel(
|
||||
self,
|
||||
connection,
|
||||
psm,
|
||||
max_credits=DEVICE_DEFAULT_L2CAP_COC_MAX_CREDITS,
|
||||
mtu=DEVICE_DEFAULT_L2CAP_COC_MTU,
|
||||
mps=DEVICE_DEFAULT_L2CAP_COC_MPS
|
||||
):
|
||||
return await self.l2cap_channel_manager.open_le_coc(connection, psm, max_credits, mtu, mps)
|
||||
|
||||
def send_l2cap_pdu(self, connection_handle, cid, pdu):
|
||||
self.host.send_l2cap_pdu(connection_handle, cid, pdu)
|
||||
|
||||
@@ -1185,13 +1217,15 @@ class Device(CompositeEventEmitter):
|
||||
def on_connection(connection):
|
||||
if transport == BT_LE_TRANSPORT or (
|
||||
# match BR/EDR connection event against peer address
|
||||
connection.transport == transport and connection.peer_address == peer_address):
|
||||
connection.transport == transport and connection.peer_address == peer_address
|
||||
):
|
||||
pending_connection.set_result(connection)
|
||||
|
||||
def on_connection_failure(error):
|
||||
if transport == BT_LE_TRANSPORT or (
|
||||
# match BR/EDR connection failure event against peer address
|
||||
error.transport == transport and error.peer_address == peer_address):
|
||||
error.transport == transport and error.peer_address == peer_address
|
||||
):
|
||||
pending_connection.set_exception(error)
|
||||
|
||||
# Create a future so that we can wait for the connection's result
|
||||
@@ -1349,8 +1383,7 @@ class Device(CompositeEventEmitter):
|
||||
try:
|
||||
# Wait for a request or a completed connection
|
||||
result = await (asyncio.wait_for(pending_request, timeout) if timeout else pending_request)
|
||||
|
||||
except:
|
||||
except Exception:
|
||||
# Remove future from device context
|
||||
if peer_address == Address.ANY:
|
||||
self.classic_pending_accepts[Address.ANY].remove(pending_request)
|
||||
@@ -1710,26 +1743,32 @@ class Device(CompositeEventEmitter):
|
||||
connection.remove_listener('connection_encryption_failure', on_encryption_failure)
|
||||
|
||||
# [Classic only]
|
||||
async def request_remote_name(self, remote: Connection | Address):
|
||||
async def request_remote_name(self, remote): # remote: Connection | Address
|
||||
# Set up event handlers
|
||||
pending_name = asyncio.get_running_loop().create_future()
|
||||
|
||||
if type(remote) == Address:
|
||||
peer_address = remote
|
||||
handler = self.on('remote_name',
|
||||
handler = self.on(
|
||||
'remote_name',
|
||||
lambda address, remote_name:
|
||||
pending_name.set_result(remote_name) if address == remote else None)
|
||||
failure_handler = self.on('remote_name_failure',
|
||||
pending_name.set_result(remote_name) if address == remote else None
|
||||
)
|
||||
failure_handler = self.on(
|
||||
'remote_name_failure',
|
||||
lambda address, error_code:
|
||||
pending_name.set_exception(HCI_Error(error_code)) if address == remote else None)
|
||||
pending_name.set_exception(HCI_Error(error_code)) if address == remote else None
|
||||
)
|
||||
else:
|
||||
peer_address = remote.peer_address
|
||||
handler = remote.on('remote_name',
|
||||
lambda:
|
||||
pending_name.set_result(remote.peer_name))
|
||||
failure_handler = remote.on('remote_name_failure',
|
||||
lambda error_code:
|
||||
pending_name.set_exception(HCI_Error(error_code)))
|
||||
handler = remote.on(
|
||||
'remote_name',
|
||||
lambda: pending_name.set_result(remote.peer_name)
|
||||
)
|
||||
failure_handler = remote.on(
|
||||
'remote_name_failure',
|
||||
lambda error_code: pending_name.set_exception(HCI_Error(error_code))
|
||||
)
|
||||
|
||||
try:
|
||||
result = await self.send_command(
|
||||
@@ -2097,7 +2136,6 @@ class Device(CompositeEventEmitter):
|
||||
else:
|
||||
self.emit('remote_name_failure', address, error)
|
||||
|
||||
|
||||
# [Classic only]
|
||||
@host_event_handler
|
||||
@try_with_connection_from_address
|
||||
|
||||
@@ -25,6 +25,7 @@
|
||||
import asyncio
|
||||
import types
|
||||
import logging
|
||||
from pyee import EventEmitter
|
||||
from colors import color
|
||||
|
||||
from .core import *
|
||||
|
||||
@@ -273,7 +273,7 @@ class Client:
|
||||
if response.op_code == ATT_ERROR_RESPONSE:
|
||||
if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR:
|
||||
# Unexpected end
|
||||
logger.waning(f'!!! unexpected error while discovering services: {HCI_Constant.error_name(response.error_code)}')
|
||||
logger.warning(f'!!! unexpected error while discovering services: {HCI_Constant.error_name(response.error_code)}')
|
||||
# TODO raise appropriate exception
|
||||
return
|
||||
break
|
||||
@@ -337,7 +337,7 @@ class Client:
|
||||
if response.op_code == ATT_ERROR_RESPONSE:
|
||||
if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR:
|
||||
# Unexpected end
|
||||
logger.waning(f'!!! unexpected error while discovering services: {HCI_Constant.error_name(response.error_code)}')
|
||||
logger.warning(f'!!! unexpected error while discovering services: {HCI_Constant.error_name(response.error_code)}')
|
||||
# TODO raise appropriate exception
|
||||
return
|
||||
break
|
||||
|
||||
@@ -155,7 +155,7 @@ class Server(EventEmitter):
|
||||
return cccd or bytes([0, 0])
|
||||
|
||||
def write_cccd(self, connection, characteristic, value):
|
||||
logger.debug(f'Subscription update for connection={connection.handle:04X}, handle={characteristic.handle:04X}: {value.hex()}')
|
||||
logger.debug(f'Subscription update for connection=0x{connection.handle:04X}, handle=0x{characteristic.handle:04X}: {value.hex()}')
|
||||
|
||||
# Sanity check
|
||||
if len(value) != 2:
|
||||
@@ -204,7 +204,7 @@ class Server(EventEmitter):
|
||||
logger.debug(f'GATT Notify from server: [0x{connection.handle:04X}] {notification}')
|
||||
self.send_gatt_pdu(connection.handle, bytes(notification))
|
||||
|
||||
async def indicate_subscriber(self, connection, attribute, value=None, force=False):
|
||||
async def indicate_subscriber(self, connection, attribute, force=False):
|
||||
# Check if there's a subscriber
|
||||
if not force:
|
||||
subscribers = self.subscribers.get(connection.handle)
|
||||
|
||||
@@ -2466,9 +2466,10 @@ class HCI_Write_Voice_Setting_Command(HCI_Command):
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@HCI_Command.command()
|
||||
class HCI_Read_Synchronous_Flow_Control_Enable_Command(HCI_Command):
|
||||
'''
|
||||
See Bluetooth spec @ 7.3.36 Write Synchronous Flow Control Enable Command
|
||||
See Bluetooth spec @ 7.3.36 Read Synchronous Flow Control Enable Command
|
||||
'''
|
||||
|
||||
|
||||
|
||||
@@ -79,6 +79,8 @@ class Host(EventEmitter):
|
||||
self.local_version = None
|
||||
self.local_supported_commands = bytes(64)
|
||||
self.local_le_features = 0
|
||||
self.suggested_max_tx_octets = 251 # Max allowed
|
||||
self.suggested_max_tx_time = 2120 # Max allowed
|
||||
self.command_semaphore = asyncio.Semaphore(1)
|
||||
self.long_term_key_provider = None
|
||||
self.link_key_provider = None
|
||||
@@ -138,6 +140,22 @@ class Host(EventEmitter):
|
||||
f'hc_total_num_le_acl_data_packets={self.hc_total_num_le_acl_data_packets}'
|
||||
)
|
||||
|
||||
if (
|
||||
self.supports_command(HCI_LE_READ_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND) and
|
||||
self.supports_command(HCI_LE_WRITE_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND)
|
||||
):
|
||||
response = await self.send_command(HCI_LE_Read_Suggested_Default_Data_Length_Command())
|
||||
suggested_max_tx_octets = response.return_parameters.suggested_max_tx_octets
|
||||
suggested_max_tx_time = response.return_parameters.suggested_max_tx_time
|
||||
if (
|
||||
suggested_max_tx_octets != self.suggested_max_tx_octets or
|
||||
suggested_max_tx_time != self.suggested_max_tx_time
|
||||
):
|
||||
await self.send_command(HCI_LE_Write_Suggested_Default_Data_Length_Command(
|
||||
suggested_max_tx_octets = self.suggested_max_tx_octets,
|
||||
suggested_max_tx_time = self.suggested_max_tx_time
|
||||
))
|
||||
|
||||
self.reset_done = True
|
||||
|
||||
@property
|
||||
|
||||
719
bumble/l2cap.py
719
bumble/l2cap.py
@@ -19,6 +19,7 @@ import asyncio
|
||||
import logging
|
||||
import struct
|
||||
|
||||
from collections import deque
|
||||
from colors import color
|
||||
from pyee import EventEmitter
|
||||
|
||||
@@ -43,13 +44,23 @@ L2CAP_MIN_BR_EDR_MTU = 48
|
||||
|
||||
L2CAP_DEFAULT_MTU = 2048 # Default value for the MTU we are willing to accept
|
||||
|
||||
L2CAP_DEFAULT_CONNECTIONLESS_MTU = 1024
|
||||
|
||||
# See Bluetooth spec @ Vol 3, Part A - Table 2.1: CID name space on ACL-U, ASB-U, and AMP-U logical links
|
||||
L2CAP_ACL_U_DYNAMIC_CID_RANGE_START = 0x0040
|
||||
L2CAP_ACL_U_DYNAMIC_CID_RANGE_END = 0xFFFF
|
||||
|
||||
# See Bluetooth spec @ Vol 3, Part A - Table 2.2: CID name space on LE-U logical link
|
||||
L2CAP_LE_U_DYNAMIC_CID_RANGE_START = 0x0040
|
||||
L2CAP_LE_U_DYNAMIC_CID_RANGE_START = 0x007F
|
||||
L2CAP_LE_U_DYNAMIC_CID_RANGE_END = 0x007F
|
||||
|
||||
# PSM Range - See Bluetooth spec @ Vol 3, Part A / Table 4.5: PSM ranges and usage
|
||||
L2CAP_PSM_DYNAMIC_RANGE_START = 0x1001
|
||||
L2CAP_PSM_DYNAMIC_RANGE_END = 0xFFFF
|
||||
|
||||
# LE PSM Ranges - See Bluetooth spec @ Vol 3, Part A / Table 4.19: LE Credit Based Connection Request LE_PSM ranges
|
||||
L2CAP_LE_PSM_DYNAMIC_RANGE_START = 0x0080
|
||||
L2CAP_LE_PSM_DYNAMIC_RANGE_END = 0x00FF
|
||||
|
||||
# Frame types
|
||||
L2CAP_COMMAND_REJECT = 0x01
|
||||
@@ -107,8 +118,13 @@ L2CAP_COMMAND_NOT_UNDERSTOOD_REASON = 0x0000
|
||||
L2CAP_SIGNALING_MTU_EXCEEDED_REASON = 0x0001
|
||||
L2CAP_INVALID_CID_IN_REQUEST_REASON = 0x0002
|
||||
|
||||
L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU = 2048
|
||||
L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_CREDITS = 65535
|
||||
L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU = 23
|
||||
L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MPS = 23
|
||||
L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MPS = 65533
|
||||
L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU = 2046
|
||||
L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS = 2048
|
||||
L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS = 256
|
||||
|
||||
L2CAP_MAXIMUM_TRANSMISSION_UNIT_CONFIGURATION_OPTION_TYPE = 0x01
|
||||
|
||||
@@ -172,7 +188,7 @@ class L2CAP_Control_Frame:
|
||||
self.identifier = pdu[1]
|
||||
length = struct.unpack_from('<H', pdu, 2)[0]
|
||||
if length + 4 != len(pdu):
|
||||
logger.warn(color(f'!!! length mismatch: expected {len(pdu) - 4} but got {length}', 'red'))
|
||||
logger.warning(color(f'!!! length mismatch: expected {len(pdu) - 4} but got {length}', 'red'))
|
||||
if hasattr(self, 'fields'):
|
||||
self.init_from_bytes(pdu, 4)
|
||||
return self
|
||||
@@ -268,7 +284,10 @@ class L2CAP_Command_Reject(L2CAP_Control_Frame):
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@L2CAP_Control_Frame.subclass([
|
||||
('psm', 2),
|
||||
('psm', {
|
||||
'parser': lambda data, offset: L2CAP_Connection_Request.parse_psm(data, offset),
|
||||
'serializer': lambda value: L2CAP_Connection_Request.serialize_psm(value)
|
||||
}),
|
||||
('source_cid', 2)
|
||||
])
|
||||
class L2CAP_Connection_Request(L2CAP_Control_Frame):
|
||||
@@ -276,6 +295,28 @@ class L2CAP_Connection_Request(L2CAP_Control_Frame):
|
||||
See Bluetooth spec @ Vol 3, Part A - 4.2 CONNECTION REQUEST
|
||||
'''
|
||||
|
||||
@staticmethod
|
||||
def parse_psm(data, offset=0):
|
||||
psm_length = 2
|
||||
psm = data[offset] | data[offset + 1] << 8
|
||||
|
||||
# The PSM field extends until the first even octet (inclusive)
|
||||
while data[offset + psm_length - 1] % 2 == 1:
|
||||
psm |= data[offset + psm_length] << (8 * psm_length)
|
||||
psm_length += 1
|
||||
|
||||
return offset + psm_length, psm
|
||||
|
||||
@staticmethod
|
||||
def serialize_psm(psm):
|
||||
serialized = struct.pack('<H', psm & 0xFFFF)
|
||||
psm >>= 16
|
||||
while psm:
|
||||
serialized += bytes([psm & 0xFF])
|
||||
psm >>= 8
|
||||
|
||||
return serialized
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@L2CAP_Control_Frame.subclass([
|
||||
@@ -298,7 +339,7 @@ class L2CAP_Connection_Response(L2CAP_Control_Frame):
|
||||
CONNECTION_REFUSED_SOURCE_CID_ALREADY_ALLOCATED = 0x0007
|
||||
CONNECTION_REFUSED_UNACCEPTABLE_PARAMETERS = 0x000B
|
||||
|
||||
CONNECTION_RESULT_NAMES = {
|
||||
RESULT_NAMES = {
|
||||
CONNECTION_SUCCESSFUL: 'CONNECTION_SUCCESSFUL',
|
||||
CONNECTION_PENDING: 'CONNECTION_PENDING',
|
||||
CONNECTION_REFUSED_LE_PSM_NOT_SUPPORTED: 'CONNECTION_REFUSED_LE_PSM_NOT_SUPPORTED',
|
||||
@@ -311,7 +352,7 @@ class L2CAP_Connection_Response(L2CAP_Control_Frame):
|
||||
|
||||
@staticmethod
|
||||
def result_name(result):
|
||||
return name_or_number(L2CAP_Connection_Response.CONNECTION_RESULT_NAMES, result)
|
||||
return name_or_number(L2CAP_Connection_Response.RESULT_NAMES, result)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -521,7 +562,7 @@ class L2CAP_LE_Credit_Based_Connection_Response(L2CAP_Control_Frame):
|
||||
CONNECTION_REFUSED_SOURCE_CID_ALREADY_ALLOCATED = 0x000A
|
||||
CONNECTION_REFUSED_UNACCEPTABLE_PARAMETERS = 0x000B
|
||||
|
||||
CONNECTION_RESULT_NAMES = {
|
||||
RESULT_NAMES = {
|
||||
CONNECTION_SUCCESSFUL: 'CONNECTION_SUCCESSFUL',
|
||||
CONNECTION_REFUSED_LE_PSM_NOT_SUPPORTED: 'CONNECTION_REFUSED_LE_PSM_NOT_SUPPORTED',
|
||||
CONNECTION_REFUSED_NO_RESOURCES_AVAILABLE: 'CONNECTION_REFUSED_NO_RESOURCES_AVAILABLE',
|
||||
@@ -536,7 +577,7 @@ class L2CAP_LE_Credit_Based_Connection_Response(L2CAP_Control_Frame):
|
||||
|
||||
@staticmethod
|
||||
def result_name(result):
|
||||
return name_or_number(L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_RESULT_NAMES, result)
|
||||
return name_or_number(L2CAP_LE_Credit_Based_Connection_Response.RESULT_NAMES, result)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -619,6 +660,9 @@ class Channel(EventEmitter):
|
||||
def send_pdu(self, pdu):
|
||||
self.manager.send_pdu(self.connection, self.destination_cid, pdu)
|
||||
|
||||
def send_control_frame(self, frame):
|
||||
self.manager.send_control_frame(self.connection, self.signaling_cid, frame)
|
||||
|
||||
async def send_request(self, request):
|
||||
# Check that there isn't already a request pending
|
||||
if self.response:
|
||||
@@ -637,15 +681,16 @@ class Channel(EventEmitter):
|
||||
elif self.sink:
|
||||
self.sink(pdu)
|
||||
else:
|
||||
logger.warn(color('received pdu without a pending request or sink', 'red'))
|
||||
|
||||
def send_control_frame(self, frame):
|
||||
self.manager.send_control_frame(self.connection, self.signaling_cid, frame)
|
||||
logger.warning(color('received pdu without a pending request or sink', 'red'))
|
||||
|
||||
async def connect(self):
|
||||
if self.state != Channel.CLOSED:
|
||||
raise InvalidStateError('invalid state')
|
||||
|
||||
# Check that we can start a new connection
|
||||
if self.connection_result:
|
||||
raise RuntimeError('connection already pending')
|
||||
|
||||
self.change_state(Channel.WAIT_CONNECT_RSP)
|
||||
self.send_control_frame(
|
||||
L2CAP_Connection_Request(
|
||||
@@ -657,7 +702,12 @@ class Channel(EventEmitter):
|
||||
|
||||
# Create a future to wait for the state machine to get to a success or error state
|
||||
self.connection_result = asyncio.get_running_loop().create_future()
|
||||
|
||||
# Wait for the connection to succeed or fail
|
||||
try:
|
||||
return await self.connection_result
|
||||
finally:
|
||||
self.connection_result = None
|
||||
|
||||
async def disconnect(self):
|
||||
if self.state != Channel.OPEN:
|
||||
@@ -708,7 +758,7 @@ class Channel(EventEmitter):
|
||||
|
||||
def on_connection_response(self, response):
|
||||
if self.state != Channel.WAIT_CONNECT_RSP:
|
||||
logger.warn(color('invalid state', 'red'))
|
||||
logger.warning(color('invalid state', 'red'))
|
||||
return
|
||||
|
||||
if response.result == L2CAP_Connection_Response.CONNECTION_SUCCESSFUL:
|
||||
@@ -734,7 +784,7 @@ class Channel(EventEmitter):
|
||||
self.state != Channel.WAIT_CONFIG_REQ and
|
||||
self.state != Channel.WAIT_CONFIG_REQ_RSP
|
||||
):
|
||||
logger.warn(color('invalid state', 'red'))
|
||||
logger.warning(color('invalid state', 'red'))
|
||||
return
|
||||
|
||||
# Decode the options
|
||||
@@ -750,7 +800,7 @@ class Channel(EventEmitter):
|
||||
source_cid = self.destination_cid,
|
||||
flags = 0x0000,
|
||||
result = L2CAP_Configure_Response.SUCCESS,
|
||||
options = request.options # TODO: don't accept everthing blindly
|
||||
options = request.options # TODO: don't accept everything blindly
|
||||
)
|
||||
)
|
||||
if self.state == Channel.WAIT_CONFIG:
|
||||
@@ -777,7 +827,7 @@ class Channel(EventEmitter):
|
||||
self.connection_result = None
|
||||
self.emit('open')
|
||||
else:
|
||||
logger.warn(color('invalid state', 'red'))
|
||||
logger.warning(color('invalid state', 'red'))
|
||||
elif response.result == L2CAP_Configure_Response.FAILURE_UNACCEPTABLE_PARAMETERS:
|
||||
# Re-configure with what's suggested in the response
|
||||
self.send_control_frame(
|
||||
@@ -789,7 +839,7 @@ class Channel(EventEmitter):
|
||||
)
|
||||
)
|
||||
else:
|
||||
logger.warn(color(f'!!! configuration rejected: {L2CAP_Configure_Response.result_name(response.result)}', 'red'))
|
||||
logger.warning(color(f'!!! configuration rejected: {L2CAP_Configure_Response.result_name(response.result)}', 'red'))
|
||||
# TODO: decide how to fail gracefully
|
||||
|
||||
def on_disconnection_request(self, request):
|
||||
@@ -805,15 +855,15 @@ class Channel(EventEmitter):
|
||||
self.emit('close')
|
||||
self.manager.on_channel_closed(self)
|
||||
else:
|
||||
logger.warn(color('invalid state', 'red'))
|
||||
logger.warning(color('invalid state', 'red'))
|
||||
|
||||
def on_disconnection_response(self, response):
|
||||
if self.state != Channel.WAIT_DISCONNECT:
|
||||
logger.warn(color('invalid state', 'red'))
|
||||
logger.warning(color('invalid state', 'red'))
|
||||
return
|
||||
|
||||
if response.destination_cid != self.destination_cid or response.source_cid != self.source_cid:
|
||||
logger.warn('unexpected source or destination CID')
|
||||
logger.warning('unexpected source or destination CID')
|
||||
return
|
||||
|
||||
self.change_state(Channel.CLOSED)
|
||||
@@ -827,23 +877,363 @@ class Channel(EventEmitter):
|
||||
return f'Channel({self.source_cid}->{self.destination_cid}, PSM={self.psm}, MTU={self.mtu}, state={Channel.STATE_NAMES[self.state]})'
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class LeConnectionOrientedChannel(EventEmitter):
|
||||
"""
|
||||
LE Credit-based Connection Oriented Channel
|
||||
"""
|
||||
|
||||
INIT = 0
|
||||
CONNECTED = 1
|
||||
CONNECTING = 2
|
||||
DISCONNECTING = 3
|
||||
DISCONNECTED = 4
|
||||
CONNECTION_ERROR = 5
|
||||
|
||||
STATE_NAMES = {
|
||||
INIT: 'INIT',
|
||||
CONNECTED: 'CONNECTED',
|
||||
CONNECTING: 'CONNECTING',
|
||||
DISCONNECTING: 'DISCONNECTING',
|
||||
DISCONNECTED: 'DISCONNECTED',
|
||||
CONNECTION_ERROR: 'CONNECTION_ERROR'
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def state_name(state):
|
||||
return name_or_number(LeConnectionOrientedChannel.STATE_NAMES, state)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
manager,
|
||||
connection,
|
||||
le_psm,
|
||||
source_cid,
|
||||
destination_cid,
|
||||
mtu,
|
||||
mps,
|
||||
credits,
|
||||
peer_mtu,
|
||||
peer_mps,
|
||||
peer_credits,
|
||||
connected
|
||||
):
|
||||
super().__init__()
|
||||
self.manager = manager
|
||||
self.connection = connection
|
||||
self.le_psm = le_psm
|
||||
self.source_cid = source_cid
|
||||
self.destination_cid = destination_cid
|
||||
self.mtu = mtu
|
||||
self.mps = mps
|
||||
self.credits = credits
|
||||
self.peer_mtu = peer_mtu
|
||||
self.peer_mps = peer_mps
|
||||
self.peer_credits = peer_credits
|
||||
self.peer_max_credits = self.peer_credits
|
||||
self.peer_credits_threshold = self.peer_max_credits // 2
|
||||
self.in_sdu = None
|
||||
self.in_sdu_length = 0
|
||||
self.out_queue = deque()
|
||||
self.out_sdu = None
|
||||
self.sink = None
|
||||
self.connection_result = None
|
||||
self.disconnection_result = None
|
||||
self.drained = asyncio.Event()
|
||||
|
||||
self.drained.set()
|
||||
|
||||
if connected:
|
||||
self.state = LeConnectionOrientedChannel.CONNECTED
|
||||
else:
|
||||
self.state = LeConnectionOrientedChannel.INIT
|
||||
|
||||
def change_state(self, new_state):
|
||||
logger.debug(f'{self} state change -> {color(self.state_name(new_state), "cyan")}')
|
||||
self.state = new_state
|
||||
|
||||
if new_state == self.CONNECTED:
|
||||
self.emit('open')
|
||||
elif new_state == self.DISCONNECTED:
|
||||
self.emit('close')
|
||||
|
||||
def send_pdu(self, pdu):
|
||||
self.manager.send_pdu(self.connection, self.destination_cid, pdu)
|
||||
|
||||
def send_control_frame(self, frame):
|
||||
self.manager.send_control_frame(self.connection, L2CAP_LE_SIGNALING_CID, frame)
|
||||
|
||||
async def connect(self):
|
||||
# Check that we're in the right state
|
||||
if self.state != self.INIT:
|
||||
raise InvalidStateError('not in a connectable state')
|
||||
|
||||
# Check that we can start a new connection
|
||||
identifier = self.manager.next_identifier(self.connection)
|
||||
if identifier in self.manager.le_coc_requests:
|
||||
raise RuntimeError('too many concurrent connection requests')
|
||||
|
||||
self.change_state(self.CONNECTING)
|
||||
request = L2CAP_LE_Credit_Based_Connection_Request(
|
||||
identifier = identifier,
|
||||
le_psm = self.le_psm,
|
||||
source_cid = self.source_cid,
|
||||
mtu = self.mtu,
|
||||
mps = self.mps,
|
||||
initial_credits = self.peer_credits
|
||||
)
|
||||
self.manager.le_coc_requests[identifier] = request
|
||||
self.send_control_frame(request)
|
||||
|
||||
# Create a future to wait for the response
|
||||
self.connection_result = asyncio.get_running_loop().create_future()
|
||||
|
||||
# Wait for the connection to succeed or fail
|
||||
return await self.connection_result
|
||||
|
||||
async def disconnect(self):
|
||||
# Check that we're connected
|
||||
if self.state != self.CONNECTED:
|
||||
raise InvalidStateError('not connected')
|
||||
|
||||
self.change_state(self.DISCONNECTING)
|
||||
self.flush_output()
|
||||
self.send_control_frame(
|
||||
L2CAP_Disconnection_Request(
|
||||
identifier = self.manager.next_identifier(self.connection),
|
||||
destination_cid = self.destination_cid,
|
||||
source_cid = self.source_cid
|
||||
)
|
||||
)
|
||||
|
||||
# Create a future to wait for the state machine to get to a success or error state
|
||||
self.disconnection_result = asyncio.get_running_loop().create_future()
|
||||
return await self.disconnection_result
|
||||
|
||||
def on_pdu(self, pdu):
|
||||
if self.sink is None:
|
||||
logger.warning('received pdu without a sink')
|
||||
return
|
||||
|
||||
if self.state != self.CONNECTED:
|
||||
logger.warning('received PDU while not connected, dropping')
|
||||
|
||||
# Manage the peer credits
|
||||
if self.peer_credits == 0:
|
||||
logger.warning('received LE frame when peer out of credits')
|
||||
else:
|
||||
self.peer_credits -= 1
|
||||
if self.peer_credits <= self.peer_credits_threshold:
|
||||
# The credits fell below the threshold, replenish them to the max
|
||||
self.send_control_frame(
|
||||
L2CAP_LE_Flow_Control_Credit(
|
||||
identifier = self.manager.next_identifier(self.connection),
|
||||
cid = self.source_cid,
|
||||
credits = self.peer_max_credits - self.peer_credits
|
||||
)
|
||||
)
|
||||
self.peer_credits = self.peer_max_credits
|
||||
|
||||
# Check if this starts a new SDU
|
||||
if self.in_sdu is None:
|
||||
# Start a new SDU
|
||||
self.in_sdu = pdu
|
||||
else:
|
||||
# Continue an SDU
|
||||
self.in_sdu += pdu
|
||||
|
||||
# Check if the SDU is complete
|
||||
if self.in_sdu_length == 0:
|
||||
# We don't know the size yet, check if we have received the header to compute it
|
||||
if len(self.in_sdu) >= 2:
|
||||
self.in_sdu_length = struct.unpack_from('<H', self.in_sdu, 0)[0]
|
||||
if self.in_sdu_length == 0:
|
||||
# We'll compute it later
|
||||
return
|
||||
if len(self.in_sdu) < 2 + self.in_sdu_length:
|
||||
# Not complete yet
|
||||
logger.debug(f'SDU: {len(self.in_sdu) - 2} of {self.in_sdu_length} bytes received')
|
||||
return
|
||||
if len(self.in_sdu) != 2 + self.in_sdu_length:
|
||||
# Overflow
|
||||
logger.warning(f'SDU overflow: sdu_length={self.in_sdu_length}, received {len(self.in_sdu) - 2}')
|
||||
# TODO: we should disconnect
|
||||
self.in_sdu = None
|
||||
self.in_sdu_length = 0
|
||||
return
|
||||
|
||||
# Send the SDU to the sink
|
||||
logger.debug(f'SDU complete: 2+{len(self.in_sdu) - 2} bytes')
|
||||
self.sink(self.in_sdu[2:])
|
||||
|
||||
# Prepare for a new SDU
|
||||
self.in_sdu = None
|
||||
self.in_sdu_length = 0
|
||||
|
||||
def on_connection_response(self, response):
|
||||
# Look for a matching pending response result
|
||||
if self.connection_result is None:
|
||||
logger.warning(f'received unexpected connection response (id={response.identifier})')
|
||||
return
|
||||
|
||||
if response.result == L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_SUCCESSFUL:
|
||||
self.destination_cid = response.destination_cid
|
||||
self.peer_mtu = response.mtu
|
||||
self.peer_mps = response.mps
|
||||
self.credits = response.initial_credits
|
||||
self.connected = True
|
||||
self.connection_result.set_result(self)
|
||||
self.change_state(self.CONNECTED)
|
||||
else:
|
||||
self.connection_result.set_exception(
|
||||
ProtocolError(
|
||||
response.result,
|
||||
'l2cap',
|
||||
L2CAP_LE_Credit_Based_Connection_Response.result_name(response.result))
|
||||
)
|
||||
self.change_state(self.CONNECTION_ERROR)
|
||||
|
||||
# Cleanup
|
||||
self.connection_result = None
|
||||
|
||||
def on_credits(self, credits):
|
||||
self.credits += credits
|
||||
logger.debug(f'received {credits} credits, total = {self.credits}')
|
||||
|
||||
# Try to send more data if we have any queued up
|
||||
self.process_output()
|
||||
|
||||
def on_disconnection_request(self, request):
|
||||
self.send_control_frame(
|
||||
L2CAP_Disconnection_Response(
|
||||
identifier = request.identifier,
|
||||
destination_cid = request.destination_cid,
|
||||
source_cid = request.source_cid
|
||||
)
|
||||
)
|
||||
self.change_state(self.DISCONNECTED)
|
||||
self.flush_output()
|
||||
|
||||
def on_disconnection_response(self, response):
|
||||
if self.state != self.DISCONNECTING:
|
||||
logger.warning(color('invalid state', 'red'))
|
||||
return
|
||||
|
||||
if response.destination_cid != self.destination_cid or response.source_cid != self.source_cid:
|
||||
logger.warning('unexpected source or destination CID')
|
||||
return
|
||||
|
||||
self.change_state(self.DISCONNECTED)
|
||||
if self.disconnection_result:
|
||||
self.disconnection_result.set_result(None)
|
||||
self.disconnection_result = None
|
||||
|
||||
def flush_output(self):
|
||||
self.out_queue.clear()
|
||||
self.out_sdu = None
|
||||
|
||||
def process_output(self):
|
||||
while self.credits > 0:
|
||||
if self.out_sdu is not None:
|
||||
# Finish the current SDU
|
||||
packet = self.out_sdu[:self.peer_mps]
|
||||
self.send_pdu(packet)
|
||||
self.credits -= 1
|
||||
logger.debug(f'sent {len(packet)} bytes, {self.credits} credits left')
|
||||
if len(packet) == len(self.out_sdu):
|
||||
# We sent everything
|
||||
self.out_sdu = None
|
||||
else:
|
||||
# Keep what's still left to send
|
||||
self.out_sdu = self.out_sdu[len(packet):]
|
||||
continue
|
||||
elif self.out_queue:
|
||||
# Create the next SDU (2 bytes header plus up to MTU bytes payload)
|
||||
logger.debug(f'assembling SDU from {len(self.out_queue)} packets in output queue')
|
||||
payload = b''
|
||||
while self.out_queue and len(payload) < self.peer_mtu:
|
||||
# We can add more data to the payload
|
||||
chunk = self.out_queue[0][:self.peer_mtu - len(payload)]
|
||||
payload += chunk
|
||||
self.out_queue[0] = self.out_queue[0][len(chunk):]
|
||||
if len(self.out_queue[0]) == 0:
|
||||
# We consumed the entire buffer, remove it
|
||||
self.out_queue.popleft()
|
||||
logger.debug(f'packet completed, {len(self.out_queue)} left in queue')
|
||||
|
||||
# Construct the SDU with its header
|
||||
assert len(payload) != 0
|
||||
logger.debug(f'SDU complete: {len(payload)} payload bytes')
|
||||
self.out_sdu = struct.pack('<H', len(payload)) + payload
|
||||
else:
|
||||
# Nothing left to send for now
|
||||
self.drained.set()
|
||||
return
|
||||
|
||||
def write(self, data):
|
||||
if self.state != self.CONNECTED:
|
||||
logger.warning('not connected, dropping data')
|
||||
return
|
||||
|
||||
# Queue the data
|
||||
self.out_queue.append(data)
|
||||
self.drained.clear()
|
||||
logger.debug(f'{len(data)} bytes packet queued, {len(self.out_queue)} packets in queue')
|
||||
|
||||
# Send what we can
|
||||
self.process_output()
|
||||
|
||||
async def drain(self):
|
||||
await self.drained.wait()
|
||||
|
||||
def pause_reading(self):
|
||||
# TODO: not implemented yet
|
||||
pass
|
||||
|
||||
def resume_reading(self):
|
||||
# TODO: not implemented yet
|
||||
pass
|
||||
|
||||
def __str__(self):
|
||||
return f'CoC({self.source_cid}->{self.destination_cid}, State={self.state_name(self.state)}, PSM={self.le_psm}, MTU={self.mtu}/{self.peer_mtu}, MPS={self.mps}/{self.peer_mps}, credits={self.credits}/{self.peer_credits})'
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class ChannelManager:
|
||||
def __init__(self, extended_features=None, connectionless_mtu=1024):
|
||||
self.host = None
|
||||
self.channels = {} # Channels, mapped by connection and cid
|
||||
# Fixed channel handlers, mapped by cid
|
||||
self.fixed_channels = {
|
||||
L2CAP_SIGNALING_CID: None, L2CAP_LE_SIGNALING_CID: None}
|
||||
def __init__(self, extended_features=[], connectionless_mtu=L2CAP_DEFAULT_CONNECTIONLESS_MTU):
|
||||
self._host = None
|
||||
self.identifiers = {} # Incrementing identifier values by connection
|
||||
self.channels = {} # All channels, mapped by connection and source cid
|
||||
self.fixed_channels = { # Fixed channel handlers, mapped by cid
|
||||
L2CAP_SIGNALING_CID: None, L2CAP_LE_SIGNALING_CID: None
|
||||
}
|
||||
self.servers = {} # Servers accepting connections, by PSM
|
||||
self.extended_features = [] if extended_features is None else extended_features
|
||||
self.le_coc_channels = {} # LE CoC channels, mapped by connection and destination cid
|
||||
self.le_coc_servers = {} # LE CoC - Servers accepting connections, by PSM
|
||||
self.le_coc_requests = {} # LE CoC connection requests, by identifier
|
||||
self.extended_features = extended_features
|
||||
self.connectionless_mtu = connectionless_mtu
|
||||
|
||||
@property
|
||||
def host(self):
|
||||
return self._host
|
||||
|
||||
@host.setter
|
||||
def host(self, host):
|
||||
if self._host is not None:
|
||||
self._host.remove_listener('disconnection', self.on_disconnection)
|
||||
self._host = host
|
||||
if host is not None:
|
||||
host.add_listener('disconnection', self.on_disconnection)
|
||||
|
||||
def find_channel(self, connection_handle, cid):
|
||||
if connection_channels := self.channels.get(connection_handle):
|
||||
return connection_channels.get(cid)
|
||||
|
||||
def find_le_coc_channel(self, connection_handle, cid):
|
||||
if connection_channels := self.le_coc_channels.get(connection_handle):
|
||||
return connection_channels.get(cid)
|
||||
|
||||
@staticmethod
|
||||
def find_free_br_edr_cid(channels):
|
||||
# Pick the smallest valid CID that's not already in the list
|
||||
@@ -853,6 +1243,24 @@ class ChannelManager:
|
||||
if cid not in channels:
|
||||
return cid
|
||||
|
||||
@staticmethod
|
||||
def find_free_le_cid(channels):
|
||||
# Pick the smallest valid CID that's not already in the list
|
||||
# (not necessarily the most efficient algorithm, but the list of CID is
|
||||
# very small in practice)
|
||||
for cid in range(L2CAP_LE_U_DYNAMIC_CID_RANGE_START, L2CAP_LE_U_DYNAMIC_CID_RANGE_END + 1):
|
||||
if cid not in channels:
|
||||
return cid
|
||||
|
||||
@staticmethod
|
||||
def check_le_coc_parameters(max_credits, mtu, mps):
|
||||
if max_credits < 1 or max_credits > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_CREDITS:
|
||||
raise ValueError('max credits out of range')
|
||||
if mtu < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU:
|
||||
raise ValueError('MTU too small')
|
||||
if mps < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MPS or mps > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MPS:
|
||||
raise ValueError('MPS out of range')
|
||||
|
||||
def next_identifier(self, connection):
|
||||
identifier = (self.identifiers.setdefault(connection.handle, 0) + 1) % 256
|
||||
self.identifiers[connection.handle] = identifier
|
||||
@@ -866,8 +1274,77 @@ class ChannelManager:
|
||||
del self.fixed_channels[cid]
|
||||
|
||||
def register_server(self, psm, server):
|
||||
if psm == 0:
|
||||
# Find a free PSM
|
||||
for candidate in range(L2CAP_PSM_DYNAMIC_RANGE_START, L2CAP_PSM_DYNAMIC_RANGE_END + 1, 2):
|
||||
if (candidate >> 8) % 2 == 1:
|
||||
continue
|
||||
if candidate in self.servers:
|
||||
continue
|
||||
psm = candidate
|
||||
break
|
||||
else:
|
||||
raise InvalidStateError('no free PSM')
|
||||
else:
|
||||
# Check that the PSM isn't already in use
|
||||
if psm in self.servers:
|
||||
raise ValueError('PSM already in use')
|
||||
|
||||
# Check that the PSM is valid
|
||||
if psm % 2 == 0:
|
||||
raise ValueError('invalid PSM (not odd)')
|
||||
check = psm >> 8
|
||||
while check:
|
||||
if check % 2 != 0:
|
||||
raise ValueError('invalid PSM')
|
||||
check >>= 8
|
||||
|
||||
self.servers[psm] = server
|
||||
|
||||
return psm
|
||||
|
||||
def register_le_coc_server(
|
||||
self,
|
||||
psm,
|
||||
server,
|
||||
max_credits=L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS,
|
||||
mtu=L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU,
|
||||
mps=L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS
|
||||
):
|
||||
self.check_le_coc_parameters(max_credits, mtu, mps)
|
||||
|
||||
if psm == 0:
|
||||
# Find a free PSM
|
||||
for candidate in range(L2CAP_LE_PSM_DYNAMIC_RANGE_START, L2CAP_LE_PSM_DYNAMIC_RANGE_END + 1):
|
||||
if candidate in self.le_coc_servers:
|
||||
continue
|
||||
psm = candidate
|
||||
break
|
||||
else:
|
||||
raise InvalidStateError('no free PSM')
|
||||
else:
|
||||
# Check that the PSM isn't already in use
|
||||
if psm in self.le_coc_servers:
|
||||
raise ValueError('PSM already in use')
|
||||
|
||||
self.le_coc_servers[psm] = (
|
||||
server,
|
||||
max_credits or L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS,
|
||||
mtu or L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU,
|
||||
mps or L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS
|
||||
)
|
||||
|
||||
return psm
|
||||
|
||||
def on_disconnection(self, connection_handle, reason):
|
||||
logger.debug(f'disconnection from {connection_handle}, cleaning up channels')
|
||||
if connection_handle in self.channels:
|
||||
del self.channels[connection_handle]
|
||||
if connection_handle in self.le_coc_channels:
|
||||
del self.le_coc_channels[connection_handle]
|
||||
if connection_handle in self.identifiers:
|
||||
del self.identifiers[connection_handle]
|
||||
|
||||
def send_pdu(self, connection, cid, pdu):
|
||||
pdu_str = pdu.hex() if type(pdu) is bytes else str(pdu)
|
||||
logger.debug(f'{color(">>> Sending L2CAP PDU", "blue")} on connection [0x{connection.handle:04X}] (CID={cid}) {connection.peer_address}: {pdu_str}')
|
||||
@@ -883,7 +1360,7 @@ class ChannelManager:
|
||||
self.fixed_channels[cid](connection.handle, pdu)
|
||||
else:
|
||||
if (channel := self.find_channel(connection.handle, cid)) is None:
|
||||
logger.warn(color(f'channel not found for 0x{connection.handle:04X}:{cid}', 'red'))
|
||||
logger.warning(color(f'channel not found for 0x{connection.handle:04X}:{cid}', 'red'))
|
||||
return
|
||||
|
||||
channel.on_pdu(pdu)
|
||||
@@ -927,7 +1404,6 @@ class ChannelManager:
|
||||
|
||||
def on_l2cap_command_reject(self, connection, cid, packet):
|
||||
logger.warning(f'{color("!!! Command rejected:", "red")} {packet.reason}')
|
||||
pass
|
||||
|
||||
def on_l2cap_connection_request(self, connection, cid, request):
|
||||
# Check if there's a server for this PSM
|
||||
@@ -959,7 +1435,7 @@ class ChannelManager:
|
||||
server(channel)
|
||||
channel.on_connection_request(request)
|
||||
else:
|
||||
logger.warn(f'No server for connection 0x{connection.handle:04X} on PSM {request.psm}')
|
||||
logger.warning(f'No server for connection 0x{connection.handle:04X} on PSM {request.psm}')
|
||||
self.send_control_frame(
|
||||
connection,
|
||||
cid,
|
||||
@@ -974,35 +1450,35 @@ class ChannelManager:
|
||||
|
||||
def on_l2cap_connection_response(self, connection, cid, response):
|
||||
if (channel := self.find_channel(connection.handle, response.source_cid)) is None:
|
||||
logger.warn(color(f'channel {response.source_cid} not found for 0x{connection.handle:04X}:{cid}', 'red'))
|
||||
logger.warning(color(f'channel {response.source_cid} not found for 0x{connection.handle:04X}:{cid}', 'red'))
|
||||
return
|
||||
|
||||
channel.on_connection_response(response)
|
||||
|
||||
def on_l2cap_configure_request(self, connection, cid, request):
|
||||
if (channel := self.find_channel(connection.handle, request.destination_cid)) is None:
|
||||
logger.warn(color(f'channel {request.destination_cid} not found for 0x{connection.handle:04X}:{cid}', 'red'))
|
||||
logger.warning(color(f'channel {request.destination_cid} not found for 0x{connection.handle:04X}:{cid}', 'red'))
|
||||
return
|
||||
|
||||
channel.on_configure_request(request)
|
||||
|
||||
def on_l2cap_configure_response(self, connection, cid, response):
|
||||
if (channel := self.find_channel(connection.handle, response.source_cid)) is None:
|
||||
logger.warn(color(f'channel {response.source_cid} not found for 0x{connection.handle:04X}:{cid}', 'red'))
|
||||
logger.warning(color(f'channel {response.source_cid} not found for 0x{connection.handle:04X}:{cid}', 'red'))
|
||||
return
|
||||
|
||||
channel.on_configure_response(response)
|
||||
|
||||
def on_l2cap_disconnection_request(self, connection, cid, request):
|
||||
if (channel := self.find_channel(connection.handle, request.destination_cid)) is None:
|
||||
logger.warn(color(f'channel {request.destination_cid} not found for 0x{connection.handle:04X}:{cid}', 'red'))
|
||||
logger.warning(color(f'channel {request.destination_cid} not found for 0x{connection.handle:04X}:{cid}', 'red'))
|
||||
return
|
||||
|
||||
channel.on_disconnection_request(request)
|
||||
|
||||
def on_l2cap_disconnection_response(self, connection, cid, response):
|
||||
if (channel := self.find_channel(connection.handle, response.source_cid)) is None:
|
||||
logger.warn(color(f'channel {response.source_cid} not found for 0x{connection.handle:04X}:{cid}', 'red'))
|
||||
logger.warning(color(f'channel {response.source_cid} not found for 0x{connection.handle:04X}:{cid}', 'red'))
|
||||
return
|
||||
|
||||
channel.on_disconnection_response(response)
|
||||
@@ -1076,25 +1552,123 @@ class ChannelManager:
|
||||
)
|
||||
|
||||
def on_l2cap_connection_parameter_update_response(self, connection, cid, response):
|
||||
# TODO: check response
|
||||
pass
|
||||
|
||||
def on_l2cap_le_credit_based_connection_request(self, connection, cid, request):
|
||||
# FIXME: temp fixed values
|
||||
if request.le_psm in self.le_coc_servers:
|
||||
(server, max_credits, mtu, mps) = self.le_coc_servers[request.le_psm]
|
||||
|
||||
# Check that the CID isn't already used
|
||||
le_connection_channels = self.le_coc_channels.setdefault(connection.handle, {})
|
||||
if request.source_cid in le_connection_channels:
|
||||
logger.warning(f'source CID {request.source_cid} already in use')
|
||||
self.send_control_frame(
|
||||
connection,
|
||||
cid,
|
||||
L2CAP_LE_Credit_Based_Connection_Response(
|
||||
identifier = request.identifier,
|
||||
destination_cid = 194, # FIXME: for testing only
|
||||
mtu = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU,
|
||||
mps = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS,
|
||||
initial_credits = 3, # FIXME: for testing only
|
||||
destination_cid = 0,
|
||||
mtu = mtu,
|
||||
mps = mps,
|
||||
initial_credits = 0,
|
||||
result = L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_REFUSED_SOURCE_CID_ALREADY_ALLOCATED
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# Find a free CID for this new channel
|
||||
connection_channels = self.channels.setdefault(connection.handle, {})
|
||||
source_cid = self.find_free_le_cid(connection_channels)
|
||||
if source_cid is None: # Should never happen!
|
||||
self.send_control_frame(
|
||||
connection,
|
||||
cid,
|
||||
L2CAP_LE_Credit_Based_Connection_Response(
|
||||
identifier = request.identifier,
|
||||
destination_cid = 0,
|
||||
mtu = mtu,
|
||||
mps = mps,
|
||||
initial_credits = 0,
|
||||
result = L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_REFUSED_NO_RESOURCES_AVAILABLE,
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# Create a new channel
|
||||
logger.debug(f'creating LE CoC server channel with cid={source_cid} for psm {request.le_psm}')
|
||||
channel = LeConnectionOrientedChannel(
|
||||
self,
|
||||
connection,
|
||||
request.le_psm,
|
||||
source_cid,
|
||||
request.source_cid,
|
||||
mtu,
|
||||
mps,
|
||||
request.initial_credits,
|
||||
request.mtu,
|
||||
request.mps,
|
||||
max_credits,
|
||||
True
|
||||
)
|
||||
connection_channels[source_cid] = channel
|
||||
le_connection_channels[request.source_cid] = channel
|
||||
|
||||
# Respond
|
||||
self.send_control_frame(
|
||||
connection,
|
||||
cid,
|
||||
L2CAP_LE_Credit_Based_Connection_Response(
|
||||
identifier = request.identifier,
|
||||
destination_cid = source_cid,
|
||||
mtu = mtu,
|
||||
mps = mps,
|
||||
initial_credits = max_credits,
|
||||
result = L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_SUCCESSFUL
|
||||
)
|
||||
)
|
||||
|
||||
def on_l2cap_le_flow_control_credit(self, connection, cid, packet):
|
||||
pass
|
||||
# Notify
|
||||
server(channel)
|
||||
else:
|
||||
logger.info(f'No LE server for connection 0x{connection.handle:04X} on PSM {request.le_psm}')
|
||||
self.send_control_frame(
|
||||
connection,
|
||||
cid,
|
||||
L2CAP_LE_Credit_Based_Connection_Response(
|
||||
identifier = request.identifier,
|
||||
destination_cid = 0,
|
||||
mtu = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU,
|
||||
mps = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS,
|
||||
initial_credits = 0,
|
||||
result = L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_REFUSED_LE_PSM_NOT_SUPPORTED,
|
||||
)
|
||||
)
|
||||
|
||||
def on_l2cap_le_credit_based_connection_response(self, connection, cid, response):
|
||||
# Find the pending request by identifier
|
||||
request = self.le_coc_requests.get(response.identifier)
|
||||
if request is None:
|
||||
logger.warning(color('!!! received response for unknown request', 'red'))
|
||||
return
|
||||
del self.le_coc_requests[response.identifier]
|
||||
|
||||
# Find the channel for this request
|
||||
channel = self.find_channel(connection.handle, request.source_cid)
|
||||
if channel is None:
|
||||
logger.warning(color(f'received connection response for an unknown channel (cid={request.source_cid})', 'red'))
|
||||
return
|
||||
|
||||
# Process the response
|
||||
channel.on_connection_response(response)
|
||||
|
||||
def on_l2cap_le_flow_control_credit(self, connection, cid, credit):
|
||||
channel = self.find_le_coc_channel(connection.handle, credit.cid)
|
||||
if channel is None:
|
||||
logger.warning(f'received credits for an unknown channel (cid={credit.cid}')
|
||||
return
|
||||
|
||||
channel.on_credits(credit.credits)
|
||||
|
||||
def on_channel_closed(self, channel):
|
||||
connection_channels = self.channels.get(channel.connection.handle)
|
||||
@@ -1102,22 +1676,65 @@ class ChannelManager:
|
||||
if channel.source_cid in connection_channels:
|
||||
del connection_channels[channel.source_cid]
|
||||
|
||||
async def connect(self, connection, psm):
|
||||
# NOTE: this implementation hard-codes BR/EDR more
|
||||
# TODO: LE mode (maybe?)
|
||||
async def open_le_coc(self, connection, psm, max_credits, mtu, mps):
|
||||
self.check_le_coc_parameters(max_credits, mtu, mps)
|
||||
|
||||
# Find a free CID for a new channel
|
||||
# Find a free CID for the new channel
|
||||
connection_channels = self.channels.setdefault(connection.handle, {})
|
||||
cid = self.find_free_br_edr_cid(connection_channels)
|
||||
if cid is None: # Should never happen!
|
||||
source_cid = self.find_free_le_cid(connection_channels)
|
||||
if source_cid is None: # Should never happen!
|
||||
raise RuntimeError('all CIDs already in use')
|
||||
|
||||
# Create the channel
|
||||
logger.debug(f'creating client channel with cid={cid} for psm {psm}')
|
||||
channel = Channel(self, connection, L2CAP_SIGNALING_CID, psm, cid, L2CAP_MIN_BR_EDR_MTU)
|
||||
connection_channels[cid] = channel
|
||||
logger.debug(f'creating coc channel with cid={source_cid} for psm {psm}')
|
||||
channel = LeConnectionOrientedChannel(
|
||||
manager = self,
|
||||
connection = connection,
|
||||
le_psm = psm,
|
||||
source_cid = source_cid,
|
||||
destination_cid = 0,
|
||||
mtu = mtu,
|
||||
mps = mps,
|
||||
credits = 0,
|
||||
peer_mtu = 0,
|
||||
peer_mps = 0,
|
||||
peer_credits = max_credits,
|
||||
connected = False
|
||||
)
|
||||
connection_channels[source_cid] = channel
|
||||
|
||||
# Connect
|
||||
try:
|
||||
await channel.connect()
|
||||
except Exception as error:
|
||||
logger.warning(f'connection failed: {error}')
|
||||
del connection_channels[source_cid]
|
||||
raise
|
||||
|
||||
# Remember the channel by source CID and destination CID
|
||||
le_connection_channels = self.le_coc_channels.setdefault(connection.handle, {})
|
||||
le_connection_channels[channel.destination_cid] = channel
|
||||
|
||||
return channel
|
||||
|
||||
async def connect(self, connection, psm):
|
||||
# NOTE: this implementation hard-codes BR/EDR
|
||||
|
||||
# Find a free CID for a new channel
|
||||
connection_channels = self.channels.setdefault(connection.handle, {})
|
||||
source_cid = self.find_free_br_edr_cid(connection_channels)
|
||||
if source_cid is None: # Should never happen!
|
||||
raise RuntimeError('all CIDs already in use')
|
||||
|
||||
# Create the channel
|
||||
logger.debug(f'creating client channel with cid={source_cid} for psm {psm}')
|
||||
channel = Channel(self, connection, L2CAP_SIGNALING_CID, psm, source_cid, L2CAP_MIN_BR_EDR_MTU)
|
||||
connection_channels[source_cid] = channel
|
||||
|
||||
# Connect
|
||||
try:
|
||||
await channel.connect()
|
||||
except Exception:
|
||||
del connection_channels[source_cid]
|
||||
|
||||
return channel
|
||||
|
||||
@@ -274,7 +274,7 @@ class PumpedPacketSource(ParserSource):
|
||||
self.terminated.set_result(error)
|
||||
break
|
||||
|
||||
self.pump_task = asyncio.get_running_loop().create_task(pump_packets())
|
||||
self.pump_task = asyncio.create_task(pump_packets())
|
||||
|
||||
def close(self):
|
||||
if self.pump_task:
|
||||
@@ -304,7 +304,7 @@ class PumpedPacketSink:
|
||||
logger.warn(f'exception while sending packet: {error}')
|
||||
break
|
||||
|
||||
self.pump_task = asyncio.get_running_loop().create_task(pump_packets())
|
||||
self.pump_task = asyncio.create_task(pump_packets())
|
||||
|
||||
def close(self):
|
||||
if self.pump_task:
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import traceback
|
||||
import collections
|
||||
from functools import wraps
|
||||
from colors import color
|
||||
from pyee import EventEmitter
|
||||
@@ -140,3 +141,95 @@ class AsyncRunner:
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class FlowControlAsyncPipe:
|
||||
"""
|
||||
Asyncio pipe with flow control. When writing to the pipe, the source is
|
||||
paused (by calling a function passed in when the pipe is created) if the
|
||||
amount of queued data exceeds a specified threshold.
|
||||
"""
|
||||
def __init__(self, pause_source, resume_source, write_to_sink=None, drain_sink=None, threshold=0):
|
||||
self.pause_source = pause_source
|
||||
self.resume_source = resume_source
|
||||
self.write_to_sink = write_to_sink
|
||||
self.drain_sink = drain_sink
|
||||
self.threshold = threshold
|
||||
self.queue = collections.deque() # Queue of packets
|
||||
self.queued_bytes = 0 # Number of bytes in the queue
|
||||
self.ready_to_pump = asyncio.Event()
|
||||
self.paused = False
|
||||
self.source_paused = False
|
||||
self.pump_task = None
|
||||
|
||||
def start(self):
|
||||
if self.pump_task is None:
|
||||
self.pump_task = asyncio.create_task(self.pump())
|
||||
|
||||
self.check_pump()
|
||||
|
||||
def stop(self):
|
||||
if self.pump_task is not None:
|
||||
self.pump_task.cancel()
|
||||
self.pump_task = None
|
||||
|
||||
def write(self, packet):
|
||||
self.queued_bytes += len(packet)
|
||||
self.queue.append(packet)
|
||||
|
||||
# Pause the source if we're over the threshold
|
||||
if self.queued_bytes > self.threshold and not self.source_paused:
|
||||
logger.debug(f'pausing source (queued={self.queued_bytes})')
|
||||
self.pause_source()
|
||||
self.source_paused = True
|
||||
|
||||
self.check_pump()
|
||||
|
||||
def pause(self):
|
||||
if not self.paused:
|
||||
self.paused = True
|
||||
if not self.source_paused:
|
||||
self.pause_source()
|
||||
self.source_paused = True
|
||||
self.check_pump()
|
||||
|
||||
def resume(self):
|
||||
if self.paused:
|
||||
self.paused = False
|
||||
if self.source_paused:
|
||||
self.resume_source()
|
||||
self.source_paused = False
|
||||
self.check_pump()
|
||||
|
||||
def can_pump(self):
|
||||
return self.queue and not self.paused and self.write_to_sink is not None
|
||||
|
||||
def check_pump(self):
|
||||
if self.can_pump():
|
||||
self.ready_to_pump.set()
|
||||
else:
|
||||
self.ready_to_pump.clear()
|
||||
|
||||
async def pump(self):
|
||||
while True:
|
||||
# Wait until we can try to pump packets
|
||||
await self.ready_to_pump.wait()
|
||||
|
||||
# Try to pump a packet
|
||||
if self.can_pump():
|
||||
packet = self.queue.pop()
|
||||
self.write_to_sink(packet)
|
||||
self.queued_bytes -= len(packet)
|
||||
|
||||
# Drain the sink if we can
|
||||
if self.drain_sink:
|
||||
await self.drain_sink()
|
||||
|
||||
# Check if we can accept more
|
||||
if self.queued_bytes <= self.threshold and self.source_paused:
|
||||
logger.debug(f'resuming source (queued={self.queued_bytes})')
|
||||
self.source_paused = False
|
||||
self.resume_source()
|
||||
|
||||
self.check_pump()
|
||||
|
||||
5
examples/asha_sink1.json
Normal file
5
examples/asha_sink1.json
Normal file
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"name": "Bumble Aid Left",
|
||||
"address": "F1:F2:F3:F4:F5:F6",
|
||||
"keystore": "JsonKeyStore"
|
||||
}
|
||||
5
examples/asha_sink2.json
Normal file
5
examples/asha_sink2.json
Normal file
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"name": "Bumble Aid Right",
|
||||
"address": "F7:F8:F9:FA:FB:FC",
|
||||
"keystore": "JsonKeyStore"
|
||||
}
|
||||
161
examples/run_asha_sink.py
Normal file
161
examples/run_asha_sink.py
Normal file
@@ -0,0 +1,161 @@
|
||||
# Copyright 2021-2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import asyncio
|
||||
import struct
|
||||
import sys
|
||||
import os
|
||||
import logging
|
||||
|
||||
from bumble.core import AdvertisingData
|
||||
from bumble.device import Device
|
||||
from bumble.transport import open_transport_or_link
|
||||
from bumble.hci import UUID
|
||||
from bumble.gatt import (
|
||||
Service,
|
||||
Characteristic,
|
||||
CharacteristicValue
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Constants
|
||||
# -----------------------------------------------------------------------------
|
||||
ASHA_SERVICE = UUID.from_16_bits(0xFDF0, 'Audio Streaming for Hearing Aid')
|
||||
ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC = UUID('6333651e-c481-4a3e-9169-7c902aad37bb', 'ReadOnlyProperties')
|
||||
ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC = UUID('f0d4de7e-4a88-476c-9d9f-1937b0996cc0', 'AudioControlPoint')
|
||||
ASHA_AUDIO_STATUS_CHARACTERISTIC = UUID('38663f1a-e711-4cac-b641-326b56404837', 'AudioStatus')
|
||||
ASHA_VOLUME_CHARACTERISTIC = UUID('00e4ca9e-ab14-41e4-8823-f9e70c7e91df', 'Volume')
|
||||
ASHA_LE_PSM_OUT_CHARACTERISTIC = UUID('2d410339-82b6-42aa-b34e-e2e01df8cc1a', 'LE_PSM_OUT')
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def main():
|
||||
if len(sys.argv) != 4:
|
||||
print('Usage: python run_asha_sink.py <device-config> <transport-spec> <audio-file>')
|
||||
print('example: python run_asha_sink.py device1.json usb:0 audio_out.g722')
|
||||
return
|
||||
|
||||
audio_out = open(sys.argv[3], 'wb')
|
||||
|
||||
async with await open_transport_or_link(sys.argv[2]) as (hci_source, hci_sink):
|
||||
device = Device.from_config_file_with_hci(sys.argv[1], hci_source, hci_sink)
|
||||
|
||||
# Handler for audio control commands
|
||||
def on_audio_control_point_write(connection, value):
|
||||
print('--- AUDIO CONTROL POINT Write:', value.hex())
|
||||
opcode = value[0]
|
||||
if opcode == 1:
|
||||
# Start
|
||||
audio_type = ('Unknown', 'Ringtone', 'Phone Call', 'Media')[value[2]]
|
||||
print(f'### START: codec={value[1]}, audio_type={audio_type}, volume={value[3]}, otherstate={value[4]}')
|
||||
elif opcode == 2:
|
||||
print('### STOP')
|
||||
elif opcode == 3:
|
||||
print(f'### STATUS: connected={value[1]}')
|
||||
|
||||
# Respond with a status
|
||||
asyncio.create_task(device.notify_subscribers(audio_status_characteristic, force=True))
|
||||
|
||||
# Handler for volume control
|
||||
def on_volume_write(connection, value):
|
||||
print('--- VOLUME Write:', value[0])
|
||||
|
||||
# Register an L2CAP CoC server
|
||||
def on_coc(channel):
|
||||
def on_data(data):
|
||||
print('<<< Voice data received:', data.hex())
|
||||
audio_out.write(data)
|
||||
|
||||
channel.sink = on_data
|
||||
|
||||
psm = device.register_l2cap_channel_server(0, on_coc, 8)
|
||||
print(f'### LE_PSM_OUT = {psm}')
|
||||
|
||||
# Add the ASHA service to the GATT server
|
||||
read_only_properties_characteristic = Characteristic(
|
||||
ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC,
|
||||
Characteristic.READ,
|
||||
Characteristic.READABLE,
|
||||
bytes([
|
||||
0x01, # Version
|
||||
0x00, # Device Capabilities [Left, Monaural]
|
||||
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, # HiSyncId
|
||||
0x01, # Feature Map [LE CoC audio output streaming supported]
|
||||
0x00, 0x00, # Render Delay
|
||||
0x00, 0x00, # RFU
|
||||
0x02, 0x00 # Codec IDs [G.722 at 16 kHz]
|
||||
])
|
||||
)
|
||||
audio_control_point_characteristic = Characteristic(
|
||||
ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC,
|
||||
Characteristic.WRITE | Characteristic.WRITE_WITHOUT_RESPONSE,
|
||||
Characteristic.WRITEABLE,
|
||||
CharacteristicValue(write=on_audio_control_point_write)
|
||||
)
|
||||
audio_status_characteristic = Characteristic(
|
||||
ASHA_AUDIO_STATUS_CHARACTERISTIC,
|
||||
Characteristic.READ | Characteristic.NOTIFY,
|
||||
Characteristic.READABLE,
|
||||
bytes([0])
|
||||
)
|
||||
volume_characteristic = Characteristic(
|
||||
ASHA_VOLUME_CHARACTERISTIC,
|
||||
Characteristic.WRITE_WITHOUT_RESPONSE,
|
||||
Characteristic.WRITEABLE,
|
||||
CharacteristicValue(write=on_volume_write)
|
||||
)
|
||||
le_psm_out_characteristic = Characteristic(
|
||||
ASHA_LE_PSM_OUT_CHARACTERISTIC,
|
||||
Characteristic.READ,
|
||||
Characteristic.READABLE,
|
||||
struct.pack('<H', psm)
|
||||
)
|
||||
device.add_service(Service(
|
||||
ASHA_SERVICE,
|
||||
[
|
||||
read_only_properties_characteristic,
|
||||
audio_control_point_characteristic,
|
||||
audio_status_characteristic,
|
||||
volume_characteristic,
|
||||
le_psm_out_characteristic
|
||||
]
|
||||
))
|
||||
|
||||
# Set the advertising data
|
||||
device.advertising_data = bytes(
|
||||
AdvertisingData([
|
||||
(AdvertisingData.COMPLETE_LOCAL_NAME, bytes(device.name, 'utf-8')),
|
||||
(AdvertisingData.FLAGS, bytes([0x06])),
|
||||
(AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS, bytes(ASHA_SERVICE)),
|
||||
(AdvertisingData.SERVICE_DATA_16_BIT_UUID, bytes(ASHA_SERVICE) + bytes([
|
||||
0x01, # Protocol Version
|
||||
0x00, # Capability
|
||||
0x01, 0x02, 0x03, 0x04 # Truncated HiSyncID
|
||||
]))
|
||||
])
|
||||
)
|
||||
|
||||
# Go!
|
||||
await device.power_on()
|
||||
await device.start_advertising(auto_restart=True)
|
||||
|
||||
await hci_source.wait_for_termination()
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper())
|
||||
asyncio.run(main())
|
||||
@@ -49,11 +49,17 @@ class Listener(Device.Listener, Connection.Listener):
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Alternative way to listen for subscriptions
|
||||
# -----------------------------------------------------------------------------
|
||||
def on_my_characteristic_subscription(peer, enabled):
|
||||
print(f'### My characteristic from {peer}: {"enabled" if enabled else "disabled"}')
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def main():
|
||||
if len(sys.argv) < 3:
|
||||
print('Usage: run_gatt_server.py <device-config> <transport-spec>')
|
||||
print('example: run_gatt_server.py device1.json usb:0')
|
||||
print('Usage: run_notifier.py <device-config> <transport-spec>')
|
||||
print('example: run_notifier.py device1.json usb:0')
|
||||
return
|
||||
|
||||
print('<<< connecting to HCI...')
|
||||
@@ -83,6 +89,7 @@ async def main():
|
||||
Characteristic.READABLE,
|
||||
bytes([0x42])
|
||||
)
|
||||
characteristic3.on('subscription', on_my_characteristic_subscription)
|
||||
custom_service = Service(
|
||||
'50DB505C-8AC4-4738-8448-3B1D9CC09CC5',
|
||||
[characteristic1, characteristic2, characteristic3]
|
||||
|
||||
@@ -98,6 +98,7 @@ async def list_rfcomm_channels(device, connection):
|
||||
|
||||
await sdp_client.disconnect()
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class TcpServerProtocol(asyncio.Protocol):
|
||||
def __init__(self, rfcomm_session):
|
||||
@@ -173,7 +174,7 @@ async def main():
|
||||
print('*** Encryption on')
|
||||
|
||||
# Create a client and start it
|
||||
print('@@@ Starting to RFCOMM client...')
|
||||
print('@@@ Starting RFCOMM client...')
|
||||
rfcomm_client = Client(device, connection)
|
||||
rfcomm_mux = await rfcomm_client.start()
|
||||
print('@@@ Started')
|
||||
@@ -192,7 +193,7 @@ async def main():
|
||||
if len(sys.argv) == 6:
|
||||
# A TCP port was specified, start listening
|
||||
tcp_port = int(sys.argv[5])
|
||||
asyncio.get_running_loop().create_task(tcp_server(tcp_port, session))
|
||||
asyncio.create_task(tcp_server(tcp_port, session))
|
||||
|
||||
await hci_source.wait_for_termination()
|
||||
|
||||
|
||||
@@ -51,6 +51,7 @@ console_scripts =
|
||||
bumble-controller-info = bumble.apps.controller_info:main
|
||||
bumble-gatt-dump = bumble.apps.gatt_dump:main
|
||||
bumble-hci-bridge = bumble.apps.hci_bridge:main
|
||||
bumble-l2cap-bridge = bumble.apps.l2cap_bridge:main
|
||||
bumble-pair = bumble.apps.pair:main
|
||||
bumble-scan = bumble.apps.scan:main
|
||||
bumble-show = bumble.apps.show:main
|
||||
@@ -64,6 +65,7 @@ build =
|
||||
test =
|
||||
pytest >= 6.2
|
||||
pytest-asyncio >= 0.17
|
||||
coverage >= 6.4
|
||||
development =
|
||||
invoke >= 1.4
|
||||
nox >= 2022
|
||||
|
||||
284
tests/l2cap_test.py
Normal file
284
tests/l2cap_test.py
Normal file
@@ -0,0 +1,284 @@
|
||||
# Copyright 2021-2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Imports
|
||||
# -----------------------------------------------------------------------------
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import pytest
|
||||
|
||||
from bumble.controller import Controller
|
||||
from bumble.link import LocalLink
|
||||
from bumble.device import Device
|
||||
from bumble.host import Host
|
||||
from bumble.transport import AsyncPipeSink
|
||||
from bumble.core import ProtocolError
|
||||
from bumble.l2cap import (
|
||||
L2CAP_Connection_Request
|
||||
)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Logging
|
||||
# -----------------------------------------------------------------------------
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
class TwoDevices:
|
||||
def __init__(self):
|
||||
self.connections = [None, None]
|
||||
|
||||
self.link = LocalLink()
|
||||
self.controllers = [
|
||||
Controller('C1', link = self.link),
|
||||
Controller('C2', link = self.link)
|
||||
]
|
||||
self.devices = [
|
||||
Device(
|
||||
address = 'F0:F1:F2:F3:F4:F5',
|
||||
host = Host(self.controllers[0], AsyncPipeSink(self.controllers[0]))
|
||||
),
|
||||
Device(
|
||||
address = 'F5:F4:F3:F2:F1:F0',
|
||||
host = Host(self.controllers[1], AsyncPipeSink(self.controllers[1]))
|
||||
)
|
||||
]
|
||||
|
||||
self.paired = [None, None]
|
||||
|
||||
def on_connection(self, which, connection):
|
||||
self.connections[which] = connection
|
||||
|
||||
def on_paired(self, which, keys):
|
||||
self.paired[which] = keys
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def setup_connection():
|
||||
# Create two devices, each with a controller, attached to the same link
|
||||
two_devices = TwoDevices()
|
||||
|
||||
# Attach listeners
|
||||
two_devices.devices[0].on('connection', lambda connection: two_devices.on_connection(0, connection))
|
||||
two_devices.devices[1].on('connection', lambda connection: two_devices.on_connection(1, connection))
|
||||
|
||||
# Start
|
||||
await two_devices.devices[0].power_on()
|
||||
await two_devices.devices[1].power_on()
|
||||
|
||||
# Connect the two devices
|
||||
await two_devices.devices[0].connect(two_devices.devices[1].random_address)
|
||||
|
||||
# Check the post conditions
|
||||
assert(two_devices.connections[0] is not None)
|
||||
assert(two_devices.connections[1] is not None)
|
||||
|
||||
return two_devices
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
def test_helpers():
|
||||
psm = L2CAP_Connection_Request.serialize_psm(0x01)
|
||||
assert(psm == bytes([0x01, 0x00]))
|
||||
|
||||
psm = L2CAP_Connection_Request.serialize_psm(0x1023)
|
||||
assert(psm == bytes([0x23, 0x10]))
|
||||
|
||||
psm = L2CAP_Connection_Request.serialize_psm(0x242311)
|
||||
assert(psm == bytes([0x11, 0x23, 0x24]))
|
||||
|
||||
(offset, psm) = L2CAP_Connection_Request.parse_psm(bytes([0x00, 0x01, 0x00, 0x44]), 1)
|
||||
assert(offset == 3)
|
||||
assert(psm == 0x01)
|
||||
|
||||
(offset, psm) = L2CAP_Connection_Request.parse_psm(bytes([0x00, 0x23, 0x10, 0x44]), 1)
|
||||
assert(offset == 3)
|
||||
assert(psm == 0x1023)
|
||||
|
||||
(offset, psm) = L2CAP_Connection_Request.parse_psm(bytes([0x00, 0x11, 0x23, 0x24, 0x44]), 1)
|
||||
assert(offset == 4)
|
||||
assert(psm == 0x242311)
|
||||
|
||||
rq = L2CAP_Connection_Request(psm = 0x01, source_cid = 0x44)
|
||||
brq = bytes(rq)
|
||||
srq = L2CAP_Connection_Request.from_bytes(brq)
|
||||
assert(srq.psm == rq.psm)
|
||||
assert(srq.source_cid == rq.source_cid)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_connection():
|
||||
devices = await setup_connection()
|
||||
psm = 1234
|
||||
|
||||
# Check that if there's no one listening, we can't connect
|
||||
with pytest.raises(ProtocolError):
|
||||
l2cap_channel = await devices.connections[0].open_l2cap_channel(psm)
|
||||
|
||||
# Now add a listener
|
||||
incoming_channel = None
|
||||
received = []
|
||||
|
||||
def on_coc(channel):
|
||||
nonlocal incoming_channel
|
||||
incoming_channel = channel
|
||||
|
||||
def on_data(data):
|
||||
received.append(data)
|
||||
|
||||
channel.sink = on_data
|
||||
|
||||
devices.devices[1].register_l2cap_channel_server(psm, on_coc)
|
||||
l2cap_channel = await devices.connections[0].open_l2cap_channel(psm)
|
||||
|
||||
messages = (
|
||||
bytes([1, 2, 3]),
|
||||
bytes([4, 5, 6]),
|
||||
bytes(10000)
|
||||
)
|
||||
for message in messages:
|
||||
l2cap_channel.write(message)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
await l2cap_channel.drain()
|
||||
|
||||
# Test closing
|
||||
closed = [False, False]
|
||||
closed_event = asyncio.Event()
|
||||
|
||||
def on_close(which, event):
|
||||
closed[which] = True
|
||||
if event:
|
||||
event.set()
|
||||
|
||||
l2cap_channel.on('close', lambda: on_close(0, None))
|
||||
incoming_channel.on('close', lambda: on_close(1, closed_event))
|
||||
await l2cap_channel.disconnect()
|
||||
assert(closed == [True, True])
|
||||
await closed_event.wait()
|
||||
|
||||
sent_bytes = b''.join(messages)
|
||||
received_bytes = b''.join(received)
|
||||
assert(sent_bytes == received_bytes)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def transfer_payload(max_credits, mtu, mps):
|
||||
devices = await setup_connection()
|
||||
|
||||
received = []
|
||||
|
||||
def on_coc(channel):
|
||||
def on_data(data):
|
||||
received.append(data)
|
||||
|
||||
channel.sink = on_data
|
||||
|
||||
psm = devices.devices[1].register_l2cap_channel_server(
|
||||
psm = 0,
|
||||
server = on_coc,
|
||||
max_credits = max_credits,
|
||||
mtu = mtu,
|
||||
mps = mps
|
||||
)
|
||||
l2cap_channel = await devices.connections[0].open_l2cap_channel(psm)
|
||||
|
||||
messages = [
|
||||
bytes([1, 2, 3, 4, 5, 6, 7]) * x
|
||||
for x in (3, 10, 100, 500, 789)
|
||||
]
|
||||
for message in messages:
|
||||
l2cap_channel.write(message)
|
||||
await asyncio.sleep(0)
|
||||
if random.randint(0, 5) == 1:
|
||||
await l2cap_channel.drain()
|
||||
|
||||
await l2cap_channel.drain()
|
||||
await l2cap_channel.disconnect()
|
||||
|
||||
sent_bytes = b''.join(messages)
|
||||
received_bytes = b''.join(received)
|
||||
assert(sent_bytes == received_bytes)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transfer():
|
||||
for max_credits in (1, 10, 100, 10000):
|
||||
for mtu in (23, 24, 25, 26, 50, 200, 255, 256, 1000):
|
||||
for mps in (23, 24, 25, 26, 50, 200, 255, 256, 1000):
|
||||
# print(max_credits, mtu, mps)
|
||||
await transfer_payload(max_credits, mtu, mps)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@pytest.mark.asyncio
|
||||
async def test_bidirectional_transfer():
|
||||
devices = await setup_connection()
|
||||
|
||||
client_received = []
|
||||
server_received = []
|
||||
server_channel = None
|
||||
|
||||
def on_server_coc(channel):
|
||||
nonlocal server_channel
|
||||
server_channel = channel
|
||||
|
||||
def on_server_data(data):
|
||||
server_received.append(data)
|
||||
|
||||
channel.sink = on_server_data
|
||||
|
||||
def on_client_data(data):
|
||||
client_received.append(data)
|
||||
|
||||
psm = devices.devices[1].register_l2cap_channel_server(psm=0, server=on_server_coc)
|
||||
client_channel = await devices.connections[0].open_l2cap_channel(psm)
|
||||
client_channel.sink = on_client_data
|
||||
|
||||
messages = [
|
||||
bytes([1, 2, 3, 4, 5, 6, 7]) * x
|
||||
for x in (3, 10, 100)
|
||||
]
|
||||
for message in messages:
|
||||
client_channel.write(message)
|
||||
await client_channel.drain()
|
||||
await asyncio.sleep(0)
|
||||
server_channel.write(message)
|
||||
await server_channel.drain()
|
||||
|
||||
await client_channel.disconnect()
|
||||
|
||||
message_bytes = b''.join(messages)
|
||||
client_received_bytes = b''.join(client_received)
|
||||
server_received_bytes = b''.join(server_received)
|
||||
assert(client_received_bytes == message_bytes)
|
||||
assert(server_received_bytes == message_bytes)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
async def run():
|
||||
test_helpers()
|
||||
await test_basic_connection()
|
||||
await test_transfer()
|
||||
await test_bidirectional_transfer()
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
if __name__ == '__main__':
|
||||
logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper())
|
||||
asyncio.run(run())
|
||||
Reference in New Issue
Block a user