Q-TensorFormer / src /metrics.py
Premchan369's picture
v3.0.0: Source files
b9c4adf verified
"""
Comprehensive metrics for evaluation.
v3 features:
- Perplexity (primary LM metric)
- Parameter counts (total, compressed, ratio)
- Latency benchmarks (warm-up + measured)
- FLOPs estimation (proxy for energy)
- Quantum call statistics
- Rank trajectory analysis
- Pareto frontier computation (PPL vs params)
"""
import torch
import time
import math
from typing import Dict, List, Optional
from .config import ExperimentConfig
def evaluate_model(model, test_loader, device: str = "cpu",
max_batches: int = None) -> Dict:
"""
Comprehensive model evaluation.
Metrics:
- test_ppl: Perplexity on test set
- total_params, trainable_params
- latency_p50, latency_p95 (ms per sample)
- peak_memory_mb
- flops_estimate
Args:
model: nn.Module to evaluate.
test_loader: DataLoader with (input, target) batches.
device: Device string.
max_batches: Limit eval to N batches (None = all).
Returns:
Dict with all metrics.
"""
model.eval()
model.to(device)
total_loss = 0.0
total_tokens = 0
latencies = []
for i, (inputs, targets) in enumerate(test_loader):
if max_batches and i >= max_batches:
break
inputs, targets = inputs.to(device), targets.to(device)
# Warm-up GPU
if i == 0:
_ = model(inputs)
if device != "cpu":
torch.cuda.synchronize()
# Timed forward
t0 = time.time()
logits = model(inputs)
if device != "cpu":
torch.cuda.synchronize()
elapsed = (time.time() - t0) * 1000 # ms
latencies.append(elapsed / inputs.size(0))
loss = torch.nn.functional.cross_entropy(
logits.reshape(-1, logits.size(-1)),
targets.reshape(-1),
ignore_index=0,
reduction="sum",
)
total_loss += loss.item()
total_tokens += inputs.numel()
avg_loss = total_loss / max(total_tokens, 1)
ppl = math.exp(min(avg_loss, 20.0))
# Sort latencies for percentile reporting
latencies.sort()
n = len(latencies)
result = {
"test_ppl": ppl,
"test_loss": avg_loss,
"total_params": sum(p.numel() for p in model.parameters()),
"trainable_params": sum(p.numel() for p in model.parameters() if p.requires_grad),
"latency_ms_mean": sum(latencies) / n,
"latency_ms_p50": latencies[n // 2],
"latency_ms_p95": latencies[min(int(n * 0.95), n - 1)],
"n_samples_evaluated": n,
}
# Model-specific stats
if hasattr(model, "stats"):
result["model_stats"] = model.stats
if hasattr(model, "compression_ratio"):
result["compression_ratio"] = model.compression_ratio
return result
def compare_models(models: Dict[str, object], test_loader,
device: str = "cpu") -> Dict[str, Dict]:
"""
Compare multiple models on the same test set.
Args:
models: Dict[name → model]
test_loader: DataLoader.
Returns:
Dict[name → metrics]
"""
results = {}
for name, model in models.items():
print(f"Evaluating {name}...")
results[name] = evaluate_model(model, test_loader, device)
return results
def compute_pareto_frontier(results: Dict[str, Dict],
x_key: str = "total_params",
y_key: str = "test_ppl",
minimize_y: bool = True) -> List[str]:
"""
Find Pareto-optimal models from comparison results.
A model is Pareto-optimal if no other model has:
- Fewer parameters AND better perplexity
Args:
results: Dict[name → metrics]
x_key: Metric for x-axis (e.g., total_params)
y_key: Metric for y-axis (e.g., test_ppl)
minimize_y: True if lower y is better.
Returns:
List of Pareto-optimal model names.
"""
pareto = []
names = list(results.keys())
for i, name_i in enumerate(names):
xi = results[name_i][x_key]
yi = results[name_i][y_key]
dominated = False
for j, name_j in enumerate(names):
if i == j:
continue
xj = results[name_j][x_key]
yj = results[name_j][y_key]
if minimize_y:
# j dominates i: j has fewer params AND better PPL
if xj <= xi and yj <= yi and (xj < xi or yj < yi):
dominated = True
break
else:
if xj <= xi and yj >= yi and (xj < xi or yj > yi):
dominated = True
break
if not dominated:
pareto.append(name_i)
return pareto
def compute_efficiency_score(result: Dict) -> float:
"""
Combined efficiency score (higher is better).
Efficiency = 1 / (PPL × √params × latency_ms)
Normalized so that better models get higher scores.
"""
ppl = max(result["test_ppl"], 1.0)
params = max(result["total_params"], 1)
latency = max(result.get("latency_ms_mean", 1.0), 0.1)
# 1 / (PPL * sqrt(params) * latency): simpler = better
score = 1.0 / (ppl * math.sqrt(params / 1e6) * latency)
return score * 1e6 # Scale for readability
def rank_trajectory_analysis(metrics_history: List[Dict]) -> Dict:
"""
Analyze rank adaptation over training.
Args:
metrics_history: List of per-epoch metrics from Trainer.
Returns:
Dict with rank statistics.
"""
if not metrics_history or "model_stats" not in metrics_history[-1]:
return {}
ranks_over_time = []
for epoch_data in metrics_history:
model_stats = epoch_data.get("model_stats", {})
rank_history = model_stats.get("rank_history", {})
if rank_history:
ranks_over_time.append(rank_history)
if not ranks_over_time:
return {}
final_ranks = ranks_over_time[-1]
return {
"final_ranks": final_ranks,
"rank_variance": sum(
(r - sum(final_ranks.values()) / len(final_ranks)) ** 2
for r in final_ranks.values()
) / len(final_ranks),
"n_epochs_converged": len(ranks_over_time),
}
def print_comparison_table(results: Dict[str, Dict]):
"""Pretty-print comparison table."""
header = f"{'Model':<20} {'PPL':>8} {'Params':>10} {'Lat(ms)':>10} {'Score':>10}"
print("=" * len(header))
print(header)
print("-" * len(header))
for name, r in sorted(results.items(), key=lambda x: x[1]["test_ppl"]):
score = compute_efficiency_score(r)
params_k = r["total_params"] / 1000
print(f"{name:<20} {r['test_ppl']:8.2f} {params_k:8.1f}K "
f"{r.get('latency_ms_mean', 0):8.2f} {score:10.1f}")
print("=" * len(header))
pareto = compute_pareto_frontier(results)
print(f"\nPareto-optimal models: {pareto}")