drewli20200316 commited on
Commit
635abe8
·
verified ·
1 Parent(s): 528ba5d

Upload vector2.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. vector2.py +307 -0
vector2.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pydantic import BaseModel
3
+ from tqdm import tqdm
4
+ import json
5
+ import uuid
6
+ import time
7
+ import redis
8
+ import pandas as pd
9
+ from openai import OpenAI
10
+ from langchain.embeddings.base import Embeddings
11
+ from langchain_core.documents import Document
12
+ from langchain_milvus import Milvus, BM25BuiltInFunction
13
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
14
+ from langchain_classic.retrievers.parent_document_retriever import ParentDocumentRetriever
15
+ from langchain_core.stores import InMemoryStore
16
+ from dotenv import load_dotenv
17
+
18
+ # 加载 .env 文件中的环境变量, 隐藏 API Keys
19
+ load_dotenv()
20
+
21
+
22
+ # ============================================================
23
+ # Redis 缓存处理模块
24
+ # ============================================================
25
+
26
+ def get_redis_client():
27
+ # 创建Redis连接, 使用连接池 (推荐用于生产环境)
28
+ pool = redis.ConnectionPool(host='0.0.0.0', port=6379, db=0, password=None, max_connections=10)
29
+ r = redis.StrictRedis(connection_pool=pool)
30
+
31
+ # 测试连接
32
+ try:
33
+ r.ping()
34
+ print("成功连接到 Redis !")
35
+ except redis.ConnectionError:
36
+ print("无法连接到 Redis !")
37
+
38
+ return r
39
+
40
+
41
+ # 将 (question, answer) 问答对, 存入 redis
42
+ def cache_set(r, question: str, answer: str):
43
+ r.hset("qa", question, answer)
44
+ r.expire("qa", 3600)
45
+
46
+
47
+ # 通过 question, 读取存在 redis 中的 answer
48
+ def cache_get(r, question: str):
49
+ return r.hget("qa", question)
50
+
51
+
52
+ # ============================================================
53
+ # 嵌入模型, 采用 OpenAI text-embedding-3-small
54
+ # ============================================================
55
+
56
+ class OpenAIEmbeddings(Embeddings):
57
+ """基于 OpenAI Embedding API 的自定义嵌入类"""
58
+
59
+ def __init__(self):
60
+ self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
61
+
62
+ def embed_documents(self, texts):
63
+ embeddings = []
64
+ for text in texts:
65
+ response = self.client.embeddings.create(
66
+ model="text-embedding-3-small",
67
+ input=[text],
68
+ )
69
+ embeddings.append(response.data[0].embedding)
70
+ return embeddings
71
+
72
+ def embed_query(self, text):
73
+ # 查询文档
74
+ return self.embed_documents([text])[0]
75
+
76
+
77
+ # ============================================================
78
+ # Milvus 向量数据库封装类 (第一路召回: JSONL 文本数据)
79
+ # ============================================================
80
+
81
+ class Milvus_vector():
82
+ def __init__(self, uri="./milvus_agent.db", collection_name="LangChainCollection"):
83
+ self.URI = uri
84
+ self.collection_name = collection_name
85
+ self.embeddings = OpenAIEmbeddings()
86
+
87
+ # 定义索引类型
88
+ self.dense_index = {
89
+ "metric_type": "IP",
90
+ "index_type": "IVF_FLAT",
91
+ }
92
+ self.sparse_index = {
93
+ "metric_type": "BM25",
94
+ "index_type": "SPARSE_INVERTED_INDEX"
95
+ }
96
+
97
+ def create_vector_store(self, docs):
98
+ init_docs = docs[:10]
99
+ self.vectorstore = Milvus.from_documents(
100
+ documents=init_docs,
101
+ embedding=self.embeddings,
102
+ builtin_function=BM25BuiltInFunction(), # output_field_names="sparse",
103
+ index_params=[self.dense_index, self.sparse_index],
104
+ vector_field=["dense", "sparse"],
105
+ connection_args={
106
+ "uri": self.URI,
107
+ },
108
+ collection_name=self.collection_name,
109
+ # 支持 ("Strong", "Session", "Bounded", "Eventually")
110
+ consistency_level="Bounded",
111
+ drop_old=False,
112
+ )
113
+ print("已初始化创建 Milvus ‼")
114
+
115
+ count = 10
116
+ temp = []
117
+ for doc in tqdm(docs[10:]):
118
+ temp.append(doc)
119
+ if len(temp) >= 5:
120
+ self.vectorstore.aadd_documents(temp)
121
+ count += len(temp)
122
+ temp = []
123
+ print(f"已插入 {count} 条数据......")
124
+ time.sleep(1)
125
+
126
+ print(f"总共插入 {count} 条数据......")
127
+ print("已创建 Milvus 索引完成 ‼")
128
+
129
+ return self.vectorstore
130
+
131
+
132
+ # ============================================================
133
+ # PDF 父子文档检索器 (第二路召回: PDF 文档数据)
134
+ # ============================================================
135
+
136
+ class Pdf_retriever():
137
+ def __init__(self, uri="./pdf_agent.db", collection_name="LangChainCollection"):
138
+ self.URI = uri
139
+ self.collection_name = collection_name
140
+ self.embeddings = OpenAIEmbeddings()
141
+
142
+ # 定义索引类型
143
+ self.dense_index = {
144
+ "metric_type": "IP",
145
+ "index_type": "IVF_FLAT",
146
+ }
147
+ self.sparse_index = {
148
+ "metric_type": "BM25",
149
+ "index_type": "SPARSE_INVERTED_INDEX"
150
+ }
151
+
152
+ self.docstore = InMemoryStore()
153
+
154
+ # 文本分割器
155
+ self.child_splitter = RecursiveCharacterTextSplitter(
156
+ chunk_size=200,
157
+ chunk_overlap=50,
158
+ length_function=len,
159
+ separators=["\n\n", "\n", "。", "!", "?", ";", ",", " ", ""]
160
+ )
161
+ self.parent_splitter = RecursiveCharacterTextSplitter(
162
+ chunk_size=1000,
163
+ chunk_overlap=200
164
+ )
165
+
166
+ def create_pdf_vector_store(self, docs):
167
+ self.milvus_vectorstore = Milvus(
168
+ embedding_function=self.embeddings,
169
+ builtin_function=BM25BuiltInFunction(),
170
+ vector_field=["dense", "sparse"],
171
+ index_params=[
172
+ {
173
+ "metric_type": "IP",
174
+ "index_type": "IVF_FLAT",
175
+ },
176
+ {
177
+ "metric_type": "BM25",
178
+ "index_type": "SPARSE_INVERTED_INDEX"
179
+ }
180
+ ],
181
+ connection_args={"uri": self.URI},
182
+ collection_name=self.collection_name,
183
+ consistency_level="Bounded",
184
+ drop_old=False,
185
+ )
186
+
187
+ # 设置父子文档检索器
188
+ self.retriever = ParentDocumentRetriever(
189
+ vectorstore=self.milvus_vectorstore,
190
+ docstore=self.docstore,
191
+ child_splitter=self.child_splitter,
192
+ parent_splitter=self.parent_splitter,
193
+ )
194
+
195
+ # 添加文档
196
+ count = 0
197
+ temp = []
198
+ for doc in tqdm(docs):
199
+ temp.append(doc)
200
+ if len(temp) >= 10:
201
+ # ParentDocumentRetriever()不支持异步等待操作
202
+ self.retriever.add_documents(temp)
203
+ count += len(temp)
204
+ temp = []
205
+ print(f"已插入 {count} 条数据......")
206
+ time.sleep(1)
207
+
208
+ print(f"总共插入 {count} 条数据......")
209
+ print("基于PDF文档数据的 Milvus 索引完成 ‼")
210
+
211
+ return self.retriever
212
+
213
+
214
+ # ============================================================
215
+ # 数据预处理: 从 JSONL 文件加载文档 (第一路)
216
+ # ============================================================
217
+
218
+ def prepare_document(file_path=['./data/dialog.jsonl', './data/train.jsonl']):
219
+ # 逐条取出文本数据, 创建嵌入张量, 然后将张量数据插入Milvus
220
+ file_path1 = file_path[0]
221
+
222
+ count = 0
223
+ docs = []
224
+
225
+ with open(file_path1, 'r', encoding='utf-8') as f:
226
+ for line in f:
227
+ content = json.loads(line.strip())
228
+ prompt = content['query'] + "\n" + content['response']
229
+
230
+ temp_doc = Document(page_content=prompt, metadata={"doc_id": str(uuid.uuid4())})
231
+ docs.append(temp_doc)
232
+
233
+ count += 1
234
+
235
+ print(f"已加载 {count} 条数据!")
236
+
237
+ return docs
238
+
239
+
240
+ # ============================================================
241
+ # 数据预处理: 从 PDF 提取结果加载文档 (第二路)
242
+ # ============================================================
243
+
244
+ def prepare_pdf_document(file_path="./pdf_output/pdf_detailed_text.xlsx"):
245
+ df = pd.read_excel(file_path)
246
+
247
+ # 空行直接删除, 否则后续处理报错
248
+ df = df.dropna(subset=['text_content'])
249
+
250
+ # 将DataFrame转换为LangChain文档
251
+ documents = []
252
+ for _, row in df.iterrows():
253
+ # 确保 text_content 是字符串, 且不为 NaN
254
+ text_content = str(row['text_content']) if pd.notna(row['text_content']) else ""
255
+
256
+ doc = Document(
257
+ page_content=text_content.strip(),
258
+ metadata={"doc_id": str(uuid.uuid4())}
259
+ )
260
+ documents.append(doc)
261
+
262
+ print(f"成功加载 {len(documents)} 个文档")
263
+
264
+ return documents
265
+
266
+
267
+ # ============================================================
268
+ # 主入口: 执行数据入库流程
269
+ # ============================================================
270
+
271
+ if __name__ == "__main__":
272
+ # ============================================================
273
+ # 数据灌入 Milvus Server (agent6 多 Worker 模式)
274
+ # ============================================================
275
+ # collection_name 必须与 agent6.py 中的一致:
276
+ # medical_agent → 第一路 JSONL 医学问答
277
+ # medical_pdf → 第二路 PDF 文档
278
+
279
+ MILVUS_SERVER_URI = os.getenv("MILVUS_URI", "http://localhost:19530")
280
+
281
+ # --- 第一路: JSONL 数据 → medical_agent ---
282
+ docs = prepare_document()
283
+ print("预处理文档数据成功......")
284
+
285
+ milvus_vectorstore = Milvus_vector(
286
+ uri=MILVUS_SERVER_URI,
287
+ collection_name="medical_agent",
288
+ )
289
+ print("创建 Milvus 连接成功......")
290
+
291
+ vectorstore = milvus_vectorstore.create_vector_store(docs)
292
+ print("第一路 (JSONL) 数据灌入完成 ✅")
293
+
294
+ # --- 第二路: PDF 数据 → medical_pdf ---
295
+ pdf_docs = prepare_pdf_document()
296
+ print("预处理 PDF 文档数据成功......")
297
+
298
+ pdf_vectorstore = Pdf_retriever(
299
+ uri=MILVUS_SERVER_URI,
300
+ collection_name="medical_pdf",
301
+ )
302
+ print("创建 PDF Milvus 连接成功......")
303
+
304
+ retriever = pdf_vectorstore.create_pdf_vector_store(pdf_docs)
305
+ print("第二路 (PDF) 数据灌入完成 ✅")
306
+
307
+ print("全部数据灌入 Milvus Server 完成, 可以启动 agent6.py 了 ✅")