File size: 9,405 Bytes
d63a1ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6eb49cc
 
 
 
 
 
 
 
 
 
 
d63a1ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
import random

from models import TriageDraft, VulnTriageAction
from server.cases import choose_balanced_task_id, CASE_DEFINITIONS
from server.graders import grade_task, version_range_match
from server.vuln_triage_env_environment import VulnTriageEnvironment


# ---------------------------------------------------------------------------
# Core environment tests
# ---------------------------------------------------------------------------

def test_easy_task_can_be_solved_deterministically():
    """Easy task should be solvable in 10 steps with just 2 evidence reads."""
    env = VulnTriageEnvironment()
    env.reset(task_id="task_easy_guarddog")
    env.step(VulnTriageAction(action_type="read_report"))
    env.step(VulnTriageAction(action_type="inspect_evidence", evidence_id="osv_advisory"))
    env.step(VulnTriageAction(action_type="inspect_evidence", evidence_id="affected_versions"))
    env.step(VulnTriageAction(action_type="set_validity", value="valid"))
    env.step(VulnTriageAction(action_type="set_affected_package", value="guarddog"))
    env.step(VulnTriageAction(action_type="set_affected_versions", value="<0.1.5"))
    env.step(VulnTriageAction(action_type="set_severity", value="medium"))
    env.step(VulnTriageAction(action_type="set_exploitability", value="low"))
    env.step(VulnTriageAction(action_type="set_next_action", value="patch"))
    result = env.step(VulnTriageAction(action_type="submit_triage"))
    assert result.done is True
    assert result.final_score == 0.9999
    assert result.score_breakdown["total"] == 0.9999


def test_terminal_submission_score_never_hits_zero():
    env = VulnTriageEnvironment()
    env.reset(task_id="task_easy_guarddog")
    result = env.step(VulnTriageAction(action_type="submit_triage"))
    assert result.done is True
    assert result.final_score == 0.0001
    assert result.score_breakdown["total"] == 0.0001


def test_medium_task_uses_real_provider_backed_truth():
    env = VulnTriageEnvironment()
    env.reset(task_id="task_medium_invenio")
    env.step(VulnTriageAction(action_type="set_validity", value="valid"))
    env.step(VulnTriageAction(action_type="set_affected_package", value="invenio-records"))
    breakdown = grade_task("task_medium_invenio", env.state.draft)
    assert breakdown["validity"] == 1.0
    assert breakdown["affected_package"] == 1.0


def test_balanced_sampler_is_seed_reproducible():
    first = choose_balanced_task_id(7, random.Random(0))
    second = choose_balanced_task_id(7, random.Random(999))
    assert first == second


def test_environment_reset_without_task_id_samples_valid_difficulties():
    env = VulnTriageEnvironment()
    seen = {env.reset().difficulty for _ in range(12)}
    assert seen == {"easy", "medium", "hard"}


# ---------------------------------------------------------------------------
# Fix 1: version range normalizer accepts equivalent expressions
# ---------------------------------------------------------------------------

def test_version_range_match_accepts_trivial_lower_bound():
    assert version_range_match(">=0,<0.1.5", "<0.1.5") == 1.0
    assert version_range_match(">=0.0.0,<0.1.5", "<0.1.5") == 1.0


def test_version_range_match_is_order_insensitive_for_segments():
    a = "<1.0.2 ; >=1.1.0,<1.1.1 ; >=1.2.0,<1.2.2"
    b = ">=1.2.0,<1.2.2 ; >=1.1.0,<1.1.1 ; <1.0.2"
    assert version_range_match(a, b) == 1.0


def test_version_range_match_different_ranges_score_zero():
    assert version_range_match("<0.1.4", "<0.1.5") == 0.0


# ---------------------------------------------------------------------------
# Fix 2: multi-branch affected versions captured correctly
# ---------------------------------------------------------------------------

def test_medium_invenio_ground_truth_includes_all_branches():
    truth = CASE_DEFINITIONS["task_medium_invenio"].truth
    assert "<1.0.2" in truth.affected_versions
    assert ">=1.1.0,<1.1.1" in truth.affected_versions
    assert ">=1.2.0,<1.2.2" in truth.affected_versions


def test_medium_invenio_all_branches_score_full_points():
    draft = TriageDraft(
        validity="valid",
        affected_package="invenio-records",
        affected_versions=">=1.2.0,<1.2.2 ; >=1.1.0,<1.1.1 ; <1.0.2",
        severity="medium",
        exploitability="low",
        next_action="publish_advisory",
    )
    breakdown = grade_task("task_medium_invenio", draft)
    assert breakdown["affected_versions"] == 1.0


# ---------------------------------------------------------------------------
# Difficulty redesign — Easy task
# ---------------------------------------------------------------------------

