File size: 3,740 Bytes
d30a2f9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 | """
Knowledge distillation training script.
Trains a compressed Q-TensorFormer student using a dense teacher model.
Matches the student's parameter budget to ~50% of the teacher.
Usage:
python scripts/distill.py --teacher_config small --student_rank 4
"""
import sys
import os
import argparse
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
import torch
from src.config import ExperimentConfig, PRESETS
from src.models import create_model
from src.baselines import StandardTransformer
from src.data import load_wikitext2, load_synthetic_data
from src.training import DistillationTrainer
from src.metrics import evaluate_model
def main():
parser = argparse.ArgumentParser(description="KD for Q-TensorFormer")
parser.add_argument("--teacher_config", type=str, default="small")
parser.add_argument("--student_rank", type=int, default=4)
parser.add_argument("--alpha", type=float, default=0.5,
help="Distillation loss weight")
parser.add_argument("--temperature", type=float, default=3.0)
parser.add_argument("--epochs", type=int, default=8)
parser.add_argument("--batch_size", type=int, default=16)
parser.add_argument("--device", type=str, default="cpu")
parser.add_argument("--output", type=str, default="./outputs/distill/")
parser.add_argument("--synthetic", action="store_true")
args = parser.parse_args()
torch.manual_seed(42)
# Teacher: dense baseline
teacher_config = PRESETS[args.teacher_config]()
print(f"Teacher config: {teacher_config.experiment_name}")
# Load data
if args.synthetic:
train_loader = load_synthetic_data(batch_size=args.batch_size)
test_loader = train_loader
else:
train_loader, val_loader, test_loader, tokenizer = load_wikitext2(
batch_size=args.batch_size
)
teacher_config.model.vocab_size = tokenizer.vocab_size
# Create teacher (dense)
teacher = StandardTransformer(
vocab_size=teacher_config.model.vocab_size,
d_model=teacher_config.model.d_model,
n_heads=teacher_config.model.n_heads,
n_layers=teacher_config.model.n_layers,
)
print(f"Teacher params: {teacher.total_params:,}")
# Student: compressed Q-TensorFormer
student_config = ExperimentConfig(
model=type(teacher_config.model)(
**{k: v for k, v in teacher_config.model.__dict__.items()}
),
training=type(teacher_config.training)(
**{k: v for k, v in teacher_config.training.__dict__.items()}
),
)
student_config.model.tt_rank = args.student_rank
student_config.model.use_quantum = True
student_config.training.max_epochs = args.epochs
student = create_model(student_config, "qtensor")
print(f"Student params: {student.total_params:,}")
print(f"Compression: {teacher.total_params / student.total_params:.1f}x")
# Train with distillation
trainer = DistillationTrainer(
student=student,
teacher=teacher,
config=student_config,
train_loader=train_loader,
val_loader=val_loader if not args.synthetic else None,
test_loader=test_loader,
device=args.device,
output_dir=args.output,
alpha=args.alpha,
temperature=args.temperature,
)
trainer.train()
# Evaluate
print("\nEvaluating knowledge-distilled model...")
results = evaluate_model(student, test_loader, args.device)
print(f"Student PPL: {results['test_ppl']:.2f}")
print(f"Student params: {results['total_params']:,}")
print(f"Compression vs teacher: {teacher.total_params / results['total_params']:.1f}x")
if __name__ == "__main__":
main()
|