import torch.nn as nn def build_mlp_(hidden_size=640, projector_dim=1024, z_dim=768): return nn.Sequential( nn.Linear(hidden_size, projector_dim), nn.SiLU(), nn.Linear(projector_dim, projector_dim), nn.SiLU(), nn.Linear(projector_dim, z_dim), )