feat: add waveform visualization and improve confidence estimation with SNR gate

This commit is contained in:
2025-11-05 18:27:24 +01:00
parent d4756cbb03
commit 6b6a193125

View File

@@ -9,6 +9,7 @@ matplotlib.use("TkAgg") # GUI-Ausgabe für interaktives Fenster
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import matplotlib.ticker as mticker import matplotlib.ticker as mticker
from matplotlib.widgets import Button, Slider, CheckButtons from matplotlib.widgets import Button, Slider, CheckButtons
import matplotlib.patches as mpatches
import threading import threading
import time import time
from datetime import datetime from datetime import datetime
@@ -35,7 +36,7 @@ def generate_tone(f_hz: float, dur_s: float, fs: int, volume: float,
out = np.concatenate([np.zeros(n_pre, dtype=np.float32), ref, np.zeros(n_post, dtype=np.float32)]) out = np.concatenate([np.zeros(n_pre, dtype=np.float32), ref, np.zeros(n_post, dtype=np.float32)])
return out, ref, n_pre return out, ref, n_pre
def detect_onset_xcorr(signal: np.ndarray, ref: np.ndarray): def detect_onset_xcorr(signal: np.ndarray, ref: np.ndarray, pre_len: int | None = None):
"""Normierte Kreuzkorrelation; liefert Onset-Index und Confidence.""" """Normierte Kreuzkorrelation; liefert Onset-Index und Confidence."""
x = signal.astype(np.float64) x = signal.astype(np.float64)
r = ref.astype(np.float64) r = ref.astype(np.float64)
@@ -53,7 +54,21 @@ def detect_onset_xcorr(signal: np.ndarray, ref: np.ndarray):
nrm = np.sqrt(E_x * E_r) + 1e-20 nrm = np.sqrt(E_x * E_r) + 1e-20
nxc = corr / nrm nxc = corr / nrm
k = int(np.argmax(nxc)) k = int(np.argmax(nxc))
conf = float(nxc[k]) peak = float(nxc[k])
# Robust confidence: compare peak to pre-silence baseline distribution
if pre_len is None or pre_len <= M:
base_end = max(1, int(len(nxc) * 0.2))
else:
base_end = max(1, min(len(nxc), int(pre_len - M + 1)))
base = nxc[:base_end]
if base.size <= 1:
conf = peak
else:
med = float(np.median(base))
mad = float(np.median(np.abs(base - med)))
scale = 1.4826 * mad + 1e-6
z = (peak - med) / scale
conf = float(1.0 / (1.0 + np.exp(-0.5 * z)))
return k, nxc, conf return k, nxc, conf
# Simple biquad band-pass (RBJ cookbook) and direct-form I filter # Simple biquad band-pass (RBJ cookbook) and direct-form I filter
@@ -90,7 +105,7 @@ def measure_latency_once(freq_hz: float, fs: int, dur_s: float, volume: float,
blocksize: int | None = None, iolatency: float | str | None = None, blocksize: int | None = None, iolatency: float | str | None = None,
estimator: str = "xcorr", xrun_counter: dict | None = None, estimator: str = "xcorr", xrun_counter: dict | None = None,
bandpass: bool = True, rms_info: dict | None = None, bandpass: bool = True, rms_info: dict | None = None,
io_info: dict | None = None): io_info: dict | None = None, diag: dict | None = None):
"""Spielt einen Ton, nimmt parallel auf, schätzt Latenz in ms, gibt Confidence zurück.""" """Spielt einen Ton, nimmt parallel auf, schätzt Latenz in ms, gibt Confidence zurück."""
play_buf, ref, n_pre = generate_tone(freq_hz, dur_s, fs, volume, pre_silence, post_silence) play_buf, ref, n_pre = generate_tone(freq_hz, dur_s, fs, volume, pre_silence, post_silence)
record_buf = [] record_buf = []
@@ -162,7 +177,37 @@ def measure_latency_once(freq_hz: float, fs: int, dur_s: float, volume: float,
rec = lfilter_biquad(b, a, rec) rec = lfilter_biquad(b, a, rec)
except Exception: except Exception:
pass pass
onset_idx, _, conf = detect_onset_xcorr(rec, ref) onset_idx, nxc, conf = detect_onset_xcorr(rec, ref, n_pre)
# Simple RMS gate: require window RMS to exceed pre-silence RMS
try:
M = len(ref)
base_rms = float(np.sqrt(np.mean(np.square(rec[:max(1, n_pre)])))) + 1e-12
w0 = int(max(0, onset_idx))
w1 = int(min(len(rec), w0 + M))
win_rms = float(np.sqrt(np.mean(np.square(rec[w0:w1])))) if w1 > w0 else 0.0
snr_lin = win_rms / max(base_rms, 1e-12)
if snr_lin < 2.0:
conf = float(min(conf, 0.2))
except Exception:
pass
# Fill diagnostics for visualization if requested
if diag is not None:
try:
diag.clear()
diag.update({
"fs": int(fs),
"play_buf": play_buf.copy(),
"rec": rec.copy(),
"ref": ref.copy(),
"n_pre": int(n_pre),
"onset_idx": int(onset_idx),
"nxc": nxc.copy(),
"bandpass": bool(bandpass)
})
except Exception:
pass
if estimator == "timeinfo": if estimator == "timeinfo":
if times.adc_first_time is None or times.dac_first_time is None: if times.adc_first_time is None or times.dac_first_time is None:
@@ -224,11 +269,11 @@ def run_gui(fs: int, dur: float, vol: float, indev: int | None, outdev: int | No
confidences: list[float] = [] confidences: list[float] = []
fig = plt.figure(figsize=(11, 6)) fig = plt.figure(figsize=(11, 6))
# Leave space at bottom for controls; scatter on left, terminal on right # Leave more space at bottom for a two-row control area
plt.tight_layout(rect=[0, 0.16, 1, 1]) plt.tight_layout(rect=[0, 0.20, 1, 1])
# Scatterplot of latency vs sample index (left) # Scatterplot of latency vs sample index (top-left)
ax_sc = fig.add_axes([0.08, 0.22, 0.62, 0.72]) ax_sc = fig.add_axes([0.05, 0.55, 0.62, 0.42])
ax_sc.set_title("Latency over samples", loc="left") ax_sc.set_title("Latency over samples", loc="left")
ax_sc.set_xlabel("sample index") ax_sc.set_xlabel("sample index")
ax_sc.set_ylabel("latency [ms]") ax_sc.set_ylabel("latency [ms]")
@@ -255,27 +300,96 @@ def run_gui(fs: int, dur: float, vol: float, indev: int | None, outdev: int | No
except Exception: except Exception:
pass pass
# Duplex waveform panel (below scatter, left)
ax_duplex = fig.add_axes([0.05, 0.25, 0.62, 0.25])
ax_duplex.set_title("Duplex stream (time-domain)", loc="left", fontsize=9)
ax_duplex.set_xlabel("time [ms]")
ax_duplex.set_ylabel("amplitude")
ax_duplex.grid(True, axis="both", alpha=0.25)
line_play, = ax_duplex.plot([], [], '-', color="#4C78A8", linewidth=1.0, label="playout")
line_rec, = ax_duplex.plot([], [], '-', color="#E45756", linewidth=1.0, alpha=0.9, label="record")
v_on = ax_duplex.axvline(0.0, color="#2ca02c", linestyle="--", linewidth=1.0, label="onset")
v_t0 = ax_duplex.axvline(0.0, color="#999999", linestyle=":", linewidth=1.0, label="tone start")
ax_duplex.legend(loc="upper right", fontsize=8)
# Visual box background
ax_duplex.set_facecolor("#fcfcff")
try:
ax_duplex.add_patch(mpatches.FancyBboxPatch((0, 0), 1, 1, transform=ax_duplex.transAxes,
boxstyle="round,pad=0.01", facecolor="#f7f9ff",
edgecolor="#c6d3f5", linewidth=0.8, zorder=-1, clip_on=False))
except Exception:
pass
# Terminal-style readout panel (right) # Terminal-style readout panel (right)
ax_log = fig.add_axes([0.73, 0.30, 0.23, 0.60]) ax_log = fig.add_axes([0.73, 0.30, 0.23, 0.60])
ax_log.set_title("Measurements", loc="center", fontsize=9) ax_log.set_title("Measurements", loc="center", fontsize=10)
ax_log.axis("off") ax_log.axis("off")
log_text = ax_log.text(0.0, 1.0, "", va="top", ha="left", family="monospace", fontsize=8) log_text = ax_log.text(0.0, 1.0, "", va="top", ha="left", family="monospace", fontsize=8)
LOG_WINDOW = 10 # show last 10 lines; start scrolling after 10 LOG_WINDOW = 10 # show last 10 lines; start scrolling after 10
# Visual box
try:
ax_log.add_patch(mpatches.FancyBboxPatch((0, 0), 1, 1, transform=ax_log.transAxes,
boxstyle="round,pad=0.01", facecolor="#fbfbfb",
edgecolor="#dddddd", linewidth=0.8, zorder=-1, clip_on=False))
except Exception:
pass
# Stats panel (move higher and slightly to the right) # Stats panel (move higher and slightly to the right)
ax_stats = fig.add_axes([0.78, 0.60, 0.18, 0.04]) ax_stats = fig.add_axes([0.73, 0.60, 0.18, 0.04])
ax_stats.axis("off") ax_stats.axis("off")
ax_stats.set_title("Stats", loc="left", fontsize=9)
stats_text = ax_stats.text(0.0, 1.0, "", va="top", ha="left", family="monospace", fontsize=8) stats_text = ax_stats.text(0.0, 1.0, "", va="top", ha="left", family="monospace", fontsize=8)
try:
ax_stats.add_patch(mpatches.FancyBboxPatch((0, 0), 1, 1, transform=ax_stats.transAxes,
boxstyle="round,pad=0.01", facecolor="#fbfbff",
edgecolor="#dfe6ff", linewidth=0.8, zorder=-1, clip_on=False))
except Exception:
pass
# Hardware/Status panel (just below the moved stats) # Hardware/Status panel (just below the moved stats)
ax_hw = fig.add_axes([0.80, 0.46, 0.18, 0.04]) ax_hw = fig.add_axes([0.73, 0.46, 0.18, 0.04])
ax_hw.axis("off") ax_hw.axis("off")
ax_hw.set_title("Hardware", loc="left", fontsize=9)
hw_text = ax_hw.text(0.0, 1.0, "", va="top", ha="left", family="monospace", fontsize=8) hw_text = ax_hw.text(0.0, 1.0, "", va="top", ha="left", family="monospace", fontsize=8)
try:
ax_hw.add_patch(mpatches.FancyBboxPatch((0, 0), 1, 1, transform=ax_hw.transAxes,
boxstyle="round,pad=0.01", facecolor="#fbfffb",
edgecolor="#d8f0d8", linewidth=0.8, zorder=-1, clip_on=False))
except Exception:
pass
# Information box (explains the measurement method). Placed above controls.
ax_info = fig.add_axes([0.2, 0.5, 0.5, 0.20])
ax_info.axis("off")
ax_info.set_title("Method", loc="left", fontsize=9)
# runtime I/O info (used below in info text; updated by stream callback later)
io_info = {"blocksize_actual": None}
info_text_str = (
f"Method details:\n"
f"- Signal: 440 Hz sine, dur={dur:.3f}s, pre={pre_silence:.2f}s, post={post_silence:.2f}s, vol={vol:.2f}; 5 ms fade-in/out.\n"
f"- I/O: full-duplex sd.Stream(fs={fs}, ch=1, dtype=float32, blocksize={io_info['blocksize_actual'] if io_info['blocksize_actual'] is not None else (blocksize if blocksize is not None else 'auto')}, latency={iolatency}).\n"
f"- Band-pass (optional): RBJ biquad centered 440 Hz, Q=8; direct-form I.\n"
f"- Pre-whitening: apply x[n]-0.97*x[n-1] on ref and recording.\n"
f"- Normalized xcorr: corr / sqrt(E_x*E_r) over valid lags; take peak index k and value.\n"
f"- Baseline/confidence: median/MAD of pre-silence nxc; z=(peak-med)/(1.4826*MAD+eps); conf=sigmoid(0.5*z).\n"
f"- SNR gate: window RMS vs pre-silence RMS; if <2x then cap conf≤0.2.\n"
f"- Latency (xcorr): ((k - n_pre)/fs)*1000 ms.\n"
f"- Latency (timeinfo): uses PortAudio DAC/ADC timestamps around onset.\n"
f"- Negatives are invalid → shown as NaN; conf_min and 'include low' control filtering.\n"
f"- Display: zero_offset subtracts current mean; rolling mean/std shown over window."
)
info_box = ax_info.text(
0.0, 1.0, info_text_str,
va="top", ha="left", fontsize=14, wrap=True,
bbox=dict(boxstyle="round", facecolor="#f0f6ff", edgecolor="#4C78A8", alpha=0.9)
)
running = threading.Event() running = threading.Event()
latest_changed = threading.Event() latest_changed = threading.Event()
lock = threading.Lock() lock = threading.Lock()
latest_conf = {"value": float("nan")} latest_conf = {"value": float("nan")}
last_diag = {}
info_visible = [True]
current_conf_min = [float(conf_min)] current_conf_min = [float(conf_min)]
include_low = [False] include_low = [False]
@@ -284,8 +398,6 @@ def run_gui(fs: int, dur: float, vol: float, indev: int | None, outdev: int | No
xrun_counter = {"count": 0} xrun_counter = {"count": 0}
# input RMS meter shared # input RMS meter shared
rms_info = {"rms_dbfs": float('nan'), "clip": False} rms_info = {"rms_dbfs": float('nan'), "clip": False}
# runtime I/O info
io_info = {"blocksize_actual": None}
# Resolve device names for display # Resolve device names for display
try: try:
@@ -435,6 +547,45 @@ def run_gui(fs: int, dur: float, vol: float, indev: int | None, outdev: int | No
sc_last.set_data([], []) sc_last.set_data([], [])
ann_last.set_text("") ann_last.set_text("")
# Update duplex waveform (left-bottom)
if last_diag:
try:
fs_d = int(last_diag.get("fs", fs))
rec = np.asarray(last_diag.get("rec", []), dtype=float)
play = np.asarray(last_diag.get("play_buf", []), dtype=float)
n_pre = int(last_diag.get("n_pre", 0))
onset_idx = int(last_diag.get("onset_idx", 0))
M = len(last_diag.get("ref", []))
# Choose a window around onset
w_before = max(M // 2, int(0.03 * fs_d))
w_after = max(int(1.5 * M), int(0.06 * fs_d))
s0 = max(0, onset_idx - w_before)
s1 = min(len(rec), onset_idx + w_after)
if s1 > s0:
t_ms = (np.arange(s0, s1) - onset_idx) * 1000.0 / fs_d
y_rec = rec[s0:s1]
y_play = play[s0:s1] if s1 <= len(play) else play[s0:min(s1, len(play))]
# Ensure same length for plotting
if y_play.shape[0] != (s1 - s0):
y_play = np.pad(y_play, (0, (s1 - s0) - y_play.shape[0]), mode='constant')
line_rec.set_data(t_ms, y_rec)
line_play.set_data(t_ms, y_play)
v_on.set_xdata([0.0, 0.0])
t0_ms = (n_pre - onset_idx) * 1000.0 / fs_d
v_t0.set_xdata([t0_ms, t0_ms])
# Y limits with padding
y_min = float(np.nanmin([np.min(y_rec), np.min(y_play)]) if y_rec.size and y_play.size else -1.0)
y_max = float(np.nanmax([np.max(y_rec), np.max(y_play)]) if y_rec.size and y_play.size else 1.0)
if not np.isfinite(y_min) or not np.isfinite(y_max) or y_min == y_max:
y_min, y_max = -1.0, 1.0
pad = 0.05 * (y_max - y_min)
ax_duplex.set_xlim(t_ms[0], t_ms[-1])
ax_duplex.set_ylim(y_min - pad, y_max + pad)
except Exception:
# If anything goes wrong, clear the duplex plot gracefully
line_rec.set_data([], [])
line_play.set_data([], [])
# Update rolling terminal (right) # Update rolling terminal (right)
lines = [] lines = []
thr = current_conf_min[0] thr = current_conf_min[0]
@@ -487,12 +638,13 @@ def run_gui(fs: int, dur: float, vol: float, indev: int | None, outdev: int | No
def worker(): def worker():
f = 440.0 f = 440.0
while running.is_set(): while running.is_set():
local_diag = {}
lat_ms, conf = measure_latency_once( lat_ms, conf = measure_latency_once(
f, fs, dur, vol, indev, outdev, f, fs, dur, vol, indev, outdev,
pre_silence=pre_silence, post_silence=post_silence, pre_silence=pre_silence, post_silence=post_silence,
blocksize=blocksize, iolatency=iolatency, blocksize=blocksize, iolatency=iolatency,
estimator=estimator, xrun_counter=xrun_counter, estimator=estimator, xrun_counter=xrun_counter,
bandpass=bandpass, rms_info=rms_info, io_info=io_info bandpass=bandpass, rms_info=rms_info, io_info=io_info, diag=local_diag
) )
with lock: with lock:
# Negative latencies are physically impossible -> mark as invalid (NaN) # Negative latencies are physically impossible -> mark as invalid (NaN)
@@ -502,6 +654,8 @@ def run_gui(fs: int, dur: float, vol: float, indev: int | None, outdev: int | No
latencies.append(lat_ms) latencies.append(lat_ms)
confidences.append(conf) confidences.append(conf)
latest_conf["value"] = conf latest_conf["value"] = conf
last_diag.clear()
last_diag.update(local_diag)
latest_changed.set() latest_changed.set()
def on_start(event): def on_start(event):
@@ -515,31 +669,19 @@ def run_gui(fs: int, dur: float, vol: float, indev: int | None, outdev: int | No
def on_stop(event): def on_stop(event):
running.clear() running.clear()
# Slider for confidence threshold
slider_ax = fig.add_axes([0.10, 0.02, 0.32, 0.05])
slider = Slider(slider_ax, 'conf_min', 0.0, 1.0, valinit=current_conf_min[0], valstep=0.01)
def on_slider(val):
current_conf_min[0] = float(val)
update_plot()
slider.on_changed(on_slider)
# Checkbox to include low-confidence samples (placed next to conf_min slider)
cbox_ax = fig.add_axes([0.45, 0.02, 0.12, 0.05])
cbox = CheckButtons(cbox_ax, ["include low"], [include_low[0]])
def on_cbox(label):
include_low[0] = not include_low[0]
update_plot()
cbox.on_clicked(on_cbox)
# (removed middle window slider control) # (removed middle window slider control)
start_ax = fig.add_axes([0.54, 0.02, 0.13, 0.06]) # Controls, two rows (no overlap)
stop_ax = fig.add_axes([0.69, 0.02, 0.13, 0.06]) slider_ax = fig.add_axes([0.08, 0.02, 0.46, 0.06])
reset_ax = fig.add_axes([0.84, 0.02, 0.06, 0.06]) cbox_ax = fig.add_axes([0.6, 0.02, 0.12, 0.06])
save_ax = fig.add_axes([0.92, 0.02, 0.06, 0.06]) info_ax = fig.add_axes([0.75, 0.02, 0.08, 0.06])
zero_ax = fig.add_axes([0.84, 0.10, 0.06, 0.06]) start_ax = fig.add_axes([0.08, 0.10, 0.10, 0.08])
zero_clr_ax = fig.add_axes([0.92, 0.10, 0.06, 0.06]) stop_ax = fig.add_axes([0.20, 0.10, 0.10, 0.08])
reset_ax = fig.add_axes([0.32, 0.10, 0.08, 0.08])
save_ax = fig.add_axes([0.42, 0.10, 0.08, 0.08])
zero_ax = fig.add_axes([0.52, 0.10, 0.10, 0.08])
zero_clr_ax = fig.add_axes([0.64, 0.10, 0.12, 0.08])
btn_info = Button(info_ax, "Info")
btn_start = Button(start_ax, "Start") btn_start = Button(start_ax, "Start")
btn_stop = Button(stop_ax, "Stop") btn_stop = Button(stop_ax, "Stop")
btn_reset = Button(reset_ax, "Clr") btn_reset = Button(reset_ax, "Clr")
@@ -548,6 +690,25 @@ def run_gui(fs: int, dur: float, vol: float, indev: int | None, outdev: int | No
btn_zero_clr = Button(zero_clr_ax, "ZeroClr") btn_zero_clr = Button(zero_clr_ax, "ZeroClr")
btn_start.on_clicked(on_start) btn_start.on_clicked(on_start)
btn_stop.on_clicked(on_stop) btn_stop.on_clicked(on_stop)
def on_info(event):
info_visible[0] = not info_visible[0]
ax_info.set_visible(info_visible[0])
fig.canvas.draw_idle()
btn_info.on_clicked(on_info)
# Now create Slider and CheckButtons after their axes exist
slider = Slider(slider_ax, 'conf_min', 0.0, 1.0, valinit=current_conf_min[0], valstep=0.01)
def on_slider(val):
current_conf_min[0] = float(val)
update_plot()
slider.on_changed(on_slider)
cbox = CheckButtons(cbox_ax, ["include low"], [include_low[0]])
def on_cbox(label):
include_low[0] = not include_low[0]
update_plot()
cbox.on_clicked(on_cbox)
def on_reset(event): def on_reset(event):
with lock: with lock:
latencies.clear() latencies.clear()
@@ -624,7 +785,7 @@ def main():
ap.add_argument("-v", "--volume", type=float, default=0.6, help="Lautstärke 0..1") ap.add_argument("-v", "--volume", type=float, default=0.6, help="Lautstärke 0..1")
ap.add_argument("--indev", type=int, default=None, help="Input-Geräteindex") ap.add_argument("--indev", type=int, default=None, help="Input-Geräteindex")
ap.add_argument("--outdev", type=int, default=None, help="Output-Geräteindex") ap.add_argument("--outdev", type=int, default=None, help="Output-Geräteindex")
ap.add_argument("--conf-min", type=float, default=0.3, help="Warnschwelle für Confidence") ap.add_argument("--conf-min", type=float, default=0.9, help="Warnschwelle für Confidence")
ap.add_argument("--blocksize", type=int, default=None, help="Audio blocksize (frames), e.g. 1024/2048") ap.add_argument("--blocksize", type=int, default=None, help="Audio blocksize (frames), e.g. 1024/2048")
ap.add_argument("--iolatency", type=str, default="high", help="Audio I/O latency (seconds or preset: 'low','high')") ap.add_argument("--iolatency", type=str, default="high", help="Audio I/O latency (seconds or preset: 'low','high')")
ap.add_argument("--estimator", type=str, choices=["xcorr","timeinfo"], default="xcorr", help="Latency estimator: 'xcorr' (default, robust) or 'timeinfo' (host timestamps)") ap.add_argument("--estimator", type=str, choices=["xcorr","timeinfo"], default="xcorr", help="Latency estimator: 'xcorr' (default, robust) or 'timeinfo' (host timestamps)")