| """ |
| HarEmb PII — local Gradio inference demo. |
| |
| Upload a PDF (or paste text), pick a device (CPU / cuda:N), and the model |
| highlights detected PII spans across the 55-category Nemotron-PII taxonomy. |
| |
| Install: |
| pip install "gradio>=4" "transformers>=4.45" torch pypdf accelerate |
| |
| Run from inside this folder: |
| python app.py |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import re |
| from pathlib import Path |
| from typing import Dict, List, Optional, Tuple |
|
|
| import gradio as gr |
| import torch |
| from pypdf import PdfReader |
| from transformers import pipeline |
|
|
|
|
| |
| |
| DEFAULT_MODEL = "." |
|
|
| CHUNK_CHARS = 400_000 |
|
|
| |
| |
| PALETTE: Dict[str, str] = { |
| |
| "first_name": "#ef4444", |
| "last_name": "#ef4444", |
| "user_name": "#ef4444", |
| "company_name": "#ef4444", |
| "age": "#fb7185", |
| "gender": "#fb7185", |
| "race_ethnicity": "#fb7185", |
| "sexuality": "#fb7185", |
| "religious_belief": "#fb7185", |
| "political_view": "#fb7185", |
| "language": "#fb7185", |
| "education_level": "#fb7185", |
| "occupation": "#fb7185", |
| "employment_status": "#fb7185", |
| "blood_type": "#fb7185", |
| "biometric_identifier":"#fb7185", |
| |
| "email": "#8b5cf6", |
| "phone_number": "#a78bfa", |
| "fax_number": "#a78bfa", |
| "url": "#7c3aed", |
| |
| "street_address": "#10b981", |
| "city": "#34d399", |
| "county": "#34d399", |
| "state": "#34d399", |
| "country": "#34d399", |
| "postcode": "#34d399", |
| "coordinate": "#059669", |
| |
| "date": "#3b82f6", |
| "date_of_birth": "#60a5fa", |
| "date_time": "#60a5fa", |
| "time": "#60a5fa", |
| |
| "ssn": "#f97316", |
| "national_id": "#fb923c", |
| "tax_id": "#fb923c", |
| |
| "account_number": "#f59e0b", |
| "bank_routing_number": "#fbbf24", |
| "swift_bic": "#fbbf24", |
| "credit_debit_card": "#fbbf24", |
| "cvv": "#fbbf24", |
| "pin": "#fbbf24", |
| "password": "#d97706", |
| |
| "medical_record_number": "#ec4899", |
| "health_plan_beneficiary_number": "#f472b6", |
| |
| "customer_id": "#06b6d4", |
| "employee_id": "#06b6d4", |
| "unique_id": "#22d3ee", |
| "certificate_license_number": "#22d3ee", |
| |
| "license_plate": "#84cc16", |
| "vehicle_identifier": "#84cc16", |
| |
| "ipv4": "#6366f1", |
| "ipv6": "#6366f1", |
| "mac_address": "#818cf8", |
| "device_identifier": "#818cf8", |
| "api_key": "#4f46e5", |
| "http_cookie": "#4f46e5", |
| } |
|
|
|
|
| def list_devices() -> List[str]: |
| devs = ["cpu"] |
| if torch.cuda.is_available(): |
| for i in range(torch.cuda.device_count()): |
| devs.append(f"cuda:{i}") |
| return devs |
|
|
|
|
| _pipe_cache: Dict[Tuple[str, str], object] = {} |
|
|
|
|
| def get_pipe(model_path: str, device: str): |
| key = (model_path, device) |
| if key in _pipe_cache: |
| return _pipe_cache[key] |
| dtype = torch.bfloat16 if device.startswith("cuda") else torch.float32 |
| pipe = pipeline( |
| "token-classification", |
| model=model_path, |
| tokenizer=model_path, |
| trust_remote_code=True, |
| aggregation_strategy="simple", |
| device=device, |
| torch_dtype=dtype, |
| ) |
| _pipe_cache[key] = pipe |
| return pipe |
|
|
|
|
| def apply_runtime_config( |
| pipe, |
| use_viterbi: bool, |
| viterbi_replace: bool, |
| top_k: Optional[int] = None, |
| ) -> None: |
| cfg = pipe.model.config |
| if hasattr(cfg, "use_viterbi_decode"): |
| cfg.use_viterbi_decode = bool(use_viterbi) |
| if hasattr(cfg, "viterbi_replace_logits"): |
| cfg.viterbi_replace_logits = bool(viterbi_replace) |
| |
| |
| |
| |
| if top_k is not None: |
| n_local = int(getattr(cfg, "num_local_experts", 128)) |
| k = max(1, min(int(top_k), n_local)) |
| for layer in pipe.model.model.layers: |
| mlp = getattr(layer, "mlp", None) |
| if mlp is None: |
| continue |
| router = getattr(mlp, "router", None) |
| if router is not None and hasattr(router, "top_k"): |
| router.top_k = k |
| if hasattr(mlp, "num_experts"): |
| mlp.num_experts = k |
|
|
|
|
| def model_top_k_default(model_path: str) -> int: |
| """Read the trained `num_experts_per_tok` from the model's config without |
| loading the weights. Falls back to 4 if the field isn't present.""" |
| try: |
| from transformers import AutoConfig |
| cfg = AutoConfig.from_pretrained(model_path, trust_remote_code=True) |
| return int(getattr(cfg, "num_experts_per_tok", 4)) |
| except Exception: |
| return 4 |
|
|
|
|
| def model_num_experts(model_path: str) -> int: |
| """Read `num_local_experts` from the model's config without loading |
| weights. Falls back to 128 if the field isn't present.""" |
| try: |
| from transformers import AutoConfig |
| cfg = AutoConfig.from_pretrained(model_path, trust_remote_code=True) |
| return int(getattr(cfg, "num_local_experts", 128)) |
| except Exception: |
| return 128 |
|
|
|
|
| def clear_model_cache() -> str: |
| _pipe_cache.clear() |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| return "Model cache cleared. Next run will reload weights." |
|
|
|
|
| def extract_text(file_obj) -> str: |
| if file_obj is None: |
| return "" |
| path = file_obj.name if hasattr(file_obj, "name") else file_obj |
| p = Path(path) |
| if p.suffix.lower() == ".pdf": |
| reader = PdfReader(str(p)) |
| return "\n\n".join((page.extract_text() or "") for page in reader.pages) |
| return p.read_text(encoding="utf-8", errors="replace") |
|
|
|
|
| def chunk_text(text: str, max_chars: int = CHUNK_CHARS) -> List[Tuple[int, str]]: |
| if not text: |
| return [] |
| if max_chars <= 0 or len(text) <= max_chars: |
| return [(0, text)] |
| pieces = re.split(r"(\n\s*\n)", text) |
| chunks: List[Tuple[int, str]] = [] |
| cur, cur_off, pos = "", 0, 0 |
| for piece in pieces: |
| if cur and len(cur) + len(piece) > max_chars and cur.strip(): |
| chunks.append((cur_off, cur)) |
| cur, cur_off = piece, pos |
| else: |
| if not cur: |
| cur_off = pos |
| cur += piece |
| pos += len(piece) |
| if cur.strip(): |
| chunks.append((cur_off, cur)) |
| return chunks |
|
|
|
|
| def category_of(label: str) -> str: |
| if len(label) > 2 and label[1] == "-": |
| return label[2:] |
| return label |
|
|
|
|
| def predict( |
| model_path: str, |
| device: str, |
| text: str, |
| aggregation: str, |
| use_viterbi: bool, |
| viterbi_replace: bool, |
| top_k: Optional[int] = None, |
| chunk_chars: int = CHUNK_CHARS, |
| ) -> List[Dict]: |
| if not text.strip(): |
| return [] |
| pipe = get_pipe(model_path, device) |
| apply_runtime_config(pipe, use_viterbi, viterbi_replace, top_k=top_k) |
| spans: List[Dict] = [] |
| for offset, chunk in chunk_text(text, max_chars=chunk_chars): |
| for ent in pipe(chunk, aggregation_strategy=aggregation): |
| label = ent.get("entity_group") or ent.get("entity") or "" |
| cat = category_of(label) |
| if cat not in PALETTE: |
| continue |
| s = ent["start"] + offset |
| e = ent["end"] + offset |
| spans.append({ |
| "start": s, "end": e, "label": cat, |
| "score": float(ent["score"]), |
| "text": text[s:e], |
| }) |
| spans.sort(key=lambda s: (s["start"], s["end"])) |
| return spans |
|
|
|
|
| def to_highlight(text: str, spans: List[Dict]) -> List[Tuple[str, Optional[str]]]: |
| if not text: |
| return [] |
| out: List[Tuple[str, Optional[str]]] = [] |
| cur = 0 |
| for s in spans: |
| if s["start"] < cur: |
| continue |
| if s["start"] > cur: |
| out.append((text[cur:s["start"]], None)) |
| out.append((text[s["start"]:s["end"]], s["label"])) |
| cur = s["end"] |
| if cur < len(text): |
| out.append((text[cur:], None)) |
| return out |
|
|
|
|
| def fmt_spans(spans: List[Dict], max_rows: int = 60) -> str: |
| if not spans: |
| return "_No PII spans detected._" |
| rows = [ |
| f"- `{s['label']}` `{s['text'][:80].replace('`', '')}` (score {s['score']:.2f})" |
| for s in spans[:max_rows] |
| ] |
| more = f"\n\n_…+{len(spans) - max_rows} more_" if len(spans) > max_rows else "" |
| return f"**Detected {len(spans)} span(s):**\n" + "\n".join(rows) + more |
|
|
|
|
| |
| |
| def _legend_html() -> str: |
| seen = {} |
| for name, c in PALETTE.items(): |
| seen.setdefault(c, []).append(name) |
| rows = [] |
| for c, names in seen.items(): |
| chip = (f"<span style='background:{c};color:#fff;padding:.15rem .55rem;" |
| f"border-radius:.3rem;font-family:monospace;'>" |
| f"{names[0]}{(' +'+str(len(names)-1)) if len(names)>1 else ''}</span>") |
| rows.append(chip) |
| return ("<div style='display:flex;flex-wrap:wrap;gap:.4rem;font-size:.85rem;" |
| "margin:.25rem 0;'>" + "".join(rows) + "</div>") |
|
|
|
|
| LEGEND_HTML = _legend_html() |
|
|
|
|
| def diff_spans(a: List[Dict], b: List[Dict]): |
| """Return (only_in_a, only_in_b, agreed) span-lists. Keys are the |
| (start, end, label) triple — agreement requires identical category.""" |
| key = lambda s: (s["start"], s["end"], s["label"]) |
| sa = {key(s): s for s in a} |
| sb = {key(s): s for s in b} |
| only_a = [sa[k] for k in sa if k not in sb] |
| only_b = [sb[k] for k in sb if k not in sa] |
| both = [sa[k] for k in sa if k in sb] |
| return only_a, only_b, both |
|
|
|
|
| def fmt_diff(label_a: str, label_b: str, |
| only_a: List[Dict], only_b: List[Dict], agreed: List[Dict]) -> str: |
| def fmt(name: str, lst: List[Dict]) -> str: |
| if not lst: |
| return f"**{name}:** none" |
| rows = [ |
| f"- `{s['label']}` `{s['text'][:80].replace('`', '')}` " |
| f" (score {s['score']:.2f})" |
| for s in lst[:30] |
| ] |
| more = f"\n …+{len(lst) - 30} more" if len(lst) > 30 else "" |
| return f"**{name} ({len(lst)}):**\n" + "\n".join(rows) + more |
|
|
| return "\n\n".join([ |
| fmt(f"Only {label_a}", only_a), |
| fmt(f"Only {label_b}", only_b), |
| fmt("Agreed by both", agreed), |
| ]) |
|
|
|
|
| def run( |
| file_obj, pasted_text, device, |
| model_a_path, model_b_path, |
| use_a, use_b, |
| aggregation, use_viterbi, viterbi_replace, |
| top_k_a, top_k_b, |
| min_score, chunk_chars, |
| ): |
| text = extract_text(file_obj) if file_obj else (pasted_text or "") |
| if not text.strip(): |
| return [], [], "_Provide a PDF, a text file, or pasted text._", "" |
| if not (use_a or use_b): |
| return [], [], "_Enable at least one model._", text |
|
|
| a_spans = ( |
| predict(model_a_path, device, text, aggregation, |
| use_viterbi, viterbi_replace, |
| top_k=int(top_k_a), chunk_chars=int(chunk_chars)) |
| if use_a else [] |
| ) |
| b_spans = ( |
| predict(model_b_path, device, text, aggregation, |
| use_viterbi, viterbi_replace, |
| top_k=int(top_k_b), chunk_chars=int(chunk_chars)) |
| if use_b else [] |
| ) |
| thr = float(min_score) |
| a_spans = [s for s in a_spans if s["score"] >= thr] |
| b_spans = [s for s in b_spans if s["score"] >= thr] |
|
|
| a_hl = to_highlight(text, a_spans) if use_a else [] |
| b_hl = to_highlight(text, b_spans) if use_b else [] |
|
|
| label_a = Path(model_a_path).name or model_a_path |
| label_b = Path(model_b_path).name or model_b_path |
|
|
| if use_a and use_b: |
| only_a, only_b, agreed = diff_spans(a_spans, b_spans) |
| diff_md = fmt_diff(label_a, label_b, only_a, only_b, agreed) |
| elif use_a: |
| diff_md = fmt_spans(a_spans) |
| elif use_b: |
| diff_md = fmt_spans(b_spans) |
| else: |
| diff_md = "_Enable a model._" |
|
|
| return a_hl, b_hl, diff_md, text |
|
|
|
|
| def build_ui(default_model_a: str, default_model_b: str) -> gr.Blocks: |
| a_default_k = model_top_k_default(default_model_a) |
| a_n_experts = model_num_experts(default_model_a) |
| b_default_k = model_top_k_default(default_model_b) |
| b_n_experts = model_num_experts(default_model_b) |
|
|
| with gr.Blocks(title="HarEmb PII") as demo: |
| gr.Markdown( |
| "# HarEmb · OpenMed-Nemotron PII\n" |
| "Detect PII across 55 categories of the Nemotron-PII taxonomy. " |
| "Run **two models side-by-side** to compare detections — by " |
| "default this checkpoint vs the OpenMed teacher it was distilled " |
| "from. Disable one model to view a single detection." |
| ) |
| devices = list_devices() |
| with gr.Row(): |
| device_dd = gr.Dropdown(devices, value=devices[0], label="Device", scale=1) |
| clear_btn = gr.Button("Clear model cache", variant="secondary", scale=1) |
|
|
| with gr.Row(): |
| with gr.Column(): |
| use_a = gr.Checkbox(value=True, label="Enable model A (teacher / baseline)") |
| model_a_tb = gr.Textbox( |
| value=default_model_a, |
| label="Model A — path / HF repo", |
| info="Default: OpenMed/privacy-filter-nemotron (teacher).", |
| ) |
| top_k_a_sl = gr.Slider( |
| 1, a_n_experts, value=a_default_k, step=1, |
| label=f"Active experts per token (top-k of {a_n_experts})", |
| info=f"Trained value: {a_default_k}. Lower = faster + less " |
| f"capacity per token. Higher = more compute, denser " |
| f"routing. Bypassing the trained value can drop " |
| f"quality — useful for ablations.", |
| ) |
| with gr.Column(): |
| use_b = gr.Checkbox(value=True, label="Enable model B (this checkpoint)") |
| model_b_tb = gr.Textbox( |
| value=default_model_b, |
| label="Model B — path / HF repo", |
| info="Default: ./ (this checkpoint).", |
| ) |
| top_k_b_sl = gr.Slider( |
| 1, b_n_experts, value=b_default_k, step=1, |
| label=f"Active experts per token (top-k of {b_n_experts})", |
| info=f"Trained value: {b_default_k}.", |
| ) |
|
|
| with gr.Accordion("Inference settings", open=False): |
| with gr.Row(): |
| aggregation_dd = gr.Dropdown( |
| ["simple", "first", "max", "average", "none"], |
| value="simple", |
| label="aggregation_strategy", |
| info="how token-level labels are merged into spans", |
| ) |
| viterbi_cb = gr.Checkbox( |
| value=True, |
| label="use_viterbi_decode", |
| info="constrained BIOES decoding (off = raw argmax)", |
| ) |
| viterbi_replace_cb = gr.Checkbox( |
| value=True, |
| label="viterbi_replace_logits", |
| info="when on, outputs.logits.argmax(-1) returns the Viterbi path", |
| ) |
| min_score_sl = gr.Slider( |
| 0.0, 1.0, value=0.0, step=0.01, |
| label="min confidence", |
| info="filter out spans with score below this threshold", |
| ) |
| chunk_sl = gr.Slider( |
| 0, 500_000, value=CHUNK_CHARS, step=10_000, |
| label="chunk size (chars)", |
| info="0 = single pass; otherwise split on paragraphs at this size. " |
| "Model window ≈131k tokens (~500k chars).", |
| ) |
|
|
| with gr.Row(): |
| file_in = gr.File(label="PDF / text file", file_types=[".pdf", ".txt", ".md"]) |
| text_in = gr.Textbox( |
| label="…or paste text", |
| lines=6, |
| placeholder=("Patient Sarah Johnson (DOB 03/15/1985), MRN 4872910, " |
| "phone 415-555-0123, email sarah.johnson@example.com."), |
| ) |
| run_btn = gr.Button("Detect PII", variant="primary") |
|
|
| gr.HTML(LEGEND_HTML) |
| with gr.Row(): |
| a_out = gr.HighlightedText( |
| label="Model A detections", |
| color_map=PALETTE, |
| show_legend=False, |
| combine_adjacent=False, |
| ) |
| b_out = gr.HighlightedText( |
| label="Model B detections", |
| color_map=PALETTE, |
| show_legend=False, |
| combine_adjacent=False, |
| ) |
| diff_out = gr.Markdown("_Run a detection to see the diff / span list._") |
| extracted_out = gr.Textbox( |
| label="Extracted text (read-only)", lines=6, interactive=False, |
| ) |
|
|
| run_btn.click( |
| run, |
| [file_in, text_in, device_dd, |
| model_a_tb, model_b_tb, use_a, use_b, |
| aggregation_dd, viterbi_cb, viterbi_replace_cb, |
| top_k_a_sl, top_k_b_sl, |
| min_score_sl, chunk_sl], |
| [a_out, b_out, diff_out, extracted_out], |
| ) |
| clear_btn.click(clear_model_cache, None, diff_out) |
|
|
| return demo |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| p = argparse.ArgumentParser(description="HarEmb PII — Gradio demo") |
| p.add_argument("--host", default="127.0.0.1", help="Bind address (default: 127.0.0.1)") |
| p.add_argument("--port", type=int, default=7860, help="Port (default: 7860)") |
| p.add_argument("--share", action="store_true", help="Create a public Gradio share link") |
| p.add_argument("--model-a", default="OpenMed/privacy-filter-nemotron", |
| help="Model A path / HF repo " |
| "(default: OpenMed/privacy-filter-nemotron — teacher)") |
| p.add_argument("--model-b", default=DEFAULT_MODEL, |
| help="Model B path / HF repo " |
| "(default: . — this checkpoint)") |
| return p.parse_args() |
|
|
|
|
| if __name__ == "__main__": |
| args = parse_args() |
| build_ui( |
| default_model_a=args.model_a, |
| default_model_b=args.model_b, |
| ).launch( |
| server_name=args.host, |
| server_port=args.port, |
| share=args.share, |
| theme=gr.themes.Soft(), |
| ) |
|
|