shreyask commited on
Commit
5479f24
·
verified ·
1 Parent(s): 162baf0

Upload needle_torch/model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. needle_torch/model.py +371 -0
needle_torch/model.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Needle Simple Attention Network — PyTorch port.
2
+
3
+ Encoder, Decoder, NeedleModel — parametric on TransformerConfig.
4
+
5
+ Key design decisions:
6
+ - No FFN (no_feedforward=True is the production default; we never implement it).
7
+ - ZCRMSNorm, GQA, RoPE all match architecture.py line-for-line.
8
+ - Decoder.step() is ONNX-traceable: no data-dependent control flow.
9
+ - Tied embedding: decoder logits = hidden @ embedding.weight.T
10
+ """
11
+
12
+ import math
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+
17
+ from .config import TransformerConfig
18
+ from .layers import ZCRMSNorm, RoPE, MultiHeadAttention, make_causal_mask
19
+
20
+
21
+ # ---------------------------------------------------------------------------
22
+ # EncoderBlock
23
+ # ---------------------------------------------------------------------------
24
+
25
+ class EncoderBlock(nn.Module):
26
+ """Pre-norm self-attention with sigmoid-gated residual.
27
+
28
+ Matches Flax EncoderBlock.__call__:
29
+ gate = sigmoid(attn_gate)
30
+ x = ZCRMSNorm(x)
31
+ x = self_attn(x, x, ...)
32
+ x = residual + gate * attn_out
33
+ """
34
+
35
+ def __init__(self, config: TransformerConfig):
36
+ super().__init__()
37
+ # Scalar gate initialized to zero — sigmoid(0) = 0.5
38
+ self.attn_gate = nn.Parameter(torch.zeros(()))
39
+ self.norm = ZCRMSNorm(config.d_model)
40
+ self.self_attn = MultiHeadAttention(config, is_cross_attn=False, is_causal=False)
41
+
42
+ def forward(self, x: torch.Tensor, mask=None, rope=None):
43
+ """
44
+ x: (B, T, d_model)
45
+ mask: (B, 1, T, T) bool
46
+ rope: (cos, sin) from RoPE buffers
47
+ """
48
+ gate = torch.sigmoid(self.attn_gate)
49
+ residual = x
50
+ x = self.norm(x)
51
+ attn_out, _ = self.self_attn(x, x, mask=mask, rope=rope)
52
+ x = residual + gate * attn_out
53
+ return x
54
+
55
+
56
+ # ---------------------------------------------------------------------------
57
+ # DecoderBlock
58
+ # ---------------------------------------------------------------------------
59
+
60
+ class DecoderBlock(nn.Module):
61
+ """Causal self-attn + cross-attn with independent sigmoid-gated residuals.
62
+
63
+ Matches Flax DecoderBlock.__call__:
64
+ self_gate = sigmoid(self_attn_gate)
65
+ x = ZCRMSNorm(x) -> self_attn(x, x) -> x = residual + self_gate * out
66
+
67
+ cross_gate = sigmoid(cross_attn_gate)
68
+ x = ZCRMSNorm(x) -> cross_attn(x, encoder_out) -> x = residual + cross_gate * out
69
+ """
70
+
71
+ def __init__(self, config: TransformerConfig):
72
+ super().__init__()
73
+ self.self_attn_gate = nn.Parameter(torch.zeros(()))
74
+ self.cross_attn_gate = nn.Parameter(torch.zeros(()))
75
+
76
+ # ZCRMSNorm_0 = pre-norm for self-attn
77
+ # ZCRMSNorm_1 = pre-norm for cross-attn
78
+ self.self_norm = ZCRMSNorm(config.d_model)
79
+ self.cross_norm = ZCRMSNorm(config.d_model)
80
+
81
+ self.self_attn = MultiHeadAttention(config, is_cross_attn=False, is_causal=True)
82
+ self.cross_attn = MultiHeadAttention(config, is_cross_attn=True, is_causal=False)
83
+
84
+ def forward(
85
+ self,
86
+ x: torch.Tensor,
87
+ encoder_out: torch.Tensor,
88
+ self_mask=None,
89
+ cross_mask=None,
90
+ rope=None,
91
+ past_self_kv=None,
92
+ ):
93
+ """
94
+ Args:
95
+ x: (B, T_dec, d_model)
96
+ encoder_out: (B, T_enc, d_model)
97
+ self_mask: (B, 1, T_dec, T_total) bool
98
+ cross_mask: (B, 1, T_dec, T_enc) bool
99
+ rope: (cos, sin) for self-attention RoPE
100
+ past_self_kv: (k, v) each (B, num_kv_heads, past_T, head_dim)
101
+
102
+ Returns:
103
+ x: (B, T_dec, d_model)
104
+ present_self_kv: (k, v) each (B, num_kv_heads, T_total, head_dim)
105
+ """
106
+ # --- Causal self-attention ---
107
+ self_gate = torch.sigmoid(self.self_attn_gate)
108
+ residual = x
109
+ x = self.self_norm(x)
110
+ self_out, present_self_kv = self.self_attn(
111
+ x, x, mask=self_mask, rope=rope, past_kv=past_self_kv
112
+ )
113
+ x = residual + self_gate * self_out
114
+
115
+ # --- Cross-attention ---
116
+ cross_gate = torch.sigmoid(self.cross_attn_gate)
117
+ residual = x
118
+ x = self.cross_norm(x)
119
+ cross_out, _ = self.cross_attn(x, encoder_out, mask=cross_mask)
120
+ x = residual + cross_gate * cross_out
121
+
122
+ return x, present_self_kv
123
+
124
+
125
+ # ---------------------------------------------------------------------------
126
+ # Encoder
127
+ # ---------------------------------------------------------------------------
128
+
129
+ class Encoder(nn.Module):
130
+ """Embedding lookup + N EncoderBlocks + final ZCRMSNorm.
131
+
132
+ Returns encoder hidden states: (B, T_enc, d_model).
133
+ Note: embedding is shared with Decoder and set externally via .embedding.
134
+ """
135
+
136
+ def __init__(self, config: TransformerConfig):
137
+ super().__init__()
138
+ self.config = config
139
+ # Embedding is shared; the NeedleModel assigns it after construction.
140
+ self.embedding: nn.Embedding | None = None
141
+ self.embed_scale = math.sqrt(config.d_model)
142
+
143
+ self.layers = nn.ModuleList([
144
+ EncoderBlock(config) for _ in range(config.num_encoder_layers)
145
+ ])
146
+ self.final_norm = ZCRMSNorm(config.d_model)
147
+
148
+ head_dim = config.d_model // config.num_heads
149
+ self.rope = RoPE(head_dim, config.max_seq_len, config.rope_theta)
150
+
151
+ def forward(self, input_ids: torch.Tensor, mask=None) -> torch.Tensor:
152
+ """
153
+ input_ids: (B, T_enc) long
154
+ mask: (B, 1, 1, T_enc) bool padding mask (optional)
155
+
156
+ Returns: (B, T_enc, d_model)
157
+ """
158
+ assert self.embedding is not None, "Encoder.embedding must be set by NeedleModel"
159
+ x = self.embedding(input_ids) * self.embed_scale
160
+
161
+ T = input_ids.shape[1]
162
+ cos, sin = self.rope.get_cos_sin(T)
163
+ rope = (cos, sin)
164
+
165
+ for layer in self.layers:
166
+ x = layer(x, mask=mask, rope=rope)
167
+
168
+ x = self.final_norm(x)
169
+ return x
170
+
171
+
172
+ # ---------------------------------------------------------------------------
173
+ # Decoder
174
+ # ---------------------------------------------------------------------------
175
+
176
+ class Decoder(nn.Module):
177
+ """Embedding lookup + N DecoderBlocks + final ZCRMSNorm + LM head.
178
+
179
+ The LM head is a tied projection: logits = hidden @ embedding.weight.T
180
+ The embedding weight is shared with the Encoder/NeedleModel.
181
+ """
182
+
183
+ def __init__(self, config: TransformerConfig):
184
+ super().__init__()
185
+ self.config = config
186
+ # Embedding is shared; set by NeedleModel after construction.
187
+ self.embedding: nn.Embedding | None = None
188
+ self.embed_scale = math.sqrt(config.d_model)
189
+
190
+ self.layers = nn.ModuleList([
191
+ DecoderBlock(config) for _ in range(config.num_decoder_layers)
192
+ ])
193
+ # ZCRMSNorm_0 in the decoder (final norm after all layers)
194
+ self.final_norm = ZCRMSNorm(config.d_model)
195
+
196
+ head_dim = config.d_model // config.num_heads
197
+ self.rope = RoPE(head_dim, config.max_seq_len, config.rope_theta)
198
+
199
+ def forward(
200
+ self,
201
+ input_ids: torch.Tensor,
202
+ encoder_out: torch.Tensor,
203
+ self_mask=None,
204
+ cross_mask=None,
205
+ ) -> torch.Tensor:
206
+ """Full-sequence decode (training / teacher-forcing).
207
+
208
+ Args:
209
+ input_ids: (B, T_dec) long
210
+ encoder_out: (B, T_enc, d_model)
211
+ self_mask: (B, 1, T_dec, T_dec) bool causal mask
212
+ cross_mask: (B, 1, T_dec, T_enc) bool
213
+
214
+ Returns:
215
+ logits: (B, T_dec, vocab_size)
216
+ """
217
+ assert self.embedding is not None
218
+ x = self.embedding(input_ids) * self.embed_scale
219
+
220
+ T = input_ids.shape[1]
221
+ cos, sin = self.rope.get_cos_sin(T)
222
+ rope = (cos, sin)
223
+
224
+ for layer in self.layers:
225
+ x, _ = layer(x, encoder_out, self_mask=self_mask, cross_mask=cross_mask,
226
+ rope=rope, past_self_kv=None)
227
+
228
+ x = self.final_norm(x)
229
+ # Tied output projection: (B, T, d_model) @ (d_model, vocab_size)
230
+ logits = x.float() @ self.embedding.weight.T
231
+ return logits
232
+
233
+ # ------------------------------------------------------------------
234
+ # Autoregressive step — the entry point for ONNX export (Task 7)
235
+ # ------------------------------------------------------------------
236
+
237
+ def initial_past_kv(self, batch: int = 1) -> torch.Tensor:
238
+ """Return a zero past_kv tensor for the first step.
239
+
240
+ Shape: (num_decoder_layers, 2, batch, num_kv_heads, 0, head_dim)
241
+
242
+ Using length-0 in the sequence dimension so the first step's cat
243
+ produces just the current step's KV.
244
+ """
245
+ cfg = self.config
246
+ head_dim = cfg.d_model // cfg.num_heads
247
+ return torch.zeros(
248
+ cfg.num_decoder_layers, 2, batch, cfg.num_kv_heads, 0, head_dim,
249
+ dtype=torch.float32,
250
+ )
251
+
252
+ def step(
253
+ self,
254
+ decoder_input_ids: torch.Tensor,
255
+ encoder_kv: torch.Tensor,
256
+ past_self_kv: torch.Tensor,
257
+ ):
258
+ """Single autoregressive decoder step.
259
+
260
+ Accepts explicit past KV cache and returns updated KV (present).
261
+ This signature is what torch.onnx.export traces in Task 7.
262
+
263
+ Args:
264
+ decoder_input_ids: (B, 1) long — single token per step
265
+ encoder_kv: (B, T_enc, d_model) — frozen encoder output
266
+ past_self_kv: (num_decoder_layers, 2, B, num_kv_heads, past_T, head_dim)
267
+ Use initial_past_kv() for the first step.
268
+
269
+ Returns:
270
+ logits: (B, 1, vocab_size)
271
+ present_kv: (num_decoder_layers, 2, B, num_kv_heads, past_T+1, head_dim)
272
+
273
+ NOTE: No Python control flow that depends on tensor *values* — only
274
+ shape-derived constants — so this is safely ONNX-traceable.
275
+ """
276
+ assert self.embedding is not None
277
+ B = decoder_input_ids.shape[0]
278
+
279
+ x = self.embedding(decoder_input_ids) * self.embed_scale # (B, 1, d_model)
280
+
281
+ # RoPE for this one position: offset by past_T
282
+ past_T = past_self_kv.shape[4]
283
+ # We use position (past_T) for the current token.
284
+ # Slice cos/sin at that single position: (1, head_dim//2)
285
+ cos_full, sin_full = self.rope.get_cos_sin(past_T + 1)
286
+ cos = cos_full[past_T:past_T + 1] # (1, head_dim//2)
287
+ sin = sin_full[past_T:past_T + 1]
288
+ rope = (cos, sin)
289
+
290
+ # Causal mask: shape (1, 1, 1, past_T+1) — current token attends all past+self
291
+ self_mask = make_causal_mask(1, past_T, device=x.device) # (1,1,1, past_T+1)
292
+
293
+ present_layers = []
294
+ for i, layer in enumerate(self.layers):
295
+ # Unpack this layer's past KV: each (B, num_kv_heads, past_T, head_dim)
296
+ layer_past_k = past_self_kv[i, 0] # (B, num_kv_heads, past_T, head_dim)
297
+ layer_past_v = past_self_kv[i, 1]
298
+ layer_past = (layer_past_k, layer_past_v)
299
+
300
+ x, (k_new, v_new) = layer(
301
+ x, encoder_kv,
302
+ self_mask=self_mask,
303
+ cross_mask=None,
304
+ rope=rope,
305
+ past_self_kv=layer_past,
306
+ )
307
+ # k_new, v_new: (B, num_kv_heads, past_T+1, head_dim)
308
+ present_layers.append(torch.stack([k_new, v_new], dim=0)) # (2, B, nkv, T+1, hd)
309
+
310
+ # Stack layers: (num_decoder_layers, 2, B, num_kv_heads, past_T+1, head_dim)
311
+ present_kv = torch.stack(present_layers, dim=0)
312
+
313
+ x = self.final_norm(x)
314
+ logits = x.float() @ self.embedding.weight.T # (B, 1, vocab_size)
315
+ return logits, present_kv
316
+
317
+
318
+ # ---------------------------------------------------------------------------
319
+ # NeedleModel
320
+ # ---------------------------------------------------------------------------
321
+
322
+ class NeedleModel(nn.Module):
323
+ """Top-level Needle Simple Attention Network — PyTorch port.
324
+
325
+ Mirrors SimpleAttentionNetwork (Flax).
326
+
327
+ Parameters
328
+ ----------
329
+ config : TransformerConfig
330
+ Architecture hyperparameters. Pass production dims to get the 26M model.
331
+ """
332
+
333
+ def __init__(self, config: TransformerConfig):
334
+ super().__init__()
335
+ self.config = config
336
+
337
+ # Shared embedding (tied output projection in decoder)
338
+ self.embedding = nn.Embedding(config.vocab_size, config.d_model)
339
+ nn.init.normal_(self.embedding.weight, std=0.02)
340
+
341
+ self.encoder = Encoder(config)
342
+ self.decoder = Decoder(config)
343
+
344
+ # Wire up shared embedding
345
+ self.encoder.embedding = self.embedding
346
+ self.decoder.embedding = self.embedding
347
+
348
+ # Contrastive head — present in the Flax param tree
349
+ # contrastive_hidden: (d_model, d_model//4) with bias
350
+ self.contrastive_hidden = nn.Linear(config.d_model, config.d_model // 4, bias=True)
351
+ # contrastive_proj: (d_model//4, contrastive_dim) no bias
352
+ self.contrastive_proj = nn.Linear(config.d_model // 4, config.contrastive_dim, bias=False)
353
+
354
+ # Scalar contrastive temperature
355
+ self.log_temp = nn.Parameter(torch.zeros(()))
356
+
357
+ def forward(
358
+ self,
359
+ src: torch.Tensor,
360
+ tgt: torch.Tensor,
361
+ src_mask=None,
362
+ tgt_mask=None,
363
+ cross_mask=None,
364
+ ) -> torch.Tensor:
365
+ """Full encoder-decoder forward pass (training).
366
+
367
+ Returns logits: (B, T_dec, vocab_size)
368
+ """
369
+ encoder_out = self.encoder(src, mask=src_mask)
370
+ logits = self.decoder(tgt, encoder_out, self_mask=tgt_mask, cross_mask=cross_mask)
371
+ return logits