| """ |
| ================================================================ |
| agent6.py — 多 Worker 版 Medical RAG Agent |
| ================================================================ |
| 基于 agent5.py, 新增 P2 优化: |
| P0: 三路召回并行化 (asyncio.gather) ← 继承 agent5 |
| P1: AsyncOpenAI 客户端 (async LLM 推理) ← 继承 agent5 |
| P2: Milvus Lite → Milvus Server + workers=4 ← 新增 |
| |
| 架构变化: |
| agent5.py: 单 worker + async (Milvus Lite 文件锁限制) |
| agent6.py: 4 workers × async (Milvus Server 网络连接, 无文件锁) |
| |
| Worker 1 ──→ ┐ |
| Worker 2 ──→ ├── Milvus Server (:19530) ──→ 数据持久化 |
| Worker 3 ──→ ┤ |
| Worker 4 ──→ ┘ |
| |
| 前置条件: |
| 1. 安装并启动 Milvus Server (Docker): |
| docker run -d --name milvus-standalone \ |
| -p 19530:19530 -p 9091:9091 \ |
| milvusdb/milvus:latest |
| |
| 2. 将已有数据从 Milvus Lite 迁移到 Milvus Server: |
| 参考: https://milvus.io/docs/migrate_overview.md |
| |
| 3. .env 中配置 (可选, 有默认值): |
| MILVUS_URI=http://localhost:19530 |
| |
| 运行: |
| python agent6.py |
| # Uvicorn running on http://0.0.0.0:8103 (4 workers) |
| ================================================================ |
| """ |
|
|
| import os |
| import uvicorn |
| import asyncio |
| from fastapi import FastAPI, Request |
| from fastapi.middleware.cors import CORSMiddleware |
| import json |
| import datetime |
| import hashlib |
| import logging |
|
|
| import httpx |
| from openai import AsyncOpenAI |
| from neo4j import GraphDatabase |
| from langchain_milvus import Milvus, BM25BuiltInFunction |
| from vector import OpenAIEmbeddings |
| from langchain_text_splitters import RecursiveCharacterTextSplitter |
| from langchain_core.stores import InMemoryStore |
| from langchain_classic.retrievers.parent_document_retriever import ParentDocumentRetriever |
| from dotenv import load_dotenv |
|
|
| from new_redis import redis_manager |
|
|
| load_dotenv() |
|
|
| os.environ["TOKENIZERS_PARALLELISM"] = "false" |
| logger = logging.getLogger("agent6") |
|
|
| app = FastAPI() |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
|
|
| |
| |
| |
|
|
| embedding_model = OpenAIEmbeddings() |
| print("创建 Embedding 模型成功......") |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| MILVUS_URI = os.getenv("MILVUS_URI", "http://localhost:19530") |
|
|
| milvus_vectorstore = Milvus( |
| embedding_function=embedding_model, |
| builtin_function=BM25BuiltInFunction(), |
| vector_field=["dense", "sparse"], |
| index_params=[ |
| {"metric_type": "IP", "index_type": "IVF_FLAT"}, |
| {"metric_type": "BM25", "index_type": "SPARSE_INVERTED_INDEX"}, |
| ], |
| connection_args={"uri": MILVUS_URI}, |
| collection_name="medical_agent", |
| ) |
| print(f"创建 Milvus 连接成功...... (URI: {MILVUS_URI})") |
|
|
| docstore = InMemoryStore() |
|
|
| child_splitter = RecursiveCharacterTextSplitter( |
| chunk_size=200, chunk_overlap=50, length_function=len, |
| separators=["\n\n", "\n", "。", "!", "?", ";", ",", " ", ""], |
| ) |
| parent_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) |
|
|
| pdf_vectorstore = Milvus( |
| embedding_function=embedding_model, |
| builtin_function=BM25BuiltInFunction(), |
| vector_field=["dense", "sparse"], |
| index_params=[ |
| {"metric_type": "IP", "index_type": "IVF_FLAT"}, |
| {"metric_type": "BM25", "index_type": "SPARSE_INVERTED_INDEX"}, |
| ], |
| connection_args={"uri": MILVUS_URI}, |
| collection_name="medical_pdf", |
| consistency_level="Bounded", |
| drop_old=False, |
| ) |
|
|
| parent_retriever = ParentDocumentRetriever( |
| vectorstore=pdf_vectorstore, |
| docstore=docstore, |
| child_splitter=child_splitter, |
| parent_splitter=parent_splitter, |
| ) |
| print("创建 Parent Milvus 连接成功......") |
|
|
| neo4j_uri = os.getenv("NEO4J_URI", "bolt://localhost:7687") |
| neo4j_user = os.getenv("NEO4J_USER", "neo4j") |
| neo4j_password = os.getenv("NEO4J_PASSWORD", "neo4j") |
| driver = GraphDatabase.driver( |
| uri=neo4j_uri, auth=(neo4j_user, neo4j_password), max_connection_lifetime=1000, |
| ) |
| print("创建 Neo4j 连接成功......") |
|
|
| |
| async_openai_client = AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY")) |
| print("创建 AsyncOpenAI LLM 成功......") |
|
|
| |
| cypher_http_client = httpx.AsyncClient(timeout=30.0) |
|
|
| print("创建 Redis 连接成功......") |
|
|
|
|
| |
| |
| |
|
|
| def format_docs(docs): |
| return "\n\n".join(doc.page_content for doc in docs) |
|
|
|
|
| async def retrieve_milvus(query: str) -> str: |
| """路径 1: Milvus 向量召回""" |
| try: |
| results = await asyncio.to_thread( |
| milvus_vectorstore.similarity_search, |
| query, k=10, ranker_type="rrf", ranker_params={"k": 100}, |
| ) |
| return format_docs(results) if results else "" |
| except Exception as e: |
| logger.warning(f"Milvus 召回失败: {e}") |
| return "" |
|
|
|
|
| async def retrieve_pdf(query: str) -> str: |
| """路径 2: PDF 父子文档检索""" |
| try: |
| docs = await asyncio.to_thread(parent_retriever.invoke, query) |
| if docs and len(docs) >= 1: |
| return docs[0].page_content |
| return "" |
| except Exception as e: |
| logger.warning(f"PDF 检索失败: {e}") |
| return "" |
|
|
|
|
| async def retrieve_neo4j(query: str) -> str: |
| """路径 3: Neo4j 图数据库召回""" |
| try: |
| payload = json.dumps({"natural_language_query": query}) |
| resp = await cypher_http_client.post("http://0.0.0.0:8101/generate", content=payload) |
|
|
| if resp.status_code != 200: |
| return "" |
|
|
| data = resp.json() |
| cypher_query = data.get("cypher_query") |
| confidence = data.get("confidence", 0) |
| is_valid = data.get("validated", False) |
|
|
| if not cypher_query or float(confidence) < 0.9 or not is_valid: |
| return "" |
|
|
| print("neo4j Cypher 初步生成成功 !!!") |
|
|
| val_payload = json.dumps({"cypher_query": cypher_query}) |
| val_resp = await cypher_http_client.post("http://0.0.0.0:8101/validate", content=val_payload) |
|
|
| if val_resp.status_code != 200 or not val_resp.json().get("is_valid"): |
| return "" |
|
|
| def _run_neo4j(): |
| with driver.session() as session: |
| record = session.run(cypher_query) |
| result = list(map(lambda x: x[0], record)) |
| return ",".join(result) |
|
|
| return await asyncio.to_thread(_run_neo4j) |
|
|
| except Exception as e: |
| logger.warning(f"Neo4j 召回失败: {e}") |
| return "" |
|
|
|
|
| |
| |
| |
|
|
| async def perform_rag_and_llm_async(query: str) -> str: |
| """异步版 RAG 流程""" |
|
|
| milvus_ctx, pdf_ctx, neo4j_ctx = await asyncio.gather( |
| retrieve_milvus(query), |
| retrieve_pdf(query), |
| retrieve_neo4j(query), |
| ) |
|
|
| context = "\n".join(filter(None, [milvus_ctx, pdf_ctx, neo4j_ctx])) |
|
|
| SYSTEM_PROMPT = """ |
| System: 你是一个非常得力的医学助手, 你可以通过从数据库中检索出的信息找到问题的答案. |
| """ |
|
|
| USER_PROMPT = f""" |
| User: 利用介于<context>和</context>之间的从数据库中检索出的信息来回答问题, 具体的问题介于<question>和</question>之间. 如果提供的信息为空, 则按照你的经验知识来给出尽可能严谨准确的回答, 不知道的时候坦诚的承认不了解, 不要编造不真实的信息. |
| <context> |
| {context} |
| </context> |
| |
| <question> |
| {query} |
| </question> |
| """ |
|
|
| response = await async_openai_client.chat.completions.create( |
| model="gpt-4o-mini", |
| messages=[{"role": "user", "content": SYSTEM_PROMPT + USER_PROMPT}], |
| temperature=0.7, |
| ) |
|
|
| return response.choices[0].message.content |
|
|
|
|
| |
| |
| |
|
|
| async def get_or_compute_async(question: str) -> str: |
| """异步版 get_or_compute (防击穿/防雪崩/双重检查)""" |
|
|
| cached = await asyncio.to_thread(redis_manager.get_answer, question) |
| if cached: |
| print("REDIS HIT !!!✅😊") |
| return cached |
|
|
| hash_key = hashlib.md5(question.encode("utf-8")).hexdigest() |
| lock_token = await asyncio.to_thread(redis_manager.acquire_lock, hash_key) |
|
|
| if lock_token: |
| try: |
| cached_retry = await asyncio.to_thread(redis_manager.get_answer, question) |
| if cached_retry: |
| print("REDIS HIT (Double Check) !!!✅😊") |
| return cached_retry |
|
|
| print("Cache Miss ❌, Computing async RAG + LLM...") |
| answer = await perform_rag_and_llm_async(question) |
|
|
| if answer: |
| await asyncio.to_thread(redis_manager.set_answer, question, answer) |
| else: |
| await asyncio.to_thread( |
| redis_manager.client.setex, |
| redis_manager._generate_key(question), 60, "<EMPTY>", |
| ) |
|
|
| return answer |
| finally: |
| await asyncio.to_thread(redis_manager.release_lock, hash_key, lock_token) |
| else: |
| await asyncio.sleep(0.1) |
| cached_fallback = await asyncio.to_thread(redis_manager.get_answer, question) |
| return cached_fallback or "System busy, calculating..." |
|
|
|
|
| |
| |
| |
|
|
| @app.post("/") |
| async def chatbot(request: Request): |
| try: |
| json_post_raw = await request.json() |
|
|
| if isinstance(json_post_raw, str): |
| json_post_list = json.loads(json_post_raw) |
| else: |
| json_post_list = json_post_raw |
|
|
| query = json_post_list.get("question") |
|
|
| if not query: |
| return {"status": 400, "error": "Question is required"} |
|
|
| response = await get_or_compute_async(query) |
|
|
| now = datetime.datetime.now() |
| time_str = now.strftime("%Y-%m-%d %H:%M:%S") |
|
|
| return { |
| "response": response, |
| "status": 200, |
| "time": time_str, |
| } |
|
|
| except Exception as e: |
| print(f"Server Error: {e}") |
| return {"status": 500, "error": str(e)} |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| if __name__ == "__main__": |
| uvicorn.run( |
| "agent6:app", |
| host="0.0.0.0", |
| port=8103, |
| workers=4, |
| ) |