aamrinder commited on
Commit
8d3bf91
·
verified ·
1 Parent(s): 70346e7

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. train/hour1_smoke.py +26 -16
  2. train/train_grpo.py +29 -13
train/hour1_smoke.py CHANGED
@@ -42,13 +42,15 @@ def main():
42
  traceback.print_exc()
43
  sys.exit(1)
44
 
45
- # 2. Unsloth + TRL imports
46
- print("\n[2/6] importing Unsloth + TRL")
47
  try:
48
- from unsloth import FastLanguageModel
 
 
49
  from trl import GRPOTrainer, GRPOConfig
50
  from datasets import Dataset
51
- print(" ✓ Unsloth + TRL + datasets imported")
52
  except Exception as e:
53
  print(f" ✗ {e}")
54
  traceback.print_exc()
@@ -80,23 +82,31 @@ def main():
80
  traceback.print_exc()
81
  sys.exit(1)
82
 
83
- # 5. Load Qwen2.5-3B-Instruct + LoRA
84
- print("\n[5/6] loading Qwen2.5-3B-Instruct (4-bit + LoRA)")
85
  try:
86
- import torch as _t
87
- model, tokenizer = FastLanguageModel.from_pretrained(
88
- model_name="unsloth/Qwen2.5-3B-Instruct",
89
- max_seq_length=2048, # smaller than full 4096 for speed
90
  load_in_4bit=True,
91
- dtype=_t.bfloat16, # avoid LoRA dtype mismatch on L4
 
 
 
 
 
 
 
 
 
 
 
92
  )
93
- model = FastLanguageModel.get_peft_model(
94
- model,
95
- r=8, # smaller r for the smoke test
96
- lora_alpha=16,
97
  target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
98
- use_gradient_checkpointing=True, # plain torch GC, not "unsloth" custom
99
  )
 
100
  n_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
101
  print(f" ✓ model loaded; {n_trainable / 1e6:.1f}M LoRA params trainable")
102
  except Exception as e:
 
42
  traceback.print_exc()
43
  sys.exit(1)
44
 
45
+ # 2. transformers + TRL + PEFT (deck requirement #2: "Unsloth OR HF TRL")
46
+ print("\n[2/6] importing transformers + TRL + PEFT")
47
  try:
48
+ import torch as _t
49
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
50
+ from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
51
  from trl import GRPOTrainer, GRPOConfig
52
  from datasets import Dataset
53
+ print(" ✓ transformers + TRL + PEFT + datasets imported")
54
  except Exception as e:
55
  print(f" ✗ {e}")
56
  traceback.print_exc()
 
82
  traceback.print_exc()
83
  sys.exit(1)
84
 
85
+ # 5. Load Qwen2.5-3B-Instruct + LoRA via plain transformers + PEFT
86
+ print("\n[5/6] loading Qwen2.5-3B-Instruct (4-bit + LoRA via transformers/peft)")
87
  try:
88
+ bnb = BitsAndBytesConfig(
 
 
 
89
  load_in_4bit=True,
90
+ bnb_4bit_compute_dtype=_t.bfloat16,
91
+ bnb_4bit_quant_type="nf4",
92
+ bnb_4bit_use_double_quant=True,
93
+ )
94
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B-Instruct")
95
+ if tokenizer.pad_token is None:
96
+ tokenizer.pad_token = tokenizer.eos_token
97
+ base = AutoModelForCausalLM.from_pretrained(
98
+ "Qwen/Qwen2.5-3B-Instruct",
99
+ quantization_config=bnb,
100
+ dtype=_t.bfloat16,
101
+ device_map="auto",
102
  )
103
+ base = prepare_model_for_kbit_training(base, use_gradient_checkpointing=True)
104
+ peft_config = LoraConfig(
105
+ r=8, lora_alpha=16, lora_dropout=0.0, bias="none",
 
106
  target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
107
+ task_type="CAUSAL_LM",
108
  )
109
+ model = get_peft_model(base, peft_config)
110
  n_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
111
  print(f" ✓ model loaded; {n_trainable / 1e6:.1f}M LoRA params trainable")
112
  except Exception as e:
