""" Training utilities with budget-aware scheduling, energy metrics, and sweep support. v3 features: - Budget-constrained training (auto-adjusts ranks to meet param/latency targets) - Energy estimation (FLOPs-based proxy) - Knowledge distillation support - Gradient monitoring and NaN detection - Checkpointing with metadata """ import torch import torch.nn as nn import torch.nn.functional as F from torch.optim import AdamW from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, LinearLR, SequentialLR import math import time from typing import Optional, Dict, Tuple, List from pathlib import Path import json from .config import ExperimentConfig from .budget import BudgetTracker, EnergyEstimator def create_optimizer(model: nn.Module, lr: float, weight_decay: float, betas: Tuple[float, float] = (0.9, 0.98), eps: float = 1e-8) -> AdamW: """Create AdamW optimizer with weight decay exclusion for norms/biases.""" no_decay = ["bias", "LayerNorm.weight", "layernorm.weight", "ln.weight"] params = [ { "params": [p for n, p in model.named_parameters() if p.requires_grad and not any(nd in n for nd in no_decay)], "weight_decay": weight_decay, }, { "params": [p for n, p in model.named_parameters() if p.requires_grad and any(nd in n for nd in no_decay)], "weight_decay": 0.0, }, ] return AdamW(params, lr=lr, betas=betas, eps=eps) def create_scheduler(optimizer, warmup_steps: int, max_steps: int, lr_min_factor: float = 0.1, scheduler_type: str = "cosine"): """Create learning rate scheduler with warmup.""" warmup = LinearLR(optimizer, start_factor=1e-3, end_factor=1.0, total_iters=warmup_steps) if scheduler_type == "cosine": main = CosineAnnealingWarmRestarts( optimizer, T_0=max_steps - warmup_steps, T_mult=1, eta_min=lr_min_factor * optimizer.param_groups[0]["lr"] ) elif scheduler_type == "linear": main = LinearLR(optimizer, start_factor=1.0, end_factor=lr_min_factor, total_iters=max_steps - warmup_steps) else: main = LinearLR(optimizer, start_factor=1.0, end_factor=1.0, total_iters=max_steps - warmup_steps) return SequentialLR(optimizer, schedulers=[warmup, main], milestones=[warmup_steps]) def compute_perplexity(logits: torch.Tensor, targets: torch.Tensor, ignore_index: int = 0) -> float: """Compute perplexity with ignore_index.""" loss = F.cross_entropy( logits.reshape(-1, logits.size(-1)), targets.reshape(-1), ignore_index=ignore_index, reduction="mean", ) return math.exp(loss.item()) class Trainer: """ Budget-aware Q-TensorFormer trainer. Tracks: - Perplexity (primary metric) - Model size (parameters) - Latency estimates - Energy consumption (FLOPs proxy) - Quantum call statistics - Rank adaptation trajectories """ def __init__(self, model: nn.Module, config: ExperimentConfig, train_loader, val_loader=None, test_loader=None, device: str = "cpu", output_dir: str = None): self.model = model self.config = config self.train_loader = train_loader self.val_loader = val_loader self.test_loader = test_loader self.device = torch.device(device) self.output_dir = Path(output_dir or config.output_dir) self.model.to(self.device) total_steps = len(train_loader) * config.training.max_epochs self.optimizer = create_optimizer( model, config.training.learning_rate, config.training.weight_decay ) self.scheduler = create_scheduler( self.optimizer, warmup_steps=config.training.warmup_steps, max_steps=total_steps, lr_min_factor=config.training.lr_min_factor, scheduler_type=config.training.lr_scheduler, ) # Budget tracking self.budget_tracker = BudgetTracker(config.budget) self.energy_estimator = EnergyEstimator() # Logging self.metrics_history: List[Dict] = [] self.grad_norms: List[float] = [] def train_epoch(self, epoch: int) -> Dict: """Train for one epoch. Returns metrics dict.""" self.model.train() self.model.reset_schedulers() total_loss = 0.0 total_tokens = 0 start_time = time.time() for step, (inputs, targets) in enumerate(self.train_loader): inputs, targets = inputs.to(self.device), targets.to(self.device) self.optimizer.zero_grad() logits, stats = self.model(inputs, return_stats=True) loss = F.cross_entropy( logits.reshape(-1, logits.size(-1)), targets.reshape(-1), ignore_index=0, # pad token ) loss.backward() # Gradient monitoring grad_norm = torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.config.training.max_grad_norm ) self.grad_norms.append(grad_norm.item()) # NaN check if torch.isnan(grad_norm) or torch.isinf(grad_norm): print(f"[WARN] NaN/Inf gradient at step {step}. Skipping update.") self.optimizer.zero_grad() continue self.optimizer.step() self.scheduler.step() total_loss += loss.item() * inputs.size(0) * inputs.size(1) total_tokens += inputs.size(0) * inputs.size(1) elapsed = time.time() - start_time avg_loss = total_loss / max(total_tokens, 1) ppl = math.exp(min(avg_loss, 20.0)) # Cap for stability # Budget metrics latency_est = self.budget_tracker.estimate_latency( self.model, self.config.model.max_seq_len ) energy_est = self.energy_estimator.estimate(self.model) metrics = { "epoch": epoch, "train_loss": avg_loss, "train_ppl": ppl, "lr": self.optimizer.param_groups[0]["lr"], "grad_norm_mean": sum(self.grad_norms[-len(self.train_loader):]) / len(self.grad_norms), "total_params": sum(p.numel() for p in self.model.parameters()), "latency_ms": latency_est, "energy_uj": energy_est, "time_s": elapsed, } # Extract TT stats if hasattr(self.model, "stats"): metrics["model_stats"] = self.model.stats # Validation if self.val_loader is not None: val_metrics = self.validate() metrics.update(val_metrics) self.metrics_history.append(metrics) return metrics @torch.no_grad() def validate(self) -> Dict: """Run validation.""" self.model.eval() total_loss = 0.0 total_tokens = 0 for inputs, targets in self.val_loader: inputs, targets = inputs.to(self.device), targets.to(self.device) logits = self.model(inputs) loss = F.cross_entropy( logits.reshape(-1, logits.size(-1)), targets.reshape(-1), ignore_index=0, reduction="sum", ) total_loss += loss.item() total_tokens += inputs.numel() avg_loss = total_loss / max(total_tokens, 1) return { "val_loss": avg_loss, "val_ppl": math.exp(min(avg_loss, 20.0)), } @torch.no_grad() def evaluate(self) -> Dict: """ Full evaluation on test set. Returns comprehensive metrics dict. """ self.model.eval() total_loss = 0.0 total_tokens = 0 latency_samples = [] for inputs, targets in self.test_loader: inputs, targets = inputs.to(self.device), targets.to(self.device) t0 = time.time() logits = self.model(inputs) t1 = time.time() latency_samples.append((t1 - t0) * 1000 / inputs.size(0)) # ms per sample loss = F.cross_entropy( logits.reshape(-1, logits.size(-1)), targets.reshape(-1), ignore_index=0, reduction="sum", ) total_loss += loss.item() total_tokens += inputs.numel() avg_loss = total_loss / max(total_tokens, 1) return { "test_loss": avg_loss, "test_ppl": math.exp(min(avg_loss, 20.0)), "latency_ms_mean": sum(latency_samples) / len(latency_samples), "total_params": self.model.total_params, "energy_uj": self.energy_estimator.estimate(self.model), "model_stats": getattr(self.model, "stats", {}), } def train(self) -> Dict: """Full training loop.""" best_val_ppl = float("inf") for epoch in range(self.config.training.max_epochs): metrics = self.train_epoch(epoch) # Logging print(f"Epoch {epoch+1}/{self.config.training.max_epochs}: " f"train_ppl={metrics['train_ppl']:.2f} " f"val_ppl={metrics.get('val_ppl', 'N/A')} " f"lr={metrics['lr']:.2e}") if metrics.get("val_ppl", float("inf")) < best_val_ppl: best_val_ppl = metrics["val_ppl"] self.save_checkpoint("best") # Early stopping checks if self.budget_tracker.exceeds_budget(metrics, self.config.model): print(f"[BUDGET] Exceeded constraints. Stopping.") break self.save_checkpoint("last") self.save_metrics() return self.metrics_history[-1] if self.metrics_history else {} def save_checkpoint(self, tag: str = "checkpoint"): """Save model checkpoint with metadata.""" self.output_dir.mkdir(parents=True, exist_ok=True) path = self.output_dir / f"{tag}.pt" torch.save({ "model_state_dict": self.model.state_dict(), "optimizer_state_dict": self.optimizer.state_dict(), "config": self.config, "metrics": self.metrics_history, }, path) print(f"Checkpoint saved to {path}") def load_checkpoint(self, tag: str = "best"): """Load checkpoint.""" path = self.output_dir / f"{tag}.pt" if not path.exists(): print(f"Checkpoint {path} not found") return ckpt = torch.load(path, map_location=self.device, weights_only=True) self.model.load_state_dict(ckpt["model_state_dict"]) self.optimizer.load_state_dict(ckpt["optimizer_state_dict"]) def save_metrics(self): """Save metrics to JSON.""" self.output_dir.mkdir(parents=True, exist_ok=True) path = self.output_dir / "metrics.json" with open(path, "w") as f: json.dump(self.metrics_history, f, indent=2) print(f"Metrics saved to {path}") class DistillationTrainer(Trainer): """ Knowledge distillation trainer. Student = compressed Q-TensorFormer. Teacher = dense (or larger) model. """ def __init__(self, student: nn.Module, teacher: nn.Module, *args, alpha: float = 0.5, temperature: float = 3.0, **kwargs): """ Args: student: Compressed Q-TensorFormer. teacher: Dense baseline (frozen). alpha: Weight between distillation loss (α) and task loss (1-α). temperature: Softmax temperature. """ super().__init__(student, *args, **kwargs) self.teacher = teacher.to(self.device) self.teacher.eval() self.alpha = alpha self.temperature = temperature # Freeze teacher for p in self.teacher.parameters(): p.requires_grad = False def train_epoch(self, epoch: int) -> Dict: self.model.train() total_loss = 0.0 total_tokens = 0 for step, (inputs, targets) in enumerate(self.train_loader): inputs, targets = inputs.to(self.device), targets.to(self.device) self.optimizer.zero_grad() # Student forward logits, stats = self.model(inputs, return_stats=True) # Task loss task_loss = F.cross_entropy( logits.reshape(-1, logits.size(-1)), targets.reshape(-1), ignore_index=0, ) # Distillation loss with torch.no_grad(): teacher_logits = self.teacher(inputs) distill_loss = F.kl_div( F.log_softmax(logits / self.temperature, dim=-1), F.softmax(teacher_logits / self.temperature, dim=-1), reduction="batchmean", ) * (self.temperature ** 2) loss = (1 - self.alpha) * task_loss + self.alpha * distill_loss loss.backward() torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.config.training.max_grad_norm ) self.optimizer.step() self.scheduler.step() total_loss += task_loss.item() * inputs.numel() total_tokens += inputs.numel() avg_loss = total_loss / max(total_tokens, 1) ppl = math.exp(min(avg_loss, 20.0)) return { "epoch": epoch, "train_loss": avg_loss, "train_ppl": ppl, "lr": self.optimizer.param_groups[0]["lr"], }