""" Knowledge distillation training script. Trains a compressed Q-TensorFormer student using a dense teacher model. Matches the student's parameter budget to ~50% of the teacher. Usage: python scripts/distill.py --teacher_config small --student_rank 4 """ import sys import os import argparse from pathlib import Path sys.path.insert(0, str(Path(__file__).parent.parent)) import torch from src.config import ExperimentConfig, PRESETS from src.models import create_model from src.baselines import StandardTransformer from src.data import load_wikitext2, load_synthetic_data from src.training import DistillationTrainer from src.metrics import evaluate_model def main(): parser = argparse.ArgumentParser(description="KD for Q-TensorFormer") parser.add_argument("--teacher_config", type=str, default="small") parser.add_argument("--student_rank", type=int, default=4) parser.add_argument("--alpha", type=float, default=0.5, help="Distillation loss weight") parser.add_argument("--temperature", type=float, default=3.0) parser.add_argument("--epochs", type=int, default=8) parser.add_argument("--batch_size", type=int, default=16) parser.add_argument("--device", type=str, default="cpu") parser.add_argument("--output", type=str, default="./outputs/distill/") parser.add_argument("--synthetic", action="store_true") args = parser.parse_args() torch.manual_seed(42) # Teacher: dense baseline teacher_config = PRESETS[args.teacher_config]() print(f"Teacher config: {teacher_config.experiment_name}") # Load data if args.synthetic: train_loader = load_synthetic_data(batch_size=args.batch_size) test_loader = train_loader else: train_loader, val_loader, test_loader, tokenizer = load_wikitext2( batch_size=args.batch_size ) teacher_config.model.vocab_size = tokenizer.vocab_size # Create teacher (dense) teacher = StandardTransformer( vocab_size=teacher_config.model.vocab_size, d_model=teacher_config.model.d_model, n_heads=teacher_config.model.n_heads, n_layers=teacher_config.model.n_layers, ) print(f"Teacher params: {teacher.total_params:,}") # Student: compressed Q-TensorFormer student_config = ExperimentConfig( model=type(teacher_config.model)( **{k: v for k, v in teacher_config.model.__dict__.items()} ), training=type(teacher_config.training)( **{k: v for k, v in teacher_config.training.__dict__.items()} ), ) student_config.model.tt_rank = args.student_rank student_config.model.use_quantum = True student_config.training.max_epochs = args.epochs student = create_model(student_config, "qtensor") print(f"Student params: {student.total_params:,}") print(f"Compression: {teacher.total_params / student.total_params:.1f}x") # Train with distillation trainer = DistillationTrainer( student=student, teacher=teacher, config=student_config, train_loader=train_loader, val_loader=val_loader if not args.synthetic else None, test_loader=test_loader, device=args.device, output_dir=args.output, alpha=args.alpha, temperature=args.temperature, ) trainer.train() # Evaluate print("\nEvaluating knowledge-distilled model...") results = evaluate_model(student, test_loader, args.device) print(f"Student PPL: {results['test_ppl']:.2f}") print(f"Student params: {results['total_params']:,}") print(f"Compression vs teacher: {teacher.total_params / results['total_params']:.1f}x") if __name__ == "__main__": main()