code_SAS_VLM2Vec / eval_test_time_early_exit.py
MgGladys's picture
Add files using upload-large-folder tool
2a40e7a verified
# import datetime
# import logging
# import json
# import random
# import time
# import numpy as np
# import os
# import pickle
# import sys
# import torch
# import torch.distributed as dist
# import torch.nn.functional as F
# import yaml
# import transformers
# import math
# from torch.utils.data import DataLoader
# from tqdm import tqdm
# from transformers import HfArgumentParser, AutoConfig, AutoTokenizer
# from datasets import Dataset, concatenate_datasets
# from datasets.distributed import split_dataset_by_node
# from src.model.vlm_backbone.qwen2_vl.modeling_qwen2_vl_train_tokrnpooling import Qwen2VLForConditionalGeneration as _Qwen2VLForConditionalGeneration_src
# from src.arguments import ModelArguments, DataArguments, TrainingArguments
# from src.data.collator.eval_collator import MultimodalEvalDataCollator
# from src.data.eval_dataset.base_eval_dataset import AutoEvalPairDataset, generate_cand_dataset
# from src.eval_utils.metrics import RankingMetrics
# from src.model.model_cut_layer_AOP_add_text_cut import MMEBModel
# from src.model.processor import get_backbone_name, load_processor, COLPALI
# from src.utils import batch_to_device, print_rank, print_master
# from dataclasses import dataclass
# def get_env_mid_layer():
# v = os.environ.get("MID_LM_LAYER", "").strip()
# if v == "" or v.lower() in {"none", "null"}:
# return None
# try:
# return int(v)
# except:
# logger.warning(f"Invalid MID_LM_LAYER={v}, ignore.")
# return None
# # ------------- AOP-Prune config parsing -------------
# def _parse_bool(v: str, default=False):
# if v is None: return default
# v = v.strip().lower()
# return v in {"1","true","yes","y","t","on"}
# def _parse_float(v: str, default=None):
# try: return float(v) if v is not None else default
# except: return default
# def _parse_int(v: str, default=None):
# try: return int(v) if v is not None else default
# except: return default
# def get_env_aop_config():
# """
# 从环境变量读取 AOP 剪裁配置。仅作为“驱动层”的简要测试开关;
# 实际剪裁逻辑在底模里(Qwen2-VLModel.forward)实现。
# """
# enabled = _parse_bool(os.environ.get("AOP_ENABLED"), False)
# apply_to = os.environ.get("AOP_APPLY", "qry").strip().lower() # qry|cand|both
# layer_idx = _parse_int(os.environ.get("AOP_LAYER"), None)
# mode = os.environ.get("AOP_MODE", "delta").strip().lower()
# # 通用回退
# delta = _parse_float(os.environ.get("AOP_DELTA"), 0.10)
# khat = _parse_float(os.environ.get("AOP_KHAT"), 1.0)
# keep_ratio = _parse_float(os.environ.get("AOP_KEEP_RATIO"), 1.0)
# min_keep = _parse_int(os.environ.get("AOP_MIN_KEEP"), 64)
# use_bias = _parse_bool(os.environ.get("AOP_USE_BIAS"), True)
# # 按类型控制
# prune_vision = _parse_bool(os.environ.get("AOP_PRUNE_VISION"), True)
# prune_text = _parse_bool(os.environ.get("AOP_PRUNE_TEXT"), False)
# delta_v = _parse_float(os.environ.get("AOP_DELTA_VISION"), None)
# khat_v = _parse_float(os.environ.get("AOP_KHAT_VISION"), None)
# keep_ratio_v= _parse_float(os.environ.get("AOP_KEEP_RATIO_VISION"), None)
# min_keep_v = _parse_int(os.environ.get("AOP_MIN_KEEP_VISION"), None)
# delta_t = _parse_float(os.environ.get("AOP_DELTA_TEXT"), None)
# khat_t = _parse_float(os.environ.get("AOP_KHAT_TEXT"), None)
# keep_ratio_t= _parse_float(os.environ.get("AOP_KEEP_RATIO_TEXT"), None)
# min_keep_t = _parse_int(os.environ.get("AOP_MIN_KEEP_TEXT"), 32)
# protect_text_last = _parse_int(os.environ.get("AOP_PROTECT_TEXT_LAST"), 16)
# protect_special = _parse_bool(os.environ.get("AOP_PROTECT_SPECIAL"), True)
# margin_src = os.environ.get("AOP_MARGIN", "").strip().lower() # "" or "mid"
# attn_impl = os.environ.get("AOP_ATTN_IMPL", "").strip().lower() # "" or "sdpa"
# if layer_idx is None and enabled:
# logger.warning("AOP_ENABLED=1 但未设置 AOP_LAYER,关闭 AOP。"); enabled=False
# # 新增:选择策略(aop | random)
# selection = os.environ.get("AOP_SELECTION", "aop").strip().lower()
# if _parse_bool(os.environ.get("AOP_RANDOM"), False):
# selection = "random"
# random_seed = _parse_int(os.environ.get("AOP_RANDOM_SEED"), None)
# # 选择策略:aop | random | attention
# selection = os.environ.get("AOP_SELECTION", "aop").strip().lower()
# if _parse_bool(os.environ.get("AOP_RANDOM"), False):
# selection = "random"
# random_seed = _parse_int(os.environ.get("AOP_RANDOM_SEED"), None)
# attn_agg = os.environ.get("AOP_ATTENTION_AGG", "mean").strip().lower() # mean|max|sum
# cfg = {
# "enabled": enabled,
# "apply_to": apply_to,
# "layer_idx": layer_idx,
# "mode": mode,
# # 回退
# "delta": delta, "K_hat": khat,
# "keep_ratio": keep_ratio, "min_keep": min_keep,
# "use_bias": use_bias, "eps": 1e-6,
# # 类型开关
# "prune_vision": prune_vision,
# "prune_text": prune_text,
# # 视觉桶
# "delta_vision": delta_v,
# "K_hat_vision": khat_v,
# "keep_ratio_vision": keep_ratio_v,
# "min_keep_vision": min_keep_v,
# # 文本桶
# "delta_text": delta_t,
# "K_hat_text": khat_t,
# "keep_ratio_text": keep_ratio_t,
# "min_keep_text": min_keep_t,
# # 文本保护
# "protect_text_last": protect_text_last,
# "protect_special": protect_special,
# # 可选:排名安全预算
# "margin_mid": None if margin_src != "mid" else "USE_MID_MARGIN",
# "epsilon_hat": None,
# "attn_impl_override": attn_impl if attn_impl in {"sdpa"} else "",
# # NEW: 选择策略
# "selection": selection, # "aop" 或 "random"
# "random_seed": random_seed, # 可选
# "attn_agg": attn_agg,
# }
# return cfg
# def get_env_eval_layers():
# """
# 解析环境变量 LM_LAYERS(优先)或兼容旧的 MID_LM_LAYER。
# - LM_LAYERS 示例:"4,8,12,last";可包含 'last'/'none'/'null'/'-1' 表示最后一层(None)。
# - 若未设置 LM_LAYERS,则回落到旧逻辑:MID_LM_LAYER=None -> [None];否则 [mid, None]
# 返回: list[ int | None ],例如 [4, 8, 12, None];None 代表最后一层。
# """
# v = os.environ.get("LM_LAYERS", None)
# if v is not None:
# v = v.strip()
# if v:
# toks = [t.strip() for t in v.split(',') if t.strip() != ""]
# layers = []
# for tok in toks:
# tl = tok.lower()
# if tl in {"last", "none", "null", "-1"}:
# layers.append(None)
# else:
# try:
# val = int(tok)
# if val > 0:
# layers.append(val)
# else:
# logger.warning(f"Ignoring non-positive layer '{tok}' in LM_LAYERS.")
# except Exception:
# logger.warning(f"Invalid token '{tok}' in LM_LAYERS; must be int or 'last'/'none'.")
# # 去重但保持顺序
# seen = set()
# uniq = []
# for l in layers:
# key = -1 if l is None else l
# if key in seen:
# continue
# seen.add(key)
# uniq.append(l)
# if not uniq:
# return [None]
# return uniq
# else:
# # 兼容旧逻辑
# mid = get_env_mid_layer()
# return [None] if mid is None else [mid, None]
# # === Early-Exit config & helpers ===
# def get_env_ee_config():
# ee_enabled = os.environ.get("EE_ENABLED", "0").strip().lower() in {"1","true","yes","on","y","t"}
# layer = int(os.environ.get("EE_LAYER", os.environ.get("AOP_LAYER", "12"))) # 默认用 AOP_LAYER
# method = os.environ.get("EE_METHOD", "margin").strip().lower() # margin|p1p2|entropy|gini|combined
# tau = float(os.environ.get("EE_TAU", "0.2"))
# topk = int(os.environ.get("EE_TOPK", "1024"))
# temp = float(os.environ.get("EE_TEMP", "0.05"))
# save = os.environ.get("EE_SAVE", "1").strip().lower() in {"1","true","yes","on","y","t"}
# combw = os.environ.get("EE_COMB_WEIGHTS", "1.0,0.5,0.5")
# try:
# w_margin, w_conf, w_sq = [float(x) for x in combw.split(",")]
# except Exception:
# w_margin, w_conf, w_sq = 1.0, 0.5, 0.5
# return dict(
# enabled=ee_enabled, layer=layer, method=method, tau=tau,
# topk=topk, temp=temp, save=save,
# w_margin=w_margin, w_conf=w_conf, w_sq=w_sq
# )
# def _softmax_np(x: np.ndarray, temp: float = 1.0) -> np.ndarray:
# x = x - np.max(x)
# ex = np.exp(x / max(1e-6, temp))
# s = np.sum(ex)
# return ex / max(s, 1e-12)
# def confidence_from_topk(scores: np.ndarray, method="margin", temp=0.05, w_margin=1.0, w_conf=0.5, w_sq=0.5) -> float:
# # scores: 已按降序排列(topK)
# if scores.size == 0:
# return 0.0
# if scores.size == 1:
# return 1e9
# margin = float(scores[0] - scores[1])
# p = _softmax_np(scores, temp=temp)
# p1p2 = float(p[0] - p[1])
# H = - float(np.sum(p * np.log(p + 1e-12))) / np.log(len(p)) # 归一化熵 ∈ [0,1]
# conf = 1.0 - H
# sqsum = float(np.sum(p**2)) # Gini 的等价度量(越大越集中)
# if method == "margin": return margin
# if method == "p1p2": return p1p2
# if method == "entropy": return conf
# if method == "gini": return sqsum
# # combined
# return w_margin*margin + w_conf*conf + w_sq*sqsum
# def run_early_exit_queries(
# model: MMEBModel,
# processor,
# model_args: ModelArguments,
# data_args: DataArguments,
# training_args: TrainingArguments,
# qry_dataset: Dataset,
# cand_mid_dict: dict,
# cand_last_dict: dict,
# ee_cfg: dict,
# dataset_name: str,
# out_dir: str,
# global_ranking: bool = True,
# ):
# device = training_args.device
# local_rank = dist.get_rank() if dist.is_initialized() else 0
# is_main = (not dist.is_initialized()) or (local_rank == 0)
# # 候选矩阵 -> GPU(bfloat16)
# cand_ids = list(cand_mid_dict.keys())
# cand_id2row = {str(cid): i for i, cid in enumerate(cand_ids)}
# # cand_mid = np.stack([cand_mid_dict[c] for c in cand_ids]).astype(np.float32)
# # cand_last = np.stack([cand_last_dict[c] for c in cand_ids]).astype(np.float32)
# # cand_mid_t = torch.from_numpy(cand_mid).to(device=device, dtype=torch.bfloat16)
# # cand_last_t = torch.from_numpy(cand_last).to(device=device, dtype=torch.bfloat16)
# cand_mid = np.stack([cand_mid_dict[c] for c in cand_ids]).astype(np.float32)
# cand_last = np.stack([cand_last_dict[c] for c in cand_ids]).astype(np.float32)
# # 先搬 cand_mid 到 GPU;cand_last 延迟到真的需要续跑时再搬
# cand_mid_t = torch.from_numpy(cand_mid).to(device=device, dtype=torch.bfloat16)
# cand_last_t = None # NEW: 延迟到 need_idx>0 分支内首次使用时再构造
# # DataLoader(仅 query)
# collator = MultimodalEvalDataCollator(processor, model_args, data_args, "qry")
# loader = DataLoader(
# qry_dataset,
# batch_size=training_args.per_device_eval_batch_size,
# collate_fn=collator,
# num_workers=training_args.dataloader_num_workers
# )
# pred_dicts = []
# details = []
# # AOP 按侧门控(仅对 query 生效)
# aop_cfg = getattr(model.encoder, "aop_prune_config", None)
# _orig_enabled = None
# side_enable = True
# if isinstance(aop_cfg, dict) and aop_cfg:
# _orig_enabled = aop_cfg.get("enabled", False)
# apply_to = aop_cfg.get("apply_to", "qry")
# side_enable = (apply_to == "both") or (apply_to == "qry")
# # 门控用的 k(margin/p1p2 只需要 top2)
# k_conf = 2
# tau = float(ee_cfg["tau"])
# method= ee_cfg["method"]
# temp = float(ee_cfg["temp"])
# start_time = time.time()
# idx_global = 0
# for inputs, infos in tqdm(loader, desc=f"[EE] {dataset_name}@L{ee_cfg['layer']} (rank {local_rank})", disable=local_rank>0):
# inputs = batch_to_device(inputs, device)
# # if isinstance(aop_cfg, dict) and aop_cfg:
# # aop_cfg["enabled"] = bool(_orig_enabled and side_enable)
# # setattr(model.encoder, "aop_prune_config", aop_cfg)
# # # 1) 前半程:跑到中间层(stop_at_layer),跳过 logits
# # with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
# # out_mid = model.encoder(
# # **inputs,
# # return_dict=True,
# # output_hidden_states=False,
# # stop_at_layer=int(ee_cfg["layer"]),
# # compute_lm_head=False, # 关键:不算 logits
# # )
# orig_cfg = None
# if isinstance(aop_cfg, dict) and aop_cfg:
# orig_cfg = dict(aop_cfg) # 备份原配置,mid 后恢复
# aop_layer = aop_cfg.get("layer_idx", None)
# ee_layer = int(ee_cfg["layer"])
# apply_to = aop_cfg.get("apply_to", "qry").strip().lower()
# # 新规则:mid 阶段是否启用 AOP
# aop_on_mid = bool(
# _orig_enabled and side_enable and
# (aop_layer is not None) and (aop_layer < ee_layer) and
# (apply_to in {"qry", "both"})
# )
# aop_cfg_mid = dict(aop_cfg)
# aop_cfg_mid["enabled"] = aop_on_mid
# setattr(model.encoder, "aop_prune_config", aop_cfg_mid)
# # 1) 前半程:跑到中间层(stop_at_layer),跳过 logits
# with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
# out_mid = model.encoder(
# **inputs,
# return_dict=True,
# output_hidden_states=False,
# stop_at_layer=int(ee_cfg["layer"]),
# compute_lm_head=False, # 不算 logits
# )
# # 恢复原始 AOP 配置(避免影响后续续跑逻辑)
# if isinstance(orig_cfg, dict):
# setattr(model.encoder, "aop_prune_config", orig_cfg)
# # EOS 池化 -> GPU
# hs_mid = getattr(out_mid, "last_hidden_state", None)
# if hs_mid is None:
# assert out_mid.hidden_states is not None and len(out_mid.hidden_states) > 0
# hs_mid = out_mid.hidden_states[-1]
# am_mid = getattr(out_mid, "attention_mask", None)
# if am_mid is None:
# am_mid = inputs.get("attention_mask", None)
# if hasattr(am_mid, "device") and am_mid.device != hs_mid.device:
# am_mid = am_mid.to(hs_mid.device)
# reps_mid_t = model._pooling(hs_mid, am_mid).detach().to(device=device, dtype=torch.bfloat16) # [B,D]
# B = reps_mid_t.size(0)
# use_local = (not global_ranking)
# # 2) 门控:GPU上 top2 + 置信度
# if not use_local:
# # 全库:top2 即可
# scores_t = reps_mid_t @ cand_mid_t.T # [B, Nc]
# vals_t, idxs_t = torch.topk(scores_t, k=min(k_conf, scores_t.size(1)), dim=1) # [B,2]
# p_t = torch.softmax(vals_t / max(temp, 1e-6), dim=1)
# if vals_t.size(1) >= 2:
# margin_t = vals_t[:, 0] - vals_t[:, 1]
# p1p2_t = p_t[:, 0] - p_t[:, 1]
# else:
# margin_t = torch.full((B,), float("inf"), device=device, dtype=vals_t.dtype)
# p1p2_t = torch.ones(B, device=device, dtype=vals_t.dtype)
# H_t = -(p_t * (torch.log(p_t + 1e-12))).sum(dim=1) / math.log(max(vals_t.size(1),1))
# conf_map = {"margin": margin_t, "p1p2": p1p2_t, "entropy": 1.0 - H_t, "gini": (p_t ** 2).sum(dim=1)}
# confs_t = conf_map.get(method, margin_t) # [B]
# exit_mask = (confs_t >= tau).detach().cpu().numpy().astype(bool)
# else:
# confs = []
# for b in range(B):
# cand_local = infos[b]["cand_names"]
# rows = [cand_id2row.get(str(cid), -1) for cid in cand_local]
# rows = [r for r in rows if r >= 0]
# if len(rows) == 0:
# confs.append(0.0); continue
# cmat_t = cand_mid_t[rows] # [Nl, D]
# sv_t = (reps_mid_t[b:b+1] @ cmat_t.T)[0] # [Nl]
# k = 2 if sv_t.size(0) >= 2 else 1
# vals_t, _ = torch.topk(sv_t, k=k, dim=0)
# p_t = torch.softmax(vals_t / max(temp, 1e-6), dim=0)
# if k >= 2:
# margin = (vals_t[0] - vals_t[1]).item()
# p1p2 = (p_t[0] - p_t[1]).item()
# else:
# margin, p1p2 = float("inf"), 1.0
# H = (-(p_t * (torch.log(p_t + 1e-12))).sum() / math.log(max(k,1))).item()
# gini = ((p_t ** 2).sum()).item()
# d = {"margin": margin, "p1p2": p1p2, "entropy": 1.0 - H, "gini": gini}
# confs.append(d.get(method, margin))
# exit_mask = (np.array(confs) >= tau)
# # 3) 早停:直接 mid 排序(只在需要保存 details 时构建 topk 列表)
# for j in np.where(exit_mask)[0].tolist():
# if not use_local:
# scores_j = (reps_mid_t[j:j+1] @ cand_mid_t.T)[0] # [Nc]
# order = torch.argsort(scores_j, dim=0, descending=True).detach().cpu().numpy()
# cids = [cand_ids[i] for i in order]
# else:
# cand_local = infos[j]["cand_names"]
# rows = [cand_id2row.get(str(cid), -1) for cid in cand_local]
# rows = [r for r in rows if r >= 0]
# if len(rows) == 0:
# cids = []
# else:
# cmat_t = cand_mid_t[rows]
# vec = (reps_mid_t[j:j+1] @ cmat_t.T)[0]
# order_local = torch.argsort(vec, dim=0, descending=True).detach().cpu().numpy()
# cids = [str(cand_local[i]) for i in order_local]
# rel_docids = infos[j]["label_name"] if isinstance(infos[j]["label_name"], list) else [infos[j]["label_name"]]
# pred_dicts.append({"prediction": cids, "label": rel_docids, "rel_scores": None})
# # 4) 续跑:仅对未早停子集,从中间态继续到 last(跳过 logits)
# need_idx = np.where(~exit_mask)[0].tolist()
# if len(need_idx) > 0:
# if cand_last_t is None:
# cand_last_t = torch.from_numpy(cand_last).to(device=device, dtype=torch.bfloat16)
# if isinstance(aop_cfg, dict) and aop_cfg:
# aop_resume = dict(aop_cfg)
# aop_resume["enabled"] = bool(_orig_enabled and side_enable)
# setattr(model.encoder, "aop_prune_config", aop_resume)
# interm = getattr(out_mid, "intermediate_state", None)
# assert interm is not None, "Model must return intermediate_state when stop_at_layer is set."
# hs = interm["hidden_states"].detach()
# am = interm["attention_mask"].detach()
# pos = interm["position_ids"].detach()
# vm = interm.get("vision_mask", None)
# tm = interm.get("text_mask", None)
# next_layer = int(interm["next_layer_idx"])
# hs_sub = hs[need_idx]
# am_sub = am[need_idx]
# pos_sub = pos[:, need_idx, :]
# vm_sub = vm[need_idx] if vm is not None else None
# tm_sub = tm[need_idx] if tm is not None else None
# resume_state = {
# "hidden_states": hs_sub,
# "attention_mask": am_sub,
# "position_ids": pos_sub,
# "vision_mask": vm_sub,
# "text_mask": tm_sub,
# "next_layer_idx": next_layer,
# }
# with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
# out_last = model.encoder(
# return_dict=True,
# output_hidden_states=False,
# stop_at_layer=None,
# resume_state=resume_state,
# compute_lm_head=False, # 关键:不算 logits
# )
# hs_last = getattr(out_last, "last_hidden_state", None)
# if hs_last is None:
# assert out_last.hidden_states is not None and len(out_last.hidden_states) > 0
# hs_last = out_last.hidden_states[-1]
# am_last = getattr(out_last, "attention_mask", None)
# if am_last is None:
# am_last = am_sub
# if hasattr(am_last, "device") and am_last.device != hs_last.device:
# am_last = am_last.to(hs_last.device)
# reps_last_t = model._pooling(hs_last, am_last).detach().to(device=device, dtype=torch.bfloat16)
# if not use_local:
# scores_last_t = reps_last_t @ cand_last_t.T
# order_t = torch.argsort(scores_last_t, dim=1, descending=True)
# for k, j in enumerate(need_idx):
# order = order_t[k].detach().cpu().numpy()
# cids = [cand_ids[i] for i in order]
# rel_docids = infos[j]["label_name"] if isinstance(infos[j]["label_name"], list) else [infos[j]["label_name"]]
# pred_dicts.append({"prediction": cids, "label": rel_docids, "rel_scores": None})
# else:
# for k, j in enumerate(need_idx):
# cand_local = infos[j]["cand_names"]
# rows = [cand_id2row.get(str(cid), -1) for cid in cand_local]
# rows = [r for r in rows if r >= 0]
# if len(rows) == 0:
# cids = []
# else:
# cmat_last_t = cand_last_t[rows]
# vec_t = (reps_last_t[k:k+1] @ cmat_last_t.T)[0]
# order_local = torch.argsort(vec_t, dim=0, descending=True).detach().cpu().numpy()
# cids = [str(cand_local[i]) for i in order_local]
# rel_docids = infos[j]["label_name"] if isinstance(infos[j]["label_name"], list) else [infos[j]["label_name"]]
# pred_dicts.append({"prediction": cids, "label": rel_docids, "rel_scores": None})
# idx_global += B
# # 评测并保存
# metrics_to_report = ["hit", "ndcg", "precision", "recall", "f1", "map", "mrr"]
# score = RankingMetrics(metrics_to_report).evaluate(pred_dicts)
# if is_main:
# os.makedirs(out_dir, exist_ok=True)
# with open(os.path.join(out_dir, f"{dataset_name}_score_earlyexit.json"), "w") as f:
# json.dump(score, f, indent=4)
# # 建议测速时 EE_SAVE=0,不写 details
# if ee_cfg.get("save", False):
# with open(os.path.join(out_dir, f"{dataset_name}_pred_earlyexit.jsonl"), "w", encoding="utf-8") as f:
# for row in pred_dicts: f.write(json.dumps(row, ensure_ascii=False) + "\n")
# elapsed = time.time() - start_time
# return score, elapsed
# def make_layer_tag(keep_layers: int | None):
# return f"layer{keep_layers}" if keep_layers and keep_layers > 0 else "layerlast"
# def dot_sim(a: np.ndarray, b: np.ndarray) -> np.ndarray:
# # a: [Nq, D], b: [Nc, D], both L2-normalized already if normalize=true
# return a @ b.T
# def build_score_details(qid: int, cand_ids: list, score_vec: np.ndarray, ranked_indices: np.ndarray):
# return {
# "qid": int(qid),
# "cand_scores": [
# {"cand_id": str(cand_ids[i]), "score": float(score_vec[i])}
# for i in ranked_indices
# ]
# }
# def top1_top2_margin(score_vec: np.ndarray) -> float:
# if len(score_vec) < 2:
# return float("inf") # 只有一个候选时视作极大margin
# top2 = np.partition(score_vec, -2)[-2:]
# top2.sort()
# return float(top2[-1] - top2[-2])
# def simulate_early_exit_by_margin(
# sims_mid: list[dict], sims_last: list[dict], labels: list[list[str]], metrics_to_report: list[str],
# taus: list[float], rank_global: bool
# ):
# """
# sims_mid / sims_last: 每个query一个dict: {cand_id: score}
# labels: 每个query的正样本cand_id列表
# 返回:不同tau下的覆盖率、指标
# """
# assert len(sims_mid) == len(sims_last) == len(labels)
# N = len(labels)
# results = []
# from src.eval_utils.metrics import RankingMetrics
# metrics = RankingMetrics(metrics_to_report)
# # 预构造 用于metrics.evaluate 的pred_dict
# def to_pred_dicts(use_mid_mask: list[bool]) -> list[dict]:
# pred_dicts = []
# for qid in range(N):
# sims_use = sims_mid[qid] if use_mid_mask[qid] else sims_last[qid]
# # 排序
# ranked = sorted(sims_use.items(), key=lambda x: -x[1])
# pred_dicts.append({
# "prediction": [cid for cid, _ in ranked],
# "label": labels[qid],
# "rel_scores": None
# })
# return pred_dicts
# # 计算中间层margin
# margins = []
# for qid in range(N):
# # 取前两大分数的margin
# if len(sims_mid[qid]) == 0:
# margins.append(0.0)
# continue
# scores = np.array(list(sims_mid[qid].values()), dtype=np.float32)
# margins.append(top1_top2_margin(scores))
# margins = np.array(margins, dtype=np.float32)
# for tau in taus:
# use_mid_mask = (margins >= tau).tolist()
# pred_dicts = to_pred_dicts(use_mid_mask)
# score_dict = metrics.evaluate(pred_dicts)
# coverage = float(np.mean(use_mid_mask)) # 早停覆盖率
# results.append({
# "tau": tau,
# "coverage": coverage,
# **score_dict
# })
# return results
# def top1_top2_margin_from_array(score_vec: np.ndarray) -> float:
# if score_vec is None or len(score_vec) == 0:
# return 0.0
# if len(score_vec) == 1:
# return float('inf')
# # 取前两大
# top2 = np.partition(score_vec, -2)[-2:]
# top2.sort()
# return float(top2[-1] - top2[-2])
# logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s [%(name)s:%(lineno)s] %(message)s')
# logger = logging.getLogger(__name__)
# # --- Global Dictionaries for Hooks (will be cleared before each encode_embeddings call) ---
# timing_info = {}
# token_info = {
# "vision_tokens": 0,
# "text_input_tokens": 0, # Refers to the original text token count
# "text_output_tokens": 0, # Not directly applicable here as we are encoding, not generating. Will be 0.
# "total_llm_input_tokens": 0, # Refers to the total tokens LLM receives (visual + formatted text)
# }
# # --- Hook Functions Definition ---
# def timing_pre_hook(module, input):
# module_id = id(module)
# if module_id not in timing_info:
# timing_info[module_id] = []
# timing_info[module_id].append((time.time(), 'pre', module.__class__.__name__))
# def timing_post_hook(module, input, output):
# module_id = id(module)
# if module_id not in timing_info:
# # print(f"Warning: No pre-hook data for module {module.__class__.__name__} ({module_id})")
# return
# timing_info[module_id].append((time.time(), 'post', module.__class__.__name__))
# # Collect vision token count (only from Vision Transformer module's post hook)
# module_name = module.__class__.__name__
# if "vision" in module_name.lower() and "transformer" in module_name.lower():
# if isinstance(output, torch.Tensor):
# token_info["vision_tokens"] = output.shape[0] # For visual features, usually (batch_size, num_tokens, hidden_dim)
# elif hasattr(output, 'last_hidden_state'):
# token_info["vision_tokens"] = output.last_hidden_state.shape[1]
# def register_model_hooks(model):
# registered_modules = []
# core_model = model.encoder if hasattr(model, "encoder") and model.encoder is not None else model
# # Vision module
# if hasattr(core_model, 'visual') and core_model.visual is not None:
# vision_module = core_model.visual
# vision_module.register_forward_pre_hook(timing_pre_hook)
# vision_module.register_forward_hook(timing_post_hook)
# registered_modules.append(vision_module)
# print_master(f"Registered hooks for vision module: {vision_module.__class__.__name__}")
# else:
# print_master(f"WARNING: No 'visual' attribute found on core_model ({type(core_model)}).")
# # Merger module (if inside visual) - it's part of the vision component
# if hasattr(core_model, 'visual') and hasattr(core_model.visual, 'merger') and core_model.visual.merger is not None:
# merger_module = core_model.visual.merger
# merger_module.register_forward_pre_hook(timing_pre_hook)
# merger_module.register_forward_hook(timing_post_hook)
# registered_modules.append(merger_module)
# print_master(f"Registered hooks for merger module: {merger_module.__class__.__name__}")
# else:
# print_master(f"WARNING: No 'merger' attribute found on core_model.visual ({type(getattr(core_model, 'visual', 'N/A'))}).")
# # Language model body
# if hasattr(core_model, 'model') and core_model.model is not None:
# llm_main_module = core_model.model
# llm_main_module.register_forward_pre_hook(timing_pre_hook)
# llm_main_module.register_forward_hook(timing_post_hook)
# registered_modules.append(llm_main_module)
# print_master(f"Registered hooks for LLM main module: {llm_main_module.__class__.__name__}")
# else:
# print_master(f"WARNING: No 'model' attribute found on core_model ({type(core_model)}).")
# # LM Head
# if hasattr(core_model, 'lm_head') and core_model.lm_head is not None:
# lm_head_module = core_model.lm_head
# lm_head_module.register_forward_pre_hook(timing_pre_hook)
# lm_head_module.register_forward_hook(timing_post_hook)
# registered_modules.append(lm_head_module)
# print_master(f"Registered hooks for LM head module: {lm_head_module.__class__.__name__}")
# else:
# print_master(f"WARNING: No 'lm_head' attribute found on core_model ({type(core_model)}).")
# if not registered_modules:
# print_master("Warning: No major modules found for hook registration. Check model architecture.")
# return registered_modules
# def pad_dataset_to_divisible(dataset, world_size):
# num_samples = len(dataset)
# if num_samples % world_size == 0:
# return dataset, num_samples
# num_to_add = world_size - (num_samples % world_size)
# padded_size = num_samples + num_to_add
# padding_data = dataset.select([i % len(dataset) for i in range(num_to_add)])
# padded_dataset = concatenate_datasets([dataset, padding_data])
# return padded_dataset, padded_size
# def encode_embeddings(
# model: MMEBModel,
# loader: DataLoader,
# training_args: TrainingArguments,
# model_args: ModelArguments,
# full_dataset: Dataset,
# encode_side: str,
# description: str = "Encoding"
# ) -> tuple[np.ndarray, list, list, list]: # CHANGED: + list for img_token_masks
# """
# Encodes embeddings for a given dataset using the model, handling both standard and
# late-interaction models in a DDP-safe manner.
# Returns:
# - embeddings: np.ndarray
# - infos_or_ids: list
# - batch_stats_list: list
# - img_token_masks: list[None | list[bool]] # NEW
# """
# local_rank = dist.get_rank() if dist.is_initialized() else 0
# world_size = dist.get_world_size() if dist.is_initialized() else 1
# # Check if the model is a late-interaction type
# is_late_interaction = (model_args.model_backbone == COLPALI)
# local_embeds = []
# local_gt_infos = []
# local_max_len = 0
# # --- New: List to store statistics for each batch ---
# batch_stats_list = []
# # --- NEW: Collect masks ---
# local_img_token_masks = [] # post image mask per sample
# local_txt_token_masks = [] # NEW: post text mask per sample
# local_post_attn_masks = [] # NEW: post attention_mask per sample (after prune, 1/0)
# # --- NEW: per-sample token reduction records ---
# local_token_records = [] # 每条样本一个 dict,含 pre/post/delta 数量
# model.eval()
# # Register hooks for the model once per encode_embeddings call
# registered_hooks = register_model_hooks(model)
# # --- NEW: helpers to取mask并序列化 ---
# def _search_key(obj, key: str):
# # 递归搜索 dict/list/tuple,找到指定 key
# if isinstance(obj, dict):
# if key in obj:
# return obj[key]
# for v in obj.values():
# r = _search_key(v, key)
# if r is not None:
# return r
# elif isinstance(obj, (list, tuple)):
# for v in obj:
# r = _search_key(v, key)
# if r is not None:
# return r
# return None
# def _to_serializable_mask_list(mask_list, batch_size: int):
# # 将模型返回的 mask(list/tensor/ndarray/None)转成 [None | list[bool]] * B
# if mask_list is None:
# return [None] * batch_size
# out = []
# if isinstance(mask_list, (list, tuple)):
# for m in mask_list:
# if m is None:
# out.append(None)
# elif torch.is_tensor(m):
# out.append(m.detach().cpu().tolist())
# elif isinstance(m, np.ndarray):
# out.append(m.tolist())
# else:
# # already python list/bool
# out.append(m)
# elif torch.is_tensor(mask_list):
# # 若是 2D 张量(B, L),直接 tolist() -> list[list[bool/int]]
# out = mask_list.detach().cpu().tolist()
# elif isinstance(mask_list, np.ndarray):
# out = mask_list.tolist()
# else:
# # 未知类型,保守返回 None 占位
# out = [None] * batch_size
# # 长度对齐 batch_size
# if isinstance(out, list):
# if len(out) < batch_size:
# out = out + [None] * (batch_size - len(out))
# elif len(out) > batch_size:
# out = out[:batch_size]
# return out
# def _to_bool_lists(m, batch_size: int):
# lst = _to_serializable_mask_list(m, batch_size)
# # 归一化成 list[ list[bool] | None ]
# out = []
# for x in lst:
# if x is None:
# out.append(None)
# else:
# # x 可能是 list[int] 或 list[bool]
# out.append([bool(int(v)) for v in x])
# return out
# with torch.no_grad():
# for inputs, dataset_info in tqdm(loader, desc=f"{description} (rank {local_rank})", disable=local_rank > 0):
# # --- Reset statistics for each inference pass ---
# timing_info.clear()
# token_info["vision_tokens"] = 0
# token_info["text_input_tokens"] = 0
# token_info["text_output_tokens"] = 0
# token_info["total_llm_input_tokens"] = 0
# inputs = batch_to_device(inputs, training_args.device)
# current_batch_size = inputs['input_ids'].shape[0] if 'input_ids' in inputs and inputs['input_ids'] is not None else 1
# with torch.autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda"):
# start_inference_time = time.time()
# # ---- NEW: 按侧开/关 AOP ----
# aop_cfg = getattr(model.encoder, "aop_prune_config", None)
# _orig_enabled = None
# if isinstance(aop_cfg, dict) and aop_cfg:
# _orig_enabled = aop_cfg.get("enabled", False)
# apply_to = aop_cfg.get("apply_to", "qry")
# side_enable = (apply_to == "both") or (apply_to == encode_side)
# aop_cfg["enabled"] = bool(side_enable and _orig_enabled)
# setattr(model.encoder, "aop_prune_config", aop_cfg)
# if encode_side == "qry":
# output = model(qry=inputs)
# reps = output["qry_reps"].detach()
# local_gt_infos.extend(dataset_info)
# else:
# output = model(tgt=inputs)
# reps = output["tgt_reps"].detach()
# local_gt_infos.extend([info["cand_name"] for info in dataset_info])
# # ---- NEW: 恢复 enabled(避免影响下个 encode_side)----
# if isinstance(aop_cfg, dict) and _orig_enabled is not None:
# aop_cfg["enabled"] = _orig_enabled
# setattr(model.encoder, "aop_prune_config", aop_cfg)
# end_inference_time = time.time()
# # --- NEW: 提取 post-prune 的 image/text 掩码 与 post attention_mask ---
# img_masks_raw = None
# txt_masks_raw = None
# post_attn_raw = None
# if isinstance(output, dict):
# img_masks_raw = _search_key(output, "image_token_bool_masks")
# txt_masks_raw = _search_key(output, "text_token_bool_masks") # NEW
# post_attn_raw = _search_key(output, "post_attention_mask") # NEW(我们的 MMEBModel.forward 里带了这个键)
# # 兼容:若挂在 model 上
# if img_masks_raw is None and hasattr(model, "image_token_bool_masks"):
# img_masks_raw = getattr(model, "image_token_bool_masks")
# if txt_masks_raw is None and hasattr(model, "text_token_bool_masks"):
# txt_masks_raw = getattr(model, "text_token_bool_masks")
# if post_attn_raw is None and hasattr(model, "post_attention_mask"):
# post_attn_raw = getattr(model, "post_attention_mask")
# img_masks_serializable = _to_serializable_mask_list(img_masks_raw, current_batch_size)
# txt_masks_serializable = _to_serializable_mask_list(txt_masks_raw, current_batch_size) # NEW
# post_attn_serializable = _to_serializable_mask_list(post_attn_raw, current_batch_size) # NEW
# local_img_token_masks.extend(img_masks_serializable)
# local_txt_token_masks.extend(txt_masks_serializable) # NEW
# local_post_attn_masks.extend(post_attn_serializable) # NEW
# # --- NEW: 计算本 batch 的 pre/post/delta 数量并累计 ---
# cfg = getattr(model.encoder, "config", None)
# # pre masks 来自 inputs(删前)
# input_ids = inputs.get("input_ids", None)
# attn2d_pre = inputs.get("attention_mask", None)
# if input_ids is None or attn2d_pre is None or cfg is None:
# # 无法统计,留空
# pre_vis_counts = [0] * current_batch_size
# pre_txt_counts = [0] * current_batch_size
# pre_tot_counts = [0] * current_batch_size
# else:
# iid = input_ids
# am = attn2d_pre.to(torch.bool)
# image_token_id = getattr(cfg, "image_token_id", None)
# video_token_id = getattr(cfg, "video_token_id", None)
# bos_id = getattr(cfg, "bos_token_id", None)
# eos_id = getattr(cfg, "eos_token_id", None)
# pad_id = getattr(cfg, "pad_token_id", None)
# is_image = (iid == image_token_id) if (image_token_id is not None and image_token_id >= 0) else torch.zeros_like(iid, dtype=torch.bool)
# is_video = (iid == video_token_id) if (video_token_id is not None and video_token_id >= 0) else torch.zeros_like(iid, dtype=torch.bool)
# is_vision = is_image | is_video
# is_special = torch.zeros_like(iid, dtype=torch.bool)
# for tid in [bos_id, eos_id, pad_id]:
# if tid is not None and tid >= 0:
# is_special |= (iid == tid)
# pre_txt_mask = am & (~is_vision) & (~is_special)
# pre_vis_mask = am & is_vision
# pre_vis_counts = pre_vis_mask.sum(dim=1).tolist()
# pre_txt_counts = pre_txt_mask.sum(dim=1).tolist()
# pre_tot_counts = am.sum(dim=1).tolist()
# # post masks(删后)来自模型输出;与 post_attn 做与运算
# post_text_masks = _to_bool_lists(txt_masks_raw, current_batch_size) # list[ list[bool] | None ]
# post_image_masks = _to_bool_lists(img_masks_raw, current_batch_size)
# post_attn_masks = _to_bool_lists(post_attn_raw, current_batch_size)
# sum_pre_text = 0; sum_post_text = 0
# sum_pre_vis = 0; sum_post_vis = 0
# sum_pre_tot = 0; sum_post_tot = 0
# for i in range(current_batch_size):
# pre_text = int(pre_txt_counts[i]) if i < len(pre_txt_counts) else 0
# pre_vis = int(pre_vis_counts[i]) if i < len(pre_vis_counts) else 0
# pre_tot = int(pre_tot_counts[i]) if i < len(pre_tot_counts) else 0
# # post 计数:mask 可能为 None
# m_text = post_text_masks[i] if post_text_masks is not None and i < len(post_text_masks) else None
# m_img = post_image_masks[i] if post_image_masks is not None and i < len(post_image_masks) else None
# m_attn = post_attn_masks[i] if post_attn_masks is not None and i < len(post_attn_masks) else None
# if m_attn is None:
# post_text = 0; post_vis = 0; post_tot = 0
# else:
# # 与 attention_mask 后统计 True 的数
# if m_text is not None:
# post_text = sum(1 for a, t in zip(m_attn, m_text) if a and t)
# else:
# post_text = 0
# if m_img is not None:
# post_vis = sum(1 for a, v in zip(m_attn, m_img) if a and v)
# else:
# post_vis = 0
# post_tot = sum(1 for a in m_attn if a)
# # 累计 batch 级
# sum_pre_text += pre_text; sum_post_text += post_text
# sum_pre_vis += pre_vis; sum_post_vis += post_vis
# sum_pre_tot += pre_tot; sum_post_tot += post_tot
# # 保存 per-sample 记录(用于 JSONL)
# local_token_records.append({
# "side": encode_side,
# "pre": {"text": pre_text, "vision": pre_vis, "total": pre_tot},
# "post": {"text": post_text, "vision": post_vis, "total": post_tot},
# "delta":{"text": pre_text - post_text, "vision": pre_vis - post_vis, "total": pre_tot - post_tot},
# })
# # --- Update total LLM input tokens after the model call ---
# if 'input_ids' in inputs and inputs['input_ids'] is not None:
# token_info["total_llm_input_tokens"] = inputs['input_ids'].shape[1]
# token_info["text_input_tokens"] = token_info["total_llm_input_tokens"] - token_info["vision_tokens"]
# token_info["text_input_tokens"] = max(0, token_info["text_input_tokens"])
# # --- Collect and Store Batch Statistics ---
# batch_inference_time = end_inference_time - start_inference_time
# current_batch_stats = {
# "batch_size": current_batch_size,
# "total_inference_time_seconds": batch_inference_time,
# "module_inference_times": {},
# "token_counts": {
# "visual_tokens": token_info["vision_tokens"],
# "language_input_tokens_raw": token_info["text_input_tokens"],
# "llm_total_input_tokens": token_info["total_llm_input_tokens"],
# "language_output_tokens": token_info["text_output_tokens"],
# }
# }
# current_batch_stats["token_reduction"] = {
# "sum_pre_text": sum_pre_text,
# "sum_post_text": sum_post_text,
# "sum_pre_vision": sum_pre_vis,
# "sum_post_vision": sum_post_vis,
# "sum_pre_total": sum_pre_tot,
# "sum_post_total": sum_post_tot,
# }
# # Calculate and store module timings for the current batch
# for module_obj in registered_hooks:
# module_id = id(module_obj)
# module_name = module_obj.__class__.__name__
# times = timing_info.get(module_id, [])
# durations = []
# pre_times = {}
# for t, event_type, _ in times:
# if event_type == 'pre':
# pre_times[module_id] = t
# elif event_type == 'post' and module_id in pre_times:
# duration = t - pre_times.pop(module_id)
# durations.append(duration)
# if durations:
# current_batch_stats["module_inference_times"][module_name] = {
# "total": sum(durations),
# "count": len(durations),
# "avg": sum(durations) / len(durations)
# }
# else:
# current_batch_stats["module_inference_times"][module_name] = {
# "total": 0.0,
# "count": 0,
# "avg": 0.0
# }
# batch_stats_list.append(current_batch_stats)
# # --- Debug prints (optional) ---
# print_rank(f"\n--- Inference Statistics for {encode_side} batch (Rank {local_rank}) ---")
# print_rank(f"Batch Inference took: {batch_inference_time:.4f} seconds")
# print_rank("--- Module Inference Timing Statistics ---")
# for module_name, stats in current_batch_stats["module_inference_times"].items():
# print_rank(f"**{module_name}**: Total: {stats['total']:.6f}s, Count: {stats['count']}, Avg: {stats['avg']:.6f}s")
# print_rank("--- Token Count Statistics ---")
# print_rank(f"**视觉 token 数量**: {current_batch_stats['token_counts']['visual_tokens']}")
# print_rank(f"**语言输入 token 数量 (仅原始文本)**: {current_batch_stats['token_counts']['language_input_tokens_raw']}")
# print_rank(f"**LLM总输入 token 数量 (包含视觉 + 格式化文本)**: {current_batch_stats['token_counts']['llm_total_input_tokens']}")
# print_rank(f"**语言输出 token 数量**: {current_batch_stats['token_counts']['language_output_tokens']}")
# if is_late_interaction and reps.dim() == 3:
# local_max_len = max(local_max_len, reps.shape[1])
# local_embeds.append(reps)
# if not local_embeds:
# # Handle cases where a rank gets no data
# return np.array([]), [], [], [] # CHANGED: 4个返回值
# # === DDP Synchronization and Padding for Late-Interaction Models ===
# if is_late_interaction:
# if dist.is_initialized():
# # 1: global max length
# local_max_len_tensor = torch.tensor(local_max_len, device=training_args.device)
# dist.all_reduce(local_max_len_tensor, op=dist.ReduceOp.MAX)
# global_max_len = local_max_len_tensor.item()
# else:
# global_max_len = local_max_len
# # 2: pad to global max length
# padded_embeds = []
# for reps_batch in local_embeds:
# if reps_batch.dim() == 3:
# B, L, H = reps_batch.shape
# padding_size = global_max_len - L
# padded_batch = F.pad(reps_batch, (0, 0, 0, padding_size), "constant", 0)
# padded_embeds.append(padded_batch)
# else:
# padded_embeds.append(reps_batch)
# embeds_tensor = torch.cat(padded_embeds, dim=0).contiguous()
# else:
# embeds_tensor = torch.cat(local_embeds, dim=0).contiguous()
# # === Gather embeddings and keys from all ranks ===
# if dist.is_initialized() and full_dataset.num_rows >= world_size:
# print_master(f"Gathering {encode_side} embeddings across all ranks...")
# # tensor gather
# output_shape = list(embeds_tensor.shape)
# output_shape[0] = full_dataset.num_rows
# embeds_tensor = embeds_tensor.to(training_args.device)
# gathered_embeds_tensor = torch.empty(output_shape, dtype=embeds_tensor.dtype, device=training_args.device)
# dist.all_gather_into_tensor(gathered_embeds_tensor, embeds_tensor)
# final_embeddings = gathered_embeds_tensor.cpu().float().numpy()
# # object gather for infos and stats
# gathered_gt_infos = [None for _ in range(world_size)]
# dist.all_gather_object(gathered_gt_infos, local_gt_infos)
# all_gt_infos = [key for rank_keys in gathered_gt_infos for key in rank_keys]
# gathered_batch_stats = [None for _ in range(world_size)]
# dist.all_gather_object(gathered_batch_stats, batch_stats_list)
# all_batch_stats = [stats for rank_stats in gathered_batch_stats for stats in rank_stats]
# # --- NEW: gather masks ---
# gathered_masks = [None for _ in range(world_size)]
# dist.all_gather_object(gathered_masks, local_img_token_masks)
# all_img_token_masks = [m for rank_list in gathered_masks for m in rank_list]
# # NEW: gather text masks
# gathered_txt_masks = [None for _ in range(world_size)]
# dist.all_gather_object(gathered_txt_masks, local_txt_token_masks)
# all_txt_token_masks = [m for rank_list in gathered_txt_masks for m in rank_list]
# # NEW: gather post attention masks(如需)
# gathered_post_attn = [None for _ in range(world_size)]
# dist.all_gather_object(gathered_post_attn, local_post_attn_masks)
# all_post_attn_masks = [m for rank_list in gathered_post_attn for m in rank_list]
# # NEW: gather token records
# gathered_token_recs = [None for _ in range(world_size)]
# dist.all_gather_object(gathered_token_recs, local_token_records)
# all_token_records = [r for rank_list in gathered_token_recs for r in rank_list]
# else:
# all_gt_infos = local_gt_infos
# final_embeddings = embeds_tensor.cpu().float().numpy()
# all_batch_stats = batch_stats_list
# all_img_token_masks = local_img_token_masks # NEW
# all_txt_token_masks = local_txt_token_masks
# all_post_attn_masks = local_post_attn_masks
# all_token_records = local_token_records
# return final_embeddings, all_gt_infos, all_batch_stats, all_img_token_masks, all_txt_token_masks, all_token_records
# # === NEW: 一次前向同时导出 cand 的中间层和最后一层向量 ===
# def encode_candidates_both_layers(
# model: MMEBModel,
# loader: DataLoader,
# training_args: TrainingArguments,
# model_args: ModelArguments,
# full_dataset: Dataset,
# mid_layer: int,
# ) -> tuple[np.ndarray, np.ndarray, list]:
# """
# 单次forward到最后一层,直接从 hidden_states 取:
# - mid_hidden = hidden_states[mid_layer] # 表示经过 mid_layer 层后的状态(见Qwen2_5_VLModel的all_hidden_states定义)
# - last_hidden = hidden_states[-1] # 最后一层norm后的状态
# 然后用 _pooling(attention_mask) 取句向量,返回:
# - cand_mid_embeds: np.ndarray [Nc, D]
# - cand_last_embeds: np.ndarray [Nc, D]
# - cand_ids: list[str]
# 说明:
# - cand 侧默认不做 AOP 剪枝(AOP_APPLY=qry 时天然关闭),因此 mid/last 的序列长度一致,可直接用原 attention_mask 做池化。
# """
# local_rank = dist.get_rank() if dist.is_initialized() else 0
# model.eval()
# all_mid = []
# all_last = []
# all_ids = []
# with torch.no_grad():
# for inputs, dataset_info in tqdm(loader, desc=f"Candidates[BOTH] (rank {local_rank})", disable=local_rank > 0):
# inputs = batch_to_device(inputs, training_args.device)
# # cand 侧确保不触发 AOP(如果你的 AOP_APPLY=qry/both,会在底模按侧门控;此处再做一次保险)
# aop_cfg = getattr(model.encoder, "aop_prune_config", None)
# _orig_enabled = None
# if isinstance(aop_cfg, dict) and aop_cfg:
# _orig_enabled = aop_cfg.get("enabled", False)
# apply_to = aop_cfg.get("apply_to", "qry")
# side_enable = (apply_to == "both") or (apply_to == "cand")
# aop_cfg["enabled"] = bool(side_enable and _orig_enabled)
# setattr(model.encoder, "aop_prune_config", aop_cfg)
# with torch.autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda"):
# # 关键:一次forward拿全层的hidden_states
# out = model.encoder(
# **inputs,
# return_dict=True,
# output_hidden_states=True, # 必须
# stop_at_layer=None, # 走全层
# )
# # 取 hidden_states 并索引中间层/最后一层
# hs_list = out.hidden_states
# assert hs_list is not None and len(hs_list) > mid_layer, \
# f"hidden_states is None or too short. Need index {mid_layer}, got len={0 if hs_list is None else len(hs_list)}"
# mid_hs = hs_list[mid_layer] # [B, L, D]:等价“经过 mid_layer 层后的状态”(即 pre-layer(mid_layer+1))
# last_hs = hs_list[-1] # [B, L, D]:最终norm后的状态
# # 用原 attention_mask 池化(cand侧未剪枝)
# am = inputs.get("attention_mask", None)
# if am is not None and hasattr(am, "device"):
# if am.device != mid_hs.device:
# am = am.to(mid_hs.device)
# reps_mid = model._pooling(mid_hs, am) # [B, D]
# reps_last = model._pooling(last_hs, am) # [B, D]
# all_mid.append(reps_mid.detach().float().cpu())
# all_last.append(reps_last.detach().float().cpu())
# all_ids.extend([info["cand_name"] for info in dataset_info])
# # 恢复 AOP 开关(避免影响其它侧)
# if isinstance(aop_cfg, dict) and _orig_enabled is not None:
# aop_cfg["enabled"] = _orig_enabled
# setattr(model.encoder, "aop_prune_config", aop_cfg)
# if not all_mid:
# return np.array([]), np.array([]), []
# cand_mid_embeds = torch.cat(all_mid, dim=0).numpy()
# cand_last_embeds = torch.cat(all_last, dim=0).numpy()
# return cand_mid_embeds, cand_last_embeds, all_ids
# def main():
# # ----------------------- Distributed init -----------------------
# if "RANK" in os.environ and dist.is_available() and not dist.is_initialized():
# dist.init_process_group(backend="nccl", timeout=datetime.timedelta(minutes=60))
# local_rank = dist.get_rank() if dist.is_initialized() else 0
# world_size = dist.get_world_size() if dist.is_initialized() else 1
# print_master("Distributed init debug info:")
# print_master(f"RANK: {os.environ.get('RANK')}")
# print_master(f"LOCAL_RANK: {os.environ.get('LOCAL_RANK')}")
# print_master(f"WORLD_SIZE: {os.environ.get('WORLD_SIZE')}")
# print_master(f"MASTER_ADDR: {os.environ.get('MASTER_ADDR')}")
# print_master(f"MASTER_PORT: {os.environ.get('MASTER_PORT')}")
# if dist.is_initialized():
# print_rank(f"dist.get_rank(): {dist.get_rank()}")
# print_rank(f"dist.get_world_size(): {dist.get_world_size()}")
# # 兼容 torchrun 参数
# for arg in sys.argv:
# if arg.startswith("--local-rank="):
# rank = arg.split("=")[1]
# sys.argv.remove(arg)
# sys.argv.append('--local_rank')
# sys.argv.append(rank)
# # ----------------------- Parse args -----------------------
# parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
# model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# model_args: ModelArguments
# data_args: DataArguments
# training_args: TrainingArguments
# os.makedirs(data_args.encode_output_path, exist_ok=True)
# # 支持多层评测(优先 LM_LAYERS,兼容 MID_LM_LAYER)
# layers_to_eval = get_env_eval_layers()
# print_master(f"Eval layers (qry/tgt): {layers_to_eval}")
# # ----------------------- Model loading -----------------------
# hf_config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True)
# if not getattr(model_args, "model_backbone", None):
# model_backbone = get_backbone_name(hf_config=hf_config, model_type=model_args.model_type)
# setattr(model_args, 'model_backbone', model_backbone)
# setattr(training_args, 'model_backbone', model_backbone)
# print_master(f'Model Backbone: {model_args.model_backbone}')
# # 仅 rank0 下载,其他rank等待缓存
# if local_rank == 0:
# processor = load_processor(model_args, data_args)
# model = MMEBModel.load(model_args, is_trainable=False, processor=processor)
# print_master(f"[rank=0] Loading the model from Huggingface: {model_args.model_name}...")
# if torch.distributed.is_initialized():
# torch.distributed.barrier()
# if local_rank != 0:
# print_rank(f"Loading the model from cache...")
# processor = load_processor(model_args, data_args)
# time.sleep(random.randint(2 * local_rank, 3 * local_rank))
# model = MMEBModel.load(model_args, is_trainable=False, processor=processor)
# model.eval()
# model = model.to(training_args.device, dtype=torch.bfloat16)
# # ---- NEW: AOP 剪裁配置注入(驱动底模里已实现的 AOP 逻辑)----
# aop_cfg = get_env_aop_config()
# if aop_cfg["enabled"]:
# # 把配置塞到底模;底模 forward 中读取该 dict 并执行剪裁
# setattr(model.encoder, "aop_prune_config", aop_cfg)
# # 可选:为了便于在判定层取注意力或手算 qk,覆盖注意力实现
# attn_override = aop_cfg.get("attn_impl_override", "")
# if attn_override:
# try:
# if hasattr(model.encoder, "model") and hasattr(model.encoder.model, "config"):
# prev = model.encoder.model.config._attn_implementation
# model.encoder.model.config._attn_implementation = attn_override
# print_master(f"[AOP] override attn impl: {prev} -> {attn_override} (仅测试建议)")
# except Exception as e:
# print_master(f"[AOP] try override attn impl failed: {e}")
# print_master("[AOP] AOP-Prune enabled with config: " + json.dumps({
# "apply_to": aop_cfg["apply_to"],
# "layer_idx": aop_cfg["layer_idx"],
# "mode": aop_cfg["mode"],
# "delta": aop_cfg["delta"],
# "K_hat": aop_cfg["K_hat"],
# "keep_ratio": aop_cfg["keep_ratio"],
# "min_keep": aop_cfg["min_keep"],
# "use_bias": aop_cfg["use_bias"],
# "margin_mid?": (aop_cfg["margin_mid"] is not None),
# "prune_text": aop_cfg.get("prune_text", False),
# "keep_ratio_text": aop_cfg.get("keep_ratio_text", None),
# "keep_ratio_vision": aop_cfg.get("keep_ratio_vision", None),
# "selection": aop_cfg.get("selection", "aop"),
# "attn_agg": aop_cfg.get("attn_agg", "mean"),
# }))
# else:
# print_master("[AOP] disabled (set AOP_ENABLED=1 to enable)")
# # 确保“最后一层”时不裁层(避免类里默认20层的坑)
# model.set_inference_layers(qry_layers=None, tgt_layers=None)
# with open(data_args.dataset_config, 'r') as yaml_file:
# dataset_configs = yaml.safe_load(yaml_file)
# # ----------------------- Main evaluation loop -----------------------
# for dataset_idx, (dataset_name, task_config) in enumerate(dataset_configs.items()):
# if dist.is_initialized():
# dist.barrier()
# print_master(f"\n--- Evaluating {dataset_name} ---")
# # 根据 data_basedir 修正路径
# if data_args.data_basedir is not None:
# for key in ["image_root", "video_root", "frame_root", "clip_root", "data_path"]:
# if data_args.data_basedir and task_config.get(key):
# task_config[key] = os.path.join(data_args.data_basedir, task_config[key])
# # 构建数据集
# full_eval_qry_dataset, corpus = AutoEvalPairDataset.instantiate(model_args=model_args, data_args=data_args, **task_config)
# full_eval_cand_dataset = generate_cand_dataset(full_eval_qry_dataset, corpus)
# eval_qry_dataset, eval_cand_dataset = full_eval_qry_dataset, full_eval_cand_dataset
# if dist.is_initialized():
# world_size = dist.get_world_size()
# padded_qry_dataset, _ = pad_dataset_to_divisible(full_eval_qry_dataset, world_size)
# padded_cand_dataset, _ = pad_dataset_to_divisible(full_eval_cand_dataset, world_size)
# eval_qry_dataset = split_dataset_by_node(padded_qry_dataset, rank=local_rank, world_size=world_size)
# eval_cand_dataset = split_dataset_by_node(padded_cand_dataset, rank=local_rank, world_size=world_size)
# else:
# padded_qry_dataset, padded_cand_dataset = full_eval_qry_dataset, full_eval_cand_dataset
# # === EE-only: 仅在线早停推理(先确保两份 candidate 向量)===
# ee_cfg = get_env_ee_config()
# assert ee_cfg["enabled"], "EE_ENABLED must be 1 for EE-only pipeline."
# # 依据 EE_LAYER 构造 tag
# mid_layer = int(ee_cfg["layer"])
# mid_tag = make_layer_tag(mid_layer) # e.g., layer12
# last_tag = "layerlast"
# # 准备路径
# cand_mid_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt_{mid_tag}")
# cand_last_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt_{last_tag}")
# # 构造 cand DataLoader(一次性,不切分)
# eval_cand_collator = MultimodalEvalDataCollator(processor, model_args, data_args, "cand")
# eval_cand_loader = DataLoader(
# full_eval_cand_dataset,
# batch_size=training_args.per_device_eval_batch_size,
# collate_fn=eval_cand_collator,
# num_workers=training_args.dataloader_num_workers
# )
# # === 替换为:一次前向,导出 cand 的 mid/last 两份向量 ===
# need_mid = (not os.path.exists(cand_mid_path))
# need_last = (not os.path.exists(cand_last_path))
# if need_mid or need_last:
# print_master(f"[{dataset_name}] EE-only: encoding candidates BOTH layers in one pass (mid={mid_tag}, last={last_tag}) ...")
# # 走全层(不截层)
# model.set_inference_layers(qry_layers=None, tgt_layers=None)
# cand_embeds_mid, cand_embeds_last, all_cand_ids = encode_candidates_both_layers(
# model=model,
# loader=eval_cand_loader,
# training_args=training_args,
# model_args=model_args,
# full_dataset=full_eval_cand_dataset,
# mid_layer=mid_layer,
# )
# if local_rank == 0:
# if need_mid:
# cand_embed_dict_mid = {cid: emb for cid, emb in zip(all_cand_ids, cand_embeds_mid)}
# with open(cand_mid_path, "wb") as f:
# pickle.dump(cand_embed_dict_mid, f)
# print_master(f"[{dataset_name}] EE-only: saved {mid_tag} candidate embeddings -> {cand_mid_path}")
# if need_last:
# cand_embed_dict_last = {cid: emb for cid, emb in zip(all_cand_ids, cand_embeds_last)}
# with open(cand_last_path, "wb") as f:
# pickle.dump(cand_embed_dict_last, f)
# print_master(f"[{dataset_name}] EE-only: saved {last_tag} candidate embeddings -> {cand_last_path}")
# else:
# print_master(f"[{dataset_name}] EE-only: reuse existing candidates (mid={cand_mid_path}, last={cand_last_path})")
# if dist.is_initialized():
# dist.barrier()
# # 3) 在线早停门控 + 子集续跑(不做离线分层评分/曲线)
# if local_rank == 0:
# with open(cand_mid_path, "rb") as f:
# cand_mid_dict = pickle.load(f)
# with open(cand_last_path, "rb") as f:
# cand_last_dict = pickle.load(f)
# rank_global = task_config.get("eval_type", "global") == "global"
# print_master(f"[{dataset_name}] Run ONLINE early-exit at layer={ee_cfg['layer']}, method={ee_cfg['method']}, tau={ee_cfg['tau']}, topk={ee_cfg['topk']}, global={rank_global}")
# run_early_exit_queries(
# model=model,
# processor=processor,
# model_args=model_args,
# data_args=data_args,
# training_args=training_args,
# qry_dataset=full_eval_qry_dataset, # 全量 query
# cand_mid_dict=cand_mid_dict,
# cand_last_dict=cand_last_dict,
# ee_cfg=ee_cfg,
# dataset_name=dataset_name,
# out_dir=data_args.encode_output_path,
# global_ranking=rank_global,
# )
# if dist.is_initialized():
# dist.barrier()
# # === EE-only 结束;直接进入下一个数据集 ===
# continue
# if __name__ == '__main__':
# main()
import datetime
import logging
import json
import random
import time
import numpy as np
import os
import pickle
import sys
import torch
import torch.distributed as dist
import torch.nn.functional as F
import yaml
import transformers
import math
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import HfArgumentParser, AutoConfig, AutoTokenizer
from datasets import Dataset, concatenate_datasets
from datasets.distributed import split_dataset_by_node
from src.model.vlm_backbone.qwen2_vl.modeling_qwen2_vl_train_tokrnpooling import Qwen2VLForConditionalGeneration as _Qwen2VLForConditionalGeneration_src
from src.arguments import ModelArguments, DataArguments, TrainingArguments
from src.data.collator.eval_collator import MultimodalEvalDataCollator
from src.data.eval_dataset.base_eval_dataset import AutoEvalPairDataset, generate_cand_dataset
from src.eval_utils.metrics import RankingMetrics
from src.model.model_cut_layer_AOP_add_text_cut import MMEBModel
from src.model.processor import get_backbone_name, load_processor, COLPALI
from src.utils import batch_to_device, print_rank, print_master
from dataclasses import dataclass
def get_env_mid_layer():
v = os.environ.get("MID_LM_LAYER", "").strip()
if v == "" or v.lower() in {"none", "null"}:
return None
try:
return int(v)
except:
logger.warning(f"Invalid MID_LM_LAYER={v}, ignore.")
return None
# ------------- AOP-Prune config parsing -------------
def _parse_bool(v: str, default=False):
if v is None: return default
v = v.strip().lower()
return v in {"1","true","yes","y","t","on"}
def _parse_float(v: str, default=None):
try: return float(v) if v is not None else default
except: return default
def _parse_int(v: str, default=None):
try: return int(v) if v is not None else default
except: return default
def get_env_aop_config():
"""
从环境变量读取 AOP 剪裁配置。仅作为“驱动层”的简要测试开关;
实际剪裁逻辑在底模里(Qwen2-VLModel.forward)实现。
"""
enabled = _parse_bool(os.environ.get("AOP_ENABLED"), False)
apply_to = os.environ.get("AOP_APPLY", "qry").strip().lower() # qry|cand|both
layer_idx = _parse_int(os.environ.get("AOP_LAYER"), None)
mode = os.environ.get("AOP_MODE", "delta").strip().lower()
# 通用回退
delta = _parse_float(os.environ.get("AOP_DELTA"), 0.10)
khat = _parse_float(os.environ.get("AOP_KHAT"), 1.0)
keep_ratio = _parse_float(os.environ.get("AOP_KEEP_RATIO"), 1.0)
min_keep = _parse_int(os.environ.get("AOP_MIN_KEEP"), 64)
use_bias = _parse_bool(os.environ.get("AOP_USE_BIAS"), True)
# 按类型控制
prune_vision = _parse_bool(os.environ.get("AOP_PRUNE_VISION"), True)
prune_text = _parse_bool(os.environ.get("AOP_PRUNE_TEXT"), False)
delta_v = _parse_float(os.environ.get("AOP_DELTA_VISION"), None)
khat_v = _parse_float(os.environ.get("AOP_KHAT_VISION"), None)
keep_ratio_v= _parse_float(os.environ.get("AOP_KEEP_RATIO_VISION"), None)
min_keep_v = _parse_int(os.environ.get("AOP_MIN_KEEP_VISION"), None)
delta_t = _parse_float(os.environ.get("AOP_DELTA_TEXT"), None)
khat_t = _parse_float(os.environ.get("AOP_KHAT_TEXT"), None)
keep_ratio_t= _parse_float(os.environ.get("AOP_KEEP_RATIO_TEXT"), None)
min_keep_t = _parse_int(os.environ.get("AOP_MIN_KEEP_TEXT"), 32)
protect_text_last = _parse_int(os.environ.get("AOP_PROTECT_TEXT_LAST"), 16)
protect_special = _parse_bool(os.environ.get("AOP_PROTECT_SPECIAL"), True)
margin_src = os.environ.get("AOP_MARGIN", "").strip().lower() # "" or "mid"
attn_impl = os.environ.get("AOP_ATTN_IMPL", "").strip().lower() # "" or "sdpa"
if layer_idx is None and enabled:
logger.warning("AOP_ENABLED=1 但未设置 AOP_LAYER,关闭 AOP。"); enabled=False
# 新增:选择策略(aop | random)
selection = os.environ.get("AOP_SELECTION", "aop").strip().lower()
if _parse_bool(os.environ.get("AOP_RANDOM"), False):
selection = "random"
random_seed = _parse_int(os.environ.get("AOP_RANDOM_SEED"), None)
# 选择策略:aop | random | attention
selection = os.environ.get("AOP_SELECTION", "aop").strip().lower()
if _parse_bool(os.environ.get("AOP_RANDOM"), False):
selection = "random"
random_seed = _parse_int(os.environ.get("AOP_RANDOM_SEED"), None)
attn_agg = os.environ.get("AOP_ATTENTION_AGG", "mean").strip().lower() # mean|max|sum
cfg = {
"enabled": enabled,
"apply_to": apply_to,
"layer_idx": layer_idx,
"mode": mode,
# 回退
"delta": delta, "K_hat": khat,
"keep_ratio": keep_ratio, "min_keep": min_keep,
"use_bias": use_bias, "eps": 1e-6,
# 类型开关
"prune_vision": prune_vision,
"prune_text": prune_text,
# 视觉桶
"delta_vision": delta_v,
"K_hat_vision": khat_v,
"keep_ratio_vision": keep_ratio_v,
"min_keep_vision": min_keep_v,
# 文本桶
"delta_text": delta_t,
"K_hat_text": khat_t,
"keep_ratio_text": keep_ratio_t,
"min_keep_text": min_keep_t,
# 文本保护
"protect_text_last": protect_text_last,
"protect_special": protect_special,
# 可选:排名安全预算
"margin_mid": None if margin_src != "mid" else "USE_MID_MARGIN",
"epsilon_hat": None,
"attn_impl_override": attn_impl if attn_impl in {"sdpa"} else "",
# NEW: 选择策略
"selection": selection, # "aop" 或 "random"
"random_seed": random_seed, # 可选
"attn_agg": attn_agg,
}
return cfg
def get_env_eval_layers():
"""
解析环境变量 LM_LAYERS(优先)或兼容旧的 MID_LM_LAYER。
- LM_LAYERS 示例:"4,8,12,last";可包含 'last'/'none'/'null'/'-1' 表示最后一层(None)。
- 若未设置 LM_LAYERS,则回落到旧逻辑:MID_LM_LAYER=None -> [None];否则 [mid, None]
返回: list[ int | None ],例如 [4, 8, 12, None];None 代表最后一层。
"""
v = os.environ.get("LM_LAYERS", None)
if v is not None:
v = v.strip()
if v:
toks = [t.strip() for t in v.split(',') if t.strip() != ""]
layers = []
for tok in toks:
tl = tok.lower()
if tl in {"last", "none", "null", "-1"}:
layers.append(None)
else:
try:
val = int(tok)
if val > 0:
layers.append(val)
else:
logger.warning(f"Ignoring non-positive layer '{tok}' in LM_LAYERS.")
except Exception:
logger.warning(f"Invalid token '{tok}' in LM_LAYERS; must be int or 'last'/'none'.")
# 去重但保持顺序
seen = set()
uniq = []
for l in layers:
key = -1 if l is None else l
if key in seen:
continue
seen.add(key)
uniq.append(l)
if not uniq:
return [None]
return uniq
else:
# 兼容旧逻辑
mid = get_env_mid_layer()
return [None] if mid is None else [mid, None]
# === Early-Exit config & helpers ===
def get_env_ee_config():
ee_enabled = os.environ.get("EE_ENABLED", "0").strip().lower() in {"1","true","yes","on","y","t"}
layer = int(os.environ.get("EE_LAYER", os.environ.get("AOP_LAYER", "12"))) # 默认用 AOP_LAYER
method = os.environ.get("EE_METHOD", "margin").strip().lower() # margin|p1p2|entropy|gini|combined
tau = float(os.environ.get("EE_TAU", "0.2"))
topk = int(os.environ.get("EE_TOPK", "1024"))
temp = float(os.environ.get("EE_TEMP", "0.05"))
save = os.environ.get("EE_SAVE", "1").strip().lower() in {"1","true","yes","on","y","t"}
combw = os.environ.get("EE_COMB_WEIGHTS", "1.0,0.5,0.5")
try:
w_margin, w_conf, w_sq = [float(x) for x in combw.split(",")]
except Exception:
w_margin, w_conf, w_sq = 1.0, 0.5, 0.5
return dict(
enabled=ee_enabled, layer=layer, method=method, tau=tau,
topk=topk, temp=temp, save=save,
w_margin=w_margin, w_conf=w_conf, w_sq=w_sq
)
def _softmax_np(x: np.ndarray, temp: float = 1.0) -> np.ndarray:
x = x - np.max(x)
ex = np.exp(x / max(1e-6, temp))
s = np.sum(ex)
return ex / max(s, 1e-12)
def confidence_from_topk(scores: np.ndarray, method="margin", temp=0.05, w_margin=1.0, w_conf=0.5, w_sq=0.5) -> float:
# scores: 已按降序排列(topK)
if scores.size == 0:
return 0.0
if scores.size == 1:
return 1e9
margin = float(scores[0] - scores[1])
p = _softmax_np(scores, temp=temp)
p1p2 = float(p[0] - p[1])
H = - float(np.sum(p * np.log(p + 1e-12))) / np.log(len(p)) # 归一化熵 ∈ [0,1]
conf = 1.0 - H
sqsum = float(np.sum(p**2)) # Gini 的等价度量(越大越集中)
if method == "margin": return margin
if method == "p1p2": return p1p2
if method == "entropy": return conf
if method == "gini": return sqsum
# combined
return w_margin*margin + w_conf*conf + w_sq*sqsum
def run_early_exit_queries(
model: MMEBModel,
processor,
model_args: ModelArguments,
data_args: DataArguments,
training_args: TrainingArguments,
qry_dataset: Dataset,
cand_mid_dict: dict,
cand_last_dict: dict,
ee_cfg: dict,
dataset_name: str,
out_dir: str,
global_ranking: bool = True,
):
"""
仅在线早停推理(不画曲线),并额外输出:
- 每个 query 的中间层 vs cand_last 的 top-K 相似度 (mid2last)
- 对未早停的 query,再输出最后一层 vs cand_last 的 top-K 相似度 (last2last)
输出文件:
{out_dir}/{dataset}_score_earlyexit.json - 检索指标
{out_dir}/{dataset}_pred_earlyexit.jsonl - 预测列表(原有)
{out_dir}/{dataset}_sim_earlyexit.jsonl - 本函数新增的相似度信息(需 EE_SAVE=1)
"""
device = training_args.device
local_rank = dist.get_rank() if dist.is_initialized() else 0
is_main = (not dist.is_initialized()) or (local_rank == 0)
# 候选矩阵
cand_ids = list(cand_mid_dict.keys())
cand_id2row = {str(cid): i for i, cid in enumerate(cand_ids)}
cand_mid = np.stack([cand_mid_dict[c] for c in cand_ids]).astype(np.float32)
cand_last = np.stack([cand_last_dict[c] for c in cand_ids]).astype(np.float32)
cand_mid_t = torch.from_numpy(cand_mid).to(device=device, dtype=torch.bfloat16)
cand_last_t = torch.from_numpy(cand_last).to(device=device, dtype=torch.bfloat16)
# query DataLoader
collator = MultimodalEvalDataCollator(processor, model_args, data_args, "qry")
loader = DataLoader(
qry_dataset,
batch_size=training_args.per_device_eval_batch_size,
collate_fn=collator,
num_workers=training_args.dataloader_num_workers
)
pred_dicts = []
# 是否保存相似度(沿用 EE_SAVE)
save_scores = ee_cfg.get("save", False)
topk_sim = int(ee_cfg.get("topk", 1024))
sim_records = [] if (save_scores and is_main) else None
# AOP 按侧开启
aop_cfg = getattr(model.encoder, "aop_prune_config", None)
_orig_enabled = None
side_enable = True
if isinstance(aop_cfg, dict) and aop_cfg:
_orig_enabled = aop_cfg.get("enabled", False)
apply_to = aop_cfg.get("apply_to", "qry")
side_enable = (apply_to == "both") or (apply_to == "qry")
# 门控相关
k_conf = 2
tau = float(ee_cfg["tau"])
method= ee_cfg["method"]
temp = float(ee_cfg["temp"])
idx_global = 0
start_time = time.time()
for inputs, infos in tqdm(
loader,
desc=f"[EE] {dataset_name}@L{ee_cfg['layer']} (rank {local_rank})",
disable=local_rank > 0,
):
inputs = batch_to_device(inputs, device)
# -------- 1) 跑到中间层(stop_at_layer),不算 logits --------
orig_cfg = None
if isinstance(aop_cfg, dict) and aop_cfg:
orig_cfg = dict(aop_cfg)
aop_layer = aop_cfg.get("layer_idx", None)
ee_layer = int(ee_cfg["layer"])
apply_to = aop_cfg.get("apply_to", "qry").strip().lower()
aop_on_mid = bool(
_orig_enabled and side_enable and
(aop_layer is not None) and (aop_layer < ee_layer) and
(apply_to in {"qry", "both"})
)
aop_cfg_mid = dict(aop_cfg)
aop_cfg_mid["enabled"] = aop_on_mid
setattr(model.encoder, "aop_prune_config", aop_cfg_mid)
with torch.no_grad(), torch.autocast(
device_type="cuda", dtype=torch.bfloat16, enabled=True
):
out_mid = model.encoder(
**inputs,
return_dict=True,
output_hidden_states=False,
stop_at_layer=int(ee_cfg["layer"]),
compute_lm_head=False,
)
if isinstance(orig_cfg, dict):
setattr(model.encoder, "aop_prune_config", orig_cfg)
# EOS pooling 得到中间层表征
hs_mid = getattr(out_mid, "last_hidden_state", None)
if hs_mid is None:
assert out_mid.hidden_states is not None and len(out_mid.hidden_states) > 0
hs_mid = out_mid.hidden_states[-1]
am_mid = getattr(out_mid, "attention_mask", None)
if am_mid is None:
am_mid = inputs.get("attention_mask", None)
if hasattr(am_mid, "device") and am_mid.device != hs_mid.device:
am_mid = am_mid.to(hs_mid.device)
reps_mid_t = model._pooling(hs_mid, am_mid).detach().to(device=device, dtype=torch.bfloat16) # [B,D]
B = reps_mid_t.size(0)
use_local = (not global_ranking)
# -------- 2) 门控:基于 mid→mid 的 top2 分数 --------
if not use_local:
# 全库:cand_mid_t
scores_t = reps_mid_t @ cand_mid_t.T # [B, Nc]
vals_t, idxs_t = torch.topk(
scores_t, k=min(k_conf, scores_t.size(1)), dim=1
) # [B,2]
p_t = torch.softmax(vals_t / max(temp, 1e-6), dim=1)
if vals_t.size(1) >= 2:
margin_t = vals_t[:, 0] - vals_t[:, 1]
p1p2_t = p_t[:, 0] - p_t[:, 1]
else:
margin_t = torch.full((B,), float("inf"), device=device, dtype=vals_t.dtype)
p1p2_t = torch.ones(B, device=device, dtype=vals_t.dtype)
H_t = -(p_t * (torch.log(p_t + 1e-12))).sum(dim=1) / math.log(max(vals_t.size(1),1))
conf_map = {
"margin": margin_t,
"p1p2": p1p2_t,
"entropy": 1.0 - H_t,
"gini": (p_t ** 2).sum(dim=1),
}
confs_t = conf_map.get(method, margin_t)
exit_mask = (confs_t >= tau).detach().cpu().numpy().astype(bool)
else:
# local:对每个 query 单独用 cand_mid_t[rows]
confs = []
for b in range(B):
cand_local = infos[b]["cand_names"]
rows = [cand_id2row.get(str(cid), -1) for cid in cand_local]
rows = [r for r in rows if r >= 0]
if len(rows) == 0:
confs.append(0.0)
continue
cmat_t = cand_mid_t[rows] # [Nl, D]
sv_t = (reps_mid_t[b:b+1] @ cmat_t.T)[0] # [Nl]
k = 2 if sv_t.size(0) >= 2 else 1
vals_t, _ = torch.topk(sv_t, k=k, dim=0)
p_t = torch.softmax(vals_t / max(temp, 1e-6), dim=0)
if k >= 2:
margin = (vals_t[0] - vals_t[1]).item()
p1p2 = (p_t[0] - p_t[1]).item()
else:
margin, p1p2 = float("inf"), 1.0
H = (-(p_t * (torch.log(p_t + 1e-12))).sum() / math.log(max(k,1))).item()
gini = ((p_t ** 2).sum()).item()
d = {"margin": margin, "p1p2": p1p2, "entropy": 1.0 - H, "gini": gini}
confs.append(d.get(method, margin))
exit_mask = (np.array(confs) >= tau)
# -------- 3) 检索 + 相似度记录 --------
# 早停样本
exit_indices = np.where(exit_mask)[0].tolist()
# 续跑样本
need_indices = np.where(~exit_mask)[0].tolist()
# A. 早停:直接用 mid→mid 排序,但我们额外算 mid→last 的 top-K 相似度
for j in exit_indices:
# 1) 排序(pred_dicts)
if not use_local:
scores_mid_mid = (reps_mid_t[j:j+1] @ cand_mid_t.T)[0] # [Nc]
order = torch.argsort(scores_mid_mid, dim=0, descending=True).detach().cpu().numpy()
cids = [cand_ids[i] for i in order]
else:
cand_local = infos[j]["cand_names"]
rows = [cand_id2row.get(str(cid), -1) for cid in cand_local]
rows = [r for r in rows if r >= 0]
if len(rows) == 0:
cids = []
else:
cmat_t = cand_mid_t[rows]
sv = (reps_mid_t[j:j+1] @ cmat_t.T)[0]
order_local = torch.argsort(sv, dim=0, descending=True).detach().cpu().numpy()
cids = [str(cand_local[i]) for i in order_local]
rel_docids = infos[j]["label_name"]
if not isinstance(rel_docids, list):
rel_docids = [rel_docids]
pred_dicts.append({"prediction": cids, "label": rel_docids, "rel_scores": None})
# 2) 相似度记录:mid→last
if save_scores and is_main:
if not use_local:
scores_mid_last = (reps_mid_t[j:j+1] @ cand_last_t.T)[0].detach().float().cpu() # [Nc]
Nc = scores_mid_last.size(0)
K = min(topk_sim, Nc)
mid_vals, mid_inds = torch.topk(scores_mid_last, k=K, dim=0)
mid_ids = [cand_ids[i] for i in mid_inds.tolist()]
else:
cand_local = infos[j]["cand_names"]
rows = [cand_id2row.get(str(cid), -1) for cid in cand_local]
rows = [r for r in rows if r >= 0]
if len(rows) == 0:
mid_vals = torch.empty(0)
mid_ids = []
else:
cmat_last = cand_last_t[rows] # [Nl, D]
sv_last = (reps_mid_t[j:j+1] @ cmat_last.T)[0].detach().float().cpu()
K = min(topk_sim, sv_last.size(0))
mid_vals, mid_inds = torch.topk(sv_last, k=K, dim=0)
mid_ids = [str(cand_local[i]) for i in mid_inds.tolist()]
rec = {
"qid": int(idx_global + j),
"early_exit": True,
"mid_topk_scores": mid_vals.tolist() if mid_vals.numel() > 0 else [],
"mid_topk_cand_ids": mid_ids,
"last_topk_scores": None,
"last_topk_cand_ids": None,
}
sim_records.append(rec)
# B. 续跑:mid->last,再用 last→last 排序;同时记录 mid→last & last→last 相似度
if len(need_indices) > 0:
# 从中间态恢复
if isinstance(aop_cfg, dict) and aop_cfg:
aop_resume = dict(aop_cfg)
aop_resume["enabled"] = bool(_orig_enabled and side_enable)
setattr(model.encoder, "aop_prune_config", aop_resume)
interm = getattr(out_mid, "intermediate_state", None)
assert interm is not None, "Model must return intermediate_state when stop_at_layer is set."
hs = interm["hidden_states"].detach()
am = interm["attention_mask"].detach()
pos = interm["position_ids"].detach()
vm = interm.get("vision_mask", None)
tm = interm.get("text_mask", None)
next_layer = int(interm["next_layer_idx"])
hs_sub = hs[need_indices]
am_sub = am[need_indices]
pos_sub = pos[:, need_indices, :]
vm_sub = vm[need_indices] if vm is not None else None
tm_sub = tm[need_indices] if tm is not None else None
resume_state = {
"hidden_states": hs_sub,
"attention_mask": am_sub,
"position_ids": pos_sub,
"vision_mask": vm_sub,
"text_mask": tm_sub,
"next_layer_idx": next_layer,
}
with torch.no_grad(), torch.autocast(
device_type="cuda", dtype=torch.bfloat16, enabled=True
):
out_last = model.encoder(
return_dict=True,
output_hidden_states=False,
stop_at_layer=None,
resume_state=resume_state,
compute_lm_head=False,
)
hs_last = getattr(out_last, "last_hidden_state", None)
if hs_last is None:
assert out_last.hidden_states is not None and len(out_last.hidden_states) > 0
hs_last = out_last.hidden_states[-1]
am_last = getattr(out_last, "attention_mask", None)
if am_last is None:
am_last = am_sub
if hasattr(am_last, "device") and am_last.device != hs_last.device:
am_last = am_last.to(hs_last.device)
reps_last_t = model._pooling(hs_last, am_last).detach().to(device=device, dtype=torch.bfloat16)
if not use_local:
scores_last_all = (reps_last_t @ cand_last_t.T).detach().float().cpu() # [N_need, Nc]
for k, j in enumerate(need_indices):
# 1) 排序预测
row = scores_last_all[k]
order = torch.argsort(row, dim=0, descending=True).tolist()
cids = [cand_ids[i] for i in order]
rel_docids = infos[j]["label_name"]
if not isinstance(rel_docids, list):
rel_docids = [rel_docids]
pred_dicts.append({"prediction": cids, "label": rel_docids, "rel_scores": None})
# 2) mid→last & last→last 相似度
if save_scores and is_main:
# mid2last
scores_mid_last = (reps_mid_t[j:j+1] @ cand_last_t.T)[0].detach().float().cpu()
Nc = scores_mid_last.size(0)
K = min(topk_sim, Nc)
mid_vals, mid_inds = torch.topk(scores_mid_last, k=K, dim=0)
mid_ids = [cand_ids[i] for i in mid_inds.tolist()]
# last2last
last_row = row
last_vals, last_inds = torch.topk(last_row, k=K, dim=0)
last_ids = [cand_ids[i] for i in last_inds.tolist()]
rec = {
"qid": int(idx_global + j),
"early_exit": False,
"mid_topk_scores": mid_vals.tolist(),
"mid_topk_cand_ids": mid_ids,
"last_topk_scores": last_vals.tolist(),
"last_topk_cand_ids": last_ids,
}
sim_records.append(rec)
else:
# local ranking
for k, j in enumerate(need_indices):
cand_local = infos[j]["cand_names"]
rows = [cand_id2row.get(str(cid), -1) for cid in cand_local]
rows = [r for r in rows if r >= 0]
if len(rows) == 0:
cids = []
rel_docids = infos[j]["label_name"]
if not isinstance(rel_docids, list):
rel_docids = [rel_docids]
pred_dicts.append({"prediction": cids, "label": rel_docids, "rel_scores": None})
if save_scores and is_main:
rec = {
"qid": int(idx_global + j),
"early_exit": False,
"mid_topk_scores": [],
"mid_topk_cand_ids": [],
"last_topk_scores": [],
"last_topk_cand_ids": [],
}
sim_records.append(rec)
continue
# 1) 排序(last→last)
cmat_last = cand_last_t[rows]
sv_last = (reps_last_t[k:k+1] @ cmat_last.T)[0].detach().float().cpu()
order_local = torch.argsort(sv_last, dim=0, descending=True).tolist()
cids = [str(cand_local[i]) for i in order_local]
rel_docids = infos[j]["label_name"]
if not isinstance(rel_docids, list):
rel_docids = [rel_docids]
pred_dicts.append({"prediction": cids, "label": rel_docids, "rel_scores": None})
# 2) mid→last & last→last 相似度
if save_scores and is_main:
cmat_last = cand_last_t[rows]
# mid2last
sv_mid_last = (reps_mid_t[j:j+1] @ cmat_last.T)[0].detach().float().cpu()
K = min(topk_sim, sv_mid_last.size(0))
mid_vals, mid_inds = torch.topk(sv_mid_last, k=K, dim=0)
mid_ids = [str(cand_local[i]) for i in mid_inds.tolist()]
# last2last
sv_last_row = sv_last
last_vals, last_inds = torch.topk(sv_last_row, k=K, dim=0)
last_ids = [str(cand_local[i]) for i in last_inds.tolist()]
rec = {
"qid": int(idx_global + j),
"early_exit": False,
"mid_topk_scores": mid_vals.tolist(),
"mid_topk_cand_ids": mid_ids,
"last_topk_scores": last_vals.tolist(),
"last_topk_cand_ids": last_ids,
}
sim_records.append(rec)
idx_global += B
# -------- 4) 评测 + 写出 --------
metrics_to_report = ["hit", "ndcg", "precision", "recall", "f1", "map", "mrr"]
score = RankingMetrics(metrics_to_report).evaluate(pred_dicts)
if is_main:
os.makedirs(out_dir, exist_ok=True)
# 原有的 early-exit 检索结果
with open(os.path.join(out_dir, f"{dataset_name}_score_earlyexit.json"), "w") as f:
json.dump(score, f, indent=4)
if ee_cfg.get("save", False):
with open(os.path.join(out_dir, f"{dataset_name}_pred_earlyexit.jsonl"), "w", encoding="utf-8") as f:
for row in pred_dicts:
f.write(json.dumps(row, ensure_ascii=False) + "\n")
# 新增的 mid/last 相似度输出
if save_scores and sim_records is not None:
sims_path = os.path.join(out_dir, f"{dataset_name}_sim_earlyexit.jsonl")
with open(sims_path, "w", encoding="utf-8") as f:
for rec in sim_records:
f.write(json.dumps(rec, ensure_ascii=False) + "\n")
print_master(f"[EE] Saved mid/last similarity records -> {sims_path}")
elapsed = time.time() - start_time
return score, elapsed
def make_layer_tag(keep_layers: int | None):
return f"layer{keep_layers}" if keep_layers and keep_layers > 0 else "layerlast"
def dot_sim(a: np.ndarray, b: np.ndarray) -> np.ndarray:
# a: [Nq, D], b: [Nc, D], both L2-normalized already if normalize=true
return a @ b.T
def build_score_details(qid: int, cand_ids: list, score_vec: np.ndarray, ranked_indices: np.ndarray):
return {
"qid": int(qid),
"cand_scores": [
{"cand_id": str(cand_ids[i]), "score": float(score_vec[i])}
for i in ranked_indices
]
}
def top1_top2_margin(score_vec: np.ndarray) -> float:
if len(score_vec) < 2:
return float("inf") # 只有一个候选时视作极大margin
top2 = np.partition(score_vec, -2)[-2:]
top2.sort()
return float(top2[-1] - top2[-2])
def simulate_early_exit_by_margin(
sims_mid: list[dict], sims_last: list[dict], labels: list[list[str]], metrics_to_report: list[str],
taus: list[float], rank_global: bool
):
"""
sims_mid / sims_last: 每个query一个dict: {cand_id: score}
labels: 每个query的正样本cand_id列表
返回:不同tau下的覆盖率、指标
"""
assert len(sims_mid) == len(sims_last) == len(labels)
N = len(labels)
results = []
from src.eval_utils.metrics import RankingMetrics
metrics = RankingMetrics(metrics_to_report)
# 预构造 用于metrics.evaluate 的pred_dict
def to_pred_dicts(use_mid_mask: list[bool]) -> list[dict]:
pred_dicts = []
for qid in range(N):
sims_use = sims_mid[qid] if use_mid_mask[qid] else sims_last[qid]
# 排序
ranked = sorted(sims_use.items(), key=lambda x: -x[1])
pred_dicts.append({
"prediction": [cid for cid, _ in ranked],
"label": labels[qid],
"rel_scores": None
})
return pred_dicts
# 计算中间层margin
margins = []
for qid in range(N):
# 取前两大分数的margin
if len(sims_mid[qid]) == 0:
margins.append(0.0)
continue
scores = np.array(list(sims_mid[qid].values()), dtype=np.float32)
margins.append(top1_top2_margin(scores))
margins = np.array(margins, dtype=np.float32)
for tau in taus:
use_mid_mask = (margins >= tau).tolist()
pred_dicts = to_pred_dicts(use_mid_mask)
score_dict = metrics.evaluate(pred_dicts)
coverage = float(np.mean(use_mid_mask)) # 早停覆盖率
results.append({
"tau": tau,
"coverage": coverage,
**score_dict
})
return results
def top1_top2_margin_from_array(score_vec: np.ndarray) -> float:
if score_vec is None or len(score_vec) == 0:
return 0.0
if len(score_vec) == 1:
return float('inf')
# 取前两大
top2 = np.partition(score_vec, -2)[-2:]
top2.sort()
return float(top2[-1] - top2[-2])
logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s [%(name)s:%(lineno)s] %(message)s')
logger = logging.getLogger(__name__)
# --- Global Dictionaries for Hooks (will be cleared before each encode_embeddings call) ---
timing_info = {}
token_info = {
"vision_tokens": 0,
"text_input_tokens": 0, # Refers to the original text token count
"text_output_tokens": 0, # Not directly applicable here as we are encoding, not generating. Will be 0.
"total_llm_input_tokens": 0, # Refers to the total tokens LLM receives (visual + formatted text)
}
# --- Hook Functions Definition ---
def timing_pre_hook(module, input):
module_id = id(module)
if module_id not in timing_info:
timing_info[module_id] = []
timing_info[module_id].append((time.time(), 'pre', module.__class__.__name__))
def timing_post_hook(module, input, output):
module_id = id(module)
if module_id not in timing_info:
# print(f"Warning: No pre-hook data for module {module.__class__.__name__} ({module_id})")
return
timing_info[module_id].append((time.time(), 'post', module.__class__.__name__))
# Collect vision token count (only from Vision Transformer module's post hook)
module_name = module.__class__.__name__
if "vision" in module_name.lower() and "transformer" in module_name.lower():
if isinstance(output, torch.Tensor):
token_info["vision_tokens"] = output.shape[0] # For visual features, usually (batch_size, num_tokens, hidden_dim)
elif hasattr(output, 'last_hidden_state'):
token_info["vision_tokens"] = output.last_hidden_state.shape[1]
def register_model_hooks(model):
registered_modules = []
core_model = model.encoder if hasattr(model, "encoder") and model.encoder is not None else model
# Vision module
if hasattr(core_model, 'visual') and core_model.visual is not None:
vision_module = core_model.visual
vision_module.register_forward_pre_hook(timing_pre_hook)
vision_module.register_forward_hook(timing_post_hook)
registered_modules.append(vision_module)
print_master(f"Registered hooks for vision module: {vision_module.__class__.__name__}")
else:
print_master(f"WARNING: No 'visual' attribute found on core_model ({type(core_model)}).")
# Merger module (if inside visual) - it's part of the vision component
if hasattr(core_model, 'visual') and hasattr(core_model.visual, 'merger') and core_model.visual.merger is not None:
merger_module = core_model.visual.merger
merger_module.register_forward_pre_hook(timing_pre_hook)
merger_module.register_forward_hook(timing_post_hook)
registered_modules.append(merger_module)
print_master(f"Registered hooks for merger module: {merger_module.__class__.__name__}")
else:
print_master(f"WARNING: No 'merger' attribute found on core_model.visual ({type(getattr(core_model, 'visual', 'N/A'))}).")
# Language model body
if hasattr(core_model, 'model') and core_model.model is not None:
llm_main_module = core_model.model
llm_main_module.register_forward_pre_hook(timing_pre_hook)
llm_main_module.register_forward_hook(timing_post_hook)
registered_modules.append(llm_main_module)
print_master(f"Registered hooks for LLM main module: {llm_main_module.__class__.__name__}")
else:
print_master(f"WARNING: No 'model' attribute found on core_model ({type(core_model)}).")
# LM Head
if hasattr(core_model, 'lm_head') and core_model.lm_head is not None:
lm_head_module = core_model.lm_head
lm_head_module.register_forward_pre_hook(timing_pre_hook)
lm_head_module.register_forward_hook(timing_post_hook)
registered_modules.append(lm_head_module)
print_master(f"Registered hooks for LM head module: {lm_head_module.__class__.__name__}")
else:
print_master(f"WARNING: No 'lm_head' attribute found on core_model ({type(core_model)}).")
if not registered_modules:
print_master("Warning: No major modules found for hook registration. Check model architecture.")
return registered_modules
def pad_dataset_to_divisible(dataset, world_size):
num_samples = len(dataset)
if num_samples % world_size == 0:
return dataset, num_samples
num_to_add = world_size - (num_samples % world_size)
padded_size = num_samples + num_to_add
padding_data = dataset.select([i % len(dataset) for i in range(num_to_add)])
padded_dataset = concatenate_datasets([dataset, padding_data])
return padded_dataset, padded_size
def encode_embeddings(
model: MMEBModel,
loader: DataLoader,
training_args: TrainingArguments,
model_args: ModelArguments,
full_dataset: Dataset,
encode_side: str,
description: str = "Encoding"
) -> tuple[np.ndarray, list, list, list]: # CHANGED: + list for img_token_masks
"""
Encodes embeddings for a given dataset using the model, handling both standard and
late-interaction models in a DDP-safe manner.
Returns:
- embeddings: np.ndarray
- infos_or_ids: list
- batch_stats_list: list
- img_token_masks: list[None | list[bool]] # NEW
"""
local_rank = dist.get_rank() if dist.is_initialized() else 0
world_size = dist.get_world_size() if dist.is_initialized() else 1
# Check if the model is a late-interaction type
is_late_interaction = (model_args.model_backbone == COLPALI)
local_embeds = []
local_gt_infos = []
local_max_len = 0
# --- New: List to store statistics for each batch ---
batch_stats_list = []
# --- NEW: Collect masks ---
local_img_token_masks = [] # post image mask per sample
local_txt_token_masks = [] # NEW: post text mask per sample
local_post_attn_masks = [] # NEW: post attention_mask per sample (after prune, 1/0)
# --- NEW: per-sample token reduction records ---
local_token_records = [] # 每条样本一个 dict,含 pre/post/delta 数量
model.eval()
# Register hooks for the model once per encode_embeddings call
registered_hooks = register_model_hooks(model)
# --- NEW: helpers to取mask并序列化 ---
def _search_key(obj, key: str):
# 递归搜索 dict/list/tuple,找到指定 key
if isinstance(obj, dict):
if key in obj:
return obj[key]
for v in obj.values():
r = _search_key(v, key)
if r is not None:
return r
elif isinstance(obj, (list, tuple)):
for v in obj:
r = _search_key(v, key)
if r is not None:
return r
return None
def _to_serializable_mask_list(mask_list, batch_size: int):
# 将模型返回的 mask(list/tensor/ndarray/None)转成 [None | list[bool]] * B
if mask_list is None:
return [None] * batch_size
out = []
if isinstance(mask_list, (list, tuple)):
for m in mask_list:
if m is None:
out.append(None)
elif torch.is_tensor(m):
out.append(m.detach().cpu().tolist())
elif isinstance(m, np.ndarray):
out.append(m.tolist())
else:
# already python list/bool
out.append(m)
elif torch.is_tensor(mask_list):
# 若是 2D 张量(B, L),直接 tolist() -> list[list[bool/int]]
out = mask_list.detach().cpu().tolist()
elif isinstance(mask_list, np.ndarray):
out = mask_list.tolist()
else:
# 未知类型,保守返回 None 占位
out = [None] * batch_size
# 长度对齐 batch_size
if isinstance(out, list):
if len(out) < batch_size:
out = out + [None] * (batch_size - len(out))
elif len(out) > batch_size:
out = out[:batch_size]
return out
def _to_bool_lists(m, batch_size: int):
lst = _to_serializable_mask_list(m, batch_size)
# 归一化成 list[ list[bool] | None ]
out = []
for x in lst:
if x is None:
out.append(None)
else:
# x 可能是 list[int] 或 list[bool]
out.append([bool(int(v)) for v in x])
return out
with torch.no_grad():
for inputs, dataset_info in tqdm(loader, desc=f"{description} (rank {local_rank})", disable=local_rank > 0):
# --- Reset statistics for each inference pass ---
timing_info.clear()
token_info["vision_tokens"] = 0
token_info["text_input_tokens"] = 0
token_info["text_output_tokens"] = 0
token_info["total_llm_input_tokens"] = 0
inputs = batch_to_device(inputs, training_args.device)
current_batch_size = inputs['input_ids'].shape[0] if 'input_ids' in inputs and inputs['input_ids'] is not None else 1
with torch.autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda"):
start_inference_time = time.time()
# ---- NEW: 按侧开/关 AOP ----
aop_cfg = getattr(model.encoder, "aop_prune_config", None)
_orig_enabled = None
if isinstance(aop_cfg, dict) and aop_cfg:
_orig_enabled = aop_cfg.get("enabled", False)
apply_to = aop_cfg.get("apply_to", "qry")
side_enable = (apply_to == "both") or (apply_to == encode_side)
aop_cfg["enabled"] = bool(side_enable and _orig_enabled)
setattr(model.encoder, "aop_prune_config", aop_cfg)
if encode_side == "qry":
output = model(qry=inputs)
reps = output["qry_reps"].detach()
local_gt_infos.extend(dataset_info)
else:
output = model(tgt=inputs)
reps = output["tgt_reps"].detach()
local_gt_infos.extend([info["cand_name"] for info in dataset_info])
# ---- NEW: 恢复 enabled(避免影响下个 encode_side)----
if isinstance(aop_cfg, dict) and _orig_enabled is not None:
aop_cfg["enabled"] = _orig_enabled
setattr(model.encoder, "aop_prune_config", aop_cfg)
end_inference_time = time.time()
# --- NEW: 提取 post-prune 的 image/text 掩码 与 post attention_mask ---
img_masks_raw = None
txt_masks_raw = None
post_attn_raw = None
if isinstance(output, dict):
img_masks_raw = _search_key(output, "image_token_bool_masks")
txt_masks_raw = _search_key(output, "text_token_bool_masks") # NEW
post_attn_raw = _search_key(output, "post_attention_mask") # NEW(我们的 MMEBModel.forward 里带了这个键)
# 兼容:若挂在 model 上
if img_masks_raw is None and hasattr(model, "image_token_bool_masks"):
img_masks_raw = getattr(model, "image_token_bool_masks")
if txt_masks_raw is None and hasattr(model, "text_token_bool_masks"):
txt_masks_raw = getattr(model, "text_token_bool_masks")
if post_attn_raw is None and hasattr(model, "post_attention_mask"):
post_attn_raw = getattr(model, "post_attention_mask")
img_masks_serializable = _to_serializable_mask_list(img_masks_raw, current_batch_size)
txt_masks_serializable = _to_serializable_mask_list(txt_masks_raw, current_batch_size) # NEW
post_attn_serializable = _to_serializable_mask_list(post_attn_raw, current_batch_size) # NEW
local_img_token_masks.extend(img_masks_serializable)
local_txt_token_masks.extend(txt_masks_serializable) # NEW
local_post_attn_masks.extend(post_attn_serializable) # NEW
# --- NEW: 计算本 batch 的 pre/post/delta 数量并累计 ---
cfg = getattr(model.encoder, "config", None)
# pre masks 来自 inputs(删前)
input_ids = inputs.get("input_ids", None)
attn2d_pre = inputs.get("attention_mask", None)
if input_ids is None or attn2d_pre is None or cfg is None:
# 无法统计,留空
pre_vis_counts = [0] * current_batch_size
pre_txt_counts = [0] * current_batch_size
pre_tot_counts = [0] * current_batch_size
else:
iid = input_ids
am = attn2d_pre.to(torch.bool)
image_token_id = getattr(cfg, "image_token_id", None)
video_token_id = getattr(cfg, "video_token_id", None)
bos_id = getattr(cfg, "bos_token_id", None)
eos_id = getattr(cfg, "eos_token_id", None)
pad_id = getattr(cfg, "pad_token_id", None)
is_image = (iid == image_token_id) if (image_token_id is not None and image_token_id >= 0) else torch.zeros_like(iid, dtype=torch.bool)
is_video = (iid == video_token_id) if (video_token_id is not None and video_token_id >= 0) else torch.zeros_like(iid, dtype=torch.bool)
is_vision = is_image | is_video
is_special = torch.zeros_like(iid, dtype=torch.bool)
for tid in [bos_id, eos_id, pad_id]:
if tid is not None and tid >= 0:
is_special |= (iid == tid)
pre_txt_mask = am & (~is_vision) & (~is_special)
pre_vis_mask = am & is_vision
pre_vis_counts = pre_vis_mask.sum(dim=1).tolist()
pre_txt_counts = pre_txt_mask.sum(dim=1).tolist()
pre_tot_counts = am.sum(dim=1).tolist()
# post masks(删后)来自模型输出;与 post_attn 做与运算
post_text_masks = _to_bool_lists(txt_masks_raw, current_batch_size) # list[ list[bool] | None ]
post_image_masks = _to_bool_lists(img_masks_raw, current_batch_size)
post_attn_masks = _to_bool_lists(post_attn_raw, current_batch_size)
sum_pre_text = 0; sum_post_text = 0
sum_pre_vis = 0; sum_post_vis = 0
sum_pre_tot = 0; sum_post_tot = 0
for i in range(current_batch_size):
pre_text = int(pre_txt_counts[i]) if i < len(pre_txt_counts) else 0
pre_vis = int(pre_vis_counts[i]) if i < len(pre_vis_counts) else 0
pre_tot = int(pre_tot_counts[i]) if i < len(pre_tot_counts) else 0
# post 计数:mask 可能为 None
m_text = post_text_masks[i] if post_text_masks is not None and i < len(post_text_masks) else None
m_img = post_image_masks[i] if post_image_masks is not None and i < len(post_image_masks) else None
m_attn = post_attn_masks[i] if post_attn_masks is not None and i < len(post_attn_masks) else None
if m_attn is None:
post_text = 0; post_vis = 0; post_tot = 0
else:
# 与 attention_mask 后统计 True 的数
if m_text is not None:
post_text = sum(1 for a, t in zip(m_attn, m_text) if a and t)
else:
post_text = 0
if m_img is not None:
post_vis = sum(1 for a, v in zip(m_attn, m_img) if a and v)
else:
post_vis = 0
post_tot = sum(1 for a in m_attn if a)
# 累计 batch 级
sum_pre_text += pre_text; sum_post_text += post_text
sum_pre_vis += pre_vis; sum_post_vis += post_vis
sum_pre_tot += pre_tot; sum_post_tot += post_tot
# 保存 per-sample 记录(用于 JSONL)
local_token_records.append({
"side": encode_side,
"pre": {"text": pre_text, "vision": pre_vis, "total": pre_tot},
"post": {"text": post_text, "vision": post_vis, "total": post_tot},
"delta":{"text": pre_text - post_text, "vision": pre_vis - post_vis, "total": pre_tot - post_tot},
})
# --- Update total LLM input tokens after the model call ---
if 'input_ids' in inputs and inputs['input_ids'] is not None:
token_info["total_llm_input_tokens"] = inputs['input_ids'].shape[1]
token_info["text_input_tokens"] = token_info["total_llm_input_tokens"] - token_info["vision_tokens"]
token_info["text_input_tokens"] = max(0, token_info["text_input_tokens"])
# --- Collect and Store Batch Statistics ---
batch_inference_time = end_inference_time - start_inference_time
current_batch_stats = {
"batch_size": current_batch_size,
"total_inference_time_seconds": batch_inference_time,
"module_inference_times": {},
"token_counts": {
"visual_tokens": token_info["vision_tokens"],
"language_input_tokens_raw": token_info["text_input_tokens"],
"llm_total_input_tokens": token_info["total_llm_input_tokens"],
"language_output_tokens": token_info["text_output_tokens"],
}
}
current_batch_stats["token_reduction"] = {
"sum_pre_text": sum_pre_text,
"sum_post_text": sum_post_text,
"sum_pre_vision": sum_pre_vis,
"sum_post_vision": sum_post_vis,
"sum_pre_total": sum_pre_tot,
"sum_post_total": sum_post_tot,
}
# Calculate and store module timings for the current batch
for module_obj in registered_hooks:
module_id = id(module_obj)
module_name = module_obj.__class__.__name__
times = timing_info.get(module_id, [])
durations = []
pre_times = {}
for t, event_type, _ in times:
if event_type == 'pre':
pre_times[module_id] = t
elif event_type == 'post' and module_id in pre_times:
duration = t - pre_times.pop(module_id)
durations.append(duration)
if durations:
current_batch_stats["module_inference_times"][module_name] = {
"total": sum(durations),
"count": len(durations),
"avg": sum(durations) / len(durations)
}
else:
current_batch_stats["module_inference_times"][module_name] = {
"total": 0.0,
"count": 0,
"avg": 0.0
}
batch_stats_list.append(current_batch_stats)
# --- Debug prints (optional) ---
print_rank(f"\n--- Inference Statistics for {encode_side} batch (Rank {local_rank}) ---")
print_rank(f"Batch Inference took: {batch_inference_time:.4f} seconds")
print_rank("--- Module Inference Timing Statistics ---")
for module_name, stats in current_batch_stats["module_inference_times"].items():
print_rank(f"**{module_name}**: Total: {stats['total']:.6f}s, Count: {stats['count']}, Avg: {stats['avg']:.6f}s")
print_rank("--- Token Count Statistics ---")
print_rank(f"**视觉 token 数量**: {current_batch_stats['token_counts']['visual_tokens']}")
print_rank(f"**语言输入 token 数量 (仅原始文本)**: {current_batch_stats['token_counts']['language_input_tokens_raw']}")
print_rank(f"**LLM总输入 token 数量 (包含视觉 + 格式化文本)**: {current_batch_stats['token_counts']['llm_total_input_tokens']}")
print_rank(f"**语言输出 token 数量**: {current_batch_stats['token_counts']['language_output_tokens']}")
if is_late_interaction and reps.dim() == 3:
local_max_len = max(local_max_len, reps.shape[1])
local_embeds.append(reps)
if not local_embeds:
# Handle cases where a rank gets no data
return np.array([]), [], [], [] # CHANGED: 4个返回值
# === DDP Synchronization and Padding for Late-Interaction Models ===
if is_late_interaction:
if dist.is_initialized():
# 1: global max length
local_max_len_tensor = torch.tensor(local_max_len, device=training_args.device)
dist.all_reduce(local_max_len_tensor, op=dist.ReduceOp.MAX)
global_max_len = local_max_len_tensor.item()
else:
global_max_len = local_max_len
# 2: pad to global max length
padded_embeds = []
for reps_batch in local_embeds:
if reps_batch.dim() == 3:
B, L, H = reps_batch.shape
padding_size = global_max_len - L
padded_batch = F.pad(reps_batch, (0, 0, 0, padding_size), "constant", 0)
padded_embeds.append(padded_batch)
else:
padded_embeds.append(reps_batch)
embeds_tensor = torch.cat(padded_embeds, dim=0).contiguous()
else:
embeds_tensor = torch.cat(local_embeds, dim=0).contiguous()
# === Gather embeddings and keys from all ranks ===
if dist.is_initialized() and full_dataset.num_rows >= world_size:
print_master(f"Gathering {encode_side} embeddings across all ranks...")
# tensor gather
output_shape = list(embeds_tensor.shape)
output_shape[0] = full_dataset.num_rows
embeds_tensor = embeds_tensor.to(training_args.device)
gathered_embeds_tensor = torch.empty(output_shape, dtype=embeds_tensor.dtype, device=training_args.device)
dist.all_gather_into_tensor(gathered_embeds_tensor, embeds_tensor)
final_embeddings = gathered_embeds_tensor.cpu().float().numpy()
# object gather for infos and stats
gathered_gt_infos = [None for _ in range(world_size)]
dist.all_gather_object(gathered_gt_infos, local_gt_infos)
all_gt_infos = [key for rank_keys in gathered_gt_infos for key in rank_keys]
gathered_batch_stats = [None for _ in range(world_size)]
dist.all_gather_object(gathered_batch_stats, batch_stats_list)
all_batch_stats = [stats for rank_stats in gathered_batch_stats for stats in rank_stats]
# --- NEW: gather masks ---
gathered_masks = [None for _ in range(world_size)]
dist.all_gather_object(gathered_masks, local_img_token_masks)
all_img_token_masks = [m for rank_list in gathered_masks for m in rank_list]
# NEW: gather text masks
gathered_txt_masks = [None for _ in range(world_size)]
dist.all_gather_object(gathered_txt_masks, local_txt_token_masks)
all_txt_token_masks = [m for rank_list in gathered_txt_masks for m in rank_list]
# NEW: gather post attention masks(如需)
gathered_post_attn = [None for _ in range(world_size)]
dist.all_gather_object(gathered_post_attn, local_post_attn_masks)
all_post_attn_masks = [m for rank_list in gathered_post_attn for m in rank_list]
# NEW: gather token records
gathered_token_recs = [None for _ in range(world_size)]
dist.all_gather_object(gathered_token_recs, local_token_records)
all_token_records = [r for rank_list in gathered_token_recs for r in rank_list]
else:
all_gt_infos = local_gt_infos
final_embeddings = embeds_tensor.cpu().float().numpy()
all_batch_stats = batch_stats_list
all_img_token_masks = local_img_token_masks # NEW
all_txt_token_masks = local_txt_token_masks
all_post_attn_masks = local_post_attn_masks
all_token_records = local_token_records
return final_embeddings, all_gt_infos, all_batch_stats, all_img_token_masks, all_txt_token_masks, all_token_records
# === NEW: 一次前向同时导出 cand 的中间层和最后一层向量 ===
def encode_candidates_both_layers(
model: MMEBModel,
loader: DataLoader,
training_args: TrainingArguments,
model_args: ModelArguments,
full_dataset: Dataset,
mid_layer: int,
) -> tuple[np.ndarray, np.ndarray, list]:
"""
单次forward到最后一层,直接从 hidden_states 取:
- mid_hidden = hidden_states[mid_layer] # 表示经过 mid_layer 层后的状态(见Qwen2_5_VLModel的all_hidden_states定义)
- last_hidden = hidden_states[-1] # 最后一层norm后的状态
然后用 _pooling(attention_mask) 取句向量,返回:
- cand_mid_embeds: np.ndarray [Nc, D]
- cand_last_embeds: np.ndarray [Nc, D]
- cand_ids: list[str]
说明:
- cand 侧默认不做 AOP 剪枝(AOP_APPLY=qry 时天然关闭),因此 mid/last 的序列长度一致,可直接用原 attention_mask 做池化。
"""
local_rank = dist.get_rank() if dist.is_initialized() else 0
model.eval()
all_mid = []
all_last = []
all_ids = []
with torch.no_grad():
for inputs, dataset_info in tqdm(loader, desc=f"Candidates[BOTH] (rank {local_rank})", disable=local_rank > 0):
inputs = batch_to_device(inputs, training_args.device)
# cand 侧确保不触发 AOP(如果你的 AOP_APPLY=qry/both,会在底模按侧门控;此处再做一次保险)
aop_cfg = getattr(model.encoder, "aop_prune_config", None)
_orig_enabled = None
if isinstance(aop_cfg, dict) and aop_cfg:
_orig_enabled = aop_cfg.get("enabled", False)
apply_to = aop_cfg.get("apply_to", "qry")
side_enable = (apply_to == "both") or (apply_to == "cand")
aop_cfg["enabled"] = bool(side_enable and _orig_enabled)
setattr(model.encoder, "aop_prune_config", aop_cfg)
with torch.autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda"):
# 关键:一次forward拿全层的hidden_states
out = model.encoder(
**inputs,
return_dict=True,
output_hidden_states=True, # 必须
stop_at_layer=None, # 走全层
)
# 取 hidden_states 并索引中间层/最后一层
hs_list = out.hidden_states
assert hs_list is not None and len(hs_list) > mid_layer, \
f"hidden_states is None or too short. Need index {mid_layer}, got len={0 if hs_list is None else len(hs_list)}"
mid_hs = hs_list[mid_layer] # [B, L, D]:等价“经过 mid_layer 层后的状态”(即 pre-layer(mid_layer+1))
last_hs = hs_list[-1] # [B, L, D]:最终norm后的状态
# 用原 attention_mask 池化(cand侧未剪枝)
am = inputs.get("attention_mask", None)
if am is not None and hasattr(am, "device"):
if am.device != mid_hs.device:
am = am.to(mid_hs.device)
reps_mid = model._pooling(mid_hs, am) # [B, D]
reps_last = model._pooling(last_hs, am) # [B, D]
all_mid.append(reps_mid.detach().float().cpu())
all_last.append(reps_last.detach().float().cpu())
all_ids.extend([info["cand_name"] for info in dataset_info])
# 恢复 AOP 开关(避免影响其它侧)
if isinstance(aop_cfg, dict) and _orig_enabled is not None:
aop_cfg["enabled"] = _orig_enabled
setattr(model.encoder, "aop_prune_config", aop_cfg)
if not all_mid:
return np.array([]), np.array([]), []
cand_mid_embeds = torch.cat(all_mid, dim=0).numpy()
cand_last_embeds = torch.cat(all_last, dim=0).numpy()
return cand_mid_embeds, cand_last_embeds, all_ids
def main():
# ----------------------- Distributed init -----------------------
if "RANK" in os.environ and dist.is_available() and not dist.is_initialized():
dist.init_process_group(backend="nccl", timeout=datetime.timedelta(minutes=60))
local_rank = dist.get_rank() if dist.is_initialized() else 0
world_size = dist.get_world_size() if dist.is_initialized() else 1
print_master("Distributed init debug info:")
print_master(f"RANK: {os.environ.get('RANK')}")
print_master(f"LOCAL_RANK: {os.environ.get('LOCAL_RANK')}")
print_master(f"WORLD_SIZE: {os.environ.get('WORLD_SIZE')}")
print_master(f"MASTER_ADDR: {os.environ.get('MASTER_ADDR')}")
print_master(f"MASTER_PORT: {os.environ.get('MASTER_PORT')}")
if dist.is_initialized():
print_rank(f"dist.get_rank(): {dist.get_rank()}")
print_rank(f"dist.get_world_size(): {dist.get_world_size()}")
# 兼容 torchrun 参数
for arg in sys.argv:
if arg.startswith("--local-rank="):
rank = arg.split("=")[1]
sys.argv.remove(arg)
sys.argv.append('--local_rank')
sys.argv.append(rank)
# ----------------------- Parse args -----------------------
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
model_args: ModelArguments
data_args: DataArguments
training_args: TrainingArguments
os.makedirs(data_args.encode_output_path, exist_ok=True)
# 支持多层评测(优先 LM_LAYERS,兼容 MID_LM_LAYER)
layers_to_eval = get_env_eval_layers()
print_master(f"Eval layers (qry/tgt): {layers_to_eval}")
# ----------------------- Model loading -----------------------
hf_config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True)
if not getattr(model_args, "model_backbone", None):
model_backbone = get_backbone_name(hf_config=hf_config, model_type=model_args.model_type)
setattr(model_args, 'model_backbone', model_backbone)
setattr(training_args, 'model_backbone', model_backbone)
print_master(f'Model Backbone: {model_args.model_backbone}')
# 仅 rank0 下载,其他rank等待缓存
if local_rank == 0:
processor = load_processor(model_args, data_args)
model = MMEBModel.load(model_args, is_trainable=False, processor=processor)
print_master(f"[rank=0] Loading the model from Huggingface: {model_args.model_name}...")
if torch.distributed.is_initialized():
torch.distributed.barrier()
if local_rank != 0:
print_rank(f"Loading the model from cache...")
processor = load_processor(model_args, data_args)
time.sleep(random.randint(2 * local_rank, 3 * local_rank))
model = MMEBModel.load(model_args, is_trainable=False, processor=processor)
model.eval()
model = model.to(training_args.device, dtype=torch.bfloat16)
# ---- NEW: AOP 剪裁配置注入(驱动底模里已实现的 AOP 逻辑)----
aop_cfg = get_env_aop_config()
if aop_cfg["enabled"]:
# 把配置塞到底模;底模 forward 中读取该 dict 并执行剪裁
setattr(model.encoder, "aop_prune_config", aop_cfg)
# 可选:为了便于在判定层取注意力或手算 qk,覆盖注意力实现
attn_override = aop_cfg.get("attn_impl_override", "")
if attn_override:
try:
if hasattr(model.encoder, "model") and hasattr(model.encoder.model, "config"):
prev = model.encoder.model.config._attn_implementation
model.encoder.model.config._attn_implementation = attn_override
print_master(f"[AOP] override attn impl: {prev} -> {attn_override} (仅测试建议)")
except Exception as e:
print_master(f"[AOP] try override attn impl failed: {e}")
print_master("[AOP] AOP-Prune enabled with config: " + json.dumps({
"apply_to": aop_cfg["apply_to"],
"layer_idx": aop_cfg["layer_idx"],
"mode": aop_cfg["mode"],
"delta": aop_cfg["delta"],
"K_hat": aop_cfg["K_hat"],
"keep_ratio": aop_cfg["keep_ratio"],
"min_keep": aop_cfg["min_keep"],
"use_bias": aop_cfg["use_bias"],
"margin_mid?": (aop_cfg["margin_mid"] is not None),
"prune_text": aop_cfg.get("prune_text", False),
"keep_ratio_text": aop_cfg.get("keep_ratio_text", None),
"keep_ratio_vision": aop_cfg.get("keep_ratio_vision", None),
"selection": aop_cfg.get("selection", "aop"),
"attn_agg": aop_cfg.get("attn_agg", "mean"),
}))
else:
print_master("[AOP] disabled (set AOP_ENABLED=1 to enable)")
# 确保“最后一层”时不裁层(避免类里默认20层的坑)
model.set_inference_layers(qry_layers=None, tgt_layers=None)
with open(data_args.dataset_config, 'r') as yaml_file:
dataset_configs = yaml.safe_load(yaml_file)
# ----------------------- Main evaluation loop -----------------------
for dataset_idx, (dataset_name, task_config) in enumerate(dataset_configs.items()):
if dist.is_initialized():
dist.barrier()
print_master(f"\n--- Evaluating {dataset_name} ---")
# 根据 data_basedir 修正路径
if data_args.data_basedir is not None:
for key in ["image_root", "video_root", "frame_root", "clip_root", "data_path"]:
if data_args.data_basedir and task_config.get(key):
task_config[key] = os.path.join(data_args.data_basedir, task_config[key])
# 构建数据集
full_eval_qry_dataset, corpus = AutoEvalPairDataset.instantiate(model_args=model_args, data_args=data_args, **task_config)
full_eval_cand_dataset = generate_cand_dataset(full_eval_qry_dataset, corpus)
eval_qry_dataset, eval_cand_dataset = full_eval_qry_dataset, full_eval_cand_dataset
if dist.is_initialized():
world_size = dist.get_world_size()
padded_qry_dataset, _ = pad_dataset_to_divisible(full_eval_qry_dataset, world_size)
padded_cand_dataset, _ = pad_dataset_to_divisible(full_eval_cand_dataset, world_size)
eval_qry_dataset = split_dataset_by_node(padded_qry_dataset, rank=local_rank, world_size=world_size)
eval_cand_dataset = split_dataset_by_node(padded_cand_dataset, rank=local_rank, world_size=world_size)
else:
padded_qry_dataset, padded_cand_dataset = full_eval_qry_dataset, full_eval_cand_dataset
# === EE-only: 仅在线早停推理(先确保两份 candidate 向量)===
ee_cfg = get_env_ee_config()
assert ee_cfg["enabled"], "EE_ENABLED must be 1 for EE-only pipeline."
# 依据 EE_LAYER 构造 tag
mid_layer = int(ee_cfg["layer"])
mid_tag = make_layer_tag(mid_layer) # e.g., layer12
last_tag = "layerlast"
# 准备路径
cand_mid_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt_{mid_tag}")
cand_last_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt_{last_tag}")
# 构造 cand DataLoader(一次性,不切分)
eval_cand_collator = MultimodalEvalDataCollator(processor, model_args, data_args, "cand")
eval_cand_loader = DataLoader(
full_eval_cand_dataset,
batch_size=training_args.per_device_eval_batch_size,
collate_fn=eval_cand_collator,
num_workers=training_args.dataloader_num_workers
)
# === 替换为:一次前向,导出 cand 的 mid/last 两份向量 ===
need_mid = (not os.path.exists(cand_mid_path))
need_last = (not os.path.exists(cand_last_path))
if need_mid or need_last:
print_master(f"[{dataset_name}] EE-only: encoding candidates BOTH layers in one pass (mid={mid_tag}, last={last_tag}) ...")
# 走全层(不截层)
model.set_inference_layers(qry_layers=None, tgt_layers=None)
cand_embeds_mid, cand_embeds_last, all_cand_ids = encode_candidates_both_layers(
model=model,
loader=eval_cand_loader,
training_args=training_args,
model_args=model_args,
full_dataset=full_eval_cand_dataset,
mid_layer=mid_layer,
)
if local_rank == 0:
if need_mid:
cand_embed_dict_mid = {cid: emb for cid, emb in zip(all_cand_ids, cand_embeds_mid)}
with open(cand_mid_path, "wb") as f:
pickle.dump(cand_embed_dict_mid, f)
print_master(f"[{dataset_name}] EE-only: saved {mid_tag} candidate embeddings -> {cand_mid_path}")
if need_last:
cand_embed_dict_last = {cid: emb for cid, emb in zip(all_cand_ids, cand_embeds_last)}
with open(cand_last_path, "wb") as f:
pickle.dump(cand_embed_dict_last, f)
print_master(f"[{dataset_name}] EE-only: saved {last_tag} candidate embeddings -> {cand_last_path}")
else:
print_master(f"[{dataset_name}] EE-only: reuse existing candidates (mid={cand_mid_path}, last={cand_last_path})")
if dist.is_initialized():
dist.barrier()
# 3) 在线早停门控 + 子集续跑(不做离线分层评分/曲线)
if local_rank == 0:
with open(cand_mid_path, "rb") as f:
cand_mid_dict = pickle.load(f)
with open(cand_last_path, "rb") as f:
cand_last_dict = pickle.load(f)
rank_global = task_config.get("eval_type", "global") == "global"
print_master(f"[{dataset_name}] Run ONLINE early-exit at layer={ee_cfg['layer']}, method={ee_cfg['method']}, tau={ee_cfg['tau']}, topk={ee_cfg['topk']}, global={rank_global}")
run_early_exit_queries(
model=model,
processor=processor,
model_args=model_args,
data_args=data_args,
training_args=training_args,
qry_dataset=full_eval_qry_dataset, # 全量 query
cand_mid_dict=cand_mid_dict,
cand_last_dict=cand_last_dict,
ee_cfg=ee_cfg,
dataset_name=dataset_name,
out_dir=data_args.encode_output_path,
global_ranking=rank_global,
)
if dist.is_initialized():
dist.barrier()
# === EE-only 结束;直接进入下一个数据集 ===
continue
if __name__ == '__main__':
main()