tritesh commited on
Commit
c61f568
·
verified ·
1 Parent(s): 4e22dea

Upload dflash_mlx/convert.py

Browse files
Files changed (1) hide show
  1. dflash_mlx/convert.py +135 -34
dflash_mlx/convert.py CHANGED
@@ -3,19 +3,24 @@ Convert PyTorch DFlash drafter models to MLX format.
3
 
4
  Handles weight conversion from PyTorch safetensors to MLX arrays,
5
  compatible with any z-lab DFlash drafter.
 
 
6
  """
7
 
8
  import json
9
  import os
10
  from pathlib import Path
11
- from typing import Optional, Dict
12
  import mlx.core as mx
13
- from transformers import AutoConfig, AutoModel
14
  from huggingface_hub import hf_hub_download, snapshot_download
15
 
16
 
17
  def _convert_key(key: str) -> str:
18
- """Convert PyTorch parameter names to MLX format."""
 
 
 
19
  # Replace PyTorch-specific prefixes
20
  key = key.replace("model.", "")
21
  # Standardize naming
@@ -45,8 +50,10 @@ def _convert_key(key: str) -> str:
45
 
46
 
47
  def _transpose_if_needed(key: str, tensor) -> mx.array:
48
- """Transpose linear layer weights from PyTorch to MLX format."""
49
- # Linear layers in PyTorch are [out, in], MLX expects [in, out]
 
 
50
  if "proj" in key or "fc" in key or "lm_head" in key or "embed" in key:
51
  if len(tensor.shape) == 2:
52
  return mx.array(tensor.T)
@@ -79,12 +86,12 @@ def convert_dflash_to_mlx(
79
  repo_path = snapshot_download(
80
  repo_id=pytorch_model_id,
81
  token=token,
82
- ignore_patterns=["*.md", "*.png", "*.jpg"],
83
  )
84
  repo_path = Path(repo_path)
85
 
86
- # Load PyTorch model to extract config
87
- print("[Convert] Loading PyTorch model for config extraction...")
88
  config = AutoConfig.from_pretrained(
89
  repo_path,
90
  trust_remote_code=trust_remote_code,
@@ -103,26 +110,43 @@ def convert_dflash_to_mlx(
103
  "block_size": getattr(config, "block_size", 16),
104
  "rope_base": getattr(config, "rope_theta", 10000.0),
105
  }
 
 
 
 
 
 
106
 
107
  # Load weights from safetensors
108
  print("[Convert] Loading weights from safetensors...")
109
  try:
110
  from safetensors.torch import load_file
111
- weights_file = repo_path / "model.safetensors"
112
- if weights_file.exists():
113
- pt_weights = load_file(str(weights_file))
 
 
 
 
 
 
 
114
  else:
115
- # Try to find any .safetensors file
116
- safetensors_files = list(repo_path.glob("*.safetensors"))
117
- if safetensors_files:
118
- pt_weights = load_file(str(safetensors_files[0]))
 
119
  else:
120
- raise FileNotFoundError("No safetensors file found")
121
  except ImportError:
122
  # Fallback to torch load
123
  import torch
124
  weights_file = repo_path / "pytorch_model.bin"
125
- pt_weights = torch.load(str(weights_file), map_location="cpu")
 
 
 
126
 
127
  # Convert weights
128
  print(f"[Convert] Converting {len(pt_weights)} parameters...")
@@ -131,72 +155,114 @@ def convert_dflash_to_mlx(
131
  mlx_key = _convert_key(key)
132
  mlx_weights[mlx_key] = _transpose_if_needed(key, tensor)
133
 
134
- # Save MLX weights
135
- weights_path = output_path / "weights.safetensors"
136
- print(f"[Convert] Saving to {weights_path}...")
137
-
138
- # Save using MLX
139
- mx.save_safetensors(str(weights_path), mlx_weights)
 
 
 
 
 
 
 
 
 
 
140
 
141
  # Save config
142
  config_path = output_path / "config.json"
143
  with open(config_path, "w") as f:
144
  json.dump(dflash_config, f, indent=2)
145
 
146
- # Save target model info
147
  target_info = {
148
  "source_model": pytorch_model_id,
149
- "target_model": _infer_target_model(pytorch_model_id),
 
150
  }
151
  info_path = output_path / "model_info.json"
152
  with open(info_path, "w") as f:
153
  json.dump(target_info, f, indent=2)
154
 
155
  print(f"[Convert] Done! Model saved to {output_path}")
 
 
156
  return str(output_path)
157
 
158
 
159
- def _infer_target_model(dflash_model_id: str) -> str:
160
- """Infer the target model from DFlash drafter ID."""
 
 
 
 
161
  # Map drafter IDs to target models
162
  mapping = {
 
163
  "Qwen3-4B-DFlash": "Qwen/Qwen3-4B",
164
  "Qwen3-8B-DFlash": "Qwen/Qwen3-8B",
 
 
 
165
  "Qwen3.5-9B-DFlash": "Qwen/Qwen3.5-9B",
166
  "Qwen3.5-27B-DFlash": "Qwen/Qwen3.5-27B",
 
 
 
167
  "Qwen3.6-27B-DFlash": "Qwen/Qwen3.6-27B",
168
  "Qwen3.6-35B-A3B-DFlash": "Qwen/Qwen3.6-35B-A3B",
 
 
169
  "Qwen3-Coder-30B-A3B-DFlash": "Qwen/Qwen3-Coder-30B-A3B",
170
- "Qwen3.5-122B-A10B-DFlash": "Qwen/Qwen3.5-122B-A10B",
171
  "LLaMA3.1-8B-Instruct-DFlash": "meta-llama/Llama-3.1-8B-Instruct",
 
 
172
  "gemma-4-31B-it-DFlash": "google/gemma-4-31b-it",
 
 
173
  "gpt-oss-20b-DFlash": "openai/gpt-oss-20b",
 
 
174
  "Kimi-K2.5-DFlash": "moonshotai/Kimi-K2.5",
 
175
  "MiniMax-M2.5-DFlash": "MiniMax/MiniMax-M2.5",
176
  }
177
 
 
178
  for key, target in mapping.items():
179
  if key in dflash_model_id:
180
  return target
181
 
182
- # Generic inference
183
  if "Qwen3.6" in dflash_model_id:
184
  return "Qwen/Qwen3.6-27B"
185
  elif "Qwen3.5" in dflash_model_id:
186
  return "Qwen/Qwen3.5-9B"
 
 
187
  elif "Qwen3" in dflash_model_id:
188
  return "Qwen/Qwen3-4B"
189
- elif "LLaMA" in dflash_model_id or "Llama" in dflash_model_id:
190
  return "meta-llama/Llama-3.1-8B-Instruct"
191
- elif "gemma" in dflash_model_id:
192
  return "google/gemma-4-31b-it"
 
 
 
 
 
 
193
 
194
  return "unknown"
195
 
196
 
197
  def load_mlx_dflash(
198
  model_path: str,
199
- ) -> tuple:
200
  """Load a converted MLX DFlash model.
