""" Universal DFlash decoder for any MLX-converted model. Provides a high-level interface that works with any mlx_lm model, including those without pre-built DFlash drafters. Now uses the architecture-agnostic adapter system for proper target model interaction across all supported families (Qwen3, Qwen3.5, LLaMA, Mistral, Gemma). """ from typing import Optional, List, Dict, Any import mlx.core as mx from .model import DFlashDraftModel from .speculative_decode import DFlashSpeculativeDecoder from .adapters import load_target_model, LoadedTargetModel, detect_model_architecture from .convert import load_mlx_dflash def _build_target_layer_ids(num_target_layers: int, num_draft_layers: int) -> List[int]: """Select target model layer indices for feature extraction. Uniformly samples from shallow to deep layers for cross-layer feature fusion, matching the DFlash paper. """ if num_draft_layers == 1: return [num_target_layers // 2] start = 1 end = num_target_layers - 3 span = end - start return [ int(round(start + (i * span) / (num_draft_layers - 1))) for i in range(num_draft_layers) ] class UniversalDFlashDecoder: """Universal DFlash decoder that works with any MLX-converted model. This class handles: 1. Loading pre-converted DFlash drafters with architecture detection 2. Creating generic drafters for unsupported models 3. Training custom drafters on-the-fly Key improvement: Automatically detects target model architecture and selects the correct adapter for hidden state extraction and KV cache management. """ def __init__( self, target_model: Any, tokenizer, draft_model_path: Optional[str] = None, draft_layers: int = 5, draft_hidden_size: int = 1024, block_size: int = 16, device: str = "metal", ): """Initialize the universal decoder. Args: target_model: Any mlx_lm loaded model, or path/ID to load tokenizer: Tokenizer for the model draft_model_path: Optional path to pre-converted DFlash drafter draft_layers: Number of draft layers (if creating generic drafter) draft_hidden_size: Hidden size for generic drafter block_size: Number of tokens per draft block device: MLX device """ self.tokenizer = tokenizer self.block_size = block_size self.device = device # Resolve target model if isinstance(target_model, str): print(f"[UniversalDFlash] Loading target model: {target_model}...") self.loaded_target = load_target_model(target_model) self.target_model = self.loaded_target.model elif hasattr(target_model, 'adapter'): # Already a LoadedTargetModel self.loaded_target = target_model self.target_model = target_model.model else: # Raw mlx_lm model — detect architecture print("[UniversalDFlash] Detecting model architecture...") self.target_model = target_model # Try to build adapter from model attributes arch = detect_model_architecture(target_model) print(f"[UniversalDFlash] Detected architecture: {arch}") # Create minimal LoadedTargetModel wrapper from .adapters import MLXTargetAdapter, adapter_for_model_type adapter_cls = adapter_for_model_type(arch) if adapter_cls is None: adapter_cls = MLXTargetAdapter adapter = adapter_cls(model=target_model, config={"model_type": arch}) self.loaded_target = LoadedTargetModel( requested_model="unknown", resolved_model_path=None, model=target_model, tokenizer=tokenizer, adapter=adapter, ) # Determine model type and vocab size self.vocab_size = getattr(tokenizer, "vocab_size", 151936) self.target_config = self._extract_target_config(self.target_model) # Load or create draft model if draft_model_path: print(f"[UniversalDFlash] Loading pre-built drafter from {draft_model_path}...") self.draft_model, self.draft_config = load_mlx_dflash(draft_model_path) else: print("[UniversalDFlash] Creating generic drafter for your model...") self.draft_model = self._create_generic_drafter( draft_layers=draft_layers, draft_hidden_size=draft_hidden_size, ) self.draft_config = None # Create the speculative decoder with architecture-aware adapter self.decoder = DFlashSpeculativeDecoder( target_model=self.loaded_target, draft_model=self.draft_model, tokenizer=tokenizer, block_size=block_size, device=device, ) def _extract_target_config(self, target_model) -> Dict[str, Any]: """Extract configuration from target model.""" config = {} # Try to extract from model attributes if hasattr(target_model, 'config'): model_config = target_model.config config['hidden_size'] = getattr(model_config, 'hidden_size', 4096) config['num_layers'] = getattr(model_config, 'num_hidden_layers', 32) config['vocab_size'] = getattr(model_config, 'vocab_size', 151936) config['intermediate_size'] = getattr(model_config, 'intermediate_size', 14336) config['num_attention_heads'] = getattr(model_config, 'num_attention_heads', 32) config['num_key_value_heads'] = getattr(model_config, 'num_key_value_heads', 8) config['model_type'] = getattr(model_config, 'model_type', 'unknown') else: # Default Qwen3-4B-like config config = { 'hidden_size': 4096, 'num_layers': 32, 'vocab_size': 151936, 'intermediate_size': 14336, 'num_attention_heads': 32, 'num_key_value_heads': 8, 'model_type': 'unknown', } return config def _create_generic_drafter( self, draft_layers: int, draft_hidden_size: int, ) -> DFlashDraftModel: """Create a generic DFlash drafter compatible with the target model. This creates an untrained drafter that can be trained or used with pre-trained weights from a similar architecture. The draft model is sized proportionally to the target model's hidden dimension for feature compatibility. """ # Determine architecture compatibility hidden_size = self.target_config.get('hidden_size', 4096) vocab_size = self.target_config.get('vocab_size', 151936) num_layers = self.target_config.get('num_layers', 32) # Scale drafter based on target model size # Aim for ~1B params (common for draft models) num_heads = draft_hidden_size // 64 # ~64 dims per head num_kv_heads = max(1, num_heads // 4) intermediate_size = int(draft_hidden_size * 2.75) # Standard SwiGLU ratio # Target layer ids for feature extraction target_layer_ids = _build_target_layer_ids(num_layers, draft_layers) drafter = DFlashDraftModel( vocab_size=vocab_size, hidden_size=draft_hidden_size, num_layers=draft_layers, num_heads=num_heads, num_kv_heads=num_kv_heads, intermediate_size=intermediate_size, max_seq_len=8192, block_size=self.block_size, mask_token_id=0, # Will be overridden by tokenizer num_target_layers=num_layers, target_layer_ids=target_layer_ids, ) return drafter def train_drafter( self, dataset: str, max_seq_length: int = 3072, epochs: int = 6, batch_size: int = 32, lr: float = 6e-4, warmup_ratio: float = 0.04, grad_clip: float = 1.0, output_path: Optional[str] = None, ) -> str: """Train a custom DFlash drafter for your target model. Uses the training recipe from the DFlash paper: - KV injection with target model features - Random anchor sampling for block construction - Sparse attention masking within blocks - Position-dependent loss decay Args: dataset: Path to training dataset or HF dataset name max_seq_length: Maximum sequence length for training epochs: Number of training epochs (paper: 6) batch_size: Training batch size lr: Learning rate (paper: 6e-4) warmup_ratio: Warmup ratio for cosine schedule (paper: 0.04) grad_clip: Gradient clipping threshold (paper: 1.0) output_path: Where to save the trained drafter Returns: Path to saved drafter """ from .trainer import DFlashTrainer print(f"[UniversalDFlash] Training custom drafter...") print(f" Dataset: {dataset}") print(f" Epochs: {epochs}, Batch size: {batch_size}, LR: {lr}") trainer = DFlashTrainer( target_model=self.target_model, drafter=self.draft_model, tokenizer=self.tokenizer, ) trained_model = trainer.train( dataset=dataset, max_seq_length=max_seq_length, epochs=epochs, batch_size=batch_size, lr=lr, warmup_ratio=warmup_ratio, grad_clip=grad_clip, ) # Update the draft model self.draft_model = trained_model self.decoder.draft_model = trained_model if output_path: self.save_drafter(output_path) return output_path or "./trained_dflash_drafter" def save_drafter(self, path: str): """Save the current drafter model.""" import json from pathlib import Path import numpy as np path = Path(path) path.mkdir(parents=True, exist_ok=True) # Save weights weights = dict(self.draft_model.parameters()) # Try multiple formats try: np_weights = {k: np.array(v) for k, v in weights.items()} np.savez(str(path / "weights.npz"), **np_weights) except Exception: try: mx.savez(str(path / "weights.npz"), **weights) except Exception as e: print(f"[Save] Error saving weights: {e}") raise # Save config config = { "vocab_size": self.draft_model.vocab_size, "hidden_size": self.draft_model.hidden_size, "num_hidden_layers": self.draft_model.num_layers, "num_attention_heads": self.draft_model.num_heads, "num_key_value_heads": self.draft_model.num_heads // 4, "intermediate_size": self.draft_model.layers[0].mlp.gate_proj.weight.shape[1] if hasattr(self.draft_model.layers[0].mlp.gate_proj, 'weight') else 2816, "max_position_embeddings": self.draft_model.max_seq_len, "block_size": self.draft_model.block_size, "target_layer_ids": self.draft_model.target_layer_ids, } with open(path / "config.json", "w") as f: json.dump(config, f, indent=2) print(f"[UniversalDFlash] Drafter saved to {path}") def generate( self, prompt: str, max_tokens: int = 2048, temperature: float = 0.0, stop_strings: Optional[List[str]] = None, stream: bool = False, ) -> str | Any: """Generate text using DFlash speculative decoding. Args: prompt: Text prompt max_tokens: Maximum tokens to generate temperature: Sampling temperature stop_strings: Optional stop strings stream: If True, returns a generator yielding text deltas Returns: Generated text string, or generator if stream=True """ return self.decoder.generate( prompt=prompt, max_tokens=max_tokens, temperature=temperature, stop_strings=stop_strings, stream=stream, ) def benchmark( self, prompt: str = "Write a quicksort in Python.", max_tokens: int = 512, num_runs: int = 5, ) -> Dict[str, float]: """Benchmark DFlash speculative decoding. Args: prompt: Test prompt max_tokens: Tokens per run num_runs: Number of benchmark runs Returns: Dict with speedup metrics """ return self.decoder.benchmark( prompt=prompt, max_tokens=max_tokens, num_runs=num_runs, )