|
|
| """
|
| Comprehensive validation test suite for Supernova training.
|
| Runs while user trains on VM to ensure system integrity.
|
| """
|
|
|
| import sys
|
| import os
|
| import time
|
| import traceback
|
| from pathlib import Path
|
|
|
| sys.path.append('.')
|
|
|
| def test_1_model_architecture():
|
| """Test 1: Model Architecture & Parameter Count"""
|
| print("π§ͺ TEST 1: Model Architecture & Parameter Count")
|
| try:
|
| from supernova.config import ModelConfig
|
| from supernova.model import SupernovaModel
|
|
|
| cfg = ModelConfig.from_json_file('./configs/supernova_25m.json')
|
| model = SupernovaModel(cfg)
|
| total_params = sum(p.numel() for p in model.parameters())
|
|
|
| assert total_params == 25_000_000, f"Expected 25M, got {total_params}"
|
| assert cfg.n_layers == 6, f"Expected 6 layers, got {cfg.n_layers}"
|
| assert cfg.d_model == 320, f"Expected d_model=320, got {cfg.d_model}"
|
| assert cfg.n_heads == 10, f"Expected 10 heads, got {cfg.n_heads}"
|
|
|
| print(f" β
Parameter count: {total_params:,} (EXACT)")
|
| print(f" β
Architecture: {cfg.n_layers}L, {cfg.d_model}D, {cfg.n_heads}H")
|
| return True
|
|
|
| except Exception as e:
|
| print(f" β FAILED: {e}")
|
| return False
|
|
|
| def test_2_data_pipeline():
|
| """Test 2: Data Loading & Processing"""
|
| print("π§ͺ TEST 2: Data Pipeline Validation")
|
| try:
|
| from supernova.data import load_sources_from_yaml, TokenChunkDataset
|
| from supernova.tokenizer import load_gpt2_tokenizer
|
|
|
|
|
| sources = load_sources_from_yaml('./configs/data_sources.yaml')
|
| assert len(sources) > 0, "No data sources loaded"
|
|
|
|
|
| tok = load_gpt2_tokenizer()
|
| assert tok.vocab_size == 50257, f"Expected vocab=50257, got {tok.vocab_size}"
|
|
|
|
|
| ds = TokenChunkDataset(tok, sources, seq_len=256, eos_token_id=tok.eos_token_id)
|
|
|
|
|
| batch = next(iter(ds))
|
| x, y = batch
|
| assert x.shape == (256,), f"Expected shape (256,), got {x.shape}"
|
| assert y.shape == (256,), f"Expected shape (256,), got {y.shape}"
|
|
|
| print(f" β
Data sources: {len(sources)} sources loaded")
|
| print(f" β
Tokenizer: {tok.vocab_size:,} vocab size")
|
| print(f" β
Dataset: Batch shape {x.shape}")
|
| return True
|
|
|
| except Exception as e:
|
| print(f" β FAILED: {e}")
|
| return False
|
|
|
| def test_3_training_mechanics():
|
| """Test 3: Training Forward/Backward Pass"""
|
| print("π§ͺ TEST 3: Training Mechanics")
|
| try:
|
| import torch
|
| from supernova.config import ModelConfig
|
| from supernova.model import SupernovaModel
|
| from supernova.tokenizer import load_gpt2_tokenizer
|
|
|
|
|
| cfg = ModelConfig.from_json_file('./configs/supernova_25m.json')
|
| model = SupernovaModel(cfg)
|
| tok = load_gpt2_tokenizer()
|
|
|
|
|
| batch_size, seq_len = 2, 128
|
| x = torch.randint(0, tok.vocab_size, (batch_size, seq_len))
|
| y = torch.randint(0, tok.vocab_size, (batch_size, seq_len))
|
|
|
|
|
| model.train()
|
| logits, loss = model(x, y)
|
| assert logits.shape == (batch_size, seq_len, tok.vocab_size)
|
| assert loss.numel() == 1, "Loss should be scalar"
|
|
|
|
|
| loss.backward()
|
|
|
|
|
| grad_count = sum(1 for p in model.parameters() if p.grad is not None)
|
| total_params = len(list(model.parameters()))
|
| assert grad_count == total_params, f"Missing gradients: {grad_count}/{total_params}"
|
|
|
| print(f" β
Forward pass: logits shape {logits.shape}")
|
| print(f" β
Loss computation: {loss.item():.4f}")
|
| print(f" β
Backward pass: {grad_count}/{total_params} gradients")
|
| return True
|
|
|
| except Exception as e:
|
| print(f" β FAILED: {e}")
|
| return False
|
|
|
| def test_4_advanced_reasoning():
|
| """Test 4: Advanced Reasoning System"""
|
| print("π§ͺ TEST 4: Advanced Reasoning System")
|
| try:
|
| from chat_advanced import AdvancedSupernovaChat
|
|
|
|
|
| chat = AdvancedSupernovaChat(
|
| config_path="./configs/supernova_25m.json",
|
| api_keys_path="./configs/api_keys.yaml"
|
| )
|
|
|
|
|
| math_response = chat.respond("what is 7 * 8?")
|
| assert "56" in math_response, f"Math engine failed: {math_response}"
|
|
|
|
|
| reasoning_response = chat.respond("analyze the benefits of solar energy")
|
| assert len(reasoning_response) > 50, "Reasoning response too short"
|
|
|
| print(" β
Math engine: Working (7*8=56)")
|
| print(" β
Reasoning engine: Response generated")
|
| print(" β
Tool coordination: Functional")
|
| return True
|
|
|
| except Exception as e:
|
| print(f" β FAILED: {e}")
|
| return False
|
|
|
| def test_5_checkpoint_system():
|
| """Test 5: Checkpoint Save/Load"""
|
| print("π§ͺ TEST 5: Checkpoint System")
|
| try:
|
| import torch
|
| from supernova.config import ModelConfig
|
| from supernova.model import SupernovaModel
|
|
|
|
|
| cfg = ModelConfig.from_json_file('./configs/supernova_25m.json')
|
| model = SupernovaModel(cfg)
|
|
|
|
|
| test_dir = "./test_checkpoint"
|
| os.makedirs(test_dir, exist_ok=True)
|
| checkpoint_path = os.path.join(test_dir, "test.pt")
|
|
|
| torch.save({
|
| "model_state_dict": model.state_dict(),
|
| "config": cfg.__dict__,
|
| "step": 100,
|
| "test": True
|
| }, checkpoint_path)
|
|
|
|
|
| checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
| assert "model_state_dict" in checkpoint
|
| assert "config" in checkpoint
|
| assert checkpoint["step"] == 100
|
| assert checkpoint["test"] == True
|
|
|
|
|
| new_model = SupernovaModel(cfg)
|
| new_model.load_state_dict(checkpoint["model_state_dict"])
|
|
|
|
|
| os.remove(checkpoint_path)
|
| os.rmdir(test_dir)
|
|
|
| print(" β
Checkpoint save: Working")
|
| print(" β
Checkpoint load: Working")
|
| print(" β
Model state restoration: Working")
|
| return True
|
|
|
| except Exception as e:
|
| print(f" β FAILED: {e}")
|
| return False
|
|
|
| def test_6_memory_efficiency():
|
| """Test 6: Memory Usage & Efficiency"""
|
| print("π§ͺ TEST 6: Memory Efficiency")
|
| try:
|
| import torch
|
| import psutil
|
| import gc
|
| from supernova.config import ModelConfig
|
| from supernova.model import SupernovaModel
|
|
|
|
|
| process = psutil.Process()
|
| initial_memory = process.memory_info().rss / 1024 / 1024
|
|
|
|
|
| cfg = ModelConfig.from_json_file('./configs/supernova_25m.json')
|
| model = SupernovaModel(cfg)
|
|
|
|
|
| model_memory = process.memory_info().rss / 1024 / 1024
|
| model_overhead = model_memory - initial_memory
|
|
|
|
|
| expected_size = 25_000_000 * 4 / 1024 / 1024
|
|
|
|
|
| x = torch.randint(0, 50257, (4, 256))
|
| y = torch.randint(0, 50257, (4, 256))
|
|
|
| logits, loss = model(x, y)
|
| loss.backward()
|
|
|
| grad_memory = process.memory_info().rss / 1024 / 1024
|
| grad_overhead = grad_memory - model_memory
|
|
|
| print(f" β
Model memory: {model_overhead:.1f}MB (expected ~{expected_size:.1f}MB)")
|
| print(f" β
Gradient memory: {grad_overhead:.1f}MB")
|
| print(f" β
Total memory: {grad_memory:.1f}MB")
|
|
|
|
|
| assert grad_memory < 1024, f"Memory usage too high: {grad_memory:.1f}MB"
|
|
|
| return True
|
|
|
| except Exception as e:
|
| print(f" β FAILED: {e}")
|
| return False
|
|
|
| def test_7_training_script():
|
| """Test 7: Training Script Validation"""
|
| print("π§ͺ TEST 7: Training Script")
|
| try:
|
|
|
| assert os.path.exists("supernova/train.py"), "Training script not found"
|
|
|
|
|
| from supernova.train import train, compute_grad_norm
|
|
|
|
|
| import inspect
|
| train_sig = inspect.signature(train)
|
| expected_params = ['config_path', 'data_config_path', 'seq_len', 'batch_size', 'grad_accum']
|
|
|
| for param in expected_params:
|
| assert param in train_sig.parameters, f"Missing parameter: {param}"
|
|
|
| print(" β
Training script: Found")
|
| print(" β
Function imports: Working")
|
| print(" β
Parameter validation: Complete")
|
| return True
|
|
|
| except Exception as e:
|
| print(f" β FAILED: {e}")
|
| return False
|
|
|
| def test_8_configuration_files():
|
| """Test 8: Configuration Files"""
|
| print("π§ͺ TEST 8: Configuration Files")
|
| try:
|
|
|
| assert os.path.exists("./configs/supernova_25m.json"), "Model config missing"
|
| assert os.path.exists("./configs/data_sources.yaml"), "Data config missing"
|
| assert os.path.exists("./configs/api_keys.yaml"), "API config missing"
|
|
|
|
|
| from supernova.config import ModelConfig
|
| from supernova.data import load_sources_from_yaml
|
| import yaml
|
|
|
| cfg = ModelConfig.from_json_file('./configs/supernova_25m.json')
|
| sources = load_sources_from_yaml('./configs/data_sources.yaml')
|
|
|
| with open('./configs/api_keys.yaml', 'r') as f:
|
| api_config = yaml.safe_load(f)
|
|
|
| assert 'serper_api_key' in api_config, "Serper API key missing"
|
| assert len(sources) > 0, "No data sources configured"
|
|
|
| print(" β
Model config: Valid")
|
| print(" β
Data config: Valid")
|
| print(" β
API config: Valid")
|
| return True
|
|
|
| except Exception as e:
|
| print(f" β FAILED: {e}")
|
| return False
|
|
|
| def run_full_validation_suite():
|
| """Run the complete validation suite"""
|
| print("π SUPERNOVA TRAINING VALIDATION SUITE")
|
| print("=" * 60)
|
| print("Running comprehensive tests while VM training initiates...")
|
| print()
|
|
|
| tests = [
|
| test_1_model_architecture,
|
| test_2_data_pipeline,
|
| test_3_training_mechanics,
|
| test_4_advanced_reasoning,
|
| test_5_checkpoint_system,
|
| test_6_memory_efficiency,
|
| test_7_training_script,
|
| test_8_configuration_files,
|
| ]
|
|
|
| results = []
|
| start_time = time.time()
|
|
|
| for i, test_func in enumerate(tests, 1):
|
| print(f"\n{'='*20} TEST {i}/{len(tests)} {'='*20}")
|
| try:
|
| result = test_func()
|
| results.append(result)
|
| print(f" {'β
PASSED' if result else 'β FAILED'}")
|
| except Exception as e:
|
| print(f" β CRITICAL ERROR: {e}")
|
| traceback.print_exc()
|
| results.append(False)
|
| print()
|
|
|
|
|
| passed = sum(results)
|
| total = len(results)
|
| success_rate = (passed / total) * 100
|
| elapsed = time.time() - start_time
|
|
|
| print("=" * 60)
|
| print("π VALIDATION SUMMARY")
|
| print("=" * 60)
|
| print(f"Tests Passed: {passed}/{total} ({success_rate:.1f}%)")
|
| print(f"Validation Time: {elapsed:.1f}s")
|
| print()
|
|
|
| if passed == total:
|
| print("π ALL TESTS PASSED - TRAINING SYSTEM VALIDATED")
|
| print("β
VM training can proceed with confidence")
|
| print("β
No blocking issues detected")
|
| else:
|
| print("β οΈ SOME TESTS FAILED")
|
| print("β Review failed tests before continuing VM training")
|
| failed_tests = [i+1 for i, result in enumerate(results) if not result]
|
| print(f"β Failed test numbers: {failed_tests}")
|
|
|
| print("=" * 60)
|
| return passed == total
|
|
|
| if __name__ == "__main__":
|
| success = run_full_validation_suite()
|
| sys.exit(0 if success else 1) |