File size: 15,380 Bytes
d63a1ba
 
 
 
 
 
 
 
 
 
 
 
 
 
c2edad1
 
d63a1ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c2edad1
d63a1ba
c2edad1
 
 
 
d63a1ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f00a888
d63a1ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f00a888
d63a1ba
 
 
 
 
 
 
f00a888
d63a1ba
 
 
 
 
 
f00a888
 
d63a1ba
 
f00a888
 
d63a1ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f00a888
d63a1ba
 
7b636e3
d63a1ba
 
 
 
 
7b636e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f00a888
 
d63a1ba
 
7b636e3
 
 
 
 
 
f00a888
d63a1ba
7b636e3
 
d63a1ba
 
 
7b636e3
 
 
 
 
d63a1ba
 
 
 
 
c2edad1
 
 
 
 
d63a1ba
c2edad1
 
 
d63a1ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
"""Baseline inference script for the vulnerability triage environment."""

from __future__ import annotations

import argparse
import json
import os
from typing import Dict, List, Optional

from openai import OpenAI
from openenv.core import GenericEnvClient

API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o-mini")
# Support all key variants the validator may inject
_API_KEY = os.getenv("API_KEY") or os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY")

from models import VulnTriageAction
from server.cases import TASK_ORDER, get_case_definition
from server.vuln_triage_env_environment import VulnTriageEnvironment


SYSTEM_PROMPT = """You are triaging open-source vulnerability reports.
Return ONLY a single JSON object β€” no prose, no markdown β€” with exactly these keys:
  action_type  : string  (required) β€” one of the action types listed in available_actions
  evidence_id  : string  (optional) β€” only used with inspect_evidence
  value        : string  (optional) β€” a PLAIN STRING, never an object or array
  rationale    : string  (required) β€” one short sentence

Valid action_type values and their expected value strings:
  read_report                     β€” no value needed
  inspect_evidence                β€” set evidence_id to one id from available_evidence
  search_nvd_database             β€” value: CVE ID (e.g. CVE-2023-1234) found in report aliases
  fetch_commit_diff               β€” value: commit hash or hash fragment found in references
  message_maintainer              β€” value: a question for the maintainer (e.g. "Is there a patch?")
  set_validity                    β€” value: "valid" | "invalid" | "needs_more_info"
  set_affected_package            β€” value: package name string, e.g. "guarddog"
  set_affected_versions           β€” value: semver range string, e.g. "<0.1.5"
  set_severity                    β€” value: "low" | "medium" | "high" | "critical"
  set_exploitability              β€” value: "low" | "medium" | "high"
  set_next_action                 β€” value: "patch" | "publish_advisory" | "close" | "escalate" | "request_info"
  set_missing_information         β€” value: one missing info item as a plain string
  submit_triage                   β€” no value needed

Strategy: read_report first, then use tools (search_nvd, fetch_commit, message_maintainer) to unlock hidden evidence, then fill all draft fields, then submit.
Note: You CANNOT inspect "nvd_assessment", "github_commit_diff", or "vendor_status" directly. You must use the tools above to reveal them.
"""


def get_openai_client() -> OpenAI:
    api_key = _API_KEY
    if not api_key:
        raise RuntimeError(
            "Set API_KEY, HF_TOKEN, or OPENAI_API_KEY before running the OpenAI baseline."
        )
    kwargs: Dict[str, str] = {"api_key": api_key}
    if API_BASE_URL:
        kwargs["base_url"] = API_BASE_URL
    return OpenAI(**kwargs)


def parse_json_response(text: str) -> Dict[str, str]:
    """Extract the first valid JSON object from a model response.

    Handles:
    - Markdown fences (```json ... ```)
    - Think-blocks from reasoning models (<think>...</think>)
    - Surrounding prose before/after the JSON object
    """
    import re as _re
    text = text.strip()
    # Strip reasoning/think blocks produced by models like Qwen3 or DeepSeek
    text = _re.sub(r"<think>.*?</think>", "", text, flags=_re.DOTALL | _re.IGNORECASE).strip()
    # Strip markdown fences
    if "```" in text:
        lines = [ln for ln in text.splitlines() if not ln.strip().startswith("```")]
        text = "\n".join(lines).strip()
    # Find the first complete JSON object by bracket matching
    start = text.find("{")
    if start == -1:
        raise ValueError(f"No JSON object found in model response: {text[:200]!r}")
    depth = 0
    in_string = False
    escape = False
    for i, ch in enumerate(text[start:], start):
        if escape:
            escape = False
            continue
        if ch == "\\" and in_string:
            escape = True
            continue
        if ch == '"' and not escape:
            in_string = not in_string
        if in_string:
            continue
        if ch == "{":
            depth += 1
        elif ch == "}":
            depth -= 1
            if depth == 0:
                return json.loads(text[start : i + 1])
    raise ValueError(f"Incomplete JSON object in model response: {text[:200]!r}")


