Upload folder using huggingface_hub
Browse files- .gitattributes +4 -0
- .gitignore +1 -0
- README.md +258 -0
- app.py +534 -0
- benchmark.py +1351 -0
- compare.log +188 -0
- config.json +788 -0
- configuration_haremb_pii.py +47 -0
- eval_confusion.png +3 -0
- eval_performance.png +0 -0
- eval_summary.png +3 -0
- haremb.png +3 -0
- infer.log +51 -0
- model.safetensors +3 -0
- modeling_haremb_pii.py +270 -0
- tokenizer.json +3 -0
- tokenizer_config.json +13 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
eval_confusion.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
eval_summary.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
haremb.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
__pycache__
|
README.md
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
language:
|
| 4 |
+
- en
|
| 5 |
+
pipeline_tag: token-classification
|
| 6 |
+
library_name: transformers
|
| 7 |
+
tags:
|
| 8 |
+
- pii
|
| 9 |
+
- privacy
|
| 10 |
+
- token-classification
|
| 11 |
+
- bioes
|
| 12 |
+
- moe
|
| 13 |
+
- haremb
|
| 14 |
+
base_model:
|
| 15 |
+
- OpenMed/privacy-filter-nemotron
|
| 16 |
+
datasets:
|
| 17 |
+
- nvidia/Nemotron-PII
|
| 18 |
+
---
|
| 19 |
+
|
| 20 |
+
# HarEmb · OpenMed-Nemotron PII
|
| 21 |
+
|
| 22 |
+
> A **single-layer** HarEmb model on the [`OpenMed/privacy-filter-nemotron`](https://huggingface.co/OpenMed/privacy-filter-nemotron) lineage. It has 287M total parameters and predicts the full **221-class BIOES** Nemotron-PII label space.
|
| 23 |
+
|
| 24 |
+
**Model**: [`fblgit/haremb-privacy-filter-opennemo`](https://huggingface.co/fblgit/haremb-privacy-filter-opennemo)
|
| 25 |
+
|
| 26 |
+

|
| 27 |
+
|
| 28 |
+
## Lineage
|
| 29 |
+
|
| 30 |
+
This model is the third leg of a three-step lineage:
|
| 31 |
+
|
| 32 |
+
1. **[`openai/privacy-filter`](https://huggingface.co/openai/privacy-filter)** — OpenAI's open release of the underlying 1.4B-parameter MoE backbone (8 transformer layers, ~50M active params/token, BIOES token classifier head).
|
| 33 |
+
2. **[`OpenMed/privacy-filter-nemotron`](https://huggingface.co/OpenMed/privacy-filter-nemotron)** — OpenMed's full fine-tune of that backbone on `nvidia/Nemotron-PII`, expanding the head to 221 BIOES classes (55 fine-grained PII categories).
|
| 34 |
+
3. **`haremb-privacy-filter-opennemo`** *(this model)* — a one-layer surgical slice of the OpenMed teacher.
|
| 35 |
+
|
| 36 |
+
## What this model does
|
| 37 |
+
|
| 38 |
+
Token-level PII classification over **55 Nemotron-PII categories**. Every token receives one of `O` or `{B, I, E, S}-<category>`, covering identity, contact, address, date/time, government ID, financial, healthcare, enterprise ID, vehicle, and digital identifier categories.
|
| 39 |
+
|
| 40 |
+
In `eval()` mode the model can run constrained-BIOES Viterbi decoding internally, so `outputs.logits.argmax(-1)` is span-coherent by default. See [Output semantics](#output-semantics) for the exact fields and opt-out flags.
|
| 41 |
+
|
| 42 |
+
## Evaluation
|
| 43 |
+
|
| 44 |
+
Evaluated on a 1% slice of `nvidia/Nemotron-PII:test` (1,000 documents, ctx 1024, seed 42), Viterbi-decoded. The benchmark and app both use the convention **A = `OpenMed/privacy-filter-nemotron` (teacher / baseline)**, **B = this checkpoint** (`haremb`); ratios are reported as **B ÷ A**.
|
| 45 |
+
|
| 46 |
+
### Quality (viterbi stream)
|
| 47 |
+
|
| 48 |
+
| metric | **A: OpenMed teacher** | **B: haremb** (this) | B − A |
|
| 49 |
+
|---|---:|---:|---:|
|
| 50 |
+
| span F1 | 0.9434 | **0.9288** | −0.0146 |
|
| 51 |
+
| span precision | 0.9531 | **0.9396** | −0.0135 |
|
| 52 |
+
| span recall | 0.9338 | **0.9182** | −0.0156 |
|
| 53 |
+
| token accuracy | 0.9900 | **0.9885** | −0.0015 |
|
| 54 |
+
| non-O recall | 0.9703 | **0.9637** | −0.0066 |
|
| 55 |
+
|
| 56 |
+
### Performance (same eval set, ctx 1024, bf16, single GPU)
|
| 57 |
+
|
| 58 |
+
| metric | **A: OpenMed teacher** | **B: haremb** | B vs A |
|
| 59 |
+
|---|---:|---:|---:|
|
| 60 |
+
| total params | 1,400M | **287M** | **4.87× smaller** |
|
| 61 |
+
| dense params | 139M | 130M | 1.07× smaller |
|
| 62 |
+
| MoE expert params | 1,260M | 158M | **7.97× smaller** |
|
| 63 |
+
| **active params / token** (memory) | 178.7M | **134.5M** | 1.33× smaller |
|
| 64 |
+
| **compute params / token** (FLOPs) | 50.7M | **6.5M** | **7.85× cheaper** |
|
| 65 |
+
| GFLOP / token (forward) | 0.101 | **0.013** | **7.85× cheaper** |
|
| 66 |
+
| weights on disk | (HF repo) | **548 MiB** | — |
|
| 67 |
+
| weights in RAM | 2,669 MiB | 548 MiB | **4.87× smaller** |
|
| 68 |
+
| peak GPU memory (eval) | 3.30 GiB | **1.22 GiB** | **2.70× less** |
|
| 69 |
+
| throughput | 3,275 tok/s | **6,361 tok/s** | **1.94× faster** |
|
| 70 |
+
|
| 71 |
+
`active params / token` estimates memory bandwidth pressure, while `compute params / token` estimates matmul FLOPs and excludes the embedding table row-gather. GFLOP/token is `2 × compute_params_per_token`. `infer.log` and `compare.log` contain the full breakdown, including peak GPU memory from `torch.cuda.max_memory_allocated`.
|
| 72 |
+
|
| 73 |
+

|
| 74 |
+
|
| 75 |
+
### Quality breakdown
|
| 76 |
+
|
| 77 |
+

|
| 78 |
+
|
| 79 |
+
### Per-category highlights (viterbi span F1)
|
| 80 |
+
|
| 81 |
+
**At or near 1.000 (B)** — `biometric_identifier`, `blood_type`, `coordinate`, `health_plan_beneficiary_number`, `ipv4`, `ipv6`, `license_plate`, `mac_address`, `national_id`, `postcode` (≥ 0.99 with ≥ 100 gold spans).
|
| 82 |
+
|
| 83 |
+
**Categories where B beats A** — `gender` (0.987 vs 0.841), `political_view` (0.872 vs 0.839), `religious_belief` (0.935 vs 0.926), `state` (0.908 vs 0.829), `language` (0.897 vs 0.804), `race_ethnicity` (0.864 vs 0.861), `country` (0.952 vs 0.936). Several "fuzzy" world-knowledge categories where the 1-layer student carries the right inductive bias.
|
| 84 |
+
|
| 85 |
+
**Categories where A leads** — `occupation` (0.727 vs 0.605), `company_name` (0.929 vs 0.776), `last_name` (0.976 vs 0.931), `first_name` (0.970 vs 0.930), `user_name` (0.961 vs 0.942). Identity-noun categories where the teacher's deeper-layer mixing helps.
|
| 86 |
+
|
| 87 |
+
### Token-outcome breakdown — A: OpenMed teacher vs B: haremb (viterbi)
|
| 88 |
+
|
| 89 |
+

|
| 90 |
+
|
| 91 |
+
## Quick start
|
| 92 |
+
|
| 93 |
+
### Recommended ��� via OpenMed
|
| 94 |
+
|
| 95 |
+
The OpenMed wrapper is the same UX the teacher card recommends and works on this checkpoint as a drop-in:
|
| 96 |
+
|
| 97 |
+
```bash
|
| 98 |
+
pip install -U "openmed[hf]"
|
| 99 |
+
```
|
| 100 |
+
|
| 101 |
+
```python
|
| 102 |
+
from openmed import extract_pii, deidentify
|
| 103 |
+
|
| 104 |
+
text = (
|
| 105 |
+
"Patient Sarah Johnson (DOB 03/15/1985), MRN 4872910, "
|
| 106 |
+
"phone 415-555-0123, email sarah.johnson@example.com."
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
result = extract_pii(text, model_name="fblgit/haremb-privacy-filter-opennemo")
|
| 110 |
+
for ent in result.entities:
|
| 111 |
+
print(f"{ent.label:30s} {ent.text!r} conf={ent.confidence:.2f}")
|
| 112 |
+
|
| 113 |
+
masked = deidentify(text, method="mask",
|
| 114 |
+
model_name="fblgit/haremb-privacy-filter-opennemo")
|
| 115 |
+
fake = deidentify(text, method="replace",
|
| 116 |
+
model_name="fblgit/haremb-privacy-filter-opennemo",
|
| 117 |
+
consistent=True, seed=42)
|
| 118 |
+
```
|
| 119 |
+
|
| 120 |
+
### HuggingFace `transformers` pipeline
|
| 121 |
+
|
| 122 |
+
```python
|
| 123 |
+
from transformers import pipeline
|
| 124 |
+
|
| 125 |
+
pipe = pipeline(
|
| 126 |
+
"token-classification",
|
| 127 |
+
model="fblgit/haremb-privacy-filter-opennemo",
|
| 128 |
+
tokenizer="fblgit/haremb-privacy-filter-opennemo",
|
| 129 |
+
trust_remote_code=True,
|
| 130 |
+
aggregation_strategy="simple",
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
pipe("Send the invoice to billing@acmecorp.io, account 1234-5678.")
|
| 134 |
+
# → [{'entity_group': 'email', 'word': 'billing@acmecorp.io', ...},
|
| 135 |
+
# {'entity_group': 'account_number', 'word': '1234-5678', ...}]
|
| 136 |
+
```
|
| 137 |
+
|
| 138 |
+
### Raw `transformers` API
|
| 139 |
+
|
| 140 |
+
```python
|
| 141 |
+
import torch
|
| 142 |
+
from transformers import AutoModelForTokenClassification, AutoTokenizer
|
| 143 |
+
|
| 144 |
+
repo = "fblgit/haremb-privacy-filter-opennemo"
|
| 145 |
+
model = AutoModelForTokenClassification.from_pretrained(
|
| 146 |
+
repo, trust_remote_code=True, dtype=torch.bfloat16,
|
| 147 |
+
).to("cuda").eval()
|
| 148 |
+
tok = AutoTokenizer.from_pretrained(repo)
|
| 149 |
+
|
| 150 |
+
enc = tok("My email is foo@bar.com.", return_tensors="pt").to("cuda")
|
| 151 |
+
with torch.no_grad():
|
| 152 |
+
out = model(**enc)
|
| 153 |
+
|
| 154 |
+
# By default, `outputs.logits.argmax(-1)` follows the Viterbi-decoded path.
|
| 155 |
+
labels = out.logits.argmax(-1)[0]
|
| 156 |
+
```
|
| 157 |
+
|
| 158 |
+
## Output semantics
|
| 159 |
+
|
| 160 |
+
The forward pass — in `eval()` mode — runs constrained-BIOES Viterbi over the per-token logits and attaches three things to the output:
|
| 161 |
+
|
| 162 |
+
- `outputs.logits` — a tensor whose `argmax(-1)` equals the Viterbi prediction (so HF `pipeline()` and naive `argmax` consumers get span-coherent predictions automatically).
|
| 163 |
+
- `outputs.predicted_labels` — a `[B, T]` LongTensor of Viterbi-decoded label ids (`-1` at padded positions).
|
| 164 |
+
- `outputs.raw_logits` — the original per-token logits, preserved for callers that want raw confidences.
|
| 165 |
+
|
| 166 |
+
To opt out:
|
| 167 |
+
|
| 168 |
+
```python
|
| 169 |
+
model.config.viterbi_replace_logits = False # raw logits in outputs.logits
|
| 170 |
+
model.config.use_viterbi_decode = False # also skip Viterbi entirely
|
| 171 |
+
```
|
| 172 |
+
|
| 173 |
+
The model supports the upstream context length (max position embeddings 131,072 tokens). Practical batch sizes depend on hardware; bf16 + batch 1 + full-length is comfortable on 24 GB.
|
| 174 |
+
|
| 175 |
+
## Limitations & intended use
|
| 176 |
+
|
| 177 |
+
- **English-only training data.** Nemotron-PII is predominantly English. Performance on non-English text is not guaranteed.
|
| 178 |
+
- **Synthetic training data.** Real clinical notes, legal documents, and live web text may show different surface forms. For high-stakes deployments, collect a domain-specific eval set and re-calibrate.
|
| 179 |
+
- **Fuzzier categories** — `occupation`, `company_name`, and identity nouns (`first_name`, `last_name`, `user_name`) carry more uncertainty than formatted identifiers; downstream pipelines that only need strict PII can ignore low-confidence predictions on these.
|
| 180 |
+
- **Not a substitute for legal compliance review.** Use alongside a governance layer (human review, deterministic regex pre-filters, etc.).
|
| 181 |
+
|
| 182 |
+
## Reproducibility
|
| 183 |
+
|
| 184 |
+
Every metric, log, and plot in this card is regenerated by the single-file [`benchmark.py`](benchmark.py) shipped alongside the weights:
|
| 185 |
+
|
| 186 |
+
```bash
|
| 187 |
+
python benchmark.py # full benchmark vs OpenMed teacher
|
| 188 |
+
python benchmark.py --no-base # skip teacher download (logs only)
|
| 189 |
+
python benchmark.py --no-plots # skip matplotlib (logs + JSON only)
|
| 190 |
+
python benchmark.py --eval-pct 0.1 # smaller slice for a quick check
|
| 191 |
+
```
|
| 192 |
+
|
| 193 |
+
Outputs into the model folder:
|
| 194 |
+
|
| 195 |
+
- `infer.log`
|
| 196 |
+
- `compare.log`
|
| 197 |
+
- `eval_summary.png`
|
| 198 |
+
- `eval_confusion.png`
|
| 199 |
+
- `eval_performance.png`
|
| 200 |
+
|
| 201 |
+
Raw per-doc eval data is held in memory only. Pass `--out` to write artifacts somewhere else.
|
| 202 |
+
|
| 203 |
+
The Gradio demo in [`app.py`](app.py) supports **side-by-side A-vs-B comparison** between any two token-classification checkpoints with the same label space. Defaults match the report convention: **A = OpenMed/privacy-filter-nemotron** (teacher / baseline), **B = this checkpoint**. Disable either model to run single-model inference; both expose a runtime "active experts per token" slider so you can sweep MoE routing density. From inside the model folder:
|
| 204 |
+
|
| 205 |
+
```bash
|
| 206 |
+
python app.py # A=OpenMed teacher, B=. (this)
|
| 207 |
+
python app.py --model-a /path/to/another/repo # swap baseline A
|
| 208 |
+
python app.py --model-b /path/to/another/repo # swap candidate B
|
| 209 |
+
python app.py --port 7860 --share # public share link
|
| 210 |
+
```
|
| 211 |
+
|
| 212 |
+
## License
|
| 213 |
+
|
| 214 |
+
Apache-2.0, same as the lineage. Subject to the license terms of [`openai/privacy-filter`](https://huggingface.co/openai/privacy-filter) and the dataset terms of [`nvidia/Nemotron-PII`](https://huggingface.co/datasets/nvidia/Nemotron-PII).
|
| 215 |
+
|
| 216 |
+
## Citation
|
| 217 |
+
|
| 218 |
+
```bibtex
|
| 219 |
+
@misc{haremb-privacy-filter-opennemo,
|
| 220 |
+
title = {HarEmb · OpenMed-Nemotron PII: a single-layer
|
| 221 |
+
privacy-filter slice with span-coherent inference},
|
| 222 |
+
author = {fblgit},
|
| 223 |
+
year = {2026},
|
| 224 |
+
publisher = {Hugging Face},
|
| 225 |
+
url = {https://huggingface.co/fblgit/haremb-privacy-filter-opennemo},
|
| 226 |
+
howpublished = {\url{https://huggingface.co/fblgit/haremb-privacy-filter-opennemo}},
|
| 227 |
+
note = {Single-transformer-layer model on the openai/privacy-filter →
|
| 228 |
+
OpenMed/privacy-filter-nemotron lineage; 287M total params,
|
| 229 |
+
221 BIOES classes (55 fine-grained PII categories), with
|
| 230 |
+
inlined constrained-BIOES Viterbi decoding so
|
| 231 |
+
outputs.logits.argmax(-1) is span-coherent.}
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
@misc{openmed-privacy-filter-nemotron,
|
| 235 |
+
title = {OpenMed/privacy-filter-nemotron: fine-grained PII extraction
|
| 236 |
+
with 55 categories},
|
| 237 |
+
author = {OpenMed},
|
| 238 |
+
year = {2026},
|
| 239 |
+
publisher = {Hugging Face},
|
| 240 |
+
url = {https://huggingface.co/OpenMed/privacy-filter-nemotron}
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
@misc{openai-privacy-filter,
|
| 244 |
+
title = {Privacy Filter},
|
| 245 |
+
author = {OpenAI},
|
| 246 |
+
year = {2026},
|
| 247 |
+
publisher = {Hugging Face},
|
| 248 |
+
url = {https://huggingface.co/openai/privacy-filter}
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
@misc{nvidia-nemotron-pii,
|
| 252 |
+
title = {Nemotron-PII},
|
| 253 |
+
author = {NVIDIA},
|
| 254 |
+
year = {2025},
|
| 255 |
+
publisher = {Hugging Face},
|
| 256 |
+
url = {https://huggingface.co/datasets/nvidia/Nemotron-PII}
|
| 257 |
+
}
|
| 258 |
+
```
|
app.py
ADDED
|
@@ -0,0 +1,534 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
HarEmb PII — local Gradio inference demo.
|
| 3 |
+
|
| 4 |
+
Upload a PDF (or paste text), pick a device (CPU / cuda:N), and the model
|
| 5 |
+
highlights detected PII spans across the 55-category Nemotron-PII taxonomy.
|
| 6 |
+
|
| 7 |
+
Install:
|
| 8 |
+
pip install "gradio>=4" "transformers>=4.45" torch pypdf accelerate
|
| 9 |
+
|
| 10 |
+
Run from inside this folder:
|
| 11 |
+
python app.py
|
| 12 |
+
"""
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import argparse
|
| 16 |
+
import re
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
from typing import Dict, List, Optional, Tuple
|
| 19 |
+
|
| 20 |
+
import gradio as gr
|
| 21 |
+
import torch
|
| 22 |
+
from pypdf import PdfReader
|
| 23 |
+
from transformers import pipeline
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# Default to loading from this folder so `python app.py` works in-place after
|
| 27 |
+
# downloading the repo. Override by setting --model-path on the CLI.
|
| 28 |
+
DEFAULT_MODEL = "."
|
| 29 |
+
|
| 30 |
+
CHUNK_CHARS = 400_000 # ~100k tokens; well under the model's 131k window
|
| 31 |
+
|
| 32 |
+
# 55 Nemotron-PII categories grouped for visual coherence; one color per
|
| 33 |
+
# coarse "family" so the highlight legend stays readable.
|
| 34 |
+
PALETTE: Dict[str, str] = {
|
| 35 |
+
# Identity (red)
|
| 36 |
+
"first_name": "#ef4444",
|
| 37 |
+
"last_name": "#ef4444",
|
| 38 |
+
"user_name": "#ef4444",
|
| 39 |
+
"company_name": "#ef4444",
|
| 40 |
+
"age": "#fb7185",
|
| 41 |
+
"gender": "#fb7185",
|
| 42 |
+
"race_ethnicity": "#fb7185",
|
| 43 |
+
"sexuality": "#fb7185",
|
| 44 |
+
"religious_belief": "#fb7185",
|
| 45 |
+
"political_view": "#fb7185",
|
| 46 |
+
"language": "#fb7185",
|
| 47 |
+
"education_level": "#fb7185",
|
| 48 |
+
"occupation": "#fb7185",
|
| 49 |
+
"employment_status": "#fb7185",
|
| 50 |
+
"blood_type": "#fb7185",
|
| 51 |
+
"biometric_identifier":"#fb7185",
|
| 52 |
+
# Contact (purple)
|
| 53 |
+
"email": "#8b5cf6",
|
| 54 |
+
"phone_number": "#a78bfa",
|
| 55 |
+
"fax_number": "#a78bfa",
|
| 56 |
+
"url": "#7c3aed",
|
| 57 |
+
# Address (green)
|
| 58 |
+
"street_address": "#10b981",
|
| 59 |
+
"city": "#34d399",
|
| 60 |
+
"county": "#34d399",
|
| 61 |
+
"state": "#34d399",
|
| 62 |
+
"country": "#34d399",
|
| 63 |
+
"postcode": "#34d399",
|
| 64 |
+
"coordinate": "#059669",
|
| 65 |
+
# Dates (blue)
|
| 66 |
+
"date": "#3b82f6",
|
| 67 |
+
"date_of_birth": "#60a5fa",
|
| 68 |
+
"date_time": "#60a5fa",
|
| 69 |
+
"time": "#60a5fa",
|
| 70 |
+
# Government IDs (orange)
|
| 71 |
+
"ssn": "#f97316",
|
| 72 |
+
"national_id": "#fb923c",
|
| 73 |
+
"tax_id": "#fb923c",
|
| 74 |
+
# Financial (amber)
|
| 75 |
+
"account_number": "#f59e0b",
|
| 76 |
+
"bank_routing_number": "#fbbf24",
|
| 77 |
+
"swift_bic": "#fbbf24",
|
| 78 |
+
"credit_debit_card": "#fbbf24",
|
| 79 |
+
"cvv": "#fbbf24",
|
| 80 |
+
"pin": "#fbbf24",
|
| 81 |
+
"password": "#d97706",
|
| 82 |
+
# Healthcare (pink)
|
| 83 |
+
"medical_record_number": "#ec4899",
|
| 84 |
+
"health_plan_beneficiary_number": "#f472b6",
|
| 85 |
+
# Enterprise IDs (cyan)
|
| 86 |
+
"customer_id": "#06b6d4",
|
| 87 |
+
"employee_id": "#06b6d4",
|
| 88 |
+
"unique_id": "#22d3ee",
|
| 89 |
+
"certificate_license_number": "#22d3ee",
|
| 90 |
+
# Vehicle (lime)
|
| 91 |
+
"license_plate": "#84cc16",
|
| 92 |
+
"vehicle_identifier": "#84cc16",
|
| 93 |
+
# Digital (indigo)
|
| 94 |
+
"ipv4": "#6366f1",
|
| 95 |
+
"ipv6": "#6366f1",
|
| 96 |
+
"mac_address": "#818cf8",
|
| 97 |
+
"device_identifier": "#818cf8",
|
| 98 |
+
"api_key": "#4f46e5",
|
| 99 |
+
"http_cookie": "#4f46e5",
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def list_devices() -> List[str]:
|
| 104 |
+
devs = ["cpu"]
|
| 105 |
+
if torch.cuda.is_available():
|
| 106 |
+
for i in range(torch.cuda.device_count()):
|
| 107 |
+
devs.append(f"cuda:{i}")
|
| 108 |
+
return devs
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
_pipe_cache: Dict[Tuple[str, str], object] = {}
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def get_pipe(model_path: str, device: str):
|
| 115 |
+
key = (model_path, device)
|
| 116 |
+
if key in _pipe_cache:
|
| 117 |
+
return _pipe_cache[key]
|
| 118 |
+
dtype = torch.bfloat16 if device.startswith("cuda") else torch.float32
|
| 119 |
+
pipe = pipeline(
|
| 120 |
+
"token-classification",
|
| 121 |
+
model=model_path,
|
| 122 |
+
tokenizer=model_path,
|
| 123 |
+
trust_remote_code=True,
|
| 124 |
+
aggregation_strategy="simple",
|
| 125 |
+
device=device,
|
| 126 |
+
torch_dtype=dtype,
|
| 127 |
+
)
|
| 128 |
+
_pipe_cache[key] = pipe
|
| 129 |
+
return pipe
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def apply_runtime_config(
|
| 133 |
+
pipe,
|
| 134 |
+
use_viterbi: bool,
|
| 135 |
+
viterbi_replace: bool,
|
| 136 |
+
top_k: Optional[int] = None,
|
| 137 |
+
) -> None:
|
| 138 |
+
cfg = pipe.model.config
|
| 139 |
+
if hasattr(cfg, "use_viterbi_decode"):
|
| 140 |
+
cfg.use_viterbi_decode = bool(use_viterbi)
|
| 141 |
+
if hasattr(cfg, "viterbi_replace_logits"):
|
| 142 |
+
cfg.viterbi_replace_logits = bool(viterbi_replace)
|
| 143 |
+
# Override the per-layer MoE top-k at inference. Both fields need to be
|
| 144 |
+
# set: `mlp.router.top_k` is the actual router top-k, and the upstream
|
| 145 |
+
# `mlp.num_experts` is misnamed (it's also the per-token top_k, not
|
| 146 |
+
# num_local_experts). top_k=None leaves the trained config alone.
|
| 147 |
+
if top_k is not None:
|
| 148 |
+
n_local = int(getattr(cfg, "num_local_experts", 128))
|
| 149 |
+
k = max(1, min(int(top_k), n_local))
|
| 150 |
+
for layer in pipe.model.model.layers:
|
| 151 |
+
mlp = getattr(layer, "mlp", None)
|
| 152 |
+
if mlp is None:
|
| 153 |
+
continue
|
| 154 |
+
router = getattr(mlp, "router", None)
|
| 155 |
+
if router is not None and hasattr(router, "top_k"):
|
| 156 |
+
router.top_k = k
|
| 157 |
+
if hasattr(mlp, "num_experts"):
|
| 158 |
+
mlp.num_experts = k
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def model_top_k_default(model_path: str) -> int:
|
| 162 |
+
"""Read the trained `num_experts_per_tok` from the model's config without
|
| 163 |
+
loading the weights. Falls back to 4 if the field isn't present."""
|
| 164 |
+
try:
|
| 165 |
+
from transformers import AutoConfig
|
| 166 |
+
cfg = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
| 167 |
+
return int(getattr(cfg, "num_experts_per_tok", 4))
|
| 168 |
+
except Exception:
|
| 169 |
+
return 4
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def model_num_experts(model_path: str) -> int:
|
| 173 |
+
"""Read `num_local_experts` from the model's config without loading
|
| 174 |
+
weights. Falls back to 128 if the field isn't present."""
|
| 175 |
+
try:
|
| 176 |
+
from transformers import AutoConfig
|
| 177 |
+
cfg = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
|
| 178 |
+
return int(getattr(cfg, "num_local_experts", 128))
|
| 179 |
+
except Exception:
|
| 180 |
+
return 128
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def clear_model_cache() -> str:
|
| 184 |
+
_pipe_cache.clear()
|
| 185 |
+
if torch.cuda.is_available():
|
| 186 |
+
torch.cuda.empty_cache()
|
| 187 |
+
return "Model cache cleared. Next run will reload weights."
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def extract_text(file_obj) -> str:
|
| 191 |
+
if file_obj is None:
|
| 192 |
+
return ""
|
| 193 |
+
path = file_obj.name if hasattr(file_obj, "name") else file_obj
|
| 194 |
+
p = Path(path)
|
| 195 |
+
if p.suffix.lower() == ".pdf":
|
| 196 |
+
reader = PdfReader(str(p))
|
| 197 |
+
return "\n\n".join((page.extract_text() or "") for page in reader.pages)
|
| 198 |
+
return p.read_text(encoding="utf-8", errors="replace")
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def chunk_text(text: str, max_chars: int = CHUNK_CHARS) -> List[Tuple[int, str]]:
|
| 202 |
+
if not text:
|
| 203 |
+
return []
|
| 204 |
+
if max_chars <= 0 or len(text) <= max_chars:
|
| 205 |
+
return [(0, text)]
|
| 206 |
+
pieces = re.split(r"(\n\s*\n)", text)
|
| 207 |
+
chunks: List[Tuple[int, str]] = []
|
| 208 |
+
cur, cur_off, pos = "", 0, 0
|
| 209 |
+
for piece in pieces:
|
| 210 |
+
if cur and len(cur) + len(piece) > max_chars and cur.strip():
|
| 211 |
+
chunks.append((cur_off, cur))
|
| 212 |
+
cur, cur_off = piece, pos
|
| 213 |
+
else:
|
| 214 |
+
if not cur:
|
| 215 |
+
cur_off = pos
|
| 216 |
+
cur += piece
|
| 217 |
+
pos += len(piece)
|
| 218 |
+
if cur.strip():
|
| 219 |
+
chunks.append((cur_off, cur))
|
| 220 |
+
return chunks
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def category_of(label: str) -> str:
|
| 224 |
+
if len(label) > 2 and label[1] == "-":
|
| 225 |
+
return label[2:]
|
| 226 |
+
return label
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def predict(
|
| 230 |
+
model_path: str,
|
| 231 |
+
device: str,
|
| 232 |
+
text: str,
|
| 233 |
+
aggregation: str,
|
| 234 |
+
use_viterbi: bool,
|
| 235 |
+
viterbi_replace: bool,
|
| 236 |
+
top_k: Optional[int] = None,
|
| 237 |
+
chunk_chars: int = CHUNK_CHARS,
|
| 238 |
+
) -> List[Dict]:
|
| 239 |
+
if not text.strip():
|
| 240 |
+
return []
|
| 241 |
+
pipe = get_pipe(model_path, device)
|
| 242 |
+
apply_runtime_config(pipe, use_viterbi, viterbi_replace, top_k=top_k)
|
| 243 |
+
spans: List[Dict] = []
|
| 244 |
+
for offset, chunk in chunk_text(text, max_chars=chunk_chars):
|
| 245 |
+
for ent in pipe(chunk, aggregation_strategy=aggregation):
|
| 246 |
+
label = ent.get("entity_group") or ent.get("entity") or ""
|
| 247 |
+
cat = category_of(label)
|
| 248 |
+
if cat not in PALETTE:
|
| 249 |
+
continue
|
| 250 |
+
s = ent["start"] + offset
|
| 251 |
+
e = ent["end"] + offset
|
| 252 |
+
spans.append({
|
| 253 |
+
"start": s, "end": e, "label": cat,
|
| 254 |
+
"score": float(ent["score"]),
|
| 255 |
+
"text": text[s:e],
|
| 256 |
+
})
|
| 257 |
+
spans.sort(key=lambda s: (s["start"], s["end"]))
|
| 258 |
+
return spans
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def to_highlight(text: str, spans: List[Dict]) -> List[Tuple[str, Optional[str]]]:
|
| 262 |
+
if not text:
|
| 263 |
+
return []
|
| 264 |
+
out: List[Tuple[str, Optional[str]]] = []
|
| 265 |
+
cur = 0
|
| 266 |
+
for s in spans:
|
| 267 |
+
if s["start"] < cur:
|
| 268 |
+
continue
|
| 269 |
+
if s["start"] > cur:
|
| 270 |
+
out.append((text[cur:s["start"]], None))
|
| 271 |
+
out.append((text[s["start"]:s["end"]], s["label"]))
|
| 272 |
+
cur = s["end"]
|
| 273 |
+
if cur < len(text):
|
| 274 |
+
out.append((text[cur:], None))
|
| 275 |
+
return out
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def fmt_spans(spans: List[Dict], max_rows: int = 60) -> str:
|
| 279 |
+
if not spans:
|
| 280 |
+
return "_No PII spans detected._"
|
| 281 |
+
rows = [
|
| 282 |
+
f"- `{s['label']}` `{s['text'][:80].replace('`', '')}` (score {s['score']:.2f})"
|
| 283 |
+
for s in spans[:max_rows]
|
| 284 |
+
]
|
| 285 |
+
more = f"\n\n_…+{len(spans) - max_rows} more_" if len(spans) > max_rows else ""
|
| 286 |
+
return f"**Detected {len(spans)} span(s):**\n" + "\n".join(rows) + more
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
# Build legend HTML for the categories present in PALETTE — one row per family
|
| 290 |
+
# (we still want it readable; show one swatch per unique color).
|
| 291 |
+
def _legend_html() -> str:
|
| 292 |
+
seen = {}
|
| 293 |
+
for name, c in PALETTE.items():
|
| 294 |
+
seen.setdefault(c, []).append(name)
|
| 295 |
+
rows = []
|
| 296 |
+
for c, names in seen.items():
|
| 297 |
+
chip = (f"<span style='background:{c};color:#fff;padding:.15rem .55rem;"
|
| 298 |
+
f"border-radius:.3rem;font-family:monospace;'>"
|
| 299 |
+
f"{names[0]}{(' +'+str(len(names)-1)) if len(names)>1 else ''}</span>")
|
| 300 |
+
rows.append(chip)
|
| 301 |
+
return ("<div style='display:flex;flex-wrap:wrap;gap:.4rem;font-size:.85rem;"
|
| 302 |
+
"margin:.25rem 0;'>" + "".join(rows) + "</div>")
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
LEGEND_HTML = _legend_html()
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def diff_spans(a: List[Dict], b: List[Dict]):
|
| 309 |
+
"""Return (only_in_a, only_in_b, agreed) span-lists. Keys are the
|
| 310 |
+
(start, end, label) triple — agreement requires identical category."""
|
| 311 |
+
key = lambda s: (s["start"], s["end"], s["label"])
|
| 312 |
+
sa = {key(s): s for s in a}
|
| 313 |
+
sb = {key(s): s for s in b}
|
| 314 |
+
only_a = [sa[k] for k in sa if k not in sb]
|
| 315 |
+
only_b = [sb[k] for k in sb if k not in sa]
|
| 316 |
+
both = [sa[k] for k in sa if k in sb]
|
| 317 |
+
return only_a, only_b, both
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
def fmt_diff(label_a: str, label_b: str,
|
| 321 |
+
only_a: List[Dict], only_b: List[Dict], agreed: List[Dict]) -> str:
|
| 322 |
+
def fmt(name: str, lst: List[Dict]) -> str:
|
| 323 |
+
if not lst:
|
| 324 |
+
return f"**{name}:** none"
|
| 325 |
+
rows = [
|
| 326 |
+
f"- `{s['label']}` `{s['text'][:80].replace('`', '')}` "
|
| 327 |
+
f" (score {s['score']:.2f})"
|
| 328 |
+
for s in lst[:30]
|
| 329 |
+
]
|
| 330 |
+
more = f"\n …+{len(lst) - 30} more" if len(lst) > 30 else ""
|
| 331 |
+
return f"**{name} ({len(lst)}):**\n" + "\n".join(rows) + more
|
| 332 |
+
|
| 333 |
+
return "\n\n".join([
|
| 334 |
+
fmt(f"Only {label_a}", only_a),
|
| 335 |
+
fmt(f"Only {label_b}", only_b),
|
| 336 |
+
fmt("Agreed by both", agreed),
|
| 337 |
+
])
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
def run(
|
| 341 |
+
file_obj, pasted_text, device,
|
| 342 |
+
model_a_path, model_b_path,
|
| 343 |
+
use_a, use_b,
|
| 344 |
+
aggregation, use_viterbi, viterbi_replace,
|
| 345 |
+
top_k_a, top_k_b,
|
| 346 |
+
min_score, chunk_chars,
|
| 347 |
+
):
|
| 348 |
+
text = extract_text(file_obj) if file_obj else (pasted_text or "")
|
| 349 |
+
if not text.strip():
|
| 350 |
+
return [], [], "_Provide a PDF, a text file, or pasted text._", ""
|
| 351 |
+
if not (use_a or use_b):
|
| 352 |
+
return [], [], "_Enable at least one model._", text
|
| 353 |
+
|
| 354 |
+
a_spans = (
|
| 355 |
+
predict(model_a_path, device, text, aggregation,
|
| 356 |
+
use_viterbi, viterbi_replace,
|
| 357 |
+
top_k=int(top_k_a), chunk_chars=int(chunk_chars))
|
| 358 |
+
if use_a else []
|
| 359 |
+
)
|
| 360 |
+
b_spans = (
|
| 361 |
+
predict(model_b_path, device, text, aggregation,
|
| 362 |
+
use_viterbi, viterbi_replace,
|
| 363 |
+
top_k=int(top_k_b), chunk_chars=int(chunk_chars))
|
| 364 |
+
if use_b else []
|
| 365 |
+
)
|
| 366 |
+
thr = float(min_score)
|
| 367 |
+
a_spans = [s for s in a_spans if s["score"] >= thr]
|
| 368 |
+
b_spans = [s for s in b_spans if s["score"] >= thr]
|
| 369 |
+
|
| 370 |
+
a_hl = to_highlight(text, a_spans) if use_a else []
|
| 371 |
+
b_hl = to_highlight(text, b_spans) if use_b else []
|
| 372 |
+
|
| 373 |
+
label_a = Path(model_a_path).name or model_a_path
|
| 374 |
+
label_b = Path(model_b_path).name or model_b_path
|
| 375 |
+
|
| 376 |
+
if use_a and use_b:
|
| 377 |
+
only_a, only_b, agreed = diff_spans(a_spans, b_spans)
|
| 378 |
+
diff_md = fmt_diff(label_a, label_b, only_a, only_b, agreed)
|
| 379 |
+
elif use_a:
|
| 380 |
+
diff_md = fmt_spans(a_spans)
|
| 381 |
+
elif use_b:
|
| 382 |
+
diff_md = fmt_spans(b_spans)
|
| 383 |
+
else:
|
| 384 |
+
diff_md = "_Enable a model._"
|
| 385 |
+
|
| 386 |
+
return a_hl, b_hl, diff_md, text
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
def build_ui(default_model_a: str, default_model_b: str) -> gr.Blocks:
|
| 390 |
+
a_default_k = model_top_k_default(default_model_a)
|
| 391 |
+
a_n_experts = model_num_experts(default_model_a)
|
| 392 |
+
b_default_k = model_top_k_default(default_model_b)
|
| 393 |
+
b_n_experts = model_num_experts(default_model_b)
|
| 394 |
+
|
| 395 |
+
with gr.Blocks(title="HarEmb PII") as demo:
|
| 396 |
+
gr.Markdown(
|
| 397 |
+
"# HarEmb · OpenMed-Nemotron PII\n"
|
| 398 |
+
"Detect PII across 55 categories of the Nemotron-PII taxonomy. "
|
| 399 |
+
"Run **two models side-by-side** to compare detections — by "
|
| 400 |
+
"default this checkpoint vs the OpenMed teacher it was distilled "
|
| 401 |
+
"from. Disable one model to view a single detection."
|
| 402 |
+
)
|
| 403 |
+
devices = list_devices()
|
| 404 |
+
with gr.Row():
|
| 405 |
+
device_dd = gr.Dropdown(devices, value=devices[0], label="Device", scale=1)
|
| 406 |
+
clear_btn = gr.Button("Clear model cache", variant="secondary", scale=1)
|
| 407 |
+
|
| 408 |
+
with gr.Row():
|
| 409 |
+
with gr.Column():
|
| 410 |
+
use_a = gr.Checkbox(value=True, label="Enable model A (teacher / baseline)")
|
| 411 |
+
model_a_tb = gr.Textbox(
|
| 412 |
+
value=default_model_a,
|
| 413 |
+
label="Model A — path / HF repo",
|
| 414 |
+
info="Default: OpenMed/privacy-filter-nemotron (teacher).",
|
| 415 |
+
)
|
| 416 |
+
top_k_a_sl = gr.Slider(
|
| 417 |
+
1, a_n_experts, value=a_default_k, step=1,
|
| 418 |
+
label=f"Active experts per token (top-k of {a_n_experts})",
|
| 419 |
+
info=f"Trained value: {a_default_k}. Lower = faster + less "
|
| 420 |
+
f"capacity per token. Higher = more compute, denser "
|
| 421 |
+
f"routing. Bypassing the trained value can drop "
|
| 422 |
+
f"quality — useful for ablations.",
|
| 423 |
+
)
|
| 424 |
+
with gr.Column():
|
| 425 |
+
use_b = gr.Checkbox(value=True, label="Enable model B (this checkpoint)")
|
| 426 |
+
model_b_tb = gr.Textbox(
|
| 427 |
+
value=default_model_b,
|
| 428 |
+
label="Model B — path / HF repo",
|
| 429 |
+
info="Default: ./ (this checkpoint).",
|
| 430 |
+
)
|
| 431 |
+
top_k_b_sl = gr.Slider(
|
| 432 |
+
1, b_n_experts, value=b_default_k, step=1,
|
| 433 |
+
label=f"Active experts per token (top-k of {b_n_experts})",
|
| 434 |
+
info=f"Trained value: {b_default_k}.",
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
with gr.Accordion("Inference settings", open=False):
|
| 438 |
+
with gr.Row():
|
| 439 |
+
aggregation_dd = gr.Dropdown(
|
| 440 |
+
["simple", "first", "max", "average", "none"],
|
| 441 |
+
value="simple",
|
| 442 |
+
label="aggregation_strategy",
|
| 443 |
+
info="how token-level labels are merged into spans",
|
| 444 |
+
)
|
| 445 |
+
viterbi_cb = gr.Checkbox(
|
| 446 |
+
value=True,
|
| 447 |
+
label="use_viterbi_decode",
|
| 448 |
+
info="constrained BIOES decoding (off = raw argmax)",
|
| 449 |
+
)
|
| 450 |
+
viterbi_replace_cb = gr.Checkbox(
|
| 451 |
+
value=True,
|
| 452 |
+
label="viterbi_replace_logits",
|
| 453 |
+
info="when on, outputs.logits.argmax(-1) returns the Viterbi path",
|
| 454 |
+
)
|
| 455 |
+
min_score_sl = gr.Slider(
|
| 456 |
+
0.0, 1.0, value=0.0, step=0.01,
|
| 457 |
+
label="min confidence",
|
| 458 |
+
info="filter out spans with score below this threshold",
|
| 459 |
+
)
|
| 460 |
+
chunk_sl = gr.Slider(
|
| 461 |
+
0, 500_000, value=CHUNK_CHARS, step=10_000,
|
| 462 |
+
label="chunk size (chars)",
|
| 463 |
+
info="0 = single pass; otherwise split on paragraphs at this size. "
|
| 464 |
+
"Model window ≈131k tokens (~500k chars).",
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
with gr.Row():
|
| 468 |
+
file_in = gr.File(label="PDF / text file", file_types=[".pdf", ".txt", ".md"])
|
| 469 |
+
text_in = gr.Textbox(
|
| 470 |
+
label="…or paste text",
|
| 471 |
+
lines=6,
|
| 472 |
+
placeholder=("Patient Sarah Johnson (DOB 03/15/1985), MRN 4872910, "
|
| 473 |
+
"phone 415-555-0123, email sarah.johnson@example.com."),
|
| 474 |
+
)
|
| 475 |
+
run_btn = gr.Button("Detect PII", variant="primary")
|
| 476 |
+
|
| 477 |
+
gr.HTML(LEGEND_HTML)
|
| 478 |
+
with gr.Row():
|
| 479 |
+
a_out = gr.HighlightedText(
|
| 480 |
+
label="Model A detections",
|
| 481 |
+
color_map=PALETTE,
|
| 482 |
+
show_legend=False,
|
| 483 |
+
combine_adjacent=False,
|
| 484 |
+
)
|
| 485 |
+
b_out = gr.HighlightedText(
|
| 486 |
+
label="Model B detections",
|
| 487 |
+
color_map=PALETTE,
|
| 488 |
+
show_legend=False,
|
| 489 |
+
combine_adjacent=False,
|
| 490 |
+
)
|
| 491 |
+
diff_out = gr.Markdown("_Run a detection to see the diff / span list._")
|
| 492 |
+
extracted_out = gr.Textbox(
|
| 493 |
+
label="Extracted text (read-only)", lines=6, interactive=False,
|
| 494 |
+
)
|
| 495 |
+
|
| 496 |
+
run_btn.click(
|
| 497 |
+
run,
|
| 498 |
+
[file_in, text_in, device_dd,
|
| 499 |
+
model_a_tb, model_b_tb, use_a, use_b,
|
| 500 |
+
aggregation_dd, viterbi_cb, viterbi_replace_cb,
|
| 501 |
+
top_k_a_sl, top_k_b_sl,
|
| 502 |
+
min_score_sl, chunk_sl],
|
| 503 |
+
[a_out, b_out, diff_out, extracted_out],
|
| 504 |
+
)
|
| 505 |
+
clear_btn.click(clear_model_cache, None, diff_out)
|
| 506 |
+
|
| 507 |
+
return demo
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
def parse_args() -> argparse.Namespace:
|
| 511 |
+
p = argparse.ArgumentParser(description="HarEmb PII — Gradio demo")
|
| 512 |
+
p.add_argument("--host", default="127.0.0.1", help="Bind address (default: 127.0.0.1)")
|
| 513 |
+
p.add_argument("--port", type=int, default=7860, help="Port (default: 7860)")
|
| 514 |
+
p.add_argument("--share", action="store_true", help="Create a public Gradio share link")
|
| 515 |
+
p.add_argument("--model-a", default="OpenMed/privacy-filter-nemotron",
|
| 516 |
+
help="Model A path / HF repo "
|
| 517 |
+
"(default: OpenMed/privacy-filter-nemotron — teacher)")
|
| 518 |
+
p.add_argument("--model-b", default=DEFAULT_MODEL,
|
| 519 |
+
help="Model B path / HF repo "
|
| 520 |
+
"(default: . — this checkpoint)")
|
| 521 |
+
return p.parse_args()
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
if __name__ == "__main__":
|
| 525 |
+
args = parse_args()
|
| 526 |
+
build_ui(
|
| 527 |
+
default_model_a=args.model_a,
|
| 528 |
+
default_model_b=args.model_b,
|
| 529 |
+
).launch(
|
| 530 |
+
server_name=args.host,
|
| 531 |
+
server_port=args.port,
|
| 532 |
+
share=args.share,
|
| 533 |
+
theme=gr.themes.Soft(),
|
| 534 |
+
)
|
benchmark.py
ADDED
|
@@ -0,0 +1,1351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
benchmark.py — single self-contained reproducibility script for
|
| 3 |
+
haremb-privacy-filter-opennemo.
|
| 4 |
+
|
| 5 |
+
Run from inside this folder:
|
| 6 |
+
|
| 7 |
+
python benchmark.py # default: cuda if available
|
| 8 |
+
python benchmark.py --device cpu # cpu fallback
|
| 9 |
+
python benchmark.py --eval-pct 0.5 # smaller slice
|
| 10 |
+
python benchmark.py --no-base # skip teacher download
|
| 11 |
+
|
| 12 |
+
Produces, in `--out` (default ./):
|
| 13 |
+
infer.log — sample inference timing + redaction example
|
| 14 |
+
compare.log — aggregate + per-category metrics, this model vs
|
| 15 |
+
OpenMed teacher (raw + viterbi streams), and
|
| 16 |
+
token-level pairwise breakdown.
|
| 17 |
+
eval_summary.png — bar charts of headline metrics + per-category
|
| 18 |
+
span-F1 (this vs teacher).
|
| 19 |
+
eval_confusion.png — token-level outcome breakdown on gold non-O
|
| 20 |
+
positions (this vs teacher).
|
| 21 |
+
eval_performance.png — model-size / compute / memory / throughput
|
| 22 |
+
comparison (this vs teacher), absolute + ratios.
|
| 23 |
+
|
| 24 |
+
This script does not import from training code. It vendors the small set
|
| 25 |
+
of helpers it needs (BIOES decoder, span builder, eval-set sampler,
|
| 26 |
+
metrics aggregator) so the model folder is self-contained.
|
| 27 |
+
"""
|
| 28 |
+
from __future__ import annotations
|
| 29 |
+
|
| 30 |
+
import argparse
|
| 31 |
+
import ast
|
| 32 |
+
import math
|
| 33 |
+
import os
|
| 34 |
+
import sys
|
| 35 |
+
import time
|
| 36 |
+
from collections import Counter, defaultdict
|
| 37 |
+
from pathlib import Path
|
| 38 |
+
from typing import Dict, List, Tuple
|
| 39 |
+
|
| 40 |
+
import numpy as np
|
| 41 |
+
import torch
|
| 42 |
+
from datasets import load_dataset
|
| 43 |
+
from torch.utils.data import DataLoader, Dataset
|
| 44 |
+
from tqdm.auto import tqdm
|
| 45 |
+
from transformers import AutoModelForTokenClassification, AutoTokenizer
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# ---------------------------------------------------------------------------
|
| 49 |
+
# Constants
|
| 50 |
+
# ---------------------------------------------------------------------------
|
| 51 |
+
|
| 52 |
+
SOURCE_DATASET = "nvidia/Nemotron-PII"
|
| 53 |
+
TEACHER = "OpenMed/privacy-filter-nemotron"
|
| 54 |
+
|
| 55 |
+
# 55 Nemotron-PII categories, alphabetically sorted (matches the order used
|
| 56 |
+
# when the model was trained, so id2label / label2id round-trip cleanly).
|
| 57 |
+
NEMOTRON_CATEGORIES: List[str] = sorted([
|
| 58 |
+
"account_number", "age", "api_key", "bank_routing_number",
|
| 59 |
+
"biometric_identifier", "blood_type", "certificate_license_number",
|
| 60 |
+
"city", "company_name", "coordinate", "country", "county",
|
| 61 |
+
"credit_debit_card", "customer_id", "cvv", "date", "date_of_birth",
|
| 62 |
+
"date_time", "device_identifier", "education_level", "email",
|
| 63 |
+
"employee_id", "employment_status", "fax_number", "first_name",
|
| 64 |
+
"gender", "health_plan_beneficiary_number", "http_cookie", "ipv4",
|
| 65 |
+
"ipv6", "language", "last_name", "license_plate", "mac_address",
|
| 66 |
+
"medical_record_number", "national_id", "occupation", "password",
|
| 67 |
+
"phone_number", "pin", "political_view", "postcode", "race_ethnicity",
|
| 68 |
+
"religious_belief", "sexuality", "ssn", "state", "street_address",
|
| 69 |
+
"swift_bic", "tax_id", "time", "unique_id", "url", "user_name",
|
| 70 |
+
"vehicle_identifier",
|
| 71 |
+
])
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def nemotron_native_label_space() -> Tuple[Dict[str, int], Dict[int, str]]:
|
| 75 |
+
"""O at id 0, then {B, I, E, S}-{cat} for each cat in alphabetical order."""
|
| 76 |
+
label2id: Dict[str, int] = {"O": 0}
|
| 77 |
+
nxt = 1
|
| 78 |
+
for cat in NEMOTRON_CATEGORIES:
|
| 79 |
+
for prefix in ("B", "I", "E", "S"):
|
| 80 |
+
label2id[f"{prefix}-{cat}"] = nxt
|
| 81 |
+
nxt += 1
|
| 82 |
+
id2label: Dict[int, str] = {v: k for k, v in label2id.items()}
|
| 83 |
+
return label2id, id2label
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
# ---------------------------------------------------------------------------
|
| 87 |
+
# Span parsing + char-level token alignment (vendored from the training data
|
| 88 |
+
# pipeline; identical logic, no training imports)
|
| 89 |
+
# ---------------------------------------------------------------------------
|
| 90 |
+
|
| 91 |
+
def _trim_span(text: str, start: int, end: int) -> Tuple[int, int]:
|
| 92 |
+
raw = text[start:end]
|
| 93 |
+
i = 0
|
| 94 |
+
while i < len(raw) and raw[i].isspace():
|
| 95 |
+
i += 1
|
| 96 |
+
j = len(raw)
|
| 97 |
+
while j > i and (raw[j - 1].isspace() or raw[j - 1] in ".,;:)"):
|
| 98 |
+
j -= 1
|
| 99 |
+
return start + i, start + j
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def _parse_spans(spans_str) -> List[dict]:
|
| 103 |
+
if isinstance(spans_str, list):
|
| 104 |
+
return spans_str
|
| 105 |
+
try:
|
| 106 |
+
return ast.literal_eval(spans_str)
|
| 107 |
+
except (SyntaxError, ValueError):
|
| 108 |
+
return []
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def _assign_native_bioes_labels(
|
| 112 |
+
text: str,
|
| 113 |
+
raw_spans: List[dict],
|
| 114 |
+
tokenizer,
|
| 115 |
+
max_length: int,
|
| 116 |
+
label2id: Dict[str, int],
|
| 117 |
+
min_overlap_frac: float = 0.5,
|
| 118 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 119 |
+
pf_like: List[Tuple[int, int, str]] = []
|
| 120 |
+
for s in raw_spans:
|
| 121 |
+
cat = s.get("label")
|
| 122 |
+
if not cat:
|
| 123 |
+
continue
|
| 124 |
+
st, en = _trim_span(text, int(s["start"]), int(s["end"]))
|
| 125 |
+
if st >= en:
|
| 126 |
+
continue
|
| 127 |
+
pf_like.append((st, en, cat))
|
| 128 |
+
pf_like.sort(key=lambda x: (x[0], -x[1]))
|
| 129 |
+
|
| 130 |
+
enc = tokenizer(
|
| 131 |
+
text, truncation=True, max_length=max_length,
|
| 132 |
+
padding="max_length", return_offsets_mapping=True, return_tensors="pt",
|
| 133 |
+
)
|
| 134 |
+
input_ids = enc.input_ids[0]
|
| 135 |
+
attention_mask = enc.attention_mask[0]
|
| 136 |
+
offsets = enc.offset_mapping[0].tolist()
|
| 137 |
+
|
| 138 |
+
o_id = label2id["O"]
|
| 139 |
+
label_ids = [o_id] * len(input_ids)
|
| 140 |
+
locked = [False] * len(input_ids)
|
| 141 |
+
|
| 142 |
+
for span_start, span_end, cat in pf_like:
|
| 143 |
+
tok_indices: List[int] = []
|
| 144 |
+
for ti, (s, e) in enumerate(offsets):
|
| 145 |
+
if s == 0 and e == 0:
|
| 146 |
+
continue
|
| 147 |
+
if e <= span_start or s >= span_end:
|
| 148 |
+
continue
|
| 149 |
+
tok_len = e - s
|
| 150 |
+
if tok_len <= 0:
|
| 151 |
+
continue
|
| 152 |
+
overlap = min(e, span_end) - max(s, span_start)
|
| 153 |
+
if overlap / tok_len >= min_overlap_frac:
|
| 154 |
+
tok_indices.append(ti)
|
| 155 |
+
|
| 156 |
+
if not tok_indices:
|
| 157 |
+
continue
|
| 158 |
+
tok_indices = [ti for ti in tok_indices if not locked[ti]]
|
| 159 |
+
if not tok_indices:
|
| 160 |
+
continue
|
| 161 |
+
|
| 162 |
+
if len(tok_indices) == 1:
|
| 163 |
+
tag = f"S-{cat}"
|
| 164 |
+
if tag in label2id:
|
| 165 |
+
label_ids[tok_indices[0]] = label2id[tag]
|
| 166 |
+
locked[tok_indices[0]] = True
|
| 167 |
+
else:
|
| 168 |
+
b_tag, i_tag, e_tag = f"B-{cat}", f"I-{cat}", f"E-{cat}"
|
| 169 |
+
if b_tag in label2id:
|
| 170 |
+
label_ids[tok_indices[0]] = label2id[b_tag]
|
| 171 |
+
locked[tok_indices[0]] = True
|
| 172 |
+
for ti in tok_indices[1:-1]:
|
| 173 |
+
if i_tag in label2id:
|
| 174 |
+
label_ids[ti] = label2id[i_tag]
|
| 175 |
+
locked[ti] = True
|
| 176 |
+
if e_tag in label2id:
|
| 177 |
+
label_ids[tok_indices[-1]] = label2id[e_tag]
|
| 178 |
+
locked[tok_indices[-1]] = True
|
| 179 |
+
|
| 180 |
+
label_tensor = torch.tensor(label_ids, dtype=torch.long)
|
| 181 |
+
label_tensor[attention_mask == 0] = -100
|
| 182 |
+
return input_ids, attention_mask, label_tensor
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
class _NemotronEvalDataset(Dataset):
|
| 186 |
+
def __init__(self, hf_split, tokenizer, label2id, max_length):
|
| 187 |
+
self.hf = hf_split
|
| 188 |
+
self.tok = tokenizer
|
| 189 |
+
self.l2i = label2id
|
| 190 |
+
self.maxlen = max_length
|
| 191 |
+
|
| 192 |
+
def __len__(self):
|
| 193 |
+
return len(self.hf)
|
| 194 |
+
|
| 195 |
+
def __getitem__(self, idx):
|
| 196 |
+
ex = self.hf[idx]
|
| 197 |
+
ids, mask, labels = _assign_native_bioes_labels(
|
| 198 |
+
ex["text"], _parse_spans(ex["spans"]),
|
| 199 |
+
self.tok, self.maxlen, self.l2i,
|
| 200 |
+
)
|
| 201 |
+
L = int(mask.sum().item())
|
| 202 |
+
return {
|
| 203 |
+
"input_ids": ids[:L].tolist(),
|
| 204 |
+
"labels": labels[:L].tolist(),
|
| 205 |
+
"valid_len": L,
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def _make_collate(pad_token_id, max_length):
|
| 210 |
+
def _c(batch):
|
| 211 |
+
ids_list = [list(ex["input_ids"])[:max_length] for ex in batch]
|
| 212 |
+
labels_list = [list(ex["labels"])[:max_length] for ex in batch]
|
| 213 |
+
max_len = max(len(x) for x in ids_list)
|
| 214 |
+
B = len(batch)
|
| 215 |
+
input_ids = torch.full((B, max_len), pad_token_id, dtype=torch.long)
|
| 216 |
+
attention_mask = torch.zeros((B, max_len), dtype=torch.long)
|
| 217 |
+
labels = torch.full((B, max_len), -100, dtype=torch.long)
|
| 218 |
+
for i, (ids, lab) in enumerate(zip(ids_list, labels_list)):
|
| 219 |
+
L = len(ids)
|
| 220 |
+
input_ids[i, :L] = torch.tensor(ids, dtype=torch.long)
|
| 221 |
+
attention_mask[i, :L] = 1
|
| 222 |
+
labels[i, :L] = torch.tensor(lab, dtype=torch.long)
|
| 223 |
+
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
|
| 224 |
+
return _c
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def _build_eval_streaming(test_split, target_n, chunk_size, seed) -> List[int]:
|
| 228 |
+
"""Uniform chunked sampling, identical to the training-time eval split."""
|
| 229 |
+
n_total = len(test_split)
|
| 230 |
+
target_n = min(target_n, n_total)
|
| 231 |
+
if target_n <= 0:
|
| 232 |
+
return []
|
| 233 |
+
rng = np.random.RandomState(seed)
|
| 234 |
+
per_chunk = max(1, math.ceil(chunk_size * target_n / n_total))
|
| 235 |
+
selected: List[int] = []
|
| 236 |
+
for chunk_start in range(0, n_total, chunk_size):
|
| 237 |
+
if len(selected) >= target_n:
|
| 238 |
+
break
|
| 239 |
+
chunk_end = min(chunk_start + chunk_size, n_total)
|
| 240 |
+
n_in_chunk = chunk_end - chunk_start
|
| 241 |
+
n_to_pick = min(per_chunk, n_in_chunk, target_n - len(selected))
|
| 242 |
+
if n_to_pick <= 0:
|
| 243 |
+
break
|
| 244 |
+
offsets = rng.choice(n_in_chunk, size=n_to_pick, replace=False)
|
| 245 |
+
selected.extend(int(chunk_start + o) for o in offsets)
|
| 246 |
+
return sorted(selected[:target_n])
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
# ---------------------------------------------------------------------------
|
| 250 |
+
# BIOES → spans + metrics
|
| 251 |
+
# ---------------------------------------------------------------------------
|
| 252 |
+
|
| 253 |
+
def _bioes_to_spans(labels, id2label, o_id=0):
|
| 254 |
+
"""Convert per-token BIOES label ids to a set of (start, end, cat)."""
|
| 255 |
+
spans = set()
|
| 256 |
+
cur_start = None
|
| 257 |
+
cur_cat = None
|
| 258 |
+
for i, lid in enumerate(labels):
|
| 259 |
+
lid = int(lid)
|
| 260 |
+
if lid == o_id or lid < 0:
|
| 261 |
+
if cur_start is not None:
|
| 262 |
+
spans.add((cur_start, i, cur_cat))
|
| 263 |
+
cur_start = None
|
| 264 |
+
cur_cat = None
|
| 265 |
+
continue
|
| 266 |
+
tag = id2label.get(lid, "O")
|
| 267 |
+
if tag == "O" or "-" not in tag:
|
| 268 |
+
if cur_start is not None:
|
| 269 |
+
spans.add((cur_start, i, cur_cat))
|
| 270 |
+
cur_start = None
|
| 271 |
+
cur_cat = None
|
| 272 |
+
continue
|
| 273 |
+
prefix, cat = tag.split("-", 1)
|
| 274 |
+
if prefix == "S":
|
| 275 |
+
if cur_start is not None:
|
| 276 |
+
spans.add((cur_start, i, cur_cat))
|
| 277 |
+
spans.add((i, i + 1, cat))
|
| 278 |
+
cur_start = None
|
| 279 |
+
cur_cat = None
|
| 280 |
+
elif prefix == "B":
|
| 281 |
+
if cur_start is not None:
|
| 282 |
+
spans.add((cur_start, i, cur_cat))
|
| 283 |
+
cur_start = i
|
| 284 |
+
cur_cat = cat
|
| 285 |
+
elif prefix == "I":
|
| 286 |
+
if cur_start is None or cur_cat != cat:
|
| 287 |
+
if cur_start is not None:
|
| 288 |
+
spans.add((cur_start, i, cur_cat))
|
| 289 |
+
cur_start = i
|
| 290 |
+
cur_cat = cat
|
| 291 |
+
elif prefix == "E":
|
| 292 |
+
if cur_start is None or cur_cat != cat:
|
| 293 |
+
if cur_start is not None:
|
| 294 |
+
spans.add((cur_start, i, cur_cat))
|
| 295 |
+
spans.add((i, i + 1, cat))
|
| 296 |
+
cur_start = None
|
| 297 |
+
cur_cat = None
|
| 298 |
+
else:
|
| 299 |
+
spans.add((cur_start, i + 1, cur_cat))
|
| 300 |
+
cur_start = None
|
| 301 |
+
cur_cat = None
|
| 302 |
+
if cur_start is not None:
|
| 303 |
+
spans.add((cur_start, len(labels), cur_cat))
|
| 304 |
+
return spans
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
def _aggregate_span_metrics(gold_spans, pred_spans):
|
| 308 |
+
correct = gold_spans & pred_spans
|
| 309 |
+
n_gold = len(gold_spans)
|
| 310 |
+
n_pred = len(pred_spans)
|
| 311 |
+
n_correct = len(correct)
|
| 312 |
+
p = n_correct / n_pred if n_pred else 0.0
|
| 313 |
+
r = n_correct / n_gold if n_gold else 0.0
|
| 314 |
+
f1 = (2 * p * r / (p + r)) if (p + r) else 0.0
|
| 315 |
+
per_cat: Dict[str, dict] = {}
|
| 316 |
+
cats = sorted({c for _, _, c in gold_spans} | {c for _, _, c in pred_spans})
|
| 317 |
+
for cat in cats:
|
| 318 |
+
g_c = {s for s in gold_spans if s[2] == cat}
|
| 319 |
+
p_c = {s for s in pred_spans if s[2] == cat}
|
| 320 |
+
c_c = g_c & p_c
|
| 321 |
+
pp = len(c_c) / len(p_c) if p_c else 0.0
|
| 322 |
+
rr = len(c_c) / len(g_c) if g_c else 0.0
|
| 323 |
+
ff = (2 * pp * rr / (pp + rr)) if (pp + rr) else 0.0
|
| 324 |
+
per_cat[cat] = {"precision": pp, "recall": rr, "f1": ff,
|
| 325 |
+
"n_gold": len(g_c), "n_pred": len(p_c), "n_correct": len(c_c)}
|
| 326 |
+
return {"precision": p, "recall": r, "f1": f1,
|
| 327 |
+
"n_gold": n_gold, "n_pred": n_pred, "n_correct": n_correct,
|
| 328 |
+
"per_cat": per_cat}
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def _stream_metrics(docs, stream, id2label, o_id):
|
| 332 |
+
"""Aggregate metrics over a list of {gold, raw, viterbi} per-doc dicts."""
|
| 333 |
+
n_tokens = correct = n_non_o = non_o_correct = 0
|
| 334 |
+
gold_spans_all: set = set()
|
| 335 |
+
pred_spans_all: set = set()
|
| 336 |
+
doc_offset = 0
|
| 337 |
+
for doc in docs:
|
| 338 |
+
gold = [int(x) for x in doc["gold"]]
|
| 339 |
+
pred = [int(x) for x in doc[stream]]
|
| 340 |
+
n = len(gold)
|
| 341 |
+
n_tokens += n
|
| 342 |
+
for g, p in zip(gold, pred):
|
| 343 |
+
n_non_o += int(g != o_id)
|
| 344 |
+
if g == p:
|
| 345 |
+
correct += 1
|
| 346 |
+
if g != o_id:
|
| 347 |
+
non_o_correct += 1
|
| 348 |
+
gs = _bioes_to_spans(gold, id2label, o_id)
|
| 349 |
+
ps = _bioes_to_spans(pred, id2label, o_id)
|
| 350 |
+
gold_spans_all.update((doc_offset + s, doc_offset + e, c) for s, e, c in gs)
|
| 351 |
+
pred_spans_all.update((doc_offset + s, doc_offset + e, c) for s, e, c in ps)
|
| 352 |
+
doc_offset += n
|
| 353 |
+
span_m = _aggregate_span_metrics(gold_spans_all, pred_spans_all)
|
| 354 |
+
return {
|
| 355 |
+
"n_tokens": n_tokens,
|
| 356 |
+
"n_non_o": n_non_o,
|
| 357 |
+
"token_acc": correct / n_tokens if n_tokens else 0.0,
|
| 358 |
+
"non_o_recall": non_o_correct / n_non_o if n_non_o else 0.0,
|
| 359 |
+
"span_precision": span_m["precision"],
|
| 360 |
+
"span_recall": span_m["recall"],
|
| 361 |
+
"span_f1": span_m["f1"],
|
| 362 |
+
"n_gold_spans": span_m["n_gold"],
|
| 363 |
+
"n_pred_spans": span_m["n_pred"],
|
| 364 |
+
"n_correct_spans": span_m["n_correct"],
|
| 365 |
+
"span_per_cat": span_m["per_cat"],
|
| 366 |
+
}
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
# ---------------------------------------------------------------------------
|
| 370 |
+
# Forward pass + viterbi (delegates to the released modeling)
|
| 371 |
+
# ---------------------------------------------------------------------------
|
| 372 |
+
|
| 373 |
+
def _model_perf_stats(model, dtype) -> dict:
|
| 374 |
+
"""Total / active / compute / MoE breakdown + on-device byte size.
|
| 375 |
+
|
| 376 |
+
Three distinct param counts:
|
| 377 |
+
|
| 378 |
+
* total_params — every parameter the model has on disk / in RAM.
|
| 379 |
+
* active_params_per_tok — params *touched* during one token's forward
|
| 380 |
+
pass (memory footprint per token). Counts
|
| 381 |
+
the embedding because the embedding row
|
| 382 |
+
IS read per token; counts only top-k of
|
| 383 |
+
num_experts MoE experts because routing is
|
| 384 |
+
sparse.
|
| 385 |
+
* compute_params_per_tok — params that contribute matmul FLOPs per
|
| 386 |
+
token. EXCLUDES the embedding table:
|
| 387 |
+
`embed_tokens.weight` is a gather (one row
|
| 388 |
+
read), not a matmul, so its FLOP cost is
|
| 389 |
+
negligible (~hidden_size ops vs the table
|
| 390 |
+
having ~vocab × hidden params). Counting
|
| 391 |
+
it via the standard "2 × params" matmul
|
| 392 |
+
approximation hugely inflates the apparent
|
| 393 |
+
GFLOP/token and compresses the ratio between
|
| 394 |
+
deep and shallow models.
|
| 395 |
+
|
| 396 |
+
GFLOP/token is computed from `compute_params_per_tok`, not from
|
| 397 |
+
`active_params_per_tok`. This makes the metric reflect actual layer-wise
|
| 398 |
+
computational cost.
|
| 399 |
+
"""
|
| 400 |
+
cfg = model.config
|
| 401 |
+
num_experts = int(getattr(cfg, "num_local_experts", 1))
|
| 402 |
+
top_k = int(getattr(cfg, "num_experts_per_tok", num_experts))
|
| 403 |
+
expert_frac = top_k / max(1, num_experts)
|
| 404 |
+
|
| 405 |
+
moe_total = 0
|
| 406 |
+
moe_active = 0
|
| 407 |
+
other_total = 0
|
| 408 |
+
embed_total = 0 # gather-only params; excluded from FLOP estimate
|
| 409 |
+
for name, p in model.named_parameters():
|
| 410 |
+
n = p.numel()
|
| 411 |
+
# MoE expert tensors are stacked along an experts axis. The upstream
|
| 412 |
+
# exposes them under `mlp.experts.*`. Only `top_k` of `num_experts`
|
| 413 |
+
# experts contribute per token.
|
| 414 |
+
if ".mlp.experts." in name:
|
| 415 |
+
moe_total += n
|
| 416 |
+
moe_active += int(round(n * expert_frac))
|
| 417 |
+
# `embed_tokens.weight` (and any other lookup-style table) is a
|
| 418 |
+
# gather: one row of [vocab, hidden] is read per token, costing
|
| 419 |
+
# ~hidden ops, not 2 × vocab × hidden FLOPs. Tracked separately
|
| 420 |
+
# so it doesn't pollute the FLOP estimate.
|
| 421 |
+
elif "embed_tokens" in name:
|
| 422 |
+
embed_total += n
|
| 423 |
+
other_total += n
|
| 424 |
+
else:
|
| 425 |
+
other_total += n
|
| 426 |
+
|
| 427 |
+
total = moe_total + other_total
|
| 428 |
+
active_per_tok = moe_active + other_total
|
| 429 |
+
|
| 430 |
+
# Compute-relevant params per token: matmuls only. Drop the embedding
|
| 431 |
+
# lookup, which contributes ~zero FLOPs.
|
| 432 |
+
compute_per_tok = active_per_tok - embed_total
|
| 433 |
+
bytes_per_param = {torch.bfloat16: 2, torch.float16: 2, torch.float32: 4}.get(dtype, 4)
|
| 434 |
+
storage_bytes = total * bytes_per_param # in-memory dense weights
|
| 435 |
+
# GFLOP/token over matmul params only. Matmul ≈ 2 FLOPs per param.
|
| 436 |
+
gflops_per_tok = 2 * compute_per_tok / 1e9
|
| 437 |
+
return {
|
| 438 |
+
"total_params": total,
|
| 439 |
+
"moe_total": moe_total,
|
| 440 |
+
"moe_active_per_tok": moe_active,
|
| 441 |
+
"other_total": other_total,
|
| 442 |
+
"embed_total": embed_total,
|
| 443 |
+
"active_params_per_tok": active_per_tok,
|
| 444 |
+
"compute_params_per_tok": compute_per_tok,
|
| 445 |
+
"num_experts": num_experts,
|
| 446 |
+
"experts_per_tok": top_k,
|
| 447 |
+
"expert_frac": expert_frac,
|
| 448 |
+
"weight_bytes": storage_bytes,
|
| 449 |
+
"gflops_per_tok": gflops_per_tok,
|
| 450 |
+
}
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
def _disk_size_bytes(model_path: str) -> int:
|
| 454 |
+
"""Sum on-disk size of weight files at the given path. Falls back to 0
|
| 455 |
+
if the path is a HF repo id (not a local directory)."""
|
| 456 |
+
p = Path(model_path)
|
| 457 |
+
if not p.is_dir():
|
| 458 |
+
return 0
|
| 459 |
+
total = 0
|
| 460 |
+
for f in p.iterdir():
|
| 461 |
+
if f.is_file() and f.suffix in {".safetensors", ".bin", ".pt", ".pth"}:
|
| 462 |
+
total += f.stat().st_size
|
| 463 |
+
return total
|
| 464 |
+
|
| 465 |
+
|
| 466 |
+
def _eval_one_model(
|
| 467 |
+
model_path: str, tokenizer, eval_ds, label2id, id2label, o_id,
|
| 468 |
+
bioes_trans, bioes_init, batch_size, max_length, device, dtype,
|
| 469 |
+
label: str,
|
| 470 |
+
):
|
| 471 |
+
print(f"[eval] loading {label} from {model_path} ...", flush=True)
|
| 472 |
+
if torch.cuda.is_available() and device.type == "cuda":
|
| 473 |
+
torch.cuda.reset_peak_memory_stats(device)
|
| 474 |
+
torch.cuda.empty_cache()
|
| 475 |
+
mem_before = (torch.cuda.memory_allocated(device)
|
| 476 |
+
if torch.cuda.is_available() and device.type == "cuda" else 0)
|
| 477 |
+
t_load = time.time()
|
| 478 |
+
model = AutoModelForTokenClassification.from_pretrained(
|
| 479 |
+
model_path, dtype=dtype, trust_remote_code=True,
|
| 480 |
+
).to(device).eval()
|
| 481 |
+
if hasattr(model.config, "use_viterbi_decode"):
|
| 482 |
+
model.config.use_viterbi_decode = True
|
| 483 |
+
if hasattr(model.config, "viterbi_replace_logits"):
|
| 484 |
+
model.config.viterbi_replace_logits = False
|
| 485 |
+
load_s = time.time() - t_load
|
| 486 |
+
perf = _model_perf_stats(model, dtype)
|
| 487 |
+
perf["disk_size_bytes"] = _disk_size_bytes(model_path)
|
| 488 |
+
weights_resident_bytes = (
|
| 489 |
+
torch.cuda.memory_allocated(device) - mem_before
|
| 490 |
+
if torch.cuda.is_available() and device.type == "cuda" else perf["weight_bytes"]
|
| 491 |
+
)
|
| 492 |
+
perf["weights_resident_bytes"] = weights_resident_bytes
|
| 493 |
+
|
| 494 |
+
# Vendor the batched viterbi (re-import from the released modeling file).
|
| 495 |
+
from modeling_haremb_pii import _bioes_viterbi_batched
|
| 496 |
+
|
| 497 |
+
pad_token_id = tokenizer.pad_token_id or 199999
|
| 498 |
+
loader = DataLoader(
|
| 499 |
+
eval_ds, batch_size=batch_size, shuffle=False,
|
| 500 |
+
collate_fn=_make_collate(pad_token_id, max_length), num_workers=2,
|
| 501 |
+
)
|
| 502 |
+
docs: List[dict] = []
|
| 503 |
+
n_tok = 0
|
| 504 |
+
if torch.cuda.is_available() and device.type == "cuda":
|
| 505 |
+
torch.cuda.synchronize()
|
| 506 |
+
torch.cuda.reset_peak_memory_stats(device)
|
| 507 |
+
t0 = time.time()
|
| 508 |
+
for batch in tqdm(loader, desc=f"eval {label}", unit="batch", leave=False):
|
| 509 |
+
ids = batch["input_ids"].to(device, non_blocking=True)
|
| 510 |
+
mask = batch["attention_mask"].to(device, non_blocking=True)
|
| 511 |
+
gold = batch["labels"].to(device, non_blocking=True)
|
| 512 |
+
with torch.no_grad():
|
| 513 |
+
out = model(input_ids=ids, attention_mask=mask)
|
| 514 |
+
raw = out.logits.argmax(dim=-1)
|
| 515 |
+
vit = _bioes_viterbi_batched(out.logits.float(), mask, bioes_trans, bioes_init)
|
| 516 |
+
valid = (gold != -100) & mask.bool()
|
| 517 |
+
for b in range(gold.shape[0]):
|
| 518 |
+
keep = [i for i, ok in enumerate(valid[b].cpu().tolist()) if ok]
|
| 519 |
+
n_tok += len(keep)
|
| 520 |
+
docs.append({
|
| 521 |
+
"gold": [int(gold[b, i].item()) for i in keep],
|
| 522 |
+
"raw": [int(raw[b, i].item()) for i in keep],
|
| 523 |
+
"viterbi": [int(vit[b, i].item()) for i in keep],
|
| 524 |
+
})
|
| 525 |
+
if torch.cuda.is_available() and device.type == "cuda":
|
| 526 |
+
torch.cuda.synchronize()
|
| 527 |
+
peak_mem = torch.cuda.max_memory_allocated(device)
|
| 528 |
+
else:
|
| 529 |
+
peak_mem = 0
|
| 530 |
+
eval_s = time.time() - t0
|
| 531 |
+
|
| 532 |
+
raw_m = _stream_metrics(docs, "raw", id2label, o_id)
|
| 533 |
+
vit_m = _stream_metrics(docs, "viterbi", id2label, o_id)
|
| 534 |
+
|
| 535 |
+
perf["peak_eval_mem_bytes"] = peak_mem
|
| 536 |
+
|
| 537 |
+
del model
|
| 538 |
+
if torch.cuda.is_available():
|
| 539 |
+
torch.cuda.empty_cache()
|
| 540 |
+
|
| 541 |
+
return {
|
| 542 |
+
"label": label,
|
| 543 |
+
"n_total_M": perf["total_params"] / 1e6,
|
| 544 |
+
"load_s": load_s,
|
| 545 |
+
"eval_s": eval_s,
|
| 546 |
+
"n_tok": n_tok,
|
| 547 |
+
"throughput_tok_s": n_tok / eval_s if eval_s else 0.0,
|
| 548 |
+
"perf": perf,
|
| 549 |
+
"raw": raw_m,
|
| 550 |
+
"viterbi": vit_m,
|
| 551 |
+
"docs": docs,
|
| 552 |
+
}
|
| 553 |
+
|
| 554 |
+
|
| 555 |
+
# ---------------------------------------------------------------------------
|
| 556 |
+
# Pairwise token-level breakdown (this vs reference, viterbi stream)
|
| 557 |
+
# ---------------------------------------------------------------------------
|
| 558 |
+
|
| 559 |
+
def _pairwise(docs_cand, docs_ref, id2label, o_id):
|
| 560 |
+
both_correct = only_cand = only_ref = both_wrong = 0
|
| 561 |
+
by_cat = defaultdict(lambda: {"both_correct": 0, "only_cand": 0, "only_ref": 0, "both_wrong": 0})
|
| 562 |
+
for dc, dr in zip(docs_cand, docs_ref):
|
| 563 |
+
gold = dc["gold"]
|
| 564 |
+
cv = dc["viterbi"]
|
| 565 |
+
rv = dr["viterbi"]
|
| 566 |
+
for g, c, r in zip(gold, cv, rv):
|
| 567 |
+
cat = id2label.get(g, "O")
|
| 568 |
+
cat = cat.split("-", 1)[1] if "-" in cat else cat
|
| 569 |
+
cc = (c == g)
|
| 570 |
+
rc = (r == g)
|
| 571 |
+
if cc and rc:
|
| 572 |
+
both_correct += 1
|
| 573 |
+
by_cat[cat]["both_correct"] += 1
|
| 574 |
+
elif cc and not rc:
|
| 575 |
+
only_cand += 1
|
| 576 |
+
by_cat[cat]["only_cand"] += 1
|
| 577 |
+
elif rc and not cc:
|
| 578 |
+
only_ref += 1
|
| 579 |
+
by_cat[cat]["only_ref"] += 1
|
| 580 |
+
else:
|
| 581 |
+
both_wrong += 1
|
| 582 |
+
by_cat[cat]["both_wrong"] += 1
|
| 583 |
+
return {
|
| 584 |
+
"both_correct": both_correct,
|
| 585 |
+
"only_cand_correct": only_cand,
|
| 586 |
+
"only_ref_correct": only_ref,
|
| 587 |
+
"both_wrong": both_wrong,
|
| 588 |
+
"by_cat": dict(by_cat),
|
| 589 |
+
}
|
| 590 |
+
|
| 591 |
+
|
| 592 |
+
# ---------------------------------------------------------------------------
|
| 593 |
+
# Plot rendering
|
| 594 |
+
# ---------------------------------------------------------------------------
|
| 595 |
+
|
| 596 |
+
def _render_plots(cand, ref, pair, out_dir: Path, cand_label, ref_label):
|
| 597 |
+
"""Render benchmark plots.
|
| 598 |
+
|
| 599 |
+
Visual convention:
|
| 600 |
+
A = reference / teacher / baseline
|
| 601 |
+
B = candidate / this checkpoint
|
| 602 |
+
|
| 603 |
+
The charts avoid color-only "win" encoding: labels state the actual delta
|
| 604 |
+
or ratio, and horizontal layouts keep long metric/category names readable.
|
| 605 |
+
"""
|
| 606 |
+
try:
|
| 607 |
+
import matplotlib
|
| 608 |
+
matplotlib.use("Agg")
|
| 609 |
+
import matplotlib.pyplot as plt
|
| 610 |
+
from matplotlib.ticker import FuncFormatter
|
| 611 |
+
except ImportError:
|
| 612 |
+
print("[plot] matplotlib not installed, skipping", flush=True)
|
| 613 |
+
return
|
| 614 |
+
|
| 615 |
+
plt.rcParams.update({
|
| 616 |
+
"figure.facecolor": "#ffffff",
|
| 617 |
+
"axes.facecolor": "#ffffff",
|
| 618 |
+
"axes.edgecolor": "#cbd5e1",
|
| 619 |
+
"axes.labelcolor": "#0f172a",
|
| 620 |
+
"xtick.color": "#334155",
|
| 621 |
+
"ytick.color": "#334155",
|
| 622 |
+
"grid.color": "#e2e8f0",
|
| 623 |
+
"font.size": 9,
|
| 624 |
+
"axes.titleweight": "bold",
|
| 625 |
+
"axes.titlesize": 11,
|
| 626 |
+
"legend.frameon": False,
|
| 627 |
+
})
|
| 628 |
+
|
| 629 |
+
C_REF = "#64748b" # slate
|
| 630 |
+
C_CAND = "#2563eb" # blue
|
| 631 |
+
C_GOOD = "#0f766e" # teal
|
| 632 |
+
C_BAD = "#b91c1c" # red
|
| 633 |
+
C_NEUTRAL = "#94a3b8" # light slate
|
| 634 |
+
C_BG = "#f8fafc"
|
| 635 |
+
|
| 636 |
+
def _pct_axis(ax):
|
| 637 |
+
ax.xaxis.set_major_formatter(FuncFormatter(lambda x, _pos: f"{x:.0%}"))
|
| 638 |
+
|
| 639 |
+
def _metric_delta_text(delta):
|
| 640 |
+
return f"{delta:+.4f}"
|
| 641 |
+
|
| 642 |
+
def _value_text(v):
|
| 643 |
+
if v >= 100:
|
| 644 |
+
return f"{v:,.0f}"
|
| 645 |
+
if v >= 10:
|
| 646 |
+
return f"{v:,.1f}"
|
| 647 |
+
if v >= 1:
|
| 648 |
+
return f"{v:,.2f}"
|
| 649 |
+
return f"{v:,.3f}"
|
| 650 |
+
|
| 651 |
+
def _ratio_text(r, lower_is_better):
|
| 652 |
+
if r <= 0:
|
| 653 |
+
return "n/a"
|
| 654 |
+
if lower_is_better:
|
| 655 |
+
return f"{1.0 / r:.2f}x lower" if r <= 1 else f"{r:.2f}x higher"
|
| 656 |
+
return f"{r:.2f}x higher" if r >= 1 else f"{1.0 / r:.2f}x lower"
|
| 657 |
+
|
| 658 |
+
# --- eval_summary.png: headline metrics + category-level deltas ---
|
| 659 |
+
fig, axes = plt.subplots(2, 2, figsize=(8, 5), constrained_layout=True)
|
| 660 |
+
ax_head = axes[0, 0]
|
| 661 |
+
ax_delta = axes[0, 1]
|
| 662 |
+
ax_raw_vit = axes[1, 0]
|
| 663 |
+
ax_cat = axes[1, 1]
|
| 664 |
+
|
| 665 |
+
metrics = ["span_f1", "span_precision", "span_recall", "token_acc", "non_o_recall"]
|
| 666 |
+
labels = ["Span F1", "Span P", "Span R", "Token acc", "Non-O recall"]
|
| 667 |
+
cand_v = [cand["viterbi"][m] for m in metrics]
|
| 668 |
+
ref_v = [ref["viterbi"][m] for m in metrics] if ref is not None else None
|
| 669 |
+
|
| 670 |
+
y = np.arange(len(metrics))
|
| 671 |
+
if ref_v is not None:
|
| 672 |
+
ax_head.hlines(y, ref_v, cand_v, color=C_NEUTRAL, linewidth=2, alpha=0.9)
|
| 673 |
+
ax_head.scatter(ref_v, y, s=55, color=C_REF, label=f"A: {ref_label}", zorder=3)
|
| 674 |
+
ax_head.scatter(cand_v, y, s=70, color=C_CAND, label=f"B: {cand_label}", zorder=4)
|
| 675 |
+
ax_head.set_yticks(y)
|
| 676 |
+
ax_head.set_yticklabels(labels)
|
| 677 |
+
ax_head.invert_yaxis()
|
| 678 |
+
ax_head.set_xlim(max(0.0, min(cand_v + (ref_v or cand_v)) - 0.02),
|
| 679 |
+
min(1.08, max(cand_v + (ref_v or cand_v)) + 0.05))
|
| 680 |
+
ax_head.set_title("Headline metrics, Viterbi stream")
|
| 681 |
+
ax_head.grid(axis="x", alpha=0.7)
|
| 682 |
+
_pct_axis(ax_head)
|
| 683 |
+
if ref_v is not None:
|
| 684 |
+
for i, v in enumerate(ref_v):
|
| 685 |
+
ax_head.text(v + 0.002, i, f"{v:.4f}", va="center", ha="left",
|
| 686 |
+
fontsize=7, color=C_REF)
|
| 687 |
+
for i, v in enumerate(cand_v):
|
| 688 |
+
ax_head.text(v - 0.002, i, f"{v:.4f}", va="center", ha="right",
|
| 689 |
+
fontsize=7, color=C_CAND)
|
| 690 |
+
|
| 691 |
+
if ref_v is not None:
|
| 692 |
+
deltas = [b - a for a, b in zip(ref_v, cand_v)]
|
| 693 |
+
colors = [C_GOOD if d >= 0 else C_BAD for d in deltas]
|
| 694 |
+
ax_delta.axvline(0, color="#0f172a", linewidth=0.9)
|
| 695 |
+
ax_delta.barh(y, deltas, color=colors, alpha=0.9)
|
| 696 |
+
ax_delta.set_yticks(y)
|
| 697 |
+
ax_delta.set_yticklabels(labels)
|
| 698 |
+
ax_delta.invert_yaxis()
|
| 699 |
+
ax_delta.set_title("Delta: B minus A")
|
| 700 |
+
ax_delta.grid(axis="x", alpha=0.7)
|
| 701 |
+
max_abs = max([abs(d) for d in deltas] + [0.002])
|
| 702 |
+
ax_delta.set_xlim(-max_abs * 1.45, max_abs * 1.45)
|
| 703 |
+
for i, d in enumerate(deltas):
|
| 704 |
+
ax_delta.text(d + (max_abs * 0.04 if d >= 0 else -max_abs * 0.04), i, _metric_delta_text(d),
|
| 705 |
+
ha="left" if d >= 0 else "right", va="center", fontsize=7)
|
| 706 |
+
else:
|
| 707 |
+
ax_delta.axis("off")
|
| 708 |
+
|
| 709 |
+
stream_rows = [
|
| 710 |
+
("A raw", ref["raw"]["span_f1"] if ref is not None else None, C_REF),
|
| 711 |
+
("A viterbi", ref["viterbi"]["span_f1"] if ref is not None else None, C_REF),
|
| 712 |
+
("B raw", cand["raw"]["span_f1"], C_CAND),
|
| 713 |
+
("B viterbi", cand["viterbi"]["span_f1"], C_CAND),
|
| 714 |
+
]
|
| 715 |
+
stream_rows = [r for r in stream_rows if r[1] is not None]
|
| 716 |
+
sy = np.arange(len(stream_rows))
|
| 717 |
+
ax_raw_vit.barh(sy, [r[1] for r in stream_rows], color=[r[2] for r in stream_rows], alpha=0.88)
|
| 718 |
+
ax_raw_vit.set_yticks(sy)
|
| 719 |
+
ax_raw_vit.set_yticklabels([r[0] for r in stream_rows])
|
| 720 |
+
ax_raw_vit.invert_yaxis()
|
| 721 |
+
ax_raw_vit.set_xlim(0, 1.08)
|
| 722 |
+
ax_raw_vit.set_title("Raw vs Viterbi span F1")
|
| 723 |
+
ax_raw_vit.grid(axis="x", alpha=0.7)
|
| 724 |
+
_pct_axis(ax_raw_vit)
|
| 725 |
+
for i, (_, v, _) in enumerate(stream_rows):
|
| 726 |
+
ax_raw_vit.text(v + 0.008, i, f"{v:.4f}", va="center", fontsize=7)
|
| 727 |
+
|
| 728 |
+
cand_pc = cand["viterbi"]["span_per_cat"]
|
| 729 |
+
if ref is not None:
|
| 730 |
+
ref_pc = ref["viterbi"]["span_per_cat"]
|
| 731 |
+
cats = sorted(set(cand_pc) | set(ref_pc))
|
| 732 |
+
rows = []
|
| 733 |
+
for c in cats:
|
| 734 |
+
a = ref_pc.get(c, {}).get("f1", 0.0)
|
| 735 |
+
b = cand_pc.get(c, {}).get("f1", 0.0)
|
| 736 |
+
n = max(cand_pc.get(c, {}).get("n_gold", 0), ref_pc.get(c, {}).get("n_gold", 0))
|
| 737 |
+
rows.append((c, b - a, b, a, n))
|
| 738 |
+
# Keep the categories that explain the comparison: worst and best B deltas.
|
| 739 |
+
worst = sorted(rows, key=lambda r: r[1])[:8]
|
| 740 |
+
best = sorted(rows, key=lambda r: r[1], reverse=True)[:8]
|
| 741 |
+
picked = worst + [r for r in best if r[0] not in {x[0] for x in worst}]
|
| 742 |
+
picked = sorted(picked, key=lambda r: r[1])
|
| 743 |
+
cy = np.arange(len(picked))
|
| 744 |
+
deltas = [r[1] for r in picked]
|
| 745 |
+
ax_cat.axvline(0, color="#0f172a", linewidth=0.9)
|
| 746 |
+
ax_cat.barh(cy, deltas, color=[C_GOOD if d >= 0 else C_BAD for d in deltas])
|
| 747 |
+
ax_cat.set_yticks(cy)
|
| 748 |
+
ax_cat.set_yticklabels([r[0] for r in picked], fontsize=8)
|
| 749 |
+
ax_cat.set_title("Per-category span F1 delta, selected extremes")
|
| 750 |
+
ax_cat.grid(axis="x", alpha=0.7)
|
| 751 |
+
max_abs = max([abs(d) for d in deltas] + [0.05])
|
| 752 |
+
ax_cat.set_xlim(-max_abs * 1.55, max_abs * 1.55)
|
| 753 |
+
for i, r in enumerate(picked):
|
| 754 |
+
d = r[1]
|
| 755 |
+
ax_cat.text(d + (max_abs * 0.05 if d >= 0 else -max_abs * 0.05), i,
|
| 756 |
+
f"{d:+.3f} B={r[2]:.2f} A={r[3]:.2f}",
|
| 757 |
+
va="center", ha="left" if d >= 0 else "right", fontsize=6)
|
| 758 |
+
else:
|
| 759 |
+
cats_sorted = sorted(cand_pc.keys(), key=lambda c: cand_pc[c]["f1"])[:18]
|
| 760 |
+
vals = [cand_pc[c]["f1"] for c in cats_sorted]
|
| 761 |
+
cy = np.arange(len(cats_sorted))
|
| 762 |
+
ax_cat.barh(cy, vals, color=C_CAND)
|
| 763 |
+
ax_cat.set_yticks(cy)
|
| 764 |
+
ax_cat.set_yticklabels(cats_sorted, fontsize=8)
|
| 765 |
+
ax_cat.set_xlim(0, 1.0)
|
| 766 |
+
ax_cat.set_title("Lowest per-category span F1")
|
| 767 |
+
ax_cat.grid(axis="x", alpha=0.7)
|
| 768 |
+
_pct_axis(ax_cat)
|
| 769 |
+
|
| 770 |
+
fig.suptitle(f"Evaluation summary — A: {ref_label if ref else 'n/a'} | B: {cand_label}",
|
| 771 |
+
fontsize=9, fontweight="bold")
|
| 772 |
+
fig.savefig(out_dir / "eval_summary.png", dpi=160)
|
| 773 |
+
plt.close(fig)
|
| 774 |
+
print(f"[plot] wrote {out_dir / 'eval_summary.png'}", flush=True)
|
| 775 |
+
|
| 776 |
+
# --- eval_confusion.png: pairwise outcome on gold non-O tokens ---
|
| 777 |
+
# Display order matches A vs B: "Only A correct" (teacher) before
|
| 778 |
+
# "Only B correct" (student). Underlying buckets in `pair` are still
|
| 779 |
+
# named cand/ref; we just relabel for display.
|
| 780 |
+
if pair is not None:
|
| 781 |
+
fig, axes = plt.subplots(
|
| 782 |
+
1, 2, figsize=(8, 3), constrained_layout=True,
|
| 783 |
+
gridspec_kw={"width_ratios": [0.9, 1.7]},
|
| 784 |
+
)
|
| 785 |
+
ax = axes[0]
|
| 786 |
+
non_o_buckets = {k: 0 for k in ["both_correct", "only_cand", "only_ref", "both_wrong"]}
|
| 787 |
+
for cat, d in pair["by_cat"].items():
|
| 788 |
+
if cat == "O":
|
| 789 |
+
continue
|
| 790 |
+
for k in non_o_buckets:
|
| 791 |
+
non_o_buckets[k] += d[k]
|
| 792 |
+
values = [non_o_buckets["both_correct"], non_o_buckets["only_ref"],
|
| 793 |
+
non_o_buckets["only_cand"], non_o_buckets["both_wrong"]]
|
| 794 |
+
labels_ = [
|
| 795 |
+
"Both\ncorrect",
|
| 796 |
+
"Only A\ncorrect",
|
| 797 |
+
"Only B\ncorrect",
|
| 798 |
+
"Both wrong",
|
| 799 |
+
]
|
| 800 |
+
colors = [C_GOOD, C_REF, C_CAND, C_BAD]
|
| 801 |
+
total = max(1, sum(values))
|
| 802 |
+
ax.barh(np.arange(4), values, color=colors)
|
| 803 |
+
ax.set_yticks(np.arange(4))
|
| 804 |
+
ax.set_yticklabels(labels_)
|
| 805 |
+
ax.invert_yaxis()
|
| 806 |
+
ax.set_ylabel("Gold non-O tokens")
|
| 807 |
+
ax.set_title("Token outcome on gold non-O")
|
| 808 |
+
ax.grid(axis="x", alpha=0.7)
|
| 809 |
+
ax.set_xlim(0, max(values) * 1.32 if values else 1)
|
| 810 |
+
for i, v in enumerate(values):
|
| 811 |
+
ax.text(v + max(values) * 0.015, i, f"{v:,} ({v / total:.1%})",
|
| 812 |
+
va="center", fontsize=6)
|
| 813 |
+
|
| 814 |
+
rows = []
|
| 815 |
+
for cat, d in pair["by_cat"].items():
|
| 816 |
+
if cat == "O":
|
| 817 |
+
continue
|
| 818 |
+
net = d["only_cand"] - d["only_ref"]
|
| 819 |
+
active = d["only_cand"] + d["only_ref"] + d["both_wrong"]
|
| 820 |
+
if active:
|
| 821 |
+
rows.append((cat, net, d["only_cand"], d["only_ref"], d["both_wrong"]))
|
| 822 |
+
worst = sorted(rows, key=lambda r: r[1])[:8]
|
| 823 |
+
best = sorted(rows, key=lambda r: r[1], reverse=True)[:8]
|
| 824 |
+
picked = worst + [r for r in best if r[0] not in {x[0] for x in worst}]
|
| 825 |
+
picked = sorted(picked, key=lambda r: r[1])
|
| 826 |
+
ax2 = axes[1]
|
| 827 |
+
if picked:
|
| 828 |
+
py = np.arange(len(picked))
|
| 829 |
+
nets = [r[1] for r in picked]
|
| 830 |
+
ax2.axvline(0, color="#0f172a", linewidth=0.9)
|
| 831 |
+
ax2.barh(py, nets, color=[C_GOOD if n >= 0 else C_BAD for n in nets])
|
| 832 |
+
ax2.set_yticks(py)
|
| 833 |
+
ax2.set_yticklabels([r[0] for r in picked], fontsize=8)
|
| 834 |
+
ax2.set_title("Net token wins by category: B only-correct minus A only-correct")
|
| 835 |
+
ax2.grid(axis="x", alpha=0.7)
|
| 836 |
+
max_abs = max([abs(n) for n in nets] + [1])
|
| 837 |
+
ax2.set_xlim(-max_abs * 1.5, max_abs * 1.5)
|
| 838 |
+
for i, r in enumerate(picked):
|
| 839 |
+
n = r[1]
|
| 840 |
+
label_x = n + max_abs * 0.04 if n >= 0 else max_abs * 0.05
|
| 841 |
+
ax2.text(label_x, i,
|
| 842 |
+
f"{n:+d} B={r[2]} A={r[3]} W={r[4]}",
|
| 843 |
+
va="center", ha="left", fontsize=6)
|
| 844 |
+
else:
|
| 845 |
+
ax2.axis("off")
|
| 846 |
+
|
| 847 |
+
fig.suptitle(f"Pairwise correctness — A: {ref_label} | B: {cand_label}",
|
| 848 |
+
fontsize=9, fontweight="bold")
|
| 849 |
+
fig.savefig(out_dir / "eval_confusion.png", dpi=160)
|
| 850 |
+
plt.close(fig)
|
| 851 |
+
print(f"[plot] wrote {out_dir / 'eval_confusion.png'}", flush=True)
|
| 852 |
+
|
| 853 |
+
# --- eval_performance.png: model size, compute, throughput, memory ---
|
| 854 |
+
if ref is not None and "perf" in cand and "perf" in ref:
|
| 855 |
+
cp, rp = cand["perf"], ref["perf"]
|
| 856 |
+
|
| 857 |
+
fig, axes = plt.subplots(1, 2, figsize=(8.5, 5), constrained_layout=True)
|
| 858 |
+
ax_abs = axes[0]
|
| 859 |
+
ax_ratio = axes[1]
|
| 860 |
+
|
| 861 |
+
metrics = [
|
| 862 |
+
("Total params (M)", cp["total_params"]/1e6, rp["total_params"]/1e6, True),
|
| 863 |
+
("Active params/tok (M)", cp["active_params_per_tok"]/1e6, rp["active_params_per_tok"]/1e6, True),
|
| 864 |
+
("MoE expert params (M)", cp["moe_total"]/1e6, rp["moe_total"]/1e6, True),
|
| 865 |
+
("GFLOP/token", cp["gflops_per_tok"], rp["gflops_per_tok"], True),
|
| 866 |
+
("Weights RAM (MiB)", cp["weight_bytes"]/(1<<20), rp["weight_bytes"]/(1<<20), True),
|
| 867 |
+
("Peak eval mem (MiB)", cp["peak_eval_mem_bytes"]/(1<<20), rp["peak_eval_mem_bytes"]/(1<<20), True),
|
| 868 |
+
("Throughput (tok/s)", cand["throughput_tok_s"], ref["throughput_tok_s"], False),
|
| 869 |
+
]
|
| 870 |
+
|
| 871 |
+
y = np.arange(len(metrics))
|
| 872 |
+
w = 0.38
|
| 873 |
+
cand_v = [m[1] for m in metrics]
|
| 874 |
+
ref_v = [m[2] for m in metrics]
|
| 875 |
+
ax_abs.barh(y - w/2, ref_v, w, label=f"A: {ref_label}", color=C_REF)
|
| 876 |
+
ax_abs.barh(y + w/2, cand_v, w, label=f"B: {cand_label}", color=C_CAND)
|
| 877 |
+
ax_abs.set_yticks(y)
|
| 878 |
+
ax_abs.set_yticklabels([m[0] for m in metrics], fontsize=8)
|
| 879 |
+
ax_abs.invert_yaxis()
|
| 880 |
+
ax_abs.set_xscale("log")
|
| 881 |
+
ax_abs.set_title("Absolute footprint and speed, log scale")
|
| 882 |
+
ax_abs.grid(axis="x", which="both", alpha=0.7)
|
| 883 |
+
positive_vals = [v for v in (ref_v + cand_v) if v > 0]
|
| 884 |
+
if positive_vals:
|
| 885 |
+
ax_abs.set_xlim(min(positive_vals) * 0.55, max(positive_vals) * 3.8)
|
| 886 |
+
for yi, vals in enumerate(zip(ref_v, cand_v)):
|
| 887 |
+
for off, v, col in [(-w/2, vals[0], C_REF), (w/2, vals[1], C_CAND)]:
|
| 888 |
+
if v <= 0:
|
| 889 |
+
continue
|
| 890 |
+
ax_abs.text(v * 1.05, yi + off, _value_text(v), va="center", fontsize=6, color=col)
|
| 891 |
+
|
| 892 |
+
ratios = [(m[1] / max(1e-12, m[2])) for m in metrics]
|
| 893 |
+
lower_better = [m[3] for m in metrics]
|
| 894 |
+
colors = [
|
| 895 |
+
C_GOOD if ((lb and r <= 1.0) or ((not lb) and r >= 1.0)) else C_BAD
|
| 896 |
+
for r, lb in zip(ratios, lower_better)
|
| 897 |
+
]
|
| 898 |
+
ax_ratio.axvline(1.0, color="#0f172a", linestyle="--", linewidth=0.9, alpha=0.65)
|
| 899 |
+
ax_ratio.barh(y, ratios, color=colors)
|
| 900 |
+
ax_ratio.set_yticks(y)
|
| 901 |
+
ax_ratio.set_yticklabels([m[0] for m in metrics], fontsize=8)
|
| 902 |
+
ax_ratio.invert_yaxis()
|
| 903 |
+
ax_ratio.set_xscale("log")
|
| 904 |
+
ax_ratio.set_title("B / A ratio with explicit direction")
|
| 905 |
+
ax_ratio.grid(axis="x", which="both", alpha=0.7)
|
| 906 |
+
positive_ratios = [r for r in ratios if r > 0]
|
| 907 |
+
if positive_ratios:
|
| 908 |
+
ax_ratio.set_xlim(min(positive_ratios) * 0.38, max(positive_ratios) * 2.8)
|
| 909 |
+
for i, (r, lb) in enumerate(zip(ratios, lower_better)):
|
| 910 |
+
ax_ratio.text(r * (1.05 if r >= 1 else 0.95), i, _ratio_text(r, lb),
|
| 911 |
+
va="center", ha="left" if r >= 1 else "right", fontsize=6)
|
| 912 |
+
|
| 913 |
+
fig.suptitle(f"Performance profile — A: {ref_label} | B: {cand_label}",
|
| 914 |
+
fontsize=9, fontweight="bold")
|
| 915 |
+
fig.savefig(out_dir / "eval_performance.png", dpi=160)
|
| 916 |
+
plt.close(fig)
|
| 917 |
+
print(f"[plot] wrote {out_dir / 'eval_performance.png'}", flush=True)
|
| 918 |
+
|
| 919 |
+
|
| 920 |
+
# ---------------------------------------------------------------------------
|
| 921 |
+
# Reporting
|
| 922 |
+
# ---------------------------------------------------------------------------
|
| 923 |
+
|
| 924 |
+
def _fmt_metrics(m):
|
| 925 |
+
return (f"span_F1={m['span_f1']:.4f} P={m['span_precision']:.4f} "
|
| 926 |
+
f"R={m['span_recall']:.4f} token_acc={m['token_acc']:.4f} "
|
| 927 |
+
f"non_o_recall={m['non_o_recall']:.4f} "
|
| 928 |
+
f"spans={m['n_gold_spans']}/{m['n_pred_spans']}/{m['n_correct_spans']}")
|
| 929 |
+
|
| 930 |
+
|
| 931 |
+
def _write_compare_log(path: Path, cand, ref, pair, args):
|
| 932 |
+
lines: List[str] = []
|
| 933 |
+
A = lines.append
|
| 934 |
+
A(f"Benchmark: A: {ref['label']} vs B: {cand['label']}"
|
| 935 |
+
if ref else f"Benchmark: {cand['label']}")
|
| 936 |
+
A(f"Dataset: {SOURCE_DATASET}, split=test, eval_pct={args.eval_pct}, "
|
| 937 |
+
f"ctx={args.max_length}, seed={args.seed}, n_docs={args.n_docs}")
|
| 938 |
+
A(f"Eval tokens scored: {cand['n_tok']:,}")
|
| 939 |
+
A("")
|
| 940 |
+
A("=== Aggregate ===")
|
| 941 |
+
if ref is not None:
|
| 942 |
+
A(f" A: {ref['label']:<25s} RAW {_fmt_metrics(ref['raw'])}")
|
| 943 |
+
A(f" A: {ref['label']:<25s} VITERBI {_fmt_metrics(ref['viterbi'])}")
|
| 944 |
+
A(f" B: {cand['label']:<25s} RAW {_fmt_metrics(cand['raw'])}")
|
| 945 |
+
A(f" B: {cand['label']:<25s} VITERBI {_fmt_metrics(cand['viterbi'])}")
|
| 946 |
+
if ref is not None:
|
| 947 |
+
d_f1 = ref["viterbi"]["span_f1"] - cand["viterbi"]["span_f1"]
|
| 948 |
+
A("")
|
| 949 |
+
A(f"Gap B vs A (viterbi span_F1): {-d_f1:+.4f}")
|
| 950 |
+
A("")
|
| 951 |
+
if ref is not None:
|
| 952 |
+
A(f"Throughput: A: {ref['label']} {ref['throughput_tok_s']:.0f} tok/s "
|
| 953 |
+
f"({ref['n_total_M']:.2f}M params)")
|
| 954 |
+
A(f" B: {cand['label']} {cand['throughput_tok_s']:.0f} tok/s "
|
| 955 |
+
f"({cand['n_total_M']:.2f}M params)")
|
| 956 |
+
else:
|
| 957 |
+
A(f"Throughput: {cand['label']} {cand['throughput_tok_s']:.0f} tok/s "
|
| 958 |
+
f"({cand['n_total_M']:.2f}M params)")
|
| 959 |
+
A("")
|
| 960 |
+
# ---- Performance summary table ----
|
| 961 |
+
# Column order: A (ref / teacher) first, then B (cand / student).
|
| 962 |
+
# The "B vs A" column is written in human-readable direction:
|
| 963 |
+
# - When B is smaller (size/compute/mem): "X.XX× smaller" / "X.XX× cheaper" / "X.XX× less".
|
| 964 |
+
# - When B is larger but lower-is-better: "X.XX× larger" / "X.XX× more".
|
| 965 |
+
# - When B is faster (throughput, higher-is-better): "X.XX× faster".
|
| 966 |
+
# - When B is slower: "X.XX× slower".
|
| 967 |
+
# Always uses the magnitude in the dominant direction so the reader
|
| 968 |
+
# doesn't need to mentally invert 0.21× into 4.87×.
|
| 969 |
+
def _fmt_vs(b: float, a: float, kind: str) -> str:
|
| 970 |
+
if a is None or a == 0 or b is None or b == 0:
|
| 971 |
+
return ""
|
| 972 |
+
ratio = b / a
|
| 973 |
+
if kind == "size": # lower is better; phrase as "smaller" or "larger"
|
| 974 |
+
if ratio <= 1.0:
|
| 975 |
+
return f"{1.0/ratio:.2f}× smaller"
|
| 976 |
+
return f"{ratio:.2f}× larger"
|
| 977 |
+
if kind == "compute":
|
| 978 |
+
if ratio <= 1.0:
|
| 979 |
+
return f"{1.0/ratio:.2f}× cheaper"
|
| 980 |
+
return f"{ratio:.2f}× more"
|
| 981 |
+
if kind == "memory":
|
| 982 |
+
if ratio <= 1.0:
|
| 983 |
+
return f"{1.0/ratio:.2f}× less"
|
| 984 |
+
return f"{ratio:.2f}× more"
|
| 985 |
+
if kind == "speed": # higher is better
|
| 986 |
+
if ratio >= 1.0:
|
| 987 |
+
return f"{ratio:.2f}× faster"
|
| 988 |
+
return f"{1.0/ratio:.2f}× slower"
|
| 989 |
+
return f"{ratio:.2f}×"
|
| 990 |
+
|
| 991 |
+
cp = cand["perf"]
|
| 992 |
+
A("=== Performance ===")
|
| 993 |
+
headers = ["metric",
|
| 994 |
+
f"A: {ref['label']}" if ref else "",
|
| 995 |
+
f"B: {cand['label']}",
|
| 996 |
+
"B vs A"]
|
| 997 |
+
rp = ref["perf"] if ref else None
|
| 998 |
+
rows = [
|
| 999 |
+
["total params (M)",
|
| 1000 |
+
(f"{rp['total_params']/1e6:.2f}" if ref else ""),
|
| 1001 |
+
f"{cp['total_params']/1e6:.2f}",
|
| 1002 |
+
(_fmt_vs(cp['total_params'], rp['total_params'], "size") if ref else "")],
|
| 1003 |
+
["dense params (M)",
|
| 1004 |
+
(f"{rp['other_total']/1e6:.2f}" if ref else ""),
|
| 1005 |
+
f"{cp['other_total']/1e6:.2f}",
|
| 1006 |
+
(_fmt_vs(cp['other_total'], rp['other_total'], "size") if ref else "")],
|
| 1007 |
+
["MoE expert params (M)",
|
| 1008 |
+
(f"{rp['moe_total']/1e6:.2f}" if ref else ""),
|
| 1009 |
+
f"{cp['moe_total']/1e6:.2f}",
|
| 1010 |
+
(_fmt_vs(cp['moe_total'], rp['moe_total'], "size") if ref else "")],
|
| 1011 |
+
[f"active params/token (M, mem)",
|
| 1012 |
+
(f"{rp['active_params_per_tok']/1e6:.2f}" if ref else ""),
|
| 1013 |
+
f"{cp['active_params_per_tok']/1e6:.2f}",
|
| 1014 |
+
(_fmt_vs(cp['active_params_per_tok'], rp['active_params_per_tok'], "memory") if ref else "")],
|
| 1015 |
+
[f"compute params/token (M, FLOPs)",
|
| 1016 |
+
(f"{rp['compute_params_per_tok']/1e6:.2f}" if ref else ""),
|
| 1017 |
+
f"{cp['compute_params_per_tok']/1e6:.2f}",
|
| 1018 |
+
(_fmt_vs(cp['compute_params_per_tok'], rp['compute_params_per_tok'], "compute") if ref else "")],
|
| 1019 |
+
["GFLOP / token",
|
| 1020 |
+
(f"{rp['gflops_per_tok']:.4f}" if ref else ""),
|
| 1021 |
+
f"{cp['gflops_per_tok']:.4f}",
|
| 1022 |
+
(_fmt_vs(cp['gflops_per_tok'], rp['gflops_per_tok'], "compute") if ref else "")],
|
| 1023 |
+
["disk size (MiB)",
|
| 1024 |
+
(f"{rp['disk_size_bytes']/(1<<20):.1f}" if ref and rp['disk_size_bytes'] else ""),
|
| 1025 |
+
f"{cp['disk_size_bytes']/(1<<20):.1f}",
|
| 1026 |
+
(_fmt_vs(cp['disk_size_bytes'], rp['disk_size_bytes'], "size") if ref and rp['disk_size_bytes'] else "")],
|
| 1027 |
+
["weights in RAM (MiB)",
|
| 1028 |
+
(f"{rp['weight_bytes']/(1<<20):.1f}" if ref else ""),
|
| 1029 |
+
f"{cp['weight_bytes']/(1<<20):.1f}",
|
| 1030 |
+
(_fmt_vs(cp['weight_bytes'], rp['weight_bytes'], "size") if ref else "")],
|
| 1031 |
+
["peak GPU mem eval (MiB)",
|
| 1032 |
+
(f"{rp['peak_eval_mem_bytes']/(1<<20):.1f}" if ref else ""),
|
| 1033 |
+
f"{cp['peak_eval_mem_bytes']/(1<<20):.1f}",
|
| 1034 |
+
(_fmt_vs(cp['peak_eval_mem_bytes'], rp['peak_eval_mem_bytes'], "memory") if ref else "")],
|
| 1035 |
+
["throughput (tok/s)",
|
| 1036 |
+
(f"{ref['throughput_tok_s']:.0f}" if ref else ""),
|
| 1037 |
+
f"{cand['throughput_tok_s']:.0f}",
|
| 1038 |
+
(_fmt_vs(cand['throughput_tok_s'], ref['throughput_tok_s'], "speed") if ref else "")],
|
| 1039 |
+
]
|
| 1040 |
+
widths = [max(len(r[i]) for r in [headers] + rows) for i in range(4)]
|
| 1041 |
+
sep = " " + " ".join("-" * w for w in widths)
|
| 1042 |
+
A(" " + " ".join(h.ljust(widths[i]) for i, h in enumerate(headers)))
|
| 1043 |
+
A(sep)
|
| 1044 |
+
for r in rows:
|
| 1045 |
+
A(" " + " ".join(r[i].ljust(widths[i]) for i in range(4)))
|
| 1046 |
+
A("")
|
| 1047 |
+
if pair is not None:
|
| 1048 |
+
# Display labels: A = ref (teacher), B = cand (student).
|
| 1049 |
+
a_lbl = ref["label"] if ref is not None else "A"
|
| 1050 |
+
b_lbl = cand["label"]
|
| 1051 |
+
A(f"=== Pairwise (viterbi, all gold tokens) — A: {a_lbl} vs B: {b_lbl} ===")
|
| 1052 |
+
total = (pair["both_correct"] + pair["only_cand_correct"]
|
| 1053 |
+
+ pair["only_ref_correct"] + pair["both_wrong"])
|
| 1054 |
+
# Display order: agreement, A-only, B-only, both-wrong.
|
| 1055 |
+
rows_all = [
|
| 1056 |
+
("both_correct", pair["both_correct"]),
|
| 1057 |
+
("only_A_correct", pair["only_ref_correct"]),
|
| 1058 |
+
("only_B_correct", pair["only_cand_correct"]),
|
| 1059 |
+
("both_wrong", pair["both_wrong"]),
|
| 1060 |
+
]
|
| 1061 |
+
for k, v in rows_all:
|
| 1062 |
+
A(f" {k:<26s} {v:8d} ({100.0*v/total:.2f}%)")
|
| 1063 |
+
A("")
|
| 1064 |
+
A(f"=== Pairwise (viterbi, gold non-O tokens) — A: {a_lbl} vs B: {b_lbl} ===")
|
| 1065 |
+
non_o = {k: 0 for k in ["both_correct", "only_cand", "only_ref", "both_wrong"]}
|
| 1066 |
+
for cat, d in pair["by_cat"].items():
|
| 1067 |
+
if cat == "O":
|
| 1068 |
+
continue
|
| 1069 |
+
for k in non_o:
|
| 1070 |
+
non_o[k] += d[k]
|
| 1071 |
+
total_non_o = sum(non_o.values())
|
| 1072 |
+
rows_non_o = [
|
| 1073 |
+
("both_correct", non_o["both_correct"]),
|
| 1074 |
+
("only_A_correct", non_o["only_ref"]),
|
| 1075 |
+
("only_B_correct", non_o["only_cand"]),
|
| 1076 |
+
("both_wrong", non_o["both_wrong"]),
|
| 1077 |
+
]
|
| 1078 |
+
for k, v in rows_non_o:
|
| 1079 |
+
A(f" {k:<26s} {v:8d} ({100.0*v/total_non_o:.2f}%)" if total_non_o else f" {k}: 0")
|
| 1080 |
+
A("")
|
| 1081 |
+
# Per-cat net B-wins. Net = (only_B) - (only_A) = (only_cand) - (only_ref).
|
| 1082 |
+
# Negative net = A (teacher) wins more in this category.
|
| 1083 |
+
nets = []
|
| 1084 |
+
for cat, d in pair["by_cat"].items():
|
| 1085 |
+
if cat == "O":
|
| 1086 |
+
continue
|
| 1087 |
+
nets.append((cat, d["only_cand"] - d["only_ref"],
|
| 1088 |
+
d["only_cand"], d["only_ref"], d["both_wrong"]))
|
| 1089 |
+
nets.sort(key=lambda x: x[1])
|
| 1090 |
+
A(f"=== Worst B-net wins by gold category — A: {a_lbl} ahead (top 15) ===")
|
| 1091 |
+
for cat, net, ob, oa, bw in nets[:15]:
|
| 1092 |
+
A(f" {cat:<32s} net_B={net:+5d} A_only={oa:4d} B_only={ob:4d} both_wrong={bw:4d}")
|
| 1093 |
+
A("")
|
| 1094 |
+
A(f"=== Best B-net wins by gold category — B: {b_lbl} ahead (top 15) ===")
|
| 1095 |
+
for cat, net, ob, oa, bw in nets[::-1][:15]:
|
| 1096 |
+
A(f" {cat:<32s} net_B={net:+5d} A_only={oa:4d} B_only={ob:4d} both_wrong={bw:4d}")
|
| 1097 |
+
A("")
|
| 1098 |
+
A("=== Per-category span F1 (viterbi) ===")
|
| 1099 |
+
if ref is not None:
|
| 1100 |
+
A(f" -- A: {ref['label']} --")
|
| 1101 |
+
per_r = ref["viterbi"]["span_per_cat"]
|
| 1102 |
+
for cat in sorted(per_r):
|
| 1103 |
+
c = per_r[cat]
|
| 1104 |
+
A(f" {cat:<32s} F1={c['f1']:.4f} P={c['precision']:.4f} R={c['recall']:.4f} "
|
| 1105 |
+
f"({c['n_gold']}/{c['n_pred']}/{c['n_correct']})")
|
| 1106 |
+
A(f" -- B: {cand['label']} --")
|
| 1107 |
+
per = cand["viterbi"]["span_per_cat"]
|
| 1108 |
+
for cat in sorted(per):
|
| 1109 |
+
c = per[cat]
|
| 1110 |
+
A(f" {cat:<32s} F1={c['f1']:.4f} P={c['precision']:.4f} R={c['recall']:.4f} "
|
| 1111 |
+
f"({c['n_gold']}/{c['n_pred']}/{c['n_correct']})")
|
| 1112 |
+
path.write_text("\n".join(lines) + "\n")
|
| 1113 |
+
print(f"[log] wrote {path}", flush=True)
|
| 1114 |
+
|
| 1115 |
+
|
| 1116 |
+
def _fmt_bytes(n: int) -> str:
|
| 1117 |
+
if n <= 0:
|
| 1118 |
+
return "—"
|
| 1119 |
+
if n >= 1 << 30:
|
| 1120 |
+
return f"{n / (1 << 30):.2f} GiB"
|
| 1121 |
+
if n >= 1 << 20:
|
| 1122 |
+
return f"{n / (1 << 20):.1f} MiB"
|
| 1123 |
+
return f"{n / (1 << 10):.1f} KiB"
|
| 1124 |
+
|
| 1125 |
+
|
| 1126 |
+
def _perf_block(stream, ctx: int) -> List[str]:
|
| 1127 |
+
p = stream["perf"]
|
| 1128 |
+
out = [
|
| 1129 |
+
f" total params : {p['total_params']/1e6:>9.2f}M "
|
| 1130 |
+
f"({p['other_total']/1e6:.2f}M dense + {p['moe_total']/1e6:.2f}M MoE-experts)",
|
| 1131 |
+
f" active params / token : {p['active_params_per_tok']/1e6:>9.2f}M "
|
| 1132 |
+
f"(memory footprint — embed lookup + top_{p['experts_per_tok']}/{p['num_experts']} experts: "
|
| 1133 |
+
f"{p['embed_total']/1e6:.2f}M embed + "
|
| 1134 |
+
f"{p['moe_active_per_tok']/1e6:.2f}M MoE-active + "
|
| 1135 |
+
f"{(p['other_total']-p['embed_total'])/1e6:.2f}M attn/norm/head)",
|
| 1136 |
+
f" compute params / token : {p['compute_params_per_tok']/1e6:>9.2f}M "
|
| 1137 |
+
f"(matmul FLOPs only — embedding lookup excluded)",
|
| 1138 |
+
f" GFLOP / token (fwd, MAC×2): {p['gflops_per_tok']:>9.3f}",
|
| 1139 |
+
f" weights size (on disk) : {_fmt_bytes(p['disk_size_bytes']):>9s}",
|
| 1140 |
+
f" weights size (in RAM) : {_fmt_bytes(p['weight_bytes']):>9s}",
|
| 1141 |
+
f" weights resident (GPU) : {_fmt_bytes(p['weights_resident_bytes']):>9s}",
|
| 1142 |
+
f" peak GPU mem (eval, ctx={ctx}) : {_fmt_bytes(p['peak_eval_mem_bytes']):>9s}",
|
| 1143 |
+
]
|
| 1144 |
+
return out
|
| 1145 |
+
|
| 1146 |
+
|
| 1147 |
+
def _write_infer_log(path: Path, cand, ref, args, sample_text: str, tokenizer, device, dtype):
|
| 1148 |
+
"""Single-doc inference example + timing + performance metrics."""
|
| 1149 |
+
from modeling_haremb_pii import _bioes_viterbi_batched
|
| 1150 |
+
|
| 1151 |
+
lines: List[str] = []
|
| 1152 |
+
A = lines.append
|
| 1153 |
+
A(f"Inference benchmark: A: {ref['label']} vs B: {cand['label']}"
|
| 1154 |
+
if ref else f"Inference benchmark: {cand['label']}")
|
| 1155 |
+
A(f" device : {device} dtype: {dtype}")
|
| 1156 |
+
A(f" ctx : {args.max_length}")
|
| 1157 |
+
A("")
|
| 1158 |
+
if ref is not None:
|
| 1159 |
+
A(f"A: {ref['label']} (reference / teacher)")
|
| 1160 |
+
A(f" load : {ref['load_s']:.2f}s")
|
| 1161 |
+
A(f" eval : {ref['eval_s']:.2f}s on {ref['n_tok']:,} tokens "
|
| 1162 |
+
f"({ref['throughput_tok_s']:.0f} tok/s)")
|
| 1163 |
+
A("Performance:")
|
| 1164 |
+
for ln in _perf_block(ref, args.max_length):
|
| 1165 |
+
A(ln)
|
| 1166 |
+
A("")
|
| 1167 |
+
A(f"B: {cand['label']}" + (" (this checkpoint)" if ref else ""))
|
| 1168 |
+
A(f" load : {cand['load_s']:.2f}s")
|
| 1169 |
+
A(f" eval : {cand['eval_s']:.2f}s on {cand['n_tok']:,} tokens "
|
| 1170 |
+
f"({cand['throughput_tok_s']:.0f} tok/s)")
|
| 1171 |
+
A("Performance:")
|
| 1172 |
+
for ln in _perf_block(cand, args.max_length):
|
| 1173 |
+
A(ln)
|
| 1174 |
+
A("")
|
| 1175 |
+
if ref is not None:
|
| 1176 |
+
cp, rp = cand["perf"], ref["perf"]
|
| 1177 |
+
|
| 1178 |
+
def _fmt(b, a, kind):
|
| 1179 |
+
if a is None or a == 0 or b is None or b == 0:
|
| 1180 |
+
return "—"
|
| 1181 |
+
r = b / a
|
| 1182 |
+
if kind == "size":
|
| 1183 |
+
return f"{1.0/r:.2f}× smaller" if r <= 1.0 else f"{r:.2f}× larger"
|
| 1184 |
+
if kind == "compute":
|
| 1185 |
+
return f"{1.0/r:.2f}× cheaper" if r <= 1.0 else f"{r:.2f}× more"
|
| 1186 |
+
if kind == "memory":
|
| 1187 |
+
return f"{1.0/r:.2f}× less" if r <= 1.0 else f"{r:.2f}× more"
|
| 1188 |
+
if kind == "speed":
|
| 1189 |
+
return f"{r:.2f}× faster" if r >= 1.0 else f"{1.0/r:.2f}× slower"
|
| 1190 |
+
return f"{r:.2f}×"
|
| 1191 |
+
|
| 1192 |
+
A(f"B vs A ({cand['label']} vs {ref['label']}):")
|
| 1193 |
+
A(f" total params : {_fmt(cp['total_params'], rp['total_params'], 'size')}")
|
| 1194 |
+
A(f" active params / token : {_fmt(cp['active_params_per_tok'], rp['active_params_per_tok'], 'memory')} [memory]")
|
| 1195 |
+
A(f" compute params / token : {_fmt(cp['compute_params_per_tok'], rp['compute_params_per_tok'], 'compute')} [FLOPs]")
|
| 1196 |
+
A(f" GFLOP / token : {_fmt(cp['gflops_per_tok'], rp['gflops_per_tok'], 'compute')}")
|
| 1197 |
+
if rp['disk_size_bytes']:
|
| 1198 |
+
A(f" weights size (on disk) : {_fmt(cp['disk_size_bytes'], rp['disk_size_bytes'], 'size')}")
|
| 1199 |
+
else:
|
| 1200 |
+
A(f" weights size (on disk) : —")
|
| 1201 |
+
A(f" weights in RAM : {_fmt(cp['weight_bytes'], rp['weight_bytes'], 'size')}")
|
| 1202 |
+
A(f" peak GPU mem (eval) : {_fmt(cp['peak_eval_mem_bytes'], rp['peak_eval_mem_bytes'], 'memory')}")
|
| 1203 |
+
A(f" throughput : {_fmt(cand['throughput_tok_s'], ref['throughput_tok_s'], 'speed')}")
|
| 1204 |
+
A("")
|
| 1205 |
+
|
| 1206 |
+
A("Sample inference (load → tokenize → forward → viterbi-decode → spans):")
|
| 1207 |
+
A(f" text: {sample_text!r}")
|
| 1208 |
+
model = AutoModelForTokenClassification.from_pretrained(
|
| 1209 |
+
".", dtype=dtype, trust_remote_code=True,
|
| 1210 |
+
).to(device).eval()
|
| 1211 |
+
if hasattr(model.config, "viterbi_replace_logits"):
|
| 1212 |
+
model.config.viterbi_replace_logits = True
|
| 1213 |
+
enc = tokenizer(sample_text, return_tensors="pt", truncation=True,
|
| 1214 |
+
max_length=args.max_length).to(device)
|
| 1215 |
+
with torch.no_grad():
|
| 1216 |
+
if torch.cuda.is_available():
|
| 1217 |
+
torch.cuda.synchronize()
|
| 1218 |
+
t0 = time.time()
|
| 1219 |
+
out = model(**enc)
|
| 1220 |
+
if torch.cuda.is_available():
|
| 1221 |
+
torch.cuda.synchronize()
|
| 1222 |
+
dt = time.time() - t0
|
| 1223 |
+
label2id, id2label = nemotron_native_label_space()
|
| 1224 |
+
pred = out.logits.argmax(-1)[0].cpu().tolist()
|
| 1225 |
+
spans = _bioes_to_spans(pred, id2label, 0)
|
| 1226 |
+
A(f" forward latency: {dt*1000:.1f}ms ({enc.input_ids.shape[1]} tokens)")
|
| 1227 |
+
A(f" detected {len(spans)} spans:")
|
| 1228 |
+
tok_ids = enc.input_ids[0].cpu().tolist()
|
| 1229 |
+
for s, e, cat in sorted(spans):
|
| 1230 |
+
text = tokenizer.decode(tok_ids[s:e]).strip()
|
| 1231 |
+
A(f" [{s:3d}, {e:3d}) {cat:<28s} {text!r}")
|
| 1232 |
+
del model
|
| 1233 |
+
if torch.cuda.is_available():
|
| 1234 |
+
torch.cuda.empty_cache()
|
| 1235 |
+
path.write_text("\n".join(lines) + "\n")
|
| 1236 |
+
print(f"[log] wrote {path}", flush=True)
|
| 1237 |
+
|
| 1238 |
+
|
| 1239 |
+
# ---------------------------------------------------------------------------
|
| 1240 |
+
# Main
|
| 1241 |
+
# ---------------------------------------------------------------------------
|
| 1242 |
+
|
| 1243 |
+
def main():
|
| 1244 |
+
p = argparse.ArgumentParser(description="Benchmark haremb-privacy-filter-opennemo")
|
| 1245 |
+
p.add_argument("--device", default=None,
|
| 1246 |
+
help="cuda or cpu. Default: cuda if available.")
|
| 1247 |
+
p.add_argument("--dtype", default="bfloat16",
|
| 1248 |
+
choices=["bfloat16", "float16", "float32"])
|
| 1249 |
+
p.add_argument("--eval-pct", type=float, default=1.0,
|
| 1250 |
+
help="Percent of nvidia/Nemotron-PII test split to use. Default 1%%.")
|
| 1251 |
+
p.add_argument("--eval-chunk-size", type=int, default=10_000)
|
| 1252 |
+
p.add_argument("--seed", type=int, default=42)
|
| 1253 |
+
p.add_argument("--max-length", type=int, default=1024)
|
| 1254 |
+
p.add_argument("--batch-size", type=int, default=4)
|
| 1255 |
+
p.add_argument("--out", type=str, default=".")
|
| 1256 |
+
p.add_argument("--model-path", type=str, default=".",
|
| 1257 |
+
help="Path to this checkpoint. Default: ./ (this folder).")
|
| 1258 |
+
p.add_argument("--no-base", action="store_true",
|
| 1259 |
+
help="Skip the OpenMed teacher comparison.")
|
| 1260 |
+
p.add_argument("--no-plots", action="store_true",
|
| 1261 |
+
help="Skip rendering eval_summary.png / eval_confusion.png.")
|
| 1262 |
+
args = p.parse_args()
|
| 1263 |
+
|
| 1264 |
+
out_dir = Path(args.out).resolve()
|
| 1265 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 1266 |
+
|
| 1267 |
+
if args.device is None:
|
| 1268 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 1269 |
+
else:
|
| 1270 |
+
device = torch.device(args.device)
|
| 1271 |
+
dtype = {"bfloat16": torch.bfloat16, "float16": torch.float16,
|
| 1272 |
+
"float32": torch.float32}[args.dtype]
|
| 1273 |
+
|
| 1274 |
+
print(f"[setup] device={device} dtype={dtype} model={args.model_path} "
|
| 1275 |
+
f"out={out_dir}", flush=True)
|
| 1276 |
+
|
| 1277 |
+
label2id, id2label = nemotron_native_label_space()
|
| 1278 |
+
o_id = label2id["O"]
|
| 1279 |
+
|
| 1280 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
|
| 1281 |
+
pad_token_id = tokenizer.pad_token_id or 199999
|
| 1282 |
+
|
| 1283 |
+
# Build the eval set (same slice the README headline numbers reference)
|
| 1284 |
+
print(f"[data] loading {SOURCE_DATASET} ...", flush=True)
|
| 1285 |
+
ds = load_dataset(SOURCE_DATASET)
|
| 1286 |
+
target_eval = max(1, int(round(len(ds["test"]) * args.eval_pct / 100.0)))
|
| 1287 |
+
eval_indices = _build_eval_streaming(
|
| 1288 |
+
ds["test"], target_n=target_eval,
|
| 1289 |
+
chunk_size=args.eval_chunk_size, seed=args.seed,
|
| 1290 |
+
)
|
| 1291 |
+
eval_ds = _NemotronEvalDataset(
|
| 1292 |
+
ds["test"].select(eval_indices), tokenizer, label2id, args.max_length,
|
| 1293 |
+
)
|
| 1294 |
+
args.n_docs = len(eval_ds)
|
| 1295 |
+
print(f"[data] eval={args.n_docs:,} docs ({args.eval_pct:.2f}% of test split)",
|
| 1296 |
+
flush=True)
|
| 1297 |
+
|
| 1298 |
+
# BIOES decoding masks (used for the explicit RAW vs VITERBI streams).
|
| 1299 |
+
from modeling_haremb_pii import (
|
| 1300 |
+
_build_bioes_initial_mask as _bld_init,
|
| 1301 |
+
_build_bioes_transition_mask as _bld_trans,
|
| 1302 |
+
)
|
| 1303 |
+
bioes_trans = _bld_trans(id2label).to(device).float()
|
| 1304 |
+
bioes_init = _bld_init(id2label).to(device).float()
|
| 1305 |
+
|
| 1306 |
+
# Eval candidate
|
| 1307 |
+
cand = _eval_one_model(
|
| 1308 |
+
args.model_path, tokenizer, eval_ds, label2id, id2label, o_id,
|
| 1309 |
+
bioes_trans, bioes_init,
|
| 1310 |
+
args.batch_size, args.max_length, device, dtype,
|
| 1311 |
+
label="haremb",
|
| 1312 |
+
)
|
| 1313 |
+
|
| 1314 |
+
print(f"\n=== {cand['label']} ===")
|
| 1315 |
+
print(f"RAW {_fmt_metrics(cand['raw'])}")
|
| 1316 |
+
print(f"VITERBI {_fmt_metrics(cand['viterbi'])}")
|
| 1317 |
+
print(f"DELTA span_F1={cand['viterbi']['span_f1']-cand['raw']['span_f1']:+.4f} "
|
| 1318 |
+
f"P={cand['viterbi']['span_precision']-cand['raw']['span_precision']:+.4f} "
|
| 1319 |
+
f"R={cand['viterbi']['span_recall']-cand['raw']['span_recall']:+.4f}")
|
| 1320 |
+
|
| 1321 |
+
ref = None
|
| 1322 |
+
pair = None
|
| 1323 |
+
if not args.no_base:
|
| 1324 |
+
ref = _eval_one_model(
|
| 1325 |
+
TEACHER, tokenizer, eval_ds, label2id, id2label, o_id,
|
| 1326 |
+
bioes_trans, bioes_init,
|
| 1327 |
+
args.batch_size, args.max_length, device, dtype,
|
| 1328 |
+
label="openmed-base",
|
| 1329 |
+
)
|
| 1330 |
+
print(f"\n=== {ref['label']} (teacher) ===")
|
| 1331 |
+
print(f"VITERBI {_fmt_metrics(ref['viterbi'])}")
|
| 1332 |
+
pair = _pairwise(cand["docs"], ref["docs"], id2label, o_id)
|
| 1333 |
+
d = ref["viterbi"]["span_f1"] - cand["viterbi"]["span_f1"]
|
| 1334 |
+
print(f"\nGap to teacher (viterbi span_F1): {d:+.4f}")
|
| 1335 |
+
|
| 1336 |
+
# Reports
|
| 1337 |
+
sample_text = ("Patient Sarah Johnson (DOB 03/15/1985), MRN 4872910, "
|
| 1338 |
+
"phone 415-555-0123, email sarah.johnson@example.com, "
|
| 1339 |
+
"credit card 4111-1111-1111-1111.")
|
| 1340 |
+
_write_infer_log(out_dir / "infer.log", cand, ref, args, sample_text,
|
| 1341 |
+
tokenizer, device, dtype)
|
| 1342 |
+
_write_compare_log(out_dir / "compare.log", cand, ref, pair, args)
|
| 1343 |
+
if not args.no_plots:
|
| 1344 |
+
_render_plots(cand, ref, pair, out_dir,
|
| 1345 |
+
cand_label=cand["label"],
|
| 1346 |
+
ref_label=ref["label"] if ref else "")
|
| 1347 |
+
print("\n[done]")
|
| 1348 |
+
|
| 1349 |
+
|
| 1350 |
+
if __name__ == "__main__":
|
| 1351 |
+
main()
|
compare.log
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Benchmark: A: openmed-base vs B: haremb
|
| 2 |
+
Dataset: nvidia/Nemotron-PII, split=test, eval_pct=1.0, ctx=1024, seed=42, n_docs=1000
|
| 3 |
+
Eval tokens scored: 212,909
|
| 4 |
+
|
| 5 |
+
=== Aggregate ===
|
| 6 |
+
A: openmed-base RAW span_F1=0.9174 P=0.9125 R=0.9223 token_acc=0.9895 non_o_recall=0.9685 spans=8627/8720/7957
|
| 7 |
+
A: openmed-base VITERBI span_F1=0.9434 P=0.9531 R=0.9338 token_acc=0.9900 non_o_recall=0.9703 spans=8627/8452/8056
|
| 8 |
+
B: haremb RAW span_F1=0.7741 P=0.7186 R=0.8388 token_acc=0.9831 non_o_recall=0.9467 spans=8627/10069/7236
|
| 9 |
+
B: haremb VITERBI span_F1=0.9288 P=0.9396 R=0.9182 token_acc=0.9885 non_o_recall=0.9637 spans=8627/8430/7921
|
| 10 |
+
|
| 11 |
+
Gap B vs A (viterbi span_F1): -0.0146
|
| 12 |
+
|
| 13 |
+
Throughput: A: openmed-base 3293 tok/s (1399.61M params)
|
| 14 |
+
B: haremb 6343 tok/s (287.11M params)
|
| 15 |
+
|
| 16 |
+
=== Performance ===
|
| 17 |
+
metric A: openmed-base B: haremb B vs A
|
| 18 |
+
------------------------------- --------------- --------- -------------
|
| 19 |
+
total params (M) 1399.61 287.11 4.87× smaller
|
| 20 |
+
dense params (M) 139.35 129.58 1.08× smaller
|
| 21 |
+
MoE expert params (M) 1260.26 157.53 8.00× smaller
|
| 22 |
+
active params/token (M, mem) 178.73 134.50 1.33× less
|
| 23 |
+
compute params/token (M, FLOPs) 50.69 6.46 7.85× cheaper
|
| 24 |
+
GFLOP / token 0.1014 0.0129 7.85× cheaper
|
| 25 |
+
disk size (MiB) 547.6
|
| 26 |
+
weights in RAM (MiB) 2669.5 547.6 4.87× smaller
|
| 27 |
+
peak GPU mem eval (MiB) 3376.2 1248.6 2.70× less
|
| 28 |
+
throughput (tok/s) 3293 6343 1.93× faster
|
| 29 |
+
|
| 30 |
+
=== Pairwise (viterbi, all gold tokens) — A: openmed-base vs B: haremb ===
|
| 31 |
+
both_correct 209830 (98.55%)
|
| 32 |
+
only_A_correct 958 (0.45%)
|
| 33 |
+
only_B_correct 633 (0.30%)
|
| 34 |
+
both_wrong 1488 (0.70%)
|
| 35 |
+
|
| 36 |
+
=== Pairwise (viterbi, gold non-O tokens) — A: openmed-base vs B: haremb ===
|
| 37 |
+
both_correct 43902 (95.47%)
|
| 38 |
+
only_A_correct 717 (1.56%)
|
| 39 |
+
only_B_correct 417 (0.91%)
|
| 40 |
+
both_wrong 951 (2.07%)
|
| 41 |
+
|
| 42 |
+
=== Worst B-net wins by gold category — A: openmed-base ahead (top 15) ===
|
| 43 |
+
company_name net_B= -142 A_only= 216 B_only= 74 both_wrong= 119
|
| 44 |
+
first_name net_B= -75 A_only= 82 B_only= 7 both_wrong= 19
|
| 45 |
+
last_name net_B= -65 A_only= 67 B_only= 2 both_wrong= 38
|
| 46 |
+
occupation net_B= -55 A_only= 79 B_only= 24 both_wrong= 286
|
| 47 |
+
device_identifier net_B= -29 A_only= 29 B_only= 0 both_wrong= 0
|
| 48 |
+
user_name net_B= -26 A_only= 29 B_only= 3 both_wrong= 10
|
| 49 |
+
city net_B= -16 A_only= 32 B_only= 16 both_wrong= 36
|
| 50 |
+
street_address net_B= -13 A_only= 14 B_only= 1 both_wrong= 0
|
| 51 |
+
date_of_birth net_B= -12 A_only= 12 B_only= 0 both_wrong= 0
|
| 52 |
+
email net_B= -8 A_only= 9 B_only= 1 both_wrong= 0
|
| 53 |
+
medical_record_number net_B= -7 A_only= 7 B_only= 0 both_wrong= 0
|
| 54 |
+
phone_number net_B= -6 A_only= 6 B_only= 0 both_wrong= 0
|
| 55 |
+
account_number net_B= -6 A_only= 10 B_only= 4 both_wrong= 0
|
| 56 |
+
tax_id net_B= -6 A_only= 6 B_only= 0 both_wrong= 0
|
| 57 |
+
race_ethnicity net_B= -5 A_only= 15 B_only= 10 both_wrong= 11
|
| 58 |
+
|
| 59 |
+
=== Best B-net wins by gold category — B: haremb ahead (top 15) ===
|
| 60 |
+
date net_B= +46 A_only= 15 B_only= 61 both_wrong= 145
|
| 61 |
+
fax_number net_B= +29 A_only= 1 B_only= 30 both_wrong= 44
|
| 62 |
+
unique_id net_B= +26 A_only= 4 B_only= 30 both_wrong= 0
|
| 63 |
+
ssn net_B= +18 A_only= 0 B_only= 18 both_wrong= 0
|
| 64 |
+
time net_B= +12 A_only= 9 B_only= 21 both_wrong= 87
|
| 65 |
+
political_view net_B= +11 A_only= 3 B_only= 14 both_wrong= 6
|
| 66 |
+
coordinate net_B= +11 A_only= 0 B_only= 11 both_wrong= 0
|
| 67 |
+
customer_id net_B= +8 A_only= 4 B_only= 12 both_wrong= 7
|
| 68 |
+
certificate_license_number net_B= +7 A_only= 0 B_only= 7 both_wrong= 0
|
| 69 |
+
education_level net_B= +6 A_only= 2 B_only= 8 both_wrong= 28
|
| 70 |
+
state net_B= +6 A_only= 13 B_only= 19 both_wrong= 20
|
| 71 |
+
blood_type net_B= +6 A_only= 0 B_only= 6 both_wrong= 0
|
| 72 |
+
gender net_B= +2 A_only= 0 B_only= 2 both_wrong= 2
|
| 73 |
+
http_cookie net_B= +2 A_only= 0 B_only= 2 both_wrong= 3
|
| 74 |
+
country net_B= +2 A_only= 0 B_only= 2 both_wrong= 19
|
| 75 |
+
|
| 76 |
+
=== Per-category span F1 (viterbi) ===
|
| 77 |
+
-- A: openmed-base --
|
| 78 |
+
account_number F1=0.9929 P=0.9929 R=0.9929 (140/140/139)
|
| 79 |
+
age F1=0.8840 P=0.8511 R=0.9195 (87/94/80)
|
| 80 |
+
api_key F1=0.9921 P=0.9844 R=1.0000 (63/64/63)
|
| 81 |
+
bank_routing_number F1=0.9867 P=0.9867 R=0.9867 (75/75/74)
|
| 82 |
+
biometric_identifier F1=1.0000 P=1.0000 R=1.0000 (113/113/113)
|
| 83 |
+
blood_type F1=0.9032 P=0.9032 R=0.9032 (62/62/56)
|
| 84 |
+
certificate_license_number F1=0.9697 P=1.0000 R=0.9412 (34/32/32)
|
| 85 |
+
city F1=0.9154 P=0.9583 R=0.8762 (210/192/184)
|
| 86 |
+
company_name F1=0.8824 P=0.9143 R=0.8526 (563/525/480)
|
| 87 |
+
coordinate F1=0.8000 P=0.8000 R=0.8000 (55/55/44)
|
| 88 |
+
country F1=0.9431 P=0.9324 R=0.9539 (217/222/207)
|
| 89 |
+
county F1=0.9519 P=0.9612 R=0.9429 (105/103/99)
|
| 90 |
+
credit_debit_card F1=0.9967 P=0.9934 R=1.0000 (150/151/150)
|
| 91 |
+
customer_id F1=0.9849 P=1.0000 R=0.9703 (202/196/196)
|
| 92 |
+
cvv F1=0.9787 P=1.0000 R=0.9583 (48/46/46)
|
| 93 |
+
date F1=0.9440 P=0.9571 R=0.9312 (814/792/758)
|
| 94 |
+
date_of_birth F1=1.0000 P=1.0000 R=1.0000 (164/164/164)
|
| 95 |
+
date_time F1=0.9635 P=0.9429 R=0.9851 (134/140/132)
|
| 96 |
+
device_identifier F1=0.9714 P=0.9444 R=1.0000 (51/54/51)
|
| 97 |
+
education_level F1=0.9091 P=0.9524 R=0.8696 (92/84/80)
|
| 98 |
+
email F1=0.9971 P=0.9961 R=0.9980 (511/512/510)
|
| 99 |
+
employee_id F1=0.9948 P=1.0000 R=0.9896 (96/95/95)
|
| 100 |
+
employment_status F1=0.9478 P=0.9593 R=0.9365 (126/123/118)
|
| 101 |
+
fax_number F1=0.9091 P=0.9848 R=0.8442 (77/66/65)
|
| 102 |
+
first_name F1=0.9766 P=0.9716 R=0.9816 (871/880/855)
|
| 103 |
+
gender F1=0.9737 P=0.9867 R=0.9610 (77/75/74)
|
| 104 |
+
health_plan_beneficiary_number F1=1.0000 P=1.0000 R=1.0000 (103/103/103)
|
| 105 |
+
http_cookie F1=0.9307 P=0.9400 R=0.9216 (51/50/47)
|
| 106 |
+
ipv4 F1=1.0000 P=1.0000 R=1.0000 (59/59/59)
|
| 107 |
+
ipv6 F1=1.0000 P=1.0000 R=1.0000 (21/21/21)
|
| 108 |
+
language F1=0.9000 P=0.9000 R=0.9000 (90/90/81)
|
| 109 |
+
last_name F1=0.9744 P=0.9767 R=0.9721 (646/643/628)
|
| 110 |
+
license_plate F1=1.0000 P=1.0000 R=1.0000 (55/55/55)
|
| 111 |
+
mac_address F1=1.0000 P=1.0000 R=1.0000 (30/30/30)
|
| 112 |
+
medical_record_number F1=1.0000 P=1.0000 R=1.0000 (103/103/103)
|
| 113 |
+
national_id F1=1.0000 P=1.0000 R=1.0000 (28/28/28)
|
| 114 |
+
occupation F1=0.6522 P=0.7721 R=0.5645 (372/272/210)
|
| 115 |
+
password F1=0.9217 P=0.9636 R=0.8833 (60/55/53)
|
| 116 |
+
phone_number F1=0.9751 P=0.9514 R=1.0000 (235/247/235)
|
| 117 |
+
pin F1=0.9302 P=0.8955 R=0.9677 (62/67/60)
|
| 118 |
+
political_view F1=0.8387 P=0.8125 R=0.8667 (45/48/39)
|
| 119 |
+
postcode F1=0.9934 P=0.9868 R=1.0000 (75/76/75)
|
| 120 |
+
race_ethnicity F1=0.8889 P=0.8889 R=0.8889 (81/81/72)
|
| 121 |
+
religious_belief F1=0.8936 P=0.8750 R=0.9130 (46/48/42)
|
| 122 |
+
sexuality F1=0.9667 P=1.0000 R=0.9355 (31/29/29)
|
| 123 |
+
ssn F1=0.9440 P=0.9365 R=0.9516 (62/63/59)
|
| 124 |
+
state F1=0.9198 P=0.9399 R=0.9005 (191/183/172)
|
| 125 |
+
street_address F1=0.9894 P=0.9842 R=0.9947 (188/190/187)
|
| 126 |
+
swift_bic F1=0.9905 P=0.9811 R=1.0000 (52/53/52)
|
| 127 |
+
tax_id F1=1.0000 P=1.0000 R=1.0000 (15/15/15)
|
| 128 |
+
time F1=0.8209 P=0.8514 R=0.7926 (188/175/149)
|
| 129 |
+
unique_id F1=0.9600 P=1.0000 R=0.9231 (13/12/12)
|
| 130 |
+
url F1=0.9725 P=0.9687 R=0.9763 (380/383/371)
|
| 131 |
+
user_name F1=0.9497 P=0.9264 R=0.9742 (155/163/151)
|
| 132 |
+
vehicle_identifier F1=0.9815 P=0.9636 R=1.0000 (53/55/53)
|
| 133 |
+
-- B: haremb --
|
| 134 |
+
account_number F1=0.9751 P=0.9716 R=0.9786 (140/141/137)
|
| 135 |
+
age F1=0.8571 P=0.8211 R=0.8966 (87/95/78)
|
| 136 |
+
api_key F1=0.9921 P=0.9844 R=1.0000 (63/64/63)
|
| 137 |
+
bank_routing_number F1=0.9933 P=1.0000 R=0.9867 (75/74/74)
|
| 138 |
+
biometric_identifier F1=1.0000 P=1.0000 R=1.0000 (113/113/113)
|
| 139 |
+
blood_type F1=1.0000 P=1.0000 R=1.0000 (62/62/62)
|
| 140 |
+
certificate_license_number F1=0.9855 P=0.9714 R=1.0000 (34/35/34)
|
| 141 |
+
city F1=0.8932 P=0.9109 R=0.8762 (210/202/184)
|
| 142 |
+
company_name F1=0.7766 P=0.8120 R=0.7442 (563/516/419)
|
| 143 |
+
coordinate F1=1.0000 P=1.0000 R=1.0000 (55/55/55)
|
| 144 |
+
country F1=0.9543 P=0.9457 R=0.9631 (217/221/209)
|
| 145 |
+
county F1=0.9340 P=0.9252 R=0.9429 (105/107/99)
|
| 146 |
+
credit_debit_card F1=0.9934 P=0.9868 R=1.0000 (150/152/150)
|
| 147 |
+
customer_id F1=0.9779 P=0.9707 R=0.9851 (202/205/199)
|
| 148 |
+
cvv F1=0.9792 P=0.9792 R=0.9792 (48/48/47)
|
| 149 |
+
date F1=0.9510 P=0.9599 R=0.9423 (814/799/767)
|
| 150 |
+
date_of_birth F1=0.9939 P=1.0000 R=0.9878 (164/162/162)
|
| 151 |
+
date_time F1=0.9635 P=0.9429 R=0.9851 (134/140/132)
|
| 152 |
+
device_identifier F1=0.9515 P=0.9423 R=0.9608 (51/52/49)
|
| 153 |
+
education_level F1=0.9091 P=0.9524 R=0.8696 (92/84/80)
|
| 154 |
+
email F1=0.9912 P=0.9883 R=0.9941 (511/514/508)
|
| 155 |
+
employee_id F1=0.9895 P=1.0000 R=0.9792 (96/94/94)
|
| 156 |
+
employment_status F1=0.9562 P=0.9600 R=0.9524 (126/125/120)
|
| 157 |
+
fax_number F1=0.9396 P=0.9722 R=0.9091 (77/72/70)
|
| 158 |
+
first_name F1=0.9299 P=0.9231 R=0.9369 (871/884/816)
|
| 159 |
+
gender F1=0.9870 P=0.9870 R=0.9870 (77/77/76)
|
| 160 |
+
health_plan_beneficiary_number F1=1.0000 P=1.0000 R=1.0000 (103/103/103)
|
| 161 |
+
http_cookie F1=0.9608 P=0.9608 R=0.9608 (51/51/49)
|
| 162 |
+
ipv4 F1=1.0000 P=1.0000 R=1.0000 (59/59/59)
|
| 163 |
+
ipv6 F1=1.0000 P=1.0000 R=1.0000 (21/21/21)
|
| 164 |
+
language F1=0.8966 P=0.9286 R=0.8667 (90/84/78)
|
| 165 |
+
last_name F1=0.9308 P=0.9457 R=0.9164 (646/626/592)
|
| 166 |
+
license_plate F1=1.0000 P=1.0000 R=1.0000 (55/55/55)
|
| 167 |
+
mac_address F1=1.0000 P=1.0000 R=1.0000 (30/30/30)
|
| 168 |
+
medical_record_number F1=0.9903 P=0.9903 R=0.9903 (103/103/102)
|
| 169 |
+
national_id F1=0.9825 P=0.9655 R=1.0000 (28/29/28)
|
| 170 |
+
occupation F1=0.5981 P=0.7440 R=0.5000 (372/250/186)
|
| 171 |
+
password F1=0.9391 P=0.9818 R=0.9000 (60/55/54)
|
| 172 |
+
phone_number F1=0.9730 P=0.9512 R=0.9957 (235/246/234)
|
| 173 |
+
pin F1=0.9508 P=0.9667 R=0.9355 (62/60/58)
|
| 174 |
+
political_view F1=0.8723 P=0.8367 R=0.9111 (45/49/41)
|
| 175 |
+
postcode F1=0.9934 P=0.9868 R=1.0000 (75/76/75)
|
| 176 |
+
race_ethnicity F1=0.8590 P=0.8933 R=0.8272 (81/75/67)
|
| 177 |
+
religious_belief F1=0.9348 P=0.9348 R=0.9348 (46/46/43)
|
| 178 |
+
sexuality F1=0.9492 P=1.0000 R=0.9032 (31/28/28)
|
| 179 |
+
ssn F1=0.9688 P=0.9394 R=1.0000 (62/66/62)
|
| 180 |
+
state F1=0.9105 P=0.9153 R=0.9058 (191/189/173)
|
| 181 |
+
street_address F1=0.9894 P=0.9894 R=0.9894 (188/188/186)
|
| 182 |
+
swift_bic F1=0.9905 P=0.9811 R=1.0000 (52/53/52)
|
| 183 |
+
tax_id F1=0.9655 P=1.0000 R=0.9333 (15/14/14)
|
| 184 |
+
time F1=0.8421 P=0.8786 R=0.8085 (188/173/152)
|
| 185 |
+
unique_id F1=0.8571 P=0.8000 R=0.9231 (13/15/12)
|
| 186 |
+
url F1=0.9752 P=0.9688 R=0.9816 (380/385/373)
|
| 187 |
+
user_name F1=0.9416 P=0.9477 R=0.9355 (155/153/145)
|
| 188 |
+
vehicle_identifier F1=0.9630 P=0.9455 R=0.9811 (53/55/52)
|
config.json
ADDED
|
@@ -0,0 +1,788 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"HaremPiiForTokenClassification"
|
| 4 |
+
],
|
| 5 |
+
"attention_bias": true,
|
| 6 |
+
"attention_dropout": 0.0,
|
| 7 |
+
"auto_map": {
|
| 8 |
+
"AutoConfig": "configuration_haremb_pii.HaremPiiConfig",
|
| 9 |
+
"AutoModelForTokenClassification": "modeling_haremb_pii.HaremPiiForTokenClassification"
|
| 10 |
+
},
|
| 11 |
+
"bos_token_id": null,
|
| 12 |
+
"classifier_dropout": 0.0,
|
| 13 |
+
"default_n_ctx": 128000,
|
| 14 |
+
"dtype": "bfloat16",
|
| 15 |
+
"eos_token_id": 199999,
|
| 16 |
+
"head_dim": 64,
|
| 17 |
+
"hidden_act": "silu",
|
| 18 |
+
"hidden_size": 640,
|
| 19 |
+
"id2label": {
|
| 20 |
+
"0": "O",
|
| 21 |
+
"1": "B-account_number",
|
| 22 |
+
"2": "I-account_number",
|
| 23 |
+
"3": "E-account_number",
|
| 24 |
+
"4": "S-account_number",
|
| 25 |
+
"5": "B-age",
|
| 26 |
+
"6": "I-age",
|
| 27 |
+
"7": "E-age",
|
| 28 |
+
"8": "S-age",
|
| 29 |
+
"9": "B-api_key",
|
| 30 |
+
"10": "I-api_key",
|
| 31 |
+
"11": "E-api_key",
|
| 32 |
+
"12": "S-api_key",
|
| 33 |
+
"13": "B-bank_routing_number",
|
| 34 |
+
"14": "I-bank_routing_number",
|
| 35 |
+
"15": "E-bank_routing_number",
|
| 36 |
+
"16": "S-bank_routing_number",
|
| 37 |
+
"17": "B-biometric_identifier",
|
| 38 |
+
"18": "I-biometric_identifier",
|
| 39 |
+
"19": "E-biometric_identifier",
|
| 40 |
+
"20": "S-biometric_identifier",
|
| 41 |
+
"21": "B-blood_type",
|
| 42 |
+
"22": "I-blood_type",
|
| 43 |
+
"23": "E-blood_type",
|
| 44 |
+
"24": "S-blood_type",
|
| 45 |
+
"25": "B-certificate_license_number",
|
| 46 |
+
"26": "I-certificate_license_number",
|
| 47 |
+
"27": "E-certificate_license_number",
|
| 48 |
+
"28": "S-certificate_license_number",
|
| 49 |
+
"29": "B-city",
|
| 50 |
+
"30": "I-city",
|
| 51 |
+
"31": "E-city",
|
| 52 |
+
"32": "S-city",
|
| 53 |
+
"33": "B-company_name",
|
| 54 |
+
"34": "I-company_name",
|
| 55 |
+
"35": "E-company_name",
|
| 56 |
+
"36": "S-company_name",
|
| 57 |
+
"37": "B-coordinate",
|
| 58 |
+
"38": "I-coordinate",
|
| 59 |
+
"39": "E-coordinate",
|
| 60 |
+
"40": "S-coordinate",
|
| 61 |
+
"41": "B-country",
|
| 62 |
+
"42": "I-country",
|
| 63 |
+
"43": "E-country",
|
| 64 |
+
"44": "S-country",
|
| 65 |
+
"45": "B-county",
|
| 66 |
+
"46": "I-county",
|
| 67 |
+
"47": "E-county",
|
| 68 |
+
"48": "S-county",
|
| 69 |
+
"49": "B-credit_debit_card",
|
| 70 |
+
"50": "I-credit_debit_card",
|
| 71 |
+
"51": "E-credit_debit_card",
|
| 72 |
+
"52": "S-credit_debit_card",
|
| 73 |
+
"53": "B-customer_id",
|
| 74 |
+
"54": "I-customer_id",
|
| 75 |
+
"55": "E-customer_id",
|
| 76 |
+
"56": "S-customer_id",
|
| 77 |
+
"57": "B-cvv",
|
| 78 |
+
"58": "I-cvv",
|
| 79 |
+
"59": "E-cvv",
|
| 80 |
+
"60": "S-cvv",
|
| 81 |
+
"61": "B-date",
|
| 82 |
+
"62": "I-date",
|
| 83 |
+
"63": "E-date",
|
| 84 |
+
"64": "S-date",
|
| 85 |
+
"65": "B-date_of_birth",
|
| 86 |
+
"66": "I-date_of_birth",
|
| 87 |
+
"67": "E-date_of_birth",
|
| 88 |
+
"68": "S-date_of_birth",
|
| 89 |
+
"69": "B-date_time",
|
| 90 |
+
"70": "I-date_time",
|
| 91 |
+
"71": "E-date_time",
|
| 92 |
+
"72": "S-date_time",
|
| 93 |
+
"73": "B-device_identifier",
|
| 94 |
+
"74": "I-device_identifier",
|
| 95 |
+
"75": "E-device_identifier",
|
| 96 |
+
"76": "S-device_identifier",
|
| 97 |
+
"77": "B-education_level",
|
| 98 |
+
"78": "I-education_level",
|
| 99 |
+
"79": "E-education_level",
|
| 100 |
+
"80": "S-education_level",
|
| 101 |
+
"81": "B-email",
|
| 102 |
+
"82": "I-email",
|
| 103 |
+
"83": "E-email",
|
| 104 |
+
"84": "S-email",
|
| 105 |
+
"85": "B-employee_id",
|
| 106 |
+
"86": "I-employee_id",
|
| 107 |
+
"87": "E-employee_id",
|
| 108 |
+
"88": "S-employee_id",
|
| 109 |
+
"89": "B-employment_status",
|
| 110 |
+
"90": "I-employment_status",
|
| 111 |
+
"91": "E-employment_status",
|
| 112 |
+
"92": "S-employment_status",
|
| 113 |
+
"93": "B-fax_number",
|
| 114 |
+
"94": "I-fax_number",
|
| 115 |
+
"95": "E-fax_number",
|
| 116 |
+
"96": "S-fax_number",
|
| 117 |
+
"97": "B-first_name",
|
| 118 |
+
"98": "I-first_name",
|
| 119 |
+
"99": "E-first_name",
|
| 120 |
+
"100": "S-first_name",
|
| 121 |
+
"101": "B-gender",
|
| 122 |
+
"102": "I-gender",
|
| 123 |
+
"103": "E-gender",
|
| 124 |
+
"104": "S-gender",
|
| 125 |
+
"105": "B-health_plan_beneficiary_number",
|
| 126 |
+
"106": "I-health_plan_beneficiary_number",
|
| 127 |
+
"107": "E-health_plan_beneficiary_number",
|
| 128 |
+
"108": "S-health_plan_beneficiary_number",
|
| 129 |
+
"109": "B-http_cookie",
|
| 130 |
+
"110": "I-http_cookie",
|
| 131 |
+
"111": "E-http_cookie",
|
| 132 |
+
"112": "S-http_cookie",
|
| 133 |
+
"113": "B-ipv4",
|
| 134 |
+
"114": "I-ipv4",
|
| 135 |
+
"115": "E-ipv4",
|
| 136 |
+
"116": "S-ipv4",
|
| 137 |
+
"117": "B-ipv6",
|
| 138 |
+
"118": "I-ipv6",
|
| 139 |
+
"119": "E-ipv6",
|
| 140 |
+
"120": "S-ipv6",
|
| 141 |
+
"121": "B-language",
|
| 142 |
+
"122": "I-language",
|
| 143 |
+
"123": "E-language",
|
| 144 |
+
"124": "S-language",
|
| 145 |
+
"125": "B-last_name",
|
| 146 |
+
"126": "I-last_name",
|
| 147 |
+
"127": "E-last_name",
|
| 148 |
+
"128": "S-last_name",
|
| 149 |
+
"129": "B-license_plate",
|
| 150 |
+
"130": "I-license_plate",
|
| 151 |
+
"131": "E-license_plate",
|
| 152 |
+
"132": "S-license_plate",
|
| 153 |
+
"133": "B-mac_address",
|
| 154 |
+
"134": "I-mac_address",
|
| 155 |
+
"135": "E-mac_address",
|
| 156 |
+
"136": "S-mac_address",
|
| 157 |
+
"137": "B-medical_record_number",
|
| 158 |
+
"138": "I-medical_record_number",
|
| 159 |
+
"139": "E-medical_record_number",
|
| 160 |
+
"140": "S-medical_record_number",
|
| 161 |
+
"141": "B-national_id",
|
| 162 |
+
"142": "I-national_id",
|
| 163 |
+
"143": "E-national_id",
|
| 164 |
+
"144": "S-national_id",
|
| 165 |
+
"145": "B-occupation",
|
| 166 |
+
"146": "I-occupation",
|
| 167 |
+
"147": "E-occupation",
|
| 168 |
+
"148": "S-occupation",
|
| 169 |
+
"149": "B-password",
|
| 170 |
+
"150": "I-password",
|
| 171 |
+
"151": "E-password",
|
| 172 |
+
"152": "S-password",
|
| 173 |
+
"153": "B-phone_number",
|
| 174 |
+
"154": "I-phone_number",
|
| 175 |
+
"155": "E-phone_number",
|
| 176 |
+
"156": "S-phone_number",
|
| 177 |
+
"157": "B-pin",
|
| 178 |
+
"158": "I-pin",
|
| 179 |
+
"159": "E-pin",
|
| 180 |
+
"160": "S-pin",
|
| 181 |
+
"161": "B-political_view",
|
| 182 |
+
"162": "I-political_view",
|
| 183 |
+
"163": "E-political_view",
|
| 184 |
+
"164": "S-political_view",
|
| 185 |
+
"165": "B-postcode",
|
| 186 |
+
"166": "I-postcode",
|
| 187 |
+
"167": "E-postcode",
|
| 188 |
+
"168": "S-postcode",
|
| 189 |
+
"169": "B-race_ethnicity",
|
| 190 |
+
"170": "I-race_ethnicity",
|
| 191 |
+
"171": "E-race_ethnicity",
|
| 192 |
+
"172": "S-race_ethnicity",
|
| 193 |
+
"173": "B-religious_belief",
|
| 194 |
+
"174": "I-religious_belief",
|
| 195 |
+
"175": "E-religious_belief",
|
| 196 |
+
"176": "S-religious_belief",
|
| 197 |
+
"177": "B-sexuality",
|
| 198 |
+
"178": "I-sexuality",
|
| 199 |
+
"179": "E-sexuality",
|
| 200 |
+
"180": "S-sexuality",
|
| 201 |
+
"181": "B-ssn",
|
| 202 |
+
"182": "I-ssn",
|
| 203 |
+
"183": "E-ssn",
|
| 204 |
+
"184": "S-ssn",
|
| 205 |
+
"185": "B-state",
|
| 206 |
+
"186": "I-state",
|
| 207 |
+
"187": "E-state",
|
| 208 |
+
"188": "S-state",
|
| 209 |
+
"189": "B-street_address",
|
| 210 |
+
"190": "I-street_address",
|
| 211 |
+
"191": "E-street_address",
|
| 212 |
+
"192": "S-street_address",
|
| 213 |
+
"193": "B-swift_bic",
|
| 214 |
+
"194": "I-swift_bic",
|
| 215 |
+
"195": "E-swift_bic",
|
| 216 |
+
"196": "S-swift_bic",
|
| 217 |
+
"197": "B-tax_id",
|
| 218 |
+
"198": "I-tax_id",
|
| 219 |
+
"199": "E-tax_id",
|
| 220 |
+
"200": "S-tax_id",
|
| 221 |
+
"201": "B-time",
|
| 222 |
+
"202": "I-time",
|
| 223 |
+
"203": "E-time",
|
| 224 |
+
"204": "S-time",
|
| 225 |
+
"205": "B-unique_id",
|
| 226 |
+
"206": "I-unique_id",
|
| 227 |
+
"207": "E-unique_id",
|
| 228 |
+
"208": "S-unique_id",
|
| 229 |
+
"209": "B-url",
|
| 230 |
+
"210": "I-url",
|
| 231 |
+
"211": "E-url",
|
| 232 |
+
"212": "S-url",
|
| 233 |
+
"213": "B-user_name",
|
| 234 |
+
"214": "I-user_name",
|
| 235 |
+
"215": "E-user_name",
|
| 236 |
+
"216": "S-user_name",
|
| 237 |
+
"217": "B-vehicle_identifier",
|
| 238 |
+
"218": "I-vehicle_identifier",
|
| 239 |
+
"219": "E-vehicle_identifier",
|
| 240 |
+
"220": "S-vehicle_identifier"
|
| 241 |
+
},
|
| 242 |
+
"initial_context_length": 4096,
|
| 243 |
+
"initializer_range": 0.02,
|
| 244 |
+
"intermediate_size": 640,
|
| 245 |
+
"label2id": {
|
| 246 |
+
"B-account_number": 1,
|
| 247 |
+
"B-age": 5,
|
| 248 |
+
"B-api_key": 9,
|
| 249 |
+
"B-bank_routing_number": 13,
|
| 250 |
+
"B-biometric_identifier": 17,
|
| 251 |
+
"B-blood_type": 21,
|
| 252 |
+
"B-certificate_license_number": 25,
|
| 253 |
+
"B-city": 29,
|
| 254 |
+
"B-company_name": 33,
|
| 255 |
+
"B-coordinate": 37,
|
| 256 |
+
"B-country": 41,
|
| 257 |
+
"B-county": 45,
|
| 258 |
+
"B-credit_debit_card": 49,
|
| 259 |
+
"B-customer_id": 53,
|
| 260 |
+
"B-cvv": 57,
|
| 261 |
+
"B-date": 61,
|
| 262 |
+
"B-date_of_birth": 65,
|
| 263 |
+
"B-date_time": 69,
|
| 264 |
+
"B-device_identifier": 73,
|
| 265 |
+
"B-education_level": 77,
|
| 266 |
+
"B-email": 81,
|
| 267 |
+
"B-employee_id": 85,
|
| 268 |
+
"B-employment_status": 89,
|
| 269 |
+
"B-fax_number": 93,
|
| 270 |
+
"B-first_name": 97,
|
| 271 |
+
"B-gender": 101,
|
| 272 |
+
"B-health_plan_beneficiary_number": 105,
|
| 273 |
+
"B-http_cookie": 109,
|
| 274 |
+
"B-ipv4": 113,
|
| 275 |
+
"B-ipv6": 117,
|
| 276 |
+
"B-language": 121,
|
| 277 |
+
"B-last_name": 125,
|
| 278 |
+
"B-license_plate": 129,
|
| 279 |
+
"B-mac_address": 133,
|
| 280 |
+
"B-medical_record_number": 137,
|
| 281 |
+
"B-national_id": 141,
|
| 282 |
+
"B-occupation": 145,
|
| 283 |
+
"B-password": 149,
|
| 284 |
+
"B-phone_number": 153,
|
| 285 |
+
"B-pin": 157,
|
| 286 |
+
"B-political_view": 161,
|
| 287 |
+
"B-postcode": 165,
|
| 288 |
+
"B-race_ethnicity": 169,
|
| 289 |
+
"B-religious_belief": 173,
|
| 290 |
+
"B-sexuality": 177,
|
| 291 |
+
"B-ssn": 181,
|
| 292 |
+
"B-state": 185,
|
| 293 |
+
"B-street_address": 189,
|
| 294 |
+
"B-swift_bic": 193,
|
| 295 |
+
"B-tax_id": 197,
|
| 296 |
+
"B-time": 201,
|
| 297 |
+
"B-unique_id": 205,
|
| 298 |
+
"B-url": 209,
|
| 299 |
+
"B-user_name": 213,
|
| 300 |
+
"B-vehicle_identifier": 217,
|
| 301 |
+
"E-account_number": 3,
|
| 302 |
+
"E-age": 7,
|
| 303 |
+
"E-api_key": 11,
|
| 304 |
+
"E-bank_routing_number": 15,
|
| 305 |
+
"E-biometric_identifier": 19,
|
| 306 |
+
"E-blood_type": 23,
|
| 307 |
+
"E-certificate_license_number": 27,
|
| 308 |
+
"E-city": 31,
|
| 309 |
+
"E-company_name": 35,
|
| 310 |
+
"E-coordinate": 39,
|
| 311 |
+
"E-country": 43,
|
| 312 |
+
"E-county": 47,
|
| 313 |
+
"E-credit_debit_card": 51,
|
| 314 |
+
"E-customer_id": 55,
|
| 315 |
+
"E-cvv": 59,
|
| 316 |
+
"E-date": 63,
|
| 317 |
+
"E-date_of_birth": 67,
|
| 318 |
+
"E-date_time": 71,
|
| 319 |
+
"E-device_identifier": 75,
|
| 320 |
+
"E-education_level": 79,
|
| 321 |
+
"E-email": 83,
|
| 322 |
+
"E-employee_id": 87,
|
| 323 |
+
"E-employment_status": 91,
|
| 324 |
+
"E-fax_number": 95,
|
| 325 |
+
"E-first_name": 99,
|
| 326 |
+
"E-gender": 103,
|
| 327 |
+
"E-health_plan_beneficiary_number": 107,
|
| 328 |
+
"E-http_cookie": 111,
|
| 329 |
+
"E-ipv4": 115,
|
| 330 |
+
"E-ipv6": 119,
|
| 331 |
+
"E-language": 123,
|
| 332 |
+
"E-last_name": 127,
|
| 333 |
+
"E-license_plate": 131,
|
| 334 |
+
"E-mac_address": 135,
|
| 335 |
+
"E-medical_record_number": 139,
|
| 336 |
+
"E-national_id": 143,
|
| 337 |
+
"E-occupation": 147,
|
| 338 |
+
"E-password": 151,
|
| 339 |
+
"E-phone_number": 155,
|
| 340 |
+
"E-pin": 159,
|
| 341 |
+
"E-political_view": 163,
|
| 342 |
+
"E-postcode": 167,
|
| 343 |
+
"E-race_ethnicity": 171,
|
| 344 |
+
"E-religious_belief": 175,
|
| 345 |
+
"E-sexuality": 179,
|
| 346 |
+
"E-ssn": 183,
|
| 347 |
+
"E-state": 187,
|
| 348 |
+
"E-street_address": 191,
|
| 349 |
+
"E-swift_bic": 195,
|
| 350 |
+
"E-tax_id": 199,
|
| 351 |
+
"E-time": 203,
|
| 352 |
+
"E-unique_id": 207,
|
| 353 |
+
"E-url": 211,
|
| 354 |
+
"E-user_name": 215,
|
| 355 |
+
"E-vehicle_identifier": 219,
|
| 356 |
+
"I-account_number": 2,
|
| 357 |
+
"I-age": 6,
|
| 358 |
+
"I-api_key": 10,
|
| 359 |
+
"I-bank_routing_number": 14,
|
| 360 |
+
"I-biometric_identifier": 18,
|
| 361 |
+
"I-blood_type": 22,
|
| 362 |
+
"I-certificate_license_number": 26,
|
| 363 |
+
"I-city": 30,
|
| 364 |
+
"I-company_name": 34,
|
| 365 |
+
"I-coordinate": 38,
|
| 366 |
+
"I-country": 42,
|
| 367 |
+
"I-county": 46,
|
| 368 |
+
"I-credit_debit_card": 50,
|
| 369 |
+
"I-customer_id": 54,
|
| 370 |
+
"I-cvv": 58,
|
| 371 |
+
"I-date": 62,
|
| 372 |
+
"I-date_of_birth": 66,
|
| 373 |
+
"I-date_time": 70,
|
| 374 |
+
"I-device_identifier": 74,
|
| 375 |
+
"I-education_level": 78,
|
| 376 |
+
"I-email": 82,
|
| 377 |
+
"I-employee_id": 86,
|
| 378 |
+
"I-employment_status": 90,
|
| 379 |
+
"I-fax_number": 94,
|
| 380 |
+
"I-first_name": 98,
|
| 381 |
+
"I-gender": 102,
|
| 382 |
+
"I-health_plan_beneficiary_number": 106,
|
| 383 |
+
"I-http_cookie": 110,
|
| 384 |
+
"I-ipv4": 114,
|
| 385 |
+
"I-ipv6": 118,
|
| 386 |
+
"I-language": 122,
|
| 387 |
+
"I-last_name": 126,
|
| 388 |
+
"I-license_plate": 130,
|
| 389 |
+
"I-mac_address": 134,
|
| 390 |
+
"I-medical_record_number": 138,
|
| 391 |
+
"I-national_id": 142,
|
| 392 |
+
"I-occupation": 146,
|
| 393 |
+
"I-password": 150,
|
| 394 |
+
"I-phone_number": 154,
|
| 395 |
+
"I-pin": 158,
|
| 396 |
+
"I-political_view": 162,
|
| 397 |
+
"I-postcode": 166,
|
| 398 |
+
"I-race_ethnicity": 170,
|
| 399 |
+
"I-religious_belief": 174,
|
| 400 |
+
"I-sexuality": 178,
|
| 401 |
+
"I-ssn": 182,
|
| 402 |
+
"I-state": 186,
|
| 403 |
+
"I-street_address": 190,
|
| 404 |
+
"I-swift_bic": 194,
|
| 405 |
+
"I-tax_id": 198,
|
| 406 |
+
"I-time": 202,
|
| 407 |
+
"I-unique_id": 206,
|
| 408 |
+
"I-url": 210,
|
| 409 |
+
"I-user_name": 214,
|
| 410 |
+
"I-vehicle_identifier": 218,
|
| 411 |
+
"O": 0,
|
| 412 |
+
"S-account_number": 4,
|
| 413 |
+
"S-age": 8,
|
| 414 |
+
"S-api_key": 12,
|
| 415 |
+
"S-bank_routing_number": 16,
|
| 416 |
+
"S-biometric_identifier": 20,
|
| 417 |
+
"S-blood_type": 24,
|
| 418 |
+
"S-certificate_license_number": 28,
|
| 419 |
+
"S-city": 32,
|
| 420 |
+
"S-company_name": 36,
|
| 421 |
+
"S-coordinate": 40,
|
| 422 |
+
"S-country": 44,
|
| 423 |
+
"S-county": 48,
|
| 424 |
+
"S-credit_debit_card": 52,
|
| 425 |
+
"S-customer_id": 56,
|
| 426 |
+
"S-cvv": 60,
|
| 427 |
+
"S-date": 64,
|
| 428 |
+
"S-date_of_birth": 68,
|
| 429 |
+
"S-date_time": 72,
|
| 430 |
+
"S-device_identifier": 76,
|
| 431 |
+
"S-education_level": 80,
|
| 432 |
+
"S-email": 84,
|
| 433 |
+
"S-employee_id": 88,
|
| 434 |
+
"S-employment_status": 92,
|
| 435 |
+
"S-fax_number": 96,
|
| 436 |
+
"S-first_name": 100,
|
| 437 |
+
"S-gender": 104,
|
| 438 |
+
"S-health_plan_beneficiary_number": 108,
|
| 439 |
+
"S-http_cookie": 112,
|
| 440 |
+
"S-ipv4": 116,
|
| 441 |
+
"S-ipv6": 120,
|
| 442 |
+
"S-language": 124,
|
| 443 |
+
"S-last_name": 128,
|
| 444 |
+
"S-license_plate": 132,
|
| 445 |
+
"S-mac_address": 136,
|
| 446 |
+
"S-medical_record_number": 140,
|
| 447 |
+
"S-national_id": 144,
|
| 448 |
+
"S-occupation": 148,
|
| 449 |
+
"S-password": 152,
|
| 450 |
+
"S-phone_number": 156,
|
| 451 |
+
"S-pin": 160,
|
| 452 |
+
"S-political_view": 164,
|
| 453 |
+
"S-postcode": 168,
|
| 454 |
+
"S-race_ethnicity": 172,
|
| 455 |
+
"S-religious_belief": 176,
|
| 456 |
+
"S-sexuality": 180,
|
| 457 |
+
"S-ssn": 184,
|
| 458 |
+
"S-state": 188,
|
| 459 |
+
"S-street_address": 192,
|
| 460 |
+
"S-swift_bic": 196,
|
| 461 |
+
"S-tax_id": 200,
|
| 462 |
+
"S-time": 204,
|
| 463 |
+
"S-unique_id": 208,
|
| 464 |
+
"S-url": 212,
|
| 465 |
+
"S-user_name": 216,
|
| 466 |
+
"S-vehicle_identifier": 220
|
| 467 |
+
},
|
| 468 |
+
"max_position_embeddings": 131072,
|
| 469 |
+
"model_type": "haremb_pii",
|
| 470 |
+
"num_attention_heads": 14,
|
| 471 |
+
"num_experts_per_tok": 4,
|
| 472 |
+
"num_hidden_layers": 1,
|
| 473 |
+
"num_key_value_heads": 2,
|
| 474 |
+
"num_local_experts": 128,
|
| 475 |
+
"opf_metadata": {
|
| 476 |
+
"category_version": "nemotron_fine_v1",
|
| 477 |
+
"encoding": "o200k_base",
|
| 478 |
+
"inference_contract_version": 1,
|
| 479 |
+
"ner_class_names": [
|
| 480 |
+
"O",
|
| 481 |
+
"B-account_number",
|
| 482 |
+
"I-account_number",
|
| 483 |
+
"E-account_number",
|
| 484 |
+
"S-account_number",
|
| 485 |
+
"B-age",
|
| 486 |
+
"I-age",
|
| 487 |
+
"E-age",
|
| 488 |
+
"S-age",
|
| 489 |
+
"B-api_key",
|
| 490 |
+
"I-api_key",
|
| 491 |
+
"E-api_key",
|
| 492 |
+
"S-api_key",
|
| 493 |
+
"B-bank_routing_number",
|
| 494 |
+
"I-bank_routing_number",
|
| 495 |
+
"E-bank_routing_number",
|
| 496 |
+
"S-bank_routing_number",
|
| 497 |
+
"B-biometric_identifier",
|
| 498 |
+
"I-biometric_identifier",
|
| 499 |
+
"E-biometric_identifier",
|
| 500 |
+
"S-biometric_identifier",
|
| 501 |
+
"B-blood_type",
|
| 502 |
+
"I-blood_type",
|
| 503 |
+
"E-blood_type",
|
| 504 |
+
"S-blood_type",
|
| 505 |
+
"B-certificate_license_number",
|
| 506 |
+
"I-certificate_license_number",
|
| 507 |
+
"E-certificate_license_number",
|
| 508 |
+
"S-certificate_license_number",
|
| 509 |
+
"B-city",
|
| 510 |
+
"I-city",
|
| 511 |
+
"E-city",
|
| 512 |
+
"S-city",
|
| 513 |
+
"B-company_name",
|
| 514 |
+
"I-company_name",
|
| 515 |
+
"E-company_name",
|
| 516 |
+
"S-company_name",
|
| 517 |
+
"B-coordinate",
|
| 518 |
+
"I-coordinate",
|
| 519 |
+
"E-coordinate",
|
| 520 |
+
"S-coordinate",
|
| 521 |
+
"B-country",
|
| 522 |
+
"I-country",
|
| 523 |
+
"E-country",
|
| 524 |
+
"S-country",
|
| 525 |
+
"B-county",
|
| 526 |
+
"I-county",
|
| 527 |
+
"E-county",
|
| 528 |
+
"S-county",
|
| 529 |
+
"B-credit_debit_card",
|
| 530 |
+
"I-credit_debit_card",
|
| 531 |
+
"E-credit_debit_card",
|
| 532 |
+
"S-credit_debit_card",
|
| 533 |
+
"B-customer_id",
|
| 534 |
+
"I-customer_id",
|
| 535 |
+
"E-customer_id",
|
| 536 |
+
"S-customer_id",
|
| 537 |
+
"B-cvv",
|
| 538 |
+
"I-cvv",
|
| 539 |
+
"E-cvv",
|
| 540 |
+
"S-cvv",
|
| 541 |
+
"B-date",
|
| 542 |
+
"I-date",
|
| 543 |
+
"E-date",
|
| 544 |
+
"S-date",
|
| 545 |
+
"B-date_of_birth",
|
| 546 |
+
"I-date_of_birth",
|
| 547 |
+
"E-date_of_birth",
|
| 548 |
+
"S-date_of_birth",
|
| 549 |
+
"B-date_time",
|
| 550 |
+
"I-date_time",
|
| 551 |
+
"E-date_time",
|
| 552 |
+
"S-date_time",
|
| 553 |
+
"B-device_identifier",
|
| 554 |
+
"I-device_identifier",
|
| 555 |
+
"E-device_identifier",
|
| 556 |
+
"S-device_identifier",
|
| 557 |
+
"B-education_level",
|
| 558 |
+
"I-education_level",
|
| 559 |
+
"E-education_level",
|
| 560 |
+
"S-education_level",
|
| 561 |
+
"B-email",
|
| 562 |
+
"I-email",
|
| 563 |
+
"E-email",
|
| 564 |
+
"S-email",
|
| 565 |
+
"B-employee_id",
|
| 566 |
+
"I-employee_id",
|
| 567 |
+
"E-employee_id",
|
| 568 |
+
"S-employee_id",
|
| 569 |
+
"B-employment_status",
|
| 570 |
+
"I-employment_status",
|
| 571 |
+
"E-employment_status",
|
| 572 |
+
"S-employment_status",
|
| 573 |
+
"B-fax_number",
|
| 574 |
+
"I-fax_number",
|
| 575 |
+
"E-fax_number",
|
| 576 |
+
"S-fax_number",
|
| 577 |
+
"B-first_name",
|
| 578 |
+
"I-first_name",
|
| 579 |
+
"E-first_name",
|
| 580 |
+
"S-first_name",
|
| 581 |
+
"B-gender",
|
| 582 |
+
"I-gender",
|
| 583 |
+
"E-gender",
|
| 584 |
+
"S-gender",
|
| 585 |
+
"B-health_plan_beneficiary_number",
|
| 586 |
+
"I-health_plan_beneficiary_number",
|
| 587 |
+
"E-health_plan_beneficiary_number",
|
| 588 |
+
"S-health_plan_beneficiary_number",
|
| 589 |
+
"B-http_cookie",
|
| 590 |
+
"I-http_cookie",
|
| 591 |
+
"E-http_cookie",
|
| 592 |
+
"S-http_cookie",
|
| 593 |
+
"B-ipv4",
|
| 594 |
+
"I-ipv4",
|
| 595 |
+
"E-ipv4",
|
| 596 |
+
"S-ipv4",
|
| 597 |
+
"B-ipv6",
|
| 598 |
+
"I-ipv6",
|
| 599 |
+
"E-ipv6",
|
| 600 |
+
"S-ipv6",
|
| 601 |
+
"B-language",
|
| 602 |
+
"I-language",
|
| 603 |
+
"E-language",
|
| 604 |
+
"S-language",
|
| 605 |
+
"B-last_name",
|
| 606 |
+
"I-last_name",
|
| 607 |
+
"E-last_name",
|
| 608 |
+
"S-last_name",
|
| 609 |
+
"B-license_plate",
|
| 610 |
+
"I-license_plate",
|
| 611 |
+
"E-license_plate",
|
| 612 |
+
"S-license_plate",
|
| 613 |
+
"B-mac_address",
|
| 614 |
+
"I-mac_address",
|
| 615 |
+
"E-mac_address",
|
| 616 |
+
"S-mac_address",
|
| 617 |
+
"B-medical_record_number",
|
| 618 |
+
"I-medical_record_number",
|
| 619 |
+
"E-medical_record_number",
|
| 620 |
+
"S-medical_record_number",
|
| 621 |
+
"B-national_id",
|
| 622 |
+
"I-national_id",
|
| 623 |
+
"E-national_id",
|
| 624 |
+
"S-national_id",
|
| 625 |
+
"B-occupation",
|
| 626 |
+
"I-occupation",
|
| 627 |
+
"E-occupation",
|
| 628 |
+
"S-occupation",
|
| 629 |
+
"B-password",
|
| 630 |
+
"I-password",
|
| 631 |
+
"E-password",
|
| 632 |
+
"S-password",
|
| 633 |
+
"B-phone_number",
|
| 634 |
+
"I-phone_number",
|
| 635 |
+
"E-phone_number",
|
| 636 |
+
"S-phone_number",
|
| 637 |
+
"B-pin",
|
| 638 |
+
"I-pin",
|
| 639 |
+
"E-pin",
|
| 640 |
+
"S-pin",
|
| 641 |
+
"B-political_view",
|
| 642 |
+
"I-political_view",
|
| 643 |
+
"E-political_view",
|
| 644 |
+
"S-political_view",
|
| 645 |
+
"B-postcode",
|
| 646 |
+
"I-postcode",
|
| 647 |
+
"E-postcode",
|
| 648 |
+
"S-postcode",
|
| 649 |
+
"B-race_ethnicity",
|
| 650 |
+
"I-race_ethnicity",
|
| 651 |
+
"E-race_ethnicity",
|
| 652 |
+
"S-race_ethnicity",
|
| 653 |
+
"B-religious_belief",
|
| 654 |
+
"I-religious_belief",
|
| 655 |
+
"E-religious_belief",
|
| 656 |
+
"S-religious_belief",
|
| 657 |
+
"B-sexuality",
|
| 658 |
+
"I-sexuality",
|
| 659 |
+
"E-sexuality",
|
| 660 |
+
"S-sexuality",
|
| 661 |
+
"B-ssn",
|
| 662 |
+
"I-ssn",
|
| 663 |
+
"E-ssn",
|
| 664 |
+
"S-ssn",
|
| 665 |
+
"B-state",
|
| 666 |
+
"I-state",
|
| 667 |
+
"E-state",
|
| 668 |
+
"S-state",
|
| 669 |
+
"B-street_address",
|
| 670 |
+
"I-street_address",
|
| 671 |
+
"E-street_address",
|
| 672 |
+
"S-street_address",
|
| 673 |
+
"B-swift_bic",
|
| 674 |
+
"I-swift_bic",
|
| 675 |
+
"E-swift_bic",
|
| 676 |
+
"S-swift_bic",
|
| 677 |
+
"B-tax_id",
|
| 678 |
+
"I-tax_id",
|
| 679 |
+
"E-tax_id",
|
| 680 |
+
"S-tax_id",
|
| 681 |
+
"B-time",
|
| 682 |
+
"I-time",
|
| 683 |
+
"E-time",
|
| 684 |
+
"S-time",
|
| 685 |
+
"B-unique_id",
|
| 686 |
+
"I-unique_id",
|
| 687 |
+
"E-unique_id",
|
| 688 |
+
"S-unique_id",
|
| 689 |
+
"B-url",
|
| 690 |
+
"I-url",
|
| 691 |
+
"E-url",
|
| 692 |
+
"S-url",
|
| 693 |
+
"B-user_name",
|
| 694 |
+
"I-user_name",
|
| 695 |
+
"E-user_name",
|
| 696 |
+
"S-user_name",
|
| 697 |
+
"B-vehicle_identifier",
|
| 698 |
+
"I-vehicle_identifier",
|
| 699 |
+
"E-vehicle_identifier",
|
| 700 |
+
"S-vehicle_identifier"
|
| 701 |
+
],
|
| 702 |
+
"span_class_names": [
|
| 703 |
+
"O",
|
| 704 |
+
"account_number",
|
| 705 |
+
"age",
|
| 706 |
+
"api_key",
|
| 707 |
+
"bank_routing_number",
|
| 708 |
+
"biometric_identifier",
|
| 709 |
+
"blood_type",
|
| 710 |
+
"certificate_license_number",
|
| 711 |
+
"city",
|
| 712 |
+
"company_name",
|
| 713 |
+
"coordinate",
|
| 714 |
+
"country",
|
| 715 |
+
"county",
|
| 716 |
+
"credit_debit_card",
|
| 717 |
+
"customer_id",
|
| 718 |
+
"cvv",
|
| 719 |
+
"date",
|
| 720 |
+
"date_of_birth",
|
| 721 |
+
"date_time",
|
| 722 |
+
"device_identifier",
|
| 723 |
+
"education_level",
|
| 724 |
+
"email",
|
| 725 |
+
"employee_id",
|
| 726 |
+
"employment_status",
|
| 727 |
+
"fax_number",
|
| 728 |
+
"first_name",
|
| 729 |
+
"gender",
|
| 730 |
+
"health_plan_beneficiary_number",
|
| 731 |
+
"http_cookie",
|
| 732 |
+
"ipv4",
|
| 733 |
+
"ipv6",
|
| 734 |
+
"language",
|
| 735 |
+
"last_name",
|
| 736 |
+
"license_plate",
|
| 737 |
+
"mac_address",
|
| 738 |
+
"medical_record_number",
|
| 739 |
+
"national_id",
|
| 740 |
+
"occupation",
|
| 741 |
+
"password",
|
| 742 |
+
"phone_number",
|
| 743 |
+
"pin",
|
| 744 |
+
"political_view",
|
| 745 |
+
"postcode",
|
| 746 |
+
"race_ethnicity",
|
| 747 |
+
"religious_belief",
|
| 748 |
+
"sexuality",
|
| 749 |
+
"ssn",
|
| 750 |
+
"state",
|
| 751 |
+
"street_address",
|
| 752 |
+
"swift_bic",
|
| 753 |
+
"tax_id",
|
| 754 |
+
"time",
|
| 755 |
+
"unique_id",
|
| 756 |
+
"url",
|
| 757 |
+
"user_name",
|
| 758 |
+
"vehicle_identifier"
|
| 759 |
+
]
|
| 760 |
+
},
|
| 761 |
+
"output_router_logits": false,
|
| 762 |
+
"pad_token_id": 199999,
|
| 763 |
+
"rms_norm_eps": 1e-05,
|
| 764 |
+
"rope_parameters": {
|
| 765 |
+
"beta_fast": 32.0,
|
| 766 |
+
"beta_slow": 1.0,
|
| 767 |
+
"factor": 32.0,
|
| 768 |
+
"original_max_position_embeddings": 4096,
|
| 769 |
+
"rope_theta": 150000.0,
|
| 770 |
+
"rope_type": "yarn",
|
| 771 |
+
"truncate": false
|
| 772 |
+
},
|
| 773 |
+
"router_aux_loss_coef": 0.001,
|
| 774 |
+
"sliding_window": 128,
|
| 775 |
+
"tie_word_embeddings": false,
|
| 776 |
+
"transformers.js_config": {
|
| 777 |
+
"use_external_data_format": {
|
| 778 |
+
"model": 1,
|
| 779 |
+
"model.onnx": 3,
|
| 780 |
+
"model_fp16.onnx": 2
|
| 781 |
+
}
|
| 782 |
+
},
|
| 783 |
+
"transformers_version": "5.7.0",
|
| 784 |
+
"use_cache": true,
|
| 785 |
+
"use_viterbi_decode": true,
|
| 786 |
+
"viterbi_replace_logits": true,
|
| 787 |
+
"vocab_size": 200064
|
| 788 |
+
}
|
configuration_haremb_pii.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
HaremPiiConfig — subclass of OpenAIPrivacyFilterConfig that:
|
| 3 |
+
* sets `model_type="haremb_pii"` (so AutoConfig + auto_map dispatch works
|
| 4 |
+
with `trust_remote_code=True`)
|
| 5 |
+
* paired with HaremPiiForTokenClassification in modeling_haremb_pii.py
|
| 6 |
+
via `auto_map`
|
| 7 |
+
|
| 8 |
+
This release is a 1-layer surgical slice of the OpenMed teacher:
|
| 9 |
+
* num_hidden_layers=1
|
| 10 |
+
* inference-only — Viterbi decoding is built into the forward pass.
|
| 11 |
+
"""
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
from transformers.models.openai_privacy_filter.configuration_openai_privacy_filter import (
|
| 15 |
+
OpenAIPrivacyFilterConfig,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class HaremPiiConfig(OpenAIPrivacyFilterConfig):
|
| 20 |
+
"""
|
| 21 |
+
HarEmb config. `model_type="haremb_pii"` disambiguates from upstream so
|
| 22 |
+
AutoConfig + AutoModel mappings can target our subclasses without
|
| 23 |
+
colliding with the registered OpenAIPrivacyFilterConfig entry.
|
| 24 |
+
`modeling_haremb_pii` performs the auto-registration at import time.
|
| 25 |
+
"""
|
| 26 |
+
model_type = "haremb_pii"
|
| 27 |
+
|
| 28 |
+
def __init__(
|
| 29 |
+
self,
|
| 30 |
+
use_viterbi_decode: bool = True,
|
| 31 |
+
viterbi_replace_logits: bool = True,
|
| 32 |
+
**kwargs,
|
| 33 |
+
):
|
| 34 |
+
super().__init__(**kwargs)
|
| 35 |
+
# When True (and model is in eval mode), HaremPiiForTokenClassification.forward
|
| 36 |
+
# runs constrained BIOES Viterbi over logits and attaches `predicted_labels`
|
| 37 |
+
# to the output. Set to False to skip Viterbi entirely.
|
| 38 |
+
self.use_viterbi_decode = bool(use_viterbi_decode)
|
| 39 |
+
# When True (and Viterbi is on), forward replaces `outputs.logits` with a
|
| 40 |
+
# one-hot-shaped tensor whose argmax equals the Viterbi prediction. This
|
| 41 |
+
# makes HF `pipeline()` and any naive `logits.argmax(-1)` consumer use
|
| 42 |
+
# Viterbi predictions automatically. The raw logits are preserved on
|
| 43 |
+
# the output as `raw_logits`.
|
| 44 |
+
self.viterbi_replace_logits = bool(viterbi_replace_logits)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
__all__ = ["HaremPiiConfig"]
|
eval_confusion.png
ADDED
|
Git LFS Details
|
eval_performance.png
ADDED
|
eval_summary.png
ADDED
|
Git LFS Details
|
haremb.png
ADDED
|
Git LFS Details
|
infer.log
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Inference benchmark: A: openmed-base vs B: haremb
|
| 2 |
+
device : cuda dtype: torch.bfloat16
|
| 3 |
+
ctx : 1024
|
| 4 |
+
|
| 5 |
+
A: openmed-base (reference / teacher)
|
| 6 |
+
load : 0.71s
|
| 7 |
+
eval : 64.66s on 212,909 tokens (3293 tok/s)
|
| 8 |
+
Performance:
|
| 9 |
+
total params : 1399.61M (139.35M dense + 1260.26M MoE-experts)
|
| 10 |
+
active params / token : 178.73M (memory footprint — embed lookup + top_4/128 experts: 128.04M embed + 39.38M MoE-active + 11.31M attn/norm/head)
|
| 11 |
+
compute params / token : 50.69M (matmul FLOPs only — embedding lookup excluded)
|
| 12 |
+
GFLOP / token (fwd, MAC×2): 0.101
|
| 13 |
+
weights size (on disk) : —
|
| 14 |
+
weights size (in RAM) : 2.61 GiB
|
| 15 |
+
weights resident (GPU) : 2.61 GiB
|
| 16 |
+
peak GPU mem (eval, ctx=1024) : 3.30 GiB
|
| 17 |
+
|
| 18 |
+
B: haremb (this checkpoint)
|
| 19 |
+
load : 0.10s
|
| 20 |
+
eval : 33.56s on 212,909 tokens (6343 tok/s)
|
| 21 |
+
Performance:
|
| 22 |
+
total params : 287.11M (129.58M dense + 157.53M MoE-experts)
|
| 23 |
+
active params / token : 134.50M (memory footprint — embed lookup + top_4/128 experts: 128.04M embed + 4.92M MoE-active + 1.54M attn/norm/head)
|
| 24 |
+
compute params / token : 6.46M (matmul FLOPs only — embedding lookup excluded)
|
| 25 |
+
GFLOP / token (fwd, MAC×2): 0.013
|
| 26 |
+
weights size (on disk) : 547.6 MiB
|
| 27 |
+
weights size (in RAM) : 547.6 MiB
|
| 28 |
+
weights resident (GPU) : 548.3 MiB
|
| 29 |
+
peak GPU mem (eval, ctx=1024) : 1.22 GiB
|
| 30 |
+
|
| 31 |
+
B vs A (haremb vs openmed-base):
|
| 32 |
+
total params : 4.87× smaller
|
| 33 |
+
active params / token : 1.33× less [memory]
|
| 34 |
+
compute params / token : 7.85× cheaper [FLOPs]
|
| 35 |
+
GFLOP / token : 7.85× cheaper
|
| 36 |
+
weights size (on disk) : —
|
| 37 |
+
weights in RAM : 4.87× smaller
|
| 38 |
+
peak GPU mem (eval) : 2.70× less
|
| 39 |
+
throughput : 1.93× faster
|
| 40 |
+
|
| 41 |
+
Sample inference (load → tokenize → forward → viterbi-decode → spans):
|
| 42 |
+
text: 'Patient Sarah Johnson (DOB 03/15/1985), MRN 4872910, phone 415-555-0123, email sarah.johnson@example.com, credit card 4111-1111-1111-1111.'
|
| 43 |
+
forward latency: 65.8ms (53 tokens)
|
| 44 |
+
detected 7 spans:
|
| 45 |
+
[ 1, 2) first_name 'Sarah'
|
| 46 |
+
[ 2, 3) last_name 'Johnson'
|
| 47 |
+
[ 6, 12) date '03/15/1985'
|
| 48 |
+
[ 16, 19) phone_number '4872910'
|
| 49 |
+
[ 22, 28) phone_number '415-555-0123'
|
| 50 |
+
[ 30, 37) email 'sarah.johnson@example.com'
|
| 51 |
+
[ 41, 52) credit_debit_card '4111-1111-1111-1111'
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:64059006ac732bd608cc478ee579cc96f56174d8910603b6d4747688b130b8a2
|
| 3 |
+
size 574224842
|
modeling_haremb_pii.py
ADDED
|
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
HaremPii — 1-layer surgical inference wrapper over OpenAI Privacy Filter.
|
| 3 |
+
|
| 4 |
+
Defines:
|
| 5 |
+
* HaremPiiForTokenClassification — subclass of
|
| 6 |
+
OpenAIPrivacyFilterForTokenClassification. Reuses the upstream forward
|
| 7 |
+
pass and adds eval-time constrained-BIOES Viterbi decoding so
|
| 8 |
+
`outputs.logits.argmax(-1)` returns the Viterbi path.
|
| 9 |
+
* HaremPiiModel — encoder alias pinned to HaremPiiConfig.
|
| 10 |
+
|
| 11 |
+
The model class is auto-registered so
|
| 12 |
+
`AutoModelForTokenClassification.from_pretrained(repo, trust_remote_code=True)`
|
| 13 |
+
dispatches to us via `config.auto_map` (model_type "haremb_pii").
|
| 14 |
+
|
| 15 |
+
This file is the released, inference-only copy. It contains no
|
| 16 |
+
training-related utilities.
|
| 17 |
+
"""
|
| 18 |
+
from __future__ import annotations
|
| 19 |
+
|
| 20 |
+
from typing import Optional
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
import torch.nn as nn
|
| 24 |
+
from transformers import (
|
| 25 |
+
AutoConfig,
|
| 26 |
+
AutoModel,
|
| 27 |
+
AutoModelForTokenClassification,
|
| 28 |
+
)
|
| 29 |
+
from transformers.models.openai_privacy_filter.modeling_openai_privacy_filter import (
|
| 30 |
+
OpenAIPrivacyFilterForTokenClassification,
|
| 31 |
+
OpenAIPrivacyFilterModel,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
from configuration_haremb_pii import HaremPiiConfig
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# ---------------------------------------------------------------------------
|
| 38 |
+
# Constrained BIOES Viterbi (inlined so the checkpoint is self-contained)
|
| 39 |
+
# ---------------------------------------------------------------------------
|
| 40 |
+
# Transition rules:
|
| 41 |
+
# O -> {O, B-X, S-X}
|
| 42 |
+
# B-X -> {I-X, E-X}
|
| 43 |
+
# I-X -> {I-X, E-X}
|
| 44 |
+
# E-X -> {O, B-Y, S-Y}
|
| 45 |
+
# S-X -> {O, B-Y, S-Y}
|
| 46 |
+
# Initial state allows {O, B-X, S-X} only.
|
| 47 |
+
|
| 48 |
+
def _parse_bioes(label: str):
|
| 49 |
+
if label == "O" or "-" not in label:
|
| 50 |
+
return "O", None
|
| 51 |
+
pref, cat = label.split("-", 1)
|
| 52 |
+
return pref, cat
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _build_bioes_transition_mask(id2label) -> torch.Tensor:
|
| 56 |
+
C = len(id2label)
|
| 57 |
+
mask = torch.full((C, C), float("-inf"))
|
| 58 |
+
parsed = {i: _parse_bioes(id2label[i]) for i in range(C)}
|
| 59 |
+
for i, (p_prev, c_prev) in parsed.items():
|
| 60 |
+
for j, (p_cur, c_cur) in parsed.items():
|
| 61 |
+
ok = False
|
| 62 |
+
if p_prev == "O":
|
| 63 |
+
if p_cur in ("O", "B", "S"):
|
| 64 |
+
ok = True
|
| 65 |
+
elif p_prev == "B":
|
| 66 |
+
if p_cur in ("I", "E") and c_cur == c_prev:
|
| 67 |
+
ok = True
|
| 68 |
+
elif p_prev == "I":
|
| 69 |
+
if p_cur in ("I", "E") and c_cur == c_prev:
|
| 70 |
+
ok = True
|
| 71 |
+
elif p_prev in ("E", "S"):
|
| 72 |
+
if p_cur in ("O", "B", "S"):
|
| 73 |
+
ok = True
|
| 74 |
+
if ok:
|
| 75 |
+
mask[i, j] = 0.0
|
| 76 |
+
return mask
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def _build_bioes_initial_mask(id2label) -> torch.Tensor:
|
| 80 |
+
C = len(id2label)
|
| 81 |
+
mask = torch.full((C,), float("-inf"))
|
| 82 |
+
for i, lbl in id2label.items():
|
| 83 |
+
p, _ = _parse_bioes(lbl)
|
| 84 |
+
if p in ("O", "B", "S"):
|
| 85 |
+
mask[i] = 0.0
|
| 86 |
+
return mask
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def _bioes_viterbi(
|
| 90 |
+
logits: torch.Tensor,
|
| 91 |
+
transition_mask: torch.Tensor,
|
| 92 |
+
initial_mask: torch.Tensor,
|
| 93 |
+
) -> torch.Tensor:
|
| 94 |
+
if logits.dim() != 2:
|
| 95 |
+
raise ValueError(f"expected 2D logits, got {logits.shape}")
|
| 96 |
+
T = logits.shape[0]
|
| 97 |
+
mask = torch.ones((1, T), dtype=torch.long, device=logits.device)
|
| 98 |
+
out = _bioes_viterbi_batched(
|
| 99 |
+
logits.unsqueeze(0), mask, transition_mask, initial_mask,
|
| 100 |
+
)
|
| 101 |
+
return out[0]
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def _bioes_viterbi_batched(
|
| 105 |
+
logits: torch.Tensor,
|
| 106 |
+
attention_mask: torch.Tensor,
|
| 107 |
+
transition_mask: torch.Tensor,
|
| 108 |
+
initial_mask: torch.Tensor,
|
| 109 |
+
) -> torch.Tensor:
|
| 110 |
+
"""Vectorized constrained BIOES Viterbi.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
logits: [B, T, C] float
|
| 114 |
+
attention_mask: [B, T] {0, 1} long/bool
|
| 115 |
+
transition_mask: [C, C] 0 valid, -inf invalid
|
| 116 |
+
initial_mask: [C] 0 allowed first tag, -inf forbidden
|
| 117 |
+
|
| 118 |
+
Returns:
|
| 119 |
+
[B, T] LongTensor of best constrained-BIOES tag id per token; padded
|
| 120 |
+
positions hold -1.
|
| 121 |
+
"""
|
| 122 |
+
if logits.dim() != 3:
|
| 123 |
+
raise ValueError(f"expected 3D logits [B,T,C], got {logits.shape}")
|
| 124 |
+
device = logits.device
|
| 125 |
+
B, T, C = logits.shape
|
| 126 |
+
scores = logits.float()
|
| 127 |
+
trans = transition_mask.to(device).float()
|
| 128 |
+
init = initial_mask.to(device).float()
|
| 129 |
+
mask = attention_mask.to(device).bool()
|
| 130 |
+
|
| 131 |
+
dp = scores[:, 0] + init.unsqueeze(0)
|
| 132 |
+
back = torch.zeros((B, T, C), dtype=torch.long, device=device)
|
| 133 |
+
trans_b = trans.unsqueeze(0)
|
| 134 |
+
for t in range(1, T):
|
| 135 |
+
cand = dp.unsqueeze(2) + trans_b
|
| 136 |
+
best_val, best_prev = cand.max(dim=1)
|
| 137 |
+
new_dp = best_val + scores[:, t]
|
| 138 |
+
keep = mask[:, t].unsqueeze(1)
|
| 139 |
+
dp = torch.where(keep, new_dp, dp)
|
| 140 |
+
back[:, t] = best_prev
|
| 141 |
+
|
| 142 |
+
last_t = (mask.sum(dim=1) - 1).clamp_min(0)
|
| 143 |
+
best_last = dp.argmax(dim=1)
|
| 144 |
+
out = torch.full((B, T), -1, dtype=torch.long, device=device)
|
| 145 |
+
batch_idx = torch.arange(B, device=device)
|
| 146 |
+
out[batch_idx, last_t] = best_last
|
| 147 |
+
current = best_last.clone()
|
| 148 |
+
for t in range(T - 1, 0, -1):
|
| 149 |
+
new_current = torch.gather(
|
| 150 |
+
back[:, t, :], 1, current.unsqueeze(1)
|
| 151 |
+
).squeeze(1)
|
| 152 |
+
active = (t <= last_t)
|
| 153 |
+
current = torch.where(active, new_current, current)
|
| 154 |
+
out[batch_idx, t - 1] = torch.where(
|
| 155 |
+
active, current, out[batch_idx, t - 1],
|
| 156 |
+
)
|
| 157 |
+
return out
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
# ---------------------------------------------------------------------------
|
| 161 |
+
# Architecture classes
|
| 162 |
+
# ---------------------------------------------------------------------------
|
| 163 |
+
|
| 164 |
+
class HaremPiiModel(OpenAIPrivacyFilterModel):
|
| 165 |
+
"""Thin alias of the upstream encoder pinned to HaremPiiConfig."""
|
| 166 |
+
config_class = HaremPiiConfig
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class HaremPiiForTokenClassification(OpenAIPrivacyFilterForTokenClassification):
|
| 170 |
+
"""1-layer student. Wraps the upstream forward with eval-time
|
| 171 |
+
constrained-BIOES Viterbi decoding."""
|
| 172 |
+
config_class = HaremPiiConfig
|
| 173 |
+
|
| 174 |
+
def __init__(self, config: HaremPiiConfig):
|
| 175 |
+
# Bypass GenericForTokenClassification.__init__ because it calls
|
| 176 |
+
# AutoModel.from_config(config), which uses type(config) as the
|
| 177 |
+
# registry key. Under the trust_remote_code Hub-loading path the
|
| 178 |
+
# cached HaremPiiConfig class identity differs from whatever was
|
| 179 |
+
# registered at module import (the cache hosts the class under a
|
| 180 |
+
# synthetic, sha-qualified module name). Constructing the encoder
|
| 181 |
+
# directly avoids the registry dispatch entirely.
|
| 182 |
+
from transformers.modeling_utils import PreTrainedModel as _PreTrainedModel
|
| 183 |
+
_PreTrainedModel.__init__(self, config)
|
| 184 |
+
self.num_labels = config.num_labels
|
| 185 |
+
self.model = OpenAIPrivacyFilterModel(config)
|
| 186 |
+
if getattr(config, "classifier_dropout", None) is not None:
|
| 187 |
+
classifier_dropout = config.classifier_dropout
|
| 188 |
+
elif getattr(config, "hidden_dropout", None) is not None:
|
| 189 |
+
classifier_dropout = config.hidden_dropout
|
| 190 |
+
else:
|
| 191 |
+
classifier_dropout = 0.1
|
| 192 |
+
self.dropout = nn.Dropout(classifier_dropout)
|
| 193 |
+
self.score = nn.Linear(config.hidden_size, config.num_labels)
|
| 194 |
+
self.post_init()
|
| 195 |
+
|
| 196 |
+
self._viterbi_trans_mask = None
|
| 197 |
+
self._viterbi_init_mask = None
|
| 198 |
+
|
| 199 |
+
def _ensure_viterbi_masks(self):
|
| 200 |
+
if self._viterbi_trans_mask is None:
|
| 201 |
+
id2label = {int(k): v for k, v in self.config.id2label.items()}
|
| 202 |
+
self._viterbi_trans_mask = _build_bioes_transition_mask(id2label)
|
| 203 |
+
self._viterbi_init_mask = _build_bioes_initial_mask(id2label)
|
| 204 |
+
return self._viterbi_trans_mask, self._viterbi_init_mask
|
| 205 |
+
|
| 206 |
+
@torch.no_grad()
|
| 207 |
+
def decode_predictions(
|
| 208 |
+
self,
|
| 209 |
+
logits: torch.Tensor,
|
| 210 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 211 |
+
) -> torch.Tensor:
|
| 212 |
+
trans, init = self._ensure_viterbi_masks()
|
| 213 |
+
if logits.dim() == 2:
|
| 214 |
+
T = logits.shape[0]
|
| 215 |
+
mask = torch.ones((1, T), dtype=torch.long, device=logits.device)
|
| 216 |
+
return _bioes_viterbi_batched(
|
| 217 |
+
logits.unsqueeze(0), mask, trans, init,
|
| 218 |
+
)[0]
|
| 219 |
+
if attention_mask is None:
|
| 220 |
+
attention_mask = torch.ones(
|
| 221 |
+
logits.shape[:2], dtype=torch.long, device=logits.device,
|
| 222 |
+
)
|
| 223 |
+
return _bioes_viterbi_batched(logits, attention_mask, trans, init)
|
| 224 |
+
|
| 225 |
+
def forward(self, *args, **kwargs):
|
| 226 |
+
outputs = super().forward(*args, **kwargs)
|
| 227 |
+
if self.training:
|
| 228 |
+
return outputs
|
| 229 |
+
if not getattr(self.config, "use_viterbi_decode", True):
|
| 230 |
+
return outputs
|
| 231 |
+
|
| 232 |
+
attn_mask = kwargs.get("attention_mask", None)
|
| 233 |
+
if attn_mask is None and len(args) >= 2:
|
| 234 |
+
attn_mask = args[1]
|
| 235 |
+
|
| 236 |
+
decoded = self.decode_predictions(outputs.logits, attention_mask=attn_mask)
|
| 237 |
+
try:
|
| 238 |
+
outputs.predicted_labels = decoded
|
| 239 |
+
except Exception:
|
| 240 |
+
outputs.__dict__["predicted_labels"] = decoded
|
| 241 |
+
|
| 242 |
+
if getattr(self.config, "viterbi_replace_logits", True):
|
| 243 |
+
raw = outputs.logits
|
| 244 |
+
fake = torch.full_like(raw, fill_value=-1e9)
|
| 245 |
+
fake.scatter_(-1, decoded.clamp_min(0).unsqueeze(-1), 1e9)
|
| 246 |
+
try:
|
| 247 |
+
outputs.raw_logits = raw
|
| 248 |
+
outputs.logits = fake
|
| 249 |
+
except Exception:
|
| 250 |
+
outputs.__dict__["raw_logits"] = raw
|
| 251 |
+
outputs.__dict__["logits"] = fake
|
| 252 |
+
return outputs
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
# --- Auto-registry ---
|
| 256 |
+
AutoConfig.register("haremb_pii", HaremPiiConfig, exist_ok=True)
|
| 257 |
+
AutoModel.register(HaremPiiConfig, HaremPiiModel, exist_ok=True)
|
| 258 |
+
AutoModelForTokenClassification.register(
|
| 259 |
+
HaremPiiConfig, HaremPiiForTokenClassification, exist_ok=True,
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
HaremPiiConfig.register_for_auto_class("AutoConfig")
|
| 263 |
+
HaremPiiForTokenClassification.register_for_auto_class("AutoModelForTokenClassification")
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
__all__ = [
|
| 267 |
+
"HaremPiiConfig",
|
| 268 |
+
"HaremPiiModel",
|
| 269 |
+
"HaremPiiForTokenClassification",
|
| 270 |
+
]
|
tokenizer.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0614fe83cadab421296e664e1f48f4261fa8fef6e03e63bb75c20f38e37d07d3
|
| 3 |
+
size 27868174
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"backend": "tokenizers",
|
| 3 |
+
"eos_token": "<|endoftext|>",
|
| 4 |
+
"is_local": false,
|
| 5 |
+
"local_files_only": false,
|
| 6 |
+
"model_input_names": [
|
| 7 |
+
"input_ids",
|
| 8 |
+
"attention_mask"
|
| 9 |
+
],
|
| 10 |
+
"model_max_length": 128000,
|
| 11 |
+
"pad_token": "<|endoftext|>",
|
| 12 |
+
"tokenizer_class": "TokenizersBackend"
|
| 13 |
+
}
|