| import torch |
| from transformers import PreTrainedModel |
| from transformers.generation import GenerationMixin |
| from transformers.modeling_outputs import CausalLMOutput |
|
|
| from .configuration_gpjtgpt2 import GPJTGPT2Config |
| from .gpt import GPTModel |
|
|
|
|
| class GPJTGPT2Model(PreTrainedModel): |
|
|
| config_class = GPJTGPT2Config |
|
|
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.model = GPTModel(config.cfg) |
| self.post_init() |
|
|
|
|
| def forward(self, input_ids, **kwargs): |
| return self.model.forward(input_ids) |
|
|
|
|
| class GPJTGPT2ModelForCausalLM(PreTrainedModel, GenerationMixin): |
|
|
| config_class = GPJTGPT2Config |
|
|
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.model = GPTModel(config.cfg) |
| self.post_init() |
|
|
|
|
| def forward(self, input_ids, attention_mask=None, labels=None, **kwargs): |
| logits = self.model.forward(input_ids) |
|
|
| loss = None |
| if labels is not None: |
| shifted_logits = logits[:, :-1, :] |
| shifted_labels = labels[:, 1:] |
|
|
| if attention_mask is not None: |
| shifted_mask = attention_mask[:, 1:] |
| shifted_labels = shifted_labels.masked_fill( |
| shifted_mask == 0, -100 |
| ) |
|
|
| loss = torch.nn.functional.cross_entropy( |
| shifted_logits.flatten(0, 1), shifted_labels.flatten(), |
| ignore_index=-100 |
| ) |
|
|
| return CausalLMOutput(logits=logits, loss=loss) |
|
|
|
|