Spaces:
Sleeping
Sleeping
Commit ·
2eada22
1
Parent(s): be8eade
feat: add episode trace fingerprinting for improved trace logging and update reward penalties in GRPO configuration
Browse files- scripts/modal_train_grpo.py +18 -3
- tests/test_rewards.py +2 -1
- tests/test_trackio_utils.py +30 -0
- training/configs/grpo_small.yaml +4 -4
- training/trackio_utils.py +18 -0
scripts/modal_train_grpo.py
CHANGED
|
@@ -518,6 +518,7 @@ def train_cybersecurity_owasp_grpo(
|
|
| 518 |
from training.trackio_utils import (
|
| 519 |
aggregate_episode_metrics,
|
| 520 |
episode_record_from_state,
|
|
|
|
| 521 |
log_gpu_metrics,
|
| 522 |
log_trace_table,
|
| 523 |
log_trackio_metrics,
|
|
@@ -886,6 +887,7 @@ def train_cybersecurity_owasp_grpo(
|
|
| 886 |
pass
|
| 887 |
|
| 888 |
trace_step = {"value": 0}
|
|
|
|
| 889 |
|
| 890 |
def _completion_to_text(completion) -> str:
|
| 891 |
if completion is None:
|
|
@@ -950,16 +952,28 @@ def train_cybersecurity_owasp_grpo(
|
|
| 950 |
except Exception as exc:
|
| 951 |
print(f"Trackio metric logging skipped: {exc!r}")
|
| 952 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 953 |
try:
|
| 954 |
log_trace_table(
|
| 955 |
-
|
| 956 |
table_name="sample_traces",
|
| 957 |
step=trace_step["value"],
|
| 958 |
)
|
| 959 |
except Exception as exc:
|
| 960 |
print(f"Trackio sample trace table logging skipped: {exc!r}")
|
| 961 |
|
| 962 |
-
for index, env in
|
| 963 |
messages = list(getattr(env, "trace_messages", []))
|
| 964 |
if index < len(completions):
|
| 965 |
completion_text = _completion_to_text(completions[index])
|
|
@@ -974,8 +988,9 @@ def train_cybersecurity_owasp_grpo(
|
|
| 974 |
metadata.update(
|
| 975 |
{
|
| 976 |
"sample_index": index,
|
| 977 |
-
"reward":
|
| 978 |
"trace_step": trace_step["value"],
|
|
|
|
| 979 |
"run_name": run_name,
|
| 980 |
}
|
| 981 |
)
|
|
|
|
| 518 |
from training.trackio_utils import (
|
| 519 |
aggregate_episode_metrics,
|
| 520 |
episode_record_from_state,
|
| 521 |
+
episode_trace_fingerprint,
|
| 522 |
log_gpu_metrics,
|
| 523 |
log_trace_table,
|
| 524 |
log_trackio_metrics,
|
|
|
|
| 887 |
pass
|
| 888 |
|
| 889 |
trace_step = {"value": 0}
|
| 890 |
+
logged_trace_fingerprints: set[str] = set()
|
| 891 |
|
| 892 |
def _completion_to_text(completion) -> str:
|
| 893 |
if completion is None:
|
|
|
|
| 952 |
except Exception as exc:
|
| 953 |
print(f"Trackio metric logging skipped: {exc!r}")
|
| 954 |
|
| 955 |
+
sampled_traces = []
|
| 956 |
+
seen_this_batch: set[str] = set()
|
| 957 |
+
for index, (env, record, reward) in enumerate(zip(environments, episode_records, rewards)):
|
| 958 |
+
fingerprint = episode_trace_fingerprint(record)
|
| 959 |
+
if fingerprint in seen_this_batch or fingerprint in logged_trace_fingerprints:
|
| 960 |
+
continue
|
| 961 |
+
seen_this_batch.add(fingerprint)
|
| 962 |
+
logged_trace_fingerprints.add(fingerprint)
|
| 963 |
+
sampled_traces.append((index, env, record, reward, fingerprint))
|
| 964 |
+
if len(sampled_traces) >= 4:
|
| 965 |
+
break
|
| 966 |
+
|
| 967 |
try:
|
| 968 |
log_trace_table(
|
| 969 |
+
[record for _, _, record, _, _ in sampled_traces],
|
| 970 |
table_name="sample_traces",
|
| 971 |
step=trace_step["value"],
|
| 972 |
)
|
| 973 |
except Exception as exc:
|
| 974 |
print(f"Trackio sample trace table logging skipped: {exc!r}")
|
| 975 |
|
| 976 |
+
for index, env, _record, reward, fingerprint in sampled_traces:
|
| 977 |
messages = list(getattr(env, "trace_messages", []))
|
| 978 |
if index < len(completions):
|
| 979 |
completion_text = _completion_to_text(completions[index])
|
|
|
|
| 988 |
metadata.update(
|
| 989 |
{
|
| 990 |
"sample_index": index,
|
| 991 |
+
"reward": reward,
|
| 992 |
"trace_step": trace_step["value"],
|
| 993 |
+
"trace_fingerprint": fingerprint,
|
| 994 |
"run_name": run_name,
|
| 995 |
}
|
| 996 |
)
|
tests/test_rewards.py
CHANGED
|
@@ -114,8 +114,9 @@ def test_repeated_futile_actions_are_penalized(monkeypatch):
|
|
| 114 |
|
| 115 |
assert first.reward_breakdown["progressive"] > 0.0
|
| 116 |
assert second.reward_breakdown["progressive"] == 0.0
|
| 117 |
-
assert second.reward_breakdown["behavior_penalty"] <= -0.
|
| 118 |
assert second.reward_breakdown["total"] < 0.0
|
|
|
|
| 119 |
|
| 120 |
|
| 121 |
def test_dense_episode_reward_cap_blocks_repeated_positive_farming(monkeypatch):
|
|
|
|
| 114 |
|
| 115 |
assert first.reward_breakdown["progressive"] > 0.0
|
| 116 |
assert second.reward_breakdown["progressive"] == 0.0
|
| 117 |
+
assert second.reward_breakdown["behavior_penalty"] <= -0.50
|
| 118 |
assert second.reward_breakdown["total"] < 0.0
|
| 119 |
+
assert env.state.accumulated_reward < 0.0
|
| 120 |
|
| 121 |
|
| 122 |
def test_dense_episode_reward_cap_blocks_repeated_positive_farming(monkeypatch):
|
tests/test_trackio_utils.py
CHANGED
|
@@ -6,6 +6,7 @@ from training.trackio_utils import (
|
|
| 6 |
DERIVED_TRACKIO_METRICS,
|
| 7 |
aggregate_episode_metrics,
|
| 8 |
episode_record_from_state,
|
|
|
|
| 9 |
episode_to_trace_row,
|
| 10 |
episode_to_tracking_fields,
|
| 11 |
)
|
|
@@ -91,3 +92,32 @@ def test_trace_rows_redact_hidden_values_from_action_arguments():
|
|
| 91 |
assert not value or value not in row_text
|
| 92 |
finally:
|
| 93 |
env.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
DERIVED_TRACKIO_METRICS,
|
| 7 |
aggregate_episode_metrics,
|
| 8 |
episode_record_from_state,
|
| 9 |
+
episode_trace_fingerprint,
|
| 10 |
episode_to_trace_row,
|
| 11 |
episode_to_tracking_fields,
|
| 12 |
)
|
|
|
|
| 92 |
assert not value or value not in row_text
|
| 93 |
finally:
|
| 94 |
env.close()
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def test_trace_fingerprint_ignores_episode_id_but_tracks_action_changes():
|
| 98 |
+
base_record = {
|
| 99 |
+
"episode_id": "episode-a",
|
| 100 |
+
"task_id": "task-1",
|
| 101 |
+
"scenario/seed": 123,
|
| 102 |
+
"scenario/split": "train",
|
| 103 |
+
"scenario/difficulty": 0,
|
| 104 |
+
"scenario/bug_type": "bola_idor",
|
| 105 |
+
"action_history": [
|
| 106 |
+
{
|
| 107 |
+
"tool_name": "read_file",
|
| 108 |
+
"arguments": {"path": "app/routes/invoices.py"},
|
| 109 |
+
}
|
| 110 |
+
],
|
| 111 |
+
"observation_history": [{"last_action_valid": True}],
|
| 112 |
+
"reward_breakdown": {"total": 0.0},
|
| 113 |
+
}
|
| 114 |
+
same_trace = dict(base_record)
|
| 115 |
+
same_trace["episode_id"] = "episode-b"
|
| 116 |
+
changed_trace = dict(base_record)
|
| 117 |
+
changed_trace["action_history"] = [
|
| 118 |
+
*base_record["action_history"],
|
| 119 |
+
{"tool_name": "submit_fix", "arguments": {}},
|
| 120 |
+
]
|
| 121 |
+
|
| 122 |
+
assert episode_trace_fingerprint(base_record) == episode_trace_fingerprint(same_trace)
|
| 123 |
+
assert episode_trace_fingerprint(base_record) != episode_trace_fingerprint(changed_trace)
|
training/configs/grpo_small.yaml
CHANGED
|
@@ -90,19 +90,19 @@ reward:
|
|
| 90 |
value: -0.30
|
| 91 |
description: "Penalty for repeating the same failed action."
|
| 92 |
repeated_low_value_action:
|
| 93 |
-
value: -0.
|
| 94 |
description: "Penalty for repeating the exact same non-progress action."
|
| 95 |
no_progress_action:
|
| 96 |
-
value: -0.
|
| 97 |
description: "Penalty for valid tool calls that add no new useful progress."
|
| 98 |
noop_action:
|
| 99 |
value: -0.02
|
| 100 |
description: "Small penalty for spending a step without acting."
|
| 101 |
repeated_file_read:
|
| 102 |
-
value: -0.
|
| 103 |
description: "Penalty for rereading the same file without a patch change."
|
| 104 |
repeated_local_request:
|
| 105 |
-
value: -0.
|
| 106 |
description: "Penalty for repeating the same local request after evidence is known."
|
| 107 |
repeated_visible_tests:
|
| 108 |
value: -0.10
|
|
|
|
| 90 |
value: -0.30
|
| 91 |
description: "Penalty for repeating the same failed action."
|
| 92 |
repeated_low_value_action:
|
| 93 |
+
value: -0.40
|
| 94 |
description: "Penalty for repeating the exact same non-progress action."
|
| 95 |
no_progress_action:
|
| 96 |
+
value: -0.15
|
| 97 |
description: "Penalty for valid tool calls that add no new useful progress."
|
| 98 |
noop_action:
|
| 99 |
value: -0.02
|
| 100 |
description: "Small penalty for spending a step without acting."
|
| 101 |
repeated_file_read:
|
| 102 |
+
value: -0.20
|
| 103 |
description: "Penalty for rereading the same file without a patch change."
|
| 104 |
repeated_local_request:
|
| 105 |
+
value: -0.20
|
| 106 |
description: "Penalty for repeating the same local request after evidence is known."
|
| 107 |
repeated_visible_tests:
|
| 108 |
value: -0.10
|
training/trackio_utils.py
CHANGED
|
@@ -882,6 +882,24 @@ def trace_table_rows(episodes: Sequence[Any]) -> list[dict[str, Any]]:
|
|
| 882 |
return [episode_to_trace_row(episode) for episode in episodes]
|
| 883 |
|
| 884 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 885 |
def log_trace_table(
|
| 886 |
episodes: Sequence[Any],
|
| 887 |
*,
|
|
|
|
| 882 |
return [episode_to_trace_row(episode) for episode in episodes]
|
| 883 |
|
| 884 |
|
| 885 |
+
def episode_trace_fingerprint(episode: Any) -> str:
|
| 886 |
+
"""Return a stable fingerprint for a redacted trace row.
|
| 887 |
+
|
| 888 |
+
The episode id is intentionally excluded so repeated GRPO samples with the
|
| 889 |
+
same scenario/action trace do not appear as separate Trackio examples.
|
| 890 |
+
"""
|
| 891 |
+
|
| 892 |
+
row = episode_to_trace_row(episode)
|
| 893 |
+
return _stable_hash(
|
| 894 |
+
{
|
| 895 |
+
key: row.get(key, "")
|
| 896 |
+
for key in TRACE_TABLE_COLUMNS
|
| 897 |
+
if key != "episode_id"
|
| 898 |
+
},
|
| 899 |
+
length=24,
|
| 900 |
+
)
|
| 901 |
+
|
| 902 |
+
|
| 903 |
def log_trace_table(
|
| 904 |
episodes: Sequence[Any],
|
| 905 |
*,
|