| """ |
| bert_ordinal.py |
| --------------- |
| BERT-based ordinal regression model, fully integrated with the HuggingFace |
| Transformers API: |
| |
| model.save_pretrained("my-checkpoint/") |
| model = BertOrdinal.from_pretrained("my-checkpoint/") |
| |
| Architecture |
| ------------ |
| 1. A (optionally frozen) BERT backbone. |
| 2. A projection head on the [CLS] token: |
| Linear(hidden_size β hidden_dim) β ReLU β Dropout(p) β Linear(hidden_dim β 1) |
| producing a single latent score s β β. |
| 3. K-1 learnable raw_threshold parameters enforcing monotonicity via |
| cumsum(softplus(Β·)). |
| 4. Cumulative-link probabilities: |
| P(Y β€ j | x) = Ο(ΞΈ_j β s) |
| |
| Usage |
| ----- |
| from bert_ordinal import BertOrdinalConfig, BertOrdinal |
| |
| # ββ Create from scratch ββββββββββββββββββββββββββββββββββββββββββββββββββ |
| cfg = BertOrdinalConfig( |
| bert_model_name="bert-base-uncased", |
| num_classes=3, |
| hidden_dim=128, |
| dropout=0.1, |
| freeze_bert=True, |
| ) |
| model = BertOrdinal(cfg) |
| |
| # ββ Save ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
| model.save_pretrained("my-checkpoint/") |
| tokenizer.save_pretrained("my-checkpoint/") # keep tokenizer alongside |
| |
| # ββ Reload ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
| model = BertOrdinal.from_pretrained("my-checkpoint/") |
| tokenizer = AutoTokenizer.from_pretrained("my-checkpoint/") |
| """ |
|
|
| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
| from typing import Optional |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import AutoModel, PreTrainedModel |
| from transformers.modeling_outputs import ModelOutput |
|
|
| from .configuration_bert_ordinal import BertOrdinalConfig |
|
|
| |
| |
| |
|
|
| @dataclass |
| class BertOrdinalOutput(ModelOutput): |
| """ |
| Return type of :class:`BertOrdinal`. |
| |
| Attributes |
| ---------- |
| loss : torch.Tensor or None |
| Ordinal cross-entropy loss (scalar). Present only when ``labels`` |
| are supplied. |
| logits : torch.Tensor (B,) |
| Raw latent score from the projection head. |
| predictions : torch.Tensor (B,) |
| Predicted class index β argmax of ``class_probs``. |
| cum_probs : torch.Tensor (B, K-1) |
| Cumulative probabilities P(Y β€ j | x). |
| class_probs : torch.Tensor (B, K) |
| Per-class probabilities P(Y = j | x). |
| """ |
| |
| loss: Optional[torch.Tensor] = None |
| logits: Optional[torch.Tensor] = None |
| predictions: Optional[torch.Tensor] = None |
| cum_probs: Optional[torch.Tensor] = None |
| class_probs: Optional[torch.Tensor] = None |
|
|
|
|
| |
| |
| |
|
|
| class BertOrdinal(PreTrainedModel): |
| """ |
| BERT encoder with an ordinal-regression head. |
| |
| Fully compatible with the HuggingFace checkpoint API:: |
| |
| model.save_pretrained("my-checkpoint/") |
| model = BertOrdinal.from_pretrained("my-checkpoint/") |
| |
| What gets saved |
| ~~~~~~~~~~~~~~~ |
| ``save_pretrained`` writes two files: |
| |
| * ``config.json`` β the full :class:`BertOrdinalConfig` (including |
| ``bert_model_name``, ``hidden_size``, thresholds shape, β¦). |
| * ``model.safetensors`` (or ``pytorch_model.bin``) β a **single flat |
| state_dict** containing both the BERT backbone weights and the |
| head/threshold parameters. |
| |
| ``from_pretrained`` reconstructs the model from the config (which |
| already has ``hidden_size`` cached), loads the state_dict, and |
| re-applies the ``freeze_bert`` setting β no internet access needed |
| after the first save. |
| """ |
|
|
| config_class = BertOrdinalConfig |
|
|
| def __init__(self, config: BertOrdinalConfig) -> None: |
| super().__init__(config) |
| K = config.num_classes |
|
|
| |
| |
| |
| |
| |
| self.bert = AutoModel.from_pretrained(config.bert_model_name) |
| hidden_size: int = self.bert.config.hidden_size |
|
|
| |
| config.hidden_size = hidden_size |
|
|
| if config.freeze_bert: |
| for param in self.bert.parameters(): |
| param.requires_grad = False |
|
|
| |
| self.head = nn.Sequential( |
| nn.Linear(hidden_size, config.hidden_dim), |
| nn.ReLU(), |
| nn.Dropout(config.dropout), |
| nn.Linear(config.hidden_dim, 1), |
| ) |
| self._init_head() |
|
|
| |
| |
| self.raw_thresholds = nn.Parameter(torch.zeros(K - 1)) |
| with torch.no_grad(): |
| targets = torch.linspace(-1.0, 1.0, K - 1) |
| diffs = torch.cat([targets[:1], targets[1:] - targets[:-1]]) |
| self.raw_thresholds.copy_( |
| torch.log(torch.expm1(diffs.clamp(min=1e-3))) |
| ) |
|
|
| |
| self.post_init() |
|
|
| |
| |
| |
|
|
| def _init_head(self) -> None: |
| for m in self.head.modules(): |
| if isinstance(m, nn.Linear): |
| nn.init.kaiming_normal_(m.weight, nonlinearity="relu") |
| nn.init.zeros_(m.bias) |
|
|
| @property |
| def thresholds(self) -> torch.Tensor: |
| """Monotone thresholds ΞΈβ β€ β¦ β€ ΞΈ_{K-1} (shape: K-1).""" |
| return torch.cumsum(F.softplus(self.raw_thresholds), dim=0) |
|
|
| |
| |
| |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| token_type_ids: Optional[torch.Tensor] = None, |
| labels: Optional[torch.Tensor] = None, |
| **kwargs, |
| ) -> BertOrdinalOutput: |
| """ |
| Parameters |
| ---------- |
| input_ids : (B, L) |
| attention_mask : (B, L) |
| token_type_ids : (B, L) optional |
| labels : (B,) long β class indices in {0, β¦, K-1} |
| |
| Returns |
| ------- |
| BertOrdinalOutput |
| """ |
| |
| bert_kwargs = dict(input_ids=input_ids, attention_mask=attention_mask) |
| if token_type_ids is not None: |
| bert_kwargs["token_type_ids"] = token_type_ids |
|
|
| cls_repr = self.bert(**bert_kwargs).last_hidden_state[:, 0, :] |
|
|
| |
| score = self.head(cls_repr).squeeze(-1) |
|
|
| |
| cum_logits = self.thresholds.unsqueeze(0) - score.unsqueeze(1) |
| cum_probs = torch.sigmoid(cum_logits) |
|
|
| |
| B, dev = cum_probs.size(0), cum_probs.device |
| F_ = torch.cat( |
| [torch.zeros(B, 1, device=dev), cum_probs, torch.ones(B, 1, device=dev)], |
| dim=1, |
| ) |
| class_probs = (F_[:, 1:] - F_[:, :-1]).clamp(min=1e-9) |
|
|
| |
| predictions = class_probs.argmax(dim=-1) |
|
|
| |
| loss: Optional[torch.Tensor] = None |
| if labels is not None: |
| loss = ordinal_cross_entropy( |
| class_probs, labels, reduction=self.config.loss_reduction |
| ) |
|
|
| return BertOrdinalOutput( |
| loss=loss, |
| logits=score, |
| predictions=predictions, |
| cum_probs=cum_probs, |
| class_probs=class_probs, |
| ) |
|
|
| |
| |
| |
|
|
| @torch.no_grad() |
| def predict( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| token_type_ids: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| """Return predicted class indices (no loss computed).""" |
| return self.forward(input_ids, attention_mask, token_type_ids).predictions |
|
|
|
|
| |
| |
| |
|
|
| def ordinal_cross_entropy( |
| class_probs: torch.Tensor, |
| labels: torch.Tensor, |
| reduction: str = "mean", |
| ) -> torch.Tensor: |
| """ |
| Ordinal cross-entropy. |
| |
| Parameters |
| ---------- |
| class_probs : (B, K) β P(Y=j|x), clamped > 0 |
| labels : (B,) β ground-truth indices in {0, β¦, K-1} |
| reduction : 'mean' | 'sum' | 'none' |
| """ |
| return F.nll_loss(torch.log(class_probs), labels, reduction=reduction) |
|
|