Spaces:
Sleeping
Sleeping
File size: 4,974 Bytes
d63a1ba 6eb49cc d63a1ba 6eb49cc d63a1ba | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 | """Deterministic graders for the vulnerability triage benchmark."""
from __future__ import annotations
import re
from typing import Dict, Iterable, List
try:
from ..models import TriageDraft
from .cases import CASE_DEFINITIONS, CaseDefinition, get_case_definition
except ImportError:
from models import TriageDraft
from server.cases import CASE_DEFINITIONS, CaseDefinition, get_case_definition
WEIGHTS: Dict[str, float] = {
"validity": 0.20,
"affected_package": 0.10,
"affected_versions": 0.10,
"severity": 0.20,
"exploitability": 0.15,
"next_action": 0.15,
"missing_information": 0.10,
}
TERMINAL_SCORE_EPSILON = 0.0001
def normalize_text(value: str) -> str:
return " ".join(value.strip().lower().split())
def normalize_list(values: Iterable[str]) -> List[str]:
return sorted({normalize_text(value) for value in values if normalize_text(value)})
def set_similarity(actual: Iterable[str], expected: Iterable[str]) -> float:
actual_set = set(normalize_list(actual))
expected_set = set(normalize_list(expected))
if not actual_set and not expected_set:
return 1.0
if not actual_set or not expected_set:
return 0.0
union = actual_set | expected_set
return len(actual_set & expected_set) / len(union)
def field_match(actual: str, expected: str) -> float:
return 1.0 if normalize_text(actual) == normalize_text(expected) else 0.0
def _normalize_version_range(value: str) -> str:
"""Canonicalize a version range string for flexible comparison.
Two representations that are treated as equivalent:
- A trivial lower bound ``>=0`` / ``>=0.0`` / ``>=0.0.0`` followed by a
comma is stripped, so ``>=0,<0.1.5`` compares equal to ``<0.1.5``.
- Semicolon-separated multi-branch segments are sorted so submission
order does not matter.
"""
text = normalize_text(value)
segments = [seg.strip() for seg in text.split(";") if seg.strip()]
normalized: List[str] = []
for seg in segments:
# Remove trivial lower-bound prefix: >=0, >=0.0, >=0.0.0 before comma
seg = re.sub(r">=\s*0(?:\.0)*\s*,\s*", "", seg)
# Collapse whitespace around comparison operators
seg = re.sub(r"\s*([><=!]+)\s*", r"\1", seg).strip()
if seg:
normalized.append(seg)
return " ; ".join(sorted(normalized))
def version_range_match(actual: str, expected: str) -> float:
"""Score 1.0 when version ranges are semantically equivalent."""
return (
1.0
if _normalize_version_range(actual) == _normalize_version_range(expected)
else 0.0
)
def normalize_terminal_score(score: float) -> float:
"""Clamp terminal task scores into the validator's required open interval."""
rounded = round(min(max(float(score), 0.0), 1.0), 4)
if rounded <= 0.0:
return TERMINAL_SCORE_EPSILON
if rounded >= 1.0:
return round(1.0 - TERMINAL_SCORE_EPSILON, 4)
return rounded
def grade_case(case: CaseDefinition, draft: TriageDraft) -> Dict[str, float]:
breakdown = {
"validity": field_match(draft.validity, case.truth.validity),
"affected_package": field_match(
draft.affected_package, case.truth.affected_package
),
"affected_versions": version_range_match(
draft.affected_versions, case.truth.affected_versions
),
"severity": field_match(draft.severity, case.truth.severity),
"exploitability": field_match(
draft.exploitability, case.truth.exploitability
),
"next_action": field_match(draft.next_action, case.truth.next_action),
"missing_information": set_similarity(
draft.missing_information, case.truth.missing_information
),
}
breakdown["total"] = round(
sum(breakdown[name] * WEIGHTS[name] for name in WEIGHTS), 4
)
# Cascading Difficulty Penalties
difficulty = getattr(case, 'difficulty', 'hard')
# HARD Penalty: if truth action is request_info (meaning a patch is missing)
# and the model tries to publish an advisory or patch, it's a critical logic failure.
if difficulty == "hard" and case.truth.next_action == "request_info" and draft.next_action in {"publish_advisory", "patch"}:
breakdown["total"] = round(breakdown["total"] * 0.5, 4)
# MEDIUM Penalty: Medium tasks specifically test the ability to decode
# raw intelligence (code diffs, conflicting EPSS/honeypot numbers) against decoys.
# If a model fails to extract both severity and exploitability accurately, it failed the core challenge.
elif difficulty == "medium" and breakdown["severity"] == 0.0 and breakdown["exploitability"] == 0.0:
breakdown["total"] = round(breakdown["total"] * 0.75, 4)
return breakdown
def grade_task(task_id: str, draft: TriageDraft) -> Dict[str, float]:
return grade_case(get_case_definition(task_id), draft)
|