tritesh commited on
Commit
bb76689
Β·
verified Β·
1 Parent(s): 7aca493

Upload dflash_mlx/adapters.py

Browse files
Files changed (1) hide show
  1. dflash_mlx/adapters.py +706 -0
dflash_mlx/adapters.py ADDED
@@ -0,0 +1,706 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Universal architecture adapters for DFlash speculative decoding on MLX.
3
+
4
+ Supports: Qwen3, Qwen3.5, LLaMA (2/3), Mistral, Gemma, and generic transformers.
5
+ Inspired by Aryagm's adapter pattern and bstnxbt's per-family engine approach.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import json
11
+ from dataclasses import dataclass
12
+ from pathlib import Path
13
+ from typing import Any, Optional, Tuple, List, Dict
14
+
15
+ import mlx.core as mx
16
+ import mlx.nn as nn
17
+ from huggingface_hub import snapshot_download
18
+ from mlx_lm import load
19
+ from mlx_lm.models import cache as cache_lib
20
+
21
+
22
+ # ──────────────────────────────────────────────────────────────────────────────
23
+ # Architecture registry β€” maps model_type β†’ adapter class
24
+ # ──────────────────────────────────────────────────────────────────────────────
25
+
26
+ ARCH_LAYER_MAP: Dict[str, Dict[str, Any]] = {
27
+ "qwen3": {
28
+ "layers_attr": "model.layers",
29
+ "embed_attr": "model.embed_tokens",
30
+ "norm_attr": "model.norm",
31
+ "lm_head_attr": "lm_head",
32
+ "cache_type": "KVCache",
33
+ "make_cache_fn": "make_cache",
34
+ "tie_embeddings": True,
35
+ "model_type": "qwen3",
36
+ },
37
+ "qwen2": {
38
+ "layers_attr": "model.layers",
39
+ "embed_attr": "model.embed_tokens",
40
+ "norm_attr": "model.norm",
41
+ "lm_head_attr": "lm_head",
42
+ "cache_type": "KVCache",
43
+ "make_cache_fn": "make_cache",
44
+ "tie_embeddings": True,
45
+ "model_type": "qwen2",
46
+ },
47
+ "qwen3_5": {
48
+ "layers_attr": "language_model.model.layers",
49
+ "embed_attr": "language_model.model.embed_tokens",
50
+ "norm_attr": "language_model.model.norm",
51
+ "lm_head_attr": "language_model.lm_head",
52
+ "cache_type": "ArraysCache",
53
+ "make_cache_fn": "make_cache",
54
+ "tie_embeddings": True,
55
+ "model_type": "qwen3_5",
56
+ "has_hybrid_attention": True,
57
+ "has_linear_attention": True,
58
+ },
59
+ "llama": {
60
+ "layers_attr": "model.layers",
61
+ "embed_attr": "model.embed_tokens",
62
+ "norm_attr": "model.norm",
63
+ "lm_head_attr": "lm_head",
64
+ "cache_type": "KVCache",
65
+ "make_cache_fn": "make_cache",
66
+ "tie_embeddings": False,
67
+ "model_type": "llama",
68
+ },
69
+ "mistral": {
70
+ "layers_attr": "model.layers",
71
+ "embed_attr": "model.embed_tokens",
72
+ "norm_attr": "model.norm",
73
+ "lm_head_attr": "lm_head",
74
+ "cache_type": "KVCache",
75
+ "make_cache_fn": "make_cache",
76
+ "tie_embeddings": False,
77
+ "model_type": "mistral",
78
+ },
79
+ "gemma": {
80
+ "layers_attr": "model.layers",
81
+ "embed_attr": "model.embed_tokens",
82
+ "norm_attr": "model.norm",
83
+ "lm_head_attr": "lm_head",
84
+ "cache_type": "KVCache",
85
+ "make_cache_fn": "make_cache",
86
+ "tie_embeddings": True,
87
+ "model_type": "gemma",
88
+ "norm_eps": 1e-6,
89
+ },
90
+ "gemma2": {
91
+ "layers_attr": "model.layers",
92
+ "embed_attr": "model.embed_tokens",
93
+ "norm_attr": "model.norm",
94
+ "lm_head_attr": "lm_head",
95
+ "cache_type": "KVCache",
96
+ "make_cache_fn": "make_cache",
97
+ "tie_embeddings": True,
98
+ "model_type": "gemma2",
99
+ "norm_eps": 1e-6,
100
+ },
101
+ "generic": {
102
+ "layers_attr": "layers",
103
+ "embed_attr": "embedding",
104
+ "norm_attr": "norm",
105
+ "lm_head_attr": "lm_head",
106
+ "cache_type": "KVCache",
107
+ "make_cache_fn": None,
108
+ "tie_embeddings": False,
109
+ "model_type": "generic",
110
+ },
111
+ }
112
+
113
+
114
+ def resolve_model_path(path_or_repo: str) -> Path:
115
+ """Resolve a model path or HF Hub repo ID to a local path."""
116
+ path = Path(path_or_repo)
117
+ if path.exists():
118
+ return path
119
+ return Path(snapshot_download(path_or_repo))
120
+
121
+
122
+ def _get_attr(obj: Any, attr_path: str) -> Any:
123
+ """Get nested attribute by dot-path, e.g. 'language_model.model.layers'."""
124
+ for part in attr_path.split("."):
125
+ if obj is None:
126
+ return None
127
+ obj = getattr(obj, part, None)
128
+ return obj
129
+
130
+
131
+ def detect_model_architecture(model, config: Optional[Dict] = None) -> str:
132
+ """Auto-detect model architecture from model structure and config."""
133
+ # Try config first
134
+ if config is None and hasattr(model, "config"):
135
+ if hasattr(model.config, "to_dict"):
136
+ config = model.config.to_dict()
137
+ elif hasattr(model.config, "model_type"):
138
+ config = {"model_type": model.config.model_type}
139
+
140
+ if config and "model_type" in config:
141
+ mt = config["model_type"]
142
+ if mt in ARCH_LAYER_MAP:
143
+ return mt
144
+ # Aliases
145
+ if mt.startswith("qwen3_5") or mt == "qwen3.5":
146
+ return "qwen3_5"
147
+ if mt.startswith("qwen3"):
148
+ return "qwen3"
149
+ if mt.startswith("qwen2"):
150
+ return "qwen2"
151
+ if mt.startswith("llama"):
152
+ return "llama"
153
+ if mt.startswith("mistral"):
154
+ return "mistral"
155
+ if mt == "gemma2":
156
+ return "gemma2"
157
+ if mt.startswith("gemma"):
158
+ return "gemma"
159
+
160
+ # Structural detection
161
+ if hasattr(model, "language_model"):
162
+ return "qwen3_5"
163
+ if hasattr(model, "model") and hasattr(model.model, "layers"):
164
+ return "llama" # llama, qwen3, mistral all share this
165
+ if hasattr(model, "layers"):
166
+ return "generic"
167
+
168
+ return "generic"
169
+
170
+
171
+ # ──────────────────────────────────────────────────────────────────────────────
172
+ # Base adapter class β€” defines the contract all adapters must implement
173
+ # ──────────────────────────────────────────────────────────────────────────────
174
+
175
+ class MLXTargetAdapter:
176
+ """Base adapter for DFlash target model interaction.
177
+
178
+ Every supported architecture needs an adapter that knows:
179
+ - Where embeddings live
180
+ - How to iterate layers and extract hidden states
181
+ - How to create/manage KV caches
182
+ - How to call the LM head
183
+ - How to trim/rewind caches on rejection
184
+ """
185
+
186
+ family: str = "unknown"
187
+ arch_info: Dict[str, Any] = {}
188
+
189
+ def __init__(self, model, config: Optional[Dict] = None):
190
+ self.model = model
191
+ self.config = config or {}
192
+ self._detect_attributes()
193
+
194
+ def _detect_attributes(self):
195
+ """Resolve embedding, layer, norm, lm_head references."""
196
+ arch = ARCH_LAYER_MAP.get(self.family, ARCH_LAYER_MAP["generic"])
197
+ self.arch_info = arch.copy()
198
+
199
+ # Try exact path first
200
+ self._embed = _get_attr(self.model, arch["embed_attr"])
201
+ self._layers = _get_attr(self.model, arch["layers_attr"])
202
+ self._norm = _get_attr(self.model, arch["norm_attr"])
203
+ self._lm_head = _get_attr(self.model, arch["lm_head_attr"])
204
+
205
+ # Fallback: probe common locations
206
+ if self._embed is None:
207
+ for attr in ("embedding", "token_embedding", "embed_tokens", "wte"):
208
+ self._embed = getattr(self.model, attr, None)
209
+ if self._embed is not None:
210
+ break
211
+
212
+ if self._layers is None:
213
+ self._layers = getattr(self.model, "layers", None)
214
+
215
+ if self._norm is None:
216
+ self._norm = getattr(self.model, "norm", None)
217
+
218
+ if self._lm_head is None:
219
+ self._lm_head = getattr(self.model, "lm_head", None)
220
+
221
+ # ── Tokenization / Prompt ───────────────────────────────────────────────
222
+
223
+ def build_prompt(self, tokenizer, prompt_text: str, enable_thinking: bool = False) -> mx.array:
224
+ """Build prompt tokens from text."""
225
+ messages = [{"role": "user", "content": prompt_text}]
226
+ try:
227
+ text = tokenizer.apply_chat_template(
228
+ messages,
229
+ tokenize=False,
230
+ add_generation_prompt=True,
231
+ enable_thinking=enable_thinking,
232
+ )
233
+ except TypeError:
234
+ text = tokenizer.apply_chat_template(
235
+ messages,
236
+ tokenize=False,
237
+ add_generation_prompt=True,
238
+ )
239
+ tokens = tokenizer.encode(text, add_special_tokens=False)
240
+ return mx.array(tokens, dtype=mx.uint32)
241
+
242
+ def stop_token_ids(self, tokenizer) -> set[int]:
243
+ """Get set of stop token IDs."""
244
+ eos = tokenizer.eos_token_ids
245
+ if isinstance(eos, int):
246
+ return {eos}
247
+ if isinstance(eos, (list, tuple)):
248
+ return set(eos)
249
+ return set()
250
+
251
+ # ── Embeddings / LM Head ────────────────────────────────────────────────
252
+
253
+ def embed_tokens(self, tokens: mx.array) -> mx.array:
254
+ """Embed token IDs to hidden states."""
255
+ if self._embed is None:
256
+ raise RuntimeError(f"[{self.family}] Could not find embedding layer")
257
+ return self._embed(tokens)
258
+
259
+ def lm_head_logits(self, hidden_states: mx.array) -> mx.array:
260
+ """Project hidden states to vocab logits."""
261
+ if self._lm_head is not None:
262
+ return self._lm_head(hidden_states)
263
+ # Tie-word-embedding fallback
264
+ if self.arch_info.get("tie_embeddings") and self._embed is not None:
265
+ if hasattr(self._embed, "as_linear"):
266
+ return self._embed.as_linear(hidden_states)
267
+ raise RuntimeError(f"[{self.family}] Could not find LM head")
268
+
269
+ def lm_head_argmax(self, hidden_states: mx.array) -> mx.array:
270
+ """Greedy next-token from hidden states."""
271
+ logits = self.lm_head_logits(hidden_states)
272
+ return mx.argmax(logits, axis=-1).astype(mx.uint32)
273
+
274
+ # ── Cache Management ──────────────────────────────────────────────────────
275
+
276
+ def make_cache(self) -> list[Any]:
277
+ """Create fresh KV cache for all layers."""
278
+ cache_type = self.arch_info.get("cache_type", "KVCache")
279
+ num_layers = len(self._layers) if self._layers is not None else 0
280
+
281
+ if cache_type == "KVCache":
282
+ return [cache_lib.KVCache() for _ in range(num_layers)]
283
+ elif cache_type == "ArraysCache":
284
+ return [cache_lib.ArraysCache() for _ in range(num_layers)]
285
+ else:
286
+ return [None for _ in range(num_layers)]
287
+
288
+ def rewind_kv_caches(self, cache: list[Any], num_tokens: int) -> None:
289
+ """Trim cache to accepted prefix length."""
290
+ for layer_cache in cache:
291
+ if isinstance(layer_cache, cache_lib.KVCache):
292
+ layer_cache.trim(num_tokens)
293
+ elif isinstance(layer_cache, cache_lib.ArraysCache) and hasattr(layer_cache, "trim"):
294
+ layer_cache.trim(num_tokens)
295
+
296
+ # ── Forward with Hidden-State Extraction ─────────────────────────────────
297
+
298
+ def create_attention_mask(self, hidden_states: mx.array, cache: Any = None) -> Optional[mx.array]:
299
+ """Build causal attention mask appropriate for this architecture."""
300
+ # Default: simple causal mask via triangular structure
301
+ # MLX fast attention often handles this internally, but we provide a hook
302
+ seq_len = hidden_states.shape[1]
303
+ if cache is not None and hasattr(cache, "offset"):
304
+ # Cached generation β€” no mask needed for single new token
305
+ if seq_len == 1:
306
+ return None
307
+ return None # MLX models typically compute mask internally
308
+
309
+ def forward_with_hidden_states(
310
+ self,
311
+ tokens: mx.array,
312
+ cache: list[Any],
313
+ layer_ids: List[int],
314
+ output_rollback_records: bool = False,
315
+ ) -> Tuple[mx.array, mx.array] | Tuple[mx.array, mx.array, Dict]:
316
+ """
317
+ Run target model, returning (logits, target_hidden).
318
+ target_hidden = concatenation of hidden states at layer_ids.
319
+
320
+ Args:
321
+ tokens: Input token IDs [bsz, seq_len]
322
+ cache: Per-layer KV cache
323
+ layer_ids: Target layer indices for DFlash conditioning
324
+ output_rollback_records: Whether to return per-layer state for rollback
325
+
326
+ Returns:
327
+ (logits, target_hidden) or (logits, target_hidden, rollback_records)
328
+ """
329
+ if self._embed is None or self._layers is None:
330
+ raise RuntimeError(f"[{self.family}] Model attributes not resolved")
331
+
332
+ hidden = self.embed_tokens(tokens)
333
+ mask = self.create_attention_mask(hidden, cache[0] if cache else None)
334
+
335
+ selected: List[mx.array] = []
336
+ rollback_records: Dict[int, Dict[str, mx.array]] = {}
337
+ target_layer_ids = set(layer_ids)
338
+
339
+ for idx, (layer, layer_cache) in enumerate(zip(self._layers, cache)):
340
+ # Each layer returns updated hidden states
341
+ # Some return tuple (hidden, cache_update), some just hidden
342
+ layer_out = layer(hidden, mask=mask, cache=layer_cache)
343
+ if isinstance(layer_out, tuple):
344
+ hidden = layer_out[0]
345
+ else:
346
+ hidden = layer_out
347
+
348
+ if idx in target_layer_ids:
349
+ selected.append(hidden)
350
+
351
+ # Final norm + LM head
352
+ if self._norm is not None:
353
+ hidden = self._norm(hidden)
354
+ logits = self.lm_head_logits(hidden)
355
+
356
+ # Concatenate selected hidden states across feature dim
357
+ if selected:
358
+ target_hidden = mx.concatenate(selected, axis=-1)
359
+ else:
360
+ # Fallback: use final hidden state
361
+ target_hidden = hidden
362
+
363
+ if output_rollback_records:
364
+ return logits, target_hidden, rollback_records
365
+ return logits, target_hidden
366
+
367
+ def forward_verifier_states(
368
+ self,
369
+ tokens: mx.array,
370
+ cache: list[Any],
371
+ layer_ids: List[int],
372
+ ) -> Tuple[mx.array, mx.array, Dict]:
373
+ """Forward pass that always returns rollback records."""
374
+ return self.forward_with_hidden_states(
375
+ tokens, cache, layer_ids, output_rollback_records=True
376
+ )
377
+
378
+ def forward_accept_all_block(
379
+ self,
380
+ tokens: mx.array,
381
+ cache: list[Any],
382
+ layer_ids: List[int],
383
+ ) -> Tuple[mx.array, mx.array]:
384
+ """Single-token forward returning last-position logits + target hidden."""
385
+ logits, target_hidden = self.forward_with_hidden_states(
386
+ tokens, cache, layer_ids, output_rollback_records=False
387
+ )
388
+ return logits[:, -1:, :], target_hidden
389
+
390
+ # ── Cache Summary (for debugging) ───────────────────────────────────────
391
+
392
+ def cache_summary(self, cache: list[Any]) -> str:
393
+ """Human-readable cache status."""
394
+ parts: List[str] = []
395
+ for idx, c in enumerate(cache):
396
+ if isinstance(c, cache_lib.KVCache):
397
+ parts.append(f"{idx}:kv={c.offset}")
398
+ elif isinstance(c, cache_lib.ArraysCache):
399
+ rec = None if c[1] is None else tuple(c[1].shape)
400
+ parts.append(f"{idx}:ssm={rec}")
401
+ else:
402
+ parts.append(f"{idx}:none")
403
+ return " ".join(parts)
404
+
405
+
406
+ # ──────────────────────────────────────────────────────────────────────────────
407
+ # Per-family adapter subclasses (for architecture-specific overrides)
408
+ # ──────────────────────────────────────────────────────────────────────────────
409
+
410
+ class Qwen3Adapter(MLXTargetAdapter):
411
+ family = "qwen3"
412
+
413
+ def build_prompt(self, tokenizer, prompt_text: str, enable_thinking: bool = False) -> mx.array:
414
+ messages = [{"role": "user", "content": prompt_text}]
415
+ try:
416
+ text = tokenizer.apply_chat_template(
417
+ messages,
418
+ tokenize=False,
419
+ add_generation_prompt=True,
420
+ enable_thinking=enable_thinking,
421
+ )
422
+ except TypeError:
423
+ text = tokenizer.apply_chat_template(
424
+ messages,
425
+ tokenize=False,
426
+ add_generation_prompt=True,
427
+ )
428
+ tokens = tokenizer.encode(text, add_special_tokens=False)
429
+ return mx.array(tokens, dtype=mx.uint32)
430
+
431
+ def create_attention_mask(self, hidden_states: mx.array, cache: Any = None) -> Optional[mx.array]:
432
+ try:
433
+ from mlx_lm.models import qwen3
434
+ return qwen3.create_attention_mask(hidden_states, cache)
435
+ except Exception:
436
+ return super().create_attention_mask(hidden_states, cache)
437
+
438
+ def lm_head_logits(self, hidden_states: mx.array) -> mx.array:
439
+ # Qwen3 often uses tied embeddings
440
+ if self.arch_info.get("tie_embeddings") and self._embed is not None:
441
+ if hasattr(self._embed, "as_linear"):
442
+ return self._embed.as_linear(hidden_states)
443
+ if self._lm_head is not None:
444
+ return self._lm_head(hidden_states)
445
+ raise RuntimeError("[qwen3] No LM head found")
446
+
447
+
448
+ class Qwen35Adapter(MLXTargetAdapter):
449
+ family = "qwen3_5"
450
+
451
+ def build_prompt(self, tokenizer, prompt_text: str, enable_thinking: bool = False) -> mx.array:
452
+ messages = [{"role": "user", "content": prompt_text}]
453
+ try:
454
+ text = tokenizer.apply_chat_template(
455
+ messages,
456
+ tokenize=False,
457
+ add_generation_prompt=True,
458
+ enable_thinking=enable_thinking,
459
+ )
460
+ except TypeError:
461
+ text = tokenizer.apply_chat_template(
462
+ messages,
463
+ tokenize=False,
464
+ add_generation_prompt=True,
465
+ )
466
+ tokens = tokenizer.encode(text, add_special_tokens=False)
467
+ return mx.array(tokens, dtype=mx.uint32)
468
+
469
+ def create_attention_mask(self, hidden_states: mx.array, cache: Any = None) -> Optional[mx.array]:
470
+ try:
471
+ from mlx_lm.models import qwen3_5
472
+ # Qwen3.5 has hybrid attention: full-attention + linear-attention
473
+ if cache is not None and hasattr(cache, "__len__") and len(cache) > 0:
474
+ # Detect cache type
475
+ if hasattr(cache[0], "fa_idx"):
476
+ fa_mask = qwen3_5.create_attention_mask(hidden_states, cache[0])
477
+ return fa_mask
478
+ except Exception:
479
+ pass
480
+ return super().create_attention_mask(hidden_states, cache)
481
+
482
+ def forward_with_hidden_states(
483
+ self,
484
+ tokens: mx.array,
485
+ cache: list[Any],
486
+ layer_ids: List[int],
487
+ output_rollback_records: bool = False,
488
+ ):
489
+ # Qwen3.5 needs special handling for hybrid attention layers
490
+ if self._embed is None or self._layers is None:
491
+ raise RuntimeError("[qwen3_5] Model attributes not resolved")
492
+
493
+ hidden = self.embed_tokens(tokens)
494
+
495
+ # Build masks for full-attention and linear-attention layers
496
+ try:
497
+ from mlx_lm.models import qwen3_5
498
+ fa_mask = qwen3_5.create_attention_mask(hidden_states=hidden, cache=cache[0] if cache else None)
499
+ except Exception:
500
+ fa_mask = None
501
+
502
+ selected: List[mx.array] = []
503
+ target_layer_ids = set(layer_ids)
504
+
505
+ for idx, (layer, layer_cache) in enumerate(zip(self._layers, cache)):
506
+ # Qwen3.5 layers have is_linear flag
507
+ mask = None
508
+ if hasattr(layer, "is_linear") and layer.is_linear:
509
+ # Linear attention layer β€” uses different mask or none
510
+ pass
511
+ else:
512
+ mask = fa_mask
513
+
514
+ layer_out = layer(hidden, mask=mask, cache=layer_cache)
515
+ if isinstance(layer_out, tuple):
516
+ hidden = layer_out[0]
517
+ else:
518
+ hidden = layer_out
519
+
520
+ if idx in target_layer_ids:
521
+ selected.append(hidden)
522
+
523
+ if self._norm is not None:
524
+ hidden = self._norm(hidden)
525
+
526
+ logits = self.lm_head_logits(hidden)
527
+
528
+ if selected:
529
+ target_hidden = mx.concatenate(selected, axis=-1)
530
+ else:
531
+ target_hidden = hidden
532
+
533
+ if output_rollback_records:
534
+ return logits, target_hidden, {}
535
+ return logits, target_hidden
536
+
537
+
538
+ class LlamaAdapter(MLXTargetAdapter):
539
+ family = "llama"
540
+
541
+ def create_attention_mask(self, hidden_states: mx.array, cache: Any = None) -> Optional[mx.array]:
542
+ try:
543
+ from mlx_lm.models import llama
544
+ return llama.create_attention_mask(hidden_states, cache)
545
+ except Exception:
546
+ return super().create_attention_mask(hidden_states, cache)
547
+
548
+
549
+ class MistralAdapter(MLXTargetAdapter):
550
+ family = "mistral"
551
+
552
+ def create_attention_mask(self, hidden_states: mx.array, cache: Any = None) -> Optional[mx.array]:
553
+ try:
554
+ from mlx_lm.models import mistral
555
+ return mistral.create_attention_mask(hidden_states, cache)
556
+ except Exception:
557
+ return super().create_attention_mask(hidden_states, cache)
558
+
559
+
560
+ class GemmaAdapter(MLXTargetAdapter):
561
+ family = "gemma"
562
+
563
+ def create_attention_mask(self, hidden_states: mx.array, cache: Any = None) -> Optional[mx.array]:
564
+ try:
565
+ from mlx_lm.models import gemma
566
+ return gemma.create_attention_mask(hidden_states, cache)
567
+ except Exception:
568
+ return super().create_attention_mask(hidden_states, cache)
569
+
570
+
571
+ # ──────────────────────────────────────────────────────────────────────────────
572
+ # Adapter registry and factory
573
+ # ──────────────────────────────────────────────────────────────────────────────
574
+
575
+ ADAPTERS: Dict[str, type[MLXTargetAdapter]] = {
576
+ "qwen3": Qwen3Adapter,
577
+ "qwen2": Qwen3Adapter, # Shares structure
578
+ "qwen3_5": Qwen35Adapter,
579
+ "llama": LlamaAdapter,
580
+ "mistral": MistralAdapter,
581
+ "gemma": GemmaAdapter,
582
+ "gemma2": GemmaAdapter,
583
+ "generic": MLXTargetAdapter,
584
+ }
585
+
586
+
587
+ def adapter_for_model_type(model_type: str) -> Optional[type[MLXTargetAdapter]]:
588
+ """Get adapter class for a model type string."""
589
+ # Direct match
590
+ if model_type in ADAPTERS:
591
+ return ADAPTERS[model_type]
592
+ # Aliases
593
+ if model_type.startswith("qwen3_5") or model_type == "qwen3.5":
594
+ return Qwen35Adapter
595
+ if model_type.startswith("qwen3"):
596
+ return Qwen3Adapter
597
+ if model_type.startswith("qwen2"):
598
+ return Qwen3Adapter
599
+ if model_type.startswith("llama"):
600
+ return LlamaAdapter
601
+ if model_type.startswith("mistral"):
602
+ return MistralAdapter
603
+ if model_type == "gemma2":
604
+ return GemmaAdapter
605
+ if model_type.startswith("gemma"):
606
+ return GemmaAdapter
607
+ return None
608
+
609
+
610
+ # ──────────────────────────────────────────────────────────────────────────────
611
+ # LoadedTargetModel β€” convenience wrapper binding model + adapter + tokenizer
612
+ # ──────────────────────────────────────────────────────────────────────────────
613
+
614
+ @dataclass
615
+ class LoadedTargetModel:
616
+ requested_model: str
617
+ resolved_model_path: Path
618
+ model: Any
619
+ tokenizer: Any
620
+ adapter: MLXTargetAdapter
621
+
622
+ def build_prompt(self, prompt_text: str, enable_thinking: bool = False) -> mx.array:
623
+ return self.adapter.build_prompt(self.tokenizer, prompt_text, enable_thinking)
624
+
625
+ def stop_token_ids(self) -> set[int]:
626
+ return self.adapter.stop_token_ids(self.tokenizer)
627
+
628
+ def make_cache(self) -> list[Any]:
629
+ return self.adapter.make_cache()
630
+
631
+ def embed_tokens(self, tokens: mx.array) -> mx.array:
632
+ return self.adapter.embed_tokens(tokens)
633
+
634
+ def lm_head_logits(self, hidden_states: mx.array) -> mx.array:
635
+ return self.adapter.lm_head_logits(hidden_states)
636
+
637
+ def lm_head_argmax(self, hidden_states: mx.array) -> mx.array:
638
+ return self.adapter.lm_head_argmax(hidden_states)
639
+
640
+ def forward_with_hidden_states(
641
+ self,
642
+ tokens: mx.array,
643
+ cache: list[Any],
644
+ layer_ids: List[int],
645
+ output_rollback_records: bool = False,
646
+ ):
647
+ return self.adapter.forward_with_hidden_states(
648
+ tokens, cache, layer_ids, output_rollback_records
649
+ )
650
+
651
+ def forward_verifier_states(self, tokens: mx.array, cache: list[Any], layer_ids: List[int]):
652
+ return self.adapter.forward_verifier_states(tokens, cache, layer_ids)
653
+
654
+ def forward_accept_all_block(self, tokens: mx.array, cache: list[Any], layer_ids: List[int]):
655
+ return self.adapter.forward_accept_all_block(tokens, cache, layer_ids)
656
+
657
+ def rewind_kv_caches(self, cache: list[Any], num_tokens: int) -> None:
658
+ self.adapter.rewind_kv_caches(cache, num_tokens)
659
+
660
+ def cache_summary(self, cache: list[Any]) -> str:
661
+ return self.adapter.cache_summary(cache)
662
+
663
+
664
+ def load_target_model(path_or_repo: str) -> LoadedTargetModel:
665
+ """Load an MLX target model with the correct adapter.
666
+
667
+ Args:
668
+ path_or_repo: Local path or HF Hub model ID
669
+
670
+ Returns:
671
+ LoadedTargetModel with architecture-aware adapter
672
+ """
673
+ base_path = resolve_model_path(path_or_repo)
674
+
675
+ # Load config to detect architecture
676
+ config_path = base_path / "config.json"
677
+ if config_path.exists():
678
+ with open(config_path, "r") as f:
679
+ config = json.load(f)
680
+ else:
681
+ config = {}
682
+
683
+ model_type = config.get("model_type", "generic")
684
+ adapter_cls = adapter_for_model_type(model_type)
685
+
686
+ if adapter_cls is None:
687
+ registered = ", ".join(sorted(ADAPTERS.keys()))
688
+ raise NotImplementedError(
689
+ f"Unsupported MLX DFlash target model_type={model_type!r}. "
690
+ f"Registered adapters: {registered}. "
691
+ f"You can add one by subclassing MLXTargetAdapter in adapters.py."
692
+ )
693
+
694
+ # Load model + tokenizer via mlx_lm
695
+ model, tokenizer = load(str(base_path))
696
+
697
+ # Instantiate adapter
698
+ adapter = adapter_cls(model, config)
699
+
700
+ return LoadedTargetModel(
701
+ requested_model=path_or_repo,
702
+ resolved_model_path=base_path,
703
+ model=model,
704
+ tokenizer=tokenizer,
705
+ adapter=adapter,
706
+ )