mini-agent / src /models /mlp_module.py
qninhdt's picture
cc
8db8077
import torch
import torch.nn as nn
import torch.nn.functional as F
class MLPProjection(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super().__init__()
self.linear1 = nn.Linear(input_dim, hidden_dim)
self.dropout = nn.Dropout(0.5)
self.linear2 = nn.Linear(hidden_dim, output_dim)
def forward(self, x_output):
# only use first token ([CLS]) of each output
x = x_output[:, 0, :]
x = self.linear1(x)
x = F.silu(x)
x = self.dropout(x)
x = self.linear2(x)
return x
class MLPPrediction(nn.Module):
def __init__(self, input_dim, use_abs_diff=False, use_mult=False):
super().__init__()
self.use_abs_diff = use_abs_diff
self.use_mult = use_mult
real_input_dim = input_dim * (2 + int(use_abs_diff) + int(use_mult))
self.mlp = nn.Sequential(
nn.Linear(real_input_dim, 512),
nn.SiLU(),
nn.Dropout(0.5),
nn.Linear(512, 256),
nn.SiLU(),
nn.Dropout(0.5),
nn.Linear(256, 128),
nn.SiLU(),
nn.Linear(128, 1),
)
def forward(self, x1, x2):
x = torch.cat([x1, x2], dim=1)
if self.use_abs_diff:
x_diff = torch.abs(x1 - x2)
x = torch.cat([x, x_diff], dim=1)
if self.use_mult:
x_mult = x1 * x2
x = torch.cat([x, x_mult], dim=1)
x = self.mlp(x)
return x