Fix e_score_correction_bias wrong dtype
Browse filesAs per title.
In the checkpoint, mlp.gate.e_score_correction_bias is in float32.
Simply using `torch.empty` without specifying the dtype risks initializing this parameter with a wrong dtype.
- modeling_deepseek.py +1 -1
modeling_deepseek.py
CHANGED
|
@@ -430,7 +430,7 @@ class MoEGate(nn.Module):
|
|
| 430 |
torch.empty((self.n_routed_experts, self.gating_dim)))
|
| 431 |
if self.topk_method == "noaux_tc":
|
| 432 |
self.e_score_correction_bias = nn.Parameter(
|
| 433 |
-
torch.empty((self.n_routed_experts)))
|
| 434 |
self.reset_parameters()
|
| 435 |
|
| 436 |
def reset_parameters(self) -> None:
|
|
|
|
| 430 |
torch.empty((self.n_routed_experts, self.gating_dim)))
|
| 431 |
if self.topk_method == "noaux_tc":
|
| 432 |
self.e_score_correction_bias = nn.Parameter(
|
| 433 |
+
torch.empty((self.n_routed_experts), dtype=torch.float32))
|
| 434 |
self.reset_parameters()
|
| 435 |
|
| 436 |
def reset_parameters(self) -> None:
|