Add L2CAP CoC support (squashed)

[85542e0] fix test
[3748781] add ASAH sink example
[e782e29] add app
[83daa30] wip
[7f138a0] add test
[f732108] allow different address syntax
[9d0bbf8] rename deprecated methods
[eb303d5] add LE CoC support
This commit is contained in:
Gilles Boccon-Gibod
2022-06-02 11:44:01 -07:00
committed by Gilles Boccon-Gibod
parent be8f8ac68f
commit ce9004f0ac
19 changed files with 1882 additions and 187 deletions

View File

@@ -351,7 +351,7 @@ class MediaPacketPump:
logger.debug('pump canceled')
# Pump packets
self.pump_task = asyncio.get_running_loop().create_task(pump_packets())
self.pump_task = asyncio.create_task(pump_packets())
async def stop(self):
# Stop the pump
@@ -1890,10 +1890,10 @@ class LocalSource(LocalStreamEndPoint, EventEmitter):
self.configuration = configuration
def on_start_command(self):
asyncio.get_running_loop().create_task(self.start())
asyncio.create_task(self.start())
def on_suspend_command(self):
asyncio.get_running_loop().create_task(self.stop())
asyncio.create_task(self.stop())
# -----------------------------------------------------------------------------

View File

@@ -43,6 +43,13 @@ logger = logging.getLogger(__name__)
# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------
DEVICE_MIN_SCAN_INTERVAL = 25
DEVICE_MAX_SCAN_INTERVAL = 10240
DEVICE_MIN_SCAN_WINDOW = 25
DEVICE_MAX_SCAN_WINDOW = 10240
DEVICE_MIN_LE_RSSI = -127
DEVICE_MAX_LE_RSSI = 20
DEVICE_DEFAULT_ADDRESS = '00:00:00:00:00:00'
DEVICE_DEFAULT_ADVERTISING_INTERVAL = 1000 # ms
DEVICE_DEFAULT_ADVERTISING_DATA = ''
@@ -62,20 +69,15 @@ DEVICE_DEFAULT_CONNECTION_MAX_LATENCY = 0
DEVICE_DEFAULT_CONNECTION_SUPERVISION_TIMEOUT = 720 # ms
DEVICE_DEFAULT_CONNECTION_MIN_CE_LENGTH = 0 # ms
DEVICE_DEFAULT_CONNECTION_MAX_CE_LENGTH = 0 # ms
DEVICE_MIN_SCAN_INTERVAL = 25
DEVICE_MAX_SCAN_INTERVAL = 10240
DEVICE_MIN_SCAN_WINDOW = 25
DEVICE_MAX_SCAN_WINDOW = 10240
DEVICE_MIN_LE_RSSI = -127
DEVICE_MAX_LE_RSSI = 20
DEVICE_DEFAULT_L2CAP_COC_MTU = l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MTU
DEVICE_DEFAULT_L2CAP_COC_MPS = l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_MPS
DEVICE_DEFAULT_L2CAP_COC_MAX_CREDITS = l2cap.L2CAP_LE_CREDIT_BASED_CONNECTION_DEFAULT_INITIAL_CREDITS
# -----------------------------------------------------------------------------
# Classes
# -----------------------------------------------------------------------------
# -----------------------------------------------------------------------------
class Advertisement:
TX_POWER_NOT_AVAILABLE = HCI_LE_Extended_Advertising_Report_Event.TX_POWER_INFORMATION_NOT_AVAILABLE
@@ -429,7 +431,16 @@ class Connection(CompositeEventEmitter):
def create_l2cap_connector(self, psm):
return self.device.create_l2cap_connector(self, psm)
async def disconnect(self, reason = HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR):
async def open_l2cap_channel(
self,
psm,
max_credits=DEVICE_DEFAULT_L2CAP_COC_MAX_CREDITS,
mtu=DEVICE_DEFAULT_L2CAP_COC_MTU,
mps=DEVICE_DEFAULT_L2CAP_COC_MPS
):
return await self.device.open_l2cap_channel(self, psm, max_credits, mtu, mps)
async def disconnect(self, reason=HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR):
return await self.device.disconnect(self, reason)
async def pair(self):
@@ -563,6 +574,7 @@ class DeviceConfiguration:
with open(filename, 'r') as file:
self.load_from_dict(json.load(file))
# -----------------------------------------------------------------------------
# Decorators used with the following Device class
# (we define them outside of the Device class, because defining decorators
@@ -685,7 +697,7 @@ class Device(CompositeEventEmitter):
self.classic_enabled = False
self.inquiry_response = None
self.address_resolver = None
self.classic_pending_accepts = { Address.ANY: [] } # Futures, by BD address OR [Futures] for Address.ANY
self.classic_pending_accepts = {Address.ANY: []} # Futures, by BD address OR [Futures] for Address.ANY
# Use the initial config or a default
self.public_address = Address('00:00:00:00:00:00')
@@ -785,15 +797,35 @@ class Device(CompositeEventEmitter):
if transport is None or connection.transport == transport:
return connection
def register_l2cap_server(self, psm, server):
self.l2cap_channel_manager.register_server(psm, server)
def create_l2cap_connector(self, connection, psm):
return lambda: self.l2cap_channel_manager.connect(connection, psm)
def create_l2cap_registrar(self, psm):
return lambda handler: self.register_l2cap_server(psm, handler)
def register_l2cap_server(self, psm, server):
self.l2cap_channel_manager.register_server(psm, server)
def register_l2cap_channel_server(
self,
psm,
server,
max_credits=DEVICE_DEFAULT_L2CAP_COC_MAX_CREDITS,
mtu=DEVICE_DEFAULT_L2CAP_COC_MTU,
mps=DEVICE_DEFAULT_L2CAP_COC_MPS
):
return self.l2cap_channel_manager.register_le_coc_server(psm, server, max_credits, mtu, mps)
async def open_l2cap_channel(
self,
connection,
psm,
max_credits=DEVICE_DEFAULT_L2CAP_COC_MAX_CREDITS,
mtu=DEVICE_DEFAULT_L2CAP_COC_MTU,
mps=DEVICE_DEFAULT_L2CAP_COC_MPS
):
return await self.l2cap_channel_manager.open_le_coc(connection, psm, max_credits, mtu, mps)
def send_l2cap_pdu(self, connection_handle, cid, pdu):
self.host.send_l2cap_pdu(connection_handle, cid, pdu)
@@ -1185,13 +1217,15 @@ class Device(CompositeEventEmitter):
def on_connection(connection):
if transport == BT_LE_TRANSPORT or (
# match BR/EDR connection event against peer address
connection.transport == transport and connection.peer_address == peer_address):
connection.transport == transport and connection.peer_address == peer_address
):
pending_connection.set_result(connection)
def on_connection_failure(error):
if transport == BT_LE_TRANSPORT or (
# match BR/EDR connection failure event against peer address
error.transport == transport and error.peer_address == peer_address):
error.transport == transport and error.peer_address == peer_address
):
pending_connection.set_exception(error)
# Create a future so that we can wait for the connection's result
@@ -1336,7 +1370,7 @@ class Device(CompositeEventEmitter):
if peer_address == Address.NIL:
raise ValueError('accept on nil address')
# Create a future so that we can wait for the request
# Create a future so that we can wait for the request
pending_request = asyncio.get_running_loop().create_future()
if peer_address == Address.ANY:
@@ -1349,8 +1383,7 @@ class Device(CompositeEventEmitter):
try:
# Wait for a request or a completed connection
result = await (asyncio.wait_for(pending_request, timeout) if timeout else pending_request)
except:
except Exception:
# Remove future from device context
if peer_address == Address.ANY:
self.classic_pending_accepts[Address.ANY].remove(pending_request)
@@ -1710,26 +1743,32 @@ class Device(CompositeEventEmitter):
connection.remove_listener('connection_encryption_failure', on_encryption_failure)
# [Classic only]
async def request_remote_name(self, remote: Connection | Address):
async def request_remote_name(self, remote): # remote: Connection | Address
# Set up event handlers
pending_name = asyncio.get_running_loop().create_future()
if type(remote) == Address:
peer_address = remote
handler = self.on('remote_name',
handler = self.on(
'remote_name',
lambda address, remote_name:
pending_name.set_result(remote_name) if address == remote else None)
failure_handler = self.on('remote_name_failure',
pending_name.set_result(remote_name) if address == remote else None
)
failure_handler = self.on(
'remote_name_failure',
lambda address, error_code:
pending_name.set_exception(HCI_Error(error_code)) if address == remote else None)
pending_name.set_exception(HCI_Error(error_code)) if address == remote else None
)
else:
peer_address = remote.peer_address
handler = remote.on('remote_name',
lambda:
pending_name.set_result(remote.peer_name))
failure_handler = remote.on('remote_name_failure',
lambda error_code:
pending_name.set_exception(HCI_Error(error_code)))
handler = remote.on(
'remote_name',
lambda: pending_name.set_result(remote.peer_name)
)
failure_handler = remote.on(
'remote_name_failure',
lambda error_code: pending_name.set_exception(HCI_Error(error_code))
)
try:
result = await self.send_command(
@@ -2097,7 +2136,6 @@ class Device(CompositeEventEmitter):
else:
self.emit('remote_name_failure', address, error)
# [Classic only]
@host_event_handler
@try_with_connection_from_address

View File

@@ -25,6 +25,7 @@
import asyncio
import types
import logging
from pyee import EventEmitter
from colors import color
from .core import *

View File

@@ -273,7 +273,7 @@ class Client:
if response.op_code == ATT_ERROR_RESPONSE:
if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR:
# Unexpected end
logger.waning(f'!!! unexpected error while discovering services: {HCI_Constant.error_name(response.error_code)}')
logger.warning(f'!!! unexpected error while discovering services: {HCI_Constant.error_name(response.error_code)}')
# TODO raise appropriate exception
return
break
@@ -337,7 +337,7 @@ class Client:
if response.op_code == ATT_ERROR_RESPONSE:
if response.error_code != ATT_ATTRIBUTE_NOT_FOUND_ERROR:
# Unexpected end
logger.waning(f'!!! unexpected error while discovering services: {HCI_Constant.error_name(response.error_code)}')
logger.warning(f'!!! unexpected error while discovering services: {HCI_Constant.error_name(response.error_code)}')
# TODO raise appropriate exception
return
break

View File

@@ -155,7 +155,7 @@ class Server(EventEmitter):
return cccd or bytes([0, 0])
def write_cccd(self, connection, characteristic, value):
logger.debug(f'Subscription update for connection={connection.handle:04X}, handle={characteristic.handle:04X}: {value.hex()}')
logger.debug(f'Subscription update for connection=0x{connection.handle:04X}, handle=0x{characteristic.handle:04X}: {value.hex()}')
# Sanity check
if len(value) != 2:
@@ -204,7 +204,7 @@ class Server(EventEmitter):
logger.debug(f'GATT Notify from server: [0x{connection.handle:04X}] {notification}')
self.send_gatt_pdu(connection.handle, bytes(notification))
async def indicate_subscriber(self, connection, attribute, value=None, force=False):
async def indicate_subscriber(self, connection, attribute, force=False):
# Check if there's a subscriber
if not force:
subscribers = self.subscribers.get(connection.handle)

View File

@@ -2466,9 +2466,10 @@ class HCI_Write_Voice_Setting_Command(HCI_Command):
# -----------------------------------------------------------------------------
@HCI_Command.command()
class HCI_Read_Synchronous_Flow_Control_Enable_Command(HCI_Command):
'''
See Bluetooth spec @ 7.3.36 Write Synchronous Flow Control Enable Command
See Bluetooth spec @ 7.3.36 Read Synchronous Flow Control Enable Command
'''

View File

@@ -79,6 +79,8 @@ class Host(EventEmitter):
self.local_version = None
self.local_supported_commands = bytes(64)
self.local_le_features = 0
self.suggested_max_tx_octets = 251 # Max allowed
self.suggested_max_tx_time = 2120 # Max allowed
self.command_semaphore = asyncio.Semaphore(1)
self.long_term_key_provider = None
self.link_key_provider = None
@@ -138,6 +140,22 @@ class Host(EventEmitter):
f'hc_total_num_le_acl_data_packets={self.hc_total_num_le_acl_data_packets}'
)
if (
self.supports_command(HCI_LE_READ_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND) and
self.supports_command(HCI_LE_WRITE_SUGGESTED_DEFAULT_DATA_LENGTH_COMMAND)
):
response = await self.send_command(HCI_LE_Read_Suggested_Default_Data_Length_Command())
suggested_max_tx_octets = response.return_parameters.suggested_max_tx_octets
suggested_max_tx_time = response.return_parameters.suggested_max_tx_time
if (
suggested_max_tx_octets != self.suggested_max_tx_octets or
suggested_max_tx_time != self.suggested_max_tx_time
):
await self.send_command(HCI_LE_Write_Suggested_Default_Data_Length_Command(
suggested_max_tx_octets = self.suggested_max_tx_octets,
suggested_max_tx_time = self.suggested_max_tx_time
))
self.reset_done = True
@property

