import os import uvicorn from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware import json import requests import datetime from openai import OpenAI from neo4j import GraphDatabase from langchain_milvus import Milvus, BM25BuiltInFunction from vector import OpenAIEmbeddings # 只保留 Embedding, Redis 相关全部移除 from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain_core.stores import InMemoryStore from langchain_classic.retrievers.parent_document_retriever import ParentDocumentRetriever from dotenv import load_dotenv # 导入新的 redis 工具类 (替代旧的 get_redis_client, cache_set, cache_get) from new_redis import redis_manager # 加载 .env 文件中的环境变量, 隐藏 API Keys load_dotenv() os.environ["TOKENIZERS_PARALLELISM"] = "false" app = FastAPI() # ============================================================ # OpenAI LLM 客户端封装 # ============================================================ def create_openai_client(): """创建 OpenAI 客户端""" client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) return client def generate_openai_answer(client, prompt): """使用 OpenAI 生成回复""" response = client.chat.completions.create( model="gpt-4o-mini", messages=[ {"role": "user", "content": prompt} ], temperature=0.7, ) return response.choices[0].message.content # 允许所有域的请求 app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # 创建 Embedding 模型 embedding_model = OpenAIEmbeddings() print("创建 Embedding 模型成功......") # 设置默认的 Milvus 数据库文件路径 URI = "./milvus_agent.db" URI1 = "./pdf_agent.db" # 创建 Milvus 连接 milvus_vectorstore = Milvus( embedding_function=embedding_model, builtin_function=BM25BuiltInFunction(), vector_field=["dense", "sparse"], index_params=[ { "metric_type": "IP", "index_type": "IVF_FLAT", }, { "metric_type": "BM25", "index_type": "SPARSE_INVERTED_INDEX" } ], connection_args={"uri": URI}, ) retriever = milvus_vectorstore.as_retriever() print("创建 Milvus 连接成功......") docstore = InMemoryStore() # 文本分割器 child_splitter = RecursiveCharacterTextSplitter( chunk_size=200, chunk_overlap=50, length_function=len, separators=["\n\n", "\n", "。", "!", "?", ";", ",", " ", ""] ) parent_splitter = RecursiveCharacterTextSplitter( chunk_size=1000, chunk_overlap=200 ) pdf_vectorstore = Milvus( embedding_function=embedding_model, builtin_function=BM25BuiltInFunction(), vector_field=["dense", "sparse"], index_params=[ { "metric_type": "IP", "index_type": "IVF_FLAT", }, { "metric_type": "BM25", "index_type": "SPARSE_INVERTED_INDEX" } ], connection_args={"uri": URI1}, consistency_level="Bounded", drop_old=False, ) # 设置父子文档检索器 parent_retriever = ParentDocumentRetriever( vectorstore=pdf_vectorstore, docstore=docstore, child_splitter=child_splitter, parent_splitter=parent_splitter, ) print("创建 Parent Milvus 连接成功......") # 获取 neo4j 图数据库的连接 neo4j_uri = os.getenv("NEO4J_URI", "bolt://localhost:7687") neo4j_user = os.getenv("NEO4J_USER", "neo4j") neo4j_password = os.getenv("NEO4J_PASSWORD", "neo4j") driver = GraphDatabase.driver(uri=neo4j_uri, auth=(neo4j_user, neo4j_password), max_connection_lifetime=1000) print("创建 Neo4j 连接成功......") # 创建大语言模型, 采用 OpenAI client_llm = create_openai_client() print("创建 OpenAI LLM 成功......") # 注意: Redis 连接已经在 new_redis 导入时自动初始化了 (单例模式) print("创建 Redis 连接成功......") def format_docs(docs): return "\n\n".join(doc.page_content for doc in docs) # ============================================================ # 核心封装: 将 RAG + Neo4j + LLM 的耗时逻辑封装为一个函数 # ============================================================ def perform_rag_and_llm(query: str) -> str: """ 执行完整的 RAG 流程: 1. Milvus 向量召回 & 重排序 2. PDF 父子文档检索 3. Neo4j 图数据库精准召回 4. OpenAI LLM 推理 """ global milvus_vectorstore, retriever # ============================================================ # 1: 向量数据库 Milvus 模糊召回 & 重排序 # ============================================================ recall_rerank_milvus = milvus_vectorstore.similarity_search( query, k=10, ranker_type="rrf", ranker_params={"k": 100} ) if recall_rerank_milvus: context = format_docs(recall_rerank_milvus) else: context = "" # ============================================================ # 2: PDF 文档的 Milvus 召回 (父子文档检索器) # ============================================================ pdf_res = "" retrieved_docs = parent_retriever.invoke(query) if retrieved_docs is not None and len(retrieved_docs) >= 1: pdf_res = retrieved_docs[0].page_content print("PDF res: ", pdf_res) context = context + "\n" + pdf_res # ============================================================ # 3: 图数据库 neo4j 精准召回 # ============================================================ neo4j_res = "" data = {"natural_language_query": query} data_json = json.dumps(data) try: cypher_response = requests.post("http://0.0.0.0:8101/generate", data_json) if cypher_response.status_code == 200: cypher_response_data = cypher_response.json() cypher_query = cypher_response_data["cypher_query"] confidence = cypher_response_data["confidence"] is_valid = cypher_response_data["validated"] if cypher_query is not None and float(confidence) >= 0.9 and is_valid == True: print("neo4j Cypher 初步生成成功 !!!") # 验证 neo4j 生成的 Cypher 命令完全正确 data = {"cypher_query": cypher_query} data_json = json.dumps(data) cypher_valid = requests.post("http://0.0.0.0:8101/validate", data_json) if cypher_valid.status_code == 200: cypher_valid_data = cypher_valid.json() if cypher_valid_data["is_valid"] == True: with driver.session() as session: try: record = session.run(cypher_query) result = list(map(lambda x: x[0], record)) neo4j_res = ','.join(result) except Exception as e: print(e) print("neo4j查询失败 !!") neo4j_res = "" else: print("生成Cypher查询失败 !!") except Exception as e: print(f"neo4j API 服务不可用: {e}") # 合并 Milvus、PDF 和 neo4j 的召回结果, 共同作为 LLM 的输入 prompt context = context + "\n" + neo4j_res # ============================================================ # 4: 为LLM定义系统和用户提示 # ============================================================ SYSTEM_PROMPT = """ System: 你是一个非常得力的医学助手, 你可以通过从数据库中检索出的信息找到问题的答案. """ USER_PROMPT = f""" User: 利用介于之间的从数据库中检索出的信息来回答问题, 具体的问题介于之间. 如果提供的信息为空, 则按照你的经验知识来给出尽可能严谨准确的回答, 不知道的时候坦诚的承认不了解, 不要编造不真实的信息. {context} {query} """ # ============================================================ # 5: 使用 OpenAI 模型, 根据提示生成回复 # ============================================================ response = generate_openai_answer(client_llm, SYSTEM_PROMPT + USER_PROMPT) return response @app.post("/") async def chatbot(request: Request): try: json_post_raw = await request.json() # 处理可能的双重编码问题 if isinstance(json_post_raw, str): json_post_list = json.loads(json_post_raw) else: json_post_list = json_post_raw query = json_post_list.get('question') if not query: return {"status": 400, "error": "Question is required"} # ============================================================ # 使用 redis_manager 接管全部缓存流程: # 1. 尝试从缓存读 # 2. 如果没有, 自动加锁 (分布式互斥锁, 防击穿) # 3. Double Check (防重复调用 LLM) # 4. 执行 perform_rag_and_llm 回调函数 # 5. 将结果写入缓存 (带随机过期时间, 防雪崩) # 6. 自动释放锁 # ============================================================ compute_callback = lambda: perform_rag_and_llm(query) response = redis_manager.get_or_compute(query, compute_callback) now = datetime.datetime.now() time_str = now.strftime("%Y-%m-%d %H:%M:%S") return { "response": response, "status": 200, "time": time_str } except Exception as e: print(f"Server Error: {e}") return {"status": 500, "error": str(e)} if __name__ == '__main__': uvicorn.run(app, host='0.0.0.0', port=8103, workers=1)