| """ |
| HaremPii — 1-layer surgical inference wrapper over OpenAI Privacy Filter. |
| |
| Defines: |
| * HaremPiiForTokenClassification — subclass of |
| OpenAIPrivacyFilterForTokenClassification. Reuses the upstream forward |
| pass and adds eval-time constrained-BIOES Viterbi decoding so |
| `outputs.logits.argmax(-1)` returns the Viterbi path. |
| * HaremPiiModel — encoder alias pinned to HaremPiiConfig. |
| |
| The model class is auto-registered so |
| `AutoModelForTokenClassification.from_pretrained(repo, trust_remote_code=True)` |
| dispatches to us via `config.auto_map` (model_type "haremb_pii"). |
| |
| This file is the released, inference-only copy. It contains no |
| training-related utilities. |
| """ |
| from __future__ import annotations |
|
|
| from typing import Optional |
|
|
| import torch |
| import torch.nn as nn |
| from transformers import ( |
| AutoConfig, |
| AutoModel, |
| AutoModelForTokenClassification, |
| ) |
| from transformers.models.openai_privacy_filter.modeling_openai_privacy_filter import ( |
| OpenAIPrivacyFilterForTokenClassification, |
| OpenAIPrivacyFilterModel, |
| ) |
|
|
| from configuration_haremb_pii import HaremPiiConfig |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| def _parse_bioes(label: str): |
| if label == "O" or "-" not in label: |
| return "O", None |
| pref, cat = label.split("-", 1) |
| return pref, cat |
|
|
|
|
| def _build_bioes_transition_mask(id2label) -> torch.Tensor: |
| C = len(id2label) |
| mask = torch.full((C, C), float("-inf")) |
| parsed = {i: _parse_bioes(id2label[i]) for i in range(C)} |
| for i, (p_prev, c_prev) in parsed.items(): |
| for j, (p_cur, c_cur) in parsed.items(): |
| ok = False |
| if p_prev == "O": |
| if p_cur in ("O", "B", "S"): |
| ok = True |
| elif p_prev == "B": |
| if p_cur in ("I", "E") and c_cur == c_prev: |
| ok = True |
| elif p_prev == "I": |
| if p_cur in ("I", "E") and c_cur == c_prev: |
| ok = True |
| elif p_prev in ("E", "S"): |
| if p_cur in ("O", "B", "S"): |
| ok = True |
| if ok: |
| mask[i, j] = 0.0 |
| return mask |
|
|
|
|
| def _build_bioes_initial_mask(id2label) -> torch.Tensor: |
| C = len(id2label) |
| mask = torch.full((C,), float("-inf")) |
| for i, lbl in id2label.items(): |
| p, _ = _parse_bioes(lbl) |
| if p in ("O", "B", "S"): |
| mask[i] = 0.0 |
| return mask |
|
|
|
|
| def _bioes_viterbi( |
| logits: torch.Tensor, |
| transition_mask: torch.Tensor, |
| initial_mask: torch.Tensor, |
| ) -> torch.Tensor: |
| if logits.dim() != 2: |
| raise ValueError(f"expected 2D logits, got {logits.shape}") |
| T = logits.shape[0] |
| mask = torch.ones((1, T), dtype=torch.long, device=logits.device) |
| out = _bioes_viterbi_batched( |
| logits.unsqueeze(0), mask, transition_mask, initial_mask, |
| ) |
| return out[0] |
|
|
|
|
| def _bioes_viterbi_batched( |
| logits: torch.Tensor, |
| attention_mask: torch.Tensor, |
| transition_mask: torch.Tensor, |
| initial_mask: torch.Tensor, |
| ) -> torch.Tensor: |
| """Vectorized constrained BIOES Viterbi. |
| |
| Args: |
| logits: [B, T, C] float |
| attention_mask: [B, T] {0, 1} long/bool |
| transition_mask: [C, C] 0 valid, -inf invalid |
| initial_mask: [C] 0 allowed first tag, -inf forbidden |
| |
| Returns: |
| [B, T] LongTensor of best constrained-BIOES tag id per token; padded |
| positions hold -1. |
| """ |
| if logits.dim() != 3: |
| raise ValueError(f"expected 3D logits [B,T,C], got {logits.shape}") |
| device = logits.device |
| B, T, C = logits.shape |
| scores = logits.float() |
| trans = transition_mask.to(device).float() |
| init = initial_mask.to(device).float() |
| mask = attention_mask.to(device).bool() |
|
|
| dp = scores[:, 0] + init.unsqueeze(0) |
| back = torch.zeros((B, T, C), dtype=torch.long, device=device) |
| trans_b = trans.unsqueeze(0) |
| for t in range(1, T): |
| cand = dp.unsqueeze(2) + trans_b |
| best_val, best_prev = cand.max(dim=1) |
| new_dp = best_val + scores[:, t] |
| keep = mask[:, t].unsqueeze(1) |
| dp = torch.where(keep, new_dp, dp) |
| back[:, t] = best_prev |
|
|
| last_t = (mask.sum(dim=1) - 1).clamp_min(0) |
| best_last = dp.argmax(dim=1) |
| out = torch.full((B, T), -1, dtype=torch.long, device=device) |
| batch_idx = torch.arange(B, device=device) |
| out[batch_idx, last_t] = best_last |
| current = best_last.clone() |
| for t in range(T - 1, 0, -1): |
| new_current = torch.gather( |
| back[:, t, :], 1, current.unsqueeze(1) |
| ).squeeze(1) |
| active = (t <= last_t) |
| current = torch.where(active, new_current, current) |
| out[batch_idx, t - 1] = torch.where( |
| active, current, out[batch_idx, t - 1], |
| ) |
| return out |
|
|
|
|
| |
| |
| |
|
|
| class HaremPiiModel(OpenAIPrivacyFilterModel): |
| """Thin alias of the upstream encoder pinned to HaremPiiConfig.""" |
| config_class = HaremPiiConfig |
|
|
|
|
| class HaremPiiForTokenClassification(OpenAIPrivacyFilterForTokenClassification): |
| """1-layer student. Wraps the upstream forward with eval-time |
| constrained-BIOES Viterbi decoding.""" |
| config_class = HaremPiiConfig |
|
|
| def __init__(self, config: HaremPiiConfig): |
| |
| |
| |
| |
| |
| |
| |
| from transformers.modeling_utils import PreTrainedModel as _PreTrainedModel |
| _PreTrainedModel.__init__(self, config) |
| self.num_labels = config.num_labels |
| self.model = OpenAIPrivacyFilterModel(config) |
| if getattr(config, "classifier_dropout", None) is not None: |
| classifier_dropout = config.classifier_dropout |
| elif getattr(config, "hidden_dropout", None) is not None: |
| classifier_dropout = config.hidden_dropout |
| else: |
| classifier_dropout = 0.1 |
| self.dropout = nn.Dropout(classifier_dropout) |
| self.score = nn.Linear(config.hidden_size, config.num_labels) |
| self.post_init() |
|
|
| self._viterbi_trans_mask = None |
| self._viterbi_init_mask = None |
|
|
| def _ensure_viterbi_masks(self): |
| if self._viterbi_trans_mask is None: |
| id2label = {int(k): v for k, v in self.config.id2label.items()} |
| self._viterbi_trans_mask = _build_bioes_transition_mask(id2label) |
| self._viterbi_init_mask = _build_bioes_initial_mask(id2label) |
| return self._viterbi_trans_mask, self._viterbi_init_mask |
|
|
| @torch.no_grad() |
| def decode_predictions( |
| self, |
| logits: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| trans, init = self._ensure_viterbi_masks() |
| if logits.dim() == 2: |
| T = logits.shape[0] |
| mask = torch.ones((1, T), dtype=torch.long, device=logits.device) |
| return _bioes_viterbi_batched( |
| logits.unsqueeze(0), mask, trans, init, |
| )[0] |
| if attention_mask is None: |
| attention_mask = torch.ones( |
| logits.shape[:2], dtype=torch.long, device=logits.device, |
| ) |
| return _bioes_viterbi_batched(logits, attention_mask, trans, init) |
|
|
| def forward(self, *args, **kwargs): |
| outputs = super().forward(*args, **kwargs) |
| if self.training: |
| return outputs |
| if not getattr(self.config, "use_viterbi_decode", True): |
| return outputs |
|
|
| attn_mask = kwargs.get("attention_mask", None) |
| if attn_mask is None and len(args) >= 2: |
| attn_mask = args[1] |
|
|
| decoded = self.decode_predictions(outputs.logits, attention_mask=attn_mask) |
| try: |
| outputs.predicted_labels = decoded |
| except Exception: |
| outputs.__dict__["predicted_labels"] = decoded |
|
|
| if getattr(self.config, "viterbi_replace_logits", True): |
| raw = outputs.logits |
| fake = torch.full_like(raw, fill_value=-1e9) |
| fake.scatter_(-1, decoded.clamp_min(0).unsqueeze(-1), 1e9) |
| try: |
| outputs.raw_logits = raw |
| outputs.logits = fake |
| except Exception: |
| outputs.__dict__["raw_logits"] = raw |
| outputs.__dict__["logits"] = fake |
| return outputs |
|
|
|
|
| |
| AutoConfig.register("haremb_pii", HaremPiiConfig, exist_ok=True) |
| AutoModel.register(HaremPiiConfig, HaremPiiModel, exist_ok=True) |
| AutoModelForTokenClassification.register( |
| HaremPiiConfig, HaremPiiForTokenClassification, exist_ok=True, |
| ) |
|
|
| HaremPiiConfig.register_for_auto_class("AutoConfig") |
| HaremPiiForTokenClassification.register_for_auto_class("AutoModelForTokenClassification") |
|
|
|
|
| __all__ = [ |
| "HaremPiiConfig", |
| "HaremPiiModel", |
| "HaremPiiForTokenClassification", |
| ] |
|
|