|
|
| from typing import Optional
|
| import os
|
| import re
|
| import argparse
|
| import numpy as np
|
| import torch
|
| import torch.nn.functional as F
|
| import torchaudio
|
| import librosa
|
| import matplotlib.pyplot as plt
|
|
|
| _EPS = 1e-12
|
|
|
| def build_mel_transform(
|
| sample_rate,
|
| n_fft=1024,
|
| win_length=1024,
|
| hop_length=256,
|
| n_mels=80,
|
| power=1.0,
|
| f_min=0.0,
|
| f_max=None,
|
| mel_scale="htk",
|
| norm=None,
|
| device=None,
|
| ):
|
| mel_tf = torchaudio.transforms.MelSpectrogram(
|
| sample_rate=sample_rate,
|
| n_fft=n_fft,
|
| win_length=win_length,
|
| hop_length=hop_length,
|
| f_min=f_min,
|
| f_max=f_max,
|
| n_mels=n_mels,
|
| power=power,
|
| center=True,
|
| norm=norm,
|
| mel_scale=mel_scale,
|
| )
|
| if device is not None:
|
| mel_tf = mel_tf.to(device)
|
| return mel_tf
|
|
|
|
|
| def _ensure_stereo_torch(x):
|
| if x.dim() == 1:
|
| x = x.unsqueeze(0)
|
| if x.size(0) == 1:
|
| x = x.repeat(2, 1)
|
| elif x.size(0) > 2:
|
| x = x[:2]
|
| return x
|
|
|
|
|
| @torch.no_grad()
|
| def mel_cosine_stereo(
|
| ref, hat, sample_rate,
|
| n_fft=1024,
|
| win_length=1024,
|
| hop_length=256,
|
| n_mels=80,
|
| power=1.0,
|
| mel_tf=None,
|
| ):
|
| ref = _ensure_stereo_torch(ref)
|
| hat = _ensure_stereo_torch(hat)
|
|
|
| device = ref.device
|
| if mel_tf is None:
|
| mel_tf = build_mel_transform(
|
| sample_rate=sample_rate,
|
| n_fft=n_fft, win_length=win_length, hop_length=hop_length,
|
| n_mels=n_mels, power=power, device=device
|
| )
|
| else:
|
| mel_tf = mel_tf.to(device)
|
|
|
| Mr = mel_tf(ref)
|
| Mh = mel_tf(hat)
|
|
|
| Ar = Mr.reshape(Mr.size(0), -1)
|
| Ah = Mh.reshape(Mh.size(0), -1)
|
|
|
| sim = F.cosine_similarity(Ar, Ah, dim=-1)
|
| return float(sim.mean().item())
|
|
|
|
|
| @torch.no_grad()
|
| def drms_avg_db_stereo(ref, hat, win_length=1024, hop_length=256):
|
| ref = _ensure_stereo_torch(ref)
|
| hat = _ensure_stereo_torch(hat)
|
|
|
| def _rms_db(x):
|
| C, T = x.size(0), x.size(1)
|
| if T < win_length:
|
| x = F.pad(x, (0, win_length - T))
|
| frames = x.unfold(dimension=-1, size=win_length, step=hop_length)
|
| rms = torch.sqrt(frames.pow(2).mean(dim=-1) + _EPS)
|
| db = 20.0 * torch.log10(rms + _EPS)
|
| return db
|
|
|
| dbr = _rms_db(ref)
|
| dbh = _rms_db(hat)
|
|
|
| Fmin = min(dbr.size(-1), dbh.size(-1))
|
| dbr = dbr[:, :Fmin]
|
| dbh = dbh[:, :Fmin]
|
|
|
| d_db = dbh - dbr
|
| return float(d_db.mean(dim=-1).mean().item())
|
|
|
|
|
| def load_stereo_wav_np(path):
|
| y, sr = librosa.load(path, sr=None, mono=False)
|
| if y.ndim == 1:
|
| y = np.stack([y, y], axis=0)
|
| elif y.shape[0] != 2:
|
| y = y[:2]
|
| return y, sr
|
|
|
|
|
| def compute_spectrogram_np(audio_stereo,
|
| n_fft=512,
|
| hop_length=160,
|
| win_length=400,
|
| pool=4):
|
| def _stft_abs(sig):
|
| st = np.abs(librosa.stft(sig, n_fft=n_fft, hop_length=hop_length, win_length=win_length))
|
| h, w = st.shape
|
| hq, wq = h // pool, w // pool
|
| if hq == 0 or wq == 0:
|
| raise ValueError(f"audio too short for pooling (stft shape {st.shape})")
|
| st = st[:hq * pool, :wq * pool]
|
| st = st.reshape(hq, pool, wq, pool).mean(axis=(1, 3))
|
| return st
|
|
|
| L = np.log1p(_stft_abs(audio_stereo[0]))
|
| if audio_stereo.shape[0] >= 2:
|
| R = np.log1p(_stft_abs(audio_stereo[1]))
|
| else:
|
| R = L.copy()
|
| spec = np.stack([L, R], axis=-1)
|
| return spec
|
|
|
|
|
| def render_ref_hat_panel(title, spec_ref, spec_hat, out_path, cmap="magma"):
|
| L_all = [spec_ref[:, :, 0], spec_hat[:, :, 0]]
|
| R_all = [spec_ref[:, :, 1], spec_hat[:, :, 1]]
|
|
|
| if any(a.size == 0 for a in L_all + R_all):
|
| print(f"[SKIP]")
|
| return False
|
|
|
| vmin_L = min(a.min() for a in L_all)
|
| vmax_L = max(a.max() for a in L_all)
|
| vmin_R = min(a.min() for a in R_all)
|
| vmax_R = max(a.max() for a in R_all)
|
|
|
| fig, axes = plt.subplots(2, 2, figsize=(8, 6), constrained_layout=True)
|
| Lr, Rr = spec_ref[:, :, 0], spec_ref[:, :, 1]
|
| Lh, Rh = spec_hat[:, :, 0], spec_hat[:, :, 1]
|
|
|
| axes[0, 0].imshow(Lr, origin="lower", aspect="auto", cmap=cmap, vmin=vmin_L, vmax=vmax_L)
|
| axes[0, 1].imshow(Lh, origin="lower", aspect="auto", cmap=cmap, vmin=vmin_L, vmax=vmax_L)
|
| axes[1, 0].imshow(Rr, origin="lower", aspect="auto", cmap=cmap, vmin=vmin_R, vmax=vmax_R)
|
| axes[1, 1].imshow(Rh, origin="lower", aspect="auto", cmap=cmap, vmin=vmin_R, vmax=vmax_R)
|
|
|
| axes[0, 0].set_title("ref")
|
| axes[0, 1].set_title("hat")
|
| axes[0, 0].set_ylabel("Left")
|
| axes[1, 0].set_ylabel("Right")
|
|
|
| for ax in axes.ravel():
|
| ax.set_xticks([])
|
| ax.set_yticks([])
|
|
|
| fig.suptitle(title)
|
| os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True)
|
| plt.savefig(out_path, dpi=180)
|
| plt.close(fig)
|
| return True
|
|
|
|
|
| def save_ref_hat_spectrogram_panel(
|
| ref, hat, out_path,
|
| n_fft=512,
|
| hop_length=160,
|
| win_length=400,
|
| pool=4,
|
| title="ref vs hat (binaural spectrogram)",
|
| cmap="magma",
|
| ):
|
| def _to_np_stereo(x):
|
| if isinstance(x, torch.Tensor):
|
| x = x.detach().to(torch.float32).cpu().numpy()
|
| if x.ndim == 1:
|
| x = np.stack([x, x], axis=0)
|
| elif x.shape[0] == 1:
|
| x = np.repeat(x, 2, axis=0)
|
| elif x.shape[0] > 2:
|
| x = x[:2]
|
| return x
|
|
|
| ref_np = _to_np_stereo(ref)
|
| hat_np = _to_np_stereo(hat)
|
|
|
| spec_ref = compute_spectrogram_np(ref_np, n_fft=n_fft, hop_length=hop_length, win_length=win_length, pool=pool)
|
| spec_hat = compute_spectrogram_np(hat_np, n_fft=n_fft, hop_length=hop_length, win_length=win_length, pool=pool)
|
| return render_ref_hat_panel(title, spec_ref, spec_hat, out_path, cmap=cmap)
|
|
|