Merge pull request #376 from zxzxwu/host

Manage lifecycle of CIS and SCO links in host
This commit is contained in:
zxzxwu
2024-01-28 22:09:08 +08:00
committed by GitHub
4 changed files with 82 additions and 43 deletions

View File

@@ -3080,34 +3080,30 @@ class Device(CompositeEventEmitter):
cig_id=cig_id, cig_id=cig_id,
) )
result = await self.send_command(
HCI_LE_Create_CIS_Command(
cis_connection_handle=[p[0] for p in cis_acl_pairs],
acl_connection_handle=[p[1] for p in cis_acl_pairs],
),
)
if result.status != HCI_COMMAND_STATUS_PENDING:
logger.warning(
'HCI_LE_Create_CIS_Command failed: '
f'{HCI_Constant.error_name(result.status)}'
)
raise HCI_StatusError(result)
pending_cis_establishments: Dict[int, asyncio.Future[CisLink]] = {}
for cis_handle, _ in cis_acl_pairs:
pending_cis_establishments[
cis_handle
] = asyncio.get_running_loop().create_future()
with closing(EventWatcher()) as watcher: with closing(EventWatcher()) as watcher:
pending_cis_establishments = {
cis_handle: asyncio.get_running_loop().create_future()
for cis_handle, _ in cis_acl_pairs
}
@watcher.on(self, 'cis_establishment') @watcher.on(self, 'cis_establishment')
def on_cis_establishment(cis_link: CisLink) -> None: def on_cis_establishment(cis_link: CisLink) -> None:
if pending_future := pending_cis_establishments.get( if pending_future := pending_cis_establishments.get(cis_link.handle):
cis_link.handle, None
):
pending_future.set_result(cis_link) pending_future.set_result(cis_link)
result = await self.send_command(
HCI_LE_Create_CIS_Command(
cis_connection_handle=[p[0] for p in cis_acl_pairs],
acl_connection_handle=[p[1] for p in cis_acl_pairs],
),
)
if result.status != HCI_COMMAND_STATUS_PENDING:
logger.warning(
'HCI_LE_Create_CIS_Command failed: '
f'{HCI_Constant.error_name(result.status)}'
)
raise HCI_StatusError(result)
return await asyncio.gather(*pending_cis_establishments.values()) return await asyncio.gather(*pending_cis_establishments.values())
# [LE only] # [LE only]
@@ -3755,7 +3751,7 @@ class Device(CompositeEventEmitter):
@host_event_handler @host_event_handler
@experimental('Only for testing') @experimental('Only for testing')
def on_sco_packet(self, sco_handle: int, packet: HCI_SynchronousDataPacket) -> None: def on_sco_packet(self, sco_handle: int, packet: HCI_SynchronousDataPacket) -> None:
if sco_link := self.sco_links.get(sco_handle, None): if sco_link := self.sco_links.get(sco_handle):
sco_link.emit('pdu', packet) sco_link.emit('pdu', packet)
# [LE only] # [LE only]
@@ -3835,7 +3831,7 @@ class Device(CompositeEventEmitter):
@experimental('Only for testing') @experimental('Only for testing')
def on_cis_establishment_failure(self, cis_handle: int, status: int) -> None: def on_cis_establishment_failure(self, cis_handle: int, status: int) -> None:
logger.debug(f'*** CIS Establishment Failure: cis=[0x{cis_handle:04X}] ***') logger.debug(f'*** CIS Establishment Failure: cis=[0x{cis_handle:04X}] ***')
if cis_link := self.cis_links.pop(cis_handle, None): if cis_link := self.cis_links.pop(cis_handle):
cis_link.emit('establishment_failure') cis_link.emit('establishment_failure')
self.emit('cis_establishment_failure', cis_handle, status) self.emit('cis_establishment_failure', cis_handle, status)
@@ -3843,7 +3839,7 @@ class Device(CompositeEventEmitter):
@host_event_handler @host_event_handler
@experimental('Only for testing') @experimental('Only for testing')
def on_iso_packet(self, handle: int, packet: HCI_IsoDataPacket) -> None: def on_iso_packet(self, handle: int, packet: HCI_IsoDataPacket) -> None:
if cis_link := self.cis_links.get(handle, None): if cis_link := self.cis_links.get(handle):
cis_link.emit('pdu', packet) cis_link.emit('pdu', packet)
@host_event_handler @host_event_handler

