File size: 27,895 Bytes
bb76689 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 | """
Universal architecture adapters for DFlash speculative decoding on MLX.
Supports: Qwen3, Qwen3.5, LLaMA (2/3), Mistral, Gemma, and generic transformers.
Inspired by Aryagm's adapter pattern and bstnxbt's per-family engine approach.
"""
from __future__ import annotations
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Optional, Tuple, List, Dict
import mlx.core as mx
import mlx.nn as nn
from huggingface_hub import snapshot_download
from mlx_lm import load
from mlx_lm.models import cache as cache_lib
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Architecture registry β maps model_type β adapter class
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
ARCH_LAYER_MAP: Dict[str, Dict[str, Any]] = {
"qwen3": {
"layers_attr": "model.layers",
"embed_attr": "model.embed_tokens",
"norm_attr": "model.norm",
"lm_head_attr": "lm_head",
"cache_type": "KVCache",
"make_cache_fn": "make_cache",
"tie_embeddings": True,
"model_type": "qwen3",
},
"qwen2": {
"layers_attr": "model.layers",
"embed_attr": "model.embed_tokens",
"norm_attr": "model.norm",
"lm_head_attr": "lm_head",
"cache_type": "KVCache",
"make_cache_fn": "make_cache",
"tie_embeddings": True,
"model_type": "qwen2",
},
"qwen3_5": {
"layers_attr": "language_model.model.layers",
"embed_attr": "language_model.model.embed_tokens",
"norm_attr": "language_model.model.norm",
"lm_head_attr": "language_model.lm_head",
"cache_type": "ArraysCache",
"make_cache_fn": "make_cache",
"tie_embeddings": True,
"model_type": "qwen3_5",
"has_hybrid_attention": True,
"has_linear_attention": True,
},
"llama": {
"layers_attr": "model.layers",
"embed_attr": "model.embed_tokens",
"norm_attr": "model.norm",
"lm_head_attr": "lm_head",
"cache_type": "KVCache",
"make_cache_fn": "make_cache",
"tie_embeddings": False,
"model_type": "llama",
},
"mistral": {
"layers_attr": "model.layers",
"embed_attr": "model.embed_tokens",
"norm_attr": "model.norm",
"lm_head_attr": "lm_head",
"cache_type": "KVCache",
"make_cache_fn": "make_cache",
"tie_embeddings": False,
"model_type": "mistral",
},
"gemma": {
"layers_attr": "model.layers",
"embed_attr": "model.embed_tokens",
"norm_attr": "model.norm",
"lm_head_attr": "lm_head",
"cache_type": "KVCache",
"make_cache_fn": "make_cache",
"tie_embeddings": True,
"model_type": "gemma",
"norm_eps": 1e-6,
},
"gemma2": {
"layers_attr": "model.layers",
"embed_attr": "model.embed_tokens",
"norm_attr": "model.norm",
"lm_head_attr": "lm_head",
"cache_type": "KVCache",
"make_cache_fn": "make_cache",
"tie_embeddings": True,
"model_type": "gemma2",
"norm_eps": 1e-6,
},
"generic": {
"layers_attr": "layers",
"embed_attr": "embedding",
"norm_attr": "norm",
"lm_head_attr": "lm_head",
"cache_type": "KVCache",
"make_cache_fn": None,
"tie_embeddings": False,
"model_type": "generic",
},
}
def resolve_model_path(path_or_repo: str) -> Path:
"""Resolve a model path or HF Hub repo ID to a local path."""
path = Path(path_or_repo)
if path.exists():
return path
return Path(snapshot_download(path_or_repo))
def _get_attr(obj: Any, attr_path: str) -> Any:
"""Get nested attribute by dot-path, e.g. 'language_model.model.layers'."""
for part in attr_path.split("."):
if obj is None:
return None
obj = getattr(obj, part, None)
return obj
def detect_model_architecture(model, config: Optional[Dict] = None) -> str:
"""Auto-detect model architecture from model structure and config."""
# Try config first
if config is None and hasattr(model, "config"):
if hasattr(model.config, "to_dict"):
config = model.config.to_dict()
elif hasattr(model.config, "model_type"):
config = {"model_type": model.config.model_type}
if config and "model_type" in config:
mt = config["model_type"]
if mt in ARCH_LAYER_MAP:
return mt
# Aliases
if mt.startswith("qwen3_5") or mt == "qwen3.5":
return "qwen3_5"
if mt.startswith("qwen3"):
return "qwen3"
if mt.startswith("qwen2"):
return "qwen2"
if mt.startswith("llama"):
return "llama"
if mt.startswith("mistral"):
return "mistral"
if mt == "gemma2":
return "gemma2"
if mt.startswith("gemma"):
return "gemma"
# Structural detection
if hasattr(model, "language_model"):
return "qwen3_5"
if hasattr(model, "model") and hasattr(model.model, "layers"):
return "llama" # llama, qwen3, mistral all share this
if hasattr(model, "layers"):
return "generic"
return "generic"
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Base adapter class β defines the contract all adapters must implement
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class MLXTargetAdapter:
"""Base adapter for DFlash target model interaction.
Every supported architecture needs an adapter that knows:
- Where embeddings live
- How to iterate layers and extract hidden states
- How to create/manage KV caches
- How to call the LM head
- How to trim/rewind caches on rejection
"""
family: str = "unknown"
arch_info: Dict[str, Any] = {}
def __init__(self, model, config: Optional[Dict] = None):
self.model = model
self.config = config or {}
self._detect_attributes()
def _detect_attributes(self):
"""Resolve embedding, layer, norm, lm_head references."""
arch = ARCH_LAYER_MAP.get(self.family, ARCH_LAYER_MAP["generic"])
self.arch_info = arch.copy()
# Try exact path first
self._embed = _get_attr(self.model, arch["embed_attr"])
self._layers = _get_attr(self.model, arch["layers_attr"])
self._norm = _get_attr(self.model, arch["norm_attr"])
self._lm_head = _get_attr(self.model, arch["lm_head_attr"])
# Fallback: probe common locations
if self._embed is None:
for attr in ("embedding", "token_embedding", "embed_tokens", "wte"):
self._embed = getattr(self.model, attr, None)
if self._embed is not None:
break
if self._layers is None:
self._layers = getattr(self.model, "layers", None)
if self._norm is None:
self._norm = getattr(self.model, "norm", None)
if self._lm_head is None:
self._lm_head = getattr(self.model, "lm_head", None)
# ββ Tokenization / Prompt βββββββββββββββββββββββββββββββββββββββββββββββ
def build_prompt(self, tokenizer, prompt_text: str, enable_thinking: bool = False) -> mx.array:
"""Build prompt tokens from text."""
messages = [{"role": "user", "content": prompt_text}]
try:
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=enable_thinking,
)
except TypeError:
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
tokens = tokenizer.encode(text, add_special_tokens=False)
return mx.array(tokens, dtype=mx.uint32)
def stop_token_ids(self, tokenizer) -> set[int]:
"""Get set of stop token IDs."""
eos = tokenizer.eos_token_ids
if isinstance(eos, int):
return {eos}
if isinstance(eos, (list, tuple)):
return set(eos)
return set()
# ββ Embeddings / LM Head ββββββββββββββββββββββββββββββββββββββββββββββββ
def embed_tokens(self, tokens: mx.array) -> mx.array:
"""Embed token IDs to hidden states."""
if self._embed is None:
raise RuntimeError(f"[{self.family}] Could not find embedding layer")
return self._embed(tokens)
def lm_head_logits(self, hidden_states: mx.array) -> mx.array:
"""Project hidden states to vocab logits."""
if self._lm_head is not None:
return self._lm_head(hidden_states)
# Tie-word-embedding fallback
if self.arch_info.get("tie_embeddings") and self._embed is not None:
if hasattr(self._embed, "as_linear"):
return self._embed.as_linear(hidden_states)
raise RuntimeError(f"[{self.family}] Could not find LM head")
def lm_head_argmax(self, hidden_states: mx.array) -> mx.array:
"""Greedy next-token from hidden states."""
logits = self.lm_head_logits(hidden_states)
return mx.argmax(logits, axis=-1).astype(mx.uint32)
# ββ Cache Management ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def make_cache(self) -> list[Any]:
"""Create fresh KV cache for all layers."""
cache_type = self.arch_info.get("cache_type", "KVCache")
num_layers = len(self._layers) if self._layers is not None else 0
if cache_type == "KVCache":
return [cache_lib.KVCache() for _ in range(num_layers)]
elif cache_type == "ArraysCache":
return [cache_lib.ArraysCache() for _ in range(num_layers)]
else:
return [None for _ in range(num_layers)]
def rewind_kv_caches(self, cache: list[Any], num_tokens: int) -> None:
"""Trim cache to accepted prefix length."""
for layer_cache in cache:
if isinstance(layer_cache, cache_lib.KVCache):
layer_cache.trim(num_tokens)
elif isinstance(layer_cache, cache_lib.ArraysCache) and hasattr(layer_cache, "trim"):
layer_cache.trim(num_tokens)
# ββ Forward with Hidden-State Extraction βββββββββββββββββββββββββββββββββ
def create_attention_mask(self, hidden_states: mx.array, cache: Any = None) -> Optional[mx.array]:
"""Build causal attention mask appropriate for this architecture."""
# Default: simple causal mask via triangular structure
# MLX fast attention often handles this internally, but we provide a hook
seq_len = hidden_states.shape[1]
if cache is not None and hasattr(cache, "offset"):
# Cached generation β no mask needed for single new token
if seq_len == 1:
return None
return None # MLX models typically compute mask internally
def forward_with_hidden_states(
self,
tokens: mx.array,
cache: list[Any],
layer_ids: List[int],
output_rollback_records: bool = False,
) -> Tuple[mx.array, mx.array] | Tuple[mx.array, mx.array, Dict]:
"""
Run target model, returning (logits, target_hidden).
target_hidden = concatenation of hidden states at layer_ids.
Args:
tokens: Input token IDs [bsz, seq_len]
cache: Per-layer KV cache
layer_ids: Target layer indices for DFlash conditioning
output_rollback_records: Whether to return per-layer state for rollback
Returns:
(logits, target_hidden) or (logits, target_hidden, rollback_records)
"""
if self._embed is None or self._layers is None:
raise RuntimeError(f"[{self.family}] Model attributes not resolved")
hidden = self.embed_tokens(tokens)
mask = self.create_attention_mask(hidden, cache[0] if cache else None)
selected: List[mx.array] = []
rollback_records: Dict[int, Dict[str, mx.array]] = {}
target_layer_ids = set(layer_ids)
for idx, (layer, layer_cache) in enumerate(zip(self._layers, cache)):
# Each layer returns updated hidden states
# Some return tuple (hidden, cache_update), some just hidden
layer_out = layer(hidden, mask=mask, cache=layer_cache)
if isinstance(layer_out, tuple):
hidden = layer_out[0]
else:
hidden = layer_out
if idx in target_layer_ids:
selected.append(hidden)
# Final norm + LM head
if self._norm is not None:
hidden = self._norm(hidden)
logits = self.lm_head_logits(hidden)
# Concatenate selected hidden states across feature dim
if selected:
target_hidden = mx.concatenate(selected, axis=-1)
else:
# Fallback: use final hidden state
target_hidden = hidden
if output_rollback_records:
return logits, target_hidden, rollback_records
return logits, target_hidden
def forward_verifier_states(
self,
tokens: mx.array,
cache: list[Any],
layer_ids: List[int],
) -> Tuple[mx.array, mx.array, Dict]:
"""Forward pass that always returns rollback records."""
return self.forward_with_hidden_states(
tokens, cache, layer_ids, output_rollback_records=True
)
def forward_accept_all_block(
self,
tokens: mx.array,
cache: list[Any],
layer_ids: List[int],
) -> Tuple[mx.array, mx.array]:
"""Single-token forward returning last-position logits + target hidden."""
logits, target_hidden = self.forward_with_hidden_states(
tokens, cache, layer_ids, output_rollback_records=False
)
return logits[:, -1:, :], target_hidden
# ββ Cache Summary (for debugging) βββββββββββββββββββββββββββββββββββββββ
def cache_summary(self, cache: list[Any]) -> str:
"""Human-readable cache status."""
parts: List[str] = []
for idx, c in enumerate(cache):
if isinstance(c, cache_lib.KVCache):
parts.append(f"{idx}:kv={c.offset}")
elif isinstance(c, cache_lib.ArraysCache):
rec = None if c[1] is None else tuple(c[1].shape)
parts.append(f"{idx}:ssm={rec}")
else:
parts.append(f"{idx}:none")
return " ".join(parts)
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Per-family adapter subclasses (for architecture-specific overrides)
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class Qwen3Adapter(MLXTargetAdapter):
family = "qwen3"
def build_prompt(self, tokenizer, prompt_text: str, enable_thinking: bool = False) -> mx.array:
messages = [{"role": "user", "content": prompt_text}]
try:
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=enable_thinking,
)
except TypeError:
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
tokens = tokenizer.encode(text, add_special_tokens=False)
return mx.array(tokens, dtype=mx.uint32)
def create_attention_mask(self, hidden_states: mx.array, cache: Any = None) -> Optional[mx.array]:
try:
from mlx_lm.models import qwen3
return qwen3.create_attention_mask(hidden_states, cache)
except Exception:
return super().create_attention_mask(hidden_states, cache)
def lm_head_logits(self, hidden_states: mx.array) -> mx.array:
# Qwen3 often uses tied embeddings
if self.arch_info.get("tie_embeddings") and self._embed is not None:
if hasattr(self._embed, "as_linear"):
return self._embed.as_linear(hidden_states)
if self._lm_head is not None:
return self._lm_head(hidden_states)
raise RuntimeError("[qwen3] No LM head found")
class Qwen35Adapter(MLXTargetAdapter):
family = "qwen3_5"
def build_prompt(self, tokenizer, prompt_text: str, enable_thinking: bool = False) -> mx.array:
messages = [{"role": "user", "content": prompt_text}]
try:
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=enable_thinking,
)
except TypeError:
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
tokens = tokenizer.encode(text, add_special_tokens=False)
return mx.array(tokens, dtype=mx.uint32)
def create_attention_mask(self, hidden_states: mx.array, cache: Any = None) -> Optional[mx.array]:
try:
from mlx_lm.models import qwen3_5
# Qwen3.5 has hybrid attention: full-attention + linear-attention
if cache is not None and hasattr(cache, "__len__") and len(cache) > 0:
# Detect cache type
if hasattr(cache[0], "fa_idx"):
fa_mask = qwen3_5.create_attention_mask(hidden_states, cache[0])
return fa_mask
except Exception:
pass
return super().create_attention_mask(hidden_states, cache)
def forward_with_hidden_states(
self,
tokens: mx.array,
cache: list[Any],
layer_ids: List[int],
output_rollback_records: bool = False,
):
# Qwen3.5 needs special handling for hybrid attention layers
if self._embed is None or self._layers is None:
raise RuntimeError("[qwen3_5] Model attributes not resolved")
hidden = self.embed_tokens(tokens)
# Build masks for full-attention and linear-attention layers
try:
from mlx_lm.models import qwen3_5
fa_mask = qwen3_5.create_attention_mask(hidden_states=hidden, cache=cache[0] if cache else None)
except Exception:
fa_mask = None
selected: List[mx.array] = []
target_layer_ids = set(layer_ids)
for idx, (layer, layer_cache) in enumerate(zip(self._layers, cache)):
# Qwen3.5 layers have is_linear flag
mask = None
if hasattr(layer, "is_linear") and layer.is_linear:
# Linear attention layer β uses different mask or none
pass
else:
mask = fa_mask
layer_out = layer(hidden, mask=mask, cache=layer_cache)
if isinstance(layer_out, tuple):
hidden = layer_out[0]
else:
hidden = layer_out
if idx in target_layer_ids:
selected.append(hidden)
if self._norm is not None:
hidden = self._norm(hidden)
logits = self.lm_head_logits(hidden)
if selected:
target_hidden = mx.concatenate(selected, axis=-1)
else:
target_hidden = hidden
if output_rollback_records:
return logits, target_hidden, {}
return logits, target_hidden
class LlamaAdapter(MLXTargetAdapter):
family = "llama"
def create_attention_mask(self, hidden_states: mx.array, cache: Any = None) -> Optional[mx.array]:
try:
from mlx_lm.models import llama
return llama.create_attention_mask(hidden_states, cache)
except Exception:
return super().create_attention_mask(hidden_states, cache)
class MistralAdapter(MLXTargetAdapter):
family = "mistral"
def create_attention_mask(self, hidden_states: mx.array, cache: Any = None) -> Optional[mx.array]:
try:
from mlx_lm.models import mistral
return mistral.create_attention_mask(hidden_states, cache)
except Exception:
return super().create_attention_mask(hidden_states, cache)
class GemmaAdapter(MLXTargetAdapter):
family = "gemma"
def create_attention_mask(self, hidden_states: mx.array, cache: Any = None) -> Optional[mx.array]:
try:
from mlx_lm.models import gemma
return gemma.create_attention_mask(hidden_states, cache)
except Exception:
return super().create_attention_mask(hidden_states, cache)
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Adapter registry and factory
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
ADAPTERS: Dict[str, type[MLXTargetAdapter]] = {
"qwen3": Qwen3Adapter,
"qwen2": Qwen3Adapter, # Shares structure
"qwen3_5": Qwen35Adapter,
"llama": LlamaAdapter,
"mistral": MistralAdapter,
"gemma": GemmaAdapter,
"gemma2": GemmaAdapter,
"generic": MLXTargetAdapter,
}
def adapter_for_model_type(model_type: str) -> Optional[type[MLXTargetAdapter]]:
"""Get adapter class for a model type string."""
# Direct match
if model_type in ADAPTERS:
return ADAPTERS[model_type]
# Aliases
if model_type.startswith("qwen3_5") or model_type == "qwen3.5":
return Qwen35Adapter
if model_type.startswith("qwen3"):
return Qwen3Adapter
if model_type.startswith("qwen2"):
return Qwen3Adapter
if model_type.startswith("llama"):
return LlamaAdapter
if model_type.startswith("mistral"):
return MistralAdapter
if model_type == "gemma2":
return GemmaAdapter
if model_type.startswith("gemma"):
return GemmaAdapter
return None
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# LoadedTargetModel β convenience wrapper binding model + adapter + tokenizer
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@dataclass
class LoadedTargetModel:
requested_model: str
resolved_model_path: Path
model: Any
tokenizer: Any
adapter: MLXTargetAdapter
def build_prompt(self, prompt_text: str, enable_thinking: bool = False) -> mx.array:
return self.adapter.build_prompt(self.tokenizer, prompt_text, enable_thinking)
def stop_token_ids(self) -> set[int]:
return self.adapter.stop_token_ids(self.tokenizer)
def make_cache(self) -> list[Any]:
return self.adapter.make_cache()
def embed_tokens(self, tokens: mx.array) -> mx.array:
return self.adapter.embed_tokens(tokens)
def lm_head_logits(self, hidden_states: mx.array) -> mx.array:
return self.adapter.lm_head_logits(hidden_states)
def lm_head_argmax(self, hidden_states: mx.array) -> mx.array:
return self.adapter.lm_head_argmax(hidden_states)
def forward_with_hidden_states(
self,
tokens: mx.array,
cache: list[Any],
layer_ids: List[int],
output_rollback_records: bool = False,
):
return self.adapter.forward_with_hidden_states(
tokens, cache, layer_ids, output_rollback_records
)
def forward_verifier_states(self, tokens: mx.array, cache: list[Any], layer_ids: List[int]):
return self.adapter.forward_verifier_states(tokens, cache, layer_ids)
def forward_accept_all_block(self, tokens: mx.array, cache: list[Any], layer_ids: List[int]):
return self.adapter.forward_accept_all_block(tokens, cache, layer_ids)
def rewind_kv_caches(self, cache: list[Any], num_tokens: int) -> None:
self.adapter.rewind_kv_caches(cache, num_tokens)
def cache_summary(self, cache: list[Any]) -> str:
return self.adapter.cache_summary(cache)
def load_target_model(path_or_repo: str) -> LoadedTargetModel:
"""Load an MLX target model with the correct adapter.
Args:
path_or_repo: Local path or HF Hub model ID
Returns:
LoadedTargetModel with architecture-aware adapter
"""
base_path = resolve_model_path(path_or_repo)
# Load config to detect architecture
config_path = base_path / "config.json"
if config_path.exists():
with open(config_path, "r") as f:
config = json.load(f)
else:
config = {}
model_type = config.get("model_type", "generic")
adapter_cls = adapter_for_model_type(model_type)
if adapter_cls is None:
registered = ", ".join(sorted(ADAPTERS.keys()))
raise NotImplementedError(
f"Unsupported MLX DFlash target model_type={model_type!r}. "
f"Registered adapters: {registered}. "
f"You can add one by subclassing MLXTargetAdapter in adapters.py."
)
# Load model + tokenizer via mlx_lm
model, tokenizer = load(str(base_path))
# Instantiate adapter
adapter = adapter_cls(model, config)
return LoadedTargetModel(
requested_model=path_or_repo,
resolved_model_path=base_path,
model=model,
tokenizer=tokenizer,
adapter=adapter,
)
|