""" Hugging Face model definition for the Sentiment Transformer. This file is **self-contained** — it depends only on ``torch`` and ``transformers``. It is copied verbatim into every HF export directory so that ``AutoModelForSequenceClassification.from_pretrained()`` works with ``trust_remote_code=True``. Architecture ------------ Token Embedding + RoPE (Rotary Positional Embedding) -> N x TransformerEncoderBlock (pre-layer-norm, SwiGLU FFN) -> Final LayerNorm -> Mean pooling (masked) -> 2-layer MLP classification head (num_labels-class logits) """ from __future__ import annotations import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel from transformers.modeling_outputs import SequenceClassifierOutput from configuration_sentiment_transformer import SentimentTransformerConfig # --------------------------------------------------------------------------- # Rotary Positional Embedding (RoPE) # --------------------------------------------------------------------------- class RotaryEmbedding(nn.Module): """Precompute and cache the sin/cos frequencies for RoPE. RoPE encodes absolute position through *rotation* applied to pairs of dimensions in Q and K. This gives the dot-product between Q_i and K_j a natural dependence on relative position (i - j) without any learnable parameters. """ def __init__(self, head_dim: int, max_seq_len: int, base: float = 10000.0) -> None: super().__init__() assert head_dim % 2 == 0, "head_dim must be even for RoPE" inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) t = torch.arange(max_seq_len).float() freqs = torch.outer(t, inv_freq) self.register_buffer("cos_cached", freqs.cos(), persistent=False) self.register_buffer("sin_cached", freqs.sin(), persistent=False) def forward(self, seq_len: int) -> tuple[torch.Tensor, torch.Tensor]: """Return (cos, sin) each of shape (seq_len, head_dim // 2).""" return self.cos_cached[:seq_len], self.sin_cached[:seq_len] def _apply_rope( x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, ) -> torch.Tensor: """Apply rotary embedding to a Q or K tensor. Parameters ---------- x : Tensor, shape ``(B, num_heads, S, head_dim)`` cos, sin : Tensor, shape ``(S, head_dim // 2)`` Returns ------- Tensor, same shape as ``x``. """ x1 = x[..., 0::2] # even indices x2 = x[..., 1::2] # odd indices cos = cos.unsqueeze(0).unsqueeze(0) sin = sin.unsqueeze(0).unsqueeze(0) out1 = x1 * cos - x2 * sin out2 = x1 * sin + x2 * cos return torch.stack((out1, out2), dim=-1).flatten(-2) # --------------------------------------------------------------------------- # Building blocks # --------------------------------------------------------------------------- class MultiHeadSelfAttention(nn.Module): """Multi-head self-attention with RoPE and fused SDPA kernel. Automatically dispatches to FlashAttention or Memory-Efficient Attention when running on a compatible GPU. """ def __init__( self, hidden_dim: int, num_heads: int, dropout: float, rope: RotaryEmbedding, ) -> None: super().__init__() assert hidden_dim % num_heads == 0, ( f"hidden_dim ({hidden_dim}) must be divisible by num_heads ({num_heads})" ) self.num_heads = num_heads self.head_dim = hidden_dim // num_heads self.dropout = dropout self.rope = rope self.q_proj = nn.Linear(hidden_dim, hidden_dim) self.k_proj = nn.Linear(hidden_dim, hidden_dim) self.v_proj = nn.Linear(hidden_dim, hidden_dim) self.out_proj = nn.Linear(hidden_dim, hidden_dim) def forward( self, x: torch.Tensor, attention_mask: torch.Tensor | None = None, ) -> torch.Tensor: B, S, H = x.shape q = self.q_proj(x).view(B, S, self.num_heads, self.head_dim).transpose(1, 2) k = self.k_proj(x).view(B, S, self.num_heads, self.head_dim).transpose(1, 2) v = self.v_proj(x).view(B, S, self.num_heads, self.head_dim).transpose(1, 2) # Apply RoPE to Q and K cos, sin = self.rope(S) q = _apply_rope(q, cos, sin) k = _apply_rope(k, cos, sin) attn_mask = None if attention_mask is not None: attn_mask = attention_mask.bool().unsqueeze(1).unsqueeze(2) attn_out = F.scaled_dot_product_attention( q, k, v, attn_mask=attn_mask, dropout_p=self.dropout if self.training else 0.0, ) attn_out = attn_out.transpose(1, 2).contiguous().view(B, S, H) return self.out_proj(attn_out) class SwiGLUFeedForward(nn.Module): """SwiGLU feed-forward network (as used in LLaMA / Gemma). SwiGLU(x) = W_down · (SiLU(W_gate · x) ⊙ W_up · x) """ def __init__(self, hidden_dim: int, ffn_dim: int, dropout: float) -> None: super().__init__() inner_dim = int(2 / 3 * ffn_dim) inner_dim = ((inner_dim + 7) // 8) * 8 # round up to multiple of 8 self.w_gate = nn.Linear(hidden_dim, inner_dim, bias=False) self.w_up = nn.Linear(hidden_dim, inner_dim, bias=False) self.w_down = nn.Linear(inner_dim, hidden_dim, bias=False) self.dropout = nn.Dropout(dropout) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.dropout(self.w_down(F.silu(self.w_gate(x)) * self.w_up(x))) class TransformerEncoderBlock(nn.Module): """Single transformer encoder block with **pre-layer-norm** and SwiGLU. Pre-LN applies LayerNorm *before* each sub-layer: x = x + Attention(LayerNorm(x)) x = x + SwiGLU_FFN(LayerNorm(x)) """ def __init__( self, hidden_dim: int, num_heads: int, ffn_dim: int, dropout: float, rope: RotaryEmbedding, ) -> None: super().__init__() self.norm1 = nn.LayerNorm(hidden_dim) self.attn = MultiHeadSelfAttention(hidden_dim, num_heads, dropout, rope) self.norm2 = nn.LayerNorm(hidden_dim) self.ffn = SwiGLUFeedForward(hidden_dim, ffn_dim, dropout) self.dropout = nn.Dropout(dropout) def forward( self, x: torch.Tensor, attention_mask: torch.Tensor | None = None, ) -> torch.Tensor: x = x + self.dropout(self.attn(self.norm1(x), attention_mask)) x = x + self.dropout(self.ffn(self.norm2(x))) return x class SentimentTransformerBackbone(nn.Module): """Transformer encoder for sentiment classification. Uses mean pooling over non-padding tokens and a 2-layer MLP classification head. Returns raw logits (no softmax). """ def __init__( self, vocab_size: int, hidden_dim: int, ffn_dim: int, num_layers: int, num_heads: int, max_seq_len: int, num_classes: int, dropout: float = 0.1, ) -> None: super().__init__() self.token_embedding = nn.Embedding(vocab_size, hidden_dim, padding_idx=0) self.embedding_dropout = nn.Dropout(dropout) # Shared RoPE module head_dim = hidden_dim // num_heads self.rope = RotaryEmbedding(head_dim, max_seq_len) self.layers = nn.ModuleList([ TransformerEncoderBlock( hidden_dim=hidden_dim, num_heads=num_heads, ffn_dim=ffn_dim, dropout=dropout, rope=self.rope, ) for _ in range(num_layers) ]) self.final_norm = nn.LayerNorm(hidden_dim) # 2-layer MLP classification head self.classifier = nn.Sequential( nn.Linear(hidden_dim, hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim, num_classes), ) self._init_weights() def _init_weights(self) -> None: """Xavier-uniform for linear layers, normal for embeddings.""" for module in self.modules(): if isinstance(module, nn.Linear): nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.padding_idx is not None: with torch.no_grad(): module.weight[module.padding_idx].fill_(0) elif isinstance(module, nn.LayerNorm): nn.init.ones_(module.weight) nn.init.zeros_(module.bias) def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, ) -> torch.Tensor: B, S = input_ids.shape # Token embeddings only — positional information injected via RoPE x = self.embedding_dropout(self.token_embedding(input_ids)) for layer in self.layers: x = layer(x, attention_mask) x = self.final_norm(x) # Mean pooling over non-padding tokens mask = attention_mask.unsqueeze(-1).float() # (B, S, 1) pooled = (x * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9) # (B, H) logits = self.classifier(pooled) return logits # --------------------------------------------------------------------------- # HuggingFace PreTrainedModel wrapper # --------------------------------------------------------------------------- class SentimentTransformerForSequenceClassification(PreTrainedModel): """HuggingFace-compatible sequence classification wrapper. This class bridges the custom transformer backbone with the HF ecosystem. It accepts the standard ``input_ids``, ``attention_mask``, and ``labels`` arguments and returns a :class:`~transformers.modeling_outputs.SequenceClassifierOutput`. Usage:: from transformers import AutoModelForSequenceClassification, pipeline model = AutoModelForSequenceClassification.from_pretrained( "path/to/export", trust_remote_code=True ) pipe = pipeline("text-classification", model=model, tokenizer=tokenizer) pipe("This movie was amazing!") """ config_class = SentimentTransformerConfig base_model_prefix = "backbone" main_input_name = "input_ids" def __init__(self, config: SentimentTransformerConfig) -> None: super().__init__(config) self.backbone = SentimentTransformerBackbone( vocab_size=config.vocab_size, hidden_dim=config.hidden_size, ffn_dim=config.intermediate_size, num_layers=config.num_hidden_layers, num_heads=config.num_attention_heads, max_seq_len=config.max_position_embeddings, num_classes=config.num_labels, dropout=config.hidden_dropout_prob, ) self.post_init() def _recompute_rope_buffers(self) -> None: """Recompute all RoPE sin/cos buffers on the current device. HF's ``from_pretrained`` uses meta-device initialization which leaves non-persistent buffers as uninitialised memory. This method rebuilds them from scratch after weights are loaded. """ for module in self.modules(): if isinstance(module, RotaryEmbedding): device = module.inv_freq.device inv_freq = 1.0 / ( 10000.0 ** ( torch.arange(0, module.inv_freq.shape[0] * 2, 2, device=device).float() / (module.inv_freq.shape[0] * 2) ) ) module.inv_freq = inv_freq max_seq_len = module.cos_cached.shape[0] t = torch.arange(max_seq_len, device=device).float() freqs = torch.outer(t, inv_freq) module.cos_cached = freqs.cos() module.sin_cached = freqs.sin() self._rope_valid = True def _ensure_rope_valid(self) -> None: """Lazily recompute RoPE buffers if they were corrupted by HF loading.""" if not getattr(self, "_rope_valid", False): # Check if the backbone's RoPE buffers contain valid data rope = self.backbone.rope if not rope.cos_cached.isfinite().all(): self._recompute_rope_buffers() else: self._rope_valid = True def forward( self, input_ids: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, labels: torch.Tensor | None = None, return_dict: bool | None = None, **_kwargs, ) -> SequenceClassifierOutput | tuple[torch.Tensor, ...]: """Run sequence classification and return HF-style outputs.""" self._ensure_rope_valid() if input_ids is None: raise ValueError("`input_ids` is required.") if attention_mask is None: attention_mask = torch.ones_like(input_ids) logits = self.backbone(input_ids=input_ids, attention_mask=attention_mask) loss = None if labels is not None: loss = F.cross_entropy(logits, labels) use_return_dict = ( return_dict if return_dict is not None else self.config.return_dict ) if not use_return_dict: output = (logits,) return ((loss,) + output) if loss is not None else output return SequenceClassifierOutput(loss=loss, logits=logits)