| """ |
| ================================================================ |
| 医疗 RAG Agent 回归测试 — 防止已修复问题复现 & 升级不退化 |
| ================================================================ |
| 测试层级: |
| 单元测试 (test1.py): 单工具调用准确性 ✅ 67 passed |
| 集成测试 (test2.py): 多步骤工具链协作 ✅ 37 passed |
| 回归测试 (test3.py): 防退化 & 边界守护 ← 当前文件 |
| 压力测试: Locust 待做 |
| 安全红队测试: test4.py 待做 |
| |
| 回归测试 vs 其他测试: |
| 单元测试问: "每个工具自己对不对?" |
| 集成测试问: "工具串起来后对不对?" |
| 回归测试问: "改了代码之后, 以前对的还对不对?" |
| |
| 回归场景: |
| 场景 1: 输入边界守护 (特殊字符 / 超长输入 / Unicode / 空白) |
| 场景 2: API 契约稳定性 (响应格式不变) |
| 场景 3: Redis 缓存边界 (Key 碰撞 / 编码 / 特殊字符) |
| 场景 4: Neo4j 防御 (Cypher 注入 / 畸形响应 / 边界置信度) |
| 场景 5: Prompt 格式稳定性 (系统提示 / 上下文标签 / 不丢字段) |
| 场景 6: 数据处理回归 (JSONL 脏数据 / PDF 空页 / 编码问题) |
| 场景 7: 模型切换回归 (换模型后格式/参数不变) |
| 场景 8: 已知 Bad Case 集合 (历史上出过的 bug) |
| |
| 运行: |
| pytest test3.py -v --tb=short |
| pytest test3.py -v -k "input_boundary" # 只跑输入边界 |
| pytest test3.py -v -k "contract" # 只跑 API 契约 |
| pytest test3.py -v -k "bad_case" # 只跑已知 Bad Case |
| ================================================================ |
| """ |
|
|
| import sys |
| import os |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) |
|
|
| import types |
| import pytest |
| import json |
| import hashlib |
| import time |
| import uuid |
| import random |
| import datetime |
| from unittest.mock import MagicMock, patch |
| from dataclasses import dataclass, field |
| from typing import Optional |
|
|
|
|
| |
| |
| |
|
|
| def _ensure_mock_module(name): |
| if name not in sys.modules: |
| sys.modules[name] = MagicMock() |
|
|
| for mod in [ |
| "langchain_classic", "langchain_classic.retrievers", |
| "langchain_classic.retrievers.parent_document_retriever", |
| "langchain_milvus", "langchain_text_splitters", |
| "langchain_core", "langchain_core.stores", "langchain_core.documents", |
| "langchain.embeddings", "langchain.embeddings.base", |
| "neo4j", "dotenv", "uvicorn", |
| "fastapi", "fastapi.middleware", "fastapi.middleware.cors", |
| ]: |
| _ensure_mock_module(mod) |
|
|
| class _FakeEmbeddingsBase: |
| pass |
| sys.modules["langchain.embeddings.base"].Embeddings = _FakeEmbeddingsBase |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class FakeDocument: |
| page_content: str |
| metadata: dict = field(default_factory=dict) |
|
|
| class FakeChatResponse: |
| def __init__(self, content): |
| msg = type('Msg', (), {'content': content})() |
| choice = type('Choice', (), {'message': msg})() |
| self.choices = [choice] |
|
|
| class FakeRedisClient: |
| def __init__(self): |
| self._store = {} |
| self._expiry = {} |
| def ping(self): return True |
| def get(self, key): return self._store.get(key) |
| def set(self, key, value, ex=None, nx=False): |
| if nx and key in self._store: return False |
| self._store[key] = value |
| if ex: self._expiry[key] = ex |
| return True |
| def setex(self, key, expire, value): |
| self._store[key] = value; self._expiry[key] = expire; return True |
| def delete(self, key): return 1 if self._store.pop(key, None) is not None else 0 |
| def register_script(self, script): |
| def f(keys=None, args=None): |
| if keys and args and self._store.get(keys[0]) == args[0]: |
| del self._store[keys[0]]; return 1 |
| return 0 |
| return f |
|
|
| def make_redis_manager(): |
| from new_redis import RedisClientWrapper |
| RedisClientWrapper._pool = "FAKE" |
| mgr = object.__new__(RedisClientWrapper) |
| mgr.client = FakeRedisClient() |
| mgr.unlock_script = mgr.client.register_script("") |
| return mgr |
|
|
| def make_llm_mock(answer="默认回答"): |
| mock = MagicMock() |
| mock.chat.completions.create.return_value = FakeChatResponse(answer) |
| return mock |
|
|
| def make_milvus_mock(docs=None, raise_error=False): |
| mock = MagicMock() |
| if raise_error: |
| mock.similarity_search.side_effect = ConnectionError("Milvus timeout") |
| else: |
| mock.similarity_search.return_value = docs or [ |
| FakeDocument(page_content="默认Milvus召回内容") |
| ] |
| return mock |
|
|
| def make_pdf_mock(content=None, raise_error=False): |
| mock = MagicMock() |
| if raise_error: |
| mock.invoke.side_effect = Exception("PDF error") |
| elif content == "": |
| mock.invoke.return_value = [] |
| else: |
| mock.invoke.return_value = [FakeDocument(page_content=content or "默认PDF内容")] |
| return mock |
|
|
| def make_neo4j_driver_mock(results=None): |
| drv = MagicMock() |
| sess = MagicMock() |
| sess.run.return_value = results or [] |
| drv.session.return_value.__enter__ = MagicMock(return_value=sess) |
| drv.session.return_value.__exit__ = MagicMock(return_value=False) |
| return drv |
|
|
| def make_requests_mock(generate_resp=None, validate_resp=None, error=False): |
| mock = MagicMock() |
| if error: |
| mock.post.side_effect = ConnectionError("service down") |
| return mock |
| gen = MagicMock(); gen.status_code = 200 |
| gen.json.return_value = generate_resp or { |
| "cypher_query": "MATCH (d) RETURN d", "confidence": 0.95, "validated": True |
| } |
| val = MagicMock(); val.status_code = 200 |
| val.json.return_value = validate_resp or {"is_valid": True} |
| mock.post.side_effect = [gen, val] |
| return mock |
|
|
| def perform_rag_testable(query, milvus, pdf, neo4j_driver, llm, requests_module=None): |
| """依赖注入版 perform_rag_and_llm (与 test2.py 相同)""" |
| import json as _json |
| if requests_module is None: |
| import requests as requests_module |
|
|
| try: |
| results = milvus.similarity_search(query, k=10, ranker_type="rrf", ranker_params={"k": 100}) |
| context = "\n\n".join(d.page_content for d in results) if results else "" |
| except Exception: |
| context = "" |
|
|
| pdf_res = "" |
| try: |
| docs = pdf.invoke(query) |
| if docs and len(docs) >= 1: |
| pdf_res = docs[0].page_content |
| except Exception: |
| pass |
| context = context + "\n" + pdf_res |
|
|
| neo4j_res = "" |
| try: |
| resp = requests_module.post("http://0.0.0.0:8101/generate", _json.dumps({"natural_language_query": query})) |
| if resp.status_code == 200: |
| d = resp.json() |
| if d["cypher_query"] and float(d["confidence"]) >= 0.9 and d["validated"]: |
| vresp = requests_module.post("http://0.0.0.0:8101/validate", _json.dumps({"cypher_query": d["cypher_query"]})) |
| if vresp.status_code == 200 and vresp.json()["is_valid"]: |
| with neo4j_driver.session() as session: |
| try: |
| record = session.run(d["cypher_query"]) |
| neo4j_res = ','.join(list(map(lambda x: x[0], record))) |
| except Exception: |
| neo4j_res = "" |
| except Exception: |
| pass |
| context = context + "\n" + neo4j_res |
|
|
| SYSTEM = "System: 你是一个非常得力的医学助手, 你可以通过从数据库中检索出的信息找到问题的答案." |
| USER = f"""User: 利用介于<context>和</context>之间的从数据库中检索出的信息来回答问题. |
| <context> |
| {context} |
| </context> |
| <question> |
| {query} |
| </question>""" |
|
|
| response = llm.chat.completions.create( |
| model="gpt-4o-mini", |
| messages=[{"role": "user", "content": SYSTEM + USER}], |
| temperature=0.7, |
| ) |
| return response.choices[0].message.content |
|
|
| def simulate_chatbot(request_body, redis_mgr, rag_func): |
| """模拟 chatbot 端点""" |
| try: |
| if isinstance(request_body, str): |
| data = json.loads(request_body) |
| else: |
| data = request_body |
| query = data.get('question') |
| if not query: |
| return {"status": 400, "error": "Question is required"} |
| result = redis_mgr.get_or_compute(query, lambda: rag_func(query)) |
| return {"response": result, "status": 200, |
| "time": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")} |
| except Exception as e: |
| return {"status": 500, "error": str(e)} |
|
|
|
|
| |
| |
| |
|
|
| class TestInputBoundary: |
| """ |
| 回归重点: 各种异常输入不应导致系统崩溃 |
| 历史教训: 特殊字符曾导致 JSON 序列化失败, 超长输入曾触发 OOM |
| """ |
|
|
| def test_chinese_medical_terms(self): |
| """中文医学术语 (含生僻字)""" |
| queries = [ |
| "急性心肌梗死的溶栓治疗方案", |
| "肾上腺嗜铬细胞瘤的鉴别诊断", |
| "2型糖尿病患者的胰岛素泵使用指南", |
| "幽门螺杆菌(Hp)的根除方案", |
| ] |
| for q in queries: |
| llm = make_llm_mock(f"回答: {q}") |
| result = perform_rag_testable( |
| q, make_milvus_mock(), make_pdf_mock(), make_neo4j_driver_mock(), |
| llm, requests_module=make_requests_mock(error=True) |
| ) |
| assert isinstance(result, str) and len(result) > 0, f"输入 '{q}' 不应崩溃" |
|
|
| def test_special_characters(self): |
| """特殊字符: 引号, 换行, 制表符, HTML 标签""" |
| dangerous_inputs = [ |
| '高血压"用药"方案', |
| "高血压'用药'方案", |
| "高血压\n用药\n方案", |
| "高血压\t用药", |
| "<script>alert('xss')</script>", |
| "高血压&糖尿病", |
| "DROP TABLE diseases;--", |
| '{"key": "value"}', |
| "高血压\x00用药", |
| ] |
| for inp in dangerous_inputs: |
| mgr = make_redis_manager() |
| resp = simulate_chatbot( |
| {"question": inp}, mgr, lambda q: "安全回答" |
| ) |
| assert resp["status"] == 200, f"输入 '{repr(inp)}' 不应崩溃" |
|
|
| def test_very_long_input(self): |
| """超长输入 (10000 字符)""" |
| long_query = "高血压" * 3000 |
| llm = make_llm_mock("长输入的回答") |
| result = perform_rag_testable( |
| long_query, make_milvus_mock(), make_pdf_mock(), |
| make_neo4j_driver_mock(), llm, requests_module=make_requests_mock(error=True) |
| ) |
| assert isinstance(result, str) |
|
|
| def test_empty_and_whitespace_inputs(self): |
| """空输入应返回 400""" |
| mgr = make_redis_manager() |
| for inp in ["", None]: |
| body = {"question": inp} if inp is not None else {} |
| resp = simulate_chatbot(body, mgr, lambda q: "不应到这里") |
| assert resp["status"] == 400, f"空输入 {repr(inp)} 应返回 400" |
|
|
| def test_whitespace_only_input_is_known_gap(self): |
| """ |
| ⚠️ 已知缺陷: 纯空格 " " 能通过 `if not query` 校验 |
| 原因: Python 中 " " 是 truthy, agent4.py 未做 .strip() |
| 建议修复: 在 agent4.py L273-275 改为 `if not query or not query.strip()` |
| 此测试记录当前行为, 修复后改为 assert status == 400 |
| """ |
| mgr = make_redis_manager() |
| resp = simulate_chatbot({"question": " "}, mgr, lambda q: "空白回答") |
| |
| assert resp["status"] == 200, "当前未做 strip, 纯空格会通过校验 (已知缺陷)" |
|
|
| def test_unicode_emoji_input(self): |
| """Emoji 和特殊 Unicode""" |
| queries = ["高血压😷怎么治疗", "🏥心脏病💊", "测试️⃣🔬"] |
| for q in queries: |
| mgr = make_redis_manager() |
| resp = simulate_chatbot({"question": q}, mgr, lambda q: "emoji回答") |
| assert resp["status"] == 200 |
|
|
| def test_mixed_language_input(self): |
| """中英日韩混合输入""" |
| query = "高血压 hypertension 고혈압 高血圧 treatment方案は?" |
| llm = make_llm_mock("混合语言回答") |
| result = perform_rag_testable( |
| query, make_milvus_mock(), make_pdf_mock(), |
| make_neo4j_driver_mock(), llm, requests_module=make_requests_mock(error=True) |
| ) |
| assert isinstance(result, str) |
|
|
| def test_numeric_only_input(self): |
| """纯数字输入""" |
| mgr = make_redis_manager() |
| resp = simulate_chatbot({"question": "12345678"}, mgr, lambda q: "数字回答") |
| assert resp["status"] == 200 |
|
|
| def test_single_character_input(self): |
| """单字符输入""" |
| mgr = make_redis_manager() |
| resp = simulate_chatbot({"question": "高"}, mgr, lambda q: "单字回答") |
| assert resp["status"] == 200 |
|
|
|
|
| |
| |
| |
|
|
| class TestAPIContractStability: |
| """ |
| 回归重点: API 响应格式不能随代码变更而破坏 |
| 前端依赖固定的 response/status/time 字段 |
| """ |
|
|
| def test_success_response_has_all_fields(self): |
| """成功响应必须包含 response + status + time""" |
| mgr = make_redis_manager() |
| resp = simulate_chatbot({"question": "高血压"}, mgr, lambda q: "回答") |
|
|
| assert "response" in resp, "必须有 response 字段" |
| assert "status" in resp, "必须有 status 字段" |
| assert "time" in resp, "必须有 time 字段" |
|
|
| def test_success_response_types(self): |
| """字段类型约束: response=str, status=int, time=str""" |
| mgr = make_redis_manager() |
| resp = simulate_chatbot({"question": "高血压"}, mgr, lambda q: "回答") |
|
|
| assert isinstance(resp["response"], str) |
| assert isinstance(resp["status"], int) |
| assert isinstance(resp["time"], str) |
|
|
| def test_success_status_is_200(self): |
| """成功时 status 必须为 200""" |
| mgr = make_redis_manager() |
| resp = simulate_chatbot({"question": "test"}, mgr, lambda q: "ok") |
| assert resp["status"] == 200 |
|
|
| def test_error_response_has_status_and_error(self): |
| """错误响应必须包含 status + error""" |
| mgr = make_redis_manager() |
| resp = simulate_chatbot({"question": ""}, mgr, lambda q: "") |
| assert resp["status"] == 400 |
| assert "error" in resp |
|
|
| def test_time_format_is_valid(self): |
| """time 字段格式为 YYYY-MM-DD HH:MM:SS""" |
| mgr = make_redis_manager() |
| resp = simulate_chatbot({"question": "test"}, mgr, lambda q: "ok") |
|
|
| time_str = resp["time"] |
| try: |
| parsed = datetime.datetime.strptime(time_str, "%Y-%m-%d %H:%M:%S") |
| assert parsed.year >= 2024 |
| except ValueError: |
| pytest.fail(f"time 格式不合法: {time_str}") |
|
|
| def test_response_not_none(self): |
| """response 字段不应为 None""" |
| mgr = make_redis_manager() |
| resp = simulate_chatbot({"question": "test"}, mgr, lambda q: "有内容的回答") |
| assert resp["response"] is not None |
|
|
| def test_double_encoded_json_still_works(self): |
| """双重 JSON 编码: 前端可能意外多编码一次""" |
| mgr = make_redis_manager() |
| double = json.dumps({"question": "高血压"}) |
| resp = simulate_chatbot(double, mgr, lambda q: "回答") |
| assert resp["status"] == 200 |
|
|
|
|
| |
| |
| |
|
|
| class TestRedisCacheRegression: |
| """ |
| 回归重点: 缓存 Key 不碰撞 / 特殊字符不破坏 / TTL 范围合理 |
| """ |
|
|
| def test_similar_questions_different_cache(self): |
| """相似但不同的问题应有不同缓存""" |
| mgr = make_redis_manager() |
| mgr.set_answer("高血压吃什么?", "答案A") |
| mgr.set_answer("高血压吃什么好?", "答案B") |
|
|
| assert mgr.get_answer("高血压吃什么?") == "答案A" |
| assert mgr.get_answer("高血压吃什么好?") == "答案B" |
|
|
| def test_question_with_special_chars_cacheable(self): |
| """含特殊字符的问题可以正常缓存和读取""" |
| mgr = make_redis_manager() |
| special_questions = [ |
| '引号"测试"', |
| "换行\n测试", |
| "斜杠/反斜杠\\测试", |
| "空格 多个 空格", |
| "emoji🏥测试", |
| ] |
| for i, q in enumerate(special_questions): |
| mgr.set_answer(q, f"答案_{i}") |
| assert mgr.get_answer(q) == f"答案_{i}", f"特殊字符缓存失败: {repr(q)}" |
|
|
| def test_ttl_within_reasonable_range(self): |
| """过期时间应在 [3240, 3960] 范围内 (3600 ± 10%)""" |
| mgr = make_redis_manager() |
| for i in range(50): |
| mgr.set_answer(f"ttl_test_{i}", "v", expire_time=3600) |
| key = mgr._generate_key(f"ttl_test_{i}") |
| ttl = mgr.client._expiry.get(key) |
| assert 3240 <= ttl <= 3960, f"TTL {ttl} 超出 ±10% 范围" |
|
|
| def test_cache_key_deterministic_across_calls(self): |
| """同一问题多次生成 Key, 结果完全一致""" |
| mgr = make_redis_manager() |
| q = "高血压不能吃什么食物?" |
| keys = [mgr._generate_key(q) for _ in range(100)] |
| assert len(set(keys)) == 1, "同一问题的 Key 必须确定性一致" |
|
|
| def test_cache_key_prefix_consistency(self): |
| """所有 Key 的前缀保持一致""" |
| mgr = make_redis_manager() |
| for q in ["Q1", "Q2", "Q3", "高血压", "糖尿病"]: |
| key = mgr._generate_key(q) |
| assert key.startswith("llm:cache:"), f"Key 前缀异常: {key}" |
|
|
| def test_overwrite_existing_cache(self): |
| """相同问题重复写入, 以最新为准""" |
| mgr = make_redis_manager() |
| mgr.set_answer("Q", "旧答案") |
| mgr.set_answer("Q", "新答案") |
| assert mgr.get_answer("Q") == "新答案" |
|
|
| def test_get_or_compute_lock_released_after_exception(self): |
| """compute_func 抛异常后, 锁应被释放 (finally 块)""" |
| mgr = make_redis_manager() |
|
|
| def exploding(): |
| raise RuntimeError("GPU OOM") |
|
|
| try: |
| mgr.get_or_compute("爆炸问题", exploding) |
| except RuntimeError: |
| pass |
|
|
| |
| hash_key = hashlib.md5("爆炸问题".encode('utf-8')).hexdigest() |
| token = mgr.acquire_lock(hash_key, acquire_timeout=0.1) |
| assert token is not None, "异常后锁应被释放" |
|
|
|
|
| |
| |
| |
|
|
| class TestNeo4jDefenseRegression: |
| """ |
| 回归重点: Cypher 注入 / 畸形响应 / 边界置信度 |
| """ |
|
|
| def test_cypher_injection_blocked_by_validation(self): |
| """恶意 Cypher 应被 validation 拦截""" |
| malicious_cypher = "MATCH (n) DETACH DELETE n" |
| req = make_requests_mock( |
| generate_resp={ |
| "cypher_query": malicious_cypher, |
| "confidence": 0.95, |
| "validated": True, |
| }, |
| validate_resp={"is_valid": False}, |
| ) |
| neo4j = make_neo4j_driver_mock() |
| llm = make_llm_mock() |
|
|
| perform_rag_testable( |
| "删除所有", make_milvus_mock(), make_pdf_mock(), |
| neo4j, llm, requests_module=req, |
| ) |
|
|
| |
| neo4j.session.return_value.__enter__.return_value.run.assert_not_called() |
|
|
| def test_confidence_string_type_handled(self): |
| """confidence 是字符串 "0.95" 而非数值 → float() 转换应正常""" |
| req = make_requests_mock(generate_resp={ |
| "cypher_query": "MATCH (d) RETURN d", |
| "confidence": "0.95", |
| "validated": True, |
| }) |
| llm = make_llm_mock() |
| |
| result = perform_rag_testable( |
| "test", make_milvus_mock(), make_pdf_mock(), |
| make_neo4j_driver_mock([("结果",)]), llm, requests_module=req, |
| ) |
| assert isinstance(result, str) |
|
|
| def test_missing_fields_in_cypher_response(self): |
| """Cypher 响应缺少字段 → 不崩溃""" |
| gen = MagicMock(); gen.status_code = 200 |
| gen.json.return_value = {"cypher_query": None} |
|
|
| req = MagicMock() |
| req.post.return_value = gen |
| llm = make_llm_mock("降级回答") |
|
|
| |
| result = perform_rag_testable( |
| "test", make_milvus_mock(), make_pdf_mock(), |
| make_neo4j_driver_mock(), llm, requests_module=req, |
| ) |
| assert isinstance(result, str) |
|
|
| def test_boundary_confidence_values(self): |
| """边界置信度: 0.89/0.90/0.91 的执行决策""" |
| for conf, should_exec in [(0.89, False), (0.90, True), (0.91, True)]: |
| req = make_requests_mock(generate_resp={ |
| "cypher_query": "MATCH (d) RETURN d", |
| "confidence": conf, |
| "validated": True, |
| }) |
| neo4j = make_neo4j_driver_mock([("结果",)]) |
| llm = make_llm_mock() |
|
|
| perform_rag_testable( |
| "test", make_milvus_mock(), make_pdf_mock(), |
| neo4j, llm, requests_module=req, |
| ) |
|
|
| sess = neo4j.session.return_value.__enter__.return_value |
| if should_exec: |
| sess.run.assert_called() |
| else: |
| sess.run.assert_not_called() |
|
|
| def test_neo4j_returns_non_string_results(self): |
| """Neo4j 返回非字符串类型 (数值/布尔) → join 不崩溃""" |
| neo4j = make_neo4j_driver_mock([("120",), ("80",)]) |
| req = make_requests_mock() |
| llm = make_llm_mock() |
|
|
| result = perform_rag_testable( |
| "血压", make_milvus_mock(), make_pdf_mock(), |
| neo4j, llm, requests_module=req, |
| ) |
| assert isinstance(result, str) |
|
|
|
|
| |
| |
| |
|
|
| class TestPromptFormatRegression: |
| """ |
| 回归重点: Prompt 模板变更后, 关键结构元素不丢失 |
| LLM 对 prompt 格式非常敏感, 格式回归是高频 bug |
| """ |
|
|
| def test_prompt_has_system_role(self): |
| """系统提示必须包含 '医学助手' 角色定义""" |
| llm = make_llm_mock() |
| perform_rag_testable( |
| "test", make_milvus_mock(), make_pdf_mock(), |
| make_neo4j_driver_mock(), llm, requests_module=make_requests_mock(error=True), |
| ) |
| prompt = llm.chat.completions.create.call_args.kwargs["messages"][0]["content"] |
| assert "医学助手" in prompt, "系统提示必须包含角色定义" |
|
|
| def test_prompt_has_context_tags(self): |
| """必须有 <context></context> 标签对""" |
| llm = make_llm_mock() |
| perform_rag_testable( |
| "test", make_milvus_mock(), make_pdf_mock(), |
| make_neo4j_driver_mock(), llm, requests_module=make_requests_mock(error=True), |
| ) |
| prompt = llm.chat.completions.create.call_args.kwargs["messages"][0]["content"] |
| assert "<context>" in prompt and "</context>" in prompt |
|
|
| def test_prompt_has_question_tags(self): |
| """必须有 <question></question> 标签对""" |
| llm = make_llm_mock() |
| perform_rag_testable( |
| "test", make_milvus_mock(), make_pdf_mock(), |
| make_neo4j_driver_mock(), llm, requests_module=make_requests_mock(error=True), |
| ) |
| prompt = llm.chat.completions.create.call_args.kwargs["messages"][0]["content"] |
| assert "<question>" in prompt and "</question>" in prompt |
|
|
| def test_question_appears_after_context(self): |
| """<question> 标签必须在 <context> 之后""" |
| llm = make_llm_mock() |
| perform_rag_testable( |
| "原始问题", make_milvus_mock(), make_pdf_mock(), |
| make_neo4j_driver_mock(), llm, requests_module=make_requests_mock(error=True), |
| ) |
| prompt = llm.chat.completions.create.call_args.kwargs["messages"][0]["content"] |
| ctx_pos = prompt.find("</context>") |
| q_pos = prompt.find("<question>") |
| assert ctx_pos < q_pos, "<question> 必须在 </context> 之后" |
|
|
| def test_user_query_not_altered_in_prompt(self): |
| """用户原始问题在 prompt 中不被修改""" |
| original = "宫腔异形的治疗方案有哪些?" |
| llm = make_llm_mock() |
| perform_rag_testable( |
| original, make_milvus_mock(), make_pdf_mock(), |
| make_neo4j_driver_mock(), llm, requests_module=make_requests_mock(error=True), |
| ) |
| prompt = llm.chat.completions.create.call_args.kwargs["messages"][0]["content"] |
| assert original in prompt, "用户原始问题不应被修改" |
|
|
| def test_context_contains_retrieval_content(self): |
| """<context> 内应包含检索到的内容""" |
| milvus = make_milvus_mock([FakeDocument(page_content="特定内容_ABC123")]) |
| llm = make_llm_mock() |
| perform_rag_testable( |
| "test", milvus, make_pdf_mock(content=""), |
| make_neo4j_driver_mock(), llm, requests_module=make_requests_mock(error=True), |
| ) |
| prompt = llm.chat.completions.create.call_args.kwargs["messages"][0]["content"] |
| |
| start = prompt.find("<context>") + len("<context>") |
| end = prompt.rfind("</context>") |
| context_content = prompt[start:end] |
| assert "特定内容_ABC123" in context_content |
|
|
|
|
| |
| |
| |
|
|
| class TestDataProcessingRegression: |
| """ |
| 回归重点: 脏数据/编码问题/空值 不导致入库流程崩溃 |
| """ |
|
|
| def test_jsonl_with_extra_fields_ignored(self): |
| """JSONL 行有多余字段 → 不影响解析""" |
| line = json.dumps({ |
| "query": "Q", "response": "A", |
| "extra_field": "ignored", "score": 0.95 |
| }) |
| data = json.loads(line) |
| content = data["query"] + "\n" + data["response"] |
| assert content == "Q\nA" |
|
|
| def test_jsonl_with_html_in_content(self): |
| """JSONL 内容包含 HTML 标签""" |
| line = json.dumps({ |
| "query": "<b>高血压</b>怎么治?", |
| "response": "<p>建议低盐饮食</p>" |
| }) |
| data = json.loads(line) |
| content = data["query"] + "\n" + data["response"] |
| assert "<b>高血压</b>" in content |
|
|
| def test_jsonl_with_very_long_response(self): |
| """JSONL 中超长 response (10KB)""" |
| long_resp = "详细医学解释。" * 2000 |
| line = json.dumps({"query": "Q", "response": long_resp}) |
| data = json.loads(line) |
| assert len(data["response"]) > 10000 |
|
|
| def test_jsonl_with_unicode_escapes(self): |
| """JSONL 包含 Unicode 转义序列""" |
| line = '{"query": "\\u9ad8\\u8840\\u538b", "response": "\\u4f4e\\u76d0"}' |
| data = json.loads(line) |
| assert data["query"] == "高血压" |
| assert data["response"] == "低盐" |
|
|
| def test_pdf_processor_handles_empty_pages(self, tmp_path): |
| """PDF 提取出空页 → 不影响整体处理""" |
| import pandas as pd |
| df = pd.DataFrame({ |
| "file_name": ["doc.pdf", "doc.pdf", "doc.pdf"], |
| "page_number": [1, 2, 3], |
| "text_content": ["正常内容", None, "另一页内容"] |
| }) |
| df_clean = df.dropna(subset=["text_content"]) |
| assert len(df_clean) == 2 |
|
|
| def test_pdf_processor_handles_nan_text(self, tmp_path): |
| """PDF 提取结果中 NaN → 应被过滤""" |
| import pandas as pd |
| import math |
| df = pd.DataFrame({ |
| "text_content": ["正常", float('nan'), "也正常", None] |
| }) |
| df_clean = df.dropna(subset=["text_content"]) |
| assert len(df_clean) == 2 |
|
|
|
|
| |
| |
| |
|
|
| class TestModelSwitchRegression: |
| """ |
| 回归重点: 切换 LLM 模型后, 调用参数和格式不受影响 |
| """ |
|
|
| def test_model_name_parameter(self): |
| """当前使用 gpt-4o-mini, 模型名应正确传递""" |
| llm = make_llm_mock() |
| perform_rag_testable( |
| "test", make_milvus_mock(), make_pdf_mock(), |
| make_neo4j_driver_mock(), llm, requests_module=make_requests_mock(error=True), |
| ) |
| model = llm.chat.completions.create.call_args.kwargs["model"] |
| assert model == "gpt-4o-mini" |
|
|
| def test_temperature_parameter(self): |
| """temperature 保持 0.7""" |
| llm = make_llm_mock() |
| perform_rag_testable( |
| "test", make_milvus_mock(), make_pdf_mock(), |
| make_neo4j_driver_mock(), llm, requests_module=make_requests_mock(error=True), |
| ) |
| temp = llm.chat.completions.create.call_args.kwargs["temperature"] |
| assert temp == 0.7 |
|
|
| def test_messages_format_single_user_message(self): |
| """messages 应为包含单个 user 角色消息的列表""" |
| llm = make_llm_mock() |
| perform_rag_testable( |
| "test", make_milvus_mock(), make_pdf_mock(), |
| make_neo4j_driver_mock(), llm, requests_module=make_requests_mock(error=True), |
| ) |
| msgs = llm.chat.completions.create.call_args.kwargs["messages"] |
| assert len(msgs) == 1 |
| assert msgs[0]["role"] == "user" |
| assert isinstance(msgs[0]["content"], str) |
|
|
| def test_embedding_model_name(self): |
| """Embedding 模型应为 text-embedding-3-small""" |
| from vector import OpenAIEmbeddings |
| emb = object.__new__(OpenAIEmbeddings) |
| mock_client = MagicMock() |
| mock_client.embeddings.create.return_value = type('R', (), { |
| 'data': [type('E', (), {'embedding': [0.1] * 1536})()] |
| })() |
| emb.client = mock_client |
|
|
| emb.embed_query("test") |
|
|
| call_kwargs = mock_client.embeddings.create.call_args.kwargs |
| assert call_kwargs["model"] == "text-embedding-3-small" |
|
|
|
|
| |
| |
| |
|
|
| class TestKnownBadCases: |
| """ |
| 回归核心: 历史上出现过的 bug, 每个 case 代表一次线上问题 |
| 每修复一个 bug, 就在这里加一个 test case, 永远不删 |
| """ |
|
|
| def test_badcase_double_json_encoding(self): |
| """ |
| Bad Case #1: 前端双重 JSON 编码 |
| 现象: 请求体是 '"{\\"question\\": \\"高血压\\"}"' 导致解析失败 |
| 修复: agent4.py L268 - isinstance(json_post_raw, str) 检查 |
| """ |
| mgr = make_redis_manager() |
| inner = json.dumps({"question": "高血压不能吃什么?"}) |
| resp = simulate_chatbot(inner, mgr, lambda q: "正确回答") |
| assert resp["status"] == 200, "双重编码应能正确解析" |
|
|
| def test_badcase_neo4j_none_cypher(self): |
| """ |
| Bad Case #2: Cypher 生成返回 null |
| 现象: cypher_query 为 None, 直接 .join() 报 TypeError |
| 修复: agent4.py L208 - cypher_query is not None 检查 |
| """ |
| req = make_requests_mock(generate_resp={ |
| "cypher_query": None, |
| "confidence": 0.95, |
| "validated": True, |
| }) |
| llm = make_llm_mock("回答") |
|
|
| |
| result = perform_rag_testable( |
| "test", make_milvus_mock(), make_pdf_mock(), |
| make_neo4j_driver_mock(), llm, requests_module=req, |
| ) |
| assert isinstance(result, str) |
|
|
| def test_badcase_empty_milvus_result(self): |
| """ |
| Bad Case #3: Milvus 返回空列表 |
| 现象: format_docs([]) 不报错, 但 context 为空字符串 |
| 修复: agent4.py L174 - if recall_rerank_milvus 检查 |
| """ |
| milvus = make_milvus_mock(docs=[]) |
| llm = make_llm_mock("经验回答") |
| result = perform_rag_testable( |
| "test", milvus, make_pdf_mock(content=""), |
| make_neo4j_driver_mock(), llm, requests_module=make_requests_mock(error=True), |
| ) |
| assert isinstance(result, str) |
|
|
| def test_badcase_pdf_returns_none(self): |
| """ |
| Bad Case #4: parent_retriever.invoke() 返回 None |
| 现象: len(None) 报 TypeError |
| 修复: agent4.py L185 - retrieved_docs is not None 检查 |
| """ |
| pdf = MagicMock() |
| pdf.invoke.return_value = None |
| llm = make_llm_mock("回答") |
|
|
| result = perform_rag_testable( |
| "test", make_milvus_mock(), pdf, |
| make_neo4j_driver_mock(), llm, requests_module=make_requests_mock(error=True), |
| ) |
| assert isinstance(result, str) |
|
|
| def test_badcase_redis_connection_lost_mid_request(self): |
| """ |
| Bad Case #5: Redis 请求中途断连 |
| 现象: get_answer 时 RedisError → 应返回 None 而非崩溃 |
| """ |
| mgr = make_redis_manager() |
| |
| original_get = mgr.client.get |
| call_count = 0 |
|
|
| def flaky_get(key): |
| nonlocal call_count |
| call_count += 1 |
| if call_count == 1: |
| return None |
| return original_get(key) |
|
|
| mgr.client.get = flaky_get |
| result = mgr.get_or_compute("flaky_q", lambda: "降级回答") |
| assert result == "降级回答" |
|
|
| def test_badcase_neo4j_session_exception(self): |
| """ |
| Bad Case #6: Neo4j session.run 报 ServiceUnavailable |
| 现象: with driver.session() 内部异常, 锁未释放 |
| 修复: agent4.py L220-227 - try-except 在 session 内 |
| """ |
| neo4j = make_neo4j_driver_mock() |
| sess = neo4j.session.return_value.__enter__.return_value |
| sess.run.side_effect = Exception("ServiceUnavailable") |
|
|
| llm = make_llm_mock("降级回答") |
| result = perform_rag_testable( |
| "test", make_milvus_mock(), make_pdf_mock(), |
| neo4j, llm, requests_module=make_requests_mock(), |
| ) |
| assert result == "降级回答" |
|
|
| def test_badcase_concurrent_same_key_race_condition(self): |
| """ |
| Bad Case #7: 并发写同一 Key 的竞态条件 |
| 现象: 两个线程同时 miss, 都去调 LLM, 浪费资源 |
| 修复: new_redis.py - acquire_lock + Double Check |
| """ |
| mgr = make_redis_manager() |
|
|
| |
| |
| |
| |
| call_count = 0 |
| def patched_get(q): |
| nonlocal call_count; call_count += 1 |
| if call_count <= 1: return None |
| return "其他线程的结果" |
|
|
| mgr.get_answer = patched_get |
| result = mgr.get_or_compute("race_q", lambda: "不应到这里") |
| assert result == "其他线程的结果" |
|
|
| def test_badcase_question_field_is_list(self): |
| """ |
| Bad Case #8: question 字段不是字符串 (如 list) |
| 现象: 传入 {"question": ["Q1", "Q2"]} → 后续处理崩溃 |
| """ |
| mgr = make_redis_manager() |
| resp = simulate_chatbot( |
| {"question": ["Q1", "Q2"]}, mgr, lambda q: "不应到这里" |
| ) |
| |
| |
| assert resp["status"] in [200, 400, 500], "不应导致未处理的崩溃" |
|
|
|
|
| |
| if __name__ == "__main__": |
| pytest.main([__file__, "-v", "--tb=short"]) |