helloAK96 Claude Opus 4.7 commited on
Commit
6e35cec
·
1 Parent(s): 8878953

GRPO: expose --learning-rate, --temperature, --curriculum-schedule

Browse files

Phase-1 submission upgrade for the hackathon: previous run used the
TRL default LR (5e-6) with EASY-only data and saw a flat reward curve
+ KL=0.14. Three new knobs let us re-target without code edits:

* --learning-rate (default 5e-6, set 2e-5 to break the flat-reward
plateau without touching anything else)
* --temperature (default 0.7, set 0.8 for more exploration in Phase 2)
* --curriculum-schedule "easy:200,medium:200,hard:200" — pre-rolls a
step-budget tier sequence so GRPO sees increasing difficulty over
training instead of EASY for all 600 steps. Falls back to the old
--start-tier behavior when the flag isn't passed.

scripts/jobs_grpo_train.sh forwards GRPO_LR / GRPO_TEMP /
GRPO_CURRICULUM env vars; everything else is unchanged. 110/110 unit
tests pass (helper is a no-op when the schedule env var isn't set).

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

Files changed (2) hide show
  1. scripts/jobs_grpo_train.sh +19 -9
  2. train/grpo_train.py +79 -2
scripts/jobs_grpo_train.sh CHANGED
@@ -26,6 +26,9 @@ GRPO_LORA_RANK="${GRPO_LORA_RANK:-16}"
26
  GRPO_LOG_EVERY="${GRPO_LOG_EVERY:-1}"
27
  GRPO_MAX_SEQ_LENGTH="${GRPO_MAX_SEQ_LENGTH:-1024}"
28
  GRPO_PUSH_TO_HUB="${GRPO_PUSH_TO_HUB:-0}"
 
 
 
29
  HUB_REPO_ID="${HUB_REPO_ID:-helloAK96/chaosops-grpo-lora}"
30
 
31
  OUTPUT_DIR="/workspace/artifacts/chaosops-grpo"
@@ -63,16 +66,23 @@ mkdir -p "${OUTPUT_DIR}"
63
 
64
  GRPO_BACKEND="${GRPO_BACKEND:-transformers}"
65
 
66
- echo "==[chaosops]== launching GRPO (backend=$GRPO_BACKEND, $GRPO_EPISODES episodes, group=$GRPO_GROUP_SIZE, lora_rank=$GRPO_LORA_RANK)"
67
- python -m chaosops.train.grpo_train \
68
- --model-name "${GRPO_MODEL}" \
69
- --backend "${GRPO_BACKEND}" \
70
- --total-episodes "${GRPO_EPISODES}" \
71
- --group-size "${GRPO_GROUP_SIZE}" \
72
- --log-every "${GRPO_LOG_EVERY}" \
73
- --max-seq-length "${GRPO_MAX_SEQ_LENGTH}" \
74
- --lora-rank "${GRPO_LORA_RANK}" \
75
  --output-dir "${OUTPUT_DIR}"
 
 
 
 
 
 
 
76
 
77
  echo "==[chaosops]== training metrics:"
78
  cat "${OUTPUT_DIR}/training_metrics.json" || echo "(no metrics file)"
 
26
  GRPO_LOG_EVERY="${GRPO_LOG_EVERY:-1}"
27
  GRPO_MAX_SEQ_LENGTH="${GRPO_MAX_SEQ_LENGTH:-1024}"
28
  GRPO_PUSH_TO_HUB="${GRPO_PUSH_TO_HUB:-0}"
29
+ GRPO_LR="${GRPO_LR:-5e-6}"
30
+ GRPO_TEMP="${GRPO_TEMP:-0.7}"
31
+ GRPO_CURRICULUM="${GRPO_CURRICULUM:-}"
32
  HUB_REPO_ID="${HUB_REPO_ID:-helloAK96/chaosops-grpo-lora}"
33
 
34
  OUTPUT_DIR="/workspace/artifacts/chaosops-grpo"
 
66
 
67
  GRPO_BACKEND="${GRPO_BACKEND:-transformers}"
68
 
69
+ echo "==[chaosops]== launching GRPO (backend=$GRPO_BACKEND, $GRPO_EPISODES episodes, group=$GRPO_GROUP_SIZE, lora_rank=$GRPO_LORA_RANK, lr=$GRPO_LR, temp=$GRPO_TEMP, curriculum=${GRPO_CURRICULUM:-(none)})"
70
+ PY_ARGS=(
71
+ --model-name "${GRPO_MODEL}"
72
+ --backend "${GRPO_BACKEND}"
73
+ --total-episodes "${GRPO_EPISODES}"
74
+ --group-size "${GRPO_GROUP_SIZE}"
75
+ --log-every "${GRPO_LOG_EVERY}"
76
+ --max-seq-length "${GRPO_MAX_SEQ_LENGTH}"
77
+ --lora-rank "${GRPO_LORA_RANK}"
78
  --output-dir "${OUTPUT_DIR}"
79
+ --learning-rate "${GRPO_LR}"
80
+ --temperature "${GRPO_TEMP}"
81
+ )
82
+ if [ -n "${GRPO_CURRICULUM}" ]; then
83
+ PY_ARGS+=(--curriculum-schedule "${GRPO_CURRICULUM}")
84
+ fi
85
+ python -m chaosops.train.grpo_train "${PY_ARGS[@]}"
86
 
87
  echo "==[chaosops]== training metrics:"
88
  cat "${OUTPUT_DIR}/training_metrics.json" || echo "(no metrics file)"
train/grpo_train.py CHANGED
@@ -567,6 +567,48 @@ def _collect_scenarios(curriculum: Curriculum, *, total: int) -> list[Scenario]:
567
  return scenarios[:total]
568
 
569
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
570
  # ---------------------------------------------------------------------------
571
  # Training loop — modern TRL GRPO API
572
  # ---------------------------------------------------------------------------
@@ -585,6 +627,8 @@ def run_grpo(
585
  max_seq_length: int = 1024,
586
  max_completion_length: int = 96,
587
  learning_rate: float = 5e-6,
 
 
588
  ) -> dict[str, Any]:
589
  """Run GRPO training via TRL's GRPOTrainer.
590
 
@@ -597,7 +641,16 @@ def run_grpo(
597
  output_dir.mkdir(parents=True, exist_ok=True)
598
 
599
  scenario_count = max(total_episodes, 8)
600
- scenarios = _collect_scenarios(curriculum, total=scenario_count)
 
 
 
 
 
 
 
 
 
601
  dataset = build_training_dataset(scenarios)
602
 
603
  # Every optim step: 1 unique prompt × group_size completions.
@@ -608,7 +661,7 @@ def run_grpo(
608
  per_device_train_batch_size=per_device_train_batch_size,
609
  gradient_accumulation_steps=1,
610
  num_generations=group_size,
611
- temperature=0.7,
612
  max_prompt_length=max_seq_length,
613
  max_completion_length=max_completion_length,
614
  learning_rate=learning_rate,
@@ -710,6 +763,27 @@ def _parse_args() -> argparse.Namespace:
710
  choices=["auto", "unsloth", "transformers"],
711
  help="Model loader. 'auto' tries Unsloth, falls back to transformers.",
712
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
713
  return parser.parse_args()
714
 
715
 
@@ -732,6 +806,9 @@ def main() -> None:
732
  log_every=args.log_every,
733
  output_dir=args.output_dir,
734
  max_seq_length=args.max_seq_length,
 
 
 
735
  )
736
  print(json.dumps(summary, indent=2))
737
 
 
567
  return scenarios[:total]
568
 
569
 
570
+ def _scenarios_from_schedule(schedule: str, *, total: int) -> list[Scenario]:
571
+ """Build a curriculum dataset from a step-budget schedule.
572
+
573
+ Format: ``"easy:200,medium:200,hard:200"`` — generates 200 EASY then 200
574
+ MEDIUM then 200 HARD scenarios so TRL's GRPOTrainer (which iterates the
575
+ dataset in order under ``shuffle=False`` semantics for max_steps) sees
576
+ increasing difficulty over training.
577
+
578
+ If the schedule's total < ``total``, the last tier is padded by cycling
579
+ its failure types until ``total`` is reached.
580
+ """
581
+ parsed: list[tuple[DifficultyTier, int]] = []
582
+ for chunk in schedule.split(","):
583
+ tier_name, _, count = chunk.partition(":")
584
+ tier = DifficultyTier(tier_name.strip().lower())
585
+ parsed.append((tier, int(count.strip())))
586
+
587
+ scenarios: list[Scenario] = []
588
+ for tier, count in parsed:
589
+ cycle_seed = 0
590
+ tier_scenarios: list[Scenario] = []
591
+ while len(tier_scenarios) < count:
592
+ batch = scenarios_for_tier(
593
+ tier, seed_offset=cycle_seed, episodes_per_type=1
594
+ )
595
+ tier_scenarios.extend(batch)
596
+ cycle_seed += 97
597
+ scenarios.extend(tier_scenarios[:count])
598
+
599
+ # Pad with the last tier if the schedule under-shoots ``total``.
600
+ if scenarios and len(scenarios) < total:
601
+ last_tier = parsed[-1][0]
602
+ cycle_seed = 9000 # offset past the schedule's seeds
603
+ while len(scenarios) < total:
604
+ batch = scenarios_for_tier(
605
+ last_tier, seed_offset=cycle_seed, episodes_per_type=1
606
+ )
607
+ scenarios.extend(batch)
608
+ cycle_seed += 97
609
+ return scenarios[:total]
610
+
611
+
612
  # ---------------------------------------------------------------------------
613
  # Training loop — modern TRL GRPO API
614
  # ---------------------------------------------------------------------------
 
627
  max_seq_length: int = 1024,
628
  max_completion_length: int = 96,
629
  learning_rate: float = 5e-6,
630
+ temperature: float = 0.7,
631
+ curriculum_schedule: str | None = None,
632
  ) -> dict[str, Any]:
633
  """Run GRPO training via TRL's GRPOTrainer.
