samrat-rm Claude Sonnet 4.6 commited on
Commit
d3b224f
Β·
1 Parent(s): bf98c78

fix: clamp all rewards and scores to [0.10, 0.90]

Browse files

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

client.py CHANGED
@@ -50,7 +50,7 @@ class WhyDidItFailEnv(EnvClient[WhyDidItFailAction, WhyDidItFailObservation, Why
50
  visible_data=obs_data.get("visible_data", {}),
51
  available_actions=obs_data.get("available_actions", []),
52
  steps_taken=obs_data.get("steps_taken", 0),
53
- reward=obs_data.get("reward", 0.01),
54
  done=obs_data.get("done", False),
55
  feedback=obs_data.get("feedback", ""),
56
  )
 
50
  visible_data=obs_data.get("visible_data", {}),
51
  available_actions=obs_data.get("available_actions", []),
52
  steps_taken=obs_data.get("steps_taken", 0),
53
+ reward=obs_data.get("reward", 0.10),
54
  done=obs_data.get("done", False),
55
  feedback=obs_data.get("feedback", ""),
56
  )
inference.py CHANGED
@@ -13,7 +13,7 @@ Stdout format:
13
  [END] success=<bool> steps=<n> reward=<float> (per episode)
14
  [END] score=<float> (final overall)
15
 
16
- All reward and score values are strictly in (0.01, 0.99).
17
  """
18
 
19
  import asyncio
@@ -197,7 +197,7 @@ async def run_episode(
197
  rewards: List[float] = []
198
  inspection_order: List[str] = []
199
  submit_action: WhyDidItFailAction | None = None
200
- score = 0.01
201
  success = False
202
 
203
  try:
@@ -209,10 +209,10 @@ async def run_episode(
209
  try:
210
  result = await env.step(action)
211
  except ConnectionClosedError as e:
212
- print(f"[STEP] step={step} action={action.action_type} reward=0.01 done=true error={e}", flush=True)
213
  break
214
  obs = result.observation
215
- reward = round(max(0.01, min(0.99, result.reward or 0.01)), 2)
216
  done = result.done
217
  if action.action_type in ("inspect_logs", "inspect_config", "inspect_gradients"):
218
  source = action.action_type.replace("inspect_", "")
@@ -231,7 +231,7 @@ async def run_episode(
231
  break
232
 
233
  # WebSocket is closed β€” safe to call the judge now
234
- keyword_score = max(0.01, min(0.99, rewards[-1])) if rewards else 0.01
235
  judge_score: float | None = None
236
  if submit_action is not None:
237
  judge_score = llm_judge(
@@ -244,17 +244,17 @@ async def run_episode(
244
  inspection_order=inspection_order,
245
  )
246
  if judge_score is None:
247
- score = round(max(0.01, min(0.99, keyword_score)), 2)
248
  # print(f" [JUDGE] scenario={scenario_key} keyword={keyword_score:.2f} reasoning=n/a total={score:.2f}", file=sys.stderr, flush=True)
249
  else:
250
- score = round(max(0.01, min(0.99, 0.85 * keyword_score + 0.15 * judge_score)), 2)
251
  # print(f" [JUDGE] scenario={scenario_key} keyword={keyword_score:.3f} reasoning={judge_score:.3f} total={score:.3f}", file=sys.stderr, flush=True)
252
 
253
  success = score >= SUCCESS_THRESHOLD
254
 
255
  finally:
256
  steps_taken = len(rewards)
257
- final_reward = f"{rewards[-1]:.2f}" if rewards else "0.01"
258
  print(f"[END] success={str(success).lower()} steps={steps_taken} reward={final_reward}", flush=True)
259
 
260
  return {"scenario_key": scenario_key, "score": score, "steps": steps_taken, "success": success}, env
@@ -280,7 +280,7 @@ async def run_task(task_name: str, scenario_keys: List[str], env: WhyDidItFailEn
280
  results.append(res)
281
 
282
  scores = [r["score"] for r in results]
283
- task_score = round(max(0.01, min(0.99, sum(scores) / len(scores))), 2) if scores else 0.01
284
  print(f"[END] score={task_score}", flush=True)
285
  return scores
286
 
 
13
  [END] success=<bool> steps=<n> reward=<float> (per episode)
14
  [END] score=<float> (final overall)
15
 
16
+ All reward and score values are strictly in (0.10, 0.90).
17
  """
