fix: use nn.init.* in _init_weights + recompute freqs_cis buffer on load
Browse filesTwo bugs caused gibberish generation when loading via from_pretrained:
1. _init_weights used module.weight.data.*_() which bypasses the HF
init guard (_is_hf_initialized), causing loaded pretrained weights
to be overwritten with random initialization. Fixed by using
nn.init.normal_/nn.init.zeros_/nn.init.ones_ instead.
2. freqs_cis (RoPE frequencies) is a non-persistent buffer that gets
filled with uninitialized memory when from_pretrained constructs
on meta device. Added explicit recomputation in _init_weights.
- modeling_deepseek_v4.py +17 -8
modeling_deepseek_v4.py
CHANGED
|
@@ -469,21 +469,22 @@ class DeepseekV4PreTrainedModel(PreTrainedModel):
|
|
| 469 |
_skip_keys_device_placement = ["past_key_values"]
|
| 470 |
|
| 471 |
def _init_weights(self, module):
|
|
|
|
|
|
|
| 472 |
std = self.config.initializer_range
|
| 473 |
if isinstance(module, nn.Linear):
|
| 474 |
-
|
| 475 |
if module.bias is not None:
|
| 476 |
-
|
| 477 |
elif isinstance(module, nn.Embedding):
|
| 478 |
-
|
| 479 |
elif isinstance(module, DeepseekV4RMSNorm):
|
| 480 |
-
|
| 481 |
elif isinstance(module, DeepseekV4Gate):
|
| 482 |
-
|
| 483 |
if module.bias is not None:
|
| 484 |
-
|
| 485 |
elif isinstance(module, DeepseekV4Block):
|
| 486 |
-
# Initialize HC parameters
|
| 487 |
nn.init.normal_(module.hc_attn_fn, std=0.01)
|
| 488 |
nn.init.normal_(module.hc_ffn_fn, std=0.01)
|
| 489 |
nn.init.zeros_(module.hc_attn_base)
|
|
@@ -523,11 +524,19 @@ class DeepseekV4Model(DeepseekV4PreTrainedModel):
|
|
| 523 |
|
| 524 |
def _init_weights(self, module):
|
| 525 |
super()._init_weights(module)
|
| 526 |
-
# HC head initialization
|
| 527 |
if module is self:
|
| 528 |
nn.init.normal_(self.hc_head_fn, std=0.01)
|
| 529 |
nn.init.zeros_(self.hc_head_base)
|
| 530 |
nn.init.ones_(self.hc_head_scale)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 531 |
|
| 532 |
def hc_head(self, x):
|
| 533 |
"""Contract hc_mult copies to 1 for final output.
|
|
|
|
| 469 |
_skip_keys_device_placement = ["past_key_values"]
|
| 470 |
|
| 471 |
def _init_weights(self, module):
|
| 472 |
+
# Use nn.init.* (not module.weight.data.*_()) so transformers'
|
| 473 |
+
# init guard skips already-loaded params via the _is_hf_initialized flag.
|
| 474 |
std = self.config.initializer_range
|
| 475 |
if isinstance(module, nn.Linear):
|
| 476 |
+
nn.init.normal_(module.weight, mean=0.0, std=std)
|
| 477 |
if module.bias is not None:
|
| 478 |
+
nn.init.zeros_(module.bias)
|
| 479 |
elif isinstance(module, nn.Embedding):
|
| 480 |
+
nn.init.normal_(module.weight, mean=0.0, std=std)
|
| 481 |
elif isinstance(module, DeepseekV4RMSNorm):
|
| 482 |
+
nn.init.ones_(module.weight)
|
| 483 |
elif isinstance(module, DeepseekV4Gate):
|
| 484 |
+
nn.init.normal_(module.weight, mean=0.0, std=std)
|
| 485 |
if module.bias is not None:
|
| 486 |
+
nn.init.zeros_(module.bias)
|
| 487 |
elif isinstance(module, DeepseekV4Block):
|
|
|
|
| 488 |
nn.init.normal_(module.hc_attn_fn, std=0.01)
|
| 489 |
nn.init.normal_(module.hc_ffn_fn, std=0.01)
|
| 490 |
nn.init.zeros_(module.hc_attn_base)
|
|
|
|
| 524 |
|
| 525 |
def _init_weights(self, module):
|
| 526 |
super()._init_weights(module)
|
|
|
|
| 527 |
if module is self:
|
| 528 |
nn.init.normal_(self.hc_head_fn, std=0.01)
|
| 529 |
nn.init.zeros_(self.hc_head_base)
|
| 530 |
nn.init.ones_(self.hc_head_scale)
|
| 531 |
+
# Recompute non-persistent freqs_cis buffer. from_pretrained constructs
|
| 532 |
+
# on meta device and refills non-persistent buffers with uninitialized
|
| 533 |
+
# memory; without this, rotary positional encodings would be garbage.
|
| 534 |
+
freqs = precompute_freqs_cis(
|
| 535 |
+
self.config.qk_rope_head_dim,
|
| 536 |
+
self.config.max_position_embeddings,
|
| 537 |
+
self.config.rope_theta,
|
| 538 |
+
).to(self.freqs_cis.device)
|
| 539 |
+
self.freqs_cis.data.copy_(freqs)
|
| 540 |
|
| 541 |
def hc_head(self, x):
|
| 542 |
"""Contract hc_mult copies to 1 for final output.
|