import random from models import TriageDraft, VulnTriageAction from server.cases import choose_balanced_task_id, CASE_DEFINITIONS from server.graders import grade_task, version_range_match from server.vuln_triage_env_environment import VulnTriageEnvironment # --------------------------------------------------------------------------- # Core environment tests # --------------------------------------------------------------------------- def test_easy_task_can_be_solved_deterministically(): """Easy task should be solvable in 10 steps with just 2 evidence reads.""" env = VulnTriageEnvironment() env.reset(task_id="task_easy_guarddog") env.step(VulnTriageAction(action_type="read_report")) env.step(VulnTriageAction(action_type="inspect_evidence", evidence_id="osv_advisory")) env.step(VulnTriageAction(action_type="inspect_evidence", evidence_id="affected_versions")) env.step(VulnTriageAction(action_type="set_validity", value="valid")) env.step(VulnTriageAction(action_type="set_affected_package", value="guarddog")) env.step(VulnTriageAction(action_type="set_affected_versions", value="<0.1.5")) env.step(VulnTriageAction(action_type="set_severity", value="medium")) env.step(VulnTriageAction(action_type="set_exploitability", value="low")) env.step(VulnTriageAction(action_type="set_next_action", value="patch")) result = env.step(VulnTriageAction(action_type="submit_triage")) assert result.done is True assert result.final_score == 0.9999 assert result.score_breakdown["total"] == 0.9999 def test_terminal_submission_score_never_hits_zero(): env = VulnTriageEnvironment() env.reset(task_id="task_easy_guarddog") result = env.step(VulnTriageAction(action_type="submit_triage")) assert result.done is True assert result.final_score == 0.0001 assert result.score_breakdown["total"] == 0.0001 def test_medium_task_uses_real_provider_backed_truth(): env = VulnTriageEnvironment() env.reset(task_id="task_medium_invenio") env.step(VulnTriageAction(action_type="set_validity", value="valid")) env.step(VulnTriageAction(action_type="set_affected_package", value="invenio-records")) breakdown = grade_task("task_medium_invenio", env.state.draft) assert breakdown["validity"] == 1.0 assert breakdown["affected_package"] == 1.0 def test_balanced_sampler_is_seed_reproducible(): first = choose_balanced_task_id(7, random.Random(0)) second = choose_balanced_task_id(7, random.Random(999)) assert first == second def test_environment_reset_without_task_id_samples_valid_difficulties(): env = VulnTriageEnvironment() seen = {env.reset().difficulty for _ in range(12)} assert seen == {"easy", "medium", "hard"} # --------------------------------------------------------------------------- # Fix 1: version range normalizer accepts equivalent expressions # --------------------------------------------------------------------------- def test_version_range_match_accepts_trivial_lower_bound(): assert version_range_match(">=0,<0.1.5", "<0.1.5") == 1.0 assert version_range_match(">=0.0.0,<0.1.5", "<0.1.5") == 1.0 def test_version_range_match_is_order_insensitive_for_segments(): a = "<1.0.2 ; >=1.1.0,<1.1.1 ; >=1.2.0,<1.2.2" b = ">=1.2.0,<1.2.2 ; >=1.1.0,<1.1.1 ; <1.0.2" assert version_range_match(a, b) == 1.0 def test_version_range_match_different_ranges_score_zero(): assert version_range_match("<0.1.4", "<0.1.5") == 0.0 # --------------------------------------------------------------------------- # Fix 2: multi-branch affected versions captured correctly # --------------------------------------------------------------------------- def test_medium_invenio_ground_truth_includes_all_branches(): truth = CASE_DEFINITIONS["task_medium_invenio"].truth assert "<1.0.2" in truth.affected_versions assert ">=1.1.0,<1.1.1" in truth.affected_versions assert ">=1.2.0,<1.2.2" in truth.affected_versions def test_medium_invenio_all_branches_score_full_points(): draft = TriageDraft( validity="valid", affected_package="invenio-records", affected_versions=">=1.2.0,<1.2.2 ; >=1.1.0,<1.1.1 ; <1.0.2", severity="medium", exploitability="low", next_action="publish_advisory", ) breakdown = grade_task("task_medium_invenio", draft) assert breakdown["affected_versions"] == 1.0 # --------------------------------------------------------------------------- # Difficulty redesign — Easy task # --------------------------------------------------------------------------- def test_easy_task_only_needs_two_evidence_items(): """Easy task supporting_evidence_ids should be just 2 items, not 4.""" truth = CASE_DEFINITIONS["task_easy_guarddog"].truth assert truth.supporting_evidence_ids == ["osv_advisory", "affected_versions"] assert len(truth.supporting_evidence_ids) == 2 def test_easy_task_max_steps_is_tight(): assert CASE_DEFINITIONS["task_easy_guarddog"].max_steps == 10 # --------------------------------------------------------------------------- # Difficulty redesign — Medium task # --------------------------------------------------------------------------- def test_medium_task_has_threat_intel_evidence(): """Medium task should inject a threat_intel_signal evidence item.""" evidence_ids = [e["evidence_id"] for e in CASE_DEFINITIONS["task_medium_invenio"].evidence] assert "threat_intel_signal" in evidence_ids def test_medium_task_exploitability_is_medium_not_low(): """EPSS says low but threat intel overrides to medium — key difficulty driver.""" truth = CASE_DEFINITIONS["task_medium_invenio"].truth assert truth.exploitability == "medium", ( "Medium task exploitability must be 'medium' (overriding EPSS) " "so any model that only reads the EPSS evidence gets it wrong." ) def test_medium_task_exploitability_costs_points_if_epss_only(): """A model that reads only EPSS and submits 'low' exploitability loses points.""" draft = TriageDraft( validity="valid", affected_package="invenio-records", affected_versions="<1.0.2 ; >=1.1.0,<1.1.1 ; >=1.2.0,<1.2.2", severity="medium", exploitability="low", # wrong — EPSS-only answer next_action="publish_advisory", ) breakdown = grade_task("task_medium_invenio", draft) assert breakdown["exploitability"] == 0.0 assert breakdown["total"] < 1.0 # --------------------------------------------------------------------------- # Difficulty redesign — Hard task # --------------------------------------------------------------------------- def test_hard_task_correct_next_action_is_request_info(): """Hard task must require request_info, not publish_advisory.""" truth = CASE_DEFINITIONS["task_hard_gradio"].truth assert truth.next_action == "request_info", ( "Hard task next_action must be 'request_info' — no patch exists yet." ) def test_hard_task_has_vendor_status_evidence(): """Hard task should inject a vendor_status evidence item explaining no patch.""" evidence_ids = [e["evidence_id"] for e in CASE_DEFINITIONS["task_hard_gradio"].evidence] assert "vendor_status" in evidence_ids def test_hard_task_affected_versions_covers_all(): """Hard task affected_versions must be >=0 (no fixed version).""" truth = CASE_DEFINITIONS["task_hard_gradio"].truth assert truth.affected_versions == ">=0" def test_hard_task_publish_advisory_costs_next_action_points(): """A model that naively publishes instead of requesting info loses 15%.""" truth = CASE_DEFINITIONS["task_hard_gradio"].truth draft = TriageDraft( validity="valid", affected_package="gradio", affected_versions=">=0", severity="medium", exploitability="low", next_action="publish_advisory", # wrong — no patch exists missing_information=list(truth.missing_information), ) breakdown = grade_task("task_hard_gradio", draft) assert breakdown["next_action"] == 0.0 assert breakdown["total"] < 1.0 def test_hard_task_request_info_scores_full(): """The correct hard-task answer should score 1.0.""" truth = CASE_DEFINITIONS["task_hard_gradio"].truth draft = TriageDraft( validity=truth.validity, affected_package=truth.affected_package, affected_versions=truth.affected_versions, severity=truth.severity, exploitability=truth.exploitability, next_action="request_info", missing_information=list(truth.missing_information), ) breakdown = grade_task("task_hard_gradio", draft) assert breakdown["next_action"] == 1.0 assert breakdown["total"] == 1.0 def test_hard_task_has_non_empty_missing_information(): truth = CASE_DEFINITIONS["task_hard_gradio"].truth assert len(truth.missing_information) >= 3 def test_hard_task_empty_missing_info_costs_points(): draft = TriageDraft( validity="valid", affected_package="gradio", affected_versions=">=0", severity="medium", exploitability="low", next_action="request_info", missing_information=[], ) breakdown = grade_task("task_hard_gradio", draft) assert breakdown["missing_information"] == 0.0 assert breakdown["total"] < 1.0