cmpatino HF Staff commited on
Commit
d10cc5f
·
verified ·
1 Parent(s): 96016bc

fix: use nn.init.* in _init_weights + recompute freqs_cis buffer on load

Browse files

Two bugs caused gibberish generation when loading via from_pretrained:

1. _init_weights used module.weight.data.*_() which bypasses the HF
init guard (_is_hf_initialized), causing loaded pretrained weights
to be overwritten with random initialization. Fixed by using
nn.init.normal_/nn.init.zeros_/nn.init.ones_ instead.

2. freqs_cis (RoPE frequencies) is a non-persistent buffer that gets
filled with uninitialized memory when from_pretrained constructs
on meta device. Added explicit recomputation in _init_weights.

Files changed (1) hide show
  1. modeling_deepseek_v4.py +17 -8
modeling_deepseek_v4.py CHANGED
@@ -469,21 +469,22 @@ class DeepseekV4PreTrainedModel(PreTrainedModel):
469
  _skip_keys_device_placement = ["past_key_values"]
470
 
471
  def _init_weights(self, module):
 
 
472
  std = self.config.initializer_range
473
  if isinstance(module, nn.Linear):
474
- module.weight.data.normal_(mean=0.0, std=std)
475
  if module.bias is not None:
476
- module.bias.data.zero_()
477
  elif isinstance(module, nn.Embedding):
478
- module.weight.data.normal_(mean=0.0, std=std)
479
  elif isinstance(module, DeepseekV4RMSNorm):
480
- module.weight.data.fill_(1.0)
481
  elif isinstance(module, DeepseekV4Gate):
482
- module.weight.data.normal_(mean=0.0, std=std)
483
  if module.bias is not None:
484
- module.bias.data.zero_()
485
  elif isinstance(module, DeepseekV4Block):
486
- # Initialize HC parameters
487
  nn.init.normal_(module.hc_attn_fn, std=0.01)
488
  nn.init.normal_(module.hc_ffn_fn, std=0.01)
489
  nn.init.zeros_(module.hc_attn_base)
@@ -523,11 +524,19 @@ class DeepseekV4Model(DeepseekV4PreTrainedModel):
523
 
524
  def _init_weights(self, module):
525
  super()._init_weights(module)
526
- # HC head initialization
527
  if module is self:
528
  nn.init.normal_(self.hc_head_fn, std=0.01)
529
  nn.init.zeros_(self.hc_head_base)
530
  nn.init.ones_(self.hc_head_scale)
 
 
 
 
 
 
 
 
 
531
 
532
  def hc_head(self, x):
533
  """Contract hc_mult copies to 1 for final output.
 
469
  _skip_keys_device_placement = ["past_key_values"]
470
 
471
  def _init_weights(self, module):
472
+ # Use nn.init.* (not module.weight.data.*_()) so transformers'
473
+ # init guard skips already-loaded params via the _is_hf_initialized flag.
474
  std = self.config.initializer_range
475
  if isinstance(module, nn.Linear):
476
+ nn.init.normal_(module.weight, mean=0.0, std=std)
477
  if module.bias is not None:
478
+ nn.init.zeros_(module.bias)
479
  elif isinstance(module, nn.Embedding):
480
+ nn.init.normal_(module.weight, mean=0.0, std=std)
481
  elif isinstance(module, DeepseekV4RMSNorm):
482
+ nn.init.ones_(module.weight)
483
  elif isinstance(module, DeepseekV4Gate):
484
+ nn.init.normal_(module.weight, mean=0.0, std=std)
485
  if module.bias is not None:
486
+ nn.init.zeros_(module.bias)
487
  elif isinstance(module, DeepseekV4Block):
 
488
  nn.init.normal_(module.hc_attn_fn, std=0.01)
489
  nn.init.normal_(module.hc_ffn_fn, std=0.01)
490
  nn.init.zeros_(module.hc_attn_base)
 
524
 
525
  def _init_weights(self, module):
526
  super()._init_weights(module)
 
527
  if module is self:
528
  nn.init.normal_(self.hc_head_fn, std=0.01)
529
  nn.init.zeros_(self.hc_head_base)
530
  nn.init.ones_(self.hc_head_scale)
531
+ # Recompute non-persistent freqs_cis buffer. from_pretrained constructs
532
+ # on meta device and refills non-persistent buffers with uninitialized
533
+ # memory; without this, rotary positional encodings would be garbage.
534
+ freqs = precompute_freqs_cis(
535
+ self.config.qk_rope_head_dim,
536
+ self.config.max_position_embeddings,
537
+ self.config.rope_theta,
538
+ ).to(self.freqs_cis.device)
539
+ self.freqs_cis.data.copy_(freqs)
540
 
541
  def hc_head(self, x):
542
  """Contract hc_mult copies to 1 for final output.