| """Tests for DFlash MLX model architecture.""" |
|
|
| import unittest |
| import mlx.core as mx |
| from dflash_mlx.model import ( |
| RMSNorm, |
| DFlashAttention, |
| DFlashMLP, |
| DFlashDecoderLayer, |
| DFlashDraftModel, |
| ) |
|
|
|
|
| class TestRMSNorm(unittest.TestCase): |
| def test_shape_preservation(self): |
| norm = RMSNorm(dims=128) |
| x = mx.random.normal(shape=(2, 10, 128)) |
| out = norm(x) |
| self.assertEqual(out.shape, x.shape) |
|
|
|
|
| class TestDFlashAttention(unittest.TestCase): |
| def test_forward(self): |
| attn = DFlashAttention( |
| hidden_size=256, |
| num_heads=4, |
| num_kv_heads=2, |
| head_dim=64, |
| layer_idx=0, |
| ) |
| hidden = mx.random.normal(shape=(1, 10, 256)) |
| target_hidden = mx.random.normal(shape=(1, 5, 256)) |
| out = attn(hidden, target_hidden) |
| self.assertEqual(out.shape, (1, 10, 256)) |
|
|
|
|
| class TestDFlashDraftModel(unittest.TestCase): |
| def test_forward(self): |
| model = DFlashDraftModel( |
| vocab_size=1000, |
| hidden_size=256, |
| num_layers=2, |
| num_heads=4, |
| num_kv_heads=2, |
| intermediate_size=512, |
| max_seq_len=128, |
| block_size=16, |
| ) |
| noise = mx.random.normal(shape=(1, 16, 256)) |
| target = mx.random.normal(shape=(1, 5, 256)) |
| out = model(noise, target) |
| self.assertEqual(out.shape, (1, 16, 256)) |
|
|
| def test_logits(self): |
| model = DFlashDraftModel( |
| vocab_size=1000, |
| hidden_size=256, |
| num_layers=2, |
| num_heads=4, |
| num_kv_heads=2, |
| intermediate_size=512, |
| ) |
| hidden = mx.random.normal(shape=(1, 8, 256)) |
| logits = model.get_logits(hidden) |
| self.assertEqual(logits.shape, (1, 8, 1000)) |
|
|
|
|
| if __name__ == "__main__": |
| unittest.main() |
|
|