code_SAS_VLM2Vec / analyze_early_exit.py
MgGladys's picture
Add files using upload-large-folder tool
ac8b25b verified
import os
import json
import argparse
import numpy as np
from glob import glob
from tqdm import tqdm
import re
# 仅在非 hit@1 回退慢路径时才会用到
try:
from src.eval_utils.metrics import RankingMetrics
except Exception:
RankingMetrics = None
def load_jsonl(path, limit=None):
rows = []
with open(path, 'r', encoding='utf-8') as f:
for i, line in enumerate(f):
if not line.strip():
continue
rows.append(json.loads(line))
if limit is not None and len(rows) >= limit:
break
return rows
def pick_mid_tag_from_dir(out_dir, dataset):
paths = glob(os.path.join(out_dir, f"{dataset}_qry_layer*"))
tags = []
for p in paths:
tag = os.path.basename(p).split(f"{dataset}_qry_")[-1]
if tag == "layerlast":
continue
if tag.startswith("layer"):
try:
n = int(tag.replace("layer", ""))
tags.append((n, tag))
except:
pass
if not tags:
return None
tags.sort(key=lambda x: x[0])
return tags[0][1]
def is_numberlike(v):
try:
_ = float(v)
return np.isfinite(v)
except Exception:
return False
def parse_k_from_metric_key(k: str):
m = re.search(r'@(\d+)', k)
return int(m.group(1)) if m else None
def resolve_metric_value(score_dict: dict, requested_key: str, preferred_key: str | None = None):
if not isinstance(score_dict, dict):
return None, None
numeric_items = {k: v for k, v in score_dict.items() if is_numberlike(v)}
if not numeric_items:
return None, None
req = requested_key.strip().lower()
if preferred_key and preferred_key in numeric_items:
return float(numeric_items[preferred_key]), preferred_key
for k, v in numeric_items.items():
if k.lower() == req:
return float(v), k
cands = []
for k, v in numeric_items.items():
kl = k.lower()
if kl.startswith(req) and '@' in kl:
k_val = parse_k_from_metric_key(k)
if k_val is not None:
cands.append((k_val, k, v))
if cands:
cands.sort(key=lambda x: x[0])
_, k_sel, v_sel = cands[-1]
return float(v_sel), k_sel
for k, v in numeric_items.items():
if req in k.lower():
return float(v), k
print(f"[WARN] Cannot resolve metric '{requested_key}'. Available numeric keys: {list(numeric_items.keys())}")
return None, None
def metric_value_from_file(score_json_path, metric_key, preferred_key: str | None = None):
if not os.path.exists(score_json_path):
return None, None
d = json.load(open(score_json_path, 'r'))
val, actual_key = resolve_metric_value(d, metric_key, preferred_key=preferred_key)
return val, actual_key
def get_query_avg_time(stats_path):
if not os.path.exists(stats_path):
return None
d = json.load(open(stats_path, 'r'))
return d.get("inference_times", {}).get("avg_inference_time_per_item_seconds", None)
def compute_speedup(coverage, avg_q_time_mid, avg_q_time_last):
expected = coverage * avg_q_time_mid + (1.0 - coverage) * avg_q_time_last
if expected <= 1e-12:
return 1.0
return avg_q_time_last / expected
def analyze_fast_hit1(rows, base_metric, avg_q_time_mid, avg_q_time_last, max_drop, grid_points=41):
# 仅使用行里已有的字段:mid.top1, last.top1, mid.margin, mid.top1_score
N = len(rows)
y_mid = np.zeros(N, dtype=np.float32)
y_last = np.zeros(N, dtype=np.float32)
margins = np.zeros(N, dtype=np.float32)
s1s = np.zeros(N, dtype=np.float32)
print(f"[FAST][hit@1] Precomputing arrays for {N} queries...")
for i, r in enumerate(rows):
labels = set(r["label"] if isinstance(r["label"], list) else [r["label"]])
mid_top1 = r["mid"].get("top1", None)
last_top1 = r["last"].get("top1", None)
y_mid[i] = 1.0 if (mid_top1 in labels) else 0.0
y_last[i] = 1.0 if (last_top1 in labels) else 0.0
margins[i] = float(r["mid"].get("margin", 0.0))
# 取 mid 的 top1 分数(只做一次字典查找)
if "cand_scores" in r["mid"] and mid_top1 is not None:
s1s[i] = float(r["mid"]["cand_scores"].get(mid_top1, 0.0))
else:
s1s[i] = 0.0
# 门控后的 hit@1 = mean( use_mid*y_mid + (1-use_mid)*y_last )
def eval_mask(use_mid_mask: np.ndarray):
cov = float(use_mid_mask.mean())
metric = float((use_mid_mask * y_mid + (1.0 - use_mid_mask) * y_last).mean())
speed = compute_speedup(cov, avg_q_time_mid, avg_q_time_last)
drop = float(base_metric - metric)
return cov, metric, drop, speed
# 以分位点作为阈值网格(对分布更友好)
taus = np.quantile(margins, np.linspace(0.0, 1.0, grid_points)).tolist()
thrs = np.quantile(s1s, np.linspace(0.0, 1.0, grid_points)).tolist()
# 规则1:margin>=tau
margin_sweep = []
best_margin = None
print(f"[FAST][hit@1] Sweep margin, points={len(taus)}")
for tau in taus:
mask = (margins >= tau).astype(np.float32)
cov, metric, drop, speed = eval_mask(mask)
rec = {"rule": "margin>=tau", "tau": float(tau), "coverage": cov, "metric": metric, "metric_drop": drop, "speedup": speed}
margin_sweep.append(rec)
if drop <= max_drop and (best_margin is None or speed > best_margin["speedup"]):
best_margin = rec
# 规则2:top1 score >= t
s1_sweep = []
best_s1 = None
print(f"[FAST][hit@1] Sweep top1-score, points={len(thrs)}")
for t in thrs:
mask = (s1s >= t).astype(np.float32)
cov, metric, drop, speed = eval_mask(mask)
rec = {"rule": "s1>=t", "t": float(t), "coverage": cov, "metric": metric, "metric_drop": drop, "speedup": speed}
s1_sweep.append(rec)
if drop <= max_drop and (best_s1 is None or speed > best_s1["speedup"]):
best_s1 = rec
# 规则3:margin>=tau OR s1>=t
joint_sweep = []
best_joint = None
print(f"[FAST][hit@1] Sweep joint (margin OR s1), grid={len(taus)}x{len(thrs)}")
for tau in tqdm(taus, desc="[FAST] joint grid"):
m_mask = (margins >= tau)
for t in thrs:
s_mask = (s1s >= t)
mask = (m_mask | s_mask).astype(np.float32)
cov, metric, drop, speed = eval_mask(mask)
rec = {"rule": "(margin>=tau) OR (s1>=t)", "tau": float(tau), "t": float(t), "coverage": cov, "metric": metric, "metric_drop": drop, "speedup": speed}
joint_sweep.append(rec)
if drop <= max_drop and (best_joint is None or speed > best_joint["speedup"]):
best_joint = rec
# 规则4:Conformal(margin)
conformal_recs = []
print("[FAST][hit@1] Conformal (by margin) ...")
qids = np.array([int(r.get("qid", i)) for i, r in enumerate(rows)], dtype=np.int64)
for delta in [0.01, 0.02, 0.05, 0.1]:
calib = (qids % 5 == 0)
if not calib.any():
continue
alpha = 1.0 - margins
q = float(np.quantile(alpha[calib], 1.0 - delta))
tau = 1.0 - q
mask = (margins >= tau).astype(np.float32)
cov, metric, drop, speed = eval_mask(mask)
conformal_recs.append({
"rule": "conformal_margin",
"delta": float(delta),
"tau": float(tau),
"coverage": cov,
"metric": metric,
"metric_drop": drop,
"speedup": speed
})
# 推荐:在 drop<=max_drop 中选 speedup 最大
cands = [x for x in [best_margin, best_s1, best_joint] if x is not None]
recommended = max(cands, key=lambda x: x["speedup"]) if cands else None
return {
"margin_sweep": margin_sweep,
"top1_sweep": s1_sweep,
"margin_or_top1_sweep": joint_sweep,
"conformal": conformal_recs,
"recommended": recommended
}
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--dir", required=True, help="encode_output_path")
parser.add_argument("--dataset", required=True, help="dataset name (prefix)")
parser.add_argument("--metric", default="mrr", help="metric key or prefix (e.g., mrr, hit, hit@1, ndcg)")
parser.add_argument("--max_drop", type=float, default=0.005, help="allowed metric drop vs last")
parser.add_argument("--grid_points", type=int, default=41, help="grid size for thresholds")
parser.add_argument("--with_entropy", action="store_true", help="enable entropy pathway (slow)")
parser.add_argument("--save", action="store_true", help="save sweep and recommendations")
args = parser.parse_args()
out_dir = args.dir
dataset = args.dataset
metric_key = args.metric
max_drop = args.max_drop
combined_path = os.path.join(out_dir, f"{dataset}_score_details_both_layers.jsonl")
if not os.path.exists(combined_path):
raise FileNotFoundError(f"Combined file not found: {combined_path}. 请先用评测脚本生成该文件。")
last_tag = "layerlast"
mid_tag = pick_mid_tag_from_dir(out_dir, dataset)
if mid_tag is None:
raise RuntimeError(f"未找到 {dataset}_qry_layer*(非 layerlast)。请确认已经生成中间层 embedding。")
# 基线指标(last)
last_score_path = os.path.join(out_dir, f"{dataset}_score_{last_tag}.json")
base_metric, resolved_metric_key = metric_value_from_file(last_score_path, metric_key, preferred_key=None)
if base_metric is None:
# 回退:尝试从 combined 现算(慢路径才用)
if RankingMetrics is None:
raise RuntimeError(f"无法解析基线指标 '{metric_key}' 且 RankingMetrics 不可用。")
rows_tmp = load_jsonl(combined_path)
metrics = RankingMetrics([metric_key])
pred_dicts = []
for r in rows_tmp:
ranked = sorted(r["last"]["cand_scores"].items(), key=lambda x: -x[1])
labels = r["label"] if isinstance(r["label"], list) else [r["label"]]
pred_dicts.append({"prediction": [cid for cid, _ in ranked], "label": labels, "rel_scores": None})
score_fallback = metrics.evaluate(pred_dicts)
base_metric, resolved_metric_key = resolve_metric_value(score_fallback, metric_key, preferred_key=None)
if base_metric is None:
raise RuntimeError(f"无法解析基线指标 '{metric_key}',请查看 {last_score_path} 中的可用数值键。")
print(f"[INFO] Using metric: requested='{metric_key}', resolved='{resolved_metric_key}', base={base_metric:.6f}")
# 平均时延(query侧)
qry_stats_last = os.path.join(out_dir, f"{dataset}_qry_inference_stats_{last_tag}.json")
qry_stats_mid = os.path.join(out_dir, f"{dataset}_qry_inference_stats_{mid_tag}.json")
avg_q_time_last = get_query_avg_time(qry_stats_last) or 1.0
avg_q_time_mid = get_query_avg_time(qry_stats_mid) or 0.5
# 命中 hit@1 的极速路径
if resolved_metric_key.lower() in {"hit@1", "hit"}:
print("[FAST] Detected hit@1 metric. Running fast analyzer...")
rows = load_jsonl(combined_path) # 只加载一次
result = analyze_fast_hit1(
rows=rows,
base_metric=base_metric,
avg_q_time_mid=avg_q_time_mid,
avg_q_time_last=avg_q_time_last,
max_drop=max_drop,
grid_points=args.grid_points
)
print("\n=== Baseline ===")
print(f"- {resolved_metric_key}@last: {base_metric:.6f}")
print(f"- avg_query_time: last={avg_q_time_last:.6f}s, mid={avg_q_time_mid:.6f}s")
print("\n=== Best under max_drop=%.4f ===" % max_drop)
for name in ["margin_sweep", "top1_sweep", "margin_or_top1_sweep"]:
sweep = result[name]
best = None
for rec in sweep:
if rec["metric_drop"] <= max_drop:
if best is None or rec["speedup"] > best["speedup"]:
best = rec
if best:
params = {k: best[k] for k in ["tau","t"] if k in best}
print(f"- {name.replace('_','/')}: speedup={best['speedup']:.3f}, coverage={best['coverage']:.3f}, "
f"{resolved_metric_key}={best['metric']:.6f}, drop={best['metric_drop']:.6f}, params={json.dumps(params)}")
else:
print(f"- {name.replace('_','/')}: no feasible point within drop <= {max_drop}")
if result["recommended"]:
print("\n=== Recommended config ===")
print(json.dumps(result["recommended"], indent=2))
else:
print("\nNo recommendation satisfies the drop constraint; consider relaxing --max_drop or increasing grid.")
if args.save:
with open(os.path.join(out_dir, f"{dataset}_early_exit_sweep_summary.json"), "w") as f:
json.dump({
"margin_sweep": result["margin_sweep"],
"top1_sweep": result["top1_sweep"],
"margin_or_top1_sweep": result["margin_or_top1_sweep"],
}, f, indent=2)
with open(os.path.join(out_dir, f"{dataset}_early_exit_recommendations.json"), "w") as f:
json.dump({
"dataset": dataset,
"metric_key_requested": metric_key,
"metric_key_resolved": resolved_metric_key,
"base_metric_last": base_metric,
"avg_query_time_last": avg_q_time_last,
"avg_query_time_mid": avg_q_time_mid,
"mid_tag": mid_tag,
"last_tag": "layerlast",
"conformal": result["conformal"],
"recommended": result["recommended"]
}, f, indent=2)
return
# 非 hit@1 情况:回退(提示会慢)
print("[WARN] 非 hit@1 指标将走慢路径,可能耗时较长。建议先用 --metric hit@1 验证早停思路。")
if RankingMetrics is None:
raise RuntimeError("RankingMetrics 未导入,无法回退到通用慢路径。")
# 慢路径:加载 rows 并用 RankingMetrics 多次评估(不做熵/温度,仍会慢)
rows = load_jsonl(combined_path)
metrics_to_report = ["hit", "ndcg", "precision", "recall", "f1", "map", "mrr"]
def build_pred_dicts(use_mid_mask):
preds = []
for r, use_mid in zip(rows, use_mid_mask):
sims = r["mid"]["cand_scores"] if use_mid else r["last"]["cand_scores"]
ranked = sorted(sims.items(), key=lambda x: -x[1])
labels = r["label"] if isinstance(r["label"], list) else [r["label"]]
preds.append({"prediction": [cid for cid, _ in ranked], "label": labels, "rel_scores": None})
return preds
metrics = RankingMetrics(metrics_to_report)
# 预计算 margin / s1(慢路径也尽量少排序)
margins = np.array([float(r["mid"]["margin"]) for r in rows], dtype=np.float32)
s1s = np.array([
float(r["mid"]["cand_scores"].get(r["mid"]["top1"], 0.0)) if "cand_scores" in r["mid"] else 0.0
for r in rows
], dtype=np.float32)
taus = np.quantile(margins, np.linspace(0.0, 1.0, args.grid_points)).tolist()
thrs = np.quantile(s1s, np.linspace(0.0, 1.0, args.grid_points)).tolist()
# margin
for tau in tqdm(taus, desc="[SLOW] margin grid"):
mask = (margins >= tau).tolist()
score = metrics.evaluate(build_pred_dicts(mask))
# 可按需打印或保存
# top1
for t in tqdm(thrs, desc="[SLOW] top1 grid"):
mask = (s1s >= t).tolist()
score = metrics.evaluate(build_pred_dicts(mask))
print("Done.")
if __name__ == "__main__":
main()
# 运行示例:
# python /home/v-menggao/code/VLM2Vec/analyze_early_exit.py --dir /home/v-menggao/code/VLM2Vec/~/exps/Qwen2vl_2B.image_qry_20_none+cand_20_none_0.1/VLM2Vec-V2.0-Qwen2VL-2B/image_retrival --dataset CIRR --metric hit@1 --max_drop 0.005 --topk 200