ysharma HF Staff commited on
Commit
01719f8
Β·
verified Β·
1 Parent(s): 47b2347

Update app_v6.py

Browse files
Files changed (1) hide show
  1. app_v6.py +154 -619
app_v6.py CHANGED
@@ -1,40 +1,42 @@
1
  """
2
- ============================================
3
- PII Explorer - Document Privacy Explorer / Playground
4
- ============================================
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  """
6
 
7
  # ── stdlib ───────────────────────────────────────────────────────
8
- import dataclasses
9
  import functools
10
  import io
11
  import json
12
- import math
13
  import os
14
  import re
15
  import tempfile
16
  import time
17
- from bisect import bisect_left, bisect_right
18
- from collections.abc import Sequence
19
- from dataclasses import dataclass
20
  from pathlib import Path
21
- from typing import Final
22
 
23
  # ── third-party ──────────────────────────────────────────────────
24
  import gradio as gr
25
  import spaces
26
- import tiktoken
27
  import torch
28
- import torch.nn.functional as F
29
- from fastapi import File, Form, UploadFile
30
- from fastapi.responses import HTMLResponse, JSONResponse, StreamingResponse
31
- from huggingface_hub import snapshot_download
32
- from safetensors import safe_open
33
 
34
  # ── configuration ────────────────────────────────────────────────
35
- MODEL_REPO = os.getenv("MODEL_ID", "charles-first-org/second-model")
36
  HF_TOKEN = os.getenv("HF_TOKEN", None)
37
- MODEL_DIR = Path(snapshot_download(MODEL_REPO, token=HF_TOKEN))
38
 
