TinmanLabSL commited on
Commit
771bbf0
·
verified ·
1 Parent(s): d83d26e

Add production readiness test suite — validates all 6 capabilities end-to-end

Browse files
Files changed (1) hide show
  1. test_production_readiness.py +212 -0
test_production_readiness.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tinman-SmolOmni-MLA: Production Readiness Test Suite
3
+ ======================================================
4
+
5
+ Tests everything a new user needs to verify:
6
+ 1. pip install works
7
+ 2. Load checkpoint from HuggingFace Hub
8
+ 3. Text understanding inference
9
+ 4. Image generation pipeline
10
+ 5. Moonshine audio integration
11
+ 6. KV cache verification
12
+
13
+ Run:
14
+ python test_production_readiness.py
15
+
16
+ Requires:
17
+ pip install git+https://huggingface.co/TinmanLabSL/SmolOmni-MLA-Toolkit
18
+ pip install transformers pillow soundfile librosa
19
+ """
20
+ import torch
21
+ import warnings
22
+ import sys
23
+
24
+ # Global imports
25
+ try:
26
+ import smolomni
27
+ from smolomni import SmolOmni, get_model_config
28
+ from smolomni.config import SmolOmniConfig
29
+ IMPORTS_OK = True
30
+ except Exception as e:
31
+ IMPORTS_OK = False
32
+ IMPORT_ERROR = e
33
+ SmolOmni = None
34
+
35
+ def main():
36
+ print("=" * 60)
37
+ print("Tinman-SmolOmni-MLA: Production Readiness Test Suite")
38
+ print("=" * 60)
39
+ print(f"PyTorch: {torch.__version__}")
40
+ print(f"CUDA available: {torch.cuda.is_available()}")
41
+
42
+ results = {}
43
+
44
+ # Test 1: Imports
45
+ print("\n" + "=" * 60)
46
+ print("TEST 1: Package Import")
47
+ print("=" * 60)
48
+ if IMPORTS_OK:
49
+ print(f" ✅ Package: smolomni v{smolomni.__version__}")
50
+ print(f" ✅ get_model_config: {get_model_config('mla-hybrid-ar-flow-500M').hidden_size} hidden")
51
+ results['imports'] = True
52
+ else:
53
+ print(f" ❌ FAILED: {IMPORT_ERROR}")
54
+ results['imports'] = False
55
+ return results
56
+
57
+ # Test 2: Load 500M checkpoint
58
+ print("\n" + "=" * 60)
59
+ print("TEST 2: Load 500M Checkpoint from Hub")
60
+ print("=" * 60)
61
+ print(" Downloading 1.1GB checkpoint... (may take 2-3 minutes)")
62
+
63
+ try:
64
+ model = SmolOmni.from_hub(
65
+ 'TinmanLabSL/SmolOmni-MLA-500M',
66
+ checkpoint='stage2_final/model.pt',
67
+ config='mla-hybrid-ar-flow-500M',
68
+ device='cpu',
69
+ dtype=torch.float32,
70
+ strict=False,
71
+ )
72
+ n_params = sum(p.numel() for p in model.parameters())
73
+ print(f" ✅ Model loaded: {n_params/1e6:.1f}M parameters")
74
+ print(f" ✅ Config variant: {model.config.model_variant}")
75
+ print(f" ✅ Layers: {model.config.num_hidden_layers}")
76
+ gqa_count = sum(1 for l in model.layers if not l.is_mla)
77
+ mla_count = sum(1 for l in model.layers if l.is_mla)
78
+ print(f" ✅ GQA layers: {gqa_count}, MLA layers: {mla_count}")
79
+ results['load_checkpoint'] = True
80
+ except Exception as e:
81
+ print(f" ❌ FAILED: {e}")
82
+ import traceback
83
+ traceback.print_exc()
84
+ results['load_checkpoint'] = False
85
+ return results
86
+
87
+ # Test 3: Text understanding
88
+ print("\n" + "=" * 60)
89
+ print("TEST 3: Text Understanding Inference")
90
+ print("=" * 60)
91
+ try:
92
+ from transformers import AutoTokenizer
93
+ tokenizer = AutoTokenizer.from_pretrained('HuggingFaceTB/SmolVLM-500M-Instruct')
94
+
95
+ prompt = "The capital of France is"
96
+ inputs = tokenizer(prompt, return_tensors='pt')
97
+
98
+ with torch.no_grad():
99
+ result = model.forward_understanding(input_ids=inputs['input_ids'])
100
+ logits = result['logits']
101
+
102
+ next_token = logits[0, -1, :].argmax()
103
+ prediction = tokenizer.decode([next_token])
104
+
105
+ print(f" ✅ Input: '{prompt}'")
106
+ print(f" ✅ Logits shape: {logits.shape}")
107
+ print(f" ✅ Next token: '{prediction}'")
108
+ results['text_understanding'] = True
109
+ except Exception as e:
110
+ print(f" ❌ FAILED: {e}")
111
+ import traceback
112
+ traceback.print_exc()
113
+ results['text_understanding'] = False
114
+
115
+ # Test 4: Image generation
116
+ print("\n" + "=" * 60)
117
+ print("TEST 4: Image Generation Pipeline")
118
+ print("=" * 60)
119
+ try:
120
+ from transformers import AutoTokenizer
121
+ tokenizer = AutoTokenizer.from_pretrained('HuggingFaceTB/SmolVLM-500M-Instruct')
122
+
123
+ prompt = "a red apple"
124
+ inputs = tokenizer(prompt, return_tensors='pt')
125
+
126
+ with torch.no_grad():
127
+ latents = model.generate_image(
128
+ input_ids=inputs['input_ids'],
129
+ num_steps=5,
130
+ latent_shape=(1, 4, 32, 32),
131
+ )
132
+
133
+ print(f" ✅ Prompt: '{prompt}'")
134
+ print(f" ✅ Latents shape: {latents.shape}")
135
+ print(f" ✅ Latents mean: {latents.mean().item():.4f}, std: {latents.std().item():.4f}")
136
+ print(f" ⚠️ Run through VAE decoder for actual image")
137
+ results['image_generation'] = True
138
+ except Exception as e:
139
+ print(f" ❌ FAILED: {e}")
140
+ import traceback
141
+ traceback.print_exc()
142
+ results['image_generation'] = False
143
+
144
+ # Test 5: KV cache
145
+ print("\n" + "=" * 60)
146
+ print("TEST 5: KV Cache Verification")
147
+ print("=" * 60)
148
+ try:
149
+ kv = model.kv_cache_info()
150
+ print(f" ✅ Original GQA: {kv['original_gqa']} floats/token")
151
+ print(f" ✅ Hybrid cache: {kv['hybrid']} floats/token")
152
+ print(f" ✅ Reduction: {kv['hybrid_reduction_pct']}%")
153
+ results['kv_cache'] = True
154
+ except Exception as e:
155
+ print(f" ❌ FAILED: {e}")
156
+ results['kv_cache'] = False
157
+
158
+ # Test 6: Moonshine audio
159
+ print("\n" + "=" * 60)
160
+ print("TEST 6: Moonshine Audio Integration")
161
+ print("=" * 60)
162
+ try:
163
+ import numpy as np
164
+ from moonshine_integration import SmolOmniAudio
165
+
166
+ audio_model = SmolOmniAudio(device='cpu')
167
+
168
+ sr = 16000
169
+ t = np.linspace(0, 1, sr)
170
+ audio = 0.3 * np.sin(2 * np.pi * 440 * t).astype(np.float32)
171
+
172
+ result = audio_model.transcribe(audio)
173
+ print(f" ✅ ASR model: {audio_model.asr_params:.1f}M params")
174
+ print(f" ✅ Transcription: '{result}'")
175
+
176
+ chat = audio_model.chat(audio=audio, question="What is this?")
177
+ print(f" ✅ Chat pipeline: {len(chat['full_prompt'])} chars")
178
+ results['moonshine_audio'] = True
179
+ except Exception as e:
180
+ print(f" ❌ FAILED: {e}")
181
+ import traceback
182
+ traceback.print_exc()
183
+ results['moonshine_audio'] = False
184
+
185
+ # Summary
186
+ print("\n" + "=" * 60)
187
+ print("SUMMARY")
188
+ print("=" * 60)
189
+
190
+ passed = sum(1 for v in results.values() if v)
191
+ total = len(results)
192
+
193
+ for test_name, passed_test in results.items():
194
+ status = "✅ PASS" if passed_test else "❌ FAIL"
195
+ print(f" {test_name:25s} {status}")
196
+
197
+ print(f"\n{passed}/{total} tests passed")
198
+
199
+ if passed == total:
200
+ print("\n🎉 Production ready!")
201
+ elif passed >= 4:
202
+ print("\n⚠️ Mostly ready — investigate failures")
203
+ else:
204
+ print("\n❌ Not ready — multiple critical failures")
205
+
206
+ return results
207
+
208
+ if __name__ == "__main__":
209
+ with warnings.catch_warnings():
210
+ warnings.simplefilter("ignore")
211
+ results = main()
212
+ sys.exit(0 if all(results.values()) else 1)