K446 commited on
Commit
8dab919
·
1 Parent(s): c505237

QLoRA best practices: prepare_model_for_kbit_training, paged_adamw_8bit, cosine LR, faster iteration

Browse files
Files changed (1) hide show
  1. run_training.py +26 -10
run_training.py CHANGED
@@ -61,7 +61,7 @@ def run_grpo_training():
61
  # ── 1. Load model ──
62
  print("\n[1/6] Loading model with bitsandbytes 4-bit...")
63
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
64
- from peft import LoraConfig, get_peft_model
65
 
66
  MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
67
  bnb_config = BitsAndBytesConfig(
@@ -71,9 +71,23 @@ def run_grpo_training():
71
  bnb_4bit_use_double_quant=True,
72
  )
73
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
 
 
 
74
  model = AutoModelForCausalLM.from_pretrained(
75
  MODEL_NAME, quantization_config=bnb_config, device_map="auto",
76
  )
 
 
 
 
 
 
 
 
 
 
 
77
  lora_config = LoraConfig(
78
  r=16, lora_alpha=16, lora_dropout=0,
79
  target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
@@ -81,13 +95,10 @@ def run_grpo_training():
81
  task_type="CAUSAL_LM",
82
  )
83
  model = get_peft_model(model, lora_config)
84
- model.enable_input_require_grads() # Required for gradient checkpointing + 4-bit
85
  print(f" Model: {MODEL_NAME}")
86
  print(f" Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
87
 
88
- if tokenizer.pad_token is None:
89
- tokenizer.pad_token = tokenizer.eos_token
90
-
91
  # ── 2. Baseline evaluation ──
92
  print("\n[2/6] Running baseline evaluation...")
93
  import re
@@ -205,24 +216,29 @@ def run_grpo_training():
205
  else:
206
  obs_dicts.append(ctx)
207
 
208
- return compute_grpo_reward_env(texts, obs_dicts, task_config, horizon=3)
209
 
210
  grpo_config = GRPOConfig(
211
  output_dir="training/outputs/grpo_checkpoints",
212
  num_train_epochs=3,
213
  per_device_train_batch_size=2,
214
- gradient_accumulation_steps=8,
215
  learning_rate=1e-5,
216
- logging_steps=5,
217
  save_steps=50,
218
- max_completion_length=128,
 
219
  num_generations=2,
 
220
  report_to="none",
221
  remove_unused_columns=False,
222
  bf16=_bf16,
223
  fp16=_fp16,
224
  gradient_checkpointing=True,
225
- optim="adafactor",
 
 
 
226
  )
227
 
228
  train_dataset = Dataset.from_dict({"prompt": prompts, "obs_context": obs_contexts})
 
61
  # ── 1. Load model ──
62
  print("\n[1/6] Loading model with bitsandbytes 4-bit...")
63
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
64
+ from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
65
 
66
  MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
67
  bnb_config = BitsAndBytesConfig(
 
71
  bnb_4bit_use_double_quant=True,
72
  )
73
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
74
+ if tokenizer.pad_token is None:
75
+ tokenizer.pad_token = tokenizer.eos_token
76
+
77
  model = AutoModelForCausalLM.from_pretrained(
78
  MODEL_NAME, quantization_config=bnb_config, device_map="auto",
79
  )
80
+
81
+ # Critical for bnb-4bit + LoRA + gradient checkpointing: cast norms to fp32,
82
+ # enable input grads, and wire up non-reentrant checkpointing.
83
+ model = prepare_model_for_kbit_training(
84
+ model,
85
+ use_gradient_checkpointing=True,
86
+ gradient_checkpointing_kwargs={"use_reentrant": False},
87
+ )
88
+ model.config.pad_token_id = tokenizer.pad_token_id
89
+ model.config.use_cache = False # silences the warning loop during training
90
+
91
  lora_config = LoraConfig(
92
  r=16, lora_alpha=16, lora_dropout=0,
93
  target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
 
95
  task_type="CAUSAL_LM",
96
  )
97
  model = get_peft_model(model, lora_config)
98
+ model.enable_input_require_grads()
99
  print(f" Model: {MODEL_NAME}")
100
  print(f" Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
101
 
 
 
 
102
  # ── 2. Baseline evaluation ──
103
  print("\n[2/6] Running baseline evaluation...")
104
  import re
 
216
  else:
217
  obs_dicts.append(ctx)
218
 
219
+ return compute_grpo_reward_env(texts, obs_dicts, task_config, horizon=1)
220
 
221
  grpo_config = GRPOConfig(
222
  output_dir="training/outputs/grpo_checkpoints",
223
  num_train_epochs=3,
224
  per_device_train_batch_size=2,
225
+ gradient_accumulation_steps=2, # was 8 — first visible step lands ~4x sooner
226
  learning_rate=1e-5,
227
+ logging_steps=1, # was 5 — see loss every step
228
  save_steps=50,
229
+ max_prompt_length=1024, # default 512 truncates Karnataka prompts
230
+ max_completion_length=96, # was 128 — ~25% faster generation
231
  num_generations=2,
232
+ temperature=0.7, # was 0.9 default — less wasted sampling
233
  report_to="none",
234
  remove_unused_columns=False,
235
  bf16=_bf16,
236
  fp16=_fp16,
237
  gradient_checkpointing=True,
238
+ gradient_checkpointing_kwargs={"use_reentrant": False},
239
+ optim="paged_adamw_8bit", # canonical for QLoRA; adafactor fights bf16+bnb
240
+ warmup_ratio=0.05,
241
+ lr_scheduler_type="cosine",
242
  )
243
 
244
  train_dataset = Dataset.from_dict({"prompt": prompts, "obs_context": obs_contexts})