Update inference/model.py
Browse files- inference/model.py +1 -2
inference/model.py
CHANGED
|
@@ -624,8 +624,7 @@ class MoE(nn.Module):
|
|
| 624 |
self.experts = nn.ModuleList([Expert(args.dim, args.moe_inter_dim, dtype=expert_dtype, swiglu_limit=args.swiglu_limit) if self.experts_start_idx <= i < self.experts_end_idx else None
|
| 625 |
for i in range(self.n_routed_experts)])
|
| 626 |
assert args.n_shared_experts == 1
|
| 627 |
-
|
| 628 |
-
self.shared_experts = Expert(args.dim, args.moe_inter_dim)
|
| 629 |
|
| 630 |
def forward(self, x: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor:
|
| 631 |
shape = x.size()
|
|
|
|
| 624 |
self.experts = nn.ModuleList([Expert(args.dim, args.moe_inter_dim, dtype=expert_dtype, swiglu_limit=args.swiglu_limit) if self.experts_start_idx <= i < self.experts_end_idx else None
|
| 625 |
for i in range(self.n_routed_experts)])
|
| 626 |
assert args.n_shared_experts == 1
|
| 627 |
+
self.shared_experts = Expert(args.dim, args.moe_inter_dim, swiglu_limit=args.swiglu_limit)
|
|
|
|
| 628 |
|
| 629 |
def forward(self, x: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor:
|
| 630 |
shape = x.size()
|