| """ |
| Custom Student Model for Knowledge Distillation |
| """ |
| import torch |
| import torch.nn as nn |
| from transformers import PreTrainedModel, PretrainedConfig |
| from typing import Dict, Any, List, Optional |
|
|
| class StudentModelConfig(PretrainedConfig): |
| model_type = "distilled_student" |
|
|
| def __init__( |
| self, |
| hidden_size=768, |
| num_layers=12, |
| num_attention_heads=12, |
| intermediate_size=3072, |
| vocab_size=30522, |
| max_position_embeddings=512, |
| modalities=["text"], |
| **kwargs |
| ): |
| super().__init__(**kwargs) |
| self.hidden_size = hidden_size |
| self.num_layers = num_layers |
| self.num_attention_heads = num_attention_heads |
| self.intermediate_size = intermediate_size |
| self.vocab_size = vocab_size |
| self.max_position_embeddings = max_position_embeddings |
| self.modalities = modalities |
|
|
| class StudentModel(PreTrainedModel): |
| config_class = StudentModelConfig |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.config = config |
| self.hidden_size = config.hidden_size |
| self.num_layers = config.num_layers |
| self.modalities = config.modalities |
|
|
| |
| self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size) |
| self.layers = nn.ModuleList([ |
| nn.TransformerEncoderLayer( |
| d_model=config.hidden_size, |
| nhead=config.num_attention_heads, |
| dim_feedforward=config.intermediate_size, |
| batch_first=True |
| ) for _ in range(config.num_layers) |
| ]) |
| self.pooler = nn.Linear(config.hidden_size, config.hidden_size) |
|
|
| def forward(self, input_ids=None, attention_mask=None, **kwargs): |
| if input_ids is not None: |
| embeddings = self.embeddings(input_ids) |
| else: |
| |
| embeddings = kwargs.get('inputs_embeds') |
|
|
| for layer in self.layers: |
| embeddings = layer(embeddings, src_key_padding_mask=attention_mask) |
|
|
| pooled = self.pooler(embeddings.mean(dim=1)) |
|
|
| return { |
| 'last_hidden_state': embeddings, |
| 'pooler_output': pooled |
| } |
|
|