Jayant-Kernel commited on
Commit
e4aea5d
·
1 Parent(s): d34e286

fix: remove misplaced import inside GRPOConfig args

Browse files
Files changed (1) hide show
  1. train.py +1 -2
train.py CHANGED
@@ -175,10 +175,9 @@ trainer = GRPOTrainer(
175
  model=model,
176
  processing_class=tokenizer,
177
  reward_funcs=[reward_fn],
178
- import torch as _torch
179
  args=GRPOConfig(
180
  output_dir="/tmp/deceit-1.5b",
181
- bf16=_torch.cuda.is_available() and _torch.cuda.is_bf16_supported(),
182
  fp16=False,
183
  max_steps=150,
184
  per_device_train_batch_size=4,
 
175
  model=model,
176
  processing_class=tokenizer,
177
  reward_funcs=[reward_fn],
 
178
  args=GRPOConfig(
179
  output_dir="/tmp/deceit-1.5b",
180
+ bf16=torch.cuda.is_available() and torch.cuda.is_bf16_supported(),
181
  fp16=False,
182
  max_steps=150,
183
  per_device_train_batch_size=4,