""" Convert PyTorch DFlash drafter models to MLX format. Handles weight conversion from PyTorch safetensors to MLX arrays, compatible with any z-lab DFlash drafter. """ import json import os from pathlib import Path from typing import Optional, Dict import mlx.core as mx from transformers import AutoConfig, AutoModel from huggingface_hub import hf_hub_download, snapshot_download def _convert_key(key: str) -> str: """Convert PyTorch parameter names to MLX format.""" # Replace PyTorch-specific prefixes key = key.replace("model.", "") # Standardize naming replacements = { "embed_tokens": "embed_tokens", "layers.": "layers.", "self_attn.": "self_attn.", "mlp.": "mlp.", "input_layernorm": "input_layernorm", "post_attention_layernorm": "post_attention_layernorm", "norm": "norm", "lm_head": "lm_head", "q_proj": "q_proj", "k_proj": "k_proj", "v_proj": "v_proj", "o_proj": "o_proj", "gate_proj": "gate_proj", "up_proj": "up_proj", "down_proj": "down_proj", "fc": "fc", "hidden_norm": "hidden_norm", "q_norm": "q_norm", "k_norm": "k_norm", "weight": "weight", } return key def _transpose_if_needed(key: str, tensor) -> mx.array: """Transpose linear layer weights from PyTorch to MLX format.""" # Linear layers in PyTorch are [out, in], MLX expects [in, out] if "proj" in key or "fc" in key or "lm_head" in key or "embed" in key: if len(tensor.shape) == 2: return mx.array(tensor.T) return mx.array(tensor) def convert_dflash_to_mlx( pytorch_model_id: str, output_path: str, trust_remote_code: bool = True, token: Optional[str] = None, ) -> str: """Convert a PyTorch DFlash drafter to MLX format. Args: pytorch_model_id: Hugging Face model ID (e.g., "z-lab/Qwen3-4B-DFlash-b16") output_path: Local directory to save converted model trust_remote_code: Whether to trust custom modeling code token: HF API token for gated/private models Returns: Path to the converted model directory """ output_path = Path(output_path) output_path.mkdir(parents=True, exist_ok=True) print(f"[Convert] Downloading {pytorch_model_id}...") # Download model files repo_path = snapshot_download( repo_id=pytorch_model_id, token=token, ignore_patterns=["*.md", "*.png", "*.jpg"], ) repo_path = Path(repo_path) # Load PyTorch model to extract config print("[Convert] Loading PyTorch model for config extraction...") config = AutoConfig.from_pretrained( repo_path, trust_remote_code=trust_remote_code, ) # Extract DFlash-specific config dflash_config = { "vocab_size": getattr(config, "vocab_size", 151936), "hidden_size": getattr(config, "hidden_size", 1024), "num_hidden_layers": getattr(config, "num_hidden_layers", 5), "num_attention_heads": getattr(config, "num_attention_heads", 16), "num_key_value_heads": getattr(config, "num_key_value_heads", 4), "intermediate_size": getattr(config, "intermediate_size", 2816), "max_position_embeddings": getattr(config, "max_position_embeddings", 32768), "rms_norm_eps": getattr(config, "rms_norm_eps", 1e-6), "block_size": getattr(config, "block_size", 16), "rope_base": getattr(config, "rope_theta", 10000.0), } # Load weights from safetensors print("[Convert] Loading weights from safetensors...") try: from safetensors.torch import load_file weights_file = repo_path / "model.safetensors" if weights_file.exists(): pt_weights = load_file(str(weights_file)) else: # Try to find any .safetensors file safetensors_files = list(repo_path.glob("*.safetensors")) if safetensors_files: pt_weights = load_file(str(safetensors_files[0])) else: raise FileNotFoundError("No safetensors file found") except ImportError: # Fallback to torch load import torch weights_file = repo_path / "pytorch_model.bin" pt_weights = torch.load(str(weights_file), map_location="cpu") # Convert weights print(f"[Convert] Converting {len(pt_weights)} parameters...") mlx_weights = {} for key, tensor in pt_weights.items(): mlx_key = _convert_key(key) mlx_weights[mlx_key] = _transpose_if_needed(key, tensor) # Save MLX weights weights_path = output_path / "weights.safetensors" print(f"[Convert] Saving to {weights_path}...") # Save using MLX mx.save_safetensors(str(weights_path), mlx_weights) # Save config config_path = output_path / "config.json" with open(config_path, "w") as f: json.dump(dflash_config, f, indent=2) # Save target model info target_info = { "source_model": pytorch_model_id, "target_model": _infer_target_model(pytorch_model_id), } info_path = output_path / "model_info.json" with open(info_path, "w") as f: json.dump(target_info, f, indent=2) print(f"[Convert] Done! Model saved to {output_path}") return str(output_path) def _infer_target_model(dflash_model_id: str) -> str: """Infer the target model from DFlash drafter ID.""" # Map drafter IDs to target models mapping = { "Qwen3-4B-DFlash": "Qwen/Qwen3-4B", "Qwen3-8B-DFlash": "Qwen/Qwen3-8B", "Qwen3.5-9B-DFlash": "Qwen/Qwen3.5-9B", "Qwen3.5-27B-DFlash": "Qwen/Qwen3.5-27B", "Qwen3.6-27B-DFlash": "Qwen/Qwen3.6-27B", "Qwen3.6-35B-A3B-DFlash": "Qwen/Qwen3.6-35B-A3B", "Qwen3-Coder-30B-A3B-DFlash": "Qwen/Qwen3-Coder-30B-A3B", "Qwen3.5-122B-A10B-DFlash": "Qwen/Qwen3.5-122B-A10B", "LLaMA3.1-8B-Instruct-DFlash": "meta-llama/Llama-3.1-8B-Instruct", "gemma-4-31B-it-DFlash": "google/gemma-4-31b-it", "gpt-oss-20b-DFlash": "openai/gpt-oss-20b", "Kimi-K2.5-DFlash": "moonshotai/Kimi-K2.5", "MiniMax-M2.5-DFlash": "MiniMax/MiniMax-M2.5", } for key, target in mapping.items(): if key in dflash_model_id: return target # Generic inference if "Qwen3.6" in dflash_model_id: return "Qwen/Qwen3.6-27B" elif "Qwen3.5" in dflash_model_id: return "Qwen/Qwen3.5-9B" elif "Qwen3" in dflash_model_id: return "Qwen/Qwen3-4B" elif "LLaMA" in dflash_model_id or "Llama" in dflash_model_id: return "meta-llama/Llama-3.1-8B-Instruct" elif "gemma" in dflash_model_id: return "google/gemma-4-31b-it" return "unknown" def load_mlx_dflash( model_path: str, ) -> tuple: """Load a converted MLX DFlash model. Args: model_path: Path to converted MLX model directory Returns: Tuple of (model, config) """ from .model import DFlashDraftModel model_path = Path(model_path) # Load config with open(model_path / "config.json", "r") as f: config = json.load(f) # Load weights weights = mx.load(str(model_path / "weights.safetensors")) # Build model model = DFlashDraftModel( vocab_size=config["vocab_size"], hidden_size=config["hidden_size"], num_layers=config["num_hidden_layers"], num_heads=config["num_attention_heads"], num_kv_heads=config["num_key_value_heads"], intermediate_size=config["intermediate_size"], max_seq_len=config["max_position_embeddings"], block_size=config.get("block_size", 16), rope_base=config.get("rope_base", 10000.0), ) # Load weights into model model.update(weights) return model, config