File size: 1,879 Bytes
0433390 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 | """Tests for DFlash MLX model architecture."""
import unittest
import mlx.core as mx
from dflash_mlx.model import (
RMSNorm,
DFlashAttention,
DFlashMLP,
DFlashDecoderLayer,
DFlashDraftModel,
)
class TestRMSNorm(unittest.TestCase):
def test_shape_preservation(self):
norm = RMSNorm(dims=128)
x = mx.random.normal(shape=(2, 10, 128))
out = norm(x)
self.assertEqual(out.shape, x.shape)
class TestDFlashAttention(unittest.TestCase):
def test_forward(self):
attn = DFlashAttention(
hidden_size=256,
num_heads=4,
num_kv_heads=2,
head_dim=64,
layer_idx=0,
)
hidden = mx.random.normal(shape=(1, 10, 256))
target_hidden = mx.random.normal(shape=(1, 5, 256))
out = attn(hidden, target_hidden)
self.assertEqual(out.shape, (1, 10, 256))
class TestDFlashDraftModel(unittest.TestCase):
def test_forward(self):
model = DFlashDraftModel(
vocab_size=1000,
hidden_size=256,
num_layers=2,
num_heads=4,
num_kv_heads=2,
intermediate_size=512,
max_seq_len=128,
block_size=16,
)
noise = mx.random.normal(shape=(1, 16, 256))
target = mx.random.normal(shape=(1, 5, 256))
out = model(noise, target)
self.assertEqual(out.shape, (1, 16, 256))
def test_logits(self):
model = DFlashDraftModel(
vocab_size=1000,
hidden_size=256,
num_layers=2,
num_heads=4,
num_kv_heads=2,
intermediate_size=512,
)
hidden = mx.random.normal(shape=(1, 8, 256))
logits = model.get_logits(hidden)
self.assertEqual(logits.shape, (1, 8, 1000))
if __name__ == "__main__":
unittest.main()
|