| import os |
| import uvicorn |
| from fastapi import FastAPI, Request |
| from fastapi.middleware.cors import CORSMiddleware |
| import json |
| import datetime |
| from langchain_milvus import Milvus, BM25BuiltInFunction |
| from model import OpenAIEmbeddings |
| from dotenv import load_dotenv |
|
|
| |
| load_dotenv() |
|
|
| os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
| |
| USE_LOCAL_LLM = False |
|
|
| if USE_LOCAL_LLM: |
| from model import create_local_llm_client as create_client |
| from model import generate_local_answer as generate_answer |
| else: |
| from model import create_chatgpt_client as create_client |
| from model import generate_chatgpt_answer as generate_answer |
|
|
| client_llm = create_client() |
| print(f"创建 {'本地 vLLM' if USE_LOCAL_LLM else 'ChatGPT'} 客户端成功......") |
|
|
| app = FastAPI() |
|
|
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| |
| embedding_model = OpenAIEmbeddings() |
| print("创建 Embedding 模型成功......") |
|
|
| |
| URI = "./milvus_agent.db" |
|
|
| |
| milvus_vectorstore = Milvus( |
| embedding_function=embedding_model, |
| builtin_function=BM25BuiltInFunction(), |
| vector_field=["dense", "sparse"], |
| index_params=[ |
| { |
| "metric_type": "IP", |
| "index_type": "IVF_FLAT", |
| }, |
| { |
| "metric_type": "BM25", |
| "index_type": "SPARSE_INVERTED_INDEX" |
| } |
| ], |
| connection_args={"uri": URI}, |
| ) |
|
|
| retriever = milvus_vectorstore.as_retriever() |
| print("创建 Milvus 连接成功......") |
|
|
|
|
| def format_docs(docs): |
| return "\n\n".join(doc.page_content for doc in docs) |
|
|
|
|
| @app.post("/") |
| async def chatbot(request: Request): |
| global milvus_vectorstore, retriever |
|
|
| json_post_raw = await request.json() |
| json_post = json.dumps(json_post_raw) |
| json_post_list = json.loads(json_post) |
|
|
| query = json_post_list.get('question') |
|
|
| |
| |
| recall_rerank_milvus = milvus_vectorstore.similarity_search( |
| query, |
| k=10, |
| ranker_type="rrf", |
| ranker_params={"k": 100} |
| ) |
|
|
| if recall_rerank_milvus: |
| |
| context = format_docs(recall_rerank_milvus) |
| else: |
| context = [] |
|
|
| |
| SYSTEM_PROMPT = """ |
| System: 你是一个非常得力的医学助手, 你可以通过从数据库中检索出的信息找到问题的答案. |
| """ |
|
|
| USER_PROMPT = f""" |
| User: 利用介于<context>和</context>之间的从数据库中检索出的信息来回答问题, 具体的问题介于<question>和</question>之间. 如果提供的信息为空, 则按照你的经验知识来给出尽可能严谨准确的回答, 不知道的时候坦诚的承认不了解, 不要编造不真实的信息. |
| <context> |
| {context} |
| </context> |
| |
| <question> |
| {query} |
| </question> |
| """ |
|
|
| |
| response = generate_answer(client_llm, SYSTEM_PROMPT + USER_PROMPT) |
|
|
| now = datetime.datetime.now() |
| time = now.strftime("%Y-%m-%d %H:%M:%S") |
|
|
| answer = { |
| "response": response, |
| "status": 200, |
| "time": time |
| } |
|
|
| return answer |
|
|
|
|
| if __name__ == '__main__': |
| |
| uvicorn.run(app, host='0.0.0.0', port=8103, workers=1) |