code_SAS_VLM2Vec / eval_test_time_with_classifier_AOP_pooling.py
MgGladys's picture
Add files using upload-large-folder tool
ac8b25b verified
# import datetime
# import logging
# import json
# import random
# import time
# import numpy as np
# import os
# import pickle
# import sys
# import torch
# import torch.distributed as dist
# import torch.nn.functional as F
# import yaml
# import transformers
# import math
# from torch.utils.data import DataLoader
# from tqdm import tqdm
# from transformers import HfArgumentParser, AutoConfig, AutoTokenizer
# from datasets import Dataset, concatenate_datasets
# from datasets.distributed import split_dataset_by_node
# from src.arguments import ModelArguments, DataArguments, TrainingArguments
# from src.data.collator.eval_collator import MultimodalEvalDataCollator
# from src.data.eval_dataset.base_eval_dataset import AutoEvalPairDataset, generate_cand_dataset
# from src.eval_utils.metrics import RankingMetrics
# from src.model.model_multi_layer_distill_AOP_pooling import MMEBModel
# from src.model.processor import get_backbone_name, load_processor, COLPALI
# from src.utils import batch_to_device, print_rank, print_master
# # ==========================================
# # 【V5 修改】引入 V5 版本分类器
# # ==========================================
# from src.classifier_utils_V5 import EarlyExitClassifier
# logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s [%(name)s:%(lineno)s] %(message)s')
# logger = logging.getLogger(__name__)
# # ===========================
# # Per-dataset thresholds (示例,需根据V5分析结果更新)
# # ===========================
# PER_DATASET_THRESHOLDS = {
# "CIRR": 0.3,
# "EDIS": 0.3,
# "FashionIQ": 0.3,
# "MSCOCO_i2t": 0.3,
# "MSCOCO_t2i": 0.3,
# "NIGHTS": 0.3,
# "OVEN": 0.3,
# "VisDial": 0.3,
# "VisualNews_i2t": 0.3,
# "VisualNews_t2i": 0.3,
# "WebQA": 0.3,
# "Wiki-SS-NQ": 0.3,
# }
# # ... (Helper Functions 保持不变) ...
# def _parse_bool(v: str, default=False):
# if v is None: return default
# v = v.strip().lower()
# return v in {"1","true","yes","y","t","on"}
# def _parse_float(v: str, default=None):
# try: return float(v) if v is not None else default
# except: return default
# def _parse_int(v: str, default=None):
# try: return int(v) if v is not None else default
# except: return default
# def get_env_aop_config():
# enabled = _parse_bool(os.environ.get("AOP_ENABLED"), False)
# apply_to = os.environ.get("AOP_APPLY", "qry").strip().lower()
# layer_idx = _parse_int(os.environ.get("AOP_LAYER"), None)
# mode = os.environ.get("AOP_MODE", "delta").strip().lower()
# delta = _parse_float(os.environ.get("AOP_DELTA"), 0.10)
# khat = _parse_float(os.environ.get("AOP_KHAT"), 1.0)
# keep_ratio = _parse_float(os.environ.get("AOP_KEEP_RATIO"), 1.0)
# min_keep = _parse_int(os.environ.get("AOP_MIN_KEEP"), 64)
# use_bias = _parse_bool(os.environ.get("AOP_USE_BIAS"), True)
# prune_vision = _parse_bool(os.environ.get("AOP_PRUNE_VISION"), True)
# prune_text = _parse_bool(os.environ.get("AOP_PRUNE_TEXT"), False)
# keep_ratio_v = _parse_float(os.environ.get("AOP_KEEP_RATIO_VISION"), None)
# keep_ratio_t = _parse_float(os.environ.get("AOP_KEEP_RATIO_TEXT"), None)
# # 新增:各自的 min_keep
# min_keep_v = _parse_int(os.environ.get("AOP_MIN_KEEP_VISION"), None)
# min_keep_t = _parse_int(os.environ.get("AOP_MIN_KEEP_TEXT"), 32)
# # 新增:文本保护参数
# protect_text_last = _parse_int(os.environ.get("AOP_PROTECT_TEXT_LAST"), 16)
# protect_special = _parse_bool(os.environ.get("AOP_PROTECT_SPECIAL"), True)
# selection = os.environ.get("AOP_SELECTION", "aop").strip().lower()
# attn_agg = os.environ.get("AOP_ATTENTION_AGG", "mean").strip().lower()
# random_seed = _parse_int(os.environ.get("AOP_RANDOM_SEED"), None)
# if layer_idx is None and enabled:
# enabled = False
# return {
# "enabled": enabled,
# "apply_to": apply_to,
# "layer_idx": layer_idx,
# "mode": mode,
# "delta": delta,
# "K_hat": khat,
# "keep_ratio": keep_ratio,
# "min_keep": min_keep,
# "use_bias": use_bias,
# "prune_vision": prune_vision,
# "prune_text": prune_text,
# "keep_ratio_vision": keep_ratio_v,
# "keep_ratio_text": keep_ratio_t,
# "min_keep_vision": min_keep_v,
# "min_keep_text": min_keep_t,
# "protect_text_last": protect_text_last,
# "protect_special": protect_special,
# "selection": selection,
# "attn_agg": attn_agg,
# "random_seed": random_seed,
# "margin_mid": None,
# }
# def _set_attr_on_base(peft_or_base, name, value):
# try:
# base = peft_or_base.get_base_model() if hasattr(peft_or_base, "get_base_model") else None
# if base is None and hasattr(peft_or_base, "model"):
# base = peft_or_base.model
# if base is not None:
# setattr(base, name, value)
# if hasattr(base, "model"): # ForConditionalGeneration.model -> Qwen2VLModel
# setattr(base.model, name, value)
# except Exception as e:
# print_rank(f"[inject-config] warn: set {name} on base failed: {e}")
# def get_env_vpool_config():
# def _b(k, d=False):
# v = os.environ.get(k)
# if v is None: return d
# return str(v).strip().lower() in {"1","true","y","yes","on"}
# def _i(k, d=None):
# try: return int(os.environ.get(k, d))
# except: return d
# def _s(k, d=None):
# v = os.environ.get(k, d)
# return None if v is None else str(v).strip().lower()
# return {
# "enabled": _b("VPOOL_ENABLED", False),
# "apply_to": _s("VPOOL_APPLY", "both"), # qry|tgt|both
# "layer_idx": _i("VPOOL_LAYER", 1),
# "kernel": _i("VPOOL_KERNEL", 2),
# "stride": _i("VPOOL_STRIDE", None) or _i("VPOOL_KERNEL", 2),
# "method": _s("VPOOL_METHOD", "avg"),
# "protect_cls": _b("VPOOL_PROTECT_CLS", True),
# "vision_only": _b("VPOOL_ONLY_VISION", True),
# "monitor": _b("VPOOL_MONITOR", False),
# }
# def get_env_ee_config():
# ee_enabled = os.environ.get("EE_ENABLED", "0").strip().lower() in {"1","true","yes","on","y","t"}
# layer = int(os.environ.get("EE_LAYER", "12"))
# method = os.environ.get("EE_METHOD", "classifier").strip().lower()
# threshold = float(os.environ.get("EE_THRESHOLD", "0.8"))
# classifier_path = os.environ.get("EE_CLASSIFIER_PATH", "")
# return dict(enabled=ee_enabled, layer=layer, method=method, threshold=threshold, classifier_path=classifier_path)
# def pad_dataset_to_divisible(dataset, world_size):
# num_samples = len(dataset)
# if num_samples % world_size == 0: return dataset, num_samples
# num_to_add = world_size - (num_samples % world_size)
# padding_data = dataset.select([i % len(dataset) for i in range(num_to_add)])
# return concatenate_datasets([dataset, padding_data]), num_samples + num_to_add
# # ===========================
# # Core Inference Function
# # ===========================
# def run_early_exit_queries(
# model: MMEBModel,
# classifier: EarlyExitClassifier,
# processor,
# model_args: ModelArguments,
# data_args: DataArguments,
# training_args: TrainingArguments,
# qry_dataset: Dataset,
# cand_mid_dict: dict,
# cand_last_dict: dict,
# ee_cfg: dict,
# dataset_name: str,
# out_dir: str,
# global_ranking: bool = True,
# ):
# device = training_args.device
# local_rank = dist.get_rank() if dist.is_initialized() else 0
# is_main = (not dist.is_initialized()) or (local_rank == 0)
# # Profiling 配置
# profile_enabled = os.environ.get("EE_PROFILE", "0").strip().lower() in {"1", "true", "yes", "on", "y", "t"}
# topk_emb = int(os.environ.get("EE_TOPK_EMB", "5"))
# timing_stats = {
# "mid_time_sum": 0.0, "mid_num": 0, "tail_time_sum": 0.0, "tail_num": 0,
# }
# analysis_records = [] if (profile_enabled and is_main) else None
# # 1. 准备 Candidates
# cand_ids = list(cand_mid_dict.keys())
# cand_mid = np.stack([cand_mid_dict[c] for c in cand_ids]).astype(np.float32)
# cand_last = np.stack([cand_last_dict[c] for c in cand_ids]).astype(np.float32)
# cand_mid_np = cand_mid # [Nc, D], float32
# cand_last_np = cand_last # [Nc, D], float32
# cand_mid_t = torch.from_numpy(cand_mid).to(device=device, dtype=torch.bfloat16)
# cand_last_t = torch.from_numpy(cand_last).to(device=device, dtype=torch.bfloat16)
# from contextlib import contextmanager
# # 侧别是否启用 AOP/VPOOL
# aop_cfg_model = getattr(model.encoder, "aop_prune_config", None)
# vpool_cfg_model = getattr(model.encoder, "vision_pooling_config", None)
# @contextmanager
# def _switch_aop(enable: bool):
# enc = model.encoder
# old = getattr(enc, "aop_prune_config", None)
# def _set_cfg(mod, cfg):
# setattr(mod, "aop_prune_config", cfg)
# base = mod.get_base_model() if hasattr(mod, "get_base_model") else None
# if base is None and hasattr(mod, "model"):
# base = mod.model
# if base is not None:
# setattr(base, "aop_prune_config", cfg)
# if hasattr(base, "model"):
# setattr(base.model, "aop_prune_config", cfg)
# if isinstance(old, dict) and not enable:
# cfg = dict(old); cfg["enabled"] = False
# _set_cfg(enc, cfg)
# try:
# yield
# finally:
# _set_cfg(enc, old)
# @contextmanager
# def _switch_vpool(enable: bool):
# enc = model.encoder
# old = getattr(enc, "vision_pooling_config", None)
# def _set_cfg(mod, cfg):
# setattr(mod, "vision_pooling_config", cfg)
# base = mod.get_base_model() if hasattr(mod, "get_base_model") else None
# if base is None and hasattr(mod, "model"):
# base = mod.model
# if base is not None:
# setattr(base, "vision_pooling_config", cfg)
# if hasattr(base, "model"):
# setattr(base.model, "vision_pooling_config", cfg)
# if isinstance(old, dict) and not enable:
# cfg = dict(old); cfg["enabled"] = False
# _set_cfg(enc, cfg)
# try:
# yield
# finally:
# _set_cfg(enc, old)
# collator = MultimodalEvalDataCollator(processor, model_args, data_args, "qry")
# loader = DataLoader(
# qry_dataset, batch_size=training_args.per_device_eval_batch_size,
# collate_fn=collator, num_workers=training_args.dataloader_num_workers
# )
# pred_dicts = []
# stats = {"exit": 0, "total": 0}
# threshold = float(ee_cfg["threshold"])
# method = ee_cfg["method"]
# target_layer_idx = int(ee_cfg["layer"])
# results_dict = {}
# global_sample_idx = 0
# use_local = (not global_ranking)
# if use_local:
# print_master(f"[INFO] Using LOCAL ranking (per-query candidate sets)")
# cand_id2row = {str(cid): i for i, cid in enumerate(cand_ids)}
# else:
# print_master(f"[INFO] Using GLOBAL ranking (full library)")
# # AOP Config
# aop_cfg = getattr(model.encoder, "aop_prune_config", None)
# _orig_enabled = None
# side_enable = True
# if isinstance(aop_cfg, dict) and aop_cfg:
# _orig_enabled = aop_cfg.get("enabled", False)
# apply_to = aop_cfg.get("apply_to", "qry")
# side_enable = (apply_to == "both") or (apply_to == "qry")
# model.eval()
# if classifier:
# classifier.eval()
# # 【V5 关键】分类器运行在 FP32 (默认),不要转 BF16,配合 Log/LayerNorm
# classifier.to(device)
# start_time = time.time()
# for inputs, infos in tqdm(
# loader, desc=f"[EE+AOP] {dataset_name} (tau={threshold})", disable=local_rank > 0,
# ):
# inputs = batch_to_device(inputs, device)
# B = inputs["input_ids"].shape[0] if "input_ids" in inputs else 1
# batch_start_idx = global_sample_idx
# global_sample_idx += B
# stats["total"] += B
# # ---------------------------------------------------
# # 1. 前半程: Run to Mid Layer
# # ---------------------------------------------------
# # AOP 侧别
# aop_apply = (aop_cfg_model or {}).get("apply_to", "qry")
# aop_side_qry = aop_apply in ("both", "qry")
# aop_side_tgt = aop_apply in ("both", "tgt")
# # VPOOL 侧别(一般 pooling 设在 layer=1,因此 mid 一定会经历)
# vpool_apply = (vpool_cfg_model or {}).get("apply_to", "qry")
# vpool_layer = (vpool_cfg_model or {}).get("layer_idx", 1)
# vpool_side_qry = vpool_apply in ("both", "qry")
# vpool_side_tgt = vpool_apply in ("both", "tgt")
# orig_cfg = None
# if isinstance(aop_cfg, dict) and aop_cfg:
# orig_cfg = dict(aop_cfg)
# aop_layer = aop_cfg.get("layer_idx", None)
# aop_on_mid = bool(
# _orig_enabled and side_enable
# and (aop_layer is not None)
# and (aop_layer <= target_layer_idx)
# )
# aop_cfg_mid = dict(aop_cfg)
# aop_cfg_mid["enabled"] = aop_on_mid
# setattr(model.encoder, "aop_prune_config", aop_cfg_mid)
# if profile_enabled:
# torch.cuda.synchronize()
# t0_mid = time.perf_counter()
# # --- tgt mid ---
# tgt_pre_mask = tgt_inputs.get("attention_mask", None)
# tgt_post_mask = getattr(out_mid_tgt, "attention_mask", None)
# tgt_hs_mid = getattr(out_mid_tgt, "last_hidden_state", None) or out_mid_tgt.hidden_states[-1]
# tgt_am_mid = tgt_post_mask if (tgt_post_mask is not None and tgt_post_mask.size(1) == tgt_hs_mid.size(1)) else tgt_pre_mask
# tgt_reps_mid = model._pooling(tgt_hs_mid, tgt_am_mid).detach().to(dtype=torch.bfloat16)
# # --- qry mid ---
# qry_pre_mask = qry_inputs.get("attention_mask", None)
# qry_post_mask = getattr(out_mid_qry, "attention_mask", None)
# qry_hs_mid = getattr(out_mid_qry, "last_hidden_state", None) or out_mid_qry.hidden_states[-1]
# qry_am_mid = qry_post_mask if (qry_post_mask is not None and qry_post_mask.size(1) == qry_hs_mid.size(1)) else qry_pre_mask
# qry_reps_mid = model._pooling(qry_hs_mid, qry_am_mid).detach().to(dtype=torch.bfloat16)
# with _switch_aop(aop_side_tgt), _switch_vpool(vpool_side_tgt and vpool_layer <= target_layer_idx):
# out_mid_tgt = model.encoder(
# **tgt_inputs, return_dict=True, output_hidden_states=False,
# stop_at_layer=target_layer_idx, compute_lm_head=False,
# )
# with _switch_aop(aop_side_qry), _switch_vpool(vpool_side_qry and vpool_layer <= target_layer_idx):
# out_mid_qry = model.encoder(
# **qry_inputs, return_dict=True, output_hidden_states=False,
# stop_at_layer=target_layer_idx, compute_lm_head=False,
# )
# # 这里就是“该插这三行”的地方:
# post_attn_mid = getattr(out_mid, "attention_mask", None) # [B, L'] 剪完后的 2D mask
# img_mask_mid = getattr(out_mid, "image_token_bool_masks", None) # [B, L'] 视觉 token 掩码(Qwen2.5 里目前是 pre 的,你如果想要 post,可以在 Qwen2_5_VLModel 里改成输出 cur_vision_mask)
# txt_mask_mid = getattr(out_mid, "text_token_bool_masks", None) # [B, L'] 文本 token 掩码(已经是 post-prune 的)
# # 然后按需统计,比如:
# if post_attn_mid is not None:
# # 有效 token 总数
# keep_counts = post_attn_mid.to(torch.bool).sum(dim=1) # [B]
# # 文本 token 保留数
# if txt_mask_mid is not None:
# text_keep = (post_attn_mid.to(torch.bool) & txt_mask_mid.to(torch.bool)).sum(dim=1)
# # 图像 token 保留数
# if img_mask_mid is not None:
# vis_keep = (post_attn_mid.to(torch.bool) & img_mask_mid.to(torch.bool)).sum(dim=1)
# if profile_enabled:
# torch.cuda.synchronize()
# t1_mid = time.perf_counter()
# timing_stats["mid_time_sum"] += (t1_mid - t0_mid) * B
# timing_stats["mid_num"] += B
# if isinstance(orig_cfg, dict):
# setattr(model.encoder, "aop_prune_config", orig_cfg)
# hs_mid = getattr(out_mid, "last_hidden_state", None)
# if hs_mid is None: hs_mid = out_mid.hidden_states[-1]
# am_mid = getattr(out_mid, "attention_mask", None)
# if am_mid is None: am_mid = inputs.get("attention_mask", None)
# reps_mid_t = model._pooling(hs_mid, am_mid).detach().to(dtype=torch.bfloat16)
# # Profiling Last Layer
# reps_last_full = None
# if profile_enabled:
# with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
# out_full = model.encoder(
# **inputs, return_dict=True, output_hidden_states=False,
# stop_at_layer=None, compute_lm_head=False,
# )
# hs_last_full = getattr(out_full, "last_hidden_state", None)
# if hs_last_full is None: hs_last_full = out_full.hidden_states[-1]
# am_last_full = getattr(out_full, "attention_mask", None)
# if am_last_full is None: am_last_full = inputs.get("attention_mask", None)
# reps_last_full = model._pooling(hs_last_full, am_last_full).detach().to(dtype=torch.bfloat16)
# # ---------------------------------------------------
# # 2. 特征工程 + gating
# # ---------------------------------------------------
# exit_mask = np.zeros(B, dtype=bool)
# p_need_last_batch = None
# if method == "classifier" and classifier is not None:
# with torch.no_grad():
# cos_mid = reps_mid_t @ cand_mid_t.T # [B, N] (BF16)
# backbone_ptr = model.module if hasattr(model, "module") else model
# temp = getattr(backbone_ptr, "temperature", 0.02)
# scores_mid = cos_mid / temp
# probs_mid = torch.softmax(scores_mid, dim=1) # [B, N]
# # --- 27维特征计算 (BF16) ---
# diag_cos = cos_mid.max(dim=1)[0]
# sorted_cos, _ = torch.sort(cos_mid, dim=1, descending=True)
# s2_cos = sorted_cos[:, 1] if sorted_cos.size(1) > 1 else sorted_cos[:, 0]
# margin_mid = diag_cos - s2_cos
# margin_mean = margin_mid.mean()
# margin_std = margin_mid.std(unbiased=False) + 1e-6
# z_margin_mid = (margin_mid - margin_mean) / margin_std
# margin_median = margin_mid.median()
# mad = (margin_mid - margin_median).abs().median() + 1e-6
# mad_margin_mid = (margin_mid - margin_median) / mad
# p1_mid = probs_mid.max(dim=1)[0]
# H_mid = -(probs_mid * torch.log(probs_mid + 1e-6)).sum(dim=1)
# gini_mid = 1.0 - (probs_mid ** 2).sum(dim=1)
# TOPK = min(16, probs_mid.size(1))
# topk_vals, _ = torch.topk(probs_mid, k=TOPK, dim=1)
# topk_mean = topk_vals.mean(dim=1)
# topk_std = topk_vals.std(dim=1, unbiased=False)
# topk_cv = topk_std / (topk_mean + 1e-6)
# centered = topk_vals - topk_mean.unsqueeze(1)
# var = (centered ** 2).mean(dim=1) + 1e-6
# m4 = (centered ** 4).mean(dim=1)
# topk_kurt = m4 / (var ** 2)
# topk_med = topk_vals.median(dim=1).values
# row_mean_cos = cos_mid.mean(dim=1)
# row_med_cos = cos_mid.median(dim=1).values
# s1_over_mean = diag_cos - row_mean_cos
# s1_over_med = diag_cos - row_med_cos
# sorted_probs, _ = torch.sort(probs_mid, dim=1, descending=True)
# p1 = sorted_probs[:, 0]
# p2 = sorted_probs[:, 1] if sorted_probs.size(1) > 1 else sorted_probs[:, 0]
# shape_H = -(sorted_probs * torch.log(sorted_probs + 1e-6)).sum(dim=1)
# shape_gini = 1.0 - (sorted_probs ** 2).sum(dim=1)
# R = min(10, sorted_probs.size(1))
# x = torch.arange(R, device=device, dtype=sorted_probs.dtype)
# x_centered = x - x.mean()
# denom = (x_centered ** 2).sum()
# y = torch.log(sorted_probs[:, :R] + 1e-6)
# slope = (x_centered.unsqueeze(0) * y).sum(dim=1) / denom
# row_mean_p = probs_mid.mean(dim=1)
# row_std_p = probs_mid.std(dim=1, unbiased=False) + 1e-6
# z1 = (p1_mid - row_mean_p) / row_std_p
# center_p = probs_mid - row_mean_p.unsqueeze(1)
# m3 = (center_p ** 3).mean(dim=1)
# skew = m3 / (row_std_p ** 3 + 1e-6)
# s1_over_sk = p1_mid - skew
# TAIL_K = min(10, sorted_probs.size(1))
# tail_mean = sorted_probs[:, -TAIL_K:].mean(dim=1)
# HEAD_K = min(5, sorted_probs.size(1))
# head5_mean = sorted_probs[:, :HEAD_K].mean(dim=1)
# mask_ratio = torch.zeros_like(diag_cos)
# mask_len = torch.zeros_like(diag_cos)
# mask_runs = torch.zeros_like(diag_cos)
# scalar_inputs = torch.stack([
# diag_cos, s2_cos, margin_mid, z_margin_mid, mad_margin_mid,
# p1_mid, H_mid, gini_mid,
# topk_mean, topk_std, topk_cv, topk_kurt, topk_med,
# s1_over_mean, s1_over_med,
# p1, p2, shape_H, shape_gini, slope, z1, s1_over_sk,
# tail_mean, head5_mean,
# mask_ratio, mask_len, mask_runs
# ], dim=1)
# modality_idx = torch.zeros(B, dtype=torch.long, device=device)
# if "pixel_values" in inputs and inputs["pixel_values"] is not None:
# pv = inputs["pixel_values"]
# if isinstance(pv, list):
# for i, item in enumerate(pv):
# if item is not None: modality_idx[i] = 1
# elif isinstance(pv, torch.Tensor) and pv.numel() > 0:
# modality_idx.fill_(1)
# # =======================================================
# # 【V5 关键】强制转为 FP32 传给分类器
# # =======================================================
# scalar_inputs_f32 = scalar_inputs.float()
# qry_emb_f32 = reps_mid_t.float() # 转 FP32
# logits = classifier(scalar_inputs_f32, modality_idx, qry_emb=qry_emb_f32)
# p_need_last = torch.sigmoid(logits) # [B,1]
# p_need_last_batch = p_need_last.squeeze(1) # [B]
# should_exit = p_need_last_batch < threshold
# exit_mask = should_exit.cpu().numpy()
# if stats["total"] <= B * 3 and is_main:
# print_master(
# f"[EE Debug] Batch {stats['total']//B}: "
# f"p_need_last mean={p_need_last_batch.mean().item():.4f}, "
# f"Exit Rate={exit_mask.mean():.2%}, "
# f"Top3: diag_cos={diag_cos.mean():.3f}, margin={margin_mid.mean():.3f}"
# )
# stats["exit"] += exit_mask.sum()
# # ---------------------------------------------------
# # 3. 分支执行
# # ---------------------------------------------------
# exit_indices = np.where(exit_mask)[0]
# cont_indices = np.where(~exit_mask)[0]
# # A. 早停样本
# if len(exit_indices) > 0:
# reps_exit = reps_mid_t[exit_indices]
# if use_local:
# for i, idx in enumerate(exit_indices):
# cand_local = infos[idx].get("cand_names", [])
# if not cand_local: cids = []
# else:
# rows = [cand_id2row.get(str(cid), -1) for cid in cand_local]
# rows = [r for r in rows if r >= 0]
# if len(rows) == 0: cids = []
# else:
# cmat_np = cand_mid_np[rows]
# qry_np = reps_exit[i].detach().float().cpu().numpy()
# scores_vec = np.dot(qry_np, cmat_np.T)
# top_k = min(200, len(rows))
# order_local = np.argsort(-scores_vec)[:top_k]
# cids = [str(cand_local[o]) for o in order_local]
# _record_result(results_dict, batch_start_idx + idx, infos[idx], cids)
# else:
# reps_exit_np = reps_exit.detach().float().cpu().numpy()
# scores_full = np.dot(reps_exit_np, cand_mid_np.T)
# top_k = min(200, len(cand_ids))
# topk_inds = np.argsort(-scores_full, axis=1)[:, :top_k]
# for i, idx in enumerate(exit_indices):
# cids = [cand_ids[k] for k in topk_inds[i]]
# _record_result(results_dict, batch_start_idx + idx, infos[idx], cids)
# # B. 续跑样本
# if len(cont_indices) > 0:
# # 取中间态
# interm = getattr(out_mid, "intermediate_state", None)
# hs, am, pos = interm["hidden_states"].detach(), interm["attention_mask"].detach(), interm["position_ids"].detach()
# vm, tm = interm.get("vision_mask", None), interm.get("text_mask", None)
# next_layer = int(interm["next_layer_idx"])
# resume_state_subset = {
# "hidden_states": hs[cont_indices], "attention_mask": am[cont_indices],
# "position_ids": pos[:, cont_indices, :],
# "vision_mask": vm[cont_indices] if vm is not None else None,
# "text_mask": tm[cont_indices] if tm is not None else None,
# "next_layer_idx": next_layer,
# }
# # ====== 这里改 AOP 配置 ======
# if isinstance(aop_cfg, dict) and aop_cfg:
# aop_resume = dict(aop_cfg)
# aop_layer = aop_resume.get("layer_idx", None)
# # 情况1:AOP_LAYER 已经 <= EE_LAYER,说明 mid 那次已经剪过了 → tail 不再剪
# if (aop_layer is not None) and (aop_layer <= target_layer_idx):
# need_prune_in_tail = False
# else:
# # 情况2:AOP_LAYER 在 EE_LAYER 后面(比如 16),只在 tail 中触发
# need_prune_in_tail = bool(_orig_enabled and side_enable)
# aop_resume["enabled"] = need_prune_in_tail
# setattr(model.encoder, "aop_prune_config", aop_resume)
# # ====== AOP 配置修改结束 ======
# if profile_enabled:
# torch.cuda.synchronize()
# t0_tail = time.perf_counter()
# with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
# out_last = model.encoder(
# return_dict=True, output_hidden_states=False, stop_at_layer=None,
# resume_state=resume_state_subset, compute_lm_head=False,
# )
# if profile_enabled:
# torch.cuda.synchronize()
# t1_tail = time.perf_counter()
# timing_stats["tail_time_sum"] += (t1_tail - t0_tail) * len(cont_indices)
# timing_stats["tail_num"] += len(cont_indices)
# hs_last = out_last.last_hidden_state
# if hs_last is None: hs_last = out_last.hidden_states[-1]
# am_last = getattr(out_last, "attention_mask", None)
# if am_last is None: am_last = resume_state_subset["attention_mask"]
# reps_last_t = model._pooling(hs_last, am_last).detach().to(dtype=torch.bfloat16)
# if use_local:
# for i, idx_global in enumerate(cont_indices):
# cand_local = infos[idx_global].get("cand_names", [])
# if not cand_local: cids = []
# else:
# rows = [cand_id2row.get(str(cid), -1) for cid in cand_local]
# rows = [r for r in rows if r >= 0]
# if len(rows) == 0: cids = []
# else:
# cmat_last_np = cand_last_np[rows]
# qry_last_np = reps_last_t[i].detach().float().cpu().numpy()
# scores_vec = np.dot(qry_last_np, cmat_last_np.T)
# top_k = min(200, len(rows))
# order_local = np.argsort(-scores_vec)[:top_k]
# cids = [str(cand_local[o]) for o in order_local]
# _record_result(results_dict, batch_start_idx + idx_global, infos[idx_global], cids)
# else:
# reps_last_np = reps_last_t.detach().float().cpu().numpy()
# scores_last = np.dot(reps_last_np, cand_last_np.T)
# top_k = min(200, len(cand_ids))
# topk_inds = np.argsort(-scores_last, axis=1)[:, :top_k]
# for i, idx_global in enumerate(cont_indices):
# cids = [cand_ids[k] for k in topk_inds[i]]
# _record_result(results_dict, batch_start_idx + idx_global, infos[idx_global], cids)
# # ---------------------------------------------------
# # 4. Profiling Stats
# # ---------------------------------------------------
# if profile_enabled and is_main:
# K = min(topk_emb, cand_mid_t.size(0))
# # 转到 float32 + CPU 便于写盘
# q_mid_cpu = reps_mid_t.detach().float().cpu() # [B, D]
# q_last_cpu = (
# reps_last_full.detach().float().cpu()
# if reps_last_full is not None
# else None
# ) # [B, D]
# cand_mid_cpu = cand_mid_t.detach().float().cpu() # [Nc, D]
# cand_last_cpu = cand_last_t.detach().float().cpu() # [Nc, D]
# # mid2mid
# scores_mid_full = q_mid_cpu @ cand_mid_cpu.T # [B, Nc]
# topk_mid_vals, topk_mid_inds = torch.topk(
# scores_mid_full, k=K, dim=1
# )
# # last2last
# if q_last_cpu is not None:
# scores_last_full = q_last_cpu @ cand_last_cpu.T # [B, Nc]
# topk_last_vals, topk_last_inds = torch.topk(
# scores_last_full, k=K, dim=1
# )
# else:
# topk_last_vals = None
# topk_last_inds = None
# for i in range(B):
# qid = batch_start_idx + i
# rec = {
# "qid": int(qid),
# "early_exit": bool(exit_mask[i]),
# }
# if p_need_last_batch is not None:
# rec["p_need_last"] = float(p_need_last_batch[i].item())
# # mid2mid TopK
# mid_inds = topk_mid_inds[i].tolist()
# mid_scores = topk_mid_vals[i].tolist()
# rec["mid_topk_scores"] = mid_scores
# rec["mid_topk_cand_ids"] = [cand_ids[j] for j in mid_inds]
# # last2last TopK
# if topk_last_inds is not None:
# last_inds = topk_last_inds[i].tolist()
# last_scores = topk_last_vals[i].tolist()
# rec["last_topk_scores"] = last_scores
# rec["last_topk_cand_ids"] = [
# cand_ids[j] for j in last_inds
# ]
# else:
# rec["last_topk_scores"] = None
# rec["last_topk_cand_ids"] = None
# analysis_records.append(rec)
# # 5. 收集 & 保存结果
# for idx in sorted(results_dict.keys()):
# pred_dicts.append(results_dict[idx])
# print_master(f"Early Exit Stats: Exit={stats['exit']}/{stats['total']} ({stats['exit']/stats['total']:.2%})")
# metrics_to_report = ["hit", "ndcg", "precision", "recall", "f1", "map", "mrr"]
# score = RankingMetrics(metrics_to_report).evaluate(pred_dicts)
# if is_main:
# os.makedirs(out_dir, exist_ok=True)
# with open(os.path.join(out_dir, f"{dataset_name}_score_earlyexit.json"), "w") as f:
# json.dump(score, f, indent=4)
# # Profiling save
# if profile_enabled:
# prof_dir = os.path.join(out_dir, "profiling")
# os.makedirs(prof_dir, exist_ok=True)
# mid_avg = timing_stats["mid_time_sum"] / max(1, timing_stats["mid_num"])
# tail_avg = timing_stats["tail_time_sum"] / max(1, timing_stats["tail_num"])
# timing_out = {
# "mid_time_sum": timing_stats["mid_time_sum"],
# "mid_num": timing_stats["mid_num"],
# "tail_time_sum": timing_stats["tail_time_sum"],
# "tail_num": timing_stats["tail_num"],
# "avg_mid_time_per_query_sec": mid_avg,
# "avg_tail_time_per_cont_query_sec": tail_avg,
# "num_exit": int(stats["exit"]),
# "num_total": int(stats["total"]),
# }
# with open(os.path.join(prof_dir, f"{dataset_name}_timing.json"), "w") as f:
# json.dump(timing_out, f, indent=2)
# embed_path = os.path.join(prof_dir, f"{dataset_name}_embeds.jsonl")
# with open(embed_path, "w") as f:
# for rec in analysis_records:
# f.write(json.dumps(rec) + "\n")
# print_master(f"[PROFILE] Saved timing to {prof_dir}, details to {embed_path}")
# elapsed = time.time() - start_time
# return score, elapsed
# def _record_result(results_dict, global_idx, info, cids):
# label = info.get("label_name") or info.get("label") or info.get("rel_docids")
# if not isinstance(label, list): label = [label]
# rel_scores = info.get("rel_scores", None)
# results_dict[global_idx] = {
# "prediction": cids, "label": label, "rel_scores": rel_scores,
# }
# # ===========================
# # Helper Functions (Pre-Computation)
# # ===========================
# # ... (encode_candidates_both_layers 保持不变) ...
# def encode_candidates_both_layers(model: MMEBModel, loader: DataLoader, training_args: TrainingArguments, model_args: ModelArguments, full_dataset: Dataset, mid_layer: int):
# local_rank = dist.get_rank() if dist.is_initialized() else 0
# model.eval()
# all_mid, all_last, all_ids = [], [], []
# with torch.no_grad():
# for inputs, dataset_info in tqdm(loader, desc=f"Candidates[BOTH] (rank {local_rank})", disable=local_rank > 0):
# inputs = batch_to_device(inputs, training_args.device)
# aop_cfg = getattr(model.encoder, "aop_prune_config", None)
# vpool_cfg = getattr(model.encoder, "vision_pooling_config", None)
# _orig_aop = None
# _orig_vpool = None
# if isinstance(aop_cfg, dict):
# _orig_aop = aop_cfg.get("enabled", False)
# aop_cfg = dict(aop_cfg); aop_cfg["enabled"] = False
# setattr(model.encoder, "aop_prune_config", aop_cfg)
# if isinstance(vpool_cfg, dict):
# _orig_vpool = vpool_cfg.get("enabled", False)
# vpool_cfg = dict(vpool_cfg); vpool_cfg["enabled"] = False
# setattr(model.encoder, "vision_pooling_config", vpool_cfg)
# with torch.autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda"):
# out = model.encoder(**inputs, return_dict=True, output_hidden_states=True, stop_at_layer=None)
# # 恢复
# if isinstance(aop_cfg, dict) and _orig_aop is not None:
# aop_cfg["enabled"] = _orig_aop
# setattr(model.encoder, "aop_prune_config", aop_cfg)
# if isinstance(vpool_cfg, dict) and _orig_vpool is not None:
# vpool_cfg["enabled"] = _orig_vpool
# setattr(model.encoder, "vision_pooling_config", vpool_cfg)
# mid_hs = out.hidden_states[mid_layer]
# last_hs = out.hidden_states[-1]
# am = inputs.get("attention_mask", None)
# if am is not None and am.device != mid_hs.device: am = am.to(mid_hs.device)
# reps_mid = model._pooling(mid_hs, am)
# reps_last = model._pooling(last_hs, am)
# all_mid.append(reps_mid.detach().float().cpu())
# all_last.append(reps_last.detach().float().cpu())
# all_ids.extend([info["cand_name"] for info in dataset_info])
# if not all_mid: return np.array([]), np.array([]), []
# return torch.cat(all_mid, dim=0).numpy(), torch.cat(all_last, dim=0).numpy(), all_ids
# # ===========================
# # Main
# # ===========================
# def main():
# if "RANK" in os.environ and dist.is_available() and not dist.is_initialized():
# dist.init_process_group(backend="nccl", timeout=datetime.timedelta(minutes=60))
# local_rank = dist.get_rank() if dist.is_initialized() else 0
# parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
# model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# ee_cfg = get_env_ee_config()
# hf_config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True)
# if not getattr(model_args, "model_backbone", None):
# model_backbone = get_backbone_name(hf_config=hf_config, model_type=model_args.model_type)
# setattr(model_args, 'model_backbone', model_backbone)
# processor = load_processor(model_args, data_args)
# model = MMEBModel.load(model_args, is_trainable=False, processor=processor)
# model.eval()
# model = model.to(training_args.device, dtype=torch.bfloat16)
# # [ADD] 注入 AOP & VPOOL 到 wrapper + 底座
# aop_cfg = get_env_aop_config()
# vpool_cfg = get_env_vpool_config()
# if aop_cfg.get("enabled", False):
# setattr(model.encoder, "aop_prune_config", aop_cfg)
# _set_attr_on_base(model.encoder, "aop_prune_config", aop_cfg)
# model.set_inference_layers(qry_layers=None, tgt_layers=None)
# print_master("[AOP][eval] enabled with cfg: " + str({
# k: aop_cfg.get(k) for k in ["apply_to","layer_idx","mode","selection","attn_agg","prune_text","prune_vision","keep_ratio_text","keep_ratio_vision"]
# }))
# else:
# print_master("[AOP][eval] disabled")
# if vpool_cfg.get("enabled", False):
# setattr(model.encoder, "vision_pooling_config", vpool_cfg)
# _set_attr_on_base(model.encoder, "vision_pooling_config", vpool_cfg)
# print_master("[VPOOL][eval] enabled with cfg: " + str({
# k: vpool_cfg.get(k) for k in ["apply_to","layer_idx","kernel","stride","method","vision_only"]
# }))
# else:
# print_master("[VPOOL][eval] disabled")
# # 加载分类器
# classifier = None
# if ee_cfg["method"] == "classifier" and ee_cfg["enabled"]:
# classifier_path = ee_cfg['classifier_path']
# print_master(f"[EE] Loading Classifier from {classifier_path}...")
# # 【V5 修改】获取 Backbone Hidden Size 并初始化 V5 分类器
# backbone_hidden_size = model.encoder.config.hidden_size
# print_master(f"[EE] Backbone Hidden Size: {backbone_hidden_size}")
# # 使用 EarlyExitClassifier (其实引用的是 V5)
# classifier = EarlyExitClassifier(
# input_dim=27,
# hidden_dim=128,
# embedding_dim=backbone_hidden_size
# )
# state_dict = None
# if os.path.isdir(classifier_path):
# safetensors_file = os.path.join(classifier_path, "model.safetensors")
# if os.path.exists(safetensors_file):
# from safetensors.torch import load_file
# state_dict = load_file(safetensors_file)
# else:
# layer_idx = ee_cfg.get('layer', 12)
# # 尝试加载训练脚本保存的 .pt 文件
# pt_file = os.path.join(classifier_path, f"early_exit_classifier_layer_{layer_idx}.pt")
# if os.path.exists(pt_file):
# state_dict = torch.load(pt_file, map_location=training_args.device)
# else:
# # 尝试 pytorch_model.bin
# pt_file_bin = os.path.join(classifier_path, "pytorch_model.bin")
# if os.path.exists(pt_file_bin):
# state_dict = torch.load(pt_file_bin, map_location=training_args.device)
# elif os.path.isfile(classifier_path):
# state_dict = torch.load(classifier_path, map_location=training_args.device)
# if state_dict is not None:
# classifier.load_state_dict(state_dict)
# classifier.to(training_args.device) # 默认 FP32
# classifier.eval()
# print_master(f"[EE] Classifier loaded successfully.")
# else:
# raise FileNotFoundError(f"Could not load classifier weights from {classifier_path}")
# with open(data_args.dataset_config, 'r') as yaml_file: dataset_configs = yaml.safe_load(yaml_file)
# for dataset_name, task_config in dataset_configs.items():
# if dist.is_initialized(): dist.barrier()
# print_master(f"\n--- Evaluating {dataset_name} ---")
# base_tau = float(ee_cfg["threshold"])
# ds_tau = PER_DATASET_THRESHOLDS.get(dataset_name, base_tau)
# ee_cfg_ds = dict(ee_cfg); ee_cfg_ds["threshold"] = float(ds_tau)
# if data_args.data_basedir:
# for key in ["image_root", "video_root", "frame_root", "clip_root", "data_path"]:
# if task_config.get(key): task_config[key] = os.path.join(data_args.data_basedir, task_config[key])
# full_eval_qry_dataset, corpus = AutoEvalPairDataset.instantiate(model_args=model_args, data_args=data_args, **task_config)
# full_eval_cand_dataset = generate_cand_dataset(full_eval_qry_dataset, corpus)
# mid_layer = int(ee_cfg_ds["layer"])
# cand_mid_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt_layer{mid_layer}")
# cand_last_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt_layerlast")
# if (not os.path.exists(cand_mid_path)) or (not os.path.exists(cand_last_path)):
# eval_cand_collator = MultimodalEvalDataCollator(processor, model_args, data_args, "cand")
# eval_cand_loader = DataLoader(full_eval_cand_dataset, batch_size=training_args.per_device_eval_batch_size, collate_fn=eval_cand_collator, num_workers=training_args.dataloader_num_workers)
# cand_mid, cand_last, cand_ids = encode_candidates_both_layers(model, eval_cand_loader, training_args, model_args, full_eval_cand_dataset, mid_layer)
# if local_rank == 0:
# with open(cand_mid_path, "wb") as f: pickle.dump({cid: emb for cid, emb in zip(cand_ids, cand_mid)}, f)
# with open(cand_last_path, "wb") as f: pickle.dump({cid: emb for cid, emb in zip(cand_ids, cand_last)}, f)
# if dist.is_initialized(): dist.barrier()
# if local_rank == 0:
# with open(cand_mid_path, "rb") as f: cand_mid_dict = pickle.load(f)
# with open(cand_last_path, "rb") as f: cand_last_dict = pickle.load(f)
# rank_global = task_config.get("eval_type", "global") == "global"
# run_early_exit_queries(
# model, classifier, processor, model_args, data_args, training_args,
# full_eval_qry_dataset, cand_mid_dict, cand_last_dict,
# ee_cfg_ds, dataset_name, data_args.encode_output_path,
# global_ranking=rank_global,
# )
# if dist.is_initialized(): dist.barrier()
# if __name__ == '__main__':
# main()
###################################################################################################################
import datetime
import logging
import json
import random
import time
import numpy as np
import os
import pickle
import sys
import torch
from torch.profiler import ProfilerActivity, profile, schedule, tensorboard_trace_handler
import torch.distributed as dist
import torch.nn.functional as F
import yaml
import transformers
import math
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import HfArgumentParser, AutoConfig, AutoTokenizer
from datasets import Dataset, concatenate_datasets
from datasets.distributed import split_dataset_by_node
from src.arguments import ModelArguments, DataArguments, TrainingArguments
from src.data.collator.eval_collator import MultimodalEvalDataCollator
from src.data.eval_dataset.base_eval_dataset import AutoEvalPairDataset, generate_cand_dataset
from src.eval_utils.metrics import RankingMetrics
from src.model.model_multi_layer_distill_AOP_pooling import MMEBModel
from src.model.processor import get_backbone_name, load_processor, COLPALI
from src.utils import batch_to_device, print_rank, print_master
# ==========================================
# 【V5 修改】引入 V5 版本分类器
# ==========================================
from src.classifier_utils_V5 import EarlyExitClassifier
logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s [%(name)s:%(lineno)s] %(message)s')
logger = logging.getLogger(__name__)
# ===========================
# Per-dataset thresholds (示例,需根据V5分析结果更新)
# ===========================
PER_DATASET_THRESHOLDS = {
"CIRR": 0.43,
"EDIS": 0.46,
"FashionIQ": 0.43,
"MSCOCO_i2t": 0.43,
"MSCOCO_t2i": 0.42,
"NIGHTS": 0.48,
"OVEN": 0.44,
"VisDial": 0.45,
"VisualNews_i2t": 0.43,
"VisualNews_t2i": 0.43,
"WebQA": 0.55,
"Wiki-SS-NQ": 0.45,
}
# ... (Helper Functions 保持不变) ...
def _parse_bool(v: str, default=False):
if v is None: return default
v = v.strip().lower()
return v in {"1","true","yes","y","t","on"}
def _parse_float(v: str, default=None):
try: return float(v) if v is not None else default
except: return default
def _parse_int(v: str, default=None):
try: return int(v) if v is not None else default
except: return default
def get_env_aop_config():
enabled = _parse_bool(os.environ.get("AOP_ENABLED"), False)
apply_to = os.environ.get("AOP_APPLY", "qry").strip().lower()
layer_idx = _parse_int(os.environ.get("AOP_LAYER"), None)
mode = os.environ.get("AOP_MODE", "delta").strip().lower()
delta = _parse_float(os.environ.get("AOP_DELTA"), 0.10)
khat = _parse_float(os.environ.get("AOP_KHAT"), 1.0)
keep_ratio = _parse_float(os.environ.get("AOP_KEEP_RATIO"), 1.0)
min_keep = _parse_int(os.environ.get("AOP_MIN_KEEP"), 64)
use_bias = _parse_bool(os.environ.get("AOP_USE_BIAS"), True)
prune_vision = _parse_bool(os.environ.get("AOP_PRUNE_VISION"), True)
prune_text = _parse_bool(os.environ.get("AOP_PRUNE_TEXT"), False)
keep_ratio_v = _parse_float(os.environ.get("AOP_KEEP_RATIO_VISION"), None)
keep_ratio_t = _parse_float(os.environ.get("AOP_KEEP_RATIO_TEXT"), None)
# 新增:各自的 min_keep
min_keep_v = _parse_int(os.environ.get("AOP_MIN_KEEP_VISION"), None)
min_keep_t = _parse_int(os.environ.get("AOP_MIN_KEEP_TEXT"), 32)
# 新增:文本保护参数
protect_text_last = _parse_int(os.environ.get("AOP_PROTECT_TEXT_LAST"), 16)
protect_special = _parse_bool(os.environ.get("AOP_PROTECT_SPECIAL"), True)
selection = os.environ.get("AOP_SELECTION", "aop").strip().lower()
attn_agg = os.environ.get("AOP_ATTENTION_AGG", "mean").strip().lower()
random_seed = _parse_int(os.environ.get("AOP_RANDOM_SEED"), None)
if layer_idx is None and enabled:
enabled = False
return {
"enabled": enabled,
"apply_to": apply_to,
"layer_idx": layer_idx,
"mode": mode,
"delta": delta,
"K_hat": khat,
"keep_ratio": keep_ratio,
"min_keep": min_keep,
"use_bias": use_bias,
"prune_vision": prune_vision,
"prune_text": prune_text,
"keep_ratio_vision": keep_ratio_v,
"keep_ratio_text": keep_ratio_t,
"min_keep_vision": min_keep_v,
"min_keep_text": min_keep_t,
"protect_text_last": protect_text_last,
"protect_special": protect_special,
"selection": selection,
"attn_agg": attn_agg,
"random_seed": random_seed,
"margin_mid": None,
}
# [新增] VPOOL 配置解析
def get_env_vpool_config():
def _b(k, d=False):
v = os.environ.get(k)
if v is None: return d
return str(v).strip().lower() in {"1","true","yes","y","on"}
def _i(k, d=None):
try: return int(os.environ.get(k, d))
except: return d
def _s(k, d=None):
v = os.environ.get(k, d)
return None if v is None else str(v).strip().lower()
return {
"enabled": _b("VPOOL_ENABLED", False),
"apply_to": _s("VPOOL_APPLY", "both"), # qry|tgt|both
"layer_idx": _i("VPOOL_LAYER", 1),
"kernel": _i("VPOOL_KERNEL", 2),
"stride": _i("VPOOL_STRIDE", None) or _i("VPOOL_KERNEL", 2),
"method": _s("VPOOL_METHOD", "avg"),
"protect_cls": _b("VPOOL_PROTECT_CLS", True),
"vision_only": _b("VPOOL_ONLY_VISION", True),
"monitor": _b("VPOOL_MONITOR", False),
}
def get_env_ee_config():
ee_enabled = os.environ.get("EE_ENABLED", "0").strip().lower() in {"1","true","yes","on","y","t"}
layer = int(os.environ.get("EE_LAYER", "12"))
method = os.environ.get("EE_METHOD", "classifier").strip().lower()
threshold = float(os.environ.get("EE_THRESHOLD", "0.8"))
classifier_path = os.environ.get("EE_CLASSIFIER_PATH", "")
return dict(enabled=ee_enabled, layer=layer, method=method, threshold=threshold, classifier_path=classifier_path)
# 【新增】Helper: 将配置注入底座模型 (修复 NameError)
def _set_attr_on_base(peft_or_base, name, value):
try:
# 尝试获取 PEFT 的 base model
base = peft_or_base.get_base_model() if hasattr(peft_or_base, "get_base_model") else None
# 如果不是 PEFT,或者 get_base_model 返回空,尝试直接访问 .model
if base is None and hasattr(peft_or_base, "model"):
base = peft_or_base.model
# 注入配置
if base is not None:
setattr(base, name, value)
# 针对 Qwen2-VL 这种包装结构 (Qwen2VLForConditionalGeneration -> Qwen2VLModel)
if hasattr(base, "model"):
setattr(base.model, name, value)
except Exception as e:
# 仅打印警告,不阻断流程
print(f"[inject-config] warn: set {name} on base failed: {e}")
def pad_dataset_to_divisible(dataset, world_size):
num_samples = len(dataset)
if num_samples % world_size == 0: return dataset, num_samples
num_to_add = world_size - (num_samples % world_size)
padding_data = dataset.select([i % len(dataset) for i in range(num_to_add)])
return concatenate_datasets([dataset, padding_data]), num_samples + num_to_add
# ===========================
# Core Inference Function
# ===========================
def run_early_exit_queries(
model: MMEBModel,
classifier: EarlyExitClassifier,
processor,
model_args: ModelArguments,
data_args: DataArguments,
training_args: TrainingArguments,
qry_dataset: Dataset,
cand_mid_dict: dict,
cand_last_dict: dict,
ee_cfg: dict,
dataset_name: str,
out_dir: str,
global_ranking: bool = True,
):
device = training_args.device
local_rank = dist.get_rank() if dist.is_initialized() else 0
is_main = (not dist.is_initialized()) or (local_rank == 0)
# Profiling 配置
profile_enabled = os.environ.get("EE_PROFILE", "0").strip().lower() in {"1", "true", "yes", "on", "y", "t"}
torch_prof_enabled = os.environ.get("EE_TORCH_PROFILE", "0").strip().lower() in {"1", "true", "yes", "on", "y", "t"}
topk_emb = int(os.environ.get("EE_TOPK_EMB", "5"))
timing_stats = {
"mid_time_sum": 0.0, "mid_num": 0, "tail_time_sum": 0.0, "tail_num": 0,
}
analysis_records = [] if (profile_enabled and is_main) else None
# 1. 准备 Candidates
cand_ids = list(cand_mid_dict.keys())
cand_mid = np.stack([cand_mid_dict[c] for c in cand_ids]).astype(np.float32)
cand_last = np.stack([cand_last_dict[c] for c in cand_ids]).astype(np.float32)
cand_mid_np = cand_mid # [Nc, D], float32
cand_last_np = cand_last # [Nc, D], float32
cand_mid_t = torch.from_numpy(cand_mid).to(device=device, dtype=torch.bfloat16)
cand_last_t = torch.from_numpy(cand_last).to(device=device, dtype=torch.bfloat16)
collator = MultimodalEvalDataCollator(processor, model_args, data_args, "qry")
loader = DataLoader(
qry_dataset, batch_size=training_args.per_device_eval_batch_size,
collate_fn=collator, num_workers=training_args.dataloader_num_workers
)
pred_dicts = []
stats = {"exit": 0, "total": 0}
threshold = float(ee_cfg["threshold"])
method = ee_cfg["method"]
target_layer_idx = int(ee_cfg["layer"])
results_dict = {}
global_sample_idx = 0
use_local = (not global_ranking)
if use_local:
print_master(f"[INFO] Using LOCAL ranking (per-query candidate sets)")
cand_id2row = {str(cid): i for i, cid in enumerate(cand_ids)}
else:
print_master(f"[INFO] Using GLOBAL ranking (full library)")
# AOP Config
aop_cfg = getattr(model.encoder, "aop_prune_config", None)
_orig_enabled = None
side_enable = True
if isinstance(aop_cfg, dict) and aop_cfg:
_orig_enabled = aop_cfg.get("enabled", False)
apply_to = aop_cfg.get("apply_to", "qry")
side_enable = (apply_to == "both") or (apply_to == "qry")
model.eval()
if classifier:
classifier.eval()
# 【V5 关键】分类器运行在 FP32 (默认),不要转 BF16,配合 Log/LayerNorm
classifier.to(device)
start_time = time.time()
prof = None
if torch_prof_enabled and is_main:
prof_dir = os.path.join(out_dir, "torch_prof", dataset_name)
os.makedirs(prof_dir, exist_ok=True)
prof_schedule = schedule(wait=10, warmup=1, active=1, repeat=0)
prof = profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
schedule=prof_schedule,
on_trace_ready=tensorboard_trace_handler(prof_dir),
record_shapes=False,
profile_memory=False,
with_stack=True,
)
prof.__enter__()
for inputs, infos in tqdm(
loader, desc=f"[EE+AOP] {dataset_name} (tau={threshold})", disable=local_rank > 0,
):
inputs = batch_to_device(inputs, device)
B = inputs["input_ids"].shape[0] if "input_ids" in inputs else 1
batch_start_idx = global_sample_idx
global_sample_idx += B
stats["total"] += B
# ---------------------------------------------------
# 1. 前半程: Run to Mid Layer
# ---------------------------------------------------
orig_cfg = None
if isinstance(aop_cfg, dict) and aop_cfg:
orig_cfg = dict(aop_cfg)
aop_layer = aop_cfg.get("layer_idx", None)
aop_on_mid = bool(
_orig_enabled and side_enable
and (aop_layer is not None)
and (aop_layer <= target_layer_idx)
)
aop_cfg_mid = dict(aop_cfg)
aop_cfg_mid["enabled"] = aop_on_mid
setattr(model.encoder, "aop_prune_config", aop_cfg_mid)
if profile_enabled:
torch.cuda.synchronize()
t0_mid = time.perf_counter()
with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
out_mid = model.encoder(
**inputs, return_dict=True, output_hidden_states=False,#False
stop_at_layer=target_layer_idx, compute_lm_head=False,
return_intermediate_state=True
)
# 这里就是“该插这三行”的地方:
post_attn_mid = getattr(out_mid, "attention_mask", None) # [B, L'] 剪完后的 2D mask
img_mask_mid = getattr(out_mid, "image_token_bool_masks", None) # [B, L'] 视觉 token 掩码(Qwen2.5 里目前是 pre 的,你如果想要 post,可以在 Qwen2_5_VLModel 里改成输出 cur_vision_mask)
txt_mask_mid = getattr(out_mid, "text_token_bool_masks", None) # [B, L'] 文本 token 掩码(已经是 post-prune 的)
# 然后按需统计,比如:
if post_attn_mid is not None:
# 有效 token 总数
keep_counts = post_attn_mid.to(torch.bool).sum(dim=1) # [B]
# 文本 token 保留数
if txt_mask_mid is not None:
text_keep = (post_attn_mid.to(torch.bool) & txt_mask_mid.to(torch.bool)).sum(dim=1)
# 图像 token 保留数
if img_mask_mid is not None:
vis_keep = (post_attn_mid.to(torch.bool) & img_mask_mid.to(torch.bool)).sum(dim=1)
if profile_enabled:
torch.cuda.synchronize()
t1_mid = time.perf_counter()
timing_stats["mid_time_sum"] += (t1_mid - t0_mid) * B
timing_stats["mid_num"] += B
if isinstance(orig_cfg, dict):
setattr(model.encoder, "aop_prune_config", orig_cfg)
hs_mid = getattr(out_mid, "last_hidden_state", None)
if hs_mid is None: hs_mid = out_mid.hidden_states[-1]
am_mid = getattr(out_mid, "attention_mask", None)
if am_mid is None: am_mid = inputs.get("attention_mask", None)
reps_mid_t = model._pooling(hs_mid, am_mid).detach().to(dtype=torch.bfloat16)
# Profiling Last Layer
reps_last_full = None
if profile_enabled:
with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
out_full = model.encoder(
**inputs, return_dict=True, output_hidden_states=False,
stop_at_layer=None, compute_lm_head=False,
return_intermediate_state=True
)
hs_last_full = getattr(out_full, "last_hidden_state", None)
if hs_last_full is None: hs_last_full = out_full.hidden_states[-1]
am_last_full = getattr(out_full, "attention_mask", None)
if am_last_full is None: am_last_full = inputs.get("attention_mask", None)
reps_last_full = model._pooling(hs_last_full, am_last_full).detach().to(dtype=torch.bfloat16)
# ---------------------------------------------------
# 2. 特征工程 + gating
# ---------------------------------------------------
exit_mask = np.zeros(B, dtype=bool)
p_need_last_batch = None
cos_mid = None
if method == "classifier" and classifier is not None:
cos_mid = reps_mid_t @ cand_mid_t.T # [B, Nc] on GPU
with torch.no_grad():
# cos_mid = reps_mid_t @ cand_mid_t.T # [B, N] (BF16)
backbone_ptr = model.module if hasattr(model, "module") else model
temp = getattr(backbone_ptr, "temperature", 0.02)
scores_mid = cos_mid / temp
probs_mid = torch.softmax(scores_mid, dim=1) # [B, N]
# --- 27维特征计算 (BF16) ---
diag_cos = cos_mid.max(dim=1)[0]
# sorted_cos, _ = torch.sort(cos_mid, dim=1, descending=True)
# s2_cos = sorted_cos[:, 1] if sorted_cos.size(1) > 1 else sorted_cos[:, 0]
top2 = torch.topk(cos_mid, k=min(2, cos_mid.size(1)), dim=1).values
diag_cos = top2[:, 0]
s2_cos = top2[:, 1] if top2.size(1) > 1 else top2[:, 0]
margin_mid = diag_cos - s2_cos
margin_mean = margin_mid.mean()
margin_std = margin_mid.std(unbiased=False) + 1e-6
z_margin_mid = (margin_mid - margin_mean) / margin_std
margin_median = margin_mid.median()
mad = (margin_mid - margin_median).abs().median() + 1e-6
mad_margin_mid = (margin_mid - margin_median) / mad
p1_mid = probs_mid.max(dim=1)[0]
H_mid = -(probs_mid * torch.log(probs_mid + 1e-6)).sum(dim=1)
gini_mid = 1.0 - (probs_mid ** 2).sum(dim=1)
TOPK = min(16, probs_mid.size(1))
topk_vals, _ = torch.topk(probs_mid, k=TOPK, dim=1)
topk_mean = topk_vals.mean(dim=1)
topk_std = topk_vals.std(dim=1, unbiased=False)
topk_cv = topk_std / (topk_mean + 1e-6)
centered = topk_vals - topk_mean.unsqueeze(1)
var = (centered ** 2).mean(dim=1) + 1e-6
m4 = (centered ** 4).mean(dim=1)
topk_kurt = m4 / (var ** 2)
topk_med = topk_vals.median(dim=1).values
row_mean_cos = cos_mid.mean(dim=1)
row_med_cos = cos_mid.median(dim=1).values
s1_over_mean = diag_cos - row_mean_cos
s1_over_med = diag_cos - row_med_cos
sorted_probs, _ = torch.sort(probs_mid, dim=1, descending=True)
p1 = sorted_probs[:, 0]
p2 = sorted_probs[:, 1] if sorted_probs.size(1) > 1 else sorted_probs[:, 0]
shape_H = -(sorted_probs * torch.log(sorted_probs + 1e-6)).sum(dim=1)
shape_gini = 1.0 - (sorted_probs ** 2).sum(dim=1)
R = min(10, sorted_probs.size(1))
x = torch.arange(R, device=device, dtype=sorted_probs.dtype)
x_centered = x - x.mean()
denom = (x_centered ** 2).sum()
y = torch.log(sorted_probs[:, :R] + 1e-6)
slope = (x_centered.unsqueeze(0) * y).sum(dim=1) / denom
row_mean_p = probs_mid.mean(dim=1)
row_std_p = probs_mid.std(dim=1, unbiased=False) + 1e-6
z1 = (p1_mid - row_mean_p) / row_std_p
center_p = probs_mid - row_mean_p.unsqueeze(1)
m3 = (center_p ** 3).mean(dim=1)
skew = m3 / (row_std_p ** 3 + 1e-6)
s1_over_sk = p1_mid - skew
TAIL_K = min(10, sorted_probs.size(1))
tail_mean = sorted_probs[:, -TAIL_K:].mean(dim=1)
HEAD_K = min(5, sorted_probs.size(1))
head5_mean = sorted_probs[:, :HEAD_K].mean(dim=1)
mask_ratio = torch.zeros_like(diag_cos)
mask_len = torch.zeros_like(diag_cos)
mask_runs = torch.zeros_like(diag_cos)
scalar_inputs = torch.stack([
diag_cos, s2_cos, margin_mid, z_margin_mid, mad_margin_mid,
p1_mid, H_mid, gini_mid,
topk_mean, topk_std, topk_cv, topk_kurt, topk_med,
s1_over_mean, s1_over_med,
p1, p2, shape_H, shape_gini, slope, z1, s1_over_sk,
tail_mean, head5_mean,
mask_ratio, mask_len, mask_runs
], dim=1)
modality_idx = torch.zeros(B, dtype=torch.long, device=device)
if "pixel_values" in inputs and inputs["pixel_values"] is not None:
pv = inputs["pixel_values"]
if isinstance(pv, list):
for i, item in enumerate(pv):
if item is not None: modality_idx[i] = 1
elif isinstance(pv, torch.Tensor) and pv.numel() > 0:
modality_idx.fill_(1)
# =======================================================
# 【V5 关键】强制转为 FP32 传给分类器
# =======================================================
scalar_inputs_f32 = scalar_inputs.float()
qry_emb_f32 = reps_mid_t.float() # 转 FP32
logits = classifier(scalar_inputs_f32, modality_idx, qry_emb=qry_emb_f32)
p_need_last = torch.sigmoid(logits) # [B,1]
p_need_last_batch = p_need_last.squeeze(1) # [B]
should_exit = p_need_last_batch < threshold
exit_mask = should_exit.cpu().numpy()
if stats["total"] <= B * 3 and is_main:
print_master(
f"[EE Debug] Batch {stats['total']//B}: "
f"p_need_last mean={p_need_last_batch.mean().item():.4f}, "
f"Exit Rate={exit_mask.mean():.2%}, "
f"Top3: diag_cos={diag_cos.mean():.3f}, margin={margin_mid.mean():.3f}"
)
stats["exit"] += exit_mask.sum()
# ---------------------------------------------------
# 3. 分支执行
# ---------------------------------------------------
exit_indices = np.where(exit_mask)[0]
cont_indices = np.where(~exit_mask)[0]
# A. 早停样本
if len(exit_indices) > 0:
reps_exit = reps_mid_t[exit_indices]
if use_local:
# 对每个 query 的 local candidates,只在 cos_mid 上取子集再 topk
for i, idx in enumerate(exit_indices):
cand_local = infos[idx].get("cand_names", [])
pairs = [(str(cid), cand_id2row.get(str(cid), -1)) for cid in cand_local]
pairs = [(cid, r) for cid, r in pairs if r >= 0]
if len(pairs) == 0:
cids = []
else:
cand_local_valid = [cid for cid, _ in pairs]
rows_valid = [r for _, r in pairs]
rows_t = torch.tensor(rows_valid, device=device, dtype=torch.long)
scores = cos_mid[idx].index_select(0, rows_t) # [n_local]
top_k = min(200, scores.numel())
sel = torch.topk(scores, k=top_k, largest=True).indices.cpu().tolist()
cids = [cand_local_valid[j] for j in sel]
_record_result(results_dict, batch_start_idx + idx, infos[idx], cids)
else:
# reps_exit_np = reps_exit.detach().float().cpu().numpy()
# scores_full = np.dot(reps_exit_np, cand_mid_np.T)
# top_k = min(200, len(cand_ids))
# topk_inds = np.argsort(-scores_full, axis=1)[:, :top_k]
# for i, idx in enumerate(exit_indices):
# cids = [cand_ids[k] for k in topk_inds[i]]
# _record_result(results_dict, batch_start_idx + idx, infos[idx], cids)
# cos_mid: [B, Nc]
cos_exit = cos_mid[exit_indices] # [Be, Nc]
top_k = min(200, cos_exit.size(1))
topk_inds = torch.topk(cos_exit, k=top_k, dim=1, largest=True).indices # [Be, top_k]
topk_inds = topk_inds.cpu().tolist()
for i, idx in enumerate(exit_indices):
cids = [cand_ids[j] for j in topk_inds[i]]
_record_result(results_dict, batch_start_idx + idx, infos[idx], cids)
# B. 续跑样本
if len(cont_indices) > 0:
# 取中间态
interm = getattr(out_mid, "intermediate_state", None)
hs, am, pos = interm["hidden_states"].detach(), interm["attention_mask"].detach(), interm["position_ids"].detach()
vm, tm = interm.get("vision_mask", None), interm.get("text_mask", None)
next_layer = int(interm["next_layer_idx"])
resume_state_subset = {
"hidden_states": hs[cont_indices], "attention_mask": am[cont_indices],
"position_ids": pos[:, cont_indices, :],
"vision_mask": vm[cont_indices] if vm is not None else None,
"text_mask": tm[cont_indices] if tm is not None else None,
"next_layer_idx": next_layer,
}
# ====== 这里改 AOP 配置 ======
if isinstance(aop_cfg, dict) and aop_cfg:
aop_resume = dict(aop_cfg)
aop_layer = aop_resume.get("layer_idx", None)
# 情况1:AOP_LAYER 已经 <= EE_LAYER,说明 mid 那次已经剪过了 → tail 不再剪
if (aop_layer is not None) and (aop_layer <= target_layer_idx):
need_prune_in_tail = False
else:
# 情况2:AOP_LAYER 在 EE_LAYER 后面(比如 16),只在 tail 中触发
need_prune_in_tail = bool(_orig_enabled and side_enable)
aop_resume["enabled"] = need_prune_in_tail
setattr(model.encoder, "aop_prune_config", aop_resume)
# ====== AOP 配置修改结束 ======
if profile_enabled:
torch.cuda.synchronize()
t0_tail = time.perf_counter()
with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
out_last = model.encoder(
return_dict=True, output_hidden_states=False, stop_at_layer=None,
resume_state=resume_state_subset, compute_lm_head=False,
return_intermediate_state=False
)
if profile_enabled:
torch.cuda.synchronize()
t1_tail = time.perf_counter()
timing_stats["tail_time_sum"] += (t1_tail - t0_tail) * len(cont_indices)
timing_stats["tail_num"] += len(cont_indices)
hs_last = out_last.last_hidden_state
if hs_last is None: hs_last = out_last.hidden_states[-1]
am_last = getattr(out_last, "attention_mask", None)
if am_last is None: am_last = resume_state_subset["attention_mask"]
reps_last_t = model._pooling(hs_last, am_last).detach().to(dtype=torch.bfloat16)
if use_local:
for i, idx_global in enumerate(cont_indices):
cand_local = infos[idx_global].get("cand_names", [])
if not cand_local: cids = []
else:
pairs = [(str(cid), cand_id2row.get(str(cid), -1)) for cid in cand_local]
pairs = [(cid, r) for cid, r in pairs if r >= 0]
if len(pairs) == 0:
cids = []
else:
cand_local_valid = [cid for cid, _ in pairs]
rows_valid = [r for _, r in pairs]
rows_t = torch.tensor(rows_valid, device=device, dtype=torch.long)
cmat = cand_last_t.index_select(0, rows_t) # [n_local, D]
scores = (reps_last_t[i].unsqueeze(0) @ cmat.T).squeeze(0) # [n_local]
top_k = min(200, scores.numel())
sel = torch.topk(scores, k=top_k, largest=True).indices.cpu().tolist()
cids = [cand_local_valid[j] for j in sel]
_record_result(results_dict, batch_start_idx + idx_global, infos[idx_global], cids)
else:
# reps_last_np = reps_last_t.detach().float().cpu().numpy()
# scores_last = np.dot(reps_last_np, cand_last_np.T)
# top_k = min(200, len(cand_ids))
# topk_inds = np.argsort(-scores_last, axis=1)[:, :top_k]
# for i, idx_global in enumerate(cont_indices):
# cids = [cand_ids[k] for k in topk_inds[i]]
# _record_result(results_dict, batch_start_idx + idx_global, infos[idx_global], cids)
cos_last = reps_last_t @ cand_last_t.T
top_k = min(200, cos_last.size(1))
topk_inds = torch.topk(cos_last, k=top_k, dim=1, largest=True).indices
topk_inds = topk_inds.cpu().tolist()
for i, idx_global in enumerate(cont_indices):
cids = [cand_ids[j] for j in topk_inds[i]]
_record_result(results_dict, batch_start_idx + idx_global, infos[idx_global], cids)
# ---------------------------------------------------
# 4. Profiling Stats
# ---------------------------------------------------
if profile_enabled and is_main:
K = min(topk_emb, cand_mid_t.size(0))
# 转到 float32 + CPU 便于写盘
q_mid_cpu = reps_mid_t.detach().float().cpu() # [B, D]
q_last_cpu = (
reps_last_full.detach().float().cpu()
if reps_last_full is not None
else None
) # [B, D]
cand_mid_cpu = cand_mid_t.detach().float().cpu() # [Nc, D]
cand_last_cpu = cand_last_t.detach().float().cpu() # [Nc, D]
# mid2mid
scores_mid_full = q_mid_cpu @ cand_mid_cpu.T # [B, Nc]
topk_mid_vals, topk_mid_inds = torch.topk(
scores_mid_full, k=K, dim=1
)
# last2last
if q_last_cpu is not None:
scores_last_full = q_last_cpu @ cand_last_cpu.T # [B, Nc]
topk_last_vals, topk_last_inds = torch.topk(
scores_last_full, k=K, dim=1
)
else:
topk_last_vals = None
topk_last_inds = None
for i in range(B):
qid = batch_start_idx + i
rec = {
"qid": int(qid),
"early_exit": bool(exit_mask[i]),
}
if p_need_last_batch is not None:
rec["p_need_last"] = float(p_need_last_batch[i].item())
# mid2mid TopK
mid_inds = topk_mid_inds[i].tolist()
mid_scores = topk_mid_vals[i].tolist()
rec["mid_topk_scores"] = mid_scores
rec["mid_topk_cand_ids"] = [cand_ids[j] for j in mid_inds]
# last2last TopK
if topk_last_inds is not None:
last_inds = topk_last_inds[i].tolist()
last_scores = topk_last_vals[i].tolist()
rec["last_topk_scores"] = last_scores
rec["last_topk_cand_ids"] = [
cand_ids[j] for j in last_inds
]
else:
rec["last_topk_scores"] = None
rec["last_topk_cand_ids"] = None
analysis_records.append(rec)
if prof is not None:
prof.step()
if prof is not None:
# Trigger on_trace_ready if the schedule hasn't naturally advanced yet.
prof.step()
prof.__exit__(None, None, None)
# 5. 收集 & 保存结果
for idx in sorted(results_dict.keys()):
pred_dicts.append(results_dict[idx])
print_master(f"Early Exit Stats: Exit={stats['exit']}/{stats['total']} ({stats['exit']/stats['total']:.2%})")
metrics_to_report = ["hit", "ndcg", "precision", "recall", "f1", "map", "mrr"]
score = RankingMetrics(metrics_to_report).evaluate(pred_dicts)
if is_main:
os.makedirs(out_dir, exist_ok=True)
with open(os.path.join(out_dir, f"{dataset_name}_score_earlyexit.json"), "w") as f:
json.dump(score, f, indent=4)
# Profiling save
if profile_enabled:
prof_dir = os.path.join(out_dir, "profiling")
os.makedirs(prof_dir, exist_ok=True)
mid_avg = timing_stats["mid_time_sum"] / max(1, timing_stats["mid_num"])
tail_avg = timing_stats["tail_time_sum"] / max(1, timing_stats["tail_num"])
timing_out = {
"mid_time_sum": timing_stats["mid_time_sum"],
"mid_num": timing_stats["mid_num"],
"tail_time_sum": timing_stats["tail_time_sum"],
"tail_num": timing_stats["tail_num"],
"avg_mid_time_per_query_sec": mid_avg,
"avg_tail_time_per_cont_query_sec": tail_avg,
"num_exit": int(stats["exit"]),
"num_total": int(stats["total"]),
}
with open(os.path.join(prof_dir, f"{dataset_name}_timing.json"), "w") as f:
json.dump(timing_out, f, indent=2)
embed_path = os.path.join(prof_dir, f"{dataset_name}_embeds.jsonl")
with open(embed_path, "w") as f:
for rec in analysis_records:
f.write(json.dumps(rec) + "\n")
print_master(f"[PROFILE] Saved timing to {prof_dir}, details to {embed_path}")
elapsed = time.time() - start_time
return score, elapsed
def _record_result(results_dict, global_idx, info, cids):
label = info.get("label_name") or info.get("label") or info.get("rel_docids")
if not isinstance(label, list): label = [label]
rel_scores = info.get("rel_scores", None)
results_dict[global_idx] = {
"prediction": cids, "label": label, "rel_scores": rel_scores,
}
# ===========================
# Helper Functions (Pre-Computation)
# ===========================
# ... (encode_candidates_both_layers 保持不变) ...
def encode_candidates_both_layers(model: MMEBModel, loader: DataLoader, training_args: TrainingArguments, model_args: ModelArguments, full_dataset: Dataset, mid_layer: int):
local_rank = dist.get_rank() if dist.is_initialized() else 0
model.eval()
all_mid, all_last, all_ids = [], [], []
with torch.no_grad():
for inputs, dataset_info in tqdm(loader, desc=f"Candidates[BOTH] (rank {local_rank})", disable=local_rank > 0):
inputs = batch_to_device(inputs, training_args.device)
aop_cfg = getattr(model.encoder, "aop_prune_config", None)
_orig = None
if isinstance(aop_cfg, dict):
_orig = aop_cfg.get("enabled", False)
aop_cfg["enabled"] = False
setattr(model.encoder, "aop_prune_config", aop_cfg)
with torch.autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda"):
out = model.encoder(**inputs, return_dict=True, output_hidden_states=True, stop_at_layer=None)
if isinstance(aop_cfg, dict) and _orig is not None: aop_cfg["enabled"] = _orig
# mid_hs = out.hidden_states[mid_layer]
# last_hs = out.hidden_states[-1]
# am = inputs.get("attention_mask", None)
# if am is not None and am.device != mid_hs.device: am = am.to(mid_hs.device)
# reps_mid = model._pooling(mid_hs, am)
# reps_last = model._pooling(last_hs, am)
mid_hs = out.hidden_states[mid_layer]
last_hs = out.hidden_states[-1]
post_am = getattr(out, "attention_mask", None) # post (after VPOOL/AOP)
pre_am = inputs.get("attention_mask", None) # pre
def pick_mask(h, pre_am, post_am):
if post_am is not None and post_am.size(1) == h.size(1):
return post_am
if pre_am is not None and pre_am.size(1) == h.size(1):
return pre_am
return torch.ones(h.size(0), h.size(1), dtype=torch.long, device=h.device)
am_mid = pick_mask(mid_hs, pre_am, post_am)
am_last = pick_mask(last_hs, pre_am, post_am)
reps_mid = model._pooling(mid_hs, am_mid)
reps_last = model._pooling(last_hs, am_last)
all_mid.append(reps_mid.detach().float().cpu())
all_last.append(reps_last.detach().float().cpu())
all_ids.extend([info["cand_name"] for info in dataset_info])
if not all_mid: return np.array([]), np.array([]), []
return torch.cat(all_mid, dim=0).numpy(), torch.cat(all_last, dim=0).numpy(), all_ids
# ===========================
# Main
# ===========================
def main():
if "RANK" in os.environ and dist.is_available() and not dist.is_initialized():
dist.init_process_group(backend="nccl", timeout=datetime.timedelta(minutes=60))
local_rank = dist.get_rank() if dist.is_initialized() else 0
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
ee_cfg = get_env_ee_config()
hf_config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True)
if not getattr(model_args, "model_backbone", None):
model_backbone = get_backbone_name(hf_config=hf_config, model_type=model_args.model_type)
setattr(model_args, 'model_backbone', model_backbone)
processor = load_processor(model_args, data_args)
# 1. 加载模型
model = MMEBModel.load(model_args, is_trainable=False, processor=processor)
model.eval()
model = model.to(training_args.device, dtype=torch.bfloat16)
# 2. 注入 AOP 配置 (全局生效)
aop_cfg = get_env_aop_config()
if aop_cfg["enabled"]:
setattr(model.encoder, "aop_prune_config", aop_cfg)
_set_attr_on_base(model.encoder, "aop_prune_config", aop_cfg)
print_master(f"[AOP] Enabled: {aop_cfg['apply_to']} @ layer {aop_cfg['layer_idx']}")
# 3. 注入 VPOOL 配置 (全局生效)
vpool_cfg = get_env_vpool_config()
if vpool_cfg["enabled"]:
setattr(model.encoder, "vision_pooling_config", vpool_cfg)
_set_attr_on_base(model.encoder, "vision_pooling_config", vpool_cfg)
print_master(f"[VPOOL] Enabled: {vpool_cfg['apply_to']} @ layer {vpool_cfg['layer_idx']}")
# [删除] 这一行是旧代码遗留,新模型不需要它
# model.set_inference_layers(qry_layers=None, tgt_layers=None)
# 4. 加载分类器
classifier = None
if ee_cfg["method"] == "classifier" and ee_cfg["enabled"]:
classifier_path = ee_cfg['classifier_path']
print_master(f"[EE] Loading Classifier from {classifier_path}...")
# 获取 Embedding 维度 (AOP 不改变隐层维度,只改变长度)
backbone_hidden_size = model.encoder.config.hidden_size
print_master(f"[EE] Backbone Hidden Size: {backbone_hidden_size}")
classifier = EarlyExitClassifier(
input_dim=27,
hidden_dim=128,
embedding_dim=backbone_hidden_size
)
state_dict = None
if os.path.isdir(classifier_path):
safetensors_file = os.path.join(classifier_path, "model.safetensors")
if os.path.exists(safetensors_file):
from safetensors.torch import load_file
state_dict = load_file(safetensors_file)
else:
pt_file_bin = os.path.join(classifier_path, "pytorch_model.bin")
if os.path.exists(pt_file_bin):
state_dict = torch.load(pt_file_bin, map_location=training_args.device)
else:
layer_idx = ee_cfg.get('layer', 12)
pt_file = os.path.join(classifier_path, f"early_exit_classifier_layer_{layer_idx}.pt")
if os.path.exists(pt_file):
state_dict = torch.load(pt_file, map_location=training_args.device)
elif os.path.isfile(classifier_path):
state_dict = torch.load(classifier_path, map_location=training_args.device)
if state_dict is not None:
classifier.load_state_dict(state_dict)
classifier.to(training_args.device) # FP32
classifier.eval()
print_master(f"[EE] Classifier loaded successfully.")
else:
raise FileNotFoundError(f"Could not load classifier weights from {classifier_path}")
# 5. 开始评测循环
with open(data_args.dataset_config, 'r') as yaml_file: dataset_configs = yaml.safe_load(yaml_file)
for dataset_name, task_config in dataset_configs.items():
if dist.is_initialized(): dist.barrier()
print_master(f"\n--- Evaluating {dataset_name} ---")
base_tau = float(ee_cfg["threshold"])
ds_tau = PER_DATASET_THRESHOLDS.get(dataset_name, base_tau)
ee_cfg_ds = dict(ee_cfg); ee_cfg_ds["threshold"] = float(ds_tau)
if data_args.data_basedir:
for key in ["image_root", "video_root", "frame_root", "clip_root", "data_path"]:
if task_config.get(key): task_config[key] = os.path.join(data_args.data_basedir, task_config[key])
full_eval_qry_dataset, corpus = AutoEvalPairDataset.instantiate(model_args=model_args, data_args=data_args, **task_config)
full_eval_cand_dataset = generate_cand_dataset(full_eval_qry_dataset, corpus)
mid_layer = int(ee_cfg_ds["layer"])
cand_mid_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt_layer{mid_layer}")
cand_last_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt_layerlast")
if (not os.path.exists(cand_mid_path)) or (not os.path.exists(cand_last_path)):
eval_cand_collator = MultimodalEvalDataCollator(processor, model_args, data_args, "cand")
eval_cand_loader = DataLoader(full_eval_cand_dataset, batch_size=training_args.per_device_eval_batch_size, collate_fn=eval_cand_collator, num_workers=training_args.dataloader_num_workers)
# 【关键】Candidate 编码也必须经过 AOP/VPOOL
cand_mid, cand_last, cand_ids = encode_candidates_both_layers(model, eval_cand_loader, training_args, model_args, full_eval_cand_dataset, mid_layer)
if local_rank == 0:
with open(cand_mid_path, "wb") as f: pickle.dump({cid: emb for cid, emb in zip(cand_ids, cand_mid)}, f)
with open(cand_last_path, "wb") as f: pickle.dump({cid: emb for cid, emb in zip(cand_ids, cand_last)}, f)
if dist.is_initialized(): dist.barrier()
if local_rank == 0:
with open(cand_mid_path, "rb") as f: cand_mid_dict = pickle.load(f)
with open(cand_last_path, "rb") as f: cand_last_dict = pickle.load(f)
rank_global = task_config.get("eval_type", "global") == "global"
run_early_exit_queries(
model, classifier, processor, model_args, data_args, training_args,
full_eval_qry_dataset, cand_mid_dict, cand_last_dict,
ee_cfg_ds, dataset_name, data_args.encode_output_path,
global_ranking=rank_global,
)
if dist.is_initialized(): dist.barrier()
if __name__ == '__main__':
main()