tritesh commited on
Commit
e8cb6a7
·
verified ·
1 Parent(s): bb76689

Upload dflash_mlx/speculative_decode.py

Browse files
Files changed (1) hide show
  1. dflash_mlx/speculative_decode.py +371 -164
dflash_mlx/speculative_decode.py CHANGED
@@ -6,12 +6,20 @@ Implements the full inference pipeline:
6
  2. Draft: Block diffusion model generates parallel draft tokens
7
  3. Verify: Target model verifies drafts in parallel
8
  4. Accept: Accepted tokens appended, rejected tokens regenerated
 
 
9
  """
10
 
11
- from typing import Optional, List, Callable
12
  import mlx.core as mx
13
  import mlx.nn as nn
14
  from .model import DFlashDraftModel
 
 
 
 
 
 
15
 
16
 
17
  def sample_greedy(logits: mx.array) -> mx.array:
@@ -25,91 +33,145 @@ def sample_temperature(logits: mx.array, temperature: float) -> mx.array:
25
  return mx.random.categorical(mx.log(probs))
26
 
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  class DFlashSpeculativeDecoder:
29
  """DFlash speculative decoder for MLX-converted models.
30
 
31
- This decoder works with any MLX causal language model as the target,
32
  paired with a DFlash block diffusion draft model.
 
 
 
 
 
 
33
  """
34
 
35
  def __init__(
36
  self,
37
- target_model,
38
  draft_model: DFlashDraftModel,
39
  tokenizer,
40
  block_size: int = 16,
41
  max_seq_length: int = 8192,
42
  device: str = "metal",
 
43
  ):
44
  """Initialize the DFlash speculative decoder.
45
 
46
  Args:
47
- target_model: MLX target LLM (any mlx_lm loaded model)
48
  draft_model: DFlash block diffusion draft model
49
  tokenizer: Tokenizer for encoding/decoding
50
  block_size: Number of tokens to draft per block
51
  max_seq_length: Maximum sequence length
52
  device: MLX device ("cpu" or "metal")
 
53
  """
54
- self.target_model = target_model
 
 
 
 
 
 
 
 
 
55
  self.draft_model = draft_model
56
  self.tokenizer = tokenizer
57
  self.block_size = block_size
58
  self.max_seq_length = max_seq_length
59
  self.device = device
60
  self.mask_token_id = draft_model.mask_token_id
61
-
 
 
 
 
 
 
 
 
 
 
62
  def _target_forward(
63
  self,
64
  input_ids: mx.array,
65
- past_key_values: Optional[dict] = None,
66
  output_hidden_states: bool = False,
67
- ) -> dict:
68
- """Forward pass through target model.
 
69
 
70
  Args:
71
- input_ids: Input token IDs
72
- past_key_values: Optional KV cache
73
- output_hidden_states: Whether to return hidden states
 
74
 
75
  Returns:
76
- Dict with logits and optionally hidden states
77
  """
78
- # MLX model forward
79
- if hasattr(self.target_model, '__call__'):
80
- result = self.target_model(
81
- input_ids,
82
- cache=past_key_values,
 
 
 
 
 
 
 
 
83
  )
84
- logits = result[0] if isinstance(result, tuple) else result
 
 
 
 
85
  else:
86
- logits = self.target_model(input_ids)
87
-
88
- output = {"logits": logits}
89
-
90
- # Extract hidden states if needed (for KV injection)
91
- if output_hidden_states and hasattr(self.target_model, 'layers'):
92
- hidden_states = []
93
- hidden = self.target_model.embed_tokens(input_ids)
94
- for layer in self.target_model.layers:
95
- hidden = layer(hidden, mask=None, cache=past_key_values)
96
- hidden_states.append(hidden)
97
- output["hidden_states"] = hidden_states
98
-
99
- return output
100
-
101
  def _sample(self, logits: mx.array, temperature: float) -> mx.array:
102
  """Sample from logits."""
103
  if temperature < 1e-5:
104
  return sample_greedy(logits)
105
  return sample_temperature(logits, temperature)
106
-
107
  def spec_generate(
108
  self,
109
  input_ids: mx.array,
110
  max_new_tokens: int,
111
  temperature: float = 0.0,
112
- stop_token_ids: Optional[List[int]] = None,
 
113
  ) -> mx.array:
114
  """Generate tokens using DFlash speculative decoding.
