File size: 2,009 Bytes
d63a1ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
"""Dump a full raw generation from the MLX model for one vulnops observation."""

from __future__ import annotations

import argparse
import json
import sys
from pathlib import Path

ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

from mlx_lm import generate, load
from mlx_lm.sample_utils import make_sampler

from server.vuln_triage_env_environment import VulnTriageEnvironment
from training_utils import render_prompt


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", default="Qwen/Qwen3.5-4B")
    parser.add_argument("--adapter-path", default="artifacts/mlx_qwen3_4b/adapters")
    parser.add_argument("--task-id", default="task_easy_guarddog")
    parser.add_argument("--max-tokens", type=int, default=2048)
    parser.add_argument(
        "--output-file",
        default="artifacts/mlx_qwen3_4b/inspection/task_easy_guarddog_latest_raw_output.json",
    )
    args = parser.parse_args()

    model, tokenizer = load(args.model, adapter_path=args.adapter_path)
    env = VulnTriageEnvironment()
    observation = env.reset(task_id=args.task_id).model_dump()
    prompt = render_prompt(observation, "Return only the best next action in JSON.")
    raw_output = generate(
        model,
        tokenizer,
        prompt=prompt,
        verbose=False,
        max_tokens=args.max_tokens,
        sampler=make_sampler(temp=0.0),
    )

    output_path = Path(args.output_file)
    if not output_path.is_absolute():
        output_path = (ROOT / output_path).resolve()
    output_path.parent.mkdir(parents=True, exist_ok=True)
    payload = {
        "task_id": args.task_id,
        "model": args.model,
        "adapter_path": args.adapter_path,
        "max_tokens": args.max_tokens,
        "prompt": prompt,
        "raw_output": raw_output,
    }
    output_path.write_text(json.dumps(payload, indent=2) + "\n", encoding="utf-8")
    print(output_path)


if __name__ == "__main__":
    main()