| """Common test fixtures and configuration."""
|
|
|
| import pytest
|
| import asyncio
|
| import os
|
| from unittest.mock import Mock, MagicMock, patch
|
| from typing import Dict, Any, Generator
|
|
|
|
|
| TEST_ENV_VARS = {
|
| "TAVILY_API_KEY": "tvly-test-key-12345",
|
| "NEBIUS_API_KEY": "test-nebius-key",
|
| "OPENAI_API_KEY": "test-openai-key",
|
| "ANTHROPIC_API_KEY": "test-anthropic-key",
|
| "HUGGINGFACE_API_KEY": "test-hf-key",
|
| "LLM_PROVIDER": "nebius"
|
| }
|
|
|
|
|
| for key, value in TEST_ENV_VARS.items():
|
| os.environ[key] = value
|
|
|
| @pytest.fixture
|
| def mock_tavily_client():
|
| """Mock Tavily client for web search tests."""
|
| mock_client = Mock()
|
| mock_client.search.return_value = {
|
| "results": [
|
| {
|
| "title": "Test Result 1",
|
| "url": "https://example.com/1",
|
| "content": "Test content 1",
|
| "score": 0.9
|
| },
|
| {
|
| "title": "Test Result 2",
|
| "url": "https://example.com/2",
|
| "content": "Test content 2",
|
| "score": 0.8
|
| }
|
| ],
|
| "answer": "Test search summary"
|
| }
|
| return mock_client
|
|
|
| @pytest.fixture
|
| def mock_llm_response():
|
| """Mock LLM completion response."""
|
| return '{"sub_questions": ["Question 1?", "Question 2?", "Question 3?"]}'
|
|
|
| @pytest.fixture
|
| def mock_modal_sandbox():
|
| """Mock Modal sandbox for code execution tests."""
|
| mock_sandbox = Mock()
|
| mock_sandbox.exec.return_value = Mock(stdout="Test output", stderr="", returncode=0)
|
| return mock_sandbox
|
|
|
| @pytest.fixture
|
| def sample_user_request():
|
| """Sample user request for testing."""
|
| return "Create a Python script to analyze CSV data and generate charts"
|
|
|
| @pytest.fixture
|
| def sample_search_results():
|
| """Sample search results for testing."""
|
| return [
|
| {
|
| "title": "Python Data Analysis Tutorial",
|
| "url": "https://example.com/pandas-tutorial",
|
| "content": "Learn how to analyze CSV data with pandas and matplotlib...",
|
| "score": 0.95
|
| },
|
| {
|
| "title": "Chart Generation with Python",
|
| "url": "https://example.com/charts",
|
| "content": "Create stunning charts and visualizations...",
|
| "score": 0.87
|
| }
|
| ]
|
|
|
| @pytest.fixture
|
| def sample_code():
|
| """Sample Python code for testing."""
|
| return '''
|
| import pandas as pd
|
| import matplotlib.pyplot as plt
|
|
|
| # Load data
|
| df = pd.read_csv('data.csv')
|
|
|
| # Generate chart
|
| df.plot(kind='bar')
|
| plt.show()
|
| '''
|
|
|
| @pytest.fixture
|
| def mock_config():
|
| """Mock configuration objects."""
|
| api_config = Mock()
|
| api_config.tavily_api_key = "tvly-test-key"
|
| api_config.llm_provider = "nebius"
|
| api_config.nebius_api_key = "test-nebius-key"
|
|
|
| model_config = Mock()
|
| model_config.get_model_for_provider.return_value = "meta-llama/llama-3.1-8b-instruct"
|
|
|
| return api_config, model_config
|
|
|
| @pytest.fixture
|
| def event_loop():
|
| """Create an event loop for async tests."""
|
| loop = asyncio.new_event_loop()
|
| yield loop
|
| loop.close()
|
|
|
| class MockAgent:
|
| """Base mock agent class for testing."""
|
| def __init__(self, name: str):
|
| self.name = name
|
| self.call_count = 0
|
|
|
| def __call__(self, *args, **kwargs):
|
| self.call_count += 1
|
| return {"success": True, "agent": self.name, "calls": self.call_count}
|
|
|
| @pytest.fixture
|
| def mock_agents():
|
| """Mock agent instances for orchestrator testing."""
|
| return {
|
| "question_enhancer": MockAgent("question_enhancer"),
|
| "web_search": MockAgent("web_search"),
|
| "llm_processor": MockAgent("llm_processor"),
|
| "citation_formatter": MockAgent("citation_formatter"),
|
| "code_generator": MockAgent("code_generator"),
|
| "code_runner": MockAgent("code_runner")
|
| }
|
|
|
| @pytest.fixture
|
| def disable_advanced_features():
|
| """Disable advanced features for basic testing."""
|
| with patch('app.ADVANCED_FEATURES_AVAILABLE', False):
|
| yield |