GRPO: expose --learning-rate, --temperature, --curriculum-schedule
Browse filesPhase-1 submission upgrade for the hackathon: previous run used the
TRL default LR (5e-6) with EASY-only data and saw a flat reward curve
+ KL=0.14. Three new knobs let us re-target without code edits:
* --learning-rate (default 5e-6, set 2e-5 to break the flat-reward
plateau without touching anything else)
* --temperature (default 0.7, set 0.8 for more exploration in Phase 2)
* --curriculum-schedule "easy:200,medium:200,hard:200" — pre-rolls a
step-budget tier sequence so GRPO sees increasing difficulty over
training instead of EASY for all 600 steps. Falls back to the old
--start-tier behavior when the flag isn't passed.
scripts/jobs_grpo_train.sh forwards GRPO_LR / GRPO_TEMP /
GRPO_CURRICULUM env vars; everything else is unchanged. 110/110 unit
tests pass (helper is a no-op when the schedule env var isn't set).
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
- scripts/jobs_grpo_train.sh +19 -9
- train/grpo_train.py +79 -2
|
@@ -26,6 +26,9 @@ GRPO_LORA_RANK="${GRPO_LORA_RANK:-16}"
|
|
| 26 |
GRPO_LOG_EVERY="${GRPO_LOG_EVERY:-1}"
|
| 27 |
GRPO_MAX_SEQ_LENGTH="${GRPO_MAX_SEQ_LENGTH:-1024}"
|
| 28 |
GRPO_PUSH_TO_HUB="${GRPO_PUSH_TO_HUB:-0}"
|
|
|
|
|
|
|
|
|
|
| 29 |
HUB_REPO_ID="${HUB_REPO_ID:-helloAK96/chaosops-grpo-lora}"
|
| 30 |
|
| 31 |
OUTPUT_DIR="/workspace/artifacts/chaosops-grpo"
|
|
@@ -63,16 +66,23 @@ mkdir -p "${OUTPUT_DIR}"
|
|
| 63 |
|
| 64 |
GRPO_BACKEND="${GRPO_BACKEND:-transformers}"
|
| 65 |
|
| 66 |
-
echo "==[chaosops]== launching GRPO (backend=$GRPO_BACKEND, $GRPO_EPISODES episodes, group=$GRPO_GROUP_SIZE, lora_rank=$GRPO_LORA_RANK)"
|
| 67 |
-
|
| 68 |
-
--model-name "${GRPO_MODEL}"
|
| 69 |
-
--backend "${GRPO_BACKEND}"
|
| 70 |
-
--total-episodes "${GRPO_EPISODES}"
|
| 71 |
-
--group-size "${GRPO_GROUP_SIZE}"
|
| 72 |
-
--log-every "${GRPO_LOG_EVERY}"
|
| 73 |
-
--max-seq-length "${GRPO_MAX_SEQ_LENGTH}"
|
| 74 |
-
--lora-rank "${GRPO_LORA_RANK}"
|
| 75 |
--output-dir "${OUTPUT_DIR}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
|
| 77 |
echo "==[chaosops]== training metrics:"
|
| 78 |
cat "${OUTPUT_DIR}/training_metrics.json" || echo "(no metrics file)"
|
|
|
|
| 26 |
GRPO_LOG_EVERY="${GRPO_LOG_EVERY:-1}"
|
| 27 |
GRPO_MAX_SEQ_LENGTH="${GRPO_MAX_SEQ_LENGTH:-1024}"
|
| 28 |
GRPO_PUSH_TO_HUB="${GRPO_PUSH_TO_HUB:-0}"
|
| 29 |
+
GRPO_LR="${GRPO_LR:-5e-6}"
|
| 30 |
+
GRPO_TEMP="${GRPO_TEMP:-0.7}"
|
| 31 |
+
GRPO_CURRICULUM="${GRPO_CURRICULUM:-}"
|
| 32 |
HUB_REPO_ID="${HUB_REPO_ID:-helloAK96/chaosops-grpo-lora}"
|
| 33 |
|
| 34 |
OUTPUT_DIR="/workspace/artifacts/chaosops-grpo"
|
|
|
|
| 66 |
|
| 67 |
GRPO_BACKEND="${GRPO_BACKEND:-transformers}"
|
| 68 |
|
| 69 |
+
echo "==[chaosops]== launching GRPO (backend=$GRPO_BACKEND, $GRPO_EPISODES episodes, group=$GRPO_GROUP_SIZE, lora_rank=$GRPO_LORA_RANK, lr=$GRPO_LR, temp=$GRPO_TEMP, curriculum=${GRPO_CURRICULUM:-(none)})"
|
| 70 |
+
PY_ARGS=(
|
| 71 |
+
--model-name "${GRPO_MODEL}"
|
| 72 |
+
--backend "${GRPO_BACKEND}"
|
| 73 |
+
--total-episodes "${GRPO_EPISODES}"
|
| 74 |
+
--group-size "${GRPO_GROUP_SIZE}"
|
| 75 |
+
--log-every "${GRPO_LOG_EVERY}"
|
| 76 |
+
--max-seq-length "${GRPO_MAX_SEQ_LENGTH}"
|
| 77 |
+
--lora-rank "${GRPO_LORA_RANK}"
|
| 78 |
--output-dir "${OUTPUT_DIR}"
|
| 79 |
+
--learning-rate "${GRPO_LR}"
|
| 80 |
+
--temperature "${GRPO_TEMP}"
|
| 81 |
+
)
|
| 82 |
+
if [ -n "${GRPO_CURRICULUM}" ]; then
|
| 83 |
+
PY_ARGS+=(--curriculum-schedule "${GRPO_CURRICULUM}")
|
| 84 |
+
fi
|
| 85 |
+
python -m chaosops.train.grpo_train "${PY_ARGS[@]}"
|
| 86 |
|
| 87 |
echo "==[chaosops]== training metrics:"
|
| 88 |
cat "${OUTPUT_DIR}/training_metrics.json" || echo "(no metrics file)"
|
|
@@ -567,6 +567,48 @@ def _collect_scenarios(curriculum: Curriculum, *, total: int) -> list[Scenario]:
|
|
| 567 |
return scenarios[:total]
|
| 568 |
|
| 569 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 570 |
# ---------------------------------------------------------------------------
|
| 571 |
# Training loop — modern TRL GRPO API
|
| 572 |
# ---------------------------------------------------------------------------
|
|
@@ -585,6 +627,8 @@ def run_grpo(
|
|
| 585 |
max_seq_length: int = 1024,
|
| 586 |
max_completion_length: int = 96,
|
| 587 |
learning_rate: float = 5e-6,
|
|
|
|
|
|
|
| 588 |
) -> dict[str, Any]:
|
| 589 |
"""Run GRPO training via TRL's GRPOTrainer.
|
| 590 |
|
|
@@ -597,7 +641,16 @@ def run_grpo(
|
|
| 597 |
output_dir.mkdir(parents=True, exist_ok=True)
|
| 598 |
|
| 599 |
scenario_count = max(total_episodes, 8)
|
| 600 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 601 |
dataset = build_training_dataset(scenarios)
|
| 602 |
|
| 603 |
# Every optim step: 1 unique prompt × group_size completions.
|
|
@@ -608,7 +661,7 @@ def run_grpo(
|
|
| 608 |
per_device_train_batch_size=per_device_train_batch_size,
|
| 609 |
gradient_accumulation_steps=1,
|
| 610 |
num_generations=group_size,
|
| 611 |
-
temperature=
|
| 612 |
max_prompt_length=max_seq_length,
|
| 613 |
max_completion_length=max_completion_length,
|
| 614 |
learning_rate=learning_rate,
|
|
@@ -710,6 +763,27 @@ def _parse_args() -> argparse.Namespace:
|
|
| 710 |
choices=["auto", "unsloth", "transformers"],
|
| 711 |
help="Model loader. 'auto' tries Unsloth, falls back to transformers.",
|
| 712 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 713 |
return parser.parse_args()
|
| 714 |
|
| 715 |
|
|
@@ -732,6 +806,9 @@ def main() -> None:
|
|
| 732 |
log_every=args.log_every,
|
| 733 |
output_dir=args.output_dir,
|
| 734 |
max_seq_length=args.max_seq_length,
|
|
|
|
|
|
|
|
|
|
| 735 |
)
|
| 736 |
print(json.dumps(summary, indent=2))
|
| 737 |
|
|
|
|
| 567 |
return scenarios[:total]
|
| 568 |
|
| 569 |
|
| 570 |
+
def _scenarios_from_schedule(schedule: str, *, total: int) -> list[Scenario]:
|
| 571 |
+
"""Build a curriculum dataset from a step-budget schedule.
|
| 572 |
+
|
| 573 |
+
Format: ``"easy:200,medium:200,hard:200"`` — generates 200 EASY then 200
|
| 574 |
+
MEDIUM then 200 HARD scenarios so TRL's GRPOTrainer (which iterates the
|
| 575 |
+
dataset in order under ``shuffle=False`` semantics for max_steps) sees
|
| 576 |
+
increasing difficulty over training.
|
| 577 |
+
|
| 578 |
+
If the schedule's total < ``total``, the last tier is padded by cycling
|
| 579 |
+
its failure types until ``total`` is reached.
|
| 580 |
+
"""
|
| 581 |
+
parsed: list[tuple[DifficultyTier, int]] = []
|
| 582 |
+
for chunk in schedule.split(","):
|
| 583 |
+
tier_name, _, count = chunk.partition(":")
|
| 584 |
+
tier = DifficultyTier(tier_name.strip().lower())
|
| 585 |
+
parsed.append((tier, int(count.strip())))
|
| 586 |
+
|
| 587 |
+
scenarios: list[Scenario] = []
|
| 588 |
+
for tier, count in parsed:
|
| 589 |
+
cycle_seed = 0
|
| 590 |
+
tier_scenarios: list[Scenario] = []
|
| 591 |
+
while len(tier_scenarios) < count:
|
| 592 |
+
batch = scenarios_for_tier(
|
| 593 |
+
tier, seed_offset=cycle_seed, episodes_per_type=1
|
| 594 |
+
)
|
| 595 |
+
tier_scenarios.extend(batch)
|
| 596 |
+
cycle_seed += 97
|
| 597 |
+
scenarios.extend(tier_scenarios[:count])
|
| 598 |
+
|
| 599 |
+
# Pad with the last tier if the schedule under-shoots ``total``.
|
| 600 |
+
if scenarios and len(scenarios) < total:
|
| 601 |
+
last_tier = parsed[-1][0]
|
| 602 |
+
cycle_seed = 9000 # offset past the schedule's seeds
|
| 603 |
+
while len(scenarios) < total:
|
| 604 |
+
batch = scenarios_for_tier(
|
| 605 |
+
last_tier, seed_offset=cycle_seed, episodes_per_type=1
|
| 606 |
+
)
|
| 607 |
+
scenarios.extend(batch)
|
| 608 |
+
cycle_seed += 97
|
| 609 |
+
return scenarios[:total]
|
| 610 |
+
|
| 611 |
+
|
| 612 |
# ---------------------------------------------------------------------------
|
| 613 |
# Training loop — modern TRL GRPO API
|
| 614 |
# ---------------------------------------------------------------------------
|
|
|
|
| 627 |
max_seq_length: int = 1024,
|
| 628 |
max_completion_length: int = 96,
|
| 629 |
learning_rate: float = 5e-6,
|
| 630 |
+
temperature: float = 0.7,
|
| 631 |
+
curriculum_schedule: str | None = None,
|
| 632 |
) -> dict[str, Any]:
|
| 633 |
"""Run GRPO training via TRL's GRPOTrainer.
|
| 634 |
|
|
|
|
| 641 |
output_dir.mkdir(parents=True, exist_ok=True)
|
| 642 |
|
| 643 |
scenario_count = max(total_episodes, 8)
|
| 644 |
+
if curriculum_schedule:
|
| 645 |
+
scenarios = _scenarios_from_schedule(
|
| 646 |
+
curriculum_schedule, total=scenario_count
|
| 647 |
+
)
|
| 648 |
+
print(
|
| 649 |
+
f"[grpo_train] curriculum schedule active: {curriculum_schedule} "
|
| 650 |
+
f"({len(scenarios)} scenarios across tiers)"
|
| 651 |
+
)
|
| 652 |
+
else:
|
| 653 |
+
scenarios = _collect_scenarios(curriculum, total=scenario_count)
|
| 654 |
dataset = build_training_dataset(scenarios)
|
| 655 |
|
| 656 |
# Every optim step: 1 unique prompt × group_size completions.
|
|
|
|
| 661 |
per_device_train_batch_size=per_device_train_batch_size,
|
| 662 |
gradient_accumulation_steps=1,
|
| 663 |
num_generations=group_size,
|
| 664 |
+
temperature=temperature,
|
| 665 |
max_prompt_length=max_seq_length,
|
| 666 |
max_completion_length=max_completion_length,
|
| 667 |
learning_rate=learning_rate,
|
|
|
|
| 763 |
choices=["auto", "unsloth", "transformers"],
|
| 764 |
help="Model loader. 'auto' tries Unsloth, falls back to transformers.",
|
| 765 |
)
|
| 766 |
+
parser.add_argument(
|
| 767 |
+
"--learning-rate",
|
| 768 |
+
type=float,
|
| 769 |
+
default=5e-6,
|
| 770 |
+
help="GRPO learning rate. Default 5e-6; 2e-5 if reward stays flat.",
|
| 771 |
+
)
|
| 772 |
+
parser.add_argument(
|
| 773 |
+
"--temperature",
|
| 774 |
+
type=float,
|
| 775 |
+
default=0.7,
|
| 776 |
+
help="Sampling temperature for completions during GRPO rollout.",
|
| 777 |
+
)
|
| 778 |
+
parser.add_argument(
|
| 779 |
+
"--curriculum-schedule",
|
| 780 |
+
type=str,
|
| 781 |
+
default=None,
|
| 782 |
+
help=(
|
| 783 |
+
"Step-budget tier schedule, e.g. 'easy:200,medium:200,hard:200'. "
|
| 784 |
+
"Overrides --start-tier when set."
|
| 785 |
+
),
|
| 786 |
+
)
|
| 787 |
return parser.parse_args()
|
| 788 |
|
| 789 |
|
|
|
|
| 806 |
log_every=args.log_every,
|
| 807 |
output_dir=args.output_dir,
|
| 808 |
max_seq_length=args.max_seq_length,
|
| 809 |
+
learning_rate=args.learning_rate,
|
| 810 |
+
temperature=args.temperature,
|
| 811 |
+
curriculum_schedule=args.curriculum_schedule,
|
| 812 |
)
|
| 813 |
print(json.dumps(summary, indent=2))
|
| 814 |
|