| """ |
| Tests for DFlash draft model and denoiser. |
| |
| Validates model creation, forward pass, and KV injection mechanism. |
| """ |
|
|
| import unittest |
| import mlx.core as mx |
| from dflash_mlx.model import ( |
| RMSNorm, |
| DFlashAttention, |
| DFlashMLP, |
| DFlashDecoderLayer, |
| DFlashDraftModel, |
| DFlashDenoiser, |
| build_rope_cache, |
| apply_rotary_emb, |
| create_causal_mask, |
| ) |
|
|
|
|
| class TestRMSNorm(unittest.TestCase): |
| """Test RMSNorm implementation.""" |
| |
| def test_shape_preservation(self): |
| """RMSNorm preserves tensor shape.""" |
| norm = RMSNorm(dims=128) |
| x = mx.random.normal((2, 10, 128)) |
| out = norm(x) |
| self.assertEqual(out.shape, x.shape) |
| |
| def test_unit_variance(self): |
| """RMSNorm roughly normalizes to unit variance.""" |
| norm = RMSNorm(dims=64, eps=1e-6) |
| x = mx.random.normal((1, 5, 64)) * 5.0 |
| out = norm(x) |
| |
| self.assertTrue(mx.all(mx.abs(out) < 100).item()) |
|
|
|
|
| class TestRoPE(unittest.TestCase): |
| """Test rotary positional embeddings.""" |
| |
| def test_cache_shape(self): |
| """RoPE cache has correct shape.""" |
| cos, sin = build_rope_cache(seq_len=100, head_dim=64) |
| self.assertEqual(cos.shape, (100, 64)) |
| self.assertEqual(sin.shape, (100, 64)) |
| |
| def test_application_shape(self): |
| """RoPE application preserves shape.""" |
| cos, sin = build_rope_cache(seq_len=10, head_dim=64) |
| x = mx.random.normal((1, 2, 8, 64)) |
| out = apply_rotary_emb(x, cos[:10], sin[:10]) |
| self.assertEqual(out.shape, x.shape) |
|
|
|
|
| class TestCausalMask(unittest.TestCase): |
| """Test causal mask creation.""" |
| |
| def test_mask_shape(self): |
| """Mask has correct shape.""" |
| mask = create_causal_mask(seq_len=16) |
| self.assertEqual(mask.shape, (1, 1, 16, 16)) |
| |
| def test_mask_values(self): |
| """Upper triangle has large negative values.""" |
| mask = create_causal_mask(seq_len=4) |
| |
| self.assertTrue(mask[0, 0, 3, 0] == 0.0) |
| |
| self.assertTrue(mask[0, 0, 0, 3] < -1e8) |
|
|
|
|
| class TestDFlashAttention(unittest.TestCase): |
| """Test DFlash attention with KV injection.""" |
| |
| def test_forward_shape(self): |
| """Attention output preserves shape.""" |
| attn = DFlashAttention( |
| hidden_size=256, |
| num_heads=4, |
| num_kv_heads=2, |
| head_dim=64, |
| ) |
| hidden = mx.random.normal((1, 8, 256)) |
| target_h = mx.random.normal((1, 16, 256)) |
| |
| out = attn(hidden, target_h) |
| self.assertEqual(out.shape, (1, 8, 256)) |
| |
| def test_kv_injection(self): |
| """Different target hidden states produce different outputs.""" |
| attn = DFlashAttention( |
| hidden_size=64, |
| num_heads=2, |
| num_kv_heads=1, |
| head_dim=32, |
| ) |
| hidden = mx.random.normal((1, 4, 64)) |
| target_h1 = mx.random.normal((1, 8, 64)) |
| target_h2 = mx.random.normal((1, 8, 64)) * 2.0 |
| |
| out1 = attn(hidden, target_h1) |
| out2 = attn(hidden, target_h2) |
| |
| |
| self.assertFalse(mx.allclose(out1, out2)) |
|
|
|
|
| class TestDFlashDraftModel(unittest.TestCase): |
| """Test complete draft model.""" |
| |
| def test_model_creation(self): |
| """Model can be instantiated.""" |
| model = DFlashDraftModel( |
| vocab_size=1000, |
| hidden_size=256, |
| num_layers=3, |
| num_heads=4, |
| num_kv_heads=2, |
| intermediate_size=704, |
| block_size=8, |
| ) |
| self.assertEqual(model.num_layers, 3) |
| self.assertEqual(model.hidden_size, 256) |
| |
| def test_forward_pass(self): |
| """Model forward pass works.""" |
| model = DFlashDraftModel( |
| vocab_size=1000, |
| hidden_size=128, |
| num_layers=2, |
| num_heads=4, |
| num_kv_heads=2, |
| intermediate_size=352, |
| block_size=8, |
| ) |
| |
| |
| noise_embed = mx.random.normal((1, 8, 128)) |
| target_hidden = mx.random.normal((1, 16, 128)) |
| |
| out = model(noise_embed, target_hidden) |
| self.assertEqual(out.shape, (1, 8, 128)) |
| |
| def test_target_layer_ids(self): |
| """Target layer IDs are computed correctly.""" |
| model = DFlashDraftModel( |
| vocab_size=1000, |
| hidden_size=128, |
| num_layers=5, |
| num_target_layers=32, |
| ) |
| ids = model.target_layer_ids |
| self.assertEqual(len(ids), 5) |
| |
| self.assertGreater(ids[-1], ids[0]) |
| |
| def test_extract_features(self): |
| """Feature extraction works with multi-layer hidden states.""" |
| model = DFlashDraftModel( |
| vocab_size=1000, |
| hidden_size=128, |
| num_layers=3, |
| num_target_layers=10, |
| ) |
| |
| |
| hidden_states = [mx.random.normal((1, 5, 128)) for _ in range(10)] |
| |
| features = model.extract_context_features(hidden_states) |
| |
| self.assertEqual(features.shape[0], 1) |
| self.assertEqual(features.shape[1], 5) |
| |
| self.assertEqual(features.shape[2], 128) |
| |
| def test_get_logits(self): |
| """Logits have correct vocab size.""" |
| vocab_size = 500 |
| model = DFlashDraftModel( |
| vocab_size=vocab_size, |
| hidden_size=128, |
| num_layers=2, |
| ) |
| |
| hidden = mx.random.normal((1, 8, 128)) |
| logits = model.get_logits(hidden) |
| self.assertEqual(logits.shape, (1, 8, vocab_size)) |
|
|
|
|
| class TestDFlashDenoiser(unittest.TestCase): |
| """Test denoising block generation.""" |
| |
| def test_denoise_shape(self): |
| """Denoised block has correct shape.""" |
| model = DFlashDraftModel( |
| vocab_size=1000, |
| hidden_size=128, |
| num_layers=2, |
| block_size=8, |
| ) |
| denoiser = DFlashDenoiser(model) |
| |
| draft_tokens = mx.array([[100, 200, 300, 400, 500, 600, 700, 800]]) |
| target_hidden = mx.random.normal((1, 16, 128)) |
| position_ids = mx.arange(8) |
| |
| out = denoiser.denoise_block(draft_tokens, target_hidden, position_ids) |
| self.assertEqual(out.shape, (1, 8)) |
| |
| def test_denoise_greedy(self): |
| """Greedy denoising is deterministic.""" |
| model = DFlashDraftModel( |
| vocab_size=100, |
| hidden_size=64, |
| num_layers=1, |
| block_size=4, |
| ) |
| denoiser = DFlashDenoiser(model) |
| |
| draft_tokens = mx.array([[0, 1, 2, 3]]) |
| target_hidden = mx.random.normal((1, 8, 64)) |
| position_ids = mx.arange(4) |
| |
| out1 = denoiser.denoise_block(draft_tokens, target_hidden, position_ids, temperature=0.0) |
| out2 = denoiser.denoise_block(draft_tokens, target_hidden, position_ids, temperature=0.0) |
| |
| mx.eval(out1, out2) |
| self.assertTrue(mx.array_equal(out1, out2)) |
|
|
|
|
| class TestModelSerialization(unittest.TestCase): |
| """Test saving/loading model parameters.""" |
| |
| def test_parameter_dict(self): |
| """Model can export parameters.""" |
| model = DFlashDraftModel( |
| vocab_size=100, |
| hidden_size=64, |
| num_layers=2, |
| ) |
| params = dict(model.parameters()) |
| self.assertTrue(len(params) > 0) |
| self.assertIn("lm_head.weight", params) |
|
|
|
|
| if __name__ == "__main__": |
| unittest.main() |
|
|