File size: 9,052 Bytes
660ffb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
"""
DomainTransformer Model — GPT-style causal decoder for domain token sequences.

Architecture follows:
  - NoPE (no positional encoding) — Kazemnejad et al. 2023 (arXiv:2305.19466)
  - Pre-norm (LayerNorm before attention and FFN) — GPT-2 style
  - F.scaled_dot_product_attention with is_causal=True — auto FlashAttention
  - Weight tying between token embedding and LM head
  - Scaled residual initialization: 1/sqrt(2*N_layers)

Reference sizes (Nubank nuFormer, arXiv:2507.23267):
  - 24M:  6 layers, d=512, 8 heads
  - 330M: 24 layers, d=1024, 16 heads
"""

import math
from typing import Optional, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast

from .configuration import DomainTransformerConfig


class DomainTransformerAttention(nn.Module):
    """Multi-head self-attention with NoPE.

    Uses F.scaled_dot_product_attention for automatic FlashAttention/SDPA dispatch.
    No positional encoding — causal masking via is_causal=True.
    """

    def __init__(self, config: DomainTransformerConfig):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.hidden_size // self.num_heads
        self.scaling = self.head_dim ** -0.5

        self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
        self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
        self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
        self.out_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
        self.attn_dropout = config.attention_probs_dropout_prob

    def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        B, T, C = hidden_states.shape
        q = self.q_proj(hidden_states).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(hidden_states).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(hidden_states).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)

        # Convert HF-style attention_mask (1=attend, 0=ignore, long) to SDPA format
        sdpa_mask = None
        use_causal = True
        if attention_mask is not None:
            sdpa_mask = attention_mask[:, None, None, :].to(dtype=q.dtype)
            sdpa_mask = (1.0 - sdpa_mask) * torch.finfo(q.dtype).min
            use_causal = False

        attn_out = F.scaled_dot_product_attention(
            q, k, v, attn_mask=sdpa_mask,
            dropout_p=self.attn_dropout if self.training else 0.0,
            is_causal=use_causal, scale=self.scaling,
        )
        attn_out = attn_out.transpose(1, 2).contiguous().reshape(B, T, C)
        return self.out_proj(attn_out)


class DomainTransformerMLP(nn.Module):
    """Two-layer FFN with GELU activation (GPT-2 style)."""

    def __init__(self, config: DomainTransformerConfig):
        super().__init__()
        self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=True)
        self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=True)
        self.act = nn.GELU(approximate="tanh")
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        return self.dropout(self.down_proj(self.act(self.up_proj(hidden_states))))


class DomainTransformerBlock(nn.Module):
    """Single transformer block with pre-norm architecture."""

    def __init__(self, config: DomainTransformerConfig):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.attn = DomainTransformerAttention(config)
        self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.mlp = DomainTransformerMLP(config)

    def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.attn(self.ln_1(hidden_states), attention_mask)
        hidden_states = residual + hidden_states
        residual = hidden_states
        hidden_states = self.mlp(self.ln_2(hidden_states))
        hidden_states = residual + hidden_states
        return hidden_states


class DomainTransformerPreTrainedModel(PreTrainedModel):
    """Base class with weight initialization."""
    config_class = DomainTransformerConfig
    base_model_prefix = "model"
    supports_gradient_checkpointing = True

    def _init_weights(self, module: nn.Module):
        std = self.config.initializer_range
        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, mean=0.0, std=std)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=std)
            if module.padding_idx is not None:
                nn.init.zeros_(module.weight[module.padding_idx])
        elif isinstance(module, nn.LayerNorm):
            nn.init.zeros_(module.bias)
            nn.init.ones_(module.weight)


class DomainTransformerModel(DomainTransformerPreTrainedModel):
    """The bare DomainTransformer: embeddings + blocks + final layernorm."""

    def __init__(self, config: DomainTransformerConfig):
        super().__init__(config)
        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
        self.embed_dropout = nn.Dropout(config.hidden_dropout_prob)
        self.blocks = nn.ModuleList([DomainTransformerBlock(config) for _ in range(config.num_hidden_layers)])
        self.ln_f = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.gradient_checkpointing = False
        self.post_init()

    def get_input_embeddings(self):
        return self.embed_tokens

    def set_input_embeddings(self, value):
        self.embed_tokens = value

    def forward(self, input_ids=None, attention_mask=None, inputs_embeds=None, **kwargs):
        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)
        hidden_states = self.embed_dropout(inputs_embeds)
        for block in self.blocks:
            if self.gradient_checkpointing and self.training:
                hidden_states = torch.utils.checkpoint.checkpoint(block, hidden_states, attention_mask, use_reentrant=False)
            else:
                hidden_states = block(hidden_states, attention_mask)
        hidden_states = self.ln_f(hidden_states)
        return BaseModelOutputWithPast(last_hidden_state=hidden_states)


class DomainTransformerForCausalLM(DomainTransformerPreTrainedModel):
    """DomainTransformer with a causal language modeling head.

    The LM head is weight-tied with the token embedding layer.
    Loss is computed via standard shifted cross-entropy.
    """
    _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}

    def __init__(self, config: DomainTransformerConfig):
        super().__init__(config)
        self.model = DomainTransformerModel(config)
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        self.post_init()

    def get_input_embeddings(self):
        return self.model.embed_tokens

    def set_input_embeddings(self, value):
        self.model.embed_tokens = value

    def get_output_embeddings(self):
        return self.lm_head

    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings

    def forward(self, input_ids=None, attention_mask=None, labels=None, inputs_embeds=None, **kwargs):
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds)
        hidden_states = outputs.last_hidden_state
        logits = self.lm_head(hidden_states)

        loss = None
        if labels is not None:
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            loss = F.cross_entropy(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1), ignore_index=-100)

        return CausalLMOutputWithPast(loss=loss, logits=logits)

    def get_user_embedding(self, input_ids, attention_mask=None):
        """Extract user-level embedding from the last non-padding token."""
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
        hidden_states = outputs.last_hidden_state
        if attention_mask is not None:
            seq_lengths = attention_mask.sum(dim=1) - 1
            batch_idx = torch.arange(hidden_states.size(0), device=hidden_states.device)
            return hidden_states[batch_idx, seq_lengths]
        else:
            return hidden_states[:, -1, :]


DomainTransformerConfig.register_for_auto_class()
DomainTransformerForCausalLM.register_for_auto_class("AutoModelForCausalLM")