File size: 1,879 Bytes
0433390
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
"""Tests for DFlash MLX model architecture."""

import unittest
import mlx.core as mx
from dflash_mlx.model import (
    RMSNorm,
    DFlashAttention,
    DFlashMLP,
    DFlashDecoderLayer,
    DFlashDraftModel,
)


class TestRMSNorm(unittest.TestCase):
    def test_shape_preservation(self):
        norm = RMSNorm(dims=128)
        x = mx.random.normal(shape=(2, 10, 128))
        out = norm(x)
        self.assertEqual(out.shape, x.shape)


class TestDFlashAttention(unittest.TestCase):
    def test_forward(self):
        attn = DFlashAttention(
            hidden_size=256,
            num_heads=4,
            num_kv_heads=2,
            head_dim=64,
            layer_idx=0,
        )
        hidden = mx.random.normal(shape=(1, 10, 256))
        target_hidden = mx.random.normal(shape=(1, 5, 256))
        out = attn(hidden, target_hidden)
        self.assertEqual(out.shape, (1, 10, 256))


class TestDFlashDraftModel(unittest.TestCase):
    def test_forward(self):
        model = DFlashDraftModel(
            vocab_size=1000,
            hidden_size=256,
            num_layers=2,
            num_heads=4,
            num_kv_heads=2,
            intermediate_size=512,
            max_seq_len=128,
            block_size=16,
        )
        noise = mx.random.normal(shape=(1, 16, 256))
        target = mx.random.normal(shape=(1, 5, 256))
        out = model(noise, target)
        self.assertEqual(out.shape, (1, 16, 256))

    def test_logits(self):
        model = DFlashDraftModel(
            vocab_size=1000,
            hidden_size=256,
            num_layers=2,
            num_heads=4,
            num_kv_heads=2,
            intermediate_size=512,
        )
        hidden = mx.random.normal(shape=(1, 8, 256))
        logits = model.get_logits(hidden)
        self.assertEqual(logits.shape, (1, 8, 1000))


if __name__ == "__main__":
    unittest.main()