code_SAS_VLM2Vec / eval_test_time_multilayer_AOP.py
MgGladys's picture
Add files using upload-large-folder tool
ac8b25b verified
import datetime
import logging
import json
import random
import time
import numpy as np
import os
import pickle
import sys
import torch
import torch.distributed as dist
import torch.nn.functional as F
import yaml
import transformers
import json
import datetime
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import HfArgumentParser, AutoConfig, AutoTokenizer
from datasets import Dataset, concatenate_datasets
from datasets.distributed import split_dataset_by_node
from src.arguments import ModelArguments, DataArguments, TrainingArguments
from src.data.collator.eval_collator import MultimodalEvalDataCollator
from src.data.eval_dataset.base_eval_dataset import AutoEvalPairDataset, generate_cand_dataset
from src.eval_utils.metrics import RankingMetrics
from src.model.model_multilayer_AOP_infer import MMEBModel
from src.model.processor import get_backbone_name, load_processor, COLPALI
from src.utils import batch_to_device, print_rank, print_master
from dataclasses import dataclass
def get_env_mid_layer():
v = os.environ.get("MID_LM_LAYER", "").strip()
if v == "" or v.lower() in {"none", "null"}:
return None
try:
return int(v)
except:
logger.warning(f"Invalid MID_LM_LAYER={v}, ignore.")
return None
# ------------- AOP-Prune config parsing -------------
def _parse_bool(v: str, default=False):
if v is None: return default
v = v.strip().lower()
return v in {"1","true","yes","y","t","on"}
def _parse_float(v: str, default=None):
try: return float(v) if v is not None else default
except: return default
def _parse_int(v: str, default=None):
try: return int(v) if v is not None else default
except: return default
def get_env_aop_config():
"""
从环境变量读取 AOP 剪裁配置。仅作为“驱动层”的简要测试开关;
实际剪裁逻辑在底模里(Qwen2-VLModel.forward)实现。
"""
enabled = _parse_bool(os.environ.get("AOP_ENABLED"), False)
apply_to = os.environ.get("AOP_APPLY", "qry").strip().lower() # qry|cand|both
layer_idx = _parse_int(os.environ.get("AOP_LAYER"), None)
mode = os.environ.get("AOP_MODE", "delta").strip().lower()
# 通用回退
delta = _parse_float(os.environ.get("AOP_DELTA"), 0.10)
khat = _parse_float(os.environ.get("AOP_KHAT"), 1.0)
keep_ratio = _parse_float(os.environ.get("AOP_KEEP_RATIO"), 1.0)
min_keep = _parse_int(os.environ.get("AOP_MIN_KEEP"), 64)
use_bias = _parse_bool(os.environ.get("AOP_USE_BIAS"), True)
# 按类型控制
prune_vision = _parse_bool(os.environ.get("AOP_PRUNE_VISION"), True)
prune_text = _parse_bool(os.environ.get("AOP_PRUNE_TEXT"), False)
delta_v = _parse_float(os.environ.get("AOP_DELTA_VISION"), None)
khat_v = _parse_float(os.environ.get("AOP_KHAT_VISION"), None)
keep_ratio_v= _parse_float(os.environ.get("AOP_KEEP_RATIO_VISION"), None)
min_keep_v = _parse_int(os.environ.get("AOP_MIN_KEEP_VISION"), None)
delta_t = _parse_float(os.environ.get("AOP_DELTA_TEXT"), None)
khat_t = _parse_float(os.environ.get("AOP_KHAT_TEXT"), None)
keep_ratio_t= _parse_float(os.environ.get("AOP_KEEP_RATIO_TEXT"), None)
min_keep_t = _parse_int(os.environ.get("AOP_MIN_KEEP_TEXT"), 32)
protect_text_last = _parse_int(os.environ.get("AOP_PROTECT_TEXT_LAST"), 16)
protect_special = _parse_bool(os.environ.get("AOP_PROTECT_SPECIAL"), True)
margin_src = os.environ.get("AOP_MARGIN", "").strip().lower() # "" or "mid"
attn_impl = os.environ.get("AOP_ATTN_IMPL", "").strip().lower() # "" or "sdpa"
if layer_idx is None and enabled:
logger.warning("AOP_ENABLED=1 但未设置 AOP_LAYER,关闭 AOP。"); enabled=False
# 新增:选择策略(aop | random)
selection = os.environ.get("AOP_SELECTION", "aop").strip().lower()
if _parse_bool(os.environ.get("AOP_RANDOM"), False):
selection = "random"
random_seed = _parse_int(os.environ.get("AOP_RANDOM_SEED"), None)
# 选择策略:aop | random | attention
selection = os.environ.get("AOP_SELECTION", "aop").strip().lower()
if _parse_bool(os.environ.get("AOP_RANDOM"), False):
selection = "random"
random_seed = _parse_int(os.environ.get("AOP_RANDOM_SEED"), None)
attn_agg = os.environ.get("AOP_ATTENTION_AGG", "mean").strip().lower() # mean|max|sum
cfg = {
"enabled": enabled,
"apply_to": apply_to,
"layer_idx": layer_idx,
"mode": mode,
# 回退
"delta": delta, "K_hat": khat,
"keep_ratio": keep_ratio, "min_keep": min_keep,
"use_bias": use_bias, "eps": 1e-6,
# 类型开关
"prune_vision": prune_vision,
"prune_text": prune_text,
# 视觉桶
"delta_vision": delta_v,
"K_hat_vision": khat_v,
"keep_ratio_vision": keep_ratio_v,
"min_keep_vision": min_keep_v,
# 文本桶
"delta_text": delta_t,
"K_hat_text": khat_t,
"keep_ratio_text": keep_ratio_t,
"min_keep_text": min_keep_t,
# 文本保护
"protect_text_last": protect_text_last,
"protect_special": protect_special,
# 可选:排名安全预算
"margin_mid": None if margin_src != "mid" else "USE_MID_MARGIN",
"epsilon_hat": None,
"attn_impl_override": attn_impl if attn_impl in {"sdpa"} else "",
# NEW: 选择策略
"selection": selection, # "aop" 或 "random"
"random_seed": random_seed, # 可选
"attn_agg": attn_agg,
}
return cfg
def _b(k, d=False):
v = os.environ.get(k)
if v is None: return d
return str(v).strip().lower() in {"1","true","yes","y","on","t"}
def _i(k, d=None):
try: return int(os.environ.get(k, d))
except: return d
def _s(k, d=None):
v = os.environ.get(k, d)
return None if v is None else str(v).strip().lower()
def get_env_vpool_config():
return {
"enabled": _b("VPOOL_ENABLED", False),
"apply_to": _s("VPOOL_APPLY", "both"), # qry|cand|both
"layer_idx": _i("VPOOL_LAYER", 1),
"kernel": _i("VPOOL_KERNEL", 2),
"stride": _i("VPOOL_STRIDE", None) or _i("VPOOL_KERNEL", 2),
"method": _s("VPOOL_METHOD", "avg"), # avg|max|linear|conv
"protect_cls": _b("VPOOL_PROTECT_CLS", True),
"vision_only": _b("VPOOL_ONLY_VISION", True),
"monitor": _b("VPOOL_MONITOR", False),
}
def get_env_eval_layers():
"""
解析环境变量 LM_LAYERS(优先)或兼容旧的 MID_LM_LAYER。
- LM_LAYERS 示例:"4,8,12,last";可包含 'last'/'none'/'null'/'-1' 表示最后一层(None)。
- 若未设置 LM_LAYERS,则回落到旧逻辑:MID_LM_LAYER=None -> [None];否则 [mid, None]
返回: list[ int | None ],例如 [4, 8, 12, None];None 代表最后一层。
"""
v = os.environ.get("LM_LAYERS", None)
if v is not None:
v = v.strip()
if v:
toks = [t.strip() for t in v.split(',') if t.strip() != ""]
layers = []
for tok in toks:
tl = tok.lower()
if tl in {"last", "none", "null", "-1"}:
layers.append(None)
else:
try:
val = int(tok)
if val > 0:
layers.append(val)
else:
logger.warning(f"Ignoring non-positive layer '{tok}' in LM_LAYERS.")
except Exception:
logger.warning(f"Invalid token '{tok}' in LM_LAYERS; must be int or 'last'/'none'.")
# 去重但保持顺序
seen = set()
uniq = []
for l in layers:
key = -1 if l is None else l
if key in seen:
continue
seen.add(key)
uniq.append(l)
if not uniq:
return [None]
return uniq
else:
# 兼容旧逻辑
mid = get_env_mid_layer()
return [None] if mid is None else [mid, None]
def make_layer_tag(keep_layers: int | None):
return f"layer{keep_layers}" if keep_layers and keep_layers > 0 else "layerlast"
def dot_sim(a: np.ndarray, b: np.ndarray) -> np.ndarray:
# a: [Nq, D], b: [Nc, D], both L2-normalized already if normalize=true
return a @ b.T
def build_score_details(qid: int, cand_ids: list, score_vec: np.ndarray, ranked_indices: np.ndarray):
return {
"qid": int(qid),
"cand_scores": [
{"cand_id": str(cand_ids[i]), "score": float(score_vec[i])}
for i in ranked_indices
]
}
def top1_top2_margin(score_vec: np.ndarray) -> float:
if len(score_vec) < 2:
return float("inf") # 只有一个候选时视作极大margin
top2 = np.partition(score_vec, -2)[-2:]
top2.sort()
return float(top2[-1] - top2[-2])
def simulate_early_exit_by_margin(
sims_mid: list[dict], sims_last: list[dict], labels: list[list[str]], metrics_to_report: list[str],
taus: list[float], rank_global: bool
):
"""
sims_mid / sims_last: 每个query一个dict: {cand_id: score}
labels: 每个query的正样本cand_id列表
返回:不同tau下的覆盖率、指标
"""
assert len(sims_mid) == len(sims_last) == len(labels)
N = len(labels)
results = []
from src.eval_utils.metrics import RankingMetrics
metrics = RankingMetrics(metrics_to_report)
# 预构造 用于metrics.evaluate 的pred_dict
def to_pred_dicts(use_mid_mask: list[bool]) -> list[dict]:
pred_dicts = []
for qid in range(N):
sims_use = sims_mid[qid] if use_mid_mask[qid] else sims_last[qid]
# 排序
ranked = sorted(sims_use.items(), key=lambda x: -x[1])
pred_dicts.append({
"prediction": [cid for cid, _ in ranked],
"label": labels[qid],
"rel_scores": None
})
return pred_dicts
# 计算中间层margin
margins = []
for qid in range(N):
# 取前两大分数的margin
if len(sims_mid[qid]) == 0:
margins.append(0.0)
continue
scores = np.array(list(sims_mid[qid].values()), dtype=np.float32)
margins.append(top1_top2_margin(scores))
margins = np.array(margins, dtype=np.float32)
for tau in taus:
use_mid_mask = (margins >= tau).tolist()
pred_dicts = to_pred_dicts(use_mid_mask)
score_dict = metrics.evaluate(pred_dicts)
coverage = float(np.mean(use_mid_mask)) # 早停覆盖率
results.append({
"tau": tau,
"coverage": coverage,
**score_dict
})
return results
def top1_top2_margin_from_array(score_vec: np.ndarray) -> float:
if score_vec is None or len(score_vec) == 0:
return 0.0
if len(score_vec) == 1:
return float('inf')
# 取前两大
top2 = np.partition(score_vec, -2)[-2:]
top2.sort()
return float(top2[-1] - top2[-2])
logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s [%(name)s:%(lineno)s] %(message)s')
logger = logging.getLogger(__name__)
# --- Global Dictionaries for Hooks (will be cleared before each encode_embeddings call) ---
timing_info = {}
token_info = {
"vision_tokens": 0,
"text_input_tokens": 0, # Refers to the original text token count
"text_output_tokens": 0, # Not directly applicable here as we are encoding, not generating. Will be 0.
"total_llm_input_tokens": 0, # Refers to the total tokens LLM receives (visual + formatted text)
}
# --- Hook Functions Definition ---
def timing_pre_hook(module, input):
module_id = id(module)
if module_id not in timing_info:
timing_info[module_id] = []
timing_info[module_id].append((time.time(), 'pre', module.__class__.__name__))
def timing_post_hook(module, input, output):
module_id = id(module)
if module_id not in timing_info:
# print(f"Warning: No pre-hook data for module {module.__class__.__name__} ({module_id})")
return
timing_info[module_id].append((time.time(), 'post', module.__class__.__name__))
# Collect vision token count (only from Vision Transformer module's post hook)
module_name = module.__class__.__name__
if "vision" in module_name.lower() and "transformer" in module_name.lower():
if isinstance(output, torch.Tensor):
token_info["vision_tokens"] = output.shape[0] # For visual features, usually (batch_size, num_tokens, hidden_dim)
elif hasattr(output, 'last_hidden_state'):
token_info["vision_tokens"] = output.last_hidden_state.shape[1]
def register_model_hooks(model):
registered_modules = []
core_model = model
# print_master(f"DEBUG: Initial model type in register_model_hooks: {type(model)}")
# Vision module
if hasattr(core_model, 'visual') and core_model.visual is not None:
vision_module = core_model.visual
vision_module.register_forward_pre_hook(timing_pre_hook)
vision_module.register_forward_hook(timing_post_hook)
registered_modules.append(vision_module)
print_master(f"Registered hooks for vision module: {vision_module.__class__.__name__}")
else:
print_master(f"WARNING: No 'visual' attribute found on core_model ({type(core_model)}).")
# Merger module (if inside visual) - it's part of the vision component
if hasattr(core_model, 'visual') and hasattr(core_model.visual, 'merger') and core_model.visual.merger is not None:
merger_module = core_model.visual.merger
merger_module.register_forward_pre_hook(timing_pre_hook)
merger_module.register_forward_hook(timing_post_hook)
registered_modules.append(merger_module)
print_master(f"Registered hooks for merger module: {merger_module.__class__.__name__}")
else:
print_master(f"WARNING: No 'merger' attribute found on core_model.visual ({type(getattr(core_model, 'visual', 'N/A'))}).")
# Language model body
if hasattr(core_model, 'model') and core_model.model is not None:
llm_main_module = core_model.model
llm_main_module.register_forward_pre_hook(timing_pre_hook)
llm_main_module.register_forward_hook(timing_post_hook)
registered_modules.append(llm_main_module)
print_master(f"Registered hooks for LLM main module: {llm_main_module.__class__.__name__}")
else:
print_master(f"WARNING: No 'model' attribute found on core_model ({type(core_model)}).")
# LM Head
if hasattr(core_model, 'lm_head') and core_model.lm_head is not None:
lm_head_module = core_model.lm_head
lm_head_module.register_forward_pre_hook(timing_pre_hook)
lm_head_module.register_forward_hook(timing_post_hook)
registered_modules.append(lm_head_module)
print_master(f"Registered hooks for LM head module: {lm_head_module.__class__.__name__}")
else:
print_master(f"WARNING: No 'lm_head' attribute found on core_model ({type(core_model)}).")
if not registered_modules:
print_master("Warning: No major modules found for hook registration. Check model architecture.")
return registered_modules
def pad_dataset_to_divisible(dataset, world_size):
num_samples = len(dataset)
if num_samples % world_size == 0:
return dataset, num_samples
num_to_add = world_size - (num_samples % world_size)
padded_size = num_samples + num_to_add
padding_data = dataset.select([i % len(dataset) for i in range(num_to_add)])
padded_dataset = concatenate_datasets([dataset, padding_data])
return padded_dataset, padded_size
def encode_embeddings(
model: MMEBModel,
loader: DataLoader,
training_args: TrainingArguments,
model_args: ModelArguments,
full_dataset: Dataset,
encode_side: str,
description: str = "Encoding"
) -> tuple[np.ndarray, list, list, list]: # CHANGED: + list for img_token_masks
"""
Encodes embeddings for a given dataset using the model, handling both standard and
late-interaction models in a DDP-safe manner.
Returns:
- embeddings: np.ndarray
- infos_or_ids: list
- batch_stats_list: list
- img_token_masks: list[None | list[bool]] # NEW
"""
local_rank = dist.get_rank() if dist.is_initialized() else 0
world_size = dist.get_world_size() if dist.is_initialized() else 1
# Check if the model is a late-interaction type
is_late_interaction = (model_args.model_backbone == COLPALI)
local_embeds = []
local_gt_infos = []
local_max_len = 0
# --- New: List to store statistics for each batch ---
batch_stats_list = []
# --- NEW: Collect masks ---
local_img_token_masks = [] # post image mask per sample
local_txt_token_masks = [] # NEW: post text mask per sample
local_post_attn_masks = [] # NEW: post attention_mask per sample (after prune, 1/0)
# --- NEW: per-sample token reduction records ---
local_token_records = [] # 每条样本一个 dict,含 pre/post/delta 数量
model.eval()
# Register hooks for the model once per encode_embeddings call
registered_hooks = register_model_hooks(model)
# --- NEW: helpers to取mask并序列化 ---
def _search_key(obj, key: str):
# 递归搜索 dict/list/tuple,找到指定 key
if isinstance(obj, dict):
if key in obj:
return obj[key]
for v in obj.values():
r = _search_key(v, key)
if r is not None:
return r
elif isinstance(obj, (list, tuple)):
for v in obj:
r = _search_key(v, key)
if r is not None:
return r
return None
def _to_serializable_mask_list(mask_list, batch_size: int):
# 将模型返回的 mask(list/tensor/ndarray/None)转成 [None | list[bool]] * B
if mask_list is None:
return [None] * batch_size
out = []
if isinstance(mask_list, (list, tuple)):
for m in mask_list:
if m is None:
out.append(None)
elif torch.is_tensor(m):
out.append(m.detach().cpu().tolist())
elif isinstance(m, np.ndarray):
out.append(m.tolist())
else:
# already python list/bool
out.append(m)
elif torch.is_tensor(mask_list):
# 若是 2D 张量(B, L),直接 tolist() -> list[list[bool/int]]
out = mask_list.detach().cpu().tolist()
elif isinstance(mask_list, np.ndarray):
out = mask_list.tolist()
else:
# 未知类型,保守返回 None 占位
out = [None] * batch_size
# 长度对齐 batch_size
if isinstance(out, list):
if len(out) < batch_size:
out = out + [None] * (batch_size - len(out))
elif len(out) > batch_size:
out = out[:batch_size]
return out
def _to_bool_lists(m, batch_size: int):
lst = _to_serializable_mask_list(m, batch_size)
# 归一化成 list[ list[bool] | None ]
out = []
for x in lst:
if x is None:
out.append(None)
else:
# x 可能是 list[int] 或 list[bool]
out.append([bool(int(v)) for v in x])
return out
with torch.no_grad():
for inputs, dataset_info in tqdm(loader, desc=f"{description} (rank {local_rank})", disable=local_rank > 0):
# --- Reset statistics for each inference pass ---
timing_info.clear()
token_info["vision_tokens"] = 0
token_info["text_input_tokens"] = 0
token_info["text_output_tokens"] = 0
token_info["total_llm_input_tokens"] = 0
inputs = batch_to_device(inputs, training_args.device)
current_batch_size = inputs['input_ids'].shape[0] if 'input_ids' in inputs and inputs['input_ids'] is not None else 1
with torch.autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda"):
start_inference_time = time.time()
# ---- NEW: 按侧开/关 AOP ----
aop_cfg = getattr(model.encoder, "aop_prune_config", None)
_orig_enabled = None
if isinstance(aop_cfg, dict) and aop_cfg:
_orig_enabled = aop_cfg.get("enabled", False)
apply_to = aop_cfg.get("apply_to", "qry")
side_enable = (apply_to == "both") or (apply_to == encode_side)
aop_cfg["enabled"] = bool(side_enable and _orig_enabled)
setattr(model.encoder, "aop_prune_config", aop_cfg)
if encode_side == "qry":
output = model(qry=inputs)
reps = output["qry_reps"].detach()
local_gt_infos.extend(dataset_info)
else:
output = model(tgt=inputs)
reps = output["tgt_reps"].detach()
local_gt_infos.extend([info["cand_name"] for info in dataset_info])
# ---- NEW: 恢复 enabled(避免影响下个 encode_side)----
if isinstance(aop_cfg, dict) and _orig_enabled is not None:
aop_cfg["enabled"] = _orig_enabled
setattr(model.encoder, "aop_prune_config", aop_cfg)
end_inference_time = time.time()
# --- NEW: 提取 post-prune 的 image/text 掩码 与 post attention_mask ---
img_masks_raw = None
txt_masks_raw = None
post_attn_raw = None
if isinstance(output, dict):
img_masks_raw = _search_key(output, "image_token_bool_masks")
txt_masks_raw = _search_key(output, "text_token_bool_masks") # NEW
post_attn_raw = _search_key(output, "post_attention_mask")
if post_attn_raw is None:
# 兼容 mldaop 变体:有些只返回 attention_mask
post_attn_raw = _search_key(output, "attention_mask")
# 兼容:若挂在 model 上
if img_masks_raw is None and hasattr(model, "image_token_bool_masks"):
img_masks_raw = getattr(model, "image_token_bool_masks")
if txt_masks_raw is None and hasattr(model, "text_token_bool_masks"):
txt_masks_raw = getattr(model, "text_token_bool_masks")
if post_attn_raw is None and hasattr(model, "post_attention_mask"):
post_attn_raw = getattr(model, "post_attention_mask")
img_masks_serializable = _to_serializable_mask_list(img_masks_raw, current_batch_size)
txt_masks_serializable = _to_serializable_mask_list(txt_masks_raw, current_batch_size) # NEW
post_attn_serializable = _to_serializable_mask_list(post_attn_raw, current_batch_size) # NEW
local_img_token_masks.extend(img_masks_serializable)
local_txt_token_masks.extend(txt_masks_serializable) # NEW
local_post_attn_masks.extend(post_attn_serializable) # NEW
# --- NEW: 计算本 batch 的 pre/post/delta 数量并累计 ---
cfg = getattr(model.encoder, "config", None)
# pre masks 来自 inputs(删前)
input_ids = inputs.get("input_ids", None)
attn2d_pre = inputs.get("attention_mask", None)
if input_ids is None or attn2d_pre is None or cfg is None:
# 无法统计,留空
pre_vis_counts = [0] * current_batch_size
pre_txt_counts = [0] * current_batch_size
pre_tot_counts = [0] * current_batch_size
else:
iid = input_ids
am = attn2d_pre.to(torch.bool)
image_token_id = getattr(cfg, "image_token_id", None)
video_token_id = getattr(cfg, "video_token_id", None)
bos_id = getattr(cfg, "bos_token_id", None)
eos_id = getattr(cfg, "eos_token_id", None)
pad_id = getattr(cfg, "pad_token_id", None)
is_image = (iid == image_token_id) if (image_token_id is not None and image_token_id >= 0) else torch.zeros_like(iid, dtype=torch.bool)
is_video = (iid == video_token_id) if (video_token_id is not None and video_token_id >= 0) else torch.zeros_like(iid, dtype=torch.bool)
is_vision = is_image | is_video
is_special = torch.zeros_like(iid, dtype=torch.bool)
for tid in [bos_id, eos_id, pad_id]:
if tid is not None and tid >= 0:
is_special |= (iid == tid)
pre_txt_mask = am & (~is_vision) & (~is_special)
pre_vis_mask = am & is_vision
pre_vis_counts = pre_vis_mask.sum(dim=1).tolist()
pre_txt_counts = pre_txt_mask.sum(dim=1).tolist()
pre_tot_counts = am.sum(dim=1).tolist()
# post masks(删后)来自模型输出;与 post_attn 做与运算
post_text_masks = _to_bool_lists(txt_masks_raw, current_batch_size) # list[ list[bool] | None ]
post_image_masks = _to_bool_lists(img_masks_raw, current_batch_size)
post_attn_masks = _to_bool_lists(post_attn_raw, current_batch_size)
sum_pre_text = 0; sum_post_text = 0
sum_pre_vis = 0; sum_post_vis = 0
sum_pre_tot = 0; sum_post_tot = 0
for i in range(current_batch_size):
pre_text = int(pre_txt_counts[i]) if i < len(pre_txt_counts) else 0
pre_vis = int(pre_vis_counts[i]) if i < len(pre_vis_counts) else 0
pre_tot = int(pre_tot_counts[i]) if i < len(pre_tot_counts) else 0
# post 计数:mask 可能为 None
m_text = post_text_masks[i] if post_text_masks is not None and i < len(post_text_masks) else None
m_img = post_image_masks[i] if post_image_masks is not None and i < len(post_image_masks) else None
m_attn = post_attn_masks[i] if post_attn_masks is not None and i < len(post_attn_masks) else None
if m_attn is None:
post_text = 0; post_vis = 0; post_tot = 0
else:
# 与 attention_mask 后统计 True 的数
if m_text is not None:
post_text = sum(1 for a, t in zip(m_attn, m_text) if a and t)
else:
post_text = 0
if m_img is not None:
post_vis = sum(1 for a, v in zip(m_attn, m_img) if a and v)
else:
post_vis = 0
post_tot = sum(1 for a in m_attn if a)
# 累计 batch 级
sum_pre_text += pre_text; sum_post_text += post_text
sum_pre_vis += pre_vis; sum_post_vis += post_vis
sum_pre_tot += pre_tot; sum_post_tot += post_tot
# 保存 per-sample 记录(用于 JSONL)
local_token_records.append({
"side": encode_side,
"pre": {"text": pre_text, "vision": pre_vis, "total": pre_tot},
"post": {"text": post_text, "vision": post_vis, "total": post_tot},
"delta":{"text": pre_text - post_text, "vision": pre_vis - post_vis, "total": pre_tot - post_tot},
})
# --- Update total LLM input tokens after the model call ---
if 'input_ids' in inputs and inputs['input_ids'] is not None:
token_info["total_llm_input_tokens"] = inputs['input_ids'].shape[1]
token_info["text_input_tokens"] = token_info["total_llm_input_tokens"] - token_info["vision_tokens"]
token_info["text_input_tokens"] = max(0, token_info["text_input_tokens"])
# --- Collect and Store Batch Statistics ---
batch_inference_time = end_inference_time - start_inference_time
current_batch_stats = {
"batch_size": current_batch_size,
"total_inference_time_seconds": batch_inference_time,
"module_inference_times": {},
"token_counts": {
"visual_tokens": token_info["vision_tokens"],
"language_input_tokens_raw": token_info["text_input_tokens"],
"llm_total_input_tokens": token_info["total_llm_input_tokens"],
"language_output_tokens": token_info["text_output_tokens"],
}
}
current_batch_stats["token_reduction"] = {
"sum_pre_text": sum_pre_text,
"sum_post_text": sum_post_text,
"sum_pre_vision": sum_pre_vis,
"sum_post_vision": sum_post_vis,
"sum_pre_total": sum_pre_tot,
"sum_post_total": sum_post_tot,
}
# Calculate and store module timings for the current batch
for module_obj in registered_hooks:
module_id = id(module_obj)
module_name = module_obj.__class__.__name__
times = timing_info.get(module_id, [])
durations = []
pre_times = {}
for t, event_type, _ in times:
if event_type == 'pre':
pre_times[module_id] = t
elif event_type == 'post' and module_id in pre_times:
duration = t - pre_times.pop(module_id)
durations.append(duration)
if durations:
current_batch_stats["module_inference_times"][module_name] = {
"total": sum(durations),
"count": len(durations),
"avg": sum(durations) / len(durations)
}
else:
current_batch_stats["module_inference_times"][module_name] = {
"total": 0.0,
"count": 0,
"avg": 0.0
}
batch_stats_list.append(current_batch_stats)
# --- Debug prints (optional) ---
print_rank(f"\n--- Inference Statistics for {encode_side} batch (Rank {local_rank}) ---")
print_rank(f"Batch Inference took: {batch_inference_time:.4f} seconds")
print_rank("--- Module Inference Timing Statistics ---")
for module_name, stats in current_batch_stats["module_inference_times"].items():
print_rank(f"**{module_name}**: Total: {stats['total']:.6f}s, Count: {stats['count']}, Avg: {stats['avg']:.6f}s")
print_rank("--- Token Count Statistics ---")
print_rank(f"**视觉 token 数量**: {current_batch_stats['token_counts']['visual_tokens']}")
print_rank(f"**语言输入 token 数量 (仅原始文本)**: {current_batch_stats['token_counts']['language_input_tokens_raw']}")
print_rank(f"**LLM总输入 token 数量 (包含视觉 + 格式化文本)**: {current_batch_stats['token_counts']['llm_total_input_tokens']}")
print_rank(f"**语言输出 token 数量**: {current_batch_stats['token_counts']['language_output_tokens']}")
if is_late_interaction and reps.dim() == 3:
local_max_len = max(local_max_len, reps.shape[1])
local_embeds.append(reps)
if not local_embeds:
# Handle cases where a rank gets no data
return np.array([]), [], [], [] # CHANGED: 4个返回值
# === DDP Synchronization and Padding for Late-Interaction Models ===
if is_late_interaction:
if dist.is_initialized():
# 1: global max length
local_max_len_tensor = torch.tensor(local_max_len, device=training_args.device)
dist.all_reduce(local_max_len_tensor, op=dist.ReduceOp.MAX)
global_max_len = local_max_len_tensor.item()
else:
global_max_len = local_max_len
# 2: pad to global max length
padded_embeds = []
for reps_batch in local_embeds:
if reps_batch.dim() == 3:
B, L, H = reps_batch.shape
padding_size = global_max_len - L
padded_batch = F.pad(reps_batch, (0, 0, 0, padding_size), "constant", 0)
padded_embeds.append(padded_batch)
else:
padded_embeds.append(reps_batch)
embeds_tensor = torch.cat(padded_embeds, dim=0).contiguous()
else:
embeds_tensor = torch.cat(local_embeds, dim=0).contiguous()
# === Gather embeddings and keys from all ranks ===
if dist.is_initialized() and full_dataset.num_rows >= world_size:
print_master(f"Gathering {encode_side} embeddings across all ranks...")
# tensor gather
output_shape = list(embeds_tensor.shape)
output_shape[0] = full_dataset.num_rows
embeds_tensor = embeds_tensor.to(training_args.device)
gathered_embeds_tensor = torch.empty(output_shape, dtype=embeds_tensor.dtype, device=training_args.device)
dist.all_gather_into_tensor(gathered_embeds_tensor, embeds_tensor)
final_embeddings = gathered_embeds_tensor.cpu().float().numpy()
# object gather for infos and stats
gathered_gt_infos = [None for _ in range(world_size)]
dist.all_gather_object(gathered_gt_infos, local_gt_infos)
all_gt_infos = [key for rank_keys in gathered_gt_infos for key in rank_keys]
gathered_batch_stats = [None for _ in range(world_size)]
dist.all_gather_object(gathered_batch_stats, batch_stats_list)
all_batch_stats = [stats for rank_stats in gathered_batch_stats for stats in rank_stats]
# --- NEW: gather masks ---
gathered_masks = [None for _ in range(world_size)]
dist.all_gather_object(gathered_masks, local_img_token_masks)
all_img_token_masks = [m for rank_list in gathered_masks for m in rank_list]
# NEW: gather text masks
gathered_txt_masks = [None for _ in range(world_size)]
dist.all_gather_object(gathered_txt_masks, local_txt_token_masks)
all_txt_token_masks = [m for rank_list in gathered_txt_masks for m in rank_list]
# NEW: gather post attention masks(如需)
gathered_post_attn = [None for _ in range(world_size)]
dist.all_gather_object(gathered_post_attn, local_post_attn_masks)
all_post_attn_masks = [m for rank_list in gathered_post_attn for m in rank_list]
# NEW: gather token records
gathered_token_recs = [None for _ in range(world_size)]
dist.all_gather_object(gathered_token_recs, local_token_records)
all_token_records = [r for rank_list in gathered_token_recs for r in rank_list]
else:
all_gt_infos = local_gt_infos
final_embeddings = embeds_tensor.cpu().float().numpy()
all_batch_stats = batch_stats_list
all_img_token_masks = local_img_token_masks # NEW
all_txt_token_masks = local_txt_token_masks
all_post_attn_masks = local_post_attn_masks
all_token_records = local_token_records
return final_embeddings, all_gt_infos, all_batch_stats, all_img_token_masks, all_txt_token_masks, all_token_records
def main():
# ----------------------- Distributed init -----------------------
if "RANK" in os.environ and dist.is_available() and not dist.is_initialized():
dist.init_process_group(backend="nccl", timeout=datetime.timedelta(minutes=60))
local_rank = dist.get_rank() if dist.is_initialized() else 0
world_size = dist.get_world_size() if dist.is_initialized() else 1
print_master("Distributed init debug info:")
print_master(f"RANK: {os.environ.get('RANK')}")
print_master(f"LOCAL_RANK: {os.environ.get('LOCAL_RANK')}")
print_master(f"WORLD_SIZE: {os.environ.get('WORLD_SIZE')}")
print_master(f"MASTER_ADDR: {os.environ.get('MASTER_ADDR')}")
print_master(f"MASTER_PORT: {os.environ.get('MASTER_PORT')}")
if dist.is_initialized():
print_rank(f"dist.get_rank(): {dist.get_rank()}")
print_rank(f"dist.get_world_size(): {dist.get_world_size()}")
# 兼容 torchrun 参数
for arg in sys.argv:
if arg.startswith("--local-rank="):
rank = arg.split("=")[1]
sys.argv.remove(arg)
sys.argv.append('--local_rank')
sys.argv.append(rank)
# ----------------------- Parse args -----------------------
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
model_args: ModelArguments
data_args: DataArguments
training_args: TrainingArguments
os.makedirs(data_args.encode_output_path, exist_ok=True)
# 支持多层评测(优先 LM_LAYERS,兼容 MID_LM_LAYER)
layers_to_eval = get_env_eval_layers()
print_master(f"Eval layers (qry/tgt): {layers_to_eval}")
# ----------------------- Model loading -----------------------
hf_config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True)
if not getattr(model_args, "model_backbone", None):
model_backbone = get_backbone_name(hf_config=hf_config, model_type=model_args.model_type)
setattr(model_args, 'model_backbone', model_backbone)
setattr(training_args, 'model_backbone', model_backbone)
print_master(f'Model Backbone: {model_args.model_backbone}')
# 仅 rank0 下载,其他rank等待缓存
if local_rank == 0:
processor = load_processor(model_args, data_args)
model = MMEBModel.load(model_args, is_trainable=False, processor=processor)
print_master(f"[rank=0] Loading the model from Huggingface: {model_args.model_name}...")
if torch.distributed.is_initialized():
torch.distributed.barrier()
if local_rank != 0:
print_rank(f"Loading the model from cache...")
processor = load_processor(model_args, data_args)
time.sleep(random.randint(2 * local_rank, 3 * local_rank))
model = MMEBModel.load(model_args, is_trainable=False, processor=processor)
model.eval()
model = model.to(training_args.device, dtype=torch.bfloat16)
# ---- NEW: AOP 剪裁配置注入(驱动底模里已实现的 AOP 逻辑)----
aop_cfg = get_env_aop_config()
if aop_cfg["enabled"]:
# 1) 写到 LoRA wrapper
setattr(model.encoder, "aop_prune_config", aop_cfg)
# 2) 同步到底座(ForConditionalGeneration / Qwen2_5_VLModel)
try:
base = model.encoder.get_base_model() if hasattr(model.encoder, "get_base_model") else None
if base is None and hasattr(model.encoder, "model"):
base = model.encoder.model
if base is not None:
setattr(base, "aop_prune_config", aop_cfg)
if hasattr(base, "model"): # ForConditionalGeneration.model -> Qwen2_5_VLModel
setattr(base.model, "aop_prune_config", aop_cfg)
except Exception as e:
print_master(f"[AOP] warn: sync cfg to base failed: {e}")
# 可选:覆盖注意力实现用于分析
attn_override = aop_cfg.get("attn_impl_override", "")
if attn_override:
try:
if hasattr(model.encoder, "model") and hasattr(model.encoder.model, "config"):
prev = model.encoder.model.config._attn_implementation
model.encoder.model.config._attn_implementation = attn_override
print_master(f"[AOP] override attn impl: {prev} -> {attn_override} (仅测试建议)")
except Exception as e:
print_master(f"[AOP] try override attn impl failed: {e}")
print_master("[AOP] AOP-Prune enabled with config: " + json.dumps({
"apply_to": aop_cfg.get("apply_to"),
"layer_idx": aop_cfg.get("layer_idx"),
"mode": aop_cfg.get("mode"),
"delta": aop_cfg.get("delta"),
"K_hat": aop_cfg.get("K_hat"),
"keep_ratio": aop_cfg.get("keep_ratio"),
"min_keep": aop_cfg.get("min_keep"),
"prune_text": aop_cfg.get("prune_text"),
"prune_vision": aop_cfg.get("prune_vision"),
"keep_ratio_text": aop_cfg.get("keep_ratio_text"),
"keep_ratio_vision": aop_cfg.get("keep_ratio_vision"),
"selection": aop_cfg.get("selection"),
"attn_agg": aop_cfg.get("attn_agg"),
}, ensure_ascii=False))
else:
print_master("[AOP] disabled (set AOP_ENABLED=1 to enable)")
# ---- NEW: Vision Pooling 配置注入 ----
vpool_cfg = get_env_vpool_config()
if vpool_cfg["enabled"]:
# 1) 写到 LoRA wrapper
setattr(model.encoder, "vision_pooling_config", vpool_cfg)
# 2) 同步到底座(ForConditionalGeneration / Qwen2_5_VLModel)
try:
base = model.encoder.get_base_model() if hasattr(model.encoder, "get_base_model") else None
if base is None and hasattr(model.encoder, "model"):
base = model.encoder.model
if base is not None:
setattr(base, "vision_pooling_config", vpool_cfg)
if hasattr(base, "model"): # ForConditionalGeneration.model -> Qwen2_5_VLModel
setattr(base.model, "vision_pooling_config", vpool_cfg)
except Exception as e:
print_master(f"[VPOOL] warn: sync cfg to base failed: {e}")
print_master("[VPOOL] enabled with config: " + json.dumps({
"apply_to": vpool_cfg.get("apply_to"),
"layer_idx": vpool_cfg.get("layer_idx"),
"kernel": vpool_cfg.get("kernel"),
"stride": vpool_cfg.get("stride"),
"method": vpool_cfg.get("method"),
"vision_only": vpool_cfg.get("vision_only"),
"monitor": vpool_cfg.get("monitor"),
}, ensure_ascii=False))
else:
print_master("[VPOOL] disabled (set VPOOL_ENABLED=1 to enable)")
# 确保“最后一层”时不裁层(避免类里默认20层的坑)
model.set_inference_layers(qry_layers=None, tgt_layers=None)
with open(data_args.dataset_config, 'r') as yaml_file:
dataset_configs = yaml.safe_load(yaml_file)
# ----------------------- Main evaluation loop -----------------------
for dataset_idx, (dataset_name, task_config) in enumerate(dataset_configs.items()):
if dist.is_initialized():
dist.barrier()
print_master(f"\n--- Evaluating {dataset_name} ---")
# 根据 data_basedir 修正路径
if data_args.data_basedir is not None:
for key in ["image_root", "video_root", "frame_root", "clip_root", "data_path"]:
if data_args.data_basedir and task_config.get(key):
task_config[key] = os.path.join(data_args.data_basedir, task_config[key])
# 构建数据集
full_eval_qry_dataset, corpus = AutoEvalPairDataset.instantiate(model_args=model_args, data_args=data_args, **task_config)
full_eval_cand_dataset = generate_cand_dataset(full_eval_qry_dataset, corpus)
eval_qry_dataset, eval_cand_dataset = full_eval_qry_dataset, full_eval_cand_dataset
if dist.is_initialized():
world_size = dist.get_world_size()
padded_qry_dataset, _ = pad_dataset_to_divisible(full_eval_qry_dataset, world_size)
padded_cand_dataset, _ = pad_dataset_to_divisible(full_eval_cand_dataset, world_size)
eval_qry_dataset = split_dataset_by_node(padded_qry_dataset, rank=local_rank, world_size=world_size)
eval_cand_dataset = split_dataset_by_node(padded_cand_dataset, rank=local_rank, world_size=world_size)
else:
padded_qry_dataset, padded_cand_dataset = full_eval_qry_dataset, full_eval_cand_dataset
# 路径索引
saved_paths = {} # {(side, tag): path}
# --------- 针对每个层设置(中间层/最后一层)分别编码与保存 ---------
for keep_layers in layers_to_eval:
tag = make_layer_tag(keep_layers)
print_master(f"[{dataset_name}] Start encoding for tag={tag} (keep_layers={keep_layers})")
# 设置模型层数
model.set_inference_layers(qry_layers=keep_layers, tgt_layers=keep_layers)
# 路径
query_embed_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_qry_{tag}")
cand_embed_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt_{tag}")
dataset_info_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_info.jsonl")
query_inference_stats_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_qry_inference_stats_{tag}.json")
cand_inference_stats_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_cand_inference_stats_{tag}.json")
qry_img_masks_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_qry_img_token_masks_{tag}.jsonl")
cand_img_masks_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_cand_img_token_masks_{tag}.jsonl")
# 追加四个新文件路径
qry_txt_masks_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_qry_text_token_masks_{tag}.jsonl")
qry_token_stats_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_qry_token_stats_{tag}.jsonl")
cand_txt_masks_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_cand_text_token_masks_{tag}.jsonl")
cand_token_stats_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_cand_token_stats_{tag}.jsonl")
saved_paths[("qry", tag)] = query_embed_path
saved_paths[("tgt", tag)] = cand_embed_path
do_query = not os.path.exists(query_embed_path) or not os.path.exists(dataset_info_path)
do_cand = not os.path.exists(cand_embed_path)
# 动态累计统计
def init_total_stats():
return {
"total_inference_time_seconds": 0.0,
"module_inference_times": {},
"token_counts": {
"visual_tokens": 0,
"language_input_tokens_raw": 0,
"llm_total_input_tokens": 0,
"language_output_tokens": 0,
},
"reduction": { # NEW: 删前/删后数量累计
"sum_pre_text": 0, "sum_post_text": 0,
"sum_pre_vision": 0, "sum_post_vision": 0,
"sum_pre_total": 0, "sum_post_total": 0,
},
"data_point_count": 0
}
def accumulate_stats(total_stats, batch_stats):
batch_size = batch_stats["batch_size"]
total_stats["total_inference_time_seconds"] += batch_stats["total_inference_time_seconds"]
# 模块时间
for mname, mstats in batch_stats["module_inference_times"].items():
if mname not in total_stats["module_inference_times"]:
total_stats["module_inference_times"][mname] = {"total": 0.0, "count": 0}
total_stats["module_inference_times"][mname]["total"] += mstats.get("total", 0.0)
total_stats["module_inference_times"][mname]["count"] += mstats.get("count", 0)
# 原始 token 统计(乘以 batch_size,是为了估计总量)
total_stats["token_counts"]["visual_tokens"] += batch_stats["token_counts"]["visual_tokens"] * batch_size
total_stats["token_counts"]["language_input_tokens_raw"] += batch_stats["token_counts"]["language_input_tokens_raw"] * batch_size
total_stats["token_counts"]["llm_total_input_tokens"] += batch_stats["token_counts"]["llm_total_input_tokens"] * batch_size
total_stats["token_counts"]["language_output_tokens"] += batch_stats["token_counts"]["language_output_tokens"] * batch_size
total_stats["data_point_count"] += batch_size
# NEW: 删减统计
red = batch_stats.get("token_reduction", None)
if red is not None:
for k in total_stats["reduction"].keys():
total_stats["reduction"][k] += int(red.get(k, 0))
def finalize_and_save_stats(total_stats, out_path, task_name, encode_side):
if local_rank != 0:
return
if total_stats["data_point_count"] <= 0:
print_master(f"No data processed for {task_name} [{encode_side}], skip saving stats.")
return
n = max(1, total_stats["data_point_count"])
red = total_stats["reduction"]
pre_txt, post_txt = red["sum_pre_text"], red["sum_post_text"]
pre_vis, post_vis = red["sum_pre_vision"], red["sum_post_vision"]
pre_tot, post_tot = red["sum_pre_total"], red["sum_post_total"]
avg_text_pruned = (pre_txt - post_txt) / n
avg_vision_pruned = (pre_vis - post_vis) / n
avg_total_pruned = (pre_tot - post_tot) / n
avg_text_keep_ratio = (post_txt / pre_txt) if pre_txt > 0 else 1.0
avg_vision_keep_ratio = (post_vis / pre_vis) if pre_vis > 0 else 1.0
avg_total_keep_ratio = (post_tot / pre_tot) if pre_tot > 0 else 1.0
final_stats = {
"task_name": task_name,
"encode_side": encode_side,
"data_point_count": total_stats["data_point_count"],
"inference_times": {
"total_inference_time_seconds": total_stats["total_inference_time_seconds"],
"avg_inference_time_per_item_seconds": total_stats["total_inference_time_seconds"] / n,
"module_average_times_per_call": {},
"module_total_times_seconds": {},
"module_calls_count": {},
},
"token_counts": {
"total_visual_tokens": total_stats["token_counts"]["visual_tokens"],
"avg_visual_tokens_per_item": total_stats["token_counts"]["visual_tokens"] / n,
"total_language_input_tokens_raw": total_stats["token_counts"]["language_input_tokens_raw"],
"avg_language_input_tokens_raw_per_item": total_stats["token_counts"]["language_input_tokens_raw"] / n,
"total_llm_total_input_tokens": total_stats["token_counts"]["llm_total_input_tokens"],
"avg_llm_total_input_tokens_per_item": total_stats["token_counts"]["llm_total_input_tokens"] / n,
"total_language_output_tokens": total_stats["token_counts"]["language_output_tokens"],
"avg_language_output_tokens_per_item": total_stats["token_counts"]["language_output_tokens"] / n,
},
"token_reduction": { # NEW: 输出平均删减与保留比例
"avg_text_pruned_per_item": float(avg_text_pruned),
"avg_vision_pruned_per_item": float(avg_vision_pruned),
"avg_total_pruned_per_item": float(avg_total_pruned),
"avg_text_keep_ratio": float(avg_text_keep_ratio),
"avg_vision_keep_ratio": float(avg_vision_keep_ratio),
"avg_total_keep_ratio": float(avg_total_keep_ratio),
"sum_pre_text": int(pre_txt), "sum_post_text": int(post_txt),
"sum_pre_vision": int(pre_vis), "sum_post_vision": int(post_vis),
"sum_pre_total": int(pre_tot), "sum_post_total": int(post_tot),
}
}
for mname, mstats in total_stats["module_inference_times"].items():
total = mstats.get("total", 0.0)
count = mstats.get("count", 0)
final_stats["inference_times"]["module_total_times_seconds"][mname] = total
final_stats["inference_times"]["module_calls_count"][mname] = count
final_stats["inference_times"]["module_average_times_per_call"][mname] = (total / count) if count > 0 else 0.0
with open(out_path, 'w', encoding='utf-8') as f:
json.dump(final_stats, f, ensure_ascii=False, indent=4)
print_master(f"[{task_name}] {encode_side} inference statistics saved to: {out_path}")
# ------- Encode queries -------
if do_query:
print_master(f"[{tag}] Encoding queries...")
eval_qry_collator = MultimodalEvalDataCollator(processor, model_args, data_args, "qry")
eval_qry_loader = DataLoader(
eval_qry_dataset,
batch_size=training_args.per_device_eval_batch_size,
collate_fn=eval_qry_collator,
num_workers=training_args.dataloader_num_workers
)
query_embeds, gt_infos, qry_batch_stats, qry_img_masks, qry_txt_masks, qry_token_records = encode_embeddings(
model, eval_qry_loader, training_args, model_args, padded_qry_dataset,
encode_side="qry", description=f"Queries[{tag}] for {dataset_name}"
)
# 截断到真实长度
true_qry_len = len(full_eval_qry_dataset)
query_embeds = query_embeds[:true_qry_len]
gt_infos = gt_infos[:true_qry_len]
qry_img_masks = qry_img_masks[:true_qry_len]
qry_txt_masks = qry_txt_masks[:true_qry_len] # NEW
qry_token_records = qry_token_records[:true_qry_len] # NEW
# 累计统计并保存
qry_total_stats = init_total_stats()
for bs in qry_batch_stats:
accumulate_stats(qry_total_stats, bs)
if local_rank == 0:
with open(query_embed_path, 'wb') as f:
pickle.dump(query_embeds, f)
if not os.path.exists(dataset_info_path):
with open(dataset_info_path, 'w') as f:
for info in gt_infos:
f.write(json.dumps(info) + '\n')
# 保存 image masks
with open(qry_img_masks_path, 'w', encoding='utf-8') as f:
for i, m in enumerate(qry_img_masks):
f.write(json.dumps({"index": i, "mask": m}, ensure_ascii=False) + "\n")
# 保存 text masks(NEW)
with open(qry_txt_masks_path, 'w', encoding='utf-8') as f:
for i, m in enumerate(qry_txt_masks):
f.write(json.dumps({"index": i, "mask": m}, ensure_ascii=False) + "\n")
# 保存 per-sample token 统计(NEW)
with open(qry_token_stats_path, 'w', encoding='utf-8') as f:
for i, rec in enumerate(qry_token_records):
f.write(json.dumps({"index": i, **rec}, ensure_ascii=False) + "\n")
print_master(f"Saved query embeddings to {query_embed_path}")
print_master(f"Saved query image token masks to {qry_img_masks_path}")
print_master(f"Saved query text token masks to {qry_txt_masks_path}")
print_master(f"Saved query token stats to {qry_token_stats_path}")
finalize_and_save_stats(qry_total_stats, query_inference_stats_path, dataset_name, f"query[{tag}]")
if dist.is_initialized():
dist.barrier()
# ------- Encode candidates -------
if do_cand:
print_master(f"[{tag}] Encoding candidates...")
eval_cand_collator = MultimodalEvalDataCollator(processor, model_args, data_args, "cand")
eval_cand_loader = DataLoader(
eval_cand_dataset,
batch_size=training_args.per_device_eval_batch_size,
collate_fn=eval_cand_collator,
num_workers=training_args.dataloader_num_workers
)
cand_embeds, all_cand_ids, cand_batch_stats, cand_img_masks, cand_txt_masks, cand_token_records = encode_embeddings(
model, eval_cand_loader, training_args, model_args, padded_cand_dataset,
encode_side="cand", description=f"Candidates[{tag}] for {dataset_name}"
)
true_cand_len = len(full_eval_cand_dataset)
cand_embeds = cand_embeds[:true_cand_len]
all_cand_ids = all_cand_ids[:true_cand_len]
cand_img_masks = cand_img_masks[:true_cand_len]
cand_txt_masks = cand_txt_masks[:true_cand_len] # NEW
cand_token_records = cand_token_records[:true_cand_len] # NEW
cand_total_stats = init_total_stats()
for bs in cand_batch_stats:
accumulate_stats(cand_total_stats, bs)
if local_rank == 0:
cand_embed_dict = {cid: emb for cid, emb in zip(all_cand_ids, cand_embeds)}
with open(cand_embed_path, 'wb') as f:
pickle.dump(cand_embed_dict, f)
with open(cand_img_masks_path, 'w', encoding='utf-8') as f:
for cid, m in zip(all_cand_ids, cand_img_masks):
f.write(json.dumps({"cand_id": str(cid), "mask": m}, ensure_ascii=False) + "\n")
# 保存 text masks(NEW)
with open(cand_txt_masks_path, 'w', encoding='utf-8') as f:
for cid, m in zip(all_cand_ids, cand_txt_masks):
f.write(json.dumps({"cand_id": str(cid), "mask": m}, ensure_ascii=False) + "\n")
# 保存 per-sample token 统计(NEW)
with open(cand_token_stats_path, 'w', encoding='utf-8') as f:
for cid, rec in zip(all_cand_ids, cand_token_records):
f.write(json.dumps({"cand_id": str(cid), **rec}, ensure_ascii=False) + "\n")
print_master(f"Saved candidate embeddings to {cand_embed_path}")
print_master(f"Saved candidate image token masks to {cand_img_masks_path}")
print_master(f"Saved candidate text token masks to {cand_txt_masks_path}")
print_master(f"Saved candidate token stats to {cand_token_stats_path}")
finalize_and_save_stats(cand_total_stats, cand_inference_stats_path, dataset_name, f"candidate[{tag}]")
if dist.is_initialized():
dist.barrier()
# --------- Scoring per layer + combined + early-exit curve ---------
if local_rank == 0:
dataset_info_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_info.jsonl")
gt_infos = [json.loads(l) for l in open(dataset_info_path)]
rank_against_all_candidates = task_config.get("eval_type", "global") == "global"
metrics_to_report = task_config.get("metrics", ["hit", "ndcg", "precision", "recall", "f1", "map", "mrr"])
layer_tags = [make_layer_tag(l) for l in layers_to_eval]
sims_by_layer = {} # tag -> list[ dict(cand_id->score) ]
for tag in layer_tags:
query_embed_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_qry_{tag}")
cand_embed_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt_{tag}")
with open(query_embed_path, 'rb') as f:
qry_embeds = pickle.load(f)
with open(cand_embed_path, 'rb') as f:
cand_embed_dict = pickle.load(f)
pred_dicts = []
score_detail_dicts = []
sims_for_exit = []
if rank_against_all_candidates:
cand_keys = list(cand_embed_dict.keys())
cand_embeds = np.stack([cand_embed_dict[key] for key in cand_keys])
if isinstance(qry_embeds, np.ndarray) and qry_embeds.ndim == 3:
# Late-interaction
qry_embed_t = torch.from_numpy(qry_embeds)
cand_embeds_t = [torch.from_numpy(np.array(t)) for t in cand_embeds]
sim_matrix = processor.score(qry_embed_t, cand_embeds_t, batch_size=64).cpu().numpy()
else:
sim_matrix = np.dot(qry_embeds, cand_embeds.T)
ranked_all = np.argsort(-sim_matrix, axis=1)
for qid, gt_info in tqdm(enumerate(gt_infos), total=len(gt_infos), desc=f"[{tag}] scoring(all) {dataset_name}"):
ranked_indices = ranked_all[qid]
rel_docids = gt_info["label_name"] if isinstance(gt_info["label_name"], list) else [gt_info["label_name"]]
rel_scores = gt_info.get("rel_scores")
pred_dicts.append({
"prediction": [cand_keys[i] for i in ranked_indices],
"label": rel_docids,
"rel_scores": rel_scores,
})
score_detail_dicts.append(build_score_details(qid, cand_keys, sim_matrix[qid], ranked_indices))
sims_for_exit.append({cand_keys[i]: float(sim_matrix[qid][i]) for i in range(len(cand_keys))})
else:
# 非全局:每个query用 gt_info["cand_names"] 的子集进行评分
for qid, (qry_embed, gt_info) in tqdm(enumerate(zip(qry_embeds, gt_infos)), total=len(gt_infos), desc=f"[{tag}] scoring(local) {dataset_name}"):
cand_ids_local = gt_info["cand_names"]
cand_embeds = np.stack([cand_embed_dict[key] for key in cand_ids_local])
if isinstance(qry_embeds, np.ndarray) and qry_embeds.ndim == 3:
qry_embed_t = torch.from_numpy(np.array(qry_embed)).unsqueeze(0) # [1, Lq, H]
cand_embeds_t = [torch.from_numpy(np.array(t)) for t in cand_embeds]
sim_vec = processor.score(qry_embed_t, cand_embeds_t, batch_size=1024).cpu().numpy()[0]
else:
sim_vec = np.dot(qry_embed, cand_embeds.T)
ranked_indices = np.argsort(-sim_vec)
rel_docids = gt_info["label_name"] if isinstance(gt_info["label_name"], list) else [gt_info["label_name"]]
rel_scores = gt_info.get("rel_scores")
pred_dicts.append({
"prediction": [cand_ids_local[i] for i in ranked_indices],
"label": rel_docids,
"rel_scores": rel_scores,
})
score_detail_dicts.append(build_score_details(qid, cand_ids_local, sim_vec, ranked_indices))
sims_for_exit.append({cid: float(s) for cid, s in zip(cand_ids_local, sim_vec.tolist())})
# 保存每层指标与详情
layer_score_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_score_{tag}.json")
layer_pred_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_pred_{tag}.jsonl")
layer_detail_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_score_details_{tag}.jsonl")
metrics = RankingMetrics(metrics_to_report)
score_dict = metrics.evaluate(pred_dicts)
score_dict["num_pred"] = len(pred_dicts)
score_dict["num_data"] = len(gt_infos)
with open(layer_score_path, "w") as f:
json.dump(score_dict, f, indent=4)
with open(layer_pred_path, "w") as f:
for pred in pred_dicts:
f.write(json.dumps(pred) + '\n')
with open(layer_detail_path, "w") as f:
for detail in score_detail_dicts:
f.write(json.dumps(detail) + "\n")
print_master(f"[{dataset_name}] {tag} score: " + json.dumps({k: (f"{v:.4f}" if isinstance(v, (int, float)) else v) for k, v in score_dict.items()}))
sims_by_layer[tag] = sims_for_exit
# 合并对比文件 + 早停曲线(仅在存在中间层时)
if len(layer_tags) == 2 and "layerlast" in layer_tags:
mid_tag = [t for t in layer_tags if t != "layerlast"][0]
last_tag = "layerlast"
# 合并详情:每个query包含 mid/last 的cand_scores、top1、margin
combined_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_score_details_both_layers.jsonl")
with open(combined_path, "w", encoding='utf-8') as f:
for qid in range(len(gt_infos)):
sims_mid = sims_by_layer[mid_tag][qid]
sims_last = sims_by_layer[last_tag][qid]
def top1_cid(sims: dict):
return max(sims.items(), key=lambda x: x[1])[0] if sims else None
def margin_of(sims: dict):
vals = np.array(list(sims.values()), dtype=np.float32)
return top1_top2_margin_from_array(vals)
row = {
"qid": int(qid),
"label": gt_infos[qid]["label_name"] if isinstance(gt_infos[qid]["label_name"], list) else [gt_infos[qid]["label_name"]],
"mid": {
"top1": top1_cid(sims_mid),
"margin": margin_of(sims_mid),
"cand_scores": sims_mid
},
"last": {
"top1": top1_cid(sims_last),
"margin": margin_of(sims_last),
"cand_scores": sims_last
}
}
f.write(json.dumps(row, ensure_ascii=False) + "\n")
print_master(f"[{dataset_name}] combined details saved to {combined_path}")
# 早停曲线(margin 阈值)
taus = [round(x, 3) for x in np.linspace(0.0, 0.6, 31).tolist()]
labels = [
gi["label_name"] if isinstance(gi["label_name"], list) else [gi["label_name"]]
for gi in gt_infos
]
exit_curve = simulate_early_exit_by_margin(
sims_by_layer[mid_tag], sims_by_layer[last_tag], labels, metrics_to_report, taus, rank_against_all_candidates
)
curve_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_early_exit_curve_margin.json")
with open(curve_path, "w") as f:
json.dump(exit_curve, f, indent=4)
print_master(f"[{dataset_name}] early-exit curve saved to {curve_path}")
if dist.is_initialized():
dist.barrier()
if __name__ == '__main__':
main()