File size: 7,954 Bytes
d00960b
 
 
 
 
0433390
 
 
 
 
 
 
 
 
d00960b
 
 
 
0433390
 
 
 
d00960b
 
0433390
d00960b
0433390
d00960b
 
 
 
 
 
 
 
0433390
d00960b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0433390
 
 
d00960b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0433390
d00960b
 
 
 
0433390
 
 
 
 
 
d00960b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0433390
 
 
d00960b
 
 
 
0433390
 
 
d00960b
0433390
 
d00960b
 
0433390
d00960b
 
 
 
 
0433390
 
d00960b
0433390
 
 
d00960b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0433390
d00960b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0433390
d00960b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0433390
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
"""
Tests for DFlash draft model and denoiser.

Validates model creation, forward pass, and KV injection mechanism.
"""

import unittest
import mlx.core as mx
from dflash_mlx.model import (
    RMSNorm,
    DFlashAttention,
    DFlashMLP,
    DFlashDecoderLayer,
    DFlashDraftModel,
    DFlashDenoiser,
    build_rope_cache,
    apply_rotary_emb,
    create_causal_mask,
)


class TestRMSNorm(unittest.TestCase):
    """Test RMSNorm implementation."""
    
    def test_shape_preservation(self):
        """RMSNorm preserves tensor shape."""
        norm = RMSNorm(dims=128)
        x = mx.random.normal((2, 10, 128))
        out = norm(x)
        self.assertEqual(out.shape, x.shape)
    
    def test_unit_variance(self):
        """RMSNorm roughly normalizes to unit variance."""
        norm = RMSNorm(dims=64, eps=1e-6)
        x = mx.random.normal((1, 5, 64)) * 5.0
        out = norm(x)
        # Check that values are reasonably scaled
        self.assertTrue(mx.all(mx.abs(out) < 100).item())


class TestRoPE(unittest.TestCase):
    """Test rotary positional embeddings."""
    
    def test_cache_shape(self):
        """RoPE cache has correct shape."""
        cos, sin = build_rope_cache(seq_len=100, head_dim=64)
        self.assertEqual(cos.shape, (100, 64))
        self.assertEqual(sin.shape, (100, 64))
    
    def test_application_shape(self):
        """RoPE application preserves shape."""
        cos, sin = build_rope_cache(seq_len=10, head_dim=64)
        x = mx.random.normal((1, 2, 8, 64))  # [bsz, seq, heads, dim]
        out = apply_rotary_emb(x, cos[:10], sin[:10])
        self.assertEqual(out.shape, x.shape)


class TestCausalMask(unittest.TestCase):
    """Test causal mask creation."""
    
    def test_mask_shape(self):
        """Mask has correct shape."""
        mask = create_causal_mask(seq_len=16)
        self.assertEqual(mask.shape, (1, 1, 16, 16))
    
    def test_mask_values(self):
        """Upper triangle has large negative values."""
        mask = create_causal_mask(seq_len=4)
        # Lower triangle should be 0 (or close)
        self.assertTrue(mask[0, 0, 3, 0] == 0.0)
        # Upper triangle should be very negative
        self.assertTrue(mask[0, 0, 0, 3] < -1e8)


class TestDFlashAttention(unittest.TestCase):
    """Test DFlash attention with KV injection."""
    
    def test_forward_shape(self):
        """Attention output preserves shape."""
        attn = DFlashAttention(
            hidden_size=256,
            num_heads=4,
            num_kv_heads=2,
            head_dim=64,
        )
        hidden = mx.random.normal((1, 8, 256))    # draft tokens
        target_h = mx.random.normal((1, 16, 256))  # target context
        
        out = attn(hidden, target_h)
        self.assertEqual(out.shape, (1, 8, 256))
    
    def test_kv_injection(self):
        """Different target hidden states produce different outputs."""
        attn = DFlashAttention(
            hidden_size=64,
            num_heads=2,
            num_kv_heads=1,
            head_dim=32,
        )
        hidden = mx.random.normal((1, 4, 64))
        target_h1 = mx.random.normal((1, 8, 64))
        target_h2 = mx.random.normal((1, 8, 64)) * 2.0
        
        out1 = attn(hidden, target_h1)
        out2 = attn(hidden, target_h2)
        
        # Should be different
        self.assertFalse(mx.allclose(out1, out2))


class TestDFlashDraftModel(unittest.TestCase):
    """Test complete draft model."""
    
    def test_model_creation(self):
        """Model can be instantiated."""
        model = DFlashDraftModel(
            vocab_size=1000,
            hidden_size=256,
            num_layers=3,
            num_heads=4,
            num_kv_heads=2,
            intermediate_size=704,
            block_size=8,
        )
        self.assertEqual(model.num_layers, 3)
        self.assertEqual(model.hidden_size, 256)
    
    def test_forward_pass(self):
        """Model forward pass works."""
        model = DFlashDraftModel(
            vocab_size=1000,
            hidden_size=128,
            num_layers=2,
            num_heads=4,
            num_kv_heads=2,
            intermediate_size=352,
            block_size=8,
        )
        
        # Create noise embedding (simulating embedded tokens)
        noise_embed = mx.random.normal((1, 8, 128))
        target_hidden = mx.random.normal((1, 16, 128))
        
        out = model(noise_embed, target_hidden)
        self.assertEqual(out.shape, (1, 8, 128))
    
    def test_target_layer_ids(self):
        """Target layer IDs are computed correctly."""
        model = DFlashDraftModel(
            vocab_size=1000,
            hidden_size=128,
            num_layers=5,
            num_target_layers=32,
        )
        ids = model.target_layer_ids
        self.assertEqual(len(ids), 5)
        # Should span from early to late layers
        self.assertGreater(ids[-1], ids[0])
    
    def test_extract_features(self):
        """Feature extraction works with multi-layer hidden states."""
        model = DFlashDraftModel(
            vocab_size=1000,
            hidden_size=128,
            num_layers=3,
            num_target_layers=10,
        )
        
        # Create fake hidden states from 10 layers
        hidden_states = [mx.random.normal((1, 5, 128)) for _ in range(10)]
        
        features = model.extract_context_features(hidden_states)
        # Features should have same batch and seq dims, hidden_size dim
        self.assertEqual(features.shape[0], 1)
        self.assertEqual(features.shape[1], 5)
        # Hidden size is projected back to hidden_size
        self.assertEqual(features.shape[2], 128)
    
    def test_get_logits(self):
        """Logits have correct vocab size."""
        vocab_size = 500
        model = DFlashDraftModel(
            vocab_size=vocab_size,
            hidden_size=128,
            num_layers=2,
        )
        
        hidden = mx.random.normal((1, 8, 128))
        logits = model.get_logits(hidden)
        self.assertEqual(logits.shape, (1, 8, vocab_size))


class TestDFlashDenoiser(unittest.TestCase):
    """Test denoising block generation."""
    
    def test_denoise_shape(self):
        """Denoised block has correct shape."""
        model = DFlashDraftModel(
            vocab_size=1000,
            hidden_size=128,
            num_layers=2,
            block_size=8,
        )
        denoiser = DFlashDenoiser(model)
        
        draft_tokens = mx.array([[100, 200, 300, 400, 500, 600, 700, 800]])
        target_hidden = mx.random.normal((1, 16, 128))
        position_ids = mx.arange(8)
        
        out = denoiser.denoise_block(draft_tokens, target_hidden, position_ids)
        self.assertEqual(out.shape, (1, 8))
    
    def test_denoise_greedy(self):
        """Greedy denoising is deterministic."""
        model = DFlashDraftModel(
            vocab_size=100,
            hidden_size=64,
            num_layers=1,
            block_size=4,
        )
        denoiser = DFlashDenoiser(model)
        
        draft_tokens = mx.array([[0, 1, 2, 3]])
        target_hidden = mx.random.normal((1, 8, 64))
        position_ids = mx.arange(4)
        
        out1 = denoiser.denoise_block(draft_tokens, target_hidden, position_ids, temperature=0.0)
        out2 = denoiser.denoise_block(draft_tokens, target_hidden, position_ids, temperature=0.0)
        
        mx.eval(out1, out2)
        self.assertTrue(mx.array_equal(out1, out2))


class TestModelSerialization(unittest.TestCase):
    """Test saving/loading model parameters."""
    
    def test_parameter_dict(self):
        """Model can export parameters."""
        model = DFlashDraftModel(
            vocab_size=100,
            hidden_size=64,
            num_layers=2,
        )
        params = dict(model.parameters())
        self.assertTrue(len(params) > 0)
        self.assertIn("lm_head.weight", params)


if __name__ == "__main__":
    unittest.main()