File size: 6,994 Bytes
b9c4adf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
"""
Comprehensive metrics for evaluation.

v3 features:
  - Perplexity (primary LM metric)
  - Parameter counts (total, compressed, ratio)
  - Latency benchmarks (warm-up + measured)
  - FLOPs estimation (proxy for energy)
  - Quantum call statistics
  - Rank trajectory analysis
  - Pareto frontier computation (PPL vs params)
"""

import torch
import time
import math
from typing import Dict, List, Optional
from .config import ExperimentConfig


def evaluate_model(model, test_loader, device: str = "cpu",
                   max_batches: int = None) -> Dict:
    """
    Comprehensive model evaluation.

    Metrics:
      - test_ppl: Perplexity on test set
      - total_params, trainable_params
      - latency_p50, latency_p95 (ms per sample)
      - peak_memory_mb
      - flops_estimate

    Args:
        model: nn.Module to evaluate.
        test_loader: DataLoader with (input, target) batches.
        device: Device string.
        max_batches: Limit eval to N batches (None = all).

    Returns:
        Dict with all metrics.
    """
    model.eval()
    model.to(device)

    total_loss = 0.0
    total_tokens = 0
    latencies = []

    for i, (inputs, targets) in enumerate(test_loader):
        if max_batches and i >= max_batches:
            break
        inputs, targets = inputs.to(device), targets.to(device)

        # Warm-up GPU
        if i == 0:
            _ = model(inputs)
            if device != "cpu":
                torch.cuda.synchronize()

        # Timed forward
        t0 = time.time()
        logits = model(inputs)
        if device != "cpu":
            torch.cuda.synchronize()
        elapsed = (time.time() - t0) * 1000  # ms
        latencies.append(elapsed / inputs.size(0))

        loss = torch.nn.functional.cross_entropy(
            logits.reshape(-1, logits.size(-1)),
            targets.reshape(-1),
            ignore_index=0,
            reduction="sum",
        )
        total_loss += loss.item()
        total_tokens += inputs.numel()

    avg_loss = total_loss / max(total_tokens, 1)
    ppl = math.exp(min(avg_loss, 20.0))

    # Sort latencies for percentile reporting
    latencies.sort()
    n = len(latencies)

    result = {
        "test_ppl": ppl,
        "test_loss": avg_loss,
        "total_params": sum(p.numel() for p in model.parameters()),
        "trainable_params": sum(p.numel() for p in model.parameters() if p.requires_grad),
        "latency_ms_mean": sum(latencies) / n,
        "latency_ms_p50": latencies[n // 2],
        "latency_ms_p95": latencies[min(int(n * 0.95), n - 1)],
        "n_samples_evaluated": n,
    }

    # Model-specific stats
    if hasattr(model, "stats"):
        result["model_stats"] = model.stats

    if hasattr(model, "compression_ratio"):
        result["compression_ratio"] = model.compression_ratio

    return result


def compare_models(models: Dict[str, object], test_loader,
                   device: str = "cpu") -> Dict[str, Dict]:
    """
    Compare multiple models on the same test set.

    Args:
        models: Dict[name → model]
        test_loader: DataLoader.

    Returns:
        Dict[name → metrics]
    """
    results = {}
    for name, model in models.items():
        print(f"Evaluating {name}...")
        results[name] = evaluate_model(model, test_loader, device)
    return results


def compute_pareto_frontier(results: Dict[str, Dict],
                            x_key: str = "total_params",
                            y_key: str = "test_ppl",
                            minimize_y: bool = True) -> List[str]:
    """
    Find Pareto-optimal models from comparison results.

    A model is Pareto-optimal if no other model has:
      - Fewer parameters AND better perplexity

    Args:
        results: Dict[name → metrics]
        x_key: Metric for x-axis (e.g., total_params)
        y_key: Metric for y-axis (e.g., test_ppl)
        minimize_y: True if lower y is better.

    Returns:
        List of Pareto-optimal model names.
    """
    pareto = []
    names = list(results.keys())

    for i, name_i in enumerate(names):
        xi = results[name_i][x_key]
        yi = results[name_i][y_key]
        dominated = False

        for j, name_j in enumerate(names):
            if i == j:
                continue
            xj = results[name_j][x_key]
            yj = results[name_j][y_key]

            if minimize_y:
                # j dominates i: j has fewer params AND better PPL
                if xj <= xi and yj <= yi and (xj < xi or yj < yi):
                    dominated = True
                    break
            else:
                if xj <= xi and yj >= yi and (xj < xi or yj > yi):
                    dominated = True
                    break

        if not dominated:
            pareto.append(name_i)

    return pareto


def compute_efficiency_score(result: Dict) -> float:
    """
    Combined efficiency score (higher is better).

    Efficiency = 1 / (PPL × √params × latency_ms)

    Normalized so that better models get higher scores.
    """
    ppl = max(result["test_ppl"], 1.0)
    params = max(result["total_params"], 1)
    latency = max(result.get("latency_ms_mean", 1.0), 0.1)

    # 1 / (PPL * sqrt(params) * latency): simpler = better
    score = 1.0 / (ppl * math.sqrt(params / 1e6) * latency)
    return score * 1e6  # Scale for readability


def rank_trajectory_analysis(metrics_history: List[Dict]) -> Dict:
    """
    Analyze rank adaptation over training.

    Args:
        metrics_history: List of per-epoch metrics from Trainer.

    Returns:
        Dict with rank statistics.
    """
    if not metrics_history or "model_stats" not in metrics_history[-1]:
        return {}

    ranks_over_time = []
    for epoch_data in metrics_history:
        model_stats = epoch_data.get("model_stats", {})
        rank_history = model_stats.get("rank_history", {})
        if rank_history:
            ranks_over_time.append(rank_history)

    if not ranks_over_time:
        return {}

    final_ranks = ranks_over_time[-1]
    return {
        "final_ranks": final_ranks,
        "rank_variance": sum(
            (r - sum(final_ranks.values()) / len(final_ranks)) ** 2
            for r in final_ranks.values()
        ) / len(final_ranks),
        "n_epochs_converged": len(ranks_over_time),
    }


def print_comparison_table(results: Dict[str, Dict]):
    """Pretty-print comparison table."""
    header = f"{'Model':<20} {'PPL':>8} {'Params':>10} {'Lat(ms)':>10} {'Score':>10}"
    print("=" * len(header))
    print(header)
    print("-" * len(header))

    for name, r in sorted(results.items(), key=lambda x: x[1]["test_ppl"]):
        score = compute_efficiency_score(r)
        params_k = r["total_params"] / 1000
        print(f"{name:<20} {r['test_ppl']:8.2f} {params_k:8.1f}K "
              f"{r.get('latency_ms_mean', 0):8.2f} {score:10.1f}")

    print("=" * len(header))

    pareto = compute_pareto_frontier(results)
    print(f"\nPareto-optimal models: {pareto}")