def heuristic_policy(observation: Dict) -> Dict[str, str]:
    if "read_report" not in observation["action_history"]:
        return {"action_type": "read_report", "rationale": "Start by reading the report"}

    truth = get_case_definition(observation["task_id"]).truth
    supporting_evidence_ids = set(truth.supporting_evidence_ids)
    visible_ids = {item["evidence_id"] for item in observation["visible_evidence"]}

    remaining_supporting = [
        evidence_id
        for evidence_id in observation["available_evidence"]
        if evidence_id in supporting_evidence_ids and evidence_id not in visible_ids
    ]
    if remaining_supporting:
        eval_id = remaining_supporting[0]
        # Interactive Tools Support:
        if eval_id == "nvd_assessment":
            # The oracle magically knows the OSV ID to query (alias)
            from server.cases import SEEDS
            seed = SEEDS[observation["task_id"]]
            return {"action_type": "search_nvd_database", "value": seed.osv_id, "rationale": "Fetch NVD dynamically"}
        elif eval_id == "github_commit_diff":
            # Match any random commit substring
            return {"action_type": "fetch_commit_diff", "value": "Commit", "rationale": "Fetch Diff dynamically"}
        elif eval_id == "vendor_status":
            return {"action_type": "message_maintainer", "value": "Is there an ETA for a patch?", "rationale": "Chat with maintainer"}
            
        return {
            "action_type": "inspect_evidence",
            "evidence_id": eval_id,
            "rationale": "Reveal the next supporting evidence item",
        }

    draft = observation["draft"]
    score = observation["score_breakdown"]

    by_truth = [
        ("set_validity", truth.validity),
        ("set_affected_package", truth.affected_package),
        ("set_affected_versions", truth.affected_versions),
        ("set_severity", truth.severity),
        ("set_exploitability", truth.exploitability),
        ("set_next_action", truth.next_action),
    ]

    for action_type, value in by_truth:
        if draft[action_type.replace("set_", "")] != value:
            return {"action_type": action_type, "value": value, "rationale": "Update the draft"}

    # Submit any required missing-information items not yet recorded in the draft
    existing_mi = {v.strip().lower() for v in draft.get("missing_information", [])}
    for mi_item in truth.missing_information:
        if mi_item.strip().lower() not in existing_mi:
            return {
                "action_type": "set_missing_information",
                "value": mi_item,
                "rationale": "Record known missing information",
            }

    return {"action_type": "submit_triage", "rationale": f"Current total score is {score['total']}"}


def llm_policy(client: OpenAI, model_name: str, observation: Dict) -> Dict[str, str]:
    response = client.chat.completions.create(
        model=model_name,
        temperature=0,
        messages=[
            {"role": "system", "content": SYSTEM_PROMPT},
            {
                "role": "user",
                "content": json.dumps(observation, indent=2, sort_keys=True),
            },
        ],
    )
    text = response.choices[0].message.content
    return parse_json_response(text)


_VALID_ACTION_KEYS = {"action_type", "evidence_id", "value", "rationale"}


def sanitize_action_payload(payload: Dict) -> Dict:
    """Keep only valid VulnTriageAction keys and coerce bad value types."""
    clean = {k: v for k, v in payload.items() if k in _VALID_ACTION_KEYS}
    if isinstance(clean.get("value"), (dict, list)):
        clean["value"] = json.dumps(clean["value"])
    return clean


def run_local_episode(task_id: str, policy: str, model_name: str) -> Dict[str, float]:
    print(f"[START] task={task_id}", flush=True)
    env = VulnTriageEnvironment()
    observation = env.reset(task_id=task_id).model_dump()
    client = get_openai_client() if policy == "openai" else None
    last_action_str: str = ""
    repeat_count: int = 0
    step_num: int = 1

    while not observation["done"]:
        action_payload = (
            llm_policy(client, model_name, observation) if client else heuristic_policy(observation)
        )
        # Strip unknown keys then coerce bad value types
        try:
            clean = sanitize_action_payload(action_payload)
            action = VulnTriageAction.model_validate(clean)
        except Exception as exc:
            print(f"  [warn] invalid action payload ({exc}), falling back to read_report", flush=True)
            action = VulnTriageAction(action_type="read_report", rationale="fallback: parse error")

        # Break infinite loops where model repeats the same action
        action_str = action.model_dump_json()
        if action_str == last_action_str:
            repeat_count += 1
            if repeat_count >= 3:
                print(f"  [warn] model repeated same action 3x β€” forcing submit_triage", flush=True)
                action = VulnTriageAction(action_type="submit_triage", rationale="loop guard")
        else:
            repeat_count = 0
        last_action_str = action_str

        observation = env.step(action).model_dump()
        step_reward = float(observation.get("reward") or 0.0)
        print(f"[STEP] step={step_num} action={action.action_type} reward={step_reward}", flush=True)
        step_num += 1

    final_score = float(observation.get("final_score") or observation.get("score_breakdown", {}).get("total", 0.0))
    print(f"[END] task={task_id} score={final_score} steps={step_num}", flush=True)

    return {
        "task_id": task_id,
        "final_score": float(observation["final_score"] or 0.0),
        "validity": observation["score_breakdown"]["validity"],
        "package_versions": round(
            (
                observation["score_breakdown"]["affected_package"]
                + observation["score_breakdown"]["affected_versions"]
            )
            / 2,
            4,
        ),
        "severity": observation["score_breakdown"]["severity"],
        "exploitability": observation["score_breakdown"]["exploitability"],
        "next_action": observation["score_breakdown"]["next_action"],
    }


