""" MLX implementation of the DFlash block diffusion draft model. Implements the core architecture from the DFlash paper (arXiv:2602.06036): - Block-level diffusion for parallel token drafting - KV injection of target model hidden features - Causal attention within blocks with cross-block masking - Position-dependent loss decay Architecture-agnostic: works with any target model family via adapters. """ import math from typing import Optional, Tuple, List import mlx.core as mx import mlx.nn as nn class RMSNorm(nn.Module): """RMSNorm as used in Qwen/Llama models.""" def __init__(self, dims: int, eps: float = 1e-6): super().__init__() self.weight = mx.ones((dims,)) self.eps = eps def __call__(self, x): var = mx.mean(mx.square(x), axis=-1, keepdims=True) x = x * mx.rsqrt(var + self.eps) return self.weight * x def apply_rotary_emb(x, cos, sin): """Apply rotary positional embeddings to x. Args: x: [..., seq_len, head_dim] cos, sin: [seq_len, head_dim] Returns: Rotated tensor same shape as x """ x1, x2 = x[..., ::2], x[..., 1::2] rotated = mx.stack([-x2, x1], axis=-1).reshape(x.shape) return x * cos + rotated * sin def build_rope_cache(seq_len: int, head_dim: int, base: float = 10000.0): """Build rotary positional embedding cache. Returns: cos, sin: [seq_len, head_dim] each interleaved for all dims """ theta = 1.0 / (base ** (mx.arange(0, head_dim, 2) / head_dim)) positions = mx.arange(seq_len) angles = mx.outer(positions, theta) cos = mx.cos(angles) sin = mx.sin(angles) # Interleave for all head dimensions cos = mx.repeat(cos, 2, axis=-1) sin = mx.repeat(sin, 2, axis=-1) return cos, sin def create_causal_mask(seq_len: int, dtype=mx.float32) -> mx.array: """Create a causal attention mask for self-attention. Returns [1, 1, seq_len, seq_len] mask with -inf in upper triangle. """ mask = mx.triu(mx.ones((seq_len, seq_len), dtype=dtype), k=1) mask = mx.where(mask == 1, -1e9, 0.0) return mask[None, None, :, :] # [1, 1, seq_len, seq_len] class DFlashAttention(nn.Module): """Multi-head attention with KV injection from target model features. This is the core of DFlash: the draft model's attention keys and values are augmented with projected target model hidden states, providing rich conditioning that enables high acceptance rates. Supports both standard attention and KV-injected cross-attention within the same layer. """ def __init__( self, hidden_size: int, num_heads: int, num_kv_heads: int, head_dim: int, layer_idx: int = 0, ): super().__init__() self.hidden_size = hidden_size self.num_heads = num_heads self.num_kv_heads = num_kv_heads self.head_dim = head_dim self.num_kv_groups = num_heads // num_kv_heads self.layer_idx = layer_idx self.scale = head_dim ** -0.5 # Q, K, V projections for noise tokens self.q_proj = nn.Linear(hidden_size, num_heads * head_dim, bias=False) self.k_proj = nn.Linear(hidden_size, num_kv_heads * head_dim, bias=False) self.v_proj = nn.Linear(hidden_size, num_kv_heads * head_dim, bias=False) self.o_proj = nn.Linear(num_heads * head_dim, hidden_size, bias=False) # Layer norms for Q, K (Qwen3.5-style pre-norm in attention) self.q_norm = RMSNorm(head_dim, eps=1e-6) self.k_norm = RMSNorm(head_dim, eps=1e-6) def __call__( self, hidden_states: mx.array, target_hidden: mx.array, attention_mask: Optional[mx.array] = None, position_embeddings: Optional[Tuple[mx.array, mx.array]] = None, past_key_values: Optional[Tuple[mx.array, mx.array]] = None, ) -> mx.array: """Forward pass with KV injection. Args: hidden_states: Draft token embeddings [bsz, q_len, hidden_size] target_hidden: Target context features [bsz, ctx_len, hidden_size] attention_mask: Optional mask [1, 1, q_len, kv_len] position_embeddings: Optional (cos, sin) for RoPE past_key_values: Not used in DFlash (diffusion is non-autoregressive) Returns: Attention output [bsz, q_len, hidden_size] """ bsz, q_len = hidden_states.shape[:2] ctx_len = target_hidden.shape[1] # Project noise tokens for queries q = self.q_proj(hidden_states) q = q.reshape(bsz, q_len, self.num_heads, self.head_dim) q = self.q_norm(q).transpose(0, 2, 1, 3) # [bsz, num_heads, q_len, head_dim] # Project target hidden states for context keys/values k_ctx = self.k_proj(target_hidden) v_ctx = self.v_proj(target_hidden) # Project noise tokens for keys/values k_noise = self.k_proj(hidden_states) v_noise = self.v_proj(hidden_states) # Concatenate context + noise for K and V k = mx.concatenate([k_ctx, k_noise], axis=1) v = mx.concatenate([v_ctx, v_noise], axis=1) k = k.reshape(bsz, ctx_len + q_len, self.num_kv_heads, self.head_dim) v = v.reshape(bsz, ctx_len + q_len, self.num_kv_heads, self.head_dim) k = self.k_norm(k).transpose(0, 2, 1, 3) v = v.transpose(0, 2, 1, 3) # Apply rotary embeddings if provided if position_embeddings is not None: cos, sin = position_embeddings q = apply_rotary_emb(q, cos, sin) k = apply_rotary_emb(k, cos, sin) # Repeat k/v for grouped query attention if self.num_kv_groups > 1: k = mx.repeat(k, self.num_kv_groups, axis=1) v = mx.repeat(v, self.num_kv_groups, axis=1) # Compute attention scores scores = mx.matmul(q, k.transpose(0, 1, 3, 2)) * self.scale if attention_mask is not None: scores = scores + attention_mask attn_weights = mx.softmax(scores, axis=-1) attn_output = mx.matmul(attn_weights, v) attn_output = attn_output.transpose(0, 2, 1, 3).reshape(bsz, q_len, -1) return self.o_proj(attn_output) class DFlashMLP(nn.Module): """Standard SwiGLU MLP as used in modern LLMs.""" def __init__(self, hidden_size: int, intermediate_size: int): super().__init__() self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) def __call__(self, x): return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) class DFlashDecoderLayer(nn.Module): """Single decoder layer with KV-injected attention and MLP.""" def __init__( self, hidden_size: int, num_heads: int, num_kv_heads: int, head_dim: int, intermediate_size: int, layer_idx: int = 0, ): super().__init__() self.hidden_size = hidden_size self.self_attn = DFlashAttention( hidden_size=hidden_size, num_heads=num_heads, num_kv_heads=num_kv_heads, head_dim=head_dim, layer_idx=layer_idx, ) self.mlp = DFlashMLP(hidden_size, intermediate_size) self.input_layernorm = RMSNorm(hidden_size, eps=1e-6) self.post_attention_layernorm = RMSNorm(hidden_size, eps=1e-6) def __call__( self, hidden_states: mx.array, target_hidden: mx.array, attention_mask: Optional[mx.array] = None, position_embeddings: Optional[Tuple[mx.array, mx.array]] = None, ) -> mx.array: # Pre-norm + attention residual = hidden_states hidden_states = self.input_layernorm(hidden_states) hidden_states = self.self_attn( hidden_states=hidden_states, target_hidden=target_hidden, attention_mask=attention_mask, position_embeddings=position_embeddings, ) hidden_states = residual + hidden_states # Pre-norm + MLP residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states return hidden_states class DFlashDraftModel(nn.Module): """Complete DFlash block diffusion draft model for MLX. Architecture: - N decoder layers with KV-injected attention - Target context feature projection (fuses cross-layer hidden states) - Rotary position embeddings - Block-wise parallel diffusion Universal: config auto-detected from target model or specified explicitly. """ def __init__( self, vocab_size: int, hidden_size: int = 1024, num_layers: int = 5, num_heads: int = 16, num_kv_heads: int = 4, intermediate_size: int = 2816, max_seq_len: int = 8192, block_size: int = 16, mask_token_id: int = 0, num_target_layers: int = 32, target_layer_ids: Optional[List[int]] = None, rope_base: float = 10000.0, ): super().__init__() self.vocab_size = vocab_size self.hidden_size = hidden_size self.num_layers = num_layers self.num_heads = num_heads self.head_dim = hidden_size // num_heads self.block_size = block_size self.mask_token_id = mask_token_id self.num_target_layers = num_target_layers self.max_seq_len = max_seq_len # Target layer ids for feature extraction if target_layer_ids is None: self.target_layer_ids = self._build_target_layer_ids( num_target_layers, num_layers ) else: self.target_layer_ids = target_layer_ids # Token embeddings for noise/mask tokens self.embed_tokens = nn.Embedding(vocab_size, hidden_size) # Feature projection: fuse multi-layer target features num_target_features = len(self.target_layer_ids) self.fc = nn.Linear(num_target_features * hidden_size, hidden_size, bias=False) self.hidden_norm = RMSNorm(hidden_size, eps=1e-6) # Decoder layers self.layers = [ DFlashDecoderLayer( hidden_size=hidden_size, num_heads=num_heads, num_kv_heads=num_kv_heads, head_dim=self.head_dim, intermediate_size=intermediate_size, layer_idx=i, ) for i in range(num_layers) ] # Final norm self.norm = RMSNorm(hidden_size, eps=1e-6) # Language modeling head (shared with embed_tokens or separate) self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False) # Pre-compute rope cache self.rope_base = rope_base self._rope_cos = None self._rope_sin = None def _build_target_layer_ids(self, num_target_layers: int, num_draft_layers: int) -> List[int]: """Select target model layer indices for feature extraction. Uniformly samples from shallow to deep layers for cross-layer feature fusion, as described in the DFlash paper. """ if num_draft_layers == 1: return [num_target_layers // 2] start = 1 end = num_target_layers - 3 span = end - start return [ int(round(start + (i * span) / (num_draft_layers - 1))) for i in range(num_draft_layers) ] def _get_rope_cache(self, seq_len: int): """Get or build rotary position embedding cache.""" if self._rope_cos is None or self._rope_cos.shape[0] < seq_len: cos, sin = build_rope_cache(seq_len, self.head_dim, self.rope_base) self._rope_cos = cos self._rope_sin = sin return self._rope_cos[:seq_len], self._rope_sin[:seq_len] def extract_context_features( self, hidden_states: List[mx.array], ) -> mx.array: """Extract and fuse target model hidden features. Args: hidden_states: List of hidden states from target model layers. hidden_states[0] is typically embedding layer output. Returns: Fused target context feature [bsz, seq_len, hidden_size] """ offset = 1 # Skip embedding layer (usually index 0) selected = [] for layer_id in self.target_layer_ids: idx = layer_id + offset if idx < len(hidden_states): selected.append(hidden_states[idx]) else: # Fallback: use last available hidden state selected.append(hidden_states[-1]) if not selected: raise RuntimeError("[DFlashDraftModel] No hidden states available for extraction") target_hidden = mx.concatenate(selected, axis=-1) return self.hidden_norm(self.fc(target_hidden)) def __call__( self, noise_embedding: mx.array, target_hidden: mx.array, attention_mask: Optional[mx.array] = None, position_ids: Optional[mx.array] = None, ) -> mx.array: """Forward pass of the DFlash draft model. Args: noise_embedding: Embedded noise/mask tokens [bsz, seq_len, hidden_size] target_hidden: Fused target context features [bsz, ctx_len, hidden_size] attention_mask: Optional attention mask position_ids: Optional position IDs for rotary embeddings Returns: Hidden states [bsz, seq_len, hidden_size] """ bsz, seq_len = noise_embedding.shape[:2] # Build position embeddings if position_ids is None: position_ids = mx.arange(seq_len) cos, sin = self._get_rope_cache(seq_len) position_embeddings = (cos[position_ids], sin[position_ids]) # Pass through decoder layers hidden_states = noise_embedding for layer in self.layers: hidden_states = layer( hidden_states=hidden_states, target_hidden=target_hidden, attention_mask=attention_mask, position_embeddings=position_embeddings, ) return self.norm(hidden_states) def get_logits(self, hidden_states: mx.array) -> mx.array: """Get logits from hidden states.""" return self.lm_head(hidden_states) class DFlashDenoiser: """Block diffusion denoising for parallel token prediction. Implements the iterative denoising process where masked tokens are progressively revealed in parallel within each block. For simplicity, this uses a single-step denoising (the draft model predicts all masked positions at once). The full DFlash paper uses multiple denoising steps with noise scheduling. """ def __init__(self, model: DFlashDraftModel, num_steps: int = 12): self.model = model self.num_steps = num_steps self.mask_token_id = model.mask_token_id def denoise_block( self, draft_tokens: mx.array, target_hidden: mx.array, position_ids: mx.array, temperature: float = 0.0, ) -> mx.array: """Denoise a block of masked tokens in parallel. Single-step: embed tokens, run draft model, sample predictions. Args: draft_tokens: Token IDs with mask tokens [bsz, block_size] target_hidden: Target context features position_ids: Position IDs for the block temperature: Sampling temperature Returns: Predicted token IDs [bsz, block_size] """ # Embed tokens embeddings = self.model.embed_tokens(draft_tokens) # Build causal mask for the block (tokens attend to context + earlier positions) seq_len = draft_tokens.shape[1] mask = create_causal_mask(seq_len) # Run draft model hidden_states = self.model( noise_embedding=embeddings, target_hidden=target_hidden, position_ids=position_ids, attention_mask=mask, ) # Get logits and sample logits = self.model.get_logits(hidden_states) if temperature < 1e-5: # Greedy tokens = mx.argmax(logits, axis=-1) else: # Temperature sampling probs = mx.softmax(logits / temperature, axis=-1) tokens = mx.random.categorical(mx.log(probs)) return tokens