dflash-mlx-universal / tests /test_model.py
tritesh's picture
Upload folder using huggingface_hub
0433390 verified
raw
history blame
1.88 kB
"""Tests for DFlash MLX model architecture."""
import unittest
import mlx.core as mx
from dflash_mlx.model import (
RMSNorm,
DFlashAttention,
DFlashMLP,
DFlashDecoderLayer,
DFlashDraftModel,
)
class TestRMSNorm(unittest.TestCase):
def test_shape_preservation(self):
norm = RMSNorm(dims=128)
x = mx.random.normal(shape=(2, 10, 128))
out = norm(x)
self.assertEqual(out.shape, x.shape)
class TestDFlashAttention(unittest.TestCase):
def test_forward(self):
attn = DFlashAttention(
hidden_size=256,
num_heads=4,
num_kv_heads=2,
head_dim=64,
layer_idx=0,
)
hidden = mx.random.normal(shape=(1, 10, 256))
target_hidden = mx.random.normal(shape=(1, 5, 256))
out = attn(hidden, target_hidden)
self.assertEqual(out.shape, (1, 10, 256))
class TestDFlashDraftModel(unittest.TestCase):
def test_forward(self):
model = DFlashDraftModel(
vocab_size=1000,
hidden_size=256,
num_layers=2,
num_heads=4,
num_kv_heads=2,
intermediate_size=512,
max_seq_len=128,
block_size=16,
)
noise = mx.random.normal(shape=(1, 16, 256))
target = mx.random.normal(shape=(1, 5, 256))
out = model(noise, target)
self.assertEqual(out.shape, (1, 16, 256))
def test_logits(self):
model = DFlashDraftModel(
vocab_size=1000,
hidden_size=256,
num_layers=2,
num_heads=4,
num_kv_heads=2,
intermediate_size=512,
)
hidden = mx.random.normal(shape=(1, 8, 256))
logits = model.get_logits(hidden)
self.assertEqual(logits.shape, (1, 8, 1000))
if __name__ == "__main__":
unittest.main()