diff --git a/apps/gg_bridge.py b/apps/gg_bridge.py index d5249139..ac3df8d7 100644 --- a/apps/gg_bridge.py +++ b/apps/gg_bridge.py @@ -17,13 +17,14 @@ # ----------------------------------------------------------------------------- import asyncio import os +import struct import logging import click from colors import color from bumble.device import Device, Peer from bumble.core import AdvertisingData -from bumble.gatt import Service, Characteristic +from bumble.gatt import Service, Characteristic, CharacteristicValue from bumble.utils import AsyncRunner from bumble.transport import open_transport_or_link from bumble.hci import HCI_Constant @@ -41,13 +42,59 @@ GG_PREFERRED_MTU = 256 # ----------------------------------------------------------------------------- -class GattlinkHubBridge(Device.Listener): +class GattlinkL2capEndpoint: def __init__(self): - self.peer = None - self.rx_socket = None - self.tx_socket = None - self.rx_characteristic = None - self.tx_characteristic = None + self.l2cap_channel = None + self.l2cap_packet = b'' + self.l2cap_packet_size = 0 + + # Called when an L2CAP SDU has been received + def on_coc_sdu(self, sdu): + print(color(f'<<< [L2CAP SDU]: {len(sdu)} bytes', 'cyan')) + while len(sdu): + if self.l2cap_packet_size == 0: + # Expect a new packet + self.l2cap_packet_size = sdu[0] + 1 + sdu = sdu[1:] + else: + bytes_needed = self.l2cap_packet_size - len(self.l2cap_packet) + chunk = min(bytes_needed, len(sdu)) + self.l2cap_packet += sdu[:chunk] + sdu = sdu[chunk:] + if len(self.l2cap_packet) == self.l2cap_packet_size: + self.on_l2cap_packet(self.l2cap_packet) + self.l2cap_packet = b'' + self.l2cap_packet_size = 0 + + +# ----------------------------------------------------------------------------- +class GattlinkHubBridge(GattlinkL2capEndpoint, Device.Listener): + def __init__(self, device, peer_address): + super().__init__() + self.device = device + self.peer_address = peer_address + self.peer = None + self.tx_socket = None + self.rx_characteristic = None + self.tx_characteristic = None + self.l2cap_psm_characteristic = None + + device.listener = self + + async def start(self): + # Connect to the peer + print(f'=== Connecting to {self.peer_address}...') + await self.device.connect(self.peer_address) + + async def connect_l2cap(self, psm): + print(color(f'### Connecting with L2CAP on PSM = {psm}', 'yellow')) + try: + self.l2cap_channel = await self.peer.connection.open_l2cap_channel(psm) + print(color('*** Connected', 'yellow'), self.l2cap_channel) + self.l2cap_channel.sink = self.on_coc_sdu + + except Exception as error: + print(color(f'!!! Connection failed: {error}', 'red')) @AsyncRunner.run_in_task() async def on_connection(self, connection): @@ -80,15 +127,24 @@ class GattlinkHubBridge(Device.Listener): self.rx_characteristic = characteristic elif characteristic.uuid == GG_GATTLINK_TX_CHARACTERISTIC_UUID: self.tx_characteristic = characteristic + elif characteristic.uuid == GG_GATTLINK_L2CAP_CHANNEL_PSM_CHARACTERISTIC_UUID: + self.l2cap_psm_characteristic = characteristic print('RX:', self.rx_characteristic) print('TX:', self.tx_characteristic) + print('PSM:', self.l2cap_psm_characteristic) - # Subscribe to TX - if self.tx_characteristic: + if self.l2cap_psm_characteristic: + # Subscribe to and then read the PSM value + await self.peer.subscribe(self.l2cap_psm_characteristic, self.on_l2cap_psm_received) + psm_bytes = await self.peer.read_value(self.l2cap_psm_characteristic) + psm = struct.unpack('>> [UDP]', 'magenta')) + self.tx_socket.sendto(packet) + # Called by the GATT client when a notification is received def on_tx_received(self, value): - print(color('>>> TX:', 'magenta'), value.hex()) + print(color(f'<<< [GATT TX]: {len(value)} bytes', 'cyan')) if self.tx_socket: + print(color('>>> [UDP]', 'magenta')) self.tx_socket.sendto(value) # Called by asyncio when the UDP socket is created - def connection_made(self, transport): - pass - - # Called by asyncio when a UDP datagram is received - def datagram_received(self, data, address): - print(color('<<< RX:', 'magenta'), data.hex()) - - # TODO: use a queue instead of creating a task everytime - if self.peer and self.rx_characteristic: - asyncio.create_task(self.peer.write_value(self.rx_characteristic, data)) - - -# ----------------------------------------------------------------------------- -class GattlinkNodeBridge(Device.Listener): - def __init__(self): - self.peer = None - self.rx_socket = None - self.tx_socket = None + def on_l2cap_psm_received(self, value): + psm = struct.unpack('>> [L2CAP]', 'yellow')) + self.l2cap_channel.write(bytes([len(data) - 1]) + data) + elif self.peer and self.rx_characteristic: + print(color('>>> [GATT RX]', 'yellow')) asyncio.create_task(self.peer.write_value(self.rx_characteristic, data)) # ----------------------------------------------------------------------------- -async def run(hci_transport, device_address, send_host, send_port, receive_host, receive_port): +class GattlinkNodeBridge(GattlinkL2capEndpoint, Device.Listener): + def __init__(self, device): + super().__init__() + self.device = device + self.peer = None + self.tx_socket = None + self.tx_subscriber = None + self.rx_characteristic = None + + # Register as a listener + device.listener = self + + # Listen for incoming L2CAP CoC connections + psm = 0xFB + device.register_l2cap_channel_server(0xFB, self.on_coc) + print(f'### Listening for CoC connection on PSM {psm}') + + # Setup the Gattlink service + self.rx_characteristic = Characteristic( + GG_GATTLINK_RX_CHARACTERISTIC_UUID, + Characteristic.WRITE_WITHOUT_RESPONSE, + Characteristic.WRITEABLE, + CharacteristicValue(write=self.on_rx_write) + ) + self.tx_characteristic = Characteristic( + GG_GATTLINK_TX_CHARACTERISTIC_UUID, + Characteristic.NOTIFY, + Characteristic.READABLE + ) + self.tx_characteristic.on('subscription', self.on_tx_subscription) + self.psm_characteristic = Characteristic( + GG_GATTLINK_L2CAP_CHANNEL_PSM_CHARACTERISTIC_UUID, + Characteristic.READ | Characteristic.NOTIFY, + Characteristic.READABLE, + bytes([psm, 0]) + ) + gattlink_service = Service( + GG_GATTLINK_SERVICE_UUID, + [ + self.rx_characteristic, + self.tx_characteristic, + self.psm_characteristic + ] + ) + device.add_services([gattlink_service]) + device.advertising_data = bytes( + AdvertisingData([ + (AdvertisingData.COMPLETE_LOCAL_NAME, bytes('Bumble GG', 'utf-8')), + (AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS, + bytes(reversed(bytes.fromhex('ABBAFF00E56A484CB8328B17CF6CBFE8')))) + ]) + ) + + async def start(self): + await self.device.start_advertising() + + # Called by asyncio when the UDP socket is created + def connection_made(self, transport): + self.transport = transport + + # Called by asyncio when a UDP datagram is received + def datagram_received(self, data, address): + print(color(f'<<< [UDP]: {len(data)} bytes', 'green')) + + if self.l2cap_channel: + print(color('>>> [L2CAP]', 'yellow')) + self.l2cap_channel.write(bytes([len(data) - 1]) + data) + elif self.tx_subscriber: + print(color('>>> [GATT TX]', 'yellow')) + self.tx_characteristic.value = data + asyncio.create_task(self.device.notify_subscribers(self.tx_characteristic)) + + # Called when a write to the RX characteristic has been received + def on_rx_write(self, connection, data): + print(color(f'<<< [GATT RX]: {len(data)} bytes', 'cyan')) + print(color('>>> [UDP]', 'magenta')) + self.tx_socket.sendto(data) + + # Called when the subscription to the TX characteristic has changed + def on_tx_subscription(self, peer, enabled): + print(f'### [GATT TX] subscription from {peer}: {"enabled" if enabled else "disabled"}') + if enabled: + self.tx_subscriber = peer + else: + self.tx_subscriber = None + + # Called when an L2CAP packet is received + def on_l2cap_packet(self, packet): + print(color(f'<<< [L2CAP PACKET]: {len(packet)} bytes', 'cyan')) + print(color('>>> [UDP]', 'magenta')) + self.tx_socket.sendto(packet) + + # Called when a new connection is established + def on_coc(self, channel): + print('*** CoC Connection', channel) + self.l2cap_channel = channel + channel.sink = self.on_coc_sdu + + +# ----------------------------------------------------------------------------- +async def run(hci_transport, device_address, role_or_peer_address, send_host, send_port, receive_host, receive_port): print('<<< connecting to HCI...') async with await open_transport_or_link(hci_transport) as (hci_source, hci_sink): print('<<< connected') # Instantiate a bridge object - bridge = GattlinkNodeBridge() + device = Device.with_hci('Bumble GG', device_address, hci_source, hci_sink) + + # Instantiate a bridge object + if role_or_peer_address == 'node': + bridge = GattlinkNodeBridge(device) + else: + bridge = GattlinkHubBridge(device, role_or_peer_address) # Create a UDP to RX bridge (receive from UDP, send to RX) loop = asyncio.get_running_loop() @@ -160,35 +317,8 @@ async def run(hci_transport, device_address, send_host, send_port, receive_host, remote_addr=(send_host, send_port) ) - # Create a device to manage the host, with a custom listener - device = Device.with_hci('Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink) - device.listener = bridge await device.power_on() - - # Connect to the peer - # print(f'=== Connecting to {device_address}...') - # await device.connect(device_address) - - # TODO move to class - gattlink_service = Service( - GG_GATTLINK_SERVICE_UUID, - [ - Characteristic( - GG_GATTLINK_L2CAP_CHANNEL_PSM_CHARACTERISTIC_UUID, - Characteristic.READ, - Characteristic.READABLE, - bytes([193, 0]) - ) - ] - ) - device.add_services([gattlink_service]) - device.advertising_data = bytes( - AdvertisingData([ - (AdvertisingData.COMPLETE_LOCAL_NAME, bytes('Bumble GG', 'utf-8')), - (AdvertisingData.INCOMPLETE_LIST_OF_128_BIT_SERVICE_CLASS_UUIDS, bytes(reversed(bytes.fromhex('ABBAFF00E56A484CB8328B17CF6CBFE8')))) - ]) - ) - await device.start_advertising() + await bridge.start() # Wait until the source terminates await hci_source.wait_for_termination() @@ -197,15 +327,16 @@ async def run(hci_transport, device_address, send_host, send_port, receive_host, @click.command() @click.argument('hci_transport') @click.argument('device_address') +@click.argument('role_or_peer_address') @click.option('-sh', '--send-host', type=str, default='127.0.0.1', help='UDP host to send to') @click.option('-sp', '--send-port', type=int, default=9001, help='UDP port to send to') @click.option('-rh', '--receive-host', type=str, default='127.0.0.1', help='UDP host to receive on') @click.option('-rp', '--receive-port', type=int, default=9000, help='UDP port to receive on') -def main(hci_transport, device_address, send_host, send_port, receive_host, receive_port): - logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) - asyncio.run(run(hci_transport, device_address, send_host, send_port, receive_host, receive_port)) +def main(hci_transport, device_address, role_or_peer_address, send_host, send_port, receive_host, receive_port): + asyncio.run(run(hci_transport, device_address, role_or_peer_address, send_host, send_port, receive_host, receive_port)) # ----------------------------------------------------------------------------- +logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper()) if __name__ == '__main__': main() diff --git a/apps/l2cap_bridge.py b/apps/l2cap_bridge.py new file mode 100644 index 00000000..ba658c21 --- /dev/null +++ b/apps/l2cap_bridge.py @@ -0,0 +1,331 @@ +# Copyright 2021-2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ----------------------------------------------------------------------------- +# Imports +# ----------------------------------------------------------------------------- +import asyncio +import click +import logging +import os +from colors import color + +from bumble.transport import open_transport_or_link +from bumble.device import Device +from bumble.utils import FlowControlAsyncPipe +from bumble.hci import HCI_Constant + + +# ----------------------------------------------------------------------------- +class ServerBridge: + """ + L2CAP CoC server bridge: waits for a peer to connect an L2CAP CoC channel + on a specified PSM. When the connection is made, the bridge connects a TCP + socket to a remote host and bridges the data in both directions, with flow + control. + When the L2CAP CoC channel is closed, the bridge disconnects the TCP socket + and waits for a new L2CAP CoC channel to be connected. + When the TCP connection is closed by the TCP server, XXXX + """ + def __init__( + self, + psm, + max_credits, + mtu, + mps, + tcp_host, + tcp_port + ): + self.psm = psm + self.max_credits = max_credits + self.mtu = mtu + self.mps = mps + self.tcp_host = tcp_host + self.tcp_port = tcp_port + + async def start(self, device): + # Listen for incoming L2CAP CoC connections + device.register_l2cap_channel_server( + psm = self.psm, + server = self.on_coc, + max_credits = self.max_credits, + mtu = self.mtu, + mps = self.mps + ) + print(color(f'### Listening for CoC connection on PSM {self.psm}', 'yellow')) + + def on_ble_connection(connection): + def on_ble_disconnection(reason): + print(color('@@@ Bluetooth disconnection:', 'red'), HCI_Constant.error_name(reason)) + + print(color('@@@ Bluetooth connection:', 'green'), connection) + connection.on('disconnection', on_ble_disconnection) + + device.on('connection', on_ble_connection) + + await device.start_advertising(auto_restart=True) + + # Called when a new L2CAP connection is established + def on_coc(self, l2cap_channel): + print(color('*** L2CAP channel:', 'cyan'), l2cap_channel) + + class Pipe: + def __init__(self, bridge, l2cap_channel): + self.bridge = bridge + self.tcp_transport = None + self.l2cap_channel = l2cap_channel + + l2cap_channel.on('close', self.on_l2cap_close) + l2cap_channel.sink = self.on_coc_sdu + + async def connect_to_tcp(self): + # Connect to the TCP server + print(color(f'### Connecting to TCP {self.bridge.tcp_host}:{self.bridge.tcp_port}...', 'yellow')) + + class TcpClientProtocol(asyncio.Protocol): + def __init__(self, pipe): + self.pipe = pipe + + def connection_lost(self, error): + print(color(f'!!! TCP connection lost: {error}', 'red')) + if self.pipe.l2cap_channel is not None: + asyncio.create_task(self.pipe.l2cap_channel.disconnect()) + + def data_received(self, data): + print(f'<<< Received on TCP: {len(data)}') + self.pipe.l2cap_channel.write(data) + + try: + self.tcp_transport, _ = await asyncio.get_running_loop().create_connection( + lambda: TcpClientProtocol(self), + host=self.bridge.tcp_host, + port=self.bridge.tcp_port, + ) + print(color('### Connected', 'green')) + except Exception as error: + print(color(f'!!! Connection failed: {error}', 'red')) + await self.l2cap_channel.disconnect() + + def on_l2cap_close(self): + self.l2cap_channel = None + if self.tcp_transport is not None: + self.tcp_transport.close() + + def on_coc_sdu(self, sdu): + print(color(f'<<< [L2CAP SDU]: {len(sdu)} bytes', 'cyan')) + if self.tcp_transport is None: + print(color('!!! TCP socket not open, dropping', 'red')) + return + self.tcp_transport.write(sdu) + + pipe = Pipe(self, l2cap_channel) + + asyncio.create_task(pipe.connect_to_tcp()) + + +# ----------------------------------------------------------------------------- +class ClientBridge: + """ + L2CAP CoC client bridge: connects to a BLE device, then waits for an inbound + TCP connection on a specified port number. When a TCP client connects, an + L2CAP CoC channel connection to the BLE device is established, and the data + is bridged in both directions, with flow control. + When the TCP connection is closed by the client, the L2CAP CoC channel is + disconnected, but the connection to the BLE device remains, ready for a new + TCP client to connect. + When the L2CAP CoC channel is closed, XXXX + """ + + READ_CHUNK_SIZE = 4096 + + def __init__( + self, + psm, + max_credits, + mtu, + mps, + address, + tcp_host, + tcp_port + ): + self.psm = psm + self.max_credits = max_credits + self.mtu = mtu + self.mps = mps + self.address = address + self.tcp_host = tcp_host + self.tcp_port = tcp_port + + async def start(self, device): + print(color(f'### Connecting to {self.address}...', 'yellow')) + connection = await device.connect(self.address) + print(color('### Connected', 'green')) + + # Called when the BLE connection is disconnected + def on_ble_disconnection(reason): + print(color('@@@ Bluetooth disconnection:', 'red'), HCI_Constant.error_name(reason)) + + connection.on('disconnection', on_ble_disconnection) + + # Called when a TCP connection is established + async def on_tcp_connection(reader, writer): + peername = writer.get_extra_info('peername') + print(color(f'<<< TCP connection from {peername}', 'magenta')) + + def on_coc_sdu(sdu): + print(color(f'<<< [L2CAP SDU]: {len(sdu)} bytes', 'cyan')) + l2cap_to_tcp_pipe.write(sdu) + + def on_l2cap_close(): + print(color('*** L2CAP channel closed', 'red')) + l2cap_to_tcp_pipe.stop() + writer.close() + + # Connect a new L2CAP channel + print(color(f'>>> Opening L2CAP channel on PSM = {self.psm}', 'yellow')) + try: + l2cap_channel = await connection.open_l2cap_channel( + psm = self.psm, + max_credits = self.max_credits, + mtu = self.mtu, + mps = self.mps + ) + print(color('*** L2CAP channel:', 'cyan'), l2cap_channel) + except Exception as error: + print(color(f'!!! Connection failed: {error}', 'red')) + writer.close() + return + + l2cap_channel.sink = on_coc_sdu + l2cap_channel.on('close', on_l2cap_close) + + # Start a flow control pipe from L2CAP to TCP + l2cap_to_tcp_pipe = FlowControlAsyncPipe( + l2cap_channel.pause_reading, + l2cap_channel.resume_reading, + writer.write, + writer.drain + ) + l2cap_to_tcp_pipe.start() + + # Pipe data from TCP to L2CAP + while True: + try: + data = await reader.read(self.READ_CHUNK_SIZE) + + if len(data) == 0: + print(color('!!! End of stream', 'red')) + await l2cap_channel.disconnect() + return + + print(color(f'<<< [TCP DATA]: {len(data)} bytes', 'blue')) + l2cap_channel.write(data) + await l2cap_channel.drain() + except Exception as error: + print(f'!!! Exception: {error}') + break + + writer.close() + print(color('~~~ Bye bye', 'magenta')) + + await asyncio.start_server( + on_tcp_connection, + host=self.tcp_host if self.tcp_host != '_' else None, + port=self.tcp_port + ) + print(color(f'### Listening for TCP connections on port {self.tcp_port}', 'magenta')) + + +# ----------------------------------------------------------------------------- +async def run(device_config, hci_transport, bridge): + print('<<< connecting to HCI...') + async with await open_transport_or_link(hci_transport) as (hci_source, hci_sink): + print('<<< connected') + + device = Device.from_config_file_with_hci(device_config, hci_source, hci_sink) + + # Let's go + await device.power_on() + await bridge.start(device) + + # Wait until the transport terminates + await hci_source.wait_for_termination() + + +# ----------------------------------------------------------------------------- +@click.group() +@click.pass_context +@click.option('--device-config', help='Device configuration file', required=True) +@click.option('--hci-transport', help='HCI transport', required=True) +@click.option('--psm', help='PSM for L2CAP CoC', type=int, default=1234) +@click.option('--l2cap-coc-max-credits', help='Maximum L2CAP CoC Credits', type=click.IntRange(1, 65535), default=128) +@click.option('--l2cap-coc-mtu', help='L2CAP CoC MTU', type=click.IntRange(23, 65535), default=1022) +@click.option('--l2cap-coc-mps', help='L2CAP CoC MPS', type=click.IntRange(23, 65533), default=1024) +def cli(context, device_config, hci_transport, psm, l2cap_coc_max_credits, l2cap_coc_mtu, l2cap_coc_mps): + context.ensure_object(dict) + context.obj['device_config'] = device_config + context.obj['hci_transport'] = hci_transport + context.obj['psm'] = psm + context.obj['max_credits'] = l2cap_coc_max_credits + context.obj['mtu'] = l2cap_coc_mtu + context.obj['mps'] = l2cap_coc_mps + + +# ----------------------------------------------------------------------------- +@cli.command() +@click.pass_context +@click.option('--tcp-host', help='TCP host', default='localhost') +@click.option('--tcp-port', help='TCP port', default=9544) +def server(context, tcp_host, tcp_port): + bridge = ServerBridge( + context.obj['psm'], + context.obj['max_credits'], + context.obj['mtu'], + context.obj['mps'], + tcp_host, + tcp_port) + asyncio.run(run( + context.obj['device_config'], + context.obj['hci_transport'], + bridge + )) + + +# ----------------------------------------------------------------------------- +@cli.command() +@click.pass_context +@click.argument('bluetooth-address') +@click.option('--tcp-host', help='TCP host', default='_') +@click.option('--tcp-port', help='TCP port', default=9543) +def client(context, bluetooth_address, tcp_host, tcp_port): + bridge = ClientBridge( + context.obj['psm'], + context.obj['max_credits'], + context.obj['mtu'], + context.obj['mps'], + bluetooth_address, + tcp_host, + tcp_port + ) + asyncio.run(run( + context.obj['device_config'], + context.obj['hci_transport'], + bridge + )) + + +# ----------------------------------------------------------------------------- +logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper()) +if __name__ == '__main__': + cli(obj={}) diff --git a/bumble/avdtp.py b/bumble/avdtp.py index 759e38c9..7fe4fbb6 100644 --- a/bumble/avdtp.py +++ b/bumble/avdtp.py @@ -351,7 +351,7 @@ class MediaPacketPump: logger.debug('pump canceled') # Pump packets - self.pump_task = asyncio.get_running_loop().create_task(pump_packets()) + self.pump_task = asyncio.create_task(pump_packets()) async def stop(self): # Stop the pump @@ -1890,10 +1890,10 @@ class LocalSource(LocalStreamEndPoint, EventEmitter): self.configuration = configuration def on_start_command(self): - asyncio.get_running_loop().create_task(self.start()) + asyncio.create_task(self.start()) def on_suspend_command(self): - asyncio.get_running_loop().create_task(self.stop()) + asyncio.create_task(self.stop()) # ----------------------------------------------------------------------------- diff --git a/bumble/device.py b/bumble/device.py index b12fad5d..50801c7b 100644 --- a/bumble/device.py +++ b/bumble/device.py @@ -43,6 +43,13 @@ logger = logging.getLogger(__name__) # ----------------------------------------------------------------------------- # Constants # ----------------------------------------------------------------------------- +DEVICE_MIN_SCAN_INTERVAL = 25 +DEVICE_MAX_SCAN_INTERVAL = 10240 +DEVICE_MIN_SCAN_WINDOW = 25 +DEVICE_MAX_SCAN_WINDOW = 10240 +DEVICE_MIN_LE_RSSI = -127 +DEVICE_MAX_LE_RSSI = 20 + DEVICE_DEFAULT_ADDRESS = '00:00:00:00:00:00' DEVICE_DEFAULT_ADVERTISING_INTERVAL = 1000 # ms DEVICE_DEFAULT_ADVERTISING_DATA = '' @@ -62,20 +69,15 @@ DEVICE_DEFAULT_CONNECTION_MAX_LATENCY = 0 DEVICE_DEFAULT_CONNECTION_SUPERVISION_TIMEOUT = 720 # ms DEVICE_DEFAULT_CONNECTION_MIN_CE_LENGTH = 0 # ms DEVICE_DEFAULT_CONNECTION_MAX_CE_LENGTH = 0 # ms - -DEVICE_MIN_SCAN_INTERVAL = 25 -DEVICE_MAX_SCAN_INTERVAL = 10240 -DEVICE_MIN_SCAN_WINDOW = 25 -DEVICE_MAX_SCAN_WINDOW = 10240 -DEVICE_MIN_LE_RSSI = -127 -DEVICE_MAX_LE_RSSI = 20 +DEVICE_DEFAULT_L2CAP_COC_MTU = l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU +DEVICE_DEFAULT_L2CAP_COC_MPS = l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS +DEVICE_DEFAULT_L2CAP_COC_MAX_CREDITS = l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS # ----------------------------------------------------------------------------- # Classes # ----------------------------------------------------------------------------- - # ----------------------------------------------------------------------------- class Advertisement: TX_POWER_NOT_AVAILABLE = HCI_LE_Extended_Advertising_Report_Event.TX_POWER_INFORMATION_NOT_AVAILABLE @@ -429,7 +431,16 @@ class Connection(CompositeEventEmitter): def create_l2cap_connector(self, psm): return self.device.create_l2cap_connector(self, psm) - async def disconnect(self, reason = HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR): + async def open_l2cap_channel( + self, + psm, + max_credits=DEVICE_DEFAULT_L2CAP_COC_MAX_CREDITS, + mtu=DEVICE_DEFAULT_L2CAP_COC_MTU, + mps=DEVICE_DEFAULT_L2CAP_COC_MPS + ): + return await self.device.open_l2cap_channel(self, psm, max_credits, mtu, mps) + + async def disconnect(self, reason=HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR): return await self.device.disconnect(self, reason) async def pair(self): @@ -563,6 +574,7 @@ class DeviceConfiguration: with open(filename, 'r') as file: self.load_from_dict(json.load(file)) + # ----------------------------------------------------------------------------- # Decorators used with the following Device class # (we define them outside of the Device class, because defining decorators @@ -685,7 +697,7 @@ class Device(CompositeEventEmitter): self.classic_enabled = False self.inquiry_response = None self.address_resolver = None - self.classic_pending_accepts = { Address.ANY: [] } # Futures, by BD address OR [Futures] for Address.ANY + self.classic_pending_accepts = {Address.ANY: []} # Futures, by BD address OR [Futures] for Address.ANY # Use the initial config or a default self.public_address = Address('00:00:00:00:00:00') @@ -785,15 +797,35 @@ class Device(CompositeEventEmitter): if transport is None or connection.transport == transport: return connection - def register_l2cap_server(self, psm, server): - self.l2cap_channel_manager.register_server(psm, server) - def create_l2cap_connector(self, connection, psm): return lambda: self.l2cap_channel_manager.connect(connection, psm) def create_l2cap_registrar(self, psm): return lambda handler: self.register_l2cap_server(psm, handler) + def register_l2cap_server(self, psm, server): + self.l2cap_channel_manager.register_server(psm, server) + + def register_l2cap_channel_server( + self, + psm, + server, + max_credits=DEVICE_DEFAULT_L2CAP_COC_MAX_CREDITS, + mtu=DEVICE_DEFAULT_L2CAP_COC_MTU, + mps=DEVICE_DEFAULT_L2CAP_COC_MPS + ): + return self.l2cap_channel_manager.register_le_coc_server(psm, server, max_credits, mtu, mps) + + async def open_l2cap_channel( + self, + connection, + psm, + max_credits=DEVICE_DEFAULT_L2CAP_COC_MAX_CREDITS, + mtu=DEVICE_DEFAULT_L2CAP_COC_MTU, + mps=DEVICE_DEFAULT_L2CAP_COC_MPS + ): + return await self.l2cap_channel_manager.open_le_coc(connection, psm, max_credits, mtu, mps) + def send_l2cap_pdu(self, connection_handle, cid, pdu): self.host.send_l2cap_pdu(connection_handle, cid, pdu) @@ -1185,13 +1217,15 @@ class Device(CompositeEventEmitter): def on_connection(connection): if transport == BT_LE_TRANSPORT or ( # match BR/EDR connection event against peer address - connection.transport == transport and connection.peer_address == peer_address): + connection.transport == transport and connection.peer_address == peer_address + ): pending_connection.set_result(connection) def on_connection_failure(error): if transport == BT_LE_TRANSPORT or ( # match BR/EDR connection failure event against peer address - error.transport == transport and error.peer_address == peer_address): + error.transport == transport and error.peer_address == peer_address + ): pending_connection.set_exception(error) # Create a future so that we can wait for the connection's result @@ -1336,7 +1370,7 @@ class Device(CompositeEventEmitter): if peer_address == Address.NIL: raise ValueError('accept on nil address') - # Create a future so that we can wait for the request + # Create a future so that we can wait for the request pending_request = asyncio.get_running_loop().create_future() if peer_address == Address.ANY: @@ -1349,8 +1383,7 @@ class Device(CompositeEventEmitter): try: # Wait for a request or a completed connection result = await (asyncio.wait_for(pending_request, timeout) if timeout else pending_request) - - except: + except Exception: # Remove future from device context if peer_address == Address.ANY: self.classic_pending_accepts[Address.ANY].remove(pending_request) @@ -1710,26 +1743,32 @@ class Device(CompositeEventEmitter): connection.remove_listener('connection_encryption_failure', on_encryption_failure) # [Classic only] - async def request_remote_name(self, remote: Connection | Address): + async def request_remote_name(self, remote): # remote: Connection | Address # Set up event handlers pending_name = asyncio.get_running_loop().create_future() if type(remote) == Address: peer_address = remote - handler = self.on('remote_name', + handler = self.on( + 'remote_name', lambda address, remote_name: - pending_name.set_result(remote_name) if address == remote else None) - failure_handler = self.on('remote_name_failure', + pending_name.set_result(remote_name) if address == remote else None + ) + failure_handler = self.on( + 'remote_name_failure', lambda address, error_code: - pending_name.set_exception(HCI_Error(error_code)) if address == remote else None) + pending_name.set_exception(HCI_Error(error_code)) if address == remote else None + ) else: peer_address = remote.peer_address - handler = remote.on('remote_name', - lambda: - pending_name.set_result(remote.peer_name)) - failure_handler = remote.on('remote_name_failure', - lambda error_code: - pending_name.set_exception(HCI_Error(error_code))) + handler = remote.on( + 'remote_name', + lambda: pending_name.set_result(remote.peer_name) + ) + failure_handler = remote.on( + 'remote_name_failure', + lambda error_code: pending_name.set_exception(HCI_Error(error_code)) + ) try: result = await self.send_command( @@ -2097,7 +2136,6 @@ class Device(CompositeEventEmitter): else: self.emit('remote_name_failure', address, error) - # [Classic only] @host_event_handler @try_with_connection_from_address diff --git a/bumble/gatt.py b/bumble/gatt.py index 0847652f..cc66329e 100644 --- a/bumble/gatt.py +++ b/bumble/gatt.py @@ -25,6 +25,7 @@ import asyncio import types import logging +from pyee import EventEmitter from colors import color from .core import * diff --git a/bumble/gatt_client.py b/bumble/gatt_client.py index c1c52764..07116fa0 100644 --- a/bumble/gatt_client.py +++ b/bumble/gatt_client.py @@ -273,7 +273,7 @@ class Client: if response.op_code == ATT_ERROR_RESPONSE: if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR: # Unexpected end - logger.waning(f'!!! unexpected error while discovering services: {HCI_Constant.error_name(response.error_code)}') + logger.warning(f'!!! unexpected error while discovering services: {HCI_Constant.error_name(response.error_code)}') # TODO raise appropriate exception return break @@ -337,7 +337,7 @@ class Client: if response.op_code == ATT_ERROR_RESPONSE: if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR: # Unexpected end - logger.waning(f'!!! unexpected error while discovering services: {HCI_Constant.error_name(response.error_code)}') + logger.warning(f'!!! unexpected error while discovering services: {HCI_Constant.error_name(response.error_code)}') # TODO raise appropriate exception return break diff --git a/bumble/gatt_server.py b/bumble/gatt_server.py index 656df6b7..c5d8dc17 100644 --- a/bumble/gatt_server.py +++ b/bumble/gatt_server.py @@ -155,7 +155,7 @@ class Server(EventEmitter): return cccd or bytes([0, 0]) def write_cccd(self, connection, characteristic, value): - logger.debug(f'Subscription update for connection={connection.handle:04X}, handle={characteristic.handle:04X}: {value.hex()}') + logger.debug(f'Subscription update for connection=0x{connection.handle:04X}, handle=0x{characteristic.handle:04X}: {value.hex()}') # Sanity check if len(value) != 2: @@ -204,7 +204,7 @@ class Server(EventEmitter): logger.debug(f'GATT Notify from server: [0x{connection.handle:04X}] {notification}') self.send_gatt_pdu(connection.handle, bytes(notification)) - async def indicate_subscriber(self, connection, attribute, value=None, force=False): + async def indicate_subscriber(self, connection, attribute, force=False): # Check if there's a subscriber if not force: subscribers = self.subscribers.get(connection.handle) diff --git a/bumble/hci.py b/bumble/hci.py index d4cf7cc0..cf9f682c 100644 --- a/bumble/hci.py +++ b/bumble/hci.py @@ -2466,9 +2466,10 @@ class HCI_Write_Voice_Setting_Command(HCI_Command): # ----------------------------------------------------------------------------- +@HCI_Command.command() class HCI_Read_Synchronous_Flow_Control_Enable_Command(HCI_Command): ''' - See Bluetooth spec @ 7.3.36 Write Synchronous Flow Control Enable Command + See Bluetooth spec @ 7.3.36 Read Synchronous Flow Control Enable Command ''' diff --git a/bumble/host.py b/bumble/host.py index 32b21946..a57299b0 100644 --- a/bumble/host.py +++ b/bumble/host.py @@ -79,6 +79,8 @@ class Host(EventEmitter): self.local_version = None self.local_supported_commands = bytes(64) self.local_le_features = 0 + self.suggested_max_tx_octets = 251 # Max allowed + self.suggested_max_tx_time = 2120 # Max allowed self.command_semaphore = asyncio.Semaphore(1) self.long_term_key_provider = None self.link_key_provider = None @@ -138,6 +140,22 @@ class Host(EventEmitter): f'hc_total_num_le_acl_data_packets={self.hc_total_num_le_acl_data_packets}' ) + if ( + self.supports_command(HCI_LE_READ_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND) and + self.supports_command(HCI_LE_WRITE_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND) + ): + response = await self.send_command(HCI_LE_Read_Suggested_Default_Data_Length_Command()) + suggested_max_tx_octets = response.return_parameters.suggested_max_tx_octets + suggested_max_tx_time = response.return_parameters.suggested_max_tx_time + if ( + suggested_max_tx_octets != self.suggested_max_tx_octets or + suggested_max_tx_time != self.suggested_max_tx_time + ): + await self.send_command(HCI_LE_Write_Suggested_Default_Data_Length_Command( + suggested_max_tx_octets = self.suggested_max_tx_octets, + suggested_max_tx_time = self.suggested_max_tx_time + )) + self.reset_done = True @property diff --git a/bumble/l2cap.py b/bumble/l2cap.py index 927454e8..78ec0ec4 100644 --- a/bumble/l2cap.py +++ b/bumble/l2cap.py @@ -19,6 +19,7 @@ import asyncio import logging import struct +from collections import deque from colors import color from pyee import EventEmitter @@ -43,13 +44,23 @@ L2CAP_MIN_BR_EDR_MTU = 48 L2CAP_DEFAULT_MTU = 2048 # Default value for the MTU we are willing to accept +L2CAP_DEFAULT_CONNECTIONLESS_MTU = 1024 + # See Bluetooth spec @ Vol 3, Part A - Table 2.1: CID name space on ACL-U, ASB-U, and AMP-U logical links L2CAP_ACL_U_DYNAMIC_CID_RANGE_START = 0x0040 L2CAP_ACL_U_DYNAMIC_CID_RANGE_END = 0xFFFF # See Bluetooth spec @ Vol 3, Part A - Table 2.2: CID name space on LE-U logical link L2CAP_LE_U_DYNAMIC_CID_RANGE_START = 0x0040 -L2CAP_LE_U_DYNAMIC_CID_RANGE_START = 0x007F +L2CAP_LE_U_DYNAMIC_CID_RANGE_END = 0x007F + +# PSM Range - See Bluetooth spec @ Vol 3, Part A / Table 4.5: PSM ranges and usage +L2CAP_PSM_DYNAMIC_RANGE_START = 0x1001 +L2CAP_PSM_DYNAMIC_RANGE_END = 0xFFFF + +# LE PSM Ranges - See Bluetooth spec @ Vol 3, Part A / Table 4.19: LE Credit Based Connection Request LE_PSM ranges +L2CAP_LE_PSM_DYNAMIC_RANGE_START = 0x0080 +L2CAP_LE_PSM_DYNAMIC_RANGE_END = 0x00FF # Frame types L2CAP_COMMAND_REJECT = 0x01 @@ -107,8 +118,13 @@ L2CAP_COMMAND_NOT_UNDERSTOOD_REASON = 0x0000 L2CAP_SIGNALING_MTU_EXCEEDED_REASON = 0x0001 L2CAP_INVALID_CID_IN_REQUEST_REASON = 0x0002 -L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU = 2048 -L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS = 2048 +L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_CREDITS = 65535 +L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU = 23 +L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MPS = 23 +L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MPS = 65533 +L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU = 2046 +L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS = 2048 +L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS = 256 L2CAP_MAXIMUM_TRANSMISSION_UNIT_CONFIGURATION_OPTION_TYPE = 0x01 @@ -172,7 +188,7 @@ class L2CAP_Control_Frame: self.identifier = pdu[1] length = struct.unpack_from('= 2: - type = data[0] + type = data[0] length = data[1] - value = data[2:2 + length] - data = data[2 + length:] + value = data[2:2 + length] + data = data[2 + length:] options.append((type, value)) return options @@ -268,7 +284,10 @@ class L2CAP_Command_Reject(L2CAP_Control_Frame): # ----------------------------------------------------------------------------- @L2CAP_Control_Frame.subclass([ - ('psm', 2), + ('psm', { + 'parser': lambda data, offset: L2CAP_Connection_Request.parse_psm(data, offset), + 'serializer': lambda value: L2CAP_Connection_Request.serialize_psm(value) + }), ('source_cid', 2) ]) class L2CAP_Connection_Request(L2CAP_Control_Frame): @@ -276,6 +295,28 @@ class L2CAP_Connection_Request(L2CAP_Control_Frame): See Bluetooth spec @ Vol 3, Part A - 4.2 CONNECTION REQUEST ''' + @staticmethod + def parse_psm(data, offset=0): + psm_length = 2 + psm = data[offset] | data[offset + 1] << 8 + + # The PSM field extends until the first even octet (inclusive) + while data[offset + psm_length - 1] % 2 == 1: + psm |= data[offset + psm_length] << (8 * psm_length) + psm_length += 1 + + return offset + psm_length, psm + + @staticmethod + def serialize_psm(psm): + serialized = struct.pack('>= 16 + while psm: + serialized += bytes([psm & 0xFF]) + psm >>= 8 + + return serialized + # ----------------------------------------------------------------------------- @L2CAP_Control_Frame.subclass([ @@ -289,16 +330,16 @@ class L2CAP_Connection_Response(L2CAP_Control_Frame): See Bluetooth spec @ Vol 3, Part A - 4.3 CONNECTION RESPONSE ''' - CONNECTION_SUCCESSFUL = 0x0000 - CONNECTION_PENDING = 0x0001 - CONNECTION_REFUSED_LE_PSM_NOT_SUPPORTED = 0x0002 - CONNECTION_REFUSED_SECURITY_BLOCK = 0x0003 - CONNECTION_REFUSED_NO_RESOURCES_AVAILABLE = 0x0004 - CONNECTION_REFUSED_INVALID_SOURCE_CID = 0x0006 - CONNECTION_REFUSED_SOURCE_CID_ALREADY_ALLOCATED = 0x0007 - CONNECTION_REFUSED_UNACCEPTABLE_PARAMETERS = 0x000B + CONNECTION_SUCCESSFUL = 0x0000 + CONNECTION_PENDING = 0x0001 + CONNECTION_REFUSED_LE_PSM_NOT_SUPPORTED = 0x0002 + CONNECTION_REFUSED_SECURITY_BLOCK = 0x0003 + CONNECTION_REFUSED_NO_RESOURCES_AVAILABLE = 0x0004 + CONNECTION_REFUSED_INVALID_SOURCE_CID = 0x0006 + CONNECTION_REFUSED_SOURCE_CID_ALREADY_ALLOCATED = 0x0007 + CONNECTION_REFUSED_UNACCEPTABLE_PARAMETERS = 0x000B - CONNECTION_RESULT_NAMES = { + RESULT_NAMES = { CONNECTION_SUCCESSFUL: 'CONNECTION_SUCCESSFUL', CONNECTION_PENDING: 'CONNECTION_PENDING', CONNECTION_REFUSED_LE_PSM_NOT_SUPPORTED: 'CONNECTION_REFUSED_LE_PSM_NOT_SUPPORTED', @@ -311,7 +352,7 @@ class L2CAP_Connection_Response(L2CAP_Control_Frame): @staticmethod def result_name(result): - return name_or_number(L2CAP_Connection_Response.CONNECTION_RESULT_NAMES, result) + return name_or_number(L2CAP_Connection_Response.RESULT_NAMES, result) # ----------------------------------------------------------------------------- @@ -485,10 +526,10 @@ class L2CAP_Connection_Parameter_Update_Response(L2CAP_Control_Frame): # ----------------------------------------------------------------------------- @L2CAP_Control_Frame.subclass([ - ('le_psm', 2), - ('source_cid', 2), - ('mtu', 2), - ('mps', 2), + ('le_psm', 2), + ('source_cid', 2), + ('mtu', 2), + ('mps', 2), ('initial_credits', 2) ]) class L2CAP_LE_Credit_Based_Connection_Request(L2CAP_Control_Frame): @@ -521,7 +562,7 @@ class L2CAP_LE_Credit_Based_Connection_Response(L2CAP_Control_Frame): CONNECTION_REFUSED_SOURCE_CID_ALREADY_ALLOCATED = 0x000A CONNECTION_REFUSED_UNACCEPTABLE_PARAMETERS = 0x000B - CONNECTION_RESULT_NAMES = { + RESULT_NAMES = { CONNECTION_SUCCESSFUL: 'CONNECTION_SUCCESSFUL', CONNECTION_REFUSED_LE_PSM_NOT_SUPPORTED: 'CONNECTION_REFUSED_LE_PSM_NOT_SUPPORTED', CONNECTION_REFUSED_NO_RESOURCES_AVAILABLE: 'CONNECTION_REFUSED_NO_RESOURCES_AVAILABLE', @@ -536,12 +577,12 @@ class L2CAP_LE_Credit_Based_Connection_Response(L2CAP_Control_Frame): @staticmethod def result_name(result): - return name_or_number(L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_RESULT_NAMES, result) + return name_or_number(L2CAP_LE_Credit_Based_Connection_Response.RESULT_NAMES, result) # ----------------------------------------------------------------------------- @L2CAP_Control_Frame.subclass([ - ('cid', 2), + ('cid', 2), ('credits', 2) ]) class L2CAP_LE_Flow_Control_Credit(L2CAP_Control_Frame): @@ -619,6 +660,9 @@ class Channel(EventEmitter): def send_pdu(self, pdu): self.manager.send_pdu(self.connection, self.destination_cid, pdu) + def send_control_frame(self, frame): + self.manager.send_control_frame(self.connection, self.signaling_cid, frame) + async def send_request(self, request): # Check that there isn't already a request pending if self.response: @@ -637,15 +681,16 @@ class Channel(EventEmitter): elif self.sink: self.sink(pdu) else: - logger.warn(color('received pdu without a pending request or sink', 'red')) - - def send_control_frame(self, frame): - self.manager.send_control_frame(self.connection, self.signaling_cid, frame) + logger.warning(color('received pdu without a pending request or sink', 'red')) async def connect(self): if self.state != Channel.CLOSED: raise InvalidStateError('invalid state') + # Check that we can start a new connection + if self.connection_result: + raise RuntimeError('connection already pending') + self.change_state(Channel.WAIT_CONNECT_RSP) self.send_control_frame( L2CAP_Connection_Request( @@ -657,7 +702,12 @@ class Channel(EventEmitter): # Create a future to wait for the state machine to get to a success or error state self.connection_result = asyncio.get_running_loop().create_future() - return await self.connection_result + + # Wait for the connection to succeed or fail + try: + return await self.connection_result + finally: + self.connection_result = None async def disconnect(self): if self.state != Channel.OPEN: @@ -708,7 +758,7 @@ class Channel(EventEmitter): def on_connection_response(self, response): if self.state != Channel.WAIT_CONNECT_RSP: - logger.warn(color('invalid state', 'red')) + logger.warning(color('invalid state', 'red')) return if response.result == L2CAP_Connection_Response.CONNECTION_SUCCESSFUL: @@ -734,7 +784,7 @@ class Channel(EventEmitter): self.state != Channel.WAIT_CONFIG_REQ and self.state != Channel.WAIT_CONFIG_REQ_RSP ): - logger.warn(color('invalid state', 'red')) + logger.warning(color('invalid state', 'red')) return # Decode the options @@ -750,7 +800,7 @@ class Channel(EventEmitter): source_cid = self.destination_cid, flags = 0x0000, result = L2CAP_Configure_Response.SUCCESS, - options = request.options # TODO: don't accept everthing blindly + options = request.options # TODO: don't accept everything blindly ) ) if self.state == Channel.WAIT_CONFIG: @@ -777,7 +827,7 @@ class Channel(EventEmitter): self.connection_result = None self.emit('open') else: - logger.warn(color('invalid state', 'red')) + logger.warning(color('invalid state', 'red')) elif response.result == L2CAP_Configure_Response.FAILURE_UNACCEPTABLE_PARAMETERS: # Re-configure with what's suggested in the response self.send_control_frame( @@ -789,7 +839,7 @@ class Channel(EventEmitter): ) ) else: - logger.warn(color(f'!!! configuration rejected: {L2CAP_Configure_Response.result_name(response.result)}', 'red')) + logger.warning(color(f'!!! configuration rejected: {L2CAP_Configure_Response.result_name(response.result)}', 'red')) # TODO: decide how to fail gracefully def on_disconnection_request(self, request): @@ -805,15 +855,15 @@ class Channel(EventEmitter): self.emit('close') self.manager.on_channel_closed(self) else: - logger.warn(color('invalid state', 'red')) + logger.warning(color('invalid state', 'red')) def on_disconnection_response(self, response): if self.state != Channel.WAIT_DISCONNECT: - logger.warn(color('invalid state', 'red')) + logger.warning(color('invalid state', 'red')) return if response.destination_cid != self.destination_cid or response.source_cid != self.source_cid: - logger.warn('unexpected source or destination CID') + logger.warning('unexpected source or destination CID') return self.change_state(Channel.CLOSED) @@ -827,23 +877,363 @@ class Channel(EventEmitter): return f'Channel({self.source_cid}->{self.destination_cid}, PSM={self.psm}, MTU={self.mtu}, state={Channel.STATE_NAMES[self.state]})' +# ----------------------------------------------------------------------------- +class LeConnectionOrientedChannel(EventEmitter): + """ + LE Credit-based Connection Oriented Channel + """ + + INIT = 0 + CONNECTED = 1 + CONNECTING = 2 + DISCONNECTING = 3 + DISCONNECTED = 4 + CONNECTION_ERROR = 5 + + STATE_NAMES = { + INIT: 'INIT', + CONNECTED: 'CONNECTED', + CONNECTING: 'CONNECTING', + DISCONNECTING: 'DISCONNECTING', + DISCONNECTED: 'DISCONNECTED', + CONNECTION_ERROR: 'CONNECTION_ERROR' + } + + @staticmethod + def state_name(state): + return name_or_number(LeConnectionOrientedChannel.STATE_NAMES, state) + + def __init__( + self, + manager, + connection, + le_psm, + source_cid, + destination_cid, + mtu, + mps, + credits, + peer_mtu, + peer_mps, + peer_credits, + connected + ): + super().__init__() + self.manager = manager + self.connection = connection + self.le_psm = le_psm + self.source_cid = source_cid + self.destination_cid = destination_cid + self.mtu = mtu + self.mps = mps + self.credits = credits + self.peer_mtu = peer_mtu + self.peer_mps = peer_mps + self.peer_credits = peer_credits + self.peer_max_credits = self.peer_credits + self.peer_credits_threshold = self.peer_max_credits // 2 + self.in_sdu = None + self.in_sdu_length = 0 + self.out_queue = deque() + self.out_sdu = None + self.sink = None + self.connection_result = None + self.disconnection_result = None + self.drained = asyncio.Event() + + self.drained.set() + + if connected: + self.state = LeConnectionOrientedChannel.CONNECTED + else: + self.state = LeConnectionOrientedChannel.INIT + + def change_state(self, new_state): + logger.debug(f'{self} state change -> {color(self.state_name(new_state), "cyan")}') + self.state = new_state + + if new_state == self.CONNECTED: + self.emit('open') + elif new_state == self.DISCONNECTED: + self.emit('close') + + def send_pdu(self, pdu): + self.manager.send_pdu(self.connection, self.destination_cid, pdu) + + def send_control_frame(self, frame): + self.manager.send_control_frame(self.connection, L2CAP_LE_SIGNALING_CID, frame) + + async def connect(self): + # Check that we're in the right state + if self.state != self.INIT: + raise InvalidStateError('not in a connectable state') + + # Check that we can start a new connection + identifier = self.manager.next_identifier(self.connection) + if identifier in self.manager.le_coc_requests: + raise RuntimeError('too many concurrent connection requests') + + self.change_state(self.CONNECTING) + request = L2CAP_LE_Credit_Based_Connection_Request( + identifier = identifier, + le_psm = self.le_psm, + source_cid = self.source_cid, + mtu = self.mtu, + mps = self.mps, + initial_credits = self.peer_credits + ) + self.manager.le_coc_requests[identifier] = request + self.send_control_frame(request) + + # Create a future to wait for the response + self.connection_result = asyncio.get_running_loop().create_future() + + # Wait for the connection to succeed or fail + return await self.connection_result + + async def disconnect(self): + # Check that we're connected + if self.state != self.CONNECTED: + raise InvalidStateError('not connected') + + self.change_state(self.DISCONNECTING) + self.flush_output() + self.send_control_frame( + L2CAP_Disconnection_Request( + identifier = self.manager.next_identifier(self.connection), + destination_cid = self.destination_cid, + source_cid = self.source_cid + ) + ) + + # Create a future to wait for the state machine to get to a success or error state + self.disconnection_result = asyncio.get_running_loop().create_future() + return await self.disconnection_result + + def on_pdu(self, pdu): + if self.sink is None: + logger.warning('received pdu without a sink') + return + + if self.state != self.CONNECTED: + logger.warning('received PDU while not connected, dropping') + + # Manage the peer credits + if self.peer_credits == 0: + logger.warning('received LE frame when peer out of credits') + else: + self.peer_credits -= 1 + if self.peer_credits <= self.peer_credits_threshold: + # The credits fell below the threshold, replenish them to the max + self.send_control_frame( + L2CAP_LE_Flow_Control_Credit( + identifier = self.manager.next_identifier(self.connection), + cid = self.source_cid, + credits = self.peer_max_credits - self.peer_credits + ) + ) + self.peer_credits = self.peer_max_credits + + # Check if this starts a new SDU + if self.in_sdu is None: + # Start a new SDU + self.in_sdu = pdu + else: + # Continue an SDU + self.in_sdu += pdu + + # Check if the SDU is complete + if self.in_sdu_length == 0: + # We don't know the size yet, check if we have received the header to compute it + if len(self.in_sdu) >= 2: + self.in_sdu_length = struct.unpack_from(' 0: + if self.out_sdu is not None: + # Finish the current SDU + packet = self.out_sdu[:self.peer_mps] + self.send_pdu(packet) + self.credits -= 1 + logger.debug(f'sent {len(packet)} bytes, {self.credits} credits left') + if len(packet) == len(self.out_sdu): + # We sent everything + self.out_sdu = None + else: + # Keep what's still left to send + self.out_sdu = self.out_sdu[len(packet):] + continue + elif self.out_queue: + # Create the next SDU (2 bytes header plus up to MTU bytes payload) + logger.debug(f'assembling SDU from {len(self.out_queue)} packets in output queue') + payload = b'' + while self.out_queue and len(payload) < self.peer_mtu: + # We can add more data to the payload + chunk = self.out_queue[0][:self.peer_mtu - len(payload)] + payload += chunk + self.out_queue[0] = self.out_queue[0][len(chunk):] + if len(self.out_queue[0]) == 0: + # We consumed the entire buffer, remove it + self.out_queue.popleft() + logger.debug(f'packet completed, {len(self.out_queue)} left in queue') + + # Construct the SDU with its header + assert len(payload) != 0 + logger.debug(f'SDU complete: {len(payload)} payload bytes') + self.out_sdu = struct.pack('{self.destination_cid}, State={self.state_name(self.state)}, PSM={self.le_psm}, MTU={self.mtu}/{self.peer_mtu}, MPS={self.mps}/{self.peer_mps}, credits={self.credits}/{self.peer_credits})' + + # ----------------------------------------------------------------------------- class ChannelManager: - def __init__(self, extended_features=None, connectionless_mtu=1024): - self.host = None - self.channels = {} # Channels, mapped by connection and cid - # Fixed channel handlers, mapped by cid - self.fixed_channels = { - L2CAP_SIGNALING_CID: None, L2CAP_LE_SIGNALING_CID: None} + def __init__(self, extended_features=[], connectionless_mtu=L2CAP_DEFAULT_CONNECTIONLESS_MTU): + self._host = None self.identifiers = {} # Incrementing identifier values by connection + self.channels = {} # All channels, mapped by connection and source cid + self.fixed_channels = { # Fixed channel handlers, mapped by cid + L2CAP_SIGNALING_CID: None, L2CAP_LE_SIGNALING_CID: None + } self.servers = {} # Servers accepting connections, by PSM - self.extended_features = [] if extended_features is None else extended_features + self.le_coc_channels = {} # LE CoC channels, mapped by connection and destination cid + self.le_coc_servers = {} # LE CoC - Servers accepting connections, by PSM + self.le_coc_requests = {} # LE CoC connection requests, by identifier + self.extended_features = extended_features self.connectionless_mtu = connectionless_mtu + @property + def host(self): + return self._host + + @host.setter + def host(self, host): + if self._host is not None: + self._host.remove_listener('disconnection', self.on_disconnection) + self._host = host + if host is not None: + host.add_listener('disconnection', self.on_disconnection) + def find_channel(self, connection_handle, cid): if connection_channels := self.channels.get(connection_handle): return connection_channels.get(cid) + def find_le_coc_channel(self, connection_handle, cid): + if connection_channels := self.le_coc_channels.get(connection_handle): + return connection_channels.get(cid) + @staticmethod def find_free_br_edr_cid(channels): # Pick the smallest valid CID that's not already in the list @@ -853,6 +1243,24 @@ class ChannelManager: if cid not in channels: return cid + @staticmethod + def find_free_le_cid(channels): + # Pick the smallest valid CID that's not already in the list + # (not necessarily the most efficient algorithm, but the list of CID is + # very small in practice) + for cid in range(L2CAP_LE_U_DYNAMIC_CID_RANGE_START, L2CAP_LE_U_DYNAMIC_CID_RANGE_END + 1): + if cid not in channels: + return cid + + @staticmethod + def check_le_coc_parameters(max_credits, mtu, mps): + if max_credits < 1 or max_credits > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_CREDITS: + raise ValueError('max credits out of range') + if mtu < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MTU: + raise ValueError('MTU too small') + if mps < L2CAP_LE_CREDIT_BASED_CONNECTION_MIN_MPS or mps > L2CAP_LE_CREDIT_BASED_CONNECTION_MAX_MPS: + raise ValueError('MPS out of range') + def next_identifier(self, connection): identifier = (self.identifiers.setdefault(connection.handle, 0) + 1) % 256 self.identifiers[connection.handle] = identifier @@ -866,8 +1274,77 @@ class ChannelManager: del self.fixed_channels[cid] def register_server(self, psm, server): + if psm == 0: + # Find a free PSM + for candidate in range(L2CAP_PSM_DYNAMIC_RANGE_START, L2CAP_PSM_DYNAMIC_RANGE_END + 1, 2): + if (candidate >> 8) % 2 == 1: + continue + if candidate in self.servers: + continue + psm = candidate + break + else: + raise InvalidStateError('no free PSM') + else: + # Check that the PSM isn't already in use + if psm in self.servers: + raise ValueError('PSM already in use') + + # Check that the PSM is valid + if psm % 2 == 0: + raise ValueError('invalid PSM (not odd)') + check = psm >> 8 + while check: + if check % 2 != 0: + raise ValueError('invalid PSM') + check >>= 8 + self.servers[psm] = server + return psm + + def register_le_coc_server( + self, + psm, + server, + max_credits=L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS, + mtu=L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU, + mps=L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS + ): + self.check_le_coc_parameters(max_credits, mtu, mps) + + if psm == 0: + # Find a free PSM + for candidate in range(L2CAP_LE_PSM_DYNAMIC_RANGE_START, L2CAP_LE_PSM_DYNAMIC_RANGE_END + 1): + if candidate in self.le_coc_servers: + continue + psm = candidate + break + else: + raise InvalidStateError('no free PSM') + else: + # Check that the PSM isn't already in use + if psm in self.le_coc_servers: + raise ValueError('PSM already in use') + + self.le_coc_servers[psm] = ( + server, + max_credits or L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS, + mtu or L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU, + mps or L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS + ) + + return psm + + def on_disconnection(self, connection_handle, reason): + logger.debug(f'disconnection from {connection_handle}, cleaning up channels') + if connection_handle in self.channels: + del self.channels[connection_handle] + if connection_handle in self.le_coc_channels: + del self.le_coc_channels[connection_handle] + if connection_handle in self.identifiers: + del self.identifiers[connection_handle] + def send_pdu(self, connection, cid, pdu): pdu_str = pdu.hex() if type(pdu) is bytes else str(pdu) logger.debug(f'{color(">>> Sending L2CAP PDU", "blue")} on connection [0x{connection.handle:04X}] (CID={cid}) {connection.peer_address}: {pdu_str}') @@ -883,7 +1360,7 @@ class ChannelManager: self.fixed_channels[cid](connection.handle, pdu) else: if (channel := self.find_channel(connection.handle, cid)) is None: - logger.warn(color(f'channel not found for 0x{connection.handle:04X}:{cid}', 'red')) + logger.warning(color(f'channel not found for 0x{connection.handle:04X}:{cid}', 'red')) return channel.on_pdu(pdu) @@ -927,7 +1404,6 @@ class ChannelManager: def on_l2cap_command_reject(self, connection, cid, packet): logger.warning(f'{color("!!! Command rejected:", "red")} {packet.reason}') - pass def on_l2cap_connection_request(self, connection, cid, request): # Check if there's a server for this PSM @@ -959,7 +1435,7 @@ class ChannelManager: server(channel) channel.on_connection_request(request) else: - logger.warn(f'No server for connection 0x{connection.handle:04X} on PSM {request.psm}') + logger.warning(f'No server for connection 0x{connection.handle:04X} on PSM {request.psm}') self.send_control_frame( connection, cid, @@ -974,35 +1450,35 @@ class ChannelManager: def on_l2cap_connection_response(self, connection, cid, response): if (channel := self.find_channel(connection.handle, response.source_cid)) is None: - logger.warn(color(f'channel {response.source_cid} not found for 0x{connection.handle:04X}:{cid}', 'red')) + logger.warning(color(f'channel {response.source_cid} not found for 0x{connection.handle:04X}:{cid}', 'red')) return channel.on_connection_response(response) def on_l2cap_configure_request(self, connection, cid, request): if (channel := self.find_channel(connection.handle, request.destination_cid)) is None: - logger.warn(color(f'channel {request.destination_cid} not found for 0x{connection.handle:04X}:{cid}', 'red')) + logger.warning(color(f'channel {request.destination_cid} not found for 0x{connection.handle:04X}:{cid}', 'red')) return channel.on_configure_request(request) def on_l2cap_configure_response(self, connection, cid, response): if (channel := self.find_channel(connection.handle, response.source_cid)) is None: - logger.warn(color(f'channel {response.source_cid} not found for 0x{connection.handle:04X}:{cid}', 'red')) + logger.warning(color(f'channel {response.source_cid} not found for 0x{connection.handle:04X}:{cid}', 'red')) return channel.on_configure_response(response) def on_l2cap_disconnection_request(self, connection, cid, request): if (channel := self.find_channel(connection.handle, request.destination_cid)) is None: - logger.warn(color(f'channel {request.destination_cid} not found for 0x{connection.handle:04X}:{cid}', 'red')) + logger.warning(color(f'channel {request.destination_cid} not found for 0x{connection.handle:04X}:{cid}', 'red')) return channel.on_disconnection_request(request) def on_l2cap_disconnection_response(self, connection, cid, response): if (channel := self.find_channel(connection.handle, response.source_cid)) is None: - logger.warn(color(f'channel {response.source_cid} not found for 0x{connection.handle:04X}:{cid}', 'red')) + logger.warning(color(f'channel {response.source_cid} not found for 0x{connection.handle:04X}:{cid}', 'red')) return channel.on_disconnection_response(response) @@ -1076,25 +1552,123 @@ class ChannelManager: ) def on_l2cap_connection_parameter_update_response(self, connection, cid, response): + # TODO: check response pass def on_l2cap_le_credit_based_connection_request(self, connection, cid, request): - # FIXME: temp fixed values - self.send_control_frame( - connection, - cid, - L2CAP_LE_Credit_Based_Connection_Response( - identifier = request.identifier, - destination_cid = 194, # FIXME: for testing only - mtu = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU, - mps = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS, - initial_credits = 3, # FIXME: for testing only - result = L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_SUCCESSFUL - ) - ) + if request.le_psm in self.le_coc_servers: + (server, max_credits, mtu, mps) = self.le_coc_servers[request.le_psm] - def on_l2cap_le_flow_control_credit(self, connection, cid, packet): - pass + # Check that the CID isn't already used + le_connection_channels = self.le_coc_channels.setdefault(connection.handle, {}) + if request.source_cid in le_connection_channels: + logger.warning(f'source CID {request.source_cid} already in use') + self.send_control_frame( + connection, + cid, + L2CAP_LE_Credit_Based_Connection_Response( + identifier = request.identifier, + destination_cid = 0, + mtu = mtu, + mps = mps, + initial_credits = 0, + result = L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_REFUSED_SOURCE_CID_ALREADY_ALLOCATED + ) + ) + return + + # Find a free CID for this new channel + connection_channels = self.channels.setdefault(connection.handle, {}) + source_cid = self.find_free_le_cid(connection_channels) + if source_cid is None: # Should never happen! + self.send_control_frame( + connection, + cid, + L2CAP_LE_Credit_Based_Connection_Response( + identifier = request.identifier, + destination_cid = 0, + mtu = mtu, + mps = mps, + initial_credits = 0, + result = L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_REFUSED_NO_RESOURCES_AVAILABLE, + ) + ) + return + + # Create a new channel + logger.debug(f'creating LE CoC server channel with cid={source_cid} for psm {request.le_psm}') + channel = LeConnectionOrientedChannel( + self, + connection, + request.le_psm, + source_cid, + request.source_cid, + mtu, + mps, + request.initial_credits, + request.mtu, + request.mps, + max_credits, + True + ) + connection_channels[source_cid] = channel + le_connection_channels[request.source_cid] = channel + + # Respond + self.send_control_frame( + connection, + cid, + L2CAP_LE_Credit_Based_Connection_Response( + identifier = request.identifier, + destination_cid = source_cid, + mtu = mtu, + mps = mps, + initial_credits = max_credits, + result = L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_SUCCESSFUL + ) + ) + + # Notify + server(channel) + else: + logger.info(f'No LE server for connection 0x{connection.handle:04X} on PSM {request.le_psm}') + self.send_control_frame( + connection, + cid, + L2CAP_LE_Credit_Based_Connection_Response( + identifier = request.identifier, + destination_cid = 0, + mtu = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU, + mps = L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS, + initial_credits = 0, + result = L2CAP_LE_Credit_Based_Connection_Response.CONNECTION_REFUSED_LE_PSM_NOT_SUPPORTED, + ) + ) + + def on_l2cap_le_credit_based_connection_response(self, connection, cid, response): + # Find the pending request by identifier + request = self.le_coc_requests.get(response.identifier) + if request is None: + logger.warning(color('!!! received response for unknown request', 'red')) + return + del self.le_coc_requests[response.identifier] + + # Find the channel for this request + channel = self.find_channel(connection.handle, request.source_cid) + if channel is None: + logger.warning(color(f'received connection response for an unknown channel (cid={request.source_cid})', 'red')) + return + + # Process the response + channel.on_connection_response(response) + + def on_l2cap_le_flow_control_credit(self, connection, cid, credit): + channel = self.find_le_coc_channel(connection.handle, credit.cid) + if channel is None: + logger.warning(f'received credits for an unknown channel (cid={credit.cid}') + return + + channel.on_credits(credit.credits) def on_channel_closed(self, channel): connection_channels = self.channels.get(channel.connection.handle) @@ -1102,22 +1676,65 @@ class ChannelManager: if channel.source_cid in connection_channels: del connection_channels[channel.source_cid] - async def connect(self, connection, psm): - # NOTE: this implementation hard-codes BR/EDR more - # TODO: LE mode (maybe?) + async def open_le_coc(self, connection, psm, max_credits, mtu, mps): + self.check_le_coc_parameters(max_credits, mtu, mps) - # Find a free CID for a new channel + # Find a free CID for the new channel connection_channels = self.channels.setdefault(connection.handle, {}) - cid = self.find_free_br_edr_cid(connection_channels) - if cid is None: # Should never happen! + source_cid = self.find_free_le_cid(connection_channels) + if source_cid is None: # Should never happen! raise RuntimeError('all CIDs already in use') # Create the channel - logger.debug(f'creating client channel with cid={cid} for psm {psm}') - channel = Channel(self, connection, L2CAP_SIGNALING_CID, psm, cid, L2CAP_MIN_BR_EDR_MTU) - connection_channels[cid] = channel + logger.debug(f'creating coc channel with cid={source_cid} for psm {psm}') + channel = LeConnectionOrientedChannel( + manager = self, + connection = connection, + le_psm = psm, + source_cid = source_cid, + destination_cid = 0, + mtu = mtu, + mps = mps, + credits = 0, + peer_mtu = 0, + peer_mps = 0, + peer_credits = max_credits, + connected = False + ) + connection_channels[source_cid] = channel # Connect - await channel.connect() + try: + await channel.connect() + except Exception as error: + logger.warning(f'connection failed: {error}') + del connection_channels[source_cid] + raise + + # Remember the channel by source CID and destination CID + le_connection_channels = self.le_coc_channels.setdefault(connection.handle, {}) + le_connection_channels[channel.destination_cid] = channel + + return channel + + async def connect(self, connection, psm): + # NOTE: this implementation hard-codes BR/EDR + + # Find a free CID for a new channel + connection_channels = self.channels.setdefault(connection.handle, {}) + source_cid = self.find_free_br_edr_cid(connection_channels) + if source_cid is None: # Should never happen! + raise RuntimeError('all CIDs already in use') + + # Create the channel + logger.debug(f'creating client channel with cid={source_cid} for psm {psm}') + channel = Channel(self, connection, L2CAP_SIGNALING_CID, psm, source_cid, L2CAP_MIN_BR_EDR_MTU) + connection_channels[source_cid] = channel + + # Connect + try: + await channel.connect() + except Exception: + del connection_channels[source_cid] return channel diff --git a/bumble/transport/common.py b/bumble/transport/common.py index d5c1ae91..0f5d27f4 100644 --- a/bumble/transport/common.py +++ b/bumble/transport/common.py @@ -274,7 +274,7 @@ class PumpedPacketSource(ParserSource): self.terminated.set_result(error) break - self.pump_task = asyncio.get_running_loop().create_task(pump_packets()) + self.pump_task = asyncio.create_task(pump_packets()) def close(self): if self.pump_task: @@ -304,7 +304,7 @@ class PumpedPacketSink: logger.warn(f'exception while sending packet: {error}') break - self.pump_task = asyncio.get_running_loop().create_task(pump_packets()) + self.pump_task = asyncio.create_task(pump_packets()) def close(self): if self.pump_task: diff --git a/bumble/utils.py b/bumble/utils.py index 1ab3fd71..5d8ab954 100644 --- a/bumble/utils.py +++ b/bumble/utils.py @@ -18,6 +18,7 @@ import asyncio import logging import traceback +import collections from functools import wraps from colors import color from pyee import EventEmitter @@ -140,3 +141,95 @@ class AsyncRunner: return wrapper return decorator + + +# ----------------------------------------------------------------------------- +class FlowControlAsyncPipe: + """ + Asyncio pipe with flow control. When writing to the pipe, the source is + paused (by calling a function passed in when the pipe is created) if the + amount of queued data exceeds a specified threshold. + """ + def __init__(self, pause_source, resume_source, write_to_sink=None, drain_sink=None, threshold=0): + self.pause_source = pause_source + self.resume_source = resume_source + self.write_to_sink = write_to_sink + self.drain_sink = drain_sink + self.threshold = threshold + self.queue = collections.deque() # Queue of packets + self.queued_bytes = 0 # Number of bytes in the queue + self.ready_to_pump = asyncio.Event() + self.paused = False + self.source_paused = False + self.pump_task = None + + def start(self): + if self.pump_task is None: + self.pump_task = asyncio.create_task(self.pump()) + + self.check_pump() + + def stop(self): + if self.pump_task is not None: + self.pump_task.cancel() + self.pump_task = None + + def write(self, packet): + self.queued_bytes += len(packet) + self.queue.append(packet) + + # Pause the source if we're over the threshold + if self.queued_bytes > self.threshold and not self.source_paused: + logger.debug(f'pausing source (queued={self.queued_bytes})') + self.pause_source() + self.source_paused = True + + self.check_pump() + + def pause(self): + if not self.paused: + self.paused = True + if not self.source_paused: + self.pause_source() + self.source_paused = True + self.check_pump() + + def resume(self): + if self.paused: + self.paused = False + if self.source_paused: + self.resume_source() + self.source_paused = False + self.check_pump() + + def can_pump(self): + return self.queue and not self.paused and self.write_to_sink is not None + + def check_pump(self): + if self.can_pump(): + self.ready_to_pump.set() + else: + self.ready_to_pump.clear() + + async def pump(self): + while True: + # Wait until we can try to pump packets + await self.ready_to_pump.wait() + + # Try to pump a packet + if self.can_pump(): + packet = self.queue.pop() + self.write_to_sink(packet) + self.queued_bytes -= len(packet) + + # Drain the sink if we can + if self.drain_sink: + await self.drain_sink() + + # Check if we can accept more + if self.queued_bytes <= self.threshold and self.source_paused: + logger.debug(f'resuming source (queued={self.queued_bytes})') + self.source_paused = False + self.resume_source() + + self.check_pump() diff --git a/examples/asha_sink1.json b/examples/asha_sink1.json new file mode 100644 index 00000000..badef8b5 --- /dev/null +++ b/examples/asha_sink1.json @@ -0,0 +1,5 @@ +{ + "name": "Bumble Aid Left", + "address": "F1:F2:F3:F4:F5:F6", + "keystore": "JsonKeyStore" +} diff --git a/examples/asha_sink2.json b/examples/asha_sink2.json new file mode 100644 index 00000000..785d406a --- /dev/null +++ b/examples/asha_sink2.json @@ -0,0 +1,5 @@ +{ + "name": "Bumble Aid Right", + "address": "F7:F8:F9:FA:FB:FC", + "keystore": "JsonKeyStore" +} diff --git a/examples/run_asha_sink.py b/examples/run_asha_sink.py new file mode 100644 index 00000000..bebb5de7 --- /dev/null +++ b/examples/run_asha_sink.py @@ -0,0 +1,161 @@ +# Copyright 2021-2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ----------------------------------------------------------------------------- +# Imports +# ----------------------------------------------------------------------------- +import asyncio +import struct +import sys +import os +import logging + +from bumble.core import AdvertisingData +from bumble.device import Device +from bumble.transport import open_transport_or_link +from bumble.hci import UUID +from bumble.gatt import ( + Service, + Characteristic, + CharacteristicValue +) + + +# ----------------------------------------------------------------------------- +# Constants +# ----------------------------------------------------------------------------- +ASHA_SERVICE = UUID.from_16_bits(0xFDF0, 'Audio Streaming for Hearing Aid') +ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC = UUID('6333651e-c481-4a3e-9169-7c902aad37bb', 'ReadOnlyProperties') +ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC = UUID('f0d4de7e-4a88-476c-9d9f-1937b0996cc0', 'AudioControlPoint') +ASHA_AUDIO_STATUS_CHARACTERISTIC = UUID('38663f1a-e711-4cac-b641-326b56404837', 'AudioStatus') +ASHA_VOLUME_CHARACTERISTIC = UUID('00e4ca9e-ab14-41e4-8823-f9e70c7e91df', 'Volume') +ASHA_LE_PSM_OUT_CHARACTERISTIC = UUID('2d410339-82b6-42aa-b34e-e2e01df8cc1a', 'LE_PSM_OUT') + + +# ----------------------------------------------------------------------------- +async def main(): + if len(sys.argv) != 4: + print('Usage: python run_asha_sink.py ') + print('example: python run_asha_sink.py device1.json usb:0 audio_out.g722') + return + + audio_out = open(sys.argv[3], 'wb') + + async with await open_transport_or_link(sys.argv[2]) as (hci_source, hci_sink): + device = Device.from_config_file_with_hci(sys.argv[1], hci_source, hci_sink) + + # Handler for audio control commands + def on_audio_control_point_write(connection, value): + print('--- AUDIO CONTROL POINT Write:', value.hex()) + opcode = value[0] + if opcode == 1: + # Start + audio_type = ('Unknown', 'Ringtone', 'Phone Call', 'Media')[value[2]] + print(f'### START: codec={value[1]}, audio_type={audio_type}, volume={value[3]}, otherstate={value[4]}') + elif opcode == 2: + print('### STOP') + elif opcode == 3: + print(f'### STATUS: connected={value[1]}') + + # Respond with a status + asyncio.create_task(device.notify_subscribers(audio_status_characteristic, force=True)) + + # Handler for volume control + def on_volume_write(connection, value): + print('--- VOLUME Write:', value[0]) + + # Register an L2CAP CoC server + def on_coc(channel): + def on_data(data): + print('<<< Voice data received:', data.hex()) + audio_out.write(data) + + channel.sink = on_data + + psm = device.register_l2cap_channel_server(0, on_coc, 8) + print(f'### LE_PSM_OUT = {psm}') + + # Add the ASHA service to the GATT server + read_only_properties_characteristic = Characteristic( + ASHA_READ_ONLY_PROPERTIES_CHARACTERISTIC, + Characteristic.READ, + Characteristic.READABLE, + bytes([ + 0x01, # Version + 0x00, # Device Capabilities [Left, Monaural] + 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, # HiSyncId + 0x01, # Feature Map [LE CoC audio output streaming supported] + 0x00, 0x00, # Render Delay + 0x00, 0x00, # RFU + 0x02, 0x00 # Codec IDs [G.722 at 16 kHz] + ]) + ) + audio_control_point_characteristic = Characteristic( + ASHA_AUDIO_CONTROL_POINT_CHARACTERISTIC, + Characteristic.WRITE | Characteristic.WRITE_WITHOUT_RESPONSE, + Characteristic.WRITEABLE, + CharacteristicValue(write=on_audio_control_point_write) + ) + audio_status_characteristic = Characteristic( + ASHA_AUDIO_STATUS_CHARACTERISTIC, + Characteristic.READ | Characteristic.NOTIFY, + Characteristic.READABLE, + bytes([0]) + ) + volume_characteristic = Characteristic( + ASHA_VOLUME_CHARACTERISTIC, + Characteristic.WRITE_WITHOUT_RESPONSE, + Characteristic.WRITEABLE, + CharacteristicValue(write=on_volume_write) + ) + le_psm_out_characteristic = Characteristic( + ASHA_LE_PSM_OUT_CHARACTERISTIC, + Characteristic.READ, + Characteristic.READABLE, + struct.pack(' ') - print('example: run_gatt_server.py device1.json usb:0') + print('Usage: run_notifier.py ') + print('example: run_notifier.py device1.json usb:0') return print('<<< connecting to HCI...') @@ -83,6 +89,7 @@ async def main(): Characteristic.READABLE, bytes([0x42]) ) + characteristic3.on('subscription', on_my_characteristic_subscription) custom_service = Service( '50DB505C-8AC4-4738-8448-3B1D9CC09CC5', [characteristic1, characteristic2, characteristic3] diff --git a/examples/run_rfcomm_client.py b/examples/run_rfcomm_client.py index 83ef8483..76586c3e 100644 --- a/examples/run_rfcomm_client.py +++ b/examples/run_rfcomm_client.py @@ -98,6 +98,7 @@ async def list_rfcomm_channels(device, connection): await sdp_client.disconnect() + # ----------------------------------------------------------------------------- class TcpServerProtocol(asyncio.Protocol): def __init__(self, rfcomm_session): @@ -173,7 +174,7 @@ async def main(): print('*** Encryption on') # Create a client and start it - print('@@@ Starting to RFCOMM client...') + print('@@@ Starting RFCOMM client...') rfcomm_client = Client(device, connection) rfcomm_mux = await rfcomm_client.start() print('@@@ Started') @@ -192,7 +193,7 @@ async def main(): if len(sys.argv) == 6: # A TCP port was specified, start listening tcp_port = int(sys.argv[5]) - asyncio.get_running_loop().create_task(tcp_server(tcp_port, session)) + asyncio.create_task(tcp_server(tcp_port, session)) await hci_source.wait_for_termination() diff --git a/setup.cfg b/setup.cfg index ff992261..c64dcce7 100644 --- a/setup.cfg +++ b/setup.cfg @@ -51,6 +51,7 @@ console_scripts = bumble-controller-info = bumble.apps.controller_info:main bumble-gatt-dump = bumble.apps.gatt_dump:main bumble-hci-bridge = bumble.apps.hci_bridge:main + bumble-l2cap-bridge = bumble.apps.l2cap_bridge:main bumble-pair = bumble.apps.pair:main bumble-scan = bumble.apps.scan:main bumble-show = bumble.apps.show:main @@ -64,6 +65,7 @@ build = test = pytest >= 6.2 pytest-asyncio >= 0.17 + coverage >= 6.4 development = invoke >= 1.4 nox >= 2022 diff --git a/tests/l2cap_test.py b/tests/l2cap_test.py new file mode 100644 index 00000000..319038b6 --- /dev/null +++ b/tests/l2cap_test.py @@ -0,0 +1,284 @@ +# Copyright 2021-2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ----------------------------------------------------------------------------- +# Imports +# ----------------------------------------------------------------------------- +import asyncio +import logging +import os +import random +import pytest + +from bumble.controller import Controller +from bumble.link import LocalLink +from bumble.device import Device +from bumble.host import Host +from bumble.transport import AsyncPipeSink +from bumble.core import ProtocolError +from bumble.l2cap import ( + L2CAP_Connection_Request +) + + +# ----------------------------------------------------------------------------- +# Logging +# ----------------------------------------------------------------------------- +logger = logging.getLogger(__name__) + + +# ----------------------------------------------------------------------------- +class TwoDevices: + def __init__(self): + self.connections = [None, None] + + self.link = LocalLink() + self.controllers = [ + Controller('C1', link = self.link), + Controller('C2', link = self.link) + ] + self.devices = [ + Device( + address = 'F0:F1:F2:F3:F4:F5', + host = Host(self.controllers[0], AsyncPipeSink(self.controllers[0])) + ), + Device( + address = 'F5:F4:F3:F2:F1:F0', + host = Host(self.controllers[1], AsyncPipeSink(self.controllers[1])) + ) + ] + + self.paired = [None, None] + + def on_connection(self, which, connection): + self.connections[which] = connection + + def on_paired(self, which, keys): + self.paired[which] = keys + + +# ----------------------------------------------------------------------------- +async def setup_connection(): + # Create two devices, each with a controller, attached to the same link + two_devices = TwoDevices() + + # Attach listeners + two_devices.devices[0].on('connection', lambda connection: two_devices.on_connection(0, connection)) + two_devices.devices[1].on('connection', lambda connection: two_devices.on_connection(1, connection)) + + # Start + await two_devices.devices[0].power_on() + await two_devices.devices[1].power_on() + + # Connect the two devices + await two_devices.devices[0].connect(two_devices.devices[1].random_address) + + # Check the post conditions + assert(two_devices.connections[0] is not None) + assert(two_devices.connections[1] is not None) + + return two_devices + + +# ----------------------------------------------------------------------------- +def test_helpers(): + psm = L2CAP_Connection_Request.serialize_psm(0x01) + assert(psm == bytes([0x01, 0x00])) + + psm = L2CAP_Connection_Request.serialize_psm(0x1023) + assert(psm == bytes([0x23, 0x10])) + + psm = L2CAP_Connection_Request.serialize_psm(0x242311) + assert(psm == bytes([0x11, 0x23, 0x24])) + + (offset, psm) = L2CAP_Connection_Request.parse_psm(bytes([0x00, 0x01, 0x00, 0x44]), 1) + assert(offset == 3) + assert(psm == 0x01) + + (offset, psm) = L2CAP_Connection_Request.parse_psm(bytes([0x00, 0x23, 0x10, 0x44]), 1) + assert(offset == 3) + assert(psm == 0x1023) + + (offset, psm) = L2CAP_Connection_Request.parse_psm(bytes([0x00, 0x11, 0x23, 0x24, 0x44]), 1) + assert(offset == 4) + assert(psm == 0x242311) + + rq = L2CAP_Connection_Request(psm = 0x01, source_cid = 0x44) + brq = bytes(rq) + srq = L2CAP_Connection_Request.from_bytes(brq) + assert(srq.psm == rq.psm) + assert(srq.source_cid == rq.source_cid) + + +# ----------------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_basic_connection(): + devices = await setup_connection() + psm = 1234 + + # Check that if there's no one listening, we can't connect + with pytest.raises(ProtocolError): + l2cap_channel = await devices.connections[0].open_l2cap_channel(psm) + + # Now add a listener + incoming_channel = None + received = [] + + def on_coc(channel): + nonlocal incoming_channel + incoming_channel = channel + + def on_data(data): + received.append(data) + + channel.sink = on_data + + devices.devices[1].register_l2cap_channel_server(psm, on_coc) + l2cap_channel = await devices.connections[0].open_l2cap_channel(psm) + + messages = ( + bytes([1, 2, 3]), + bytes([4, 5, 6]), + bytes(10000) + ) + for message in messages: + l2cap_channel.write(message) + await asyncio.sleep(0) + + await l2cap_channel.drain() + + # Test closing + closed = [False, False] + closed_event = asyncio.Event() + + def on_close(which, event): + closed[which] = True + if event: + event.set() + + l2cap_channel.on('close', lambda: on_close(0, None)) + incoming_channel.on('close', lambda: on_close(1, closed_event)) + await l2cap_channel.disconnect() + assert(closed == [True, True]) + await closed_event.wait() + + sent_bytes = b''.join(messages) + received_bytes = b''.join(received) + assert(sent_bytes == received_bytes) + + +# ----------------------------------------------------------------------------- +async def transfer_payload(max_credits, mtu, mps): + devices = await setup_connection() + + received = [] + + def on_coc(channel): + def on_data(data): + received.append(data) + + channel.sink = on_data + + psm = devices.devices[1].register_l2cap_channel_server( + psm = 0, + server = on_coc, + max_credits = max_credits, + mtu = mtu, + mps = mps + ) + l2cap_channel = await devices.connections[0].open_l2cap_channel(psm) + + messages = [ + bytes([1, 2, 3, 4, 5, 6, 7]) * x + for x in (3, 10, 100, 500, 789) + ] + for message in messages: + l2cap_channel.write(message) + await asyncio.sleep(0) + if random.randint(0, 5) == 1: + await l2cap_channel.drain() + + await l2cap_channel.drain() + await l2cap_channel.disconnect() + + sent_bytes = b''.join(messages) + received_bytes = b''.join(received) + assert(sent_bytes == received_bytes) + + +@pytest.mark.asyncio +async def test_transfer(): + for max_credits in (1, 10, 100, 10000): + for mtu in (23, 24, 25, 26, 50, 200, 255, 256, 1000): + for mps in (23, 24, 25, 26, 50, 200, 255, 256, 1000): + # print(max_credits, mtu, mps) + await transfer_payload(max_credits, mtu, mps) + + +# ----------------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_bidirectional_transfer(): + devices = await setup_connection() + + client_received = [] + server_received = [] + server_channel = None + + def on_server_coc(channel): + nonlocal server_channel + server_channel = channel + + def on_server_data(data): + server_received.append(data) + + channel.sink = on_server_data + + def on_client_data(data): + client_received.append(data) + + psm = devices.devices[1].register_l2cap_channel_server(psm=0, server=on_server_coc) + client_channel = await devices.connections[0].open_l2cap_channel(psm) + client_channel.sink = on_client_data + + messages = [ + bytes([1, 2, 3, 4, 5, 6, 7]) * x + for x in (3, 10, 100) + ] + for message in messages: + client_channel.write(message) + await client_channel.drain() + await asyncio.sleep(0) + server_channel.write(message) + await server_channel.drain() + + await client_channel.disconnect() + + message_bytes = b''.join(messages) + client_received_bytes = b''.join(client_received) + server_received_bytes = b''.join(server_received) + assert(client_received_bytes == message_bytes) + assert(server_received_bytes == message_bytes) + + +# ----------------------------------------------------------------------------- +async def run(): + test_helpers() + await test_basic_connection() + await test_transfer() + await test_bidirectional_transfer() + +# ----------------------------------------------------------------------------- +if __name__ == '__main__': + logging.basicConfig(level = os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) + asyncio.run(run())