39
  CATEGORIES_META = {
40
  "private_person": {"color": "#E24B4A", "cls": "hp", "label": "Person", "mono": False},
@@ -48,537 +50,41 @@ CATEGORIES_META = {
48
  }
49
 
50
  # =====================================================================
51
- # MODEL ARCHITECTURE + INFERENCE (unchanged from reference impl)
52
  # =====================================================================
53
 
54
- PRIVACY_FILTER_MODEL_TYPE: Final[str] = "privacy_filter"
55
- REQUIRED_MODEL_CONFIG_KEYS: Final[tuple[str, ...]] = (
56
- "model_type", "encoding", "num_hidden_layers", "num_experts",
57
- "experts_per_token", "vocab_size", "num_labels", "hidden_size",
58
- "intermediate_size", "head_dim", "num_attention_heads",
59
- "num_key_value_heads", "sliding_window", "bidirectional_context",
60
- "bidirectional_left_context", "bidirectional_right_context",
61
- "default_n_ctx", "initial_context_length", "rope_theta",
62
- "rope_scaling_factor", "rope_ntk_alpha", "rope_ntk_beta", "param_dtype",
63
- )
64
- BACKGROUND_CLASS_LABEL: Final[str] = "O"
65
- BOUNDARY_PREFIXES: Final[tuple[str, ...]] = ("B", "I", "E", "S")
66
- SPAN_CLASS_NAMES: Final[tuple[str, ...]] = (
67
- BACKGROUND_CLASS_LABEL,
68
- "account_number", "private_address", "private_date", "private_email",
69
- "private_person", "private_phone", "private_url", "secret",
70
- )
71
- NER_CLASS_NAMES: Final[tuple[str, ...]] = (BACKGROUND_CLASS_LABEL,) + tuple(
72
- f"{prefix}-{base}"
73
- for base in SPAN_CLASS_NAMES if base != BACKGROUND_CLASS_LABEL
74
- for prefix in BOUNDARY_PREFIXES
75
- )
76
- VITERBI_TRANSITION_BIAS_KEYS: Final[tuple[str, ...]] = (
77
- "transition_bias_background_stay", "transition_bias_background_to_start",
78
- "transition_bias_inside_to_continue", "transition_bias_inside_to_end",
79
- "transition_bias_end_to_background", "transition_bias_end_to_start",
80
- )
81
- DEFAULT_VITERBI_CALIBRATION_PRESET: Final[str] = "default"
82
-
83
-
84
- def validate_model_config_contract(cfg: dict, *, context: str) -> None:
85
- missing = [k for k in REQUIRED_MODEL_CONFIG_KEYS if k not in cfg]
86
- if missing:
87
- raise ValueError(f"{context} missing keys: {', '.join(missing)}")
88
- if cfg.get("model_type") != PRIVACY_FILTER_MODEL_TYPE:
89
- raise ValueError(f"{context} model_type must be {PRIVACY_FILTER_MODEL_TYPE!r}")
90
- if cfg.get("bidirectional_context") is not True:
91
- raise ValueError(f"{context} must use bidirectional_context=true")
92
- lc, rc = cfg.get("bidirectional_left_context"), cfg.get("bidirectional_right_context")
93
- if not isinstance(lc, int) or not isinstance(rc, int) or lc != rc or lc < 0:
94
- raise ValueError(f"{context} bidirectional context must be equal non-negative ints")
95
- sw = cfg.get("sliding_window")
96
- if sw != 2 * lc + 1:
97
- raise ValueError(f"{context} sliding_window must equal 2*context+1")
98
- if cfg["num_labels"] != 33:
99
- raise ValueError(f"{context} num_labels must be 33")
100
- if cfg["param_dtype"] != "bfloat16":
101
- raise ValueError(f"{context} param_dtype must be bfloat16")
102
-
103
-
104
- def expert_linear(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None) -> torch.Tensor:
105
- n, e, k = x.shape
106
- _, _, _, o = weight.shape
107
- out = torch.bmm(x.reshape(n * e, 1, k), weight.reshape(n * e, k, o)).reshape(n, e, o)
108
- return out + bias if bias is not None else out
109
-
110
-
111
- @dataclass
112
- class ModelConfig:
113
- num_hidden_layers: int; num_experts: int; experts_per_token: int
114
- vocab_size: int; num_labels: int; hidden_size: int; intermediate_size: int
115
- head_dim: int; num_attention_heads: int; num_key_value_heads: int
116
- bidirectional_context_size: int; initial_context_length: int
117
- rope_theta: float; rope_scaling_factor: float; rope_ntk_alpha: float; rope_ntk_beta: float
118
-
119
- @classmethod
120
- def from_checkpoint_config(cls, cfg: dict, *, context: str) -> "ModelConfig":
121
- cfg = dict(cfg)
122
- cfg["bidirectional_context_size"] = cfg["bidirectional_left_context"]
123
- fields = {f.name for f in dataclasses.fields(cls)}
124
- return cls(**{k: v for k, v in cfg.items() if k in fields})
125
-
126
-
127
- class RMSNorm(torch.nn.Module):
128
- def __init__(self, n: int, eps: float = 1e-5, device=None):
129
- super().__init__()
130
- self.eps = eps
131
- self.scale = torch.nn.Parameter(torch.ones(n, device=device, dtype=torch.float32))
132
-
133
- def forward(self, x):
134
- t = x.float()
135
- return (t * torch.rsqrt(t.pow(2).mean(-1, keepdim=True) + self.eps) * self.scale).to(x.dtype)
136
-
137
-
138
- def apply_rope(x, cos, sin):
139
- cos = cos.unsqueeze(-2).to(x.dtype); sin = sin.unsqueeze(-2).to(x.dtype)
140
- x1, x2 = x[..., ::2], x[..., 1::2]
141
- return torch.stack((x1 * cos - x2 * sin, x2 * cos + x1 * sin), dim=-1).reshape(x.shape)
142
-
143
-
144
- class RotaryEmbedding(torch.nn.Module):
145
- def __init__(self, head_dim, base, dtype, *, initial_context_length=4096,
146
- scaling_factor=1.0, ntk_alpha=1.0, ntk_beta=32.0, device=None):
147
- super().__init__()
148
- self.head_dim, self.base, self.dtype = head_dim, base, dtype
149
- self.initial_context_length = initial_context_length
150
- self.scaling_factor, self.ntk_alpha, self.ntk_beta = scaling_factor, ntk_alpha, ntk_beta
151
- self.device = device
152
- mp = max(int(initial_context_length * scaling_factor), initial_context_length)
153
- self.max_position_embeddings = mp
154
- cos, sin = self._compute(mp, device=torch.device("cpu"))
155
- target = device or torch.device("cpu")
156
- self.register_buffer("cos_cache", cos.to(target), persistent=False)
157
- self.register_buffer("sin_cache", sin.to(target), persistent=False)
158
-
159
- def _inv_freq(self, device=None):
160
- device = device or self.device
161
- freq = self.base ** (torch.arange(0, self.head_dim, 2, dtype=torch.float, device=device) / self.head_dim)
162
- if self.scaling_factor > 1.0:
163
- d_half = self.head_dim / 2
164
- low = d_half * math.log(self.initial_context_length / (self.ntk_beta * 2 * math.pi)) / math.log(self.base)
165
- high = d_half * math.log(self.initial_context_length / (self.ntk_alpha * 2 * math.pi)) / math.log(self.base)
166
- interp = 1.0 / (self.scaling_factor * freq)
167
- extrap = 1.0 / freq
168
- ramp = (torch.arange(d_half, dtype=torch.float32, device=device) - low) / (high - low)
169
- mask = 1 - ramp.clamp(0, 1)
170
- return interp * (1 - mask) + extrap * mask
171
- return 1.0 / freq
172
-
173
- def _compute(self, n, device=None):
174
- inv_freq = self._inv_freq(device)
175
- t = torch.arange(n, dtype=torch.float32, device=device or self.device)
176
- freqs = torch.einsum("i,j->ij", t, inv_freq)
177
- c = 0.1 * math.log(self.scaling_factor) + 1.0 if self.scaling_factor > 1.0 else 1.0
178
- return (freqs.cos() * c).to(self.dtype), (freqs.sin() * c).to(self.dtype)
179
-
180
- def forward(self, q, k):
181
- n = q.shape[0]
182
- if n > self.cos_cache.shape[0]:
183
- cos, sin = self._compute(n, torch.device("cpu"))
184
- self.cos_cache, self.sin_cache = cos.to(q.device), sin.to(q.device)
185
- cc = self.cos_cache.to(q.device) if self.cos_cache.device != q.device else self.cos_cache
186
- sc = self.sin_cache.to(q.device) if self.sin_cache.device != q.device else self.sin_cache
187
- cos, sin = cc[:n], sc[:n]
188
- q = apply_rope(q.view(n, -1, self.head_dim), cos, sin).reshape(q.shape)
189
- k = apply_rope(k.view(n, -1, self.head_dim), cos, sin).reshape(k.shape)
190
- return q, k
191
-
192
-
193
- def sdpa(Q, K, V, S, sm_scale, ctx):
194
- n, nh, qm, hd = Q.shape
195
- w = 2 * ctx + 1
196
- Kp = F.pad(K, (0, 0, 0, 0, ctx, ctx)); Vp = F.pad(V, (0, 0, 0, 0, ctx, ctx))
197
- Kw = Kp.unfold(0, w, 1).permute(0, 3, 1, 2); Vw = Vp.unfold(0, w, 1).permute(0, 3, 1, 2)
198
- idx = torch.arange(w, device=Q.device) - ctx
199
- pos = torch.arange(n, device=Q.device)[:, None] + idx[None, :]
200
- valid = (pos >= 0) & (pos < n)
201
- scores = torch.einsum("nhqd,nwhd->nhqw", Q, Kw).float() * sm_scale
202
- scores = scores.masked_fill(~valid[:, None, None, :], -float("inf"))
203
- sink = (S * math.log(2.0)).reshape(nh, qm)[None, :, :, None].expand(n, -1, -1, 1)
204
- scores = torch.cat([scores, sink], dim=-1)
205
- wt = torch.softmax(scores, dim=-1)[..., :-1].to(V.dtype)
206
- return torch.einsum("nhqw,nwhd->nhqd", wt, Vw).reshape(n, -1)
207
-
208
-
209
- class AttentionBlock(torch.nn.Module):
210
- def __init__(self, cfg: ModelConfig, device=None):
211
- super().__init__()
212
- dt = torch.bfloat16
213
- self.head_dim, self.nah, self.nkv = cfg.head_dim, cfg.num_attention_heads, cfg.num_key_value_heads
214
- self.ctx = int(cfg.bidirectional_context_size)
215
- self.sinks = torch.nn.Parameter(torch.empty(cfg.num_attention_heads, device=device, dtype=torch.float32))
216
- self.norm = RMSNorm(cfg.hidden_size, device=device)
217
- qkv_d = cfg.head_dim * (cfg.num_attention_heads + 2 * cfg.num_key_value_heads)
218
- self.qkv = torch.nn.Linear(cfg.hidden_size, qkv_d, device=device, dtype=dt)
219
- self.out = torch.nn.Linear(cfg.head_dim * cfg.num_attention_heads, cfg.hidden_size, device=device, dtype=dt)
220
- self.qk_scale = 1 / math.sqrt(math.sqrt(cfg.head_dim))
221
- self.rope = RotaryEmbedding(cfg.head_dim, int(cfg.rope_theta), torch.float32,
222
- initial_context_length=cfg.initial_context_length,
223
- scaling_factor=cfg.rope_scaling_factor,
224
- ntk_alpha=cfg.rope_ntk_alpha, ntk_beta=cfg.rope_ntk_beta, device=device)
225
-
226
- def forward(self, x):
227
- t = self.norm(x).to(self.qkv.weight.dtype)
228
- qkv = F.linear(t, self.qkv.weight, self.qkv.bias)
229
- hd, nah, nkv = self.head_dim, self.nah, self.nkv
230
- q = qkv[:, :nah * hd].contiguous()
231
- k = qkv[:, nah * hd:(nah + nkv) * hd].contiguous()
232
- v = qkv[:, (nah + nkv) * hd:(nah + 2 * nkv) * hd].contiguous()
233
- q, k = self.rope(q, k)
234
- q, k = q * self.qk_scale, k * self.qk_scale
235
- n = q.shape[0]
236
- q = q.view(n, nkv, nah // nkv, hd); k = k.view(n, nkv, hd); v = v.view(n, nkv, hd)
237
- ao = sdpa(q, k, v, self.sinks, 1.0, self.ctx).to(self.out.weight.dtype)
238
- return x + F.linear(ao, self.out.weight, self.out.bias).to(x.dtype)
239
-
240
-
241
- def swiglu(x, alpha=1.702, limit=7.0):
242
- g, l = x.chunk(2, dim=-1)
243
- g, l = g.clamp(max=limit), l.clamp(-limit, limit)
244
- return g * torch.sigmoid(alpha * g) * (l + 1)
245
-
246
-
247
- class MLPBlock(torch.nn.Module):
248
- def __init__(self, cfg: ModelConfig, device=None):
249
- super().__init__()
250
- dt = torch.bfloat16
251
- self.ne, self.ept = cfg.num_experts, cfg.experts_per_token
252
- self.norm = RMSNorm(cfg.hidden_size, device=device)
253
- self.gate = torch.nn.Linear(cfg.hidden_size, cfg.num_experts, device=device, dtype=dt)
254
- self.mlp1_weight = torch.nn.Parameter(torch.empty(cfg.num_experts, cfg.hidden_size, cfg.intermediate_size * 2, device=device, dtype=dt))
255
- self.mlp1_bias = torch.nn.Parameter(torch.empty(cfg.num_experts, cfg.intermediate_size * 2, device=device, dtype=dt))
256
- self.mlp2_weight = torch.nn.Parameter(torch.empty(cfg.num_experts, cfg.intermediate_size, cfg.hidden_size, device=device, dtype=dt))
257
- self.mlp2_bias = torch.nn.Parameter(torch.empty(cfg.num_experts, cfg.hidden_size, device=device, dtype=dt))
258
-
259
- def forward(self, x):
260
- t = self.norm(x)
261
- gs = F.linear(t.float(), self.gate.weight.float(), self.gate.bias.float())
262
- top = torch.topk(gs, k=self.ept, dim=-1, sorted=True)
263
- ew = torch.softmax(top.values, dim=-1) / self.ept
264
- ei = top.indices
265
- ept = self.ept
266
-
267
- def _chunk(tc, eic, ewc):
268
- o = expert_linear(tc.float().unsqueeze(1).expand(-1, eic.shape[1], -1),
269
- self.mlp1_weight[eic].float(), self.mlp1_bias[eic].float())
270
- o = swiglu(o)
271
- o = expert_linear(o.float(), self.mlp2_weight[eic].float(), self.mlp2_bias[eic].float())
272
- return (torch.einsum("bec,be->bc", o.to(ewc.dtype), ewc) * ept).to(x.dtype)
273
-
274
- cs = 32
275
- if t.shape[0] > cs:
276
- parts = [_chunk(t[s:s+cs], ei[s:s+cs], ew[s:s+cs]) for s in range(0, t.shape[0], cs)]
277
- return x + torch.cat(parts, 0)
278
- return x + _chunk(t, ei, ew)
279
-
280
-
281
- class TransformerBlock(torch.nn.Module):
282
- def __init__(self, cfg, device=None):
283
- super().__init__()
284
- self.attn = AttentionBlock(cfg, device=device)
285
- self.mlp = MLPBlock(cfg, device=device)
286
- def forward(self, x):
287
- return self.mlp(self.attn(x))
288
-
289
-
290
- class Checkpoint:
291
- @staticmethod
292
- def build_param_name_map(n):
293
- return ({f"block.{i}.mlp.mlp1_bias": f"block.{i}.mlp.swiglu.bias" for i in range(n)}
294
- | {f"block.{i}.mlp.mlp1_weight": f"block.{i}.mlp.swiglu.weight" for i in range(n)}
295
- | {f"block.{i}.mlp.mlp2_bias": f"block.{i}.mlp.out.bias" for i in range(n)}
296
- | {f"block.{i}.mlp.mlp2_weight": f"block.{i}.mlp.out.weight" for i in range(n)})
297
-
298
- def __init__(self, path, device, num_hidden_layers):
299
- self.pnm = self.build_param_name_map(num_hidden_layers)
300
- self.ds = device.type if device.index is None else f"{device.type}:{device.index}"
301
- files = [os.path.join(path, f) for f in os.listdir(path) if f.endswith(".safetensors")]
302
- self.map = {}
303
- for sf in files:
304
- with safe_open(sf, framework="pt", device=self.ds) as h:
305
- for k in h.keys():
306
- self.map[k] = sf
307
-
308
- def get(self, name):
309
- mapped = self.pnm.get(name, name)
310
- with safe_open(self.map[mapped], framework="pt", device=self.ds) as h:
311
- return h.get_tensor(mapped)
312
-
313
-
314
- class Transformer(torch.nn.Module):
315
- def __init__(self, cfg, device):
316
- super().__init__()
317
- dt = torch.bfloat16
318
- self.embedding = torch.nn.Embedding(cfg.vocab_size, cfg.hidden_size, device=device, dtype=dt)
319
- self.block = torch.nn.ModuleList([TransformerBlock(cfg, device=device) for _ in range(cfg.num_hidden_layers)])
320
- self.norm = RMSNorm(cfg.hidden_size, device=device)
321
- self.unembedding = torch.nn.Linear(cfg.hidden_size, cfg.num_labels, bias=False, device=device, dtype=dt)
322
-
323
- def forward(self, token_ids):
324
- x = self.embedding(token_ids)
325
- for blk in self.block:
326
- x = blk(x)
327
- return F.linear(self.norm(x), self.unembedding.weight, None)
328
-
329
- @classmethod
330
- def from_checkpoint(cls, checkpoint_dir, *, device):
331
- torch.backends.cuda.matmul.allow_tf32 = False
332
- torch.backends.cudnn.allow_tf32 = False
333
- torch.set_float32_matmul_precision("highest")
334
- cp = json.loads((Path(checkpoint_dir) / "config.json").read_text())
335
- validate_model_config_contract(cp, context=str(checkpoint_dir))
336
- cfg = ModelConfig.from_checkpoint_config(cp, context=str(checkpoint_dir))
337
- ckpt = Checkpoint(checkpoint_dir, device, cfg.num_hidden_layers)
338
- m = cls(cfg, device); m.eval()
339
- for name, param in m.named_parameters():
340
- loaded = ckpt.get(name)
341
- if param.shape != loaded.shape:
342
- raise ValueError(f"Shape mismatch {name}: {param.shape} vs {loaded.shape}")
343
- param.data.copy_(loaded)
344
- return m
345
-
346
-
347
- # ── label info + span decoding ───────────────────────────────────
348
-
349
- @dataclass(frozen=True)
350
- class LabelInfo:
351
- boundary_label_lookup: dict[str, dict[str, int]]
352
- token_to_span_label: dict[int, int]
353
- token_boundary_tags: dict[int, str | None]
354
- span_class_names: tuple[str, ...]
355
- span_label_lookup: dict[str, int]
356
- background_token_label: int
357
- background_span_label: int
358
-
359
-
360
- def labels_to_spans(labels_by_index, label_info):
361
- spans, cur_label, start_idx, prev_idx = [], None, None, None
362
- bg = label_info.background_span_label
363
- for ti in sorted(labels_by_index):
364
- lid = labels_by_index[ti]
365
- sl = label_info.token_to_span_label.get(lid)
366
- bt = label_info.token_boundary_tags.get(lid)
367
- if prev_idx is not None and ti != prev_idx + 1:
368
- if cur_label is not None and start_idx is not None:
369
- spans.append((cur_label, start_idx, prev_idx + 1))
370
- cur_label = start_idx = None
371
- if sl is None:
372
- prev_idx = ti; continue
373
- if sl == bg:
374
- if cur_label is not None and start_idx is not None:
375
- spans.append((cur_label, start_idx, ti))
376
- cur_label = start_idx = None; prev_idx = ti; continue
377
- if bt == "S":
378
- if cur_label is not None and start_idx is not None and prev_idx is not None:
379
- spans.append((cur_label, start_idx, prev_idx + 1))
380
- spans.append((sl, ti, ti + 1)); cur_label = start_idx = None
381
- elif bt == "B":
382
- if cur_label is not None and start_idx is not None and prev_idx is not None:
383
- spans.append((cur_label, start_idx, prev_idx + 1))
384
- cur_label, start_idx = sl, ti
385
- elif bt == "I":
386
- if cur_label is None or cur_label != sl:
387
- if cur_label is not None and start_idx is not None and prev_idx is not None:
388
- spans.append((cur_label, start_idx, prev_idx + 1))
389
- cur_label, start_idx = sl, ti
390
- elif bt == "E":
391
- if cur_label is None or cur_label != sl or start_idx is None:
392
- if cur_label is not None and start_idx is not None and prev_idx is not None:
393
- spans.append((cur_label, start_idx, prev_idx + 1))
394
- spans.append((sl, ti, ti + 1)); cur_label = start_idx = None
395
- else:
396
- spans.append((cur_label, start_idx, ti + 1)); cur_label = start_idx = None
397
- else:
398
- if cur_label is not None and start_idx is not None and prev_idx is not None:
399
- spans.append((cur_label, start_idx, prev_idx + 1))
400
- cur_label = start_idx = None
401
- prev_idx = ti
402
- if cur_label is not None and start_idx is not None and prev_idx is not None:
403
- spans.append((cur_label, start_idx, prev_idx + 1))
404
- return spans
405
-
406
-
407
- def token_spans_to_char_spans(spans, cs, ce):
408
- out = []
409
- for li, ts, te in spans:
410
- if not (0 <= ts < te <= len(cs)):
411
- continue
412
- s, e = cs[ts], ce[te - 1]
413
- if e > s:
414
- out.append((li, s, e))
415
- return out
416
-
417
-
418
- def trim_char_spans_whitespace(spans, text):
419
- out = []
420
- for li, s, e in spans:
421
- if not (0 <= s < e <= len(text)):
422
- continue
423
- while s < e and text[s].isspace(): s += 1
424
- while e > s and text[e - 1].isspace(): e -= 1
425
- if e > s:
426
- out.append((li, s, e))
427
- return out
428
-
429
-
430
- # ── viterbi decoder ──────────────────────────────────────────────
431
-
432
  @functools.lru_cache(maxsize=1)
433
- def get_viterbi_transition_biases():
434
- cp = MODEL_DIR / "viterbi_calibration.json"
435
- default = {k: 0.0 for k in VITERBI_TRANSITION_BIAS_KEYS}
436
- if not cp.is_file():
437
- return default
438
- payload = json.loads(cp.read_text())
439
- raw = payload
440
- ops = payload.get("operating_points")
441
- if isinstance(ops, dict):
442
- preset = ops.get(DEFAULT_VITERBI_CALIBRATION_PRESET)
443
- if isinstance(preset, dict):
444
- raw = preset.get("biases", raw)
445
- if not isinstance(raw, dict):
446
- return default
447
- return {k: float(raw.get(k, 0.0)) for k in VITERBI_TRANSITION_BIAS_KEYS}
448
-
449
-
450
- class Decoder:
451
- def __init__(self, label_info):
452
- nc = len(label_info.token_to_span_label)
453
- self._start = torch.full((nc,), -1e9, dtype=torch.float32)
454
- self._end = torch.full((nc,), -1e9, dtype=torch.float32)
455
- self._trans = torch.full((nc, nc), -1e9, dtype=torch.float32)
456
- biases = get_viterbi_transition_biases()
457
- bg_tok, bg_sp = label_info.background_token_label, label_info.background_span_label
458
- ttsl, tbt = label_info.token_to_span_label, label_info.token_boundary_tags
459
- for i in range(nc):
460
- tag, sl = tbt.get(i), ttsl.get(i)
461
- if tag in {"B", "S"} or i == bg_tok: self._start[i] = 0.0
462
- if tag in {"E", "S"} or i == bg_tok: self._end[i] = 0.0
463
- for j in range(nc):
464
- nt, ns = tbt.get(j), ttsl.get(j)
465
- if self._valid(tag, sl, nt, ns, bg_tok, bg_sp, j):
466
- self._trans[i, j] = self._bias(tag, sl, nt, ns, bg_sp, biases)
467
-
468
- @staticmethod
469
- def _valid(pt, ps, nt, ns, bti, bsi, ni):
470
- nb = ns == bsi or ni == bti
471
- if (ns is None or nt is None) and not nb: return False
472
- if pt is None or ps is None: return nb or nt in {"B", "S"}
473
- if ps == bsi or pt in {"E", "S"}: return nb or nt in {"B", "S"}
474
- if pt in {"B", "I"}: return ps == ns and nt in {"I", "E"}
475
- return False
476
-
477
- @staticmethod
478
- def _bias(pt, ps, nt, ns, bsi, b):
479
- nb, pb = ns == bsi, ps == bsi
480
- if pb: return b["transition_bias_background_stay"] if nb else b["transition_bias_background_to_start"]
481
- if pt in {"B", "I"}: return b["transition_bias_inside_to_continue"] if nt == "I" else b["transition_bias_inside_to_end"]
482
- return b["transition_bias_end_to_background"] if nb else b["transition_bias_end_to_start"]
483
-
484
- def decode(self, lp):
485
- # Runs on lp's device. When lp is on CUDA, the loop streams tiny
486
- # kernels into the CUDA queue β€” on a warmed-up T4 this completes
487
- # in a few seconds. v5's move to CPU looked cheap on paper but
488
- # PyTorch CPU dispatch overhead made it far worse in practice.
489
- sl, nc = lp.shape
490
- if sl == 0: return []
491
- st = self._start.to(lp.device, lp.dtype)
492
- en = self._end.to(lp.device, lp.dtype)
493
- tr = self._trans.to(lp.device, lp.dtype)
494
- scores = lp[0] + st
495
- bp = torch.empty((sl - 1, nc), device=lp.device, dtype=torch.int64)
496
- for i in range(1, sl):
497
- t = scores.unsqueeze(1) + tr
498
- bs, bi = t.max(dim=0)
499
- scores = bs + lp[i]; bp[i - 1] = bi
500
- if not torch.isfinite(scores).any(): return lp.argmax(dim=1).tolist()
501
- scores = scores + en
502
- path = torch.empty(sl, device=lp.device, dtype=torch.int64)
503
- path[-1] = scores.argmax()
504
- for i in range(sl - 2, -1, -1): path[i] = bp[i, path[i + 1]]
505
- return path.tolist()
506
-
507
-
508
- # ── runtime singleton ────────────────────────────────────────────
509
-
510
- @dataclass(frozen=True)
511
- class InferenceRuntime:
512
- model: Transformer; encoding: tiktoken.Encoding; label_info: LabelInfo
513
- device: torch.device; n_ctx: int
514
-
515
-
516
- @functools.lru_cache(maxsize=1)
517
- def get_runtime():
518
- cp = MODEL_DIR
519
- cfg = json.loads((cp / "config.json").read_text())
520
- validate_model_config_contract(cfg, context=str(cp))
521
- device = torch.device("cuda")
522
- encoding = tiktoken.get_encoding(str(cfg["encoding"]).strip())
523
- scn = [BACKGROUND_CLASS_LABEL]; sll = {BACKGROUND_CLASS_LABEL: 0}
524
- bll, ttsl, tbt = {}, {}, {}
525
- bg_idx = None
526
- for idx, name in enumerate(NER_CLASS_NAMES):
527
- if name == BACKGROUND_CLASS_LABEL:
528
- bg_idx = idx; ttsl[idx] = 0; tbt[idx] = None; continue
529
- bnd, base = name.split("-", 1)
530
- si = sll.get(base)
531
- if si is None:
532
- si = len(scn); scn.append(base); sll[base] = si
533
- ttsl[idx] = si; tbt[idx] = bnd
534
- bll.setdefault(base, {})[bnd] = idx
535
- li = LabelInfo(bll, ttsl, tbt, tuple(scn), sll, bg_idx, 0)
536
- m = Transformer.from_checkpoint(str(cp), device=device)
537
- return InferenceRuntime(m, encoding, li, device, int(cfg["default_n_ctx"]))
538
 
539
 
540
- @functools.lru_cache(maxsize=1)
541
- def get_decoder():
542
- return Decoder(label_info=get_runtime().label_info)
543
-
544
-
545
- @torch.inference_mode()
546
- def predict_text(runtime, text, decoder):
547
- tids = tuple(int(t) for t in runtime.encoding.encode(text, allowed_special="all"))
548
- if not tids: return text, []
549
- chunks = []
550
- for s in range(0, len(tids), runtime.n_ctx):
551
- e = min(s + runtime.n_ctx, len(tids))
552
- wt = torch.tensor(tids[s:e], device=runtime.device, dtype=torch.int32)
553
- lp = F.log_softmax(runtime.model(wt).float(), dim=-1)
554
- chunks.append(lp)
555
- # Single-chunk case dodges a copy; multi-chunk falls through to cat.
556
- stacked = chunks[0] if len(chunks) == 1 else torch.cat(chunks, dim=0)
557
- dl = decoder.decode(stacked)
558
- if len(dl) != len(tids): dl = stacked.argmax(dim=1).tolist()
559
- pli = {i: int(l) for i, l in enumerate(dl)}
560
- pts = labels_to_spans(pli, runtime.label_info)
561
- tb = [runtime.encoding.decode_single_token_bytes(t) for t in tids]
562
- dt = b"".join(tb).decode("utf-8", errors="replace")
563
- cbs, cbe = [], []
564
- bc = 0
565
- for ch in dt: cbs.append(bc); bc += len(ch.encode("utf-8")); cbe.append(bc)
566
- cs, ce = [], []
567
- tbc = 0
568
- for rb in tb:
569
- tbs = tbc; tbe = tbs + len(rb); tbc = tbe
570
- cs.append(bisect_right(cbe, tbs)); ce.append(bisect_left(cbs, tbe))
571
- pcs = token_spans_to_char_spans(pts, cs, ce)
572
- pcs = trim_char_spans_whitespace(pcs, dt if dt != text else text)
573
- src = dt if dt != text else text
574
- detected = []
575
- for li, s, e in pcs:
576
- if 0 <= li < len(runtime.label_info.span_class_names):
577
- lbl = runtime.label_info.span_class_names[li]
578
- else:
579
- lbl = f"label_{li}"
580
- detected.append({"label": lbl, "start": s, "end": e, "text": src[s:e]})
581
- return src, detected
582
 
583
 
584
  # =====================================================================
@@ -637,10 +143,7 @@ def detect_speakers(text, spans):
637
  @spaces.GPU
638
  def run_pii_analysis(text: str):
639
  """GPU-accelerated PII detection."""
640
- runtime = get_runtime()
641
- decoder = get_decoder()
642
- source_text, detected = predict_text(runtime, text, decoder)
643
- return source_text, detected
644
 
645
 
646
  def build_redacted_pdf_bytes(pdf_path: str, pii_texts: list[str]) -> bytes:
@@ -691,6 +194,12 @@ def build_redacted_pdf_bytes(pdf_path: str, pii_texts: list[str]) -> bytes:
691
 
692
 
693
  # ── Gradio Server ────────────────────────────────────────────────
 
 
 
 
 
 
694
  server = gr.Server()
695
 
696
 
@@ -699,82 +208,96 @@ async def homepage():
699
  return FRONTEND_HTML
700
 
701
 
702
- @server.post("/api/analyze")
703
- async def analyze_document(file: UploadFile = File(...)):
704
- suffix = Path(file.filename).suffix.lower()
 
 
 
 
 
 
 
 
 
 
 
705
  if suffix not in (".pdf", ".doc", ".docx"):
706
- return JSONResponse({"error": f"Unsupported: {suffix}. Use PDF, DOC, or DOCX."}, 400)
707
- with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
708
- tmp.write(await file.read()); tmp_path = tmp.name
709
  try:
710
- text = extract_text(tmp_path)
711
  if not text.strip():
712
- return JSONResponse({"error": "No text content found."}, 400)
713
  source_text, spans = run_pii_analysis(text)
714
  stats = compute_stats(source_text, spans)
715
  speakers = detect_speakers(source_text, spans)
716
- return JSONResponse({
717
- "filename": file.filename, "text": source_text, "spans": spans,
718
- "stats": stats, "speakers": speakers,
719
- "categories_meta": {k: {"color": v["color"], "cls": v["cls"],
720
- "label": v["label"], "mono": v["mono"]}
721
- for k, v in CATEGORIES_META.items()},
722
- })
 
 
 
 
 
723
  except Exception as e:
724
- return JSONResponse({"error": str(e)}, 500)
725
- finally:
726
- if os.path.exists(tmp_path): os.unlink(tmp_path)
727
 
 
 
 
 
 
 
728
 
729
- @server.post("/api/redact-pdf")
730
- async def redact_pdf_endpoint(
731
- file: UploadFile = File(...),
732
- spans: str = Form(...),
733
- active: str = Form(...),
734
- ):
735
- suffix = Path(file.filename).suffix.lower()
736
  if suffix != ".pdf":
737
- return JSONResponse({"error": "PDF redaction only accepts PDF input."}, 400)
738
  try:
739
  span_list = json.loads(spans)
740
  active_set = set(json.loads(active))
741
  except Exception as e:
742
- return JSONResponse({"error": f"Invalid payload: {e}"}, 400)
743
 
744
  pii_texts = [
745
  s.get("text", "") for s in span_list
746
  if s.get("label") in active_set
747
  ]
748
  if not pii_texts:
749
- return JSONResponse({"error": "No active categories selected β€” nothing to redact."}, 400)
750
 
751
- with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
752
- tmp.write(await file.read()); tmp_path = tmp.name
753
  try:
754
  t0 = time.perf_counter()
755
- pdf_bytes = build_redacted_pdf_bytes(tmp_path, pii_texts)
756
- elapsed = time.perf_counter() - t0
757
- out_name = (Path(file.filename).stem or "document") + ".redacted.pdf"
758
- return StreamingResponse(
759
- io.BytesIO(pdf_bytes),
760
- media_type="application/pdf",
761
- headers={
762
- "Content-Disposition": f'attachment; filename="{out_name}"',
763
- "X-Redaction-Ms": str(int(elapsed * 1000)),
764
- },
765
- )
766
  except Exception as e:
767
- return JSONResponse({"error": str(e)}, 500)
768
- finally:
769
- if os.path.exists(tmp_path): os.unlink(tmp_path)
 
 
 
 
 
 
 
770
 
771
 
772
  @server.api(name="analyze_text")
773
- def analyze_text_api(text: str) -> str:
774
- """Gradio API: analyze raw text for PII."""
 
775
  source_text, spans = run_pii_analysis(text)
776
  stats = compute_stats(source_text, spans)
777
- return json.dumps({"text": source_text, "spans": spans, "stats": stats}, ensure_ascii=False)
778
 
779
 
780
  # ── Frontend HTML (v6) ───────────────────────────────────────────
@@ -783,7 +306,7 @@ FRONTEND_HTML = r"""<!DOCTYPE html>
783
  <head>
784
  <meta charset="UTF-8">
785
  <meta name="viewport" content="width=device-width,initial-scale=1">
786
- <title>PII Explorer β€” Playground</title>
787
  <link rel="preconnect" href="https://fonts.googleapis.com">
788
  <link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
789
  <link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&family=JetBrains+Mono:wght@400;500&family=Source+Serif+4:opsz,wght@8..60,400;8..60,500;8..60,600&display=swap" rel="stylesheet">
@@ -1046,7 +569,7 @@ button{font:inherit;color:inherit;background:transparent;border:0;cursor:pointer
1046
  <circle cx="8.5" cy="8.5" r="3.2" stroke="var(--block-background-fill)" stroke-width="1.4" fill="none"/>
1047
  <line x1="11.2" y1="11.2" x2="14.2" y2="14.2" stroke="var(--block-background-fill)" stroke-width="1.4" stroke-linecap="round"/>
1048
  </svg>
1049
- <span class="u-brand-name">PII Explorer<span class="sub">/ Playground</span></span>
1050
  </div>
1051
  <h1 class="u-title">See what your documents are leaking.</h1>
1052
  <p class="u-sub">Find every PII span in a PDF, DOC or DOCX β€” names, accounts, secrets and five other entity types β€” then export a fully redacted copy.</p>
@@ -1074,10 +597,10 @@ button{font:inherit;color:inherit;background:transparent;border:0;cursor:pointer
1074
  </div>
1075
 
1076
  <div class="u-meta">
1077
- <span><b>OpenAI Privacy Filter</b></span>
1078
  <span>128k ctx</span>
 
1079
  <span>apache 2.0</span>
1080
- <span><b>gr.Server</b></span>
1081
  </div>
1082
  </div>
1083
 
@@ -1110,7 +633,7 @@ button{font:inherit;color:inherit;background:transparent;border:0;cursor:pointer
1110
  <!-- ============ results view ============ -->
1111
  <div id="results-view">
1112
  <div class="shell">
1113
- <div class="pr-app" aria-label="PII Explorer Playground">
1114
 
1115
  <div class="pr-top">
1116
  <div class="pr-logo">
@@ -1119,7 +642,7 @@ button{font:inherit;color:inherit;background:transparent;border:0;cursor:pointer
1119
  <circle cx="8.5" cy="8.5" r="3.2" stroke="var(--block-background-fill)" stroke-width="1.4" fill="none"/>
1120
  <line x1="11.2" y1="11.2" x2="14.2" y2="14.2" stroke="var(--block-background-fill)" stroke-width="1.4" stroke-linecap="round"/>
1121
  </svg>
1122
- <span class="pr-name">PII Explorer<span class="pr-name-sub">/ Playground</span></span>
1123
  </div>
1124
  <span class="pr-file-chip" id="file-chip"></span>
1125
  <span class="pr-status" id="scan-status"><span class="pr-status-dot"></span>Scan complete</span>
@@ -1191,7 +714,19 @@ button{font:inherit;color:inherit;background:transparent;border:0;cursor:pointer
1191
 
1192
  <div class="tip" id="tip" style="display:none"></div>
1193
 
1194
- <script>
 
 
 
 
 
 
 
 
 
 
 
 
1195
  const S = {
1196
  text:'', spans:[], stats:{}, speakers:{}, catMeta:{}, filename:'', file:null,
1197
  activeCats:new Set(), scanMs:0, sortedSpans:[],
@@ -1244,20 +779,22 @@ async function uploadFile(file){
1244
  S.file = file;
1245
  showLoading('scanning document…');
1246
  document.getElementById('upload-view').style.display='none';
1247
- const form = new FormData(); form.append('file', file);
1248
  const t0 = performance.now();
1249
  try{
1250
- const r = await fetch('/api/analyze', {method:'POST', body:form});
1251
- const d = await r.json();
 
 
 
1252
  if (d.error) { showError(d.error); return; }
1253
  S.scanMs = performance.now() - t0;
1254
  S.text = d.text; S.spans = d.spans; S.stats = d.stats;
1255
  S.speakers = d.speakers||{}; S.catMeta = d.categories_meta||{};
1256
- S.filename = d.filename;
1257
  S.activeCats = new Set(Object.keys(d.stats.categories));
1258
  S.sortedSpans = [...S.spans].sort((a,b) => a.start - b.start);
1259
  renderResults();
1260
- } catch(e){ showError('Analysis failed: '+e.message); }
1261
  finally { hideLoading(); }
1262
  }
1263
 
@@ -1511,20 +1048,18 @@ document.getElementById('act-pdf').addEventListener('click', async () => {
1511
  btn.disabled = true;
1512
  showLoading('redacting PDF…');
1513
  try {
1514
- const form = new FormData();
1515
- form.append('file', S.file);
1516
- form.append('spans', JSON.stringify(S.spans));
1517
- form.append('active', JSON.stringify([...S.activeCats]));
1518
- const r = await fetch('/api/redact-pdf', { method:'POST', body: form });
1519
- if (!r.ok) {
1520
- let err = `Redaction failed (${r.status})`;
1521
- try { const j = await r.json(); err = j.error || err; } catch {}
1522
- throw new Error(err);
1523
- }
1524
- const elapsedHeader = r.headers.get('X-Redaction-Ms');
1525
- const blob = await r.blob();
1526
  download(baseName() + '.redacted.pdf', blob, 'application/pdf');
1527
- if (elapsedHeader) flash('act-pdf', `Downloaded (${(elapsedHeader/1000).toFixed(1)}s)`);
1528
  else flash('act-pdf', 'Downloaded');
1529
  } catch (e) {
1530
  alert(e.message || 'Redaction failed');
 
1
  """
2
+ =======================================
3
+ PII Reveal - Document Privacy Explorer
4
+ =======================================
5
+
6
+ Uploads a PDF/DOC/DOCX, runs the openai/privacy-filter model over the
7
+ extracted text, and returns per-span character offsets + stats for an
8
+ interactive reader view. Also supports building a black-bar redacted PDF.
9
+
10
+ Inference path: `transformers.pipeline("token-classification",
11
+ "openai/privacy-filter", aggregation_strategy="simple")` β€” the pipeline
12
+ takes care of BIOES β†’ char-level span aggregation for us.
13
+
14
+ PDF redaction (build_redacted_pdf_bytes) is optimized for large files:
15
+ per-page `needle in page_text` prefilter before page.search_for, skip
16
+ apply_redactions on pages with no matches, and save with garbage=1 to
17
+ avoid the expensive stream-recompression pass.
18
  """
19
 
20
  # ── stdlib ───────────────────────────────────────────────────────
 
21
  import functools
22
  import io
23
  import json
 
24
  import os
25
  import re
26
  import tempfile
27
  import time
 
 
 
28
  from pathlib import Path
 
29
 
30
  # ── third-party ──────────────────────────────────────────────────
31
  import gradio as gr
32
  import spaces
 
33
  import torch
34
+ from fastapi.responses import HTMLResponse
35
+ from gradio.data_classes import FileData
 
 
 
36
 
37
  # ── configuration ────────────────────────────────────────────────
38
+ PII_MODEL_REPO = os.getenv("MODEL_ID", "openai/privacy-filter")
39
  HF_TOKEN = os.getenv("HF_TOKEN", None)
 
40
 
41
  CATEGORIES_META = {
42
  "private_person": {"color": "#E24B4A", "cls": "hp", "label": "Person", "mono": False},
 
50
  }
51
 
52
  # =====================================================================
53
+ # MODEL INFERENCE (transformers pipeline β€” openai/privacy-filter)
54
  # =====================================================================
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  @functools.lru_cache(maxsize=1)
57
+ def get_pii_pipeline():
58
+ """Lazy-load the privacy filter on the GPU. Cached so repeated calls
59
+ inside a single ZeroGPU slot don't re-move weights."""
60
+ from transformers import pipeline
61
+ return pipeline(
62
+ task="token-classification",
63
+ model=PII_MODEL_REPO,
64
+ aggregation_strategy="simple", # merges BIOES tags into char-level spans
65
+ device=0,
66
+ torch_dtype=torch.bfloat16,
67
+ token=HF_TOKEN,
68
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
 
71
+ def predict_text(text: str) -> tuple[str, list[dict]]:
72
+ """Returns (source_text, spans). `spans` is a list of
73
+ {label, start, end, text} with character offsets into `text`."""
74
+ if not text.strip():
75
+ return text, []
76
+ pipe = get_pii_pipeline()
77
+ results = pipe(text)
78
+ spans = []
79
+ for r in results:
80
+ label = r.get("entity_group") or r.get("entity")
81
+ if not label or label == "O":
82
+ continue
83
+ s, e = int(r["start"]), int(r["end"])
84
+ if e <= s or s < 0 or e > len(text):
85
+ continue
86
+ spans.append({"label": label, "start": s, "end": e, "text": text[s:e]})
87
+ return text, spans
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
 
90
  # =====================================================================
 
143
  @spaces.GPU
144
  def run_pii_analysis(text: str):
145
  """GPU-accelerated PII detection."""
146
+ return predict_text(text)
 
 
 
147
 
148
 
149
  def build_redacted_pdf_bytes(pdf_path: str, pii_texts: list[str]) -> bytes:
 
194
 
195
 
196
  # ── Gradio Server ────────────────────────────────────────────────
197
+ #
198
+ # We only keep one plain FastAPI route here β€” the homepage, which
199
+ # serves the static HTML shell. The heavy lifting endpoints are
200
+ # declared with @server.api, which wraps them in Gradio's queue so
201
+ # they compose correctly with @spaces.GPU on ZeroGPU and with the
202
+ # gradio_client / @gradio/client SDKs.
203
  server = gr.Server()
204
 
205
 
 
208
  return FRONTEND_HTML
209
 
210
 
211
+ @server.api(name="analyze_document")
212
+ def analyze_document_api(file: FileData) -> dict:
213
+ """Extract text from an uploaded PDF/DOC/DOCX and run the OPF
214
+ privacy filter over it. Returns the detected spans, stats,
215
+ per-speaker counts, and the category color/label table.
216
+
217
+ Called from the browser via @gradio/client:
218
+ client.predict("/analyze_document", { file: handle_file(f) })
219
+ And from Python via gradio_client:
220
+ client.predict("/analyze_document", file=handle_file(path))
221
+ """
222
+ path = file.get("path") or ""
223
+ suffix = Path(path).suffix.lower()
224
+ orig_name = file.get("orig_name") or Path(path).name
225
  if suffix not in (".pdf", ".doc", ".docx"):
226
+ return {"error": f"Unsupported: {suffix}. Use PDF, DOC, or DOCX."}
227
+
 
228
  try:
229
+ text = extract_text(path)
230
  if not text.strip():
231
+ return {"error": "No text content found."}
232
  source_text, spans = run_pii_analysis(text)
233
  stats = compute_stats(source_text, spans)
234
  speakers = detect_speakers(source_text, spans)
235
+ return {
236
+ "filename": orig_name,
237
+ "text": source_text,
238
+ "spans": spans,
239
+ "stats": stats,
240
+ "speakers": speakers,
241
+ "categories_meta": {
242
+ k: {"color": v["color"], "cls": v["cls"],
243
+ "label": v["label"], "mono": v["mono"]}
244
+ for k, v in CATEGORIES_META.items()
245
+ },
246
+ }
247
  except Exception as e:
248
+ return {"error": str(e)}
249
+
 
250
 
251
+ @server.api(name="redact_pdf")
252
+ def redact_pdf_api(file: FileData, spans: str, active: str) -> dict:
253
+ """Build a black-bar-redacted PDF from an uploaded PDF plus the
254
+ list of spans the browser wants redacted. `spans` and `active`
255
+ are JSON strings because the JS client serializes complex objects
256
+ more predictably as strings than as nested dicts.
257
 
258
+ Returns {"pdf": FileData, "elapsed_ms": int} so the caller can
259
+ download the file and also display timing."""
260
+ path = file.get("path") or ""
261
+ suffix = Path(path).suffix.lower()
 
 
 
262
  if suffix != ".pdf":
263
+ return {"error": "PDF redaction only accepts PDF input."}
264
  try:
265
  span_list = json.loads(spans)
266
  active_set = set(json.loads(active))
267
  except Exception as e:
268
+ return {"error": f"Invalid payload: {e}"}
269
 
270
  pii_texts = [
271
  s.get("text", "") for s in span_list
272
  if s.get("label") in active_set
273
  ]
274
  if not pii_texts:
275
+ return {"error": "No active categories selected β€” nothing to redact."}
276
 
 
 
277
  try:
278
  t0 = time.perf_counter()
279
+ pdf_bytes = build_redacted_pdf_bytes(path, pii_texts)
280
+ elapsed_ms = int((time.perf_counter() - t0) * 1000)
 
 
 
 
 
 
 
 
 
281
  except Exception as e:
282
+ return {"error": str(e)}
283
+
284
+ orig_name = file.get("orig_name") or Path(path).name
285
+ stem = Path(orig_name).stem or "document"
286
+ out_path = Path(tempfile.gettempdir()) / f"{stem}.redacted.pdf"
287
+ out_path.write_bytes(pdf_bytes)
288
+ return {
289
+ "pdf": FileData(path=str(out_path)),
290
+ "elapsed_ms": elapsed_ms,
291
+ }
292
 
293
 
294
  @server.api(name="analyze_text")
295
+ def analyze_text_api(text: str) -> dict:
296
+ """Analyze raw text for PII β€” convenient for gradio_client users
297
+ who don't want to build a PDF just to test the model."""
298
  source_text, spans = run_pii_analysis(text)
299
  stats = compute_stats(source_text, spans)
300
+ return {"text": source_text, "spans": spans, "stats": stats}
301
 
302
 
303
  # ── Frontend HTML (v6) ───────────────────────────────────────────
 
306
  <head>
307
  <meta charset="UTF-8">
308
  <meta name="viewport" content="width=device-width,initial-scale=1">
309
+ <title>PII Reveal β€” Inspector</title>
310
  <link rel="preconnect" href="https://fonts.googleapis.com">
311
  <link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
312
  <link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&family=JetBrains+Mono:wght@400;500&family=Source+Serif+4:opsz,wght@8..60,400;8..60,500;8..60,600&display=swap" rel="stylesheet">
 
569
  <circle cx="8.5" cy="8.5" r="3.2" stroke="var(--block-background-fill)" stroke-width="1.4" fill="none"/>
570
  <line x1="11.2" y1="11.2" x2="14.2" y2="14.2" stroke="var(--block-background-fill)" stroke-width="1.4" stroke-linecap="round"/>
571
  </svg>
572
+ <span class="u-brand-name">PII Reveal<span class="sub">/ inspector</span></span>
573
  </div>
574
  <h1 class="u-title">See what your documents are leaking.</h1>
575
  <p class="u-sub">Find every PII span in a PDF, DOC or DOCX β€” names, accounts, secrets and five other entity types β€” then export a fully redacted copy.</p>
 
597
  </div>
598
 
599
  <div class="u-meta">
600
+ <span>openai privacy filter</span>
601
  <span>128k ctx</span>
602
+ <span>bfloat16</span>
603
  <span>apache 2.0</span>
 
604
  </div>
605
  </div>
606
 
 
633
  <!-- ============ results view ============ -->
634
  <div id="results-view">
635
  <div class="shell">
636
+ <div class="pr-app" aria-label="PII Reveal inspector">
637
 
638
  <div class="pr-top">
639
  <div class="pr-logo">
 
642
  <circle cx="8.5" cy="8.5" r="3.2" stroke="var(--block-background-fill)" stroke-width="1.4" fill="none"/>
643
  <line x1="11.2" y1="11.2" x2="14.2" y2="14.2" stroke="var(--block-background-fill)" stroke-width="1.4" stroke-linecap="round"/>
644
  </svg>
645
+ <span class="pr-name">PII Reveal<span class="pr-name-sub">/ inspector</span></span>
646
  </div>
647
  <span class="pr-file-chip" id="file-chip"></span>
648
  <span class="pr-status" id="scan-status"><span class="pr-status-dot"></span>Scan complete</span>
 
714
 
715
  <div class="tip" id="tip" style="display:none"></div>
716
 
717
+ <script type="module">
718
+ // ══════════════════════════════════════════════════════════════════
719
+ // Gradio JS client β€” /api/analyze and /api/redact-pdf were plain
720
+ // FastAPI routes in the old version, which meant requests bypassed
721
+ // Gradio's queue entirely. Now the backend exposes @server.api
722
+ // routes and we call them through the Client, which gives us queue
723
+ // serialization, progress events, and correct ZeroGPU allocation
724
+ // via @spaces.GPU.
725
+ // ══════════════════════════════════════════════════════════════════
726
+ import { Client, handle_file } from "https://cdn.jsdelivr.net/npm/@gradio/client/dist/index.min.js";
727
+
728
+ const clientPromise = Client.connect(window.location.origin);
729
+
730
  const S = {
731
  text:'', spans:[], stats:{}, speakers:{}, catMeta:{}, filename:'', file:null,
732
  activeCats:new Set(), scanMs:0, sortedSpans:[],
 
779
  S.file = file;
780
  showLoading('scanning document…');
781
  document.getElementById('upload-view').style.display='none';
 
782
  const t0 = performance.now();
783
  try{
784
+ const client = await clientPromise;
785
+ const result = await client.predict("/analyze_document", {
786
+ file: handle_file(file),
787
+ });
788
+ const d = result.data[0] || {};
789
  if (d.error) { showError(d.error); return; }
790
  S.scanMs = performance.now() - t0;
791
  S.text = d.text; S.spans = d.spans; S.stats = d.stats;
792
  S.speakers = d.speakers||{}; S.catMeta = d.categories_meta||{};
793
+ S.filename = d.filename || file.name;
794
  S.activeCats = new Set(Object.keys(d.stats.categories));
795
  S.sortedSpans = [...S.spans].sort((a,b) => a.start - b.start);
796
  renderResults();
797
+ } catch(e){ showError('Analysis failed: '+(e && e.message ? e.message : e)); }
798
  finally { hideLoading(); }
799
  }
800
 
 
1048
  btn.disabled = true;
1049
  showLoading('redacting PDF…');
1050
  try {
1051
+ const client = await clientPromise;
1052
+ const result = await client.predict("/redact_pdf", {
1053
+ file: handle_file(S.file),
1054
+ spans: JSON.stringify(S.spans),
1055
+ active: JSON.stringify([...S.activeCats]),
1056
+ });
1057
+ const d = result.data[0] || {};
1058
+ if (d.error) throw new Error(d.error);
1059
+ if (!d.pdf || !d.pdf.url) throw new Error('No PDF returned.');
1060
+ const blob = await (await fetch(d.pdf.url)).blob();
 
 
1061
  download(baseName() + '.redacted.pdf', blob, 'application/pdf');
1062
+ if (typeof d.elapsed_ms === 'number') flash('act-pdf', `Downloaded (${(d.elapsed_ms/1000).toFixed(1)}s)`);
1063
  else flash('act-pdf', 'Downloaded');
1064
  } catch (e) {
1065
  alert(e.message || 'Redaction failed');