| """ |
| Training utilities with budget-aware scheduling, energy metrics, and sweep support. |
| |
| v3 features: |
| - Budget-constrained training (auto-adjusts ranks to meet param/latency targets) |
| - Energy estimation (FLOPs-based proxy) |
| - Knowledge distillation support |
| - Gradient monitoring and NaN detection |
| - Checkpointing with metadata |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.optim import AdamW |
| from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, LinearLR, SequentialLR |
| import math |
| import time |
| from typing import Optional, Dict, Tuple, List |
| from pathlib import Path |
| import json |
|
|
| from .config import ExperimentConfig |
| from .budget import BudgetTracker, EnergyEstimator |
|
|
|
|
| def create_optimizer(model: nn.Module, lr: float, weight_decay: float, |
| betas: Tuple[float, float] = (0.9, 0.98), |
| eps: float = 1e-8) -> AdamW: |
| """Create AdamW optimizer with weight decay exclusion for norms/biases.""" |
| no_decay = ["bias", "LayerNorm.weight", "layernorm.weight", "ln.weight"] |
| params = [ |
| { |
| "params": [p for n, p in model.named_parameters() |
| if p.requires_grad and not any(nd in n for nd in no_decay)], |
| "weight_decay": weight_decay, |
| }, |
| { |
| "params": [p for n, p in model.named_parameters() |
| if p.requires_grad and any(nd in n for nd in no_decay)], |
| "weight_decay": 0.0, |
| }, |
| ] |
| return AdamW(params, lr=lr, betas=betas, eps=eps) |
|
|
|
|
| def create_scheduler(optimizer, warmup_steps: int, max_steps: int, |
| lr_min_factor: float = 0.1, scheduler_type: str = "cosine"): |
| """Create learning rate scheduler with warmup.""" |
| warmup = LinearLR(optimizer, start_factor=1e-3, end_factor=1.0, |
| total_iters=warmup_steps) |
|
|
| if scheduler_type == "cosine": |
| main = CosineAnnealingWarmRestarts( |
| optimizer, T_0=max_steps - warmup_steps, |
| T_mult=1, eta_min=lr_min_factor * optimizer.param_groups[0]["lr"] |
| ) |
| elif scheduler_type == "linear": |
| main = LinearLR(optimizer, start_factor=1.0, |
| end_factor=lr_min_factor, |
| total_iters=max_steps - warmup_steps) |
| else: |
| main = LinearLR(optimizer, start_factor=1.0, end_factor=1.0, |
| total_iters=max_steps - warmup_steps) |
|
|
| return SequentialLR(optimizer, schedulers=[warmup, main], |
| milestones=[warmup_steps]) |
|
|
|
|
| def compute_perplexity(logits: torch.Tensor, targets: torch.Tensor, |
| ignore_index: int = 0) -> float: |
| """Compute perplexity with ignore_index.""" |
| loss = F.cross_entropy( |
| logits.reshape(-1, logits.size(-1)), |
| targets.reshape(-1), |
| ignore_index=ignore_index, |
| reduction="mean", |
| ) |
| return math.exp(loss.item()) |
|
|
|
|
| class Trainer: |
| """ |
| Budget-aware Q-TensorFormer trainer. |
| |
| Tracks: |
| - Perplexity (primary metric) |
| - Model size (parameters) |
| - Latency estimates |
| - Energy consumption (FLOPs proxy) |
| - Quantum call statistics |
| - Rank adaptation trajectories |
| """ |
|
|
| def __init__(self, model: nn.Module, config: ExperimentConfig, |
| train_loader, val_loader=None, test_loader=None, |
| device: str = "cpu", output_dir: str = None): |
| self.model = model |
| self.config = config |
| self.train_loader = train_loader |
| self.val_loader = val_loader |
| self.test_loader = test_loader |
| self.device = torch.device(device) |
| self.output_dir = Path(output_dir or config.output_dir) |
|
|
| self.model.to(self.device) |
|
|
| total_steps = len(train_loader) * config.training.max_epochs |
| self.optimizer = create_optimizer( |
| model, config.training.learning_rate, config.training.weight_decay |
| ) |
| self.scheduler = create_scheduler( |
| self.optimizer, |
| warmup_steps=config.training.warmup_steps, |
| max_steps=total_steps, |
| lr_min_factor=config.training.lr_min_factor, |
| scheduler_type=config.training.lr_scheduler, |
| ) |
|
|
| |
| self.budget_tracker = BudgetTracker(config.budget) |
| self.energy_estimator = EnergyEstimator() |
|
|
| |
| self.metrics_history: List[Dict] = [] |
| self.grad_norms: List[float] = [] |
|
|
| def train_epoch(self, epoch: int) -> Dict: |
| """Train for one epoch. Returns metrics dict.""" |
| self.model.train() |
| self.model.reset_schedulers() |
| total_loss = 0.0 |
| total_tokens = 0 |
| start_time = time.time() |
|
|
| for step, (inputs, targets) in enumerate(self.train_loader): |
| inputs, targets = inputs.to(self.device), targets.to(self.device) |
|
|
| self.optimizer.zero_grad() |
|
|
| logits, stats = self.model(inputs, return_stats=True) |
| loss = F.cross_entropy( |
| logits.reshape(-1, logits.size(-1)), |
| targets.reshape(-1), |
| ignore_index=0, |
| ) |
|
|
| loss.backward() |
|
|
| |
| grad_norm = torch.nn.utils.clip_grad_norm_( |
| self.model.parameters(), self.config.training.max_grad_norm |
| ) |
| self.grad_norms.append(grad_norm.item()) |
|
|
| |
| if torch.isnan(grad_norm) or torch.isinf(grad_norm): |
| print(f"[WARN] NaN/Inf gradient at step {step}. Skipping update.") |
| self.optimizer.zero_grad() |
| continue |
|
|
| self.optimizer.step() |
| self.scheduler.step() |
|
|
| total_loss += loss.item() * inputs.size(0) * inputs.size(1) |
| total_tokens += inputs.size(0) * inputs.size(1) |
|
|
| elapsed = time.time() - start_time |
| avg_loss = total_loss / max(total_tokens, 1) |
| ppl = math.exp(min(avg_loss, 20.0)) |
|
|
| |
| latency_est = self.budget_tracker.estimate_latency( |
| self.model, self.config.model.max_seq_len |
| ) |
| energy_est = self.energy_estimator.estimate(self.model) |
|
|
| metrics = { |
| "epoch": epoch, |
| "train_loss": avg_loss, |
| "train_ppl": ppl, |
| "lr": self.optimizer.param_groups[0]["lr"], |
| "grad_norm_mean": sum(self.grad_norms[-len(self.train_loader):]) / len(self.grad_norms), |
| "total_params": sum(p.numel() for p in self.model.parameters()), |
| "latency_ms": latency_est, |
| "energy_uj": energy_est, |
| "time_s": elapsed, |
| } |
|
|
| |
| if hasattr(self.model, "stats"): |
| metrics["model_stats"] = self.model.stats |
|
|
| |
| if self.val_loader is not None: |
| val_metrics = self.validate() |
| metrics.update(val_metrics) |
|
|
| self.metrics_history.append(metrics) |
| return metrics |
|
|
| @torch.no_grad() |
| def validate(self) -> Dict: |
| """Run validation.""" |
| self.model.eval() |
| total_loss = 0.0 |
| total_tokens = 0 |
|
|
| for inputs, targets in self.val_loader: |
| inputs, targets = inputs.to(self.device), targets.to(self.device) |
| logits = self.model(inputs) |
| loss = F.cross_entropy( |
| logits.reshape(-1, logits.size(-1)), |
| targets.reshape(-1), |
| ignore_index=0, |
| reduction="sum", |
| ) |
| total_loss += loss.item() |
| total_tokens += inputs.numel() |
|
|
| avg_loss = total_loss / max(total_tokens, 1) |
| return { |
| "val_loss": avg_loss, |
| "val_ppl": math.exp(min(avg_loss, 20.0)), |
| } |
|
|
| @torch.no_grad() |
| def evaluate(self) -> Dict: |
| """ |
| Full evaluation on test set. |
| Returns comprehensive metrics dict. |
| """ |
| self.model.eval() |
| total_loss = 0.0 |
| total_tokens = 0 |
| latency_samples = [] |
|
|
| for inputs, targets in self.test_loader: |
| inputs, targets = inputs.to(self.device), targets.to(self.device) |
|
|
| t0 = time.time() |
| logits = self.model(inputs) |
| t1 = time.time() |
| latency_samples.append((t1 - t0) * 1000 / inputs.size(0)) |
|
|
| loss = F.cross_entropy( |
| logits.reshape(-1, logits.size(-1)), |
| targets.reshape(-1), |
| ignore_index=0, |
| reduction="sum", |
| ) |
| total_loss += loss.item() |
| total_tokens += inputs.numel() |
|
|
| avg_loss = total_loss / max(total_tokens, 1) |
|
|
| return { |
| "test_loss": avg_loss, |
| "test_ppl": math.exp(min(avg_loss, 20.0)), |
| "latency_ms_mean": sum(latency_samples) / len(latency_samples), |
| "total_params": self.model.total_params, |
| "energy_uj": self.energy_estimator.estimate(self.model), |
| "model_stats": getattr(self.model, "stats", {}), |
| } |
|
|
| def train(self) -> Dict: |
| """Full training loop.""" |
| best_val_ppl = float("inf") |
|
|
| for epoch in range(self.config.training.max_epochs): |
| metrics = self.train_epoch(epoch) |
|
|
| |
| print(f"Epoch {epoch+1}/{self.config.training.max_epochs}: " |
| f"train_ppl={metrics['train_ppl']:.2f} " |
| f"val_ppl={metrics.get('val_ppl', 'N/A')} " |
| f"lr={metrics['lr']:.2e}") |
|
|
| if metrics.get("val_ppl", float("inf")) < best_val_ppl: |
| best_val_ppl = metrics["val_ppl"] |
| self.save_checkpoint("best") |
|
|
| |
| if self.budget_tracker.exceeds_budget(metrics, self.config.model): |
| print(f"[BUDGET] Exceeded constraints. Stopping.") |
| break |
|
|
| self.save_checkpoint("last") |
| self.save_metrics() |
| return self.metrics_history[-1] if self.metrics_history else {} |
|
|
| def save_checkpoint(self, tag: str = "checkpoint"): |
| """Save model checkpoint with metadata.""" |
| self.output_dir.mkdir(parents=True, exist_ok=True) |
| path = self.output_dir / f"{tag}.pt" |
| torch.save({ |
| "model_state_dict": self.model.state_dict(), |
| "optimizer_state_dict": self.optimizer.state_dict(), |
| "config": self.config, |
| "metrics": self.metrics_history, |
| }, path) |
| print(f"Checkpoint saved to {path}") |
|
|
| def load_checkpoint(self, tag: str = "best"): |
| """Load checkpoint.""" |
| path = self.output_dir / f"{tag}.pt" |
| if not path.exists(): |
| print(f"Checkpoint {path} not found") |
| return |
| ckpt = torch.load(path, map_location=self.device, weights_only=True) |
| self.model.load_state_dict(ckpt["model_state_dict"]) |
| self.optimizer.load_state_dict(ckpt["optimizer_state_dict"]) |
|
|
| def save_metrics(self): |
| """Save metrics to JSON.""" |
| self.output_dir.mkdir(parents=True, exist_ok=True) |
| path = self.output_dir / "metrics.json" |
| with open(path, "w") as f: |
| json.dump(self.metrics_history, f, indent=2) |
| print(f"Metrics saved to {path}") |
|
|
|
|
| class DistillationTrainer(Trainer): |
| """ |
| Knowledge distillation trainer. |
| |
| Student = compressed Q-TensorFormer. |
| Teacher = dense (or larger) model. |
| """ |
|
|
| def __init__(self, student: nn.Module, teacher: nn.Module, *args, |
| alpha: float = 0.5, temperature: float = 3.0, **kwargs): |
| """ |
| Args: |
| student: Compressed Q-TensorFormer. |
| teacher: Dense baseline (frozen). |
| alpha: Weight between distillation loss (α) and task loss (1-α). |
| temperature: Softmax temperature. |
| """ |
| super().__init__(student, *args, **kwargs) |
| self.teacher = teacher.to(self.device) |
| self.teacher.eval() |
| self.alpha = alpha |
| self.temperature = temperature |
|
|
| |
| for p in self.teacher.parameters(): |
| p.requires_grad = False |
|
|
| def train_epoch(self, epoch: int) -> Dict: |
| self.model.train() |
| total_loss = 0.0 |
| total_tokens = 0 |
|
|
| for step, (inputs, targets) in enumerate(self.train_loader): |
| inputs, targets = inputs.to(self.device), targets.to(self.device) |
|
|
| self.optimizer.zero_grad() |
|
|
| |
| logits, stats = self.model(inputs, return_stats=True) |
|
|
| |
| task_loss = F.cross_entropy( |
| logits.reshape(-1, logits.size(-1)), |
| targets.reshape(-1), |
| ignore_index=0, |
| ) |
|
|
| |
| with torch.no_grad(): |
| teacher_logits = self.teacher(inputs) |
|
|
| distill_loss = F.kl_div( |
| F.log_softmax(logits / self.temperature, dim=-1), |
| F.softmax(teacher_logits / self.temperature, dim=-1), |
| reduction="batchmean", |
| ) * (self.temperature ** 2) |
|
|
| loss = (1 - self.alpha) * task_loss + self.alpha * distill_loss |
| loss.backward() |
|
|
| torch.nn.utils.clip_grad_norm_( |
| self.model.parameters(), self.config.training.max_grad_norm |
| ) |
| self.optimizer.step() |
| self.scheduler.step() |
|
|
| total_loss += task_loss.item() * inputs.numel() |
| total_tokens += inputs.numel() |
|
|
| avg_loss = total_loss / max(total_tokens, 1) |
| ppl = math.exp(min(avg_loss, 20.0)) |
| return { |
| "epoch": epoch, |
| "train_loss": avg_loss, |
| "train_ppl": ppl, |
| "lr": self.optimizer.param_groups[0]["lr"], |
| } |
|
|