|
|
|
|
|
|
| import os
|
| import json
|
| import argparse
|
| from pathlib import Path
|
| from tqdm import tqdm
|
|
|
| import torch
|
| import torch.distributed as dist_torch
|
| import torch.nn.functional as F
|
| import numpy as np
|
| from PIL import Image
|
| import lpips
|
| from dreamsim import dreamsim
|
| from torchvision import transforms
|
| from torcheval.metrics import FrechetInceptionDistance
|
| import soundfile as sf
|
| import resampy
|
| import distributed as dist
|
| import librosa
|
| from skimage.metrics import structural_similarity as sk_ssim
|
| from mel_scale import MelScale
|
|
|
|
|
|
|
|
|
| def safe_import_fad():
|
| """
|
| Import frechet_audio_distance.FrechetAudioDistance without letting downstream
|
| libraries parse our CLI args during import time.
|
| """
|
| import importlib, sys
|
| argv_backup = sys.argv[:]
|
| try:
|
| sys.argv = [argv_backup[0]]
|
| fad_mod = importlib.import_module("frechet_audio_distance")
|
| return getattr(fad_mod, "FrechetAudioDistance")
|
| finally:
|
| sys.argv = argv_backup
|
|
|
|
|
|
|
|
|
|
|
| def setup_distributed():
|
| if "RANK" in os.environ and "WORLD_SIZE" in os.environ and "LOCAL_RANK" in os.environ:
|
| rank = int(os.environ["RANK"])
|
| world_size = int(os.environ["WORLD_SIZE"])
|
| local_rank = int(os.environ["LOCAL_RANK"])
|
| else:
|
| return 0, 1, 0
|
|
|
| os.environ.setdefault("MASTER_ADDR", "127.0.0.1")
|
| os.environ.setdefault("MASTER_PORT", "29500")
|
|
|
| assert torch.cuda.is_available(), "CUDA Unavailable"
|
| assert torch.cuda.device_count() > local_rank, "local_rank out of the number of GPUs"
|
| torch.cuda.set_device(local_rank)
|
|
|
| dist_torch.init_process_group(
|
| backend="nccl",
|
| init_method="env://",
|
| rank=rank,
|
| world_size=world_size,
|
| )
|
| dist_torch.barrier()
|
|
|
| if rank == 0:
|
| print(f"[init] world_size={world_size} | rank->gpu OK")
|
|
|
| return rank, world_size, local_rank
|
|
|
|
|
|
|
|
|
|
|
| def get_loss_fn(loss_fn_type, secs, device):
|
| if loss_fn_type == 'lpips':
|
| general_lpips_loss_fn = lpips.LPIPS(net='alex').to(device).eval()
|
|
|
| def loss_fn(img0_paths, img1_paths):
|
| img0_list, img1_list = [], []
|
| for p0, p1 in zip(img0_paths, img1_paths):
|
| img0 = lpips.im2tensor(lpips.load_image(p0)).to(device)
|
| img1 = lpips.im2tensor(lpips.load_image(p1)).to(device)
|
| img0_list.append(img0)
|
| img1_list.append(img1)
|
| all_img0 = torch.cat(img0_list, dim=0)
|
| all_img1 = torch.cat(img1_list, dim=0)
|
| with torch.no_grad():
|
| dist_val = general_lpips_loss_fn.forward(all_img0, all_img1)
|
| return dist_val.mean()
|
|
|
| elif loss_fn_type == 'dreamsim':
|
| dreamsim_loss_fn, preprocess = dreamsim(pretrained=True, device=device)
|
| dreamsim_loss_fn.eval()
|
|
|
| def loss_fn(img0_paths, img1_paths):
|
| img0_list, img1_list = [], []
|
| for p0, p1 in zip(img0_paths, img1_paths):
|
| img0 = preprocess(Image.open(p0)).to(device)
|
| img1 = preprocess(Image.open(p1)).to(device)
|
| img0_list.append(img0)
|
| img1_list.append(img1)
|
| all_img0 = torch.cat(img0_list, dim=0)
|
| all_img1 = torch.cat(img1_list, dim=0)
|
| with torch.no_grad():
|
| dist_val = dreamsim_loss_fn(all_img0, all_img1)
|
| return dist_val.mean()
|
|
|
| elif loss_fn_type == 'fid':
|
| fid_metrics = {}
|
| for sec in secs:
|
| fid_metrics[sec] = FrechetInceptionDistance(feature_dim=2048).to(device)
|
| return fid_metrics
|
|
|
| else:
|
| raise NotImplementedError
|
|
|
| return loss_fn
|
|
|
|
|
|
|
| _EPS = 1e-12
|
|
|
| def _ensure_stereo_np(y: np.ndarray):
|
| if y.ndim == 1:
|
| y = np.stack([y, y], axis=0)
|
| elif y.ndim == 2:
|
| if y.shape[0] == 1:
|
| y = np.concatenate([y, y], axis=0)
|
| elif y.shape[0] > 2:
|
| y = y[:2, :]
|
| else:
|
| raise ValueError("Unsupported audio array shape")
|
| return y
|
|
|
| def _wav_to_spectrogram(wav: np.ndarray, rate: int):
|
| if rate == 44100:
|
| hop_length = 441
|
| n_fft = 2048
|
| elif rate == 16000:
|
| hop_length = 160
|
| n_fft = 743
|
| else:
|
| raise ValueError("Bad Samplerate (expected 16000 or 44100)")
|
|
|
| f = np.abs(librosa.stft(wav, hop_length=hop_length, n_fft=n_fft))
|
| f = np.transpose(f, (1, 0))
|
| f_torch = torch.tensor(f[None, None, ...], dtype=torch.float32)
|
| return f_torch
|
|
|
| def _lsd_from_specs(est: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
| ratio = (target ** 2) / ((est + _EPS) ** 2) + _EPS
|
| lsd = torch.log10(ratio) ** 2
|
| lsd = torch.mean(torch.mean(lsd, dim=3) ** 0.5, dim=2)
|
| return lsd.mean()
|
|
|
| def _mel_lsd_ssim_single(
|
| e_wav: np.ndarray,
|
| g_wav: np.ndarray,
|
| mel_tf: MelScale,
|
| n_fft: int = 743,
|
| hop_length: int = 160,
|
| ) -> tuple[float, float]:
|
| est_mag = np.abs(librosa.stft(e_wav, n_fft=n_fft, hop_length=hop_length))
|
| ref_mag = np.abs(librosa.stft(g_wav, n_fft=n_fft, hop_length=hop_length))
|
| est_mag_t = torch.from_numpy(est_mag).float()
|
| ref_mag_t = torch.from_numpy(ref_mag).float()
|
| est_mel = mel_tf(est_mag_t)
|
| ref_mel = mel_tf(ref_mag_t)
|
| ex_m = est_mel.transpose(0, 1).unsqueeze(0).unsqueeze(0)
|
| gt_m = ref_mel.transpose(0, 1).unsqueeze(0).unsqueeze(0)
|
| mel_lsd = float(_lsd_from_specs(ex_m, gt_m))
|
| mel_ssim = float(_ssim_from_specs(ex_m, gt_m))
|
| return mel_lsd, mel_ssim
|
|
|
| def _to_log_specs(x: torch.Tensor) -> torch.Tensor:
|
| return torch.log10(x + _EPS)
|
|
|
| def _pow_p_norm(x: torch.Tensor) -> torch.Tensor:
|
| return torch.mean(x.pow(2), dim=(2, 3))
|
|
|
| def _energy_unify(est: torch.Tensor, target: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| p_est = _pow_p_norm(est)
|
| p_tgt = _pow_p_norm(target)
|
| scale = torch.sqrt((p_tgt + _EPS) / (p_est + _EPS))
|
| scale = scale[..., None, None]
|
| est_scaled = est * scale
|
| return est_scaled, target
|
|
|
| def _sispec_from_specs(est: torch.Tensor, target: torch.Tensor, log_domain: bool) -> torch.Tensor:
|
| if log_domain:
|
| est = _to_log_specs(est)
|
| target = _to_log_specs(target)
|
| est_u, tgt_u = _energy_unify(est, target)
|
| noise = est_u - tgt_u
|
| snr = ( _pow_p_norm(tgt_u) / (_pow_p_norm(noise) + _EPS) ) + _EPS
|
| sp_loss = 10.0 * torch.log10(snr)
|
| return sp_loss.mean()
|
|
|
|
|
|
|
| def _psnr_from_tensors(gt: torch.Tensor, pred: torch.Tensor, data_range: float = 1.0, eps: float = 1e-10) -> torch.Tensor:
|
| mse = torch.mean((gt - pred) ** 2, dim=(1, 2, 3))
|
| dr = torch.as_tensor(data_range, device=gt.device, dtype=gt.dtype)
|
| psnr = 10.0 * torch.log10((dr * dr) / (mse + eps))
|
| return psnr
|
|
|
| def _ssim_from_specs(est: torch.Tensor, target: torch.Tensor) -> float:
|
| if est.is_cuda:
|
| est_np = est.detach().cpu().numpy()
|
| tgt_np = target.detach().cpu().numpy()
|
| else:
|
| est_np = est.numpy()
|
| tgt_np = target.numpy()
|
|
|
| N, C, _, _ = est_np.shape
|
| acc, cnt = 0.0, 0
|
| for n in range(N):
|
| for c in range(C):
|
| ref = tgt_np[n, c, ...]
|
| out = est_np[n, c, ...]
|
| rng = float(out.max() - out.min())
|
| rng = 1.0 if rng == 0.0 else rng
|
| s = sk_ssim(out, ref, win_size=7, data_range=rng)
|
| acc += float(s); cnt += 1
|
| return acc / max(cnt, 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
| class _RunningGaussianStats:
|
| def __init__(self, feat_dim: int, device: torch.device):
|
| self.D = feat_dim
|
| self.device = device
|
| self.reset()
|
|
|
| def reset(self):
|
| D = self.D
|
| self.count = torch.zeros(1, device=self.device, dtype=torch.float64)
|
| self.sum_feat = torch.zeros(D, device=self.device, dtype=torch.float64)
|
| self.sum_outer = torch.zeros(D, D, device=self.device, dtype=torch.float64)
|
|
|
| @torch.no_grad()
|
| def update(self, feats: torch.Tensor):
|
| if feats is None or feats.numel() == 0:
|
| return
|
| f = feats.to(dtype=torch.float64)
|
| self.count += torch.tensor([f.shape[0]], device=self.device, dtype=torch.float64)
|
| self.sum_feat += f.sum(dim=0)
|
| self.sum_outer += f.t().mm(f)
|
|
|
| @torch.no_grad()
|
| def sync(self):
|
| if dist_torch.is_initialized():
|
| for t in (self.count, self.sum_feat, self.sum_outer):
|
| dist_torch.all_reduce(t, op=dist_torch.ReduceOp.SUM)
|
|
|
| @torch.no_grad()
|
| def mean_cov(self, eps: float = 1e-6):
|
| n = int(self.count.item())
|
| if n == 0:
|
| return None, None
|
| mean = self.sum_feat / self.count
|
| cov = self.sum_outer / self.count - torch.ger(mean, mean)
|
| cov = cov + torch.eye(self.D, device=self.device, dtype=torch.float64) * eps
|
| return mean, cov
|
|
|
|
|
| @torch.no_grad()
|
| def _frechet_distance_torch(mean1, cov1, mean2, cov2) -> float:
|
| diff = mean1 - mean2
|
| diff2 = diff.dot(diff)
|
| evals1, evecs1 = torch.linalg.eigh(cov1)
|
| sqrt1 = evecs1 @ torch.diag(evals1.clamp(min=0).sqrt()) @ evecs1.t()
|
| prod = sqrt1 @ cov2 @ sqrt1
|
| evals_prod = torch.linalg.eigvalsh(prod).clamp(min=0).sqrt()
|
| trace = torch.trace(cov1 + cov2) - 2.0 * evals_prod.sum()
|
| return float((diff2 + trace).item())
|
|
|
|
|
| class StreamingFAD:
|
| """
|
| Mono (downmix) FID-style streaming FAD:
|
| - update_from_wavs(paths, is_real=True/False)
|
| - compute() # does DDP all_reduce internally
|
| """
|
| def __init__(self, fad_backend, pad_seconds: float = 0.96, batch_size: int = 16):
|
| self.fad = fad_backend
|
| self.device = self.fad.device
|
| self.bs = batch_size
|
| self.pad_len = int(round(self.fad.sample_rate * float(pad_seconds)))
|
| self.feat_dim = self._infer_feat_dim()
|
| self.real_stats = _RunningGaussianStats(self.feat_dim, self.device)
|
| self.fake_stats = _RunningGaussianStats(self.feat_dim, self.device)
|
|
|
| def _infer_feat_dim(self) -> int:
|
| sr = self.fad.sample_rate
|
| x = np.zeros((self.pad_len,), dtype=np.float32)
|
| emb = self.fad.get_embeddings([x], sr=sr)
|
| return int(emb.shape[-1]) if isinstance(emb, np.ndarray) else int(emb.shape[-1])
|
|
|
| @torch.no_grad()
|
| def _load_and_resample(self, path: str):
|
| try:
|
| audio, sr = sf.read(path, dtype="float32", always_2d=False)
|
| except Exception as e:
|
| print(f"[StreamingFAD] read error: {path}: {e}")
|
| return None
|
| if audio is None or (isinstance(audio, np.ndarray) and audio.size == 0):
|
| return None
|
| if isinstance(audio, np.ndarray) and audio.ndim == 2:
|
| audio = audio.mean(axis=1)
|
| if sr != self.fad.sample_rate:
|
| try:
|
| audio = resampy.resample(audio, sr, self.fad.sample_rate)
|
| except Exception as e:
|
| print(f"[StreamingFAD] resample error: {path}: {e}")
|
| return None
|
| if audio.shape[0] < self.pad_len:
|
| pad = np.zeros((self.pad_len - audio.shape[0],), dtype=np.float32)
|
| audio = np.concatenate([audio, pad], axis=0)
|
| return audio.astype(np.float32, copy=False)
|
|
|
| @torch.no_grad()
|
| def update_from_wavs(self, wav_paths, is_real: bool):
|
| if not wav_paths:
|
| return
|
| xs = []
|
| for p in wav_paths:
|
| a = self._load_and_resample(p)
|
| if a is not None:
|
| xs.append(a)
|
| if not xs:
|
| return
|
| feats_chunks = []
|
| for i in range(0, len(xs), self.bs):
|
| chunk = xs[i:i+self.bs]
|
| emb_np = self.fad.get_embeddings(chunk, sr=self.fad.sample_rate)
|
| if isinstance(emb_np, np.ndarray):
|
| if emb_np.size == 0:
|
| continue
|
| feats_chunks.append(torch.from_numpy(emb_np).to(self.device))
|
| else:
|
| if emb_np.numel() == 0:
|
| continue
|
| feats_chunks.append(emb_np.to(self.device))
|
| if len(feats_chunks) == 0:
|
| return
|
| feats = torch.cat(feats_chunks, dim=0)
|
| (self.real_stats if is_real else self.fake_stats).update(feats)
|
|
|
| @torch.no_grad()
|
| def compute(self) -> float:
|
| self.real_stats.sync()
|
| self.fake_stats.sync()
|
| m1, c1 = self.real_stats.mean_cov()
|
| m2, c2 = self.fake_stats.mean_cov()
|
| if (m1 is None) or (m2 is None):
|
| raise RuntimeError("StreamingFAD: empty stats")
|
| return _frechet_distance_torch(m1, c1, m2, c2)
|
|
|
|
|
| class StereoStreamingFAD:
|
| def __init__(self, fad_backend, pad_seconds: float = 0.96, batch_size: int = 16):
|
| self.fad = fad_backend
|
| self.device = self.fad.device
|
| self.bs = batch_size
|
| self.pad_len = int(round(self.fad.sample_rate * float(pad_seconds)))
|
|
|
| self.feat_dim = self._infer_feat_dim()
|
| self.L_real = _RunningGaussianStats(self.feat_dim, self.device)
|
| self.L_fake = _RunningGaussianStats(self.feat_dim, self.device)
|
| self.R_real = _RunningGaussianStats(self.feat_dim, self.device)
|
| self.R_fake = _RunningGaussianStats(self.feat_dim, self.device)
|
|
|
| def _infer_feat_dim(self) -> int:
|
| sr = self.fad.sample_rate
|
| x = np.zeros((self.pad_len,), dtype=np.float32)
|
| emb = self.fad.get_embeddings([x], sr=sr)
|
| return int(emb.shape[-1]) if isinstance(emb, np.ndarray) else int(emb.shape[-1])
|
|
|
| @torch.no_grad()
|
| def _load_lr_and_resample_pad(self, path: str):
|
| try:
|
| audio, sr = sf.read(path, dtype="float32", always_2d=True)
|
| except Exception as e:
|
| print(f"[StereoFAD] read error: {path}: {e}")
|
| return None, None
|
| if audio is None or audio.size == 0:
|
| return None, None
|
|
|
| C = audio.shape[1]
|
| if C == 1:
|
| L = audio[:, 0]; R = audio[:, 0]
|
| else:
|
| L = audio[:, 0]; R = audio[:, 1] if C >= 2 else audio[:, 0]
|
|
|
| if sr != self.fad.sample_rate:
|
| try:
|
| L = resampy.resample(L, sr, self.fad.sample_rate)
|
| R = resampy.resample(R, sr, self.fad.sample_rate)
|
| except Exception as e:
|
| print(f"[StereoFAD] resample error: {path}: {e}")
|
| return None, None
|
|
|
| def _pad_to_len(x: np.ndarray, n: int):
|
| if x.shape[0] >= n:
|
| return x.astype(np.float32, copy=False)
|
| pad = np.zeros((n - x.shape[0],), dtype=np.float32)
|
| return np.concatenate([x, pad], axis=0)
|
|
|
| L = _pad_to_len(L, self.pad_len)
|
| R = _pad_to_len(R, self.pad_len)
|
| return L, R
|
|
|
| @torch.no_grad()
|
| def update_from_wavs(self, wav_paths, is_real: bool):
|
| if not wav_paths:
|
| return
|
| L_list, R_list = [], []
|
| for p in wav_paths:
|
| L, R = self._load_lr_and_resample_pad(p)
|
| if L is not None and R is not None:
|
| L_list.append(L); R_list.append(R)
|
| if not L_list:
|
| return
|
|
|
| def _embed_and_update(xs, stats_obj: _RunningGaussianStats):
|
| feats_chunks = []
|
| for i in range(0, len(xs), self.bs):
|
| chunk = xs[i:i+self.bs]
|
| emb_np = self.fad.get_embeddings(chunk, sr=self.fad.sample_rate)
|
| if isinstance(emb_np, np.ndarray):
|
| if emb_np.size == 0:
|
| continue
|
| feats_chunks.append(torch.from_numpy(emb_np).to(self.device))
|
| else:
|
| if emb_np.numel() == 0:
|
| continue
|
| feats_chunks.append(emb_np.to(self.device))
|
| if len(feats_chunks) == 0:
|
| return
|
| feats = torch.cat(feats_chunks, dim=0)
|
| stats_obj.update(feats)
|
|
|
| if is_real:
|
| _embed_and_update(L_list, self.L_real)
|
| _embed_and_update(R_list, self.R_real)
|
| else:
|
| _embed_and_update(L_list, self.L_fake)
|
| _embed_and_update(R_list, self.R_fake)
|
|
|
| @torch.no_grad()
|
| def compute(self):
|
| for t in (self.L_real, self.L_fake, self.R_real, self.R_fake):
|
| t.sync()
|
| mL_r, cL_r = self.L_real.mean_cov()
|
| mL_f, cL_f = self.L_fake.mean_cov()
|
| mR_r, cR_r = self.R_real.mean_cov()
|
| mR_f, cR_f = self.R_fake.mean_cov()
|
| if (mL_r is None) or (mL_f is None) or (mR_r is None) or (mR_f is None):
|
| raise RuntimeError("StereoStreamingFAD: empty stats")
|
|
|
| fad_left = _frechet_distance_torch(mL_r, cL_r, mL_f, cL_f)
|
| fad_right = _frechet_distance_torch(mR_r, cR_r, mR_f, cR_f)
|
| fad_mean = 0.5 * (fad_left + fad_right)
|
| return float(fad_left), float(fad_right), float(fad_mean)
|
|
|
|
|
|
|
|
|
|
|
| def _load_librosa_stereo(path: str, sr: int) -> np.ndarray:
|
| y, _ = librosa.load(path, sr=sr, mono=False)
|
| y = _ensure_stereo_np(y)
|
| return y
|
|
|
| def _mel_cosine_single_channel(wav: np.ndarray, ref: np.ndarray, sr: int, mel_tf: MelScale) -> float:
|
| hop_length = 160; n_fft = 743
|
| est_mag = np.abs(librosa.stft(wav, hop_length=hop_length, n_fft=n_fft))
|
| ref_mag = np.abs(librosa.stft(ref, hop_length=hop_length, n_fft=n_fft))
|
|
|
| est_mag_t = torch.tensor(est_mag, dtype=torch.float32)
|
| ref_mag_t = torch.tensor(ref_mag, dtype=torch.float32)
|
|
|
| est_mel = mel_tf(est_mag_t)
|
| ref_mel = mel_tf(ref_mag_t)
|
|
|
| sim = F.cosine_similarity(est_mel.flatten(), ref_mel.flatten(), dim=0)
|
| return float(sim.item())
|
|
|
|
|
|
|
|
|
| def evaluate(args, dataset_name, eval_type, metric_logger, loss_fns,
|
| gt_dir, exp_dir, secs, device, rank, world_size, modals):
|
|
|
| lpips_loss_fn, dreamsim_loss_fn, fid_loss_fn = loss_fns
|
|
|
| if eval_type == 'rollout':
|
| eval_name = 'rollout'
|
| image_idxs = secs.copy()
|
| elif eval_type == 'time':
|
| eval_name = eval_type
|
| image_idxs = secs.copy()
|
| else:
|
| raise ValueError(f"Unknown eval_type {eval_type}")
|
|
|
| if 'v' in modals:
|
| for s in secs:
|
| metric_logger.meters[f'{dataset_name}_{eval_name}_fid_{int(s)}'].update(0.0, n=0)
|
|
|
|
|
| all_eps = sorted([e for e in os.listdir(gt_dir) if os.path.isdir(os.path.join(gt_dir, e))])
|
| eps = all_eps[rank::world_size]
|
| if len(eps) == 0:
|
| return
|
|
|
| to_tensor = transforms.ToTensor()
|
|
|
| fad_streams = {}
|
| stereo_mode = False
|
| if 'a' in modals:
|
| try:
|
| FADLib = safe_import_fad()
|
| except Exception as e:
|
| if rank == 0:
|
| print(f"[WARN] Fail to import frechet_audio_distance:{e}")
|
| FADLib = None
|
|
|
| if FADLib is not None:
|
| base_fad = FADLib(
|
| model_name=args.fad_model,
|
| sample_rate=args.fad_sr,
|
| verbose=False
|
| )
|
| if args.fad_model == 'vggish' and not args.mono:
|
| stereo_mode = True
|
| for sec in secs:
|
| fad_streams[sec] = StereoStreamingFAD(base_fad, pad_seconds=args.fad_pad_sec, batch_size=16)
|
| else:
|
| for sec in secs:
|
| fad_streams[sec] = StreamingFAD(base_fad, pad_seconds=args.fad_pad_sec, batch_size=16)
|
|
|
| mel_tf = MelScale(n_mels=80, sample_rate=16000, n_stft=372)
|
|
|
| for batch_start in tqdm(range(0, len(eps), args.batch_size),
|
| total=(len(eps) + args.batch_size - 1) // args.batch_size,
|
| disable=(rank != 0)):
|
| batch_eps = eps[batch_start:batch_start + args.batch_size]
|
|
|
|
|
| gt_img_batch, exp_img_batch = {}, {}
|
| gt_img_paths_batch, exp_img_paths_batch = {}, {}
|
| denorm_pairs_by_sec = {}
|
| secs_py = [int(s) for s in secs]
|
| denorm_pairs_by_sec = {s: [] for s in secs_py}
|
| for sec in secs:
|
| gt_img_batch[sec], exp_img_batch[sec] = [], []
|
| gt_img_paths_batch[sec], exp_img_paths_batch[sec] = [], []
|
|
|
|
|
| gt_wav_paths_batch, exp_wav_paths_batch = {}, {}
|
| for sec in secs:
|
| gt_wav_paths_batch[sec], exp_wav_paths_batch[sec] = [], []
|
|
|
| for ep in batch_eps:
|
| gt_ep_dir = os.path.join(gt_dir, ep)
|
| exp_ep_dir = os.path.join(exp_dir, ep)
|
|
|
| if (not os.path.isdir(gt_ep_dir)) or (not os.path.isdir(exp_ep_dir)):
|
| continue
|
|
|
| gt_dist_p = os.path.join(gt_ep_dir, "distance.json")
|
| exp_dist_p = os.path.join(exp_ep_dir, "distance.json")
|
| try:
|
| if os.path.isfile(gt_dist_p) and os.path.isfile(exp_dist_p):
|
| with open(gt_dist_p, "r") as f: gt_list = json.load(f)
|
| with open(exp_dist_p, "r") as f: exp_list = json.load(f)
|
| gt_map = {int(it["sec"]): float(it["denorm_gt"]) for it in gt_list if "sec" in it and "denorm_gt" in it}
|
| exp_map = {int(it["sec"]): float(it["denorm_pred"]) for it in exp_list if "sec" in it and "denorm_pred" in it}
|
| for s in secs_py:
|
| if s in gt_map and s in exp_map:
|
| denorm_pairs_by_sec[s].append((gt_map[s], exp_map[s]))
|
| except Exception:
|
| pass
|
|
|
|
|
| for sec, image_idx in zip(secs, image_idxs):
|
|
|
| if 'v' in modals:
|
| gt_sec_img_path = os.path.join(gt_ep_dir, f'{int(image_idx)}.png')
|
| exp_sec_img_path = os.path.join(exp_ep_dir, f'{int(image_idx)}.png')
|
| if os.path.isfile(gt_sec_img_path) and os.path.isfile(exp_sec_img_path):
|
| try:
|
| gt_img = to_tensor(Image.open(gt_sec_img_path).convert("RGB")).unsqueeze(0).to(device)
|
| exp_img = to_tensor(Image.open(exp_sec_img_path).convert("RGB")).unsqueeze(0).to(device)
|
| if torch.isfinite(gt_img).all() and torch.isfinite(exp_img).all():
|
| gt_img_batch[sec].append(gt_img)
|
| exp_img_batch[sec].append(exp_img)
|
| gt_img_paths_batch[sec].append(gt_sec_img_path)
|
| exp_img_paths_batch[sec].append(exp_sec_img_path)
|
| except Exception:
|
| pass
|
|
|
|
|
| if 'a' in modals:
|
| gt_sec_wav_path = os.path.join(gt_ep_dir, f'{int(image_idx)}.wav')
|
| exp_sec_wav_path = os.path.join(exp_ep_dir, f'{int(image_idx)}.wav')
|
| if os.path.isfile(gt_sec_wav_path) and os.path.isfile(exp_sec_wav_path):
|
| gt_wav_paths_batch[sec].append(gt_sec_wav_path)
|
| exp_wav_paths_batch[sec].append(exp_sec_wav_path)
|
|
|
|
|
| if 'v' in modals:
|
| for sec in secs:
|
| if (len(gt_img_batch[sec]) == 0) or (len(exp_img_batch[sec]) == 0):
|
| continue
|
| lpips_dists = lpips_loss_fn(gt_img_paths_batch[sec], exp_img_paths_batch[sec])
|
| dreamsim_dists = dreamsim_loss_fn(gt_img_paths_batch[sec], exp_img_paths_batch[sec])
|
| metric_logger.meters[f'{dataset_name}_{eval_name}_lpips_{sec}'].update(lpips_dists, n=1)
|
| metric_logger.meters[f'{dataset_name}_{eval_name}_dreamsim_{sec}'].update(dreamsim_dists, n=1)
|
|
|
| sec_gt_batch = torch.cat(gt_img_batch[sec], dim=0)
|
| sec_exp_batch = torch.cat(exp_img_batch[sec], dim=0)
|
| if torch.isfinite(sec_gt_batch).all() and torch.isfinite(sec_exp_batch).all():
|
| fid_loss_fn[sec].update(images=sec_gt_batch, is_real=True)
|
| fid_loss_fn[sec].update(images=sec_exp_batch, is_real=False)
|
| psnr_vals = _psnr_from_tensors(sec_gt_batch, sec_exp_batch, data_range=1.0)
|
| metric_logger.meters[f'{dataset_name}_{eval_name}_psnr_{sec}'].update(psnr_vals.mean(), n=1)
|
|
|
|
|
| if 'a' in modals:
|
|
|
| if len(fad_streams) > 0:
|
| for sec in secs:
|
| if len(gt_wav_paths_batch[sec]) == 0 and len(exp_wav_paths_batch[sec]) == 0:
|
| continue
|
| fad_streams[sec].update_from_wavs(gt_wav_paths_batch[sec], is_real=True)
|
| fad_streams[sec].update_from_wavs(exp_wav_paths_batch[sec], is_real=False)
|
|
|
|
|
| _AUDIO_SR = 16000
|
| for sec in secs:
|
| gt_list = gt_wav_paths_batch[sec]
|
| exp_list = exp_wav_paths_batch[sec]
|
| if len(gt_list) == 0 or len(exp_list) == 0:
|
| continue
|
| pair_cnt = min(len(gt_list), len(exp_list))
|
| if pair_cnt == 0:
|
| continue
|
|
|
| lsd_L, lsd_R, ssim_L, ssim_R = [], [], [], []
|
| mel_L, mel_R = [], []
|
|
|
| mel_lsd_L, mel_lsd_R = [], []
|
| mel_ssim_L, mel_ssim_R = [], []
|
|
|
| sispec_nl_L, sispec_nl_R = [], []
|
| sispec_log_L, sispec_log_R = [], []
|
| mel_sispec_nl_L, mel_sispec_n_R = [], []
|
| mel_sispec_log_L, mel_sispec_log_R = [], []
|
|
|
|
|
| for i in range(pair_cnt):
|
| gpath = gt_list[i]
|
| epath = exp_list[i]
|
| try:
|
| g_st = _load_librosa_stereo(gpath, _AUDIO_SR)
|
| e_st = _load_librosa_stereo(epath, _AUDIO_SR)
|
|
|
| if args.mono:
|
| g_mono = g_st.mean(axis=0)
|
| e_mono = e_st.mean(axis=0)
|
|
|
|
|
| gt_sp = _wav_to_spectrogram(g_mono, rate=_AUDIO_SR)
|
| ex_sp = _wav_to_spectrogram(e_mono, rate=_AUDIO_SR)
|
| lsd_val = _lsd_from_specs(ex_sp.clone(), gt_sp.clone())
|
| ssim_val = _ssim_from_specs(ex_sp.clone(), gt_sp.clone())
|
|
|
|
|
| mel_val = _mel_cosine_single_channel(e_mono, g_mono, _AUDIO_SR, mel_tf)
|
|
|
|
|
| mel_lsd_val, mel_ssim_val = _mel_lsd_ssim_single(e_mono, g_mono, mel_tf)
|
|
|
|
|
| sispec_nl = _sispec_from_specs(ex_sp.clone(), gt_sp.clone(), log_domain=False)
|
| sispec_log = _sispec_from_specs(ex_sp.clone(), gt_sp.clone(), log_domain=True)
|
|
|
| mel_sispec_nl = _sispec_from_specs(ex_m.clone(), gt_m.clone(), log_domain=False)
|
| mel_sispec_log = _sispec_from_specs(ex_m.clone(), gt_m.clone(), log_domain=True)
|
|
|
| metric_logger.meters[f'{dataset_name}_{eval_name}_lsd_{sec}'].update(lsd_val, n=1)
|
| metric_logger.meters[f'{dataset_name}_{eval_name}_ssim_{sec}'].update(
|
| torch.tensor(ssim_val), n=1
|
| )
|
| metric_logger.meters[f'{dataset_name}_{eval_name}_melcos_{sec}'].update(
|
| torch.tensor(mel_val), n=1
|
| )
|
|
|
| metric_logger.meters[f'{dataset_name}_{eval_name}_final_mel_lsd_{sec}'].update(
|
| torch.tensor(float(mel_lsd_val)), n=1
|
| )
|
| metric_logger.meters[f'{dataset_name}_{eval_name}_final_mel_ssim_{sec}'].update(
|
| torch.tensor(float(mel_ssim_val)), n=1
|
| )
|
|
|
| metric_logger.meters[f'{dataset_name}_{eval_name}_non_log_sispec_{sec}'].update(
|
| torch.tensor(float(sispec_nl)), n=1
|
| )
|
| metric_logger.meters[f'{dataset_name}_{eval_name}_sispec_{sec}'].update(
|
| torch.tensor(float(sispec_log)), n=1
|
| )
|
| metric_logger.meters[f'{dataset_name}_{eval_name}_final_non_log_mel_sispec_{sec}'].update(
|
| torch.tensor(float(mel_sispec_nl)), n=1
|
| )
|
| metric_logger.meters[f'{dataset_name}_{eval_name}_final_mel_sispec_{sec}'].update(
|
| torch.tensor(float(mel_sispec_log)), n=1
|
| )
|
|
|
|
|
| else:
|
| for ch, (acc_lsd, acc_ssim, acc_mel,
|
| acc_mel_lsd, acc_mel_ssim,
|
| acc_sispec_nl, acc_sispec_log,
|
| acc_mel_sispec_nl, acc_mel_sispec_log) in enumerate([
|
| (lsd_L, ssim_L, mel_L, mel_lsd_L, mel_ssim_L, sispec_nl_L, sispec_log_L, mel_sispec_nl_L, mel_sispec_log_L),
|
| (lsd_R, ssim_R, mel_R, mel_lsd_R, mel_ssim_R, sispec_nl_R, sispec_log_R, mel_sispec_n_R, mel_sispec_log_R),
|
| ]):
|
| g = g_st[ch]; e = e_st[ch]
|
|
|
| gt_sp = _wav_to_spectrogram(g, rate=_AUDIO_SR)
|
| ex_sp = _wav_to_spectrogram(e, rate=_AUDIO_SR)
|
| acc_lsd.append(float(_lsd_from_specs(ex_sp.clone(), gt_sp.clone())))
|
| acc_ssim.append(float(_ssim_from_specs(ex_sp.clone(), gt_sp.clone())))
|
|
|
| acc_mel.append(_mel_cosine_single_channel(e, g, _AUDIO_SR, mel_tf))
|
|
|
|
|
| mel_lsd_val, mel_ssim_val = _mel_lsd_ssim_single(e, g, mel_tf)
|
| acc_mel_lsd.append(mel_lsd_val)
|
| acc_mel_ssim.append(mel_ssim_val)
|
|
|
|
|
| acc_sispec_nl.append( float(_sispec_from_specs(ex_sp.clone(), gt_sp.clone(), log_domain=False)) )
|
| acc_sispec_log.append( float(_sispec_from_specs(ex_sp.clone(), gt_sp.clone(), log_domain=True)) )
|
|
|
| est_mag = np.abs(librosa.stft(e, n_fft=743, hop_length=160))
|
| ref_mag = np.abs(librosa.stft(g, n_fft=743, hop_length=160))
|
| est_mel = mel_tf(torch.from_numpy(est_mag).float())
|
| ref_mel = mel_tf(torch.from_numpy(ref_mag).float())
|
| ex_m = est_mel.T.unsqueeze(0).unsqueeze(0)
|
| gt_m = ref_mel.T.unsqueeze(0).unsqueeze(0)
|
|
|
| acc_mel_sispec_nl.append( float(_sispec_from_specs(ex_m.clone(), gt_m.clone(), log_domain=False)) )
|
| acc_mel_sispec_log.append( float(_sispec_from_specs(ex_m.clone(), gt_m.clone(), log_domain=True)) )
|
|
|
| except Exception:
|
| pass
|
|
|
| if not args.mono:
|
| def _maybe_mean(x):
|
| return float(np.mean(x)) if len(x) > 0 else None
|
|
|
| v = _maybe_mean(lsd_L); w = _maybe_mean(lsd_R)
|
| if v is not None: metric_logger.meters[f'{dataset_name}_{eval_name}_lsdL_{sec}'].update(torch.tensor(v), n=1)
|
| if w is not None: metric_logger.meters[f'{dataset_name}_{eval_name}_lsdR_{sec}'].update(torch.tensor(w), n=1)
|
| if v is not None and w is not None:
|
| metric_logger.meters[f'{dataset_name}_{eval_name}_lsd_{sec}'].update(torch.tensor(0.5*(v+w)), n=1)
|
|
|
| v = _maybe_mean(ssim_L); w = _maybe_mean(ssim_R)
|
| if v is not None: metric_logger.meters[f'{dataset_name}_{eval_name}_ssimL_{sec}'].update(torch.tensor(v), n=1)
|
| if w is not None: metric_logger.meters[f'{dataset_name}_{eval_name}_ssimR_{sec}'].update(torch.tensor(w), n=1)
|
| if v is not None and w is not None:
|
| metric_logger.meters[f'{dataset_name}_{eval_name}_ssim_{sec}'].update(torch.tensor(0.5*(v+w)), n=1)
|
|
|
| v = _maybe_mean(mel_L); w = _maybe_mean(mel_R)
|
| if v is not None: metric_logger.meters[f'{dataset_name}_{eval_name}_melcosL_{sec}'].update(torch.tensor(v), n=1)
|
| if w is not None: metric_logger.meters[f'{dataset_name}_{eval_name}_melcosR_{sec}'].update(torch.tensor(w), n=1)
|
| if v is not None and w is not None:
|
| metric_logger.meters[f'{dataset_name}_{eval_name}_melcos_{sec}'].update(torch.tensor(0.5*(v+w)), n=1)
|
|
|
| v = _maybe_mean(mel_lsd_L); w = _maybe_mean(mel_lsd_R)
|
| if v is not None: metric_logger.meters[f'{dataset_name}_{eval_name}_final_mel_lsdL_{sec}'].update(torch.tensor(v), n=1)
|
| if w is not None: metric_logger.meters[f'{dataset_name}_{eval_name}_final_mel_lsdR_{sec}'].update(torch.tensor(w), n=1)
|
| if v is not None and w is not None:
|
| metric_logger.meters[f'{dataset_name}_{eval_name}_final_mel_lsd_{sec}'].update(torch.tensor(0.5*(v+w)), n=1)
|
|
|
| v = _maybe_mean(mel_ssim_L); w = _maybe_mean(mel_ssim_R)
|
| if v is not None: metric_logger.meters[f'{dataset_name}_{eval_name}_final_mel_ssimL_{sec}'].update(torch.tensor(v), n=1)
|
| if w is not None: metric_logger.meters[f'{dataset_name}_{eval_name}_final_mel_ssimR_{sec}'].update(torch.tensor(w), n=1)
|
| if v is not None and w is not None:
|
| metric_logger.meters[f'{dataset_name}_{eval_name}_final_mel_ssim_{sec}'].update(torch.tensor(0.5*(v+w)), n=1)
|
|
|
| v = _maybe_mean(sispec_nl_L); w = _maybe_mean(sispec_nl_R)
|
| if v is not None: metric_logger.meters[f'{dataset_name}_{eval_name}_non_log_sispecL_{sec}'].update(torch.tensor(v), n=1)
|
| if w is not None: metric_logger.meters[f'{dataset_name}_{eval_name}_non_log_sispecR_{sec}'].update(torch.tensor(w), n=1)
|
| if v is not None and w is not None:
|
| metric_logger.meters[f'{dataset_name}_{eval_name}_non_log_sispec_{sec}'].update(torch.tensor(0.5*(v+w)), n=1)
|
|
|
| v = _maybe_mean(sispec_log_L); w = _maybe_mean(sispec_log_R)
|
| if v is not None: metric_logger.meters[f'{dataset_name}_{eval_name}_sispecL_{sec}'].update(torch.tensor(v), n=1)
|
| if w is not None: metric_logger.meters[f'{dataset_name}_{eval_name}_sispecR_{sec}'].update(torch.tensor(w), n=1)
|
| if v is not None and w is not None:
|
| metric_logger.meters[f'{dataset_name}_{eval_name}_sispec_{sec}'].update(torch.tensor(0.5*(v+w)), n=1)
|
|
|
| v = _maybe_mean(mel_sispec_nl_L); w = _maybe_mean(mel_sispec_n_R)
|
| if v is not None: metric_logger.meters[f'{dataset_name}_{eval_name}_final_non_log_mel_sispecL_{sec}'].update(torch.tensor(v), n=1)
|
| if w is not None: metric_logger.meters[f'{dataset_name}_{eval_name}_final_non_log_mel_sispecR_{sec}'].update(torch.tensor(w), n=1)
|
| if v is not None and w is not None:
|
| metric_logger.meters[f'{dataset_name}_{eval_name}_final_non_log_mel_sispec_{sec}'].update(torch.tensor(0.5*(v+w)), n=1)
|
|
|
| v = _maybe_mean(mel_sispec_log_L); w = _maybe_mean(mel_sispec_log_R)
|
| if v is not None: metric_logger.meters[f'{dataset_name}_{eval_name}_final_mel_sispecL_{sec}'].update(torch.tensor(v), n=1)
|
| if w is not None: metric_logger.meters[f'{dataset_name}_{eval_name}_final_mel_sispecR_{sec}'].update(torch.tensor(w), n=1)
|
| if v is not None and w is not None:
|
| metric_logger.meters[f'{dataset_name}_{eval_name}_final_mel_sispec_{sec}'].update(torch.tensor(0.5*(v+w)), n=1)
|
| for s in secs_py:
|
| pairs = denorm_pairs_by_sec[s]
|
| if not pairs:
|
| continue
|
| arr = np.asarray(pairs, dtype=np.float32)
|
| mask = np.isfinite(arr).all(axis=1)
|
| if not np.any(mask):
|
| continue
|
| se_mean = float(np.mean((arr[mask, 1] - arr[mask, 0]) ** 2))
|
| metric_logger.meters[f'{dataset_name}_{eval_name}_denorm_mse_{s}'].update(
|
| torch.tensor(se_mean), n=1
|
| )
|
|
|
| if 'v' in modals:
|
| feature_dim = 2048
|
| sec_list = [int(s) for s in secs]
|
| tmp_dir = Path(os.path.join(args.exp_dir, ".fid_tmp"))
|
| if dist_torch.is_initialized():
|
| if dist_torch.get_rank() == 0:
|
| tmp_dir.mkdir(parents=True, exist_ok=True)
|
| dist_torch.barrier()
|
| else:
|
| tmp_dir.mkdir(parents=True, exist_ok=True)
|
| if dist_torch.is_initialized():
|
| my_rank = dist_torch.get_rank()
|
| world_size = dist_torch.get_world_size()
|
| else:
|
| my_rank = 0
|
| world_size = 1
|
|
|
| for s in sec_list:
|
| fid_m = fid_loss_fn[s]
|
| state = {
|
| "real_sum": fid_m.real_sum.detach().to("cpu", torch.float64),
|
| "real_cov_sum": fid_m.real_cov_sum.detach().to("cpu", torch.float64),
|
| "fake_sum": fid_m.fake_sum.detach().to("cpu", torch.float64),
|
| "fake_cov_sum": fid_m.fake_cov_sum.detach().to("cpu", torch.float64),
|
| "num_real_images": torch.tensor(int(fid_m.num_real_images.item()), dtype=torch.int64),
|
| "num_fake_images": torch.tensor(int(fid_m.num_fake_images.item()), dtype=torch.int64),
|
| }
|
| out_path = tmp_dir / f"fid_sec{s}_rank{my_rank}.pt"
|
| torch.save(state, out_path)
|
| if dist_torch.is_initialized():
|
| dist_torch.barrier()
|
| if (not dist_torch.is_initialized()) or my_rank == 0:
|
| for s in sec_list:
|
| agg = {
|
| "real_sum": torch.zeros(feature_dim, dtype=torch.float64),
|
| "real_cov_sum": torch.zeros((feature_dim, feature_dim), dtype=torch.float64),
|
| "fake_sum": torch.zeros(feature_dim, dtype=torch.float64),
|
| "fake_cov_sum": torch.zeros((feature_dim, feature_dim), dtype=torch.float64),
|
| "num_real_images": torch.tensor(0, dtype=torch.int64),
|
| "num_fake_images": torch.tensor(0, dtype=torch.int64),
|
| }
|
| for r in range(world_size):
|
| p = tmp_dir / f"fid_sec{s}_rank{r}.pt"
|
| if not p.exists():
|
| continue
|
| st = torch.load(p, map_location="cpu")
|
| agg["real_sum"] += st["real_sum"]
|
| agg["real_cov_sum"] += st["real_cov_sum"]
|
| agg["fake_sum"] += st["fake_sum"]
|
| agg["fake_cov_sum"] += st["fake_cov_sum"]
|
| agg["num_real_images"] += st["num_real_images"]
|
| agg["num_fake_images"] += st["num_fake_images"]
|
| fid_m = fid_loss_fn[s]
|
| fid_m.real_sum = agg["real_sum"].to(fid_m.device, fid_m.real_sum.dtype)
|
| fid_m.real_cov_sum = agg["real_cov_sum"].to(fid_m.device, fid_m.real_cov_sum.dtype)
|
| fid_m.fake_sum = agg["fake_sum"].to(fid_m.device, fid_m.fake_sum.dtype)
|
| fid_m.fake_cov_sum = agg["fake_cov_sum"].to(fid_m.device, fid_m.fake_cov_sum.dtype)
|
| fid_m.num_real_images = torch.tensor(
|
| int(agg["num_real_images"].item()), device=fid_m.device, dtype=fid_m.num_real_images.dtype
|
| )
|
| fid_m.num_fake_images = torch.tensor(
|
| int(agg["num_fake_images"].item()), device=fid_m.device, dtype=fid_m.num_fake_images.dtype
|
| )
|
|
|
| try:
|
| val = float(fid_m.compute().item())
|
| metric_logger.meters[f'{dataset_name}_{eval_name}_fid_{s}'].update(val, n=1)
|
| except Exception as e:
|
| print(f"[WARN] FID compute failed at sec={s}: {e}")
|
| for s in sec_list:
|
| for r in range(world_size):
|
| p = tmp_dir / f"fid_sec{s}_rank{r}.pt"
|
| try:
|
| if p.exists():
|
| p.unlink()
|
| except Exception:
|
| pass
|
| try:
|
| tmp_dir.rmdir()
|
| except Exception:
|
| pass
|
| if dist_torch.is_initialized():
|
| dist_torch.barrier()
|
|
|
| if 'a' in modals and len(fad_streams) > 0:
|
| for sec in secs:
|
| try:
|
| if stereo_mode:
|
| fad_L, fad_R, fad_avg = fad_streams[sec].compute()
|
| metric_logger.meters[f'{dataset_name}_{eval_name}_fadL_{sec}'].update(fad_L, n=1)
|
| metric_logger.meters[f'{dataset_name}_{eval_name}_fadR_{sec}'].update(fad_R, n=1)
|
| metric_logger.meters[f'{dataset_name}_{eval_name}_fad_{sec}'].update(fad_avg, n=1)
|
| else:
|
| fad_val = float(fad_streams[sec].compute())
|
| metric_logger.meters[f'{dataset_name}_{eval_name}_fad_{sec}'].update(fad_val, n=1)
|
| except Exception as e:
|
| if rank == 0:
|
| print(f"[WARN] FAD compute failed at sec={sec}: {e}")
|
| continue
|
|
|
|
|
|
|
|
|
|
|
| def save_metric_to_disk(metric_logger, log_p, rank):
|
| if dist_torch.is_initialized():
|
| metric_logger.synchronize_between_processes()
|
| if rank == 0:
|
| log_stats = {k: float(meter.global_avg) for k, meter in metric_logger.meters.items()}
|
| os.makedirs(os.path.dirname(log_p), exist_ok=True)
|
| with open(log_p, 'w') as json_file:
|
| json.dump(log_stats, json_file, indent=4)
|
| print(f"[OK] Metrics saved to: {log_p}")
|
|
|
|
|
|
|
|
|
|
|
| def main(args):
|
| rank, world_size, local_rank = setup_distributed()
|
| device = f"cuda:{local_rank}" if world_size > 1 else ("cuda" if torch.cuda.is_available() else "cpu")
|
| torch.backends.cudnn.benchmark = True
|
|
|
| dataset_name = args.dataset
|
| secs = np.array([i for i in range(1, 17)], dtype=int)
|
|
|
|
|
| lpips_loss_fn = get_loss_fn('lpips', secs, device)
|
| dreamsim_loss_fn = get_loss_fn('dreamsim', secs, device)
|
| fid_metrics_vision = get_loss_fn('fid', secs, device)
|
|
|
| try:
|
| metric_logger = dist.MetricLogger(delimiter=" ")
|
| if rank == 0:
|
| print(f"Evaluating {args.eval_name} {dataset_name} | modals = {args.modals}")
|
|
|
| time_loss_fns = (lpips_loss_fn, dreamsim_loss_fn, fid_metrics_vision)
|
|
|
| with torch.no_grad():
|
| evaluate(
|
| args=args,
|
| dataset_name=dataset_name,
|
| eval_type=args.eval_name,
|
| metric_logger=metric_logger,
|
| loss_fns=time_loss_fns,
|
| gt_dir=args.gt_dir,
|
| exp_dir=args.exp_dir,
|
| secs=secs,
|
| device=device,
|
| rank=rank,
|
| world_size=world_size,
|
| modals=args.modals
|
| )
|
|
|
| output_fn = os.path.join(args.exp_dir, f'{dataset_name}_{args.eval_name}.json')
|
| save_metric_to_disk(metric_logger, output_fn, rank)
|
|
|
| except Exception as e:
|
| if rank == 0:
|
| print(e)
|
| finally:
|
| if dist_torch.is_initialized():
|
| dist_torch.barrier()
|
| dist_torch.destroy_process_group()
|
|
|
|
|
|
|
|
|
|
|
| if __name__ == "__main__":
|
| parser = argparse.ArgumentParser(allow_abbrev=False)
|
|
|
| parser.add_argument("--batch_size", type=int, default=64, help="batch size")
|
| parser.add_argument("--gt_dir", type=str, required=True, help="gt directory")
|
| parser.add_argument("--exp_dir", type=str, required=True, help="experiment directory (also save json here)")
|
| parser.add_argument("--eval_name", type=str, default='time', choices=['time', 'rollout'], help="eval type")
|
| parser.add_argument("--dataset", type=str, required=True, help="dataset name (for metric keys & json name)")
|
| parser.add_argument("--modals", type=str, default="av", choices=["a", "v", "av"],
|
| help="a=audio only (wav), v= image only (png), av=both")
|
|
|
|
|
| parser.add_argument("--fad_model", type=str, default="vggish",
|
| choices=["vggish", "pann", "clap", "encodec"],
|
| help="embedding model for FAD")
|
| parser.add_argument("--fad_sr", type=int, default=16000,
|
| help="sampling rate for FAD")
|
|
|
|
|
| parser.add_argument("--mono", action="store_true",
|
| help="default as stereo, add --mono to mono")
|
| parser.add_argument("--fad_pad_sec", type=float, default=1.0,
|
| help="pad the input of VGGish to x seconds")
|
|
|
| args = parser.parse_args()
|
| main(args)
|
|
|