""" Comprehensive metrics for evaluation. v3 features: - Perplexity (primary LM metric) - Parameter counts (total, compressed, ratio) - Latency benchmarks (warm-up + measured) - FLOPs estimation (proxy for energy) - Quantum call statistics - Rank trajectory analysis - Pareto frontier computation (PPL vs params) """ import torch import time import math from typing import Dict, List, Optional from .config import ExperimentConfig def evaluate_model(model, test_loader, device: str = "cpu", max_batches: int = None) -> Dict: """ Comprehensive model evaluation. Metrics: - test_ppl: Perplexity on test set - total_params, trainable_params - latency_p50, latency_p95 (ms per sample) - peak_memory_mb - flops_estimate Args: model: nn.Module to evaluate. test_loader: DataLoader with (input, target) batches. device: Device string. max_batches: Limit eval to N batches (None = all). Returns: Dict with all metrics. """ model.eval() model.to(device) total_loss = 0.0 total_tokens = 0 latencies = [] for i, (inputs, targets) in enumerate(test_loader): if max_batches and i >= max_batches: break inputs, targets = inputs.to(device), targets.to(device) # Warm-up GPU if i == 0: _ = model(inputs) if device != "cpu": torch.cuda.synchronize() # Timed forward t0 = time.time() logits = model(inputs) if device != "cpu": torch.cuda.synchronize() elapsed = (time.time() - t0) * 1000 # ms latencies.append(elapsed / inputs.size(0)) loss = torch.nn.functional.cross_entropy( logits.reshape(-1, logits.size(-1)), targets.reshape(-1), ignore_index=0, reduction="sum", ) total_loss += loss.item() total_tokens += inputs.numel() avg_loss = total_loss / max(total_tokens, 1) ppl = math.exp(min(avg_loss, 20.0)) # Sort latencies for percentile reporting latencies.sort() n = len(latencies) result = { "test_ppl": ppl, "test_loss": avg_loss, "total_params": sum(p.numel() for p in model.parameters()), "trainable_params": sum(p.numel() for p in model.parameters() if p.requires_grad), "latency_ms_mean": sum(latencies) / n, "latency_ms_p50": latencies[n // 2], "latency_ms_p95": latencies[min(int(n * 0.95), n - 1)], "n_samples_evaluated": n, } # Model-specific stats if hasattr(model, "stats"): result["model_stats"] = model.stats if hasattr(model, "compression_ratio"): result["compression_ratio"] = model.compression_ratio return result def compare_models(models: Dict[str, object], test_loader, device: str = "cpu") -> Dict[str, Dict]: """ Compare multiple models on the same test set. Args: models: Dict[name → model] test_loader: DataLoader. Returns: Dict[name → metrics] """ results = {} for name, model in models.items(): print(f"Evaluating {name}...") results[name] = evaluate_model(model, test_loader, device) return results def compute_pareto_frontier(results: Dict[str, Dict], x_key: str = "total_params", y_key: str = "test_ppl", minimize_y: bool = True) -> List[str]: """ Find Pareto-optimal models from comparison results. A model is Pareto-optimal if no other model has: - Fewer parameters AND better perplexity Args: results: Dict[name → metrics] x_key: Metric for x-axis (e.g., total_params) y_key: Metric for y-axis (e.g., test_ppl) minimize_y: True if lower y is better. Returns: List of Pareto-optimal model names. """ pareto = [] names = list(results.keys()) for i, name_i in enumerate(names): xi = results[name_i][x_key] yi = results[name_i][y_key] dominated = False for j, name_j in enumerate(names): if i == j: continue xj = results[name_j][x_key] yj = results[name_j][y_key] if minimize_y: # j dominates i: j has fewer params AND better PPL if xj <= xi and yj <= yi and (xj < xi or yj < yi): dominated = True break else: if xj <= xi and yj >= yi and (xj < xi or yj > yi): dominated = True break if not dominated: pareto.append(name_i) return pareto def compute_efficiency_score(result: Dict) -> float: """ Combined efficiency score (higher is better). Efficiency = 1 / (PPL × √params × latency_ms) Normalized so that better models get higher scores. """ ppl = max(result["test_ppl"], 1.0) params = max(result["total_params"], 1) latency = max(result.get("latency_ms_mean", 1.0), 0.1) # 1 / (PPL * sqrt(params) * latency): simpler = better score = 1.0 / (ppl * math.sqrt(params / 1e6) * latency) return score * 1e6 # Scale for readability def rank_trajectory_analysis(metrics_history: List[Dict]) -> Dict: """ Analyze rank adaptation over training. Args: metrics_history: List of per-epoch metrics from Trainer. Returns: Dict with rank statistics. """ if not metrics_history or "model_stats" not in metrics_history[-1]: return {} ranks_over_time = [] for epoch_data in metrics_history: model_stats = epoch_data.get("model_stats", {}) rank_history = model_stats.get("rank_history", {}) if rank_history: ranks_over_time.append(rank_history) if not ranks_over_time: return {} final_ranks = ranks_over_time[-1] return { "final_ranks": final_ranks, "rank_variance": sum( (r - sum(final_ranks.values()) / len(final_ranks)) ** 2 for r in final_ranks.values() ) / len(final_ranks), "n_epochs_converged": len(ranks_over_time), } def print_comparison_table(results: Dict[str, Dict]): """Pretty-print comparison table.""" header = f"{'Model':<20} {'PPL':>8} {'Params':>10} {'Lat(ms)':>10} {'Score':>10}" print("=" * len(header)) print(header) print("-" * len(header)) for name, r in sorted(results.items(), key=lambda x: x[1]["test_ppl"]): score = compute_efficiency_score(r) params_k = r["total_params"] / 1000 print(f"{name:<20} {r['test_ppl']:8.2f} {params_k:8.1f}K " f"{r.get('latency_ms_mean', 0):8.2f} {score:10.1f}") print("=" * len(header)) pareto = compute_pareto_frontier(results) print(f"\nPareto-optimal models: {pareto}")