18
 
19
  import asyncio
 
197
  rewards: List[float] = []
198
  inspection_order: List[str] = []
199
  submit_action: WhyDidItFailAction | None = None
200
+ score = 0.10
201
  success = False
202
 
203
  try:
 
209
  try:
210
  result = await env.step(action)
211
  except ConnectionClosedError as e:
212
+ print(f"[STEP] step={step} action={action.action_type} reward=0.10 done=true error={e}", flush=True)
213
  break
214
  obs = result.observation
215
+ reward = round(max(0.10, min(0.90, result.reward or 0.10)), 2)
216
  done = result.done
217
  if action.action_type in ("inspect_logs", "inspect_config", "inspect_gradients"):
218
  source = action.action_type.replace("inspect_", "")
 
231
  break
232
 
233
  # WebSocket is closed β€” safe to call the judge now
234
+ keyword_score = max(0.10, min(0.90, rewards[-1])) if rewards else 0.10
235
  judge_score: float | None = None
236
  if submit_action is not None:
237
  judge_score = llm_judge(
 
244
  inspection_order=inspection_order,
245
  )
246
  if judge_score is None:
247
+ score = round(max(0.10, min(0.90, keyword_score)), 2)
248
  # print(f" [JUDGE] scenario={scenario_key} keyword={keyword_score:.2f} reasoning=n/a total={score:.2f}", file=sys.stderr, flush=True)
249
  else:
250
+ score = round(max(0.10, min(0.90, 0.85 * keyword_score + 0.15 * judge_score)), 2)
251
  # print(f" [JUDGE] scenario={scenario_key} keyword={keyword_score:.3f} reasoning={judge_score:.3f} total={score:.3f}", file=sys.stderr, flush=True)
252
 
253
  success = score >= SUCCESS_THRESHOLD
254
 
255
  finally:
256
  steps_taken = len(rewards)
257
+ final_reward = f"{rewards[-1]:.2f}" if rewards else "0.10"
258
  print(f"[END] success={str(success).lower()} steps={steps_taken} reward={final_reward}", flush=True)
259
 
260
  return {"scenario_key": scenario_key, "score": score, "steps": steps_taken, "success": success}, env
 
280
  results.append(res)
281
 
282
  scores = [r["score"] for r in results]
283
+ task_score = round(max(0.10, min(0.90, sum(scores) / len(scores))), 2) if scores else 0.10
284
  print(f"[END] score={task_score}", flush=True)
285
  return scores
286
 
models.py CHANGED
@@ -42,8 +42,8 @@ class WhyDidItFailObservation(Observation):
42
  "Which action_types are valid on this step.")
43
  steps_taken: int = Field(..., description=
44
  "Number of actions taken so far in this episode.")
45
- reward: float = Field(default=0.01, description= # type: ignore[override]
46
- "Score for the current step. 0.99 = max.")
47
  done: bool = Field(default=False, description=
48
  "True when the episode has ended.")
49
  feedback: str = Field(..., description=
 
42
  "Which action_types are valid on this step.")
43
  steps_taken: int = Field(..., description=
44
  "Number of actions taken so far in this episode.")
45
+ reward: float = Field(default=0.10, description= # type: ignore[override]
46
+ "Score for the current step. 0.90 = max.")
47
  done: bool = Field(default=False, description=
48
  "True when the episode has ended.")
