Commit ·
9e6eed7
1
Parent(s): 5a49d03
fix: add missing use_deterministic_attn parameter to MoonViT3dEncoder (#22)
Browse files- fix: add missing use_deterministic_attn parameter to MoonViT3dEncoder (f13e4d071155abcae05f85cc3c1b5fa56ea1b9fc)
- also fix weight init issue (bed55d493c07900df87ed51274581025cbf8315a)
Co-authored-by: Ace Eldeib <ace-coreweave@users.noreply.huggingface.co>
- modeling_deepseek.py +1 -1
- modeling_kimi_k25.py +18 -3
modeling_deepseek.py
CHANGED
|
@@ -1244,7 +1244,7 @@ class DeepseekV3PreTrainedModel(PreTrainedModel):
|
|
| 1244 |
|
| 1245 |
def _init_weights(self, module):
|
| 1246 |
std = self.config.initializer_range
|
| 1247 |
-
if isinstance(module, nn.Linear):
|
| 1248 |
module.weight.data.normal_(mean=0.0, std=std)
|
| 1249 |
if module.bias is not None:
|
| 1250 |
module.bias.data.zero_()
|
|
|
|
| 1244 |
|
| 1245 |
def _init_weights(self, module):
|
| 1246 |
std = self.config.initializer_range
|
| 1247 |
+
if isinstance(module, nn.Linear) and hasattr(module, "weight"):
|
| 1248 |
module.weight.data.normal_(mean=0.0, std=std)
|
| 1249 |
if module.bias is not None:
|
| 1250 |
module.bias.data.zero_()
|
modeling_kimi_k25.py
CHANGED
|
@@ -562,7 +562,8 @@ class MoonViT3dEncoder(nn.Module):
|
|
| 562 |
hidden_dim: int,
|
| 563 |
num_layers: int,
|
| 564 |
block_cfg: dict,
|
| 565 |
-
video_attn_type: str = 'spatial_temporal'
|
|
|
|
| 566 |
super().__init__()
|
| 567 |
|
| 568 |
assert video_attn_type == 'spatial_temporal', f'video_attn_type must be "spatial_temporal", got {video_attn_type}'
|
|
@@ -572,7 +573,7 @@ class MoonViT3dEncoder(nn.Module):
|
|
| 572 |
self.blocks = nn.ModuleList([
|
| 573 |
MoonViTEncoderLayer(
|
| 574 |
**block_cfg,
|
| 575 |
-
use_deterministic_attn=
|
| 576 |
for _ in range(num_layers)
|
| 577 |
])
|
| 578 |
self.final_layernorm = nn.LayerNorm(hidden_dim)
|
|
@@ -637,6 +638,20 @@ class MoonViT3dPretrainedModel(PreTrainedModel):
|
|
| 637 |
_no_split_modules = ['PackingTransformer']
|
| 638 |
_supports_flash_attn_2 = True
|
| 639 |
_supports_sdpa = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 640 |
|
| 641 |
def __init__(self, config, *inputs, **kwargs):
|
| 642 |
super().__init__(config, *inputs, **kwargs)
|
|
@@ -785,7 +800,7 @@ class KimiK25PreTrainedModel(PreTrainedModel):
|
|
| 785 |
if hasattr(module, "class_embedding"):
|
| 786 |
module.class_embedding.data.normal_(mean=0.0, std=std)
|
| 787 |
|
| 788 |
-
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
| 789 |
module.weight.data.normal_(mean=0.0, std=std)
|
| 790 |
if module.bias is not None:
|
| 791 |
module.bias.data.zero_()
|
|
|
|
| 562 |
hidden_dim: int,
|
| 563 |
num_layers: int,
|
| 564 |
block_cfg: dict,
|
| 565 |
+
video_attn_type: str = 'spatial_temporal',
|
| 566 |
+
use_deterministic_attn: bool = False) -> None:
|
| 567 |
super().__init__()
|
| 568 |
|
| 569 |
assert video_attn_type == 'spatial_temporal', f'video_attn_type must be "spatial_temporal", got {video_attn_type}'
|
|
|
|
| 573 |
self.blocks = nn.ModuleList([
|
| 574 |
MoonViTEncoderLayer(
|
| 575 |
**block_cfg,
|
| 576 |
+
use_deterministic_attn=use_deterministic_attn)
|
| 577 |
for _ in range(num_layers)
|
| 578 |
])
|
| 579 |
self.final_layernorm = nn.LayerNorm(hidden_dim)
|
|
|
|
| 638 |
_no_split_modules = ['PackingTransformer']
|
| 639 |
_supports_flash_attn_2 = True
|
| 640 |
_supports_sdpa = True
|
| 641 |
+
|
| 642 |
+
def _init_weights(self, module):
|
| 643 |
+
# Default PreTrainedModel._init_weights treats nn.Linear subclasses (e.g.
|
| 644 |
+
# compressed_tensors.CompressedLinear) as Linear but those layers may have
|
| 645 |
+
# no `weight` after compression; skip in that case.
|
| 646 |
+
std = getattr(self.config, "initializer_range", 0.02)
|
| 647 |
+
if isinstance(module, (nn.Linear, nn.Conv2d)) and hasattr(module, "weight"):
|
| 648 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 649 |
+
if module.bias is not None:
|
| 650 |
+
module.bias.data.zero_()
|
| 651 |
+
elif isinstance(module, nn.Embedding):
|
| 652 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 653 |
+
if module.padding_idx is not None:
|
| 654 |
+
module.weight.data[module.padding_idx].zero_()
|
| 655 |
|
| 656 |
def __init__(self, config, *inputs, **kwargs):
|
| 657 |
super().__init__(config, *inputs, **kwargs)
|
|
|
|
| 800 |
if hasattr(module, "class_embedding"):
|
| 801 |
module.class_embedding.data.normal_(mean=0.0, std=std)
|
| 802 |
|
| 803 |
+
if isinstance(module, (nn.Linear, nn.Conv2d)) and hasattr(module, "weight"):
|
| 804 |
module.weight.data.normal_(mean=0.0, std=std)
|
| 805 |
if module.bias is not None:
|
| 806 |
module.bias.data.zero_()
|