File size: 5,715 Bytes
9579572 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 | """
Tests for the architecture-agnostic adapter system.
Validates adapter creation, model type detection, and forward pass
for different MLX model families.
"""
import unittest
from unittest.mock import MagicMock, PropertyMock
import mlx.core as mx
class TestArchitectureDetection(unittest.TestCase):
"""Test model type detection from config and structure."""
def test_detect_qwen3(self):
"""Detect Qwen3 from config."""
from dflash_mlx.adapters import detect_model_architecture
mock_model = MagicMock()
mock_model.config = MagicMock()
mock_model.config.model_type = "qwen3"
arch = detect_model_architecture(mock_model)
self.assertEqual(arch, "qwen3")
def test_detect_qwen35(self):
"""Detect Qwen3.5 from config."""
from dflash_mlx.adapters import detect_model_architecture
mock_model = MagicMock()
mock_model.config = MagicMock()
mock_model.config.model_type = "qwen3_5"
arch = detect_model_architecture(mock_model)
self.assertEqual(arch, "qwen3_5")
def test_detect_llama(self):
"""Detect LLaMA from config."""
from dflash_mlx.adapters import detect_model_architecture
mock_model = MagicMock()
mock_model.config = MagicMock()
mock_model.config.model_type = "llama"
arch = detect_model_architecture(mock_model)
self.assertEqual(arch, "llama")
def test_structural_qwen35(self):
"""Detect Qwen3.5 from structure (language_model attr)."""
from dflash_mlx.adapters import detect_model_architecture
mock_model = MagicMock()
mock_model.language_model = MagicMock() # No config, but has language_model
arch = detect_model_architecture(mock_model)
self.assertEqual(arch, "qwen3_5")
def test_structural_generic(self):
"""Detect generic transformer from layers."""
from dflash_mlx.adapters import detect_model_architecture
mock_model = MagicMock()
mock_model.model = MagicMock()
mock_model.model.layers = []
arch = detect_model_architecture(mock_model)
self.assertEqual(arch, "generic_transformer")
def test_unknown(self):
"""Unknown model type falls back to generic."""
from dflash_mlx.adapters import detect_model_architecture
mock_model = MagicMock()
arch = detect_model_architecture(mock_model)
self.assertEqual(arch, "generic")
class TestAdapterRegistry(unittest.TestCase):
"""Test adapter class lookup."""
def test_adapter_lookup(self):
"""Look up adapters by model type."""
from dflash_mlx.adapters import adapter_for_model_type, Qwen3Adapter, LlamaAdapter
self.assertEqual(adapter_for_model_type("qwen3"), Qwen3Adapter)
self.assertEqual(adapter_for_model_type("llama"), LlamaAdapter)
self.assertIsNone(adapter_for_model_type("unknown_model"))
def test_adapter_aliases(self):
"""Test aliases for model type variations."""
from dflash_mlx.adapters import adapter_for_model_type, Qwen35Adapter, Qwen3Adapter
self.assertEqual(adapter_for_model_type("qwen3.5"), Qwen35Adapter)
self.assertEqual(adapter_for_model_type("qwen3_5_instruct"), Qwen35Adapter)
self.assertEqual(adapter_for_model_type("qwen3-instruct"), Qwen3Adapter)
class TestAdapterAttributes(unittest.TestCase):
"""Test adapter attribute resolution."""
def test_embed_resolution(self):
"""Test embedding layer resolution."""
from dflash_mlx.adapters import MLXTargetAdapter
mock_model = MagicMock()
mock_embed = MagicMock()
mock_model.embed_tokens = mock_embed
adapter = MLXTargetAdapter(mock_model)
# embed_tokens at top level should be found
self.assertIsNotNone(adapter._embed)
def test_layer_resolution(self):
"""Test layer list resolution."""
from dflash_mlx.adapters import MLXTargetAdapter
mock_model = MagicMock()
mock_layers = [MagicMock() for _ in range(3)]
mock_model.layers = mock_layers
adapter = MLXTargetAdapter(mock_model)
self.assertEqual(len(adapter._layers), 3)
class TestCacheManagement(unittest.TestCase):
"""Test KV cache creation and rewind."""
def test_make_kv_cache(self):
"""Create KV cache for standard transformer."""
from dflash_mlx.adapters import MLXTargetAdapter
from mlx_lm.models import cache as cache_lib
mock_model = MagicMock()
mock_layers = [MagicMock() for _ in range(4)]
mock_model.layers = mock_layers
adapter = MLXTargetAdapter(mock_model)
adapter.family = "qwen3"
adapter.arch_info = {"cache_type": "KVCache"}
cache = adapter.make_cache()
self.assertEqual(len(cache), 4)
self.assertIsInstance(cache[0], cache_lib.KVCache)
def test_rewind_cache(self):
"""Trim KV cache to accepted length."""
from dflash_mlx.adapters import MLXTargetAdapter
from mlx_lm.models import cache as cache_lib
mock_model = MagicMock()
adapter = MLXTargetAdapter(mock_model)
# Mock cache
mock_cache = MagicMock(spec=cache_lib.KVCache)
mock_cache.offset = 100
adapter.rewind_kv_caches([mock_cache], 50)
mock_cache.trim.assert_called_once_with(50)
if __name__ == "__main__":
unittest.main()
|