Jayant-Kernel commited on
Commit ·
e4aea5d
1
Parent(s): d34e286
fix: remove misplaced import inside GRPOConfig args
Browse files
train.py
CHANGED
|
@@ -175,10 +175,9 @@ trainer = GRPOTrainer(
|
|
| 175 |
model=model,
|
| 176 |
processing_class=tokenizer,
|
| 177 |
reward_funcs=[reward_fn],
|
| 178 |
-
import torch as _torch
|
| 179 |
args=GRPOConfig(
|
| 180 |
output_dir="/tmp/deceit-1.5b",
|
| 181 |
-
bf16=
|
| 182 |
fp16=False,
|
| 183 |
max_steps=150,
|
| 184 |
per_device_train_batch_size=4,
|
|
|
|
| 175 |
model=model,
|
| 176 |
processing_class=tokenizer,
|
| 177 |
reward_funcs=[reward_fn],
|
|
|
|
| 178 |
args=GRPOConfig(
|
| 179 |
output_dir="/tmp/deceit-1.5b",
|
| 180 |
+
bf16=torch.cuda.is_available() and torch.cuda.is_bf16_supported(),
|
| 181 |
fp16=False,
|
| 182 |
max_steps=150,
|
| 183 |
per_device_train_batch_size=4,
|