wangzhengtao commited on
Commit
bed55d4
·
1 Parent(s): f13e4d0

also fix weight init issue

Browse files
Files changed (2) hide show
  1. modeling_deepseek.py +1 -1
  2. modeling_kimi_k25.py +15 -1
modeling_deepseek.py CHANGED
@@ -1244,7 +1244,7 @@ class DeepseekV3PreTrainedModel(PreTrainedModel):
1244
 
1245
  def _init_weights(self, module):
1246
  std = self.config.initializer_range
1247
- if isinstance(module, nn.Linear):
1248
  module.weight.data.normal_(mean=0.0, std=std)
1249
  if module.bias is not None:
1250
  module.bias.data.zero_()
 
1244
 
1245
  def _init_weights(self, module):
1246
  std = self.config.initializer_range
1247
+ if isinstance(module, nn.Linear) and hasattr(module, "weight"):
1248
  module.weight.data.normal_(mean=0.0, std=std)
1249
  if module.bias is not None:
1250
  module.bias.data.zero_()
modeling_kimi_k25.py CHANGED
@@ -638,6 +638,20 @@ class MoonViT3dPretrainedModel(PreTrainedModel):
638
  _no_split_modules = ['PackingTransformer']
639
  _supports_flash_attn_2 = True
640
  _supports_sdpa = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
641
 
642
  def __init__(self, config, *inputs, **kwargs):
643
  super().__init__(config, *inputs, **kwargs)
@@ -786,7 +800,7 @@ class KimiK25PreTrainedModel(PreTrainedModel):
786
  if hasattr(module, "class_embedding"):
787
  module.class_embedding.data.normal_(mean=0.0, std=std)
788
 
789
- if isinstance(module, (nn.Linear, nn.Conv2d)):
790
  module.weight.data.normal_(mean=0.0, std=std)
791
  if module.bias is not None:
792
  module.bias.data.zero_()
 
638
  _no_split_modules = ['PackingTransformer']
639
  _supports_flash_attn_2 = True
640
  _supports_sdpa = True
641
+
642
+ def _init_weights(self, module):
643
+ # Default PreTrainedModel._init_weights treats nn.Linear subclasses (e.g.
644
+ # compressed_tensors.CompressedLinear) as Linear but those layers may have
645
+ # no `weight` after compression; skip in that case.
646
+ std = getattr(self.config, "initializer_range", 0.02)
647
+ if isinstance(module, (nn.Linear, nn.Conv2d)) and hasattr(module, "weight"):
648
+ module.weight.data.normal_(mean=0.0, std=std)
649
+ if module.bias is not None:
650
+ module.bias.data.zero_()
651
+ elif isinstance(module, nn.Embedding):
652
+ module.weight.data.normal_(mean=0.0, std=std)
653
+ if module.padding_idx is not None:
654
+ module.weight.data[module.padding_idx].zero_()
655
 
656
  def __init__(self, config, *inputs, **kwargs):
657
  super().__init__(config, *inputs, **kwargs)
 
800
  if hasattr(module, "class_embedding"):
801
  module.class_embedding.data.normal_(mean=0.0, std=std)
802
 
803
+ if isinstance(module, (nn.Linear, nn.Conv2d)) and hasattr(module, "weight"):
804
  module.weight.data.normal_(mean=0.0, std=std)
805
  if module.bias is not None:
806
  module.bias.data.zero_()