tritesh commited on
Commit
4e22dea
·
verified ·
1 Parent(s): 834cedc

Upload dflash_mlx/model.py

Browse files
Files changed (1) hide show
  1. dflash_mlx/model.py +73 -8
dflash_mlx/model.py CHANGED
@@ -1,10 +1,13 @@
1
  """
2
  MLX implementation of the DFlash block diffusion draft model.
3
 
4
- This implements the core architecture from the DFlash paper (arXiv:2602.06036):
5
  - Block-level diffusion for parallel token drafting
6
  - KV injection of target model hidden features
7
  - Causal attention within blocks with cross-block masking
 
 
 
8
  """
9
 
10
  import math
@@ -28,14 +31,26 @@ class RMSNorm(nn.Module):
28
 
29
 
30
  def apply_rotary_emb(x, cos, sin):
31
- """Apply rotary positional embeddings."""
 
 
 
 
 
 
 
 
32
  x1, x2 = x[..., ::2], x[..., 1::2]
33
  rotated = mx.stack([-x2, x1], axis=-1).reshape(x.shape)
34
  return x * cos + rotated * sin
35
 
36
 
37
  def build_rope_cache(seq_len: int, head_dim: int, base: float = 10000.0):
38
- """Build rotary positional embedding cache."""
 
 
 
 
39
  theta = 1.0 / (base ** (mx.arange(0, head_dim, 2) / head_dim))
40
  positions = mx.arange(seq_len)
41
  angles = mx.outer(positions, theta)
@@ -47,12 +62,25 @@ def build_rope_cache(seq_len: int, head_dim: int, base: float = 10000.0):
47
  return cos, sin
48
 
49
 
 
 
 
 
 
 
 
 
 
 
50
  class DFlashAttention(nn.Module):
51
  """Multi-head attention with KV injection from target model features.
52
 
53
  This is the core of DFlash: the draft model's attention keys and values
54
  are augmented with projected target model hidden states, providing rich
55
  conditioning that enables high acceptance rates.
 
 
 
56
  """
57
 
58
  def __init__(
@@ -78,7 +106,7 @@ class DFlashAttention(nn.Module):
78
  self.v_proj = nn.Linear(hidden_size, num_kv_heads * head_dim, bias=False)
79
  self.o_proj = nn.Linear(num_heads * head_dim, hidden_size, bias=False)
80
 
81
- # Layer norms
82
  self.q_norm = RMSNorm(head_dim, eps=1e-6)
83
  self.k_norm = RMSNorm(head_dim, eps=1e-6)
84
 
@@ -90,6 +118,18 @@ class DFlashAttention(nn.Module):
90
  position_embeddings: Optional[Tuple[mx.array, mx.array]] = None,
91
  past_key_values: Optional[Tuple[mx.array, mx.array]] = None,
92
  ) -> mx.array:
 
 
 
 
 
 
 
 
 
 
 
 
93
  bsz, q_len = hidden_states.shape[:2]
94
  ctx_len = target_hidden.shape[1]
95
 
@@ -209,6 +249,8 @@ class DFlashDraftModel(nn.Module):
209
  - Target context feature projection (fuses cross-layer hidden states)
210
  - Rotary position embeddings
211
  - Block-wise parallel diffusion
 
 
212
  """
213
 
214
  def __init__(
@@ -281,7 +323,7 @@ class DFlashDraftModel(nn.Module):
281
  """Select target model layer indices for feature extraction.
282
 
283
  Uniformly samples from shallow to deep layers for cross-layer
284
- feature fusion.
285
  """
286
  if num_draft_layers == 1:
287
  return [num_target_layers // 2]
@@ -308,13 +350,25 @@ class DFlashDraftModel(nn.Module):
308
  """Extract and fuse target model hidden features.
309
 
310
  Args:
311
- hidden_states: List of hidden states from target model layers
 
312
 
313
  Returns:
314
  Fused target context feature [bsz, seq_len, hidden_size]