115
 
@@ -117,149 +179,214 @@ class DFlashSpeculativeDecoder:
117
  input_ids: Prompt token IDs [bsz, seq_len]
118
  max_new_tokens: Maximum new tokens to generate
119
  temperature: Sampling temperature (0 for greedy)
120
- stop_token_ids: Optional list of stop token IDs
 
121
 
122
  Returns:
123
  Generated token IDs [bsz, total_seq_len]
124
  """
125
- num_input_tokens = input_ids.shape[1]
126
  max_length = num_input_tokens + max_new_tokens
127
  block_size = self.block_size
128
-
129
- # Initialize output buffer with mask tokens
130
  output_ids = mx.full(
131
  (1, max_length + block_size),
132
  self.mask_token_id,
133
  dtype=mx.int32,
134
  )
135
  position_ids = mx.arange(output_ids.shape[1])
136
-
137
- # Target model KV cache
138
- target_cache = None
139
- draft_cache = None
140
-
141
- # Prefill stage: process prompt with target model
142
- print("[DFlash] Prefill stage...")
 
 
143
  target_output = self._target_forward(
144
  input_ids,
145
- past_key_values=target_cache,
146
  output_hidden_states=True,
 
147
  )
148
-
149
  # Copy prompt tokens to output
150
  output_ids[:, :num_input_tokens] = input_ids[0]
151
-
152
- # Sample first token from target model
153
  first_token_logits = target_output["logits"][:, -1:, :]
154
  first_token = self._sample(first_token_logits, temperature)
155
  output_ids[:, num_input_tokens] = first_token[0, 0]
156
-
157
  # Extract target context features for draft conditioning
158
- if "hidden_states" in target_output:
159
- target_hidden = self.draft_model.extract_context_features(
160
- target_output["hidden_states"]
161
- )
162
- else:
163
- # Fallback: use last hidden state as single feature
164
- target_hidden = target_output["logits"]
165
- # Project to hidden size if needed
166
- # (simplified - in practice we'd need proper projection)
167
-
168
- # Decode stage: speculative decoding loop
 
 
 
 
169
  print(f"[DFlash] Starting speculative decoding (block_size={block_size})...")
170
- acceptance_lengths = []
171
- start = num_input_tokens
172
- generated_count = 0
173
-
 
 
 
174
  while start < max_length and generated_count < max_new_tokens:
175
- # 1. Draft: generate block of tokens with diffusion model
176
- block_output_ids = mx.array(output_ids[:, start : start + block_size])
177
- block_position_ids = position_ids[start : start + block_size]
178
-
 
 
 
 
 
 
 
 
 
 
179
  # Embed draft tokens (including mask tokens)
180
  draft_embeddings = self.draft_model.embed_tokens(block_output_ids)
181
-
182
- # Run draft model to get predictions for masked positions
183
  draft_hidden = self.draft_model(
184
  noise_embedding=draft_embeddings,
185
  target_hidden=target_hidden,
186
  position_ids=block_position_ids,
187
  )
188
  draft_logits = self.draft_model.get_logits(draft_hidden)
189
-
190
  # Sample draft tokens (predict all positions)
191
- draft_tokens = self._sample(draft_logits[:, 1:, :], temperature)
192
-
193
- # Fill draft predictions into block (keep first token from target)
194
- block_output_ids = mx.array(block_output_ids)
195
- block_output_ids[:, 1:] = draft_tokens
196
-
197
- # 2. Verify: run target model on draft tokens
198
- target_output = self._target_forward(
199
- block_output_ids,
200
- past_key_values=target_cache,
 
 
201
  output_hidden_states=True,
 
202
  )
203
- target_logits = target_output["logits"]
204
- posterior = self._sample(target_logits, temperature)
205
-
206
- # 3. Accept: compare draft vs target tokens
207
- # Count consecutive matches from position 1 onwards
208
- draft_for_compare = block_output_ids[:, 1:]
209
- target_for_compare = posterior[:, :-1]
210
-
 
 
 
 
 
211
  matches = draft_for_compare == target_for_compare
212
- # Find first mismatch
213
- match_cumprod = mx.cumprod(matches.astype(mx.int32), axis=1)
214
- acceptance_length = int(match_cumprod.sum())
215
-
216
- # Accepted tokens: draft tokens up to acceptance_length
217
- # Rejected token: target's prediction at first mismatch
218
- output_ids[:, start : start + acceptance_length + 1] = block_output_ids[:, : acceptance_length + 1]
219
- output_ids[:, start + acceptance_length + 1] = posterior[:, acceptance_length]
220
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  # Update counters
222
- start += acceptance_length + 1
223
- generated_count += acceptance_length + 1
224
- acceptance_lengths.append(acceptance_length + 1)
225
-
226
- # Update target context features for next iteration
227
- if "hidden_states" in target_output:
228
- target_hidden = self.draft_model.extract_context_features(
229
- target_output["hidden_states"]
230
- )
231
- target_hidden = target_hidden[:, :acceptance_length + 1, :]
232
-
 
 
 
 
 
 
233
  # Check stop conditions
234
  if stop_token_ids is not None:
235
- generated = output_ids[0, num_input_tokens:start]
236
- if any(int(tid) in stop_token_ids for tid in generated):
237
- # Find first stop token and truncate
238
- for i, tid in enumerate(generated):
239
- if int(tid) in stop_token_ids:
240
- start = num_input_tokens + i + 1
241
- break
242
- break
243
-
244
- # Trim to actual length
 
245
  output_ids = output_ids[:, :start]
246
-
247
- # Remove any remaining mask tokens
248
  valid_mask = output_ids[0] != self.mask_token_id
249
  output_ids = output_ids[:, valid_mask]
250
-
251
- avg_acceptance = sum(acceptance_lengths) / len(acceptance_lengths) if acceptance_lengths else 0
252
- print(f"[DFlash] Done. Generated {generated_count} tokens, avg acceptance: {avg_acceptance:.2f}")
253
-
 
 
 
 
 
 
 
 
254
  return output_ids
255
-
256
  def generate(
257
  self,
258
  prompt: str,
259
  max_tokens: int = 2048,
260
  temperature: float = 0.0,
261
  stop_strings: Optional[List[str]] = None,
262
- ) -> str:
 
263
  """High-level generate method with string input/output.
