code_SAS_VLM2Vec / eval_test_time_with_classifier_AOP_attn_pooling.py
MgGladys's picture
Add files using upload-large-folder tool
2a40e7a verified
import datetime
import logging
import json
import random
import time
import numpy as np
import os
import pickle
import sys
import hashlib
import torch
import torch.distributed as dist
import torch.nn.functional as F
import yaml
import transformers
import math
from contextlib import contextmanager
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import HfArgumentParser, AutoConfig, AutoTokenizer
from datasets import Dataset, concatenate_datasets
from datasets.distributed import split_dataset_by_node
from src.arguments import ModelArguments, DataArguments, TrainingArguments
from src.data.collator.eval_collator import MultimodalEvalDataCollator
from src.data.eval_dataset.base_eval_dataset import AutoEvalPairDataset, generate_cand_dataset
from src.eval_utils.metrics import RankingMetrics
from src.model.model_multilayer_AOP_infer_attn_pooling import MMEBModel
from src.model.processor import get_backbone_name, load_processor, COLPALI
from src.utils import batch_to_device, print_rank, print_master
# ==========================================
# 【V5 修改】引入 V5 版本分类器
# ==========================================
from src.classifier_utils_V5 import EarlyExitClassifier
logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s [%(name)s:%(lineno)s] %(message)s')
logger = logging.getLogger(__name__)
# ===========================
# Per-dataset thresholds (示例,需根据V5分析结果更新)
# ===========================
PER_DATASET_THRESHOLDS = {
"CIRR": 0.45,
"EDIS": 0.45,
"FashionIQ": 0.45,
"MSCOCO_i2t": 0.45,
"MSCOCO_t2i": 0.45,
"NIGHTS": 0.45,
"OVEN": 0.45,
"VisDial": 0.45,
"VisualNews_i2t": 0.45,
"VisualNews_t2i": 0.45,
"WebQA": 0.45,
"Wiki-SS-NQ": 0.45,
}
# ... (Helper Functions 保持不变) ...
def _parse_bool(v: str, default=False):
if v is None: return default
v = v.strip().lower()
return v in {"1","true","yes","y","t","on"}
def _parse_float(v: str, default=None):
try: return float(v) if v is not None else default
except: return default
def _parse_int(v: str, default=None):
try: return int(v) if v is not None else default
except: return default
def get_env_aop_config():
enabled = _parse_bool(os.environ.get("AOP_ENABLED"), False)
apply_to = os.environ.get("AOP_APPLY", "qry").strip().lower()
layer_idx = _parse_int(os.environ.get("AOP_LAYER"), None)
mode = os.environ.get("AOP_MODE", "delta").strip().lower()
delta = _parse_float(os.environ.get("AOP_DELTA"), 0.10)
khat = _parse_float(os.environ.get("AOP_KHAT"), 1.0)
keep_ratio = _parse_float(os.environ.get("AOP_KEEP_RATIO"), 1.0)
min_keep = _parse_int(os.environ.get("AOP_MIN_KEEP"), 64)
use_bias = _parse_bool(os.environ.get("AOP_USE_BIAS"), True)
prune_vision = _parse_bool(os.environ.get("AOP_PRUNE_VISION"), True)
prune_text = _parse_bool(os.environ.get("AOP_PRUNE_TEXT"), False)
keep_ratio_v = _parse_float(os.environ.get("AOP_KEEP_RATIO_VISION"), None)
keep_ratio_t = _parse_float(os.environ.get("AOP_KEEP_RATIO_TEXT"), None)
# 新增:各自的 min_keep
min_keep_v = _parse_int(os.environ.get("AOP_MIN_KEEP_VISION"), None)
min_keep_t = _parse_int(os.environ.get("AOP_MIN_KEEP_TEXT"), 32)
# 新增:文本保护参数
protect_text_last = _parse_int(os.environ.get("AOP_PROTECT_TEXT_LAST"), 16)
protect_special = _parse_bool(os.environ.get("AOP_PROTECT_SPECIAL"), True)
selection = os.environ.get("AOP_SELECTION", "aop").strip().lower()
attn_agg = os.environ.get("AOP_ATTENTION_AGG", "mean").strip().lower()
random_seed = _parse_int(os.environ.get("AOP_RANDOM_SEED"), None)
if layer_idx is None and enabled:
enabled = False
return {
"enabled": enabled,
"apply_to": apply_to,
"layer_idx": layer_idx,
"mode": mode,
"delta": delta,
"K_hat": khat,
"keep_ratio": keep_ratio,
"min_keep": min_keep,
"use_bias": use_bias,
"prune_vision": prune_vision,
"prune_text": prune_text,
"keep_ratio_vision": keep_ratio_v,
"keep_ratio_text": keep_ratio_t,
"min_keep_vision": min_keep_v,
"min_keep_text": min_keep_t,
"protect_text_last": protect_text_last,
"protect_special": protect_special,
"selection": selection,
"attn_agg": attn_agg,
"random_seed": random_seed,
"margin_mid": None,
}
# [新增] VPOOL 配置解析
def get_env_vpool_config():
def _b(k, d=False):
v = os.environ.get(k)
if v is None: return d
return str(v).strip().lower() in {"1","true","yes","y","on"}
def _i(k, d=None):
try: return int(os.environ.get(k, d))
except: return d
def _s(k, d=None):
v = os.environ.get(k, d)
return None if v is None else str(v).strip().lower()
def _f(k, d=None):
try: return float(os.environ.get(k, d))
except: return d
return {
"enabled": _b("VPOOL_ENABLED", False),
"apply_to": _s("VPOOL_APPLY", "both"), # qry|tgt|both
"layer_idx": _i("VPOOL_LAYER", 1),
"kernel": _i("VPOOL_KERNEL", 2),
"stride": _i("VPOOL_STRIDE", None) or _i("VPOOL_KERNEL", 2),
"method": _s("VPOOL_METHOD", "avg"),
"protect_cls": _b("VPOOL_PROTECT_CLS", True),
"vision_only": _b("VPOOL_ONLY_VISION", True),
"monitor": _b("VPOOL_MONITOR", False),
# NEW
"attn_tau": _f("VPOOL_ATTN_TAU", 1.0),
}
def get_env_ee_config():
ee_enabled = os.environ.get("EE_ENABLED", "0").strip().lower() in {"1","true","yes","on","y","t"}
layer = int(os.environ.get("EE_LAYER", "12"))
method = os.environ.get("EE_METHOD", "classifier").strip().lower()
threshold = float(os.environ.get("EE_THRESHOLD", "0.8"))
classifier_path = os.environ.get("EE_CLASSIFIER_PATH", "")
return dict(enabled=ee_enabled, layer=layer, method=method, threshold=threshold, classifier_path=classifier_path)
# 【新增】Helper: 将配置注入底座模型 (修复 NameError)
def _set_attr_on_base(peft_or_base, name, value):
try:
# 尝试获取 PEFT 的 base model
base = peft_or_base.get_base_model() if hasattr(peft_or_base, "get_base_model") else None
# 如果不是 PEFT,或者 get_base_model 返回空,尝试直接访问 .model
if base is None and hasattr(peft_or_base, "model"):
base = peft_or_base.model
# 注入配置
if base is not None:
setattr(base, name, value)
# 针对 Qwen2-VL 这种包装结构 (Qwen2VLForConditionalGeneration -> Qwen2VLModel)
if hasattr(base, "model"):
setattr(base.model, name, value)
except Exception as e:
# 仅打印警告,不阻断流程
print(f"[inject-config] warn: set {name} on base failed: {e}")
def _maybe_load_vpool_attn_pooler_into_encoder(model: MMEBModel, ckpt_dir: str):
"""
推理时把 vpool_attn_pooler 权重加载进 encoder(解决 base 随机初始化 pooler 的问题)。
兼容两种 state_dict key:
- 直接保存 pooler.state_dict(): 'mlp.0.weight'...
- peft modules_to_save wrapper: 'modules_to_save.default.mlp.0.weight'...
"""
if not ckpt_dir or (not os.path.isdir(ckpt_dir)):
print_master(f"[VPOOL-ATTN] ckpt_dir not found: {ckpt_dir}")
return False
enc = model.encoder
inner = getattr(enc, "model", None)
pooler = getattr(inner, "vpool_attn_pooler", None) if inner is not None else None
if pooler is None and hasattr(enc, "vpool_attn_pooler"):
pooler = enc.vpool_attn_pooler
if pooler is None:
print_master("[VPOOL-ATTN] pooler not found in encoder, skip loading.")
return False
path = os.path.join(ckpt_dir, "vpool_attn_pooler.safetensors")
if not os.path.isfile(path):
path = os.path.join(ckpt_dir, "vpool_attn_pooler.pt")
if not os.path.isfile(path):
print_master(f"[VPOOL-ATTN] no pooler weight file found under {ckpt_dir}")
return False
# 选真正 load 的目标(如果 pooler 是 wrapper)
target = pooler
if hasattr(pooler, "modules_to_save") and isinstance(pooler.modules_to_save, dict) and "default" in pooler.modules_to_save:
target = pooler.modules_to_save["default"]
elif hasattr(pooler, "original_module"):
target = pooler.original_module
try:
if path.endswith(".safetensors"):
from safetensors.torch import load_file
sd_raw = load_file(path, device="cpu")
else:
sd_raw = torch.load(path, map_location="cpu")
# 兼容 peft wrapper 前缀
def _strip_prefix(sd: dict, prefix: str) -> dict:
return {k[len(prefix):]: v for k, v in sd.items() if k.startswith(prefix)}
sd = sd_raw
if any(k.startswith("modules_to_save.default.") for k in sd.keys()):
sd = _strip_prefix(sd, "modules_to_save.default.")
elif any(k.startswith("original_module.") for k in sd.keys()):
sd = _strip_prefix(sd, "original_module.")
missing, unexpected = target.load_state_dict(sd, strict=False)
# 打印一个确认值
try:
mean_abs = float(target.mlp[3].weight.detach().abs().mean().cpu().item())
except Exception:
mean_abs = None
print_master(f"[VPOOL-ATTN] loaded pooler from {path}. missing={missing}, unexpected={unexpected}, mlp3_mean_abs={mean_abs}")
return (len(missing) == 0)
except Exception as e:
print_master(f"[VPOOL-ATTN] load pooler failed: {e}")
return False
def _norm_side(side: str) -> str:
s = str(side).strip().lower()
if s in {"cand", "tgt", "target"}:
return "tgt"
return "qry"
def _apply_to_match(apply_to: str, side: str) -> bool:
"""
apply_to: qry|tgt|cand|both
side: qry|tgt
"""
a = str(apply_to).strip().lower()
if a == "cand":
a = "tgt"
if a == "both":
return True
return a == side
@contextmanager
def _temp_cfg_enabled(model: MMEBModel, cfg_name: str, enable: bool):
"""
临时改 encoder.<cfg_name> 的 enabled,并同步到底座(ForConditionalGeneration / .model)。
cfg_name: "aop_prune_config" or "vision_pooling_config"
"""
enc = model.encoder
old = getattr(enc, cfg_name, None)
if isinstance(old, dict):
new = dict(old)
new["enabled"] = bool(enable and old.get("enabled", False))
setattr(enc, cfg_name, new)
_set_attr_on_base(enc, cfg_name, new)
try:
yield
finally:
setattr(enc, cfg_name, old)
_set_attr_on_base(enc, cfg_name, old)
def _pool_from_outputs(model: MMEBModel, hs: torch.Tensor, out, fallback_mask: torch.Tensor | None, lens_index: int | None = None):
"""
统一 pooling:优先使用 out.attn_lens(如果模型支持且你已按前面 patch 加入),否则用 mask。
lens_index=None -> 用最后一项(对应 last_hidden_state)。
"""
attn_lens = getattr(out, "attn_lens", None)
if hasattr(model, "_pooling_from_lens") and isinstance(attn_lens, (list, tuple)) and len(attn_lens) > 0:
if lens_index is None:
lens_1d = attn_lens[-1]
else:
if 0 <= lens_index < len(attn_lens):
lens_1d = attn_lens[lens_index]
else:
lens_1d = attn_lens[-1]
if lens_1d is not None:
return model._pooling_from_lens(hs, lens_1d)
# fallback: 用 mask(优先 out.attention_mask,次选输入 mask)
am = getattr(out, "attention_mask", None)
if am is None:
am = fallback_mask
return model._pooling(hs, am)
def _cfg_cache_tag(aop_cfg: dict | None, vpool_cfg: dict | None, mid_layer: int) -> str:
"""
生成 candidate 缓存文件名的短 hash,避免切换 AOP/VPOOL 配置后误读旧缓存。
"""
payload = {
"mid_layer": int(mid_layer),
"aop": aop_cfg if isinstance(aop_cfg, dict) else None,
"vpool": vpool_cfg if isinstance(vpool_cfg, dict) else None,
}
s = json.dumps(payload, sort_keys=True, ensure_ascii=False)
return hashlib.md5(s.encode("utf-8")).hexdigest()[:10]
def pad_dataset_to_divisible(dataset, world_size):
num_samples = len(dataset)
if num_samples % world_size == 0: return dataset, num_samples
num_to_add = world_size - (num_samples % world_size)
padding_data = dataset.select([i % len(dataset) for i in range(num_to_add)])
return concatenate_datasets([dataset, padding_data]), num_samples + num_to_add
# ===========================
# Core Inference Function
# ===========================
def run_early_exit_queries(
model: MMEBModel,
classifier: EarlyExitClassifier,
processor,
model_args: ModelArguments,
data_args: DataArguments,
training_args: TrainingArguments,
qry_dataset: Dataset,
cand_mid_dict: dict,
cand_last_dict: dict,
ee_cfg: dict,
dataset_name: str,
out_dir: str,
global_ranking: bool = True,
):
device = training_args.device
local_rank = dist.get_rank() if dist.is_initialized() else 0
is_main = (not dist.is_initialized()) or (local_rank == 0)
# Profiling 配置
profile_enabled = os.environ.get("EE_PROFILE", "0").strip().lower() in {"1", "true", "yes", "on", "y", "t"}
topk_emb = int(os.environ.get("EE_TOPK_EMB", "5"))
timing_stats = {
"mid_time_sum": 0.0, "mid_num": 0, "tail_time_sum": 0.0, "tail_num": 0,
}
analysis_records = [] if (profile_enabled and is_main) else None
# 1. 准备 Candidates
cand_ids = list(cand_mid_dict.keys())
cand_mid = np.stack([cand_mid_dict[c] for c in cand_ids]).astype(np.float32)
cand_last = np.stack([cand_last_dict[c] for c in cand_ids]).astype(np.float32)
cand_mid_np = cand_mid # [Nc, D], float32
cand_last_np = cand_last # [Nc, D], float32
cand_mid_t = torch.from_numpy(cand_mid).to(device=device, dtype=torch.bfloat16)
cand_last_t = torch.from_numpy(cand_last).to(device=device, dtype=torch.bfloat16)
collator = MultimodalEvalDataCollator(processor, model_args, data_args, "qry")
loader = DataLoader(
qry_dataset, batch_size=training_args.per_device_eval_batch_size,
collate_fn=collator, num_workers=training_args.dataloader_num_workers
)
pred_dicts = []
stats = {"exit": 0, "total": 0}
threshold = float(ee_cfg["threshold"])
method = ee_cfg["method"]
target_layer_idx = int(ee_cfg["layer"])
results_dict = {}
global_sample_idx = 0
use_local = (not global_ranking)
if use_local:
print_master(f"[INFO] Using LOCAL ranking (per-query candidate sets)")
cand_id2row = {str(cid): i for i, cid in enumerate(cand_ids)}
else:
print_master(f"[INFO] Using GLOBAL ranking (full library)")
# AOP/VPOOL config(用于按侧 gating)
aop_cfg = getattr(model.encoder, "aop_prune_config", None)
vpool_cfg = getattr(model.encoder, "vision_pooling_config", None)
_orig_enabled = None
side_enable = True
if isinstance(aop_cfg, dict) and aop_cfg:
_orig_enabled = aop_cfg.get("enabled", False)
apply_to = aop_cfg.get("apply_to", "qry")
side_enable = (apply_to == "both") or (apply_to == "qry")
model.eval()
if classifier:
classifier.eval()
# 【V5 关键】分类器运行在 FP32 (默认),不要转 BF16,配合 Log/LayerNorm
classifier.to(device)
start_time = time.time()
for inputs, infos in tqdm(
loader, desc=f"[EE+AOP] {dataset_name} (tau={threshold})", disable=local_rank > 0,
):
inputs = batch_to_device(inputs, device)
B = inputs["input_ids"].shape[0] if "input_ids" in inputs else 1
batch_start_idx = global_sample_idx
global_sample_idx += B
stats["total"] += B
# ---------------------------------------------------
# 1. 前半程: Run to Mid Layer
# ---------------------------------------------------
side = "qry"
side_ok_aop = isinstance(aop_cfg, dict) and _apply_to_match(aop_cfg.get("apply_to", "both"), side)
side_ok_vp = isinstance(vpool_cfg, dict) and _apply_to_match(vpool_cfg.get("apply_to", "both"), side)
# mid 是否启用 AOP:仅当 AOP_LAYER <= EE_LAYER 时 mid forward 会触发剪枝
aop_layer = aop_cfg.get("layer_idx", None) if isinstance(aop_cfg, dict) else None
aop_on_mid = bool(isinstance(aop_cfg, dict) and aop_cfg.get("enabled", False) and side_ok_aop and (aop_layer is not None) and (aop_layer <= target_layer_idx))
if profile_enabled:
torch.cuda.synchronize()
t0_mid = time.perf_counter()
with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
with _temp_cfg_enabled(model, "aop_prune_config", aop_on_mid):
with _temp_cfg_enabled(model, "vision_pooling_config", bool(isinstance(vpool_cfg, dict) and vpool_cfg.get("enabled", False) and side_ok_vp)):
# NEW: 兼容部分 encoder 不支持 compute_lm_head 参数
try:
out_mid = model.encoder(
**inputs,
return_dict=True,
output_hidden_states=False,
stop_at_layer=target_layer_idx,
compute_lm_head=False,
return_intermediate_state=True,
)
except TypeError:
out_mid = model.encoder(
**inputs,
return_dict=True,
output_hidden_states=False,
stop_at_layer=target_layer_idx,
return_intermediate_state=True,
)
# 这里就是“该插这三行”的地方:
post_attn_mid = getattr(out_mid, "attention_mask", None) # [B, L'] 剪完后的 2D mask
img_mask_mid = getattr(out_mid, "image_token_bool_masks", None) # [B, L'] 视觉 token 掩码(Qwen2.5 里目前是 pre 的,你如果想要 post,可以在 Qwen2_5_VLModel 里改成输出 cur_vision_mask)
txt_mask_mid = getattr(out_mid, "text_token_bool_masks", None) # [B, L'] 文本 token 掩码(已经是 post-prune 的)
# 然后按需统计,比如:
if post_attn_mid is not None:
# 有效 token 总数
keep_counts = post_attn_mid.to(torch.bool).sum(dim=1) # [B]
# 文本 token 保留数
if txt_mask_mid is not None:
text_keep = (post_attn_mid.to(torch.bool) & txt_mask_mid.to(torch.bool)).sum(dim=1)
# 图像 token 保留数
if img_mask_mid is not None:
vis_keep = (post_attn_mid.to(torch.bool) & img_mask_mid.to(torch.bool)).sum(dim=1)
if profile_enabled:
torch.cuda.synchronize()
t1_mid = time.perf_counter()
timing_stats["mid_time_sum"] += (t1_mid - t0_mid) * B
timing_stats["mid_num"] += B
hs_mid = getattr(out_mid, "last_hidden_state", None)
if hs_mid is None:
# 兼容:少数实现可能没填 last_hidden_state
hs_mid = out_mid[0]
reps_mid_t = _pool_from_outputs(model, hs_mid, out_mid, inputs.get("attention_mask", None)).detach().to(dtype=torch.bfloat16)
# Profiling Last Layer
reps_last_full = None
if profile_enabled:
with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
try:
out_full = model.encoder(
**inputs, return_dict=True, output_hidden_states=False,
stop_at_layer=None, compute_lm_head=False,
return_intermediate_state=False
)
except TypeError:
out_full = model.encoder(
**inputs, return_dict=True, output_hidden_states=False,
stop_at_layer=None,
return_intermediate_state=False
)
hs_last_full = getattr(out_full, "last_hidden_state", None)
if hs_last_full is None:
hs_last_full = out_full[0]
reps_last_full = _pool_from_outputs(model, hs_last_full, out_full, inputs.get("attention_mask", None)).detach().to(dtype=torch.bfloat16)
# ---------------------------------------------------
# 2. 特征工程 + gating
# ---------------------------------------------------
exit_mask = np.zeros(B, dtype=bool)
p_need_last_batch = None
if method == "classifier" and classifier is not None:
with torch.no_grad():
cos_mid = reps_mid_t @ cand_mid_t.T # [B, N] (BF16)
backbone_ptr = model.module if hasattr(model, "module") else model
temp = getattr(backbone_ptr, "temperature", 0.02)
scores_mid = cos_mid / temp
probs_mid = torch.softmax(scores_mid, dim=1) # [B, N]
# --- 27维特征计算 (BF16) ---
diag_cos = cos_mid.max(dim=1)[0]
sorted_cos, _ = torch.sort(cos_mid, dim=1, descending=True)
s2_cos = sorted_cos[:, 1] if sorted_cos.size(1) > 1 else sorted_cos[:, 0]
margin_mid = diag_cos - s2_cos
margin_mean = margin_mid.mean()
margin_std = margin_mid.std(unbiased=False) + 1e-6
z_margin_mid = (margin_mid - margin_mean) / margin_std
margin_median = margin_mid.median()
mad = (margin_mid - margin_median).abs().median() + 1e-6
mad_margin_mid = (margin_mid - margin_median) / mad
p1_mid = probs_mid.max(dim=1)[0]
H_mid = -(probs_mid * torch.log(probs_mid + 1e-6)).sum(dim=1)
gini_mid = 1.0 - (probs_mid ** 2).sum(dim=1)
TOPK = min(16, probs_mid.size(1))
topk_vals, _ = torch.topk(probs_mid, k=TOPK, dim=1)
topk_mean = topk_vals.mean(dim=1)
topk_std = topk_vals.std(dim=1, unbiased=False)
topk_cv = topk_std / (topk_mean + 1e-6)
centered = topk_vals - topk_mean.unsqueeze(1)
var = (centered ** 2).mean(dim=1) + 1e-6
m4 = (centered ** 4).mean(dim=1)
topk_kurt = m4 / (var ** 2)
topk_med = topk_vals.median(dim=1).values
row_mean_cos = cos_mid.mean(dim=1)
row_med_cos = cos_mid.median(dim=1).values
s1_over_mean = diag_cos - row_mean_cos
s1_over_med = diag_cos - row_med_cos
sorted_probs, _ = torch.sort(probs_mid, dim=1, descending=True)
p1 = sorted_probs[:, 0]
p2 = sorted_probs[:, 1] if sorted_probs.size(1) > 1 else sorted_probs[:, 0]
shape_H = -(sorted_probs * torch.log(sorted_probs + 1e-6)).sum(dim=1)
shape_gini = 1.0 - (sorted_probs ** 2).sum(dim=1)
R = min(10, sorted_probs.size(1))
x = torch.arange(R, device=device, dtype=sorted_probs.dtype)
x_centered = x - x.mean()
denom = (x_centered ** 2).sum()
y = torch.log(sorted_probs[:, :R] + 1e-6)
slope = (x_centered.unsqueeze(0) * y).sum(dim=1) / denom
row_mean_p = probs_mid.mean(dim=1)
row_std_p = probs_mid.std(dim=1, unbiased=False) + 1e-6
z1 = (p1_mid - row_mean_p) / row_std_p
center_p = probs_mid - row_mean_p.unsqueeze(1)
m3 = (center_p ** 3).mean(dim=1)
skew = m3 / (row_std_p ** 3 + 1e-6)
s1_over_sk = p1_mid - skew
TAIL_K = min(10, sorted_probs.size(1))
tail_mean = sorted_probs[:, -TAIL_K:].mean(dim=1)
HEAD_K = min(5, sorted_probs.size(1))
head5_mean = sorted_probs[:, :HEAD_K].mean(dim=1)
mask_ratio = torch.zeros_like(diag_cos)
mask_len = torch.zeros_like(diag_cos)
mask_runs = torch.zeros_like(diag_cos)
scalar_inputs = torch.stack([
diag_cos, s2_cos, margin_mid, z_margin_mid, mad_margin_mid,
p1_mid, H_mid, gini_mid,
topk_mean, topk_std, topk_cv, topk_kurt, topk_med,
s1_over_mean, s1_over_med,
p1, p2, shape_H, shape_gini, slope, z1, s1_over_sk,
tail_mean, head5_mean,
mask_ratio, mask_len, mask_runs
], dim=1)
modality_idx = torch.zeros(B, dtype=torch.long, device=device)
if "pixel_values" in inputs and inputs["pixel_values"] is not None:
pv = inputs["pixel_values"]
if isinstance(pv, list):
for i, item in enumerate(pv):
if item is not None: modality_idx[i] = 1
elif isinstance(pv, torch.Tensor) and pv.numel() > 0:
modality_idx.fill_(1)
# =======================================================
# 【V5 关键】强制转为 FP32 传给分类器
# =======================================================
scalar_inputs_f32 = scalar_inputs.float()
qry_emb_f32 = reps_mid_t.float() # 转 FP32
logits = classifier(scalar_inputs_f32, modality_idx, qry_emb=qry_emb_f32)
p_need_last = torch.sigmoid(logits) # [B,1]
p_need_last_batch = p_need_last.squeeze(1) # [B]
should_exit = p_need_last_batch < threshold
exit_mask = should_exit.cpu().numpy()
if stats["total"] <= B * 3 and is_main:
print_master(
f"[EE Debug] Batch {stats['total']//B}: "
f"p_need_last mean={p_need_last_batch.mean().item():.4f}, "
f"Exit Rate={exit_mask.mean():.2%}, "
f"Top3: diag_cos={diag_cos.mean():.3f}, margin={margin_mid.mean():.3f}"
)
stats["exit"] += exit_mask.sum()
# ---------------------------------------------------
# 3. 分支执行
# ---------------------------------------------------
exit_indices = np.where(exit_mask)[0]
cont_indices = np.where(~exit_mask)[0]
# A. 早停样本
if len(exit_indices) > 0:
reps_exit = reps_mid_t[exit_indices]
if use_local:
for i, idx in enumerate(exit_indices):
cand_local = infos[idx].get("cand_names", [])
if not cand_local: cids = []
else:
rows = [cand_id2row.get(str(cid), -1) for cid in cand_local]
rows = [r for r in rows if r >= 0]
if len(rows) == 0: cids = []
else:
cmat_np = cand_mid_np[rows]
qry_np = reps_exit[i].detach().float().cpu().numpy()
scores_vec = np.dot(qry_np, cmat_np.T)
top_k = min(200, len(rows))
order_local = np.argsort(-scores_vec)[:top_k]
cids = [str(cand_local[o]) for o in order_local]
_record_result(results_dict, batch_start_idx + idx, infos[idx], cids)
else:
reps_exit_np = reps_exit.detach().float().cpu().numpy()
scores_full = np.dot(reps_exit_np, cand_mid_np.T)
top_k = min(200, len(cand_ids))
topk_inds = np.argsort(-scores_full, axis=1)[:, :top_k]
for i, idx in enumerate(exit_indices):
cids = [cand_ids[k] for k in topk_inds[i]]
_record_result(results_dict, batch_start_idx + idx, infos[idx], cids)
# B. 续跑样本
if len(cont_indices) > 0:
# 取中间态
interm = getattr(out_mid, "intermediate_state", None)
hs, am, pos = interm["hidden_states"].detach(), interm["attention_mask"].detach(), interm["position_ids"].detach()
vm, tm = interm.get("vision_mask", None), interm.get("text_mask", None)
next_layer = int(interm["next_layer_idx"])
resume_state_subset = {
"hidden_states": hs[cont_indices], "attention_mask": am[cont_indices],
"position_ids": pos[:, cont_indices, :],
"vision_mask": vm[cont_indices] if vm is not None else None,
"text_mask": tm[cont_indices] if tm is not None else None,
"next_layer_idx": next_layer,
}
# ====== 这里改 AOP 配置 ======
if isinstance(aop_cfg, dict) and aop_cfg:
aop_resume = dict(aop_cfg)
aop_layer = aop_resume.get("layer_idx", None)
# 情况1:AOP_LAYER 已经 <= EE_LAYER,说明 mid 那次已经剪过了 → tail 不再剪
if (aop_layer is not None) and (aop_layer <= target_layer_idx):
need_prune_in_tail = False
else:
# 情况2:AOP_LAYER 在 EE_LAYER 后面(比如 16),只在 tail 中触发
need_prune_in_tail = bool(_orig_enabled and side_enable)
aop_resume["enabled"] = need_prune_in_tail
setattr(model.encoder, "aop_prune_config", aop_resume)
# ====== AOP 配置修改结束 ======
if profile_enabled:
torch.cuda.synchronize()
t0_tail = time.perf_counter()
with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
side = "qry"
side_ok_aop = isinstance(aop_cfg, dict) and _apply_to_match(aop_cfg.get("apply_to", "both"), side)
side_ok_vp = isinstance(vpool_cfg, dict) and _apply_to_match(vpool_cfg.get("apply_to", "both"), side)
# tail 是否启用 AOP:仅当 apply_to 匹配且 need_prune_in_tail
tail_aop_enable = bool(isinstance(aop_cfg, dict) and aop_cfg.get("enabled", False) and side_ok_aop and need_prune_in_tail)
with _temp_cfg_enabled(model, "aop_prune_config", tail_aop_enable):
with _temp_cfg_enabled(model, "vision_pooling_config", bool(isinstance(vpool_cfg, dict) and vpool_cfg.get("enabled", False) and side_ok_vp)):
try:
out_last = model.encoder(
return_dict=True,
output_hidden_states=False,
stop_at_layer=None,
resume_state=resume_state_subset,
compute_lm_head=False,
return_intermediate_state=False,
)
except TypeError:
out_last = model.encoder(
return_dict=True,
output_hidden_states=False,
stop_at_layer=None,
resume_state=resume_state_subset,
return_intermediate_state=False,
)
if profile_enabled:
torch.cuda.synchronize()
t1_tail = time.perf_counter()
timing_stats["tail_time_sum"] += (t1_tail - t0_tail) * len(cont_indices)
timing_stats["tail_num"] += len(cont_indices)
hs_last = getattr(out_last, "last_hidden_state", None)
if hs_last is None:
hs_last = out_last[0]
reps_last_t = _pool_from_outputs(model, hs_last, out_last, resume_state_subset["attention_mask"]).detach().to(dtype=torch.bfloat16)
if use_local:
for i, idx_global in enumerate(cont_indices):
cand_local = infos[idx_global].get("cand_names", [])
if not cand_local: cids = []
else:
rows = [cand_id2row.get(str(cid), -1) for cid in cand_local]
rows = [r for r in rows if r >= 0]
if len(rows) == 0: cids = []
else:
cmat_last_np = cand_last_np[rows]
qry_last_np = reps_last_t[i].detach().float().cpu().numpy()
scores_vec = np.dot(qry_last_np, cmat_last_np.T)
top_k = min(200, len(rows))
order_local = np.argsort(-scores_vec)[:top_k]
cids = [str(cand_local[o]) for o in order_local]
_record_result(results_dict, batch_start_idx + idx_global, infos[idx_global], cids)
else:
reps_last_np = reps_last_t.detach().float().cpu().numpy()
scores_last = np.dot(reps_last_np, cand_last_np.T)
top_k = min(200, len(cand_ids))
topk_inds = np.argsort(-scores_last, axis=1)[:, :top_k]
for i, idx_global in enumerate(cont_indices):
cids = [cand_ids[k] for k in topk_inds[i]]
_record_result(results_dict, batch_start_idx + idx_global, infos[idx_global], cids)
# ---------------------------------------------------
# 4. Profiling Stats
# ---------------------------------------------------
if profile_enabled and is_main:
K = min(topk_emb, cand_mid_t.size(0))
# 转到 float32 + CPU 便于写盘
q_mid_cpu = reps_mid_t.detach().float().cpu() # [B, D]
q_last_cpu = (
reps_last_full.detach().float().cpu()
if reps_last_full is not None
else None
) # [B, D]
cand_mid_cpu = cand_mid_t.detach().float().cpu() # [Nc, D]
cand_last_cpu = cand_last_t.detach().float().cpu() # [Nc, D]
# mid2mid
scores_mid_full = q_mid_cpu @ cand_mid_cpu.T # [B, Nc]
topk_mid_vals, topk_mid_inds = torch.topk(
scores_mid_full, k=K, dim=1
)
# last2last
if q_last_cpu is not None:
scores_last_full = q_last_cpu @ cand_last_cpu.T # [B, Nc]
topk_last_vals, topk_last_inds = torch.topk(
scores_last_full, k=K, dim=1
)
else:
topk_last_vals = None
topk_last_inds = None
for i in range(B):
qid = batch_start_idx + i
rec = {
"qid": int(qid),
"early_exit": bool(exit_mask[i]),
}
if p_need_last_batch is not None:
rec["p_need_last"] = float(p_need_last_batch[i].item())
# mid2mid TopK
mid_inds = topk_mid_inds[i].tolist()
mid_scores = topk_mid_vals[i].tolist()
rec["mid_topk_scores"] = mid_scores
rec["mid_topk_cand_ids"] = [cand_ids[j] for j in mid_inds]
# last2last TopK
if topk_last_inds is not None:
last_inds = topk_last_inds[i].tolist()
last_scores = topk_last_vals[i].tolist()
rec["last_topk_scores"] = last_scores
rec["last_topk_cand_ids"] = [
cand_ids[j] for j in last_inds
]
else:
rec["last_topk_scores"] = None
rec["last_topk_cand_ids"] = None
analysis_records.append(rec)
# 5. 收集 & 保存结果
for idx in sorted(results_dict.keys()):
pred_dicts.append(results_dict[idx])
print_master(f"Early Exit Stats: Exit={stats['exit']}/{stats['total']} ({stats['exit']/stats['total']:.2%})")
metrics_to_report = ["hit", "ndcg", "precision", "recall", "f1", "map", "mrr"]
score = RankingMetrics(metrics_to_report).evaluate(pred_dicts)
if is_main:
os.makedirs(out_dir, exist_ok=True)
with open(os.path.join(out_dir, f"{dataset_name}_score_earlyexit.json"), "w") as f:
json.dump(score, f, indent=4)
# Profiling save
if profile_enabled:
prof_dir = os.path.join(out_dir, "profiling")
os.makedirs(prof_dir, exist_ok=True)
mid_avg = timing_stats["mid_time_sum"] / max(1, timing_stats["mid_num"])
tail_avg = timing_stats["tail_time_sum"] / max(1, timing_stats["tail_num"])
timing_out = {
"mid_time_sum": timing_stats["mid_time_sum"],
"mid_num": timing_stats["mid_num"],
"tail_time_sum": timing_stats["tail_time_sum"],
"tail_num": timing_stats["tail_num"],
"avg_mid_time_per_query_sec": mid_avg,
"avg_tail_time_per_cont_query_sec": tail_avg,
"num_exit": int(stats["exit"]),
"num_total": int(stats["total"]),
}
with open(os.path.join(prof_dir, f"{dataset_name}_timing.json"), "w") as f:
json.dump(timing_out, f, indent=2)
embed_path = os.path.join(prof_dir, f"{dataset_name}_embeds.jsonl")
with open(embed_path, "w") as f:
for rec in analysis_records:
f.write(json.dumps(rec) + "\n")
print_master(f"[PROFILE] Saved timing to {prof_dir}, details to {embed_path}")
elapsed = time.time() - start_time
return score, elapsed
def _record_result(results_dict, global_idx, info, cids):
label = info.get("label_name") or info.get("label") or info.get("rel_docids")
if not isinstance(label, list): label = [label]
rel_scores = info.get("rel_scores", None)
results_dict[global_idx] = {
"prediction": cids, "label": label, "rel_scores": rel_scores,
}
# ===========================
# Helper Functions (Pre-Computation)
# ===========================
# ... (encode_candidates_both_layers 保持不变) ...
def encode_candidates_both_layers(model: MMEBModel, loader: DataLoader, training_args: TrainingArguments, model_args: ModelArguments, full_dataset: Dataset, mid_layer: int):
local_rank = dist.get_rank() if dist.is_initialized() else 0
model.eval()
all_mid, all_last, all_ids = [], [], []
with torch.no_grad():
for inputs, dataset_info in tqdm(loader, desc=f"Candidates[BOTH] (rank {local_rank})", disable=local_rank > 0):
inputs = batch_to_device(inputs, training_args.device)
# 按侧 gating:cand 属于 tgt
side = "tgt"
aop_cfg = getattr(model.encoder, "aop_prune_config", None)
vpool_cfg = getattr(model.encoder, "vision_pooling_config", None)
side_ok_aop = isinstance(aop_cfg, dict) and _apply_to_match(aop_cfg.get("apply_to", "both"), side)
side_ok_vp = isinstance(vpool_cfg, dict) and _apply_to_match(vpool_cfg.get("apply_to", "both"), side)
# cand 的 mid forward:只有当 AOP_LAYER <= mid_layer 时才会在 mid 这次触发剪枝
aop_layer = aop_cfg.get("layer_idx", None) if isinstance(aop_cfg, dict) else None
aop_on_mid = bool(isinstance(aop_cfg, dict) and aop_cfg.get("enabled", False) and side_ok_aop and (aop_layer is not None) and (aop_layer <= mid_layer))
with _temp_cfg_enabled(model, "aop_prune_config", aop_on_mid):
with _temp_cfg_enabled(model, "vision_pooling_config", bool(isinstance(vpool_cfg, dict) and vpool_cfg.get("enabled", False) and side_ok_vp)):
out_mid = model.encoder(
**inputs,
return_dict=True,
output_hidden_states=False,
stop_at_layer=mid_layer,
compute_lm_head=False,
return_intermediate_state=True,
)
hs_mid = getattr(out_mid, "last_hidden_state", None)
if hs_mid is None:
hs_mid = out_mid[0]
reps_mid = _pool_from_outputs(model, hs_mid, out_mid, inputs.get("attention_mask", None))
# resume 跑到 last
interm = getattr(out_mid, "intermediate_state", None)
if interm is None:
raise RuntimeError("out_mid has no intermediate_state; stop_at_layer/resume_state not supported by this backbone.")
resume_state = {
"hidden_states": interm["hidden_states"],
"attention_mask": interm["attention_mask"],
"position_ids": interm["position_ids"],
"vision_mask": interm.get("vision_mask", None),
"text_mask": interm.get("text_mask", None),
"special_mask": interm.get("special_mask", None),
"next_layer_idx": int(interm["next_layer_idx"]),
}
# tail 是否还需要剪:若 AOP_LAYER <= mid_layer,mid 已剪过 -> tail 不剪;否则按 apply_to 决定
if (aop_layer is not None) and (aop_layer <= mid_layer):
aop_on_tail = False
else:
aop_on_tail = bool(isinstance(aop_cfg, dict) and aop_cfg.get("enabled", False) and side_ok_aop)
with _temp_cfg_enabled(model, "aop_prune_config", aop_on_tail):
with _temp_cfg_enabled(model, "vision_pooling_config", bool(isinstance(vpool_cfg, dict) and vpool_cfg.get("enabled", False) and side_ok_vp)):
out_last = model.encoder(
return_dict=True,
output_hidden_states=False,
stop_at_layer=None,
resume_state=resume_state,
compute_lm_head=False,
return_intermediate_state=False,
)
hs_last = getattr(out_last, "last_hidden_state", None)
if hs_last is None:
hs_last = out_last[0]
reps_last = _pool_from_outputs(model, hs_last, out_last, resume_state["attention_mask"])
all_mid.append(reps_mid.detach().float().cpu())
all_last.append(reps_last.detach().float().cpu())
all_ids.extend([info["cand_name"] for info in dataset_info])
if not all_mid: return np.array([]), np.array([]), []
return torch.cat(all_mid, dim=0).numpy(), torch.cat(all_last, dim=0).numpy(), all_ids
# ===========================
# Main
# ===========================
def main():
if "RANK" in os.environ and dist.is_available() and not dist.is_initialized():
dist.init_process_group(backend="nccl", timeout=datetime.timedelta(minutes=60))
local_rank = dist.get_rank() if dist.is_initialized() else 0
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
ee_cfg = get_env_ee_config()
hf_config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True)
if not getattr(model_args, "model_backbone", None):
model_backbone = get_backbone_name(hf_config=hf_config, model_type=model_args.model_type)
setattr(model_args, 'model_backbone', model_backbone)
processor = load_processor(model_args, data_args)
# 1. 加载模型
model = MMEBModel.load(model_args, is_trainable=False, processor=processor)
# NEW: 加载 vpool_attn_pooler 权重(如果 checkpoint_dir 里有)
_maybe_load_vpool_attn_pooler_into_encoder(model, model_args.model_name)
model.eval()
model = model.to(training_args.device, dtype=torch.bfloat16)
# 2. 注入 AOP 配置 (全局生效)
aop_cfg = get_env_aop_config()
if aop_cfg["enabled"]:
setattr(model.encoder, "aop_prune_config", aop_cfg)
_set_attr_on_base(model.encoder, "aop_prune_config", aop_cfg)
print_master(f"[AOP] Enabled: {aop_cfg['apply_to']} @ layer {aop_cfg['layer_idx']}")
# 3. 注入 VPOOL 配置 (全局生效)
vpool_cfg = get_env_vpool_config()
if vpool_cfg["enabled"]:
setattr(model.encoder, "vision_pooling_config", vpool_cfg)
_set_attr_on_base(model.encoder, "vision_pooling_config", vpool_cfg)
print_master(f"[VPOOL] Enabled: {vpool_cfg.get('apply_to')} @ layer {vpool_cfg.get('layer_idx')} "
f"method={vpool_cfg.get('method')} attn_tau={vpool_cfg.get('attn_tau', None)}")
# 4. 加载分类器
classifier = None
if ee_cfg["method"] == "classifier" and ee_cfg["enabled"]:
classifier_path = ee_cfg['classifier_path']
print_master(f"[EE] Loading Classifier from {classifier_path}...")
# 获取 Embedding 维度 (AOP 不改变隐层维度,只改变长度)
backbone_hidden_size = model.encoder.config.hidden_size
print_master(f"[EE] Backbone Hidden Size: {backbone_hidden_size}")
classifier = EarlyExitClassifier(
input_dim=27,
hidden_dim=128,
embedding_dim=backbone_hidden_size
)
state_dict = None
if os.path.isdir(classifier_path):
safetensors_file = os.path.join(classifier_path, "model.safetensors")
if os.path.exists(safetensors_file):
from safetensors.torch import load_file
state_dict = load_file(safetensors_file)
else:
pt_file_bin = os.path.join(classifier_path, "pytorch_model.bin")
if os.path.exists(pt_file_bin):
state_dict = torch.load(pt_file_bin, map_location=training_args.device)
else:
layer_idx = ee_cfg.get('layer', 12)
pt_file = os.path.join(classifier_path, f"early_exit_classifier_layer_{layer_idx}.pt")
if os.path.exists(pt_file):
state_dict = torch.load(pt_file, map_location=training_args.device)
elif os.path.isfile(classifier_path):
state_dict = torch.load(classifier_path, map_location=training_args.device)
if state_dict is not None:
classifier.load_state_dict(state_dict)
classifier.to(training_args.device) # FP32
classifier.eval()
print_master(f"[EE] Classifier loaded successfully.")
else:
raise FileNotFoundError(f"Could not load classifier weights from {classifier_path}")
# 5. 开始评测循环
with open(data_args.dataset_config, 'r') as yaml_file: dataset_configs = yaml.safe_load(yaml_file)
for dataset_name, task_config in dataset_configs.items():
if dist.is_initialized(): dist.barrier()
print_master(f"\n--- Evaluating {dataset_name} ---")
base_tau = float(ee_cfg["threshold"])
ds_tau = PER_DATASET_THRESHOLDS.get(dataset_name, base_tau)
ee_cfg_ds = dict(ee_cfg); ee_cfg_ds["threshold"] = float(ds_tau)
if data_args.data_basedir:
for key in ["image_root", "video_root", "frame_root", "clip_root", "data_path"]:
if task_config.get(key): task_config[key] = os.path.join(data_args.data_basedir, task_config[key])
full_eval_qry_dataset, corpus = AutoEvalPairDataset.instantiate(model_args=model_args, data_args=data_args, **task_config)
full_eval_cand_dataset = generate_cand_dataset(full_eval_qry_dataset, corpus)
mid_layer = int(ee_cfg_ds["layer"])
# NEW: cand 缓存路径带上 AOP/VPOOL 配置 hash,避免读错旧缓存
aop_cfg_now = getattr(model.encoder, "aop_prune_config", None)
vpool_cfg_now = getattr(model.encoder, "vision_pooling_config", None)
tag = _cfg_cache_tag(aop_cfg_now if isinstance(aop_cfg_now, dict) else None,
vpool_cfg_now if isinstance(vpool_cfg_now, dict) else None,
mid_layer)
cand_mid_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt_layer{mid_layer}_{tag}")
cand_last_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt_layerlast_{tag}")
if (not os.path.exists(cand_mid_path)) or (not os.path.exists(cand_last_path)):
eval_cand_collator = MultimodalEvalDataCollator(processor, model_args, data_args, "cand")
eval_cand_loader = DataLoader(full_eval_cand_dataset, batch_size=training_args.per_device_eval_batch_size, collate_fn=eval_cand_collator, num_workers=training_args.dataloader_num_workers)
# 【关键】Candidate 编码也必须经过 AOP/VPOOL
cand_mid, cand_last, cand_ids = encode_candidates_both_layers(model, eval_cand_loader, training_args, model_args, full_eval_cand_dataset, mid_layer)
if local_rank == 0:
with open(cand_mid_path, "wb") as f: pickle.dump({cid: emb for cid, emb in zip(cand_ids, cand_mid)}, f)
with open(cand_last_path, "wb") as f: pickle.dump({cid: emb for cid, emb in zip(cand_ids, cand_last)}, f)
if dist.is_initialized(): dist.barrier()
if local_rank == 0:
with open(cand_mid_path, "rb") as f: cand_mid_dict = pickle.load(f)
with open(cand_last_path, "rb") as f: cand_last_dict = pickle.load(f)
rank_global = task_config.get("eval_type", "global") == "global"
run_early_exit_queries(
model, classifier, processor, model_args, data_args, training_args,
full_eval_qry_dataset, cand_mid_dict, cand_last_dict,
ee_cfg_ds, dataset_name, data_args.encode_output_path,
global_ranking=rank_global,
)
if dist.is_initialized(): dist.barrier()
if __name__ == '__main__':
main()