Spaces:
Sleeping
Sleeping
| """Baseline inference script for the vulnerability triage environment.""" | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import os | |
| from typing import Dict, List, Optional | |
| from openai import OpenAI | |
| from openenv.core import GenericEnvClient | |
| API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1") | |
| MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o-mini") | |
| # Support all key variants the validator may inject | |
| _API_KEY = os.getenv("API_KEY") or os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY") | |
| from models import VulnTriageAction | |
| from server.cases import TASK_ORDER, get_case_definition | |
| from server.vuln_triage_env_environment import VulnTriageEnvironment | |
| SYSTEM_PROMPT = """You are triaging open-source vulnerability reports. | |
| Return ONLY a single JSON object β no prose, no markdown β with exactly these keys: | |
| action_type : string (required) β one of the action types listed in available_actions | |
| evidence_id : string (optional) β only used with inspect_evidence | |
| value : string (optional) β a PLAIN STRING, never an object or array | |
| rationale : string (required) β one short sentence | |
| Valid action_type values and their expected value strings: | |
| read_report β no value needed | |
| inspect_evidence β set evidence_id to one id from available_evidence | |
| search_nvd_database β value: CVE ID (e.g. CVE-2023-1234) found in report aliases | |
| fetch_commit_diff β value: commit hash or hash fragment found in references | |
| message_maintainer β value: a question for the maintainer (e.g. "Is there a patch?") | |
| set_validity β value: "valid" | "invalid" | "needs_more_info" | |
| set_affected_package β value: package name string, e.g. "guarddog" | |
| set_affected_versions β value: semver range string, e.g. "<0.1.5" | |
| set_severity β value: "low" | "medium" | "high" | "critical" | |
| set_exploitability β value: "low" | "medium" | "high" | |
| set_next_action β value: "patch" | "publish_advisory" | "close" | "escalate" | "request_info" | |
| set_missing_information β value: one missing info item as a plain string | |
| submit_triage β no value needed | |
| Strategy: read_report first, then use tools (search_nvd, fetch_commit, message_maintainer) to unlock hidden evidence, then fill all draft fields, then submit. | |
| Note: You CANNOT inspect "nvd_assessment", "github_commit_diff", or "vendor_status" directly. You must use the tools above to reveal them. | |
| """ | |
| def get_openai_client() -> OpenAI: | |
| api_key = _API_KEY | |
| if not api_key: | |
| raise RuntimeError( | |
| "Set API_KEY, HF_TOKEN, or OPENAI_API_KEY before running the OpenAI baseline." | |
| ) | |
| kwargs: Dict[str, str] = {"api_key": api_key} | |
| if API_BASE_URL: | |
| kwargs["base_url"] = API_BASE_URL | |
| return OpenAI(**kwargs) | |
| def parse_json_response(text: str) -> Dict[str, str]: | |
| """Extract the first valid JSON object from a model response. | |
| Handles: | |
| - Markdown fences (```json ... ```) | |
| - Think-blocks from reasoning models (<think>...</think>) | |
| - Surrounding prose before/after the JSON object | |
| """ | |
| import re as _re | |
| text = text.strip() | |
| # Strip reasoning/think blocks produced by models like Qwen3 or DeepSeek | |
| text = _re.sub(r"<think>.*?</think>", "", text, flags=_re.DOTALL | _re.IGNORECASE).strip() | |
| # Strip markdown fences | |
| if "```" in text: | |
| lines = [ln for ln in text.splitlines() if not ln.strip().startswith("```")] | |
| text = "\n".join(lines).strip() | |
| # Find the first complete JSON object by bracket matching | |
| start = text.find("{") | |
| if start == -1: | |
| raise ValueError(f"No JSON object found in model response: {text[:200]!r}") | |
| depth = 0 | |
| in_string = False | |
| escape = False | |
| for i, ch in enumerate(text[start:], start): | |
| if escape: | |
| escape = False | |
| continue | |
| if ch == "\\" and in_string: | |
| escape = True | |
| continue | |
| if ch == '"' and not escape: | |
| in_string = not in_string | |
| if in_string: | |
| continue | |
| if ch == "{": | |
| depth += 1 | |
| elif ch == "}": | |
| depth -= 1 | |
| if depth == 0: | |
| return json.loads(text[start : i + 1]) | |
| raise ValueError(f"Incomplete JSON object in model response: {text[:200]!r}") | |
| def heuristic_policy(observation: Dict) -> Dict[str, str]: | |
| if "read_report" not in observation["action_history"]: | |
| return {"action_type": "read_report", "rationale": "Start by reading the report"} | |
| truth = get_case_definition(observation["task_id"]).truth | |
| supporting_evidence_ids = set(truth.supporting_evidence_ids) | |
| visible_ids = {item["evidence_id"] for item in observation["visible_evidence"]} | |
| remaining_supporting = [ | |
| evidence_id | |
| for evidence_id in observation["available_evidence"] | |
| if evidence_id in supporting_evidence_ids and evidence_id not in visible_ids | |
| ] | |
| if remaining_supporting: | |
| eval_id = remaining_supporting[0] | |
| # Interactive Tools Support: | |
| if eval_id == "nvd_assessment": | |
| # The oracle magically knows the OSV ID to query (alias) | |
| from server.cases import SEEDS | |
| seed = SEEDS[observation["task_id"]] | |
| return {"action_type": "search_nvd_database", "value": seed.osv_id, "rationale": "Fetch NVD dynamically"} | |
| elif eval_id == "github_commit_diff": | |
| # Match any random commit substring | |
| return {"action_type": "fetch_commit_diff", "value": "Commit", "rationale": "Fetch Diff dynamically"} | |
| elif eval_id == "vendor_status": | |
| return {"action_type": "message_maintainer", "value": "Is there an ETA for a patch?", "rationale": "Chat with maintainer"} | |
| return { | |
| "action_type": "inspect_evidence", | |
| "evidence_id": eval_id, | |
| "rationale": "Reveal the next supporting evidence item", | |
| } | |
| draft = observation["draft"] | |
| score = observation["score_breakdown"] | |
| by_truth = [ | |
| ("set_validity", truth.validity), | |
| ("set_affected_package", truth.affected_package), | |
| ("set_affected_versions", truth.affected_versions), | |
| ("set_severity", truth.severity), | |
| ("set_exploitability", truth.exploitability), | |
| ("set_next_action", truth.next_action), | |
| ] | |
| for action_type, value in by_truth: | |
| if draft[action_type.replace("set_", "")] != value: | |
| return {"action_type": action_type, "value": value, "rationale": "Update the draft"} | |
| # Submit any required missing-information items not yet recorded in the draft | |
| existing_mi = {v.strip().lower() for v in draft.get("missing_information", [])} | |
| for mi_item in truth.missing_information: | |
| if mi_item.strip().lower() not in existing_mi: | |
| return { | |
| "action_type": "set_missing_information", | |
| "value": mi_item, | |
| "rationale": "Record known missing information", | |
| } | |
| return {"action_type": "submit_triage", "rationale": f"Current total score is {score['total']}"} | |
| def llm_policy(client: OpenAI, model_name: str, observation: Dict) -> Dict[str, str]: | |
| response = client.chat.completions.create( | |
| model=model_name, | |
| temperature=0, | |
| messages=[ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| { | |
| "role": "user", | |
| "content": json.dumps(observation, indent=2, sort_keys=True), | |
| }, | |
| ], | |
| ) | |
| text = response.choices[0].message.content | |
| return parse_json_response(text) | |
| _VALID_ACTION_KEYS = {"action_type", "evidence_id", "value", "rationale"} | |
| def sanitize_action_payload(payload: Dict) -> Dict: | |
| """Keep only valid VulnTriageAction keys and coerce bad value types.""" | |
| clean = {k: v for k, v in payload.items() if k in _VALID_ACTION_KEYS} | |
| if isinstance(clean.get("value"), (dict, list)): | |
| clean["value"] = json.dumps(clean["value"]) | |
| return clean | |
| def run_local_episode(task_id: str, policy: str, model_name: str) -> Dict[str, float]: | |
| print(f"[START] task={task_id}", flush=True) | |
| env = VulnTriageEnvironment() | |
| observation = env.reset(task_id=task_id).model_dump() | |
| client = get_openai_client() if policy == "openai" else None | |
| last_action_str: str = "" | |
| repeat_count: int = 0 | |
| step_num: int = 1 | |
| while not observation["done"]: | |
| action_payload = ( | |
| llm_policy(client, model_name, observation) if client else heuristic_policy(observation) | |
| ) | |
| # Strip unknown keys then coerce bad value types | |
| try: | |
| clean = sanitize_action_payload(action_payload) | |
| action = VulnTriageAction.model_validate(clean) | |
| except Exception as exc: | |
| print(f" [warn] invalid action payload ({exc}), falling back to read_report", flush=True) | |
| action = VulnTriageAction(action_type="read_report", rationale="fallback: parse error") | |
| # Break infinite loops where model repeats the same action | |
| action_str = action.model_dump_json() | |
| if action_str == last_action_str: | |
| repeat_count += 1 | |
| if repeat_count >= 3: | |
| print(f" [warn] model repeated same action 3x β forcing submit_triage", flush=True) | |
| action = VulnTriageAction(action_type="submit_triage", rationale="loop guard") | |
| else: | |
| repeat_count = 0 | |
| last_action_str = action_str | |
| observation = env.step(action).model_dump() | |
| step_reward = float(observation.get("reward") or 0.0) | |
| print(f"[STEP] step={step_num} action={action.action_type} reward={step_reward}", flush=True) | |
| step_num += 1 | |
| final_score = float(observation.get("final_score") or observation.get("score_breakdown", {}).get("total", 0.0)) | |
| print(f"[END] task={task_id} score={final_score} steps={step_num}", flush=True) | |
| return { | |
| "task_id": task_id, | |
| "final_score": float(observation["final_score"] or 0.0), | |
| "validity": observation["score_breakdown"]["validity"], | |
| "package_versions": round( | |
| ( | |
| observation["score_breakdown"]["affected_package"] | |
| + observation["score_breakdown"]["affected_versions"] | |
| ) | |
| / 2, | |
| 4, | |
| ), | |
| "severity": observation["score_breakdown"]["severity"], | |
| "exploitability": observation["score_breakdown"]["exploitability"], | |
| "next_action": observation["score_breakdown"]["next_action"], | |
| } | |
| def run_remote_episode(base_url: str, task_id: str, policy: str, model_name: str) -> Dict[str, float]: | |
| print(f"[START] task={task_id}", flush=True) | |
| llm_client = get_openai_client() if policy == "openai" else None | |
| env = GenericEnvClient(base_url=base_url).sync() | |
| MAX_REMOTE_STEPS = 20 | |
| with env: | |
| response = env.reset(task_id=task_id) | |
| observation = response.observation | |
| done = response.done | |
| step_num: int = 1 | |
| last_action_str: str = "" | |
| repeat_count: int = 0 | |
| while not done and step_num <= MAX_REMOTE_STEPS: | |
| # Get action from LLM or heuristic with error handling | |
| try: | |
| if llm_client: | |
| action_payload = llm_policy(llm_client, model_name, observation) | |
| action_payload = sanitize_action_payload(action_payload) | |
| else: | |
| action_payload = heuristic_policy(observation) | |
| except Exception as exc: | |
| print(f" [warn] policy error ({exc}), falling back to read_report", flush=True) | |
| action_payload = {"action_type": "read_report", "rationale": "fallback: policy error"} | |
| # Loop guard: detect repeated identical actions | |
| action_str = json.dumps(action_payload, sort_keys=True) | |
| if action_str == last_action_str: | |
| repeat_count += 1 | |
| if repeat_count >= 3: | |
| print(f" [warn] repeated action 3x β forcing submit_triage", flush=True) | |
| action_payload = {"action_type": "submit_triage", "rationale": "loop guard"} | |
| else: | |
| repeat_count = 0 | |
| last_action_str = action_str | |
| # Step the environment with error handling | |
| try: | |
| response = env.step(action_payload) | |
| observation = response.observation | |
| done = response.done | |
| except Exception as exc: | |
| print(f" [warn] env.step failed ({exc}), forcing submit", flush=True) | |
| action_payload = {"action_type": "submit_triage", "rationale": "env error recovery"} | |
| try: | |
| response = env.step(action_payload) | |
| observation = response.observation | |
| done = response.done | |
| except Exception: | |
| done = True | |
| step_reward = float(getattr(response, 'reward', None) or 0.0) | |
| print(f"[STEP] step={step_num} action={action_payload.get('action_type')} reward={step_reward}", flush=True) | |
| step_num += 1 | |
| # Safely extract final score | |
| final_score = 0.0 | |
| try: | |
| final_score = float(observation.get("final_score") or observation.get("score_breakdown", {}).get("total", 0.0)) | |
| except Exception: | |
| pass | |
| print(f"[END] task={task_id} score={final_score} steps={step_num}", flush=True) | |
| # Build result safely | |
| sb = observation.get("score_breakdown", {}) | |
| return { | |
| "task_id": task_id, | |
| "final_score": final_score, | |
| "validity": sb.get("validity", 0.0), | |
| "package_versions": round((sb.get("affected_package", 0.0) + sb.get("affected_versions", 0.0)) / 2, 4), | |
| "severity": sb.get("severity", 0.0), | |
| "exploitability": sb.get("exploitability", 0.0), | |
| "next_action": sb.get("next_action", 0.0), | |
| } | |
| def main() -> None: | |
| parser = argparse.ArgumentParser() | |
| # Auto-select openai policy when the validator injects API credentials; | |
| # fall back to heuristic for local smoke-tests with no key. | |
| _has_credentials = bool(_API_KEY) | |
| _default_policy = "openai" if _has_credentials else "heuristic" | |
| parser.add_argument("--policy", choices=["openai", "heuristic"], default=_default_policy) | |
| parser.add_argument("--model", default=MODEL_NAME) | |
| # Default ENV_BASE_URL to the live HF Space so the validator can reach our environment | |
| _default_env_url = os.getenv("ENV_BASE_URL", "https://adhitya122-vulnops.hf.space") | |
| parser.add_argument("--env-base-url", dest="base_url", default=_default_env_url) | |
| args = parser.parse_args() | |
| results: List[Dict[str, float]] = [] | |
| for task_id in TASK_ORDER: | |
| if args.base_url: | |
| results.append(run_remote_episode(args.base_url, task_id, args.policy, args.model)) | |
| else: | |
| results.append(run_local_episode(task_id, args.policy, args.model)) | |
| aggregate = round(sum(item["final_score"] for item in results) / len(results), 4) | |
| print(json.dumps({"policy": args.policy, "model": args.model, "average_score": aggregate, "tasks": results}, indent=2)) | |
| if __name__ == "__main__": | |
| main() | |