315
  """
316
- offset = 1 # Skip embedding layer
317
- selected = [hidden_states[layer_id + offset] for layer_id in self.target_layer_ids]
 
 
 
 
 
 
 
 
 
 
 
318
  target_hidden = mx.concatenate(selected, axis=-1)
319
  return self.hidden_norm(self.fc(target_hidden))
320
 
@@ -366,6 +420,10 @@ class DFlashDenoiser:
366
 
367
  Implements the iterative denoising process where masked tokens
368
  are progressively revealed in parallel within each block.
 
 
 
 
369
  """
370
 
371
  def __init__(self, model: DFlashDraftModel, num_steps: int = 12):
@@ -382,6 +440,8 @@ class DFlashDenoiser:
382
  ) -> mx.array:
383
  """Denoise a block of masked tokens in parallel.
384
 
 
 
385
  Args:
386
  draft_tokens: Token IDs with mask tokens [bsz, block_size]
387
  target_hidden: Target context features
@@ -394,11 +454,16 @@ class DFlashDenoiser:
394
  # Embed tokens
395
  embeddings = self.model.embed_tokens(draft_tokens)
396
 
 
 
 
 
397
  # Run draft model
398
  hidden_states = self.model(
399
  noise_embedding=embeddings,
400
  target_hidden=target_hidden,
401
  position_ids=position_ids,
 
402
  )
403
 
404
  # Get logits and sample
 
1
  """
2
  MLX implementation of the DFlash block diffusion draft model.
3
 
4
+ Implements the core architecture from the DFlash paper (arXiv:2602.06036):
5
  - Block-level diffusion for parallel token drafting
6
  - KV injection of target model hidden features
7
  - Causal attention within blocks with cross-block masking
8
+ - Position-dependent loss decay
9
+
10
+ Architecture-agnostic: works with any target model family via adapters.
11
  """
12
 
13
  import math
 
31
 
32
 
33
  def apply_rotary_emb(x, cos, sin):
34
+ """Apply rotary positional embeddings to x.
35
+
36
+ Args:
37
+ x: [..., seq_len, head_dim]
38
+ cos, sin: [seq_len, head_dim]
39
+
40
+ Returns:
41
+ Rotated tensor same shape as x
42
+ """
43
  x1, x2 = x[..., ::2], x[..., 1::2]
44
  rotated = mx.stack([-x2, x1], axis=-1).reshape(x.shape)
45
  return x * cos + rotated * sin
46
 
47
 
48
  def build_rope_cache(seq_len: int, head_dim: int, base: float = 10000.0):
49
+ """Build rotary positional embedding cache.
50
+
51
+ Returns:
52
+ cos, sin: [seq_len, head_dim] each interleaved for all dims
53
+ """
54
  theta = 1.0 / (base ** (mx.arange(0, head_dim, 2) / head_dim))
55
  positions = mx.arange(seq_len)
56
  angles = mx.outer(positions, theta)
 
62
  return cos, sin
63
 
64
 
65
+ def create_causal_mask(seq_len: int, dtype=mx.float32) -> mx.array:
66
+ """Create a causal attention mask for self-attention.
67
+
68
+ Returns [1, 1, seq_len, seq_len] mask with -inf in upper triangle.
69
+ """
70
+ mask = mx.triu(mx.ones((seq_len, seq_len), dtype=dtype), k=1)
71
+ mask = mx.where(mask == 1, -1e9, 0.0)
72
+ return mask[None, None, :, :] # [1, 1, seq_len, seq_len]
73
+
74
+
75
  class DFlashAttention(nn.Module):
76
  """Multi-head attention with KV injection from target model features.
77
 
78
  This is the core of DFlash: the draft model's attention keys and values
79
  are augmented with projected target model hidden states, providing rich
80
  conditioning that enables high acceptance rates.
