|
|
| """
|
| Enhanced training script with comprehensive logging and validation.
|
| """
|
|
|
| import argparse
|
| import json
|
| import math
|
| import os
|
| import sys
|
| import time
|
| from typing import Optional
|
|
|
| import torch
|
| import torch.nn as nn
|
| from torch.utils.data import DataLoader
|
| from transformers import get_cosine_schedule_with_warmup
|
|
|
|
|
| sys.path.append('.')
|
|
|
| from supernova.config import ModelConfig
|
| from supernova.model import SupernovaModel
|
| from supernova.tokenizer import load_gpt2_tokenizer
|
| from supernova.data import load_sources_from_yaml, TokenChunkDataset
|
|
|
|
|
| def compute_grad_norm(model: nn.Module) -> float:
|
| total = 0.0
|
| for p in model.parameters():
|
| if p.grad is not None:
|
| param_norm = p.grad.data.float().norm(2).item()
|
| total += param_norm * param_norm
|
| return math.sqrt(total)
|
|
|
|
|
| def format_time(seconds):
|
| """Format seconds into readable time."""
|
| if seconds < 60:
|
| return f"{seconds:.1f}s"
|
| elif seconds < 3600:
|
| return f"{seconds//60:.0f}m{seconds%60:.0f}s"
|
| else:
|
| return f"{seconds//3600:.0f}h{(seconds%3600)//60:.0f}m"
|
|
|
|
|
| def train_enhanced(
|
| config_path: str,
|
| data_config_path: str,
|
| seq_len: int = 1024,
|
| batch_size: int = 16,
|
| grad_accum: int = 8,
|
| lr: float = 3e-4,
|
| warmup_steps: int = 2000,
|
| max_steps: int = 100_000,
|
| save_every: int = 10_000,
|
| out_dir: str = "checkpoints",
|
| seed: int = 42,
|
| ):
|
| print("π SUPERNOVA ENHANCED TRAINING")
|
| print("=" * 60)
|
|
|
|
|
| torch.manual_seed(seed)
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| print(f"π± Device: {device}")
|
| print(f"π± Seed: {seed}")
|
|
|
|
|
| cfg = ModelConfig.from_json_file(config_path)
|
| cfg.assert_exact_params(expected=25_000_000)
|
| print(f"βοΈ Model: {cfg.n_layers} layers, {cfg.d_model} d_model, {cfg.n_heads} heads")
|
|
|
|
|
| tok = load_gpt2_tokenizer()
|
| assert tok.vocab_size == cfg.vocab_size
|
| print(f"π€ Tokenizer: {tok.vocab_size:,} vocab size")
|
|
|
|
|
| model = SupernovaModel(cfg).to(device)
|
| total_params = sum(p.numel() for p in model.parameters())
|
| assert total_params == 25_000_000
|
| print(f"π§ Model: {total_params:,} parameters (EXACT)")
|
|
|
|
|
| print("π Loading datasets...")
|
| sources = load_sources_from_yaml(data_config_path)
|
| print(f"π Data sources: {len(sources)} sources loaded")
|
| for i, source in enumerate(sources):
|
| print(f" {i+1}. {source.name} (weight: {source.weight})")
|
|
|
| ds = TokenChunkDataset(tok, sources, seq_len=seq_len, eos_token_id=tok.eos_token_id)
|
| dl = DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=0)
|
| print(f"π DataLoader: batch_size={batch_size}, seq_len={seq_len}")
|
|
|
|
|
| optimizer = torch.optim.AdamW(
|
| model.parameters(), lr=lr, betas=(0.9, 0.95), weight_decay=0.1
|
| )
|
| scheduler = get_cosine_schedule_with_warmup(
|
| optimizer,
|
| num_warmup_steps=warmup_steps,
|
| num_training_steps=max_steps,
|
| )
|
|
|
| print(f"π― Training setup:")
|
| print(f" Learning rate: {lr}")
|
| print(f" Warmup steps: {warmup_steps:,}")
|
| print(f" Max steps: {max_steps:,}")
|
| print(f" Grad accumulation: {grad_accum}")
|
| print(f" Save every: {save_every:,} steps")
|
|
|
|
|
| os.makedirs(out_dir, exist_ok=True)
|
| print(f"πΎ Output dir: {out_dir}")
|
| print()
|
|
|
|
|
| model.train()
|
| step = 0
|
| micro = 0
|
| running_loss = 0.0
|
| best_loss = float('inf')
|
| start_time = time.time()
|
| last_log_time = start_time
|
|
|
| print("π Starting training...")
|
| print("=" * 60)
|
|
|
| try:
|
| while step < max_steps:
|
| for batch in dl:
|
| x, y = batch
|
| x = x.to(device)
|
| y = y.to(device)
|
|
|
| logits, loss = model(x, y)
|
| loss = loss / grad_accum
|
| loss.backward()
|
|
|
| micro += 1
|
| running_loss += loss.item()
|
|
|
| if micro % grad_accum == 0:
|
| optimizer.step()
|
| optimizer.zero_grad(set_to_none=True)
|
| scheduler.step()
|
|
|
| step += 1
|
|
|
|
|
| if step % 10 == 0:
|
| grad_norm = compute_grad_norm(model)
|
| avg_loss = running_loss * grad_accum / 10.0
|
| running_loss = 0.0
|
| elapsed = time.time() - last_log_time
|
| total_elapsed = time.time() - start_time
|
| lr_now = scheduler.get_last_lr()[0]
|
|
|
|
|
| tokens_per_batch = batch_size * seq_len
|
| tokens_per_step = tokens_per_batch * grad_accum
|
| tokens_processed = step * tokens_per_step
|
| tokens_per_sec = tokens_processed / total_elapsed
|
|
|
| print(f"Step {step:5d} | Loss: {avg_loss:.4f} | Grad: {grad_norm:.3f} | "
|
| f"LR: {lr_now:.2e} | {tokens_per_sec:.0f} tok/s | {format_time(total_elapsed)}")
|
|
|
|
|
| if avg_loss < best_loss:
|
| best_loss = avg_loss
|
| print(f"π« New best loss: {best_loss:.4f}")
|
|
|
| last_log_time = time.time()
|
|
|
|
|
| if save_every and step % save_every == 0:
|
| ckpt_path = os.path.join(out_dir, f"supernova_step{step}.pt")
|
| torch.save({
|
| "model_state_dict": model.state_dict(),
|
| "optimizer_state_dict": optimizer.state_dict(),
|
| "scheduler_state_dict": scheduler.state_dict(),
|
| "config": cfg.__dict__,
|
| "step": step,
|
| "loss": avg_loss,
|
| "best_loss": best_loss,
|
| }, ckpt_path)
|
| print(f"πΎ Saved checkpoint: {ckpt_path}")
|
|
|
| if step >= max_steps:
|
| break
|
|
|
| except KeyboardInterrupt:
|
| print("\nβΉοΈ Training interrupted by user")
|
| except Exception as e:
|
| print(f"\nβ Training failed with error: {e}")
|
| raise
|
|
|
|
|
| final_path = os.path.join(out_dir, "supernova_final.pt")
|
| torch.save({
|
| "model_state_dict": model.state_dict(),
|
| "optimizer_state_dict": optimizer.state_dict(),
|
| "scheduler_state_dict": scheduler.state_dict(),
|
| "config": cfg.__dict__,
|
| "step": step,
|
| "loss": running_loss * grad_accum / max(1, micro % grad_accum),
|
| "best_loss": best_loss,
|
| }, final_path)
|
|
|
| total_time = time.time() - start_time
|
| print("\n" + "=" * 60)
|
| print("π TRAINING COMPLETE!")
|
| print(f"π Final step: {step:,}")
|
| print(f"π Best loss: {best_loss:.4f}")
|
| print(f"β±οΈ Total time: {format_time(total_time)}")
|
| print(f"πΎ Final checkpoint: {final_path}")
|
| print("=" * 60)
|
|
|
|
|
| def main():
|
| parser = argparse.ArgumentParser(description="Enhanced Supernova Training")
|
| parser.add_argument("--config", required=True, help="Path to model config")
|
| parser.add_argument("--data-config", required=True, help="Path to data config")
|
| parser.add_argument("--seq-len", type=int, default=1024, help="Sequence length")
|
| parser.add_argument("--batch-size", type=int, default=16, help="Batch size")
|
| parser.add_argument("--grad-accum", type=int, default=8, help="Gradient accumulation")
|
| parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate")
|
| parser.add_argument("--warmup-steps", type=int, default=2000, help="Warmup steps")
|
| parser.add_argument("--max-steps", type=int, default=100000, help="Max training steps")
|
| parser.add_argument("--save-every", type=int, default=10000, help="Save frequency")
|
| parser.add_argument("--out-dir", default="checkpoints", help="Output directory")
|
| parser.add_argument("--seed", type=int, default=42, help="Random seed")
|
|
|
| args = parser.parse_args()
|
|
|
| train_enhanced(
|
| config_path=args.config,
|
| data_config_path=args.data_config,
|
| seq_len=args.seq_len,
|
| batch_size=args.batch_size,
|
| grad_accum=args.grad_accum,
|
| lr=args.lr,
|
| warmup_steps=args.warmup_steps,
|
| max_steps=args.max_steps,
|
| save_every=args.save_every,
|
| out_dir=args.out_dir,
|
| seed=args.seed,
|
| )
|
|
|
|
|
| if __name__ == "__main__":
|
| main() |