GRPO: add --rogue-bonus-multiplier to amplify oversight gradient signal
Browse filesPhase 2 traded off MEDIUM-tier rogue-catch (20% → 0%) for resolution
gains. Phase 3 needs to recover both. The new flag scales BOTH the
OversightRubric catch-bonus (+50) and false-positive penalty (-75) by
the same factor so calibration pressure is preserved while the
absolute gradient signal on flag_rogue actions is amplified.
* compute_step_reward(rogue_bonus_multiplier=1.0) — backwards-compatible
* make_reward_fn forwards multiplier into _score_completion which
re-builds an OversightRubric on the fly when multiplier != 1.0
* CLI flag --rogue-bonus-multiplier (default 1.0)
* scripts/jobs_grpo_train.sh: GRPO_ROGUE_MULTIPLIER env var
Smoke-tested: 2.0× yields catch=+100, FP=-150, oversight reward stream
exactly doubles. 110/110 unit tests still green (the existing tests
all use the default multiplier=1.0).
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
- rewards/reward_fn.py +10 -1
- scripts/jobs_grpo_train.sh +2 -0
- train/grpo_train.py +54 -3
|
@@ -268,6 +268,7 @@ def compute_step_reward(
|
|
| 268 |
outcome_flags: Mapping[str, bool],
|
| 269 |
budget_steps: int = 8,
|
| 270 |
mttr_penalty_per_step: float = 2.0,
|
|
|
|
| 271 |
) -> StepRewardBreakdown:
|
| 272 |
"""Compose the four default rubrics into a :class:`StepRewardBreakdown`.
|
| 273 |
|
|
@@ -283,13 +284,21 @@ def compute_step_reward(
|
|
| 283 |
trained agents to *optimize* for it rather than merely resolve.
|
| 284 |
mttr_penalty_per_step :
|
| 285 |
Linear MTTR penalty. Kept separate so ablations can disable it.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
|
| 287 |
The function is a thin wrapper around the rubric set; callers wanting
|
| 288 |
per-rubric introspection should call :func:`score_rubrics` directly.
|
| 289 |
"""
|
| 290 |
resolution = ResolutionRubric(budget_steps=budget_steps)(state, outcome_flags)
|
| 291 |
mttr = MTTRRubric(penalty_per_step=mttr_penalty_per_step)(state, outcome_flags)
|
| 292 |
-
oversight = OversightRubric(
|
|
|
|
|
|
|
|
|
|
| 293 |
cascade = CascadeRubric()(state, outcome_flags)
|
| 294 |
return StepRewardBreakdown(
|
| 295 |
resolved_bonus=resolution["resolved"],
|
|
|
|
| 268 |
outcome_flags: Mapping[str, bool],
|
| 269 |
budget_steps: int = 8,
|
| 270 |
mttr_penalty_per_step: float = 2.0,
|
| 271 |
+
rogue_bonus_multiplier: float = 1.0,
|
| 272 |
) -> StepRewardBreakdown:
|
| 273 |
"""Compose the four default rubrics into a :class:`StepRewardBreakdown`.
|
| 274 |
|
|
|
|
| 284 |
trained agents to *optimize* for it rather than merely resolve.
|
| 285 |
mttr_penalty_per_step :
|
| 286 |
Linear MTTR penalty. Kept separate so ablations can disable it.
|
| 287 |
+
rogue_bonus_multiplier :
|
| 288 |
+
Scales BOTH the rogue-catch bonus and the false-positive penalty
|
| 289 |
+
on the :class:`OversightRubric`. Default 1.0 (catch +50, FP −75).
|
| 290 |
+
Used during GRPO ablations to amplify the oversight gradient
|
| 291 |
+
signal when the policy collapses off ``flag_rogue`` actions.
|
| 292 |
|
| 293 |
The function is a thin wrapper around the rubric set; callers wanting
|
| 294 |
per-rubric introspection should call :func:`score_rubrics` directly.
|
| 295 |
"""
|
| 296 |
resolution = ResolutionRubric(budget_steps=budget_steps)(state, outcome_flags)
|
| 297 |
mttr = MTTRRubric(penalty_per_step=mttr_penalty_per_step)(state, outcome_flags)
|
| 298 |
+
oversight = OversightRubric(
|
| 299 |
+
rogue_caught_bonus=50.0 * rogue_bonus_multiplier,
|
| 300 |
+
rogue_false_positive_penalty=-75.0 * rogue_bonus_multiplier,
|
| 301 |
+
)(state, outcome_flags)
|
| 302 |
cascade = CascadeRubric()(state, outcome_flags)
|
| 303 |
return StepRewardBreakdown(
|
| 304 |
resolved_bonus=resolution["resolved"],
|
|
@@ -29,6 +29,7 @@ GRPO_PUSH_TO_HUB="${GRPO_PUSH_TO_HUB:-0}"
|
|
| 29 |
GRPO_LR="${GRPO_LR:-5e-6}"
|
| 30 |
GRPO_TEMP="${GRPO_TEMP:-0.7}"
|
| 31 |
GRPO_CURRICULUM="${GRPO_CURRICULUM:-}"
|
|
|
|
| 32 |
HUB_REPO_ID="${HUB_REPO_ID:-helloAK96/chaosops-grpo-lora}"
|
| 33 |
|
| 34 |
OUTPUT_DIR="/workspace/artifacts/chaosops-grpo"
|
|
@@ -78,6 +79,7 @@ PY_ARGS=(
|
|
| 78 |
--output-dir "${OUTPUT_DIR}"
|
| 79 |
--learning-rate "${GRPO_LR}"
|
| 80 |
--temperature "${GRPO_TEMP}"
|
|
|
|
| 81 |
)
|
| 82 |
if [ -n "${GRPO_CURRICULUM}" ]; then
|
| 83 |
PY_ARGS+=(--curriculum-schedule "${GRPO_CURRICULUM}")
|
|
|
|
| 29 |
GRPO_LR="${GRPO_LR:-5e-6}"
|
| 30 |
GRPO_TEMP="${GRPO_TEMP:-0.7}"
|
| 31 |
GRPO_CURRICULUM="${GRPO_CURRICULUM:-}"
|
| 32 |
+
GRPO_ROGUE_MULTIPLIER="${GRPO_ROGUE_MULTIPLIER:-1.0}"
|
| 33 |
HUB_REPO_ID="${HUB_REPO_ID:-helloAK96/chaosops-grpo-lora}"
|
| 34 |
|
| 35 |
OUTPUT_DIR="/workspace/artifacts/chaosops-grpo"
|
|
|
|
| 79 |
--output-dir "${OUTPUT_DIR}"
|
| 80 |
--learning-rate "${GRPO_LR}"
|
| 81 |
--temperature "${GRPO_TEMP}"
|
| 82 |
+
--rogue-bonus-multiplier "${GRPO_ROGUE_MULTIPLIER}"
|
| 83 |
)
|
| 84 |
if [ -n "${GRPO_CURRICULUM}" ]; then
|
| 85 |
PY_ARGS+=(--curriculum-schedule "${GRPO_CURRICULUM}")
|
|
@@ -237,8 +237,13 @@ def build_training_dataset(scenarios: list[Scenario]):
|
|
| 237 |
# ---------------------------------------------------------------------------
|
| 238 |
|
| 239 |
|
| 240 |
-
def make_reward_fn(team_weight: float):
|
| 241 |
-
"""Return a TRL-compatible reward function
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 242 |
|
| 243 |
def chaosops_reward(
|
| 244 |
prompts: list[str],
|
|
@@ -260,6 +265,7 @@ def make_reward_fn(team_weight: float):
|
|
| 260 |
hist_js=hist_js,
|
| 261 |
role_v=role_v,
|
| 262 |
team_weight=team_weight,
|
|
|
|
| 263 |
)
|
| 264 |
except Exception:
|
| 265 |
# Robust to parsing / replay failures — penalise but don't crash training.
|
|
@@ -277,7 +283,10 @@ def _score_completion(
|
|
| 277 |
hist_js: str,
|
| 278 |
role_v: str,
|
| 279 |
team_weight: float,
|
|
|
|
| 280 |
) -> float:
|
|
|
|
|
|
|
| 281 |
scen = _scenario_from_json(scen_js)
|
| 282 |
history_raw = json.loads(hist_js)
|
| 283 |
env = ChaosOpsEnvironment()
|
|
@@ -298,6 +307,30 @@ def _score_completion(
|
|
| 298 |
breakdown = env.last_breakdown
|
| 299 |
if breakdown is None:
|
| 300 |
return 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 301 |
return combine_rewards(
|
| 302 |
breakdown.team_reward,
|
| 303 |
breakdown.oversight_reward,
|
|
@@ -629,6 +662,7 @@ def run_grpo(
|
|
| 629 |
learning_rate: float = 5e-6,
|
| 630 |
temperature: float = 0.7,
|
| 631 |
curriculum_schedule: str | None = None,
|
|
|
|
| 632 |
) -> dict[str, Any]:
|
| 633 |
"""Run GRPO training via TRL's GRPOTrainer.
|
| 634 |
|
|
@@ -675,7 +709,13 @@ def run_grpo(
|
|
| 675 |
remove_unused_columns=False,
|
| 676 |
)
|
| 677 |
|
| 678 |
-
reward_fn = make_reward_fn(team_weight)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 679 |
metrics_callback = _make_metrics_callback(output_dir)
|
| 680 |
|
| 681 |
trainer = GRPOTrainer(
|
|
@@ -784,6 +824,16 @@ def _parse_args() -> argparse.Namespace:
|
|
| 784 |
"Overrides --start-tier when set."
|
| 785 |
),
|
| 786 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 787 |
return parser.parse_args()
|
| 788 |
|
| 789 |
|
|
@@ -809,6 +859,7 @@ def main() -> None:
|
|
| 809 |
learning_rate=args.learning_rate,
|
| 810 |
temperature=args.temperature,
|
| 811 |
curriculum_schedule=args.curriculum_schedule,
|
|
|
|
| 812 |
)
|
| 813 |
print(json.dumps(summary, indent=2))
|
| 814 |
|
|
|
|
| 237 |
# ---------------------------------------------------------------------------
|
| 238 |
|
| 239 |
|
| 240 |
+
def make_reward_fn(team_weight: float, rogue_bonus_multiplier: float = 1.0):
|
| 241 |
+
"""Return a TRL-compatible reward function.
|
| 242 |
+
|
| 243 |
+
``rogue_bonus_multiplier`` scales the OversightRubric weights at score
|
| 244 |
+
time so the GRPO gradient on ``flag_rogue`` actions can be amplified
|
| 245 |
+
without touching the env's published reward formula.
|
| 246 |
+
"""
|
| 247 |
|
| 248 |
def chaosops_reward(
|
| 249 |
prompts: list[str],
|
|
|
|
| 265 |
hist_js=hist_js,
|
| 266 |
role_v=role_v,
|
| 267 |
team_weight=team_weight,
|
| 268 |
+
rogue_bonus_multiplier=rogue_bonus_multiplier,
|
| 269 |
)
|
| 270 |
except Exception:
|
| 271 |
# Robust to parsing / replay failures — penalise but don't crash training.
|
|
|
|
| 283 |
hist_js: str,
|
| 284 |
role_v: str,
|
| 285 |
team_weight: float,
|
| 286 |
+
rogue_bonus_multiplier: float = 1.0,
|
| 287 |
) -> float:
|
| 288 |
+
from chaosops.rewards.reward_fn import compute_step_reward
|
| 289 |
+
|
| 290 |
scen = _scenario_from_json(scen_js)
|
| 291 |
history_raw = json.loads(hist_js)
|
| 292 |
env = ChaosOpsEnvironment()
|
|
|
|
| 307 |
breakdown = env.last_breakdown
|
| 308 |
if breakdown is None:
|
| 309 |
return 0.0
|
| 310 |
+
if rogue_bonus_multiplier != 1.0:
|
| 311 |
+
# Re-score this step with scaled oversight rubric so the GRPO
|
| 312 |
+
# gradient on `flag_rogue` actions is amplified.
|
| 313 |
+
flags = {
|
| 314 |
+
"resolved": False, # post-action state already updated; re-derive flags from breakdown
|
| 315 |
+
"wrong_fix": breakdown.wrong_fix_penalty < 0,
|
| 316 |
+
"miscommunication": breakdown.miscommunication_penalty < 0,
|
| 317 |
+
"root_cause_correct": breakdown.early_root_cause_bonus > 0,
|
| 318 |
+
"rogue_flagged_correctly": breakdown.rogue_caught_bonus > 0,
|
| 319 |
+
"rogue_flagged_incorrectly": breakdown.rogue_false_positive_penalty < 0,
|
| 320 |
+
"cascade_triggered": breakdown.cascade_penalty < 0,
|
| 321 |
+
}
|
| 322 |
+
# The `resolved` flag is recoverable from env state (post-step):
|
| 323 |
+
flags["resolved"] = env.state.resolved
|
| 324 |
+
rescored = compute_step_reward(
|
| 325 |
+
state=env.state,
|
| 326 |
+
outcome_flags=flags,
|
| 327 |
+
rogue_bonus_multiplier=rogue_bonus_multiplier,
|
| 328 |
+
)
|
| 329 |
+
return combine_rewards(
|
| 330 |
+
rescored.team_reward,
|
| 331 |
+
rescored.oversight_reward,
|
| 332 |
+
team_weight=team_weight,
|
| 333 |
+
)
|
| 334 |
return combine_rewards(
|
| 335 |
breakdown.team_reward,
|
| 336 |
breakdown.oversight_reward,
|
|
|
|
| 662 |
learning_rate: float = 5e-6,
|
| 663 |
temperature: float = 0.7,
|
| 664 |
curriculum_schedule: str | None = None,
|
| 665 |
+
rogue_bonus_multiplier: float = 1.0,
|
| 666 |
) -> dict[str, Any]:
|
| 667 |
"""Run GRPO training via TRL's GRPOTrainer.
|
| 668 |
|
|
|
|
| 709 |
remove_unused_columns=False,
|
| 710 |
)
|
| 711 |
|
| 712 |
+
reward_fn = make_reward_fn(team_weight, rogue_bonus_multiplier=rogue_bonus_multiplier)
|
| 713 |
+
if rogue_bonus_multiplier != 1.0:
|
| 714 |
+
print(
|
| 715 |
+
f"[grpo_train] rogue rubric ×{rogue_bonus_multiplier} "
|
| 716 |
+
f"(catch={50.0 * rogue_bonus_multiplier:+.0f}, "
|
| 717 |
+
f"FP={-75.0 * rogue_bonus_multiplier:+.0f})"
|
| 718 |
+
)
|
| 719 |
metrics_callback = _make_metrics_callback(output_dir)
|
| 720 |
|
| 721 |
trainer = GRPOTrainer(
|
|
|
|
| 824 |
"Overrides --start-tier when set."
|
| 825 |
),
|
| 826 |
)
|
| 827 |
+
parser.add_argument(
|
| 828 |
+
"--rogue-bonus-multiplier",
|
| 829 |
+
type=float,
|
| 830 |
+
default=1.0,
|
| 831 |
+
help=(
|
| 832 |
+
"Scale BOTH the OversightRubric rogue-catch bonus (+50) and FP "
|
| 833 |
+
"penalty (-75) by this factor. >1.0 amplifies the gradient on "
|
| 834 |
+
"flag_rogue actions; useful when prior runs collapsed off them."
|
| 835 |
+
),
|
| 836 |
+
)
|
| 837 |
return parser.parse_args()
|
| 838 |
|
| 839 |
|
|
|
|
| 859 |
learning_rate=args.learning_rate,
|
| 860 |
temperature=args.temperature,
|
| 861 |
curriculum_schedule=args.curriculum_schedule,
|
| 862 |
+
rogue_bonus_multiplier=args.rogue_bonus_multiplier,
|
| 863 |
)
|
| 864 |
print(json.dumps(summary, indent=2))
|
| 865 |
|