tritesh's picture
Upload folder using huggingface_hub
0433390 verified
raw
history blame
7.87 kB
"""
Convert PyTorch DFlash drafter models to MLX format.
Handles weight conversion from PyTorch safetensors to MLX arrays,
compatible with any z-lab DFlash drafter.
"""
import json
import os
from pathlib import Path
from typing import Optional, Dict
import mlx.core as mx
from transformers import AutoConfig, AutoModel
from huggingface_hub import hf_hub_download, snapshot_download
def _convert_key(key: str) -> str:
"""Convert PyTorch parameter names to MLX format."""
# Replace PyTorch-specific prefixes
key = key.replace("model.", "")
# Standardize naming
replacements = {
"embed_tokens": "embed_tokens",
"layers.": "layers.",
"self_attn.": "self_attn.",
"mlp.": "mlp.",
"input_layernorm": "input_layernorm",
"post_attention_layernorm": "post_attention_layernorm",
"norm": "norm",
"lm_head": "lm_head",
"q_proj": "q_proj",
"k_proj": "k_proj",
"v_proj": "v_proj",
"o_proj": "o_proj",
"gate_proj": "gate_proj",
"up_proj": "up_proj",
"down_proj": "down_proj",
"fc": "fc",
"hidden_norm": "hidden_norm",
"q_norm": "q_norm",
"k_norm": "k_norm",
"weight": "weight",
}
return key
def _transpose_if_needed(key: str, tensor) -> mx.array:
"""Transpose linear layer weights from PyTorch to MLX format."""
# Linear layers in PyTorch are [out, in], MLX expects [in, out]
if "proj" in key or "fc" in key or "lm_head" in key or "embed" in key:
if len(tensor.shape) == 2:
return mx.array(tensor.T)
return mx.array(tensor)
def convert_dflash_to_mlx(
pytorch_model_id: str,
output_path: str,
trust_remote_code: bool = True,
token: Optional[str] = None,
) -> str:
"""Convert a PyTorch DFlash drafter to MLX format.
Args:
pytorch_model_id: Hugging Face model ID (e.g., "z-lab/Qwen3-4B-DFlash-b16")
output_path: Local directory to save converted model
trust_remote_code: Whether to trust custom modeling code
token: HF API token for gated/private models
Returns:
Path to the converted model directory
"""
output_path = Path(output_path)
output_path.mkdir(parents=True, exist_ok=True)
print(f"[Convert] Downloading {pytorch_model_id}...")
# Download model files
repo_path = snapshot_download(
repo_id=pytorch_model_id,
token=token,
ignore_patterns=["*.md", "*.png", "*.jpg"],
)
repo_path = Path(repo_path)
# Load PyTorch model to extract config
print("[Convert] Loading PyTorch model for config extraction...")
config = AutoConfig.from_pretrained(
repo_path,
trust_remote_code=trust_remote_code,
)
# Extract DFlash-specific config
dflash_config = {
"vocab_size": getattr(config, "vocab_size", 151936),
"hidden_size": getattr(config, "hidden_size", 1024),
"num_hidden_layers": getattr(config, "num_hidden_layers", 5),
"num_attention_heads": getattr(config, "num_attention_heads", 16),
"num_key_value_heads": getattr(config, "num_key_value_heads", 4),
"intermediate_size": getattr(config, "intermediate_size", 2816),
"max_position_embeddings": getattr(config, "max_position_embeddings", 32768),
"rms_norm_eps": getattr(config, "rms_norm_eps", 1e-6),
"block_size": getattr(config, "block_size", 16),
"rope_base": getattr(config, "rope_theta", 10000.0),
}
# Load weights from safetensors
print("[Convert] Loading weights from safetensors...")
try:
from safetensors.torch import load_file
weights_file = repo_path / "model.safetensors"
if weights_file.exists():
pt_weights = load_file(str(weights_file))
else:
# Try to find any .safetensors file
safetensors_files = list(repo_path.glob("*.safetensors"))
if safetensors_files:
pt_weights = load_file(str(safetensors_files[0]))
else:
raise FileNotFoundError("No safetensors file found")
except ImportError:
# Fallback to torch load
import torch
weights_file = repo_path / "pytorch_model.bin"
pt_weights = torch.load(str(weights_file), map_location="cpu")
# Convert weights
print(f"[Convert] Converting {len(pt_weights)} parameters...")
mlx_weights = {}
for key, tensor in pt_weights.items():
mlx_key = _convert_key(key)
mlx_weights[mlx_key] = _transpose_if_needed(key, tensor)
# Save MLX weights
weights_path = output_path / "weights.safetensors"
print(f"[Convert] Saving to {weights_path}...")
# Save using MLX
mx.save_safetensors(str(weights_path), mlx_weights)
# Save config
config_path = output_path / "config.json"
with open(config_path, "w") as f:
json.dump(dflash_config, f, indent=2)
# Save target model info
target_info = {
"source_model": pytorch_model_id,
"target_model": _infer_target_model(pytorch_model_id),
}
info_path = output_path / "model_info.json"
with open(info_path, "w") as f:
json.dump(target_info, f, indent=2)
print(f"[Convert] Done! Model saved to {output_path}")
return str(output_path)
def _infer_target_model(dflash_model_id: str) -> str:
"""Infer the target model from DFlash drafter ID."""
# Map drafter IDs to target models
mapping = {
"Qwen3-4B-DFlash": "Qwen/Qwen3-4B",
"Qwen3-8B-DFlash": "Qwen/Qwen3-8B",
"Qwen3.5-9B-DFlash": "Qwen/Qwen3.5-9B",
"Qwen3.5-27B-DFlash": "Qwen/Qwen3.5-27B",
"Qwen3.6-27B-DFlash": "Qwen/Qwen3.6-27B",
"Qwen3.6-35B-A3B-DFlash": "Qwen/Qwen3.6-35B-A3B",
"Qwen3-Coder-30B-A3B-DFlash": "Qwen/Qwen3-Coder-30B-A3B",
"Qwen3.5-122B-A10B-DFlash": "Qwen/Qwen3.5-122B-A10B",
"LLaMA3.1-8B-Instruct-DFlash": "meta-llama/Llama-3.1-8B-Instruct",
"gemma-4-31B-it-DFlash": "google/gemma-4-31b-it",
"gpt-oss-20b-DFlash": "openai/gpt-oss-20b",
"Kimi-K2.5-DFlash": "moonshotai/Kimi-K2.5",
"MiniMax-M2.5-DFlash": "MiniMax/MiniMax-M2.5",
}
for key, target in mapping.items():
if key in dflash_model_id:
return target
# Generic inference
if "Qwen3.6" in dflash_model_id:
return "Qwen/Qwen3.6-27B"
elif "Qwen3.5" in dflash_model_id:
return "Qwen/Qwen3.5-9B"
elif "Qwen3" in dflash_model_id:
return "Qwen/Qwen3-4B"
elif "LLaMA" in dflash_model_id or "Llama" in dflash_model_id:
return "meta-llama/Llama-3.1-8B-Instruct"
elif "gemma" in dflash_model_id:
return "google/gemma-4-31b-it"
return "unknown"
def load_mlx_dflash(
model_path: str,
) -> tuple:
"""Load a converted MLX DFlash model.
Args:
model_path: Path to converted MLX model directory
Returns:
Tuple of (model, config)
"""
from .model import DFlashDraftModel
model_path = Path(model_path)
# Load config
with open(model_path / "config.json", "r") as f:
config = json.load(f)
# Load weights
weights = mx.load(str(model_path / "weights.safetensors"))
# Build model
model = DFlashDraftModel(
vocab_size=config["vocab_size"],
hidden_size=config["hidden_size"],
num_layers=config["num_hidden_layers"],
num_heads=config["num_attention_heads"],
num_kv_heads=config["num_key_value_heads"],
intermediate_size=config["intermediate_size"],
max_seq_len=config["max_position_embeddings"],
block_size=config.get("block_size", 16),
rope_base=config.get("rope_base", 10000.0),
)
# Load weights into model
model.update(weights)
return model, config