| """ |
| Comprehensive metrics for evaluation. |
| |
| v3 features: |
| - Perplexity (primary LM metric) |
| - Parameter counts (total, compressed, ratio) |
| - Latency benchmarks (warm-up + measured) |
| - FLOPs estimation (proxy for energy) |
| - Quantum call statistics |
| - Rank trajectory analysis |
| - Pareto frontier computation (PPL vs params) |
| """ |
|
|
| import torch |
| import time |
| import math |
| from typing import Dict, List, Optional |
| from .config import ExperimentConfig |
|
|
|
|
| def evaluate_model(model, test_loader, device: str = "cpu", |
| max_batches: int = None) -> Dict: |
| """ |
| Comprehensive model evaluation. |
| |
| Metrics: |
| - test_ppl: Perplexity on test set |
| - total_params, trainable_params |
| - latency_p50, latency_p95 (ms per sample) |
| - peak_memory_mb |
| - flops_estimate |
| |
| Args: |
| model: nn.Module to evaluate. |
| test_loader: DataLoader with (input, target) batches. |
| device: Device string. |
| max_batches: Limit eval to N batches (None = all). |
| |
| Returns: |
| Dict with all metrics. |
| """ |
| model.eval() |
| model.to(device) |
|
|
| total_loss = 0.0 |
| total_tokens = 0 |
| latencies = [] |
|
|
| for i, (inputs, targets) in enumerate(test_loader): |
| if max_batches and i >= max_batches: |
| break |
| inputs, targets = inputs.to(device), targets.to(device) |
|
|
| |
| if i == 0: |
| _ = model(inputs) |
| if device != "cpu": |
| torch.cuda.synchronize() |
|
|
| |
| t0 = time.time() |
| logits = model(inputs) |
| if device != "cpu": |
| torch.cuda.synchronize() |
| elapsed = (time.time() - t0) * 1000 |
| latencies.append(elapsed / inputs.size(0)) |
|
|
| loss = torch.nn.functional.cross_entropy( |
| logits.reshape(-1, logits.size(-1)), |
| targets.reshape(-1), |
| ignore_index=0, |
| reduction="sum", |
| ) |
| total_loss += loss.item() |
| total_tokens += inputs.numel() |
|
|
| avg_loss = total_loss / max(total_tokens, 1) |
| ppl = math.exp(min(avg_loss, 20.0)) |
|
|
| |
| latencies.sort() |
| n = len(latencies) |
|
|
| result = { |
| "test_ppl": ppl, |
| "test_loss": avg_loss, |
| "total_params": sum(p.numel() for p in model.parameters()), |
| "trainable_params": sum(p.numel() for p in model.parameters() if p.requires_grad), |
| "latency_ms_mean": sum(latencies) / n, |
| "latency_ms_p50": latencies[n // 2], |
| "latency_ms_p95": latencies[min(int(n * 0.95), n - 1)], |
| "n_samples_evaluated": n, |
| } |
|
|
| |
| if hasattr(model, "stats"): |
| result["model_stats"] = model.stats |
|
|
| if hasattr(model, "compression_ratio"): |
| result["compression_ratio"] = model.compression_ratio |
|
|
| return result |
|
|
|
|
| def compare_models(models: Dict[str, object], test_loader, |
| device: str = "cpu") -> Dict[str, Dict]: |
| """ |
| Compare multiple models on the same test set. |
| |
| Args: |
| models: Dict[name → model] |
| test_loader: DataLoader. |
| |
| Returns: |
| Dict[name → metrics] |
| """ |
| results = {} |
| for name, model in models.items(): |
| print(f"Evaluating {name}...") |
| results[name] = evaluate_model(model, test_loader, device) |
| return results |
|
|
|
|
| def compute_pareto_frontier(results: Dict[str, Dict], |
| x_key: str = "total_params", |
| y_key: str = "test_ppl", |
| minimize_y: bool = True) -> List[str]: |
| """ |
| Find Pareto-optimal models from comparison results. |
| |
| A model is Pareto-optimal if no other model has: |
| - Fewer parameters AND better perplexity |
| |
| Args: |
| results: Dict[name → metrics] |
| x_key: Metric for x-axis (e.g., total_params) |
| y_key: Metric for y-axis (e.g., test_ppl) |
| minimize_y: True if lower y is better. |
| |
| Returns: |
| List of Pareto-optimal model names. |
| """ |
| pareto = [] |
| names = list(results.keys()) |
|
|
| for i, name_i in enumerate(names): |
| xi = results[name_i][x_key] |
| yi = results[name_i][y_key] |
| dominated = False |
|
|
| for j, name_j in enumerate(names): |
| if i == j: |
| continue |
| xj = results[name_j][x_key] |
| yj = results[name_j][y_key] |
|
|
| if minimize_y: |
| |
| if xj <= xi and yj <= yi and (xj < xi or yj < yi): |
| dominated = True |
| break |
| else: |
| if xj <= xi and yj >= yi and (xj < xi or yj > yi): |
| dominated = True |
| break |
|
|
| if not dominated: |
| pareto.append(name_i) |
|
|
| return pareto |
|
|
|
|
| def compute_efficiency_score(result: Dict) -> float: |
| """ |
| Combined efficiency score (higher is better). |
| |
| Efficiency = 1 / (PPL × √params × latency_ms) |
| |
| Normalized so that better models get higher scores. |
| """ |
| ppl = max(result["test_ppl"], 1.0) |
| params = max(result["total_params"], 1) |
| latency = max(result.get("latency_ms_mean", 1.0), 0.1) |
|
|
| |
| score = 1.0 / (ppl * math.sqrt(params / 1e6) * latency) |
| return score * 1e6 |
|
|
|
|
| def rank_trajectory_analysis(metrics_history: List[Dict]) -> Dict: |
| """ |
| Analyze rank adaptation over training. |
| |
| Args: |
| metrics_history: List of per-epoch metrics from Trainer. |
| |
| Returns: |
| Dict with rank statistics. |
| """ |
| if not metrics_history or "model_stats" not in metrics_history[-1]: |
| return {} |
|
|
| ranks_over_time = [] |
| for epoch_data in metrics_history: |
| model_stats = epoch_data.get("model_stats", {}) |
| rank_history = model_stats.get("rank_history", {}) |
| if rank_history: |
| ranks_over_time.append(rank_history) |
|
|
| if not ranks_over_time: |
| return {} |
|
|
| final_ranks = ranks_over_time[-1] |
| return { |
| "final_ranks": final_ranks, |
| "rank_variance": sum( |
| (r - sum(final_ranks.values()) / len(final_ranks)) ** 2 |
| for r in final_ranks.values() |
| ) / len(final_ranks), |
| "n_epochs_converged": len(ranks_over_time), |
| } |
|
|
|
|
| def print_comparison_table(results: Dict[str, Dict]): |
| """Pretty-print comparison table.""" |
| header = f"{'Model':<20} {'PPL':>8} {'Params':>10} {'Lat(ms)':>10} {'Score':>10}" |
| print("=" * len(header)) |
| print(header) |
| print("-" * len(header)) |
|
|
| for name, r in sorted(results.items(), key=lambda x: x[1]["test_ppl"]): |
| score = compute_efficiency_score(r) |
| params_k = r["total_params"] / 1000 |
| print(f"{name:<20} {r['test_ppl']:8.2f} {params_k:8.1f}K " |
| f"{r.get('latency_ms_mean', 0):8.2f} {score:10.1f}") |
|
|
| print("=" * len(header)) |
|
|
| pareto = compute_pareto_frontier(results) |
| print(f"\nPareto-optimal models: {pareto}") |
|
|