| """ |
| Convert PyTorch DFlash drafter models to MLX format. |
| |
| Handles weight conversion from PyTorch safetensors to MLX arrays, |
| compatible with any z-lab DFlash drafter. |
| """ |
|
|
| import json |
| import os |
| from pathlib import Path |
| from typing import Optional, Dict |
| import mlx.core as mx |
| from transformers import AutoConfig, AutoModel |
| from huggingface_hub import hf_hub_download, snapshot_download |
|
|
|
|
| def _convert_key(key: str) -> str: |
| """Convert PyTorch parameter names to MLX format.""" |
| |
| key = key.replace("model.", "") |
| |
| replacements = { |
| "embed_tokens": "embed_tokens", |
| "layers.": "layers.", |
| "self_attn.": "self_attn.", |
| "mlp.": "mlp.", |
| "input_layernorm": "input_layernorm", |
| "post_attention_layernorm": "post_attention_layernorm", |
| "norm": "norm", |
| "lm_head": "lm_head", |
| "q_proj": "q_proj", |
| "k_proj": "k_proj", |
| "v_proj": "v_proj", |
| "o_proj": "o_proj", |
| "gate_proj": "gate_proj", |
| "up_proj": "up_proj", |
| "down_proj": "down_proj", |
| "fc": "fc", |
| "hidden_norm": "hidden_norm", |
| "q_norm": "q_norm", |
| "k_norm": "k_norm", |
| "weight": "weight", |
| } |
| return key |
|
|
|
|
| def _transpose_if_needed(key: str, tensor) -> mx.array: |
| """Transpose linear layer weights from PyTorch to MLX format.""" |
| |
| if "proj" in key or "fc" in key or "lm_head" in key or "embed" in key: |
| if len(tensor.shape) == 2: |
| return mx.array(tensor.T) |
| return mx.array(tensor) |
|
|
|
|
| def convert_dflash_to_mlx( |
| pytorch_model_id: str, |
| output_path: str, |
| trust_remote_code: bool = True, |
| token: Optional[str] = None, |
| ) -> str: |
| """Convert a PyTorch DFlash drafter to MLX format. |
| |
| Args: |
| pytorch_model_id: Hugging Face model ID (e.g., "z-lab/Qwen3-4B-DFlash-b16") |
| output_path: Local directory to save converted model |
| trust_remote_code: Whether to trust custom modeling code |
| token: HF API token for gated/private models |
| |
| Returns: |
| Path to the converted model directory |
| """ |
| output_path = Path(output_path) |
| output_path.mkdir(parents=True, exist_ok=True) |
|
|
| print(f"[Convert] Downloading {pytorch_model_id}...") |
| |
| |
| repo_path = snapshot_download( |
| repo_id=pytorch_model_id, |
| token=token, |
| ignore_patterns=["*.md", "*.png", "*.jpg"], |
| ) |
| repo_path = Path(repo_path) |
|
|
| |
| print("[Convert] Loading PyTorch model for config extraction...") |
| config = AutoConfig.from_pretrained( |
| repo_path, |
| trust_remote_code=trust_remote_code, |
| ) |
|
|
| |
| dflash_config = { |
| "vocab_size": getattr(config, "vocab_size", 151936), |
| "hidden_size": getattr(config, "hidden_size", 1024), |
| "num_hidden_layers": getattr(config, "num_hidden_layers", 5), |
| "num_attention_heads": getattr(config, "num_attention_heads", 16), |
| "num_key_value_heads": getattr(config, "num_key_value_heads", 4), |
| "intermediate_size": getattr(config, "intermediate_size", 2816), |
| "max_position_embeddings": getattr(config, "max_position_embeddings", 32768), |
| "rms_norm_eps": getattr(config, "rms_norm_eps", 1e-6), |
| "block_size": getattr(config, "block_size", 16), |
| "rope_base": getattr(config, "rope_theta", 10000.0), |
| } |
|
|
| |
| print("[Convert] Loading weights from safetensors...") |
| try: |
| from safetensors.torch import load_file |
| weights_file = repo_path / "model.safetensors" |
| if weights_file.exists(): |
| pt_weights = load_file(str(weights_file)) |
| else: |
| |
| safetensors_files = list(repo_path.glob("*.safetensors")) |
| if safetensors_files: |
| pt_weights = load_file(str(safetensors_files[0])) |
| else: |
| raise FileNotFoundError("No safetensors file found") |
| except ImportError: |
| |
| import torch |
| weights_file = repo_path / "pytorch_model.bin" |
| pt_weights = torch.load(str(weights_file), map_location="cpu") |
|
|
| |
| print(f"[Convert] Converting {len(pt_weights)} parameters...") |
| mlx_weights = {} |
| for key, tensor in pt_weights.items(): |
| mlx_key = _convert_key(key) |
| mlx_weights[mlx_key] = _transpose_if_needed(key, tensor) |
|
|
| |
| weights_path = output_path / "weights.safetensors" |
| print(f"[Convert] Saving to {weights_path}...") |
| |
| |
| mx.save_safetensors(str(weights_path), mlx_weights) |
|
|
| |
| config_path = output_path / "config.json" |
| with open(config_path, "w") as f: |
| json.dump(dflash_config, f, indent=2) |
|
|
| |
| target_info = { |
| "source_model": pytorch_model_id, |
| "target_model": _infer_target_model(pytorch_model_id), |
| } |
| info_path = output_path / "model_info.json" |
| with open(info_path, "w") as f: |
| json.dump(target_info, f, indent=2) |
|
|
| print(f"[Convert] Done! Model saved to {output_path}") |
| return str(output_path) |
|
|
|
|
| def _infer_target_model(dflash_model_id: str) -> str: |
| """Infer the target model from DFlash drafter ID.""" |
| |
| mapping = { |
| "Qwen3-4B-DFlash": "Qwen/Qwen3-4B", |
| "Qwen3-8B-DFlash": "Qwen/Qwen3-8B", |
| "Qwen3.5-9B-DFlash": "Qwen/Qwen3.5-9B", |
| "Qwen3.5-27B-DFlash": "Qwen/Qwen3.5-27B", |
| "Qwen3.6-27B-DFlash": "Qwen/Qwen3.6-27B", |
| "Qwen3.6-35B-A3B-DFlash": "Qwen/Qwen3.6-35B-A3B", |
| "Qwen3-Coder-30B-A3B-DFlash": "Qwen/Qwen3-Coder-30B-A3B", |
| "Qwen3.5-122B-A10B-DFlash": "Qwen/Qwen3.5-122B-A10B", |
| "LLaMA3.1-8B-Instruct-DFlash": "meta-llama/Llama-3.1-8B-Instruct", |
| "gemma-4-31B-it-DFlash": "google/gemma-4-31b-it", |
| "gpt-oss-20b-DFlash": "openai/gpt-oss-20b", |
| "Kimi-K2.5-DFlash": "moonshotai/Kimi-K2.5", |
| "MiniMax-M2.5-DFlash": "MiniMax/MiniMax-M2.5", |
| } |
| |
| for key, target in mapping.items(): |
| if key in dflash_model_id: |
| return target |
| |
| |
| if "Qwen3.6" in dflash_model_id: |
| return "Qwen/Qwen3.6-27B" |
| elif "Qwen3.5" in dflash_model_id: |
| return "Qwen/Qwen3.5-9B" |
| elif "Qwen3" in dflash_model_id: |
| return "Qwen/Qwen3-4B" |
| elif "LLaMA" in dflash_model_id or "Llama" in dflash_model_id: |
| return "meta-llama/Llama-3.1-8B-Instruct" |
| elif "gemma" in dflash_model_id: |
| return "google/gemma-4-31b-it" |
| |
| return "unknown" |
|
|
|
|
| def load_mlx_dflash( |
| model_path: str, |
| ) -> tuple: |
| """Load a converted MLX DFlash model. |
| |
| Args: |
| model_path: Path to converted MLX model directory |
| |
| Returns: |
| Tuple of (model, config) |
| """ |
| from .model import DFlashDraftModel |
|
|
| model_path = Path(model_path) |
| |
| |
| with open(model_path / "config.json", "r") as f: |
| config = json.load(f) |
|
|
| |
| weights = mx.load(str(model_path / "weights.safetensors")) |
|
|
| |
| model = DFlashDraftModel( |
| vocab_size=config["vocab_size"], |
| hidden_size=config["hidden_size"], |
| num_layers=config["num_hidden_layers"], |
| num_heads=config["num_attention_heads"], |
| num_kv_heads=config["num_key_value_heads"], |
| intermediate_size=config["intermediate_size"], |
| max_seq_len=config["max_position_embeddings"], |
| block_size=config.get("block_size", 16), |
| rope_base=config.get("rope_base", 10000.0), |
| ) |
|
|
| |
| model.update(weights) |
|
|
| return model, config |
|
|