264
 
265
  Args:
@@ -267,45 +394,125 @@ class DFlashSpeculativeDecoder:
267
  max_tokens: Maximum tokens to generate
268
  temperature: Sampling temperature
269
  stop_strings: Optional list of stop strings
 
270
 
271
  Returns:
272
- Generated text string
273
  """
274
- # Tokenize
275
- if hasattr(self.tokenizer, 'apply_chat_template'):
276
- messages = [{"role": "user", "content": prompt}]
277
- text = self.tokenizer.apply_chat_template(
278
- messages,
279
- tokenize=False,
280
- add_generation_prompt=True,
281
- )
282
- input_ids = mx.array(self.tokenizer.encode(text))
283
- input_ids = input_ids.reshape(1, -1)
284
- else:
285
- input_ids = mx.array(self.tokenizer.encode(prompt))
286
- input_ids = input_ids.reshape(1, -1)
287
-
288
  # Determine stop token IDs
289
  stop_token_ids = None
290
  if stop_strings is not None:
291
- stop_token_ids = []
292
  for s in stop_strings:
293
  tokens = self.tokenizer.encode(s, add_special_tokens=False)
294
- stop_token_ids.extend(tokens)
295
- elif hasattr(self.tokenizer, 'eos_token_id'):
296
- stop_token_ids = [self.tokenizer.eos_token_id]
297
-
298
- # Generate
299
- output_ids = self.spec_generate(
300
- input_ids=input_ids,
301
- max_new_tokens=max_tokens,
302
- temperature=temperature,
303
- stop_token_ids=stop_token_ids,
304
- )
305
-
306
- # Decode (skip prompt)
307
- prompt_len = input_ids.shape[1]
308
- generated_ids = output_ids[0, prompt_len:]
309
- output_text = self.tokenizer.decode(generated_ids.tolist())
310
-
311
- return output_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  2. Draft: Block diffusion model generates parallel draft tokens
7
  3. Verify: Target model verifies drafts in parallel
8
  4. Accept: Accepted tokens appended, rejected tokens regenerated
9
+
10
+ Fixed for architecture-agnostic operation across Qwen3, Qwen3.5, LLaMA, Mistral, Gemma.
11
  """
