bigmoyan ace-coreweave commited on
Commit
9e6eed7
·
1 Parent(s): 5a49d03

fix: add missing use_deterministic_attn parameter to MoonViT3dEncoder (#22)

Browse files

- fix: add missing use_deterministic_attn parameter to MoonViT3dEncoder (f13e4d071155abcae05f85cc3c1b5fa56ea1b9fc)
- also fix weight init issue (bed55d493c07900df87ed51274581025cbf8315a)


Co-authored-by: Ace Eldeib <ace-coreweave@users.noreply.huggingface.co>

Files changed (2) hide show
  1. modeling_deepseek.py +1 -1
  2. modeling_kimi_k25.py +18 -3
modeling_deepseek.py CHANGED
@@ -1244,7 +1244,7 @@ class DeepseekV3PreTrainedModel(PreTrainedModel):
1244
 
1245
  def _init_weights(self, module):
1246
  std = self.config.initializer_range
1247
- if isinstance(module, nn.Linear):
1248
  module.weight.data.normal_(mean=0.0, std=std)
1249
  if module.bias is not None:
1250
  module.bias.data.zero_()
 
1244
 
1245
  def _init_weights(self, module):
1246
  std = self.config.initializer_range
1247
+ if isinstance(module, nn.Linear) and hasattr(module, "weight"):
1248
  module.weight.data.normal_(mean=0.0, std=std)
1249
  if module.bias is not None:
1250
  module.bias.data.zero_()
modeling_kimi_k25.py CHANGED
@@ -562,7 +562,8 @@ class MoonViT3dEncoder(nn.Module):
562
  hidden_dim: int,
563
  num_layers: int,
564
  block_cfg: dict,
565
- video_attn_type: str = 'spatial_temporal') -> None:
 
566
  super().__init__()
567
 
568
  assert video_attn_type == 'spatial_temporal', f'video_attn_type must be "spatial_temporal", got {video_attn_type}'
@@ -572,7 +573,7 @@ class MoonViT3dEncoder(nn.Module):
572
  self.blocks = nn.ModuleList([
573
  MoonViTEncoderLayer(
574
  **block_cfg,
575
- use_deterministic_attn=self.use_deterministic_attn)
576
  for _ in range(num_layers)
577
  ])
578
  self.final_layernorm = nn.LayerNorm(hidden_dim)
@@ -637,6 +638,20 @@ class MoonViT3dPretrainedModel(PreTrainedModel):
637
  _no_split_modules = ['PackingTransformer']
638
  _supports_flash_attn_2 = True
639
  _supports_sdpa = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
640
 
641
  def __init__(self, config, *inputs, **kwargs):
642
  super().__init__(config, *inputs, **kwargs)
@@ -785,7 +800,7 @@ class KimiK25PreTrainedModel(PreTrainedModel):
785
  if hasattr(module, "class_embedding"):
786
  module.class_embedding.data.normal_(mean=0.0, std=std)
787
 
788
- if isinstance(module, (nn.Linear, nn.Conv2d)):
789
  module.weight.data.normal_(mean=0.0, std=std)
790
  if module.bias is not None:
791
  module.bias.data.zero_()
 
562
  hidden_dim: int,
563
  num_layers: int,
564
  block_cfg: dict,
565
+ video_attn_type: str = 'spatial_temporal',
566
+ use_deterministic_attn: bool = False) -> None:
567
  super().__init__()
568
 
569
  assert video_attn_type == 'spatial_temporal', f'video_attn_type must be "spatial_temporal", got {video_attn_type}'
 
573
  self.blocks = nn.ModuleList([
574
  MoonViTEncoderLayer(
575
  **block_cfg,
576
+ use_deterministic_attn=use_deterministic_attn)
577
  for _ in range(num_layers)
578
  ])
579
  self.final_layernorm = nn.LayerNorm(hidden_dim)
 
638
  _no_split_modules = ['PackingTransformer']
639
  _supports_flash_attn_2 = True
640
  _supports_sdpa = True
641
+
642
+ def _init_weights(self, module):
643
+ # Default PreTrainedModel._init_weights treats nn.Linear subclasses (e.g.
644
+ # compressed_tensors.CompressedLinear) as Linear but those layers may have
645
+ # no `weight` after compression; skip in that case.
646
+ std = getattr(self.config, "initializer_range", 0.02)
647
+ if isinstance(module, (nn.Linear, nn.Conv2d)) and hasattr(module, "weight"):
648
+ module.weight.data.normal_(mean=0.0, std=std)
649
+ if module.bias is not None:
650
+ module.bias.data.zero_()
651
+ elif isinstance(module, nn.Embedding):
652
+ module.weight.data.normal_(mean=0.0, std=std)
653
+ if module.padding_idx is not None:
654
+ module.weight.data[module.padding_idx].zero_()
655
 
656
  def __init__(self, config, *inputs, **kwargs):
657
  super().__init__(config, *inputs, **kwargs)
 
800
  if hasattr(module, "class_embedding"):
801
  module.class_embedding.data.normal_(mean=0.0, std=std)
802
 
803
+ if isinstance(module, (nn.Linear, nn.Conv2d)) and hasattr(module, "weight"):
804
  module.weight.data.normal_(mean=0.0, std=std)
805
  if module.bias is not None:
806
  module.bias.data.zero_()