ronnengmail commited on
Commit
2e0bc42
·
verified ·
1 Parent(s): e06a0c9

Upload model_arch.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model_arch.py +121 -312
model_arch.py CHANGED
@@ -1,343 +1,152 @@
1
- #!/usr/bin/env python3
2
- """
3
- Multilingual 3B GPT — SFT Training
4
-
5
- Fine-tunes the base model on instruction data (Aya + Bactrian-X + FLORES translations).
6
- Uses the same architecture as pretraining with LoRA-free full fine-tuning
7
- (model is 3B params, fits in 24GB A10G in bf16).
8
-
9
- Usage:
10
- python train_sft_3b.py --checkpoint /path/to/best_model.pt \
11
- --tokenizer /path/to/multilingual_32k.model \
12
- --data-dir /path/to/sft_data/ \
13
- --output /path/to/sft_model.pt
14
- """
15
-
16
- import os, sys, json, math, time, random, argparse
17
- sys.stdout.reconfigure(line_buffering=True)
18
- import gc
19
- import numpy as np
20
  import torch
21
  import torch.nn as nn
22
  import torch.nn.functional as F
23
- import sentencepiece as spm
24
 
25
- # ============ MODEL (must match training) ============
26
  VOCAB_SIZE = 32000
27
  DIM = 3072
28
  DEPTH = 26
29
  N_HEADS = 24
 
30
  MAX_SEQ_LEN = 2048
31
- ROPE_THETA = 10000
 
 
32
 
33
  class RMSNorm(nn.Module):
34
  def __init__(self, dim, eps=1e-6):
35
  super().__init__()
36
- self.weight = nn.Parameter(torch.ones(dim))
37
  self.eps = eps
38
- def forward(self, x):
39
- return x * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps).type_as(x) * self.weight
40
 
41
- class SwiGLU(nn.Module):
42
- def __init__(self, dim, hidden_dim):
43
- super().__init__()
44
- self.gate = nn.Linear(dim, hidden_dim, bias=False)
45
- self.up = nn.Linear(dim, hidden_dim, bias=False)
46
- self.down = nn.Linear(hidden_dim, dim, bias=False)
47
  def forward(self, x):
48
- return self.down(F.silu(self.gate(x)) * self.up(x))
 
 
 
 
 
 
 
 
 
49
 
50
- def apply_rope(x, cos, sin):
51
- x1, x2 = x[..., ::2], x[..., 1::2]
52
- return torch.stack((x1*cos - x2*sin, x1*sin + x2*cos), dim=-1).flatten(-2)
 
 
 
 
53
 
54
- class Attention(nn.Module):
 
55
  def __init__(self, dim, n_heads):
56
  super().__init__()
57
  self.n_heads = n_heads
58
  self.head_dim = dim // n_heads
59
- self.qkv = nn.Linear(dim, 3*dim, bias=False)
60
- self.proj = nn.Linear(dim, dim, bias=False)
61
- def forward(self, x, cos, sin):
62
- B, T, C = x.shape
63
- qkv = self.qkv(x).reshape(B, T, 3, self.n_heads, self.head_dim).permute(2, 0, 3, 1, 4)
64
- q, k, v = qkv[0], qkv[1], qkv[2]
65
- q, k = apply_rope(q, cos, sin), apply_rope(k, cos, sin)
66
- y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
67
- return self.proj(y.transpose(1, 2).contiguous().view(B, T, C))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
- class Block(nn.Module):
70
- def __init__(self, dim, n_heads, mlp_dim):
 
 
 
 
71
  super().__init__()
72
- self.ln1 = RMSNorm(dim)
73
- self.attn = Attention(dim, n_heads)
74
- self.ln2 = RMSNorm(dim)
75
- self.mlp = SwiGLU(dim, mlp_dim)
76
- def forward(self, x, cos, sin):
77
- x = x + self.attn(self.ln1(x), cos, sin)
78
- x = x + self.mlp(self.ln2(x))
 
79
  return x
80
 