View File

@@ -18,6 +18,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import collections import collections
import dataclasses
import logging import logging
import struct import struct
@@ -161,9 +162,25 @@ class Connection:
self.host.on_l2cap_pdu(self, l2cap_pdu.cid, l2cap_pdu.payload) self.host.on_l2cap_pdu(self, l2cap_pdu.cid, l2cap_pdu.payload)
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class ScoLink:
peer_address: Address
handle: int
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class CisLink:
peer_address: Address
handle: int
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
class Host(AbortableEventEmitter): class Host(AbortableEventEmitter):
connections: Dict[int, Connection] connections: Dict[int, Connection]
cis_links: Dict[int, CisLink]
sco_links: Dict[int, ScoLink]
acl_packet_queue: Optional[AclPacketQueue] = None acl_packet_queue: Optional[AclPacketQueue] = None
le_acl_packet_queue: Optional[AclPacketQueue] = None le_acl_packet_queue: Optional[AclPacketQueue] = None
hci_sink: Optional[TransportSink] = None hci_sink: Optional[TransportSink] = None
@@ -183,6 +200,8 @@ class Host(AbortableEventEmitter):
self.hci_metadata = {} self.hci_metadata = {}
self.ready = False # True when we can accept incoming packets self.ready = False # True when we can accept incoming packets
self.connections = {} # Connections, by connection handle self.connections = {} # Connections, by connection handle
self.cis_links = {} # CIS links, by connection handle
self.sco_links = {} # SCO links, by connection handle
self.pending_command = None self.pending_command = None
self.pending_response = None self.pending_response = None
self.local_version = None self.local_version = None
@@ -696,25 +715,36 @@ class Host(AbortableEventEmitter):
def on_hci_disconnection_complete_event(self, event): def on_hci_disconnection_complete_event(self, event):
# Find the connection # Find the connection
if (connection := self.connections.get(event.connection_handle)) is None: handle = event.connection_handle
if (
connection := (
self.connections.get(handle)
or self.cis_links.get(handle)
or self.sco_links.get(handle)
)
) is None:
logger.warning('!!! DISCONNECTION COMPLETE: unknown handle') logger.warning('!!! DISCONNECTION COMPLETE: unknown handle')
return return
if event.status == HCI_SUCCESS: if event.status == HCI_SUCCESS:
logger.debug( logger.debug(
f'### DISCONNECTION: [0x{event.connection_handle:04X}] ' f'### DISCONNECTION: [0x{handle:04X}] '
f'{connection.peer_address} ' f'{connection.peer_address} '
f'reason={event.reason}' f'reason={event.reason}'
) )
del self.connections[event.connection_handle]
# Notify the listeners # Notify the listeners
self.emit('disconnection', event.connection_handle, event.reason) self.emit('disconnection', handle, event.reason)
(
self.connections.pop(handle, 0)
or self.cis_links.pop(handle, 0)
or self.sco_links.pop(handle, 0)
)
else: else:
logger.debug(f'### DISCONNECTION FAILED: {event.status}') logger.debug(f'### DISCONNECTION FAILED: {event.status}')
# Notify the listeners # Notify the listeners
self.emit('disconnection_failure', event.connection_handle, event.status) self.emit('disconnection_failure', handle, event.status)
def on_hci_le_connection_update_complete_event(self, event): def on_hci_le_connection_update_complete_event(self, event):
if (connection := self.connections.get(event.connection_handle)) is None: if (connection := self.connections.get(event.connection_handle)) is None:
@@ -775,6 +805,10 @@ class Host(AbortableEventEmitter):
def on_hci_le_cis_established_event(self, event): def on_hci_le_cis_established_event(self, event):
# The remaining parameters are unused for now. # The remaining parameters are unused for now.
if event.status == HCI_SUCCESS: if event.status == HCI_SUCCESS:
self.cis_links[event.connection_handle] = CisLink(
handle=event.connection_handle,
peer_address=Address.ANY,
)
self.emit('cis_establishment', event.connection_handle) self.emit('cis_establishment', event.connection_handle)
else: else:
self.emit( self.emit(
@@ -841,6 +875,11 @@ class Host(AbortableEventEmitter):
f'{event.bd_addr}' f'{event.bd_addr}'
) )
self.sco_links[event.connection_handle] = ScoLink(
peer_address=event.bd_addr,
handle=event.connection_handle,
)
# Notify the client # Notify the client
self.emit( self.emit(
'sco_connection', 'sco_connection',

View File

@@ -467,9 +467,8 @@ async def test_cis():
await asyncio.gather(*peripheral_cis_futures.values()) await asyncio.gather(*peripheral_cis_futures.values())
assert len(cis_links) == 2 assert len(cis_links) == 2
# TODO: Fix Host CIS support. await cis_links[0].disconnect()
# await cis_links[0].disconnect() await cis_links[1].disconnect()
# await cis_links[1].disconnect()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------

View File

@@ -24,11 +24,11 @@ from typing import Tuple
from .test_utils import TwoDevices from .test_utils import TwoDevices
from bumble import core from bumble import core
from bumble import device
from bumble import hfp from bumble import hfp
from bumble import rfcomm from bumble import rfcomm
from bumble import hci from bumble import hci
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Logging # Logging
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@@ -109,7 +109,7 @@ async def test_sco_setup():
devices[1].accept(devices[0].public_address), devices[1].accept(devices[0].public_address),
) )
def on_sco_request(_connection: device.Connection, _link_type: int): def on_sco_request(_connection, _link_type: int):
connections[1].abort_on( connections[1].abort_on(
'disconnection', 'disconnection',
devices[1].send_command( devices[1].send_command(
@@ -124,17 +124,13 @@ async def test_sco_setup():
devices[1].on('sco_request', on_sco_request) devices[1].on('sco_request', on_sco_request)
sco_connections = [ sco_connection_futures = [
asyncio.get_running_loop().create_future(), asyncio.get_running_loop().create_future(),
asyncio.get_running_loop().create_future(), asyncio.get_running_loop().create_future(),
] ]
devices[0].on( for device, future in zip(devices, sco_connection_futures):
'sco_connection', lambda sco_link: sco_connections[0].set_result(sco_link) device.on('sco_connection', future.set_result)
)
devices[1].on(
'sco_connection', lambda sco_link: sco_connections[1].set_result(sco_link)
)
await devices[0].send_command( await devices[0].send_command(
hci.HCI_Enhanced_Setup_Synchronous_Connection_Command( hci.HCI_Enhanced_Setup_Synchronous_Connection_Command(
@@ -142,8 +138,17 @@ async def test_sco_setup():
**hfp.ESCO_PARAMETERS[hfp.DefaultCodecParameters.ESCO_CVSD_S1].asdict(), **hfp.ESCO_PARAMETERS[hfp.DefaultCodecParameters.ESCO_CVSD_S1].asdict(),
) )
) )
sco_connections = await asyncio.gather(*sco_connection_futures)
await asyncio.gather(*sco_connections) sco_disconnection_futures = [
asyncio.get_running_loop().create_future(),
asyncio.get_running_loop().create_future(),
]
for future, sco_connection in zip(sco_disconnection_futures, sco_connections):
sco_connection.on('disconnection', future.set_result)
await sco_connections[0].disconnect()
await asyncio.gather(*sco_disconnection_futures)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------