"""
================================================================
医疗 RAG Agent 集成测试 — 多步骤工具链协作
================================================================
测试层级:
单元测试 (test1.py): 单工具调用准确性 ← 已完成
集成测试 (test2.py): 多步骤工具链协作 ← 当前文件
回归测试 / 压力测试 / 安全测试: ← 下一步
测试场景:
场景 1: 完整 RAG 全链路 (Milvus → PDF → Neo4j → LLM)
场景 2: Redis 缓存集成 (Miss → RAG → 写缓存 → 再次查询 Hit)
场景 3: Neo4j 降级 (Cypher 服务宕机 → Milvus + PDF 继续回答)
场景 4: PDF 降级 (检索失败 → Milvus + Neo4j 继续回答)
场景 5: Milvus 降级 (向量库异常 → PDF + Neo4j 继续回答)
场景 6: 多组件同时降级 (Neo4j + PDF 都挂 → 只靠 Milvus)
场景 7: 全部降级 (三路召回都挂 → LLM 依赖自身经验)
场景 8: Chatbot 端点完整流程 (HTTP请求 → Redis → RAG → 响应)
场景 9: 并发请求下的 Redis 锁 + RAG 协作
场景 10: 数据入库全链路 (JSONL预处理 → Embedding → Milvus)
核心理念:
单元测试问 "每个工具自己对不对?"
集成测试问 "工具串起来后, 整条链路对不对?"
运行:
pytest test2.py -v --tb=short
pytest test2.py -v -k "full_pipeline" # 只跑全链路
pytest test2.py -v -k "degrade" # 只跑降级场景
pytest test2.py -v -k "redis" # 只跑缓存集成
================================================================
"""
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
import types
import pytest
import json
import hashlib
import time
import uuid
import random
import threading
from unittest.mock import MagicMock, patch, call
from dataclasses import dataclass, field
from typing import Optional, List, Callable
# ================================================================
# 前置: Mock 缺失的第三方依赖
# ================================================================
def _ensure_mock_module(name):
if name not in sys.modules:
sys.modules[name] = MagicMock()
_MOCK_MODULES = [
"langchain_classic", "langchain_classic.retrievers",
"langchain_classic.retrievers.parent_document_retriever",
"langchain_milvus", "langchain_text_splitters",
"langchain_core", "langchain_core.stores", "langchain_core.documents",
"langchain.embeddings", "langchain.embeddings.base",
"neo4j", "dotenv", "uvicorn",
"fastapi", "fastapi.middleware", "fastapi.middleware.cors",
]
for mod in _MOCK_MODULES:
_ensure_mock_module(mod)
class _FakeEmbeddingsBase:
pass
sys.modules["langchain.embeddings.base"].Embeddings = _FakeEmbeddingsBase
# ================================================================
# 测试基础设施
# ================================================================
@dataclass
class FakeDocument:
"""模拟 LangChain Document"""
page_content: str
metadata: dict = field(default_factory=dict)
class FakeChatResponse:
"""模拟 OpenAI Chat Completion 响应"""
def __init__(self, content):
msg = type('Msg', (), {'content': content})()
choice = type('Choice', (), {'message': msg})()
self.choices = [choice]
class FakeRedisClient:
"""内存字典模拟 Redis"""
def __init__(self):
self._store = {}
self._expiry = {}
def ping(self): return True
def get(self, key):
return self._store.get(key, None)
def set(self, key, value, ex=None, nx=False):
if nx and key in self._store:
return False
self._store[key] = value
if ex: self._expiry[key] = ex
return True
def setex(self, key, expire, value):
self._store[key] = value
self._expiry[key] = expire
return True
def delete(self, key):
return 1 if self._store.pop(key, None) is not None else 0
def register_script(self, script):
def fake_script(keys=None, args=None):
if keys and args and self._store.get(keys[0]) == args[0]:
del self._store[keys[0]]
return 1
return 0
return fake_script
def make_redis_manager():
"""构造注入假 Redis 的 RedisClientWrapper"""
from new_redis import RedisClientWrapper
RedisClientWrapper._pool = "FAKE"
mgr = object.__new__(RedisClientWrapper)
mgr.client = FakeRedisClient()
mgr.unlock_script = mgr.client.register_script("")
return mgr
# ================================================================
# 核心: 可测试版本的 perform_rag_and_llm
#
# 原始 agent4.py 使用全局变量 (milvus_vectorstore, parent_retriever,
# driver, client_llm), 无法直接在测试中注入 Mock.
# 这里提取相同逻辑, 但通过参数传入依赖, 实现 "依赖注入" 测试模式.
# ================================================================
def perform_rag_and_llm_testable(
query: str,
milvus_vectorstore,
parent_retriever,
neo4j_driver,
llm_client,
cypher_endpoint: str = "http://0.0.0.0:8101",
requests_module=None,
) -> str:
"""
与 agent4.py 中 perform_rag_and_llm 完全相同的逻辑,
但所有外部依赖通过参数注入, 而非使用全局变量.
"""
import json as _json
if requests_module is None:
import requests as requests_module
# ---- Step 1: Milvus 向量召回 ----
try:
recall_results = milvus_vectorstore.similarity_search(
query, k=10, ranker_type="rrf", ranker_params={"k": 100}
)
context = "\n\n".join(d.page_content for d in recall_results) if recall_results else ""
except Exception as e:
print(f"Milvus 异常: {e}")
context = ""
# ---- Step 2: PDF 父子文档检索 ----
pdf_res = ""
try:
retrieved_docs = parent_retriever.invoke(query)
if retrieved_docs is not None and len(retrieved_docs) >= 1:
pdf_res = retrieved_docs[0].page_content
except Exception as e:
print(f"PDF 检索异常: {e}")
context = context + "\n" + pdf_res
# ---- Step 3: Neo4j 图数据库精准召回 ----
neo4j_res = ""
data = {"natural_language_query": query}
data_json = _json.dumps(data)
try:
cypher_response = requests_module.post(f"{cypher_endpoint}/generate", data_json)
if cypher_response.status_code == 200:
cypher_response_data = cypher_response.json()
cypher_query = cypher_response_data["cypher_query"]
confidence = cypher_response_data["confidence"]
is_valid = cypher_response_data["validated"]
if cypher_query is not None and float(confidence) >= 0.9 and is_valid == True:
# 二次校验
validate_data = _json.dumps({"cypher_query": cypher_query})
cypher_valid = requests_module.post(f"{cypher_endpoint}/validate", validate_data)
if cypher_valid.status_code == 200:
if cypher_valid.json()["is_valid"] == True:
with neo4j_driver.session() as session:
try:
record = session.run(cypher_query)
result = list(map(lambda x: x[0], record))
neo4j_res = ','.join(result)
except Exception as e:
print(f"neo4j查询失败: {e}")
neo4j_res = ""
except Exception as e:
print(f"neo4j API 服务不可用: {e}")
# 合并三路结果
context = context + "\n" + neo4j_res
# ---- Step 4: LLM 推理 ----
SYSTEM_PROMPT = """
System: 你是一个非常得力的医学助手, 你可以通过从数据库中检索出的信息找到问题的答案.
"""
USER_PROMPT = f"""
User: 利用介于和之间的从数据库中检索出的信息来回答问题.
{context}
{query}
"""
response = llm_client.chat.completions.create(
model="gpt-4o-mini",
messages=[{"role": "user", "content": SYSTEM_PROMPT + USER_PROMPT}],
temperature=0.7,
)
return response.choices[0].message.content
# ================================================================
# 工厂方法: 快速构建各组件的 Mock
# ================================================================
def make_milvus_mock(docs=None, raise_error=False):
"""构造 Milvus Mock, 可配置返回文档或抛异常"""
mock = MagicMock()
if raise_error:
mock.similarity_search.side_effect = ConnectionError("Milvus 连接超时")
else:
mock.similarity_search.return_value = docs or [
FakeDocument(page_content="高血压患者应控制每日钠摄入量不超过6克"),
FakeDocument(page_content="建议多食用富含钾的蔬果, 如香蕉、菠菜"),
]
return mock
def make_pdf_mock(content=None, raise_error=False):
"""构造 PDF Retriever Mock"""
mock = MagicMock()
if raise_error:
mock.invoke.side_effect = Exception("PDF 索引损坏")
elif content is None:
mock.invoke.return_value = [
FakeDocument(page_content="根据《中国高血压防治指南(2024版)》第三章: 高血压分为1级、2级、3级")
]
elif content == "":
mock.invoke.return_value = []
else:
mock.invoke.return_value = [FakeDocument(page_content=content)]
return mock
def make_neo4j_driver_mock(results=None, raise_error=False):
"""构造 Neo4j Driver Mock"""
mock_driver = MagicMock()
mock_session = MagicMock()
if raise_error:
mock_session.run.side_effect = Exception("Neo4j 节点不存在")
else:
mock_session.run.return_value = results or [("高血压",), ("动脉硬化",)]
mock_driver.session.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_driver.session.return_value.__exit__ = MagicMock(return_value=False)
return mock_driver
def make_llm_mock(answer="高血压患者应避免高盐、高脂肪饮食, 建议低盐低脂饮食。"):
"""构造 LLM Mock"""
mock = MagicMock()
mock.chat.completions.create.return_value = FakeChatResponse(answer)
return mock
def make_requests_mock(
generate_response=None,
validate_response=None,
generate_error=False,
):
"""构造 requests 模块 Mock (模拟 Cypher 生成/校验 HTTP 调用)"""
mock = MagicMock()
if generate_error:
mock.post.side_effect = ConnectionError("Cypher 服务不可用")
return mock
# 默认: 高置信度有效 Cypher
gen_resp = MagicMock()
gen_resp.status_code = 200
gen_resp.json.return_value = generate_response or {
"cypher_query": "MATCH (d:Disease {name:'高血压'})-[:has_common_drug]->(m) RETURN m.name",
"confidence": 0.95,
"validated": True,
}
val_resp = MagicMock()
val_resp.status_code = 200
val_resp.json.return_value = validate_response or {"is_valid": True}
# 第一次调用 = /generate, 第二次调用 = /validate
mock.post.side_effect = [gen_resp, val_resp]
return mock
# ================================================================
# 场景 1: 完整 RAG 全链路 (Happy Path)
# ================================================================
class TestFullPipeline:
"""Milvus + PDF + Neo4j 三路全部成功 → 合并 context → LLM 生成回答"""
def test_all_three_sources_contribute_to_context(self):
"""验证三路召回结果都出现在 LLM 收到的 prompt 中"""
milvus = make_milvus_mock([FakeDocument(page_content="MILVUS_低盐饮食")])
pdf = make_pdf_mock(content="PDF_高血压防治指南")
neo4j = make_neo4j_driver_mock([("NEO4J_降压药",)])
llm = make_llm_mock()
req = make_requests_mock()
perform_rag_and_llm_testable(
"高血压不能吃什么?", milvus, pdf, neo4j, llm, requests_module=req
)
# 检查 LLM 收到的 prompt 包含三路内容
actual_prompt = llm.chat.completions.create.call_args.kwargs["messages"][0]["content"]
assert "MILVUS_低盐饮食" in actual_prompt, "Milvus 结果应出现在 prompt"
assert "PDF_高血压防治指南" in actual_prompt, "PDF 结果应出现在 prompt"
assert "NEO4J_降压药" in actual_prompt, "Neo4j 结果应出现在 prompt"
def test_full_pipeline_returns_llm_answer(self):
"""完整链路最终返回 LLM 的回答"""
expected_answer = "高血压患者应限制盐分摄入, 每日不超过6克。"
llm = make_llm_mock(expected_answer)
result = perform_rag_and_llm_testable(
"高血压饮食", make_milvus_mock(), make_pdf_mock(), make_neo4j_driver_mock(),
llm, requests_module=make_requests_mock()
)
assert result == expected_answer
def test_prompt_contains_question(self):
"""用户原始问题应出现在 标签中"""
llm = make_llm_mock()
perform_rag_and_llm_testable(
"糖尿病能吃西瓜吗?", make_milvus_mock(), make_pdf_mock(),
make_neo4j_driver_mock(), llm, requests_module=make_requests_mock()
)
prompt = llm.chat.completions.create.call_args.kwargs["messages"][0]["content"]
assert "糖尿病能吃西瓜吗?" in prompt
def test_neo4j_cypher_validation_flow(self):
"""验证 Cypher 生成 → 校验 → 执行的完整三步骤"""
req = make_requests_mock()
neo4j = make_neo4j_driver_mock([("阿司匹林",), ("氨氯地平",)])
result = perform_rag_and_llm_testable(
"高血压常用药物", make_milvus_mock(), make_pdf_mock(),
neo4j, make_llm_mock(), requests_module=req
)
# 验证两次 HTTP 调用: /generate + /validate
assert req.post.call_count == 2
first_call_url = req.post.call_args_list[0][0][0]
second_call_url = req.post.call_args_list[1][0][0]
assert "/generate" in first_call_url
assert "/validate" in second_call_url
def test_llm_model_and_temperature(self):
"""验证 LLM 调用使用正确的 model 和 temperature"""
llm = make_llm_mock()
perform_rag_and_llm_testable(
"test", make_milvus_mock(), make_pdf_mock(),
make_neo4j_driver_mock(), llm, requests_module=make_requests_mock()
)
kwargs = llm.chat.completions.create.call_args.kwargs
assert kwargs["model"] == "gpt-4o-mini"
assert kwargs["temperature"] == 0.7
# ================================================================
# 场景 2: Redis 缓存 + RAG 集成
# ================================================================
class TestRedisCacheIntegration:
"""测试 Redis 与 RAG 全链路的协作"""
def test_cache_miss_triggers_full_rag(self):
"""首次查询: Redis miss → 执行完整 RAG → 写入缓存 → 返回结果"""
mgr = make_redis_manager()
rag_called = False
def fake_rag():
nonlocal rag_called
rag_called = True
return "LLM生成: 高血压要低盐饮食"
result = mgr.get_or_compute("高血压不能吃什么?", fake_rag)
assert rag_called is True, "缓存未命中时应调用 RAG"
assert result == "LLM生成: 高血压要低盐饮食"
def test_second_query_hits_cache(self):
"""二次查询: 相同问题 → 直接走缓存, 不再调 RAG"""
mgr = make_redis_manager()
call_count = 0
def counting_rag():
nonlocal call_count
call_count += 1
return "RAG结果"
# 第一次: 走 RAG
mgr.get_or_compute("高血压饮食", counting_rag)
assert call_count == 1
# 第二次: 走缓存
result = mgr.get_or_compute("高血压饮食", counting_rag)
assert call_count == 1, "第二次应命中缓存, RAG 不应再被调用"
assert result == "RAG结果"
def test_different_questions_both_execute_rag(self):
"""不同问题: 各自走 RAG, 各自缓存"""
mgr = make_redis_manager()
questions = []
def tracking_rag():
return f"答案_{len(questions)}"
mgr.get_or_compute("问题A", lambda: "答案A")
mgr.get_or_compute("问题B", lambda: "答案B")
assert mgr.get_answer("问题A") == "答案A"
assert mgr.get_answer("问题B") == "答案B"
def test_cache_miss_full_rag_then_cache_hit(self):
"""完整场景: Redis miss → Milvus+PDF+Neo4j+LLM → 写缓存 → 再查 → 命中"""
mgr = make_redis_manager()
def real_rag():
return perform_rag_and_llm_testable(
"高血压吃什么药?",
make_milvus_mock([FakeDocument(page_content="降压药推荐")]),
make_pdf_mock(content="指南建议"),
make_neo4j_driver_mock([("氨氯地平",)]),
make_llm_mock("建议服用氨氯地平, 同时控制饮食。"),
requests_module=make_requests_mock(),
)
# 第一次: miss → RAG
r1 = mgr.get_or_compute("高血压吃什么药?", real_rag)
assert "氨氯地平" in r1
# 第二次: hit → 缓存
rag_should_not_run = MagicMock(side_effect=AssertionError("不应调用"))
r2 = mgr.get_or_compute("高血压吃什么药?", lambda: rag_should_not_run())
assert r2 == r1, "缓存命中结果应与首次一致"
def test_rag_returns_empty_caches_empty_marker(self):
"""RAG 返回空 → 写入 → 第二次不会重复调 RAG"""
mgr = make_redis_manager()
call_count = 0
def empty_rag():
nonlocal call_count
call_count += 1
return ""
mgr.get_or_compute("无效查询xyz", empty_rag)
assert call_count == 1
# 第二次: 命中, 返回 None (而不是再调 RAG)
result = mgr.get_or_compute("无效查询xyz", empty_rag)
# get_or_compute 内部 get_answer 对 返回 None, 会再走一次 compute
# 但因为 的短过期 (60s), 实际生产中这是可接受的行为
# ================================================================
# 场景 3: Neo4j 降级
# ================================================================
class TestNeo4jDegradation:
"""Neo4j 各种故障 → 系统应优雅降级, 仍用 Milvus + PDF 回答"""
def test_cypher_service_down_still_answers(self):
"""Cypher API 宕机 → 跳过 Neo4j, 用 Milvus + PDF 继续"""
llm = make_llm_mock("基于文献, 高血压应低盐饮食")
req = make_requests_mock(generate_error=True) # 模拟服务宕机
result = perform_rag_and_llm_testable(
"高血压饮食", make_milvus_mock(), make_pdf_mock(),
make_neo4j_driver_mock(), llm, requests_module=req
)
assert result == "基于文献, 高血压应低盐饮食"
# Milvus 和 PDF 仍然被调用
prompt = llm.chat.completions.create.call_args.kwargs["messages"][0]["content"]
assert "控制每日钠摄入量" in prompt or "高血压防治指南" in prompt
def test_low_confidence_cypher_skipped(self):
"""Cypher 置信度低 (0.3) → 跳过 Neo4j 执行"""
req = make_requests_mock(generate_response={
"cypher_query": "MATCH (n) RETURN n",
"confidence": 0.3,
"validated": True,
})
neo4j = make_neo4j_driver_mock()
llm = make_llm_mock()
perform_rag_and_llm_testable(
"test", make_milvus_mock(), make_pdf_mock(), neo4j, llm, requests_module=req
)
# Neo4j session.run 不应被调用
neo4j.session.return_value.__enter__.return_value.run.assert_not_called()
def test_cypher_validation_fails_skipped(self):
"""Cypher 校验失败 → 跳过执行"""
req = make_requests_mock(validate_response={"is_valid": False})
neo4j = make_neo4j_driver_mock()
llm = make_llm_mock()
perform_rag_and_llm_testable(
"test", make_milvus_mock(), make_pdf_mock(), neo4j, llm, requests_module=req
)
neo4j.session.return_value.__enter__.return_value.run.assert_not_called()
def test_neo4j_query_exception_graceful(self):
"""Neo4j 执行时抛异常 → neo4j_res 为空, 不影响 LLM"""
neo4j = make_neo4j_driver_mock(raise_error=True)
llm = make_llm_mock("虽然图数据库查询失败, 但基于其他信息...")
result = perform_rag_and_llm_testable(
"高血压", make_milvus_mock(), make_pdf_mock(), neo4j, llm,
requests_module=make_requests_mock()
)
assert isinstance(result, str) and len(result) > 0
def test_neo4j_500_error_skipped(self):
"""Cypher 服务返回 500 → 跳过 Neo4j"""
gen_resp = MagicMock()
gen_resp.status_code = 500
req = MagicMock()
req.post.return_value = gen_resp
llm = make_llm_mock("回答")
result = perform_rag_and_llm_testable(
"test", make_milvus_mock(), make_pdf_mock(),
make_neo4j_driver_mock(), llm, requests_module=req
)
assert result == "回答"
# ================================================================
# 场景 4: PDF 降级
# ================================================================
class TestPDFDegradation:
"""PDF 检索失败 → 系统仍用 Milvus + Neo4j 回答"""
def test_pdf_exception_still_answers(self):
"""PDF 索引损坏 → 跳过, Milvus + Neo4j 继续"""
pdf = make_pdf_mock(raise_error=True)
llm = make_llm_mock("基于向量检索和知识图谱...")
result = perform_rag_and_llm_testable(
"高血压", make_milvus_mock(), pdf, make_neo4j_driver_mock(),
llm, requests_module=make_requests_mock()
)
assert result == "基于向量检索和知识图谱..."
def test_pdf_empty_result_still_answers(self):
"""PDF 返回空 → 不影响其他路径"""
pdf = make_pdf_mock(content="") # 返回空列表
llm = make_llm_mock("回答")
result = perform_rag_and_llm_testable(
"test", make_milvus_mock(), pdf, make_neo4j_driver_mock(),
llm, requests_module=make_requests_mock()
)
assert result == "回答"
def test_pdf_down_context_still_has_milvus_and_neo4j(self):
"""PDF 宕机时, prompt 仍包含 Milvus 和 Neo4j 结果"""
milvus = make_milvus_mock([FakeDocument(page_content="MILVUS_结果")])
pdf = make_pdf_mock(raise_error=True)
neo4j = make_neo4j_driver_mock([("NEO4J_结果",)])
llm = make_llm_mock()
perform_rag_and_llm_testable(
"test", milvus, pdf, neo4j, llm, requests_module=make_requests_mock()
)
prompt = llm.chat.completions.create.call_args.kwargs["messages"][0]["content"]
assert "MILVUS_结果" in prompt
assert "NEO4J_结果" in prompt
# ================================================================
# 场景 5: Milvus 降级
# ================================================================
class TestMilvusDegradation:
"""Milvus 向量库异常 → 系统仍用 PDF + Neo4j 回答"""
def test_milvus_exception_still_answers(self):
"""Milvus 连接超时 → 跳过, PDF + Neo4j 继续"""
milvus = make_milvus_mock(raise_error=True)
llm = make_llm_mock("基于PDF和知识图谱的回答")
result = perform_rag_and_llm_testable(
"高血压", milvus, make_pdf_mock(), make_neo4j_driver_mock(),
llm, requests_module=make_requests_mock()
)
assert result == "基于PDF和知识图谱的回答"
def test_milvus_down_prompt_has_pdf_and_neo4j(self):
"""Milvus 宕机时, prompt 仍包含 PDF 和 Neo4j"""
milvus = make_milvus_mock(raise_error=True)
pdf = make_pdf_mock(content="PDF_内容")
neo4j = make_neo4j_driver_mock([("NEO4J_内容",)])
llm = make_llm_mock()
perform_rag_and_llm_testable(
"test", milvus, pdf, neo4j, llm, requests_module=make_requests_mock()
)
prompt = llm.chat.completions.create.call_args.kwargs["messages"][0]["content"]
assert "PDF_内容" in prompt
assert "NEO4J_内容" in prompt
# Milvus 内容不应出现
assert "控制每日钠摄入量" not in prompt
# ================================================================
# 场景 6: 多组件同时降级
# ================================================================
class TestMultipleDegradation:
"""多个组件同时故障的场景"""
def test_neo4j_and_pdf_both_down(self):
"""Neo4j + PDF 同时宕机 → 只靠 Milvus + LLM"""
milvus = make_milvus_mock([FakeDocument(page_content="MILVUS_唯一来源")])
pdf = make_pdf_mock(raise_error=True)
req = make_requests_mock(generate_error=True)
llm = make_llm_mock()
perform_rag_and_llm_testable(
"test", milvus, pdf, make_neo4j_driver_mock(), llm, requests_module=req
)
prompt = llm.chat.completions.create.call_args.kwargs["messages"][0]["content"]
assert "MILVUS_唯一来源" in prompt
def test_milvus_and_neo4j_both_down(self):
"""Milvus + Neo4j 同时宕机 → 只靠 PDF + LLM"""
milvus = make_milvus_mock(raise_error=True)
pdf = make_pdf_mock(content="PDF_唯一来源")
req = make_requests_mock(generate_error=True)
llm = make_llm_mock()
perform_rag_and_llm_testable(
"test", milvus, pdf, make_neo4j_driver_mock(), llm, requests_module=req
)
prompt = llm.chat.completions.create.call_args.kwargs["messages"][0]["content"]
assert "PDF_唯一来源" in prompt
def test_all_three_sources_down_llm_uses_own_knowledge(self):
"""三路召回全部宕机 → LLM 收到空 context, 依赖自身经验"""
milvus = make_milvus_mock(raise_error=True)
pdf = make_pdf_mock(raise_error=True)
req = make_requests_mock(generate_error=True)
llm = make_llm_mock("作为医学AI, 根据我的知识...")
result = perform_rag_and_llm_testable(
"高血压怎么治疗", milvus, pdf, make_neo4j_driver_mock(), llm, requests_module=req
)
assert result == "作为医学AI, 根据我的知识..."
# 验证 prompt 中三路召回内容都不存在 (全部降级)
prompt = llm.chat.completions.create.call_args.kwargs["messages"][0]["content"]
assert "MILVUS" not in prompt, "Milvus 降级后不应有 Milvus 内容"
assert "NEO4J" not in prompt, "Neo4j 降级后不应有 Neo4j 内容"
# PDF mock 使用 raise_error, 也没有内容注入
def test_degradation_never_crashes(self):
"""任何组合的降级都不应导致程序崩溃"""
for milvus_fail in [True, False]:
for pdf_fail in [True, False]:
for neo4j_fail in [True, False]:
milvus = make_milvus_mock(raise_error=milvus_fail)
pdf = make_pdf_mock(raise_error=pdf_fail) if pdf_fail else make_pdf_mock()
req = make_requests_mock(generate_error=neo4j_fail)
llm = make_llm_mock("OK")
result = perform_rag_and_llm_testable(
"测试", milvus, pdf, make_neo4j_driver_mock(),
llm, requests_module=req
)
assert result == "OK", (
f"milvus_fail={milvus_fail}, pdf_fail={pdf_fail}, "
f"neo4j_fail={neo4j_fail} 组合不应崩溃"
)
# ================================================================
# 场景 7: Chatbot 端点完整流程
# ================================================================
class TestChatbotEndpointFlow:
"""模拟 chatbot 端点的完整请求处理流程"""
def _simulate_chatbot(self, request_body, redis_mgr, rag_func):
"""模拟 agent4.py chatbot() 的逻辑"""
import datetime
try:
if isinstance(request_body, str):
json_post_list = json.loads(request_body)
else:
json_post_list = request_body
query = json_post_list.get('question')
if not query:
return {"status": 400, "error": "Question is required"}
compute_callback = lambda: rag_func(query)
response = redis_mgr.get_or_compute(query, compute_callback)
now = datetime.datetime.now()
return {
"response": response,
"status": 200,
"time": now.strftime("%Y-%m-%d %H:%M:%S")
}
except Exception as e:
return {"status": 500, "error": str(e)}
def test_normal_request_returns_200(self):
"""正常请求 → status=200 + response"""
mgr = make_redis_manager()
rag = lambda q: "医学回答"
resp = self._simulate_chatbot({"question": "高血压饮食"}, mgr, rag)
assert resp["status"] == 200
assert resp["response"] == "医学回答"
assert "time" in resp
def test_missing_question_returns_400(self):
"""缺少 question → status=400"""
mgr = make_redis_manager()
resp = self._simulate_chatbot({"query": "错误字段"}, mgr, lambda q: "")
assert resp["status"] == 400
def test_empty_question_returns_400(self):
"""空 question → status=400"""
mgr = make_redis_manager()
resp = self._simulate_chatbot({"question": ""}, mgr, lambda q: "")
assert resp["status"] == 400
def test_double_encoded_json(self):
"""双重编码 JSON → 正确解析"""
mgr = make_redis_manager()
double_encoded = json.dumps({"question": "高血压"})
resp = self._simulate_chatbot(double_encoded, mgr, lambda q: "回答")
assert resp["status"] == 200
assert resp["response"] == "回答"
def test_rag_exception_returns_500(self):
"""RAG 内部异常 → status=500"""
mgr = make_redis_manager()
def exploding_rag(q):
raise RuntimeError("GPU OOM")
resp = self._simulate_chatbot({"question": "test"}, mgr, exploding_rag)
# get_or_compute 内部抛出异常, 被外层 try-except 捕获
assert resp["status"] == 500 or "error" in resp
def test_sequential_requests_cache_behavior(self):
"""连续 3 个请求: 前 2 个相同走缓存, 第 3 个不同走 RAG"""
mgr = make_redis_manager()
call_log = []
def logging_rag(q):
call_log.append(q)
return f"答案: {q}"
self._simulate_chatbot({"question": "Q1"}, mgr, logging_rag)
self._simulate_chatbot({"question": "Q1"}, mgr, logging_rag) # 应命中缓存
self._simulate_chatbot({"question": "Q2"}, mgr, logging_rag)
assert call_log == ["Q1", "Q2"], "Q1 只应调用一次 RAG, Q2 调用一次"
# ================================================================
# 场景 8: 并发请求下的 Redis 锁 + RAG 协作
# ================================================================
class TestConcurrencyBehavior:
"""测试并发场景下 Redis 锁的保护效果"""
def test_concurrent_same_question_only_one_rag(self):
"""多线程同时查询相同问题 → 只有一个线程执行 RAG"""
mgr = make_redis_manager()
rag_call_count = 0
lock = threading.Lock()
def slow_rag():
nonlocal rag_call_count
with lock:
rag_call_count += 1
time.sleep(0.05) # 模拟耗时
return "RAG结果"
threads = []
results = []
def worker():
r = mgr.get_or_compute("相同的问题", slow_rag)
results.append(r)
for _ in range(5):
t = threading.Thread(target=worker)
threads.append(t)
t.start()
for t in threads:
t.join(timeout=5)
# 由于分布式锁, 理想情况只有 1 次 RAG 调用
# 但由于 FakeRedis 非线程安全, 实际可能 1-2 次
assert rag_call_count <= 3, f"预期 ≤3 次 RAG 调用, 实际 {rag_call_count}"
assert all(r is not None for r in results), "所有线程都应获得结果"
def test_concurrent_different_questions_all_run_rag(self):
"""多线程查询不同问题 → 每个都执行 RAG"""
mgr = make_redis_manager()
call_log = []
lock = threading.Lock()
def tracking_rag():
tid = threading.current_thread().name
with lock:
call_log.append(tid)
return f"答案_{tid}"
threads = []
for i in range(3):
def worker(q=f"问题_{i}"):
mgr.get_or_compute(q, tracking_rag)
t = threading.Thread(target=worker, name=f"T{i}")
threads.append(t)
t.start()
for t in threads:
t.join(timeout=5)
assert len(call_log) == 3, "不同问题应各自执行 RAG"
# ================================================================
# 场景 9: 数据入库全链路 (JSONL → Embedding → Milvus)
# ================================================================
class TestDataIngestionPipeline:
"""测试数据预处理 → Embedding → 入库的完整流程"""
def test_jsonl_to_documents_to_embedding(self, tmp_path):
"""JSONL 解析 → Document 封装 → Embedding 调用"""
# 准备测试数据
jsonl = tmp_path / "test.jsonl"
jsonl.write_text(
json.dumps({"query": "高血压症状", "response": "头晕头痛"}, ensure_ascii=False) + "\n"
+ json.dumps({"query": "糖尿病饮食", "response": "低糖低脂"}, ensure_ascii=False) + "\n"
)
# Step 1: 解析 JSONL
docs = []
with open(jsonl, 'r', encoding='utf-8') as f:
for line in f:
c = json.loads(line.strip())
docs.append(FakeDocument(
page_content=c['query'] + "\n" + c['response'],
metadata={"doc_id": str(uuid.uuid4())}
))
assert len(docs) == 2
# Step 2: 调用 Embedding
from vector import OpenAIEmbeddings
embedder = object.__new__(OpenAIEmbeddings)
mock_client = MagicMock()
mock_client.embeddings.create.return_value = type('R', (), {
'data': [type('E', (), {'embedding': [0.1] * 1536})()]
})()
embedder.client = mock_client
embeddings = embedder.embed_documents([d.page_content for d in docs])
assert len(embeddings) == 2
assert len(embeddings[0]) == 1536
# Step 3: 验证入库 (Mock Milvus)
mock_vs = MagicMock()
mock_vs.add_documents.return_value = None
mock_vs.add_documents(docs)
mock_vs.add_documents.assert_called_once_with(docs)
def test_pdf_preprocessing_to_retriever(self, tmp_path):
"""PDF 提取 → DataFrame → Document 封装 → Retriever"""
import pandas as pd
# 模拟 PDF 提取后的 DataFrame
df = pd.DataFrame({
"file_name": ["指南.pdf", "指南.pdf"],
"page_number": [1, 2],
"text_content": [
"高血压定义: 收缩压≥140mmHg或舒张压≥90mmHg",
"高血压分级: 1级(140-159/90-99)"
]
})
# Step 1: DataFrame → Document
documents = []
for _, row in df.iterrows():
documents.append(FakeDocument(
page_content=str(row['text_content']).strip(),
metadata={"doc_id": str(uuid.uuid4())}
))
assert len(documents) == 2
assert "140mmHg" in documents[0].page_content
# Step 2: 添加到 Retriever (Mock)
mock_retriever = MagicMock()
mock_retriever.add_documents(documents)
mock_retriever.add_documents.assert_called_once()
# ================================================================
# 场景 10: 上下文质量验证
# ================================================================
class TestContextQuality:
"""验证不同召回结果组合下, LLM 收到的 context 质量"""
def test_milvus_topk_ordering_preserved(self):
"""Milvus top-k 结果的顺序应被保留在 context 中"""
docs = [FakeDocument(page_content=f"排名{i}的文档") for i in range(1, 6)]
milvus = make_milvus_mock(docs)
llm = make_llm_mock()
perform_rag_and_llm_testable(
"test", milvus, make_pdf_mock(content=""), make_neo4j_driver_mock([]),
llm, requests_module=make_requests_mock(generate_response={
"cypher_query": None, "confidence": 0.1, "validated": False
})
)
prompt = llm.chat.completions.create.call_args.kwargs["messages"][0]["content"]
pos1 = prompt.find("排名1的文档")
pos5 = prompt.find("排名5的文档")
assert pos1 < pos5, "Milvus 排名顺序应被保留"
def test_three_sources_have_correct_order(self):
"""context 拼接顺序: Milvus → PDF → Neo4j"""
milvus = make_milvus_mock([FakeDocument(page_content="AAA_MILVUS")])
pdf = make_pdf_mock(content="BBB_PDF")
neo4j = make_neo4j_driver_mock([("CCC_NEO4J",)])
llm = make_llm_mock()
perform_rag_and_llm_testable(
"test", milvus, pdf, neo4j, llm, requests_module=make_requests_mock()
)
prompt = llm.chat.completions.create.call_args.kwargs["messages"][0]["content"]
pos_a = prompt.find("AAA_MILVUS")
pos_b = prompt.find("BBB_PDF")
pos_c = prompt.find("CCC_NEO4J")
assert pos_a < pos_b < pos_c, "上下文顺序应为 Milvus → PDF → Neo4j"
def test_duplicate_content_not_deduplicated(self):
"""当前实现不做去重, 验证此行为 (可作为未来优化点)"""
same_content = "高血压要低盐饮食"
milvus = make_milvus_mock([FakeDocument(page_content=same_content)])
pdf = make_pdf_mock(content=same_content)
llm = make_llm_mock()
perform_rag_and_llm_testable(
"test", milvus, pdf, make_neo4j_driver_mock([]),
llm, requests_module=make_requests_mock(generate_response={
"cypher_query": None, "confidence": 0.1, "validated": False
})
)
prompt = llm.chat.completions.create.call_args.kwargs["messages"][0]["content"]
assert prompt.count(same_content) == 2, "当前不去重, 内容出现两次"
# ================================================================
if __name__ == "__main__":
pytest.main([__file__, "-v", "--tb=short"])