code_SAS_VLM2Vec / eval_test_time_cut_layer_QCTM.py
MgGladys's picture
Add files using upload-large-folder tool
2a40e7a verified
import datetime
import logging
import json
import random
import time
import numpy as np
import os
import pickle
import sys
import torch
import torch.distributed as dist
import torch.nn.functional as F
import yaml
import transformers
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import HfArgumentParser, AutoConfig, AutoTokenizer
from datasets import Dataset, concatenate_datasets
from datasets.distributed import split_dataset_by_node
from src.model.vlm_backbone.qwen2_vl.modeling_qwen2_vl_train_tokrnpooling import Qwen2VLForConditionalGeneration as _Qwen2VLForConditionalGeneration_src
from src.arguments import ModelArguments, DataArguments, TrainingArguments
from src.data.collator.eval_collator import MultimodalEvalDataCollator
from src.data.eval_dataset.base_eval_dataset import AutoEvalPairDataset, generate_cand_dataset
from src.eval_utils.metrics import RankingMetrics
from src.model.model_cut_layer_QCTM import MMEBModel
from src.model.processor import get_backbone_name, load_processor, COLPALI
from src.utils import batch_to_device, print_rank, print_master
from dataclasses import dataclass
def get_env_mid_layer():
v = os.environ.get("MID_LM_LAYER", "").strip()
if v == "" or v.lower() in {"none", "null"}:
return None
try:
return int(v)
except:
logger.warning(f"Invalid MID_LM_LAYER={v}, ignore.")
return None
# ------------- AOP-Prune config parsing -------------
def _parse_bool(v: str, default=False):
if v is None: return default
v = v.strip().lower()
return v in {"1","true","yes","y","t","on"}
def _parse_float(v: str, default=None):
try: return float(v) if v is not None else default
except: return default
def _parse_int(v: str, default=None):
try: return int(v) if v is not None else default
except: return default
def get_env_aop_config():
"""
从环境变量读取 AOP 剪裁配置。仅作为“驱动层”的简要测试开关;
实际剪裁逻辑在底模里(Qwen2-VLModel.forward)实现。
"""
enabled = _parse_bool(os.environ.get("AOP_ENABLED"), False)
apply_to = os.environ.get("AOP_APPLY", "qry").strip().lower() # qry|cand|both
layer_idx = _parse_int(os.environ.get("AOP_LAYER"), None)
mode = os.environ.get("AOP_MODE", "delta").strip().lower()
delta = _parse_float(os.environ.get("AOP_DELTA"), 0.10)
khat = _parse_float(os.environ.get("AOP_KHAT"), 1.0)
keep_ratio = _parse_float(os.environ.get("AOP_KEEP_RATIO"), 1.0)
min_keep = _parse_int(os.environ.get("AOP_MIN_KEEP"), 64)
use_bias = _parse_bool(os.environ.get("AOP_USE_BIAS"), True)
margin_src = os.environ.get("AOP_MARGIN", "").strip().lower() # "" or "mid"
attn_impl = os.environ.get("AOP_ATTN_IMPL", "").strip().lower() # "" or "sdpa"
if layer_idx is None and enabled:
logger.warning("AOP_ENABLED=1 但未设置 AOP_LAYER,关闭 AOP。"); enabled=False
cfg = {
"enabled": enabled,
"apply_to": apply_to, # 控制在 encode_embeddings 中按侧启用
"layer_idx": layer_idx,
"mode": mode, "delta": delta, "K_hat": khat,
"keep_ratio": keep_ratio, "min_keep": min_keep,
"use_bias": use_bias, "eps": 1e-6,
# 安全预算(mid margin)可选:底模里如能取 m_mid,可用
"margin_mid": None if margin_src != "mid" else "USE_MID_MARGIN",
# 仅在你实现里需要时参考;此处只透传
"epsilon_hat": None,
# 可选覆盖注意力实现,便于在指定层取到权重/做近似
"attn_impl_override": attn_impl if attn_impl in {"sdpa"} else "",
}
return cfg
def get_env_eval_layers():
"""
解析环境变量 LM_LAYERS(优先)或兼容旧的 MID_LM_LAYER。
- LM_LAYERS 示例:"4,8,12,last";可包含 'last'/'none'/'null'/'-1' 表示最后一层(None)。
- 若未设置 LM_LAYERS,则回落到旧逻辑:MID_LM_LAYER=None -> [None];否则 [mid, None]
返回: list[ int | None ],例如 [4, 8, 12, None];None 代表最后一层。
"""
v = os.environ.get("LM_LAYERS", None)
if v is not None:
v = v.strip()
if v:
toks = [t.strip() for t in v.split(',') if t.strip() != ""]
layers = []
for tok in toks:
tl = tok.lower()
if tl in {"last", "none", "null", "-1"}:
layers.append(None)
else:
try:
val = int(tok)
if val > 0:
layers.append(val)
else:
logger.warning(f"Ignoring non-positive layer '{tok}' in LM_LAYERS.")
except Exception:
logger.warning(f"Invalid token '{tok}' in LM_LAYERS; must be int or 'last'/'none'.")
# 去重但保持顺序
seen = set()
uniq = []
for l in layers:
key = -1 if l is None else l
if key in seen:
continue
seen.add(key)
uniq.append(l)
if not uniq:
return [None]
return uniq
else:
# 兼容旧逻辑
mid = get_env_mid_layer()
return [None] if mid is None else [mid, None]
def make_layer_tag(keep_layers: int | None):
return f"layer{keep_layers}" if keep_layers and keep_layers > 0 else "layerlast"
def dot_sim(a: np.ndarray, b: np.ndarray) -> np.ndarray:
# a: [Nq, D], b: [Nc, D], both L2-normalized already if normalize=true
return a @ b.T
def build_score_details(qid: int, cand_ids: list, score_vec: np.ndarray, ranked_indices: np.ndarray):
return {
"qid": int(qid),
"cand_scores": [
{"cand_id": str(cand_ids[i]), "score": float(score_vec[i])}
for i in ranked_indices
]
}
def top1_top2_margin(score_vec: np.ndarray) -> float:
if len(score_vec) < 2:
return float("inf") # 只有一个候选时视作极大margin
top2 = np.partition(score_vec, -2)[-2:]
top2.sort()
return float(top2[-1] - top2[-2])
def simulate_early_exit_by_margin(
sims_mid: list[dict], sims_last: list[dict], labels: list[list[str]], metrics_to_report: list[str],
taus: list[float], rank_global: bool
):
"""
sims_mid / sims_last: 每个query一个dict: {cand_id: score}
labels: 每个query的正样本cand_id列表
返回:不同tau下的覆盖率、指标
"""
assert len(sims_mid) == len(sims_last) == len(labels)
N = len(labels)
results = []
from src.eval_utils.metrics import RankingMetrics
metrics = RankingMetrics(metrics_to_report)
# 预构造 用于metrics.evaluate 的pred_dict
def to_pred_dicts(use_mid_mask: list[bool]) -> list[dict]:
pred_dicts = []
for qid in range(N):
sims_use = sims_mid[qid] if use_mid_mask[qid] else sims_last[qid]
# 排序
ranked = sorted(sims_use.items(), key=lambda x: -x[1])
pred_dicts.append({
"prediction": [cid for cid, _ in ranked],
"label": labels[qid],
"rel_scores": None
})
return pred_dicts
# 计算中间层margin
margins = []
for qid in range(N):
# 取前两大分数的margin
if len(sims_mid[qid]) == 0:
margins.append(0.0)
continue
scores = np.array(list(sims_mid[qid].values()), dtype=np.float32)
margins.append(top1_top2_margin(scores))
margins = np.array(margins, dtype=np.float32)
for tau in taus:
use_mid_mask = (margins >= tau).tolist()
pred_dicts = to_pred_dicts(use_mid_mask)
score_dict = metrics.evaluate(pred_dicts)
coverage = float(np.mean(use_mid_mask)) # 早停覆盖率
results.append({
"tau": tau,
"coverage": coverage,
**score_dict
})
return results
def top1_top2_margin_from_array(score_vec: np.ndarray) -> float:
if score_vec is None or len(score_vec) == 0:
return 0.0
if len(score_vec) == 1:
return float('inf')
# 取前两大
top2 = np.partition(score_vec, -2)[-2:]
top2.sort()
return float(top2[-1] - top2[-2])
# ------------- QCTM (ViT->LLM pre-merge) config parsing -------------
def get_env_qctm_config():
"""
从环境变量读取 QCTM 合并配置;仅作为“驱动层”开关,实际合并在底模 forward 里实现。
环境变量示例:
QCTM_ENABLED=1
QCTM_APPLY=both # qry|cand|both
QCTM_KEEP_RATIO=0.5 # 保留比例
QCTM_MIN_KEEP=64 # 最少保留视觉 token
QCTM_KNN_K=8 # kNN 近邻
QCTM_CK=1.0 # (†) 键项常数,严格用1.0;排序更紧可0.5~0.8(安全由(★)兜底)
QCTM_KAPPA_V=1.1 # κ_v 放缩,确保 S̃≥S∞
QCTM_DEBUG=0
"""
enabled = _parse_bool(os.environ.get("QCTM_ENABLED"), False)
apply_to = os.environ.get("QCTM_APPLY", "qry").strip().lower() # qry|cand|both
keep_ratio = _parse_float(os.environ.get("QCTM_KEEP_RATIO"), 0.5)
min_keep = _parse_int(os.environ.get("QCTM_MIN_KEEP"), 64)
knn_k = _parse_int(os.environ.get("QCTM_KNN_K"), 8)
c_k = _parse_float(os.environ.get("QCTM_CK"), 1.0)
kappa_v = _parse_float(os.environ.get("QCTM_KAPPA_V"), 1.1)
debug = _parse_bool(os.environ.get("QCTM_DEBUG"), False)
cfg = {
"enabled": enabled,
"apply_to": apply_to, # 控制 encode_embeddings 内按侧启用
"keep_ratio": keep_ratio,
"min_keep": min_keep,
"knn_k": knn_k,
"c_k": c_k,
"kappa_v": kappa_v,
"debug": debug,
# 其余 DPPM/预算参数若底模支持也可继续扩展透传
}
return cfg
logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s [%(name)s:%(lineno)s] %(message)s')
logger = logging.getLogger(__name__)
# --- Global Dictionaries for Hooks (will be cleared before each encode_embeddings call) ---
timing_info = {}
token_info = {
"vision_tokens": 0,
"text_input_tokens": 0, # Refers to the original text token count
"text_output_tokens": 0, # Not directly applicable here as we are encoding, not generating. Will be 0.
"total_llm_input_tokens": 0, # Refers to the total tokens LLM receives (visual + formatted text)
}
# --- Hook Functions Definition ---
def timing_pre_hook(module, input):
module_id = id(module)
if module_id not in timing_info:
timing_info[module_id] = []
timing_info[module_id].append((time.time(), 'pre', module.__class__.__name__))
def timing_post_hook(module, input, output):
module_id = id(module)
if module_id not in timing_info:
# print(f"Warning: No pre-hook data for module {module.__class__.__name__} ({module_id})")
return
timing_info[module_id].append((time.time(), 'post', module.__class__.__name__))
# Collect vision token count (only from Vision Transformer module's post hook)
module_name = module.__class__.__name__
if "vision" in module_name.lower() and "transformer" in module_name.lower():
if isinstance(output, torch.Tensor):
token_info["vision_tokens"] = output.shape[0] # For visual features, usually (batch_size, num_tokens, hidden_dim)
elif hasattr(output, 'last_hidden_state'):
token_info["vision_tokens"] = output.last_hidden_state.shape[1]
def register_model_hooks(model):
registered_modules = []
core_model = model
# print_master(f"DEBUG: Initial model type in register_model_hooks: {type(model)}")
if hasattr(model, 'encoder') and model.encoder is not None:
# print_master(f"DEBUG: model has 'encoder' attribute. Type of model.encoder: {type(model.encoder)}")
# 使用从 'src' 路径导入的 Qwen2VLForConditionalGeneration 进行检查
if isinstance(model.encoder, _Qwen2VLForConditionalGeneration_src):
# print_master("Detected MMEBModel structure, registering hooks on model.encoder's sub-modules.")
core_model = model.encoder
else:
print_master(f"WARNING: model.encoder is not an instance of _Qwen2VLForConditionalGeneration_src. Its type is {type(model.encoder)}. Hooks will be registered on top-level model if applicable.")
else:
print_master("WARNING: Model structure does not have an 'encoder' attribute. Registering hooks directly on top-level modules.")
# Vision module
if hasattr(core_model, 'visual') and core_model.visual is not None:
vision_module = core_model.visual
vision_module.register_forward_pre_hook(timing_pre_hook)
vision_module.register_forward_hook(timing_post_hook)
registered_modules.append(vision_module)
print_master(f"Registered hooks for vision module: {vision_module.__class__.__name__}")
else:
print_master(f"WARNING: No 'visual' attribute found on core_model ({type(core_model)}).")
# Merger module (if inside visual) - it's part of the vision component
if hasattr(core_model, 'visual') and hasattr(core_model.visual, 'merger') and core_model.visual.merger is not None:
merger_module = core_model.visual.merger
merger_module.register_forward_pre_hook(timing_pre_hook)
merger_module.register_forward_hook(timing_post_hook)
registered_modules.append(merger_module)
print_master(f"Registered hooks for merger module: {merger_module.__class__.__name__}")
else:
print_master(f"WARNING: No 'merger' attribute found on core_model.visual ({type(getattr(core_model, 'visual', 'N/A'))}).")
# Language model body
if hasattr(core_model, 'model') and core_model.model is not None:
llm_main_module = core_model.model
llm_main_module.register_forward_pre_hook(timing_pre_hook)
llm_main_module.register_forward_hook(timing_post_hook)
registered_modules.append(llm_main_module)
print_master(f"Registered hooks for LLM main module: {llm_main_module.__class__.__name__}")
else:
print_master(f"WARNING: No 'model' attribute found on core_model ({type(core_model)}).")
# LM Head
if hasattr(core_model, 'lm_head') and core_model.lm_head is not None:
lm_head_module = core_model.lm_head
lm_head_module.register_forward_pre_hook(timing_pre_hook)
lm_head_module.register_forward_hook(timing_post_hook)
registered_modules.append(lm_head_module)
print_master(f"Registered hooks for LM head module: {lm_head_module.__class__.__name__}")
else:
print_master(f"WARNING: No 'lm_head' attribute found on core_model ({type(core_model)}).")
if not registered_modules:
print_master("Warning: No major modules found for hook registration. Check model architecture.")
return registered_modules
def pad_dataset_to_divisible(dataset, world_size):
num_samples = len(dataset)
if num_samples % world_size == 0:
return dataset, num_samples
num_to_add = world_size - (num_samples % world_size)
padded_size = num_samples + num_to_add
padding_data = dataset.select([i % len(dataset) for i in range(num_to_add)])
padded_dataset = concatenate_datasets([dataset, padding_data])
return padded_dataset, padded_size
def encode_embeddings(
model: MMEBModel,
loader: DataLoader,
training_args: TrainingArguments,
model_args: ModelArguments,
full_dataset: Dataset,
encode_side: str,
description: str = "Encoding"
) -> tuple[np.ndarray, list, list, list, list]: # CHANGED: + list for qctm_stats
"""
Encodes embeddings for a given dataset using the model, handling both standard and
late-interaction models in a DDP-safe manner.
Returns:
- embeddings: np.ndarray
- infos_or_ids: list
- batch_stats_list: list
- img_token_masks: list[None | list[bool]] # NEW
"""
local_rank = dist.get_rank() if dist.is_initialized() else 0
world_size = dist.get_world_size() if dist.is_initialized() else 1
# Check if the model is a late-interaction type
is_late_interaction = (model_args.model_backbone == COLPALI)
local_embeds = []
local_gt_infos = []
local_max_len = 0
# --- New: List to store statistics for each batch ---
batch_stats_list = []
# --- NEW: Collect image token masks locally ---
local_img_token_masks = [] # 每个样本一个元素:None 或 [bool, ...]
# --- NEW: Collect QCTM stats locally ---
local_qctm_stats = [] # 每个样本一个元素:None 或 dict(见下方格式)
model.eval()
# Register hooks for the model once per encode_embeddings call
registered_hooks = register_model_hooks(model)
# --- NEW: helpers to取mask并序列化 ---
def _search_key(obj, key: str):
# 递归搜索 dict/list/tuple,找到指定 key
if isinstance(obj, dict):
if key in obj:
return obj[key]
for v in obj.values():
r = _search_key(v, key)
if r is not None:
return r
elif isinstance(obj, (list, tuple)):
for v in obj:
r = _search_key(v, key)
if r is not None:
return r
return None
def _to_serializable_mask_list(mask_list, batch_size: int):
# 将模型返回的 mask(list/tensor/ndarray/None)转成 [None | list[bool]] * B
if mask_list is None:
return [None] * batch_size
out = []
if isinstance(mask_list, (list, tuple)):
for m in mask_list:
if m is None:
out.append(None)
elif torch.is_tensor(m):
out.append(m.detach().cpu().tolist())
elif isinstance(m, np.ndarray):
out.append(m.tolist())
else:
# already python list/bool
out.append(m)
elif torch.is_tensor(mask_list):
# 若是 2D 张量(B, L),直接 tolist() -> list[list[bool/int]]
out = mask_list.detach().cpu().tolist()
elif isinstance(mask_list, np.ndarray):
out = mask_list.tolist()
else:
# 未知类型,保守返回 None 占位
out = [None] * batch_size
# 长度对齐 batch_size
if isinstance(out, list):
if len(out) < batch_size:
out = out + [None] * (batch_size - len(out))
elif len(out) > batch_size:
out = out[:batch_size]
return out
with torch.no_grad():
for inputs, dataset_info in tqdm(loader, desc=f"{description} (rank {local_rank})", disable=local_rank > 0):
# --- Reset statistics for each inference pass ---
timing_info.clear()
token_info["vision_tokens"] = 0
token_info["text_input_tokens"] = 0
token_info["text_output_tokens"] = 0
token_info["total_llm_input_tokens"] = 0
inputs = batch_to_device(inputs, training_args.device)
current_batch_size = inputs['input_ids'].shape[0] if 'input_ids' in inputs and inputs['input_ids'] is not None else 1
with torch.autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda"):
start_inference_time = time.time()
# ---- NEW: 按侧开/关 AOP ----
aop_cfg = getattr(model.encoder, "aop_prune_config", None)
_orig_enabled = None
if isinstance(aop_cfg, dict) and aop_cfg:
_orig_enabled = aop_cfg.get("enabled", False)
apply_to = aop_cfg.get("apply_to", "qry")
side_enable = (apply_to == "both") or (apply_to == encode_side)
aop_cfg["enabled"] = bool(side_enable and _orig_enabled)
setattr(model.encoder, "aop_prune_config", aop_cfg)
# ---- NEW: 按侧开/关 QCTM ----
qctm_cfg = getattr(model.encoder, "qctm_config", None)
_qctm_orig_enabled = None
if isinstance(qctm_cfg, dict) and qctm_cfg:
_qctm_orig_enabled = qctm_cfg.get("enabled", False)
apply_to = qctm_cfg.get("apply_to", "qry")
side_enable = (apply_to == "both") or (apply_to == encode_side)
qctm_cfg["enabled"] = bool(side_enable and _qctm_orig_enabled)
setattr(model.encoder, "qctm_config", qctm_cfg)
if encode_side == "qry":
output = model(qry=inputs)
reps = output["qry_reps"].detach()
local_gt_infos.extend(dataset_info)
else:
output = model(tgt=inputs)
reps = output["tgt_reps"].detach()
local_gt_infos.extend([info["cand_name"] for info in dataset_info])
# ---- NEW: 恢复 enabled(避免影响下个 encode_side)----
if isinstance(aop_cfg, dict) and _orig_enabled is not None:
aop_cfg["enabled"] = _orig_enabled
setattr(model.encoder, "aop_prune_config", aop_cfg)
# ---- NEW: 恢复 QCTM enabled(避免影响另一侧)----
if isinstance(qctm_cfg, dict) and _qctm_orig_enabled is not None:
qctm_cfg["enabled"] = _qctm_orig_enabled
setattr(model.encoder, "qctm_config", qctm_cfg)
end_inference_time = time.time()
# --- NEW: 取 qctm_stats 并打印 ---
qctm_stats = _search_key(output, "qctm_stats")
# qctm_stats 期望是 List[dict],每个样本一个dict
if qctm_stats is not None and isinstance(qctm_stats, list):
valid_stats = [st for st in qctm_stats if isinstance(st, dict)]
# 追加到本地聚合列表(按样本对齐)
local_qctm_stats.extend(qctm_stats)
# 只打印前 N 条或根据环境变量控制
log_n = int(os.environ.get("QCTM_LOG_N", "3"))
for i, st in enumerate(qctm_stats[:log_n]):
print_rank(
f"[QCTM][batch] bidx={st.get('batch_index')} "
f"valid_before={st.get('valid_len_before')} -> valid_after={st.get('valid_len_after')}, "
f"Nv {st.get('Nv_before')} -> {st.get('Nv_after')} "
f"(eff_keep={st.get('effective_keep_ratio'):.3f}), "
f"pairs={st.get('num_pairs')}"
)
if os.environ.get("QCTM_LOG_DETAIL", "0") == "1":
print_rank(f" kept_abs(first 32)={st.get('kept_abs')[:32]}")
print_rank(f" pruned_abs(first 32)={st.get('pruned_abs')[:32]}")
else:
# 若无统计,按 batch_size 补 None 占位,维持与样本数量对齐
local_qctm_stats.extend([None] * current_batch_size)
# --- NEW: 提取并保存本 batch 的 image_token_bool_masks ---
# 期望 MMEBModel 的 output 中直接或间接包含 'image_token_bool_masks'
img_masks_raw = None
if isinstance(output, dict):
img_masks_raw = _search_key(output, "image_token_bool_masks")
# 可选:若你在 MMEBModel 上挂了属性,也可以尝试读取
if img_masks_raw is None and hasattr(model, "image_token_bool_masks"):
img_masks_raw = getattr(model, "image_token_bool_masks")
img_masks_serializable = _to_serializable_mask_list(img_masks_raw, current_batch_size)
local_img_token_masks.extend(img_masks_serializable)
# # --- Update total LLM input tokens after the model call ---
# if 'input_ids' in inputs and inputs['input_ids'] is not None:
# token_info["total_llm_input_tokens"] = inputs['input_ids'].shape[1]
# token_info["text_input_tokens"] = token_info["total_llm_input_tokens"] - token_info["vision_tokens"]
# token_info["text_input_tokens"] = max(0, token_info["text_input_tokens"])
# --- Update token counts using merged masks if available ---
# 若模型返回了合并后的 image_token_bool_masks(每样本一个 list[bool]),
# 我们用它来估计“合并后 LLM 输入总长度 = len(mask)”与“视觉 token 数 = sum(mask)”
if any(isinstance(m, list) for m in img_masks_serializable):
llm_lens = [len(m) for m in img_masks_serializable if isinstance(m, list)]
vis_lens = [sum(1 for x in m if x) for m in img_masks_serializable if isinstance(m, list)]
if len(llm_lens) > 0:
# 这里存“每个样本的平均值”,后续会乘以 batch_size 累加,再在 finalize 里统一做平均
token_info["total_llm_input_tokens"] = float(np.mean(llm_lens))
token_info["vision_tokens"] = float(np.mean(vis_lens))
token_info["text_input_tokens"] = float(max(0.0, token_info["total_llm_input_tokens"] - token_info["vision_tokens"]))
else:
# fallback
if 'input_ids' in inputs and inputs['input_ids'] is not None:
token_info["total_llm_input_tokens"] = inputs['input_ids'].shape[1]
token_info["text_input_tokens"] = max(0, token_info["total_llm_input_tokens"] - token_info["vision_tokens"])
else:
# fallback:没有 mask 就用原有估计(input_ids 长度)
if 'input_ids' in inputs and inputs['input_ids'] is not None:
token_info["total_llm_input_tokens"] = inputs['input_ids'].shape[1]
token_info["text_input_tokens"] = max(0, token_info["total_llm_input_tokens"] - token_info["vision_tokens"])
# --- Collect and Store Batch Statistics ---
batch_inference_time = end_inference_time - start_inference_time
current_batch_stats = {
"batch_size": current_batch_size,
"total_inference_time_seconds": batch_inference_time,
"module_inference_times": {},
"token_counts": {
"visual_tokens": token_info["vision_tokens"],
"language_input_tokens_raw": token_info["text_input_tokens"],
"llm_total_input_tokens": token_info["total_llm_input_tokens"],
"language_output_tokens": token_info["text_output_tokens"],
}
}
# Calculate and store module timings for the current batch
for module_obj in registered_hooks:
module_id = id(module_obj)
module_name = module_obj.__class__.__name__
times = timing_info.get(module_id, [])
durations = []
pre_times = {}
for t, event_type, _ in times:
if event_type == 'pre':
pre_times[module_id] = t
elif event_type == 'post' and module_id in pre_times:
duration = t - pre_times.pop(module_id)
durations.append(duration)
if durations:
current_batch_stats["module_inference_times"][module_name] = {
"total": sum(durations),
"count": len(durations),
"avg": sum(durations) / len(durations)
}
else:
current_batch_stats["module_inference_times"][module_name] = {
"total": 0.0,
"count": 0,
"avg": 0.0
}
batch_stats_list.append(current_batch_stats)
# --- Debug prints (optional) ---
print_rank(f"\n--- Inference Statistics for {encode_side} batch (Rank {local_rank}) ---")
print_rank(f"Batch Inference took: {batch_inference_time:.4f} seconds")
print_rank("--- Module Inference Timing Statistics ---")
for module_name, stats in current_batch_stats["module_inference_times"].items():
print_rank(f"**{module_name}**: Total: {stats['total']:.6f}s, Count: {stats['count']}, Avg: {stats['avg']:.6f}s")
print_rank("--- Token Count Statistics ---")
print_rank(f"**视觉 token 数量**: {current_batch_stats['token_counts']['visual_tokens']}")
print_rank(f"**语言输入 token 数量 (仅原始文本)**: {current_batch_stats['token_counts']['language_input_tokens_raw']}")
print_rank(f"**LLM总输入 token 数量 (包含视觉 + 格式化文本)**: {current_batch_stats['token_counts']['llm_total_input_tokens']}")
print_rank(f"**语言输出 token 数量**: {current_batch_stats['token_counts']['language_output_tokens']}")
if is_late_interaction and reps.dim() == 3:
local_max_len = max(local_max_len, reps.shape[1])
local_embeds.append(reps)
if not local_embeds:
# Handle cases where a rank gets no data
return np.array([]), [], [], [], [] # CHANGED: 5个返回值
# === DDP Synchronization and Padding for Late-Interaction Models ===
if is_late_interaction:
if dist.is_initialized():
# 1: global max length
local_max_len_tensor = torch.tensor(local_max_len, device=training_args.device)
dist.all_reduce(local_max_len_tensor, op=dist.ReduceOp.MAX)
global_max_len = local_max_len_tensor.item()
else:
global_max_len = local_max_len
# 2: pad to global max length
padded_embeds = []
for reps_batch in local_embeds:
if reps_batch.dim() == 3:
B, L, H = reps_batch.shape
padding_size = global_max_len - L
padded_batch = F.pad(reps_batch, (0, 0, 0, padding_size), "constant", 0)
padded_embeds.append(padded_batch)
else:
padded_embeds.append(reps_batch)
embeds_tensor = torch.cat(padded_embeds, dim=0).contiguous()
else:
embeds_tensor = torch.cat(local_embeds, dim=0).contiguous()
# === Gather embeddings and keys from all ranks ===
if dist.is_initialized() and full_dataset.num_rows >= world_size:
print_master(f"Gathering {encode_side} embeddings across all ranks...")
# tensor gather
output_shape = list(embeds_tensor.shape)
output_shape[0] = full_dataset.num_rows
embeds_tensor = embeds_tensor.to(training_args.device)
gathered_embeds_tensor = torch.empty(output_shape, dtype=embeds_tensor.dtype, device=training_args.device)
dist.all_gather_into_tensor(gathered_embeds_tensor, embeds_tensor)
final_embeddings = gathered_embeds_tensor.cpu().float().numpy()
# object gather for infos and stats
gathered_gt_infos = [None for _ in range(world_size)]
dist.all_gather_object(gathered_gt_infos, local_gt_infos)
all_gt_infos = [key for rank_keys in gathered_gt_infos for key in rank_keys]
gathered_batch_stats = [None for _ in range(world_size)]
dist.all_gather_object(gathered_batch_stats, batch_stats_list)
all_batch_stats = [stats for rank_stats in gathered_batch_stats for stats in rank_stats]
# --- NEW: gather masks ---
gathered_masks = [None for _ in range(world_size)]
dist.all_gather_object(gathered_masks, local_img_token_masks)
all_img_token_masks = [m for rank_list in gathered_masks for m in rank_list]
# --- NEW: gather qctm_stats ---
gathered_qctm_stats = [None for _ in range(world_size)]
dist.all_gather_object(gathered_qctm_stats, local_qctm_stats)
all_qctm_stats = [s for rank_list in gathered_qctm_stats for s in rank_list]
else:
all_gt_infos = local_gt_infos
final_embeddings = embeds_tensor.cpu().float().numpy()
all_batch_stats = batch_stats_list
all_img_token_masks = local_img_token_masks # NEW
all_qctm_stats = local_qctm_stats # NEW
return final_embeddings, all_gt_infos, all_batch_stats, all_img_token_masks, all_qctm_stats # CHANGED: +stats
def main():
# ----------------------- Distributed init -----------------------
if "RANK" in os.environ and dist.is_available() and not dist.is_initialized():
dist.init_process_group(backend="nccl", timeout=datetime.timedelta(minutes=60))
local_rank = dist.get_rank() if dist.is_initialized() else 0
world_size = dist.get_world_size() if dist.is_initialized() else 1
print_master("Distributed init debug info:")
print_master(f"RANK: {os.environ.get('RANK')}")
print_master(f"LOCAL_RANK: {os.environ.get('LOCAL_RANK')}")
print_master(f"WORLD_SIZE: {os.environ.get('WORLD_SIZE')}")
print_master(f"MASTER_ADDR: {os.environ.get('MASTER_ADDR')}")
print_master(f"MASTER_PORT: {os.environ.get('MASTER_PORT')}")
if dist.is_initialized():
print_rank(f"dist.get_rank(): {dist.get_rank()}")
print_rank(f"dist.get_world_size(): {dist.get_world_size()}")
# 兼容 torchrun 参数
for arg in sys.argv:
if arg.startswith("--local-rank="):
rank = arg.split("=")[1]
sys.argv.remove(arg)
sys.argv.append('--local_rank')
sys.argv.append(rank)
# ----------------------- Parse args -----------------------
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
model_args: ModelArguments
data_args: DataArguments
training_args: TrainingArguments
os.makedirs(data_args.encode_output_path, exist_ok=True)
# 支持多层评测(优先 LM_LAYERS,兼容 MID_LM_LAYER)
layers_to_eval = get_env_eval_layers()
print_master(f"Eval layers (qry/tgt): {layers_to_eval}")
# ----------------------- Model loading -----------------------
hf_config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True)
if not getattr(model_args, "model_backbone", None):
model_backbone = get_backbone_name(hf_config=hf_config, model_type=model_args.model_type)
setattr(model_args, 'model_backbone', model_backbone)
setattr(training_args, 'model_backbone', model_backbone)
print_master(f'Model Backbone: {model_args.model_backbone}')
# 仅 rank0 下载,其他rank等待缓存
if local_rank == 0:
processor = load_processor(model_args, data_args)
model = MMEBModel.load(model_args, is_trainable=False, processor=processor)
print_master(f"[rank=0] Loading the model from Huggingface: {model_args.model_name}...")
if torch.distributed.is_initialized():
torch.distributed.barrier()
if local_rank != 0:
print_rank(f"Loading the model from cache...")
processor = load_processor(model_args, data_args)
time.sleep(random.randint(2 * local_rank, 3 * local_rank))
model = MMEBModel.load(model_args, is_trainable=False, processor=processor)
model.eval()
model = model.to(training_args.device, dtype=torch.bfloat16)
# ---- NEW: AOP 剪裁配置注入(驱动底模里已实现的 AOP 逻辑)----
aop_cfg = get_env_aop_config()
if aop_cfg["enabled"]:
# 把配置塞到底模;底模 forward 中读取该 dict 并执行剪裁
setattr(model.encoder, "aop_prune_config", aop_cfg)
# 可选:为了便于在判定层取注意力或手算 qk,覆盖注意力实现
attn_override = aop_cfg.get("attn_impl_override", "")
if attn_override:
try:
if hasattr(model.encoder, "model") and hasattr(model.encoder.model, "config"):
prev = model.encoder.model.config._attn_implementation
model.encoder.model.config._attn_implementation = attn_override
print_master(f"[AOP] override attn impl: {prev} -> {attn_override} (仅测试建议)")
except Exception as e:
print_master(f"[AOP] try override attn impl failed: {e}")
print_master("[AOP] AOP-Prune enabled with config: " + json.dumps({
"apply_to": aop_cfg["apply_to"],
"layer_idx": aop_cfg["layer_idx"],
"mode": aop_cfg["mode"],
"delta": aop_cfg["delta"],
"K_hat": aop_cfg["K_hat"],
"keep_ratio": aop_cfg["keep_ratio"],
"min_keep": aop_cfg["min_keep"],
"use_bias": aop_cfg["use_bias"],
"margin_mid?": (aop_cfg["margin_mid"] is not None)
}))
else:
print_master("[AOP] disabled (set AOP_ENABLED=1 to enable)")
# ---- NEW: QCTM 合并配置注入(驱动底模里已实现的 QCTM 逻辑)----
qctm_cfg = get_env_qctm_config()
if qctm_cfg["enabled"]:
# 把配置塞到底模;底模 forward 中读取该 dict 并执行合并
setattr(model.encoder, "qctm_config", qctm_cfg)
print_master("[QCTM] enabled with config: " + json.dumps({
"apply_to": qctm_cfg["apply_to"],
"keep_ratio": qctm_cfg["keep_ratio"],
"min_keep": qctm_cfg["min_keep"],
"knn_k": qctm_cfg["knn_k"],
"c_k": qctm_cfg["c_k"],
"kappa_v": qctm_cfg["kappa_v"],
}))
else:
print_master("[QCTM] disabled (set QCTM_ENABLED=1 to enable)")
# 确保“最后一层”时不裁层(避免类里默认20层的坑)
model.set_inference_layers(qry_layers=None, tgt_layers=None)
with open(data_args.dataset_config, 'r') as yaml_file:
dataset_configs = yaml.safe_load(yaml_file)
# ----------------------- Main evaluation loop -----------------------
for dataset_idx, (dataset_name, task_config) in enumerate(dataset_configs.items()):
if dist.is_initialized():
dist.barrier()
print_master(f"\n--- Evaluating {dataset_name} ---")
# 根据 data_basedir 修正路径
if data_args.data_basedir is not None:
for key in ["image_root", "video_root", "frame_root", "clip_root", "data_path"]:
if data_args.data_basedir and task_config.get(key):
task_config[key] = os.path.join(data_args.data_basedir, task_config[key])
# 构建数据集
full_eval_qry_dataset, corpus = AutoEvalPairDataset.instantiate(model_args=model_args, data_args=data_args, **task_config)
full_eval_cand_dataset = generate_cand_dataset(full_eval_qry_dataset, corpus)
eval_qry_dataset, eval_cand_dataset = full_eval_qry_dataset, full_eval_cand_dataset
if dist.is_initialized():
world_size = dist.get_world_size()
padded_qry_dataset, _ = pad_dataset_to_divisible(full_eval_qry_dataset, world_size)
padded_cand_dataset, _ = pad_dataset_to_divisible(full_eval_cand_dataset, world_size)
eval_qry_dataset = split_dataset_by_node(padded_qry_dataset, rank=local_rank, world_size=world_size)
eval_cand_dataset = split_dataset_by_node(padded_cand_dataset, rank=local_rank, world_size=world_size)
else:
padded_qry_dataset, padded_cand_dataset = full_eval_qry_dataset, full_eval_cand_dataset
# 路径索引
saved_paths = {} # {(side, tag): path}
# --------- 针对每个层设置(中间层/最后一层)分别编码与保存 ---------
for keep_layers in layers_to_eval:
tag = make_layer_tag(keep_layers)
print_master(f"[{dataset_name}] Start encoding for tag={tag} (keep_layers={keep_layers})")
# 设置模型层数
model.set_inference_layers(qry_layers=keep_layers, tgt_layers=keep_layers)
# 路径
query_embed_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_qry_{tag}")
cand_embed_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt_{tag}")
dataset_info_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_info.jsonl")
query_inference_stats_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_qry_inference_stats_{tag}.json")
cand_inference_stats_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_cand_inference_stats_{tag}.json")
qry_img_masks_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_qry_img_token_masks_{tag}.jsonl")
cand_img_masks_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_cand_img_token_masks_{tag}.jsonl")
qry_qctm_stats_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_qry_qctm_stats_{tag}.jsonl")
cand_qctm_stats_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_cand_qctm_stats_{tag}.jsonl")
saved_paths[("qry", tag)] = query_embed_path
saved_paths[("tgt", tag)] = cand_embed_path
do_query = not os.path.exists(query_embed_path) or not os.path.exists(dataset_info_path)
do_cand = not os.path.exists(cand_embed_path)
# 动态累计统计
def init_total_stats():
return {
"total_inference_time_seconds": 0.0,
"module_inference_times": {}, # 动态模块名 -> {"total": float, "count": int}
"token_counts": {
"visual_tokens": 0,
"language_input_tokens_raw": 0,
"llm_total_input_tokens": 0,
"language_output_tokens": 0,
},
"data_point_count": 0
}
def accumulate_stats(total_stats, batch_stats):
batch_size = batch_stats["batch_size"]
total_stats["total_inference_time_seconds"] += batch_stats["total_inference_time_seconds"]
# 模块时间
for mname, mstats in batch_stats["module_inference_times"].items():
if mname not in total_stats["module_inference_times"]:
total_stats["module_inference_times"][mname] = {"total": 0.0, "count": 0}
total_stats["module_inference_times"][mname]["total"] += mstats.get("total", 0.0)
total_stats["module_inference_times"][mname]["count"] += mstats.get("count", 0)
# token 统计(按样本乘 batch_size 再累积)
total_stats["token_counts"]["visual_tokens"] += batch_stats["token_counts"]["visual_tokens"] * batch_size
total_stats["token_counts"]["language_input_tokens_raw"] += batch_stats["token_counts"]["language_input_tokens_raw"] * batch_size
total_stats["token_counts"]["llm_total_input_tokens"] += batch_stats["token_counts"]["llm_total_input_tokens"] * batch_size
total_stats["token_counts"]["language_output_tokens"] += batch_stats["token_counts"]["language_output_tokens"] * batch_size
total_stats["data_point_count"] += batch_size
def finalize_and_save_stats(total_stats, out_path, task_name, encode_side):
if local_rank != 0:
return
if total_stats["data_point_count"] <= 0:
print_master(f"No data processed for {task_name} [{encode_side}], skip saving stats.")
return
final_stats = {
"task_name": task_name,
"encode_side": encode_side,
"data_point_count": total_stats["data_point_count"],
"inference_times": {
"total_inference_time_seconds": total_stats["total_inference_time_seconds"],
"avg_inference_time_per_item_seconds": total_stats["total_inference_time_seconds"] / max(1, total_stats["data_point_count"]),
"module_average_times_per_call": {},
"module_total_times_seconds": {},
"module_calls_count": {},
},
"token_counts": {
"total_visual_tokens": total_stats["token_counts"]["visual_tokens"],
"avg_visual_tokens_per_item": total_stats["token_counts"]["visual_tokens"] / max(1, total_stats["data_point_count"]),
"total_language_input_tokens_raw": total_stats["token_counts"]["language_input_tokens_raw"],
"avg_language_input_tokens_raw_per_item": total_stats["token_counts"]["language_input_tokens_raw"] / max(1, total_stats["data_point_count"]),
"total_llm_total_input_tokens": total_stats["token_counts"]["llm_total_input_tokens"],
"avg_llm_total_input_tokens_per_item": total_stats["token_counts"]["llm_total_input_tokens"] / max(1, total_stats["data_point_count"]),
"total_language_output_tokens": total_stats["token_counts"]["language_output_tokens"],
"avg_language_output_tokens_per_item": total_stats["token_counts"]["language_output_tokens"] / max(1, total_stats["data_point_count"]),
}
}
for mname, mstats in total_stats["module_inference_times"].items():
total = mstats.get("total", 0.0)
count = mstats.get("count", 0)
final_stats["inference_times"]["module_total_times_seconds"][mname] = total
final_stats["inference_times"]["module_calls_count"][mname] = count
final_stats["inference_times"]["module_average_times_per_call"][mname] = (total / count) if count > 0 else 0.0
with open(out_path, 'w', encoding='utf-8') as f:
json.dump(final_stats, f, ensure_ascii=False, indent=4)
print_master(f"[{task_name}] {encode_side} inference statistics saved to: {out_path}")
# ------- Encode queries -------
if do_query:
print_master(f"[{tag}] Encoding queries...")
eval_qry_collator = MultimodalEvalDataCollator(processor, model_args, data_args, "qry")
eval_qry_loader = DataLoader(
eval_qry_dataset,
batch_size=training_args.per_device_eval_batch_size,
collate_fn=eval_qry_collator,
num_workers=training_args.dataloader_num_workers
)
query_embeds, gt_infos, qry_batch_stats, qry_img_masks, qry_qctm_stats = encode_embeddings(
model, eval_qry_loader, training_args, model_args, padded_qry_dataset,
encode_side="qry", description=f"Queries[{tag}] for {dataset_name}"
)
# 截断到真实长度
true_Nq = len(full_eval_qry_dataset)
query_embeds = query_embeds[:true_Nq]
gt_infos = gt_infos[:true_Nq]
qry_img_masks = qry_img_masks[:true_Nq]
if isinstance(qry_qctm_stats, list):
qry_qctm_stats = qry_qctm_stats[:true_Nq]
# 累计统计并保存
qry_total_stats = init_total_stats()
for bs in qry_batch_stats:
accumulate_stats(qry_total_stats, bs)
if local_rank == 0:
with open(query_embed_path, 'wb') as f:
pickle.dump(query_embeds, f)
# dataset_info 只需写一次;若第一次就写
if not os.path.exists(dataset_info_path):
with open(dataset_info_path, 'w') as f:
for info in gt_infos:
f.write(json.dumps(info) + '\n')
# 保存 masks
with open(qry_img_masks_path, 'w', encoding='utf-8') as f:
for i, m in enumerate(qry_img_masks):
f.write(json.dumps({"index": i, "mask": m}, ensure_ascii=False) + "\n")
print_master(f"Saved query embeddings to {query_embed_path}")
print_master(f"Saved query image token masks to {qry_img_masks_path}")
# 保存 QCTM stats(可用 QCTM_SAVE_STATS=1 控制)
if os.environ.get("QCTM_SAVE_STATS", "1") == "1" and isinstance(qry_qctm_stats, list):
with open(qry_qctm_stats_path, 'w', encoding='utf-8') as f:
for i, st in enumerate(qry_qctm_stats):
f.write(json.dumps({"index": i, "qctm_stats": st}, ensure_ascii=False) + "\n")
print_master(f"Saved query QCTM stats to {qry_qctm_stats_path}")
if local_rank == 0 and isinstance(qry_qctm_stats, list) and len(qry_qctm_stats) > 0:
effs = [st.get("effective_keep_ratio") for st in qry_qctm_stats if isinstance(st, dict)]
pairs = [st.get("num_pairs") for st in qry_qctm_stats if isinstance(st, dict)]
if effs:
print_master(f"[QCTM][{dataset_name}][{tag}][qry] eff_keep_ratio avg/min/max = "
f"{np.mean(effs):.3f}/{np.min(effs):.3f}/{np.max(effs):.3f}; "
f"num_pairs avg = {np.mean(pairs):.2f}")
finalize_and_save_stats(qry_total_stats, query_inference_stats_path, dataset_name, f"query[{tag}]")
if dist.is_initialized():
dist.barrier()
# ------- Encode candidates -------
if do_cand:
print_master(f"[{tag}] Encoding candidates...")
eval_cand_collator = MultimodalEvalDataCollator(processor, model_args, data_args, "cand")
eval_cand_loader = DataLoader(
eval_cand_dataset,
batch_size=training_args.per_device_eval_batch_size,
collate_fn=eval_cand_collator,
num_workers=training_args.dataloader_num_workers
)
cand_embeds, all_cand_ids, cand_batch_stats, cand_img_masks, cand_qctm_stats = encode_embeddings(
model, eval_cand_loader, training_args, model_args, padded_cand_dataset,
encode_side="cand", description=f"Candidates[{tag}] for {dataset_name}"
)
true_Nc = len(full_eval_cand_dataset)
cand_embeds = cand_embeds[:true_Nc]
all_cand_ids = all_cand_ids[:true_Nc]
cand_img_masks = cand_img_masks[:true_Nc]
if isinstance(cand_qctm_stats, list):
cand_qctm_stats = cand_qctm_stats[:true_Nc]
cand_total_stats = init_total_stats()
for bs in cand_batch_stats:
accumulate_stats(cand_total_stats, bs)
if local_rank == 0:
cand_embed_dict = {cid: emb for cid, emb in zip(all_cand_ids, cand_embeds)}
with open(cand_embed_path, 'wb') as f:
pickle.dump(cand_embed_dict, f)
with open(cand_img_masks_path, 'w', encoding='utf-8') as f:
for cid, m in zip(all_cand_ids, cand_img_masks):
f.write(json.dumps({"cand_id": str(cid), "mask": m}, ensure_ascii=False) + "\n")
print_master(f"Saved candidate embeddings to {cand_embed_path}")
print_master(f"Saved candidate image token masks to {cand_img_masks_path}")
if os.environ.get("QCTM_SAVE_STATS", "1") == "1" and isinstance(cand_qctm_stats, list):
with open(cand_qctm_stats_path, 'w', encoding='utf-8') as f:
for cid, st in zip(all_cand_ids, cand_qctm_stats):
f.write(json.dumps({"cand_id": str(cid), "qctm_stats": st}, ensure_ascii=False) + "\n")
print_master(f"Saved candidate QCTM stats to {cand_qctm_stats_path}")
if local_rank == 0 and isinstance(cand_qctm_stats, list) and len(cand_qctm_stats) > 0:
effs = [st.get("effective_keep_ratio") for st in cand_qctm_stats if isinstance(st, dict)]
pairs = [st.get("num_pairs") for st in cand_qctm_stats if isinstance(st, dict)]
if effs:
print_master(f"[QCTM][{dataset_name}][{tag}][cand] eff_keep_ratio avg/min/max = "
f"{np.mean(effs):.3f}/{np.min(effs):.3f}/{np.max(effs):.3f}; "
f"num_pairs avg = {np.mean(pairs):.2f}")
finalize_and_save_stats(cand_total_stats, cand_inference_stats_path, dataset_name, f"candidate[{tag}]")
if dist.is_initialized():
dist.barrier()
# --------- Scoring per layer + combined + early-exit curve ---------
if local_rank == 0:
dataset_info_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_info.jsonl")
gt_infos = [json.loads(l) for l in open(dataset_info_path)]
rank_against_all_candidates = task_config.get("eval_type", "global") == "global"
metrics_to_report = task_config.get("metrics", ["hit", "ndcg", "precision", "recall", "f1", "map", "mrr"])
layer_tags = [make_layer_tag(l) for l in layers_to_eval]
sims_by_layer = {} # tag -> list[ dict(cand_id->score) ]
for tag in layer_tags:
query_embed_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_qry_{tag}")
cand_embed_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt_{tag}")
with open(query_embed_path, 'rb') as f:
qry_embeds = pickle.load(f)
with open(cand_embed_path, 'rb') as f:
cand_embed_dict = pickle.load(f)
pred_dicts = []
score_detail_dicts = []
sims_for_exit = []
if rank_against_all_candidates:
cand_keys = list(cand_embed_dict.keys())
cand_embeds = np.stack([cand_embed_dict[key] for key in cand_keys])
if isinstance(qry_embeds, np.ndarray) and qry_embeds.ndim == 3:
# Late-interaction
qry_embed_t = torch.from_numpy(qry_embeds)
cand_embeds_t = [torch.from_numpy(np.array(t)) for t in cand_embeds]
sim_matrix = processor.score(qry_embed_t, cand_embeds_t, batch_size=64).cpu().numpy()
else:
sim_matrix = np.dot(qry_embeds, cand_embeds.T)
ranked_all = np.argsort(-sim_matrix, axis=1)
for qid, gt_info in tqdm(enumerate(gt_infos), total=len(gt_infos), desc=f"[{tag}] scoring(all) {dataset_name}"):
ranked_indices = ranked_all[qid]
rel_docids = gt_info["label_name"] if isinstance(gt_info["label_name"], list) else [gt_info["label_name"]]
rel_scores = gt_info.get("rel_scores")
pred_dicts.append({
"prediction": [cand_keys[i] for i in ranked_indices],
"label": rel_docids,
"rel_scores": rel_scores,
})
score_detail_dicts.append(build_score_details(qid, cand_keys, sim_matrix[qid], ranked_indices))
sims_for_exit.append({cand_keys[i]: float(sim_matrix[qid][i]) for i in range(len(cand_keys))})
else:
# 非全局:每个query用 gt_info["cand_names"] 的子集进行评分
for qid, (qry_embed, gt_info) in tqdm(enumerate(zip(qry_embeds, gt_infos)), total=len(gt_infos), desc=f"[{tag}] scoring(local) {dataset_name}"):
cand_ids_local = gt_info["cand_names"]
cand_embeds = np.stack([cand_embed_dict[key] for key in cand_ids_local])
if isinstance(qry_embeds, np.ndarray) and qry_embeds.ndim == 3:
qry_embed_t = torch.from_numpy(np.array(qry_embed)).unsqueeze(0) # [1, Lq, H]
cand_embeds_t = [torch.from_numpy(np.array(t)) for t in cand_embeds]
sim_vec = processor.score(qry_embed_t, cand_embeds_t, batch_size=1024).cpu().numpy()[0]
else:
sim_vec = np.dot(qry_embed, cand_embeds.T)
ranked_indices = np.argsort(-sim_vec)
rel_docids = gt_info["label_name"] if isinstance(gt_info["label_name"], list) else [gt_info["label_name"]]
rel_scores = gt_info.get("rel_scores")
pred_dicts.append({
"prediction": [cand_ids_local[i] for i in ranked_indices],
"label": rel_docids,
"rel_scores": rel_scores,
})
score_detail_dicts.append(build_score_details(qid, cand_ids_local, sim_vec, ranked_indices))
sims_for_exit.append({cid: float(s) for cid, s in zip(cand_ids_local, sim_vec.tolist())})
# 保存每层指标与详情
layer_score_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_score_{tag}.json")
layer_pred_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_pred_{tag}.jsonl")
layer_detail_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_score_details_{tag}.jsonl")
metrics = RankingMetrics(metrics_to_report)
score_dict = metrics.evaluate(pred_dicts)
score_dict["num_pred"] = len(pred_dicts)
score_dict["num_data"] = len(gt_infos)
with open(layer_score_path, "w") as f:
json.dump(score_dict, f, indent=4)
with open(layer_pred_path, "w") as f:
for pred in pred_dicts:
f.write(json.dumps(pred) + '\n')
with open(layer_detail_path, "w") as f:
for detail in score_detail_dicts:
f.write(json.dumps(detail) + "\n")
print_master(f"[{dataset_name}] {tag} score: " + json.dumps({k: (f"{v:.4f}" if isinstance(v, (int, float)) else v) for k, v in score_dict.items()}))
sims_by_layer[tag] = sims_for_exit
# 合并对比文件 + 早停曲线(与 last 对比),如果存在 last
last_tag = "layerlast" if "layerlast" in layer_tags else None
if last_tag is not None:
# 准备 labels 一次即可
labels = [
gi["label_name"] if isinstance(gi["label_name"], list) else [gi["label_name"]]
for gi in gt_infos
]
taus = [round(x, 3) for x in np.linspace(0.0, 0.6, 31).tolist()]
# 对每个中间层分别与 last 做对比
for mid_tag in [t for t in layer_tags if t != last_tag]:
# 合并详情:每个query包含 mid/last 的cand_scores、top1、margin
combined_path = os.path.join(
data_args.encode_output_path,
f"{dataset_name}_score_details_{mid_tag}_vs_last.jsonl"
)
with open(combined_path, "w", encoding="utf-8") as f:
for qid in range(len(gt_infos)):
sims_mid = sims_by_layer[mid_tag][qid]
sims_last = sims_by_layer[last_tag][qid]
def top1_cid(sims: dict):
return max(sims.items(), key=lambda x: x[1])[0] if sims else None
def margin_of(sims: dict):
vals = np.array(list(sims.values()), dtype=np.float32)
return top1_top2_margin_from_array(vals)
row = {
"qid": int(qid),
"label": labels[qid],
"mid": {
"top1": top1_cid(sims_mid),
"margin": margin_of(sims_mid),
"cand_scores": sims_mid
},
"last": {
"top1": top1_cid(sims_last),
"margin": margin_of(sims_last),
"cand_scores": sims_last
}
}
f.write(json.dumps(row, ensure_ascii=False) + "\n")
print_master(f"[{dataset_name}] combined details saved to {combined_path} (mid={mid_tag} vs last)")
# 早停曲线(margin 阈值)
exit_curve = simulate_early_exit_by_margin(
sims_by_layer[mid_tag], sims_by_layer[last_tag], labels, metrics_to_report, taus, rank_against_all_candidates
)
curve_path = os.path.join(
data_args.encode_output_path,
f"{dataset_name}_early_exit_curve_margin_{mid_tag}_vs_last.json"
)
with open(curve_path, "w") as f:
json.dump(exit_curve, f, indent=4)
print_master(f"[{dataset_name}] early-exit curve saved to {curve_path} (mid={mid_tag} vs last)")
if dist.is_initialized():
dist.barrier()
if __name__ == '__main__':
main()