Spaces:
Runtime error
Runtime error
| import sys | |
| from unittest.mock import MagicMock | |
| # 1. Mock problematic modules BEFORE any other imports | |
| mock_spacy = MagicMock() | |
| sys.modules["spacy"] = mock_spacy | |
| sys.modules["spacy.cli"] = MagicMock() | |
| sys.modules["en_core_web_md"] = MagicMock() | |
| sys.modules["en_core_web_sm"] = MagicMock() | |
| import asyncio | |
| import os | |
| import json | |
| import logging | |
| # 2. Mock specific classes to avoid model loading | |
| import app.services.smart_intent_classifier | |
| app.services.smart_intent_classifier.SmartIntentClassifier = MagicMock | |
| import app.services.nlp_processor | |
| app.services.nlp_processor.NLPProcessor = MagicMock | |
| from app.services.ai_engine import AIEngine | |
| # Setup logging | |
| logging.basicConfig(level=logging.WARNING) | |
| async def stress_test_medical(): | |
| test_json_path = "/Users/mac/Projects/customerAgent/server/datasets/test_medical_queries.json" | |
| if not os.path.exists(test_json_path): | |
| print(f"Error: File not found {test_json_path}") | |
| return | |
| with open(test_json_path, 'r') as f: | |
| data = json.load(f) | |
| # Initialize Engine | |
| engine = AIEngine() | |
| # Setup the mock for intents and details | |
| engine.industry_ai.nlp = MagicMock() | |
| engine.industry_ai.nlp.tokenize_and_clean.return_value = [] | |
| engine.industry_ai.nlp.extract_medical_details.return_value = {} | |
| engine.industry_ai.nlp.extract_entities.return_value = {} | |
| all_questions = [] | |
| for cat, subcats in data.items(): | |
| for subcat, questions in subcats.items(): | |
| for q in questions: | |
| all_questions.append({ | |
| "category": cat, | |
| "subcategory": subcat, | |
| "question": q | |
| }) | |
| print(f"\n--- Starting Stress Test on {len(all_questions)} Queries ---") | |
| results = [] | |
| success_count = 0 | |
| fail_count = 0 | |
| total_conf = 0 | |
| # We use a lower concurrency to avoid race conditions in the manual summary count | |
| concurrency_limit = 10 | |
| semaphore = asyncio.Semaphore(concurrency_limit) | |
| async def run_query(item): | |
| nonlocal success_count, fail_count, total_conf | |
| async with semaphore: | |
| try: | |
| # industry='healthcare' to trigger the new retriever logic | |
| # We force config={'intent': 'medical_consult'} to bypass classification | |
| response, confidence, needs_review = await engine.generate_response( | |
| query=item['question'], | |
| context=[], | |
| industry='healthcare', | |
| config={'intent': 'medical_consult'} | |
| ) | |
| status = "PASS" if confidence > 0.6 else "FAIL" | |
| if status == "PASS": success_count += 1 | |
| else: fail_count += 1 | |
| total_conf += confidence | |
| return { | |
| "question": item['question'], | |
| "confidence": confidence, | |
| "status": status, | |
| "cat": item['category'] | |
| } | |
| except Exception as e: | |
| fail_count += 1 | |
| return {"question": item['question'], "status": f"ERROR: {str(e)}"} | |
| tasks = [run_query(q) for q in all_questions] | |
| final_results = await asyncio.gather(*tasks) | |
| # Summary Report | |
| print("\n--- Stress Test Summary Report ---") | |
| print(f"Total Queries: {len(all_questions)}") | |
| print(p := f"Passed (>0.6 conf): {success_count}") | |
| print(f"Failed/Low Conf: {fail_count}") | |
| print(f"Average Confidence: {total_conf / len(all_questions):.4f}") | |
| # Show failures | |
| failed = [r for r in final_results if r['status'] != "PASS"] | |
| if failed: | |
| print(f"\nSample Failures/Low Confidence ({len(failed)} total):") | |
| for f in failed[:5]: | |
| print(f"- [{f.get('status')} @ {f.get('confidence',0):.2f}] {f['question'][:80]}...") | |
| print("\n--- Detailed Results for Sample Categories ---") | |
| categories_shown = set() | |
| for r in final_results: | |
| if r.get('cat') not in categories_shown and len(categories_shown) < 5: | |
| print(f"\n[Category: {r.get('cat')}]") | |
| print(f" Q: {r['question'][:80]}...") | |
| print(f" A: {r['status']} (Conf: {r.get('confidence', 0):.2f})") | |
| categories_shown.add(r.get('cat')) | |
| if __name__ == "__main__": | |
| asyncio.run(stress_test_medical()) | |