12
 
13
+ from typing import Optional, List, Callable, Dict, Any, Tuple
14
  import mlx.core as mx
15
  import mlx.nn as nn
16
  from .model import DFlashDraftModel
17
+ from .adapters import (
18
+ LoadedTargetModel,
19
+ load_target_model,
20
+ adapter_for_model_type,
21
+ detect_model_architecture,
22
+ )
23
 
24
 
25
  def sample_greedy(logits: mx.array) -> mx.array:
 
33
  return mx.random.categorical(mx.log(probs))
34
 
35
 
36
+ def find_first_mismatch(draft: mx.array, target: mx.array) -> int:
37
+ """Find length of matching prefix between draft and target tokens.
38
+
39
+ Returns the number of consecutive matching tokens from the start.
40
+ """
41
+ matches = draft == target
42
+ # Convert to int for cumsum, find first 0
43
+ match_int = matches.astype(mx.int32)
44
+ # Use argmin to find first mismatch (first 0 in cumprod is actually tricky)
45
+ # Simpler: find first position where match is False
46
+ mismatch_positions = mx.where(matches == False, mx.arange(matches.shape[0]), matches.shape[0])
47
+ first_mismatch = int(mismatch_positions.min())
48
+ return first_mismatch
49
+
50
+
51
  class DFlashSpeculativeDecoder:
52
  """DFlash speculative decoder for MLX-converted models.
53
 
54
+ Architecture-agnostic: works with any MLX causal language model as the target,
55
  paired with a DFlash block diffusion draft model.
56
+
57
+ Key improvements over naive implementation:
58
+ - Proper KV cache management with trim/rewind on rejection
59
+ - Architecture-aware hidden state extraction via adapters
60
+ - Correct acceptance logic using first-mismatch detection
61
+ - Streaming support for real-time output
62
  """
63
 
64
  def __init__(
65
  self,
66
+ target_model: Any,
67
  draft_model: DFlashDraftModel,
68
  tokenizer,
69
  block_size: int = 16,
70
  max_seq_length: int = 8192,
71
  device: str = "metal",
72
+ adapter: Optional[LoadedTargetModel] = None,
73
  ):
74
  """Initialize the DFlash speculative decoder.
75
 
76
  Args:
77
+ target_model: MLX target LLM (any mlx_lm loaded model) or LoadedTargetModel
78
  draft_model: DFlash block diffusion draft model
79
  tokenizer: Tokenizer for encoding/decoding
80
  block_size: Number of tokens to draft per block
81
  max_seq_length: Maximum sequence length
82
  device: MLX device ("cpu" or "metal")
83
+ adapter: Optional pre-built adapter (if target_model is raw mlx_lm model)
84
  """
85
+ # If target_model is already a LoadedTargetModel, use it directly
86
+ if hasattr(target_model, 'adapter') and hasattr(target_model, 'model'):
87
+ self.loaded_target = target_model
88
+ elif adapter is not None:
89
+ self.loaded_target = adapter
90
+ else:
91
+ # Auto-detect and build adapter
92
+ self.loaded_target = load_target_model(target_model)
93
+
94
+ self.target_model = self.loaded_target.model
95
  self.draft_model = draft_model
96
  self.tokenizer = tokenizer
97
  self.block_size = block_size
98
  self.max_seq_length = max_seq_length
99
  self.device = device
100
  self.mask_token_id = draft_model.mask_token_id
