""" Tests for DFlash draft model and denoiser. Validates model creation, forward pass, and KV injection mechanism. """ import unittest import mlx.core as mx from dflash_mlx.model import ( RMSNorm, DFlashAttention, DFlashMLP, DFlashDecoderLayer, DFlashDraftModel, DFlashDenoiser, build_rope_cache, apply_rotary_emb, create_causal_mask, ) class TestRMSNorm(unittest.TestCase): """Test RMSNorm implementation.""" def test_shape_preservation(self): """RMSNorm preserves tensor shape.""" norm = RMSNorm(dims=128) x = mx.random.normal((2, 10, 128)) out = norm(x) self.assertEqual(out.shape, x.shape) def test_unit_variance(self): """RMSNorm roughly normalizes to unit variance.""" norm = RMSNorm(dims=64, eps=1e-6) x = mx.random.normal((1, 5, 64)) * 5.0 out = norm(x) # Check that values are reasonably scaled self.assertTrue(mx.all(mx.abs(out) < 100).item()) class TestRoPE(unittest.TestCase): """Test rotary positional embeddings.""" def test_cache_shape(self): """RoPE cache has correct shape.""" cos, sin = build_rope_cache(seq_len=100, head_dim=64) self.assertEqual(cos.shape, (100, 64)) self.assertEqual(sin.shape, (100, 64)) def test_application_shape(self): """RoPE application preserves shape.""" cos, sin = build_rope_cache(seq_len=10, head_dim=64) x = mx.random.normal((1, 2, 8, 64)) # [bsz, seq, heads, dim] out = apply_rotary_emb(x, cos[:10], sin[:10]) self.assertEqual(out.shape, x.shape) class TestCausalMask(unittest.TestCase): """Test causal mask creation.""" def test_mask_shape(self): """Mask has correct shape.""" mask = create_causal_mask(seq_len=16) self.assertEqual(mask.shape, (1, 1, 16, 16)) def test_mask_values(self): """Upper triangle has large negative values.""" mask = create_causal_mask(seq_len=4) # Lower triangle should be 0 (or close) self.assertTrue(mask[0, 0, 3, 0] == 0.0) # Upper triangle should be very negative self.assertTrue(mask[0, 0, 0, 3] < -1e8) class TestDFlashAttention(unittest.TestCase): """Test DFlash attention with KV injection.""" def test_forward_shape(self): """Attention output preserves shape.""" attn = DFlashAttention( hidden_size=256, num_heads=4, num_kv_heads=2, head_dim=64, ) hidden = mx.random.normal((1, 8, 256)) # draft tokens target_h = mx.random.normal((1, 16, 256)) # target context out = attn(hidden, target_h) self.assertEqual(out.shape, (1, 8, 256)) def test_kv_injection(self): """Different target hidden states produce different outputs.""" attn = DFlashAttention( hidden_size=64, num_heads=2, num_kv_heads=1, head_dim=32, ) hidden = mx.random.normal((1, 4, 64)) target_h1 = mx.random.normal((1, 8, 64)) target_h2 = mx.random.normal((1, 8, 64)) * 2.0 out1 = attn(hidden, target_h1) out2 = attn(hidden, target_h2) # Should be different self.assertFalse(mx.allclose(out1, out2)) class TestDFlashDraftModel(unittest.TestCase): """Test complete draft model.""" def test_model_creation(self): """Model can be instantiated.""" model = DFlashDraftModel( vocab_size=1000, hidden_size=256, num_layers=3, num_heads=4, num_kv_heads=2, intermediate_size=704, block_size=8, ) self.assertEqual(model.num_layers, 3) self.assertEqual(model.hidden_size, 256) def test_forward_pass(self): """Model forward pass works.""" model = DFlashDraftModel( vocab_size=1000, hidden_size=128, num_layers=2, num_heads=4, num_kv_heads=2, intermediate_size=352, block_size=8, ) # Create noise embedding (simulating embedded tokens) noise_embed = mx.random.normal((1, 8, 128)) target_hidden = mx.random.normal((1, 16, 128)) out = model(noise_embed, target_hidden) self.assertEqual(out.shape, (1, 8, 128)) def test_target_layer_ids(self): """Target layer IDs are computed correctly.""" model = DFlashDraftModel( vocab_size=1000, hidden_size=128, num_layers=5, num_target_layers=32, ) ids = model.target_layer_ids self.assertEqual(len(ids), 5) # Should span from early to late layers self.assertGreater(ids[-1], ids[0]) def test_extract_features(self): """Feature extraction works with multi-layer hidden states.""" model = DFlashDraftModel( vocab_size=1000, hidden_size=128, num_layers=3, num_target_layers=10, ) # Create fake hidden states from 10 layers hidden_states = [mx.random.normal((1, 5, 128)) for _ in range(10)] features = model.extract_context_features(hidden_states) # Features should have same batch and seq dims, hidden_size dim self.assertEqual(features.shape[0], 1) self.assertEqual(features.shape[1], 5) # Hidden size is projected back to hidden_size self.assertEqual(features.shape[2], 128) def test_get_logits(self): """Logits have correct vocab size.""" vocab_size = 500 model = DFlashDraftModel( vocab_size=vocab_size, hidden_size=128, num_layers=2, ) hidden = mx.random.normal((1, 8, 128)) logits = model.get_logits(hidden) self.assertEqual(logits.shape, (1, 8, vocab_size)) class TestDFlashDenoiser(unittest.TestCase): """Test denoising block generation.""" def test_denoise_shape(self): """Denoised block has correct shape.""" model = DFlashDraftModel( vocab_size=1000, hidden_size=128, num_layers=2, block_size=8, ) denoiser = DFlashDenoiser(model) draft_tokens = mx.array([[100, 200, 300, 400, 500, 600, 700, 800]]) target_hidden = mx.random.normal((1, 16, 128)) position_ids = mx.arange(8) out = denoiser.denoise_block(draft_tokens, target_hidden, position_ids) self.assertEqual(out.shape, (1, 8)) def test_denoise_greedy(self): """Greedy denoising is deterministic.""" model = DFlashDraftModel( vocab_size=100, hidden_size=64, num_layers=1, block_size=4, ) denoiser = DFlashDenoiser(model) draft_tokens = mx.array([[0, 1, 2, 3]]) target_hidden = mx.random.normal((1, 8, 64)) position_ids = mx.arange(4) out1 = denoiser.denoise_block(draft_tokens, target_hidden, position_ids, temperature=0.0) out2 = denoiser.denoise_block(draft_tokens, target_hidden, position_ids, temperature=0.0) mx.eval(out1, out2) self.assertTrue(mx.array_equal(out1, out2)) class TestModelSerialization(unittest.TestCase): """Test saving/loading model parameters.""" def test_parameter_dict(self): """Model can export parameters.""" model = DFlashDraftModel( vocab_size=100, hidden_size=64, num_layers=2, ) params = dict(model.parameters()) self.assertTrue(len(params) > 0) self.assertIn("lm_head.weight", params) if __name__ == "__main__": unittest.main()