customeragent-api / server /tests /direct_medical_test.py
anasraza526's picture
Clean deploy to Hugging Face
ac90985
import asyncio
import os
import json
import logging
from app.services.medical_retriever import get_medical_retriever
from app.services.vector_operations import VectorOperations
# Setup basic logging to suppress library info
logging.basicConfig(level=logging.ERROR)
async def direct_medical_test():
test_json_path = "/Users/mac/Projects/customerAgent/server/datasets/test_medical_queries.json"
if not os.path.exists(test_json_path):
print(f"Error: File not found {test_json_path}")
return
with open(test_json_path, 'r') as f:
data = json.load(f)
# Initialize Retriever
retriever = get_medical_retriever()
all_questions = []
for cat, subcats in data.items():
for subcat, questions in subcats.items():
for q in questions:
all_questions.append({
"category": cat,
"subcategory": subcat,
"question": q
})
print(f"\n--- Starting Direct Retrieval Test on {len(all_questions)} Queries ---")
success_count = 0
fail_count = 0
total_conf = 0
results = []
# Process queries sequentially to avoid terminal noise
for i, item in enumerate(all_questions):
query = item['question']
try:
embedding = await VectorOperations.get_embedding(query)
search_results = retriever.search_medical(query, embedding, top_k=1)
confidence = search_results[0].get('confidence', 0) if search_results else 0
is_pass = confidence > 0.85 # High bar for exact query matches
if is_pass: success_count += 1
else: fail_count += 1
total_conf += confidence
results.append({
"question": query,
"confidence": confidence,
"pass": is_pass
})
if (i+1) % 20 == 0:
print(f"Progress: {i+1}/{len(all_questions)}...")
except Exception as e:
fail_count += 1
print(f"Error on query '{query[:30]}': {e}")
# Summary Report
print("\n--- Medical Retrieval Summary Report ---")
print(f"Total Queries: {len(all_questions)}")
print(f"Pass (>0.85 conf): {success_count}")
print(f"Fail/Low Conf: {fail_count}")
print(f"Average Confidence: {total_conf / len(all_questions):.4f}")
# Show Top 3 failures
failures = [r for r in results if not r['pass']]
if failures:
print("\nSample Low Confidence Matches:")
for f in failures[:3]:
print(f"- [Conf: {f['confidence']:.2f}] {f['question']}")
# Verification of Source diversity
print("\n--- Knowledge Base Integrity ---")
print(f"Retriever identifies {len(retriever.documents)} documents in unified store.")
print("Test Complete.")
if __name__ == "__main__":
asyncio.run(direct_medical_test())