"""Tests for DFlash MLX model architecture.""" import unittest import mlx.core as mx from dflash_mlx.model import ( RMSNorm, DFlashAttention, DFlashMLP, DFlashDecoderLayer, DFlashDraftModel, ) class TestRMSNorm(unittest.TestCase): def test_shape_preservation(self): norm = RMSNorm(dims=128) x = mx.random.normal(shape=(2, 10, 128)) out = norm(x) self.assertEqual(out.shape, x.shape) class TestDFlashAttention(unittest.TestCase): def test_forward(self): attn = DFlashAttention( hidden_size=256, num_heads=4, num_kv_heads=2, head_dim=64, layer_idx=0, ) hidden = mx.random.normal(shape=(1, 10, 256)) target_hidden = mx.random.normal(shape=(1, 5, 256)) out = attn(hidden, target_hidden) self.assertEqual(out.shape, (1, 10, 256)) class TestDFlashDraftModel(unittest.TestCase): def test_forward(self): model = DFlashDraftModel( vocab_size=1000, hidden_size=256, num_layers=2, num_heads=4, num_kv_heads=2, intermediate_size=512, max_seq_len=128, block_size=16, ) noise = mx.random.normal(shape=(1, 16, 256)) target = mx.random.normal(shape=(1, 5, 256)) out = model(noise, target) self.assertEqual(out.shape, (1, 16, 256)) def test_logits(self): model = DFlashDraftModel( vocab_size=1000, hidden_size=256, num_layers=2, num_heads=4, num_kv_heads=2, intermediate_size=512, ) hidden = mx.random.normal(shape=(1, 8, 256)) logits = model.get_logits(hidden) self.assertEqual(logits.shape, (1, 8, 1000)) if __name__ == "__main__": unittest.main()