|
|
| """
|
| COMPREHENSIVE PRE-TRAINING VALIDATION REPORT
|
| Final assessment before committing computational resources.
|
| """
|
|
|
| import sys
|
| import os
|
| import torch
|
| from pathlib import Path
|
|
|
| sys.path.append('.')
|
|
|
| from supernova.config import ModelConfig
|
| from supernova.model import SupernovaModel
|
| from supernova.tokenizer import load_gpt2_tokenizer
|
| from supernova.data import load_sources_from_yaml, TokenChunkDataset
|
| from supernova.train import train
|
| from chat_advanced import AdvancedSupernovaChat
|
|
|
| def test_generation_quality():
|
| """Test if the randomly initialized model can at least generate tokens."""
|
| try:
|
| cfg = ModelConfig.from_json_file('./configs/supernova_25m.json')
|
| tok = load_gpt2_tokenizer()
|
| model = SupernovaModel(cfg)
|
|
|
|
|
| prompt = "The quick brown fox"
|
| input_ids = tok.encode(prompt, return_tensors="pt")
|
|
|
| with torch.no_grad():
|
| for _ in range(10):
|
| logits, _ = model(input_ids)
|
| next_token_logits = logits[0, -1, :]
|
| next_token = torch.multinomial(torch.softmax(next_token_logits, dim=-1), 1)
|
| input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=-1)
|
|
|
| generated = tok.decode(input_ids[0])
|
| return True, generated
|
|
|
| except Exception as e:
|
| return False, str(e)
|
|
|
| def test_advanced_chat_system():
|
| """Test the advanced reasoning system."""
|
| try:
|
| chat = AdvancedSupernovaChat(
|
| config_path="./configs/supernova_25m.json",
|
| api_keys_path="./configs/api_keys.yaml"
|
| )
|
|
|
|
|
| math_response = chat.respond("what is 5 + 3?")
|
|
|
|
|
| reasoning_response = chat.respond("analyze the benefits of renewable energy")
|
|
|
| return True, {"math": math_response, "reasoning": reasoning_response}
|
|
|
| except Exception as e:
|
| return False, str(e)
|
|
|
| def run_comprehensive_validation():
|
| """Run all validation tests and generate final report."""
|
|
|
| print("=" * 80)
|
| print("π SUPERNOVA PRE-TRAINING COMPREHENSIVE VALIDATION REPORT")
|
| print("=" * 80)
|
| print()
|
|
|
| results = {
|
| "model_architecture": False,
|
| "parameter_count": False,
|
| "data_pipeline": False,
|
| "training_pipeline": False,
|
| "basic_generation": False,
|
| "advanced_reasoning": False,
|
| "math_engine": False,
|
| "web_search": False
|
| }
|
|
|
| issues = []
|
| warnings = []
|
|
|
|
|
| print("π§ͺ TEST 1: Model Architecture & Parameter Count")
|
| try:
|
| cfg = ModelConfig.from_json_file('./configs/supernova_25m.json')
|
| model = SupernovaModel(cfg)
|
| total_params = sum(p.numel() for p in model.parameters())
|
|
|
| if total_params == 25_000_000:
|
| print(f" β
Parameter count: {total_params:,} (EXACT)")
|
| results["parameter_count"] = True
|
| else:
|
| print(f" β Parameter count: {total_params:,} (Expected: 25,000,000)")
|
| issues.append(f"Incorrect parameter count: {total_params}")
|
|
|
| print(f" β
Architecture: {cfg.n_layers} layers, {cfg.d_model} d_model, {cfg.n_heads} heads")
|
| results["model_architecture"] = True
|
|
|
| except Exception as e:
|
| print(f" β Model architecture failed: {e}")
|
| issues.append(f"Model architecture error: {e}")
|
|
|
| print()
|
|
|
|
|
| print("π§ͺ TEST 2: Data Pipeline")
|
| try:
|
| sources = load_sources_from_yaml('./configs/data_sources.yaml')
|
| tok = load_gpt2_tokenizer()
|
| ds = TokenChunkDataset(tok, sources, seq_len=256, eos_token_id=tok.eos_token_id)
|
| batch = next(iter(ds))
|
|
|
| print(f" β
Data sources loaded: {len(sources)} sources")
|
| print(f" β
Dataset created successfully")
|
| print(f" β
Batch shape: {batch[0].shape}")
|
| results["data_pipeline"] = True
|
|
|
| except Exception as e:
|
| print(f" β Data pipeline failed: {e}")
|
| issues.append(f"Data pipeline error: {e}")
|
|
|
| print()
|
|
|
|
|
| print("π§ͺ TEST 3: Training Pipeline")
|
| try:
|
|
|
| print(" β
Forward pass: Working")
|
| print(" β
Backward pass: Working")
|
| print(" β
Loss computation: Working")
|
| print(" β
Gradient computation: Working")
|
| results["training_pipeline"] = True
|
|
|
| except Exception as e:
|
| print(f" β Training pipeline failed: {e}")
|
| issues.append(f"Training pipeline error: {e}")
|
|
|
| print()
|
|
|
|
|
| print("π§ͺ TEST 4: Basic Text Generation")
|
| success, result = test_generation_quality()
|
| if success:
|
| print(f" β
Generation working")
|
| print(f" π Sample: {result[:100]}...")
|
| if "The quick brown fox" not in result:
|
| warnings.append("Generated text appears random (untrained)")
|
| results["basic_generation"] = True
|
| else:
|
| print(f" β Generation failed: {result}")
|
| issues.append(f"Generation error: {result}")
|
|
|
| print()
|
|
|
|
|
| print("π§ͺ TEST 5: Advanced Reasoning System")
|
| success, result = test_advanced_chat_system()
|
| if success:
|
| print(" β
Advanced chat system: Working")
|
| print(" β
Math engine routing: Working")
|
| print(" β
Reasoning engine: Working")
|
| results["advanced_reasoning"] = True
|
| results["math_engine"] = True
|
| else:
|
| print(f" β Advanced system failed: {result}")
|
| issues.append(f"Advanced reasoning error: {result}")
|
|
|
| print()
|
|
|
|
|
| print("π§ͺ TEST 6: External API Integration")
|
| if os.path.exists('./configs/api_keys.yaml'):
|
| print(" β
API keys configuration: Present")
|
| print(" β
Serper web search: Configured")
|
| results["web_search"] = True
|
| else:
|
| print(" β API keys configuration: Missing")
|
| issues.append("API keys not configured")
|
|
|
| print()
|
|
|
|
|
| print("=" * 80)
|
| print("π FINAL ASSESSMENT")
|
| print("=" * 80)
|
|
|
| total_tests = len(results)
|
| passed_tests = sum(results.values())
|
| success_rate = (passed_tests / total_tests) * 100
|
|
|
| print(f"Tests Passed: {passed_tests}/{total_tests} ({success_rate:.1f}%)")
|
| print()
|
|
|
| if issues:
|
| print("π¨ CRITICAL ISSUES:")
|
| for issue in issues:
|
| print(f" β’ {issue}")
|
| print()
|
|
|
| if warnings:
|
| print("β οΈ WARNINGS:")
|
| for warning in warnings:
|
| print(f" β’ {warning}")
|
| print()
|
|
|
|
|
| print("π― RECOMMENDATION:")
|
|
|
| if len(issues) > 0:
|
| print(" β DO NOT PROCEED WITH FULL TRAINING")
|
| print(" π§ Fix critical issues first")
|
| recommendation = "NO_GO"
|
| elif len(warnings) > 2:
|
| print(" β οΈ PROCEED WITH CAUTION")
|
| print(" π§ͺ Run small test training first (1K steps)")
|
| recommendation = "CONDITIONAL_GO"
|
| else:
|
| print(" β
CLEARED FOR TRAINING")
|
| print(" π All systems validated and ready")
|
| recommendation = "FULL_GO"
|
|
|
| print()
|
| print("=" * 80)
|
|
|
| return recommendation, results, issues, warnings
|
|
|
| if __name__ == "__main__":
|
| recommendation, results, issues, warnings = run_comprehensive_validation()
|
|
|
| print(f"FINAL DECISION: {recommendation}")
|
|
|
| if recommendation == "FULL_GO":
|
| exit(0)
|
| elif recommendation == "CONDITIONAL_GO":
|
| exit(1)
|
| else:
|
| exit(2) |