| |
| from smol import DistillationTrainer |
| from transformers import AutoModel, AutoTokenizer |
| from transformers import DistilBERTForSequenceClassification |
| from transformers import AdamW |
| import torch |
| import torch.nn as nn |
|
|
| |
| |
| teacher_model = AutoModel.from_pretrained("swiss-ai/Apertus-8B-Instruct-2509") |
| tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-Nemo-Base-2407") |
|
|
| |
| |
| student_model = DistilBERTForSequenceClassification.from_pretrained("distilbert-base-uncased") |
|
|
| |
| class DistillationLoss(nn.Module): |
| def __init__(self, temperature, alpha): |
| super(DistillationLoss, self).__init__() |
| self.kl_loss = nn.KLDivLoss(temperature=temperature) |
| self.alpha = alpha |
|
|
| def forward(self, student_output, teacher_output): |
| return self.kl_loss(student_output.log_softmax(-1), teacher_output.softmax(-1)) * self.alpha |
|
|
| |
| def train_step(model, batch, optimizer, loss_fn, device): |
| |
| inputs = tokenizer(batch["input_ids"], **tokenizer_args) |
| labels = batch["labels"] |
| |
| |
| with torch.no_grad(): |
| teacher_output = model(**inputs) |
| teacher_output = teacher_output.logits if "logits" in teacher_output else teacher_output.logits |
| teacher_output = teacher_output.detach().to(device) |
|
|
| |
| student_output = model(**inputs) |
| student_logits = student_output.logits if hasattr(student_output, "logits") else student_output.logits |
| student_logits = student_logits.to(device) |
|
|
| |
| distillation_loss = loss_fn(student_logits, teacher_output.softmax(-1)) |
| loss = distillation_loss |
|
|
| |
| task_loss = loss_function(student_logits, labels.to(device)) |
| total_loss = distillation_loss + task_loss |
|
|
| |
| optimizer.zero_grad() |
| total_loss.backward() |
| optimizer.step() |
|
|
| return total_loss.item(), student_output, teacher_output |
|
|
| |
| from smol.trainer import DistillationTrainer |
| trainer = DistillationTrainer( |
| student_model, |
| optimizer=AdamW(student_model.parameters(), lr=1e-5), |
| loss_fn=DistillationLoss(temperature=1.0, alpha=0.5), |
| train_dataset=your_train_dataset, |
| eval_dataset=your_eval_dataset, |
| device="cuda" if torch.cuda.is_available() else "cpu", |
| num_epochs=5, |
| batch_size=16, |
| log_dir="distillation_logs", |
| ) |
|
|
| |
| trainer.train() |
|
|
| |
| |
|
|