| import datetime |
| import logging |
| import json |
| import random |
| import time |
| import numpy as np |
| import os |
| import pickle |
| import sys |
| import hashlib |
| import torch |
| import torch.distributed as dist |
| import torch.nn.functional as F |
| import yaml |
| import transformers |
| import math |
| from contextlib import contextmanager |
|
|
| from torch.utils.data import DataLoader |
| from tqdm import tqdm |
| from transformers import HfArgumentParser, AutoConfig, AutoTokenizer |
| from datasets import Dataset, concatenate_datasets |
| from datasets.distributed import split_dataset_by_node |
|
|
| from src.arguments import ModelArguments, DataArguments, TrainingArguments |
| from src.data.collator.eval_collator import MultimodalEvalDataCollator |
| from src.data.eval_dataset.base_eval_dataset import AutoEvalPairDataset, generate_cand_dataset |
| from src.eval_utils.metrics import RankingMetrics |
| from src.model.model_multilayer_AOP_infer_attn_pooling import MMEBModel |
| from src.model.processor import get_backbone_name, load_processor, COLPALI |
| from src.utils import batch_to_device, print_rank, print_master |
|
|
| |
| |
| |
| from src.classifier_utils_V5 import EarlyExitClassifier |
|
|
| logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s [%(name)s:%(lineno)s] %(message)s') |
| logger = logging.getLogger(__name__) |
|
|
| |
| |
| |
| PER_DATASET_THRESHOLDS = { |
| "CIRR": 0.45, |
| "EDIS": 0.45, |
| "FashionIQ": 0.45, |
| "MSCOCO_i2t": 0.45, |
| "MSCOCO_t2i": 0.45, |
| "NIGHTS": 0.45, |
| "OVEN": 0.45, |
| "VisDial": 0.45, |
| "VisualNews_i2t": 0.45, |
| "VisualNews_t2i": 0.45, |
| "WebQA": 0.45, |
| "Wiki-SS-NQ": 0.45, |
| } |
|
|
| |
| def _parse_bool(v: str, default=False): |
| if v is None: return default |
| v = v.strip().lower() |
| return v in {"1","true","yes","y","t","on"} |
|
|
| def _parse_float(v: str, default=None): |
| try: return float(v) if v is not None else default |
| except: return default |
|
|
| def _parse_int(v: str, default=None): |
| try: return int(v) if v is not None else default |
| except: return default |
|
|
| def get_env_aop_config(): |
| enabled = _parse_bool(os.environ.get("AOP_ENABLED"), False) |
| apply_to = os.environ.get("AOP_APPLY", "qry").strip().lower() |
| layer_idx = _parse_int(os.environ.get("AOP_LAYER"), None) |
| mode = os.environ.get("AOP_MODE", "delta").strip().lower() |
| delta = _parse_float(os.environ.get("AOP_DELTA"), 0.10) |
| khat = _parse_float(os.environ.get("AOP_KHAT"), 1.0) |
| keep_ratio = _parse_float(os.environ.get("AOP_KEEP_RATIO"), 1.0) |
| min_keep = _parse_int(os.environ.get("AOP_MIN_KEEP"), 64) |
| use_bias = _parse_bool(os.environ.get("AOP_USE_BIAS"), True) |
|
|
| prune_vision = _parse_bool(os.environ.get("AOP_PRUNE_VISION"), True) |
| prune_text = _parse_bool(os.environ.get("AOP_PRUNE_TEXT"), False) |
|
|
| keep_ratio_v = _parse_float(os.environ.get("AOP_KEEP_RATIO_VISION"), None) |
| keep_ratio_t = _parse_float(os.environ.get("AOP_KEEP_RATIO_TEXT"), None) |
|
|
| |
| min_keep_v = _parse_int(os.environ.get("AOP_MIN_KEEP_VISION"), None) |
| min_keep_t = _parse_int(os.environ.get("AOP_MIN_KEEP_TEXT"), 32) |
|
|
| |
| protect_text_last = _parse_int(os.environ.get("AOP_PROTECT_TEXT_LAST"), 16) |
| protect_special = _parse_bool(os.environ.get("AOP_PROTECT_SPECIAL"), True) |
|
|
| selection = os.environ.get("AOP_SELECTION", "aop").strip().lower() |
| attn_agg = os.environ.get("AOP_ATTENTION_AGG", "mean").strip().lower() |
| random_seed = _parse_int(os.environ.get("AOP_RANDOM_SEED"), None) |
|
|
| if layer_idx is None and enabled: |
| enabled = False |
|
|
| return { |
| "enabled": enabled, |
| "apply_to": apply_to, |
| "layer_idx": layer_idx, |
| "mode": mode, |
| "delta": delta, |
| "K_hat": khat, |
| "keep_ratio": keep_ratio, |
| "min_keep": min_keep, |
| "use_bias": use_bias, |
|
|
| "prune_vision": prune_vision, |
| "prune_text": prune_text, |
|
|
| "keep_ratio_vision": keep_ratio_v, |
| "keep_ratio_text": keep_ratio_t, |
| "min_keep_vision": min_keep_v, |
| "min_keep_text": min_keep_t, |
|
|
| "protect_text_last": protect_text_last, |
| "protect_special": protect_special, |
|
|
| "selection": selection, |
| "attn_agg": attn_agg, |
| "random_seed": random_seed, |
|
|
| "margin_mid": None, |
| } |
|
|
| |
| def get_env_vpool_config(): |
| def _b(k, d=False): |
| v = os.environ.get(k) |
| if v is None: return d |
| return str(v).strip().lower() in {"1","true","yes","y","on"} |
| def _i(k, d=None): |
| try: return int(os.environ.get(k, d)) |
| except: return d |
| def _s(k, d=None): |
| v = os.environ.get(k, d) |
| return None if v is None else str(v).strip().lower() |
| def _f(k, d=None): |
| try: return float(os.environ.get(k, d)) |
| except: return d |
|
|
| return { |
| "enabled": _b("VPOOL_ENABLED", False), |
| "apply_to": _s("VPOOL_APPLY", "both"), |
| "layer_idx": _i("VPOOL_LAYER", 1), |
| "kernel": _i("VPOOL_KERNEL", 2), |
| "stride": _i("VPOOL_STRIDE", None) or _i("VPOOL_KERNEL", 2), |
| "method": _s("VPOOL_METHOD", "avg"), |
| "protect_cls": _b("VPOOL_PROTECT_CLS", True), |
| "vision_only": _b("VPOOL_ONLY_VISION", True), |
| "monitor": _b("VPOOL_MONITOR", False), |
| |
| "attn_tau": _f("VPOOL_ATTN_TAU", 1.0), |
| } |
|
|
| def get_env_ee_config(): |
| ee_enabled = os.environ.get("EE_ENABLED", "0").strip().lower() in {"1","true","yes","on","y","t"} |
| layer = int(os.environ.get("EE_LAYER", "12")) |
| method = os.environ.get("EE_METHOD", "classifier").strip().lower() |
| threshold = float(os.environ.get("EE_THRESHOLD", "0.8")) |
| classifier_path = os.environ.get("EE_CLASSIFIER_PATH", "") |
| return dict(enabled=ee_enabled, layer=layer, method=method, threshold=threshold, classifier_path=classifier_path) |
|
|
| |
| def _set_attr_on_base(peft_or_base, name, value): |
| try: |
| |
| base = peft_or_base.get_base_model() if hasattr(peft_or_base, "get_base_model") else None |
| |
| |
| if base is None and hasattr(peft_or_base, "model"): |
| base = peft_or_base.model |
| |
| |
| if base is not None: |
| setattr(base, name, value) |
| |
| if hasattr(base, "model"): |
| setattr(base.model, name, value) |
| except Exception as e: |
| |
| print(f"[inject-config] warn: set {name} on base failed: {e}") |
|
|
| def _maybe_load_vpool_attn_pooler_into_encoder(model: MMEBModel, ckpt_dir: str): |
| """ |
| 推理时把 vpool_attn_pooler 权重加载进 encoder(解决 base 随机初始化 pooler 的问题)。 |
| 兼容两种 state_dict key: |
| - 直接保存 pooler.state_dict(): 'mlp.0.weight'... |
| - peft modules_to_save wrapper: 'modules_to_save.default.mlp.0.weight'... |
| """ |
| if not ckpt_dir or (not os.path.isdir(ckpt_dir)): |
| print_master(f"[VPOOL-ATTN] ckpt_dir not found: {ckpt_dir}") |
| return False |
|
|
| enc = model.encoder |
| inner = getattr(enc, "model", None) |
| pooler = getattr(inner, "vpool_attn_pooler", None) if inner is not None else None |
| if pooler is None and hasattr(enc, "vpool_attn_pooler"): |
| pooler = enc.vpool_attn_pooler |
| if pooler is None: |
| print_master("[VPOOL-ATTN] pooler not found in encoder, skip loading.") |
| return False |
|
|
| path = os.path.join(ckpt_dir, "vpool_attn_pooler.safetensors") |
| if not os.path.isfile(path): |
| path = os.path.join(ckpt_dir, "vpool_attn_pooler.pt") |
| if not os.path.isfile(path): |
| print_master(f"[VPOOL-ATTN] no pooler weight file found under {ckpt_dir}") |
| return False |
|
|
| |
| target = pooler |
| if hasattr(pooler, "modules_to_save") and isinstance(pooler.modules_to_save, dict) and "default" in pooler.modules_to_save: |
| target = pooler.modules_to_save["default"] |
| elif hasattr(pooler, "original_module"): |
| target = pooler.original_module |
|
|
| try: |
| if path.endswith(".safetensors"): |
| from safetensors.torch import load_file |
| sd_raw = load_file(path, device="cpu") |
| else: |
| sd_raw = torch.load(path, map_location="cpu") |
|
|
| |
| def _strip_prefix(sd: dict, prefix: str) -> dict: |
| return {k[len(prefix):]: v for k, v in sd.items() if k.startswith(prefix)} |
|
|
| sd = sd_raw |
| if any(k.startswith("modules_to_save.default.") for k in sd.keys()): |
| sd = _strip_prefix(sd, "modules_to_save.default.") |
| elif any(k.startswith("original_module.") for k in sd.keys()): |
| sd = _strip_prefix(sd, "original_module.") |
|
|
| missing, unexpected = target.load_state_dict(sd, strict=False) |
|
|
| |
| try: |
| mean_abs = float(target.mlp[3].weight.detach().abs().mean().cpu().item()) |
| except Exception: |
| mean_abs = None |
|
|
| print_master(f"[VPOOL-ATTN] loaded pooler from {path}. missing={missing}, unexpected={unexpected}, mlp3_mean_abs={mean_abs}") |
| return (len(missing) == 0) |
| except Exception as e: |
| print_master(f"[VPOOL-ATTN] load pooler failed: {e}") |
| return False |
| |
| def _norm_side(side: str) -> str: |
| s = str(side).strip().lower() |
| if s in {"cand", "tgt", "target"}: |
| return "tgt" |
| return "qry" |
|
|
| def _apply_to_match(apply_to: str, side: str) -> bool: |
| """ |
| apply_to: qry|tgt|cand|both |
| side: qry|tgt |
| """ |
| a = str(apply_to).strip().lower() |
| if a == "cand": |
| a = "tgt" |
| if a == "both": |
| return True |
| return a == side |
|
|
| @contextmanager |
| def _temp_cfg_enabled(model: MMEBModel, cfg_name: str, enable: bool): |
| """ |
| 临时改 encoder.<cfg_name> 的 enabled,并同步到底座(ForConditionalGeneration / .model)。 |
| cfg_name: "aop_prune_config" or "vision_pooling_config" |
| """ |
| enc = model.encoder |
| old = getattr(enc, cfg_name, None) |
| if isinstance(old, dict): |
| new = dict(old) |
| new["enabled"] = bool(enable and old.get("enabled", False)) |
| setattr(enc, cfg_name, new) |
| _set_attr_on_base(enc, cfg_name, new) |
| try: |
| yield |
| finally: |
| setattr(enc, cfg_name, old) |
| _set_attr_on_base(enc, cfg_name, old) |
|
|
| def _pool_from_outputs(model: MMEBModel, hs: torch.Tensor, out, fallback_mask: torch.Tensor | None, lens_index: int | None = None): |
| """ |
| 统一 pooling:优先使用 out.attn_lens(如果模型支持且你已按前面 patch 加入),否则用 mask。 |
| lens_index=None -> 用最后一项(对应 last_hidden_state)。 |
| """ |
| attn_lens = getattr(out, "attn_lens", None) |
| if hasattr(model, "_pooling_from_lens") and isinstance(attn_lens, (list, tuple)) and len(attn_lens) > 0: |
| if lens_index is None: |
| lens_1d = attn_lens[-1] |
| else: |
| if 0 <= lens_index < len(attn_lens): |
| lens_1d = attn_lens[lens_index] |
| else: |
| lens_1d = attn_lens[-1] |
| if lens_1d is not None: |
| return model._pooling_from_lens(hs, lens_1d) |
|
|
| |
| am = getattr(out, "attention_mask", None) |
| if am is None: |
| am = fallback_mask |
| return model._pooling(hs, am) |
|
|
| def _cfg_cache_tag(aop_cfg: dict | None, vpool_cfg: dict | None, mid_layer: int) -> str: |
| """ |
| 生成 candidate 缓存文件名的短 hash,避免切换 AOP/VPOOL 配置后误读旧缓存。 |
| """ |
| payload = { |
| "mid_layer": int(mid_layer), |
| "aop": aop_cfg if isinstance(aop_cfg, dict) else None, |
| "vpool": vpool_cfg if isinstance(vpool_cfg, dict) else None, |
| } |
| s = json.dumps(payload, sort_keys=True, ensure_ascii=False) |
| return hashlib.md5(s.encode("utf-8")).hexdigest()[:10] |
|
|
| def pad_dataset_to_divisible(dataset, world_size): |
| num_samples = len(dataset) |
| if num_samples % world_size == 0: return dataset, num_samples |
| num_to_add = world_size - (num_samples % world_size) |
| padding_data = dataset.select([i % len(dataset) for i in range(num_to_add)]) |
| return concatenate_datasets([dataset, padding_data]), num_samples + num_to_add |
|
|
| |
| |
| |
| def run_early_exit_queries( |
| model: MMEBModel, |
| classifier: EarlyExitClassifier, |
| processor, |
| model_args: ModelArguments, |
| data_args: DataArguments, |
| training_args: TrainingArguments, |
| qry_dataset: Dataset, |
| cand_mid_dict: dict, |
| cand_last_dict: dict, |
| ee_cfg: dict, |
| dataset_name: str, |
| out_dir: str, |
| global_ranking: bool = True, |
| ): |
| device = training_args.device |
| local_rank = dist.get_rank() if dist.is_initialized() else 0 |
| is_main = (not dist.is_initialized()) or (local_rank == 0) |
|
|
| |
| profile_enabled = os.environ.get("EE_PROFILE", "0").strip().lower() in {"1", "true", "yes", "on", "y", "t"} |
| topk_emb = int(os.environ.get("EE_TOPK_EMB", "5")) |
| timing_stats = { |
| "mid_time_sum": 0.0, "mid_num": 0, "tail_time_sum": 0.0, "tail_num": 0, |
| } |
| analysis_records = [] if (profile_enabled and is_main) else None |
|
|
| |
| cand_ids = list(cand_mid_dict.keys()) |
| cand_mid = np.stack([cand_mid_dict[c] for c in cand_ids]).astype(np.float32) |
| cand_last = np.stack([cand_last_dict[c] for c in cand_ids]).astype(np.float32) |
|
|
| cand_mid_np = cand_mid |
| cand_last_np = cand_last |
| |
| cand_mid_t = torch.from_numpy(cand_mid).to(device=device, dtype=torch.bfloat16) |
| cand_last_t = torch.from_numpy(cand_last).to(device=device, dtype=torch.bfloat16) |
|
|
| collator = MultimodalEvalDataCollator(processor, model_args, data_args, "qry") |
| loader = DataLoader( |
| qry_dataset, batch_size=training_args.per_device_eval_batch_size, |
| collate_fn=collator, num_workers=training_args.dataloader_num_workers |
| ) |
|
|
| pred_dicts = [] |
| stats = {"exit": 0, "total": 0} |
| |
| threshold = float(ee_cfg["threshold"]) |
| method = ee_cfg["method"] |
| target_layer_idx = int(ee_cfg["layer"]) |
| results_dict = {} |
| global_sample_idx = 0 |
| |
| use_local = (not global_ranking) |
| if use_local: |
| print_master(f"[INFO] Using LOCAL ranking (per-query candidate sets)") |
| cand_id2row = {str(cid): i for i, cid in enumerate(cand_ids)} |
| else: |
| print_master(f"[INFO] Using GLOBAL ranking (full library)") |
|
|
| |
| aop_cfg = getattr(model.encoder, "aop_prune_config", None) |
| vpool_cfg = getattr(model.encoder, "vision_pooling_config", None) |
| _orig_enabled = None |
| side_enable = True |
| if isinstance(aop_cfg, dict) and aop_cfg: |
| _orig_enabled = aop_cfg.get("enabled", False) |
| apply_to = aop_cfg.get("apply_to", "qry") |
| side_enable = (apply_to == "both") or (apply_to == "qry") |
|
|
| model.eval() |
| if classifier: |
| classifier.eval() |
| |
| classifier.to(device) |
|
|
| start_time = time.time() |
|
|
| for inputs, infos in tqdm( |
| loader, desc=f"[EE+AOP] {dataset_name} (tau={threshold})", disable=local_rank > 0, |
| ): |
| inputs = batch_to_device(inputs, device) |
| B = inputs["input_ids"].shape[0] if "input_ids" in inputs else 1 |
| batch_start_idx = global_sample_idx |
| global_sample_idx += B |
| stats["total"] += B |
|
|
| |
| |
| |
| side = "qry" |
| side_ok_aop = isinstance(aop_cfg, dict) and _apply_to_match(aop_cfg.get("apply_to", "both"), side) |
| side_ok_vp = isinstance(vpool_cfg, dict) and _apply_to_match(vpool_cfg.get("apply_to", "both"), side) |
|
|
| |
| aop_layer = aop_cfg.get("layer_idx", None) if isinstance(aop_cfg, dict) else None |
| aop_on_mid = bool(isinstance(aop_cfg, dict) and aop_cfg.get("enabled", False) and side_ok_aop and (aop_layer is not None) and (aop_layer <= target_layer_idx)) |
|
|
| if profile_enabled: |
| torch.cuda.synchronize() |
| t0_mid = time.perf_counter() |
|
|
| with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): |
| with _temp_cfg_enabled(model, "aop_prune_config", aop_on_mid): |
| with _temp_cfg_enabled(model, "vision_pooling_config", bool(isinstance(vpool_cfg, dict) and vpool_cfg.get("enabled", False) and side_ok_vp)): |
| |
| try: |
| out_mid = model.encoder( |
| **inputs, |
| return_dict=True, |
| output_hidden_states=False, |
| stop_at_layer=target_layer_idx, |
| compute_lm_head=False, |
| return_intermediate_state=True, |
| ) |
| except TypeError: |
| out_mid = model.encoder( |
| **inputs, |
| return_dict=True, |
| output_hidden_states=False, |
| stop_at_layer=target_layer_idx, |
| return_intermediate_state=True, |
| ) |
|
|
| |
| post_attn_mid = getattr(out_mid, "attention_mask", None) |
| img_mask_mid = getattr(out_mid, "image_token_bool_masks", None) |
| txt_mask_mid = getattr(out_mid, "text_token_bool_masks", None) |
|
|
| |
| if post_attn_mid is not None: |
| |
| keep_counts = post_attn_mid.to(torch.bool).sum(dim=1) |
| |
| if txt_mask_mid is not None: |
| text_keep = (post_attn_mid.to(torch.bool) & txt_mask_mid.to(torch.bool)).sum(dim=1) |
| |
| if img_mask_mid is not None: |
| vis_keep = (post_attn_mid.to(torch.bool) & img_mask_mid.to(torch.bool)).sum(dim=1) |
|
|
| if profile_enabled: |
| torch.cuda.synchronize() |
| t1_mid = time.perf_counter() |
| timing_stats["mid_time_sum"] += (t1_mid - t0_mid) * B |
| timing_stats["mid_num"] += B |
|
|
| hs_mid = getattr(out_mid, "last_hidden_state", None) |
| if hs_mid is None: |
| |
| hs_mid = out_mid[0] |
| reps_mid_t = _pool_from_outputs(model, hs_mid, out_mid, inputs.get("attention_mask", None)).detach().to(dtype=torch.bfloat16) |
|
|
| |
| reps_last_full = None |
| if profile_enabled: |
| with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): |
| try: |
| out_full = model.encoder( |
| **inputs, return_dict=True, output_hidden_states=False, |
| stop_at_layer=None, compute_lm_head=False, |
| return_intermediate_state=False |
| ) |
| except TypeError: |
| out_full = model.encoder( |
| **inputs, return_dict=True, output_hidden_states=False, |
| stop_at_layer=None, |
| return_intermediate_state=False |
| ) |
| hs_last_full = getattr(out_full, "last_hidden_state", None) |
| if hs_last_full is None: |
| hs_last_full = out_full[0] |
| reps_last_full = _pool_from_outputs(model, hs_last_full, out_full, inputs.get("attention_mask", None)).detach().to(dtype=torch.bfloat16) |
|
|
| |
| |
| |
| exit_mask = np.zeros(B, dtype=bool) |
| p_need_last_batch = None |
|
|
| if method == "classifier" and classifier is not None: |
| with torch.no_grad(): |
| cos_mid = reps_mid_t @ cand_mid_t.T |
| backbone_ptr = model.module if hasattr(model, "module") else model |
| temp = getattr(backbone_ptr, "temperature", 0.02) |
| scores_mid = cos_mid / temp |
| probs_mid = torch.softmax(scores_mid, dim=1) |
|
|
| |
| diag_cos = cos_mid.max(dim=1)[0] |
| sorted_cos, _ = torch.sort(cos_mid, dim=1, descending=True) |
| s2_cos = sorted_cos[:, 1] if sorted_cos.size(1) > 1 else sorted_cos[:, 0] |
| margin_mid = diag_cos - s2_cos |
| |
| margin_mean = margin_mid.mean() |
| margin_std = margin_mid.std(unbiased=False) + 1e-6 |
| z_margin_mid = (margin_mid - margin_mean) / margin_std |
|
|
| margin_median = margin_mid.median() |
| mad = (margin_mid - margin_median).abs().median() + 1e-6 |
| mad_margin_mid = (margin_mid - margin_median) / mad |
|
|
| p1_mid = probs_mid.max(dim=1)[0] |
| H_mid = -(probs_mid * torch.log(probs_mid + 1e-6)).sum(dim=1) |
| gini_mid = 1.0 - (probs_mid ** 2).sum(dim=1) |
|
|
| TOPK = min(16, probs_mid.size(1)) |
| topk_vals, _ = torch.topk(probs_mid, k=TOPK, dim=1) |
| topk_mean = topk_vals.mean(dim=1) |
| topk_std = topk_vals.std(dim=1, unbiased=False) |
| topk_cv = topk_std / (topk_mean + 1e-6) |
|
|
| centered = topk_vals - topk_mean.unsqueeze(1) |
| var = (centered ** 2).mean(dim=1) + 1e-6 |
| m4 = (centered ** 4).mean(dim=1) |
| topk_kurt = m4 / (var ** 2) |
| topk_med = topk_vals.median(dim=1).values |
|
|
| row_mean_cos = cos_mid.mean(dim=1) |
| row_med_cos = cos_mid.median(dim=1).values |
| s1_over_mean = diag_cos - row_mean_cos |
| s1_over_med = diag_cos - row_med_cos |
|
|
| sorted_probs, _ = torch.sort(probs_mid, dim=1, descending=True) |
| p1 = sorted_probs[:, 0] |
| p2 = sorted_probs[:, 1] if sorted_probs.size(1) > 1 else sorted_probs[:, 0] |
|
|
| shape_H = -(sorted_probs * torch.log(sorted_probs + 1e-6)).sum(dim=1) |
| shape_gini = 1.0 - (sorted_probs ** 2).sum(dim=1) |
|
|
| R = min(10, sorted_probs.size(1)) |
| x = torch.arange(R, device=device, dtype=sorted_probs.dtype) |
| x_centered = x - x.mean() |
| denom = (x_centered ** 2).sum() |
| y = torch.log(sorted_probs[:, :R] + 1e-6) |
| slope = (x_centered.unsqueeze(0) * y).sum(dim=1) / denom |
|
|
| row_mean_p = probs_mid.mean(dim=1) |
| row_std_p = probs_mid.std(dim=1, unbiased=False) + 1e-6 |
| z1 = (p1_mid - row_mean_p) / row_std_p |
|
|
| center_p = probs_mid - row_mean_p.unsqueeze(1) |
| m3 = (center_p ** 3).mean(dim=1) |
| skew = m3 / (row_std_p ** 3 + 1e-6) |
| s1_over_sk = p1_mid - skew |
|
|
| TAIL_K = min(10, sorted_probs.size(1)) |
| tail_mean = sorted_probs[:, -TAIL_K:].mean(dim=1) |
| HEAD_K = min(5, sorted_probs.size(1)) |
| head5_mean = sorted_probs[:, :HEAD_K].mean(dim=1) |
|
|
| mask_ratio = torch.zeros_like(diag_cos) |
| mask_len = torch.zeros_like(diag_cos) |
| mask_runs = torch.zeros_like(diag_cos) |
|
|
| scalar_inputs = torch.stack([ |
| diag_cos, s2_cos, margin_mid, z_margin_mid, mad_margin_mid, |
| p1_mid, H_mid, gini_mid, |
| topk_mean, topk_std, topk_cv, topk_kurt, topk_med, |
| s1_over_mean, s1_over_med, |
| p1, p2, shape_H, shape_gini, slope, z1, s1_over_sk, |
| tail_mean, head5_mean, |
| mask_ratio, mask_len, mask_runs |
| ], dim=1) |
|
|
| modality_idx = torch.zeros(B, dtype=torch.long, device=device) |
| if "pixel_values" in inputs and inputs["pixel_values"] is not None: |
| pv = inputs["pixel_values"] |
| if isinstance(pv, list): |
| for i, item in enumerate(pv): |
| if item is not None: modality_idx[i] = 1 |
| elif isinstance(pv, torch.Tensor) and pv.numel() > 0: |
| modality_idx.fill_(1) |
|
|
| |
| |
| |
| scalar_inputs_f32 = scalar_inputs.float() |
| qry_emb_f32 = reps_mid_t.float() |
|
|
| logits = classifier(scalar_inputs_f32, modality_idx, qry_emb=qry_emb_f32) |
| |
| p_need_last = torch.sigmoid(logits) |
| p_need_last_batch = p_need_last.squeeze(1) |
| should_exit = p_need_last_batch < threshold |
| exit_mask = should_exit.cpu().numpy() |
|
|
| if stats["total"] <= B * 3 and is_main: |
| print_master( |
| f"[EE Debug] Batch {stats['total']//B}: " |
| f"p_need_last mean={p_need_last_batch.mean().item():.4f}, " |
| f"Exit Rate={exit_mask.mean():.2%}, " |
| f"Top3: diag_cos={diag_cos.mean():.3f}, margin={margin_mid.mean():.3f}" |
| ) |
|
|
| stats["exit"] += exit_mask.sum() |
|
|
| |
| |
| |
| exit_indices = np.where(exit_mask)[0] |
| cont_indices = np.where(~exit_mask)[0] |
|
|
| |
| if len(exit_indices) > 0: |
| reps_exit = reps_mid_t[exit_indices] |
| if use_local: |
| for i, idx in enumerate(exit_indices): |
| cand_local = infos[idx].get("cand_names", []) |
| if not cand_local: cids = [] |
| else: |
| rows = [cand_id2row.get(str(cid), -1) for cid in cand_local] |
| rows = [r for r in rows if r >= 0] |
| if len(rows) == 0: cids = [] |
| else: |
| cmat_np = cand_mid_np[rows] |
| qry_np = reps_exit[i].detach().float().cpu().numpy() |
| scores_vec = np.dot(qry_np, cmat_np.T) |
| top_k = min(200, len(rows)) |
| order_local = np.argsort(-scores_vec)[:top_k] |
| cids = [str(cand_local[o]) for o in order_local] |
| _record_result(results_dict, batch_start_idx + idx, infos[idx], cids) |
| else: |
| reps_exit_np = reps_exit.detach().float().cpu().numpy() |
| scores_full = np.dot(reps_exit_np, cand_mid_np.T) |
| top_k = min(200, len(cand_ids)) |
| topk_inds = np.argsort(-scores_full, axis=1)[:, :top_k] |
| for i, idx in enumerate(exit_indices): |
| cids = [cand_ids[k] for k in topk_inds[i]] |
| _record_result(results_dict, batch_start_idx + idx, infos[idx], cids) |
|
|
| |
| if len(cont_indices) > 0: |
| |
| interm = getattr(out_mid, "intermediate_state", None) |
| hs, am, pos = interm["hidden_states"].detach(), interm["attention_mask"].detach(), interm["position_ids"].detach() |
| vm, tm = interm.get("vision_mask", None), interm.get("text_mask", None) |
| next_layer = int(interm["next_layer_idx"]) |
|
|
| resume_state_subset = { |
| "hidden_states": hs[cont_indices], "attention_mask": am[cont_indices], |
| "position_ids": pos[:, cont_indices, :], |
| "vision_mask": vm[cont_indices] if vm is not None else None, |
| "text_mask": tm[cont_indices] if tm is not None else None, |
| "next_layer_idx": next_layer, |
| } |
|
|
| |
| if isinstance(aop_cfg, dict) and aop_cfg: |
| aop_resume = dict(aop_cfg) |
| aop_layer = aop_resume.get("layer_idx", None) |
|
|
| |
| if (aop_layer is not None) and (aop_layer <= target_layer_idx): |
| need_prune_in_tail = False |
| else: |
| |
| need_prune_in_tail = bool(_orig_enabled and side_enable) |
|
|
| aop_resume["enabled"] = need_prune_in_tail |
| setattr(model.encoder, "aop_prune_config", aop_resume) |
| |
|
|
| if profile_enabled: |
| torch.cuda.synchronize() |
| t0_tail = time.perf_counter() |
|
|
| with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): |
| side = "qry" |
| side_ok_aop = isinstance(aop_cfg, dict) and _apply_to_match(aop_cfg.get("apply_to", "both"), side) |
| side_ok_vp = isinstance(vpool_cfg, dict) and _apply_to_match(vpool_cfg.get("apply_to", "both"), side) |
|
|
| |
| tail_aop_enable = bool(isinstance(aop_cfg, dict) and aop_cfg.get("enabled", False) and side_ok_aop and need_prune_in_tail) |
|
|
| with _temp_cfg_enabled(model, "aop_prune_config", tail_aop_enable): |
| with _temp_cfg_enabled(model, "vision_pooling_config", bool(isinstance(vpool_cfg, dict) and vpool_cfg.get("enabled", False) and side_ok_vp)): |
| try: |
| out_last = model.encoder( |
| return_dict=True, |
| output_hidden_states=False, |
| stop_at_layer=None, |
| resume_state=resume_state_subset, |
| compute_lm_head=False, |
| return_intermediate_state=False, |
| ) |
| except TypeError: |
| out_last = model.encoder( |
| return_dict=True, |
| output_hidden_states=False, |
| stop_at_layer=None, |
| resume_state=resume_state_subset, |
| return_intermediate_state=False, |
| ) |
|
|
| if profile_enabled: |
| torch.cuda.synchronize() |
| t1_tail = time.perf_counter() |
| timing_stats["tail_time_sum"] += (t1_tail - t0_tail) * len(cont_indices) |
| timing_stats["tail_num"] += len(cont_indices) |
|
|
| hs_last = getattr(out_last, "last_hidden_state", None) |
| if hs_last is None: |
| hs_last = out_last[0] |
| reps_last_t = _pool_from_outputs(model, hs_last, out_last, resume_state_subset["attention_mask"]).detach().to(dtype=torch.bfloat16) |
|
|
| if use_local: |
| for i, idx_global in enumerate(cont_indices): |
| cand_local = infos[idx_global].get("cand_names", []) |
| if not cand_local: cids = [] |
| else: |
| rows = [cand_id2row.get(str(cid), -1) for cid in cand_local] |
| rows = [r for r in rows if r >= 0] |
| if len(rows) == 0: cids = [] |
| else: |
| cmat_last_np = cand_last_np[rows] |
| qry_last_np = reps_last_t[i].detach().float().cpu().numpy() |
| scores_vec = np.dot(qry_last_np, cmat_last_np.T) |
| top_k = min(200, len(rows)) |
| order_local = np.argsort(-scores_vec)[:top_k] |
| cids = [str(cand_local[o]) for o in order_local] |
| _record_result(results_dict, batch_start_idx + idx_global, infos[idx_global], cids) |
| else: |
| reps_last_np = reps_last_t.detach().float().cpu().numpy() |
| scores_last = np.dot(reps_last_np, cand_last_np.T) |
| top_k = min(200, len(cand_ids)) |
| topk_inds = np.argsort(-scores_last, axis=1)[:, :top_k] |
| for i, idx_global in enumerate(cont_indices): |
| cids = [cand_ids[k] for k in topk_inds[i]] |
| _record_result(results_dict, batch_start_idx + idx_global, infos[idx_global], cids) |
|
|
| |
| |
| |
| if profile_enabled and is_main: |
| K = min(topk_emb, cand_mid_t.size(0)) |
| |
| q_mid_cpu = reps_mid_t.detach().float().cpu() |
| q_last_cpu = ( |
| reps_last_full.detach().float().cpu() |
| if reps_last_full is not None |
| else None |
| ) |
| cand_mid_cpu = cand_mid_t.detach().float().cpu() |
| cand_last_cpu = cand_last_t.detach().float().cpu() |
|
|
| |
| scores_mid_full = q_mid_cpu @ cand_mid_cpu.T |
| topk_mid_vals, topk_mid_inds = torch.topk( |
| scores_mid_full, k=K, dim=1 |
| ) |
|
|
| |
| if q_last_cpu is not None: |
| scores_last_full = q_last_cpu @ cand_last_cpu.T |
| topk_last_vals, topk_last_inds = torch.topk( |
| scores_last_full, k=K, dim=1 |
| ) |
| else: |
| topk_last_vals = None |
| topk_last_inds = None |
|
|
| for i in range(B): |
| qid = batch_start_idx + i |
| rec = { |
| "qid": int(qid), |
| "early_exit": bool(exit_mask[i]), |
| } |
| if p_need_last_batch is not None: |
| rec["p_need_last"] = float(p_need_last_batch[i].item()) |
|
|
| |
| mid_inds = topk_mid_inds[i].tolist() |
| mid_scores = topk_mid_vals[i].tolist() |
| rec["mid_topk_scores"] = mid_scores |
| rec["mid_topk_cand_ids"] = [cand_ids[j] for j in mid_inds] |
| |
| |
| if topk_last_inds is not None: |
| last_inds = topk_last_inds[i].tolist() |
| last_scores = topk_last_vals[i].tolist() |
| rec["last_topk_scores"] = last_scores |
| rec["last_topk_cand_ids"] = [ |
| cand_ids[j] for j in last_inds |
| ] |
| else: |
| rec["last_topk_scores"] = None |
| rec["last_topk_cand_ids"] = None |
|
|
| analysis_records.append(rec) |
|
|
| |
| for idx in sorted(results_dict.keys()): |
| pred_dicts.append(results_dict[idx]) |
|
|
| print_master(f"Early Exit Stats: Exit={stats['exit']}/{stats['total']} ({stats['exit']/stats['total']:.2%})") |
|
|
| metrics_to_report = ["hit", "ndcg", "precision", "recall", "f1", "map", "mrr"] |
| score = RankingMetrics(metrics_to_report).evaluate(pred_dicts) |
|
|
| if is_main: |
| os.makedirs(out_dir, exist_ok=True) |
| with open(os.path.join(out_dir, f"{dataset_name}_score_earlyexit.json"), "w") as f: |
| json.dump(score, f, indent=4) |
| |
| |
| if profile_enabled: |
| prof_dir = os.path.join(out_dir, "profiling") |
| os.makedirs(prof_dir, exist_ok=True) |
|
|
| mid_avg = timing_stats["mid_time_sum"] / max(1, timing_stats["mid_num"]) |
| tail_avg = timing_stats["tail_time_sum"] / max(1, timing_stats["tail_num"]) |
| timing_out = { |
| "mid_time_sum": timing_stats["mid_time_sum"], |
| "mid_num": timing_stats["mid_num"], |
| "tail_time_sum": timing_stats["tail_time_sum"], |
| "tail_num": timing_stats["tail_num"], |
| "avg_mid_time_per_query_sec": mid_avg, |
| "avg_tail_time_per_cont_query_sec": tail_avg, |
| "num_exit": int(stats["exit"]), |
| "num_total": int(stats["total"]), |
| } |
| with open(os.path.join(prof_dir, f"{dataset_name}_timing.json"), "w") as f: |
| json.dump(timing_out, f, indent=2) |
|
|
| embed_path = os.path.join(prof_dir, f"{dataset_name}_embeds.jsonl") |
| with open(embed_path, "w") as f: |
| for rec in analysis_records: |
| f.write(json.dumps(rec) + "\n") |
| print_master(f"[PROFILE] Saved timing to {prof_dir}, details to {embed_path}") |
|
|
| elapsed = time.time() - start_time |
| return score, elapsed |
|
|
| def _record_result(results_dict, global_idx, info, cids): |
| label = info.get("label_name") or info.get("label") or info.get("rel_docids") |
| if not isinstance(label, list): label = [label] |
| rel_scores = info.get("rel_scores", None) |
| results_dict[global_idx] = { |
| "prediction": cids, "label": label, "rel_scores": rel_scores, |
| } |
|
|
| |
| |
| |
| |
| def encode_candidates_both_layers(model: MMEBModel, loader: DataLoader, training_args: TrainingArguments, model_args: ModelArguments, full_dataset: Dataset, mid_layer: int): |
| local_rank = dist.get_rank() if dist.is_initialized() else 0 |
| model.eval() |
| all_mid, all_last, all_ids = [], [], [] |
| with torch.no_grad(): |
| for inputs, dataset_info in tqdm(loader, desc=f"Candidates[BOTH] (rank {local_rank})", disable=local_rank > 0): |
| inputs = batch_to_device(inputs, training_args.device) |
| |
| side = "tgt" |
| aop_cfg = getattr(model.encoder, "aop_prune_config", None) |
| vpool_cfg = getattr(model.encoder, "vision_pooling_config", None) |
| side_ok_aop = isinstance(aop_cfg, dict) and _apply_to_match(aop_cfg.get("apply_to", "both"), side) |
| side_ok_vp = isinstance(vpool_cfg, dict) and _apply_to_match(vpool_cfg.get("apply_to", "both"), side) |
|
|
| |
| aop_layer = aop_cfg.get("layer_idx", None) if isinstance(aop_cfg, dict) else None |
| aop_on_mid = bool(isinstance(aop_cfg, dict) and aop_cfg.get("enabled", False) and side_ok_aop and (aop_layer is not None) and (aop_layer <= mid_layer)) |
|
|
| with _temp_cfg_enabled(model, "aop_prune_config", aop_on_mid): |
| with _temp_cfg_enabled(model, "vision_pooling_config", bool(isinstance(vpool_cfg, dict) and vpool_cfg.get("enabled", False) and side_ok_vp)): |
| out_mid = model.encoder( |
| **inputs, |
| return_dict=True, |
| output_hidden_states=False, |
| stop_at_layer=mid_layer, |
| compute_lm_head=False, |
| return_intermediate_state=True, |
| ) |
|
|
| hs_mid = getattr(out_mid, "last_hidden_state", None) |
| if hs_mid is None: |
| hs_mid = out_mid[0] |
| reps_mid = _pool_from_outputs(model, hs_mid, out_mid, inputs.get("attention_mask", None)) |
|
|
| |
| interm = getattr(out_mid, "intermediate_state", None) |
| if interm is None: |
| raise RuntimeError("out_mid has no intermediate_state; stop_at_layer/resume_state not supported by this backbone.") |
|
|
| resume_state = { |
| "hidden_states": interm["hidden_states"], |
| "attention_mask": interm["attention_mask"], |
| "position_ids": interm["position_ids"], |
| "vision_mask": interm.get("vision_mask", None), |
| "text_mask": interm.get("text_mask", None), |
| "special_mask": interm.get("special_mask", None), |
| "next_layer_idx": int(interm["next_layer_idx"]), |
| } |
|
|
| |
| if (aop_layer is not None) and (aop_layer <= mid_layer): |
| aop_on_tail = False |
| else: |
| aop_on_tail = bool(isinstance(aop_cfg, dict) and aop_cfg.get("enabled", False) and side_ok_aop) |
|
|
| with _temp_cfg_enabled(model, "aop_prune_config", aop_on_tail): |
| with _temp_cfg_enabled(model, "vision_pooling_config", bool(isinstance(vpool_cfg, dict) and vpool_cfg.get("enabled", False) and side_ok_vp)): |
| out_last = model.encoder( |
| return_dict=True, |
| output_hidden_states=False, |
| stop_at_layer=None, |
| resume_state=resume_state, |
| compute_lm_head=False, |
| return_intermediate_state=False, |
| ) |
|
|
| hs_last = getattr(out_last, "last_hidden_state", None) |
| if hs_last is None: |
| hs_last = out_last[0] |
| reps_last = _pool_from_outputs(model, hs_last, out_last, resume_state["attention_mask"]) |
|
|
| all_mid.append(reps_mid.detach().float().cpu()) |
| all_last.append(reps_last.detach().float().cpu()) |
| all_ids.extend([info["cand_name"] for info in dataset_info]) |
| if not all_mid: return np.array([]), np.array([]), [] |
| return torch.cat(all_mid, dim=0).numpy(), torch.cat(all_last, dim=0).numpy(), all_ids |
|
|
| |
| |
| |
| def main(): |
| if "RANK" in os.environ and dist.is_available() and not dist.is_initialized(): |
| dist.init_process_group(backend="nccl", timeout=datetime.timedelta(minutes=60)) |
| local_rank = dist.get_rank() if dist.is_initialized() else 0 |
| |
| parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) |
| model_args, data_args, training_args = parser.parse_args_into_dataclasses() |
| ee_cfg = get_env_ee_config() |
| |
| hf_config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) |
| if not getattr(model_args, "model_backbone", None): |
| model_backbone = get_backbone_name(hf_config=hf_config, model_type=model_args.model_type) |
| setattr(model_args, 'model_backbone', model_backbone) |
| |
| processor = load_processor(model_args, data_args) |
| |
| |
| model = MMEBModel.load(model_args, is_trainable=False, processor=processor) |
|
|
| |
| _maybe_load_vpool_attn_pooler_into_encoder(model, model_args.model_name) |
| |
| model.eval() |
| model = model.to(training_args.device, dtype=torch.bfloat16) |
| |
| |
| aop_cfg = get_env_aop_config() |
| if aop_cfg["enabled"]: |
| setattr(model.encoder, "aop_prune_config", aop_cfg) |
| _set_attr_on_base(model.encoder, "aop_prune_config", aop_cfg) |
| print_master(f"[AOP] Enabled: {aop_cfg['apply_to']} @ layer {aop_cfg['layer_idx']}") |
|
|
| |
| vpool_cfg = get_env_vpool_config() |
| if vpool_cfg["enabled"]: |
| setattr(model.encoder, "vision_pooling_config", vpool_cfg) |
| _set_attr_on_base(model.encoder, "vision_pooling_config", vpool_cfg) |
| print_master(f"[VPOOL] Enabled: {vpool_cfg.get('apply_to')} @ layer {vpool_cfg.get('layer_idx')} " |
| f"method={vpool_cfg.get('method')} attn_tau={vpool_cfg.get('attn_tau', None)}") |
|
|
| |
| classifier = None |
| if ee_cfg["method"] == "classifier" and ee_cfg["enabled"]: |
| classifier_path = ee_cfg['classifier_path'] |
| print_master(f"[EE] Loading Classifier from {classifier_path}...") |
| |
| |
| backbone_hidden_size = model.encoder.config.hidden_size |
| print_master(f"[EE] Backbone Hidden Size: {backbone_hidden_size}") |
| |
| classifier = EarlyExitClassifier( |
| input_dim=27, |
| hidden_dim=128, |
| embedding_dim=backbone_hidden_size |
| ) |
| |
| state_dict = None |
| if os.path.isdir(classifier_path): |
| safetensors_file = os.path.join(classifier_path, "model.safetensors") |
| if os.path.exists(safetensors_file): |
| from safetensors.torch import load_file |
| state_dict = load_file(safetensors_file) |
| else: |
| pt_file_bin = os.path.join(classifier_path, "pytorch_model.bin") |
| if os.path.exists(pt_file_bin): |
| state_dict = torch.load(pt_file_bin, map_location=training_args.device) |
| else: |
| layer_idx = ee_cfg.get('layer', 12) |
| pt_file = os.path.join(classifier_path, f"early_exit_classifier_layer_{layer_idx}.pt") |
| if os.path.exists(pt_file): |
| state_dict = torch.load(pt_file, map_location=training_args.device) |
|
|
| elif os.path.isfile(classifier_path): |
| state_dict = torch.load(classifier_path, map_location=training_args.device) |
| |
| if state_dict is not None: |
| classifier.load_state_dict(state_dict) |
| classifier.to(training_args.device) |
| classifier.eval() |
| print_master(f"[EE] Classifier loaded successfully.") |
| else: |
| raise FileNotFoundError(f"Could not load classifier weights from {classifier_path}") |
|
|
| |
| with open(data_args.dataset_config, 'r') as yaml_file: dataset_configs = yaml.safe_load(yaml_file) |
| for dataset_name, task_config in dataset_configs.items(): |
| if dist.is_initialized(): dist.barrier() |
| print_master(f"\n--- Evaluating {dataset_name} ---") |
| base_tau = float(ee_cfg["threshold"]) |
| ds_tau = PER_DATASET_THRESHOLDS.get(dataset_name, base_tau) |
| ee_cfg_ds = dict(ee_cfg); ee_cfg_ds["threshold"] = float(ds_tau) |
| if data_args.data_basedir: |
| for key in ["image_root", "video_root", "frame_root", "clip_root", "data_path"]: |
| if task_config.get(key): task_config[key] = os.path.join(data_args.data_basedir, task_config[key]) |
| |
| full_eval_qry_dataset, corpus = AutoEvalPairDataset.instantiate(model_args=model_args, data_args=data_args, **task_config) |
| full_eval_cand_dataset = generate_cand_dataset(full_eval_qry_dataset, corpus) |
| |
| mid_layer = int(ee_cfg_ds["layer"]) |
| |
| aop_cfg_now = getattr(model.encoder, "aop_prune_config", None) |
| vpool_cfg_now = getattr(model.encoder, "vision_pooling_config", None) |
| tag = _cfg_cache_tag(aop_cfg_now if isinstance(aop_cfg_now, dict) else None, |
| vpool_cfg_now if isinstance(vpool_cfg_now, dict) else None, |
| mid_layer) |
| cand_mid_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt_layer{mid_layer}_{tag}") |
| cand_last_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt_layerlast_{tag}") |
|
|
| if (not os.path.exists(cand_mid_path)) or (not os.path.exists(cand_last_path)): |
| eval_cand_collator = MultimodalEvalDataCollator(processor, model_args, data_args, "cand") |
| eval_cand_loader = DataLoader(full_eval_cand_dataset, batch_size=training_args.per_device_eval_batch_size, collate_fn=eval_cand_collator, num_workers=training_args.dataloader_num_workers) |
| |
| cand_mid, cand_last, cand_ids = encode_candidates_both_layers(model, eval_cand_loader, training_args, model_args, full_eval_cand_dataset, mid_layer) |
| if local_rank == 0: |
| with open(cand_mid_path, "wb") as f: pickle.dump({cid: emb for cid, emb in zip(cand_ids, cand_mid)}, f) |
| with open(cand_last_path, "wb") as f: pickle.dump({cid: emb for cid, emb in zip(cand_ids, cand_last)}, f) |
| if dist.is_initialized(): dist.barrier() |
|
|
| if local_rank == 0: |
| with open(cand_mid_path, "rb") as f: cand_mid_dict = pickle.load(f) |
| with open(cand_last_path, "rb") as f: cand_last_dict = pickle.load(f) |
| rank_global = task_config.get("eval_type", "global") == "global" |
| |
| run_early_exit_queries( |
| model, classifier, processor, model_args, data_args, training_args, |
| full_eval_qry_dataset, cand_mid_dict, cand_last_dict, |
| ee_cfg_ds, dataset_name, data_args.encode_output_path, |
| global_ranking=rank_global, |
| ) |
| if dist.is_initialized(): dist.barrier() |
|
|
| if __name__ == '__main__': |
| main() |