""" Convert a PyTorch DFlash drafter from Hugging Face to MLX format. Usage: python convert_drafter.py --model z-lab/Qwen3-4B-DFlash-b16 --output ./Qwen3-4B-DFlash-mlx python convert_drafter.py --model z-lab/Qwen3-8B-DFlash-b16 --output ./Qwen3-8B-DFlash-mlx python convert_drafter.py --model z-lab/Qwen3.5-9B-DFlash --output ./Qwen3.5-9B-DFlash-mlx """ import argparse from pathlib import Path from dflash_mlx.convert import convert_dflash_to_mlx SUPPORTED_DRAFTERS = [ "z-lab/Qwen3-4B-DFlash-b16", "z-lab/Qwen3-8B-DFlash-b16", "z-lab/Qwen3.5-9B-DFlash", "z-lab/Qwen3.5-27B-DFlash", "z-lab/Qwen3.6-27B-DFlash", "z-lab/Qwen3.6-35B-A3B-DFlash", "z-lab/Qwen3-Coder-30B-A3B-DFlash", "z-lab/LLaMA3.1-8B-Instruct-DFlash-UltraChat", "z-lab/gemma-4-31B-it-DFlash", "z-lab/gemma-4-26B-A4B-it-DFlash", "z-lab/gpt-oss-20b-DFlash", "z-lab/Kimi-K2.5-DFlash", "z-lab/MiniMax-M2.5-DFlash", ] def main(): parser = argparse.ArgumentParser(description="Convert DFlash drafter to MLX") parser.add_argument( "--model", type=str, required=True, help="Hugging Face model ID of the DFlash drafter", ) parser.add_argument( "--output", type=str, required=True, help="Output directory for converted MLX model", ) parser.add_argument( "--trust-remote-code", action="store_true", default=True, help="Trust remote code for custom modeling", ) parser.add_argument( "--token", type=str, default=None, help="Hugging Face API token (for gated/private models)", ) args = parser.parse_args() if args.model not in SUPPORTED_DRAFTERS: print(f"Warning: {args.model} not in known supported list. Attempting conversion anyway.") print("Known models:") for m in SUPPORTED_DRAFTERS: print(f" - {m}") print(f"Converting {args.model} to MLX format...") print(f"Output: {args.output}") output_path = convert_dflash_to_mlx( pytorch_model_id=args.model, output_path=args.output, trust_remote_code=args.trust_remote_code, token=args.token, ) print(f"\n✓ Conversion complete!") print(f" Model saved to: {output_path}") print(f"\nTo use:") print(f" from dflash_mlx.convert import load_mlx_dflash") print(f" model, config = load_mlx_dflash('{args.output}')") if __name__ == "__main__": main()