File size: 3,740 Bytes
d30a2f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
"""
Knowledge distillation training script.

Trains a compressed Q-TensorFormer student using a dense teacher model.
Matches the student's parameter budget to ~50% of the teacher.

Usage:
    python scripts/distill.py --teacher_config small --student_rank 4
"""

import sys
import os
import argparse
from pathlib import Path

sys.path.insert(0, str(Path(__file__).parent.parent))

import torch
from src.config import ExperimentConfig, PRESETS
from src.models import create_model
from src.baselines import StandardTransformer
from src.data import load_wikitext2, load_synthetic_data
from src.training import DistillationTrainer
from src.metrics import evaluate_model


def main():
    parser = argparse.ArgumentParser(description="KD for Q-TensorFormer")
    parser.add_argument("--teacher_config", type=str, default="small")
    parser.add_argument("--student_rank", type=int, default=4)
    parser.add_argument("--alpha", type=float, default=0.5,
                        help="Distillation loss weight")
    parser.add_argument("--temperature", type=float, default=3.0)
    parser.add_argument("--epochs", type=int, default=8)
    parser.add_argument("--batch_size", type=int, default=16)
    parser.add_argument("--device", type=str, default="cpu")
    parser.add_argument("--output", type=str, default="./outputs/distill/")
    parser.add_argument("--synthetic", action="store_true")
    args = parser.parse_args()

    torch.manual_seed(42)

    # Teacher: dense baseline
    teacher_config = PRESETS[args.teacher_config]()
    print(f"Teacher config: {teacher_config.experiment_name}")

    # Load data
    if args.synthetic:
        train_loader = load_synthetic_data(batch_size=args.batch_size)
        test_loader = train_loader
    else:
        train_loader, val_loader, test_loader, tokenizer = load_wikitext2(
            batch_size=args.batch_size
        )
        teacher_config.model.vocab_size = tokenizer.vocab_size

    # Create teacher (dense)
    teacher = StandardTransformer(
        vocab_size=teacher_config.model.vocab_size,
        d_model=teacher_config.model.d_model,
        n_heads=teacher_config.model.n_heads,
        n_layers=teacher_config.model.n_layers,
    )
    print(f"Teacher params: {teacher.total_params:,}")

    # Student: compressed Q-TensorFormer
    student_config = ExperimentConfig(
        model=type(teacher_config.model)(
            **{k: v for k, v in teacher_config.model.__dict__.items()}
        ),
        training=type(teacher_config.training)(
            **{k: v for k, v in teacher_config.training.__dict__.items()}
        ),
    )
    student_config.model.tt_rank = args.student_rank
    student_config.model.use_quantum = True
    student_config.training.max_epochs = args.epochs

    student = create_model(student_config, "qtensor")
    print(f"Student params: {student.total_params:,}")
    print(f"Compression: {teacher.total_params / student.total_params:.1f}x")

    # Train with distillation
    trainer = DistillationTrainer(
        student=student,
        teacher=teacher,
        config=student_config,
        train_loader=train_loader,
        val_loader=val_loader if not args.synthetic else None,
        test_loader=test_loader,
        device=args.device,
        output_dir=args.output,
        alpha=args.alpha,
        temperature=args.temperature,
    )
    trainer.train()

    # Evaluate
    print("\nEvaluating knowledge-distilled model...")
    results = evaluate_model(student, test_loader, args.device)
    print(f"Student PPL: {results['test_ppl']:.2f}")
    print(f"Student params: {results['total_params']:,}")
    print(f"Compression vs teacher: {teacher.total_params / results['total_params']:.1f}x")


if __name__ == "__main__":
    main()