Merge pull request #376 from zxzxwu/host

Manage lifecycle of CIS and SCO links in host
This commit is contained in:
zxzxwu
2024-01-28 22:09:08 +08:00
committed by GitHub
4 changed files with 82 additions and 43 deletions

View File

@@ -3080,6 +3080,17 @@ class Device(CompositeEventEmitter):
cig_id=cig_id,
)
with closing(EventWatcher()) as watcher:
pending_cis_establishments = {
cis_handle: asyncio.get_running_loop().create_future()
for cis_handle, _ in cis_acl_pairs
}
@watcher.on(self, 'cis_establishment')
def on_cis_establishment(cis_link: CisLink) -> None:
if pending_future := pending_cis_establishments.get(cis_link.handle):
pending_future.set_result(cis_link)
result = await self.send_command(
HCI_LE_Create_CIS_Command(
cis_connection_handle=[p[0] for p in cis_acl_pairs],
@@ -3093,21 +3104,6 @@ class Device(CompositeEventEmitter):
)
raise HCI_StatusError(result)
pending_cis_establishments: Dict[int, asyncio.Future[CisLink]] = {}
for cis_handle, _ in cis_acl_pairs:
pending_cis_establishments[
cis_handle
] = asyncio.get_running_loop().create_future()
with closing(EventWatcher()) as watcher:
@watcher.on(self, 'cis_establishment')
def on_cis_establishment(cis_link: CisLink) -> None:
if pending_future := pending_cis_establishments.get(
cis_link.handle, None
):
pending_future.set_result(cis_link)
return await asyncio.gather(*pending_cis_establishments.values())
# [LE only]
@@ -3755,7 +3751,7 @@ class Device(CompositeEventEmitter):
@host_event_handler
@experimental('Only for testing')
def on_sco_packet(self, sco_handle: int, packet: HCI_SynchronousDataPacket) -> None:
if sco_link := self.sco_links.get(sco_handle, None):
if sco_link := self.sco_links.get(sco_handle):
sco_link.emit('pdu', packet)
# [LE only]
@@ -3835,7 +3831,7 @@ class Device(CompositeEventEmitter):
@experimental('Only for testing')
def on_cis_establishment_failure(self, cis_handle: int, status: int) -> None:
logger.debug(f'*** CIS Establishment Failure: cis=[0x{cis_handle:04X}] ***')
if cis_link := self.cis_links.pop(cis_handle, None):
if cis_link := self.cis_links.pop(cis_handle):
cis_link.emit('establishment_failure')
self.emit('cis_establishment_failure', cis_handle, status)
@@ -3843,7 +3839,7 @@ class Device(CompositeEventEmitter):
@host_event_handler
@experimental('Only for testing')
def on_iso_packet(self, handle: int, packet: HCI_IsoDataPacket) -> None:
if cis_link := self.cis_links.get(handle, None):
if cis_link := self.cis_links.get(handle):
cis_link.emit('pdu', packet)
@host_event_handler

View File

@@ -18,6 +18,7 @@
from __future__ import annotations
import asyncio
import collections
import dataclasses
import logging
import struct
@@ -161,9 +162,25 @@ class Connection:
self.host.on_l2cap_pdu(self, l2cap_pdu.cid, l2cap_pdu.payload)
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class ScoLink:
peer_address: Address
handle: int
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class CisLink:
peer_address: Address
handle: int
# -----------------------------------------------------------------------------
class Host(AbortableEventEmitter):
connections: Dict[int, Connection]
cis_links: Dict[int, CisLink]
sco_links: Dict[int, ScoLink]
acl_packet_queue: Optional[AclPacketQueue] = None
le_acl_packet_queue: Optional[AclPacketQueue] = None
hci_sink: Optional[TransportSink] = None
@@ -183,6 +200,8 @@ class Host(AbortableEventEmitter):
self.hci_metadata = {}
self.ready = False # True when we can accept incoming packets
self.connections = {} # Connections, by connection handle
self.cis_links = {} # CIS links, by connection handle
self.sco_links = {} # SCO links, by connection handle
self.pending_command = None
self.pending_response = None
self.local_version = None
@@ -696,25 +715,36 @@ class Host(AbortableEventEmitter):
def on_hci_disconnection_complete_event(self, event):
# Find the connection
if (connection := self.connections.get(event.connection_handle)) is None:
handle = event.connection_handle
if (
connection := (
self.connections.get(handle)
or self.cis_links.get(handle)
or self.sco_links.get(handle)
)
) is None:
logger.warning('!!! DISCONNECTION COMPLETE: unknown handle')
return
if event.status == HCI_SUCCESS:
logger.debug(
f'### DISCONNECTION: [0x{event.connection_handle:04X}] '
f'### DISCONNECTION: [0x{handle:04X}] '
f'{connection.peer_address} '
f'reason={event.reason}'
)
del self.connections[event.connection_handle]
# Notify the listeners
self.emit('disconnection', event.connection_handle, event.reason)
self.emit('disconnection', handle, event.reason)
(
self.connections.pop(handle, 0)
or self.cis_links.pop(handle, 0)
or self.sco_links.pop(handle, 0)
)
else:
logger.debug(f'### DISCONNECTION FAILED: {event.status}')
# Notify the listeners
self.emit('disconnection_failure', event.connection_handle, event.status)
self.emit('disconnection_failure', handle, event.status)
def on_hci_le_connection_update_complete_event(self, event):
if (connection := self.connections.get(event.connection_handle)) is None:
@@ -775,6 +805,10 @@ class Host(AbortableEventEmitter):
def on_hci_le_cis_established_event(self, event):
# The remaining parameters are unused for now.
if event.status == HCI_SUCCESS:
self.cis_links[event.connection_handle] = CisLink(
handle=event.connection_handle,
peer_address=Address.ANY,
)
self.emit('cis_establishment', event.connection_handle)
else:
self.emit(
@@ -841,6 +875,11 @@ class Host(AbortableEventEmitter):
f'{event.bd_addr}'
)
self.sco_links[event.connection_handle] = ScoLink(
peer_address=event.bd_addr,
handle=event.connection_handle,
)
# Notify the client
self.emit(
'sco_connection',

View File

@@ -467,9 +467,8 @@ async def test_cis():
await asyncio.gather(*peripheral_cis_futures.values())
assert len(cis_links) == 2
# TODO: Fix Host CIS support.
# await cis_links[0].disconnect()
# await cis_links[1].disconnect()
await cis_links[0].disconnect()
await cis_links[1].disconnect()
# -----------------------------------------------------------------------------

View File

@@ -24,11 +24,11 @@ from typing import Tuple
from .test_utils import TwoDevices
from bumble import core
from bumble import device
from bumble import hfp
from bumble import rfcomm
from bumble import hci
# -----------------------------------------------------------------------------
# Logging
# -----------------------------------------------------------------------------
@@ -109,7 +109,7 @@ async def test_sco_setup():
devices[1].accept(devices[0].public_address),
)
def on_sco_request(_connection: device.Connection, _link_type: int):
def on_sco_request(_connection, _link_type: int):
connections[1].abort_on(
'disconnection',
devices[1].send_command(
@@ -124,17 +124,13 @@ async def test_sco_setup():
devices[1].on('sco_request', on_sco_request)
sco_connections = [
sco_connection_futures = [
asyncio.get_running_loop().create_future(),
asyncio.get_running_loop().create_future(),
]
devices[0].on(
'sco_connection', lambda sco_link: sco_connections[0].set_result(sco_link)
)
devices[1].on(
'sco_connection', lambda sco_link: sco_connections[1].set_result(sco_link)
)
for device, future in zip(devices, sco_connection_futures):
device.on('sco_connection', future.set_result)
await devices[0].send_command(
hci.HCI_Enhanced_Setup_Synchronous_Connection_Command(
@@ -142,8 +138,17 @@ async def test_sco_setup():
**hfp.ESCO_PARAMETERS[hfp.DefaultCodecParameters.ESCO_CVSD_S1].asdict(),
)
)
sco_connections = await asyncio.gather(*sco_connection_futures)
await asyncio.gather(*sco_connections)
sco_disconnection_futures = [
asyncio.get_running_loop().create_future(),
asyncio.get_running_loop().create_future(),
]
for future, sco_connection in zip(sco_disconnection_futures, sco_connections):
sco_connection.on('disconnection', future.set_result)
await sco_connections[0].disconnect()
await asyncio.gather(*sco_disconnection_futures)
# -----------------------------------------------------------------------------