201
 
202
  Args:
@@ -214,8 +280,20 @@ def load_mlx_dflash(
214
  config = json.load(f)
215
 
216
  # Load weights
217
- weights = mx.load(str(model_path / "weights.safetensors"))
218
-
 
 
 
 
 
 
 
 
 
 
 
 
219
  # Build model
220
  model = DFlashDraftModel(
221
  vocab_size=config["vocab_size"],
@@ -227,9 +305,32 @@ def load_mlx_dflash(
227
  max_seq_len=config["max_position_embeddings"],
228
  block_size=config.get("block_size", 16),
229
  rope_base=config.get("rope_base", 10000.0),
 
230
  )
231
 
232
  # Load weights into model
233
  model.update(weights)
234
 
235
  return model, config
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  Handles weight conversion from PyTorch safetensors to MLX arrays,
5
  compatible with any z-lab DFlash drafter.
6
+
7
+ Updated to work with the universal adapter system for any target model family.
8
  """
9
 
10
  import json
11
  import os
12
  from pathlib import Path
13
+ from typing import Optional, Dict, Tuple
14
  import mlx.core as mx
15
+ from transformers import AutoConfig
16
  from huggingface_hub import hf_hub_download, snapshot_download
17
 
18
 
19
  def _convert_key(key: str) -> str:
20
+ """Convert PyTorch parameter names to MLX format.
21
+
22
+ Handles various naming conventions across model families.
23
+ """
24
  # Replace PyTorch-specific prefixes
25
  key = key.replace("model.", "")
26
  # Standardize naming
 
50
 
51
 
52
  def _transpose_if_needed(key: str, tensor) -> mx.array:
53
+ """Transpose linear layer weights from PyTorch to MLX format.
54
+
55
+ Linear layers in PyTorch are [out, in], MLX expects [in, out].
56
+ """
57
  if "proj" in key or "fc" in key or "lm_head" in key or "embed" in key:
58
  if len(tensor.shape) == 2:
59
  return mx.array(tensor.T)
 
86
  repo_path = snapshot_download(
87
  repo_id=pytorch_model_id,
88
  token=token,
89
+ ignore_patterns=["*.md", "*.png", "*.jpg", "*.gif", "*.jpeg"],
90
  )
91
  repo_path = Path(repo_path)
92
 
93
+ # Load PyTorch model config
94
+ print("[Convert] Loading PyTorch config...")
95
  config = AutoConfig.from_pretrained(
96
  repo_path,
97
  trust_remote_code=trust_remote_code,
 
110
  "block_size": getattr(config, "block_size", 16),
111
  "rope_base": getattr(config, "rope_theta", 10000.0),
112
  }
113
+
114
+ # Extract target layer IDs if present in config
115
+ if hasattr(config, "target_layer_ids"):
116
+ dflash_config["target_layer_ids"] = config.target_layer_ids
117
+ elif hasattr(config, "dflash_config") and hasattr(config.dflash_config, "target_layer_ids"):
118
+ dflash_config["target_layer_ids"] = config.dflash_config.target_layer_ids
119
 
120
  # Load weights from safetensors
121
  print("[Convert] Loading weights from safetensors...")
122
  try:
123
  from safetensors.torch import load_file
124
+
125
+ # Find all safetensors files
126
+ safetensors_files = sorted(repo_path.glob("*.safetensors"))
127
+
128
+ if safetensors_files:
129
+ pt_weights = {}
130
+ for st_file in safetensors_files:
131
+ print(f" Loading {st_file.name}...")
132
+ partial = load_file(str(st_file))
133
+ pt_weights.update(partial)
134
  else:
135
+ # Try pytorch_model.bin
136
+ bin_file = repo_path / "pytorch_model.bin"
137
+ if bin_file.exists():
138
+ import torch
139
+ pt_weights = torch.load(str(bin_file), map_location="cpu")
140
  else:
141
+ raise FileNotFoundError("No safetensors or pytorch_model.bin found")
142
  except ImportError:
143
  # Fallback to torch load
144
  import torch
145
  weights_file = repo_path / "pytorch_model.bin"
146
+ if weights_file.exists():
147
+ pt_weights = torch.load(str(weights_file), map_location="cpu")
148
+ else:
149
+ raise FileNotFoundError("No weight files found and safetensors not installed")
150
 
151
  # Convert weights
152
  print(f"[Convert] Converting {len(pt_weights)} parameters...")
 
155
  mlx_key = _convert_key(key)
156
  mlx_weights[mlx_key] = _transpose_if_needed(key, tensor)
157
 
158
+ # Save MLX weights (try safetensors, fallback to npz)
159
+ weights_path = output_path / "weights.npz"
160
+ try:
161
+ # Use numpy format if safetensors save is problematic
162
+ import numpy as np
163
+ np_weights = {k: np.array(v) for k, v in mlx_weights.items()}
164
+ np.savez(str(weights_path), **np_weights)
165
+ print(f"[Convert] Saved weights to {weights_path}")
166
+ except Exception as e:
167
+ print(f"[Convert] Warning: Could not save weights: {e}")
168
+ # Try direct mlx save
169
+ try:
170
+ mx.savez(str(weights_path), **mlx_weights)
171
+ except Exception as e2:
172
+ print(f"[Convert] Error saving weights: {e2}")
173
+ raise
174
 
175
  # Save config
176
  config_path = output_path / "config.json"
177
  with open(config_path, "w") as f:
178
  json.dump(dflash_config, f, indent=2)
179
 
180
+ # Save target model mapping
181
  target_info = {
182
  "source_model": pytorch_model_id,
183
+ "target_model": infer_target_model(pytorch_model_id),
184
+ "conversion_date": str(Path(__file__).stat().st_mtime),
185
  }
186
  info_path = output_path / "model_info.json"
187
  with open(info_path, "w") as f:
188
  json.dump(target_info, f, indent=2)
189
 
190
  print(f"[Convert] Done! Model saved to {output_path}")
191
+ print(f" Config: {dflash_config}")
192
+ print(f" Target: {target_info['target_model']}")
193
  return str(output_path)
194
 
195
 
196
+ def infer_target_model(dflash_model_id: str) -> str:
197
+ """Infer the target model from DFlash drafter ID.
198
+
199
+ Maps known drafter checkpoints to their corresponding target models.
200
+ Supports all official z-lab DFlash models plus community variants.
201
+ """
202
  # Map drafter IDs to target models
203
  mapping = {
204
+ # Qwen3 series
205
  "Qwen3-4B-DFlash": "Qwen/Qwen3-4B",
206
  "Qwen3-8B-DFlash": "Qwen/Qwen3-8B",
207
+ "Qwen3-32B-DFlash": "Qwen/Qwen3-32B",
208
+ # Qwen3.5 series
209
+ "Qwen3.5-4B-DFlash": "Qwen/Qwen3.5-4B",
210
  "Qwen3.5-9B-DFlash": "Qwen/Qwen3.5-9B",
211
  "Qwen3.5-27B-DFlash": "Qwen/Qwen3.5-27B",
212
+ "Qwen3.5-35B-A3B-DFlash": "Qwen/Qwen3.5-35B-A3B",
213
+ "Qwen3.5-122B-A10B-DFlash": "Qwen/Qwen3.5-122B-A10B",
214
+ # Qwen3.6 series
215
  "Qwen3.6-27B-DFlash": "Qwen/Qwen3.6-27B",
216
  "Qwen3.6-35B-A3B-DFlash": "Qwen/Qwen3.6-35B-A3B",
217
+ # Qwen Coder
218
+ "Qwen3-Coder-Next-DFlash": "Qwen/Qwen3-Coder-Next",
219
  "Qwen3-Coder-30B-A3B-DFlash": "Qwen/Qwen3-Coder-30B-A3B",
220
+ # LLaMA
221
  "LLaMA3.1-8B-Instruct-DFlash": "meta-llama/Llama-3.1-8B-Instruct",
222
+ "LLaMA3.1-70B-Instruct-DFlash": "meta-llama/Llama-3.1-70B-Instruct",
223
+ # Gemma
224
  "gemma-4-31B-it-DFlash": "google/gemma-4-31b-it",
225
+ "gemma-4-26B-A4B-it-DFlash": "google/gemma-4-26b-a4b-it",
226
+ # GPT-OSS
227
  "gpt-oss-20b-DFlash": "openai/gpt-oss-20b",
228
+ "gpt-oss-120b-DFlash": "openai/gpt-oss-120b",
229
+ # Kimi
230
  "Kimi-K2.5-DFlash": "moonshotai/Kimi-K2.5",
231
+ # MiniMax
232
  "MiniMax-M2.5-DFlash": "MiniMax/MiniMax-M2.5",
233
  }
234
 
235
+ # Direct mapping lookup
236
  for key, target in mapping.items():
237
  if key in dflash_model_id:
238
  return target
239
 
240
+ # Generic inference by model family
241
  if "Qwen3.6" in dflash_model_id:
242
  return "Qwen/Qwen3.6-27B"
243
  elif "Qwen3.5" in dflash_model_id:
244
  return "Qwen/Qwen3.5-9B"
245
+ elif "Qwen3-Coder" in dflash_model_id:
246
+ return "Qwen/Qwen3-Coder-Next"
247
  elif "Qwen3" in dflash_model_id:
248
  return "Qwen/Qwen3-4B"
249
+ elif "LLaMA" in dflash_model_id or "Llama" in dflash_model_id or "llama" in dflash_model_id:
250
  return "meta-llama/Llama-3.1-8B-Instruct"
251
+ elif "gemma" in dflash_model_id.lower():
252
  return "google/gemma-4-31b-it"
253
+ elif "gpt-oss" in dflash_model_id.lower():
254
+ return "openai/gpt-oss-20b"
255
+ elif "Kimi" in dflash_model_id:
256
+ return "moonshotai/Kimi-K2.5"
257
+ elif "MiniMax" in dflash_model_id:
258
+ return "MiniMax/MiniMax-M2.5"
259
 
260
  return "unknown"
261
 
262
 
263
  def load_mlx_dflash(
264
  model_path: str,
265
+ ) -> Tuple:
266
  """Load a converted MLX DFlash model.
267
 
268
  Args:
 
280
  config = json.load(f)
281
 
282
  # Load weights
283
+ weights_path = model_path / "weights.npz"
284
+ if not weights_path.exists():
285
+ # Try alternative extensions
286
+ for ext in [".safetensors", ".mlx", ".npz"]:
287
+ alt = model_path / f"weights{ext}"
288
+ if alt.exists():
289
+ weights_path = alt
290
+ break
291
+
292
+ if not weights_path.exists():
293
+ raise FileNotFoundError(f"No weights found in {model_path}")
294
+
295
+ weights = mx.load(str(weights_path))
296
+
297
  # Build model
298
  model = DFlashDraftModel(
299
  vocab_size=config["vocab_size"],
 
305
  max_seq_len=config["max_position_embeddings"],
306
  block_size=config.get("block_size", 16),
307
  rope_base=config.get("rope_base", 10000.0),
308
+ target_layer_ids=config.get("target_layer_ids", None),
309
  )
310
 
311
  # Load weights into model
312
  model.update(weights)
313
 
314
  return model, config
315
+
316
+
317
+ def main():
318
+ """CLI entry point for conversion."""
319
+ import argparse
320
+ parser = argparse.ArgumentParser(description="Convert PyTorch DFlash drafter to MLX")
321
+ parser.add_argument("--model", required=True, help="HF model ID of PyTorch drafter")
322
+ parser.add_argument("--output", required=True, help="Output directory")
323
+ parser.add_argument("--trust-remote-code", action="store_true", default=True)
324
+ parser.add_argument("--token", default=None, help="HF token for gated models")
325
+ args = parser.parse_args()
326
+
327
+ convert_dflash_to_mlx(
328
+ pytorch_model_id=args.model,
329
+ output_path=args.output,
330
+ trust_remote_code=args.trust_remote_code,
331
+ token=args.token,
332
+ )
333
+
334
+
335
+ if __name__ == "__main__":
336
+ main()