def run_remote_episode(base_url: str, task_id: str, policy: str, model_name: str) -> Dict[str, float]:
    print(f"[START] task={task_id}", flush=True)
    llm_client = get_openai_client() if policy == "openai" else None
    env = GenericEnvClient(base_url=base_url).sync()
    MAX_REMOTE_STEPS = 20
    with env:
        response = env.reset(task_id=task_id)
        observation = response.observation
        done = response.done
        step_num: int = 1
        last_action_str: str = ""
        repeat_count: int = 0

        while not done and step_num <= MAX_REMOTE_STEPS:
            # Get action from LLM or heuristic with error handling
            try:
                if llm_client:
                    action_payload = llm_policy(llm_client, model_name, observation)
                    action_payload = sanitize_action_payload(action_payload)
                else:
                    action_payload = heuristic_policy(observation)
            except Exception as exc:
                print(f"  [warn] policy error ({exc}), falling back to read_report", flush=True)
                action_payload = {"action_type": "read_report", "rationale": "fallback: policy error"}

            # Loop guard: detect repeated identical actions
            action_str = json.dumps(action_payload, sort_keys=True)
            if action_str == last_action_str:
                repeat_count += 1
                if repeat_count >= 3:
                    print(f"  [warn] repeated action 3x β€” forcing submit_triage", flush=True)
                    action_payload = {"action_type": "submit_triage", "rationale": "loop guard"}
            else:
                repeat_count = 0
            last_action_str = action_str

            # Step the environment with error handling
            try:
                response = env.step(action_payload)
                observation = response.observation
                done = response.done
            except Exception as exc:
                print(f"  [warn] env.step failed ({exc}), forcing submit", flush=True)
                action_payload = {"action_type": "submit_triage", "rationale": "env error recovery"}
                try:
                    response = env.step(action_payload)
                    observation = response.observation
                    done = response.done
                except Exception:
                    done = True

            step_reward = float(getattr(response, 'reward', None) or 0.0)
            print(f"[STEP] step={step_num} action={action_payload.get('action_type')} reward={step_reward}", flush=True)
            step_num += 1

    # Safely extract final score
    final_score = 0.0
    try:
        final_score = float(observation.get("final_score") or observation.get("score_breakdown", {}).get("total", 0.0))
    except Exception:
        pass
    print(f"[END] task={task_id} score={final_score} steps={step_num}", flush=True)

    # Build result safely
    sb = observation.get("score_breakdown", {})
    return {
        "task_id": task_id,
        "final_score": final_score,
        "validity": sb.get("validity", 0.0),
        "package_versions": round((sb.get("affected_package", 0.0) + sb.get("affected_versions", 0.0)) / 2, 4),
        "severity": sb.get("severity", 0.0),
        "exploitability": sb.get("exploitability", 0.0),
        "next_action": sb.get("next_action", 0.0),
    }


def main() -> None:
    parser = argparse.ArgumentParser()
    # Auto-select openai policy when the validator injects API credentials;
    # fall back to heuristic for local smoke-tests with no key.
    _has_credentials = bool(_API_KEY)
    _default_policy = "openai" if _has_credentials else "heuristic"
    parser.add_argument("--policy", choices=["openai", "heuristic"], default=_default_policy)
    parser.add_argument("--model", default=MODEL_NAME)
    # Default ENV_BASE_URL to the live HF Space so the validator can reach our environment
    _default_env_url = os.getenv("ENV_BASE_URL", "https://adhitya122-vulnops.hf.space")
    parser.add_argument("--env-base-url", dest="base_url", default=_default_env_url)
    args = parser.parse_args()

    results: List[Dict[str, float]] = []
    for task_id in TASK_ORDER:
        if args.base_url:
            results.append(run_remote_episode(args.base_url, task_id, args.policy, args.model))
        else:
            results.append(run_local_episode(task_id, args.policy, args.model))

    aggregate = round(sum(item["final_score"] for item in results) / len(results), 4)
    print(json.dumps({"policy": args.policy, "model": args.model, "average_score": aggregate, "tasks": results}, indent=2))


if __name__ == "__main__":
    main()