helloAK96 Claude Opus 4.7 commited on
Commit
6f963e5
·
1 Parent(s): f89a0e8

GRPO: add --rogue-bonus-multiplier to amplify oversight gradient signal

Browse files

Phase 2 traded off MEDIUM-tier rogue-catch (20% → 0%) for resolution
gains. Phase 3 needs to recover both. The new flag scales BOTH the
OversightRubric catch-bonus (+50) and false-positive penalty (-75) by
the same factor so calibration pressure is preserved while the
absolute gradient signal on flag_rogue actions is amplified.

* compute_step_reward(rogue_bonus_multiplier=1.0) — backwards-compatible
* make_reward_fn forwards multiplier into _score_completion which
re-builds an OversightRubric on the fly when multiplier != 1.0
* CLI flag --rogue-bonus-multiplier (default 1.0)
* scripts/jobs_grpo_train.sh: GRPO_ROGUE_MULTIPLIER env var

Smoke-tested: 2.0× yields catch=+100, FP=-150, oversight reward stream
exactly doubles. 110/110 unit tests still green (the existing tests
all use the default multiplier=1.0).

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

rewards/reward_fn.py CHANGED
@@ -268,6 +268,7 @@ def compute_step_reward(
268
  outcome_flags: Mapping[str, bool],
269
  budget_steps: int = 8,
270
  mttr_penalty_per_step: float = 2.0,
 
271
  ) -> StepRewardBreakdown:
272
  """Compose the four default rubrics into a :class:`StepRewardBreakdown`.
273
 
@@ -283,13 +284,21 @@ def compute_step_reward(
283
  trained agents to *optimize* for it rather than merely resolve.
284
  mttr_penalty_per_step :
285
  Linear MTTR penalty. Kept separate so ablations can disable it.
 
 
 
 
 
286
 
287
  The function is a thin wrapper around the rubric set; callers wanting
288
  per-rubric introspection should call :func:`score_rubrics` directly.
289
  """
290
  resolution = ResolutionRubric(budget_steps=budget_steps)(state, outcome_flags)
291
  mttr = MTTRRubric(penalty_per_step=mttr_penalty_per_step)(state, outcome_flags)
292
- oversight = OversightRubric()(state, outcome_flags)
 
 
 
293
  cascade = CascadeRubric()(state, outcome_flags)
294
  return StepRewardBreakdown(
295
  resolved_bonus=resolution["resolved"],
 
268
  outcome_flags: Mapping[str, bool],
269
  budget_steps: int = 8,
270
  mttr_penalty_per_step: float = 2.0,
271
+ rogue_bonus_multiplier: float = 1.0,
272
  ) -> StepRewardBreakdown:
273
  """Compose the four default rubrics into a :class:`StepRewardBreakdown`.
274
 
 
284
  trained agents to *optimize* for it rather than merely resolve.
285
  mttr_penalty_per_step :
286
  Linear MTTR penalty. Kept separate so ablations can disable it.
287
+ rogue_bonus_multiplier :
288
+ Scales BOTH the rogue-catch bonus and the false-positive penalty
289
+ on the :class:`OversightRubric`. Default 1.0 (catch +50, FP −75).
290
+ Used during GRPO ablations to amplify the oversight gradient
291
+ signal when the policy collapses off ``flag_rogue`` actions.
292
 
293
  The function is a thin wrapper around the rubric set; callers wanting
294
  per-rubric introspection should call :func:`score_rubrics` directly.
295
  """
296
  resolution = ResolutionRubric(budget_steps=budget_steps)(state, outcome_flags)
297
  mttr = MTTRRubric(penalty_per_step=mttr_penalty_per_step)(state, outcome_flags)
298
+ oversight = OversightRubric(
299
+ rogue_caught_bonus=50.0 * rogue_bonus_multiplier,
300
+ rogue_false_positive_penalty=-75.0 * rogue_bonus_multiplier,
301
+ )(state, outcome_flags)
302
  cascade = CascadeRubric()(state, outcome_flags)
