feat: add waveform visualization and improve confidence estimation with SNR gate

This commit is contained in:
2025-11-05 18:27:24 +01:00
parent d4756cbb03
commit 6b6a193125

View File

@@ -9,6 +9,7 @@ matplotlib.use("TkAgg") # GUI-Ausgabe für interaktives Fenster
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
from matplotlib.widgets import Button, Slider, CheckButtons
import matplotlib.patches as mpatches
import threading
import time
from datetime import datetime
@@ -35,7 +36,7 @@ def generate_tone(f_hz: float, dur_s: float, fs: int, volume: float,
out = np.concatenate([np.zeros(n_pre, dtype=np.float32), ref, np.zeros(n_post, dtype=np.float32)])
return out, ref, n_pre
def detect_onset_xcorr(signal: np.ndarray, ref: np.ndarray):
def detect_onset_xcorr(signal: np.ndarray, ref: np.ndarray, pre_len: int | None = None):
"""Normierte Kreuzkorrelation; liefert Onset-Index und Confidence."""
x = signal.astype(np.float64)
r = ref.astype(np.float64)
@@ -53,7 +54,21 @@ def detect_onset_xcorr(signal: np.ndarray, ref: np.ndarray):
nrm = np.sqrt(E_x * E_r) + 1e-20
nxc = corr / nrm
k = int(np.argmax(nxc))
conf = float(nxc[k])
peak = float(nxc[k])
# Robust confidence: compare peak to pre-silence baseline distribution
if pre_len is None or pre_len <= M:
base_end = max(1, int(len(nxc) * 0.2))
else:
base_end = max(1, min(len(nxc), int(pre_len - M + 1)))
base = nxc[:base_end]
if base.size <= 1:
conf = peak
else:
med = float(np.median(base))
mad = float(np.median(np.abs(base - med)))
scale = 1.4826 * mad + 1e-6
z = (peak - med) / scale
conf = float(1.0 / (1.0 + np.exp(-0.5 * z)))
return k, nxc, conf
# Simple biquad band-pass (RBJ cookbook) and direct-form I filter
@@ -90,7 +105,7 @@ def measure_latency_once(freq_hz: float, fs: int, dur_s: float, volume: float,
blocksize: int | None = None, iolatency: float | str | None = None,
estimator: str = "xcorr", xrun_counter: dict | None = None,
bandpass: bool = True, rms_info: dict | None = None,
io_info: dict | None = None):
io_info: dict | None = None, diag: dict | None = None):
"""Spielt einen Ton, nimmt parallel auf, schätzt Latenz in ms, gibt Confidence zurück."""
play_buf, ref, n_pre = generate_tone(freq_hz, dur_s, fs, volume, pre_silence, post_silence)
record_buf = []
@@ -162,7 +177,37 @@ def measure_latency_once(freq_hz: float, fs: int, dur_s: float, volume: float,
rec = lfilter_biquad(b, a, rec)
except Exception:
pass
onset_idx, _, conf = detect_onset_xcorr(rec, ref)
onset_idx, nxc, conf = detect_onset_xcorr(rec, ref, n_pre)
# Simple RMS gate: require window RMS to exceed pre-silence RMS
try:
M = len(ref)
base_rms = float(np.sqrt(np.mean(np.square(rec[:max(1, n_pre)])))) + 1e-12
w0 = int(max(0, onset_idx))
w1 = int(min(len(rec), w0 + M))
win_rms = float(np.sqrt(np.mean(np.square(rec[w0:w1])))) if w1 > w0 else 0.0
snr_lin = win_rms / max(base_rms, 1e-12)
if snr_lin < 2.0:
conf = float(min(conf, 0.2))
except Exception:
pass
# Fill diagnostics for visualization if requested
if diag is not None:
try:
diag.clear()
diag.update({
"fs": int(fs),
"play_buf": play_buf.copy(),
"rec": rec.copy(),
"ref": ref.copy(),
"n_pre": int(n_pre),
"onset_idx": int(onset_idx),
"nxc": nxc.copy(),
"bandpass": bool(bandpass)
})
except Exception:
pass
if estimator == "timeinfo":
if times.adc_first_time is None or times.dac_first_time is None:
@@ -224,11 +269,11 @@ def run_gui(fs: int, dur: float, vol: float, indev: int | None, outdev: int | No
confidences: list[float] = []
fig = plt.figure(figsize=(11, 6))
# Leave space at bottom for controls; scatter on left, terminal on right
plt.tight_layout(rect=[0, 0.16, 1, 1])
# Leave more space at bottom for a two-row control area
plt.tight_layout(rect=[0, 0.20, 1, 1])
# Scatterplot of latency vs sample index (left)
ax_sc = fig.add_axes([0.08, 0.22, 0.62, 0.72])
# Scatterplot of latency vs sample index (top-left)
ax_sc = fig.add_axes([0.05, 0.55, 0.62, 0.42])
ax_sc.set_title("Latency over samples", loc="left")
ax_sc.set_xlabel("sample index")
ax_sc.set_ylabel("latency [ms]")
@@ -255,27 +300,96 @@ def run_gui(fs: int, dur: float, vol: float, indev: int | None, outdev: int | No
except Exception:
pass
# Duplex waveform panel (below scatter, left)
ax_duplex = fig.add_axes([0.05, 0.25, 0.62, 0.25])
ax_duplex.set_title("Duplex stream (time-domain)", loc="left", fontsize=9)
ax_duplex.set_xlabel("time [ms]")
ax_duplex.set_ylabel("amplitude")
ax_duplex.grid(True, axis="both", alpha=0.25)
line_play, = ax_duplex.plot([], [], '-', color="#4C78A8", linewidth=1.0, label="playout")
line_rec, = ax_duplex.plot([], [], '-', color="#E45756", linewidth=1.0, alpha=0.9, label="record")
v_on = ax_duplex.axvline(0.0, color="#2ca02c", linestyle="--", linewidth=1.0, label="onset")
v_t0 = ax_duplex.axvline(0.0, color="#999999", linestyle=":", linewidth=1.0, label="tone start")
ax_duplex.legend(loc="upper right", fontsize=8)
# Visual box background
ax_duplex.set_facecolor("#fcfcff")
try:
ax_duplex.add_patch(mpatches.FancyBboxPatch((0, 0), 1, 1, transform=ax_duplex.transAxes,
boxstyle="round,pad=0.01", facecolor="#f7f9ff",
edgecolor="#c6d3f5", linewidth=0.8, zorder=-1, clip_on=False))
except Exception:
pass
# Terminal-style readout panel (right)
ax_log = fig.add_axes([0.73, 0.30, 0.23, 0.60])
ax_log.set_title("Measurements", loc="center", fontsize=9)
ax_log.set_title("Measurements", loc="center", fontsize=10)
ax_log.axis("off")
log_text = ax_log.text(0.0, 1.0, "", va="top", ha="left", family="monospace", fontsize=8)
LOG_WINDOW = 10 # show last 10 lines; start scrolling after 10
# Visual box
try:
ax_log.add_patch(mpatches.FancyBboxPatch((0, 0), 1, 1, transform=ax_log.transAxes,
boxstyle="round,pad=0.01", facecolor="#fbfbfb",
edgecolor="#dddddd", linewidth=0.8, zorder=-1, clip_on=False))
except Exception:
pass
# Stats panel (move higher and slightly to the right)
ax_stats = fig.add_axes([0.78, 0.60, 0.18, 0.04])
ax_stats = fig.add_axes([0.73, 0.60, 0.18, 0.04])
ax_stats.axis("off")
ax_stats.set_title("Stats", loc="left", fontsize=9)
stats_text = ax_stats.text(0.0, 1.0, "", va="top", ha="left", family="monospace", fontsize=8)
try:
ax_stats.add_patch(mpatches.FancyBboxPatch((0, 0), 1, 1, transform=ax_stats.transAxes,
boxstyle="round,pad=0.01", facecolor="#fbfbff",
edgecolor="#dfe6ff", linewidth=0.8, zorder=-1, clip_on=False))
except Exception:
pass
# Hardware/Status panel (just below the moved stats)
ax_hw = fig.add_axes([0.80, 0.46, 0.18, 0.04])
ax_hw = fig.add_axes([0.73, 0.46, 0.18, 0.04])
ax_hw.axis("off")
ax_hw.set_title("Hardware", loc="left", fontsize=9)
hw_text = ax_hw.text(0.0, 1.0, "", va="top", ha="left", family="monospace", fontsize=8)
try:
ax_hw.add_patch(mpatches.FancyBboxPatch((0, 0), 1, 1, transform=ax_hw.transAxes,
boxstyle="round,pad=0.01", facecolor="#fbfffb",
edgecolor="#d8f0d8", linewidth=0.8, zorder=-1, clip_on=False))
except Exception:
pass
# Information box (explains the measurement method). Placed above controls.
ax_info = fig.add_axes([0.2, 0.5, 0.5, 0.20])
ax_info.axis("off")
ax_info.set_title("Method", loc="left", fontsize=9)
# runtime I/O info (used below in info text; updated by stream callback later)
io_info = {"blocksize_actual": None}
info_text_str = (
f"Method details:\n"
f"- Signal: 440 Hz sine, dur={dur:.3f}s, pre={pre_silence:.2f}s, post={post_silence:.2f}s, vol={vol:.2f}; 5 ms fade-in/out.\n"
f"- I/O: full-duplex sd.Stream(fs={fs}, ch=1, dtype=float32, blocksize={io_info['blocksize_actual'] if io_info['blocksize_actual'] is not None else (blocksize if blocksize is not None else 'auto')}, latency={iolatency}).\n"
f"- Band-pass (optional): RBJ biquad centered 440 Hz, Q=8; direct-form I.\n"
f"- Pre-whitening: apply x[n]-0.97*x[n-1] on ref and recording.\n"
f"- Normalized xcorr: corr / sqrt(E_x*E_r) over valid lags; take peak index k and value.\n"
f"- Baseline/confidence: median/MAD of pre-silence nxc; z=(peak-med)/(1.4826*MAD+eps); conf=sigmoid(0.5*z).\n"
f"- SNR gate: window RMS vs pre-silence RMS; if <2x then cap conf≤0.2.\n"
f"- Latency (xcorr): ((k - n_pre)/fs)*1000 ms.\n"
f"- Latency (timeinfo): uses PortAudio DAC/ADC timestamps around onset.\n"
f"- Negatives are invalid → shown as NaN; conf_min and 'include low' control filtering.\n"
f"- Display: zero_offset subtracts current mean; rolling mean/std shown over window."
)
info_box = ax_info.text(
0.0, 1.0, info_text_str,
va="top", ha="left", fontsize=14, wrap=True,
bbox=dict(boxstyle="round", facecolor="#f0f6ff", edgecolor="#4C78A8", alpha=0.9)
)
running = threading.Event()
latest_changed = threading.Event()
lock = threading.Lock()
latest_conf = {"value": float("nan")}
last_diag = {}
info_visible = [True]
current_conf_min = [float(conf_min)]
include_low = [False]
@@ -284,8 +398,6 @@ def run_gui(fs: int, dur: float, vol: float, indev: int | None, outdev: int | No
xrun_counter = {"count": 0}
# input RMS meter shared
rms_info = {"rms_dbfs": float('nan'), "clip": False}
# runtime I/O info
io_info = {"blocksize_actual": None}
# Resolve device names for display
try:
@@ -435,6 +547,45 @@ def run_gui(fs: int, dur: float, vol: float, indev: int | None, outdev: int | No
sc_last.set_data([], [])
ann_last.set_text("")
# Update duplex waveform (left-bottom)
if last_diag:
try:
fs_d = int(last_diag.get("fs", fs))
rec = np.asarray(last_diag.get("rec", []), dtype=float)
play = np.asarray(last_diag.get("play_buf", []), dtype=float)
n_pre = int(last_diag.get("n_pre", 0))
onset_idx = int(last_diag.get("onset_idx", 0))
M = len(last_diag.get("ref", []))
# Choose a window around onset
w_before = max(M // 2, int(0.03 * fs_d))
w_after = max(int(1.5 * M), int(0.06 * fs_d))
s0 = max(0, onset_idx - w_before)
s1 = min(len(rec), onset_idx + w_after)
if s1 > s0:
t_ms = (np.arange(s0, s1) - onset_idx) * 1000.0 / fs_d
y_rec = rec[s0:s1]
y_play = play[s0:s1] if s1 <= len(play) else play[s0:min(s1, len(play))]
# Ensure same length for plotting
if y_play.shape[0] != (s1 - s0):
y_play = np.pad(y_play, (0, (s1 - s0) - y_play.shape[0]), mode='constant')
line_rec.set_data(t_ms, y_rec)
line_play.set_data(t_ms, y_play)
v_on.set_xdata([0.0, 0.0])
t0_ms = (n_pre - onset_idx) * 1000.0 / fs_d
v_t0.set_xdata([t0_ms, t0_ms])
# Y limits with padding
y_min = float(np.nanmin([np.min(y_rec), np.min(y_play)]) if y_rec.size and y_play.size else -1.0)
y_max = float(np.nanmax([np.max(y_rec), np.max(y_play)]) if y_rec.size and y_play.size else 1.0)
if not np.isfinite(y_min) or not np.isfinite(y_max) or y_min == y_max:
y_min, y_max = -1.0, 1.0
pad = 0.05 * (y_max - y_min)
ax_duplex.set_xlim(t_ms[0], t_ms[-1])
ax_duplex.set_ylim(y_min - pad, y_max + pad)
except Exception:
# If anything goes wrong, clear the duplex plot gracefully
line_rec.set_data([], [])
line_play.set_data([], [])
# Update rolling terminal (right)
lines = []
thr = current_conf_min[0]
@@ -487,12 +638,13 @@ def run_gui(fs: int, dur: float, vol: float, indev: int | None, outdev: int | No
def worker():
f = 440.0
while running.is_set():
local_diag = {}
lat_ms, conf = measure_latency_once(
f, fs, dur, vol, indev, outdev,
pre_silence=pre_silence, post_silence=post_silence,
blocksize=blocksize, iolatency=iolatency,
estimator=estimator, xrun_counter=xrun_counter,
bandpass=bandpass, rms_info=rms_info, io_info=io_info
bandpass=bandpass, rms_info=rms_info, io_info=io_info, diag=local_diag
)
with lock:
# Negative latencies are physically impossible -> mark as invalid (NaN)
@@ -502,6 +654,8 @@ def run_gui(fs: int, dur: float, vol: float, indev: int | None, outdev: int | No
latencies.append(lat_ms)
confidences.append(conf)
latest_conf["value"] = conf
last_diag.clear()
last_diag.update(local_diag)
latest_changed.set()
def on_start(event):
@@ -515,31 +669,19 @@ def run_gui(fs: int, dur: float, vol: float, indev: int | None, outdev: int | No
def on_stop(event):
running.clear()
# Slider for confidence threshold
slider_ax = fig.add_axes([0.10, 0.02, 0.32, 0.05])
slider = Slider(slider_ax, 'conf_min', 0.0, 1.0, valinit=current_conf_min[0], valstep=0.01)
def on_slider(val):
current_conf_min[0] = float(val)
update_plot()
slider.on_changed(on_slider)
# Checkbox to include low-confidence samples (placed next to conf_min slider)
cbox_ax = fig.add_axes([0.45, 0.02, 0.12, 0.05])
cbox = CheckButtons(cbox_ax, ["include low"], [include_low[0]])
def on_cbox(label):
include_low[0] = not include_low[0]
update_plot()
cbox.on_clicked(on_cbox)
# (removed middle window slider control)
start_ax = fig.add_axes([0.54, 0.02, 0.13, 0.06])
stop_ax = fig.add_axes([0.69, 0.02, 0.13, 0.06])
reset_ax = fig.add_axes([0.84, 0.02, 0.06, 0.06])
save_ax = fig.add_axes([0.92, 0.02, 0.06, 0.06])
zero_ax = fig.add_axes([0.84, 0.10, 0.06, 0.06])
zero_clr_ax = fig.add_axes([0.92, 0.10, 0.06, 0.06])
# Controls, two rows (no overlap)
slider_ax = fig.add_axes([0.08, 0.02, 0.46, 0.06])
cbox_ax = fig.add_axes([0.6, 0.02, 0.12, 0.06])
info_ax = fig.add_axes([0.75, 0.02, 0.08, 0.06])
start_ax = fig.add_axes([0.08, 0.10, 0.10, 0.08])
stop_ax = fig.add_axes([0.20, 0.10, 0.10, 0.08])
reset_ax = fig.add_axes([0.32, 0.10, 0.08, 0.08])
save_ax = fig.add_axes([0.42, 0.10, 0.08, 0.08])
zero_ax = fig.add_axes([0.52, 0.10, 0.10, 0.08])
zero_clr_ax = fig.add_axes([0.64, 0.10, 0.12, 0.08])
btn_info = Button(info_ax, "Info")
btn_start = Button(start_ax, "Start")
btn_stop = Button(stop_ax, "Stop")
btn_reset = Button(reset_ax, "Clr")
@@ -548,6 +690,25 @@ def run_gui(fs: int, dur: float, vol: float, indev: int | None, outdev: int | No
btn_zero_clr = Button(zero_clr_ax, "ZeroClr")
btn_start.on_clicked(on_start)
btn_stop.on_clicked(on_stop)
def on_info(event):
info_visible[0] = not info_visible[0]
ax_info.set_visible(info_visible[0])
fig.canvas.draw_idle()
btn_info.on_clicked(on_info)
# Now create Slider and CheckButtons after their axes exist
slider = Slider(slider_ax, 'conf_min', 0.0, 1.0, valinit=current_conf_min[0], valstep=0.01)
def on_slider(val):
current_conf_min[0] = float(val)
update_plot()
slider.on_changed(on_slider)
cbox = CheckButtons(cbox_ax, ["include low"], [include_low[0]])
def on_cbox(label):
include_low[0] = not include_low[0]
update_plot()
cbox.on_clicked(on_cbox)
def on_reset(event):
with lock:
latencies.clear()
@@ -624,7 +785,7 @@ def main():
ap.add_argument("-v", "--volume", type=float, default=0.6, help="Lautstärke 0..1")
ap.add_argument("--indev", type=int, default=None, help="Input-Geräteindex")
ap.add_argument("--outdev", type=int, default=None, help="Output-Geräteindex")
ap.add_argument("--conf-min", type=float, default=0.3, help="Warnschwelle für Confidence")
ap.add_argument("--conf-min", type=float, default=0.9, help="Warnschwelle für Confidence")
ap.add_argument("--blocksize", type=int, default=None, help="Audio blocksize (frames), e.g. 1024/2048")
ap.add_argument("--iolatency", type=str, default="high", help="Audio I/O latency (seconds or preset: 'low','high')")
ap.add_argument("--estimator", type=str, choices=["xcorr","timeinfo"], default="xcorr", help="Latency estimator: 'xcorr' (default, robust) or 'timeinfo' (host timestamps)")