| import time |
| import json |
| import os |
| import pickle |
| import numpy as np |
| import torch |
| import torch.distributed as dist |
| from torch.utils.data import DataLoader |
| from tqdm import tqdm |
| from transformers import HfArgumentParser, AutoConfig |
|
|
| from src.arguments import ModelArguments, DataArguments, TrainingArguments |
| from src.data.collator.eval_collator import MultimodalEvalDataCollator |
| from src.data.eval_dataset.base_eval_dataset import AutoEvalPairDataset, generate_cand_dataset |
| from src.eval_utils.metrics import RankingMetrics |
| from src.model.model_cut_layer_AOP_add_text_cut import MMEBModel |
| from src.model.processor import get_backbone_name, load_processor |
| from src.utils import batch_to_device, print_master |
| from src.classifier_utils_V5 import EarlyExitClassifier |
|
|
| |
| |
| |
| def get_env_config(): |
| ee_enabled = os.environ.get("EE_ENABLED", "0").strip().lower() in {"1","true","yes","on"} |
| |
| threshold = float(os.environ.get("EE_THRESHOLD", "0.5")) |
| classifier_path = os.environ.get("EE_CLASSIFIER_PATH", "") |
| layer = int(os.environ.get("EE_LAYER", "12")) |
| |
| |
| aop_enabled = os.environ.get("AOP_ENABLED", "0").strip().lower() in {"1","true","yes","on"} |
| |
| |
| return { |
| "ee_enabled": ee_enabled, |
| "ee_threshold": threshold, |
| "ee_layer": layer, |
| "classifier_path": classifier_path, |
| "aop_enabled": aop_enabled |
| } |
|
|
| |
| |
| |
| def run_benchmark( |
| model, classifier, processor, model_args, data_args, training_args, |
| qry_dataset, cand_mid_dict, cand_last_dict, cfg, dataset_name, out_dir |
| ): |
| device = training_args.device |
| local_rank = dist.get_rank() if dist.is_initialized() else 0 |
| is_main = (local_rank == 0) |
|
|
| |
| collator = MultimodalEvalDataCollator(processor, model_args, data_args, "qry") |
| loader = DataLoader( |
| qry_dataset, batch_size=training_args.per_device_eval_batch_size, |
| collate_fn=collator, num_workers=training_args.dataloader_num_workers |
| ) |
|
|
| |
| cand_ids = list(cand_last_dict.keys()) |
| |
| cand_last_np = np.stack([cand_last_dict[c] for c in cand_ids]).astype(np.float32) |
| |
| if cfg["ee_enabled"]: |
| cand_mid_np = np.stack([cand_mid_dict[c] for c in cand_ids]).astype(np.float32) |
| |
| cand_mid_t = torch.from_numpy(cand_mid_np).to(device=device, dtype=torch.bfloat16) |
| |
| model.eval() |
| if classifier: classifier.eval(); classifier.to(device) |
|
|
| |
| print_master(f"🔥 Warming up GPU...") |
| for _ in range(5): |
| dummy_inputs = next(iter(loader)) |
| dummy_inputs = batch_to_device(dummy_inputs, device) |
| with torch.no_grad(), torch.autocast("cuda", dtype=torch.bfloat16): |
| _ = model.encoder(**dummy_inputs, stop_at_layer=None) |
| torch.cuda.synchronize() |
| |
| |
| |
| |
| total_samples = 0 |
| start_time = time.perf_counter() |
| |
| pred_dicts = [] |
| |
| for inputs, infos in tqdm(loader, desc=f"Benchmarking {dataset_name}", disable=not is_main): |
| inputs = batch_to_device(inputs, device) |
| B = inputs["input_ids"].shape[0] |
| total_samples += B |
|
|
| |
| |
| |
| if not cfg["ee_enabled"]: |
| with torch.no_grad(), torch.autocast("cuda", dtype=torch.bfloat16): |
| |
| out = model.encoder( |
| **inputs, |
| return_dict=True, |
| output_hidden_states=False, |
| stop_at_layer=None, |
| compute_lm_head=False |
| ) |
| |
| |
| hs = out.last_hidden_state |
| if hs is None: hs = out.hidden_states[-1] |
| am = getattr(out, "attention_mask", None) |
| if am is None: am = inputs.get("attention_mask", None) |
| reps = model._pooling(hs, am).float().cpu().numpy() |
| |
| |
| scores = np.dot(reps, cand_last_np.T) |
| topk_inds = np.argsort(-scores, axis=1)[:, :50] |
| |
| |
| for i in range(B): |
| cids = [cand_ids[k] for k in topk_inds[i]] |
| _record(pred_dicts, infos[i], cids) |
|
|
| |
| |
| |
| else: |
| with torch.no_grad(), torch.autocast("cuda", dtype=torch.bfloat16): |
| |
| out_mid = model.encoder( |
| **inputs, return_dict=True, output_hidden_states=False, |
| stop_at_layer=cfg["ee_layer"], compute_lm_head=False |
| ) |
| |
| hs_mid = getattr(out_mid, "last_hidden_state", None) |
| if hs_mid is None: hs_mid = out_mid.hidden_states[-1] |
| am_mid = getattr(out_mid, "attention_mask", None) |
| if am_mid is None: am_mid = inputs.get("attention_mask", None) |
| reps_mid = model._pooling(hs_mid, am_mid) |
|
|
| |
| |
| cos_mid = reps_mid @ cand_mid_t.T |
| |
| |
| |
| |
| scalar_inputs = _mock_feature_extraction(cos_mid, device) |
| |
| |
| logits = classifier(scalar_inputs.float(), torch.zeros(B, dtype=torch.long, device=device), qry_emb=reps_mid.float()) |
| probs = torch.sigmoid(logits).squeeze(1) |
| |
| |
| |
| |
| k = int(B * 0.5) |
| top_vals, _ = torch.topk(probs, k=k) |
| dyn_thresh = top_vals[-1] |
| exit_mask = (probs < dyn_thresh).cpu().numpy() |
| |
| |
| exit_indices = np.where(exit_mask)[0] |
| cont_indices = np.where(~exit_mask)[0] |
| |
| |
| if len(exit_indices) > 0: |
| reps_exit_np = reps_mid[exit_indices].float().cpu().numpy() |
| scores = np.dot(reps_exit_np, cand_mid_np.T) |
| inds = np.argsort(-scores, axis=1)[:, :50] |
| for i, idx in enumerate(exit_indices): |
| _record(pred_dicts, infos[idx], [cand_ids[x] for x in inds[i]]) |
| |
| |
| if len(cont_indices) > 0: |
| |
| interm = out_mid.intermediate_state |
| |
| subset = {k: v[cont_indices] if v is not None and isinstance(v, torch.Tensor) else v |
| for k,v in interm.items() if k in ["hidden_states", "attention_mask", "position_ids"]} |
| subset["next_layer_idx"] = int(interm["next_layer_idx"]) |
| |
| out_last = model.encoder( |
| return_dict=True, output_hidden_states=False, stop_at_layer=None, |
| resume_state=subset, compute_lm_head=False |
| ) |
| |
| |
| hs = out_last.last_hidden_state |
| am = subset["attention_mask"] |
| reps_cont = model._pooling(hs, am).float().cpu().numpy() |
| scores = np.dot(reps_cont, cand_last_np.T) |
| inds = np.argsort(-scores, axis=1)[:, :50] |
| for i, idx in enumerate(cont_indices): |
| _record(pred_dicts, infos[idx], [cand_ids[x] for x in inds[i]]) |
|
|
| |
| torch.cuda.synchronize() |
| end_time = time.perf_counter() |
| total_time = end_time - start_time |
| |
| latency_ms = (total_time / total_samples) * 1000 |
| throughput = total_samples / total_time |
| |
| |
| print_master(f"\n[BENCHMARK_RESULT] Mode={'Ours' if cfg['ee_enabled'] else 'Baseline'} | Samples={total_samples} | TotalTime={total_time:.4f}s | Latency={latency_ms:.4f}ms | Throughput={throughput:.2f}qps") |
| |
| return pred_dicts |
|
|
| def _record(pred_dicts, info, cids): |
| pred_dicts.append({ |
| "prediction": cids, |
| "label": info.get("label_name") or info.get("label"), |
| }) |
|
|
| def _mock_feature_extraction(cos_mid, device): |
| |
| |
| return torch.randn(cos_mid.size(0), 27, device=device) |
|
|
| |
| def main(): |
| |
| |
| cfg = get_env_config() |
| |
| |
| classifier = None |
| if cfg["ee_enabled"]: |
| print_master(f"🚀 Mode: Accelerated (EE + AOP)") |
| |
| |
| else: |
| print_master(f"🐢 Mode: Baseline (Full Forward)") |
|
|
| |
| |
| |
| run_benchmark(model, classifier, ...) |
|
|
| if __name__ == '__main__': |
| main() |