49
  feedback: str = Field(..., description=
openenv.yaml CHANGED
@@ -25,17 +25,17 @@ tasks:
25
  {response}
26
 
27
  You MUST reply with exactly one of these four numbers and nothing else:
28
- 0.95
29
- 0.70
30
  0.30
31
- 0.05
32
 
33
  Rules:
34
- - 0.95: Correct failure mode with reasoning that cites specific numeric values from the logs
35
- - 0.70: Correct failure mode but reasoning is vague or missing specific numbers
36
  - 0.30: Wrong label but description matches a related concept
37
- - 0.05: Wrong failure mode, no diagnosis submitted, or empty response
38
- - If in doubt, return 0.05. NEVER return 0, 1, 0.0, 1.0, or any value not in the list above.
39
 
40
  - id: task_medium
41
  difficulty: medium
@@ -56,17 +56,17 @@ tasks:
56
  {response}
57
 
58
  You MUST reply with exactly one of these four numbers and nothing else:
59
- 0.95
60
- 0.70
61
  0.30
62
- 0.05
63
 
64
  Rules:
65
- - 0.95: Correct failure mode with reasoning citing both log values AND config parameters
66
- - 0.70: Correct failure mode but reasoning only references logs or config, not both
67
  - 0.30: Wrong label but description matches a related concept
68
- - 0.05: Wrong failure mode, no diagnosis submitted, or empty response
69
- - If in doubt, return 0.05. NEVER return 0, 1, 0.0, 1.0, or any value not in the list above.
70
 
71
  - id: task_hard
72
  difficulty: hard
@@ -88,16 +88,16 @@ tasks:
88
  {response}
89
 
90
  You MUST reply with exactly one of these five numbers and nothing else:
91
- 0.95
92
- 0.80
93
  0.50
94
  0.20
95
- 0.05
96
 
97
  Rules:
98
- - 0.95: Correct failure mode AND a specific actionable fix addressing the root cause
99
- - 0.80: Correct failure mode with a reasonable fix that lacks specifics
100
  - 0.50: Correct failure mode but fix is vague, wrong, or missing
101
  - 0.20: Wrong failure mode but fix is incidentally relevant
102
- - 0.05: Wrong failure mode, no useful fix, no diagnosis submitted, or empty response
103
- - If in doubt, return 0.05. NEVER return 0, 1, 0.0, 1.0, or any value not in the list above.
 
25
  {response}
26
 
27
  You MUST reply with exactly one of these four numbers and nothing else:
28
+ 0.85
29
+ 0.65
30
  0.30
31
+ 0.15
32
 
33
  Rules:
34
+ - 0.85: Correct failure mode with reasoning that cites specific numeric values from the logs
35
+ - 0.65: Correct failure mode but reasoning is vague or missing specific numbers
36
  - 0.30: Wrong label but description matches a related concept
37
+ - 0.15: Wrong failure mode, no diagnosis submitted, or empty response
38
+ - If in doubt, return 0.15. NEVER return 0, 1, 0.0, 1.0, or any value outside [0.10, 0.90].
39
 
40
  - id: task_medium
41
  difficulty: medium
 
56
  {response}
57
 
58
  You MUST reply with exactly one of these four numbers and nothing else:
59
+ 0.85
60
+ 0.65
61
  0.30
62
+ 0.15
63
 
64
  Rules:
65
+ - 0.85: Correct failure mode with reasoning citing both log values AND config parameters
66
+ - 0.65: Correct failure mode but reasoning only references logs or config, not both
67
  - 0.30: Wrong label but description matches a related concept
68
+ - 0.15: Wrong failure mode, no diagnosis submitted, or empty response
69
+ - If in doubt, return 0.15. NEVER return 0, 1, 0.0, 1.0, or any value outside [0.10, 0.90].
70
 
71
  - id: task_hard
72
  difficulty: hard
 
88
  {response}
89
 
90
  You MUST reply with exactly one of these five numbers and nothing else:
91
+ 0.85
92
+ 0.75
93
  0.50
94
  0.20
95
+ 0.15
96
 
97
  Rules:
98
+ - 0.85: Correct failure mode AND a specific actionable fix addressing the root cause
99
+ - 0.75: Correct failure mode with a reasonable fix that lacks specifics
100
  - 0.50: Correct failure mode but fix is vague, wrong, or missing
101
  - 0.20: Wrong failure mode but fix is incidentally relevant
102
+ - 0.15: Wrong failure mode, no useful fix, no diagnosis submitted, or empty response
103
+ - If in doubt, return 0.15. NEVER return 0, 1, 0.0, 1.0, or any value outside [0.10, 0.90].
server/WhyDidItFail_environment.py CHANGED
@@ -56,7 +56,7 @@ class WhyDidItFailEnvironment(Environment):
56
  visible_data={"hint": "Start by inspecting the training logs."},
57
  available_actions=["inspect_logs", "inspect_config", "inspect_gradients", "submit_diagnosis"],
58
  steps_taken=0,
59
- reward=0.01,
60
  done=False,
61
  feedback="Investigation started.",
62
  )
