| import os |
| import json |
| import math |
| import time |
| import random |
| import datetime |
| import numpy as np |
| import torch |
| import torch.distributed as dist |
| from tqdm import tqdm |
| from torch.utils.data import DataLoader |
| from transformers import HfArgumentParser, AutoConfig |
| from sklearn.model_selection import train_test_split |
| import yaml |
| from datasets import concatenate_datasets |
|
|
| from src.arguments import ModelArguments, DataArguments, TrainingArguments |
| from src.data.collator.eval_collator import MultimodalEvalDataCollator |
| from src.data.eval_dataset.base_eval_dataset import AutoEvalPairDataset, generate_cand_dataset |
| from src.model.model_cut_layer_AOP_add_text_cut import MMEBModel |
| from src.model.processor import get_backbone_name, load_processor |
| from src.utils import batch_to_device, print_master |
|
|
|
|
| |
| def _parse_bool(v: str, default=False): |
| if v is None: |
| return default |
| v = v.strip().lower() |
| return v in {"1", "true", "yes", "y", "t", "on"} |
|
|
|
|
| def _parse_int(v: str, default=None): |
| try: |
| return int(v) if v is not None else default |
| except Exception: |
| return default |
|
|
|
|
| def _parse_float(v: str, default=None): |
| try: |
| return float(v) if v is not None else default |
| except Exception: |
| return default |
|
|
|
|
| def get_env_aop_config(): |
| enabled = _parse_bool(os.environ.get("AOP_ENABLED"), False) |
| apply_to = os.environ.get("AOP_APPLY", "qry").strip().lower() |
| layer_idx = _parse_int(os.environ.get("AOP_LAYER"), None) |
| mode = os.environ.get("AOP_MODE", "ratio").strip().lower() |
| prune_vision = _parse_bool(os.environ.get("AOP_PRUNE_VISION"), True) |
| prune_text = _parse_bool(os.environ.get("AOP_PRUNE_TEXT"), False) |
| keep_ratio_v = _parse_float(os.environ.get("AOP_KEEP_RATIO_VISION"), None) |
| keep_ratio_t = _parse_float(os.environ.get("AOP_KEEP_RATIO_TEXT"), None) |
| attn_agg = os.environ.get("AOP_ATTENTION_AGG", "mean").strip().lower() |
| ee_layer = _parse_int(os.environ.get("EE_LAYER"), None) |
|
|
| return { |
| "enabled": enabled, |
| "apply_to": apply_to, |
| "layer_idx": layer_idx, |
| "mode": mode, |
| "prune_vision": prune_vision, |
| "prune_text": prune_text, |
| "keep_ratio_vision": keep_ratio_v, |
| "keep_ratio_text": keep_ratio_t, |
| "attn_agg": attn_agg, |
| "ee_layer": ee_layer, |
| } |
|
|
|
|
| def pad_dataset_to_divisible(dataset, world_size): |
| num_samples = len(dataset) |
| if num_samples % world_size == 0: |
| return dataset, num_samples |
| num_to_add = world_size - (num_samples % world_size) |
| padding_data = dataset.select([i % len(dataset) for i in range(num_to_add)]) |
| padded_dataset = concatenate_datasets([dataset, padding_data]) |
| return padded_dataset, num_samples + num_to_add |
|
|
|
|
| |
| @torch.no_grad() |
| def encode_candidates_both_layers(model: MMEBModel, loader: DataLoader, training_args: TrainingArguments, mid_layer: int): |
| model.eval() |
| all_mid, all_last, all_ids = [], [], [] |
| for inputs, infos in tqdm(loader, desc="[DUMP] Cands[BOTH]", disable=False): |
| inputs = batch_to_device(inputs, training_args.device) |
| |
| aop_cfg = getattr(model.encoder, "aop_prune_config", None) |
| if isinstance(aop_cfg, dict) and aop_cfg: |
| aop_off = dict(aop_cfg) |
| aop_off["enabled"] = False |
| setattr(model.encoder, "aop_prune_config", aop_off) |
|
|
| with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): |
| out = model.encoder( |
| **inputs, |
| return_dict=True, |
| output_hidden_states=True, |
| stop_at_layer=None, |
| compute_lm_head=False, |
| ) |
| hs_list = out.hidden_states |
| assert hs_list is not None and len(hs_list) > mid_layer, "hidden_states too short for mid_layer" |
| mid_hs, last_hs = hs_list[mid_layer], hs_list[-1] |
| am = inputs.get("attention_mask", None) |
| if am is not None and hasattr(am, "device") and am.device != mid_hs.device: |
| am = am.to(mid_hs.device) |
| reps_mid = model._pooling(mid_hs, am).detach().float().cpu() |
| reps_last = model._pooling(last_hs, am).detach().float().cpu() |
| all_mid.append(reps_mid) |
| all_last.append(reps_last) |
| all_ids.extend([info["cand_name"] for info in infos]) |
| |
| |
| if isinstance(aop_cfg, dict) and aop_cfg: |
| setattr(model.encoder, "aop_prune_config", aop_cfg) |
| |
| cand_mid = torch.cat(all_mid, dim=0).numpy() |
| cand_last = torch.cat(all_last, dim=0).numpy() |
| return cand_mid, cand_last, all_ids |
|
|
|
|
| |
| @torch.no_grad() |
| def build_phaseA_features_global( |
| reps_mid_t: torch.Tensor, |
| cand_mid_t: torch.Tensor, |
| am_mid: torch.Tensor, |
| input_ids: torch.Tensor, |
| cfg, |
| topk: int = 200, |
| temp: float = 0.05, |
| ): |
| device = reps_mid_t.device |
| B = reps_mid_t.size(0) |
| |
| scores_t = reps_mid_t @ cand_mid_t.T |
| k = min(topk, scores_t.size(1)) |
| vals_t, _ = torch.topk(scores_t, k=k, dim=1) |
| s1 = vals_t[:, 0] |
| s2 = vals_t[:, 1] if k >= 2 else torch.zeros_like(s1) |
| margin = s1 - s2 |
| p_t = torch.softmax(vals_t / max(temp, 1e-6), dim=1) |
| H = -(p_t * (torch.log(p_t + 1e-12))).sum(dim=1) / math.log(max(k, 1)) |
| sum_p2 = (p_t**2).sum(dim=1) |
|
|
| |
| am = am_mid.to(torch.bool) |
| iid = input_ids |
| image_token_id = getattr(cfg, "image_token_id", None) |
| video_token_id = getattr(cfg, "video_token_id", None) |
| bos_id = getattr(cfg, "bos_token_id", None) |
| eos_id = getattr(cfg, "eos_token_id", None) |
| pad_id = getattr(cfg, "pad_token_id", None) |
| is_image = (iid == image_token_id) if (image_token_id is not None and image_token_id >= 0) else torch.zeros_like(iid, dtype=torch.bool) |
| is_video = (iid == video_token_id) if (video_token_id is not None and video_token_id >= 0) else torch.zeros_like(iid, dtype=torch.bool) |
| is_vision = (is_image | is_video) & am |
|
|
| is_special = torch.zeros_like(iid, dtype=torch.bool) |
| for tid in [bos_id, eos_id, pad_id]: |
| if tid is not None and tid >= 0: |
| is_special |= (iid == tid) |
| is_text = am & (~is_vision) & (~is_special) |
|
|
| L_vis = is_vision.sum(dim=1).float() |
| L_txt = is_text.sum(dim=1).float() |
| L_tot = am.sum(dim=1).float().clamp(min=1.0) |
| r_vis = L_vis / L_tot |
| r_txt = L_txt / L_tot |
|
|
| |
| is_I = ((L_vis > 0) & (L_txt == 0)).float() |
| is_T = ((L_txt > 0) & (L_vis == 0)).float() |
| is_IT = ((L_txt > 0) & (L_vis > 0)).float() |
|
|
| feats = torch.stack([s1, s2, margin, H, sum_p2, L_txt, L_vis, L_tot, r_txt, r_vis, is_I, is_T, is_IT], dim=1) |
| return feats |
|
|
|
|
| @torch.no_grad() |
| def build_phaseA_features_local( |
| reps_mid_t: torch.Tensor, |
| cand_mid_t: torch.Tensor, |
| am_mid: torch.Tensor, |
| input_ids: torch.Tensor, |
| cfg, |
| per_sample_rows: list, |
| topk: int = 200, |
| temp: float = 0.05, |
| ): |
| device = reps_mid_t.device |
| B = reps_mid_t.size(0) |
| s1_list, s2_list, H_list, sum_p2_list = [], [], [], [] |
| for b in range(B): |
| rows = per_sample_rows[b] |
| if len(rows) == 0: |
| s1_list.append(torch.tensor(0.0, device=device)) |
| s2_list.append(torch.tensor(0.0, device=device)) |
| H_list.append(torch.tensor(1.0, device=device)) |
| sum_p2_list.append(torch.tensor(0.0, device=device)) |
| continue |
| cmat = cand_mid_t[rows] |
| sv = (reps_mid_t[b:b+1] @ cmat.T)[0] |
| k = min(topk, sv.size(0)) |
| vals, _ = torch.topk(sv, k=k, dim=0) |
| s1_list.append(vals[0]) |
| s2_list.append(vals[1] if k >= 2 else torch.tensor(0.0, device=device, dtype=vals.dtype)) |
| p = torch.softmax(vals / max(temp, 1e-6), dim=0) |
| H_list.append((-(p * (torch.log(p + 1e-12))).sum() / math.log(max(k, 1)))) |
| sum_p2_list.append((p**2).sum()) |
| s1 = torch.stack(s1_list) |
| s2 = torch.stack(s2_list) |
| H = torch.stack(H_list) |
| sum_p2 = torch.stack(sum_p2_list) |
| margin = s1 - s2 |
|
|
| am = am_mid.to(torch.bool) |
| iid = input_ids |
| image_token_id = getattr(cfg, "image_token_id", None) |
| video_token_id = getattr(cfg, "video_token_id", None) |
| bos_id = getattr(cfg, "bos_token_id", None) |
| eos_id = getattr(cfg, "eos_token_id", None) |
| pad_id = getattr(cfg, "pad_token_id", None) |
| is_image = (iid == image_token_id) if (image_token_id is not None and image_token_id >= 0) else torch.zeros_like(iid, dtype=torch.bool) |
| is_video = (iid == video_token_id) if (video_token_id is not None and video_token_id >= 0) else torch.zeros_like(iid, dtype=torch.bool) |
| is_vision = (is_image | is_video) & am |
|
|
| is_special = torch.zeros_like(iid, dtype=torch.bool) |
| for tid in [bos_id, eos_id, pad_id]: |
| if tid is not None and tid >= 0: |
| is_special |= (iid == tid) |
| is_text = am & (~is_vision) & (~is_special) |
|
|
| L_vis = is_vision.sum(dim=1).float() |
| L_txt = is_text.sum(dim=1).float() |
| L_tot = am.sum(dim=1).float().clamp(min=1.0) |
| r_vis = L_vis / L_tot |
| r_txt = L_txt / L_tot |
|
|
| is_I = ((L_vis > 0) & (L_txt == 0)).float() |
| is_T = ((L_txt > 0) & (L_vis == 0)).float() |
| is_IT = ((L_txt > 0) & (L_vis > 0)).float() |
|
|
| feats = torch.stack([s1, s2, margin, H, sum_p2, L_txt, L_vis, L_tot, r_txt, r_vis, is_I, is_T, is_IT], dim=1) |
| return feats |
|
|
|
|
| |
| def compute_label_top1_equal_global(scores_mid: np.ndarray, scores_last: np.ndarray) -> np.ndarray: |
| top1_mid = scores_mid.argmax(axis=1) |
| top1_last = scores_last.argmax(axis=1) |
| return (top1_mid == top1_last).astype(np.int32) |
|
|
|
|
| def compute_label_top1_equal_local(scores_mid_list, scores_last_list): |
| y = [] |
| for sv_mid, sv_last in zip(scores_mid_list, scores_last_list): |
| if sv_mid.size == 0 or sv_last.size == 0: |
| y.append(0) |
| else: |
| y.append(int(int(sv_mid.argmax()) == int(sv_last.argmax()))) |
| return np.array(y, dtype=np.int32) |
|
|
|
|
| |
| def main(): |
| |
| if "RANK" in os.environ and dist.is_available() and not dist.is_initialized(): |
| dist.init_process_group(backend="nccl", timeout=datetime.timedelta(minutes=60)) |
| local_rank = dist.get_rank() if dist.is_initialized() else 0 |
|
|
| parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) |
| model_args, data_args, training_args = parser.parse_args_into_dataclasses() |
| os.makedirs(data_args.encode_output_path, exist_ok=True) |
|
|
| |
| hf_config = AutoConfig.from_pretrained(model_args.model_name, trust_remote_code=True) |
| if not getattr(model_args, "model_backbone", None): |
| model_backbone = get_backbone_name(hf_config=hf_config, model_type=model_args.model_type) |
| setattr(model_args, 'model_backbone', model_backbone) |
| setattr(training_args, 'model_backbone', model_backbone) |
|
|
| if local_rank == 0: |
| processor = load_processor(model_args, data_args) |
| model = MMEBModel.load(model_args, is_trainable=False, processor=processor) |
| if dist.is_initialized(): |
| dist.barrier() |
| if local_rank != 0: |
| processor = load_processor(model_args, data_args) |
| time.sleep(random.randint(2 * local_rank, 3 * local_rank)) |
| model = MMEBModel.load(model_args, is_trainable=False, processor=processor) |
|
|
| model.eval() |
| model = model.to(training_args.device, dtype=torch.bfloat16) |
|
|
| |
| ee_layer = int(os.environ.get("EE_LAYER", os.environ.get("AOP_LAYER", "12"))) |
| feat_topk = int(os.environ.get("EE_FEAT_TOPK", "200")) |
| force_no_aop = os.environ.get("DUMP_EXIT_NO_AOP", "1").strip().lower() in {"1", "true", "yes", "on"} |
| |
| |
| TRAIN_RATIO = 0.1 |
| VAL_RATIO = 0.1 |
| |
|
|
| with open(data_args.dataset_config, 'r', encoding='utf-8') as yf: |
| dataset_configs = yaml.safe_load(yf) |
|
|
| for dataset_name, task_cfg in dataset_configs.items(): |
| if dist.is_initialized(): dist.barrier() |
| print_master(f"\n[DUMP] Processing {dataset_name} ...") |
|
|
| if data_args.data_basedir: |
| for key in ["image_root", "video_root", "frame_root", "clip_root", "data_path"]: |
| if task_cfg.get(key): |
| task_cfg[key] = os.path.join(data_args.data_basedir, task_cfg[key]) |
|
|
| |
| full_qry, corpus = AutoEvalPairDataset.instantiate(model_args=model_args, data_args=data_args, **task_cfg) |
| full_cand = generate_cand_dataset(full_qry, corpus) |
|
|
| |
| |
| total_len = len(full_qry) |
| all_indices = np.arange(total_len) |
| |
| |
| train_idxs, temp_idxs = train_test_split( |
| all_indices, train_size=TRAIN_RATIO, random_state=42, shuffle=True |
| ) |
| val_relative_ratio = VAL_RATIO / (1.0 - TRAIN_RATIO) |
| val_idxs, test_idxs = train_test_split( |
| temp_idxs, train_size=val_relative_ratio, random_state=42, shuffle=True |
| ) |
| |
| print_master(f"[DUMP] Split sizes -> Train: {len(train_idxs)}, Val: {len(val_idxs)}, Test: {len(test_idxs)}") |
|
|
| |
| splits = { |
| "train": {"ds": full_qry.select(train_idxs), "indices": train_idxs}, |
| "val": {"ds": full_qry.select(val_idxs), "indices": val_idxs}, |
| "test": {"ds": full_qry.select(test_idxs), "indices": test_idxs} |
| } |
|
|
| |
| cand_collator = MultimodalEvalDataCollator(processor, model_args, data_args, "cand") |
| cand_loader = DataLoader( |
| full_cand, batch_size=training_args.per_device_eval_batch_size, |
| collate_fn=cand_collator, num_workers=training_args.dataloader_num_workers |
| ) |
| cand_mid_np, cand_last_np, cand_ids = encode_candidates_both_layers(model, cand_loader, training_args, mid_layer=ee_layer) |
| cand_id2row = {str(cid): i for i, cid in enumerate(cand_ids)} |
| device = training_args.device |
| cand_mid_t = torch.from_numpy(cand_mid_np).to(device=device, dtype=torch.bfloat16) |
| cand_last_t = None |
|
|
| |
| sum_feat, sum2_feat, n_feat = None, None, 0 |
| scaler_path = os.path.join(data_args.encode_output_path, f"{dataset_name}_phaseA_scaler.json") |
|
|
| for split_name, split_info in splits.items(): |
| qry_dataset = split_info["ds"] |
| global_indices = split_info["indices"] |
| |
| if len(qry_dataset) == 0: continue |
| |
| |
| if dist.is_initialized(): |
| world_size = dist.get_world_size() |
| per_rank = len(qry_dataset) // world_size |
| start_idx = local_rank * per_rank |
| end_idx = start_idx + per_rank |
| |
| if start_idx >= len(qry_dataset): |
| local_dataset = qry_dataset.select([]) |
| local_indices = [] |
| else: |
| local_dataset = qry_dataset.select(range(start_idx, end_idx)) |
| local_indices = global_indices[start_idx : end_idx] |
| else: |
| local_dataset = qry_dataset |
| local_indices = global_indices |
|
|
| qry_collator = MultimodalEvalDataCollator(processor, model_args, data_args, "qry") |
| qry_loader = DataLoader( |
| local_dataset, |
| batch_size=training_args.per_device_eval_batch_size, |
| collate_fn=qry_collator, |
| num_workers=training_args.dataloader_num_workers, |
| shuffle=False |
| ) |
|
|
| feat_out_path_rank = os.path.join(data_args.encode_output_path, f"{dataset_name}_{split_name}_features.jsonl.rank{local_rank}") |
| print_master(f" -> Dump {split_name} features to {feat_out_path_rank} ...") |
|
|
| |
| cursor = 0 |
|
|
| with open(feat_out_path_rank, "w", encoding="utf-8") as fout: |
| for inputs, infos in tqdm(qry_loader, desc=f"[{split_name.upper()}]", disable=(local_rank!=0)): |
| inputs = batch_to_device(inputs, device) |
| B = inputs["input_ids"].size(0) |
|
|
| |
| batch_global_ids = local_indices[cursor : cursor + B] |
| cursor += B |
|
|
| |
| aop_cfg_cur = getattr(model.encoder, "aop_prune_config", None) |
| orig_aop = None |
| if force_no_aop and isinstance(aop_cfg_cur, dict): |
| orig_aop = dict(aop_cfg_cur) |
| aop_off = dict(aop_cfg_cur) |
| aop_off["enabled"] = False |
| setattr(model.encoder, "aop_prune_config", aop_off) |
|
|
| with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): |
| out_mid = model.encoder( |
| **inputs, return_dict=True, output_hidden_states=False, |
| stop_at_layer=int(ee_layer), compute_lm_head=False, |
| return_intermediate_state=True |
| ) |
| if orig_aop is not None: setattr(model.encoder, "aop_prune_config", orig_aop) |
|
|
| |
| hs_mid = getattr(out_mid, "last_hidden_state", None) |
| if hs_mid is None: hs_mid = out_mid.hidden_states[-1] |
| am_mid = getattr(out_mid, "attention_mask", None) |
| if am_mid is None: am_mid = inputs.get("attention_mask") |
| if hasattr(am_mid, "device") and am_mid.device != hs_mid.device: am_mid = am_mid.to(hs_mid.device) |
| reps_mid_t = model._pooling(hs_mid, am_mid).detach().to(device=device, dtype=torch.bfloat16) |
|
|
| |
| rank_global = task_cfg.get("eval_type", "global") == "global" |
| if rank_global: |
| feats_t = build_phaseA_features_global(reps_mid_t, cand_mid_t, am_mid, inputs["input_ids"], model.encoder.config, topk=feat_topk) |
| else: |
| rows_list = [] |
| for b_idx in range(B): |
| cand_local = infos[b_idx]["cand_names"] |
| rows = [cand_id2row.get(str(cid), -1) for cid in cand_local] |
| rows = [r for r in rows if r >= 0] |
| rows_list.append(rows) |
| feats_t = build_phaseA_features_local(reps_mid_t, cand_mid_t, am_mid, inputs["input_ids"], model.encoder.config, rows_list, topk=feat_topk) |
| feats_np = feats_t.detach().float().cpu().numpy() |
|
|
| |
| interm = getattr(out_mid, "intermediate_state", None) |
| resume_state = { |
| "hidden_states": interm["hidden_states"].detach(), |
| "attention_mask": interm["attention_mask"].detach(), |
| "position_ids": interm["position_ids"].detach(), |
| "vision_mask": interm.get("vision_mask"), |
| "text_mask": interm.get("text_mask"), |
| "next_layer_idx": int(interm["next_layer_idx"]) |
| } |
| aop_cfg_cur = getattr(model.encoder, "aop_prune_config", None) |
| orig_aop2 = None |
| if force_no_aop and isinstance(aop_cfg_cur, dict): |
| orig_aop2 = dict(aop_cfg_cur) |
| aop_off2 = dict(aop_cfg_cur) |
| aop_off2["enabled"] = False |
| setattr(model.encoder, "aop_prune_config", aop_off2) |
|
|
| with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): |
| out_last = model.encoder( |
| return_dict=True, output_hidden_states=False, stop_at_layer=None, |
| resume_state=resume_state, compute_lm_head=False |
| ) |
| if orig_aop2 is not None: setattr(model.encoder, "aop_prune_config", orig_aop2) |
|
|
| |
| hs_last = getattr(out_last, "last_hidden_state", None) |
| if hs_last is None: hs_last = out_last.hidden_states[-1] |
| am_last = getattr(out_last, "attention_mask", None) |
| if am_last is None: am_last = resume_state["attention_mask"] |
| if hasattr(am_last, "device") and am_last.device != hs_last.device: am_last = am_last.to(hs_last.device) |
| reps_last_t = model._pooling(hs_last, am_last).detach().to(device=device, dtype=torch.bfloat16) |
|
|
| |
| if rank_global: |
| if cand_last_t is None: cand_last_t = torch.from_numpy(cand_last_np).to(device=device, dtype=torch.bfloat16) |
| sim_mid = (reps_mid_t @ cand_mid_t.T).detach().float().cpu().numpy() |
| sim_last = (reps_last_t @ cand_last_t.T).detach().float().cpu().numpy() |
| y = compute_label_top1_equal_global(sim_mid, sim_last) |
| else: |
| y_list = [] |
| for b_idx in range(B): |
| cand_local = infos[b_idx]["cand_names"] |
| rows = [cand_id2row.get(str(cid), -1) for cid in cand_local] |
| rows = [r for r in rows if r >= 0] |
| if not rows: |
| y_list.append(0) |
| continue |
| c_mid = cand_mid_t[rows] |
| if cand_last_t is None: cand_last_t = torch.from_numpy(cand_last_np).to(device=device, dtype=torch.bfloat16) |
| c_last = cand_last_t[rows] |
| sv_mid = (reps_mid_t[b_idx:b_idx+1] @ c_mid.T)[0].detach().float().cpu().numpy() |
| sv_last = (reps_last_t[b_idx:b_idx+1] @ c_last.T)[0].detach().float().cpu().numpy() |
| y_list.append(int(int(sv_mid.argmax()) == int(sv_last.argmax()))) |
| y = np.array(y_list, dtype=np.int32) |
| |
| |
| |
| if split_name == "train": |
| if sum_feat is None: |
| sum_feat = feats_np.sum(axis=0) |
| sum2_feat = (feats_np**2).sum(axis=0) |
| else: |
| sum_feat += feats_np.sum(axis=0) |
| sum2_feat += (feats_np**2).sum(axis=0) |
| n_feat += feats_np.shape[0] |
| |
| L_txt = feats_np[:, 5] |
| L_vis = feats_np[:, 6] |
| types = np.where((L_vis > 0) & (L_txt == 0), "I", np.where((L_txt > 0) & (L_vis == 0), "T", "IT")) |
|
|
| for b_idx in range(B): |
| row = { |
| "dataset": dataset_name, |
| "split": split_name, |
| "qid": int(batch_global_ids[b_idx]), |
| "type": str(types[b_idx]), |
| "feats": feats_np[b_idx].tolist(), |
| "y_exit": int(y[b_idx]), |
| } |
| fout.write(json.dumps(row, ensure_ascii=False) + "\n") |
|
|
| |
| if dist.is_initialized(): |
| dist.barrier() |
| stats_vec = torch.tensor( |
| np.concatenate([sum_feat, sum2_feat, [n_feat]]) if n_feat > 0 else np.zeros(13*2+1), |
| device=device, dtype=torch.float64 |
| ) |
| dist.all_reduce(stats_vec, op=dist.ReduceOp.SUM) |
| sum_feat_all = stats_vec[:13].cpu().numpy() |
| sum2_feat_all = stats_vec[13:26].cpu().numpy() |
| n_feat_all = stats_vec[26].item() |
| else: |
| sum_feat_all = sum_feat |
| sum2_feat_all = sum2_feat |
| n_feat_all = n_feat |
|
|
| if local_rank == 0 and n_feat_all > 0: |
| mean = (sum_feat_all / n_feat_all).tolist() |
| var = (sum2_feat_all / n_feat_all - (sum_feat_all / n_feat_all) ** 2) |
| std = [float(max(1e-6, math.sqrt(max(0.0, v)))) for v in var.tolist()] |
| with open(scaler_path, "w", encoding="utf-8") as f: |
| json.dump({"mean": mean, "std": std, "in_dim": len(mean), "n_samples": n_feat_all, "dataset": dataset_name}, f, indent=2) |
| print_master(f"[DUMP] {dataset_name} Scaler saved -> {scaler_path}") |
|
|
| if dist.is_initialized(): dist.barrier() |
|
|
| if __name__ == "__main__": |
| main() |