81
- class GPT(nn.Module):
 
82
  def __init__(self):
83
  super().__init__()
84
  self.tok_emb = nn.Embedding(VOCAB_SIZE, DIM)
85
- mlp_dim = ((int(2 * DIM * 4 / 3) + 63) // 64) * 64
86
- self.blocks = nn.ModuleList([Block(DIM, N_HEADS, mlp_dim) for _ in range(DEPTH)])
87
- self.ln_f = RMSNorm(DIM)
 
88
  self.head = nn.Linear(DIM, VOCAB_SIZE, bias=False)
 
89
  self.head.weight = self.tok_emb.weight
90
- hd = DIM // N_HEADS
91
- freqs = 1.0 / (ROPE_THETA ** (torch.arange(0, hd, 2).float() / hd))
92
- angles = torch.outer(torch.arange(MAX_SEQ_LEN).float(), freqs)
93
- self.register_buffer('rope_cos', angles.cos())
94
- self.register_buffer('rope_sin', angles.sin())
95
-
96
- def forward(self, idx):
97
- B, T = idx.shape
98
- x = self.tok_emb(idx)
99
- cos = self.rope_cos[:T][None, None]
100
- sin = self.rope_sin[:T][None, None]
101
- for block in self.blocks:
102
- x = block(x, cos, sin)
103
- return self.head(self.ln_f(x))
104
-
105
- @torch.no_grad()
106
- def generate(self, idx, max_new=200, temp=0.7, top_k=40, rep_penalty=1.2):
107
- for _ in range(max_new):
108
- idx_c = idx[:, -MAX_SEQ_LEN:]
109
- logits = self(idx_c)[:, -1, :]
110
- if rep_penalty > 1.0:
111
- for token_id in set(idx[0].tolist()[-50:]):
112
- logits[0, token_id] /= rep_penalty
113
- logits = logits / temp
114
- if top_k > 0:
115
- v, _ = torch.topk(logits, top_k)
116
- logits[logits < v[:, [-1]]] = float('-inf')
117
- probs = F.softmax(logits, dim=-1)
118
- nx = torch.multinomial(probs, 1)
119
- idx = torch.cat([idx, nx], dim=1)
120
- if nx.item() == 2:
121
- break
122
- return idx
123
-
124
-
125
- # ============ DATA LOADING ============
126
- USER_PREFIX = "### User:\n"
127
- ASSISTANT_PREFIX = "### Assistant:\n"
128
- TURN_END = "\n\n"
129
-
130
- def load_sft_data(data_dir, split='train'):
131
- """Load tokenized SFT data."""
132
- filepath = os.path.join(data_dir, f'{split}_sft.bin')
133
- data = np.fromfile(filepath, dtype=np.uint16)
134
- return torch.from_numpy(data.astype(np.int64))
135
-
136
- def get_batch(data, batch_size, seq_len, device):
137
- """Get a random batch of sequences."""
138
- ix = torch.randint(len(data) - seq_len - 1, (batch_size,))
139
- x = torch.stack([data[i:i+seq_len] for i in ix]).to(device)
140
- y = torch.stack([data[i+1:i+seq_len+1] for i in ix]).to(device)
141
- return x, y
142
-
143
-
144
- # ============ TRAINING ============
145
- def train(args):
146
- device = args.device
147
- print(f"Device: {device}")
148
-
149
- # Load tokenizer
150
- print(f"Loading tokenizer: {args.tokenizer}")
151
- sp = spm.SentencePieceProcessor(args.tokenizer)
152
-
153
- # Load model
154
- print(f"Loading base model: {args.checkpoint}")
155
- model = GPT()
156
- ckpt = torch.load(args.checkpoint, map_location='cpu', weights_only=False)
157
- state_dict = ckpt.get('model_state_dict', ckpt.get('model', ckpt))
158
- clean_sd = {}
159
- for k, v in state_dict.items():
160
- k = k.replace('_orig_mod.', '').replace('module.', '')
161
- clean_sd[k] = v
162
- model.load_state_dict(clean_sd, strict=False)
163
- del ckpt, state_dict, clean_sd
164
- gc.collect()
165
-
166
- model = model.to(device).train()
167
- # Use bf16 for memory efficiency
168
- model = model.to(torch.bfloat16)
169
-
170
- param_count = sum(p.numel() for p in model.parameters())
171
- print(f"Model loaded: {param_count/1e9:.2f}B parameters")
172
-
173
- # Load data
174
- print(f"Loading SFT data from: {args.data_dir}")
175
- train_data = load_sft_data(args.data_dir, 'train')
176
- val_data = load_sft_data(args.data_dir, 'val')
177
- print(f"Train: {len(train_data)} tokens, Val: {len(val_data)} tokens")
178
-
179
- # Optimizer — 8-bit Adam for memory efficiency (halves optimizer states)
180
- try:
181
- import bitsandbytes as bnb
182
- optimizer = bnb.optim.AdamW8bit(
183
- model.parameters(),
184
- lr=args.lr,
185
- betas=(0.9, 0.95),
186
- weight_decay=0.01,
187
- )
188
- print("Using 8-bit AdamW (bitsandbytes)")
189
- except ImportError:
190
- optimizer = torch.optim.AdamW(
191
- model.parameters(),
192
- lr=args.lr,
193
- betas=(0.9, 0.95),
194
- weight_decay=0.01,
195
- )
196
- print("Using standard AdamW")
197
-
198
- # Cosine schedule with warmup
199
- def get_lr(step):
200
- if step < args.warmup_steps:
201
- return args.lr * step / args.warmup_steps
202
- decay_ratio = (step - args.warmup_steps) / (args.max_steps - args.warmup_steps)
203
- return args.lr * 0.1 + 0.9 * args.lr * 0.5 * (1 + math.cos(math.pi * decay_ratio))
204
-
205
- # Enable gradient checkpointing to save VRAM
206
- for block in model.blocks:
207
- block._gradient_checkpointing = True
208
- original_block_forward = Block.forward
209
- def checkpointed_forward(self, x, cos, sin):
210
- if self.training and hasattr(self, '_gradient_checkpointing') and self._gradient_checkpointing:
211
- return torch.utils.checkpoint.checkpoint(original_block_forward, self, x, cos, sin, use_reentrant=False)
212
- return original_block_forward(self, x, cos, sin)
213
- Block.forward = checkpointed_forward
214
-
215
- # Training loop
216
- best_val_loss = float('inf')
217
- grad_accum = args.grad_accum
218
- print(f"\nStarting SFT training for {args.max_steps} steps...")
219
- print(f"Batch size: {args.batch_size} x {grad_accum} accum = {args.batch_size * grad_accum} effective, Seq len: {MAX_SEQ_LEN}, LR: {args.lr}")
220
-
221
- t0 = time.time()
222
- for step in range(1, args.max_steps + 1):
223
- # LR schedule
224
- lr = get_lr(step)
225
- for pg in optimizer.param_groups:
226
- pg['lr'] = lr
227
-
228
- # Gradient accumulation
229
- optimizer.zero_grad(set_to_none=True)
230
- accum_loss = 0.0
231
- for micro in range(grad_accum):
232
- x, y = get_batch(train_data, args.batch_size, MAX_SEQ_LEN, device)
233
- with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
234
- logits = model(x)
235
- loss = F.cross_entropy(logits.view(-1, VOCAB_SIZE), y.view(-1)) / grad_accum
236
- loss.backward()
237
- accum_loss += loss.item()
238
-
239
- torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
240
- optimizer.step()
241
- loss = type('obj', (object,), {'item': lambda self: accum_loss})() # For logging
242
-
243
- # Logging
244
- if step % 10 == 0:
245
- elapsed = time.time() - t0
246
- tps = step * args.batch_size * grad_accum * MAX_SEQ_LEN / elapsed
247
- print(f"Step {step}/{args.max_steps} | Loss: {accum_loss:.4f} | LR: {lr:.6f} | TPS: {tps:.0f} | {elapsed:.0f}s")
248
-
249
- # Eval
250
- if step % args.eval_every == 0 or step == args.max_steps:
251
- model.eval()
252
- val_losses = []
253
- for _ in range(20):
254
- x, y = get_batch(val_data, args.batch_size, MAX_SEQ_LEN, device)
255
- with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.bfloat16):
256
- logits = model(x)
257
- val_loss = F.cross_entropy(logits.view(-1, VOCAB_SIZE), y.view(-1))
258
- val_losses.append(val_loss.item())
259
- avg_val = sum(val_losses) / len(val_losses)
260
- print(f" 📊 Val loss: {avg_val:.4f} {'(NEW BEST!)' if avg_val < best_val_loss else ''}")
261
-
262
- if avg_val < best_val_loss:
263
- best_val_loss = avg_val
264
- torch.save({
265
- 'model_state_dict': model.state_dict(),
266
- 'step': step,
267
- 'val_loss': avg_val,
268
- 'config': {
269
- 'vocab_size': VOCAB_SIZE, 'dim': DIM, 'depth': DEPTH,
270
- 'n_heads': N_HEADS, 'max_seq_len': MAX_SEQ_LEN,
271
- }
272
- }, args.output)
273
- print(f" 💾 Best model saved to {args.output}")
274
-
275
- model.train()
276
-
277
- # Generate samples periodically
278
- if step % args.sample_every == 0 or step == args.max_steps:
279
- model.eval()
280
- prompts = [
281
- ("EN", "### User:\nWhat is the capital of France?\n\n### Assistant:\n"),
282
- ("HE", "### User:\nמה בירת צרפת?\n\n### Assistant:\n"),
283
- ("AR", "### User:\nما هي عاصمة فرنسا؟\n\n### Assistant:\n"),
284
- ("FA", "### User:\nپایتخت فرانسه کجاست؟\n\n### Assistant:\n"),
285
- ("TRANSLATE", "### User:\nTranslate the following Hebrew text to English:\nשלום עולם, איך אתה היום?\n\n### Assistant:\n"),
286
- ]
287
- print(f"\n 🔤 Generation samples (step {step}):")
288
- for label, prompt in prompts:
289
- ids = torch.tensor([sp.encode(prompt)], device=device, dtype=torch.long)
290
- with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.bfloat16):
291
- out = model.generate(ids, max_new=100, temp=0.7, top_k=40)
292
- text = sp.decode(out[0].tolist())
293
- # Just show the assistant response
294
- if "### Assistant:" in text:
295
- response = text.split("### Assistant:")[-1].strip()[:200]
296
- else:
297
- response = text[len(prompt):].strip()[:200]
298
- print(f" [{label}] {response}")
299
- print()
300
- model.train()
301
-
302
- # Final save
303
- elapsed = time.time() - t0
304
- print(f"\n{'='*60}")
305
- print(f"SFT TRAINING COMPLETE")
306
- print(f"Steps: {args.max_steps}, Time: {elapsed:.0f}s ({elapsed/60:.1f}min)")
307
- print(f"Best val loss: {best_val_loss:.4f}")
308
- print(f"Model saved to: {args.output}")
309
- print(f"{'='*60}")
310
-
311
- # Upload to S3
312
- print("Uploading to S3...")
313
- os.system(f"aws s3 cp {args.output} s3://autoresearch-dashboard-196766918360/multilingual-7b/checkpoints/3b-v1-fsdp/sft_model.pt --quiet")
314
- os.system(f"aws s3 cp /tmp/sft/sft.log s3://autoresearch-dashboard-196766918360/multilingual-7b/eval/sft_3b.log --quiet 2>/dev/null")
315
- print("Done!")
316
-
317
-
318
- def main():
319
- parser = argparse.ArgumentParser()
320
- parser.add_argument('--checkpoint', required=True)
321
- parser.add_argument('--tokenizer', required=True)
322
- parser.add_argument('--data-dir', required=True)
323
- parser.add_argument('--output', default='/tmp/sft/sft_model.pt')
324
- parser.add_argument('--device', default='cuda')
325
- parser.add_argument('--batch-size', type=int, default=1) # 1 for 24GB GPU
326
- parser.add_argument('--grad-accum', type=int, default=4) # Effective batch = 4
327
- parser.add_argument('--lr', type=float, default=2e-5)
328
- parser.add_argument('--max-steps', type=int, default=2000)
329
- parser.add_argument('--warmup-steps', type=int, default=100)
330
- parser.add_argument('--eval-every', type=int, default=200)
331
- parser.add_argument('--sample-every', type=int, default=500)
332
- parser.add_argument('--seed', type=int, default=42)
333
- args = parser.parse_args()
334
-
335
- random.seed(args.seed)
336
- torch.manual_seed(args.seed)
337
- os.makedirs(os.path.dirname(args.output), exist_ok=True)
338
-
339
- train(args)
340
-
341
-
342
- if __name__ == '__main__':
343
- main()
 
