import os
import uvicorn
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
import json
import requests
import datetime
from openai import OpenAI
from neo4j import GraphDatabase
from langchain_milvus import Milvus, BM25BuiltInFunction
from vector import OpenAIEmbeddings # 只保留 Embedding, Redis 相关全部移除
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.stores import InMemoryStore
from langchain_classic.retrievers.parent_document_retriever import ParentDocumentRetriever
from dotenv import load_dotenv
# 导入新的 redis 工具类 (替代旧的 get_redis_client, cache_set, cache_get)
from new_redis import redis_manager
# 加载 .env 文件中的环境变量, 隐藏 API Keys
load_dotenv()
os.environ["TOKENIZERS_PARALLELISM"] = "false"
app = FastAPI()
# ============================================================
# OpenAI LLM 客户端封装
# ============================================================
def create_openai_client():
"""创建 OpenAI 客户端"""
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
return client
def generate_openai_answer(client, prompt):
"""使用 OpenAI 生成回复"""
response = client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{"role": "user", "content": prompt}
],
temperature=0.7,
)
return response.choices[0].message.content
# 允许所有域的请求
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 创建 Embedding 模型
embedding_model = OpenAIEmbeddings()
print("创建 Embedding 模型成功......")
# 设置默认的 Milvus 数据库文件路径
URI = "./milvus_agent.db"
URI1 = "./pdf_agent.db"
# 创建 Milvus 连接
milvus_vectorstore = Milvus(
embedding_function=embedding_model,
builtin_function=BM25BuiltInFunction(),
vector_field=["dense", "sparse"],
index_params=[
{
"metric_type": "IP",
"index_type": "IVF_FLAT",
},
{
"metric_type": "BM25",
"index_type": "SPARSE_INVERTED_INDEX"
}
],
connection_args={"uri": URI},
)
retriever = milvus_vectorstore.as_retriever()
print("创建 Milvus 连接成功......")
docstore = InMemoryStore()
# 文本分割器
child_splitter = RecursiveCharacterTextSplitter(
chunk_size=200,
chunk_overlap=50,
length_function=len,
separators=["\n\n", "\n", "。", "!", "?", ";", ",", " ", ""]
)
parent_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200
)
pdf_vectorstore = Milvus(
embedding_function=embedding_model,
builtin_function=BM25BuiltInFunction(),
vector_field=["dense", "sparse"],
index_params=[
{
"metric_type": "IP",
"index_type": "IVF_FLAT",
},
{
"metric_type": "BM25",
"index_type": "SPARSE_INVERTED_INDEX"
}
],
connection_args={"uri": URI1},
consistency_level="Bounded",
drop_old=False,
)
# 设置父子文档检索器
parent_retriever = ParentDocumentRetriever(
vectorstore=pdf_vectorstore,
docstore=docstore,
child_splitter=child_splitter,
parent_splitter=parent_splitter,
)
print("创建 Parent Milvus 连接成功......")
# 获取 neo4j 图数据库的连接
neo4j_uri = os.getenv("NEO4J_URI", "bolt://localhost:7687")
neo4j_user = os.getenv("NEO4J_USER", "neo4j")
neo4j_password = os.getenv("NEO4J_PASSWORD", "neo4j")
driver = GraphDatabase.driver(uri=neo4j_uri, auth=(neo4j_user, neo4j_password), max_connection_lifetime=1000)
print("创建 Neo4j 连接成功......")
# 创建大语言模型, 采用 OpenAI
client_llm = create_openai_client()
print("创建 OpenAI LLM 成功......")
# 注意: Redis 连接已经在 new_redis 导入时自动初始化了 (单例模式)
print("创建 Redis 连接成功......")
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
# ============================================================
# 核心封装: 将 RAG + Neo4j + LLM 的耗时逻辑封装为一个函数
# ============================================================
def perform_rag_and_llm(query: str) -> str:
"""
执行完整的 RAG 流程:
1. Milvus 向量召回 & 重排序
2. PDF 父子文档检索
3. Neo4j 图数据库精准召回
4. OpenAI LLM 推理
"""
global milvus_vectorstore, retriever
# ============================================================
# 1: 向量数据库 Milvus 模糊召回 & 重排序
# ============================================================
recall_rerank_milvus = milvus_vectorstore.similarity_search(
query,
k=10,
ranker_type="rrf",
ranker_params={"k": 100}
)
if recall_rerank_milvus:
context = format_docs(recall_rerank_milvus)
else:
context = ""
# ============================================================
# 2: PDF 文档的 Milvus 召回 (父子文档检索器)
# ============================================================
pdf_res = ""
retrieved_docs = parent_retriever.invoke(query)
if retrieved_docs is not None and len(retrieved_docs) >= 1:
pdf_res = retrieved_docs[0].page_content
print("PDF res: ", pdf_res)
context = context + "\n" + pdf_res
# ============================================================
# 3: 图数据库 neo4j 精准召回
# ============================================================
neo4j_res = ""
data = {"natural_language_query": query}
data_json = json.dumps(data)
try:
cypher_response = requests.post("http://0.0.0.0:8101/generate", data_json)
if cypher_response.status_code == 200:
cypher_response_data = cypher_response.json()
cypher_query = cypher_response_data["cypher_query"]
confidence = cypher_response_data["confidence"]
is_valid = cypher_response_data["validated"]
if cypher_query is not None and float(confidence) >= 0.9 and is_valid == True:
print("neo4j Cypher 初步生成成功 !!!")
# 验证 neo4j 生成的 Cypher 命令完全正确
data = {"cypher_query": cypher_query}
data_json = json.dumps(data)
cypher_valid = requests.post("http://0.0.0.0:8101/validate", data_json)
if cypher_valid.status_code == 200:
cypher_valid_data = cypher_valid.json()
if cypher_valid_data["is_valid"] == True:
with driver.session() as session:
try:
record = session.run(cypher_query)
result = list(map(lambda x: x[0], record))
neo4j_res = ','.join(result)
except Exception as e:
print(e)
print("neo4j查询失败 !!")
neo4j_res = ""
else:
print("生成Cypher查询失败 !!")
except Exception as e:
print(f"neo4j API 服务不可用: {e}")
# 合并 Milvus、PDF 和 neo4j 的召回结果, 共同作为 LLM 的输入 prompt
context = context + "\n" + neo4j_res
# ============================================================
# 4: 为LLM定义系统和用户提示
# ============================================================
SYSTEM_PROMPT = """
System: 你是一个非常得力的医学助手, 你可以通过从数据库中检索出的信息找到问题的答案.
"""
USER_PROMPT = f"""
User: 利用介于和之间的从数据库中检索出的信息来回答问题, 具体的问题介于和之间. 如果提供的信息为空, 则按照你的经验知识来给出尽可能严谨准确的回答, 不知道的时候坦诚的承认不了解, 不要编造不真实的信息.
{context}
{query}
"""
# ============================================================
# 5: 使用 OpenAI 模型, 根据提示生成回复
# ============================================================
response = generate_openai_answer(client_llm, SYSTEM_PROMPT + USER_PROMPT)
return response
@app.post("/")
async def chatbot(request: Request):
try:
json_post_raw = await request.json()
# 处理可能的双重编码问题
if isinstance(json_post_raw, str):
json_post_list = json.loads(json_post_raw)
else:
json_post_list = json_post_raw
query = json_post_list.get('question')
if not query:
return {"status": 400, "error": "Question is required"}
# ============================================================
# 使用 redis_manager 接管全部缓存流程:
# 1. 尝试从缓存读
# 2. 如果没有, 自动加锁 (分布式互斥锁, 防击穿)
# 3. Double Check (防重复调用 LLM)
# 4. 执行 perform_rag_and_llm 回调函数
# 5. 将结果写入缓存 (带随机过期时间, 防雪崩)
# 6. 自动释放锁
# ============================================================
compute_callback = lambda: perform_rag_and_llm(query)
response = redis_manager.get_or_compute(query, compute_callback)
now = datetime.datetime.now()
time_str = now.strftime("%Y-%m-%d %H:%M:%S")
return {
"response": response,
"status": 200,
"time": time_str
}
except Exception as e:
print(f"Server Error: {e}")
return {"status": 500, "error": str(e)}
if __name__ == '__main__':
uvicorn.run(app, host='0.0.0.0', port=8103, workers=1)