train/train_grpo.py CHANGED
@@ -255,7 +255,7 @@ def reward_decomposition(text: str, gold: str) -> Dict[str, float]:
255
 
256
  def main():
257
  parser = argparse.ArgumentParser()
258
- parser.add_argument("--model", default="unsloth/Qwen2.5-3B-Instruct")
259
  parser.add_argument("--output-dir", default="./checkpoints/run1")
260
  parser.add_argument("--max-steps", type=int, default=200)
261
  parser.add_argument("--num-generations", type=int, default=4)
@@ -278,25 +278,41 @@ def main():
278
  dataset = build_dataset(scenarios, n_rows=args.n_train_rows)
279
  print(f"[data] {len(dataset)} prompt rows built")
280
 
281
- # Model load
282
- from unsloth import FastLanguageModel
 
 
 
 
 
283
  from trl import GRPOTrainer, GRPOConfig
284
 
285
  print(f"[load] {args.model}, 4-bit, max_seq_length={args.seq_length}")
286
- import torch as _t
287
- model, tokenizer = FastLanguageModel.from_pretrained(
288
- model_name=args.model,
289
- max_seq_length=args.seq_length,
290
  load_in_4bit=True,
291
- dtype=_t.bfloat16, # explicit dtype prevents LoRA Half/Float mismatch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
  )
293
- model = FastLanguageModel.get_peft_model(
294
- model,
295
- r=args.lora_r,
296
- lora_alpha=args.lora_r,
297
  target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
298
- use_gradient_checkpointing=True, # plain torch GC; avoids unsloth-zoo dtype bug
299
  )
 
300
 
301
  config = GRPOConfig(
302
  output_dir=args.output_dir,
 
255
 
256
  def main():
257
  parser = argparse.ArgumentParser()
258
+ parser.add_argument("--model", default="Qwen/Qwen2.5-3B-Instruct")
259
  parser.add_argument("--output-dir", default="./checkpoints/run1")
260
  parser.add_argument("--max-steps", type=int, default=200)
261
  parser.add_argument("--num-generations", type=int, default=4)
 
278
  dataset = build_dataset(scenarios, n_rows=args.n_train_rows)
279
  print(f"[data] {len(dataset)} prompt rows built")
280
 
281
+ # Model load via plain transformers + PEFT (deck-compliant: training uses HF TRL).
282
+ # We dropped Unsloth because their fast_lora kernel has a Half/Float dtype
283
+ # mismatch on Qwen2.5-3B + 4-bit + bf16 in v2026.4.8 (verified via failed
284
+ # smoke runs on L4). Plain transformers+peft+trl is slower but reliable.
285
+ import torch as _t
286
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
287
+ from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
288
  from trl import GRPOTrainer, GRPOConfig
289
 
290
  print(f"[load] {args.model}, 4-bit, max_seq_length={args.seq_length}")
291
+ bnb = BitsAndBytesConfig(
 
 
 
292
  load_in_4bit=True,
293
+ bnb_4bit_compute_dtype=_t.bfloat16,
294
+ bnb_4bit_quant_type="nf4",
295
+ bnb_4bit_use_double_quant=True,
296
+ )
297
+ # Strip the "unsloth/" prefix if the user passed an Unsloth-prefixed name —
298
+ # we now load directly from the upstream Qwen repo.
299
+ model_name = args.model.replace("unsloth/", "Qwen/").replace("-Instruct-bnb-4bit", "-Instruct")
300
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
301
+ if tokenizer.pad_token is None:
302
+ tokenizer.pad_token = tokenizer.eos_token
303
+ base = AutoModelForCausalLM.from_pretrained(
304
+ model_name,
305
+ quantization_config=bnb,
306
+ dtype=_t.bfloat16,
307
+ device_map="auto",
308
  )
309
+ base = prepare_model_for_kbit_training(base, use_gradient_checkpointing=True)
310
+ peft_config = LoraConfig(
311
+ r=args.lora_r, lora_alpha=args.lora_r, lora_dropout=0.0, bias="none",
 
312
  target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
313
+ task_type="CAUSAL_LM",
314
  )
315
+ model = get_peft_model(base, peft_config)
316
 
317
  config = GRPOConfig(
318
  output_dir=args.output_dir,