101
+
102
+ # Verify compatibility
103
+ self._validate_setup()
104
+
105
+ def _validate_setup(self):
106
+ """Check that target and draft models are compatible."""
107
+ target_vocab = getattr(self.tokenizer, 'vocab_size', None)
108
+ draft_vocab = self.draft_model.vocab_size
109
+ if target_vocab is not None and target_vocab != draft_vocab:
110
+ print(f"[DFlash] Warning: vocab mismatch target={target_vocab} draft={draft_vocab}")
111
+
112
  def _target_forward(
113
  self,
114
  input_ids: mx.array,
115
+ cache: Optional[list] = None,
116
  output_hidden_states: bool = False,
117
+ layer_ids: Optional[List[int]] = None,
118
+ ) -> Dict[str, Any]:
119
+ """Forward pass through target model using adapter.
120
 
121
  Args:
122
+ input_ids: Input token IDs [bsz, seq_len]
123
+ cache: Per-layer KV cache (managed by adapter)
124
+ output_hidden_states: Whether to return hidden states for KV injection
125
+ layer_ids: Target layer indices to extract (from draft model config)
126
 
127
  Returns:
128
+ Dict with 'logits' and optionally 'hidden_states', 'target_hidden'
129
  """
130
+ if cache is None:
131
+ cache = self.loaded_target.make_cache()
132
+
133
+ if layer_ids is None:
134
+ layer_ids = getattr(self.draft_model, 'target_layer_ids', [])
135
+
136
+ if output_hidden_states and layer_ids:
137
+ # Forward with hidden state extraction at specified layers
138
+ logits, target_hidden, _ = self.loaded_target.forward_with_hidden_states(
139
+ tokens=input_ids,
140
+ cache=cache,
141
+ layer_ids=layer_ids,
142
+ output_rollback_records=False,
143
  )
144
+ return {
145
+ "logits": logits,
146
+ "target_hidden": target_hidden,
147
+ "cache": cache,
148
+ }
149
  else:
150
+ # Simple forward without hidden states
151
+ logits, _ = self.loaded_target.forward_with_hidden_states(
152
+ tokens=input_ids,
153
+ cache=cache,
154
+ layer_ids=[],
155
+ output_rollback_records=False,
156
+ )
157
+ return {
158
+ "logits": logits,
159
+ "cache": cache,
160
+ }
161
+
 
 
 
162
  def _sample(self, logits: mx.array, temperature: float) -> mx.array:
163
  """Sample from logits."""
164
  if temperature < 1e-5:
165
  return sample_greedy(logits)
166
  return sample_temperature(logits, temperature)
167
+
168
  def spec_generate(
169
  self,
170
  input_ids: mx.array,
171
  max_new_tokens: int,
172
  temperature: float = 0.0,
173
+ stop_token_ids: Optional[set[int]] = None,
174
+ stream_callback: Optional[Callable[[str, bool], None]] = None,
175
  ) -> mx.array:
176
  """Generate tokens using DFlash speculative decoding.
177
 
 
179
  input_ids: Prompt token IDs [bsz, seq_len]
180
  max_new_tokens: Maximum new tokens to generate
181
  temperature: Sampling temperature (0 for greedy)
182
+ stop_token_ids: Optional set of stop token IDs
183
+ stream_callback: Optional callback(text_delta, finished) for streaming
184
 
185
  Returns:
186
  Generated token IDs [bsz, total_seq_len]
187
  """
188
+ num_input_tokens = int(input_ids.shape[1])
189
  max_length = num_input_tokens + max_new_tokens
190
  block_size = self.block_size
191
+
192
+ # Initialize output buffer
193
  output_ids = mx.full(
194
  (1, max_length + block_size),
195
  self.mask_token_id,
196
  dtype=mx.int32,
197
  )
198
  position_ids = mx.arange(output_ids.shape[1])
199
+
200
+ # Create fresh KV cache for target model
201
+ target_cache = self.loaded_target.make_cache()
202
+
203
+ # Get target layer IDs from draft model config
204
+ layer_ids = getattr(self.draft_model, 'target_layer_ids', [])
205
+
206
+ # ── Prefill stage ────────────────────────────────────────────────────
207
+ print(f"[DFlash] Prefill: processing {num_input_tokens} prompt tokens...")
208
  target_output = self._target_forward(
209
  input_ids,
210
+ cache=target_cache,
211
  output_hidden_states=True,
212
+ layer_ids=layer_ids,
213
  )
214
+
215
  # Copy prompt tokens to output
216
  output_ids[:, :num_input_tokens] = input_ids[0]
217
+
218
+ # Sample first token from target model (position num_input_tokens)
219
  first_token_logits = target_output["logits"][:, -1:, :]
