Spaces:
Runtime error
Runtime error
| import asyncio | |
| import os | |
| import json | |
| import logging | |
| from app.services.medical_retriever import get_medical_retriever | |
| from app.services.vector_operations import VectorOperations | |
| # Setup basic logging to suppress library info | |
| logging.basicConfig(level=logging.ERROR) | |
| async def direct_medical_test(): | |
| test_json_path = "/Users/mac/Projects/customerAgent/server/datasets/test_medical_queries.json" | |
| if not os.path.exists(test_json_path): | |
| print(f"Error: File not found {test_json_path}") | |
| return | |
| with open(test_json_path, 'r') as f: | |
| data = json.load(f) | |
| # Initialize Retriever | |
| retriever = get_medical_retriever() | |
| all_questions = [] | |
| for cat, subcats in data.items(): | |
| for subcat, questions in subcats.items(): | |
| for q in questions: | |
| all_questions.append({ | |
| "category": cat, | |
| "subcategory": subcat, | |
| "question": q | |
| }) | |
| print(f"\n--- Starting Direct Retrieval Test on {len(all_questions)} Queries ---") | |
| success_count = 0 | |
| fail_count = 0 | |
| total_conf = 0 | |
| results = [] | |
| # Process queries sequentially to avoid terminal noise | |
| for i, item in enumerate(all_questions): | |
| query = item['question'] | |
| try: | |
| embedding = await VectorOperations.get_embedding(query) | |
| search_results = retriever.search_medical(query, embedding, top_k=1) | |
| confidence = search_results[0].get('confidence', 0) if search_results else 0 | |
| is_pass = confidence > 0.85 # High bar for exact query matches | |
| if is_pass: success_count += 1 | |
| else: fail_count += 1 | |
| total_conf += confidence | |
| results.append({ | |
| "question": query, | |
| "confidence": confidence, | |
| "pass": is_pass | |
| }) | |
| if (i+1) % 20 == 0: | |
| print(f"Progress: {i+1}/{len(all_questions)}...") | |
| except Exception as e: | |
| fail_count += 1 | |
| print(f"Error on query '{query[:30]}': {e}") | |
| # Summary Report | |
| print("\n--- Medical Retrieval Summary Report ---") | |
| print(f"Total Queries: {len(all_questions)}") | |
| print(f"Pass (>0.85 conf): {success_count}") | |
| print(f"Fail/Low Conf: {fail_count}") | |
| print(f"Average Confidence: {total_conf / len(all_questions):.4f}") | |
| # Show Top 3 failures | |
| failures = [r for r in results if not r['pass']] | |
| if failures: | |
| print("\nSample Low Confidence Matches:") | |
| for f in failures[:3]: | |
| print(f"- [Conf: {f['confidence']:.2f}] {f['question']}") | |
| # Verification of Source diversity | |
| print("\n--- Knowledge Base Integrity ---") | |
| print(f"Retriever identifies {len(retriever.documents)} documents in unified store.") | |
| print("Test Complete.") | |
| if __name__ == "__main__": | |
| asyncio.run(direct_medical_test()) | |