scibert-certainty-ordinal / modeling_bert_ordinal.py
Cbelem's picture
Upload 2 files
cc52b39 verified
"""
bert_ordinal.py
---------------
BERT-based ordinal regression model, fully integrated with the HuggingFace
Transformers API:
model.save_pretrained("my-checkpoint/")
model = BertOrdinal.from_pretrained("my-checkpoint/")
Architecture
------------
1. A (optionally frozen) BERT backbone.
2. A projection head on the [CLS] token:
Linear(hidden_size β†’ hidden_dim) β†’ ReLU β†’ Dropout(p) β†’ Linear(hidden_dim β†’ 1)
producing a single latent score s ∈ ℝ.
3. K-1 learnable raw_threshold parameters enforcing monotonicity via
cumsum(softplus(Β·)).
4. Cumulative-link probabilities:
P(Y ≀ j | x) = Οƒ(ΞΈ_j βˆ’ s)
Usage
-----
from bert_ordinal import BertOrdinalConfig, BertOrdinal
# ── Create from scratch ──────────────────────────────────────────────────
cfg = BertOrdinalConfig(
bert_model_name="bert-base-uncased",
num_classes=3,
hidden_dim=128,
dropout=0.1,
freeze_bert=True,
)
model = BertOrdinal(cfg)
# ── Save ────────────────────────────────────────────────────────────────
model.save_pretrained("my-checkpoint/")
tokenizer.save_pretrained("my-checkpoint/") # keep tokenizer alongside
# ── Reload ──────────────────────────────────────────────────────────────
model = BertOrdinal.from_pretrained("my-checkpoint/")
tokenizer = AutoTokenizer.from_pretrained("my-checkpoint/")
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, PreTrainedModel
from transformers.modeling_outputs import ModelOutput
from .configuration_bert_ordinal import BertOrdinalConfig
# ---------------------------------------------------------------------------
# 1. Output dataclass
# ---------------------------------------------------------------------------
@dataclass
class BertOrdinalOutput(ModelOutput):
"""
Return type of :class:`BertOrdinal`.
Attributes
----------
loss : torch.Tensor or None
Ordinal cross-entropy loss (scalar). Present only when ``labels``
are supplied.
logits : torch.Tensor (B,)
Raw latent score from the projection head.
predictions : torch.Tensor (B,)
Predicted class index β€” argmax of ``class_probs``.
cum_probs : torch.Tensor (B, K-1)
Cumulative probabilities P(Y ≀ j | x).
class_probs : torch.Tensor (B, K)
Per-class probabilities P(Y = j | x).
"""
loss: Optional[torch.Tensor] = None
logits: Optional[torch.Tensor] = None
predictions: Optional[torch.Tensor] = None
cum_probs: Optional[torch.Tensor] = None
class_probs: Optional[torch.Tensor] = None
# ---------------------------------------------------------------------------
# 3. Model β€” subclass PreTrainedModel for save / from_pretrained
# ---------------------------------------------------------------------------
class BertOrdinal(PreTrainedModel):
"""
BERT encoder with an ordinal-regression head.
Fully compatible with the HuggingFace checkpoint API::
model.save_pretrained("my-checkpoint/")
model = BertOrdinal.from_pretrained("my-checkpoint/")
What gets saved
~~~~~~~~~~~~~~~
``save_pretrained`` writes two files:
* ``config.json`` β€” the full :class:`BertOrdinalConfig` (including
``bert_model_name``, ``hidden_size``, thresholds shape, …).
* ``model.safetensors`` (or ``pytorch_model.bin``) β€” a **single flat
state_dict** containing both the BERT backbone weights and the
head/threshold parameters.
``from_pretrained`` reconstructs the model from the config (which
already has ``hidden_size`` cached), loads the state_dict, and
re-applies the ``freeze_bert`` setting β€” no internet access needed
after the first save.
"""
config_class = BertOrdinalConfig
def __init__(self, config: BertOrdinalConfig) -> None:
super().__init__(config)
K = config.num_classes
# ── 1. BERT backbone ────────────────────────────────────────────────
# If hidden_size is already in the config (i.e. we are being called
# from from_pretrained after a save), build the backbone from the
# cached backbone config instead of re-downloading weights β€”
# from_pretrained will overwrite with the saved state_dict anyway.
self.bert = AutoModel.from_pretrained(config.bert_model_name)
hidden_size: int = self.bert.config.hidden_size
# Cache so the head can be rebuilt offline after save_pretrained.
config.hidden_size = hidden_size
if config.freeze_bert:
for param in self.bert.parameters():
param.requires_grad = False
# ── 2. Projection head ──────────────────────────────────────────────
self.head = nn.Sequential(
nn.Linear(hidden_size, config.hidden_dim),
nn.ReLU(),
nn.Dropout(config.dropout),
nn.Linear(config.hidden_dim, 1),
)
self._init_head()
# ── 3. Ordinal thresholds ───────────────────────────────────────────
# K-1 raw values; monotonicity enforced via cumsum(softplus(Β·)).
self.raw_thresholds = nn.Parameter(torch.zeros(K - 1))
with torch.no_grad():
targets = torch.linspace(-1.0, 1.0, K - 1)
diffs = torch.cat([targets[:1], targets[1:] - targets[:-1]])
self.raw_thresholds.copy_(
torch.log(torch.expm1(diffs.clamp(min=1e-3)))
)
# Finalises weight init bookkeeping required by PreTrainedModel.
self.post_init()
# -----------------------------------------------------------------------
# Helpers
# -----------------------------------------------------------------------
def _init_head(self) -> None:
for m in self.head.modules():
if isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
nn.init.zeros_(m.bias)
@property
def thresholds(self) -> torch.Tensor:
"""Monotone thresholds θ₁ ≀ … ≀ ΞΈ_{K-1} (shape: K-1)."""
return torch.cumsum(F.softplus(self.raw_thresholds), dim=0)
# -----------------------------------------------------------------------
# Forward
# -----------------------------------------------------------------------
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
**kwargs,
) -> BertOrdinalOutput:
"""
Parameters
----------
input_ids : (B, L)
attention_mask : (B, L)
token_type_ids : (B, L) optional
labels : (B,) long β€” class indices in {0, …, K-1}
Returns
-------
BertOrdinalOutput
"""
# ── Encode ──────────────────────────────────────────────────────────
bert_kwargs = dict(input_ids=input_ids, attention_mask=attention_mask)
if token_type_ids is not None:
bert_kwargs["token_type_ids"] = token_type_ids
cls_repr = self.bert(**bert_kwargs).last_hidden_state[:, 0, :] # (B, H)
# ── Latent score ────────────────────────────────────────────────────
score = self.head(cls_repr).squeeze(-1) # (B,)
# ── Cumulative probs P(Y ≀ j) = Οƒ(ΞΈ_j βˆ’ score) ────────────────────
cum_logits = self.thresholds.unsqueeze(0) - score.unsqueeze(1) # (B, K-1)
cum_probs = torch.sigmoid(cum_logits) # (B, K-1)
# ── Class probs P(Y = j) = P(Y ≀ j) βˆ’ P(Y ≀ j-1) ─────────────────
B, dev = cum_probs.size(0), cum_probs.device
F_ = torch.cat(
[torch.zeros(B, 1, device=dev), cum_probs, torch.ones(B, 1, device=dev)],
dim=1,
) # (B, K+1)
class_probs = (F_[:, 1:] - F_[:, :-1]).clamp(min=1e-9) # (B, K)
# ── Predictions ──────────────────────────────────────────────────────
predictions = class_probs.argmax(dim=-1) # (B,)
# ── Loss ─────────────────────────────────────────────────────────────
loss: Optional[torch.Tensor] = None
if labels is not None:
loss = ordinal_cross_entropy(
class_probs, labels, reduction=self.config.loss_reduction
)
return BertOrdinalOutput(
loss=loss,
logits=score,
predictions=predictions,
cum_probs=cum_probs,
class_probs=class_probs,
)
# -----------------------------------------------------------------------
# Convenience
# -----------------------------------------------------------------------
@torch.no_grad()
def predict(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Return predicted class indices (no loss computed)."""
return self.forward(input_ids, attention_mask, token_type_ids).predictions
# ---------------------------------------------------------------------------
# Loss function
# ---------------------------------------------------------------------------
def ordinal_cross_entropy(
class_probs: torch.Tensor,
labels: torch.Tensor,
reduction: str = "mean",
) -> torch.Tensor:
"""
Ordinal cross-entropy.
Parameters
----------
class_probs : (B, K) β€” P(Y=j|x), clamped > 0
labels : (B,) β€” ground-truth indices in {0, …, K-1}
reduction : 'mean' | 'sum' | 'none'
"""
return F.nll_loss(torch.log(class_probs), labels, reduction=reduction)