Spaces:
Sleeping
Sleeping
| import random | |
| from models import TriageDraft, VulnTriageAction | |
| from server.cases import choose_balanced_task_id, CASE_DEFINITIONS | |
| from server.graders import grade_task, version_range_match | |
| from server.vuln_triage_env_environment import VulnTriageEnvironment | |
| # --------------------------------------------------------------------------- | |
| # Core environment tests | |
| # --------------------------------------------------------------------------- | |
| def test_easy_task_can_be_solved_deterministically(): | |
| """Easy task should be solvable in 10 steps with just 2 evidence reads.""" | |
| env = VulnTriageEnvironment() | |
| env.reset(task_id="task_easy_guarddog") | |
| env.step(VulnTriageAction(action_type="read_report")) | |
| env.step(VulnTriageAction(action_type="inspect_evidence", evidence_id="osv_advisory")) | |
| env.step(VulnTriageAction(action_type="inspect_evidence", evidence_id="affected_versions")) | |
| env.step(VulnTriageAction(action_type="set_validity", value="valid")) | |
| env.step(VulnTriageAction(action_type="set_affected_package", value="guarddog")) | |
| env.step(VulnTriageAction(action_type="set_affected_versions", value="<0.1.5")) | |
| env.step(VulnTriageAction(action_type="set_severity", value="medium")) | |
| env.step(VulnTriageAction(action_type="set_exploitability", value="low")) | |
| env.step(VulnTriageAction(action_type="set_next_action", value="patch")) | |
| result = env.step(VulnTriageAction(action_type="submit_triage")) | |
| assert result.done is True | |
| assert result.final_score == 0.9999 | |
| assert result.score_breakdown["total"] == 0.9999 | |
| def test_terminal_submission_score_never_hits_zero(): | |
| env = VulnTriageEnvironment() | |
| env.reset(task_id="task_easy_guarddog") | |
| result = env.step(VulnTriageAction(action_type="submit_triage")) | |
| assert result.done is True | |
| assert result.final_score == 0.0001 | |
| assert result.score_breakdown["total"] == 0.0001 | |
| def test_medium_task_uses_real_provider_backed_truth(): | |
| env = VulnTriageEnvironment() | |
| env.reset(task_id="task_medium_invenio") | |
| env.step(VulnTriageAction(action_type="set_validity", value="valid")) | |
| env.step(VulnTriageAction(action_type="set_affected_package", value="invenio-records")) | |
| breakdown = grade_task("task_medium_invenio", env.state.draft) | |
| assert breakdown["validity"] == 1.0 | |
| assert breakdown["affected_package"] == 1.0 | |
| def test_balanced_sampler_is_seed_reproducible(): | |
| first = choose_balanced_task_id(7, random.Random(0)) | |
| second = choose_balanced_task_id(7, random.Random(999)) | |
| assert first == second | |
| def test_environment_reset_without_task_id_samples_valid_difficulties(): | |
| env = VulnTriageEnvironment() | |
| seen = {env.reset().difficulty for _ in range(12)} | |
| assert seen == {"easy", "medium", "hard"} | |
| # --------------------------------------------------------------------------- | |
| # Fix 1: version range normalizer accepts equivalent expressions | |
| # --------------------------------------------------------------------------- | |
| def test_version_range_match_accepts_trivial_lower_bound(): | |
| assert version_range_match(">=0,<0.1.5", "<0.1.5") == 1.0 | |
| assert version_range_match(">=0.0.0,<0.1.5", "<0.1.5") == 1.0 | |
| def test_version_range_match_is_order_insensitive_for_segments(): | |
| a = "<1.0.2 ; >=1.1.0,<1.1.1 ; >=1.2.0,<1.2.2" | |
| b = ">=1.2.0,<1.2.2 ; >=1.1.0,<1.1.1 ; <1.0.2" | |
| assert version_range_match(a, b) == 1.0 | |
| def test_version_range_match_different_ranges_score_zero(): | |
| assert version_range_match("<0.1.4", "<0.1.5") == 0.0 | |
| # --------------------------------------------------------------------------- | |
| # Fix 2: multi-branch affected versions captured correctly | |
| # --------------------------------------------------------------------------- | |
| def test_medium_invenio_ground_truth_includes_all_branches(): | |
| truth = CASE_DEFINITIONS["task_medium_invenio"].truth | |
| assert "<1.0.2" in truth.affected_versions | |
| assert ">=1.1.0,<1.1.1" in truth.affected_versions | |
| assert ">=1.2.0,<1.2.2" in truth.affected_versions | |
| def test_medium_invenio_all_branches_score_full_points(): | |
| draft = TriageDraft( | |
| validity="valid", | |
| affected_package="invenio-records", | |
| affected_versions=">=1.2.0,<1.2.2 ; >=1.1.0,<1.1.1 ; <1.0.2", | |
| severity="medium", | |
| exploitability="low", | |
| next_action="publish_advisory", | |
| ) | |
| breakdown = grade_task("task_medium_invenio", draft) | |
| assert breakdown["affected_versions"] == 1.0 | |
| # --------------------------------------------------------------------------- | |
| # Difficulty redesign — Easy task | |
| # --------------------------------------------------------------------------- | |
| def test_easy_task_only_needs_two_evidence_items(): | |
| """Easy task supporting_evidence_ids should be just 2 items, not 4.""" | |
| truth = CASE_DEFINITIONS["task_easy_guarddog"].truth | |
| assert truth.supporting_evidence_ids == ["osv_advisory", "affected_versions"] | |
| assert len(truth.supporting_evidence_ids) == 2 | |
| def test_easy_task_max_steps_is_tight(): | |
| assert CASE_DEFINITIONS["task_easy_guarddog"].max_steps == 10 | |
| # --------------------------------------------------------------------------- | |
| # Difficulty redesign — Medium task | |
| # --------------------------------------------------------------------------- | |
| def test_medium_task_has_threat_intel_evidence(): | |
| """Medium task should inject a threat_intel_signal evidence item.""" | |
| evidence_ids = [e["evidence_id"] for e in CASE_DEFINITIONS["task_medium_invenio"].evidence] | |
| assert "threat_intel_signal" in evidence_ids | |
| def test_medium_task_exploitability_is_medium_not_low(): | |
| """EPSS says low but threat intel overrides to medium — key difficulty driver.""" | |
| truth = CASE_DEFINITIONS["task_medium_invenio"].truth | |
| assert truth.exploitability == "medium", ( | |
| "Medium task exploitability must be 'medium' (overriding EPSS) " | |
| "so any model that only reads the EPSS evidence gets it wrong." | |
| ) | |
| def test_medium_task_exploitability_costs_points_if_epss_only(): | |
| """A model that reads only EPSS and submits 'low' exploitability loses points.""" | |
| draft = TriageDraft( | |
| validity="valid", | |
| affected_package="invenio-records", | |
| affected_versions="<1.0.2 ; >=1.1.0,<1.1.1 ; >=1.2.0,<1.2.2", | |
| severity="medium", | |
| exploitability="low", # wrong — EPSS-only answer | |
| next_action="publish_advisory", | |
| ) | |
| breakdown = grade_task("task_medium_invenio", draft) | |
| assert breakdown["exploitability"] == 0.0 | |
| assert breakdown["total"] < 1.0 | |
| # --------------------------------------------------------------------------- | |
| # Difficulty redesign — Hard task | |
| # --------------------------------------------------------------------------- | |
| def test_hard_task_correct_next_action_is_request_info(): | |
| """Hard task must require request_info, not publish_advisory.""" | |
| truth = CASE_DEFINITIONS["task_hard_gradio"].truth | |
| assert truth.next_action == "request_info", ( | |
| "Hard task next_action must be 'request_info' — no patch exists yet." | |
| ) | |
| def test_hard_task_has_vendor_status_evidence(): | |
| """Hard task should inject a vendor_status evidence item explaining no patch.""" | |
| evidence_ids = [e["evidence_id"] for e in CASE_DEFINITIONS["task_hard_gradio"].evidence] | |
| assert "vendor_status" in evidence_ids | |
| def test_hard_task_affected_versions_covers_all(): | |
| """Hard task affected_versions must be >=0 (no fixed version).""" | |
| truth = CASE_DEFINITIONS["task_hard_gradio"].truth | |
| assert truth.affected_versions == ">=0" | |
| def test_hard_task_publish_advisory_costs_next_action_points(): | |
| """A model that naively publishes instead of requesting info loses 15%.""" | |
| truth = CASE_DEFINITIONS["task_hard_gradio"].truth | |
| draft = TriageDraft( | |
| validity="valid", | |
| affected_package="gradio", | |
| affected_versions=">=0", | |
| severity="medium", | |
| exploitability="low", | |
| next_action="publish_advisory", # wrong — no patch exists | |
| missing_information=list(truth.missing_information), | |
| ) | |
| breakdown = grade_task("task_hard_gradio", draft) | |
| assert breakdown["next_action"] == 0.0 | |
| assert breakdown["total"] < 1.0 | |
| def test_hard_task_request_info_scores_full(): | |
| """The correct hard-task answer should score 1.0.""" | |
| truth = CASE_DEFINITIONS["task_hard_gradio"].truth | |
| draft = TriageDraft( | |
| validity=truth.validity, | |
| affected_package=truth.affected_package, | |
| affected_versions=truth.affected_versions, | |
| severity=truth.severity, | |
| exploitability=truth.exploitability, | |
| next_action="request_info", | |
| missing_information=list(truth.missing_information), | |
| ) | |
| breakdown = grade_task("task_hard_gradio", draft) | |
| assert breakdown["next_action"] == 1.0 | |
| assert breakdown["total"] == 1.0 | |
| def test_hard_task_has_non_empty_missing_information(): | |
| truth = CASE_DEFINITIONS["task_hard_gradio"].truth | |
| assert len(truth.missing_information) >= 3 | |
| def test_hard_task_empty_missing_info_costs_points(): | |
| draft = TriageDraft( | |
| validity="valid", | |
| affected_package="gradio", | |
| affected_versions=">=0", | |
| severity="medium", | |
| exploitability="low", | |
| next_action="request_info", | |
| missing_information=[], | |
| ) | |
| breakdown = grade_task("task_hard_gradio", draft) | |
| assert breakdown["missing_information"] == 0.0 | |
| assert breakdown["total"] < 1.0 | |