@@ -67,18 +67,18 @@ class WhyDidItFailEnvironment(Environment):
67
 
68
  self._state.step_count += 1
69
 
70
- # Hard step limit β€” terminate immediately, grade() will return 0.01.
71
  if self._state.step_count > self.max_steps and action.action_type != "submit_diagnosis":
72
  return WhyDidItFailObservation(
73
  task_description="Step limit reached. Episode terminated.",
74
  visible_data={},
75
  available_actions=[],
76
  steps_taken=self._state.step_count,
77
- reward=0.01,
78
  done=True,
79
  feedback=(
80
  f"Step limit ({self.max_steps}) reached without a diagnosis. "
81
- f"Score: 0.01. Actual failure: '{self.scenario['correct_diagnosis']}'."
82
  ),
83
  )
84
  required: list[str] = self.scenario.get("required_sources", ["logs"])
@@ -143,30 +143,32 @@ class WhyDidItFailEnvironment(Environment):
143
  visible_data={},
144
  available_actions=["inspect_logs", "inspect_config", "inspect_gradients", "submit_diagnosis"],
145
  steps_taken=self._state.step_count,
146
- reward=-0.05,
147
  done=False,
148
- feedback=f"Unknown action '{action.action_type}'. No reward.",
149
  )
150
 
151
  # Rewards decay as more required sources are discovered β€” first clue is worth most.
152
- _REQUIRED_STEP_REWARDS = [0.10, 0.07, 0.05]
 
153
 
154
  def _inspect_reward(self, source: str, required: list[str]) -> float:
155
  """Return step reward for an inspect action.
156
 
157
- Required sources: progressive β€” +0.10 / +0.07 / +0.05 for 1st/2nd/3rd discovery.
158
- Irrelevant sources: -0.03 (mild; some exploration is acceptable).
159
- Re-inspection: -0.05 (waste).
 
160
  """
161
  if source in self.inspection_order:
162
- return -0.05 # redundant inspection
163
 
164
  if source in required:
165
  n_found = sum(1 for s in self.inspection_order if s in required)
166
  idx = min(n_found, len(self._REQUIRED_STEP_REWARDS) - 1)
167
  return self._REQUIRED_STEP_REWARDS[idx]
168
 
169
- return -0.03 # irrelevant source
170
 
171
  def _inspect_feedback(self, source: str, required: list[str], reward: float) -> str:
172
  label = {"logs": "training logs", "config": "hyperparameter config", "gradients": "gradient statistics"}[source]
 
56
  visible_data={"hint": "Start by inspecting the training logs."},
57
  available_actions=["inspect_logs", "inspect_config", "inspect_gradients", "submit_diagnosis"],
58
  steps_taken=0,
59
+ reward=0.10,
60
  done=False,
61
  feedback="Investigation started.",
62
  )
 
67
 
68
  self._state.step_count += 1
69
 
70
+ # Hard step limit β€” terminate immediately, grade() will return 0.10.
71
  if self._state.step_count > self.max_steps and action.action_type != "submit_diagnosis":
