Fix e_score_correction_bias wrong dtype

#33
by fxmarty-amd - opened
Files changed (1) hide show
  1. modeling_deepseek.py +1 -1
modeling_deepseek.py CHANGED
@@ -430,7 +430,7 @@ class MoEGate(nn.Module):
430
  torch.empty((self.n_routed_experts, self.gating_dim)))
431
  if self.topk_method == "noaux_tc":
432
  self.e_score_correction_bias = nn.Parameter(
433
- torch.empty((self.n_routed_experts)))
434
  self.reset_parameters()
435
 
436
  def reset_parameters(self) -> None:
 
430
  torch.empty((self.n_routed_experts, self.gating_dim)))
431
  if self.topk_method == "noaux_tc":
432
  self.e_score_correction_bias = nn.Parameter(
433
+ torch.empty((self.n_routed_experts), dtype=torch.float32))
434
  self.reset_parameters()
435
 
436
  def reset_parameters(self) -> None: