fblgit's picture
Upload folder using huggingface_hub
f0f5785
"""
benchmark.py — single self-contained reproducibility script for
haremb-privacy-filter-opennemo.
Run from inside this folder:
python benchmark.py # default: cuda if available
python benchmark.py --device cpu # cpu fallback
python benchmark.py --eval-pct 0.5 # smaller slice
python benchmark.py --no-base # skip teacher download
Produces, in `--out` (default ./):
infer.log — sample inference timing + redaction example
compare.log — aggregate + per-category metrics, this model vs
OpenMed teacher (raw + viterbi streams), and
token-level pairwise breakdown.
eval_summary.png — bar charts of headline metrics + per-category
span-F1 (this vs teacher).
eval_confusion.png — token-level outcome breakdown on gold non-O
positions (this vs teacher).
eval_performance.png — model-size / compute / memory / throughput
comparison (this vs teacher), absolute + ratios.
This script does not import from training code. It vendors the small set
of helpers it needs (BIOES decoder, span builder, eval-set sampler,
metrics aggregator) so the model folder is self-contained.
"""
from __future__ import annotations
import argparse
import ast
import math
import os
import sys
import time
from collections import Counter, defaultdict
from pathlib import Path
from typing import Dict, List, Tuple
import numpy as np
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm
from transformers import AutoModelForTokenClassification, AutoTokenizer
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
SOURCE_DATASET = "nvidia/Nemotron-PII"
TEACHER = "OpenMed/privacy-filter-nemotron"
# 55 Nemotron-PII categories, alphabetically sorted (matches the order used
# when the model was trained, so id2label / label2id round-trip cleanly).
NEMOTRON_CATEGORIES: List[str] = sorted([
"account_number", "age", "api_key", "bank_routing_number",
"biometric_identifier", "blood_type", "certificate_license_number",
"city", "company_name", "coordinate", "country", "county",
"credit_debit_card", "customer_id", "cvv", "date", "date_of_birth",
"date_time", "device_identifier", "education_level", "email",
"employee_id", "employment_status", "fax_number", "first_name",
"gender", "health_plan_beneficiary_number", "http_cookie", "ipv4",
"ipv6", "language", "last_name", "license_plate", "mac_address",
"medical_record_number", "national_id", "occupation", "password",
"phone_number", "pin", "political_view", "postcode", "race_ethnicity",
"religious_belief", "sexuality", "ssn", "state", "street_address",
"swift_bic", "tax_id", "time", "unique_id", "url", "user_name",
"vehicle_identifier",
])
def nemotron_native_label_space() -> Tuple[Dict[str, int], Dict[int, str]]:
"""O at id 0, then {B, I, E, S}-{cat} for each cat in alphabetical order."""
label2id: Dict[str, int] = {"O": 0}
nxt = 1
for cat in NEMOTRON_CATEGORIES:
for prefix in ("B", "I", "E", "S"):
label2id[f"{prefix}-{cat}"] = nxt
nxt += 1
id2label: Dict[int, str] = {v: k for k, v in label2id.items()}
return label2id, id2label
# ---------------------------------------------------------------------------
# Span parsing + char-level token alignment (vendored from the training data
# pipeline; identical logic, no training imports)
# ---------------------------------------------------------------------------
def _trim_span(text: str, start: int, end: int) -> Tuple[int, int]:
raw = text[start:end]
i = 0
while i < len(raw) and raw[i].isspace():
i += 1
j = len(raw)
while j > i and (raw[j - 1].isspace() or raw[j - 1] in ".,;:)"):
j -= 1
return start + i, start + j
def _parse_spans(spans_str) -> List[dict]:
if isinstance(spans_str, list):
return spans_str
try:
return ast.literal_eval(spans_str)
except (SyntaxError, ValueError):
return []
def _assign_native_bioes_labels(
text: str,
raw_spans: List[dict],
tokenizer,
max_length: int,
label2id: Dict[str, int],
min_overlap_frac: float = 0.5,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
pf_like: List[Tuple[int, int, str]] = []
for s in raw_spans:
cat = s.get("label")
if not cat:
continue
st, en = _trim_span(text, int(s["start"]), int(s["end"]))
if st >= en:
continue
pf_like.append((st, en, cat))
pf_like.sort(key=lambda x: (x[0], -x[1]))
enc = tokenizer(
text, truncation=True, max_length=max_length,
padding="max_length", return_offsets_mapping=True, return_tensors="pt",
)
input_ids = enc.input_ids[0]
attention_mask = enc.attention_mask[0]
offsets = enc.offset_mapping[0].tolist()
o_id = label2id["O"]
label_ids = [o_id] * len(input_ids)
locked = [False] * len(input_ids)
for span_start, span_end, cat in pf_like:
tok_indices: List[int] = []
for ti, (s, e) in enumerate(offsets):
if s == 0 and e == 0:
continue
if e <= span_start or s >= span_end:
continue
tok_len = e - s
if tok_len <= 0:
continue
overlap = min(e, span_end) - max(s, span_start)
if overlap / tok_len >= min_overlap_frac:
tok_indices.append(ti)
if not tok_indices:
continue
tok_indices = [ti for ti in tok_indices if not locked[ti]]
if not tok_indices:
continue
if len(tok_indices) == 1:
tag = f"S-{cat}"
if tag in label2id:
label_ids[tok_indices[0]] = label2id[tag]
locked[tok_indices[0]] = True
else:
b_tag, i_tag, e_tag = f"B-{cat}", f"I-{cat}", f"E-{cat}"
if b_tag in label2id:
label_ids[tok_indices[0]] = label2id[b_tag]
locked[tok_indices[0]] = True
for ti in tok_indices[1:-1]:
if i_tag in label2id:
label_ids[ti] = label2id[i_tag]
locked[ti] = True
if e_tag in label2id:
label_ids[tok_indices[-1]] = label2id[e_tag]
locked[tok_indices[-1]] = True
label_tensor = torch.tensor(label_ids, dtype=torch.long)
label_tensor[attention_mask == 0] = -100
return input_ids, attention_mask, label_tensor
class _NemotronEvalDataset(Dataset):
def __init__(self, hf_split, tokenizer, label2id, max_length):
self.hf = hf_split
self.tok = tokenizer
self.l2i = label2id
self.maxlen = max_length
def __len__(self):
return len(self.hf)
def __getitem__(self, idx):
ex = self.hf[idx]
ids, mask, labels = _assign_native_bioes_labels(
ex["text"], _parse_spans(ex["spans"]),
self.tok, self.maxlen, self.l2i,
)
L = int(mask.sum().item())
return {
"input_ids": ids[:L].tolist(),
"labels": labels[:L].tolist(),
"valid_len": L,
}
def _make_collate(pad_token_id, max_length):
def _c(batch):
ids_list = [list(ex["input_ids"])[:max_length] for ex in batch]
labels_list = [list(ex["labels"])[:max_length] for ex in batch]
max_len = max(len(x) for x in ids_list)
B = len(batch)
input_ids = torch.full((B, max_len), pad_token_id, dtype=torch.long)
attention_mask = torch.zeros((B, max_len), dtype=torch.long)
labels = torch.full((B, max_len), -100, dtype=torch.long)
for i, (ids, lab) in enumerate(zip(ids_list, labels_list)):
L = len(ids)
input_ids[i, :L] = torch.tensor(ids, dtype=torch.long)
attention_mask[i, :L] = 1
labels[i, :L] = torch.tensor(lab, dtype=torch.long)
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
return _c
def _build_eval_streaming(test_split, target_n, chunk_size, seed) -> List[int]:
"""Uniform chunked sampling, identical to the training-time eval split."""
n_total = len(test_split)
target_n = min(target_n, n_total)
if target_n <= 0:
return []
rng = np.random.RandomState(seed)
per_chunk = max(1, math.ceil(chunk_size * target_n / n_total))
selected: List[int] = []
for chunk_start in range(0, n_total, chunk_size):
if len(selected) >= target_n:
break
chunk_end = min(chunk_start + chunk_size, n_total)
n_in_chunk = chunk_end - chunk_start
n_to_pick = min(per_chunk, n_in_chunk, target_n - len(selected))
if n_to_pick <= 0:
break
offsets = rng.choice(n_in_chunk, size=n_to_pick, replace=False)
selected.extend(int(chunk_start + o) for o in offsets)
return sorted(selected[:target_n])
# ---------------------------------------------------------------------------
# BIOES → spans + metrics
# ---------------------------------------------------------------------------
def _bioes_to_spans(labels, id2label, o_id=0):
"""Convert per-token BIOES label ids to a set of (start, end, cat)."""
spans = set()
cur_start = None
cur_cat = None
for i, lid in enumerate(labels):
lid = int(lid)
if lid == o_id or lid < 0:
if cur_start is not None:
spans.add((cur_start, i, cur_cat))
cur_start = None
cur_cat = None
continue
tag = id2label.get(lid, "O")
if tag == "O" or "-" not in tag:
if cur_start is not None:
spans.add((cur_start, i, cur_cat))
cur_start = None
cur_cat = None
continue
prefix, cat = tag.split("-", 1)
if prefix == "S":
if cur_start is not None:
spans.add((cur_start, i, cur_cat))
spans.add((i, i + 1, cat))
cur_start = None
cur_cat = None
elif prefix == "B":
if cur_start is not None:
spans.add((cur_start, i, cur_cat))
cur_start = i
cur_cat = cat
elif prefix == "I":
if cur_start is None or cur_cat != cat:
if cur_start is not None:
spans.add((cur_start, i, cur_cat))
cur_start = i
cur_cat = cat
elif prefix == "E":
if cur_start is None or cur_cat != cat:
if cur_start is not None:
spans.add((cur_start, i, cur_cat))
spans.add((i, i + 1, cat))
cur_start = None
cur_cat = None
else:
spans.add((cur_start, i + 1, cur_cat))
cur_start = None
cur_cat = None
if cur_start is not None:
spans.add((cur_start, len(labels), cur_cat))
return spans
def _aggregate_span_metrics(gold_spans, pred_spans):
correct = gold_spans & pred_spans
n_gold = len(gold_spans)
n_pred = len(pred_spans)
n_correct = len(correct)
p = n_correct / n_pred if n_pred else 0.0
r = n_correct / n_gold if n_gold else 0.0
f1 = (2 * p * r / (p + r)) if (p + r) else 0.0
per_cat: Dict[str, dict] = {}
cats = sorted({c for _, _, c in gold_spans} | {c for _, _, c in pred_spans})
for cat in cats:
g_c = {s for s in gold_spans if s[2] == cat}
p_c = {s for s in pred_spans if s[2] == cat}
c_c = g_c & p_c
pp = len(c_c) / len(p_c) if p_c else 0.0
rr = len(c_c) / len(g_c) if g_c else 0.0
ff = (2 * pp * rr / (pp + rr)) if (pp + rr) else 0.0
per_cat[cat] = {"precision": pp, "recall": rr, "f1": ff,
"n_gold": len(g_c), "n_pred": len(p_c), "n_correct": len(c_c)}
return {"precision": p, "recall": r, "f1": f1,
"n_gold": n_gold, "n_pred": n_pred, "n_correct": n_correct,
"per_cat": per_cat}
def _stream_metrics(docs, stream, id2label, o_id):
"""Aggregate metrics over a list of {gold, raw, viterbi} per-doc dicts."""
n_tokens = correct = n_non_o = non_o_correct = 0
gold_spans_all: set = set()
pred_spans_all: set = set()
doc_offset = 0
for doc in docs:
gold = [int(x) for x in doc["gold"]]
pred = [int(x) for x in doc[stream]]
n = len(gold)
n_tokens += n
for g, p in zip(gold, pred):
n_non_o += int(g != o_id)
if g == p:
correct += 1
if g != o_id:
non_o_correct += 1
gs = _bioes_to_spans(gold, id2label, o_id)
ps = _bioes_to_spans(pred, id2label, o_id)
gold_spans_all.update((doc_offset + s, doc_offset + e, c) for s, e, c in gs)
pred_spans_all.update((doc_offset + s, doc_offset + e, c) for s, e, c in ps)
doc_offset += n
span_m = _aggregate_span_metrics(gold_spans_all, pred_spans_all)
return {
"n_tokens": n_tokens,
"n_non_o": n_non_o,
"token_acc": correct / n_tokens if n_tokens else 0.0,
"non_o_recall": non_o_correct / n_non_o if n_non_o else 0.0,
"span_precision": span_m["precision"],
"span_recall": span_m["recall"],
"span_f1": span_m["f1"],
"n_gold_spans": span_m["n_gold"],
"n_pred_spans": span_m["n_pred"],
"n_correct_spans": span_m["n_correct"],
"span_per_cat": span_m["per_cat"],
}
# ---------------------------------------------------------------------------
# Forward pass + viterbi (delegates to the released modeling)
# ---------------------------------------------------------------------------
def _model_perf_stats(model, dtype) -> dict:
"""Total / active / compute / MoE breakdown + on-device byte size.
Three distinct param counts:
* total_params — every parameter the model has on disk / in RAM.
* active_params_per_tok — params *touched* during one token's forward
pass (memory footprint per token). Counts
the embedding because the embedding row
IS read per token; counts only top-k of
num_experts MoE experts because routing is
sparse.
* compute_params_per_tok — params that contribute matmul FLOPs per
token. EXCLUDES the embedding table:
`embed_tokens.weight` is a gather (one row
read), not a matmul, so its FLOP cost is
negligible (~hidden_size ops vs the table
having ~vocab × hidden params). Counting
it via the standard "2 × params" matmul
approximation hugely inflates the apparent
GFLOP/token and compresses the ratio between
deep and shallow models.
GFLOP/token is computed from `compute_params_per_tok`, not from
`active_params_per_tok`. This makes the metric reflect actual layer-wise
computational cost.
"""
cfg = model.config
num_experts = int(getattr(cfg, "num_local_experts", 1))
top_k = int(getattr(cfg, "num_experts_per_tok", num_experts))
expert_frac = top_k / max(1, num_experts)
moe_total = 0
moe_active = 0
other_total = 0
embed_total = 0 # gather-only params; excluded from FLOP estimate
for name, p in model.named_parameters():
n = p.numel()
# MoE expert tensors are stacked along an experts axis. The upstream
# exposes them under `mlp.experts.*`. Only `top_k` of `num_experts`
# experts contribute per token.
if ".mlp.experts." in name:
moe_total += n
moe_active += int(round(n * expert_frac))
# `embed_tokens.weight` (and any other lookup-style table) is a
# gather: one row of [vocab, hidden] is read per token, costing
# ~hidden ops, not 2 × vocab × hidden FLOPs. Tracked separately
# so it doesn't pollute the FLOP estimate.
elif "embed_tokens" in name:
embed_total += n
other_total += n
else:
other_total += n
total = moe_total + other_total
active_per_tok = moe_active + other_total
# Compute-relevant params per token: matmuls only. Drop the embedding
# lookup, which contributes ~zero FLOPs.
compute_per_tok = active_per_tok - embed_total
bytes_per_param = {torch.bfloat16: 2, torch.float16: 2, torch.float32: 4}.get(dtype, 4)
storage_bytes = total * bytes_per_param # in-memory dense weights
# GFLOP/token over matmul params only. Matmul ≈ 2 FLOPs per param.
gflops_per_tok = 2 * compute_per_tok / 1e9
return {
"total_params": total,
"moe_total": moe_total,
"moe_active_per_tok": moe_active,
"other_total": other_total,
"embed_total": embed_total,
"active_params_per_tok": active_per_tok,
"compute_params_per_tok": compute_per_tok,
"num_experts": num_experts,
"experts_per_tok": top_k,
"expert_frac": expert_frac,
"weight_bytes": storage_bytes,
"gflops_per_tok": gflops_per_tok,
}
def _disk_size_bytes(model_path: str) -> int:
"""Sum on-disk size of weight files at the given path. Falls back to 0
if the path is a HF repo id (not a local directory)."""
p = Path(model_path)
if not p.is_dir():
return 0
total = 0
for f in p.iterdir():
if f.is_file() and f.suffix in {".safetensors", ".bin", ".pt", ".pth"}:
total += f.stat().st_size
return total
def _eval_one_model(
model_path: str, tokenizer, eval_ds, label2id, id2label, o_id,
bioes_trans, bioes_init, batch_size, max_length, device, dtype,
label: str,
):
print(f"[eval] loading {label} from {model_path} ...", flush=True)
if torch.cuda.is_available() and device.type == "cuda":
torch.cuda.reset_peak_memory_stats(device)
torch.cuda.empty_cache()
mem_before = (torch.cuda.memory_allocated(device)
if torch.cuda.is_available() and device.type == "cuda" else 0)
t_load = time.time()
model = AutoModelForTokenClassification.from_pretrained(
model_path, dtype=dtype, trust_remote_code=True,
).to(device).eval()
if hasattr(model.config, "use_viterbi_decode"):
model.config.use_viterbi_decode = True
if hasattr(model.config, "viterbi_replace_logits"):
model.config.viterbi_replace_logits = False
load_s = time.time() - t_load
perf = _model_perf_stats(model, dtype)
perf["disk_size_bytes"] = _disk_size_bytes(model_path)
weights_resident_bytes = (
torch.cuda.memory_allocated(device) - mem_before
if torch.cuda.is_available() and device.type == "cuda" else perf["weight_bytes"]
)
perf["weights_resident_bytes"] = weights_resident_bytes
# Vendor the batched viterbi (re-import from the released modeling file).
from modeling_haremb_pii import _bioes_viterbi_batched
pad_token_id = tokenizer.pad_token_id or 199999
loader = DataLoader(
eval_ds, batch_size=batch_size, shuffle=False,
collate_fn=_make_collate(pad_token_id, max_length), num_workers=2,
)
docs: List[dict] = []
n_tok = 0
if torch.cuda.is_available() and device.type == "cuda":
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats(device)
t0 = time.time()
for batch in tqdm(loader, desc=f"eval {label}", unit="batch", leave=False):
ids = batch["input_ids"].to(device, non_blocking=True)
mask = batch["attention_mask"].to(device, non_blocking=True)
gold = batch["labels"].to(device, non_blocking=True)
with torch.no_grad():
out = model(input_ids=ids, attention_mask=mask)
raw = out.logits.argmax(dim=-1)
vit = _bioes_viterbi_batched(out.logits.float(), mask, bioes_trans, bioes_init)
valid = (gold != -100) & mask.bool()
for b in range(gold.shape[0]):
keep = [i for i, ok in enumerate(valid[b].cpu().tolist()) if ok]
n_tok += len(keep)
docs.append({
"gold": [int(gold[b, i].item()) for i in keep],
"raw": [int(raw[b, i].item()) for i in keep],
"viterbi": [int(vit[b, i].item()) for i in keep],
})
if torch.cuda.is_available() and device.type == "cuda":
torch.cuda.synchronize()
peak_mem = torch.cuda.max_memory_allocated(device)
else:
peak_mem = 0
eval_s = time.time() - t0
raw_m = _stream_metrics(docs, "raw", id2label, o_id)
vit_m = _stream_metrics(docs, "viterbi", id2label, o_id)
perf["peak_eval_mem_bytes"] = peak_mem
del model
if torch.cuda.is_available():
torch.cuda.empty_cache()
return {
"label": label,
"n_total_M": perf["total_params"] / 1e6,
"load_s": load_s,
"eval_s": eval_s,
"n_tok": n_tok,
"throughput_tok_s": n_tok / eval_s if eval_s else 0.0,
"perf": perf,
"raw": raw_m,
"viterbi": vit_m,
"docs": docs,
}
# ---------------------------------------------------------------------------
# Pairwise token-level breakdown (this vs reference, viterbi stream)
# ---------------------------------------------------------------------------
def _pairwise(docs_cand, docs_ref, id2label, o_id):
both_correct = only_cand = only_ref = both_wrong = 0
by_cat = defaultdict(lambda: {"both_correct": 0, "only_cand": 0, "only_ref": 0, "both_wrong": 0})
for dc, dr in zip(docs_cand, docs_ref):
gold = dc["gold"]
cv = dc["viterbi"]
rv = dr["viterbi"]
for g, c, r in zip(gold, cv, rv):
cat = id2label.get(g, "O")
cat = cat.split("-", 1)[1] if "-" in cat else cat
cc = (c == g)
rc = (r == g)
if cc and rc:
both_correct += 1
by_cat[cat]["both_correct"] += 1
elif cc and not rc:
only_cand += 1
by_cat[cat]["only_cand"] += 1
elif rc and not cc:
only_ref += 1
by_cat[cat]["only_ref"] += 1
else:
both_wrong += 1
by_cat[cat]["both_wrong"] += 1
return {
"both_correct": both_correct,
"only_cand_correct": only_cand,
"only_ref_correct": only_ref,
"both_wrong": both_wrong,
"by_cat": dict(by_cat),
}
# ---------------------------------------------------------------------------
# Plot rendering
# ---------------------------------------------------------------------------
def _render_plots(cand, ref, pair, out_dir: Path, cand_label, ref_label):
"""Render benchmark plots.
Visual convention:
A = reference / teacher / baseline
B = candidate / this checkpoint
The charts avoid color-only "win" encoding: labels state the actual delta
or ratio, and horizontal layouts keep long metric/category names readable.
"""
try:
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter
except ImportError:
print("[plot] matplotlib not installed, skipping", flush=True)
return
plt.rcParams.update({
"figure.facecolor": "#ffffff",
"axes.facecolor": "#ffffff",
"axes.edgecolor": "#cbd5e1",
"axes.labelcolor": "#0f172a",
"xtick.color": "#334155",
"ytick.color": "#334155",
"grid.color": "#e2e8f0",
"font.size": 9,
"axes.titleweight": "bold",
"axes.titlesize": 11,
"legend.frameon": False,
})
C_REF = "#64748b" # slate
C_CAND = "#2563eb" # blue
C_GOOD = "#0f766e" # teal
C_BAD = "#b91c1c" # red
C_NEUTRAL = "#94a3b8" # light slate
C_BG = "#f8fafc"
def _pct_axis(ax):
ax.xaxis.set_major_formatter(FuncFormatter(lambda x, _pos: f"{x:.0%}"))
def _metric_delta_text(delta):
return f"{delta:+.4f}"
def _value_text(v):
if v >= 100:
return f"{v:,.0f}"
if v >= 10:
return f"{v:,.1f}"
if v >= 1:
return f"{v:,.2f}"
return f"{v:,.3f}"
def _ratio_text(r, lower_is_better):
if r <= 0:
return "n/a"
if lower_is_better:
return f"{1.0 / r:.2f}x lower" if r <= 1 else f"{r:.2f}x higher"
return f"{r:.2f}x higher" if r >= 1 else f"{1.0 / r:.2f}x lower"
# --- eval_summary.png: headline metrics + category-level deltas ---
fig, axes = plt.subplots(2, 2, figsize=(8, 5), constrained_layout=True)
ax_head = axes[0, 0]
ax_delta = axes[0, 1]
ax_raw_vit = axes[1, 0]
ax_cat = axes[1, 1]
metrics = ["span_f1", "span_precision", "span_recall", "token_acc", "non_o_recall"]
labels = ["Span F1", "Span P", "Span R", "Token acc", "Non-O recall"]
cand_v = [cand["viterbi"][m] for m in metrics]
ref_v = [ref["viterbi"][m] for m in metrics] if ref is not None else None
y = np.arange(len(metrics))
if ref_v is not None:
ax_head.hlines(y, ref_v, cand_v, color=C_NEUTRAL, linewidth=2, alpha=0.9)
ax_head.scatter(ref_v, y, s=55, color=C_REF, label=f"A: {ref_label}", zorder=3)
ax_head.scatter(cand_v, y, s=70, color=C_CAND, label=f"B: {cand_label}", zorder=4)
ax_head.set_yticks(y)
ax_head.set_yticklabels(labels)
ax_head.invert_yaxis()
ax_head.set_xlim(max(0.0, min(cand_v + (ref_v or cand_v)) - 0.02),
min(1.08, max(cand_v + (ref_v or cand_v)) + 0.05))
ax_head.set_title("Headline metrics, Viterbi stream")
ax_head.grid(axis="x", alpha=0.7)
_pct_axis(ax_head)
if ref_v is not None:
for i, v in enumerate(ref_v):
ax_head.text(v + 0.002, i, f"{v:.4f}", va="center", ha="left",
fontsize=7, color=C_REF)
for i, v in enumerate(cand_v):
ax_head.text(v - 0.002, i, f"{v:.4f}", va="center", ha="right",
fontsize=7, color=C_CAND)
if ref_v is not None:
deltas = [b - a for a, b in zip(ref_v, cand_v)]
colors = [C_GOOD if d >= 0 else C_BAD for d in deltas]
ax_delta.axvline(0, color="#0f172a", linewidth=0.9)
ax_delta.barh(y, deltas, color=colors, alpha=0.9)
ax_delta.set_yticks(y)
ax_delta.set_yticklabels(labels)
ax_delta.invert_yaxis()
ax_delta.set_title("Delta: B minus A")
ax_delta.grid(axis="x", alpha=0.7)
max_abs = max([abs(d) for d in deltas] + [0.002])
ax_delta.set_xlim(-max_abs * 1.45, max_abs * 1.45)
for i, d in enumerate(deltas):
ax_delta.text(d + (max_abs * 0.04 if d >= 0 else -max_abs * 0.04), i, _metric_delta_text(d),
ha="left" if d >= 0 else "right", va="center", fontsize=7)
else:
ax_delta.axis("off")
stream_rows = [
("A raw", ref["raw"]["span_f1"] if ref is not None else None, C_REF),
("A viterbi", ref["viterbi"]["span_f1"] if ref is not None else None, C_REF),
("B raw", cand["raw"]["span_f1"], C_CAND),
("B viterbi", cand["viterbi"]["span_f1"], C_CAND),
]
stream_rows = [r for r in stream_rows if r[1] is not None]
sy = np.arange(len(stream_rows))
ax_raw_vit.barh(sy, [r[1] for r in stream_rows], color=[r[2] for r in stream_rows], alpha=0.88)
ax_raw_vit.set_yticks(sy)
ax_raw_vit.set_yticklabels([r[0] for r in stream_rows])
ax_raw_vit.invert_yaxis()
ax_raw_vit.set_xlim(0, 1.08)
ax_raw_vit.set_title("Raw vs Viterbi span F1")
ax_raw_vit.grid(axis="x", alpha=0.7)
_pct_axis(ax_raw_vit)
for i, (_, v, _) in enumerate(stream_rows):
ax_raw_vit.text(v + 0.008, i, f"{v:.4f}", va="center", fontsize=7)
cand_pc = cand["viterbi"]["span_per_cat"]
if ref is not None:
ref_pc = ref["viterbi"]["span_per_cat"]
cats = sorted(set(cand_pc) | set(ref_pc))
rows = []
for c in cats:
a = ref_pc.get(c, {}).get("f1", 0.0)
b = cand_pc.get(c, {}).get("f1", 0.0)
n = max(cand_pc.get(c, {}).get("n_gold", 0), ref_pc.get(c, {}).get("n_gold", 0))
rows.append((c, b - a, b, a, n))
# Keep the categories that explain the comparison: worst and best B deltas.
worst = sorted(rows, key=lambda r: r[1])[:8]
best = sorted(rows, key=lambda r: r[1], reverse=True)[:8]
picked = worst + [r for r in best if r[0] not in {x[0] for x in worst}]
picked = sorted(picked, key=lambda r: r[1])
cy = np.arange(len(picked))
deltas = [r[1] for r in picked]
ax_cat.axvline(0, color="#0f172a", linewidth=0.9)
ax_cat.barh(cy, deltas, color=[C_GOOD if d >= 0 else C_BAD for d in deltas])
ax_cat.set_yticks(cy)
ax_cat.set_yticklabels([r[0] for r in picked], fontsize=8)
ax_cat.set_title("Per-category span F1 delta, selected extremes")
ax_cat.grid(axis="x", alpha=0.7)
max_abs = max([abs(d) for d in deltas] + [0.05])
ax_cat.set_xlim(-max_abs * 1.55, max_abs * 1.55)
for i, r in enumerate(picked):
d = r[1]
ax_cat.text(d + (max_abs * 0.05 if d >= 0 else -max_abs * 0.05), i,
f"{d:+.3f} B={r[2]:.2f} A={r[3]:.2f}",
va="center", ha="left" if d >= 0 else "right", fontsize=6)
else:
cats_sorted = sorted(cand_pc.keys(), key=lambda c: cand_pc[c]["f1"])[:18]
vals = [cand_pc[c]["f1"] for c in cats_sorted]
cy = np.arange(len(cats_sorted))
ax_cat.barh(cy, vals, color=C_CAND)
ax_cat.set_yticks(cy)
ax_cat.set_yticklabels(cats_sorted, fontsize=8)
ax_cat.set_xlim(0, 1.0)
ax_cat.set_title("Lowest per-category span F1")
ax_cat.grid(axis="x", alpha=0.7)
_pct_axis(ax_cat)
fig.suptitle(f"Evaluation summary — A: {ref_label if ref else 'n/a'} | B: {cand_label}",
fontsize=9, fontweight="bold")
fig.savefig(out_dir / "eval_summary.png", dpi=160)
plt.close(fig)
print(f"[plot] wrote {out_dir / 'eval_summary.png'}", flush=True)
# --- eval_confusion.png: pairwise outcome on gold non-O tokens ---
# Display order matches A vs B: "Only A correct" (teacher) before
# "Only B correct" (student). Underlying buckets in `pair` are still
# named cand/ref; we just relabel for display.
if pair is not None:
fig, axes = plt.subplots(
1, 2, figsize=(8, 3), constrained_layout=True,
gridspec_kw={"width_ratios": [0.9, 1.7]},
)
ax = axes[0]
non_o_buckets = {k: 0 for k in ["both_correct", "only_cand", "only_ref", "both_wrong"]}
for cat, d in pair["by_cat"].items():
if cat == "O":
continue
for k in non_o_buckets:
non_o_buckets[k] += d[k]
values = [non_o_buckets["both_correct"], non_o_buckets["only_ref"],
non_o_buckets["only_cand"], non_o_buckets["both_wrong"]]
labels_ = [
"Both\ncorrect",
"Only A\ncorrect",
"Only B\ncorrect",
"Both wrong",
]
colors = [C_GOOD, C_REF, C_CAND, C_BAD]
total = max(1, sum(values))
ax.barh(np.arange(4), values, color=colors)
ax.set_yticks(np.arange(4))
ax.set_yticklabels(labels_)
ax.invert_yaxis()
ax.set_ylabel("Gold non-O tokens")
ax.set_title("Token outcome on gold non-O")
ax.grid(axis="x", alpha=0.7)
ax.set_xlim(0, max(values) * 1.32 if values else 1)
for i, v in enumerate(values):
ax.text(v + max(values) * 0.015, i, f"{v:,} ({v / total:.1%})",
va="center", fontsize=6)
rows = []
for cat, d in pair["by_cat"].items():
if cat == "O":
continue
net = d["only_cand"] - d["only_ref"]
active = d["only_cand"] + d["only_ref"] + d["both_wrong"]
if active:
rows.append((cat, net, d["only_cand"], d["only_ref"], d["both_wrong"]))
worst = sorted(rows, key=lambda r: r[1])[:8]
best = sorted(rows, key=lambda r: r[1], reverse=True)[:8]
picked = worst + [r for r in best if r[0] not in {x[0] for x in worst}]
picked = sorted(picked, key=lambda r: r[1])
ax2 = axes[1]
if picked:
py = np.arange(len(picked))
nets = [r[1] for r in picked]
ax2.axvline(0, color="#0f172a", linewidth=0.9)
ax2.barh(py, nets, color=[C_GOOD if n >= 0 else C_BAD for n in nets])
ax2.set_yticks(py)
ax2.set_yticklabels([r[0] for r in picked], fontsize=8)
ax2.set_title("Net token wins by category: B only-correct minus A only-correct")
ax2.grid(axis="x", alpha=0.7)
max_abs = max([abs(n) for n in nets] + [1])
ax2.set_xlim(-max_abs * 1.5, max_abs * 1.5)
for i, r in enumerate(picked):
n = r[1]
label_x = n + max_abs * 0.04 if n >= 0 else max_abs * 0.05
ax2.text(label_x, i,
f"{n:+d} B={r[2]} A={r[3]} W={r[4]}",
va="center", ha="left", fontsize=6)
else:
ax2.axis("off")
fig.suptitle(f"Pairwise correctness — A: {ref_label} | B: {cand_label}",
fontsize=9, fontweight="bold")
fig.savefig(out_dir / "eval_confusion.png", dpi=160)
plt.close(fig)
print(f"[plot] wrote {out_dir / 'eval_confusion.png'}", flush=True)
# --- eval_performance.png: model size, compute, throughput, memory ---
if ref is not None and "perf" in cand and "perf" in ref:
cp, rp = cand["perf"], ref["perf"]
fig, axes = plt.subplots(1, 2, figsize=(8.5, 5), constrained_layout=True)
ax_abs = axes[0]
ax_ratio = axes[1]
metrics = [
("Total params (M)", cp["total_params"]/1e6, rp["total_params"]/1e6, True),
("Active params/tok (M)", cp["active_params_per_tok"]/1e6, rp["active_params_per_tok"]/1e6, True),
("MoE expert params (M)", cp["moe_total"]/1e6, rp["moe_total"]/1e6, True),
("GFLOP/token", cp["gflops_per_tok"], rp["gflops_per_tok"], True),
("Weights RAM (MiB)", cp["weight_bytes"]/(1<<20), rp["weight_bytes"]/(1<<20), True),
("Peak eval mem (MiB)", cp["peak_eval_mem_bytes"]/(1<<20), rp["peak_eval_mem_bytes"]/(1<<20), True),
("Throughput (tok/s)", cand["throughput_tok_s"], ref["throughput_tok_s"], False),
]
y = np.arange(len(metrics))
w = 0.38
cand_v = [m[1] for m in metrics]
ref_v = [m[2] for m in metrics]
ax_abs.barh(y - w/2, ref_v, w, label=f"A: {ref_label}", color=C_REF)
ax_abs.barh(y + w/2, cand_v, w, label=f"B: {cand_label}", color=C_CAND)
ax_abs.set_yticks(y)
ax_abs.set_yticklabels([m[0] for m in metrics], fontsize=8)
ax_abs.invert_yaxis()
ax_abs.set_xscale("log")
ax_abs.set_title("Absolute footprint and speed, log scale")
ax_abs.grid(axis="x", which="both", alpha=0.7)
positive_vals = [v for v in (ref_v + cand_v) if v > 0]
if positive_vals:
ax_abs.set_xlim(min(positive_vals) * 0.55, max(positive_vals) * 3.8)
for yi, vals in enumerate(zip(ref_v, cand_v)):
for off, v, col in [(-w/2, vals[0], C_REF), (w/2, vals[1], C_CAND)]:
if v <= 0:
continue
ax_abs.text(v * 1.05, yi + off, _value_text(v), va="center", fontsize=6, color=col)
ratios = [(m[1] / max(1e-12, m[2])) for m in metrics]
lower_better = [m[3] for m in metrics]
colors = [
C_GOOD if ((lb and r <= 1.0) or ((not lb) and r >= 1.0)) else C_BAD
for r, lb in zip(ratios, lower_better)
]
ax_ratio.axvline(1.0, color="#0f172a", linestyle="--", linewidth=0.9, alpha=0.65)
ax_ratio.barh(y, ratios, color=colors)
ax_ratio.set_yticks(y)
ax_ratio.set_yticklabels([m[0] for m in metrics], fontsize=8)
ax_ratio.invert_yaxis()
ax_ratio.set_xscale("log")
ax_ratio.set_title("B / A ratio with explicit direction")
ax_ratio.grid(axis="x", which="both", alpha=0.7)
positive_ratios = [r for r in ratios if r > 0]
if positive_ratios:
ax_ratio.set_xlim(min(positive_ratios) * 0.38, max(positive_ratios) * 2.8)
for i, (r, lb) in enumerate(zip(ratios, lower_better)):
ax_ratio.text(r * (1.05 if r >= 1 else 0.95), i, _ratio_text(r, lb),
va="center", ha="left" if r >= 1 else "right", fontsize=6)
fig.suptitle(f"Performance profile — A: {ref_label} | B: {cand_label}",
fontsize=9, fontweight="bold")
fig.savefig(out_dir / "eval_performance.png", dpi=160)
plt.close(fig)
print(f"[plot] wrote {out_dir / 'eval_performance.png'}", flush=True)
# ---------------------------------------------------------------------------
# Reporting
# ---------------------------------------------------------------------------
def _fmt_metrics(m):
return (f"span_F1={m['span_f1']:.4f} P={m['span_precision']:.4f} "
f"R={m['span_recall']:.4f} token_acc={m['token_acc']:.4f} "
f"non_o_recall={m['non_o_recall']:.4f} "
f"spans={m['n_gold_spans']}/{m['n_pred_spans']}/{m['n_correct_spans']}")
def _write_compare_log(path: Path, cand, ref, pair, args):
lines: List[str] = []
A = lines.append
A(f"Benchmark: A: {ref['label']} vs B: {cand['label']}"
if ref else f"Benchmark: {cand['label']}")
A(f"Dataset: {SOURCE_DATASET}, split=test, eval_pct={args.eval_pct}, "
f"ctx={args.max_length}, seed={args.seed}, n_docs={args.n_docs}")
A(f"Eval tokens scored: {cand['n_tok']:,}")
A("")
A("=== Aggregate ===")
if ref is not None:
A(f" A: {ref['label']:<25s} RAW {_fmt_metrics(ref['raw'])}")
A(f" A: {ref['label']:<25s} VITERBI {_fmt_metrics(ref['viterbi'])}")
A(f" B: {cand['label']:<25s} RAW {_fmt_metrics(cand['raw'])}")
A(f" B: {cand['label']:<25s} VITERBI {_fmt_metrics(cand['viterbi'])}")
if ref is not None:
d_f1 = ref["viterbi"]["span_f1"] - cand["viterbi"]["span_f1"]
A("")
A(f"Gap B vs A (viterbi span_F1): {-d_f1:+.4f}")
A("")
if ref is not None:
A(f"Throughput: A: {ref['label']} {ref['throughput_tok_s']:.0f} tok/s "
f"({ref['n_total_M']:.2f}M params)")
A(f" B: {cand['label']} {cand['throughput_tok_s']:.0f} tok/s "
f"({cand['n_total_M']:.2f}M params)")
else:
A(f"Throughput: {cand['label']} {cand['throughput_tok_s']:.0f} tok/s "
f"({cand['n_total_M']:.2f}M params)")
A("")
# ---- Performance summary table ----
# Column order: A (ref / teacher) first, then B (cand / student).
# The "B vs A" column is written in human-readable direction:
# - When B is smaller (size/compute/mem): "X.XX× smaller" / "X.XX× cheaper" / "X.XX× less".
# - When B is larger but lower-is-better: "X.XX× larger" / "X.XX× more".
# - When B is faster (throughput, higher-is-better): "X.XX× faster".
# - When B is slower: "X.XX× slower".
# Always uses the magnitude in the dominant direction so the reader
# doesn't need to mentally invert 0.21× into 4.87×.
def _fmt_vs(b: float, a: float, kind: str) -> str:
if a is None or a == 0 or b is None or b == 0:
return ""
ratio = b / a
if kind == "size": # lower is better; phrase as "smaller" or "larger"
if ratio <= 1.0:
return f"{1.0/ratio:.2f}× smaller"
return f"{ratio:.2f}× larger"
if kind == "compute":
if ratio <= 1.0:
return f"{1.0/ratio:.2f}× cheaper"
return f"{ratio:.2f}× more"
if kind == "memory":
if ratio <= 1.0:
return f"{1.0/ratio:.2f}× less"
return f"{ratio:.2f}× more"
if kind == "speed": # higher is better
if ratio >= 1.0:
return f"{ratio:.2f}× faster"
return f"{1.0/ratio:.2f}× slower"
return f"{ratio:.2f}×"
cp = cand["perf"]
A("=== Performance ===")
headers = ["metric",
f"A: {ref['label']}" if ref else "",
f"B: {cand['label']}",
"B vs A"]
rp = ref["perf"] if ref else None
rows = [
["total params (M)",
(f"{rp['total_params']/1e6:.2f}" if ref else ""),
f"{cp['total_params']/1e6:.2f}",
(_fmt_vs(cp['total_params'], rp['total_params'], "size") if ref else "")],
["dense params (M)",
(f"{rp['other_total']/1e6:.2f}" if ref else ""),
f"{cp['other_total']/1e6:.2f}",
(_fmt_vs(cp['other_total'], rp['other_total'], "size") if ref else "")],
["MoE expert params (M)",
(f"{rp['moe_total']/1e6:.2f}" if ref else ""),
f"{cp['moe_total']/1e6:.2f}",
(_fmt_vs(cp['moe_total'], rp['moe_total'], "size") if ref else "")],
[f"active params/token (M, mem)",
(f"{rp['active_params_per_tok']/1e6:.2f}" if ref else ""),
f"{cp['active_params_per_tok']/1e6:.2f}",
(_fmt_vs(cp['active_params_per_tok'], rp['active_params_per_tok'], "memory") if ref else "")],
[f"compute params/token (M, FLOPs)",
(f"{rp['compute_params_per_tok']/1e6:.2f}" if ref else ""),
f"{cp['compute_params_per_tok']/1e6:.2f}",
(_fmt_vs(cp['compute_params_per_tok'], rp['compute_params_per_tok'], "compute") if ref else "")],
["GFLOP / token",
(f"{rp['gflops_per_tok']:.4f}" if ref else ""),
f"{cp['gflops_per_tok']:.4f}",
(_fmt_vs(cp['gflops_per_tok'], rp['gflops_per_tok'], "compute") if ref else "")],
["disk size (MiB)",
(f"{rp['disk_size_bytes']/(1<<20):.1f}" if ref and rp['disk_size_bytes'] else ""),
f"{cp['disk_size_bytes']/(1<<20):.1f}",
(_fmt_vs(cp['disk_size_bytes'], rp['disk_size_bytes'], "size") if ref and rp['disk_size_bytes'] else "")],
["weights in RAM (MiB)",
(f"{rp['weight_bytes']/(1<<20):.1f}" if ref else ""),
f"{cp['weight_bytes']/(1<<20):.1f}",
(_fmt_vs(cp['weight_bytes'], rp['weight_bytes'], "size") if ref else "")],
["peak GPU mem eval (MiB)",
(f"{rp['peak_eval_mem_bytes']/(1<<20):.1f}" if ref else ""),
f"{cp['peak_eval_mem_bytes']/(1<<20):.1f}",
(_fmt_vs(cp['peak_eval_mem_bytes'], rp['peak_eval_mem_bytes'], "memory") if ref else "")],
["throughput (tok/s)",
(f"{ref['throughput_tok_s']:.0f}" if ref else ""),
f"{cand['throughput_tok_s']:.0f}",
(_fmt_vs(cand['throughput_tok_s'], ref['throughput_tok_s'], "speed") if ref else "")],
]
widths = [max(len(r[i]) for r in [headers] + rows) for i in range(4)]
sep = " " + " ".join("-" * w for w in widths)
A(" " + " ".join(h.ljust(widths[i]) for i, h in enumerate(headers)))
A(sep)
for r in rows:
A(" " + " ".join(r[i].ljust(widths[i]) for i in range(4)))
A("")
if pair is not None:
# Display labels: A = ref (teacher), B = cand (student).
a_lbl = ref["label"] if ref is not None else "A"
b_lbl = cand["label"]
A(f"=== Pairwise (viterbi, all gold tokens) — A: {a_lbl} vs B: {b_lbl} ===")
total = (pair["both_correct"] + pair["only_cand_correct"]
+ pair["only_ref_correct"] + pair["both_wrong"])
# Display order: agreement, A-only, B-only, both-wrong.
rows_all = [
("both_correct", pair["both_correct"]),
("only_A_correct", pair["only_ref_correct"]),
("only_B_correct", pair["only_cand_correct"]),
("both_wrong", pair["both_wrong"]),
]
for k, v in rows_all:
A(f" {k:<26s} {v:8d} ({100.0*v/total:.2f}%)")
A("")
A(f"=== Pairwise (viterbi, gold non-O tokens) — A: {a_lbl} vs B: {b_lbl} ===")
non_o = {k: 0 for k in ["both_correct", "only_cand", "only_ref", "both_wrong"]}
for cat, d in pair["by_cat"].items():
if cat == "O":
continue
for k in non_o:
non_o[k] += d[k]
total_non_o = sum(non_o.values())
rows_non_o = [
("both_correct", non_o["both_correct"]),
("only_A_correct", non_o["only_ref"]),
("only_B_correct", non_o["only_cand"]),
("both_wrong", non_o["both_wrong"]),
]
for k, v in rows_non_o:
A(f" {k:<26s} {v:8d} ({100.0*v/total_non_o:.2f}%)" if total_non_o else f" {k}: 0")
A("")
# Per-cat net B-wins. Net = (only_B) - (only_A) = (only_cand) - (only_ref).
# Negative net = A (teacher) wins more in this category.
nets = []
for cat, d in pair["by_cat"].items():
if cat == "O":
continue
nets.append((cat, d["only_cand"] - d["only_ref"],
d["only_cand"], d["only_ref"], d["both_wrong"]))
nets.sort(key=lambda x: x[1])
A(f"=== Worst B-net wins by gold category — A: {a_lbl} ahead (top 15) ===")
for cat, net, ob, oa, bw in nets[:15]:
A(f" {cat:<32s} net_B={net:+5d} A_only={oa:4d} B_only={ob:4d} both_wrong={bw:4d}")
A("")
A(f"=== Best B-net wins by gold category — B: {b_lbl} ahead (top 15) ===")
for cat, net, ob, oa, bw in nets[::-1][:15]:
A(f" {cat:<32s} net_B={net:+5d} A_only={oa:4d} B_only={ob:4d} both_wrong={bw:4d}")
A("")
A("=== Per-category span F1 (viterbi) ===")
if ref is not None:
A(f" -- A: {ref['label']} --")
per_r = ref["viterbi"]["span_per_cat"]
for cat in sorted(per_r):
c = per_r[cat]
A(f" {cat:<32s} F1={c['f1']:.4f} P={c['precision']:.4f} R={c['recall']:.4f} "
f"({c['n_gold']}/{c['n_pred']}/{c['n_correct']})")
A(f" -- B: {cand['label']} --")
per = cand["viterbi"]["span_per_cat"]
for cat in sorted(per):
c = per[cat]
A(f" {cat:<32s} F1={c['f1']:.4f} P={c['precision']:.4f} R={c['recall']:.4f} "
f"({c['n_gold']}/{c['n_pred']}/{c['n_correct']})")
path.write_text("\n".join(lines) + "\n")
print(f"[log] wrote {path}", flush=True)
def _fmt_bytes(n: int) -> str:
if n <= 0:
return "—"
if n >= 1 << 30:
return f"{n / (1 << 30):.2f} GiB"
if n >= 1 << 20:
return f"{n / (1 << 20):.1f} MiB"
return f"{n / (1 << 10):.1f} KiB"
def _perf_block(stream, ctx: int) -> List[str]:
p = stream["perf"]
out = [
f" total params : {p['total_params']/1e6:>9.2f}M "
f"({p['other_total']/1e6:.2f}M dense + {p['moe_total']/1e6:.2f}M MoE-experts)",
f" active params / token : {p['active_params_per_tok']/1e6:>9.2f}M "
f"(memory footprint — embed lookup + top_{p['experts_per_tok']}/{p['num_experts']} experts: "
f"{p['embed_total']/1e6:.2f}M embed + "
f"{p['moe_active_per_tok']/1e6:.2f}M MoE-active + "
f"{(p['other_total']-p['embed_total'])/1e6:.2f}M attn/norm/head)",
f" compute params / token : {p['compute_params_per_tok']/1e6:>9.2f}M "
f"(matmul FLOPs only — embedding lookup excluded)",
f" GFLOP / token (fwd, MAC×2): {p['gflops_per_tok']:>9.3f}",
f" weights size (on disk) : {_fmt_bytes(p['disk_size_bytes']):>9s}",
f" weights size (in RAM) : {_fmt_bytes(p['weight_bytes']):>9s}",
f" weights resident (GPU) : {_fmt_bytes(p['weights_resident_bytes']):>9s}",
f" peak GPU mem (eval, ctx={ctx}) : {_fmt_bytes(p['peak_eval_mem_bytes']):>9s}",
]
return out
def _write_infer_log(path: Path, cand, ref, args, sample_text: str, tokenizer, device, dtype):
"""Single-doc inference example + timing + performance metrics."""
from modeling_haremb_pii import _bioes_viterbi_batched
lines: List[str] = []
A = lines.append
A(f"Inference benchmark: A: {ref['label']} vs B: {cand['label']}"
if ref else f"Inference benchmark: {cand['label']}")
A(f" device : {device} dtype: {dtype}")
A(f" ctx : {args.max_length}")
A("")
if ref is not None:
A(f"A: {ref['label']} (reference / teacher)")
A(f" load : {ref['load_s']:.2f}s")
A(f" eval : {ref['eval_s']:.2f}s on {ref['n_tok']:,} tokens "
f"({ref['throughput_tok_s']:.0f} tok/s)")
A("Performance:")
for ln in _perf_block(ref, args.max_length):
A(ln)
A("")
A(f"B: {cand['label']}" + (" (this checkpoint)" if ref else ""))
A(f" load : {cand['load_s']:.2f}s")
A(f" eval : {cand['eval_s']:.2f}s on {cand['n_tok']:,} tokens "
f"({cand['throughput_tok_s']:.0f} tok/s)")
A("Performance:")
for ln in _perf_block(cand, args.max_length):
A(ln)
A("")
if ref is not None:
cp, rp = cand["perf"], ref["perf"]
def _fmt(b, a, kind):
if a is None or a == 0 or b is None or b == 0:
return "—"
r = b / a
if kind == "size":
return f"{1.0/r:.2f}× smaller" if r <= 1.0 else f"{r:.2f}× larger"
if kind == "compute":
return f"{1.0/r:.2f}× cheaper" if r <= 1.0 else f"{r:.2f}× more"
if kind == "memory":
return f"{1.0/r:.2f}× less" if r <= 1.0 else f"{r:.2f}× more"
if kind == "speed":
return f"{r:.2f}× faster" if r >= 1.0 else f"{1.0/r:.2f}× slower"
return f"{r:.2f}×"
A(f"B vs A ({cand['label']} vs {ref['label']}):")
A(f" total params : {_fmt(cp['total_params'], rp['total_params'], 'size')}")
A(f" active params / token : {_fmt(cp['active_params_per_tok'], rp['active_params_per_tok'], 'memory')} [memory]")
A(f" compute params / token : {_fmt(cp['compute_params_per_tok'], rp['compute_params_per_tok'], 'compute')} [FLOPs]")
A(f" GFLOP / token : {_fmt(cp['gflops_per_tok'], rp['gflops_per_tok'], 'compute')}")
if rp['disk_size_bytes']:
A(f" weights size (on disk) : {_fmt(cp['disk_size_bytes'], rp['disk_size_bytes'], 'size')}")
else:
A(f" weights size (on disk) : —")
A(f" weights in RAM : {_fmt(cp['weight_bytes'], rp['weight_bytes'], 'size')}")
A(f" peak GPU mem (eval) : {_fmt(cp['peak_eval_mem_bytes'], rp['peak_eval_mem_bytes'], 'memory')}")
A(f" throughput : {_fmt(cand['throughput_tok_s'], ref['throughput_tok_s'], 'speed')}")
A("")
A("Sample inference (load → tokenize → forward → viterbi-decode → spans):")
A(f" text: {sample_text!r}")
model = AutoModelForTokenClassification.from_pretrained(
".", dtype=dtype, trust_remote_code=True,
).to(device).eval()
if hasattr(model.config, "viterbi_replace_logits"):
model.config.viterbi_replace_logits = True
enc = tokenizer(sample_text, return_tensors="pt", truncation=True,
max_length=args.max_length).to(device)
with torch.no_grad():
if torch.cuda.is_available():
torch.cuda.synchronize()
t0 = time.time()
out = model(**enc)
if torch.cuda.is_available():
torch.cuda.synchronize()
dt = time.time() - t0
label2id, id2label = nemotron_native_label_space()
pred = out.logits.argmax(-1)[0].cpu().tolist()
spans = _bioes_to_spans(pred, id2label, 0)
A(f" forward latency: {dt*1000:.1f}ms ({enc.input_ids.shape[1]} tokens)")
A(f" detected {len(spans)} spans:")
tok_ids = enc.input_ids[0].cpu().tolist()
for s, e, cat in sorted(spans):
text = tokenizer.decode(tok_ids[s:e]).strip()
A(f" [{s:3d}, {e:3d}) {cat:<28s} {text!r}")
del model
if torch.cuda.is_available():
torch.cuda.empty_cache()
path.write_text("\n".join(lines) + "\n")
print(f"[log] wrote {path}", flush=True)
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
p = argparse.ArgumentParser(description="Benchmark haremb-privacy-filter-opennemo")
p.add_argument("--device", default=None,
help="cuda or cpu. Default: cuda if available.")
p.add_argument("--dtype", default="bfloat16",
choices=["bfloat16", "float16", "float32"])
p.add_argument("--eval-pct", type=float, default=1.0,
help="Percent of nvidia/Nemotron-PII test split to use. Default 1%%.")
p.add_argument("--eval-chunk-size", type=int, default=10_000)
p.add_argument("--seed", type=int, default=42)
p.add_argument("--max-length", type=int, default=1024)
p.add_argument("--batch-size", type=int, default=4)
p.add_argument("--out", type=str, default=".")
p.add_argument("--model-path", type=str, default=".",
help="Path to this checkpoint. Default: ./ (this folder).")
p.add_argument("--no-base", action="store_true",
help="Skip the OpenMed teacher comparison.")
p.add_argument("--no-plots", action="store_true",
help="Skip rendering eval_summary.png / eval_confusion.png.")
args = p.parse_args()
out_dir = Path(args.out).resolve()
out_dir.mkdir(parents=True, exist_ok=True)
if args.device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
device = torch.device(args.device)
dtype = {"bfloat16": torch.bfloat16, "float16": torch.float16,
"float32": torch.float32}[args.dtype]
print(f"[setup] device={device} dtype={dtype} model={args.model_path} "
f"out={out_dir}", flush=True)
label2id, id2label = nemotron_native_label_space()
o_id = label2id["O"]
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
pad_token_id = tokenizer.pad_token_id or 199999
# Build the eval set (same slice the README headline numbers reference)
print(f"[data] loading {SOURCE_DATASET} ...", flush=True)
ds = load_dataset(SOURCE_DATASET)
target_eval = max(1, int(round(len(ds["test"]) * args.eval_pct / 100.0)))
eval_indices = _build_eval_streaming(
ds["test"], target_n=target_eval,
chunk_size=args.eval_chunk_size, seed=args.seed,
)
eval_ds = _NemotronEvalDataset(
ds["test"].select(eval_indices), tokenizer, label2id, args.max_length,
)
args.n_docs = len(eval_ds)
print(f"[data] eval={args.n_docs:,} docs ({args.eval_pct:.2f}% of test split)",
flush=True)
# BIOES decoding masks (used for the explicit RAW vs VITERBI streams).
from modeling_haremb_pii import (
_build_bioes_initial_mask as _bld_init,
_build_bioes_transition_mask as _bld_trans,
)
bioes_trans = _bld_trans(id2label).to(device).float()
bioes_init = _bld_init(id2label).to(device).float()
# Eval candidate
cand = _eval_one_model(
args.model_path, tokenizer, eval_ds, label2id, id2label, o_id,
bioes_trans, bioes_init,
args.batch_size, args.max_length, device, dtype,
label="haremb",
)
print(f"\n=== {cand['label']} ===")
print(f"RAW {_fmt_metrics(cand['raw'])}")
print(f"VITERBI {_fmt_metrics(cand['viterbi'])}")
print(f"DELTA span_F1={cand['viterbi']['span_f1']-cand['raw']['span_f1']:+.4f} "
f"P={cand['viterbi']['span_precision']-cand['raw']['span_precision']:+.4f} "
f"R={cand['viterbi']['span_recall']-cand['raw']['span_recall']:+.4f}")
ref = None
pair = None
if not args.no_base:
ref = _eval_one_model(
TEACHER, tokenizer, eval_ds, label2id, id2label, o_id,
bioes_trans, bioes_init,
args.batch_size, args.max_length, device, dtype,
label="openmed-base",
)
print(f"\n=== {ref['label']} (teacher) ===")
print(f"VITERBI {_fmt_metrics(ref['viterbi'])}")
pair = _pairwise(cand["docs"], ref["docs"], id2label, o_id)
d = ref["viterbi"]["span_f1"] - cand["viterbi"]["span_f1"]
print(f"\nGap to teacher (viterbi span_F1): {d:+.4f}")
# Reports
sample_text = ("Patient Sarah Johnson (DOB 03/15/1985), MRN 4872910, "
"phone 415-555-0123, email sarah.johnson@example.com, "
"credit card 4111-1111-1111-1111.")
_write_infer_log(out_dir / "infer.log", cand, ref, args, sample_text,
tokenizer, device, dtype)
_write_compare_log(out_dir / "compare.log", cand, ref, pair, args)
if not args.no_plots:
_render_plots(cand, ref, pair, out_dir,
cand_label=cand["label"],
ref_label=ref["label"] if ref else "")
print("\n[done]")
if __name__ == "__main__":
main()