72
  return WhyDidItFailObservation(
73
  task_description="Step limit reached. Episode terminated.",
74
  visible_data={},
75
  available_actions=[],
76
  steps_taken=self._state.step_count,
77
+ reward=0.10,
78
  done=True,
79
  feedback=(
80
  f"Step limit ({self.max_steps}) reached without a diagnosis. "
81
+ f"Score: 0.10. Actual failure: '{self.scenario['correct_diagnosis']}'."
82
  ),
83
  )
84
  required: list[str] = self.scenario.get("required_sources", ["logs"])
 
143
  visible_data={},
144
  available_actions=["inspect_logs", "inspect_config", "inspect_gradients", "submit_diagnosis"],
145
  steps_taken=self._state.step_count,
146
+ reward=0.10,
147
  done=False,
148
+ feedback=f"Unknown action '{action.action_type}'. Minimum reward.",
149
  )
150
 
151
  # Rewards decay as more required sources are discovered β€” first clue is worth most.
152
+ # All values are in [0.10, 0.90] β€” no negative rewards.
153
+ _REQUIRED_STEP_REWARDS = [0.50, 0.30, 0.15]
154
 
155
  def _inspect_reward(self, source: str, required: list[str]) -> float:
156
  """Return step reward for an inspect action.
157
 
158
+ Required sources: progressive β€” 0.50 / 0.30 / 0.15 for 1st/2nd/3rd discovery.
159
+ Irrelevant sources: 0.10 (minimum; mild penalty via contrast with required rewards).
160
+ Re-inspection: 0.10 (minimum; waste with no new information).
161
+ All values are strictly in [0.10, 0.90].
162
  """
163
  if source in self.inspection_order:
164
+ return 0.10 # redundant inspection β€” minimum reward
165
 
166
  if source in required:
167
  n_found = sum(1 for s in self.inspection_order if s in required)
168
  idx = min(n_found, len(self._REQUIRED_STEP_REWARDS) - 1)
169
  return self._REQUIRED_STEP_REWARDS[idx]
170
 
171
+ return 0.10 # irrelevant source β€” minimum reward
172
 
173
  def _inspect_feedback(self, source: str, required: list[str], reward: float) -> str:
174
  label = {"logs": "training logs", "config": "hyperparameter config", "gradients": "gradient statistics"}[source]
server/graders.py CHANGED
@@ -6,7 +6,7 @@ grade() is the single entry point. It scores the full episode trajectory:
6
  diagnosis_score (0.00 – 0.70) was the diagnosis correct?
7
  evidence_score (0.00 – 0.15) did the agent inspect the right sources?
8
  efficiency_score (0.00 – 0.15) did the agent act without waste?
9
- fix_bonus (0.00 – 0.15) did the agent suggest a valid fix? (bonus, capped at 0.99)
10
 
11
  Step-level partial rewards are returned by the environment's step() on every action,
12
  giving the agent a signal over the full trajectory before the episode ends.
