dflash-mlx-universal / examples /convert_drafter.py
tritesh's picture
Upload folder using huggingface_hub
0433390 verified
"""
Convert a PyTorch DFlash drafter from Hugging Face to MLX format.
Usage:
python convert_drafter.py --model z-lab/Qwen3-4B-DFlash-b16 --output ./Qwen3-4B-DFlash-mlx
python convert_drafter.py --model z-lab/Qwen3-8B-DFlash-b16 --output ./Qwen3-8B-DFlash-mlx
python convert_drafter.py --model z-lab/Qwen3.5-9B-DFlash --output ./Qwen3.5-9B-DFlash-mlx
"""
import argparse
from pathlib import Path
from dflash_mlx.convert import convert_dflash_to_mlx
SUPPORTED_DRAFTERS = [
"z-lab/Qwen3-4B-DFlash-b16",
"z-lab/Qwen3-8B-DFlash-b16",
"z-lab/Qwen3.5-9B-DFlash",
"z-lab/Qwen3.5-27B-DFlash",
"z-lab/Qwen3.6-27B-DFlash",
"z-lab/Qwen3.6-35B-A3B-DFlash",
"z-lab/Qwen3-Coder-30B-A3B-DFlash",
"z-lab/LLaMA3.1-8B-Instruct-DFlash-UltraChat",
"z-lab/gemma-4-31B-it-DFlash",
"z-lab/gemma-4-26B-A4B-it-DFlash",
"z-lab/gpt-oss-20b-DFlash",
"z-lab/Kimi-K2.5-DFlash",
"z-lab/MiniMax-M2.5-DFlash",
]
def main():
parser = argparse.ArgumentParser(description="Convert DFlash drafter to MLX")
parser.add_argument(
"--model",
type=str,
required=True,
help="Hugging Face model ID of the DFlash drafter",
)
parser.add_argument(
"--output",
type=str,
required=True,
help="Output directory for converted MLX model",
)
parser.add_argument(
"--trust-remote-code",
action="store_true",
default=True,
help="Trust remote code for custom modeling",
)
parser.add_argument(
"--token",
type=str,
default=None,
help="Hugging Face API token (for gated/private models)",
)
args = parser.parse_args()
if args.model not in SUPPORTED_DRAFTERS:
print(f"Warning: {args.model} not in known supported list. Attempting conversion anyway.")
print("Known models:")
for m in SUPPORTED_DRAFTERS:
print(f" - {m}")
print(f"Converting {args.model} to MLX format...")
print(f"Output: {args.output}")
output_path = convert_dflash_to_mlx(
pytorch_model_id=args.model,
output_path=args.output,
trust_remote_code=args.trust_remote_code,
token=args.token,
)
print(f"\n✓ Conversion complete!")
print(f" Model saved to: {output_path}")
print(f"\nTo use:")
print(f" from dflash_mlx.convert import load_mlx_dflash")
print(f" model, config = load_mlx_dflash('{args.output}')")
if __name__ == "__main__":
main()