# eval_audio.py from typing import Optional import os import re import argparse import numpy as np import torch import torch.nn.functional as F import torchaudio import librosa import matplotlib.pyplot as plt _EPS = 1e-12 def build_mel_transform( sample_rate, n_fft=1024, win_length=1024, hop_length=256, n_mels=80, power=1.0, f_min=0.0, f_max=None, mel_scale="htk", norm=None, device=None, ): mel_tf = torchaudio.transforms.MelSpectrogram( sample_rate=sample_rate, n_fft=n_fft, win_length=win_length, hop_length=hop_length, f_min=f_min, f_max=f_max, n_mels=n_mels, power=power, center=True, norm=norm, mel_scale=mel_scale, ) if device is not None: mel_tf = mel_tf.to(device) return mel_tf def _ensure_stereo_torch(x): if x.dim() == 1: x = x.unsqueeze(0) if x.size(0) == 1: x = x.repeat(2, 1) elif x.size(0) > 2: x = x[:2] return x @torch.no_grad() def mel_cosine_stereo( ref, hat, sample_rate, n_fft=1024, win_length=1024, hop_length=256, n_mels=80, power=1.0, mel_tf=None, ): ref = _ensure_stereo_torch(ref) hat = _ensure_stereo_torch(hat) device = ref.device if mel_tf is None: mel_tf = build_mel_transform( sample_rate=sample_rate, n_fft=n_fft, win_length=win_length, hop_length=hop_length, n_mels=n_mels, power=power, device=device ) else: mel_tf = mel_tf.to(device) Mr = mel_tf(ref) Mh = mel_tf(hat) Ar = Mr.reshape(Mr.size(0), -1) Ah = Mh.reshape(Mh.size(0), -1) sim = F.cosine_similarity(Ar, Ah, dim=-1) return float(sim.mean().item()) @torch.no_grad() def drms_avg_db_stereo(ref, hat, win_length=1024, hop_length=256): ref = _ensure_stereo_torch(ref) hat = _ensure_stereo_torch(hat) def _rms_db(x): C, T = x.size(0), x.size(1) if T < win_length: x = F.pad(x, (0, win_length - T)) frames = x.unfold(dimension=-1, size=win_length, step=hop_length) rms = torch.sqrt(frames.pow(2).mean(dim=-1) + _EPS) db = 20.0 * torch.log10(rms + _EPS) return db dbr = _rms_db(ref) dbh = _rms_db(hat) Fmin = min(dbr.size(-1), dbh.size(-1)) dbr = dbr[:, :Fmin] dbh = dbh[:, :Fmin] d_db = dbh - dbr return float(d_db.mean(dim=-1).mean().item()) def load_stereo_wav_np(path): y, sr = librosa.load(path, sr=None, mono=False) if y.ndim == 1: y = np.stack([y, y], axis=0) elif y.shape[0] != 2: y = y[:2] return y, sr def compute_spectrogram_np(audio_stereo, n_fft=512, hop_length=160, win_length=400, pool=4): def _stft_abs(sig): st = np.abs(librosa.stft(sig, n_fft=n_fft, hop_length=hop_length, win_length=win_length)) h, w = st.shape hq, wq = h // pool, w // pool if hq == 0 or wq == 0: raise ValueError(f"audio too short for pooling (stft shape {st.shape})") st = st[:hq * pool, :wq * pool] st = st.reshape(hq, pool, wq, pool).mean(axis=(1, 3)) return st L = np.log1p(_stft_abs(audio_stereo[0])) if audio_stereo.shape[0] >= 2: R = np.log1p(_stft_abs(audio_stereo[1])) else: R = L.copy() spec = np.stack([L, R], axis=-1) return spec def render_ref_hat_panel(title, spec_ref, spec_hat, out_path, cmap="magma"): L_all = [spec_ref[:, :, 0], spec_hat[:, :, 0]] R_all = [spec_ref[:, :, 1], spec_hat[:, :, 1]] if any(a.size == 0 for a in L_all + R_all): print(f"[SKIP]") return False vmin_L = min(a.min() for a in L_all) vmax_L = max(a.max() for a in L_all) vmin_R = min(a.min() for a in R_all) vmax_R = max(a.max() for a in R_all) fig, axes = plt.subplots(2, 2, figsize=(8, 6), constrained_layout=True) Lr, Rr = spec_ref[:, :, 0], spec_ref[:, :, 1] Lh, Rh = spec_hat[:, :, 0], spec_hat[:, :, 1] axes[0, 0].imshow(Lr, origin="lower", aspect="auto", cmap=cmap, vmin=vmin_L, vmax=vmax_L) axes[0, 1].imshow(Lh, origin="lower", aspect="auto", cmap=cmap, vmin=vmin_L, vmax=vmax_L) axes[1, 0].imshow(Rr, origin="lower", aspect="auto", cmap=cmap, vmin=vmin_R, vmax=vmax_R) axes[1, 1].imshow(Rh, origin="lower", aspect="auto", cmap=cmap, vmin=vmin_R, vmax=vmax_R) axes[0, 0].set_title("ref") axes[0, 1].set_title("hat") axes[0, 0].set_ylabel("Left") axes[1, 0].set_ylabel("Right") for ax in axes.ravel(): ax.set_xticks([]) ax.set_yticks([]) fig.suptitle(title) os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True) plt.savefig(out_path, dpi=180) plt.close(fig) return True def save_ref_hat_spectrogram_panel( ref, hat, out_path, n_fft=512, hop_length=160, win_length=400, pool=4, title="ref vs hat (binaural spectrogram)", cmap="magma", ): def _to_np_stereo(x): if isinstance(x, torch.Tensor): x = x.detach().to(torch.float32).cpu().numpy() if x.ndim == 1: x = np.stack([x, x], axis=0) elif x.shape[0] == 1: x = np.repeat(x, 2, axis=0) elif x.shape[0] > 2: x = x[:2] return x ref_np = _to_np_stereo(ref) hat_np = _to_np_stereo(hat) spec_ref = compute_spectrogram_np(ref_np, n_fft=n_fft, hop_length=hop_length, win_length=win_length, pool=pool) spec_hat = compute_spectrogram_np(hat_np, n_fft=n_fft, hop_length=hop_length, win_length=win_length, pool=pool) return render_ref_hat_panel(title, spec_ref, spec_hat, out_path, cmap=cmap)