Jayant-Kernel commited on
Commit ·
d34e286
1
Parent(s): d75e720
fix: auto-detect bf16 support
Browse files
train.py
CHANGED
|
@@ -175,8 +175,11 @@ trainer = GRPOTrainer(
|
|
| 175 |
model=model,
|
| 176 |
processing_class=tokenizer,
|
| 177 |
reward_funcs=[reward_fn],
|
|
|
|
| 178 |
args=GRPOConfig(
|
| 179 |
output_dir="/tmp/deceit-1.5b",
|
|
|
|
|
|
|
| 180 |
max_steps=150,
|
| 181 |
per_device_train_batch_size=4,
|
| 182 |
num_generations=4,
|
|
|
|
| 175 |
model=model,
|
| 176 |
processing_class=tokenizer,
|
| 177 |
reward_funcs=[reward_fn],
|
| 178 |
+
import torch as _torch
|
| 179 |
args=GRPOConfig(
|
| 180 |
output_dir="/tmp/deceit-1.5b",
|
| 181 |
+
bf16=_torch.cuda.is_available() and _torch.cuda.is_bf16_supported(),
|
| 182 |
+
fp16=False,
|
| 183 |
max_steps=150,
|
| 184 |
per_device_train_batch_size=4,
|
| 185 |
num_generations=4,
|