| import torch |
| from torch import nn |
| from typing import Optional |
| from dataclasses import dataclass |
| from transformers import PreTrainedModel |
| from .configuration_mlp import MLPConfig |
| from transformers.utils import ModelOutput |
| from transformers.activations import ACT2FN |
|
|
|
|
| @dataclass |
| class MLPOutput(ModelOutput): |
| loss: Optional[torch.FloatTensor] = None |
| logits: Optional[torch.FloatTensor] = None |
|
|
|
|
| class MLPPreTrainedModel(PreTrainedModel): |
| config_class = MLPConfig |
|
|
| def _init_weights(self, module): |
| """Initialize the weights""" |
| if isinstance(module, nn.Linear): |
| module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
| if module.bias is not None: |
| module.bias.data.zero_() |
|
|
|
|
| class MLPModel(MLPPreTrainedModel): |
| def __init__(self, config): |
| super().__init__(config) |
| self.act_fn = ACT2FN[config.hidden_act] |
| iho = [config.input_size, *config.hidden_size, config.output_size] |
| self.linears = nn.ModuleList([ |
| nn.Linear(iho[i], iho[i+1]) |
| for i in range(config.num_hidden_layers + 1) |
| ]) |
| self.loss_fn = nn.CrossEntropyLoss() |
| |
| self.post_init() |
| |
| def forward(self, inputs, labels=None): |
| for i in range(len(self.linears) - 1): |
| inputs = self.act_fn(self.linears[i](inputs)) |
| logits = self.linears[-1](inputs) |
|
|
| loss = None |
| if labels is None: |
| return ModelOutput(loss=loss, logits=logits) |
| else: |
| loss = self.loss_fn(logits, labels) |
| return ModelOutput(loss=loss, logits=logits) |