| """ |
| Tests for the architecture-agnostic adapter system. |
| |
| Validates adapter creation, model type detection, and forward pass |
| for different MLX model families. |
| """ |
|
|
| import unittest |
| from unittest.mock import MagicMock, PropertyMock |
| import mlx.core as mx |
|
|
|
|
| class TestArchitectureDetection(unittest.TestCase): |
| """Test model type detection from config and structure.""" |
| |
| def test_detect_qwen3(self): |
| """Detect Qwen3 from config.""" |
| from dflash_mlx.adapters import detect_model_architecture |
| |
| mock_model = MagicMock() |
| mock_model.config = MagicMock() |
| mock_model.config.model_type = "qwen3" |
| |
| arch = detect_model_architecture(mock_model) |
| self.assertEqual(arch, "qwen3") |
| |
| def test_detect_qwen35(self): |
| """Detect Qwen3.5 from config.""" |
| from dflash_mlx.adapters import detect_model_architecture |
| |
| mock_model = MagicMock() |
| mock_model.config = MagicMock() |
| mock_model.config.model_type = "qwen3_5" |
| |
| arch = detect_model_architecture(mock_model) |
| self.assertEqual(arch, "qwen3_5") |
| |
| def test_detect_llama(self): |
| """Detect LLaMA from config.""" |
| from dflash_mlx.adapters import detect_model_architecture |
| |
| mock_model = MagicMock() |
| mock_model.config = MagicMock() |
| mock_model.config.model_type = "llama" |
| |
| arch = detect_model_architecture(mock_model) |
| self.assertEqual(arch, "llama") |
| |
| def test_structural_qwen35(self): |
| """Detect Qwen3.5 from structure (language_model attr).""" |
| from dflash_mlx.adapters import detect_model_architecture |
| |
| mock_model = MagicMock() |
| mock_model.language_model = MagicMock() |
| |
| arch = detect_model_architecture(mock_model) |
| self.assertEqual(arch, "qwen3_5") |
| |
| def test_structural_generic(self): |
| """Detect generic transformer from layers.""" |
| from dflash_mlx.adapters import detect_model_architecture |
| |
| mock_model = MagicMock() |
| mock_model.model = MagicMock() |
| mock_model.model.layers = [] |
| |
| arch = detect_model_architecture(mock_model) |
| self.assertEqual(arch, "generic_transformer") |
| |
| def test_unknown(self): |
| """Unknown model type falls back to generic.""" |
| from dflash_mlx.adapters import detect_model_architecture |
| |
| mock_model = MagicMock() |
| |
| arch = detect_model_architecture(mock_model) |
| self.assertEqual(arch, "generic") |
|
|
|
|
| class TestAdapterRegistry(unittest.TestCase): |
| """Test adapter class lookup.""" |
| |
| def test_adapter_lookup(self): |
| """Look up adapters by model type.""" |
| from dflash_mlx.adapters import adapter_for_model_type, Qwen3Adapter, LlamaAdapter |
| |
| self.assertEqual(adapter_for_model_type("qwen3"), Qwen3Adapter) |
| self.assertEqual(adapter_for_model_type("llama"), LlamaAdapter) |
| self.assertIsNone(adapter_for_model_type("unknown_model")) |
| |
| def test_adapter_aliases(self): |
| """Test aliases for model type variations.""" |
| from dflash_mlx.adapters import adapter_for_model_type, Qwen35Adapter, Qwen3Adapter |
| |
| self.assertEqual(adapter_for_model_type("qwen3.5"), Qwen35Adapter) |
| self.assertEqual(adapter_for_model_type("qwen3_5_instruct"), Qwen35Adapter) |
| self.assertEqual(adapter_for_model_type("qwen3-instruct"), Qwen3Adapter) |
|
|
|
|
| class TestAdapterAttributes(unittest.TestCase): |
| """Test adapter attribute resolution.""" |
| |
| def test_embed_resolution(self): |
| """Test embedding layer resolution.""" |
| from dflash_mlx.adapters import MLXTargetAdapter |
| |
| mock_model = MagicMock() |
| mock_embed = MagicMock() |
| mock_model.embed_tokens = mock_embed |
| |
| adapter = MLXTargetAdapter(mock_model) |
| |
| self.assertIsNotNone(adapter._embed) |
| |
| def test_layer_resolution(self): |
| """Test layer list resolution.""" |
| from dflash_mlx.adapters import MLXTargetAdapter |
| |
| mock_model = MagicMock() |
| mock_layers = [MagicMock() for _ in range(3)] |
| mock_model.layers = mock_layers |
| |
| adapter = MLXTargetAdapter(mock_model) |
| self.assertEqual(len(adapter._layers), 3) |
|
|
|
|
| class TestCacheManagement(unittest.TestCase): |
| """Test KV cache creation and rewind.""" |
| |
| def test_make_kv_cache(self): |
| """Create KV cache for standard transformer.""" |
| from dflash_mlx.adapters import MLXTargetAdapter |
| from mlx_lm.models import cache as cache_lib |
| |
| mock_model = MagicMock() |
| mock_layers = [MagicMock() for _ in range(4)] |
| mock_model.layers = mock_layers |
| |
| adapter = MLXTargetAdapter(mock_model) |
| adapter.family = "qwen3" |
| adapter.arch_info = {"cache_type": "KVCache"} |
| |
| cache = adapter.make_cache() |
| self.assertEqual(len(cache), 4) |
| self.assertIsInstance(cache[0], cache_lib.KVCache) |
| |
| def test_rewind_cache(self): |
| """Trim KV cache to accepted length.""" |
| from dflash_mlx.adapters import MLXTargetAdapter |
| from mlx_lm.models import cache as cache_lib |
| |
| mock_model = MagicMock() |
| |
| adapter = MLXTargetAdapter(mock_model) |
| |
| |
| mock_cache = MagicMock(spec=cache_lib.KVCache) |
| mock_cache.offset = 100 |
| |
| adapter.rewind_kv_caches([mock_cache], 50) |
| mock_cache.trim.assert_called_once_with(50) |
|
|
|
|
| if __name__ == "__main__": |
| unittest.main() |
|
|