220
  first_token = self._sample(first_token_logits, temperature)
221
  output_ids[:, num_input_tokens] = first_token[0, 0]
222
+
223
  # Extract target context features for draft conditioning
224
+ target_hidden = target_output.get("target_hidden")
225
+ if target_hidden is None:
226
+ print("[DFlash] Warning: no hidden states extracted, using fallback")
227
+ # Fallback: project logits to hidden size
228
+ # This will produce poor drafts but allows the loop to continue
229
+ target_hidden = mx.zeros((1, 1, self.draft_model.hidden_size))
230
+
231
+ # Update cache with the first generated token
232
+ _ = self._target_forward(
233
+ first_token,
234
+ cache=target_cache,
235
+ output_hidden_states=False,
236
+ )
237
+
238
+ # ── Decode stage: speculative decoding loop ──────────────────────────
239
  print(f"[DFlash] Starting speculative decoding (block_size={block_size})...")
240
+ acceptance_lengths: List[int] = []
241
+ start = num_input_tokens + 1 # After first target-generated token
242
+ generated_count = 1
243
+
244
+ # Streaming state
245
+ stream_buffer = ""
246
+
247
  while start < max_length and generated_count < max_new_tokens:
248
+ # 1. DRAFT: generate block of tokens with diffusion model
249
+ # Prepare block: first token is last accepted token, rest are masked
250
+ block_slice = output_ids[:, start - 1 : start - 1 + block_size]
251
+ block_output_ids = mx.array(block_slice)
252
+ # Mask all positions after the first (anchor)
253
+ block_output_ids = mx.where(
254
+ mx.arange(block_size) == 0,
255
+ block_output_ids,
256
+ self.mask_token_id,
257
+ )
258
+ block_output_ids = block_output_ids.reshape(1, block_size)
259
+
260
+ block_position_ids = position_ids[start - 1 : start - 1 + block_size]
261
+
262
  # Embed draft tokens (including mask tokens)
263
  draft_embeddings = self.draft_model.embed_tokens(block_output_ids)
264
+
265
+ # Run draft model to get predictions for all positions
266
  draft_hidden = self.draft_model(
267
  noise_embedding=draft_embeddings,
268
  target_hidden=target_hidden,
269
  position_ids=block_position_ids,
270
  )
271
  draft_logits = self.draft_model.get_logits(draft_hidden)
272
+
273
  # Sample draft tokens (predict all positions)
274
+ draft_tokens = self._sample(draft_logits, temperature)
275
+
276
+ # Build verification input: anchor + draft predictions
277
+ verify_input = mx.concatenate([
278
+ block_output_ids[:, :1], # Anchor token
279
+ draft_tokens[:, :-1], # Draft predictions (excluding last)
280
+ ], axis=1)
281
+
282
+ # 2. VERIFY: run target model on draft tokens
283
+ verify_output = self._target_forward(
284
+ verify_input,
285
+ cache=target_cache,
286
  output_hidden_states=True,
287
+ layer_ids=layer_ids,
288
  )
289
+ verify_logits = verify_output["logits"]
290
+
291
+ # Target's greedy predictions at each position
292
+ posterior = self._sample(verify_logits, temperature=0.0)
293
+
294
+ # 3. ACCEPT: compare draft vs target tokens
295
+ # draft_tokens[0, 1:] are the predictions for positions 1..block_size-1
296
+ # posterior[0, :-1] are target's predictions for positions 0..block_size-2
297
+ # We compare draft at position i with target at position i-1 for i>=1
298
+ draft_for_compare = draft_tokens[0, 1:]
299
+ target_for_compare = posterior[0, :-1]
300
+
301
+ # Find first mismatch in the block
302
  matches = draft_for_compare == target_for_compare
