| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| class MLP(nn.Module): |
| def __init__( |
| self, |
| dim: int, |
| hidden_dim: int, |
| ): |
| """ |
| Initializes the multilayer perceptron (MLP) module. |
| |
| Args: |
| dim: The input and output dimensionality. |
| hidden_dim: The dimensionality of the hidden layer. |
| """ |
| super().__init__() |
| self.w1 = nn.Linear(dim, hidden_dim, bias=False) |
| self.w2 = nn.Linear(hidden_dim, dim, bias=False) |
| self.w3 = nn.Linear(dim, hidden_dim, bias=False) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """ |
| Performs the forward pass of the MLP module. |
| |
| Args: |
| x: The input tensor of shape (batch_size, dim). |
| |
| Returns: |
| The output tensor of shape (batch_size, dim). |
| """ |
| output = self.w2(F.silu(self.w1(x)) * self.w3(x)) |
| return output |
|
|