| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| def build_moe_connector(num_experts, num_selected): |
| mm_hidden_size = 1024 |
| hidden_size = 4096 |
|
|
| return MLPMoE( |
| num_experts = num_experts, |
| num_selected = num_selected, |
| mm_channels = mm_hidden_size, |
| channels = hidden_size, |
| ) |
|
|
|
|
| class MLPMoE(nn.Module): |
| def __init__(self, num_experts, num_selected, mm_channels, channels): |
| super().__init__() |
| self.num_experts = num_experts |
| self.num_selected = num_selected |
| self.mm_channels = mm_channels |
| self.channels = channels |
|
|
| self.gate = nn.Linear(mm_channels, num_experts, bias=False) |
|
|
| self.num_selected = num_selected |
| self.num_experts = num_experts |
| self.experts = nn.ModuleList([nn.Sequential(nn.Linear(mm_channels, channels, bias=True), nn.GELU(), nn.Linear(channels, channels, bias=True)) for _ in range(num_experts)]) |
| |
| def forward(self, x_img): |
| gate_logits = self.gate(x_img) |
| gate_softmax = F.softmax(gate_logits, dim=-1, dtype=torch.float).to(x_img.dtype) |
|
|
| weights, selected_experts = torch.topk(gate_softmax, self.num_selected) |
| weights = weights / torch.sum(weights, dim=-1, keepdim=True).to(x_img.dtype) |
|
|
| results = torch.zeros((x_img.shape[0], x_img.shape[1], self.channels)).to(x_img.device, x_img.dtype) |
| for b in range(x_img.shape[0]): |
| for i, expert in enumerate(self.experts): |
| token_idx, nth_expert = torch.where(selected_experts[b] == i) |
| results[b][token_idx] += weights[b][token_idx, nth_expert, None] * expert(x_img[b][token_idx]) |
| |
| return results |