vulnops / scripts /dump_mlx_generation.py
Adhitya-Vardhan
Initial commit: VulnOps OpenEnv benchmark
d63a1ba
"""Dump a full raw generation from the MLX model for one vulnops observation."""
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from mlx_lm import generate, load
from mlx_lm.sample_utils import make_sampler
from server.vuln_triage_env_environment import VulnTriageEnvironment
from training_utils import render_prompt
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--model", default="Qwen/Qwen3.5-4B")
parser.add_argument("--adapter-path", default="artifacts/mlx_qwen3_4b/adapters")
parser.add_argument("--task-id", default="task_easy_guarddog")
parser.add_argument("--max-tokens", type=int, default=2048)
parser.add_argument(
"--output-file",
default="artifacts/mlx_qwen3_4b/inspection/task_easy_guarddog_latest_raw_output.json",
)
args = parser.parse_args()
model, tokenizer = load(args.model, adapter_path=args.adapter_path)
env = VulnTriageEnvironment()
observation = env.reset(task_id=args.task_id).model_dump()
prompt = render_prompt(observation, "Return only the best next action in JSON.")
raw_output = generate(
model,
tokenizer,
prompt=prompt,
verbose=False,
max_tokens=args.max_tokens,
sampler=make_sampler(temp=0.0),
)
output_path = Path(args.output_file)
if not output_path.is_absolute():
output_path = (ROOT / output_path).resolve()
output_path.parent.mkdir(parents=True, exist_ok=True)
payload = {
"task_id": args.task_id,
"model": args.model,
"adapter_path": args.adapter_path,
"max_tokens": args.max_tokens,
"prompt": prompt,
"raw_output": raw_output,
}
output_path.write_text(json.dumps(payload, indent=2) + "\n", encoding="utf-8")
print(output_path)
if __name__ == "__main__":
main()