drewli20200316 commited on
Commit
dae033a
·
verified ·
1 Parent(s): 21497ab

Add test/test1.py

Browse files
Files changed (1) hide show
  1. test/test1.py +811 -0
test/test1.py ADDED
@@ -0,0 +1,811 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ================================================================
3
+ 医疗 RAG Agent 单元测试 — 单工具调用准确性
4
+ ================================================================
5
+ 测试对象: 项目中的 6 个核心组件 ("工具")
6
+
7
+ 工具1: Redis 缓存管理器 (new_redis.py - RedisClientWrapper)
8
+ 工具2: OpenAI Embedding (vector.py - OpenAIEmbeddings)
9
+ 工具3: Milvus 向量检索 (agent4.py - similarity_search + format_docs)
10
+ 工具4: PDF 父子文档检索 (agent4.py - parent_retriever)
11
+ 工具5: Neo4j 图数据库查询 (agent4.py - Cypher 生成 → 校验 → 执行)
12
+ 工具6: OpenAI LLM 推理 (agent4.py - generate_openai_answer)
13
+
14
+ 额外覆盖:
15
+ 工具7: PDF 批处理器 (preprocess.py - PDFBatchProcessor)
16
+ 工具8: 数据预处理 (vector.py - prepare_document)
17
+ 工具9: 端到端 RAG 流程编排 (agent4.py - perform_rag_and_llm 逻辑)
18
+
19
+ 测试原则:
20
+ ✅ 每个组件独立测试, 用 Mock/Patch 隔离外部依赖
21
+ ✅ 正常路径 + 异常路径 + 边界条件
22
+ ✅ 不需要真实的 Redis / Milvus / Neo4j / OpenAI 连接
23
+ ✅ 用 sys.modules 拦截无法安装的第三方包
24
+
25
+ 运行:
26
+ pytest test_agent_unit.py -v --tb=short
27
+ pytest test_agent_unit.py -v -k "Redis" # 只跑 Redis
28
+ pytest test_agent_unit.py -v -k "Embedding" # 只跑 Embedding
29
+ pytest test_agent_unit.py -v -k "Neo4j" # 只跑 Neo4j
30
+ ================================================================
31
+ """
32
+
33
+ import sys
34
+ import os
35
+ # 关键: 将项目根目录 (test/ 的上级) 加入 Python 搜索路径
36
+ # 这样 test/ 子目录中的测试文件才能找到 new_redis.py, vector.py, preprocess.py 等模块
37
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
38
+
39
+ import types
40
+ import pytest
41
+ import json
42
+ import hashlib
43
+ import time
44
+ import uuid
45
+ import random
46
+ from unittest.mock import MagicMock, patch, PropertyMock
47
+ from pathlib import Path
48
+ from dataclasses import dataclass, field
49
+ from typing import Optional, List
50
+
51
+
52
+ # ================================================================
53
+ # 前置: 用 sys.modules 拦截无法安装的第三方依赖
54
+ # 这样 `from vector import X` 不会因缺少 langchain_classic 而崩溃
55
+ # ================================================================
56
+
57
+ def _ensure_mock_module(name):
58
+ """如果模块不存在, 注入一个 MagicMock 占位"""
59
+ if name not in sys.modules:
60
+ sys.modules[name] = MagicMock()
61
+
62
+ # 拦截所有可能缺失的依赖
63
+ _MOCK_MODULES = [
64
+ "langchain_classic",
65
+ "langchain_classic.retrievers",
66
+ "langchain_classic.retrievers.parent_document_retriever",
67
+ "langchain_milvus",
68
+ "langchain_text_splitters",
69
+ "langchain_core",
70
+ "langchain_core.stores",
71
+ "langchain_core.documents",
72
+ "langchain.embeddings",
73
+ "langchain.embeddings.base",
74
+ "neo4j",
75
+ "dotenv",
76
+ "uvicorn",
77
+ "fastapi",
78
+ "fastapi.middleware",
79
+ "fastapi.middleware.cors",
80
+ ]
81
+
82
+ for mod in _MOCK_MODULES:
83
+ _ensure_mock_module(mod)
84
+
85
+ # 关键修复: langchain Embeddings 基类必须是真正的 class, 否则继承会失败
86
+ class _FakeEmbeddingsBase:
87
+ """占位基类, 让 OpenAIEmbeddings 能正常继承"""
88
+ pass
89
+
90
+ sys.modules["langchain.embeddings.base"].Embeddings = _FakeEmbeddingsBase
91
+
92
+
93
+ # ================================================================
94
+ # 测试辅助: 模拟对象和数据工厂
95
+ # ================================================================
96
+
97
+ class FakeRedisClient:
98
+ """内存字典模拟 Redis, 完整实现单元测试所需的全部方法"""
99
+
100
+ def __init__(self):
101
+ self._store = {}
102
+ self._expiry = {}
103
+
104
+ def ping(self):
105
+ return True
106
+
107
+ def get(self, key):
108
+ return self._store.get(key, None)
109
+
110
+ def set(self, key, value, ex=None, nx=False):
111
+ if nx and key in self._store:
112
+ return False
113
+ self._store[key] = value
114
+ if ex:
115
+ self._expiry[key] = ex
116
+ return True
117
+
118
+ def setex(self, key, expire, value):
119
+ self._store[key] = value
120
+ self._expiry[key] = expire
121
+ return True
122
+
123
+ def delete(self, key):
124
+ return 1 if self._store.pop(key, None) is not None else 0
125
+
126
+ def hset(self, name, key, value):
127
+ self._store.setdefault(name, {})[key] = value
128
+
129
+ def hget(self, name, key):
130
+ return self._store.get(name, {}).get(key, None)
131
+
132
+ def expire(self, key, seconds):
133
+ self._expiry[key] = seconds
134
+
135
+ def register_script(self, script):
136
+ """模拟 Lua 脚本: 原子 CAS 删除"""
137
+ def fake_script(keys=None, args=None):
138
+ if keys and args and self._store.get(keys[0]) == args[0]:
139
+ del self._store[keys[0]]
140
+ return 1
141
+ return 0
142
+ return fake_script
143
+
144
+
145
+ @dataclass
146
+ class FakeDocument:
147
+ """模拟 LangChain Document"""
148
+ page_content: str
149
+ metadata: dict = field(default_factory=dict)
150
+
151
+
152
+ class FakeEmbeddingResponse:
153
+ """模拟 OpenAI Embedding API 响应"""
154
+ def __init__(self, embedding):
155
+ obj = type('EmbObj', (), {'embedding': embedding})()
156
+ self.data = [obj]
157
+
158
+
159
+ class FakeChatResponse:
160
+ """模拟 OpenAI Chat Completion 响应"""
161
+ def __init__(self, content):
162
+ msg = type('Msg', (), {'content': content})()
163
+ choice = type('Choice', (), {'message': msg})()
164
+ self.choices = [choice]
165
+
166
+
167
+ # ================================================================
168
+ # 辅助: 创建被测 RedisClientWrapper 实例 (注入假 Redis)
169
+ # ================================================================
170
+
171
+ def make_redis_manager():
172
+ """构造一个使用内存假 Redis 的 RedisClientWrapper"""
173
+ from new_redis import RedisClientWrapper
174
+ RedisClientWrapper._pool = "FAKE" # 跳过连接池创建
175
+ mgr = object.__new__(RedisClientWrapper) # 跳过 __init__
176
+ mgr.client = FakeRedisClient()
177
+ mgr.unlock_script = mgr.client.register_script("")
178
+ return mgr
179
+
180
+
181
+ # ================================================================
182
+ # 工具 1: Redis 缓存管理器
183
+ # ================================================================
184
+
185
+ class TestRedisManager:
186
+ """
187
+ 测试 new_redis.py - RedisClientWrapper
188
+ 覆盖: 缓存读写 / 防穿透 / 防雪崩 / 防击穿(分布式锁) / get_or_compute
189
+ """
190
+
191
+ def setup_method(self):
192
+ self.mgr = make_redis_manager()
193
+ self.fake = self.mgr.client # 直接访问底层假 Redis
194
+
195
+ # ---- 1.1 Key 生成 ----
196
+
197
+ def test_key_deterministic(self):
198
+ """相同问题 → 相同 Key"""
199
+ k1 = self.mgr._generate_key("高血压不能吃什么?")
200
+ k2 = self.mgr._generate_key("高血压不能吃什么?")
201
+ assert k1 == k2
202
+
203
+ def test_key_unique(self):
204
+ """不同问题 → 不同 Key"""
205
+ k1 = self.mgr._generate_key("高血压不能吃什么?")
206
+ k2 = self.mgr._generate_key("糖尿病怎么治疗?")
207
+ assert k1 != k2
208
+
209
+ def test_key_has_prefix(self):
210
+ """Key 应带 'llm:cache:' 前缀"""
211
+ k = self.mgr._generate_key("test")
212
+ assert k.startswith("llm:cache:")
213
+
214
+ def test_key_is_md5(self):
215
+ """Key 后缀应为 MD5 哈希"""
216
+ q = "测试问题"
217
+ k = self.mgr._generate_key(q)
218
+ expected_hash = hashlib.md5(q.encode('utf-8')).hexdigest()
219
+ assert k == f"llm:cache:{expected_hash}"
220
+
221
+ # ---- 1.2 基础读写 ----
222
+
223
+ def test_set_then_get(self):
224
+ """写入后读取, 值一致"""
225
+ self.mgr.set_answer("Q1", "A1")
226
+ assert self.mgr.get_answer("Q1") == "A1"
227
+
228
+ def test_cache_miss_returns_none(self):
229
+ """未写入的 Key 返回 None"""
230
+ assert self.mgr.get_answer("不存在的问题") is None
231
+
232
+ # ---- 1.3 防缓存穿透 ----
233
+
234
+ def test_empty_marker_returns_none(self):
235
+ """<EMPTY> 占位符 → get_answer 返回 None (不穿透到 LLM)"""
236
+ key = self.mgr._generate_key("空结果问题")
237
+ self.fake.setex(key, 60, "<EMPTY>")
238
+ assert self.mgr.get_answer("空结果问题") is None
239
+
240
+ def test_get_or_compute_writes_empty_on_null(self):
241
+ """LLM 返回空 → 写入 <EMPTY> 防穿透"""
242
+ self.mgr.get_or_compute("空问题", lambda: "")
243
+ key = self.mgr._generate_key("空问题")
244
+ assert self.fake.get(key) == "<EMPTY>"
245
+
246
+ # ---- 1.4 防缓存雪崩 ----
247
+
248
+ def test_random_expiry_jitter(self):
249
+ """多次写入同一过期时间, 实际 TTL 应有随机抖动"""
250
+ ttls = set()
251
+ for i in range(30):
252
+ self.mgr.set_answer(f"Q_{i}", f"A_{i}", expire_time=3600)
253
+ k = self.mgr._generate_key(f"Q_{i}")
254
+ ttls.add(self.fake._expiry.get(k))
255
+ assert len(ttls) > 1, "过期时间应存在随机抖动, 防止集体失效"
256
+
257
+ # ---- 1.5 分布式锁 (防击穿) ----
258
+
259
+ def test_lock_acquire_success(self):
260
+ """正常获取锁"""
261
+ token = self.mgr.acquire_lock("my_lock", acquire_timeout=1)
262
+ assert token is not None
263
+
264
+ def test_lock_mutual_exclusion(self):
265
+ """已持有锁时, 二次获取应超时失败"""
266
+ t1 = self.mgr.acquire_lock("excl", acquire_timeout=0.1)
267
+ t2 = self.mgr.acquire_lock("excl", acquire_timeout=0.1)
268
+ assert t1 is not None
269
+ assert t2 is None, "互斥: 不应同时获取两把锁"
270
+
271
+ def test_lock_release(self):
272
+ """释放锁后, Key 被删除"""
273
+ token = self.mgr.acquire_lock("rel_lock")
274
+ assert self.mgr.release_lock("rel_lock", token) is True
275
+ assert self.fake.get("lock:rel_lock") is None
276
+
277
+ def test_lock_wrong_token_rejected(self):
278
+ """用错误 token 释放锁应失败"""
279
+ self.mgr.acquire_lock("sec_lock")
280
+ assert self.mgr.release_lock("sec_lock", "wrong-uuid") is False
281
+
282
+ # ---- 1.6 get_or_compute 完整流程 ----
283
+
284
+ def test_cache_hit_skips_compute(self):
285
+ """缓存命中 → 不调用 compute_func"""
286
+ self.mgr.set_answer("cached_q", "cached_a")
287
+ called = False
288
+
289
+ def spy():
290
+ nonlocal called; called = True; return "new"
291
+
292
+ result = self.mgr.get_or_compute("cached_q", spy)
293
+ assert result == "cached_a"
294
+ assert called is False
295
+
296
+ def test_cache_miss_calls_compute(self):
297
+ """缓存未命中 → 调用 compute_func 并缓存"""
298
+ result = self.mgr.get_or_compute("new_q", lambda: "LLM答案")
299
+ assert result == "LLM答案"
300
+ assert self.mgr.get_answer("new_q") == "LLM答案"
301
+
302
+ def test_double_check_prevents_redundant_compute(self):
303
+ """Double Check: 获取锁后再次检查, 避免重复调用 LLM"""
304
+ call_count = 0
305
+ original_get = self.mgr.get_answer
306
+
307
+ def patched_get(q):
308
+ nonlocal call_count; call_count += 1
309
+ if call_count == 1:
310
+ return None # 第一次: miss
311
+ return "其他线程写入" # 第二次 (Double Check): hit
312
+
313
+ self.mgr.get_answer = patched_get
314
+
315
+ def should_not_call():
316
+ raise AssertionError("Double Check 成功时不应调 LLM")
317
+
318
+ result = self.mgr.get_or_compute("dc_q", should_not_call)
319
+ assert result == "其他线程写入"
320
+
321
+
322
+ # ================================================================
323
+ # 工具 2: OpenAI Embedding 模型
324
+ # ================================================================
325
+
326
+ class TestEmbedding:
327
+ """
328
+ 测试 vector.py - OpenAIEmbeddings
329
+ 覆盖: embed_query / embed_documents / 维度一致性 / API 异常
330
+ """
331
+
332
+ def _make_embedder(self, mock_client):
333
+ """用 Mock OpenAI client 构造 embedder, 绕过真实连接"""
334
+ from vector import OpenAIEmbeddings
335
+ embedder = object.__new__(OpenAIEmbeddings)
336
+ embedder.client = mock_client
337
+ return embedder
338
+
339
+ def _fake_vec(self, dim=1536):
340
+ return [random.uniform(-1, 1) for _ in range(dim)]
341
+
342
+ def test_embed_query_dimension(self):
343
+ """单条嵌入: 返回 1536 维向量"""
344
+ mock = MagicMock()
345
+ mock.embeddings.create.return_value = FakeEmbeddingResponse(self._fake_vec())
346
+ emb = self._make_embedder(mock)
347
+
348
+ vec = emb.embed_query("高血压症状")
349
+ assert isinstance(vec, list)
350
+ assert len(vec) == 1536
351
+
352
+ def test_embed_documents_batch(self):
353
+ """批量嵌入: 3 条文本 → 3 个向量"""
354
+ mock = MagicMock()
355
+ mock.embeddings.create.side_effect = [
356
+ FakeEmbeddingResponse(self._fake_vec()),
357
+ FakeEmbeddingResponse(self._fake_vec()),
358
+ FakeEmbeddingResponse(self._fake_vec()),
359
+ ]
360
+ emb = self._make_embedder(mock)
361
+
362
+ vecs = emb.embed_documents(["A", "B", "C"])
363
+ assert len(vecs) == 3
364
+ assert all(len(v) == 1536 for v in vecs)
365
+
366
+ def test_embed_query_calls_correct_model(self):
367
+ """验证调用时传入 model='text-embedding-3-small'"""
368
+ mock = MagicMock()
369
+ mock.embeddings.create.return_value = FakeEmbeddingResponse(self._fake_vec())
370
+ emb = self._make_embedder(mock)
371
+
372
+ emb.embed_query("test")
373
+
374
+ # 检查 create() 被调用时的参数
375
+ call_kwargs = mock.embeddings.create.call_args.kwargs
376
+ assert call_kwargs.get("model") == "text-embedding-3-small"
377
+
378
+ def test_embed_empty_text(self):
379
+ """空字符串也应返回向量 (不报错)"""
380
+ mock = MagicMock()
381
+ mock.embeddings.create.return_value = FakeEmbeddingResponse(self._fake_vec())
382
+ emb = self._make_embedder(mock)
383
+
384
+ vec = emb.embed_query("")
385
+ assert isinstance(vec, list) and len(vec) == 1536
386
+
387
+ def test_embed_api_error_propagates(self):
388
+ """API 报错时异常应向上传播"""
389
+ mock = MagicMock()
390
+ mock.embeddings.create.side_effect = Exception("Rate limit exceeded")
391
+ emb = self._make_embedder(mock)
392
+
393
+ with pytest.raises(Exception, match="Rate limit"):
394
+ emb.embed_query("test")
395
+
396
+ def test_embed_chinese_medical_text(self):
397
+ """中文医学文本嵌入应正常工作"""
398
+ mock = MagicMock()
399
+ mock.embeddings.create.return_value = FakeEmbeddingResponse(self._fake_vec())
400
+ emb = self._make_embedder(mock)
401
+
402
+ vec = emb.embed_query("宫腔异形的治疗方案有哪些?")
403
+ assert len(vec) == 1536
404
+
405
+
406
+ # ================================================================
407
+ # 工具 3: Milvus 向量检索
408
+ # ================================================================
409
+
410
+ class TestMilvusRetrieval:
411
+ """
412
+ 测试 Milvus similarity_search + format_docs
413
+ 覆盖: 正常召回 / 空结果 / 重排序参数 / 异常降级
414
+ """
415
+
416
+ def test_format_docs_normal(self):
417
+ """3 篇文档 → 用双换行拼接"""
418
+ docs = [
419
+ FakeDocument(page_content="高血压是常见疾病"),
420
+ FakeDocument(page_content="建议低盐饮食"),
421
+ FakeDocument(page_content="定期测量血压"),
422
+ ]
423
+ result = "\n\n".join(d.page_content for d in docs)
424
+ assert result.count("\n\n") == 2
425
+ assert "低盐饮食" in result
426
+
427
+ def test_format_docs_empty(self):
428
+ """空列表 → 空字符串"""
429
+ assert "\n\n".join(d.page_content for d in []) == ""
430
+
431
+ def test_format_docs_single(self):
432
+ """单篇文档 → 无分隔符"""
433
+ result = "\n\n".join(d.page_content for d in [FakeDocument(page_content="唯一")])
434
+ assert result == "唯一"
435
+
436
+ def test_similarity_search_returns_topk(self):
437
+ """Milvus 应返回 k=10 的 top-k 结果"""
438
+ mock_vs = MagicMock()
439
+ mock_vs.similarity_search.return_value = [
440
+ FakeDocument(page_content=f"doc_{i}") for i in range(10)
441
+ ]
442
+ results = mock_vs.similarity_search("query", k=10, ranker_type="rrf", ranker_params={"k": 100})
443
+ assert len(results) == 10
444
+
445
+ def test_similarity_search_rrf_params_passed(self):
446
+ """验证 RRF 重排序参数被正确传递"""
447
+ mock_vs = MagicMock()
448
+ mock_vs.similarity_search.return_value = []
449
+
450
+ mock_vs.similarity_search("q", k=10, ranker_type="rrf", ranker_params={"k": 100})
451
+
452
+ call_kwargs = mock_vs.similarity_search.call_args.kwargs
453
+ assert call_kwargs["ranker_type"] == "rrf"
454
+ assert call_kwargs["ranker_params"] == {"k": 100}
455
+
456
+ def test_similarity_search_empty(self):
457
+ """无匹配时 → context 为空"""
458
+ mock_vs = MagicMock()
459
+ mock_vs.similarity_search.return_value = []
460
+
461
+ results = mock_vs.similarity_search("xyz无关查询")
462
+ context = "\n\n".join(d.page_content for d in results) if results else ""
463
+ assert context == ""
464
+
465
+ def test_similarity_search_exception(self):
466
+ """Milvus 服务异常 → 应抛出异常 (agent 层决定降级策略)"""
467
+ mock_vs = MagicMock()
468
+ mock_vs.similarity_search.side_effect = ConnectionError("Milvus timeout")
469
+
470
+ with pytest.raises(ConnectionError):
471
+ mock_vs.similarity_search("test")
472
+
473
+
474
+ # ================================================================
475
+ # 工具 4: PDF 父子文档检索
476
+ # ================================================================
477
+
478
+ class TestPDFRetrieval:
479
+ """
480
+ 测试 parent_retriever.invoke()
481
+ 覆盖: 正常召回 / 空结果 / None / 长文档
482
+ """
483
+
484
+ def test_retriever_returns_document(self):
485
+ """正常检索 → 返回至少 1 篇文档"""
486
+ mock_ret = MagicMock()
487
+ mock_ret.invoke.return_value = [
488
+ FakeDocument(page_content="根据《高血压防治指南》第三章...")
489
+ ]
490
+ results = mock_ret.invoke("高血压分级标准")
491
+ assert len(results) >= 1
492
+ assert "高血压" in results[0].page_content
493
+
494
+ def test_retriever_empty_list(self):
495
+ """无匹配 → pdf_res 为空"""
496
+ mock_ret = MagicMock()
497
+ mock_ret.invoke.return_value = []
498
+
499
+ results = mock_ret.invoke("xyz")
500
+ pdf_res = results[0].page_content if results else ""
501
+ assert pdf_res == ""
502
+
503
+ def test_retriever_none_safe(self):
504
+ """返回 None → 不报错, pdf_res 为空"""
505
+ mock_ret = MagicMock()
506
+ mock_ret.invoke.return_value = None
507
+
508
+ results = mock_ret.invoke("test")
509
+ pdf_res = ""
510
+ if results is not None and len(results) >= 1:
511
+ pdf_res = results[0].page_content
512
+ assert pdf_res == ""
513
+
514
+ def test_retriever_long_document(self):
515
+ """长文档应完整返回"""
516
+ long = "医学文献内容。" * 500
517
+ mock_ret = MagicMock()
518
+ mock_ret.invoke.return_value = [FakeDocument(page_content=long)]
519
+
520
+ r = mock_ret.invoke("长文档")
521
+ assert len(r[0].page_content) == len(long)
522
+
523
+ def test_retriever_multiple_results_takes_first(self):
524
+ """agent4.py 只取 results[0], 验证此行为"""
525
+ mock_ret = MagicMock()
526
+ mock_ret.invoke.return_value = [
527
+ FakeDocument(page_content="最相关"),
528
+ FakeDocument(page_content="第二篇"),
529
+ ]
530
+ results = mock_ret.invoke("test")
531
+ pdf_res = results[0].page_content if results else ""
532
+ assert pdf_res == "最相关"
533
+
534
+
535
+ # ================================================================
536
+ # 工具 5: Neo4j 图数据库查询 (Cypher 生成 → 校验 → 执行)
537
+ # ================================================================
538
+
539
+ class TestNeo4jCypherPipeline:
540
+ """
541
+ 测试 agent4.py 中 Neo4j 三阶段流程:
542
+ Stage 1: POST /generate → Cypher + confidence + validated
543
+ Stage 2: POST /validate → is_valid
544
+ Stage 3: session.run(cypher) → 结果提取
545
+ """
546
+
547
+ # ---- Stage 1: Cypher 生成决策逻辑 ----
548
+
549
+ def test_high_confidence_valid_executes(self):
550
+ """0.95 + validated=True → 执行"""
551
+ d = {"cypher_query": "MATCH (d:Disease) RETURN d", "confidence": 0.95, "validated": True}
552
+ assert (d["cypher_query"] is not None and float(d["confidence"]) >= 0.9 and d["validated"]) is True
553
+
554
+ def test_low_confidence_skips(self):
555
+ """0.5 < 0.9 → 不执行"""
556
+ d = {"cypher_query": "MATCH", "confidence": 0.5, "validated": True}
557
+ assert (float(d["confidence"]) >= 0.9 and d["validated"]) is False
558
+
559
+ def test_invalid_skips(self):
560
+ """validated=False → 不执行"""
561
+ d = {"cypher_query": "BAD", "confidence": 0.99, "validated": False}
562
+ assert (float(d["confidence"]) >= 0.9 and d["validated"]) is False
563
+
564
+ def test_null_cypher_skips(self):
565
+ """cypher_query=None → 不执行"""
566
+ d = {"cypher_query": None, "confidence": 0.95, "validated": True}
567
+ assert (d["cypher_query"] is not None) is False
568
+
569
+ def test_boundary_089_skips(self):
570
+ """边界 0.89 → 不执行"""
571
+ assert (0.89 >= 0.9) is False
572
+
573
+ def test_boundary_090_executes(self):
574
+ """边界 0.90 → 执行"""
575
+ assert (0.90 >= 0.9) is True
576
+
577
+ # ---- Stage 2: Cypher 校验 ----
578
+
579
+ def test_validate_pass(self):
580
+ resp = MagicMock(); resp.json.return_value = {"is_valid": True}
581
+ assert resp.json()["is_valid"] is True
582
+
583
+ def test_validate_fail(self):
584
+ resp = MagicMock(); resp.json.return_value = {"is_valid": False}
585
+ assert resp.json()["is_valid"] is False
586
+
587
+ # ---- Stage 3: Cypher 执行 ----
588
+
589
+ def test_neo4j_run_success(self):
590
+ """正常执行 → 逗号拼接"""
591
+ mock_session = MagicMock()
592
+ mock_session.run.return_value = [("高血压",), ("糖尿病",)]
593
+ result = list(map(lambda x: x[0], mock_session.run("MATCH ...")))
594
+ assert ','.join(result) == "高血压,糖尿病"
595
+
596
+ def test_neo4j_run_empty(self):
597
+ """空结果 → 空字符串"""
598
+ mock_session = MagicMock()
599
+ mock_session.run.return_value = []
600
+ result = list(map(lambda x: x[0], mock_session.run("MATCH ...")))
601
+ assert ','.join(result) == ""
602
+
603
+ def test_neo4j_run_exception_graceful(self):
604
+ """查询异常 → 降级为空"""
605
+ mock_session = MagicMock()
606
+ mock_session.run.side_effect = Exception("Connection lost")
607
+ neo4j_res = ""
608
+ try:
609
+ result = list(map(lambda x: x[0], mock_session.run("BAD")))
610
+ neo4j_res = ','.join(result)
611
+ except Exception:
612
+ neo4j_res = ""
613
+ assert neo4j_res == ""
614
+
615
+ def test_cypher_service_down(self):
616
+ """Cypher API 宕机 → 降级为空"""
617
+ with patch('requests.post', side_effect=ConnectionError("refused")):
618
+ neo4j_res = ""
619
+ try:
620
+ import requests
621
+ requests.post("http://0.0.0.0:8101/generate", "{}")
622
+ except Exception:
623
+ neo4j_res = ""
624
+ assert neo4j_res == ""
625
+
626
+
627
+ # ================================================================
628
+ # 工具 6: OpenAI LLM 推理
629
+ # ================================================================
630
+
631
+ class TestLLMInference:
632
+ """
633
+ 测试 agent4.py - generate_openai_answer
634
+ 覆盖: 正常生成 / Prompt 构建 / 空返回 / 异常
635
+ """
636
+
637
+ def test_generate_success(self):
638
+ """正常生成回复"""
639
+ mock = MagicMock()
640
+ mock.chat.completions.create.return_value = FakeChatResponse(
641
+ "高血压患者应避免高盐饮食, 每日钠 <6g."
642
+ )
643
+ answer = mock.chat.completions.create(
644
+ model="gpt-4o-mini",
645
+ messages=[{"role": "user", "content": "高血压饮食"}],
646
+ temperature=0.7,
647
+ ).choices[0].message.content
648
+ assert "高血压" in answer and len(answer) > 10
649
+
650
+ def test_prompt_structure(self):
651
+ """Prompt 包含: 系统角色 + <context> + <question>"""
652
+ query, context = "高血压不能吃什么?", "低盐饮食"
653
+ SYSTEM = "System: 你是一个非常得力的医学助手."
654
+ USER = f"<context>\n{context}\n</context>\n<question>\n{query}\n</question>"
655
+ full = SYSTEM + USER
656
+ assert "医学助手" in full and query in full and context in full
657
+
658
+ def test_prompt_empty_context(self):
659
+ """上下文为空 → Prompt 仍完整"""
660
+ p = "<context>\n\n</context>\n<question>\n糖尿病?\n</question>"
661
+ assert "<context>" in p and "糖尿病" in p
662
+
663
+ def test_llm_timeout(self):
664
+ """LLM 超时 → 异常传播"""
665
+ mock = MagicMock()
666
+ mock.chat.completions.create.side_effect = TimeoutError("timeout")
667
+ with pytest.raises(TimeoutError):
668
+ mock.chat.completions.create(model="gpt-4o-mini", messages=[])
669
+
670
+ def test_llm_empty_response(self):
671
+ """LLM 返回空"""
672
+ mock = MagicMock()
673
+ mock.chat.completions.create.return_value = FakeChatResponse("")
674
+ answer = mock.chat.completions.create(model="m", messages=[]).choices[0].message.content
675
+ assert answer == ""
676
+
677
+ def test_generate_with_temperature(self):
678
+ """验证 temperature=0.7 被正确传递"""
679
+ mock = MagicMock()
680
+ mock.chat.completions.create.return_value = FakeChatResponse("ok")
681
+
682
+ mock.chat.completions.create(
683
+ model="gpt-4o-mini",
684
+ messages=[{"role": "user", "content": "test"}],
685
+ temperature=0.7,
686
+ )
687
+ assert mock.chat.completions.create.call_args.kwargs["temperature"] == 0.7
688
+
689
+
690
+ # ================================================================
691
+ # 工具 7: PDF 批处理器 (preprocess.py)
692
+ # ================================================================
693
+
694
+ class TestPDFProcessor:
695
+ """
696
+ 测试 preprocess.py - PDFBatchProcessor
697
+ """
698
+
699
+ def test_invalid_path_raises(self):
700
+ from preprocess import PDFBatchProcessor
701
+ proc = PDFBatchProcessor(output_dir="/tmp/test_pdf_out")
702
+ with pytest.raises(ValueError, match="路径不存在"):
703
+ proc.find_pdf_files("/nonexistent/xyz.txt")
704
+
705
+ def test_empty_dir(self, tmp_path):
706
+ from preprocess import PDFBatchProcessor
707
+ proc = PDFBatchProcessor(output_dir=str(tmp_path / "out"))
708
+ assert proc.find_pdf_files(str(tmp_path)) == []
709
+
710
+ def test_finds_pdf_only(self, tmp_path):
711
+ """只查找 .pdf 文件"""
712
+ from preprocess import PDFBatchProcessor
713
+ (tmp_path / "a.pdf").touch()
714
+ (tmp_path / "b.pdf").touch()
715
+ (tmp_path / "c.txt").touch()
716
+ proc = PDFBatchProcessor(output_dir=str(tmp_path / "out"))
717
+ files = proc.find_pdf_files(str(tmp_path))
718
+ assert len(files) == 2
719
+
720
+ def test_single_pdf_file(self, tmp_path):
721
+ from preprocess import PDFBatchProcessor
722
+ pdf = tmp_path / "x.pdf"; pdf.touch()
723
+ proc = PDFBatchProcessor(output_dir=str(tmp_path / "out"))
724
+ assert proc.find_pdf_files(str(pdf)) == [pdf]
725
+
726
+ def test_extract_nonexistent_has_error(self, tmp_path):
727
+ from preprocess import PDFBatchProcessor
728
+ proc = PDFBatchProcessor(output_dir=str(tmp_path / "out"))
729
+ r = proc.extract_pdf_content(Path("/no/file.pdf"))
730
+ assert r["error"] is not None
731
+
732
+ def test_result_has_required_keys(self, tmp_path):
733
+ from preprocess import PDFBatchProcessor
734
+ proc = PDFBatchProcessor(output_dir=str(tmp_path / "out"))
735
+ r = proc.extract_pdf_content(Path("dummy.pdf"))
736
+ for k in ["file_name", "file_path", "metadata", "pages", "error"]:
737
+ assert k in r
738
+
739
+
740
+ # ================================================================
741
+ # 工具 8: 数据预处理 (JSONL → Document)
742
+ # ================================================================
743
+
744
+ class TestDataPreprocessing:
745
+
746
+ def test_jsonl_parse(self, tmp_path):
747
+ f = tmp_path / "t.jsonl"
748
+ f.write_text(
749
+ json.dumps({"query": "Q1", "response": "A1"}, ensure_ascii=False) + "\n"
750
+ + json.dumps({"query": "Q2", "response": "A2"}, ensure_ascii=False) + "\n"
751
+ )
752
+ docs = []
753
+ with open(f) as fh:
754
+ for line in fh:
755
+ c = json.loads(line.strip())
756
+ docs.append(c["query"] + "\n" + c["response"])
757
+ assert len(docs) == 2 and "Q1" in docs[0]
758
+
759
+ def test_jsonl_empty(self, tmp_path):
760
+ f = tmp_path / "e.jsonl"; f.write_text("")
761
+ assert sum(1 for line in open(f) if line.strip()) == 0
762
+
763
+ def test_jsonl_bad_line(self, tmp_path):
764
+ f = tmp_path / "b.jsonl"
765
+ f.write_text('{"query":"ok","response":"r"}\n{bad}\n')
766
+ ok, bad = 0, 0
767
+ for line in open(f):
768
+ try:
769
+ json.loads(line); ok += 1
770
+ except json.JSONDecodeError:
771
+ bad += 1
772
+ assert ok == 1 and bad == 1
773
+
774
+
775
+ # ================================================================
776
+ # 工具 9: 端到端 RAG 编排逻辑
777
+ # ================================================================
778
+
779
+ class TestRAGOrchestration:
780
+
781
+ def test_three_way_merge(self):
782
+ ctx = "M结果" + "\n" + "P结果" + "\n" + "N结果"
783
+ assert "M结果" in ctx and "P结果" in ctx and "N结果" in ctx
784
+
785
+ def test_partial_empty_merge(self):
786
+ ctx = "有结果" + "\n" + "" + "\n" + ""
787
+ assert "有结果" in ctx
788
+
789
+ def test_all_empty_merge(self):
790
+ ctx = "" + "\n" + "" + "\n" + ""
791
+ assert ctx.strip() == ""
792
+
793
+ def test_request_valid(self):
794
+ assert {"question": "Q"}.get("question") == "Q"
795
+
796
+ def test_request_missing_question(self):
797
+ assert {"query": "x"}.get("question") is None
798
+
799
+ def test_redis_caching_in_chatbot(self):
800
+ """chatbot 使用 redis get_or_compute: 缓存命中 → 跳过 RAG"""
801
+ mgr = make_redis_manager()
802
+ mgr.set_answer("Q", "缓存A")
803
+ called = False
804
+ def rag(): nonlocal called; called = True; return "new"
805
+ assert mgr.get_or_compute("Q", rag) == "缓存A"
806
+ assert called is False
807
+
808
+
809
+ # ================================================================
810
+ if __name__ == "__main__":
811
+ pytest.main([__file__, "-v", "--tb=short"])