""" HarEmb PII — local Gradio inference demo. Upload a PDF (or paste text), pick a device (CPU / cuda:N), and the model highlights detected PII spans across the 55-category Nemotron-PII taxonomy. Install: pip install "gradio>=4" "transformers>=4.45" torch pypdf accelerate Run from inside this folder: python app.py """ from __future__ import annotations import argparse import re from pathlib import Path from typing import Dict, List, Optional, Tuple import gradio as gr import torch from pypdf import PdfReader from transformers import pipeline # Default to loading from this folder so `python app.py` works in-place after # downloading the repo. Override by setting --model-path on the CLI. DEFAULT_MODEL = "." CHUNK_CHARS = 400_000 # ~100k tokens; well under the model's 131k window # 55 Nemotron-PII categories grouped for visual coherence; one color per # coarse "family" so the highlight legend stays readable. PALETTE: Dict[str, str] = { # Identity (red) "first_name": "#ef4444", "last_name": "#ef4444", "user_name": "#ef4444", "company_name": "#ef4444", "age": "#fb7185", "gender": "#fb7185", "race_ethnicity": "#fb7185", "sexuality": "#fb7185", "religious_belief": "#fb7185", "political_view": "#fb7185", "language": "#fb7185", "education_level": "#fb7185", "occupation": "#fb7185", "employment_status": "#fb7185", "blood_type": "#fb7185", "biometric_identifier":"#fb7185", # Contact (purple) "email": "#8b5cf6", "phone_number": "#a78bfa", "fax_number": "#a78bfa", "url": "#7c3aed", # Address (green) "street_address": "#10b981", "city": "#34d399", "county": "#34d399", "state": "#34d399", "country": "#34d399", "postcode": "#34d399", "coordinate": "#059669", # Dates (blue) "date": "#3b82f6", "date_of_birth": "#60a5fa", "date_time": "#60a5fa", "time": "#60a5fa", # Government IDs (orange) "ssn": "#f97316", "national_id": "#fb923c", "tax_id": "#fb923c", # Financial (amber) "account_number": "#f59e0b", "bank_routing_number": "#fbbf24", "swift_bic": "#fbbf24", "credit_debit_card": "#fbbf24", "cvv": "#fbbf24", "pin": "#fbbf24", "password": "#d97706", # Healthcare (pink) "medical_record_number": "#ec4899", "health_plan_beneficiary_number": "#f472b6", # Enterprise IDs (cyan) "customer_id": "#06b6d4", "employee_id": "#06b6d4", "unique_id": "#22d3ee", "certificate_license_number": "#22d3ee", # Vehicle (lime) "license_plate": "#84cc16", "vehicle_identifier": "#84cc16", # Digital (indigo) "ipv4": "#6366f1", "ipv6": "#6366f1", "mac_address": "#818cf8", "device_identifier": "#818cf8", "api_key": "#4f46e5", "http_cookie": "#4f46e5", } def list_devices() -> List[str]: devs = ["cpu"] if torch.cuda.is_available(): for i in range(torch.cuda.device_count()): devs.append(f"cuda:{i}") return devs _pipe_cache: Dict[Tuple[str, str], object] = {} def get_pipe(model_path: str, device: str): key = (model_path, device) if key in _pipe_cache: return _pipe_cache[key] dtype = torch.bfloat16 if device.startswith("cuda") else torch.float32 pipe = pipeline( "token-classification", model=model_path, tokenizer=model_path, trust_remote_code=True, aggregation_strategy="simple", device=device, torch_dtype=dtype, ) _pipe_cache[key] = pipe return pipe def apply_runtime_config( pipe, use_viterbi: bool, viterbi_replace: bool, top_k: Optional[int] = None, ) -> None: cfg = pipe.model.config if hasattr(cfg, "use_viterbi_decode"): cfg.use_viterbi_decode = bool(use_viterbi) if hasattr(cfg, "viterbi_replace_logits"): cfg.viterbi_replace_logits = bool(viterbi_replace) # Override the per-layer MoE top-k at inference. Both fields need to be # set: `mlp.router.top_k` is the actual router top-k, and the upstream # `mlp.num_experts` is misnamed (it's also the per-token top_k, not # num_local_experts). top_k=None leaves the trained config alone. if top_k is not None: n_local = int(getattr(cfg, "num_local_experts", 128)) k = max(1, min(int(top_k), n_local)) for layer in pipe.model.model.layers: mlp = getattr(layer, "mlp", None) if mlp is None: continue router = getattr(mlp, "router", None) if router is not None and hasattr(router, "top_k"): router.top_k = k if hasattr(mlp, "num_experts"): mlp.num_experts = k def model_top_k_default(model_path: str) -> int: """Read the trained `num_experts_per_tok` from the model's config without loading the weights. Falls back to 4 if the field isn't present.""" try: from transformers import AutoConfig cfg = AutoConfig.from_pretrained(model_path, trust_remote_code=True) return int(getattr(cfg, "num_experts_per_tok", 4)) except Exception: return 4 def model_num_experts(model_path: str) -> int: """Read `num_local_experts` from the model's config without loading weights. Falls back to 128 if the field isn't present.""" try: from transformers import AutoConfig cfg = AutoConfig.from_pretrained(model_path, trust_remote_code=True) return int(getattr(cfg, "num_local_experts", 128)) except Exception: return 128 def clear_model_cache() -> str: _pipe_cache.clear() if torch.cuda.is_available(): torch.cuda.empty_cache() return "Model cache cleared. Next run will reload weights." def extract_text(file_obj) -> str: if file_obj is None: return "" path = file_obj.name if hasattr(file_obj, "name") else file_obj p = Path(path) if p.suffix.lower() == ".pdf": reader = PdfReader(str(p)) return "\n\n".join((page.extract_text() or "") for page in reader.pages) return p.read_text(encoding="utf-8", errors="replace") def chunk_text(text: str, max_chars: int = CHUNK_CHARS) -> List[Tuple[int, str]]: if not text: return [] if max_chars <= 0 or len(text) <= max_chars: return [(0, text)] pieces = re.split(r"(\n\s*\n)", text) chunks: List[Tuple[int, str]] = [] cur, cur_off, pos = "", 0, 0 for piece in pieces: if cur and len(cur) + len(piece) > max_chars and cur.strip(): chunks.append((cur_off, cur)) cur, cur_off = piece, pos else: if not cur: cur_off = pos cur += piece pos += len(piece) if cur.strip(): chunks.append((cur_off, cur)) return chunks def category_of(label: str) -> str: if len(label) > 2 and label[1] == "-": return label[2:] return label def predict( model_path: str, device: str, text: str, aggregation: str, use_viterbi: bool, viterbi_replace: bool, top_k: Optional[int] = None, chunk_chars: int = CHUNK_CHARS, ) -> List[Dict]: if not text.strip(): return [] pipe = get_pipe(model_path, device) apply_runtime_config(pipe, use_viterbi, viterbi_replace, top_k=top_k) spans: List[Dict] = [] for offset, chunk in chunk_text(text, max_chars=chunk_chars): for ent in pipe(chunk, aggregation_strategy=aggregation): label = ent.get("entity_group") or ent.get("entity") or "" cat = category_of(label) if cat not in PALETTE: continue s = ent["start"] + offset e = ent["end"] + offset spans.append({ "start": s, "end": e, "label": cat, "score": float(ent["score"]), "text": text[s:e], }) spans.sort(key=lambda s: (s["start"], s["end"])) return spans def to_highlight(text: str, spans: List[Dict]) -> List[Tuple[str, Optional[str]]]: if not text: return [] out: List[Tuple[str, Optional[str]]] = [] cur = 0 for s in spans: if s["start"] < cur: continue if s["start"] > cur: out.append((text[cur:s["start"]], None)) out.append((text[s["start"]:s["end"]], s["label"])) cur = s["end"] if cur < len(text): out.append((text[cur:], None)) return out def fmt_spans(spans: List[Dict], max_rows: int = 60) -> str: if not spans: return "_No PII spans detected._" rows = [ f"- `{s['label']}` `{s['text'][:80].replace('`', '')}` (score {s['score']:.2f})" for s in spans[:max_rows] ] more = f"\n\n_…+{len(spans) - max_rows} more_" if len(spans) > max_rows else "" return f"**Detected {len(spans)} span(s):**\n" + "\n".join(rows) + more # Build legend HTML for the categories present in PALETTE — one row per family # (we still want it readable; show one swatch per unique color). def _legend_html() -> str: seen = {} for name, c in PALETTE.items(): seen.setdefault(c, []).append(name) rows = [] for c, names in seen.items(): chip = (f"" f"{names[0]}{(' +'+str(len(names)-1)) if len(names)>1 else ''}") rows.append(chip) return ("