tritesh commited on
Commit
d00960b
·
verified ·
1 Parent(s): 9579572

Upload tests/test_model.py

Browse files
Files changed (1) hide show
  1. tests/test_model.py +204 -23
tests/test_model.py CHANGED
@@ -1,4 +1,8 @@
1
- """Tests for DFlash MLX model architecture."""
 
 
 
 
2
 
3
  import unittest
4
  import mlx.core as mx
@@ -8,61 +12,238 @@ from dflash_mlx.model import (
8
  DFlashMLP,
9
  DFlashDecoderLayer,
10
  DFlashDraftModel,
 
 
 
 
11
  )
12
 
13
 
14
  class TestRMSNorm(unittest.TestCase):
 
 
15
  def test_shape_preservation(self):
 
16
  norm = RMSNorm(dims=128)
17
- x = mx.random.normal(shape=(2, 10, 128))
 
 
 
 
 
 
 
18
  out = norm(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  self.assertEqual(out.shape, x.shape)
20
 
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  class TestDFlashAttention(unittest.TestCase):
23
- def test_forward(self):
 
 
 
24
  attn = DFlashAttention(
25
  hidden_size=256,
26
  num_heads=4,
27
  num_kv_heads=2,
28
  head_dim=64,
29
- layer_idx=0,
30
  )
31
- hidden = mx.random.normal(shape=(1, 10, 256))
32
- target_hidden = mx.random.normal(shape=(1, 5, 256))
33
- out = attn(hidden, target_hidden)
34
- self.assertEqual(out.shape, (1, 10, 256))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
 
37
  class TestDFlashDraftModel(unittest.TestCase):
38
- def test_forward(self):
 
 
 
39
  model = DFlashDraftModel(
40
  vocab_size=1000,
41
  hidden_size=256,
42
- num_layers=2,
43
  num_heads=4,
44
  num_kv_heads=2,
45
- intermediate_size=512,
46
- max_seq_len=128,
47
- block_size=16,
48
  )
49
- noise = mx.random.normal(shape=(1, 16, 256))
50
- target = mx.random.normal(shape=(1, 5, 256))
51
- out = model(noise, target)
52
- self.assertEqual(out.shape, (1, 16, 256))
53
-
54
- def test_logits(self):
55
  model = DFlashDraftModel(
56
  vocab_size=1000,
57
- hidden_size=256,
58
  num_layers=2,
59
  num_heads=4,
60
  num_kv_heads=2,
61
- intermediate_size=512,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  )
63
- hidden = mx.random.normal(shape=(1, 8, 256))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  logits = model.get_logits(hidden)
65
- self.assertEqual(logits.shape, (1, 8, 1000))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
 
68
  if __name__ == "__main__":
 
1
+ """
2
+ Tests for DFlash draft model and denoiser.
3
+
4
+ Validates model creation, forward pass, and KV injection mechanism.
5
+ """
6
 
7
  import unittest
8
  import mlx.core as mx
 
12
  DFlashMLP,
13
  DFlashDecoderLayer,
14
  DFlashDraftModel,
15
+ DFlashDenoiser,
16
+ build_rope_cache,
17
+ apply_rotary_emb,
18
+ create_causal_mask,
19
  )
20
 
21
 
22
  class TestRMSNorm(unittest.TestCase):
23
+ """Test RMSNorm implementation."""
24
+
25
  def test_shape_preservation(self):
26
+ """RMSNorm preserves tensor shape."""
27
  norm = RMSNorm(dims=128)
28
+ x = mx.random.normal((2, 10, 128))
29
+ out = norm(x)
30
+ self.assertEqual(out.shape, x.shape)
31
+
32
+ def test_unit_variance(self):
33
+ """RMSNorm roughly normalizes to unit variance."""
34
+ norm = RMSNorm(dims=64, eps=1e-6)
35
+ x = mx.random.normal((1, 5, 64)) * 5.0
36
  out = norm(x)