81
+
82
+ Supports both standard attention and KV-injected cross-attention within
83
+ the same layer.
84
  """
85
 
86
  def __init__(
 
106
  self.v_proj = nn.Linear(hidden_size, num_kv_heads * head_dim, bias=False)
107
  self.o_proj = nn.Linear(num_heads * head_dim, hidden_size, bias=False)
108
 
109
+ # Layer norms for Q, K (Qwen3.5-style pre-norm in attention)
110
  self.q_norm = RMSNorm(head_dim, eps=1e-6)
111
  self.k_norm = RMSNorm(head_dim, eps=1e-6)
112
 
 
118
  position_embeddings: Optional[Tuple[mx.array, mx.array]] = None,
119
  past_key_values: Optional[Tuple[mx.array, mx.array]] = None,
120
  ) -> mx.array:
121
+ """Forward pass with KV injection.
122
+
123
+ Args:
124
+ hidden_states: Draft token embeddings [bsz, q_len, hidden_size]
125
+ target_hidden: Target context features [bsz, ctx_len, hidden_size]
126
+ attention_mask: Optional mask [1, 1, q_len, kv_len]
127
+ position_embeddings: Optional (cos, sin) for RoPE
128
+ past_key_values: Not used in DFlash (diffusion is non-autoregressive)
129
+
130
+ Returns:
131
+ Attention output [bsz, q_len, hidden_size]
132
+ """
133
  bsz, q_len = hidden_states.shape[:2]
134
  ctx_len = target_hidden.shape[1]
135
 
 
249
  - Target context feature projection (fuses cross-layer hidden states)
250
  - Rotary position embeddings
251
  - Block-wise parallel diffusion
252
+
253
+ Universal: config auto-detected from target model or specified explicitly.
254
  """
255
 
256
  def __init__(
 
323
  """Select target model layer indices for feature extraction.
324
 
325
  Uniformly samples from shallow to deep layers for cross-layer
326
+ feature fusion, as described in the DFlash paper.
327
  """
328
  if num_draft_layers == 1:
329
  return [num_target_layers // 2]
 
350
  """Extract and fuse target model hidden features.
351
 
352
  Args:
353
+ hidden_states: List of hidden states from target model layers.
354
+ hidden_states[0] is typically embedding layer output.
355
 
356
  Returns:
357
  Fused target context feature [bsz, seq_len, hidden_size]
358
  """
359
+ offset = 1 # Skip embedding layer (usually index 0)
360
+ selected = []
361
+ for layer_id in self.target_layer_ids:
362
+ idx = layer_id + offset
363
+ if idx < len(hidden_states):
364
+ selected.append(hidden_states[idx])
365
+ else:
366
+ # Fallback: use last available hidden state
367
+ selected.append(hidden_states[-1])
368
+
369
+ if not selected:
370
+ raise RuntimeError("[DFlashDraftModel] No hidden states available for extraction")
371
+
372
  target_hidden = mx.concatenate(selected, axis=-1)
373
  return self.hidden_norm(self.fc(target_hidden))
374
 
 
420
 
421
  Implements the iterative denoising process where masked tokens
422
  are progressively revealed in parallel within each block.
423
+
424
+ For simplicity, this uses a single-step denoising (the draft model
425
+ predicts all masked positions at once). The full DFlash paper
426
+ uses multiple denoising steps with noise scheduling.
427
  """
428
 
429
  def __init__(self, model: DFlashDraftModel, num_steps: int = 12):
 
440
  ) -> mx.array:
441
  """Denoise a block of masked tokens in parallel.
442
 
443
+ Single-step: embed tokens, run draft model, sample predictions.
444
+
445
  Args:
446
  draft_tokens: Token IDs with mask tokens [bsz, block_size]
447
  target_hidden: Target context features
 
454
  # Embed tokens
455
  embeddings = self.model.embed_tokens(draft_tokens)
456
 
457
+ # Build causal mask for the block (tokens attend to context + earlier positions)
458
+ seq_len = draft_tokens.shape[1]
459
+ mask = create_causal_mask(seq_len)
460
+
461
  # Run draft model
462
  hidden_states = self.model(
463
  noise_embedding=embeddings,
464
  target_hidden=target_hidden,
465
  position_ids=position_ids,
466
+ attention_mask=mask,
467
  )
468
 
469
  # Get logits and sample