| |
| """ |
| Fréchet Inception Distance(FID),与 evaluator 中 sqrt(Σ₁Σ₂) 路径对齐。 |
| 具体哪几行会引入虚部/负值,见 fid() 内按行注释与 --diagnose 输出。 |
| """ |
| import argparse |
| import os |
| import warnings |
|
|
| import numpy as np |
| from scipy import linalg |
| import tensorflow.compat.v1 as tf |
|
|
| from evaluator import Evaluator |
|
|
|
|
| def _trace_sqrt_symmetric_psd(M): |
| """ |
| 对“应当对称 PSD”的矩阵 M 计算 Tr(sqrt(M))。 |
| |
| 这里不直接用 scipy.linalg.sqrtm(它会对数值上“几乎 PSD”的矩阵也返回复数), |
| 而是: |
| - 先对称化:M ← (M+M^T)/2 |
| - 做对称特征分解:M = Q diag(λ) Q^T |
| - 将数值误差造成的负特征值截断到 0(PSD 投影) |
| - Tr(sqrt(M)) = Σ sqrt(max(λ,0)) |
| """ |
| M = (M + M.T) / 2.0 |
|
|
| if not np.isfinite(M).all(): |
| warnings.warn("mid matrix contains NaN/Inf; returning NaN for Tr(sqrt(M))") |
| return float("nan") |
|
|
| |
| |
| I = np.eye(M.shape[0], dtype=M.dtype) |
| trace = float(np.trace(M)) |
| diag_jitter = 1e-6 * (trace / max(M.shape[0], 1)) |
| if not np.isfinite(diag_jitter) or diag_jitter <= 0: |
| diag_jitter = 1e-6 |
|
|
| try: |
| w = np.linalg.eigvalsh(M) |
| except np.linalg.LinAlgError: |
| w = np.linalg.eigvalsh(M + diag_jitter * I) |
|
|
| w_min = float(w.min()) if w.size else 0.0 |
| if w_min < -1e-6: |
| warnings.warn( |
| "PSD 中间矩阵出现明显负特征值 λ_min={:.3e};已截断到 0 以稳定 Tr(sqrt(M))".format( |
| w_min |
| ) |
| ) |
| w = np.clip(w, 0.0, None) |
| return float(np.sqrt(w).sum()) |
|
|
|
|
| def fid_psd_geometric(mn1, cov1, mn2, cov2): |
| """ |
| Fréchet 项用 Tr( sqrt( Σ₁^{1/2} Σ₂ Σ₁^{1/2} ) )(中间矩阵对称 PSD), |
| 与 sqrt(Σ₁Σ₂) 在一般情形下不等价;此值在精确算术下应对应非负距离平方(数值上可能极小负)。 |
| 用于对照:若本函数非负而 fid() 为负,问题在 sqrt(Σ₁Σ₂) 与 .real 那几行。 |
| """ |
| mn1 = np.atleast_1d(mn1) |
| mn2 = np.atleast_1d(mn2) |
| cov1 = np.atleast_2d(cov1) |
| cov2 = np.atleast_2d(cov2) |
| cov1 = (cov1 + cov1.T) / 2.0 |
| cov2 = (cov2 + cov2.T) / 2.0 |
| diff = mn1 - mn2 |
|
|
| |
| w1, v1 = np.linalg.eigh(cov1) |
| w1_min = float(w1.min()) if w1.size else 0.0 |
| if w1_min < -1e-6: |
| warnings.warn( |
| "Σ₁ 出现明显负特征值 λ_min={:.3e};已截断到 0 以稳定 Σ₁^{1/2}".format(w1_min) |
| ) |
| w1 = np.clip(w1, 0.0, None) |
| s1 = (v1 * np.sqrt(w1)) @ v1.T |
| mid = s1 @ cov2 @ s1 |
| tr_cross = _trace_sqrt_symmetric_psd(mid) |
| return float(diff.dot(diff) + np.trace(cov1) + np.trace(cov2) - 2.0 * tr_cross) |
|
|
|
|
| def fid(mn1, cov1, mn2, cov2, eps=1e-6, diagnose=False, cross_type="psd_geometric"): |
| """ |
| ‖μ₁−μ₂‖² + Tr(Σ₁)+Tr(Σ₂) − 2·Tr(交叉项)。 |
| |
| - `cross_type="product"`: 交叉项用 `Tr(sqrt(Σ₁Σ₂))`(对齐 evaluator 的 sqrt(σ1·σ2) 路径)。 |
| - `cross_type="psd_geometric"`: 交叉项用 `Tr(sqrt(Σ₁^{1/2} Σ₂ Σ₁^{1/2}))`(中间矩阵对称 PSD,数值更稳)。 |
| |
| diagnose=True 时返回 (fid_value, diag_dict),用于定位是哪类数值失配。 |
| """ |
| mn1 = np.atleast_1d(mn1) |
| mn2 = np.atleast_1d(mn2) |
|
|
| cov1 = np.atleast_2d(cov1) |
| cov2 = np.atleast_2d(cov2) |
|
|
| if mn1.shape != mn2.shape: |
| raise ValueError(f"mean shape mismatch: {mn1.shape} vs {mn2.shape}") |
| if cov1.shape != cov2.shape: |
| raise ValueError(f"cov shape mismatch: {cov1.shape} vs {cov2.shape}") |
|
|
| diff = mn1 - mn2 |
|
|
| |
| asym1 = np.max(np.abs(cov1 - cov1.T)) |
| asym2 = np.max(np.abs(cov2 - cov2.T)) |
| cov1 = (cov1 + cov1.T) / 2.0 |
| cov2 = (cov2 + cov2.T) / 2.0 |
| d = cov1.shape[0] |
|
|
| |
| need_product = (cross_type == "product") or diagnose |
|
|
| fid_product = None |
| product_diag = {} |
| if need_product: |
| prod = cov1.dot(cov2) |
| prod_sym_err = np.max(np.abs(prod - prod.T)) |
|
|
| jitter_list = [0.0, eps, 1e-5, 1e-4, 1e-3] |
| covmean = None |
| j_used = None |
| for j in jitter_list: |
| if j == 0.0: |
| |
| cm, _ = linalg.sqrtm(cov1.dot(cov2), disp=False) |
| else: |
| offset = np.eye(d) * j |
| |
| cm, _ = linalg.sqrtm((cov1 + offset).dot(cov2 + offset), disp=False) |
| if np.isfinite(cm).all(): |
| covmean = cm |
| j_used = j |
| break |
| if covmean is None: |
| raise ValueError("FID sqrtm failed: non-finite covmean even after jitter retries") |
|
|
| had_imag = bool(np.iscomplexobj(covmean)) |
| imag_max = float(np.max(np.abs(covmean.imag))) if had_imag else 0.0 |
| if had_imag: |
| |
| if imag_max > 1e-3: |
| warnings.warn( |
| "Large imaginary component in sqrtm ({:.6f}); taking real part.".format(imag_max) |
| ) |
| covmean = covmean.real |
|
|
| tr_covmean = np.trace(covmean) |
| fid_product = float(diff.dot(diff) + np.trace(cov1) + np.trace(cov2) - 2.0 * tr_covmean) |
|
|
| product_diag = { |
| "asym_before_symmetrize_max": (float(asym1), float(asym2)), |
| "prod_symmetry_error": float(prod_sym_err), |
| "jitter_used": float(j_used) if j_used is not None else None, |
| "had_imaginary_sqrtm": had_imag, |
| "imag_max": imag_max, |
| "fid_sqrt_product_path": fid_product, |
| } |
|
|
| if cross_type == "psd_geometric" and not diagnose: |
| |
| return fid_psd_geometric(mn1, cov1, mn2, cov2) |
|
|
| |
| fid_psd = None |
| fid_psd_err = None |
| try: |
| fid_psd = fid_psd_geometric(mn1, cov1, mn2, cov2) |
| except Exception as e: |
| fid_psd = float("nan") |
| fid_psd_err = repr(e) |
|
|
| if cross_type == "product": |
| fid_val = fid_product |
| else: |
| fid_val = fid_psd |
|
|
| if not diagnose: |
| return float(fid_val) |
|
|
| |
| wlist = [] |
| if asym1 > 1e-4 or asym2 > 1e-4: |
| wlist.append( |
| "对称化前协方差不对称量过大 (max asym1={:.3e}, asym2={:.3e})".format(asym1, asym2) |
| ) |
| if product_diag.get("prod_symmetry_error", 0.0) > 1e-2: |
| wlist.append( |
| "Σ₁Σ₂ 与对称矩阵偏离较大 (max|A-A^T|={:.3e}),sqrtm 复值风险高".format( |
| product_diag["prod_symmetry_error"] |
| ) |
| ) |
| if product_diag.get("jitter_used") and product_diag["jitter_used"] > 0.0: |
| wlist.append( |
| "使用了 jitter={}:sqrtm 与 Tr(Σ₁)+Tr(Σ₂) 项不一致,负 FID 风险".format( |
| product_diag["jitter_used"] |
| ) |
| ) |
| if product_diag.get("had_imaginary_sqrtm"): |
| wlist.append( |
| "sqrtm(Σ₁Σ₂) 含虚部 (max|Im|={:.3e}),已取 .real,与理论 Fréchet 项可能失配".format( |
| product_diag["imag_max"] |
| ) |
| ) |
| if product_diag.get("fid_sqrt_product_path") is not None and product_diag["fid_sqrt_product_path"] < 0: |
| wlist.append("product 路径 FID<0(数值上不应出现)") |
| if not np.isnan(fid_psd) and fid_psd >= -1e-5 and product_diag.get("fid_sqrt_product_path", 0.0) < 0: |
| wlist.append( |
| "对照 fid_psd_geometric≈{:.6f} 非负,负值主要来自 sqrt(Σ₁Σ₂)/.real/jitter 路径".format(fid_psd) |
| ) |
|
|
| diag = { |
| **product_diag, |
| "fid_psd_geometric": fid_psd, |
| "fid_psd_geometric_error": fid_psd_err, |
| "warnings": wlist, |
| } |
| return float(fid_val), diag |
|
|
|
|
| def print_diagnosis(diag): |
| print("--- FID 诊断(按代码风险点)---") |
| a0, a1 = diag.get("asym_before_symmetrize_max", (None, None)) |
| if a0 is not None: |
| print(f" 对称化前 |Σ-Σ^T|_max: ref={a0:.3e}, sample={a1:.3e}") |
| print(f" |Σ₁Σ₂ - (Σ₁Σ₂)^T|_max: {diag.get('prod_symmetry_error', float('nan')):.3e}") |
| print(f" sqrtm 使用的 jitter: {diag.get('jitter_used', None)}") |
| if diag.get("had_imaginary_sqrtm"): |
| print( |
| f" sqrtm(Σ₁Σ₂) 含虚部: {diag['had_imaginary_sqrtm']}, max|Im|: {diag['imag_max']:.3e}" |
| ) |
| else: |
| print(f" sqrtm(Σ₁Σ₂) 含虚部: {diag.get('had_imaginary_sqrtm', False)}") |
| if diag.get("fid_sqrt_product_path") is not None: |
| print(f" FID(sqrt(Σ₁Σ₂) 路径): {diag['fid_sqrt_product_path']:.6f}") |
| if diag.get("fid_psd_geometric_error"): |
| print(f" FID(PSD 几何平均路径) 未算: {diag['fid_psd_geometric_error']}") |
| else: |
| print(f" FID(PSD 几何平均路径): {diag.get('fid_psd_geometric', float('nan')):.6f}") |
| for s in diag["warnings"]: |
| print(f" [检查] {s}") |
| print("---") |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument( |
| "--ref_batch", |
| default="/gemini/space/zhaozy/guzhenyu/UAVFlow/UAV_Flow_base/exps/jsflow-experiment/samples/VIRTUAL_imagenet256_labeled/imagenet/VIRTUAL_imagenet256_labeled.npz", |
| help="reference npz (arr_0 images or stats npz)", |
| ) |
| parser.add_argument("--sample_batch", required=True, help="sample npz path") |
| parser.add_argument("--save_txt", type=str, default=None, help="optional txt output path") |
| parser.add_argument( |
| "--cross-type", |
| type=str, |
| default="psd_geometric", |
| choices=["product", "psd_geometric"], |
| help="交叉项计算方式:product=Tr(sqrt(Σ₁Σ₂))(对齐 evaluator),psd_geometric=Tr(sqrt(Σ₁^{1/2}Σ₂Σ₁^{1/2}))(更稳)", |
| ) |
| parser.add_argument( |
| "--diagnose", |
| action="store_true", |
| help="打印逐行风险检测:jitter/虚部/Σ₁Σ₂ 对称性,并对照 PSD 几何平均路径", |
| ) |
| args = parser.parse_args() |
|
|
| |
| os.environ["CUDA_VISIBLE_DEVICES"] = "" |
| config = tf.ConfigProto(allow_soft_placement=True, device_count={"GPU": 0}) |
| evaluator = Evaluator(tf.Session(config=config)) |
|
|
| print("warming up TensorFlow...") |
| evaluator.warmup() |
|
|
| print("computing reference batch activations...") |
| ref_acts = evaluator.read_activations(args.ref_batch) |
| print("computing/reading reference batch statistics...") |
| ref_stats, _ = evaluator.read_statistics(args.ref_batch, ref_acts) |
|
|
| print("computing sample batch activations...") |
| sample_acts = evaluator.read_activations(args.sample_batch) |
| print("computing/reading sample batch statistics...") |
| sample_stats, _ = evaluator.read_statistics(args.sample_batch, sample_acts) |
|
|
| print("Computing custom FID...") |
| |
| if args.diagnose: |
| fid_value, diag = fid( |
| sample_stats.mu, |
| sample_stats.sigma, |
| ref_stats.mu, |
| ref_stats.sigma, |
| diagnose=True, |
| cross_type=args.cross_type, |
| ) |
| print_diagnosis(diag) |
| else: |
| fid_value = fid( |
| sample_stats.mu, |
| sample_stats.sigma, |
| ref_stats.mu, |
| ref_stats.sigma, |
| cross_type=args.cross_type, |
| ) |
| print(f"FID(custom): {fid_value}") |
|
|
| if args.save_txt: |
| with open(args.save_txt, "w", encoding="utf-8") as f: |
| f.write(f"ref_batch: {args.ref_batch}\n") |
| f.write(f"sample_batch: {args.sample_batch}\n") |
| f.write(f"FID(custom): {fid_value}\n") |
| if args.diagnose: |
| f.write(f"diagnose: jitter_used={diag['jitter_used']}, imag_max={diag['imag_max']}\n") |
| for s in diag["warnings"]: |
| f.write(f" {s}\n") |
| print(f"Saved report to {args.save_txt}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|
|
|