fix: properly cleanup audio inputs and PipeWire capture nodes on stream stop

This commit is contained in:
pstruebi
2025-10-01 14:18:55 +02:00
parent 125385a202
commit b922eca39c
4 changed files with 226 additions and 156 deletions

View File

@@ -375,18 +375,51 @@ class Streamer():
if self.task is not None:
self.task.cancel()
# Let cancellation propagate to the stream() coroutine
await asyncio.sleep(0.01)
self.task = None
# Close audio inputs (await to ensure ALSA devices are released)
close_tasks = []
async_closers = []
sync_closers = []
for big in self.bigs.values():
ai = big.get("audio_input")
if ai and hasattr(ai, "close"):
close_tasks.append(ai.close())
# Remove reference so a fresh one is created next time
big.pop("audio_input", None)
if close_tasks:
await asyncio.gather(*close_tasks, return_exceptions=True)
if not ai:
continue
# First close any frames generator backed by the input to stop reads
frames_gen = big.get("frames_gen")
if frames_gen and hasattr(frames_gen, "aclose"):
try:
await frames_gen.aclose()
except Exception:
pass
big.pop("frames_gen", None)
if hasattr(ai, "aclose") and callable(getattr(ai, "aclose")):
async_closers.append(ai.aclose())
elif hasattr(ai, "close") and callable(getattr(ai, "close")):
sync_closers.append(ai.close)
# Remove reference so a fresh one is created next time
big.pop("audio_input", None)
if async_closers:
await asyncio.gather(*async_closers, return_exceptions=True)
for fn in sync_closers:
try:
fn()
except Exception:
pass
# Reset PortAudio to drop lingering PipeWire capture nodes
try:
import sounddevice as _sd
if hasattr(_sd, "_terminate"):
_sd._terminate()
await asyncio.sleep(0.05)
if hasattr(_sd, "_initialize"):
_sd._initialize()
except Exception:
pass
async def stream(self):
@@ -414,6 +447,8 @@ class Streamer():
big['audio_input'] = audio_source
big['encoder'] = encoder
big['precoded'] = False
# Prepare frames generator for graceful shutdown
big['frames_gen'] = big['audio_input'].frames(lc3_frame_samples)
elif audio_source == 'webrtc':
big['audio_input'] = WebRTCAudioInput()
@@ -429,6 +464,8 @@ class Streamer():
big['lc3_bytes_per_frame'] = global_config.octets_per_frame
big['encoder'] = encoder
big['precoded'] = False
# Prepare frames generator for graceful shutdown
big['frames_gen'] = big['audio_input'].frames(lc3_frame_samples)
# precoded lc3 from ram
elif isinstance(big_config[i].audio_source, bytes):
@@ -599,7 +636,12 @@ class Streamer():
stream_finished[i] = True
continue
else: # code lc3 on the fly
pcm_frame = await anext(big['audio_input'].frames(big['lc3_frame_samples']), None)
# Use stored frames generator when available so we can aclose() it on stop
frames_gen = big.get('frames_gen')
if frames_gen is None:
frames_gen = big['audio_input'].frames(big['lc3_frame_samples'])
big['frames_gen'] = frames_gen
pcm_frame = await anext(frames_gen, None)
if pcm_frame is None: # Not all streams may stop at the same time
stream_finished[i] = True

View File

@@ -56,6 +56,8 @@ class Multicaster:
"""Start streaming; if an old stream is running, stop it first to release audio devices."""
if self.streamer is not None:
await self.stop_streaming()
# Brief pause to ensure ALSA/PortAudio fully releases the input device
await asyncio.sleep(0.5)
self.streamer = multicast.Streamer(self.bigs, self.global_conf, self.big_conf)
self.streamer.start_streaming()

View File

