ronnengmail commited on
Commit
21a2488
·
verified ·
1 Parent(s): 452f44f

Upload training_scripts/train_sft_3b.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. training_scripts/train_sft_3b.py +343 -0
training_scripts/train_sft_3b.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Multilingual 3B GPT — SFT Training
4
+
5
+ Fine-tunes the base model on instruction data (Aya + Bactrian-X + FLORES translations).
6
+ Uses the same architecture as pretraining with LoRA-free full fine-tuning
7
+ (model is 3B params, fits in 24GB A10G in bf16).
8
+
9
+ Usage:
10
+ python train_sft_3b.py --checkpoint /path/to/best_model.pt \
11
+ --tokenizer /path/to/multilingual_32k.model \
12
+ --data-dir /path/to/sft_data/ \
13
+ --output /path/to/sft_model.pt
14
+ """
15
+
16
+ import os, sys, json, math, time, random, argparse
17
+ sys.stdout.reconfigure(line_buffering=True)
18
+ import gc
19
+ import numpy as np
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ import sentencepiece as spm
24
+
25
+ # ============ MODEL (must match training) ============
26
+ VOCAB_SIZE = 32000
27
+ DIM = 3072
28
+ DEPTH = 26
29
+ N_HEADS = 24
30
+ MAX_SEQ_LEN = 2048
31
+ ROPE_THETA = 10000
32
+
33
+ class RMSNorm(nn.Module):
34
+ def __init__(self, dim, eps=1e-6):
35
+ super().__init__()
36
+ self.weight = nn.Parameter(torch.ones(dim))
37
+ self.eps = eps
38
+ def forward(self, x):
39
+ return x * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps).type_as(x) * self.weight
40
+
41
+ class SwiGLU(nn.Module):
42
+ def __init__(self, dim, hidden_dim):
43
+ super().__init__()
44
+ self.gate = nn.Linear(dim, hidden_dim, bias=False)
45
+ self.up = nn.Linear(dim, hidden_dim, bias=False)
46
+ self.down = nn.Linear(hidden_dim, dim, bias=False)
47
+ def forward(self, x):
48
+ return self.down(F.silu(self.gate(x)) * self.up(x))
49
+
50
+ def apply_rope(x, cos, sin):
51
+ x1, x2 = x[..., ::2], x[..., 1::2]
52
+ return torch.stack((x1*cos - x2*sin, x1*sin + x2*cos), dim=-1).flatten(-2)
53
+
54
+ class Attention(nn.Module):
55
+ def __init__(self, dim, n_heads):
56
+ super().__init__()
57
+ self.n_heads = n_heads
58
+ self.head_dim = dim // n_heads
59
+ self.qkv = nn.Linear(dim, 3*dim, bias=False)
60
+ self.proj = nn.Linear(dim, dim, bias=False)
61
+ def forward(self, x, cos, sin):
62
+ B, T, C = x.shape
63
+ qkv = self.qkv(x).reshape(B, T, 3, self.n_heads, self.head_dim).permute(2, 0, 3, 1, 4)
64
+ q, k, v = qkv[0], qkv[1], qkv[2]
65
+ q, k = apply_rope(q, cos, sin), apply_rope(k, cos, sin)
66
+ y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
67
+ return self.proj(y.transpose(1, 2).contiguous().view(B, T, C))
68
+
69
+ class Block(nn.Module):
70
+ def __init__(self, dim, n_heads, mlp_dim):
71
+ super().__init__()
72
+ self.ln1 = RMSNorm(dim)
73
+ self.attn = Attention(dim, n_heads)
74
+ self.ln2 = RMSNorm(dim)
75
+ self.mlp = SwiGLU(dim, mlp_dim)
76
+ def forward(self, x, cos, sin):
77
+ x = x + self.attn(self.ln1(x), cos, sin)
78
+ x = x + self.mlp(self.ln2(x))
79
+ return x
80
+
81
+ class GPT(nn.Module):
82
+ def __init__(self):
83
+ super().__init__()
84
+ self.tok_emb = nn.Embedding(VOCAB_SIZE, DIM)
85
+ mlp_dim = ((int(2 * DIM * 4 / 3) + 63) // 64) * 64
86
+ self.blocks = nn.ModuleList([Block(DIM, N_HEADS, mlp_dim) for _ in range(DEPTH)])
87
+ self.ln_f = RMSNorm(DIM)
88
+ self.head = nn.Linear(DIM, VOCAB_SIZE, bias=False)
89
+ self.head.weight = self.tok_emb.weight
90
+ hd = DIM // N_HEADS
91
+ freqs = 1.0 / (ROPE_THETA ** (torch.arange(0, hd, 2).float() / hd))
92
+ angles = torch.outer(torch.arange(MAX_SEQ_LEN).float(), freqs)
93
+ self.register_buffer('rope_cos', angles.cos())
94
+ self.register_buffer('rope_sin', angles.sin())
95
+
96
+ def forward(self, idx):
97
+ B, T = idx.shape
98
+ x = self.tok_emb(idx)
99
+ cos = self.rope_cos[:T][None, None]
100
+ sin = self.rope_sin[:T][None, None]
101
+ for block in self.blocks:
102
+ x = block(x, cos, sin)
103
+ return self.head(self.ln_f(x))
104
+
105
+ @torch.no_grad()
106
+ def generate(self, idx, max_new=200, temp=0.7, top_k=40, rep_penalty=1.2):
107
+ for _ in range(max_new):
108
+ idx_c = idx[:, -MAX_SEQ_LEN:]
109
+ logits = self(idx_c)[:, -1, :]
110
+ if rep_penalty > 1.0:
111
+ for token_id in set(idx[0].tolist()[-50:]):
112
+ logits[0, token_id] /= rep_penalty
113
+ logits = logits / temp
114
+ if top_k > 0:
115
+ v, _ = torch.topk(logits, top_k)
116
+ logits[logits < v[:, [-1]]] = float('-inf')
117
+ probs = F.softmax(logits, dim=-1)
118
+ nx = torch.multinomial(probs, 1)
119
+ idx = torch.cat([idx, nx], dim=1)
120
+ if nx.item() == 2:
121
+ break
122
+ return idx
123
+
124
+
125
+ # ============ DATA LOADING ============
126
+ USER_PREFIX = "### User:\n"
127
+ ASSISTANT_PREFIX = "### Assistant:\n"
128
+ TURN_END = "\n\n"
129
+
130
+ def load_sft_data(data_dir, split='train'):
131
+ """Load tokenized SFT data."""
132
+ filepath = os.path.join(data_dir, f'{split}_sft.bin')
133
+ data = np.fromfile(filepath, dtype=np.uint16)
134
+ return torch.from_numpy(data.astype(np.int64))
135
+
136
+ def get_batch(data, batch_size, seq_len, device):
137
+ """Get a random batch of sequences."""
138
+ ix = torch.randint(len(data) - seq_len - 1, (batch_size,))
139
+ x = torch.stack([data[i:i+seq_len] for i in ix]).to(device)
140
+ y = torch.stack([data[i+1:i+seq_len+1] for i in ix]).to(device)
141
+ return x, y
142
+
143
+
144
+ # ============ TRAINING ============
145
+ def train(args):
146
+ device = args.device
147
+ print(f"Device: {device}")
148
+
149
+ # Load tokenizer
150
+ print(f"Loading tokenizer: {args.tokenizer}")
151
+ sp = spm.SentencePieceProcessor(args.tokenizer)
152
+
153
+ # Load model
154
+ print(f"Loading base model: {args.checkpoint}")
155
+ model = GPT()
156
+ ckpt = torch.load(args.checkpoint, map_location='cpu', weights_only=False)
157
+ state_dict = ckpt.get('model_state_dict', ckpt.get('model', ckpt))
158
+ clean_sd = {}
159
+ for k, v in state_dict.items():
160
+ k = k.replace('_orig_mod.', '').replace('module.', '')
161
+ clean_sd[k] = v
162
+ model.load_state_dict(clean_sd, strict=False)
163
+ del ckpt, state_dict, clean_sd
164
+ gc.collect()
165
+
166
+ model = model.to(device).train()
167
+ # Use bf16 for memory efficiency
168
+ model = model.to(torch.bfloat16)
169
+
170
+ param_count = sum(p.numel() for p in model.parameters())
171
+ print(f"Model loaded: {param_count/1e9:.2f}B parameters")
172
+
173
+ # Load data
174
+ print(f"Loading SFT data from: {args.data_dir}")
175
+ train_data = load_sft_data(args.data_dir, 'train')
176
+ val_data = load_sft_data(args.data_dir, 'val')
177
+ print(f"Train: {len(train_data)} tokens, Val: {len(val_data)} tokens")
178
+
179
+ # Optimizer — 8-bit Adam for memory efficiency (halves optimizer states)
180
+ try:
181
+ import bitsandbytes as bnb
182
+ optimizer = bnb.optim.AdamW8bit(
183
+ model.parameters(),
184
+ lr=args.lr,
185
+ betas=(0.9, 0.95),
186
+ weight_decay=0.01,
187
+ )
188
+ print("Using 8-bit AdamW (bitsandbytes)")
189
+ except ImportError:
190
+ optimizer = torch.optim.AdamW(
191
+ model.parameters(),
192
+ lr=args.lr,
193
+ betas=(0.9, 0.95),
194
+ weight_decay=0.01,
195
+ )
196
+ print("Using standard AdamW")
197
+
198
+ # Cosine schedule with warmup
199
+ def get_lr(step):
200
+ if step < args.warmup_steps:
201
+ return args.lr * step / args.warmup_steps
202
+ decay_ratio = (step - args.warmup_steps) / (args.max_steps - args.warmup_steps)
203
+ return args.lr * 0.1 + 0.9 * args.lr * 0.5 * (1 + math.cos(math.pi * decay_ratio))
204
+
205
+ # Enable gradient checkpointing to save VRAM
206
+ for block in model.blocks:
207
+ block._gradient_checkpointing = True
208
+ original_block_forward = Block.forward
209
+ def checkpointed_forward(self, x, cos, sin):
210
+ if self.training and hasattr(self, '_gradient_checkpointing') and self._gradient_checkpointing:
211
+ return torch.utils.checkpoint.checkpoint(original_block_forward, self, x, cos, sin, use_reentrant=False)
212
+ return original_block_forward(self, x, cos, sin)
213
+ Block.forward = checkpointed_forward
214
+
215
+ # Training loop
216
+ best_val_loss = float('inf')
217
+ grad_accum = args.grad_accum
218
+ print(f"\nStarting SFT training for {args.max_steps} steps...")
219
+ print(f"Batch size: {args.batch_size} x {grad_accum} accum = {args.batch_size * grad_accum} effective, Seq len: {MAX_SEQ_LEN}, LR: {args.lr}")
220
+
221
+ t0 = time.time()
222
+ for step in range(1, args.max_steps + 1):
223
+ # LR schedule
224
+ lr = get_lr(step)
225
+ for pg in optimizer.param_groups:
226
+ pg['lr'] = lr
227
+
228
+ # Gradient accumulation
229
+ optimizer.zero_grad(set_to_none=True)
230
+ accum_loss = 0.0
231
+ for micro in range(grad_accum):
232
+ x, y = get_batch(train_data, args.batch_size, MAX_SEQ_LEN, device)
233
+ with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
234
+ logits = model(x)
235
+ loss = F.cross_entropy(logits.view(-1, VOCAB_SIZE), y.view(-1)) / grad_accum
236
+ loss.backward()
237
+ accum_loss += loss.item()
238
+
239
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
240
+ optimizer.step()
241
+ loss = type('obj', (object,), {'item': lambda self: accum_loss})() # For logging
242
+
243
+ # Logging
244
+ if step % 10 == 0:
245
+ elapsed = time.time() - t0
246
+ tps = step * args.batch_size * grad_accum * MAX_SEQ_LEN / elapsed
247
+ print(f"Step {step}/{args.max_steps} | Loss: {accum_loss:.4f} | LR: {lr:.6f} | TPS: {tps:.0f} | {elapsed:.0f}s")
248
+
249
+ # Eval
250
+ if step % args.eval_every == 0 or step == args.max_steps:
251
+ model.eval()
252
+ val_losses = []
253
+ for _ in range(20):
254
+ x, y = get_batch(val_data, args.batch_size, MAX_SEQ_LEN, device)
255
+ with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.bfloat16):
256
+ logits = model(x)
257
+ val_loss = F.cross_entropy(logits.view(-1, VOCAB_SIZE), y.view(-1))
258
+ val_losses.append(val_loss.item())
259
+ avg_val = sum(val_losses) / len(val_losses)
260
+ print(f" 📊 Val loss: {avg_val:.4f} {'(NEW BEST!)' if avg_val < best_val_loss else ''}")
261
+
262
+ if avg_val < best_val_loss:
263
+ best_val_loss = avg_val
264
+ torch.save({
265
+ 'model_state_dict': model.state_dict(),
266
+ 'step': step,
267
+ 'val_loss': avg_val,
268
+ 'config': {
269
+ 'vocab_size': VOCAB_SIZE, 'dim': DIM, 'depth': DEPTH,
270
+ 'n_heads': N_HEADS, 'max_seq_len': MAX_SEQ_LEN,
271
+ }
272
+ }, args.output)
273
+ print(f" 💾 Best model saved to {args.output}")
274
+
275
+ model.train()
276
+
277
+ # Generate samples periodically
278
+ if step % args.sample_every == 0 or step == args.max_steps:
279
+ model.eval()
280
+ prompts = [
281
+ ("EN", "### User:\nWhat is the capital of France?\n\n### Assistant:\n"),
282
+ ("HE", "### User:\nמה בירת צרפת?\n\n### Assistant:\n"),
283
+ ("AR", "### User:\nما هي عاصمة فرنسا؟\n\n### Assistant:\n"),
284
+ ("FA", "### User:\nپایتخت فرانسه کجاست؟\n\n### Assistant:\n"),
285
+ ("TRANSLATE", "### User:\nTranslate the following Hebrew text to English:\nשלום עולם, איך אתה היום?\n\n### Assistant:\n"),
286
+ ]
287
+ print(f"\n 🔤 Generation samples (step {step}):")
288
+ for label, prompt in prompts:
289
+ ids = torch.tensor([sp.encode(prompt)], device=device, dtype=torch.long)
290
+ with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.bfloat16):
291
+ out = model.generate(ids, max_new=100, temp=0.7, top_k=40)
292
+ text = sp.decode(out[0].tolist())
293
+ # Just show the assistant response
294
+ if "### Assistant:" in text:
295
+ response = text.split("### Assistant:")[-1].strip()[:200]
296
+ else:
297
+ response = text[len(prompt):].strip()[:200]
298
+ print(f" [{label}] {response}")
299
+ print()
300
+ model.train()
301
+
302
+ # Final save
303
+ elapsed = time.time() - t0
304
+ print(f"\n{'='*60}")
305
+ print(f"SFT TRAINING COMPLETE")
306
+ print(f"Steps: {args.max_steps}, Time: {elapsed:.0f}s ({elapsed/60:.1f}min)")
307
+ print(f"Best val loss: {best_val_loss:.4f}")
308
+ print(f"Model saved to: {args.output}")
309
+ print(f"{'='*60}")
310
+
311
+ # Upload to S3
312
+ print("Uploading to S3...")
313
+ os.system(f"aws s3 cp {args.output} s3://autoresearch-dashboard-196766918360/multilingual-7b/checkpoints/3b-v1-fsdp/sft_model.pt --quiet")
314
+ os.system(f"aws s3 cp /tmp/sft/sft.log s3://autoresearch-dashboard-196766918360/multilingual-7b/eval/sft_3b.log --quiet 2>/dev/null")
315
+ print("Done!")
316
+
317
+
318
+ def main():
319
+ parser = argparse.ArgumentParser()
320
+ parser.add_argument('--checkpoint', required=True)
321
+ parser.add_argument('--tokenizer', required=True)
322
+ parser.add_argument('--data-dir', required=True)
323
+ parser.add_argument('--output', default='/tmp/sft/sft_model.pt')
324
+ parser.add_argument('--device', default='cuda')
325
+ parser.add_argument('--batch-size', type=int, default=1) # 1 for 24GB GPU
326
+ parser.add_argument('--grad-accum', type=int, default=4) # Effective batch = 4
327
+ parser.add_argument('--lr', type=float, default=2e-5)
328
+ parser.add_argument('--max-steps', type=int, default=2000)
329
+ parser.add_argument('--warmup-steps', type=int, default=100)
330
+ parser.add_argument('--eval-every', type=int, default=200)
331
+ parser.add_argument('--sample-every', type=int, default=500)
332
+ parser.add_argument('--seed', type=int, default=42)
333
+ args = parser.parse_args()
334
+
335
+ random.seed(args.seed)
336
+ torch.manual_seed(args.seed)
337
+ os.makedirs(os.path.dirname(args.output), exist_ok=True)
338
+
339
+ train(args)
340
+
341
+
342
+ if __name__ == '__main__':
343
+ main()