fblgit's picture
Upload folder using huggingface_hub
f0f5785
"""
HarEmb PII — local Gradio inference demo.
Upload a PDF (or paste text), pick a device (CPU / cuda:N), and the model
highlights detected PII spans across the 55-category Nemotron-PII taxonomy.
Install:
pip install "gradio>=4" "transformers>=4.45" torch pypdf accelerate
Run from inside this folder:
python app.py
"""
from __future__ import annotations
import argparse
import re
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import gradio as gr
import torch
from pypdf import PdfReader
from transformers import pipeline
# Default to loading from this folder so `python app.py` works in-place after
# downloading the repo. Override by setting --model-path on the CLI.
DEFAULT_MODEL = "."
CHUNK_CHARS = 400_000 # ~100k tokens; well under the model's 131k window
# 55 Nemotron-PII categories grouped for visual coherence; one color per
# coarse "family" so the highlight legend stays readable.
PALETTE: Dict[str, str] = {
# Identity (red)
"first_name": "#ef4444",
"last_name": "#ef4444",
"user_name": "#ef4444",
"company_name": "#ef4444",
"age": "#fb7185",
"gender": "#fb7185",
"race_ethnicity": "#fb7185",
"sexuality": "#fb7185",
"religious_belief": "#fb7185",
"political_view": "#fb7185",
"language": "#fb7185",
"education_level": "#fb7185",
"occupation": "#fb7185",
"employment_status": "#fb7185",
"blood_type": "#fb7185",
"biometric_identifier":"#fb7185",
# Contact (purple)
"email": "#8b5cf6",
"phone_number": "#a78bfa",
"fax_number": "#a78bfa",
"url": "#7c3aed",
# Address (green)
"street_address": "#10b981",
"city": "#34d399",
"county": "#34d399",
"state": "#34d399",
"country": "#34d399",
"postcode": "#34d399",
"coordinate": "#059669",
# Dates (blue)
"date": "#3b82f6",
"date_of_birth": "#60a5fa",
"date_time": "#60a5fa",
"time": "#60a5fa",
# Government IDs (orange)
"ssn": "#f97316",
"national_id": "#fb923c",
"tax_id": "#fb923c",
# Financial (amber)
"account_number": "#f59e0b",
"bank_routing_number": "#fbbf24",
"swift_bic": "#fbbf24",
"credit_debit_card": "#fbbf24",
"cvv": "#fbbf24",
"pin": "#fbbf24",
"password": "#d97706",
# Healthcare (pink)
"medical_record_number": "#ec4899",
"health_plan_beneficiary_number": "#f472b6",
# Enterprise IDs (cyan)
"customer_id": "#06b6d4",
"employee_id": "#06b6d4",
"unique_id": "#22d3ee",
"certificate_license_number": "#22d3ee",
# Vehicle (lime)
"license_plate": "#84cc16",
"vehicle_identifier": "#84cc16",
# Digital (indigo)
"ipv4": "#6366f1",
"ipv6": "#6366f1",
"mac_address": "#818cf8",
"device_identifier": "#818cf8",
"api_key": "#4f46e5",
"http_cookie": "#4f46e5",
}
def list_devices() -> List[str]:
devs = ["cpu"]
if torch.cuda.is_available():
for i in range(torch.cuda.device_count()):
devs.append(f"cuda:{i}")
return devs
_pipe_cache: Dict[Tuple[str, str], object] = {}
def get_pipe(model_path: str, device: str):
key = (model_path, device)
if key in _pipe_cache:
return _pipe_cache[key]
dtype = torch.bfloat16 if device.startswith("cuda") else torch.float32
pipe = pipeline(
"token-classification",
model=model_path,
tokenizer=model_path,
trust_remote_code=True,
aggregation_strategy="simple",
device=device,
torch_dtype=dtype,
)
_pipe_cache[key] = pipe
return pipe
def apply_runtime_config(
pipe,
use_viterbi: bool,
viterbi_replace: bool,
top_k: Optional[int] = None,
) -> None:
cfg = pipe.model.config
if hasattr(cfg, "use_viterbi_decode"):
cfg.use_viterbi_decode = bool(use_viterbi)
if hasattr(cfg, "viterbi_replace_logits"):
cfg.viterbi_replace_logits = bool(viterbi_replace)
# Override the per-layer MoE top-k at inference. Both fields need to be
# set: `mlp.router.top_k` is the actual router top-k, and the upstream
# `mlp.num_experts` is misnamed (it's also the per-token top_k, not
# num_local_experts). top_k=None leaves the trained config alone.
if top_k is not None:
n_local = int(getattr(cfg, "num_local_experts", 128))
k = max(1, min(int(top_k), n_local))
for layer in pipe.model.model.layers:
mlp = getattr(layer, "mlp", None)
if mlp is None:
continue
router = getattr(mlp, "router", None)
if router is not None and hasattr(router, "top_k"):
router.top_k = k
if hasattr(mlp, "num_experts"):
mlp.num_experts = k
def model_top_k_default(model_path: str) -> int:
"""Read the trained `num_experts_per_tok` from the model's config without
loading the weights. Falls back to 4 if the field isn't present."""
try:
from transformers import AutoConfig
cfg = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
return int(getattr(cfg, "num_experts_per_tok", 4))
except Exception:
return 4
def model_num_experts(model_path: str) -> int:
"""Read `num_local_experts` from the model's config without loading
weights. Falls back to 128 if the field isn't present."""
try:
from transformers import AutoConfig
cfg = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
return int(getattr(cfg, "num_local_experts", 128))
except Exception:
return 128
def clear_model_cache() -> str:
_pipe_cache.clear()
if torch.cuda.is_available():
torch.cuda.empty_cache()
return "Model cache cleared. Next run will reload weights."
def extract_text(file_obj) -> str:
if file_obj is None:
return ""
path = file_obj.name if hasattr(file_obj, "name") else file_obj
p = Path(path)
if p.suffix.lower() == ".pdf":
reader = PdfReader(str(p))
return "\n\n".join((page.extract_text() or "") for page in reader.pages)
return p.read_text(encoding="utf-8", errors="replace")
def chunk_text(text: str, max_chars: int = CHUNK_CHARS) -> List[Tuple[int, str]]:
if not text:
return []
if max_chars <= 0 or len(text) <= max_chars:
return [(0, text)]
pieces = re.split(r"(\n\s*\n)", text)
chunks: List[Tuple[int, str]] = []
cur, cur_off, pos = "", 0, 0
for piece in pieces:
if cur and len(cur) + len(piece) > max_chars and cur.strip():
chunks.append((cur_off, cur))
cur, cur_off = piece, pos
else:
if not cur:
cur_off = pos
cur += piece
pos += len(piece)
if cur.strip():
chunks.append((cur_off, cur))
return chunks
def category_of(label: str) -> str:
if len(label) > 2 and label[1] == "-":
return label[2:]
return label
def predict(
model_path: str,
device: str,
text: str,
aggregation: str,
use_viterbi: bool,
viterbi_replace: bool,
top_k: Optional[int] = None,
chunk_chars: int = CHUNK_CHARS,
) -> List[Dict]:
if not text.strip():
return []
pipe = get_pipe(model_path, device)
apply_runtime_config(pipe, use_viterbi, viterbi_replace, top_k=top_k)
spans: List[Dict] = []
for offset, chunk in chunk_text(text, max_chars=chunk_chars):
for ent in pipe(chunk, aggregation_strategy=aggregation):
label = ent.get("entity_group") or ent.get("entity") or ""
cat = category_of(label)
if cat not in PALETTE:
continue
s = ent["start"] + offset
e = ent["end"] + offset
spans.append({
"start": s, "end": e, "label": cat,
"score": float(ent["score"]),
"text": text[s:e],
})
spans.sort(key=lambda s: (s["start"], s["end"]))
return spans
def to_highlight(text: str, spans: List[Dict]) -> List[Tuple[str, Optional[str]]]:
if not text:
return []
out: List[Tuple[str, Optional[str]]] = []
cur = 0
for s in spans:
if s["start"] < cur:
continue
if s["start"] > cur:
out.append((text[cur:s["start"]], None))
out.append((text[s["start"]:s["end"]], s["label"]))
cur = s["end"]
if cur < len(text):
out.append((text[cur:], None))
return out
def fmt_spans(spans: List[Dict], max_rows: int = 60) -> str:
if not spans:
return "_No PII spans detected._"
rows = [
f"- `{s['label']}` &nbsp; `{s['text'][:80].replace('`', '')}` &nbsp; (score {s['score']:.2f})"
for s in spans[:max_rows]
]
more = f"\n\n_…+{len(spans) - max_rows} more_" if len(spans) > max_rows else ""
return f"**Detected {len(spans)} span(s):**\n" + "\n".join(rows) + more
# Build legend HTML for the categories present in PALETTE — one row per family
# (we still want it readable; show one swatch per unique color).
def _legend_html() -> str:
seen = {}
for name, c in PALETTE.items():
seen.setdefault(c, []).append(name)
rows = []
for c, names in seen.items():
chip = (f"<span style='background:{c};color:#fff;padding:.15rem .55rem;"
f"border-radius:.3rem;font-family:monospace;'>"
f"{names[0]}{(' +'+str(len(names)-1)) if len(names)>1 else ''}</span>")
rows.append(chip)
return ("<div style='display:flex;flex-wrap:wrap;gap:.4rem;font-size:.85rem;"
"margin:.25rem 0;'>" + "".join(rows) + "</div>")
LEGEND_HTML = _legend_html()
def diff_spans(a: List[Dict], b: List[Dict]):
"""Return (only_in_a, only_in_b, agreed) span-lists. Keys are the
(start, end, label) triple — agreement requires identical category."""
key = lambda s: (s["start"], s["end"], s["label"])
sa = {key(s): s for s in a}
sb = {key(s): s for s in b}
only_a = [sa[k] for k in sa if k not in sb]
only_b = [sb[k] for k in sb if k not in sa]
both = [sa[k] for k in sa if k in sb]
return only_a, only_b, both
def fmt_diff(label_a: str, label_b: str,
only_a: List[Dict], only_b: List[Dict], agreed: List[Dict]) -> str:
def fmt(name: str, lst: List[Dict]) -> str:
if not lst:
return f"**{name}:** none"
rows = [
f"- `{s['label']}` &nbsp; `{s['text'][:80].replace('`', '')}` "
f"&nbsp; (score {s['score']:.2f})"
for s in lst[:30]
]
more = f"\n …+{len(lst) - 30} more" if len(lst) > 30 else ""
return f"**{name} ({len(lst)}):**\n" + "\n".join(rows) + more
return "\n\n".join([
fmt(f"Only {label_a}", only_a),
fmt(f"Only {label_b}", only_b),
fmt("Agreed by both", agreed),
])
def run(
file_obj, pasted_text, device,
model_a_path, model_b_path,
use_a, use_b,
aggregation, use_viterbi, viterbi_replace,
top_k_a, top_k_b,
min_score, chunk_chars,
):
text = extract_text(file_obj) if file_obj else (pasted_text or "")
if not text.strip():
return [], [], "_Provide a PDF, a text file, or pasted text._", ""
if not (use_a or use_b):
return [], [], "_Enable at least one model._", text
a_spans = (
predict(model_a_path, device, text, aggregation,
use_viterbi, viterbi_replace,
top_k=int(top_k_a), chunk_chars=int(chunk_chars))
if use_a else []
)
b_spans = (
predict(model_b_path, device, text, aggregation,
use_viterbi, viterbi_replace,
top_k=int(top_k_b), chunk_chars=int(chunk_chars))
if use_b else []
)
thr = float(min_score)
a_spans = [s for s in a_spans if s["score"] >= thr]
b_spans = [s for s in b_spans if s["score"] >= thr]
a_hl = to_highlight(text, a_spans) if use_a else []
b_hl = to_highlight(text, b_spans) if use_b else []
label_a = Path(model_a_path).name or model_a_path
label_b = Path(model_b_path).name or model_b_path
if use_a and use_b:
only_a, only_b, agreed = diff_spans(a_spans, b_spans)
diff_md = fmt_diff(label_a, label_b, only_a, only_b, agreed)
elif use_a:
diff_md = fmt_spans(a_spans)
elif use_b:
diff_md = fmt_spans(b_spans)
else:
diff_md = "_Enable a model._"
return a_hl, b_hl, diff_md, text
def build_ui(default_model_a: str, default_model_b: str) -> gr.Blocks:
a_default_k = model_top_k_default(default_model_a)
a_n_experts = model_num_experts(default_model_a)
b_default_k = model_top_k_default(default_model_b)
b_n_experts = model_num_experts(default_model_b)
with gr.Blocks(title="HarEmb PII") as demo:
gr.Markdown(
"# HarEmb · OpenMed-Nemotron PII\n"
"Detect PII across 55 categories of the Nemotron-PII taxonomy. "
"Run **two models side-by-side** to compare detections — by "
"default this checkpoint vs the OpenMed teacher it was distilled "
"from. Disable one model to view a single detection."
)
devices = list_devices()
with gr.Row():
device_dd = gr.Dropdown(devices, value=devices[0], label="Device", scale=1)
clear_btn = gr.Button("Clear model cache", variant="secondary", scale=1)
with gr.Row():
with gr.Column():
use_a = gr.Checkbox(value=True, label="Enable model A (teacher / baseline)")
model_a_tb = gr.Textbox(
value=default_model_a,
label="Model A — path / HF repo",
info="Default: OpenMed/privacy-filter-nemotron (teacher).",
)
top_k_a_sl = gr.Slider(
1, a_n_experts, value=a_default_k, step=1,
label=f"Active experts per token (top-k of {a_n_experts})",
info=f"Trained value: {a_default_k}. Lower = faster + less "
f"capacity per token. Higher = more compute, denser "
f"routing. Bypassing the trained value can drop "
f"quality — useful for ablations.",
)
with gr.Column():
use_b = gr.Checkbox(value=True, label="Enable model B (this checkpoint)")
model_b_tb = gr.Textbox(
value=default_model_b,
label="Model B — path / HF repo",
info="Default: ./ (this checkpoint).",
)
top_k_b_sl = gr.Slider(
1, b_n_experts, value=b_default_k, step=1,
label=f"Active experts per token (top-k of {b_n_experts})",
info=f"Trained value: {b_default_k}.",
)
with gr.Accordion("Inference settings", open=False):
with gr.Row():
aggregation_dd = gr.Dropdown(
["simple", "first", "max", "average", "none"],
value="simple",
label="aggregation_strategy",
info="how token-level labels are merged into spans",
)
viterbi_cb = gr.Checkbox(
value=True,
label="use_viterbi_decode",
info="constrained BIOES decoding (off = raw argmax)",
)
viterbi_replace_cb = gr.Checkbox(
value=True,
label="viterbi_replace_logits",
info="when on, outputs.logits.argmax(-1) returns the Viterbi path",
)
min_score_sl = gr.Slider(
0.0, 1.0, value=0.0, step=0.01,
label="min confidence",
info="filter out spans with score below this threshold",
)
chunk_sl = gr.Slider(
0, 500_000, value=CHUNK_CHARS, step=10_000,
label="chunk size (chars)",
info="0 = single pass; otherwise split on paragraphs at this size. "
"Model window ≈131k tokens (~500k chars).",
)
with gr.Row():
file_in = gr.File(label="PDF / text file", file_types=[".pdf", ".txt", ".md"])
text_in = gr.Textbox(
label="…or paste text",
lines=6,
placeholder=("Patient Sarah Johnson (DOB 03/15/1985), MRN 4872910, "
"phone 415-555-0123, email sarah.johnson@example.com."),
)
run_btn = gr.Button("Detect PII", variant="primary")
gr.HTML(LEGEND_HTML)
with gr.Row():
a_out = gr.HighlightedText(
label="Model A detections",
color_map=PALETTE,
show_legend=False,
combine_adjacent=False,
)
b_out = gr.HighlightedText(
label="Model B detections",
color_map=PALETTE,
show_legend=False,
combine_adjacent=False,
)
diff_out = gr.Markdown("_Run a detection to see the diff / span list._")
extracted_out = gr.Textbox(
label="Extracted text (read-only)", lines=6, interactive=False,
)
run_btn.click(
run,
[file_in, text_in, device_dd,
model_a_tb, model_b_tb, use_a, use_b,
aggregation_dd, viterbi_cb, viterbi_replace_cb,
top_k_a_sl, top_k_b_sl,
min_score_sl, chunk_sl],
[a_out, b_out, diff_out, extracted_out],
)
clear_btn.click(clear_model_cache, None, diff_out)
return demo
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description="HarEmb PII — Gradio demo")
p.add_argument("--host", default="127.0.0.1", help="Bind address (default: 127.0.0.1)")
p.add_argument("--port", type=int, default=7860, help="Port (default: 7860)")
p.add_argument("--share", action="store_true", help="Create a public Gradio share link")
p.add_argument("--model-a", default="OpenMed/privacy-filter-nemotron",
help="Model A path / HF repo "
"(default: OpenMed/privacy-filter-nemotron — teacher)")
p.add_argument("--model-b", default=DEFAULT_MODEL,
help="Model B path / HF repo "
"(default: . — this checkpoint)")
return p.parse_args()
if __name__ == "__main__":
args = parse_args()
build_ui(
default_model_a=args.model_a,
default_model_b=args.model_b,
).launch(
server_name=args.host,
server_port=args.port,
share=args.share,
theme=gr.themes.Soft(),
)