Spaces:
Sleeping
Sleeping
File size: 15,380 Bytes
d63a1ba c2edad1 d63a1ba c2edad1 d63a1ba c2edad1 d63a1ba f00a888 d63a1ba f00a888 d63a1ba f00a888 d63a1ba f00a888 d63a1ba f00a888 d63a1ba f00a888 d63a1ba 7b636e3 d63a1ba 7b636e3 f00a888 d63a1ba 7b636e3 f00a888 d63a1ba 7b636e3 d63a1ba 7b636e3 d63a1ba c2edad1 d63a1ba c2edad1 d63a1ba | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 | """Baseline inference script for the vulnerability triage environment."""
from __future__ import annotations
import argparse
import json
import os
from typing import Dict, List, Optional
from openai import OpenAI
from openenv.core import GenericEnvClient
API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o-mini")
# Support all key variants the validator may inject
_API_KEY = os.getenv("API_KEY") or os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY")
from models import VulnTriageAction
from server.cases import TASK_ORDER, get_case_definition
from server.vuln_triage_env_environment import VulnTriageEnvironment
SYSTEM_PROMPT = """You are triaging open-source vulnerability reports.
Return ONLY a single JSON object β no prose, no markdown β with exactly these keys:
action_type : string (required) β one of the action types listed in available_actions
evidence_id : string (optional) β only used with inspect_evidence
value : string (optional) β a PLAIN STRING, never an object or array
rationale : string (required) β one short sentence
Valid action_type values and their expected value strings:
read_report β no value needed
inspect_evidence β set evidence_id to one id from available_evidence
search_nvd_database β value: CVE ID (e.g. CVE-2023-1234) found in report aliases
fetch_commit_diff β value: commit hash or hash fragment found in references
message_maintainer β value: a question for the maintainer (e.g. "Is there a patch?")
set_validity β value: "valid" | "invalid" | "needs_more_info"
set_affected_package β value: package name string, e.g. "guarddog"
set_affected_versions β value: semver range string, e.g. "<0.1.5"
set_severity β value: "low" | "medium" | "high" | "critical"
set_exploitability β value: "low" | "medium" | "high"
set_next_action β value: "patch" | "publish_advisory" | "close" | "escalate" | "request_info"
set_missing_information β value: one missing info item as a plain string
submit_triage β no value needed
Strategy: read_report first, then use tools (search_nvd, fetch_commit, message_maintainer) to unlock hidden evidence, then fill all draft fields, then submit.
Note: You CANNOT inspect "nvd_assessment", "github_commit_diff", or "vendor_status" directly. You must use the tools above to reveal them.
"""
def get_openai_client() -> OpenAI:
api_key = _API_KEY
if not api_key:
raise RuntimeError(
"Set API_KEY, HF_TOKEN, or OPENAI_API_KEY before running the OpenAI baseline."
)
kwargs: Dict[str, str] = {"api_key": api_key}
if API_BASE_URL:
kwargs["base_url"] = API_BASE_URL
return OpenAI(**kwargs)
def parse_json_response(text: str) -> Dict[str, str]:
"""Extract the first valid JSON object from a model response.
Handles:
- Markdown fences (```json ... ```)
- Think-blocks from reasoning models (<think>...</think>)
- Surrounding prose before/after the JSON object
"""
import re as _re
text = text.strip()
# Strip reasoning/think blocks produced by models like Qwen3 or DeepSeek
text = _re.sub(r"<think>.*?</think>", "", text, flags=_re.DOTALL | _re.IGNORECASE).strip()
# Strip markdown fences
if "```" in text:
lines = [ln for ln in text.splitlines() if not ln.strip().startswith("```")]
text = "\n".join(lines).strip()
# Find the first complete JSON object by bracket matching
start = text.find("{")
if start == -1:
raise ValueError(f"No JSON object found in model response: {text[:200]!r}")
depth = 0
in_string = False
escape = False
for i, ch in enumerate(text[start:], start):
if escape:
escape = False
continue
if ch == "\\" and in_string:
escape = True
continue
if ch == '"' and not escape:
in_string = not in_string
if in_string:
continue
if ch == "{":
depth += 1
elif ch == "}":
depth -= 1
if depth == 0:
return json.loads(text[start : i + 1])
raise ValueError(f"Incomplete JSON object in model response: {text[:200]!r}")
def heuristic_policy(observation: Dict) -> Dict[str, str]:
if "read_report" not in observation["action_history"]:
return {"action_type": "read_report", "rationale": "Start by reading the report"}
truth = get_case_definition(observation["task_id"]).truth
supporting_evidence_ids = set(truth.supporting_evidence_ids)
visible_ids = {item["evidence_id"] for item in observation["visible_evidence"]}
remaining_supporting = [
evidence_id
for evidence_id in observation["available_evidence"]
if evidence_id in supporting_evidence_ids and evidence_id not in visible_ids
]
if remaining_supporting:
eval_id = remaining_supporting[0]
# Interactive Tools Support:
if eval_id == "nvd_assessment":
# The oracle magically knows the OSV ID to query (alias)
from server.cases import SEEDS
seed = SEEDS[observation["task_id"]]
return {"action_type": "search_nvd_database", "value": seed.osv_id, "rationale": "Fetch NVD dynamically"}
elif eval_id == "github_commit_diff":
# Match any random commit substring
return {"action_type": "fetch_commit_diff", "value": "Commit", "rationale": "Fetch Diff dynamically"}
elif eval_id == "vendor_status":
return {"action_type": "message_maintainer", "value": "Is there an ETA for a patch?", "rationale": "Chat with maintainer"}
return {
"action_type": "inspect_evidence",
"evidence_id": eval_id,
"rationale": "Reveal the next supporting evidence item",
}
draft = observation["draft"]
score = observation["score_breakdown"]
by_truth = [
("set_validity", truth.validity),
("set_affected_package", truth.affected_package),
("set_affected_versions", truth.affected_versions),
("set_severity", truth.severity),
("set_exploitability", truth.exploitability),
("set_next_action", truth.next_action),
]
for action_type, value in by_truth:
if draft[action_type.replace("set_", "")] != value:
return {"action_type": action_type, "value": value, "rationale": "Update the draft"}
# Submit any required missing-information items not yet recorded in the draft
existing_mi = {v.strip().lower() for v in draft.get("missing_information", [])}
for mi_item in truth.missing_information:
if mi_item.strip().lower() not in existing_mi:
return {
"action_type": "set_missing_information",
"value": mi_item,
"rationale": "Record known missing information",
}
return {"action_type": "submit_triage", "rationale": f"Current total score is {score['total']}"}
def llm_policy(client: OpenAI, model_name: str, observation: Dict) -> Dict[str, str]:
response = client.chat.completions.create(
model=model_name,
temperature=0,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{
"role": "user",
"content": json.dumps(observation, indent=2, sort_keys=True),
},
],
)
text = response.choices[0].message.content
return parse_json_response(text)
_VALID_ACTION_KEYS = {"action_type", "evidence_id", "value", "rationale"}
def sanitize_action_payload(payload: Dict) -> Dict:
"""Keep only valid VulnTriageAction keys and coerce bad value types."""
clean = {k: v for k, v in payload.items() if k in _VALID_ACTION_KEYS}
if isinstance(clean.get("value"), (dict, list)):
clean["value"] = json.dumps(clean["value"])
return clean
def run_local_episode(task_id: str, policy: str, model_name: str) -> Dict[str, float]:
print(f"[START] task={task_id}", flush=True)
env = VulnTriageEnvironment()
observation = env.reset(task_id=task_id).model_dump()
client = get_openai_client() if policy == "openai" else None
last_action_str: str = ""
repeat_count: int = 0
step_num: int = 1
while not observation["done"]:
action_payload = (
llm_policy(client, model_name, observation) if client else heuristic_policy(observation)
)
# Strip unknown keys then coerce bad value types
try:
clean = sanitize_action_payload(action_payload)
action = VulnTriageAction.model_validate(clean)
except Exception as exc:
print(f" [warn] invalid action payload ({exc}), falling back to read_report", flush=True)
action = VulnTriageAction(action_type="read_report", rationale="fallback: parse error")
# Break infinite loops where model repeats the same action
action_str = action.model_dump_json()
if action_str == last_action_str:
repeat_count += 1
if repeat_count >= 3:
print(f" [warn] model repeated same action 3x β forcing submit_triage", flush=True)
action = VulnTriageAction(action_type="submit_triage", rationale="loop guard")
else:
repeat_count = 0
last_action_str = action_str
observation = env.step(action).model_dump()
step_reward = float(observation.get("reward") or 0.0)
print(f"[STEP] step={step_num} action={action.action_type} reward={step_reward}", flush=True)
step_num += 1
final_score = float(observation.get("final_score") or observation.get("score_breakdown", {}).get("total", 0.0))
print(f"[END] task={task_id} score={final_score} steps={step_num}", flush=True)
return {
"task_id": task_id,
"final_score": float(observation["final_score"] or 0.0),
"validity": observation["score_breakdown"]["validity"],
"package_versions": round(
(
observation["score_breakdown"]["affected_package"]
+ observation["score_breakdown"]["affected_versions"]
)
/ 2,
4,
),
"severity": observation["score_breakdown"]["severity"],
"exploitability": observation["score_breakdown"]["exploitability"],
"next_action": observation["score_breakdown"]["next_action"],
}
def run_remote_episode(base_url: str, task_id: str, policy: str, model_name: str) -> Dict[str, float]:
print(f"[START] task={task_id}", flush=True)
llm_client = get_openai_client() if policy == "openai" else None
env = GenericEnvClient(base_url=base_url).sync()
MAX_REMOTE_STEPS = 20
with env:
response = env.reset(task_id=task_id)
observation = response.observation
done = response.done
step_num: int = 1
last_action_str: str = ""
repeat_count: int = 0
while not done and step_num <= MAX_REMOTE_STEPS:
# Get action from LLM or heuristic with error handling
try:
if llm_client:
action_payload = llm_policy(llm_client, model_name, observation)
action_payload = sanitize_action_payload(action_payload)
else:
action_payload = heuristic_policy(observation)
except Exception as exc:
print(f" [warn] policy error ({exc}), falling back to read_report", flush=True)
action_payload = {"action_type": "read_report", "rationale": "fallback: policy error"}
# Loop guard: detect repeated identical actions
action_str = json.dumps(action_payload, sort_keys=True)
if action_str == last_action_str:
repeat_count += 1
if repeat_count >= 3:
print(f" [warn] repeated action 3x β forcing submit_triage", flush=True)
action_payload = {"action_type": "submit_triage", "rationale": "loop guard"}
else:
repeat_count = 0
last_action_str = action_str
# Step the environment with error handling
try:
response = env.step(action_payload)
observation = response.observation
done = response.done
except Exception as exc:
print(f" [warn] env.step failed ({exc}), forcing submit", flush=True)
action_payload = {"action_type": "submit_triage", "rationale": "env error recovery"}
try:
response = env.step(action_payload)
observation = response.observation
done = response.done
except Exception:
done = True
step_reward = float(getattr(response, 'reward', None) or 0.0)
print(f"[STEP] step={step_num} action={action_payload.get('action_type')} reward={step_reward}", flush=True)
step_num += 1
# Safely extract final score
final_score = 0.0
try:
final_score = float(observation.get("final_score") or observation.get("score_breakdown", {}).get("total", 0.0))
except Exception:
pass
print(f"[END] task={task_id} score={final_score} steps={step_num}", flush=True)
# Build result safely
sb = observation.get("score_breakdown", {})
return {
"task_id": task_id,
"final_score": final_score,
"validity": sb.get("validity", 0.0),
"package_versions": round((sb.get("affected_package", 0.0) + sb.get("affected_versions", 0.0)) / 2, 4),
"severity": sb.get("severity", 0.0),
"exploitability": sb.get("exploitability", 0.0),
"next_action": sb.get("next_action", 0.0),
}
def main() -> None:
parser = argparse.ArgumentParser()
# Auto-select openai policy when the validator injects API credentials;
# fall back to heuristic for local smoke-tests with no key.
_has_credentials = bool(_API_KEY)
_default_policy = "openai" if _has_credentials else "heuristic"
parser.add_argument("--policy", choices=["openai", "heuristic"], default=_default_policy)
parser.add_argument("--model", default=MODEL_NAME)
# Default ENV_BASE_URL to the live HF Space so the validator can reach our environment
_default_env_url = os.getenv("ENV_BASE_URL", "https://adhitya122-vulnops.hf.space")
parser.add_argument("--env-base-url", dest="base_url", default=_default_env_url)
args = parser.parse_args()
results: List[Dict[str, float]] = []
for task_id in TASK_ORDER:
if args.base_url:
results.append(run_remote_episode(args.base_url, task_id, args.policy, args.model))
else:
results.append(run_local_episode(task_id, args.policy, args.model))
aggregate = round(sum(item["final_score"] for item in results) / len(results), 4)
print(json.dumps({"policy": args.policy, "model": args.model, "average_score": aggregate, "tasks": results}, indent=2))
if __name__ == "__main__":
main()
|