haremb-privacy-filter-opennemo / modeling_haremb_pii.py
fblgit's picture
Upload folder using huggingface_hub
f0f5785
"""
HaremPii — 1-layer surgical inference wrapper over OpenAI Privacy Filter.
Defines:
* HaremPiiForTokenClassification — subclass of
OpenAIPrivacyFilterForTokenClassification. Reuses the upstream forward
pass and adds eval-time constrained-BIOES Viterbi decoding so
`outputs.logits.argmax(-1)` returns the Viterbi path.
* HaremPiiModel — encoder alias pinned to HaremPiiConfig.
The model class is auto-registered so
`AutoModelForTokenClassification.from_pretrained(repo, trust_remote_code=True)`
dispatches to us via `config.auto_map` (model_type "haremb_pii").
This file is the released, inference-only copy. It contains no
training-related utilities.
"""
from __future__ import annotations
from typing import Optional
import torch
import torch.nn as nn
from transformers import (
AutoConfig,
AutoModel,
AutoModelForTokenClassification,
)
from transformers.models.openai_privacy_filter.modeling_openai_privacy_filter import (
OpenAIPrivacyFilterForTokenClassification,
OpenAIPrivacyFilterModel,
)
from configuration_haremb_pii import HaremPiiConfig
# ---------------------------------------------------------------------------
# Constrained BIOES Viterbi (inlined so the checkpoint is self-contained)
# ---------------------------------------------------------------------------
# Transition rules:
# O -> {O, B-X, S-X}
# B-X -> {I-X, E-X}
# I-X -> {I-X, E-X}
# E-X -> {O, B-Y, S-Y}
# S-X -> {O, B-Y, S-Y}
# Initial state allows {O, B-X, S-X} only.
def _parse_bioes(label: str):
if label == "O" or "-" not in label:
return "O", None
pref, cat = label.split("-", 1)
return pref, cat
def _build_bioes_transition_mask(id2label) -> torch.Tensor:
C = len(id2label)
mask = torch.full((C, C), float("-inf"))
parsed = {i: _parse_bioes(id2label[i]) for i in range(C)}
for i, (p_prev, c_prev) in parsed.items():
for j, (p_cur, c_cur) in parsed.items():
ok = False
if p_prev == "O":
if p_cur in ("O", "B", "S"):
ok = True
elif p_prev == "B":
if p_cur in ("I", "E") and c_cur == c_prev:
ok = True
elif p_prev == "I":
if p_cur in ("I", "E") and c_cur == c_prev:
ok = True
elif p_prev in ("E", "S"):
if p_cur in ("O", "B", "S"):
ok = True
if ok:
mask[i, j] = 0.0
return mask
def _build_bioes_initial_mask(id2label) -> torch.Tensor:
C = len(id2label)
mask = torch.full((C,), float("-inf"))
for i, lbl in id2label.items():
p, _ = _parse_bioes(lbl)
if p in ("O", "B", "S"):
mask[i] = 0.0
return mask
def _bioes_viterbi(
logits: torch.Tensor,
transition_mask: torch.Tensor,
initial_mask: torch.Tensor,
) -> torch.Tensor:
if logits.dim() != 2:
raise ValueError(f"expected 2D logits, got {logits.shape}")
T = logits.shape[0]
mask = torch.ones((1, T), dtype=torch.long, device=logits.device)
out = _bioes_viterbi_batched(
logits.unsqueeze(0), mask, transition_mask, initial_mask,
)
return out[0]
def _bioes_viterbi_batched(
logits: torch.Tensor,
attention_mask: torch.Tensor,
transition_mask: torch.Tensor,
initial_mask: torch.Tensor,
) -> torch.Tensor:
"""Vectorized constrained BIOES Viterbi.
Args:
logits: [B, T, C] float
attention_mask: [B, T] {0, 1} long/bool
transition_mask: [C, C] 0 valid, -inf invalid
initial_mask: [C] 0 allowed first tag, -inf forbidden
Returns:
[B, T] LongTensor of best constrained-BIOES tag id per token; padded
positions hold -1.
"""
if logits.dim() != 3:
raise ValueError(f"expected 3D logits [B,T,C], got {logits.shape}")
device = logits.device
B, T, C = logits.shape
scores = logits.float()
trans = transition_mask.to(device).float()
init = initial_mask.to(device).float()
mask = attention_mask.to(device).bool()
dp = scores[:, 0] + init.unsqueeze(0)
back = torch.zeros((B, T, C), dtype=torch.long, device=device)
trans_b = trans.unsqueeze(0)
for t in range(1, T):
cand = dp.unsqueeze(2) + trans_b
best_val, best_prev = cand.max(dim=1)
new_dp = best_val + scores[:, t]
keep = mask[:, t].unsqueeze(1)
dp = torch.where(keep, new_dp, dp)
back[:, t] = best_prev
last_t = (mask.sum(dim=1) - 1).clamp_min(0)
best_last = dp.argmax(dim=1)
out = torch.full((B, T), -1, dtype=torch.long, device=device)
batch_idx = torch.arange(B, device=device)
out[batch_idx, last_t] = best_last
current = best_last.clone()
for t in range(T - 1, 0, -1):
new_current = torch.gather(
back[:, t, :], 1, current.unsqueeze(1)
).squeeze(1)
active = (t <= last_t)
current = torch.where(active, new_current, current)
out[batch_idx, t - 1] = torch.where(
active, current, out[batch_idx, t - 1],
)
return out
# ---------------------------------------------------------------------------
# Architecture classes
# ---------------------------------------------------------------------------
class HaremPiiModel(OpenAIPrivacyFilterModel):
"""Thin alias of the upstream encoder pinned to HaremPiiConfig."""
config_class = HaremPiiConfig
class HaremPiiForTokenClassification(OpenAIPrivacyFilterForTokenClassification):
"""1-layer student. Wraps the upstream forward with eval-time
constrained-BIOES Viterbi decoding."""
config_class = HaremPiiConfig
def __init__(self, config: HaremPiiConfig):
# Bypass GenericForTokenClassification.__init__ because it calls
# AutoModel.from_config(config), which uses type(config) as the
# registry key. Under the trust_remote_code Hub-loading path the
# cached HaremPiiConfig class identity differs from whatever was
# registered at module import (the cache hosts the class under a
# synthetic, sha-qualified module name). Constructing the encoder
# directly avoids the registry dispatch entirely.
from transformers.modeling_utils import PreTrainedModel as _PreTrainedModel
_PreTrainedModel.__init__(self, config)
self.num_labels = config.num_labels
self.model = OpenAIPrivacyFilterModel(config)
if getattr(config, "classifier_dropout", None) is not None:
classifier_dropout = config.classifier_dropout
elif getattr(config, "hidden_dropout", None) is not None:
classifier_dropout = config.hidden_dropout
else:
classifier_dropout = 0.1
self.dropout = nn.Dropout(classifier_dropout)
self.score = nn.Linear(config.hidden_size, config.num_labels)
self.post_init()
self._viterbi_trans_mask = None
self._viterbi_init_mask = None
def _ensure_viterbi_masks(self):
if self._viterbi_trans_mask is None:
id2label = {int(k): v for k, v in self.config.id2label.items()}
self._viterbi_trans_mask = _build_bioes_transition_mask(id2label)
self._viterbi_init_mask = _build_bioes_initial_mask(id2label)
return self._viterbi_trans_mask, self._viterbi_init_mask
@torch.no_grad()
def decode_predictions(
self,
logits: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
trans, init = self._ensure_viterbi_masks()
if logits.dim() == 2:
T = logits.shape[0]
mask = torch.ones((1, T), dtype=torch.long, device=logits.device)
return _bioes_viterbi_batched(
logits.unsqueeze(0), mask, trans, init,
)[0]
if attention_mask is None:
attention_mask = torch.ones(
logits.shape[:2], dtype=torch.long, device=logits.device,
)
return _bioes_viterbi_batched(logits, attention_mask, trans, init)
def forward(self, *args, **kwargs):
outputs = super().forward(*args, **kwargs)
if self.training:
return outputs
if not getattr(self.config, "use_viterbi_decode", True):
return outputs
attn_mask = kwargs.get("attention_mask", None)
if attn_mask is None and len(args) >= 2:
attn_mask = args[1]
decoded = self.decode_predictions(outputs.logits, attention_mask=attn_mask)
try:
outputs.predicted_labels = decoded
except Exception:
outputs.__dict__["predicted_labels"] = decoded
if getattr(self.config, "viterbi_replace_logits", True):
raw = outputs.logits
fake = torch.full_like(raw, fill_value=-1e9)
fake.scatter_(-1, decoded.clamp_min(0).unsqueeze(-1), 1e9)
try:
outputs.raw_logits = raw
outputs.logits = fake
except Exception:
outputs.__dict__["raw_logits"] = raw
outputs.__dict__["logits"] = fake
return outputs
# --- Auto-registry ---
AutoConfig.register("haremb_pii", HaremPiiConfig, exist_ok=True)
AutoModel.register(HaremPiiConfig, HaremPiiModel, exist_ok=True)
AutoModelForTokenClassification.register(
HaremPiiConfig, HaremPiiForTokenClassification, exist_ok=True,
)
HaremPiiConfig.register_for_auto_class("AutoConfig")
HaremPiiForTokenClassification.register_for_auto_class("AutoModelForTokenClassification")
__all__ = [
"HaremPiiConfig",
"HaremPiiModel",
"HaremPiiForTokenClassification",
]