| """ |
| Universal DFlash decoder for any MLX-converted model. |
| |
| Provides a high-level interface that works with any mlx_lm model, |
| including those without pre-built DFlash drafters. |
| |
| Now uses the architecture-agnostic adapter system for proper target model |
| interaction across all supported families (Qwen3, Qwen3.5, LLaMA, Mistral, Gemma). |
| """ |
|
|
| from typing import Optional, List, Dict, Any |
| import mlx.core as mx |
| from .model import DFlashDraftModel |
| from .speculative_decode import DFlashSpeculativeDecoder |
| from .adapters import load_target_model, LoadedTargetModel, detect_model_architecture |
| from .convert import load_mlx_dflash |
|
|
|
|
| def _build_target_layer_ids(num_target_layers: int, num_draft_layers: int) -> List[int]: |
| """Select target model layer indices for feature extraction. |
| |
| Uniformly samples from shallow to deep layers for cross-layer |
| feature fusion, matching the DFlash paper. |
| """ |
| if num_draft_layers == 1: |
| return [num_target_layers // 2] |
| start = 1 |
| end = num_target_layers - 3 |
| span = end - start |
| return [ |
| int(round(start + (i * span) / (num_draft_layers - 1))) |
| for i in range(num_draft_layers) |
| ] |
|
|
|
|
| class UniversalDFlashDecoder: |
| """Universal DFlash decoder that works with any MLX-converted model. |
| |
| This class handles: |
| 1. Loading pre-converted DFlash drafters with architecture detection |
| 2. Creating generic drafters for unsupported models |
| 3. Training custom drafters on-the-fly |
| |
| Key improvement: Automatically detects target model architecture and |
| selects the correct adapter for hidden state extraction and KV cache management. |
| """ |
|
|
| def __init__( |
| self, |
| target_model: Any, |
| tokenizer, |
| draft_model_path: Optional[str] = None, |
| draft_layers: int = 5, |
| draft_hidden_size: int = 1024, |
| block_size: int = 16, |
| device: str = "metal", |
| ): |
| """Initialize the universal decoder. |
| |
| Args: |
| target_model: Any mlx_lm loaded model, or path/ID to load |
| tokenizer: Tokenizer for the model |
| draft_model_path: Optional path to pre-converted DFlash drafter |
| draft_layers: Number of draft layers (if creating generic drafter) |
| draft_hidden_size: Hidden size for generic drafter |
| block_size: Number of tokens per draft block |
| device: MLX device |
| """ |
| self.tokenizer = tokenizer |
| self.block_size = block_size |
| self.device = device |
|
|
| |
| if isinstance(target_model, str): |
| print(f"[UniversalDFlash] Loading target model: {target_model}...") |
| self.loaded_target = load_target_model(target_model) |
| self.target_model = self.loaded_target.model |
| elif hasattr(target_model, 'adapter'): |
| |
| self.loaded_target = target_model |
| self.target_model = target_model.model |
| else: |
| |
| print("[UniversalDFlash] Detecting model architecture...") |
| self.target_model = target_model |
| |
| arch = detect_model_architecture(target_model) |
| print(f"[UniversalDFlash] Detected architecture: {arch}") |
| |
| from .adapters import MLXTargetAdapter, adapter_for_model_type |
| adapter_cls = adapter_for_model_type(arch) |
| if adapter_cls is None: |
| adapter_cls = MLXTargetAdapter |
| adapter = adapter_cls(model=target_model, config={"model_type": arch}) |
| self.loaded_target = LoadedTargetModel( |
| requested_model="unknown", |
| resolved_model_path=None, |
| model=target_model, |
| tokenizer=tokenizer, |
| adapter=adapter, |
| ) |
|
|
| |
| self.vocab_size = getattr(tokenizer, "vocab_size", 151936) |
| self.target_config = self._extract_target_config(self.target_model) |
|
|
| |
| if draft_model_path: |
| print(f"[UniversalDFlash] Loading pre-built drafter from {draft_model_path}...") |
| self.draft_model, self.draft_config = load_mlx_dflash(draft_model_path) |
| else: |
| print("[UniversalDFlash] Creating generic drafter for your model...") |
| self.draft_model = self._create_generic_drafter( |
| draft_layers=draft_layers, |
| draft_hidden_size=draft_hidden_size, |
| ) |
| self.draft_config = None |
|
|
| |
| self.decoder = DFlashSpeculativeDecoder( |
| target_model=self.loaded_target, |
| draft_model=self.draft_model, |
| tokenizer=tokenizer, |
| block_size=block_size, |
| device=device, |
| ) |
|
|
| def _extract_target_config(self, target_model) -> Dict[str, Any]: |
| """Extract configuration from target model.""" |
| config = {} |
| |
| |
| if hasattr(target_model, 'config'): |
| model_config = target_model.config |
| config['hidden_size'] = getattr(model_config, 'hidden_size', 4096) |
| config['num_layers'] = getattr(model_config, 'num_hidden_layers', 32) |
| config['vocab_size'] = getattr(model_config, 'vocab_size', 151936) |
| config['intermediate_size'] = getattr(model_config, 'intermediate_size', 14336) |
| config['num_attention_heads'] = getattr(model_config, 'num_attention_heads', 32) |
| config['num_key_value_heads'] = getattr(model_config, 'num_key_value_heads', 8) |
| config['model_type'] = getattr(model_config, 'model_type', 'unknown') |
| else: |
| |
| config = { |
| 'hidden_size': 4096, |
| 'num_layers': 32, |
| 'vocab_size': 151936, |
| 'intermediate_size': 14336, |
| 'num_attention_heads': 32, |
| 'num_key_value_heads': 8, |
| 'model_type': 'unknown', |
| } |
|
|
| return config |
|
|
| def _create_generic_drafter( |
| self, |
| draft_layers: int, |
| draft_hidden_size: int, |
| ) -> DFlashDraftModel: |
| """Create a generic DFlash drafter compatible with the target model. |
| |
| This creates an untrained drafter that can be trained or used |
| with pre-trained weights from a similar architecture. |
| |
| The draft model is sized proportionally to the target model's |
| hidden dimension for feature compatibility. |
| """ |
| |
| hidden_size = self.target_config.get('hidden_size', 4096) |
| vocab_size = self.target_config.get('vocab_size', 151936) |
| num_layers = self.target_config.get('num_layers', 32) |
| |
| |
| |
| num_heads = draft_hidden_size // 64 |
| num_kv_heads = max(1, num_heads // 4) |
| intermediate_size = int(draft_hidden_size * 2.75) |
|
|
| |
| target_layer_ids = _build_target_layer_ids(num_layers, draft_layers) |
|
|
| drafter = DFlashDraftModel( |
| vocab_size=vocab_size, |
| hidden_size=draft_hidden_size, |
| num_layers=draft_layers, |
| num_heads=num_heads, |
| num_kv_heads=num_kv_heads, |
| intermediate_size=intermediate_size, |
| max_seq_len=8192, |
| block_size=self.block_size, |
| mask_token_id=0, |
| num_target_layers=num_layers, |
| target_layer_ids=target_layer_ids, |
| ) |
|
|
| return drafter |
|
|
| def train_drafter( |
| self, |
| dataset: str, |
| max_seq_length: int = 3072, |
| epochs: int = 6, |
| batch_size: int = 32, |
| lr: float = 6e-4, |
| warmup_ratio: float = 0.04, |
| grad_clip: float = 1.0, |
| output_path: Optional[str] = None, |
| ) -> str: |
| """Train a custom DFlash drafter for your target model. |
| |
| Uses the training recipe from the DFlash paper: |
| - KV injection with target model features |
| - Random anchor sampling for block construction |
| - Sparse attention masking within blocks |
| - Position-dependent loss decay |
| |
| Args: |
| dataset: Path to training dataset or HF dataset name |
| max_seq_length: Maximum sequence length for training |
| epochs: Number of training epochs (paper: 6) |
| batch_size: Training batch size |
| lr: Learning rate (paper: 6e-4) |
| warmup_ratio: Warmup ratio for cosine schedule (paper: 0.04) |
| grad_clip: Gradient clipping threshold (paper: 1.0) |
| output_path: Where to save the trained drafter |
| |
| Returns: |
| Path to saved drafter |
| """ |
| from .trainer import DFlashTrainer |
|
|
| print(f"[UniversalDFlash] Training custom drafter...") |
| print(f" Dataset: {dataset}") |
| print(f" Epochs: {epochs}, Batch size: {batch_size}, LR: {lr}") |
| |
| trainer = DFlashTrainer( |
| target_model=self.target_model, |
| drafter=self.draft_model, |
| tokenizer=self.tokenizer, |
| ) |
|
|
| trained_model = trainer.train( |
| dataset=dataset, |
| max_seq_length=max_seq_length, |
| epochs=epochs, |
| batch_size=batch_size, |
| lr=lr, |
| warmup_ratio=warmup_ratio, |
| grad_clip=grad_clip, |
| ) |
|
|
| |
| self.draft_model = trained_model |
| self.decoder.draft_model = trained_model |
|
|
| if output_path: |
| self.save_drafter(output_path) |
|
|
| return output_path or "./trained_dflash_drafter" |
|
|
| def save_drafter(self, path: str): |
| """Save the current drafter model.""" |
| import json |
| from pathlib import Path |
| import numpy as np |
|
|
| path = Path(path) |
| path.mkdir(parents=True, exist_ok=True) |
|
|
| |
| weights = dict(self.draft_model.parameters()) |
| |
| |
| try: |
| np_weights = {k: np.array(v) for k, v in weights.items()} |
| np.savez(str(path / "weights.npz"), **np_weights) |
| except Exception: |
| try: |
| mx.savez(str(path / "weights.npz"), **weights) |
| except Exception as e: |
| print(f"[Save] Error saving weights: {e}") |
| raise |
|
|
| |
| config = { |
| "vocab_size": self.draft_model.vocab_size, |
| "hidden_size": self.draft_model.hidden_size, |
| "num_hidden_layers": self.draft_model.num_layers, |
| "num_attention_heads": self.draft_model.num_heads, |
| "num_key_value_heads": self.draft_model.num_heads // 4, |
| "intermediate_size": self.draft_model.layers[0].mlp.gate_proj.weight.shape[1] |
| if hasattr(self.draft_model.layers[0].mlp.gate_proj, 'weight') else 2816, |
| "max_position_embeddings": self.draft_model.max_seq_len, |
| "block_size": self.draft_model.block_size, |
| "target_layer_ids": self.draft_model.target_layer_ids, |
| } |
|
|
| with open(path / "config.json", "w") as f: |
| json.dump(config, f, indent=2) |
|
|
| print(f"[UniversalDFlash] Drafter saved to {path}") |
|
|
| def generate( |
| self, |
| prompt: str, |
| max_tokens: int = 2048, |
| temperature: float = 0.0, |
| stop_strings: Optional[List[str]] = None, |
| stream: bool = False, |
| ) -> str | Any: |
| """Generate text using DFlash speculative decoding. |
| |
| Args: |
| prompt: Text prompt |
| max_tokens: Maximum tokens to generate |
| temperature: Sampling temperature |
| stop_strings: Optional stop strings |
| stream: If True, returns a generator yielding text deltas |
| |
| Returns: |
| Generated text string, or generator if stream=True |
| """ |
| return self.decoder.generate( |
| prompt=prompt, |
| max_tokens=max_tokens, |
| temperature=temperature, |
| stop_strings=stop_strings, |
| stream=stream, |
| ) |
|
|
| def benchmark( |
| self, |
| prompt: str = "Write a quicksort in Python.", |
| max_tokens: int = 512, |
| num_runs: int = 5, |
| ) -> Dict[str, float]: |
| """Benchmark DFlash speculative decoding. |
| |
| Args: |
| prompt: Test prompt |
| max_tokens: Tokens per run |
| num_runs: Number of benchmark runs |
| |
| Returns: |
| Dict with speedup metrics |
| """ |
| return self.decoder.benchmark( |
| prompt=prompt, |
| max_tokens=max_tokens, |
| num_runs=num_runs, |
| ) |
|
|