| import torch.nn as nn |
| from torch.nn import functional as F |
| import torch |
| class MLP(nn.Module): |
| def __init__(self, config): |
| |
| super().__init__() |
| self.hidden_size = config.dim |
| self.intermediate_size = config.dim * config.mlp_scale |
| self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.bias) |
| self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.bias) |
| self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.bias) |
| self.dropout = nn.Dropout(config.dropout) |
|
|
| def forward(self, x): |
| gate = self.gate_proj(x) |
| gate = F.gelu(gate, approximate="tanh") |
| up = self.up_proj(x) |
| fuse = gate * up |
| outputs = self.down_proj(fuse) |
| outputs = self.dropout(outputs) |
| return outputs |