""" Universal DFlash decoder for any MLX-converted model. Provides a high-level interface that works with any mlx_lm model, including those without pre-built DFlash drafters. """ from typing import Optional, List, Dict, Any import mlx.core as mx from .model import DFlashDraftModel from .speculative_decode import DFlashSpeculativeDecoder class UniversalDFlashDecoder: """Universal DFlash decoder that works with any MLX-converted model. This class handles: 1. Loading pre-converted DFlash drafters 2. Creating generic drafters for unsupported models 3. Training custom drafters on-the-fly """ def __init__( self, target_model, tokenizer, draft_model_path: Optional[str] = None, draft_layers: int = 5, draft_hidden_size: int = 1024, block_size: int = 16, device: str = "metal", ): """Initialize the universal decoder. Args: target_model: Any mlx_lm loaded model tokenizer: Tokenizer for the model draft_model_path: Optional path to pre-converted DFlash drafter draft_layers: Number of draft layers (if creating generic drafter) draft_hidden_size: Hidden size for generic drafter block_size: Number of tokens per draft block device: MLX device """ self.target_model = target_model self.tokenizer = tokenizer self.block_size = block_size self.device = device # Determine model type and vocab size self.vocab_size = getattr(tokenizer, "vocab_size", 151936) self.target_config = self._extract_target_config(target_model) # Load or create draft model if draft_model_path: print(f"[UniversalDFlash] Loading pre-built drafter from {draft_model_path}") from .convert import load_mlx_dflash self.draft_model, self.draft_config = load_mlx_dflash(draft_model_path) else: print("[UniversalDFlash] Creating generic drafter for your model...") self.draft_model = self._create_generic_drafter( draft_layers=draft_layers, draft_hidden_size=draft_hidden_size, ) self.draft_config = None # Create the speculative decoder self.decoder = DFlashSpeculativeDecoder( target_model=target_model, draft_model=self.draft_model, tokenizer=tokenizer, block_size=block_size, device=device, ) def _extract_target_config(self, target_model) -> Dict[str, Any]: """Extract configuration from target model.""" config = {} # Try to extract from model attributes if hasattr(target_model, 'config'): model_config = target_model.config config['hidden_size'] = getattr(model_config, 'hidden_size', 4096) config['num_layers'] = getattr(model_config, 'num_hidden_layers', 32) config['vocab_size'] = getattr(model_config, 'vocab_size', 151936) config['intermediate_size'] = getattr(model_config, 'intermediate_size', 14336) config['num_attention_heads'] = getattr(model_config, 'num_attention_heads', 32) config['num_key_value_heads'] = getattr(model_config, 'num_key_value_heads', 8) else: # Default Qwen3-4B-like config config = { 'hidden_size': 4096, 'num_layers': 32, 'vocab_size': 151936, 'intermediate_size': 14336, 'num_attention_heads': 32, 'num_key_value_heads': 8, } return config def _create_generic_drafter( self, draft_layers: int, draft_hidden_size: int, ) -> DFlashDraftModel: """Create a generic DFlash drafter compatible with the target model. This creates an untrained drafter that can be trained or used with pre-trained weights from a similar architecture. """ # Determine architecture compatibility hidden_size = self.target_config.get('hidden_size', 4096) vocab_size = self.target_config.get('vocab_size', 151936) # Scale drafter based on target model size num_heads = draft_hidden_size // 64 # ~64 dims per head num_kv_heads = max(1, num_heads // 4) intermediate_size = int(draft_hidden_size * 2.75) # Standard SwiGLU ratio drafter = DFlashDraftModel( vocab_size=vocab_size, hidden_size=draft_hidden_size, num_layers=draft_layers, num_heads=num_heads, num_kv_heads=num_kv_heads, intermediate_size=intermediate_size, max_seq_len=8192, block_size=self.block_size, mask_token_id=0, # Will be set from tokenizer num_target_layers=self.target_config.get('num_layers', 32), ) return drafter def train_drafter( self, dataset: str, max_seq_length: int = 3072, epochs: int = 6, batch_size: int = 32, lr: float = 6e-4, warmup_ratio: float = 0.04, grad_clip: float = 1.0, output_path: Optional[str] = None, ) -> str: """Train a custom DFlash drafter for your target model. Args: dataset: Path to training dataset or HF dataset name max_seq_length: Maximum sequence length for training epochs: Number of training epochs batch_size: Training batch size lr: Learning rate warmup_ratio: Warmup ratio for cosine schedule grad_clip: Gradient clipping threshold output_path: Where to save the trained drafter Returns: Path to saved drafter """ from .trainer import DFlashTrainer print(f"[UniversalDFlash] Training custom drafter...") trainer = DFlashTrainer( target_model=self.target_model, drafter=self.draft_model, tokenizer=self.tokenizer, ) trained_model = trainer.train( dataset=dataset, max_seq_length=max_seq_length, epochs=epochs, batch_size=batch_size, lr=lr, warmup_ratio=warmup_ratio, grad_clip=grad_clip, ) # Update the draft model self.draft_model = trained_model self.decoder.draft_model = trained_model if output_path: self.save_drafter(output_path) return output_path or "./trained_dflash_drafter" def save_drafter(self, path: str): """Save the current drafter model.""" import json from pathlib import Path path = Path(path) path.mkdir(parents=True, exist_ok=True) # Save weights weights = dict(self.draft_model.parameters()) mx.save_safetensors(str(path / "weights.safetensors"), weights) # Save config config = { "vocab_size": self.draft_model.vocab_size, "hidden_size": self.draft_model.hidden_size, "num_hidden_layers": self.draft_model.num_layers, "num_attention_heads": self.draft_model.num_heads, "num_key_value_heads": self.draft_model.num_heads // 4, "intermediate_size": self.draft_model.layers[0].mlp.gate_proj.weight.shape[1] if hasattr(self.draft_model.layers[0].mlp.gate_proj, 'weight') else 2816, "max_position_embeddings": self.draft_model.max_seq_len, "block_size": self.draft_model.block_size, } with open(path / "config.json", "w") as f: json.dump(config, f, indent=2) print(f"[UniversalDFlash] Drafter saved to {path}") def generate( self, prompt: str, max_tokens: int = 2048, temperature: float = 0.0, stop_strings: Optional[List[str]] = None, ) -> str: """Generate text using DFlash speculative decoding. Args: prompt: Text prompt max_tokens: Maximum tokens to generate temperature: Sampling temperature stop_strings: Optional stop strings Returns: Generated text """ return self.decoder.generate( prompt=prompt, max_tokens=max_tokens, temperature=temperature, stop_strings=stop_strings, ) def benchmark( self, prompt: str = "Write a quicksort in Python.", max_tokens: int = 512, num_runs: int = 5, ) -> Dict[str, float]: """Benchmark DFlash speculative decoding. Args: prompt: Test prompt max_tokens: Tokens per run num_runs: Number of benchmark runs Returns: Dict with speedup metrics """ import time print(f"[Benchmark] Running {num_runs} generations...") # Warmup self.generate(prompt, max_tokens=10) # DFlash generation dflash_times = [] for _ in range(num_runs): start = time.time() self.generate(prompt, max_tokens=max_tokens) dflash_times.append(time.time() - start) # Baseline generation (without speculative decoding) # We estimate based on token count vs time # In practice you'd run a full baseline comparison avg_time = sum(dflash_times) / len(dflash_times) tokens_per_sec = max_tokens / avg_time print(f"[Benchmark] Avg time: {avg_time:.2f}s, Speed: {tokens_per_sec:.1f} tok/s") return { "avg_time_sec": avg_time, "tokens_per_sec": tokens_per_sec, "num_runs": num_runs, }