37
+ # Check that values are reasonably scaled
38
+ self.assertTrue(mx.all(mx.abs(out) < 100).item())
39
+
40
+
41
+ class TestRoPE(unittest.TestCase):
42
+ """Test rotary positional embeddings."""
43
+
44
+ def test_cache_shape(self):
45
+ """RoPE cache has correct shape."""
46
+ cos, sin = build_rope_cache(seq_len=100, head_dim=64)
47
+ self.assertEqual(cos.shape, (100, 64))
48
+ self.assertEqual(sin.shape, (100, 64))
49
+
50
+ def test_application_shape(self):
51
+ """RoPE application preserves shape."""
52
+ cos, sin = build_rope_cache(seq_len=10, head_dim=64)
53
+ x = mx.random.normal((1, 2, 8, 64)) # [bsz, seq, heads, dim]
54
+ out = apply_rotary_emb(x, cos[:10], sin[:10])
55
  self.assertEqual(out.shape, x.shape)
56
 
57
 
58
+ class TestCausalMask(unittest.TestCase):
59
+ """Test causal mask creation."""
60
+
61
+ def test_mask_shape(self):
62
+ """Mask has correct shape."""
63
+ mask = create_causal_mask(seq_len=16)
64
+ self.assertEqual(mask.shape, (1, 1, 16, 16))
65
+
66
+ def test_mask_values(self):
67
+ """Upper triangle has large negative values."""
68
+ mask = create_causal_mask(seq_len=4)
69
+ # Lower triangle should be 0 (or close)
70
+ self.assertTrue(mask[0, 0, 3, 0] == 0.0)
71
+ # Upper triangle should be very negative
72
+ self.assertTrue(mask[0, 0, 0, 3] < -1e8)
73
+
74
+
75
  class TestDFlashAttention(unittest.TestCase):
76
+ """Test DFlash attention with KV injection."""
77
+
78
+ def test_forward_shape(self):
79
+ """Attention output preserves shape."""
80
  attn = DFlashAttention(
81
  hidden_size=256,
82
  num_heads=4,
83
  num_kv_heads=2,
84
  head_dim=64,
 
85
  )
86
+ hidden = mx.random.normal((1, 8, 256)) # draft tokens
87
+ target_h = mx.random.normal((1, 16, 256)) # target context
88
+
89
+ out = attn(hidden, target_h)
90
+ self.assertEqual(out.shape, (1, 8, 256))
91
+
92
+ def test_kv_injection(self):
93
+ """Different target hidden states produce different outputs."""
94
+ attn = DFlashAttention(
95
+ hidden_size=64,
96
+ num_heads=2,
97
+ num_kv_heads=1,
98
+ head_dim=32,
99
+ )
100
+ hidden = mx.random.normal((1, 4, 64))
101
+ target_h1 = mx.random.normal((1, 8, 64))
102
+ target_h2 = mx.random.normal((1, 8, 64)) * 2.0
103
+
104
+ out1 = attn(hidden, target_h1)
105
+ out2 = attn(hidden, target_h2)
106
+
107
+ # Should be different
108
+ self.assertFalse(mx.allclose(out1, out2))
109
 
110
 
111
  class TestDFlashDraftModel(unittest.TestCase):
112
+ """Test complete draft model."""
113
+
114
+ def test_model_creation(self):
115
+ """Model can be instantiated."""
116
  model = DFlashDraftModel(
117
  vocab_size=1000,
118
  hidden_size=256,
119
+ num_layers=3,
120
  num_heads=4,
121
  num_kv_heads=2,
122
+ intermediate_size=704,
123
+ block_size=8,
 
124
  )
125
+ self.assertEqual(model.num_layers, 3)
126
+ self.assertEqual(model.hidden_size, 256)
127
+
128
+ def test_forward_pass(self):
129
+ """Model forward pass works."""
 
