Humanlearning commited on
Commit
2eada22
·
1 Parent(s): be8eade

feat: add episode trace fingerprinting for improved trace logging and update reward penalties in GRPO configuration

Browse files
scripts/modal_train_grpo.py CHANGED
@@ -518,6 +518,7 @@ def train_cybersecurity_owasp_grpo(
518
  from training.trackio_utils import (
519
  aggregate_episode_metrics,
520
  episode_record_from_state,
 
521
  log_gpu_metrics,
522
  log_trace_table,
523
  log_trackio_metrics,
@@ -886,6 +887,7 @@ def train_cybersecurity_owasp_grpo(
886
  pass
887
 
888
  trace_step = {"value": 0}
 
889
 
890
  def _completion_to_text(completion) -> str:
891
  if completion is None:
@@ -950,16 +952,28 @@ def train_cybersecurity_owasp_grpo(
950
  except Exception as exc:
951
  print(f"Trackio metric logging skipped: {exc!r}")
952
 
 
 
 
 
 
 
 
 
 
 
 
 
953
  try:
954
  log_trace_table(
955
- episode_records[: min(4, len(episode_records))],
956
  table_name="sample_traces",
957
  step=trace_step["value"],
958
  )
959
  except Exception as exc:
960
  print(f"Trackio sample trace table logging skipped: {exc!r}")
961
 
962
- for index, env in enumerate(environments):
963
  messages = list(getattr(env, "trace_messages", []))
964
  if index < len(completions):
965
  completion_text = _completion_to_text(completions[index])
@@ -974,8 +988,9 @@ def train_cybersecurity_owasp_grpo(
974
  metadata.update(
975
  {
976
  "sample_index": index,
977
- "reward": rewards[index],
978
  "trace_step": trace_step["value"],
 
979
  "run_name": run_name,
980
  }
981
  )
 
518
  from training.trackio_utils import (
519
  aggregate_episode_metrics,
520
  episode_record_from_state,
521
+ episode_trace_fingerprint,
522
  log_gpu_metrics,
523
  log_trace_table,
524
  log_trackio_metrics,
 
887
  pass
888
 
889
  trace_step = {"value": 0}
890
+ logged_trace_fingerprints: set[str] = set()
891
 
892
  def _completion_to_text(completion) -> str:
893
  if completion is None:
 
952
  except Exception as exc:
953
  print(f"Trackio metric logging skipped: {exc!r}")
954
 
955
+ sampled_traces = []
956
+ seen_this_batch: set[str] = set()
957
+ for index, (env, record, reward) in enumerate(zip(environments, episode_records, rewards)):
958
+ fingerprint = episode_trace_fingerprint(record)
959
+ if fingerprint in seen_this_batch or fingerprint in logged_trace_fingerprints:
960
+ continue
961
+ seen_this_batch.add(fingerprint)
962
+ logged_trace_fingerprints.add(fingerprint)
963
+ sampled_traces.append((index, env, record, reward, fingerprint))
964
+ if len(sampled_traces) >= 4:
965
+ break
966
+
967
  try:
968
  log_trace_table(
969
+ [record for _, _, record, _, _ in sampled_traces],
970
  table_name="sample_traces",
971
  step=trace_step["value"],
972
  )
973
  except Exception as exc:
974
  print(f"Trackio sample trace table logging skipped: {exc!r}")
975
 
976
+ for index, env, _record, reward, fingerprint in sampled_traces:
977
  messages = list(getattr(env, "trace_messages", []))
978
  if index < len(completions):
979
  completion_text = _completion_to_text(completions[index])
 
988
  metadata.update(
989
  {
990
  "sample_index": index,
991
+ "reward": reward,
992
  "trace_step": trace_step["value"],
993
+ "trace_fingerprint": fingerprint,
994
  "run_name": run_name,
995
  }
996
  )
tests/test_rewards.py CHANGED
@@ -114,8 +114,9 @@ def test_repeated_futile_actions_are_penalized(monkeypatch):
114
 
115
  assert first.reward_breakdown["progressive"] > 0.0
116
  assert second.reward_breakdown["progressive"] == 0.0
117
- assert second.reward_breakdown["behavior_penalty"] <= -0.10
118
  assert second.reward_breakdown["total"] < 0.0
 
119
 
120
 
121
  def test_dense_episode_reward_cap_blocks_repeated_positive_farming(monkeypatch):
 
114
 
115
  assert first.reward_breakdown["progressive"] > 0.0
116
  assert second.reward_breakdown["progressive"] == 0.0
117
+ assert second.reward_breakdown["behavior_penalty"] <= -0.50
118
  assert second.reward_breakdown["total"] < 0.0
119
+ assert env.state.accumulated_reward < 0.0
120
 
121
 
122
  def test_dense_episode_reward_cap_blocks_repeated_positive_farming(monkeypatch):
tests/test_trackio_utils.py CHANGED
@@ -6,6 +6,7 @@ from training.trackio_utils import (
6
  DERIVED_TRACKIO_METRICS,
7
  aggregate_episode_metrics,
8
  episode_record_from_state,
 
9
  episode_to_trace_row,
10
  episode_to_tracking_fields,
11
  )
@@ -91,3 +92,32 @@ def test_trace_rows_redact_hidden_values_from_action_arguments():
91
  assert not value or value not in row_text
92
  finally:
93
  env.close()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  DERIVED_TRACKIO_METRICS,
7
  aggregate_episode_metrics,
8
  episode_record_from_state,
9
+ episode_trace_fingerprint,
10
  episode_to_trace_row,
11
  episode_to_tracking_fields,
12
  )
 
92
  assert not value or value not in row_text
93
  finally:
94
  env.close()
95
+
96
+
97
+ def test_trace_fingerprint_ignores_episode_id_but_tracks_action_changes():
98
+ base_record = {
99
+ "episode_id": "episode-a",
100
+ "task_id": "task-1",
101
+ "scenario/seed": 123,
102
+ "scenario/split": "train",
103
+ "scenario/difficulty": 0,
104
+ "scenario/bug_type": "bola_idor",
105
+ "action_history": [
106
+ {
107
+ "tool_name": "read_file",
108
+ "arguments": {"path": "app/routes/invoices.py"},
109
+ }
110
+ ],
111
+ "observation_history": [{"last_action_valid": True}],
112
+ "reward_breakdown": {"total": 0.0},
113
+ }
114
+ same_trace = dict(base_record)
115
+ same_trace["episode_id"] = "episode-b"
116
+ changed_trace = dict(base_record)
117
+ changed_trace["action_history"] = [
118
+ *base_record["action_history"],
119
+ {"tool_name": "submit_fix", "arguments": {}},
120
+ ]
121
+
122
+ assert episode_trace_fingerprint(base_record) == episode_trace_fingerprint(same_trace)
123
+ assert episode_trace_fingerprint(base_record) != episode_trace_fingerprint(changed_trace)
training/configs/grpo_small.yaml CHANGED
@@ -90,19 +90,19 @@ reward:
90
  value: -0.30
91
  description: "Penalty for repeating the same failed action."
92
  repeated_low_value_action:
93
- value: -0.10
94
  description: "Penalty for repeating the exact same non-progress action."
95
  no_progress_action:
96
- value: -0.05
97
  description: "Penalty for valid tool calls that add no new useful progress."
98
  noop_action:
99
  value: -0.02
100
  description: "Small penalty for spending a step without acting."
101
  repeated_file_read:
102
- value: -0.05
103
  description: "Penalty for rereading the same file without a patch change."
104
  repeated_local_request:
105
- value: -0.05
106
  description: "Penalty for repeating the same local request after evidence is known."
107
  repeated_visible_tests:
108
  value: -0.10
 
90
  value: -0.30
91
  description: "Penalty for repeating the same failed action."
92
  repeated_low_value_action:
93
+ value: -0.40
94
  description: "Penalty for repeating the exact same non-progress action."
95
  no_progress_action:
96
+ value: -0.15
97
  description: "Penalty for valid tool calls that add no new useful progress."
98
  noop_action:
99
  value: -0.02
100
  description: "Small penalty for spending a step without acting."
101
  repeated_file_read:
102
+ value: -0.20
103
  description: "Penalty for rereading the same file without a patch change."
104
  repeated_local_request:
105
+ value: -0.20
106
  description: "Penalty for repeating the same local request after evidence is known."
107
  repeated_visible_tests:
108
  value: -0.10
training/trackio_utils.py CHANGED
@@ -882,6 +882,24 @@ def trace_table_rows(episodes: Sequence[Any]) -> list[dict[str, Any]]:
882
  return [episode_to_trace_row(episode) for episode in episodes]
883
 
884
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
885
  def log_trace_table(
886
  episodes: Sequence[Any],
887
  *,
 
882
  return [episode_to_trace_row(episode) for episode in episodes]
883
 
884
 
885
+ def episode_trace_fingerprint(episode: Any) -> str:
886
+ """Return a stable fingerprint for a redacted trace row.
887
+
888
+ The episode id is intentionally excluded so repeated GRPO samples with the
889
+ same scenario/action trace do not appear as separate Trackio examples.
890
+ """
891
+
892
+ row = episode_to_trace_row(episode)
893
+ return _stable_hash(
894
+ {
895
+ key: row.get(key, "")
896
+ for key in TRACE_TABLE_COLUMNS
897
+ if key != "episode_id"
898
+ },
899
+ length=24,
900
+ )
901
+
902
+
903
  def log_trace_table(
904
  episodes: Sequence[Any],
905
  *,