File size: 4,365 Bytes
d63a1ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
"""Evaluate base or MLX-adapted Qwen models on the local vulnops environment."""

from __future__ import annotations

import argparse
import json
import re
import sys
from pathlib import Path
from typing import Dict, List

ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

from mlx_lm import generate, load
from mlx_lm.sample_utils import make_sampler

from models import VulnTriageAction
from server.cases import TASK_ORDER
from server.vuln_triage_env_environment import VulnTriageEnvironment
from training_utils import render_prompt


THINK_BLOCK_RE = re.compile(r"<think>.*?</think>", re.DOTALL | re.IGNORECASE)


def extract_last_json_object(text: str) -> str | None:
    cleaned = THINK_BLOCK_RE.sub("", text).strip()
    start = cleaned.find("{")
    if start == -1:
        return None
    depth = 0
    in_string = False
    escape = False
    last_candidate = None
    candidate_start = None
    for index, ch in enumerate(cleaned):
        if ch == "\\" and in_string and not escape:
            escape = True
            continue
        if ch == '"' and not escape:
            in_string = not in_string
        escape = False
        if in_string:
            continue
        if ch == "{":
            if depth == 0:
                candidate_start = index
            depth += 1
        elif ch == "}":
            depth -= 1
            if depth == 0 and candidate_start is not None:
                last_candidate = cleaned[candidate_start : index + 1]
    return last_candidate


def parse_action_output(text: str) -> Dict[str, object] | None:
    candidate = extract_last_json_object(text)
    if candidate is None:
        return None
    try:
        payload = json.loads(candidate)
        action = VulnTriageAction.model_validate(payload)
    except Exception:
        return None
    return action.model_dump(exclude_none=True)


def next_action(model, tokenizer, observation: Dict[str, object]) -> Dict[str, object]:
    prompt = render_prompt(
        observation=observation,
        prompt_variant="Return only the best next action in JSON.",
    )
    output = generate(
        model,
        tokenizer,
        prompt=prompt,
        verbose=False,
        max_tokens=192,
        sampler=make_sampler(temp=0.0),
    )
    payload = parse_action_output(output)
    if payload is None:
        return {
            "action_type": "submit_triage",
            "rationale": f"Fallback because model output could not be parsed: {output[:120]}",
        }
    return payload


def run_episode(model, tokenizer, task_id: str) -> Dict[str, object]:
    env = VulnTriageEnvironment()
    observation = env.reset(task_id=task_id).model_dump()
    actions: List[Dict[str, object]] = []
    while not observation["done"]:
        action_payload = next_action(model, tokenizer, observation)
        action = VulnTriageAction.model_validate(action_payload)
        actions.append(action.model_dump(exclude_none=True))
        observation = env.step(action).model_dump()
    return {
        "task_id": task_id,
        "difficulty": observation["difficulty"],
        "final_score": float(observation.get("final_score") or 0.0),
        "score_breakdown": observation["score_breakdown"],
        "steps_used": len(actions),
        "actions": actions,
    }


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", default="Qwen/Qwen3.5-4B")
    parser.add_argument("--adapter-path")
    parser.add_argument("--output-json")
    args = parser.parse_args()

    model, tokenizer = load(args.model, adapter_path=args.adapter_path)
    episodes = [run_episode(model, tokenizer, task_id) for task_id in TASK_ORDER]
    average_score = round(sum(item["final_score"] for item in episodes) / len(episodes), 4)
    payload = {
        "model": args.model,
        "adapter_path": args.adapter_path,
        "average_score": average_score,
        "episodes": episodes,
    }
    if args.output_json:
        out = Path(args.output_json)
        if not out.is_absolute():
            out = (ROOT / out).resolve()
        out.parent.mkdir(parents=True, exist_ok=True)
        out.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n", encoding="utf-8")
    print(json.dumps(payload, indent=2, sort_keys=True))


if __name__ == "__main__":
    main()