# import datetime # import logging # import json # import random # import time # import numpy as np # import os # import pickle # import sys # import torch # import torch.distributed as dist # import torch.nn.functional as F # import yaml # import transformers # import math # from torch.utils.data import DataLoader # from tqdm import tqdm # from transformers import HfArgumentParser, AutoConfig, AutoTokenizer # from datasets import Dataset, concatenate_datasets # from datasets.distributed import split_dataset_by_node # from src.model.vlm_backbone.qwen2_vl.modeling_qwen2_vl_train_tokrnpooling import Qwen2VLForConditionalGeneration as _Qwen2VLForConditionalGeneration_src # from src.arguments import ModelArguments, DataArguments, TrainingArguments # from src.data.collator.eval_collator import MultimodalEvalDataCollator # from src.data.eval_dataset.base_eval_dataset import AutoEvalPairDataset, generate_cand_dataset # from src.eval_utils.metrics import RankingMetrics # from src.model.model_cut_layer_AOP_add_text_cut import MMEBModel # from src.model.processor import get_backbone_name, load_processor, COLPALI # from src.utils import batch_to_device, print_rank, print_master # from dataclasses import dataclass # def get_env_mid_layer(): # v = os.environ.get("MID_LM_LAYER", "").strip() # if v == "" or v.lower() in {"none", "null"}: # return None # try: # return int(v) # except: # logger.warning(f"Invalid MID_LM_LAYER={v}, ignore.") # return None # # ------------- AOP-Prune config parsing ------------- # def _parse_bool(v: str, default=False): # if v is None: return default # v = v.strip().lower() # return v in {"1","true","yes","y","t","on"} # def _parse_float(v: str, default=None): # try: return float(v) if v is not None else default # except: return default # def _parse_int(v: str, default=None): # try: return int(v) if v is not None else default # except: return default # def get_env_aop_config(): # """ # 从环境变量读取 AOP 剪裁配置。仅作为“驱动层”的简要测试开关; # 实际剪裁逻辑在底模里(Qwen2-VLModel.forward)实现。 # """ # enabled = _parse_bool(os.environ.get("AOP_ENABLED"), False) # apply_to = os.environ.get("AOP_APPLY", "qry").strip().lower() # qry|cand|both # layer_idx = _parse_int(os.environ.get("AOP_LAYER"), None) # mode = os.environ.get("AOP_MODE", "delta").strip().lower() # # 通用回退 # delta = _parse_float(os.environ.get("AOP_DELTA"), 0.10) # khat = _parse_float(os.environ.get("AOP_KHAT"), 1.0) # keep_ratio = _parse_float(os.environ.get("AOP_KEEP_RATIO"), 1.0) # min_keep = _parse_int(os.environ.get("AOP_MIN_KEEP"), 64) # use_bias = _parse_bool(os.environ.get("AOP_USE_BIAS"), True) # # 按类型控制 # prune_vision = _parse_bool(os.environ.get("AOP_PRUNE_VISION"), True) # prune_text = _parse_bool(os.environ.get("AOP_PRUNE_TEXT"), False) # delta_v = _parse_float(os.environ.get("AOP_DELTA_VISION"), None) # khat_v = _parse_float(os.environ.get("AOP_KHAT_VISION"), None) # keep_ratio_v= _parse_float(os.environ.get("AOP_KEEP_RATIO_VISION"), None) # min_keep_v = _parse_int(os.environ.get("AOP_MIN_KEEP_VISION"), None) # delta_t = _parse_float(os.environ.get("AOP_DELTA_TEXT"), None) # khat_t = _parse_float(os.environ.get("AOP_KHAT_TEXT"), None) # keep_ratio_t= _parse_float(os.environ.get("AOP_KEEP_RATIO_TEXT"), None) # min_keep_t = _parse_int(os.environ.get("AOP_MIN_KEEP_TEXT"), 32) # protect_text_last = _parse_int(os.environ.get("AOP_PROTECT_TEXT_LAST"), 16) # protect_special = _parse_bool(os.environ.get("AOP_PROTECT_SPECIAL"), True) # margin_src = os.environ.get("AOP_MARGIN", "").strip().lower() # "" or "mid" # attn_impl = os.environ.get("AOP_ATTN_IMPL", "").strip().lower() # "" or "sdpa" # if layer_idx is None and enabled: # logger.warning("AOP_ENABLED=1 但未设置 AOP_LAYER,关闭 AOP。"); enabled=False # # 新增:选择策略(aop | random) # selection = os.environ.get("AOP_SELECTION", "aop").strip().lower() # if _parse_bool(os.environ.get("AOP_RANDOM"), False): # selection = "random" # random_seed = _parse_int(os.environ.get("AOP_RANDOM_SEED"), None) # # 选择策略:aop | random | attention # selection = os.environ.get("AOP_SELECTION", "aop").strip().lower() # if _parse_bool(os.environ.get("AOP_RANDOM"), False): # selection = "random" # random_seed = _parse_int(os.environ.get("AOP_RANDOM_SEED"), None) # attn_agg = os.environ.get("AOP_ATTENTION_AGG", "mean").strip().lower() # mean|max|sum # cfg = { # "enabled": enabled, # "apply_to": apply_to, # "layer_idx": layer_idx, # "mode": mode, # # 回退 # "delta": delta, "K_hat": khat, # "keep_ratio": keep_ratio, "min_keep": min_keep, # "use_bias": use_bias, "eps": 1e-6, # # 类型开关 # "prune_vision": prune_vision, # "prune_text": prune_text, # # 视觉桶 # "delta_vision": delta_v, # "K_hat_vision": khat_v, # "keep_ratio_vision": keep_ratio_v, # "min_keep_vision": min_keep_v, # # 文本桶 # "delta_text": delta_t, # "K_hat_text": khat_t, # "keep_ratio_text": keep_ratio_t, # "min_keep_text": min_keep_t, # # 文本保护 # "protect_text_last": protect_text_last, # "protect_special": protect_special, # # 可选:排名安全预算 # "margin_mid": None if margin_src != "mid" else "USE_MID_MARGIN", # "epsilon_hat": None, # "attn_impl_override": attn_impl if attn_impl in {"sdpa"} else "", # # NEW: 选择策略 # "selection": selection, # "aop" 或 "random" # "random_seed": random_seed, # 可选 # "attn_agg": attn_agg, # } # return cfg # def get_env_eval_layers(): # """ # 解析环境变量 LM_LAYERS(优先)或兼容旧的 MID_LM_LAYER。 # - LM_LAYERS 示例:"4,8,12,last";可包含 'last'/'none'/'null'/'-1' 表示最后一层(None)。 # - 若未设置 LM_LAYERS,则回落到旧逻辑:MID_LM_LAYER=None -> [None];否则 [mid, None] # 返回: list[ int | None ],例如 [4, 8, 12, None];None 代表最后一层。 # """ # v = os.environ.get("LM_LAYERS", None) # if v is not None: # v = v.strip() # if v: # toks = [t.strip() for t in v.split(',') if t.strip() != ""] # layers = [] # for tok in toks: # tl = tok.lower() # if tl in {"last", "none", "null", "-1"}: # layers.append(None) # else: # try: # val = int(tok) # if val > 0: # layers.append(val) # else: # logger.warning(f"Ignoring non-positive layer '{tok}' in LM_LAYERS.") # except Exception: # logger.warning(f"Invalid token '{tok}' in LM_LAYERS; must be int or 'last'/'none'.") # # 去重但保持顺序 # seen = set() # uniq = [] # for l in layers: # key = -1 if l is None else l # if key in seen: # continue # seen.add(key) # uniq.append(l) # if not uniq: # return [None] # return uniq # else: # # 兼容旧逻辑 # mid = get_env_mid_layer() # return [None] if mid is None else [mid, None] # # === Early-Exit config & helpers === # def get_env_ee_config(): # ee_enabled = os.environ.get("EE_ENABLED", "0").strip().lower() in {"1","true","yes","on","y","t"} # layer = int(os.environ.get("EE_LAYER", os.environ.get("AOP_LAYER", "12"))) # 默认用 AOP_LAYER # method = os.environ.get("EE_METHOD", "margin").strip().lower() # margin|p1p2|entropy|gini|combined # tau = float(os.environ.get("EE_TAU", "0.2")) # topk = int(os.environ.get("EE_TOPK", "1024")) # temp = float(os.environ.get("EE_TEMP", "0.05")) # save = os.environ.get("EE_SAVE", "1").strip().lower() in {"1","true","yes","on","y","t"} # combw = os.environ.get("EE_COMB_WEIGHTS", "1.0,0.5,0.5") # try: # w_margin, w_conf, w_sq = [float(x) for x in combw.split(",")] # except Exception: # w_margin, w_conf, w_sq = 1.0, 0.5, 0.5 # return dict( # enabled=ee_enabled, layer=layer, method=method, tau=tau, # topk=topk, temp=temp, save=save, # w_margin=w_margin, w_conf=w_conf, w_sq=w_sq # ) # def _softmax_np(x: np.ndarray, temp: float = 1.0) -> np.ndarray: # x = x - np.max(x) # ex = np.exp(x / max(1e-6, temp)) # s = np.sum(ex) # return ex / max(s, 1e-12) # def confidence_from_topk(scores: np.ndarray, method="margin", temp=0.05, w_margin=1.0, w_conf=0.5, w_sq=0.5) -> float: # # scores: 已按降序排列(topK) # if scores.size == 0: # return 0.0 # if scores.size == 1: # return 1e9 # margin = float(scores[0] - scores[1]) # p = _softmax_np(scores, temp=temp) # p1p2 = float(p[0] - p[1]) # H = - float(np.sum(p * np.log(p + 1e-12))) / np.log(len(p)) # 归一化熵 ∈ [0,1] # conf = 1.0 - H # sqsum = float(np.sum(p**2)) # Gini 的等价度量(越大越集中) # if method == "margin": return margin # if method == "p1p2": return p1p2 # if method == "entropy": return conf # if method == "gini": return sqsum # # combined # return w_margin*margin + w_conf*conf + w_sq*sqsum # def run_early_exit_queries( # model: MMEBModel, # processor, # model_args: ModelArguments, # data_args: DataArguments, # training_args: TrainingArguments, # qry_dataset: Dataset, # cand_mid_dict: dict, # cand_last_dict: dict, # ee_cfg: dict, # dataset_name: str, # out_dir: str, # global_ranking: bool = True, # ): # device = training_args.device # local_rank = dist.get_rank() if dist.is_initialized() else 0 # is_main = (not dist.is_initialized()) or (local_rank == 0) # # 候选矩阵 -> GPU(bfloat16) # cand_ids = list(cand_mid_dict.keys()) # cand_id2row = {str(cid): i for i, cid in enumerate(cand_ids)} # # cand_mid = np.stack([cand_mid_dict[c] for c in cand_ids]).astype(np.float32) # # cand_last = np.stack([cand_last_dict[c] for c in cand_ids]).astype(np.float32) # # cand_mid_t = torch.from_numpy(cand_mid).to(device=device, dtype=torch.bfloat16) # # cand_last_t = torch.from_numpy(cand_last).to(device=device, dtype=torch.bfloat16) # cand_mid = np.stack([cand_mid_dict[c] for c in cand_ids]).astype(np.float32) # cand_last = np.stack([cand_last_dict[c] for c in cand_ids]).astype(np.float32) # # 先搬 cand_mid 到 GPU;cand_last 延迟到真的需要续跑时再搬 # cand_mid_t = torch.from_numpy(cand_mid).to(device=device, dtype=torch.bfloat16) # cand_last_t = None # NEW: 延迟到 need_idx>0 分支内首次使用时再构造 # # DataLoader(仅 query) # collator = MultimodalEvalDataCollator(processor, model_args, data_args, "qry") # loader = DataLoader( # qry_dataset, # batch_size=training_args.per_device_eval_batch_size, # collate_fn=collator, # num_workers=training_args.dataloader_num_workers # ) # pred_dicts = [] # details = [] # # AOP 按侧门控(仅对 query 生效) # aop_cfg = getattr(model.encoder, "aop_prune_config", None) # _orig_enabled = None # side_enable = True # if isinstance(aop_cfg, dict) and aop_cfg: # _orig_enabled = aop_cfg.get("enabled", False) # apply_to = aop_cfg.get("apply_to", "qry") # side_enable = (apply_to == "both") or (apply_to == "qry") # # 门控用的 k(margin/p1p2 只需要 top2) # k_conf = 2 # tau = float(ee_cfg["tau"]) # method= ee_cfg["method"] # temp = float(ee_cfg["temp"]) # start_time = time.time() # idx_global = 0 # for inputs, infos in tqdm(loader, desc=f"[EE] {dataset_name}@L{ee_cfg['layer']} (rank {local_rank})", disable=local_rank>0): # inputs = batch_to_device(inputs, device) # # if isinstance(aop_cfg, dict) and aop_cfg: # # aop_cfg["enabled"] = bool(_orig_enabled and side_enable) # # setattr(model.encoder, "aop_prune_config", aop_cfg) # # # 1) 前半程:跑到中间层(stop_at_layer),跳过 logits # # with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): # # out_mid = model.encoder( # # **inputs, # # return_dict=True, # # output_hidden_states=False, # # stop_at_layer=int(ee_cfg["layer"]), # # compute_lm_head=False, # 关键:不算 logits # # ) # orig_cfg = None # if isinstance(aop_cfg, dict) and aop_cfg: # orig_cfg = dict(aop_cfg) # 备份原配置,mid 后恢复 # aop_layer = aop_cfg.get("layer_idx", None) # ee_layer = int(ee_cfg["layer"]) # apply_to = aop_cfg.get("apply_to", "qry").strip().lower() # # 新规则:mid 阶段是否启用 AOP # aop_on_mid = bool( # _orig_enabled and side_enable and # (aop_layer is not None) and (aop_layer < ee_layer) and # (apply_to in {"qry", "both"}) # ) # aop_cfg_mid = dict(aop_cfg) # aop_cfg_mid["enabled"] = aop_on_mid # setattr(model.encoder, "aop_prune_config", aop_cfg_mid) # # 1) 前半程:跑到中间层(stop_at_layer),跳过 logits # with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): # out_mid = model.encoder( # **inputs, # return_dict=True, # output_hidden_states=False, # stop_at_layer=int(ee_cfg["layer"]), # compute_lm_head=False, # 不算 logits # ) # # 恢复原始 AOP 配置(避免影响后续续跑逻辑) # if isinstance(orig_cfg, dict): # setattr(model.encoder, "aop_prune_config", orig_cfg) # # EOS 池化 -> GPU # hs_mid = getattr(out_mid, "last_hidden_state", None) # if hs_mid is None: # assert out_mid.hidden_states is not None and len(out_mid.hidden_states) > 0 # hs_mid = out_mid.hidden_states[-1] # am_mid = getattr(out_mid, "attention_mask", None) # if am_mid is None: # am_mid = inputs.get("attention_mask", None) # if hasattr(am_mid, "device") and am_mid.device != hs_mid.device: # am_mid = am_mid.to(hs_mid.device) # reps_mid_t = model._pooling(hs_mid, am_mid).detach().to(device=device, dtype=torch.bfloat16) # [B,D] # B = reps_mid_t.size(0) # use_local = (not global_ranking) # # 2) 门控:GPU上 top2 + 置信度 # if not use_local: # # 全库:top2 即可 # scores_t = reps_mid_t @ cand_mid_t.T # [B, Nc] # vals_t, idxs_t = torch.topk(scores_t, k=min(k_conf, scores_t.size(1)), dim=1) # [B,2] # p_t = torch.softmax(vals_t / max(temp, 1e-6), dim=1) # if vals_t.size(1) >= 2: # margin_t = vals_t[:, 0] - vals_t[:, 1] # p1p2_t = p_t[:, 0] - p_t[:, 1] # else: # margin_t = torch.full((B,), float("inf"), device=device, dtype=vals_t.dtype) # p1p2_t = torch.ones(B, device=device, dtype=vals_t.dtype) # H_t = -(p_t * (torch.log(p_t + 1e-12))).sum(dim=1) / math.log(max(vals_t.size(1),1)) # conf_map = {"margin": margin_t, "p1p2": p1p2_t, "entropy": 1.0 - H_t, "gini": (p_t ** 2).sum(dim=1)} # confs_t = conf_map.get(method, margin_t) # [B] # exit_mask = (confs_t >= tau).detach().cpu().numpy().astype(bool) # else: # confs = [] # for b in range(B): # cand_local = infos[b]["cand_names"] # rows = [cand_id2row.get(str(cid), -1) for cid in cand_local] # rows = [r for r in rows if r >= 0] # if len(rows) == 0: # confs.append(0.0); continue # cmat_t = cand_mid_t[rows] # [Nl, D] # sv_t = (reps_mid_t[b:b+1] @ cmat_t.T)[0] # [Nl] # k = 2 if sv_t.size(0) >= 2 else 1 # vals_t, _ = torch.topk(sv_t, k=k, dim=0) # p_t = torch.softmax(vals_t / max(temp, 1e-6), dim=0) # if k >= 2: # margin = (vals_t[0] - vals_t[1]).item() # p1p2 = (p_t[0] - p_t[1]).item() # else: # margin, p1p2 = float("inf"), 1.0 # H = (-(p_t * (torch.log(p_t + 1e-12))).sum() / math.log(max(k,1))).item() # gini = ((p_t ** 2).sum()).item() # d = {"margin": margin, "p1p2": p1p2, "entropy": 1.0 - H, "gini": gini} # confs.append(d.get(method, margin)) # exit_mask = (np.array(confs) >= tau) # # 3) 早停:直接 mid 排序(只在需要保存 details 时构建 topk 列表) # for j in np.where(exit_mask)[0].tolist(): # if not use_local: # scores_j = (reps_mid_t[j:j+1] @ cand_mid_t.T)[0] # [Nc] # order = torch.argsort(scores_j, dim=0, descending=True).detach().cpu().numpy() # cids = [cand_ids[i] for i in order] # else: # cand_local = infos[j]["cand_names"] # rows = [cand_id2row.get(str(cid), -1) for cid in cand_local] # rows = [r for r in rows if r >= 0] # if len(rows) == 0: # cids = [] # else: # cmat_t = cand_mid_t[rows] # vec = (reps_mid_t[j:j+1] @ cmat_t.T)[0] # order_local = torch.argsort(vec, dim=0, descending=True).detach().cpu().numpy() # cids = [str(cand_local[i]) for i in order_local] # rel_docids = infos[j]["label_name"] if isinstance(infos[j]["label_name"], list) else [infos[j]["label_name"]] # pred_dicts.append({"prediction": cids, "label": rel_docids, "rel_scores": None}) # # 4) 续跑:仅对未早停子集,从中间态继续到 last(跳过 logits) # need_idx = np.where(~exit_mask)[0].tolist() # if len(need_idx) > 0: # if cand_last_t is None: # cand_last_t = torch.from_numpy(cand_last).to(device=device, dtype=torch.bfloat16) # if isinstance(aop_cfg, dict) and aop_cfg: # aop_resume = dict(aop_cfg) # aop_resume["enabled"] = bool(_orig_enabled and side_enable) # setattr(model.encoder, "aop_prune_config", aop_resume) # interm = getattr(out_mid, "intermediate_state", None) # assert interm is not None, "Model must return intermediate_state when stop_at_layer is set." # hs = interm["hidden_states"].detach() # am = interm["attention_mask"].detach() # pos = interm["position_ids"].detach() # vm = interm.get("vision_mask", None) # tm = interm.get("text_mask", None) # next_layer = int(interm["next_layer_idx"]) # hs_sub = hs[need_idx] # am_sub = am[need_idx] # pos_sub = pos[:, need_idx, :] # vm_sub = vm[need_idx] if vm is not None else None # tm_sub = tm[need_idx] if tm is not None else None # resume_state = { # "hidden_states": hs_sub, # "attention_mask": am_sub, # "position_ids": pos_sub, # "vision_mask": vm_sub, # "text_mask": tm_sub, # "next_layer_idx": next_layer, # } # with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): # out_last = model.encoder( # return_dict=True, # output_hidden_states=False, # stop_at_layer=None, # resume_state=resume_state, # compute_lm_head=False, # 关键:不算 logits # ) # hs_last = getattr(out_last, "last_hidden_state", None) # if hs_last is None: # assert out_last.hidden_states is not None and len(out_last.hidden_states) > 0 # hs_last = out_last.hidden_states[-1] # am_last = getattr(out_last, "attention_mask", None) # if am_last is None: # am_last = am_sub # if hasattr(am_last, "device") and am_last.device != hs_last.device: # am_last = am_last.to(hs_last.device) # reps_last_t = model._pooling(hs_last, am_last).detach().to(device=device, dtype=torch.bfloat16) # if not use_local: # scores_last_t = reps_last_t @ cand_last_t.T # order_t = torch.argsort(scores_last_t, dim=1, descending=True) # for k, j in enumerate(need_idx): # order = order_t[k].detach().cpu().numpy() # cids = [cand_ids[i] for i in order] # rel_docids = infos[j]["label_name"] if isinstance(infos[j]["label_name"], list) else [infos[j]["label_name"]] # pred_dicts.append({"prediction": cids, "label": rel_docids, "rel_scores": None}) # else: # for k, j in enumerate(need_idx): # cand_local = infos[j]["cand_names"] # rows = [cand_id2row.get(str(cid), -1) for cid in cand_local] # rows = [r for r in rows if r >= 0] # if len(rows) == 0: # cids = [] # else: # cmat_last_t = cand_last_t[rows] # vec_t = (reps_last_t[k:k+1] @ cmat_last_t.T)[0] # order_local = torch.argsort(vec_t, dim=0, descending=True).detach().cpu().numpy() # cids = [str(cand_local[i]) for i in order_local] # rel_docids = infos[j]["label_name"] if isinstance(infos[j]["label_name"], list) else [infos[j]["label_name"]] # pred_dicts.append({"prediction": cids, "label": rel_docids, "rel_scores": None}) # idx_global += B # # 评测并保存 # metrics_to_report = ["hit", "ndcg", "precision", "recall", "f1", "map", "mrr"] # score = RankingMetrics(metrics_to_report).evaluate(pred_dicts) # if is_main: # os.makedirs(out_dir, exist_ok=True) # with open(os.path.join(out_dir, f"{dataset_name}_score_earlyexit.json"), "w") as f: # json.dump(score, f, indent=4) # # 建议测速时 EE_SAVE=0,不写 details # if ee_cfg.get("save", False): # with open(os.path.join(out_dir, f"{dataset_name}_pred_earlyexit.jsonl"), "w", encoding="utf-8") as f: # for row in pred_dicts: f.write(json.dumps(row, ensure_ascii=False) + "\n") # elapsed = time.time() - start_time # return score, elapsed # def make_layer_tag(keep_layers: int | None): # return f"layer{keep_layers}" if keep_layers and keep_layers > 0 else "layerlast" # def dot_sim(a: np.ndarray, b: np.ndarray) -> np.ndarray: # # a: [Nq, D], b: [Nc, D], both L2-normalized already if normalize=true # return a @ b.T # def build_score_details(qid: int, cand_ids: list, score_vec: np.ndarray, ranked_indices: np.ndarray): # return { # "qid": int(qid), # "cand_scores": [ # {"cand_id": str(cand_ids[i]), "score": float(score_vec[i])} # for i in ranked_indices # ] # } # def top1_top2_margin(score_vec: np.ndarray) -> float: # if len(score_vec) < 2: # return float("inf") # 只有一个候选时视作极大margin # top2 = np.partition(score_vec, -2)[-2:] # top2.sort() # return float(top2[-1] - top2[-2]) # def simulate_early_exit_by_margin( # sims_mid: list[dict], sims_last: list[dict], labels: list[list[str]], metrics_to_report: list[str], # taus: list[float], rank_global: bool # ): # """ # sims_mid / sims_last: 每个query一个dict: {cand_id: score} # labels: 每个query的正样本cand_id列表 # 返回:不同tau下的覆盖率、指标 # """ # assert len(sims_mid) == len(sims_last) == len(labels) # N = len(labels) # results = [] # from src.eval_utils.metrics import RankingMetrics # metrics = RankingMetrics(metrics_to_report) # # 预构造 用于metrics.evaluate 的pred_dict # def to_pred_dicts(use_mid_mask: list[bool]) -> list[dict]: # pred_dicts = [] # for qid in range(N): # sims_use = sims_mid[qid] if use_mid_mask[qid] else sims_last[qid] # # 排序 # ranked = sorted(sims_use.items(), key=lambda x: -x[1]) # pred_dicts.append({ # "prediction": [cid for cid, _ in ranked], # "label": labels[qid], # "rel_scores": None # }) # return pred_dicts # # 计算中间层margin # margins = [] # for qid in range(N): # # 取前两大分数的margin # if len(sims_mid[qid]) == 0: # margins.append(0.0) # continue # scores = np.array(list(sims_mid[qid].values()), dtype=np.float32) # margins.append(top1_top2_margin(scores)) # margins = np.array(margins, dtype=np.float32) # for tau in taus: # use_mid_mask = (margins >= tau).tolist() # pred_dicts = to_pred_dicts(use_mid_mask) # score_dict = metrics.evaluate(pred_dicts) # coverage = float(np.mean(use_mid_mask)) # 早停覆盖率 # results.append({ # "tau": tau, # "coverage": coverage, # **score_dict # }) # return results # def top1_top2_margin_from_array(score_vec: np.ndarray) -> float: # if score_vec is None or len(score_vec) == 0: # return 0.0 # if len(score_vec) == 1: # return float('inf') # # 取前两大 # top2 = np.partition(score_vec, -2)[-2:] # top2.sort() # return float(top2[-1] - top2[-2]) # logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s [%(name)s:%(lineno)s] %(message)s') # logger = logging.getLogger(__name__) # # --- Global Dictionaries for Hooks (will be cleared before each encode_embeddings call) --- # timing_info = {} # token_info = { # "vision_tokens": 0, # "text_input_tokens": 0, # Refers to the original text token count # "text_output_tokens": 0, # Not directly applicable here as we are encoding, not generating. Will be 0. # "total_llm_input_tokens": 0, # Refers to the total tokens LLM receives (visual + formatted text) # } # # --- Hook Functions Definition --- # def timing_pre_hook(module, input): # module_id = id(module) # if module_id not in timing_info: # timing_info[module_id] = [] # timing_info[module_id].append((time.time(), 'pre', module.__class__.__name__)) # def timing_post_hook(module, input, output): # module_id = id(module) # if module_id not in timing_info: # # print(f"Warning: No pre-hook data for module {module.__class__.__name__} ({module_id})") # return # timing_info[module_id].append((time.time(), 'post', module.__class__.__name__)) # # Collect vision token count (only from Vision Transformer module's post hook) # module_name = module.__class__.__name__ # if "vision" in module_name.lower() and "transformer" in module_name.lower(): # if isinstance(output, torch.Tensor): # token_info["vision_tokens"] = output.shape[0] # For visual features, usually (batch_size, num_tokens, hidden_dim) # elif hasattr(output, 'last_hidden_state'): # token_info["vision_tokens"] = output.last_hidden_state.shape[1] # def register_model_hooks(model): # registered_modules = [] # core_model = model.encoder if hasattr(model, "encoder") and model.encoder is not None else model # # Vision module # if hasattr(core_model, 'visual') and core_model.visual is not None: # vision_module = core_model.visual # vision_module.register_forward_pre_hook(timing_pre_hook) # vision_module.register_forward_hook(timing_post_hook) # registered_modules.append(vision_module) # print_master(f"Registered hooks for vision module: {vision_module.__class__.__name__}") # else: # print_master(f"WARNING: No 'visual' attribute found on core_model ({type(core_model)}).") # # Merger module (if inside visual) - it's part of the vision component # if hasattr(core_model, 'visual') and hasattr(core_model.visual, 'merger') and core_model.visual.merger is not None: # merger_module = core_model.visual.merger # merger_module.register_forward_pre_hook(timing_pre_hook) # merger_module.register_forward_hook(timing_post_hook) # registered_modules.append(merger_module) # print_master(f"Registered hooks for merger module: {merger_module.__class__.__name__}") # else: # print_master(f"WARNING: No 'merger' attribute found on core_model.visual ({type(getattr(core_model, 'visual', 'N/A'))}).") # # Language model body # if hasattr(core_model, 'model') and core_model.model is not None: # llm_main_module = core_model.model # llm_main_module.register_forward_pre_hook(timing_pre_hook) # llm_main_module.register_forward_hook(timing_post_hook) # registered_modules.append(llm_main_module) # print_master(f"Registered hooks for LLM main module: {llm_main_module.__class__.__name__}") # else: # print_master(f"WARNING: No 'model' attribute found on core_model ({type(core_model)}).") # # LM Head # if hasattr(core_model, 'lm_head') and core_model.lm_head is not None: # lm_head_module = core_model.lm_head # lm_head_module.register_forward_pre_hook(timing_pre_hook) # lm_head_module.register_forward_hook(timing_post_hook) # registered_modules.append(lm_head_module) # print_master(f"Registered hooks for LM head module: {lm_head_module.__class__.__name__}") # else: # print_master(f"WARNING: No 'lm_head' attribute found on core_model ({type(core_model)}).") # if not registered_modules: # print_master("Warning: No major modules found for hook registration. Check model architecture.") # return registered_modules # def pad_dataset_to_divisible(dataset, world_size): # num_samples = len(dataset) # if num_samples % world_size == 0: # return dataset, num_samples # num_to_add = world_size - (num_samples % world_size) # padded_size = num_samples + num_to_add # padding_data = dataset.select([i % len(dataset) for i in range(num_to_add)]) # padded_dataset = concatenate_datasets([dataset, padding_data]) # return padded_dataset, padded_size # def encode_embeddings( # model: MMEBModel, # loader: DataLoader, # training_args: TrainingArguments, # model_args: ModelArguments, # full_dataset: Dataset, # encode_side: str, # description: str = "Encoding" # ) -> tuple[np.ndarray, list, list, list]: # CHANGED: + list for img_token_masks # """ # Encodes embeddings for a given dataset using the model, handling both standard and # late-interaction models in a DDP-safe manner. # Returns: # - embeddings: np.ndarray # - infos_or_ids: list # - batch_stats_list: list # - img_token_masks: list[None | list[bool]] # NEW # """ # local_rank = dist.get_rank() if dist.is_initialized() else 0 # world_size = dist.get_world_size() if dist.is_initialized() else 1 # # Check if the model is a late-interaction type # is_late_interaction = (model_args.model_backbone == COLPALI) # local_embeds = [] # local_gt_infos = [] # local_max_len = 0 # # --- New: List to store statistics for each batch --- # batch_stats_list = [] # # --- NEW: Collect masks --- # local_img_token_masks = [] # post image mask per sample # local_txt_token_masks = [] # NEW: post text mask per sample # local_post_attn_masks = [] # NEW: post attention_mask per sample (after prune, 1/0) # # --- NEW: per-sample token reduction records --- # local_token_records = [] # 每条样本一个 dict,含 pre/post/delta 数量 # model.eval() # # Register hooks for the model once per encode_embeddings call # registered_hooks = register_model_hooks(model) # # --- NEW: helpers to取mask并序列化 --- # def _search_key(obj, key: str): # # 递归搜索 dict/list/tuple,找到指定 key # if isinstance(obj, dict): # if key in obj: # return obj[key] # for v in obj.values(): # r = _search_key(v, key) # if r is not None: # return r # elif isinstance(obj, (list, tuple)): # for v in obj: # r = _search_key(v, key) # if r is not None: # return r # return None # def _to_serializable_mask_list(mask_list, batch_size: int): # # 将模型返回的 mask(list/tensor/ndarray/None)转成 [None | list[bool]] * B # if mask_list is None: # return [None] * batch_size # out = [] # if isinstance(mask_list, (list, tuple)): # for m in mask_list: # if m is None: # out.append(None) # elif torch.is_tensor(m): # out.append(m.detach().cpu().tolist()) # elif isinstance(m, np.ndarray): # out.append(m.tolist()) # else: # # already python list/bool # out.append(m) # elif torch.is_tensor(mask_list): # # 若是 2D 张量(B, L),直接 tolist() -> list[list[bool/int]] # out = mask_list.detach().cpu().tolist() # elif isinstance(mask_list, np.ndarray): # out = mask_list.tolist() # else: # # 未知类型,保守返回 None 占位 # out = [None] * batch_size # # 长度对齐 batch_size # if isinstance(out, list): # if len(out) < batch_size: # out = out + [None] * (batch_size - len(out)) # elif len(out) > batch_size: # out = out[:batch_size] # return out # def _to_bool_lists(m, batch_size: int): # lst = _to_serializable_mask_list(m, batch_size) # # 归一化成 list[ list[bool] | None ] # out = [] # for x in lst: # if x is None: # out.append(None) # else: # # x 可能是 list[int] 或 list[bool] # out.append([bool(int(v)) for v in x]) # return out # with torch.no_grad(): # for inputs, dataset_info in tqdm(loader, desc=f"{description} (rank {local_rank})", disable=local_rank > 0): # # --- Reset statistics for each inference pass --- # timing_info.clear() # token_info["vision_tokens"] = 0 # token_info["text_input_tokens"] = 0 # token_info["text_output_tokens"] = 0 # token_info["total_llm_input_tokens"] = 0 # inputs = batch_to_device(inputs, training_args.device) # current_batch_size = inputs['input_ids'].shape[0] if 'input_ids' in inputs and inputs['input_ids'] is not None else 1 # with torch.autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda"): # start_inference_time = time.time() # # ---- NEW: 按侧开/关 AOP ---- # aop_cfg = getattr(model.encoder, "aop_prune_config", None) # _orig_enabled = None # if isinstance(aop_cfg, dict) and aop_cfg: # _orig_enabled = aop_cfg.get("enabled", False) # apply_to = aop_cfg.get("apply_to", "qry") # side_enable = (apply_to == "both") or (apply_to == encode_side) # aop_cfg["enabled"] = bool(side_enable and _orig_enabled) # setattr(model.encoder, "aop_prune_config", aop_cfg) # if encode_side == "qry": # output = model(qry=inputs) # reps = output["qry_reps"].detach() # local_gt_infos.extend(dataset_info) # else: # output = model(tgt=inputs) # reps = output["tgt_reps"].detach() # local_gt_infos.extend([info["cand_name"] for info in dataset_info]) # # ---- NEW: 恢复 enabled(避免影响下个 encode_side)---- # if isinstance(aop_cfg, dict) and _orig_enabled is not None: # aop_cfg["enabled"] = _orig_enabled # setattr(model.encoder, "aop_prune_config", aop_cfg) # end_inference_time = time.time() # # --- NEW: 提取 post-prune 的 image/text 掩码 与 post attention_mask --- # img_masks_raw = None # txt_masks_raw = None # post_attn_raw = None # if isinstance(output, dict): # img_masks_raw = _search_key(output, "image_token_bool_masks") # txt_masks_raw = _search_key(output, "text_token_bool_masks") # NEW # post_attn_raw = _search_key(output, "post_attention_mask") # NEW(我们的 MMEBModel.forward 里带了这个键) # # 兼容:若挂在 model 上 # if img_masks_raw is None and hasattr(model, "image_token_bool_masks"): # img_masks_raw = getattr(model, "image_token_bool_masks") # if txt_masks_raw is None and hasattr(model, "text_token_bool_masks"): # txt_masks_raw = getattr(model, "text_token_bool_masks") # if post_attn_raw is None and hasattr(model, "post_attention_mask"): # post_attn_raw = getattr(model, "post_attention_mask") # img_masks_serializable = _to_serializable_mask_list(img_masks_raw, current_batch_size) # txt_masks_serializable = _to_serializable_mask_list(txt_masks_raw, current_batch_size) # NEW # post_attn_serializable = _to_serializable_mask_list(post_attn_raw, current_batch_size) # NEW # local_img_token_masks.extend(img_masks_serializable) # local_txt_token_masks.extend(txt_masks_serializable) # NEW # local_post_attn_masks.extend(post_attn_serializable) # NEW # # --- NEW: 计算本 batch 的 pre/post/delta 数量并累计 --- # cfg = getattr(model.encoder, "config", None) # # pre masks 来自 inputs(删前) # input_ids = inputs.get("input_ids", None) # attn2d_pre = inputs.get("attention_mask", None) # if input_ids is None or attn2d_pre is None or cfg is None: # # 无法统计,留空 # pre_vis_counts = [0] * current_batch_size # pre_txt_counts = [0] * current_batch_size # pre_tot_counts = [0] * current_batch_size # else: # iid = input_ids # am = attn2d_pre.to(torch.bool) # image_token_id = getattr(cfg, "image_token_id", None) # video_token_id = getattr(cfg, "video_token_id", None) # bos_id = getattr(cfg, "bos_token_id", None) # eos_id = getattr(cfg, "eos_token_id", None) # pad_id = getattr(cfg, "pad_token_id", None) # is_image = (iid == image_token_id) if (image_token_id is not None and image_token_id >= 0) else torch.zeros_like(iid, dtype=torch.bool) # is_video = (iid == video_token_id) if (video_token_id is not None and video_token_id >= 0) else torch.zeros_like(iid, dtype=torch.bool) # is_vision = is_image | is_video # is_special = torch.zeros_like(iid, dtype=torch.bool) # for tid in [bos_id, eos_id, pad_id]: # if tid is not None and tid >= 0: # is_special |= (iid == tid) # pre_txt_mask = am & (~is_vision) & (~is_special) # pre_vis_mask = am & is_vision # pre_vis_counts = pre_vis_mask.sum(dim=1).tolist() # pre_txt_counts = pre_txt_mask.sum(dim=1).tolist() # pre_tot_counts = am.sum(dim=1).tolist() # # post masks(删后)来自模型输出;与 post_attn 做与运算 # post_text_masks = _to_bool_lists(txt_masks_raw, current_batch_size) # list[ list[bool] | None ] # post_image_masks = _to_bool_lists(img_masks_raw, current_batch_size) # post_attn_masks = _to_bool_lists(post_attn_raw, current_batch_size) # sum_pre_text = 0; sum_post_text = 0 # sum_pre_vis = 0; sum_post_vis = 0 # sum_pre_tot = 0; sum_post_tot = 0 # for i in range(current_batch_size): # pre_text = int(pre_txt_counts[i]) if i < len(pre_txt_counts) else 0 # pre_vis = int(pre_vis_counts[i]) if i < len(pre_vis_counts) else 0 # pre_tot = int(pre_tot_counts[i]) if i < len(pre_tot_counts) else 0 # # post 计数:mask 可能为 None # m_text = post_text_masks[i] if post_text_masks is not None and i < len(post_text_masks) else None # m_img = post_image_masks[i] if post_image_masks is not None and i < len(post_image_masks) else None # m_attn = post_attn_masks[i] if post_attn_masks is not None and i < len(post_attn_masks) else None # if m_attn is None: # post_text = 0; post_vis = 0; post_tot = 0 # else: # # 与 attention_mask 后统计 True 的数 # if m_text is not None: # post_text = sum(1 for a, t in zip(m_attn, m_text) if a and t) # else: # post_text = 0 # if m_img is not None: # post_vis = sum(1 for a, v in zip(m_attn, m_img) if a and v) # else: # post_vis = 0 # post_tot = sum(1 for a in m_attn if a) # # 累计 batch 级 # sum_pre_text += pre_text; sum_post_text += post_text # sum_pre_vis += pre_vis; sum_post_vis += post_vis # sum_pre_tot += pre_tot; sum_post_tot += post_tot # # 保存 per-sample 记录(用于 JSONL) # local_token_records.append({ # "side": encode_side, # "pre": {"text": pre_text, "vision": pre_vis, "total": pre_tot}, # "post": {"text": post_text, "vision": post_vis, "total": post_tot}, # "delta":{"text": pre_text - post_text, "vision": pre_vis - post_vis, "total": pre_tot - post_tot}, # }) # # --- Update total LLM input tokens after the model call --- # if 'input_ids' in inputs and inputs['input_ids'] is not None: # token_info["total_llm_input_tokens"] = inputs['input_ids'].shape[1] # token_info["text_input_tokens"] = token_info["total_llm_input_tokens"] - token_info["vision_tokens"] # token_info["text_input_tokens"] = max(0, token_info["text_input_tokens"]) # # --- Collect and Store Batch Statistics --- # batch_inference_time = end_inference_time - start_inference_time # current_batch_stats = { # "batch_size": current_batch_size, # "total_inference_time_seconds": batch_inference_time, # "module_inference_times": {}, # "token_counts": { # "visual_tokens": token_info["vision_tokens"], # "language_input_tokens_raw": token_info["text_input_tokens"], # "llm_total_input_tokens": token_info["total_llm_input_tokens"], # "language_output_tokens": token_info["text_output_tokens"], # } # } # current_batch_stats["token_reduction"] = { # "sum_pre_text": sum_pre_text, # "sum_post_text": sum_post_text, # "sum_pre_vision": sum_pre_vis, # "sum_post_vision": sum_post_vis, # "sum_pre_total": sum_pre_tot, # "sum_post_total": sum_post_tot, # } # # Calculate and store module timings for the current batch # for module_obj in registered_hooks: # module_id = id(module_obj) # module_name = module_obj.__class__.__name__ # times = timing_info.get(module_id, []) # durations = [] # pre_times = {} # for t, event_type, _ in times: # if event_type == 'pre': # pre_times[module_id] = t # elif event_type == 'post' and module_id in pre_times: # duration = t - pre_times.pop(module_id) # durations.append(duration) # if durations: # current_batch_stats["module_inference_times"][module_name] = { # "total": sum(durations), # "count": len(durations), # "avg": sum(durations) / len(durations) # } # else: # current_batch_stats["module_inference_times"][module_name] = { # "total": 0.0, # "count": 0, # "avg": 0.0 # } # batch_stats_list.append(current_batch_stats) # # --- Debug prints (optional) --- # print_rank(f"\n--- Inference Statistics for {encode_side} batch (Rank {local_rank}) ---") # print_rank(f"Batch Inference took: {batch_inference_time:.4f} seconds") # print_rank("--- Module Inference Timing Statistics ---") # for module_name, stats in current_batch_stats["module_inference_times"].items(): # print_rank(f"**{module_name}**: Total: {stats['total']:.6f}s, Count: {stats['count']}, Avg: {stats['avg']:.6f}s") # print_rank("--- Token Count Statistics ---") # print_rank(f"**视觉 token 数量**: {current_batch_stats['token_counts']['visual_tokens']}") # print_rank(f"**语言输入 token 数量 (仅原始文本)**: {current_batch_stats['token_counts']['language_input_tokens_raw']}") # print_rank(f"**LLM总输入 token 数量 (包含视觉 + 格式化文本)**: {current_batch_stats['token_counts']['llm_total_input_tokens']}") # print_rank(f"**语言输出 token 数量**: {current_batch_stats['token_counts']['language_output_tokens']}") # if is_late_interaction and reps.dim() == 3: # local_max_len = max(local_max_len, reps.shape[1]) # local_embeds.append(reps) # if not local_embeds: # # Handle cases where a rank gets no data # return np.array([]), [], [], [] # CHANGED: 4个返回值 # # === DDP Synchronization and Padding for Late-Interaction Models === # if is_late_interaction: # if dist.is_initialized(): # # 1: global max length # local_max_len_tensor = torch.tensor(local_max_len, device=training_args.device) # dist.all_reduce(local_max_len_tensor, op=dist.ReduceOp.MAX) # global_max_len = local_max_len_tensor.item() # else: # global_max_len = local_max_len # # 2: pad to global max length # padded_embeds = [] # for reps_batch in local_embeds: # if reps_batch.dim() == 3: # B, L, H = reps_batch.shape # padding_size = global_max_len - L # padded_batch = F.pad(reps_batch, (0, 0, 0, padding_size), "constant", 0) # padded_embeds.append(padded_batch) # else: # padded_embeds.append(reps_batch) # embeds_tensor = torch.cat(padded_embeds, dim=0).contiguous() # else: # embeds_tensor = torch.cat(local_embeds, dim=0).contiguous() # # === Gather embeddings and keys from all ranks === # if dist.is_initialized() and full_dataset.num_rows >= world_size: # print_master(f"Gathering {encode_side} embeddings across all ranks...") # # tensor gather # output_shape = list(embeds_tensor.shape) # output_shape[0] = full_dataset.num_rows # embeds_tensor = embeds_tensor.to(training_args.device) # gathered_embeds_tensor = torch.empty(output_shape, dtype=embeds_tensor.dtype, device=training_args.device) # dist.all_gather_into_tensor(gathered_embeds_tensor, embeds_tensor) # final_embeddings = gathered_embeds_tensor.cpu().float().numpy() # # object gather for infos and stats # gathered_gt_infos = [None for _ in range(world_size)] # dist.all_gather_object(gathered_gt_infos, local_gt_infos) # all_gt_infos = [key for rank_keys in gathered_gt_infos for key in rank_keys] # gathered_batch_stats = [None for _ in range(world_size)] # dist.all_gather_object(gathered_batch_stats, batch_stats_list) # all_batch_stats = [stats for rank_stats in gathered_batch_stats for stats in rank_stats] # # --- NEW: gather masks --- # gathered_masks = [None for _ in range(world_size)] # dist.all_gather_object(gathered_masks, local_img_token_masks) # all_img_token_masks = [m for rank_list in gathered_masks for m in rank_list] # # NEW: gather text masks # gathered_txt_masks = [None for _ in range(world_size)] # dist.all_gather_object(gathered_txt_masks, local_txt_token_masks) # all_txt_token_masks = [m for rank_list in gathered_txt_masks for m in rank_list] # # NEW: gather post attention masks(如需) # gathered_post_attn = [None for _ in range(world_size)] # dist.all_gather_object(gathered_post_attn, local_post_attn_masks) # all_post_attn_masks = [m for rank_list in gathered_post_attn for m in rank_list] # # NEW: gather token records # gathered_token_recs = [None for _ in range(world_size)] # dist.all_gather_object(gathered_token_recs, local_token_records) # all_token_records = [r for rank_list in gathered_token_recs for r in rank_list] # else: # all_gt_infos = local_gt_infos # final_embeddings = embeds_tensor.cpu().float().numpy() # all_batch_stats = batch_stats_list # all_img_token_masks = local_img_token_masks # NEW # all_txt_token_masks = local_txt_token_masks # all_post_attn_masks = local_post_attn_masks # all_token_records = local_token_records # return final_embeddings, all_gt_infos, all_batch_stats, all_img_token_masks, all_txt_token_masks, all_token_records # # === NEW: 一次前向同时导出 cand 的中间层和最后一层向量 === # def encode_candidates_both_layers( # model: MMEBModel, # loader: DataLoader, # training_args: TrainingArguments, # model_args: ModelArguments, # full_dataset: Dataset, # mid_layer: int, # ) -> tuple[np.ndarray, np.ndarray, list]: # """ # 单次forward到最后一层,直接从 hidden_states 取: # - mid_hidden = hidden_states[mid_layer] # 表示经过 mid_layer 层后的状态(见Qwen2_5_VLModel的all_hidden_states定义) # - last_hidden = hidden_states[-1] # 最后一层norm后的状态 # 然后用 _pooling(attention_mask) 取句向量,返回: # - cand_mid_embeds: np.ndarray [Nc, D] # - cand_last_embeds: np.ndarray [Nc, D] # - cand_ids: list[str] # 说明: # - cand 侧默认不做 AOP 剪枝(AOP_APPLY=qry 时天然关闭),因此 mid/last 的序列长度一致,可直接用原 attention_mask 做池化。 # """ # local_rank = dist.get_rank() if dist.is_initialized() else 0 # model.eval() # all_mid = [] # all_last = [] # all_ids = [] # with torch.no_grad(): # for inputs, dataset_info in tqdm(loader, desc=f"Candidates[BOTH] (rank {local_rank})", disable=local_rank > 0): # inputs = batch_to_device(inputs, training_args.device) # # cand 侧确保不触发 AOP(如果你的 AOP_APPLY=qry/both,会在底模按侧门控;此处再做一次保险) # aop_cfg = getattr(model.encoder, "aop_prune_config", None) # _orig_enabled = None # if isinstance(aop_cfg, dict) and aop_cfg: # _orig_enabled = aop_cfg.get("enabled", False) # apply_to = aop_cfg.get("apply_to", "qry") # side_enable = (apply_to == "both") or (apply_to == "cand") # aop_cfg["enabled"] = bool(side_enable and _orig_enabled) # setattr(model.encoder, "aop_prune_config", aop_cfg) # with torch.autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda"): # # 关键:一次forward拿全层的hidden_states # out = model.encoder( # **inputs, # return_dict=True, # output_hidden_states=True, # 必须 # stop_at_layer=None, # 走全层 # ) # # 取 hidden_states 并索引中间层/最后一层 # hs_list = out.hidden_states # assert hs_list is not None and len(hs_list) > mid_layer, \ # f"hidden_states is None or too short. Need index {mid_layer}, got len={0 if hs_list is None else len(hs_list)}" # mid_hs = hs_list[mid_layer] # [B, L, D]:等价“经过 mid_layer 层后的状态”(即 pre-layer(mid_layer+1)) # last_hs = hs_list[-1] # [B, L, D]:最终norm后的状态 # # 用原 attention_mask 池化(cand侧未剪枝) # am = inputs.get("attention_mask", None) # if am is not None and hasattr(am, "device"): # if am.device != mid_hs.device: # am = am.to(mid_hs.device) # reps_mid = model._pooling(mid_hs, am) # [B, D] # reps_last = model._pooling(last_hs, am) # [B, D] # all_mid.append(reps_mid.detach().float().cpu()) # all_last.append(reps_last.detach().float().cpu()) # all_ids.extend([info["cand_name"] for info in dataset_info]) # # 恢复 AOP 开关(避免影响其它侧) # if isinstance(aop_cfg, dict) and _orig_enabled is not None: # aop_cfg["enabled"] = _orig_enabled # setattr(model.encoder, "aop_prune_config", aop_cfg) # if not all_mid: # return np.array([]), np.array([]), [] # cand_mid_embeds = torch.cat(all_mid, dim=0).numpy() # cand_last_embeds = torch.cat(all_last, dim=0).numpy() # return cand_mid_embeds, cand_last_embeds, all_ids # def main(): # # ----------------------- Distributed init ----------------------- # if "RANK" in os.environ and dist.is_available() and not dist.is_initialized(): # dist.init_process_group(backend="nccl", timeout=datetime.timedelta(minutes=60)) # local_rank = dist.get_rank() if dist.is_initialized() else 0 # world_size = dist.get_world_size() if dist.is_initialized() else 1 # print_master("Distributed init debug info:") # print_master(f"RANK: {os.environ.get('RANK')}") # print_master(f"LOCAL_RANK: {os.environ.get('LOCAL_RANK')}") # print_master(f"WORLD_SIZE: {os.environ.get('WORLD_SIZE')}") # print_master(f"MASTER_ADDR: {os.environ.get('MASTER_ADDR')}") # print_master(f"MASTER_PORT: {os.environ.get('MASTER_PORT')}") # if dist.is_initialized(): # print_rank(f"dist.get_rank(): {dist.get_rank()}") # print_rank(f"dist.get_world_size(): {dist.get_world_size()}") # # 兼容 torchrun 参数 # for arg in sys.argv: # if arg.startswith("--local-rank="): # rank = arg.split("=")[1] # sys.argv.remove(arg) # sys.argv.append('--local_rank') # sys.argv.append(rank) # # ----------------------- Parse args ----------------------- # parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) # model_args, data_args, training_args = parser.parse_args_into_dataclasses() # model_args: ModelArguments # data_args: DataArguments # training_args: TrainingArguments # os.makedirs(data_args.encode_output_path, exist_ok=True) # # 支持多层评测(优先 LM_LAYERS,兼容 MID_LM_LAYER) # layers_to_eval = get_env_eval_layers() # print_master(f"Eval layers (qry/tgt): {layers_to_eval}") # # ----------------------- Model loading ----------------------- # hf_config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) # if not getattr(model_args, "model_backbone", None): # model_backbone = get_backbone_name(hf_config=hf_config, model_type=model_args.model_type) # setattr(model_args, 'model_backbone', model_backbone) # setattr(training_args, 'model_backbone', model_backbone) # print_master(f'Model Backbone: {model_args.model_backbone}') # # 仅 rank0 下载,其他rank等待缓存 # if local_rank == 0: # processor = load_processor(model_args, data_args) # model = MMEBModel.load(model_args, is_trainable=False, processor=processor) # print_master(f"[rank=0] Loading the model from Huggingface: {model_args.model_name}...") # if torch.distributed.is_initialized(): # torch.distributed.barrier() # if local_rank != 0: # print_rank(f"Loading the model from cache...") # processor = load_processor(model_args, data_args) # time.sleep(random.randint(2 * local_rank, 3 * local_rank)) # model = MMEBModel.load(model_args, is_trainable=False, processor=processor) # model.eval() # model = model.to(training_args.device, dtype=torch.bfloat16) # # ---- NEW: AOP 剪裁配置注入(驱动底模里已实现的 AOP 逻辑)---- # aop_cfg = get_env_aop_config() # if aop_cfg["enabled"]: # # 把配置塞到底模;底模 forward 中读取该 dict 并执行剪裁 # setattr(model.encoder, "aop_prune_config", aop_cfg) # # 可选:为了便于在判定层取注意力或手算 qk,覆盖注意力实现 # attn_override = aop_cfg.get("attn_impl_override", "") # if attn_override: # try: # if hasattr(model.encoder, "model") and hasattr(model.encoder.model, "config"): # prev = model.encoder.model.config._attn_implementation # model.encoder.model.config._attn_implementation = attn_override # print_master(f"[AOP] override attn impl: {prev} -> {attn_override} (仅测试建议)") # except Exception as e: # print_master(f"[AOP] try override attn impl failed: {e}") # print_master("[AOP] AOP-Prune enabled with config: " + json.dumps({ # "apply_to": aop_cfg["apply_to"], # "layer_idx": aop_cfg["layer_idx"], # "mode": aop_cfg["mode"], # "delta": aop_cfg["delta"], # "K_hat": aop_cfg["K_hat"], # "keep_ratio": aop_cfg["keep_ratio"], # "min_keep": aop_cfg["min_keep"], # "use_bias": aop_cfg["use_bias"], # "margin_mid?": (aop_cfg["margin_mid"] is not None), # "prune_text": aop_cfg.get("prune_text", False), # "keep_ratio_text": aop_cfg.get("keep_ratio_text", None), # "keep_ratio_vision": aop_cfg.get("keep_ratio_vision", None), # "selection": aop_cfg.get("selection", "aop"), # "attn_agg": aop_cfg.get("attn_agg", "mean"), # })) # else: # print_master("[AOP] disabled (set AOP_ENABLED=1 to enable)") # # 确保“最后一层”时不裁层(避免类里默认20层的坑) # model.set_inference_layers(qry_layers=None, tgt_layers=None) # with open(data_args.dataset_config, 'r') as yaml_file: # dataset_configs = yaml.safe_load(yaml_file) # # ----------------------- Main evaluation loop ----------------------- # for dataset_idx, (dataset_name, task_config) in enumerate(dataset_configs.items()): # if dist.is_initialized(): # dist.barrier() # print_master(f"\n--- Evaluating {dataset_name} ---") # # 根据 data_basedir 修正路径 # if data_args.data_basedir is not None: # for key in ["image_root", "video_root", "frame_root", "clip_root", "data_path"]: # if data_args.data_basedir and task_config.get(key): # task_config[key] = os.path.join(data_args.data_basedir, task_config[key]) # # 构建数据集 # full_eval_qry_dataset, corpus = AutoEvalPairDataset.instantiate(model_args=model_args, data_args=data_args, **task_config) # full_eval_cand_dataset = generate_cand_dataset(full_eval_qry_dataset, corpus) # eval_qry_dataset, eval_cand_dataset = full_eval_qry_dataset, full_eval_cand_dataset # if dist.is_initialized(): # world_size = dist.get_world_size() # padded_qry_dataset, _ = pad_dataset_to_divisible(full_eval_qry_dataset, world_size) # padded_cand_dataset, _ = pad_dataset_to_divisible(full_eval_cand_dataset, world_size) # eval_qry_dataset = split_dataset_by_node(padded_qry_dataset, rank=local_rank, world_size=world_size) # eval_cand_dataset = split_dataset_by_node(padded_cand_dataset, rank=local_rank, world_size=world_size) # else: # padded_qry_dataset, padded_cand_dataset = full_eval_qry_dataset, full_eval_cand_dataset # # === EE-only: 仅在线早停推理(先确保两份 candidate 向量)=== # ee_cfg = get_env_ee_config() # assert ee_cfg["enabled"], "EE_ENABLED must be 1 for EE-only pipeline." # # 依据 EE_LAYER 构造 tag # mid_layer = int(ee_cfg["layer"]) # mid_tag = make_layer_tag(mid_layer) # e.g., layer12 # last_tag = "layerlast" # # 准备路径 # cand_mid_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt_{mid_tag}") # cand_last_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt_{last_tag}") # # 构造 cand DataLoader(一次性,不切分) # eval_cand_collator = MultimodalEvalDataCollator(processor, model_args, data_args, "cand") # eval_cand_loader = DataLoader( # full_eval_cand_dataset, # batch_size=training_args.per_device_eval_batch_size, # collate_fn=eval_cand_collator, # num_workers=training_args.dataloader_num_workers # ) # # === 替换为:一次前向,导出 cand 的 mid/last 两份向量 === # need_mid = (not os.path.exists(cand_mid_path)) # need_last = (not os.path.exists(cand_last_path)) # if need_mid or need_last: # print_master(f"[{dataset_name}] EE-only: encoding candidates BOTH layers in one pass (mid={mid_tag}, last={last_tag}) ...") # # 走全层(不截层) # model.set_inference_layers(qry_layers=None, tgt_layers=None) # cand_embeds_mid, cand_embeds_last, all_cand_ids = encode_candidates_both_layers( # model=model, # loader=eval_cand_loader, # training_args=training_args, # model_args=model_args, # full_dataset=full_eval_cand_dataset, # mid_layer=mid_layer, # ) # if local_rank == 0: # if need_mid: # cand_embed_dict_mid = {cid: emb for cid, emb in zip(all_cand_ids, cand_embeds_mid)} # with open(cand_mid_path, "wb") as f: # pickle.dump(cand_embed_dict_mid, f) # print_master(f"[{dataset_name}] EE-only: saved {mid_tag} candidate embeddings -> {cand_mid_path}") # if need_last: # cand_embed_dict_last = {cid: emb for cid, emb in zip(all_cand_ids, cand_embeds_last)} # with open(cand_last_path, "wb") as f: # pickle.dump(cand_embed_dict_last, f) # print_master(f"[{dataset_name}] EE-only: saved {last_tag} candidate embeddings -> {cand_last_path}") # else: # print_master(f"[{dataset_name}] EE-only: reuse existing candidates (mid={cand_mid_path}, last={cand_last_path})") # if dist.is_initialized(): # dist.barrier() # # 3) 在线早停门控 + 子集续跑(不做离线分层评分/曲线) # if local_rank == 0: # with open(cand_mid_path, "rb") as f: # cand_mid_dict = pickle.load(f) # with open(cand_last_path, "rb") as f: # cand_last_dict = pickle.load(f) # rank_global = task_config.get("eval_type", "global") == "global" # print_master(f"[{dataset_name}] Run ONLINE early-exit at layer={ee_cfg['layer']}, method={ee_cfg['method']}, tau={ee_cfg['tau']}, topk={ee_cfg['topk']}, global={rank_global}") # run_early_exit_queries( # model=model, # processor=processor, # model_args=model_args, # data_args=data_args, # training_args=training_args, # qry_dataset=full_eval_qry_dataset, # 全量 query # cand_mid_dict=cand_mid_dict, # cand_last_dict=cand_last_dict, # ee_cfg=ee_cfg, # dataset_name=dataset_name, # out_dir=data_args.encode_output_path, # global_ranking=rank_global, # ) # if dist.is_initialized(): # dist.barrier() # # === EE-only 结束;直接进入下一个数据集 === # continue # if __name__ == '__main__': # main() import datetime import logging import json import random import time import numpy as np import os import pickle import sys import torch import torch.distributed as dist import torch.nn.functional as F import yaml import transformers import math from torch.utils.data import DataLoader from tqdm import tqdm from transformers import HfArgumentParser, AutoConfig, AutoTokenizer from datasets import Dataset, concatenate_datasets from datasets.distributed import split_dataset_by_node from src.model.vlm_backbone.qwen2_vl.modeling_qwen2_vl_train_tokrnpooling import Qwen2VLForConditionalGeneration as _Qwen2VLForConditionalGeneration_src from src.arguments import ModelArguments, DataArguments, TrainingArguments from src.data.collator.eval_collator import MultimodalEvalDataCollator from src.data.eval_dataset.base_eval_dataset import AutoEvalPairDataset, generate_cand_dataset from src.eval_utils.metrics import RankingMetrics from src.model.model_cut_layer_AOP_add_text_cut import MMEBModel from src.model.processor import get_backbone_name, load_processor, COLPALI from src.utils import batch_to_device, print_rank, print_master from dataclasses import dataclass def get_env_mid_layer(): v = os.environ.get("MID_LM_LAYER", "").strip() if v == "" or v.lower() in {"none", "null"}: return None try: return int(v) except: logger.warning(f"Invalid MID_LM_LAYER={v}, ignore.") return None # ------------- AOP-Prune config parsing ------------- def _parse_bool(v: str, default=False): if v is None: return default v = v.strip().lower() return v in {"1","true","yes","y","t","on"} def _parse_float(v: str, default=None): try: return float(v) if v is not None else default except: return default def _parse_int(v: str, default=None): try: return int(v) if v is not None else default except: return default def get_env_aop_config(): """ 从环境变量读取 AOP 剪裁配置。仅作为“驱动层”的简要测试开关; 实际剪裁逻辑在底模里(Qwen2-VLModel.forward)实现。 """ enabled = _parse_bool(os.environ.get("AOP_ENABLED"), False) apply_to = os.environ.get("AOP_APPLY", "qry").strip().lower() # qry|cand|both layer_idx = _parse_int(os.environ.get("AOP_LAYER"), None) mode = os.environ.get("AOP_MODE", "delta").strip().lower() # 通用回退 delta = _parse_float(os.environ.get("AOP_DELTA"), 0.10) khat = _parse_float(os.environ.get("AOP_KHAT"), 1.0) keep_ratio = _parse_float(os.environ.get("AOP_KEEP_RATIO"), 1.0) min_keep = _parse_int(os.environ.get("AOP_MIN_KEEP"), 64) use_bias = _parse_bool(os.environ.get("AOP_USE_BIAS"), True) # 按类型控制 prune_vision = _parse_bool(os.environ.get("AOP_PRUNE_VISION"), True) prune_text = _parse_bool(os.environ.get("AOP_PRUNE_TEXT"), False) delta_v = _parse_float(os.environ.get("AOP_DELTA_VISION"), None) khat_v = _parse_float(os.environ.get("AOP_KHAT_VISION"), None) keep_ratio_v= _parse_float(os.environ.get("AOP_KEEP_RATIO_VISION"), None) min_keep_v = _parse_int(os.environ.get("AOP_MIN_KEEP_VISION"), None) delta_t = _parse_float(os.environ.get("AOP_DELTA_TEXT"), None) khat_t = _parse_float(os.environ.get("AOP_KHAT_TEXT"), None) keep_ratio_t= _parse_float(os.environ.get("AOP_KEEP_RATIO_TEXT"), None) min_keep_t = _parse_int(os.environ.get("AOP_MIN_KEEP_TEXT"), 32) protect_text_last = _parse_int(os.environ.get("AOP_PROTECT_TEXT_LAST"), 16) protect_special = _parse_bool(os.environ.get("AOP_PROTECT_SPECIAL"), True) margin_src = os.environ.get("AOP_MARGIN", "").strip().lower() # "" or "mid" attn_impl = os.environ.get("AOP_ATTN_IMPL", "").strip().lower() # "" or "sdpa" if layer_idx is None and enabled: logger.warning("AOP_ENABLED=1 但未设置 AOP_LAYER,关闭 AOP。"); enabled=False # 新增:选择策略(aop | random) selection = os.environ.get("AOP_SELECTION", "aop").strip().lower() if _parse_bool(os.environ.get("AOP_RANDOM"), False): selection = "random" random_seed = _parse_int(os.environ.get("AOP_RANDOM_SEED"), None) # 选择策略:aop | random | attention selection = os.environ.get("AOP_SELECTION", "aop").strip().lower() if _parse_bool(os.environ.get("AOP_RANDOM"), False): selection = "random" random_seed = _parse_int(os.environ.get("AOP_RANDOM_SEED"), None) attn_agg = os.environ.get("AOP_ATTENTION_AGG", "mean").strip().lower() # mean|max|sum cfg = { "enabled": enabled, "apply_to": apply_to, "layer_idx": layer_idx, "mode": mode, # 回退 "delta": delta, "K_hat": khat, "keep_ratio": keep_ratio, "min_keep": min_keep, "use_bias": use_bias, "eps": 1e-6, # 类型开关 "prune_vision": prune_vision, "prune_text": prune_text, # 视觉桶 "delta_vision": delta_v, "K_hat_vision": khat_v, "keep_ratio_vision": keep_ratio_v, "min_keep_vision": min_keep_v, # 文本桶 "delta_text": delta_t, "K_hat_text": khat_t, "keep_ratio_text": keep_ratio_t, "min_keep_text": min_keep_t, # 文本保护 "protect_text_last": protect_text_last, "protect_special": protect_special, # 可选:排名安全预算 "margin_mid": None if margin_src != "mid" else "USE_MID_MARGIN", "epsilon_hat": None, "attn_impl_override": attn_impl if attn_impl in {"sdpa"} else "", # NEW: 选择策略 "selection": selection, # "aop" 或 "random" "random_seed": random_seed, # 可选 "attn_agg": attn_agg, } return cfg def get_env_eval_layers(): """ 解析环境变量 LM_LAYERS(优先)或兼容旧的 MID_LM_LAYER。 - LM_LAYERS 示例:"4,8,12,last";可包含 'last'/'none'/'null'/'-1' 表示最后一层(None)。 - 若未设置 LM_LAYERS,则回落到旧逻辑:MID_LM_LAYER=None -> [None];否则 [mid, None] 返回: list[ int | None ],例如 [4, 8, 12, None];None 代表最后一层。 """ v = os.environ.get("LM_LAYERS", None) if v is not None: v = v.strip() if v: toks = [t.strip() for t in v.split(',') if t.strip() != ""] layers = [] for tok in toks: tl = tok.lower() if tl in {"last", "none", "null", "-1"}: layers.append(None) else: try: val = int(tok) if val > 0: layers.append(val) else: logger.warning(f"Ignoring non-positive layer '{tok}' in LM_LAYERS.") except Exception: logger.warning(f"Invalid token '{tok}' in LM_LAYERS; must be int or 'last'/'none'.") # 去重但保持顺序 seen = set() uniq = [] for l in layers: key = -1 if l is None else l if key in seen: continue seen.add(key) uniq.append(l) if not uniq: return [None] return uniq else: # 兼容旧逻辑 mid = get_env_mid_layer() return [None] if mid is None else [mid, None] # === Early-Exit config & helpers === def get_env_ee_config(): ee_enabled = os.environ.get("EE_ENABLED", "0").strip().lower() in {"1","true","yes","on","y","t"} layer = int(os.environ.get("EE_LAYER", os.environ.get("AOP_LAYER", "12"))) # 默认用 AOP_LAYER method = os.environ.get("EE_METHOD", "margin").strip().lower() # margin|p1p2|entropy|gini|combined tau = float(os.environ.get("EE_TAU", "0.2")) topk = int(os.environ.get("EE_TOPK", "1024")) temp = float(os.environ.get("EE_TEMP", "0.05")) save = os.environ.get("EE_SAVE", "1").strip().lower() in {"1","true","yes","on","y","t"} combw = os.environ.get("EE_COMB_WEIGHTS", "1.0,0.5,0.5") try: w_margin, w_conf, w_sq = [float(x) for x in combw.split(",")] except Exception: w_margin, w_conf, w_sq = 1.0, 0.5, 0.5 return dict( enabled=ee_enabled, layer=layer, method=method, tau=tau, topk=topk, temp=temp, save=save, w_margin=w_margin, w_conf=w_conf, w_sq=w_sq ) def _softmax_np(x: np.ndarray, temp: float = 1.0) -> np.ndarray: x = x - np.max(x) ex = np.exp(x / max(1e-6, temp)) s = np.sum(ex) return ex / max(s, 1e-12) def confidence_from_topk(scores: np.ndarray, method="margin", temp=0.05, w_margin=1.0, w_conf=0.5, w_sq=0.5) -> float: # scores: 已按降序排列(topK) if scores.size == 0: return 0.0 if scores.size == 1: return 1e9 margin = float(scores[0] - scores[1]) p = _softmax_np(scores, temp=temp) p1p2 = float(p[0] - p[1]) H = - float(np.sum(p * np.log(p + 1e-12))) / np.log(len(p)) # 归一化熵 ∈ [0,1] conf = 1.0 - H sqsum = float(np.sum(p**2)) # Gini 的等价度量(越大越集中) if method == "margin": return margin if method == "p1p2": return p1p2 if method == "entropy": return conf if method == "gini": return sqsum # combined return w_margin*margin + w_conf*conf + w_sq*sqsum def run_early_exit_queries( model: MMEBModel, processor, model_args: ModelArguments, data_args: DataArguments, training_args: TrainingArguments, qry_dataset: Dataset, cand_mid_dict: dict, cand_last_dict: dict, ee_cfg: dict, dataset_name: str, out_dir: str, global_ranking: bool = True, ): """ 仅在线早停推理(不画曲线),并额外输出: - 每个 query 的中间层 vs cand_last 的 top-K 相似度 (mid2last) - 对未早停的 query,再输出最后一层 vs cand_last 的 top-K 相似度 (last2last) 输出文件: {out_dir}/{dataset}_score_earlyexit.json - 检索指标 {out_dir}/{dataset}_pred_earlyexit.jsonl - 预测列表(原有) {out_dir}/{dataset}_sim_earlyexit.jsonl - 本函数新增的相似度信息(需 EE_SAVE=1) """ device = training_args.device local_rank = dist.get_rank() if dist.is_initialized() else 0 is_main = (not dist.is_initialized()) or (local_rank == 0) # 候选矩阵 cand_ids = list(cand_mid_dict.keys()) cand_id2row = {str(cid): i for i, cid in enumerate(cand_ids)} cand_mid = np.stack([cand_mid_dict[c] for c in cand_ids]).astype(np.float32) cand_last = np.stack([cand_last_dict[c] for c in cand_ids]).astype(np.float32) cand_mid_t = torch.from_numpy(cand_mid).to(device=device, dtype=torch.bfloat16) cand_last_t = torch.from_numpy(cand_last).to(device=device, dtype=torch.bfloat16) # query DataLoader collator = MultimodalEvalDataCollator(processor, model_args, data_args, "qry") loader = DataLoader( qry_dataset, batch_size=training_args.per_device_eval_batch_size, collate_fn=collator, num_workers=training_args.dataloader_num_workers ) pred_dicts = [] # 是否保存相似度(沿用 EE_SAVE) save_scores = ee_cfg.get("save", False) topk_sim = int(ee_cfg.get("topk", 1024)) sim_records = [] if (save_scores and is_main) else None # AOP 按侧开启 aop_cfg = getattr(model.encoder, "aop_prune_config", None) _orig_enabled = None side_enable = True if isinstance(aop_cfg, dict) and aop_cfg: _orig_enabled = aop_cfg.get("enabled", False) apply_to = aop_cfg.get("apply_to", "qry") side_enable = (apply_to == "both") or (apply_to == "qry") # 门控相关 k_conf = 2 tau = float(ee_cfg["tau"]) method= ee_cfg["method"] temp = float(ee_cfg["temp"]) idx_global = 0 start_time = time.time() for inputs, infos in tqdm( loader, desc=f"[EE] {dataset_name}@L{ee_cfg['layer']} (rank {local_rank})", disable=local_rank > 0, ): inputs = batch_to_device(inputs, device) # -------- 1) 跑到中间层(stop_at_layer),不算 logits -------- orig_cfg = None if isinstance(aop_cfg, dict) and aop_cfg: orig_cfg = dict(aop_cfg) aop_layer = aop_cfg.get("layer_idx", None) ee_layer = int(ee_cfg["layer"]) apply_to = aop_cfg.get("apply_to", "qry").strip().lower() aop_on_mid = bool( _orig_enabled and side_enable and (aop_layer is not None) and (aop_layer < ee_layer) and (apply_to in {"qry", "both"}) ) aop_cfg_mid = dict(aop_cfg) aop_cfg_mid["enabled"] = aop_on_mid setattr(model.encoder, "aop_prune_config", aop_cfg_mid) with torch.no_grad(), torch.autocast( device_type="cuda", dtype=torch.bfloat16, enabled=True ): out_mid = model.encoder( **inputs, return_dict=True, output_hidden_states=False, stop_at_layer=int(ee_cfg["layer"]), compute_lm_head=False, ) if isinstance(orig_cfg, dict): setattr(model.encoder, "aop_prune_config", orig_cfg) # EOS pooling 得到中间层表征 hs_mid = getattr(out_mid, "last_hidden_state", None) if hs_mid is None: assert out_mid.hidden_states is not None and len(out_mid.hidden_states) > 0 hs_mid = out_mid.hidden_states[-1] am_mid = getattr(out_mid, "attention_mask", None) if am_mid is None: am_mid = inputs.get("attention_mask", None) if hasattr(am_mid, "device") and am_mid.device != hs_mid.device: am_mid = am_mid.to(hs_mid.device) reps_mid_t = model._pooling(hs_mid, am_mid).detach().to(device=device, dtype=torch.bfloat16) # [B,D] B = reps_mid_t.size(0) use_local = (not global_ranking) # -------- 2) 门控:基于 mid→mid 的 top2 分数 -------- if not use_local: # 全库:cand_mid_t scores_t = reps_mid_t @ cand_mid_t.T # [B, Nc] vals_t, idxs_t = torch.topk( scores_t, k=min(k_conf, scores_t.size(1)), dim=1 ) # [B,2] p_t = torch.softmax(vals_t / max(temp, 1e-6), dim=1) if vals_t.size(1) >= 2: margin_t = vals_t[:, 0] - vals_t[:, 1] p1p2_t = p_t[:, 0] - p_t[:, 1] else: margin_t = torch.full((B,), float("inf"), device=device, dtype=vals_t.dtype) p1p2_t = torch.ones(B, device=device, dtype=vals_t.dtype) H_t = -(p_t * (torch.log(p_t + 1e-12))).sum(dim=1) / math.log(max(vals_t.size(1),1)) conf_map = { "margin": margin_t, "p1p2": p1p2_t, "entropy": 1.0 - H_t, "gini": (p_t ** 2).sum(dim=1), } confs_t = conf_map.get(method, margin_t) exit_mask = (confs_t >= tau).detach().cpu().numpy().astype(bool) else: # local:对每个 query 单独用 cand_mid_t[rows] confs = [] for b in range(B): cand_local = infos[b]["cand_names"] rows = [cand_id2row.get(str(cid), -1) for cid in cand_local] rows = [r for r in rows if r >= 0] if len(rows) == 0: confs.append(0.0) continue cmat_t = cand_mid_t[rows] # [Nl, D] sv_t = (reps_mid_t[b:b+1] @ cmat_t.T)[0] # [Nl] k = 2 if sv_t.size(0) >= 2 else 1 vals_t, _ = torch.topk(sv_t, k=k, dim=0) p_t = torch.softmax(vals_t / max(temp, 1e-6), dim=0) if k >= 2: margin = (vals_t[0] - vals_t[1]).item() p1p2 = (p_t[0] - p_t[1]).item() else: margin, p1p2 = float("inf"), 1.0 H = (-(p_t * (torch.log(p_t + 1e-12))).sum() / math.log(max(k,1))).item() gini = ((p_t ** 2).sum()).item() d = {"margin": margin, "p1p2": p1p2, "entropy": 1.0 - H, "gini": gini} confs.append(d.get(method, margin)) exit_mask = (np.array(confs) >= tau) # -------- 3) 检索 + 相似度记录 -------- # 早停样本 exit_indices = np.where(exit_mask)[0].tolist() # 续跑样本 need_indices = np.where(~exit_mask)[0].tolist() # A. 早停:直接用 mid→mid 排序,但我们额外算 mid→last 的 top-K 相似度 for j in exit_indices: # 1) 排序(pred_dicts) if not use_local: scores_mid_mid = (reps_mid_t[j:j+1] @ cand_mid_t.T)[0] # [Nc] order = torch.argsort(scores_mid_mid, dim=0, descending=True).detach().cpu().numpy() cids = [cand_ids[i] for i in order] else: cand_local = infos[j]["cand_names"] rows = [cand_id2row.get(str(cid), -1) for cid in cand_local] rows = [r for r in rows if r >= 0] if len(rows) == 0: cids = [] else: cmat_t = cand_mid_t[rows] sv = (reps_mid_t[j:j+1] @ cmat_t.T)[0] order_local = torch.argsort(sv, dim=0, descending=True).detach().cpu().numpy() cids = [str(cand_local[i]) for i in order_local] rel_docids = infos[j]["label_name"] if not isinstance(rel_docids, list): rel_docids = [rel_docids] pred_dicts.append({"prediction": cids, "label": rel_docids, "rel_scores": None}) # 2) 相似度记录:mid→last if save_scores and is_main: if not use_local: scores_mid_last = (reps_mid_t[j:j+1] @ cand_last_t.T)[0].detach().float().cpu() # [Nc] Nc = scores_mid_last.size(0) K = min(topk_sim, Nc) mid_vals, mid_inds = torch.topk(scores_mid_last, k=K, dim=0) mid_ids = [cand_ids[i] for i in mid_inds.tolist()] else: cand_local = infos[j]["cand_names"] rows = [cand_id2row.get(str(cid), -1) for cid in cand_local] rows = [r for r in rows if r >= 0] if len(rows) == 0: mid_vals = torch.empty(0) mid_ids = [] else: cmat_last = cand_last_t[rows] # [Nl, D] sv_last = (reps_mid_t[j:j+1] @ cmat_last.T)[0].detach().float().cpu() K = min(topk_sim, sv_last.size(0)) mid_vals, mid_inds = torch.topk(sv_last, k=K, dim=0) mid_ids = [str(cand_local[i]) for i in mid_inds.tolist()] rec = { "qid": int(idx_global + j), "early_exit": True, "mid_topk_scores": mid_vals.tolist() if mid_vals.numel() > 0 else [], "mid_topk_cand_ids": mid_ids, "last_topk_scores": None, "last_topk_cand_ids": None, } sim_records.append(rec) # B. 续跑:mid->last,再用 last→last 排序;同时记录 mid→last & last→last 相似度 if len(need_indices) > 0: # 从中间态恢复 if isinstance(aop_cfg, dict) and aop_cfg: aop_resume = dict(aop_cfg) aop_resume["enabled"] = bool(_orig_enabled and side_enable) setattr(model.encoder, "aop_prune_config", aop_resume) interm = getattr(out_mid, "intermediate_state", None) assert interm is not None, "Model must return intermediate_state when stop_at_layer is set." hs = interm["hidden_states"].detach() am = interm["attention_mask"].detach() pos = interm["position_ids"].detach() vm = interm.get("vision_mask", None) tm = interm.get("text_mask", None) next_layer = int(interm["next_layer_idx"]) hs_sub = hs[need_indices] am_sub = am[need_indices] pos_sub = pos[:, need_indices, :] vm_sub = vm[need_indices] if vm is not None else None tm_sub = tm[need_indices] if tm is not None else None resume_state = { "hidden_states": hs_sub, "attention_mask": am_sub, "position_ids": pos_sub, "vision_mask": vm_sub, "text_mask": tm_sub, "next_layer_idx": next_layer, } with torch.no_grad(), torch.autocast( device_type="cuda", dtype=torch.bfloat16, enabled=True ): out_last = model.encoder( return_dict=True, output_hidden_states=False, stop_at_layer=None, resume_state=resume_state, compute_lm_head=False, ) hs_last = getattr(out_last, "last_hidden_state", None) if hs_last is None: assert out_last.hidden_states is not None and len(out_last.hidden_states) > 0 hs_last = out_last.hidden_states[-1] am_last = getattr(out_last, "attention_mask", None) if am_last is None: am_last = am_sub if hasattr(am_last, "device") and am_last.device != hs_last.device: am_last = am_last.to(hs_last.device) reps_last_t = model._pooling(hs_last, am_last).detach().to(device=device, dtype=torch.bfloat16) if not use_local: scores_last_all = (reps_last_t @ cand_last_t.T).detach().float().cpu() # [N_need, Nc] for k, j in enumerate(need_indices): # 1) 排序预测 row = scores_last_all[k] order = torch.argsort(row, dim=0, descending=True).tolist() cids = [cand_ids[i] for i in order] rel_docids = infos[j]["label_name"] if not isinstance(rel_docids, list): rel_docids = [rel_docids] pred_dicts.append({"prediction": cids, "label": rel_docids, "rel_scores": None}) # 2) mid→last & last→last 相似度 if save_scores and is_main: # mid2last scores_mid_last = (reps_mid_t[j:j+1] @ cand_last_t.T)[0].detach().float().cpu() Nc = scores_mid_last.size(0) K = min(topk_sim, Nc) mid_vals, mid_inds = torch.topk(scores_mid_last, k=K, dim=0) mid_ids = [cand_ids[i] for i in mid_inds.tolist()] # last2last last_row = row last_vals, last_inds = torch.topk(last_row, k=K, dim=0) last_ids = [cand_ids[i] for i in last_inds.tolist()] rec = { "qid": int(idx_global + j), "early_exit": False, "mid_topk_scores": mid_vals.tolist(), "mid_topk_cand_ids": mid_ids, "last_topk_scores": last_vals.tolist(), "last_topk_cand_ids": last_ids, } sim_records.append(rec) else: # local ranking for k, j in enumerate(need_indices): cand_local = infos[j]["cand_names"] rows = [cand_id2row.get(str(cid), -1) for cid in cand_local] rows = [r for r in rows if r >= 0] if len(rows) == 0: cids = [] rel_docids = infos[j]["label_name"] if not isinstance(rel_docids, list): rel_docids = [rel_docids] pred_dicts.append({"prediction": cids, "label": rel_docids, "rel_scores": None}) if save_scores and is_main: rec = { "qid": int(idx_global + j), "early_exit": False, "mid_topk_scores": [], "mid_topk_cand_ids": [], "last_topk_scores": [], "last_topk_cand_ids": [], } sim_records.append(rec) continue # 1) 排序(last→last) cmat_last = cand_last_t[rows] sv_last = (reps_last_t[k:k+1] @ cmat_last.T)[0].detach().float().cpu() order_local = torch.argsort(sv_last, dim=0, descending=True).tolist() cids = [str(cand_local[i]) for i in order_local] rel_docids = infos[j]["label_name"] if not isinstance(rel_docids, list): rel_docids = [rel_docids] pred_dicts.append({"prediction": cids, "label": rel_docids, "rel_scores": None}) # 2) mid→last & last→last 相似度 if save_scores and is_main: cmat_last = cand_last_t[rows] # mid2last sv_mid_last = (reps_mid_t[j:j+1] @ cmat_last.T)[0].detach().float().cpu() K = min(topk_sim, sv_mid_last.size(0)) mid_vals, mid_inds = torch.topk(sv_mid_last, k=K, dim=0) mid_ids = [str(cand_local[i]) for i in mid_inds.tolist()] # last2last sv_last_row = sv_last last_vals, last_inds = torch.topk(sv_last_row, k=K, dim=0) last_ids = [str(cand_local[i]) for i in last_inds.tolist()] rec = { "qid": int(idx_global + j), "early_exit": False, "mid_topk_scores": mid_vals.tolist(), "mid_topk_cand_ids": mid_ids, "last_topk_scores": last_vals.tolist(), "last_topk_cand_ids": last_ids, } sim_records.append(rec) idx_global += B # -------- 4) 评测 + 写出 -------- metrics_to_report = ["hit", "ndcg", "precision", "recall", "f1", "map", "mrr"] score = RankingMetrics(metrics_to_report).evaluate(pred_dicts) if is_main: os.makedirs(out_dir, exist_ok=True) # 原有的 early-exit 检索结果 with open(os.path.join(out_dir, f"{dataset_name}_score_earlyexit.json"), "w") as f: json.dump(score, f, indent=4) if ee_cfg.get("save", False): with open(os.path.join(out_dir, f"{dataset_name}_pred_earlyexit.jsonl"), "w", encoding="utf-8") as f: for row in pred_dicts: f.write(json.dumps(row, ensure_ascii=False) + "\n") # 新增的 mid/last 相似度输出 if save_scores and sim_records is not None: sims_path = os.path.join(out_dir, f"{dataset_name}_sim_earlyexit.jsonl") with open(sims_path, "w", encoding="utf-8") as f: for rec in sim_records: f.write(json.dumps(rec, ensure_ascii=False) + "\n") print_master(f"[EE] Saved mid/last similarity records -> {sims_path}") elapsed = time.time() - start_time return score, elapsed def make_layer_tag(keep_layers: int | None): return f"layer{keep_layers}" if keep_layers and keep_layers > 0 else "layerlast" def dot_sim(a: np.ndarray, b: np.ndarray) -> np.ndarray: # a: [Nq, D], b: [Nc, D], both L2-normalized already if normalize=true return a @ b.T def build_score_details(qid: int, cand_ids: list, score_vec: np.ndarray, ranked_indices: np.ndarray): return { "qid": int(qid), "cand_scores": [ {"cand_id": str(cand_ids[i]), "score": float(score_vec[i])} for i in ranked_indices ] } def top1_top2_margin(score_vec: np.ndarray) -> float: if len(score_vec) < 2: return float("inf") # 只有一个候选时视作极大margin top2 = np.partition(score_vec, -2)[-2:] top2.sort() return float(top2[-1] - top2[-2]) def simulate_early_exit_by_margin( sims_mid: list[dict], sims_last: list[dict], labels: list[list[str]], metrics_to_report: list[str], taus: list[float], rank_global: bool ): """ sims_mid / sims_last: 每个query一个dict: {cand_id: score} labels: 每个query的正样本cand_id列表 返回:不同tau下的覆盖率、指标 """ assert len(sims_mid) == len(sims_last) == len(labels) N = len(labels) results = [] from src.eval_utils.metrics import RankingMetrics metrics = RankingMetrics(metrics_to_report) # 预构造 用于metrics.evaluate 的pred_dict def to_pred_dicts(use_mid_mask: list[bool]) -> list[dict]: pred_dicts = [] for qid in range(N): sims_use = sims_mid[qid] if use_mid_mask[qid] else sims_last[qid] # 排序 ranked = sorted(sims_use.items(), key=lambda x: -x[1]) pred_dicts.append({ "prediction": [cid for cid, _ in ranked], "label": labels[qid], "rel_scores": None }) return pred_dicts # 计算中间层margin margins = [] for qid in range(N): # 取前两大分数的margin if len(sims_mid[qid]) == 0: margins.append(0.0) continue scores = np.array(list(sims_mid[qid].values()), dtype=np.float32) margins.append(top1_top2_margin(scores)) margins = np.array(margins, dtype=np.float32) for tau in taus: use_mid_mask = (margins >= tau).tolist() pred_dicts = to_pred_dicts(use_mid_mask) score_dict = metrics.evaluate(pred_dicts) coverage = float(np.mean(use_mid_mask)) # 早停覆盖率 results.append({ "tau": tau, "coverage": coverage, **score_dict }) return results def top1_top2_margin_from_array(score_vec: np.ndarray) -> float: if score_vec is None or len(score_vec) == 0: return 0.0 if len(score_vec) == 1: return float('inf') # 取前两大 top2 = np.partition(score_vec, -2)[-2:] top2.sort() return float(top2[-1] - top2[-2]) logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s [%(name)s:%(lineno)s] %(message)s') logger = logging.getLogger(__name__) # --- Global Dictionaries for Hooks (will be cleared before each encode_embeddings call) --- timing_info = {} token_info = { "vision_tokens": 0, "text_input_tokens": 0, # Refers to the original text token count "text_output_tokens": 0, # Not directly applicable here as we are encoding, not generating. Will be 0. "total_llm_input_tokens": 0, # Refers to the total tokens LLM receives (visual + formatted text) } # --- Hook Functions Definition --- def timing_pre_hook(module, input): module_id = id(module) if module_id not in timing_info: timing_info[module_id] = [] timing_info[module_id].append((time.time(), 'pre', module.__class__.__name__)) def timing_post_hook(module, input, output): module_id = id(module) if module_id not in timing_info: # print(f"Warning: No pre-hook data for module {module.__class__.__name__} ({module_id})") return timing_info[module_id].append((time.time(), 'post', module.__class__.__name__)) # Collect vision token count (only from Vision Transformer module's post hook) module_name = module.__class__.__name__ if "vision" in module_name.lower() and "transformer" in module_name.lower(): if isinstance(output, torch.Tensor): token_info["vision_tokens"] = output.shape[0] # For visual features, usually (batch_size, num_tokens, hidden_dim) elif hasattr(output, 'last_hidden_state'): token_info["vision_tokens"] = output.last_hidden_state.shape[1] def register_model_hooks(model): registered_modules = [] core_model = model.encoder if hasattr(model, "encoder") and model.encoder is not None else model # Vision module if hasattr(core_model, 'visual') and core_model.visual is not None: vision_module = core_model.visual vision_module.register_forward_pre_hook(timing_pre_hook) vision_module.register_forward_hook(timing_post_hook) registered_modules.append(vision_module) print_master(f"Registered hooks for vision module: {vision_module.__class__.__name__}") else: print_master(f"WARNING: No 'visual' attribute found on core_model ({type(core_model)}).") # Merger module (if inside visual) - it's part of the vision component if hasattr(core_model, 'visual') and hasattr(core_model.visual, 'merger') and core_model.visual.merger is not None: merger_module = core_model.visual.merger merger_module.register_forward_pre_hook(timing_pre_hook) merger_module.register_forward_hook(timing_post_hook) registered_modules.append(merger_module) print_master(f"Registered hooks for merger module: {merger_module.__class__.__name__}") else: print_master(f"WARNING: No 'merger' attribute found on core_model.visual ({type(getattr(core_model, 'visual', 'N/A'))}).") # Language model body if hasattr(core_model, 'model') and core_model.model is not None: llm_main_module = core_model.model llm_main_module.register_forward_pre_hook(timing_pre_hook) llm_main_module.register_forward_hook(timing_post_hook) registered_modules.append(llm_main_module) print_master(f"Registered hooks for LLM main module: {llm_main_module.__class__.__name__}") else: print_master(f"WARNING: No 'model' attribute found on core_model ({type(core_model)}).") # LM Head if hasattr(core_model, 'lm_head') and core_model.lm_head is not None: lm_head_module = core_model.lm_head lm_head_module.register_forward_pre_hook(timing_pre_hook) lm_head_module.register_forward_hook(timing_post_hook) registered_modules.append(lm_head_module) print_master(f"Registered hooks for LM head module: {lm_head_module.__class__.__name__}") else: print_master(f"WARNING: No 'lm_head' attribute found on core_model ({type(core_model)}).") if not registered_modules: print_master("Warning: No major modules found for hook registration. Check model architecture.") return registered_modules def pad_dataset_to_divisible(dataset, world_size): num_samples = len(dataset) if num_samples % world_size == 0: return dataset, num_samples num_to_add = world_size - (num_samples % world_size) padded_size = num_samples + num_to_add padding_data = dataset.select([i % len(dataset) for i in range(num_to_add)]) padded_dataset = concatenate_datasets([dataset, padding_data]) return padded_dataset, padded_size def encode_embeddings( model: MMEBModel, loader: DataLoader, training_args: TrainingArguments, model_args: ModelArguments, full_dataset: Dataset, encode_side: str, description: str = "Encoding" ) -> tuple[np.ndarray, list, list, list]: # CHANGED: + list for img_token_masks """ Encodes embeddings for a given dataset using the model, handling both standard and late-interaction models in a DDP-safe manner. Returns: - embeddings: np.ndarray - infos_or_ids: list - batch_stats_list: list - img_token_masks: list[None | list[bool]] # NEW """ local_rank = dist.get_rank() if dist.is_initialized() else 0 world_size = dist.get_world_size() if dist.is_initialized() else 1 # Check if the model is a late-interaction type is_late_interaction = (model_args.model_backbone == COLPALI) local_embeds = [] local_gt_infos = [] local_max_len = 0 # --- New: List to store statistics for each batch --- batch_stats_list = [] # --- NEW: Collect masks --- local_img_token_masks = [] # post image mask per sample local_txt_token_masks = [] # NEW: post text mask per sample local_post_attn_masks = [] # NEW: post attention_mask per sample (after prune, 1/0) # --- NEW: per-sample token reduction records --- local_token_records = [] # 每条样本一个 dict,含 pre/post/delta 数量 model.eval() # Register hooks for the model once per encode_embeddings call registered_hooks = register_model_hooks(model) # --- NEW: helpers to取mask并序列化 --- def _search_key(obj, key: str): # 递归搜索 dict/list/tuple,找到指定 key if isinstance(obj, dict): if key in obj: return obj[key] for v in obj.values(): r = _search_key(v, key) if r is not None: return r elif isinstance(obj, (list, tuple)): for v in obj: r = _search_key(v, key) if r is not None: return r return None def _to_serializable_mask_list(mask_list, batch_size: int): # 将模型返回的 mask(list/tensor/ndarray/None)转成 [None | list[bool]] * B if mask_list is None: return [None] * batch_size out = [] if isinstance(mask_list, (list, tuple)): for m in mask_list: if m is None: out.append(None) elif torch.is_tensor(m): out.append(m.detach().cpu().tolist()) elif isinstance(m, np.ndarray): out.append(m.tolist()) else: # already python list/bool out.append(m) elif torch.is_tensor(mask_list): # 若是 2D 张量(B, L),直接 tolist() -> list[list[bool/int]] out = mask_list.detach().cpu().tolist() elif isinstance(mask_list, np.ndarray): out = mask_list.tolist() else: # 未知类型,保守返回 None 占位 out = [None] * batch_size # 长度对齐 batch_size if isinstance(out, list): if len(out) < batch_size: out = out + [None] * (batch_size - len(out)) elif len(out) > batch_size: out = out[:batch_size] return out def _to_bool_lists(m, batch_size: int): lst = _to_serializable_mask_list(m, batch_size) # 归一化成 list[ list[bool] | None ] out = [] for x in lst: if x is None: out.append(None) else: # x 可能是 list[int] 或 list[bool] out.append([bool(int(v)) for v in x]) return out with torch.no_grad(): for inputs, dataset_info in tqdm(loader, desc=f"{description} (rank {local_rank})", disable=local_rank > 0): # --- Reset statistics for each inference pass --- timing_info.clear() token_info["vision_tokens"] = 0 token_info["text_input_tokens"] = 0 token_info["text_output_tokens"] = 0 token_info["total_llm_input_tokens"] = 0 inputs = batch_to_device(inputs, training_args.device) current_batch_size = inputs['input_ids'].shape[0] if 'input_ids' in inputs and inputs['input_ids'] is not None else 1 with torch.autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda"): start_inference_time = time.time() # ---- NEW: 按侧开/关 AOP ---- aop_cfg = getattr(model.encoder, "aop_prune_config", None) _orig_enabled = None if isinstance(aop_cfg, dict) and aop_cfg: _orig_enabled = aop_cfg.get("enabled", False) apply_to = aop_cfg.get("apply_to", "qry") side_enable = (apply_to == "both") or (apply_to == encode_side) aop_cfg["enabled"] = bool(side_enable and _orig_enabled) setattr(model.encoder, "aop_prune_config", aop_cfg) if encode_side == "qry": output = model(qry=inputs) reps = output["qry_reps"].detach() local_gt_infos.extend(dataset_info) else: output = model(tgt=inputs) reps = output["tgt_reps"].detach() local_gt_infos.extend([info["cand_name"] for info in dataset_info]) # ---- NEW: 恢复 enabled(避免影响下个 encode_side)---- if isinstance(aop_cfg, dict) and _orig_enabled is not None: aop_cfg["enabled"] = _orig_enabled setattr(model.encoder, "aop_prune_config", aop_cfg) end_inference_time = time.time() # --- NEW: 提取 post-prune 的 image/text 掩码 与 post attention_mask --- img_masks_raw = None txt_masks_raw = None post_attn_raw = None if isinstance(output, dict): img_masks_raw = _search_key(output, "image_token_bool_masks") txt_masks_raw = _search_key(output, "text_token_bool_masks") # NEW post_attn_raw = _search_key(output, "post_attention_mask") # NEW(我们的 MMEBModel.forward 里带了这个键) # 兼容:若挂在 model 上 if img_masks_raw is None and hasattr(model, "image_token_bool_masks"): img_masks_raw = getattr(model, "image_token_bool_masks") if txt_masks_raw is None and hasattr(model, "text_token_bool_masks"): txt_masks_raw = getattr(model, "text_token_bool_masks") if post_attn_raw is None and hasattr(model, "post_attention_mask"): post_attn_raw = getattr(model, "post_attention_mask") img_masks_serializable = _to_serializable_mask_list(img_masks_raw, current_batch_size) txt_masks_serializable = _to_serializable_mask_list(txt_masks_raw, current_batch_size) # NEW post_attn_serializable = _to_serializable_mask_list(post_attn_raw, current_batch_size) # NEW local_img_token_masks.extend(img_masks_serializable) local_txt_token_masks.extend(txt_masks_serializable) # NEW local_post_attn_masks.extend(post_attn_serializable) # NEW # --- NEW: 计算本 batch 的 pre/post/delta 数量并累计 --- cfg = getattr(model.encoder, "config", None) # pre masks 来自 inputs(删前) input_ids = inputs.get("input_ids", None) attn2d_pre = inputs.get("attention_mask", None) if input_ids is None or attn2d_pre is None or cfg is None: # 无法统计,留空 pre_vis_counts = [0] * current_batch_size pre_txt_counts = [0] * current_batch_size pre_tot_counts = [0] * current_batch_size else: iid = input_ids am = attn2d_pre.to(torch.bool) image_token_id = getattr(cfg, "image_token_id", None) video_token_id = getattr(cfg, "video_token_id", None) bos_id = getattr(cfg, "bos_token_id", None) eos_id = getattr(cfg, "eos_token_id", None) pad_id = getattr(cfg, "pad_token_id", None) is_image = (iid == image_token_id) if (image_token_id is not None and image_token_id >= 0) else torch.zeros_like(iid, dtype=torch.bool) is_video = (iid == video_token_id) if (video_token_id is not None and video_token_id >= 0) else torch.zeros_like(iid, dtype=torch.bool) is_vision = is_image | is_video is_special = torch.zeros_like(iid, dtype=torch.bool) for tid in [bos_id, eos_id, pad_id]: if tid is not None and tid >= 0: is_special |= (iid == tid) pre_txt_mask = am & (~is_vision) & (~is_special) pre_vis_mask = am & is_vision pre_vis_counts = pre_vis_mask.sum(dim=1).tolist() pre_txt_counts = pre_txt_mask.sum(dim=1).tolist() pre_tot_counts = am.sum(dim=1).tolist() # post masks(删后)来自模型输出;与 post_attn 做与运算 post_text_masks = _to_bool_lists(txt_masks_raw, current_batch_size) # list[ list[bool] | None ] post_image_masks = _to_bool_lists(img_masks_raw, current_batch_size) post_attn_masks = _to_bool_lists(post_attn_raw, current_batch_size) sum_pre_text = 0; sum_post_text = 0 sum_pre_vis = 0; sum_post_vis = 0 sum_pre_tot = 0; sum_post_tot = 0 for i in range(current_batch_size): pre_text = int(pre_txt_counts[i]) if i < len(pre_txt_counts) else 0 pre_vis = int(pre_vis_counts[i]) if i < len(pre_vis_counts) else 0 pre_tot = int(pre_tot_counts[i]) if i < len(pre_tot_counts) else 0 # post 计数:mask 可能为 None m_text = post_text_masks[i] if post_text_masks is not None and i < len(post_text_masks) else None m_img = post_image_masks[i] if post_image_masks is not None and i < len(post_image_masks) else None m_attn = post_attn_masks[i] if post_attn_masks is not None and i < len(post_attn_masks) else None if m_attn is None: post_text = 0; post_vis = 0; post_tot = 0 else: # 与 attention_mask 后统计 True 的数 if m_text is not None: post_text = sum(1 for a, t in zip(m_attn, m_text) if a and t) else: post_text = 0 if m_img is not None: post_vis = sum(1 for a, v in zip(m_attn, m_img) if a and v) else: post_vis = 0 post_tot = sum(1 for a in m_attn if a) # 累计 batch 级 sum_pre_text += pre_text; sum_post_text += post_text sum_pre_vis += pre_vis; sum_post_vis += post_vis sum_pre_tot += pre_tot; sum_post_tot += post_tot # 保存 per-sample 记录(用于 JSONL) local_token_records.append({ "side": encode_side, "pre": {"text": pre_text, "vision": pre_vis, "total": pre_tot}, "post": {"text": post_text, "vision": post_vis, "total": post_tot}, "delta":{"text": pre_text - post_text, "vision": pre_vis - post_vis, "total": pre_tot - post_tot}, }) # --- Update total LLM input tokens after the model call --- if 'input_ids' in inputs and inputs['input_ids'] is not None: token_info["total_llm_input_tokens"] = inputs['input_ids'].shape[1] token_info["text_input_tokens"] = token_info["total_llm_input_tokens"] - token_info["vision_tokens"] token_info["text_input_tokens"] = max(0, token_info["text_input_tokens"]) # --- Collect and Store Batch Statistics --- batch_inference_time = end_inference_time - start_inference_time current_batch_stats = { "batch_size": current_batch_size, "total_inference_time_seconds": batch_inference_time, "module_inference_times": {}, "token_counts": { "visual_tokens": token_info["vision_tokens"], "language_input_tokens_raw": token_info["text_input_tokens"], "llm_total_input_tokens": token_info["total_llm_input_tokens"], "language_output_tokens": token_info["text_output_tokens"], } } current_batch_stats["token_reduction"] = { "sum_pre_text": sum_pre_text, "sum_post_text": sum_post_text, "sum_pre_vision": sum_pre_vis, "sum_post_vision": sum_post_vis, "sum_pre_total": sum_pre_tot, "sum_post_total": sum_post_tot, } # Calculate and store module timings for the current batch for module_obj in registered_hooks: module_id = id(module_obj) module_name = module_obj.__class__.__name__ times = timing_info.get(module_id, []) durations = [] pre_times = {} for t, event_type, _ in times: if event_type == 'pre': pre_times[module_id] = t elif event_type == 'post' and module_id in pre_times: duration = t - pre_times.pop(module_id) durations.append(duration) if durations: current_batch_stats["module_inference_times"][module_name] = { "total": sum(durations), "count": len(durations), "avg": sum(durations) / len(durations) } else: current_batch_stats["module_inference_times"][module_name] = { "total": 0.0, "count": 0, "avg": 0.0 } batch_stats_list.append(current_batch_stats) # --- Debug prints (optional) --- print_rank(f"\n--- Inference Statistics for {encode_side} batch (Rank {local_rank}) ---") print_rank(f"Batch Inference took: {batch_inference_time:.4f} seconds") print_rank("--- Module Inference Timing Statistics ---") for module_name, stats in current_batch_stats["module_inference_times"].items(): print_rank(f"**{module_name}**: Total: {stats['total']:.6f}s, Count: {stats['count']}, Avg: {stats['avg']:.6f}s") print_rank("--- Token Count Statistics ---") print_rank(f"**视觉 token 数量**: {current_batch_stats['token_counts']['visual_tokens']}") print_rank(f"**语言输入 token 数量 (仅原始文本)**: {current_batch_stats['token_counts']['language_input_tokens_raw']}") print_rank(f"**LLM总输入 token 数量 (包含视觉 + 格式化文本)**: {current_batch_stats['token_counts']['llm_total_input_tokens']}") print_rank(f"**语言输出 token 数量**: {current_batch_stats['token_counts']['language_output_tokens']}") if is_late_interaction and reps.dim() == 3: local_max_len = max(local_max_len, reps.shape[1]) local_embeds.append(reps) if not local_embeds: # Handle cases where a rank gets no data return np.array([]), [], [], [] # CHANGED: 4个返回值 # === DDP Synchronization and Padding for Late-Interaction Models === if is_late_interaction: if dist.is_initialized(): # 1: global max length local_max_len_tensor = torch.tensor(local_max_len, device=training_args.device) dist.all_reduce(local_max_len_tensor, op=dist.ReduceOp.MAX) global_max_len = local_max_len_tensor.item() else: global_max_len = local_max_len # 2: pad to global max length padded_embeds = [] for reps_batch in local_embeds: if reps_batch.dim() == 3: B, L, H = reps_batch.shape padding_size = global_max_len - L padded_batch = F.pad(reps_batch, (0, 0, 0, padding_size), "constant", 0) padded_embeds.append(padded_batch) else: padded_embeds.append(reps_batch) embeds_tensor = torch.cat(padded_embeds, dim=0).contiguous() else: embeds_tensor = torch.cat(local_embeds, dim=0).contiguous() # === Gather embeddings and keys from all ranks === if dist.is_initialized() and full_dataset.num_rows >= world_size: print_master(f"Gathering {encode_side} embeddings across all ranks...") # tensor gather output_shape = list(embeds_tensor.shape) output_shape[0] = full_dataset.num_rows embeds_tensor = embeds_tensor.to(training_args.device) gathered_embeds_tensor = torch.empty(output_shape, dtype=embeds_tensor.dtype, device=training_args.device) dist.all_gather_into_tensor(gathered_embeds_tensor, embeds_tensor) final_embeddings = gathered_embeds_tensor.cpu().float().numpy() # object gather for infos and stats gathered_gt_infos = [None for _ in range(world_size)] dist.all_gather_object(gathered_gt_infos, local_gt_infos) all_gt_infos = [key for rank_keys in gathered_gt_infos for key in rank_keys] gathered_batch_stats = [None for _ in range(world_size)] dist.all_gather_object(gathered_batch_stats, batch_stats_list) all_batch_stats = [stats for rank_stats in gathered_batch_stats for stats in rank_stats] # --- NEW: gather masks --- gathered_masks = [None for _ in range(world_size)] dist.all_gather_object(gathered_masks, local_img_token_masks) all_img_token_masks = [m for rank_list in gathered_masks for m in rank_list] # NEW: gather text masks gathered_txt_masks = [None for _ in range(world_size)] dist.all_gather_object(gathered_txt_masks, local_txt_token_masks) all_txt_token_masks = [m for rank_list in gathered_txt_masks for m in rank_list] # NEW: gather post attention masks(如需) gathered_post_attn = [None for _ in range(world_size)] dist.all_gather_object(gathered_post_attn, local_post_attn_masks) all_post_attn_masks = [m for rank_list in gathered_post_attn for m in rank_list] # NEW: gather token records gathered_token_recs = [None for _ in range(world_size)] dist.all_gather_object(gathered_token_recs, local_token_records) all_token_records = [r for rank_list in gathered_token_recs for r in rank_list] else: all_gt_infos = local_gt_infos final_embeddings = embeds_tensor.cpu().float().numpy() all_batch_stats = batch_stats_list all_img_token_masks = local_img_token_masks # NEW all_txt_token_masks = local_txt_token_masks all_post_attn_masks = local_post_attn_masks all_token_records = local_token_records return final_embeddings, all_gt_infos, all_batch_stats, all_img_token_masks, all_txt_token_masks, all_token_records # === NEW: 一次前向同时导出 cand 的中间层和最后一层向量 === def encode_candidates_both_layers( model: MMEBModel, loader: DataLoader, training_args: TrainingArguments, model_args: ModelArguments, full_dataset: Dataset, mid_layer: int, ) -> tuple[np.ndarray, np.ndarray, list]: """ 单次forward到最后一层,直接从 hidden_states 取: - mid_hidden = hidden_states[mid_layer] # 表示经过 mid_layer 层后的状态(见Qwen2_5_VLModel的all_hidden_states定义) - last_hidden = hidden_states[-1] # 最后一层norm后的状态 然后用 _pooling(attention_mask) 取句向量,返回: - cand_mid_embeds: np.ndarray [Nc, D] - cand_last_embeds: np.ndarray [Nc, D] - cand_ids: list[str] 说明: - cand 侧默认不做 AOP 剪枝(AOP_APPLY=qry 时天然关闭),因此 mid/last 的序列长度一致,可直接用原 attention_mask 做池化。 """ local_rank = dist.get_rank() if dist.is_initialized() else 0 model.eval() all_mid = [] all_last = [] all_ids = [] with torch.no_grad(): for inputs, dataset_info in tqdm(loader, desc=f"Candidates[BOTH] (rank {local_rank})", disable=local_rank > 0): inputs = batch_to_device(inputs, training_args.device) # cand 侧确保不触发 AOP(如果你的 AOP_APPLY=qry/both,会在底模按侧门控;此处再做一次保险) aop_cfg = getattr(model.encoder, "aop_prune_config", None) _orig_enabled = None if isinstance(aop_cfg, dict) and aop_cfg: _orig_enabled = aop_cfg.get("enabled", False) apply_to = aop_cfg.get("apply_to", "qry") side_enable = (apply_to == "both") or (apply_to == "cand") aop_cfg["enabled"] = bool(side_enable and _orig_enabled) setattr(model.encoder, "aop_prune_config", aop_cfg) with torch.autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda"): # 关键:一次forward拿全层的hidden_states out = model.encoder( **inputs, return_dict=True, output_hidden_states=True, # 必须 stop_at_layer=None, # 走全层 ) # 取 hidden_states 并索引中间层/最后一层 hs_list = out.hidden_states assert hs_list is not None and len(hs_list) > mid_layer, \ f"hidden_states is None or too short. Need index {mid_layer}, got len={0 if hs_list is None else len(hs_list)}" mid_hs = hs_list[mid_layer] # [B, L, D]:等价“经过 mid_layer 层后的状态”(即 pre-layer(mid_layer+1)) last_hs = hs_list[-1] # [B, L, D]:最终norm后的状态 # 用原 attention_mask 池化(cand侧未剪枝) am = inputs.get("attention_mask", None) if am is not None and hasattr(am, "device"): if am.device != mid_hs.device: am = am.to(mid_hs.device) reps_mid = model._pooling(mid_hs, am) # [B, D] reps_last = model._pooling(last_hs, am) # [B, D] all_mid.append(reps_mid.detach().float().cpu()) all_last.append(reps_last.detach().float().cpu()) all_ids.extend([info["cand_name"] for info in dataset_info]) # 恢复 AOP 开关(避免影响其它侧) if isinstance(aop_cfg, dict) and _orig_enabled is not None: aop_cfg["enabled"] = _orig_enabled setattr(model.encoder, "aop_prune_config", aop_cfg) if not all_mid: return np.array([]), np.array([]), [] cand_mid_embeds = torch.cat(all_mid, dim=0).numpy() cand_last_embeds = torch.cat(all_last, dim=0).numpy() return cand_mid_embeds, cand_last_embeds, all_ids def main(): # ----------------------- Distributed init ----------------------- if "RANK" in os.environ and dist.is_available() and not dist.is_initialized(): dist.init_process_group(backend="nccl", timeout=datetime.timedelta(minutes=60)) local_rank = dist.get_rank() if dist.is_initialized() else 0 world_size = dist.get_world_size() if dist.is_initialized() else 1 print_master("Distributed init debug info:") print_master(f"RANK: {os.environ.get('RANK')}") print_master(f"LOCAL_RANK: {os.environ.get('LOCAL_RANK')}") print_master(f"WORLD_SIZE: {os.environ.get('WORLD_SIZE')}") print_master(f"MASTER_ADDR: {os.environ.get('MASTER_ADDR')}") print_master(f"MASTER_PORT: {os.environ.get('MASTER_PORT')}") if dist.is_initialized(): print_rank(f"dist.get_rank(): {dist.get_rank()}") print_rank(f"dist.get_world_size(): {dist.get_world_size()}") # 兼容 torchrun 参数 for arg in sys.argv: if arg.startswith("--local-rank="): rank = arg.split("=")[1] sys.argv.remove(arg) sys.argv.append('--local_rank') sys.argv.append(rank) # ----------------------- Parse args ----------------------- parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) model_args, data_args, training_args = parser.parse_args_into_dataclasses() model_args: ModelArguments data_args: DataArguments training_args: TrainingArguments os.makedirs(data_args.encode_output_path, exist_ok=True) # 支持多层评测(优先 LM_LAYERS,兼容 MID_LM_LAYER) layers_to_eval = get_env_eval_layers() print_master(f"Eval layers (qry/tgt): {layers_to_eval}") # ----------------------- Model loading ----------------------- hf_config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) if not getattr(model_args, "model_backbone", None): model_backbone = get_backbone_name(hf_config=hf_config, model_type=model_args.model_type) setattr(model_args, 'model_backbone', model_backbone) setattr(training_args, 'model_backbone', model_backbone) print_master(f'Model Backbone: {model_args.model_backbone}') # 仅 rank0 下载,其他rank等待缓存 if local_rank == 0: processor = load_processor(model_args, data_args) model = MMEBModel.load(model_args, is_trainable=False, processor=processor) print_master(f"[rank=0] Loading the model from Huggingface: {model_args.model_name}...") if torch.distributed.is_initialized(): torch.distributed.barrier() if local_rank != 0: print_rank(f"Loading the model from cache...") processor = load_processor(model_args, data_args) time.sleep(random.randint(2 * local_rank, 3 * local_rank)) model = MMEBModel.load(model_args, is_trainable=False, processor=processor) model.eval() model = model.to(training_args.device, dtype=torch.bfloat16) # ---- NEW: AOP 剪裁配置注入(驱动底模里已实现的 AOP 逻辑)---- aop_cfg = get_env_aop_config() if aop_cfg["enabled"]: # 把配置塞到底模;底模 forward 中读取该 dict 并执行剪裁 setattr(model.encoder, "aop_prune_config", aop_cfg) # 可选:为了便于在判定层取注意力或手算 qk,覆盖注意力实现 attn_override = aop_cfg.get("attn_impl_override", "") if attn_override: try: if hasattr(model.encoder, "model") and hasattr(model.encoder.model, "config"): prev = model.encoder.model.config._attn_implementation model.encoder.model.config._attn_implementation = attn_override print_master(f"[AOP] override attn impl: {prev} -> {attn_override} (仅测试建议)") except Exception as e: print_master(f"[AOP] try override attn impl failed: {e}") print_master("[AOP] AOP-Prune enabled with config: " + json.dumps({ "apply_to": aop_cfg["apply_to"], "layer_idx": aop_cfg["layer_idx"], "mode": aop_cfg["mode"], "delta": aop_cfg["delta"], "K_hat": aop_cfg["K_hat"], "keep_ratio": aop_cfg["keep_ratio"], "min_keep": aop_cfg["min_keep"], "use_bias": aop_cfg["use_bias"], "margin_mid?": (aop_cfg["margin_mid"] is not None), "prune_text": aop_cfg.get("prune_text", False), "keep_ratio_text": aop_cfg.get("keep_ratio_text", None), "keep_ratio_vision": aop_cfg.get("keep_ratio_vision", None), "selection": aop_cfg.get("selection", "aop"), "attn_agg": aop_cfg.get("attn_agg", "mean"), })) else: print_master("[AOP] disabled (set AOP_ENABLED=1 to enable)") # 确保“最后一层”时不裁层(避免类里默认20层的坑) model.set_inference_layers(qry_layers=None, tgt_layers=None) with open(data_args.dataset_config, 'r') as yaml_file: dataset_configs = yaml.safe_load(yaml_file) # ----------------------- Main evaluation loop ----------------------- for dataset_idx, (dataset_name, task_config) in enumerate(dataset_configs.items()): if dist.is_initialized(): dist.barrier() print_master(f"\n--- Evaluating {dataset_name} ---") # 根据 data_basedir 修正路径 if data_args.data_basedir is not None: for key in ["image_root", "video_root", "frame_root", "clip_root", "data_path"]: if data_args.data_basedir and task_config.get(key): task_config[key] = os.path.join(data_args.data_basedir, task_config[key]) # 构建数据集 full_eval_qry_dataset, corpus = AutoEvalPairDataset.instantiate(model_args=model_args, data_args=data_args, **task_config) full_eval_cand_dataset = generate_cand_dataset(full_eval_qry_dataset, corpus) eval_qry_dataset, eval_cand_dataset = full_eval_qry_dataset, full_eval_cand_dataset if dist.is_initialized(): world_size = dist.get_world_size() padded_qry_dataset, _ = pad_dataset_to_divisible(full_eval_qry_dataset, world_size) padded_cand_dataset, _ = pad_dataset_to_divisible(full_eval_cand_dataset, world_size) eval_qry_dataset = split_dataset_by_node(padded_qry_dataset, rank=local_rank, world_size=world_size) eval_cand_dataset = split_dataset_by_node(padded_cand_dataset, rank=local_rank, world_size=world_size) else: padded_qry_dataset, padded_cand_dataset = full_eval_qry_dataset, full_eval_cand_dataset # === EE-only: 仅在线早停推理(先确保两份 candidate 向量)=== ee_cfg = get_env_ee_config() assert ee_cfg["enabled"], "EE_ENABLED must be 1 for EE-only pipeline." # 依据 EE_LAYER 构造 tag mid_layer = int(ee_cfg["layer"]) mid_tag = make_layer_tag(mid_layer) # e.g., layer12 last_tag = "layerlast" # 准备路径 cand_mid_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt_{mid_tag}") cand_last_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt_{last_tag}") # 构造 cand DataLoader(一次性,不切分) eval_cand_collator = MultimodalEvalDataCollator(processor, model_args, data_args, "cand") eval_cand_loader = DataLoader( full_eval_cand_dataset, batch_size=training_args.per_device_eval_batch_size, collate_fn=eval_cand_collator, num_workers=training_args.dataloader_num_workers ) # === 替换为:一次前向,导出 cand 的 mid/last 两份向量 === need_mid = (not os.path.exists(cand_mid_path)) need_last = (not os.path.exists(cand_last_path)) if need_mid or need_last: print_master(f"[{dataset_name}] EE-only: encoding candidates BOTH layers in one pass (mid={mid_tag}, last={last_tag}) ...") # 走全层(不截层) model.set_inference_layers(qry_layers=None, tgt_layers=None) cand_embeds_mid, cand_embeds_last, all_cand_ids = encode_candidates_both_layers( model=model, loader=eval_cand_loader, training_args=training_args, model_args=model_args, full_dataset=full_eval_cand_dataset, mid_layer=mid_layer, ) if local_rank == 0: if need_mid: cand_embed_dict_mid = {cid: emb for cid, emb in zip(all_cand_ids, cand_embeds_mid)} with open(cand_mid_path, "wb") as f: pickle.dump(cand_embed_dict_mid, f) print_master(f"[{dataset_name}] EE-only: saved {mid_tag} candidate embeddings -> {cand_mid_path}") if need_last: cand_embed_dict_last = {cid: emb for cid, emb in zip(all_cand_ids, cand_embeds_last)} with open(cand_last_path, "wb") as f: pickle.dump(cand_embed_dict_last, f) print_master(f"[{dataset_name}] EE-only: saved {last_tag} candidate embeddings -> {cand_last_path}") else: print_master(f"[{dataset_name}] EE-only: reuse existing candidates (mid={cand_mid_path}, last={cand_last_path})") if dist.is_initialized(): dist.barrier() # 3) 在线早停门控 + 子集续跑(不做离线分层评分/曲线) if local_rank == 0: with open(cand_mid_path, "rb") as f: cand_mid_dict = pickle.load(f) with open(cand_last_path, "rb") as f: cand_last_dict = pickle.load(f) rank_global = task_config.get("eval_type", "global") == "global" print_master(f"[{dataset_name}] Run ONLINE early-exit at layer={ee_cfg['layer']}, method={ee_cfg['method']}, tau={ee_cfg['tau']}, topk={ee_cfg['topk']}, global={rank_global}") run_early_exit_queries( model=model, processor=processor, model_args=model_args, data_args=data_args, training_args=training_args, qry_dataset=full_eval_qry_dataset, # 全量 query cand_mid_dict=cand_mid_dict, cand_last_dict=cand_last_dict, ee_cfg=ee_cfg, dataset_name=dataset_name, out_dir=data_args.encode_output_path, global_ranking=rank_global, ) if dist.is_initialized(): dist.barrier() # === EE-only 结束;直接进入下一个数据集 === continue if __name__ == '__main__': main()