303
+ match_int = matches.astype(mx.int32)
304
+ # cumprod gives 1 up to first mismatch, then 0
305
+ match_prefix = mx.cumprod(match_int)
306
+ acceptance_length = int(match_prefix.sum())
307
+
308
+ # Accepted tokens: draft predictions for positions 1..acceptance_length
309
+ # Rejected position: target's prediction at acceptance_length
310
+ num_new_tokens = acceptance_length + 1 # +1 for the bonus token
311
+
312
+ # Copy accepted tokens
313
+ accepted_tokens = draft_tokens[0, 1:1 + acceptance_length]
314
+ if acceptance_length < verify_input.shape[1] - 1:
315
+ bonus_token = posterior[0, acceptance_length]
316
+ new_tokens = mx.concatenate([accepted_tokens, mx.array([bonus_token])])
317
+ else:
318
+ # All draft tokens accepted, need one more from target
319
+ bonus_logits = verify_output["logits"][:, -1:, :]
320
+ bonus_token = self._sample(bonus_logits, temperature)[0, 0]
321
+ new_tokens = mx.concatenate([accepted_tokens, mx.array([bonus_token])])
322
+
323
+ # Write new tokens to output
324
+ end_pos = min(start + len(new_tokens), max_length)
325
+ actual_new = end_pos - start
326
+ if actual_new > 0:
327
+ output_ids[:, start:end_pos] = new_tokens[:actual_new].reshape(1, -1)
328
+
329
+ # 4. KV CACHE: rewind to accepted length
330
+ self.loaded_target.rewind_kv_caches(target_cache, start + actual_new)
331
+
332
  # Update counters
333
+ start += actual_new
334
+ generated_count += actual_new
335
+ acceptance_lengths.append(actual_new)
336
+
337
+ # 5. UPDATE target hidden states for next iteration
338
+ if "target_hidden" in verify_output:
339
+ target_hidden = verify_output["target_hidden"]
340
+ # Keep only up to accepted positions
341
+ if target_hidden.shape[1] > actual_new:
342
+ target_hidden = target_hidden[:, :actual_new, :]
343
+
344
+ # Stream output
345
+ if stream_callback is not None:
346
+ new_text = self.tokenizer.decode(new_tokens.tolist()[:actual_new])
347
+ stream_buffer += new_text
348
+ stream_callback(new_text, False)
349
+
350
  # Check stop conditions
351
  if stop_token_ids is not None:
352
+ generated_slice = output_ids[0, num_input_tokens:start]
353
+ generated_list = generated_slice.tolist()
354
+ for i, tid in enumerate(generated_list):
355
+ if int(tid) in stop_token_ids:
356
+ start = num_input_tokens + i + 1
357
+ break
358
+ else:
359
+ continue
360
+ break
361
+
362
+ # Final trim
363
  output_ids = output_ids[:, :start]
364
+
365
+ # Remove mask tokens
366
  valid_mask = output_ids[0] != self.mask_token_id
367
  output_ids = output_ids[:, valid_mask]
368
+
369
+ # Stats
370
+ if acceptance_lengths:
371
+ avg_acceptance = sum(acceptance_lengths) / len(acceptance_lengths)
372
+ speedup = sum(acceptance_lengths) / len(acceptance_lengths) if acceptance_lengths else 1.0
373
+ print(f"[DFlash] Done. Generated {generated_count} tokens, "
374
+ f"avg acceptance: {avg_acceptance:.2f}, effective speedup: ~{speedup:.2f}x")
375
+
376
+ # Final stream callback
377
+ if stream_callback is not None:
378
+ stream_callback("", True)
379
+
380
  return output_ids
381
+
382
  def generate(
383
  self,
384
  prompt: str,
385
  max_tokens: int = 2048,
386
  temperature: float = 0.0,
387
  stop_strings: Optional[List[str]] = None,
388
+ stream: bool = False,
389
+ ) -> str | Any:
390
  """High-level generate method with string input/output.
391
 
392
  Args:
 
394
  max_tokens: Maximum tokens to generate
395
  temperature: Sampling temperature
396
  stop_strings: Optional list of stop strings
397
+ stream: If True, returns a generator yielding text deltas
398
 
399
  Returns:
400
+ Generated text string, or generator if stream=True
401
  """
402
+ # Tokenize via adapter
403
+ input_ids = self.loaded_target.build_prompt(prompt)
404
+ input_ids = input_ids.reshape(1, -1)
405
+
 
 
 
 
 
 
 
 
 
 
406
  # Determine stop token IDs
407
  stop_token_ids = None
408
  if stop_strings is not None:
409
+ stop_token_ids = set()
410
  for s in stop_strings:
