Spaces:
Running on Zero
Running on Zero
| import dataclasses | |
| import functools | |
| import inspect | |
| import json | |
| import math | |
| import os | |
| from bisect import bisect_left, bisect_right | |
| from collections.abc import Sequence | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Final | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| import torch.nn.functional as F | |
| from safetensors import safe_open | |
| import tiktoken | |
| from huggingface_hub import snapshot_download | |
| MODEL_ROOT = snapshot_download("openai/privacy-filter", allow_patterns=["original/*"]) | |
| MODEL_DIR = Path(MODEL_ROOT) / "original" | |
| PRIVACY_FILTER_MODEL_TYPE: Final[str] = "privacy_filter" | |
| REQUIRED_MODEL_CONFIG_KEYS: Final[tuple[str, ...]] = ( | |
| "model_type", | |
| "encoding", | |
| "num_hidden_layers", | |
| "num_experts", | |
| "experts_per_token", | |
| "vocab_size", | |
| "num_labels", | |
| "hidden_size", | |
| "intermediate_size", | |
| "head_dim", | |
| "num_attention_heads", | |
| "num_key_value_heads", | |
| "sliding_window", | |
| "bidirectional_context", | |
| "bidirectional_left_context", | |
| "bidirectional_right_context", | |
| "default_n_ctx", | |
| "initial_context_length", | |
| "rope_theta", | |
| "rope_scaling_factor", | |
| "rope_ntk_alpha", | |
| "rope_ntk_beta", | |
| "param_dtype", | |
| ) | |
| BACKGROUND_CLASS_LABEL: Final[str] = "O" | |
| BOUNDARY_PREFIXES: Final[tuple[str, ...]] = ("B", "I", "E", "S") | |
| EMPTY_HIGHLIGHT_PAYLOAD = {"text": "", "entities": []} | |
| EMPTY_SUMMARY_MARKDOWN = "_No entities detected yet._" | |
| SPAN_CLASS_NAMES: Final[tuple[str, ...]] = ( | |
| BACKGROUND_CLASS_LABEL, | |
| "account_number", | |
| "private_address", | |
| "private_date", | |
| "private_email", | |
| "private_person", | |
| "private_phone", | |
| "private_url", | |
| "secret", | |
| ) | |
| REDACTION_LABEL_MAP: Final[dict[str, str]] = { | |
| "account_number": "[ACCOUNT_NUMBER]", | |
| "private_address": "[ADDRESS]", | |
| "private_date": "[DATE]", | |
| "private_email": "[EMAIL]", | |
| "private_person": "[PERSON]", | |
| "private_phone": "[PHONE]", | |
| "private_url": "[URL]", | |
| "secret": "[SECRET]", | |
| } | |
| NER_CLASS_NAMES: Final[tuple[str, ...]] = (BACKGROUND_CLASS_LABEL,) + tuple( | |
| f"{prefix}-{base_label}" | |
| for base_label in SPAN_CLASS_NAMES | |
| if base_label != BACKGROUND_CLASS_LABEL | |
| for prefix in BOUNDARY_PREFIXES | |
| ) | |
| VITERBI_TRANSITION_BIAS_KEYS: Final[tuple[str, ...]] = ( | |
| "transition_bias_background_stay", | |
| "transition_bias_background_to_start", | |
| "transition_bias_inside_to_continue", | |
| "transition_bias_inside_to_end", | |
| "transition_bias_end_to_background", | |
| "transition_bias_end_to_start", | |
| ) | |
| DEFAULT_VITERBI_CALIBRATION_PRESET: Final[str] = "default" | |
| def supported_kwargs( | |
| factory: object, | |
| **kwargs: object, | |
| ) -> dict[str, object]: | |
| signature = inspect.signature(factory) | |
| return {key: value for key, value in kwargs.items() if key in signature.parameters} | |
| def validate_model_config_contract( | |
| checkpoint_config: dict[str, object], | |
| *, | |
| context: str, | |
| ) -> None: | |
| missing = [key for key in REQUIRED_MODEL_CONFIG_KEYS if key not in checkpoint_config] | |
| if missing: | |
| raise ValueError(f"{context} is missing required model config keys: {', '.join(missing)}") | |
| model_type = checkpoint_config.get("model_type") | |
| if model_type != PRIVACY_FILTER_MODEL_TYPE: | |
| raise ValueError( | |
| f"{context} model_type must be {PRIVACY_FILTER_MODEL_TYPE!r}, got {model_type!r}" | |
| ) | |
| if checkpoint_config.get("bidirectional_context") is not True: | |
| raise ValueError(f"{context} must use bidirectional_context=true") | |
| raw_left_context = checkpoint_config.get("bidirectional_left_context") | |
| raw_right_context = checkpoint_config.get("bidirectional_right_context") | |
| if ( | |
| not isinstance(raw_left_context, int) | |
| or isinstance(raw_left_context, bool) | |
| or not isinstance(raw_right_context, int) | |
| or isinstance(raw_right_context, bool) | |
| ): | |
| raise ValueError( | |
| f"{context} bidirectional context sizes must be integers " | |
| f"(got {raw_left_context!r}/{raw_right_context!r})" | |
| ) | |
| left_context = raw_left_context | |
| right_context = raw_right_context | |
| if left_context < 0 or right_context < 0: | |
| raise ValueError( | |
| f"{context} bidirectional context sizes must be >= 0 " | |
| f"(got {left_context}/{right_context})" | |
| ) | |
| if left_context != right_context: | |
| raise ValueError( | |
| f"{context} bidirectional context must be symmetric " | |
| f"(got left={left_context}, right={right_context})" | |
| ) | |
| raw_sliding_window = checkpoint_config.get("sliding_window") | |
| if not isinstance(raw_sliding_window, int) or isinstance(raw_sliding_window, bool): | |
| raise ValueError(f"{context} sliding_window must be an integer, got {raw_sliding_window!r}") | |
| sliding_window = raw_sliding_window | |
| expected_sliding_window = 2 * left_context + 1 | |
| if sliding_window != expected_sliding_window: | |
| raise ValueError( | |
| f"{context} sliding_window must equal 2 * bidirectional context + 1 " | |
| f"(got {sliding_window}, expected {expected_sliding_window})" | |
| ) | |
| num_labels_raw = checkpoint_config["num_labels"] | |
| if not isinstance(num_labels_raw, int) or isinstance(num_labels_raw, bool): | |
| raise ValueError(f"{context} num_labels must be an integer, got {num_labels_raw!r}") | |
| num_labels = num_labels_raw | |
| if num_labels != 33: | |
| raise ValueError( | |
| f"{context} must use num_labels=33 for the label space, got {num_labels}" | |
| ) | |
| raw_encoding = checkpoint_config["encoding"] | |
| if not isinstance(raw_encoding, str) or not raw_encoding.strip(): | |
| raise ValueError(f"{context} encoding must be a non-empty string") | |
| raw_n_ctx = checkpoint_config["default_n_ctx"] | |
| if not isinstance(raw_n_ctx, int) or isinstance(raw_n_ctx, bool): | |
| raise ValueError(f"{context} default_n_ctx must be a positive integer, got {raw_n_ctx!r}") | |
| n_ctx = raw_n_ctx | |
| if n_ctx <= 0: | |
| raise ValueError(f"{context} default_n_ctx must be positive, got {n_ctx}") | |
| raw_param_dtype = checkpoint_config["param_dtype"] | |
| if raw_param_dtype != "bfloat16": | |
| raise ValueError(f"{context} param_dtype must be bfloat16, got {raw_param_dtype!r}") | |
| def expert_linear( | |
| x: torch.Tensor, | |
| weight: torch.Tensor, | |
| bias: torch.Tensor | None, | |
| ) -> torch.Tensor: | |
| num_rows, experts, k_dim = x.shape | |
| _, _, _, out_dim = weight.shape | |
| x_bmm = x.reshape(num_rows * experts, 1, k_dim) | |
| w_bmm = weight.reshape(num_rows * experts, k_dim, out_dim) | |
| out = torch.bmm(x_bmm, w_bmm).reshape(num_rows, experts, out_dim) | |
| if bias is not None: | |
| out = out + bias | |
| return out | |
| class ModelConfig: | |
| num_hidden_layers: int | |
| num_experts: int | |
| experts_per_token: int | |
| vocab_size: int | |
| num_labels: int | |
| hidden_size: int | |
| intermediate_size: int | |
| head_dim: int | |
| num_attention_heads: int | |
| num_key_value_heads: int | |
| bidirectional_context_size: int | |
| initial_context_length: int | |
| rope_theta: float | |
| rope_scaling_factor: float | |
| rope_ntk_alpha: float | |
| rope_ntk_beta: float | |
| def from_checkpoint_config( | |
| cls, | |
| checkpoint_config: dict[str, object], | |
| *, | |
| context: str, | |
| ) -> "ModelConfig": | |
| checkpoint_config = dict(checkpoint_config) | |
| checkpoint_config["bidirectional_context_size"] = checkpoint_config[ | |
| "bidirectional_left_context" | |
| ] | |
| fields = {field.name: field for field in dataclasses.fields(cls)} | |
| config_values = { | |
| key: value for key, value in checkpoint_config.items() if key in fields | |
| } | |
| missing = [ | |
| name | |
| for name, field in fields.items() | |
| if field.default is dataclasses.MISSING | |
| and field.default_factory is dataclasses.MISSING | |
| and name not in config_values | |
| ] | |
| if missing: | |
| raise ValueError( | |
| f"{context} is missing required model config fields: {', '.join(missing)}" | |
| ) | |
| try: | |
| return cls(**config_values) | |
| except TypeError as exc: | |
| raise ValueError(f"Invalid model config payload at {context}: {exc}") from exc | |
| class RMSNorm(torch.nn.Module): | |
| def __init__( | |
| self, num_features: int, eps: float = 1e-05, device: torch.device | None = None | |
| ) -> None: | |
| super().__init__() | |
| self.num_features = num_features | |
| self.eps = eps | |
| self.scale = torch.nn.Parameter( | |
| torch.ones(num_features, device=device, dtype=torch.float32) | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| t = x.float() | |
| t = t * torch.rsqrt(torch.mean(t**2, dim=-1, keepdim=True) + self.eps) | |
| return (t * self.scale).to(x.dtype) | |
| def apply_rope( | |
| x: torch.Tensor, | |
| cos: torch.Tensor, | |
| sin: torch.Tensor, | |
| ) -> torch.Tensor: | |
| cos = cos.unsqueeze(-2).to(x.dtype) | |
| sin = sin.unsqueeze(-2).to(x.dtype) | |
| x1 = x[..., ::2] | |
| x2 = x[..., 1::2] | |
| out1 = x1 * cos - x2 * sin | |
| out2 = x2 * cos + x1 * sin | |
| return torch.stack((out1, out2), dim=-1).reshape(x.shape) | |
| class RotaryEmbedding(torch.nn.Module): | |
| def __init__( | |
| self, | |
| head_dim: int, | |
| base: int, | |
| dtype: torch.dtype, | |
| *, | |
| initial_context_length: int = 4096, | |
| scaling_factor: float = 1.0, | |
| ntk_alpha: float = 1.0, | |
| ntk_beta: float = 32.0, | |
| device: torch.device | None = None, | |
| ) -> None: | |
| super().__init__() | |
| self.head_dim = head_dim | |
| self.base = base | |
| self.dtype = dtype | |
| self.initial_context_length = initial_context_length | |
| self.scaling_factor = scaling_factor | |
| self.ntk_alpha = ntk_alpha | |
| self.ntk_beta = ntk_beta | |
| self.device = device | |
| max_positions = int(self.initial_context_length * self.scaling_factor) | |
| max_positions = max(max_positions, self.initial_context_length) | |
| self.max_position_embeddings = max_positions | |
| cos, sin = self._compute_cos_sin(self.max_position_embeddings, device=torch.device("cpu")) | |
| target_device = device or torch.device("cpu") | |
| self.register_buffer("cos_cache", cos.to(target_device), persistent=False) | |
| self.register_buffer("sin_cache", sin.to(target_device), persistent=False) | |
| def _compute_concentration_and_inv_freq( | |
| self, device: torch.device | None = None | |
| ) -> tuple[float, torch.Tensor]: | |
| device = device or self.device | |
| freq = self.base ** ( | |
| torch.arange(0, self.head_dim, 2, dtype=torch.float, device=device) / self.head_dim | |
| ) | |
| if self.scaling_factor > 1.0: | |
| concentration = 0.1 * math.log(self.scaling_factor) + 1.0 | |
| d_half = self.head_dim / 2 | |
| low = ( | |
| d_half | |
| * math.log(self.initial_context_length / (self.ntk_beta * 2 * math.pi)) | |
| / math.log(self.base) | |
| ) | |
| high = ( | |
| d_half | |
| * math.log(self.initial_context_length / (self.ntk_alpha * 2 * math.pi)) | |
| / math.log(self.base) | |
| ) | |
| interpolation = 1.0 / (self.scaling_factor * freq) | |
| extrapolation = 1.0 / freq | |
| ramp = (torch.arange(d_half, dtype=torch.float32, device=freq.device) - low) / ( | |
| high - low | |
| ) | |
| mask = 1 - ramp.clamp(0, 1) | |
| inv_freq = interpolation * (1 - mask) + extrapolation * mask | |
| else: | |
| concentration = 1.0 | |
| inv_freq = 1.0 / freq | |
| return concentration, inv_freq | |
| def _compute_cos_sin( | |
| self, num_tokens: int, device: torch.device | None = None | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| concentration, inv_freq = self._compute_concentration_and_inv_freq(device=device) | |
| device = device or self.device | |
| t = torch.arange(num_tokens, dtype=torch.float32, device=device) | |
| freqs = torch.einsum("i,j->ij", t, inv_freq) | |
| cos = freqs.cos() * concentration | |
| sin = freqs.sin() * concentration | |
| return cos.to(self.dtype), sin.to(self.dtype) | |
| def forward( | |
| self, | |
| query: torch.Tensor, | |
| key: torch.Tensor, | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| num_tokens = query.shape[0] | |
| if num_tokens > self.cos_cache.shape[0]: | |
| cos, sin = self._compute_cos_sin(num_tokens, device=torch.device("cpu")) | |
| self.cos_cache = cos.to(query.device) | |
| self.sin_cache = sin.to(query.device) | |
| if self.cos_cache.device != query.device: | |
| cos_cache = self.cos_cache.to(query.device) | |
| sin_cache = self.sin_cache.to(query.device) | |
| else: | |
| cos_cache = self.cos_cache | |
| sin_cache = self.sin_cache | |
| cos = cos_cache[:num_tokens] | |
| sin = sin_cache[:num_tokens] | |
| query_shape = query.shape | |
| query = query.view(num_tokens, -1, self.head_dim) | |
| query = apply_rope(query, cos, sin) | |
| query = query.reshape(query_shape) | |
| key_shape = key.shape | |
| key = key.view(num_tokens, -1, self.head_dim) | |
| key = apply_rope(key, cos, sin) | |
| key = key.reshape(key_shape) | |
| return query, key | |
| def sdpa( | |
| Q: torch.Tensor, | |
| K: torch.Tensor, | |
| V: torch.Tensor, | |
| S: torch.Tensor, | |
| sm_scale: float, | |
| context_size: int, | |
| ) -> torch.Tensor: | |
| num_tokens, num_heads, q_mult, head_dim = Q.shape | |
| window = 2 * context_size + 1 | |
| Kp = F.pad(K, (0, 0, 0, 0, context_size, context_size)) | |
| Vp = F.pad(V, (0, 0, 0, 0, context_size, context_size)) | |
| Kwin = Kp.unfold(0, window, 1).permute(0, 3, 1, 2) | |
| Vwin = Vp.unfold(0, window, 1).permute(0, 3, 1, 2) | |
| idx = torch.arange(window, device=Q.device) - context_size | |
| pos = torch.arange(num_tokens, device=Q.device)[:, None] + idx[None, :] | |
| valid = (pos >= 0) & (pos < num_tokens) | |
| scores = torch.einsum("nhqd,nwhd->nhqw", Q, Kwin).float() | |
| scores *= sm_scale | |
| scores = scores.masked_fill(~valid[:, None, None, :], -float("inf")) | |
| sink_scores = (S * math.log(2.0)).reshape(num_heads, q_mult) | |
| sink_scores = sink_scores[None, :, :, None].expand(num_tokens, -1, -1, 1) | |
| scores = torch.cat([scores, sink_scores], dim=-1) | |
| weights = torch.softmax(scores, dim=-1)[..., :-1].to(V.dtype) | |
| attn = torch.einsum("nhqw,nwhd->nhqd", weights, Vwin) | |
| return attn.reshape(num_tokens, -1) | |
| class AttentionBlock(torch.nn.Module): | |
| def __init__( | |
| self, | |
| config: ModelConfig, | |
| device: torch.device | None = None, | |
| ) -> None: | |
| super().__init__() | |
| param_dtype = torch.bfloat16 | |
| self.head_dim = config.head_dim | |
| self.num_attention_heads = config.num_attention_heads | |
| self.num_key_value_heads = config.num_key_value_heads | |
| self.bidirectional_context_size = int(config.bidirectional_context_size) | |
| self.sinks = torch.nn.Parameter( | |
| torch.empty(config.num_attention_heads, device=device, dtype=torch.float32) | |
| ) | |
| self.norm = RMSNorm(config.hidden_size, device=device) | |
| qkv_dim = config.head_dim * (config.num_attention_heads + 2 * config.num_key_value_heads) | |
| self.qkv = torch.nn.Linear(config.hidden_size, qkv_dim, device=device, dtype=param_dtype) | |
| self.out = torch.nn.Linear( | |
| config.head_dim * config.num_attention_heads, | |
| config.hidden_size, | |
| device=device, | |
| dtype=param_dtype, | |
| ) | |
| self.qk_scale = 1 / math.sqrt(math.sqrt(config.head_dim)) | |
| self.sm_scale = 1.0 | |
| self.rope = RotaryEmbedding( | |
| config.head_dim, | |
| int(config.rope_theta), | |
| torch.float32, | |
| initial_context_length=config.initial_context_length, | |
| scaling_factor=config.rope_scaling_factor, | |
| ntk_alpha=config.rope_ntk_alpha, | |
| ntk_beta=config.rope_ntk_beta, | |
| device=device, | |
| ) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| ) -> torch.Tensor: | |
| t = self.norm(x) | |
| if t.dtype != self.qkv.weight.dtype: | |
| t = t.to(self.qkv.weight.dtype) | |
| qkv = F.linear(t, self.qkv.weight, self.qkv.bias) | |
| query = qkv[:, : self.num_attention_heads * self.head_dim].contiguous() | |
| key = qkv[ | |
| :, | |
| self.num_attention_heads * self.head_dim : ( | |
| self.num_attention_heads + self.num_key_value_heads | |
| ) | |
| * self.head_dim, | |
| ].contiguous() | |
| value = qkv[ | |
| :, | |
| (self.num_attention_heads + self.num_key_value_heads) * self.head_dim : ( | |
| self.num_attention_heads + 2 * self.num_key_value_heads | |
| ) | |
| * self.head_dim, | |
| ].contiguous() | |
| query, key = self.rope(query, key) | |
| query = query * self.qk_scale | |
| key = key * self.qk_scale | |
| sinks = self.sinks | |
| num_tokens = query.shape[0] | |
| query = query.view( | |
| num_tokens, | |
| self.num_key_value_heads, | |
| self.num_attention_heads // self.num_key_value_heads, | |
| self.head_dim, | |
| ) | |
| key = key.view(num_tokens, self.num_key_value_heads, self.head_dim) | |
| value = value.view(num_tokens, self.num_key_value_heads, self.head_dim) | |
| attn_out = sdpa( | |
| query, | |
| key, | |
| value, | |
| sinks, | |
| self.sm_scale, | |
| self.bidirectional_context_size, | |
| ) | |
| if attn_out.dtype != self.out.weight.dtype: | |
| attn_out = attn_out.to(self.out.weight.dtype) | |
| proj_bias = self.out.bias | |
| proj = F.linear(attn_out, self.out.weight, proj_bias) | |
| return x + proj.to(x.dtype) | |
| def swiglu( | |
| x: torch.Tensor, | |
| alpha: float = 1.702, | |
| limit: float = 7.0, | |
| ) -> torch.Tensor: | |
| x_glu, x_linear = x.chunk(2, dim=-1) | |
| x_glu = x_glu.clamp(min=None, max=limit) | |
| x_linear = x_linear.clamp(min=-limit, max=limit) | |
| out_glu = x_glu * torch.sigmoid(alpha * x_glu) | |
| return out_glu * (x_linear + 1) | |
| class MLPBlock(torch.nn.Module): | |
| def __init__( | |
| self, | |
| config: ModelConfig, | |
| device: torch.device | None = None, | |
| ) -> None: | |
| super().__init__() | |
| param_dtype = torch.bfloat16 | |
| self.num_experts = config.num_experts | |
| self.experts_per_token = config.experts_per_token | |
| self.swiglu_limit = 7.0 | |
| self.norm = RMSNorm(config.hidden_size, device=device) | |
| self.gate = torch.nn.Linear( | |
| config.hidden_size, config.num_experts, device=device, dtype=param_dtype | |
| ) | |
| self.mlp1_weight = torch.nn.Parameter( | |
| torch.empty( | |
| (config.num_experts, config.hidden_size, config.intermediate_size * 2), | |
| device=device, | |
| dtype=param_dtype, | |
| ) | |
| ) | |
| self.mlp1_bias = torch.nn.Parameter( | |
| torch.empty( | |
| (config.num_experts, config.intermediate_size * 2), | |
| device=device, | |
| dtype=param_dtype, | |
| ) | |
| ) | |
| self.mlp2_weight = torch.nn.Parameter( | |
| torch.empty( | |
| (config.num_experts, config.intermediate_size, config.hidden_size), | |
| device=device, | |
| dtype=param_dtype, | |
| ) | |
| ) | |
| self.mlp2_bias = torch.nn.Parameter( | |
| torch.empty( | |
| (config.num_experts, config.hidden_size), | |
| device=device, | |
| dtype=param_dtype, | |
| ) | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| t = self.norm(x) | |
| gate_scores = F.linear(t.float(), self.gate.weight.float(), self.gate.bias.float()) | |
| experts = torch.topk(gate_scores, k=self.experts_per_token, dim=-1, sorted=True) | |
| expert_weights = torch.softmax(experts.values, dim=-1) / self.experts_per_token | |
| expert_indices = experts.indices | |
| experts_per_token_eff = self.experts_per_token | |
| def _moe_chunk( | |
| t_chunk: torch.Tensor, | |
| expert_indices_chunk: torch.Tensor, | |
| expert_weights_chunk: torch.Tensor, | |
| ) -> torch.Tensor: | |
| mlp1_weight = self.mlp1_weight[expert_indices_chunk].float() | |
| mlp1_bias = self.mlp1_bias[expert_indices_chunk].float() | |
| t_expanded = t_chunk.float().unsqueeze(1).expand(-1, expert_indices_chunk.shape[1], -1) | |
| out = expert_linear( | |
| t_expanded, | |
| mlp1_weight, | |
| mlp1_bias, | |
| ) | |
| out = swiglu(out, limit=self.swiglu_limit) | |
| mlp2_weight = self.mlp2_weight[expert_indices_chunk].float() | |
| mlp2_bias = self.mlp2_bias[expert_indices_chunk].float() | |
| out = expert_linear( | |
| out.float(), | |
| mlp2_weight, | |
| mlp2_bias, | |
| ) | |
| if out.dtype != expert_weights_chunk.dtype: | |
| out = out.to(expert_weights_chunk.dtype) | |
| out = torch.einsum("bec,be->bc", out, expert_weights_chunk) | |
| out = out * experts_per_token_eff | |
| return out.to(x.dtype) | |
| torch_ops_chunk_size = 32 | |
| if t.shape[0] > torch_ops_chunk_size: | |
| chunks = [] | |
| for start in range(0, t.shape[0], torch_ops_chunk_size): | |
| end = start + torch_ops_chunk_size | |
| chunks.append( | |
| _moe_chunk( | |
| t[start:end], | |
| expert_indices[start:end], | |
| expert_weights[start:end], | |
| ) | |
| ) | |
| t = torch.cat(chunks, dim=0) | |
| else: | |
| t = _moe_chunk(t, expert_indices, expert_weights) | |
| return x + t | |
| class TransformerBlock(torch.nn.Module): | |
| def __init__( | |
| self, | |
| config: ModelConfig, | |
| device: torch.device | None = None, | |
| ) -> None: | |
| super().__init__() | |
| self.attn = AttentionBlock(config, device=device) | |
| self.mlp = MLPBlock(config, device=device) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| ) -> torch.Tensor: | |
| x = self.attn(x) | |
| return self.mlp(x) | |
| class Checkpoint: | |
| def build_param_name_map( | |
| num_hidden_layers: int, | |
| ) -> dict[str, str]: | |
| return ( | |
| { | |
| f"block.{n}.mlp.mlp1_bias": f"block.{n}.mlp.swiglu.bias" | |
| for n in range(num_hidden_layers) | |
| } | |
| | { | |
| f"block.{n}.mlp.mlp1_weight": f"block.{n}.mlp.swiglu.weight" | |
| for n in range(num_hidden_layers) | |
| } | |
| | { | |
| f"block.{n}.mlp.mlp2_bias": f"block.{n}.mlp.out.bias" | |
| for n in range(num_hidden_layers) | |
| } | |
| | { | |
| f"block.{n}.mlp.mlp2_weight": f"block.{n}.mlp.out.weight" | |
| for n in range(num_hidden_layers) | |
| } | |
| ) | |
| def __init__(self, path: str, device: torch.device, num_hidden_layers: int) -> None: | |
| self.param_name_map = self.build_param_name_map(num_hidden_layers) | |
| self.device_str = device.type if device.index is None else f"{device.type}:{device.index}" | |
| safetensor_files = [ | |
| os.path.join(path, filename) | |
| for filename in os.listdir(path) | |
| if filename.endswith(".safetensors") | |
| ] | |
| tensor_name_to_file: dict[str, str] = {} | |
| for safetensor_file in safetensor_files: | |
| with safe_open(safetensor_file, framework="pt", device=self.device_str) as handle: | |
| for key in handle.keys(): | |
| prior_file = tensor_name_to_file.get(key) | |
| if prior_file is not None: | |
| raise ValueError( | |
| "Duplicate tensor name in checkpoint shards: " | |
| f"{key!r} appears in {prior_file!r} and {safetensor_file!r}" | |
| ) | |
| tensor_name_to_file[key] = safetensor_file | |
| self.tensor_name_to_file = tensor_name_to_file | |
| def get(self, name: str) -> torch.Tensor: | |
| mapped = self.param_name_map.get(name, name) | |
| return self._get_tensor(mapped) | |
| def _get_tensor(self, name: str) -> torch.Tensor: | |
| if name not in self.tensor_name_to_file: | |
| raise KeyError(f"Tensor {name!r} not found in checkpoint") | |
| with safe_open( | |
| self.tensor_name_to_file[name], framework="pt", device=self.device_str | |
| ) as handle: | |
| return handle.get_tensor(name) | |
| class Transformer(torch.nn.Module): | |
| def __init__(self, config: ModelConfig, device: torch.device) -> None: | |
| super().__init__() | |
| param_dtype = torch.bfloat16 | |
| self.embedding = torch.nn.Embedding( | |
| config.vocab_size, config.hidden_size, device=device, dtype=param_dtype | |
| ) | |
| self.block = torch.nn.ModuleList( | |
| [ | |
| TransformerBlock(config, device=device) | |
| for _ in range(config.num_hidden_layers) | |
| ] | |
| ) | |
| self.norm = RMSNorm(config.hidden_size, device=device) | |
| self.unembedding = torch.nn.Linear( | |
| config.hidden_size, | |
| config.num_labels, | |
| bias=False, | |
| device=device, | |
| dtype=param_dtype, | |
| ) | |
| def forward( | |
| self, | |
| token_ids: torch.Tensor, | |
| ) -> torch.Tensor: | |
| x = self.embedding(token_ids) | |
| for block in self.block: | |
| x = block(x) | |
| x = self.norm(x) | |
| x = F.linear(x, self.unembedding.weight, None) | |
| return x | |
| def from_checkpoint( | |
| cls, | |
| checkpoint_dir: str, | |
| *, | |
| device: torch.device, | |
| ) -> "Transformer": | |
| torch.backends.cuda.matmul.allow_tf32 = False | |
| torch.backends.cudnn.allow_tf32 = False | |
| torch.set_float32_matmul_precision("highest") | |
| config_path = Path(checkpoint_dir) / "config.json" | |
| with config_path.open("r", encoding="utf-8") as handle: | |
| checkpoint_config = json.load(handle) | |
| if not isinstance(checkpoint_config, dict): | |
| raise ValueError(f"Invalid checkpoint config payload at {config_path}") | |
| validate_model_config_contract( | |
| checkpoint_config, | |
| context=str(config_path), | |
| ) | |
| config = ModelConfig.from_checkpoint_config( | |
| checkpoint_config, | |
| context=str(config_path), | |
| ) | |
| checkpoint = Checkpoint( | |
| checkpoint_dir, | |
| device, | |
| num_hidden_layers=config.num_hidden_layers, | |
| ) | |
| model = cls(config=config, device=device) | |
| model.eval() | |
| for name, param in model.named_parameters(): | |
| loaded_tensor = checkpoint.get(name) | |
| if param.data.shape != loaded_tensor.shape: | |
| raise ValueError( | |
| f"Tensor shape mismatch for {name!r}: expected {tuple(param.data.shape)}, " | |
| f"got {tuple(loaded_tensor.shape)}" | |
| ) | |
| param.data.copy_(loaded_tensor) | |
| return model | |
| class LabelInfo: | |
| boundary_label_lookup: dict[str, dict[str, int]] | |
| token_to_span_label: dict[int, int] | |
| token_boundary_tags: dict[int, str | None] | |
| span_class_names: tuple[str, ...] | |
| span_label_lookup: dict[str, int] | |
| background_token_label: int | |
| background_span_label: int | |
| def labels_to_spans( | |
| labels_by_index: dict[int, int], label_info: LabelInfo | |
| ) -> list[tuple[int, int, int]]: | |
| spans: list[tuple[int, int, int]] = [] | |
| current_label: int | None = None | |
| start_idx: int | None = None | |
| previous_idx: int | None = None | |
| background_span_label = label_info.background_span_label | |
| for token_idx in sorted(labels_by_index): | |
| label_id = labels_by_index[token_idx] | |
| span_label = label_info.token_to_span_label.get(label_id) | |
| boundary_tag = label_info.token_boundary_tags.get(label_id) | |
| if previous_idx is not None and token_idx != previous_idx + 1: | |
| if current_label is not None and start_idx is not None: | |
| spans.append((current_label, start_idx, previous_idx + 1)) | |
| current_label = None | |
| start_idx = None | |
| if span_label is None: | |
| previous_idx = token_idx | |
| continue | |
| if span_label == background_span_label: | |
| if current_label is not None and start_idx is not None: | |
| spans.append((current_label, start_idx, token_idx)) | |
| current_label = None | |
| start_idx = None | |
| previous_idx = token_idx | |
| continue | |
| if boundary_tag == "S": | |
| if current_label is not None and start_idx is not None and previous_idx is not None: | |
| spans.append((current_label, start_idx, previous_idx + 1)) | |
| spans.append((span_label, token_idx, token_idx + 1)) | |
| current_label = None | |
| start_idx = None | |
| elif boundary_tag == "B": | |
| if current_label is not None and start_idx is not None and previous_idx is not None: | |
| spans.append((current_label, start_idx, previous_idx + 1)) | |
| current_label = span_label | |
| start_idx = token_idx | |
| elif boundary_tag == "I": | |
| if current_label is None or current_label != span_label: | |
| if current_label is not None and start_idx is not None and previous_idx is not None: | |
| spans.append((current_label, start_idx, previous_idx + 1)) | |
| current_label = span_label | |
| start_idx = token_idx | |
| elif boundary_tag == "E": | |
| if current_label is None or current_label != span_label or start_idx is None: | |
| if current_label is not None and start_idx is not None and previous_idx is not None: | |
| spans.append((current_label, start_idx, previous_idx + 1)) | |
| spans.append((span_label, token_idx, token_idx + 1)) | |
| current_label = None | |
| start_idx = None | |
| else: | |
| spans.append((current_label, start_idx, token_idx + 1)) | |
| current_label = None | |
| start_idx = None | |
| else: | |
| if current_label is not None and start_idx is not None and previous_idx is not None: | |
| spans.append((current_label, start_idx, previous_idx + 1)) | |
| current_label = None | |
| start_idx = None | |
| previous_idx = token_idx | |
| if current_label is not None and start_idx is not None and previous_idx is not None: | |
| spans.append((current_label, start_idx, previous_idx + 1)) | |
| return spans | |
| def token_spans_to_char_spans( | |
| spans: Sequence[tuple[int, int, int]], | |
| char_starts: Sequence[int], | |
| char_ends: Sequence[int], | |
| ) -> list[tuple[int, int, int]]: | |
| converted: list[tuple[int, int, int]] = [] | |
| for label_idx, token_start, token_end in spans: | |
| if not (0 <= token_start < token_end <= len(char_starts)): | |
| continue | |
| char_start = char_starts[token_start] | |
| char_end = char_ends[token_end - 1] | |
| if char_end <= char_start: | |
| continue | |
| converted.append((label_idx, char_start, char_end)) | |
| return converted | |
| def trim_char_spans_whitespace( | |
| spans: Sequence[tuple[int, int, int]], | |
| text: str, | |
| ) -> list[tuple[int, int, int]]: | |
| trimmed: list[tuple[int, int, int]] = [] | |
| for label_idx, start, end in spans: | |
| if not (0 <= start < end <= len(text)): | |
| continue | |
| while start < end and text[start].isspace(): | |
| start += 1 | |
| while end > start and text[end - 1].isspace(): | |
| end -= 1 | |
| if end > start: | |
| trimmed.append((label_idx, start, end)) | |
| return trimmed | |
| class InferenceRuntime: | |
| model: Transformer | |
| encoding: tiktoken.Encoding | |
| label_info: LabelInfo | |
| device: torch.device | |
| n_ctx: int | |
| def get_viterbi_transition_biases() -> dict[str, float]: | |
| calibration_path = MODEL_DIR / "viterbi_calibration.json" | |
| default_biases = {key: 0.0 for key in VITERBI_TRANSITION_BIAS_KEYS} | |
| if not calibration_path.is_file(): | |
| return default_biases | |
| payload = json.loads(calibration_path.read_text(encoding="utf-8")) | |
| if not isinstance(payload, dict): | |
| raise ValueError(f"Invalid Viterbi calibration payload at {calibration_path}") | |
| raw_biases: object = payload | |
| operating_points = payload.get("operating_points") | |
| if operating_points is not None: | |
| if not isinstance(operating_points, dict): | |
| raise ValueError(f"Invalid operating_points payload at {calibration_path}") | |
| preset_entry = operating_points.get(DEFAULT_VITERBI_CALIBRATION_PRESET) | |
| if not isinstance(preset_entry, dict): | |
| raise ValueError( | |
| f"Missing operating_points.{DEFAULT_VITERBI_CALIBRATION_PRESET!s} " | |
| f"in {calibration_path}" | |
| ) | |
| raw_biases = preset_entry.get("biases") | |
| if not isinstance(raw_biases, dict): | |
| raise ValueError(f"Invalid Viterbi bias payload at {calibration_path}") | |
| resolved_biases: dict[str, float] = {} | |
| for key in VITERBI_TRANSITION_BIAS_KEYS: | |
| raw_value = raw_biases.get(key) | |
| if isinstance(raw_value, bool) or not isinstance(raw_value, (int, float)): | |
| raise ValueError(f"Missing or invalid {key!r} in {calibration_path}") | |
| resolved_biases[key] = float(raw_value) | |
| return resolved_biases | |
| def get_runtime() -> InferenceRuntime: | |
| checkpoint = MODEL_DIR | |
| if not checkpoint.exists() or not checkpoint.is_dir(): | |
| raise FileNotFoundError(f"Checkpoint directory not found: {checkpoint}") | |
| if not any(checkpoint.glob("*.safetensors")): | |
| raise FileNotFoundError(f"Checkpoint directory has no .safetensors files: {checkpoint}") | |
| if not torch.cuda.is_available(): | |
| raise RuntimeError("CUDA is not available") | |
| config_path = checkpoint / "config.json" | |
| checkpoint_config = json.loads(config_path.read_text(encoding="utf-8")) | |
| if not isinstance(checkpoint_config, dict): | |
| raise ValueError(f"Invalid checkpoint config payload at {config_path}") | |
| validate_model_config_contract( | |
| checkpoint_config, | |
| context=str(config_path), | |
| ) | |
| ner_class_names = NER_CLASS_NAMES | |
| device = torch.device("cuda") | |
| n_ctx = int(checkpoint_config["default_n_ctx"]) | |
| encoding = tiktoken.get_encoding(str(checkpoint_config["encoding"]).strip()) | |
| span_class_names: list[str] = [BACKGROUND_CLASS_LABEL] | |
| span_label_lookup: dict[str, int] = {BACKGROUND_CLASS_LABEL: 0} | |
| boundary_label_lookup: dict[str, dict[str, int]] = {} | |
| token_to_span_label: dict[int, int] = {} | |
| token_boundary_tags: dict[int, str | None] = {} | |
| background_idx: int | None = None | |
| for idx, name in enumerate(ner_class_names): | |
| if name == BACKGROUND_CLASS_LABEL: | |
| background_idx = idx | |
| token_to_span_label[idx] = span_label_lookup[BACKGROUND_CLASS_LABEL] | |
| token_boundary_tags[idx] = None | |
| continue | |
| boundary, base_label = name.split("-", 1) | |
| span_idx = span_label_lookup.get(base_label) | |
| if span_idx is None: | |
| span_idx = len(span_class_names) | |
| span_class_names.append(base_label) | |
| span_label_lookup[base_label] = span_idx | |
| token_to_span_label[idx] = span_idx | |
| token_boundary_tags[idx] = boundary | |
| boundary_label_lookup.setdefault(base_label, {})[boundary] = idx | |
| if background_idx is None: | |
| raise ValueError("Class names must include background label 'O'") | |
| for base_label, mapping in boundary_label_lookup.items(): | |
| missing = set(BOUNDARY_PREFIXES) - set(mapping) | |
| if missing: | |
| raise ValueError( | |
| f"Missing boundary classes {sorted(missing)} for base label {base_label}" | |
| ) | |
| label_info = LabelInfo( | |
| boundary_label_lookup={key: dict(value) for key, value in boundary_label_lookup.items()}, | |
| token_to_span_label=dict(token_to_span_label), | |
| token_boundary_tags=dict(token_boundary_tags), | |
| span_class_names=tuple(span_class_names), | |
| span_label_lookup=dict(span_label_lookup), | |
| background_token_label=background_idx, | |
| background_span_label=span_label_lookup[BACKGROUND_CLASS_LABEL], | |
| ) | |
| model = Transformer.from_checkpoint( | |
| checkpoint, | |
| device=device, | |
| ) | |
| return InferenceRuntime( | |
| model=model, | |
| encoding=encoding, | |
| label_info=label_info, | |
| device=device, | |
| n_ctx=n_ctx, | |
| ) | |
| class Decoder: | |
| def __init__(self, label_info: LabelInfo) -> None: | |
| self.label_info = label_info | |
| num_classes = len(label_info.token_to_span_label) | |
| self._start_scores = torch.full((num_classes,), -1e9, dtype=torch.float32) | |
| self._end_scores = torch.full((num_classes,), -1e9, dtype=torch.float32) | |
| self._transition_scores = torch.full((num_classes, num_classes), -1e9, dtype=torch.float32) | |
| transition_biases = get_viterbi_transition_biases() | |
| background_token_idx = label_info.background_token_label | |
| background_span_idx = label_info.background_span_label | |
| token_boundary_tags = label_info.token_boundary_tags | |
| token_to_span_label = label_info.token_to_span_label | |
| for idx in range(num_classes): | |
| tag = token_boundary_tags.get(idx) | |
| span_label = token_to_span_label.get(idx) | |
| if tag in {"B", "S"} or idx == background_token_idx: | |
| self._start_scores[idx] = 0.0 | |
| if tag in {"E", "S"} or idx == background_token_idx: | |
| self._end_scores[idx] = 0.0 | |
| for next_idx in range(num_classes): | |
| next_tag = token_boundary_tags.get(next_idx) | |
| next_span_label = token_to_span_label.get(next_idx) | |
| if self._is_valid_transition( | |
| prev_tag=tag, | |
| prev_span=span_label, | |
| next_tag=next_tag, | |
| next_span=next_span_label, | |
| background_token_idx=background_token_idx, | |
| background_span_idx=background_span_idx, | |
| next_idx=next_idx, | |
| ): | |
| self._transition_scores[idx, next_idx] = self._transition_bias( | |
| prev_tag=tag, | |
| prev_span=span_label, | |
| next_tag=next_tag, | |
| next_span=next_span_label, | |
| background_span_idx=background_span_idx, | |
| biases=transition_biases, | |
| ) | |
| def _is_valid_transition( | |
| *, | |
| prev_tag: str | None, | |
| prev_span: int | None, | |
| next_tag: str | None, | |
| next_span: int | None, | |
| background_token_idx: int, | |
| background_span_idx: int, | |
| next_idx: int, | |
| ) -> bool: | |
| next_is_background = next_span == background_span_idx or next_idx == background_token_idx | |
| if (next_span is None or next_tag is None) and not next_is_background: | |
| return False | |
| if prev_span is None or prev_tag is None: | |
| return next_is_background or next_tag in {"B", "S"} | |
| prev_is_background = prev_span == background_span_idx | |
| if prev_is_background or prev_tag in {"E", "S"}: | |
| return next_is_background or next_tag in {"B", "S"} | |
| if prev_tag in {"B", "I"}: | |
| return prev_span == next_span and next_tag in {"I", "E"} | |
| return False | |
| def _transition_bias( | |
| *, | |
| prev_tag: str | None, | |
| prev_span: int | None, | |
| next_tag: str | None, | |
| next_span: int | None, | |
| background_span_idx: int, | |
| biases: dict[str, float], | |
| ) -> float: | |
| next_is_background = next_span == background_span_idx | |
| prev_is_background = prev_span == background_span_idx | |
| if prev_is_background: | |
| return ( | |
| biases["transition_bias_background_stay"] | |
| if next_is_background | |
| else biases["transition_bias_background_to_start"] | |
| ) | |
| if prev_tag in {"B", "I"}: | |
| return ( | |
| biases["transition_bias_inside_to_continue"] | |
| if next_tag == "I" | |
| else biases["transition_bias_inside_to_end"] | |
| ) | |
| return ( | |
| biases["transition_bias_end_to_background"] | |
| if next_is_background | |
| else biases["transition_bias_end_to_start"] | |
| ) | |
| def decode(self, token_logprobs: torch.Tensor) -> list[int]: | |
| if token_logprobs.ndim != 2: | |
| raise ValueError("token_logprobs must have shape [seq_len, num_classes]") | |
| seq_len, num_classes = token_logprobs.shape | |
| if seq_len == 0: | |
| return [] | |
| start_scores = self._start_scores.to( | |
| device=token_logprobs.device, | |
| dtype=token_logprobs.dtype, | |
| ) | |
| end_scores = self._end_scores.to( | |
| device=token_logprobs.device, | |
| dtype=token_logprobs.dtype, | |
| ) | |
| transition_scores = self._transition_scores.to( | |
| device=token_logprobs.device, | |
| dtype=token_logprobs.dtype, | |
| ) | |
| scores = token_logprobs[0] + start_scores | |
| backpointers = torch.empty( | |
| (seq_len - 1, num_classes), | |
| device=token_logprobs.device, | |
| dtype=torch.int64, | |
| ) | |
| for idx in range(1, seq_len): | |
| transitions = scores.unsqueeze(1) + transition_scores | |
| best_scores, best_paths = transitions.max(dim=0) | |
| scores = best_scores + token_logprobs[idx] | |
| backpointers[idx - 1] = best_paths | |
| if not torch.isfinite(scores).any(): | |
| return token_logprobs.argmax(dim=1).tolist() | |
| scores = scores + end_scores | |
| last_label = scores.argmax() | |
| path = torch.empty((seq_len,), device=token_logprobs.device, dtype=torch.int64) | |
| path[-1] = last_label | |
| for idx in range(seq_len - 2, -1, -1): | |
| last_label = backpointers[idx, last_label] | |
| path[idx] = last_label | |
| return path.tolist() | |
| def predict_text( | |
| runtime: InferenceRuntime, | |
| text: str, | |
| decoder: Decoder, | |
| ) -> tuple[str, list[dict[str, object]]]: | |
| token_ids = tuple(int(token) for token in runtime.encoding.encode(text, allowed_special="all")) | |
| if not token_ids: | |
| return text, [] | |
| if runtime.n_ctx <= 0: | |
| raise ValueError("runtime.n_ctx must be positive") | |
| token_score_vectors: list[torch.Tensor] = [] | |
| for start in range(0, len(token_ids), runtime.n_ctx): | |
| end = min(start + runtime.n_ctx, len(token_ids)) | |
| window_tokens = torch.tensor(token_ids[start:end], device=runtime.device, dtype=torch.int32) | |
| logits = runtime.model(window_tokens) | |
| log_probs = F.log_softmax(logits.float(), dim=-1) | |
| if log_probs.shape[0] != window_tokens.shape[0]: | |
| raise ValueError("Logprob output length does not match window length") | |
| token_score_vectors.extend(log_probs.unbind(0)) | |
| if not token_score_vectors: | |
| return text, [] | |
| stacked_scores = torch.stack(token_score_vectors, dim=0) | |
| decoded_labels = decoder.decode(stacked_scores) | |
| if len(decoded_labels) != len(token_ids): | |
| decoded_labels = stacked_scores.argmax(dim=1).tolist() | |
| predicted_labels_by_index = { | |
| token_idx: int(label) for token_idx, label in enumerate(decoded_labels) | |
| } | |
| predicted_token_spans = labels_to_spans(predicted_labels_by_index, runtime.label_info) | |
| token_bytes = [runtime.encoding.decode_single_token_bytes(token_id) for token_id in token_ids] | |
| decoded_text = b"".join(token_bytes).decode("utf-8", errors="replace") | |
| char_byte_starts: list[int] = [] | |
| char_byte_ends: list[int] = [] | |
| byte_cursor = 0 | |
| for ch in decoded_text: | |
| char_byte_starts.append(byte_cursor) | |
| byte_cursor += len(ch.encode("utf-8")) | |
| char_byte_ends.append(byte_cursor) | |
| char_starts: list[int] = [] | |
| char_ends: list[int] = [] | |
| token_byte_cursor = 0 | |
| for raw_bytes in token_bytes: | |
| token_byte_start = token_byte_cursor | |
| token_byte_end = token_byte_start + len(raw_bytes) | |
| token_byte_cursor = token_byte_end | |
| start_idx = bisect_right(char_byte_ends, token_byte_start) | |
| end_idx = bisect_left(char_byte_starts, token_byte_end) | |
| if end_idx < start_idx: | |
| end_idx = start_idx | |
| char_starts.append(start_idx) | |
| char_ends.append(end_idx) | |
| if char_ends and char_ends[-1] != len(decoded_text): | |
| raise ValueError( | |
| f"Character length mismatch for decoded text (tokens={char_ends[-1]}, text={len(decoded_text)})" | |
| ) | |
| decoded_mismatch = decoded_text != text | |
| source_text = decoded_text if decoded_mismatch else text | |
| predicted_char_spans = token_spans_to_char_spans( | |
| predicted_token_spans, | |
| char_starts, | |
| char_ends, | |
| ) | |
| predicted_char_spans = trim_char_spans_whitespace(predicted_char_spans, source_text) | |
| detected: list[dict[str, object]] = [] | |
| for label_idx, start, end in predicted_char_spans: | |
| if not (0 <= start < end <= len(source_text)): | |
| continue | |
| label = ( | |
| runtime.label_info.span_class_names[label_idx] | |
| if 0 <= label_idx < len(runtime.label_info.span_class_names) | |
| else f"label_{label_idx}" | |
| ) | |
| detected.append( | |
| { | |
| "entity": label, | |
| "start": int(start), | |
| "end": int(end), | |
| } | |
| ) | |
| return source_text, detected | |
| def predict(text: str) -> dict[str, object]: | |
| text = text or "" | |
| if not text.strip(): | |
| return EMPTY_HIGHLIGHT_PAYLOAD | |
| runtime = get_runtime() | |
| decoder = Decoder(label_info=runtime.label_info) | |
| filtered_text, spans = predict_text(runtime, text, decoder) | |
| return { | |
| "text": filtered_text, | |
| "entities": spans, | |
| } | |
| def build_redacted_text(text: str, entities: Sequence[dict[str, object]]) -> str: | |
| if not text or not entities: | |
| return text | |
| redacted_parts: list[str] = [] | |
| cursor = 0 | |
| sorted_entities = sorted( | |
| entities, | |
| key=lambda item: ( | |
| int(item.get("start", 0)), | |
| int(item.get("end", 0)), | |
| ), | |
| ) | |
| for entity in sorted_entities: | |
| start_raw = entity.get("start") | |
| end_raw = entity.get("end") | |
| label_raw = entity.get("entity") | |
| if not isinstance(start_raw, int) or not isinstance(end_raw, int): | |
| continue | |
| if not isinstance(label_raw, str): | |
| continue | |
| if start_raw < cursor or start_raw >= end_raw: | |
| continue | |
| start = max(0, min(start_raw, len(text))) | |
| end = max(0, min(end_raw, len(text))) | |
| if start < cursor or start >= end: | |
| continue | |
| redacted_parts.append(text[cursor:start]) | |
| replacement = REDACTION_LABEL_MAP.get(label_raw, "[REDACTED]") | |
| redacted_parts.append(replacement) | |
| cursor = end | |
| redacted_parts.append(text[cursor:]) | |
| return "".join(redacted_parts) | |
| def summarize_entities_markdown(entities: Sequence[dict[str, object]]) -> str: | |
| if not entities: | |
| return EMPTY_SUMMARY_MARKDOWN | |
| counts: dict[str, int] = {} | |
| for entity in entities: | |
| label = entity.get("entity") | |
| if not isinstance(label, str): | |
| continue | |
| counts[label] = counts.get(label, 0) + 1 | |
| if not counts: | |
| return EMPTY_SUMMARY_MARKDOWN | |
| ordered_labels = sorted(counts.items(), key=lambda item: (-item[1], item[0])) | |
| lines = ["**Detected entities**"] | |
| lines.extend(f"- `{label}`: {count}" for label, count in ordered_labels) | |
| return "\n".join(lines) | |
| def predict_for_demo(text: str) -> tuple[dict[str, object], str, str]: | |
| prediction = predict(text) | |
| detected = prediction.get("entities") | |
| source_text = prediction.get("text") | |
| entities = detected if isinstance(detected, list) else [] | |
| display_text = source_text if isinstance(source_text, str) else (text or "") | |
| redacted_text = build_redacted_text(display_text, entities) | |
| summary = summarize_entities_markdown(entities) | |
| return prediction, redacted_text, summary | |
| def build_demo() -> gr.Blocks: | |
| config_path = MODEL_DIR / "config.json" | |
| checkpoint_config = json.loads(config_path.read_text(encoding="utf-8")) | |
| if not isinstance(checkpoint_config, dict): | |
| raise ValueError(f"Invalid checkpoint config payload at {config_path}") | |
| validate_model_config_contract( | |
| checkpoint_config, | |
| context=str(config_path), | |
| ) | |
| span_class_names = SPAN_CLASS_NAMES | |
| web_color_palette = ( | |
| "#e6194b", | |
| "#3cb44b", | |
| "#4363d8", | |
| "#f58231", | |
| "#911eb4", | |
| "#008080", | |
| "#9a6324", | |
| "#f032e6", | |
| "#b59f00", | |
| "#800000", | |
| "#000075", | |
| "#808080", | |
| ) | |
| with gr.Blocks( | |
| **supported_kwargs( | |
| gr.Blocks, | |
| title="OpenAI Privacy Filter", | |
| fill_width=True, | |
| elem_id="privacy-filter-app", | |
| ) | |
| ) as demo: | |
| gr.Markdown("# OpenAI Privacy Filter Demo") | |
| gr.Markdown( | |
| "Detect and redact personal identifiers using `openai/privacy-filter`.\n\n" | |
| "This demo highlights predicted spans and generates a redacted text variant " | |
| "with label placeholders." | |
| ) | |
| with gr.Column(variant="panel"): | |
| input_text = gr.Textbox( | |
| **supported_kwargs( | |
| gr.Textbox, | |
| lines=6, | |
| label="Input text with PII", | |
| placeholder="Paste text to detect personal identifiers and generate redacted output...", | |
| container=False, | |
| ) | |
| ) | |
| with gr.Row(): | |
| submit_button = gr.Button("Detect & Redact", variant="primary") | |
| clear_button = gr.Button("Clear") | |
| with gr.Column(variant="panel"): | |
| output_text = gr.HighlightedText( | |
| **supported_kwargs( | |
| gr.HighlightedText, | |
| label="Detected entities (highlighted)", | |
| value=EMPTY_HIGHLIGHT_PAYLOAD, | |
| color_map={ | |
| label: web_color_palette[idx % len(web_color_palette)] | |
| for idx, label in enumerate( | |
| label for label in span_class_names if label != BACKGROUND_CLASS_LABEL | |
| ) | |
| }, | |
| combine_adjacent=False, | |
| show_legend=False, | |
| container=True, | |
| ) | |
| ) | |
| redacted_output = gr.Textbox( | |
| **supported_kwargs( | |
| gr.Textbox, | |
| label="Redacted text output", | |
| lines=6, | |
| show_copy_button=True, | |
| interactive=False, | |
| ) | |
| ) | |
| entity_summary = gr.Markdown(EMPTY_SUMMARY_MARKDOWN) | |
| with gr.Accordion("How to read results", open=False): | |
| gr.Markdown( | |
| "- Detects 8 span categories: person, email, phone, address, date, URL, " | |
| "account number, and secrets.\n" | |
| "- Uses sequence decoding (BIOES + constrained Viterbi) for cleaner boundaries.\n" | |
| "- Best treated as a redaction aid, not a standalone compliance or anonymization guarantee.\n" | |
| "- Official card notes strongest support is English, with limited multilingual robustness." | |
| ) | |
| submit_button.click( | |
| fn=predict_for_demo, | |
| inputs=input_text, | |
| outputs=[output_text, redacted_output, entity_summary], | |
| api_name="predict_and_redact", | |
| ) | |
| input_text.submit( | |
| fn=predict_for_demo, | |
| inputs=input_text, | |
| outputs=[output_text, redacted_output, entity_summary], | |
| ) | |
| clear_button.click( | |
| lambda: ("", EMPTY_HIGHLIGHT_PAYLOAD, "", EMPTY_SUMMARY_MARKDOWN), | |
| outputs=[input_text, output_text, redacted_output, entity_summary], | |
| ) | |
| gr.Markdown("### Multilingual quick examples") | |
| gr.Examples( | |
| examples=[ | |
| ["Alice was born on 1990-01-02 and lives at 1 Main St."], | |
| ["Email me at alice@example.com or call 415-555-0101."], | |
| ["Me llamo Laura Gómez y vivo en Calle de Alcalá 21, Madrid."], | |
| ["Mon e-mail est jean.dupont@example.fr et mon téléphone est +33 6 12 34 56 78."], | |
| ["私の名前は山田太郎です。メールはtaro.yamada@example.jpです。"], | |
| ["اسمي أحمد وبريدي هو ahmed@example.com ورقم هاتفي +971501234567."], | |
| ], | |
| inputs=input_text, | |
| outputs=[output_text, redacted_output, entity_summary], | |
| fn=predict_for_demo, | |
| cache_examples=False, | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = build_demo() | |
| demo.launch() |