tritesh commited on
Commit
20edc82
·
verified ·
1 Parent(s): c61f568

Upload dflash_mlx/universal.py

Browse files
Files changed (1) hide show
  1. dflash_mlx/universal.py +97 -47
dflash_mlx/universal.py CHANGED
@@ -3,26 +3,34 @@ Universal DFlash decoder for any MLX-converted model.
3
 
4
  Provides a high-level interface that works with any mlx_lm model,
5
  including those without pre-built DFlash drafters.
 
 
 
6
  """
7
 
8
  from typing import Optional, List, Dict, Any
9
  import mlx.core as mx
10
  from .model import DFlashDraftModel
11
  from .speculative_decode import DFlashSpeculativeDecoder
 
 
12
 
13
 
14
  class UniversalDFlashDecoder:
15
  """Universal DFlash decoder that works with any MLX-converted model.
16
 
17
  This class handles:
18
- 1. Loading pre-converted DFlash drafters
19
  2. Creating generic drafters for unsupported models
20
  3. Training custom drafters on-the-fly
 
 
 
21
  """
22
 
23
  def __init__(
24
  self,
25
- target_model,
26
  tokenizer,
27
  draft_model_path: Optional[str] = None,
28
  draft_layers: int = 5,
@@ -33,7 +41,7 @@ class UniversalDFlashDecoder:
33
  """Initialize the universal decoder.
34
 
35
  Args:
36
- target_model: Any mlx_lm loaded model
37
  tokenizer: Tokenizer for the model
38
  draft_model_path: Optional path to pre-converted DFlash drafter
39
  draft_layers: Number of draft layers (if creating generic drafter)
@@ -41,19 +49,47 @@ class UniversalDFlashDecoder:
41
  block_size: Number of tokens per draft block
42
  device: MLX device
43
  """
44
- self.target_model = target_model
45
  self.tokenizer = tokenizer
46
  self.block_size = block_size
47
  self.device = device
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  # Determine model type and vocab size
50
  self.vocab_size = getattr(tokenizer, "vocab_size", 151936)
51
- self.target_config = self._extract_target_config(target_model)
52
 
53
  # Load or create draft model
54
  if draft_model_path:
55
- print(f"[UniversalDFlash] Loading pre-built drafter from {draft_model_path}")
56
- from .convert import load_mlx_dflash
57
  self.draft_model, self.draft_config = load_mlx_dflash(draft_model_path)
58
  else:
59
  print("[UniversalDFlash] Creating generic drafter for your model...")
@@ -63,9 +99,9 @@ class UniversalDFlashDecoder:
63
  )
64
  self.draft_config = None
65
 
66
- # Create the speculative decoder
67
  self.decoder = DFlashSpeculativeDecoder(
68
- target_model=target_model,
69
  draft_model=self.draft_model,
70
  tokenizer=tokenizer,
71
  block_size=block_size,
@@ -85,6 +121,7 @@ class UniversalDFlashDecoder:
85
  config['intermediate_size'] = getattr(model_config, 'intermediate_size', 14336)
86
  config['num_attention_heads'] = getattr(model_config, 'num_attention_heads', 32)
87
  config['num_key_value_heads'] = getattr(model_config, 'num_key_value_heads', 8)
 
88
  else:
89
  # Default Qwen3-4B-like config
90
  config = {
@@ -94,6 +131,7 @@ class UniversalDFlashDecoder:
94
  'intermediate_size': 14336,
95
  'num_attention_heads': 32,
96
  'num_key_value_heads': 8,
 
97
  }
98
 
99
  return config
@@ -107,16 +145,26 @@ class UniversalDFlashDecoder:
107
 
108
  This creates an untrained drafter that can be trained or used
109
  with pre-trained weights from a similar architecture.
 
 
 
110
  """
111
  # Determine architecture compatibility
112
  hidden_size = self.target_config.get('hidden_size', 4096)
113
  vocab_size = self.target_config.get('vocab_size', 151936)
 
114
 
115
  # Scale drafter based on target model size
 
116
  num_heads = draft_hidden_size // 64 # ~64 dims per head
117
  num_kv_heads = max(1, num_heads // 4)
118
  intermediate_size = int(draft_hidden_size * 2.75) # Standard SwiGLU ratio
119
 
 
 
 
 
 
120
  drafter = DFlashDraftModel(
121
  vocab_size=vocab_size,
122
  hidden_size=draft_hidden_size,
@@ -126,8 +174,9 @@ class UniversalDFlashDecoder:
126
  intermediate_size=intermediate_size,
127
  max_seq_len=8192,
128
  block_size=self.block_size,
129
- mask_token_id=0, # Will be set from tokenizer
130
- num_target_layers=self.target_config.get('num_layers', 32),
 
131
  )
132
 
133
  return drafter
@@ -145,14 +194,20 @@ class UniversalDFlashDecoder:
145
  ) -> str:
