v3.0.0: Scripts
Browse files- scripts/benchmark.py +186 -0
- scripts/distill.py +107 -0
- scripts/sweep.py +167 -0
scripts/benchmark.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Comprehensive benchmark script for Q-TensorFormer v3.
|
| 3 |
+
|
| 4 |
+
Runs multi-model comparison against all baselines and produces
|
| 5 |
+
a full evaluation report with Pareto frontier analysis.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
python scripts/benchmark.py --preset small --epochs 5 --output results/
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import sys
|
| 12 |
+
import os
|
| 13 |
+
import argparse
|
| 14 |
+
import json
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
|
| 17 |
+
# Add project root to path
|
| 18 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 19 |
+
|
| 20 |
+
from src.config import ExperimentConfig, ModelConfig, TrainingConfig, PRESETS
|
| 21 |
+
from src.models import create_model
|
| 22 |
+
from src.baselines import StandardTransformer, DistilledTransformer, PrunedTransformer
|
| 23 |
+
from src.data import load_wikitext2, load_synthetic_data
|
| 24 |
+
from src.training import Trainer
|
| 25 |
+
from src.metrics import (
|
| 26 |
+
evaluate_model, compare_models, compute_pareto_frontier,
|
| 27 |
+
compute_efficiency_score, print_comparison_table,
|
| 28 |
+
rank_trajectory_analysis,
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def parse_args():
|
| 33 |
+
parser = argparse.ArgumentParser(description="Q-TensorFormer Benchmark")
|
| 34 |
+
parser.add_argument("--preset", type=str, default="small",
|
| 35 |
+
choices=["tiny", "small", "medium"],
|
| 36 |
+
help="Configuration preset")
|
| 37 |
+
parser.add_argument("--epochs", type=int, default=5,
|
| 38 |
+
help="Training epochs")
|
| 39 |
+
parser.add_argument("--batch-size", type=int, default=16)
|
| 40 |
+
parser.add_argument("--seq-len", type=int, default=128)
|
| 41 |
+
parser.add_argument("--output", type=str, default="./outputs/benchmark/",
|
| 42 |
+
help="Output directory")
|
| 43 |
+
parser.add_argument("--device", type=str, default="cpu",
|
| 44 |
+
help="Device (cpu, cuda)")
|
| 45 |
+
parser.add_argument("--synthetic", action="store_true",
|
| 46 |
+
help="Use synthetic data (faster)")
|
| 47 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 48 |
+
return parser.parse_args()
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def main():
|
| 52 |
+
args = parse_args()
|
| 53 |
+
torch.manual_seed(args.seed)
|
| 54 |
+
|
| 55 |
+
# Load config
|
| 56 |
+
config = PRESETS[args.preset]()
|
| 57 |
+
config.training.max_epochs = args.epochs
|
| 58 |
+
config.training.batch_size = args.batch_size
|
| 59 |
+
config.model.max_seq_len = args.seq_len
|
| 60 |
+
|
| 61 |
+
print(f"Config: {config.experiment_name}")
|
| 62 |
+
print(f"Model: d_model={config.model.d_model}, "
|
| 63 |
+
f"n_layers={config.model.n_layers}, "
|
| 64 |
+
f"tt_rank={config.model.tt_rank}")
|
| 65 |
+
|
| 66 |
+
# Load data
|
| 67 |
+
print("\nLoading data...")
|
| 68 |
+
if args.synthetic:
|
| 69 |
+
train_loader = load_synthetic_data(
|
| 70 |
+
vocab_size=config.model.vocab_size,
|
| 71 |
+
seq_len=args.seq_len,
|
| 72 |
+
n_samples=2000,
|
| 73 |
+
batch_size=args.batch_size,
|
| 74 |
+
)
|
| 75 |
+
val_loader = None
|
| 76 |
+
test_loader = train_loader # Same for synthetic
|
| 77 |
+
tokenizer = None
|
| 78 |
+
else:
|
| 79 |
+
train_loader, val_loader, test_loader, tokenizer = load_wikitext2(
|
| 80 |
+
seq_len=args.seq_len,
|
| 81 |
+
batch_size=args.batch_size,
|
| 82 |
+
)
|
| 83 |
+
config.model.vocab_size = tokenizer.vocab_size
|
| 84 |
+
|
| 85 |
+
# Create models
|
| 86 |
+
print("\nCreating models...")
|
| 87 |
+
models = {}
|
| 88 |
+
|
| 89 |
+
# Q-TensorFormer (hybrid)
|
| 90 |
+
models["QTensorFormer"] = create_model(config, "qtensor")
|
| 91 |
+
print(f" QTensorFormer: {models['QTensorFormer'].total_params:,} params")
|
| 92 |
+
|
| 93 |
+
# TT-Only (no quantum)
|
| 94 |
+
models["TensorOnly"] = create_model(config, "tensor_only")
|
| 95 |
+
print(f" TensorOnly: {models['TensorOnly'].total_params:,} params")
|
| 96 |
+
|
| 97 |
+
# Standard transformer (dense)
|
| 98 |
+
models["StandardTransformer"] = StandardTransformer(
|
| 99 |
+
vocab_size=config.model.vocab_size,
|
| 100 |
+
d_model=config.model.d_model,
|
| 101 |
+
n_heads=config.model.n_heads,
|
| 102 |
+
n_layers=config.model.n_layers,
|
| 103 |
+
max_seq_len=config.model.max_seq_len,
|
| 104 |
+
)
|
| 105 |
+
print(f" StandardTransformer: {models['StandardTransformer'].total_params:,} params")
|
| 106 |
+
|
| 107 |
+
# Distilled (smaller dense)
|
| 108 |
+
models["Distilled"] = DistilledTransformer(
|
| 109 |
+
vocab_size=config.model.vocab_size,
|
| 110 |
+
d_model=max(64, config.model.d_model // 2),
|
| 111 |
+
n_heads=config.model.n_heads,
|
| 112 |
+
n_layers=config.model.n_layers,
|
| 113 |
+
max_seq_len=config.model.max_seq_len,
|
| 114 |
+
)
|
| 115 |
+
print(f" Distilled: {models['Distilled'].total_params:,} params")
|
| 116 |
+
|
| 117 |
+
# Train all models
|
| 118 |
+
print(f"\n{'='*60}")
|
| 119 |
+
print("Training models...")
|
| 120 |
+
print(f"{'='*60}")
|
| 121 |
+
|
| 122 |
+
trained_models = {}
|
| 123 |
+
for name, model in models.items():
|
| 124 |
+
print(f"\n--- Training {name} ---")
|
| 125 |
+
trainer = Trainer(
|
| 126 |
+
model, config,
|
| 127 |
+
train_loader=train_loader,
|
| 128 |
+
val_loader=val_loader,
|
| 129 |
+
test_loader=test_loader,
|
| 130 |
+
device=args.device,
|
| 131 |
+
output_dir=f"{args.output}/{name}",
|
| 132 |
+
)
|
| 133 |
+
trainer.train()
|
| 134 |
+
trained_models[name] = model
|
| 135 |
+
|
| 136 |
+
# Evaluate
|
| 137 |
+
print(f"\n{'='*60}")
|
| 138 |
+
print("Evaluating models...")
|
| 139 |
+
print(f"{'='*60}")
|
| 140 |
+
|
| 141 |
+
results = {}
|
| 142 |
+
for name, model in trained_models.items():
|
| 143 |
+
results[name] = evaluate_model(model, test_loader, args.device)
|
| 144 |
+
|
| 145 |
+
# Print comparison
|
| 146 |
+
print_comparison_table(results)
|
| 147 |
+
|
| 148 |
+
# Pareto frontier
|
| 149 |
+
pareto = compute_pareto_frontier(results)
|
| 150 |
+
print(f"\nPareto-optimal models: {pareto}")
|
| 151 |
+
|
| 152 |
+
# Efficiency ranking
|
| 153 |
+
efficiency = {name: compute_efficiency_score(r) for name, r in results.items()}
|
| 154 |
+
best = max(efficiency, key=efficiency.get)
|
| 155 |
+
print(f"Most efficient: {best} (score={efficiency[best]:.1f})")
|
| 156 |
+
|
| 157 |
+
# Save results
|
| 158 |
+
os.makedirs(args.output, exist_ok=True)
|
| 159 |
+
with open(f"{args.output}/results.json", "w") as f:
|
| 160 |
+
# Convert float32 to native float
|
| 161 |
+
clean = {}
|
| 162 |
+
for name, r in results.items():
|
| 163 |
+
clean[name] = {k: (float(v) if hasattr(v, 'item') else v) for k, v in r.items()}
|
| 164 |
+
json.dump({
|
| 165 |
+
"config": config.experiment_name,
|
| 166 |
+
"results": clean,
|
| 167 |
+
"pareto": pareto,
|
| 168 |
+
"efficiency": {k: float(v) for k, v in efficiency.items()},
|
| 169 |
+
"best": best,
|
| 170 |
+
}, f, indent=2)
|
| 171 |
+
|
| 172 |
+
print(f"\nResults saved to {args.output}/results.json")
|
| 173 |
+
|
| 174 |
+
# Summary
|
| 175 |
+
print(f"\n{'='*60}")
|
| 176 |
+
print("SUMMARY")
|
| 177 |
+
print(f"{'='*60}")
|
| 178 |
+
for name in results:
|
| 179 |
+
ppl = results[name]["test_ppl"]
|
| 180 |
+
params = results[name]["total_params"]
|
| 181 |
+
lat = results[name].get("latency_ms_mean", 0)
|
| 182 |
+
print(f" {name:<25} PPL={ppl:.2f} Params={params:,} Lat={lat:.1f}ms")
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
if __name__ == "__main__":
|
| 186 |
+
main()
|
scripts/distill.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Knowledge distillation training script.
|
| 3 |
+
|
| 4 |
+
Trains a compressed Q-TensorFormer student using a dense teacher model.
|
| 5 |
+
Matches the student's parameter budget to ~50% of the teacher.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
python scripts/distill.py --teacher_config small --student_rank 4
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import sys
|
| 12 |
+
import os
|
| 13 |
+
import argparse
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
|
| 16 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
from src.config import ExperimentConfig, PRESETS
|
| 20 |
+
from src.models import create_model
|
| 21 |
+
from src.baselines import StandardTransformer
|
| 22 |
+
from src.data import load_wikitext2, load_synthetic_data
|
| 23 |
+
from src.training import DistillationTrainer
|
| 24 |
+
from src.metrics import evaluate_model
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def main():
|
| 28 |
+
parser = argparse.ArgumentParser(description="KD for Q-TensorFormer")
|
| 29 |
+
parser.add_argument("--teacher_config", type=str, default="small")
|
| 30 |
+
parser.add_argument("--student_rank", type=int, default=4)
|
| 31 |
+
parser.add_argument("--alpha", type=float, default=0.5,
|
| 32 |
+
help="Distillation loss weight")
|
| 33 |
+
parser.add_argument("--temperature", type=float, default=3.0)
|
| 34 |
+
parser.add_argument("--epochs", type=int, default=8)
|
| 35 |
+
parser.add_argument("--batch_size", type=int, default=16)
|
| 36 |
+
parser.add_argument("--device", type=str, default="cpu")
|
| 37 |
+
parser.add_argument("--output", type=str, default="./outputs/distill/")
|
| 38 |
+
parser.add_argument("--synthetic", action="store_true")
|
| 39 |
+
args = parser.parse_args()
|
| 40 |
+
|
| 41 |
+
torch.manual_seed(42)
|
| 42 |
+
|
| 43 |
+
# Teacher: dense baseline
|
| 44 |
+
teacher_config = PRESETS[args.teacher_config]()
|
| 45 |
+
print(f"Teacher config: {teacher_config.experiment_name}")
|
| 46 |
+
|
| 47 |
+
# Load data
|
| 48 |
+
if args.synthetic:
|
| 49 |
+
train_loader = load_synthetic_data(batch_size=args.batch_size)
|
| 50 |
+
test_loader = train_loader
|
| 51 |
+
else:
|
| 52 |
+
train_loader, val_loader, test_loader, tokenizer = load_wikitext2(
|
| 53 |
+
batch_size=args.batch_size
|
| 54 |
+
)
|
| 55 |
+
teacher_config.model.vocab_size = tokenizer.vocab_size
|
| 56 |
+
|
| 57 |
+
# Create teacher (dense)
|
| 58 |
+
teacher = StandardTransformer(
|
| 59 |
+
vocab_size=teacher_config.model.vocab_size,
|
| 60 |
+
d_model=teacher_config.model.d_model,
|
| 61 |
+
n_heads=teacher_config.model.n_heads,
|
| 62 |
+
n_layers=teacher_config.model.n_layers,
|
| 63 |
+
)
|
| 64 |
+
print(f"Teacher params: {teacher.total_params:,}")
|
| 65 |
+
|
| 66 |
+
# Student: compressed Q-TensorFormer
|
| 67 |
+
student_config = ExperimentConfig(
|
| 68 |
+
model=type(teacher_config.model)(
|
| 69 |
+
**{k: v for k, v in teacher_config.model.__dict__.items()}
|
| 70 |
+
),
|
| 71 |
+
training=type(teacher_config.training)(
|
| 72 |
+
**{k: v for k, v in teacher_config.training.__dict__.items()}
|
| 73 |
+
),
|
| 74 |
+
)
|
| 75 |
+
student_config.model.tt_rank = args.student_rank
|
| 76 |
+
student_config.model.use_quantum = True
|
| 77 |
+
student_config.training.max_epochs = args.epochs
|
| 78 |
+
|
| 79 |
+
student = create_model(student_config, "qtensor")
|
| 80 |
+
print(f"Student params: {student.total_params:,}")
|
| 81 |
+
print(f"Compression: {teacher.total_params / student.total_params:.1f}x")
|
| 82 |
+
|
| 83 |
+
# Train with distillation
|
| 84 |
+
trainer = DistillationTrainer(
|
| 85 |
+
student=student,
|
| 86 |
+
teacher=teacher,
|
| 87 |
+
config=student_config,
|
| 88 |
+
train_loader=train_loader,
|
| 89 |
+
val_loader=val_loader if not args.synthetic else None,
|
| 90 |
+
test_loader=test_loader,
|
| 91 |
+
device=args.device,
|
| 92 |
+
output_dir=args.output,
|
| 93 |
+
alpha=args.alpha,
|
| 94 |
+
temperature=args.temperature,
|
| 95 |
+
)
|
| 96 |
+
trainer.train()
|
| 97 |
+
|
| 98 |
+
# Evaluate
|
| 99 |
+
print("\nEvaluating knowledge-distilled model...")
|
| 100 |
+
results = evaluate_model(student, test_loader, args.device)
|
| 101 |
+
print(f"Student PPL: {results['test_ppl']:.2f}")
|
| 102 |
+
print(f"Student params: {results['total_params']:,}")
|
| 103 |
+
print(f"Compression vs teacher: {teacher.total_params / results['total_params']:.1f}x")
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
if __name__ == "__main__":
|
| 107 |
+
main()
|
scripts/sweep.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Hyperparameter sweep script for Q-TensorFormer v3.
|
| 3 |
+
|
| 4 |
+
Runs a grid/search over key hyperparameters and produces
|
| 5 |
+
comparative evaluation results.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
python scripts/sweep.py --preset sweep --output results/
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import sys
|
| 12 |
+
import os
|
| 13 |
+
import json
|
| 14 |
+
import itertools
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
|
| 17 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
from src.config import ExperimentConfig, ModelConfig, TrainingConfig
|
| 21 |
+
from src.models import create_model
|
| 22 |
+
from src.baselines import StandardTransformer
|
| 23 |
+
from src.data import load_wikitext2, load_synthetic_data
|
| 24 |
+
from src.training import Trainer
|
| 25 |
+
from src.metrics import evaluate_model, print_comparison_table, compute_pareto_frontier
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def run_sweep(base_config, sweep_params, train_loader, val_loader, test_loader,
|
| 29 |
+
device="cpu", output_dir="./outputs/sweep/"):
|
| 30 |
+
"""
|
| 31 |
+
Run a hyperparameter sweep.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
base_config: Base ExperimentConfig.
|
| 35 |
+
sweep_params: Dict of param_name → [values].
|
| 36 |
+
"""
|
| 37 |
+
keys = list(sweep_params.keys())
|
| 38 |
+
values = list(sweep_params.values())
|
| 39 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 40 |
+
|
| 41 |
+
results = {}
|
| 42 |
+
configs = []
|
| 43 |
+
|
| 44 |
+
for combo in itertools.product(*values):
|
| 45 |
+
config = ExperimentConfig(
|
| 46 |
+
model=ModelConfig(**base_config.model.__dict__),
|
| 47 |
+
training=TrainingConfig(**base_config.training.__dict__),
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
# Apply sweep params
|
| 51 |
+
param_dict = dict(zip(keys, combo))
|
| 52 |
+
for k, v in param_dict.items():
|
| 53 |
+
if "." in k:
|
| 54 |
+
section, key = k.split(".")
|
| 55 |
+
getattr(getattr(config, section), key).__class__.__dict__
|
| 56 |
+
setattr(getattr(config, section), key, v)
|
| 57 |
+
else:
|
| 58 |
+
if hasattr(config.model, k):
|
| 59 |
+
setattr(config.model, k, v)
|
| 60 |
+
elif hasattr(config.training, k):
|
| 61 |
+
setattr(config.training, k, v)
|
| 62 |
+
|
| 63 |
+
name = "_".join(f"{k}={v}" for k, v in param_dict.items())
|
| 64 |
+
config.experiment_name = name
|
| 65 |
+
configs.append((name, config))
|
| 66 |
+
|
| 67 |
+
print(f"Running {len(configs)} configurations...")
|
| 68 |
+
for i, (name, config) in enumerate(configs):
|
| 69 |
+
print(f"\n[{i+1}/{len(configs)}] {name}")
|
| 70 |
+
|
| 71 |
+
# Create model
|
| 72 |
+
model = create_model(config, "qtensor")
|
| 73 |
+
|
| 74 |
+
# Train
|
| 75 |
+
trainer = Trainer(
|
| 76 |
+
model, config,
|
| 77 |
+
train_loader=train_loader,
|
| 78 |
+
val_loader=val_loader,
|
| 79 |
+
test_loader=test_loader,
|
| 80 |
+
device=device,
|
| 81 |
+
output_dir=f"{output_dir}/{name}",
|
| 82 |
+
)
|
| 83 |
+
trainer.train()
|
| 84 |
+
|
| 85 |
+
# Evaluate
|
| 86 |
+
results[name] = evaluate_model(model, test_loader, device)
|
| 87 |
+
|
| 88 |
+
# Save sweep results
|
| 89 |
+
with open(f"{output_dir}/sweep_results.json", "w") as f:
|
| 90 |
+
clean = {}
|
| 91 |
+
for name, r in results.items():
|
| 92 |
+
clean[name] = {k: (float(v) if hasattr(v, "item") else v) for k, v in r.items()}
|
| 93 |
+
json.dump(clean, f, indent=2)
|
| 94 |
+
|
| 95 |
+
# Print summary
|
| 96 |
+
print("\n" + "=" * 70)
|
| 97 |
+
print("SWEEP RESULTS")
|
| 98 |
+
print("=" * 70)
|
| 99 |
+
print_comparison_table(results)
|
| 100 |
+
|
| 101 |
+
pareto = compute_pareto_frontier(results)
|
| 102 |
+
print(f"\nPareto-optimal: {pareto}")
|
| 103 |
+
|
| 104 |
+
# Best by metric
|
| 105 |
+
best_ppl = min(results.items(), key=lambda x: x[1]["test_ppl"])
|
| 106 |
+
best_params = min(results.items(), key=lambda x: x[1]["total_params"])
|
| 107 |
+
print(f"\nBest PPL: {best_ppl[0]} ({best_ppl[1]['test_ppl']:.2f})")
|
| 108 |
+
print(f"Fewest params: {best_params[0]} ({best_params[1]['total_params']:,})")
|
| 109 |
+
|
| 110 |
+
return results
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def main():
|
| 114 |
+
import argparse
|
| 115 |
+
parser = argparse.ArgumentParser()
|
| 116 |
+
parser.add_argument("--epochs", type=int, default=5)
|
| 117 |
+
parser.add_argument("--batch-size", type=int, default=16)
|
| 118 |
+
parser.add_argument("--device", type=str, default="cpu")
|
| 119 |
+
parser.add_argument("--output", type=str, default="./outputs/sweep/")
|
| 120 |
+
parser.add_argument("--synthetic", action="store_true")
|
| 121 |
+
args = parser.parse_args()
|
| 122 |
+
|
| 123 |
+
torch.manual_seed(42)
|
| 124 |
+
|
| 125 |
+
# Base config
|
| 126 |
+
config = ExperimentConfig(
|
| 127 |
+
model=ModelConfig(d_model=128, n_layers=2, n_heads=4, tt_rank=8,
|
| 128 |
+
vocab_size=10000, max_seq_len=128),
|
| 129 |
+
training=TrainingConfig(max_epochs=args.epochs, batch_size=args.batch_size),
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
# Load data
|
| 133 |
+
if args.synthetic:
|
| 134 |
+
train_loader = load_synthetic_data(batch_size=args.batch_size)
|
| 135 |
+
val_loader = None
|
| 136 |
+
test_loader = train_loader
|
| 137 |
+
else:
|
| 138 |
+
train_loader, val_loader, test_loader, tokenizer = load_wikitext2(
|
| 139 |
+
seq_len=128, batch_size=args.batch_size
|
| 140 |
+
)
|
| 141 |
+
config.model.vocab_size = tokenizer.vocab_size
|
| 142 |
+
|
| 143 |
+
# Sweep parameters
|
| 144 |
+
sweep = {
|
| 145 |
+
"tt_rank": [2, 4, 8, 16],
|
| 146 |
+
"use_quantum": [True, False],
|
| 147 |
+
"quantum_sparsity": [0.5, 0.7, 0.9],
|
| 148 |
+
"rank_alpha": [1.0, 2.0, 3.0],
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
# Limit combinations for manageable runtime
|
| 152 |
+
# Full sweep: 4 * 2 * 3 * 3 = 72 combos
|
| 153 |
+
# Reduced: tt_rank vs quantum vs alpha
|
| 154 |
+
sweep = {
|
| 155 |
+
"tt_rank": [2, 4, 8, 16],
|
| 156 |
+
"use_quantum": [True, False],
|
| 157 |
+
"quantum_sparsity": [0.7], # Fixed for now
|
| 158 |
+
"rank_alpha": [2.0], # Fixed for now
|
| 159 |
+
}
|
| 160 |
+
# 4 * 2 = 8 combos
|
| 161 |
+
|
| 162 |
+
run_sweep(config, sweep, train_loader, val_loader, test_loader,
|
| 163 |
+
args.device, args.output)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
if __name__ == "__main__":
|
| 167 |
+
main()
|