| import os |
| import json |
| import argparse |
| import numpy as np |
| from glob import glob |
| from tqdm import tqdm |
| import re |
|
|
| |
| try: |
| from src.eval_utils.metrics import RankingMetrics |
| except Exception: |
| RankingMetrics = None |
|
|
| def load_jsonl(path, limit=None): |
| rows = [] |
| with open(path, 'r', encoding='utf-8') as f: |
| for i, line in enumerate(f): |
| if not line.strip(): |
| continue |
| rows.append(json.loads(line)) |
| if limit is not None and len(rows) >= limit: |
| break |
| return rows |
|
|
| def pick_mid_tag_from_dir(out_dir, dataset): |
| paths = glob(os.path.join(out_dir, f"{dataset}_qry_layer*")) |
| tags = [] |
| for p in paths: |
| tag = os.path.basename(p).split(f"{dataset}_qry_")[-1] |
| if tag == "layerlast": |
| continue |
| if tag.startswith("layer"): |
| try: |
| n = int(tag.replace("layer", "")) |
| tags.append((n, tag)) |
| except: |
| pass |
| if not tags: |
| return None |
| tags.sort(key=lambda x: x[0]) |
| return tags[0][1] |
|
|
| def is_numberlike(v): |
| try: |
| _ = float(v) |
| return np.isfinite(v) |
| except Exception: |
| return False |
|
|
| def parse_k_from_metric_key(k: str): |
| m = re.search(r'@(\d+)', k) |
| return int(m.group(1)) if m else None |
|
|
| def resolve_metric_value(score_dict: dict, requested_key: str, preferred_key: str | None = None): |
| if not isinstance(score_dict, dict): |
| return None, None |
| numeric_items = {k: v for k, v in score_dict.items() if is_numberlike(v)} |
| if not numeric_items: |
| return None, None |
|
|
| req = requested_key.strip().lower() |
| if preferred_key and preferred_key in numeric_items: |
| return float(numeric_items[preferred_key]), preferred_key |
|
|
| for k, v in numeric_items.items(): |
| if k.lower() == req: |
| return float(v), k |
|
|
| cands = [] |
| for k, v in numeric_items.items(): |
| kl = k.lower() |
| if kl.startswith(req) and '@' in kl: |
| k_val = parse_k_from_metric_key(k) |
| if k_val is not None: |
| cands.append((k_val, k, v)) |
| if cands: |
| cands.sort(key=lambda x: x[0]) |
| _, k_sel, v_sel = cands[-1] |
| return float(v_sel), k_sel |
|
|
| for k, v in numeric_items.items(): |
| if req in k.lower(): |
| return float(v), k |
|
|
| print(f"[WARN] Cannot resolve metric '{requested_key}'. Available numeric keys: {list(numeric_items.keys())}") |
| return None, None |
|
|
| def metric_value_from_file(score_json_path, metric_key, preferred_key: str | None = None): |
| if not os.path.exists(score_json_path): |
| return None, None |
| d = json.load(open(score_json_path, 'r')) |
| val, actual_key = resolve_metric_value(d, metric_key, preferred_key=preferred_key) |
| return val, actual_key |
|
|
| def get_query_avg_time(stats_path): |
| if not os.path.exists(stats_path): |
| return None |
| d = json.load(open(stats_path, 'r')) |
| return d.get("inference_times", {}).get("avg_inference_time_per_item_seconds", None) |
|
|
| def compute_speedup(coverage, avg_q_time_mid, avg_q_time_last): |
| expected = coverage * avg_q_time_mid + (1.0 - coverage) * avg_q_time_last |
| if expected <= 1e-12: |
| return 1.0 |
| return avg_q_time_last / expected |
|
|
| def analyze_fast_hit1(rows, base_metric, avg_q_time_mid, avg_q_time_last, max_drop, grid_points=41): |
| |
| N = len(rows) |
| y_mid = np.zeros(N, dtype=np.float32) |
| y_last = np.zeros(N, dtype=np.float32) |
| margins = np.zeros(N, dtype=np.float32) |
| s1s = np.zeros(N, dtype=np.float32) |
|
|
| print(f"[FAST][hit@1] Precomputing arrays for {N} queries...") |
| for i, r in enumerate(rows): |
| labels = set(r["label"] if isinstance(r["label"], list) else [r["label"]]) |
| mid_top1 = r["mid"].get("top1", None) |
| last_top1 = r["last"].get("top1", None) |
| y_mid[i] = 1.0 if (mid_top1 in labels) else 0.0 |
| y_last[i] = 1.0 if (last_top1 in labels) else 0.0 |
| margins[i] = float(r["mid"].get("margin", 0.0)) |
| |
| if "cand_scores" in r["mid"] and mid_top1 is not None: |
| s1s[i] = float(r["mid"]["cand_scores"].get(mid_top1, 0.0)) |
| else: |
| s1s[i] = 0.0 |
|
|
| |
| def eval_mask(use_mid_mask: np.ndarray): |
| cov = float(use_mid_mask.mean()) |
| metric = float((use_mid_mask * y_mid + (1.0 - use_mid_mask) * y_last).mean()) |
| speed = compute_speedup(cov, avg_q_time_mid, avg_q_time_last) |
| drop = float(base_metric - metric) |
| return cov, metric, drop, speed |
|
|
| |
| taus = np.quantile(margins, np.linspace(0.0, 1.0, grid_points)).tolist() |
| thrs = np.quantile(s1s, np.linspace(0.0, 1.0, grid_points)).tolist() |
|
|
| |
| margin_sweep = [] |
| best_margin = None |
| print(f"[FAST][hit@1] Sweep margin, points={len(taus)}") |
| for tau in taus: |
| mask = (margins >= tau).astype(np.float32) |
| cov, metric, drop, speed = eval_mask(mask) |
| rec = {"rule": "margin>=tau", "tau": float(tau), "coverage": cov, "metric": metric, "metric_drop": drop, "speedup": speed} |
| margin_sweep.append(rec) |
| if drop <= max_drop and (best_margin is None or speed > best_margin["speedup"]): |
| best_margin = rec |
|
|
| |
| s1_sweep = [] |
| best_s1 = None |
| print(f"[FAST][hit@1] Sweep top1-score, points={len(thrs)}") |
| for t in thrs: |
| mask = (s1s >= t).astype(np.float32) |
| cov, metric, drop, speed = eval_mask(mask) |
| rec = {"rule": "s1>=t", "t": float(t), "coverage": cov, "metric": metric, "metric_drop": drop, "speedup": speed} |
| s1_sweep.append(rec) |
| if drop <= max_drop and (best_s1 is None or speed > best_s1["speedup"]): |
| best_s1 = rec |
|
|
| |
| joint_sweep = [] |
| best_joint = None |
| print(f"[FAST][hit@1] Sweep joint (margin OR s1), grid={len(taus)}x{len(thrs)}") |
| for tau in tqdm(taus, desc="[FAST] joint grid"): |
| m_mask = (margins >= tau) |
| for t in thrs: |
| s_mask = (s1s >= t) |
| mask = (m_mask | s_mask).astype(np.float32) |
| cov, metric, drop, speed = eval_mask(mask) |
| rec = {"rule": "(margin>=tau) OR (s1>=t)", "tau": float(tau), "t": float(t), "coverage": cov, "metric": metric, "metric_drop": drop, "speedup": speed} |
| joint_sweep.append(rec) |
| if drop <= max_drop and (best_joint is None or speed > best_joint["speedup"]): |
| best_joint = rec |
|
|
| |
| conformal_recs = [] |
| print("[FAST][hit@1] Conformal (by margin) ...") |
| qids = np.array([int(r.get("qid", i)) for i, r in enumerate(rows)], dtype=np.int64) |
| for delta in [0.01, 0.02, 0.05, 0.1]: |
| calib = (qids % 5 == 0) |
| if not calib.any(): |
| continue |
| alpha = 1.0 - margins |
| q = float(np.quantile(alpha[calib], 1.0 - delta)) |
| tau = 1.0 - q |
| mask = (margins >= tau).astype(np.float32) |
| cov, metric, drop, speed = eval_mask(mask) |
| conformal_recs.append({ |
| "rule": "conformal_margin", |
| "delta": float(delta), |
| "tau": float(tau), |
| "coverage": cov, |
| "metric": metric, |
| "metric_drop": drop, |
| "speedup": speed |
| }) |
|
|
| |
| cands = [x for x in [best_margin, best_s1, best_joint] if x is not None] |
| recommended = max(cands, key=lambda x: x["speedup"]) if cands else None |
|
|
| return { |
| "margin_sweep": margin_sweep, |
| "top1_sweep": s1_sweep, |
| "margin_or_top1_sweep": joint_sweep, |
| "conformal": conformal_recs, |
| "recommended": recommended |
| } |
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--dir", required=True, help="encode_output_path") |
| parser.add_argument("--dataset", required=True, help="dataset name (prefix)") |
| parser.add_argument("--metric", default="mrr", help="metric key or prefix (e.g., mrr, hit, hit@1, ndcg)") |
| parser.add_argument("--max_drop", type=float, default=0.005, help="allowed metric drop vs last") |
| parser.add_argument("--grid_points", type=int, default=41, help="grid size for thresholds") |
| parser.add_argument("--with_entropy", action="store_true", help="enable entropy pathway (slow)") |
| parser.add_argument("--save", action="store_true", help="save sweep and recommendations") |
| args = parser.parse_args() |
|
|
| out_dir = args.dir |
| dataset = args.dataset |
| metric_key = args.metric |
| max_drop = args.max_drop |
|
|
| combined_path = os.path.join(out_dir, f"{dataset}_score_details_both_layers.jsonl") |
| if not os.path.exists(combined_path): |
| raise FileNotFoundError(f"Combined file not found: {combined_path}. 请先用评测脚本生成该文件。") |
|
|
| last_tag = "layerlast" |
| mid_tag = pick_mid_tag_from_dir(out_dir, dataset) |
| if mid_tag is None: |
| raise RuntimeError(f"未找到 {dataset}_qry_layer*(非 layerlast)。请确认已经生成中间层 embedding。") |
|
|
| |
| last_score_path = os.path.join(out_dir, f"{dataset}_score_{last_tag}.json") |
| base_metric, resolved_metric_key = metric_value_from_file(last_score_path, metric_key, preferred_key=None) |
| if base_metric is None: |
| |
| if RankingMetrics is None: |
| raise RuntimeError(f"无法解析基线指标 '{metric_key}' 且 RankingMetrics 不可用。") |
| rows_tmp = load_jsonl(combined_path) |
| metrics = RankingMetrics([metric_key]) |
| pred_dicts = [] |
| for r in rows_tmp: |
| ranked = sorted(r["last"]["cand_scores"].items(), key=lambda x: -x[1]) |
| labels = r["label"] if isinstance(r["label"], list) else [r["label"]] |
| pred_dicts.append({"prediction": [cid for cid, _ in ranked], "label": labels, "rel_scores": None}) |
| score_fallback = metrics.evaluate(pred_dicts) |
| base_metric, resolved_metric_key = resolve_metric_value(score_fallback, metric_key, preferred_key=None) |
| if base_metric is None: |
| raise RuntimeError(f"无法解析基线指标 '{metric_key}',请查看 {last_score_path} 中的可用数值键。") |
|
|
| print(f"[INFO] Using metric: requested='{metric_key}', resolved='{resolved_metric_key}', base={base_metric:.6f}") |
|
|
| |
| qry_stats_last = os.path.join(out_dir, f"{dataset}_qry_inference_stats_{last_tag}.json") |
| qry_stats_mid = os.path.join(out_dir, f"{dataset}_qry_inference_stats_{mid_tag}.json") |
| avg_q_time_last = get_query_avg_time(qry_stats_last) or 1.0 |
| avg_q_time_mid = get_query_avg_time(qry_stats_mid) or 0.5 |
|
|
| |
| if resolved_metric_key.lower() in {"hit@1", "hit"}: |
| print("[FAST] Detected hit@1 metric. Running fast analyzer...") |
| rows = load_jsonl(combined_path) |
| result = analyze_fast_hit1( |
| rows=rows, |
| base_metric=base_metric, |
| avg_q_time_mid=avg_q_time_mid, |
| avg_q_time_last=avg_q_time_last, |
| max_drop=max_drop, |
| grid_points=args.grid_points |
| ) |
|
|
| print("\n=== Baseline ===") |
| print(f"- {resolved_metric_key}@last: {base_metric:.6f}") |
| print(f"- avg_query_time: last={avg_q_time_last:.6f}s, mid={avg_q_time_mid:.6f}s") |
|
|
| print("\n=== Best under max_drop=%.4f ===" % max_drop) |
| for name in ["margin_sweep", "top1_sweep", "margin_or_top1_sweep"]: |
| sweep = result[name] |
| best = None |
| for rec in sweep: |
| if rec["metric_drop"] <= max_drop: |
| if best is None or rec["speedup"] > best["speedup"]: |
| best = rec |
| if best: |
| params = {k: best[k] for k in ["tau","t"] if k in best} |
| print(f"- {name.replace('_','/')}: speedup={best['speedup']:.3f}, coverage={best['coverage']:.3f}, " |
| f"{resolved_metric_key}={best['metric']:.6f}, drop={best['metric_drop']:.6f}, params={json.dumps(params)}") |
| else: |
| print(f"- {name.replace('_','/')}: no feasible point within drop <= {max_drop}") |
|
|
| if result["recommended"]: |
| print("\n=== Recommended config ===") |
| print(json.dumps(result["recommended"], indent=2)) |
| else: |
| print("\nNo recommendation satisfies the drop constraint; consider relaxing --max_drop or increasing grid.") |
|
|
| if args.save: |
| with open(os.path.join(out_dir, f"{dataset}_early_exit_sweep_summary.json"), "w") as f: |
| json.dump({ |
| "margin_sweep": result["margin_sweep"], |
| "top1_sweep": result["top1_sweep"], |
| "margin_or_top1_sweep": result["margin_or_top1_sweep"], |
| }, f, indent=2) |
| with open(os.path.join(out_dir, f"{dataset}_early_exit_recommendations.json"), "w") as f: |
| json.dump({ |
| "dataset": dataset, |
| "metric_key_requested": metric_key, |
| "metric_key_resolved": resolved_metric_key, |
| "base_metric_last": base_metric, |
| "avg_query_time_last": avg_q_time_last, |
| "avg_query_time_mid": avg_q_time_mid, |
| "mid_tag": mid_tag, |
| "last_tag": "layerlast", |
| "conformal": result["conformal"], |
| "recommended": result["recommended"] |
| }, f, indent=2) |
| return |
|
|
| |
| print("[WARN] 非 hit@1 指标将走慢路径,可能耗时较长。建议先用 --metric hit@1 验证早停思路。") |
| if RankingMetrics is None: |
| raise RuntimeError("RankingMetrics 未导入,无法回退到通用慢路径。") |
|
|
| |
| rows = load_jsonl(combined_path) |
| metrics_to_report = ["hit", "ndcg", "precision", "recall", "f1", "map", "mrr"] |
|
|
| def build_pred_dicts(use_mid_mask): |
| preds = [] |
| for r, use_mid in zip(rows, use_mid_mask): |
| sims = r["mid"]["cand_scores"] if use_mid else r["last"]["cand_scores"] |
| ranked = sorted(sims.items(), key=lambda x: -x[1]) |
| labels = r["label"] if isinstance(r["label"], list) else [r["label"]] |
| preds.append({"prediction": [cid for cid, _ in ranked], "label": labels, "rel_scores": None}) |
| return preds |
|
|
| metrics = RankingMetrics(metrics_to_report) |
|
|
| |
| margins = np.array([float(r["mid"]["margin"]) for r in rows], dtype=np.float32) |
| s1s = np.array([ |
| float(r["mid"]["cand_scores"].get(r["mid"]["top1"], 0.0)) if "cand_scores" in r["mid"] else 0.0 |
| for r in rows |
| ], dtype=np.float32) |
| taus = np.quantile(margins, np.linspace(0.0, 1.0, args.grid_points)).tolist() |
| thrs = np.quantile(s1s, np.linspace(0.0, 1.0, args.grid_points)).tolist() |
|
|
| |
| for tau in tqdm(taus, desc="[SLOW] margin grid"): |
| mask = (margins >= tau).tolist() |
| score = metrics.evaluate(build_pred_dicts(mask)) |
| |
|
|
| |
| for t in tqdm(thrs, desc="[SLOW] top1 grid"): |
| mask = (s1s >= t).tolist() |
| score = metrics.evaluate(build_pred_dicts(mask)) |
|
|
| print("Done.") |
|
|
| if __name__ == "__main__": |
| main() |
|
|
| |
| |