| import torch |
| import torch.nn as nn |
| from models.gating_network import GatingNetwork |
| from models.vision_expert import VisionExpert |
| from models.audio_expert import AudioExpert |
| from models.sensor_expert import SensorExpert |
|
|
| class MoEModel(nn.Module): |
| def __init__(self, input_dim, num_experts): |
| super(MoEModel, self).__init__() |
| self.gating_network = GatingNetwork(input_dim=input_dim, num_experts=num_experts) |
| self.experts = nn.ModuleList([VisionExpert(), AudioExpert(), SensorExpert()]) |
| self.fc_final = nn.Linear(128, 10) |
|
|
| def forward(self, vision_input, audio_input, sensor_input): |
| vision_features = self.experts[0](vision_input) |
| audio_features = self.experts[1](audio_input) |
| sensor_features = self.experts[2](sensor_input) |
| |
| combined_features = torch.cat((vision_features, audio_features, sensor_features), dim=1) |
| gating_weights = self.gating_network(combined_features) |
| |
| expert_outputs = torch.stack([expert(combined_features) for expert in self.experts], dim=1) |
| final_output = torch.einsum('ij,ijk->ik', gating_weights, expert_outputs) |
| |
| return self.fc_final(final_output) |
|
|