shank commited on
Commit
ea6fe4e
·
1 Parent(s): 9487853

Auto-detect GPU: bfloat16+batch2+gen8 on A100, float16+batch1+gen4 on T4 — same script works on both

Browse files
Files changed (1) hide show
  1. training/train_grpo.py +47 -10
training/train_grpo.py CHANGED
@@ -2,7 +2,7 @@
2
  AgentDebuggerEnv — GRPO Training Script
3
  Model: Qwen2.5-Coder-7B-Instruct (4-bit quantized via bitsandbytes)
4
  Algorithm: GRPO (Group Relative Policy Optimization) via HuggingFace TRL
5
- GPU: Kaggle P100 (16GB) float16 only, no bfloat16
6
 
7
  Usage:
8
  # Local reward sanity-check (no GPU, no model loading):
@@ -257,12 +257,49 @@ if args.test_local:
257
  print("\nLOCAL TEST PASSED")
258
  sys.exit(0)
259
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  # ── Load model ────────────────────────────────────────────────────────────────
261
  print(f"Loading {MODEL_NAME}...")
262
  bnb_config = BitsAndBytesConfig(
263
  load_in_4bit=True,
264
  bnb_4bit_quant_type="nf4",
265
- bnb_4bit_compute_dtype=torch.float16, # P100 has no bfloat16 hardware support
266
  bnb_4bit_use_double_quant=True,
267
  )
268
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
@@ -274,13 +311,13 @@ model = AutoModelForCausalLM.from_pretrained(
274
  quantization_config=bnb_config,
275
  device_map="auto",
276
  trust_remote_code=True,
277
- torch_dtype=torch.float16, # P100 has no bfloat16 hardware support
278
  )
279
  model.config.use_cache = False
280
 
281
  lora_config = LoraConfig(
282
- r=8, # P100: 16GB VRAM, halved from r=16
283
- lora_alpha=16,
284
  target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
285
  "gate_proj", "up_proj", "down_proj"],
286
  lora_dropout=0.0,
@@ -407,16 +444,16 @@ def make_dataset(step: int) -> Dataset:
407
  config = GRPOConfig(
408
  output_dir=CHECKPOINT_DIR,
409
  max_steps=MAX_STEPS,
410
- per_device_train_batch_size=1, # P100 16GB: must be 1
411
- gradient_accumulation_steps=8, # effective batch = 8 (compensates for batch=1)
412
  learning_rate=2e-5,
413
  lr_scheduler_type="cosine",
414
  warmup_steps=10 if args.test else 30,
415
- num_generations=4, # P100: halved from 8 to fit in 16GB
416
- max_completion_length=160, # T4: shorter completions = faster generation (bottleneck)
417
  temperature=0.9,
418
  logging_steps=5,
419
- save_steps=50 if args.test else 50,
420
  report_to="wandb" if WANDB_API_KEY else "none",
421
  )
422
 
 
2
  AgentDebuggerEnv — GRPO Training Script
3
  Model: Qwen2.5-Coder-7B-Instruct (4-bit quantized via bitsandbytes)
4
  Algorithm: GRPO (Group Relative Policy Optimization) via HuggingFace TRL
5
+ GPU: auto-detected at runtime (A100/H100 bfloat16+large batch, T4/V100 → float16+small batch)
6
 
7
  Usage:
8
  # Local reward sanity-check (no GPU, no model loading):
 
257
  print("\nLOCAL TEST PASSED")
258
  sys.exit(0)
259
 
260
+ # ── Auto-detect GPU and set optimal config ────────────────────────────────────
261
+ _gpu_vram_gb = 0
262
+ _is_ampere_plus = False # A100/H100 support bfloat16 natively (compute cap >= 8.0)
263
+ if torch.cuda.is_available():
264
+ _props = torch.cuda.get_device_properties(0)
265
+ _gpu_vram_gb = _props.total_memory / 1e9
266
+ _is_ampere_plus = _props.major >= 8
267
+ print(f"GPU: {_props.name} | VRAM: {_gpu_vram_gb:.1f}GB | "
268
+ f"Compute cap: {_props.major}.{_props.minor} | "
269
+ f"bfloat16: {'yes' if _is_ampere_plus else 'no'}")
270
+
271
+ COMPUTE_DTYPE = torch.bfloat16 if _is_ampere_plus else torch.float16
272
+
273
+ # Scale batch/generation config to available VRAM
274
+ if _gpu_vram_gb >= 40: # A100 40GB / A100 80GB
275
+ _batch = 2
276
+ _grad_accum = 4 # effective batch = 8
277
+ _num_gen = 8
278
+ _max_comp = 256
279
+ _lora_r = 16
280
+ elif _gpu_vram_gb >= 20: # V100 32GB
281
+ _batch = 1
282
+ _grad_accum = 8
283
+ _num_gen = 6
284
+ _max_comp = 220
285
+ _lora_r = 12
286
+ else: # T4 15GB / anything smaller
287
+ _batch = 1
288
+ _grad_accum = 8
289
+ _num_gen = 4
290
+ _max_comp = 160
291
+ _lora_r = 8
292
+
293
+ print(f"Training config: batch={_batch} grad_accum={_grad_accum} "
294
+ f"num_gen={_num_gen} max_comp={_max_comp} lora_r={_lora_r} "
295
+ f"dtype={COMPUTE_DTYPE}")
296
+
297
  # ── Load model ────────────────────────────────────────────────────────────────
298
  print(f"Loading {MODEL_NAME}...")
299
  bnb_config = BitsAndBytesConfig(
300
  load_in_4bit=True,
301
  bnb_4bit_quant_type="nf4",
302
+ bnb_4bit_compute_dtype=COMPUTE_DTYPE,
303
  bnb_4bit_use_double_quant=True,
304
  )
305
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
 
311
  quantization_config=bnb_config,
312
  device_map="auto",
313
  trust_remote_code=True,
314
+ torch_dtype=COMPUTE_DTYPE,
315
  )
316
  model.config.use_cache = False
317
 
318
  lora_config = LoraConfig(
319
+ r=_lora_r,
320
+ lora_alpha=_lora_r * 2,
321
  target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
322
  "gate_proj", "up_proj", "down_proj"],
323
  lora_dropout=0.0,
 
444
  config = GRPOConfig(
445
  output_dir=CHECKPOINT_DIR,
446
  max_steps=MAX_STEPS,
447
+ per_device_train_batch_size=_batch,
448
+ gradient_accumulation_steps=_grad_accum,
449
  learning_rate=2e-5,
450
  lr_scheduler_type="cosine",
451
  warmup_steps=10 if args.test else 30,
452
+ num_generations=_num_gen,
453
+ max_completion_length=_max_comp,
454
  temperature=0.9,
455
  logging_steps=5,
456
+ save_steps=50,
457
  report_to="wandb" if WANDB_API_KEY else "none",
458
  )
459