akhiilll commited on
Commit
9c7f0da
·
verified ·
1 Parent(s): 77cbc5a

fix: capture rewards/reward_repair_function/mean for plot

Browse files
Files changed (1) hide show
  1. scripts/jobs/train_repair_agent.py +18 -1
scripts/jobs/train_repair_agent.py CHANGED
@@ -211,11 +211,28 @@ training_rewards: list[float] = []
211
  if trainer_state.exists():
212
  state = json.loads(trainer_state.read_text())
213
  for log in state.get("log_history", []):
214
- for k in ("rewards/mean", "reward", "train/reward"):
 
 
 
 
 
 
 
 
 
 
 
 
215
  if k in log:
216
  training_rewards.append(float(log[k]))
217
  break
218
  print(f"[job] {len(training_rewards)} reward log points", flush=True)
 
 
 
 
 
219
 
220
  plot_reward_curve(
221
  training_rewards or [0.0],
 
211
  if trainer_state.exists():
212
  state = json.loads(trainer_state.read_text())
213
  for log in state.get("log_history", []):
214
+ # TRL emits a few different reward keys depending on version;
215
+ # try the most specific first, then fall back.
216
+ candidates = [
217
+ "rewards/reward_repair_function/mean",
218
+ "rewards/mean",
219
+ "reward",
220
+ "train/reward",
221
+ ]
222
+ # also pick up any key matching rewards/<name>/mean
223
+ for k in list(log.keys()):
224
+ if k.startswith("rewards/") and k.endswith("/mean") and k not in candidates:
225
+ candidates.append(k)
226
+ for k in candidates:
227
  if k in log:
228
  training_rewards.append(float(log[k]))
229
  break
230
  print(f"[job] {len(training_rewards)} reward log points", flush=True)
231
+ if training_rewards:
232
+ print(
233
+ f"[job] reward range: {min(training_rewards):.3f}..{max(training_rewards):.3f}",
234
+ flush=True,
235
+ )
236
 
237
  plot_reward_curve(
238
  training_rewards or [0.0],