|
|
| """
|
| Production-ready Supernova training script.
|
| Optimized for stability, monitoring, and memory efficiency.
|
| """
|
|
|
| import argparse
|
| import json
|
| import math
|
| import os
|
| import sys
|
| import time
|
| import logging
|
| from pathlib import Path
|
| from typing import Optional, Dict, Any
|
|
|
| import torch
|
| import torch.nn as nn
|
| from torch.utils.data import DataLoader
|
| from transformers import get_cosine_schedule_with_warmup
|
|
|
|
|
| sys.path.append('.')
|
|
|
| from supernova.config import ModelConfig
|
| from supernova.model import SupernovaModel
|
| from supernova.tokenizer import load_gpt2_tokenizer
|
| from supernova.data import load_sources_from_yaml, TokenChunkDataset
|
|
|
|
|
| def setup_logging(output_dir: str) -> logging.Logger:
|
| """Setup comprehensive logging."""
|
| os.makedirs(output_dir, exist_ok=True)
|
|
|
| logger = logging.getLogger('supernova_training')
|
| logger.setLevel(logging.INFO)
|
|
|
|
|
| file_handler = logging.FileHandler(os.path.join(output_dir, 'training.log'))
|
| file_handler.setLevel(logging.INFO)
|
|
|
|
|
| console_handler = logging.StreamHandler()
|
| console_handler.setLevel(logging.INFO)
|
|
|
|
|
| formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
| file_handler.setFormatter(formatter)
|
| console_handler.setFormatter(formatter)
|
|
|
| logger.addHandler(file_handler)
|
| logger.addHandler(console_handler)
|
|
|
| return logger
|
|
|
|
|
| def compute_grad_norm(model: nn.Module) -> float:
|
| """Compute gradient norm."""
|
| total = 0.0
|
| for p in model.parameters():
|
| if p.grad is not None:
|
| param_norm = p.grad.data.float().norm(2).item()
|
| total += param_norm * param_norm
|
| return math.sqrt(total)
|
|
|
|
|
| def format_time(seconds: float) -> str:
|
| """Format seconds into readable time."""
|
| if seconds < 60:
|
| return f"{seconds:.1f}s"
|
| elif seconds < 3600:
|
| return f"{seconds//60:.0f}m{seconds%60:.0f}s"
|
| else:
|
| return f"{seconds//3600:.0f}h{(seconds%3600)//60:.0f}m"
|
|
|
|
|
| def get_memory_usage() -> Dict[str, float]:
|
| """Get current memory usage."""
|
| if torch.cuda.is_available():
|
| allocated = torch.cuda.memory_allocated() / 1024**3
|
| cached = torch.cuda.memory_reserved() / 1024**3
|
| return {'allocated': allocated, 'cached': cached}
|
| return {'allocated': 0, 'cached': 0}
|
|
|
|
|
| def save_checkpoint(
|
| model: nn.Module,
|
| optimizer: torch.optim.Optimizer,
|
| scheduler: Any,
|
| step: int,
|
| loss: float,
|
| best_loss: float,
|
| config: Dict[str, Any],
|
| path: str,
|
| logger: logging.Logger
|
| ) -> None:
|
| """Save training checkpoint."""
|
| try:
|
| checkpoint = {
|
| "model_state_dict": model.state_dict(),
|
| "optimizer_state_dict": optimizer.state_dict(),
|
| "scheduler_state_dict": scheduler.state_dict(),
|
| "config": config,
|
| "step": step,
|
| "loss": loss,
|
| "best_loss": best_loss,
|
| "timestamp": time.time(),
|
| }
|
|
|
|
|
| os.makedirs(os.path.dirname(path), exist_ok=True)
|
|
|
| torch.save(checkpoint, path)
|
| logger.info(f"💾 Checkpoint saved: {path} (loss: {loss:.4f})")
|
|
|
| except Exception as e:
|
| logger.error(f"❌ Failed to save checkpoint {path}: {e}")
|
| raise
|
|
|
|
|
| def validate_training_setup(
|
| config_path: str,
|
| data_config_path: str,
|
| logger: logging.Logger
|
| ) -> None:
|
| """Validate training setup before starting."""
|
| logger.info("🔍 Validating training setup...")
|
|
|
|
|
| if not os.path.exists(config_path):
|
| raise FileNotFoundError(f"Model config not found: {config_path}")
|
| if not os.path.exists(data_config_path):
|
| raise FileNotFoundError(f"Data config not found: {data_config_path}")
|
|
|
|
|
| cfg = ModelConfig.from_json_file(config_path)
|
| cfg.assert_exact_params(expected=25_000_000)
|
| model = SupernovaModel(cfg)
|
| total_params = sum(p.numel() for p in model.parameters())
|
| assert total_params == 25_000_000
|
|
|
|
|
| sources = load_sources_from_yaml(data_config_path)
|
| if not sources:
|
| raise ValueError("No data sources configured")
|
|
|
|
|
| tok = load_gpt2_tokenizer()
|
| assert tok.vocab_size == cfg.vocab_size
|
|
|
| logger.info("✅ Training setup validation complete")
|
|
|
|
|
| def train_production(
|
| config_path: str,
|
| data_config_path: str,
|
| seq_len: int = 1024,
|
| batch_size: int = 16,
|
| grad_accum: int = 8,
|
| lr: float = 3e-4,
|
| warmup_steps: int = 2000,
|
| max_steps: int = 100_000,
|
| save_every: int = 10_000,
|
| log_every: int = 50,
|
| out_dir: str = "checkpoints",
|
| seed: int = 42,
|
| max_grad_norm: float = 1.0,
|
| enable_mixed_precision: bool = True,
|
| ) -> None:
|
| """Production training with full monitoring and optimization."""
|
|
|
|
|
| logger = setup_logging(out_dir)
|
| logger.info("🚀 SUPERNOVA PRODUCTION TRAINING STARTED")
|
| logger.info("=" * 60)
|
|
|
|
|
| validate_training_setup(config_path, data_config_path, logger)
|
|
|
|
|
| torch.manual_seed(seed)
|
| if torch.cuda.is_available():
|
| torch.cuda.manual_seed(seed)
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| logger.info(f"📱 Device: {device}")
|
| logger.info(f"🌱 Seed: {seed}")
|
|
|
|
|
| cfg = ModelConfig.from_json_file(config_path)
|
| cfg.assert_exact_params(expected=25_000_000)
|
| logger.info(f"⚙️ Model: {cfg.n_layers} layers, {cfg.d_model} d_model, {cfg.n_heads} heads")
|
|
|
|
|
| tok = load_gpt2_tokenizer()
|
| logger.info(f"🔤 Tokenizer: {tok.vocab_size:,} vocab size")
|
|
|
|
|
| model = SupernovaModel(cfg).to(device)
|
| total_params = sum(p.numel() for p in model.parameters())
|
| logger.info(f"🧠 Model: {total_params:,} parameters")
|
|
|
|
|
| scaler = torch.cuda.amp.GradScaler() if enable_mixed_precision and torch.cuda.is_available() else None
|
| if scaler:
|
| logger.info("⚡ Mixed precision training enabled")
|
|
|
|
|
| logger.info("📚 Loading datasets...")
|
| sources = load_sources_from_yaml(data_config_path)
|
| logger.info(f"📊 Data sources: {len(sources)} sources loaded")
|
| for i, source in enumerate(sources):
|
| logger.info(f" {i+1}. {source.name} (weight: {source.weight})")
|
|
|
| ds = TokenChunkDataset(tok, sources, seq_len=seq_len, eos_token_id=tok.eos_token_id)
|
| dl = DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=0)
|
| logger.info(f"🔄 DataLoader: batch_size={batch_size}, seq_len={seq_len}")
|
|
|
|
|
| optimizer = torch.optim.AdamW(
|
| model.parameters(), lr=lr, betas=(0.9, 0.95), weight_decay=0.1
|
| )
|
| scheduler = get_cosine_schedule_with_warmup(
|
| optimizer, num_warmup_steps=warmup_steps, num_training_steps=max_steps
|
| )
|
|
|
| logger.info(f"🎯 Training configuration:")
|
| logger.info(f" Learning rate: {lr}")
|
| logger.info(f" Warmup steps: {warmup_steps:,}")
|
| logger.info(f" Max steps: {max_steps:,}")
|
| logger.info(f" Gradient accumulation: {grad_accum}")
|
| logger.info(f" Max gradient norm: {max_grad_norm}")
|
| logger.info(f" Save every: {save_every:,} steps")
|
| logger.info(f" Log every: {log_every} steps")
|
|
|
|
|
| model.train()
|
| step = 0
|
| micro = 0
|
| running_loss = 0.0
|
| best_loss = float('inf')
|
| start_time = time.time()
|
|
|
| logger.info("🏃 Starting training loop...")
|
| logger.info("=" * 60)
|
|
|
| try:
|
| while step < max_steps:
|
| for batch in dl:
|
| x, y = batch
|
| x = x.to(device, non_blocking=True)
|
| y = y.to(device, non_blocking=True)
|
|
|
|
|
| if scaler:
|
| with torch.cuda.amp.autocast():
|
| logits, loss = model(x, y)
|
| loss = loss / grad_accum
|
| else:
|
| logits, loss = model(x, y)
|
| loss = loss / grad_accum
|
|
|
|
|
| if scaler:
|
| scaler.scale(loss).backward()
|
| else:
|
| loss.backward()
|
|
|
| micro += 1
|
| running_loss += loss.item()
|
|
|
|
|
| if micro % grad_accum == 0:
|
| if scaler:
|
| scaler.unscale_(optimizer)
|
| torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
|
| scaler.step(optimizer)
|
| scaler.update()
|
| else:
|
| torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
|
| optimizer.step()
|
|
|
| optimizer.zero_grad(set_to_none=True)
|
| scheduler.step()
|
| step += 1
|
|
|
|
|
| if step % log_every == 0:
|
| grad_norm = compute_grad_norm(model)
|
| avg_loss = running_loss * grad_accum / log_every
|
| running_loss = 0.0
|
| lr_now = scheduler.get_last_lr()[0]
|
| elapsed = time.time() - start_time
|
|
|
|
|
| memory = get_memory_usage()
|
|
|
|
|
| tokens_per_sec = (step * batch_size * seq_len * grad_accum) / elapsed
|
|
|
| log_msg = (
|
| f"Step {step:6d} | Loss: {avg_loss:.4f} | Grad: {grad_norm:.3f} | "
|
| f"LR: {lr_now:.2e} | {tokens_per_sec:.0f} tok/s"
|
| )
|
|
|
| if memory['allocated'] > 0:
|
| log_msg += f" | Mem: {memory['allocated']:.1f}GB"
|
|
|
| logger.info(log_msg)
|
|
|
|
|
| if avg_loss < best_loss:
|
| best_loss = avg_loss
|
| logger.info(f"💫 New best loss: {best_loss:.4f}")
|
|
|
|
|
| if save_every and step % save_every == 0:
|
| ckpt_path = os.path.join(out_dir, f"supernova_step{step}.pt")
|
| save_checkpoint(
|
| model, optimizer, scheduler, step, avg_loss if 'avg_loss' in locals() else 0.0,
|
| best_loss, cfg.__dict__, ckpt_path, logger
|
| )
|
|
|
| if step >= max_steps:
|
| break
|
|
|
|
|
| if torch.cuda.is_available() and micro % 100 == 0:
|
| torch.cuda.empty_cache()
|
|
|
| except KeyboardInterrupt:
|
| logger.info("\n⏹️ Training interrupted by user")
|
| except Exception as e:
|
| logger.error(f"\n❌ Training failed: {e}")
|
| raise
|
|
|
|
|
| final_path = os.path.join(out_dir, "supernova_final.pt")
|
| final_loss = running_loss * grad_accum / max(1, micro % grad_accum) if running_loss > 0 else best_loss
|
| save_checkpoint(model, optimizer, scheduler, step, final_loss, best_loss, cfg.__dict__, final_path, logger)
|
|
|
|
|
| total_time = time.time() - start_time
|
| total_tokens = step * batch_size * seq_len * grad_accum
|
|
|
| logger.info("\n" + "=" * 60)
|
| logger.info("🎉 TRAINING COMPLETE!")
|
| logger.info(f"📈 Final step: {step:,}")
|
| logger.info(f"🏆 Best loss: {best_loss:.4f}")
|
| logger.info(f"⏱️ Total time: {format_time(total_time)}")
|
| logger.info(f"🔢 Total tokens: {total_tokens:,}")
|
| logger.info(f"⚡ Average throughput: {total_tokens/total_time:.0f} tokens/sec")
|
| logger.info(f"💾 Final checkpoint: {final_path}")
|
| logger.info("=" * 60)
|
|
|
|
|
| def main():
|
| parser = argparse.ArgumentParser(description="Production Supernova Training")
|
| parser.add_argument("--config", required=True, help="Path to model config")
|
| parser.add_argument("--data-config", required=True, help="Path to data config")
|
| parser.add_argument("--seq-len", type=int, default=1024, help="Sequence length")
|
| parser.add_argument("--batch-size", type=int, default=16, help="Batch size")
|
| parser.add_argument("--grad-accum", type=int, default=8, help="Gradient accumulation")
|
| parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate")
|
| parser.add_argument("--warmup-steps", type=int, default=2000, help="Warmup steps")
|
| parser.add_argument("--max-steps", type=int, default=100000, help="Max training steps")
|
| parser.add_argument("--save-every", type=int, default=10000, help="Save frequency")
|
| parser.add_argument("--log-every", type=int, default=50, help="Log frequency")
|
| parser.add_argument("--out-dir", default="checkpoints", help="Output directory")
|
| parser.add_argument("--seed", type=int, default=42, help="Random seed")
|
| parser.add_argument("--max-grad-norm", type=float, default=1.0, help="Gradient clipping")
|
| parser.add_argument("--no-mixed-precision", action="store_true", help="Disable mixed precision")
|
|
|
| args = parser.parse_args()
|
|
|
| train_production(
|
| config_path=args.config,
|
| data_config_path=args.data_config,
|
| seq_len=args.seq_len,
|
| batch_size=args.batch_size,
|
| grad_accum=args.grad_accum,
|
| lr=args.lr,
|
| warmup_steps=args.warmup_steps,
|
| max_steps=args.max_steps,
|
| save_every=args.save_every,
|
| log_every=args.log_every,
|
| out_dir=args.out_dir,
|
| seed=args.seed,
|
| max_grad_norm=args.max_grad_norm,
|
| enable_mixed_precision=not args.no_mixed_precision,
|
| )
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|