code_SAS_VLM2Vec / eval_test_time_cut_layer_unified.py
MgGladys's picture
Add files using upload-large-folder tool
2a40e7a verified
# -*- coding: utf-8 -*-
import logging
import json
import random
import time
import numpy as np
import os
import pickle
import sys
import torch
import torch.distributed as dist
import torch.nn.functional as F
import yaml
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import HfArgumentParser, AutoConfig
from datasets import concatenate_datasets
from datasets.distributed import split_dataset_by_node
from src.arguments import ModelArguments, DataArguments, TrainingArguments
from src.data.collator.eval_collator import MultimodalEvalDataCollator
from src.data.eval_dataset.base_eval_dataset import AutoEvalPairDataset, generate_cand_dataset
from src.eval_utils.metrics import RankingMetrics
from src.model.model_cut_layer_AOP import MMEBModel # NOTE: 使用你的 AOP 版本(支持 cut_layer + mask 透传)
from src.model.processor import get_backbone_name, load_processor, COLPALI
from src.utils import batch_to_device, print_rank, print_master
logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s [%(name)s:%(lineno)s] %(message)s')
logger = logging.getLogger(__name__)
# ----------------- 环境变量解析 -----------------
def _parse_bool(v: str, default=False):
if v is None: return default
v = v.strip().lower()
return v in {"1","true","yes","y","t","on"}
def _parse_float(v: str, default=None):
try: return float(v) if v is not None else default
except: return default
def _parse_int(v: str, default=None):
try: return int(v) if v is not None else default
except: return default
def get_env_eval_layers():
"""
LM_LAYERS: "4,8,12,last"(last/none/-1 -> None 代表最后一层)
未设置则默认 [None]
"""
v = os.environ.get("LM_LAYERS", None)
if v:
toks = [t.strip() for t in v.split(',') if t.strip() != ""]
layers = []
for tok in toks:
tl = tok.lower()
if tl in {"last","none","null","-1"}:
layers.append(None)
else:
try:
val = int(tok)
if val > 0: layers.append(val)
except:
logger.warning(f"Invalid token '{tok}' in LM_LAYERS; ignored.")
# 去重保序
seen=set(); uniq=[]
for l in layers:
key = -1 if l is None else l
if key in seen: continue
seen.add(key); uniq.append(l)
return uniq or [None]
return [None]
def get_env_zip_config():
"""
VisionZip(输入侧 token 压缩)配置:
- ZIP_ENABLED=1 开启
- ZIP_APPLY=qry|cand|both
- ZIP_METHOD=visionzip|none
- ZIP_KEEP_DOM / ZIP_KEEP_CTX = dominant/context 保留份额
"""
cfg = {
"enabled": _parse_bool(os.environ.get("ZIP_ENABLED"), False),
"apply_to": (os.environ.get("ZIP_APPLY","both").strip().lower()),
"method": os.environ.get("ZIP_METHOD","visionzip").strip().lower(),
"keep_dom": _parse_float(os.environ.get("ZIP_KEEP_DOM"), 0.45),
"keep_ctx": _parse_float(os.environ.get("ZIP_KEEP_CTX"), 0.10),
}
if cfg["method"] == "none":
cfg["enabled"] = False
return cfg
def get_env_aop_config():
"""
AOP(层内剪裁)配置:
- AOP_ENABLED=1
- AOP_APPLY=qry|cand|both
- AOP_LAYER=N(1-based,在进入该层前剪裁一次)
- AOP_MODE=delta|ratio
- AOP_KEEP_RATIO / AOP_DELTA / AOP_KHAT / AOP_MIN_KEEP / AOP_USE_BIAS
- AOP_ATTN_IMPL=sdpa 可选覆盖注意力实现(便于取权重或稳定)
"""
enabled = _parse_bool(os.environ.get("AOP_ENABLED"), False)
apply_to = os.environ.get("AOP_APPLY", "qry").strip().lower()
layer_idx = _parse_int(os.environ.get("AOP_LAYER"), None)
mode = os.environ.get("AOP_MODE", "delta").strip().lower()
delta = _parse_float(os.environ.get("AOP_DELTA"), 0.10)
khat = _parse_float(os.environ.get("AOP_KHAT"), 1.0)
keep_ratio = _parse_float(os.environ.get("AOP_KEEP_RATIO"), 1.0)
min_keep = _parse_int(os.environ.get("AOP_MIN_KEEP"), 64)
use_bias = _parse_bool(os.environ.get("AOP_USE_BIAS"), True)
attn_impl = os.environ.get("AOP_ATTN_IMPL","").strip().lower()
if layer_idx is None and enabled:
logger.warning("AOP_ENABLED=1 但未设置 AOP_LAYER,关闭 AOP。"); enabled=False
return {
"enabled": enabled,
"apply_to": apply_to,
"layer_idx": layer_idx,
"mode": mode, "delta": delta, "K_hat": khat,
"keep_ratio": keep_ratio, "min_keep": min_keep,
"use_bias": use_bias, "eps": 1e-6,
"attn_impl_override": attn_impl if attn_impl in {"sdpa"} else "",
}
# ----------------- Hook & utils -----------------
timing_info = {}
token_info = {"vision_tokens":0,"text_input_tokens":0,"text_output_tokens":0,"total_llm_input_tokens":0}
def timing_pre_hook(module, input):
module_id = id(module)
if module_id not in timing_info:
timing_info[module_id] = []
timing_info[module_id].append((time.time(), 'pre', module.__class__.__name__))
def timing_post_hook(module, input, output):
module_id = id(module)
if module_id not in timing_info:
return
timing_info[module_id].append((time.time(), 'post', module.__class__.__name__))
module_name = module.__class__.__name__
if "vision" in module_name.lower() and "transformer" in module_name.lower():
if isinstance(output, torch.Tensor):
token_info["vision_tokens"] = output.shape[0]
elif hasattr(output, 'last_hidden_state'):
token_info["vision_tokens"] = output.last_hidden_state.shape[1]
def register_model_hooks(model):
# 在 encoder 的 visual/merger/LLM/lm_head 上打点
regs=[]
core = model
if hasattr(model,'encoder') and model.encoder is not None:
core = model.encoder
if hasattr(core,'visual') and core.visual is not None:
core.visual.register_forward_pre_hook(timing_pre_hook)
core.visual.register_forward_hook(timing_post_hook); regs.append(core.visual)
if hasattr(core,'visual') and hasattr(core.visual,'merger') and core.visual.merger is not None:
core.visual.merger.register_forward_pre_hook(timing_pre_hook)
core.visual.merger.register_forward_hook(timing_post_hook); regs.append(core.visual.merger)
if hasattr(core,'model') and core.model is not None:
core.model.register_forward_pre_hook(timing_pre_hook)
core.model.register_forward_hook(timing_post_hook); regs.append(core.model)
if hasattr(core,'lm_head') and core.lm_head is not None:
core.lm_head.register_forward_pre_hook(timing_pre_hook)
core.lm_head.register_forward_hook(timing_post_hook); regs.append(core.lm_head)
return regs
def pad_dataset_to_divisible(dataset, world_size):
n = len(dataset)
if n % world_size == 0: return dataset, n
m = world_size - (n % world_size)
pad = dataset.select([i % len(dataset) for i in range(m)])
return concatenate_datasets([dataset, pad]), n + m
# ----------------- 编码函数(合并 AOP + VisionZip 注入 + cut_layer) -----------------
def encode_embeddings(
model: MMEBModel,
loader: DataLoader,
training_args: TrainingArguments,
model_args: ModelArguments,
full_dataset,
encode_side: str,
zip_cfg: dict,
aop_cfg: dict,
description: str = "Encoding"
):
local_rank = dist.get_rank() if dist.is_initialized() else 0
is_late_interaction = (model_args.model_backbone == COLPALI)
embeds, infos, batch_stats_list, img_masks_all = [], [], [], []
local_max_len = 0
model.eval()
regs = register_model_hooks(model)
with torch.no_grad():
for inputs, dataset_info in tqdm(loader, desc=f"{description} (rank {local_rank})", disable=local_rank > 0):
# reset stats
timing_info.clear()
token_info.update({"vision_tokens":0,"text_input_tokens":0,"text_output_tokens":0,"total_llm_input_tokens":0})
inputs = batch_to_device(inputs, training_args.device)
B = inputs.get('input_ids', torch.empty(1,1)).shape[0]
# VisionZip:按侧注入 zip_runtime_cfg
if zip_cfg and zip_cfg.get("enabled", False):
apply_to = zip_cfg.get("apply_to","both")
side_enable = (apply_to == "both") or (apply_to == encode_side) or (encode_side=="cand" and apply_to=="tgt")
if side_enable:
inputs["zip_runtime_cfg"] = {
"enable": True,
"method": zip_cfg.get("method","visionzip"),
"keep_dominant_ratio": float(zip_cfg.get("keep_dom", 0.45)),
"keep_context_ratio": float(zip_cfg.get("keep_ctx", 0.10)),
}
# AOP:按侧临时开启 enabled(encoder.aop_prune_config 由 main 挂载)
_orig_enabled = None
enc = model.encoder
if isinstance(aop_cfg, dict) and aop_cfg:
_orig_enabled = aop_cfg.get("enabled", False)
apply_to = aop_cfg.get("apply_to","qry")
side_enable = (apply_to == "both") or (apply_to == encode_side) or (encode_side=="cand" and apply_to=="tgt")
aop_cfg["enabled"] = bool(side_enable and _orig_enabled)
setattr(enc, "aop_prune_config", aop_cfg)
with torch.autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda"):
t0 = time.time()
if encode_side == "qry":
out = model(qry=inputs)
reps = out["qry_reps"].detach()
infos.extend(dataset_info)
else:
out = model(tgt=inputs)
reps = out["tgt_reps"].detach()
infos.extend([x["cand_name"] for x in dataset_info])
t1 = time.time()
if isinstance(aop_cfg, dict) and _orig_enabled is not None:
aop_cfg["enabled"] = _orig_enabled
setattr(enc, "aop_prune_config", aop_cfg)
# image token masks(优先从 out 取;你的 model_cut_layer_AOP 已返回)
img_masks = None
if isinstance(out, dict):
img_masks = out.get("image_token_bool_masks", None)
if img_masks is None and hasattr(model, "_last_image_token_bool_masks"):
img_masks = getattr(model, "_last_image_token_bool_masks")
# 统一序列化:list[None|list[bool]] * B
if img_masks is None:
img_masks_list = [None] * B
elif torch.is_tensor(img_masks):
if img_masks.dim() == 2:
img_masks_list = img_masks.detach().cpu().tolist()
else:
img_masks_list = [None] * B
elif isinstance(img_masks, list):
# 尝试转 list
tmp=[]
for m in img_masks:
if torch.is_tensor(m): tmp.append(m.detach().cpu().tolist())
else: tmp.append(m)
img_masks_list = tmp
if len(img_masks_list) < B: img_masks_list += [None]*(B-len(img_masks_list))
if len(img_masks_list) > B: img_masks_list = img_masks_list[:B]
else:
img_masks_list = [None] * B
img_masks_all.extend(img_masks_list)
# 统计
if 'input_ids' in inputs and inputs['input_ids'] is not None:
token_info["total_llm_input_tokens"] = int(inputs['input_ids'].shape[1])
token_info["text_input_tokens"] = max(0, token_info["total_llm_input_tokens"] - token_info["vision_tokens"])
batch_stats = {
"batch_size": B,
"total_inference_time_seconds": float(t1 - t0),
"module_inference_times": {},
"token_counts": {
"visual_tokens": token_info["vision_tokens"],
"language_input_tokens_raw": token_info["text_input_tokens"],
"llm_total_input_tokens": token_info["total_llm_input_tokens"],
"language_output_tokens": token_info["text_output_tokens"],
}
}
# 模块时间
for m in regs:
mid = id(m)
name = m.__class__.__name__
times = timing_info.get(mid, [])
durations=[]; pre=None
for (ts, tp, _) in times:
if tp == 'pre': pre = ts
elif tp == 'post' and pre is not None:
durations.append(ts - pre); pre=None
if durations:
batch_stats["module_inference_times"][name] = {
"total": sum(durations), "count": len(durations), "avg": sum(durations)/len(durations)
}
else:
batch_stats["module_inference_times"][name] = {"total":0.0,"count":0,"avg":0.0}
batch_stats_list.append(batch_stats)
print_rank(f"[{encode_side}] time={t1-t0:.4f}s, vis_tokens={token_info['vision_tokens']}")
if is_late_interaction and reps.dim()==3:
local_max_len = max(local_max_len, reps.shape[1])
embeds.append(reps)
if not embeds:
return np.array([]), [], [], []
# Late-interaction padding
if is_late_interaction:
if dist.is_initialized():
lm = torch.tensor(local_max_len, device=training_args.device)
dist.all_reduce(lm, op=dist.ReduceOp.MAX)
global_max_len = int(lm.item())
else:
global_max_len = local_max_len
padded=[]
for e in embeds:
if e.dim()==3:
B, L, H = e.shape
pad = global_max_len - L
e = F.pad(e, (0,0,0,pad), "constant", 0)
padded.append(e)
embeds_tensor = torch.cat(padded, dim=0).contiguous()
else:
embeds_tensor = torch.cat(embeds, dim=0).contiguous()
# DDP gather
if dist.is_initialized() and len(full_dataset) >= dist.get_world_size():
print_master(f"Gathering {encode_side} embeddings across ranks...")
output_shape = list(embeds_tensor.shape); output_shape[0] = len(full_dataset)
embeds_tensor = embeds_tensor.to(training_args.device)
gathered = torch.empty(output_shape, dtype=embeds_tensor.dtype, device=training_args.device)
dist.all_gather_into_tensor(gathered, embeds_tensor)
final_embeddings = gathered.cpu().float().numpy()
gathered_infos=[None for _ in range(dist.get_world_size())]
dist.all_gather_object(gathered_infos, infos)
all_infos=[x for r in gathered_infos for x in r]
gathered_stats=[None for _ in range(dist.get_world_size())]
dist.all_gather_object(gathered_stats, batch_stats_list)
all_stats=[s for r in gathered_stats for s in r]
gathered_masks=[None for _ in range(dist.get_world_size())]
dist.all_gather_object(gathered_masks, img_masks_all)
all_masks=[m for r in gathered_masks for m in r]
else:
final_embeddings = embeds_tensor.cpu().float().numpy()
all_infos = infos
all_stats = batch_stats_list
all_masks = img_masks_all
return final_embeddings, all_infos, all_stats, all_masks
# ----------------- 主入口 -----------------
def main():
# DDP init(与 torchrun 兼容)
if "RANK" in os.environ and dist.is_available() and not dist.is_initialized():
timeout = int(os.environ.get("DDP_TIMEOUT_MIN", "60"))
dist.init_process_group(backend="nccl", timeout=torch.distributed.timedelta(minutes=timeout))
for arg in sys.argv:
if arg.startswith("--local-rank="):
rank = arg.split("=")[1]
sys.argv.remove(arg); sys.argv += ['--local_rank', rank]
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
os.makedirs(data_args.encode_output_path, exist_ok=True)
layers_to_eval = get_env_eval_layers()
zip_cfg = get_env_zip_config()
aop_cfg = get_env_aop_config()
print_master(f"Eval layers: {layers_to_eval}")
print_master(f"[ZIP] enabled={zip_cfg.get('enabled',False)}, apply_to={zip_cfg.get('apply_to')}, method={zip_cfg.get('method')}")
print_master(f"[AOP] enabled={aop_cfg.get('enabled',False)}, apply_to={aop_cfg.get('apply_to')}, layer={aop_cfg.get('layer_idx')}")
# 加载模型
hf_config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True)
if not getattr(model_args, "model_backbone", None):
model_backbone = get_backbone_name(hf_config=hf_config, model_type=model_args.model_type)
setattr(model_args, 'model_backbone', model_backbone)
setattr(training_args, 'model_backbone', model_backbone)
print_master(f"Backbone: {model_args.model_backbone}")
# 仅 rank0 先下载,其他等待缓存
local_rank = dist.get_rank() if dist.is_initialized() else 0
if local_rank == 0:
processor = load_processor(model_args, data_args)
model = MMEBModel.load(model_args, is_trainable=False, processor=processor)
print_master(f"[rank=0] loaded {model_args.model_name}")
if dist.is_initialized(): dist.barrier()
if local_rank != 0:
processor = load_processor(model_args, data_args)
time.sleep(random.randint(2*local_rank, 3*local_rank))
model = MMEBModel.load(model_args, is_trainable=False, processor=processor)
model.eval()
model = model.to(training_args.device, dtype=torch.bfloat16)
# 默认“最后一层”不裁层;每次循环会按需覆盖
model.set_inference_layers(qry_layers=None, tgt_layers=None)
# 注入 AOP 底模配置(实例属性)
if aop_cfg.get("enabled", False):
setattr(model.encoder, "aop_prune_config", aop_cfg)
attn_override = aop_cfg.get("attn_impl_override","")
if attn_override and hasattr(model.encoder,"model") and hasattr(model.encoder.model,"config"):
prev = model.encoder.model.config._attn_implementation
model.encoder.model.config._attn_implementation = attn_override
print_master(f"[AOP] attn impl override: {prev} -> {attn_override}")
else:
print_master("[AOP] disabled")
with open(data_args.dataset_config, 'r') as f:
dataset_configs = yaml.safe_load(f)
# ----------------------- 主评测循环 -----------------------
for dataset_name, task_cfg in dataset_configs.items():
if dist.is_initialized(): dist.barrier()
print_master(f"\n--- Evaluating {dataset_name} ---")
# 基目录修正
if data_args.data_basedir:
for k in ["image_root","video_root","frame_root","clip_root","data_path"]:
if task_cfg.get(k): task_cfg[k] = os.path.join(data_args.data_basedir, task_cfg[k])
full_qry_dataset, corpus = AutoEvalPairDataset.instantiate(model_args=model_args, data_args=data_args, **task_cfg)
full_cand_dataset = generate_cand_dataset(full_qry_dataset, corpus)
eval_qry_dataset, eval_cand_dataset = full_qry_dataset, full_cand_dataset
if dist.is_initialized():
ws = dist.get_world_size()
padded_qry, _ = pad_dataset_to_divisible(full_qry_dataset, ws)
padded_cand, _ = pad_dataset_to_divisible(full_cand_dataset, ws)
eval_qry_dataset = split_dataset_by_node(padded_qry, rank=local_rank, world_size=ws)
eval_cand_dataset = split_dataset_by_node(padded_cand, rank=local_rank, world_size=ws)
for keep_layers in layers_to_eval:
tag = f"layer{keep_layers}" if keep_layers else "layerlast"
print_master(f"[{dataset_name}] tag={tag}, keep_layers={keep_layers}")
model.set_inference_layers(qry_layers=keep_layers, tgt_layers=keep_layers)
# 输出路径
q_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_qry_{tag}")
c_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt_{tag}")
info_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_info.jsonl")
q_stats_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_qry_inference_stats_{tag}.json")
c_stats_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_cand_inference_stats_{tag}.json")
q_masks_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_qry_img_token_masks_{tag}.jsonl")
c_masks_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_cand_img_token_masks_{tag}.jsonl")
# 统计累积器
def _init_total():
return {
"total_inference_time_seconds": 0.0,
"module_inference_times": {},
"token_counts": {"visual_tokens":0,"language_input_tokens_raw":0,"llm_total_input_tokens":0,"language_output_tokens":0},
"data_point_count": 0
}
def _acc(total, bs):
bsz = bs["batch_size"]
total["total_inference_time_seconds"] += bs["total_inference_time_seconds"]
for m, ms in bs["module_inference_times"].items():
if m not in total["module_inference_times"]:
total["module_inference_times"][m] = {"total":0.0,"count":0}
total["module_inference_times"][m]["total"] += ms.get("total",0.0)
total["module_inference_times"][m]["count"] += ms.get("count",0)
for k in total["token_counts"]:
total["token_counts"][k] += bs["token_counts"][k] * bsz
total["data_point_count"] += bsz
def _finalize(total, out_path, task_name, side_name):
if local_rank != 0: return
n = max(1, total["data_point_count"])
final = {
"task_name": task_name,
"encode_side": side_name,
"data_point_count": total["data_point_count"],
"inference_times":{
"total_inference_time_seconds": total["total_inference_time_seconds"],
"avg_inference_time_per_item_seconds": total["total_inference_time_seconds"]/n,
"module_average_times_per_call": {},
"module_total_times_seconds": {},
"module_calls_count": {},
},
"token_counts":{
"total_visual_tokens": total["token_counts"]["visual_tokens"],
"avg_visual_tokens_per_item": total["token_counts"]["visual_tokens"]/n,
"total_language_input_tokens_raw": total["token_counts"]["language_input_tokens_raw"],
"avg_language_input_tokens_raw_per_item": total["token_counts"]["language_input_tokens_raw"]/n,
"total_llm_total_input_tokens": total["token_counts"]["llm_total_input_tokens"],
"avg_llm_total_input_tokens_per_item": total["token_counts"]["llm_total_input_tokens"]/n,
"total_language_output_tokens": total["token_counts"]["language_output_tokens"],
"avg_language_output_tokens_per_item": total["token_counts"]["language_output_tokens"]/n,
}
}
for m, ms in total["module_inference_times"].items():
final["inference_times"]["module_total_times_seconds"][m] = ms["total"]
final["inference_times"]["module_calls_count"][m] = ms["count"]
final["inference_times"]["module_average_times_per_call"][m] = (ms["total"]/ms["count"]) if ms["count"]>0 else 0.0
with open(out_path, 'w', encoding='utf-8') as f:
json.dump(final, f, ensure_ascii=False, indent=4)
print_master(f"[{task_name}] {side_name} stats saved: {out_path}")
# 编码 QRY
if (not os.path.exists(q_path)) or (not os.path.exists(info_path)):
print_master(f"[{tag}] Encode queries...")
coll_q = MultimodalEvalDataCollator(processor, model_args, data_args, "qry")
loader_q = DataLoader(eval_qry_dataset, batch_size=training_args.per_device_eval_batch_size,
collate_fn=coll_q, num_workers=training_args.dataloader_num_workers)
q_embeds, q_infos, q_stats, q_masks = encode_embeddings(
model, loader_q, training_args, model_args, full_qry_dataset,
encode_side="qry", zip_cfg=zip_cfg, aop_cfg=aop_cfg, description=f"Queries[{tag}] {dataset_name}"
)
q_embeds = q_embeds[:len(full_qry_dataset)]
q_infos = q_infos[:len(full_qry_dataset)]
q_masks = q_masks[:len(full_qry_dataset)]
q_total = _init_total()
for bs in q_stats: _acc(q_total, bs)
if local_rank == 0:
with open(q_path, 'wb') as f: pickle.dump(q_embeds, f)
if not os.path.exists(info_path):
with open(info_path, 'w') as f:
for info in q_infos: f.write(json.dumps(info) + '\n')
with open(q_masks_path, 'w', encoding='utf-8') as f:
for i, m in enumerate(q_masks): f.write(json.dumps({"index": i, "mask": m}, ensure_ascii=False) + "\n")
_finalize(q_total, q_stats_path, dataset_name, f"query[{tag}]")
if dist.is_initialized(): dist.barrier()
# 编码 CAND
if not os.path.exists(c_path):
print_master(f"[{tag}] Encode candidates...")
coll_c = MultimodalEvalDataCollator(processor, model_args, data_args, "cand")
loader_c = DataLoader(eval_cand_dataset, batch_size=training_args.per_device_eval_batch_size,
collate_fn=coll_c, num_workers=training_args.dataloader_num_workers)
c_embeds, c_ids, c_stats, c_masks = encode_embeddings(
model, loader_c, training_args, model_args, full_cand_dataset,
encode_side="cand", zip_cfg=zip_cfg, aop_cfg=aop_cfg, description=f"Cands[{tag}] {dataset_name}"
)
c_embeds = c_embeds[:len(full_cand_dataset)]
c_ids = c_ids[:len(full_cand_dataset)]
c_masks = c_masks[:len(full_cand_dataset)]
c_total = _init_total()
for bs in c_stats: _acc(c_total, bs)
if local_rank == 0:
cdict = {cid: emb for cid, emb in zip(c_ids, c_embeds)}
with open(c_path, 'wb') as f: pickle.dump(cdict, f)
with open(c_masks_path, 'w', encoding='utf-8') as f:
for cid, m in zip(c_ids, c_masks): f.write(json.dumps({"cand_id": str(cid), "mask": m}, ensure_ascii=False) + "\n")
_finalize(c_total, c_stats_path, dataset_name, f"cand[{tag}]")
if dist.is_initialized(): dist.barrier()
# 评分(同你原逻辑)
if local_rank == 0:
info_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_info.jsonl")
gt_infos = [json.loads(l) for l in open(info_path)]
rank_global = (dataset_configs[dataset_name].get("eval_type", "global") == "global")
metrics_to_report = dataset_configs[dataset_name].get("metrics", ["hit","ndcg","precision","recall","f1","map","mrr"])
for keep_layers in layers_to_eval:
tag = f"layer{keep_layers}" if keep_layers else "layerlast"
q_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_qry_{tag}")
c_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt_{tag}")
with open(q_path, 'rb') as f: qry_embeds = pickle.load(f)
with open(c_path, 'rb') as f: cand_embed_dict = pickle.load(f)
pred_dicts=[]
if rank_global:
ck = list(cand_embed_dict.keys())
ce = np.stack([cand_embed_dict[k] for k in ck])
if isinstance(qry_embeds, np.ndarray) and qry_embeds.ndim==3:
# late-interaction
# 这里按需调用 processor.score(略),若需要可补上
sim = qry_embeds @ ce.T # 占位:如有自定义 late score 可以在此替换
else:
sim = qry_embeds @ ce.T
ranked = np.argsort(-sim, axis=1)
for qid, gi in tqdm(enumerate(gt_infos), total=len(gt_infos), desc=f"[{tag}] scoring(all) {dataset_name}"):
rid = ranked[qid]
label = gi["label_name"] if isinstance(gi["label_name"], list) else [gi["label_name"]]
pred_dicts.append({"prediction":[ck[i] for i in rid],"label":label,"rel_scores":gi.get("rel_scores")})
else:
for qid, (qe, gi) in tqdm(enumerate(zip(qry_embeds, gt_infos)), total=len(gt_infos), desc=f"[{tag}] scoring(local) {dataset_name}"):
cand_ids = gi["cand_names"]
ce = np.stack([cand_embed_dict[k] for k in cand_ids])
if isinstance(qry_embeds, np.ndarray) and qry_embeds.ndim==3:
sim_vec = qe @ ce.T
else:
sim_vec = qe @ ce.T
rid = np.argsort(-sim_vec)
label = gi["label_name"] if isinstance(gi["label_name"], list) else [gi["label_name"]]
pred_dicts.append({"prediction":[cand_ids[i] for i in rid],"label":label,"rel_scores":gi.get("rel_scores")})
metrics = RankingMetrics(metrics_to_report)
score = metrics.evaluate(pred_dicts)
score["num_pred"] = len(pred_dicts); score["num_data"] = len(gt_infos)
layer_score_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_score_{tag}.json")
layer_pred_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_pred_{tag}.jsonl")
with open(layer_score_path, "w") as f: json.dump(score, f, indent=4)
with open(layer_pred_path, "w") as f:
for p in pred_dicts: f.write(json.dumps(p) + '\n')
print_master(f"[{dataset_name}] {tag} score: " + json.dumps({k: (f'{v:.4f}' if isinstance(v,(int,float)) else v) for k,v in score.items()}))
if dist.is_initialized(): dist.barrier()
if __name__ == '__main__':
main()