shank commited on
Commit
ba8df98
Β·
1 Parent(s): b37b2eb

Reduce training to 500 steps with tightened curriculum for A10G budget

Browse files
Files changed (1) hide show
  1. training/train_grpo.py +3 -3
training/train_grpo.py CHANGED
@@ -33,7 +33,7 @@ parser.add_argument("--test", action="store_true", help="Run 10 steps for testin
33
  parser.add_argument("--test-local", action="store_true", dest="test_local",
34
  help="Sanity-check reward function locally without any model or GPU")
35
  parser.add_argument("--resume", type=str, default=None, help="Path to checkpoint")
36
- parser.add_argument("--max_steps", type=int, default=1000)
37
  args = parser.parse_args()
38
 
39
  # ── Install dependencies (for Colab/HF Spaces) ───────────────────────────────
@@ -104,7 +104,7 @@ def load_bugs(tier: int) -> list[dict]:
104
 
105
  def get_bugs_for_step(step: int) -> list[dict]:
106
  tier1 = load_bugs(1)
107
- if step < 300:
108
  return tier1
109
  elif step < 600:
110
  return tier1 + load_bugs(2)
@@ -393,7 +393,7 @@ trainer = GRPOTrainer(
393
  class CurriculumCallback(TrainerCallback):
394
  def on_step_end(self, args, state, control, **kwargs):
395
  step = state.global_step
396
- if step in [300, 600]:
397
  trainer.train_dataset = make_dataset(step)
398
  print(f"\nCurriculum advanced at step {step}!")
399
  if WANDB_API_KEY:
 
33
  parser.add_argument("--test-local", action="store_true", dest="test_local",
34
  help="Sanity-check reward function locally without any model or GPU")
35
  parser.add_argument("--resume", type=str, default=None, help="Path to checkpoint")
36
+ parser.add_argument("--max_steps", type=int, default=500)
37
  args = parser.parse_args()
38
 
39
  # ── Install dependencies (for Colab/HF Spaces) ───────────────────────────────
 
104
 
105
  def get_bugs_for_step(step: int) -> list[dict]:
106
  tier1 = load_bugs(1)
107
+ if step < 150:
108
  return tier1
109
  elif step < 600:
110
  return tier1 + load_bugs(2)
 
393
  class CurriculumCallback(TrainerCallback):
394
  def on_step_end(self, args, state, control, **kwargs):
395
  step = state.global_step
396
+ if step in [150, 350]:
397
  trainer.train_dataset = make_dataset(step)
398
  print(f"\nCurriculum advanced at step {step}!")
399
  if WANDB_API_KEY: