drewli20200316 commited on
Commit
528ba5d
·
verified ·
1 Parent(s): f0531e2

Upload agent6.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. agent6.py +356 -0
agent6.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ================================================================
3
+ agent6.py — 多 Worker 版 Medical RAG Agent
4
+ ================================================================
5
+ 基于 agent5.py, 新增 P2 优化:
6
+ P0: 三路召回并行化 (asyncio.gather) ← 继承 agent5
7
+ P1: AsyncOpenAI 客户端 (async LLM 推理) ← 继承 agent5
8
+ P2: Milvus Lite → Milvus Server + workers=4 ← 新增
9
+
10
+ 架构变化:
11
+ agent5.py: 单 worker + async (Milvus Lite 文件锁限制)
12
+ agent6.py: 4 workers × async (Milvus Server 网络连接, 无文件锁)
13
+
14
+ Worker 1 ──→ ┐
15
+ Worker 2 ──→ ├── Milvus Server (:19530) ──→ 数据持久化
16
+ Worker 3 ──→ ┤
17
+ Worker 4 ──→ ┘
18
+
19
+ 前置条件:
20
+ 1. 安装并启动 Milvus Server (Docker):
21
+ docker run -d --name milvus-standalone \
22
+ -p 19530:19530 -p 9091:9091 \
23
+ milvusdb/milvus:latest
24
+
25
+ 2. 将已有数据从 Milvus Lite 迁移到 Milvus Server:
26
+ 参考: https://milvus.io/docs/migrate_overview.md
27
+
28
+ 3. .env 中配置 (可选, 有默认值):
29
+ MILVUS_URI=http://localhost:19530
30
+
31
+ 运行:
32
+ python agent6.py
33
+ # Uvicorn running on http://0.0.0.0:8103 (4 workers)
34
+ ================================================================
35
+ """
36
+
37
+ import os
38
+ import uvicorn
39
+ import asyncio
40
+ from fastapi import FastAPI, Request
41
+ from fastapi.middleware.cors import CORSMiddleware
42
+ import json
43
+ import datetime
44
+ import hashlib
45
+ import logging
46
+
47
+ import httpx
48
+ from openai import AsyncOpenAI
49
+ from neo4j import GraphDatabase
50
+ from langchain_milvus import Milvus, BM25BuiltInFunction
51
+ from vector import OpenAIEmbeddings
52
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
53
+ from langchain_core.stores import InMemoryStore
54
+ from langchain_classic.retrievers.parent_document_retriever import ParentDocumentRetriever
55
+ from dotenv import load_dotenv
56
+
57
+ from new_redis import redis_manager
58
+
59
+ load_dotenv()
60
+
61
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
62
+ logger = logging.getLogger("agent6")
63
+
64
+ app = FastAPI()
65
+
66
+ app.add_middleware(
67
+ CORSMiddleware,
68
+ allow_origins=["*"],
69
+ allow_credentials=True,
70
+ allow_methods=["*"],
71
+ allow_headers=["*"],
72
+ )
73
+
74
+
75
+ # ============================================================
76
+ # 全局资源初始化 (每个 worker 进程各自初始化一份)
77
+ # ============================================================
78
+
79
+ embedding_model = OpenAIEmbeddings()
80
+ print("创建 Embedding 模型成功......")
81
+
82
+ # ============================================================
83
+ # P2: Milvus Lite → Milvus Server
84
+ # ============================================================
85
+ # agent4/agent5: URI = "./milvus_agent.db" (本地文件, 单进程独占)
86
+ # agent6: URI = "http://localhost:19530" (网络连接, 多进程共享)
87
+ #
88
+ # Milvus Server 是独立进程, 通过 gRPC 端口 19530 对外服务.
89
+ # 4 个 worker 各自建立网络连接, 不再争抢文件锁.
90
+
91
+ MILVUS_URI = os.getenv("MILVUS_URI", "http://localhost:19530")
92
+
93
+ milvus_vectorstore = Milvus(
94
+ embedding_function=embedding_model,
95
+ builtin_function=BM25BuiltInFunction(),
96
+ vector_field=["dense", "sparse"],
97
+ index_params=[
98
+ {"metric_type": "IP", "index_type": "IVF_FLAT"},
99
+ {"metric_type": "BM25", "index_type": "SPARSE_INVERTED_INDEX"},
100
+ ],
101
+ connection_args={"uri": MILVUS_URI},
102
+ collection_name="medical_agent", # 显式指定 collection 名称
103
+ )
104
+ print(f"创建 Milvus 连接成功...... (URI: {MILVUS_URI})")
105
+
106
+ docstore = InMemoryStore()
107
+
108
+ child_splitter = RecursiveCharacterTextSplitter(
109
+ chunk_size=200, chunk_overlap=50, length_function=len,
110
+ separators=["\n\n", "\n", "。", "!", "?", ";", ",", " ", ""],
111
+ )
112
+ parent_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
113
+
114
+ pdf_vectorstore = Milvus(
115
+ embedding_function=embedding_model,
116
+ builtin_function=BM25BuiltInFunction(),
117
+ vector_field=["dense", "sparse"],
118
+ index_params=[
119
+ {"metric_type": "IP", "index_type": "IVF_FLAT"},
120
+ {"metric_type": "BM25", "index_type": "SPARSE_INVERTED_INDEX"},
121
+ ],
122
+ connection_args={"uri": MILVUS_URI},
123
+ collection_name="medical_pdf", # 显式指定 collection 名称
124
+ consistency_level="Bounded",
125
+ drop_old=False,
126
+ )
127
+
128
+ parent_retriever = ParentDocumentRetriever(
129
+ vectorstore=pdf_vectorstore,
130
+ docstore=docstore,
131
+ child_splitter=child_splitter,
132
+ parent_splitter=parent_splitter,
133
+ )
134
+ print("创建 Parent Milvus 连接成功......")
135
+
136
+ neo4j_uri = os.getenv("NEO4J_URI", "bolt://localhost:7687")
137
+ neo4j_user = os.getenv("NEO4J_USER", "neo4j")
138
+ neo4j_password = os.getenv("NEO4J_PASSWORD", "neo4j")
139
+ driver = GraphDatabase.driver(
140
+ uri=neo4j_uri, auth=(neo4j_user, neo4j_password), max_connection_lifetime=1000,
141
+ )
142
+ print("创建 Neo4j 连接成功......")
143
+
144
+ # P1: AsyncOpenAI 客户端
145
+ async_openai_client = AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY"))
146
+ print("创建 AsyncOpenAI LLM 成功......")
147
+
148
+ # Cypher API 用 httpx.AsyncClient
149
+ cypher_http_client = httpx.AsyncClient(timeout=30.0)
150
+
151
+ print("创建 Redis 连接成功......")
152
+
153
+
154
+ # ============================================================
155
+ # P0: 三路召回 — 各自独立的 async 函数
156
+ # ============================================================
157
+
158
+ def format_docs(docs):
159
+ return "\n\n".join(doc.page_content for doc in docs)
160
+
161
+
162
+ async def retrieve_milvus(query: str) -> str:
163
+ """路径 1: Milvus 向量召回"""
164
+ try:
165
+ results = await asyncio.to_thread(
166
+ milvus_vectorstore.similarity_search,
167
+ query, k=10, ranker_type="rrf", ranker_params={"k": 100},
168
+ )
169
+ return format_docs(results) if results else ""
170
+ except Exception as e:
171
+ logger.warning(f"Milvus 召回失败: {e}")
172
+ return ""
173
+
174
+
175
+ async def retrieve_pdf(query: str) -> str:
176
+ """路径 2: PDF 父子文档检索"""
177
+ try:
178
+ docs = await asyncio.to_thread(parent_retriever.invoke, query)
179
+ if docs and len(docs) >= 1:
180
+ return docs[0].page_content
181
+ return ""
182
+ except Exception as e:
183
+ logger.warning(f"PDF 检索失败: {e}")
184
+ return ""
185
+
186
+
187
+ async def retrieve_neo4j(query: str) -> str:
188
+ """路径 3: Neo4j 图数据库召回"""
189
+ try:
190
+ payload = json.dumps({"natural_language_query": query})
191
+ resp = await cypher_http_client.post("http://0.0.0.0:8101/generate", content=payload)
192
+
193
+ if resp.status_code != 200:
194
+ return ""
195
+
196
+ data = resp.json()
197
+ cypher_query = data.get("cypher_query")
198
+ confidence = data.get("confidence", 0)
199
+ is_valid = data.get("validated", False)
200
+
201
+ if not cypher_query or float(confidence) < 0.9 or not is_valid:
202
+ return ""
203
+
204
+ print("neo4j Cypher 初步生成成功 !!!")
205
+
206
+ val_payload = json.dumps({"cypher_query": cypher_query})
207
+ val_resp = await cypher_http_client.post("http://0.0.0.0:8101/validate", content=val_payload)
208
+
209
+ if val_resp.status_code != 200 or not val_resp.json().get("is_valid"):
210
+ return ""
211
+
212
+ def _run_neo4j():
213
+ with driver.session() as session:
214
+ record = session.run(cypher_query)
215
+ result = list(map(lambda x: x[0], record))
216
+ return ",".join(result)
217
+
218
+ return await asyncio.to_thread(_run_neo4j)
219
+
220
+ except Exception as e:
221
+ logger.warning(f"Neo4j 召回失败: {e}")
222
+ return ""
223
+
224
+
225
+ # ============================================================
226
+ # P0 + P1: 异步并行 RAG + 异步 LLM 推理
227
+ # ============================================================
228
+
229
+ async def perform_rag_and_llm_async(query: str) -> str:
230
+ """异步版 RAG 流程"""
231
+
232
+ milvus_ctx, pdf_ctx, neo4j_ctx = await asyncio.gather(
233
+ retrieve_milvus(query),
234
+ retrieve_pdf(query),
235
+ retrieve_neo4j(query),
236
+ )
237
+
238
+ context = "\n".join(filter(None, [milvus_ctx, pdf_ctx, neo4j_ctx]))
239
+
240
+ SYSTEM_PROMPT = """
241
+ System: 你是一个非常得力的医学助手, 你可以通过从数据库中检索出的信息找到问题的答案.
242
+ """
243
+
244
+ USER_PROMPT = f"""
245
+ User: 利用介于<context>和</context>之间的从数据库中检索出的信息来回答问题, 具体的问题介于<question>和</question>之间. 如果提供的信息为空, 则按照你的经验知识来给出尽可能严谨准确的回答, 不知道的时候坦诚的承认不了解, 不要编造不真实的信息.
246
+ <context>
247
+ {context}
248
+ </context>
249
+
250
+ <question>
251
+ {query}
252
+ </question>
253
+ """
254
+
255
+ response = await async_openai_client.chat.completions.create(
256
+ model="gpt-4o-mini",
257
+ messages=[{"role": "user", "content": SYSTEM_PROMPT + USER_PROMPT}],
258
+ temperature=0.7,
259
+ )
260
+
261
+ return response.choices[0].message.content
262
+
263
+
264
+ # ============================================================
265
+ # Redis 缓存 + 异步 RAG 的衔接
266
+ # ============================================================
267
+
268
+ async def get_or_compute_async(question: str) -> str:
269
+ """异步版 get_or_compute (防击穿/防雪崩/双重检查)"""
270
+
271
+ cached = await asyncio.to_thread(redis_manager.get_answer, question)
272
+ if cached:
273
+ print("REDIS HIT !!!✅😊")
274
+ return cached
275
+
276
+ hash_key = hashlib.md5(question.encode("utf-8")).hexdigest()
277
+ lock_token = await asyncio.to_thread(redis_manager.acquire_lock, hash_key)
278
+
279
+ if lock_token:
280
+ try:
281
+ cached_retry = await asyncio.to_thread(redis_manager.get_answer, question)
282
+ if cached_retry:
283
+ print("REDIS HIT (Double Check) !!!✅😊")
284
+ return cached_retry
285
+
286
+ print("Cache Miss ❌, Computing async RAG + LLM...")
287
+ answer = await perform_rag_and_llm_async(question)
288
+
289
+ if answer:
290
+ await asyncio.to_thread(redis_manager.set_answer, question, answer)
291
+ else:
292
+ await asyncio.to_thread(
293
+ redis_manager.client.setex,
294
+ redis_manager._generate_key(question), 60, "<EMPTY>",
295
+ )
296
+
297
+ return answer
298
+ finally:
299
+ await asyncio.to_thread(redis_manager.release_lock, hash_key, lock_token)
300
+ else:
301
+ await asyncio.sleep(0.1)
302
+ cached_fallback = await asyncio.to_thread(redis_manager.get_answer, question)
303
+ return cached_fallback or "System busy, calculating..."
304
+
305
+
306
+ # ============================================================
307
+ # FastAPI 路由
308
+ # ============================================================
309
+
310
+ @app.post("/")
311
+ async def chatbot(request: Request):
312
+ try:
313
+ json_post_raw = await request.json()
314
+
315
+ if isinstance(json_post_raw, str):
316
+ json_post_list = json.loads(json_post_raw)
317
+ else:
318
+ json_post_list = json_post_raw
319
+
320
+ query = json_post_list.get("question")
321
+
322
+ if not query:
323
+ return {"status": 400, "error": "Question is required"}
324
+
325
+ response = await get_or_compute_async(query)
326
+
327
+ now = datetime.datetime.now()
328
+ time_str = now.strftime("%Y-%m-%d %H:%M:%S")
329
+
330
+ return {
331
+ "response": response,
332
+ "status": 200,
333
+ "time": time_str,
334
+ }
335
+
336
+ except Exception as e:
337
+ print(f"Server Error: {e}")
338
+ return {"status": 500, "error": str(e)}
339
+
340
+
341
+ # ============================================================
342
+ # P2: 多 Worker 启动
343
+ # ============================================================
344
+ # Milvus Server 通过网络端口提供服务, 不再有文件锁限制,
345
+ # 4 个 worker 进程各自建立独立连接, 互不干扰.
346
+ #
347
+ # 每个 worker 内部仍然是 async (P0 + P1),
348
+ # 所以总并发能力 = 4 workers × 每 worker ~5 并发 ≈ 20 并发用户
349
+
350
+ if __name__ == "__main__":
351
+ uvicorn.run(
352
+ "agent6:app", # 字符串形式, 多 worker 必须这样写
353
+ host="0.0.0.0",
354
+ port=8103,
355
+ workers=4, # P2: 4 个 worker 进程
356
+ )