GeeeekExplorer commited on
Commit
2b2bebc
·
verified ·
1 Parent(s): 6c858e7

Update inference/model.py

Browse files
Files changed (1) hide show
  1. inference/model.py +1 -2
inference/model.py CHANGED
@@ -624,8 +624,7 @@ class MoE(nn.Module):
624
  self.experts = nn.ModuleList([Expert(args.dim, args.moe_inter_dim, dtype=expert_dtype, swiglu_limit=args.swiglu_limit) if self.experts_start_idx <= i < self.experts_end_idx else None
625
  for i in range(self.n_routed_experts)])
626
  assert args.n_shared_experts == 1
627
- # no swiglu_limit
628
- self.shared_experts = Expert(args.dim, args.moe_inter_dim)
629
 
630
  def forward(self, x: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor:
631
  shape = x.size()
 
624
  self.experts = nn.ModuleList([Expert(args.dim, args.moe_inter_dim, dtype=expert_dtype, swiglu_limit=args.swiglu_limit) if self.experts_start_idx <= i < self.experts_end_idx else None
625
  for i in range(self.n_routed_experts)])
626
  assert args.n_shared_experts == 1
627
+ self.shared_experts = Expert(args.dim, args.moe_inter_dim, swiglu_limit=args.swiglu_limit)
 
628
 
629
  def forward(self, x: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor:
630
  shape = x.size()