Handle ISO data path race condition

This commit is contained in:
Josh Wu
2025-09-12 14:55:45 +08:00
parent 7d8addb849
commit 83c5061700
3 changed files with 98 additions and 48 deletions

View File

@@ -33,7 +33,6 @@ from bumble.hci import (
HCI_COMMAND_DISALLOWED_ERROR,
HCI_COMMAND_PACKET,
HCI_COMMAND_STATUS_PENDING,
HCI_CONNECTION_TIMEOUT_ERROR,
HCI_CONTROLLER_BUSY_ERROR,
HCI_EVENT_PACKET,
HCI_INVALID_HCI_COMMAND_PARAMETERS_ERROR,
@@ -88,6 +87,7 @@ class CisLink:
cis_id: int
cig_id: int
acl_connection: Optional[Connection] = None
data_paths: set[int] = dataclasses.field(default_factory=set)
# -----------------------------------------------------------------------------
@@ -381,6 +381,11 @@ class Controller:
return connection
return None
def find_iso_link_by_handle(self, handle: int) -> Optional[CisLink]:
return self.central_cis_links.get(handle) or self.peripheral_cis_links.get(
handle
)
def on_link_central_connected(self, central_address):
'''
Called when an incoming connection occurs from a central on the link
@@ -1853,16 +1858,51 @@ class Controller:
)
)
def on_hci_le_setup_iso_data_path_command(self, command):
def on_hci_le_setup_iso_data_path_command(
self, command: hci.HCI_LE_Setup_ISO_Data_Path_Command
) -> bytes:
'''
See Bluetooth spec Vol 4, Part E - 7.8.109 LE Setup ISO Data Path Command
'''
if not (iso_link := self.find_iso_link_by_handle(command.connection_handle)):
return struct.pack(
'<BH',
HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR,
command.connection_handle,
)
if command.data_path_direction in iso_link.data_paths:
return struct.pack(
'<BH',
HCI_COMMAND_DISALLOWED_ERROR,
command.connection_handle,
)
iso_link.data_paths.add(command.data_path_direction)
return struct.pack('<BH', HCI_SUCCESS, command.connection_handle)
def on_hci_le_remove_iso_data_path_command(self, command):
def on_hci_le_remove_iso_data_path_command(
self, command: hci.HCI_LE_Remove_ISO_Data_Path_Command
) -> bytes:
'''
See Bluetooth spec Vol 4, Part E - 7.8.110 LE Remove ISO Data Path Command
'''
if not (iso_link := self.find_iso_link_by_handle(command.connection_handle)):
return struct.pack(
'<BH',
HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR,
command.connection_handle,
)
data_paths: set[int] = set(
direction
for direction in hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction
if (1 << direction) & command.data_path_direction
)
if not data_paths.issubset(iso_link.data_paths):
return struct.pack(
'<BH',
HCI_COMMAND_DISALLOWED_ERROR,
command.connection_handle,
)
iso_link.data_paths.difference_update(data_paths)
return struct.pack('<BH', HCI_SUCCESS, command.connection_handle)
def on_hci_le_set_host_feature_command(

View File

@@ -1453,6 +1453,8 @@ class _IsoLink:
handle: int
device: Device
sink: Callable[[hci.HCI_IsoDataPacket], Any] | None = None
data_paths: set[_IsoLink.Direction]
_data_path_lock: asyncio.Lock
class Direction(IntEnum):
HOST_TO_CONTROLLER = (
@@ -1462,6 +1464,10 @@ class _IsoLink:
hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction.CONTROLLER_TO_HOST
)
def __init__(self) -> None:
self._data_path_lock = asyncio.Lock()
self.data_paths = set()
async def setup_data_path(
self,
direction: _IsoLink.Direction,
@@ -1482,37 +1488,45 @@ class _IsoLink:
Raises:
HCI_Error: When command complete status is not HCI_SUCCESS.
"""
await self.device.send_command(
hci.HCI_LE_Setup_ISO_Data_Path_Command(
connection_handle=self.handle,
data_path_direction=direction,
data_path_id=data_path_id,
codec_id=codec_id or hci.CodingFormat(hci.CodecID.TRANSPARENT),
controller_delay=controller_delay,
codec_configuration=codec_configuration,
),
check_result=True,
)
async with self._data_path_lock:
if direction in self.data_paths:
return
await self.device.send_command(
hci.HCI_LE_Setup_ISO_Data_Path_Command(
connection_handle=self.handle,
data_path_direction=direction,
data_path_id=data_path_id,
codec_id=codec_id or hci.CodingFormat(hci.CodecID.TRANSPARENT),
controller_delay=controller_delay,
codec_configuration=codec_configuration,
),
check_result=True,
)
self.data_paths.add(direction)
async def remove_data_path(self, directions: Iterable[_IsoLink.Direction]) -> int:
async def remove_data_path(self, directions: Iterable[_IsoLink.Direction]) -> None:
"""Remove a data path with controller on given direction.
Args:
direction: Direction of data path.
Returns:
Command status.
Raises:
HCI_Error: When command complete status is not HCI_SUCCESS.
"""
response = await self.device.send_command(
hci.HCI_LE_Remove_ISO_Data_Path_Command(
connection_handle=self.handle,
data_path_direction=sum(
1 << direction for direction in set(directions)
async with self._data_path_lock:
directions_to_remove = set(directions).intersection(self.data_paths)
if not directions_to_remove:
return
await self.device.send_command(
hci.HCI_LE_Remove_ISO_Data_Path_Command(
connection_handle=self.handle,
data_path_direction=sum(
1 << direction for direction in directions_to_remove
),
),
),
check_result=False,
)
return response.return_parameters.status
check_result=True,
)
self.data_paths.difference_update(directions_to_remove)
def write(self, sdu: bytes) -> None:
"""Write an ISO SDU."""
@@ -1622,7 +1636,8 @@ class CisLink(utils.EventEmitter, _IsoLink):
EVENT_ESTABLISHMENT_FAILURE: ClassVar[str] = "establishment_failure"
def __post_init__(self) -> None:
super().__init__()
utils.EventEmitter.__init__(self)
_IsoLink.__init__(self)
async def disconnect(
self, reason: int = hci.HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR
@@ -1638,6 +1653,7 @@ class BisLink(_IsoLink):
sink: Callable[[hci.HCI_IsoDataPacket], Any] | None = None
def __post_init__(self) -> None:
super().__init__()
self.device = self.big.device

View File

@@ -310,12 +310,12 @@ async def test_pacs():
@pytest.mark.asyncio
async def test_ascs():
devices = TwoDevices()
devices[0].add_service(
AudioStreamControlService(device=devices[0], sink_ase_id=[1, 2])
devices[1].add_service(
AudioStreamControlService(device=devices[1], sink_ase_id=[1, 2])
)
await devices.setup_connection()
peer = device.Peer(devices.connections[1])
peer = device.Peer(devices.connections[0])
ascs_client = await peer.discover_service_and_create_proxy(
AudioStreamControlServiceProxy
)
@@ -369,7 +369,7 @@ async def test_ascs():
await ascs_client.ase_control_point.write_value(
ASE_Config_QOS(
ase_id=[1, 2],
cig_id=[1, 2],
cig_id=[1, 1],
cis_id=[3, 4],
sdu_interval=[5, 6],
framing=[0, 1],
@@ -402,25 +402,19 @@ async def test_ascs():
)
# CIS establishment
devices[0].emit(
'cis_establishment',
device.CisLink(
device=devices[0],
acl_connection=devices.connections[0],
handle=5,
cis_id=3,
cis_handles = await devices[0].setup_cig(
device.CigParameters(
cig_id=1,
),
cis_parameters=[
device.CigParameters.CisParameters(cis_id=3),
device.CigParameters.CisParameters(cis_id=4),
],
sdu_interval_c_to_p=0,
sdu_interval_p_to_c=0,
)
)
devices[0].emit(
'cis_establishment',
device.CisLink(
device=devices[0],
acl_connection=devices.connections[0],
handle=6,
cis_id=4,
cig_id=2,
),
await devices[0].create_cis(
[(cis_handle, devices.connections[0]) for cis_handle in cis_handles]
)
assert (await notifications[1].get())[:2] == bytes(
[1, AseStateMachine.State.STREAMING]