""" benchmark.py — single self-contained reproducibility script for haremb-privacy-filter-opennemo. Run from inside this folder: python benchmark.py # default: cuda if available python benchmark.py --device cpu # cpu fallback python benchmark.py --eval-pct 0.5 # smaller slice python benchmark.py --no-base # skip teacher download Produces, in `--out` (default ./): infer.log — sample inference timing + redaction example compare.log — aggregate + per-category metrics, this model vs OpenMed teacher (raw + viterbi streams), and token-level pairwise breakdown. eval_summary.png — bar charts of headline metrics + per-category span-F1 (this vs teacher). eval_confusion.png — token-level outcome breakdown on gold non-O positions (this vs teacher). eval_performance.png — model-size / compute / memory / throughput comparison (this vs teacher), absolute + ratios. This script does not import from training code. It vendors the small set of helpers it needs (BIOES decoder, span builder, eval-set sampler, metrics aggregator) so the model folder is self-contained. """ from __future__ import annotations import argparse import ast import math import os import sys import time from collections import Counter, defaultdict from pathlib import Path from typing import Dict, List, Tuple import numpy as np import torch from datasets import load_dataset from torch.utils.data import DataLoader, Dataset from tqdm.auto import tqdm from transformers import AutoModelForTokenClassification, AutoTokenizer # --------------------------------------------------------------------------- # Constants # --------------------------------------------------------------------------- SOURCE_DATASET = "nvidia/Nemotron-PII" TEACHER = "OpenMed/privacy-filter-nemotron" # 55 Nemotron-PII categories, alphabetically sorted (matches the order used # when the model was trained, so id2label / label2id round-trip cleanly). NEMOTRON_CATEGORIES: List[str] = sorted([ "account_number", "age", "api_key", "bank_routing_number", "biometric_identifier", "blood_type", "certificate_license_number", "city", "company_name", "coordinate", "country", "county", "credit_debit_card", "customer_id", "cvv", "date", "date_of_birth", "date_time", "device_identifier", "education_level", "email", "employee_id", "employment_status", "fax_number", "first_name", "gender", "health_plan_beneficiary_number", "http_cookie", "ipv4", "ipv6", "language", "last_name", "license_plate", "mac_address", "medical_record_number", "national_id", "occupation", "password", "phone_number", "pin", "political_view", "postcode", "race_ethnicity", "religious_belief", "sexuality", "ssn", "state", "street_address", "swift_bic", "tax_id", "time", "unique_id", "url", "user_name", "vehicle_identifier", ]) def nemotron_native_label_space() -> Tuple[Dict[str, int], Dict[int, str]]: """O at id 0, then {B, I, E, S}-{cat} for each cat in alphabetical order.""" label2id: Dict[str, int] = {"O": 0} nxt = 1 for cat in NEMOTRON_CATEGORIES: for prefix in ("B", "I", "E", "S"): label2id[f"{prefix}-{cat}"] = nxt nxt += 1 id2label: Dict[int, str] = {v: k for k, v in label2id.items()} return label2id, id2label # --------------------------------------------------------------------------- # Span parsing + char-level token alignment (vendored from the training data # pipeline; identical logic, no training imports) # --------------------------------------------------------------------------- def _trim_span(text: str, start: int, end: int) -> Tuple[int, int]: raw = text[start:end] i = 0 while i < len(raw) and raw[i].isspace(): i += 1 j = len(raw) while j > i and (raw[j - 1].isspace() or raw[j - 1] in ".,;:)"): j -= 1 return start + i, start + j def _parse_spans(spans_str) -> List[dict]: if isinstance(spans_str, list): return spans_str try: return ast.literal_eval(spans_str) except (SyntaxError, ValueError): return [] def _assign_native_bioes_labels( text: str, raw_spans: List[dict], tokenizer, max_length: int, label2id: Dict[str, int], min_overlap_frac: float = 0.5, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: pf_like: List[Tuple[int, int, str]] = [] for s in raw_spans: cat = s.get("label") if not cat: continue st, en = _trim_span(text, int(s["start"]), int(s["end"])) if st >= en: continue pf_like.append((st, en, cat)) pf_like.sort(key=lambda x: (x[0], -x[1])) enc = tokenizer( text, truncation=True, max_length=max_length, padding="max_length", return_offsets_mapping=True, return_tensors="pt", ) input_ids = enc.input_ids[0] attention_mask = enc.attention_mask[0] offsets = enc.offset_mapping[0].tolist() o_id = label2id["O"] label_ids = [o_id] * len(input_ids) locked = [False] * len(input_ids) for span_start, span_end, cat in pf_like: tok_indices: List[int] = [] for ti, (s, e) in enumerate(offsets): if s == 0 and e == 0: continue if e <= span_start or s >= span_end: continue tok_len = e - s if tok_len <= 0: continue overlap = min(e, span_end) - max(s, span_start) if overlap / tok_len >= min_overlap_frac: tok_indices.append(ti) if not tok_indices: continue tok_indices = [ti for ti in tok_indices if not locked[ti]] if not tok_indices: continue if len(tok_indices) == 1: tag = f"S-{cat}" if tag in label2id: label_ids[tok_indices[0]] = label2id[tag] locked[tok_indices[0]] = True else: b_tag, i_tag, e_tag = f"B-{cat}", f"I-{cat}", f"E-{cat}" if b_tag in label2id: label_ids[tok_indices[0]] = label2id[b_tag] locked[tok_indices[0]] = True for ti in tok_indices[1:-1]: if i_tag in label2id: label_ids[ti] = label2id[i_tag] locked[ti] = True if e_tag in label2id: label_ids[tok_indices[-1]] = label2id[e_tag] locked[tok_indices[-1]] = True label_tensor = torch.tensor(label_ids, dtype=torch.long) label_tensor[attention_mask == 0] = -100 return input_ids, attention_mask, label_tensor class _NemotronEvalDataset(Dataset): def __init__(self, hf_split, tokenizer, label2id, max_length): self.hf = hf_split self.tok = tokenizer self.l2i = label2id self.maxlen = max_length def __len__(self): return len(self.hf) def __getitem__(self, idx): ex = self.hf[idx] ids, mask, labels = _assign_native_bioes_labels( ex["text"], _parse_spans(ex["spans"]), self.tok, self.maxlen, self.l2i, ) L = int(mask.sum().item()) return { "input_ids": ids[:L].tolist(), "labels": labels[:L].tolist(), "valid_len": L, } def _make_collate(pad_token_id, max_length): def _c(batch): ids_list = [list(ex["input_ids"])[:max_length] for ex in batch] labels_list = [list(ex["labels"])[:max_length] for ex in batch] max_len = max(len(x) for x in ids_list) B = len(batch) input_ids = torch.full((B, max_len), pad_token_id, dtype=torch.long) attention_mask = torch.zeros((B, max_len), dtype=torch.long) labels = torch.full((B, max_len), -100, dtype=torch.long) for i, (ids, lab) in enumerate(zip(ids_list, labels_list)): L = len(ids) input_ids[i, :L] = torch.tensor(ids, dtype=torch.long) attention_mask[i, :L] = 1 labels[i, :L] = torch.tensor(lab, dtype=torch.long) return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels} return _c def _build_eval_streaming(test_split, target_n, chunk_size, seed) -> List[int]: """Uniform chunked sampling, identical to the training-time eval split.""" n_total = len(test_split) target_n = min(target_n, n_total) if target_n <= 0: return [] rng = np.random.RandomState(seed) per_chunk = max(1, math.ceil(chunk_size * target_n / n_total)) selected: List[int] = [] for chunk_start in range(0, n_total, chunk_size): if len(selected) >= target_n: break chunk_end = min(chunk_start + chunk_size, n_total) n_in_chunk = chunk_end - chunk_start n_to_pick = min(per_chunk, n_in_chunk, target_n - len(selected)) if n_to_pick <= 0: break offsets = rng.choice(n_in_chunk, size=n_to_pick, replace=False) selected.extend(int(chunk_start + o) for o in offsets) return sorted(selected[:target_n]) # --------------------------------------------------------------------------- # BIOES → spans + metrics # --------------------------------------------------------------------------- def _bioes_to_spans(labels, id2label, o_id=0): """Convert per-token BIOES label ids to a set of (start, end, cat).""" spans = set() cur_start = None cur_cat = None for i, lid in enumerate(labels): lid = int(lid) if lid == o_id or lid < 0: if cur_start is not None: spans.add((cur_start, i, cur_cat)) cur_start = None cur_cat = None continue tag = id2label.get(lid, "O") if tag == "O" or "-" not in tag: if cur_start is not None: spans.add((cur_start, i, cur_cat)) cur_start = None cur_cat = None continue prefix, cat = tag.split("-", 1) if prefix == "S": if cur_start is not None: spans.add((cur_start, i, cur_cat)) spans.add((i, i + 1, cat)) cur_start = None cur_cat = None elif prefix == "B": if cur_start is not None: spans.add((cur_start, i, cur_cat)) cur_start = i cur_cat = cat elif prefix == "I": if cur_start is None or cur_cat != cat: if cur_start is not None: spans.add((cur_start, i, cur_cat)) cur_start = i cur_cat = cat elif prefix == "E": if cur_start is None or cur_cat != cat: if cur_start is not None: spans.add((cur_start, i, cur_cat)) spans.add((i, i + 1, cat)) cur_start = None cur_cat = None else: spans.add((cur_start, i + 1, cur_cat)) cur_start = None cur_cat = None if cur_start is not None: spans.add((cur_start, len(labels), cur_cat)) return spans def _aggregate_span_metrics(gold_spans, pred_spans): correct = gold_spans & pred_spans n_gold = len(gold_spans) n_pred = len(pred_spans) n_correct = len(correct) p = n_correct / n_pred if n_pred else 0.0 r = n_correct / n_gold if n_gold else 0.0 f1 = (2 * p * r / (p + r)) if (p + r) else 0.0 per_cat: Dict[str, dict] = {} cats = sorted({c for _, _, c in gold_spans} | {c for _, _, c in pred_spans}) for cat in cats: g_c = {s for s in gold_spans if s[2] == cat} p_c = {s for s in pred_spans if s[2] == cat} c_c = g_c & p_c pp = len(c_c) / len(p_c) if p_c else 0.0 rr = len(c_c) / len(g_c) if g_c else 0.0 ff = (2 * pp * rr / (pp + rr)) if (pp + rr) else 0.0 per_cat[cat] = {"precision": pp, "recall": rr, "f1": ff, "n_gold": len(g_c), "n_pred": len(p_c), "n_correct": len(c_c)} return {"precision": p, "recall": r, "f1": f1, "n_gold": n_gold, "n_pred": n_pred, "n_correct": n_correct, "per_cat": per_cat} def _stream_metrics(docs, stream, id2label, o_id): """Aggregate metrics over a list of {gold, raw, viterbi} per-doc dicts.""" n_tokens = correct = n_non_o = non_o_correct = 0 gold_spans_all: set = set() pred_spans_all: set = set() doc_offset = 0 for doc in docs: gold = [int(x) for x in doc["gold"]] pred = [int(x) for x in doc[stream]] n = len(gold) n_tokens += n for g, p in zip(gold, pred): n_non_o += int(g != o_id) if g == p: correct += 1 if g != o_id: non_o_correct += 1 gs = _bioes_to_spans(gold, id2label, o_id) ps = _bioes_to_spans(pred, id2label, o_id) gold_spans_all.update((doc_offset + s, doc_offset + e, c) for s, e, c in gs) pred_spans_all.update((doc_offset + s, doc_offset + e, c) for s, e, c in ps) doc_offset += n span_m = _aggregate_span_metrics(gold_spans_all, pred_spans_all) return { "n_tokens": n_tokens, "n_non_o": n_non_o, "token_acc": correct / n_tokens if n_tokens else 0.0, "non_o_recall": non_o_correct / n_non_o if n_non_o else 0.0, "span_precision": span_m["precision"], "span_recall": span_m["recall"], "span_f1": span_m["f1"], "n_gold_spans": span_m["n_gold"], "n_pred_spans": span_m["n_pred"], "n_correct_spans": span_m["n_correct"], "span_per_cat": span_m["per_cat"], } # --------------------------------------------------------------------------- # Forward pass + viterbi (delegates to the released modeling) # --------------------------------------------------------------------------- def _model_perf_stats(model, dtype) -> dict: """Total / active / compute / MoE breakdown + on-device byte size. Three distinct param counts: * total_params — every parameter the model has on disk / in RAM. * active_params_per_tok — params *touched* during one token's forward pass (memory footprint per token). Counts the embedding because the embedding row IS read per token; counts only top-k of num_experts MoE experts because routing is sparse. * compute_params_per_tok — params that contribute matmul FLOPs per token. EXCLUDES the embedding table: `embed_tokens.weight` is a gather (one row read), not a matmul, so its FLOP cost is negligible (~hidden_size ops vs the table having ~vocab × hidden params). Counting it via the standard "2 × params" matmul approximation hugely inflates the apparent GFLOP/token and compresses the ratio between deep and shallow models. GFLOP/token is computed from `compute_params_per_tok`, not from `active_params_per_tok`. This makes the metric reflect actual layer-wise computational cost. """ cfg = model.config num_experts = int(getattr(cfg, "num_local_experts", 1)) top_k = int(getattr(cfg, "num_experts_per_tok", num_experts)) expert_frac = top_k / max(1, num_experts) moe_total = 0 moe_active = 0 other_total = 0 embed_total = 0 # gather-only params; excluded from FLOP estimate for name, p in model.named_parameters(): n = p.numel() # MoE expert tensors are stacked along an experts axis. The upstream # exposes them under `mlp.experts.*`. Only `top_k` of `num_experts` # experts contribute per token. if ".mlp.experts." in name: moe_total += n moe_active += int(round(n * expert_frac)) # `embed_tokens.weight` (and any other lookup-style table) is a # gather: one row of [vocab, hidden] is read per token, costing # ~hidden ops, not 2 × vocab × hidden FLOPs. Tracked separately # so it doesn't pollute the FLOP estimate. elif "embed_tokens" in name: embed_total += n other_total += n else: other_total += n total = moe_total + other_total active_per_tok = moe_active + other_total # Compute-relevant params per token: matmuls only. Drop the embedding # lookup, which contributes ~zero FLOPs. compute_per_tok = active_per_tok - embed_total bytes_per_param = {torch.bfloat16: 2, torch.float16: 2, torch.float32: 4}.get(dtype, 4) storage_bytes = total * bytes_per_param # in-memory dense weights # GFLOP/token over matmul params only. Matmul ≈ 2 FLOPs per param. gflops_per_tok = 2 * compute_per_tok / 1e9 return { "total_params": total, "moe_total": moe_total, "moe_active_per_tok": moe_active, "other_total": other_total, "embed_total": embed_total, "active_params_per_tok": active_per_tok, "compute_params_per_tok": compute_per_tok, "num_experts": num_experts, "experts_per_tok": top_k, "expert_frac": expert_frac, "weight_bytes": storage_bytes, "gflops_per_tok": gflops_per_tok, } def _disk_size_bytes(model_path: str) -> int: """Sum on-disk size of weight files at the given path. Falls back to 0 if the path is a HF repo id (not a local directory).""" p = Path(model_path) if not p.is_dir(): return 0 total = 0 for f in p.iterdir(): if f.is_file() and f.suffix in {".safetensors", ".bin", ".pt", ".pth"}: total += f.stat().st_size return total def _eval_one_model( model_path: str, tokenizer, eval_ds, label2id, id2label, o_id, bioes_trans, bioes_init, batch_size, max_length, device, dtype, label: str, ): print(f"[eval] loading {label} from {model_path} ...", flush=True) if torch.cuda.is_available() and device.type == "cuda": torch.cuda.reset_peak_memory_stats(device) torch.cuda.empty_cache() mem_before = (torch.cuda.memory_allocated(device) if torch.cuda.is_available() and device.type == "cuda" else 0) t_load = time.time() model = AutoModelForTokenClassification.from_pretrained( model_path, dtype=dtype, trust_remote_code=True, ).to(device).eval() if hasattr(model.config, "use_viterbi_decode"): model.config.use_viterbi_decode = True if hasattr(model.config, "viterbi_replace_logits"): model.config.viterbi_replace_logits = False load_s = time.time() - t_load perf = _model_perf_stats(model, dtype) perf["disk_size_bytes"] = _disk_size_bytes(model_path) weights_resident_bytes = ( torch.cuda.memory_allocated(device) - mem_before if torch.cuda.is_available() and device.type == "cuda" else perf["weight_bytes"] ) perf["weights_resident_bytes"] = weights_resident_bytes # Vendor the batched viterbi (re-import from the released modeling file). from modeling_haremb_pii import _bioes_viterbi_batched pad_token_id = tokenizer.pad_token_id or 199999 loader = DataLoader( eval_ds, batch_size=batch_size, shuffle=False, collate_fn=_make_collate(pad_token_id, max_length), num_workers=2, ) docs: List[dict] = [] n_tok = 0 if torch.cuda.is_available() and device.type == "cuda": torch.cuda.synchronize() torch.cuda.reset_peak_memory_stats(device) t0 = time.time() for batch in tqdm(loader, desc=f"eval {label}", unit="batch", leave=False): ids = batch["input_ids"].to(device, non_blocking=True) mask = batch["attention_mask"].to(device, non_blocking=True) gold = batch["labels"].to(device, non_blocking=True) with torch.no_grad(): out = model(input_ids=ids, attention_mask=mask) raw = out.logits.argmax(dim=-1) vit = _bioes_viterbi_batched(out.logits.float(), mask, bioes_trans, bioes_init) valid = (gold != -100) & mask.bool() for b in range(gold.shape[0]): keep = [i for i, ok in enumerate(valid[b].cpu().tolist()) if ok] n_tok += len(keep) docs.append({ "gold": [int(gold[b, i].item()) for i in keep], "raw": [int(raw[b, i].item()) for i in keep], "viterbi": [int(vit[b, i].item()) for i in keep], }) if torch.cuda.is_available() and device.type == "cuda": torch.cuda.synchronize() peak_mem = torch.cuda.max_memory_allocated(device) else: peak_mem = 0 eval_s = time.time() - t0 raw_m = _stream_metrics(docs, "raw", id2label, o_id) vit_m = _stream_metrics(docs, "viterbi", id2label, o_id) perf["peak_eval_mem_bytes"] = peak_mem del model if torch.cuda.is_available(): torch.cuda.empty_cache() return { "label": label, "n_total_M": perf["total_params"] / 1e6, "load_s": load_s, "eval_s": eval_s, "n_tok": n_tok, "throughput_tok_s": n_tok / eval_s if eval_s else 0.0, "perf": perf, "raw": raw_m, "viterbi": vit_m, "docs": docs, } # --------------------------------------------------------------------------- # Pairwise token-level breakdown (this vs reference, viterbi stream) # --------------------------------------------------------------------------- def _pairwise(docs_cand, docs_ref, id2label, o_id): both_correct = only_cand = only_ref = both_wrong = 0 by_cat = defaultdict(lambda: {"both_correct": 0, "only_cand": 0, "only_ref": 0, "both_wrong": 0}) for dc, dr in zip(docs_cand, docs_ref): gold = dc["gold"] cv = dc["viterbi"] rv = dr["viterbi"] for g, c, r in zip(gold, cv, rv): cat = id2label.get(g, "O") cat = cat.split("-", 1)[1] if "-" in cat else cat cc = (c == g) rc = (r == g) if cc and rc: both_correct += 1 by_cat[cat]["both_correct"] += 1 elif cc and not rc: only_cand += 1 by_cat[cat]["only_cand"] += 1 elif rc and not cc: only_ref += 1 by_cat[cat]["only_ref"] += 1 else: both_wrong += 1 by_cat[cat]["both_wrong"] += 1 return { "both_correct": both_correct, "only_cand_correct": only_cand, "only_ref_correct": only_ref, "both_wrong": both_wrong, "by_cat": dict(by_cat), } # --------------------------------------------------------------------------- # Plot rendering # --------------------------------------------------------------------------- def _render_plots(cand, ref, pair, out_dir: Path, cand_label, ref_label): """Render benchmark plots. Visual convention: A = reference / teacher / baseline B = candidate / this checkpoint The charts avoid color-only "win" encoding: labels state the actual delta or ratio, and horizontal layouts keep long metric/category names readable. """ try: import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt from matplotlib.ticker import FuncFormatter except ImportError: print("[plot] matplotlib not installed, skipping", flush=True) return plt.rcParams.update({ "figure.facecolor": "#ffffff", "axes.facecolor": "#ffffff", "axes.edgecolor": "#cbd5e1", "axes.labelcolor": "#0f172a", "xtick.color": "#334155", "ytick.color": "#334155", "grid.color": "#e2e8f0", "font.size": 9, "axes.titleweight": "bold", "axes.titlesize": 11, "legend.frameon": False, }) C_REF = "#64748b" # slate C_CAND = "#2563eb" # blue C_GOOD = "#0f766e" # teal C_BAD = "#b91c1c" # red C_NEUTRAL = "#94a3b8" # light slate C_BG = "#f8fafc" def _pct_axis(ax): ax.xaxis.set_major_formatter(FuncFormatter(lambda x, _pos: f"{x:.0%}")) def _metric_delta_text(delta): return f"{delta:+.4f}" def _value_text(v): if v >= 100: return f"{v:,.0f}" if v >= 10: return f"{v:,.1f}" if v >= 1: return f"{v:,.2f}" return f"{v:,.3f}" def _ratio_text(r, lower_is_better): if r <= 0: return "n/a" if lower_is_better: return f"{1.0 / r:.2f}x lower" if r <= 1 else f"{r:.2f}x higher" return f"{r:.2f}x higher" if r >= 1 else f"{1.0 / r:.2f}x lower" # --- eval_summary.png: headline metrics + category-level deltas --- fig, axes = plt.subplots(2, 2, figsize=(8, 5), constrained_layout=True) ax_head = axes[0, 0] ax_delta = axes[0, 1] ax_raw_vit = axes[1, 0] ax_cat = axes[1, 1] metrics = ["span_f1", "span_precision", "span_recall", "token_acc", "non_o_recall"] labels = ["Span F1", "Span P", "Span R", "Token acc", "Non-O recall"] cand_v = [cand["viterbi"][m] for m in metrics] ref_v = [ref["viterbi"][m] for m in metrics] if ref is not None else None y = np.arange(len(metrics)) if ref_v is not None: ax_head.hlines(y, ref_v, cand_v, color=C_NEUTRAL, linewidth=2, alpha=0.9) ax_head.scatter(ref_v, y, s=55, color=C_REF, label=f"A: {ref_label}", zorder=3) ax_head.scatter(cand_v, y, s=70, color=C_CAND, label=f"B: {cand_label}", zorder=4) ax_head.set_yticks(y) ax_head.set_yticklabels(labels) ax_head.invert_yaxis() ax_head.set_xlim(max(0.0, min(cand_v + (ref_v or cand_v)) - 0.02), min(1.08, max(cand_v + (ref_v or cand_v)) + 0.05)) ax_head.set_title("Headline metrics, Viterbi stream") ax_head.grid(axis="x", alpha=0.7) _pct_axis(ax_head) if ref_v is not None: for i, v in enumerate(ref_v): ax_head.text(v + 0.002, i, f"{v:.4f}", va="center", ha="left", fontsize=7, color=C_REF) for i, v in enumerate(cand_v): ax_head.text(v - 0.002, i, f"{v:.4f}", va="center", ha="right", fontsize=7, color=C_CAND) if ref_v is not None: deltas = [b - a for a, b in zip(ref_v, cand_v)] colors = [C_GOOD if d >= 0 else C_BAD for d in deltas] ax_delta.axvline(0, color="#0f172a", linewidth=0.9) ax_delta.barh(y, deltas, color=colors, alpha=0.9) ax_delta.set_yticks(y) ax_delta.set_yticklabels(labels) ax_delta.invert_yaxis() ax_delta.set_title("Delta: B minus A") ax_delta.grid(axis="x", alpha=0.7) max_abs = max([abs(d) for d in deltas] + [0.002]) ax_delta.set_xlim(-max_abs * 1.45, max_abs * 1.45) for i, d in enumerate(deltas): ax_delta.text(d + (max_abs * 0.04 if d >= 0 else -max_abs * 0.04), i, _metric_delta_text(d), ha="left" if d >= 0 else "right", va="center", fontsize=7) else: ax_delta.axis("off") stream_rows = [ ("A raw", ref["raw"]["span_f1"] if ref is not None else None, C_REF), ("A viterbi", ref["viterbi"]["span_f1"] if ref is not None else None, C_REF), ("B raw", cand["raw"]["span_f1"], C_CAND), ("B viterbi", cand["viterbi"]["span_f1"], C_CAND), ] stream_rows = [r for r in stream_rows if r[1] is not None] sy = np.arange(len(stream_rows)) ax_raw_vit.barh(sy, [r[1] for r in stream_rows], color=[r[2] for r in stream_rows], alpha=0.88) ax_raw_vit.set_yticks(sy) ax_raw_vit.set_yticklabels([r[0] for r in stream_rows]) ax_raw_vit.invert_yaxis() ax_raw_vit.set_xlim(0, 1.08) ax_raw_vit.set_title("Raw vs Viterbi span F1") ax_raw_vit.grid(axis="x", alpha=0.7) _pct_axis(ax_raw_vit) for i, (_, v, _) in enumerate(stream_rows): ax_raw_vit.text(v + 0.008, i, f"{v:.4f}", va="center", fontsize=7) cand_pc = cand["viterbi"]["span_per_cat"] if ref is not None: ref_pc = ref["viterbi"]["span_per_cat"] cats = sorted(set(cand_pc) | set(ref_pc)) rows = [] for c in cats: a = ref_pc.get(c, {}).get("f1", 0.0) b = cand_pc.get(c, {}).get("f1", 0.0) n = max(cand_pc.get(c, {}).get("n_gold", 0), ref_pc.get(c, {}).get("n_gold", 0)) rows.append((c, b - a, b, a, n)) # Keep the categories that explain the comparison: worst and best B deltas. worst = sorted(rows, key=lambda r: r[1])[:8] best = sorted(rows, key=lambda r: r[1], reverse=True)[:8] picked = worst + [r for r in best if r[0] not in {x[0] for x in worst}] picked = sorted(picked, key=lambda r: r[1]) cy = np.arange(len(picked)) deltas = [r[1] for r in picked] ax_cat.axvline(0, color="#0f172a", linewidth=0.9) ax_cat.barh(cy, deltas, color=[C_GOOD if d >= 0 else C_BAD for d in deltas]) ax_cat.set_yticks(cy) ax_cat.set_yticklabels([r[0] for r in picked], fontsize=8) ax_cat.set_title("Per-category span F1 delta, selected extremes") ax_cat.grid(axis="x", alpha=0.7) max_abs = max([abs(d) for d in deltas] + [0.05]) ax_cat.set_xlim(-max_abs * 1.55, max_abs * 1.55) for i, r in enumerate(picked): d = r[1] ax_cat.text(d + (max_abs * 0.05 if d >= 0 else -max_abs * 0.05), i, f"{d:+.3f} B={r[2]:.2f} A={r[3]:.2f}", va="center", ha="left" if d >= 0 else "right", fontsize=6) else: cats_sorted = sorted(cand_pc.keys(), key=lambda c: cand_pc[c]["f1"])[:18] vals = [cand_pc[c]["f1"] for c in cats_sorted] cy = np.arange(len(cats_sorted)) ax_cat.barh(cy, vals, color=C_CAND) ax_cat.set_yticks(cy) ax_cat.set_yticklabels(cats_sorted, fontsize=8) ax_cat.set_xlim(0, 1.0) ax_cat.set_title("Lowest per-category span F1") ax_cat.grid(axis="x", alpha=0.7) _pct_axis(ax_cat) fig.suptitle(f"Evaluation summary — A: {ref_label if ref else 'n/a'} | B: {cand_label}", fontsize=9, fontweight="bold") fig.savefig(out_dir / "eval_summary.png", dpi=160) plt.close(fig) print(f"[plot] wrote {out_dir / 'eval_summary.png'}", flush=True) # --- eval_confusion.png: pairwise outcome on gold non-O tokens --- # Display order matches A vs B: "Only A correct" (teacher) before # "Only B correct" (student). Underlying buckets in `pair` are still # named cand/ref; we just relabel for display. if pair is not None: fig, axes = plt.subplots( 1, 2, figsize=(8, 3), constrained_layout=True, gridspec_kw={"width_ratios": [0.9, 1.7]}, ) ax = axes[0] non_o_buckets = {k: 0 for k in ["both_correct", "only_cand", "only_ref", "both_wrong"]} for cat, d in pair["by_cat"].items(): if cat == "O": continue for k in non_o_buckets: non_o_buckets[k] += d[k] values = [non_o_buckets["both_correct"], non_o_buckets["only_ref"], non_o_buckets["only_cand"], non_o_buckets["both_wrong"]] labels_ = [ "Both\ncorrect", "Only A\ncorrect", "Only B\ncorrect", "Both wrong", ] colors = [C_GOOD, C_REF, C_CAND, C_BAD] total = max(1, sum(values)) ax.barh(np.arange(4), values, color=colors) ax.set_yticks(np.arange(4)) ax.set_yticklabels(labels_) ax.invert_yaxis() ax.set_ylabel("Gold non-O tokens") ax.set_title("Token outcome on gold non-O") ax.grid(axis="x", alpha=0.7) ax.set_xlim(0, max(values) * 1.32 if values else 1) for i, v in enumerate(values): ax.text(v + max(values) * 0.015, i, f"{v:,} ({v / total:.1%})", va="center", fontsize=6) rows = [] for cat, d in pair["by_cat"].items(): if cat == "O": continue net = d["only_cand"] - d["only_ref"] active = d["only_cand"] + d["only_ref"] + d["both_wrong"] if active: rows.append((cat, net, d["only_cand"], d["only_ref"], d["both_wrong"])) worst = sorted(rows, key=lambda r: r[1])[:8] best = sorted(rows, key=lambda r: r[1], reverse=True)[:8] picked = worst + [r for r in best if r[0] not in {x[0] for x in worst}] picked = sorted(picked, key=lambda r: r[1]) ax2 = axes[1] if picked: py = np.arange(len(picked)) nets = [r[1] for r in picked] ax2.axvline(0, color="#0f172a", linewidth=0.9) ax2.barh(py, nets, color=[C_GOOD if n >= 0 else C_BAD for n in nets]) ax2.set_yticks(py) ax2.set_yticklabels([r[0] for r in picked], fontsize=8) ax2.set_title("Net token wins by category: B only-correct minus A only-correct") ax2.grid(axis="x", alpha=0.7) max_abs = max([abs(n) for n in nets] + [1]) ax2.set_xlim(-max_abs * 1.5, max_abs * 1.5) for i, r in enumerate(picked): n = r[1] label_x = n + max_abs * 0.04 if n >= 0 else max_abs * 0.05 ax2.text(label_x, i, f"{n:+d} B={r[2]} A={r[3]} W={r[4]}", va="center", ha="left", fontsize=6) else: ax2.axis("off") fig.suptitle(f"Pairwise correctness — A: {ref_label} | B: {cand_label}", fontsize=9, fontweight="bold") fig.savefig(out_dir / "eval_confusion.png", dpi=160) plt.close(fig) print(f"[plot] wrote {out_dir / 'eval_confusion.png'}", flush=True) # --- eval_performance.png: model size, compute, throughput, memory --- if ref is not None and "perf" in cand and "perf" in ref: cp, rp = cand["perf"], ref["perf"] fig, axes = plt.subplots(1, 2, figsize=(8.5, 5), constrained_layout=True) ax_abs = axes[0] ax_ratio = axes[1] metrics = [ ("Total params (M)", cp["total_params"]/1e6, rp["total_params"]/1e6, True), ("Active params/tok (M)", cp["active_params_per_tok"]/1e6, rp["active_params_per_tok"]/1e6, True), ("MoE expert params (M)", cp["moe_total"]/1e6, rp["moe_total"]/1e6, True), ("GFLOP/token", cp["gflops_per_tok"], rp["gflops_per_tok"], True), ("Weights RAM (MiB)", cp["weight_bytes"]/(1<<20), rp["weight_bytes"]/(1<<20), True), ("Peak eval mem (MiB)", cp["peak_eval_mem_bytes"]/(1<<20), rp["peak_eval_mem_bytes"]/(1<<20), True), ("Throughput (tok/s)", cand["throughput_tok_s"], ref["throughput_tok_s"], False), ] y = np.arange(len(metrics)) w = 0.38 cand_v = [m[1] for m in metrics] ref_v = [m[2] for m in metrics] ax_abs.barh(y - w/2, ref_v, w, label=f"A: {ref_label}", color=C_REF) ax_abs.barh(y + w/2, cand_v, w, label=f"B: {cand_label}", color=C_CAND) ax_abs.set_yticks(y) ax_abs.set_yticklabels([m[0] for m in metrics], fontsize=8) ax_abs.invert_yaxis() ax_abs.set_xscale("log") ax_abs.set_title("Absolute footprint and speed, log scale") ax_abs.grid(axis="x", which="both", alpha=0.7) positive_vals = [v for v in (ref_v + cand_v) if v > 0] if positive_vals: ax_abs.set_xlim(min(positive_vals) * 0.55, max(positive_vals) * 3.8) for yi, vals in enumerate(zip(ref_v, cand_v)): for off, v, col in [(-w/2, vals[0], C_REF), (w/2, vals[1], C_CAND)]: if v <= 0: continue ax_abs.text(v * 1.05, yi + off, _value_text(v), va="center", fontsize=6, color=col) ratios = [(m[1] / max(1e-12, m[2])) for m in metrics] lower_better = [m[3] for m in metrics] colors = [ C_GOOD if ((lb and r <= 1.0) or ((not lb) and r >= 1.0)) else C_BAD for r, lb in zip(ratios, lower_better) ] ax_ratio.axvline(1.0, color="#0f172a", linestyle="--", linewidth=0.9, alpha=0.65) ax_ratio.barh(y, ratios, color=colors) ax_ratio.set_yticks(y) ax_ratio.set_yticklabels([m[0] for m in metrics], fontsize=8) ax_ratio.invert_yaxis() ax_ratio.set_xscale("log") ax_ratio.set_title("B / A ratio with explicit direction") ax_ratio.grid(axis="x", which="both", alpha=0.7) positive_ratios = [r for r in ratios if r > 0] if positive_ratios: ax_ratio.set_xlim(min(positive_ratios) * 0.38, max(positive_ratios) * 2.8) for i, (r, lb) in enumerate(zip(ratios, lower_better)): ax_ratio.text(r * (1.05 if r >= 1 else 0.95), i, _ratio_text(r, lb), va="center", ha="left" if r >= 1 else "right", fontsize=6) fig.suptitle(f"Performance profile — A: {ref_label} | B: {cand_label}", fontsize=9, fontweight="bold") fig.savefig(out_dir / "eval_performance.png", dpi=160) plt.close(fig) print(f"[plot] wrote {out_dir / 'eval_performance.png'}", flush=True) # --------------------------------------------------------------------------- # Reporting # --------------------------------------------------------------------------- def _fmt_metrics(m): return (f"span_F1={m['span_f1']:.4f} P={m['span_precision']:.4f} " f"R={m['span_recall']:.4f} token_acc={m['token_acc']:.4f} " f"non_o_recall={m['non_o_recall']:.4f} " f"spans={m['n_gold_spans']}/{m['n_pred_spans']}/{m['n_correct_spans']}") def _write_compare_log(path: Path, cand, ref, pair, args): lines: List[str] = [] A = lines.append A(f"Benchmark: A: {ref['label']} vs B: {cand['label']}" if ref else f"Benchmark: {cand['label']}") A(f"Dataset: {SOURCE_DATASET}, split=test, eval_pct={args.eval_pct}, " f"ctx={args.max_length}, seed={args.seed}, n_docs={args.n_docs}") A(f"Eval tokens scored: {cand['n_tok']:,}") A("") A("=== Aggregate ===") if ref is not None: A(f" A: {ref['label']:<25s} RAW {_fmt_metrics(ref['raw'])}") A(f" A: {ref['label']:<25s} VITERBI {_fmt_metrics(ref['viterbi'])}") A(f" B: {cand['label']:<25s} RAW {_fmt_metrics(cand['raw'])}") A(f" B: {cand['label']:<25s} VITERBI {_fmt_metrics(cand['viterbi'])}") if ref is not None: d_f1 = ref["viterbi"]["span_f1"] - cand["viterbi"]["span_f1"] A("") A(f"Gap B vs A (viterbi span_F1): {-d_f1:+.4f}") A("") if ref is not None: A(f"Throughput: A: {ref['label']} {ref['throughput_tok_s']:.0f} tok/s " f"({ref['n_total_M']:.2f}M params)") A(f" B: {cand['label']} {cand['throughput_tok_s']:.0f} tok/s " f"({cand['n_total_M']:.2f}M params)") else: A(f"Throughput: {cand['label']} {cand['throughput_tok_s']:.0f} tok/s " f"({cand['n_total_M']:.2f}M params)") A("") # ---- Performance summary table ---- # Column order: A (ref / teacher) first, then B (cand / student). # The "B vs A" column is written in human-readable direction: # - When B is smaller (size/compute/mem): "X.XX× smaller" / "X.XX× cheaper" / "X.XX× less". # - When B is larger but lower-is-better: "X.XX× larger" / "X.XX× more". # - When B is faster (throughput, higher-is-better): "X.XX× faster". # - When B is slower: "X.XX× slower". # Always uses the magnitude in the dominant direction so the reader # doesn't need to mentally invert 0.21× into 4.87×. def _fmt_vs(b: float, a: float, kind: str) -> str: if a is None or a == 0 or b is None or b == 0: return "" ratio = b / a if kind == "size": # lower is better; phrase as "smaller" or "larger" if ratio <= 1.0: return f"{1.0/ratio:.2f}× smaller" return f"{ratio:.2f}× larger" if kind == "compute": if ratio <= 1.0: return f"{1.0/ratio:.2f}× cheaper" return f"{ratio:.2f}× more" if kind == "memory": if ratio <= 1.0: return f"{1.0/ratio:.2f}× less" return f"{ratio:.2f}× more" if kind == "speed": # higher is better if ratio >= 1.0: return f"{ratio:.2f}× faster" return f"{1.0/ratio:.2f}× slower" return f"{ratio:.2f}×" cp = cand["perf"] A("=== Performance ===") headers = ["metric", f"A: {ref['label']}" if ref else "", f"B: {cand['label']}", "B vs A"] rp = ref["perf"] if ref else None rows = [ ["total params (M)", (f"{rp['total_params']/1e6:.2f}" if ref else ""), f"{cp['total_params']/1e6:.2f}", (_fmt_vs(cp['total_params'], rp['total_params'], "size") if ref else "")], ["dense params (M)", (f"{rp['other_total']/1e6:.2f}" if ref else ""), f"{cp['other_total']/1e6:.2f}", (_fmt_vs(cp['other_total'], rp['other_total'], "size") if ref else "")], ["MoE expert params (M)", (f"{rp['moe_total']/1e6:.2f}" if ref else ""), f"{cp['moe_total']/1e6:.2f}", (_fmt_vs(cp['moe_total'], rp['moe_total'], "size") if ref else "")], [f"active params/token (M, mem)", (f"{rp['active_params_per_tok']/1e6:.2f}" if ref else ""), f"{cp['active_params_per_tok']/1e6:.2f}", (_fmt_vs(cp['active_params_per_tok'], rp['active_params_per_tok'], "memory") if ref else "")], [f"compute params/token (M, FLOPs)", (f"{rp['compute_params_per_tok']/1e6:.2f}" if ref else ""), f"{cp['compute_params_per_tok']/1e6:.2f}", (_fmt_vs(cp['compute_params_per_tok'], rp['compute_params_per_tok'], "compute") if ref else "")], ["GFLOP / token", (f"{rp['gflops_per_tok']:.4f}" if ref else ""), f"{cp['gflops_per_tok']:.4f}", (_fmt_vs(cp['gflops_per_tok'], rp['gflops_per_tok'], "compute") if ref else "")], ["disk size (MiB)", (f"{rp['disk_size_bytes']/(1<<20):.1f}" if ref and rp['disk_size_bytes'] else ""), f"{cp['disk_size_bytes']/(1<<20):.1f}", (_fmt_vs(cp['disk_size_bytes'], rp['disk_size_bytes'], "size") if ref and rp['disk_size_bytes'] else "")], ["weights in RAM (MiB)", (f"{rp['weight_bytes']/(1<<20):.1f}" if ref else ""), f"{cp['weight_bytes']/(1<<20):.1f}", (_fmt_vs(cp['weight_bytes'], rp['weight_bytes'], "size") if ref else "")], ["peak GPU mem eval (MiB)", (f"{rp['peak_eval_mem_bytes']/(1<<20):.1f}" if ref else ""), f"{cp['peak_eval_mem_bytes']/(1<<20):.1f}", (_fmt_vs(cp['peak_eval_mem_bytes'], rp['peak_eval_mem_bytes'], "memory") if ref else "")], ["throughput (tok/s)", (f"{ref['throughput_tok_s']:.0f}" if ref else ""), f"{cand['throughput_tok_s']:.0f}", (_fmt_vs(cand['throughput_tok_s'], ref['throughput_tok_s'], "speed") if ref else "")], ] widths = [max(len(r[i]) for r in [headers] + rows) for i in range(4)] sep = " " + " ".join("-" * w for w in widths) A(" " + " ".join(h.ljust(widths[i]) for i, h in enumerate(headers))) A(sep) for r in rows: A(" " + " ".join(r[i].ljust(widths[i]) for i in range(4))) A("") if pair is not None: # Display labels: A = ref (teacher), B = cand (student). a_lbl = ref["label"] if ref is not None else "A" b_lbl = cand["label"] A(f"=== Pairwise (viterbi, all gold tokens) — A: {a_lbl} vs B: {b_lbl} ===") total = (pair["both_correct"] + pair["only_cand_correct"] + pair["only_ref_correct"] + pair["both_wrong"]) # Display order: agreement, A-only, B-only, both-wrong. rows_all = [ ("both_correct", pair["both_correct"]), ("only_A_correct", pair["only_ref_correct"]), ("only_B_correct", pair["only_cand_correct"]), ("both_wrong", pair["both_wrong"]), ] for k, v in rows_all: A(f" {k:<26s} {v:8d} ({100.0*v/total:.2f}%)") A("") A(f"=== Pairwise (viterbi, gold non-O tokens) — A: {a_lbl} vs B: {b_lbl} ===") non_o = {k: 0 for k in ["both_correct", "only_cand", "only_ref", "both_wrong"]} for cat, d in pair["by_cat"].items(): if cat == "O": continue for k in non_o: non_o[k] += d[k] total_non_o = sum(non_o.values()) rows_non_o = [ ("both_correct", non_o["both_correct"]), ("only_A_correct", non_o["only_ref"]), ("only_B_correct", non_o["only_cand"]), ("both_wrong", non_o["both_wrong"]), ] for k, v in rows_non_o: A(f" {k:<26s} {v:8d} ({100.0*v/total_non_o:.2f}%)" if total_non_o else f" {k}: 0") A("") # Per-cat net B-wins. Net = (only_B) - (only_A) = (only_cand) - (only_ref). # Negative net = A (teacher) wins more in this category. nets = [] for cat, d in pair["by_cat"].items(): if cat == "O": continue nets.append((cat, d["only_cand"] - d["only_ref"], d["only_cand"], d["only_ref"], d["both_wrong"])) nets.sort(key=lambda x: x[1]) A(f"=== Worst B-net wins by gold category — A: {a_lbl} ahead (top 15) ===") for cat, net, ob, oa, bw in nets[:15]: A(f" {cat:<32s} net_B={net:+5d} A_only={oa:4d} B_only={ob:4d} both_wrong={bw:4d}") A("") A(f"=== Best B-net wins by gold category — B: {b_lbl} ahead (top 15) ===") for cat, net, ob, oa, bw in nets[::-1][:15]: A(f" {cat:<32s} net_B={net:+5d} A_only={oa:4d} B_only={ob:4d} both_wrong={bw:4d}") A("") A("=== Per-category span F1 (viterbi) ===") if ref is not None: A(f" -- A: {ref['label']} --") per_r = ref["viterbi"]["span_per_cat"] for cat in sorted(per_r): c = per_r[cat] A(f" {cat:<32s} F1={c['f1']:.4f} P={c['precision']:.4f} R={c['recall']:.4f} " f"({c['n_gold']}/{c['n_pred']}/{c['n_correct']})") A(f" -- B: {cand['label']} --") per = cand["viterbi"]["span_per_cat"] for cat in sorted(per): c = per[cat] A(f" {cat:<32s} F1={c['f1']:.4f} P={c['precision']:.4f} R={c['recall']:.4f} " f"({c['n_gold']}/{c['n_pred']}/{c['n_correct']})") path.write_text("\n".join(lines) + "\n") print(f"[log] wrote {path}", flush=True) def _fmt_bytes(n: int) -> str: if n <= 0: return "—" if n >= 1 << 30: return f"{n / (1 << 30):.2f} GiB" if n >= 1 << 20: return f"{n / (1 << 20):.1f} MiB" return f"{n / (1 << 10):.1f} KiB" def _perf_block(stream, ctx: int) -> List[str]: p = stream["perf"] out = [ f" total params : {p['total_params']/1e6:>9.2f}M " f"({p['other_total']/1e6:.2f}M dense + {p['moe_total']/1e6:.2f}M MoE-experts)", f" active params / token : {p['active_params_per_tok']/1e6:>9.2f}M " f"(memory footprint — embed lookup + top_{p['experts_per_tok']}/{p['num_experts']} experts: " f"{p['embed_total']/1e6:.2f}M embed + " f"{p['moe_active_per_tok']/1e6:.2f}M MoE-active + " f"{(p['other_total']-p['embed_total'])/1e6:.2f}M attn/norm/head)", f" compute params / token : {p['compute_params_per_tok']/1e6:>9.2f}M " f"(matmul FLOPs only — embedding lookup excluded)", f" GFLOP / token (fwd, MAC×2): {p['gflops_per_tok']:>9.3f}", f" weights size (on disk) : {_fmt_bytes(p['disk_size_bytes']):>9s}", f" weights size (in RAM) : {_fmt_bytes(p['weight_bytes']):>9s}", f" weights resident (GPU) : {_fmt_bytes(p['weights_resident_bytes']):>9s}", f" peak GPU mem (eval, ctx={ctx}) : {_fmt_bytes(p['peak_eval_mem_bytes']):>9s}", ] return out def _write_infer_log(path: Path, cand, ref, args, sample_text: str, tokenizer, device, dtype): """Single-doc inference example + timing + performance metrics.""" from modeling_haremb_pii import _bioes_viterbi_batched lines: List[str] = [] A = lines.append A(f"Inference benchmark: A: {ref['label']} vs B: {cand['label']}" if ref else f"Inference benchmark: {cand['label']}") A(f" device : {device} dtype: {dtype}") A(f" ctx : {args.max_length}") A("") if ref is not None: A(f"A: {ref['label']} (reference / teacher)") A(f" load : {ref['load_s']:.2f}s") A(f" eval : {ref['eval_s']:.2f}s on {ref['n_tok']:,} tokens " f"({ref['throughput_tok_s']:.0f} tok/s)") A("Performance:") for ln in _perf_block(ref, args.max_length): A(ln) A("") A(f"B: {cand['label']}" + (" (this checkpoint)" if ref else "")) A(f" load : {cand['load_s']:.2f}s") A(f" eval : {cand['eval_s']:.2f}s on {cand['n_tok']:,} tokens " f"({cand['throughput_tok_s']:.0f} tok/s)") A("Performance:") for ln in _perf_block(cand, args.max_length): A(ln) A("") if ref is not None: cp, rp = cand["perf"], ref["perf"] def _fmt(b, a, kind): if a is None or a == 0 or b is None or b == 0: return "—" r = b / a if kind == "size": return f"{1.0/r:.2f}× smaller" if r <= 1.0 else f"{r:.2f}× larger" if kind == "compute": return f"{1.0/r:.2f}× cheaper" if r <= 1.0 else f"{r:.2f}× more" if kind == "memory": return f"{1.0/r:.2f}× less" if r <= 1.0 else f"{r:.2f}× more" if kind == "speed": return f"{r:.2f}× faster" if r >= 1.0 else f"{1.0/r:.2f}× slower" return f"{r:.2f}×" A(f"B vs A ({cand['label']} vs {ref['label']}):") A(f" total params : {_fmt(cp['total_params'], rp['total_params'], 'size')}") A(f" active params / token : {_fmt(cp['active_params_per_tok'], rp['active_params_per_tok'], 'memory')} [memory]") A(f" compute params / token : {_fmt(cp['compute_params_per_tok'], rp['compute_params_per_tok'], 'compute')} [FLOPs]") A(f" GFLOP / token : {_fmt(cp['gflops_per_tok'], rp['gflops_per_tok'], 'compute')}") if rp['disk_size_bytes']: A(f" weights size (on disk) : {_fmt(cp['disk_size_bytes'], rp['disk_size_bytes'], 'size')}") else: A(f" weights size (on disk) : —") A(f" weights in RAM : {_fmt(cp['weight_bytes'], rp['weight_bytes'], 'size')}") A(f" peak GPU mem (eval) : {_fmt(cp['peak_eval_mem_bytes'], rp['peak_eval_mem_bytes'], 'memory')}") A(f" throughput : {_fmt(cand['throughput_tok_s'], ref['throughput_tok_s'], 'speed')}") A("") A("Sample inference (load → tokenize → forward → viterbi-decode → spans):") A(f" text: {sample_text!r}") model = AutoModelForTokenClassification.from_pretrained( ".", dtype=dtype, trust_remote_code=True, ).to(device).eval() if hasattr(model.config, "viterbi_replace_logits"): model.config.viterbi_replace_logits = True enc = tokenizer(sample_text, return_tensors="pt", truncation=True, max_length=args.max_length).to(device) with torch.no_grad(): if torch.cuda.is_available(): torch.cuda.synchronize() t0 = time.time() out = model(**enc) if torch.cuda.is_available(): torch.cuda.synchronize() dt = time.time() - t0 label2id, id2label = nemotron_native_label_space() pred = out.logits.argmax(-1)[0].cpu().tolist() spans = _bioes_to_spans(pred, id2label, 0) A(f" forward latency: {dt*1000:.1f}ms ({enc.input_ids.shape[1]} tokens)") A(f" detected {len(spans)} spans:") tok_ids = enc.input_ids[0].cpu().tolist() for s, e, cat in sorted(spans): text = tokenizer.decode(tok_ids[s:e]).strip() A(f" [{s:3d}, {e:3d}) {cat:<28s} {text!r}") del model if torch.cuda.is_available(): torch.cuda.empty_cache() path.write_text("\n".join(lines) + "\n") print(f"[log] wrote {path}", flush=True) # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main(): p = argparse.ArgumentParser(description="Benchmark haremb-privacy-filter-opennemo") p.add_argument("--device", default=None, help="cuda or cpu. Default: cuda if available.") p.add_argument("--dtype", default="bfloat16", choices=["bfloat16", "float16", "float32"]) p.add_argument("--eval-pct", type=float, default=1.0, help="Percent of nvidia/Nemotron-PII test split to use. Default 1%%.") p.add_argument("--eval-chunk-size", type=int, default=10_000) p.add_argument("--seed", type=int, default=42) p.add_argument("--max-length", type=int, default=1024) p.add_argument("--batch-size", type=int, default=4) p.add_argument("--out", type=str, default=".") p.add_argument("--model-path", type=str, default=".", help="Path to this checkpoint. Default: ./ (this folder).") p.add_argument("--no-base", action="store_true", help="Skip the OpenMed teacher comparison.") p.add_argument("--no-plots", action="store_true", help="Skip rendering eval_summary.png / eval_confusion.png.") args = p.parse_args() out_dir = Path(args.out).resolve() out_dir.mkdir(parents=True, exist_ok=True) if args.device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: device = torch.device(args.device) dtype = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}[args.dtype] print(f"[setup] device={device} dtype={dtype} model={args.model_path} " f"out={out_dir}", flush=True) label2id, id2label = nemotron_native_label_space() o_id = label2id["O"] tokenizer = AutoTokenizer.from_pretrained(args.model_path) pad_token_id = tokenizer.pad_token_id or 199999 # Build the eval set (same slice the README headline numbers reference) print(f"[data] loading {SOURCE_DATASET} ...", flush=True) ds = load_dataset(SOURCE_DATASET) target_eval = max(1, int(round(len(ds["test"]) * args.eval_pct / 100.0))) eval_indices = _build_eval_streaming( ds["test"], target_n=target_eval, chunk_size=args.eval_chunk_size, seed=args.seed, ) eval_ds = _NemotronEvalDataset( ds["test"].select(eval_indices), tokenizer, label2id, args.max_length, ) args.n_docs = len(eval_ds) print(f"[data] eval={args.n_docs:,} docs ({args.eval_pct:.2f}% of test split)", flush=True) # BIOES decoding masks (used for the explicit RAW vs VITERBI streams). from modeling_haremb_pii import ( _build_bioes_initial_mask as _bld_init, _build_bioes_transition_mask as _bld_trans, ) bioes_trans = _bld_trans(id2label).to(device).float() bioes_init = _bld_init(id2label).to(device).float() # Eval candidate cand = _eval_one_model( args.model_path, tokenizer, eval_ds, label2id, id2label, o_id, bioes_trans, bioes_init, args.batch_size, args.max_length, device, dtype, label="haremb", ) print(f"\n=== {cand['label']} ===") print(f"RAW {_fmt_metrics(cand['raw'])}") print(f"VITERBI {_fmt_metrics(cand['viterbi'])}") print(f"DELTA span_F1={cand['viterbi']['span_f1']-cand['raw']['span_f1']:+.4f} " f"P={cand['viterbi']['span_precision']-cand['raw']['span_precision']:+.4f} " f"R={cand['viterbi']['span_recall']-cand['raw']['span_recall']:+.4f}") ref = None pair = None if not args.no_base: ref = _eval_one_model( TEACHER, tokenizer, eval_ds, label2id, id2label, o_id, bioes_trans, bioes_init, args.batch_size, args.max_length, device, dtype, label="openmed-base", ) print(f"\n=== {ref['label']} (teacher) ===") print(f"VITERBI {_fmt_metrics(ref['viterbi'])}") pair = _pairwise(cand["docs"], ref["docs"], id2label, o_id) d = ref["viterbi"]["span_f1"] - cand["viterbi"]["span_f1"] print(f"\nGap to teacher (viterbi span_F1): {d:+.4f}") # Reports sample_text = ("Patient Sarah Johnson (DOB 03/15/1985), MRN 4872910, " "phone 415-555-0123, email sarah.johnson@example.com, " "credit card 4111-1111-1111-1111.") _write_infer_log(out_dir / "infer.log", cand, ref, args, sample_text, tokenizer, device, dtype) _write_compare_log(out_dir / "compare.log", cand, ref, pair, args) if not args.no_plots: _render_plots(cand, ref, pair, out_dir, cand_label=cand["label"], ref_label=ref["label"] if ref else "") print("\n[done]") if __name__ == "__main__": main()