#!/usr/bin/env python3 """ 合并体检:从 NPZ 读 arr_0,做 1) 精确重复(SHA1 全图/整行特征字节) 2) 可选粗近重复(下采样灰度位图 hash,仅图像) 3) 可选:与 fid_custom 相同 Evaluator,算样本协方差特征值与有效秩,并可与参考 NPZ 对比 示例: python npz_health_check.py --npz /path/to/samples.npz python npz_health_check.py --npz samples.npz --ref /path/to/VIRTUAL_imagenet256_labeled.npz python npz_health_check.py --npz samples.npz --no-tf """ from __future__ import annotations import argparse import hashlib import os import sys from collections import Counter import numpy as np def _sigma_diag(sigma: np.ndarray): sigma = np.atleast_2d(sigma) sigma = (sigma + sigma.T) / 2.0 eig = np.linalg.eigvalsh(sigma) eig = np.clip(eig, 0.0, None) lam_min = float(eig.min()) lam_max = float(eig.max()) cond = float(lam_max / max(lam_min, 1e-30)) s1 = float(eig.sum()) s2 = float((eig**2).sum()) erank = float((s1 * s1) / s2) if s2 > 0 else 0.0 return { "dim": int(eig.size), "eig_min": lam_min, "eig_max": lam_max, "cond": cond, "effective_rank": erank, } def _dup_exact_images(arr: np.ndarray): """arr: (N,H,W,3) uint8 或任意可 tobytes 的 slice。""" n = arr.shape[0] hs = [] for i in range(n): hs.append(hashlib.sha1(np.asarray(arr[i]).tobytes()).digest()) cnt = Counter(hs) unique = len(cnt) dup_imgs = n - unique max_group = max(cnt.values()) if cnt else 0 dup_groups = sum(1 for v in cnt.values() if v > 1) return { "n": n, "unique": unique, "dup_count": dup_imgs, "dup_ratio_pct": 100.0 * dup_imgs / max(n, 1), "dup_groups": dup_groups, "largest_group": max_group, } def _dup_exact_rows(feat: np.ndarray): n = feat.shape[0] hs = [] for i in range(n): hs.append(hashlib.sha1(np.asarray(feat[i]).tobytes()).digest()) cnt = Counter(hs) unique = len(cnt) dup = n - unique max_group = max(cnt.values()) if cnt else 0 return { "n": n, "unique": unique, "dup_count": dup, "dup_ratio_pct": 100.0 * dup / max(n, 1), "dup_groups": sum(1 for v in cnt.values() if v > 1), "largest_group": max_group, } def _near_dup_coarse_images(arr: np.ndarray, stride: int = 16): """16x16 级灰度平均 hash,仅启发式。""" n = arr.shape[0] small = arr[:, ::stride, ::stride, :].astype(np.float32).mean(axis=3) mean = small.mean(axis=(1, 2), keepdims=True) bits = (small > mean).reshape(n, -1) packed = np.packbits(bits, axis=1) hs = [hashlib.sha1(packed[i].tobytes()).digest() for i in range(n)] cnt = Counter(hs) near_dup = n - len(cnt) return { "stride": stride, "near_dup_count": near_dup, "near_dup_ratio_pct": 100.0 * near_dup / max(n, 1), "near_unique": len(cnt), } def main(): p = argparse.ArgumentParser(description="NPZ 样本体检:重复 + 协方差谱(可选 TF)") p.add_argument("--npz", required=True, help="含 arr_0 的 .npz(图像或后续由 Evaluator 读图)") p.add_argument( "--ref", default=None, help="可选参考 NPZ(含 arr_0 或预存 mu/sigma),用于对比 effective_rank / cond", ) p.add_argument("--no-tf", action="store_true", help="不跑 Inception,仅做重复与(无特征时无谱)") p.add_argument("--no-dup", action="store_true", help="跳过重复检测") p.add_argument("--near-dup", action="store_true", help="启用粗近重复(图像)") p.add_argument("--near-dup-stride", type=int, default=16, help="粗近重复下采样步长") args = p.parse_args() path = os.path.abspath(args.npz) if not os.path.isfile(path): print(f"文件不存在: {path}", file=sys.stderr) sys.exit(1) data = np.load(path, mmap_mode="r") if "arr_0" not in data.files: print("NPZ 中无 arr_0", file=sys.stderr) sys.exit(1) arr = data["arr_0"] print(f"=== NPZ: {path}") print(f"arr_0 shape={arr.shape}, dtype={arr.dtype}") dup_report = None # if not args.no_dup: # if arr.ndim == 4 and arr.shape[-1] == 3: # dup_report = _dup_exact_images(arr) # print("\n--- 精确重复(整图字节 SHA1)---") # print( # f" N={dup_report['n']}, 唯一图={dup_report['unique']}, " # f"重复张数={dup_report['dup_count']} ({dup_report['dup_ratio_pct']:.4f}%)" # ) # print( # f" 重复组数={dup_report['dup_groups']}, 最大组大小={dup_report['largest_group']}" # ) # if args.near_dup: # nd = _near_dup_coarse_images(np.asarray(arr), stride=args.near_dup_stride) # print("\n--- 粗近重复(启发式,下采样 stride={})---".format(nd["stride"])) # print( # f" 粗唯一={nd['near_unique']}, 粗近重复≈{nd['near_dup_count']} ({nd['near_dup_ratio_pct']:.4f}%)" # ) # elif arr.ndim == 2: # dup_report = _dup_exact_rows(np.asarray(arr)) # print("\n--- 精确重复(特征行字节 SHA1)---") # print( # f" N={dup_report['n']}, 唯一行={dup_report['unique']}, " # f"重复行={dup_report['dup_count']} ({dup_report['dup_ratio_pct']:.4f}%)" # ) # else: # print("\n--- 跳过重复:arr_0 形状非 (N,H,W,3) 或 (N,D) ---") if args.no_tf: print("\n--no-tf:跳过 Inception 协方差谱。") return try: import tensorflow.compat.v1 as tf except ImportError as e: print("未安装 tensorflow,无法跑 Inception 诊断。可用 --no-tf 仅查重复。", file=sys.stderr) raise SystemExit(1) from e from evaluator import Evaluator os.environ["CUDA_VISIBLE_DEVICES"] = "" config = tf.ConfigProto(allow_soft_placement=True, device_count={"GPU": 0}) evaluator = Evaluator(tf.Session(config=config)) print("\n--- TensorFlow Inception 特征(与 fid_custom 一致)---") print("warming up...") evaluator.warmup() print("computing sample activations & statistics...") acts = evaluator.read_activations(path) stats, _ = evaluator.read_statistics(path, acts) s = _sigma_diag(stats.sigma) print("\n--- 样本协方差 Σ(pool 特征)---") print(f" dim={s['dim']}") print(f" λ_min={s['eig_min']:.3e}, λ_max={s['eig_max']:.3e}, cond={s['cond']:.3e}") print(f" effective_rank (tr²/tr²)={s['effective_rank']:.2f}") if args.ref: ref_path = os.path.abspath(args.ref) if not os.path.isfile(ref_path): print(f"参考文件不存在: {ref_path}", file=sys.stderr) else: print("\ncomputing reference activations & statistics...") ref_acts = evaluator.read_activations(ref_path) ref_stats, _ = evaluator.read_statistics(ref_path, ref_acts) r = _sigma_diag(ref_stats.sigma) print("\n--- 参考协方差 Σ_ref ---") print(f" λ_min={r['eig_min']:.3e}, cond={r['cond']:.3e}, effective_rank={r['effective_rank']:.2f}") print("\n--- 对比(样本 vs 参考)---") print(f" cond 比: sample/ref = {s['cond'] / max(r['cond'], 1e-30):.3f}") print(f" effective_rank 比: sample/ref = {s['effective_rank'] / max(r['effective_rank'], 1e-30):.3f}") print("\n--- 简要判读 ---") if s["cond"] > 1e12: print(" cond 极大:协方差近奇异,FID sqrtm 易不稳。") if s["effective_rank"] < 50 and s["dim"] >= 512: print(" effective_rank 偏低:特征分布可能塌缩在低维子空间。") if dup_report is not None: if dup_report["dup_ratio_pct"] > 1.0: print(" 精确重复率 >1%:检查采样/打包流程。") else: print(" 精确重复率很低:重复样本多半不是主因。") if __name__ == "__main__": main()