|
|
|
|
|
|
|
|
| import os
|
| import math
|
| import time
|
| import json
|
| from pathlib import Path
|
| from dataclasses import dataclass
|
| from typing import Optional
|
|
|
|
|
| import torch
|
| import torch.nn as nn
|
| from torch.utils.data import Dataset, DataLoader
|
| from torch.optim.lr_scheduler import CosineAnnealingLR
|
| from accelerate import Accelerator
|
| from tqdm import tqdm
|
|
|
|
|
| from model_neo import NeoMini, NeoMiniConfig
|
|
|
|
|
|
|
| @dataclass
|
| class TrainingConfig:
|
| """Training configuration optimized for 8GB VRAM"""
|
|
|
| data_path: str = "data/tokens/packed_1024.txt"
|
| seq_length: int = 1024
|
|
|
|
|
| model_config_path: Optional[str] = None
|
|
|
|
|
| batch_size: int = 1
|
| gradient_accumulation_steps: int = 32
|
| max_steps: int = 150000
|
| warmup_steps: int = 3750
|
|
|
|
|
| resume_from_checkpoint: Optional[str] = "checkpoints/checkpoint_step_15000.pt"
|
|
|
|
|
| learning_rate: float = 3e-4
|
| weight_decay: float = 0.01
|
| beta1: float = 0.9
|
| beta2: float = 0.95
|
| grad_clip: float = 1.0
|
|
|
|
|
| mixed_precision: str = "bf16"
|
| gradient_checkpointing: bool = True
|
|
|
|
|
| log_interval: int = 10
|
| eval_interval: int = 500
|
| save_interval: int = 7500
|
| output_dir: str = "checkpoints"
|
|
|
|
|
| compile_model: bool = False
|
|
|
|
|
|
|
| class PackedDataset(Dataset):
|
| """Dataset for pre-tokenized and packed sequences"""
|
| def __init__(self, data_path: str, seq_length: int = 1024):
|
| self.data_path = Path(data_path)
|
| self.seq_length = seq_length
|
|
|
|
|
| print(f"Loading data from {data_path}...")
|
| with open(self.data_path, 'r', encoding='utf-8') as f:
|
| self.sequences = []
|
| for line in f:
|
| tokens = list(map(int, line.strip().split()))
|
| if len(tokens) == seq_length:
|
| self.sequences.append(tokens)
|
|
|
| print(f"Loaded {len(self.sequences)} sequences of length {seq_length}")
|
|
|
| def __len__(self):
|
| return len(self.sequences)
|
|
|
| def __getitem__(self, idx):
|
| tokens = self.sequences[idx]
|
|
|
| input_ids = torch.tensor(tokens[:-1], dtype=torch.long)
|
| targets = torch.tensor(tokens[1:], dtype=torch.long)
|
| return input_ids, targets
|
|
|
|
|
|
|
| def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, min_lr_ratio=0.1):
|
| """Cosine learning rate schedule with warmup"""
|
| def lr_lambda(current_step):
|
| if current_step < num_warmup_steps:
|
| return current_step / max(1, num_warmup_steps)
|
|
|
| progress = (current_step - num_warmup_steps) / max(1, num_training_steps - num_warmup_steps)
|
| return min_lr_ratio + (1 - min_lr_ratio) * 0.5 * (1 + math.cos(math.pi * progress))
|
|
|
| return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
|
|
|
|
|
|
|
| def compute_loss(logits, targets):
|
| """Compute cross-entropy loss"""
|
|
|
| logits_flat = logits.view(-1, logits.size(-1))
|
| targets_flat = targets.view(-1)
|
|
|
| loss = nn.functional.cross_entropy(logits_flat, targets_flat, ignore_index=-100)
|
| return loss
|
|
|
|
|
|
|
| def save_checkpoint(model, optimizer, scheduler, step, loss, config, checkpoint_dir):
|
| """Save training checkpoint"""
|
| checkpoint_dir = Path(checkpoint_dir)
|
| checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
|
|
| checkpoint = {
|
| 'model_state_dict': model.state_dict(),
|
| 'optimizer_state_dict': optimizer.state_dict(),
|
| 'scheduler_state_dict': scheduler.state_dict(),
|
| 'step': step,
|
| 'loss': loss,
|
| 'config': config.__dict__
|
| }
|
|
|
|
|
| checkpoint_path = checkpoint_dir / f"checkpoint_step_{step}.pt"
|
| torch.save(checkpoint, checkpoint_path)
|
|
|
|
|
| if hasattr(model, 'config'):
|
| config_path = checkpoint_dir / "model_config.json"
|
| with open(config_path, 'w') as f:
|
| json.dump(model.config.to_dict(), f, indent=2)
|
|
|
| print(f"Checkpoint saved: {checkpoint_path}")
|
| return checkpoint_path
|
|
|
|
|
|
|
| def load_checkpoint(checkpoint_path, model, optimizer, scheduler):
|
| """ADDED: Load training checkpoint and resume"""
|
| print(f"Loading checkpoint from {checkpoint_path}...")
|
|
|
| checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
|
|
|
|
| model.load_state_dict(checkpoint['model_state_dict'])
|
|
|
|
|
| optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
|
|
|
| scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
|
|
|
|
| start_step = checkpoint['step']
|
| last_loss = checkpoint['loss']
|
|
|
| print(f"✅ Checkpoint loaded successfully!")
|
| print(f" Resuming from step: {start_step}")
|
| print(f" Last loss: {last_loss:.4f}")
|
|
|
| return start_step, last_loss
|
|
|
|
|
|
|
| def generate_sample(model, tokenizer, prompt="The future of AI", max_length=100, temperature=0.8):
|
| """Generate text sample for evaluation"""
|
| model.eval()
|
| device = next(model.parameters()).device
|
|
|
|
|
| input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
|
|
|
| with torch.no_grad():
|
| for _ in range(max_length):
|
|
|
| logits = model(input_ids)
|
| next_token_logits = logits[0, -1, :] / temperature
|
|
|
|
|
| probs = torch.softmax(next_token_logits, dim=-1)
|
| next_token = torch.multinomial(probs, num_samples=1)
|
|
|
|
|
| input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
|
|
|
|
|
| if next_token.item() == tokenizer.eos_token_id:
|
| break
|
|
|
|
|
| generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
|
| model.train()
|
| return generated_text
|
|
|
|
|
|
|
| def main():
|
|
|
| config = TrainingConfig()
|
|
|
|
|
| accelerator = Accelerator(
|
| mixed_precision=config.mixed_precision,
|
| gradient_accumulation_steps=config.gradient_accumulation_steps,
|
| log_with="tensorboard",
|
| project_dir=config.output_dir
|
| )
|
|
|
|
|
| output_dir = Path(config.output_dir)
|
| output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
| print("Loading dataset...")
|
| dataset = PackedDataset(config.data_path, config.seq_length)
|
| dataloader = DataLoader(
|
| dataset,
|
| batch_size=config.batch_size,
|
| shuffle=True,
|
| pin_memory=True,
|
| num_workers=0,
|
| persistent_workers=False
|
| )
|
|
|
|
|
| print("Creating model...")
|
| if config.model_config_path and Path(config.model_config_path).exists():
|
| model = NeoMini.from_config(config.model_config_path)
|
| else:
|
| model_config = NeoMiniConfig()
|
| model = NeoMini(model_config)
|
|
|
| print(f"Model has {model.get_num_params():,} parameters")
|
|
|
|
|
| if config.gradient_checkpointing:
|
| model.gradient_checkpointing_enable = lambda: None
|
| print("Gradient checkpointing enabled")
|
|
|
|
|
| optimizer = torch.optim.AdamW(
|
| model.parameters(),
|
| lr=config.learning_rate,
|
| betas=(config.beta1, config.beta2),
|
| weight_decay=config.weight_decay
|
| )
|
|
|
|
|
| scheduler = get_cosine_schedule_with_warmup(
|
| optimizer,
|
| num_warmup_steps=config.warmup_steps,
|
| num_training_steps=config.max_steps
|
| )
|
|
|
|
|
| model, optimizer, dataloader, scheduler = accelerator.prepare(
|
| model, optimizer, dataloader, scheduler
|
| )
|
|
|
|
|
| start_step = 0
|
| total_loss = 0
|
|
|
| if config.resume_from_checkpoint and Path(config.resume_from_checkpoint).exists():
|
|
|
| unwrapped_model = accelerator.unwrap_model(model)
|
| start_step, last_loss = load_checkpoint(
|
| config.resume_from_checkpoint,
|
| unwrapped_model,
|
| optimizer,
|
| scheduler
|
| )
|
| total_loss = last_loss * start_step
|
| print(f"🚀 Resuming training from step {start_step}")
|
| else:
|
| print("🚀 Starting fresh training")
|
|
|
|
|
| print("Starting training...")
|
| model.train()
|
|
|
| log_loss = 0
|
|
|
|
|
| dataloader_iter = iter(dataloader)
|
|
|
|
|
| progress_bar = tqdm(range(start_step, config.max_steps), desc="Training")
|
|
|
| for step in progress_bar:
|
|
|
| try:
|
| batch = next(dataloader_iter)
|
| except StopIteration:
|
| dataloader_iter = iter(dataloader)
|
| batch = next(dataloader_iter)
|
|
|
| input_ids, targets = batch
|
|
|
| with accelerator.accumulate(model):
|
|
|
| logits = model(input_ids)
|
| loss = compute_loss(logits, targets)
|
|
|
|
|
| accelerator.backward(loss)
|
|
|
|
|
| if accelerator.sync_gradients:
|
| accelerator.clip_grad_norm_(model.parameters(), config.grad_clip)
|
|
|
|
|
| optimizer.step()
|
| scheduler.step()
|
| optimizer.zero_grad()
|
|
|
|
|
| total_loss += loss.item()
|
| log_loss += loss.item()
|
|
|
| if step % config.log_interval == 0 and step > 0:
|
| avg_loss = log_loss / config.log_interval
|
| lr = scheduler.get_last_lr()[0]
|
|
|
| progress_bar.set_postfix({
|
| 'loss': f'{avg_loss:.4f}',
|
| 'lr': f'{lr:.2e}',
|
| 'step': step
|
| })
|
|
|
|
|
| accelerator.log({
|
| 'train_loss': avg_loss,
|
| 'learning_rate': lr,
|
| 'step': step
|
| }, step=step)
|
|
|
| log_loss = 0
|
|
|
|
|
| if step % config.save_interval == 0 and step > 0:
|
| if accelerator.is_main_process:
|
|
|
| unwrapped_model = accelerator.unwrap_model(model)
|
| save_checkpoint(
|
| unwrapped_model, optimizer, scheduler,
|
| step, total_loss / (step + 1 - start_step), config, output_dir
|
| )
|
|
|
|
|
| if step >= config.max_steps:
|
| break
|
|
|
|
|
| if accelerator.is_main_process:
|
| unwrapped_model = accelerator.unwrap_model(model)
|
| final_checkpoint = save_checkpoint(
|
| unwrapped_model, optimizer, scheduler,
|
| step, total_loss / (step + 1 - start_step), config, output_dir
|
| )
|
| print(f"Training completed! Final checkpoint: {final_checkpoint}")
|
|
|
| accelerator.end_training()
|
|
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|