File size: 4,655 Bytes
d63a1ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
"""Evaluate a base or LoRA-adapted model on the local vulnops environment."""

from __future__ import annotations

import argparse
import json
import sys
from pathlib import Path
from typing import Dict, List

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

from server.cases import TASK_ORDER
from training_utils import (
    detect_device,
    maybe_parse_action,
    preferred_torch_dtype,
    render_prompt,
    set_default_env,
)
from models import VulnTriageAction
from server.vuln_triage_env_environment import VulnTriageEnvironment


def load_model(model_name: str, adapter_path: str | None, output_root: Path):
    set_default_env(output_root)
    device = detect_device()
    torch_dtype = preferred_torch_dtype(device)
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch_dtype,
        trust_remote_code=True,
        low_cpu_mem_usage=True,
    )

    if adapter_path:
        try:
            from peft import PeftModel
        except ImportError as exc:
            raise RuntimeError("peft is required to evaluate a LoRA adapter.") from exc
        model = PeftModel.from_pretrained(model, adapter_path)

    if device in {"cuda", "mps"}:
        model.to(device)
    model.eval()
    return model, tokenizer, device


@torch.inference_mode()
def next_action(model, tokenizer, device: str, observation: Dict[str, object]) -> Dict[str, object]:
    prompt = render_prompt(
        observation=observation,
        prompt_variant="Return only the best next action in JSON.",
    )
    encoded = tokenizer(prompt, return_tensors="pt")
    encoded = {key: value.to(device) for key, value in encoded.items()}
    generated = model.generate(
        **encoded,
        max_new_tokens=192,
        do_sample=False,
        temperature=None,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
    )
    prompt_length = encoded["input_ids"].shape[1]
    output_text = tokenizer.decode(generated[0][prompt_length:], skip_special_tokens=True).strip()
    payload = maybe_parse_action(output_text)
    if payload is None:
        return {
            "action_type": "submit_triage",
            "rationale": f"Fallback because model output could not be parsed: {output_text[:120]}",
        }
    return payload


def run_episode(model, tokenizer, device: str, task_id: str) -> Dict[str, object]:
    env = VulnTriageEnvironment()
    observation = env.reset(task_id=task_id).model_dump()
    actions: List[Dict[str, object]] = []
    while not observation["done"]:
        action_payload = next_action(model, tokenizer, device, observation)
        action = VulnTriageAction.model_validate(action_payload)
        actions.append(action.model_dump(exclude_none=True))
        observation = env.step(action).model_dump()
    return {
        "task_id": task_id,
        "difficulty": observation["difficulty"],
        "final_score": float(observation.get("final_score") or 0.0),
        "score_breakdown": observation["score_breakdown"],
        "steps_used": len(actions),
        "actions": actions,
    }


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", default="Qwen/Qwen3.5-4B")
    parser.add_argument("--adapter-path")
    parser.add_argument("--output-root", default="artifacts/lora_qwen3_4b")
    parser.add_argument("--output-json")
    args = parser.parse_args()

    output_root = (ROOT / args.output_root).resolve()
    model, tokenizer, device = load_model(args.model, args.adapter_path, output_root)
    episodes = [run_episode(model, tokenizer, device, task_id) for task_id in TASK_ORDER]
    average_score = round(sum(item["final_score"] for item in episodes) / len(episodes), 4)
    payload = {
        "model": args.model,
        "adapter_path": args.adapter_path,
        "device": device,
        "average_score": average_score,
        "episodes": episodes,
    }
    if args.output_json:
        output_path = Path(args.output_json)
        if not output_path.is_absolute():
            output_path = (ROOT / output_path).resolve()
        output_path.parent.mkdir(parents=True, exist_ok=True)
        output_path.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n", encoding="utf-8")
    print(json.dumps(payload, indent=2, sort_keys=True))


if __name__ == "__main__":
    main()