Jayant-Kernel commited on
Commit
d34e286
·
1 Parent(s): d75e720

fix: auto-detect bf16 support

Browse files
Files changed (1) hide show
  1. train.py +3 -0
train.py CHANGED
@@ -175,8 +175,11 @@ trainer = GRPOTrainer(
175
  model=model,
176
  processing_class=tokenizer,
177
  reward_funcs=[reward_fn],
 
178
  args=GRPOConfig(
179
  output_dir="/tmp/deceit-1.5b",
 
 
180
  max_steps=150,
181
  per_device_train_batch_size=4,
182
  num_generations=4,
 
175
  model=model,
176
  processing_class=tokenizer,
177
  reward_funcs=[reward_fn],
178
+ import torch as _torch
179
  args=GRPOConfig(
180
  output_dir="/tmp/deceit-1.5b",
181
+ bf16=_torch.cuda.is_available() and _torch.cuda.is_bf16_supported(),
182
+ fp16=False,
183
  max_steps=150,
184
  per_device_train_batch_size=4,
185
  num_generations=4,