File size: 4,919 Bytes
d63a1ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
"""Run MLX LoRA training as the default local Mac training path."""

from __future__ import annotations

import argparse
import json
import shlex
import subprocess
import sys
from pathlib import Path


ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))


def write_json(path: Path, payload: dict) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    path.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n", encoding="utf-8")


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", default="Qwen/Qwen3.5-4B")
    parser.add_argument("--source-root", default="artifacts/lora_qwen3_4b/data")
    parser.add_argument("--output-root", default="artifacts/mlx_qwen3_4b")
    parser.add_argument("--iters", type=int, default=120)
    parser.add_argument("--batch-size", type=int, default=1)
    parser.add_argument("--grad-accumulation-steps", type=int, default=8)
    parser.add_argument("--learning-rate", type=float, default=5e-5)
    parser.add_argument("--num-layers", type=int, default=8)
    parser.add_argument("--max-seq-length", type=int, default=1024)
    parser.add_argument("--steps-per-report", type=int, default=1)
    parser.add_argument("--save-every", type=int, default=20)
    parser.add_argument("--seed", type=int, default=7)
    parser.add_argument("--fresh-start", action="store_true")
    parser.add_argument("--include-valid", action="store_true")
    args = parser.parse_args()

    output_root = (ROOT / args.output_root).resolve()
    data_root = output_root / "data"
    log_path = output_root / "logs" / "mlx_train.log"
    manifest_path = output_root / "run_manifest.json"
    adapter_root = output_root / "adapters"
    adapter_file = adapter_root / "adapters.safetensors"
    speed_path = output_root / "metrics" / "speed_mlx.json"

    output_root.mkdir(parents=True, exist_ok=True)
    if args.fresh_start:
        for rel in [log_path, speed_path, output_root / "training_summary.json", adapter_file]:
            if rel.exists():
                rel.unlink()

    prepare_cmd = [
        sys.executable,
        "scripts/prepare_mlx_data.py",
        "--source-root",
        args.source_root,
        "--output-root",
        str(data_root.relative_to(ROOT)),
        "--model",
        args.model,
        "--max-seq-length",
        str(args.max_seq_length),
        "--force",
    ]
    if args.include_valid:
        prepare_cmd.append("--include-valid")
    subprocess.run(prepare_cmd, cwd=ROOT, check=True)

    cmd = [
        sys.executable,
        "-m",
        "mlx_lm",
        "lora",
        "--model",
        args.model,
        "--train",
        "--data",
        str(data_root),
        "--mask-prompt",
        "--num-layers",
        str(args.num_layers),
        "--batch-size",
        str(args.batch_size),
        "--iters",
        str(args.iters),
        "--learning-rate",
        str(args.learning_rate),
        "--steps-per-report",
        str(args.steps_per_report),
        "--steps-per-eval",
        "1000000",
        "--save-every",
        str(args.save_every),
        "--grad-accumulation-steps",
        str(args.grad_accumulation_steps),
        "--grad-checkpoint",
        "--adapter-path",
        str(adapter_root),
        "--max-seq-length",
        str(args.max_seq_length),
        "--seed",
        str(args.seed),
    ]
    if not args.fresh_start and adapter_file.exists():
        cmd.extend(["--resume-adapter-file", str(adapter_file)])

    write_json(
        manifest_path,
        {
            "status": "starting_training",
            "trainer": "mlx_lm_lora",
            "model": args.model,
            "data_root": str(data_root),
            "output_root": str(output_root),
            "command": cmd,
            "fresh_start": args.fresh_start,
        },
    )

    log_path.parent.mkdir(parents=True, exist_ok=True)
    with log_path.open("a", encoding="utf-8") as handle:
        handle.write("\n===== mlx_lm_lora =====\n")
        handle.write("COMMAND: " + " ".join(shlex.quote(part) for part in cmd) + "\n")
        handle.flush()
        process = subprocess.run(cmd, cwd=ROOT, stdout=handle, stderr=subprocess.STDOUT, text=True)

    subprocess.run([sys.executable, "scripts/save_mlx_speed.py", "--log-path", str(log_path), "--output-path", str(speed_path)], cwd=ROOT, check=False)

    summary = {
        "status": "finished" if process.returncode == 0 else "failed",
        "trainer": "mlx_lm_lora",
        "return_code": process.returncode,
        "log_path": str(log_path),
        "speed_path": str(speed_path),
        "adapter_root": str(adapter_root),
    }
    write_json(output_root / "training_summary.json", summary)
    write_json(manifest_path, summary)
    if process.returncode != 0:
        raise SystemExit(process.returncode)


if __name__ == "__main__":
    main()