| |
| import logging |
| import json |
| import random |
| import time |
| import numpy as np |
| import os |
| import pickle |
| import sys |
| import torch |
| import torch.distributed as dist |
| import torch.nn.functional as F |
| import yaml |
|
|
| from torch.utils.data import DataLoader |
| from tqdm import tqdm |
| from transformers import HfArgumentParser, AutoConfig |
| from datasets import concatenate_datasets |
| from datasets.distributed import split_dataset_by_node |
|
|
| from src.arguments import ModelArguments, DataArguments, TrainingArguments |
| from src.data.collator.eval_collator import MultimodalEvalDataCollator |
| from src.data.eval_dataset.base_eval_dataset import AutoEvalPairDataset, generate_cand_dataset |
| from src.eval_utils.metrics import RankingMetrics |
| from src.model.model_cut_layer_AOP import MMEBModel |
| from src.model.processor import get_backbone_name, load_processor, COLPALI |
| from src.utils import batch_to_device, print_rank, print_master |
|
|
| logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s [%(name)s:%(lineno)s] %(message)s') |
| logger = logging.getLogger(__name__) |
|
|
| |
| def _parse_bool(v: str, default=False): |
| if v is None: return default |
| v = v.strip().lower() |
| return v in {"1","true","yes","y","t","on"} |
|
|
| def _parse_float(v: str, default=None): |
| try: return float(v) if v is not None else default |
| except: return default |
|
|
| def _parse_int(v: str, default=None): |
| try: return int(v) if v is not None else default |
| except: return default |
|
|
| def get_env_eval_layers(): |
| """ |
| LM_LAYERS: "4,8,12,last"(last/none/-1 -> None 代表最后一层) |
| 未设置则默认 [None] |
| """ |
| v = os.environ.get("LM_LAYERS", None) |
| if v: |
| toks = [t.strip() for t in v.split(',') if t.strip() != ""] |
| layers = [] |
| for tok in toks: |
| tl = tok.lower() |
| if tl in {"last","none","null","-1"}: |
| layers.append(None) |
| else: |
| try: |
| val = int(tok) |
| if val > 0: layers.append(val) |
| except: |
| logger.warning(f"Invalid token '{tok}' in LM_LAYERS; ignored.") |
| |
| seen=set(); uniq=[] |
| for l in layers: |
| key = -1 if l is None else l |
| if key in seen: continue |
| seen.add(key); uniq.append(l) |
| return uniq or [None] |
| return [None] |
|
|
| def get_env_zip_config(): |
| """ |
| VisionZip(输入侧 token 压缩)配置: |
| - ZIP_ENABLED=1 开启 |
| - ZIP_APPLY=qry|cand|both |
| - ZIP_METHOD=visionzip|none |
| - ZIP_KEEP_DOM / ZIP_KEEP_CTX = dominant/context 保留份额 |
| """ |
| cfg = { |
| "enabled": _parse_bool(os.environ.get("ZIP_ENABLED"), False), |
| "apply_to": (os.environ.get("ZIP_APPLY","both").strip().lower()), |
| "method": os.environ.get("ZIP_METHOD","visionzip").strip().lower(), |
| "keep_dom": _parse_float(os.environ.get("ZIP_KEEP_DOM"), 0.45), |
| "keep_ctx": _parse_float(os.environ.get("ZIP_KEEP_CTX"), 0.10), |
| } |
| if cfg["method"] == "none": |
| cfg["enabled"] = False |
| return cfg |
|
|
| def get_env_aop_config(): |
| """ |
| AOP(层内剪裁)配置: |
| - AOP_ENABLED=1 |
| - AOP_APPLY=qry|cand|both |
| - AOP_LAYER=N(1-based,在进入该层前剪裁一次) |
| - AOP_MODE=delta|ratio |
| - AOP_KEEP_RATIO / AOP_DELTA / AOP_KHAT / AOP_MIN_KEEP / AOP_USE_BIAS |
| - AOP_ATTN_IMPL=sdpa 可选覆盖注意力实现(便于取权重或稳定) |
| """ |
| enabled = _parse_bool(os.environ.get("AOP_ENABLED"), False) |
| apply_to = os.environ.get("AOP_APPLY", "qry").strip().lower() |
| layer_idx = _parse_int(os.environ.get("AOP_LAYER"), None) |
| mode = os.environ.get("AOP_MODE", "delta").strip().lower() |
| delta = _parse_float(os.environ.get("AOP_DELTA"), 0.10) |
| khat = _parse_float(os.environ.get("AOP_KHAT"), 1.0) |
| keep_ratio = _parse_float(os.environ.get("AOP_KEEP_RATIO"), 1.0) |
| min_keep = _parse_int(os.environ.get("AOP_MIN_KEEP"), 64) |
| use_bias = _parse_bool(os.environ.get("AOP_USE_BIAS"), True) |
| attn_impl = os.environ.get("AOP_ATTN_IMPL","").strip().lower() |
| if layer_idx is None and enabled: |
| logger.warning("AOP_ENABLED=1 但未设置 AOP_LAYER,关闭 AOP。"); enabled=False |
| return { |
| "enabled": enabled, |
| "apply_to": apply_to, |
| "layer_idx": layer_idx, |
| "mode": mode, "delta": delta, "K_hat": khat, |
| "keep_ratio": keep_ratio, "min_keep": min_keep, |
| "use_bias": use_bias, "eps": 1e-6, |
| "attn_impl_override": attn_impl if attn_impl in {"sdpa"} else "", |
| } |
|
|
| |
| timing_info = {} |
| token_info = {"vision_tokens":0,"text_input_tokens":0,"text_output_tokens":0,"total_llm_input_tokens":0} |
|
|
| def timing_pre_hook(module, input): |
| module_id = id(module) |
| if module_id not in timing_info: |
| timing_info[module_id] = [] |
| timing_info[module_id].append((time.time(), 'pre', module.__class__.__name__)) |
|
|
| def timing_post_hook(module, input, output): |
| module_id = id(module) |
| if module_id not in timing_info: |
| return |
| timing_info[module_id].append((time.time(), 'post', module.__class__.__name__)) |
| module_name = module.__class__.__name__ |
| if "vision" in module_name.lower() and "transformer" in module_name.lower(): |
| if isinstance(output, torch.Tensor): |
| token_info["vision_tokens"] = output.shape[0] |
| elif hasattr(output, 'last_hidden_state'): |
| token_info["vision_tokens"] = output.last_hidden_state.shape[1] |
|
|
| def register_model_hooks(model): |
| |
| regs=[] |
| core = model |
| if hasattr(model,'encoder') and model.encoder is not None: |
| core = model.encoder |
| if hasattr(core,'visual') and core.visual is not None: |
| core.visual.register_forward_pre_hook(timing_pre_hook) |
| core.visual.register_forward_hook(timing_post_hook); regs.append(core.visual) |
| if hasattr(core,'visual') and hasattr(core.visual,'merger') and core.visual.merger is not None: |
| core.visual.merger.register_forward_pre_hook(timing_pre_hook) |
| core.visual.merger.register_forward_hook(timing_post_hook); regs.append(core.visual.merger) |
| if hasattr(core,'model') and core.model is not None: |
| core.model.register_forward_pre_hook(timing_pre_hook) |
| core.model.register_forward_hook(timing_post_hook); regs.append(core.model) |
| if hasattr(core,'lm_head') and core.lm_head is not None: |
| core.lm_head.register_forward_pre_hook(timing_pre_hook) |
| core.lm_head.register_forward_hook(timing_post_hook); regs.append(core.lm_head) |
| return regs |
|
|
| def pad_dataset_to_divisible(dataset, world_size): |
| n = len(dataset) |
| if n % world_size == 0: return dataset, n |
| m = world_size - (n % world_size) |
| pad = dataset.select([i % len(dataset) for i in range(m)]) |
| return concatenate_datasets([dataset, pad]), n + m |
|
|
| |
| def encode_embeddings( |
| model: MMEBModel, |
| loader: DataLoader, |
| training_args: TrainingArguments, |
| model_args: ModelArguments, |
| full_dataset, |
| encode_side: str, |
| zip_cfg: dict, |
| aop_cfg: dict, |
| description: str = "Encoding" |
| ): |
| local_rank = dist.get_rank() if dist.is_initialized() else 0 |
| is_late_interaction = (model_args.model_backbone == COLPALI) |
|
|
| embeds, infos, batch_stats_list, img_masks_all = [], [], [], [] |
| local_max_len = 0 |
|
|
| model.eval() |
| regs = register_model_hooks(model) |
|
|
| with torch.no_grad(): |
| for inputs, dataset_info in tqdm(loader, desc=f"{description} (rank {local_rank})", disable=local_rank > 0): |
| |
| timing_info.clear() |
| token_info.update({"vision_tokens":0,"text_input_tokens":0,"text_output_tokens":0,"total_llm_input_tokens":0}) |
|
|
| inputs = batch_to_device(inputs, training_args.device) |
| B = inputs.get('input_ids', torch.empty(1,1)).shape[0] |
|
|
| |
| if zip_cfg and zip_cfg.get("enabled", False): |
| apply_to = zip_cfg.get("apply_to","both") |
| side_enable = (apply_to == "both") or (apply_to == encode_side) or (encode_side=="cand" and apply_to=="tgt") |
| if side_enable: |
| inputs["zip_runtime_cfg"] = { |
| "enable": True, |
| "method": zip_cfg.get("method","visionzip"), |
| "keep_dominant_ratio": float(zip_cfg.get("keep_dom", 0.45)), |
| "keep_context_ratio": float(zip_cfg.get("keep_ctx", 0.10)), |
| } |
|
|
| |
| _orig_enabled = None |
| enc = model.encoder |
| if isinstance(aop_cfg, dict) and aop_cfg: |
| _orig_enabled = aop_cfg.get("enabled", False) |
| apply_to = aop_cfg.get("apply_to","qry") |
| side_enable = (apply_to == "both") or (apply_to == encode_side) or (encode_side=="cand" and apply_to=="tgt") |
| aop_cfg["enabled"] = bool(side_enable and _orig_enabled) |
| setattr(enc, "aop_prune_config", aop_cfg) |
|
|
| with torch.autocast(enabled=True, dtype=torch.bfloat16, device_type="cuda"): |
| t0 = time.time() |
| if encode_side == "qry": |
| out = model(qry=inputs) |
| reps = out["qry_reps"].detach() |
| infos.extend(dataset_info) |
| else: |
| out = model(tgt=inputs) |
| reps = out["tgt_reps"].detach() |
| infos.extend([x["cand_name"] for x in dataset_info]) |
| t1 = time.time() |
|
|
| if isinstance(aop_cfg, dict) and _orig_enabled is not None: |
| aop_cfg["enabled"] = _orig_enabled |
| setattr(enc, "aop_prune_config", aop_cfg) |
|
|
| |
| img_masks = None |
| if isinstance(out, dict): |
| img_masks = out.get("image_token_bool_masks", None) |
| if img_masks is None and hasattr(model, "_last_image_token_bool_masks"): |
| img_masks = getattr(model, "_last_image_token_bool_masks") |
| |
| if img_masks is None: |
| img_masks_list = [None] * B |
| elif torch.is_tensor(img_masks): |
| if img_masks.dim() == 2: |
| img_masks_list = img_masks.detach().cpu().tolist() |
| else: |
| img_masks_list = [None] * B |
| elif isinstance(img_masks, list): |
| |
| tmp=[] |
| for m in img_masks: |
| if torch.is_tensor(m): tmp.append(m.detach().cpu().tolist()) |
| else: tmp.append(m) |
| img_masks_list = tmp |
| if len(img_masks_list) < B: img_masks_list += [None]*(B-len(img_masks_list)) |
| if len(img_masks_list) > B: img_masks_list = img_masks_list[:B] |
| else: |
| img_masks_list = [None] * B |
| img_masks_all.extend(img_masks_list) |
|
|
| |
| if 'input_ids' in inputs and inputs['input_ids'] is not None: |
| token_info["total_llm_input_tokens"] = int(inputs['input_ids'].shape[1]) |
| token_info["text_input_tokens"] = max(0, token_info["total_llm_input_tokens"] - token_info["vision_tokens"]) |
| batch_stats = { |
| "batch_size": B, |
| "total_inference_time_seconds": float(t1 - t0), |
| "module_inference_times": {}, |
| "token_counts": { |
| "visual_tokens": token_info["vision_tokens"], |
| "language_input_tokens_raw": token_info["text_input_tokens"], |
| "llm_total_input_tokens": token_info["total_llm_input_tokens"], |
| "language_output_tokens": token_info["text_output_tokens"], |
| } |
| } |
| |
| for m in regs: |
| mid = id(m) |
| name = m.__class__.__name__ |
| times = timing_info.get(mid, []) |
| durations=[]; pre=None |
| for (ts, tp, _) in times: |
| if tp == 'pre': pre = ts |
| elif tp == 'post' and pre is not None: |
| durations.append(ts - pre); pre=None |
| if durations: |
| batch_stats["module_inference_times"][name] = { |
| "total": sum(durations), "count": len(durations), "avg": sum(durations)/len(durations) |
| } |
| else: |
| batch_stats["module_inference_times"][name] = {"total":0.0,"count":0,"avg":0.0} |
| batch_stats_list.append(batch_stats) |
|
|
| print_rank(f"[{encode_side}] time={t1-t0:.4f}s, vis_tokens={token_info['vision_tokens']}") |
|
|
| if is_late_interaction and reps.dim()==3: |
| local_max_len = max(local_max_len, reps.shape[1]) |
| embeds.append(reps) |
|
|
| if not embeds: |
| return np.array([]), [], [], [] |
|
|
| |
| if is_late_interaction: |
| if dist.is_initialized(): |
| lm = torch.tensor(local_max_len, device=training_args.device) |
| dist.all_reduce(lm, op=dist.ReduceOp.MAX) |
| global_max_len = int(lm.item()) |
| else: |
| global_max_len = local_max_len |
| padded=[] |
| for e in embeds: |
| if e.dim()==3: |
| B, L, H = e.shape |
| pad = global_max_len - L |
| e = F.pad(e, (0,0,0,pad), "constant", 0) |
| padded.append(e) |
| embeds_tensor = torch.cat(padded, dim=0).contiguous() |
| else: |
| embeds_tensor = torch.cat(embeds, dim=0).contiguous() |
|
|
| |
| if dist.is_initialized() and len(full_dataset) >= dist.get_world_size(): |
| print_master(f"Gathering {encode_side} embeddings across ranks...") |
| output_shape = list(embeds_tensor.shape); output_shape[0] = len(full_dataset) |
| embeds_tensor = embeds_tensor.to(training_args.device) |
| gathered = torch.empty(output_shape, dtype=embeds_tensor.dtype, device=training_args.device) |
| dist.all_gather_into_tensor(gathered, embeds_tensor) |
| final_embeddings = gathered.cpu().float().numpy() |
|
|
| gathered_infos=[None for _ in range(dist.get_world_size())] |
| dist.all_gather_object(gathered_infos, infos) |
| all_infos=[x for r in gathered_infos for x in r] |
|
|
| gathered_stats=[None for _ in range(dist.get_world_size())] |
| dist.all_gather_object(gathered_stats, batch_stats_list) |
| all_stats=[s for r in gathered_stats for s in r] |
|
|
| gathered_masks=[None for _ in range(dist.get_world_size())] |
| dist.all_gather_object(gathered_masks, img_masks_all) |
| all_masks=[m for r in gathered_masks for m in r] |
| else: |
| final_embeddings = embeds_tensor.cpu().float().numpy() |
| all_infos = infos |
| all_stats = batch_stats_list |
| all_masks = img_masks_all |
|
|
| return final_embeddings, all_infos, all_stats, all_masks |
|
|
| |
| def main(): |
| |
| if "RANK" in os.environ and dist.is_available() and not dist.is_initialized(): |
| timeout = int(os.environ.get("DDP_TIMEOUT_MIN", "60")) |
| dist.init_process_group(backend="nccl", timeout=torch.distributed.timedelta(minutes=timeout)) |
| for arg in sys.argv: |
| if arg.startswith("--local-rank="): |
| rank = arg.split("=")[1] |
| sys.argv.remove(arg); sys.argv += ['--local_rank', rank] |
|
|
| parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) |
| model_args, data_args, training_args = parser.parse_args_into_dataclasses() |
| os.makedirs(data_args.encode_output_path, exist_ok=True) |
|
|
| layers_to_eval = get_env_eval_layers() |
| zip_cfg = get_env_zip_config() |
| aop_cfg = get_env_aop_config() |
| print_master(f"Eval layers: {layers_to_eval}") |
| print_master(f"[ZIP] enabled={zip_cfg.get('enabled',False)}, apply_to={zip_cfg.get('apply_to')}, method={zip_cfg.get('method')}") |
| print_master(f"[AOP] enabled={aop_cfg.get('enabled',False)}, apply_to={aop_cfg.get('apply_to')}, layer={aop_cfg.get('layer_idx')}") |
|
|
| |
| hf_config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) |
| if not getattr(model_args, "model_backbone", None): |
| model_backbone = get_backbone_name(hf_config=hf_config, model_type=model_args.model_type) |
| setattr(model_args, 'model_backbone', model_backbone) |
| setattr(training_args, 'model_backbone', model_backbone) |
| print_master(f"Backbone: {model_args.model_backbone}") |
|
|
| |
| local_rank = dist.get_rank() if dist.is_initialized() else 0 |
| if local_rank == 0: |
| processor = load_processor(model_args, data_args) |
| model = MMEBModel.load(model_args, is_trainable=False, processor=processor) |
| print_master(f"[rank=0] loaded {model_args.model_name}") |
| if dist.is_initialized(): dist.barrier() |
| if local_rank != 0: |
| processor = load_processor(model_args, data_args) |
| time.sleep(random.randint(2*local_rank, 3*local_rank)) |
| model = MMEBModel.load(model_args, is_trainable=False, processor=processor) |
|
|
| model.eval() |
| model = model.to(training_args.device, dtype=torch.bfloat16) |
| |
| model.set_inference_layers(qry_layers=None, tgt_layers=None) |
|
|
| |
| if aop_cfg.get("enabled", False): |
| setattr(model.encoder, "aop_prune_config", aop_cfg) |
| attn_override = aop_cfg.get("attn_impl_override","") |
| if attn_override and hasattr(model.encoder,"model") and hasattr(model.encoder.model,"config"): |
| prev = model.encoder.model.config._attn_implementation |
| model.encoder.model.config._attn_implementation = attn_override |
| print_master(f"[AOP] attn impl override: {prev} -> {attn_override}") |
| else: |
| print_master("[AOP] disabled") |
|
|
| with open(data_args.dataset_config, 'r') as f: |
| dataset_configs = yaml.safe_load(f) |
|
|
| |
| for dataset_name, task_cfg in dataset_configs.items(): |
| if dist.is_initialized(): dist.barrier() |
| print_master(f"\n--- Evaluating {dataset_name} ---") |
|
|
| |
| if data_args.data_basedir: |
| for k in ["image_root","video_root","frame_root","clip_root","data_path"]: |
| if task_cfg.get(k): task_cfg[k] = os.path.join(data_args.data_basedir, task_cfg[k]) |
|
|
| full_qry_dataset, corpus = AutoEvalPairDataset.instantiate(model_args=model_args, data_args=data_args, **task_cfg) |
| full_cand_dataset = generate_cand_dataset(full_qry_dataset, corpus) |
| eval_qry_dataset, eval_cand_dataset = full_qry_dataset, full_cand_dataset |
|
|
| if dist.is_initialized(): |
| ws = dist.get_world_size() |
| padded_qry, _ = pad_dataset_to_divisible(full_qry_dataset, ws) |
| padded_cand, _ = pad_dataset_to_divisible(full_cand_dataset, ws) |
| eval_qry_dataset = split_dataset_by_node(padded_qry, rank=local_rank, world_size=ws) |
| eval_cand_dataset = split_dataset_by_node(padded_cand, rank=local_rank, world_size=ws) |
|
|
| for keep_layers in layers_to_eval: |
| tag = f"layer{keep_layers}" if keep_layers else "layerlast" |
| print_master(f"[{dataset_name}] tag={tag}, keep_layers={keep_layers}") |
| model.set_inference_layers(qry_layers=keep_layers, tgt_layers=keep_layers) |
|
|
| |
| q_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_qry_{tag}") |
| c_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt_{tag}") |
| info_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_info.jsonl") |
| q_stats_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_qry_inference_stats_{tag}.json") |
| c_stats_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_cand_inference_stats_{tag}.json") |
| q_masks_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_qry_img_token_masks_{tag}.jsonl") |
| c_masks_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_cand_img_token_masks_{tag}.jsonl") |
|
|
| |
| def _init_total(): |
| return { |
| "total_inference_time_seconds": 0.0, |
| "module_inference_times": {}, |
| "token_counts": {"visual_tokens":0,"language_input_tokens_raw":0,"llm_total_input_tokens":0,"language_output_tokens":0}, |
| "data_point_count": 0 |
| } |
| def _acc(total, bs): |
| bsz = bs["batch_size"] |
| total["total_inference_time_seconds"] += bs["total_inference_time_seconds"] |
| for m, ms in bs["module_inference_times"].items(): |
| if m not in total["module_inference_times"]: |
| total["module_inference_times"][m] = {"total":0.0,"count":0} |
| total["module_inference_times"][m]["total"] += ms.get("total",0.0) |
| total["module_inference_times"][m]["count"] += ms.get("count",0) |
| for k in total["token_counts"]: |
| total["token_counts"][k] += bs["token_counts"][k] * bsz |
| total["data_point_count"] += bsz |
| def _finalize(total, out_path, task_name, side_name): |
| if local_rank != 0: return |
| n = max(1, total["data_point_count"]) |
| final = { |
| "task_name": task_name, |
| "encode_side": side_name, |
| "data_point_count": total["data_point_count"], |
| "inference_times":{ |
| "total_inference_time_seconds": total["total_inference_time_seconds"], |
| "avg_inference_time_per_item_seconds": total["total_inference_time_seconds"]/n, |
| "module_average_times_per_call": {}, |
| "module_total_times_seconds": {}, |
| "module_calls_count": {}, |
| }, |
| "token_counts":{ |
| "total_visual_tokens": total["token_counts"]["visual_tokens"], |
| "avg_visual_tokens_per_item": total["token_counts"]["visual_tokens"]/n, |
| "total_language_input_tokens_raw": total["token_counts"]["language_input_tokens_raw"], |
| "avg_language_input_tokens_raw_per_item": total["token_counts"]["language_input_tokens_raw"]/n, |
| "total_llm_total_input_tokens": total["token_counts"]["llm_total_input_tokens"], |
| "avg_llm_total_input_tokens_per_item": total["token_counts"]["llm_total_input_tokens"]/n, |
| "total_language_output_tokens": total["token_counts"]["language_output_tokens"], |
| "avg_language_output_tokens_per_item": total["token_counts"]["language_output_tokens"]/n, |
| } |
| } |
| for m, ms in total["module_inference_times"].items(): |
| final["inference_times"]["module_total_times_seconds"][m] = ms["total"] |
| final["inference_times"]["module_calls_count"][m] = ms["count"] |
| final["inference_times"]["module_average_times_per_call"][m] = (ms["total"]/ms["count"]) if ms["count"]>0 else 0.0 |
| with open(out_path, 'w', encoding='utf-8') as f: |
| json.dump(final, f, ensure_ascii=False, indent=4) |
| print_master(f"[{task_name}] {side_name} stats saved: {out_path}") |
|
|
| |
| if (not os.path.exists(q_path)) or (not os.path.exists(info_path)): |
| print_master(f"[{tag}] Encode queries...") |
| coll_q = MultimodalEvalDataCollator(processor, model_args, data_args, "qry") |
| loader_q = DataLoader(eval_qry_dataset, batch_size=training_args.per_device_eval_batch_size, |
| collate_fn=coll_q, num_workers=training_args.dataloader_num_workers) |
| q_embeds, q_infos, q_stats, q_masks = encode_embeddings( |
| model, loader_q, training_args, model_args, full_qry_dataset, |
| encode_side="qry", zip_cfg=zip_cfg, aop_cfg=aop_cfg, description=f"Queries[{tag}] {dataset_name}" |
| ) |
| q_embeds = q_embeds[:len(full_qry_dataset)] |
| q_infos = q_infos[:len(full_qry_dataset)] |
| q_masks = q_masks[:len(full_qry_dataset)] |
|
|
| q_total = _init_total() |
| for bs in q_stats: _acc(q_total, bs) |
|
|
| if local_rank == 0: |
| with open(q_path, 'wb') as f: pickle.dump(q_embeds, f) |
| if not os.path.exists(info_path): |
| with open(info_path, 'w') as f: |
| for info in q_infos: f.write(json.dumps(info) + '\n') |
| with open(q_masks_path, 'w', encoding='utf-8') as f: |
| for i, m in enumerate(q_masks): f.write(json.dumps({"index": i, "mask": m}, ensure_ascii=False) + "\n") |
| _finalize(q_total, q_stats_path, dataset_name, f"query[{tag}]") |
|
|
| if dist.is_initialized(): dist.barrier() |
|
|
| |
| if not os.path.exists(c_path): |
| print_master(f"[{tag}] Encode candidates...") |
| coll_c = MultimodalEvalDataCollator(processor, model_args, data_args, "cand") |
| loader_c = DataLoader(eval_cand_dataset, batch_size=training_args.per_device_eval_batch_size, |
| collate_fn=coll_c, num_workers=training_args.dataloader_num_workers) |
| c_embeds, c_ids, c_stats, c_masks = encode_embeddings( |
| model, loader_c, training_args, model_args, full_cand_dataset, |
| encode_side="cand", zip_cfg=zip_cfg, aop_cfg=aop_cfg, description=f"Cands[{tag}] {dataset_name}" |
| ) |
| c_embeds = c_embeds[:len(full_cand_dataset)] |
| c_ids = c_ids[:len(full_cand_dataset)] |
| c_masks = c_masks[:len(full_cand_dataset)] |
|
|
| c_total = _init_total() |
| for bs in c_stats: _acc(c_total, bs) |
|
|
| if local_rank == 0: |
| cdict = {cid: emb for cid, emb in zip(c_ids, c_embeds)} |
| with open(c_path, 'wb') as f: pickle.dump(cdict, f) |
| with open(c_masks_path, 'w', encoding='utf-8') as f: |
| for cid, m in zip(c_ids, c_masks): f.write(json.dumps({"cand_id": str(cid), "mask": m}, ensure_ascii=False) + "\n") |
| _finalize(c_total, c_stats_path, dataset_name, f"cand[{tag}]") |
|
|
| if dist.is_initialized(): dist.barrier() |
|
|
| |
| if local_rank == 0: |
| info_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_info.jsonl") |
| gt_infos = [json.loads(l) for l in open(info_path)] |
| rank_global = (dataset_configs[dataset_name].get("eval_type", "global") == "global") |
| metrics_to_report = dataset_configs[dataset_name].get("metrics", ["hit","ndcg","precision","recall","f1","map","mrr"]) |
|
|
| for keep_layers in layers_to_eval: |
| tag = f"layer{keep_layers}" if keep_layers else "layerlast" |
| q_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_qry_{tag}") |
| c_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_tgt_{tag}") |
| with open(q_path, 'rb') as f: qry_embeds = pickle.load(f) |
| with open(c_path, 'rb') as f: cand_embed_dict = pickle.load(f) |
|
|
| pred_dicts=[] |
| if rank_global: |
| ck = list(cand_embed_dict.keys()) |
| ce = np.stack([cand_embed_dict[k] for k in ck]) |
| if isinstance(qry_embeds, np.ndarray) and qry_embeds.ndim==3: |
| |
| |
| sim = qry_embeds @ ce.T |
| else: |
| sim = qry_embeds @ ce.T |
| ranked = np.argsort(-sim, axis=1) |
| for qid, gi in tqdm(enumerate(gt_infos), total=len(gt_infos), desc=f"[{tag}] scoring(all) {dataset_name}"): |
| rid = ranked[qid] |
| label = gi["label_name"] if isinstance(gi["label_name"], list) else [gi["label_name"]] |
| pred_dicts.append({"prediction":[ck[i] for i in rid],"label":label,"rel_scores":gi.get("rel_scores")}) |
| else: |
| for qid, (qe, gi) in tqdm(enumerate(zip(qry_embeds, gt_infos)), total=len(gt_infos), desc=f"[{tag}] scoring(local) {dataset_name}"): |
| cand_ids = gi["cand_names"] |
| ce = np.stack([cand_embed_dict[k] for k in cand_ids]) |
| if isinstance(qry_embeds, np.ndarray) and qry_embeds.ndim==3: |
| sim_vec = qe @ ce.T |
| else: |
| sim_vec = qe @ ce.T |
| rid = np.argsort(-sim_vec) |
| label = gi["label_name"] if isinstance(gi["label_name"], list) else [gi["label_name"]] |
| pred_dicts.append({"prediction":[cand_ids[i] for i in rid],"label":label,"rel_scores":gi.get("rel_scores")}) |
|
|
| metrics = RankingMetrics(metrics_to_report) |
| score = metrics.evaluate(pred_dicts) |
| score["num_pred"] = len(pred_dicts); score["num_data"] = len(gt_infos) |
|
|
| layer_score_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_score_{tag}.json") |
| layer_pred_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_pred_{tag}.jsonl") |
| with open(layer_score_path, "w") as f: json.dump(score, f, indent=4) |
| with open(layer_pred_path, "w") as f: |
| for p in pred_dicts: f.write(json.dumps(p) + '\n') |
| print_master(f"[{dataset_name}] {tag} score: " + json.dumps({k: (f'{v:.4f}' if isinstance(v,(int,float)) else v) for k,v in score.items()})) |
|
|
| if dist.is_initialized(): dist.barrier() |
|
|
| if __name__ == '__main__': |
| main() |