| # import datetime | |
| # import logging | |
| # import json | |
| # import random | |
| # import time | |
| # import numpy as np | |
| # import os | |
| # import pickle | |
| # import sys | |
| # import torch | |
| # import torch.distributed as dist | |
| # import torch.nn.functional as F | |
| # import yaml | |
| # import transformers | |
| # import math | |
| # from torch.utils.data import DataLoader | |
| # from tqdm import tqdm | |
| # from transformers import HfArgumentParser, AutoConfig, AutoTokenizer | |
| # from datasets import Dataset, concatenate_datasets | |
| # from datasets.distributed import split_dataset_by_node | |
| # from src.arguments import ModelArguments, DataArguments, TrainingArguments | |
| # from src.data.collator.eval_collator import MultimodalEvalDataCollator | |
| # from src.data.eval_dataset.base_eval_dataset import AutoEvalPairDataset, generate_cand_dataset | |
| # from src.eval_utils.metrics import RankingMetrics | |
| # from src.model.model_multi_layer_distill_AOP_pooling import MMEBModel | |
| # from src.model.processor import get_backbone_name, load_processor, COLPALI | |
| # from src.utils import batch_to_device, print_rank, print_master | |
| # # ========================================== | |
| # # 【V5 修改】引入 V5 版本分类器 | |
| # # ========================================== | |
| # from src.classifier_utils_V5 import EarlyExitClassifier | |
| # logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s [%(name)s:%(lineno)s] %(message)s') | |
| # logger = logging.getLogger(__name__) | |
| # # =========================== | |
| # # Per-dataset thresholds (示例,需根据V5分析结果更新) | |
| # # =========================== | |
| # PER_DATASET_THRESHOLDS = { | |
| # "CIRR": 0.3, | |
| # "EDIS": 0.3, | |
| # "FashionIQ": 0.3, | |
| # "MSCOCO_i2t": 0.3, | |
| # "MSCOCO_t2i": 0.3, | |
| # "NIGHTS": 0.3, | |
| # "OVEN": 0.3, | |
| # "VisDial": 0.3, | |
| # "VisualNews_i2t": 0.3, | |
| # "VisualNews_t2i": 0.3, | |
| # "WebQA": 0.3, | |
| # "Wiki-SS-NQ": 0.3, | |
| # } | |
| # # ... (Helper Functions 保持不变) ... | |
| # def _parse_bool(v: str, default=False): | |
| # if v is None: return default | |
| # v = v.strip().lower() | |
| # return v in {"1","true","yes","y","t","on"} | |
| # def _parse_float(v: str, default=None): | |
| # try: return float(v) if v is not None else default | |
| # except: return default | |
| # def _parse_int(v: str, default=None): | |
| # try: return int(v) if v is not None else default | |
| # except: return default | |
| # def get_env_aop_config(): | |
| # enabled = _parse_bool(os.environ.get("AOP_ENABLED"), False) | |
| # apply_to = os.environ.get("AOP_APPLY", "qry").strip().lower() | |
| # layer_idx = _parse_int(os.environ.get("AOP_LAYER"), None) | |
| # mode = os.environ.get("AOP_MODE", "delta").strip().lower() | |
| # delta = _parse_float(os.environ.get("AOP_DELTA"), 0.10) | |
| # khat = _parse_float(os.environ.get("AOP_KHAT"), 1.0) | |
| # keep_ratio = _parse_float(os.environ.get("AOP_KEEP_RATIO"), 1.0) | |
| # min_keep = _parse_int(os.environ.get("AOP_MIN_KEEP"), 64) | |
| # use_bias = _parse_bool(os.environ.get("AOP_USE_BIAS"), True) | |
| # prune_vision = _parse_bool(os.environ.get("AOP_PRUNE_VISION"), True) | |
| # prune_text = _parse_bool(os.environ.get("AOP_PRUNE_TEXT"), False) | |
| # keep_ratio_v = _parse_float(os.environ.get("AOP_KEEP_RATIO_VISION"), None) | |
| # keep_ratio_t = _parse_float(os.environ.get("AOP_KEEP_RATIO_TEXT"), None) | |
| # # 新增:各自的 min_keep | |
| # min_keep_v = _parse_int(os.environ.get("AOP_MIN_KEEP_VISION"), None) | |
| # min_keep_t = _parse_int(os.environ.get("AOP_MIN_KEEP_TEXT"), 32) | |
| # # 新增:文本保护参数 | |
| # protect_text_last = _parse_int(os.environ.get("AOP_PROTECT_TEXT_LAST"), 16) | |
| # protect_special = _parse_bool(os.environ.get("AOP_PROTECT_SPECIAL"), True) | |
| # selection = os.environ.get("AOP_SELECTION", "aop").strip().lower() | |
| # attn_agg = os.environ.get("AOP_ATTENTION_AGG", "mean").strip().lower() | |
| # random_seed = _parse_int(os.environ.get("AOP_RANDOM_SEED"), None) | |
| # if layer_idx is None and enabled: | |
| # enabled = False | |
| # return { | |
| # "enabled": enabled, | |
| # "apply_to": apply_to, | |
| # "layer_idx": layer_idx, | |
| # "mode": mode, | |
| # "delta": delta, | |
| # "K_hat": khat, | |
| # "keep_ratio": keep_ratio, | |
| # "min_keep": min_keep, | |
| # "use_bias": use_bias, | |
| # "prune_vision": prune_vision, | |
| # "prune_text": prune_text, | |
| # "keep_ratio_vision": keep_ratio_v, | |
| # "keep_ratio_text": keep_ratio_t, | |
| # "min_keep_vision": min_keep_v, | |
| # "min_keep_text": min_keep_t, | |
| # "protect_text_last": protect_text_last, | |
| # "protect_special": protect_special, | |
| # "selection": selection, | |
| # "attn_agg": attn_agg, | |
| # "random_seed": random_seed, | |
| # "margin_mid": None, | |
| # } | |
| # def _set_attr_on_base(peft_or_base, name, value): | |
| # try: | |
| # base = peft_or_base.get_base_model() if hasattr(peft_or_base, "get_base_model") else None | |
| # if base is None and hasattr(peft_or_base, "model"): | |
| # base = peft_or_base.model | |
| # if base is not None: | |
| # setattr(base, name, value) | |
| # if hasattr(base, "model"): # ForConditionalGeneration.model -> Qwen2VLModel | |
| # setattr(base.model, name, value) | |
| # except Exception as e: | |
| # print_rank(f"[inject-config] warn: set {name} on base failed: {e}") | |
| # def get_env_vpool_config(): | |
| # def _b(k, d=False): | |
| # v = os.environ.get(k) | |
| # if v is None: return d | |
| # return str(v).strip().lower() in {"1","true","y","yes","on"} | |
| # def _i(k, d=None): | |
| # try: return int(os.environ.get(k, d)) | |
| # except: return d | |
| # def _s(k, d=None): | |
| # v = os.environ.get(k, d) | |
| # return None if v is None else str(v).strip().lower() | |
| # return { | |
| # "enabled": _b("VPOOL_ENABLED", False), | |
| # "apply_to": _s("VPOOL_APPLY", "both"), # qry|tgt|both | |
| # "layer_idx": _i("VPOOL_LAYER", 1), | |
| # "kernel": _i("VPOOL_KERNEL", 2), | |
| # "stride": _i("VPOOL_STRIDE", None) or _i("VPOOL_KERNEL", 2), | |
| # "method": _s("VPOOL_METHOD", "avg"), | |
| # "protect_cls": _b("VPOOL_PROTECT_CLS", True), | |
| # "vision_only": _b("VPOOL_ONLY_VISION", True), | |
| # "monitor": _b("VPOOL_MONITOR", False), | |
| # } | |
| # def get_env_ee_config(): | |
| # ee_enabled = os.environ.get("EE_ENABLED", "0").strip().lower() in {"1","true","yes","on","y","t"} | |
| # layer = int(os.environ.get("EE_LAYER", "12")) | |
| # method = os.environ.get("EE_METHOD", "classifier").strip().lower() | |
| # threshold = float(os.environ.get("EE_THRESHOLD", "0.8")) | |
| # classifier_path = os.environ.get("EE_CLASSIFIER_PATH", "") | |
| # return dict(enabled=ee_enabled, layer=layer, method=method, threshold=threshold, classifier_path=classifier_path) | |
| # def pad_dataset_to_divisible(dataset, world_size): | |
| # num_samples = len(dataset) | |
| # if num_samples % world_size == 0: return dataset, num_samples | |
| # num_to_add = world_size - (num_samples % world_size) | |
| # padding_data = dataset.select([i % len(dataset) for i in range(num_to_add)]) | |
| # return concatenate_datasets([dataset, padding_data]), num_samples + num_to_add | |
| # # =========================== | |
| # # Core Inference Function | |
| # # =========================== | |
| # def run_early_exit_queries( | |
| # model: MMEBModel, | |
| # classifier: EarlyExitClassifier, | |
| # processor, | |
| # model_args: ModelArguments, | |
| # data_args: DataArguments, | |
| # training_args: TrainingArguments, | |
| # qry_dataset: Dataset, | |
| # cand_mid_dict: dict, | |
| # cand_last_dict: dict, | |
| # ee_cfg: dict, | |
| # dataset_name: str, | |
| # out_dir: str, | |
| # global_ranking: bool = True, | |
| # ): | |
| # device = training_args.device | |
| # local_rank = dist.get_rank() if dist.is_initialized() else 0 | |
| # is_main = (not dist.is_initialized()) or (local_rank == 0) | |
| # # Profiling 配置 | |
| # profile_enabled = os.environ.get("EE_PROFILE", "0").strip().lower() in {"1", "true", "yes", "on", "y", "t"} | |
| # topk_emb = int(os.environ.get("EE_TOPK_EMB", "5")) | |
| # timing_stats = { | |
| # "mid_time_sum": 0.0, "mid_num": 0, "tail_time_sum": 0.0, "tail_num": 0, | |
| # } | |
| # analysis_records = [] if (profile_enabled and is_main) else None | |
| # # 1. 准备 Candidates | |
| # cand_ids = list(cand_mid_dict.keys()) | |
| # cand_mid = np.stack([cand_mid_dict[c] for c in cand_ids]).astype(np.float32) | |
| # cand_last = np.stack([cand_last_dict[c] for c in cand_ids]).astype(np.float32) | |
| # cand_mid_np = cand_mid # [Nc, D], float32 | |
| # cand_last_np = cand_last # [Nc, D], float32 | |
| # cand_mid_t = torch.from_numpy(cand_mid).to(device=device, dtype=torch.bfloat16) | |
| # cand_last_t = torch.from_numpy(cand_last).to(device=device, dtype=torch.bfloat16) | |
| # from contextlib import contextmanager | |
| # # 侧别是否启用 AOP/VPOOL | |
| # aop_cfg_model = getattr(model.encoder, "aop_prune_config", None) | |
| # vpool_cfg_model = getattr(model.encoder, "vision_pooling_config", None) | |
| # @contextmanager | |
| # def _switch_aop(enable: bool): | |
| # enc = model.encoder | |
| # old = getattr(enc, "aop_prune_config", None) | |
| # def _set_cfg(mod, cfg): | |
| # setattr(mod, "aop_prune_config", cfg) | |
| # base = mod.get_base_model() if hasattr(mod, "get_base_model") else None | |
| # if base is None and hasattr(mod, "model"): | |
| # base = mod.model | |
| # if base is not None: | |
| # setattr(base, "aop_prune_config", cfg) | |
| # if hasattr(base, "model"): | |
| # setattr(base.model, "aop_prune_config", cfg) | |
| # if isinstance(old, dict) and not enable: | |
| # cfg = dict(old); cfg["enabled"] = False | |
| # _set_cfg(enc, cfg) | |
| # try: | |
| # yield | |
| # finally: | |
| # _set_cfg(enc, old) | |
| # @contextmanager | |
| # def _switch_vpool(enable: bool): | |
| # enc = model.encoder | |
| # old = getattr(enc, "vision_pooling_config", None) | |
| # def _set_cfg(mod, cfg): | |
| # setattr(mod, "vision_pooling_config", cfg) | |
| # base = mod.get_base_model() if hasattr(mod, "get_base_model") else None | |
| # if base is None and hasattr(mod, "model"): | |
| # base = mod.model | |
| # if base is not None: | |
| # setattr(base, "vision_pooling_config", cfg) | |
| # if hasattr(base, "model"): | |
| # setattr(base.model, "vision_pooling_config", cfg) | |
| # if isinstance(old, dict) and not enable: | |
| # cfg = dict(old); cfg["enabled"] = False | |
| # _set_cfg(enc, cfg) | |
| # try: | |
| # yield | |
| # finally: | |
| # _set_cfg(enc, old) | |
| # collator = MultimodalEvalDataCollator(processor, model_args, data_args, "qry") | |
| # loader = DataLoader( | |
| # qry_dataset, batch_size=training_args.per_device_eval_batch_size, | |
| # collate_fn=collator, num_workers=training_args.dataloader_num_workers | |
| # ) | |
| # pred_dicts = [] | |
| # stats = {"exit": 0, "total": 0} | |
| # threshold = float(ee_cfg["threshold"]) | |
| # method = ee_cfg["method"] | |
| # target_layer_idx = int(ee_cfg["layer"]) | |
| # results_dict = {} | |
| # global_sample_idx = 0 | |
| # use_local = (not global_ranking) | |
| # if use_local: | |
| # print_master(f"[INFO] Using LOCAL ranking (per-query candidate sets)") | |
| # cand_id2row = {str(cid): i for i, cid in enumerate(cand_ids)} | |
| # else: | |
| # print_master(f"[INFO] Using GLOBAL ranking (full library)") | |
| # # AOP Config | |
| # aop_cfg = getattr(model.encoder, "aop_prune_config", None) | |
| # _orig_enabled = None | |
| # side_enable = True | |
| # if isinstance(aop_cfg, dict) and aop_cfg: | |
| # _orig_enabled = aop_cfg.get("enabled", False) | |
| # apply_to = aop_cfg.get("apply_to", "qry") | |
| # side_enable = (apply_to == "both") or (apply_to == "qry") | |
| # model.eval() | |
| # if classifier: | |
| # classifier.eval() | |
| # # 【V5 关键】分类器运行在 FP32 (默认),不要转 BF16,配合 Log/LayerNorm | |
| # classifier.to(device) | |
| # start_time = time.time() | |
| # for inputs, infos in tqdm( | |
| # loader, desc=f"[EE+AOP] {dataset_name} (tau={threshold})", disable=local_rank > 0, | |
| # ): | |
| # inputs = batch_to_device(inputs, device) | |
| # B = inputs["input_ids"].shape[0] if "input_ids" in inputs else 1 | |
| # batch_start_idx = global_sample_idx | |
| # global_sample_idx += B | |
| # stats["total"] += B | |
| # # --------------------------------------------------- | |
| # # 1. 前半程: Run to Mid Layer | |
| # # --------------------------------------------------- | |
| # # AOP 侧别 | |
| # aop_apply = (aop_cfg_model or {}).get("apply_to", "qry") | |
| # aop_side_qry = aop_apply in ("both", "qry") | |
| # aop_side_tgt = aop_apply in ("both", "tgt") | |
| # # VPOOL 侧别(一般 pooling 设在 layer=1,因此 mid 一定会经历) | |
| # vpool_apply = (vpool_cfg_model or {}).get("apply_to", "qry") | |
| # vpool_layer = (vpool_cfg_model or {}).get("layer_idx", 1) | |
| # vpool_side_qry = vpool_apply in ("both", "qry") | |
| # vpool_side_tgt = vpool_apply in ("both", "tgt") | |
| # orig_cfg = None | |
| # if isinstance(aop_cfg, dict) and aop_cfg: | |
| # orig_cfg = dict(aop_cfg) | |
| # aop_layer = aop_cfg.get("layer_idx", None) | |
| # aop_on_mid = bool( | |
| # _orig_enabled and side_enable | |
| # and (aop_layer is not None) | |
| # and (aop_layer <= target_layer_idx) | |
| # ) | |
| # aop_cfg_mid = dict(aop_cfg) | |
| # aop_cfg_mid["enabled"] = aop_on_mid | |
| # setattr(model.encoder, "aop_prune_config", aop_cfg_mid) | |
| # if profile_enabled: | |
| # torch.cuda.synchronize() | |
| # t0_mid = time.perf_counter() | |
| # # --- tgt mid --- | |
| # tgt_pre_mask = tgt_inputs.get("attention_mask", None) | |
| # tgt_post_mask = getattr(out_mid_tgt, "attention_mask", None) | |
| # tgt_hs_mid = getattr(out_mid_tgt, "last_hidden_state", None) or out_mid_tgt.hidden_states[-1] | |
| # tgt_am_mid = tgt_post_mask if (tgt_post_mask is not None and tgt_post_mask.size(1) == tgt_hs_mid.size(1)) else tgt_pre_mask | |
| # tgt_reps_mid = model._pooling(tgt_hs_mid, tgt_am_mid).detach().to(dtype=torch.bfloat16) | |
| # # --- qry mid --- | |
| # qry_pre_mask = qry_inputs.get("attention_mask", None) | |
| # qry_post_mask = getattr(out_mid_qry, "attention_mask", None) | |
| # qry_hs_mid = getattr(out_mid_qry, "last_hidden_state", None) or out_mid_qry.hidden_states[-1] | |
| # qry_am_mid = qry_post_mask if (qry_post_mask is not None and qry_post_mask.size(1) == qry_hs_mid.size(1)) else qry_pre_mask | |
| # qry_reps_mid = model._pooling(qry_hs_mid, qry_am_mid).detach().to(dtype=torch.bfloat16) | |
| # with _switch_aop(aop_side_tgt), _switch_vpool(vpool_side_tgt and vpool_layer <= target_layer_idx): | |
| # out_mid_tgt = model.encoder( | |
| # **tgt_inputs, return_dict=True, output_hidden_states=False, | |
| # stop_at_layer=target_layer_idx, compute_lm_head=False, | |
| # ) | |
| # with _switch_aop(aop_side_qry), _switch_vpool(vpool_side_qry and vpool_layer <= target_layer_idx): | |
| # out_mid_qry = model.encoder( | |
| # **qry_inputs, return_dict=True, output_hidden_states=False, | |
| # stop_at_layer=target_layer_idx, compute_lm_head=False, | |
| # ) | |
| # # 这里就是“该插这三行”的地方: | |
| # post_attn_mid = getattr(out_mid, "attention_mask", None) # [B, L'] 剪完后的 2D mask | |
| # img_mask_mid = getattr(out_mid, "image_token_bool_masks", None) # [B, L'] 视觉 token 掩码(Qwen2.5 里目前是 pre 的,你如果想要 post,可以在 Qwen2_5_VLModel 里改成输出 cur_vision_mask) | |
| # txt_mask_mid = getattr(out_mid, "text_token_bool_masks", None) # [B, L'] 文本 token 掩码(已经是 post-prune 的) | |
| # # 然后按需统计,比如: | |
| # if post_attn_mid is not None: | |
| # # 有效 token 总数 | |
| # keep_counts = post_attn_mid.to(torch.bool).sum(dim=1) # [B] | |
| # # 文本 token 保留数 | |
| # if txt_mask_mid is not None: | |
| # text_keep = (post_attn_mid.to(torch.bool) & txt_mask_mid.to(torch.bool)).sum(dim=1) | |
| # # 图像 token 保留数 | |
| # if img_mask_mid is not None: | |
| # vis_keep = (post_attn_mid.to(torch.bool) & img_mask_mid.to(torch.bool)).sum(dim=1) | |
| # if profile_enabled: | |
| # torch.cuda.synchronize() | |
| # t1_mid = time.perf_counter() | |
| # timing_stats["mid_time_sum"] += (t1_mid - t0_mid) * B | |
| # timing_stats["mid_num"] += B | |
| # if isinstance(orig_cfg, dict): | |
| # setattr(model.encoder, "aop_prune_config", orig_cfg) | |
| # hs_mid = getattr(out_mid, "last_hidden_state", None) | |
| # if hs_mid is None: hs_mid = out_mid.hidden_states[-1] | |
| # am_mid = getattr(out_mid, "attention_mask", None) | |
| # if am_mid is None: am_mid = inputs.get("attention_mask", None) | |
| # reps_mid_t = model._pooling(hs_mid, am_mid).detach().to(dtype=torch.bfloat16) | |
| # # Profiling Last Layer | |
| # reps_last_full = None | |
| # if profile_enabled: | |
| # with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): | |
| # out_full = model.encoder( | |
| # **inputs, return_dict=True, output_hidden_states=False, | |
| # stop_at_layer=None, compute_lm_head=False, | |
| # ) | |
| # hs_last_full = getattr(out_full, "last_hidden_state", None) | |
| # if hs_last_full is None: hs_last_full = out_full.hidden_states[-1] | |
| # am_last_full = getattr(out_full, "attention_mask", None) | |
| # if am_last_full is None: am_last_full = inputs.get("attention_mask", None) | |
| # reps_last_full = model._pooling(hs_last_full, am_last_full).detach().to(dtype=torch.bfloat16) | |
| # # --------------------------------------------------- | |
| # # 2. 特征工程 + gating | |
| # # --------------------------------------------------- | |
| # exit_mask = np.zeros(B, dtype=bool) | |
| # p_need_last_batch = None | |
| # if method == "classifier" and classifier is not None: | |
| # with torch.no_grad(): | |
| # cos_mid = reps_mid_t @ cand_mid_t.T # [B, N] (BF16) | |
| # backbone_ptr = model.module if hasattr(model, "module") else model | |
| # temp = getattr(backbone_ptr, "temperature", 0.02) | |
| # scores_mid = cos_mid / temp | |
| # probs_mid = torch.softmax(scores_mid, dim=1) # [B, N] | |
| # # --- 27维特征计算 (BF16) --- | |
| # diag_cos = cos_mid.max(dim=1)[0] | |
| # sorted_cos, _ = torch.sort(cos_mid, dim=1, descending=True) | |
| # s2_cos = sorted_cos[:, 1] if sorted_cos.size(1) > 1 else sorted_cos[:, 0] | |
| # margin_mid = diag_cos - s2_cos | |
| # margin_mean = margin_mid.mean() | |
| # margin_std = margin_mid.std(unbiased=False) + 1e-6 | |
| # z_margin_mid = (margin_mid - margin_mean) / margin_std | |
| # margin_median = margin_mid.median() | |
| # mad = (margin_mid - margin_median).abs().median() + 1e-6 | |
| # mad_margin_mid = (margin_mid - margin_median) / mad | |
| # p1_mid = probs_mid.max(dim=1)[0] | |
| # H_mid = -(probs_mid * torch.log(probs_mid + 1e-6)).sum(dim=1) | |
| # gini_mid = 1.0 - (probs_mid ** 2).sum(dim=1) | |
| # TOPK = min(16, probs_mid.size(1)) | |
| # topk_vals, _ = torch.topk(probs_mid, k=TOPK, dim=1) | |
| # topk_mean = topk_vals.mean(dim=1) | |
| # topk_std = topk_vals.std(dim=1, unbiased=False) | |
| # topk_cv = topk_std / (topk_mean + 1e-6) | |
| # centered = topk_vals - topk_mean.unsqueeze(1) | |
| # var = (centered ** 2).mean(dim=1) + 1e-6 | |
| # m4 = (centered ** 4).mean(dim=1) | |
| # topk_kurt = m4 / (var ** 2) | |
| # topk_med = topk_vals.median(dim=1).values | |
| # row_mean_cos = cos_mid.mean(dim=1) | |
| # row_med_cos = cos_mid.median(dim=1).values | |
| # s1_over_mean = diag_cos - row_mean_cos | |
| # s1_over_med = diag_cos - row_med_cos | |
| # sorted_probs, _ = torch.sort(probs_mid, dim=1, descending=True) | |
| # p1 = sorted_probs[:, 0] | |
| # p2 = sorted_probs[:, 1] if sorted_probs.size(1) > 1 else sorted_probs[:, 0] | |
| # shape_H = -(sorted_probs * torch.log(sorted_probs + 1e-6)).sum(dim=1) | |
| # shape_gini = 1.0 - (sorted_probs ** 2).sum(dim=1) | |
| # R = min(10, sorted_probs.size(1)) | |
| # x = torch.arange(R, device=device, dtype=sorted_probs.dtype) | |
| # x_centered = x - x.mean() | |
| # denom = (x_centered ** 2).sum() | |
| # y = torch.log(sorted_probs[:, :R] + 1e-6) | |
| # slope = (x_centered.unsqueeze(0) * y).sum(dim=1) / denom | |
| # row_mean_p = probs_mid.mean(dim=1) | |
| # row_std_p = probs_mid.std(dim=1, unbiased=False) + 1e-6 | |
| # z1 = (p1_mid - row_mean_p) / row_std_p | |
| # center_p = probs_mid - row_mean_p.unsqueeze(1) | |
| # m3 = (center_p ** 3).mean(dim=1) | |
| # skew = m3 / (row_std_p ** 3 + 1e-6) | |
| # s1_over_sk = p1_mid - skew | |
| # TAIL_K = min(10, sorted_probs.size(1)) | |
| # tail_mean = sorted_probs[:, -TAIL_K:].mean(dim=1) | |
| # HEAD_K = min(5, sorted_probs.size(1)) | |
| # head5_mean = sorted_probs[:, :HEAD_K].mean(dim=1) | |
| # mask_ratio = torch.zeros_like(diag_cos) | |
| # mask_len = torch.zeros_like(diag_cos) | |
| # mask_runs = torch.zeros_like(diag_cos) | |
| # scalar_inputs = torch.stack([ | |
| # diag_cos, s2_cos, margin_mid, z_margin_mid, mad_margin_mid, | |
| # p1_mid, H_mid, gini_mid, | |
| # topk_mean, topk_std, topk_cv, topk_kurt, topk_med, | |
| # s1_over_mean, s1_over_med, | |
| # p1, p2, shape_H, shape_gini, slope, z1, s1_over_sk, | |
| # tail_mean, head5_mean, | |
| # mask_ratio, mask_len, mask_runs | |
| # ], dim=1) | |
| # modality_idx = torch.zeros(B, dtype=torch.long, device=device) | |
| # if "pixel_values" in inputs and inputs["pixel_values"] is not None: | |
| # pv = inputs["pixel_values"] | |
| # if isinstance(pv, list): | |
| # for i, item in enumerate(pv): | |
| # if item is not None: modality_idx[i] = 1 | |
| # elif isinstance(pv, torch.Tensor) and pv.numel() > 0: | |
| # modality_idx.fill_(1) | |
| # # ======================================================= | |
| # # 【V5 关键】强制转为 FP32 传给分类器 | |
| # # ======================================================= | |
| # scalar_inputs_f32 = scalar_inputs.float() | |
| # qry_emb_f32 = reps_mid_t.float() # 转 FP32 | |
| # logits = classifier(scalar_inputs_f32, modality_idx, qry_emb=qry_emb_f32) | |
| # p_need_last = torch.sigmoid(logits) # [B,1] | |
| # p_need_last_batch = p_need_last.squeeze(1) # [B] | |
| # should_exit = p_need_last_batch < threshold | |
| # exit_mask = should_exit.cpu().numpy() | |
| # if stats["total"] <= B * 3 and is_main: | |
| # print_master( | |
| # f"[EE Debug] Batch {stats['total']//B}: " | |
| # f"p_need_last mean={p_need_last_batch.mean().item():.4f}, " | |
| # f"Exit Rate={exit_mask.mean():.2%}, " | |
| # f"Top3: diag_cos={diag_cos.mean():.3f}, margin={margin_mid.mean():.3f}" | |
| # ) | |
| # stats["exit"] += exit_mask.sum() | |
| # # --------------------------------------------------- | |
| # # 3. 分支执行 | |
| # # --------------------------------------------------- | |
| # exit_indices = np.where(exit_mask)[0] | |
| # cont_indices = np.where(~exit_mask)[0] | |
| # # A. 早停样本 | |
| # if len(exit_indices) > 0: | |
| # reps_exit = reps_mid_t[exit_indices] | |
| # if use_local: | |
| # for i, idx in enumerate(exit_indices): | |
| # cand_local = infos[idx].get("cand_names", []) | |
| # if not cand_local: cids = [] | |
| # else: | |
| # rows = [cand_id2row.get(str(cid), -1) for cid in cand_local] | |
| # rows = [r for r in rows if r >= 0] | |
| # if len(rows) == 0: cids = [] | |
| # else: | |
| # cmat_np = cand_mid_np[rows] | |
| # qry_np = reps_exit[i].detach().float().cpu().numpy() | |
| # scores_vec = np.dot(qry_np, cmat_np.T) | |
| # top_k = min(200, len(rows)) | |
| # order_local = np.argsort(-scores_vec)[:top_k] | |
| # cids = [str(cand_local[o]) for o in order_local] | |
| # _record_result(results_dict, batch_start_idx + idx, infos[idx], cids) | |
| # else: | |
| # reps_exit_np = reps_exit.detach().float().cpu().numpy() | |
| # scores_full = np.dot(reps_exit_np, cand_mid_np.T) | |
| # top_k = min(200, len(cand_ids)) | |
| # topk_inds = np.argsort(-scores_full, axis=1)[:, :top_k] | |
| # for i, idx in enumerate(exit_indices): | |
| # cids = [cand_ids[k] for k in topk_inds[i]] | |
| # _record_result(results_dict, batch_start_idx + idx, infos[idx], cids) | |
| # # B. 续跑样本 | |
| # if len(cont_indices) > 0: | |
| # # 取中间态 | |
| # interm = getattr(out_mid, "intermediate_state", None) | |
| # hs, am, pos = interm["hidden_states"].detach(), interm["attention_mask"].detach(), interm["position_ids"].detach() | |
| # vm, tm = interm.get("vision_mask", None), interm.get("text_mask", None) | |
| # next_layer = int(interm["next_layer_idx"]) | |
| # resume_state_subset = { | |
| # "hidden_states": hs[cont_indices], "attention_mask": am[cont_indices], | |
| # "position_ids": pos[:, cont_indices, :], | |
| # "vision_mask": vm[cont_indices] if vm is not None else None, | |
| # "text_mask": tm[cont_indices] if tm is not None else None, | |
| # "next_layer_idx": next_layer, | |
| # } | |
| # # ====== 这里改 AOP 配置 ====== | |
| # if isinstance(aop_cfg, dict) and aop_cfg: | |
| # aop_resume = dict(aop_cfg) | |
| # aop_layer = aop_resume.get("layer_idx", None) | |
| # # 情况1:AOP_LAYER 已经 <= EE_LAYER,说明 mid 那次已经剪过了 → tail 不再剪 | |
| # if (aop_layer is not None) and (aop_layer <= target_layer_idx): | |
| # need_prune_in_tail = False | |
| # else: | |
| # # 情况2:AOP_LAYER 在 EE_LAYER 后面(比如 16),只在 tail 中触发 | |
| # need_prune_in_tail = bool(_orig_enabled and side_enable) | |
| # aop_resume["enabled"] = need_prune_in_tail | |
| # setattr(model.encoder, "aop_prune_config", aop_resume) | |
| # # ====== AOP 配置修改结束 ====== | |
| # if profile_enabled: | |
| # torch.cuda.synchronize() | |
| # t0_tail = time.perf_counter() | |
| # with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): | |
| # out_last = model.encoder( | |
| # return_dict=True, output_hidden_states=False, stop_at_layer=None, | |
| # resume_state=resume_state_subset, compute_lm_head=False, | |
| # ) | |
| # if profile_enabled: | |
| # torch.cuda.synchronize() | |
| # t1_tail = time.perf_counter() | |
| # timing_stats["tail_time_sum"] += (t1_tail - t0_tail) * len(cont_indices) | |
| # timing_stats["tail_num"] += len(cont_indices) | |
| # hs_last = out_last.last_hidden_state | |
| # if hs_last is None: hs_last = out_last.hidden_states[-1] | |
| # am_last = getattr(out_last, "attention_mask", None) | |
| # if am_last is None: am_last = resume_state_subset["attention_mask"] | |
| # reps_last_t = model._pooling(hs_last, am_last).detach().to(dtype=torch.bfloat16) | |
| # if use_local: | |
| # for i, idx_global in enumerate(cont_indices): | |
| # cand_local = infos[idx_global].get("cand_names", []) | |
| # if not cand_local: cids = [] | |
| # else: | |
| # rows = [cand_id2row.get(str(cid), -1) for cid in cand_local] | |
| # rows = [r for r in rows if r >= 0] | |
| # if len(rows) == 0: cids = [] | |
| # else: | |
| # cmat_last_np = cand_last_np[rows] | |
| # qry_last_np = reps_last_t[i].detach().float().cpu().numpy() | |
| # scores_vec = np.dot(qry_last_np, cmat_last_np.T) | |
| # top_k = min(200, len(rows)) | |
| # order_local = np.argsort(-scores_vec)[:top_k] | |
| # cids = [str(cand_local[o]) for o in order_local] | |
| # _record_result(results_dict, batch_start_idx + idx_global, infos[idx_global], cids) | |
| # else: | |
| # reps_last_np = reps_last_t.detach().float().cpu().numpy() | |
| # scores_last = np.dot(reps_last_np, cand_last_np.T) | |
| # top_k = min(200, len(cand_ids)) | |
| # topk_inds = np.argsort(-scores_last, axis=1)[:, :top_k] | |
| # for i, idx_global in enumerate(cont_indices): | |
| # cids = [cand_ids[k] for k in topk_inds[i]] | |
| # _record_result(results_dict, batch_start_idx + idx_global, infos[idx_global], cids) | |
| # # --------------------------------------------------- | |
| # # 4. Profiling Stats | |
| # # --------------------------------------------------- | |
| # if profile_enabled and is_main: | |
| # K = min(topk_emb, cand_mid_t.size(0)) | |
| # # 转到 float32 + CPU 便于写盘 | |
| # q_mid_cpu = reps_mid_t.detach().float().cpu() # [B, D] | |
| # q_last_cpu = ( | |
| # reps_last_full.detach().float().cpu() | |
| # if reps_last_full is not None | |
| # else None | |
| # ) # [B, D] | |
| # cand_mid_cpu = cand_mid_t.detach().float().cpu() # [Nc, D] | |
| # cand_last_cpu = cand_last_t.detach().float().cpu() # [Nc, D] | |
| # # mid2mid | |
| # scores_mid_full = q_mid_cpu @ cand_mid_cpu.T # [B, Nc] | |
| # topk_mid_vals, topk_mid_inds = torch.topk( | |
| # scores_mid_full, k=K, dim=1 | |
| # ) | |
| # # last2last | |
| # if q_last_cpu is not None: | |
| # scores_last_full = q_last_cpu @ cand_last_cpu.T # [B, Nc] | |
| # topk_last_vals, topk_last_inds = torch.topk( | |
| # scores_last_full, k=K, dim=1 | |
| # ) | |
| # else: | |
| # topk_last_vals = None | |
| # topk_last_inds = None | |
| # for i in range(B): | |
| # qid = batch_start_idx + i | |
| # rec = { | |
| # "qid": int(qid), | |
| # "early_exit": bool(exit_mask[i]), | |
| # } | |
| # if p_need_last_batch is not None: | |
| # rec["p_need_last"] = float(p_need_last_batch[i].item()) | |
| # # mid2mid TopK | |
| # mid_inds = topk_mid_inds[i].tolist() | |
| # mid_scores = topk_mid_vals[i].tolist() | |
| # rec["mid_topk_scores"] = mid_scores | |
| # rec["mid_topk_cand_ids"] = [cand_ids[j] for j in mid_inds] | |
| # # last2last TopK | |
| # if topk_last_inds is not None: | |
| # last_inds = topk_last_inds[i].tolist() | |
| # last_scores = topk_last_vals[i].tolist() | |
| # rec["last_topk_scores"] = last_scores | |
| # rec["last_topk_cand_ids"] = [ | |
| # cand_ids[j] for j in last_inds | |
| # ] | |
| # else: | |
| # rec["last_topk_scores"] = None | |
| # rec["last_topk_cand_ids"] = None | |
| # analysis_records.append(rec) | |
| # # 5. 收集 & 保存结果 | |
| # for idx in sorted(results_dict.keys()): | |
| # pred_dicts.append(results_dict[idx]) | |
| # print_master(f"Early Exit Stats: Exit={stats['exit']}/{stats['total']} ({stats['exit']/stats['total']:.2%})") | |
| # metrics_to_report = ["hit", "ndcg", "precision", "recall", "f1", "map", "mrr"] | |
| # score = RankingMetrics(metrics_to_report).evaluate(pred_dicts) | |
| # if is_main: | |
| # os.makedirs(out_dir, exist_ok=True) | |
| # with open(os.path.join(out_dir, f"{dataset_name}_score_earlyexit.json"), "w") as f: | |
| # json.dump(score, f, indent=4) | |
| # # Profiling save | |
| # if profile_enabled: | |
| # prof_dir = os.path.join(out_dir, "profiling") | |
| # os.makedirs(prof_dir, exist_ok=True) | |
| # mid_avg = timing_stats["mid_time_sum"] / max(1, timing_stats["mid_num"]) | |
| # tail_avg = timing_stats["tail_time_sum"] / max(1, timing_stats["tail_num"]) | |
| # timing_out = { | |
| # "mid_time_sum": timing_stats["mid_time_sum"], | |
| # "mid_num": timing_stats["mid_num"], | |
| # "tail_time_sum": timing_stats["tail_time_sum"], | |
| # "tail_num": timing_stats["tail_num"], | |
| # "avg_mid_time_per_query_sec": mid_avg, | |
| # "avg_tail_time_per_cont_query_sec": tail_avg, | |
| # "num_exit": int(stats["exit"]), | |
| # "num_total": int(stats["total"]), | |
| # } | |
| # with open(os.path.join(prof_dir, f"{dataset_name}_timing.json"), "w") as f: | |
| # json.dump(timing_out, f, indent=2) | |
| # embed_path = os.path.join(prof_dir, f"{dataset_name}_embeds.jsonl") | |
| # with open(embed_path, "w") as f: | |
| # for rec in analysis_records: | |
| # f.write(json.dumps(rec) + "\n") | |
| # print_master(f"[PROFILE] Saved timing to {prof_dir}, details to {embed_path}") | |
| # elapsed = time.time() - start_time | |
| # return score, elapsed | |
| # def _record_result(results_dict, global_idx, info, cids): | |
| # label = info.get("label_name") or info.get("label") or info.get("rel_docids") | |
| # if not isinstance(label, list): label = [label] | |
| # rel_scores = info.get("rel_scores", None) | |
| # results_dict[global_idx] = { | |
| # "prediction": cids, "label": label, "rel_scores": rel_scores, | |
| # } | |
| # # =========================== | |
| # # Helper Functions (Pre-Computation) | |
| # # =========================== | |
| # # ... (encode_candidates_both_layers 保持不变) ... | |
| # def encode_candidates_both_layers(model: MMEBModel, loader: DataLoader, training_args: TrainingArguments, model_args: ModelArguments, full_dataset: Dataset, mid_layer: int): | |
| # local_rank = dist.get_rank() if dist.is_initialized() else 0 | |
| # model.eval() | |
| # all_mid, all_last, all_ids = [], [], [] | |
| # with torch.no_grad(): | |
| # for inputs, dataset_info in tqdm(loader, desc=f"Candidates[BOTH] (rank {local_rank})", disable=local_rank > 0): | |
| # inputs = batch_to_device(inputs, training_args.device) | |
| # aop_cfg = getattr(model.encoder, "aop_prune_config", None) | |
| # vpool_cfg = getattr(model.encoder, "vision_pooling_config", None) | |
| # _orig_aop = None | |
| # _orig_vpool = None | |
| # if isinstance(aop_cfg, dict): | |
| # _orig_aop = aop_cfg.get("enabled", False) | |
| # aop_cfg = dict(aop_cfg); aop_cfg["enabled"] = False | |
| # setattr(model.encoder, "aop_prune_config", aop_cfg) | |
| # if isinstance(vpool_cfg, dict): | |
| # _orig_vpool = vpool_cfg.get("enabled", False) | |
| # vpool_cfg = dict(vpool_cfg); vpool_cfg["enabled"] = False | |
| # setattr(model.encoder, "vision_pooling_config", vpool_cfg) | |
| # with torch.autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda"): | |
| # out = model.encoder(**inputs, return_dict=True, output_hidden_states=True, stop_at_layer=None) | |
| # # 恢复 | |
| # if isinstance(aop_cfg, dict) and _orig_aop is not None: | |
| # aop_cfg["enabled"] = _orig_aop | |
| # setattr(model.encoder, "aop_prune_config", aop_cfg) | |
| # if isinstance(vpool_cfg, dict) and _orig_vpool is not None: | |
| # vpool_cfg["enabled"] = _orig_vpool | |
| # setattr(model.encoder, "vision_pooling_config", vpool_cfg) | |
| # mid_hs = out.hidden_states[mid_layer] | |
| # last_hs = out.hidden_states[-1] | |
| # am = inputs.get("attention_mask", None) | |
| # if am is not None and am.device != mid_hs.device: am = am.to(mid_hs.device) | |
| # reps_mid = model._pooling(mid_hs, am) | |
| # reps_last = model._pooling(last_hs, am) | |
| # all_mid.append(reps_mid.detach().float().cpu()) | |
| # all_last.append(reps_last.detach().float().cpu()) | |
| # all_ids.extend([info["cand_name"] for info in dataset_info]) | |
| # if not all_mid: return np.array([]), np.array([]), [] | |
| # return torch.cat(all_mid, dim=0).numpy(), torch.cat(all_last, dim=0).numpy(), all_ids | |
| # # =========================== | |
| # # Main | |
| # # =========================== | |
| # def main(): | |
| # if "RANK" in os.environ and dist.is_available() and not dist.is_initialized(): | |
| # dist.init_process_group(backend="nccl", timeout=datetime.timedelta(minutes=60)) | |
| # local_rank = dist.get_rank() if dist.is_initialized() else 0 | |
| # parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) | |
| # model_args, data_args, training_args = parser.parse_args_into_dataclasses() | |
| # ee_cfg = get_env_ee_config() | |
| # hf_config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) | |
| # if not getattr(model_args, "model_backbone", None): | |
| # model_backbone = get_backbone_name(hf_config=hf_config, model_type=model_args.model_type) | |
| # setattr(model_args, 'model_backbone', model_backbone) | |
| # processor = load_processor(model_args, data_args) | |
| # model = MMEBModel.load(model_args, is_trainable=False, processor=processor) | |
| # model.eval() | |
| # model = model.to(training_args.device, dtype=torch.bfloat16) | |
| # # [ADD] 注入 AOP & VPOOL 到 wrapper + 底座 | |
| # aop_cfg = get_env_aop_config() | |
| # vpool_cfg = get_env_vpool_config() | |
| # if aop_cfg.get("enabled", False): | |
| # setattr(model.encoder, "aop_prune_config", aop_cfg) | |
| # _set_attr_on_base(model.encoder, "aop_prune_config", aop_cfg) | |
| # model.set_inference_layers(qry_layers=None, tgt_layers=None) | |
| # print_master("[AOP][eval] enabled with cfg: " + str({ | |
| # k: aop_cfg.get(k) for k in ["apply_to","layer_idx","mode","selection","attn_agg","prune_text","prune_vision","keep_ratio_text","keep_ratio_vision"] | |
| # })) | |
| # else: | |
| # print_master("[AOP][eval] disabled") | |
| # if vpool_cfg.get("enabled", False): | |
| # setattr(model.encoder, "vision_pooling_config", vpool_cfg) | |
| # _set_attr_on_base(model.encoder, "vision_pooling_config", vpool_cfg) | |
| # print_master("[VPOOL][eval] enabled with cfg: " + str({ | |
| # k: vpool_cfg.get(k) for k in ["apply_to","layer_idx","kernel","stride","method","vision_only"] | |
| # })) | |
| # else: | |
| # print_master("[VPOOL][eval] disabled") | |
| # # 加载分类器 | |
| # classifier = None | |
| # if ee_cfg["method"] == "classifier" and ee_cfg["enabled"]: | |
| # classifier_path = ee_cfg['classifier_path'] | |
| # print_master(f"[EE] Loading Classifier from {classifier_path}...") | |
| # # 【V5 修改】获取 Backbone Hidden Size 并初始化 V5 分类器 | |
| # backbone_hidden_size = model.encoder.config.hidden_size | |
| # print_master(f"[EE] Backbone Hidden Size: {backbone_hidden_size}") | |
| # # 使用 EarlyExitClassifier (其实引用的是 V5) | |
| # classifier = EarlyExitClassifier( | |
| # input_dim=27, | |
| # hidden_dim=128, | |
| # embedding_dim=backbone_hidden_size | |
| # ) | |
| # state_dict = None | |
| # if os.path.isdir(classifier_path): | |
| # safetensors_file = os.path.join(classifier_path, "model.safetensors") | |
| # if os.path.exists(safetensors_file): | |
| # from safetensors.torch import load_file | |
| # state_dict = load_file(safetensors_file) | |
| # else: | |
| # layer_idx = ee_cfg.get('layer', 12) | |
| # # 尝试加载训练脚本保存的 .pt 文件 | |
| # pt_file = os.path.join(classifier_path, f"early_exit_classifier_layer_{layer_idx}.pt") | |
| # if os.path.exists(pt_file): | |
| # state_dict = torch.load(pt_file, map_location=training_args.device) | |
| # else: | |
| # # 尝试 pytorch_model.bin | |
| # pt_file_bin = os.path.join(classifier_path, "pytorch_model.bin") | |
| # if os.path.exists(pt_file_bin): | |
| # state_dict = torch.load(pt_file_bin, map_location=training_args.device) | |
| # elif os.path.isfile(classifier_path): | |
| # state_dict = torch.load(classifier_path, map_location=training_args.device) | |
| # if state_dict is not None: | |
| # classifier.load_state_dict(state_dict) | |
| # classifier.to(training_args.device) # 默认 FP32 | |
| # classifier.eval() | |
| # print_master(f"[EE] Classifier loaded successfully.") | |
| # else: | |
| # raise FileNotFoundError(f"Could not load classifier weights from {classifier_path}") | |
| # with open(data_args.dataset_config, 'r') as yaml_file: dataset_configs = yaml.safe_load(yaml_file) | |
| # for dataset_name, task_config in dataset_configs.items(): | |
| # if dist.is_initialized(): dist.barrier() | |
| # print_master(f"\n--- Evaluating {dataset_name} ---") | |
| # base_tau = float(ee_cfg["threshold"]) | |
| # ds_tau = PER_DATASET_THRESHOLDS.get(dataset_name, base_tau) | |
| # ee_cfg_ds = dict(ee_cfg); ee_cfg_ds["threshold"] = float(ds_tau) | |
| # if data_args.data_basedir: | |
| # for key in ["image_root", "video_root", "frame_root", "clip_root", "data_path"]: | |
| # if task_config.get(key): task_config[key] = os.path.join(data_args.data_basedir, task_config[key]) | |
| # full_eval_qry_dataset, corpus = AutoEvalPairDataset.instantiate(model_args=model_args, data_args=data_args, **task_config) | |
| # full_eval_cand_dataset = generate_cand_dataset(full_eval_qry_dataset, corpus) | |
| # mid_layer = int(ee_cfg_ds["layer"]) | |
| # cand_mid_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt_layer{mid_layer}") | |
| # cand_last_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt_layerlast") | |
| # if (not os.path.exists(cand_mid_path)) or (not os.path.exists(cand_last_path)): | |
| # eval_cand_collator = MultimodalEvalDataCollator(processor, model_args, data_args, "cand") | |
| # eval_cand_loader = DataLoader(full_eval_cand_dataset, batch_size=training_args.per_device_eval_batch_size, collate_fn=eval_cand_collator, num_workers=training_args.dataloader_num_workers) | |
| # cand_mid, cand_last, cand_ids = encode_candidates_both_layers(model, eval_cand_loader, training_args, model_args, full_eval_cand_dataset, mid_layer) | |
| # if local_rank == 0: | |
| # with open(cand_mid_path, "wb") as f: pickle.dump({cid: emb for cid, emb in zip(cand_ids, cand_mid)}, f) | |
| # with open(cand_last_path, "wb") as f: pickle.dump({cid: emb for cid, emb in zip(cand_ids, cand_last)}, f) | |
| # if dist.is_initialized(): dist.barrier() | |
| # if local_rank == 0: | |
| # with open(cand_mid_path, "rb") as f: cand_mid_dict = pickle.load(f) | |
| # with open(cand_last_path, "rb") as f: cand_last_dict = pickle.load(f) | |
| # rank_global = task_config.get("eval_type", "global") == "global" | |
| # run_early_exit_queries( | |
| # model, classifier, processor, model_args, data_args, training_args, | |
| # full_eval_qry_dataset, cand_mid_dict, cand_last_dict, | |
| # ee_cfg_ds, dataset_name, data_args.encode_output_path, | |
| # global_ranking=rank_global, | |
| # ) | |
| # if dist.is_initialized(): dist.barrier() | |
| # if __name__ == '__main__': | |
| # main() | |
| ################################################################################################################### | |
| import datetime | |
| import logging | |
| import json | |
| import random | |
| import time | |
| import numpy as np | |
| import os | |
| import pickle | |
| import sys | |
| import torch | |
| from torch.profiler import ProfilerActivity, profile, schedule, tensorboard_trace_handler | |
| import torch.distributed as dist | |
| import torch.nn.functional as F | |
| import yaml | |
| import transformers | |
| import math | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| from transformers import HfArgumentParser, AutoConfig, AutoTokenizer | |
| from datasets import Dataset, concatenate_datasets | |
| from datasets.distributed import split_dataset_by_node | |
| from src.arguments import ModelArguments, DataArguments, TrainingArguments | |
| from src.data.collator.eval_collator import MultimodalEvalDataCollator | |
| from src.data.eval_dataset.base_eval_dataset import AutoEvalPairDataset, generate_cand_dataset | |
| from src.eval_utils.metrics import RankingMetrics | |
| from src.model.model_multi_layer_distill_AOP_pooling import MMEBModel | |
| from src.model.processor import get_backbone_name, load_processor, COLPALI | |
| from src.utils import batch_to_device, print_rank, print_master | |
| # ========================================== | |
| # 【V5 修改】引入 V5 版本分类器 | |
| # ========================================== | |
| from src.classifier_utils_V5 import EarlyExitClassifier | |
| logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s [%(name)s:%(lineno)s] %(message)s') | |
| logger = logging.getLogger(__name__) | |
| # =========================== | |
| # Per-dataset thresholds (示例,需根据V5分析结果更新) | |
| # =========================== | |
| PER_DATASET_THRESHOLDS = { | |
| "CIRR": 0.43, | |
| "EDIS": 0.46, | |
| "FashionIQ": 0.43, | |
| "MSCOCO_i2t": 0.43, | |
| "MSCOCO_t2i": 0.42, | |
| "NIGHTS": 0.48, | |
| "OVEN": 0.44, | |
| "VisDial": 0.45, | |
| "VisualNews_i2t": 0.43, | |
| "VisualNews_t2i": 0.43, | |
| "WebQA": 0.55, | |
| "Wiki-SS-NQ": 0.45, | |
| } | |
| # ... (Helper Functions 保持不变) ... | |
| def _parse_bool(v: str, default=False): | |
| if v is None: return default | |
| v = v.strip().lower() | |
| return v in {"1","true","yes","y","t","on"} | |
| def _parse_float(v: str, default=None): | |
| try: return float(v) if v is not None else default | |
| except: return default | |
| def _parse_int(v: str, default=None): | |
| try: return int(v) if v is not None else default | |
| except: return default | |
| def get_env_aop_config(): | |
| enabled = _parse_bool(os.environ.get("AOP_ENABLED"), False) | |
| apply_to = os.environ.get("AOP_APPLY", "qry").strip().lower() | |
| layer_idx = _parse_int(os.environ.get("AOP_LAYER"), None) | |
| mode = os.environ.get("AOP_MODE", "delta").strip().lower() | |
| delta = _parse_float(os.environ.get("AOP_DELTA"), 0.10) | |
| khat = _parse_float(os.environ.get("AOP_KHAT"), 1.0) | |
| keep_ratio = _parse_float(os.environ.get("AOP_KEEP_RATIO"), 1.0) | |
| min_keep = _parse_int(os.environ.get("AOP_MIN_KEEP"), 64) | |
| use_bias = _parse_bool(os.environ.get("AOP_USE_BIAS"), True) | |
| prune_vision = _parse_bool(os.environ.get("AOP_PRUNE_VISION"), True) | |
| prune_text = _parse_bool(os.environ.get("AOP_PRUNE_TEXT"), False) | |
| keep_ratio_v = _parse_float(os.environ.get("AOP_KEEP_RATIO_VISION"), None) | |
| keep_ratio_t = _parse_float(os.environ.get("AOP_KEEP_RATIO_TEXT"), None) | |
| # 新增:各自的 min_keep | |
| min_keep_v = _parse_int(os.environ.get("AOP_MIN_KEEP_VISION"), None) | |
| min_keep_t = _parse_int(os.environ.get("AOP_MIN_KEEP_TEXT"), 32) | |
| # 新增:文本保护参数 | |
| protect_text_last = _parse_int(os.environ.get("AOP_PROTECT_TEXT_LAST"), 16) | |
| protect_special = _parse_bool(os.environ.get("AOP_PROTECT_SPECIAL"), True) | |
| selection = os.environ.get("AOP_SELECTION", "aop").strip().lower() | |
| attn_agg = os.environ.get("AOP_ATTENTION_AGG", "mean").strip().lower() | |
| random_seed = _parse_int(os.environ.get("AOP_RANDOM_SEED"), None) | |
| if layer_idx is None and enabled: | |
| enabled = False | |
| return { | |
| "enabled": enabled, | |
| "apply_to": apply_to, | |
| "layer_idx": layer_idx, | |
| "mode": mode, | |
| "delta": delta, | |
| "K_hat": khat, | |
| "keep_ratio": keep_ratio, | |
| "min_keep": min_keep, | |
| "use_bias": use_bias, | |
| "prune_vision": prune_vision, | |
| "prune_text": prune_text, | |
| "keep_ratio_vision": keep_ratio_v, | |
| "keep_ratio_text": keep_ratio_t, | |
| "min_keep_vision": min_keep_v, | |
| "min_keep_text": min_keep_t, | |
| "protect_text_last": protect_text_last, | |
| "protect_special": protect_special, | |
| "selection": selection, | |
| "attn_agg": attn_agg, | |
| "random_seed": random_seed, | |
| "margin_mid": None, | |
| } | |
| # [新增] VPOOL 配置解析 | |
| def get_env_vpool_config(): | |
| def _b(k, d=False): | |
| v = os.environ.get(k) | |
| if v is None: return d | |
| return str(v).strip().lower() in {"1","true","yes","y","on"} | |
| def _i(k, d=None): | |
| try: return int(os.environ.get(k, d)) | |
| except: return d | |
| def _s(k, d=None): | |
| v = os.environ.get(k, d) | |
| return None if v is None else str(v).strip().lower() | |
| return { | |
| "enabled": _b("VPOOL_ENABLED", False), | |
| "apply_to": _s("VPOOL_APPLY", "both"), # qry|tgt|both | |
| "layer_idx": _i("VPOOL_LAYER", 1), | |
| "kernel": _i("VPOOL_KERNEL", 2), | |
| "stride": _i("VPOOL_STRIDE", None) or _i("VPOOL_KERNEL", 2), | |
| "method": _s("VPOOL_METHOD", "avg"), | |
| "protect_cls": _b("VPOOL_PROTECT_CLS", True), | |
| "vision_only": _b("VPOOL_ONLY_VISION", True), | |
| "monitor": _b("VPOOL_MONITOR", False), | |
| } | |
| def get_env_ee_config(): | |
| ee_enabled = os.environ.get("EE_ENABLED", "0").strip().lower() in {"1","true","yes","on","y","t"} | |
| layer = int(os.environ.get("EE_LAYER", "12")) | |
| method = os.environ.get("EE_METHOD", "classifier").strip().lower() | |
| threshold = float(os.environ.get("EE_THRESHOLD", "0.8")) | |
| classifier_path = os.environ.get("EE_CLASSIFIER_PATH", "") | |
| return dict(enabled=ee_enabled, layer=layer, method=method, threshold=threshold, classifier_path=classifier_path) | |
| # 【新增】Helper: 将配置注入底座模型 (修复 NameError) | |
| def _set_attr_on_base(peft_or_base, name, value): | |
| try: | |
| # 尝试获取 PEFT 的 base model | |
| base = peft_or_base.get_base_model() if hasattr(peft_or_base, "get_base_model") else None | |
| # 如果不是 PEFT,或者 get_base_model 返回空,尝试直接访问 .model | |
| if base is None and hasattr(peft_or_base, "model"): | |
| base = peft_or_base.model | |
| # 注入配置 | |
| if base is not None: | |
| setattr(base, name, value) | |
| # 针对 Qwen2-VL 这种包装结构 (Qwen2VLForConditionalGeneration -> Qwen2VLModel) | |
| if hasattr(base, "model"): | |
| setattr(base.model, name, value) | |
| except Exception as e: | |
| # 仅打印警告,不阻断流程 | |
| print(f"[inject-config] warn: set {name} on base failed: {e}") | |
| def pad_dataset_to_divisible(dataset, world_size): | |
| num_samples = len(dataset) | |
| if num_samples % world_size == 0: return dataset, num_samples | |
| num_to_add = world_size - (num_samples % world_size) | |
| padding_data = dataset.select([i % len(dataset) for i in range(num_to_add)]) | |
| return concatenate_datasets([dataset, padding_data]), num_samples + num_to_add | |
| # =========================== | |
| # Core Inference Function | |
| # =========================== | |
| def run_early_exit_queries( | |
| model: MMEBModel, | |
| classifier: EarlyExitClassifier, | |
| processor, | |
| model_args: ModelArguments, | |
| data_args: DataArguments, | |
| training_args: TrainingArguments, | |
| qry_dataset: Dataset, | |
| cand_mid_dict: dict, | |
| cand_last_dict: dict, | |
| ee_cfg: dict, | |
| dataset_name: str, | |
| out_dir: str, | |
| global_ranking: bool = True, | |
| ): | |
| device = training_args.device | |
| local_rank = dist.get_rank() if dist.is_initialized() else 0 | |
| is_main = (not dist.is_initialized()) or (local_rank == 0) | |
| # Profiling 配置 | |
| profile_enabled = os.environ.get("EE_PROFILE", "0").strip().lower() in {"1", "true", "yes", "on", "y", "t"} | |
| torch_prof_enabled = os.environ.get("EE_TORCH_PROFILE", "0").strip().lower() in {"1", "true", "yes", "on", "y", "t"} | |
| topk_emb = int(os.environ.get("EE_TOPK_EMB", "5")) | |
| timing_stats = { | |
| "mid_time_sum": 0.0, "mid_num": 0, "tail_time_sum": 0.0, "tail_num": 0, | |
| } | |
| analysis_records = [] if (profile_enabled and is_main) else None | |
| # 1. 准备 Candidates | |
| cand_ids = list(cand_mid_dict.keys()) | |
| cand_mid = np.stack([cand_mid_dict[c] for c in cand_ids]).astype(np.float32) | |
| cand_last = np.stack([cand_last_dict[c] for c in cand_ids]).astype(np.float32) | |
| cand_mid_np = cand_mid # [Nc, D], float32 | |
| cand_last_np = cand_last # [Nc, D], float32 | |
| cand_mid_t = torch.from_numpy(cand_mid).to(device=device, dtype=torch.bfloat16) | |
| cand_last_t = torch.from_numpy(cand_last).to(device=device, dtype=torch.bfloat16) | |
| collator = MultimodalEvalDataCollator(processor, model_args, data_args, "qry") | |
| loader = DataLoader( | |
| qry_dataset, batch_size=training_args.per_device_eval_batch_size, | |
| collate_fn=collator, num_workers=training_args.dataloader_num_workers | |
| ) | |
| pred_dicts = [] | |
| stats = {"exit": 0, "total": 0} | |
| threshold = float(ee_cfg["threshold"]) | |
| method = ee_cfg["method"] | |
| target_layer_idx = int(ee_cfg["layer"]) | |
| results_dict = {} | |
| global_sample_idx = 0 | |
| use_local = (not global_ranking) | |
| if use_local: | |
| print_master(f"[INFO] Using LOCAL ranking (per-query candidate sets)") | |
| cand_id2row = {str(cid): i for i, cid in enumerate(cand_ids)} | |
| else: | |
| print_master(f"[INFO] Using GLOBAL ranking (full library)") | |
| # AOP Config | |
| aop_cfg = getattr(model.encoder, "aop_prune_config", None) | |
| _orig_enabled = None | |
| side_enable = True | |
| if isinstance(aop_cfg, dict) and aop_cfg: | |
| _orig_enabled = aop_cfg.get("enabled", False) | |
| apply_to = aop_cfg.get("apply_to", "qry") | |
| side_enable = (apply_to == "both") or (apply_to == "qry") | |
| model.eval() | |
| if classifier: | |
| classifier.eval() | |
| # 【V5 关键】分类器运行在 FP32 (默认),不要转 BF16,配合 Log/LayerNorm | |
| classifier.to(device) | |
| start_time = time.time() | |
| prof = None | |
| if torch_prof_enabled and is_main: | |
| prof_dir = os.path.join(out_dir, "torch_prof", dataset_name) | |
| os.makedirs(prof_dir, exist_ok=True) | |
| prof_schedule = schedule(wait=10, warmup=1, active=1, repeat=0) | |
| prof = profile( | |
| activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], | |
| schedule=prof_schedule, | |
| on_trace_ready=tensorboard_trace_handler(prof_dir), | |
| record_shapes=False, | |
| profile_memory=False, | |
| with_stack=True, | |
| ) | |
| prof.__enter__() | |
| for inputs, infos in tqdm( | |
| loader, desc=f"[EE+AOP] {dataset_name} (tau={threshold})", disable=local_rank > 0, | |
| ): | |
| inputs = batch_to_device(inputs, device) | |
| B = inputs["input_ids"].shape[0] if "input_ids" in inputs else 1 | |
| batch_start_idx = global_sample_idx | |
| global_sample_idx += B | |
| stats["total"] += B | |
| # --------------------------------------------------- | |
| # 1. 前半程: Run to Mid Layer | |
| # --------------------------------------------------- | |
| orig_cfg = None | |
| if isinstance(aop_cfg, dict) and aop_cfg: | |
| orig_cfg = dict(aop_cfg) | |
| aop_layer = aop_cfg.get("layer_idx", None) | |
| aop_on_mid = bool( | |
| _orig_enabled and side_enable | |
| and (aop_layer is not None) | |
| and (aop_layer <= target_layer_idx) | |
| ) | |
| aop_cfg_mid = dict(aop_cfg) | |
| aop_cfg_mid["enabled"] = aop_on_mid | |
| setattr(model.encoder, "aop_prune_config", aop_cfg_mid) | |
| if profile_enabled: | |
| torch.cuda.synchronize() | |
| t0_mid = time.perf_counter() | |
| with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): | |
| out_mid = model.encoder( | |
| **inputs, return_dict=True, output_hidden_states=False,#False | |
| stop_at_layer=target_layer_idx, compute_lm_head=False, | |
| return_intermediate_state=True | |
| ) | |
| # 这里就是“该插这三行”的地方: | |
| post_attn_mid = getattr(out_mid, "attention_mask", None) # [B, L'] 剪完后的 2D mask | |
| img_mask_mid = getattr(out_mid, "image_token_bool_masks", None) # [B, L'] 视觉 token 掩码(Qwen2.5 里目前是 pre 的,你如果想要 post,可以在 Qwen2_5_VLModel 里改成输出 cur_vision_mask) | |
| txt_mask_mid = getattr(out_mid, "text_token_bool_masks", None) # [B, L'] 文本 token 掩码(已经是 post-prune 的) | |
| # 然后按需统计,比如: | |
| if post_attn_mid is not None: | |
| # 有效 token 总数 | |
| keep_counts = post_attn_mid.to(torch.bool).sum(dim=1) # [B] | |
| # 文本 token 保留数 | |
| if txt_mask_mid is not None: | |
| text_keep = (post_attn_mid.to(torch.bool) & txt_mask_mid.to(torch.bool)).sum(dim=1) | |
| # 图像 token 保留数 | |
| if img_mask_mid is not None: | |
| vis_keep = (post_attn_mid.to(torch.bool) & img_mask_mid.to(torch.bool)).sum(dim=1) | |
| if profile_enabled: | |
| torch.cuda.synchronize() | |
| t1_mid = time.perf_counter() | |
| timing_stats["mid_time_sum"] += (t1_mid - t0_mid) * B | |
| timing_stats["mid_num"] += B | |
| if isinstance(orig_cfg, dict): | |
| setattr(model.encoder, "aop_prune_config", orig_cfg) | |
| hs_mid = getattr(out_mid, "last_hidden_state", None) | |
| if hs_mid is None: hs_mid = out_mid.hidden_states[-1] | |
| am_mid = getattr(out_mid, "attention_mask", None) | |
| if am_mid is None: am_mid = inputs.get("attention_mask", None) | |
| reps_mid_t = model._pooling(hs_mid, am_mid).detach().to(dtype=torch.bfloat16) | |
| # Profiling Last Layer | |
| reps_last_full = None | |
| if profile_enabled: | |
| with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): | |
| out_full = model.encoder( | |
| **inputs, return_dict=True, output_hidden_states=False, | |
| stop_at_layer=None, compute_lm_head=False, | |
| return_intermediate_state=True | |
| ) | |
| hs_last_full = getattr(out_full, "last_hidden_state", None) | |
| if hs_last_full is None: hs_last_full = out_full.hidden_states[-1] | |
| am_last_full = getattr(out_full, "attention_mask", None) | |
| if am_last_full is None: am_last_full = inputs.get("attention_mask", None) | |
| reps_last_full = model._pooling(hs_last_full, am_last_full).detach().to(dtype=torch.bfloat16) | |
| # --------------------------------------------------- | |
| # 2. 特征工程 + gating | |
| # --------------------------------------------------- | |
| exit_mask = np.zeros(B, dtype=bool) | |
| p_need_last_batch = None | |
| cos_mid = None | |
| if method == "classifier" and classifier is not None: | |
| cos_mid = reps_mid_t @ cand_mid_t.T # [B, Nc] on GPU | |
| with torch.no_grad(): | |
| # cos_mid = reps_mid_t @ cand_mid_t.T # [B, N] (BF16) | |
| backbone_ptr = model.module if hasattr(model, "module") else model | |
| temp = getattr(backbone_ptr, "temperature", 0.02) | |
| scores_mid = cos_mid / temp | |
| probs_mid = torch.softmax(scores_mid, dim=1) # [B, N] | |
| # --- 27维特征计算 (BF16) --- | |
| diag_cos = cos_mid.max(dim=1)[0] | |
| # sorted_cos, _ = torch.sort(cos_mid, dim=1, descending=True) | |
| # s2_cos = sorted_cos[:, 1] if sorted_cos.size(1) > 1 else sorted_cos[:, 0] | |
| top2 = torch.topk(cos_mid, k=min(2, cos_mid.size(1)), dim=1).values | |
| diag_cos = top2[:, 0] | |
| s2_cos = top2[:, 1] if top2.size(1) > 1 else top2[:, 0] | |
| margin_mid = diag_cos - s2_cos | |
| margin_mean = margin_mid.mean() | |
| margin_std = margin_mid.std(unbiased=False) + 1e-6 | |
| z_margin_mid = (margin_mid - margin_mean) / margin_std | |
| margin_median = margin_mid.median() | |
| mad = (margin_mid - margin_median).abs().median() + 1e-6 | |
| mad_margin_mid = (margin_mid - margin_median) / mad | |
| p1_mid = probs_mid.max(dim=1)[0] | |
| H_mid = -(probs_mid * torch.log(probs_mid + 1e-6)).sum(dim=1) | |
| gini_mid = 1.0 - (probs_mid ** 2).sum(dim=1) | |
| TOPK = min(16, probs_mid.size(1)) | |
| topk_vals, _ = torch.topk(probs_mid, k=TOPK, dim=1) | |
| topk_mean = topk_vals.mean(dim=1) | |
| topk_std = topk_vals.std(dim=1, unbiased=False) | |
| topk_cv = topk_std / (topk_mean + 1e-6) | |
| centered = topk_vals - topk_mean.unsqueeze(1) | |
| var = (centered ** 2).mean(dim=1) + 1e-6 | |
| m4 = (centered ** 4).mean(dim=1) | |
| topk_kurt = m4 / (var ** 2) | |
| topk_med = topk_vals.median(dim=1).values | |
| row_mean_cos = cos_mid.mean(dim=1) | |
| row_med_cos = cos_mid.median(dim=1).values | |
| s1_over_mean = diag_cos - row_mean_cos | |
| s1_over_med = diag_cos - row_med_cos | |
| sorted_probs, _ = torch.sort(probs_mid, dim=1, descending=True) | |
| p1 = sorted_probs[:, 0] | |
| p2 = sorted_probs[:, 1] if sorted_probs.size(1) > 1 else sorted_probs[:, 0] | |
| shape_H = -(sorted_probs * torch.log(sorted_probs + 1e-6)).sum(dim=1) | |
| shape_gini = 1.0 - (sorted_probs ** 2).sum(dim=1) | |
| R = min(10, sorted_probs.size(1)) | |
| x = torch.arange(R, device=device, dtype=sorted_probs.dtype) | |
| x_centered = x - x.mean() | |
| denom = (x_centered ** 2).sum() | |
| y = torch.log(sorted_probs[:, :R] + 1e-6) | |
| slope = (x_centered.unsqueeze(0) * y).sum(dim=1) / denom | |
| row_mean_p = probs_mid.mean(dim=1) | |
| row_std_p = probs_mid.std(dim=1, unbiased=False) + 1e-6 | |
| z1 = (p1_mid - row_mean_p) / row_std_p | |
| center_p = probs_mid - row_mean_p.unsqueeze(1) | |
| m3 = (center_p ** 3).mean(dim=1) | |
| skew = m3 / (row_std_p ** 3 + 1e-6) | |
| s1_over_sk = p1_mid - skew | |
| TAIL_K = min(10, sorted_probs.size(1)) | |
| tail_mean = sorted_probs[:, -TAIL_K:].mean(dim=1) | |
| HEAD_K = min(5, sorted_probs.size(1)) | |
| head5_mean = sorted_probs[:, :HEAD_K].mean(dim=1) | |
| mask_ratio = torch.zeros_like(diag_cos) | |
| mask_len = torch.zeros_like(diag_cos) | |
| mask_runs = torch.zeros_like(diag_cos) | |
| scalar_inputs = torch.stack([ | |
| diag_cos, s2_cos, margin_mid, z_margin_mid, mad_margin_mid, | |
| p1_mid, H_mid, gini_mid, | |
| topk_mean, topk_std, topk_cv, topk_kurt, topk_med, | |
| s1_over_mean, s1_over_med, | |
| p1, p2, shape_H, shape_gini, slope, z1, s1_over_sk, | |
| tail_mean, head5_mean, | |
| mask_ratio, mask_len, mask_runs | |
| ], dim=1) | |
| modality_idx = torch.zeros(B, dtype=torch.long, device=device) | |
| if "pixel_values" in inputs and inputs["pixel_values"] is not None: | |
| pv = inputs["pixel_values"] | |
| if isinstance(pv, list): | |
| for i, item in enumerate(pv): | |
| if item is not None: modality_idx[i] = 1 | |
| elif isinstance(pv, torch.Tensor) and pv.numel() > 0: | |
| modality_idx.fill_(1) | |
| # ======================================================= | |
| # 【V5 关键】强制转为 FP32 传给分类器 | |
| # ======================================================= | |
| scalar_inputs_f32 = scalar_inputs.float() | |
| qry_emb_f32 = reps_mid_t.float() # 转 FP32 | |
| logits = classifier(scalar_inputs_f32, modality_idx, qry_emb=qry_emb_f32) | |
| p_need_last = torch.sigmoid(logits) # [B,1] | |
| p_need_last_batch = p_need_last.squeeze(1) # [B] | |
| should_exit = p_need_last_batch < threshold | |
| exit_mask = should_exit.cpu().numpy() | |
| if stats["total"] <= B * 3 and is_main: | |
| print_master( | |
| f"[EE Debug] Batch {stats['total']//B}: " | |
| f"p_need_last mean={p_need_last_batch.mean().item():.4f}, " | |
| f"Exit Rate={exit_mask.mean():.2%}, " | |
| f"Top3: diag_cos={diag_cos.mean():.3f}, margin={margin_mid.mean():.3f}" | |
| ) | |
| stats["exit"] += exit_mask.sum() | |
| # --------------------------------------------------- | |
| # 3. 分支执行 | |
| # --------------------------------------------------- | |
| exit_indices = np.where(exit_mask)[0] | |
| cont_indices = np.where(~exit_mask)[0] | |
| # A. 早停样本 | |
| if len(exit_indices) > 0: | |
| reps_exit = reps_mid_t[exit_indices] | |
| if use_local: | |
| # 对每个 query 的 local candidates,只在 cos_mid 上取子集再 topk | |
| for i, idx in enumerate(exit_indices): | |
| cand_local = infos[idx].get("cand_names", []) | |
| pairs = [(str(cid), cand_id2row.get(str(cid), -1)) for cid in cand_local] | |
| pairs = [(cid, r) for cid, r in pairs if r >= 0] | |
| if len(pairs) == 0: | |
| cids = [] | |
| else: | |
| cand_local_valid = [cid for cid, _ in pairs] | |
| rows_valid = [r for _, r in pairs] | |
| rows_t = torch.tensor(rows_valid, device=device, dtype=torch.long) | |
| scores = cos_mid[idx].index_select(0, rows_t) # [n_local] | |
| top_k = min(200, scores.numel()) | |
| sel = torch.topk(scores, k=top_k, largest=True).indices.cpu().tolist() | |
| cids = [cand_local_valid[j] for j in sel] | |
| _record_result(results_dict, batch_start_idx + idx, infos[idx], cids) | |
| else: | |
| # reps_exit_np = reps_exit.detach().float().cpu().numpy() | |
| # scores_full = np.dot(reps_exit_np, cand_mid_np.T) | |
| # top_k = min(200, len(cand_ids)) | |
| # topk_inds = np.argsort(-scores_full, axis=1)[:, :top_k] | |
| # for i, idx in enumerate(exit_indices): | |
| # cids = [cand_ids[k] for k in topk_inds[i]] | |
| # _record_result(results_dict, batch_start_idx + idx, infos[idx], cids) | |
| # cos_mid: [B, Nc] | |
| cos_exit = cos_mid[exit_indices] # [Be, Nc] | |
| top_k = min(200, cos_exit.size(1)) | |
| topk_inds = torch.topk(cos_exit, k=top_k, dim=1, largest=True).indices # [Be, top_k] | |
| topk_inds = topk_inds.cpu().tolist() | |
| for i, idx in enumerate(exit_indices): | |
| cids = [cand_ids[j] for j in topk_inds[i]] | |
| _record_result(results_dict, batch_start_idx + idx, infos[idx], cids) | |
| # B. 续跑样本 | |
| if len(cont_indices) > 0: | |
| # 取中间态 | |
| interm = getattr(out_mid, "intermediate_state", None) | |
| hs, am, pos = interm["hidden_states"].detach(), interm["attention_mask"].detach(), interm["position_ids"].detach() | |
| vm, tm = interm.get("vision_mask", None), interm.get("text_mask", None) | |
| next_layer = int(interm["next_layer_idx"]) | |
| resume_state_subset = { | |
| "hidden_states": hs[cont_indices], "attention_mask": am[cont_indices], | |
| "position_ids": pos[:, cont_indices, :], | |
| "vision_mask": vm[cont_indices] if vm is not None else None, | |
| "text_mask": tm[cont_indices] if tm is not None else None, | |
| "next_layer_idx": next_layer, | |
| } | |
| # ====== 这里改 AOP 配置 ====== | |
| if isinstance(aop_cfg, dict) and aop_cfg: | |
| aop_resume = dict(aop_cfg) | |
| aop_layer = aop_resume.get("layer_idx", None) | |
| # 情况1:AOP_LAYER 已经 <= EE_LAYER,说明 mid 那次已经剪过了 → tail 不再剪 | |
| if (aop_layer is not None) and (aop_layer <= target_layer_idx): | |
| need_prune_in_tail = False | |
| else: | |
| # 情况2:AOP_LAYER 在 EE_LAYER 后面(比如 16),只在 tail 中触发 | |
| need_prune_in_tail = bool(_orig_enabled and side_enable) | |
| aop_resume["enabled"] = need_prune_in_tail | |
| setattr(model.encoder, "aop_prune_config", aop_resume) | |
| # ====== AOP 配置修改结束 ====== | |
| if profile_enabled: | |
| torch.cuda.synchronize() | |
| t0_tail = time.perf_counter() | |
| with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): | |
| out_last = model.encoder( | |
| return_dict=True, output_hidden_states=False, stop_at_layer=None, | |
| resume_state=resume_state_subset, compute_lm_head=False, | |
| return_intermediate_state=False | |
| ) | |
| if profile_enabled: | |
| torch.cuda.synchronize() | |
| t1_tail = time.perf_counter() | |
| timing_stats["tail_time_sum"] += (t1_tail - t0_tail) * len(cont_indices) | |
| timing_stats["tail_num"] += len(cont_indices) | |
| hs_last = out_last.last_hidden_state | |
| if hs_last is None: hs_last = out_last.hidden_states[-1] | |
| am_last = getattr(out_last, "attention_mask", None) | |
| if am_last is None: am_last = resume_state_subset["attention_mask"] | |
| reps_last_t = model._pooling(hs_last, am_last).detach().to(dtype=torch.bfloat16) | |
| if use_local: | |
| for i, idx_global in enumerate(cont_indices): | |
| cand_local = infos[idx_global].get("cand_names", []) | |
| if not cand_local: cids = [] | |
| else: | |
| pairs = [(str(cid), cand_id2row.get(str(cid), -1)) for cid in cand_local] | |
| pairs = [(cid, r) for cid, r in pairs if r >= 0] | |
| if len(pairs) == 0: | |
| cids = [] | |
| else: | |
| cand_local_valid = [cid for cid, _ in pairs] | |
| rows_valid = [r for _, r in pairs] | |
| rows_t = torch.tensor(rows_valid, device=device, dtype=torch.long) | |
| cmat = cand_last_t.index_select(0, rows_t) # [n_local, D] | |
| scores = (reps_last_t[i].unsqueeze(0) @ cmat.T).squeeze(0) # [n_local] | |
| top_k = min(200, scores.numel()) | |
| sel = torch.topk(scores, k=top_k, largest=True).indices.cpu().tolist() | |
| cids = [cand_local_valid[j] for j in sel] | |
| _record_result(results_dict, batch_start_idx + idx_global, infos[idx_global], cids) | |
| else: | |
| # reps_last_np = reps_last_t.detach().float().cpu().numpy() | |
| # scores_last = np.dot(reps_last_np, cand_last_np.T) | |
| # top_k = min(200, len(cand_ids)) | |
| # topk_inds = np.argsort(-scores_last, axis=1)[:, :top_k] | |
| # for i, idx_global in enumerate(cont_indices): | |
| # cids = [cand_ids[k] for k in topk_inds[i]] | |
| # _record_result(results_dict, batch_start_idx + idx_global, infos[idx_global], cids) | |
| cos_last = reps_last_t @ cand_last_t.T | |
| top_k = min(200, cos_last.size(1)) | |
| topk_inds = torch.topk(cos_last, k=top_k, dim=1, largest=True).indices | |
| topk_inds = topk_inds.cpu().tolist() | |
| for i, idx_global in enumerate(cont_indices): | |
| cids = [cand_ids[j] for j in topk_inds[i]] | |
| _record_result(results_dict, batch_start_idx + idx_global, infos[idx_global], cids) | |
| # --------------------------------------------------- | |
| # 4. Profiling Stats | |
| # --------------------------------------------------- | |
| if profile_enabled and is_main: | |
| K = min(topk_emb, cand_mid_t.size(0)) | |
| # 转到 float32 + CPU 便于写盘 | |
| q_mid_cpu = reps_mid_t.detach().float().cpu() # [B, D] | |
| q_last_cpu = ( | |
| reps_last_full.detach().float().cpu() | |
| if reps_last_full is not None | |
| else None | |
| ) # [B, D] | |
| cand_mid_cpu = cand_mid_t.detach().float().cpu() # [Nc, D] | |
| cand_last_cpu = cand_last_t.detach().float().cpu() # [Nc, D] | |
| # mid2mid | |
| scores_mid_full = q_mid_cpu @ cand_mid_cpu.T # [B, Nc] | |
| topk_mid_vals, topk_mid_inds = torch.topk( | |
| scores_mid_full, k=K, dim=1 | |
| ) | |
| # last2last | |
| if q_last_cpu is not None: | |
| scores_last_full = q_last_cpu @ cand_last_cpu.T # [B, Nc] | |
| topk_last_vals, topk_last_inds = torch.topk( | |
| scores_last_full, k=K, dim=1 | |
| ) | |
| else: | |
| topk_last_vals = None | |
| topk_last_inds = None | |
| for i in range(B): | |
| qid = batch_start_idx + i | |
| rec = { | |
| "qid": int(qid), | |
| "early_exit": bool(exit_mask[i]), | |
| } | |
| if p_need_last_batch is not None: | |
| rec["p_need_last"] = float(p_need_last_batch[i].item()) | |
| # mid2mid TopK | |
| mid_inds = topk_mid_inds[i].tolist() | |
| mid_scores = topk_mid_vals[i].tolist() | |
| rec["mid_topk_scores"] = mid_scores | |
| rec["mid_topk_cand_ids"] = [cand_ids[j] for j in mid_inds] | |
| # last2last TopK | |
| if topk_last_inds is not None: | |
| last_inds = topk_last_inds[i].tolist() | |
| last_scores = topk_last_vals[i].tolist() | |
| rec["last_topk_scores"] = last_scores | |
| rec["last_topk_cand_ids"] = [ | |
| cand_ids[j] for j in last_inds | |
| ] | |
| else: | |
| rec["last_topk_scores"] = None | |
| rec["last_topk_cand_ids"] = None | |
| analysis_records.append(rec) | |
| if prof is not None: | |
| prof.step() | |
| if prof is not None: | |
| # Trigger on_trace_ready if the schedule hasn't naturally advanced yet. | |
| prof.step() | |
| prof.__exit__(None, None, None) | |
| # 5. 收集 & 保存结果 | |
| for idx in sorted(results_dict.keys()): | |
| pred_dicts.append(results_dict[idx]) | |
| print_master(f"Early Exit Stats: Exit={stats['exit']}/{stats['total']} ({stats['exit']/stats['total']:.2%})") | |
| metrics_to_report = ["hit", "ndcg", "precision", "recall", "f1", "map", "mrr"] | |
| score = RankingMetrics(metrics_to_report).evaluate(pred_dicts) | |
| if is_main: | |
| os.makedirs(out_dir, exist_ok=True) | |
| with open(os.path.join(out_dir, f"{dataset_name}_score_earlyexit.json"), "w") as f: | |
| json.dump(score, f, indent=4) | |
| # Profiling save | |
| if profile_enabled: | |
| prof_dir = os.path.join(out_dir, "profiling") | |
| os.makedirs(prof_dir, exist_ok=True) | |
| mid_avg = timing_stats["mid_time_sum"] / max(1, timing_stats["mid_num"]) | |
| tail_avg = timing_stats["tail_time_sum"] / max(1, timing_stats["tail_num"]) | |
| timing_out = { | |
| "mid_time_sum": timing_stats["mid_time_sum"], | |
| "mid_num": timing_stats["mid_num"], | |
| "tail_time_sum": timing_stats["tail_time_sum"], | |
| "tail_num": timing_stats["tail_num"], | |
| "avg_mid_time_per_query_sec": mid_avg, | |
| "avg_tail_time_per_cont_query_sec": tail_avg, | |
| "num_exit": int(stats["exit"]), | |
| "num_total": int(stats["total"]), | |
| } | |
| with open(os.path.join(prof_dir, f"{dataset_name}_timing.json"), "w") as f: | |
| json.dump(timing_out, f, indent=2) | |
| embed_path = os.path.join(prof_dir, f"{dataset_name}_embeds.jsonl") | |
| with open(embed_path, "w") as f: | |
| for rec in analysis_records: | |
| f.write(json.dumps(rec) + "\n") | |
| print_master(f"[PROFILE] Saved timing to {prof_dir}, details to {embed_path}") | |
| elapsed = time.time() - start_time | |
| return score, elapsed | |
| def _record_result(results_dict, global_idx, info, cids): | |
| label = info.get("label_name") or info.get("label") or info.get("rel_docids") | |
| if not isinstance(label, list): label = [label] | |
| rel_scores = info.get("rel_scores", None) | |
| results_dict[global_idx] = { | |
| "prediction": cids, "label": label, "rel_scores": rel_scores, | |
| } | |
| # =========================== | |
| # Helper Functions (Pre-Computation) | |
| # =========================== | |
| # ... (encode_candidates_both_layers 保持不变) ... | |
| def encode_candidates_both_layers(model: MMEBModel, loader: DataLoader, training_args: TrainingArguments, model_args: ModelArguments, full_dataset: Dataset, mid_layer: int): | |
| local_rank = dist.get_rank() if dist.is_initialized() else 0 | |
| model.eval() | |
| all_mid, all_last, all_ids = [], [], [] | |
| with torch.no_grad(): | |
| for inputs, dataset_info in tqdm(loader, desc=f"Candidates[BOTH] (rank {local_rank})", disable=local_rank > 0): | |
| inputs = batch_to_device(inputs, training_args.device) | |
| aop_cfg = getattr(model.encoder, "aop_prune_config", None) | |
| _orig = None | |
| if isinstance(aop_cfg, dict): | |
| _orig = aop_cfg.get("enabled", False) | |
| aop_cfg["enabled"] = False | |
| setattr(model.encoder, "aop_prune_config", aop_cfg) | |
| with torch.autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda"): | |
| out = model.encoder(**inputs, return_dict=True, output_hidden_states=True, stop_at_layer=None) | |
| if isinstance(aop_cfg, dict) and _orig is not None: aop_cfg["enabled"] = _orig | |
| # mid_hs = out.hidden_states[mid_layer] | |
| # last_hs = out.hidden_states[-1] | |
| # am = inputs.get("attention_mask", None) | |
| # if am is not None and am.device != mid_hs.device: am = am.to(mid_hs.device) | |
| # reps_mid = model._pooling(mid_hs, am) | |
| # reps_last = model._pooling(last_hs, am) | |
| mid_hs = out.hidden_states[mid_layer] | |
| last_hs = out.hidden_states[-1] | |
| post_am = getattr(out, "attention_mask", None) # post (after VPOOL/AOP) | |
| pre_am = inputs.get("attention_mask", None) # pre | |
| def pick_mask(h, pre_am, post_am): | |
| if post_am is not None and post_am.size(1) == h.size(1): | |
| return post_am | |
| if pre_am is not None and pre_am.size(1) == h.size(1): | |
| return pre_am | |
| return torch.ones(h.size(0), h.size(1), dtype=torch.long, device=h.device) | |
| am_mid = pick_mask(mid_hs, pre_am, post_am) | |
| am_last = pick_mask(last_hs, pre_am, post_am) | |
| reps_mid = model._pooling(mid_hs, am_mid) | |
| reps_last = model._pooling(last_hs, am_last) | |
| all_mid.append(reps_mid.detach().float().cpu()) | |
| all_last.append(reps_last.detach().float().cpu()) | |
| all_ids.extend([info["cand_name"] for info in dataset_info]) | |
| if not all_mid: return np.array([]), np.array([]), [] | |
| return torch.cat(all_mid, dim=0).numpy(), torch.cat(all_last, dim=0).numpy(), all_ids | |
| # =========================== | |
| # Main | |
| # =========================== | |
| def main(): | |
| if "RANK" in os.environ and dist.is_available() and not dist.is_initialized(): | |
| dist.init_process_group(backend="nccl", timeout=datetime.timedelta(minutes=60)) | |
| local_rank = dist.get_rank() if dist.is_initialized() else 0 | |
| parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) | |
| model_args, data_args, training_args = parser.parse_args_into_dataclasses() | |
| ee_cfg = get_env_ee_config() | |
| hf_config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) | |
| if not getattr(model_args, "model_backbone", None): | |
| model_backbone = get_backbone_name(hf_config=hf_config, model_type=model_args.model_type) | |
| setattr(model_args, 'model_backbone', model_backbone) | |
| processor = load_processor(model_args, data_args) | |
| # 1. 加载模型 | |
| model = MMEBModel.load(model_args, is_trainable=False, processor=processor) | |
| model.eval() | |
| model = model.to(training_args.device, dtype=torch.bfloat16) | |
| # 2. 注入 AOP 配置 (全局生效) | |
| aop_cfg = get_env_aop_config() | |
| if aop_cfg["enabled"]: | |
| setattr(model.encoder, "aop_prune_config", aop_cfg) | |
| _set_attr_on_base(model.encoder, "aop_prune_config", aop_cfg) | |
| print_master(f"[AOP] Enabled: {aop_cfg['apply_to']} @ layer {aop_cfg['layer_idx']}") | |
| # 3. 注入 VPOOL 配置 (全局生效) | |
| vpool_cfg = get_env_vpool_config() | |
| if vpool_cfg["enabled"]: | |
| setattr(model.encoder, "vision_pooling_config", vpool_cfg) | |
| _set_attr_on_base(model.encoder, "vision_pooling_config", vpool_cfg) | |
| print_master(f"[VPOOL] Enabled: {vpool_cfg['apply_to']} @ layer {vpool_cfg['layer_idx']}") | |
| # [删除] 这一行是旧代码遗留,新模型不需要它 | |
| # model.set_inference_layers(qry_layers=None, tgt_layers=None) | |
| # 4. 加载分类器 | |
| classifier = None | |
| if ee_cfg["method"] == "classifier" and ee_cfg["enabled"]: | |
| classifier_path = ee_cfg['classifier_path'] | |
| print_master(f"[EE] Loading Classifier from {classifier_path}...") | |
| # 获取 Embedding 维度 (AOP 不改变隐层维度,只改变长度) | |
| backbone_hidden_size = model.encoder.config.hidden_size | |
| print_master(f"[EE] Backbone Hidden Size: {backbone_hidden_size}") | |
| classifier = EarlyExitClassifier( | |
| input_dim=27, | |
| hidden_dim=128, | |
| embedding_dim=backbone_hidden_size | |
| ) | |
| state_dict = None | |
| if os.path.isdir(classifier_path): | |
| safetensors_file = os.path.join(classifier_path, "model.safetensors") | |
| if os.path.exists(safetensors_file): | |
| from safetensors.torch import load_file | |
| state_dict = load_file(safetensors_file) | |
| else: | |
| pt_file_bin = os.path.join(classifier_path, "pytorch_model.bin") | |
| if os.path.exists(pt_file_bin): | |
| state_dict = torch.load(pt_file_bin, map_location=training_args.device) | |
| else: | |
| layer_idx = ee_cfg.get('layer', 12) | |
| pt_file = os.path.join(classifier_path, f"early_exit_classifier_layer_{layer_idx}.pt") | |
| if os.path.exists(pt_file): | |
| state_dict = torch.load(pt_file, map_location=training_args.device) | |
| elif os.path.isfile(classifier_path): | |
| state_dict = torch.load(classifier_path, map_location=training_args.device) | |
| if state_dict is not None: | |
| classifier.load_state_dict(state_dict) | |
| classifier.to(training_args.device) # FP32 | |
| classifier.eval() | |
| print_master(f"[EE] Classifier loaded successfully.") | |
| else: | |
| raise FileNotFoundError(f"Could not load classifier weights from {classifier_path}") | |
| # 5. 开始评测循环 | |
| with open(data_args.dataset_config, 'r') as yaml_file: dataset_configs = yaml.safe_load(yaml_file) | |
| for dataset_name, task_config in dataset_configs.items(): | |
| if dist.is_initialized(): dist.barrier() | |
| print_master(f"\n--- Evaluating {dataset_name} ---") | |
| base_tau = float(ee_cfg["threshold"]) | |
| ds_tau = PER_DATASET_THRESHOLDS.get(dataset_name, base_tau) | |
| ee_cfg_ds = dict(ee_cfg); ee_cfg_ds["threshold"] = float(ds_tau) | |
| if data_args.data_basedir: | |
| for key in ["image_root", "video_root", "frame_root", "clip_root", "data_path"]: | |
| if task_config.get(key): task_config[key] = os.path.join(data_args.data_basedir, task_config[key]) | |
| full_eval_qry_dataset, corpus = AutoEvalPairDataset.instantiate(model_args=model_args, data_args=data_args, **task_config) | |
| full_eval_cand_dataset = generate_cand_dataset(full_eval_qry_dataset, corpus) | |
| mid_layer = int(ee_cfg_ds["layer"]) | |
| cand_mid_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt_layer{mid_layer}") | |
| cand_last_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt_layerlast") | |
| if (not os.path.exists(cand_mid_path)) or (not os.path.exists(cand_last_path)): | |
| eval_cand_collator = MultimodalEvalDataCollator(processor, model_args, data_args, "cand") | |
| eval_cand_loader = DataLoader(full_eval_cand_dataset, batch_size=training_args.per_device_eval_batch_size, collate_fn=eval_cand_collator, num_workers=training_args.dataloader_num_workers) | |
| # 【关键】Candidate 编码也必须经过 AOP/VPOOL | |
| cand_mid, cand_last, cand_ids = encode_candidates_both_layers(model, eval_cand_loader, training_args, model_args, full_eval_cand_dataset, mid_layer) | |
| if local_rank == 0: | |
| with open(cand_mid_path, "wb") as f: pickle.dump({cid: emb for cid, emb in zip(cand_ids, cand_mid)}, f) | |
| with open(cand_last_path, "wb") as f: pickle.dump({cid: emb for cid, emb in zip(cand_ids, cand_last)}, f) | |
| if dist.is_initialized(): dist.barrier() | |
| if local_rank == 0: | |
| with open(cand_mid_path, "rb") as f: cand_mid_dict = pickle.load(f) | |
| with open(cand_last_path, "rb") as f: cand_last_dict = pickle.load(f) | |
| rank_global = task_config.get("eval_type", "global") == "global" | |
| run_early_exit_queries( | |
| model, classifier, processor, model_args, data_args, training_args, | |
| full_eval_qry_dataset, cand_mid_dict, cand_last_dict, | |
| ee_cfg_ds, dataset_name, data_args.encode_output_path, | |
| global_ranking=rank_global, | |
| ) | |
| if dist.is_initialized(): dist.barrier() | |
| if __name__ == '__main__': | |
| main() |