| """ |
| benchmark.py — single self-contained reproducibility script for |
| haremb-privacy-filter-opennemo. |
| |
| Run from inside this folder: |
| |
| python benchmark.py # default: cuda if available |
| python benchmark.py --device cpu # cpu fallback |
| python benchmark.py --eval-pct 0.5 # smaller slice |
| python benchmark.py --no-base # skip teacher download |
| |
| Produces, in `--out` (default ./): |
| infer.log — sample inference timing + redaction example |
| compare.log — aggregate + per-category metrics, this model vs |
| OpenMed teacher (raw + viterbi streams), and |
| token-level pairwise breakdown. |
| eval_summary.png — bar charts of headline metrics + per-category |
| span-F1 (this vs teacher). |
| eval_confusion.png — token-level outcome breakdown on gold non-O |
| positions (this vs teacher). |
| eval_performance.png — model-size / compute / memory / throughput |
| comparison (this vs teacher), absolute + ratios. |
| |
| This script does not import from training code. It vendors the small set |
| of helpers it needs (BIOES decoder, span builder, eval-set sampler, |
| metrics aggregator) so the model folder is self-contained. |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import ast |
| import math |
| import os |
| import sys |
| import time |
| from collections import Counter, defaultdict |
| from pathlib import Path |
| from typing import Dict, List, Tuple |
|
|
| import numpy as np |
| import torch |
| from datasets import load_dataset |
| from torch.utils.data import DataLoader, Dataset |
| from tqdm.auto import tqdm |
| from transformers import AutoModelForTokenClassification, AutoTokenizer |
|
|
|
|
| |
| |
| |
|
|
| SOURCE_DATASET = "nvidia/Nemotron-PII" |
| TEACHER = "OpenMed/privacy-filter-nemotron" |
|
|
| |
| |
| NEMOTRON_CATEGORIES: List[str] = sorted([ |
| "account_number", "age", "api_key", "bank_routing_number", |
| "biometric_identifier", "blood_type", "certificate_license_number", |
| "city", "company_name", "coordinate", "country", "county", |
| "credit_debit_card", "customer_id", "cvv", "date", "date_of_birth", |
| "date_time", "device_identifier", "education_level", "email", |
| "employee_id", "employment_status", "fax_number", "first_name", |
| "gender", "health_plan_beneficiary_number", "http_cookie", "ipv4", |
| "ipv6", "language", "last_name", "license_plate", "mac_address", |
| "medical_record_number", "national_id", "occupation", "password", |
| "phone_number", "pin", "political_view", "postcode", "race_ethnicity", |
| "religious_belief", "sexuality", "ssn", "state", "street_address", |
| "swift_bic", "tax_id", "time", "unique_id", "url", "user_name", |
| "vehicle_identifier", |
| ]) |
|
|
|
|
| def nemotron_native_label_space() -> Tuple[Dict[str, int], Dict[int, str]]: |
| """O at id 0, then {B, I, E, S}-{cat} for each cat in alphabetical order.""" |
| label2id: Dict[str, int] = {"O": 0} |
| nxt = 1 |
| for cat in NEMOTRON_CATEGORIES: |
| for prefix in ("B", "I", "E", "S"): |
| label2id[f"{prefix}-{cat}"] = nxt |
| nxt += 1 |
| id2label: Dict[int, str] = {v: k for k, v in label2id.items()} |
| return label2id, id2label |
|
|
|
|
| |
| |
| |
| |
|
|
| def _trim_span(text: str, start: int, end: int) -> Tuple[int, int]: |
| raw = text[start:end] |
| i = 0 |
| while i < len(raw) and raw[i].isspace(): |
| i += 1 |
| j = len(raw) |
| while j > i and (raw[j - 1].isspace() or raw[j - 1] in ".,;:)"): |
| j -= 1 |
| return start + i, start + j |
|
|
|
|
| def _parse_spans(spans_str) -> List[dict]: |
| if isinstance(spans_str, list): |
| return spans_str |
| try: |
| return ast.literal_eval(spans_str) |
| except (SyntaxError, ValueError): |
| return [] |
|
|
|
|
| def _assign_native_bioes_labels( |
| text: str, |
| raw_spans: List[dict], |
| tokenizer, |
| max_length: int, |
| label2id: Dict[str, int], |
| min_overlap_frac: float = 0.5, |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| pf_like: List[Tuple[int, int, str]] = [] |
| for s in raw_spans: |
| cat = s.get("label") |
| if not cat: |
| continue |
| st, en = _trim_span(text, int(s["start"]), int(s["end"])) |
| if st >= en: |
| continue |
| pf_like.append((st, en, cat)) |
| pf_like.sort(key=lambda x: (x[0], -x[1])) |
|
|
| enc = tokenizer( |
| text, truncation=True, max_length=max_length, |
| padding="max_length", return_offsets_mapping=True, return_tensors="pt", |
| ) |
| input_ids = enc.input_ids[0] |
| attention_mask = enc.attention_mask[0] |
| offsets = enc.offset_mapping[0].tolist() |
|
|
| o_id = label2id["O"] |
| label_ids = [o_id] * len(input_ids) |
| locked = [False] * len(input_ids) |
|
|
| for span_start, span_end, cat in pf_like: |
| tok_indices: List[int] = [] |
| for ti, (s, e) in enumerate(offsets): |
| if s == 0 and e == 0: |
| continue |
| if e <= span_start or s >= span_end: |
| continue |
| tok_len = e - s |
| if tok_len <= 0: |
| continue |
| overlap = min(e, span_end) - max(s, span_start) |
| if overlap / tok_len >= min_overlap_frac: |
| tok_indices.append(ti) |
|
|
| if not tok_indices: |
| continue |
| tok_indices = [ti for ti in tok_indices if not locked[ti]] |
| if not tok_indices: |
| continue |
|
|
| if len(tok_indices) == 1: |
| tag = f"S-{cat}" |
| if tag in label2id: |
| label_ids[tok_indices[0]] = label2id[tag] |
| locked[tok_indices[0]] = True |
| else: |
| b_tag, i_tag, e_tag = f"B-{cat}", f"I-{cat}", f"E-{cat}" |
| if b_tag in label2id: |
| label_ids[tok_indices[0]] = label2id[b_tag] |
| locked[tok_indices[0]] = True |
| for ti in tok_indices[1:-1]: |
| if i_tag in label2id: |
| label_ids[ti] = label2id[i_tag] |
| locked[ti] = True |
| if e_tag in label2id: |
| label_ids[tok_indices[-1]] = label2id[e_tag] |
| locked[tok_indices[-1]] = True |
|
|
| label_tensor = torch.tensor(label_ids, dtype=torch.long) |
| label_tensor[attention_mask == 0] = -100 |
| return input_ids, attention_mask, label_tensor |
|
|
|
|
| class _NemotronEvalDataset(Dataset): |
| def __init__(self, hf_split, tokenizer, label2id, max_length): |
| self.hf = hf_split |
| self.tok = tokenizer |
| self.l2i = label2id |
| self.maxlen = max_length |
|
|
| def __len__(self): |
| return len(self.hf) |
|
|
| def __getitem__(self, idx): |
| ex = self.hf[idx] |
| ids, mask, labels = _assign_native_bioes_labels( |
| ex["text"], _parse_spans(ex["spans"]), |
| self.tok, self.maxlen, self.l2i, |
| ) |
| L = int(mask.sum().item()) |
| return { |
| "input_ids": ids[:L].tolist(), |
| "labels": labels[:L].tolist(), |
| "valid_len": L, |
| } |
|
|
|
|
| def _make_collate(pad_token_id, max_length): |
| def _c(batch): |
| ids_list = [list(ex["input_ids"])[:max_length] for ex in batch] |
| labels_list = [list(ex["labels"])[:max_length] for ex in batch] |
| max_len = max(len(x) for x in ids_list) |
| B = len(batch) |
| input_ids = torch.full((B, max_len), pad_token_id, dtype=torch.long) |
| attention_mask = torch.zeros((B, max_len), dtype=torch.long) |
| labels = torch.full((B, max_len), -100, dtype=torch.long) |
| for i, (ids, lab) in enumerate(zip(ids_list, labels_list)): |
| L = len(ids) |
| input_ids[i, :L] = torch.tensor(ids, dtype=torch.long) |
| attention_mask[i, :L] = 1 |
| labels[i, :L] = torch.tensor(lab, dtype=torch.long) |
| return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels} |
| return _c |
|
|
|
|
| def _build_eval_streaming(test_split, target_n, chunk_size, seed) -> List[int]: |
| """Uniform chunked sampling, identical to the training-time eval split.""" |
| n_total = len(test_split) |
| target_n = min(target_n, n_total) |
| if target_n <= 0: |
| return [] |
| rng = np.random.RandomState(seed) |
| per_chunk = max(1, math.ceil(chunk_size * target_n / n_total)) |
| selected: List[int] = [] |
| for chunk_start in range(0, n_total, chunk_size): |
| if len(selected) >= target_n: |
| break |
| chunk_end = min(chunk_start + chunk_size, n_total) |
| n_in_chunk = chunk_end - chunk_start |
| n_to_pick = min(per_chunk, n_in_chunk, target_n - len(selected)) |
| if n_to_pick <= 0: |
| break |
| offsets = rng.choice(n_in_chunk, size=n_to_pick, replace=False) |
| selected.extend(int(chunk_start + o) for o in offsets) |
| return sorted(selected[:target_n]) |
|
|
|
|
| |
| |
| |
|
|
| def _bioes_to_spans(labels, id2label, o_id=0): |
| """Convert per-token BIOES label ids to a set of (start, end, cat).""" |
| spans = set() |
| cur_start = None |
| cur_cat = None |
| for i, lid in enumerate(labels): |
| lid = int(lid) |
| if lid == o_id or lid < 0: |
| if cur_start is not None: |
| spans.add((cur_start, i, cur_cat)) |
| cur_start = None |
| cur_cat = None |
| continue |
| tag = id2label.get(lid, "O") |
| if tag == "O" or "-" not in tag: |
| if cur_start is not None: |
| spans.add((cur_start, i, cur_cat)) |
| cur_start = None |
| cur_cat = None |
| continue |
| prefix, cat = tag.split("-", 1) |
| if prefix == "S": |
| if cur_start is not None: |
| spans.add((cur_start, i, cur_cat)) |
| spans.add((i, i + 1, cat)) |
| cur_start = None |
| cur_cat = None |
| elif prefix == "B": |
| if cur_start is not None: |
| spans.add((cur_start, i, cur_cat)) |
| cur_start = i |
| cur_cat = cat |
| elif prefix == "I": |
| if cur_start is None or cur_cat != cat: |
| if cur_start is not None: |
| spans.add((cur_start, i, cur_cat)) |
| cur_start = i |
| cur_cat = cat |
| elif prefix == "E": |
| if cur_start is None or cur_cat != cat: |
| if cur_start is not None: |
| spans.add((cur_start, i, cur_cat)) |
| spans.add((i, i + 1, cat)) |
| cur_start = None |
| cur_cat = None |
| else: |
| spans.add((cur_start, i + 1, cur_cat)) |
| cur_start = None |
| cur_cat = None |
| if cur_start is not None: |
| spans.add((cur_start, len(labels), cur_cat)) |
| return spans |
|
|
|
|
| def _aggregate_span_metrics(gold_spans, pred_spans): |
| correct = gold_spans & pred_spans |
| n_gold = len(gold_spans) |
| n_pred = len(pred_spans) |
| n_correct = len(correct) |
| p = n_correct / n_pred if n_pred else 0.0 |
| r = n_correct / n_gold if n_gold else 0.0 |
| f1 = (2 * p * r / (p + r)) if (p + r) else 0.0 |
| per_cat: Dict[str, dict] = {} |
| cats = sorted({c for _, _, c in gold_spans} | {c for _, _, c in pred_spans}) |
| for cat in cats: |
| g_c = {s for s in gold_spans if s[2] == cat} |
| p_c = {s for s in pred_spans if s[2] == cat} |
| c_c = g_c & p_c |
| pp = len(c_c) / len(p_c) if p_c else 0.0 |
| rr = len(c_c) / len(g_c) if g_c else 0.0 |
| ff = (2 * pp * rr / (pp + rr)) if (pp + rr) else 0.0 |
| per_cat[cat] = {"precision": pp, "recall": rr, "f1": ff, |
| "n_gold": len(g_c), "n_pred": len(p_c), "n_correct": len(c_c)} |
| return {"precision": p, "recall": r, "f1": f1, |
| "n_gold": n_gold, "n_pred": n_pred, "n_correct": n_correct, |
| "per_cat": per_cat} |
|
|
|
|
| def _stream_metrics(docs, stream, id2label, o_id): |
| """Aggregate metrics over a list of {gold, raw, viterbi} per-doc dicts.""" |
| n_tokens = correct = n_non_o = non_o_correct = 0 |
| gold_spans_all: set = set() |
| pred_spans_all: set = set() |
| doc_offset = 0 |
| for doc in docs: |
| gold = [int(x) for x in doc["gold"]] |
| pred = [int(x) for x in doc[stream]] |
| n = len(gold) |
| n_tokens += n |
| for g, p in zip(gold, pred): |
| n_non_o += int(g != o_id) |
| if g == p: |
| correct += 1 |
| if g != o_id: |
| non_o_correct += 1 |
| gs = _bioes_to_spans(gold, id2label, o_id) |
| ps = _bioes_to_spans(pred, id2label, o_id) |
| gold_spans_all.update((doc_offset + s, doc_offset + e, c) for s, e, c in gs) |
| pred_spans_all.update((doc_offset + s, doc_offset + e, c) for s, e, c in ps) |
| doc_offset += n |
| span_m = _aggregate_span_metrics(gold_spans_all, pred_spans_all) |
| return { |
| "n_tokens": n_tokens, |
| "n_non_o": n_non_o, |
| "token_acc": correct / n_tokens if n_tokens else 0.0, |
| "non_o_recall": non_o_correct / n_non_o if n_non_o else 0.0, |
| "span_precision": span_m["precision"], |
| "span_recall": span_m["recall"], |
| "span_f1": span_m["f1"], |
| "n_gold_spans": span_m["n_gold"], |
| "n_pred_spans": span_m["n_pred"], |
| "n_correct_spans": span_m["n_correct"], |
| "span_per_cat": span_m["per_cat"], |
| } |
|
|
|
|
| |
| |
| |
|
|
| def _model_perf_stats(model, dtype) -> dict: |
| """Total / active / compute / MoE breakdown + on-device byte size. |
| |
| Three distinct param counts: |
| |
| * total_params — every parameter the model has on disk / in RAM. |
| * active_params_per_tok — params *touched* during one token's forward |
| pass (memory footprint per token). Counts |
| the embedding because the embedding row |
| IS read per token; counts only top-k of |
| num_experts MoE experts because routing is |
| sparse. |
| * compute_params_per_tok — params that contribute matmul FLOPs per |
| token. EXCLUDES the embedding table: |
| `embed_tokens.weight` is a gather (one row |
| read), not a matmul, so its FLOP cost is |
| negligible (~hidden_size ops vs the table |
| having ~vocab × hidden params). Counting |
| it via the standard "2 × params" matmul |
| approximation hugely inflates the apparent |
| GFLOP/token and compresses the ratio between |
| deep and shallow models. |
| |
| GFLOP/token is computed from `compute_params_per_tok`, not from |
| `active_params_per_tok`. This makes the metric reflect actual layer-wise |
| computational cost. |
| """ |
| cfg = model.config |
| num_experts = int(getattr(cfg, "num_local_experts", 1)) |
| top_k = int(getattr(cfg, "num_experts_per_tok", num_experts)) |
| expert_frac = top_k / max(1, num_experts) |
|
|
| moe_total = 0 |
| moe_active = 0 |
| other_total = 0 |
| embed_total = 0 |
| for name, p in model.named_parameters(): |
| n = p.numel() |
| |
| |
| |
| if ".mlp.experts." in name: |
| moe_total += n |
| moe_active += int(round(n * expert_frac)) |
| |
| |
| |
| |
| elif "embed_tokens" in name: |
| embed_total += n |
| other_total += n |
| else: |
| other_total += n |
|
|
| total = moe_total + other_total |
| active_per_tok = moe_active + other_total |
|
|
| |
| |
| compute_per_tok = active_per_tok - embed_total |
| bytes_per_param = {torch.bfloat16: 2, torch.float16: 2, torch.float32: 4}.get(dtype, 4) |
| storage_bytes = total * bytes_per_param |
| |
| gflops_per_tok = 2 * compute_per_tok / 1e9 |
| return { |
| "total_params": total, |
| "moe_total": moe_total, |
| "moe_active_per_tok": moe_active, |
| "other_total": other_total, |
| "embed_total": embed_total, |
| "active_params_per_tok": active_per_tok, |
| "compute_params_per_tok": compute_per_tok, |
| "num_experts": num_experts, |
| "experts_per_tok": top_k, |
| "expert_frac": expert_frac, |
| "weight_bytes": storage_bytes, |
| "gflops_per_tok": gflops_per_tok, |
| } |
|
|
|
|
| def _disk_size_bytes(model_path: str) -> int: |
| """Sum on-disk size of weight files at the given path. Falls back to 0 |
| if the path is a HF repo id (not a local directory).""" |
| p = Path(model_path) |
| if not p.is_dir(): |
| return 0 |
| total = 0 |
| for f in p.iterdir(): |
| if f.is_file() and f.suffix in {".safetensors", ".bin", ".pt", ".pth"}: |
| total += f.stat().st_size |
| return total |
|
|
|
|
| def _eval_one_model( |
| model_path: str, tokenizer, eval_ds, label2id, id2label, o_id, |
| bioes_trans, bioes_init, batch_size, max_length, device, dtype, |
| label: str, |
| ): |
| print(f"[eval] loading {label} from {model_path} ...", flush=True) |
| if torch.cuda.is_available() and device.type == "cuda": |
| torch.cuda.reset_peak_memory_stats(device) |
| torch.cuda.empty_cache() |
| mem_before = (torch.cuda.memory_allocated(device) |
| if torch.cuda.is_available() and device.type == "cuda" else 0) |
| t_load = time.time() |
| model = AutoModelForTokenClassification.from_pretrained( |
| model_path, dtype=dtype, trust_remote_code=True, |
| ).to(device).eval() |
| if hasattr(model.config, "use_viterbi_decode"): |
| model.config.use_viterbi_decode = True |
| if hasattr(model.config, "viterbi_replace_logits"): |
| model.config.viterbi_replace_logits = False |
| load_s = time.time() - t_load |
| perf = _model_perf_stats(model, dtype) |
| perf["disk_size_bytes"] = _disk_size_bytes(model_path) |
| weights_resident_bytes = ( |
| torch.cuda.memory_allocated(device) - mem_before |
| if torch.cuda.is_available() and device.type == "cuda" else perf["weight_bytes"] |
| ) |
| perf["weights_resident_bytes"] = weights_resident_bytes |
|
|
| |
| from modeling_haremb_pii import _bioes_viterbi_batched |
|
|
| pad_token_id = tokenizer.pad_token_id or 199999 |
| loader = DataLoader( |
| eval_ds, batch_size=batch_size, shuffle=False, |
| collate_fn=_make_collate(pad_token_id, max_length), num_workers=2, |
| ) |
| docs: List[dict] = [] |
| n_tok = 0 |
| if torch.cuda.is_available() and device.type == "cuda": |
| torch.cuda.synchronize() |
| torch.cuda.reset_peak_memory_stats(device) |
| t0 = time.time() |
| for batch in tqdm(loader, desc=f"eval {label}", unit="batch", leave=False): |
| ids = batch["input_ids"].to(device, non_blocking=True) |
| mask = batch["attention_mask"].to(device, non_blocking=True) |
| gold = batch["labels"].to(device, non_blocking=True) |
| with torch.no_grad(): |
| out = model(input_ids=ids, attention_mask=mask) |
| raw = out.logits.argmax(dim=-1) |
| vit = _bioes_viterbi_batched(out.logits.float(), mask, bioes_trans, bioes_init) |
| valid = (gold != -100) & mask.bool() |
| for b in range(gold.shape[0]): |
| keep = [i for i, ok in enumerate(valid[b].cpu().tolist()) if ok] |
| n_tok += len(keep) |
| docs.append({ |
| "gold": [int(gold[b, i].item()) for i in keep], |
| "raw": [int(raw[b, i].item()) for i in keep], |
| "viterbi": [int(vit[b, i].item()) for i in keep], |
| }) |
| if torch.cuda.is_available() and device.type == "cuda": |
| torch.cuda.synchronize() |
| peak_mem = torch.cuda.max_memory_allocated(device) |
| else: |
| peak_mem = 0 |
| eval_s = time.time() - t0 |
|
|
| raw_m = _stream_metrics(docs, "raw", id2label, o_id) |
| vit_m = _stream_metrics(docs, "viterbi", id2label, o_id) |
|
|
| perf["peak_eval_mem_bytes"] = peak_mem |
|
|
| del model |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| return { |
| "label": label, |
| "n_total_M": perf["total_params"] / 1e6, |
| "load_s": load_s, |
| "eval_s": eval_s, |
| "n_tok": n_tok, |
| "throughput_tok_s": n_tok / eval_s if eval_s else 0.0, |
| "perf": perf, |
| "raw": raw_m, |
| "viterbi": vit_m, |
| "docs": docs, |
| } |
|
|
|
|
| |
| |
| |
|
|
| def _pairwise(docs_cand, docs_ref, id2label, o_id): |
| both_correct = only_cand = only_ref = both_wrong = 0 |
| by_cat = defaultdict(lambda: {"both_correct": 0, "only_cand": 0, "only_ref": 0, "both_wrong": 0}) |
| for dc, dr in zip(docs_cand, docs_ref): |
| gold = dc["gold"] |
| cv = dc["viterbi"] |
| rv = dr["viterbi"] |
| for g, c, r in zip(gold, cv, rv): |
| cat = id2label.get(g, "O") |
| cat = cat.split("-", 1)[1] if "-" in cat else cat |
| cc = (c == g) |
| rc = (r == g) |
| if cc and rc: |
| both_correct += 1 |
| by_cat[cat]["both_correct"] += 1 |
| elif cc and not rc: |
| only_cand += 1 |
| by_cat[cat]["only_cand"] += 1 |
| elif rc and not cc: |
| only_ref += 1 |
| by_cat[cat]["only_ref"] += 1 |
| else: |
| both_wrong += 1 |
| by_cat[cat]["both_wrong"] += 1 |
| return { |
| "both_correct": both_correct, |
| "only_cand_correct": only_cand, |
| "only_ref_correct": only_ref, |
| "both_wrong": both_wrong, |
| "by_cat": dict(by_cat), |
| } |
|
|
|
|
| |
| |
| |
|
|
| def _render_plots(cand, ref, pair, out_dir: Path, cand_label, ref_label): |
| """Render benchmark plots. |
| |
| Visual convention: |
| A = reference / teacher / baseline |
| B = candidate / this checkpoint |
| |
| The charts avoid color-only "win" encoding: labels state the actual delta |
| or ratio, and horizontal layouts keep long metric/category names readable. |
| """ |
| try: |
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
| from matplotlib.ticker import FuncFormatter |
| except ImportError: |
| print("[plot] matplotlib not installed, skipping", flush=True) |
| return |
|
|
| plt.rcParams.update({ |
| "figure.facecolor": "#ffffff", |
| "axes.facecolor": "#ffffff", |
| "axes.edgecolor": "#cbd5e1", |
| "axes.labelcolor": "#0f172a", |
| "xtick.color": "#334155", |
| "ytick.color": "#334155", |
| "grid.color": "#e2e8f0", |
| "font.size": 9, |
| "axes.titleweight": "bold", |
| "axes.titlesize": 11, |
| "legend.frameon": False, |
| }) |
|
|
| C_REF = "#64748b" |
| C_CAND = "#2563eb" |
| C_GOOD = "#0f766e" |
| C_BAD = "#b91c1c" |
| C_NEUTRAL = "#94a3b8" |
| C_BG = "#f8fafc" |
|
|
| def _pct_axis(ax): |
| ax.xaxis.set_major_formatter(FuncFormatter(lambda x, _pos: f"{x:.0%}")) |
|
|
| def _metric_delta_text(delta): |
| return f"{delta:+.4f}" |
|
|
| def _value_text(v): |
| if v >= 100: |
| return f"{v:,.0f}" |
| if v >= 10: |
| return f"{v:,.1f}" |
| if v >= 1: |
| return f"{v:,.2f}" |
| return f"{v:,.3f}" |
|
|
| def _ratio_text(r, lower_is_better): |
| if r <= 0: |
| return "n/a" |
| if lower_is_better: |
| return f"{1.0 / r:.2f}x lower" if r <= 1 else f"{r:.2f}x higher" |
| return f"{r:.2f}x higher" if r >= 1 else f"{1.0 / r:.2f}x lower" |
|
|
| |
| fig, axes = plt.subplots(2, 2, figsize=(8, 5), constrained_layout=True) |
| ax_head = axes[0, 0] |
| ax_delta = axes[0, 1] |
| ax_raw_vit = axes[1, 0] |
| ax_cat = axes[1, 1] |
|
|
| metrics = ["span_f1", "span_precision", "span_recall", "token_acc", "non_o_recall"] |
| labels = ["Span F1", "Span P", "Span R", "Token acc", "Non-O recall"] |
| cand_v = [cand["viterbi"][m] for m in metrics] |
| ref_v = [ref["viterbi"][m] for m in metrics] if ref is not None else None |
|
|
| y = np.arange(len(metrics)) |
| if ref_v is not None: |
| ax_head.hlines(y, ref_v, cand_v, color=C_NEUTRAL, linewidth=2, alpha=0.9) |
| ax_head.scatter(ref_v, y, s=55, color=C_REF, label=f"A: {ref_label}", zorder=3) |
| ax_head.scatter(cand_v, y, s=70, color=C_CAND, label=f"B: {cand_label}", zorder=4) |
| ax_head.set_yticks(y) |
| ax_head.set_yticklabels(labels) |
| ax_head.invert_yaxis() |
| ax_head.set_xlim(max(0.0, min(cand_v + (ref_v or cand_v)) - 0.02), |
| min(1.08, max(cand_v + (ref_v or cand_v)) + 0.05)) |
| ax_head.set_title("Headline metrics, Viterbi stream") |
| ax_head.grid(axis="x", alpha=0.7) |
| _pct_axis(ax_head) |
| if ref_v is not None: |
| for i, v in enumerate(ref_v): |
| ax_head.text(v + 0.002, i, f"{v:.4f}", va="center", ha="left", |
| fontsize=7, color=C_REF) |
| for i, v in enumerate(cand_v): |
| ax_head.text(v - 0.002, i, f"{v:.4f}", va="center", ha="right", |
| fontsize=7, color=C_CAND) |
|
|
| if ref_v is not None: |
| deltas = [b - a for a, b in zip(ref_v, cand_v)] |
| colors = [C_GOOD if d >= 0 else C_BAD for d in deltas] |
| ax_delta.axvline(0, color="#0f172a", linewidth=0.9) |
| ax_delta.barh(y, deltas, color=colors, alpha=0.9) |
| ax_delta.set_yticks(y) |
| ax_delta.set_yticklabels(labels) |
| ax_delta.invert_yaxis() |
| ax_delta.set_title("Delta: B minus A") |
| ax_delta.grid(axis="x", alpha=0.7) |
| max_abs = max([abs(d) for d in deltas] + [0.002]) |
| ax_delta.set_xlim(-max_abs * 1.45, max_abs * 1.45) |
| for i, d in enumerate(deltas): |
| ax_delta.text(d + (max_abs * 0.04 if d >= 0 else -max_abs * 0.04), i, _metric_delta_text(d), |
| ha="left" if d >= 0 else "right", va="center", fontsize=7) |
| else: |
| ax_delta.axis("off") |
|
|
| stream_rows = [ |
| ("A raw", ref["raw"]["span_f1"] if ref is not None else None, C_REF), |
| ("A viterbi", ref["viterbi"]["span_f1"] if ref is not None else None, C_REF), |
| ("B raw", cand["raw"]["span_f1"], C_CAND), |
| ("B viterbi", cand["viterbi"]["span_f1"], C_CAND), |
| ] |
| stream_rows = [r for r in stream_rows if r[1] is not None] |
| sy = np.arange(len(stream_rows)) |
| ax_raw_vit.barh(sy, [r[1] for r in stream_rows], color=[r[2] for r in stream_rows], alpha=0.88) |
| ax_raw_vit.set_yticks(sy) |
| ax_raw_vit.set_yticklabels([r[0] for r in stream_rows]) |
| ax_raw_vit.invert_yaxis() |
| ax_raw_vit.set_xlim(0, 1.08) |
| ax_raw_vit.set_title("Raw vs Viterbi span F1") |
| ax_raw_vit.grid(axis="x", alpha=0.7) |
| _pct_axis(ax_raw_vit) |
| for i, (_, v, _) in enumerate(stream_rows): |
| ax_raw_vit.text(v + 0.008, i, f"{v:.4f}", va="center", fontsize=7) |
|
|
| cand_pc = cand["viterbi"]["span_per_cat"] |
| if ref is not None: |
| ref_pc = ref["viterbi"]["span_per_cat"] |
| cats = sorted(set(cand_pc) | set(ref_pc)) |
| rows = [] |
| for c in cats: |
| a = ref_pc.get(c, {}).get("f1", 0.0) |
| b = cand_pc.get(c, {}).get("f1", 0.0) |
| n = max(cand_pc.get(c, {}).get("n_gold", 0), ref_pc.get(c, {}).get("n_gold", 0)) |
| rows.append((c, b - a, b, a, n)) |
| |
| worst = sorted(rows, key=lambda r: r[1])[:8] |
| best = sorted(rows, key=lambda r: r[1], reverse=True)[:8] |
| picked = worst + [r for r in best if r[0] not in {x[0] for x in worst}] |
| picked = sorted(picked, key=lambda r: r[1]) |
| cy = np.arange(len(picked)) |
| deltas = [r[1] for r in picked] |
| ax_cat.axvline(0, color="#0f172a", linewidth=0.9) |
| ax_cat.barh(cy, deltas, color=[C_GOOD if d >= 0 else C_BAD for d in deltas]) |
| ax_cat.set_yticks(cy) |
| ax_cat.set_yticklabels([r[0] for r in picked], fontsize=8) |
| ax_cat.set_title("Per-category span F1 delta, selected extremes") |
| ax_cat.grid(axis="x", alpha=0.7) |
| max_abs = max([abs(d) for d in deltas] + [0.05]) |
| ax_cat.set_xlim(-max_abs * 1.55, max_abs * 1.55) |
| for i, r in enumerate(picked): |
| d = r[1] |
| ax_cat.text(d + (max_abs * 0.05 if d >= 0 else -max_abs * 0.05), i, |
| f"{d:+.3f} B={r[2]:.2f} A={r[3]:.2f}", |
| va="center", ha="left" if d >= 0 else "right", fontsize=6) |
| else: |
| cats_sorted = sorted(cand_pc.keys(), key=lambda c: cand_pc[c]["f1"])[:18] |
| vals = [cand_pc[c]["f1"] for c in cats_sorted] |
| cy = np.arange(len(cats_sorted)) |
| ax_cat.barh(cy, vals, color=C_CAND) |
| ax_cat.set_yticks(cy) |
| ax_cat.set_yticklabels(cats_sorted, fontsize=8) |
| ax_cat.set_xlim(0, 1.0) |
| ax_cat.set_title("Lowest per-category span F1") |
| ax_cat.grid(axis="x", alpha=0.7) |
| _pct_axis(ax_cat) |
|
|
| fig.suptitle(f"Evaluation summary — A: {ref_label if ref else 'n/a'} | B: {cand_label}", |
| fontsize=9, fontweight="bold") |
| fig.savefig(out_dir / "eval_summary.png", dpi=160) |
| plt.close(fig) |
| print(f"[plot] wrote {out_dir / 'eval_summary.png'}", flush=True) |
|
|
| |
| |
| |
| |
| if pair is not None: |
| fig, axes = plt.subplots( |
| 1, 2, figsize=(8, 3), constrained_layout=True, |
| gridspec_kw={"width_ratios": [0.9, 1.7]}, |
| ) |
| ax = axes[0] |
| non_o_buckets = {k: 0 for k in ["both_correct", "only_cand", "only_ref", "both_wrong"]} |
| for cat, d in pair["by_cat"].items(): |
| if cat == "O": |
| continue |
| for k in non_o_buckets: |
| non_o_buckets[k] += d[k] |
| values = [non_o_buckets["both_correct"], non_o_buckets["only_ref"], |
| non_o_buckets["only_cand"], non_o_buckets["both_wrong"]] |
| labels_ = [ |
| "Both\ncorrect", |
| "Only A\ncorrect", |
| "Only B\ncorrect", |
| "Both wrong", |
| ] |
| colors = [C_GOOD, C_REF, C_CAND, C_BAD] |
| total = max(1, sum(values)) |
| ax.barh(np.arange(4), values, color=colors) |
| ax.set_yticks(np.arange(4)) |
| ax.set_yticklabels(labels_) |
| ax.invert_yaxis() |
| ax.set_ylabel("Gold non-O tokens") |
| ax.set_title("Token outcome on gold non-O") |
| ax.grid(axis="x", alpha=0.7) |
| ax.set_xlim(0, max(values) * 1.32 if values else 1) |
| for i, v in enumerate(values): |
| ax.text(v + max(values) * 0.015, i, f"{v:,} ({v / total:.1%})", |
| va="center", fontsize=6) |
|
|
| rows = [] |
| for cat, d in pair["by_cat"].items(): |
| if cat == "O": |
| continue |
| net = d["only_cand"] - d["only_ref"] |
| active = d["only_cand"] + d["only_ref"] + d["both_wrong"] |
| if active: |
| rows.append((cat, net, d["only_cand"], d["only_ref"], d["both_wrong"])) |
| worst = sorted(rows, key=lambda r: r[1])[:8] |
| best = sorted(rows, key=lambda r: r[1], reverse=True)[:8] |
| picked = worst + [r for r in best if r[0] not in {x[0] for x in worst}] |
| picked = sorted(picked, key=lambda r: r[1]) |
| ax2 = axes[1] |
| if picked: |
| py = np.arange(len(picked)) |
| nets = [r[1] for r in picked] |
| ax2.axvline(0, color="#0f172a", linewidth=0.9) |
| ax2.barh(py, nets, color=[C_GOOD if n >= 0 else C_BAD for n in nets]) |
| ax2.set_yticks(py) |
| ax2.set_yticklabels([r[0] for r in picked], fontsize=8) |
| ax2.set_title("Net token wins by category: B only-correct minus A only-correct") |
| ax2.grid(axis="x", alpha=0.7) |
| max_abs = max([abs(n) for n in nets] + [1]) |
| ax2.set_xlim(-max_abs * 1.5, max_abs * 1.5) |
| for i, r in enumerate(picked): |
| n = r[1] |
| label_x = n + max_abs * 0.04 if n >= 0 else max_abs * 0.05 |
| ax2.text(label_x, i, |
| f"{n:+d} B={r[2]} A={r[3]} W={r[4]}", |
| va="center", ha="left", fontsize=6) |
| else: |
| ax2.axis("off") |
|
|
| fig.suptitle(f"Pairwise correctness — A: {ref_label} | B: {cand_label}", |
| fontsize=9, fontweight="bold") |
| fig.savefig(out_dir / "eval_confusion.png", dpi=160) |
| plt.close(fig) |
| print(f"[plot] wrote {out_dir / 'eval_confusion.png'}", flush=True) |
|
|
| |
| if ref is not None and "perf" in cand and "perf" in ref: |
| cp, rp = cand["perf"], ref["perf"] |
|
|
| fig, axes = plt.subplots(1, 2, figsize=(8.5, 5), constrained_layout=True) |
| ax_abs = axes[0] |
| ax_ratio = axes[1] |
|
|
| metrics = [ |
| ("Total params (M)", cp["total_params"]/1e6, rp["total_params"]/1e6, True), |
| ("Active params/tok (M)", cp["active_params_per_tok"]/1e6, rp["active_params_per_tok"]/1e6, True), |
| ("MoE expert params (M)", cp["moe_total"]/1e6, rp["moe_total"]/1e6, True), |
| ("GFLOP/token", cp["gflops_per_tok"], rp["gflops_per_tok"], True), |
| ("Weights RAM (MiB)", cp["weight_bytes"]/(1<<20), rp["weight_bytes"]/(1<<20), True), |
| ("Peak eval mem (MiB)", cp["peak_eval_mem_bytes"]/(1<<20), rp["peak_eval_mem_bytes"]/(1<<20), True), |
| ("Throughput (tok/s)", cand["throughput_tok_s"], ref["throughput_tok_s"], False), |
| ] |
|
|
| y = np.arange(len(metrics)) |
| w = 0.38 |
| cand_v = [m[1] for m in metrics] |
| ref_v = [m[2] for m in metrics] |
| ax_abs.barh(y - w/2, ref_v, w, label=f"A: {ref_label}", color=C_REF) |
| ax_abs.barh(y + w/2, cand_v, w, label=f"B: {cand_label}", color=C_CAND) |
| ax_abs.set_yticks(y) |
| ax_abs.set_yticklabels([m[0] for m in metrics], fontsize=8) |
| ax_abs.invert_yaxis() |
| ax_abs.set_xscale("log") |
| ax_abs.set_title("Absolute footprint and speed, log scale") |
| ax_abs.grid(axis="x", which="both", alpha=0.7) |
| positive_vals = [v for v in (ref_v + cand_v) if v > 0] |
| if positive_vals: |
| ax_abs.set_xlim(min(positive_vals) * 0.55, max(positive_vals) * 3.8) |
| for yi, vals in enumerate(zip(ref_v, cand_v)): |
| for off, v, col in [(-w/2, vals[0], C_REF), (w/2, vals[1], C_CAND)]: |
| if v <= 0: |
| continue |
| ax_abs.text(v * 1.05, yi + off, _value_text(v), va="center", fontsize=6, color=col) |
|
|
| ratios = [(m[1] / max(1e-12, m[2])) for m in metrics] |
| lower_better = [m[3] for m in metrics] |
| colors = [ |
| C_GOOD if ((lb and r <= 1.0) or ((not lb) and r >= 1.0)) else C_BAD |
| for r, lb in zip(ratios, lower_better) |
| ] |
| ax_ratio.axvline(1.0, color="#0f172a", linestyle="--", linewidth=0.9, alpha=0.65) |
| ax_ratio.barh(y, ratios, color=colors) |
| ax_ratio.set_yticks(y) |
| ax_ratio.set_yticklabels([m[0] for m in metrics], fontsize=8) |
| ax_ratio.invert_yaxis() |
| ax_ratio.set_xscale("log") |
| ax_ratio.set_title("B / A ratio with explicit direction") |
| ax_ratio.grid(axis="x", which="both", alpha=0.7) |
| positive_ratios = [r for r in ratios if r > 0] |
| if positive_ratios: |
| ax_ratio.set_xlim(min(positive_ratios) * 0.38, max(positive_ratios) * 2.8) |
| for i, (r, lb) in enumerate(zip(ratios, lower_better)): |
| ax_ratio.text(r * (1.05 if r >= 1 else 0.95), i, _ratio_text(r, lb), |
| va="center", ha="left" if r >= 1 else "right", fontsize=6) |
|
|
| fig.suptitle(f"Performance profile — A: {ref_label} | B: {cand_label}", |
| fontsize=9, fontweight="bold") |
| fig.savefig(out_dir / "eval_performance.png", dpi=160) |
| plt.close(fig) |
| print(f"[plot] wrote {out_dir / 'eval_performance.png'}", flush=True) |
|
|
|
|
| |
| |
| |
|
|
| def _fmt_metrics(m): |
| return (f"span_F1={m['span_f1']:.4f} P={m['span_precision']:.4f} " |
| f"R={m['span_recall']:.4f} token_acc={m['token_acc']:.4f} " |
| f"non_o_recall={m['non_o_recall']:.4f} " |
| f"spans={m['n_gold_spans']}/{m['n_pred_spans']}/{m['n_correct_spans']}") |
|
|
|
|
| def _write_compare_log(path: Path, cand, ref, pair, args): |
| lines: List[str] = [] |
| A = lines.append |
| A(f"Benchmark: A: {ref['label']} vs B: {cand['label']}" |
| if ref else f"Benchmark: {cand['label']}") |
| A(f"Dataset: {SOURCE_DATASET}, split=test, eval_pct={args.eval_pct}, " |
| f"ctx={args.max_length}, seed={args.seed}, n_docs={args.n_docs}") |
| A(f"Eval tokens scored: {cand['n_tok']:,}") |
| A("") |
| A("=== Aggregate ===") |
| if ref is not None: |
| A(f" A: {ref['label']:<25s} RAW {_fmt_metrics(ref['raw'])}") |
| A(f" A: {ref['label']:<25s} VITERBI {_fmt_metrics(ref['viterbi'])}") |
| A(f" B: {cand['label']:<25s} RAW {_fmt_metrics(cand['raw'])}") |
| A(f" B: {cand['label']:<25s} VITERBI {_fmt_metrics(cand['viterbi'])}") |
| if ref is not None: |
| d_f1 = ref["viterbi"]["span_f1"] - cand["viterbi"]["span_f1"] |
| A("") |
| A(f"Gap B vs A (viterbi span_F1): {-d_f1:+.4f}") |
| A("") |
| if ref is not None: |
| A(f"Throughput: A: {ref['label']} {ref['throughput_tok_s']:.0f} tok/s " |
| f"({ref['n_total_M']:.2f}M params)") |
| A(f" B: {cand['label']} {cand['throughput_tok_s']:.0f} tok/s " |
| f"({cand['n_total_M']:.2f}M params)") |
| else: |
| A(f"Throughput: {cand['label']} {cand['throughput_tok_s']:.0f} tok/s " |
| f"({cand['n_total_M']:.2f}M params)") |
| A("") |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| def _fmt_vs(b: float, a: float, kind: str) -> str: |
| if a is None or a == 0 or b is None or b == 0: |
| return "" |
| ratio = b / a |
| if kind == "size": |
| if ratio <= 1.0: |
| return f"{1.0/ratio:.2f}× smaller" |
| return f"{ratio:.2f}× larger" |
| if kind == "compute": |
| if ratio <= 1.0: |
| return f"{1.0/ratio:.2f}× cheaper" |
| return f"{ratio:.2f}× more" |
| if kind == "memory": |
| if ratio <= 1.0: |
| return f"{1.0/ratio:.2f}× less" |
| return f"{ratio:.2f}× more" |
| if kind == "speed": |
| if ratio >= 1.0: |
| return f"{ratio:.2f}× faster" |
| return f"{1.0/ratio:.2f}× slower" |
| return f"{ratio:.2f}×" |
|
|
| cp = cand["perf"] |
| A("=== Performance ===") |
| headers = ["metric", |
| f"A: {ref['label']}" if ref else "", |
| f"B: {cand['label']}", |
| "B vs A"] |
| rp = ref["perf"] if ref else None |
| rows = [ |
| ["total params (M)", |
| (f"{rp['total_params']/1e6:.2f}" if ref else ""), |
| f"{cp['total_params']/1e6:.2f}", |
| (_fmt_vs(cp['total_params'], rp['total_params'], "size") if ref else "")], |
| ["dense params (M)", |
| (f"{rp['other_total']/1e6:.2f}" if ref else ""), |
| f"{cp['other_total']/1e6:.2f}", |
| (_fmt_vs(cp['other_total'], rp['other_total'], "size") if ref else "")], |
| ["MoE expert params (M)", |
| (f"{rp['moe_total']/1e6:.2f}" if ref else ""), |
| f"{cp['moe_total']/1e6:.2f}", |
| (_fmt_vs(cp['moe_total'], rp['moe_total'], "size") if ref else "")], |
| [f"active params/token (M, mem)", |
| (f"{rp['active_params_per_tok']/1e6:.2f}" if ref else ""), |
| f"{cp['active_params_per_tok']/1e6:.2f}", |
| (_fmt_vs(cp['active_params_per_tok'], rp['active_params_per_tok'], "memory") if ref else "")], |
| [f"compute params/token (M, FLOPs)", |
| (f"{rp['compute_params_per_tok']/1e6:.2f}" if ref else ""), |
| f"{cp['compute_params_per_tok']/1e6:.2f}", |
| (_fmt_vs(cp['compute_params_per_tok'], rp['compute_params_per_tok'], "compute") if ref else "")], |
| ["GFLOP / token", |
| (f"{rp['gflops_per_tok']:.4f}" if ref else ""), |
| f"{cp['gflops_per_tok']:.4f}", |
| (_fmt_vs(cp['gflops_per_tok'], rp['gflops_per_tok'], "compute") if ref else "")], |
| ["disk size (MiB)", |
| (f"{rp['disk_size_bytes']/(1<<20):.1f}" if ref and rp['disk_size_bytes'] else ""), |
| f"{cp['disk_size_bytes']/(1<<20):.1f}", |
| (_fmt_vs(cp['disk_size_bytes'], rp['disk_size_bytes'], "size") if ref and rp['disk_size_bytes'] else "")], |
| ["weights in RAM (MiB)", |
| (f"{rp['weight_bytes']/(1<<20):.1f}" if ref else ""), |
| f"{cp['weight_bytes']/(1<<20):.1f}", |
| (_fmt_vs(cp['weight_bytes'], rp['weight_bytes'], "size") if ref else "")], |
| ["peak GPU mem eval (MiB)", |
| (f"{rp['peak_eval_mem_bytes']/(1<<20):.1f}" if ref else ""), |
| f"{cp['peak_eval_mem_bytes']/(1<<20):.1f}", |
| (_fmt_vs(cp['peak_eval_mem_bytes'], rp['peak_eval_mem_bytes'], "memory") if ref else "")], |
| ["throughput (tok/s)", |
| (f"{ref['throughput_tok_s']:.0f}" if ref else ""), |
| f"{cand['throughput_tok_s']:.0f}", |
| (_fmt_vs(cand['throughput_tok_s'], ref['throughput_tok_s'], "speed") if ref else "")], |
| ] |
| widths = [max(len(r[i]) for r in [headers] + rows) for i in range(4)] |
| sep = " " + " ".join("-" * w for w in widths) |
| A(" " + " ".join(h.ljust(widths[i]) for i, h in enumerate(headers))) |
| A(sep) |
| for r in rows: |
| A(" " + " ".join(r[i].ljust(widths[i]) for i in range(4))) |
| A("") |
| if pair is not None: |
| |
| a_lbl = ref["label"] if ref is not None else "A" |
| b_lbl = cand["label"] |
| A(f"=== Pairwise (viterbi, all gold tokens) — A: {a_lbl} vs B: {b_lbl} ===") |
| total = (pair["both_correct"] + pair["only_cand_correct"] |
| + pair["only_ref_correct"] + pair["both_wrong"]) |
| |
| rows_all = [ |
| ("both_correct", pair["both_correct"]), |
| ("only_A_correct", pair["only_ref_correct"]), |
| ("only_B_correct", pair["only_cand_correct"]), |
| ("both_wrong", pair["both_wrong"]), |
| ] |
| for k, v in rows_all: |
| A(f" {k:<26s} {v:8d} ({100.0*v/total:.2f}%)") |
| A("") |
| A(f"=== Pairwise (viterbi, gold non-O tokens) — A: {a_lbl} vs B: {b_lbl} ===") |
| non_o = {k: 0 for k in ["both_correct", "only_cand", "only_ref", "both_wrong"]} |
| for cat, d in pair["by_cat"].items(): |
| if cat == "O": |
| continue |
| for k in non_o: |
| non_o[k] += d[k] |
| total_non_o = sum(non_o.values()) |
| rows_non_o = [ |
| ("both_correct", non_o["both_correct"]), |
| ("only_A_correct", non_o["only_ref"]), |
| ("only_B_correct", non_o["only_cand"]), |
| ("both_wrong", non_o["both_wrong"]), |
| ] |
| for k, v in rows_non_o: |
| A(f" {k:<26s} {v:8d} ({100.0*v/total_non_o:.2f}%)" if total_non_o else f" {k}: 0") |
| A("") |
| |
| |
| nets = [] |
| for cat, d in pair["by_cat"].items(): |
| if cat == "O": |
| continue |
| nets.append((cat, d["only_cand"] - d["only_ref"], |
| d["only_cand"], d["only_ref"], d["both_wrong"])) |
| nets.sort(key=lambda x: x[1]) |
| A(f"=== Worst B-net wins by gold category — A: {a_lbl} ahead (top 15) ===") |
| for cat, net, ob, oa, bw in nets[:15]: |
| A(f" {cat:<32s} net_B={net:+5d} A_only={oa:4d} B_only={ob:4d} both_wrong={bw:4d}") |
| A("") |
| A(f"=== Best B-net wins by gold category — B: {b_lbl} ahead (top 15) ===") |
| for cat, net, ob, oa, bw in nets[::-1][:15]: |
| A(f" {cat:<32s} net_B={net:+5d} A_only={oa:4d} B_only={ob:4d} both_wrong={bw:4d}") |
| A("") |
| A("=== Per-category span F1 (viterbi) ===") |
| if ref is not None: |
| A(f" -- A: {ref['label']} --") |
| per_r = ref["viterbi"]["span_per_cat"] |
| for cat in sorted(per_r): |
| c = per_r[cat] |
| A(f" {cat:<32s} F1={c['f1']:.4f} P={c['precision']:.4f} R={c['recall']:.4f} " |
| f"({c['n_gold']}/{c['n_pred']}/{c['n_correct']})") |
| A(f" -- B: {cand['label']} --") |
| per = cand["viterbi"]["span_per_cat"] |
| for cat in sorted(per): |
| c = per[cat] |
| A(f" {cat:<32s} F1={c['f1']:.4f} P={c['precision']:.4f} R={c['recall']:.4f} " |
| f"({c['n_gold']}/{c['n_pred']}/{c['n_correct']})") |
| path.write_text("\n".join(lines) + "\n") |
| print(f"[log] wrote {path}", flush=True) |
|
|
|
|
| def _fmt_bytes(n: int) -> str: |
| if n <= 0: |
| return "—" |
| if n >= 1 << 30: |
| return f"{n / (1 << 30):.2f} GiB" |
| if n >= 1 << 20: |
| return f"{n / (1 << 20):.1f} MiB" |
| return f"{n / (1 << 10):.1f} KiB" |
|
|
|
|
| def _perf_block(stream, ctx: int) -> List[str]: |
| p = stream["perf"] |
| out = [ |
| f" total params : {p['total_params']/1e6:>9.2f}M " |
| f"({p['other_total']/1e6:.2f}M dense + {p['moe_total']/1e6:.2f}M MoE-experts)", |
| f" active params / token : {p['active_params_per_tok']/1e6:>9.2f}M " |
| f"(memory footprint — embed lookup + top_{p['experts_per_tok']}/{p['num_experts']} experts: " |
| f"{p['embed_total']/1e6:.2f}M embed + " |
| f"{p['moe_active_per_tok']/1e6:.2f}M MoE-active + " |
| f"{(p['other_total']-p['embed_total'])/1e6:.2f}M attn/norm/head)", |
| f" compute params / token : {p['compute_params_per_tok']/1e6:>9.2f}M " |
| f"(matmul FLOPs only — embedding lookup excluded)", |
| f" GFLOP / token (fwd, MAC×2): {p['gflops_per_tok']:>9.3f}", |
| f" weights size (on disk) : {_fmt_bytes(p['disk_size_bytes']):>9s}", |
| f" weights size (in RAM) : {_fmt_bytes(p['weight_bytes']):>9s}", |
| f" weights resident (GPU) : {_fmt_bytes(p['weights_resident_bytes']):>9s}", |
| f" peak GPU mem (eval, ctx={ctx}) : {_fmt_bytes(p['peak_eval_mem_bytes']):>9s}", |
| ] |
| return out |
|
|
|
|
| def _write_infer_log(path: Path, cand, ref, args, sample_text: str, tokenizer, device, dtype): |
| """Single-doc inference example + timing + performance metrics.""" |
| from modeling_haremb_pii import _bioes_viterbi_batched |
|
|
| lines: List[str] = [] |
| A = lines.append |
| A(f"Inference benchmark: A: {ref['label']} vs B: {cand['label']}" |
| if ref else f"Inference benchmark: {cand['label']}") |
| A(f" device : {device} dtype: {dtype}") |
| A(f" ctx : {args.max_length}") |
| A("") |
| if ref is not None: |
| A(f"A: {ref['label']} (reference / teacher)") |
| A(f" load : {ref['load_s']:.2f}s") |
| A(f" eval : {ref['eval_s']:.2f}s on {ref['n_tok']:,} tokens " |
| f"({ref['throughput_tok_s']:.0f} tok/s)") |
| A("Performance:") |
| for ln in _perf_block(ref, args.max_length): |
| A(ln) |
| A("") |
| A(f"B: {cand['label']}" + (" (this checkpoint)" if ref else "")) |
| A(f" load : {cand['load_s']:.2f}s") |
| A(f" eval : {cand['eval_s']:.2f}s on {cand['n_tok']:,} tokens " |
| f"({cand['throughput_tok_s']:.0f} tok/s)") |
| A("Performance:") |
| for ln in _perf_block(cand, args.max_length): |
| A(ln) |
| A("") |
| if ref is not None: |
| cp, rp = cand["perf"], ref["perf"] |
|
|
| def _fmt(b, a, kind): |
| if a is None or a == 0 or b is None or b == 0: |
| return "—" |
| r = b / a |
| if kind == "size": |
| return f"{1.0/r:.2f}× smaller" if r <= 1.0 else f"{r:.2f}× larger" |
| if kind == "compute": |
| return f"{1.0/r:.2f}× cheaper" if r <= 1.0 else f"{r:.2f}× more" |
| if kind == "memory": |
| return f"{1.0/r:.2f}× less" if r <= 1.0 else f"{r:.2f}× more" |
| if kind == "speed": |
| return f"{r:.2f}× faster" if r >= 1.0 else f"{1.0/r:.2f}× slower" |
| return f"{r:.2f}×" |
|
|
| A(f"B vs A ({cand['label']} vs {ref['label']}):") |
| A(f" total params : {_fmt(cp['total_params'], rp['total_params'], 'size')}") |
| A(f" active params / token : {_fmt(cp['active_params_per_tok'], rp['active_params_per_tok'], 'memory')} [memory]") |
| A(f" compute params / token : {_fmt(cp['compute_params_per_tok'], rp['compute_params_per_tok'], 'compute')} [FLOPs]") |
| A(f" GFLOP / token : {_fmt(cp['gflops_per_tok'], rp['gflops_per_tok'], 'compute')}") |
| if rp['disk_size_bytes']: |
| A(f" weights size (on disk) : {_fmt(cp['disk_size_bytes'], rp['disk_size_bytes'], 'size')}") |
| else: |
| A(f" weights size (on disk) : —") |
| A(f" weights in RAM : {_fmt(cp['weight_bytes'], rp['weight_bytes'], 'size')}") |
| A(f" peak GPU mem (eval) : {_fmt(cp['peak_eval_mem_bytes'], rp['peak_eval_mem_bytes'], 'memory')}") |
| A(f" throughput : {_fmt(cand['throughput_tok_s'], ref['throughput_tok_s'], 'speed')}") |
| A("") |
|
|
| A("Sample inference (load → tokenize → forward → viterbi-decode → spans):") |
| A(f" text: {sample_text!r}") |
| model = AutoModelForTokenClassification.from_pretrained( |
| ".", dtype=dtype, trust_remote_code=True, |
| ).to(device).eval() |
| if hasattr(model.config, "viterbi_replace_logits"): |
| model.config.viterbi_replace_logits = True |
| enc = tokenizer(sample_text, return_tensors="pt", truncation=True, |
| max_length=args.max_length).to(device) |
| with torch.no_grad(): |
| if torch.cuda.is_available(): |
| torch.cuda.synchronize() |
| t0 = time.time() |
| out = model(**enc) |
| if torch.cuda.is_available(): |
| torch.cuda.synchronize() |
| dt = time.time() - t0 |
| label2id, id2label = nemotron_native_label_space() |
| pred = out.logits.argmax(-1)[0].cpu().tolist() |
| spans = _bioes_to_spans(pred, id2label, 0) |
| A(f" forward latency: {dt*1000:.1f}ms ({enc.input_ids.shape[1]} tokens)") |
| A(f" detected {len(spans)} spans:") |
| tok_ids = enc.input_ids[0].cpu().tolist() |
| for s, e, cat in sorted(spans): |
| text = tokenizer.decode(tok_ids[s:e]).strip() |
| A(f" [{s:3d}, {e:3d}) {cat:<28s} {text!r}") |
| del model |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| path.write_text("\n".join(lines) + "\n") |
| print(f"[log] wrote {path}", flush=True) |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| p = argparse.ArgumentParser(description="Benchmark haremb-privacy-filter-opennemo") |
| p.add_argument("--device", default=None, |
| help="cuda or cpu. Default: cuda if available.") |
| p.add_argument("--dtype", default="bfloat16", |
| choices=["bfloat16", "float16", "float32"]) |
| p.add_argument("--eval-pct", type=float, default=1.0, |
| help="Percent of nvidia/Nemotron-PII test split to use. Default 1%%.") |
| p.add_argument("--eval-chunk-size", type=int, default=10_000) |
| p.add_argument("--seed", type=int, default=42) |
| p.add_argument("--max-length", type=int, default=1024) |
| p.add_argument("--batch-size", type=int, default=4) |
| p.add_argument("--out", type=str, default=".") |
| p.add_argument("--model-path", type=str, default=".", |
| help="Path to this checkpoint. Default: ./ (this folder).") |
| p.add_argument("--no-base", action="store_true", |
| help="Skip the OpenMed teacher comparison.") |
| p.add_argument("--no-plots", action="store_true", |
| help="Skip rendering eval_summary.png / eval_confusion.png.") |
| args = p.parse_args() |
|
|
| out_dir = Path(args.out).resolve() |
| out_dir.mkdir(parents=True, exist_ok=True) |
|
|
| if args.device is None: |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| else: |
| device = torch.device(args.device) |
| dtype = {"bfloat16": torch.bfloat16, "float16": torch.float16, |
| "float32": torch.float32}[args.dtype] |
|
|
| print(f"[setup] device={device} dtype={dtype} model={args.model_path} " |
| f"out={out_dir}", flush=True) |
|
|
| label2id, id2label = nemotron_native_label_space() |
| o_id = label2id["O"] |
|
|
| tokenizer = AutoTokenizer.from_pretrained(args.model_path) |
| pad_token_id = tokenizer.pad_token_id or 199999 |
|
|
| |
| print(f"[data] loading {SOURCE_DATASET} ...", flush=True) |
| ds = load_dataset(SOURCE_DATASET) |
| target_eval = max(1, int(round(len(ds["test"]) * args.eval_pct / 100.0))) |
| eval_indices = _build_eval_streaming( |
| ds["test"], target_n=target_eval, |
| chunk_size=args.eval_chunk_size, seed=args.seed, |
| ) |
| eval_ds = _NemotronEvalDataset( |
| ds["test"].select(eval_indices), tokenizer, label2id, args.max_length, |
| ) |
| args.n_docs = len(eval_ds) |
| print(f"[data] eval={args.n_docs:,} docs ({args.eval_pct:.2f}% of test split)", |
| flush=True) |
|
|
| |
| from modeling_haremb_pii import ( |
| _build_bioes_initial_mask as _bld_init, |
| _build_bioes_transition_mask as _bld_trans, |
| ) |
| bioes_trans = _bld_trans(id2label).to(device).float() |
| bioes_init = _bld_init(id2label).to(device).float() |
|
|
| |
| cand = _eval_one_model( |
| args.model_path, tokenizer, eval_ds, label2id, id2label, o_id, |
| bioes_trans, bioes_init, |
| args.batch_size, args.max_length, device, dtype, |
| label="haremb", |
| ) |
|
|
| print(f"\n=== {cand['label']} ===") |
| print(f"RAW {_fmt_metrics(cand['raw'])}") |
| print(f"VITERBI {_fmt_metrics(cand['viterbi'])}") |
| print(f"DELTA span_F1={cand['viterbi']['span_f1']-cand['raw']['span_f1']:+.4f} " |
| f"P={cand['viterbi']['span_precision']-cand['raw']['span_precision']:+.4f} " |
| f"R={cand['viterbi']['span_recall']-cand['raw']['span_recall']:+.4f}") |
|
|
| ref = None |
| pair = None |
| if not args.no_base: |
| ref = _eval_one_model( |
| TEACHER, tokenizer, eval_ds, label2id, id2label, o_id, |
| bioes_trans, bioes_init, |
| args.batch_size, args.max_length, device, dtype, |
| label="openmed-base", |
| ) |
| print(f"\n=== {ref['label']} (teacher) ===") |
| print(f"VITERBI {_fmt_metrics(ref['viterbi'])}") |
| pair = _pairwise(cand["docs"], ref["docs"], id2label, o_id) |
| d = ref["viterbi"]["span_f1"] - cand["viterbi"]["span_f1"] |
| print(f"\nGap to teacher (viterbi span_F1): {d:+.4f}") |
|
|
| |
| sample_text = ("Patient Sarah Johnson (DOB 03/15/1985), MRN 4872910, " |
| "phone 415-555-0123, email sarah.johnson@example.com, " |
| "credit card 4111-1111-1111-1111.") |
| _write_infer_log(out_dir / "infer.log", cand, ref, args, sample_text, |
| tokenizer, device, dtype) |
| _write_compare_log(out_dir / "compare.log", cand, ref, pair, args) |
| if not args.no_plots: |
| _render_plots(cand, ref, pair, out_dir, |
| cand_label=cand["label"], |
| ref_label=ref["label"] if ref else "") |
| print("\n[done]") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|