Q-TensorFormer / scripts /sweep.py
Premchan369's picture
v3.0.0: Scripts
d30a2f9 verified
"""
Hyperparameter sweep script for Q-TensorFormer v3.
Runs a grid/search over key hyperparameters and produces
comparative evaluation results.
Usage:
python scripts/sweep.py --preset sweep --output results/
"""
import sys
import os
import json
import itertools
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
import torch
from src.config import ExperimentConfig, ModelConfig, TrainingConfig
from src.models import create_model
from src.baselines import StandardTransformer
from src.data import load_wikitext2, load_synthetic_data
from src.training import Trainer
from src.metrics import evaluate_model, print_comparison_table, compute_pareto_frontier
def run_sweep(base_config, sweep_params, train_loader, val_loader, test_loader,
device="cpu", output_dir="./outputs/sweep/"):
"""
Run a hyperparameter sweep.
Args:
base_config: Base ExperimentConfig.
sweep_params: Dict of param_name → [values].
"""
keys = list(sweep_params.keys())
values = list(sweep_params.values())
os.makedirs(output_dir, exist_ok=True)
results = {}
configs = []
for combo in itertools.product(*values):
config = ExperimentConfig(
model=ModelConfig(**base_config.model.__dict__),
training=TrainingConfig(**base_config.training.__dict__),
)
# Apply sweep params
param_dict = dict(zip(keys, combo))
for k, v in param_dict.items():
if "." in k:
section, key = k.split(".")
getattr(getattr(config, section), key).__class__.__dict__
setattr(getattr(config, section), key, v)
else:
if hasattr(config.model, k):
setattr(config.model, k, v)
elif hasattr(config.training, k):
setattr(config.training, k, v)
name = "_".join(f"{k}={v}" for k, v in param_dict.items())
config.experiment_name = name
configs.append((name, config))
print(f"Running {len(configs)} configurations...")
for i, (name, config) in enumerate(configs):
print(f"\n[{i+1}/{len(configs)}] {name}")
# Create model
model = create_model(config, "qtensor")
# Train
trainer = Trainer(
model, config,
train_loader=train_loader,
val_loader=val_loader,
test_loader=test_loader,
device=device,
output_dir=f"{output_dir}/{name}",
)
trainer.train()
# Evaluate
results[name] = evaluate_model(model, test_loader, device)
# Save sweep results
with open(f"{output_dir}/sweep_results.json", "w") as f:
clean = {}
for name, r in results.items():
clean[name] = {k: (float(v) if hasattr(v, "item") else v) for k, v in r.items()}
json.dump(clean, f, indent=2)
# Print summary
print("\n" + "=" * 70)
print("SWEEP RESULTS")
print("=" * 70)
print_comparison_table(results)
pareto = compute_pareto_frontier(results)
print(f"\nPareto-optimal: {pareto}")
# Best by metric
best_ppl = min(results.items(), key=lambda x: x[1]["test_ppl"])
best_params = min(results.items(), key=lambda x: x[1]["total_params"])
print(f"\nBest PPL: {best_ppl[0]} ({best_ppl[1]['test_ppl']:.2f})")
print(f"Fewest params: {best_params[0]} ({best_params[1]['total_params']:,})")
return results
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--epochs", type=int, default=5)
parser.add_argument("--batch-size", type=int, default=16)
parser.add_argument("--device", type=str, default="cpu")
parser.add_argument("--output", type=str, default="./outputs/sweep/")
parser.add_argument("--synthetic", action="store_true")
args = parser.parse_args()
torch.manual_seed(42)
# Base config
config = ExperimentConfig(
model=ModelConfig(d_model=128, n_layers=2, n_heads=4, tt_rank=8,
vocab_size=10000, max_seq_len=128),
training=TrainingConfig(max_epochs=args.epochs, batch_size=args.batch_size),
)
# Load data
if args.synthetic:
train_loader = load_synthetic_data(batch_size=args.batch_size)
val_loader = None
test_loader = train_loader
else:
train_loader, val_loader, test_loader, tokenizer = load_wikitext2(
seq_len=128, batch_size=args.batch_size
)
config.model.vocab_size = tokenizer.vocab_size
# Sweep parameters
sweep = {
"tt_rank": [2, 4, 8, 16],
"use_quantum": [True, False],
"quantum_sparsity": [0.5, 0.7, 0.9],
"rank_alpha": [1.0, 2.0, 3.0],
}
# Limit combinations for manageable runtime
# Full sweep: 4 * 2 * 3 * 3 = 72 combos
# Reduced: tt_rank vs quantum vs alpha
sweep = {
"tt_rank": [2, 4, 8, 16],
"use_quantum": [True, False],
"quantum_sparsity": [0.7], # Fixed for now
"rank_alpha": [2.0], # Fixed for now
}
# 4 * 2 = 8 combos
run_sweep(config, sweep, train_loader, val_loader, test_loader,
args.device, args.output)
if __name__ == "__main__":
main()