Imsachin010 commited on
Commit
876b380
·
1 Parent(s): 1141c48

Fix FP16 AMP crash by explicitly loading base model in float32 for fallback hardware

Browse files
Files changed (1) hide show
  1. training/grpo_train.py +1 -1
training/grpo_train.py CHANGED
@@ -74,7 +74,7 @@ def _load_model_and_tokenizer(model_name: str, use_unsloth: bool = False):
74
  tokenizer.pad_token = tokenizer.eos_token
75
  model = AutoModelForCausalLM.from_pretrained(
76
  model_name,
77
- torch_dtype=torch.bfloat16 if bf16_supported else torch.float16,
78
  device_map="auto",
79
  )
80
  return model, tokenizer
 
74
  tokenizer.pad_token = tokenizer.eos_token
75
  model = AutoModelForCausalLM.from_pretrained(
76
  model_name,
77
+ torch_dtype=torch.bfloat16 if bf16_supported else torch.float32,
78
  device_map="auto",
79
  )
80
  return model, tokenizer