303
  return StepRewardBreakdown(
304
  resolved_bonus=resolution["resolved"],
scripts/jobs_grpo_train.sh CHANGED
@@ -29,6 +29,7 @@ GRPO_PUSH_TO_HUB="${GRPO_PUSH_TO_HUB:-0}"
29
  GRPO_LR="${GRPO_LR:-5e-6}"
30
  GRPO_TEMP="${GRPO_TEMP:-0.7}"
31
  GRPO_CURRICULUM="${GRPO_CURRICULUM:-}"
 
32
  HUB_REPO_ID="${HUB_REPO_ID:-helloAK96/chaosops-grpo-lora}"
33
 
34
  OUTPUT_DIR="/workspace/artifacts/chaosops-grpo"
@@ -78,6 +79,7 @@ PY_ARGS=(
78
  --output-dir "${OUTPUT_DIR}"
79
  --learning-rate "${GRPO_LR}"
80
  --temperature "${GRPO_TEMP}"
 
81
  )
82
  if [ -n "${GRPO_CURRICULUM}" ]; then
83
  PY_ARGS+=(--curriculum-schedule "${GRPO_CURRICULUM}")
 
29
  GRPO_LR="${GRPO_LR:-5e-6}"
30
  GRPO_TEMP="${GRPO_TEMP:-0.7}"
31
  GRPO_CURRICULUM="${GRPO_CURRICULUM:-}"
32
+ GRPO_ROGUE_MULTIPLIER="${GRPO_ROGUE_MULTIPLIER:-1.0}"
33
  HUB_REPO_ID="${HUB_REPO_ID:-helloAK96/chaosops-grpo-lora}"
34
 
35
  OUTPUT_DIR="/workspace/artifacts/chaosops-grpo"
 
79
  --output-dir "${OUTPUT_DIR}"
80
  --learning-rate "${GRPO_LR}"
81
  --temperature "${GRPO_TEMP}"
82
+ --rogue-bonus-multiplier "${GRPO_ROGUE_MULTIPLIER}"
83
  )
84
  if [ -n "${GRPO_CURRICULUM}" ]; then
85
  PY_ARGS+=(--curriculum-schedule "${GRPO_CURRICULUM}")
train/grpo_train.py CHANGED
@@ -237,8 +237,13 @@ def build_training_dataset(scenarios: list[Scenario]):
237
  # ---------------------------------------------------------------------------
238
 
239
 
240
- def make_reward_fn(team_weight: float):
241
- """Return a TRL-compatible reward function closed over ``team_weight``."""
 
 
 
 
 
242
 
243
  def chaosops_reward(
244
  prompts: list[str],
@@ -260,6 +265,7 @@ def make_reward_fn(team_weight: float):
260
  hist_js=hist_js,
261
  role_v=role_v,
262
  team_weight=team_weight,
 
263
  )
264
  except Exception:
265
  # Robust to parsing / replay failures — penalise but don't crash training.
@@ -277,7 +283,10 @@ def _score_completion(
277
  hist_js: str,
278
  role_v: str,
279
  team_weight: float,
 
280
  ) -> float:
 
 
281
  scen = _scenario_from_json(scen_js)
282
  history_raw = json.loads(hist_js)
283
  env = ChaosOpsEnvironment()
