| """ |
| ================================================================ |
| agent5.py — 异步并行优化版 Medical RAG Agent |
| ================================================================ |
| 基于 agent4.py 的两项优化: |
| P0: 三路召回并行化 (asyncio.gather) |
| P1: AsyncOpenAI 客户端 (async LLM 推理) |
| P2: workers=1 (Milvus Lite 不支持多进程, 靠 async 获得并发) |
| |
| 架构对比: |
| agent4.py (串行, 同步): |
| Milvus [3s] → PDF [3s] → Neo4j [4s] → LLM [12s] = 22s |
| 且等 LLM 时 worker 完全阻塞, 无法服务其他请求 |
| |
| agent5.py (并行, 异步): |
| ┌─ Milvus [3s] ─┐ |
| ├─ PDF [3s] ──┼→ 合并 → async LLM [12s] = 16s |
| └─ Neo4j [4s] ─┘ |
| 且等 LLM 时 worker 可处理其他请求 (Cache Hit / 新的召回) |
| |
| 预期效果: |
| Cache Hit: ~8ms (不变, Redis 直接返回) |
| Cache Miss: ~16s (从 ~22s 降到 ~16s, 省 27%) |
| 并发能力: 单 worker async ≈ 等效 5-10 个同步 worker |
| |
| 运行: |
| python agent5.py |
| # Uvicorn running on http://0.0.0.0:8103 (4 workers) |
| ================================================================ |
| """ |
|
|
| import os |
| import uvicorn |
| import asyncio |
| from fastapi import FastAPI, Request |
| from fastapi.middleware.cors import CORSMiddleware |
| import json |
| import datetime |
| import hashlib |
| import logging |
|
|
| import httpx |
| from openai import AsyncOpenAI |
| from neo4j import GraphDatabase |
| from langchain_milvus import Milvus, BM25BuiltInFunction |
| from vector import OpenAIEmbeddings |
| from langchain_text_splitters import RecursiveCharacterTextSplitter |
| from langchain_core.stores import InMemoryStore |
| from langchain_classic.retrievers.parent_document_retriever import ParentDocumentRetriever |
| from dotenv import load_dotenv |
|
|
| from new_redis import redis_manager |
|
|
| load_dotenv() |
|
|
| os.environ["TOKENIZERS_PARALLELISM"] = "false" |
| logger = logging.getLogger("agent5") |
|
|
| app = FastAPI() |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
|
|
| |
| |
| |
|
|
| embedding_model = OpenAIEmbeddings() |
| print("创建 Embedding 模型成功......") |
|
|
| URI = "./milvus_agent.db" |
| URI1 = "./pdf_agent.db" |
|
|
| milvus_vectorstore = Milvus( |
| embedding_function=embedding_model, |
| builtin_function=BM25BuiltInFunction(), |
| vector_field=["dense", "sparse"], |
| index_params=[ |
| {"metric_type": "IP", "index_type": "IVF_FLAT"}, |
| {"metric_type": "BM25", "index_type": "SPARSE_INVERTED_INDEX"}, |
| ], |
| connection_args={"uri": URI}, |
| ) |
| print("创建 Milvus 连接成功......") |
|
|
| docstore = InMemoryStore() |
|
|
| child_splitter = RecursiveCharacterTextSplitter( |
| chunk_size=200, chunk_overlap=50, length_function=len, |
| separators=["\n\n", "\n", "。", "!", "?", ";", ",", " ", ""], |
| ) |
| parent_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) |
|
|
| pdf_vectorstore = Milvus( |
| embedding_function=embedding_model, |
| builtin_function=BM25BuiltInFunction(), |
| vector_field=["dense", "sparse"], |
| index_params=[ |
| {"metric_type": "IP", "index_type": "IVF_FLAT"}, |
| {"metric_type": "BM25", "index_type": "SPARSE_INVERTED_INDEX"}, |
| ], |
| connection_args={"uri": URI1}, |
| consistency_level="Bounded", |
| drop_old=False, |
| ) |
|
|
| parent_retriever = ParentDocumentRetriever( |
| vectorstore=pdf_vectorstore, |
| docstore=docstore, |
| child_splitter=child_splitter, |
| parent_splitter=parent_splitter, |
| ) |
| print("创建 Parent Milvus 连接成功......") |
|
|
| neo4j_uri = os.getenv("NEO4J_URI", "bolt://localhost:7687") |
| neo4j_user = os.getenv("NEO4J_USER", "neo4j") |
| neo4j_password = os.getenv("NEO4J_PASSWORD", "neo4j") |
| driver = GraphDatabase.driver( |
| uri=neo4j_uri, auth=(neo4j_user, neo4j_password), max_connection_lifetime=1000, |
| ) |
| print("创建 Neo4j 连接成功......") |
|
|
| |
| |
| |
| async_openai_client = AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY")) |
| print("创建 AsyncOpenAI LLM 成功......") |
|
|
| |
| cypher_http_client = httpx.AsyncClient(timeout=30.0) |
|
|
| print("创建 Redis 连接成功......") |
|
|
|
|
| |
| |
| |
|
|
| def format_docs(docs): |
| return "\n\n".join(doc.page_content for doc in docs) |
|
|
|
|
| async def retrieve_milvus(query: str) -> str: |
| """ |
| 路径 1: Milvus 向量召回 (同步 SDK → asyncio.to_thread) |
| """ |
| try: |
| results = await asyncio.to_thread( |
| milvus_vectorstore.similarity_search, |
| query, k=10, ranker_type="rrf", ranker_params={"k": 100}, |
| ) |
| return format_docs(results) if results else "" |
| except Exception as e: |
| logger.warning(f"Milvus 召回失败: {e}") |
| return "" |
|
|
|
|
| async def retrieve_pdf(query: str) -> str: |
| """ |
| 路径 2: PDF 父子文档检索 (同步 SDK → asyncio.to_thread) |
| """ |
| try: |
| docs = await asyncio.to_thread(parent_retriever.invoke, query) |
| if docs and len(docs) >= 1: |
| return docs[0].page_content |
| return "" |
| except Exception as e: |
| logger.warning(f"PDF 检索失败: {e}") |
| return "" |
|
|
|
|
| async def retrieve_neo4j(query: str) -> str: |
| """ |
| 路径 3: Neo4j 图数据库召回 (httpx async + 同步 session → to_thread) |
| """ |
| try: |
| |
| payload = json.dumps({"natural_language_query": query}) |
| resp = await cypher_http_client.post("http://0.0.0.0:8101/generate", content=payload) |
|
|
| if resp.status_code != 200: |
| return "" |
|
|
| data = resp.json() |
| cypher_query = data.get("cypher_query") |
| confidence = data.get("confidence", 0) |
| is_valid = data.get("validated", False) |
|
|
| if not cypher_query or float(confidence) < 0.9 or not is_valid: |
| return "" |
|
|
| print("neo4j Cypher 初步生成成功 !!!") |
|
|
| |
| val_payload = json.dumps({"cypher_query": cypher_query}) |
| val_resp = await cypher_http_client.post("http://0.0.0.0:8101/validate", content=val_payload) |
|
|
| if val_resp.status_code != 200 or not val_resp.json().get("is_valid"): |
| return "" |
|
|
| |
| def _run_neo4j(): |
| with driver.session() as session: |
| record = session.run(cypher_query) |
| result = list(map(lambda x: x[0], record)) |
| return ",".join(result) |
|
|
| return await asyncio.to_thread(_run_neo4j) |
|
|
| except Exception as e: |
| logger.warning(f"Neo4j 召回失败: {e}") |
| return "" |
|
|
|
|
| |
| |
| |
|
|
| async def perform_rag_and_llm_async(query: str) -> str: |
| """ |
| 异步版 RAG 流程: |
| 1. asyncio.gather 并行执行三路召回 |
| 2. 合并 context |
| 3. AsyncOpenAI 异步推理 |
| """ |
|
|
| |
| milvus_ctx, pdf_ctx, neo4j_ctx = await asyncio.gather( |
| retrieve_milvus(query), |
| retrieve_pdf(query), |
| retrieve_neo4j(query), |
| ) |
|
|
| context = "\n".join(filter(None, [milvus_ctx, pdf_ctx, neo4j_ctx])) |
|
|
| |
| SYSTEM_PROMPT = """ |
| System: 你是一个非常得力的医学助手, 你可以通过从数据库中检索出的信息找到问题的答案. |
| """ |
|
|
| USER_PROMPT = f""" |
| User: 利用介于<context>和</context>之间的从数据库中检索出的信息来回答问题, 具体的问题介于<question>和</question>之间. 如果提供的信息为空, 则按照你的经验知识来给出尽可能严谨准确的回答, 不知道的时候坦诚的承认不了解, 不要编造不真实的信息. |
| <context> |
| {context} |
| </context> |
| |
| <question> |
| {query} |
| </question> |
| """ |
|
|
| |
| response = await async_openai_client.chat.completions.create( |
| model="gpt-4o-mini", |
| messages=[{"role": "user", "content": SYSTEM_PROMPT + USER_PROMPT}], |
| temperature=0.7, |
| ) |
|
|
| return response.choices[0].message.content |
|
|
|
|
| |
| |
| |
|
|
| async def get_or_compute_async(question: str) -> str: |
| """ |
| 异步版 get_or_compute: |
| 1. Redis 查缓存 (同步但极快, to_thread 包装) |
| 2. 命中 → 直接返回 |
| 3. 未命中 → 加锁 → async RAG → 写缓存 → 释放锁 |
| |
| 为什么不直接用 redis_manager.get_or_compute? |
| 因为它的 compute_func 参数是同步回调, |
| 但我们的 perform_rag_and_llm_async 是 async 函数, |
| 不能直接传进去. 所以这里重写缓存逻辑, |
| 保留防击穿/防雪崩/双重检查的完整语义. |
| """ |
|
|
| |
| cached = await asyncio.to_thread(redis_manager.get_answer, question) |
| if cached: |
| print("REDIS HIT !!!✅😊") |
| return cached |
|
|
| |
| hash_key = hashlib.md5(question.encode("utf-8")).hexdigest() |
| lock_token = await asyncio.to_thread(redis_manager.acquire_lock, hash_key) |
|
|
| if lock_token: |
| try: |
| |
| cached_retry = await asyncio.to_thread(redis_manager.get_answer, question) |
| if cached_retry: |
| print("REDIS HIT (Double Check) !!!✅😊") |
| return cached_retry |
|
|
| print("Cache Miss ❌, Computing async RAG + LLM...") |
|
|
| |
| answer = await perform_rag_and_llm_async(question) |
|
|
| |
| if answer: |
| await asyncio.to_thread(redis_manager.set_answer, question, answer) |
| else: |
| await asyncio.to_thread( |
| redis_manager.client.setex, |
| redis_manager._generate_key(question), 60, "<EMPTY>", |
| ) |
|
|
| return answer |
| finally: |
| |
| await asyncio.to_thread(redis_manager.release_lock, hash_key, lock_token) |
| else: |
| |
| await asyncio.sleep(0.1) |
| cached_fallback = await asyncio.to_thread(redis_manager.get_answer, question) |
| return cached_fallback or "System busy, calculating..." |
|
|
|
|
| |
| |
| |
|
|
| @app.post("/") |
| async def chatbot(request: Request): |
| try: |
| json_post_raw = await request.json() |
|
|
| if isinstance(json_post_raw, str): |
| json_post_list = json.loads(json_post_raw) |
| else: |
| json_post_list = json_post_raw |
|
|
| query = json_post_list.get("question") |
|
|
| if not query: |
| return {"status": 400, "error": "Question is required"} |
|
|
| |
| response = await get_or_compute_async(query) |
|
|
| now = datetime.datetime.now() |
| time_str = now.strftime("%Y-%m-%d %H:%M:%S") |
|
|
| return { |
| "response": response, |
| "status": 200, |
| "time": time_str, |
| } |
|
|
| except Exception as e: |
| print(f"Server Error: {e}") |
| return {"status": 500, "error": str(e)} |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| if __name__ == "__main__": |
| uvicorn.run( |
| app, |
| host="0.0.0.0", |
| port=8103, |
| workers=1, |
| ) |