shank commited on
Commit ·
ea6fe4e
1
Parent(s): 9487853
Auto-detect GPU: bfloat16+batch2+gen8 on A100, float16+batch1+gen4 on T4 — same script works on both
Browse files- training/train_grpo.py +47 -10
training/train_grpo.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
| 2 |
AgentDebuggerEnv — GRPO Training Script
|
| 3 |
Model: Qwen2.5-Coder-7B-Instruct (4-bit quantized via bitsandbytes)
|
| 4 |
Algorithm: GRPO (Group Relative Policy Optimization) via HuggingFace TRL
|
| 5 |
-
GPU:
|
| 6 |
|
| 7 |
Usage:
|
| 8 |
# Local reward sanity-check (no GPU, no model loading):
|
|
@@ -257,12 +257,49 @@ if args.test_local:
|
|
| 257 |
print("\nLOCAL TEST PASSED")
|
| 258 |
sys.exit(0)
|
| 259 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
# ── Load model ────────────────────────────────────────────────────────────────
|
| 261 |
print(f"Loading {MODEL_NAME}...")
|
| 262 |
bnb_config = BitsAndBytesConfig(
|
| 263 |
load_in_4bit=True,
|
| 264 |
bnb_4bit_quant_type="nf4",
|
| 265 |
-
bnb_4bit_compute_dtype=
|
| 266 |
bnb_4bit_use_double_quant=True,
|
| 267 |
)
|
| 268 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
|
|
@@ -274,13 +311,13 @@ model = AutoModelForCausalLM.from_pretrained(
|
|
| 274 |
quantization_config=bnb_config,
|
| 275 |
device_map="auto",
|
| 276 |
trust_remote_code=True,
|
| 277 |
-
torch_dtype=
|
| 278 |
)
|
| 279 |
model.config.use_cache = False
|
| 280 |
|
| 281 |
lora_config = LoraConfig(
|
| 282 |
-
r=
|
| 283 |
-
lora_alpha=
|
| 284 |
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
|
| 285 |
"gate_proj", "up_proj", "down_proj"],
|
| 286 |
lora_dropout=0.0,
|
|
@@ -407,16 +444,16 @@ def make_dataset(step: int) -> Dataset:
|
|
| 407 |
config = GRPOConfig(
|
| 408 |
output_dir=CHECKPOINT_DIR,
|
| 409 |
max_steps=MAX_STEPS,
|
| 410 |
-
per_device_train_batch_size=
|
| 411 |
-
gradient_accumulation_steps=
|
| 412 |
learning_rate=2e-5,
|
| 413 |
lr_scheduler_type="cosine",
|
| 414 |
warmup_steps=10 if args.test else 30,
|
| 415 |
-
num_generations=
|
| 416 |
-
max_completion_length=
|
| 417 |
temperature=0.9,
|
| 418 |
logging_steps=5,
|
| 419 |
-
save_steps=50
|
| 420 |
report_to="wandb" if WANDB_API_KEY else "none",
|
| 421 |
)
|
| 422 |
|
|
|
|
| 2 |
AgentDebuggerEnv — GRPO Training Script
|
| 3 |
Model: Qwen2.5-Coder-7B-Instruct (4-bit quantized via bitsandbytes)
|
| 4 |
Algorithm: GRPO (Group Relative Policy Optimization) via HuggingFace TRL
|
| 5 |
+
GPU: auto-detected at runtime (A100/H100 → bfloat16+large batch, T4/V100 → float16+small batch)
|
| 6 |
|
| 7 |
Usage:
|
| 8 |
# Local reward sanity-check (no GPU, no model loading):
|
|
|
|
| 257 |
print("\nLOCAL TEST PASSED")
|
| 258 |
sys.exit(0)
|
| 259 |
|
| 260 |
+
# ── Auto-detect GPU and set optimal config ────────────────────────────────────
|
| 261 |
+
_gpu_vram_gb = 0
|
| 262 |
+
_is_ampere_plus = False # A100/H100 support bfloat16 natively (compute cap >= 8.0)
|
| 263 |
+
if torch.cuda.is_available():
|
| 264 |
+
_props = torch.cuda.get_device_properties(0)
|
| 265 |
+
_gpu_vram_gb = _props.total_memory / 1e9
|
| 266 |
+
_is_ampere_plus = _props.major >= 8
|
| 267 |
+
print(f"GPU: {_props.name} | VRAM: {_gpu_vram_gb:.1f}GB | "
|
| 268 |
+
f"Compute cap: {_props.major}.{_props.minor} | "
|
| 269 |
+
f"bfloat16: {'yes' if _is_ampere_plus else 'no'}")
|
| 270 |
+
|
| 271 |
+
COMPUTE_DTYPE = torch.bfloat16 if _is_ampere_plus else torch.float16
|
| 272 |
+
|
| 273 |
+
# Scale batch/generation config to available VRAM
|
| 274 |
+
if _gpu_vram_gb >= 40: # A100 40GB / A100 80GB
|
| 275 |
+
_batch = 2
|
| 276 |
+
_grad_accum = 4 # effective batch = 8
|
| 277 |
+
_num_gen = 8
|
| 278 |
+
_max_comp = 256
|
| 279 |
+
_lora_r = 16
|
| 280 |
+
elif _gpu_vram_gb >= 20: # V100 32GB
|
| 281 |
+
_batch = 1
|
| 282 |
+
_grad_accum = 8
|
| 283 |
+
_num_gen = 6
|
| 284 |
+
_max_comp = 220
|
| 285 |
+
_lora_r = 12
|
| 286 |
+
else: # T4 15GB / anything smaller
|
| 287 |
+
_batch = 1
|
| 288 |
+
_grad_accum = 8
|
| 289 |
+
_num_gen = 4
|
| 290 |
+
_max_comp = 160
|
| 291 |
+
_lora_r = 8
|
| 292 |
+
|
| 293 |
+
print(f"Training config: batch={_batch} grad_accum={_grad_accum} "
|
| 294 |
+
f"num_gen={_num_gen} max_comp={_max_comp} lora_r={_lora_r} "
|
| 295 |
+
f"dtype={COMPUTE_DTYPE}")
|
| 296 |
+
|
| 297 |
# ── Load model ────────────────────────────────────────────────────────────────
|
| 298 |
print(f"Loading {MODEL_NAME}...")
|
| 299 |
bnb_config = BitsAndBytesConfig(
|
| 300 |
load_in_4bit=True,
|
| 301 |
bnb_4bit_quant_type="nf4",
|
| 302 |
+
bnb_4bit_compute_dtype=COMPUTE_DTYPE,
|
| 303 |
bnb_4bit_use_double_quant=True,
|
| 304 |
)
|
| 305 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
|
|
|
|
| 311 |
quantization_config=bnb_config,
|
| 312 |
device_map="auto",
|
| 313 |
trust_remote_code=True,
|
| 314 |
+
torch_dtype=COMPUTE_DTYPE,
|
| 315 |
)
|
| 316 |
model.config.use_cache = False
|
| 317 |
|
| 318 |
lora_config = LoraConfig(
|
| 319 |
+
r=_lora_r,
|
| 320 |
+
lora_alpha=_lora_r * 2,
|
| 321 |
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
|
| 322 |
"gate_proj", "up_proj", "down_proj"],
|
| 323 |
lora_dropout=0.0,
|
|
|
|
| 444 |
config = GRPOConfig(
|
| 445 |
output_dir=CHECKPOINT_DIR,
|
| 446 |
max_steps=MAX_STEPS,
|
| 447 |
+
per_device_train_batch_size=_batch,
|
| 448 |
+
gradient_accumulation_steps=_grad_accum,
|
| 449 |
learning_rate=2e-5,
|
| 450 |
lr_scheduler_type="cosine",
|
| 451 |
warmup_steps=10 if args.test else 30,
|
| 452 |
+
num_generations=_num_gen,
|
| 453 |
+
max_completion_length=_max_comp,
|
| 454 |
temperature=0.9,
|
| 455 |
logging_steps=5,
|
| 456 |
+
save_steps=50,
|
| 457 |
report_to="wandb" if WANDB_API_KEY else "none",
|
| 458 |
)
|
| 459 |
|