agentV2 / agent3.py
shenli
Add GraphDatabase module with Neo4j + Redis caching
8a17806
import os
import uvicorn
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
import json
import requests
import datetime
from openai import OpenAI
from neo4j import GraphDatabase
from langchain_milvus import Milvus, BM25BuiltInFunction
from vector import OpenAIEmbeddings, get_redis_client, cache_set, cache_get
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.stores import InMemoryStore
from langchain_classic.retrievers.parent_document_retriever import ParentDocumentRetriever
from dotenv import load_dotenv
# 加载 .env 文件中的环境变量, 隐藏 API Keys
load_dotenv()
os.environ["TOKENIZERS_PARALLELISM"] = "false"
app = FastAPI()
# ============================================================
# OpenAI LLM 客户端封装 (替代讲义中的 DeepSeek)
# ============================================================
def create_openai_client():
"""创建 OpenAI 客户端"""
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
return client
def generate_openai_answer(client, prompt):
"""使用 OpenAI 生成回复"""
response = client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{"role": "user", "content": prompt}
],
temperature=0.7,
)
return response.choices[0].message.content
# 允许所有域的请求
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 创建 Embedding 模型
embedding_model = OpenAIEmbeddings()
print("创建 Embedding 模型成功......")
# 设置默认的 Milvus 数据库文件路径
URI = "./milvus_agent.db"
URI1 = "./pdf_agent.db"
# 创建 Milvus 连接
milvus_vectorstore = Milvus(
embedding_function=embedding_model,
builtin_function=BM25BuiltInFunction(),
vector_field=["dense", "sparse"],
index_params=[
{
"metric_type": "IP",
"index_type": "IVF_FLAT",
},
{
"metric_type": "BM25",
"index_type": "SPARSE_INVERTED_INDEX"
}
],
connection_args={"uri": URI},
)
retriever = milvus_vectorstore.as_retriever()
print("创建 Milvus 连接成功......")
docstore = InMemoryStore()
# 文本分割器
child_splitter = RecursiveCharacterTextSplitter(
chunk_size=200,
chunk_overlap=50,
length_function=len,
separators=["\n\n", "\n", "。", "!", "?", ";", ",", " ", ""]
)
parent_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200
)
pdf_vectorstore = Milvus(
embedding_function=embedding_model,
builtin_function=BM25BuiltInFunction(),
vector_field=["dense", "sparse"],
index_params=[
{
"metric_type": "IP",
"index_type": "IVF_FLAT",
},
{
"metric_type": "BM25",
"index_type": "SPARSE_INVERTED_INDEX"
}
],
connection_args={"uri": URI1},
consistency_level="Bounded",
drop_old=False,
)
# 设置父子文档检索器
parent_retriever = ParentDocumentRetriever(
vectorstore=pdf_vectorstore,
docstore=docstore,
child_splitter=child_splitter,
parent_splitter=parent_splitter,
)
print("创建 Parent Milvus 连接成功......")
# 获取 neo4j 图数据库的连接
neo4j_uri = os.getenv("NEO4J_URI", "bolt://localhost:7687")
neo4j_user = os.getenv("NEO4J_USER", "neo4j")
neo4j_password = os.getenv("NEO4J_PASSWORD", "neo4j")
driver = GraphDatabase.driver(uri=neo4j_uri, auth=(neo4j_user, neo4j_password), max_connection_lifetime=1000)
print("创建 Neo4j 连接成功......")
# 创建大语言模型, 采用 OpenAI
client_llm = create_openai_client()
print("创建 OpenAI LLM 成功......")
# 获取 Redis 连接
client_redis = get_redis_client()
print("创建 Redis 连接成功......")
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
@app.post("/")
async def chatbot(request: Request):
global milvus_vectorstore, retriever
json_post_raw = await request.json()
json_post = json.dumps(json_post_raw)
json_post_list = json.loads(json_post)
query = json_post_list.get('question')
# ============================================================
# 1: 先查 Redis 缓存, 如果缓存命中, 直接返回结果
# ============================================================
response_redis = cache_get(client_redis, query)
if response_redis is not None:
# redis 返回的字符串是以十六进制显示的, 需要按 utf-8 解码
response = response_redis.decode('utf-8')
now = datetime.datetime.now()
time = now.strftime("%Y-%m-%d %H:%M:%S")
answer = {
"response": response,
"status": 200,
"time": time
}
print('REDIS HIT !!!')
return answer
# ============================================================
# 2: 向量数据库 Milvus 模糊召回 & 重排序
# ============================================================
# 在集合中搜索问题并检索语义 top-10 匹配项, 而且已经配置了 reranker 的处理, 采用RRF算法
recall_rerank_milvus = milvus_vectorstore.similarity_search(
query,
k=10,
ranker_type="rrf",
ranker_params={"k": 100}
)
if recall_rerank_milvus:
# 检索结果存放在列表中
context = format_docs(recall_rerank_milvus)
else:
context = ""
# ============================================================
# 2.5: PDF 文档的 Milvus 召回 (父子文档检索器)
# ============================================================
pdf_res = ""
retrieved_docs = parent_retriever.invoke(query)
if retrieved_docs is not None and len(retrieved_docs) >= 1:
pdf_res = retrieved_docs[0].page_content
print("PDF res: ", pdf_res)
context = context + "\n" + pdf_res
# ============================================================
# 3: 图数据库 neo4j 精准召回
# ============================================================
# 访问 neo4j API 服务, 生成 Cypher 命令
neo4j_res = ""
data = {"natural_language_query": query}
data_json = json.dumps(data)
try:
cypher_response = requests.post("http://0.0.0.0:8101/generate", data_json)
if cypher_response.status_code == 200:
cypher_response_data = cypher_response.json()
cypher_query = cypher_response_data["cypher_query"]
confidence = cypher_response_data["confidence"]
is_valid = cypher_response_data["validated"]
if cypher_query is not None and float(confidence) >= 0.9 and is_valid == True:
print("neo4j Cypher 初步生成成功 !!!")
# 验证 neo4j 生成的 Cypher 命令完全正确
data = {"cypher_query": cypher_query}
data_json = json.dumps(data)
cypher_valid = requests.post("http://0.0.0.0:8101/validate", data_json)
if cypher_valid.status_code == 200:
cypher_valid_data = cypher_valid.json()
if cypher_valid_data["is_valid"] == True:
with driver.session() as session:
try:
record = session.run(cypher_query)
result = list(map(lambda x: x[0], record))
neo4j_res = ','.join(result)
except Exception as e:
print(e)
print("neo4j查询失败 !!")
neo4j_res = ""
else:
print("生成Cypher查询失败 !!")
except Exception as e:
print(f"neo4j API 服务不可用: {e}")
# 合并 Milvus、PDF 和 neo4j 的召回结果, 共同作为 LLM 的输入 prompt
context = context + "\n" + neo4j_res
# ============================================================
# 4: 为LLM定义系统和用户提示
# ============================================================
SYSTEM_PROMPT = """
System: 你是一个非常得力的医学助手, 你可以通过从数据库中检索出的信息找到问题的答案.
"""
USER_PROMPT = f"""
User: 利用介于<context>和</context>之间的从数据库中检索出的信息来回答问题, 具体的问题介于<question>和</question>之间. 如果提供的信息为空, 则按照你的经验知识来给出尽可能严谨准确的回答, 不知道的时候坦诚的承认不了解, 不要编造不真实的信息.
<context>
{context}
</context>
<question>
{query}
</question>
"""
# ============================================================
# 5: 使用 OpenAI 最新版本模型, 根据提示生成回复
# ============================================================
response = generate_openai_answer(client_llm, SYSTEM_PROMPT + USER_PROMPT.format(context, query))
# ============================================================
# 6: 写入缓存
# ============================================================
cache_set(client_redis, query, response)
# ============================================================
# 7: 组装服务返回数据
# ============================================================
now = datetime.datetime.now()
time = now.strftime("%Y-%m-%d %H:%M:%S")
answer = {
"response": response,
"status": 200,
"time": time
}
return answer
if __name__ == '__main__':
# 主函数中直接启动fastapi服务
uvicorn.run(app, host='0.0.0.0', port=8103, workers=1)