tritesh commited on
Commit
81dfca1
·
verified ·
1 Parent(s): 20edc82

Upload dflash_mlx/universal.py

Browse files
Files changed (1) hide show
  1. dflash_mlx/universal.py +19 -3
dflash_mlx/universal.py CHANGED
@@ -16,6 +16,23 @@ from .adapters import load_target_model, LoadedTargetModel, detect_model_archite
16
  from .convert import load_mlx_dflash
17
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  class UniversalDFlashDecoder:
20
  """Universal DFlash decoder that works with any MLX-converted model.
21
 
@@ -161,9 +178,7 @@ class UniversalDFlashDecoder:
161
  intermediate_size = int(draft_hidden_size * 2.75) # Standard SwiGLU ratio
162
 
163
  # Target layer ids for feature extraction
164
- target_layer_ids = DFlashDraftModel._build_target_layer_ids(
165
- None, num_layers, draft_layers
166
- )
167
 
168
  drafter = DFlashDraftModel(
169
  vocab_size=vocab_size,
@@ -248,6 +263,7 @@ class UniversalDFlashDecoder:
248
  """Save the current drafter model."""
249
  import json
250
  from pathlib import Path
 
251
 
252
  path = Path(path)
253
  path.mkdir(parents=True, exist_ok=True)
 
16
  from .convert import load_mlx_dflash
17
 
18
 
19
+ def _build_target_layer_ids(num_target_layers: int, num_draft_layers: int) -> List[int]:
20
+ """Select target model layer indices for feature extraction.
21
+
22
+ Uniformly samples from shallow to deep layers for cross-layer
23
+ feature fusion, matching the DFlash paper.
24
+ """
25
+ if num_draft_layers == 1:
26
+ return [num_target_layers // 2]
27
+ start = 1
28
+ end = num_target_layers - 3
29
+ span = end - start
30
+ return [
31
+ int(round(start + (i * span) / (num_draft_layers - 1)))
32
+ for i in range(num_draft_layers)
33
+ ]
34
+
35
+
36
  class UniversalDFlashDecoder:
37
  """Universal DFlash decoder that works with any MLX-converted model.
38
 
 
178
  intermediate_size = int(draft_hidden_size * 2.75) # Standard SwiGLU ratio
179
 
180
  # Target layer ids for feature extraction
181
+ target_layer_ids = _build_target_layer_ids(num_layers, draft_layers)
 
 
182
 
183
  drafter = DFlashDraftModel(
184
  vocab_size=vocab_size,
 
263
  """Save the current drafter model."""
264
  import json
265
  from pathlib import Path
266
+ import numpy as np
267
 
268
  path = Path(path)
269
  path.mkdir(parents=True, exist_ok=True)