146
  """Train a custom DFlash drafter for your target model.
147
 
 
 
 
 
 
 
148
  Args:
149
  dataset: Path to training dataset or HF dataset name
150
  max_seq_length: Maximum sequence length for training
151
- epochs: Number of training epochs
152
  batch_size: Training batch size
153
- lr: Learning rate
154
- warmup_ratio: Warmup ratio for cosine schedule
155
- grad_clip: Gradient clipping threshold
156
  output_path: Where to save the trained drafter
157
 
158
  Returns:
@@ -161,6 +216,9 @@ class UniversalDFlashDecoder:
161
  from .trainer import DFlashTrainer
162
 
163
  print(f"[UniversalDFlash] Training custom drafter...")
 
 
 
164
  trainer = DFlashTrainer(
165
  target_model=self.target_model,
166
  drafter=self.draft_model,
@@ -196,7 +254,17 @@ class UniversalDFlashDecoder:
196
 
197
  # Save weights
198
  weights = dict(self.draft_model.parameters())
199
- mx.save_safetensors(str(path / "weights.safetensors"), weights)
 
 
 
 
 
 
 
 
 
 
200
 
201
  # Save config
202
  config = {
@@ -205,9 +273,11 @@ class UniversalDFlashDecoder:
205
  "num_hidden_layers": self.draft_model.num_layers,
206
  "num_attention_heads": self.draft_model.num_heads,
207
  "num_key_value_heads": self.draft_model.num_heads // 4,
208
- "intermediate_size": self.draft_model.layers[0].mlp.gate_proj.weight.shape[1] if hasattr(self.draft_model.layers[0].mlp.gate_proj, 'weight') else 2816,
 
209
  "max_position_embeddings": self.draft_model.max_seq_len,
210
  "block_size": self.draft_model.block_size,
 
211
  }
212
 
213
  with open(path / "config.json", "w") as f:
@@ -221,7 +291,8 @@ class UniversalDFlashDecoder:
221
  max_tokens: int = 2048,
222
  temperature: float = 0.0,
223
  stop_strings: Optional[List[str]] = None,
224
- ) -> str:
 
225
  """Generate text using DFlash speculative decoding.
226
 
227
  Args:
@@ -229,15 +300,17 @@ class UniversalDFlashDecoder:
229
  max_tokens: Maximum tokens to generate
230
  temperature: Sampling temperature
231
  stop_strings: Optional stop strings
 
232
 
233
  Returns:
234
- Generated text
235
  """
236
  return self.decoder.generate(
237
  prompt=prompt,
238
  max_tokens=max_tokens,
239
  temperature=temperature,
240
  stop_strings=stop_strings,
 
241
  )
242
 
243
  def benchmark(
@@ -256,31 +329,8 @@ class UniversalDFlashDecoder:
256
  Returns:
257
  Dict with speedup metrics
258
  """
259
- import time
260
-
261
- print(f"[Benchmark] Running {num_runs} generations...")
262
-
263
- # Warmup
264
- self.generate(prompt, max_tokens=10)
265
-
266
- # DFlash generation
267
- dflash_times = []
268
- for _ in range(num_runs):
269
- start = time.time()
270
- self.generate(prompt, max_tokens=max_tokens)
271
- dflash_times.append(time.time() - start)
272
-
273
- # Baseline generation (without speculative decoding)
274
- # We estimate based on token count vs time
275
- # In practice you'd run a full baseline comparison
276
-
277
- avg_time = sum(dflash_times) / len(dflash_times)
278
- tokens_per_sec = max_tokens / avg_time
279
-
280
- print(f"[Benchmark] Avg time: {avg_time:.2f}s, Speed: {tokens_per_sec:.1f} tok/s")
281
-
282
- return {
283
- "avg_time_sec": avg_time,
284
- "tokens_per_sec": tokens_per_sec,
285
- "num_runs": num_runs,
286
- }
 
3
 
4
  Provides a high-level interface that works with any mlx_lm model,
5
  including those without pre-built DFlash drafters.
6
+
7
+ Now uses the architecture-agnostic adapter system for proper target model
8
+ interaction across all supported families (Qwen3, Qwen3.5, LLaMA, Mistral, Gemma).
9
  """
10
 
11
  from typing import Optional, List, Dict, Any
12
  import mlx.core as mx
13
  from .model import DFlashDraftModel
14
  from .speculative_decode import DFlashSpeculativeDecoder
15
+ from .adapters import load_target_model, LoadedTargetModel, detect_model_architecture
16
+ from .convert import load_mlx_dflash
17
 
18
 
19
  class UniversalDFlashDecoder:
20
  """Universal DFlash decoder that works with any MLX-converted model.