@@ -197,7 +197,7 @@ def grade(
197
  Single unified grade function. Scores every scenario identically.
198
 
199
  Total score = diagnosis_score + evidence_score + efficiency_score + fix_bonus
200
- clamped to [0.01, 0.99].
201
 
202
  Max achievable without fix: 0.70 + 0.15 + 0.15 = 1.00
203
  Max achievable with fix: 0.70 + 0.15 + 0.15 + 0.15 = 1.00 (capped)
@@ -210,7 +210,7 @@ def grade(
210
  max_steps = len(required) * 3 + 2 # hard ceiling; exceeding it = total failure
211
 
212
  if steps_taken > max_steps:
213
- return 0.01
214
 
215
  d_score = _diagnosis_score(diagnosis, scenario)
216
  ed_penalty = _evidence_diagnosis_penalty(diagnosis, scenario, inspection_order)
@@ -221,4 +221,4 @@ def grade(
221
 
222
  total = d_score + ed_penalty + e_score + f_score + b_score + o_bonus
223
 
224
- return round(max(0.01, min(0.99, total)), 2)
 
6
  diagnosis_score (0.00 – 0.70) was the diagnosis correct?
7
  evidence_score (0.00 – 0.15) did the agent inspect the right sources?
8
  efficiency_score (0.00 – 0.15) did the agent act without waste?
9
+ fix_bonus (0.00 – 0.15) did the agent suggest a valid fix? (bonus, capped at 0.90)
10
 
11
  Step-level partial rewards are returned by the environment's step() on every action,
12
  giving the agent a signal over the full trajectory before the episode ends.
 
197
  Single unified grade function. Scores every scenario identically.
198
 
199
  Total score = diagnosis_score + evidence_score + efficiency_score + fix_bonus
200
+ clamped to [0.10, 0.90].
201
 
202
  Max achievable without fix: 0.70 + 0.15 + 0.15 = 1.00
203
  Max achievable with fix: 0.70 + 0.15 + 0.15 + 0.15 = 1.00 (capped)
 
210
  max_steps = len(required) * 3 + 2 # hard ceiling; exceeding it = total failure
211
 
212
  if steps_taken > max_steps:
213
+ return 0.10
214
 
215
  d_score = _diagnosis_score(diagnosis, scenario)
216
  ed_penalty = _evidence_diagnosis_penalty(diagnosis, scenario, inspection_order)
 
221
 
222
  total = d_score + ed_penalty + e_score + f_score + b_score + o_bonus
223
 
224
+ return round(max(0.10, min(0.90, total)), 2)
server/llm_judge.py CHANGED
@@ -3,16 +3,16 @@ LLM Judge β€” reasoning quality scorer for WhyDidItFail.
3
 
4
  Called from inference.py after submit_diagnosis.
5
  Uses the same OpenAI-compatible client and model as the agent.
6
- Returns a normalized score in [0.0, 1.0] representing reasoning quality.
7
- Returns 0.0 silently if reasoning is absent or the call fails.
8
 
9
- Scoring criteria (0–5 each, total 0–15 β†’ normalized to 0.0–1.0):
10
  evidence_grounding β€” does the reasoning cite specific observed values?
11
  causal_chain β€” does it connect evidence to the failure mode logically?
12
  fix_rationale β€” is the fix justified by the evidence?
13
 
14
  Final score in inference.py:
15
- total = 0.85 * keyword_score + 0.15 * judge_score β†’ always in [0.0, 1.0]
16
  """
17
 
18
  import json
@@ -90,8 +90,8 @@ def judge(
90
  + data.get("causal_chain", 0)
91
  + data.get("fix_rationale", 0)
92
  )
93
- # normalize: raw 0–15 β†’ 0.01–0.99 (never exact 0 or 1)
94
- return round(max(0.01, min(0.99, raw / 15)), 2)
95
 
96
  except Exception as exc:
97
  print(f" [JUDGE] failed: {exc}", flush=True)
 
3
 
4
  Called from inference.py after submit_diagnosis.
5
  Uses the same OpenAI-compatible client and model as the agent.
6
+ Returns a normalized score in [0.10, 0.90] representing reasoning quality.
7
+ Returns None silently if reasoning is absent or the call fails.
8
 
9
+ Scoring criteria (0–5 each, total 0–15 β†’ normalized to 0.10–0.90):
10
  evidence_grounding β€” does the reasoning cite specific observed values?
11
  causal_chain β€” does it connect evidence to the failure mode logically?
12
  fix_rationale β€” is the fix justified by the evidence?
13
 
14
  Final score in inference.py:
15
+ total = 0.85 * keyword_score + 0.15 * judge_score β†’ always in [0.10, 0.90]
16
  """
17
 
18
  import json
 
90
  + data.get("causal_chain", 0)
91
  + data.get("fix_rationale", 0)
92
  )
93
+ # normalize: raw 0–15 β†’ 0.10–0.90 (never below 0.10 or above 0.90)
94
+ return round(max(0.10, min(0.90, raw / 15)), 2)
95
 
96
  except Exception as exc:
97
  print(f" [JUDGE] failed: {exc}", flush=True)