Q-TensorFormer / src /training.py
Premchan369's picture
v3.0.0: Source files
b9c4adf verified
"""
Training utilities with budget-aware scheduling, energy metrics, and sweep support.
v3 features:
- Budget-constrained training (auto-adjusts ranks to meet param/latency targets)
- Energy estimation (FLOPs-based proxy)
- Knowledge distillation support
- Gradient monitoring and NaN detection
- Checkpointing with metadata
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, LinearLR, SequentialLR
import math
import time
from typing import Optional, Dict, Tuple, List
from pathlib import Path
import json
from .config import ExperimentConfig
from .budget import BudgetTracker, EnergyEstimator
def create_optimizer(model: nn.Module, lr: float, weight_decay: float,
betas: Tuple[float, float] = (0.9, 0.98),
eps: float = 1e-8) -> AdamW:
"""Create AdamW optimizer with weight decay exclusion for norms/biases."""
no_decay = ["bias", "LayerNorm.weight", "layernorm.weight", "ln.weight"]
params = [
{
"params": [p for n, p in model.named_parameters()
if p.requires_grad and not any(nd in n for nd in no_decay)],
"weight_decay": weight_decay,
},
{
"params": [p for n, p in model.named_parameters()
if p.requires_grad and any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
return AdamW(params, lr=lr, betas=betas, eps=eps)
def create_scheduler(optimizer, warmup_steps: int, max_steps: int,
lr_min_factor: float = 0.1, scheduler_type: str = "cosine"):
"""Create learning rate scheduler with warmup."""
warmup = LinearLR(optimizer, start_factor=1e-3, end_factor=1.0,
total_iters=warmup_steps)
if scheduler_type == "cosine":
main = CosineAnnealingWarmRestarts(
optimizer, T_0=max_steps - warmup_steps,
T_mult=1, eta_min=lr_min_factor * optimizer.param_groups[0]["lr"]
)
elif scheduler_type == "linear":
main = LinearLR(optimizer, start_factor=1.0,
end_factor=lr_min_factor,
total_iters=max_steps - warmup_steps)
else:
main = LinearLR(optimizer, start_factor=1.0, end_factor=1.0,
total_iters=max_steps - warmup_steps)
return SequentialLR(optimizer, schedulers=[warmup, main],
milestones=[warmup_steps])
def compute_perplexity(logits: torch.Tensor, targets: torch.Tensor,
ignore_index: int = 0) -> float:
"""Compute perplexity with ignore_index."""
loss = F.cross_entropy(
logits.reshape(-1, logits.size(-1)),
targets.reshape(-1),
ignore_index=ignore_index,
reduction="mean",
)
return math.exp(loss.item())
class Trainer:
"""
Budget-aware Q-TensorFormer trainer.
Tracks:
- Perplexity (primary metric)
- Model size (parameters)
- Latency estimates
- Energy consumption (FLOPs proxy)
- Quantum call statistics
- Rank adaptation trajectories
"""
def __init__(self, model: nn.Module, config: ExperimentConfig,
train_loader, val_loader=None, test_loader=None,
device: str = "cpu", output_dir: str = None):
self.model = model
self.config = config
self.train_loader = train_loader
self.val_loader = val_loader
self.test_loader = test_loader
self.device = torch.device(device)
self.output_dir = Path(output_dir or config.output_dir)
self.model.to(self.device)
total_steps = len(train_loader) * config.training.max_epochs
self.optimizer = create_optimizer(
model, config.training.learning_rate, config.training.weight_decay
)
self.scheduler = create_scheduler(
self.optimizer,
warmup_steps=config.training.warmup_steps,
max_steps=total_steps,
lr_min_factor=config.training.lr_min_factor,
scheduler_type=config.training.lr_scheduler,
)
# Budget tracking
self.budget_tracker = BudgetTracker(config.budget)
self.energy_estimator = EnergyEstimator()
# Logging
self.metrics_history: List[Dict] = []
self.grad_norms: List[float] = []
def train_epoch(self, epoch: int) -> Dict:
"""Train for one epoch. Returns metrics dict."""
self.model.train()
self.model.reset_schedulers()
total_loss = 0.0
total_tokens = 0
start_time = time.time()
for step, (inputs, targets) in enumerate(self.train_loader):
inputs, targets = inputs.to(self.device), targets.to(self.device)
self.optimizer.zero_grad()
logits, stats = self.model(inputs, return_stats=True)
loss = F.cross_entropy(
logits.reshape(-1, logits.size(-1)),
targets.reshape(-1),
ignore_index=0, # pad token
)
loss.backward()
# Gradient monitoring
grad_norm = torch.nn.utils.clip_grad_norm_(
self.model.parameters(), self.config.training.max_grad_norm
)
self.grad_norms.append(grad_norm.item())
# NaN check
if torch.isnan(grad_norm) or torch.isinf(grad_norm):
print(f"[WARN] NaN/Inf gradient at step {step}. Skipping update.")
self.optimizer.zero_grad()
continue
self.optimizer.step()
self.scheduler.step()
total_loss += loss.item() * inputs.size(0) * inputs.size(1)
total_tokens += inputs.size(0) * inputs.size(1)
elapsed = time.time() - start_time
avg_loss = total_loss / max(total_tokens, 1)
ppl = math.exp(min(avg_loss, 20.0)) # Cap for stability
# Budget metrics
latency_est = self.budget_tracker.estimate_latency(
self.model, self.config.model.max_seq_len
)
energy_est = self.energy_estimator.estimate(self.model)
metrics = {
"epoch": epoch,
"train_loss": avg_loss,
"train_ppl": ppl,
"lr": self.optimizer.param_groups[0]["lr"],
"grad_norm_mean": sum(self.grad_norms[-len(self.train_loader):]) / len(self.grad_norms),
"total_params": sum(p.numel() for p in self.model.parameters()),
"latency_ms": latency_est,
"energy_uj": energy_est,
"time_s": elapsed,
}
# Extract TT stats
if hasattr(self.model, "stats"):
metrics["model_stats"] = self.model.stats
# Validation
if self.val_loader is not None:
val_metrics = self.validate()
metrics.update(val_metrics)
self.metrics_history.append(metrics)
return metrics
@torch.no_grad()
def validate(self) -> Dict:
"""Run validation."""
self.model.eval()
total_loss = 0.0
total_tokens = 0
for inputs, targets in self.val_loader:
inputs, targets = inputs.to(self.device), targets.to(self.device)
logits = self.model(inputs)
loss = F.cross_entropy(
logits.reshape(-1, logits.size(-1)),
targets.reshape(-1),
ignore_index=0,
reduction="sum",
)
total_loss += loss.item()
total_tokens += inputs.numel()
avg_loss = total_loss / max(total_tokens, 1)
return {
"val_loss": avg_loss,
"val_ppl": math.exp(min(avg_loss, 20.0)),
}
@torch.no_grad()
def evaluate(self) -> Dict:
"""
Full evaluation on test set.
Returns comprehensive metrics dict.
"""
self.model.eval()
total_loss = 0.0
total_tokens = 0
latency_samples = []
for inputs, targets in self.test_loader:
inputs, targets = inputs.to(self.device), targets.to(self.device)
t0 = time.time()
logits = self.model(inputs)
t1 = time.time()
latency_samples.append((t1 - t0) * 1000 / inputs.size(0)) # ms per sample
loss = F.cross_entropy(
logits.reshape(-1, logits.size(-1)),
targets.reshape(-1),
ignore_index=0,
reduction="sum",
)
total_loss += loss.item()
total_tokens += inputs.numel()
avg_loss = total_loss / max(total_tokens, 1)
return {
"test_loss": avg_loss,
"test_ppl": math.exp(min(avg_loss, 20.0)),
"latency_ms_mean": sum(latency_samples) / len(latency_samples),
"total_params": self.model.total_params,
"energy_uj": self.energy_estimator.estimate(self.model),
"model_stats": getattr(self.model, "stats", {}),
}
def train(self) -> Dict:
"""Full training loop."""
best_val_ppl = float("inf")
for epoch in range(self.config.training.max_epochs):
metrics = self.train_epoch(epoch)
# Logging
print(f"Epoch {epoch+1}/{self.config.training.max_epochs}: "
f"train_ppl={metrics['train_ppl']:.2f} "
f"val_ppl={metrics.get('val_ppl', 'N/A')} "
f"lr={metrics['lr']:.2e}")
if metrics.get("val_ppl", float("inf")) < best_val_ppl:
best_val_ppl = metrics["val_ppl"]
self.save_checkpoint("best")
# Early stopping checks
if self.budget_tracker.exceeds_budget(metrics, self.config.model):
print(f"[BUDGET] Exceeded constraints. Stopping.")
break
self.save_checkpoint("last")
self.save_metrics()
return self.metrics_history[-1] if self.metrics_history else {}
def save_checkpoint(self, tag: str = "checkpoint"):
"""Save model checkpoint with metadata."""
self.output_dir.mkdir(parents=True, exist_ok=True)
path = self.output_dir / f"{tag}.pt"
torch.save({
"model_state_dict": self.model.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
"config": self.config,
"metrics": self.metrics_history,
}, path)
print(f"Checkpoint saved to {path}")
def load_checkpoint(self, tag: str = "best"):
"""Load checkpoint."""
path = self.output_dir / f"{tag}.pt"
if not path.exists():
print(f"Checkpoint {path} not found")
return
ckpt = torch.load(path, map_location=self.device, weights_only=True)
self.model.load_state_dict(ckpt["model_state_dict"])
self.optimizer.load_state_dict(ckpt["optimizer_state_dict"])
def save_metrics(self):
"""Save metrics to JSON."""
self.output_dir.mkdir(parents=True, exist_ok=True)
path = self.output_dir / "metrics.json"
with open(path, "w") as f:
json.dump(self.metrics_history, f, indent=2)
print(f"Metrics saved to {path}")
class DistillationTrainer(Trainer):
"""
Knowledge distillation trainer.
Student = compressed Q-TensorFormer.
Teacher = dense (or larger) model.
"""
def __init__(self, student: nn.Module, teacher: nn.Module, *args,
alpha: float = 0.5, temperature: float = 3.0, **kwargs):
"""
Args:
student: Compressed Q-TensorFormer.
teacher: Dense baseline (frozen).
alpha: Weight between distillation loss (α) and task loss (1-α).
temperature: Softmax temperature.
"""
super().__init__(student, *args, **kwargs)
self.teacher = teacher.to(self.device)
self.teacher.eval()
self.alpha = alpha
self.temperature = temperature
# Freeze teacher
for p in self.teacher.parameters():
p.requires_grad = False
def train_epoch(self, epoch: int) -> Dict:
self.model.train()
total_loss = 0.0
total_tokens = 0
for step, (inputs, targets) in enumerate(self.train_loader):
inputs, targets = inputs.to(self.device), targets.to(self.device)
self.optimizer.zero_grad()
# Student forward
logits, stats = self.model(inputs, return_stats=True)
# Task loss
task_loss = F.cross_entropy(
logits.reshape(-1, logits.size(-1)),
targets.reshape(-1),
ignore_index=0,
)
# Distillation loss
with torch.no_grad():
teacher_logits = self.teacher(inputs)
distill_loss = F.kl_div(
F.log_softmax(logits / self.temperature, dim=-1),
F.softmax(teacher_logits / self.temperature, dim=-1),
reduction="batchmean",
) * (self.temperature ** 2)
loss = (1 - self.alpha) * task_loss + self.alpha * distill_loss
loss.backward()
torch.nn.utils.clip_grad_norm_(
self.model.parameters(), self.config.training.max_grad_norm
)
self.optimizer.step()
self.scheduler.step()
total_loss += task_loss.item() * inputs.numel()
total_tokens += inputs.numel()
avg_loss = total_loss / max(total_tokens, 1)
ppl = math.exp(min(avg_loss, 20.0))
return {
"epoch": epoch,
"train_loss": avg_loss,
"train_ppl": ppl,
"lr": self.optimizer.param_groups[0]["lr"],
}