QLoRA best practices: prepare_model_for_kbit_training, paged_adamw_8bit, cosine LR, faster iteration
Browse files- run_training.py +26 -10
run_training.py
CHANGED
|
@@ -61,7 +61,7 @@ def run_grpo_training():
|
|
| 61 |
# ── 1. Load model ──
|
| 62 |
print("\n[1/6] Loading model with bitsandbytes 4-bit...")
|
| 63 |
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
| 64 |
-
from peft import LoraConfig, get_peft_model
|
| 65 |
|
| 66 |
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
|
| 67 |
bnb_config = BitsAndBytesConfig(
|
|
@@ -71,9 +71,23 @@ def run_grpo_training():
|
|
| 71 |
bnb_4bit_use_double_quant=True,
|
| 72 |
)
|
| 73 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
|
|
|
|
|
|
|
|
|
| 74 |
model = AutoModelForCausalLM.from_pretrained(
|
| 75 |
MODEL_NAME, quantization_config=bnb_config, device_map="auto",
|
| 76 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
lora_config = LoraConfig(
|
| 78 |
r=16, lora_alpha=16, lora_dropout=0,
|
| 79 |
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
|
|
@@ -81,13 +95,10 @@ def run_grpo_training():
|
|
| 81 |
task_type="CAUSAL_LM",
|
| 82 |
)
|
| 83 |
model = get_peft_model(model, lora_config)
|
| 84 |
-
model.enable_input_require_grads()
|
| 85 |
print(f" Model: {MODEL_NAME}")
|
| 86 |
print(f" Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
|
| 87 |
|
| 88 |
-
if tokenizer.pad_token is None:
|
| 89 |
-
tokenizer.pad_token = tokenizer.eos_token
|
| 90 |
-
|
| 91 |
# ── 2. Baseline evaluation ──
|
| 92 |
print("\n[2/6] Running baseline evaluation...")
|
| 93 |
import re
|
|
@@ -205,24 +216,29 @@ def run_grpo_training():
|
|
| 205 |
else:
|
| 206 |
obs_dicts.append(ctx)
|
| 207 |
|
| 208 |
-
return compute_grpo_reward_env(texts, obs_dicts, task_config, horizon=
|
| 209 |
|
| 210 |
grpo_config = GRPOConfig(
|
| 211 |
output_dir="training/outputs/grpo_checkpoints",
|
| 212 |
num_train_epochs=3,
|
| 213 |
per_device_train_batch_size=2,
|
| 214 |
-
gradient_accumulation_steps=
|
| 215 |
learning_rate=1e-5,
|
| 216 |
-
logging_steps=
|
| 217 |
save_steps=50,
|
| 218 |
-
|
|
|
|
| 219 |
num_generations=2,
|
|
|
|
| 220 |
report_to="none",
|
| 221 |
remove_unused_columns=False,
|
| 222 |
bf16=_bf16,
|
| 223 |
fp16=_fp16,
|
| 224 |
gradient_checkpointing=True,
|
| 225 |
-
|
|
|
|
|
|
|
|
|
|
| 226 |
)
|
| 227 |
|
| 228 |
train_dataset = Dataset.from_dict({"prompt": prompts, "obs_context": obs_contexts})
|
|
|
|
| 61 |
# ── 1. Load model ──
|
| 62 |
print("\n[1/6] Loading model with bitsandbytes 4-bit...")
|
| 63 |
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
| 64 |
+
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
| 65 |
|
| 66 |
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
|
| 67 |
bnb_config = BitsAndBytesConfig(
|
|
|
|
| 71 |
bnb_4bit_use_double_quant=True,
|
| 72 |
)
|
| 73 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| 74 |
+
if tokenizer.pad_token is None:
|
| 75 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 76 |
+
|
| 77 |
model = AutoModelForCausalLM.from_pretrained(
|
| 78 |
MODEL_NAME, quantization_config=bnb_config, device_map="auto",
|
| 79 |
)
|
| 80 |
+
|
| 81 |
+
# Critical for bnb-4bit + LoRA + gradient checkpointing: cast norms to fp32,
|
| 82 |
+
# enable input grads, and wire up non-reentrant checkpointing.
|
| 83 |
+
model = prepare_model_for_kbit_training(
|
| 84 |
+
model,
|
| 85 |
+
use_gradient_checkpointing=True,
|
| 86 |
+
gradient_checkpointing_kwargs={"use_reentrant": False},
|
| 87 |
+
)
|
| 88 |
+
model.config.pad_token_id = tokenizer.pad_token_id
|
| 89 |
+
model.config.use_cache = False # silences the warning loop during training
|
| 90 |
+
|
| 91 |
lora_config = LoraConfig(
|
| 92 |
r=16, lora_alpha=16, lora_dropout=0,
|
| 93 |
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
|
|
|
|
| 95 |
task_type="CAUSAL_LM",
|
| 96 |
)
|
| 97 |
model = get_peft_model(model, lora_config)
|
| 98 |
+
model.enable_input_require_grads()
|
| 99 |
print(f" Model: {MODEL_NAME}")
|
| 100 |
print(f" Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
|
| 101 |
|
|
|
|
|
|
|
|
|
|
| 102 |
# ── 2. Baseline evaluation ──
|
| 103 |
print("\n[2/6] Running baseline evaluation...")
|
| 104 |
import re
|
|
|
|
| 216 |
else:
|
| 217 |
obs_dicts.append(ctx)
|
| 218 |
|
| 219 |
+
return compute_grpo_reward_env(texts, obs_dicts, task_config, horizon=1)
|
| 220 |
|
| 221 |
grpo_config = GRPOConfig(
|
| 222 |
output_dir="training/outputs/grpo_checkpoints",
|
| 223 |
num_train_epochs=3,
|
| 224 |
per_device_train_batch_size=2,
|
| 225 |
+
gradient_accumulation_steps=2, # was 8 — first visible step lands ~4x sooner
|
| 226 |
learning_rate=1e-5,
|
| 227 |
+
logging_steps=1, # was 5 — see loss every step
|
| 228 |
save_steps=50,
|
| 229 |
+
max_prompt_length=1024, # default 512 truncates Karnataka prompts
|
| 230 |
+
max_completion_length=96, # was 128 — ~25% faster generation
|
| 231 |
num_generations=2,
|
| 232 |
+
temperature=0.7, # was 0.9 default — less wasted sampling
|
| 233 |
report_to="none",
|
| 234 |
remove_unused_columns=False,
|
| 235 |
bf16=_bf16,
|
| 236 |
fp16=_fp16,
|
| 237 |
gradient_checkpointing=True,
|
| 238 |
+
gradient_checkpointing_kwargs={"use_reentrant": False},
|
| 239 |
+
optim="paged_adamw_8bit", # canonical for QLoRA; adafactor fights bf16+bnb
|
| 240 |
+
warmup_ratio=0.05,
|
| 241 |
+
lr_scheduler_type="cosine",
|
| 242 |
)
|
| 243 |
|
| 244 |
train_dataset = Dataset.from_dict({"prompt": prompts, "obs_context": obs_contexts})
|