Prasham.Jain Claude Sonnet 4.6 commited on
Commit
d11066d
·
1 Parent(s): 18a3fbf

feat(rewards): Phase C1 — all 9 reward components implemented

Browse files

- FormatGate: schema-validates all ToolCall args via jsonschema + TerminalAction bounds
- DiagnosisReward: asymmetric 7×7 confusion-matrix with operational consequence weights
- ActionQualityReward: secondary action × failure-family matrix; quarantine-on-bug is most catastrophic
- CostEfficiencyReward: linear penalty inversely proportional to budget consumed
- InvestigationReward: coverage × ordering × redundancy shaping reward
- TimePenaltyReward: per-step penalty beyond 6-step reference
- AntiGamingReward: no-info-action guard + rolling quarantine-rate guard + Brier calibration probe
- MinimalEvidenceReward: bonus for correct diagnosis using only the minimal evidence set (weight=0 in v1)
- CounterfactualPredictReward: dormant in v1 (weight=0); implementation preserved for v2
- 67 unit tests across 9 test files (329 total, all passing)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

src/ci_triage_env/rewards/__init__.py CHANGED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ci_triage_env.rewards.action_quality import ActionQualityReward
2
+ from ci_triage_env.rewards.anti_gaming import AntiGamingReward
3
+ from ci_triage_env.rewards.cost_efficiency import CostEfficiencyReward
4
+ from ci_triage_env.rewards.counterfactual_predict import CounterfactualPredictReward
5
+ from ci_triage_env.rewards.diagnosis import DiagnosisReward
6
+ from ci_triage_env.rewards.format_gate import FormatGate
7
+ from ci_triage_env.rewards.investigation import InvestigationReward
8
+ from ci_triage_env.rewards.minimal_evidence import MinimalEvidenceReward
9
+ from ci_triage_env.rewards.time_penalty import TimePenaltyReward
10
+
11
+ __all__ = [
12
+ "ActionQualityReward",
13
+ "AntiGamingReward",
14
+ "CostEfficiencyReward",
15
+ "CounterfactualPredictReward",
16
+ "DiagnosisReward",
17
+ "FormatGate",
18
+ "InvestigationReward",
19
+ "MinimalEvidenceReward",
20
+ "TimePenaltyReward",
21
+ ]
src/ci_triage_env/rewards/action_quality.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ActionQualityReward — secondary action × failure-family matrix.
2
+
3
+ Raw score range: [-2.0, 1.5] (capped). Default weight: 0.20.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ from ci_triage_env.rewards.base import RewardComponent
9
+ from ci_triage_env.schemas.episode import EpisodeTrace
10
+ from ci_triage_env.schemas.reward import ComponentScore
11
+ from ci_triage_env.schemas.scenario import Scenario
12
+
13
+ # (action_name, ground_truth_family) → reward
14
+ ACTION_REWARD_MATRIX: dict[tuple[str, str], float] = {
15
+ ("file_bug", "real_bug"): 1.0,
16
+ ("file_bug", "dependency_drift"): 0.7,
17
+ ("file_bug", "race_flake"): -0.5,
18
+ ("file_bug", "timing_flake"): -0.3,
19
+ ("file_bug", "infra_network"): -0.5,
20
+ ("file_bug", "infra_resource"): -0.5,
21
+ ("file_bug", "ambiguous"): -0.2,
22
+ # Quarantine: ideal for flakes, catastrophic for real bugs
23
+ ("quarantine_test", "race_flake"): 1.0,
24
+ ("quarantine_test", "timing_flake"): 0.8,
25
+ ("quarantine_test", "real_bug"): -1.5,
26
+ ("quarantine_test", "infra_network"): -0.3,
27
+ ("quarantine_test", "infra_resource"): -0.3,
28
+ ("quarantine_test", "dependency_drift"): -0.5,
29
+ ("quarantine_test", "ambiguous"): -0.3,
30
+ # Rerun: right for transient failures, bad for bugs
31
+ ("rerun_test", "race_flake"): 0.6,
32
+ ("rerun_test", "timing_flake"): 0.6,
33
+ ("rerun_test", "infra_network"): 0.8,
34
+ ("rerun_test", "infra_resource"): 0.5,
35
+ ("rerun_test", "real_bug"): -0.6,
36
+ ("rerun_test", "dependency_drift"): -0.3,
37
+ ("rerun_test", "ambiguous"): 0.2,
38
+ # Ping owner: escalates to the right team
39
+ ("ping_owner", "infra_resource"): 0.7,
40
+ ("ping_owner", "infra_network"): 0.5,
41
+ ("ping_owner", "real_bug"): 0.4,
42
+ ("ping_owner", "dependency_drift"): 0.6,
43
+ ("ping_owner", "race_flake"): 0.0,
44
+ ("ping_owner", "timing_flake"): 0.0,
45
+ ("ping_owner", "ambiguous"): 0.3,
46
+ }
47
+
48
+ _RAW_MIN = -2.0
49
+ _RAW_MAX = 1.5
50
+
51
+
52
+ class ActionQualityReward(RewardComponent):
53
+ """Reward for secondary actions taken alongside the diagnosis.
54
+
55
+ Multiple secondary actions are summed then capped to [-2.0, 1.5].
56
+ No secondary actions → neutral (0.0). No terminal action → -0.5.
57
+ """
58
+
59
+ name = "action_quality"
60
+ default_weight = 0.20
61
+
62
+ def score(self, trace: EpisodeTrace, scenario: Scenario) -> ComponentScore:
63
+ if trace.episode.final_action is None:
64
+ raw = -0.5
65
+ return ComponentScore(
66
+ raw=raw,
67
+ weighted=raw * self.default_weight,
68
+ weight=self.default_weight,
69
+ sub_scores={"no_action": -0.5},
70
+ )
71
+
72
+ true = scenario.ground_truth.label.value
73
+ secondary = trace.episode.final_action.secondary_actions
74
+
75
+ if not secondary:
76
+ return ComponentScore(
77
+ raw=0.0,
78
+ weighted=0.0,
79
+ weight=self.default_weight,
80
+ sub_scores={"no_secondary": 0.0},
81
+ )
82
+
83
+ sub_scores: dict[str, float] = {}
84
+ total = 0.0
85
+ for sa in secondary:
86
+ r = ACTION_REWARD_MATRIX.get((sa.name, true), 0.0)
87
+ sub_scores[sa.name] = r
88
+ total += r
89
+
90
+ capped = max(min(total, _RAW_MAX), _RAW_MIN)
91
+ return ComponentScore(
92
+ raw=capped,
93
+ weighted=capped * self.default_weight,
94
+ weight=self.default_weight,
95
+ sub_scores=sub_scores,
96
+ )
src/ci_triage_env/rewards/anti_gaming.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """AntiGamingReward — three guards against reward exploitation.
2
+
3
+ Guards:
4
+ 1. No-info-action: terminal with < 2 tool calls → -0.5
5
+ 2. Quarantine-rate: rolling-window over-use of quarantine_test → penalty
6
+ 3. Brier calibration: on ambiguous scenarios, penalises mis-calibrated confidence
7
+
8
+ Raw score range: [-1.5, 1.0]. Default weight: 0.15.
9
+ Quarantine-rate state is injected at construction; empty list → no penalty.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ from ci_triage_env.rewards.base import RewardComponent
15
+ from ci_triage_env.schemas.action import ToolCall
16
+ from ci_triage_env.schemas.episode import EpisodeTrace
17
+ from ci_triage_env.schemas.reward import ComponentScore
18
+ from ci_triage_env.schemas.scenario import Scenario
19
+
20
+ _QUARANTINE_THRESHOLD = 0.30
21
+ _RAW_MIN = -1.5
22
+ _RAW_MAX = 1.0
23
+
24
+
25
+ class AntiGamingReward(RewardComponent):
26
+ """Guards against common reward-gaming strategies.
27
+
28
+ Raw score range: [-1.5, 1.0].
29
+
30
+ Args:
31
+ recent_episode_actions: Names of the primary secondary actions taken in
32
+ the last N episodes. Supplied by the trainer's rolling-window state.
33
+ Pass an empty list for unit tests (no quarantine-rate pressure).
34
+ """
35
+
36
+ name = "anti_gaming"
37
+ default_weight = 0.15
38
+
39
+ def __init__(self, recent_episode_actions: list[str] | None = None) -> None:
40
+ self.recent_actions: list[str] = recent_episode_actions or []
41
+
42
+ def score(self, trace: EpisodeTrace, scenario: Scenario) -> ComponentScore:
43
+ sub: dict[str, float] = {}
44
+
45
+ # Guard 1: must gather at least 2 tool calls before diagnosing
46
+ n_tool_calls = sum(1 for r in trace.episode.history if isinstance(r.action, ToolCall))
47
+ if trace.episode.final_action is not None and n_tool_calls < 2:
48
+ no_info_penalty = -0.5
49
+ else:
50
+ no_info_penalty = 0.0
51
+ sub["no_info_penalty"] = no_info_penalty
52
+
53
+ # Guard 2: quarantine over-use relative to a rolling window
54
+ quarantine_rate = self._compute_quarantine_rate()
55
+ if quarantine_rate > _QUARANTINE_THRESHOLD:
56
+ quarantine_penalty = -(quarantine_rate - _QUARANTINE_THRESHOLD) * 2.0
57
+ else:
58
+ quarantine_penalty = 0.0
59
+ sub["quarantine_rate"] = quarantine_rate
60
+ sub["quarantine_penalty"] = quarantine_penalty
61
+
62
+ # Guard 3: Brier calibration probe (ambiguous scenarios only)
63
+ brier_bonus = 0.0
64
+ if scenario.ground_truth.is_ambiguous:
65
+ target = scenario.ground_truth.confidence_target
66
+ if trace.episode.final_action is not None:
67
+ pred_conf = trace.episode.final_action.confidence
68
+ brier = (pred_conf - target) ** 2
69
+ brier_bonus = 0.5 * (1.0 - brier)
70
+ else:
71
+ brier_bonus = -0.5
72
+ sub["brier_bonus"] = brier_bonus
73
+
74
+ raw = no_info_penalty + quarantine_penalty + brier_bonus
75
+ raw = max(min(raw, _RAW_MAX), _RAW_MIN)
76
+ return ComponentScore(
77
+ raw=raw,
78
+ weighted=raw * self.default_weight,
79
+ weight=self.default_weight,
80
+ sub_scores=sub,
81
+ )
82
+
83
+ def _compute_quarantine_rate(self) -> float:
84
+ if not self.recent_actions:
85
+ return 0.0
86
+ return sum(1 for a in self.recent_actions if a == "quarantine_test") / len(self.recent_actions)
src/ci_triage_env/rewards/cost_efficiency.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """CostEfficiencyReward — penalises high tool-call cost spend.
2
+
3
+ Raw score range: [-1.0, 1.0]. Default weight: 0.15.
4
+ Mapping: 0 cost → 1.0; full BUDGET_REFERENCE spend → -1.0.
5
+ Over-budget episodes are not possible (env enforces budget), so ratio is clamped at 1.0.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from ci_triage_env.rewards.base import RewardComponent
11
+ from ci_triage_env.schemas.episode import EpisodeTrace
12
+ from ci_triage_env.schemas.reward import ComponentScore
13
+ from ci_triage_env.schemas.scenario import Scenario
14
+
15
+
16
+ class CostEfficiencyReward(RewardComponent):
17
+ """Linear reward inversely proportional to total cost spent.
18
+
19
+ Raw score range: [-1.0, 1.0].
20
+ """
21
+
22
+ name = "cost_efficiency"
23
+ default_weight = 0.15
24
+
25
+ BUDGET_REFERENCE: float = 5.0
26
+
27
+ def score(self, trace: EpisodeTrace, scenario: Scenario) -> ComponentScore:
28
+ total_spent = sum(rec.cost_charged for rec in trace.episode.history)
29
+ ratio = total_spent / self.BUDGET_REFERENCE
30
+ raw = 1.0 - 2.0 * min(ratio, 1.0)
31
+ return ComponentScore(
32
+ raw=raw,
33
+ weighted=raw * self.default_weight,
34
+ weight=self.default_weight,
35
+ sub_scores={"total_cost": total_spent, "ratio": ratio},
36
+ )
src/ci_triage_env/rewards/counterfactual_predict.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """CounterfactualPredictReward — DORMANT in v1.
2
+
3
+ Counterfactual probe is deferred to v2. In v1 the env never fires probes
4
+ (trace.counterfactual_replay is always None), so this component always returns
5
+ (raw=0.0, weight=0.0). The implementation is preserved so v2 re-enable is a
6
+ purely additive change: set default_weight to 0.10 in weights.py.
7
+
8
+ Raw score range: [-0.5, 1.0]. Default weight: 0.0 (dormant).
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ from ci_triage_env.rewards.base import RewardComponent
14
+ from ci_triage_env.schemas.episode import EpisodeTrace
15
+ from ci_triage_env.schemas.reward import ComponentScore
16
+ from ci_triage_env.schemas.scenario import Scenario
17
+
18
+
19
+ class CounterfactualPredictReward(RewardComponent):
20
+ """Rewards correct prediction of the counterfactual probe outcome.
21
+
22
+ DORMANT in v1: default_weight=0.0 and trace.counterfactual_replay is always
23
+ None, so score() always returns zero contribution.
24
+ Raw score range: [-0.5, 1.0].
25
+ """
26
+
27
+ name = "counterfactual"
28
+ default_weight = 0.0
29
+
30
+ def score(self, trace: EpisodeTrace, scenario: Scenario) -> ComponentScore:
31
+ # v1: probes never fire; replay list is None or empty
32
+ if not trace.counterfactual_replay:
33
+ return ComponentScore(
34
+ raw=0.0,
35
+ weighted=0.0,
36
+ weight=self.default_weight,
37
+ sub_scores={"fired": 0.0},
38
+ )
39
+
40
+ # v2 path (reachable only when probes are enabled):
41
+ # The replay records encode the probe action and its observed outcome.
42
+ # Compare the agent's predicted outcome (last record) vs actual terminal.
43
+ predicted_record = trace.counterfactual_replay[-1]
44
+ actual_record = trace.episode.history[-1] if trace.episode.history else None
45
+
46
+ if actual_record is not None and predicted_record.action == actual_record.action:
47
+ raw = 1.0
48
+ else:
49
+ raw = -0.5
50
+
51
+ return ComponentScore(
52
+ raw=raw,
53
+ weighted=raw * self.default_weight,
54
+ weight=self.default_weight,
55
+ sub_scores={"fired": 1.0},
56
+ )
src/ci_triage_env/rewards/diagnosis.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """DiagnosisReward — asymmetric confusion-matrix reward.
2
+
3
+ Raw score range: [-1.0, 1.0]. Default weight: 0.25.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ from ci_triage_env.rewards.base import RewardComponent
9
+ from ci_triage_env.schemas.episode import EpisodeTrace
10
+ from ci_triage_env.schemas.reward import ComponentScore
11
+ from ci_triage_env.schemas.scenario import Scenario
12
+
13
+ # (predicted, true) → reward
14
+ # Diagonal = 1.0; off-diagonal is asymmetric based on operational consequence.
15
+ # Worst: quarantine a real bug (ships to prod); best: correct identification.
16
+ DIAGNOSIS_REWARD_MATRIX: dict[tuple[str, str], float] = {
17
+ ("real_bug", "real_bug"): 1.0,
18
+ ("race_flake", "race_flake"): 1.0,
19
+ ("timing_flake", "timing_flake"): 1.0,
20
+ ("infra_network", "infra_network"): 1.0,
21
+ ("infra_resource", "infra_resource"): 1.0,
22
+ ("dependency_drift", "dependency_drift"): 1.0,
23
+ ("ambiguous", "ambiguous"): 1.0,
24
+ # Worst: predicting flake when it's a real bug (ships to prod)
25
+ ("race_flake", "real_bug"): -1.0,
26
+ ("timing_flake", "real_bug"): -1.0,
27
+ ("ambiguous", "real_bug"): -0.7,
28
+ # Bad: predicting infra when it's a real bug (file with wrong team)
29
+ ("infra_network", "real_bug"): -0.5,
30
+ ("infra_resource", "real_bug"): -0.5,
31
+ ("dependency_drift", "real_bug"): -0.4,
32
+ # Bad: predicting bug when it's a flake (false-alarm noise)
33
+ ("real_bug", "race_flake"): -0.3,
34
+ ("real_bug", "timing_flake"): -0.3,
35
+ # Bad: predicting bug when it's infra (wastes engineering time)
36
+ ("real_bug", "infra_network"): -0.4,
37
+ ("real_bug", "infra_resource"): -0.4,
38
+ ("real_bug", "dependency_drift"): -0.2,
39
+ # Mild: confusing similar families
40
+ ("race_flake", "timing_flake"): 0.2,
41
+ ("timing_flake", "race_flake"): 0.2,
42
+ ("infra_network", "infra_resource"): 0.1,
43
+ ("infra_resource", "infra_network"): 0.1,
44
+ # Abstaining on clear non-real-bug causes
45
+ ("ambiguous", "race_flake"): 0.0,
46
+ ("ambiguous", "timing_flake"): 0.0,
47
+ ("ambiguous", "infra_network"): 0.0,
48
+ ("ambiguous", "infra_resource"): 0.0,
49
+ ("ambiguous", "dependency_drift"): 0.0,
50
+ }
51
+
52
+ _DEFAULT_OFF_DIAGONAL = -0.5
53
+
54
+
55
+ def lookup_reward(predicted: str, true: str) -> float:
56
+ return DIAGNOSIS_REWARD_MATRIX.get((predicted, true), _DEFAULT_OFF_DIAGONAL)
57
+
58
+
59
+ class DiagnosisReward(RewardComponent):
60
+ """Reward based on predicted vs. true failure family.
61
+
62
+ Raw score range: [-1.0, 1.0]. No-terminal penalty: -1.0.
63
+ """
64
+
65
+ name = "diagnosis"
66
+ default_weight = 0.25
67
+
68
+ def score(self, trace: EpisodeTrace, scenario: Scenario) -> ComponentScore:
69
+ if trace.episode.final_action is None:
70
+ raw = -1.0
71
+ return ComponentScore(
72
+ raw=raw,
73
+ weighted=raw * self.default_weight,
74
+ weight=self.default_weight,
75
+ sub_scores={"no_diagnosis": -1.0},
76
+ )
77
+ predicted = trace.episode.final_action.diagnosis.value
78
+ true = scenario.ground_truth.label.value
79
+ raw = lookup_reward(predicted, true)
80
+ return ComponentScore(
81
+ raw=raw,
82
+ weighted=raw * self.default_weight,
83
+ weight=self.default_weight,
84
+ sub_scores={
85
+ "matrix_lookup": raw,
86
+ "predicted": 1.0 if predicted == true else 0.0,
87
+ },
88
+ )
src/ci_triage_env/rewards/format_gate.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FormatGate — validates trajectory schema compliance.
2
+
3
+ Returns 1.0 (all records valid) or 0.0 (first violation found).
4
+ Raw score range: {0.0, 1.0}. Used as a multiplicative gate in composite.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import jsonschema
10
+
11
+ from ci_triage_env.rewards.base import RewardComponent
12
+ from ci_triage_env.schemas.action import TerminalAction, ToolCall
13
+ from ci_triage_env.schemas.diagnosis import DiagnosisLabel
14
+ from ci_triage_env.schemas.episode import EpisodeTrace
15
+ from ci_triage_env.schemas.reward import ComponentScore
16
+ from ci_triage_env.schemas.scenario import Scenario
17
+ from ci_triage_env.schemas.tools import ALL_TOOLS
18
+
19
+ TOOL_DEF_BY_NAME: dict = {t.name: t for t in ALL_TOOLS}
20
+
21
+
22
+ class FormatGate(RewardComponent):
23
+ """Validates every ToolCall args against the tool's args_schema and every
24
+ TerminalAction against the DiagnosisLabel enum + confidence bounds.
25
+
26
+ Returns 0.0 (gate fails) or 1.0 (passes). The composite uses this as a
27
+ multiplicative gate: total = format_gate * weighted_sum.
28
+ """
29
+
30
+ name = "format_gate"
31
+ default_weight = 1.0
32
+
33
+ def score(self, trace: EpisodeTrace, scenario: Scenario) -> ComponentScore:
34
+ for record in trace.episode.history:
35
+ if isinstance(record.action, ToolCall):
36
+ tool_def = TOOL_DEF_BY_NAME.get(record.action.tool_name)
37
+ if tool_def is None:
38
+ return self._fail("unknown_tool")
39
+ try:
40
+ jsonschema.validate(record.action.args, tool_def.args_schema)
41
+ except jsonschema.ValidationError:
42
+ return self._fail("args_invalid")
43
+ elif isinstance(record.action, TerminalAction):
44
+ if record.action.diagnosis not in DiagnosisLabel:
45
+ return self._fail("invalid_diagnosis")
46
+ if not (0.0 <= record.action.confidence <= 1.0):
47
+ return self._fail("confidence_oob")
48
+
49
+ # v1: counterfactual_replay is a list of StepRecords or None; probes never fire
50
+ if trace.counterfactual_replay is not None and len(trace.counterfactual_replay) > 0:
51
+ # Any probe records must themselves contain valid actions
52
+ for record in trace.counterfactual_replay:
53
+ if isinstance(record.action, ToolCall):
54
+ tool_def = TOOL_DEF_BY_NAME.get(record.action.tool_name)
55
+ if tool_def is None:
56
+ return self._fail("probe_unknown_tool")
57
+
58
+ return ComponentScore(
59
+ raw=1.0,
60
+ weighted=1.0,
61
+ weight=self.default_weight,
62
+ sub_scores={"valid": 1.0},
63
+ )
64
+
65
+ def _fail(self, reason: str) -> ComponentScore:
66
+ return ComponentScore(
67
+ raw=0.0,
68
+ weighted=0.0,
69
+ weight=self.default_weight,
70
+ sub_scores={"reason": 0.0, "reason_code": 0.0},
71
+ )
src/ci_triage_env/rewards/investigation.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """InvestigationReward — shaping reward for evidence-gathering quality.
2
+
3
+ Combines:
4
+ - coverage: fraction of informative_tools that were called (weight 0.6)
5
+ - ordering: cheap-before-expensive bonus (weight 0.2)
6
+ - redundancy_penalty: -0.1 per duplicate (tool_name, args) call
7
+
8
+ Raw score range: [-1.0, 1.0]. Default weight: 0.15.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import json
14
+
15
+ from ci_triage_env.rewards.base import RewardComponent
16
+ from ci_triage_env.schemas.action import ToolCall
17
+ from ci_triage_env.schemas.episode import EpisodeTrace
18
+ from ci_triage_env.schemas.reward import ComponentScore
19
+ from ci_triage_env.schemas.scenario import Scenario
20
+
21
+ _CHEAP_TOOLS = frozenset({
22
+ "read_logs", "query_flake_history", "recent_commits",
23
+ "check_owner", "inspect_test_code", "cluster_metrics",
24
+ })
25
+ _EXPENSIVE_TOOLS = frozenset({
26
+ "rerun_test", "run_diagnostic", "file_bug", "ping_owner", "quarantine_test",
27
+ })
28
+
29
+
30
+ class InvestigationReward(RewardComponent):
31
+ """Shaping reward for how well the agent investigates the failure.
32
+
33
+ Raw score range: [-1.0, 1.0].
34
+ """
35
+
36
+ name = "investigation"
37
+ default_weight = 0.15
38
+
39
+ def score(self, trace: EpisodeTrace, scenario: Scenario) -> ComponentScore:
40
+ called_tools = [
41
+ rec.action.tool_name
42
+ for rec in trace.episode.history
43
+ if isinstance(rec.action, ToolCall)
44
+ ]
45
+
46
+ # Coverage: fraction of informative_tools called
47
+ informative = set(scenario.informative_tools)
48
+ called_informative = sum(1 for t in called_tools if t in informative)
49
+ coverage = called_informative / max(len(informative), 1)
50
+
51
+ # Redundancy: duplicate (tool_name, sorted-args-json) calls
52
+ seen_calls: set[tuple[str, str]] = set()
53
+ redundancy_count = 0
54
+ for rec in trace.episode.history:
55
+ if isinstance(rec.action, ToolCall):
56
+ key = (rec.action.tool_name, json.dumps(rec.action.args, sort_keys=True))
57
+ if key in seen_calls:
58
+ redundancy_count += 1
59
+ seen_calls.add(key)
60
+ redundancy_penalty = -0.1 * redundancy_count
61
+
62
+ # Ordering: cheap tools should precede expensive tools
63
+ ordering = self._compute_ordering_score(called_tools)
64
+
65
+ raw = 0.6 * coverage + 0.2 * ordering + redundancy_penalty
66
+ raw = max(min(raw, 1.0), -1.0)
67
+
68
+ return ComponentScore(
69
+ raw=raw,
70
+ weighted=raw * self.default_weight,
71
+ weight=self.default_weight,
72
+ sub_scores={
73
+ "coverage": coverage,
74
+ "ordering": ordering,
75
+ "redundancy_penalty": redundancy_penalty,
76
+ },
77
+ )
78
+
79
+ def _compute_ordering_score(self, tools: list[str]) -> float:
80
+ violations = 0
81
+ seen_expensive = False
82
+ for t in tools:
83
+ if t in _EXPENSIVE_TOOLS:
84
+ seen_expensive = True
85
+ elif t in _CHEAP_TOOLS and seen_expensive:
86
+ violations += 1
87
+ return max(1.0 - 0.2 * violations, 0.0)
src/ci_triage_env/rewards/minimal_evidence.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MinimalEvidenceReward — bonus for diagnosing correctly with the minimal tool set.
2
+
3
+ Default weight: 0.0 — this component is NOT in the additive composite directly.
4
+ In Phase C2 its score modifies the InvestigationReward via a multiplier.
5
+ Raw score range: [-0.5, 1.0].
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from ci_triage_env.rewards.base import RewardComponent
11
+ from ci_triage_env.schemas.action import ToolCall
12
+ from ci_triage_env.schemas.episode import EpisodeTrace
13
+ from ci_triage_env.schemas.reward import ComponentScore
14
+ from ci_triage_env.schemas.scenario import Scenario
15
+
16
+
17
+ class MinimalEvidenceReward(RewardComponent):
18
+ """Bonus when the agent reaches the correct diagnosis using only the minimal evidence set.
19
+
20
+ If minimal_evidence_set is empty (ambiguous scenarios), returns 0.0.
21
+ Raw score range: [-0.5, 1.0]. Default weight: 0.0 (folded into InvestigationReward in C2).
22
+ """
23
+
24
+ name = "minimal_evidence"
25
+ default_weight = 0.0
26
+
27
+ def score(self, trace: EpisodeTrace, scenario: Scenario) -> ComponentScore:
28
+ min_set = set(scenario.minimal_evidence_set)
29
+ if not min_set:
30
+ return ComponentScore(
31
+ raw=0.0, weighted=0.0, weight=self.default_weight, sub_scores={}
32
+ )
33
+
34
+ called = {
35
+ rec.action.tool_name
36
+ for rec in trace.episode.history
37
+ if isinstance(rec.action, ToolCall)
38
+ }
39
+
40
+ final = trace.episode.final_action
41
+ correct_diagnosis = (
42
+ final is not None
43
+ and final.diagnosis.value == scenario.ground_truth.label.value
44
+ )
45
+
46
+ if correct_diagnosis:
47
+ min_used = called & min_set
48
+ extra = called - min_set
49
+ if min_used == min_set:
50
+ # All minimal evidence used; small penalty for extras
51
+ bonus = max(min(1.0 - 0.1 * len(extra), 1.0), -0.5)
52
+ else:
53
+ bonus = 0.3 # correct answer but didn't use all key evidence
54
+ else:
55
+ bonus = 0.0
56
+
57
+ return ComponentScore(
58
+ raw=bonus,
59
+ weighted=bonus * self.default_weight,
60
+ weight=self.default_weight,
61
+ sub_scores={
62
+ "min_set_used": float(len(called & min_set)),
63
+ "extras": float(len(called - min_set)),
64
+ },
65
+ )
src/ci_triage_env/rewards/time_penalty.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """TimePenaltyReward — penalises episodes that take more than REFERENCE_STEPS tool calls.
2
+
3
+ Raw score range: [-1.0, 0.0]. Default weight: 0.10.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ from ci_triage_env.rewards.base import RewardComponent
9
+ from ci_triage_env.schemas.action import ToolCall
10
+ from ci_triage_env.schemas.episode import EpisodeTrace
11
+ from ci_triage_env.schemas.reward import ComponentScore
12
+ from ci_triage_env.schemas.scenario import Scenario
13
+
14
+
15
+ class TimePenaltyReward(RewardComponent):
16
+ """Linear per-step penalty beyond REFERENCE_STEPS tool calls.
17
+
18
+ 0 to REFERENCE_STEPS calls → 0.0. Each extra step → -PER_STEP_PENALTY.
19
+ Floor at -1.0. Raw score range: [-1.0, 0.0].
20
+ """
21
+
22
+ name = "time"
23
+ default_weight = 0.10
24
+
25
+ PER_STEP_PENALTY: float = 0.02
26
+ REFERENCE_STEPS: int = 6
27
+
28
+ def score(self, trace: EpisodeTrace, scenario: Scenario) -> ComponentScore:
29
+ steps = sum(1 for r in trace.episode.history if isinstance(r.action, ToolCall))
30
+ excess = max(0, steps - self.REFERENCE_STEPS)
31
+ raw = max(-self.PER_STEP_PENALTY * excess, -1.0)
32
+ return ComponentScore(
33
+ raw=raw,
34
+ weighted=raw * self.default_weight,
35
+ weight=self.default_weight,
36
+ sub_scores={"steps": float(steps), "excess": float(excess)},
37
+ )
tests/rewards/__init__.py ADDED
File without changes
tests/rewards/test_action_quality.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for ActionQualityReward component."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from ci_triage_env.mock import make_mock_scenario, make_mock_trajectory
6
+ from ci_triage_env.rewards.action_quality import ACTION_REWARD_MATRIX, ActionQualityReward
7
+ from ci_triage_env.schemas.action import SecondaryAction, TerminalAction
8
+
9
+
10
+ def _patch_secondary(trace, secondary_actions):
11
+ new_terminal = TerminalAction(
12
+ action_type="submit_diagnosis",
13
+ diagnosis=trace.episode.final_action.diagnosis,
14
+ confidence=trace.episode.final_action.confidence,
15
+ secondary_actions=secondary_actions,
16
+ )
17
+ return trace.model_copy(
18
+ update={"episode": trace.episode.model_copy(update={"final_action": new_terminal})}
19
+ )
20
+
21
+
22
+ def test_action_quality_correct_case_returns_high_score() -> None:
23
+ scenario = make_mock_scenario("real_bug")
24
+ trace = make_mock_trajectory(scenario, outcome="good")
25
+ patched = _patch_secondary(trace, [SecondaryAction(name="file_bug", args={})])
26
+ score = ActionQualityReward().score(patched, scenario)
27
+ assert score.raw > 0.5
28
+
29
+
30
+ def test_action_quality_wrong_case_returns_low_score() -> None:
31
+ scenario = make_mock_scenario("real_bug")
32
+ trace = make_mock_trajectory(scenario, outcome="good")
33
+ # quarantine_test on real_bug is catastrophically bad
34
+ patched = _patch_secondary(trace, [SecondaryAction(name="quarantine_test", args={})])
35
+ score = ActionQualityReward().score(patched, scenario)
36
+ assert score.raw < 0.0
37
+
38
+
39
+ def test_action_quality_handles_no_terminal_action() -> None:
40
+ scenario = make_mock_scenario("real_bug")
41
+ trace = make_mock_trajectory(scenario, outcome="good")
42
+ no_terminal = trace.model_copy(
43
+ update={"episode": trace.episode.model_copy(update={"final_action": None})}
44
+ )
45
+ score = ActionQualityReward().score(no_terminal, scenario)
46
+ assert score.raw == -0.5
47
+
48
+
49
+ def test_action_quality_deterministic() -> None:
50
+ scenario = make_mock_scenario("race_flake")
51
+ trace = make_mock_trajectory(scenario, outcome="good")
52
+ comp = ActionQualityReward()
53
+ s1 = comp.score(trace, scenario)
54
+ s2 = comp.score(trace, scenario)
55
+ assert s1.raw == s2.raw
56
+
57
+
58
+ def test_action_quality_score_is_in_documented_range() -> None:
59
+ scenario = make_mock_scenario("real_bug")
60
+ trace = make_mock_trajectory(scenario, outcome="good")
61
+ # Maximum stacking: multiple good actions
62
+ patched = _patch_secondary(trace, [
63
+ SecondaryAction(name="file_bug", args={}),
64
+ SecondaryAction(name="ping_owner", args={}),
65
+ ])
66
+ score = ActionQualityReward().score(patched, scenario)
67
+ assert -2.0 <= score.raw <= 1.5
68
+
69
+
70
+ def test_action_quality_subscores_are_meaningful() -> None:
71
+ scenario = make_mock_scenario("real_bug")
72
+ trace = make_mock_trajectory(scenario, outcome="good")
73
+ patched = _patch_secondary(trace, [SecondaryAction(name="file_bug", args={})])
74
+ score = ActionQualityReward().score(patched, scenario)
75
+ assert "file_bug" in score.sub_scores
76
+
77
+
78
+ def test_quarantine_real_bug_is_worst() -> None:
79
+ worst = ACTION_REWARD_MATRIX[("quarantine_test", "real_bug")]
80
+ assert worst == -1.5
81
+ all_values = list(ACTION_REWARD_MATRIX.values())
82
+ assert all(v >= worst for v in all_values)
83
+
84
+
85
+ def test_action_quality_no_secondary_neutral() -> None:
86
+ scenario = make_mock_scenario("real_bug")
87
+ trace = make_mock_trajectory(scenario, outcome="good")
88
+ # Default mock trajectory has no secondary actions
89
+ score = ActionQualityReward().score(trace, scenario)
90
+ assert score.raw == 0.0
tests/rewards/test_anti_gaming.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for AntiGamingReward component."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import pytest
6
+
7
+ from ci_triage_env.mock import make_mock_scenario, make_mock_trajectory
8
+ from ci_triage_env.rewards.anti_gaming import AntiGamingReward
9
+ from ci_triage_env.schemas.action import ToolCall
10
+ from ci_triage_env.schemas.episode import StepRecord
11
+ from ci_triage_env.schemas.observation import BudgetState, Observation
12
+
13
+
14
+ def _dummy_obs() -> Observation:
15
+ return Observation(
16
+ episode_id="test",
17
+ step=0,
18
+ failure_summary=None,
19
+ tool_response=None,
20
+ budget_remaining=BudgetState(tool_calls_remaining=10, cost_remaining=1.0),
21
+ is_terminal=False,
22
+ probe_question=None,
23
+ )
24
+
25
+
26
+ def test_anti_gaming_correct_case_returns_high_score() -> None:
27
+ scenario = make_mock_scenario("real_bug")
28
+ trace = make_mock_trajectory(scenario, outcome="good")
29
+ score = AntiGamingReward().score(trace, scenario)
30
+ # Good trajectory: ≥ 2 tool calls, no quarantine abuse, non-ambiguous
31
+ assert score.raw >= 0.0
32
+
33
+
34
+ def test_anti_gaming_wrong_case_returns_low_score() -> None:
35
+ scenario = make_mock_scenario("real_bug")
36
+ trace = make_mock_trajectory(scenario, outcome="good")
37
+ # Simulate no-info-action: only 1 tool call before terminal
38
+ one_tool = [
39
+ StepRecord(step=0, action=ToolCall(tool_name="read_logs", args={"scope": "full"}),
40
+ observation=_dummy_obs(), cost_charged=0.001)
41
+ ]
42
+ patched = trace.model_copy(
43
+ update={"episode": trace.episode.model_copy(update={"history": one_tool})}
44
+ )
45
+ score = AntiGamingReward().score(patched, scenario)
46
+ assert score.raw <= -0.5
47
+
48
+
49
+ def test_anti_gaming_handles_no_terminal_action() -> None:
50
+ scenario = make_mock_scenario("real_bug")
51
+ trace = make_mock_trajectory(scenario, outcome="good")
52
+ no_terminal = trace.model_copy(
53
+ update={"episode": trace.episode.model_copy(update={"final_action": None})}
54
+ )
55
+ score = AntiGamingReward().score(no_terminal, scenario)
56
+ # No terminal → no no-info-action penalty; result depends on other guards
57
+ assert -1.5 <= score.raw <= 1.0
58
+
59
+
60
+ def test_anti_gaming_deterministic() -> None:
61
+ scenario = make_mock_scenario("race_flake")
62
+ trace = make_mock_trajectory(scenario, outcome="good")
63
+ comp = AntiGamingReward()
64
+ s1 = comp.score(trace, scenario)
65
+ s2 = comp.score(trace, scenario)
66
+ assert s1.raw == s2.raw
67
+
68
+
69
+ def test_anti_gaming_score_is_in_documented_range() -> None:
70
+ for family in ["real_bug", "race_flake", "ambiguous"]:
71
+ scenario = make_mock_scenario(family)
72
+ for outcome in ["good", "bad"]:
73
+ trace = make_mock_trajectory(scenario, outcome=outcome)
74
+ score = AntiGamingReward().score(trace, scenario)
75
+ assert -1.5 <= score.raw <= 1.0
76
+
77
+
78
+ def test_anti_gaming_subscores_are_meaningful() -> None:
79
+ scenario = make_mock_scenario("real_bug")
80
+ trace = make_mock_trajectory(scenario, outcome="good")
81
+ score = AntiGamingReward().score(trace, scenario)
82
+ assert "no_info_penalty" in score.sub_scores
83
+ assert "quarantine_rate" in score.sub_scores
84
+ assert "brier_bonus" in score.sub_scores
85
+
86
+
87
+ def test_brier_calibration_perfect_match_bonus() -> None:
88
+ scenario = make_mock_scenario("ambiguous")
89
+ # confidence_target=0.5 for ambiguous mock; abstain trajectory uses confidence=0.5
90
+ trace = make_mock_trajectory(scenario, outcome="abstain")
91
+ score = AntiGamingReward().score(trace, scenario)
92
+ # Perfect match: brier=(0.5-0.5)^2=0; bonus=0.5*(1-0)=0.5
93
+ assert score.sub_scores["brier_bonus"] == pytest.approx(0.5, abs=1e-6)
94
+
95
+
96
+ def test_quarantine_rate_above_threshold_penalizes() -> None:
97
+ scenario = make_mock_scenario("real_bug")
98
+ trace = make_mock_trajectory(scenario, outcome="good")
99
+ # 100% quarantine rate → well above 30% threshold
100
+ comp = AntiGamingReward(recent_episode_actions=["quarantine_test"] * 50)
101
+ score = comp.score(trace, scenario)
102
+ assert score.sub_scores["quarantine_penalty"] < 0.0
tests/rewards/test_cost_efficiency.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for CostEfficiencyReward component."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from ci_triage_env.mock import make_mock_scenario, make_mock_trajectory
6
+ from ci_triage_env.rewards.cost_efficiency import CostEfficiencyReward
7
+
8
+
9
+ def test_cost_efficiency_correct_case_returns_high_score() -> None:
10
+ # Low-cost trajectory (mock costs ~0.016) → should be positive
11
+ scenario = make_mock_scenario("real_bug")
12
+ trace = make_mock_trajectory(scenario, outcome="good")
13
+ score = CostEfficiencyReward().score(trace, scenario)
14
+ assert score.raw > 0.0
15
+
16
+
17
+ def test_cost_efficiency_wrong_case_returns_low_score() -> None:
18
+ # Simulate a trajectory that spent the full budget
19
+ scenario = make_mock_scenario("real_bug")
20
+ trace = make_mock_trajectory(scenario, outcome="good")
21
+ # Patch each step to have spent the full BUDGET_REFERENCE
22
+ from ci_triage_env.rewards.cost_efficiency import CostEfficiencyReward as CR
23
+ budget_each = CR.BUDGET_REFERENCE / max(len(trace.episode.history), 1)
24
+ patched_history = [
25
+ r.model_copy(update={"cost_charged": budget_each})
26
+ for r in trace.episode.history
27
+ ]
28
+ patched = trace.model_copy(
29
+ update={"episode": trace.episode.model_copy(update={"history": patched_history})}
30
+ )
31
+ score = CR().score(patched, scenario)
32
+ assert score.raw <= -0.9
33
+
34
+
35
+ def test_cost_efficiency_handles_no_terminal_action() -> None:
36
+ scenario = make_mock_scenario("real_bug")
37
+ trace = make_mock_trajectory(scenario, outcome="good")
38
+ no_terminal = trace.model_copy(
39
+ update={"episode": trace.episode.model_copy(update={"final_action": None})}
40
+ )
41
+ # Still scores based on cost; no terminal doesn't affect this component
42
+ score = CostEfficiencyReward().score(no_terminal, scenario)
43
+ assert -1.0 <= score.raw <= 1.0
44
+
45
+
46
+ def test_cost_efficiency_deterministic() -> None:
47
+ scenario = make_mock_scenario("race_flake")
48
+ trace = make_mock_trajectory(scenario, outcome="good")
49
+ comp = CostEfficiencyReward()
50
+ s1 = comp.score(trace, scenario)
51
+ s2 = comp.score(trace, scenario)
52
+ assert s1.raw == s2.raw
53
+
54
+
55
+ def test_cost_efficiency_score_is_in_documented_range() -> None:
56
+ for family in ["real_bug", "race_flake", "ambiguous"]:
57
+ scenario = make_mock_scenario(family)
58
+ for outcome in ["good", "bad"]:
59
+ trace = make_mock_trajectory(scenario, outcome=outcome)
60
+ score = CostEfficiencyReward().score(trace, scenario)
61
+ assert -1.0 <= score.raw <= 1.0
62
+
63
+
64
+ def test_cost_efficiency_subscores_are_meaningful() -> None:
65
+ scenario = make_mock_scenario("real_bug")
66
+ trace = make_mock_trajectory(scenario, outcome="good")
67
+ score = CostEfficiencyReward().score(trace, scenario)
68
+ assert "total_cost" in score.sub_scores
69
+ assert "ratio" in score.sub_scores
70
+ assert score.sub_scores["total_cost"] >= 0.0
tests/rewards/test_counterfactual_predict.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for CounterfactualPredictReward component — dormant in v1."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from ci_triage_env.mock import make_mock_scenario, make_mock_trajectory
6
+ from ci_triage_env.rewards.counterfactual_predict import CounterfactualPredictReward
7
+
8
+
9
+ def test_counterfactual_correct_case_returns_high_score() -> None:
10
+ # In v1 probes never fire → always returns 0.0 even in "good" trajectory
11
+ scenario = make_mock_scenario("real_bug")
12
+ trace = make_mock_trajectory(scenario, outcome="good")
13
+ score = CounterfactualPredictReward().score(trace, scenario)
14
+ assert score.raw == 0.0
15
+
16
+
17
+ def test_counterfactual_wrong_case_returns_low_score() -> None:
18
+ # v1: still 0.0 since counterfactual_replay is always None
19
+ scenario = make_mock_scenario("real_bug")
20
+ trace = make_mock_trajectory(scenario, outcome="bad")
21
+ score = CounterfactualPredictReward().score(trace, scenario)
22
+ assert score.raw == 0.0
23
+
24
+
25
+ def test_no_probe_returns_zero() -> None:
26
+ scenario = make_mock_scenario("race_flake")
27
+ trace = make_mock_trajectory(scenario, outcome="good")
28
+ assert trace.counterfactual_replay is None
29
+ score = CounterfactualPredictReward().score(trace, scenario)
30
+ assert score.raw == 0.0
31
+ assert score.weighted == 0.0
32
+
33
+
34
+ def test_counterfactual_handles_no_terminal_action() -> None:
35
+ scenario = make_mock_scenario("real_bug")
36
+ trace = make_mock_trajectory(scenario, outcome="good")
37
+ no_terminal = trace.model_copy(
38
+ update={"episode": trace.episode.model_copy(update={"final_action": None})}
39
+ )
40
+ score = CounterfactualPredictReward().score(no_terminal, scenario)
41
+ assert score.raw == 0.0
42
+
43
+
44
+ def test_counterfactual_deterministic() -> None:
45
+ scenario = make_mock_scenario("race_flake")
46
+ trace = make_mock_trajectory(scenario, outcome="good")
47
+ comp = CounterfactualPredictReward()
48
+ s1 = comp.score(trace, scenario)
49
+ s2 = comp.score(trace, scenario)
50
+ assert s1.raw == s2.raw
51
+
52
+
53
+ def test_counterfactual_score_is_in_documented_range() -> None:
54
+ for family in ["real_bug", "race_flake", "ambiguous"]:
55
+ scenario = make_mock_scenario(family)
56
+ trace = make_mock_trajectory(scenario, outcome="good")
57
+ score = CounterfactualPredictReward().score(trace, scenario)
58
+ # v1: always 0.0; generally in [-0.5, 1.0]
59
+ assert -0.5 <= score.raw <= 1.0
60
+
61
+
62
+ def test_counterfactual_subscores_are_meaningful() -> None:
63
+ scenario = make_mock_scenario("real_bug")
64
+ trace = make_mock_trajectory(scenario, outcome="good")
65
+ score = CounterfactualPredictReward().score(trace, scenario)
66
+ assert "fired" in score.sub_scores
67
+
68
+
69
+ def test_v1_default_weight_is_zero() -> None:
70
+ assert CounterfactualPredictReward.default_weight == 0.0
71
+
72
+
73
+ def test_v1_weighted_always_zero() -> None:
74
+ for family in ["real_bug", "ambiguous"]:
75
+ scenario = make_mock_scenario(family)
76
+ trace = make_mock_trajectory(scenario, outcome="good")
77
+ score = CounterfactualPredictReward().score(trace, scenario)
78
+ assert score.weighted == 0.0
tests/rewards/test_diagnosis.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for DiagnosisReward component."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from ci_triage_env.mock import make_mock_scenario, make_mock_trajectory
6
+ from ci_triage_env.rewards.diagnosis import DIAGNOSIS_REWARD_MATRIX, DiagnosisReward
7
+
8
+ ALL_FAMILIES = ["real_bug", "race_flake", "timing_flake", "infra_network", "infra_resource",
9
+ "dependency_drift", "ambiguous"]
10
+
11
+
12
+ def test_diagnosis_correct_case_returns_high_score() -> None:
13
+ for family in ALL_FAMILIES:
14
+ scenario = make_mock_scenario(family)
15
+ trace = make_mock_trajectory(scenario, outcome="good")
16
+ score = DiagnosisReward().score(trace, scenario)
17
+ assert score.raw == 1.0, f"family={family}: expected 1.0 got {score.raw}"
18
+
19
+
20
+ def test_diagnosis_wrong_case_returns_low_score() -> None:
21
+ scenario = make_mock_scenario("real_bug")
22
+ trace = make_mock_trajectory(scenario, outcome="bad")
23
+ score = DiagnosisReward().score(trace, scenario)
24
+ assert score.raw < 0.0
25
+
26
+
27
+ def test_diagnosis_handles_no_terminal_action() -> None:
28
+ scenario = make_mock_scenario("real_bug")
29
+ trace = make_mock_trajectory(scenario, outcome="good")
30
+ no_terminal = trace.model_copy(
31
+ update={"episode": trace.episode.model_copy(update={"final_action": None})}
32
+ )
33
+ score = DiagnosisReward().score(no_terminal, scenario)
34
+ assert score.raw == -1.0
35
+ assert score.sub_scores.get("no_diagnosis") == -1.0
36
+
37
+
38
+ def test_diagnosis_deterministic() -> None:
39
+ scenario = make_mock_scenario("race_flake")
40
+ trace = make_mock_trajectory(scenario, outcome="good")
41
+ comp = DiagnosisReward()
42
+ s1 = comp.score(trace, scenario)
43
+ s2 = comp.score(trace, scenario)
44
+ assert s1.raw == s2.raw
45
+
46
+
47
+ def test_diagnosis_score_is_in_documented_range() -> None:
48
+ for family in ALL_FAMILIES:
49
+ scenario = make_mock_scenario(family)
50
+ for outcome in ["good", "bad"]:
51
+ trace = make_mock_trajectory(scenario, outcome=outcome)
52
+ score = DiagnosisReward().score(trace, scenario)
53
+ assert -1.0 <= score.raw <= 1.0, f"out of range: family={family} outcome={outcome}"
54
+
55
+
56
+ def test_diagnosis_subscores_are_meaningful() -> None:
57
+ scenario = make_mock_scenario("real_bug")
58
+ trace = make_mock_trajectory(scenario, outcome="good")
59
+ score = DiagnosisReward().score(trace, scenario)
60
+ assert "matrix_lookup" in score.sub_scores
61
+
62
+
63
+ def test_diagonal_matches_return_one() -> None:
64
+ families = ALL_FAMILIES
65
+ for f in families:
66
+ assert DIAGNOSIS_REWARD_MATRIX.get((f, f)) == 1.0, f"diagonal {f} is not 1.0"
67
+
68
+
69
+ def test_quarantine_real_bug_is_most_negative_action() -> None:
70
+ # The diagnosis matrix's worst entry for predicting flake on real_bug is -1.0
71
+ flake_on_real = DIAGNOSIS_REWARD_MATRIX[("race_flake", "real_bug")]
72
+ assert flake_on_real == -1.0
73
+ # Every other (predicted, "real_bug") entry should be >= flake_on_real
74
+ real_bug_penalties = [v for (p, t), v in DIAGNOSIS_REWARD_MATRIX.items() if t == "real_bug" and p != "real_bug"]
75
+ assert all(v >= flake_on_real for v in real_bug_penalties)
tests/rewards/test_format_gate.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for FormatGate reward component."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from ci_triage_env.mock import make_mock_scenario, make_mock_trajectory
6
+ from ci_triage_env.rewards.format_gate import FormatGate
7
+ from ci_triage_env.schemas.action import ToolCall
8
+ from ci_triage_env.schemas.episode import StepRecord
9
+ from ci_triage_env.schemas.observation import BudgetState, Observation
10
+
11
+
12
+ def _dummy_obs() -> Observation:
13
+ return Observation(
14
+ episode_id="test",
15
+ step=0,
16
+ failure_summary=None,
17
+ tool_response=None,
18
+ budget_remaining=BudgetState(tool_calls_remaining=10, cost_remaining=1.0),
19
+ is_terminal=False,
20
+ probe_question=None,
21
+ )
22
+
23
+
24
+ def test_format_gate_correct_case_returns_high_score() -> None:
25
+ # Build a trajectory with only valid read_logs calls (scope is required)
26
+ scenario = make_mock_scenario("real_bug")
27
+ trace = make_mock_trajectory(scenario, outcome="good")
28
+ # Replace history with a single valid tool call
29
+ valid_record = StepRecord(
30
+ step=0,
31
+ action=ToolCall(tool_name="read_logs", args={"scope": "full"}),
32
+ observation=_dummy_obs(),
33
+ cost_charged=0.001,
34
+ )
35
+ patched = trace.model_copy(
36
+ update={"episode": trace.episode.model_copy(update={"history": [valid_record]})}
37
+ )
38
+ score = FormatGate().score(patched, scenario)
39
+ assert score.raw == 1.0
40
+
41
+
42
+ def test_format_gate_wrong_case_returns_low_score() -> None:
43
+ scenario = make_mock_scenario("real_bug")
44
+ trace = make_mock_trajectory(scenario, outcome="good")
45
+
46
+ # Inject an unknown tool name
47
+ bad_action = ToolCall(tool_name="__nonexistent_tool__", args={})
48
+ bad_record = StepRecord(step=99, action=bad_action, observation=_dummy_obs(), cost_charged=0.0)
49
+ patched = trace.model_copy(
50
+ update={"episode": trace.episode.model_copy(
51
+ update={"history": trace.episode.history + [bad_record]}
52
+ )}
53
+ )
54
+ score = FormatGate().score(patched, scenario)
55
+ assert score.raw == 0.0
56
+
57
+
58
+ def test_format_gate_handles_no_terminal_action() -> None:
59
+ scenario = make_mock_scenario("real_bug")
60
+ trace = make_mock_trajectory(scenario, outcome="good")
61
+ # Build trajectory with only valid tool calls and no terminal action
62
+ valid_record = StepRecord(
63
+ step=0,
64
+ action=ToolCall(tool_name="read_logs", args={"scope": "full"}),
65
+ observation=_dummy_obs(),
66
+ cost_charged=0.001,
67
+ )
68
+ patched = trace.model_copy(
69
+ update={"episode": trace.episode.model_copy(
70
+ update={"history": [valid_record], "final_action": None}
71
+ )}
72
+ )
73
+ score = FormatGate().score(patched, scenario)
74
+ # No terminal action → still valid if all tool calls are valid
75
+ assert score.raw == 1.0
76
+
77
+
78
+ def test_format_gate_deterministic() -> None:
79
+ scenario = make_mock_scenario("race_flake")
80
+ trace = make_mock_trajectory(scenario, outcome="good")
81
+ gate = FormatGate()
82
+ s1 = gate.score(trace, scenario)
83
+ s2 = gate.score(trace, scenario)
84
+ assert s1.raw == s2.raw
85
+
86
+
87
+ def test_format_gate_score_is_in_documented_range() -> None:
88
+ scenario = make_mock_scenario("real_bug")
89
+ trace = make_mock_trajectory(scenario, outcome="good")
90
+ # Valid trajectory (single valid tool call)
91
+ valid_record = StepRecord(
92
+ step=0,
93
+ action=ToolCall(tool_name="read_logs", args={"scope": "full"}),
94
+ observation=_dummy_obs(),
95
+ cost_charged=0.001,
96
+ )
97
+ valid_trace = trace.model_copy(
98
+ update={"episode": trace.episode.model_copy(update={"history": [valid_record]})}
99
+ )
100
+ # Invalid trajectory (unknown tool)
101
+ bad_record = StepRecord(
102
+ step=0,
103
+ action=ToolCall(tool_name="__bad__", args={}),
104
+ observation=_dummy_obs(),
105
+ cost_charged=0.0,
106
+ )
107
+ invalid_trace = trace.model_copy(
108
+ update={"episode": trace.episode.model_copy(update={"history": [bad_record]})}
109
+ )
110
+ assert FormatGate().score(valid_trace, scenario).raw == 1.0
111
+ assert FormatGate().score(invalid_trace, scenario).raw == 0.0
112
+
113
+
114
+ def test_format_gate_subscores_are_meaningful() -> None:
115
+ scenario = make_mock_scenario("real_bug")
116
+ trace = make_mock_trajectory(scenario, outcome="good")
117
+ # Use a valid trajectory to get the "valid" key
118
+ valid_record = StepRecord(
119
+ step=0,
120
+ action=ToolCall(tool_name="read_logs", args={"scope": "full"}),
121
+ observation=_dummy_obs(),
122
+ cost_charged=0.001,
123
+ )
124
+ patched = trace.model_copy(
125
+ update={"episode": trace.episode.model_copy(update={"history": [valid_record]})}
126
+ )
127
+ score = FormatGate().score(patched, scenario)
128
+ assert "valid" in score.sub_scores
129
+
130
+
131
+ def test_format_gate_invalid_args_fails() -> None:
132
+ scenario = make_mock_scenario("real_bug")
133
+ trace = make_mock_trajectory(scenario, outcome="good")
134
+
135
+ # read_logs requires "scope" arg — inject one without it
136
+ bad_action = ToolCall(tool_name="read_logs", args={}) # missing required "scope"
137
+ bad_record = StepRecord(step=99, action=bad_action, observation=_dummy_obs(), cost_charged=0.0)
138
+ patched = trace.model_copy(
139
+ update={"episode": trace.episode.model_copy(
140
+ update={"history": [bad_record]}
141
+ )}
142
+ )
143
+ score = FormatGate().score(patched, scenario)
144
+ assert score.raw == 0.0
tests/rewards/test_investigation.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for InvestigationReward component."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from ci_triage_env.mock import make_mock_scenario, make_mock_trajectory
6
+ from ci_triage_env.rewards.investigation import InvestigationReward
7
+ from ci_triage_env.schemas.action import ToolCall
8
+ from ci_triage_env.schemas.episode import StepRecord
9
+ from ci_triage_env.schemas.observation import BudgetState, Observation
10
+
11
+
12
+ def _dummy_obs(step: int = 0) -> Observation:
13
+ return Observation(
14
+ episode_id="test",
15
+ step=step,
16
+ failure_summary=None,
17
+ tool_response=None,
18
+ budget_remaining=BudgetState(tool_calls_remaining=10, cost_remaining=1.0),
19
+ is_terminal=False,
20
+ probe_question=None,
21
+ )
22
+
23
+
24
+ def _make_tool_record(tool_name: str, args: dict, step: int = 0) -> StepRecord:
25
+ return StepRecord(
26
+ step=step,
27
+ action=ToolCall(tool_name=tool_name, args=args),
28
+ observation=_dummy_obs(step),
29
+ cost_charged=0.001,
30
+ )
31
+
32
+
33
+ def test_investigation_correct_case_returns_high_score() -> None:
34
+ # Trajectory that calls all informative tools in cheap-before-expensive order
35
+ scenario = make_mock_scenario("real_bug")
36
+ trace = make_mock_trajectory(scenario, outcome="good")
37
+ score = InvestigationReward().score(trace, scenario)
38
+ # Mock trajectory calls read_logs, query_flake_history, recent_commits, rerun_test
39
+ # informative tools for real_bug mock: read_logs, query_flake_history, rerun_test
40
+ assert score.raw > 0.0
41
+
42
+
43
+ def test_investigation_wrong_case_returns_low_score() -> None:
44
+ # Trajectory with no informative tools called (mock informative_tools = read_logs, query_flake_history, rerun_test)
45
+ scenario = make_mock_scenario("real_bug")
46
+ trace = make_mock_trajectory(scenario, outcome="good")
47
+ # Use only tools NOT in informative_tools and in wrong order (expensive first, cheap second)
48
+ no_informative = [
49
+ _make_tool_record("ping_owner", {}, step=0), # expensive, not informative
50
+ _make_tool_record("recent_commits", {}, step=1), # cheap after expensive = ordering violation
51
+ ]
52
+ patched = trace.model_copy(
53
+ update={"episode": trace.episode.model_copy(update={"history": no_informative})}
54
+ )
55
+ score = InvestigationReward().score(patched, scenario)
56
+ # coverage=0.0, ordering=0.8 (1 violation), redundancy=0 → raw=0.6*0+0.2*0.8=0.16
57
+ assert score.raw <= 0.2
58
+
59
+
60
+ def test_investigation_handles_no_terminal_action() -> None:
61
+ scenario = make_mock_scenario("real_bug")
62
+ trace = make_mock_trajectory(scenario, outcome="good")
63
+ no_terminal = trace.model_copy(
64
+ update={"episode": trace.episode.model_copy(update={"final_action": None})}
65
+ )
66
+ score = InvestigationReward().score(no_terminal, scenario)
67
+ assert -1.0 <= score.raw <= 1.0
68
+
69
+
70
+ def test_investigation_deterministic() -> None:
71
+ scenario = make_mock_scenario("race_flake")
72
+ trace = make_mock_trajectory(scenario, outcome="good")
73
+ comp = InvestigationReward()
74
+ s1 = comp.score(trace, scenario)
75
+ s2 = comp.score(trace, scenario)
76
+ assert s1.raw == s2.raw
77
+
78
+
79
+ def test_investigation_score_is_in_documented_range() -> None:
80
+ for family in ["real_bug", "race_flake", "timing_flake", "ambiguous"]:
81
+ scenario = make_mock_scenario(family)
82
+ for outcome in ["good", "bad"]:
83
+ trace = make_mock_trajectory(scenario, outcome=outcome)
84
+ score = InvestigationReward().score(trace, scenario)
85
+ assert -1.0 <= score.raw <= 1.0
86
+
87
+
88
+ def test_investigation_subscores_are_meaningful() -> None:
89
+ scenario = make_mock_scenario("real_bug")
90
+ trace = make_mock_trajectory(scenario, outcome="good")
91
+ score = InvestigationReward().score(trace, scenario)
92
+ assert "coverage" in score.sub_scores
93
+ assert "ordering" in score.sub_scores
94
+ assert "redundancy_penalty" in score.sub_scores
95
+
96
+
97
+ def test_investigation_redundancy_penalised() -> None:
98
+ scenario = make_mock_scenario("real_bug")
99
+ trace = make_mock_trajectory(scenario, outcome="good")
100
+ # Duplicate tool call with same args
101
+ dup = _make_tool_record("read_logs", {"scope": "full"}, step=0)
102
+ dup2 = _make_tool_record("read_logs", {"scope": "full"}, step=1)
103
+ patched = trace.model_copy(
104
+ update={"episode": trace.episode.model_copy(update={"history": [dup, dup2]})}
105
+ )
106
+ score_dup = InvestigationReward().score(patched, scenario)
107
+ only_one = trace.model_copy(
108
+ update={"episode": trace.episode.model_copy(update={"history": [dup]})}
109
+ )
110
+ score_single = InvestigationReward().score(only_one, scenario)
111
+ assert score_dup.sub_scores["redundancy_penalty"] < score_single.sub_scores["redundancy_penalty"]
tests/rewards/test_minimal_evidence.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for MinimalEvidenceReward component."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import pytest
6
+
7
+ from ci_triage_env.mock import make_mock_scenario, make_mock_trajectory
8
+ from ci_triage_env.rewards.minimal_evidence import MinimalEvidenceReward
9
+ from ci_triage_env.schemas.action import TerminalAction, ToolCall
10
+ from ci_triage_env.schemas.episode import StepRecord
11
+ from ci_triage_env.schemas.observation import BudgetState, Observation
12
+
13
+
14
+ def _dummy_obs(step: int = 0) -> Observation:
15
+ return Observation(
16
+ episode_id="test",
17
+ step=step,
18
+ failure_summary=None,
19
+ tool_response=None,
20
+ budget_remaining=BudgetState(tool_calls_remaining=10, cost_remaining=1.0),
21
+ is_terminal=False,
22
+ probe_question=None,
23
+ )
24
+
25
+
26
+ def _make_tool_records(tool_names: list[str]) -> list[StepRecord]:
27
+ return [
28
+ StepRecord(step=i, action=ToolCall(tool_name=t, args={}),
29
+ observation=_dummy_obs(i), cost_charged=0.001)
30
+ for i, t in enumerate(tool_names)
31
+ ]
32
+
33
+
34
+ def test_minimal_evidence_correct_case_returns_high_score() -> None:
35
+ scenario = make_mock_scenario("real_bug")
36
+ trace = make_mock_trajectory(scenario, outcome="good")
37
+ score = MinimalEvidenceReward().score(trace, scenario)
38
+ # score may be 0.0 (weight=0) but raw should be non-negative for correct diagnosis
39
+ assert score.raw >= 0.0
40
+
41
+
42
+ def test_minimal_evidence_wrong_case_returns_low_score() -> None:
43
+ scenario = make_mock_scenario("real_bug")
44
+ trace = make_mock_trajectory(scenario, outcome="bad") # wrong diagnosis
45
+ score = MinimalEvidenceReward().score(trace, scenario)
46
+ assert score.raw == 0.0 # wrong diagnosis → no bonus
47
+
48
+
49
+ def test_minimal_evidence_handles_no_terminal_action() -> None:
50
+ scenario = make_mock_scenario("real_bug")
51
+ trace = make_mock_trajectory(scenario, outcome="good")
52
+ no_terminal = trace.model_copy(
53
+ update={"episode": trace.episode.model_copy(update={"final_action": None})}
54
+ )
55
+ score = MinimalEvidenceReward().score(no_terminal, scenario)
56
+ assert score.raw == 0.0 # no correct diagnosis → no bonus
57
+
58
+
59
+ def test_minimal_evidence_deterministic() -> None:
60
+ scenario = make_mock_scenario("race_flake")
61
+ trace = make_mock_trajectory(scenario, outcome="good")
62
+ comp = MinimalEvidenceReward()
63
+ s1 = comp.score(trace, scenario)
64
+ s2 = comp.score(trace, scenario)
65
+ assert s1.raw == s2.raw
66
+
67
+
68
+ def test_minimal_evidence_score_is_in_documented_range() -> None:
69
+ for family in ["real_bug", "race_flake", "timing_flake"]:
70
+ scenario = make_mock_scenario(family)
71
+ for outcome in ["good", "bad"]:
72
+ trace = make_mock_trajectory(scenario, outcome=outcome)
73
+ score = MinimalEvidenceReward().score(trace, scenario)
74
+ assert -0.5 <= score.raw <= 1.0
75
+
76
+
77
+ def test_minimal_evidence_subscores_are_meaningful() -> None:
78
+ scenario = make_mock_scenario("real_bug")
79
+ trace = make_mock_trajectory(scenario, outcome="good")
80
+ score = MinimalEvidenceReward().score(trace, scenario)
81
+ # Either empty (no min_set) or has the two keys
82
+ if score.sub_scores:
83
+ assert "min_set_used" in score.sub_scores
84
+ assert "extras" in score.sub_scores
85
+
86
+
87
+ def test_using_only_min_set_max_bonus() -> None:
88
+ scenario = make_mock_scenario("real_bug")
89
+ assert scenario.minimal_evidence_set # must have a min set for this test
90
+ min_set = scenario.minimal_evidence_set
91
+ trace = make_mock_trajectory(scenario, outcome="good")
92
+
93
+ # Build a history using ONLY the minimal evidence tools
94
+ minimal_records = _make_tool_records(min_set)
95
+ correct_terminal = TerminalAction(
96
+ action_type="submit_diagnosis",
97
+ diagnosis=scenario.ground_truth.label,
98
+ confidence=1.0,
99
+ secondary_actions=[],
100
+ )
101
+ patched = trace.model_copy(
102
+ update={"episode": trace.episode.model_copy(
103
+ update={"history": minimal_records, "final_action": correct_terminal}
104
+ )}
105
+ )
106
+ score = MinimalEvidenceReward().score(patched, scenario)
107
+ # Only min set tools → extras=0 → bonus = 1.0 - 0.1*0 = 1.0
108
+ assert score.raw == pytest.approx(1.0, abs=1e-6)
109
+
110
+
111
+ def test_empty_min_set_returns_zero() -> None:
112
+ # Use a real ambiguous scenario from the generator (which correctly has empty min_set)
113
+ from ci_triage_env.data.generators import GENERATOR_REGISTRY
114
+ from ci_triage_env.mock import make_mock_trajectory
115
+
116
+ scenario = GENERATOR_REGISTRY["ambiguous"]().generate(seed=42)
117
+ assert scenario.minimal_evidence_set == []
118
+ trace = make_mock_trajectory(scenario, outcome="good")
119
+ score = MinimalEvidenceReward().score(trace, scenario)
120
+ assert score.raw == 0.0
tests/rewards/test_time_penalty.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for TimePenaltyReward component."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from ci_triage_env.mock import make_mock_scenario, make_mock_trajectory
6
+ from ci_triage_env.rewards.time_penalty import TimePenaltyReward
7
+ from ci_triage_env.schemas.action import ToolCall
8
+ from ci_triage_env.schemas.episode import StepRecord
9
+ from ci_triage_env.schemas.observation import BudgetState, Observation
10
+
11
+
12
+ def _make_tool_records(n: int) -> list[StepRecord]:
13
+ obs = Observation(
14
+ episode_id="test",
15
+ step=0,
16
+ failure_summary=None,
17
+ tool_response=None,
18
+ budget_remaining=BudgetState(tool_calls_remaining=10, cost_remaining=1.0),
19
+ is_terminal=False,
20
+ probe_question=None,
21
+ )
22
+ return [
23
+ StepRecord(step=i, action=ToolCall(tool_name="read_logs", args={"scope": "full"}),
24
+ observation=obs, cost_charged=0.001)
25
+ for i in range(n)
26
+ ]
27
+
28
+
29
+ def test_time_penalty_correct_case_returns_high_score() -> None:
30
+ # <= REFERENCE_STEPS tool calls → no penalty
31
+ scenario = make_mock_scenario("real_bug")
32
+ trace = make_mock_trajectory(scenario, outcome="good")
33
+ patched = trace.model_copy(
34
+ update={"episode": trace.episode.model_copy(update={"history": _make_tool_records(4)})}
35
+ )
36
+ score = TimePenaltyReward().score(patched, scenario)
37
+ assert score.raw == 0.0
38
+
39
+
40
+ def test_time_penalty_wrong_case_returns_low_score() -> None:
41
+ # Many tool calls → large penalty
42
+ scenario = make_mock_scenario("real_bug")
43
+ trace = make_mock_trajectory(scenario, outcome="good")
44
+ patched = trace.model_copy(
45
+ update={"episode": trace.episode.model_copy(update={"history": _make_tool_records(56)})}
46
+ )
47
+ score = TimePenaltyReward().score(patched, scenario)
48
+ assert score.raw == -1.0 # capped at floor
49
+
50
+
51
+ def test_time_penalty_handles_no_terminal_action() -> None:
52
+ scenario = make_mock_scenario("real_bug")
53
+ trace = make_mock_trajectory(scenario, outcome="good")
54
+ no_terminal = trace.model_copy(
55
+ update={"episode": trace.episode.model_copy(update={"final_action": None})}
56
+ )
57
+ score = TimePenaltyReward().score(no_terminal, scenario)
58
+ assert -1.0 <= score.raw <= 0.0
59
+
60
+
61
+ def test_time_penalty_deterministic() -> None:
62
+ scenario = make_mock_scenario("race_flake")
63
+ trace = make_mock_trajectory(scenario, outcome="good")
64
+ comp = TimePenaltyReward()
65
+ s1 = comp.score(trace, scenario)
66
+ s2 = comp.score(trace, scenario)
67
+ assert s1.raw == s2.raw
68
+
69
+
70
+ def test_time_penalty_score_is_in_documented_range() -> None:
71
+ scenario = make_mock_scenario("real_bug")
72
+ for n in [0, 3, 6, 10, 60]:
73
+ trace = make_mock_trajectory(scenario, outcome="good")
74
+ patched = trace.model_copy(
75
+ update={"episode": trace.episode.model_copy(update={"history": _make_tool_records(n)})}
76
+ )
77
+ score = TimePenaltyReward().score(patched, scenario)
78
+ assert -1.0 <= score.raw <= 0.0
79
+
80
+
81
+ def test_time_penalty_subscores_are_meaningful() -> None:
82
+ scenario = make_mock_scenario("real_bug")
83
+ trace = make_mock_trajectory(scenario, outcome="good")
84
+ score = TimePenaltyReward().score(trace, scenario)
85
+ assert "steps" in score.sub_scores
86
+ assert "excess" in score.sub_scores
87
+ assert score.sub_scores["steps"] >= 0