File size: 2,514 Bytes
0433390
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
"""
Convert a PyTorch DFlash drafter from Hugging Face to MLX format.

Usage:
    python convert_drafter.py --model z-lab/Qwen3-4B-DFlash-b16 --output ./Qwen3-4B-DFlash-mlx
    python convert_drafter.py --model z-lab/Qwen3-8B-DFlash-b16 --output ./Qwen3-8B-DFlash-mlx
    python convert_drafter.py --model z-lab/Qwen3.5-9B-DFlash --output ./Qwen3.5-9B-DFlash-mlx
"""

import argparse
from pathlib import Path
from dflash_mlx.convert import convert_dflash_to_mlx


SUPPORTED_DRAFTERS = [
    "z-lab/Qwen3-4B-DFlash-b16",
    "z-lab/Qwen3-8B-DFlash-b16",
    "z-lab/Qwen3.5-9B-DFlash",
    "z-lab/Qwen3.5-27B-DFlash",
    "z-lab/Qwen3.6-27B-DFlash",
    "z-lab/Qwen3.6-35B-A3B-DFlash",
    "z-lab/Qwen3-Coder-30B-A3B-DFlash",
    "z-lab/LLaMA3.1-8B-Instruct-DFlash-UltraChat",
    "z-lab/gemma-4-31B-it-DFlash",
    "z-lab/gemma-4-26B-A4B-it-DFlash",
    "z-lab/gpt-oss-20b-DFlash",
    "z-lab/Kimi-K2.5-DFlash",
    "z-lab/MiniMax-M2.5-DFlash",
]


def main():
    parser = argparse.ArgumentParser(description="Convert DFlash drafter to MLX")
    parser.add_argument(
        "--model",
        type=str,
        required=True,
        help="Hugging Face model ID of the DFlash drafter",
    )
    parser.add_argument(
        "--output",
        type=str,
        required=True,
        help="Output directory for converted MLX model",
    )
    parser.add_argument(
        "--trust-remote-code",
        action="store_true",
        default=True,
        help="Trust remote code for custom modeling",
    )
    parser.add_argument(
        "--token",
        type=str,
        default=None,
        help="Hugging Face API token (for gated/private models)",
    )

    args = parser.parse_args()

    if args.model not in SUPPORTED_DRAFTERS:
        print(f"Warning: {args.model} not in known supported list. Attempting conversion anyway.")
        print("Known models:")
        for m in SUPPORTED_DRAFTERS:
            print(f"  - {m}")

    print(f"Converting {args.model} to MLX format...")
    print(f"Output: {args.output}")
    
    output_path = convert_dflash_to_mlx(
        pytorch_model_id=args.model,
        output_path=args.output,
        trust_remote_code=args.trust_remote_code,
        token=args.token,
    )
    
    print(f"\n✓ Conversion complete!")
    print(f"  Model saved to: {output_path}")
    print(f"\nTo use:")
    print(f"  from dflash_mlx.convert import load_mlx_dflash")
    print(f"  model, config = load_mlx_dflash('{args.output}')")


if __name__ == "__main__":
    main()