""" HaremPii — 1-layer surgical inference wrapper over OpenAI Privacy Filter. Defines: * HaremPiiForTokenClassification — subclass of OpenAIPrivacyFilterForTokenClassification. Reuses the upstream forward pass and adds eval-time constrained-BIOES Viterbi decoding so `outputs.logits.argmax(-1)` returns the Viterbi path. * HaremPiiModel — encoder alias pinned to HaremPiiConfig. The model class is auto-registered so `AutoModelForTokenClassification.from_pretrained(repo, trust_remote_code=True)` dispatches to us via `config.auto_map` (model_type "haremb_pii"). This file is the released, inference-only copy. It contains no training-related utilities. """ from __future__ import annotations from typing import Optional import torch import torch.nn as nn from transformers import ( AutoConfig, AutoModel, AutoModelForTokenClassification, ) from transformers.models.openai_privacy_filter.modeling_openai_privacy_filter import ( OpenAIPrivacyFilterForTokenClassification, OpenAIPrivacyFilterModel, ) from configuration_haremb_pii import HaremPiiConfig # --------------------------------------------------------------------------- # Constrained BIOES Viterbi (inlined so the checkpoint is self-contained) # --------------------------------------------------------------------------- # Transition rules: # O -> {O, B-X, S-X} # B-X -> {I-X, E-X} # I-X -> {I-X, E-X} # E-X -> {O, B-Y, S-Y} # S-X -> {O, B-Y, S-Y} # Initial state allows {O, B-X, S-X} only. def _parse_bioes(label: str): if label == "O" or "-" not in label: return "O", None pref, cat = label.split("-", 1) return pref, cat def _build_bioes_transition_mask(id2label) -> torch.Tensor: C = len(id2label) mask = torch.full((C, C), float("-inf")) parsed = {i: _parse_bioes(id2label[i]) for i in range(C)} for i, (p_prev, c_prev) in parsed.items(): for j, (p_cur, c_cur) in parsed.items(): ok = False if p_prev == "O": if p_cur in ("O", "B", "S"): ok = True elif p_prev == "B": if p_cur in ("I", "E") and c_cur == c_prev: ok = True elif p_prev == "I": if p_cur in ("I", "E") and c_cur == c_prev: ok = True elif p_prev in ("E", "S"): if p_cur in ("O", "B", "S"): ok = True if ok: mask[i, j] = 0.0 return mask def _build_bioes_initial_mask(id2label) -> torch.Tensor: C = len(id2label) mask = torch.full((C,), float("-inf")) for i, lbl in id2label.items(): p, _ = _parse_bioes(lbl) if p in ("O", "B", "S"): mask[i] = 0.0 return mask def _bioes_viterbi( logits: torch.Tensor, transition_mask: torch.Tensor, initial_mask: torch.Tensor, ) -> torch.Tensor: if logits.dim() != 2: raise ValueError(f"expected 2D logits, got {logits.shape}") T = logits.shape[0] mask = torch.ones((1, T), dtype=torch.long, device=logits.device) out = _bioes_viterbi_batched( logits.unsqueeze(0), mask, transition_mask, initial_mask, ) return out[0] def _bioes_viterbi_batched( logits: torch.Tensor, attention_mask: torch.Tensor, transition_mask: torch.Tensor, initial_mask: torch.Tensor, ) -> torch.Tensor: """Vectorized constrained BIOES Viterbi. Args: logits: [B, T, C] float attention_mask: [B, T] {0, 1} long/bool transition_mask: [C, C] 0 valid, -inf invalid initial_mask: [C] 0 allowed first tag, -inf forbidden Returns: [B, T] LongTensor of best constrained-BIOES tag id per token; padded positions hold -1. """ if logits.dim() != 3: raise ValueError(f"expected 3D logits [B,T,C], got {logits.shape}") device = logits.device B, T, C = logits.shape scores = logits.float() trans = transition_mask.to(device).float() init = initial_mask.to(device).float() mask = attention_mask.to(device).bool() dp = scores[:, 0] + init.unsqueeze(0) back = torch.zeros((B, T, C), dtype=torch.long, device=device) trans_b = trans.unsqueeze(0) for t in range(1, T): cand = dp.unsqueeze(2) + trans_b best_val, best_prev = cand.max(dim=1) new_dp = best_val + scores[:, t] keep = mask[:, t].unsqueeze(1) dp = torch.where(keep, new_dp, dp) back[:, t] = best_prev last_t = (mask.sum(dim=1) - 1).clamp_min(0) best_last = dp.argmax(dim=1) out = torch.full((B, T), -1, dtype=torch.long, device=device) batch_idx = torch.arange(B, device=device) out[batch_idx, last_t] = best_last current = best_last.clone() for t in range(T - 1, 0, -1): new_current = torch.gather( back[:, t, :], 1, current.unsqueeze(1) ).squeeze(1) active = (t <= last_t) current = torch.where(active, new_current, current) out[batch_idx, t - 1] = torch.where( active, current, out[batch_idx, t - 1], ) return out # --------------------------------------------------------------------------- # Architecture classes # --------------------------------------------------------------------------- class HaremPiiModel(OpenAIPrivacyFilterModel): """Thin alias of the upstream encoder pinned to HaremPiiConfig.""" config_class = HaremPiiConfig class HaremPiiForTokenClassification(OpenAIPrivacyFilterForTokenClassification): """1-layer student. Wraps the upstream forward with eval-time constrained-BIOES Viterbi decoding.""" config_class = HaremPiiConfig def __init__(self, config: HaremPiiConfig): # Bypass GenericForTokenClassification.__init__ because it calls # AutoModel.from_config(config), which uses type(config) as the # registry key. Under the trust_remote_code Hub-loading path the # cached HaremPiiConfig class identity differs from whatever was # registered at module import (the cache hosts the class under a # synthetic, sha-qualified module name). Constructing the encoder # directly avoids the registry dispatch entirely. from transformers.modeling_utils import PreTrainedModel as _PreTrainedModel _PreTrainedModel.__init__(self, config) self.num_labels = config.num_labels self.model = OpenAIPrivacyFilterModel(config) if getattr(config, "classifier_dropout", None) is not None: classifier_dropout = config.classifier_dropout elif getattr(config, "hidden_dropout", None) is not None: classifier_dropout = config.hidden_dropout else: classifier_dropout = 0.1 self.dropout = nn.Dropout(classifier_dropout) self.score = nn.Linear(config.hidden_size, config.num_labels) self.post_init() self._viterbi_trans_mask = None self._viterbi_init_mask = None def _ensure_viterbi_masks(self): if self._viterbi_trans_mask is None: id2label = {int(k): v for k, v in self.config.id2label.items()} self._viterbi_trans_mask = _build_bioes_transition_mask(id2label) self._viterbi_init_mask = _build_bioes_initial_mask(id2label) return self._viterbi_trans_mask, self._viterbi_init_mask @torch.no_grad() def decode_predictions( self, logits: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: trans, init = self._ensure_viterbi_masks() if logits.dim() == 2: T = logits.shape[0] mask = torch.ones((1, T), dtype=torch.long, device=logits.device) return _bioes_viterbi_batched( logits.unsqueeze(0), mask, trans, init, )[0] if attention_mask is None: attention_mask = torch.ones( logits.shape[:2], dtype=torch.long, device=logits.device, ) return _bioes_viterbi_batched(logits, attention_mask, trans, init) def forward(self, *args, **kwargs): outputs = super().forward(*args, **kwargs) if self.training: return outputs if not getattr(self.config, "use_viterbi_decode", True): return outputs attn_mask = kwargs.get("attention_mask", None) if attn_mask is None and len(args) >= 2: attn_mask = args[1] decoded = self.decode_predictions(outputs.logits, attention_mask=attn_mask) try: outputs.predicted_labels = decoded except Exception: outputs.__dict__["predicted_labels"] = decoded if getattr(self.config, "viterbi_replace_logits", True): raw = outputs.logits fake = torch.full_like(raw, fill_value=-1e9) fake.scatter_(-1, decoded.clamp_min(0).unsqueeze(-1), 1e9) try: outputs.raw_logits = raw outputs.logits = fake except Exception: outputs.__dict__["raw_logits"] = raw outputs.__dict__["logits"] = fake return outputs # --- Auto-registry --- AutoConfig.register("haremb_pii", HaremPiiConfig, exist_ok=True) AutoModel.register(HaremPiiConfig, HaremPiiModel, exist_ok=True) AutoModelForTokenClassification.register( HaremPiiConfig, HaremPiiForTokenClassification, exist_ok=True, ) HaremPiiConfig.register_for_auto_class("AutoConfig") HaremPiiForTokenClassification.register_for_auto_class("AutoModelForTokenClassification") __all__ = [ "HaremPiiConfig", "HaremPiiModel", "HaremPiiForTokenClassification", ]