vulnops / scripts /run_mlx_training.py
Adhitya-Vardhan
Initial commit: VulnOps OpenEnv benchmark
d63a1ba
"""Run MLX LoRA training as the default local Mac training path."""
from __future__ import annotations
import argparse
import json
import shlex
import subprocess
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
def write_json(path: Path, payload: dict) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n", encoding="utf-8")
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--model", default="Qwen/Qwen3.5-4B")
parser.add_argument("--source-root", default="artifacts/lora_qwen3_4b/data")
parser.add_argument("--output-root", default="artifacts/mlx_qwen3_4b")
parser.add_argument("--iters", type=int, default=120)
parser.add_argument("--batch-size", type=int, default=1)
parser.add_argument("--grad-accumulation-steps", type=int, default=8)
parser.add_argument("--learning-rate", type=float, default=5e-5)
parser.add_argument("--num-layers", type=int, default=8)
parser.add_argument("--max-seq-length", type=int, default=1024)
parser.add_argument("--steps-per-report", type=int, default=1)
parser.add_argument("--save-every", type=int, default=20)
parser.add_argument("--seed", type=int, default=7)
parser.add_argument("--fresh-start", action="store_true")
parser.add_argument("--include-valid", action="store_true")
args = parser.parse_args()
output_root = (ROOT / args.output_root).resolve()
data_root = output_root / "data"
log_path = output_root / "logs" / "mlx_train.log"
manifest_path = output_root / "run_manifest.json"
adapter_root = output_root / "adapters"
adapter_file = adapter_root / "adapters.safetensors"
speed_path = output_root / "metrics" / "speed_mlx.json"
output_root.mkdir(parents=True, exist_ok=True)
if args.fresh_start:
for rel in [log_path, speed_path, output_root / "training_summary.json", adapter_file]:
if rel.exists():
rel.unlink()
prepare_cmd = [
sys.executable,
"scripts/prepare_mlx_data.py",
"--source-root",
args.source_root,
"--output-root",
str(data_root.relative_to(ROOT)),
"--model",
args.model,
"--max-seq-length",
str(args.max_seq_length),
"--force",
]
if args.include_valid:
prepare_cmd.append("--include-valid")
subprocess.run(prepare_cmd, cwd=ROOT, check=True)
cmd = [
sys.executable,
"-m",
"mlx_lm",
"lora",
"--model",
args.model,
"--train",
"--data",
str(data_root),
"--mask-prompt",
"--num-layers",
str(args.num_layers),
"--batch-size",
str(args.batch_size),
"--iters",
str(args.iters),
"--learning-rate",
str(args.learning_rate),
"--steps-per-report",
str(args.steps_per_report),
"--steps-per-eval",
"1000000",
"--save-every",
str(args.save_every),
"--grad-accumulation-steps",
str(args.grad_accumulation_steps),
"--grad-checkpoint",
"--adapter-path",
str(adapter_root),
"--max-seq-length",
str(args.max_seq_length),
"--seed",
str(args.seed),
]
if not args.fresh_start and adapter_file.exists():
cmd.extend(["--resume-adapter-file", str(adapter_file)])
write_json(
manifest_path,
{
"status": "starting_training",
"trainer": "mlx_lm_lora",
"model": args.model,
"data_root": str(data_root),
"output_root": str(output_root),
"command": cmd,
"fresh_start": args.fresh_start,
},
)
log_path.parent.mkdir(parents=True, exist_ok=True)
with log_path.open("a", encoding="utf-8") as handle:
handle.write("\n===== mlx_lm_lora =====\n")
handle.write("COMMAND: " + " ".join(shlex.quote(part) for part in cmd) + "\n")
handle.flush()
process = subprocess.run(cmd, cwd=ROOT, stdout=handle, stderr=subprocess.STDOUT, text=True)
subprocess.run([sys.executable, "scripts/save_mlx_speed.py", "--log-path", str(log_path), "--output-path", str(speed_path)], cwd=ROOT, check=False)
summary = {
"status": "finished" if process.returncode == 0 else "failed",
"trainer": "mlx_lm_lora",
"return_code": process.returncode,
"log_path": str(log_path),
"speed_path": str(speed_path),
"adapter_root": str(adapter_root),
}
write_json(output_root / "training_summary.json", summary)
write_json(manifest_path, summary)
if process.returncode != 0:
raise SystemExit(process.returncode)
if __name__ == "__main__":
main()