21
 
22
  This class handles:
23
+ 1. Loading pre-converted DFlash drafters with architecture detection
24
  2. Creating generic drafters for unsupported models
25
  3. Training custom drafters on-the-fly
26
+
27
+ Key improvement: Automatically detects target model architecture and
28
+ selects the correct adapter for hidden state extraction and KV cache management.
29
  """
30
 
31
  def __init__(
32
  self,
33
+ target_model: Any,
34
  tokenizer,
35
  draft_model_path: Optional[str] = None,
36
  draft_layers: int = 5,
 
41
  """Initialize the universal decoder.
42
 
43
  Args:
44
+ target_model: Any mlx_lm loaded model, or path/ID to load
45
  tokenizer: Tokenizer for the model
46
  draft_model_path: Optional path to pre-converted DFlash drafter
47
  draft_layers: Number of draft layers (if creating generic drafter)
 
49
  block_size: Number of tokens per draft block
50
  device: MLX device
51
  """
 
52
  self.tokenizer = tokenizer
53
  self.block_size = block_size
54
  self.device = device
55
 
56
+ # Resolve target model
57
+ if isinstance(target_model, str):
58
+ print(f"[UniversalDFlash] Loading target model: {target_model}...")
59
+ self.loaded_target = load_target_model(target_model)
60
+ self.target_model = self.loaded_target.model
61
+ elif hasattr(target_model, 'adapter'):
62
+ # Already a LoadedTargetModel
63
+ self.loaded_target = target_model
64
+ self.target_model = target_model.model
65
+ else:
66
+ # Raw mlx_lm model — detect architecture
67
+ print("[UniversalDFlash] Detecting model architecture...")
68
+ self.target_model = target_model
69
+ # Try to build adapter from model attributes
70
+ arch = detect_model_architecture(target_model)
71
+ print(f"[UniversalDFlash] Detected architecture: {arch}")
72
+ # Create minimal LoadedTargetModel wrapper
73
+ from .adapters import MLXTargetAdapter, adapter_for_model_type
74
+ adapter_cls = adapter_for_model_type(arch)
75
+ if adapter_cls is None:
76
+ adapter_cls = MLXTargetAdapter
77
+ adapter = adapter_cls(model=target_model, config={"model_type": arch})
78
+ self.loaded_target = LoadedTargetModel(
79
+ requested_model="unknown",
80
+ resolved_model_path=None,
81
+ model=target_model,
82
+ tokenizer=tokenizer,
83
+ adapter=adapter,
84
+ )
85
+
86
  # Determine model type and vocab size
87
  self.vocab_size = getattr(tokenizer, "vocab_size", 151936)
88
+ self.target_config = self._extract_target_config(self.target_model)
89
 
90
  # Load or create draft model
91
  if draft_model_path:
92
+ print(f"[UniversalDFlash] Loading pre-built drafter from {draft_model_path}...")
 
93
  self.draft_model, self.draft_config = load_mlx_dflash(draft_model_path)
94
  else:
95
  print("[UniversalDFlash] Creating generic drafter for your model...")
 
99
  )
100
  self.draft_config = None
101
 
102
+ # Create the speculative decoder with architecture-aware adapter
103
  self.decoder = DFlashSpeculativeDecoder(
104
+ target_model=self.loaded_target,
105
  draft_model=self.draft_model,
106
  tokenizer=tokenizer,
107
  block_size=block_size,
 
121
  config['intermediate_size'] = getattr(model_config, 'intermediate_size', 14336)
122
  config['num_attention_heads'] = getattr(model_config, 'num_attention_heads', 32)
123
  config['num_key_value_heads'] = getattr(model_config, 'num_key_value_heads', 8)
124
+ config['model_type'] = getattr(model_config, 'model_type', 'unknown')
125
  else:
126
  # Default Qwen3-4B-like config
127
  config = {
 
131
  'intermediate_size': 14336,
132
  'num_attention_heads': 32,
133
  'num_key_value_heads': 8,
134
+ 'model_type': 'unknown',
135
  }
136
 
137
  return config
 
145
 
146
  This creates an untrained drafter that can be trained or used
147
  with pre-trained weights from a similar architecture.
148
+
149
+ The draft model is sized proportionally to the target model's
150
+ hidden dimension for feature compatibility.
151
  """
152
  # Determine architecture compatibility
153
  hidden_size = self.target_config.get('hidden_size', 4096)
154
  vocab_size = self.target_config.get('vocab_size', 151936)
