agentV2 / agent5.py
drewli20200316's picture
Upload agent5.py with huggingface_hub
90b3130 verified
"""
================================================================
agent5.py — 异步并行优化版 Medical RAG Agent
================================================================
基于 agent4.py 的两项优化:
P0: 三路召回并行化 (asyncio.gather)
P1: AsyncOpenAI 客户端 (async LLM 推理)
P2: workers=1 (Milvus Lite 不支持多进程, 靠 async 获得并发)
架构对比:
agent4.py (串行, 同步):
Milvus [3s] → PDF [3s] → Neo4j [4s] → LLM [12s] = 22s
且等 LLM 时 worker 完全阻塞, 无法服务其他请求
agent5.py (并行, 异步):
┌─ Milvus [3s] ─┐
├─ PDF [3s] ──┼→ 合并 → async LLM [12s] = 16s
└─ Neo4j [4s] ─┘
且等 LLM 时 worker 可处理其他请求 (Cache Hit / 新的召回)
预期效果:
Cache Hit: ~8ms (不变, Redis 直接返回)
Cache Miss: ~16s (从 ~22s 降到 ~16s, 省 27%)
并发能力: 单 worker async ≈ 等效 5-10 个同步 worker
运行:
python agent5.py
# Uvicorn running on http://0.0.0.0:8103 (4 workers)
================================================================
"""
import os
import uvicorn
import asyncio
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
import json
import datetime
import hashlib
import logging
import httpx
from openai import AsyncOpenAI
from neo4j import GraphDatabase
from langchain_milvus import Milvus, BM25BuiltInFunction
from vector import OpenAIEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.stores import InMemoryStore
from langchain_classic.retrievers.parent_document_retriever import ParentDocumentRetriever
from dotenv import load_dotenv
from new_redis import redis_manager
load_dotenv()
os.environ["TOKENIZERS_PARALLELISM"] = "false"
logger = logging.getLogger("agent5")
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ============================================================
# 全局资源初始化 (每个 worker 进程各自初始化一份)
# ============================================================
embedding_model = OpenAIEmbeddings()
print("创建 Embedding 模型成功......")
URI = "./milvus_agent.db"
URI1 = "./pdf_agent.db"
milvus_vectorstore = Milvus(
embedding_function=embedding_model,
builtin_function=BM25BuiltInFunction(),
vector_field=["dense", "sparse"],
index_params=[
{"metric_type": "IP", "index_type": "IVF_FLAT"},
{"metric_type": "BM25", "index_type": "SPARSE_INVERTED_INDEX"},
],
connection_args={"uri": URI},
)
print("创建 Milvus 连接成功......")
docstore = InMemoryStore()
child_splitter = RecursiveCharacterTextSplitter(
chunk_size=200, chunk_overlap=50, length_function=len,
separators=["\n\n", "\n", "。", "!", "?", ";", ",", " ", ""],
)
parent_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
pdf_vectorstore = Milvus(
embedding_function=embedding_model,
builtin_function=BM25BuiltInFunction(),
vector_field=["dense", "sparse"],
index_params=[
{"metric_type": "IP", "index_type": "IVF_FLAT"},
{"metric_type": "BM25", "index_type": "SPARSE_INVERTED_INDEX"},
],
connection_args={"uri": URI1},
consistency_level="Bounded",
drop_old=False,
)
parent_retriever = ParentDocumentRetriever(
vectorstore=pdf_vectorstore,
docstore=docstore,
child_splitter=child_splitter,
parent_splitter=parent_splitter,
)
print("创建 Parent Milvus 连接成功......")
neo4j_uri = os.getenv("NEO4J_URI", "bolt://localhost:7687")
neo4j_user = os.getenv("NEO4J_USER", "neo4j")
neo4j_password = os.getenv("NEO4J_PASSWORD", "neo4j")
driver = GraphDatabase.driver(
uri=neo4j_uri, auth=(neo4j_user, neo4j_password), max_connection_lifetime=1000,
)
print("创建 Neo4j 连接成功......")
# ============================================================
# P1: AsyncOpenAI 客户端 (替代同步 OpenAI)
# ============================================================
async_openai_client = AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY"))
print("创建 AsyncOpenAI LLM 成功......")
# Cypher API 用 httpx.AsyncClient (替代同步 requests)
cypher_http_client = httpx.AsyncClient(timeout=30.0)
print("创建 Redis 连接成功......")
# ============================================================
# P0: 三路召回 — 各自独立的 async 函数
# ============================================================
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
async def retrieve_milvus(query: str) -> str:
"""
路径 1: Milvus 向量召回 (同步 SDK → asyncio.to_thread)
"""
try:
results = await asyncio.to_thread(
milvus_vectorstore.similarity_search,
query, k=10, ranker_type="rrf", ranker_params={"k": 100},
)
return format_docs(results) if results else ""
except Exception as e:
logger.warning(f"Milvus 召回失败: {e}")
return ""
async def retrieve_pdf(query: str) -> str:
"""
路径 2: PDF 父子文档检索 (同步 SDK → asyncio.to_thread)
"""
try:
docs = await asyncio.to_thread(parent_retriever.invoke, query)
if docs and len(docs) >= 1:
return docs[0].page_content
return ""
except Exception as e:
logger.warning(f"PDF 检索失败: {e}")
return ""
async def retrieve_neo4j(query: str) -> str:
"""
路径 3: Neo4j 图数据库召回 (httpx async + 同步 session → to_thread)
"""
try:
# Step 1: 异步调用 Cypher 生成 API
payload = json.dumps({"natural_language_query": query})
resp = await cypher_http_client.post("http://0.0.0.0:8101/generate", content=payload)
if resp.status_code != 200:
return ""
data = resp.json()
cypher_query = data.get("cypher_query")
confidence = data.get("confidence", 0)
is_valid = data.get("validated", False)
if not cypher_query or float(confidence) < 0.9 or not is_valid:
return ""
print("neo4j Cypher 初步生成成功 !!!")
# Step 2: 异步校验 Cypher
val_payload = json.dumps({"cypher_query": cypher_query})
val_resp = await cypher_http_client.post("http://0.0.0.0:8101/validate", content=val_payload)
if val_resp.status_code != 200 or not val_resp.json().get("is_valid"):
return ""
# Step 3: 执行 Neo4j 查询 (同步 driver → to_thread)
def _run_neo4j():
with driver.session() as session:
record = session.run(cypher_query)
result = list(map(lambda x: x[0], record))
return ",".join(result)
return await asyncio.to_thread(_run_neo4j)
except Exception as e:
logger.warning(f"Neo4j 召回失败: {e}")
return ""
# ============================================================
# P0 + P1: 异步并行 RAG + 异步 LLM 推理
# ============================================================
async def perform_rag_and_llm_async(query: str) -> str:
"""
异步版 RAG 流程:
1. asyncio.gather 并行执行三路召回
2. 合并 context
3. AsyncOpenAI 异步推理
"""
# ── P0: 三路并行召回 ──
milvus_ctx, pdf_ctx, neo4j_ctx = await asyncio.gather(
retrieve_milvus(query),
retrieve_pdf(query),
retrieve_neo4j(query),
)
context = "\n".join(filter(None, [milvus_ctx, pdf_ctx, neo4j_ctx]))
# ── 拼 Prompt ──
SYSTEM_PROMPT = """
System: 你是一个非常得力的医学助手, 你可以通过从数据库中检索出的信息找到问题的答案.
"""
USER_PROMPT = f"""
User: 利用介于<context>和</context>之间的从数据库中检索出的信息来回答问题, 具体的问题介于<question>和</question>之间. 如果提供的信息为空, 则按照你的经验知识来给出尽可能严谨准确的回答, 不知道的时候坦诚的承认不了解, 不要编造不真实的信息.
<context>
{context}
</context>
<question>
{query}
</question>
"""
# ── P1: AsyncOpenAI 异步推理 ──
response = await async_openai_client.chat.completions.create(
model="gpt-4o-mini",
messages=[{"role": "user", "content": SYSTEM_PROMPT + USER_PROMPT}],
temperature=0.7,
)
return response.choices[0].message.content
# ============================================================
# Redis 缓存 + 异步 RAG 的衔接
# ============================================================
async def get_or_compute_async(question: str) -> str:
"""
异步版 get_or_compute:
1. Redis 查缓存 (同步但极快, to_thread 包装)
2. 命中 → 直接返回
3. 未命中 → 加锁 → async RAG → 写缓存 → 释放锁
为什么不直接用 redis_manager.get_or_compute?
因为它的 compute_func 参数是同步回调,
但我们的 perform_rag_and_llm_async 是 async 函数,
不能直接传进去. 所以这里重写缓存逻辑,
保留防击穿/防雪崩/双重检查的完整语义.
"""
# 1. 查缓存 (Redis 操作 <5ms, to_thread 安全包装)
cached = await asyncio.to_thread(redis_manager.get_answer, question)
if cached:
print("REDIS HIT !!!✅😊")
return cached
# 2. 加锁 (防击穿)
hash_key = hashlib.md5(question.encode("utf-8")).hexdigest()
lock_token = await asyncio.to_thread(redis_manager.acquire_lock, hash_key)
if lock_token:
try:
# 双重检查 (加锁后再查一次, 防止重复计算)
cached_retry = await asyncio.to_thread(redis_manager.get_answer, question)
if cached_retry:
print("REDIS HIT (Double Check) !!!✅😊")
return cached_retry
print("Cache Miss ❌, Computing async RAG + LLM...")
# 3. 执行异步 RAG (核心优化点)
answer = await perform_rag_and_llm_async(question)
# 4. 写缓存
if answer:
await asyncio.to_thread(redis_manager.set_answer, question, answer)
else:
await asyncio.to_thread(
redis_manager.client.setex,
redis_manager._generate_key(question), 60, "<EMPTY>",
)
return answer
finally:
# 5. 释放锁
await asyncio.to_thread(redis_manager.release_lock, hash_key, lock_token)
else:
# 获取锁失败, 等待重试
await asyncio.sleep(0.1)
cached_fallback = await asyncio.to_thread(redis_manager.get_answer, question)
return cached_fallback or "System busy, calculating..."
# ============================================================
# FastAPI 路由
# ============================================================
@app.post("/")
async def chatbot(request: Request):
try:
json_post_raw = await request.json()
if isinstance(json_post_raw, str):
json_post_list = json.loads(json_post_raw)
else:
json_post_list = json_post_raw
query = json_post_list.get("question")
if not query:
return {"status": 400, "error": "Question is required"}
# 使用异步版 get_or_compute
response = await get_or_compute_async(query)
now = datetime.datetime.now()
time_str = now.strftime("%Y-%m-%d %H:%M:%S")
return {
"response": response,
"status": 200,
"time": time_str,
}
except Exception as e:
print(f"Server Error: {e}")
return {"status": 500, "error": str(e)}
# ============================================================
# 启动说明
# ============================================================
# Milvus Lite 使用本地 .db 文件, 不支持多进程同时打开,
# 因此 workers 必须为 1. 并发能力通过 P0 + P1 的 async 实现:
# - P0: asyncio.gather 三路并行召回
# - P1: AsyncOpenAI 在等待 I/O 时释放事件循环
# - 单 worker async ≈ 等效 5-10 个同步 worker 的并发能力
#
# 如果未来迁移到 Milvus Server (非 Lite), 可改为 workers=4
if __name__ == "__main__":
uvicorn.run(
app,
host="0.0.0.0",
port=8103,
workers=1,
)