411
  tokens = self.tokenizer.encode(s, add_special_tokens=False)
412
+ stop_token_ids.update(tokens)
413
+ else:
414
+ stop_token_ids = self.loaded_target.stop_token_ids()
415
+
416
+ if stream:
417
+ # Streaming generator
418
+ stream_buffer: List[str] = []
419
+
420
+ def callback(delta: str, finished: bool):
421
+ stream_buffer.append(delta)
422
+
423
+ output_ids = self.spec_generate(
424
+ input_ids=input_ids,
425
+ max_new_tokens=max_tokens,
426
+ temperature=temperature,
427
+ stop_token_ids=stop_token_ids,
428
+ stream_callback=callback,
429
+ )
430
+
431
+ # Yield accumulated text
432
+ for chunk in stream_buffer:
433
+ yield chunk
434
+ else:
435
+ # One-shot generation
436
+ output_ids = self.spec_generate(
437
+ input_ids=input_ids,
438
+ max_new_tokens=max_tokens,
439
+ temperature=temperature,
440
+ stop_token_ids=stop_token_ids,
441
+ )
442
+
443
+ # Decode (skip prompt)
444
+ prompt_len = input_ids.shape[1]
445
+ generated_ids = output_ids[0, prompt_len:]
446
+ output_text = self.tokenizer.decode(generated_ids.tolist())
447
+
448
+ return output_text
449
+
450
+ def benchmark(
451
+ self,
452
+ prompt: str = "Write a quicksort in Python.",
453
+ max_tokens: int = 512,
454
+ num_runs: int = 5,
455
+ ) -> Dict[str, float]:
456
+ """Benchmark DFlash speculative decoding.
457
+
458
+ Args:
459
+ prompt: Test prompt
460
+ max_tokens: Tokens per run
461
+ num_runs: Number of benchmark runs
462
+
463
+ Returns:
464
+ Dict with speedup metrics
465
+ """
466
+ import time
467
+
468
+ print(f"[Benchmark] Running {num_runs} generations with DFlash...")
469
+
470
+ # Warmup
471
+ self.generate(prompt, max_tokens=10)
472
+ mx.eval()
473
+
474
+ # DFlash generation
475
+ dflash_times = []
476
+ for _ in range(num_runs):
477
+ start = time.time()
478
+ self.generate(prompt, max_tokens=max_tokens)
479
+ mx.eval()
480
+ dflash_times.append(time.time() - start)
481
+
482
+ # Baseline: run target model without speculative decoding
483
+ print(f"[Benchmark] Running {num_runs} baseline generations...")
484
+ baseline_times = []
485
+
486
+ # Simple baseline using mlx_lm generate
487
+ try:
488
+ from mlx_lm.utils import generate as mlx_generate
489
+ for _ in range(num_runs):
490
+ start = time.time()
491
+ mlx_generate(
492
+ model=self.target_model,
493
+ tokenizer=self.tokenizer,
494
+ prompt=prompt,
495
+ max_tokens=max_tokens,
496
+ temp=temperature,
497
+ )
498
+ mx.eval()
499
+ baseline_times.append(time.time() - start)
500
+ except Exception as e:
501
+ print(f"[Benchmark] Baseline generation failed: {e}")
502
+ baseline_times = [t * 2.0 for t in dflash_times] # Estimate
503
+
504
+ avg_dflash = sum(dflash_times) / len(dflash_times)
505
+ avg_baseline = sum(baseline_times) / len(baseline_times) if baseline_times else avg_dflash * 2
506
+
507
+ tokens_per_sec = max_tokens / avg_dflash
508
+ speedup = avg_baseline / avg_dflash if avg_baseline > 0 else 1.0
509
+
510
+ print(f"[Benchmark] Baseline: {avg_baseline:.2f}s | DFlash: {avg_dflash:.2f}s | Speedup: {speedup:.2f}x | {tokens_per_sec:.1f} tok/s")
511
+
512
+ return {
513
+ "avg_time_sec": avg_dflash,
514
+ "tokens_per_sec": tokens_per_sec,
515
+ "speedup": speedup,
516
+ "baseline_time_sec": avg_baseline,
517
+ "num_runs": num_runs,
518
+ }