agentV2 / agent4.py
drewli20200316's picture
Add agent4.py
21497ab verified
import os
import uvicorn
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
import json
import requests
import datetime
from openai import OpenAI
from neo4j import GraphDatabase
from langchain_milvus import Milvus, BM25BuiltInFunction
from vector import OpenAIEmbeddings # 只保留 Embedding, Redis 相关全部移除
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.stores import InMemoryStore
from langchain_classic.retrievers.parent_document_retriever import ParentDocumentRetriever
from dotenv import load_dotenv
# 导入新的 redis 工具类 (替代旧的 get_redis_client, cache_set, cache_get)
from new_redis import redis_manager
# 加载 .env 文件中的环境变量, 隐藏 API Keys
load_dotenv()
os.environ["TOKENIZERS_PARALLELISM"] = "false"
app = FastAPI()
# ============================================================
# OpenAI LLM 客户端封装
# ============================================================
def create_openai_client():
"""创建 OpenAI 客户端"""
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
return client
def generate_openai_answer(client, prompt):
"""使用 OpenAI 生成回复"""
response = client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{"role": "user", "content": prompt}
],
temperature=0.7,
)
return response.choices[0].message.content
# 允许所有域的请求
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 创建 Embedding 模型
embedding_model = OpenAIEmbeddings()
print("创建 Embedding 模型成功......")
# 设置默认的 Milvus 数据库文件路径
URI = "./milvus_agent.db"
URI1 = "./pdf_agent.db"
# 创建 Milvus 连接
milvus_vectorstore = Milvus(
embedding_function=embedding_model,
builtin_function=BM25BuiltInFunction(),
vector_field=["dense", "sparse"],
index_params=[
{
"metric_type": "IP",
"index_type": "IVF_FLAT",
},
{
"metric_type": "BM25",
"index_type": "SPARSE_INVERTED_INDEX"
}
],
connection_args={"uri": URI},
)
retriever = milvus_vectorstore.as_retriever()
print("创建 Milvus 连接成功......")
docstore = InMemoryStore()
# 文本分割器
child_splitter = RecursiveCharacterTextSplitter(
chunk_size=200,
chunk_overlap=50,
length_function=len,
separators=["\n\n", "\n", "。", "!", "?", ";", ",", " ", ""]
)
parent_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200
)
pdf_vectorstore = Milvus(
embedding_function=embedding_model,
builtin_function=BM25BuiltInFunction(),
vector_field=["dense", "sparse"],
index_params=[
{
"metric_type": "IP",
"index_type": "IVF_FLAT",
},
{
"metric_type": "BM25",
"index_type": "SPARSE_INVERTED_INDEX"
}
],
connection_args={"uri": URI1},
consistency_level="Bounded",
drop_old=False,
)
# 设置父子文档检索器
parent_retriever = ParentDocumentRetriever(
vectorstore=pdf_vectorstore,
docstore=docstore,
child_splitter=child_splitter,
parent_splitter=parent_splitter,
)
print("创建 Parent Milvus 连接成功......")
# 获取 neo4j 图数据库的连接
neo4j_uri = os.getenv("NEO4J_URI", "bolt://localhost:7687")
neo4j_user = os.getenv("NEO4J_USER", "neo4j")
neo4j_password = os.getenv("NEO4J_PASSWORD", "neo4j")
driver = GraphDatabase.driver(uri=neo4j_uri, auth=(neo4j_user, neo4j_password), max_connection_lifetime=1000)
print("创建 Neo4j 连接成功......")
# 创建大语言模型, 采用 OpenAI
client_llm = create_openai_client()
print("创建 OpenAI LLM 成功......")
# 注意: Redis 连接已经在 new_redis 导入时自动初始化了 (单例模式)
print("创建 Redis 连接成功......")
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
# ============================================================
# 核心封装: 将 RAG + Neo4j + LLM 的耗时逻辑封装为一个函数
# ============================================================
def perform_rag_and_llm(query: str) -> str:
"""
执行完整的 RAG 流程:
1. Milvus 向量召回 & 重排序
2. PDF 父子文档检索
3. Neo4j 图数据库精准召回
4. OpenAI LLM 推理
"""
global milvus_vectorstore, retriever
# ============================================================
# 1: 向量数据库 Milvus 模糊召回 & 重排序
# ============================================================
recall_rerank_milvus = milvus_vectorstore.similarity_search(
query,
k=10,
ranker_type="rrf",
ranker_params={"k": 100}
)
if recall_rerank_milvus:
context = format_docs(recall_rerank_milvus)
else:
context = ""
# ============================================================
# 2: PDF 文档的 Milvus 召回 (父子文档检索器)
# ============================================================
pdf_res = ""
retrieved_docs = parent_retriever.invoke(query)
if retrieved_docs is not None and len(retrieved_docs) >= 1:
pdf_res = retrieved_docs[0].page_content
print("PDF res: ", pdf_res)
context = context + "\n" + pdf_res
# ============================================================
# 3: 图数据库 neo4j 精准召回
# ============================================================
neo4j_res = ""
data = {"natural_language_query": query}
data_json = json.dumps(data)
try:
cypher_response = requests.post("http://0.0.0.0:8101/generate", data_json)
if cypher_response.status_code == 200:
cypher_response_data = cypher_response.json()
cypher_query = cypher_response_data["cypher_query"]
confidence = cypher_response_data["confidence"]
is_valid = cypher_response_data["validated"]
if cypher_query is not None and float(confidence) >= 0.9 and is_valid == True:
print("neo4j Cypher 初步生成成功 !!!")
# 验证 neo4j 生成的 Cypher 命令完全正确
data = {"cypher_query": cypher_query}
data_json = json.dumps(data)
cypher_valid = requests.post("http://0.0.0.0:8101/validate", data_json)
if cypher_valid.status_code == 200:
cypher_valid_data = cypher_valid.json()
if cypher_valid_data["is_valid"] == True:
with driver.session() as session:
try:
record = session.run(cypher_query)
result = list(map(lambda x: x[0], record))
neo4j_res = ','.join(result)
except Exception as e:
print(e)
print("neo4j查询失败 !!")
neo4j_res = ""
else:
print("生成Cypher查询失败 !!")
except Exception as e:
print(f"neo4j API 服务不可用: {e}")
# 合并 Milvus、PDF 和 neo4j 的召回结果, 共同作为 LLM 的输入 prompt
context = context + "\n" + neo4j_res
# ============================================================
# 4: 为LLM定义系统和用户提示
# ============================================================
SYSTEM_PROMPT = """
System: 你是一个非常得力的医学助手, 你可以通过从数据库中检索出的信息找到问题的答案.
"""
USER_PROMPT = f"""
User: 利用介于<context>和</context>之间的从数据库中检索出的信息来回答问题, 具体的问题介于<question>和</question>之间. 如果提供的信息为空, 则按照你的经验知识来给出尽可能严谨准确的回答, 不知道的时候坦诚的承认不了解, 不要编造不真实的信息.
<context>
{context}
</context>
<question>
{query}
</question>
"""
# ============================================================
# 5: 使用 OpenAI 模型, 根据提示生成回复
# ============================================================
response = generate_openai_answer(client_llm, SYSTEM_PROMPT + USER_PROMPT)
return response
@app.post("/")
async def chatbot(request: Request):
try:
json_post_raw = await request.json()
# 处理可能的双重编码问题
if isinstance(json_post_raw, str):
json_post_list = json.loads(json_post_raw)
else:
json_post_list = json_post_raw
query = json_post_list.get('question')
if not query:
return {"status": 400, "error": "Question is required"}
# ============================================================
# 使用 redis_manager 接管全部缓存流程:
# 1. 尝试从缓存读
# 2. 如果没有, 自动加锁 (分布式互斥锁, 防击穿)
# 3. Double Check (防重复调用 LLM)
# 4. 执行 perform_rag_and_llm 回调函数
# 5. 将结果写入缓存 (带随机过期时间, 防雪崩)
# 6. 自动释放锁
# ============================================================
compute_callback = lambda: perform_rag_and_llm(query)
response = redis_manager.get_or_compute(query, compute_callback)
now = datetime.datetime.now()
time_str = now.strftime("%Y-%m-%d %H:%M:%S")
return {
"response": response,
"status": 200,
"time": time_str
}
except Exception as e:
print(f"Server Error: {e}")
return {"status": 500, "error": str(e)}
if __name__ == '__main__':
uvicorn.run(app, host='0.0.0.0', port=8103, workers=1)