K446 commited on
Commit
c505237
·
1 Parent(s): c09f4cb

Fix: enable_input_require_grads for gradient checkpointing + 4-bit

Browse files
Files changed (1) hide show
  1. run_training.py +1 -0
run_training.py CHANGED
@@ -81,6 +81,7 @@ def run_grpo_training():
81
  task_type="CAUSAL_LM",
82
  )
83
  model = get_peft_model(model, lora_config)
 
84
  print(f" Model: {MODEL_NAME}")
85
  print(f" Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
86
 
 
81
  task_type="CAUSAL_LM",
82
  )
83
  model = get_peft_model(model, lora_config)
84
+ model.enable_input_require_grads() # Required for gradient checkpointing + 4-bit
85
  print(f" Model: {MODEL_NAME}")
86
  print(f" Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
87