""" HarEmb PII — local Gradio inference demo. Upload a PDF (or paste text), pick a device (CPU / cuda:N), and the model highlights detected PII spans across the 55-category Nemotron-PII taxonomy. Install: pip install "gradio>=4" "transformers>=4.45" torch pypdf accelerate Run from inside this folder: python app.py """ from __future__ import annotations import argparse import re from pathlib import Path from typing import Dict, List, Optional, Tuple import gradio as gr import torch from pypdf import PdfReader from transformers import pipeline # Default to loading from this folder so `python app.py` works in-place after # downloading the repo. Override by setting --model-path on the CLI. DEFAULT_MODEL = "." CHUNK_CHARS = 400_000 # ~100k tokens; well under the model's 131k window # 55 Nemotron-PII categories grouped for visual coherence; one color per # coarse "family" so the highlight legend stays readable. PALETTE: Dict[str, str] = { # Identity (red) "first_name": "#ef4444", "last_name": "#ef4444", "user_name": "#ef4444", "company_name": "#ef4444", "age": "#fb7185", "gender": "#fb7185", "race_ethnicity": "#fb7185", "sexuality": "#fb7185", "religious_belief": "#fb7185", "political_view": "#fb7185", "language": "#fb7185", "education_level": "#fb7185", "occupation": "#fb7185", "employment_status": "#fb7185", "blood_type": "#fb7185", "biometric_identifier":"#fb7185", # Contact (purple) "email": "#8b5cf6", "phone_number": "#a78bfa", "fax_number": "#a78bfa", "url": "#7c3aed", # Address (green) "street_address": "#10b981", "city": "#34d399", "county": "#34d399", "state": "#34d399", "country": "#34d399", "postcode": "#34d399", "coordinate": "#059669", # Dates (blue) "date": "#3b82f6", "date_of_birth": "#60a5fa", "date_time": "#60a5fa", "time": "#60a5fa", # Government IDs (orange) "ssn": "#f97316", "national_id": "#fb923c", "tax_id": "#fb923c", # Financial (amber) "account_number": "#f59e0b", "bank_routing_number": "#fbbf24", "swift_bic": "#fbbf24", "credit_debit_card": "#fbbf24", "cvv": "#fbbf24", "pin": "#fbbf24", "password": "#d97706", # Healthcare (pink) "medical_record_number": "#ec4899", "health_plan_beneficiary_number": "#f472b6", # Enterprise IDs (cyan) "customer_id": "#06b6d4", "employee_id": "#06b6d4", "unique_id": "#22d3ee", "certificate_license_number": "#22d3ee", # Vehicle (lime) "license_plate": "#84cc16", "vehicle_identifier": "#84cc16", # Digital (indigo) "ipv4": "#6366f1", "ipv6": "#6366f1", "mac_address": "#818cf8", "device_identifier": "#818cf8", "api_key": "#4f46e5", "http_cookie": "#4f46e5", } def list_devices() -> List[str]: devs = ["cpu"] if torch.cuda.is_available(): for i in range(torch.cuda.device_count()): devs.append(f"cuda:{i}") return devs _pipe_cache: Dict[Tuple[str, str], object] = {} def get_pipe(model_path: str, device: str): key = (model_path, device) if key in _pipe_cache: return _pipe_cache[key] dtype = torch.bfloat16 if device.startswith("cuda") else torch.float32 pipe = pipeline( "token-classification", model=model_path, tokenizer=model_path, trust_remote_code=True, aggregation_strategy="simple", device=device, torch_dtype=dtype, ) _pipe_cache[key] = pipe return pipe def apply_runtime_config( pipe, use_viterbi: bool, viterbi_replace: bool, top_k: Optional[int] = None, ) -> None: cfg = pipe.model.config if hasattr(cfg, "use_viterbi_decode"): cfg.use_viterbi_decode = bool(use_viterbi) if hasattr(cfg, "viterbi_replace_logits"): cfg.viterbi_replace_logits = bool(viterbi_replace) # Override the per-layer MoE top-k at inference. Both fields need to be # set: `mlp.router.top_k` is the actual router top-k, and the upstream # `mlp.num_experts` is misnamed (it's also the per-token top_k, not # num_local_experts). top_k=None leaves the trained config alone. if top_k is not None: n_local = int(getattr(cfg, "num_local_experts", 128)) k = max(1, min(int(top_k), n_local)) for layer in pipe.model.model.layers: mlp = getattr(layer, "mlp", None) if mlp is None: continue router = getattr(mlp, "router", None) if router is not None and hasattr(router, "top_k"): router.top_k = k if hasattr(mlp, "num_experts"): mlp.num_experts = k def model_top_k_default(model_path: str) -> int: """Read the trained `num_experts_per_tok` from the model's config without loading the weights. Falls back to 4 if the field isn't present.""" try: from transformers import AutoConfig cfg = AutoConfig.from_pretrained(model_path, trust_remote_code=True) return int(getattr(cfg, "num_experts_per_tok", 4)) except Exception: return 4 def model_num_experts(model_path: str) -> int: """Read `num_local_experts` from the model's config without loading weights. Falls back to 128 if the field isn't present.""" try: from transformers import AutoConfig cfg = AutoConfig.from_pretrained(model_path, trust_remote_code=True) return int(getattr(cfg, "num_local_experts", 128)) except Exception: return 128 def clear_model_cache() -> str: _pipe_cache.clear() if torch.cuda.is_available(): torch.cuda.empty_cache() return "Model cache cleared. Next run will reload weights." def extract_text(file_obj) -> str: if file_obj is None: return "" path = file_obj.name if hasattr(file_obj, "name") else file_obj p = Path(path) if p.suffix.lower() == ".pdf": reader = PdfReader(str(p)) return "\n\n".join((page.extract_text() or "") for page in reader.pages) return p.read_text(encoding="utf-8", errors="replace") def chunk_text(text: str, max_chars: int = CHUNK_CHARS) -> List[Tuple[int, str]]: if not text: return [] if max_chars <= 0 or len(text) <= max_chars: return [(0, text)] pieces = re.split(r"(\n\s*\n)", text) chunks: List[Tuple[int, str]] = [] cur, cur_off, pos = "", 0, 0 for piece in pieces: if cur and len(cur) + len(piece) > max_chars and cur.strip(): chunks.append((cur_off, cur)) cur, cur_off = piece, pos else: if not cur: cur_off = pos cur += piece pos += len(piece) if cur.strip(): chunks.append((cur_off, cur)) return chunks def category_of(label: str) -> str: if len(label) > 2 and label[1] == "-": return label[2:] return label def predict( model_path: str, device: str, text: str, aggregation: str, use_viterbi: bool, viterbi_replace: bool, top_k: Optional[int] = None, chunk_chars: int = CHUNK_CHARS, ) -> List[Dict]: if not text.strip(): return [] pipe = get_pipe(model_path, device) apply_runtime_config(pipe, use_viterbi, viterbi_replace, top_k=top_k) spans: List[Dict] = [] for offset, chunk in chunk_text(text, max_chars=chunk_chars): for ent in pipe(chunk, aggregation_strategy=aggregation): label = ent.get("entity_group") or ent.get("entity") or "" cat = category_of(label) if cat not in PALETTE: continue s = ent["start"] + offset e = ent["end"] + offset spans.append({ "start": s, "end": e, "label": cat, "score": float(ent["score"]), "text": text[s:e], }) spans.sort(key=lambda s: (s["start"], s["end"])) return spans def to_highlight(text: str, spans: List[Dict]) -> List[Tuple[str, Optional[str]]]: if not text: return [] out: List[Tuple[str, Optional[str]]] = [] cur = 0 for s in spans: if s["start"] < cur: continue if s["start"] > cur: out.append((text[cur:s["start"]], None)) out.append((text[s["start"]:s["end"]], s["label"])) cur = s["end"] if cur < len(text): out.append((text[cur:], None)) return out def fmt_spans(spans: List[Dict], max_rows: int = 60) -> str: if not spans: return "_No PII spans detected._" rows = [ f"- `{s['label']}`   `{s['text'][:80].replace('`', '')}`   (score {s['score']:.2f})" for s in spans[:max_rows] ] more = f"\n\n_…+{len(spans) - max_rows} more_" if len(spans) > max_rows else "" return f"**Detected {len(spans)} span(s):**\n" + "\n".join(rows) + more # Build legend HTML for the categories present in PALETTE — one row per family # (we still want it readable; show one swatch per unique color). def _legend_html() -> str: seen = {} for name, c in PALETTE.items(): seen.setdefault(c, []).append(name) rows = [] for c, names in seen.items(): chip = (f"" f"{names[0]}{(' +'+str(len(names)-1)) if len(names)>1 else ''}") rows.append(chip) return ("
" + "".join(rows) + "
") LEGEND_HTML = _legend_html() def diff_spans(a: List[Dict], b: List[Dict]): """Return (only_in_a, only_in_b, agreed) span-lists. Keys are the (start, end, label) triple — agreement requires identical category.""" key = lambda s: (s["start"], s["end"], s["label"]) sa = {key(s): s for s in a} sb = {key(s): s for s in b} only_a = [sa[k] for k in sa if k not in sb] only_b = [sb[k] for k in sb if k not in sa] both = [sa[k] for k in sa if k in sb] return only_a, only_b, both def fmt_diff(label_a: str, label_b: str, only_a: List[Dict], only_b: List[Dict], agreed: List[Dict]) -> str: def fmt(name: str, lst: List[Dict]) -> str: if not lst: return f"**{name}:** none" rows = [ f"- `{s['label']}`   `{s['text'][:80].replace('`', '')}` " f"  (score {s['score']:.2f})" for s in lst[:30] ] more = f"\n …+{len(lst) - 30} more" if len(lst) > 30 else "" return f"**{name} ({len(lst)}):**\n" + "\n".join(rows) + more return "\n\n".join([ fmt(f"Only {label_a}", only_a), fmt(f"Only {label_b}", only_b), fmt("Agreed by both", agreed), ]) def run( file_obj, pasted_text, device, model_a_path, model_b_path, use_a, use_b, aggregation, use_viterbi, viterbi_replace, top_k_a, top_k_b, min_score, chunk_chars, ): text = extract_text(file_obj) if file_obj else (pasted_text or "") if not text.strip(): return [], [], "_Provide a PDF, a text file, or pasted text._", "" if not (use_a or use_b): return [], [], "_Enable at least one model._", text a_spans = ( predict(model_a_path, device, text, aggregation, use_viterbi, viterbi_replace, top_k=int(top_k_a), chunk_chars=int(chunk_chars)) if use_a else [] ) b_spans = ( predict(model_b_path, device, text, aggregation, use_viterbi, viterbi_replace, top_k=int(top_k_b), chunk_chars=int(chunk_chars)) if use_b else [] ) thr = float(min_score) a_spans = [s for s in a_spans if s["score"] >= thr] b_spans = [s for s in b_spans if s["score"] >= thr] a_hl = to_highlight(text, a_spans) if use_a else [] b_hl = to_highlight(text, b_spans) if use_b else [] label_a = Path(model_a_path).name or model_a_path label_b = Path(model_b_path).name or model_b_path if use_a and use_b: only_a, only_b, agreed = diff_spans(a_spans, b_spans) diff_md = fmt_diff(label_a, label_b, only_a, only_b, agreed) elif use_a: diff_md = fmt_spans(a_spans) elif use_b: diff_md = fmt_spans(b_spans) else: diff_md = "_Enable a model._" return a_hl, b_hl, diff_md, text def build_ui(default_model_a: str, default_model_b: str) -> gr.Blocks: a_default_k = model_top_k_default(default_model_a) a_n_experts = model_num_experts(default_model_a) b_default_k = model_top_k_default(default_model_b) b_n_experts = model_num_experts(default_model_b) with gr.Blocks(title="HarEmb PII") as demo: gr.Markdown( "# HarEmb · OpenMed-Nemotron PII\n" "Detect PII across 55 categories of the Nemotron-PII taxonomy. " "Run **two models side-by-side** to compare detections — by " "default this checkpoint vs the OpenMed teacher it was distilled " "from. Disable one model to view a single detection." ) devices = list_devices() with gr.Row(): device_dd = gr.Dropdown(devices, value=devices[0], label="Device", scale=1) clear_btn = gr.Button("Clear model cache", variant="secondary", scale=1) with gr.Row(): with gr.Column(): use_a = gr.Checkbox(value=True, label="Enable model A (teacher / baseline)") model_a_tb = gr.Textbox( value=default_model_a, label="Model A — path / HF repo", info="Default: OpenMed/privacy-filter-nemotron (teacher).", ) top_k_a_sl = gr.Slider( 1, a_n_experts, value=a_default_k, step=1, label=f"Active experts per token (top-k of {a_n_experts})", info=f"Trained value: {a_default_k}. Lower = faster + less " f"capacity per token. Higher = more compute, denser " f"routing. Bypassing the trained value can drop " f"quality — useful for ablations.", ) with gr.Column(): use_b = gr.Checkbox(value=True, label="Enable model B (this checkpoint)") model_b_tb = gr.Textbox( value=default_model_b, label="Model B — path / HF repo", info="Default: ./ (this checkpoint).", ) top_k_b_sl = gr.Slider( 1, b_n_experts, value=b_default_k, step=1, label=f"Active experts per token (top-k of {b_n_experts})", info=f"Trained value: {b_default_k}.", ) with gr.Accordion("Inference settings", open=False): with gr.Row(): aggregation_dd = gr.Dropdown( ["simple", "first", "max", "average", "none"], value="simple", label="aggregation_strategy", info="how token-level labels are merged into spans", ) viterbi_cb = gr.Checkbox( value=True, label="use_viterbi_decode", info="constrained BIOES decoding (off = raw argmax)", ) viterbi_replace_cb = gr.Checkbox( value=True, label="viterbi_replace_logits", info="when on, outputs.logits.argmax(-1) returns the Viterbi path", ) min_score_sl = gr.Slider( 0.0, 1.0, value=0.0, step=0.01, label="min confidence", info="filter out spans with score below this threshold", ) chunk_sl = gr.Slider( 0, 500_000, value=CHUNK_CHARS, step=10_000, label="chunk size (chars)", info="0 = single pass; otherwise split on paragraphs at this size. " "Model window ≈131k tokens (~500k chars).", ) with gr.Row(): file_in = gr.File(label="PDF / text file", file_types=[".pdf", ".txt", ".md"]) text_in = gr.Textbox( label="…or paste text", lines=6, placeholder=("Patient Sarah Johnson (DOB 03/15/1985), MRN 4872910, " "phone 415-555-0123, email sarah.johnson@example.com."), ) run_btn = gr.Button("Detect PII", variant="primary") gr.HTML(LEGEND_HTML) with gr.Row(): a_out = gr.HighlightedText( label="Model A detections", color_map=PALETTE, show_legend=False, combine_adjacent=False, ) b_out = gr.HighlightedText( label="Model B detections", color_map=PALETTE, show_legend=False, combine_adjacent=False, ) diff_out = gr.Markdown("_Run a detection to see the diff / span list._") extracted_out = gr.Textbox( label="Extracted text (read-only)", lines=6, interactive=False, ) run_btn.click( run, [file_in, text_in, device_dd, model_a_tb, model_b_tb, use_a, use_b, aggregation_dd, viterbi_cb, viterbi_replace_cb, top_k_a_sl, top_k_b_sl, min_score_sl, chunk_sl], [a_out, b_out, diff_out, extracted_out], ) clear_btn.click(clear_model_cache, None, diff_out) return demo def parse_args() -> argparse.Namespace: p = argparse.ArgumentParser(description="HarEmb PII — Gradio demo") p.add_argument("--host", default="127.0.0.1", help="Bind address (default: 127.0.0.1)") p.add_argument("--port", type=int, default=7860, help="Port (default: 7860)") p.add_argument("--share", action="store_true", help="Create a public Gradio share link") p.add_argument("--model-a", default="OpenMed/privacy-filter-nemotron", help="Model A path / HF repo " "(default: OpenMed/privacy-filter-nemotron — teacher)") p.add_argument("--model-b", default=DEFAULT_MODEL, help="Model B path / HF repo " "(default: . — this checkpoint)") return p.parse_args() if __name__ == "__main__": args = parse_args() build_ui( default_model_a=args.model_a, default_model_b=args.model_b, ).launch( server_name=args.host, server_port=args.port, share=args.share, theme=gr.themes.Soft(), )