| """ |
| Universal DFlash decoder for any MLX-converted model. |
| |
| Provides a high-level interface that works with any mlx_lm model, |
| including those without pre-built DFlash drafters. |
| """ |
|
|
| from typing import Optional, List, Dict, Any |
| import mlx.core as mx |
| from .model import DFlashDraftModel |
| from .speculative_decode import DFlashSpeculativeDecoder |
|
|
|
|
| class UniversalDFlashDecoder: |
| """Universal DFlash decoder that works with any MLX-converted model. |
| |
| This class handles: |
| 1. Loading pre-converted DFlash drafters |
| 2. Creating generic drafters for unsupported models |
| 3. Training custom drafters on-the-fly |
| """ |
|
|
| def __init__( |
| self, |
| target_model, |
| tokenizer, |
| draft_model_path: Optional[str] = None, |
| draft_layers: int = 5, |
| draft_hidden_size: int = 1024, |
| block_size: int = 16, |
| device: str = "metal", |
| ): |
| """Initialize the universal decoder. |
| |
| Args: |
| target_model: Any mlx_lm loaded model |
| tokenizer: Tokenizer for the model |
| draft_model_path: Optional path to pre-converted DFlash drafter |
| draft_layers: Number of draft layers (if creating generic drafter) |
| draft_hidden_size: Hidden size for generic drafter |
| block_size: Number of tokens per draft block |
| device: MLX device |
| """ |
| self.target_model = target_model |
| self.tokenizer = tokenizer |
| self.block_size = block_size |
| self.device = device |
|
|
| |
| self.vocab_size = getattr(tokenizer, "vocab_size", 151936) |
| self.target_config = self._extract_target_config(target_model) |
|
|
| |
| if draft_model_path: |
| print(f"[UniversalDFlash] Loading pre-built drafter from {draft_model_path}") |
| from .convert import load_mlx_dflash |
| self.draft_model, self.draft_config = load_mlx_dflash(draft_model_path) |
| else: |
| print("[UniversalDFlash] Creating generic drafter for your model...") |
| self.draft_model = self._create_generic_drafter( |
| draft_layers=draft_layers, |
| draft_hidden_size=draft_hidden_size, |
| ) |
| self.draft_config = None |
|
|
| |
| self.decoder = DFlashSpeculativeDecoder( |
| target_model=target_model, |
| draft_model=self.draft_model, |
| tokenizer=tokenizer, |
| block_size=block_size, |
| device=device, |
| ) |
|
|
| def _extract_target_config(self, target_model) -> Dict[str, Any]: |
| """Extract configuration from target model.""" |
| config = {} |
| |
| |
| if hasattr(target_model, 'config'): |
| model_config = target_model.config |
| config['hidden_size'] = getattr(model_config, 'hidden_size', 4096) |
| config['num_layers'] = getattr(model_config, 'num_hidden_layers', 32) |
| config['vocab_size'] = getattr(model_config, 'vocab_size', 151936) |
| config['intermediate_size'] = getattr(model_config, 'intermediate_size', 14336) |
| config['num_attention_heads'] = getattr(model_config, 'num_attention_heads', 32) |
| config['num_key_value_heads'] = getattr(model_config, 'num_key_value_heads', 8) |
| else: |
| |
| config = { |
| 'hidden_size': 4096, |
| 'num_layers': 32, |
| 'vocab_size': 151936, |
| 'intermediate_size': 14336, |
| 'num_attention_heads': 32, |
| 'num_key_value_heads': 8, |
| } |
|
|
| return config |
|
|
| def _create_generic_drafter( |
| self, |
| draft_layers: int, |
| draft_hidden_size: int, |
| ) -> DFlashDraftModel: |
| """Create a generic DFlash drafter compatible with the target model. |
| |
| This creates an untrained drafter that can be trained or used |
| with pre-trained weights from a similar architecture. |
| """ |
| |
| hidden_size = self.target_config.get('hidden_size', 4096) |
| vocab_size = self.target_config.get('vocab_size', 151936) |
| |
| |
| num_heads = draft_hidden_size // 64 |
| num_kv_heads = max(1, num_heads // 4) |
| intermediate_size = int(draft_hidden_size * 2.75) |
|
|
| drafter = DFlashDraftModel( |
| vocab_size=vocab_size, |
| hidden_size=draft_hidden_size, |
| num_layers=draft_layers, |
| num_heads=num_heads, |
| num_kv_heads=num_kv_heads, |
| intermediate_size=intermediate_size, |
| max_seq_len=8192, |
| block_size=self.block_size, |
| mask_token_id=0, |
| num_target_layers=self.target_config.get('num_layers', 32), |
| ) |
|
|
| return drafter |
|
|
| def train_drafter( |
| self, |
| dataset: str, |
| max_seq_length: int = 3072, |
| epochs: int = 6, |
| batch_size: int = 32, |
| lr: float = 6e-4, |
| warmup_ratio: float = 0.04, |
| grad_clip: float = 1.0, |
| output_path: Optional[str] = None, |
| ) -> str: |
| """Train a custom DFlash drafter for your target model. |
| |
| Args: |
| dataset: Path to training dataset or HF dataset name |
| max_seq_length: Maximum sequence length for training |
| epochs: Number of training epochs |
| batch_size: Training batch size |
| lr: Learning rate |
| warmup_ratio: Warmup ratio for cosine schedule |
| grad_clip: Gradient clipping threshold |
| output_path: Where to save the trained drafter |
| |
| Returns: |
| Path to saved drafter |
| """ |
| from .trainer import DFlashTrainer |
|
|
| print(f"[UniversalDFlash] Training custom drafter...") |
| trainer = DFlashTrainer( |
| target_model=self.target_model, |
| drafter=self.draft_model, |
| tokenizer=self.tokenizer, |
| ) |
|
|
| trained_model = trainer.train( |
| dataset=dataset, |
| max_seq_length=max_seq_length, |
| epochs=epochs, |
| batch_size=batch_size, |
| lr=lr, |
| warmup_ratio=warmup_ratio, |
| grad_clip=grad_clip, |
| ) |
|
|
| |
| self.draft_model = trained_model |
| self.decoder.draft_model = trained_model |
|
|
| if output_path: |
| self.save_drafter(output_path) |
|
|
| return output_path or "./trained_dflash_drafter" |
|
|
| def save_drafter(self, path: str): |
| """Save the current drafter model.""" |
| import json |
| from pathlib import Path |
|
|
| path = Path(path) |
| path.mkdir(parents=True, exist_ok=True) |
|
|
| |
| weights = dict(self.draft_model.parameters()) |
| mx.save_safetensors(str(path / "weights.safetensors"), weights) |
|
|
| |
| config = { |
| "vocab_size": self.draft_model.vocab_size, |
| "hidden_size": self.draft_model.hidden_size, |
| "num_hidden_layers": self.draft_model.num_layers, |
| "num_attention_heads": self.draft_model.num_heads, |
| "num_key_value_heads": self.draft_model.num_heads // 4, |
| "intermediate_size": self.draft_model.layers[0].mlp.gate_proj.weight.shape[1] if hasattr(self.draft_model.layers[0].mlp.gate_proj, 'weight') else 2816, |
| "max_position_embeddings": self.draft_model.max_seq_len, |
| "block_size": self.draft_model.block_size, |
| } |
|
|
| with open(path / "config.json", "w") as f: |
| json.dump(config, f, indent=2) |
|
|
| print(f"[UniversalDFlash] Drafter saved to {path}") |
|
|
| def generate( |
| self, |
| prompt: str, |
| max_tokens: int = 2048, |
| temperature: float = 0.0, |
| stop_strings: Optional[List[str]] = None, |
| ) -> str: |
| """Generate text using DFlash speculative decoding. |
| |
| Args: |
| prompt: Text prompt |
| max_tokens: Maximum tokens to generate |
| temperature: Sampling temperature |
| stop_strings: Optional stop strings |
| |
| Returns: |
| Generated text |
| """ |
| return self.decoder.generate( |
| prompt=prompt, |
| max_tokens=max_tokens, |
| temperature=temperature, |
| stop_strings=stop_strings, |
| ) |
|
|
| def benchmark( |
| self, |
| prompt: str = "Write a quicksort in Python.", |
| max_tokens: int = 512, |
| num_runs: int = 5, |
| ) -> Dict[str, float]: |
| """Benchmark DFlash speculative decoding. |
| |
| Args: |
| prompt: Test prompt |
| max_tokens: Tokens per run |
| num_runs: Number of benchmark runs |
| |
| Returns: |
| Dict with speedup metrics |
| """ |
| import time |
|
|
| print(f"[Benchmark] Running {num_runs} generations...") |
| |
| |
| self.generate(prompt, max_tokens=10) |
|
|
| |
| dflash_times = [] |
| for _ in range(num_runs): |
| start = time.time() |
| self.generate(prompt, max_tokens=max_tokens) |
| dflash_times.append(time.time() - start) |
|
|
| |
| |
| |
| |
| avg_time = sum(dflash_times) / len(dflash_times) |
| tokens_per_sec = max_tokens / avg_time |
|
|
| print(f"[Benchmark] Avg time: {avg_time:.2f}s, Speed: {tokens_per_sec:.1f} tok/s") |
| |
| return { |
| "avg_time_sec": avg_time, |
| "tokens_per_sec": tokens_per_sec, |
| "num_runs": num_runs, |
| } |
|
|