File diff suppressed because it is too large Load Diff

View File

@@ -274,7 +274,7 @@ class PumpedPacketSource(ParserSource):
self.terminated.set_result(error)
break
self.pump_task = asyncio.get_running_loop().create_task(pump_packets())
self.pump_task = asyncio.create_task(pump_packets())
def close(self):
if self.pump_task:
@@ -304,7 +304,7 @@ class PumpedPacketSink:
logger.warn(f'exception while sending packet: {error}')
break
self.pump_task = asyncio.get_running_loop().create_task(pump_packets())
self.pump_task = asyncio.create_task(pump_packets())
def close(self):
if self.pump_task:

View File

@@ -18,6 +18,7 @@
import asyncio
import logging
import traceback
import collections
from functools import wraps
from colors import color
from pyee import EventEmitter
@@ -140,3 +141,95 @@ class AsyncRunner:
return wrapper
return decorator
# -----------------------------------------------------------------------------
class FlowControlAsyncPipe:
"""
Asyncio pipe with flow control. When writing to the pipe, the source is
paused (by calling a function passed in when the pipe is created) if the
amount of queued data exceeds a specified threshold.
"""
def __init__(self, pause_source, resume_source, write_to_sink=None, drain_sink=None, threshold=0):
self.pause_source = pause_source
self.resume_source = resume_source
self.write_to_sink = write_to_sink
self.drain_sink = drain_sink
self.threshold = threshold
self.queue = collections.deque() # Queue of packets
self.queued_bytes = 0 # Number of bytes in the queue
self.ready_to_pump = asyncio.Event()
self.paused = False
self.source_paused = False
self.pump_task = None
def start(self):
if self.pump_task is None:
self.pump_task = asyncio.create_task(self.pump())
self.check_pump()
def stop(self):
if self.pump_task is not None:
self.pump_task.cancel()
self.pump_task = None
def write(self, packet):
self.queued_bytes += len(packet)
self.queue.append(packet)
# Pause the source if we're over the threshold
if self.queued_bytes > self.threshold and not self.source_paused:
logger.debug(f'pausing source (queued={self.queued_bytes})')
self.pause_source()
self.source_paused = True
self.check_pump()
def pause(self):
if not self.paused:
self.paused = True
if not self.source_paused:
self.pause_source()
self.source_paused = True
self.check_pump()
def resume(self):
if self.paused:
self.paused = False
if self.source_paused:
self.resume_source()
self.source_paused = False
self.check_pump()
def can_pump(self):
return self.queue and not self.paused and self.write_to_sink is not None
def check_pump(self):
if self.can_pump():
self.ready_to_pump.set()
else:
self.ready_to_pump.clear()
async def pump(self):
while True:
# Wait until we can try to pump packets
await self.ready_to_pump.wait()
# Try to pump a packet
if self.can_pump():
packet = self.queue.pop()
self.write_to_sink(packet)
self.queued_bytes -= len(packet)
# Drain the sink if we can
if self.drain_sink:
await self.drain_sink()
# Check if we can accept more
if self.queued_bytes <= self.threshold and self.source_paused:
logger.debug(f'resuming source (queued={self.queued_bytes})')
self.source_paused = False
self.resume_source()
self.check_pump()