import os import json import argparse import numpy as np from glob import glob from tqdm import tqdm import re # 仅在非 hit@1 回退慢路径时才会用到 try: from src.eval_utils.metrics import RankingMetrics except Exception: RankingMetrics = None def load_jsonl(path, limit=None): rows = [] with open(path, 'r', encoding='utf-8') as f: for i, line in enumerate(f): if not line.strip(): continue rows.append(json.loads(line)) if limit is not None and len(rows) >= limit: break return rows def pick_mid_tag_from_dir(out_dir, dataset): paths = glob(os.path.join(out_dir, f"{dataset}_qry_layer*")) tags = [] for p in paths: tag = os.path.basename(p).split(f"{dataset}_qry_")[-1] if tag == "layerlast": continue if tag.startswith("layer"): try: n = int(tag.replace("layer", "")) tags.append((n, tag)) except: pass if not tags: return None tags.sort(key=lambda x: x[0]) return tags[0][1] def is_numberlike(v): try: _ = float(v) return np.isfinite(v) except Exception: return False def parse_k_from_metric_key(k: str): m = re.search(r'@(\d+)', k) return int(m.group(1)) if m else None def resolve_metric_value(score_dict: dict, requested_key: str, preferred_key: str | None = None): if not isinstance(score_dict, dict): return None, None numeric_items = {k: v for k, v in score_dict.items() if is_numberlike(v)} if not numeric_items: return None, None req = requested_key.strip().lower() if preferred_key and preferred_key in numeric_items: return float(numeric_items[preferred_key]), preferred_key for k, v in numeric_items.items(): if k.lower() == req: return float(v), k cands = [] for k, v in numeric_items.items(): kl = k.lower() if kl.startswith(req) and '@' in kl: k_val = parse_k_from_metric_key(k) if k_val is not None: cands.append((k_val, k, v)) if cands: cands.sort(key=lambda x: x[0]) _, k_sel, v_sel = cands[-1] return float(v_sel), k_sel for k, v in numeric_items.items(): if req in k.lower(): return float(v), k print(f"[WARN] Cannot resolve metric '{requested_key}'. Available numeric keys: {list(numeric_items.keys())}") return None, None def metric_value_from_file(score_json_path, metric_key, preferred_key: str | None = None): if not os.path.exists(score_json_path): return None, None d = json.load(open(score_json_path, 'r')) val, actual_key = resolve_metric_value(d, metric_key, preferred_key=preferred_key) return val, actual_key def get_query_avg_time(stats_path): if not os.path.exists(stats_path): return None d = json.load(open(stats_path, 'r')) return d.get("inference_times", {}).get("avg_inference_time_per_item_seconds", None) def compute_speedup(coverage, avg_q_time_mid, avg_q_time_last): expected = coverage * avg_q_time_mid + (1.0 - coverage) * avg_q_time_last if expected <= 1e-12: return 1.0 return avg_q_time_last / expected def analyze_fast_hit1(rows, base_metric, avg_q_time_mid, avg_q_time_last, max_drop, grid_points=41): # 仅使用行里已有的字段:mid.top1, last.top1, mid.margin, mid.top1_score N = len(rows) y_mid = np.zeros(N, dtype=np.float32) y_last = np.zeros(N, dtype=np.float32) margins = np.zeros(N, dtype=np.float32) s1s = np.zeros(N, dtype=np.float32) print(f"[FAST][hit@1] Precomputing arrays for {N} queries...") for i, r in enumerate(rows): labels = set(r["label"] if isinstance(r["label"], list) else [r["label"]]) mid_top1 = r["mid"].get("top1", None) last_top1 = r["last"].get("top1", None) y_mid[i] = 1.0 if (mid_top1 in labels) else 0.0 y_last[i] = 1.0 if (last_top1 in labels) else 0.0 margins[i] = float(r["mid"].get("margin", 0.0)) # 取 mid 的 top1 分数(只做一次字典查找) if "cand_scores" in r["mid"] and mid_top1 is not None: s1s[i] = float(r["mid"]["cand_scores"].get(mid_top1, 0.0)) else: s1s[i] = 0.0 # 门控后的 hit@1 = mean( use_mid*y_mid + (1-use_mid)*y_last ) def eval_mask(use_mid_mask: np.ndarray): cov = float(use_mid_mask.mean()) metric = float((use_mid_mask * y_mid + (1.0 - use_mid_mask) * y_last).mean()) speed = compute_speedup(cov, avg_q_time_mid, avg_q_time_last) drop = float(base_metric - metric) return cov, metric, drop, speed # 以分位点作为阈值网格(对分布更友好) taus = np.quantile(margins, np.linspace(0.0, 1.0, grid_points)).tolist() thrs = np.quantile(s1s, np.linspace(0.0, 1.0, grid_points)).tolist() # 规则1:margin>=tau margin_sweep = [] best_margin = None print(f"[FAST][hit@1] Sweep margin, points={len(taus)}") for tau in taus: mask = (margins >= tau).astype(np.float32) cov, metric, drop, speed = eval_mask(mask) rec = {"rule": "margin>=tau", "tau": float(tau), "coverage": cov, "metric": metric, "metric_drop": drop, "speedup": speed} margin_sweep.append(rec) if drop <= max_drop and (best_margin is None or speed > best_margin["speedup"]): best_margin = rec # 规则2:top1 score >= t s1_sweep = [] best_s1 = None print(f"[FAST][hit@1] Sweep top1-score, points={len(thrs)}") for t in thrs: mask = (s1s >= t).astype(np.float32) cov, metric, drop, speed = eval_mask(mask) rec = {"rule": "s1>=t", "t": float(t), "coverage": cov, "metric": metric, "metric_drop": drop, "speedup": speed} s1_sweep.append(rec) if drop <= max_drop and (best_s1 is None or speed > best_s1["speedup"]): best_s1 = rec # 规则3:margin>=tau OR s1>=t joint_sweep = [] best_joint = None print(f"[FAST][hit@1] Sweep joint (margin OR s1), grid={len(taus)}x{len(thrs)}") for tau in tqdm(taus, desc="[FAST] joint grid"): m_mask = (margins >= tau) for t in thrs: s_mask = (s1s >= t) mask = (m_mask | s_mask).astype(np.float32) cov, metric, drop, speed = eval_mask(mask) rec = {"rule": "(margin>=tau) OR (s1>=t)", "tau": float(tau), "t": float(t), "coverage": cov, "metric": metric, "metric_drop": drop, "speedup": speed} joint_sweep.append(rec) if drop <= max_drop and (best_joint is None or speed > best_joint["speedup"]): best_joint = rec # 规则4:Conformal(margin) conformal_recs = [] print("[FAST][hit@1] Conformal (by margin) ...") qids = np.array([int(r.get("qid", i)) for i, r in enumerate(rows)], dtype=np.int64) for delta in [0.01, 0.02, 0.05, 0.1]: calib = (qids % 5 == 0) if not calib.any(): continue alpha = 1.0 - margins q = float(np.quantile(alpha[calib], 1.0 - delta)) tau = 1.0 - q mask = (margins >= tau).astype(np.float32) cov, metric, drop, speed = eval_mask(mask) conformal_recs.append({ "rule": "conformal_margin", "delta": float(delta), "tau": float(tau), "coverage": cov, "metric": metric, "metric_drop": drop, "speedup": speed }) # 推荐:在 drop<=max_drop 中选 speedup 最大 cands = [x for x in [best_margin, best_s1, best_joint] if x is not None] recommended = max(cands, key=lambda x: x["speedup"]) if cands else None return { "margin_sweep": margin_sweep, "top1_sweep": s1_sweep, "margin_or_top1_sweep": joint_sweep, "conformal": conformal_recs, "recommended": recommended } def main(): parser = argparse.ArgumentParser() parser.add_argument("--dir", required=True, help="encode_output_path") parser.add_argument("--dataset", required=True, help="dataset name (prefix)") parser.add_argument("--metric", default="mrr", help="metric key or prefix (e.g., mrr, hit, hit@1, ndcg)") parser.add_argument("--max_drop", type=float, default=0.005, help="allowed metric drop vs last") parser.add_argument("--grid_points", type=int, default=41, help="grid size for thresholds") parser.add_argument("--with_entropy", action="store_true", help="enable entropy pathway (slow)") parser.add_argument("--save", action="store_true", help="save sweep and recommendations") args = parser.parse_args() out_dir = args.dir dataset = args.dataset metric_key = args.metric max_drop = args.max_drop combined_path = os.path.join(out_dir, f"{dataset}_score_details_both_layers.jsonl") if not os.path.exists(combined_path): raise FileNotFoundError(f"Combined file not found: {combined_path}. 请先用评测脚本生成该文件。") last_tag = "layerlast" mid_tag = pick_mid_tag_from_dir(out_dir, dataset) if mid_tag is None: raise RuntimeError(f"未找到 {dataset}_qry_layer*(非 layerlast)。请确认已经生成中间层 embedding。") # 基线指标(last) last_score_path = os.path.join(out_dir, f"{dataset}_score_{last_tag}.json") base_metric, resolved_metric_key = metric_value_from_file(last_score_path, metric_key, preferred_key=None) if base_metric is None: # 回退:尝试从 combined 现算(慢路径才用) if RankingMetrics is None: raise RuntimeError(f"无法解析基线指标 '{metric_key}' 且 RankingMetrics 不可用。") rows_tmp = load_jsonl(combined_path) metrics = RankingMetrics([metric_key]) pred_dicts = [] for r in rows_tmp: ranked = sorted(r["last"]["cand_scores"].items(), key=lambda x: -x[1]) labels = r["label"] if isinstance(r["label"], list) else [r["label"]] pred_dicts.append({"prediction": [cid for cid, _ in ranked], "label": labels, "rel_scores": None}) score_fallback = metrics.evaluate(pred_dicts) base_metric, resolved_metric_key = resolve_metric_value(score_fallback, metric_key, preferred_key=None) if base_metric is None: raise RuntimeError(f"无法解析基线指标 '{metric_key}',请查看 {last_score_path} 中的可用数值键。") print(f"[INFO] Using metric: requested='{metric_key}', resolved='{resolved_metric_key}', base={base_metric:.6f}") # 平均时延(query侧) qry_stats_last = os.path.join(out_dir, f"{dataset}_qry_inference_stats_{last_tag}.json") qry_stats_mid = os.path.join(out_dir, f"{dataset}_qry_inference_stats_{mid_tag}.json") avg_q_time_last = get_query_avg_time(qry_stats_last) or 1.0 avg_q_time_mid = get_query_avg_time(qry_stats_mid) or 0.5 # 命中 hit@1 的极速路径 if resolved_metric_key.lower() in {"hit@1", "hit"}: print("[FAST] Detected hit@1 metric. Running fast analyzer...") rows = load_jsonl(combined_path) # 只加载一次 result = analyze_fast_hit1( rows=rows, base_metric=base_metric, avg_q_time_mid=avg_q_time_mid, avg_q_time_last=avg_q_time_last, max_drop=max_drop, grid_points=args.grid_points ) print("\n=== Baseline ===") print(f"- {resolved_metric_key}@last: {base_metric:.6f}") print(f"- avg_query_time: last={avg_q_time_last:.6f}s, mid={avg_q_time_mid:.6f}s") print("\n=== Best under max_drop=%.4f ===" % max_drop) for name in ["margin_sweep", "top1_sweep", "margin_or_top1_sweep"]: sweep = result[name] best = None for rec in sweep: if rec["metric_drop"] <= max_drop: if best is None or rec["speedup"] > best["speedup"]: best = rec if best: params = {k: best[k] for k in ["tau","t"] if k in best} print(f"- {name.replace('_','/')}: speedup={best['speedup']:.3f}, coverage={best['coverage']:.3f}, " f"{resolved_metric_key}={best['metric']:.6f}, drop={best['metric_drop']:.6f}, params={json.dumps(params)}") else: print(f"- {name.replace('_','/')}: no feasible point within drop <= {max_drop}") if result["recommended"]: print("\n=== Recommended config ===") print(json.dumps(result["recommended"], indent=2)) else: print("\nNo recommendation satisfies the drop constraint; consider relaxing --max_drop or increasing grid.") if args.save: with open(os.path.join(out_dir, f"{dataset}_early_exit_sweep_summary.json"), "w") as f: json.dump({ "margin_sweep": result["margin_sweep"], "top1_sweep": result["top1_sweep"], "margin_or_top1_sweep": result["margin_or_top1_sweep"], }, f, indent=2) with open(os.path.join(out_dir, f"{dataset}_early_exit_recommendations.json"), "w") as f: json.dump({ "dataset": dataset, "metric_key_requested": metric_key, "metric_key_resolved": resolved_metric_key, "base_metric_last": base_metric, "avg_query_time_last": avg_q_time_last, "avg_query_time_mid": avg_q_time_mid, "mid_tag": mid_tag, "last_tag": "layerlast", "conformal": result["conformal"], "recommended": result["recommended"] }, f, indent=2) return # 非 hit@1 情况:回退(提示会慢) print("[WARN] 非 hit@1 指标将走慢路径,可能耗时较长。建议先用 --metric hit@1 验证早停思路。") if RankingMetrics is None: raise RuntimeError("RankingMetrics 未导入,无法回退到通用慢路径。") # 慢路径:加载 rows 并用 RankingMetrics 多次评估(不做熵/温度,仍会慢) rows = load_jsonl(combined_path) metrics_to_report = ["hit", "ndcg", "precision", "recall", "f1", "map", "mrr"] def build_pred_dicts(use_mid_mask): preds = [] for r, use_mid in zip(rows, use_mid_mask): sims = r["mid"]["cand_scores"] if use_mid else r["last"]["cand_scores"] ranked = sorted(sims.items(), key=lambda x: -x[1]) labels = r["label"] if isinstance(r["label"], list) else [r["label"]] preds.append({"prediction": [cid for cid, _ in ranked], "label": labels, "rel_scores": None}) return preds metrics = RankingMetrics(metrics_to_report) # 预计算 margin / s1(慢路径也尽量少排序) margins = np.array([float(r["mid"]["margin"]) for r in rows], dtype=np.float32) s1s = np.array([ float(r["mid"]["cand_scores"].get(r["mid"]["top1"], 0.0)) if "cand_scores" in r["mid"] else 0.0 for r in rows ], dtype=np.float32) taus = np.quantile(margins, np.linspace(0.0, 1.0, args.grid_points)).tolist() thrs = np.quantile(s1s, np.linspace(0.0, 1.0, args.grid_points)).tolist() # margin for tau in tqdm(taus, desc="[SLOW] margin grid"): mask = (margins >= tau).tolist() score = metrics.evaluate(build_pred_dicts(mask)) # 可按需打印或保存 # top1 for t in tqdm(thrs, desc="[SLOW] top1 grid"): mask = (s1s >= t).tolist() score = metrics.evaluate(build_pred_dicts(mask)) print("Done.") if __name__ == "__main__": main() # 运行示例: # python /home/v-menggao/code/VLM2Vec/analyze_early_exit.py --dir /home/v-menggao/code/VLM2Vec/~/exps/Qwen2vl_2B.image_qry_20_none+cand_20_none_0.1/VLM2Vec-V2.0-Qwen2VL-2B/image_retrival --dataset CIRR --metric hit@1 --max_drop 0.005 --topk 200