| """ |
| Convert a PyTorch DFlash drafter from Hugging Face to MLX format. |
| |
| Usage: |
| python convert_drafter.py --model z-lab/Qwen3-4B-DFlash-b16 --output ./Qwen3-4B-DFlash-mlx |
| python convert_drafter.py --model z-lab/Qwen3-8B-DFlash-b16 --output ./Qwen3-8B-DFlash-mlx |
| python convert_drafter.py --model z-lab/Qwen3.5-9B-DFlash --output ./Qwen3.5-9B-DFlash-mlx |
| """ |
|
|
| import argparse |
| from pathlib import Path |
| from dflash_mlx.convert import convert_dflash_to_mlx |
|
|
|
|
| SUPPORTED_DRAFTERS = [ |
| "z-lab/Qwen3-4B-DFlash-b16", |
| "z-lab/Qwen3-8B-DFlash-b16", |
| "z-lab/Qwen3.5-9B-DFlash", |
| "z-lab/Qwen3.5-27B-DFlash", |
| "z-lab/Qwen3.6-27B-DFlash", |
| "z-lab/Qwen3.6-35B-A3B-DFlash", |
| "z-lab/Qwen3-Coder-30B-A3B-DFlash", |
| "z-lab/LLaMA3.1-8B-Instruct-DFlash-UltraChat", |
| "z-lab/gemma-4-31B-it-DFlash", |
| "z-lab/gemma-4-26B-A4B-it-DFlash", |
| "z-lab/gpt-oss-20b-DFlash", |
| "z-lab/Kimi-K2.5-DFlash", |
| "z-lab/MiniMax-M2.5-DFlash", |
| ] |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Convert DFlash drafter to MLX") |
| parser.add_argument( |
| "--model", |
| type=str, |
| required=True, |
| help="Hugging Face model ID of the DFlash drafter", |
| ) |
| parser.add_argument( |
| "--output", |
| type=str, |
| required=True, |
| help="Output directory for converted MLX model", |
| ) |
| parser.add_argument( |
| "--trust-remote-code", |
| action="store_true", |
| default=True, |
| help="Trust remote code for custom modeling", |
| ) |
| parser.add_argument( |
| "--token", |
| type=str, |
| default=None, |
| help="Hugging Face API token (for gated/private models)", |
| ) |
|
|
| args = parser.parse_args() |
|
|
| if args.model not in SUPPORTED_DRAFTERS: |
| print(f"Warning: {args.model} not in known supported list. Attempting conversion anyway.") |
| print("Known models:") |
| for m in SUPPORTED_DRAFTERS: |
| print(f" - {m}") |
|
|
| print(f"Converting {args.model} to MLX format...") |
| print(f"Output: {args.output}") |
| |
| output_path = convert_dflash_to_mlx( |
| pytorch_model_id=args.model, |
| output_path=args.output, |
| trust_remote_code=args.trust_remote_code, |
| token=args.token, |
| ) |
| |
| print(f"\n✓ Conversion complete!") |
| print(f" Model saved to: {output_path}") |
| print(f"\nTo use:") |
| print(f" from dflash_mlx.convert import load_mlx_dflash") |
| print(f" model, config = load_mlx_dflash('{args.output}')") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|