fblgit commited on
Commit
f0f5785
·
1 Parent(s): 8fb7b0f

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ eval_confusion.png filter=lfs diff=lfs merge=lfs -text
37
+ eval_summary.png filter=lfs diff=lfs merge=lfs -text
38
+ haremb.png filter=lfs diff=lfs merge=lfs -text
39
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__
README.md ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - en
5
+ pipeline_tag: token-classification
6
+ library_name: transformers
7
+ tags:
8
+ - pii
9
+ - privacy
10
+ - token-classification
11
+ - bioes
12
+ - moe
13
+ - haremb
14
+ base_model:
15
+ - OpenMed/privacy-filter-nemotron
16
+ datasets:
17
+ - nvidia/Nemotron-PII
18
+ ---
19
+
20
+ # HarEmb · OpenMed-Nemotron PII
21
+
22
+ > A **single-layer** HarEmb model on the [`OpenMed/privacy-filter-nemotron`](https://huggingface.co/OpenMed/privacy-filter-nemotron) lineage. It has 287M total parameters and predicts the full **221-class BIOES** Nemotron-PII label space.
23
+
24
+ **Model**: [`fblgit/haremb-privacy-filter-opennemo`](https://huggingface.co/fblgit/haremb-privacy-filter-opennemo)
25
+
26
+ ![HarEmb architecture](haremb.png)
27
+
28
+ ## Lineage
29
+
30
+ This model is the third leg of a three-step lineage:
31
+
32
+ 1. **[`openai/privacy-filter`](https://huggingface.co/openai/privacy-filter)** — OpenAI's open release of the underlying 1.4B-parameter MoE backbone (8 transformer layers, ~50M active params/token, BIOES token classifier head).
33
+ 2. **[`OpenMed/privacy-filter-nemotron`](https://huggingface.co/OpenMed/privacy-filter-nemotron)** — OpenMed's full fine-tune of that backbone on `nvidia/Nemotron-PII`, expanding the head to 221 BIOES classes (55 fine-grained PII categories).
34
+ 3. **`haremb-privacy-filter-opennemo`** *(this model)* — a one-layer surgical slice of the OpenMed teacher.
35
+
36
+ ## What this model does
37
+
38
+ Token-level PII classification over **55 Nemotron-PII categories**. Every token receives one of `O` or `{B, I, E, S}-<category>`, covering identity, contact, address, date/time, government ID, financial, healthcare, enterprise ID, vehicle, and digital identifier categories.
39
+
40
+ In `eval()` mode the model can run constrained-BIOES Viterbi decoding internally, so `outputs.logits.argmax(-1)` is span-coherent by default. See [Output semantics](#output-semantics) for the exact fields and opt-out flags.
41
+
42
+ ## Evaluation
43
+
44
+ Evaluated on a 1% slice of `nvidia/Nemotron-PII:test` (1,000 documents, ctx 1024, seed 42), Viterbi-decoded. The benchmark and app both use the convention **A = `OpenMed/privacy-filter-nemotron` (teacher / baseline)**, **B = this checkpoint** (`haremb`); ratios are reported as **B ÷ A**.
45
+
46
+ ### Quality (viterbi stream)
47
+
48
+ | metric | **A: OpenMed teacher** | **B: haremb** (this) | B − A |
49
+ |---|---:|---:|---:|
50
+ | span F1 | 0.9434 | **0.9288** | −0.0146 |
51
+ | span precision | 0.9531 | **0.9396** | −0.0135 |
52
+ | span recall | 0.9338 | **0.9182** | −0.0156 |
53
+ | token accuracy | 0.9900 | **0.9885** | −0.0015 |
54
+ | non-O recall | 0.9703 | **0.9637** | −0.0066 |
55
+
56
+ ### Performance (same eval set, ctx 1024, bf16, single GPU)
57
+
58
+ | metric | **A: OpenMed teacher** | **B: haremb** | B vs A |
59
+ |---|---:|---:|---:|
60
+ | total params | 1,400M | **287M** | **4.87× smaller** |
61
+ | dense params | 139M | 130M | 1.07× smaller |
62
+ | MoE expert params | 1,260M | 158M | **7.97× smaller** |
63
+ | **active params / token** (memory) | 178.7M | **134.5M** | 1.33× smaller |
64
+ | **compute params / token** (FLOPs) | 50.7M | **6.5M** | **7.85× cheaper** |
65
+ | GFLOP / token (forward) | 0.101 | **0.013** | **7.85× cheaper** |
66
+ | weights on disk | (HF repo) | **548 MiB** | — |
67
+ | weights in RAM | 2,669 MiB | 548 MiB | **4.87× smaller** |
68
+ | peak GPU memory (eval) | 3.30 GiB | **1.22 GiB** | **2.70× less** |
69
+ | throughput | 3,275 tok/s | **6,361 tok/s** | **1.94× faster** |
70
+
71
+ `active params / token` estimates memory bandwidth pressure, while `compute params / token` estimates matmul FLOPs and excludes the embedding table row-gather. GFLOP/token is `2 × compute_params_per_token`. `infer.log` and `compare.log` contain the full breakdown, including peak GPU memory from `torch.cuda.max_memory_allocated`.
72
+
73
+ ![Performance profile — absolute footprint and B/A ratio, A teacher vs B candidate](eval_performance.png)
74
+
75
+ ### Quality breakdown
76
+
77
+ ![Eval summary — headline metrics, raw-vs-viterbi span F1, and selected per-category deltas](eval_summary.png)
78
+
79
+ ### Per-category highlights (viterbi span F1)
80
+
81
+ **At or near 1.000 (B)** — `biometric_identifier`, `blood_type`, `coordinate`, `health_plan_beneficiary_number`, `ipv4`, `ipv6`, `license_plate`, `mac_address`, `national_id`, `postcode` (≥ 0.99 with ≥ 100 gold spans).
82
+
83
+ **Categories where B beats A** — `gender` (0.987 vs 0.841), `political_view` (0.872 vs 0.839), `religious_belief` (0.935 vs 0.926), `state` (0.908 vs 0.829), `language` (0.897 vs 0.804), `race_ethnicity` (0.864 vs 0.861), `country` (0.952 vs 0.936). Several "fuzzy" world-knowledge categories where the 1-layer student carries the right inductive bias.
84
+
85
+ **Categories where A leads** — `occupation` (0.727 vs 0.605), `company_name` (0.929 vs 0.776), `last_name` (0.976 vs 0.931), `first_name` (0.970 vs 0.930), `user_name` (0.961 vs 0.942). Identity-noun categories where the teacher's deeper-layer mixing helps.
86
+
87
+ ### Token-outcome breakdown — A: OpenMed teacher vs B: haremb (viterbi)
88
+
89
+ ![Pairwise token outcome and net category wins on gold non-O tokens](eval_confusion.png)
90
+
91
+ ## Quick start
92
+
93
+ ### Recommended ��� via OpenMed
94
+
95
+ The OpenMed wrapper is the same UX the teacher card recommends and works on this checkpoint as a drop-in:
96
+
97
+ ```bash
98
+ pip install -U "openmed[hf]"
99
+ ```
100
+
101
+ ```python
102
+ from openmed import extract_pii, deidentify
103
+
104
+ text = (
105
+ "Patient Sarah Johnson (DOB 03/15/1985), MRN 4872910, "
106
+ "phone 415-555-0123, email sarah.johnson@example.com."
107
+ )
108
+
109
+ result = extract_pii(text, model_name="fblgit/haremb-privacy-filter-opennemo")
110
+ for ent in result.entities:
111
+ print(f"{ent.label:30s} {ent.text!r} conf={ent.confidence:.2f}")
112
+
113
+ masked = deidentify(text, method="mask",
114
+ model_name="fblgit/haremb-privacy-filter-opennemo")
115
+ fake = deidentify(text, method="replace",
116
+ model_name="fblgit/haremb-privacy-filter-opennemo",
117
+ consistent=True, seed=42)
118
+ ```
119
+
120
+ ### HuggingFace `transformers` pipeline
121
+
122
+ ```python
123
+ from transformers import pipeline
124
+
125
+ pipe = pipeline(
126
+ "token-classification",
127
+ model="fblgit/haremb-privacy-filter-opennemo",
128
+ tokenizer="fblgit/haremb-privacy-filter-opennemo",
129
+ trust_remote_code=True,
130
+ aggregation_strategy="simple",
131
+ )
132
+
133
+ pipe("Send the invoice to billing@acmecorp.io, account 1234-5678.")
134
+ # → [{'entity_group': 'email', 'word': 'billing@acmecorp.io', ...},
135
+ # {'entity_group': 'account_number', 'word': '1234-5678', ...}]
136
+ ```
137
+
138
+ ### Raw `transformers` API
139
+
140
+ ```python
141
+ import torch
142
+ from transformers import AutoModelForTokenClassification, AutoTokenizer
143
+
144
+ repo = "fblgit/haremb-privacy-filter-opennemo"
145
+ model = AutoModelForTokenClassification.from_pretrained(
146
+ repo, trust_remote_code=True, dtype=torch.bfloat16,
147
+ ).to("cuda").eval()
148
+ tok = AutoTokenizer.from_pretrained(repo)
149
+
150
+ enc = tok("My email is foo@bar.com.", return_tensors="pt").to("cuda")
151
+ with torch.no_grad():
152
+ out = model(**enc)
153
+
154
+ # By default, `outputs.logits.argmax(-1)` follows the Viterbi-decoded path.
155
+ labels = out.logits.argmax(-1)[0]
156
+ ```
157
+
158
+ ## Output semantics
159
+
160
+ The forward pass — in `eval()` mode — runs constrained-BIOES Viterbi over the per-token logits and attaches three things to the output:
161
+
162
+ - `outputs.logits` — a tensor whose `argmax(-1)` equals the Viterbi prediction (so HF `pipeline()` and naive `argmax` consumers get span-coherent predictions automatically).
163
+ - `outputs.predicted_labels` — a `[B, T]` LongTensor of Viterbi-decoded label ids (`-1` at padded positions).
164
+ - `outputs.raw_logits` — the original per-token logits, preserved for callers that want raw confidences.
165
+
166
+ To opt out:
167
+
168
+ ```python
169
+ model.config.viterbi_replace_logits = False # raw logits in outputs.logits
170
+ model.config.use_viterbi_decode = False # also skip Viterbi entirely
171
+ ```
172
+
173
+ The model supports the upstream context length (max position embeddings 131,072 tokens). Practical batch sizes depend on hardware; bf16 + batch 1 + full-length is comfortable on 24 GB.
174
+
175
+ ## Limitations & intended use
176
+
177
+ - **English-only training data.** Nemotron-PII is predominantly English. Performance on non-English text is not guaranteed.
178
+ - **Synthetic training data.** Real clinical notes, legal documents, and live web text may show different surface forms. For high-stakes deployments, collect a domain-specific eval set and re-calibrate.
179
+ - **Fuzzier categories** — `occupation`, `company_name`, and identity nouns (`first_name`, `last_name`, `user_name`) carry more uncertainty than formatted identifiers; downstream pipelines that only need strict PII can ignore low-confidence predictions on these.
180
+ - **Not a substitute for legal compliance review.** Use alongside a governance layer (human review, deterministic regex pre-filters, etc.).
181
+
182
+ ## Reproducibility
183
+
184
+ Every metric, log, and plot in this card is regenerated by the single-file [`benchmark.py`](benchmark.py) shipped alongside the weights:
185
+
186
+ ```bash
187
+ python benchmark.py # full benchmark vs OpenMed teacher
188
+ python benchmark.py --no-base # skip teacher download (logs only)
189
+ python benchmark.py --no-plots # skip matplotlib (logs + JSON only)
190
+ python benchmark.py --eval-pct 0.1 # smaller slice for a quick check
191
+ ```
192
+
193
+ Outputs into the model folder:
194
+
195
+ - `infer.log`
196
+ - `compare.log`
197
+ - `eval_summary.png`
198
+ - `eval_confusion.png`
199
+ - `eval_performance.png`
200
+
201
+ Raw per-doc eval data is held in memory only. Pass `--out` to write artifacts somewhere else.
202
+
203
+ The Gradio demo in [`app.py`](app.py) supports **side-by-side A-vs-B comparison** between any two token-classification checkpoints with the same label space. Defaults match the report convention: **A = OpenMed/privacy-filter-nemotron** (teacher / baseline), **B = this checkpoint**. Disable either model to run single-model inference; both expose a runtime "active experts per token" slider so you can sweep MoE routing density. From inside the model folder:
204
+
205
+ ```bash
206
+ python app.py # A=OpenMed teacher, B=. (this)
207
+ python app.py --model-a /path/to/another/repo # swap baseline A
208
+ python app.py --model-b /path/to/another/repo # swap candidate B
209
+ python app.py --port 7860 --share # public share link
210
+ ```
211
+
212
+ ## License
213
+
214
+ Apache-2.0, same as the lineage. Subject to the license terms of [`openai/privacy-filter`](https://huggingface.co/openai/privacy-filter) and the dataset terms of [`nvidia/Nemotron-PII`](https://huggingface.co/datasets/nvidia/Nemotron-PII).
215
+
216
+ ## Citation
217
+
218
+ ```bibtex
219
+ @misc{haremb-privacy-filter-opennemo,
220
+ title = {HarEmb · OpenMed-Nemotron PII: a single-layer
221
+ privacy-filter slice with span-coherent inference},
222
+ author = {fblgit},
223
+ year = {2026},
224
+ publisher = {Hugging Face},
225
+ url = {https://huggingface.co/fblgit/haremb-privacy-filter-opennemo},
226
+ howpublished = {\url{https://huggingface.co/fblgit/haremb-privacy-filter-opennemo}},
227
+ note = {Single-transformer-layer model on the openai/privacy-filter →
228
+ OpenMed/privacy-filter-nemotron lineage; 287M total params,
229
+ 221 BIOES classes (55 fine-grained PII categories), with
230
+ inlined constrained-BIOES Viterbi decoding so
231
+ outputs.logits.argmax(-1) is span-coherent.}
232
+ }
233
+
234
+ @misc{openmed-privacy-filter-nemotron,
235
+ title = {OpenMed/privacy-filter-nemotron: fine-grained PII extraction
236
+ with 55 categories},
237
+ author = {OpenMed},
238
+ year = {2026},
239
+ publisher = {Hugging Face},
240
+ url = {https://huggingface.co/OpenMed/privacy-filter-nemotron}
241
+ }
242
+
243
+ @misc{openai-privacy-filter,
244
+ title = {Privacy Filter},
245
+ author = {OpenAI},
246
+ year = {2026},
247
+ publisher = {Hugging Face},
248
+ url = {https://huggingface.co/openai/privacy-filter}
249
+ }
250
+
251
+ @misc{nvidia-nemotron-pii,
252
+ title = {Nemotron-PII},
253
+ author = {NVIDIA},
254
+ year = {2025},
255
+ publisher = {Hugging Face},
256
+ url = {https://huggingface.co/datasets/nvidia/Nemotron-PII}
257
+ }
258
+ ```
app.py ADDED
@@ -0,0 +1,534 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HarEmb PII — local Gradio inference demo.
3
+
4
+ Upload a PDF (or paste text), pick a device (CPU / cuda:N), and the model
5
+ highlights detected PII spans across the 55-category Nemotron-PII taxonomy.
6
+
7
+ Install:
8
+ pip install "gradio>=4" "transformers>=4.45" torch pypdf accelerate
9
+
10
+ Run from inside this folder:
11
+ python app.py
12
+ """
13
+ from __future__ import annotations
14
+
15
+ import argparse
16
+ import re
17
+ from pathlib import Path
18
+ from typing import Dict, List, Optional, Tuple
19
+
20
+ import gradio as gr
21
+ import torch
22
+ from pypdf import PdfReader
23
+ from transformers import pipeline
24
+
25
+
26
+ # Default to loading from this folder so `python app.py` works in-place after
27
+ # downloading the repo. Override by setting --model-path on the CLI.
28
+ DEFAULT_MODEL = "."
29
+
30
+ CHUNK_CHARS = 400_000 # ~100k tokens; well under the model's 131k window
31
+
32
+ # 55 Nemotron-PII categories grouped for visual coherence; one color per
33
+ # coarse "family" so the highlight legend stays readable.
34
+ PALETTE: Dict[str, str] = {
35
+ # Identity (red)
36
+ "first_name": "#ef4444",
37
+ "last_name": "#ef4444",
38
+ "user_name": "#ef4444",
39
+ "company_name": "#ef4444",
40
+ "age": "#fb7185",
41
+ "gender": "#fb7185",
42
+ "race_ethnicity": "#fb7185",
43
+ "sexuality": "#fb7185",
44
+ "religious_belief": "#fb7185",
45
+ "political_view": "#fb7185",
46
+ "language": "#fb7185",
47
+ "education_level": "#fb7185",
48
+ "occupation": "#fb7185",
49
+ "employment_status": "#fb7185",
50
+ "blood_type": "#fb7185",
51
+ "biometric_identifier":"#fb7185",
52
+ # Contact (purple)
53
+ "email": "#8b5cf6",
54
+ "phone_number": "#a78bfa",
55
+ "fax_number": "#a78bfa",
56
+ "url": "#7c3aed",
57
+ # Address (green)
58
+ "street_address": "#10b981",
59
+ "city": "#34d399",
60
+ "county": "#34d399",
61
+ "state": "#34d399",
62
+ "country": "#34d399",
63
+ "postcode": "#34d399",
64
+ "coordinate": "#059669",
65
+ # Dates (blue)
66
+ "date": "#3b82f6",
67
+ "date_of_birth": "#60a5fa",
68
+ "date_time": "#60a5fa",
69
+ "time": "#60a5fa",
70
+ # Government IDs (orange)
71
+ "ssn": "#f97316",
72
+ "national_id": "#fb923c",
73
+ "tax_id": "#fb923c",
74
+ # Financial (amber)
75
+ "account_number": "#f59e0b",
76
+ "bank_routing_number": "#fbbf24",
77
+ "swift_bic": "#fbbf24",
78
+ "credit_debit_card": "#fbbf24",
79
+ "cvv": "#fbbf24",
80
+ "pin": "#fbbf24",
81
+ "password": "#d97706",
82
+ # Healthcare (pink)
83
+ "medical_record_number": "#ec4899",
84
+ "health_plan_beneficiary_number": "#f472b6",
85
+ # Enterprise IDs (cyan)
86
+ "customer_id": "#06b6d4",
87
+ "employee_id": "#06b6d4",
88
+ "unique_id": "#22d3ee",
89
+ "certificate_license_number": "#22d3ee",
90
+ # Vehicle (lime)
91
+ "license_plate": "#84cc16",
92
+ "vehicle_identifier": "#84cc16",
93
+ # Digital (indigo)
94
+ "ipv4": "#6366f1",
95
+ "ipv6": "#6366f1",
96
+ "mac_address": "#818cf8",
97
+ "device_identifier": "#818cf8",
98
+ "api_key": "#4f46e5",
99
+ "http_cookie": "#4f46e5",
100
+ }
101
+
102
+
103
+ def list_devices() -> List[str]:
104
+ devs = ["cpu"]
105
+ if torch.cuda.is_available():
106
+ for i in range(torch.cuda.device_count()):
107
+ devs.append(f"cuda:{i}")
108
+ return devs
109
+
110
+
111
+ _pipe_cache: Dict[Tuple[str, str], object] = {}
112
+
113
+
114
+ def get_pipe(model_path: str, device: str):
115
+ key = (model_path, device)
116
+ if key in _pipe_cache:
117
+ return _pipe_cache[key]
118
+ dtype = torch.bfloat16 if device.startswith("cuda") else torch.float32
119
+ pipe = pipeline(
120
+ "token-classification",
121
+ model=model_path,
122
+ tokenizer=model_path,
123
+ trust_remote_code=True,
124
+ aggregation_strategy="simple",
125
+ device=device,
126
+ torch_dtype=dtype,
127
+ )
128
+ _pipe_cache[key] = pipe
129
+ return pipe
130
+
131
+
132
+ def apply_runtime_config(
133
+ pipe,
134
+ use_viterbi: bool,
135
+ viterbi_replace: bool,
136
+ top_k: Optional[int] = None,
137
+ ) -> None:
138
+ cfg = pipe.model.config
139
+ if hasattr(cfg, "use_viterbi_decode"):
140
+ cfg.use_viterbi_decode = bool(use_viterbi)
141
+ if hasattr(cfg, "viterbi_replace_logits"):
142
+ cfg.viterbi_replace_logits = bool(viterbi_replace)
143
+ # Override the per-layer MoE top-k at inference. Both fields need to be
144
+ # set: `mlp.router.top_k` is the actual router top-k, and the upstream
145
+ # `mlp.num_experts` is misnamed (it's also the per-token top_k, not
146
+ # num_local_experts). top_k=None leaves the trained config alone.
147
+ if top_k is not None:
148
+ n_local = int(getattr(cfg, "num_local_experts", 128))
149
+ k = max(1, min(int(top_k), n_local))
150
+ for layer in pipe.model.model.layers:
151
+ mlp = getattr(layer, "mlp", None)
152
+ if mlp is None:
153
+ continue
154
+ router = getattr(mlp, "router", None)
155
+ if router is not None and hasattr(router, "top_k"):
156
+ router.top_k = k
157
+ if hasattr(mlp, "num_experts"):
158
+ mlp.num_experts = k
159
+
160
+
161
+ def model_top_k_default(model_path: str) -> int:
162
+ """Read the trained `num_experts_per_tok` from the model's config without
163
+ loading the weights. Falls back to 4 if the field isn't present."""
164
+ try:
165
+ from transformers import AutoConfig
166
+ cfg = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
167
+ return int(getattr(cfg, "num_experts_per_tok", 4))
168
+ except Exception:
169
+ return 4
170
+
171
+
172
+ def model_num_experts(model_path: str) -> int:
173
+ """Read `num_local_experts` from the model's config without loading
174
+ weights. Falls back to 128 if the field isn't present."""
175
+ try:
176
+ from transformers import AutoConfig
177
+ cfg = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
178
+ return int(getattr(cfg, "num_local_experts", 128))
179
+ except Exception:
180
+ return 128
181
+
182
+
183
+ def clear_model_cache() -> str:
184
+ _pipe_cache.clear()
185
+ if torch.cuda.is_available():
186
+ torch.cuda.empty_cache()
187
+ return "Model cache cleared. Next run will reload weights."
188
+
189
+
190
+ def extract_text(file_obj) -> str:
191
+ if file_obj is None:
192
+ return ""
193
+ path = file_obj.name if hasattr(file_obj, "name") else file_obj
194
+ p = Path(path)
195
+ if p.suffix.lower() == ".pdf":
196
+ reader = PdfReader(str(p))
197
+ return "\n\n".join((page.extract_text() or "") for page in reader.pages)
198
+ return p.read_text(encoding="utf-8", errors="replace")
199
+
200
+
201
+ def chunk_text(text: str, max_chars: int = CHUNK_CHARS) -> List[Tuple[int, str]]:
202
+ if not text:
203
+ return []
204
+ if max_chars <= 0 or len(text) <= max_chars:
205
+ return [(0, text)]
206
+ pieces = re.split(r"(\n\s*\n)", text)
207
+ chunks: List[Tuple[int, str]] = []
208
+ cur, cur_off, pos = "", 0, 0
209
+ for piece in pieces:
210
+ if cur and len(cur) + len(piece) > max_chars and cur.strip():
211
+ chunks.append((cur_off, cur))
212
+ cur, cur_off = piece, pos
213
+ else:
214
+ if not cur:
215
+ cur_off = pos
216
+ cur += piece
217
+ pos += len(piece)
218
+ if cur.strip():
219
+ chunks.append((cur_off, cur))
220
+ return chunks
221
+
222
+
223
+ def category_of(label: str) -> str:
224
+ if len(label) > 2 and label[1] == "-":
225
+ return label[2:]
226
+ return label
227
+
228
+
229
+ def predict(
230
+ model_path: str,
231
+ device: str,
232
+ text: str,
233
+ aggregation: str,
234
+ use_viterbi: bool,
235
+ viterbi_replace: bool,
236
+ top_k: Optional[int] = None,
237
+ chunk_chars: int = CHUNK_CHARS,
238
+ ) -> List[Dict]:
239
+ if not text.strip():
240
+ return []
241
+ pipe = get_pipe(model_path, device)
242
+ apply_runtime_config(pipe, use_viterbi, viterbi_replace, top_k=top_k)
243
+ spans: List[Dict] = []
244
+ for offset, chunk in chunk_text(text, max_chars=chunk_chars):
245
+ for ent in pipe(chunk, aggregation_strategy=aggregation):
246
+ label = ent.get("entity_group") or ent.get("entity") or ""
247
+ cat = category_of(label)
248
+ if cat not in PALETTE:
249
+ continue
250
+ s = ent["start"] + offset
251
+ e = ent["end"] + offset
252
+ spans.append({
253
+ "start": s, "end": e, "label": cat,
254
+ "score": float(ent["score"]),
255
+ "text": text[s:e],
256
+ })
257
+ spans.sort(key=lambda s: (s["start"], s["end"]))
258
+ return spans
259
+
260
+
261
+ def to_highlight(text: str, spans: List[Dict]) -> List[Tuple[str, Optional[str]]]:
262
+ if not text:
263
+ return []
264
+ out: List[Tuple[str, Optional[str]]] = []
265
+ cur = 0
266
+ for s in spans:
267
+ if s["start"] < cur:
268
+ continue
269
+ if s["start"] > cur:
270
+ out.append((text[cur:s["start"]], None))
271
+ out.append((text[s["start"]:s["end"]], s["label"]))
272
+ cur = s["end"]
273
+ if cur < len(text):
274
+ out.append((text[cur:], None))
275
+ return out
276
+
277
+
278
+ def fmt_spans(spans: List[Dict], max_rows: int = 60) -> str:
279
+ if not spans:
280
+ return "_No PII spans detected._"
281
+ rows = [
282
+ f"- `{s['label']}` &nbsp; `{s['text'][:80].replace('`', '')}` &nbsp; (score {s['score']:.2f})"
283
+ for s in spans[:max_rows]
284
+ ]
285
+ more = f"\n\n_…+{len(spans) - max_rows} more_" if len(spans) > max_rows else ""
286
+ return f"**Detected {len(spans)} span(s):**\n" + "\n".join(rows) + more
287
+
288
+
289
+ # Build legend HTML for the categories present in PALETTE — one row per family
290
+ # (we still want it readable; show one swatch per unique color).
291
+ def _legend_html() -> str:
292
+ seen = {}
293
+ for name, c in PALETTE.items():
294
+ seen.setdefault(c, []).append(name)
295
+ rows = []
296
+ for c, names in seen.items():
297
+ chip = (f"<span style='background:{c};color:#fff;padding:.15rem .55rem;"
298
+ f"border-radius:.3rem;font-family:monospace;'>"
299
+ f"{names[0]}{(' +'+str(len(names)-1)) if len(names)>1 else ''}</span>")
300
+ rows.append(chip)
301
+ return ("<div style='display:flex;flex-wrap:wrap;gap:.4rem;font-size:.85rem;"
302
+ "margin:.25rem 0;'>" + "".join(rows) + "</div>")
303
+
304
+
305
+ LEGEND_HTML = _legend_html()
306
+
307
+
308
+ def diff_spans(a: List[Dict], b: List[Dict]):
309
+ """Return (only_in_a, only_in_b, agreed) span-lists. Keys are the
310
+ (start, end, label) triple — agreement requires identical category."""
311
+ key = lambda s: (s["start"], s["end"], s["label"])
312
+ sa = {key(s): s for s in a}
313
+ sb = {key(s): s for s in b}
314
+ only_a = [sa[k] for k in sa if k not in sb]
315
+ only_b = [sb[k] for k in sb if k not in sa]
316
+ both = [sa[k] for k in sa if k in sb]
317
+ return only_a, only_b, both
318
+
319
+
320
+ def fmt_diff(label_a: str, label_b: str,
321
+ only_a: List[Dict], only_b: List[Dict], agreed: List[Dict]) -> str:
322
+ def fmt(name: str, lst: List[Dict]) -> str:
323
+ if not lst:
324
+ return f"**{name}:** none"
325
+ rows = [
326
+ f"- `{s['label']}` &nbsp; `{s['text'][:80].replace('`', '')}` "
327
+ f"&nbsp; (score {s['score']:.2f})"
328
+ for s in lst[:30]
329
+ ]
330
+ more = f"\n …+{len(lst) - 30} more" if len(lst) > 30 else ""
331
+ return f"**{name} ({len(lst)}):**\n" + "\n".join(rows) + more
332
+
333
+ return "\n\n".join([
334
+ fmt(f"Only {label_a}", only_a),
335
+ fmt(f"Only {label_b}", only_b),
336
+ fmt("Agreed by both", agreed),
337
+ ])
338
+
339
+
340
+ def run(
341
+ file_obj, pasted_text, device,
342
+ model_a_path, model_b_path,
343
+ use_a, use_b,
344
+ aggregation, use_viterbi, viterbi_replace,
345
+ top_k_a, top_k_b,
346
+ min_score, chunk_chars,
347
+ ):
348
+ text = extract_text(file_obj) if file_obj else (pasted_text or "")
349
+ if not text.strip():
350
+ return [], [], "_Provide a PDF, a text file, or pasted text._", ""
351
+ if not (use_a or use_b):
352
+ return [], [], "_Enable at least one model._", text
353
+
354
+ a_spans = (
355
+ predict(model_a_path, device, text, aggregation,
356
+ use_viterbi, viterbi_replace,
357
+ top_k=int(top_k_a), chunk_chars=int(chunk_chars))
358
+ if use_a else []
359
+ )
360
+ b_spans = (
361
+ predict(model_b_path, device, text, aggregation,
362
+ use_viterbi, viterbi_replace,
363
+ top_k=int(top_k_b), chunk_chars=int(chunk_chars))
364
+ if use_b else []
365
+ )
366
+ thr = float(min_score)
367
+ a_spans = [s for s in a_spans if s["score"] >= thr]
368
+ b_spans = [s for s in b_spans if s["score"] >= thr]
369
+
370
+ a_hl = to_highlight(text, a_spans) if use_a else []
371
+ b_hl = to_highlight(text, b_spans) if use_b else []
372
+
373
+ label_a = Path(model_a_path).name or model_a_path
374
+ label_b = Path(model_b_path).name or model_b_path
375
+
376
+ if use_a and use_b:
377
+ only_a, only_b, agreed = diff_spans(a_spans, b_spans)
378
+ diff_md = fmt_diff(label_a, label_b, only_a, only_b, agreed)
379
+ elif use_a:
380
+ diff_md = fmt_spans(a_spans)
381
+ elif use_b:
382
+ diff_md = fmt_spans(b_spans)
383
+ else:
384
+ diff_md = "_Enable a model._"
385
+
386
+ return a_hl, b_hl, diff_md, text
387
+
388
+
389
+ def build_ui(default_model_a: str, default_model_b: str) -> gr.Blocks:
390
+ a_default_k = model_top_k_default(default_model_a)
391
+ a_n_experts = model_num_experts(default_model_a)
392
+ b_default_k = model_top_k_default(default_model_b)
393
+ b_n_experts = model_num_experts(default_model_b)
394
+
395
+ with gr.Blocks(title="HarEmb PII") as demo:
396
+ gr.Markdown(
397
+ "# HarEmb · OpenMed-Nemotron PII\n"
398
+ "Detect PII across 55 categories of the Nemotron-PII taxonomy. "
399
+ "Run **two models side-by-side** to compare detections — by "
400
+ "default this checkpoint vs the OpenMed teacher it was distilled "
401
+ "from. Disable one model to view a single detection."
402
+ )
403
+ devices = list_devices()
404
+ with gr.Row():
405
+ device_dd = gr.Dropdown(devices, value=devices[0], label="Device", scale=1)
406
+ clear_btn = gr.Button("Clear model cache", variant="secondary", scale=1)
407
+
408
+ with gr.Row():
409
+ with gr.Column():
410
+ use_a = gr.Checkbox(value=True, label="Enable model A (teacher / baseline)")
411
+ model_a_tb = gr.Textbox(
412
+ value=default_model_a,
413
+ label="Model A — path / HF repo",
414
+ info="Default: OpenMed/privacy-filter-nemotron (teacher).",
415
+ )
416
+ top_k_a_sl = gr.Slider(
417
+ 1, a_n_experts, value=a_default_k, step=1,
418
+ label=f"Active experts per token (top-k of {a_n_experts})",
419
+ info=f"Trained value: {a_default_k}. Lower = faster + less "
420
+ f"capacity per token. Higher = more compute, denser "
421
+ f"routing. Bypassing the trained value can drop "
422
+ f"quality — useful for ablations.",
423
+ )
424
+ with gr.Column():
425
+ use_b = gr.Checkbox(value=True, label="Enable model B (this checkpoint)")
426
+ model_b_tb = gr.Textbox(
427
+ value=default_model_b,
428
+ label="Model B — path / HF repo",
429
+ info="Default: ./ (this checkpoint).",
430
+ )
431
+ top_k_b_sl = gr.Slider(
432
+ 1, b_n_experts, value=b_default_k, step=1,
433
+ label=f"Active experts per token (top-k of {b_n_experts})",
434
+ info=f"Trained value: {b_default_k}.",
435
+ )
436
+
437
+ with gr.Accordion("Inference settings", open=False):
438
+ with gr.Row():
439
+ aggregation_dd = gr.Dropdown(
440
+ ["simple", "first", "max", "average", "none"],
441
+ value="simple",
442
+ label="aggregation_strategy",
443
+ info="how token-level labels are merged into spans",
444
+ )
445
+ viterbi_cb = gr.Checkbox(
446
+ value=True,
447
+ label="use_viterbi_decode",
448
+ info="constrained BIOES decoding (off = raw argmax)",
449
+ )
450
+ viterbi_replace_cb = gr.Checkbox(
451
+ value=True,
452
+ label="viterbi_replace_logits",
453
+ info="when on, outputs.logits.argmax(-1) returns the Viterbi path",
454
+ )
455
+ min_score_sl = gr.Slider(
456
+ 0.0, 1.0, value=0.0, step=0.01,
457
+ label="min confidence",
458
+ info="filter out spans with score below this threshold",
459
+ )
460
+ chunk_sl = gr.Slider(
461
+ 0, 500_000, value=CHUNK_CHARS, step=10_000,
462
+ label="chunk size (chars)",
463
+ info="0 = single pass; otherwise split on paragraphs at this size. "
464
+ "Model window ≈131k tokens (~500k chars).",
465
+ )
466
+
467
+ with gr.Row():
468
+ file_in = gr.File(label="PDF / text file", file_types=[".pdf", ".txt", ".md"])
469
+ text_in = gr.Textbox(
470
+ label="…or paste text",
471
+ lines=6,
472
+ placeholder=("Patient Sarah Johnson (DOB 03/15/1985), MRN 4872910, "
473
+ "phone 415-555-0123, email sarah.johnson@example.com."),
474
+ )
475
+ run_btn = gr.Button("Detect PII", variant="primary")
476
+
477
+ gr.HTML(LEGEND_HTML)
478
+ with gr.Row():
479
+ a_out = gr.HighlightedText(
480
+ label="Model A detections",
481
+ color_map=PALETTE,
482
+ show_legend=False,
483
+ combine_adjacent=False,
484
+ )
485
+ b_out = gr.HighlightedText(
486
+ label="Model B detections",
487
+ color_map=PALETTE,
488
+ show_legend=False,
489
+ combine_adjacent=False,
490
+ )
491
+ diff_out = gr.Markdown("_Run a detection to see the diff / span list._")
492
+ extracted_out = gr.Textbox(
493
+ label="Extracted text (read-only)", lines=6, interactive=False,
494
+ )
495
+
496
+ run_btn.click(
497
+ run,
498
+ [file_in, text_in, device_dd,
499
+ model_a_tb, model_b_tb, use_a, use_b,
500
+ aggregation_dd, viterbi_cb, viterbi_replace_cb,
501
+ top_k_a_sl, top_k_b_sl,
502
+ min_score_sl, chunk_sl],
503
+ [a_out, b_out, diff_out, extracted_out],
504
+ )
505
+ clear_btn.click(clear_model_cache, None, diff_out)
506
+
507
+ return demo
508
+
509
+
510
+ def parse_args() -> argparse.Namespace:
511
+ p = argparse.ArgumentParser(description="HarEmb PII — Gradio demo")
512
+ p.add_argument("--host", default="127.0.0.1", help="Bind address (default: 127.0.0.1)")
513
+ p.add_argument("--port", type=int, default=7860, help="Port (default: 7860)")
514
+ p.add_argument("--share", action="store_true", help="Create a public Gradio share link")
515
+ p.add_argument("--model-a", default="OpenMed/privacy-filter-nemotron",
516
+ help="Model A path / HF repo "
517
+ "(default: OpenMed/privacy-filter-nemotron — teacher)")
518
+ p.add_argument("--model-b", default=DEFAULT_MODEL,
519
+ help="Model B path / HF repo "
520
+ "(default: . — this checkpoint)")
521
+ return p.parse_args()
522
+
523
+
524
+ if __name__ == "__main__":
525
+ args = parse_args()
526
+ build_ui(
527
+ default_model_a=args.model_a,
528
+ default_model_b=args.model_b,
529
+ ).launch(
530
+ server_name=args.host,
531
+ server_port=args.port,
532
+ share=args.share,
533
+ theme=gr.themes.Soft(),
534
+ )
benchmark.py ADDED
@@ -0,0 +1,1351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ benchmark.py — single self-contained reproducibility script for
3
+ haremb-privacy-filter-opennemo.
4
+
5
+ Run from inside this folder:
6
+
7
+ python benchmark.py # default: cuda if available
8
+ python benchmark.py --device cpu # cpu fallback
9
+ python benchmark.py --eval-pct 0.5 # smaller slice
10
+ python benchmark.py --no-base # skip teacher download
11
+
12
+ Produces, in `--out` (default ./):
13
+ infer.log — sample inference timing + redaction example
14
+ compare.log — aggregate + per-category metrics, this model vs
15
+ OpenMed teacher (raw + viterbi streams), and
16
+ token-level pairwise breakdown.
17
+ eval_summary.png — bar charts of headline metrics + per-category
18
+ span-F1 (this vs teacher).
19
+ eval_confusion.png — token-level outcome breakdown on gold non-O
20
+ positions (this vs teacher).
21
+ eval_performance.png — model-size / compute / memory / throughput
22
+ comparison (this vs teacher), absolute + ratios.
23
+
24
+ This script does not import from training code. It vendors the small set
25
+ of helpers it needs (BIOES decoder, span builder, eval-set sampler,
26
+ metrics aggregator) so the model folder is self-contained.
27
+ """
28
+ from __future__ import annotations
29
+
30
+ import argparse
31
+ import ast
32
+ import math
33
+ import os
34
+ import sys
35
+ import time
36
+ from collections import Counter, defaultdict
37
+ from pathlib import Path
38
+ from typing import Dict, List, Tuple
39
+
40
+ import numpy as np
41
+ import torch
42
+ from datasets import load_dataset
43
+ from torch.utils.data import DataLoader, Dataset
44
+ from tqdm.auto import tqdm
45
+ from transformers import AutoModelForTokenClassification, AutoTokenizer
46
+
47
+
48
+ # ---------------------------------------------------------------------------
49
+ # Constants
50
+ # ---------------------------------------------------------------------------
51
+
52
+ SOURCE_DATASET = "nvidia/Nemotron-PII"
53
+ TEACHER = "OpenMed/privacy-filter-nemotron"
54
+
55
+ # 55 Nemotron-PII categories, alphabetically sorted (matches the order used
56
+ # when the model was trained, so id2label / label2id round-trip cleanly).
57
+ NEMOTRON_CATEGORIES: List[str] = sorted([
58
+ "account_number", "age", "api_key", "bank_routing_number",
59
+ "biometric_identifier", "blood_type", "certificate_license_number",
60
+ "city", "company_name", "coordinate", "country", "county",
61
+ "credit_debit_card", "customer_id", "cvv", "date", "date_of_birth",
62
+ "date_time", "device_identifier", "education_level", "email",
63
+ "employee_id", "employment_status", "fax_number", "first_name",
64
+ "gender", "health_plan_beneficiary_number", "http_cookie", "ipv4",
65
+ "ipv6", "language", "last_name", "license_plate", "mac_address",
66
+ "medical_record_number", "national_id", "occupation", "password",
67
+ "phone_number", "pin", "political_view", "postcode", "race_ethnicity",
68
+ "religious_belief", "sexuality", "ssn", "state", "street_address",
69
+ "swift_bic", "tax_id", "time", "unique_id", "url", "user_name",
70
+ "vehicle_identifier",
71
+ ])
72
+
73
+
74
+ def nemotron_native_label_space() -> Tuple[Dict[str, int], Dict[int, str]]:
75
+ """O at id 0, then {B, I, E, S}-{cat} for each cat in alphabetical order."""
76
+ label2id: Dict[str, int] = {"O": 0}
77
+ nxt = 1
78
+ for cat in NEMOTRON_CATEGORIES:
79
+ for prefix in ("B", "I", "E", "S"):
80
+ label2id[f"{prefix}-{cat}"] = nxt
81
+ nxt += 1
82
+ id2label: Dict[int, str] = {v: k for k, v in label2id.items()}
83
+ return label2id, id2label
84
+
85
+
86
+ # ---------------------------------------------------------------------------
87
+ # Span parsing + char-level token alignment (vendored from the training data
88
+ # pipeline; identical logic, no training imports)
89
+ # ---------------------------------------------------------------------------
90
+
91
+ def _trim_span(text: str, start: int, end: int) -> Tuple[int, int]:
92
+ raw = text[start:end]
93
+ i = 0
94
+ while i < len(raw) and raw[i].isspace():
95
+ i += 1
96
+ j = len(raw)
97
+ while j > i and (raw[j - 1].isspace() or raw[j - 1] in ".,;:)"):
98
+ j -= 1
99
+ return start + i, start + j
100
+
101
+
102
+ def _parse_spans(spans_str) -> List[dict]:
103
+ if isinstance(spans_str, list):
104
+ return spans_str
105
+ try:
106
+ return ast.literal_eval(spans_str)
107
+ except (SyntaxError, ValueError):
108
+ return []
109
+
110
+
111
+ def _assign_native_bioes_labels(
112
+ text: str,
113
+ raw_spans: List[dict],
114
+ tokenizer,
115
+ max_length: int,
116
+ label2id: Dict[str, int],
117
+ min_overlap_frac: float = 0.5,
118
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
119
+ pf_like: List[Tuple[int, int, str]] = []
120
+ for s in raw_spans:
121
+ cat = s.get("label")
122
+ if not cat:
123
+ continue
124
+ st, en = _trim_span(text, int(s["start"]), int(s["end"]))
125
+ if st >= en:
126
+ continue
127
+ pf_like.append((st, en, cat))
128
+ pf_like.sort(key=lambda x: (x[0], -x[1]))
129
+
130
+ enc = tokenizer(
131
+ text, truncation=True, max_length=max_length,
132
+ padding="max_length", return_offsets_mapping=True, return_tensors="pt",
133
+ )
134
+ input_ids = enc.input_ids[0]
135
+ attention_mask = enc.attention_mask[0]
136
+ offsets = enc.offset_mapping[0].tolist()
137
+
138
+ o_id = label2id["O"]
139
+ label_ids = [o_id] * len(input_ids)
140
+ locked = [False] * len(input_ids)
141
+
142
+ for span_start, span_end, cat in pf_like:
143
+ tok_indices: List[int] = []
144
+ for ti, (s, e) in enumerate(offsets):
145
+ if s == 0 and e == 0:
146
+ continue
147
+ if e <= span_start or s >= span_end:
148
+ continue
149
+ tok_len = e - s
150
+ if tok_len <= 0:
151
+ continue
152
+ overlap = min(e, span_end) - max(s, span_start)
153
+ if overlap / tok_len >= min_overlap_frac:
154
+ tok_indices.append(ti)
155
+
156
+ if not tok_indices:
157
+ continue
158
+ tok_indices = [ti for ti in tok_indices if not locked[ti]]
159
+ if not tok_indices:
160
+ continue
161
+
162
+ if len(tok_indices) == 1:
163
+ tag = f"S-{cat}"
164
+ if tag in label2id:
165
+ label_ids[tok_indices[0]] = label2id[tag]
166
+ locked[tok_indices[0]] = True
167
+ else:
168
+ b_tag, i_tag, e_tag = f"B-{cat}", f"I-{cat}", f"E-{cat}"
169
+ if b_tag in label2id:
170
+ label_ids[tok_indices[0]] = label2id[b_tag]
171
+ locked[tok_indices[0]] = True
172
+ for ti in tok_indices[1:-1]:
173
+ if i_tag in label2id:
174
+ label_ids[ti] = label2id[i_tag]
175
+ locked[ti] = True
176
+ if e_tag in label2id:
177
+ label_ids[tok_indices[-1]] = label2id[e_tag]
178
+ locked[tok_indices[-1]] = True
179
+
180
+ label_tensor = torch.tensor(label_ids, dtype=torch.long)
181
+ label_tensor[attention_mask == 0] = -100
182
+ return input_ids, attention_mask, label_tensor
183
+
184
+
185
+ class _NemotronEvalDataset(Dataset):
186
+ def __init__(self, hf_split, tokenizer, label2id, max_length):
187
+ self.hf = hf_split
188
+ self.tok = tokenizer
189
+ self.l2i = label2id
190
+ self.maxlen = max_length
191
+
192
+ def __len__(self):
193
+ return len(self.hf)
194
+
195
+ def __getitem__(self, idx):
196
+ ex = self.hf[idx]
197
+ ids, mask, labels = _assign_native_bioes_labels(
198
+ ex["text"], _parse_spans(ex["spans"]),
199
+ self.tok, self.maxlen, self.l2i,
200
+ )
201
+ L = int(mask.sum().item())
202
+ return {
203
+ "input_ids": ids[:L].tolist(),
204
+ "labels": labels[:L].tolist(),
205
+ "valid_len": L,
206
+ }
207
+
208
+
209
+ def _make_collate(pad_token_id, max_length):
210
+ def _c(batch):
211
+ ids_list = [list(ex["input_ids"])[:max_length] for ex in batch]
212
+ labels_list = [list(ex["labels"])[:max_length] for ex in batch]
213
+ max_len = max(len(x) for x in ids_list)
214
+ B = len(batch)
215
+ input_ids = torch.full((B, max_len), pad_token_id, dtype=torch.long)
216
+ attention_mask = torch.zeros((B, max_len), dtype=torch.long)
217
+ labels = torch.full((B, max_len), -100, dtype=torch.long)
218
+ for i, (ids, lab) in enumerate(zip(ids_list, labels_list)):
219
+ L = len(ids)
220
+ input_ids[i, :L] = torch.tensor(ids, dtype=torch.long)
221
+ attention_mask[i, :L] = 1
222
+ labels[i, :L] = torch.tensor(lab, dtype=torch.long)
223
+ return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
224
+ return _c
225
+
226
+
227
+ def _build_eval_streaming(test_split, target_n, chunk_size, seed) -> List[int]:
228
+ """Uniform chunked sampling, identical to the training-time eval split."""
229
+ n_total = len(test_split)
230
+ target_n = min(target_n, n_total)
231
+ if target_n <= 0:
232
+ return []
233
+ rng = np.random.RandomState(seed)
234
+ per_chunk = max(1, math.ceil(chunk_size * target_n / n_total))
235
+ selected: List[int] = []
236
+ for chunk_start in range(0, n_total, chunk_size):
237
+ if len(selected) >= target_n:
238
+ break
239
+ chunk_end = min(chunk_start + chunk_size, n_total)
240
+ n_in_chunk = chunk_end - chunk_start
241
+ n_to_pick = min(per_chunk, n_in_chunk, target_n - len(selected))
242
+ if n_to_pick <= 0:
243
+ break
244
+ offsets = rng.choice(n_in_chunk, size=n_to_pick, replace=False)
245
+ selected.extend(int(chunk_start + o) for o in offsets)
246
+ return sorted(selected[:target_n])
247
+
248
+
249
+ # ---------------------------------------------------------------------------
250
+ # BIOES → spans + metrics
251
+ # ---------------------------------------------------------------------------
252
+
253
+ def _bioes_to_spans(labels, id2label, o_id=0):
254
+ """Convert per-token BIOES label ids to a set of (start, end, cat)."""
255
+ spans = set()
256
+ cur_start = None
257
+ cur_cat = None
258
+ for i, lid in enumerate(labels):
259
+ lid = int(lid)
260
+ if lid == o_id or lid < 0:
261
+ if cur_start is not None:
262
+ spans.add((cur_start, i, cur_cat))
263
+ cur_start = None
264
+ cur_cat = None
265
+ continue
266
+ tag = id2label.get(lid, "O")
267
+ if tag == "O" or "-" not in tag:
268
+ if cur_start is not None:
269
+ spans.add((cur_start, i, cur_cat))
270
+ cur_start = None
271
+ cur_cat = None
272
+ continue
273
+ prefix, cat = tag.split("-", 1)
274
+ if prefix == "S":
275
+ if cur_start is not None:
276
+ spans.add((cur_start, i, cur_cat))
277
+ spans.add((i, i + 1, cat))
278
+ cur_start = None
279
+ cur_cat = None
280
+ elif prefix == "B":
281
+ if cur_start is not None:
282
+ spans.add((cur_start, i, cur_cat))
283
+ cur_start = i
284
+ cur_cat = cat
285
+ elif prefix == "I":
286
+ if cur_start is None or cur_cat != cat:
287
+ if cur_start is not None:
288
+ spans.add((cur_start, i, cur_cat))
289
+ cur_start = i
290
+ cur_cat = cat
291
+ elif prefix == "E":
292
+ if cur_start is None or cur_cat != cat:
293
+ if cur_start is not None:
294
+ spans.add((cur_start, i, cur_cat))
295
+ spans.add((i, i + 1, cat))
296
+ cur_start = None
297
+ cur_cat = None
298
+ else:
299
+ spans.add((cur_start, i + 1, cur_cat))
300
+ cur_start = None
301
+ cur_cat = None
302
+ if cur_start is not None:
303
+ spans.add((cur_start, len(labels), cur_cat))
304
+ return spans
305
+
306
+
307
+ def _aggregate_span_metrics(gold_spans, pred_spans):
308
+ correct = gold_spans & pred_spans
309
+ n_gold = len(gold_spans)
310
+ n_pred = len(pred_spans)
311
+ n_correct = len(correct)
312
+ p = n_correct / n_pred if n_pred else 0.0
313
+ r = n_correct / n_gold if n_gold else 0.0
314
+ f1 = (2 * p * r / (p + r)) if (p + r) else 0.0
315
+ per_cat: Dict[str, dict] = {}
316
+ cats = sorted({c for _, _, c in gold_spans} | {c for _, _, c in pred_spans})
317
+ for cat in cats:
318
+ g_c = {s for s in gold_spans if s[2] == cat}
319
+ p_c = {s for s in pred_spans if s[2] == cat}
320
+ c_c = g_c & p_c
321
+ pp = len(c_c) / len(p_c) if p_c else 0.0
322
+ rr = len(c_c) / len(g_c) if g_c else 0.0
323
+ ff = (2 * pp * rr / (pp + rr)) if (pp + rr) else 0.0
324
+ per_cat[cat] = {"precision": pp, "recall": rr, "f1": ff,
325
+ "n_gold": len(g_c), "n_pred": len(p_c), "n_correct": len(c_c)}
326
+ return {"precision": p, "recall": r, "f1": f1,
327
+ "n_gold": n_gold, "n_pred": n_pred, "n_correct": n_correct,
328
+ "per_cat": per_cat}
329
+
330
+
331
+ def _stream_metrics(docs, stream, id2label, o_id):
332
+ """Aggregate metrics over a list of {gold, raw, viterbi} per-doc dicts."""
333
+ n_tokens = correct = n_non_o = non_o_correct = 0
334
+ gold_spans_all: set = set()
335
+ pred_spans_all: set = set()
336
+ doc_offset = 0
337
+ for doc in docs:
338
+ gold = [int(x) for x in doc["gold"]]
339
+ pred = [int(x) for x in doc[stream]]
340
+ n = len(gold)
341
+ n_tokens += n
342
+ for g, p in zip(gold, pred):
343
+ n_non_o += int(g != o_id)
344
+ if g == p:
345
+ correct += 1
346
+ if g != o_id:
347
+ non_o_correct += 1
348
+ gs = _bioes_to_spans(gold, id2label, o_id)
349
+ ps = _bioes_to_spans(pred, id2label, o_id)
350
+ gold_spans_all.update((doc_offset + s, doc_offset + e, c) for s, e, c in gs)
351
+ pred_spans_all.update((doc_offset + s, doc_offset + e, c) for s, e, c in ps)
352
+ doc_offset += n
353
+ span_m = _aggregate_span_metrics(gold_spans_all, pred_spans_all)
354
+ return {
355
+ "n_tokens": n_tokens,
356
+ "n_non_o": n_non_o,
357
+ "token_acc": correct / n_tokens if n_tokens else 0.0,
358
+ "non_o_recall": non_o_correct / n_non_o if n_non_o else 0.0,
359
+ "span_precision": span_m["precision"],
360
+ "span_recall": span_m["recall"],
361
+ "span_f1": span_m["f1"],
362
+ "n_gold_spans": span_m["n_gold"],
363
+ "n_pred_spans": span_m["n_pred"],
364
+ "n_correct_spans": span_m["n_correct"],
365
+ "span_per_cat": span_m["per_cat"],
366
+ }
367
+
368
+
369
+ # ---------------------------------------------------------------------------
370
+ # Forward pass + viterbi (delegates to the released modeling)
371
+ # ---------------------------------------------------------------------------
372
+
373
+ def _model_perf_stats(model, dtype) -> dict:
374
+ """Total / active / compute / MoE breakdown + on-device byte size.
375
+
376
+ Three distinct param counts:
377
+
378
+ * total_params — every parameter the model has on disk / in RAM.
379
+ * active_params_per_tok — params *touched* during one token's forward
380
+ pass (memory footprint per token). Counts
381
+ the embedding because the embedding row
382
+ IS read per token; counts only top-k of
383
+ num_experts MoE experts because routing is
384
+ sparse.
385
+ * compute_params_per_tok — params that contribute matmul FLOPs per
386
+ token. EXCLUDES the embedding table:
387
+ `embed_tokens.weight` is a gather (one row
388
+ read), not a matmul, so its FLOP cost is
389
+ negligible (~hidden_size ops vs the table
390
+ having ~vocab × hidden params). Counting
391
+ it via the standard "2 × params" matmul
392
+ approximation hugely inflates the apparent
393
+ GFLOP/token and compresses the ratio between
394
+ deep and shallow models.
395
+
396
+ GFLOP/token is computed from `compute_params_per_tok`, not from
397
+ `active_params_per_tok`. This makes the metric reflect actual layer-wise
398
+ computational cost.
399
+ """
400
+ cfg = model.config
401
+ num_experts = int(getattr(cfg, "num_local_experts", 1))
402
+ top_k = int(getattr(cfg, "num_experts_per_tok", num_experts))
403
+ expert_frac = top_k / max(1, num_experts)
404
+
405
+ moe_total = 0
406
+ moe_active = 0
407
+ other_total = 0
408
+ embed_total = 0 # gather-only params; excluded from FLOP estimate
409
+ for name, p in model.named_parameters():
410
+ n = p.numel()
411
+ # MoE expert tensors are stacked along an experts axis. The upstream
412
+ # exposes them under `mlp.experts.*`. Only `top_k` of `num_experts`
413
+ # experts contribute per token.
414
+ if ".mlp.experts." in name:
415
+ moe_total += n
416
+ moe_active += int(round(n * expert_frac))
417
+ # `embed_tokens.weight` (and any other lookup-style table) is a
418
+ # gather: one row of [vocab, hidden] is read per token, costing
419
+ # ~hidden ops, not 2 × vocab × hidden FLOPs. Tracked separately
420
+ # so it doesn't pollute the FLOP estimate.
421
+ elif "embed_tokens" in name:
422
+ embed_total += n
423
+ other_total += n
424
+ else:
425
+ other_total += n
426
+
427
+ total = moe_total + other_total
428
+ active_per_tok = moe_active + other_total
429
+
430
+ # Compute-relevant params per token: matmuls only. Drop the embedding
431
+ # lookup, which contributes ~zero FLOPs.
432
+ compute_per_tok = active_per_tok - embed_total
433
+ bytes_per_param = {torch.bfloat16: 2, torch.float16: 2, torch.float32: 4}.get(dtype, 4)
434
+ storage_bytes = total * bytes_per_param # in-memory dense weights
435
+ # GFLOP/token over matmul params only. Matmul ≈ 2 FLOPs per param.
436
+ gflops_per_tok = 2 * compute_per_tok / 1e9
437
+ return {
438
+ "total_params": total,
439
+ "moe_total": moe_total,
440
+ "moe_active_per_tok": moe_active,
441
+ "other_total": other_total,
442
+ "embed_total": embed_total,
443
+ "active_params_per_tok": active_per_tok,
444
+ "compute_params_per_tok": compute_per_tok,
445
+ "num_experts": num_experts,
446
+ "experts_per_tok": top_k,
447
+ "expert_frac": expert_frac,
448
+ "weight_bytes": storage_bytes,
449
+ "gflops_per_tok": gflops_per_tok,
450
+ }
451
+
452
+
453
+ def _disk_size_bytes(model_path: str) -> int:
454
+ """Sum on-disk size of weight files at the given path. Falls back to 0
455
+ if the path is a HF repo id (not a local directory)."""
456
+ p = Path(model_path)
457
+ if not p.is_dir():
458
+ return 0
459
+ total = 0
460
+ for f in p.iterdir():
461
+ if f.is_file() and f.suffix in {".safetensors", ".bin", ".pt", ".pth"}:
462
+ total += f.stat().st_size
463
+ return total
464
+
465
+
466
+ def _eval_one_model(
467
+ model_path: str, tokenizer, eval_ds, label2id, id2label, o_id,
468
+ bioes_trans, bioes_init, batch_size, max_length, device, dtype,
469
+ label: str,
470
+ ):
471
+ print(f"[eval] loading {label} from {model_path} ...", flush=True)
472
+ if torch.cuda.is_available() and device.type == "cuda":
473
+ torch.cuda.reset_peak_memory_stats(device)
474
+ torch.cuda.empty_cache()
475
+ mem_before = (torch.cuda.memory_allocated(device)
476
+ if torch.cuda.is_available() and device.type == "cuda" else 0)
477
+ t_load = time.time()
478
+ model = AutoModelForTokenClassification.from_pretrained(
479
+ model_path, dtype=dtype, trust_remote_code=True,
480
+ ).to(device).eval()
481
+ if hasattr(model.config, "use_viterbi_decode"):
482
+ model.config.use_viterbi_decode = True
483
+ if hasattr(model.config, "viterbi_replace_logits"):
484
+ model.config.viterbi_replace_logits = False
485
+ load_s = time.time() - t_load
486
+ perf = _model_perf_stats(model, dtype)
487
+ perf["disk_size_bytes"] = _disk_size_bytes(model_path)
488
+ weights_resident_bytes = (
489
+ torch.cuda.memory_allocated(device) - mem_before
490
+ if torch.cuda.is_available() and device.type == "cuda" else perf["weight_bytes"]
491
+ )
492
+ perf["weights_resident_bytes"] = weights_resident_bytes
493
+
494
+ # Vendor the batched viterbi (re-import from the released modeling file).
495
+ from modeling_haremb_pii import _bioes_viterbi_batched
496
+
497
+ pad_token_id = tokenizer.pad_token_id or 199999
498
+ loader = DataLoader(
499
+ eval_ds, batch_size=batch_size, shuffle=False,
500
+ collate_fn=_make_collate(pad_token_id, max_length), num_workers=2,
501
+ )
502
+ docs: List[dict] = []
503
+ n_tok = 0
504
+ if torch.cuda.is_available() and device.type == "cuda":
505
+ torch.cuda.synchronize()
506
+ torch.cuda.reset_peak_memory_stats(device)
507
+ t0 = time.time()
508
+ for batch in tqdm(loader, desc=f"eval {label}", unit="batch", leave=False):
509
+ ids = batch["input_ids"].to(device, non_blocking=True)
510
+ mask = batch["attention_mask"].to(device, non_blocking=True)
511
+ gold = batch["labels"].to(device, non_blocking=True)
512
+ with torch.no_grad():
513
+ out = model(input_ids=ids, attention_mask=mask)
514
+ raw = out.logits.argmax(dim=-1)
515
+ vit = _bioes_viterbi_batched(out.logits.float(), mask, bioes_trans, bioes_init)
516
+ valid = (gold != -100) & mask.bool()
517
+ for b in range(gold.shape[0]):
518
+ keep = [i for i, ok in enumerate(valid[b].cpu().tolist()) if ok]
519
+ n_tok += len(keep)
520
+ docs.append({
521
+ "gold": [int(gold[b, i].item()) for i in keep],
522
+ "raw": [int(raw[b, i].item()) for i in keep],
523
+ "viterbi": [int(vit[b, i].item()) for i in keep],
524
+ })
525
+ if torch.cuda.is_available() and device.type == "cuda":
526
+ torch.cuda.synchronize()
527
+ peak_mem = torch.cuda.max_memory_allocated(device)
528
+ else:
529
+ peak_mem = 0
530
+ eval_s = time.time() - t0
531
+
532
+ raw_m = _stream_metrics(docs, "raw", id2label, o_id)
533
+ vit_m = _stream_metrics(docs, "viterbi", id2label, o_id)
534
+
535
+ perf["peak_eval_mem_bytes"] = peak_mem
536
+
537
+ del model
538
+ if torch.cuda.is_available():
539
+ torch.cuda.empty_cache()
540
+
541
+ return {
542
+ "label": label,
543
+ "n_total_M": perf["total_params"] / 1e6,
544
+ "load_s": load_s,
545
+ "eval_s": eval_s,
546
+ "n_tok": n_tok,
547
+ "throughput_tok_s": n_tok / eval_s if eval_s else 0.0,
548
+ "perf": perf,
549
+ "raw": raw_m,
550
+ "viterbi": vit_m,
551
+ "docs": docs,
552
+ }
553
+
554
+
555
+ # ---------------------------------------------------------------------------
556
+ # Pairwise token-level breakdown (this vs reference, viterbi stream)
557
+ # ---------------------------------------------------------------------------
558
+
559
+ def _pairwise(docs_cand, docs_ref, id2label, o_id):
560
+ both_correct = only_cand = only_ref = both_wrong = 0
561
+ by_cat = defaultdict(lambda: {"both_correct": 0, "only_cand": 0, "only_ref": 0, "both_wrong": 0})
562
+ for dc, dr in zip(docs_cand, docs_ref):
563
+ gold = dc["gold"]
564
+ cv = dc["viterbi"]
565
+ rv = dr["viterbi"]
566
+ for g, c, r in zip(gold, cv, rv):
567
+ cat = id2label.get(g, "O")
568
+ cat = cat.split("-", 1)[1] if "-" in cat else cat
569
+ cc = (c == g)
570
+ rc = (r == g)
571
+ if cc and rc:
572
+ both_correct += 1
573
+ by_cat[cat]["both_correct"] += 1
574
+ elif cc and not rc:
575
+ only_cand += 1
576
+ by_cat[cat]["only_cand"] += 1
577
+ elif rc and not cc:
578
+ only_ref += 1
579
+ by_cat[cat]["only_ref"] += 1
580
+ else:
581
+ both_wrong += 1
582
+ by_cat[cat]["both_wrong"] += 1
583
+ return {
584
+ "both_correct": both_correct,
585
+ "only_cand_correct": only_cand,
586
+ "only_ref_correct": only_ref,
587
+ "both_wrong": both_wrong,
588
+ "by_cat": dict(by_cat),
589
+ }
590
+
591
+
592
+ # ---------------------------------------------------------------------------
593
+ # Plot rendering
594
+ # ---------------------------------------------------------------------------
595
+
596
+ def _render_plots(cand, ref, pair, out_dir: Path, cand_label, ref_label):
597
+ """Render benchmark plots.
598
+
599
+ Visual convention:
600
+ A = reference / teacher / baseline
601
+ B = candidate / this checkpoint
602
+
603
+ The charts avoid color-only "win" encoding: labels state the actual delta
604
+ or ratio, and horizontal layouts keep long metric/category names readable.
605
+ """
606
+ try:
607
+ import matplotlib
608
+ matplotlib.use("Agg")
609
+ import matplotlib.pyplot as plt
610
+ from matplotlib.ticker import FuncFormatter
611
+ except ImportError:
612
+ print("[plot] matplotlib not installed, skipping", flush=True)
613
+ return
614
+
615
+ plt.rcParams.update({
616
+ "figure.facecolor": "#ffffff",
617
+ "axes.facecolor": "#ffffff",
618
+ "axes.edgecolor": "#cbd5e1",
619
+ "axes.labelcolor": "#0f172a",
620
+ "xtick.color": "#334155",
621
+ "ytick.color": "#334155",
622
+ "grid.color": "#e2e8f0",
623
+ "font.size": 9,
624
+ "axes.titleweight": "bold",
625
+ "axes.titlesize": 11,
626
+ "legend.frameon": False,
627
+ })
628
+
629
+ C_REF = "#64748b" # slate
630
+ C_CAND = "#2563eb" # blue
631
+ C_GOOD = "#0f766e" # teal
632
+ C_BAD = "#b91c1c" # red
633
+ C_NEUTRAL = "#94a3b8" # light slate
634
+ C_BG = "#f8fafc"
635
+
636
+ def _pct_axis(ax):
637
+ ax.xaxis.set_major_formatter(FuncFormatter(lambda x, _pos: f"{x:.0%}"))
638
+
639
+ def _metric_delta_text(delta):
640
+ return f"{delta:+.4f}"
641
+
642
+ def _value_text(v):
643
+ if v >= 100:
644
+ return f"{v:,.0f}"
645
+ if v >= 10:
646
+ return f"{v:,.1f}"
647
+ if v >= 1:
648
+ return f"{v:,.2f}"
649
+ return f"{v:,.3f}"
650
+
651
+ def _ratio_text(r, lower_is_better):
652
+ if r <= 0:
653
+ return "n/a"
654
+ if lower_is_better:
655
+ return f"{1.0 / r:.2f}x lower" if r <= 1 else f"{r:.2f}x higher"
656
+ return f"{r:.2f}x higher" if r >= 1 else f"{1.0 / r:.2f}x lower"
657
+
658
+ # --- eval_summary.png: headline metrics + category-level deltas ---
659
+ fig, axes = plt.subplots(2, 2, figsize=(8, 5), constrained_layout=True)
660
+ ax_head = axes[0, 0]
661
+ ax_delta = axes[0, 1]
662
+ ax_raw_vit = axes[1, 0]
663
+ ax_cat = axes[1, 1]
664
+
665
+ metrics = ["span_f1", "span_precision", "span_recall", "token_acc", "non_o_recall"]
666
+ labels = ["Span F1", "Span P", "Span R", "Token acc", "Non-O recall"]
667
+ cand_v = [cand["viterbi"][m] for m in metrics]
668
+ ref_v = [ref["viterbi"][m] for m in metrics] if ref is not None else None
669
+
670
+ y = np.arange(len(metrics))
671
+ if ref_v is not None:
672
+ ax_head.hlines(y, ref_v, cand_v, color=C_NEUTRAL, linewidth=2, alpha=0.9)
673
+ ax_head.scatter(ref_v, y, s=55, color=C_REF, label=f"A: {ref_label}", zorder=3)
674
+ ax_head.scatter(cand_v, y, s=70, color=C_CAND, label=f"B: {cand_label}", zorder=4)
675
+ ax_head.set_yticks(y)
676
+ ax_head.set_yticklabels(labels)
677
+ ax_head.invert_yaxis()
678
+ ax_head.set_xlim(max(0.0, min(cand_v + (ref_v or cand_v)) - 0.02),
679
+ min(1.08, max(cand_v + (ref_v or cand_v)) + 0.05))
680
+ ax_head.set_title("Headline metrics, Viterbi stream")
681
+ ax_head.grid(axis="x", alpha=0.7)
682
+ _pct_axis(ax_head)
683
+ if ref_v is not None:
684
+ for i, v in enumerate(ref_v):
685
+ ax_head.text(v + 0.002, i, f"{v:.4f}", va="center", ha="left",
686
+ fontsize=7, color=C_REF)
687
+ for i, v in enumerate(cand_v):
688
+ ax_head.text(v - 0.002, i, f"{v:.4f}", va="center", ha="right",
689
+ fontsize=7, color=C_CAND)
690
+
691
+ if ref_v is not None:
692
+ deltas = [b - a for a, b in zip(ref_v, cand_v)]
693
+ colors = [C_GOOD if d >= 0 else C_BAD for d in deltas]
694
+ ax_delta.axvline(0, color="#0f172a", linewidth=0.9)
695
+ ax_delta.barh(y, deltas, color=colors, alpha=0.9)
696
+ ax_delta.set_yticks(y)
697
+ ax_delta.set_yticklabels(labels)
698
+ ax_delta.invert_yaxis()
699
+ ax_delta.set_title("Delta: B minus A")
700
+ ax_delta.grid(axis="x", alpha=0.7)
701
+ max_abs = max([abs(d) for d in deltas] + [0.002])
702
+ ax_delta.set_xlim(-max_abs * 1.45, max_abs * 1.45)
703
+ for i, d in enumerate(deltas):
704
+ ax_delta.text(d + (max_abs * 0.04 if d >= 0 else -max_abs * 0.04), i, _metric_delta_text(d),
705
+ ha="left" if d >= 0 else "right", va="center", fontsize=7)
706
+ else:
707
+ ax_delta.axis("off")
708
+
709
+ stream_rows = [
710
+ ("A raw", ref["raw"]["span_f1"] if ref is not None else None, C_REF),
711
+ ("A viterbi", ref["viterbi"]["span_f1"] if ref is not None else None, C_REF),
712
+ ("B raw", cand["raw"]["span_f1"], C_CAND),
713
+ ("B viterbi", cand["viterbi"]["span_f1"], C_CAND),
714
+ ]
715
+ stream_rows = [r for r in stream_rows if r[1] is not None]
716
+ sy = np.arange(len(stream_rows))
717
+ ax_raw_vit.barh(sy, [r[1] for r in stream_rows], color=[r[2] for r in stream_rows], alpha=0.88)
718
+ ax_raw_vit.set_yticks(sy)
719
+ ax_raw_vit.set_yticklabels([r[0] for r in stream_rows])
720
+ ax_raw_vit.invert_yaxis()
721
+ ax_raw_vit.set_xlim(0, 1.08)
722
+ ax_raw_vit.set_title("Raw vs Viterbi span F1")
723
+ ax_raw_vit.grid(axis="x", alpha=0.7)
724
+ _pct_axis(ax_raw_vit)
725
+ for i, (_, v, _) in enumerate(stream_rows):
726
+ ax_raw_vit.text(v + 0.008, i, f"{v:.4f}", va="center", fontsize=7)
727
+
728
+ cand_pc = cand["viterbi"]["span_per_cat"]
729
+ if ref is not None:
730
+ ref_pc = ref["viterbi"]["span_per_cat"]
731
+ cats = sorted(set(cand_pc) | set(ref_pc))
732
+ rows = []
733
+ for c in cats:
734
+ a = ref_pc.get(c, {}).get("f1", 0.0)
735
+ b = cand_pc.get(c, {}).get("f1", 0.0)
736
+ n = max(cand_pc.get(c, {}).get("n_gold", 0), ref_pc.get(c, {}).get("n_gold", 0))
737
+ rows.append((c, b - a, b, a, n))
738
+ # Keep the categories that explain the comparison: worst and best B deltas.
739
+ worst = sorted(rows, key=lambda r: r[1])[:8]
740
+ best = sorted(rows, key=lambda r: r[1], reverse=True)[:8]
741
+ picked = worst + [r for r in best if r[0] not in {x[0] for x in worst}]
742
+ picked = sorted(picked, key=lambda r: r[1])
743
+ cy = np.arange(len(picked))
744
+ deltas = [r[1] for r in picked]
745
+ ax_cat.axvline(0, color="#0f172a", linewidth=0.9)
746
+ ax_cat.barh(cy, deltas, color=[C_GOOD if d >= 0 else C_BAD for d in deltas])
747
+ ax_cat.set_yticks(cy)
748
+ ax_cat.set_yticklabels([r[0] for r in picked], fontsize=8)
749
+ ax_cat.set_title("Per-category span F1 delta, selected extremes")
750
+ ax_cat.grid(axis="x", alpha=0.7)
751
+ max_abs = max([abs(d) for d in deltas] + [0.05])
752
+ ax_cat.set_xlim(-max_abs * 1.55, max_abs * 1.55)
753
+ for i, r in enumerate(picked):
754
+ d = r[1]
755
+ ax_cat.text(d + (max_abs * 0.05 if d >= 0 else -max_abs * 0.05), i,
756
+ f"{d:+.3f} B={r[2]:.2f} A={r[3]:.2f}",
757
+ va="center", ha="left" if d >= 0 else "right", fontsize=6)
758
+ else:
759
+ cats_sorted = sorted(cand_pc.keys(), key=lambda c: cand_pc[c]["f1"])[:18]
760
+ vals = [cand_pc[c]["f1"] for c in cats_sorted]
761
+ cy = np.arange(len(cats_sorted))
762
+ ax_cat.barh(cy, vals, color=C_CAND)
763
+ ax_cat.set_yticks(cy)
764
+ ax_cat.set_yticklabels(cats_sorted, fontsize=8)
765
+ ax_cat.set_xlim(0, 1.0)
766
+ ax_cat.set_title("Lowest per-category span F1")
767
+ ax_cat.grid(axis="x", alpha=0.7)
768
+ _pct_axis(ax_cat)
769
+
770
+ fig.suptitle(f"Evaluation summary — A: {ref_label if ref else 'n/a'} | B: {cand_label}",
771
+ fontsize=9, fontweight="bold")
772
+ fig.savefig(out_dir / "eval_summary.png", dpi=160)
773
+ plt.close(fig)
774
+ print(f"[plot] wrote {out_dir / 'eval_summary.png'}", flush=True)
775
+
776
+ # --- eval_confusion.png: pairwise outcome on gold non-O tokens ---
777
+ # Display order matches A vs B: "Only A correct" (teacher) before
778
+ # "Only B correct" (student). Underlying buckets in `pair` are still
779
+ # named cand/ref; we just relabel for display.
780
+ if pair is not None:
781
+ fig, axes = plt.subplots(
782
+ 1, 2, figsize=(8, 3), constrained_layout=True,
783
+ gridspec_kw={"width_ratios": [0.9, 1.7]},
784
+ )
785
+ ax = axes[0]
786
+ non_o_buckets = {k: 0 for k in ["both_correct", "only_cand", "only_ref", "both_wrong"]}
787
+ for cat, d in pair["by_cat"].items():
788
+ if cat == "O":
789
+ continue
790
+ for k in non_o_buckets:
791
+ non_o_buckets[k] += d[k]
792
+ values = [non_o_buckets["both_correct"], non_o_buckets["only_ref"],
793
+ non_o_buckets["only_cand"], non_o_buckets["both_wrong"]]
794
+ labels_ = [
795
+ "Both\ncorrect",
796
+ "Only A\ncorrect",
797
+ "Only B\ncorrect",
798
+ "Both wrong",
799
+ ]
800
+ colors = [C_GOOD, C_REF, C_CAND, C_BAD]
801
+ total = max(1, sum(values))
802
+ ax.barh(np.arange(4), values, color=colors)
803
+ ax.set_yticks(np.arange(4))
804
+ ax.set_yticklabels(labels_)
805
+ ax.invert_yaxis()
806
+ ax.set_ylabel("Gold non-O tokens")
807
+ ax.set_title("Token outcome on gold non-O")
808
+ ax.grid(axis="x", alpha=0.7)
809
+ ax.set_xlim(0, max(values) * 1.32 if values else 1)
810
+ for i, v in enumerate(values):
811
+ ax.text(v + max(values) * 0.015, i, f"{v:,} ({v / total:.1%})",
812
+ va="center", fontsize=6)
813
+
814
+ rows = []
815
+ for cat, d in pair["by_cat"].items():
816
+ if cat == "O":
817
+ continue
818
+ net = d["only_cand"] - d["only_ref"]
819
+ active = d["only_cand"] + d["only_ref"] + d["both_wrong"]
820
+ if active:
821
+ rows.append((cat, net, d["only_cand"], d["only_ref"], d["both_wrong"]))
822
+ worst = sorted(rows, key=lambda r: r[1])[:8]
823
+ best = sorted(rows, key=lambda r: r[1], reverse=True)[:8]
824
+ picked = worst + [r for r in best if r[0] not in {x[0] for x in worst}]
825
+ picked = sorted(picked, key=lambda r: r[1])
826
+ ax2 = axes[1]
827
+ if picked:
828
+ py = np.arange(len(picked))
829
+ nets = [r[1] for r in picked]
830
+ ax2.axvline(0, color="#0f172a", linewidth=0.9)
831
+ ax2.barh(py, nets, color=[C_GOOD if n >= 0 else C_BAD for n in nets])
832
+ ax2.set_yticks(py)
833
+ ax2.set_yticklabels([r[0] for r in picked], fontsize=8)
834
+ ax2.set_title("Net token wins by category: B only-correct minus A only-correct")
835
+ ax2.grid(axis="x", alpha=0.7)
836
+ max_abs = max([abs(n) for n in nets] + [1])
837
+ ax2.set_xlim(-max_abs * 1.5, max_abs * 1.5)
838
+ for i, r in enumerate(picked):
839
+ n = r[1]
840
+ label_x = n + max_abs * 0.04 if n >= 0 else max_abs * 0.05
841
+ ax2.text(label_x, i,
842
+ f"{n:+d} B={r[2]} A={r[3]} W={r[4]}",
843
+ va="center", ha="left", fontsize=6)
844
+ else:
845
+ ax2.axis("off")
846
+
847
+ fig.suptitle(f"Pairwise correctness — A: {ref_label} | B: {cand_label}",
848
+ fontsize=9, fontweight="bold")
849
+ fig.savefig(out_dir / "eval_confusion.png", dpi=160)
850
+ plt.close(fig)
851
+ print(f"[plot] wrote {out_dir / 'eval_confusion.png'}", flush=True)
852
+
853
+ # --- eval_performance.png: model size, compute, throughput, memory ---
854
+ if ref is not None and "perf" in cand and "perf" in ref:
855
+ cp, rp = cand["perf"], ref["perf"]
856
+
857
+ fig, axes = plt.subplots(1, 2, figsize=(8.5, 5), constrained_layout=True)
858
+ ax_abs = axes[0]
859
+ ax_ratio = axes[1]
860
+
861
+ metrics = [
862
+ ("Total params (M)", cp["total_params"]/1e6, rp["total_params"]/1e6, True),
863
+ ("Active params/tok (M)", cp["active_params_per_tok"]/1e6, rp["active_params_per_tok"]/1e6, True),
864
+ ("MoE expert params (M)", cp["moe_total"]/1e6, rp["moe_total"]/1e6, True),
865
+ ("GFLOP/token", cp["gflops_per_tok"], rp["gflops_per_tok"], True),
866
+ ("Weights RAM (MiB)", cp["weight_bytes"]/(1<<20), rp["weight_bytes"]/(1<<20), True),
867
+ ("Peak eval mem (MiB)", cp["peak_eval_mem_bytes"]/(1<<20), rp["peak_eval_mem_bytes"]/(1<<20), True),
868
+ ("Throughput (tok/s)", cand["throughput_tok_s"], ref["throughput_tok_s"], False),
869
+ ]
870
+
871
+ y = np.arange(len(metrics))
872
+ w = 0.38
873
+ cand_v = [m[1] for m in metrics]
874
+ ref_v = [m[2] for m in metrics]
875
+ ax_abs.barh(y - w/2, ref_v, w, label=f"A: {ref_label}", color=C_REF)
876
+ ax_abs.barh(y + w/2, cand_v, w, label=f"B: {cand_label}", color=C_CAND)
877
+ ax_abs.set_yticks(y)
878
+ ax_abs.set_yticklabels([m[0] for m in metrics], fontsize=8)
879
+ ax_abs.invert_yaxis()
880
+ ax_abs.set_xscale("log")
881
+ ax_abs.set_title("Absolute footprint and speed, log scale")
882
+ ax_abs.grid(axis="x", which="both", alpha=0.7)
883
+ positive_vals = [v for v in (ref_v + cand_v) if v > 0]
884
+ if positive_vals:
885
+ ax_abs.set_xlim(min(positive_vals) * 0.55, max(positive_vals) * 3.8)
886
+ for yi, vals in enumerate(zip(ref_v, cand_v)):
887
+ for off, v, col in [(-w/2, vals[0], C_REF), (w/2, vals[1], C_CAND)]:
888
+ if v <= 0:
889
+ continue
890
+ ax_abs.text(v * 1.05, yi + off, _value_text(v), va="center", fontsize=6, color=col)
891
+
892
+ ratios = [(m[1] / max(1e-12, m[2])) for m in metrics]
893
+ lower_better = [m[3] for m in metrics]
894
+ colors = [
895
+ C_GOOD if ((lb and r <= 1.0) or ((not lb) and r >= 1.0)) else C_BAD
896
+ for r, lb in zip(ratios, lower_better)
897
+ ]
898
+ ax_ratio.axvline(1.0, color="#0f172a", linestyle="--", linewidth=0.9, alpha=0.65)
899
+ ax_ratio.barh(y, ratios, color=colors)
900
+ ax_ratio.set_yticks(y)
901
+ ax_ratio.set_yticklabels([m[0] for m in metrics], fontsize=8)
902
+ ax_ratio.invert_yaxis()
903
+ ax_ratio.set_xscale("log")
904
+ ax_ratio.set_title("B / A ratio with explicit direction")
905
+ ax_ratio.grid(axis="x", which="both", alpha=0.7)
906
+ positive_ratios = [r for r in ratios if r > 0]
907
+ if positive_ratios:
908
+ ax_ratio.set_xlim(min(positive_ratios) * 0.38, max(positive_ratios) * 2.8)
909
+ for i, (r, lb) in enumerate(zip(ratios, lower_better)):
910
+ ax_ratio.text(r * (1.05 if r >= 1 else 0.95), i, _ratio_text(r, lb),
911
+ va="center", ha="left" if r >= 1 else "right", fontsize=6)
912
+
913
+ fig.suptitle(f"Performance profile — A: {ref_label} | B: {cand_label}",
914
+ fontsize=9, fontweight="bold")
915
+ fig.savefig(out_dir / "eval_performance.png", dpi=160)
916
+ plt.close(fig)
917
+ print(f"[plot] wrote {out_dir / 'eval_performance.png'}", flush=True)
918
+
919
+
920
+ # ---------------------------------------------------------------------------
921
+ # Reporting
922
+ # ---------------------------------------------------------------------------
923
+
924
+ def _fmt_metrics(m):
925
+ return (f"span_F1={m['span_f1']:.4f} P={m['span_precision']:.4f} "
926
+ f"R={m['span_recall']:.4f} token_acc={m['token_acc']:.4f} "
927
+ f"non_o_recall={m['non_o_recall']:.4f} "
928
+ f"spans={m['n_gold_spans']}/{m['n_pred_spans']}/{m['n_correct_spans']}")
929
+
930
+
931
+ def _write_compare_log(path: Path, cand, ref, pair, args):
932
+ lines: List[str] = []
933
+ A = lines.append
934
+ A(f"Benchmark: A: {ref['label']} vs B: {cand['label']}"
935
+ if ref else f"Benchmark: {cand['label']}")
936
+ A(f"Dataset: {SOURCE_DATASET}, split=test, eval_pct={args.eval_pct}, "
937
+ f"ctx={args.max_length}, seed={args.seed}, n_docs={args.n_docs}")
938
+ A(f"Eval tokens scored: {cand['n_tok']:,}")
939
+ A("")
940
+ A("=== Aggregate ===")
941
+ if ref is not None:
942
+ A(f" A: {ref['label']:<25s} RAW {_fmt_metrics(ref['raw'])}")
943
+ A(f" A: {ref['label']:<25s} VITERBI {_fmt_metrics(ref['viterbi'])}")
944
+ A(f" B: {cand['label']:<25s} RAW {_fmt_metrics(cand['raw'])}")
945
+ A(f" B: {cand['label']:<25s} VITERBI {_fmt_metrics(cand['viterbi'])}")
946
+ if ref is not None:
947
+ d_f1 = ref["viterbi"]["span_f1"] - cand["viterbi"]["span_f1"]
948
+ A("")
949
+ A(f"Gap B vs A (viterbi span_F1): {-d_f1:+.4f}")
950
+ A("")
951
+ if ref is not None:
952
+ A(f"Throughput: A: {ref['label']} {ref['throughput_tok_s']:.0f} tok/s "
953
+ f"({ref['n_total_M']:.2f}M params)")
954
+ A(f" B: {cand['label']} {cand['throughput_tok_s']:.0f} tok/s "
955
+ f"({cand['n_total_M']:.2f}M params)")
956
+ else:
957
+ A(f"Throughput: {cand['label']} {cand['throughput_tok_s']:.0f} tok/s "
958
+ f"({cand['n_total_M']:.2f}M params)")
959
+ A("")
960
+ # ---- Performance summary table ----
961
+ # Column order: A (ref / teacher) first, then B (cand / student).
962
+ # The "B vs A" column is written in human-readable direction:
963
+ # - When B is smaller (size/compute/mem): "X.XX× smaller" / "X.XX× cheaper" / "X.XX× less".
964
+ # - When B is larger but lower-is-better: "X.XX× larger" / "X.XX× more".
965
+ # - When B is faster (throughput, higher-is-better): "X.XX× faster".
966
+ # - When B is slower: "X.XX× slower".
967
+ # Always uses the magnitude in the dominant direction so the reader
968
+ # doesn't need to mentally invert 0.21× into 4.87×.
969
+ def _fmt_vs(b: float, a: float, kind: str) -> str:
970
+ if a is None or a == 0 or b is None or b == 0:
971
+ return ""
972
+ ratio = b / a
973
+ if kind == "size": # lower is better; phrase as "smaller" or "larger"
974
+ if ratio <= 1.0:
975
+ return f"{1.0/ratio:.2f}× smaller"
976
+ return f"{ratio:.2f}× larger"
977
+ if kind == "compute":
978
+ if ratio <= 1.0:
979
+ return f"{1.0/ratio:.2f}× cheaper"
980
+ return f"{ratio:.2f}× more"
981
+ if kind == "memory":
982
+ if ratio <= 1.0:
983
+ return f"{1.0/ratio:.2f}× less"
984
+ return f"{ratio:.2f}× more"
985
+ if kind == "speed": # higher is better
986
+ if ratio >= 1.0:
987
+ return f"{ratio:.2f}× faster"
988
+ return f"{1.0/ratio:.2f}× slower"
989
+ return f"{ratio:.2f}×"
990
+
991
+ cp = cand["perf"]
992
+ A("=== Performance ===")
993
+ headers = ["metric",
994
+ f"A: {ref['label']}" if ref else "",
995
+ f"B: {cand['label']}",
996
+ "B vs A"]
997
+ rp = ref["perf"] if ref else None
998
+ rows = [
999
+ ["total params (M)",
1000
+ (f"{rp['total_params']/1e6:.2f}" if ref else ""),
1001
+ f"{cp['total_params']/1e6:.2f}",
1002
+ (_fmt_vs(cp['total_params'], rp['total_params'], "size") if ref else "")],
1003
+ ["dense params (M)",
1004
+ (f"{rp['other_total']/1e6:.2f}" if ref else ""),
1005
+ f"{cp['other_total']/1e6:.2f}",
1006
+ (_fmt_vs(cp['other_total'], rp['other_total'], "size") if ref else "")],
1007
+ ["MoE expert params (M)",
1008
+ (f"{rp['moe_total']/1e6:.2f}" if ref else ""),
1009
+ f"{cp['moe_total']/1e6:.2f}",
1010
+ (_fmt_vs(cp['moe_total'], rp['moe_total'], "size") if ref else "")],
1011
+ [f"active params/token (M, mem)",
1012
+ (f"{rp['active_params_per_tok']/1e6:.2f}" if ref else ""),
1013
+ f"{cp['active_params_per_tok']/1e6:.2f}",
1014
+ (_fmt_vs(cp['active_params_per_tok'], rp['active_params_per_tok'], "memory") if ref else "")],
1015
+ [f"compute params/token (M, FLOPs)",
1016
+ (f"{rp['compute_params_per_tok']/1e6:.2f}" if ref else ""),
1017
+ f"{cp['compute_params_per_tok']/1e6:.2f}",
1018
+ (_fmt_vs(cp['compute_params_per_tok'], rp['compute_params_per_tok'], "compute") if ref else "")],
1019
+ ["GFLOP / token",
1020
+ (f"{rp['gflops_per_tok']:.4f}" if ref else ""),
1021
+ f"{cp['gflops_per_tok']:.4f}",
1022
+ (_fmt_vs(cp['gflops_per_tok'], rp['gflops_per_tok'], "compute") if ref else "")],
1023
+ ["disk size (MiB)",
1024
+ (f"{rp['disk_size_bytes']/(1<<20):.1f}" if ref and rp['disk_size_bytes'] else ""),
1025
+ f"{cp['disk_size_bytes']/(1<<20):.1f}",
1026
+ (_fmt_vs(cp['disk_size_bytes'], rp['disk_size_bytes'], "size") if ref and rp['disk_size_bytes'] else "")],
1027
+ ["weights in RAM (MiB)",
1028
+ (f"{rp['weight_bytes']/(1<<20):.1f}" if ref else ""),
1029
+ f"{cp['weight_bytes']/(1<<20):.1f}",
1030
+ (_fmt_vs(cp['weight_bytes'], rp['weight_bytes'], "size") if ref else "")],
1031
+ ["peak GPU mem eval (MiB)",
1032
+ (f"{rp['peak_eval_mem_bytes']/(1<<20):.1f}" if ref else ""),
1033
+ f"{cp['peak_eval_mem_bytes']/(1<<20):.1f}",
1034
+ (_fmt_vs(cp['peak_eval_mem_bytes'], rp['peak_eval_mem_bytes'], "memory") if ref else "")],
1035
+ ["throughput (tok/s)",
1036
+ (f"{ref['throughput_tok_s']:.0f}" if ref else ""),
1037
+ f"{cand['throughput_tok_s']:.0f}",
1038
+ (_fmt_vs(cand['throughput_tok_s'], ref['throughput_tok_s'], "speed") if ref else "")],
1039
+ ]
1040
+ widths = [max(len(r[i]) for r in [headers] + rows) for i in range(4)]
1041
+ sep = " " + " ".join("-" * w for w in widths)
1042
+ A(" " + " ".join(h.ljust(widths[i]) for i, h in enumerate(headers)))
1043
+ A(sep)
1044
+ for r in rows:
1045
+ A(" " + " ".join(r[i].ljust(widths[i]) for i in range(4)))
1046
+ A("")
1047
+ if pair is not None:
1048
+ # Display labels: A = ref (teacher), B = cand (student).
1049
+ a_lbl = ref["label"] if ref is not None else "A"
1050
+ b_lbl = cand["label"]
1051
+ A(f"=== Pairwise (viterbi, all gold tokens) — A: {a_lbl} vs B: {b_lbl} ===")
1052
+ total = (pair["both_correct"] + pair["only_cand_correct"]
1053
+ + pair["only_ref_correct"] + pair["both_wrong"])
1054
+ # Display order: agreement, A-only, B-only, both-wrong.
1055
+ rows_all = [
1056
+ ("both_correct", pair["both_correct"]),
1057
+ ("only_A_correct", pair["only_ref_correct"]),
1058
+ ("only_B_correct", pair["only_cand_correct"]),
1059
+ ("both_wrong", pair["both_wrong"]),
1060
+ ]
1061
+ for k, v in rows_all:
1062
+ A(f" {k:<26s} {v:8d} ({100.0*v/total:.2f}%)")
1063
+ A("")
1064
+ A(f"=== Pairwise (viterbi, gold non-O tokens) — A: {a_lbl} vs B: {b_lbl} ===")
1065
+ non_o = {k: 0 for k in ["both_correct", "only_cand", "only_ref", "both_wrong"]}
1066
+ for cat, d in pair["by_cat"].items():
1067
+ if cat == "O":
1068
+ continue
1069
+ for k in non_o:
1070
+ non_o[k] += d[k]
1071
+ total_non_o = sum(non_o.values())
1072
+ rows_non_o = [
1073
+ ("both_correct", non_o["both_correct"]),
1074
+ ("only_A_correct", non_o["only_ref"]),
1075
+ ("only_B_correct", non_o["only_cand"]),
1076
+ ("both_wrong", non_o["both_wrong"]),
1077
+ ]
1078
+ for k, v in rows_non_o:
1079
+ A(f" {k:<26s} {v:8d} ({100.0*v/total_non_o:.2f}%)" if total_non_o else f" {k}: 0")
1080
+ A("")
1081
+ # Per-cat net B-wins. Net = (only_B) - (only_A) = (only_cand) - (only_ref).
1082
+ # Negative net = A (teacher) wins more in this category.
1083
+ nets = []
1084
+ for cat, d in pair["by_cat"].items():
1085
+ if cat == "O":
1086
+ continue
1087
+ nets.append((cat, d["only_cand"] - d["only_ref"],
1088
+ d["only_cand"], d["only_ref"], d["both_wrong"]))
1089
+ nets.sort(key=lambda x: x[1])
1090
+ A(f"=== Worst B-net wins by gold category — A: {a_lbl} ahead (top 15) ===")
1091
+ for cat, net, ob, oa, bw in nets[:15]:
1092
+ A(f" {cat:<32s} net_B={net:+5d} A_only={oa:4d} B_only={ob:4d} both_wrong={bw:4d}")
1093
+ A("")
1094
+ A(f"=== Best B-net wins by gold category — B: {b_lbl} ahead (top 15) ===")
1095
+ for cat, net, ob, oa, bw in nets[::-1][:15]:
1096
+ A(f" {cat:<32s} net_B={net:+5d} A_only={oa:4d} B_only={ob:4d} both_wrong={bw:4d}")
1097
+ A("")
1098
+ A("=== Per-category span F1 (viterbi) ===")
1099
+ if ref is not None:
1100
+ A(f" -- A: {ref['label']} --")
1101
+ per_r = ref["viterbi"]["span_per_cat"]
1102
+ for cat in sorted(per_r):
1103
+ c = per_r[cat]
1104
+ A(f" {cat:<32s} F1={c['f1']:.4f} P={c['precision']:.4f} R={c['recall']:.4f} "
1105
+ f"({c['n_gold']}/{c['n_pred']}/{c['n_correct']})")
1106
+ A(f" -- B: {cand['label']} --")
1107
+ per = cand["viterbi"]["span_per_cat"]
1108
+ for cat in sorted(per):
1109
+ c = per[cat]
1110
+ A(f" {cat:<32s} F1={c['f1']:.4f} P={c['precision']:.4f} R={c['recall']:.4f} "
1111
+ f"({c['n_gold']}/{c['n_pred']}/{c['n_correct']})")
1112
+ path.write_text("\n".join(lines) + "\n")
1113
+ print(f"[log] wrote {path}", flush=True)
1114
+
1115
+
1116
+ def _fmt_bytes(n: int) -> str:
1117
+ if n <= 0:
1118
+ return "—"
1119
+ if n >= 1 << 30:
1120
+ return f"{n / (1 << 30):.2f} GiB"
1121
+ if n >= 1 << 20:
1122
+ return f"{n / (1 << 20):.1f} MiB"
1123
+ return f"{n / (1 << 10):.1f} KiB"
1124
+
1125
+
1126
+ def _perf_block(stream, ctx: int) -> List[str]:
1127
+ p = stream["perf"]
1128
+ out = [
1129
+ f" total params : {p['total_params']/1e6:>9.2f}M "
1130
+ f"({p['other_total']/1e6:.2f}M dense + {p['moe_total']/1e6:.2f}M MoE-experts)",
1131
+ f" active params / token : {p['active_params_per_tok']/1e6:>9.2f}M "
1132
+ f"(memory footprint — embed lookup + top_{p['experts_per_tok']}/{p['num_experts']} experts: "
1133
+ f"{p['embed_total']/1e6:.2f}M embed + "
1134
+ f"{p['moe_active_per_tok']/1e6:.2f}M MoE-active + "
1135
+ f"{(p['other_total']-p['embed_total'])/1e6:.2f}M attn/norm/head)",
1136
+ f" compute params / token : {p['compute_params_per_tok']/1e6:>9.2f}M "
1137
+ f"(matmul FLOPs only — embedding lookup excluded)",
1138
+ f" GFLOP / token (fwd, MAC×2): {p['gflops_per_tok']:>9.3f}",
1139
+ f" weights size (on disk) : {_fmt_bytes(p['disk_size_bytes']):>9s}",
1140
+ f" weights size (in RAM) : {_fmt_bytes(p['weight_bytes']):>9s}",
1141
+ f" weights resident (GPU) : {_fmt_bytes(p['weights_resident_bytes']):>9s}",
1142
+ f" peak GPU mem (eval, ctx={ctx}) : {_fmt_bytes(p['peak_eval_mem_bytes']):>9s}",
1143
+ ]
1144
+ return out
1145
+
1146
+
1147
+ def _write_infer_log(path: Path, cand, ref, args, sample_text: str, tokenizer, device, dtype):
1148
+ """Single-doc inference example + timing + performance metrics."""
1149
+ from modeling_haremb_pii import _bioes_viterbi_batched
1150
+
1151
+ lines: List[str] = []
1152
+ A = lines.append
1153
+ A(f"Inference benchmark: A: {ref['label']} vs B: {cand['label']}"
1154
+ if ref else f"Inference benchmark: {cand['label']}")
1155
+ A(f" device : {device} dtype: {dtype}")
1156
+ A(f" ctx : {args.max_length}")
1157
+ A("")
1158
+ if ref is not None:
1159
+ A(f"A: {ref['label']} (reference / teacher)")
1160
+ A(f" load : {ref['load_s']:.2f}s")
1161
+ A(f" eval : {ref['eval_s']:.2f}s on {ref['n_tok']:,} tokens "
1162
+ f"({ref['throughput_tok_s']:.0f} tok/s)")
1163
+ A("Performance:")
1164
+ for ln in _perf_block(ref, args.max_length):
1165
+ A(ln)
1166
+ A("")
1167
+ A(f"B: {cand['label']}" + (" (this checkpoint)" if ref else ""))
1168
+ A(f" load : {cand['load_s']:.2f}s")
1169
+ A(f" eval : {cand['eval_s']:.2f}s on {cand['n_tok']:,} tokens "
1170
+ f"({cand['throughput_tok_s']:.0f} tok/s)")
1171
+ A("Performance:")
1172
+ for ln in _perf_block(cand, args.max_length):
1173
+ A(ln)
1174
+ A("")
1175
+ if ref is not None:
1176
+ cp, rp = cand["perf"], ref["perf"]
1177
+
1178
+ def _fmt(b, a, kind):
1179
+ if a is None or a == 0 or b is None or b == 0:
1180
+ return "—"
1181
+ r = b / a
1182
+ if kind == "size":
1183
+ return f"{1.0/r:.2f}× smaller" if r <= 1.0 else f"{r:.2f}× larger"
1184
+ if kind == "compute":
1185
+ return f"{1.0/r:.2f}× cheaper" if r <= 1.0 else f"{r:.2f}× more"
1186
+ if kind == "memory":
1187
+ return f"{1.0/r:.2f}× less" if r <= 1.0 else f"{r:.2f}× more"
1188
+ if kind == "speed":
1189
+ return f"{r:.2f}× faster" if r >= 1.0 else f"{1.0/r:.2f}× slower"
1190
+ return f"{r:.2f}×"
1191
+
1192
+ A(f"B vs A ({cand['label']} vs {ref['label']}):")
1193
+ A(f" total params : {_fmt(cp['total_params'], rp['total_params'], 'size')}")
1194
+ A(f" active params / token : {_fmt(cp['active_params_per_tok'], rp['active_params_per_tok'], 'memory')} [memory]")
1195
+ A(f" compute params / token : {_fmt(cp['compute_params_per_tok'], rp['compute_params_per_tok'], 'compute')} [FLOPs]")
1196
+ A(f" GFLOP / token : {_fmt(cp['gflops_per_tok'], rp['gflops_per_tok'], 'compute')}")
1197
+ if rp['disk_size_bytes']:
1198
+ A(f" weights size (on disk) : {_fmt(cp['disk_size_bytes'], rp['disk_size_bytes'], 'size')}")
1199
+ else:
1200
+ A(f" weights size (on disk) : —")
1201
+ A(f" weights in RAM : {_fmt(cp['weight_bytes'], rp['weight_bytes'], 'size')}")
1202
+ A(f" peak GPU mem (eval) : {_fmt(cp['peak_eval_mem_bytes'], rp['peak_eval_mem_bytes'], 'memory')}")
1203
+ A(f" throughput : {_fmt(cand['throughput_tok_s'], ref['throughput_tok_s'], 'speed')}")
1204
+ A("")
1205
+
1206
+ A("Sample inference (load → tokenize → forward → viterbi-decode → spans):")
1207
+ A(f" text: {sample_text!r}")
1208
+ model = AutoModelForTokenClassification.from_pretrained(
1209
+ ".", dtype=dtype, trust_remote_code=True,
1210
+ ).to(device).eval()
1211
+ if hasattr(model.config, "viterbi_replace_logits"):
1212
+ model.config.viterbi_replace_logits = True
1213
+ enc = tokenizer(sample_text, return_tensors="pt", truncation=True,
1214
+ max_length=args.max_length).to(device)
1215
+ with torch.no_grad():
1216
+ if torch.cuda.is_available():
1217
+ torch.cuda.synchronize()
1218
+ t0 = time.time()
1219
+ out = model(**enc)
1220
+ if torch.cuda.is_available():
1221
+ torch.cuda.synchronize()
1222
+ dt = time.time() - t0
1223
+ label2id, id2label = nemotron_native_label_space()
1224
+ pred = out.logits.argmax(-1)[0].cpu().tolist()
1225
+ spans = _bioes_to_spans(pred, id2label, 0)
1226
+ A(f" forward latency: {dt*1000:.1f}ms ({enc.input_ids.shape[1]} tokens)")
1227
+ A(f" detected {len(spans)} spans:")
1228
+ tok_ids = enc.input_ids[0].cpu().tolist()
1229
+ for s, e, cat in sorted(spans):
1230
+ text = tokenizer.decode(tok_ids[s:e]).strip()
1231
+ A(f" [{s:3d}, {e:3d}) {cat:<28s} {text!r}")
1232
+ del model
1233
+ if torch.cuda.is_available():
1234
+ torch.cuda.empty_cache()
1235
+ path.write_text("\n".join(lines) + "\n")
1236
+ print(f"[log] wrote {path}", flush=True)
1237
+
1238
+
1239
+ # ---------------------------------------------------------------------------
1240
+ # Main
1241
+ # ---------------------------------------------------------------------------
1242
+
1243
+ def main():
1244
+ p = argparse.ArgumentParser(description="Benchmark haremb-privacy-filter-opennemo")
1245
+ p.add_argument("--device", default=None,
1246
+ help="cuda or cpu. Default: cuda if available.")
1247
+ p.add_argument("--dtype", default="bfloat16",
1248
+ choices=["bfloat16", "float16", "float32"])
1249
+ p.add_argument("--eval-pct", type=float, default=1.0,
1250
+ help="Percent of nvidia/Nemotron-PII test split to use. Default 1%%.")
1251
+ p.add_argument("--eval-chunk-size", type=int, default=10_000)
1252
+ p.add_argument("--seed", type=int, default=42)
1253
+ p.add_argument("--max-length", type=int, default=1024)
1254
+ p.add_argument("--batch-size", type=int, default=4)
1255
+ p.add_argument("--out", type=str, default=".")
1256
+ p.add_argument("--model-path", type=str, default=".",
1257
+ help="Path to this checkpoint. Default: ./ (this folder).")
1258
+ p.add_argument("--no-base", action="store_true",
1259
+ help="Skip the OpenMed teacher comparison.")
1260
+ p.add_argument("--no-plots", action="store_true",
1261
+ help="Skip rendering eval_summary.png / eval_confusion.png.")
1262
+ args = p.parse_args()
1263
+
1264
+ out_dir = Path(args.out).resolve()
1265
+ out_dir.mkdir(parents=True, exist_ok=True)
1266
+
1267
+ if args.device is None:
1268
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1269
+ else:
1270
+ device = torch.device(args.device)
1271
+ dtype = {"bfloat16": torch.bfloat16, "float16": torch.float16,
1272
+ "float32": torch.float32}[args.dtype]
1273
+
1274
+ print(f"[setup] device={device} dtype={dtype} model={args.model_path} "
1275
+ f"out={out_dir}", flush=True)
1276
+
1277
+ label2id, id2label = nemotron_native_label_space()
1278
+ o_id = label2id["O"]
1279
+
1280
+ tokenizer = AutoTokenizer.from_pretrained(args.model_path)
1281
+ pad_token_id = tokenizer.pad_token_id or 199999
1282
+
1283
+ # Build the eval set (same slice the README headline numbers reference)
1284
+ print(f"[data] loading {SOURCE_DATASET} ...", flush=True)
1285
+ ds = load_dataset(SOURCE_DATASET)
1286
+ target_eval = max(1, int(round(len(ds["test"]) * args.eval_pct / 100.0)))
1287
+ eval_indices = _build_eval_streaming(
1288
+ ds["test"], target_n=target_eval,
1289
+ chunk_size=args.eval_chunk_size, seed=args.seed,
1290
+ )
1291
+ eval_ds = _NemotronEvalDataset(
1292
+ ds["test"].select(eval_indices), tokenizer, label2id, args.max_length,
1293
+ )
1294
+ args.n_docs = len(eval_ds)
1295
+ print(f"[data] eval={args.n_docs:,} docs ({args.eval_pct:.2f}% of test split)",
1296
+ flush=True)
1297
+
1298
+ # BIOES decoding masks (used for the explicit RAW vs VITERBI streams).
1299
+ from modeling_haremb_pii import (
1300
+ _build_bioes_initial_mask as _bld_init,
1301
+ _build_bioes_transition_mask as _bld_trans,
1302
+ )
1303
+ bioes_trans = _bld_trans(id2label).to(device).float()
1304
+ bioes_init = _bld_init(id2label).to(device).float()
1305
+
1306
+ # Eval candidate
1307
+ cand = _eval_one_model(
1308
+ args.model_path, tokenizer, eval_ds, label2id, id2label, o_id,
1309
+ bioes_trans, bioes_init,
1310
+ args.batch_size, args.max_length, device, dtype,
1311
+ label="haremb",
1312
+ )
1313
+
1314
+ print(f"\n=== {cand['label']} ===")
1315
+ print(f"RAW {_fmt_metrics(cand['raw'])}")
1316
+ print(f"VITERBI {_fmt_metrics(cand['viterbi'])}")
1317
+ print(f"DELTA span_F1={cand['viterbi']['span_f1']-cand['raw']['span_f1']:+.4f} "
1318
+ f"P={cand['viterbi']['span_precision']-cand['raw']['span_precision']:+.4f} "
1319
+ f"R={cand['viterbi']['span_recall']-cand['raw']['span_recall']:+.4f}")
1320
+
1321
+ ref = None
1322
+ pair = None
1323
+ if not args.no_base:
1324
+ ref = _eval_one_model(
1325
+ TEACHER, tokenizer, eval_ds, label2id, id2label, o_id,
1326
+ bioes_trans, bioes_init,
1327
+ args.batch_size, args.max_length, device, dtype,
1328
+ label="openmed-base",
1329
+ )
1330
+ print(f"\n=== {ref['label']} (teacher) ===")
1331
+ print(f"VITERBI {_fmt_metrics(ref['viterbi'])}")
1332
+ pair = _pairwise(cand["docs"], ref["docs"], id2label, o_id)
1333
+ d = ref["viterbi"]["span_f1"] - cand["viterbi"]["span_f1"]
1334
+ print(f"\nGap to teacher (viterbi span_F1): {d:+.4f}")
1335
+
1336
+ # Reports
1337
+ sample_text = ("Patient Sarah Johnson (DOB 03/15/1985), MRN 4872910, "
1338
+ "phone 415-555-0123, email sarah.johnson@example.com, "
1339
+ "credit card 4111-1111-1111-1111.")
1340
+ _write_infer_log(out_dir / "infer.log", cand, ref, args, sample_text,
1341
+ tokenizer, device, dtype)
1342
+ _write_compare_log(out_dir / "compare.log", cand, ref, pair, args)
1343
+ if not args.no_plots:
1344
+ _render_plots(cand, ref, pair, out_dir,
1345
+ cand_label=cand["label"],
1346
+ ref_label=ref["label"] if ref else "")
1347
+ print("\n[done]")
1348
+
1349
+
1350
+ if __name__ == "__main__":
1351
+ main()
compare.log ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Benchmark: A: openmed-base vs B: haremb
2
+ Dataset: nvidia/Nemotron-PII, split=test, eval_pct=1.0, ctx=1024, seed=42, n_docs=1000
3
+ Eval tokens scored: 212,909
4
+
5
+ === Aggregate ===
6
+ A: openmed-base RAW span_F1=0.9174 P=0.9125 R=0.9223 token_acc=0.9895 non_o_recall=0.9685 spans=8627/8720/7957
7
+ A: openmed-base VITERBI span_F1=0.9434 P=0.9531 R=0.9338 token_acc=0.9900 non_o_recall=0.9703 spans=8627/8452/8056
8
+ B: haremb RAW span_F1=0.7741 P=0.7186 R=0.8388 token_acc=0.9831 non_o_recall=0.9467 spans=8627/10069/7236
9
+ B: haremb VITERBI span_F1=0.9288 P=0.9396 R=0.9182 token_acc=0.9885 non_o_recall=0.9637 spans=8627/8430/7921
10
+
11
+ Gap B vs A (viterbi span_F1): -0.0146
12
+
13
+ Throughput: A: openmed-base 3293 tok/s (1399.61M params)
14
+ B: haremb 6343 tok/s (287.11M params)
15
+
16
+ === Performance ===
17
+ metric A: openmed-base B: haremb B vs A
18
+ ------------------------------- --------------- --------- -------------
19
+ total params (M) 1399.61 287.11 4.87× smaller
20
+ dense params (M) 139.35 129.58 1.08× smaller
21
+ MoE expert params (M) 1260.26 157.53 8.00× smaller
22
+ active params/token (M, mem) 178.73 134.50 1.33× less
23
+ compute params/token (M, FLOPs) 50.69 6.46 7.85× cheaper
24
+ GFLOP / token 0.1014 0.0129 7.85× cheaper
25
+ disk size (MiB) 547.6
26
+ weights in RAM (MiB) 2669.5 547.6 4.87× smaller
27
+ peak GPU mem eval (MiB) 3376.2 1248.6 2.70× less
28
+ throughput (tok/s) 3293 6343 1.93× faster
29
+
30
+ === Pairwise (viterbi, all gold tokens) — A: openmed-base vs B: haremb ===
31
+ both_correct 209830 (98.55%)
32
+ only_A_correct 958 (0.45%)
33
+ only_B_correct 633 (0.30%)
34
+ both_wrong 1488 (0.70%)
35
+
36
+ === Pairwise (viterbi, gold non-O tokens) — A: openmed-base vs B: haremb ===
37
+ both_correct 43902 (95.47%)
38
+ only_A_correct 717 (1.56%)
39
+ only_B_correct 417 (0.91%)
40
+ both_wrong 951 (2.07%)
41
+
42
+ === Worst B-net wins by gold category — A: openmed-base ahead (top 15) ===
43
+ company_name net_B= -142 A_only= 216 B_only= 74 both_wrong= 119
44
+ first_name net_B= -75 A_only= 82 B_only= 7 both_wrong= 19
45
+ last_name net_B= -65 A_only= 67 B_only= 2 both_wrong= 38
46
+ occupation net_B= -55 A_only= 79 B_only= 24 both_wrong= 286
47
+ device_identifier net_B= -29 A_only= 29 B_only= 0 both_wrong= 0
48
+ user_name net_B= -26 A_only= 29 B_only= 3 both_wrong= 10
49
+ city net_B= -16 A_only= 32 B_only= 16 both_wrong= 36
50
+ street_address net_B= -13 A_only= 14 B_only= 1 both_wrong= 0
51
+ date_of_birth net_B= -12 A_only= 12 B_only= 0 both_wrong= 0
52
+ email net_B= -8 A_only= 9 B_only= 1 both_wrong= 0
53
+ medical_record_number net_B= -7 A_only= 7 B_only= 0 both_wrong= 0
54
+ phone_number net_B= -6 A_only= 6 B_only= 0 both_wrong= 0
55
+ account_number net_B= -6 A_only= 10 B_only= 4 both_wrong= 0
56
+ tax_id net_B= -6 A_only= 6 B_only= 0 both_wrong= 0
57
+ race_ethnicity net_B= -5 A_only= 15 B_only= 10 both_wrong= 11
58
+
59
+ === Best B-net wins by gold category — B: haremb ahead (top 15) ===
60
+ date net_B= +46 A_only= 15 B_only= 61 both_wrong= 145
61
+ fax_number net_B= +29 A_only= 1 B_only= 30 both_wrong= 44
62
+ unique_id net_B= +26 A_only= 4 B_only= 30 both_wrong= 0
63
+ ssn net_B= +18 A_only= 0 B_only= 18 both_wrong= 0
64
+ time net_B= +12 A_only= 9 B_only= 21 both_wrong= 87
65
+ political_view net_B= +11 A_only= 3 B_only= 14 both_wrong= 6
66
+ coordinate net_B= +11 A_only= 0 B_only= 11 both_wrong= 0
67
+ customer_id net_B= +8 A_only= 4 B_only= 12 both_wrong= 7
68
+ certificate_license_number net_B= +7 A_only= 0 B_only= 7 both_wrong= 0
69
+ education_level net_B= +6 A_only= 2 B_only= 8 both_wrong= 28
70
+ state net_B= +6 A_only= 13 B_only= 19 both_wrong= 20
71
+ blood_type net_B= +6 A_only= 0 B_only= 6 both_wrong= 0
72
+ gender net_B= +2 A_only= 0 B_only= 2 both_wrong= 2
73
+ http_cookie net_B= +2 A_only= 0 B_only= 2 both_wrong= 3
74
+ country net_B= +2 A_only= 0 B_only= 2 both_wrong= 19
75
+
76
+ === Per-category span F1 (viterbi) ===
77
+ -- A: openmed-base --
78
+ account_number F1=0.9929 P=0.9929 R=0.9929 (140/140/139)
79
+ age F1=0.8840 P=0.8511 R=0.9195 (87/94/80)
80
+ api_key F1=0.9921 P=0.9844 R=1.0000 (63/64/63)
81
+ bank_routing_number F1=0.9867 P=0.9867 R=0.9867 (75/75/74)
82
+ biometric_identifier F1=1.0000 P=1.0000 R=1.0000 (113/113/113)
83
+ blood_type F1=0.9032 P=0.9032 R=0.9032 (62/62/56)
84
+ certificate_license_number F1=0.9697 P=1.0000 R=0.9412 (34/32/32)
85
+ city F1=0.9154 P=0.9583 R=0.8762 (210/192/184)
86
+ company_name F1=0.8824 P=0.9143 R=0.8526 (563/525/480)
87
+ coordinate F1=0.8000 P=0.8000 R=0.8000 (55/55/44)
88
+ country F1=0.9431 P=0.9324 R=0.9539 (217/222/207)
89
+ county F1=0.9519 P=0.9612 R=0.9429 (105/103/99)
90
+ credit_debit_card F1=0.9967 P=0.9934 R=1.0000 (150/151/150)
91
+ customer_id F1=0.9849 P=1.0000 R=0.9703 (202/196/196)
92
+ cvv F1=0.9787 P=1.0000 R=0.9583 (48/46/46)
93
+ date F1=0.9440 P=0.9571 R=0.9312 (814/792/758)
94
+ date_of_birth F1=1.0000 P=1.0000 R=1.0000 (164/164/164)
95
+ date_time F1=0.9635 P=0.9429 R=0.9851 (134/140/132)
96
+ device_identifier F1=0.9714 P=0.9444 R=1.0000 (51/54/51)
97
+ education_level F1=0.9091 P=0.9524 R=0.8696 (92/84/80)
98
+ email F1=0.9971 P=0.9961 R=0.9980 (511/512/510)
99
+ employee_id F1=0.9948 P=1.0000 R=0.9896 (96/95/95)
100
+ employment_status F1=0.9478 P=0.9593 R=0.9365 (126/123/118)
101
+ fax_number F1=0.9091 P=0.9848 R=0.8442 (77/66/65)
102
+ first_name F1=0.9766 P=0.9716 R=0.9816 (871/880/855)
103
+ gender F1=0.9737 P=0.9867 R=0.9610 (77/75/74)
104
+ health_plan_beneficiary_number F1=1.0000 P=1.0000 R=1.0000 (103/103/103)
105
+ http_cookie F1=0.9307 P=0.9400 R=0.9216 (51/50/47)
106
+ ipv4 F1=1.0000 P=1.0000 R=1.0000 (59/59/59)
107
+ ipv6 F1=1.0000 P=1.0000 R=1.0000 (21/21/21)
108
+ language F1=0.9000 P=0.9000 R=0.9000 (90/90/81)
109
+ last_name F1=0.9744 P=0.9767 R=0.9721 (646/643/628)
110
+ license_plate F1=1.0000 P=1.0000 R=1.0000 (55/55/55)
111
+ mac_address F1=1.0000 P=1.0000 R=1.0000 (30/30/30)
112
+ medical_record_number F1=1.0000 P=1.0000 R=1.0000 (103/103/103)
113
+ national_id F1=1.0000 P=1.0000 R=1.0000 (28/28/28)
114
+ occupation F1=0.6522 P=0.7721 R=0.5645 (372/272/210)
115
+ password F1=0.9217 P=0.9636 R=0.8833 (60/55/53)
116
+ phone_number F1=0.9751 P=0.9514 R=1.0000 (235/247/235)
117
+ pin F1=0.9302 P=0.8955 R=0.9677 (62/67/60)
118
+ political_view F1=0.8387 P=0.8125 R=0.8667 (45/48/39)
119
+ postcode F1=0.9934 P=0.9868 R=1.0000 (75/76/75)
120
+ race_ethnicity F1=0.8889 P=0.8889 R=0.8889 (81/81/72)
121
+ religious_belief F1=0.8936 P=0.8750 R=0.9130 (46/48/42)
122
+ sexuality F1=0.9667 P=1.0000 R=0.9355 (31/29/29)
123
+ ssn F1=0.9440 P=0.9365 R=0.9516 (62/63/59)
124
+ state F1=0.9198 P=0.9399 R=0.9005 (191/183/172)
125
+ street_address F1=0.9894 P=0.9842 R=0.9947 (188/190/187)
126
+ swift_bic F1=0.9905 P=0.9811 R=1.0000 (52/53/52)
127
+ tax_id F1=1.0000 P=1.0000 R=1.0000 (15/15/15)
128
+ time F1=0.8209 P=0.8514 R=0.7926 (188/175/149)
129
+ unique_id F1=0.9600 P=1.0000 R=0.9231 (13/12/12)
130
+ url F1=0.9725 P=0.9687 R=0.9763 (380/383/371)
131
+ user_name F1=0.9497 P=0.9264 R=0.9742 (155/163/151)
132
+ vehicle_identifier F1=0.9815 P=0.9636 R=1.0000 (53/55/53)
133
+ -- B: haremb --
134
+ account_number F1=0.9751 P=0.9716 R=0.9786 (140/141/137)
135
+ age F1=0.8571 P=0.8211 R=0.8966 (87/95/78)
136
+ api_key F1=0.9921 P=0.9844 R=1.0000 (63/64/63)
137
+ bank_routing_number F1=0.9933 P=1.0000 R=0.9867 (75/74/74)
138
+ biometric_identifier F1=1.0000 P=1.0000 R=1.0000 (113/113/113)
139
+ blood_type F1=1.0000 P=1.0000 R=1.0000 (62/62/62)
140
+ certificate_license_number F1=0.9855 P=0.9714 R=1.0000 (34/35/34)
141
+ city F1=0.8932 P=0.9109 R=0.8762 (210/202/184)
142
+ company_name F1=0.7766 P=0.8120 R=0.7442 (563/516/419)
143
+ coordinate F1=1.0000 P=1.0000 R=1.0000 (55/55/55)
144
+ country F1=0.9543 P=0.9457 R=0.9631 (217/221/209)
145
+ county F1=0.9340 P=0.9252 R=0.9429 (105/107/99)
146
+ credit_debit_card F1=0.9934 P=0.9868 R=1.0000 (150/152/150)
147
+ customer_id F1=0.9779 P=0.9707 R=0.9851 (202/205/199)
148
+ cvv F1=0.9792 P=0.9792 R=0.9792 (48/48/47)
149
+ date F1=0.9510 P=0.9599 R=0.9423 (814/799/767)
150
+ date_of_birth F1=0.9939 P=1.0000 R=0.9878 (164/162/162)
151
+ date_time F1=0.9635 P=0.9429 R=0.9851 (134/140/132)
152
+ device_identifier F1=0.9515 P=0.9423 R=0.9608 (51/52/49)
153
+ education_level F1=0.9091 P=0.9524 R=0.8696 (92/84/80)
154
+ email F1=0.9912 P=0.9883 R=0.9941 (511/514/508)
155
+ employee_id F1=0.9895 P=1.0000 R=0.9792 (96/94/94)
156
+ employment_status F1=0.9562 P=0.9600 R=0.9524 (126/125/120)
157
+ fax_number F1=0.9396 P=0.9722 R=0.9091 (77/72/70)
158
+ first_name F1=0.9299 P=0.9231 R=0.9369 (871/884/816)
159
+ gender F1=0.9870 P=0.9870 R=0.9870 (77/77/76)
160
+ health_plan_beneficiary_number F1=1.0000 P=1.0000 R=1.0000 (103/103/103)
161
+ http_cookie F1=0.9608 P=0.9608 R=0.9608 (51/51/49)
162
+ ipv4 F1=1.0000 P=1.0000 R=1.0000 (59/59/59)
163
+ ipv6 F1=1.0000 P=1.0000 R=1.0000 (21/21/21)
164
+ language F1=0.8966 P=0.9286 R=0.8667 (90/84/78)
165
+ last_name F1=0.9308 P=0.9457 R=0.9164 (646/626/592)
166
+ license_plate F1=1.0000 P=1.0000 R=1.0000 (55/55/55)
167
+ mac_address F1=1.0000 P=1.0000 R=1.0000 (30/30/30)
168
+ medical_record_number F1=0.9903 P=0.9903 R=0.9903 (103/103/102)
169
+ national_id F1=0.9825 P=0.9655 R=1.0000 (28/29/28)
170
+ occupation F1=0.5981 P=0.7440 R=0.5000 (372/250/186)
171
+ password F1=0.9391 P=0.9818 R=0.9000 (60/55/54)
172
+ phone_number F1=0.9730 P=0.9512 R=0.9957 (235/246/234)
173
+ pin F1=0.9508 P=0.9667 R=0.9355 (62/60/58)
174
+ political_view F1=0.8723 P=0.8367 R=0.9111 (45/49/41)
175
+ postcode F1=0.9934 P=0.9868 R=1.0000 (75/76/75)
176
+ race_ethnicity F1=0.8590 P=0.8933 R=0.8272 (81/75/67)
177
+ religious_belief F1=0.9348 P=0.9348 R=0.9348 (46/46/43)
178
+ sexuality F1=0.9492 P=1.0000 R=0.9032 (31/28/28)
179
+ ssn F1=0.9688 P=0.9394 R=1.0000 (62/66/62)
180
+ state F1=0.9105 P=0.9153 R=0.9058 (191/189/173)
181
+ street_address F1=0.9894 P=0.9894 R=0.9894 (188/188/186)
182
+ swift_bic F1=0.9905 P=0.9811 R=1.0000 (52/53/52)
183
+ tax_id F1=0.9655 P=1.0000 R=0.9333 (15/14/14)
184
+ time F1=0.8421 P=0.8786 R=0.8085 (188/173/152)
185
+ unique_id F1=0.8571 P=0.8000 R=0.9231 (13/15/12)
186
+ url F1=0.9752 P=0.9688 R=0.9816 (380/385/373)
187
+ user_name F1=0.9416 P=0.9477 R=0.9355 (155/153/145)
188
+ vehicle_identifier F1=0.9630 P=0.9455 R=0.9811 (53/55/52)
config.json ADDED
@@ -0,0 +1,788 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "HaremPiiForTokenClassification"
4
+ ],
5
+ "attention_bias": true,
6
+ "attention_dropout": 0.0,
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_haremb_pii.HaremPiiConfig",
9
+ "AutoModelForTokenClassification": "modeling_haremb_pii.HaremPiiForTokenClassification"
10
+ },
11
+ "bos_token_id": null,
12
+ "classifier_dropout": 0.0,
13
+ "default_n_ctx": 128000,
14
+ "dtype": "bfloat16",
15
+ "eos_token_id": 199999,
16
+ "head_dim": 64,
17
+ "hidden_act": "silu",
18
+ "hidden_size": 640,
19
+ "id2label": {
20
+ "0": "O",
21
+ "1": "B-account_number",
22
+ "2": "I-account_number",
23
+ "3": "E-account_number",
24
+ "4": "S-account_number",
25
+ "5": "B-age",
26
+ "6": "I-age",
27
+ "7": "E-age",
28
+ "8": "S-age",
29
+ "9": "B-api_key",
30
+ "10": "I-api_key",
31
+ "11": "E-api_key",
32
+ "12": "S-api_key",
33
+ "13": "B-bank_routing_number",
34
+ "14": "I-bank_routing_number",
35
+ "15": "E-bank_routing_number",
36
+ "16": "S-bank_routing_number",
37
+ "17": "B-biometric_identifier",
38
+ "18": "I-biometric_identifier",
39
+ "19": "E-biometric_identifier",
40
+ "20": "S-biometric_identifier",
41
+ "21": "B-blood_type",
42
+ "22": "I-blood_type",
43
+ "23": "E-blood_type",
44
+ "24": "S-blood_type",
45
+ "25": "B-certificate_license_number",
46
+ "26": "I-certificate_license_number",
47
+ "27": "E-certificate_license_number",
48
+ "28": "S-certificate_license_number",
49
+ "29": "B-city",
50
+ "30": "I-city",
51
+ "31": "E-city",
52
+ "32": "S-city",
53
+ "33": "B-company_name",
54
+ "34": "I-company_name",
55
+ "35": "E-company_name",
56
+ "36": "S-company_name",
57
+ "37": "B-coordinate",
58
+ "38": "I-coordinate",
59
+ "39": "E-coordinate",
60
+ "40": "S-coordinate",
61
+ "41": "B-country",
62
+ "42": "I-country",
63
+ "43": "E-country",
64
+ "44": "S-country",
65
+ "45": "B-county",
66
+ "46": "I-county",
67
+ "47": "E-county",
68
+ "48": "S-county",
69
+ "49": "B-credit_debit_card",
70
+ "50": "I-credit_debit_card",
71
+ "51": "E-credit_debit_card",
72
+ "52": "S-credit_debit_card",
73
+ "53": "B-customer_id",
74
+ "54": "I-customer_id",
75
+ "55": "E-customer_id",
76
+ "56": "S-customer_id",
77
+ "57": "B-cvv",
78
+ "58": "I-cvv",
79
+ "59": "E-cvv",
80
+ "60": "S-cvv",
81
+ "61": "B-date",
82
+ "62": "I-date",
83
+ "63": "E-date",
84
+ "64": "S-date",
85
+ "65": "B-date_of_birth",
86
+ "66": "I-date_of_birth",
87
+ "67": "E-date_of_birth",
88
+ "68": "S-date_of_birth",
89
+ "69": "B-date_time",
90
+ "70": "I-date_time",
91
+ "71": "E-date_time",
92
+ "72": "S-date_time",
93
+ "73": "B-device_identifier",
94
+ "74": "I-device_identifier",
95
+ "75": "E-device_identifier",
96
+ "76": "S-device_identifier",
97
+ "77": "B-education_level",
98
+ "78": "I-education_level",
99
+ "79": "E-education_level",
100
+ "80": "S-education_level",
101
+ "81": "B-email",
102
+ "82": "I-email",
103
+ "83": "E-email",
104
+ "84": "S-email",
105
+ "85": "B-employee_id",
106
+ "86": "I-employee_id",
107
+ "87": "E-employee_id",
108
+ "88": "S-employee_id",
109
+ "89": "B-employment_status",
110
+ "90": "I-employment_status",
111
+ "91": "E-employment_status",
112
+ "92": "S-employment_status",
113
+ "93": "B-fax_number",
114
+ "94": "I-fax_number",
115
+ "95": "E-fax_number",
116
+ "96": "S-fax_number",
117
+ "97": "B-first_name",
118
+ "98": "I-first_name",
119
+ "99": "E-first_name",
120
+ "100": "S-first_name",
121
+ "101": "B-gender",
122
+ "102": "I-gender",
123
+ "103": "E-gender",
124
+ "104": "S-gender",
125
+ "105": "B-health_plan_beneficiary_number",
126
+ "106": "I-health_plan_beneficiary_number",
127
+ "107": "E-health_plan_beneficiary_number",
128
+ "108": "S-health_plan_beneficiary_number",
129
+ "109": "B-http_cookie",
130
+ "110": "I-http_cookie",
131
+ "111": "E-http_cookie",
132
+ "112": "S-http_cookie",
133
+ "113": "B-ipv4",
134
+ "114": "I-ipv4",
135
+ "115": "E-ipv4",
136
+ "116": "S-ipv4",
137
+ "117": "B-ipv6",
138
+ "118": "I-ipv6",
139
+ "119": "E-ipv6",
140
+ "120": "S-ipv6",
141
+ "121": "B-language",
142
+ "122": "I-language",
143
+ "123": "E-language",
144
+ "124": "S-language",
145
+ "125": "B-last_name",
146
+ "126": "I-last_name",
147
+ "127": "E-last_name",
148
+ "128": "S-last_name",
149
+ "129": "B-license_plate",
150
+ "130": "I-license_plate",
151
+ "131": "E-license_plate",
152
+ "132": "S-license_plate",
153
+ "133": "B-mac_address",
154
+ "134": "I-mac_address",
155
+ "135": "E-mac_address",
156
+ "136": "S-mac_address",
157
+ "137": "B-medical_record_number",
158
+ "138": "I-medical_record_number",
159
+ "139": "E-medical_record_number",
160
+ "140": "S-medical_record_number",
161
+ "141": "B-national_id",
162
+ "142": "I-national_id",
163
+ "143": "E-national_id",
164
+ "144": "S-national_id",
165
+ "145": "B-occupation",
166
+ "146": "I-occupation",
167
+ "147": "E-occupation",
168
+ "148": "S-occupation",
169
+ "149": "B-password",
170
+ "150": "I-password",
171
+ "151": "E-password",
172
+ "152": "S-password",
173
+ "153": "B-phone_number",
174
+ "154": "I-phone_number",
175
+ "155": "E-phone_number",
176
+ "156": "S-phone_number",
177
+ "157": "B-pin",
178
+ "158": "I-pin",
179
+ "159": "E-pin",
180
+ "160": "S-pin",
181
+ "161": "B-political_view",
182
+ "162": "I-political_view",
183
+ "163": "E-political_view",
184
+ "164": "S-political_view",
185
+ "165": "B-postcode",
186
+ "166": "I-postcode",
187
+ "167": "E-postcode",
188
+ "168": "S-postcode",
189
+ "169": "B-race_ethnicity",
190
+ "170": "I-race_ethnicity",
191
+ "171": "E-race_ethnicity",
192
+ "172": "S-race_ethnicity",
193
+ "173": "B-religious_belief",
194
+ "174": "I-religious_belief",
195
+ "175": "E-religious_belief",
196
+ "176": "S-religious_belief",
197
+ "177": "B-sexuality",
198
+ "178": "I-sexuality",
199
+ "179": "E-sexuality",
200
+ "180": "S-sexuality",
201
+ "181": "B-ssn",
202
+ "182": "I-ssn",
203
+ "183": "E-ssn",
204
+ "184": "S-ssn",
205
+ "185": "B-state",
206
+ "186": "I-state",
207
+ "187": "E-state",
208
+ "188": "S-state",
209
+ "189": "B-street_address",
210
+ "190": "I-street_address",
211
+ "191": "E-street_address",
212
+ "192": "S-street_address",
213
+ "193": "B-swift_bic",
214
+ "194": "I-swift_bic",
215
+ "195": "E-swift_bic",
216
+ "196": "S-swift_bic",
217
+ "197": "B-tax_id",
218
+ "198": "I-tax_id",
219
+ "199": "E-tax_id",
220
+ "200": "S-tax_id",
221
+ "201": "B-time",
222
+ "202": "I-time",
223
+ "203": "E-time",
224
+ "204": "S-time",
225
+ "205": "B-unique_id",
226
+ "206": "I-unique_id",
227
+ "207": "E-unique_id",
228
+ "208": "S-unique_id",
229
+ "209": "B-url",
230
+ "210": "I-url",
231
+ "211": "E-url",
232
+ "212": "S-url",
233
+ "213": "B-user_name",
234
+ "214": "I-user_name",
235
+ "215": "E-user_name",
236
+ "216": "S-user_name",
237
+ "217": "B-vehicle_identifier",
238
+ "218": "I-vehicle_identifier",
239
+ "219": "E-vehicle_identifier",
240
+ "220": "S-vehicle_identifier"
241
+ },
242
+ "initial_context_length": 4096,
243
+ "initializer_range": 0.02,
244
+ "intermediate_size": 640,
245
+ "label2id": {
246
+ "B-account_number": 1,
247
+ "B-age": 5,
248
+ "B-api_key": 9,
249
+ "B-bank_routing_number": 13,
250
+ "B-biometric_identifier": 17,
251
+ "B-blood_type": 21,
252
+ "B-certificate_license_number": 25,
253
+ "B-city": 29,
254
+ "B-company_name": 33,
255
+ "B-coordinate": 37,
256
+ "B-country": 41,
257
+ "B-county": 45,
258
+ "B-credit_debit_card": 49,
259
+ "B-customer_id": 53,
260
+ "B-cvv": 57,
261
+ "B-date": 61,
262
+ "B-date_of_birth": 65,
263
+ "B-date_time": 69,
264
+ "B-device_identifier": 73,
265
+ "B-education_level": 77,
266
+ "B-email": 81,
267
+ "B-employee_id": 85,
268
+ "B-employment_status": 89,
269
+ "B-fax_number": 93,
270
+ "B-first_name": 97,
271
+ "B-gender": 101,
272
+ "B-health_plan_beneficiary_number": 105,
273
+ "B-http_cookie": 109,
274
+ "B-ipv4": 113,
275
+ "B-ipv6": 117,
276
+ "B-language": 121,
277
+ "B-last_name": 125,
278
+ "B-license_plate": 129,
279
+ "B-mac_address": 133,
280
+ "B-medical_record_number": 137,
281
+ "B-national_id": 141,
282
+ "B-occupation": 145,
283
+ "B-password": 149,
284
+ "B-phone_number": 153,
285
+ "B-pin": 157,
286
+ "B-political_view": 161,
287
+ "B-postcode": 165,
288
+ "B-race_ethnicity": 169,
289
+ "B-religious_belief": 173,
290
+ "B-sexuality": 177,
291
+ "B-ssn": 181,
292
+ "B-state": 185,
293
+ "B-street_address": 189,
294
+ "B-swift_bic": 193,
295
+ "B-tax_id": 197,
296
+ "B-time": 201,
297
+ "B-unique_id": 205,
298
+ "B-url": 209,
299
+ "B-user_name": 213,
300
+ "B-vehicle_identifier": 217,
301
+ "E-account_number": 3,
302
+ "E-age": 7,
303
+ "E-api_key": 11,
304
+ "E-bank_routing_number": 15,
305
+ "E-biometric_identifier": 19,
306
+ "E-blood_type": 23,
307
+ "E-certificate_license_number": 27,
308
+ "E-city": 31,
309
+ "E-company_name": 35,
310
+ "E-coordinate": 39,
311
+ "E-country": 43,
312
+ "E-county": 47,
313
+ "E-credit_debit_card": 51,
314
+ "E-customer_id": 55,
315
+ "E-cvv": 59,
316
+ "E-date": 63,
317
+ "E-date_of_birth": 67,
318
+ "E-date_time": 71,
319
+ "E-device_identifier": 75,
320
+ "E-education_level": 79,
321
+ "E-email": 83,
322
+ "E-employee_id": 87,
323
+ "E-employment_status": 91,
324
+ "E-fax_number": 95,
325
+ "E-first_name": 99,
326
+ "E-gender": 103,
327
+ "E-health_plan_beneficiary_number": 107,
328
+ "E-http_cookie": 111,
329
+ "E-ipv4": 115,
330
+ "E-ipv6": 119,
331
+ "E-language": 123,
332
+ "E-last_name": 127,
333
+ "E-license_plate": 131,
334
+ "E-mac_address": 135,
335
+ "E-medical_record_number": 139,
336
+ "E-national_id": 143,
337
+ "E-occupation": 147,
338
+ "E-password": 151,
339
+ "E-phone_number": 155,
340
+ "E-pin": 159,
341
+ "E-political_view": 163,
342
+ "E-postcode": 167,
343
+ "E-race_ethnicity": 171,
344
+ "E-religious_belief": 175,
345
+ "E-sexuality": 179,
346
+ "E-ssn": 183,
347
+ "E-state": 187,
348
+ "E-street_address": 191,
349
+ "E-swift_bic": 195,
350
+ "E-tax_id": 199,
351
+ "E-time": 203,
352
+ "E-unique_id": 207,
353
+ "E-url": 211,
354
+ "E-user_name": 215,
355
+ "E-vehicle_identifier": 219,
356
+ "I-account_number": 2,
357
+ "I-age": 6,
358
+ "I-api_key": 10,
359
+ "I-bank_routing_number": 14,
360
+ "I-biometric_identifier": 18,
361
+ "I-blood_type": 22,
362
+ "I-certificate_license_number": 26,
363
+ "I-city": 30,
364
+ "I-company_name": 34,
365
+ "I-coordinate": 38,
366
+ "I-country": 42,
367
+ "I-county": 46,
368
+ "I-credit_debit_card": 50,
369
+ "I-customer_id": 54,
370
+ "I-cvv": 58,
371
+ "I-date": 62,
372
+ "I-date_of_birth": 66,
373
+ "I-date_time": 70,
374
+ "I-device_identifier": 74,
375
+ "I-education_level": 78,
376
+ "I-email": 82,
377
+ "I-employee_id": 86,
378
+ "I-employment_status": 90,
379
+ "I-fax_number": 94,
380
+ "I-first_name": 98,
381
+ "I-gender": 102,
382
+ "I-health_plan_beneficiary_number": 106,
383
+ "I-http_cookie": 110,
384
+ "I-ipv4": 114,
385
+ "I-ipv6": 118,
386
+ "I-language": 122,
387
+ "I-last_name": 126,
388
+ "I-license_plate": 130,
389
+ "I-mac_address": 134,
390
+ "I-medical_record_number": 138,
391
+ "I-national_id": 142,
392
+ "I-occupation": 146,
393
+ "I-password": 150,
394
+ "I-phone_number": 154,
395
+ "I-pin": 158,
396
+ "I-political_view": 162,
397
+ "I-postcode": 166,
398
+ "I-race_ethnicity": 170,
399
+ "I-religious_belief": 174,
400
+ "I-sexuality": 178,
401
+ "I-ssn": 182,
402
+ "I-state": 186,
403
+ "I-street_address": 190,
404
+ "I-swift_bic": 194,
405
+ "I-tax_id": 198,
406
+ "I-time": 202,
407
+ "I-unique_id": 206,
408
+ "I-url": 210,
409
+ "I-user_name": 214,
410
+ "I-vehicle_identifier": 218,
411
+ "O": 0,
412
+ "S-account_number": 4,
413
+ "S-age": 8,
414
+ "S-api_key": 12,
415
+ "S-bank_routing_number": 16,
416
+ "S-biometric_identifier": 20,
417
+ "S-blood_type": 24,
418
+ "S-certificate_license_number": 28,
419
+ "S-city": 32,
420
+ "S-company_name": 36,
421
+ "S-coordinate": 40,
422
+ "S-country": 44,
423
+ "S-county": 48,
424
+ "S-credit_debit_card": 52,
425
+ "S-customer_id": 56,
426
+ "S-cvv": 60,
427
+ "S-date": 64,
428
+ "S-date_of_birth": 68,
429
+ "S-date_time": 72,
430
+ "S-device_identifier": 76,
431
+ "S-education_level": 80,
432
+ "S-email": 84,
433
+ "S-employee_id": 88,
434
+ "S-employment_status": 92,
435
+ "S-fax_number": 96,
436
+ "S-first_name": 100,
437
+ "S-gender": 104,
438
+ "S-health_plan_beneficiary_number": 108,
439
+ "S-http_cookie": 112,
440
+ "S-ipv4": 116,
441
+ "S-ipv6": 120,
442
+ "S-language": 124,
443
+ "S-last_name": 128,
444
+ "S-license_plate": 132,
445
+ "S-mac_address": 136,
446
+ "S-medical_record_number": 140,
447
+ "S-national_id": 144,
448
+ "S-occupation": 148,
449
+ "S-password": 152,
450
+ "S-phone_number": 156,
451
+ "S-pin": 160,
452
+ "S-political_view": 164,
453
+ "S-postcode": 168,
454
+ "S-race_ethnicity": 172,
455
+ "S-religious_belief": 176,
456
+ "S-sexuality": 180,
457
+ "S-ssn": 184,
458
+ "S-state": 188,
459
+ "S-street_address": 192,
460
+ "S-swift_bic": 196,
461
+ "S-tax_id": 200,
462
+ "S-time": 204,
463
+ "S-unique_id": 208,
464
+ "S-url": 212,
465
+ "S-user_name": 216,
466
+ "S-vehicle_identifier": 220
467
+ },
468
+ "max_position_embeddings": 131072,
469
+ "model_type": "haremb_pii",
470
+ "num_attention_heads": 14,
471
+ "num_experts_per_tok": 4,
472
+ "num_hidden_layers": 1,
473
+ "num_key_value_heads": 2,
474
+ "num_local_experts": 128,
475
+ "opf_metadata": {
476
+ "category_version": "nemotron_fine_v1",
477
+ "encoding": "o200k_base",
478
+ "inference_contract_version": 1,
479
+ "ner_class_names": [
480
+ "O",
481
+ "B-account_number",
482
+ "I-account_number",
483
+ "E-account_number",
484
+ "S-account_number",
485
+ "B-age",
486
+ "I-age",
487
+ "E-age",
488
+ "S-age",
489
+ "B-api_key",
490
+ "I-api_key",
491
+ "E-api_key",
492
+ "S-api_key",
493
+ "B-bank_routing_number",
494
+ "I-bank_routing_number",
495
+ "E-bank_routing_number",
496
+ "S-bank_routing_number",
497
+ "B-biometric_identifier",
498
+ "I-biometric_identifier",
499
+ "E-biometric_identifier",
500
+ "S-biometric_identifier",
501
+ "B-blood_type",
502
+ "I-blood_type",
503
+ "E-blood_type",
504
+ "S-blood_type",
505
+ "B-certificate_license_number",
506
+ "I-certificate_license_number",
507
+ "E-certificate_license_number",
508
+ "S-certificate_license_number",
509
+ "B-city",
510
+ "I-city",
511
+ "E-city",
512
+ "S-city",
513
+ "B-company_name",
514
+ "I-company_name",
515
+ "E-company_name",
516
+ "S-company_name",
517
+ "B-coordinate",
518
+ "I-coordinate",
519
+ "E-coordinate",
520
+ "S-coordinate",
521
+ "B-country",
522
+ "I-country",
523
+ "E-country",
524
+ "S-country",
525
+ "B-county",
526
+ "I-county",
527
+ "E-county",
528
+ "S-county",
529
+ "B-credit_debit_card",
530
+ "I-credit_debit_card",
531
+ "E-credit_debit_card",
532
+ "S-credit_debit_card",
533
+ "B-customer_id",
534
+ "I-customer_id",
535
+ "E-customer_id",
536
+ "S-customer_id",
537
+ "B-cvv",
538
+ "I-cvv",
539
+ "E-cvv",
540
+ "S-cvv",
541
+ "B-date",
542
+ "I-date",
543
+ "E-date",
544
+ "S-date",
545
+ "B-date_of_birth",
546
+ "I-date_of_birth",
547
+ "E-date_of_birth",
548
+ "S-date_of_birth",
549
+ "B-date_time",
550
+ "I-date_time",
551
+ "E-date_time",
552
+ "S-date_time",
553
+ "B-device_identifier",
554
+ "I-device_identifier",
555
+ "E-device_identifier",
556
+ "S-device_identifier",
557
+ "B-education_level",
558
+ "I-education_level",
559
+ "E-education_level",
560
+ "S-education_level",
561
+ "B-email",
562
+ "I-email",
563
+ "E-email",
564
+ "S-email",
565
+ "B-employee_id",
566
+ "I-employee_id",
567
+ "E-employee_id",
568
+ "S-employee_id",
569
+ "B-employment_status",
570
+ "I-employment_status",
571
+ "E-employment_status",
572
+ "S-employment_status",
573
+ "B-fax_number",
574
+ "I-fax_number",
575
+ "E-fax_number",
576
+ "S-fax_number",
577
+ "B-first_name",
578
+ "I-first_name",
579
+ "E-first_name",
580
+ "S-first_name",
581
+ "B-gender",
582
+ "I-gender",
583
+ "E-gender",
584
+ "S-gender",
585
+ "B-health_plan_beneficiary_number",
586
+ "I-health_plan_beneficiary_number",
587
+ "E-health_plan_beneficiary_number",
588
+ "S-health_plan_beneficiary_number",
589
+ "B-http_cookie",
590
+ "I-http_cookie",
591
+ "E-http_cookie",
592
+ "S-http_cookie",
593
+ "B-ipv4",
594
+ "I-ipv4",
595
+ "E-ipv4",
596
+ "S-ipv4",
597
+ "B-ipv6",
598
+ "I-ipv6",
599
+ "E-ipv6",
600
+ "S-ipv6",
601
+ "B-language",
602
+ "I-language",
603
+ "E-language",
604
+ "S-language",
605
+ "B-last_name",
606
+ "I-last_name",
607
+ "E-last_name",
608
+ "S-last_name",
609
+ "B-license_plate",
610
+ "I-license_plate",
611
+ "E-license_plate",
612
+ "S-license_plate",
613
+ "B-mac_address",
614
+ "I-mac_address",
615
+ "E-mac_address",
616
+ "S-mac_address",
617
+ "B-medical_record_number",
618
+ "I-medical_record_number",
619
+ "E-medical_record_number",
620
+ "S-medical_record_number",
621
+ "B-national_id",
622
+ "I-national_id",
623
+ "E-national_id",
624
+ "S-national_id",
625
+ "B-occupation",
626
+ "I-occupation",
627
+ "E-occupation",
628
+ "S-occupation",
629
+ "B-password",
630
+ "I-password",
631
+ "E-password",
632
+ "S-password",
633
+ "B-phone_number",
634
+ "I-phone_number",
635
+ "E-phone_number",
636
+ "S-phone_number",
637
+ "B-pin",
638
+ "I-pin",
639
+ "E-pin",
640
+ "S-pin",
641
+ "B-political_view",
642
+ "I-political_view",
643
+ "E-political_view",
644
+ "S-political_view",
645
+ "B-postcode",
646
+ "I-postcode",
647
+ "E-postcode",
648
+ "S-postcode",
649
+ "B-race_ethnicity",
650
+ "I-race_ethnicity",
651
+ "E-race_ethnicity",
652
+ "S-race_ethnicity",
653
+ "B-religious_belief",
654
+ "I-religious_belief",
655
+ "E-religious_belief",
656
+ "S-religious_belief",
657
+ "B-sexuality",
658
+ "I-sexuality",
659
+ "E-sexuality",
660
+ "S-sexuality",
661
+ "B-ssn",
662
+ "I-ssn",
663
+ "E-ssn",
664
+ "S-ssn",
665
+ "B-state",
666
+ "I-state",
667
+ "E-state",
668
+ "S-state",
669
+ "B-street_address",
670
+ "I-street_address",
671
+ "E-street_address",
672
+ "S-street_address",
673
+ "B-swift_bic",
674
+ "I-swift_bic",
675
+ "E-swift_bic",
676
+ "S-swift_bic",
677
+ "B-tax_id",
678
+ "I-tax_id",
679
+ "E-tax_id",
680
+ "S-tax_id",
681
+ "B-time",
682
+ "I-time",
683
+ "E-time",
684
+ "S-time",
685
+ "B-unique_id",
686
+ "I-unique_id",
687
+ "E-unique_id",
688
+ "S-unique_id",
689
+ "B-url",
690
+ "I-url",
691
+ "E-url",
692
+ "S-url",
693
+ "B-user_name",
694
+ "I-user_name",
695
+ "E-user_name",
696
+ "S-user_name",
697
+ "B-vehicle_identifier",
698
+ "I-vehicle_identifier",
699
+ "E-vehicle_identifier",
700
+ "S-vehicle_identifier"
701
+ ],
702
+ "span_class_names": [
703
+ "O",
704
+ "account_number",
705
+ "age",
706
+ "api_key",
707
+ "bank_routing_number",
708
+ "biometric_identifier",
709
+ "blood_type",
710
+ "certificate_license_number",
711
+ "city",
712
+ "company_name",
713
+ "coordinate",
714
+ "country",
715
+ "county",
716
+ "credit_debit_card",
717
+ "customer_id",
718
+ "cvv",
719
+ "date",
720
+ "date_of_birth",
721
+ "date_time",
722
+ "device_identifier",
723
+ "education_level",
724
+ "email",
725
+ "employee_id",
726
+ "employment_status",
727
+ "fax_number",
728
+ "first_name",
729
+ "gender",
730
+ "health_plan_beneficiary_number",
731
+ "http_cookie",
732
+ "ipv4",
733
+ "ipv6",
734
+ "language",
735
+ "last_name",
736
+ "license_plate",
737
+ "mac_address",
738
+ "medical_record_number",
739
+ "national_id",
740
+ "occupation",
741
+ "password",
742
+ "phone_number",
743
+ "pin",
744
+ "political_view",
745
+ "postcode",
746
+ "race_ethnicity",
747
+ "religious_belief",
748
+ "sexuality",
749
+ "ssn",
750
+ "state",
751
+ "street_address",
752
+ "swift_bic",
753
+ "tax_id",
754
+ "time",
755
+ "unique_id",
756
+ "url",
757
+ "user_name",
758
+ "vehicle_identifier"
759
+ ]
760
+ },
761
+ "output_router_logits": false,
762
+ "pad_token_id": 199999,
763
+ "rms_norm_eps": 1e-05,
764
+ "rope_parameters": {
765
+ "beta_fast": 32.0,
766
+ "beta_slow": 1.0,
767
+ "factor": 32.0,
768
+ "original_max_position_embeddings": 4096,
769
+ "rope_theta": 150000.0,
770
+ "rope_type": "yarn",
771
+ "truncate": false
772
+ },
773
+ "router_aux_loss_coef": 0.001,
774
+ "sliding_window": 128,
775
+ "tie_word_embeddings": false,
776
+ "transformers.js_config": {
777
+ "use_external_data_format": {
778
+ "model": 1,
779
+ "model.onnx": 3,
780
+ "model_fp16.onnx": 2
781
+ }
782
+ },
783
+ "transformers_version": "5.7.0",
784
+ "use_cache": true,
785
+ "use_viterbi_decode": true,
786
+ "viterbi_replace_logits": true,
787
+ "vocab_size": 200064
788
+ }
configuration_haremb_pii.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HaremPiiConfig — subclass of OpenAIPrivacyFilterConfig that:
3
+ * sets `model_type="haremb_pii"` (so AutoConfig + auto_map dispatch works
4
+ with `trust_remote_code=True`)
5
+ * paired with HaremPiiForTokenClassification in modeling_haremb_pii.py
6
+ via `auto_map`
7
+
8
+ This release is a 1-layer surgical slice of the OpenMed teacher:
9
+ * num_hidden_layers=1
10
+ * inference-only — Viterbi decoding is built into the forward pass.
11
+ """
12
+ from __future__ import annotations
13
+
14
+ from transformers.models.openai_privacy_filter.configuration_openai_privacy_filter import (
15
+ OpenAIPrivacyFilterConfig,
16
+ )
17
+
18
+
19
+ class HaremPiiConfig(OpenAIPrivacyFilterConfig):
20
+ """
21
+ HarEmb config. `model_type="haremb_pii"` disambiguates from upstream so
22
+ AutoConfig + AutoModel mappings can target our subclasses without
23
+ colliding with the registered OpenAIPrivacyFilterConfig entry.
24
+ `modeling_haremb_pii` performs the auto-registration at import time.
25
+ """
26
+ model_type = "haremb_pii"
27
+
28
+ def __init__(
29
+ self,
30
+ use_viterbi_decode: bool = True,
31
+ viterbi_replace_logits: bool = True,
32
+ **kwargs,
33
+ ):
34
+ super().__init__(**kwargs)
35
+ # When True (and model is in eval mode), HaremPiiForTokenClassification.forward
36
+ # runs constrained BIOES Viterbi over logits and attaches `predicted_labels`
37
+ # to the output. Set to False to skip Viterbi entirely.
38
+ self.use_viterbi_decode = bool(use_viterbi_decode)
39
+ # When True (and Viterbi is on), forward replaces `outputs.logits` with a
40
+ # one-hot-shaped tensor whose argmax equals the Viterbi prediction. This
41
+ # makes HF `pipeline()` and any naive `logits.argmax(-1)` consumer use
42
+ # Viterbi predictions automatically. The raw logits are preserved on
43
+ # the output as `raw_logits`.
44
+ self.viterbi_replace_logits = bool(viterbi_replace_logits)
45
+
46
+
47
+ __all__ = ["HaremPiiConfig"]
eval_confusion.png ADDED

Git LFS Details

  • SHA256: d3f48f7c2692cd641800d883952b19b7aa3ca2b1e3cfbe9846f082dcbd2a2b16
  • Pointer size: 131 Bytes
  • Size of remote file: 114 kB
eval_performance.png ADDED
eval_summary.png ADDED

Git LFS Details

  • SHA256: 77c7da68b68cfa9223fc8423ad2be6124e3c14e6d9f49d1d683db2630f2f5b71
  • Pointer size: 131 Bytes
  • Size of remote file: 167 kB
haremb.png ADDED

Git LFS Details

  • SHA256: 3d53d359c14545f778e4789ac53d40eb8a4b0fc0ed7bf05a72c101ee60e33db1
  • Pointer size: 131 Bytes
  • Size of remote file: 767 kB
infer.log ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Inference benchmark: A: openmed-base vs B: haremb
2
+ device : cuda dtype: torch.bfloat16
3
+ ctx : 1024
4
+
5
+ A: openmed-base (reference / teacher)
6
+ load : 0.71s
7
+ eval : 64.66s on 212,909 tokens (3293 tok/s)
8
+ Performance:
9
+ total params : 1399.61M (139.35M dense + 1260.26M MoE-experts)
10
+ active params / token : 178.73M (memory footprint — embed lookup + top_4/128 experts: 128.04M embed + 39.38M MoE-active + 11.31M attn/norm/head)
11
+ compute params / token : 50.69M (matmul FLOPs only — embedding lookup excluded)
12
+ GFLOP / token (fwd, MAC×2): 0.101
13
+ weights size (on disk) : —
14
+ weights size (in RAM) : 2.61 GiB
15
+ weights resident (GPU) : 2.61 GiB
16
+ peak GPU mem (eval, ctx=1024) : 3.30 GiB
17
+
18
+ B: haremb (this checkpoint)
19
+ load : 0.10s
20
+ eval : 33.56s on 212,909 tokens (6343 tok/s)
21
+ Performance:
22
+ total params : 287.11M (129.58M dense + 157.53M MoE-experts)
23
+ active params / token : 134.50M (memory footprint — embed lookup + top_4/128 experts: 128.04M embed + 4.92M MoE-active + 1.54M attn/norm/head)
24
+ compute params / token : 6.46M (matmul FLOPs only — embedding lookup excluded)
25
+ GFLOP / token (fwd, MAC×2): 0.013
26
+ weights size (on disk) : 547.6 MiB
27
+ weights size (in RAM) : 547.6 MiB
28
+ weights resident (GPU) : 548.3 MiB
29
+ peak GPU mem (eval, ctx=1024) : 1.22 GiB
30
+
31
+ B vs A (haremb vs openmed-base):
32
+ total params : 4.87× smaller
33
+ active params / token : 1.33× less [memory]
34
+ compute params / token : 7.85× cheaper [FLOPs]
35
+ GFLOP / token : 7.85× cheaper
36
+ weights size (on disk) : —
37
+ weights in RAM : 4.87× smaller
38
+ peak GPU mem (eval) : 2.70× less
39
+ throughput : 1.93× faster
40
+
41
+ Sample inference (load → tokenize → forward → viterbi-decode → spans):
42
+ text: 'Patient Sarah Johnson (DOB 03/15/1985), MRN 4872910, phone 415-555-0123, email sarah.johnson@example.com, credit card 4111-1111-1111-1111.'
43
+ forward latency: 65.8ms (53 tokens)
44
+ detected 7 spans:
45
+ [ 1, 2) first_name 'Sarah'
46
+ [ 2, 3) last_name 'Johnson'
47
+ [ 6, 12) date '03/15/1985'
48
+ [ 16, 19) phone_number '4872910'
49
+ [ 22, 28) phone_number '415-555-0123'
50
+ [ 30, 37) email 'sarah.johnson@example.com'
51
+ [ 41, 52) credit_debit_card '4111-1111-1111-1111'
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:64059006ac732bd608cc478ee579cc96f56174d8910603b6d4747688b130b8a2
3
+ size 574224842
modeling_haremb_pii.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HaremPii — 1-layer surgical inference wrapper over OpenAI Privacy Filter.
3
+
4
+ Defines:
5
+ * HaremPiiForTokenClassification — subclass of
6
+ OpenAIPrivacyFilterForTokenClassification. Reuses the upstream forward
7
+ pass and adds eval-time constrained-BIOES Viterbi decoding so
8
+ `outputs.logits.argmax(-1)` returns the Viterbi path.
9
+ * HaremPiiModel — encoder alias pinned to HaremPiiConfig.
10
+
11
+ The model class is auto-registered so
12
+ `AutoModelForTokenClassification.from_pretrained(repo, trust_remote_code=True)`
13
+ dispatches to us via `config.auto_map` (model_type "haremb_pii").
14
+
15
+ This file is the released, inference-only copy. It contains no
16
+ training-related utilities.
17
+ """
18
+ from __future__ import annotations
19
+
20
+ from typing import Optional
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+ from transformers import (
25
+ AutoConfig,
26
+ AutoModel,
27
+ AutoModelForTokenClassification,
28
+ )
29
+ from transformers.models.openai_privacy_filter.modeling_openai_privacy_filter import (
30
+ OpenAIPrivacyFilterForTokenClassification,
31
+ OpenAIPrivacyFilterModel,
32
+ )
33
+
34
+ from configuration_haremb_pii import HaremPiiConfig
35
+
36
+
37
+ # ---------------------------------------------------------------------------
38
+ # Constrained BIOES Viterbi (inlined so the checkpoint is self-contained)
39
+ # ---------------------------------------------------------------------------
40
+ # Transition rules:
41
+ # O -> {O, B-X, S-X}
42
+ # B-X -> {I-X, E-X}
43
+ # I-X -> {I-X, E-X}
44
+ # E-X -> {O, B-Y, S-Y}
45
+ # S-X -> {O, B-Y, S-Y}
46
+ # Initial state allows {O, B-X, S-X} only.
47
+
48
+ def _parse_bioes(label: str):
49
+ if label == "O" or "-" not in label:
50
+ return "O", None
51
+ pref, cat = label.split("-", 1)
52
+ return pref, cat
53
+
54
+
55
+ def _build_bioes_transition_mask(id2label) -> torch.Tensor:
56
+ C = len(id2label)
57
+ mask = torch.full((C, C), float("-inf"))
58
+ parsed = {i: _parse_bioes(id2label[i]) for i in range(C)}
59
+ for i, (p_prev, c_prev) in parsed.items():
60
+ for j, (p_cur, c_cur) in parsed.items():
61
+ ok = False
62
+ if p_prev == "O":
63
+ if p_cur in ("O", "B", "S"):
64
+ ok = True
65
+ elif p_prev == "B":
66
+ if p_cur in ("I", "E") and c_cur == c_prev:
67
+ ok = True
68
+ elif p_prev == "I":
69
+ if p_cur in ("I", "E") and c_cur == c_prev:
70
+ ok = True
71
+ elif p_prev in ("E", "S"):
72
+ if p_cur in ("O", "B", "S"):
73
+ ok = True
74
+ if ok:
75
+ mask[i, j] = 0.0
76
+ return mask
77
+
78
+
79
+ def _build_bioes_initial_mask(id2label) -> torch.Tensor:
80
+ C = len(id2label)
81
+ mask = torch.full((C,), float("-inf"))
82
+ for i, lbl in id2label.items():
83
+ p, _ = _parse_bioes(lbl)
84
+ if p in ("O", "B", "S"):
85
+ mask[i] = 0.0
86
+ return mask
87
+
88
+
89
+ def _bioes_viterbi(
90
+ logits: torch.Tensor,
91
+ transition_mask: torch.Tensor,
92
+ initial_mask: torch.Tensor,
93
+ ) -> torch.Tensor:
94
+ if logits.dim() != 2:
95
+ raise ValueError(f"expected 2D logits, got {logits.shape}")
96
+ T = logits.shape[0]
97
+ mask = torch.ones((1, T), dtype=torch.long, device=logits.device)
98
+ out = _bioes_viterbi_batched(
99
+ logits.unsqueeze(0), mask, transition_mask, initial_mask,
100
+ )
101
+ return out[0]
102
+
103
+
104
+ def _bioes_viterbi_batched(
105
+ logits: torch.Tensor,
106
+ attention_mask: torch.Tensor,
107
+ transition_mask: torch.Tensor,
108
+ initial_mask: torch.Tensor,
109
+ ) -> torch.Tensor:
110
+ """Vectorized constrained BIOES Viterbi.
111
+
112
+ Args:
113
+ logits: [B, T, C] float
114
+ attention_mask: [B, T] {0, 1} long/bool
115
+ transition_mask: [C, C] 0 valid, -inf invalid
116
+ initial_mask: [C] 0 allowed first tag, -inf forbidden
117
+
118
+ Returns:
119
+ [B, T] LongTensor of best constrained-BIOES tag id per token; padded
120
+ positions hold -1.
121
+ """
122
+ if logits.dim() != 3:
123
+ raise ValueError(f"expected 3D logits [B,T,C], got {logits.shape}")
124
+ device = logits.device
125
+ B, T, C = logits.shape
126
+ scores = logits.float()
127
+ trans = transition_mask.to(device).float()
128
+ init = initial_mask.to(device).float()
129
+ mask = attention_mask.to(device).bool()
130
+
131
+ dp = scores[:, 0] + init.unsqueeze(0)
132
+ back = torch.zeros((B, T, C), dtype=torch.long, device=device)
133
+ trans_b = trans.unsqueeze(0)
134
+ for t in range(1, T):
135
+ cand = dp.unsqueeze(2) + trans_b
136
+ best_val, best_prev = cand.max(dim=1)
137
+ new_dp = best_val + scores[:, t]
138
+ keep = mask[:, t].unsqueeze(1)
139
+ dp = torch.where(keep, new_dp, dp)
140
+ back[:, t] = best_prev
141
+
142
+ last_t = (mask.sum(dim=1) - 1).clamp_min(0)
143
+ best_last = dp.argmax(dim=1)
144
+ out = torch.full((B, T), -1, dtype=torch.long, device=device)
145
+ batch_idx = torch.arange(B, device=device)
146
+ out[batch_idx, last_t] = best_last
147
+ current = best_last.clone()
148
+ for t in range(T - 1, 0, -1):
149
+ new_current = torch.gather(
150
+ back[:, t, :], 1, current.unsqueeze(1)
151
+ ).squeeze(1)
152
+ active = (t <= last_t)
153
+ current = torch.where(active, new_current, current)
154
+ out[batch_idx, t - 1] = torch.where(
155
+ active, current, out[batch_idx, t - 1],
156
+ )
157
+ return out
158
+
159
+
160
+ # ---------------------------------------------------------------------------
161
+ # Architecture classes
162
+ # ---------------------------------------------------------------------------
163
+
164
+ class HaremPiiModel(OpenAIPrivacyFilterModel):
165
+ """Thin alias of the upstream encoder pinned to HaremPiiConfig."""
166
+ config_class = HaremPiiConfig
167
+
168
+
169
+ class HaremPiiForTokenClassification(OpenAIPrivacyFilterForTokenClassification):
170
+ """1-layer student. Wraps the upstream forward with eval-time
171
+ constrained-BIOES Viterbi decoding."""
172
+ config_class = HaremPiiConfig
173
+
174
+ def __init__(self, config: HaremPiiConfig):
175
+ # Bypass GenericForTokenClassification.__init__ because it calls
176
+ # AutoModel.from_config(config), which uses type(config) as the
177
+ # registry key. Under the trust_remote_code Hub-loading path the
178
+ # cached HaremPiiConfig class identity differs from whatever was
179
+ # registered at module import (the cache hosts the class under a
180
+ # synthetic, sha-qualified module name). Constructing the encoder
181
+ # directly avoids the registry dispatch entirely.
182
+ from transformers.modeling_utils import PreTrainedModel as _PreTrainedModel
183
+ _PreTrainedModel.__init__(self, config)
184
+ self.num_labels = config.num_labels
185
+ self.model = OpenAIPrivacyFilterModel(config)
186
+ if getattr(config, "classifier_dropout", None) is not None:
187
+ classifier_dropout = config.classifier_dropout
188
+ elif getattr(config, "hidden_dropout", None) is not None:
189
+ classifier_dropout = config.hidden_dropout
190
+ else:
191
+ classifier_dropout = 0.1
192
+ self.dropout = nn.Dropout(classifier_dropout)
193
+ self.score = nn.Linear(config.hidden_size, config.num_labels)
194
+ self.post_init()
195
+
196
+ self._viterbi_trans_mask = None
197
+ self._viterbi_init_mask = None
198
+
199
+ def _ensure_viterbi_masks(self):
200
+ if self._viterbi_trans_mask is None:
201
+ id2label = {int(k): v for k, v in self.config.id2label.items()}
202
+ self._viterbi_trans_mask = _build_bioes_transition_mask(id2label)
203
+ self._viterbi_init_mask = _build_bioes_initial_mask(id2label)
204
+ return self._viterbi_trans_mask, self._viterbi_init_mask
205
+
206
+ @torch.no_grad()
207
+ def decode_predictions(
208
+ self,
209
+ logits: torch.Tensor,
210
+ attention_mask: Optional[torch.Tensor] = None,
211
+ ) -> torch.Tensor:
212
+ trans, init = self._ensure_viterbi_masks()
213
+ if logits.dim() == 2:
214
+ T = logits.shape[0]
215
+ mask = torch.ones((1, T), dtype=torch.long, device=logits.device)
216
+ return _bioes_viterbi_batched(
217
+ logits.unsqueeze(0), mask, trans, init,
218
+ )[0]
219
+ if attention_mask is None:
220
+ attention_mask = torch.ones(
221
+ logits.shape[:2], dtype=torch.long, device=logits.device,
222
+ )
223
+ return _bioes_viterbi_batched(logits, attention_mask, trans, init)
224
+
225
+ def forward(self, *args, **kwargs):
226
+ outputs = super().forward(*args, **kwargs)
227
+ if self.training:
228
+ return outputs
229
+ if not getattr(self.config, "use_viterbi_decode", True):
230
+ return outputs
231
+
232
+ attn_mask = kwargs.get("attention_mask", None)
233
+ if attn_mask is None and len(args) >= 2:
234
+ attn_mask = args[1]
235
+
236
+ decoded = self.decode_predictions(outputs.logits, attention_mask=attn_mask)
237
+ try:
238
+ outputs.predicted_labels = decoded
239
+ except Exception:
240
+ outputs.__dict__["predicted_labels"] = decoded
241
+
242
+ if getattr(self.config, "viterbi_replace_logits", True):
243
+ raw = outputs.logits
244
+ fake = torch.full_like(raw, fill_value=-1e9)
245
+ fake.scatter_(-1, decoded.clamp_min(0).unsqueeze(-1), 1e9)
246
+ try:
247
+ outputs.raw_logits = raw
248
+ outputs.logits = fake
249
+ except Exception:
250
+ outputs.__dict__["raw_logits"] = raw
251
+ outputs.__dict__["logits"] = fake
252
+ return outputs
253
+
254
+
255
+ # --- Auto-registry ---
256
+ AutoConfig.register("haremb_pii", HaremPiiConfig, exist_ok=True)
257
+ AutoModel.register(HaremPiiConfig, HaremPiiModel, exist_ok=True)
258
+ AutoModelForTokenClassification.register(
259
+ HaremPiiConfig, HaremPiiForTokenClassification, exist_ok=True,
260
+ )
261
+
262
+ HaremPiiConfig.register_for_auto_class("AutoConfig")
263
+ HaremPiiForTokenClassification.register_for_auto_class("AutoModelForTokenClassification")
264
+
265
+
266
+ __all__ = [
267
+ "HaremPiiConfig",
268
+ "HaremPiiModel",
269
+ "HaremPiiForTokenClassification",
270
+ ]
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0614fe83cadab421296e664e1f48f4261fa8fef6e03e63bb75c20f38e37d07d3
3
+ size 27868174
tokenizer_config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "backend": "tokenizers",
3
+ "eos_token": "<|endoftext|>",
4
+ "is_local": false,
5
+ "local_files_only": false,
6
+ "model_input_names": [
7
+ "input_ids",
8
+ "attention_mask"
9
+ ],
10
+ "model_max_length": 128000,
11
+ "pad_token": "<|endoftext|>",
12
+ "tokenizer_class": "TokenizersBackend"
13
+ }