| |
| """ |
| Phase 1 Validation Test Script |
| Tests that HF API inference has been removed and local models work correctly |
| """ |
|
|
| import sys |
| import os |
| import asyncio |
| import logging |
|
|
| |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
| logger = logging.getLogger(__name__) |
|
|
| def test_imports(): |
| """Test that all required modules can be imported""" |
| logger.info("Testing imports...") |
| try: |
| from src.llm_router import LLMRouter |
| from src.models_config import LLM_CONFIG |
| from src.local_model_loader import LocalModelLoader |
| logger.info("✅ All imports successful") |
| return True |
| except Exception as e: |
| logger.error(f"❌ Import failed: {e}") |
| return False |
|
|
| def test_models_config(): |
| """Test that models_config is updated correctly""" |
| logger.info("Testing models_config...") |
| try: |
| from src.models_config import LLM_CONFIG |
| |
| |
| assert LLM_CONFIG["primary_provider"] == "local", "Primary provider should be 'local'" |
| logger.info("✅ Primary provider is 'local'") |
| |
| |
| reasoning_model = LLM_CONFIG["models"]["reasoning_primary"]["model_id"] |
| assert ":cerebras" not in reasoning_model, "Model ID should not have API suffix" |
| assert reasoning_model == "Qwen/Qwen2.5-7B-Instruct", "Should use Qwen model" |
| logger.info(f"✅ Reasoning model: {reasoning_model}") |
| |
| |
| assert "API" not in str(LLM_CONFIG["routing_logic"]["fallback_chain"]), "No API in fallback chain" |
| logger.info("✅ Routing logic updated") |
| |
| return True |
| except Exception as e: |
| logger.error(f"❌ Models config test failed: {e}") |
| return False |
|
|
| def test_llm_router_init(): |
| """Test LLM router initialization""" |
| logger.info("Testing LLM router initialization...") |
| try: |
| from src.llm_router import LLMRouter |
| |
| |
| try: |
| router = LLMRouter(hf_token=None, use_local_models=False) |
| logger.error("❌ Should have raised ValueError for use_local_models=False") |
| return False |
| except ValueError: |
| logger.info("✅ Correctly raises error for use_local_models=False") |
| |
| |
| try: |
| router = LLMRouter(hf_token=None, use_local_models=True) |
| logger.info("✅ LLM router initialized (local models)") |
| |
| |
| assert not hasattr(router, '_call_hf_endpoint'), "Should not have _call_hf_endpoint method" |
| assert not hasattr(router, '_is_model_healthy'), "Should not have _is_model_healthy method" |
| assert not hasattr(router, '_get_fallback_model'), "Should not have _get_fallback_model method" |
| logger.info("✅ HF API methods removed") |
| |
| return True |
| except RuntimeError as e: |
| logger.warning(f"⚠️ Local models not available: {e}") |
| logger.warning("This is expected if transformers/torch not installed") |
| return True |
| except Exception as e: |
| logger.error(f"❌ LLM router test failed: {e}") |
| return False |
|
|
| def test_no_api_references(): |
| """Test that no API references remain in code""" |
| logger.info("Testing for API references...") |
| try: |
| import inspect |
| from src.llm_router import LLMRouter |
| |
| router_source = inspect.getsource(LLMRouter) |
| |
| |
| assert "_call_hf_endpoint" not in router_source, "Should not have _call_hf_endpoint" |
| assert "router.huggingface.co" not in router_source, "Should not have HF API URL" |
| assert "HF Inference API" not in router_source or "no API fallback" in router_source, "Should not reference HF API" |
| |
| logger.info("✅ No API references found in LLM router") |
| return True |
| except Exception as e: |
| logger.error(f"❌ API reference test failed: {e}") |
| return False |
|
|
| async def test_inference_flow(): |
| """Test inference flow (if models available)""" |
| logger.info("Testing inference flow...") |
| try: |
| from src.llm_router import LLMRouter |
| |
| router = LLMRouter(hf_token=None, use_local_models=True) |
| |
| |
| try: |
| result = await router.route_inference( |
| task_type="general_reasoning", |
| prompt="What is 2+2?", |
| max_tokens=50 |
| ) |
| |
| if result: |
| logger.info(f"✅ Inference successful: {result[:50]}...") |
| return True |
| else: |
| logger.warning("⚠️ Inference returned None") |
| return False |
| except RuntimeError as e: |
| logger.warning(f"⚠️ Inference failed (expected if models not loaded): {e}") |
| return True |
| except RuntimeError as e: |
| logger.warning(f"⚠️ Router not available: {e}") |
| return True |
| except Exception as e: |
| logger.error(f"❌ Inference test failed: {e}") |
| return False |
|
|
| def main(): |
| """Run all tests""" |
| logger.info("=" * 60) |
| logger.info("PHASE 1 VALIDATION TESTS") |
| logger.info("=" * 60) |
| |
| tests = [ |
| ("Imports", test_imports), |
| ("Models Config", test_models_config), |
| ("LLM Router Init", test_llm_router_init), |
| ("No API References", test_no_api_references), |
| ] |
| |
| results = [] |
| for test_name, test_func in tests: |
| logger.info(f"\n--- Running {test_name} Test ---") |
| try: |
| result = test_func() |
| results.append((test_name, result)) |
| except Exception as e: |
| logger.error(f"Test {test_name} crashed: {e}") |
| results.append((test_name, False)) |
| |
| |
| logger.info("\n--- Running Inference Flow Test ---") |
| try: |
| result = asyncio.run(test_inference_flow()) |
| results.append(("Inference Flow", result)) |
| except Exception as e: |
| logger.error(f"Inference flow test crashed: {e}") |
| results.append(("Inference Flow", False)) |
| |
| |
| logger.info("\n" + "=" * 60) |
| logger.info("TEST SUMMARY") |
| logger.info("=" * 60) |
| |
| passed = sum(1 for _, result in results if result) |
| total = len(results) |
| |
| for test_name, result in results: |
| status = "✅ PASS" if result else "❌ FAIL" |
| logger.info(f"{status}: {test_name}") |
| |
| logger.info(f"\nTotal: {passed}/{total} tests passed") |
| |
| if passed == total: |
| logger.info("✅ All tests passed!") |
| return 0 |
| else: |
| logger.warning(f"⚠️ {total - passed} test(s) failed") |
| return 1 |
|
|
| if __name__ == "__main__": |
| sys.exit(main()) |
|
|
|
|