@@ -321,53 +321,37 @@ else:
# Input device selection for USB or AES67 mode
if audio_mode in ("USB", "AES67"):
try:
endpoint = "/audio_inputs_pw_usb" if audio_mode == "USB" else "/audio_inputs_pw_network"
resp = requests.get(f"{BACKEND_URL}{endpoint}")
device_list = resp.json().get('inputs', [])
except Exception as e:
st.error(f"Failed to fetch devices: {e}")
device_list = []
if not is_streaming:
# Only query device lists when NOT streaming to avoid extra backend calls
try:
endpoint = "/audio_inputs_pw_usb" if audio_mode == "USB" else "/audio_inputs_pw_network"
resp = requests.get(f"{BACKEND_URL}{endpoint}")
device_list = resp.json().get('inputs', [])
except Exception as e:
st.error(f"Failed to fetch devices: {e}")
device_list = []
# Display "name [id]" but use name as value
input_options = [f"{d['name']} [{d['id']}]" for d in device_list]
option_name_map = {f"{d['name']} [{d['id']}]": d['name'] for d in device_list}
device_names = [d['name'] for d in device_list]
# Display "name [id]" but use name as value
input_options = [f"{d['name']} [{d['id']}]" for d in device_list]
option_name_map = {f"{d['name']} [{d['id']}]": d['name'] for d in device_list}
device_names = [d['name'] for d in device_list]
# Determine default input by name (from persisted server state)
default_input_name = saved_settings.get('input_device')
if default_input_name not in device_names and device_names:
default_input_name = device_names[0]
default_input_label = None
for label, name in option_name_map.items():
if name == default_input_name:
default_input_label = label
break
if not input_options:
warn_text = (
"No USB audio input devices found. Connect a USB input and click Refresh."
if audio_mode == "USB" else
"No AES67/Network inputs found."
)
st.warning(warn_text)
if st.button("Refresh", disabled=is_streaming):
try:
r = requests.post(f"{BACKEND_URL}/refresh_audio_devices", timeout=8)
if not r.ok:
st.error(f"Failed to refresh: {r.text}")
except Exception as e:
st.error(f"Failed to refresh devices: {e}")
st.rerun()
input_device = None
else:
col1, col2 = st.columns([3, 1], vertical_alignment="bottom")
with col1:
selected_option = st.selectbox(
"Input Device",
input_options,
index=input_options.index(default_input_label) if default_input_label in input_options else 0
# Determine default input by name (from persisted server state)
default_input_name = saved_settings.get('input_device')
if default_input_name not in device_names and device_names:
default_input_name = device_names[0]
default_input_label = None
for label, name in option_name_map.items():
if name == default_input_name:
default_input_label = label
break
if not input_options:
warn_text = (
"No USB audio input devices found. Connect a USB input and click Refresh."
if audio_mode == "USB" else
"No AES67/Network inputs found."
)
with col2:
st.warning(warn_text)
if st.button("Refresh", disabled=is_streaming):
try:
r = requests.post(f"{BACKEND_URL}/refresh_audio_devices", timeout=8)
@@ -376,8 +360,29 @@ else:
except Exception as e:
st.error(f"Failed to refresh devices: {e}")
st.rerun()
# Send only the device name to backend
input_device = option_name_map.get(selected_option)
input_device = None
else:
col1, col2 = st.columns([3, 1], vertical_alignment="bottom")
with col1:
selected_option = st.selectbox(
"Input Device",
input_options,
index=input_options.index(default_input_label) if default_input_label in input_options else 0
)
with col2:
if st.button("Refresh", disabled=is_streaming):
try:
r = requests.post(f"{BACKEND_URL}/refresh_audio_devices", timeout=8)
if not r.ok:
st.error(f"Failed to refresh: {r.text}")
except Exception as e:
st.error(f"Failed to refresh devices: {e}")
st.rerun()
# Send only the device name to backend
input_device = option_name_map.get(selected_option)
else:
# When streaming, do not call backend for device lists. Reuse persisted selection.
input_device = saved_settings.get('input_device')
else:
input_device = None

View File

@@ -5,6 +5,7 @@ import uuid
import json
import sys
from datetime import datetime
import time
import asyncio
import numpy as np
from dotenv import load_dotenv
@@ -93,80 +94,84 @@ global_config_group = auracast_config.AuracastConfigGroup()
# Create multicast controller
multicaster1: multicast_control.Multicaster | None = None
multicaster2: multicast_control.Multicaster | None = None
_stream_lock = asyncio.Lock() # serialize initialize/stop_audio
@app.post("/init")
async def initialize(conf: auracast_config.AuracastConfigGroup):
"""Initializes the primary broadcaster (multicaster1)."""
global global_config_group
global multicaster1
try:
async with _stream_lock:
try:
# Cleanly stop any existing instance to avoid lingering PipeWire streams
if multicaster1 is not None:
log.info("Shutting down existing multicaster instance before re-initializing.")
await multicaster1.shutdown()
multicaster1 = None
conf.transport = TRANSPORT1
# Derive audio_mode and input_device from first BIG audio_source
first_source = conf.bigs[0].audio_source if conf.bigs else ''
if first_source.startswith('device:'):
input_device_name = first_source.split(':', 1)[1] if ':' in first_source else None
# Determine if the device is a USB or Network(AES67) PipeWire input
try:
usb_names = {d.get('name') for _, d in get_usb_pw_inputs()}
net_names = {d.get('name') for _, d in get_network_pw_inputs()}
except Exception:
usb_names, net_names = set(), set()
if input_device_name in net_names:
audio_mode_persist = 'AES67'
else:
audio_mode_persist = 'USB'
conf.transport = TRANSPORT1
# Derive audio_mode and input_device from first BIG audio_source
first_source = conf.bigs[0].audio_source if conf.bigs else ''
if first_source.startswith('device:'):
input_device_name = first_source.split(':', 1)[1] if ':' in first_source else None
# Determine if the device is a USB or Network(AES67) PipeWire input
try:
usb_names = {d.get('name') for _, d in get_usb_pw_inputs()}
net_names = {d.get('name') for _, d in get_network_pw_inputs()}
except Exception:
usb_names, net_names = set(), set()
audio_mode_persist = 'AES67' if input_device_name in net_names else 'USB'
# Resolve to device index and set input_format to avoid PipeWire resampling
device_index = None
if input_device_name:
device_index = int(input_device_name) if input_device_name.isdigit() else get_device_index_by_name(input_device_name)
if device_index is None:
log.error(f"Device name '{input_device_name}' not found in current device list.")
raise HTTPException(status_code=400, detail=f"Audio device '{input_device_name}' not found.")
# Map device name to current index for use with sounddevice
device_index = get_device_index_by_name(input_device_name) if input_device_name else None
if device_index is not None:
for big in conf.bigs:
if big.audio_source.startswith('device:'):
big.audio_source = f'device:{device_index}'
else:
log.error(f"Device name '{input_device_name}' not found in current device list.")
raise HTTPException(status_code=400, detail=f"Audio device '{input_device_name}' not found.")
elif first_source == 'webrtc':
audio_mode_persist = 'Webapp'
input_device_name = None
elif first_source.startswith('file:'):
audio_mode_persist = 'Demo'
input_device_name = None
else:
audio_mode_persist = 'Network'
input_device_name = None
save_stream_settings({
'channel_names': [big.name for big in conf.bigs],
'languages': [big.language for big in conf.bigs],
'audio_mode': audio_mode_persist,
'input_device': input_device_name,
'program_info': [getattr(big, 'program_info', None) for big in conf.bigs],
'gain': [getattr(big, 'input_gain', 1.0) for big in conf.bigs],
'auracast_sampling_rate_hz': conf.auracast_sampling_rate_hz,
'octets_per_frame': conf.octets_per_frame,
'immediate_rendering': getattr(conf, 'immediate_rendering', False),
'assisted_listening_stream': getattr(conf, 'assisted_listening_stream', False),
'stream_password': (conf.bigs[0].code if conf.bigs and getattr(conf.bigs[0], 'code', None) else None),
'is_streaming': False, # will be set to True below if we actually start
'timestamp': datetime.utcnow().isoformat()
})
global_config_group = conf
log.info('Initializing multicaster1 with config:\n %s', conf.model_dump_json(indent=2))
multicaster1 = multicast_control.Multicaster(conf, conf.bigs)
# Ensure target is reset before initializing broadcast
await reset_nrf54l(1)
await multicaster1.init_broadcast()
if any(big.audio_source.startswith("device:") or big.audio_source.startswith("file:") for big in conf.bigs):
log.info("Auto-starting streaming on multicaster1")
await multicaster1.start_streaming()
# Mark persisted state as streaming
settings = load_stream_settings() or {}
settings['is_streaming'] = True
settings['timestamp'] = datetime.utcnow().isoformat()
save_stream_settings(settings)
except Exception as e:
log.error("Exception in /init: %s", traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
devinfo = sd.query_devices(device_index)
capture_rate = int(devinfo.get('default_samplerate') or 48000)
max_in = int(devinfo.get('max_input_channels') or 1)
channels = max(1, min(2, max_in))
for big in conf.bigs:
big.input_format = f"int16le,{capture_rate},{channels}"
save_stream_settings({
'channel_names': [big.name for big in conf.bigs],
'languages': [big.language for big in conf.bigs],
'audio_mode': audio_mode_persist,
'input_device': input_device_name,
'program_info': [getattr(big, 'program_info', None) for big in conf.bigs],
'gain': [getattr(big, 'input_gain', 1.0) for big in conf.bigs],
'auracast_sampling_rate_hz': conf.auracast_sampling_rate_hz,
'octets_per_frame': conf.octets_per_frame,
'immediate_rendering': getattr(conf, 'immediate_rendering', False),
'assisted_listening_stream': getattr(conf, 'assisted_listening_stream', False),
'stream_password': (conf.bigs[0].code if conf.bigs and getattr(conf.bigs[0], 'code', None) else None),
'is_streaming': False, # will be set to True below if we actually start
'timestamp': datetime.utcnow().isoformat()
})
# Proceed with initialization and optional auto-start
global_config_group = conf
log.info('Initializing multicaster1 with config:\n %s', conf.model_dump_json(indent=2))
multicaster1 = multicast_control.Multicaster(conf, conf.bigs)
# Ensure target is reset before initializing broadcast
await reset_nrf54l(1)
await multicaster1.init_broadcast()
if any(big.audio_source.startswith("device:") or big.audio_source.startswith("file:") for big in conf.bigs):
log.info("Auto-starting streaming on multicaster1")
await multicaster1.start_streaming()
# Mark persisted state as streaming
settings = load_stream_settings() or {}
settings['is_streaming'] = True
settings['timestamp'] = datetime.utcnow().isoformat()
save_stream_settings(settings)
except Exception as e:
log.error("Exception in /init: %s", traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
@app.post("/init2")
async def initialize2(conf: auracast_config.AuracastConfigGroup):
@@ -197,6 +202,51 @@ async def initialize2(conf: auracast_config.AuracastConfigGroup):
raise HTTPException(status_code=500, detail=str(e))
@app.post("/stop_audio")
async def stop_audio():
"""Stops streaming on both multicaster1 and multicaster2."""
try:
# First close any active WebRTC peer connections so their track loops finish cleanly
close_tasks = [pc.close() for pc in list(pcs)]
pcs.clear()
if close_tasks:
await asyncio.gather(*close_tasks, return_exceptions=True)
# Now shut down both multicasters and release audio devices
global multicaster1, multicaster2
was_running = False
if multicaster1 is not None:
await multicaster1.stop_streaming()
await multicaster1.shutdown()
multicaster1 = None
was_running = True
if multicaster2 is not None:
await multicaster2.stop_streaming()
await multicaster2.shutdown()
multicaster2 = None
was_running = True
# Persist is_streaming=False
try:
settings = load_stream_settings() or {}
if settings.get('is_streaming'):
settings['is_streaming'] = False
settings['timestamp'] = datetime.utcnow().isoformat()
save_stream_settings(settings)
except Exception:
log.warning("Failed to persist is_streaming=False during stop_audio", exc_info=True)
# Grace period: allow PipeWire/PortAudio to fully drop capture nodes
await asyncio.sleep(0.2)
return {"status": "stopped", "was_running": was_running}
except Exception as e:
log.error("Exception in /stop_audio: %s", traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
@app.post("/stream_lc3")
async def send_audio(audio_data: dict[str, str]):
"""Sends a block of pre-coded LC3 audio."""
@@ -215,44 +265,6 @@ async def send_audio(audio_data: dict[str, str]):
raise HTTPException(status_code=500, detail=str(e))
@app.post("/stop_audio")
async def stop_audio():
"""Stops streaming on both multicaster1 and multicaster2."""
try:
# First close any active WebRTC peer connections so their track loops finish cleanly
close_tasks = [pc.close() for pc in list(pcs)]
pcs.clear()
if close_tasks:
await asyncio.gather(*close_tasks, return_exceptions=True)
# Now shut down both multicasters and release audio devices
global multicaster1, multicaster2
running = False
if multicaster1 is not None:
await multicaster1.shutdown()
multicaster1 = None
running = True
if multicaster2 is not None:
await multicaster2.shutdown()
multicaster2 = None
running = True
# Persist is_streaming=False
try:
settings = load_stream_settings() or {}
if settings.get('is_streaming'):
settings['is_streaming'] = False
settings['timestamp'] = datetime.utcnow().isoformat()
save_stream_settings(settings)
except Exception:
log.warning("Failed to persist is_streaming=False during stop_audio", exc_info=True)
return {"status": "stopped", "was_running": running}
except Exception as e:
log.error("Exception in /stop_audio: %s", traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
@app.get("/status")
async def get_status():
"""Gets the current status of the multicaster together with persisted stream info."""
@@ -264,6 +276,13 @@ async def get_status():
return status
@app.get("/long_block")
async def long_block():
"""Test endpoint that simulates a small delay without blocking the event loop."""
time.sleep(0.3)
return True
async def _autostart_from_settings():
"""Background task: auto-start last selected device-based input at server startup.
@@ -352,6 +371,8 @@ async def _autostart_from_settings():
@app.on_event("startup")
async def _startup_autostart_event():
# Spawn the autostart task without blocking startup
log.info("Refreshing PipeWire device cache.")
refresh_pw_cache()
asyncio.create_task(_autostart_from_settings())