| from transformers import PreTrainedModel |
| from .custom_config import LinearConfig |
| import torch.nn as nn |
| import torch |
|
|
| class BasicLinear(PreTrainedModel): |
| config_class = LinearConfig |
| |
| def __init__(self, config): |
| super().__init__(config) |
| self.weight = nn.Parameter(torch.randn(config.out_features, config.in_features) * 0.01) |
| if config.bias: |
| self.bias = nn.Parameter(torch.zeros(config.out_features)) |
| else: |
| self.bias = None |
|
|
| def forward(self, x): |
| out = x @ self.weight.T |
| if self.bias is not None: |
| out = out + self.bias |
| return out |
|
|