tritesh commited on
Commit
9579572
·
verified ·
1 Parent(s): 81dfca1

Upload tests/test_adapters.py

Browse files
Files changed (1) hide show
  1. tests/test_adapters.py +165 -0
tests/test_adapters.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tests for the architecture-agnostic adapter system.
3
+
4
+ Validates adapter creation, model type detection, and forward pass
5
+ for different MLX model families.
6
+ """
7
+
8
+ import unittest
9
+ from unittest.mock import MagicMock, PropertyMock
10
+ import mlx.core as mx
11
+
12
+
13
+ class TestArchitectureDetection(unittest.TestCase):
14
+ """Test model type detection from config and structure."""
15
+
16
+ def test_detect_qwen3(self):
17
+ """Detect Qwen3 from config."""
18
+ from dflash_mlx.adapters import detect_model_architecture
19
+
20
+ mock_model = MagicMock()
21
+ mock_model.config = MagicMock()
22
+ mock_model.config.model_type = "qwen3"
23
+
24
+ arch = detect_model_architecture(mock_model)
25
+ self.assertEqual(arch, "qwen3")
26
+
27
+ def test_detect_qwen35(self):
28
+ """Detect Qwen3.5 from config."""
29
+ from dflash_mlx.adapters import detect_model_architecture
30
+
31
+ mock_model = MagicMock()
32
+ mock_model.config = MagicMock()
33
+ mock_model.config.model_type = "qwen3_5"
34
+
35
+ arch = detect_model_architecture(mock_model)
36
+ self.assertEqual(arch, "qwen3_5")
37
+
38
+ def test_detect_llama(self):
39
+ """Detect LLaMA from config."""
40
+ from dflash_mlx.adapters import detect_model_architecture
41
+
42
+ mock_model = MagicMock()
43
+ mock_model.config = MagicMock()
44
+ mock_model.config.model_type = "llama"
45
+
46
+ arch = detect_model_architecture(mock_model)
47
+ self.assertEqual(arch, "llama")
48
+
49
+ def test_structural_qwen35(self):
50
+ """Detect Qwen3.5 from structure (language_model attr)."""
51
+ from dflash_mlx.adapters import detect_model_architecture
52
+
53
+ mock_model = MagicMock()
54
+ mock_model.language_model = MagicMock() # No config, but has language_model
55
+
56
+ arch = detect_model_architecture(mock_model)
57
+ self.assertEqual(arch, "qwen3_5")
58
+
59
+ def test_structural_generic(self):
60
+ """Detect generic transformer from layers."""
61
+ from dflash_mlx.adapters import detect_model_architecture
62
+
63
+ mock_model = MagicMock()
64
+ mock_model.model = MagicMock()
65
+ mock_model.model.layers = []
66
+
67
+ arch = detect_model_architecture(mock_model)
68
+ self.assertEqual(arch, "generic_transformer")
69
+
70
+ def test_unknown(self):
71
+ """Unknown model type falls back to generic."""
72
+ from dflash_mlx.adapters import detect_model_architecture
73
+
74
+ mock_model = MagicMock()
75
+
76
+ arch = detect_model_architecture(mock_model)
77
+ self.assertEqual(arch, "generic")
78
+
79
+
80
+ class TestAdapterRegistry(unittest.TestCase):
81
+ """Test adapter class lookup."""
82
+
83
+ def test_adapter_lookup(self):
84
+ """Look up adapters by model type."""
85
+ from dflash_mlx.adapters import adapter_for_model_type, Qwen3Adapter, LlamaAdapter
86
+
87
+ self.assertEqual(adapter_for_model_type("qwen3"), Qwen3Adapter)
88
+ self.assertEqual(adapter_for_model_type("llama"), LlamaAdapter)
89
+ self.assertIsNone(adapter_for_model_type("unknown_model"))
90
+
91
+ def test_adapter_aliases(self):
92
+ """Test aliases for model type variations."""
93
+ from dflash_mlx.adapters import adapter_for_model_type, Qwen35Adapter, Qwen3Adapter
94
+
95
+ self.assertEqual(adapter_for_model_type("qwen3.5"), Qwen35Adapter)
96
+ self.assertEqual(adapter_for_model_type("qwen3_5_instruct"), Qwen35Adapter)
97
+ self.assertEqual(adapter_for_model_type("qwen3-instruct"), Qwen3Adapter)
98
+
99
+
100
+ class TestAdapterAttributes(unittest.TestCase):
101
+ """Test adapter attribute resolution."""
102
+
103
+ def test_embed_resolution(self):
104
+ """Test embedding layer resolution."""
105
+ from dflash_mlx.adapters import MLXTargetAdapter
106
+
107
+ mock_model = MagicMock()
108
+ mock_embed = MagicMock()
109
+ mock_model.embed_tokens = mock_embed
110
+
111
+ adapter = MLXTargetAdapter(mock_model)
112
+ # embed_tokens at top level should be found
113
+ self.assertIsNotNone(adapter._embed)
114
+
115
+ def test_layer_resolution(self):
116
+ """Test layer list resolution."""
117
+ from dflash_mlx.adapters import MLXTargetAdapter
118
+
119
+ mock_model = MagicMock()
120
+ mock_layers = [MagicMock() for _ in range(3)]
121
+ mock_model.layers = mock_layers
122
+
123
+ adapter = MLXTargetAdapter(mock_model)
124
+ self.assertEqual(len(adapter._layers), 3)
125
+
126
+
127
+ class TestCacheManagement(unittest.TestCase):
128
+ """Test KV cache creation and rewind."""
129
+
130
+ def test_make_kv_cache(self):
131
+ """Create KV cache for standard transformer."""
132
+ from dflash_mlx.adapters import MLXTargetAdapter
133
+ from mlx_lm.models import cache as cache_lib
134
+
135
+ mock_model = MagicMock()
136
+ mock_layers = [MagicMock() for _ in range(4)]
137
+ mock_model.layers = mock_layers
138
+
139
+ adapter = MLXTargetAdapter(mock_model)
140
+ adapter.family = "qwen3"
141
+ adapter.arch_info = {"cache_type": "KVCache"}
142
+
143
+ cache = adapter.make_cache()
144
+ self.assertEqual(len(cache), 4)
145
+ self.assertIsInstance(cache[0], cache_lib.KVCache)
146
+
147
+ def test_rewind_cache(self):
148
+ """Trim KV cache to accepted length."""
149
+ from dflash_mlx.adapters import MLXTargetAdapter
150
+ from mlx_lm.models import cache as cache_lib
151
+
152
+ mock_model = MagicMock()
153
+
154
+ adapter = MLXTargetAdapter(mock_model)
155
+
156
+ # Mock cache
157
+ mock_cache = MagicMock(spec=cache_lib.KVCache)
158
+ mock_cache.offset = 100
159
+
160
+ adapter.rewind_kv_caches([mock_cache], 50)
161
+ mock_cache.trim.assert_called_once_with(50)
162
+
163
+
164
+ if __name__ == "__main__":
165
+ unittest.main()