Fix: enable_input_require_grads for gradient checkpointing + 4-bit
Browse files- run_training.py +1 -0
run_training.py
CHANGED
|
@@ -81,6 +81,7 @@ def run_grpo_training():
|
|
| 81 |
task_type="CAUSAL_LM",
|
| 82 |
)
|
| 83 |
model = get_peft_model(model, lora_config)
|
|
|
|
| 84 |
print(f" Model: {MODEL_NAME}")
|
| 85 |
print(f" Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
|
| 86 |
|
|
|
|
| 81 |
task_type="CAUSAL_LM",
|
| 82 |
)
|
| 83 |
model = get_peft_model(model, lora_config)
|
| 84 |
+
model.enable_input_require_grads() # Required for gradient checkpointing + 4-bit
|
| 85 |
print(f" Model: {MODEL_NAME}")
|
| 86 |
print(f" Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
|
| 87 |
|