Commit ·
e45316c
0
Parent(s):
AgentV2: Medical AI Agent with multi-route RAG
Browse files- .gitattributes +5 -0
- .gitignore +1 -0
- .ipynb_checkpoints/test-checkpoint.py +21 -0
- .milvus_agent.db.lock +0 -0
- .pdf_agent.db.lock +0 -0
- __pycache__/vector.cpython-312.pyc +0 -0
- agent.py +131 -0
- agent2.py +206 -0
- data/dialog.jsonl +3 -0
- data/train.json +3 -0
- milvus_agent.db +3 -0
- model.py +102 -0
- pdf_agent.db +3 -0
- pdf_documents/01.内科学_第9版_全书签_可复制检索.pdf +3 -0
- pdf_documents/02.外科学_第9版_全书签_可复制检索.pdf +3 -0
- pdf_documents/03.妇产科学_第9版_全书签_可复制检索.pdf +3 -0
- pdf_documents/04.儿科学_第9版_全书签_可复制检索.pdf +3 -0
- pdf_documents/05.神经病学_第9版_全书签_可复制检索.pdf +3 -0
- pdf_documents/06.系统解剖学_第9版_全书签_可复制检索.pdf +3 -0
- pdf_documents/07.局部解剖学_第9版_全书签_可部分复制检索.pdf +3 -0
- pdf_documents/08.组织学与胚胎学_第9版_全书签_可复制检索.pdf +3 -0
- pdf_documents/09.生物化学与分子生物学_第9版_全书签_可复制检索.pdf +3 -0
- pdf_documents/10.生理学(可复制).pdf +3 -0
- pdf_output/pdf_detailed_text.xlsx +3 -0
- pdf_output/pdf_extraction_summary.xlsx +3 -0
- pdf_output/pdf_processing.log +24 -0
- pdf_output/progress_batch_10.csv +11 -0
- preprocess.py +259 -0
- test.py +21 -0
- vector.py +262 -0
.gitattributes
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.pdf filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.db filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.xlsx filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.jsonl filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.json filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
.env
|
.ipynb_checkpoints/test-checkpoint.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import requests
|
| 2 |
+
import time
|
| 3 |
+
import json
|
| 4 |
+
|
| 5 |
+
url = "http://0.0.0.0:8103/"
|
| 6 |
+
data = {"question": "平日里蜂蜜加白醋一起喝有什么疗效?"}
|
| 7 |
+
#data = {"question": "听说用酸枣仁泡水喝能养生,是真的吗?"}
|
| 8 |
+
|
| 9 |
+
start_time = time.time()
|
| 10 |
+
|
| 11 |
+
data = json.dumps(data)
|
| 12 |
+
|
| 13 |
+
# 向服务发送请求
|
| 14 |
+
res = requests.post(url, data=data)
|
| 15 |
+
|
| 16 |
+
cost_time = time.time() - start_time
|
| 17 |
+
|
| 18 |
+
print('单次查询的耗时:', cost_time, 's')
|
| 19 |
+
|
| 20 |
+
res = json.loads(res.text)
|
| 21 |
+
print(res)
|
.milvus_agent.db.lock
ADDED
|
File without changes
|
.pdf_agent.db.lock
ADDED
|
File without changes
|
__pycache__/vector.cpython-312.pyc
ADDED
|
Binary file (8.68 kB). View file
|
|
|
agent.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import uvicorn
|
| 3 |
+
from fastapi import FastAPI, Request
|
| 4 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 5 |
+
import json
|
| 6 |
+
import datetime
|
| 7 |
+
from langchain_milvus import Milvus, BM25BuiltInFunction
|
| 8 |
+
from model import OpenAIEmbeddings
|
| 9 |
+
from dotenv import load_dotenv
|
| 10 |
+
|
| 11 |
+
# 加载 .env 文件中的环境变量, 隐藏 API Keys
|
| 12 |
+
load_dotenv()
|
| 13 |
+
|
| 14 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 15 |
+
|
| 16 |
+
# ====== 开关: True=本地vLLM, False=ChatGPT ======
|
| 17 |
+
USE_LOCAL_LLM = False
|
| 18 |
+
|
| 19 |
+
if USE_LOCAL_LLM:
|
| 20 |
+
from model import create_local_llm_client as create_client
|
| 21 |
+
from model import generate_local_answer as generate_answer
|
| 22 |
+
else:
|
| 23 |
+
from model import create_chatgpt_client as create_client
|
| 24 |
+
from model import generate_chatgpt_answer as generate_answer
|
| 25 |
+
|
| 26 |
+
client_llm = create_client()
|
| 27 |
+
print(f"创建 {'本地 vLLM' if USE_LOCAL_LLM else 'ChatGPT'} 客户端成功......")
|
| 28 |
+
|
| 29 |
+
app = FastAPI()
|
| 30 |
+
|
| 31 |
+
# 允许所有域的请求
|
| 32 |
+
app.add_middleware(
|
| 33 |
+
CORSMiddleware,
|
| 34 |
+
allow_origins=["*"],
|
| 35 |
+
allow_credentials=True,
|
| 36 |
+
allow_methods=["*"],
|
| 37 |
+
allow_headers=["*"],
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
# 创建 Embedding 模型
|
| 41 |
+
embedding_model = OpenAIEmbeddings()
|
| 42 |
+
print("创建 Embedding 模型成功......")
|
| 43 |
+
|
| 44 |
+
# 设置默认的 Milvus 数据库文件路径
|
| 45 |
+
URI = "./milvus_agent.db"
|
| 46 |
+
|
| 47 |
+
# 创建 Milvus 连接
|
| 48 |
+
milvus_vectorstore = Milvus(
|
| 49 |
+
embedding_function=embedding_model,
|
| 50 |
+
builtin_function=BM25BuiltInFunction(),
|
| 51 |
+
vector_field=["dense", "sparse"],
|
| 52 |
+
index_params=[
|
| 53 |
+
{
|
| 54 |
+
"metric_type": "IP",
|
| 55 |
+
"index_type": "IVF_FLAT",
|
| 56 |
+
},
|
| 57 |
+
{
|
| 58 |
+
"metric_type": "BM25",
|
| 59 |
+
"index_type": "SPARSE_INVERTED_INDEX"
|
| 60 |
+
}
|
| 61 |
+
],
|
| 62 |
+
connection_args={"uri": URI},
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
retriever = milvus_vectorstore.as_retriever()
|
| 66 |
+
print("创建 Milvus 连接成功......")
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def format_docs(docs):
|
| 70 |
+
return "\n\n".join(doc.page_content for doc in docs)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
@app.post("/")
|
| 74 |
+
async def chatbot(request: Request):
|
| 75 |
+
global milvus_vectorstore, retriever
|
| 76 |
+
|
| 77 |
+
json_post_raw = await request.json()
|
| 78 |
+
json_post = json.dumps(json_post_raw)
|
| 79 |
+
json_post_list = json.loads(json_post)
|
| 80 |
+
|
| 81 |
+
query = json_post_list.get('question')
|
| 82 |
+
|
| 83 |
+
# 召回 & 排序
|
| 84 |
+
# 在集合中搜索问题并检索语义 top-10 匹配项, 而且已经配置了 reranker 的处理, 采用RRF算法
|
| 85 |
+
recall_rerank_milvus = milvus_vectorstore.similarity_search(
|
| 86 |
+
query,
|
| 87 |
+
k=10,
|
| 88 |
+
ranker_type="rrf",
|
| 89 |
+
ranker_params={"k": 100}
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
if recall_rerank_milvus:
|
| 93 |
+
# 检索结果存放在列表中
|
| 94 |
+
context = format_docs(recall_rerank_milvus)
|
| 95 |
+
else:
|
| 96 |
+
context = []
|
| 97 |
+
|
| 98 |
+
# 为LLM定义系统和用户提示, 这个提示是由从Milvus检索到的文档组装而成的.
|
| 99 |
+
SYSTEM_PROMPT = """
|
| 100 |
+
System: 你是一个非常得力的医学助手, 你可以通过从数据库中检索出的信息找到问题的答案.
|
| 101 |
+
"""
|
| 102 |
+
|
| 103 |
+
USER_PROMPT = f"""
|
| 104 |
+
User: 利用介于<context>和</context>之间的从数据库中检索出的信息来回答问题, 具体的问题介于<question>和</question>之间. 如果提供的信息为空, 则按照你的经验知识来给出尽可能严谨准确的回答, 不知道的时候坦诚的承认不了解, 不要编造不真实的信息.
|
| 105 |
+
<context>
|
| 106 |
+
{context}
|
| 107 |
+
</context>
|
| 108 |
+
|
| 109 |
+
<question>
|
| 110 |
+
{query}
|
| 111 |
+
</question>
|
| 112 |
+
"""
|
| 113 |
+
|
| 114 |
+
# 使用 LLM 模型, 根据提示生成回复 (根据 USE_LOCAL_LLM 开关自动选择 vLLM 或 ChatGPT)
|
| 115 |
+
response = generate_answer(client_llm, SYSTEM_PROMPT + USER_PROMPT)
|
| 116 |
+
|
| 117 |
+
now = datetime.datetime.now()
|
| 118 |
+
time = now.strftime("%Y-%m-%d %H:%M:%S")
|
| 119 |
+
|
| 120 |
+
answer = {
|
| 121 |
+
"response": response,
|
| 122 |
+
"status": 200,
|
| 123 |
+
"time": time
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
return answer
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
if __name__ == '__main__':
|
| 130 |
+
# 主函数中直接启动fastapi服务
|
| 131 |
+
uvicorn.run(app, host='0.0.0.0', port=8103, workers=1)
|
agent2.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import uvicorn
|
| 3 |
+
from fastapi import FastAPI, Request
|
| 4 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 5 |
+
import json
|
| 6 |
+
import datetime
|
| 7 |
+
from openai import OpenAI
|
| 8 |
+
from langchain_milvus import Milvus, BM25BuiltInFunction
|
| 9 |
+
from vector import OpenAIEmbeddings
|
| 10 |
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
| 11 |
+
from langchain_core.stores import InMemoryStore
|
| 12 |
+
from langchain_classic.retrievers.parent_document_retriever import ParentDocumentRetriever
|
| 13 |
+
from dotenv import load_dotenv
|
| 14 |
+
|
| 15 |
+
# 加载 .env 文件中的环境变量, 隐藏 API Keys
|
| 16 |
+
load_dotenv()
|
| 17 |
+
|
| 18 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 19 |
+
app = FastAPI()
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# ============================================================
|
| 23 |
+
# OpenAI LLM 客户端封装 (替代讲义中的 DeepSeek)
|
| 24 |
+
# ============================================================
|
| 25 |
+
|
| 26 |
+
def create_openai_client():
|
| 27 |
+
"""创建 OpenAI 客户端"""
|
| 28 |
+
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
| 29 |
+
return client
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def generate_openai_answer(client, prompt):
|
| 33 |
+
"""使用 OpenAI 生成回复"""
|
| 34 |
+
response = client.chat.completions.create(
|
| 35 |
+
model="gpt-4o-mini",
|
| 36 |
+
messages=[
|
| 37 |
+
{"role": "user", "content": prompt}
|
| 38 |
+
],
|
| 39 |
+
temperature=0.7,
|
| 40 |
+
)
|
| 41 |
+
return response.choices[0].message.content
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# 允许所有域的请求
|
| 45 |
+
app.add_middleware(
|
| 46 |
+
CORSMiddleware,
|
| 47 |
+
allow_origins=["*"],
|
| 48 |
+
allow_credentials=True,
|
| 49 |
+
allow_methods=["*"],
|
| 50 |
+
allow_headers=["*"],
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
# 创建 Embedding 模型
|
| 54 |
+
embedding_model = OpenAIEmbeddings()
|
| 55 |
+
print("创建 Embedding 模型成功......")
|
| 56 |
+
|
| 57 |
+
# 设置默认的 Milvus 数据库文件路径
|
| 58 |
+
URI = "./milvus_agent.db"
|
| 59 |
+
URI1 = "./pdf_agent.db"
|
| 60 |
+
|
| 61 |
+
# 创建 Milvus 连接
|
| 62 |
+
milvus_vectorstore = Milvus(
|
| 63 |
+
embedding_function=embedding_model,
|
| 64 |
+
builtin_function=BM25BuiltInFunction(),
|
| 65 |
+
vector_field=["dense", "sparse"],
|
| 66 |
+
index_params=[
|
| 67 |
+
{
|
| 68 |
+
"metric_type": "IP",
|
| 69 |
+
"index_type": "IVF_FLAT",
|
| 70 |
+
},
|
| 71 |
+
{
|
| 72 |
+
"metric_type": "BM25",
|
| 73 |
+
"index_type": "SPARSE_INVERTED_INDEX"
|
| 74 |
+
}
|
| 75 |
+
],
|
| 76 |
+
connection_args={"uri": URI},
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
retriever = milvus_vectorstore.as_retriever()
|
| 80 |
+
print("创建 Milvus 连接成功......")
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
docstore = InMemoryStore()
|
| 84 |
+
|
| 85 |
+
# 文本分割器
|
| 86 |
+
child_splitter = RecursiveCharacterTextSplitter(
|
| 87 |
+
chunk_size=200,
|
| 88 |
+
chunk_overlap=50,
|
| 89 |
+
length_function=len,
|
| 90 |
+
separators=["\n\n", "\n", "。", "!", "?", ";", ",", " ", ""]
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
parent_splitter = RecursiveCharacterTextSplitter(
|
| 94 |
+
chunk_size=1000,
|
| 95 |
+
chunk_overlap=200
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
pdf_vectorstore = Milvus(
|
| 99 |
+
embedding_function=embedding_model,
|
| 100 |
+
builtin_function=BM25BuiltInFunction(),
|
| 101 |
+
vector_field=["dense", "sparse"],
|
| 102 |
+
index_params=[
|
| 103 |
+
{
|
| 104 |
+
"metric_type": "IP",
|
| 105 |
+
"index_type": "IVF_FLAT",
|
| 106 |
+
},
|
| 107 |
+
{
|
| 108 |
+
"metric_type": "BM25",
|
| 109 |
+
"index_type": "SPARSE_INVERTED_INDEX"
|
| 110 |
+
}
|
| 111 |
+
],
|
| 112 |
+
connection_args={"uri": URI1},
|
| 113 |
+
consistency_level="Bounded",
|
| 114 |
+
drop_old=False,
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
# 设置父子文档检索器
|
| 118 |
+
parent_retriever = ParentDocumentRetriever(
|
| 119 |
+
vectorstore=pdf_vectorstore,
|
| 120 |
+
docstore=docstore,
|
| 121 |
+
child_splitter=child_splitter,
|
| 122 |
+
parent_splitter=parent_splitter,
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
print("创建 Parent Milvus 连接成功......")
|
| 126 |
+
|
| 127 |
+
# 创建大语言模型, 采用 OpenAI
|
| 128 |
+
client_llm = create_openai_client()
|
| 129 |
+
print("创建 OpenAI LLM 成功......")
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def format_docs(docs):
|
| 133 |
+
return "\n\n".join(doc.page_content for doc in docs)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
@app.post("/")
|
| 137 |
+
async def chatbot(request: Request):
|
| 138 |
+
global milvus_vectorstore, retriever
|
| 139 |
+
|
| 140 |
+
json_post_raw = await request.json()
|
| 141 |
+
json_post = json.dumps(json_post_raw)
|
| 142 |
+
json_post_list = json.loads(json_post)
|
| 143 |
+
|
| 144 |
+
query = json_post_list.get('question')
|
| 145 |
+
|
| 146 |
+
# 1: Milvus 召回 & 排序
|
| 147 |
+
# 在集合中搜索问题并检索语义 top-10 匹配项, 而且已经配置了 reranker 的处理, 采用RRF算法
|
| 148 |
+
recall_rerank_milvus = milvus_vectorstore.similarity_search(
|
| 149 |
+
query,
|
| 150 |
+
k=10,
|
| 151 |
+
ranker_type="rrf",
|
| 152 |
+
ranker_params={"k": 100}
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
if recall_rerank_milvus:
|
| 156 |
+
# 检索结果存放在列表中
|
| 157 |
+
context = [r.page_content for r in recall_rerank_milvus]
|
| 158 |
+
context = format_docs(recall_rerank_milvus)
|
| 159 |
+
else:
|
| 160 |
+
context = ""
|
| 161 |
+
|
| 162 |
+
# 2: PDF 文档的 Milvus 召回
|
| 163 |
+
# 父文档检索器按照query进行召回
|
| 164 |
+
res = ""
|
| 165 |
+
retrieved_docs = parent_retriever.invoke(query)
|
| 166 |
+
|
| 167 |
+
if retrieved_docs is not None and len(retrieved_docs) >= 1:
|
| 168 |
+
res = retrieved_docs[0].page_content
|
| 169 |
+
print("PDF res: ", res)
|
| 170 |
+
|
| 171 |
+
context = context + "\n" + res
|
| 172 |
+
|
| 173 |
+
# 为LLM定义系统和用户提示, 这个提示是由从Milvus检索到的文档组装而成的.
|
| 174 |
+
SYSTEM_PROMPT = """
|
| 175 |
+
System: 你是一个非常得力的医学助手, 你可以通过从数据库中检索出的信息找到问题的答案.
|
| 176 |
+
"""
|
| 177 |
+
|
| 178 |
+
USER_PROMPT = f"""
|
| 179 |
+
User: 利用介于<context>和</context>之间的从数据库中检索出的信息来回答问题, 具体的问题介于<question>和</question>之间. 如果提供的信息为空, 则按照你的经验知识来给出尽可能严谨准确的回答, 不知道的时候坦诚的承认不了解, 不要编造不真实的信息.
|
| 180 |
+
<context>
|
| 181 |
+
{context}
|
| 182 |
+
</context>
|
| 183 |
+
|
| 184 |
+
<question>
|
| 185 |
+
{query}
|
| 186 |
+
</question>
|
| 187 |
+
"""
|
| 188 |
+
|
| 189 |
+
# 3. 使用 OpenAI 最新版本模型, 根据提示生成回复
|
| 190 |
+
response = generate_openai_answer(client_llm, SYSTEM_PROMPT + USER_PROMPT.format(context, query))
|
| 191 |
+
|
| 192 |
+
now = datetime.datetime.now()
|
| 193 |
+
time = now.strftime("%Y-%m-%d %H:%M:%S")
|
| 194 |
+
|
| 195 |
+
answer = {
|
| 196 |
+
"response": response,
|
| 197 |
+
"status": 200,
|
| 198 |
+
"time": time
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
return answer
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
if __name__ == '__main__':
|
| 205 |
+
# 主函数中直接启动fastapi服务
|
| 206 |
+
uvicorn.run(app, host='0.0.0.0', port=8103, workers=1)
|
data/dialog.jsonl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:879089741db2827a13e1a6f61716be405a95b40a168623a577ad0f22615c7911
|
| 3 |
+
size 15035121
|
data/train.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cbd0d98a702e355753bf9f1a2f14f034323464156654ca0cd7d1d8b2e97f6864
|
| 3 |
+
size 1834602
|
milvus_agent.db
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3fb6f3a55a098a6eac5d6b916bb55ac65827941084f862007ded0669e2671f8e
|
| 3 |
+
size 28672
|
model.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from openai import OpenAI
|
| 3 |
+
from langchain.embeddings.base import Embeddings
|
| 4 |
+
from dotenv import load_dotenv
|
| 5 |
+
|
| 6 |
+
# 加载 .env 文件中的环境变量, 隐藏 API Keys
|
| 7 |
+
load_dotenv()
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
# ============================================================
|
| 11 |
+
# 模型1: 嵌入模型, 采用 OpenAI text-embedding-3-small
|
| 12 |
+
# ============================================================
|
| 13 |
+
|
| 14 |
+
class OpenAIEmbeddings(Embeddings):
|
| 15 |
+
"""基于 OpenAI Embedding API 的自定义嵌入类"""
|
| 16 |
+
|
| 17 |
+
def __init__(self):
|
| 18 |
+
self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
| 19 |
+
|
| 20 |
+
def embed_documents(self, texts):
|
| 21 |
+
embeddings = []
|
| 22 |
+
for text in texts:
|
| 23 |
+
response = self.client.embeddings.create(
|
| 24 |
+
model="text-embedding-3-small",
|
| 25 |
+
input=[text],
|
| 26 |
+
)
|
| 27 |
+
embeddings.append(response.data[0].embedding)
|
| 28 |
+
return embeddings
|
| 29 |
+
|
| 30 |
+
def embed_query(self, text):
|
| 31 |
+
# 查询文档
|
| 32 |
+
return self.embed_documents([text])[0]
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# ============================================================
|
| 36 |
+
# 模型2: 本地大语言模型, 通过 vLLM 服务 (OpenAI 兼容接口)
|
| 37 |
+
# 启动命令: vllm serve ./Qwen3-Next-80B-A3B-Thinking-AWQ-4bit --dtype auto --trust-remote-code --max-model-len 4096 --port 8000
|
| 38 |
+
# ============================================================
|
| 39 |
+
|
| 40 |
+
VLLM_MODEL_NAME = "./Qwen3-Next-80B-A3B-Thinking-AWQ-4bit"
|
| 41 |
+
VLLM_BASE_URL = "http://localhost:8000/v1"
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def create_local_llm_client():
|
| 45 |
+
"""创建本地 vLLM 客户端 (OpenAI 兼容接口)"""
|
| 46 |
+
client = OpenAI(api_key="none", base_url=VLLM_BASE_URL)
|
| 47 |
+
return client
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def generate_local_answer(client, question):
|
| 51 |
+
"""调用本地 vLLM 生成回答"""
|
| 52 |
+
response = client.chat.completions.create(
|
| 53 |
+
model=VLLM_MODEL_NAME,
|
| 54 |
+
messages=[
|
| 55 |
+
{"role": "system", "content": "你是一个能力非常强大的助手."},
|
| 56 |
+
{"role": "user", "content": question}
|
| 57 |
+
],
|
| 58 |
+
max_tokens=2048,
|
| 59 |
+
stream=False
|
| 60 |
+
)
|
| 61 |
+
return response.choices[0].message.content
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
# ============================================================
|
| 65 |
+
# 模型3: 远程大语言模型, 采用 ChatGPT (OpenAI API)
|
| 66 |
+
# ============================================================
|
| 67 |
+
|
| 68 |
+
def create_chatgpt_client():
|
| 69 |
+
"""创建 OpenAI ChatGPT 客户端"""
|
| 70 |
+
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
| 71 |
+
return client
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def generate_chatgpt_answer(client, question):
|
| 75 |
+
"""调用 ChatGPT 生成回答"""
|
| 76 |
+
response = client.chat.completions.create(
|
| 77 |
+
model="gpt-4o",
|
| 78 |
+
messages=[
|
| 79 |
+
{"role": "system", "content": "你是一个能力非常强大的助手."},
|
| 80 |
+
{"role": "user", "content": question}
|
| 81 |
+
],
|
| 82 |
+
stream=False
|
| 83 |
+
)
|
| 84 |
+
return response.choices[0].message.content
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
# ============================================================
|
| 88 |
+
# 测试入口
|
| 89 |
+
# ============================================================
|
| 90 |
+
|
| 91 |
+
if __name__ == "__main__":
|
| 92 |
+
# 测试本地 vLLM 模型
|
| 93 |
+
client = create_local_llm_client()
|
| 94 |
+
output = generate_local_answer(client, "你好啊,千与千寻")
|
| 95 |
+
print('-' * 50)
|
| 96 |
+
print(output)
|
| 97 |
+
|
| 98 |
+
# 测试远程 ChatGPT
|
| 99 |
+
#client = create_chatgpt_client()
|
| 100 |
+
#output = generate_chatgpt_answer(client, "你好啊,千与千寻")
|
| 101 |
+
#print('-' * 50)
|
| 102 |
+
#print(output)
|
pdf_agent.db
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:dd578ee2a4c3f4c9a3834d1486815e1f0002cd5b1de31784fa714cea1168f134
|
| 3 |
+
size 452308992
|
pdf_documents/01.内科学_第9版_全书签_可复制检索.pdf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3cf35d1734c7816e3ec58545d28ce67d5155c962a0472f3f5da1357b6a10b41e
|
| 3 |
+
size 486582267
|
pdf_documents/02.外科学_第9版_全书签_可复制检索.pdf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fdad0ad721ab78a5bf07507ff22d0e25ca8ab0e0f7041b770c7d35f90194711f
|
| 3 |
+
size 433303834
|
pdf_documents/03.妇产科学_第9版_全书签_可复制检索.pdf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:10f84dadb1feba7c255d95d02df9447e2d48e828dec7d99d1d152e8a763a389a
|
| 3 |
+
size 248708226
|
pdf_documents/04.儿科学_第9版_全书签_可复制检索.pdf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:28b99f77d6f44dc9e43697cf58168cdca9db74965efaa8764a89567066c13010
|
| 3 |
+
size 102781905
|
pdf_documents/05.神经病学_第9版_全书签_可复制检索.pdf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:064bb8ce4efa9715680713906992e05dc1a99dc8feabefbe8edf13bea53022e5
|
| 3 |
+
size 812868931
|
pdf_documents/06.系统解剖学_第9版_全书签_可复制检索.pdf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8356c6cacfa84b04df444c3714e8391790da78682e82c4b4a484f6f0975bd617
|
| 3 |
+
size 128118784
|
pdf_documents/07.局部解剖学_第9版_全书签_可部分复制检索.pdf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cf5762f141bcc8fc4ae2ae0540d64001293e30cd9e2a72d80d1946e798b56286
|
| 3 |
+
size 103522604
|
pdf_documents/08.组织学与胚胎学_第9版_全书签_可复制检索.pdf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:eac553218200f3ca71c49cae4c56cfc323e9f11f3d3f03340d8f33ee53406e42
|
| 3 |
+
size 49850201
|
pdf_documents/09.生物化学与分子生物学_第9版_全书签_可复制检索.pdf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:63bb2e1cf400dc6727834a2d0d8a25a8aa12795586f3b7b16eaed5f7f04db214
|
| 3 |
+
size 284284534
|
pdf_documents/10.生理学(可复制).pdf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9d6423f99687a487e52479589a81ce9cea69cc4832d9c3b065b8f44d8a00111c
|
| 3 |
+
size 412517165
|
pdf_output/pdf_detailed_text.xlsx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d2fb7a2bba1bc86d83071e35c1f523341ed14162e9b8d9f2f8a56610d3e2a6a7
|
| 3 |
+
size 6481056
|
pdf_output/pdf_extraction_summary.xlsx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:13c43bc65f29f3b878879ea496d369e7b14b238efc8f54a09a0019e739375bac
|
| 3 |
+
size 5700
|
pdf_output/pdf_processing.log
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
2026-02-09 02:07:34,106 - INFO - 在 ./pdf_documents 中找到 10 个PDF文件
|
| 2 |
+
2026-02-09 02:07:34,109 - INFO - 处理进度: 1/10 - 01.内科学_第9版_全书签_可复制检索.pdf
|
| 3 |
+
2026-02-09 02:10:28,298 - INFO - 成功处理: 01.内科学_第9版_全书签_可复制检索.pdf - 962 页
|
| 4 |
+
2026-02-09 02:10:28,312 - INFO - 处理进度: 2/10 - 02.外科学_第9版_全书签_可复制检索.pdf
|
| 5 |
+
2026-02-09 02:13:01,894 - INFO - 成功处理: 02.外科学_第9版_全书签_可复制检索.pdf - 830 页
|
| 6 |
+
2026-02-09 02:13:01,911 - INFO - 处理进度: 3/10 - 03.妇产科学_第9版_全书签_可复制检索.pdf
|
| 7 |
+
2026-02-09 02:14:33,064 - INFO - 成功处理: 03.妇产科学_第9版_全书签_可复制检索.pdf - 495 页
|
| 8 |
+
2026-02-09 02:14:33,080 - INFO - 处理进度: 4/10 - 04.儿科学_第9版_全书签_可复制检索.pdf
|
| 9 |
+
2026-02-09 02:15:26,939 - INFO - 成功处理: 04.儿科学_第9版_全书签_可复制检索.pdf - 478 页
|
| 10 |
+
2026-02-09 02:15:26,955 - INFO - 处理进度: 5/10 - 05.神经病学_第9版_全书签_可复制检索.pdf
|
| 11 |
+
2026-02-09 02:16:26,155 - INFO - 成功处理: 05.神经病学_第9版_全书签_可复制检索.pdf - 494 页
|
| 12 |
+
2026-02-09 02:16:26,162 - INFO - 处理进度: 6/10 - 06.系统解剖学_第9版_全书签_可复制检索.pdf
|
| 13 |
+
2026-02-09 02:16:29,044 - ERROR - 处理文件失败 pdf_documents/06.系统解剖学_第9版_全书签_可复制检索.pdf: Unexpected EOF
|
| 14 |
+
2026-02-09 02:16:29,044 - INFO - 处理进度: 7/10 - 07.局部解剖学_第9版_全书签_可部分复制检索.pdf
|
| 15 |
+
2026-02-09 02:16:32,187 - INFO - 成功处理: 07.局部解剖学_第9版_全书签_可部分复制检索.pdf - 318 页
|
| 16 |
+
2026-02-09 02:16:32,192 - INFO - 处理进度: 8/10 - 08.组织学与胚胎学_第9版_全书签_可复制检索.pdf
|
| 17 |
+
2026-02-09 02:18:05,497 - INFO - 成功处理: 08.组织学与胚胎学_第9版_全书签_可复制检索.pdf - 300 页
|
| 18 |
+
2026-02-09 02:18:05,501 - INFO - 处理进度: 9/10 - 09.生物化学与分子生物学_第9版_全书签_可复制检索.pdf
|
| 19 |
+
2026-02-09 02:19:53,250 - INFO - 成功处理: 09.生物化学与分子生物学_第9版_全书签_可复制检索.pdf - 559 页
|
| 20 |
+
2026-02-09 02:19:53,270 - INFO - 处理进度: 10/10 - 10.生理学(可复制).pdf
|
| 21 |
+
2026-02-09 02:20:26,372 - INFO - 成功处理: 10.生理学(可复制).pdf - 466 页
|
| 22 |
+
2026-02-09 02:20:27,907 - INFO - 结果已保存到 pdf_output
|
| 23 |
+
2026-02-09 02:20:27,913 - INFO - 处理完成: 9/10 个文件成功
|
| 24 |
+
2026-02-09 02:20:27,914 - INFO - 平均每文件: 707778 字符, 2.0 个表格
|
pdf_output/progress_batch_10.csv
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
file_name,status,pages_processed
|
| 2 |
+
01.内科学_第9版_全书签_可复制检索.pdf,Success,962
|
| 3 |
+
02.外科学_第9版_全书签_可复制检索.pdf,Success,830
|
| 4 |
+
03.妇产科学_第9版_全书签_可复制检索.pdf,Success,495
|
| 5 |
+
04.儿科学_第9版_全书签_可复制检索.pdf,Success,478
|
| 6 |
+
05.神经病学_第9版_全书签_可复制检索.pdf,Success,494
|
| 7 |
+
06.系统解剖学_第9版_全书签_可复制检索.pdf,Error,0
|
| 8 |
+
07.局部解剖学_第9版_全书签_可部分复制检索.pdf,Success,318
|
| 9 |
+
08.组织学与胚胎学_第9版_全书签_可复制检索.pdf,Success,300
|
| 10 |
+
09.生物化学与分子生物学_第9版_全书签_可复制检索.pdf,Success,559
|
| 11 |
+
10.生理学(可复制).pdf,Success,466
|
preprocess.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import glob
|
| 3 |
+
import logging
|
| 4 |
+
import pandas as pd
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
from typing import List, Dict, Optional
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
import pdfplumber
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# 工业级PDF批量处理器, 生产一线级别的代码
|
| 12 |
+
class PDFBatchProcessor:
|
| 13 |
+
def __init__(self, output_dir: str = "./output"):
|
| 14 |
+
self.output_dir = Path(output_dir)
|
| 15 |
+
self.output_dir.mkdir(exist_ok=True)
|
| 16 |
+
|
| 17 |
+
# 配置日志系统
|
| 18 |
+
logging.basicConfig(
|
| 19 |
+
level=logging.INFO,
|
| 20 |
+
format='%(asctime)s - %(levelname)s - %(message)s',
|
| 21 |
+
handlers=[
|
| 22 |
+
logging.FileHandler(self.output_dir / "pdf_processing.log"),
|
| 23 |
+
logging.StreamHandler()
|
| 24 |
+
]
|
| 25 |
+
)
|
| 26 |
+
self.logger = logging.getLogger(__name__)
|
| 27 |
+
|
| 28 |
+
# 查找指定路径下的所有PDF文件
|
| 29 |
+
def find_pdf_files(self, input_path: str) -> List[Path]:
|
| 30 |
+
path = Path(input_path)
|
| 31 |
+
if path.is_file() and path.suffix.lower() == '.pdf':
|
| 32 |
+
return [path]
|
| 33 |
+
elif path.is_dir():
|
| 34 |
+
# 递归查找所有PDF文件
|
| 35 |
+
pdf_files = list(path.glob("**/*.pdf"))
|
| 36 |
+
self.logger.info(f"在 {input_path} 中找到 {len(pdf_files)} 个PDF文件")
|
| 37 |
+
return pdf_files
|
| 38 |
+
else:
|
| 39 |
+
raise ValueError(f"路径不存在,或不是PDF文件: {input_path}")
|
| 40 |
+
|
| 41 |
+
# 提取单个PDF文件的内容
|
| 42 |
+
def extract_pdf_content(self,
|
| 43 |
+
pdf_path: Path,
|
| 44 |
+
extract_text: bool = True,
|
| 45 |
+
extract_tables: bool = True,
|
| 46 |
+
table_settings: Optional[dict] = None) -> Dict:
|
| 47 |
+
"""
|
| 48 |
+
Args:
|
| 49 |
+
pdf_path: PDF文件路径
|
| 50 |
+
extract_text: 是否提取文本
|
| 51 |
+
extract_tables: 是否提取表格
|
| 52 |
+
table_settings: 表格提取配置
|
| 53 |
+
"""
|
| 54 |
+
result = {
|
| 55 |
+
"file_name": pdf_path.name,
|
| 56 |
+
"file_path": str(pdf_path),
|
| 57 |
+
"metadata": {},
|
| 58 |
+
"pages": [],
|
| 59 |
+
"error": None
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
try:
|
| 63 |
+
with pdfplumber.open(pdf_path) as pdf:
|
| 64 |
+
# 提取元数据
|
| 65 |
+
result["metadata"] = pdf.metadata
|
| 66 |
+
|
| 67 |
+
for page_num, page in enumerate(pdf.pages, 1):
|
| 68 |
+
page_result = {"page_number": page_num, "text": "", "tables": []}
|
| 69 |
+
|
| 70 |
+
# 提取文本
|
| 71 |
+
if extract_text:
|
| 72 |
+
try:
|
| 73 |
+
# 布局模式根据需求调整
|
| 74 |
+
text = page.extract_text(layout=False)
|
| 75 |
+
page_result["text"] = text if text else ""
|
| 76 |
+
except Exception as e:
|
| 77 |
+
self.logger.warning(f"页面 {page_num} 文本提取失败: {str(e)}")
|
| 78 |
+
pass
|
| 79 |
+
|
| 80 |
+
# 提取表格
|
| 81 |
+
if extract_tables:
|
| 82 |
+
try:
|
| 83 |
+
tables = page.extract_tables(table_settings or {})
|
| 84 |
+
if tables:
|
| 85 |
+
page_result["tables"] = tables
|
| 86 |
+
except Exception as e:
|
| 87 |
+
self.logger.warning(f"页面 {page_num} 表格提取失败: {str(e)}")
|
| 88 |
+
pass
|
| 89 |
+
|
| 90 |
+
# 添加当前页面page的提取结果
|
| 91 |
+
result["pages"].append(page_result)
|
| 92 |
+
|
| 93 |
+
# 单一PDF文档提取完毕后, 写日志处理
|
| 94 |
+
self.logger.info(f"成功处理: {pdf_path.name} - {len(pdf.pages)} 页")
|
| 95 |
+
|
| 96 |
+
# 单一PDF文档提取失败后, 写日志处理
|
| 97 |
+
except Exception as e:
|
| 98 |
+
# 明确记录一下哪篇PDF文档处理失败, 并记录失败原因, 便于后续回溯与 "bad case分析"
|
| 99 |
+
error_msg = f"处理文件失败 {pdf_path}: {str(e)}"
|
| 100 |
+
result["error"] = error_msg
|
| 101 |
+
self.logger.error(error_msg)
|
| 102 |
+
|
| 103 |
+
return result
|
| 104 |
+
|
| 105 |
+
# 批量处理PDF文件
|
| 106 |
+
def process_batch(self, pdf_files: List[Path],
|
| 107 |
+
save_format: str = "excel",
|
| 108 |
+
**extract_kwargs) -> pd.DataFrame:
|
| 109 |
+
"""
|
| 110 |
+
Args:
|
| 111 |
+
pdf_files: PDF文件列表
|
| 112 |
+
save_format: 保存格式 (excel, csv, parquet)
|
| 113 |
+
**extract_kwargs: 提取参数
|
| 114 |
+
"""
|
| 115 |
+
all_results = []
|
| 116 |
+
|
| 117 |
+
for i, pdf_file in tqdm(enumerate(pdf_files, 1)):
|
| 118 |
+
self.logger.info(f"处理进度: {i}/{len(pdf_files)} - {pdf_file.name}")
|
| 119 |
+
|
| 120 |
+
result = self.extract_pdf_content(pdf_file, **extract_kwargs)
|
| 121 |
+
all_results.append(result)
|
| 122 |
+
|
| 123 |
+
# 实时保存进度 (针对大批量处理)
|
| 124 |
+
if i % 10 == 0:
|
| 125 |
+
self._save_intermediate_results(all_results, f"batch_{i}")
|
| 126 |
+
|
| 127 |
+
# 保存最终结果
|
| 128 |
+
return self._save_results(all_results, save_format)
|
| 129 |
+
|
| 130 |
+
# 保存处理结果
|
| 131 |
+
def _save_results(self, results: List[Dict], format: str) -> pd.DataFrame:
|
| 132 |
+
# 扁平化结果, 以��保存
|
| 133 |
+
flat_data = []
|
| 134 |
+
|
| 135 |
+
for result in results:
|
| 136 |
+
if result["error"]:
|
| 137 |
+
flat_data.append(
|
| 138 |
+
{
|
| 139 |
+
"file_name": result["file_name"],
|
| 140 |
+
"status": "Error",
|
| 141 |
+
"error_message": result["error"],
|
| 142 |
+
"page_count": 0,
|
| 143 |
+
"text_length": 0,
|
| 144 |
+
"table_count": 0
|
| 145 |
+
}
|
| 146 |
+
)
|
| 147 |
+
continue
|
| 148 |
+
|
| 149 |
+
total_text = ""
|
| 150 |
+
total_tables = 0
|
| 151 |
+
|
| 152 |
+
for page in result["pages"]:
|
| 153 |
+
total_text += page["text"]
|
| 154 |
+
total_tables += len(page["tables"])
|
| 155 |
+
|
| 156 |
+
flat_data.append({
|
| 157 |
+
"file_name": result["file_name"],
|
| 158 |
+
"status": "Success",
|
| 159 |
+
"error_message": "",
|
| 160 |
+
"page_count": len(result["pages"]),
|
| 161 |
+
"text_length": len(total_text),
|
| 162 |
+
"table_count": total_tables,
|
| 163 |
+
"author": result["metadata"].get("Author", ""),
|
| 164 |
+
"creation_date": result["metadata"].get("CreationDate", "")
|
| 165 |
+
})
|
| 166 |
+
|
| 167 |
+
# for循环处理完毕后, 所有数据封装成 Pandas 的 DataFrame 格式
|
| 168 |
+
df = pd.DataFrame(flat_data)
|
| 169 |
+
|
| 170 |
+
# 根据格式保存
|
| 171 |
+
if format.lower() == "excel":
|
| 172 |
+
df.to_excel(self.output_dir / "pdf_extraction_summary.xlsx", index=False)
|
| 173 |
+
|
| 174 |
+
# 同时保存详细文本内容
|
| 175 |
+
detailed_results = []
|
| 176 |
+
for result in results:
|
| 177 |
+
if not result["error"]:
|
| 178 |
+
for page in result["pages"]:
|
| 179 |
+
if page["text"]:
|
| 180 |
+
detailed_results.append({
|
| 181 |
+
"file_name": result["file_name"],
|
| 182 |
+
"page_number": page["page_number"],
|
| 183 |
+
"text_content": page["text"]
|
| 184 |
+
})
|
| 185 |
+
|
| 186 |
+
if detailed_results:
|
| 187 |
+
pd.DataFrame(detailed_results).to_excel(
|
| 188 |
+
self.output_dir / "pdf_detailed_text.xlsx", index=False
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
elif format.lower() == "csv":
|
| 192 |
+
df.to_csv(self.output_dir / "pdf_extraction_summary.csv", index=False)
|
| 193 |
+
|
| 194 |
+
self.logger.info(f"结果已保存到 {self.output_dir}")
|
| 195 |
+
return df
|
| 196 |
+
|
| 197 |
+
# 保存中间结果 (工业界一线生产环境, 异常因素很多, 防止处理中断丢失数据)
|
| 198 |
+
def _save_intermediate_results(self, results: List[Dict], batch_name: str):
|
| 199 |
+
try:
|
| 200 |
+
temp_df = pd.DataFrame([{
|
| 201 |
+
"file_name": r["file_name"],
|
| 202 |
+
"status": "Error" if r["error"] else "Success",
|
| 203 |
+
"pages_processed": len(r["pages"])
|
| 204 |
+
} for r in results])
|
| 205 |
+
|
| 206 |
+
temp_df.to_csv(self.output_dir / f"progress_{batch_name}.csv", index=False)
|
| 207 |
+
except Exception as e:
|
| 208 |
+
self.logger.warning(f"保存中间结果失败: {str(e)}")
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
# 高级表格提取配置
|
| 212 |
+
ADVANCED_TABLE_SETTINGS = {
|
| 213 |
+
"vertical_strategy": "lines",
|
| 214 |
+
"horizontal_strategy": "lines",
|
| 215 |
+
"snap_tolerance": 4,
|
| 216 |
+
"join_tolerance": 10,
|
| 217 |
+
"edge_min_length": 3,
|
| 218 |
+
"min_words_vertical": 2,
|
| 219 |
+
"min_words_horizontal": 1
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def main():
|
| 224 |
+
# 实例化PDF处理器对象
|
| 225 |
+
processor = PDFBatchProcessor(output_dir="./pdf_output")
|
| 226 |
+
|
| 227 |
+
try:
|
| 228 |
+
# 查找PDF文件
|
| 229 |
+
pdf_files = processor.find_pdf_files("./pdf_documents")
|
| 230 |
+
|
| 231 |
+
if not pdf_files:
|
| 232 |
+
processor.logger.warning("未找到PDF文件")
|
| 233 |
+
return
|
| 234 |
+
|
| 235 |
+
# 批量处理
|
| 236 |
+
results_df = processor.process_batch(
|
| 237 |
+
pdf_files,
|
| 238 |
+
save_format="excel",
|
| 239 |
+
extract_text=True,
|
| 240 |
+
extract_tables=True,
|
| 241 |
+
table_settings=ADVANCED_TABLE_SETTINGS
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
# 打印摘要统计
|
| 245 |
+
success_count = len(results_df[results_df["status"] == "Success"])
|
| 246 |
+
processor.logger.info(f"处理完成: {success_count}/{len(pdf_files)} 个文件成功")
|
| 247 |
+
|
| 248 |
+
if success_count > 0:
|
| 249 |
+
avg_text_length = results_df[results_df["status"] == "Success"]["text_length"].mean()
|
| 250 |
+
avg_tables = results_df[results_df["status"] == "Success"]["table_count"].mean()
|
| 251 |
+
processor.logger.info(f"平均每文件: {avg_text_length:.0f} 字符, {avg_tables:.1f} 个表格")
|
| 252 |
+
|
| 253 |
+
# 处理过程中发生错误, 记录日志
|
| 254 |
+
except Exception as e:
|
| 255 |
+
processor.logger.error(f"处理过程发生错误: {str(e)}")
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
if __name__ == "__main__":
|
| 259 |
+
main()
|
test.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import requests
|
| 2 |
+
import time
|
| 3 |
+
import json
|
| 4 |
+
|
| 5 |
+
url = "http://0.0.0.0:8103/"
|
| 6 |
+
data = {"question": "平日里蜂蜜加白醋一起喝有什么疗效?"}
|
| 7 |
+
#data = {"question": "听说用酸枣仁泡水喝能养生,是真的吗?"}
|
| 8 |
+
|
| 9 |
+
start_time = time.time()
|
| 10 |
+
|
| 11 |
+
data = json.dumps(data)
|
| 12 |
+
|
| 13 |
+
# 向服务发送请求
|
| 14 |
+
res = requests.post(url, data=data)
|
| 15 |
+
|
| 16 |
+
cost_time = time.time() - start_time
|
| 17 |
+
|
| 18 |
+
print('单次查询的耗时:', cost_time, 's')
|
| 19 |
+
|
| 20 |
+
res = json.loads(res.text)
|
| 21 |
+
print(res)
|
vector.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from pydantic import BaseModel
|
| 3 |
+
from tqdm import tqdm
|
| 4 |
+
import json
|
| 5 |
+
import uuid
|
| 6 |
+
import time
|
| 7 |
+
import pandas as pd
|
| 8 |
+
from openai import OpenAI
|
| 9 |
+
from langchain.embeddings.base import Embeddings
|
| 10 |
+
from langchain_core.documents import Document
|
| 11 |
+
from langchain_milvus import Milvus, BM25BuiltInFunction
|
| 12 |
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
| 13 |
+
from langchain_classic.retrievers.parent_document_retriever import ParentDocumentRetriever
|
| 14 |
+
from langchain_core.stores import InMemoryStore
|
| 15 |
+
from dotenv import load_dotenv
|
| 16 |
+
|
| 17 |
+
# 加载 .env 文件中的环境变量, 隐藏 API Keys
|
| 18 |
+
load_dotenv()
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# ============================================================
|
| 22 |
+
# 嵌入模型, 采用 OpenAI text-embedding-3-small
|
| 23 |
+
# ============================================================
|
| 24 |
+
|
| 25 |
+
class OpenAIEmbeddings(Embeddings):
|
| 26 |
+
"""基于 OpenAI Embedding API 的自定义嵌入类"""
|
| 27 |
+
|
| 28 |
+
def __init__(self):
|
| 29 |
+
self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
| 30 |
+
|
| 31 |
+
def embed_documents(self, texts):
|
| 32 |
+
embeddings = []
|
| 33 |
+
for text in texts:
|
| 34 |
+
response = self.client.embeddings.create(
|
| 35 |
+
model="text-embedding-3-small",
|
| 36 |
+
input=[text],
|
| 37 |
+
)
|
| 38 |
+
embeddings.append(response.data[0].embedding)
|
| 39 |
+
return embeddings
|
| 40 |
+
|
| 41 |
+
def embed_query(self, text):
|
| 42 |
+
# 查询文档
|
| 43 |
+
return self.embed_documents([text])[0]
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# ============================================================
|
| 47 |
+
# Milvus 向量数据库封装类 (第一路召回: JSONL 文本数据)
|
| 48 |
+
# ============================================================
|
| 49 |
+
|
| 50 |
+
class Milvus_vector():
|
| 51 |
+
def __init__(self, uri="./milvus_agent.db"):
|
| 52 |
+
self.URI = uri
|
| 53 |
+
self.embeddings = OpenAIEmbeddings()
|
| 54 |
+
|
| 55 |
+
# 定义索引类型
|
| 56 |
+
self.dense_index = {
|
| 57 |
+
"metric_type": "IP",
|
| 58 |
+
"index_type": "IVF_FLAT",
|
| 59 |
+
}
|
| 60 |
+
self.sparse_index = {
|
| 61 |
+
"metric_type": "BM25",
|
| 62 |
+
"index_type": "SPARSE_INVERTED_INDEX"
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
def create_vector_store(self, docs):
|
| 66 |
+
init_docs = docs[:10]
|
| 67 |
+
self.vectorstore = Milvus.from_documents(
|
| 68 |
+
documents=init_docs,
|
| 69 |
+
embedding=self.embeddings,
|
| 70 |
+
builtin_function=BM25BuiltInFunction(), # output_field_names="sparse",
|
| 71 |
+
index_params=[self.dense_index, self.sparse_index],
|
| 72 |
+
vector_field=["dense", "sparse"],
|
| 73 |
+
connection_args={
|
| 74 |
+
"uri": self.URI,
|
| 75 |
+
},
|
| 76 |
+
# 支持 ("Strong", "Session", "Bounded", "Eventually")
|
| 77 |
+
consistency_level="Bounded",
|
| 78 |
+
drop_old=False,
|
| 79 |
+
)
|
| 80 |
+
print("已初始化创建 Milvus ‼")
|
| 81 |
+
|
| 82 |
+
count = 10
|
| 83 |
+
temp = []
|
| 84 |
+
for doc in tqdm(docs[10:]):
|
| 85 |
+
temp.append(doc)
|
| 86 |
+
if len(temp) >= 5:
|
| 87 |
+
self.vectorstore.aadd_documents(temp)
|
| 88 |
+
count += len(temp)
|
| 89 |
+
temp = []
|
| 90 |
+
print(f"已插入 {count} 条数据......")
|
| 91 |
+
time.sleep(1)
|
| 92 |
+
|
| 93 |
+
print(f"总共插入 {count} 条数据......")
|
| 94 |
+
print("已创建 Milvus 索引完成 ‼")
|
| 95 |
+
|
| 96 |
+
return self.vectorstore
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
# ============================================================
|
| 100 |
+
# PDF 父子文档检索器 (第二路召回: PDF 文档数据)
|
| 101 |
+
# ============================================================
|
| 102 |
+
|
| 103 |
+
class Pdf_retriever():
|
| 104 |
+
def __init__(self, uri="./pdf_agent.db"):
|
| 105 |
+
self.URI = uri
|
| 106 |
+
self.embeddings = OpenAIEmbeddings()
|
| 107 |
+
|
| 108 |
+
# 定义索引类型
|
| 109 |
+
self.dense_index = {
|
| 110 |
+
"metric_type": "IP",
|
| 111 |
+
"index_type": "IVF_FLAT",
|
| 112 |
+
}
|
| 113 |
+
self.sparse_index = {
|
| 114 |
+
"metric_type": "BM25",
|
| 115 |
+
"index_type": "SPARSE_INVERTED_INDEX"
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
self.docstore = InMemoryStore()
|
| 119 |
+
|
| 120 |
+
# 文本分割器
|
| 121 |
+
self.child_splitter = RecursiveCharacterTextSplitter(
|
| 122 |
+
chunk_size=200,
|
| 123 |
+
chunk_overlap=50,
|
| 124 |
+
length_function=len,
|
| 125 |
+
separators=["\n\n", "\n", "。", "!", "?", ";", ",", " ", ""]
|
| 126 |
+
)
|
| 127 |
+
self.parent_splitter = RecursiveCharacterTextSplitter(
|
| 128 |
+
chunk_size=1000,
|
| 129 |
+
chunk_overlap=200
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
def create_pdf_vector_store(self, docs):
|
| 133 |
+
self.milvus_vectorstore = Milvus(
|
| 134 |
+
embedding_function=self.embeddings,
|
| 135 |
+
builtin_function=BM25BuiltInFunction(),
|
| 136 |
+
vector_field=["dense", "sparse"],
|
| 137 |
+
index_params=[
|
| 138 |
+
{
|
| 139 |
+
"metric_type": "IP",
|
| 140 |
+
"index_type": "IVF_FLAT",
|
| 141 |
+
},
|
| 142 |
+
{
|
| 143 |
+
"metric_type": "BM25",
|
| 144 |
+
"index_type": "SPARSE_INVERTED_INDEX"
|
| 145 |
+
}
|
| 146 |
+
],
|
| 147 |
+
connection_args={"uri": self.URI},
|
| 148 |
+
consistency_level="Bounded",
|
| 149 |
+
drop_old=False,
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
# 设置父子文档检索器
|
| 153 |
+
self.retriever = ParentDocumentRetriever(
|
| 154 |
+
vectorstore=self.milvus_vectorstore,
|
| 155 |
+
docstore=self.docstore,
|
| 156 |
+
child_splitter=self.child_splitter,
|
| 157 |
+
parent_splitter=self.parent_splitter,
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
# 添加文档
|
| 161 |
+
count = 0
|
| 162 |
+
temp = []
|
| 163 |
+
for doc in tqdm(docs):
|
| 164 |
+
temp.append(doc)
|
| 165 |
+
if len(temp) >= 10:
|
| 166 |
+
# ParentDocumentRetriever()不支持异步等待操作
|
| 167 |
+
self.retriever.add_documents(temp)
|
| 168 |
+
count += len(temp)
|
| 169 |
+
temp = []
|
| 170 |
+
print(f"已插入 {count} 条数据......")
|
| 171 |
+
time.sleep(1)
|
| 172 |
+
|
| 173 |
+
print(f"总共插入 {count} 条数据......")
|
| 174 |
+
print("基于PDF文档数据的 Milvus 索引完成 ‼")
|
| 175 |
+
|
| 176 |
+
return self.retriever
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
# ============================================================
|
| 180 |
+
# 数据预处理: 从 JSONL 文件加载文档 (第一路)
|
| 181 |
+
# ============================================================
|
| 182 |
+
|
| 183 |
+
def prepare_document(file_path=['./data/dialog.jsonl', './data/train.jsonl']):
|
| 184 |
+
# 逐条取出文本数据, 创建嵌入张量, 然后将张量数据插入Milvus
|
| 185 |
+
file_path1 = file_path[0]
|
| 186 |
+
|
| 187 |
+
count = 0
|
| 188 |
+
docs = []
|
| 189 |
+
|
| 190 |
+
with open(file_path1, 'r', encoding='utf-8') as f:
|
| 191 |
+
for line in f:
|
| 192 |
+
content = json.loads(line.strip())
|
| 193 |
+
prompt = content['query'] + "\n" + content['response']
|
| 194 |
+
|
| 195 |
+
temp_doc = Document(page_content=prompt, metadata={"doc_id": str(uuid.uuid4())})
|
| 196 |
+
docs.append(temp_doc)
|
| 197 |
+
|
| 198 |
+
count += 1
|
| 199 |
+
|
| 200 |
+
print(f"已加载 {count} 条数据!")
|
| 201 |
+
|
| 202 |
+
return docs
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
# ============================================================
|
| 206 |
+
# 数据预处理: 从 PDF 提取结果加载文档 (第二路)
|
| 207 |
+
# ============================================================
|
| 208 |
+
|
| 209 |
+
def prepare_pdf_document(file_path="./pdf_output/pdf_detailed_text.xlsx"):
|
| 210 |
+
df = pd.read_excel(file_path)
|
| 211 |
+
|
| 212 |
+
# 空行直接删除, 否则后续处理报错
|
| 213 |
+
df = df.dropna(subset=['text_content'])
|
| 214 |
+
|
| 215 |
+
# 将DataFrame转换为LangChain文档
|
| 216 |
+
documents = []
|
| 217 |
+
for _, row in df.iterrows():
|
| 218 |
+
# 确保 text_content 是字符串, 且不为 NaN
|
| 219 |
+
text_content = str(row['text_content']) if pd.notna(row['text_content']) else ""
|
| 220 |
+
|
| 221 |
+
doc = Document(
|
| 222 |
+
page_content=text_content.strip(),
|
| 223 |
+
metadata={"doc_id": str(uuid.uuid4())}
|
| 224 |
+
)
|
| 225 |
+
documents.append(doc)
|
| 226 |
+
|
| 227 |
+
print(f"成功加载 {len(documents)} 个文档")
|
| 228 |
+
|
| 229 |
+
return documents
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
# ============================================================
|
| 233 |
+
# 主入口: 执行数据入库流程
|
| 234 |
+
# ============================================================
|
| 235 |
+
|
| 236 |
+
if __name__ == "__main__":
|
| 237 |
+
'''
|
| 238 |
+
# 预处理即将插入 Milvus 的文档数据
|
| 239 |
+
docs = prepare_document()
|
| 240 |
+
print("预处理文档数据成功......")
|
| 241 |
+
|
| 242 |
+
# 创建 Milvus 连接
|
| 243 |
+
milvus_vectorstore = Milvus_vector()
|
| 244 |
+
print("创建Milvus连接成功......")
|
| 245 |
+
|
| 246 |
+
# 创建向量索引
|
| 247 |
+
vectorstore = milvus_vectorstore.create_vector_store(docs)
|
| 248 |
+
print("全部初始化完成, 可以开始问答了......")
|
| 249 |
+
'''
|
| 250 |
+
|
| 251 |
+
# 将 PDF 后处理文档中的数据, 封装成Document
|
| 252 |
+
docs = prepare_pdf_document()
|
| 253 |
+
print("预处理 PDF 文档数据成功......")
|
| 254 |
+
# print(docs[0])
|
| 255 |
+
|
| 256 |
+
pdf_vectorstore = Pdf_retriever()
|
| 257 |
+
print("创建 PDF Milvus 连接成功......")
|
| 258 |
+
|
| 259 |
+
retriever = pdf_vectorstore.create_pdf_vector_store(docs)
|
| 260 |
+
print("创建基于 Milvus 数据库的父子文档检索器成功......")
|
| 261 |
+
print(retriever)
|
| 262 |
+
print("全部初始化完成, 可以开始问答了......")
|