agentV2 / agent6.py
drewli20200316's picture
Upload agent6.py with huggingface_hub
528ba5d verified
"""
================================================================
agent6.py — 多 Worker 版 Medical RAG Agent
================================================================
基于 agent5.py, 新增 P2 优化:
P0: 三路召回并行化 (asyncio.gather) ← 继承 agent5
P1: AsyncOpenAI 客户端 (async LLM 推理) ← 继承 agent5
P2: Milvus Lite → Milvus Server + workers=4 ← 新增
架构变化:
agent5.py: 单 worker + async (Milvus Lite 文件锁限制)
agent6.py: 4 workers × async (Milvus Server 网络连接, 无文件锁)
Worker 1 ──→ ┐
Worker 2 ──→ ├── Milvus Server (:19530) ──→ 数据持久化
Worker 3 ──→ ┤
Worker 4 ──→ ┘
前置条件:
1. 安装并启动 Milvus Server (Docker):
docker run -d --name milvus-standalone \
-p 19530:19530 -p 9091:9091 \
milvusdb/milvus:latest
2. 将已有数据从 Milvus Lite 迁移到 Milvus Server:
参考: https://milvus.io/docs/migrate_overview.md
3. .env 中配置 (可选, 有默认值):
MILVUS_URI=http://localhost:19530
运行:
python agent6.py
# Uvicorn running on http://0.0.0.0:8103 (4 workers)
================================================================
"""
import os
import uvicorn
import asyncio
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
import json
import datetime
import hashlib
import logging
import httpx
from openai import AsyncOpenAI
from neo4j import GraphDatabase
from langchain_milvus import Milvus, BM25BuiltInFunction
from vector import OpenAIEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.stores import InMemoryStore
from langchain_classic.retrievers.parent_document_retriever import ParentDocumentRetriever
from dotenv import load_dotenv
from new_redis import redis_manager
load_dotenv()
os.environ["TOKENIZERS_PARALLELISM"] = "false"
logger = logging.getLogger("agent6")
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ============================================================
# 全局资源初始化 (每个 worker 进程各自初始化一份)
# ============================================================
embedding_model = OpenAIEmbeddings()
print("创建 Embedding 模型成功......")
# ============================================================
# P2: Milvus Lite → Milvus Server
# ============================================================
# agent4/agent5: URI = "./milvus_agent.db" (本地文件, 单进程独占)
# agent6: URI = "http://localhost:19530" (网络连接, 多进程共享)
#
# Milvus Server 是独立进程, 通过 gRPC 端口 19530 对外服务.
# 4 个 worker 各自建立网络连接, 不再争抢文件锁.
MILVUS_URI = os.getenv("MILVUS_URI", "http://localhost:19530")
milvus_vectorstore = Milvus(
embedding_function=embedding_model,
builtin_function=BM25BuiltInFunction(),
vector_field=["dense", "sparse"],
index_params=[
{"metric_type": "IP", "index_type": "IVF_FLAT"},
{"metric_type": "BM25", "index_type": "SPARSE_INVERTED_INDEX"},
],
connection_args={"uri": MILVUS_URI},
collection_name="medical_agent", # 显式指定 collection 名称
)
print(f"创建 Milvus 连接成功...... (URI: {MILVUS_URI})")
docstore = InMemoryStore()
child_splitter = RecursiveCharacterTextSplitter(
chunk_size=200, chunk_overlap=50, length_function=len,
separators=["\n\n", "\n", "。", "!", "?", ";", ",", " ", ""],
)
parent_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
pdf_vectorstore = Milvus(
embedding_function=embedding_model,
builtin_function=BM25BuiltInFunction(),
vector_field=["dense", "sparse"],
index_params=[
{"metric_type": "IP", "index_type": "IVF_FLAT"},
{"metric_type": "BM25", "index_type": "SPARSE_INVERTED_INDEX"},
],
connection_args={"uri": MILVUS_URI},
collection_name="medical_pdf", # 显式指定 collection 名称
consistency_level="Bounded",
drop_old=False,
)
parent_retriever = ParentDocumentRetriever(
vectorstore=pdf_vectorstore,
docstore=docstore,
child_splitter=child_splitter,
parent_splitter=parent_splitter,
)
print("创建 Parent Milvus 连接成功......")
neo4j_uri = os.getenv("NEO4J_URI", "bolt://localhost:7687")
neo4j_user = os.getenv("NEO4J_USER", "neo4j")
neo4j_password = os.getenv("NEO4J_PASSWORD", "neo4j")
driver = GraphDatabase.driver(
uri=neo4j_uri, auth=(neo4j_user, neo4j_password), max_connection_lifetime=1000,
)
print("创建 Neo4j 连接成功......")
# P1: AsyncOpenAI 客户端
async_openai_client = AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY"))
print("创建 AsyncOpenAI LLM 成功......")
# Cypher API 用 httpx.AsyncClient
cypher_http_client = httpx.AsyncClient(timeout=30.0)
print("创建 Redis 连接成功......")
# ============================================================
# P0: 三路召回 — 各自独立的 async 函数
# ============================================================
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
async def retrieve_milvus(query: str) -> str:
"""路径 1: Milvus 向量召回"""
try:
results = await asyncio.to_thread(
milvus_vectorstore.similarity_search,
query, k=10, ranker_type="rrf", ranker_params={"k": 100},
)
return format_docs(results) if results else ""
except Exception as e:
logger.warning(f"Milvus 召回失败: {e}")
return ""
async def retrieve_pdf(query: str) -> str:
"""路径 2: PDF 父子文档检索"""
try:
docs = await asyncio.to_thread(parent_retriever.invoke, query)
if docs and len(docs) >= 1:
return docs[0].page_content
return ""
except Exception as e:
logger.warning(f"PDF 检索失败: {e}")
return ""
async def retrieve_neo4j(query: str) -> str:
"""路径 3: Neo4j 图数据库召回"""
try:
payload = json.dumps({"natural_language_query": query})
resp = await cypher_http_client.post("http://0.0.0.0:8101/generate", content=payload)
if resp.status_code != 200:
return ""
data = resp.json()
cypher_query = data.get("cypher_query")
confidence = data.get("confidence", 0)
is_valid = data.get("validated", False)
if not cypher_query or float(confidence) < 0.9 or not is_valid:
return ""
print("neo4j Cypher 初步生成成功 !!!")
val_payload = json.dumps({"cypher_query": cypher_query})
val_resp = await cypher_http_client.post("http://0.0.0.0:8101/validate", content=val_payload)
if val_resp.status_code != 200 or not val_resp.json().get("is_valid"):
return ""
def _run_neo4j():
with driver.session() as session:
record = session.run(cypher_query)
result = list(map(lambda x: x[0], record))
return ",".join(result)
return await asyncio.to_thread(_run_neo4j)
except Exception as e:
logger.warning(f"Neo4j 召回失败: {e}")
return ""
# ============================================================
# P0 + P1: 异步并行 RAG + 异步 LLM 推理
# ============================================================
async def perform_rag_and_llm_async(query: str) -> str:
"""异步版 RAG 流程"""
milvus_ctx, pdf_ctx, neo4j_ctx = await asyncio.gather(
retrieve_milvus(query),
retrieve_pdf(query),
retrieve_neo4j(query),
)
context = "\n".join(filter(None, [milvus_ctx, pdf_ctx, neo4j_ctx]))
SYSTEM_PROMPT = """
System: 你是一个非常得力的医学助手, 你可以通过从数据库中检索出的信息找到问题的答案.
"""
USER_PROMPT = f"""
User: 利用介于<context>和</context>之间的从数据库中检索出的信息来回答问题, 具体的问题介于<question>和</question>之间. 如果提供的信息为空, 则按照你的经验知识来给出尽可能严谨准确的回答, 不知道的时候坦诚的承认不了解, 不要编造不真实的信息.
<context>
{context}
</context>
<question>
{query}
</question>
"""
response = await async_openai_client.chat.completions.create(
model="gpt-4o-mini",
messages=[{"role": "user", "content": SYSTEM_PROMPT + USER_PROMPT}],
temperature=0.7,
)
return response.choices[0].message.content
# ============================================================
# Redis 缓存 + 异步 RAG 的衔接
# ============================================================
async def get_or_compute_async(question: str) -> str:
"""异步版 get_or_compute (防击穿/防雪崩/双重检查)"""
cached = await asyncio.to_thread(redis_manager.get_answer, question)
if cached:
print("REDIS HIT !!!✅😊")
return cached
hash_key = hashlib.md5(question.encode("utf-8")).hexdigest()
lock_token = await asyncio.to_thread(redis_manager.acquire_lock, hash_key)
if lock_token:
try:
cached_retry = await asyncio.to_thread(redis_manager.get_answer, question)
if cached_retry:
print("REDIS HIT (Double Check) !!!✅😊")
return cached_retry
print("Cache Miss ❌, Computing async RAG + LLM...")
answer = await perform_rag_and_llm_async(question)
if answer:
await asyncio.to_thread(redis_manager.set_answer, question, answer)
else:
await asyncio.to_thread(
redis_manager.client.setex,
redis_manager._generate_key(question), 60, "<EMPTY>",
)
return answer
finally:
await asyncio.to_thread(redis_manager.release_lock, hash_key, lock_token)
else:
await asyncio.sleep(0.1)
cached_fallback = await asyncio.to_thread(redis_manager.get_answer, question)
return cached_fallback or "System busy, calculating..."
# ============================================================
# FastAPI 路由
# ============================================================
@app.post("/")
async def chatbot(request: Request):
try:
json_post_raw = await request.json()
if isinstance(json_post_raw, str):
json_post_list = json.loads(json_post_raw)
else:
json_post_list = json_post_raw
query = json_post_list.get("question")
if not query:
return {"status": 400, "error": "Question is required"}
response = await get_or_compute_async(query)
now = datetime.datetime.now()
time_str = now.strftime("%Y-%m-%d %H:%M:%S")
return {
"response": response,
"status": 200,
"time": time_str,
}
except Exception as e:
print(f"Server Error: {e}")
return {"status": 500, "error": str(e)}
# ============================================================
# P2: 多 Worker 启动
# ============================================================
# Milvus Server 通过网络端口提供服务, 不再有文件锁限制,
# 4 个 worker 进程各自建立独立连接, 互不干扰.
#
# 每个 worker 内部仍然是 async (P0 + P1),
# 所以总并发能力 = 4 workers × 每 worker ~5 并发 ≈ 20 并发用户
if __name__ == "__main__":
uvicorn.run(
"agent6:app", # 字符串形式, 多 worker 必须这样写
host="0.0.0.0",
port=8103,
workers=4, # P2: 4 个 worker 进程
)