fix: capture rewards/reward_repair_function/mean for plot
Browse files
scripts/jobs/train_repair_agent.py
CHANGED
|
@@ -211,11 +211,28 @@ training_rewards: list[float] = []
|
|
| 211 |
if trainer_state.exists():
|
| 212 |
state = json.loads(trainer_state.read_text())
|
| 213 |
for log in state.get("log_history", []):
|
| 214 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
if k in log:
|
| 216 |
training_rewards.append(float(log[k]))
|
| 217 |
break
|
| 218 |
print(f"[job] {len(training_rewards)} reward log points", flush=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
|
| 220 |
plot_reward_curve(
|
| 221 |
training_rewards or [0.0],
|
|
|
|
| 211 |
if trainer_state.exists():
|
| 212 |
state = json.loads(trainer_state.read_text())
|
| 213 |
for log in state.get("log_history", []):
|
| 214 |
+
# TRL emits a few different reward keys depending on version;
|
| 215 |
+
# try the most specific first, then fall back.
|
| 216 |
+
candidates = [
|
| 217 |
+
"rewards/reward_repair_function/mean",
|
| 218 |
+
"rewards/mean",
|
| 219 |
+
"reward",
|
| 220 |
+
"train/reward",
|
| 221 |
+
]
|
| 222 |
+
# also pick up any key matching rewards/<name>/mean
|
| 223 |
+
for k in list(log.keys()):
|
| 224 |
+
if k.startswith("rewards/") and k.endswith("/mean") and k not in candidates:
|
| 225 |
+
candidates.append(k)
|
| 226 |
+
for k in candidates:
|
| 227 |
if k in log:
|
| 228 |
training_rewards.append(float(log[k]))
|
| 229 |
break
|
| 230 |
print(f"[job] {len(training_rewards)} reward log points", flush=True)
|
| 231 |
+
if training_rewards:
|
| 232 |
+
print(
|
| 233 |
+
f"[job] reward range: {min(training_rewards):.3f}..{max(training_rewards):.3f}",
|
| 234 |
+
flush=True,
|
| 235 |
+
)
|
| 236 |
|
| 237 |
plot_reward_curve(
|
| 238 |
training_rewards or [0.0],
|