@@ -298,6 +307,30 @@ def _score_completion(
298
  breakdown = env.last_breakdown
299
  if breakdown is None:
300
  return 0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
301
  return combine_rewards(
302
  breakdown.team_reward,
303
  breakdown.oversight_reward,
@@ -629,6 +662,7 @@ def run_grpo(
629
  learning_rate: float = 5e-6,
630
  temperature: float = 0.7,
631
  curriculum_schedule: str | None = None,
 
632
  ) -> dict[str, Any]:
633
  """Run GRPO training via TRL's GRPOTrainer.
634
 
@@ -675,7 +709,13 @@ def run_grpo(
675
  remove_unused_columns=False,
676
  )
677
 
678
- reward_fn = make_reward_fn(team_weight)
 
 
 
 
 
 
679
  metrics_callback = _make_metrics_callback(output_dir)
680
 
681
  trainer = GRPOTrainer(
@@ -784,6 +824,16 @@ def _parse_args() -> argparse.Namespace:
784
  "Overrides --start-tier when set."
785
  ),
786
  )
 
 
 
 
 
 
 
 
 
 
787
  return parser.parse_args()
788
 
789
 
@@ -809,6 +859,7 @@ def main() -> None:
809
  learning_rate=args.learning_rate,
810
  temperature=args.temperature,
811
  curriculum_schedule=args.curriculum_schedule,
 
812
  )
813
  print(json.dumps(summary, indent=2))
814
 
 
237
  # ---------------------------------------------------------------------------
238
 
239
 
240
+ def make_reward_fn(team_weight: float, rogue_bonus_multiplier: float = 1.0):
241
+ """Return a TRL-compatible reward function.
242
+
243
+ ``rogue_bonus_multiplier`` scales the OversightRubric weights at score
244
+ time so the GRPO gradient on ``flag_rogue`` actions can be amplified
245
+ without touching the env's published reward formula.
246
+ """
247
 
248
  def chaosops_reward(
249
  prompts: list[str],
 
265
  hist_js=hist_js,
266
  role_v=role_v,
267
  team_weight=team_weight,
268
+ rogue_bonus_multiplier=rogue_bonus_multiplier,
269
  )
270
  except Exception:
271
  # Robust to parsing / replay failures — penalise but don't crash training.
 
283
  hist_js: str,
284
  role_v: str,
285
  team_weight: float,
286
+ rogue_bonus_multiplier: float = 1.0,
287
  ) -> float:
288
+ from chaosops.rewards.reward_fn import compute_step_reward
289
+
290
  scen = _scenario_from_json(scen_js)
291
  history_raw = json.loads(hist_js)
292
  env = ChaosOpsEnvironment()
 
307
  breakdown = env.last_breakdown
308
  if breakdown is None:
309
  return 0.0
310
+ if rogue_bonus_multiplier != 1.0:
311
+ # Re-score this step with scaled oversight rubric so the GRPO
312
+ # gradient on `flag_rogue` actions is amplified.
313
+ flags = {
314
+ "resolved": False, # post-action state already updated; re-derive flags from breakdown
315
+ "wrong_fix": breakdown.wrong_fix_penalty < 0,
316
+ "miscommunication": breakdown.miscommunication_penalty < 0,
317
+ "root_cause_correct": breakdown.early_root_cause_bonus > 0,
318
+ "rogue_flagged_correctly": breakdown.rogue_caught_bonus > 0,
319
+ "rogue_flagged_incorrectly": breakdown.rogue_false_positive_penalty < 0,
320
+ "cascade_triggered": breakdown.cascade_penalty < 0,
321
+ }
322
+ # The `resolved` flag is recoverable from env state (post-step):
323
+ flags["resolved"] = env.state.resolved
324
+ rescored = compute_step_reward(
325
+ state=env.state,
326
+ outcome_flags=flags,
327
+ rogue_bonus_multiplier=rogue_bonus_multiplier,
328
+ )
329
+ return combine_rewards(
330
+ rescored.team_reward,
331
+ rescored.oversight_reward,
332
+ team_weight=team_weight,
333
+ )
334
  return combine_rewards(
335
  breakdown.team_reward,
336
  breakdown.oversight_reward,
 
662
  learning_rate: float = 5e-6,
663
  temperature: float = 0.7,
664
  curriculum_schedule: str | None = None,
665
+ rogue_bonus_multiplier: float = 1.0,
666
  ) -> dict[str, Any]:
667
  """Run GRPO training via TRL's GRPOTrainer.
668
 
 
709
  remove_unused_columns=False,
710
  )
711
 
712
+ reward_fn = make_reward_fn(team_weight, rogue_bonus_multiplier=rogue_bonus_multiplier)
713
+ if rogue_bonus_multiplier != 1.0:
714
+ print(
715
+ f"[grpo_train] rogue rubric ×{rogue_bonus_multiplier} "
716
+ f"(catch={50.0 * rogue_bonus_multiplier:+.0f}, "
717
+ f"FP={-75.0 * rogue_bonus_multiplier:+.0f})"
718
+ )
719
  metrics_callback = _make_metrics_callback(output_dir)
720
 
721
  trainer = GRPOTrainer(
 
824
  "Overrides --start-tier when set."
825
  ),
826
  )
827
+ parser.add_argument(
828
+ "--rogue-bonus-multiplier",
829
+ type=float,
830
+ default=1.0,
831
+ help=(
832
+ "Scale BOTH the OversightRubric rogue-catch bonus (+50) and FP "
833
+ "penalty (-75) by this factor. >1.0 amplifies the gradient on "
834
+ "flag_rogue actions; useful when prior runs collapsed off them."
835
+ ),
836
+ )
837
  return parser.parse_args()
838
 
839
 
 
859
  learning_rate=args.learning_rate,
860
  temperature=args.temperature,
861
  curriculum_schedule=args.curriculum_schedule,
862
+ rogue_bonus_multiplier=args.rogue_bonus_multiplier,
863
  )
864
  print(json.dumps(summary, indent=2))
865