dflash-mlx-universal / tests /test_model.py
tritesh's picture
Upload tests/test_model.py
d00960b verified
"""
Tests for DFlash draft model and denoiser.
Validates model creation, forward pass, and KV injection mechanism.
"""
import unittest
import mlx.core as mx
from dflash_mlx.model import (
RMSNorm,
DFlashAttention,
DFlashMLP,
DFlashDecoderLayer,
DFlashDraftModel,
DFlashDenoiser,
build_rope_cache,
apply_rotary_emb,
create_causal_mask,
)
class TestRMSNorm(unittest.TestCase):
"""Test RMSNorm implementation."""
def test_shape_preservation(self):
"""RMSNorm preserves tensor shape."""
norm = RMSNorm(dims=128)
x = mx.random.normal((2, 10, 128))
out = norm(x)
self.assertEqual(out.shape, x.shape)
def test_unit_variance(self):
"""RMSNorm roughly normalizes to unit variance."""
norm = RMSNorm(dims=64, eps=1e-6)
x = mx.random.normal((1, 5, 64)) * 5.0
out = norm(x)
# Check that values are reasonably scaled
self.assertTrue(mx.all(mx.abs(out) < 100).item())
class TestRoPE(unittest.TestCase):
"""Test rotary positional embeddings."""
def test_cache_shape(self):
"""RoPE cache has correct shape."""
cos, sin = build_rope_cache(seq_len=100, head_dim=64)
self.assertEqual(cos.shape, (100, 64))
self.assertEqual(sin.shape, (100, 64))
def test_application_shape(self):
"""RoPE application preserves shape."""
cos, sin = build_rope_cache(seq_len=10, head_dim=64)
x = mx.random.normal((1, 2, 8, 64)) # [bsz, seq, heads, dim]
out = apply_rotary_emb(x, cos[:10], sin[:10])
self.assertEqual(out.shape, x.shape)
class TestCausalMask(unittest.TestCase):
"""Test causal mask creation."""
def test_mask_shape(self):
"""Mask has correct shape."""
mask = create_causal_mask(seq_len=16)
self.assertEqual(mask.shape, (1, 1, 16, 16))
def test_mask_values(self):
"""Upper triangle has large negative values."""
mask = create_causal_mask(seq_len=4)
# Lower triangle should be 0 (or close)
self.assertTrue(mask[0, 0, 3, 0] == 0.0)
# Upper triangle should be very negative
self.assertTrue(mask[0, 0, 0, 3] < -1e8)
class TestDFlashAttention(unittest.TestCase):
"""Test DFlash attention with KV injection."""
def test_forward_shape(self):
"""Attention output preserves shape."""
attn = DFlashAttention(
hidden_size=256,
num_heads=4,
num_kv_heads=2,
head_dim=64,
)
hidden = mx.random.normal((1, 8, 256)) # draft tokens
target_h = mx.random.normal((1, 16, 256)) # target context
out = attn(hidden, target_h)
self.assertEqual(out.shape, (1, 8, 256))
def test_kv_injection(self):
"""Different target hidden states produce different outputs."""
attn = DFlashAttention(
hidden_size=64,
num_heads=2,
num_kv_heads=1,
head_dim=32,
)
hidden = mx.random.normal((1, 4, 64))
target_h1 = mx.random.normal((1, 8, 64))
target_h2 = mx.random.normal((1, 8, 64)) * 2.0
out1 = attn(hidden, target_h1)
out2 = attn(hidden, target_h2)
# Should be different
self.assertFalse(mx.allclose(out1, out2))
class TestDFlashDraftModel(unittest.TestCase):
"""Test complete draft model."""
def test_model_creation(self):
"""Model can be instantiated."""
model = DFlashDraftModel(
vocab_size=1000,
hidden_size=256,
num_layers=3,
num_heads=4,
num_kv_heads=2,
intermediate_size=704,
block_size=8,
)
self.assertEqual(model.num_layers, 3)
self.assertEqual(model.hidden_size, 256)
def test_forward_pass(self):
"""Model forward pass works."""
model = DFlashDraftModel(
vocab_size=1000,
hidden_size=128,
num_layers=2,
num_heads=4,
num_kv_heads=2,
intermediate_size=352,
block_size=8,
)
# Create noise embedding (simulating embedded tokens)
noise_embed = mx.random.normal((1, 8, 128))
target_hidden = mx.random.normal((1, 16, 128))
out = model(noise_embed, target_hidden)
self.assertEqual(out.shape, (1, 8, 128))
def test_target_layer_ids(self):
"""Target layer IDs are computed correctly."""
model = DFlashDraftModel(
vocab_size=1000,
hidden_size=128,
num_layers=5,
num_target_layers=32,
)
ids = model.target_layer_ids
self.assertEqual(len(ids), 5)
# Should span from early to late layers
self.assertGreater(ids[-1], ids[0])
def test_extract_features(self):
"""Feature extraction works with multi-layer hidden states."""
model = DFlashDraftModel(
vocab_size=1000,
hidden_size=128,
num_layers=3,
num_target_layers=10,
)
# Create fake hidden states from 10 layers
hidden_states = [mx.random.normal((1, 5, 128)) for _ in range(10)]
features = model.extract_context_features(hidden_states)
# Features should have same batch and seq dims, hidden_size dim
self.assertEqual(features.shape[0], 1)
self.assertEqual(features.shape[1], 5)
# Hidden size is projected back to hidden_size
self.assertEqual(features.shape[2], 128)
def test_get_logits(self):
"""Logits have correct vocab size."""
vocab_size = 500
model = DFlashDraftModel(
vocab_size=vocab_size,
hidden_size=128,
num_layers=2,
)
hidden = mx.random.normal((1, 8, 128))
logits = model.get_logits(hidden)
self.assertEqual(logits.shape, (1, 8, vocab_size))
class TestDFlashDenoiser(unittest.TestCase):
"""Test denoising block generation."""
def test_denoise_shape(self):
"""Denoised block has correct shape."""
model = DFlashDraftModel(
vocab_size=1000,
hidden_size=128,
num_layers=2,
block_size=8,
)
denoiser = DFlashDenoiser(model)
draft_tokens = mx.array([[100, 200, 300, 400, 500, 600, 700, 800]])
target_hidden = mx.random.normal((1, 16, 128))
position_ids = mx.arange(8)
out = denoiser.denoise_block(draft_tokens, target_hidden, position_ids)
self.assertEqual(out.shape, (1, 8))
def test_denoise_greedy(self):
"""Greedy denoising is deterministic."""
model = DFlashDraftModel(
vocab_size=100,
hidden_size=64,
num_layers=1,
block_size=4,
)
denoiser = DFlashDenoiser(model)
draft_tokens = mx.array([[0, 1, 2, 3]])
target_hidden = mx.random.normal((1, 8, 64))
position_ids = mx.arange(4)
out1 = denoiser.denoise_block(draft_tokens, target_hidden, position_ids, temperature=0.0)
out2 = denoiser.denoise_block(draft_tokens, target_hidden, position_ids, temperature=0.0)
mx.eval(out1, out2)
self.assertTrue(mx.array_equal(out1, out2))
class TestModelSerialization(unittest.TestCase):
"""Test saving/loading model parameters."""
def test_parameter_dict(self):
"""Model can export parameters."""
model = DFlashDraftModel(
vocab_size=100,
hidden_size=64,
num_layers=2,
)
params = dict(model.parameters())
self.assertTrue(len(params) > 0)
self.assertIn("lm_head.weight", params)
if __name__ == "__main__":
unittest.main()