import datetime import logging import json import random import time import numpy as np import os import pickle import sys import torch import torch.distributed as dist import torch.nn.functional as F import yaml import transformers import json import datetime from torch.utils.data import DataLoader from tqdm import tqdm from transformers import HfArgumentParser, AutoConfig, AutoTokenizer from datasets import Dataset, concatenate_datasets from datasets.distributed import split_dataset_by_node from src.arguments import ModelArguments, DataArguments, TrainingArguments from src.data.collator.eval_collator import MultimodalEvalDataCollator from src.data.eval_dataset.base_eval_dataset import AutoEvalPairDataset, generate_cand_dataset from src.eval_utils.metrics import RankingMetrics from src.model.model_multilayer_AOP_infer import MMEBModel from src.model.processor import get_backbone_name, load_processor, COLPALI from src.utils import batch_to_device, print_rank, print_master from dataclasses import dataclass def get_env_mid_layer(): v = os.environ.get("MID_LM_LAYER", "").strip() if v == "" or v.lower() in {"none", "null"}: return None try: return int(v) except: logger.warning(f"Invalid MID_LM_LAYER={v}, ignore.") return None # ------------- AOP-Prune config parsing ------------- def _parse_bool(v: str, default=False): if v is None: return default v = v.strip().lower() return v in {"1","true","yes","y","t","on"} def _parse_float(v: str, default=None): try: return float(v) if v is not None else default except: return default def _parse_int(v: str, default=None): try: return int(v) if v is not None else default except: return default def get_env_aop_config(): """ 从环境变量读取 AOP 剪裁配置。仅作为“驱动层”的简要测试开关; 实际剪裁逻辑在底模里(Qwen2-VLModel.forward)实现。 """ enabled = _parse_bool(os.environ.get("AOP_ENABLED"), False) apply_to = os.environ.get("AOP_APPLY", "qry").strip().lower() # qry|cand|both layer_idx = _parse_int(os.environ.get("AOP_LAYER"), None) mode = os.environ.get("AOP_MODE", "delta").strip().lower() # 通用回退 delta = _parse_float(os.environ.get("AOP_DELTA"), 0.10) khat = _parse_float(os.environ.get("AOP_KHAT"), 1.0) keep_ratio = _parse_float(os.environ.get("AOP_KEEP_RATIO"), 1.0) min_keep = _parse_int(os.environ.get("AOP_MIN_KEEP"), 64) use_bias = _parse_bool(os.environ.get("AOP_USE_BIAS"), True) # 按类型控制 prune_vision = _parse_bool(os.environ.get("AOP_PRUNE_VISION"), True) prune_text = _parse_bool(os.environ.get("AOP_PRUNE_TEXT"), False) delta_v = _parse_float(os.environ.get("AOP_DELTA_VISION"), None) khat_v = _parse_float(os.environ.get("AOP_KHAT_VISION"), None) keep_ratio_v= _parse_float(os.environ.get("AOP_KEEP_RATIO_VISION"), None) min_keep_v = _parse_int(os.environ.get("AOP_MIN_KEEP_VISION"), None) delta_t = _parse_float(os.environ.get("AOP_DELTA_TEXT"), None) khat_t = _parse_float(os.environ.get("AOP_KHAT_TEXT"), None) keep_ratio_t= _parse_float(os.environ.get("AOP_KEEP_RATIO_TEXT"), None) min_keep_t = _parse_int(os.environ.get("AOP_MIN_KEEP_TEXT"), 32) protect_text_last = _parse_int(os.environ.get("AOP_PROTECT_TEXT_LAST"), 16) protect_special = _parse_bool(os.environ.get("AOP_PROTECT_SPECIAL"), True) margin_src = os.environ.get("AOP_MARGIN", "").strip().lower() # "" or "mid" attn_impl = os.environ.get("AOP_ATTN_IMPL", "").strip().lower() # "" or "sdpa" if layer_idx is None and enabled: logger.warning("AOP_ENABLED=1 但未设置 AOP_LAYER,关闭 AOP。"); enabled=False # 新增:选择策略(aop | random) selection = os.environ.get("AOP_SELECTION", "aop").strip().lower() if _parse_bool(os.environ.get("AOP_RANDOM"), False): selection = "random" random_seed = _parse_int(os.environ.get("AOP_RANDOM_SEED"), None) # 选择策略:aop | random | attention selection = os.environ.get("AOP_SELECTION", "aop").strip().lower() if _parse_bool(os.environ.get("AOP_RANDOM"), False): selection = "random" random_seed = _parse_int(os.environ.get("AOP_RANDOM_SEED"), None) attn_agg = os.environ.get("AOP_ATTENTION_AGG", "mean").strip().lower() # mean|max|sum cfg = { "enabled": enabled, "apply_to": apply_to, "layer_idx": layer_idx, "mode": mode, # 回退 "delta": delta, "K_hat": khat, "keep_ratio": keep_ratio, "min_keep": min_keep, "use_bias": use_bias, "eps": 1e-6, # 类型开关 "prune_vision": prune_vision, "prune_text": prune_text, # 视觉桶 "delta_vision": delta_v, "K_hat_vision": khat_v, "keep_ratio_vision": keep_ratio_v, "min_keep_vision": min_keep_v, # 文本桶 "delta_text": delta_t, "K_hat_text": khat_t, "keep_ratio_text": keep_ratio_t, "min_keep_text": min_keep_t, # 文本保护 "protect_text_last": protect_text_last, "protect_special": protect_special, # 可选:排名安全预算 "margin_mid": None if margin_src != "mid" else "USE_MID_MARGIN", "epsilon_hat": None, "attn_impl_override": attn_impl if attn_impl in {"sdpa"} else "", # NEW: 选择策略 "selection": selection, # "aop" 或 "random" "random_seed": random_seed, # 可选 "attn_agg": attn_agg, } return cfg def _b(k, d=False): v = os.environ.get(k) if v is None: return d return str(v).strip().lower() in {"1","true","yes","y","on","t"} def _i(k, d=None): try: return int(os.environ.get(k, d)) except: return d def _s(k, d=None): v = os.environ.get(k, d) return None if v is None else str(v).strip().lower() def get_env_vpool_config(): return { "enabled": _b("VPOOL_ENABLED", False), "apply_to": _s("VPOOL_APPLY", "both"), # qry|cand|both "layer_idx": _i("VPOOL_LAYER", 1), "kernel": _i("VPOOL_KERNEL", 2), "stride": _i("VPOOL_STRIDE", None) or _i("VPOOL_KERNEL", 2), "method": _s("VPOOL_METHOD", "avg"), # avg|max|linear|conv "protect_cls": _b("VPOOL_PROTECT_CLS", True), "vision_only": _b("VPOOL_ONLY_VISION", True), "monitor": _b("VPOOL_MONITOR", False), } def get_env_eval_layers(): """ 解析环境变量 LM_LAYERS(优先)或兼容旧的 MID_LM_LAYER。 - LM_LAYERS 示例:"4,8,12,last";可包含 'last'/'none'/'null'/'-1' 表示最后一层(None)。 - 若未设置 LM_LAYERS,则回落到旧逻辑:MID_LM_LAYER=None -> [None];否则 [mid, None] 返回: list[ int | None ],例如 [4, 8, 12, None];None 代表最后一层。 """ v = os.environ.get("LM_LAYERS", None) if v is not None: v = v.strip() if v: toks = [t.strip() for t in v.split(',') if t.strip() != ""] layers = [] for tok in toks: tl = tok.lower() if tl in {"last", "none", "null", "-1"}: layers.append(None) else: try: val = int(tok) if val > 0: layers.append(val) else: logger.warning(f"Ignoring non-positive layer '{tok}' in LM_LAYERS.") except Exception: logger.warning(f"Invalid token '{tok}' in LM_LAYERS; must be int or 'last'/'none'.") # 去重但保持顺序 seen = set() uniq = [] for l in layers: key = -1 if l is None else l if key in seen: continue seen.add(key) uniq.append(l) if not uniq: return [None] return uniq else: # 兼容旧逻辑 mid = get_env_mid_layer() return [None] if mid is None else [mid, None] def make_layer_tag(keep_layers: int | None): return f"layer{keep_layers}" if keep_layers and keep_layers > 0 else "layerlast" def dot_sim(a: np.ndarray, b: np.ndarray) -> np.ndarray: # a: [Nq, D], b: [Nc, D], both L2-normalized already if normalize=true return a @ b.T def build_score_details(qid: int, cand_ids: list, score_vec: np.ndarray, ranked_indices: np.ndarray): return { "qid": int(qid), "cand_scores": [ {"cand_id": str(cand_ids[i]), "score": float(score_vec[i])} for i in ranked_indices ] } def top1_top2_margin(score_vec: np.ndarray) -> float: if len(score_vec) < 2: return float("inf") # 只有一个候选时视作极大margin top2 = np.partition(score_vec, -2)[-2:] top2.sort() return float(top2[-1] - top2[-2]) def simulate_early_exit_by_margin( sims_mid: list[dict], sims_last: list[dict], labels: list[list[str]], metrics_to_report: list[str], taus: list[float], rank_global: bool ): """ sims_mid / sims_last: 每个query一个dict: {cand_id: score} labels: 每个query的正样本cand_id列表 返回:不同tau下的覆盖率、指标 """ assert len(sims_mid) == len(sims_last) == len(labels) N = len(labels) results = [] from src.eval_utils.metrics import RankingMetrics metrics = RankingMetrics(metrics_to_report) # 预构造 用于metrics.evaluate 的pred_dict def to_pred_dicts(use_mid_mask: list[bool]) -> list[dict]: pred_dicts = [] for qid in range(N): sims_use = sims_mid[qid] if use_mid_mask[qid] else sims_last[qid] # 排序 ranked = sorted(sims_use.items(), key=lambda x: -x[1]) pred_dicts.append({ "prediction": [cid for cid, _ in ranked], "label": labels[qid], "rel_scores": None }) return pred_dicts # 计算中间层margin margins = [] for qid in range(N): # 取前两大分数的margin if len(sims_mid[qid]) == 0: margins.append(0.0) continue scores = np.array(list(sims_mid[qid].values()), dtype=np.float32) margins.append(top1_top2_margin(scores)) margins = np.array(margins, dtype=np.float32) for tau in taus: use_mid_mask = (margins >= tau).tolist() pred_dicts = to_pred_dicts(use_mid_mask) score_dict = metrics.evaluate(pred_dicts) coverage = float(np.mean(use_mid_mask)) # 早停覆盖率 results.append({ "tau": tau, "coverage": coverage, **score_dict }) return results def top1_top2_margin_from_array(score_vec: np.ndarray) -> float: if score_vec is None or len(score_vec) == 0: return 0.0 if len(score_vec) == 1: return float('inf') # 取前两大 top2 = np.partition(score_vec, -2)[-2:] top2.sort() return float(top2[-1] - top2[-2]) logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s [%(name)s:%(lineno)s] %(message)s') logger = logging.getLogger(__name__) # --- Global Dictionaries for Hooks (will be cleared before each encode_embeddings call) --- timing_info = {} token_info = { "vision_tokens": 0, "text_input_tokens": 0, # Refers to the original text token count "text_output_tokens": 0, # Not directly applicable here as we are encoding, not generating. Will be 0. "total_llm_input_tokens": 0, # Refers to the total tokens LLM receives (visual + formatted text) } # --- Hook Functions Definition --- def timing_pre_hook(module, input): module_id = id(module) if module_id not in timing_info: timing_info[module_id] = [] timing_info[module_id].append((time.time(), 'pre', module.__class__.__name__)) def timing_post_hook(module, input, output): module_id = id(module) if module_id not in timing_info: # print(f"Warning: No pre-hook data for module {module.__class__.__name__} ({module_id})") return timing_info[module_id].append((time.time(), 'post', module.__class__.__name__)) # Collect vision token count (only from Vision Transformer module's post hook) module_name = module.__class__.__name__ if "vision" in module_name.lower() and "transformer" in module_name.lower(): if isinstance(output, torch.Tensor): token_info["vision_tokens"] = output.shape[0] # For visual features, usually (batch_size, num_tokens, hidden_dim) elif hasattr(output, 'last_hidden_state'): token_info["vision_tokens"] = output.last_hidden_state.shape[1] def register_model_hooks(model): registered_modules = [] core_model = model # print_master(f"DEBUG: Initial model type in register_model_hooks: {type(model)}") # Vision module if hasattr(core_model, 'visual') and core_model.visual is not None: vision_module = core_model.visual vision_module.register_forward_pre_hook(timing_pre_hook) vision_module.register_forward_hook(timing_post_hook) registered_modules.append(vision_module) print_master(f"Registered hooks for vision module: {vision_module.__class__.__name__}") else: print_master(f"WARNING: No 'visual' attribute found on core_model ({type(core_model)}).") # Merger module (if inside visual) - it's part of the vision component if hasattr(core_model, 'visual') and hasattr(core_model.visual, 'merger') and core_model.visual.merger is not None: merger_module = core_model.visual.merger merger_module.register_forward_pre_hook(timing_pre_hook) merger_module.register_forward_hook(timing_post_hook) registered_modules.append(merger_module) print_master(f"Registered hooks for merger module: {merger_module.__class__.__name__}") else: print_master(f"WARNING: No 'merger' attribute found on core_model.visual ({type(getattr(core_model, 'visual', 'N/A'))}).") # Language model body if hasattr(core_model, 'model') and core_model.model is not None: llm_main_module = core_model.model llm_main_module.register_forward_pre_hook(timing_pre_hook) llm_main_module.register_forward_hook(timing_post_hook) registered_modules.append(llm_main_module) print_master(f"Registered hooks for LLM main module: {llm_main_module.__class__.__name__}") else: print_master(f"WARNING: No 'model' attribute found on core_model ({type(core_model)}).") # LM Head if hasattr(core_model, 'lm_head') and core_model.lm_head is not None: lm_head_module = core_model.lm_head lm_head_module.register_forward_pre_hook(timing_pre_hook) lm_head_module.register_forward_hook(timing_post_hook) registered_modules.append(lm_head_module) print_master(f"Registered hooks for LM head module: {lm_head_module.__class__.__name__}") else: print_master(f"WARNING: No 'lm_head' attribute found on core_model ({type(core_model)}).") if not registered_modules: print_master("Warning: No major modules found for hook registration. Check model architecture.") return registered_modules def pad_dataset_to_divisible(dataset, world_size): num_samples = len(dataset) if num_samples % world_size == 0: return dataset, num_samples num_to_add = world_size - (num_samples % world_size) padded_size = num_samples + num_to_add padding_data = dataset.select([i % len(dataset) for i in range(num_to_add)]) padded_dataset = concatenate_datasets([dataset, padding_data]) return padded_dataset, padded_size def encode_embeddings( model: MMEBModel, loader: DataLoader, training_args: TrainingArguments, model_args: ModelArguments, full_dataset: Dataset, encode_side: str, description: str = "Encoding" ) -> tuple[np.ndarray, list, list, list]: # CHANGED: + list for img_token_masks """ Encodes embeddings for a given dataset using the model, handling both standard and late-interaction models in a DDP-safe manner. Returns: - embeddings: np.ndarray - infos_or_ids: list - batch_stats_list: list - img_token_masks: list[None | list[bool]] # NEW """ local_rank = dist.get_rank() if dist.is_initialized() else 0 world_size = dist.get_world_size() if dist.is_initialized() else 1 # Check if the model is a late-interaction type is_late_interaction = (model_args.model_backbone == COLPALI) local_embeds = [] local_gt_infos = [] local_max_len = 0 # --- New: List to store statistics for each batch --- batch_stats_list = [] # --- NEW: Collect masks --- local_img_token_masks = [] # post image mask per sample local_txt_token_masks = [] # NEW: post text mask per sample local_post_attn_masks = [] # NEW: post attention_mask per sample (after prune, 1/0) # --- NEW: per-sample token reduction records --- local_token_records = [] # 每条样本一个 dict,含 pre/post/delta 数量 model.eval() # Register hooks for the model once per encode_embeddings call registered_hooks = register_model_hooks(model) # --- NEW: helpers to取mask并序列化 --- def _search_key(obj, key: str): # 递归搜索 dict/list/tuple,找到指定 key if isinstance(obj, dict): if key in obj: return obj[key] for v in obj.values(): r = _search_key(v, key) if r is not None: return r elif isinstance(obj, (list, tuple)): for v in obj: r = _search_key(v, key) if r is not None: return r return None def _to_serializable_mask_list(mask_list, batch_size: int): # 将模型返回的 mask(list/tensor/ndarray/None)转成 [None | list[bool]] * B if mask_list is None: return [None] * batch_size out = [] if isinstance(mask_list, (list, tuple)): for m in mask_list: if m is None: out.append(None) elif torch.is_tensor(m): out.append(m.detach().cpu().tolist()) elif isinstance(m, np.ndarray): out.append(m.tolist()) else: # already python list/bool out.append(m) elif torch.is_tensor(mask_list): # 若是 2D 张量(B, L),直接 tolist() -> list[list[bool/int]] out = mask_list.detach().cpu().tolist() elif isinstance(mask_list, np.ndarray): out = mask_list.tolist() else: # 未知类型,保守返回 None 占位 out = [None] * batch_size # 长度对齐 batch_size if isinstance(out, list): if len(out) < batch_size: out = out + [None] * (batch_size - len(out)) elif len(out) > batch_size: out = out[:batch_size] return out def _to_bool_lists(m, batch_size: int): lst = _to_serializable_mask_list(m, batch_size) # 归一化成 list[ list[bool] | None ] out = [] for x in lst: if x is None: out.append(None) else: # x 可能是 list[int] 或 list[bool] out.append([bool(int(v)) for v in x]) return out with torch.no_grad(): for inputs, dataset_info in tqdm(loader, desc=f"{description} (rank {local_rank})", disable=local_rank > 0): # --- Reset statistics for each inference pass --- timing_info.clear() token_info["vision_tokens"] = 0 token_info["text_input_tokens"] = 0 token_info["text_output_tokens"] = 0 token_info["total_llm_input_tokens"] = 0 inputs = batch_to_device(inputs, training_args.device) current_batch_size = inputs['input_ids'].shape[0] if 'input_ids' in inputs and inputs['input_ids'] is not None else 1 with torch.autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda"): start_inference_time = time.time() # ---- NEW: 按侧开/关 AOP ---- aop_cfg = getattr(model.encoder, "aop_prune_config", None) _orig_enabled = None if isinstance(aop_cfg, dict) and aop_cfg: _orig_enabled = aop_cfg.get("enabled", False) apply_to = aop_cfg.get("apply_to", "qry") side_enable = (apply_to == "both") or (apply_to == encode_side) aop_cfg["enabled"] = bool(side_enable and _orig_enabled) setattr(model.encoder, "aop_prune_config", aop_cfg) if encode_side == "qry": output = model(qry=inputs) reps = output["qry_reps"].detach() local_gt_infos.extend(dataset_info) else: output = model(tgt=inputs) reps = output["tgt_reps"].detach() local_gt_infos.extend([info["cand_name"] for info in dataset_info]) # ---- NEW: 恢复 enabled(避免影响下个 encode_side)---- if isinstance(aop_cfg, dict) and _orig_enabled is not None: aop_cfg["enabled"] = _orig_enabled setattr(model.encoder, "aop_prune_config", aop_cfg) end_inference_time = time.time() # --- NEW: 提取 post-prune 的 image/text 掩码 与 post attention_mask --- img_masks_raw = None txt_masks_raw = None post_attn_raw = None if isinstance(output, dict): img_masks_raw = _search_key(output, "image_token_bool_masks") txt_masks_raw = _search_key(output, "text_token_bool_masks") # NEW post_attn_raw = _search_key(output, "post_attention_mask") if post_attn_raw is None: # 兼容 mldaop 变体:有些只返回 attention_mask post_attn_raw = _search_key(output, "attention_mask") # 兼容:若挂在 model 上 if img_masks_raw is None and hasattr(model, "image_token_bool_masks"): img_masks_raw = getattr(model, "image_token_bool_masks") if txt_masks_raw is None and hasattr(model, "text_token_bool_masks"): txt_masks_raw = getattr(model, "text_token_bool_masks") if post_attn_raw is None and hasattr(model, "post_attention_mask"): post_attn_raw = getattr(model, "post_attention_mask") img_masks_serializable = _to_serializable_mask_list(img_masks_raw, current_batch_size) txt_masks_serializable = _to_serializable_mask_list(txt_masks_raw, current_batch_size) # NEW post_attn_serializable = _to_serializable_mask_list(post_attn_raw, current_batch_size) # NEW local_img_token_masks.extend(img_masks_serializable) local_txt_token_masks.extend(txt_masks_serializable) # NEW local_post_attn_masks.extend(post_attn_serializable) # NEW # --- NEW: 计算本 batch 的 pre/post/delta 数量并累计 --- cfg = getattr(model.encoder, "config", None) # pre masks 来自 inputs(删前) input_ids = inputs.get("input_ids", None) attn2d_pre = inputs.get("attention_mask", None) if input_ids is None or attn2d_pre is None or cfg is None: # 无法统计,留空 pre_vis_counts = [0] * current_batch_size pre_txt_counts = [0] * current_batch_size pre_tot_counts = [0] * current_batch_size else: iid = input_ids am = attn2d_pre.to(torch.bool) image_token_id = getattr(cfg, "image_token_id", None) video_token_id = getattr(cfg, "video_token_id", None) bos_id = getattr(cfg, "bos_token_id", None) eos_id = getattr(cfg, "eos_token_id", None) pad_id = getattr(cfg, "pad_token_id", None) is_image = (iid == image_token_id) if (image_token_id is not None and image_token_id >= 0) else torch.zeros_like(iid, dtype=torch.bool) is_video = (iid == video_token_id) if (video_token_id is not None and video_token_id >= 0) else torch.zeros_like(iid, dtype=torch.bool) is_vision = is_image | is_video is_special = torch.zeros_like(iid, dtype=torch.bool) for tid in [bos_id, eos_id, pad_id]: if tid is not None and tid >= 0: is_special |= (iid == tid) pre_txt_mask = am & (~is_vision) & (~is_special) pre_vis_mask = am & is_vision pre_vis_counts = pre_vis_mask.sum(dim=1).tolist() pre_txt_counts = pre_txt_mask.sum(dim=1).tolist() pre_tot_counts = am.sum(dim=1).tolist() # post masks(删后)来自模型输出;与 post_attn 做与运算 post_text_masks = _to_bool_lists(txt_masks_raw, current_batch_size) # list[ list[bool] | None ] post_image_masks = _to_bool_lists(img_masks_raw, current_batch_size) post_attn_masks = _to_bool_lists(post_attn_raw, current_batch_size) sum_pre_text = 0; sum_post_text = 0 sum_pre_vis = 0; sum_post_vis = 0 sum_pre_tot = 0; sum_post_tot = 0 for i in range(current_batch_size): pre_text = int(pre_txt_counts[i]) if i < len(pre_txt_counts) else 0 pre_vis = int(pre_vis_counts[i]) if i < len(pre_vis_counts) else 0 pre_tot = int(pre_tot_counts[i]) if i < len(pre_tot_counts) else 0 # post 计数:mask 可能为 None m_text = post_text_masks[i] if post_text_masks is not None and i < len(post_text_masks) else None m_img = post_image_masks[i] if post_image_masks is not None and i < len(post_image_masks) else None m_attn = post_attn_masks[i] if post_attn_masks is not None and i < len(post_attn_masks) else None if m_attn is None: post_text = 0; post_vis = 0; post_tot = 0 else: # 与 attention_mask 后统计 True 的数 if m_text is not None: post_text = sum(1 for a, t in zip(m_attn, m_text) if a and t) else: post_text = 0 if m_img is not None: post_vis = sum(1 for a, v in zip(m_attn, m_img) if a and v) else: post_vis = 0 post_tot = sum(1 for a in m_attn if a) # 累计 batch 级 sum_pre_text += pre_text; sum_post_text += post_text sum_pre_vis += pre_vis; sum_post_vis += post_vis sum_pre_tot += pre_tot; sum_post_tot += post_tot # 保存 per-sample 记录(用于 JSONL) local_token_records.append({ "side": encode_side, "pre": {"text": pre_text, "vision": pre_vis, "total": pre_tot}, "post": {"text": post_text, "vision": post_vis, "total": post_tot}, "delta":{"text": pre_text - post_text, "vision": pre_vis - post_vis, "total": pre_tot - post_tot}, }) # --- Update total LLM input tokens after the model call --- if 'input_ids' in inputs and inputs['input_ids'] is not None: token_info["total_llm_input_tokens"] = inputs['input_ids'].shape[1] token_info["text_input_tokens"] = token_info["total_llm_input_tokens"] - token_info["vision_tokens"] token_info["text_input_tokens"] = max(0, token_info["text_input_tokens"]) # --- Collect and Store Batch Statistics --- batch_inference_time = end_inference_time - start_inference_time current_batch_stats = { "batch_size": current_batch_size, "total_inference_time_seconds": batch_inference_time, "module_inference_times": {}, "token_counts": { "visual_tokens": token_info["vision_tokens"], "language_input_tokens_raw": token_info["text_input_tokens"], "llm_total_input_tokens": token_info["total_llm_input_tokens"], "language_output_tokens": token_info["text_output_tokens"], } } current_batch_stats["token_reduction"] = { "sum_pre_text": sum_pre_text, "sum_post_text": sum_post_text, "sum_pre_vision": sum_pre_vis, "sum_post_vision": sum_post_vis, "sum_pre_total": sum_pre_tot, "sum_post_total": sum_post_tot, } # Calculate and store module timings for the current batch for module_obj in registered_hooks: module_id = id(module_obj) module_name = module_obj.__class__.__name__ times = timing_info.get(module_id, []) durations = [] pre_times = {} for t, event_type, _ in times: if event_type == 'pre': pre_times[module_id] = t elif event_type == 'post' and module_id in pre_times: duration = t - pre_times.pop(module_id) durations.append(duration) if durations: current_batch_stats["module_inference_times"][module_name] = { "total": sum(durations), "count": len(durations), "avg": sum(durations) / len(durations) } else: current_batch_stats["module_inference_times"][module_name] = { "total": 0.0, "count": 0, "avg": 0.0 } batch_stats_list.append(current_batch_stats) # --- Debug prints (optional) --- print_rank(f"\n--- Inference Statistics for {encode_side} batch (Rank {local_rank}) ---") print_rank(f"Batch Inference took: {batch_inference_time:.4f} seconds") print_rank("--- Module Inference Timing Statistics ---") for module_name, stats in current_batch_stats["module_inference_times"].items(): print_rank(f"**{module_name}**: Total: {stats['total']:.6f}s, Count: {stats['count']}, Avg: {stats['avg']:.6f}s") print_rank("--- Token Count Statistics ---") print_rank(f"**视觉 token 数量**: {current_batch_stats['token_counts']['visual_tokens']}") print_rank(f"**语言输入 token 数量 (仅原始文本)**: {current_batch_stats['token_counts']['language_input_tokens_raw']}") print_rank(f"**LLM总输入 token 数量 (包含视觉 + 格式化文本)**: {current_batch_stats['token_counts']['llm_total_input_tokens']}") print_rank(f"**语言输出 token 数量**: {current_batch_stats['token_counts']['language_output_tokens']}") if is_late_interaction and reps.dim() == 3: local_max_len = max(local_max_len, reps.shape[1]) local_embeds.append(reps) if not local_embeds: # Handle cases where a rank gets no data return np.array([]), [], [], [] # CHANGED: 4个返回值 # === DDP Synchronization and Padding for Late-Interaction Models === if is_late_interaction: if dist.is_initialized(): # 1: global max length local_max_len_tensor = torch.tensor(local_max_len, device=training_args.device) dist.all_reduce(local_max_len_tensor, op=dist.ReduceOp.MAX) global_max_len = local_max_len_tensor.item() else: global_max_len = local_max_len # 2: pad to global max length padded_embeds = [] for reps_batch in local_embeds: if reps_batch.dim() == 3: B, L, H = reps_batch.shape padding_size = global_max_len - L padded_batch = F.pad(reps_batch, (0, 0, 0, padding_size), "constant", 0) padded_embeds.append(padded_batch) else: padded_embeds.append(reps_batch) embeds_tensor = torch.cat(padded_embeds, dim=0).contiguous() else: embeds_tensor = torch.cat(local_embeds, dim=0).contiguous() # === Gather embeddings and keys from all ranks === if dist.is_initialized() and full_dataset.num_rows >= world_size: print_master(f"Gathering {encode_side} embeddings across all ranks...") # tensor gather output_shape = list(embeds_tensor.shape) output_shape[0] = full_dataset.num_rows embeds_tensor = embeds_tensor.to(training_args.device) gathered_embeds_tensor = torch.empty(output_shape, dtype=embeds_tensor.dtype, device=training_args.device) dist.all_gather_into_tensor(gathered_embeds_tensor, embeds_tensor) final_embeddings = gathered_embeds_tensor.cpu().float().numpy() # object gather for infos and stats gathered_gt_infos = [None for _ in range(world_size)] dist.all_gather_object(gathered_gt_infos, local_gt_infos) all_gt_infos = [key for rank_keys in gathered_gt_infos for key in rank_keys] gathered_batch_stats = [None for _ in range(world_size)] dist.all_gather_object(gathered_batch_stats, batch_stats_list) all_batch_stats = [stats for rank_stats in gathered_batch_stats for stats in rank_stats] # --- NEW: gather masks --- gathered_masks = [None for _ in range(world_size)] dist.all_gather_object(gathered_masks, local_img_token_masks) all_img_token_masks = [m for rank_list in gathered_masks for m in rank_list] # NEW: gather text masks gathered_txt_masks = [None for _ in range(world_size)] dist.all_gather_object(gathered_txt_masks, local_txt_token_masks) all_txt_token_masks = [m for rank_list in gathered_txt_masks for m in rank_list] # NEW: gather post attention masks(如需) gathered_post_attn = [None for _ in range(world_size)] dist.all_gather_object(gathered_post_attn, local_post_attn_masks) all_post_attn_masks = [m for rank_list in gathered_post_attn for m in rank_list] # NEW: gather token records gathered_token_recs = [None for _ in range(world_size)] dist.all_gather_object(gathered_token_recs, local_token_records) all_token_records = [r for rank_list in gathered_token_recs for r in rank_list] else: all_gt_infos = local_gt_infos final_embeddings = embeds_tensor.cpu().float().numpy() all_batch_stats = batch_stats_list all_img_token_masks = local_img_token_masks # NEW all_txt_token_masks = local_txt_token_masks all_post_attn_masks = local_post_attn_masks all_token_records = local_token_records return final_embeddings, all_gt_infos, all_batch_stats, all_img_token_masks, all_txt_token_masks, all_token_records def main(): # ----------------------- Distributed init ----------------------- if "RANK" in os.environ and dist.is_available() and not dist.is_initialized(): dist.init_process_group(backend="nccl", timeout=datetime.timedelta(minutes=60)) local_rank = dist.get_rank() if dist.is_initialized() else 0 world_size = dist.get_world_size() if dist.is_initialized() else 1 print_master("Distributed init debug info:") print_master(f"RANK: {os.environ.get('RANK')}") print_master(f"LOCAL_RANK: {os.environ.get('LOCAL_RANK')}") print_master(f"WORLD_SIZE: {os.environ.get('WORLD_SIZE')}") print_master(f"MASTER_ADDR: {os.environ.get('MASTER_ADDR')}") print_master(f"MASTER_PORT: {os.environ.get('MASTER_PORT')}") if dist.is_initialized(): print_rank(f"dist.get_rank(): {dist.get_rank()}") print_rank(f"dist.get_world_size(): {dist.get_world_size()}") # 兼容 torchrun 参数 for arg in sys.argv: if arg.startswith("--local-rank="): rank = arg.split("=")[1] sys.argv.remove(arg) sys.argv.append('--local_rank') sys.argv.append(rank) # ----------------------- Parse args ----------------------- parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) model_args, data_args, training_args = parser.parse_args_into_dataclasses() model_args: ModelArguments data_args: DataArguments training_args: TrainingArguments os.makedirs(data_args.encode_output_path, exist_ok=True) # 支持多层评测(优先 LM_LAYERS,兼容 MID_LM_LAYER) layers_to_eval = get_env_eval_layers() print_master(f"Eval layers (qry/tgt): {layers_to_eval}") # ----------------------- Model loading ----------------------- hf_config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) if not getattr(model_args, "model_backbone", None): model_backbone = get_backbone_name(hf_config=hf_config, model_type=model_args.model_type) setattr(model_args, 'model_backbone', model_backbone) setattr(training_args, 'model_backbone', model_backbone) print_master(f'Model Backbone: {model_args.model_backbone}') # 仅 rank0 下载,其他rank等待缓存 if local_rank == 0: processor = load_processor(model_args, data_args) model = MMEBModel.load(model_args, is_trainable=False, processor=processor) print_master(f"[rank=0] Loading the model from Huggingface: {model_args.model_name}...") if torch.distributed.is_initialized(): torch.distributed.barrier() if local_rank != 0: print_rank(f"Loading the model from cache...") processor = load_processor(model_args, data_args) time.sleep(random.randint(2 * local_rank, 3 * local_rank)) model = MMEBModel.load(model_args, is_trainable=False, processor=processor) model.eval() model = model.to(training_args.device, dtype=torch.bfloat16) # ---- NEW: AOP 剪裁配置注入(驱动底模里已实现的 AOP 逻辑)---- aop_cfg = get_env_aop_config() if aop_cfg["enabled"]: # 1) 写到 LoRA wrapper setattr(model.encoder, "aop_prune_config", aop_cfg) # 2) 同步到底座(ForConditionalGeneration / Qwen2_5_VLModel) try: base = model.encoder.get_base_model() if hasattr(model.encoder, "get_base_model") else None if base is None and hasattr(model.encoder, "model"): base = model.encoder.model if base is not None: setattr(base, "aop_prune_config", aop_cfg) if hasattr(base, "model"): # ForConditionalGeneration.model -> Qwen2_5_VLModel setattr(base.model, "aop_prune_config", aop_cfg) except Exception as e: print_master(f"[AOP] warn: sync cfg to base failed: {e}") # 可选:覆盖注意力实现用于分析 attn_override = aop_cfg.get("attn_impl_override", "") if attn_override: try: if hasattr(model.encoder, "model") and hasattr(model.encoder.model, "config"): prev = model.encoder.model.config._attn_implementation model.encoder.model.config._attn_implementation = attn_override print_master(f"[AOP] override attn impl: {prev} -> {attn_override} (仅测试建议)") except Exception as e: print_master(f"[AOP] try override attn impl failed: {e}") print_master("[AOP] AOP-Prune enabled with config: " + json.dumps({ "apply_to": aop_cfg.get("apply_to"), "layer_idx": aop_cfg.get("layer_idx"), "mode": aop_cfg.get("mode"), "delta": aop_cfg.get("delta"), "K_hat": aop_cfg.get("K_hat"), "keep_ratio": aop_cfg.get("keep_ratio"), "min_keep": aop_cfg.get("min_keep"), "prune_text": aop_cfg.get("prune_text"), "prune_vision": aop_cfg.get("prune_vision"), "keep_ratio_text": aop_cfg.get("keep_ratio_text"), "keep_ratio_vision": aop_cfg.get("keep_ratio_vision"), "selection": aop_cfg.get("selection"), "attn_agg": aop_cfg.get("attn_agg"), }, ensure_ascii=False)) else: print_master("[AOP] disabled (set AOP_ENABLED=1 to enable)") # ---- NEW: Vision Pooling 配置注入 ---- vpool_cfg = get_env_vpool_config() if vpool_cfg["enabled"]: # 1) 写到 LoRA wrapper setattr(model.encoder, "vision_pooling_config", vpool_cfg) # 2) 同步到底座(ForConditionalGeneration / Qwen2_5_VLModel) try: base = model.encoder.get_base_model() if hasattr(model.encoder, "get_base_model") else None if base is None and hasattr(model.encoder, "model"): base = model.encoder.model if base is not None: setattr(base, "vision_pooling_config", vpool_cfg) if hasattr(base, "model"): # ForConditionalGeneration.model -> Qwen2_5_VLModel setattr(base.model, "vision_pooling_config", vpool_cfg) except Exception as e: print_master(f"[VPOOL] warn: sync cfg to base failed: {e}") print_master("[VPOOL] enabled with config: " + json.dumps({ "apply_to": vpool_cfg.get("apply_to"), "layer_idx": vpool_cfg.get("layer_idx"), "kernel": vpool_cfg.get("kernel"), "stride": vpool_cfg.get("stride"), "method": vpool_cfg.get("method"), "vision_only": vpool_cfg.get("vision_only"), "monitor": vpool_cfg.get("monitor"), }, ensure_ascii=False)) else: print_master("[VPOOL] disabled (set VPOOL_ENABLED=1 to enable)") # 确保“最后一层”时不裁层(避免类里默认20层的坑) model.set_inference_layers(qry_layers=None, tgt_layers=None) with open(data_args.dataset_config, 'r') as yaml_file: dataset_configs = yaml.safe_load(yaml_file) # ----------------------- Main evaluation loop ----------------------- for dataset_idx, (dataset_name, task_config) in enumerate(dataset_configs.items()): if dist.is_initialized(): dist.barrier() print_master(f"\n--- Evaluating {dataset_name} ---") # 根据 data_basedir 修正路径 if data_args.data_basedir is not None: for key in ["image_root", "video_root", "frame_root", "clip_root", "data_path"]: if data_args.data_basedir and task_config.get(key): task_config[key] = os.path.join(data_args.data_basedir, task_config[key]) # 构建数据集 full_eval_qry_dataset, corpus = AutoEvalPairDataset.instantiate(model_args=model_args, data_args=data_args, **task_config) full_eval_cand_dataset = generate_cand_dataset(full_eval_qry_dataset, corpus) eval_qry_dataset, eval_cand_dataset = full_eval_qry_dataset, full_eval_cand_dataset if dist.is_initialized(): world_size = dist.get_world_size() padded_qry_dataset, _ = pad_dataset_to_divisible(full_eval_qry_dataset, world_size) padded_cand_dataset, _ = pad_dataset_to_divisible(full_eval_cand_dataset, world_size) eval_qry_dataset = split_dataset_by_node(padded_qry_dataset, rank=local_rank, world_size=world_size) eval_cand_dataset = split_dataset_by_node(padded_cand_dataset, rank=local_rank, world_size=world_size) else: padded_qry_dataset, padded_cand_dataset = full_eval_qry_dataset, full_eval_cand_dataset # 路径索引 saved_paths = {} # {(side, tag): path} # --------- 针对每个层设置(中间层/最后一层)分别编码与保存 --------- for keep_layers in layers_to_eval: tag = make_layer_tag(keep_layers) print_master(f"[{dataset_name}] Start encoding for tag={tag} (keep_layers={keep_layers})") # 设置模型层数 model.set_inference_layers(qry_layers=keep_layers, tgt_layers=keep_layers) # 路径 query_embed_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_qry_{tag}") cand_embed_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt_{tag}") dataset_info_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_info.jsonl") query_inference_stats_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_qry_inference_stats_{tag}.json") cand_inference_stats_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_cand_inference_stats_{tag}.json") qry_img_masks_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_qry_img_token_masks_{tag}.jsonl") cand_img_masks_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_cand_img_token_masks_{tag}.jsonl") # 追加四个新文件路径 qry_txt_masks_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_qry_text_token_masks_{tag}.jsonl") qry_token_stats_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_qry_token_stats_{tag}.jsonl") cand_txt_masks_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_cand_text_token_masks_{tag}.jsonl") cand_token_stats_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_cand_token_stats_{tag}.jsonl") saved_paths[("qry", tag)] = query_embed_path saved_paths[("tgt", tag)] = cand_embed_path do_query = not os.path.exists(query_embed_path) or not os.path.exists(dataset_info_path) do_cand = not os.path.exists(cand_embed_path) # 动态累计统计 def init_total_stats(): return { "total_inference_time_seconds": 0.0, "module_inference_times": {}, "token_counts": { "visual_tokens": 0, "language_input_tokens_raw": 0, "llm_total_input_tokens": 0, "language_output_tokens": 0, }, "reduction": { # NEW: 删前/删后数量累计 "sum_pre_text": 0, "sum_post_text": 0, "sum_pre_vision": 0, "sum_post_vision": 0, "sum_pre_total": 0, "sum_post_total": 0, }, "data_point_count": 0 } def accumulate_stats(total_stats, batch_stats): batch_size = batch_stats["batch_size"] total_stats["total_inference_time_seconds"] += batch_stats["total_inference_time_seconds"] # 模块时间 for mname, mstats in batch_stats["module_inference_times"].items(): if mname not in total_stats["module_inference_times"]: total_stats["module_inference_times"][mname] = {"total": 0.0, "count": 0} total_stats["module_inference_times"][mname]["total"] += mstats.get("total", 0.0) total_stats["module_inference_times"][mname]["count"] += mstats.get("count", 0) # 原始 token 统计(乘以 batch_size,是为了估计总量) total_stats["token_counts"]["visual_tokens"] += batch_stats["token_counts"]["visual_tokens"] * batch_size total_stats["token_counts"]["language_input_tokens_raw"] += batch_stats["token_counts"]["language_input_tokens_raw"] * batch_size total_stats["token_counts"]["llm_total_input_tokens"] += batch_stats["token_counts"]["llm_total_input_tokens"] * batch_size total_stats["token_counts"]["language_output_tokens"] += batch_stats["token_counts"]["language_output_tokens"] * batch_size total_stats["data_point_count"] += batch_size # NEW: 删减统计 red = batch_stats.get("token_reduction", None) if red is not None: for k in total_stats["reduction"].keys(): total_stats["reduction"][k] += int(red.get(k, 0)) def finalize_and_save_stats(total_stats, out_path, task_name, encode_side): if local_rank != 0: return if total_stats["data_point_count"] <= 0: print_master(f"No data processed for {task_name} [{encode_side}], skip saving stats.") return n = max(1, total_stats["data_point_count"]) red = total_stats["reduction"] pre_txt, post_txt = red["sum_pre_text"], red["sum_post_text"] pre_vis, post_vis = red["sum_pre_vision"], red["sum_post_vision"] pre_tot, post_tot = red["sum_pre_total"], red["sum_post_total"] avg_text_pruned = (pre_txt - post_txt) / n avg_vision_pruned = (pre_vis - post_vis) / n avg_total_pruned = (pre_tot - post_tot) / n avg_text_keep_ratio = (post_txt / pre_txt) if pre_txt > 0 else 1.0 avg_vision_keep_ratio = (post_vis / pre_vis) if pre_vis > 0 else 1.0 avg_total_keep_ratio = (post_tot / pre_tot) if pre_tot > 0 else 1.0 final_stats = { "task_name": task_name, "encode_side": encode_side, "data_point_count": total_stats["data_point_count"], "inference_times": { "total_inference_time_seconds": total_stats["total_inference_time_seconds"], "avg_inference_time_per_item_seconds": total_stats["total_inference_time_seconds"] / n, "module_average_times_per_call": {}, "module_total_times_seconds": {}, "module_calls_count": {}, }, "token_counts": { "total_visual_tokens": total_stats["token_counts"]["visual_tokens"], "avg_visual_tokens_per_item": total_stats["token_counts"]["visual_tokens"] / n, "total_language_input_tokens_raw": total_stats["token_counts"]["language_input_tokens_raw"], "avg_language_input_tokens_raw_per_item": total_stats["token_counts"]["language_input_tokens_raw"] / n, "total_llm_total_input_tokens": total_stats["token_counts"]["llm_total_input_tokens"], "avg_llm_total_input_tokens_per_item": total_stats["token_counts"]["llm_total_input_tokens"] / n, "total_language_output_tokens": total_stats["token_counts"]["language_output_tokens"], "avg_language_output_tokens_per_item": total_stats["token_counts"]["language_output_tokens"] / n, }, "token_reduction": { # NEW: 输出平均删减与保留比例 "avg_text_pruned_per_item": float(avg_text_pruned), "avg_vision_pruned_per_item": float(avg_vision_pruned), "avg_total_pruned_per_item": float(avg_total_pruned), "avg_text_keep_ratio": float(avg_text_keep_ratio), "avg_vision_keep_ratio": float(avg_vision_keep_ratio), "avg_total_keep_ratio": float(avg_total_keep_ratio), "sum_pre_text": int(pre_txt), "sum_post_text": int(post_txt), "sum_pre_vision": int(pre_vis), "sum_post_vision": int(post_vis), "sum_pre_total": int(pre_tot), "sum_post_total": int(post_tot), } } for mname, mstats in total_stats["module_inference_times"].items(): total = mstats.get("total", 0.0) count = mstats.get("count", 0) final_stats["inference_times"]["module_total_times_seconds"][mname] = total final_stats["inference_times"]["module_calls_count"][mname] = count final_stats["inference_times"]["module_average_times_per_call"][mname] = (total / count) if count > 0 else 0.0 with open(out_path, 'w', encoding='utf-8') as f: json.dump(final_stats, f, ensure_ascii=False, indent=4) print_master(f"[{task_name}] {encode_side} inference statistics saved to: {out_path}") # ------- Encode queries ------- if do_query: print_master(f"[{tag}] Encoding queries...") eval_qry_collator = MultimodalEvalDataCollator(processor, model_args, data_args, "qry") eval_qry_loader = DataLoader( eval_qry_dataset, batch_size=training_args.per_device_eval_batch_size, collate_fn=eval_qry_collator, num_workers=training_args.dataloader_num_workers ) query_embeds, gt_infos, qry_batch_stats, qry_img_masks, qry_txt_masks, qry_token_records = encode_embeddings( model, eval_qry_loader, training_args, model_args, padded_qry_dataset, encode_side="qry", description=f"Queries[{tag}] for {dataset_name}" ) # 截断到真实长度 true_qry_len = len(full_eval_qry_dataset) query_embeds = query_embeds[:true_qry_len] gt_infos = gt_infos[:true_qry_len] qry_img_masks = qry_img_masks[:true_qry_len] qry_txt_masks = qry_txt_masks[:true_qry_len] # NEW qry_token_records = qry_token_records[:true_qry_len] # NEW # 累计统计并保存 qry_total_stats = init_total_stats() for bs in qry_batch_stats: accumulate_stats(qry_total_stats, bs) if local_rank == 0: with open(query_embed_path, 'wb') as f: pickle.dump(query_embeds, f) if not os.path.exists(dataset_info_path): with open(dataset_info_path, 'w') as f: for info in gt_infos: f.write(json.dumps(info) + '\n') # 保存 image masks with open(qry_img_masks_path, 'w', encoding='utf-8') as f: for i, m in enumerate(qry_img_masks): f.write(json.dumps({"index": i, "mask": m}, ensure_ascii=False) + "\n") # 保存 text masks(NEW) with open(qry_txt_masks_path, 'w', encoding='utf-8') as f: for i, m in enumerate(qry_txt_masks): f.write(json.dumps({"index": i, "mask": m}, ensure_ascii=False) + "\n") # 保存 per-sample token 统计(NEW) with open(qry_token_stats_path, 'w', encoding='utf-8') as f: for i, rec in enumerate(qry_token_records): f.write(json.dumps({"index": i, **rec}, ensure_ascii=False) + "\n") print_master(f"Saved query embeddings to {query_embed_path}") print_master(f"Saved query image token masks to {qry_img_masks_path}") print_master(f"Saved query text token masks to {qry_txt_masks_path}") print_master(f"Saved query token stats to {qry_token_stats_path}") finalize_and_save_stats(qry_total_stats, query_inference_stats_path, dataset_name, f"query[{tag}]") if dist.is_initialized(): dist.barrier() # ------- Encode candidates ------- if do_cand: print_master(f"[{tag}] Encoding candidates...") eval_cand_collator = MultimodalEvalDataCollator(processor, model_args, data_args, "cand") eval_cand_loader = DataLoader( eval_cand_dataset, batch_size=training_args.per_device_eval_batch_size, collate_fn=eval_cand_collator, num_workers=training_args.dataloader_num_workers ) cand_embeds, all_cand_ids, cand_batch_stats, cand_img_masks, cand_txt_masks, cand_token_records = encode_embeddings( model, eval_cand_loader, training_args, model_args, padded_cand_dataset, encode_side="cand", description=f"Candidates[{tag}] for {dataset_name}" ) true_cand_len = len(full_eval_cand_dataset) cand_embeds = cand_embeds[:true_cand_len] all_cand_ids = all_cand_ids[:true_cand_len] cand_img_masks = cand_img_masks[:true_cand_len] cand_txt_masks = cand_txt_masks[:true_cand_len] # NEW cand_token_records = cand_token_records[:true_cand_len] # NEW cand_total_stats = init_total_stats() for bs in cand_batch_stats: accumulate_stats(cand_total_stats, bs) if local_rank == 0: cand_embed_dict = {cid: emb for cid, emb in zip(all_cand_ids, cand_embeds)} with open(cand_embed_path, 'wb') as f: pickle.dump(cand_embed_dict, f) with open(cand_img_masks_path, 'w', encoding='utf-8') as f: for cid, m in zip(all_cand_ids, cand_img_masks): f.write(json.dumps({"cand_id": str(cid), "mask": m}, ensure_ascii=False) + "\n") # 保存 text masks(NEW) with open(cand_txt_masks_path, 'w', encoding='utf-8') as f: for cid, m in zip(all_cand_ids, cand_txt_masks): f.write(json.dumps({"cand_id": str(cid), "mask": m}, ensure_ascii=False) + "\n") # 保存 per-sample token 统计(NEW) with open(cand_token_stats_path, 'w', encoding='utf-8') as f: for cid, rec in zip(all_cand_ids, cand_token_records): f.write(json.dumps({"cand_id": str(cid), **rec}, ensure_ascii=False) + "\n") print_master(f"Saved candidate embeddings to {cand_embed_path}") print_master(f"Saved candidate image token masks to {cand_img_masks_path}") print_master(f"Saved candidate text token masks to {cand_txt_masks_path}") print_master(f"Saved candidate token stats to {cand_token_stats_path}") finalize_and_save_stats(cand_total_stats, cand_inference_stats_path, dataset_name, f"candidate[{tag}]") if dist.is_initialized(): dist.barrier() # --------- Scoring per layer + combined + early-exit curve --------- if local_rank == 0: dataset_info_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_info.jsonl") gt_infos = [json.loads(l) for l in open(dataset_info_path)] rank_against_all_candidates = task_config.get("eval_type", "global") == "global" metrics_to_report = task_config.get("metrics", ["hit", "ndcg", "precision", "recall", "f1", "map", "mrr"]) layer_tags = [make_layer_tag(l) for l in layers_to_eval] sims_by_layer = {} # tag -> list[ dict(cand_id->score) ] for tag in layer_tags: query_embed_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_qry_{tag}") cand_embed_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt_{tag}") with open(query_embed_path, 'rb') as f: qry_embeds = pickle.load(f) with open(cand_embed_path, 'rb') as f: cand_embed_dict = pickle.load(f) pred_dicts = [] score_detail_dicts = [] sims_for_exit = [] if rank_against_all_candidates: cand_keys = list(cand_embed_dict.keys()) cand_embeds = np.stack([cand_embed_dict[key] for key in cand_keys]) if isinstance(qry_embeds, np.ndarray) and qry_embeds.ndim == 3: # Late-interaction qry_embed_t = torch.from_numpy(qry_embeds) cand_embeds_t = [torch.from_numpy(np.array(t)) for t in cand_embeds] sim_matrix = processor.score(qry_embed_t, cand_embeds_t, batch_size=64).cpu().numpy() else: sim_matrix = np.dot(qry_embeds, cand_embeds.T) ranked_all = np.argsort(-sim_matrix, axis=1) for qid, gt_info in tqdm(enumerate(gt_infos), total=len(gt_infos), desc=f"[{tag}] scoring(all) {dataset_name}"): ranked_indices = ranked_all[qid] rel_docids = gt_info["label_name"] if isinstance(gt_info["label_name"], list) else [gt_info["label_name"]] rel_scores = gt_info.get("rel_scores") pred_dicts.append({ "prediction": [cand_keys[i] for i in ranked_indices], "label": rel_docids, "rel_scores": rel_scores, }) score_detail_dicts.append(build_score_details(qid, cand_keys, sim_matrix[qid], ranked_indices)) sims_for_exit.append({cand_keys[i]: float(sim_matrix[qid][i]) for i in range(len(cand_keys))}) else: # 非全局:每个query用 gt_info["cand_names"] 的子集进行评分 for qid, (qry_embed, gt_info) in tqdm(enumerate(zip(qry_embeds, gt_infos)), total=len(gt_infos), desc=f"[{tag}] scoring(local) {dataset_name}"): cand_ids_local = gt_info["cand_names"] cand_embeds = np.stack([cand_embed_dict[key] for key in cand_ids_local]) if isinstance(qry_embeds, np.ndarray) and qry_embeds.ndim == 3: qry_embed_t = torch.from_numpy(np.array(qry_embed)).unsqueeze(0) # [1, Lq, H] cand_embeds_t = [torch.from_numpy(np.array(t)) for t in cand_embeds] sim_vec = processor.score(qry_embed_t, cand_embeds_t, batch_size=1024).cpu().numpy()[0] else: sim_vec = np.dot(qry_embed, cand_embeds.T) ranked_indices = np.argsort(-sim_vec) rel_docids = gt_info["label_name"] if isinstance(gt_info["label_name"], list) else [gt_info["label_name"]] rel_scores = gt_info.get("rel_scores") pred_dicts.append({ "prediction": [cand_ids_local[i] for i in ranked_indices], "label": rel_docids, "rel_scores": rel_scores, }) score_detail_dicts.append(build_score_details(qid, cand_ids_local, sim_vec, ranked_indices)) sims_for_exit.append({cid: float(s) for cid, s in zip(cand_ids_local, sim_vec.tolist())}) # 保存每层指标与详情 layer_score_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_score_{tag}.json") layer_pred_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_pred_{tag}.jsonl") layer_detail_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_score_details_{tag}.jsonl") metrics = RankingMetrics(metrics_to_report) score_dict = metrics.evaluate(pred_dicts) score_dict["num_pred"] = len(pred_dicts) score_dict["num_data"] = len(gt_infos) with open(layer_score_path, "w") as f: json.dump(score_dict, f, indent=4) with open(layer_pred_path, "w") as f: for pred in pred_dicts: f.write(json.dumps(pred) + '\n') with open(layer_detail_path, "w") as f: for detail in score_detail_dicts: f.write(json.dumps(detail) + "\n") print_master(f"[{dataset_name}] {tag} score: " + json.dumps({k: (f"{v:.4f}" if isinstance(v, (int, float)) else v) for k, v in score_dict.items()})) sims_by_layer[tag] = sims_for_exit # 合并对比文件 + 早停曲线(仅在存在中间层时) if len(layer_tags) == 2 and "layerlast" in layer_tags: mid_tag = [t for t in layer_tags if t != "layerlast"][0] last_tag = "layerlast" # 合并详情:每个query包含 mid/last 的cand_scores、top1、margin combined_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_score_details_both_layers.jsonl") with open(combined_path, "w", encoding='utf-8') as f: for qid in range(len(gt_infos)): sims_mid = sims_by_layer[mid_tag][qid] sims_last = sims_by_layer[last_tag][qid] def top1_cid(sims: dict): return max(sims.items(), key=lambda x: x[1])[0] if sims else None def margin_of(sims: dict): vals = np.array(list(sims.values()), dtype=np.float32) return top1_top2_margin_from_array(vals) row = { "qid": int(qid), "label": gt_infos[qid]["label_name"] if isinstance(gt_infos[qid]["label_name"], list) else [gt_infos[qid]["label_name"]], "mid": { "top1": top1_cid(sims_mid), "margin": margin_of(sims_mid), "cand_scores": sims_mid }, "last": { "top1": top1_cid(sims_last), "margin": margin_of(sims_last), "cand_scores": sims_last } } f.write(json.dumps(row, ensure_ascii=False) + "\n") print_master(f"[{dataset_name}] combined details saved to {combined_path}") # 早停曲线(margin 阈值) taus = [round(x, 3) for x in np.linspace(0.0, 0.6, 31).tolist()] labels = [ gi["label_name"] if isinstance(gi["label_name"], list) else [gi["label_name"]] for gi in gt_infos ] exit_curve = simulate_early_exit_by_margin( sims_by_layer[mid_tag], sims_by_layer[last_tag], labels, metrics_to_report, taus, rank_against_all_candidates ) curve_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_early_exit_curve_margin.json") with open(curve_path, "w") as f: json.dump(exit_curve, f, indent=4) print_master(f"[{dataset_name}] early-exit curve saved to {curve_path}") if dist.is_initialized(): dist.barrier() if __name__ == '__main__': main()