File size: 5,283 Bytes
d30a2f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
"""
Hyperparameter sweep script for Q-TensorFormer v3.

Runs a grid/search over key hyperparameters and produces
comparative evaluation results.

Usage:
    python scripts/sweep.py --preset sweep --output results/
"""

import sys
import os
import json
import itertools
from pathlib import Path

sys.path.insert(0, str(Path(__file__).parent.parent))

import torch
from src.config import ExperimentConfig, ModelConfig, TrainingConfig
from src.models import create_model
from src.baselines import StandardTransformer
from src.data import load_wikitext2, load_synthetic_data
from src.training import Trainer
from src.metrics import evaluate_model, print_comparison_table, compute_pareto_frontier


def run_sweep(base_config, sweep_params, train_loader, val_loader, test_loader,
              device="cpu", output_dir="./outputs/sweep/"):
    """
    Run a hyperparameter sweep.

    Args:
        base_config: Base ExperimentConfig.
        sweep_params: Dict of param_name → [values].
    """
    keys = list(sweep_params.keys())
    values = list(sweep_params.values())
    os.makedirs(output_dir, exist_ok=True)

    results = {}
    configs = []

    for combo in itertools.product(*values):
        config = ExperimentConfig(
            model=ModelConfig(**base_config.model.__dict__),
            training=TrainingConfig(**base_config.training.__dict__),
        )

        # Apply sweep params
        param_dict = dict(zip(keys, combo))
        for k, v in param_dict.items():
            if "." in k:
                section, key = k.split(".")
                getattr(getattr(config, section), key).__class__.__dict__
                setattr(getattr(config, section), key, v)
            else:
                if hasattr(config.model, k):
                    setattr(config.model, k, v)
                elif hasattr(config.training, k):
                    setattr(config.training, k, v)

        name = "_".join(f"{k}={v}" for k, v in param_dict.items())
        config.experiment_name = name
        configs.append((name, config))

    print(f"Running {len(configs)} configurations...")
    for i, (name, config) in enumerate(configs):
        print(f"\n[{i+1}/{len(configs)}] {name}")

        # Create model
        model = create_model(config, "qtensor")

        # Train
        trainer = Trainer(
            model, config,
            train_loader=train_loader,
            val_loader=val_loader,
            test_loader=test_loader,
            device=device,
            output_dir=f"{output_dir}/{name}",
        )
        trainer.train()

        # Evaluate
        results[name] = evaluate_model(model, test_loader, device)

    # Save sweep results
    with open(f"{output_dir}/sweep_results.json", "w") as f:
        clean = {}
        for name, r in results.items():
            clean[name] = {k: (float(v) if hasattr(v, "item") else v) for k, v in r.items()}
        json.dump(clean, f, indent=2)

    # Print summary
    print("\n" + "=" * 70)
    print("SWEEP RESULTS")
    print("=" * 70)
    print_comparison_table(results)

    pareto = compute_pareto_frontier(results)
    print(f"\nPareto-optimal: {pareto}")

    # Best by metric
    best_ppl = min(results.items(), key=lambda x: x[1]["test_ppl"])
    best_params = min(results.items(), key=lambda x: x[1]["total_params"])
    print(f"\nBest PPL: {best_ppl[0]} ({best_ppl[1]['test_ppl']:.2f})")
    print(f"Fewest params: {best_params[0]} ({best_params[1]['total_params']:,})")

    return results


def main():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--epochs", type=int, default=5)
    parser.add_argument("--batch-size", type=int, default=16)
    parser.add_argument("--device", type=str, default="cpu")
    parser.add_argument("--output", type=str, default="./outputs/sweep/")
    parser.add_argument("--synthetic", action="store_true")
    args = parser.parse_args()

    torch.manual_seed(42)

    # Base config
    config = ExperimentConfig(
        model=ModelConfig(d_model=128, n_layers=2, n_heads=4, tt_rank=8,
                          vocab_size=10000, max_seq_len=128),
        training=TrainingConfig(max_epochs=args.epochs, batch_size=args.batch_size),
    )

    # Load data
    if args.synthetic:
        train_loader = load_synthetic_data(batch_size=args.batch_size)
        val_loader = None
        test_loader = train_loader
    else:
        train_loader, val_loader, test_loader, tokenizer = load_wikitext2(
            seq_len=128, batch_size=args.batch_size
        )
        config.model.vocab_size = tokenizer.vocab_size

    # Sweep parameters
    sweep = {
        "tt_rank": [2, 4, 8, 16],
        "use_quantum": [True, False],
        "quantum_sparsity": [0.5, 0.7, 0.9],
        "rank_alpha": [1.0, 2.0, 3.0],
    }

    # Limit combinations for manageable runtime
    # Full sweep: 4 * 2 * 3 * 3 = 72 combos
    # Reduced: tt_rank vs quantum vs alpha
    sweep = {
        "tt_rank": [2, 4, 8, 16],
        "use_quantum": [True, False],
        "quantum_sparsity": [0.7],   # Fixed for now
        "rank_alpha": [2.0],         # Fixed for now
    }
    # 4 * 2 = 8 combos

    run_sweep(config, sweep, train_loader, val_loader, test_loader,
              args.device, args.output)


if __name__ == "__main__":
    main()