Mihai Maruseac commited on
Commit
1d92498
·
unverified ·
1 Parent(s): 0bfe81b

ZeroGPU Gradio demo for OpenAI Privacy Filter

Browse files

Updates the README to be informative and adds all the necessary files to showcase the ZeroGPU Gradio demo of the OpenAI Privacy Filter

Signed-off-by: Mihai Maruseac <mihaimaruseac@openai.com>

Files changed (3) hide show
  1. README.md +29 -4
  2. app.py +1319 -0
  3. requirements.txt +5 -0
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
- title: Privacy Filter
3
- emoji: 📉
4
  colorFrom: gray
5
  colorTo: gray
6
  sdk: gradio
@@ -9,7 +9,32 @@ python_version: '3.12'
9
  app_file: app.py
10
  pinned: false
11
  license: apache-2.0
12
- short_description: Privacy Filter
13
  ---
14
 
15
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: OpenAI Privacy Filter
3
+ emoji: 🛡️
4
  colorFrom: gray
5
  colorTo: gray
6
  sdk: gradio
 
9
  app_file: app.py
10
  pinned: false
11
  license: apache-2.0
12
+ short_description: OpenAI Privacy Filter ZeroGPU demo
13
  ---
14
 
15
+ # OpenAI Privacy Filter
16
+
17
+ OpenAI Privacy Filter is a bidirectional token-classification model for personally identifiable information (PII) detection and masking in text. It is intended for high-throughput data sanitization workflows where teams need a model that they can run on-premises that is fast, context-aware, and tunable.
18
+
19
+ OpenAI Privacy Filter is pretrained autoregressively to arrive at a checkpoint with similar architecture to gpt-oss, albeit of a smaller size. We then converted that checkpoint into a bidirectional token classifier over a privacy label taxonomy, and post-trained with a supervised classification loss. (For architecture details about gpt-oss, please see the gpt-oss model card.) Instead of generating text token-by-token, this model labels an input sequence in a single forward pass, then decodes coherent spans with a constrained Viterbi procedure. For each input token, the model predicts a probability distribution over the label taxonomy which consists of 8 output categories described below.
20
+
21
+ Highlights:
22
+
23
+ - Permissive Apache 2.0 license: ideal for experimentation, customization, and commercial deployment.
24
+ - Small size: Runs in a web browser or on a laptop – 1.5B parameters total and 50M active parameters.
25
+ - Fine-tunable: Adapt the model to specific data distributions through easy and data efficient finetuning.
26
+ - Long-context: 128,000-token context window enables processing long text with high throughput and no chunking.
27
+ - Runtime control: configure precision/recall tradeoffs and detected span lengths through preset operating points.
28
+
29
+ ## Metadata
30
+
31
+ - Developed by: OpenAI
32
+ - Funded by: OpenAI
33
+ - Shared by: OpenAI
34
+ - Model type: Bidirectional token classification model for privacy span detection
35
+ - Language(s): Primarily English; selected multilingual robustness evaluation reported
36
+ - License: [Apache 2.0](LICENSE)
37
+
38
+ - Source repository: https://github.com/openai/privacy-filter
39
+ - Model weights: https://huggingface.co/openai/privacy-filter
40
+ - Model card: [OpenAI Privacy Filter Model Card](https://cdn.openai.com/pdf/c66281ed-b638-456a-8ce1-97e9f5264a90/OpenAI-Privacy-Filter-Model-Card.pdf)
app.py ADDED
@@ -0,0 +1,1319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import functools
3
+ import json
4
+ import math
5
+ import os
6
+
7
+ from bisect import bisect_left, bisect_right
8
+ from collections.abc import Sequence
9
+ from dataclasses import dataclass
10
+ from pathlib import Path
11
+ from typing import Final
12
+
13
+ import gradio as gr
14
+ import spaces
15
+ import torch
16
+ import torch.nn.functional as F
17
+ from safetensors import safe_open
18
+
19
+ import tiktoken
20
+
21
+ from huggingface_hub import snapshot_download
22
+
23
+ MODEL_ROOT = snapshot_download("openai/privacy-filter", allow_patterns=["original/*"])
24
+ MODEL_DIR = Path(MODEL_ROOT) / "original"
25
+
26
+ PRIVACY_FILTER_MODEL_TYPE: Final[str] = "privacy_filter"
27
+ REQUIRED_MODEL_CONFIG_KEYS: Final[tuple[str, ...]] = (
28
+ "model_type",
29
+ "encoding",
30
+ "num_hidden_layers",
31
+ "num_experts",
32
+ "experts_per_token",
33
+ "vocab_size",
34
+ "num_labels",
35
+ "hidden_size",
36
+ "intermediate_size",
37
+ "head_dim",
38
+ "num_attention_heads",
39
+ "num_key_value_heads",
40
+ "sliding_window",
41
+ "bidirectional_context",
42
+ "bidirectional_left_context",
43
+ "bidirectional_right_context",
44
+ "default_n_ctx",
45
+ "initial_context_length",
46
+ "rope_theta",
47
+ "rope_scaling_factor",
48
+ "rope_ntk_alpha",
49
+ "rope_ntk_beta",
50
+ "param_dtype",
51
+ )
52
+ BACKGROUND_CLASS_LABEL: Final[str] = "O"
53
+ BOUNDARY_PREFIXES: Final[tuple[str, ...]] = ("B", "I", "E", "S")
54
+ EMPTY_HIGHLIGHT_PAYLOAD = {"text": "", "entities": []}
55
+ SPAN_CLASS_NAMES: Final[tuple[str, ...]] = (
56
+ BACKGROUND_CLASS_LABEL,
57
+ "account_number",
58
+ "private_address",
59
+ "private_date",
60
+ "private_email",
61
+ "private_person",
62
+ "private_phone",
63
+ "private_url",
64
+ "secret",
65
+ )
66
+ NER_CLASS_NAMES: Final[tuple[str, ...]] = (BACKGROUND_CLASS_LABEL,) + tuple(
67
+ f"{prefix}-{base_label}"
68
+ for base_label in SPAN_CLASS_NAMES
69
+ if base_label != BACKGROUND_CLASS_LABEL
70
+ for prefix in BOUNDARY_PREFIXES
71
+ )
72
+ VITERBI_TRANSITION_BIAS_KEYS: Final[tuple[str, ...]] = (
73
+ "transition_bias_background_stay",
74
+ "transition_bias_background_to_start",
75
+ "transition_bias_inside_to_continue",
76
+ "transition_bias_inside_to_end",
77
+ "transition_bias_end_to_background",
78
+ "transition_bias_end_to_start",
79
+ )
80
+ DEFAULT_VITERBI_CALIBRATION_PRESET: Final[str] = "default"
81
+
82
+
83
+ def validate_model_config_contract(
84
+ checkpoint_config: dict[str, object],
85
+ *,
86
+ context: str,
87
+ ) -> None:
88
+ missing = [key for key in REQUIRED_MODEL_CONFIG_KEYS if key not in checkpoint_config]
89
+ if missing:
90
+ raise ValueError(f"{context} is missing required model config keys: {', '.join(missing)}")
91
+ model_type = checkpoint_config.get("model_type")
92
+ if model_type != PRIVACY_FILTER_MODEL_TYPE:
93
+ raise ValueError(
94
+ f"{context} model_type must be {PRIVACY_FILTER_MODEL_TYPE!r}, got {model_type!r}"
95
+ )
96
+ if checkpoint_config.get("bidirectional_context") is not True:
97
+ raise ValueError(f"{context} must use bidirectional_context=true")
98
+
99
+ raw_left_context = checkpoint_config.get("bidirectional_left_context")
100
+ raw_right_context = checkpoint_config.get("bidirectional_right_context")
101
+ if (
102
+ not isinstance(raw_left_context, int)
103
+ or isinstance(raw_left_context, bool)
104
+ or not isinstance(raw_right_context, int)
105
+ or isinstance(raw_right_context, bool)
106
+ ):
107
+ raise ValueError(
108
+ f"{context} bidirectional context sizes must be integers "
109
+ f"(got {raw_left_context!r}/{raw_right_context!r})"
110
+ )
111
+ left_context = raw_left_context
112
+ right_context = raw_right_context
113
+ if left_context < 0 or right_context < 0:
114
+ raise ValueError(
115
+ f"{context} bidirectional context sizes must be >= 0 "
116
+ f"(got {left_context}/{right_context})"
117
+ )
118
+ if left_context != right_context:
119
+ raise ValueError(
120
+ f"{context} bidirectional context must be symmetric "
121
+ f"(got left={left_context}, right={right_context})"
122
+ )
123
+
124
+ raw_sliding_window = checkpoint_config.get("sliding_window")
125
+ if not isinstance(raw_sliding_window, int) or isinstance(raw_sliding_window, bool):
126
+ raise ValueError(f"{context} sliding_window must be an integer, got {raw_sliding_window!r}")
127
+ sliding_window = raw_sliding_window
128
+ expected_sliding_window = 2 * left_context + 1
129
+ if sliding_window != expected_sliding_window:
130
+ raise ValueError(
131
+ f"{context} sliding_window must equal 2 * bidirectional context + 1 "
132
+ f"(got {sliding_window}, expected {expected_sliding_window})"
133
+ )
134
+
135
+ num_labels_raw = checkpoint_config["num_labels"]
136
+ if not isinstance(num_labels_raw, int) or isinstance(num_labels_raw, bool):
137
+ raise ValueError(f"{context} num_labels must be an integer, got {num_labels_raw!r}")
138
+ num_labels = num_labels_raw
139
+ if num_labels != 33:
140
+ raise ValueError(
141
+ f"{context} must use num_labels=33 for the label space, got {num_labels}"
142
+ )
143
+
144
+ raw_encoding = checkpoint_config["encoding"]
145
+ if not isinstance(raw_encoding, str) or not raw_encoding.strip():
146
+ raise ValueError(f"{context} encoding must be a non-empty string")
147
+
148
+ raw_n_ctx = checkpoint_config["default_n_ctx"]
149
+ if not isinstance(raw_n_ctx, int) or isinstance(raw_n_ctx, bool):
150
+ raise ValueError(f"{context} default_n_ctx must be a positive integer, got {raw_n_ctx!r}")
151
+ n_ctx = raw_n_ctx
152
+ if n_ctx <= 0:
153
+ raise ValueError(f"{context} default_n_ctx must be positive, got {n_ctx}")
154
+
155
+ raw_param_dtype = checkpoint_config["param_dtype"]
156
+ if raw_param_dtype != "bfloat16":
157
+ raise ValueError(f"{context} param_dtype must be bfloat16, got {raw_param_dtype!r}")
158
+
159
+
160
+ def expert_linear(
161
+ x: torch.Tensor,
162
+ weight: torch.Tensor,
163
+ bias: torch.Tensor | None,
164
+ ) -> torch.Tensor:
165
+ num_rows, experts, k_dim = x.shape
166
+ _, _, _, out_dim = weight.shape
167
+ x_bmm = x.reshape(num_rows * experts, 1, k_dim)
168
+ w_bmm = weight.reshape(num_rows * experts, k_dim, out_dim)
169
+ out = torch.bmm(x_bmm, w_bmm).reshape(num_rows, experts, out_dim)
170
+ if bias is not None:
171
+ out = out + bias
172
+ return out
173
+
174
+
175
+ @dataclass
176
+ class ModelConfig:
177
+ num_hidden_layers: int
178
+ num_experts: int
179
+ experts_per_token: int
180
+ vocab_size: int
181
+ num_labels: int
182
+ hidden_size: int
183
+ intermediate_size: int
184
+ head_dim: int
185
+ num_attention_heads: int
186
+ num_key_value_heads: int
187
+ bidirectional_context_size: int
188
+ initial_context_length: int
189
+ rope_theta: float
190
+ rope_scaling_factor: float
191
+ rope_ntk_alpha: float
192
+ rope_ntk_beta: float
193
+
194
+ @classmethod
195
+ def from_checkpoint_config(
196
+ cls,
197
+ checkpoint_config: dict[str, object],
198
+ *,
199
+ context: str,
200
+ ) -> "ModelConfig":
201
+ checkpoint_config = dict(checkpoint_config)
202
+ checkpoint_config["bidirectional_context_size"] = checkpoint_config[
203
+ "bidirectional_left_context"
204
+ ]
205
+ fields = {field.name: field for field in dataclasses.fields(cls)}
206
+ config_values = {
207
+ key: value for key, value in checkpoint_config.items() if key in fields
208
+ }
209
+
210
+ missing = [
211
+ name
212
+ for name, field in fields.items()
213
+ if field.default is dataclasses.MISSING
214
+ and field.default_factory is dataclasses.MISSING
215
+ and name not in config_values
216
+ ]
217
+ if missing:
218
+ raise ValueError(
219
+ f"{context} is missing required model config fields: {', '.join(missing)}"
220
+ )
221
+
222
+ try:
223
+ return cls(**config_values)
224
+ except TypeError as exc:
225
+ raise ValueError(f"Invalid model config payload at {context}: {exc}") from exc
226
+
227
+
228
+ class RMSNorm(torch.nn.Module):
229
+ def __init__(
230
+ self, num_features: int, eps: float = 1e-05, device: torch.device | None = None
231
+ ) -> None:
232
+ super().__init__()
233
+ self.num_features = num_features
234
+ self.eps = eps
235
+ self.scale = torch.nn.Parameter(
236
+ torch.ones(num_features, device=device, dtype=torch.float32)
237
+ )
238
+
239
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
240
+ t = x.float()
241
+ t = t * torch.rsqrt(torch.mean(t**2, dim=-1, keepdim=True) + self.eps)
242
+ return (t * self.scale).to(x.dtype)
243
+
244
+
245
+ def apply_rope(
246
+ x: torch.Tensor,
247
+ cos: torch.Tensor,
248
+ sin: torch.Tensor,
249
+ ) -> torch.Tensor:
250
+ cos = cos.unsqueeze(-2).to(x.dtype)
251
+ sin = sin.unsqueeze(-2).to(x.dtype)
252
+ x1 = x[..., ::2]
253
+ x2 = x[..., 1::2]
254
+ out1 = x1 * cos - x2 * sin
255
+ out2 = x2 * cos + x1 * sin
256
+ return torch.stack((out1, out2), dim=-1).reshape(x.shape)
257
+
258
+
259
+ class RotaryEmbedding(torch.nn.Module):
260
+ def __init__(
261
+ self,
262
+ head_dim: int,
263
+ base: int,
264
+ dtype: torch.dtype,
265
+ *,
266
+ initial_context_length: int = 4096,
267
+ scaling_factor: float = 1.0,
268
+ ntk_alpha: float = 1.0,
269
+ ntk_beta: float = 32.0,
270
+ device: torch.device | None = None,
271
+ ) -> None:
272
+ super().__init__()
273
+ self.head_dim = head_dim
274
+ self.base = base
275
+ self.dtype = dtype
276
+ self.initial_context_length = initial_context_length
277
+ self.scaling_factor = scaling_factor
278
+ self.ntk_alpha = ntk_alpha
279
+ self.ntk_beta = ntk_beta
280
+ self.device = device
281
+ max_positions = int(self.initial_context_length * self.scaling_factor)
282
+ max_positions = max(max_positions, self.initial_context_length)
283
+ self.max_position_embeddings = max_positions
284
+ cos, sin = self._compute_cos_sin(self.max_position_embeddings, device=torch.device("cpu"))
285
+ target_device = device or torch.device("cpu")
286
+ self.register_buffer("cos_cache", cos.to(target_device), persistent=False)
287
+ self.register_buffer("sin_cache", sin.to(target_device), persistent=False)
288
+
289
+ def _compute_concentration_and_inv_freq(
290
+ self, device: torch.device | None = None
291
+ ) -> tuple[float, torch.Tensor]:
292
+ device = device or self.device
293
+ freq = self.base ** (
294
+ torch.arange(0, self.head_dim, 2, dtype=torch.float, device=device) / self.head_dim
295
+ )
296
+ if self.scaling_factor > 1.0:
297
+ concentration = 0.1 * math.log(self.scaling_factor) + 1.0
298
+ d_half = self.head_dim / 2
299
+ low = (
300
+ d_half
301
+ * math.log(self.initial_context_length / (self.ntk_beta * 2 * math.pi))
302
+ / math.log(self.base)
303
+ )
304
+ high = (
305
+ d_half
306
+ * math.log(self.initial_context_length / (self.ntk_alpha * 2 * math.pi))
307
+ / math.log(self.base)
308
+ )
309
+ interpolation = 1.0 / (self.scaling_factor * freq)
310
+ extrapolation = 1.0 / freq
311
+ ramp = (torch.arange(d_half, dtype=torch.float32, device=freq.device) - low) / (
312
+ high - low
313
+ )
314
+ mask = 1 - ramp.clamp(0, 1)
315
+ inv_freq = interpolation * (1 - mask) + extrapolation * mask
316
+ else:
317
+ concentration = 1.0
318
+ inv_freq = 1.0 / freq
319
+ return concentration, inv_freq
320
+
321
+ def _compute_cos_sin(
322
+ self, num_tokens: int, device: torch.device | None = None
323
+ ) -> tuple[torch.Tensor, torch.Tensor]:
324
+ concentration, inv_freq = self._compute_concentration_and_inv_freq(device=device)
325
+ device = device or self.device
326
+ t = torch.arange(num_tokens, dtype=torch.float32, device=device)
327
+ freqs = torch.einsum("i,j->ij", t, inv_freq)
328
+ cos = freqs.cos() * concentration
329
+ sin = freqs.sin() * concentration
330
+ return cos.to(self.dtype), sin.to(self.dtype)
331
+
332
+ def forward(
333
+ self,
334
+ query: torch.Tensor,
335
+ key: torch.Tensor,
336
+ ) -> tuple[torch.Tensor, torch.Tensor]:
337
+ num_tokens = query.shape[0]
338
+ if num_tokens > self.cos_cache.shape[0]:
339
+ cos, sin = self._compute_cos_sin(num_tokens, device=torch.device("cpu"))
340
+ self.cos_cache = cos.to(query.device)
341
+ self.sin_cache = sin.to(query.device)
342
+ if self.cos_cache.device != query.device:
343
+ cos_cache = self.cos_cache.to(query.device)
344
+ sin_cache = self.sin_cache.to(query.device)
345
+ else:
346
+ cos_cache = self.cos_cache
347
+ sin_cache = self.sin_cache
348
+ cos = cos_cache[:num_tokens]
349
+ sin = sin_cache[:num_tokens]
350
+
351
+ query_shape = query.shape
352
+ query = query.view(num_tokens, -1, self.head_dim)
353
+ query = apply_rope(query, cos, sin)
354
+ query = query.reshape(query_shape)
355
+
356
+ key_shape = key.shape
357
+ key = key.view(num_tokens, -1, self.head_dim)
358
+ key = apply_rope(key, cos, sin)
359
+ key = key.reshape(key_shape)
360
+ return query, key
361
+
362
+
363
+ def sdpa(
364
+ Q: torch.Tensor,
365
+ K: torch.Tensor,
366
+ V: torch.Tensor,
367
+ S: torch.Tensor,
368
+ sm_scale: float,
369
+ context_size: int,
370
+ ) -> torch.Tensor:
371
+ num_tokens, num_heads, q_mult, head_dim = Q.shape
372
+ window = 2 * context_size + 1
373
+ Kp = F.pad(K, (0, 0, 0, 0, context_size, context_size))
374
+ Vp = F.pad(V, (0, 0, 0, 0, context_size, context_size))
375
+ Kwin = Kp.unfold(0, window, 1).permute(0, 3, 1, 2)
376
+ Vwin = Vp.unfold(0, window, 1).permute(0, 3, 1, 2)
377
+ idx = torch.arange(window, device=Q.device) - context_size
378
+ pos = torch.arange(num_tokens, device=Q.device)[:, None] + idx[None, :]
379
+ valid = (pos >= 0) & (pos < num_tokens)
380
+ scores = torch.einsum("nhqd,nwhd->nhqw", Q, Kwin).float()
381
+ scores *= sm_scale
382
+ scores = scores.masked_fill(~valid[:, None, None, :], -float("inf"))
383
+ sink_scores = (S * math.log(2.0)).reshape(num_heads, q_mult)
384
+ sink_scores = sink_scores[None, :, :, None].expand(num_tokens, -1, -1, 1)
385
+ scores = torch.cat([scores, sink_scores], dim=-1)
386
+ weights = torch.softmax(scores, dim=-1)[..., :-1].to(V.dtype)
387
+ attn = torch.einsum("nhqw,nwhd->nhqd", weights, Vwin)
388
+ return attn.reshape(num_tokens, -1)
389
+
390
+
391
+ class AttentionBlock(torch.nn.Module):
392
+ def __init__(
393
+ self,
394
+ config: ModelConfig,
395
+ device: torch.device | None = None,
396
+ ) -> None:
397
+ super().__init__()
398
+ param_dtype = torch.bfloat16
399
+ self.head_dim = config.head_dim
400
+ self.num_attention_heads = config.num_attention_heads
401
+ self.num_key_value_heads = config.num_key_value_heads
402
+ self.bidirectional_context_size = int(config.bidirectional_context_size)
403
+ self.sinks = torch.nn.Parameter(
404
+ torch.empty(config.num_attention_heads, device=device, dtype=torch.float32)
405
+ )
406
+ self.norm = RMSNorm(config.hidden_size, device=device)
407
+ qkv_dim = config.head_dim * (config.num_attention_heads + 2 * config.num_key_value_heads)
408
+ self.qkv = torch.nn.Linear(config.hidden_size, qkv_dim, device=device, dtype=param_dtype)
409
+ self.out = torch.nn.Linear(
410
+ config.head_dim * config.num_attention_heads,
411
+ config.hidden_size,
412
+ device=device,
413
+ dtype=param_dtype,
414
+ )
415
+ self.qk_scale = 1 / math.sqrt(math.sqrt(config.head_dim))
416
+ self.sm_scale = 1.0
417
+ self.rope = RotaryEmbedding(
418
+ config.head_dim,
419
+ int(config.rope_theta),
420
+ torch.float32,
421
+ initial_context_length=config.initial_context_length,
422
+ scaling_factor=config.rope_scaling_factor,
423
+ ntk_alpha=config.rope_ntk_alpha,
424
+ ntk_beta=config.rope_ntk_beta,
425
+ device=device,
426
+ )
427
+
428
+ def forward(
429
+ self,
430
+ x: torch.Tensor,
431
+ ) -> torch.Tensor:
432
+ t = self.norm(x)
433
+ if t.dtype != self.qkv.weight.dtype:
434
+ t = t.to(self.qkv.weight.dtype)
435
+ qkv = F.linear(t, self.qkv.weight, self.qkv.bias)
436
+ query = qkv[:, : self.num_attention_heads * self.head_dim].contiguous()
437
+ key = qkv[
438
+ :,
439
+ self.num_attention_heads * self.head_dim : (
440
+ self.num_attention_heads + self.num_key_value_heads
441
+ )
442
+ * self.head_dim,
443
+ ].contiguous()
444
+ value = qkv[
445
+ :,
446
+ (self.num_attention_heads + self.num_key_value_heads) * self.head_dim : (
447
+ self.num_attention_heads + 2 * self.num_key_value_heads
448
+ )
449
+ * self.head_dim,
450
+ ].contiguous()
451
+
452
+ query, key = self.rope(query, key)
453
+ query = query * self.qk_scale
454
+ key = key * self.qk_scale
455
+ sinks = self.sinks
456
+ num_tokens = query.shape[0]
457
+ query = query.view(
458
+ num_tokens,
459
+ self.num_key_value_heads,
460
+ self.num_attention_heads // self.num_key_value_heads,
461
+ self.head_dim,
462
+ )
463
+ key = key.view(num_tokens, self.num_key_value_heads, self.head_dim)
464
+ value = value.view(num_tokens, self.num_key_value_heads, self.head_dim)
465
+ attn_out = sdpa(
466
+ query,
467
+ key,
468
+ value,
469
+ sinks,
470
+ self.sm_scale,
471
+ self.bidirectional_context_size,
472
+ )
473
+ if attn_out.dtype != self.out.weight.dtype:
474
+ attn_out = attn_out.to(self.out.weight.dtype)
475
+ proj_bias = self.out.bias
476
+ proj = F.linear(attn_out, self.out.weight, proj_bias)
477
+ return x + proj.to(x.dtype)
478
+
479
+
480
+ def swiglu(
481
+ x: torch.Tensor,
482
+ alpha: float = 1.702,
483
+ limit: float = 7.0,
484
+ ) -> torch.Tensor:
485
+ x_glu, x_linear = x.chunk(2, dim=-1)
486
+ x_glu = x_glu.clamp(min=None, max=limit)
487
+ x_linear = x_linear.clamp(min=-limit, max=limit)
488
+ out_glu = x_glu * torch.sigmoid(alpha * x_glu)
489
+ return out_glu * (x_linear + 1)
490
+
491
+
492
+ class MLPBlock(torch.nn.Module):
493
+ def __init__(
494
+ self,
495
+ config: ModelConfig,
496
+ device: torch.device | None = None,
497
+ ) -> None:
498
+ super().__init__()
499
+ param_dtype = torch.bfloat16
500
+ self.num_experts = config.num_experts
501
+ self.experts_per_token = config.experts_per_token
502
+ self.swiglu_limit = 7.0
503
+ self.norm = RMSNorm(config.hidden_size, device=device)
504
+ self.gate = torch.nn.Linear(
505
+ config.hidden_size, config.num_experts, device=device, dtype=param_dtype
506
+ )
507
+ self.mlp1_weight = torch.nn.Parameter(
508
+ torch.empty(
509
+ (config.num_experts, config.hidden_size, config.intermediate_size * 2),
510
+ device=device,
511
+ dtype=param_dtype,
512
+ )
513
+ )
514
+ self.mlp1_bias = torch.nn.Parameter(
515
+ torch.empty(
516
+ (config.num_experts, config.intermediate_size * 2),
517
+ device=device,
518
+ dtype=param_dtype,
519
+ )
520
+ )
521
+ self.mlp2_weight = torch.nn.Parameter(
522
+ torch.empty(
523
+ (config.num_experts, config.intermediate_size, config.hidden_size),
524
+ device=device,
525
+ dtype=param_dtype,
526
+ )
527
+ )
528
+ self.mlp2_bias = torch.nn.Parameter(
529
+ torch.empty(
530
+ (config.num_experts, config.hidden_size),
531
+ device=device,
532
+ dtype=param_dtype,
533
+ )
534
+ )
535
+
536
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
537
+ t = self.norm(x)
538
+ gate_scores = F.linear(t.float(), self.gate.weight.float(), self.gate.bias.float())
539
+ experts = torch.topk(gate_scores, k=self.experts_per_token, dim=-1, sorted=True)
540
+ expert_weights = torch.softmax(experts.values, dim=-1) / self.experts_per_token
541
+
542
+ expert_indices = experts.indices
543
+ experts_per_token_eff = self.experts_per_token
544
+
545
+ def _moe_chunk(
546
+ t_chunk: torch.Tensor,
547
+ expert_indices_chunk: torch.Tensor,
548
+ expert_weights_chunk: torch.Tensor,
549
+ ) -> torch.Tensor:
550
+ mlp1_weight = self.mlp1_weight[expert_indices_chunk].float()
551
+ mlp1_bias = self.mlp1_bias[expert_indices_chunk].float()
552
+ t_expanded = t_chunk.float().unsqueeze(1).expand(-1, expert_indices_chunk.shape[1], -1)
553
+ out = expert_linear(
554
+ t_expanded,
555
+ mlp1_weight,
556
+ mlp1_bias,
557
+ )
558
+ out = swiglu(out, limit=self.swiglu_limit)
559
+ mlp2_weight = self.mlp2_weight[expert_indices_chunk].float()
560
+ mlp2_bias = self.mlp2_bias[expert_indices_chunk].float()
561
+ out = expert_linear(
562
+ out.float(),
563
+ mlp2_weight,
564
+ mlp2_bias,
565
+ )
566
+ if out.dtype != expert_weights_chunk.dtype:
567
+ out = out.to(expert_weights_chunk.dtype)
568
+ out = torch.einsum("bec,be->bc", out, expert_weights_chunk)
569
+ out = out * experts_per_token_eff
570
+ return out.to(x.dtype)
571
+
572
+ torch_ops_chunk_size = 32
573
+ if t.shape[0] > torch_ops_chunk_size:
574
+ chunks = []
575
+ for start in range(0, t.shape[0], torch_ops_chunk_size):
576
+ end = start + torch_ops_chunk_size
577
+ chunks.append(
578
+ _moe_chunk(
579
+ t[start:end],
580
+ expert_indices[start:end],
581
+ expert_weights[start:end],
582
+ )
583
+ )
584
+ t = torch.cat(chunks, dim=0)
585
+ else:
586
+ t = _moe_chunk(t, expert_indices, expert_weights)
587
+ return x + t
588
+
589
+
590
+ class TransformerBlock(torch.nn.Module):
591
+ def __init__(
592
+ self,
593
+ config: ModelConfig,
594
+ device: torch.device | None = None,
595
+ ) -> None:
596
+ super().__init__()
597
+ self.attn = AttentionBlock(config, device=device)
598
+ self.mlp = MLPBlock(config, device=device)
599
+
600
+ def forward(
601
+ self,
602
+ x: torch.Tensor,
603
+ ) -> torch.Tensor:
604
+ x = self.attn(x)
605
+ return self.mlp(x)
606
+
607
+
608
+ class Checkpoint:
609
+ @staticmethod
610
+ def build_param_name_map(
611
+ num_hidden_layers: int,
612
+ ) -> dict[str, str]:
613
+ return (
614
+ {
615
+ f"block.{n}.mlp.mlp1_bias": f"block.{n}.mlp.swiglu.bias"
616
+ for n in range(num_hidden_layers)
617
+ }
618
+ | {
619
+ f"block.{n}.mlp.mlp1_weight": f"block.{n}.mlp.swiglu.weight"
620
+ for n in range(num_hidden_layers)
621
+ }
622
+ | {
623
+ f"block.{n}.mlp.mlp2_bias": f"block.{n}.mlp.out.bias"
624
+ for n in range(num_hidden_layers)
625
+ }
626
+ | {
627
+ f"block.{n}.mlp.mlp2_weight": f"block.{n}.mlp.out.weight"
628
+ for n in range(num_hidden_layers)
629
+ }
630
+ )
631
+
632
+ def __init__(self, path: str, device: torch.device, num_hidden_layers: int) -> None:
633
+ self.param_name_map = self.build_param_name_map(num_hidden_layers)
634
+ self.device_str = device.type if device.index is None else f"{device.type}:{device.index}"
635
+ safetensor_files = [
636
+ os.path.join(path, filename)
637
+ for filename in os.listdir(path)
638
+ if filename.endswith(".safetensors")
639
+ ]
640
+ tensor_name_to_file: dict[str, str] = {}
641
+ for safetensor_file in safetensor_files:
642
+ with safe_open(safetensor_file, framework="pt", device=self.device_str) as handle:
643
+ for key in handle.keys():
644
+ prior_file = tensor_name_to_file.get(key)
645
+ if prior_file is not None:
646
+ raise ValueError(
647
+ "Duplicate tensor name in checkpoint shards: "
648
+ f"{key!r} appears in {prior_file!r} and {safetensor_file!r}"
649
+ )
650
+ tensor_name_to_file[key] = safetensor_file
651
+ self.tensor_name_to_file = tensor_name_to_file
652
+
653
+ def get(self, name: str) -> torch.Tensor:
654
+ mapped = self.param_name_map.get(name, name)
655
+ return self._get_tensor(mapped)
656
+
657
+ def _get_tensor(self, name: str) -> torch.Tensor:
658
+ if name not in self.tensor_name_to_file:
659
+ raise KeyError(f"Tensor {name!r} not found in checkpoint")
660
+ with safe_open(
661
+ self.tensor_name_to_file[name], framework="pt", device=self.device_str
662
+ ) as handle:
663
+ return handle.get_tensor(name)
664
+
665
+ class Transformer(torch.nn.Module):
666
+ def __init__(self, config: ModelConfig, device: torch.device) -> None:
667
+ super().__init__()
668
+ param_dtype = torch.bfloat16
669
+ self.embedding = torch.nn.Embedding(
670
+ config.vocab_size, config.hidden_size, device=device, dtype=param_dtype
671
+ )
672
+ self.block = torch.nn.ModuleList(
673
+ [
674
+ TransformerBlock(config, device=device)
675
+ for _ in range(config.num_hidden_layers)
676
+ ]
677
+ )
678
+ self.norm = RMSNorm(config.hidden_size, device=device)
679
+ self.unembedding = torch.nn.Linear(
680
+ config.hidden_size,
681
+ config.num_labels,
682
+ bias=False,
683
+ device=device,
684
+ dtype=param_dtype,
685
+ )
686
+
687
+ def forward(
688
+ self,
689
+ token_ids: torch.Tensor,
690
+ ) -> torch.Tensor:
691
+ x = self.embedding(token_ids)
692
+ for block in self.block:
693
+ x = block(x)
694
+ x = self.norm(x)
695
+ x = F.linear(x, self.unembedding.weight, None)
696
+ return x
697
+
698
+ @classmethod
699
+ def from_checkpoint(
700
+ cls,
701
+ checkpoint_dir: str,
702
+ *,
703
+ device: torch.device,
704
+ ) -> "Transformer":
705
+ torch.backends.cuda.matmul.allow_tf32 = False
706
+ torch.backends.cudnn.allow_tf32 = False
707
+ torch.set_float32_matmul_precision("highest")
708
+ config_path = Path(checkpoint_dir) / "config.json"
709
+ with config_path.open("r", encoding="utf-8") as handle:
710
+ checkpoint_config = json.load(handle)
711
+ if not isinstance(checkpoint_config, dict):
712
+ raise ValueError(f"Invalid checkpoint config payload at {config_path}")
713
+ validate_model_config_contract(
714
+ checkpoint_config,
715
+ context=str(config_path),
716
+ )
717
+
718
+ config = ModelConfig.from_checkpoint_config(
719
+ checkpoint_config,
720
+ context=str(config_path),
721
+ )
722
+ checkpoint = Checkpoint(
723
+ checkpoint_dir,
724
+ device,
725
+ num_hidden_layers=config.num_hidden_layers,
726
+ )
727
+
728
+ model = cls(config=config, device=device)
729
+ model.eval()
730
+
731
+ for name, param in model.named_parameters():
732
+ loaded_tensor = checkpoint.get(name)
733
+ if param.data.shape != loaded_tensor.shape:
734
+ raise ValueError(
735
+ f"Tensor shape mismatch for {name!r}: expected {tuple(param.data.shape)}, "
736
+ f"got {tuple(loaded_tensor.shape)}"
737
+ )
738
+ param.data.copy_(loaded_tensor)
739
+
740
+ return model
741
+
742
+
743
+ @dataclass(frozen=True)
744
+ class LabelInfo:
745
+ boundary_label_lookup: dict[str, dict[str, int]]
746
+ token_to_span_label: dict[int, int]
747
+ token_boundary_tags: dict[int, str | None]
748
+ span_class_names: tuple[str, ...]
749
+ span_label_lookup: dict[str, int]
750
+ background_token_label: int
751
+ background_span_label: int
752
+
753
+
754
+ def labels_to_spans(
755
+ labels_by_index: dict[int, int], label_info: LabelInfo
756
+ ) -> list[tuple[int, int, int]]:
757
+ spans: list[tuple[int, int, int]] = []
758
+ current_label: int | None = None
759
+ start_idx: int | None = None
760
+ previous_idx: int | None = None
761
+ background_span_label = label_info.background_span_label
762
+
763
+ for token_idx in sorted(labels_by_index):
764
+ label_id = labels_by_index[token_idx]
765
+ span_label = label_info.token_to_span_label.get(label_id)
766
+ boundary_tag = label_info.token_boundary_tags.get(label_id)
767
+
768
+ if previous_idx is not None and token_idx != previous_idx + 1:
769
+ if current_label is not None and start_idx is not None:
770
+ spans.append((current_label, start_idx, previous_idx + 1))
771
+ current_label = None
772
+ start_idx = None
773
+
774
+ if span_label is None:
775
+ previous_idx = token_idx
776
+ continue
777
+
778
+ if span_label == background_span_label:
779
+ if current_label is not None and start_idx is not None:
780
+ spans.append((current_label, start_idx, token_idx))
781
+ current_label = None
782
+ start_idx = None
783
+ previous_idx = token_idx
784
+ continue
785
+
786
+ if boundary_tag == "S":
787
+ if current_label is not None and start_idx is not None and previous_idx is not None:
788
+ spans.append((current_label, start_idx, previous_idx + 1))
789
+ spans.append((span_label, token_idx, token_idx + 1))
790
+ current_label = None
791
+ start_idx = None
792
+ elif boundary_tag == "B":
793
+ if current_label is not None and start_idx is not None and previous_idx is not None:
794
+ spans.append((current_label, start_idx, previous_idx + 1))
795
+ current_label = span_label
796
+ start_idx = token_idx
797
+ elif boundary_tag == "I":
798
+ if current_label is None or current_label != span_label:
799
+ if current_label is not None and start_idx is not None and previous_idx is not None:
800
+ spans.append((current_label, start_idx, previous_idx + 1))
801
+ current_label = span_label
802
+ start_idx = token_idx
803
+ elif boundary_tag == "E":
804
+ if current_label is None or current_label != span_label or start_idx is None:
805
+ if current_label is not None and start_idx is not None and previous_idx is not None:
806
+ spans.append((current_label, start_idx, previous_idx + 1))
807
+ spans.append((span_label, token_idx, token_idx + 1))
808
+ current_label = None
809
+ start_idx = None
810
+ else:
811
+ spans.append((current_label, start_idx, token_idx + 1))
812
+ current_label = None
813
+ start_idx = None
814
+ else:
815
+ if current_label is not None and start_idx is not None and previous_idx is not None:
816
+ spans.append((current_label, start_idx, previous_idx + 1))
817
+ current_label = None
818
+ start_idx = None
819
+
820
+ previous_idx = token_idx
821
+
822
+ if current_label is not None and start_idx is not None and previous_idx is not None:
823
+ spans.append((current_label, start_idx, previous_idx + 1))
824
+ return spans
825
+
826
+
827
+ def token_spans_to_char_spans(
828
+ spans: Sequence[tuple[int, int, int]],
829
+ char_starts: Sequence[int],
830
+ char_ends: Sequence[int],
831
+ ) -> list[tuple[int, int, int]]:
832
+ converted: list[tuple[int, int, int]] = []
833
+ for label_idx, token_start, token_end in spans:
834
+ if not (0 <= token_start < token_end <= len(char_starts)):
835
+ continue
836
+ char_start = char_starts[token_start]
837
+ char_end = char_ends[token_end - 1]
838
+ if char_end <= char_start:
839
+ continue
840
+ converted.append((label_idx, char_start, char_end))
841
+ return converted
842
+
843
+
844
+ def trim_char_spans_whitespace(
845
+ spans: Sequence[tuple[int, int, int]],
846
+ text: str,
847
+ ) -> list[tuple[int, int, int]]:
848
+ trimmed: list[tuple[int, int, int]] = []
849
+ for label_idx, start, end in spans:
850
+ if not (0 <= start < end <= len(text)):
851
+ continue
852
+ while start < end and text[start].isspace():
853
+ start += 1
854
+ while end > start and text[end - 1].isspace():
855
+ end -= 1
856
+ if end > start:
857
+ trimmed.append((label_idx, start, end))
858
+ return trimmed
859
+
860
+
861
+ @dataclass(frozen=True)
862
+ class InferenceRuntime:
863
+ model: Transformer
864
+ encoding: tiktoken.Encoding
865
+ label_info: LabelInfo
866
+ device: torch.device
867
+ n_ctx: int
868
+
869
+
870
+ @functools.lru_cache(maxsize=1)
871
+ def get_viterbi_transition_biases() -> dict[str, float]:
872
+ calibration_path = MODEL_DIR / "viterbi_calibration.json"
873
+ default_biases = {key: 0.0 for key in VITERBI_TRANSITION_BIAS_KEYS}
874
+ if not calibration_path.is_file():
875
+ return default_biases
876
+
877
+ payload = json.loads(calibration_path.read_text(encoding="utf-8"))
878
+ if not isinstance(payload, dict):
879
+ raise ValueError(f"Invalid Viterbi calibration payload at {calibration_path}")
880
+
881
+ raw_biases: object = payload
882
+ operating_points = payload.get("operating_points")
883
+ if operating_points is not None:
884
+ if not isinstance(operating_points, dict):
885
+ raise ValueError(f"Invalid operating_points payload at {calibration_path}")
886
+ preset_entry = operating_points.get(DEFAULT_VITERBI_CALIBRATION_PRESET)
887
+ if not isinstance(preset_entry, dict):
888
+ raise ValueError(
889
+ f"Missing operating_points.{DEFAULT_VITERBI_CALIBRATION_PRESET!s} "
890
+ f"in {calibration_path}"
891
+ )
892
+ raw_biases = preset_entry.get("biases")
893
+
894
+ if not isinstance(raw_biases, dict):
895
+ raise ValueError(f"Invalid Viterbi bias payload at {calibration_path}")
896
+
897
+ resolved_biases: dict[str, float] = {}
898
+ for key in VITERBI_TRANSITION_BIAS_KEYS:
899
+ raw_value = raw_biases.get(key)
900
+ if isinstance(raw_value, bool) or not isinstance(raw_value, (int, float)):
901
+ raise ValueError(f"Missing or invalid {key!r} in {calibration_path}")
902
+ resolved_biases[key] = float(raw_value)
903
+ return resolved_biases
904
+
905
+
906
+ @functools.lru_cache(maxsize=1)
907
+ def get_runtime() -> InferenceRuntime:
908
+ checkpoint = MODEL_DIR
909
+ if not checkpoint.exists() or not checkpoint.is_dir():
910
+ raise FileNotFoundError(f"Checkpoint directory not found: {checkpoint}")
911
+ if not any(checkpoint.glob("*.safetensors")):
912
+ raise FileNotFoundError(f"Checkpoint directory has no .safetensors files: {checkpoint}")
913
+ if not torch.cuda.is_available():
914
+ raise RuntimeError("CUDA is not available")
915
+ config_path = checkpoint / "config.json"
916
+ checkpoint_config = json.loads(config_path.read_text(encoding="utf-8"))
917
+ if not isinstance(checkpoint_config, dict):
918
+ raise ValueError(f"Invalid checkpoint config payload at {config_path}")
919
+ validate_model_config_contract(
920
+ checkpoint_config,
921
+ context=str(config_path),
922
+ )
923
+ ner_class_names = NER_CLASS_NAMES
924
+ device = torch.device("cuda")
925
+ n_ctx = int(checkpoint_config["default_n_ctx"])
926
+
927
+ encoding = tiktoken.get_encoding(str(checkpoint_config["encoding"]).strip())
928
+ span_class_names: list[str] = [BACKGROUND_CLASS_LABEL]
929
+ span_label_lookup: dict[str, int] = {BACKGROUND_CLASS_LABEL: 0}
930
+ boundary_label_lookup: dict[str, dict[str, int]] = {}
931
+ token_to_span_label: dict[int, int] = {}
932
+ token_boundary_tags: dict[int, str | None] = {}
933
+ background_idx: int | None = None
934
+ for idx, name in enumerate(ner_class_names):
935
+ if name == BACKGROUND_CLASS_LABEL:
936
+ background_idx = idx
937
+ token_to_span_label[idx] = span_label_lookup[BACKGROUND_CLASS_LABEL]
938
+ token_boundary_tags[idx] = None
939
+ continue
940
+ boundary, base_label = name.split("-", 1)
941
+ span_idx = span_label_lookup.get(base_label)
942
+ if span_idx is None:
943
+ span_idx = len(span_class_names)
944
+ span_class_names.append(base_label)
945
+ span_label_lookup[base_label] = span_idx
946
+ token_to_span_label[idx] = span_idx
947
+ token_boundary_tags[idx] = boundary
948
+ boundary_label_lookup.setdefault(base_label, {})[boundary] = idx
949
+ if background_idx is None:
950
+ raise ValueError("Class names must include background label 'O'")
951
+ for base_label, mapping in boundary_label_lookup.items():
952
+ missing = set(BOUNDARY_PREFIXES) - set(mapping)
953
+ if missing:
954
+ raise ValueError(
955
+ f"Missing boundary classes {sorted(missing)} for base label {base_label}"
956
+ )
957
+ label_info = LabelInfo(
958
+ boundary_label_lookup={key: dict(value) for key, value in boundary_label_lookup.items()},
959
+ token_to_span_label=dict(token_to_span_label),
960
+ token_boundary_tags=dict(token_boundary_tags),
961
+ span_class_names=tuple(span_class_names),
962
+ span_label_lookup=dict(span_label_lookup),
963
+ background_token_label=background_idx,
964
+ background_span_label=span_label_lookup[BACKGROUND_CLASS_LABEL],
965
+ )
966
+ model = Transformer.from_checkpoint(
967
+ checkpoint,
968
+ device=device,
969
+ )
970
+ return InferenceRuntime(
971
+ model=model,
972
+ encoding=encoding,
973
+ label_info=label_info,
974
+ device=device,
975
+ n_ctx=n_ctx,
976
+ )
977
+
978
+
979
+ class Decoder:
980
+ def __init__(self, label_info: LabelInfo) -> None:
981
+ self.label_info = label_info
982
+ num_classes = len(label_info.token_to_span_label)
983
+ self._start_scores = torch.full((num_classes,), -1e9, dtype=torch.float32)
984
+ self._end_scores = torch.full((num_classes,), -1e9, dtype=torch.float32)
985
+ self._transition_scores = torch.full((num_classes, num_classes), -1e9, dtype=torch.float32)
986
+ transition_biases = get_viterbi_transition_biases()
987
+
988
+ background_token_idx = label_info.background_token_label
989
+ background_span_idx = label_info.background_span_label
990
+ token_boundary_tags = label_info.token_boundary_tags
991
+ token_to_span_label = label_info.token_to_span_label
992
+
993
+ for idx in range(num_classes):
994
+ tag = token_boundary_tags.get(idx)
995
+ span_label = token_to_span_label.get(idx)
996
+ if tag in {"B", "S"} or idx == background_token_idx:
997
+ self._start_scores[idx] = 0.0
998
+ if tag in {"E", "S"} or idx == background_token_idx:
999
+ self._end_scores[idx] = 0.0
1000
+
1001
+ for next_idx in range(num_classes):
1002
+ next_tag = token_boundary_tags.get(next_idx)
1003
+ next_span_label = token_to_span_label.get(next_idx)
1004
+ if self._is_valid_transition(
1005
+ prev_tag=tag,
1006
+ prev_span=span_label,
1007
+ next_tag=next_tag,
1008
+ next_span=next_span_label,
1009
+ background_token_idx=background_token_idx,
1010
+ background_span_idx=background_span_idx,
1011
+ next_idx=next_idx,
1012
+ ):
1013
+ self._transition_scores[idx, next_idx] = self._transition_bias(
1014
+ prev_tag=tag,
1015
+ prev_span=span_label,
1016
+ next_tag=next_tag,
1017
+ next_span=next_span_label,
1018
+ background_span_idx=background_span_idx,
1019
+ biases=transition_biases,
1020
+ )
1021
+
1022
+ @staticmethod
1023
+ def _is_valid_transition(
1024
+ *,
1025
+ prev_tag: str | None,
1026
+ prev_span: int | None,
1027
+ next_tag: str | None,
1028
+ next_span: int | None,
1029
+ background_token_idx: int,
1030
+ background_span_idx: int,
1031
+ next_idx: int,
1032
+ ) -> bool:
1033
+ next_is_background = next_span == background_span_idx or next_idx == background_token_idx
1034
+ if (next_span is None or next_tag is None) and not next_is_background:
1035
+ return False
1036
+
1037
+ if prev_span is None or prev_tag is None:
1038
+ return next_is_background or next_tag in {"B", "S"}
1039
+
1040
+ prev_is_background = prev_span == background_span_idx
1041
+ if prev_is_background or prev_tag in {"E", "S"}:
1042
+ return next_is_background or next_tag in {"B", "S"}
1043
+ if prev_tag in {"B", "I"}:
1044
+ return prev_span == next_span and next_tag in {"I", "E"}
1045
+ return False
1046
+
1047
+ @staticmethod
1048
+ def _transition_bias(
1049
+ *,
1050
+ prev_tag: str | None,
1051
+ prev_span: int | None,
1052
+ next_tag: str | None,
1053
+ next_span: int | None,
1054
+ background_span_idx: int,
1055
+ biases: dict[str, float],
1056
+ ) -> float:
1057
+ next_is_background = next_span == background_span_idx
1058
+ prev_is_background = prev_span == background_span_idx
1059
+ if prev_is_background:
1060
+ return (
1061
+ biases["transition_bias_background_stay"]
1062
+ if next_is_background
1063
+ else biases["transition_bias_background_to_start"]
1064
+ )
1065
+ if prev_tag in {"B", "I"}:
1066
+ return (
1067
+ biases["transition_bias_inside_to_continue"]
1068
+ if next_tag == "I"
1069
+ else biases["transition_bias_inside_to_end"]
1070
+ )
1071
+ return (
1072
+ biases["transition_bias_end_to_background"]
1073
+ if next_is_background
1074
+ else biases["transition_bias_end_to_start"]
1075
+ )
1076
+
1077
+ def decode(self, token_logprobs: torch.Tensor) -> list[int]:
1078
+ if token_logprobs.ndim != 2:
1079
+ raise ValueError("token_logprobs must have shape [seq_len, num_classes]")
1080
+ seq_len, num_classes = token_logprobs.shape
1081
+ if seq_len == 0:
1082
+ return []
1083
+
1084
+ start_scores = self._start_scores.to(
1085
+ device=token_logprobs.device,
1086
+ dtype=token_logprobs.dtype,
1087
+ )
1088
+ end_scores = self._end_scores.to(
1089
+ device=token_logprobs.device,
1090
+ dtype=token_logprobs.dtype,
1091
+ )
1092
+ transition_scores = self._transition_scores.to(
1093
+ device=token_logprobs.device,
1094
+ dtype=token_logprobs.dtype,
1095
+ )
1096
+ scores = token_logprobs[0] + start_scores
1097
+ backpointers = torch.empty(
1098
+ (seq_len - 1, num_classes),
1099
+ device=token_logprobs.device,
1100
+ dtype=torch.int64,
1101
+ )
1102
+
1103
+ for idx in range(1, seq_len):
1104
+ transitions = scores.unsqueeze(1) + transition_scores
1105
+ best_scores, best_paths = transitions.max(dim=0)
1106
+ scores = best_scores + token_logprobs[idx]
1107
+ backpointers[idx - 1] = best_paths
1108
+
1109
+ if not torch.isfinite(scores).any():
1110
+ return token_logprobs.argmax(dim=1).tolist()
1111
+
1112
+ scores = scores + end_scores
1113
+ last_label = scores.argmax()
1114
+ path = torch.empty((seq_len,), device=token_logprobs.device, dtype=torch.int64)
1115
+ path[-1] = last_label
1116
+ for idx in range(seq_len - 2, -1, -1):
1117
+ last_label = backpointers[idx, last_label]
1118
+ path[idx] = last_label
1119
+ return path.tolist()
1120
+
1121
+
1122
+ @torch.inference_mode()
1123
+ def predict_text(
1124
+ runtime: InferenceRuntime,
1125
+ text: str,
1126
+ decoder: Decoder,
1127
+ ) -> tuple[str, list[dict[str, object]]]:
1128
+ token_ids = tuple(int(token) for token in runtime.encoding.encode(text, allowed_special="all"))
1129
+ if not token_ids:
1130
+ return text, []
1131
+
1132
+ if runtime.n_ctx <= 0:
1133
+ raise ValueError("runtime.n_ctx must be positive")
1134
+
1135
+ token_score_vectors: list[torch.Tensor] = []
1136
+ for start in range(0, len(token_ids), runtime.n_ctx):
1137
+ end = min(start + runtime.n_ctx, len(token_ids))
1138
+ window_tokens = torch.tensor(token_ids[start:end], device=runtime.device, dtype=torch.int32)
1139
+ logits = runtime.model(window_tokens)
1140
+ log_probs = F.log_softmax(logits.float(), dim=-1)
1141
+ if log_probs.shape[0] != window_tokens.shape[0]:
1142
+ raise ValueError("Logprob output length does not match window length")
1143
+ token_score_vectors.extend(log_probs.unbind(0))
1144
+
1145
+ if not token_score_vectors:
1146
+ return text, []
1147
+
1148
+ stacked_scores = torch.stack(token_score_vectors, dim=0)
1149
+ decoded_labels = decoder.decode(stacked_scores)
1150
+ if len(decoded_labels) != len(token_ids):
1151
+ decoded_labels = stacked_scores.argmax(dim=1).tolist()
1152
+
1153
+ predicted_labels_by_index = {
1154
+ token_idx: int(label) for token_idx, label in enumerate(decoded_labels)
1155
+ }
1156
+ predicted_token_spans = labels_to_spans(predicted_labels_by_index, runtime.label_info)
1157
+ token_bytes = [runtime.encoding.decode_single_token_bytes(token_id) for token_id in token_ids]
1158
+ decoded_text = b"".join(token_bytes).decode("utf-8", errors="replace")
1159
+ char_byte_starts: list[int] = []
1160
+ char_byte_ends: list[int] = []
1161
+ byte_cursor = 0
1162
+ for ch in decoded_text:
1163
+ char_byte_starts.append(byte_cursor)
1164
+ byte_cursor += len(ch.encode("utf-8"))
1165
+ char_byte_ends.append(byte_cursor)
1166
+ char_starts: list[int] = []
1167
+ char_ends: list[int] = []
1168
+ token_byte_cursor = 0
1169
+ for raw_bytes in token_bytes:
1170
+ token_byte_start = token_byte_cursor
1171
+ token_byte_end = token_byte_start + len(raw_bytes)
1172
+ token_byte_cursor = token_byte_end
1173
+ start_idx = bisect_right(char_byte_ends, token_byte_start)
1174
+ end_idx = bisect_left(char_byte_starts, token_byte_end)
1175
+ if end_idx < start_idx:
1176
+ end_idx = start_idx
1177
+ char_starts.append(start_idx)
1178
+ char_ends.append(end_idx)
1179
+ if char_ends and char_ends[-1] != len(decoded_text):
1180
+ raise ValueError(
1181
+ f"Character length mismatch for decoded text (tokens={char_ends[-1]}, text={len(decoded_text)})"
1182
+ )
1183
+ decoded_mismatch = decoded_text != text
1184
+ source_text = decoded_text if decoded_mismatch else text
1185
+ predicted_char_spans = token_spans_to_char_spans(
1186
+ predicted_token_spans,
1187
+ char_starts,
1188
+ char_ends,
1189
+ )
1190
+ predicted_char_spans = trim_char_spans_whitespace(predicted_char_spans, source_text)
1191
+
1192
+ detected: list[dict[str, object]] = []
1193
+ for label_idx, start, end in predicted_char_spans:
1194
+ if not (0 <= start < end <= len(source_text)):
1195
+ continue
1196
+ label = (
1197
+ runtime.label_info.span_class_names[label_idx]
1198
+ if 0 <= label_idx < len(runtime.label_info.span_class_names)
1199
+ else f"label_{label_idx}"
1200
+ )
1201
+ detected.append(
1202
+ {
1203
+ "entity": label,
1204
+ "start": int(start),
1205
+ "end": int(end),
1206
+ }
1207
+ )
1208
+
1209
+ return source_text, detected
1210
+
1211
+
1212
+ @spaces.GPU
1213
+ def predict(text: str) -> dict[str, object]:
1214
+ text = text or ""
1215
+ if not text.strip():
1216
+ return EMPTY_HIGHLIGHT_PAYLOAD
1217
+ runtime = get_runtime()
1218
+ decoder = Decoder(label_info=runtime.label_info)
1219
+ filtered_text, spans = predict_text(runtime, text, decoder)
1220
+ return {
1221
+ "text": filtered_text,
1222
+ "entities": spans,
1223
+ }
1224
+
1225
+
1226
+ def build_demo() -> gr.Blocks:
1227
+ config_path = MODEL_DIR / "config.json"
1228
+ checkpoint_config = json.loads(config_path.read_text(encoding="utf-8"))
1229
+ if not isinstance(checkpoint_config, dict):
1230
+ raise ValueError(f"Invalid checkpoint config payload at {config_path}")
1231
+ validate_model_config_contract(
1232
+ checkpoint_config,
1233
+ context=str(config_path),
1234
+ )
1235
+ span_class_names = SPAN_CLASS_NAMES
1236
+ web_color_palette = (
1237
+ "#e6194b",
1238
+ "#3cb44b",
1239
+ "#4363d8",
1240
+ "#f58231",
1241
+ "#911eb4",
1242
+ "#008080",
1243
+ "#9a6324",
1244
+ "#f032e6",
1245
+ "#b59f00",
1246
+ "#800000",
1247
+ "#000075",
1248
+ "#808080",
1249
+ )
1250
+ with gr.Blocks(
1251
+ title="OpenAI Privacy Filter",
1252
+ fill_width=True,
1253
+ elem_id="privacy-filter-app",
1254
+ ) as demo:
1255
+ gr.Markdown("# OpenAI Privacy Filter Demo")
1256
+ gr.Markdown("Example of using OpenAI Privacy Filter (OPF) to mask personal identifiers.")
1257
+
1258
+ with gr.Column(variant="panel"):
1259
+ gr.Markdown("Input text:")
1260
+ input_text = gr.Textbox(
1261
+ lines=2,
1262
+ placeholder="Paste text here to detect and mask personal identifiers...",
1263
+ show_label=False,
1264
+ container=False,
1265
+ )
1266
+
1267
+ with gr.Column(variant="panel"):
1268
+ gr.Markdown("Text after masking personal identifiers:")
1269
+ output_text = gr.HighlightedText(
1270
+ value=EMPTY_HIGHLIGHT_PAYLOAD,
1271
+ color_map={
1272
+ label: web_color_palette[idx % len(web_color_palette)]
1273
+ for idx, label in enumerate(
1274
+ label for label in span_class_names if label != BACKGROUND_CLASS_LABEL
1275
+ )
1276
+ },
1277
+ combine_adjacent=False,
1278
+ show_legend=False,
1279
+ show_label=False,
1280
+ container=True,
1281
+ )
1282
+
1283
+ with gr.Row():
1284
+ submit_button = gr.Button("Submit", variant="primary")
1285
+ clear_button = gr.Button("Clear")
1286
+
1287
+ submit_button.click(
1288
+ fn=predict,
1289
+ inputs=input_text,
1290
+ outputs=output_text,
1291
+ api_name="predict",
1292
+ )
1293
+ input_text.submit(
1294
+ fn=predict,
1295
+ inputs=input_text,
1296
+ outputs=output_text,
1297
+ )
1298
+ clear_button.click(
1299
+ lambda: ("", EMPTY_HIGHLIGHT_PAYLOAD),
1300
+ outputs=[input_text, output_text],
1301
+ show_progress="hidden",
1302
+ )
1303
+
1304
+ gr.Examples(
1305
+ examples=[
1306
+ ["Alice was born on 1990-01-02 and lives at 1 Main St."],
1307
+ ["Email me at alice@example.com or call 415-555-0101."],
1308
+ ],
1309
+ inputs=input_text,
1310
+ outputs=output_text,
1311
+ fn=predict,
1312
+ cache_examples=False,
1313
+ )
1314
+ return demo
1315
+
1316
+
1317
+ if __name__ == "__main__":
1318
+ demo = build_demo()
1319
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio>=6.9.0,<7
2
+ safetensors>=0.7.0,<1
3
+ spaces>=0.47.0,<1
4
+ tiktoken>=0.12.0,<1
5
+ torch