scibert-certainty-ordinal / configuration_bert_ordinal.py
Cbelem's picture
Upload 3 files
c1f6b2a verified
"""
bert_ordinal.py
---------------
BERT-based ordinal regression model, fully integrated with the HuggingFace
Transformers API:
model.save_pretrained("my-checkpoint/")
model = BertOrdinal.from_pretrained("my-checkpoint/")
Architecture
------------
1. A (optionally frozen) BERT backbone.
2. A projection head on the [CLS] token:
Linear(hidden_size β†’ hidden_dim) β†’ ReLU β†’ Dropout(p) β†’ Linear(hidden_dim β†’ 1)
producing a single latent score s ∈ ℝ.
3. K-1 learnable raw_threshold parameters enforcing monotonicity via
cumsum(softplus(Β·)).
4. Cumulative-link probabilities:
P(Y ≀ j | x) = Οƒ(ΞΈ_j βˆ’ s)
Usage
-----
from bert_ordinal import BertOrdinalConfig, BertOrdinal
# ── Create from scratch ──────────────────────────────────────────────────
cfg = BertOrdinalConfig(
bert_model_name="bert-base-uncased",
num_classes=3,
hidden_dim=128,
dropout=0.1,
freeze_bert=True,
)
model = BertOrdinal(cfg)
# ── Save ────────────────────────────────────────────────────────────────
model.save_pretrained("my-checkpoint/")
tokenizer.save_pretrained("my-checkpoint/") # keep tokenizer alongside
# ── Reload ──────────────────────────────────────────────────────────────
model = BertOrdinal.from_pretrained("my-checkpoint/")
tokenizer = AutoTokenizer.from_pretrained("my-checkpoint/")
"""
from __future__ import annotations
from typing import Optional
from transformers import PretrainedConfig
# ---------------------------------------------------------------------------
# 1. Config β€” subclass PretrainedConfig for full HF serialisation
# ---------------------------------------------------------------------------
class BertOrdinalConfig(PretrainedConfig):
"""
Configuration for :class:`BertOrdinal`.
Because this inherits from :class:`~transformers.PretrainedConfig`,
``save_pretrained`` writes a ``config.json`` that ``from_pretrained``
can read back without any extra bookkeeping.
Parameters
----------
bert_model_name : str
HuggingFace model name or local path for the BERT backbone.
num_classes : int
Number of ordinal classes K. Creates K-1 learnable thresholds.
hidden_dim : int
Inner dimension of the projection head.
dropout : float
Dropout probability inside the projection head.
freeze_bert : bool
Freeze backbone weights at construction time.
loss_reduction : str
``'mean'`` or ``'sum'``.
"""
# Tells HF which class owns this config (written into config.json).
model_type = "bert_ordinal"
problem_type = "single_label_classification"
def __init__(
self,
bert_model_name: str = "allenai/scibert_scivocab_uncased",
num_classes: int = 3,
hidden_dim: int = 256,
dropout: float = 0.1,
freeze_bert: bool = True,
loss_reduction: str = "mean",
# hidden_size is set automatically by the model after loading BERT;
# it is stored here so from_pretrained can rebuild the head offline.
hidden_size: Optional[int] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.bert_model_name = bert_model_name
self.num_classes = num_classes
self.hidden_dim = hidden_dim
self.dropout = dropout
self.freeze_bert = freeze_bert
self.loss_reduction = loss_reduction
self.hidden_size = hidden_size # filled in by BertOrdinal.__init__
self.auto_map = {
"AutoConfig": "configuration_bert_ordinal.BertOrdinalConfig",
"AutoModel": "modeling_bert_ordinal.BertOrdinal",
"AutoModelForSequenceClassification": "modeling_bert_ordinal.BertOrdinal",
}