1
+ """Shared model architecture for multilingual 3B GPT — must match training exactly."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import torch
3
  import torch.nn as nn
4
  import torch.nn.functional as F
5
+ import math
6
 
 
7
  VOCAB_SIZE = 32000
8
  DIM = 3072
9
  DEPTH = 26
10
  N_HEADS = 24
11
+ HEAD_DIM = DIM // N_HEADS # 128
12
  MAX_SEQ_LEN = 2048
13
+ ROPE_THETA = 10000.0
14
+ HIDDEN_DIM = ((int(2 * DIM * 4 / 3) + 63) // 64) * 64 # SwiGLU hidden
15
+
16
 
17
  class RMSNorm(nn.Module):
18
  def __init__(self, dim, eps=1e-6):
19
  super().__init__()
 
20
  self.eps = eps
21
+ self.weight = nn.Parameter(torch.ones(dim))
 
22
 
 
 
 
 
 
 
23
  def forward(self, x):
24
+ norm = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
25
+ return (x.float() * norm).type_as(x) * self.weight
26
+
27
+
28
+ def precompute_freqs_cis(dim, max_seq_len, theta=ROPE_THETA):
29
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
30
+ t = torch.arange(max_seq_len, dtype=torch.float32)
31
+ freqs = torch.outer(t, freqs)
32
+ return torch.polar(torch.ones_like(freqs), freqs)
33
+
34
 
35
+ def apply_rotary_emb(x, freqs_cis):
36
+ # x: (B, n_heads, S, head_dim)
37
+ B, H, S, D = x.shape
38
+ x_complex = torch.view_as_complex(x.float().reshape(B, H, S, D // 2, 2))
39
+ freqs = freqs_cis[:S].unsqueeze(0).unsqueeze(1) # (1, 1, S, D//2)
40
+ x_rot = torch.view_as_real(x_complex * freqs).reshape(B, H, S, D)
41
+ return x_rot.type_as(x)
42
 
43
+
44
+ class FusedAttention(nn.Module):
45
  def __init__(self, dim, n_heads):
46
  super().__init__()
47
  self.n_heads = n_heads
48
  self.head_dim = dim // n_heads
49
+ self.qkv = nn.Linear(dim, 3 * dim, bias=False)
50
+ self.out_proj = nn.Linear(dim, dim, bias=False)
51
+
52
+ def forward(self, x, freqs_cis, mask=None):
53
+ B, S, D = x.shape
54
+ qkv = self.qkv(x).reshape(B, S, 3, self.n_heads, self.head_dim)
55
+ q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
56
+ q = q.transpose(1, 2) # (B, H, S, D)
57
+ k = k.transpose(1, 2)
58
+ v = v.transpose(1, 2)
59
+ q = apply_rotary_emb(q, freqs_cis)
60
+ k = apply_rotary_emb(k, freqs_cis)
61
+ # Scaled dot-product attention
62
+ scale = math.sqrt(self.head_dim)
63
+ attn = (q @ k.transpose(-2, -1)) / scale
64
+ if mask is not None:
65
+ attn = attn + mask
66
+ attn = F.softmax(attn, dim=-1)
67
+ out = (attn @ v).transpose(1, 2).reshape(B, S, D)
68
+ return self.out_proj(out)
69
+
70
+
71
+ class SwiGLUFFN(nn.Module):
72
+ def __init__(self, dim, hidden_dim):
73
+ super().__init__()
74
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
75
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
76
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
77
 
78
+ def forward(self, x):
79
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
80
+
81
+
82
+ class TransformerBlock(nn.Module):
83
+ def __init__(self, dim, n_heads, hidden_dim):
84
  super().__init__()
85
+ self.attn_norm = RMSNorm(dim)
86
+ self.attn = FusedAttention(dim, n_heads)
87
+ self.ffn_norm = RMSNorm(dim)
88
+ self.ffn = SwiGLUFFN(dim, hidden_dim)
89
+
90
+ def forward(self, x, freqs_cis, mask=None):
91
+ x = x + self.attn(self.attn_norm(x), freqs_cis, mask)
92
+ x = x + self.ffn(self.ffn_norm(x))
93
  return x
94
 
95
+
96
+ class MultilingualGPT(nn.Module):
97
  def __init__(self):
98
  super().__init__()
99
  self.tok_emb = nn.Embedding(VOCAB_SIZE, DIM)
100
+ self.layers = nn.ModuleList([
101
+ TransformerBlock(DIM, N_HEADS, HIDDEN_DIM) for _ in range(DEPTH)
102
+ ])
103
+ self.norm = RMSNorm(DIM)
104
  self.head = nn.Linear(DIM, VOCAB_SIZE, bias=False)
105
+ # Tied embeddings
106
  self.head.weight = self.tok_emb.weight
107
+ # Precompute RoPE
108
+ self.register_buffer('freqs_cis', precompute_freqs_cis(HEAD_DIM, MAX_SEQ_LEN))
109
+
110
+ def forward(self, tokens, targets=None):
111
+ B, S = tokens.shape
112
+ x = self.tok_emb(tokens)
113
+ mask = torch.triu(torch.full((S, S), float('-inf'), device=tokens.device), diagonal=1)
114
+ mask = mask.unsqueeze(0).unsqueeze(0) # (1, 1, S, S)
115
+ for layer in self.layers:
116
+ x = layer(x, self.freqs_cis, mask)
117
+ x = self.norm(x)
118
+ logits = self.head(x)
119
+ loss = None
120
+ if targets is not None:
121
+ loss = F.cross_entropy(logits.view(-1, VOCAB_SIZE), targets.view(-1))
122
+ return logits, loss
123
+
124
+
125
+ def load_model(path, device='cuda'):
126
+ """Load model from checkpoint, stripping prefixes."""
127
+ model = MultilingualGPT()
128
+ ckpt = torch.load(path, map_location='cpu', weights_only=False)
129
+ state = ckpt.get('model_state_dict', ckpt)
130
+ # Strip prefixes
131
+ cleaned = {}
132
+ for k, v in state.items():
133
+ new_k = k
134
+ for prefix in ['_orig_mod.', 'module.']:
135
+ if new_k.startswith(prefix):
136
+ new_k = new_k[len(prefix):]
137
+ cleaned[new_k] = v
138
+ # Handle tied weights - remove head.weight if present (will be tied)
139
+ if 'head.weight' in cleaned and 'tok_emb.weight' in cleaned:
140
+ if torch.equal(cleaned['head.weight'], cleaned['tok_emb.weight']):
141
+ del cleaned['head.weight']
142
+ model.load_state_dict(cleaned, strict=False)
143
+ model = model.to(device).eval()
144
+ return model
145
+
146
+
147
+ def load_tokenizer(path):
148
+ """Load SentencePiece tokenizer."""
149
+ import sentencepiece as spm
150
+ sp = spm.SentencePieceProcessor()
151
+ sp.Load(path)
152
+ return sp