""" Universal architecture adapters for DFlash speculative decoding on MLX. Supports: Qwen3, Qwen3.5, LLaMA (2/3), Mistral, Gemma, and generic transformers. Inspired by Aryagm's adapter pattern and bstnxbt's per-family engine approach. """ from __future__ import annotations import json from dataclasses import dataclass from pathlib import Path from typing import Any, Optional, Tuple, List, Dict import mlx.core as mx import mlx.nn as nn from huggingface_hub import snapshot_download from mlx_lm import load from mlx_lm.models import cache as cache_lib # ────────────────────────────────────────────────────────────────────────────── # Architecture registry — maps model_type → adapter class # ────────────────────────────────────────────────────────────────────────────── ARCH_LAYER_MAP: Dict[str, Dict[str, Any]] = { "qwen3": { "layers_attr": "model.layers", "embed_attr": "model.embed_tokens", "norm_attr": "model.norm", "lm_head_attr": "lm_head", "cache_type": "KVCache", "make_cache_fn": "make_cache", "tie_embeddings": True, "model_type": "qwen3", }, "qwen2": { "layers_attr": "model.layers", "embed_attr": "model.embed_tokens", "norm_attr": "model.norm", "lm_head_attr": "lm_head", "cache_type": "KVCache", "make_cache_fn": "make_cache", "tie_embeddings": True, "model_type": "qwen2", }, "qwen3_5": { "layers_attr": "language_model.model.layers", "embed_attr": "language_model.model.embed_tokens", "norm_attr": "language_model.model.norm", "lm_head_attr": "language_model.lm_head", "cache_type": "ArraysCache", "make_cache_fn": "make_cache", "tie_embeddings": True, "model_type": "qwen3_5", "has_hybrid_attention": True, "has_linear_attention": True, }, "llama": { "layers_attr": "model.layers", "embed_attr": "model.embed_tokens", "norm_attr": "model.norm", "lm_head_attr": "lm_head", "cache_type": "KVCache", "make_cache_fn": "make_cache", "tie_embeddings": False, "model_type": "llama", }, "mistral": { "layers_attr": "model.layers", "embed_attr": "model.embed_tokens", "norm_attr": "model.norm", "lm_head_attr": "lm_head", "cache_type": "KVCache", "make_cache_fn": "make_cache", "tie_embeddings": False, "model_type": "mistral", }, "gemma": { "layers_attr": "model.layers", "embed_attr": "model.embed_tokens", "norm_attr": "model.norm", "lm_head_attr": "lm_head", "cache_type": "KVCache", "make_cache_fn": "make_cache", "tie_embeddings": True, "model_type": "gemma", "norm_eps": 1e-6, }, "gemma2": { "layers_attr": "model.layers", "embed_attr": "model.embed_tokens", "norm_attr": "model.norm", "lm_head_attr": "lm_head", "cache_type": "KVCache", "make_cache_fn": "make_cache", "tie_embeddings": True, "model_type": "gemma2", "norm_eps": 1e-6, }, "generic": { "layers_attr": "layers", "embed_attr": "embedding", "norm_attr": "norm", "lm_head_attr": "lm_head", "cache_type": "KVCache", "make_cache_fn": None, "tie_embeddings": False, "model_type": "generic", }, } def resolve_model_path(path_or_repo: str) -> Path: """Resolve a model path or HF Hub repo ID to a local path.""" path = Path(path_or_repo) if path.exists(): return path return Path(snapshot_download(path_or_repo)) def _get_attr(obj: Any, attr_path: str) -> Any: """Get nested attribute by dot-path, e.g. 'language_model.model.layers'.""" for part in attr_path.split("."): if obj is None: return None obj = getattr(obj, part, None) return obj def detect_model_architecture(model, config: Optional[Dict] = None) -> str: """Auto-detect model architecture from model structure and config.""" # Try config first if config is None and hasattr(model, "config"): if hasattr(model.config, "to_dict"): config = model.config.to_dict() elif hasattr(model.config, "model_type"): config = {"model_type": model.config.model_type} if config and "model_type" in config: mt = config["model_type"] if mt in ARCH_LAYER_MAP: return mt # Aliases if mt.startswith("qwen3_5") or mt == "qwen3.5": return "qwen3_5" if mt.startswith("qwen3"): return "qwen3" if mt.startswith("qwen2"): return "qwen2" if mt.startswith("llama"): return "llama" if mt.startswith("mistral"): return "mistral" if mt == "gemma2": return "gemma2" if mt.startswith("gemma"): return "gemma" # Structural detection if hasattr(model, "language_model"): return "qwen3_5" if hasattr(model, "model") and hasattr(model.model, "layers"): return "llama" # llama, qwen3, mistral all share this if hasattr(model, "layers"): return "generic" return "generic" # ────────────────────────────────────────────────────────────────────────────── # Base adapter class — defines the contract all adapters must implement # ────────────────────────────────────────────────────────────────────────────── class MLXTargetAdapter: """Base adapter for DFlash target model interaction. Every supported architecture needs an adapter that knows: - Where embeddings live - How to iterate layers and extract hidden states - How to create/manage KV caches - How to call the LM head - How to trim/rewind caches on rejection """ family: str = "unknown" arch_info: Dict[str, Any] = {} def __init__(self, model, config: Optional[Dict] = None): self.model = model self.config = config or {} self._detect_attributes() def _detect_attributes(self): """Resolve embedding, layer, norm, lm_head references.""" arch = ARCH_LAYER_MAP.get(self.family, ARCH_LAYER_MAP["generic"]) self.arch_info = arch.copy() # Try exact path first self._embed = _get_attr(self.model, arch["embed_attr"]) self._layers = _get_attr(self.model, arch["layers_attr"]) self._norm = _get_attr(self.model, arch["norm_attr"]) self._lm_head = _get_attr(self.model, arch["lm_head_attr"]) # Fallback: probe common locations if self._embed is None: for attr in ("embedding", "token_embedding", "embed_tokens", "wte"): self._embed = getattr(self.model, attr, None) if self._embed is not None: break if self._layers is None: self._layers = getattr(self.model, "layers", None) if self._norm is None: self._norm = getattr(self.model, "norm", None) if self._lm_head is None: self._lm_head = getattr(self.model, "lm_head", None) # ── Tokenization / Prompt ─────────────────────────────────────────────── def build_prompt(self, tokenizer, prompt_text: str, enable_thinking: bool = False) -> mx.array: """Build prompt tokens from text.""" messages = [{"role": "user", "content": prompt_text}] try: text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, enable_thinking=enable_thinking, ) except TypeError: text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) tokens = tokenizer.encode(text, add_special_tokens=False) return mx.array(tokens, dtype=mx.uint32) def stop_token_ids(self, tokenizer) -> set[int]: """Get set of stop token IDs.""" eos = tokenizer.eos_token_ids if isinstance(eos, int): return {eos} if isinstance(eos, (list, tuple)): return set(eos) return set() # ── Embeddings / LM Head ──────────────────────────────────────────────── def embed_tokens(self, tokens: mx.array) -> mx.array: """Embed token IDs to hidden states.""" if self._embed is None: raise RuntimeError(f"[{self.family}] Could not find embedding layer") return self._embed(tokens) def lm_head_logits(self, hidden_states: mx.array) -> mx.array: """Project hidden states to vocab logits.""" if self._lm_head is not None: return self._lm_head(hidden_states) # Tie-word-embedding fallback if self.arch_info.get("tie_embeddings") and self._embed is not None: if hasattr(self._embed, "as_linear"): return self._embed.as_linear(hidden_states) raise RuntimeError(f"[{self.family}] Could not find LM head") def lm_head_argmax(self, hidden_states: mx.array) -> mx.array: """Greedy next-token from hidden states.""" logits = self.lm_head_logits(hidden_states) return mx.argmax(logits, axis=-1).astype(mx.uint32) # ── Cache Management ────────────────────────────────────────────────────── def make_cache(self) -> list[Any]: """Create fresh KV cache for all layers.""" cache_type = self.arch_info.get("cache_type", "KVCache") num_layers = len(self._layers) if self._layers is not None else 0 if cache_type == "KVCache": return [cache_lib.KVCache() for _ in range(num_layers)] elif cache_type == "ArraysCache": return [cache_lib.ArraysCache() for _ in range(num_layers)] else: return [None for _ in range(num_layers)] def rewind_kv_caches(self, cache: list[Any], num_tokens: int) -> None: """Trim cache to accepted prefix length.""" for layer_cache in cache: if isinstance(layer_cache, cache_lib.KVCache): layer_cache.trim(num_tokens) elif isinstance(layer_cache, cache_lib.ArraysCache) and hasattr(layer_cache, "trim"): layer_cache.trim(num_tokens) # ── Forward with Hidden-State Extraction ───────────────────────────────── def create_attention_mask(self, hidden_states: mx.array, cache: Any = None) -> Optional[mx.array]: """Build causal attention mask appropriate for this architecture.""" # Default: simple causal mask via triangular structure # MLX fast attention often handles this internally, but we provide a hook seq_len = hidden_states.shape[1] if cache is not None and hasattr(cache, "offset"): # Cached generation — no mask needed for single new token if seq_len == 1: return None return None # MLX models typically compute mask internally def forward_with_hidden_states( self, tokens: mx.array, cache: list[Any], layer_ids: List[int], output_rollback_records: bool = False, ) -> Tuple[mx.array, mx.array] | Tuple[mx.array, mx.array, Dict]: """ Run target model, returning (logits, target_hidden). target_hidden = concatenation of hidden states at layer_ids. Args: tokens: Input token IDs [bsz, seq_len] cache: Per-layer KV cache layer_ids: Target layer indices for DFlash conditioning output_rollback_records: Whether to return per-layer state for rollback Returns: (logits, target_hidden) or (logits, target_hidden, rollback_records) """ if self._embed is None or self._layers is None: raise RuntimeError(f"[{self.family}] Model attributes not resolved") hidden = self.embed_tokens(tokens) mask = self.create_attention_mask(hidden, cache[0] if cache else None) selected: List[mx.array] = [] rollback_records: Dict[int, Dict[str, mx.array]] = {} target_layer_ids = set(layer_ids) for idx, (layer, layer_cache) in enumerate(zip(self._layers, cache)): # Each layer returns updated hidden states # Some return tuple (hidden, cache_update), some just hidden layer_out = layer(hidden, mask=mask, cache=layer_cache) if isinstance(layer_out, tuple): hidden = layer_out[0] else: hidden = layer_out if idx in target_layer_ids: selected.append(hidden) # Final norm + LM head if self._norm is not None: hidden = self._norm(hidden) logits = self.lm_head_logits(hidden) # Concatenate selected hidden states across feature dim if selected: target_hidden = mx.concatenate(selected, axis=-1) else: # Fallback: use final hidden state target_hidden = hidden if output_rollback_records: return logits, target_hidden, rollback_records return logits, target_hidden def forward_verifier_states( self, tokens: mx.array, cache: list[Any], layer_ids: List[int], ) -> Tuple[mx.array, mx.array, Dict]: """Forward pass that always returns rollback records.""" return self.forward_with_hidden_states( tokens, cache, layer_ids, output_rollback_records=True ) def forward_accept_all_block( self, tokens: mx.array, cache: list[Any], layer_ids: List[int], ) -> Tuple[mx.array, mx.array]: """Single-token forward returning last-position logits + target hidden.""" logits, target_hidden = self.forward_with_hidden_states( tokens, cache, layer_ids, output_rollback_records=False ) return logits[:, -1:, :], target_hidden # ── Cache Summary (for debugging) ─────────────────────────────────────── def cache_summary(self, cache: list[Any]) -> str: """Human-readable cache status.""" parts: List[str] = [] for idx, c in enumerate(cache): if isinstance(c, cache_lib.KVCache): parts.append(f"{idx}:kv={c.offset}") elif isinstance(c, cache_lib.ArraysCache): rec = None if c[1] is None else tuple(c[1].shape) parts.append(f"{idx}:ssm={rec}") else: parts.append(f"{idx}:none") return " ".join(parts) # ────────────────────────────────────────────────────────────────────────────── # Per-family adapter subclasses (for architecture-specific overrides) # ────────────────────────────────────────────────────────────────────────────── class Qwen3Adapter(MLXTargetAdapter): family = "qwen3" def build_prompt(self, tokenizer, prompt_text: str, enable_thinking: bool = False) -> mx.array: messages = [{"role": "user", "content": prompt_text}] try: text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, enable_thinking=enable_thinking, ) except TypeError: text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) tokens = tokenizer.encode(text, add_special_tokens=False) return mx.array(tokens, dtype=mx.uint32) def create_attention_mask(self, hidden_states: mx.array, cache: Any = None) -> Optional[mx.array]: try: from mlx_lm.models import qwen3 return qwen3.create_attention_mask(hidden_states, cache) except Exception: return super().create_attention_mask(hidden_states, cache) def lm_head_logits(self, hidden_states: mx.array) -> mx.array: # Qwen3 often uses tied embeddings if self.arch_info.get("tie_embeddings") and self._embed is not None: if hasattr(self._embed, "as_linear"): return self._embed.as_linear(hidden_states) if self._lm_head is not None: return self._lm_head(hidden_states) raise RuntimeError("[qwen3] No LM head found") class Qwen35Adapter(MLXTargetAdapter): family = "qwen3_5" def build_prompt(self, tokenizer, prompt_text: str, enable_thinking: bool = False) -> mx.array: messages = [{"role": "user", "content": prompt_text}] try: text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, enable_thinking=enable_thinking, ) except TypeError: text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) tokens = tokenizer.encode(text, add_special_tokens=False) return mx.array(tokens, dtype=mx.uint32) def create_attention_mask(self, hidden_states: mx.array, cache: Any = None) -> Optional[mx.array]: try: from mlx_lm.models import qwen3_5 # Qwen3.5 has hybrid attention: full-attention + linear-attention if cache is not None and hasattr(cache, "__len__") and len(cache) > 0: # Detect cache type if hasattr(cache[0], "fa_idx"): fa_mask = qwen3_5.create_attention_mask(hidden_states, cache[0]) return fa_mask except Exception: pass return super().create_attention_mask(hidden_states, cache) def forward_with_hidden_states( self, tokens: mx.array, cache: list[Any], layer_ids: List[int], output_rollback_records: bool = False, ): # Qwen3.5 needs special handling for hybrid attention layers if self._embed is None or self._layers is None: raise RuntimeError("[qwen3_5] Model attributes not resolved") hidden = self.embed_tokens(tokens) # Build masks for full-attention and linear-attention layers try: from mlx_lm.models import qwen3_5 fa_mask = qwen3_5.create_attention_mask(hidden_states=hidden, cache=cache[0] if cache else None) except Exception: fa_mask = None selected: List[mx.array] = [] target_layer_ids = set(layer_ids) for idx, (layer, layer_cache) in enumerate(zip(self._layers, cache)): # Qwen3.5 layers have is_linear flag mask = None if hasattr(layer, "is_linear") and layer.is_linear: # Linear attention layer — uses different mask or none pass else: mask = fa_mask layer_out = layer(hidden, mask=mask, cache=layer_cache) if isinstance(layer_out, tuple): hidden = layer_out[0] else: hidden = layer_out if idx in target_layer_ids: selected.append(hidden) if self._norm is not None: hidden = self._norm(hidden) logits = self.lm_head_logits(hidden) if selected: target_hidden = mx.concatenate(selected, axis=-1) else: target_hidden = hidden if output_rollback_records: return logits, target_hidden, {} return logits, target_hidden class LlamaAdapter(MLXTargetAdapter): family = "llama" def create_attention_mask(self, hidden_states: mx.array, cache: Any = None) -> Optional[mx.array]: try: from mlx_lm.models import llama return llama.create_attention_mask(hidden_states, cache) except Exception: return super().create_attention_mask(hidden_states, cache) class MistralAdapter(MLXTargetAdapter): family = "mistral" def create_attention_mask(self, hidden_states: mx.array, cache: Any = None) -> Optional[mx.array]: try: from mlx_lm.models import mistral return mistral.create_attention_mask(hidden_states, cache) except Exception: return super().create_attention_mask(hidden_states, cache) class GemmaAdapter(MLXTargetAdapter): family = "gemma" def create_attention_mask(self, hidden_states: mx.array, cache: Any = None) -> Optional[mx.array]: try: from mlx_lm.models import gemma return gemma.create_attention_mask(hidden_states, cache) except Exception: return super().create_attention_mask(hidden_states, cache) # ────────────────────────────────────────────────────────────────────────────── # Adapter registry and factory # ────────────────────────────────────────────────────────────────────────────── ADAPTERS: Dict[str, type[MLXTargetAdapter]] = { "qwen3": Qwen3Adapter, "qwen2": Qwen3Adapter, # Shares structure "qwen3_5": Qwen35Adapter, "llama": LlamaAdapter, "mistral": MistralAdapter, "gemma": GemmaAdapter, "gemma2": GemmaAdapter, "generic": MLXTargetAdapter, } def adapter_for_model_type(model_type: str) -> Optional[type[MLXTargetAdapter]]: """Get adapter class for a model type string.""" # Direct match if model_type in ADAPTERS: return ADAPTERS[model_type] # Aliases if model_type.startswith("qwen3_5") or model_type == "qwen3.5": return Qwen35Adapter if model_type.startswith("qwen3"): return Qwen3Adapter if model_type.startswith("qwen2"): return Qwen3Adapter if model_type.startswith("llama"): return LlamaAdapter if model_type.startswith("mistral"): return MistralAdapter if model_type == "gemma2": return GemmaAdapter if model_type.startswith("gemma"): return GemmaAdapter return None # ────────────────────────────────────────────────────────────────────────────── # LoadedTargetModel — convenience wrapper binding model + adapter + tokenizer # ────────────────────────────────────────────────────────────────────────────── @dataclass class LoadedTargetModel: requested_model: str resolved_model_path: Path model: Any tokenizer: Any adapter: MLXTargetAdapter def build_prompt(self, prompt_text: str, enable_thinking: bool = False) -> mx.array: return self.adapter.build_prompt(self.tokenizer, prompt_text, enable_thinking) def stop_token_ids(self) -> set[int]: return self.adapter.stop_token_ids(self.tokenizer) def make_cache(self) -> list[Any]: return self.adapter.make_cache() def embed_tokens(self, tokens: mx.array) -> mx.array: return self.adapter.embed_tokens(tokens) def lm_head_logits(self, hidden_states: mx.array) -> mx.array: return self.adapter.lm_head_logits(hidden_states) def lm_head_argmax(self, hidden_states: mx.array) -> mx.array: return self.adapter.lm_head_argmax(hidden_states) def forward_with_hidden_states( self, tokens: mx.array, cache: list[Any], layer_ids: List[int], output_rollback_records: bool = False, ): return self.adapter.forward_with_hidden_states( tokens, cache, layer_ids, output_rollback_records ) def forward_verifier_states(self, tokens: mx.array, cache: list[Any], layer_ids: List[int]): return self.adapter.forward_verifier_states(tokens, cache, layer_ids) def forward_accept_all_block(self, tokens: mx.array, cache: list[Any], layer_ids: List[int]): return self.adapter.forward_accept_all_block(tokens, cache, layer_ids) def rewind_kv_caches(self, cache: list[Any], num_tokens: int) -> None: self.adapter.rewind_kv_caches(cache, num_tokens) def cache_summary(self, cache: list[Any]) -> str: return self.adapter.cache_summary(cache) def load_target_model(path_or_repo: str) -> LoadedTargetModel: """Load an MLX target model with the correct adapter. Args: path_or_repo: Local path or HF Hub model ID Returns: LoadedTargetModel with architecture-aware adapter """ base_path = resolve_model_path(path_or_repo) # Load config to detect architecture config_path = base_path / "config.json" if config_path.exists(): with open(config_path, "r") as f: config = json.load(f) else: config = {} model_type = config.get("model_type", "generic") adapter_cls = adapter_for_model_type(model_type) if adapter_cls is None: registered = ", ".join(sorted(ADAPTERS.keys())) raise NotImplementedError( f"Unsupported MLX DFlash target model_type={model_type!r}. " f"Registered adapters: {registered}. " f"You can add one by subclassing MLXTargetAdapter in adapters.py." ) # Load model + tokenizer via mlx_lm model, tokenizer = load(str(base_path)) # Instantiate adapter adapter = adapter_cls(model, config) return LoadedTargetModel( requested_model=path_or_repo, resolved_model_path=base_path, model=model, tokenizer=tokenizer, adapter=adapter, )