155
+ num_layers = self.target_config.get('num_layers', 32)
156
 
157
  # Scale drafter based on target model size
158
+ # Aim for ~1B params (common for draft models)
159
  num_heads = draft_hidden_size // 64 # ~64 dims per head
160
  num_kv_heads = max(1, num_heads // 4)
161
  intermediate_size = int(draft_hidden_size * 2.75) # Standard SwiGLU ratio
162
 
163
+ # Target layer ids for feature extraction
164
+ target_layer_ids = DFlashDraftModel._build_target_layer_ids(
165
+ None, num_layers, draft_layers
166
+ )
167
+
168
  drafter = DFlashDraftModel(
169
  vocab_size=vocab_size,
170
  hidden_size=draft_hidden_size,
 
174
  intermediate_size=intermediate_size,
175
  max_seq_len=8192,
176
  block_size=self.block_size,
177
+ mask_token_id=0, # Will be overridden by tokenizer
178
+ num_target_layers=num_layers,
179
+ target_layer_ids=target_layer_ids,
180
  )
181
 
182
  return drafter
 
194
  ) -> str:
195
  """Train a custom DFlash drafter for your target model.
196
 
197
+ Uses the training recipe from the DFlash paper:
198
+ - KV injection with target model features
199
+ - Random anchor sampling for block construction
200
+ - Sparse attention masking within blocks
201
+ - Position-dependent loss decay
202
+
203
  Args:
204
  dataset: Path to training dataset or HF dataset name
205
  max_seq_length: Maximum sequence length for training
206
+ epochs: Number of training epochs (paper: 6)
207
  batch_size: Training batch size
208
+ lr: Learning rate (paper: 6e-4)
209
+ warmup_ratio: Warmup ratio for cosine schedule (paper: 0.04)
210
+ grad_clip: Gradient clipping threshold (paper: 1.0)
211
  output_path: Where to save the trained drafter
212
 
213
  Returns:
 
216
  from .trainer import DFlashTrainer
217
 
218
  print(f"[UniversalDFlash] Training custom drafter...")
219
+ print(f" Dataset: {dataset}")
220
+ print(f" Epochs: {epochs}, Batch size: {batch_size}, LR: {lr}")
221
+
222
  trainer = DFlashTrainer(
223
  target_model=self.target_model,
224
  drafter=self.draft_model,
 
254
 
255
  # Save weights
256
  weights = dict(self.draft_model.parameters())
257
+
258
+ # Try multiple formats
259
+ try:
260
+ np_weights = {k: np.array(v) for k, v in weights.items()}
261
+ np.savez(str(path / "weights.npz"), **np_weights)
262
+ except Exception:
263
+ try:
264
+ mx.savez(str(path / "weights.npz"), **weights)
265
+ except Exception as e:
266
+ print(f"[Save] Error saving weights: {e}")
267
+ raise
268
 
269
  # Save config
270
  config = {
 
273
  "num_hidden_layers": self.draft_model.num_layers,
274
  "num_attention_heads": self.draft_model.num_heads,
275
  "num_key_value_heads": self.draft_model.num_heads // 4,
276
+ "intermediate_size": self.draft_model.layers[0].mlp.gate_proj.weight.shape[1]
277
+ if hasattr(self.draft_model.layers[0].mlp.gate_proj, 'weight') else 2816,
278
  "max_position_embeddings": self.draft_model.max_seq_len,
279
  "block_size": self.draft_model.block_size,
280
+ "target_layer_ids": self.draft_model.target_layer_ids,
281
  }
282
 
283
  with open(path / "config.json", "w") as f:
 
291
  max_tokens: int = 2048,
292
  temperature: float = 0.0,
293
  stop_strings: Optional[List[str]] = None,
294
+ stream: bool = False,
295
+ ) -> str | Any:
296
  """Generate text using DFlash speculative decoding.
297
 
298
  Args:
 
300
  max_tokens: Maximum tokens to generate
301
  temperature: Sampling temperature
302
  stop_strings: Optional stop strings
303
+ stream: If True, returns a generator yielding text deltas
304
 
305
  Returns:
306
+ Generated text string, or generator if stream=True
307
  """
308
  return self.decoder.generate(
309
  prompt=prompt,
310
  max_tokens=max_tokens,
311
  temperature=temperature,
312
  stop_strings=stop_strings,
313
+ stream=stream,
314
  )
315
 
316
  def benchmark(
 
329
  Returns:
330
  Dict with speedup metrics
331
  """
332
+ return self.decoder.benchmark(
333
+ prompt=prompt,
334
+ max_tokens=max_tokens,
335
+ num_runs=num_runs,
336
+ )