| """ |
| Universal architecture adapters for DFlash speculative decoding on MLX. |
| |
| Supports: Qwen3, Qwen3.5, LLaMA (2/3), Mistral, Gemma, and generic transformers. |
| Inspired by Aryagm's adapter pattern and bstnxbt's per-family engine approach. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import json |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Any, Optional, Tuple, List, Dict |
|
|
| import mlx.core as mx |
| import mlx.nn as nn |
| from huggingface_hub import snapshot_download |
| from mlx_lm import load |
| from mlx_lm.models import cache as cache_lib |
|
|
|
|
| |
| |
| |
|
|
| ARCH_LAYER_MAP: Dict[str, Dict[str, Any]] = { |
| "qwen3": { |
| "layers_attr": "model.layers", |
| "embed_attr": "model.embed_tokens", |
| "norm_attr": "model.norm", |
| "lm_head_attr": "lm_head", |
| "cache_type": "KVCache", |
| "make_cache_fn": "make_cache", |
| "tie_embeddings": True, |
| "model_type": "qwen3", |
| }, |
| "qwen2": { |
| "layers_attr": "model.layers", |
| "embed_attr": "model.embed_tokens", |
| "norm_attr": "model.norm", |
| "lm_head_attr": "lm_head", |
| "cache_type": "KVCache", |
| "make_cache_fn": "make_cache", |
| "tie_embeddings": True, |
| "model_type": "qwen2", |
| }, |
| "qwen3_5": { |
| "layers_attr": "language_model.model.layers", |
| "embed_attr": "language_model.model.embed_tokens", |
| "norm_attr": "language_model.model.norm", |
| "lm_head_attr": "language_model.lm_head", |
| "cache_type": "ArraysCache", |
| "make_cache_fn": "make_cache", |
| "tie_embeddings": True, |
| "model_type": "qwen3_5", |
| "has_hybrid_attention": True, |
| "has_linear_attention": True, |
| }, |
| "llama": { |
| "layers_attr": "model.layers", |
| "embed_attr": "model.embed_tokens", |
| "norm_attr": "model.norm", |
| "lm_head_attr": "lm_head", |
| "cache_type": "KVCache", |
| "make_cache_fn": "make_cache", |
| "tie_embeddings": False, |
| "model_type": "llama", |
| }, |
| "mistral": { |
| "layers_attr": "model.layers", |
| "embed_attr": "model.embed_tokens", |
| "norm_attr": "model.norm", |
| "lm_head_attr": "lm_head", |
| "cache_type": "KVCache", |
| "make_cache_fn": "make_cache", |
| "tie_embeddings": False, |
| "model_type": "mistral", |
| }, |
| "gemma": { |
| "layers_attr": "model.layers", |
| "embed_attr": "model.embed_tokens", |
| "norm_attr": "model.norm", |
| "lm_head_attr": "lm_head", |
| "cache_type": "KVCache", |
| "make_cache_fn": "make_cache", |
| "tie_embeddings": True, |
| "model_type": "gemma", |
| "norm_eps": 1e-6, |
| }, |
| "gemma2": { |
| "layers_attr": "model.layers", |
| "embed_attr": "model.embed_tokens", |
| "norm_attr": "model.norm", |
| "lm_head_attr": "lm_head", |
| "cache_type": "KVCache", |
| "make_cache_fn": "make_cache", |
| "tie_embeddings": True, |
| "model_type": "gemma2", |
| "norm_eps": 1e-6, |
| }, |
| "generic": { |
| "layers_attr": "layers", |
| "embed_attr": "embedding", |
| "norm_attr": "norm", |
| "lm_head_attr": "lm_head", |
| "cache_type": "KVCache", |
| "make_cache_fn": None, |
| "tie_embeddings": False, |
| "model_type": "generic", |
| }, |
| } |
|
|
|
|
| def resolve_model_path(path_or_repo: str) -> Path: |
| """Resolve a model path or HF Hub repo ID to a local path.""" |
| path = Path(path_or_repo) |
| if path.exists(): |
| return path |
| return Path(snapshot_download(path_or_repo)) |
|
|
|
|
| def _get_attr(obj: Any, attr_path: str) -> Any: |
| """Get nested attribute by dot-path, e.g. 'language_model.model.layers'.""" |
| for part in attr_path.split("."): |
| if obj is None: |
| return None |
| obj = getattr(obj, part, None) |
| return obj |
|
|
|
|
| def detect_model_architecture(model, config: Optional[Dict] = None) -> str: |
| """Auto-detect model architecture from model structure and config.""" |
| |
| if config is None and hasattr(model, "config"): |
| if hasattr(model.config, "to_dict"): |
| config = model.config.to_dict() |
| elif hasattr(model.config, "model_type"): |
| config = {"model_type": model.config.model_type} |
| |
| if config and "model_type" in config: |
| mt = config["model_type"] |
| if mt in ARCH_LAYER_MAP: |
| return mt |
| |
| if mt.startswith("qwen3_5") or mt == "qwen3.5": |
| return "qwen3_5" |
| if mt.startswith("qwen3"): |
| return "qwen3" |
| if mt.startswith("qwen2"): |
| return "qwen2" |
| if mt.startswith("llama"): |
| return "llama" |
| if mt.startswith("mistral"): |
| return "mistral" |
| if mt == "gemma2": |
| return "gemma2" |
| if mt.startswith("gemma"): |
| return "gemma" |
| |
| |
| if hasattr(model, "language_model"): |
| return "qwen3_5" |
| if hasattr(model, "model") and hasattr(model.model, "layers"): |
| return "llama" |
| if hasattr(model, "layers"): |
| return "generic" |
| |
| return "generic" |
|
|
|
|
| |
| |
| |
|
|
| class MLXTargetAdapter: |
| """Base adapter for DFlash target model interaction. |
| |
| Every supported architecture needs an adapter that knows: |
| - Where embeddings live |
| - How to iterate layers and extract hidden states |
| - How to create/manage KV caches |
| - How to call the LM head |
| - How to trim/rewind caches on rejection |
| """ |
| |
| family: str = "unknown" |
| arch_info: Dict[str, Any] = {} |
| |
| def __init__(self, model, config: Optional[Dict] = None): |
| self.model = model |
| self.config = config or {} |
| self._detect_attributes() |
| |
| def _detect_attributes(self): |
| """Resolve embedding, layer, norm, lm_head references.""" |
| arch = ARCH_LAYER_MAP.get(self.family, ARCH_LAYER_MAP["generic"]) |
| self.arch_info = arch.copy() |
| |
| |
| self._embed = _get_attr(self.model, arch["embed_attr"]) |
| self._layers = _get_attr(self.model, arch["layers_attr"]) |
| self._norm = _get_attr(self.model, arch["norm_attr"]) |
| self._lm_head = _get_attr(self.model, arch["lm_head_attr"]) |
| |
| |
| if self._embed is None: |
| for attr in ("embedding", "token_embedding", "embed_tokens", "wte"): |
| self._embed = getattr(self.model, attr, None) |
| if self._embed is not None: |
| break |
| |
| if self._layers is None: |
| self._layers = getattr(self.model, "layers", None) |
| |
| if self._norm is None: |
| self._norm = getattr(self.model, "norm", None) |
| |
| if self._lm_head is None: |
| self._lm_head = getattr(self.model, "lm_head", None) |
| |
| |
| |
| def build_prompt(self, tokenizer, prompt_text: str, enable_thinking: bool = False) -> mx.array: |
| """Build prompt tokens from text.""" |
| messages = [{"role": "user", "content": prompt_text}] |
| try: |
| text = tokenizer.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=True, |
| enable_thinking=enable_thinking, |
| ) |
| except TypeError: |
| text = tokenizer.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=True, |
| ) |
| tokens = tokenizer.encode(text, add_special_tokens=False) |
| return mx.array(tokens, dtype=mx.uint32) |
| |
| def stop_token_ids(self, tokenizer) -> set[int]: |
| """Get set of stop token IDs.""" |
| eos = tokenizer.eos_token_ids |
| if isinstance(eos, int): |
| return {eos} |
| if isinstance(eos, (list, tuple)): |
| return set(eos) |
| return set() |
| |
| |
| |
| def embed_tokens(self, tokens: mx.array) -> mx.array: |
| """Embed token IDs to hidden states.""" |
| if self._embed is None: |
| raise RuntimeError(f"[{self.family}] Could not find embedding layer") |
| return self._embed(tokens) |
| |
| def lm_head_logits(self, hidden_states: mx.array) -> mx.array: |
| """Project hidden states to vocab logits.""" |
| if self._lm_head is not None: |
| return self._lm_head(hidden_states) |
| |
| if self.arch_info.get("tie_embeddings") and self._embed is not None: |
| if hasattr(self._embed, "as_linear"): |
| return self._embed.as_linear(hidden_states) |
| raise RuntimeError(f"[{self.family}] Could not find LM head") |
| |
| def lm_head_argmax(self, hidden_states: mx.array) -> mx.array: |
| """Greedy next-token from hidden states.""" |
| logits = self.lm_head_logits(hidden_states) |
| return mx.argmax(logits, axis=-1).astype(mx.uint32) |
| |
| |
| |
| def make_cache(self) -> list[Any]: |
| """Create fresh KV cache for all layers.""" |
| cache_type = self.arch_info.get("cache_type", "KVCache") |
| num_layers = len(self._layers) if self._layers is not None else 0 |
| |
| if cache_type == "KVCache": |
| return [cache_lib.KVCache() for _ in range(num_layers)] |
| elif cache_type == "ArraysCache": |
| return [cache_lib.ArraysCache() for _ in range(num_layers)] |
| else: |
| return [None for _ in range(num_layers)] |
| |
| def rewind_kv_caches(self, cache: list[Any], num_tokens: int) -> None: |
| """Trim cache to accepted prefix length.""" |
| for layer_cache in cache: |
| if isinstance(layer_cache, cache_lib.KVCache): |
| layer_cache.trim(num_tokens) |
| elif isinstance(layer_cache, cache_lib.ArraysCache) and hasattr(layer_cache, "trim"): |
| layer_cache.trim(num_tokens) |
| |
| |
| |
| def create_attention_mask(self, hidden_states: mx.array, cache: Any = None) -> Optional[mx.array]: |
| """Build causal attention mask appropriate for this architecture.""" |
| |
| |
| seq_len = hidden_states.shape[1] |
| if cache is not None and hasattr(cache, "offset"): |
| |
| if seq_len == 1: |
| return None |
| return None |
| |
| def forward_with_hidden_states( |
| self, |
| tokens: mx.array, |
| cache: list[Any], |
| layer_ids: List[int], |
| output_rollback_records: bool = False, |
| ) -> Tuple[mx.array, mx.array] | Tuple[mx.array, mx.array, Dict]: |
| """ |
| Run target model, returning (logits, target_hidden). |
| target_hidden = concatenation of hidden states at layer_ids. |
| |
| Args: |
| tokens: Input token IDs [bsz, seq_len] |
| cache: Per-layer KV cache |
| layer_ids: Target layer indices for DFlash conditioning |
| output_rollback_records: Whether to return per-layer state for rollback |
| |
| Returns: |
| (logits, target_hidden) or (logits, target_hidden, rollback_records) |
| """ |
| if self._embed is None or self._layers is None: |
| raise RuntimeError(f"[{self.family}] Model attributes not resolved") |
| |
| hidden = self.embed_tokens(tokens) |
| mask = self.create_attention_mask(hidden, cache[0] if cache else None) |
| |
| selected: List[mx.array] = [] |
| rollback_records: Dict[int, Dict[str, mx.array]] = {} |
| target_layer_ids = set(layer_ids) |
| |
| for idx, (layer, layer_cache) in enumerate(zip(self._layers, cache)): |
| |
| |
| layer_out = layer(hidden, mask=mask, cache=layer_cache) |
| if isinstance(layer_out, tuple): |
| hidden = layer_out[0] |
| else: |
| hidden = layer_out |
| |
| if idx in target_layer_ids: |
| selected.append(hidden) |
| |
| |
| if self._norm is not None: |
| hidden = self._norm(hidden) |
| logits = self.lm_head_logits(hidden) |
| |
| |
| if selected: |
| target_hidden = mx.concatenate(selected, axis=-1) |
| else: |
| |
| target_hidden = hidden |
| |
| if output_rollback_records: |
| return logits, target_hidden, rollback_records |
| return logits, target_hidden |
| |
| def forward_verifier_states( |
| self, |
| tokens: mx.array, |
| cache: list[Any], |
| layer_ids: List[int], |
| ) -> Tuple[mx.array, mx.array, Dict]: |
| """Forward pass that always returns rollback records.""" |
| return self.forward_with_hidden_states( |
| tokens, cache, layer_ids, output_rollback_records=True |
| ) |
| |
| def forward_accept_all_block( |
| self, |
| tokens: mx.array, |
| cache: list[Any], |
| layer_ids: List[int], |
| ) -> Tuple[mx.array, mx.array]: |
| """Single-token forward returning last-position logits + target hidden.""" |
| logits, target_hidden = self.forward_with_hidden_states( |
| tokens, cache, layer_ids, output_rollback_records=False |
| ) |
| return logits[:, -1:, :], target_hidden |
| |
| |
| |
| def cache_summary(self, cache: list[Any]) -> str: |
| """Human-readable cache status.""" |
| parts: List[str] = [] |
| for idx, c in enumerate(cache): |
| if isinstance(c, cache_lib.KVCache): |
| parts.append(f"{idx}:kv={c.offset}") |
| elif isinstance(c, cache_lib.ArraysCache): |
| rec = None if c[1] is None else tuple(c[1].shape) |
| parts.append(f"{idx}:ssm={rec}") |
| else: |
| parts.append(f"{idx}:none") |
| return " ".join(parts) |
|
|
|
|
| |
| |
| |
|
|
| class Qwen3Adapter(MLXTargetAdapter): |
| family = "qwen3" |
| |
| def build_prompt(self, tokenizer, prompt_text: str, enable_thinking: bool = False) -> mx.array: |
| messages = [{"role": "user", "content": prompt_text}] |
| try: |
| text = tokenizer.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=True, |
| enable_thinking=enable_thinking, |
| ) |
| except TypeError: |
| text = tokenizer.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=True, |
| ) |
| tokens = tokenizer.encode(text, add_special_tokens=False) |
| return mx.array(tokens, dtype=mx.uint32) |
| |
| def create_attention_mask(self, hidden_states: mx.array, cache: Any = None) -> Optional[mx.array]: |
| try: |
| from mlx_lm.models import qwen3 |
| return qwen3.create_attention_mask(hidden_states, cache) |
| except Exception: |
| return super().create_attention_mask(hidden_states, cache) |
| |
| def lm_head_logits(self, hidden_states: mx.array) -> mx.array: |
| |
| if self.arch_info.get("tie_embeddings") and self._embed is not None: |
| if hasattr(self._embed, "as_linear"): |
| return self._embed.as_linear(hidden_states) |
| if self._lm_head is not None: |
| return self._lm_head(hidden_states) |
| raise RuntimeError("[qwen3] No LM head found") |
|
|
|
|
| class Qwen35Adapter(MLXTargetAdapter): |
| family = "qwen3_5" |
| |
| def build_prompt(self, tokenizer, prompt_text: str, enable_thinking: bool = False) -> mx.array: |
| messages = [{"role": "user", "content": prompt_text}] |
| try: |
| text = tokenizer.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=True, |
| enable_thinking=enable_thinking, |
| ) |
| except TypeError: |
| text = tokenizer.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=True, |
| ) |
| tokens = tokenizer.encode(text, add_special_tokens=False) |
| return mx.array(tokens, dtype=mx.uint32) |
| |
| def create_attention_mask(self, hidden_states: mx.array, cache: Any = None) -> Optional[mx.array]: |
| try: |
| from mlx_lm.models import qwen3_5 |
| |
| if cache is not None and hasattr(cache, "__len__") and len(cache) > 0: |
| |
| if hasattr(cache[0], "fa_idx"): |
| fa_mask = qwen3_5.create_attention_mask(hidden_states, cache[0]) |
| return fa_mask |
| except Exception: |
| pass |
| return super().create_attention_mask(hidden_states, cache) |
| |
| def forward_with_hidden_states( |
| self, |
| tokens: mx.array, |
| cache: list[Any], |
| layer_ids: List[int], |
| output_rollback_records: bool = False, |
| ): |
| |
| if self._embed is None or self._layers is None: |
| raise RuntimeError("[qwen3_5] Model attributes not resolved") |
| |
| hidden = self.embed_tokens(tokens) |
| |
| |
| try: |
| from mlx_lm.models import qwen3_5 |
| fa_mask = qwen3_5.create_attention_mask(hidden_states=hidden, cache=cache[0] if cache else None) |
| except Exception: |
| fa_mask = None |
| |
| selected: List[mx.array] = [] |
| target_layer_ids = set(layer_ids) |
| |
| for idx, (layer, layer_cache) in enumerate(zip(self._layers, cache)): |
| |
| mask = None |
| if hasattr(layer, "is_linear") and layer.is_linear: |
| |
| pass |
| else: |
| mask = fa_mask |
| |
| layer_out = layer(hidden, mask=mask, cache=layer_cache) |
| if isinstance(layer_out, tuple): |
| hidden = layer_out[0] |
| else: |
| hidden = layer_out |
| |
| if idx in target_layer_ids: |
| selected.append(hidden) |
| |
| if self._norm is not None: |
| hidden = self._norm(hidden) |
| |
| logits = self.lm_head_logits(hidden) |
| |
| if selected: |
| target_hidden = mx.concatenate(selected, axis=-1) |
| else: |
| target_hidden = hidden |
| |
| if output_rollback_records: |
| return logits, target_hidden, {} |
| return logits, target_hidden |
|
|
|
|
| class LlamaAdapter(MLXTargetAdapter): |
| family = "llama" |
| |
| def create_attention_mask(self, hidden_states: mx.array, cache: Any = None) -> Optional[mx.array]: |
| try: |
| from mlx_lm.models import llama |
| return llama.create_attention_mask(hidden_states, cache) |
| except Exception: |
| return super().create_attention_mask(hidden_states, cache) |
|
|
|
|
| class MistralAdapter(MLXTargetAdapter): |
| family = "mistral" |
| |
| def create_attention_mask(self, hidden_states: mx.array, cache: Any = None) -> Optional[mx.array]: |
| try: |
| from mlx_lm.models import mistral |
| return mistral.create_attention_mask(hidden_states, cache) |
| except Exception: |
| return super().create_attention_mask(hidden_states, cache) |
|
|
|
|
| class GemmaAdapter(MLXTargetAdapter): |
| family = "gemma" |
| |
| def create_attention_mask(self, hidden_states: mx.array, cache: Any = None) -> Optional[mx.array]: |
| try: |
| from mlx_lm.models import gemma |
| return gemma.create_attention_mask(hidden_states, cache) |
| except Exception: |
| return super().create_attention_mask(hidden_states, cache) |
|
|
|
|
| |
| |
| |
|
|
| ADAPTERS: Dict[str, type[MLXTargetAdapter]] = { |
| "qwen3": Qwen3Adapter, |
| "qwen2": Qwen3Adapter, |
| "qwen3_5": Qwen35Adapter, |
| "llama": LlamaAdapter, |
| "mistral": MistralAdapter, |
| "gemma": GemmaAdapter, |
| "gemma2": GemmaAdapter, |
| "generic": MLXTargetAdapter, |
| } |
|
|
|
|
| def adapter_for_model_type(model_type: str) -> Optional[type[MLXTargetAdapter]]: |
| """Get adapter class for a model type string.""" |
| |
| if model_type in ADAPTERS: |
| return ADAPTERS[model_type] |
| |
| if model_type.startswith("qwen3_5") or model_type == "qwen3.5": |
| return Qwen35Adapter |
| if model_type.startswith("qwen3"): |
| return Qwen3Adapter |
| if model_type.startswith("qwen2"): |
| return Qwen3Adapter |
| if model_type.startswith("llama"): |
| return LlamaAdapter |
| if model_type.startswith("mistral"): |
| return MistralAdapter |
| if model_type == "gemma2": |
| return GemmaAdapter |
| if model_type.startswith("gemma"): |
| return GemmaAdapter |
| return None |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class LoadedTargetModel: |
| requested_model: str |
| resolved_model_path: Path |
| model: Any |
| tokenizer: Any |
| adapter: MLXTargetAdapter |
| |
| def build_prompt(self, prompt_text: str, enable_thinking: bool = False) -> mx.array: |
| return self.adapter.build_prompt(self.tokenizer, prompt_text, enable_thinking) |
| |
| def stop_token_ids(self) -> set[int]: |
| return self.adapter.stop_token_ids(self.tokenizer) |
| |
| def make_cache(self) -> list[Any]: |
| return self.adapter.make_cache() |
| |
| def embed_tokens(self, tokens: mx.array) -> mx.array: |
| return self.adapter.embed_tokens(tokens) |
| |
| def lm_head_logits(self, hidden_states: mx.array) -> mx.array: |
| return self.adapter.lm_head_logits(hidden_states) |
| |
| def lm_head_argmax(self, hidden_states: mx.array) -> mx.array: |
| return self.adapter.lm_head_argmax(hidden_states) |
| |
| def forward_with_hidden_states( |
| self, |
| tokens: mx.array, |
| cache: list[Any], |
| layer_ids: List[int], |
| output_rollback_records: bool = False, |
| ): |
| return self.adapter.forward_with_hidden_states( |
| tokens, cache, layer_ids, output_rollback_records |
| ) |
| |
| def forward_verifier_states(self, tokens: mx.array, cache: list[Any], layer_ids: List[int]): |
| return self.adapter.forward_verifier_states(tokens, cache, layer_ids) |
| |
| def forward_accept_all_block(self, tokens: mx.array, cache: list[Any], layer_ids: List[int]): |
| return self.adapter.forward_accept_all_block(tokens, cache, layer_ids) |
| |
| def rewind_kv_caches(self, cache: list[Any], num_tokens: int) -> None: |
| self.adapter.rewind_kv_caches(cache, num_tokens) |
| |
| def cache_summary(self, cache: list[Any]) -> str: |
| return self.adapter.cache_summary(cache) |
|
|
|
|
| def load_target_model(path_or_repo: str) -> LoadedTargetModel: |
| """Load an MLX target model with the correct adapter. |
| |
| Args: |
| path_or_repo: Local path or HF Hub model ID |
| |
| Returns: |
| LoadedTargetModel with architecture-aware adapter |
| """ |
| base_path = resolve_model_path(path_or_repo) |
| |
| |
| config_path = base_path / "config.json" |
| if config_path.exists(): |
| with open(config_path, "r") as f: |
| config = json.load(f) |
| else: |
| config = {} |
| |
| model_type = config.get("model_type", "generic") |
| adapter_cls = adapter_for_model_type(model_type) |
| |
| if adapter_cls is None: |
| registered = ", ".join(sorted(ADAPTERS.keys())) |
| raise NotImplementedError( |
| f"Unsupported MLX DFlash target model_type={model_type!r}. " |
| f"Registered adapters: {registered}. " |
| f"You can add one by subclassing MLXTargetAdapter in adapters.py." |
| ) |
| |
| |
| model, tokenizer = load(str(base_path)) |
| |
| |
| adapter = adapter_cls(model, config) |
| |
| return LoadedTargetModel( |
| requested_model=path_or_repo, |
| resolved_model_path=base_path, |
| model=model, |
| tokenizer=tokenizer, |
| adapter=adapter, |
| ) |
|
|