aamrinder commited on
Commit
70346e7
·
verified ·
1 Parent(s): 37818f1

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. train/hour1_smoke.py +3 -1
  2. train/train_grpo.py +3 -1
train/hour1_smoke.py CHANGED
@@ -83,17 +83,19 @@ def main():
83
  # 5. Load Qwen2.5-3B-Instruct + LoRA
84
  print("\n[5/6] loading Qwen2.5-3B-Instruct (4-bit + LoRA)")
85
  try:
 
86
  model, tokenizer = FastLanguageModel.from_pretrained(
87
  model_name="unsloth/Qwen2.5-3B-Instruct",
88
  max_seq_length=2048, # smaller than full 4096 for speed
89
  load_in_4bit=True,
 
90
  )
91
  model = FastLanguageModel.get_peft_model(
92
  model,
93
  r=8, # smaller r for the smoke test
94
  lora_alpha=16,
95
  target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
96
- use_gradient_checkpointing="unsloth",
97
  )
98
  n_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
99
  print(f" ✓ model loaded; {n_trainable / 1e6:.1f}M LoRA params trainable")
 
83
  # 5. Load Qwen2.5-3B-Instruct + LoRA
84
  print("\n[5/6] loading Qwen2.5-3B-Instruct (4-bit + LoRA)")
85
  try:
86
+ import torch as _t
87
  model, tokenizer = FastLanguageModel.from_pretrained(
88
  model_name="unsloth/Qwen2.5-3B-Instruct",
89
  max_seq_length=2048, # smaller than full 4096 for speed
90
  load_in_4bit=True,
91
+ dtype=_t.bfloat16, # avoid LoRA dtype mismatch on L4
92
  )
93
  model = FastLanguageModel.get_peft_model(
94
  model,
95
  r=8, # smaller r for the smoke test
96
  lora_alpha=16,
97
  target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
98
+ use_gradient_checkpointing=True, # plain torch GC, not "unsloth" custom
99
  )
100
  n_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
101
  print(f" ✓ model loaded; {n_trainable / 1e6:.1f}M LoRA params trainable")
train/train_grpo.py CHANGED
@@ -283,17 +283,19 @@ def main():
283
  from trl import GRPOTrainer, GRPOConfig
284
 
285
  print(f"[load] {args.model}, 4-bit, max_seq_length={args.seq_length}")
 
286
  model, tokenizer = FastLanguageModel.from_pretrained(
287
  model_name=args.model,
288
  max_seq_length=args.seq_length,
289
  load_in_4bit=True,
 
290
  )
291
  model = FastLanguageModel.get_peft_model(
292
  model,
293
  r=args.lora_r,
294
  lora_alpha=args.lora_r,
295
  target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
296
- use_gradient_checkpointing="unsloth",
297
  )
298
 
299
  config = GRPOConfig(
 
283
  from trl import GRPOTrainer, GRPOConfig
284
 
285
  print(f"[load] {args.model}, 4-bit, max_seq_length={args.seq_length}")
286
+ import torch as _t
287
  model, tokenizer = FastLanguageModel.from_pretrained(
288
  model_name=args.model,
289
  max_seq_length=args.seq_length,
290
  load_in_4bit=True,
291
+ dtype=_t.bfloat16, # explicit dtype prevents LoRA Half/Float mismatch
292
  )
293
  model = FastLanguageModel.get_peft_model(
294
  model,
295
  r=args.lora_r,
296
  lora_alpha=args.lora_r,
297
  target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
298
+ use_gradient_checkpointing=True, # plain torch GC; avoids unsloth-zoo dtype bug
299
  )
300
 
301
  config = GRPOConfig(