634
 
 
641
  output_dir.mkdir(parents=True, exist_ok=True)
642
 
643
  scenario_count = max(total_episodes, 8)
644
+ if curriculum_schedule:
645
+ scenarios = _scenarios_from_schedule(
646
+ curriculum_schedule, total=scenario_count
647
+ )
648
+ print(
649
+ f"[grpo_train] curriculum schedule active: {curriculum_schedule} "
650
+ f"({len(scenarios)} scenarios across tiers)"
651
+ )
652
+ else:
653
+ scenarios = _collect_scenarios(curriculum, total=scenario_count)
654
  dataset = build_training_dataset(scenarios)
655
 
656
  # Every optim step: 1 unique prompt × group_size completions.
 
661
  per_device_train_batch_size=per_device_train_batch_size,
662
  gradient_accumulation_steps=1,
663
  num_generations=group_size,
664
+ temperature=temperature,
665
  max_prompt_length=max_seq_length,
666
  max_completion_length=max_completion_length,
667
  learning_rate=learning_rate,
 
763
  choices=["auto", "unsloth", "transformers"],
764
  help="Model loader. 'auto' tries Unsloth, falls back to transformers.",
765
  )
766
+ parser.add_argument(
767
+ "--learning-rate",
768
+ type=float,
769
+ default=5e-6,
770
+ help="GRPO learning rate. Default 5e-6; 2e-5 if reward stays flat.",
771
+ )
772
+ parser.add_argument(
773
+ "--temperature",
774
+ type=float,
775
+ default=0.7,
776
+ help="Sampling temperature for completions during GRPO rollout.",
777
+ )
778
+ parser.add_argument(
779
+ "--curriculum-schedule",
780
+ type=str,
781
+ default=None,
782
+ help=(
783
+ "Step-budget tier schedule, e.g. 'easy:200,medium:200,hard:200'. "
784
+ "Overrides --start-tier when set."
785
+ ),
786
+ )
787
  return parser.parse_args()
788
 
789
 
 
806
  log_every=args.log_every,
807
  output_dir=args.output_dir,
808
  max_seq_length=args.max_seq_length,
809
+ learning_rate=args.learning_rate,
810
+ temperature=args.temperature,
811
+ curriculum_schedule=args.curriculum_schedule,
812
  )
813
  print(json.dumps(summary, indent=2))
814