tritesh's picture
Upload dflash_mlx/adapters.py
bb76689 verified
"""
Universal architecture adapters for DFlash speculative decoding on MLX.
Supports: Qwen3, Qwen3.5, LLaMA (2/3), Mistral, Gemma, and generic transformers.
Inspired by Aryagm's adapter pattern and bstnxbt's per-family engine approach.
"""
from __future__ import annotations
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Optional, Tuple, List, Dict
import mlx.core as mx
import mlx.nn as nn
from huggingface_hub import snapshot_download
from mlx_lm import load
from mlx_lm.models import cache as cache_lib
# ──────────────────────────────────────────────────────────────────────────────
# Architecture registry β€” maps model_type β†’ adapter class
# ──────────────────────────────────────────────────────────────────────────────
ARCH_LAYER_MAP: Dict[str, Dict[str, Any]] = {
"qwen3": {
"layers_attr": "model.layers",
"embed_attr": "model.embed_tokens",
"norm_attr": "model.norm",
"lm_head_attr": "lm_head",
"cache_type": "KVCache",
"make_cache_fn": "make_cache",
"tie_embeddings": True,
"model_type": "qwen3",
},
"qwen2": {
"layers_attr": "model.layers",
"embed_attr": "model.embed_tokens",
"norm_attr": "model.norm",
"lm_head_attr": "lm_head",
"cache_type": "KVCache",
"make_cache_fn": "make_cache",
"tie_embeddings": True,
"model_type": "qwen2",
},
"qwen3_5": {
"layers_attr": "language_model.model.layers",
"embed_attr": "language_model.model.embed_tokens",
"norm_attr": "language_model.model.norm",
"lm_head_attr": "language_model.lm_head",
"cache_type": "ArraysCache",
"make_cache_fn": "make_cache",
"tie_embeddings": True,
"model_type": "qwen3_5",
"has_hybrid_attention": True,
"has_linear_attention": True,
},
"llama": {
"layers_attr": "model.layers",
"embed_attr": "model.embed_tokens",
"norm_attr": "model.norm",
"lm_head_attr": "lm_head",
"cache_type": "KVCache",
"make_cache_fn": "make_cache",
"tie_embeddings": False,
"model_type": "llama",
},
"mistral": {
"layers_attr": "model.layers",
"embed_attr": "model.embed_tokens",
"norm_attr": "model.norm",
"lm_head_attr": "lm_head",
"cache_type": "KVCache",
"make_cache_fn": "make_cache",
"tie_embeddings": False,
"model_type": "mistral",
},
"gemma": {
"layers_attr": "model.layers",
"embed_attr": "model.embed_tokens",
"norm_attr": "model.norm",
"lm_head_attr": "lm_head",
"cache_type": "KVCache",
"make_cache_fn": "make_cache",
"tie_embeddings": True,
"model_type": "gemma",
"norm_eps": 1e-6,
},
"gemma2": {
"layers_attr": "model.layers",
"embed_attr": "model.embed_tokens",
"norm_attr": "model.norm",
"lm_head_attr": "lm_head",
"cache_type": "KVCache",
"make_cache_fn": "make_cache",
"tie_embeddings": True,
"model_type": "gemma2",
"norm_eps": 1e-6,
},
"generic": {
"layers_attr": "layers",
"embed_attr": "embedding",
"norm_attr": "norm",
"lm_head_attr": "lm_head",
"cache_type": "KVCache",
"make_cache_fn": None,
"tie_embeddings": False,
"model_type": "generic",
},
}
def resolve_model_path(path_or_repo: str) -> Path:
"""Resolve a model path or HF Hub repo ID to a local path."""
path = Path(path_or_repo)
if path.exists():
return path
return Path(snapshot_download(path_or_repo))
def _get_attr(obj: Any, attr_path: str) -> Any:
"""Get nested attribute by dot-path, e.g. 'language_model.model.layers'."""
for part in attr_path.split("."):
if obj is None:
return None
obj = getattr(obj, part, None)
return obj
def detect_model_architecture(model, config: Optional[Dict] = None) -> str:
"""Auto-detect model architecture from model structure and config."""
# Try config first
if config is None and hasattr(model, "config"):
if hasattr(model.config, "to_dict"):
config = model.config.to_dict()
elif hasattr(model.config, "model_type"):
config = {"model_type": model.config.model_type}
if config and "model_type" in config:
mt = config["model_type"]
if mt in ARCH_LAYER_MAP:
return mt
# Aliases
if mt.startswith("qwen3_5") or mt == "qwen3.5":
return "qwen3_5"
if mt.startswith("qwen3"):
return "qwen3"
if mt.startswith("qwen2"):
return "qwen2"
if mt.startswith("llama"):
return "llama"
if mt.startswith("mistral"):
return "mistral"
if mt == "gemma2":
return "gemma2"
if mt.startswith("gemma"):
return "gemma"
# Structural detection
if hasattr(model, "language_model"):
return "qwen3_5"
if hasattr(model, "model") and hasattr(model.model, "layers"):
return "llama" # llama, qwen3, mistral all share this
if hasattr(model, "layers"):
return "generic"
return "generic"
# ──────────────────────────────────────────────────────────────────────────────
# Base adapter class β€” defines the contract all adapters must implement
# ──────────────────────────────────────────────────────────────────────────────
class MLXTargetAdapter:
"""Base adapter for DFlash target model interaction.
Every supported architecture needs an adapter that knows:
- Where embeddings live
- How to iterate layers and extract hidden states
- How to create/manage KV caches
- How to call the LM head
- How to trim/rewind caches on rejection
"""
family: str = "unknown"
arch_info: Dict[str, Any] = {}
def __init__(self, model, config: Optional[Dict] = None):
self.model = model
self.config = config or {}
self._detect_attributes()
def _detect_attributes(self):
"""Resolve embedding, layer, norm, lm_head references."""
arch = ARCH_LAYER_MAP.get(self.family, ARCH_LAYER_MAP["generic"])
self.arch_info = arch.copy()
# Try exact path first
self._embed = _get_attr(self.model, arch["embed_attr"])
self._layers = _get_attr(self.model, arch["layers_attr"])
self._norm = _get_attr(self.model, arch["norm_attr"])
self._lm_head = _get_attr(self.model, arch["lm_head_attr"])
# Fallback: probe common locations
if self._embed is None:
for attr in ("embedding", "token_embedding", "embed_tokens", "wte"):
self._embed = getattr(self.model, attr, None)
if self._embed is not None:
break
if self._layers is None:
self._layers = getattr(self.model, "layers", None)
if self._norm is None:
self._norm = getattr(self.model, "norm", None)
if self._lm_head is None:
self._lm_head = getattr(self.model, "lm_head", None)
# ── Tokenization / Prompt ───────────────────────────────────────────────
def build_prompt(self, tokenizer, prompt_text: str, enable_thinking: bool = False) -> mx.array:
"""Build prompt tokens from text."""
messages = [{"role": "user", "content": prompt_text}]
try:
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=enable_thinking,
)
except TypeError:
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
tokens = tokenizer.encode(text, add_special_tokens=False)
return mx.array(tokens, dtype=mx.uint32)
def stop_token_ids(self, tokenizer) -> set[int]:
"""Get set of stop token IDs."""
eos = tokenizer.eos_token_ids
if isinstance(eos, int):
return {eos}
if isinstance(eos, (list, tuple)):
return set(eos)
return set()
# ── Embeddings / LM Head ────────────────────────────────────────────────
def embed_tokens(self, tokens: mx.array) -> mx.array:
"""Embed token IDs to hidden states."""
if self._embed is None:
raise RuntimeError(f"[{self.family}] Could not find embedding layer")
return self._embed(tokens)
def lm_head_logits(self, hidden_states: mx.array) -> mx.array:
"""Project hidden states to vocab logits."""
if self._lm_head is not None:
return self._lm_head(hidden_states)
# Tie-word-embedding fallback
if self.arch_info.get("tie_embeddings") and self._embed is not None:
if hasattr(self._embed, "as_linear"):
return self._embed.as_linear(hidden_states)
raise RuntimeError(f"[{self.family}] Could not find LM head")
def lm_head_argmax(self, hidden_states: mx.array) -> mx.array:
"""Greedy next-token from hidden states."""
logits = self.lm_head_logits(hidden_states)
return mx.argmax(logits, axis=-1).astype(mx.uint32)
# ── Cache Management ──────────────────────────────────────────────────────
def make_cache(self) -> list[Any]:
"""Create fresh KV cache for all layers."""
cache_type = self.arch_info.get("cache_type", "KVCache")
num_layers = len(self._layers) if self._layers is not None else 0
if cache_type == "KVCache":
return [cache_lib.KVCache() for _ in range(num_layers)]
elif cache_type == "ArraysCache":
return [cache_lib.ArraysCache() for _ in range(num_layers)]
else:
return [None for _ in range(num_layers)]
def rewind_kv_caches(self, cache: list[Any], num_tokens: int) -> None:
"""Trim cache to accepted prefix length."""
for layer_cache in cache:
if isinstance(layer_cache, cache_lib.KVCache):
layer_cache.trim(num_tokens)
elif isinstance(layer_cache, cache_lib.ArraysCache) and hasattr(layer_cache, "trim"):
layer_cache.trim(num_tokens)
# ── Forward with Hidden-State Extraction ─────────────────────────────────
def create_attention_mask(self, hidden_states: mx.array, cache: Any = None) -> Optional[mx.array]:
"""Build causal attention mask appropriate for this architecture."""
# Default: simple causal mask via triangular structure
# MLX fast attention often handles this internally, but we provide a hook
seq_len = hidden_states.shape[1]
if cache is not None and hasattr(cache, "offset"):
# Cached generation β€” no mask needed for single new token
if seq_len == 1:
return None
return None # MLX models typically compute mask internally
def forward_with_hidden_states(
self,
tokens: mx.array,
cache: list[Any],
layer_ids: List[int],
output_rollback_records: bool = False,
) -> Tuple[mx.array, mx.array] | Tuple[mx.array, mx.array, Dict]:
"""
Run target model, returning (logits, target_hidden).
target_hidden = concatenation of hidden states at layer_ids.
Args:
tokens: Input token IDs [bsz, seq_len]
cache: Per-layer KV cache
layer_ids: Target layer indices for DFlash conditioning
output_rollback_records: Whether to return per-layer state for rollback
Returns:
(logits, target_hidden) or (logits, target_hidden, rollback_records)
"""
if self._embed is None or self._layers is None:
raise RuntimeError(f"[{self.family}] Model attributes not resolved")
hidden = self.embed_tokens(tokens)
mask = self.create_attention_mask(hidden, cache[0] if cache else None)
selected: List[mx.array] = []
rollback_records: Dict[int, Dict[str, mx.array]] = {}
target_layer_ids = set(layer_ids)
for idx, (layer, layer_cache) in enumerate(zip(self._layers, cache)):
# Each layer returns updated hidden states
# Some return tuple (hidden, cache_update), some just hidden
layer_out = layer(hidden, mask=mask, cache=layer_cache)
if isinstance(layer_out, tuple):
hidden = layer_out[0]
else:
hidden = layer_out
if idx in target_layer_ids:
selected.append(hidden)
# Final norm + LM head
if self._norm is not None:
hidden = self._norm(hidden)
logits = self.lm_head_logits(hidden)
# Concatenate selected hidden states across feature dim
if selected:
target_hidden = mx.concatenate(selected, axis=-1)
else:
# Fallback: use final hidden state
target_hidden = hidden
if output_rollback_records:
return logits, target_hidden, rollback_records
return logits, target_hidden
def forward_verifier_states(
self,
tokens: mx.array,
cache: list[Any],
layer_ids: List[int],
) -> Tuple[mx.array, mx.array, Dict]:
"""Forward pass that always returns rollback records."""
return self.forward_with_hidden_states(
tokens, cache, layer_ids, output_rollback_records=True
)
def forward_accept_all_block(
self,
tokens: mx.array,
cache: list[Any],
layer_ids: List[int],
) -> Tuple[mx.array, mx.array]:
"""Single-token forward returning last-position logits + target hidden."""
logits, target_hidden = self.forward_with_hidden_states(
tokens, cache, layer_ids, output_rollback_records=False
)
return logits[:, -1:, :], target_hidden
# ── Cache Summary (for debugging) ───────────────────────────────────────
def cache_summary(self, cache: list[Any]) -> str:
"""Human-readable cache status."""
parts: List[str] = []
for idx, c in enumerate(cache):
if isinstance(c, cache_lib.KVCache):
parts.append(f"{idx}:kv={c.offset}")
elif isinstance(c, cache_lib.ArraysCache):
rec = None if c[1] is None else tuple(c[1].shape)
parts.append(f"{idx}:ssm={rec}")
else:
parts.append(f"{idx}:none")
return " ".join(parts)
# ──────────────────────────────────────────────────────────────────────────────
# Per-family adapter subclasses (for architecture-specific overrides)
# ──────────────────────────────────────────────────────────────────────────────
class Qwen3Adapter(MLXTargetAdapter):
family = "qwen3"
def build_prompt(self, tokenizer, prompt_text: str, enable_thinking: bool = False) -> mx.array:
messages = [{"role": "user", "content": prompt_text}]
try:
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=enable_thinking,
)
except TypeError:
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
tokens = tokenizer.encode(text, add_special_tokens=False)
return mx.array(tokens, dtype=mx.uint32)
def create_attention_mask(self, hidden_states: mx.array, cache: Any = None) -> Optional[mx.array]:
try:
from mlx_lm.models import qwen3
return qwen3.create_attention_mask(hidden_states, cache)
except Exception:
return super().create_attention_mask(hidden_states, cache)
def lm_head_logits(self, hidden_states: mx.array) -> mx.array:
# Qwen3 often uses tied embeddings
if self.arch_info.get("tie_embeddings") and self._embed is not None:
if hasattr(self._embed, "as_linear"):
return self._embed.as_linear(hidden_states)
if self._lm_head is not None:
return self._lm_head(hidden_states)
raise RuntimeError("[qwen3] No LM head found")
class Qwen35Adapter(MLXTargetAdapter):
family = "qwen3_5"
def build_prompt(self, tokenizer, prompt_text: str, enable_thinking: bool = False) -> mx.array:
messages = [{"role": "user", "content": prompt_text}]
try:
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=enable_thinking,
)
except TypeError:
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
tokens = tokenizer.encode(text, add_special_tokens=False)
return mx.array(tokens, dtype=mx.uint32)
def create_attention_mask(self, hidden_states: mx.array, cache: Any = None) -> Optional[mx.array]:
try:
from mlx_lm.models import qwen3_5
# Qwen3.5 has hybrid attention: full-attention + linear-attention
if cache is not None and hasattr(cache, "__len__") and len(cache) > 0:
# Detect cache type
if hasattr(cache[0], "fa_idx"):
fa_mask = qwen3_5.create_attention_mask(hidden_states, cache[0])
return fa_mask
except Exception:
pass
return super().create_attention_mask(hidden_states, cache)
def forward_with_hidden_states(
self,
tokens: mx.array,
cache: list[Any],
layer_ids: List[int],
output_rollback_records: bool = False,
):
# Qwen3.5 needs special handling for hybrid attention layers
if self._embed is None or self._layers is None:
raise RuntimeError("[qwen3_5] Model attributes not resolved")
hidden = self.embed_tokens(tokens)
# Build masks for full-attention and linear-attention layers
try:
from mlx_lm.models import qwen3_5
fa_mask = qwen3_5.create_attention_mask(hidden_states=hidden, cache=cache[0] if cache else None)
except Exception:
fa_mask = None
selected: List[mx.array] = []
target_layer_ids = set(layer_ids)
for idx, (layer, layer_cache) in enumerate(zip(self._layers, cache)):
# Qwen3.5 layers have is_linear flag
mask = None
if hasattr(layer, "is_linear") and layer.is_linear:
# Linear attention layer β€” uses different mask or none
pass
else:
mask = fa_mask
layer_out = layer(hidden, mask=mask, cache=layer_cache)
if isinstance(layer_out, tuple):
hidden = layer_out[0]
else:
hidden = layer_out
if idx in target_layer_ids:
selected.append(hidden)
if self._norm is not None:
hidden = self._norm(hidden)
logits = self.lm_head_logits(hidden)
if selected:
target_hidden = mx.concatenate(selected, axis=-1)
else:
target_hidden = hidden
if output_rollback_records:
return logits, target_hidden, {}
return logits, target_hidden
class LlamaAdapter(MLXTargetAdapter):
family = "llama"
def create_attention_mask(self, hidden_states: mx.array, cache: Any = None) -> Optional[mx.array]:
try:
from mlx_lm.models import llama
return llama.create_attention_mask(hidden_states, cache)
except Exception:
return super().create_attention_mask(hidden_states, cache)
class MistralAdapter(MLXTargetAdapter):
family = "mistral"
def create_attention_mask(self, hidden_states: mx.array, cache: Any = None) -> Optional[mx.array]:
try:
from mlx_lm.models import mistral
return mistral.create_attention_mask(hidden_states, cache)
except Exception:
return super().create_attention_mask(hidden_states, cache)
class GemmaAdapter(MLXTargetAdapter):
family = "gemma"
def create_attention_mask(self, hidden_states: mx.array, cache: Any = None) -> Optional[mx.array]:
try:
from mlx_lm.models import gemma
return gemma.create_attention_mask(hidden_states, cache)
except Exception:
return super().create_attention_mask(hidden_states, cache)
# ──────────────────────────────────────────────────────────────────────────────
# Adapter registry and factory
# ──────────────────────────────────────────────────────────────────────────────
ADAPTERS: Dict[str, type[MLXTargetAdapter]] = {
"qwen3": Qwen3Adapter,
"qwen2": Qwen3Adapter, # Shares structure
"qwen3_5": Qwen35Adapter,
"llama": LlamaAdapter,
"mistral": MistralAdapter,
"gemma": GemmaAdapter,
"gemma2": GemmaAdapter,
"generic": MLXTargetAdapter,
}
def adapter_for_model_type(model_type: str) -> Optional[type[MLXTargetAdapter]]:
"""Get adapter class for a model type string."""
# Direct match
if model_type in ADAPTERS:
return ADAPTERS[model_type]
# Aliases
if model_type.startswith("qwen3_5") or model_type == "qwen3.5":
return Qwen35Adapter
if model_type.startswith("qwen3"):
return Qwen3Adapter
if model_type.startswith("qwen2"):
return Qwen3Adapter
if model_type.startswith("llama"):
return LlamaAdapter
if model_type.startswith("mistral"):
return MistralAdapter
if model_type == "gemma2":
return GemmaAdapter
if model_type.startswith("gemma"):
return GemmaAdapter
return None
# ──────────────────────────────────────────────────────────────────────────────
# LoadedTargetModel β€” convenience wrapper binding model + adapter + tokenizer
# ──────────────────────────────────────────────────────────────────────────────
@dataclass
class LoadedTargetModel:
requested_model: str
resolved_model_path: Path
model: Any
tokenizer: Any
adapter: MLXTargetAdapter
def build_prompt(self, prompt_text: str, enable_thinking: bool = False) -> mx.array:
return self.adapter.build_prompt(self.tokenizer, prompt_text, enable_thinking)
def stop_token_ids(self) -> set[int]:
return self.adapter.stop_token_ids(self.tokenizer)
def make_cache(self) -> list[Any]:
return self.adapter.make_cache()
def embed_tokens(self, tokens: mx.array) -> mx.array:
return self.adapter.embed_tokens(tokens)
def lm_head_logits(self, hidden_states: mx.array) -> mx.array:
return self.adapter.lm_head_logits(hidden_states)
def lm_head_argmax(self, hidden_states: mx.array) -> mx.array:
return self.adapter.lm_head_argmax(hidden_states)
def forward_with_hidden_states(
self,
tokens: mx.array,
cache: list[Any],
layer_ids: List[int],
output_rollback_records: bool = False,
):
return self.adapter.forward_with_hidden_states(
tokens, cache, layer_ids, output_rollback_records
)
def forward_verifier_states(self, tokens: mx.array, cache: list[Any], layer_ids: List[int]):
return self.adapter.forward_verifier_states(tokens, cache, layer_ids)
def forward_accept_all_block(self, tokens: mx.array, cache: list[Any], layer_ids: List[int]):
return self.adapter.forward_accept_all_block(tokens, cache, layer_ids)
def rewind_kv_caches(self, cache: list[Any], num_tokens: int) -> None:
self.adapter.rewind_kv_caches(cache, num_tokens)
def cache_summary(self, cache: list[Any]) -> str:
return self.adapter.cache_summary(cache)
def load_target_model(path_or_repo: str) -> LoadedTargetModel:
"""Load an MLX target model with the correct adapter.
Args:
path_or_repo: Local path or HF Hub model ID
Returns:
LoadedTargetModel with architecture-aware adapter
"""
base_path = resolve_model_path(path_or_repo)
# Load config to detect architecture
config_path = base_path / "config.json"
if config_path.exists():
with open(config_path, "r") as f:
config = json.load(f)
else:
config = {}
model_type = config.get("model_type", "generic")
adapter_cls = adapter_for_model_type(model_type)
if adapter_cls is None:
registered = ", ".join(sorted(ADAPTERS.keys()))
raise NotImplementedError(
f"Unsupported MLX DFlash target model_type={model_type!r}. "
f"Registered adapters: {registered}. "
f"You can add one by subclassing MLXTargetAdapter in adapters.py."
)
# Load model + tokenizer via mlx_lm
model, tokenizer = load(str(base_path))
# Instantiate adapter
adapter = adapter_cls(model, config)
return LoadedTargetModel(
requested_model=path_or_repo,
resolved_model_path=base_path,
model=model,
tokenizer=tokenizer,
adapter=adapter,
)