Spaces:
Sleeping
Sleeping
| """Run MLX LoRA training as the default local Mac training path.""" | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import shlex | |
| import subprocess | |
| import sys | |
| from pathlib import Path | |
| ROOT = Path(__file__).resolve().parents[1] | |
| if str(ROOT) not in sys.path: | |
| sys.path.insert(0, str(ROOT)) | |
| def write_json(path: Path, payload: dict) -> None: | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| path.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n", encoding="utf-8") | |
| def main() -> None: | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--model", default="Qwen/Qwen3.5-4B") | |
| parser.add_argument("--source-root", default="artifacts/lora_qwen3_4b/data") | |
| parser.add_argument("--output-root", default="artifacts/mlx_qwen3_4b") | |
| parser.add_argument("--iters", type=int, default=120) | |
| parser.add_argument("--batch-size", type=int, default=1) | |
| parser.add_argument("--grad-accumulation-steps", type=int, default=8) | |
| parser.add_argument("--learning-rate", type=float, default=5e-5) | |
| parser.add_argument("--num-layers", type=int, default=8) | |
| parser.add_argument("--max-seq-length", type=int, default=1024) | |
| parser.add_argument("--steps-per-report", type=int, default=1) | |
| parser.add_argument("--save-every", type=int, default=20) | |
| parser.add_argument("--seed", type=int, default=7) | |
| parser.add_argument("--fresh-start", action="store_true") | |
| parser.add_argument("--include-valid", action="store_true") | |
| args = parser.parse_args() | |
| output_root = (ROOT / args.output_root).resolve() | |
| data_root = output_root / "data" | |
| log_path = output_root / "logs" / "mlx_train.log" | |
| manifest_path = output_root / "run_manifest.json" | |
| adapter_root = output_root / "adapters" | |
| adapter_file = adapter_root / "adapters.safetensors" | |
| speed_path = output_root / "metrics" / "speed_mlx.json" | |
| output_root.mkdir(parents=True, exist_ok=True) | |
| if args.fresh_start: | |
| for rel in [log_path, speed_path, output_root / "training_summary.json", adapter_file]: | |
| if rel.exists(): | |
| rel.unlink() | |
| prepare_cmd = [ | |
| sys.executable, | |
| "scripts/prepare_mlx_data.py", | |
| "--source-root", | |
| args.source_root, | |
| "--output-root", | |
| str(data_root.relative_to(ROOT)), | |
| "--model", | |
| args.model, | |
| "--max-seq-length", | |
| str(args.max_seq_length), | |
| "--force", | |
| ] | |
| if args.include_valid: | |
| prepare_cmd.append("--include-valid") | |
| subprocess.run(prepare_cmd, cwd=ROOT, check=True) | |
| cmd = [ | |
| sys.executable, | |
| "-m", | |
| "mlx_lm", | |
| "lora", | |
| "--model", | |
| args.model, | |
| "--train", | |
| "--data", | |
| str(data_root), | |
| "--mask-prompt", | |
| "--num-layers", | |
| str(args.num_layers), | |
| "--batch-size", | |
| str(args.batch_size), | |
| "--iters", | |
| str(args.iters), | |
| "--learning-rate", | |
| str(args.learning_rate), | |
| "--steps-per-report", | |
| str(args.steps_per_report), | |
| "--steps-per-eval", | |
| "1000000", | |
| "--save-every", | |
| str(args.save_every), | |
| "--grad-accumulation-steps", | |
| str(args.grad_accumulation_steps), | |
| "--grad-checkpoint", | |
| "--adapter-path", | |
| str(adapter_root), | |
| "--max-seq-length", | |
| str(args.max_seq_length), | |
| "--seed", | |
| str(args.seed), | |
| ] | |
| if not args.fresh_start and adapter_file.exists(): | |
| cmd.extend(["--resume-adapter-file", str(adapter_file)]) | |
| write_json( | |
| manifest_path, | |
| { | |
| "status": "starting_training", | |
| "trainer": "mlx_lm_lora", | |
| "model": args.model, | |
| "data_root": str(data_root), | |
| "output_root": str(output_root), | |
| "command": cmd, | |
| "fresh_start": args.fresh_start, | |
| }, | |
| ) | |
| log_path.parent.mkdir(parents=True, exist_ok=True) | |
| with log_path.open("a", encoding="utf-8") as handle: | |
| handle.write("\n===== mlx_lm_lora =====\n") | |
| handle.write("COMMAND: " + " ".join(shlex.quote(part) for part in cmd) + "\n") | |
| handle.flush() | |
| process = subprocess.run(cmd, cwd=ROOT, stdout=handle, stderr=subprocess.STDOUT, text=True) | |
| subprocess.run([sys.executable, "scripts/save_mlx_speed.py", "--log-path", str(log_path), "--output-path", str(speed_path)], cwd=ROOT, check=False) | |
| summary = { | |
| "status": "finished" if process.returncode == 0 else "failed", | |
| "trainer": "mlx_lm_lora", | |
| "return_code": process.returncode, | |
| "log_path": str(log_path), | |
| "speed_path": str(speed_path), | |
| "adapter_root": str(adapter_root), | |
| } | |
| write_json(output_root / "training_summary.json", summary) | |
| write_json(manifest_path, summary) | |
| if process.returncode != 0: | |
| raise SystemExit(process.returncode) | |
| if __name__ == "__main__": | |
| main() | |