harryrobert commited on
Commit
8c4ee68
·
verified ·
1 Parent(s): cc2ddbf

force update

Browse files
Files changed (1) hide show
  1. model_latex_decoder.py +199 -0
model_latex_decoder.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # update
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from typing import Optional
7
+
8
+ from transformers import PreTrainedModel
9
+ from transformers.modeling_outputs import CausalLMOutput
10
+
11
+ from .configuration_latex_decoder import LaTeXDecoderConfig
12
+
13
+
14
+ class RMSNorm(nn.Module):
15
+ def __init__(self, d_model: int, eps: float = 1e-6):
16
+ super().__init__()
17
+ self.eps = eps
18
+ self.weight = nn.Parameter(torch.ones(d_model))
19
+
20
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
21
+ rms = x.pow(2).mean(-1, keepdim=True).add(self.eps).sqrt()
22
+ return x / rms * self.weight
23
+
24
+
25
+ def _build_rope_cache(seq_len, head_dim, theta=10000.0, device=None, dtype=torch.float32):
26
+ half = head_dim // 2
27
+ inv_freq = 1.0 / (theta ** (torch.arange(0, half, device=device, dtype=torch.float32) / half))
28
+ pos = torch.arange(seq_len, device=device, dtype=torch.float32)
29
+ freqs = torch.outer(pos, inv_freq)
30
+ emb = torch.cat([freqs, freqs], dim=-1)
31
+ return emb.cos().to(dtype), emb.sin().to(dtype)
32
+
33
+
34
+ def _rotate_half(x: torch.Tensor) -> torch.Tensor:
35
+ half = x.shape[-1] // 2
36
+ x1, x2 = x[..., :half], x[..., half:]
37
+ return torch.cat([-x2, x1], dim=-1)
38
+
39
+
40
+ def apply_rope(q, k, cos, sin):
41
+ cos = cos.unsqueeze(0).unsqueeze(0)
42
+ sin = sin.unsqueeze(0).unsqueeze(0)
43
+ return q * cos + _rotate_half(q) * sin, k * cos + _rotate_half(k) * sin
44
+
45
+
46
+ class CausalSelfAttention(nn.Module):
47
+ def __init__(self, cfg: LaTeXDecoderConfig):
48
+ super().__init__()
49
+ self.n_heads = cfg.n_heads
50
+ self.head_dim = cfg.head_dim
51
+ self.d_model = cfg.d_model
52
+ self.dropout_p = cfg.dropout
53
+ self.rope_theta = cfg.rope_theta
54
+
55
+ self.qkv_proj = nn.Linear(cfg.d_model, 3 * cfg.d_model, bias=False)
56
+ self.out_proj = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
57
+ self._rope_cache: dict = {}
58
+
59
+ def _get_rope(self, seq_len, device, dtype):
60
+ key = (seq_len, str(device), dtype)
61
+ if key not in self._rope_cache:
62
+ self._rope_cache[key] = _build_rope_cache(seq_len, self.head_dim, self.rope_theta, device, dtype)
63
+ return self._rope_cache[key]
64
+
65
+ def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
66
+ B, T, C = x.shape
67
+ q, k, v = self.qkv_proj(x).chunk(3, dim=-1)
68
+
69
+ q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
70
+ k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
71
+ v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
72
+
73
+ cos, sin = self._get_rope(T, x.device, q.dtype)
74
+ q, k = apply_rope(q, k, cos, sin)
75
+
76
+ dropout_p = self.dropout_p if self.training else 0.0
77
+
78
+ if attention_mask is not None:
79
+ causal = torch.triu(torch.full((T, T), float("-inf"), device=x.device, dtype=q.dtype), diagonal=1)
80
+ pad = (~attention_mask).unsqueeze(1).unsqueeze(2)
81
+ attn_bias = causal.unsqueeze(0).unsqueeze(0).expand(B, 1, T, T).clone()
82
+ attn_bias = attn_bias.masked_fill(pad, float("-inf"))
83
+ out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias, dropout_p=dropout_p, is_causal=False)
84
+ else:
85
+ out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p, is_causal=True)
86
+
87
+ return self.out_proj(out.transpose(1, 2).contiguous().view(B, T, C))
88
+
89
+
90
+ class SwiGLUFFN(nn.Module):
91
+ def __init__(self, cfg: LaTeXDecoderConfig):
92
+ super().__init__()
93
+ self.gate_proj = nn.Linear(cfg.d_model, cfg.d_ff, bias=False)
94
+ self.up_proj = nn.Linear(cfg.d_model, cfg.d_ff, bias=False)
95
+ self.down_proj = nn.Linear(cfg.d_ff, cfg.d_model, bias=False)
96
+ self.dropout = nn.Dropout(cfg.dropout)
97
+
98
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
99
+ return self.dropout(self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)))
100
+
101
+
102
+ class TransformerBlock(nn.Module):
103
+ def __init__(self, cfg: LaTeXDecoderConfig):
104
+ super().__init__()
105
+ self.norm1 = RMSNorm(cfg.d_model)
106
+ self.attn = CausalSelfAttention(cfg)
107
+ self.norm2 = RMSNorm(cfg.d_model)
108
+ self.ffn = SwiGLUFFN(cfg)
109
+ self.drop = nn.Dropout(cfg.dropout)
110
+
111
+ def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
112
+ x = x + self.drop(self.attn(self.norm1(x), attention_mask))
113
+ x = x + self.drop(self.ffn(self.norm2(x)))
114
+ return x
115
+
116
+
117
+ class LaTeXDecoderForCausalLM(PreTrainedModel):
118
+ config_class = LaTeXDecoderConfig
119
+ base_model_prefix = "model"
120
+ supports_gradient_checkpointing = False
121
+
122
+ def __init__(self, config: LaTeXDecoderConfig):
123
+ super().__init__(config)
124
+
125
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, padding_idx=config.pad_id)
126
+ self.embed_drop = nn.Dropout(config.dropout)
127
+ self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)])
128
+ self.norm_final = RMSNorm(config.d_model)
129
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
130
+
131
+ if config.tie_weights:
132
+ self.lm_head.weight = self.embed_tokens.weight
133
+
134
+ self.post_init()
135
+
136
+ def _init_weights(self, module: nn.Module):
137
+ if isinstance(module, nn.Linear):
138
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
139
+ if module.bias is not None:
140
+ nn.init.zeros_(module.bias)
141
+ elif isinstance(module, nn.Embedding):
142
+ nn.init.normal_(module.weight, mean=0.0, std=0.02)
143
+
144
+ def forward(
145
+ self,
146
+ input_ids: torch.Tensor,
147
+ attention_mask: Optional[torch.Tensor] = None,
148
+ labels: Optional[torch.Tensor] = None,
149
+ **kwargs,
150
+ ) -> CausalLMOutput:
151
+ x = self.embed_drop(self.embed_tokens(input_ids))
152
+ for layer in self.layers:
153
+ x = layer(x, attention_mask)
154
+ logits = self.lm_head(self.norm_final(x))
155
+
156
+ loss = None
157
+ if labels is not None:
158
+ shift_logits = logits[:, :-1, :].contiguous()
159
+ shift_labels = labels[:, 1:].contiguous()
160
+ shift_labels = shift_labels.masked_fill(shift_labels == self.config.pad_id, -100)
161
+ loss = F.cross_entropy(
162
+ shift_logits.view(-1, self.config.vocab_size),
163
+ shift_labels.view(-1),
164
+ ignore_index=-100,
165
+ )
166
+
167
+ return CausalLMOutput(loss=loss, logits=logits)
168
+
169
+ @torch.inference_mode()
170
+ def generate(
171
+ self,
172
+ prompt_ids: torch.Tensor,
173
+ max_new_tokens: int = 200,
174
+ temperature: float = 1.0,
175
+ top_p: float = 0.9,
176
+ eos_id: Optional[int] = None,
177
+ ) -> torch.Tensor:
178
+ eos = eos_id if eos_id is not None else self.config.eos_id
179
+ generated = prompt_ids.clone()
180
+
181
+ for _ in range(max_new_tokens):
182
+ ctx = generated[:, -self.config.max_seq_len:]
183
+ logits = self.forward(ctx).logits[:, -1, :]
184
+
185
+ if temperature == 0.0:
186
+ next_id = logits.argmax(dim=-1, keepdim=True)
187
+ else:
188
+ probs = F.softmax(logits / temperature, dim=-1)
189
+ sorted_probs, sorted_idx = probs.sort(dim=-1, descending=True)
190
+ cumsum = sorted_probs.cumsum(dim=-1)
191
+ sorted_probs[cumsum - sorted_probs > top_p] = 0.0
192
+ sorted_probs /= sorted_probs.sum(dim=-1, keepdim=True)
193
+ next_id = sorted_idx.gather(-1, torch.multinomial(sorted_probs, 1))
194
+
195
+ generated = torch.cat([generated, next_id], dim=-1)
196
+ if next_id.item() == eos:
197
+ break
198
+
199
+ return generated