dflash-mlx-universal / tests /test_model.py
tritesh's picture
Upload tests/test_model.py
d00960b verified
raw
history blame
7.95 kB
"""
Tests for DFlash draft model and denoiser.
Validates model creation, forward pass, and KV injection mechanism.
"""
import unittest
import mlx.core as mx
from dflash_mlx.model import (
RMSNorm,
DFlashAttention,
DFlashMLP,
DFlashDecoderLayer,
DFlashDraftModel,
DFlashDenoiser,
build_rope_cache,
apply_rotary_emb,
create_causal_mask,
)
class TestRMSNorm(unittest.TestCase):
"""Test RMSNorm implementation."""
def test_shape_preservation(self):
"""RMSNorm preserves tensor shape."""
norm = RMSNorm(dims=128)
x = mx.random.normal((2, 10, 128))
out = norm(x)
self.assertEqual(out.shape, x.shape)
def test_unit_variance(self):
"""RMSNorm roughly normalizes to unit variance."""
norm = RMSNorm(dims=64, eps=1e-6)
x = mx.random.normal((1, 5, 64)) * 5.0
out = norm(x)
# Check that values are reasonably scaled
self.assertTrue(mx.all(mx.abs(out) < 100).item())
class TestRoPE(unittest.TestCase):
"""Test rotary positional embeddings."""
def test_cache_shape(self):
"""RoPE cache has correct shape."""
cos, sin = build_rope_cache(seq_len=100, head_dim=64)
self.assertEqual(cos.shape, (100, 64))
self.assertEqual(sin.shape, (100, 64))
def test_application_shape(self):
"""RoPE application preserves shape."""
cos, sin = build_rope_cache(seq_len=10, head_dim=64)
x = mx.random.normal((1, 2, 8, 64)) # [bsz, seq, heads, dim]
out = apply_rotary_emb(x, cos[:10], sin[:10])
self.assertEqual(out.shape, x.shape)
class TestCausalMask(unittest.TestCase):
"""Test causal mask creation."""
def test_mask_shape(self):
"""Mask has correct shape."""
mask = create_causal_mask(seq_len=16)
self.assertEqual(mask.shape, (1, 1, 16, 16))
def test_mask_values(self):
"""Upper triangle has large negative values."""
mask = create_causal_mask(seq_len=4)
# Lower triangle should be 0 (or close)
self.assertTrue(mask[0, 0, 3, 0] == 0.0)
# Upper triangle should be very negative
self.assertTrue(mask[0, 0, 0, 3] < -1e8)
class TestDFlashAttention(unittest.TestCase):
"""Test DFlash attention with KV injection."""
def test_forward_shape(self):
"""Attention output preserves shape."""
attn = DFlashAttention(
hidden_size=256,
num_heads=4,
num_kv_heads=2,
head_dim=64,
)
hidden = mx.random.normal((1, 8, 256)) # draft tokens
target_h = mx.random.normal((1, 16, 256)) # target context
out = attn(hidden, target_h)
self.assertEqual(out.shape, (1, 8, 256))
def test_kv_injection(self):
"""Different target hidden states produce different outputs."""
attn = DFlashAttention(
hidden_size=64,
num_heads=2,
num_kv_heads=1,
head_dim=32,
)
hidden = mx.random.normal((1, 4, 64))
target_h1 = mx.random.normal((1, 8, 64))
target_h2 = mx.random.normal((1, 8, 64)) * 2.0
out1 = attn(hidden, target_h1)
out2 = attn(hidden, target_h2)
# Should be different
self.assertFalse(mx.allclose(out1, out2))
class TestDFlashDraftModel(unittest.TestCase):
"""Test complete draft model."""
def test_model_creation(self):
"""Model can be instantiated."""
model = DFlashDraftModel(
vocab_size=1000,
hidden_size=256,
num_layers=3,
num_heads=4,
num_kv_heads=2,
intermediate_size=704,
block_size=8,
)
self.assertEqual(model.num_layers, 3)
self.assertEqual(model.hidden_size, 256)
def test_forward_pass(self):
"""Model forward pass works."""
model = DFlashDraftModel(
vocab_size=1000,
hidden_size=128,
num_layers=2,
num_heads=4,
num_kv_heads=2,
intermediate_size=352,
block_size=8,
)
# Create noise embedding (simulating embedded tokens)
noise_embed = mx.random.normal((1, 8, 128))
target_hidden = mx.random.normal((1, 16, 128))
out = model(noise_embed, target_hidden)
self.assertEqual(out.shape, (1, 8, 128))
def test_target_layer_ids(self):
"""Target layer IDs are computed correctly."""
model = DFlashDraftModel(
vocab_size=1000,
hidden_size=128,
num_layers=5,
num_target_layers=32,
)
ids = model.target_layer_ids
self.assertEqual(len(ids), 5)
# Should span from early to late layers
self.assertGreater(ids[-1], ids[0])
def test_extract_features(self):
"""Feature extraction works with multi-layer hidden states."""
model = DFlashDraftModel(
vocab_size=1000,
hidden_size=128,
num_layers=3,
num_target_layers=10,
)
# Create fake hidden states from 10 layers
hidden_states = [mx.random.normal((1, 5, 128)) for _ in range(10)]
features = model.extract_context_features(hidden_states)
# Features should have same batch and seq dims, hidden_size dim
self.assertEqual(features.shape[0], 1)
self.assertEqual(features.shape[1], 5)
# Hidden size is projected back to hidden_size
self.assertEqual(features.shape[2], 128)
def test_get_logits(self):
"""Logits have correct vocab size."""
vocab_size = 500
model = DFlashDraftModel(
vocab_size=vocab_size,
hidden_size=128,
num_layers=2,
)
hidden = mx.random.normal((1, 8, 128))
logits = model.get_logits(hidden)
self.assertEqual(logits.shape, (1, 8, vocab_size))
class TestDFlashDenoiser(unittest.TestCase):
"""Test denoising block generation."""
def test_denoise_shape(self):
"""Denoised block has correct shape."""
model = DFlashDraftModel(
vocab_size=1000,
hidden_size=128,
num_layers=2,
block_size=8,
)
denoiser = DFlashDenoiser(model)
draft_tokens = mx.array([[100, 200, 300, 400, 500, 600, 700, 800]])
target_hidden = mx.random.normal((1, 16, 128))
position_ids = mx.arange(8)
out = denoiser.denoise_block(draft_tokens, target_hidden, position_ids)
self.assertEqual(out.shape, (1, 8))
def test_denoise_greedy(self):
"""Greedy denoising is deterministic."""
model = DFlashDraftModel(
vocab_size=100,
hidden_size=64,
num_layers=1,
block_size=4,
)
denoiser = DFlashDenoiser(model)
draft_tokens = mx.array([[0, 1, 2, 3]])
target_hidden = mx.random.normal((1, 8, 64))
position_ids = mx.arange(4)
out1 = denoiser.denoise_block(draft_tokens, target_hidden, position_ids, temperature=0.0)
out2 = denoiser.denoise_block(draft_tokens, target_hidden, position_ids, temperature=0.0)
mx.eval(out1, out2)
self.assertTrue(mx.array_equal(out1, out2))
class TestModelSerialization(unittest.TestCase):
"""Test saving/loading model parameters."""
def test_parameter_dict(self):
"""Model can export parameters."""
model = DFlashDraftModel(
vocab_size=100,
hidden_size=64,
num_layers=2,
)
params = dict(model.parameters())
self.assertTrue(len(params) > 0)
self.assertIn("lm_head.weight", params)
if __name__ == "__main__":
unittest.main()