""" PII Reveal - Document Privacy Explorer (v5) ============================================ Changes from v4 (feedback): 1. Light theme refresh Crisper, higher-contrast neutral palette: bright white cards on a very light grey body, near-black body text, and subdued text bumped one stop darker so labels stop disappearing. Category highlight alpha bumped from 12% to 16% on light, 15% on dark. 2. PDF redaction export New POST /api/redact-pdf endpoint. Accepts the original PDF plus the list of spans + active labels the client is viewing, applies true PyMuPDF redactions (black fill, underlying text removed — not just a visual overlay), and streams the result back. The inspector gets an "Export PDF" primary button when the input file is a PDF. 3. Performance Two code-level fixes for the 100k-token slowness on T4: a) predict_text: dropped the unbind(0) -> list -> stack(0) roundtrip in favour of a single torch.cat. It was allocating 100k separate 33-wide tensors and re-stacking them for no reason. b) Decoder.decode: the Viterbi loop is inherently sequential and launches O(seq_len) CUDA kernels — on Turing (T4, compute 7.5) kernel-launch overhead dominated because the state space is tiny (33 classes). It now runs on CPU, which is bandwidth-bound on a 33x33 matrix and completes in a couple of seconds for 100k tokens. Also cached the Decoder itself with lru_cache (was being rebuilt per request). Hardware note: T4 is pre-Ampere and has no native bf16 support, so every attention matmul is emulated. Code-level changes help the decoder, but the model's attention pass will still be faster on L4 / A10 / A100 by a large factor. The v5 fixes were validated against the "is it code or hardware" question: both, but the decoder path was the dominant contribution for a 100k document. """ # ── stdlib ─────────────────────────────────────────────────────── import dataclasses import functools import io import json import math import os import re import tempfile from bisect import bisect_left, bisect_right from collections.abc import Sequence from dataclasses import dataclass from pathlib import Path from typing import Final # ── third-party ────────────────────────────────────────────────── import gradio as gr import spaces import tiktoken import torch import torch.nn.functional as F from fastapi import File, Form, UploadFile from fastapi.responses import HTMLResponse, JSONResponse, StreamingResponse from huggingface_hub import snapshot_download from safetensors import safe_open # ── configuration ──────────────────────────────────────────────── MODEL_REPO = os.getenv("MODEL_ID", "charles-first-org/second-model") HF_TOKEN = os.getenv("HF_TOKEN", None) MODEL_DIR = Path(snapshot_download(MODEL_REPO, token=HF_TOKEN)) CATEGORIES_META = { "private_person": {"color": "#E24B4A", "cls": "hp", "label": "Person", "mono": False}, "private_date": {"color": "#7F77DD", "cls": "hd", "label": "Date", "mono": True}, "private_address": {"color": "#1D9E75", "cls": "ha", "label": "Address", "mono": False}, "private_email": {"color": "#378ADD", "cls": "he", "label": "Email", "mono": True}, "account_number": {"color": "#BA7517", "cls": "hac", "label": "Account", "mono": True}, "private_url": {"color": "#D85A30", "cls": "hu", "label": "URL", "mono": True}, "secret": {"color": "#D4537E", "cls": "hs", "label": "Secret", "mono": True}, "private_phone": {"color": "#639922", "cls": "hph", "label": "Phone", "mono": True}, } # ===================================================================== # MODEL ARCHITECTURE + INFERENCE (from reference implementation) # ===================================================================== PRIVACY_FILTER_MODEL_TYPE: Final[str] = "privacy_filter" REQUIRED_MODEL_CONFIG_KEYS: Final[tuple[str, ...]] = ( "model_type", "encoding", "num_hidden_layers", "num_experts", "experts_per_token", "vocab_size", "num_labels", "hidden_size", "intermediate_size", "head_dim", "num_attention_heads", "num_key_value_heads", "sliding_window", "bidirectional_context", "bidirectional_left_context", "bidirectional_right_context", "default_n_ctx", "initial_context_length", "rope_theta", "rope_scaling_factor", "rope_ntk_alpha", "rope_ntk_beta", "param_dtype", ) BACKGROUND_CLASS_LABEL: Final[str] = "O" BOUNDARY_PREFIXES: Final[tuple[str, ...]] = ("B", "I", "E", "S") SPAN_CLASS_NAMES: Final[tuple[str, ...]] = ( BACKGROUND_CLASS_LABEL, "account_number", "private_address", "private_date", "private_email", "private_person", "private_phone", "private_url", "secret", ) NER_CLASS_NAMES: Final[tuple[str, ...]] = (BACKGROUND_CLASS_LABEL,) + tuple( f"{prefix}-{base}" for base in SPAN_CLASS_NAMES if base != BACKGROUND_CLASS_LABEL for prefix in BOUNDARY_PREFIXES ) VITERBI_TRANSITION_BIAS_KEYS: Final[tuple[str, ...]] = ( "transition_bias_background_stay", "transition_bias_background_to_start", "transition_bias_inside_to_continue", "transition_bias_inside_to_end", "transition_bias_end_to_background", "transition_bias_end_to_start", ) DEFAULT_VITERBI_CALIBRATION_PRESET: Final[str] = "default" def validate_model_config_contract(cfg: dict, *, context: str) -> None: missing = [k for k in REQUIRED_MODEL_CONFIG_KEYS if k not in cfg] if missing: raise ValueError(f"{context} missing keys: {', '.join(missing)}") if cfg.get("model_type") != PRIVACY_FILTER_MODEL_TYPE: raise ValueError(f"{context} model_type must be {PRIVACY_FILTER_MODEL_TYPE!r}") if cfg.get("bidirectional_context") is not True: raise ValueError(f"{context} must use bidirectional_context=true") lc, rc = cfg.get("bidirectional_left_context"), cfg.get("bidirectional_right_context") if not isinstance(lc, int) or not isinstance(rc, int) or lc != rc or lc < 0: raise ValueError(f"{context} bidirectional context must be equal non-negative ints") sw = cfg.get("sliding_window") if sw != 2 * lc + 1: raise ValueError(f"{context} sliding_window must equal 2*context+1") if cfg["num_labels"] != 33: raise ValueError(f"{context} num_labels must be 33") if cfg["param_dtype"] != "bfloat16": raise ValueError(f"{context} param_dtype must be bfloat16") # ── model helpers ──────────────────────────────────────────────── def expert_linear(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None) -> torch.Tensor: n, e, k = x.shape _, _, _, o = weight.shape out = torch.bmm(x.reshape(n * e, 1, k), weight.reshape(n * e, k, o)).reshape(n, e, o) return out + bias if bias is not None else out @dataclass class ModelConfig: num_hidden_layers: int; num_experts: int; experts_per_token: int vocab_size: int; num_labels: int; hidden_size: int; intermediate_size: int head_dim: int; num_attention_heads: int; num_key_value_heads: int bidirectional_context_size: int; initial_context_length: int rope_theta: float; rope_scaling_factor: float; rope_ntk_alpha: float; rope_ntk_beta: float @classmethod def from_checkpoint_config(cls, cfg: dict, *, context: str) -> "ModelConfig": cfg = dict(cfg) cfg["bidirectional_context_size"] = cfg["bidirectional_left_context"] fields = {f.name for f in dataclasses.fields(cls)} return cls(**{k: v for k, v in cfg.items() if k in fields}) class RMSNorm(torch.nn.Module): def __init__(self, n: int, eps: float = 1e-5, device=None): super().__init__() self.eps = eps self.scale = torch.nn.Parameter(torch.ones(n, device=device, dtype=torch.float32)) def forward(self, x): t = x.float() return (t * torch.rsqrt(t.pow(2).mean(-1, keepdim=True) + self.eps) * self.scale).to(x.dtype) def apply_rope(x, cos, sin): cos = cos.unsqueeze(-2).to(x.dtype); sin = sin.unsqueeze(-2).to(x.dtype) x1, x2 = x[..., ::2], x[..., 1::2] return torch.stack((x1 * cos - x2 * sin, x2 * cos + x1 * sin), dim=-1).reshape(x.shape) class RotaryEmbedding(torch.nn.Module): def __init__(self, head_dim, base, dtype, *, initial_context_length=4096, scaling_factor=1.0, ntk_alpha=1.0, ntk_beta=32.0, device=None): super().__init__() self.head_dim, self.base, self.dtype = head_dim, base, dtype self.initial_context_length = initial_context_length self.scaling_factor, self.ntk_alpha, self.ntk_beta = scaling_factor, ntk_alpha, ntk_beta self.device = device mp = max(int(initial_context_length * scaling_factor), initial_context_length) self.max_position_embeddings = mp cos, sin = self._compute(mp, device=torch.device("cpu")) target = device or torch.device("cpu") self.register_buffer("cos_cache", cos.to(target), persistent=False) self.register_buffer("sin_cache", sin.to(target), persistent=False) def _inv_freq(self, device=None): device = device or self.device freq = self.base ** (torch.arange(0, self.head_dim, 2, dtype=torch.float, device=device) / self.head_dim) if self.scaling_factor > 1.0: d_half = self.head_dim / 2 low = d_half * math.log(self.initial_context_length / (self.ntk_beta * 2 * math.pi)) / math.log(self.base) high = d_half * math.log(self.initial_context_length / (self.ntk_alpha * 2 * math.pi)) / math.log(self.base) interp = 1.0 / (self.scaling_factor * freq) extrap = 1.0 / freq ramp = (torch.arange(d_half, dtype=torch.float32, device=device) - low) / (high - low) mask = 1 - ramp.clamp(0, 1) return interp * (1 - mask) + extrap * mask return 1.0 / freq def _compute(self, n, device=None): inv_freq = self._inv_freq(device) t = torch.arange(n, dtype=torch.float32, device=device or self.device) freqs = torch.einsum("i,j->ij", t, inv_freq) c = 0.1 * math.log(self.scaling_factor) + 1.0 if self.scaling_factor > 1.0 else 1.0 return (freqs.cos() * c).to(self.dtype), (freqs.sin() * c).to(self.dtype) def forward(self, q, k): n = q.shape[0] if n > self.cos_cache.shape[0]: cos, sin = self._compute(n, torch.device("cpu")) self.cos_cache, self.sin_cache = cos.to(q.device), sin.to(q.device) cc = self.cos_cache.to(q.device) if self.cos_cache.device != q.device else self.cos_cache sc = self.sin_cache.to(q.device) if self.sin_cache.device != q.device else self.sin_cache cos, sin = cc[:n], sc[:n] q = apply_rope(q.view(n, -1, self.head_dim), cos, sin).reshape(q.shape) k = apply_rope(k.view(n, -1, self.head_dim), cos, sin).reshape(k.shape) return q, k def sdpa(Q, K, V, S, sm_scale, ctx): n, nh, qm, hd = Q.shape w = 2 * ctx + 1 Kp = F.pad(K, (0, 0, 0, 0, ctx, ctx)); Vp = F.pad(V, (0, 0, 0, 0, ctx, ctx)) Kw = Kp.unfold(0, w, 1).permute(0, 3, 1, 2); Vw = Vp.unfold(0, w, 1).permute(0, 3, 1, 2) idx = torch.arange(w, device=Q.device) - ctx pos = torch.arange(n, device=Q.device)[:, None] + idx[None, :] valid = (pos >= 0) & (pos < n) scores = torch.einsum("nhqd,nwhd->nhqw", Q, Kw).float() * sm_scale scores = scores.masked_fill(~valid[:, None, None, :], -float("inf")) sink = (S * math.log(2.0)).reshape(nh, qm)[None, :, :, None].expand(n, -1, -1, 1) scores = torch.cat([scores, sink], dim=-1) wt = torch.softmax(scores, dim=-1)[..., :-1].to(V.dtype) return torch.einsum("nhqw,nwhd->nhqd", wt, Vw).reshape(n, -1) class AttentionBlock(torch.nn.Module): def __init__(self, cfg: ModelConfig, device=None): super().__init__() dt = torch.bfloat16 self.head_dim, self.nah, self.nkv = cfg.head_dim, cfg.num_attention_heads, cfg.num_key_value_heads self.ctx = int(cfg.bidirectional_context_size) self.sinks = torch.nn.Parameter(torch.empty(cfg.num_attention_heads, device=device, dtype=torch.float32)) self.norm = RMSNorm(cfg.hidden_size, device=device) qkv_d = cfg.head_dim * (cfg.num_attention_heads + 2 * cfg.num_key_value_heads) self.qkv = torch.nn.Linear(cfg.hidden_size, qkv_d, device=device, dtype=dt) self.out = torch.nn.Linear(cfg.head_dim * cfg.num_attention_heads, cfg.hidden_size, device=device, dtype=dt) self.qk_scale = 1 / math.sqrt(math.sqrt(cfg.head_dim)) self.rope = RotaryEmbedding(cfg.head_dim, int(cfg.rope_theta), torch.float32, initial_context_length=cfg.initial_context_length, scaling_factor=cfg.rope_scaling_factor, ntk_alpha=cfg.rope_ntk_alpha, ntk_beta=cfg.rope_ntk_beta, device=device) def forward(self, x): t = self.norm(x).to(self.qkv.weight.dtype) qkv = F.linear(t, self.qkv.weight, self.qkv.bias) hd, nah, nkv = self.head_dim, self.nah, self.nkv q = qkv[:, :nah * hd].contiguous() k = qkv[:, nah * hd:(nah + nkv) * hd].contiguous() v = qkv[:, (nah + nkv) * hd:(nah + 2 * nkv) * hd].contiguous() q, k = self.rope(q, k) q, k = q * self.qk_scale, k * self.qk_scale n = q.shape[0] q = q.view(n, nkv, nah // nkv, hd); k = k.view(n, nkv, hd); v = v.view(n, nkv, hd) ao = sdpa(q, k, v, self.sinks, 1.0, self.ctx).to(self.out.weight.dtype) return x + F.linear(ao, self.out.weight, self.out.bias).to(x.dtype) def swiglu(x, alpha=1.702, limit=7.0): g, l = x.chunk(2, dim=-1) g, l = g.clamp(max=limit), l.clamp(-limit, limit) return g * torch.sigmoid(alpha * g) * (l + 1) class MLPBlock(torch.nn.Module): def __init__(self, cfg: ModelConfig, device=None): super().__init__() dt = torch.bfloat16 self.ne, self.ept = cfg.num_experts, cfg.experts_per_token self.norm = RMSNorm(cfg.hidden_size, device=device) self.gate = torch.nn.Linear(cfg.hidden_size, cfg.num_experts, device=device, dtype=dt) self.mlp1_weight = torch.nn.Parameter(torch.empty(cfg.num_experts, cfg.hidden_size, cfg.intermediate_size * 2, device=device, dtype=dt)) self.mlp1_bias = torch.nn.Parameter(torch.empty(cfg.num_experts, cfg.intermediate_size * 2, device=device, dtype=dt)) self.mlp2_weight = torch.nn.Parameter(torch.empty(cfg.num_experts, cfg.intermediate_size, cfg.hidden_size, device=device, dtype=dt)) self.mlp2_bias = torch.nn.Parameter(torch.empty(cfg.num_experts, cfg.hidden_size, device=device, dtype=dt)) def forward(self, x): t = self.norm(x) gs = F.linear(t.float(), self.gate.weight.float(), self.gate.bias.float()) top = torch.topk(gs, k=self.ept, dim=-1, sorted=True) ew = torch.softmax(top.values, dim=-1) / self.ept ei = top.indices ept = self.ept def _chunk(tc, eic, ewc): o = expert_linear(tc.float().unsqueeze(1).expand(-1, eic.shape[1], -1), self.mlp1_weight[eic].float(), self.mlp1_bias[eic].float()) o = swiglu(o) o = expert_linear(o.float(), self.mlp2_weight[eic].float(), self.mlp2_bias[eic].float()) return (torch.einsum("bec,be->bc", o.to(ewc.dtype), ewc) * ept).to(x.dtype) cs = 32 if t.shape[0] > cs: parts = [_chunk(t[s:s+cs], ei[s:s+cs], ew[s:s+cs]) for s in range(0, t.shape[0], cs)] return x + torch.cat(parts, 0) return x + _chunk(t, ei, ew) class TransformerBlock(torch.nn.Module): def __init__(self, cfg, device=None): super().__init__() self.attn = AttentionBlock(cfg, device=device) self.mlp = MLPBlock(cfg, device=device) def forward(self, x): return self.mlp(self.attn(x)) class Checkpoint: @staticmethod def build_param_name_map(n): return ({f"block.{i}.mlp.mlp1_bias": f"block.{i}.mlp.swiglu.bias" for i in range(n)} | {f"block.{i}.mlp.mlp1_weight": f"block.{i}.mlp.swiglu.weight" for i in range(n)} | {f"block.{i}.mlp.mlp2_bias": f"block.{i}.mlp.out.bias" for i in range(n)} | {f"block.{i}.mlp.mlp2_weight": f"block.{i}.mlp.out.weight" for i in range(n)}) def __init__(self, path, device, num_hidden_layers): self.pnm = self.build_param_name_map(num_hidden_layers) self.ds = device.type if device.index is None else f"{device.type}:{device.index}" files = [os.path.join(path, f) for f in os.listdir(path) if f.endswith(".safetensors")] self.map = {} for sf in files: with safe_open(sf, framework="pt", device=self.ds) as h: for k in h.keys(): self.map[k] = sf def get(self, name): mapped = self.pnm.get(name, name) with safe_open(self.map[mapped], framework="pt", device=self.ds) as h: return h.get_tensor(mapped) class Transformer(torch.nn.Module): def __init__(self, cfg, device): super().__init__() dt = torch.bfloat16 self.embedding = torch.nn.Embedding(cfg.vocab_size, cfg.hidden_size, device=device, dtype=dt) self.block = torch.nn.ModuleList([TransformerBlock(cfg, device=device) for _ in range(cfg.num_hidden_layers)]) self.norm = RMSNorm(cfg.hidden_size, device=device) self.unembedding = torch.nn.Linear(cfg.hidden_size, cfg.num_labels, bias=False, device=device, dtype=dt) def forward(self, token_ids): x = self.embedding(token_ids) for blk in self.block: x = blk(x) return F.linear(self.norm(x), self.unembedding.weight, None) @classmethod def from_checkpoint(cls, checkpoint_dir, *, device): torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False torch.set_float32_matmul_precision("highest") cp = json.loads((Path(checkpoint_dir) / "config.json").read_text()) validate_model_config_contract(cp, context=str(checkpoint_dir)) cfg = ModelConfig.from_checkpoint_config(cp, context=str(checkpoint_dir)) ckpt = Checkpoint(checkpoint_dir, device, cfg.num_hidden_layers) m = cls(cfg, device); m.eval() for name, param in m.named_parameters(): loaded = ckpt.get(name) if param.shape != loaded.shape: raise ValueError(f"Shape mismatch {name}: {param.shape} vs {loaded.shape}") param.data.copy_(loaded) return m # ── label info + span decoding ─────────────────────────────────── @dataclass(frozen=True) class LabelInfo: boundary_label_lookup: dict[str, dict[str, int]] token_to_span_label: dict[int, int] token_boundary_tags: dict[int, str | None] span_class_names: tuple[str, ...] span_label_lookup: dict[str, int] background_token_label: int background_span_label: int def labels_to_spans(labels_by_index, label_info): spans, cur_label, start_idx, prev_idx = [], None, None, None bg = label_info.background_span_label for ti in sorted(labels_by_index): lid = labels_by_index[ti] sl = label_info.token_to_span_label.get(lid) bt = label_info.token_boundary_tags.get(lid) if prev_idx is not None and ti != prev_idx + 1: if cur_label is not None and start_idx is not None: spans.append((cur_label, start_idx, prev_idx + 1)) cur_label = start_idx = None if sl is None: prev_idx = ti; continue if sl == bg: if cur_label is not None and start_idx is not None: spans.append((cur_label, start_idx, ti)) cur_label = start_idx = None; prev_idx = ti; continue if bt == "S": if cur_label is not None and start_idx is not None and prev_idx is not None: spans.append((cur_label, start_idx, prev_idx + 1)) spans.append((sl, ti, ti + 1)); cur_label = start_idx = None elif bt == "B": if cur_label is not None and start_idx is not None and prev_idx is not None: spans.append((cur_label, start_idx, prev_idx + 1)) cur_label, start_idx = sl, ti elif bt == "I": if cur_label is None or cur_label != sl: if cur_label is not None and start_idx is not None and prev_idx is not None: spans.append((cur_label, start_idx, prev_idx + 1)) cur_label, start_idx = sl, ti elif bt == "E": if cur_label is None or cur_label != sl or start_idx is None: if cur_label is not None and start_idx is not None and prev_idx is not None: spans.append((cur_label, start_idx, prev_idx + 1)) spans.append((sl, ti, ti + 1)); cur_label = start_idx = None else: spans.append((cur_label, start_idx, ti + 1)); cur_label = start_idx = None else: if cur_label is not None and start_idx is not None and prev_idx is not None: spans.append((cur_label, start_idx, prev_idx + 1)) cur_label = start_idx = None prev_idx = ti if cur_label is not None and start_idx is not None and prev_idx is not None: spans.append((cur_label, start_idx, prev_idx + 1)) return spans def token_spans_to_char_spans(spans, cs, ce): out = [] for li, ts, te in spans: if not (0 <= ts < te <= len(cs)): continue s, e = cs[ts], ce[te - 1] if e > s: out.append((li, s, e)) return out def trim_char_spans_whitespace(spans, text): out = [] for li, s, e in spans: if not (0 <= s < e <= len(text)): continue while s < e and text[s].isspace(): s += 1 while e > s and text[e - 1].isspace(): e -= 1 if e > s: out.append((li, s, e)) return out # ── viterbi decoder ────────────────────────────────────────────── @functools.lru_cache(maxsize=1) def get_viterbi_transition_biases(): cp = MODEL_DIR / "viterbi_calibration.json" default = {k: 0.0 for k in VITERBI_TRANSITION_BIAS_KEYS} if not cp.is_file(): return default payload = json.loads(cp.read_text()) raw = payload ops = payload.get("operating_points") if isinstance(ops, dict): preset = ops.get(DEFAULT_VITERBI_CALIBRATION_PRESET) if isinstance(preset, dict): raw = preset.get("biases", raw) if not isinstance(raw, dict): return default return {k: float(raw.get(k, 0.0)) for k in VITERBI_TRANSITION_BIAS_KEYS} class Decoder: def __init__(self, label_info): nc = len(label_info.token_to_span_label) self._start = torch.full((nc,), -1e9, dtype=torch.float32) self._end = torch.full((nc,), -1e9, dtype=torch.float32) self._trans = torch.full((nc, nc), -1e9, dtype=torch.float32) biases = get_viterbi_transition_biases() bg_tok, bg_sp = label_info.background_token_label, label_info.background_span_label ttsl, tbt = label_info.token_to_span_label, label_info.token_boundary_tags for i in range(nc): tag, sl = tbt.get(i), ttsl.get(i) if tag in {"B", "S"} or i == bg_tok: self._start[i] = 0.0 if tag in {"E", "S"} or i == bg_tok: self._end[i] = 0.0 for j in range(nc): nt, ns = tbt.get(j), ttsl.get(j) if self._valid(tag, sl, nt, ns, bg_tok, bg_sp, j): self._trans[i, j] = self._bias(tag, sl, nt, ns, bg_sp, biases) @staticmethod def _valid(pt, ps, nt, ns, bti, bsi, ni): nb = ns == bsi or ni == bti if (ns is None or nt is None) and not nb: return False if pt is None or ps is None: return nb or nt in {"B", "S"} if ps == bsi or pt in {"E", "S"}: return nb or nt in {"B", "S"} if pt in {"B", "I"}: return ps == ns and nt in {"I", "E"} return False @staticmethod def _bias(pt, ps, nt, ns, bsi, b): nb, pb = ns == bsi, ps == bsi if pb: return b["transition_bias_background_stay"] if nb else b["transition_bias_background_to_start"] if pt in {"B", "I"}: return b["transition_bias_inside_to_continue"] if nt == "I" else b["transition_bias_inside_to_end"] return b["transition_bias_end_to_background"] if nb else b["transition_bias_end_to_start"] def decode(self, lp): # Sequential Viterbi over a tiny (33-class) state space. On T4 the # per-step CUDA kernel launches dominated runtime for 100k+ tokens, # so run on CPU unconditionally — it's bandwidth-bound on 33x33 and # avoids one CUDA sync per timestep. if lp.is_cuda: lp = lp.to("cpu", dtype=torch.float32, non_blocking=True) else: lp = lp.to(dtype=torch.float32) sl, nc = lp.shape if sl == 0: return [] st, en, tr = self._start, self._end, self._trans # already CPU/fp32 scores = lp[0] + st bp = torch.empty((sl - 1, nc), dtype=torch.int64) for i in range(1, sl): t = scores.unsqueeze(1) + tr bs, bi = t.max(dim=0) scores = bs + lp[i]; bp[i - 1] = bi if not torch.isfinite(scores).any(): return lp.argmax(dim=1).tolist() scores = scores + en path = torch.empty(sl, dtype=torch.int64) path[-1] = scores.argmax() for i in range(sl - 2, -1, -1): path[i] = bp[i, path[i + 1]] return path.tolist() # ── runtime singleton ──────────────────────────────────────────── @dataclass(frozen=True) class InferenceRuntime: model: Transformer; encoding: tiktoken.Encoding; label_info: LabelInfo device: torch.device; n_ctx: int @functools.lru_cache(maxsize=1) def get_runtime(): cp = MODEL_DIR cfg = json.loads((cp / "config.json").read_text()) validate_model_config_contract(cfg, context=str(cp)) device = torch.device("cuda") encoding = tiktoken.get_encoding(str(cfg["encoding"]).strip()) scn = [BACKGROUND_CLASS_LABEL]; sll = {BACKGROUND_CLASS_LABEL: 0} bll, ttsl, tbt = {}, {}, {} bg_idx = None for idx, name in enumerate(NER_CLASS_NAMES): if name == BACKGROUND_CLASS_LABEL: bg_idx = idx; ttsl[idx] = 0; tbt[idx] = None; continue bnd, base = name.split("-", 1) si = sll.get(base) if si is None: si = len(scn); scn.append(base); sll[base] = si ttsl[idx] = si; tbt[idx] = bnd bll.setdefault(base, {})[bnd] = idx li = LabelInfo(bll, ttsl, tbt, tuple(scn), sll, bg_idx, 0) m = Transformer.from_checkpoint(str(cp), device=device) return InferenceRuntime(m, encoding, li, device, int(cfg["default_n_ctx"])) @functools.lru_cache(maxsize=1) def get_decoder(): return Decoder(label_info=get_runtime().label_info) @torch.inference_mode() def predict_text(runtime, text, decoder): tids = tuple(int(t) for t in runtime.encoding.encode(text, allowed_special="all")) if not tids: return text, [] # Run the model per-chunk and concat once. The v4 code built a Python # list via `.unbind(0)` and then rebuilt the same tensor via stack — a # no-op that paid 100k small allocations on long inputs. chunks = [] for s in range(0, len(tids), runtime.n_ctx): e = min(s + runtime.n_ctx, len(tids)) wt = torch.tensor(tids[s:e], device=runtime.device, dtype=torch.int32) lp = F.log_softmax(runtime.model(wt).float(), dim=-1) chunks.append(lp) stacked = chunks[0] if len(chunks) == 1 else torch.cat(chunks, dim=0) dl = decoder.decode(stacked) # Decoder pulls to CPU internally if len(dl) != len(tids): dl = stacked.argmax(dim=1).tolist() pli = {i: int(l) for i, l in enumerate(dl)} pts = labels_to_spans(pli, runtime.label_info) tb = [runtime.encoding.decode_single_token_bytes(t) for t in tids] dt = b"".join(tb).decode("utf-8", errors="replace") cbs, cbe = [], [] bc = 0 for ch in dt: cbs.append(bc); bc += len(ch.encode("utf-8")); cbe.append(bc) cs, ce = [], [] tbc = 0 for rb in tb: tbs = tbc; tbe = tbs + len(rb); tbc = tbe cs.append(bisect_right(cbe, tbs)); ce.append(bisect_left(cbs, tbe)) pcs = token_spans_to_char_spans(pts, cs, ce) pcs = trim_char_spans_whitespace(pcs, dt if dt != text else text) src = dt if dt != text else text detected = [] for li, s, e in pcs: if 0 <= li < len(runtime.label_info.span_class_names): lbl = runtime.label_info.span_class_names[li] else: lbl = f"label_{li}" detected.append({"label": lbl, "start": s, "end": e, "text": src[s:e]}) return src, detected # ===================================================================== # APPLICATION LAYER # ===================================================================== def extract_text(file_path: str) -> str: suffix = Path(file_path).suffix.lower() if suffix == ".pdf": import fitz doc = fitz.open(file_path) pages = [page.get_text() for page in doc] doc.close() return "\n\n".join(pages) elif suffix in (".docx", ".doc"): from docx import Document doc = Document(file_path) return "\n\n".join(p.text for p in doc.paragraphs if p.text.strip()) raise ValueError(f"Unsupported file type: {suffix}") def compute_stats(text, spans): total = len(text) pii_chars = sum(s["end"] - s["start"] for s in spans) by_cat = {} for s in spans: c = s["label"] by_cat.setdefault(c, {"count": 0, "chars": 0}) by_cat[c]["count"] += 1; by_cat[c]["chars"] += s["end"] - s["start"] return { "total_chars": total, "pii_chars": pii_chars, "pii_percentage": round(pii_chars / total * 100, 1) if total else 0, "total_spans": len(spans), "categories": by_cat, "num_categories": len(by_cat), "total_lines": text.count("\n") + 1 if total else 0, } def detect_speakers(text, spans): patterns = [r"^([A-Z][a-zA-Z ]{1,30}):\s", r"^\[([^\]]{1,30})\]\s", r"^(Speaker\s*\d+):\s"] line_sp, pos, cur = [], 0, None for line in text.split("\n"): for p in patterns: m = re.match(p, line) if m: cur = m.group(1).strip(); break line_sp.append((pos, pos + len(line), cur)); pos += len(line) + 1 result = {} for span in spans: mid = (span["start"] + span["end"]) // 2 speaker = "Document" for ls, le, sp in line_sp: if ls <= mid <= le and sp: speaker = sp; break result[speaker] = result.get(speaker, 0) + 1 return {} if list(result.keys()) == ["Document"] else result @spaces.GPU def run_pii_analysis(text: str): """GPU-accelerated PII detection.""" runtime = get_runtime() decoder = get_decoder() # cached, not rebuilt per request source_text, detected = predict_text(runtime, text, decoder) return source_text, detected def build_redacted_pdf_bytes(pdf_path: str, pii_texts: list[str]) -> bytes: """ True PyMuPDF redaction: draws a black fill rectangle over the target text AND removes the underlying text stream. Longer strings are redacted first so fuller matches win over their substrings. """ import fitz # Longest first: "Dr. Margaret Holloway" before "Margaret" ordered = sorted({t.strip() for t in pii_texts if t and len(t.strip()) >= 2}, key=len, reverse=True) doc = fitz.open(pdf_path) try: for page in doc: for needle in ordered: for rect in page.search_for(needle): page.add_redact_annot(rect, fill=(0, 0, 0)) page.apply_redactions() buf = io.BytesIO() doc.save(buf, garbage=4, deflate=True) return buf.getvalue() finally: doc.close() # ── Gradio Server ──────────────────────────────────────────────── server = gr.Server() @server.get("/", response_class=HTMLResponse) async def homepage(): return FRONTEND_HTML @server.post("/api/analyze") async def analyze_document(file: UploadFile = File(...)): suffix = Path(file.filename).suffix.lower() if suffix not in (".pdf", ".doc", ".docx"): return JSONResponse({"error": f"Unsupported: {suffix}. Use PDF, DOC, or DOCX."}, 400) with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: tmp.write(await file.read()); tmp_path = tmp.name try: text = extract_text(tmp_path) if not text.strip(): return JSONResponse({"error": "No text content found."}, 400) source_text, spans = run_pii_analysis(text) stats = compute_stats(source_text, spans) speakers = detect_speakers(source_text, spans) return JSONResponse({ "filename": file.filename, "text": source_text, "spans": spans, "stats": stats, "speakers": speakers, "categories_meta": {k: {"color": v["color"], "cls": v["cls"], "label": v["label"], "mono": v["mono"]} for k, v in CATEGORIES_META.items()}, }) except Exception as e: return JSONResponse({"error": str(e)}, 500) finally: if os.path.exists(tmp_path): os.unlink(tmp_path) @server.post("/api/redact-pdf") async def redact_pdf_endpoint( file: UploadFile = File(...), spans: str = Form(...), active: str = Form(...), ): suffix = Path(file.filename).suffix.lower() if suffix != ".pdf": return JSONResponse({"error": "PDF redaction only accepts PDF input."}, 400) try: span_list = json.loads(spans) active_set = set(json.loads(active)) except Exception as e: return JSONResponse({"error": f"Invalid payload: {e}"}, 400) pii_texts = [ s.get("text", "") for s in span_list if s.get("label") in active_set ] if not pii_texts: return JSONResponse({"error": "No active categories selected — nothing to redact."}, 400) with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: tmp.write(await file.read()); tmp_path = tmp.name try: pdf_bytes = build_redacted_pdf_bytes(tmp_path, pii_texts) out_name = (Path(file.filename).stem or "document") + ".redacted.pdf" return StreamingResponse( io.BytesIO(pdf_bytes), media_type="application/pdf", headers={"Content-Disposition": f'attachment; filename="{out_name}"'}, ) except Exception as e: return JSONResponse({"error": str(e)}, 500) finally: if os.path.exists(tmp_path): os.unlink(tmp_path) @server.api(name="analyze_text") def analyze_text_api(text: str) -> str: """Gradio API: analyze raw text for PII.""" source_text, spans = run_pii_analysis(text) stats = compute_stats(source_text, spans) return json.dumps({"text": source_text, "spans": spans, "stats": stats}, ensure_ascii=False) # ── Frontend HTML (v5) ─────────────────────────────────────────── FRONTEND_HTML = r""" PII Reveal — Inspector
PII Reveal/ inspector

See what your documents are leaking.

Find every PII span in a PDF, DOC or DOCX — names, accounts, secrets and five other entity types — then export a fully redacted copy.

Person Email Date Address Account URL Phone Secret
Drop a document, or click to browse
pdf · doc · docx · up to 128k tokens
openai privacy filter 128k ctx bfloat16 apache 2.0
Scan complete
0%
PII content
0
Spans detected
0 / 8
Categories present
0
Speakers identified
scanning document…
""" # ── launch ─────────────────────────────────────────────────────── if __name__ == "__main__": server.launch(server_name="0.0.0.0", server_port=7860)