| import os |
| import uvicorn |
| from fastapi import FastAPI, Request |
| from fastapi.middleware.cors import CORSMiddleware |
| import json |
| import requests |
| import datetime |
| from openai import OpenAI |
| from neo4j import GraphDatabase |
| from langchain_milvus import Milvus, BM25BuiltInFunction |
| from vector import OpenAIEmbeddings, get_redis_client, cache_set, cache_get |
| from langchain_text_splitters import RecursiveCharacterTextSplitter |
| from langchain_core.stores import InMemoryStore |
| from langchain_classic.retrievers.parent_document_retriever import ParentDocumentRetriever |
| from dotenv import load_dotenv |
|
|
| |
| load_dotenv() |
|
|
| os.environ["TOKENIZERS_PARALLELISM"] = "false" |
| app = FastAPI() |
|
|
|
|
| |
| |
| |
|
|
| def create_openai_client(): |
| """创建 OpenAI 客户端""" |
| client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) |
| return client |
|
|
|
|
| def generate_openai_answer(client, prompt): |
| """使用 OpenAI 生成回复""" |
| response = client.chat.completions.create( |
| model="gpt-4o-mini", |
| messages=[ |
| {"role": "user", "content": prompt} |
| ], |
| temperature=0.7, |
| ) |
| return response.choices[0].message.content |
|
|
|
|
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| embedding_model = OpenAIEmbeddings() |
| print("创建 Embedding 模型成功......") |
|
|
| |
| URI = "./milvus_agent.db" |
| URI1 = "./pdf_agent.db" |
|
|
| |
| milvus_vectorstore = Milvus( |
| embedding_function=embedding_model, |
| builtin_function=BM25BuiltInFunction(), |
| vector_field=["dense", "sparse"], |
| index_params=[ |
| { |
| "metric_type": "IP", |
| "index_type": "IVF_FLAT", |
| }, |
| { |
| "metric_type": "BM25", |
| "index_type": "SPARSE_INVERTED_INDEX" |
| } |
| ], |
| connection_args={"uri": URI}, |
| ) |
|
|
| retriever = milvus_vectorstore.as_retriever() |
| print("创建 Milvus 连接成功......") |
|
|
|
|
| docstore = InMemoryStore() |
|
|
| |
| child_splitter = RecursiveCharacterTextSplitter( |
| chunk_size=200, |
| chunk_overlap=50, |
| length_function=len, |
| separators=["\n\n", "\n", "。", "!", "?", ";", ",", " ", ""] |
| ) |
|
|
| parent_splitter = RecursiveCharacterTextSplitter( |
| chunk_size=1000, |
| chunk_overlap=200 |
| ) |
|
|
| pdf_vectorstore = Milvus( |
| embedding_function=embedding_model, |
| builtin_function=BM25BuiltInFunction(), |
| vector_field=["dense", "sparse"], |
| index_params=[ |
| { |
| "metric_type": "IP", |
| "index_type": "IVF_FLAT", |
| }, |
| { |
| "metric_type": "BM25", |
| "index_type": "SPARSE_INVERTED_INDEX" |
| } |
| ], |
| connection_args={"uri": URI1}, |
| consistency_level="Bounded", |
| drop_old=False, |
| ) |
|
|
| |
| parent_retriever = ParentDocumentRetriever( |
| vectorstore=pdf_vectorstore, |
| docstore=docstore, |
| child_splitter=child_splitter, |
| parent_splitter=parent_splitter, |
| ) |
|
|
| print("创建 Parent Milvus 连接成功......") |
|
|
| |
| neo4j_uri = os.getenv("NEO4J_URI", "bolt://localhost:7687") |
| neo4j_user = os.getenv("NEO4J_USER", "neo4j") |
| neo4j_password = os.getenv("NEO4J_PASSWORD", "neo4j") |
| driver = GraphDatabase.driver(uri=neo4j_uri, auth=(neo4j_user, neo4j_password), max_connection_lifetime=1000) |
| print("创建 Neo4j 连接成功......") |
|
|
| |
| client_llm = create_openai_client() |
| print("创建 OpenAI LLM 成功......") |
|
|
| |
| client_redis = get_redis_client() |
| print("创建 Redis 连接成功......") |
|
|
|
|
| def format_docs(docs): |
| return "\n\n".join(doc.page_content for doc in docs) |
|
|
|
|
| @app.post("/") |
| async def chatbot(request: Request): |
| global milvus_vectorstore, retriever |
|
|
| json_post_raw = await request.json() |
| json_post = json.dumps(json_post_raw) |
| json_post_list = json.loads(json_post) |
|
|
| query = json_post_list.get('question') |
|
|
| |
| |
| |
| response_redis = cache_get(client_redis, query) |
|
|
| if response_redis is not None: |
| |
| response = response_redis.decode('utf-8') |
|
|
| now = datetime.datetime.now() |
| time = now.strftime("%Y-%m-%d %H:%M:%S") |
| answer = { |
| "response": response, |
| "status": 200, |
| "time": time |
| } |
| print('REDIS HIT !!!') |
| return answer |
|
|
| |
| |
| |
| |
| recall_rerank_milvus = milvus_vectorstore.similarity_search( |
| query, |
| k=10, |
| ranker_type="rrf", |
| ranker_params={"k": 100} |
| ) |
|
|
| if recall_rerank_milvus: |
| |
| context = format_docs(recall_rerank_milvus) |
| else: |
| context = "" |
|
|
| |
| |
| |
| pdf_res = "" |
| retrieved_docs = parent_retriever.invoke(query) |
|
|
| if retrieved_docs is not None and len(retrieved_docs) >= 1: |
| pdf_res = retrieved_docs[0].page_content |
| print("PDF res: ", pdf_res) |
|
|
| context = context + "\n" + pdf_res |
|
|
| |
| |
| |
| |
| neo4j_res = "" |
| data = {"natural_language_query": query} |
| data_json = json.dumps(data) |
|
|
| try: |
| cypher_response = requests.post("http://0.0.0.0:8101/generate", data_json) |
|
|
| if cypher_response.status_code == 200: |
| cypher_response_data = cypher_response.json() |
|
|
| cypher_query = cypher_response_data["cypher_query"] |
| confidence = cypher_response_data["confidence"] |
| is_valid = cypher_response_data["validated"] |
|
|
| if cypher_query is not None and float(confidence) >= 0.9 and is_valid == True: |
| print("neo4j Cypher 初步生成成功 !!!") |
|
|
| |
| data = {"cypher_query": cypher_query} |
| data_json = json.dumps(data) |
| cypher_valid = requests.post("http://0.0.0.0:8101/validate", data_json) |
|
|
| if cypher_valid.status_code == 200: |
| cypher_valid_data = cypher_valid.json() |
| if cypher_valid_data["is_valid"] == True: |
| with driver.session() as session: |
| try: |
| record = session.run(cypher_query) |
| result = list(map(lambda x: x[0], record)) |
| neo4j_res = ','.join(result) |
| except Exception as e: |
| print(e) |
| print("neo4j查询失败 !!") |
| neo4j_res = "" |
| else: |
| print("生成Cypher查询失败 !!") |
| except Exception as e: |
| print(f"neo4j API 服务不可用: {e}") |
|
|
| |
| context = context + "\n" + neo4j_res |
|
|
| |
| |
| |
| SYSTEM_PROMPT = """ |
| System: 你是一个非常得力的医学助手, 你可以通过从数据库中检索出的信息找到问题的答案. |
| """ |
|
|
| USER_PROMPT = f""" |
| User: 利用介于<context>和</context>之间的从数据库中检索出的信息来回答问题, 具体的问题介于<question>和</question>之间. 如果提供的信息为空, 则按照你的经验知识来给出尽可能严谨准确的回答, 不知道的时候坦诚的承认不了解, 不要编造不真实的信息. |
| <context> |
| {context} |
| </context> |
| |
| <question> |
| {query} |
| </question> |
| """ |
|
|
| |
| |
| |
| response = generate_openai_answer(client_llm, SYSTEM_PROMPT + USER_PROMPT.format(context, query)) |
|
|
| |
| |
| |
| cache_set(client_redis, query, response) |
|
|
| |
| |
| |
| now = datetime.datetime.now() |
| time = now.strftime("%Y-%m-%d %H:%M:%S") |
|
|
| answer = { |
| "response": response, |
| "status": 200, |
| "time": time |
| } |
|
|
| return answer |
|
|
|
|
| if __name__ == '__main__': |
| |
| uvicorn.run(app, host='0.0.0.0', port=8103, workers=1) |