kshitijthakkar commited on
Commit
fc8b993
·
verified ·
1 Parent(s): e0e7a3d

modeling: enable HF gradient_checkpointing — declare attribute on DeepseekV4Model + DeepseekV4ForCausalLM and wrap layer iteration in self._gradient_checkpointing_func when enabled+training

Browse files
code/deepseek_v4/modeling_deepseek_v4.py CHANGED
@@ -1386,6 +1386,8 @@ class DeepseekV4Model(DeepseekV4PreTrainedModel):
1386
  self.mtp = nn.ModuleList([
1387
  DeepseekV4MTPModule(config) for _ in range(config.num_nextn_predict_layers)
1388
  ])
 
 
1389
  self.post_init()
1390
 
1391
  def _build_rope(self, max_len: int, device, dtype):
@@ -1417,8 +1419,14 @@ class DeepseekV4Model(DeepseekV4PreTrainedModel):
1417
  pad_mask = attention_mask.bool() if attention_mask is not None else None
1418
 
1419
  for layer in self.layers:
1420
- X = layer(X, self._mhc, input_ids, positions,
1421
- rope_cos, rope_sin, rope_cos_c, rope_sin_c, pad_mask)
 
 
 
 
 
 
1422
 
1423
  # Head-side mHC: collapse residual back to [B,S,d] using A_l
1424
  # Head mHC: pre-only collapse hc -> 1, then final norm
@@ -1459,6 +1467,8 @@ class DeepseekV4ForCausalLM(DeepseekV4PreTrainedModel):
1459
  self.mtp = nn.ModuleList([
1460
  DeepseekV4MTPModule(config) for _ in range(config.num_nextn_predict_layers)
1461
  ])
 
 
1462
  self.post_init()
1463
 
1464
  # HF auto methods
@@ -1554,8 +1564,14 @@ class DeepseekV4ForCausalLM(DeepseekV4PreTrainedModel):
1554
  pad_mask = attention_mask.bool() if attention_mask is not None else None
1555
 
1556
  for layer in self.layers:
1557
- X = layer(X, self._mhc, input_ids, positions,
1558
- rope_cos, rope_sin, rope_cos_c, rope_sin_c, pad_mask)
 
 
 
 
 
 
1559
 
1560
  # Head mHC: pre-only collapse hc -> 1, then final norm
1561
  head_pre = self._mhc.gen_head_pre(X, self.hc_head_fn, self.hc_head_base,
 
1386
  self.mtp = nn.ModuleList([
1387
  DeepseekV4MTPModule(config) for _ in range(config.num_nextn_predict_layers)
1388
  ])
1389
+ # HF Trainer flips this via gradient_checkpointing_enable; checked in forward.
1390
+ self.gradient_checkpointing = False
1391
  self.post_init()
1392
 
1393
  def _build_rope(self, max_len: int, device, dtype):
 
1419
  pad_mask = attention_mask.bool() if attention_mask is not None else None
1420
 
1421
  for layer in self.layers:
1422
+ if self.gradient_checkpointing and self.training:
1423
+ X = self._gradient_checkpointing_func(
1424
+ layer.__call__, X, self._mhc, input_ids, positions,
1425
+ rope_cos, rope_sin, rope_cos_c, rope_sin_c, pad_mask,
1426
+ )
1427
+ else:
1428
+ X = layer(X, self._mhc, input_ids, positions,
1429
+ rope_cos, rope_sin, rope_cos_c, rope_sin_c, pad_mask)
1430
 
1431
  # Head-side mHC: collapse residual back to [B,S,d] using A_l
1432
  # Head mHC: pre-only collapse hc -> 1, then final norm
 
1467
  self.mtp = nn.ModuleList([
1468
  DeepseekV4MTPModule(config) for _ in range(config.num_nextn_predict_layers)
1469
  ])
1470
+ # HF Trainer flips this via gradient_checkpointing_enable; checked in _backbone.
1471
+ self.gradient_checkpointing = False
1472
  self.post_init()
1473
 
1474
  # HF auto methods
 
1564
  pad_mask = attention_mask.bool() if attention_mask is not None else None
1565
 
1566
  for layer in self.layers:
1567
+ if self.gradient_checkpointing and self.training:
1568
+ X = self._gradient_checkpointing_func(
1569
+ layer.__call__, X, self._mhc, input_ids, positions,
1570
+ rope_cos, rope_sin, rope_cos_c, rope_sin_c, pad_mask,
1571
+ )
1572
+ else:
1573
+ X = layer(X, self._mhc, input_ids, positions,
1574
+ rope_cos, rope_sin, rope_cos_c, rope_sin_c, pad_mask)
1575
 
1576
  # Head mHC: pre-only collapse hc -> 1, then final norm
1577
  head_pre = self._mhc.gen_head_pre(X, self.hc_head_fn, self.hc_head_base,