Spaces:
Runtime error
Runtime error
| import asyncio | |
| import json | |
| import logging | |
| import sys | |
| import os | |
| # Ensure app modules can be imported | |
| sys.path.append(os.path.dirname(os.path.abspath(__file__))) | |
| from app.services.medical_orchestrator import get_medical_orchestrator, MedicalOrchestrator | |
| from app.services.intent_classifier import get_classifier | |
| from app.services.context_manager import EntryContext | |
| from app.services.vector_db import VectorDB | |
| # Configure Logging | |
| logging.basicConfig(level=logging.INFO, format='%(name)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger("Evaluator") | |
| async def setup_test_data(orchestrator: MedicalOrchestrator): | |
| """Inject dummy business data for RAG testing""" | |
| print("Injecting test business data into VectorDB...") | |
| # Dummy data for tenant "1" | |
| docs = [ | |
| {"text": "Our opening hours are Monday to Friday, 9 AM to 5 PM.", "source": "Business Hours"}, | |
| {"text": "We accept Cigna, BlueCross, and Aetna insurance plans.", "source": "Insurance Policy"}, | |
| {"text": "Dental cleanings start at $99 for new patients.", "source": "Pricing"}, | |
| {"text": "We are located at 123 Main St, New York.", "source": "Location"} | |
| ] | |
| website_id = 1 | |
| # Generate embeddings and add | |
| from app.services.vector_operations import VectorOperations | |
| vectors = [] | |
| metadata = [] | |
| for doc in docs: | |
| emb = await VectorOperations.get_embedding(doc['text'], is_query=False) | |
| vectors.append(emb) | |
| metadata.append(doc) | |
| import numpy as np | |
| orchestrator.vector_db.add_vectors( | |
| np.array(vectors, dtype=np.float32), | |
| metadata, | |
| website_id | |
| ) | |
| print("β Test data injected.") | |
| async def run_evaluation(): | |
| print("=== Starting Golden Dataset Evaluation ===") | |
| # Load Dataset | |
| with open("datasets/golden_evaluation_dataset.json", "r") as f: | |
| test_cases = json.load(f) | |
| orchestrator = get_medical_orchestrator() | |
| classifier = get_classifier() | |
| # Setup Data | |
| await setup_test_data(orchestrator) | |
| results = { | |
| "total": 0, | |
| "intent_pass": 0, | |
| "risk_pass": 0, | |
| "rag_pass": 0, | |
| "failures": [] | |
| } | |
| for case in test_cases: | |
| query = case['query'] | |
| expected_intent = case['expected_intent'] | |
| expected_risk = case.get('expected_risk') | |
| print(f"\nScanning: '{query}'") | |
| results["total"] += 1 | |
| # 1. Test Intent | |
| intent_res = await classifier.classify(query, industry="healthcare", context={}) | |
| actual_intent = intent_res.category.value | |
| # Loose match for intent (e.g. MEDICAL_CONSULT match) | |
| intent_match = (actual_intent == expected_intent) or \ | |
| (expected_intent == "BUSINESS_SPECIFIC" and actual_intent in ["FAQ", "BUSINESS_SPECIFIC"]) | |
| if intent_match: | |
| results["intent_pass"] += 1 | |
| print(f" β Intent: {actual_intent}") | |
| else: | |
| print(f" β Intent Mismatch: Expected {expected_intent}, Got {actual_intent}") | |
| results["failures"].append(f"Intent fail: {query}") | |
| # 2. Test Risk (using Orchestrator logic) | |
| # We need to manually invoke the risk logic as it's private/internal usually, | |
| # but analyze_risk is public in our refactor. | |
| actual_risk, _ = await orchestrator.analyze_risk(query, {}) | |
| # Risk matching (Critical/High are often grouped) | |
| risk_match = (actual_risk == expected_risk) | |
| if not risk_match and expected_risk == "high" and actual_risk == "critical": risk_match = True | |
| if risk_match: | |
| results["risk_pass"] += 1 | |
| print(f" β Risk: {actual_risk}") | |
| else: | |
| print(f" β Risk Mismatch: Expected {expected_risk}, Got {actual_risk}") | |
| results["failures"].append(f"Risk fail: {query}") | |
| # 3. Test Response (E2E) | |
| # Context with tenant_id="1" to match our injected data | |
| entry_context = EntryContext(tenant_id="1") | |
| response, conf, _ = await orchestrator.process_query(query, entry_context) | |
| # Simple validation for RAG | |
| rag_success = True | |
| if expected_intent in ["FAQ", "BUSINESS_SPECIFIC"]: | |
| # Check if response contains key info from our injected docs | |
| key_terms = [] | |
| if "hours" in query: key_terms = ["9 AM", "5 PM", "Monday"] | |
| if "insurance" in query: key_terms = ["Cigna", "Aetna"] | |
| if "cost" in query: key_terms = ["$99"] | |
| if key_terms: | |
| if any(term in response for term in key_terms): | |
| print(f" β RAG Retrieval Verified (Found '{key_terms[0]}')") | |
| else: | |
| print(f" β RAG Fail: Key terms {key_terms} not found in response: '{response[:50]}...'") | |
| rag_success = False | |
| results["failures"].append(f"RAG fail: {query}") | |
| if rag_success: | |
| results["rag_pass"] += 1 | |
| print("\n=== Evaluation Summary ===") | |
| print(f"Total Cases: {results['total']}") | |
| print(f"Intent Accuracy: {results['intent_pass']}/{results['total']} ({(results['intent_pass']/results['total'])*100:.1f}%)") | |
| print(f"Risk Accuracy: {results['risk_pass']}/{results['total']} ({(results['risk_pass']/results['total'])*100:.1f}%)") | |
| if len(results["failures"]) > 0: | |
| print("\nFailures:") | |
| for f in results["failures"]: | |
| print(f"- {f}") | |
| if __name__ == "__main__": | |
| asyncio.run(run_evaluation()) | |