def test_easy_task_only_needs_two_evidence_items():
    """Easy task supporting_evidence_ids should be just 2 items, not 4."""
    truth = CASE_DEFINITIONS["task_easy_guarddog"].truth
    assert truth.supporting_evidence_ids == ["osv_advisory", "affected_versions"]
    assert len(truth.supporting_evidence_ids) == 2


def test_easy_task_max_steps_is_tight():
    assert CASE_DEFINITIONS["task_easy_guarddog"].max_steps == 10


# ---------------------------------------------------------------------------
# Difficulty redesign — Medium task
# ---------------------------------------------------------------------------

def test_medium_task_has_threat_intel_evidence():
    """Medium task should inject a threat_intel_signal evidence item."""
    evidence_ids = [e["evidence_id"] for e in CASE_DEFINITIONS["task_medium_invenio"].evidence]
    assert "threat_intel_signal" in evidence_ids


def test_medium_task_exploitability_is_medium_not_low():
    """EPSS says low but threat intel overrides to medium — key difficulty driver."""
    truth = CASE_DEFINITIONS["task_medium_invenio"].truth
    assert truth.exploitability == "medium", (
        "Medium task exploitability must be 'medium' (overriding EPSS) "
        "so any model that only reads the EPSS evidence gets it wrong."
    )


def test_medium_task_exploitability_costs_points_if_epss_only():
    """A model that reads only EPSS and submits 'low' exploitability loses points."""
    draft = TriageDraft(
        validity="valid",
        affected_package="invenio-records",
        affected_versions="<1.0.2 ; >=1.1.0,<1.1.1 ; >=1.2.0,<1.2.2",
        severity="medium",
        exploitability="low",   # wrong — EPSS-only answer
        next_action="publish_advisory",
    )
    breakdown = grade_task("task_medium_invenio", draft)
    assert breakdown["exploitability"] == 0.0
    assert breakdown["total"] < 1.0


# ---------------------------------------------------------------------------
# Difficulty redesign — Hard task
# ---------------------------------------------------------------------------

def test_hard_task_correct_next_action_is_request_info():
    """Hard task must require request_info, not publish_advisory."""
    truth = CASE_DEFINITIONS["task_hard_gradio"].truth
    assert truth.next_action == "request_info", (
        "Hard task next_action must be 'request_info' — no patch exists yet."
    )


def test_hard_task_has_vendor_status_evidence():
    """Hard task should inject a vendor_status evidence item explaining no patch."""
    evidence_ids = [e["evidence_id"] for e in CASE_DEFINITIONS["task_hard_gradio"].evidence]
    assert "vendor_status" in evidence_ids


def test_hard_task_affected_versions_covers_all():
    """Hard task affected_versions must be >=0 (no fixed version)."""
    truth = CASE_DEFINITIONS["task_hard_gradio"].truth
    assert truth.affected_versions == ">=0"


def test_hard_task_publish_advisory_costs_next_action_points():
    """A model that naively publishes instead of requesting info loses 15%."""
    truth = CASE_DEFINITIONS["task_hard_gradio"].truth
    draft = TriageDraft(
        validity="valid",
        affected_package="gradio",
        affected_versions=">=0",
        severity="medium",
        exploitability="low",
        next_action="publish_advisory",   # wrong — no patch exists
        missing_information=list(truth.missing_information),
    )
    breakdown = grade_task("task_hard_gradio", draft)
    assert breakdown["next_action"] == 0.0
    assert breakdown["total"] < 1.0


def test_hard_task_request_info_scores_full():
    """The correct hard-task answer should score 1.0."""
    truth = CASE_DEFINITIONS["task_hard_gradio"].truth
    draft = TriageDraft(
        validity=truth.validity,
        affected_package=truth.affected_package,
        affected_versions=truth.affected_versions,
        severity=truth.severity,
        exploitability=truth.exploitability,
        next_action="request_info",
        missing_information=list(truth.missing_information),
    )
    breakdown = grade_task("task_hard_gradio", draft)
    assert breakdown["next_action"] == 1.0
    assert breakdown["total"] == 1.0


def test_hard_task_has_non_empty_missing_information():
    truth = CASE_DEFINITIONS["task_hard_gradio"].truth
    assert len(truth.missing_information) >= 3


def test_hard_task_empty_missing_info_costs_points():
    draft = TriageDraft(
        validity="valid",
        affected_package="gradio",
        affected_versions=">=0",
        severity="medium",
        exploitability="low",
        next_action="request_info",
        missing_information=[],
    )
    breakdown = grade_task("task_hard_gradio", draft)
    assert breakdown["missing_information"] == 0.0
    assert breakdown["total"] < 1.0