130
  model = DFlashDraftModel(
131
  vocab_size=1000,
132
+ hidden_size=128,
133
  num_layers=2,
134
  num_heads=4,
135
  num_kv_heads=2,
136
+ intermediate_size=352,
137
+ block_size=8,
138
+ )
139
+
140
+ # Create noise embedding (simulating embedded tokens)
141
+ noise_embed = mx.random.normal((1, 8, 128))
142
+ target_hidden = mx.random.normal((1, 16, 128))
143
+
144
+ out = model(noise_embed, target_hidden)
145
+ self.assertEqual(out.shape, (1, 8, 128))
146
+
147
+ def test_target_layer_ids(self):
148
+ """Target layer IDs are computed correctly."""
149
+ model = DFlashDraftModel(
150
+ vocab_size=1000,
151
+ hidden_size=128,
152
+ num_layers=5,
153
+ num_target_layers=32,
154
+ )
155
+ ids = model.target_layer_ids
156
+ self.assertEqual(len(ids), 5)
157
+ # Should span from early to late layers
158
+ self.assertGreater(ids[-1], ids[0])
159
+
160
+ def test_extract_features(self):
161
+ """Feature extraction works with multi-layer hidden states."""
162
+ model = DFlashDraftModel(
163
+ vocab_size=1000,
164
+ hidden_size=128,
165
+ num_layers=3,
166
+ num_target_layers=10,
167
  )
168
+
169
+ # Create fake hidden states from 10 layers
170
+ hidden_states = [mx.random.normal((1, 5, 128)) for _ in range(10)]
171
+
172
+ features = model.extract_context_features(hidden_states)
173
+ # Features should have same batch and seq dims, hidden_size dim
174
+ self.assertEqual(features.shape[0], 1)
175
+ self.assertEqual(features.shape[1], 5)
176
+ # Hidden size is projected back to hidden_size
177
+ self.assertEqual(features.shape[2], 128)
178
+
179
+ def test_get_logits(self):
180
+ """Logits have correct vocab size."""
181
+ vocab_size = 500
182
+ model = DFlashDraftModel(
183
+ vocab_size=vocab_size,
184
+ hidden_size=128,
185
+ num_layers=2,
186
+ )
187
+
188
+ hidden = mx.random.normal((1, 8, 128))
189
  logits = model.get_logits(hidden)
190
+ self.assertEqual(logits.shape, (1, 8, vocab_size))
191
+
192
+
193
+ class TestDFlashDenoiser(unittest.TestCase):
194
+ """Test denoising block generation."""
195
+
196
+ def test_denoise_shape(self):
197
+ """Denoised block has correct shape."""
198
+ model = DFlashDraftModel(
199
+ vocab_size=1000,
200
+ hidden_size=128,
201
+ num_layers=2,
202
+ block_size=8,
203
+ )
204
+ denoiser = DFlashDenoiser(model)
205
+
206
+ draft_tokens = mx.array([[100, 200, 300, 400, 500, 600, 700, 800]])
207
+ target_hidden = mx.random.normal((1, 16, 128))
208
+ position_ids = mx.arange(8)
209
+
210
+ out = denoiser.denoise_block(draft_tokens, target_hidden, position_ids)
211
+ self.assertEqual(out.shape, (1, 8))
212
+
213
+ def test_denoise_greedy(self):
214
+ """Greedy denoising is deterministic."""
215
+ model = DFlashDraftModel(
216
+ vocab_size=100,
217
+ hidden_size=64,
218
+ num_layers=1,
219
+ block_size=4,
220
+ )
221
+ denoiser = DFlashDenoiser(model)
222
+
223
+ draft_tokens = mx.array([[0, 1, 2, 3]])
224
+ target_hidden = mx.random.normal((1, 8, 64))
225
+ position_ids = mx.arange(4)
226
+
227
+ out1 = denoiser.denoise_block(draft_tokens, target_hidden, position_ids, temperature=0.0)
228
+ out2 = denoiser.denoise_block(draft_tokens, target_hidden, position_ids, temperature=0.0)
229
+
230
+ mx.eval(out1, out2)
231
+ self.assertTrue(mx.array_equal(out1, out2))
232
+
233
+
234
+ class TestModelSerialization(unittest.TestCase):
235
+ """Test saving/loading model parameters."""
236
+
237
+ def test_parameter_dict(self):
238
+ """Model can export parameters."""
239
+ model = DFlashDraftModel(
240
+ vocab_size=100,
241
+ hidden_size=64,
242
+ num_layers=2,
243
+ )
244
+ params = dict(model.parameters())
245
+ self.assertTrue(len(params) > 0)
246
+ self.assertIn("lm_head.weight", params)
247
 
248
 
249
  if __name__ == "__main__":