| """ |
| MLX implementation of the DFlash block diffusion draft model. |
| |
| Implements the core architecture from the DFlash paper (arXiv:2602.06036): |
| - Block-level diffusion for parallel token drafting |
| - KV injection of target model hidden features |
| - Causal attention within blocks with cross-block masking |
| - Position-dependent loss decay |
| |
| Architecture-agnostic: works with any target model family via adapters. |
| """ |
|
|
| import math |
| from typing import Optional, Tuple, List |
| import mlx.core as mx |
| import mlx.nn as nn |
|
|
|
|
| class RMSNorm(nn.Module): |
| """RMSNorm as used in Qwen/Llama models.""" |
|
|
| def __init__(self, dims: int, eps: float = 1e-6): |
| super().__init__() |
| self.weight = mx.ones((dims,)) |
| self.eps = eps |
|
|
| def __call__(self, x): |
| var = mx.mean(mx.square(x), axis=-1, keepdims=True) |
| x = x * mx.rsqrt(var + self.eps) |
| return self.weight * x |
|
|
|
|
| def apply_rotary_emb(x, cos, sin): |
| """Apply rotary positional embeddings to x. |
| |
| Args: |
| x: [..., seq_len, head_dim] |
| cos, sin: [seq_len, head_dim] |
| |
| Returns: |
| Rotated tensor same shape as x |
| """ |
| x1, x2 = x[..., ::2], x[..., 1::2] |
| rotated = mx.stack([-x2, x1], axis=-1).reshape(x.shape) |
| return x * cos + rotated * sin |
|
|
|
|
| def build_rope_cache(seq_len: int, head_dim: int, base: float = 10000.0): |
| """Build rotary positional embedding cache. |
| |
| Returns: |
| cos, sin: [seq_len, head_dim] each interleaved for all dims |
| """ |
| theta = 1.0 / (base ** (mx.arange(0, head_dim, 2) / head_dim)) |
| positions = mx.arange(seq_len) |
| angles = mx.outer(positions, theta) |
| cos = mx.cos(angles) |
| sin = mx.sin(angles) |
| |
| cos = mx.repeat(cos, 2, axis=-1) |
| sin = mx.repeat(sin, 2, axis=-1) |
| return cos, sin |
|
|
|
|
| def create_causal_mask(seq_len: int, dtype=mx.float32) -> mx.array: |
| """Create a causal attention mask for self-attention. |
| |
| Returns [1, 1, seq_len, seq_len] mask with -inf in upper triangle. |
| """ |
| mask = mx.triu(mx.ones((seq_len, seq_len), dtype=dtype), k=1) |
| mask = mx.where(mask == 1, -1e9, 0.0) |
| return mask[None, None, :, :] |
|
|
|
|
| class DFlashAttention(nn.Module): |
| """Multi-head attention with KV injection from target model features. |
| |
| This is the core of DFlash: the draft model's attention keys and values |
| are augmented with projected target model hidden states, providing rich |
| conditioning that enables high acceptance rates. |
| |
| Supports both standard attention and KV-injected cross-attention within |
| the same layer. |
| """ |
|
|
| def __init__( |
| self, |
| hidden_size: int, |
| num_heads: int, |
| num_kv_heads: int, |
| head_dim: int, |
| layer_idx: int = 0, |
| ): |
| super().__init__() |
| self.hidden_size = hidden_size |
| self.num_heads = num_heads |
| self.num_kv_heads = num_kv_heads |
| self.head_dim = head_dim |
| self.num_kv_groups = num_heads // num_kv_heads |
| self.layer_idx = layer_idx |
| self.scale = head_dim ** -0.5 |
|
|
| |
| self.q_proj = nn.Linear(hidden_size, num_heads * head_dim, bias=False) |
| self.k_proj = nn.Linear(hidden_size, num_kv_heads * head_dim, bias=False) |
| self.v_proj = nn.Linear(hidden_size, num_kv_heads * head_dim, bias=False) |
| self.o_proj = nn.Linear(num_heads * head_dim, hidden_size, bias=False) |
|
|
| |
| self.q_norm = RMSNorm(head_dim, eps=1e-6) |
| self.k_norm = RMSNorm(head_dim, eps=1e-6) |
|
|
| def __call__( |
| self, |
| hidden_states: mx.array, |
| target_hidden: mx.array, |
| attention_mask: Optional[mx.array] = None, |
| position_embeddings: Optional[Tuple[mx.array, mx.array]] = None, |
| past_key_values: Optional[Tuple[mx.array, mx.array]] = None, |
| ) -> mx.array: |
| """Forward pass with KV injection. |
| |
| Args: |
| hidden_states: Draft token embeddings [bsz, q_len, hidden_size] |
| target_hidden: Target context features [bsz, ctx_len, hidden_size] |
| attention_mask: Optional mask [1, 1, q_len, kv_len] |
| position_embeddings: Optional (cos, sin) for RoPE |
| past_key_values: Not used in DFlash (diffusion is non-autoregressive) |
| |
| Returns: |
| Attention output [bsz, q_len, hidden_size] |
| """ |
| bsz, q_len = hidden_states.shape[:2] |
| ctx_len = target_hidden.shape[1] |
|
|
| |
| q = self.q_proj(hidden_states) |
| q = q.reshape(bsz, q_len, self.num_heads, self.head_dim) |
| q = self.q_norm(q).transpose(0, 2, 1, 3) |
|
|
| |
| k_ctx = self.k_proj(target_hidden) |
| v_ctx = self.v_proj(target_hidden) |
|
|
| |
| k_noise = self.k_proj(hidden_states) |
| v_noise = self.v_proj(hidden_states) |
|
|
| |
| k = mx.concatenate([k_ctx, k_noise], axis=1) |
| v = mx.concatenate([v_ctx, v_noise], axis=1) |
| k = k.reshape(bsz, ctx_len + q_len, self.num_kv_heads, self.head_dim) |
| v = v.reshape(bsz, ctx_len + q_len, self.num_kv_heads, self.head_dim) |
| k = self.k_norm(k).transpose(0, 2, 1, 3) |
| v = v.transpose(0, 2, 1, 3) |
|
|
| |
| if position_embeddings is not None: |
| cos, sin = position_embeddings |
| q = apply_rotary_emb(q, cos, sin) |
| k = apply_rotary_emb(k, cos, sin) |
|
|
| |
| if self.num_kv_groups > 1: |
| k = mx.repeat(k, self.num_kv_groups, axis=1) |
| v = mx.repeat(v, self.num_kv_groups, axis=1) |
|
|
| |
| scores = mx.matmul(q, k.transpose(0, 1, 3, 2)) * self.scale |
|
|
| if attention_mask is not None: |
| scores = scores + attention_mask |
|
|
| attn_weights = mx.softmax(scores, axis=-1) |
| attn_output = mx.matmul(attn_weights, v) |
| attn_output = attn_output.transpose(0, 2, 1, 3).reshape(bsz, q_len, -1) |
| return self.o_proj(attn_output) |
|
|
|
|
| class DFlashMLP(nn.Module): |
| """Standard SwiGLU MLP as used in modern LLMs.""" |
|
|
| def __init__(self, hidden_size: int, intermediate_size: int): |
| super().__init__() |
| self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) |
| self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) |
| self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) |
|
|
| def __call__(self, x): |
| return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) |
|
|
|
|
| class DFlashDecoderLayer(nn.Module): |
| """Single decoder layer with KV-injected attention and MLP.""" |
|
|
| def __init__( |
| self, |
| hidden_size: int, |
| num_heads: int, |
| num_kv_heads: int, |
| head_dim: int, |
| intermediate_size: int, |
| layer_idx: int = 0, |
| ): |
| super().__init__() |
| self.hidden_size = hidden_size |
| self.self_attn = DFlashAttention( |
| hidden_size=hidden_size, |
| num_heads=num_heads, |
| num_kv_heads=num_kv_heads, |
| head_dim=head_dim, |
| layer_idx=layer_idx, |
| ) |
| self.mlp = DFlashMLP(hidden_size, intermediate_size) |
| self.input_layernorm = RMSNorm(hidden_size, eps=1e-6) |
| self.post_attention_layernorm = RMSNorm(hidden_size, eps=1e-6) |
|
|
| def __call__( |
| self, |
| hidden_states: mx.array, |
| target_hidden: mx.array, |
| attention_mask: Optional[mx.array] = None, |
| position_embeddings: Optional[Tuple[mx.array, mx.array]] = None, |
| ) -> mx.array: |
| |
| residual = hidden_states |
| hidden_states = self.input_layernorm(hidden_states) |
| hidden_states = self.self_attn( |
| hidden_states=hidden_states, |
| target_hidden=target_hidden, |
| attention_mask=attention_mask, |
| position_embeddings=position_embeddings, |
| ) |
| hidden_states = residual + hidden_states |
|
|
| |
| residual = hidden_states |
| hidden_states = self.post_attention_layernorm(hidden_states) |
| hidden_states = self.mlp(hidden_states) |
| hidden_states = residual + hidden_states |
| return hidden_states |
|
|
|
|
| class DFlashDraftModel(nn.Module): |
| """Complete DFlash block diffusion draft model for MLX. |
| |
| Architecture: |
| - N decoder layers with KV-injected attention |
| - Target context feature projection (fuses cross-layer hidden states) |
| - Rotary position embeddings |
| - Block-wise parallel diffusion |
| |
| Universal: config auto-detected from target model or specified explicitly. |
| """ |
|
|
| def __init__( |
| self, |
| vocab_size: int, |
| hidden_size: int = 1024, |
| num_layers: int = 5, |
| num_heads: int = 16, |
| num_kv_heads: int = 4, |
| intermediate_size: int = 2816, |
| max_seq_len: int = 8192, |
| block_size: int = 16, |
| mask_token_id: int = 0, |
| num_target_layers: int = 32, |
| target_layer_ids: Optional[List[int]] = None, |
| rope_base: float = 10000.0, |
| ): |
| super().__init__() |
| self.vocab_size = vocab_size |
| self.hidden_size = hidden_size |
| self.num_layers = num_layers |
| self.num_heads = num_heads |
| self.head_dim = hidden_size // num_heads |
| self.block_size = block_size |
| self.mask_token_id = mask_token_id |
| self.num_target_layers = num_target_layers |
| self.max_seq_len = max_seq_len |
|
|
| |
| if target_layer_ids is None: |
| self.target_layer_ids = self._build_target_layer_ids( |
| num_target_layers, num_layers |
| ) |
| else: |
| self.target_layer_ids = target_layer_ids |
|
|
| |
| self.embed_tokens = nn.Embedding(vocab_size, hidden_size) |
|
|
| |
| num_target_features = len(self.target_layer_ids) |
| self.fc = nn.Linear(num_target_features * hidden_size, hidden_size, bias=False) |
| self.hidden_norm = RMSNorm(hidden_size, eps=1e-6) |
|
|
| |
| self.layers = [ |
| DFlashDecoderLayer( |
| hidden_size=hidden_size, |
| num_heads=num_heads, |
| num_kv_heads=num_kv_heads, |
| head_dim=self.head_dim, |
| intermediate_size=intermediate_size, |
| layer_idx=i, |
| ) |
| for i in range(num_layers) |
| ] |
|
|
| |
| self.norm = RMSNorm(hidden_size, eps=1e-6) |
|
|
| |
| self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False) |
|
|
| |
| self.rope_base = rope_base |
| self._rope_cos = None |
| self._rope_sin = None |
|
|
| def _build_target_layer_ids(self, num_target_layers: int, num_draft_layers: int) -> List[int]: |
| """Select target model layer indices for feature extraction. |
| |
| Uniformly samples from shallow to deep layers for cross-layer |
| feature fusion, as described in the DFlash paper. |
| """ |
| if num_draft_layers == 1: |
| return [num_target_layers // 2] |
| start = 1 |
| end = num_target_layers - 3 |
| span = end - start |
| return [ |
| int(round(start + (i * span) / (num_draft_layers - 1))) |
| for i in range(num_draft_layers) |
| ] |
|
|
| def _get_rope_cache(self, seq_len: int): |
| """Get or build rotary position embedding cache.""" |
| if self._rope_cos is None or self._rope_cos.shape[0] < seq_len: |
| cos, sin = build_rope_cache(seq_len, self.head_dim, self.rope_base) |
| self._rope_cos = cos |
| self._rope_sin = sin |
| return self._rope_cos[:seq_len], self._rope_sin[:seq_len] |
|
|
| def extract_context_features( |
| self, |
| hidden_states: List[mx.array], |
| ) -> mx.array: |
| """Extract and fuse target model hidden features. |
| |
| Args: |
| hidden_states: List of hidden states from target model layers. |
| hidden_states[0] is typically embedding layer output. |
| |
| Returns: |
| Fused target context feature [bsz, seq_len, hidden_size] |
| """ |
| offset = 1 |
| selected = [] |
| for layer_id in self.target_layer_ids: |
| idx = layer_id + offset |
| if idx < len(hidden_states): |
| selected.append(hidden_states[idx]) |
| else: |
| |
| selected.append(hidden_states[-1]) |
| |
| if not selected: |
| raise RuntimeError("[DFlashDraftModel] No hidden states available for extraction") |
| |
| target_hidden = mx.concatenate(selected, axis=-1) |
| return self.hidden_norm(self.fc(target_hidden)) |
|
|
| def __call__( |
| self, |
| noise_embedding: mx.array, |
| target_hidden: mx.array, |
| attention_mask: Optional[mx.array] = None, |
| position_ids: Optional[mx.array] = None, |
| ) -> mx.array: |
| """Forward pass of the DFlash draft model. |
| |
| Args: |
| noise_embedding: Embedded noise/mask tokens [bsz, seq_len, hidden_size] |
| target_hidden: Fused target context features [bsz, ctx_len, hidden_size] |
| attention_mask: Optional attention mask |
| position_ids: Optional position IDs for rotary embeddings |
| |
| Returns: |
| Hidden states [bsz, seq_len, hidden_size] |
| """ |
| bsz, seq_len = noise_embedding.shape[:2] |
|
|
| |
| if position_ids is None: |
| position_ids = mx.arange(seq_len) |
| cos, sin = self._get_rope_cache(seq_len) |
| position_embeddings = (cos[position_ids], sin[position_ids]) |
|
|
| |
| hidden_states = noise_embedding |
| for layer in self.layers: |
| hidden_states = layer( |
| hidden_states=hidden_states, |
| target_hidden=target_hidden, |
| attention_mask=attention_mask, |
| position_embeddings=position_embeddings, |
| ) |
|
|
| return self.norm(hidden_states) |
|
|
| def get_logits(self, hidden_states: mx.array) -> mx.array: |
| """Get logits from hidden states.""" |
| return self.lm_head(hidden_states) |
|
|
|
|
| class DFlashDenoiser: |
| """Block diffusion denoising for parallel token prediction. |
| |
| Implements the iterative denoising process where masked tokens |
| are progressively revealed in parallel within each block. |
| |
| For simplicity, this uses a single-step denoising (the draft model |
| predicts all masked positions at once). The full DFlash paper |
| uses multiple denoising steps with noise scheduling. |
| """ |
|
|
| def __init__(self, model: DFlashDraftModel, num_steps: int = 12): |
| self.model = model |
| self.num_steps = num_steps |
| self.mask_token_id = model.mask_token_id |
|
|
| def denoise_block( |
| self, |
| draft_tokens: mx.array, |
| target_hidden: mx.array, |
| position_ids: mx.array, |
| temperature: float = 0.0, |
| ) -> mx.array: |
| """Denoise a block of masked tokens in parallel. |
| |
| Single-step: embed tokens, run draft model, sample predictions. |
| |
| Args: |
| draft_tokens: Token IDs with mask tokens [bsz, block_size] |
| target_hidden: Target context features |
| position_ids: Position IDs for the block |
| temperature: Sampling temperature |
| |
| Returns: |
| Predicted token IDs [bsz, block_size] |
| """ |
| |
| embeddings = self.model.embed_tokens(draft_tokens) |
|
|
| |
| seq_len = draft_tokens.shape[1] |
| mask = create_causal_mask(seq_len) |
|
|
| |
| hidden_states = self.model( |
| noise_embedding=embeddings, |
| target_hidden=target_hidden, |
| position_ids=position_ids, |
| attention_mask=mask, |
| ) |
|
|
| |
| logits = self.model.get_logits(hidden_states) |
|
|
| if temperature < 1e-5: |
| |
| tokens = mx.argmax(logits, axis=-1) |
| else: |
| |
| probs = mx.softmax(logits / temperature, axis=-1) |
| tokens = mx.random.categorical(mx.log(probs)) |
|
|
| return tokens |
|
|