Upload dflash_mlx/universal.py
Browse files- dflash_mlx/universal.py +19 -3
dflash_mlx/universal.py
CHANGED
|
@@ -16,6 +16,23 @@ from .adapters import load_target_model, LoadedTargetModel, detect_model_archite
|
|
| 16 |
from .convert import load_mlx_dflash
|
| 17 |
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
class UniversalDFlashDecoder:
|
| 20 |
"""Universal DFlash decoder that works with any MLX-converted model.
|
| 21 |
|
|
@@ -161,9 +178,7 @@ class UniversalDFlashDecoder:
|
|
| 161 |
intermediate_size = int(draft_hidden_size * 2.75) # Standard SwiGLU ratio
|
| 162 |
|
| 163 |
# Target layer ids for feature extraction
|
| 164 |
-
target_layer_ids =
|
| 165 |
-
None, num_layers, draft_layers
|
| 166 |
-
)
|
| 167 |
|
| 168 |
drafter = DFlashDraftModel(
|
| 169 |
vocab_size=vocab_size,
|
|
@@ -248,6 +263,7 @@ class UniversalDFlashDecoder:
|
|
| 248 |
"""Save the current drafter model."""
|
| 249 |
import json
|
| 250 |
from pathlib import Path
|
|
|
|
| 251 |
|
| 252 |
path = Path(path)
|
| 253 |
path.mkdir(parents=True, exist_ok=True)
|
|
|
|
| 16 |
from .convert import load_mlx_dflash
|
| 17 |
|
| 18 |
|
| 19 |
+
def _build_target_layer_ids(num_target_layers: int, num_draft_layers: int) -> List[int]:
|
| 20 |
+
"""Select target model layer indices for feature extraction.
|
| 21 |
+
|
| 22 |
+
Uniformly samples from shallow to deep layers for cross-layer
|
| 23 |
+
feature fusion, matching the DFlash paper.
|
| 24 |
+
"""
|
| 25 |
+
if num_draft_layers == 1:
|
| 26 |
+
return [num_target_layers // 2]
|
| 27 |
+
start = 1
|
| 28 |
+
end = num_target_layers - 3
|
| 29 |
+
span = end - start
|
| 30 |
+
return [
|
| 31 |
+
int(round(start + (i * span) / (num_draft_layers - 1)))
|
| 32 |
+
for i in range(num_draft_layers)
|
| 33 |
+
]
|
| 34 |
+
|
| 35 |
+
|
| 36 |
class UniversalDFlashDecoder:
|
| 37 |
"""Universal DFlash decoder that works with any MLX-converted model.
|
| 38 |
|
|
|
|
| 178 |
intermediate_size = int(draft_hidden_size * 2.75) # Standard SwiGLU ratio
|
| 179 |
|
| 180 |
# Target layer ids for feature extraction
|
| 181 |
+
target_layer_ids = _build_target_layer_ids(num_layers, draft_layers)
|
|
|
|
|
|
|
| 182 |
|
| 183 |
drafter = DFlashDraftModel(
|
| 184 |
vocab_size=vocab_size,
|
|
|
|
| 263 |
"""Save the current drafter model."""
|
| 264 |
import json
|
| 265 |
from pathlib import Path
|
| 266 |
+
import numpy as np
|
| 267 |
|
| 268 |
path = Path(path)
|
| 269 |
path.mkdir(parents=True, exist_ok=True)
|