sentiment-transformer / configuration_sentiment_transformer.py
Impulse2000's picture
Upload sentiment-transformer model
21035f8 verified
"""
Hugging Face configuration for the Sentiment Transformer.
This file is **self-contained** — it has no dependency on the project's
``config.py`` or ``config.toml``. It is copied verbatim into every HF
export directory so that ``AutoConfig.from_pretrained()`` works with
``trust_remote_code=True``.
"""
from __future__ import annotations
from transformers import PretrainedConfig
class SentimentTransformerConfig(PretrainedConfig):
"""HuggingFace-compatible configuration for the custom sentiment
transformer encoder classifier.
This maps the project's internal hyperparameter names to the
canonical HF field names used by ``AutoConfig`` / ``AutoModel``.
Attributes
----------
vocab_size : int
Size of the BPE vocabulary.
hidden_size : int
Embedding / hidden dimension of the transformer.
intermediate_size : int
Inner (expanded) dimension of the position-wise FFN.
num_hidden_layers : int
Number of stacked transformer encoder blocks.
num_attention_heads : int
Number of parallel attention heads.
max_position_embeddings : int
Maximum supported input sequence length.
hidden_dropout_prob : float
Dropout probability used throughout the model.
num_labels : int
Number of output classes (2 for binary, 3 for ternary, etc.).
"""
model_type = "sentiment-transformer"
def __init__(
self,
vocab_size: int = 16_000,
hidden_size: int = 256,
intermediate_size: int = 1024,
num_hidden_layers: int = 6,
num_attention_heads: int = 8,
max_position_embeddings: int = 256,
hidden_dropout_prob: float = 0.1,
num_labels: int = 2,
pad_token_id: int = 0,
id2label: dict[int, str] | None = None,
label2id: dict[str, int] | None = None,
**kwargs,
) -> None:
# When loading from a serialized config.json, `id2label` and
# `num_labels` may both be present. HF's PreTrainedConfig sets
# ``num_labels = 2`` as a hidden default, which overrides the
# id2label we saved. Reconcile by deriving from id2label.
if id2label is not None and len(id2label) != num_labels:
num_labels = len(id2label)
# `problem_type` may already be present in kwargs when loading from
# a serialized config.json — use setdefault to avoid duplicate kwarg.
kwargs.setdefault("problem_type", "single_label_classification")
super().__init__(
pad_token_id=pad_token_id,
num_labels=num_labels,
id2label=id2label,
label2id=label2id,
**kwargs,
)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.max_position_embeddings = max_position_embeddings
self.hidden_dropout_prob = hidden_dropout_prob