| """ |
| ================================================================ |
| 医疗 RAG Agent 单元测试 — 单工具调用准确性 |
| ================================================================ |
| 测试对象: 项目中的 6 个核心组件 ("工具") |
| |
| 工具1: Redis 缓存管理器 (new_redis.py - RedisClientWrapper) |
| 工具2: OpenAI Embedding (vector.py - OpenAIEmbeddings) |
| 工具3: Milvus 向量检索 (agent4.py - similarity_search + format_docs) |
| 工具4: PDF 父子文档检索 (agent4.py - parent_retriever) |
| 工具5: Neo4j 图数据库查询 (agent4.py - Cypher 生成 → 校验 → 执行) |
| 工具6: OpenAI LLM 推理 (agent4.py - generate_openai_answer) |
| |
| 额外覆盖: |
| 工具7: PDF 批处理器 (preprocess.py - PDFBatchProcessor) |
| 工具8: 数据预处理 (vector.py - prepare_document) |
| 工具9: 端到端 RAG 流程编排 (agent4.py - perform_rag_and_llm 逻辑) |
| |
| 测试原则: |
| ✅ 每个组件独立测试, 用 Mock/Patch 隔离外部依赖 |
| ✅ 正常路径 + 异常路径 + 边界条件 |
| ✅ 不需要真实的 Redis / Milvus / Neo4j / OpenAI 连接 |
| ✅ 用 sys.modules 拦截无法安装的第三方包 |
| |
| 运行: |
| pytest test_agent_unit.py -v --tb=short |
| pytest test_agent_unit.py -v -k "Redis" # 只跑 Redis |
| pytest test_agent_unit.py -v -k "Embedding" # 只跑 Embedding |
| pytest test_agent_unit.py -v -k "Neo4j" # 只跑 Neo4j |
| ================================================================ |
| """ |
|
|
| import sys |
| import os |
| |
| |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) |
|
|
| import types |
| import pytest |
| import json |
| import hashlib |
| import time |
| import uuid |
| import random |
| from unittest.mock import MagicMock, patch, PropertyMock |
| from pathlib import Path |
| from dataclasses import dataclass, field |
| from typing import Optional, List |
|
|
|
|
| |
| |
| |
| |
|
|
| def _ensure_mock_module(name): |
| """如果模块不存在, 注入一个 MagicMock 占位""" |
| if name not in sys.modules: |
| sys.modules[name] = MagicMock() |
|
|
| |
| _MOCK_MODULES = [ |
| "langchain_classic", |
| "langchain_classic.retrievers", |
| "langchain_classic.retrievers.parent_document_retriever", |
| "langchain_milvus", |
| "langchain_text_splitters", |
| "langchain_core", |
| "langchain_core.stores", |
| "langchain_core.documents", |
| "langchain.embeddings", |
| "langchain.embeddings.base", |
| "neo4j", |
| "dotenv", |
| "uvicorn", |
| "fastapi", |
| "fastapi.middleware", |
| "fastapi.middleware.cors", |
| ] |
|
|
| for mod in _MOCK_MODULES: |
| _ensure_mock_module(mod) |
|
|
| |
| class _FakeEmbeddingsBase: |
| """占位基类, 让 OpenAIEmbeddings 能正常继承""" |
| pass |
|
|
| sys.modules["langchain.embeddings.base"].Embeddings = _FakeEmbeddingsBase |
|
|
|
|
| |
| |
| |
|
|
| class FakeRedisClient: |
| """内存字典模拟 Redis, 完整实现单元测试所需的全部方法""" |
|
|
| def __init__(self): |
| self._store = {} |
| self._expiry = {} |
|
|
| def ping(self): |
| return True |
|
|
| def get(self, key): |
| return self._store.get(key, None) |
|
|
| def set(self, key, value, ex=None, nx=False): |
| if nx and key in self._store: |
| return False |
| self._store[key] = value |
| if ex: |
| self._expiry[key] = ex |
| return True |
|
|
| def setex(self, key, expire, value): |
| self._store[key] = value |
| self._expiry[key] = expire |
| return True |
|
|
| def delete(self, key): |
| return 1 if self._store.pop(key, None) is not None else 0 |
|
|
| def hset(self, name, key, value): |
| self._store.setdefault(name, {})[key] = value |
|
|
| def hget(self, name, key): |
| return self._store.get(name, {}).get(key, None) |
|
|
| def expire(self, key, seconds): |
| self._expiry[key] = seconds |
|
|
| def register_script(self, script): |
| """模拟 Lua 脚本: 原子 CAS 删除""" |
| def fake_script(keys=None, args=None): |
| if keys and args and self._store.get(keys[0]) == args[0]: |
| del self._store[keys[0]] |
| return 1 |
| return 0 |
| return fake_script |
|
|
|
|
| @dataclass |
| class FakeDocument: |
| """模拟 LangChain Document""" |
| page_content: str |
| metadata: dict = field(default_factory=dict) |
|
|
|
|
| class FakeEmbeddingResponse: |
| """模拟 OpenAI Embedding API 响应""" |
| def __init__(self, embedding): |
| obj = type('EmbObj', (), {'embedding': embedding})() |
| self.data = [obj] |
|
|
|
|
| class FakeChatResponse: |
| """模拟 OpenAI Chat Completion 响应""" |
| def __init__(self, content): |
| msg = type('Msg', (), {'content': content})() |
| choice = type('Choice', (), {'message': msg})() |
| self.choices = [choice] |
|
|
|
|
| |
| |
| |
|
|
| def make_redis_manager(): |
| """构造一个使用内存假 Redis 的 RedisClientWrapper""" |
| from new_redis import RedisClientWrapper |
| RedisClientWrapper._pool = "FAKE" |
| mgr = object.__new__(RedisClientWrapper) |
| mgr.client = FakeRedisClient() |
| mgr.unlock_script = mgr.client.register_script("") |
| return mgr |
|
|
|
|
| |
| |
| |
|
|
| class TestRedisManager: |
| """ |
| 测试 new_redis.py - RedisClientWrapper |
| 覆盖: 缓存读写 / 防穿透 / 防雪崩 / 防击穿(分布式锁) / get_or_compute |
| """ |
|
|
| def setup_method(self): |
| self.mgr = make_redis_manager() |
| self.fake = self.mgr.client |
|
|
| |
|
|
| def test_key_deterministic(self): |
| """相同问题 → 相同 Key""" |
| k1 = self.mgr._generate_key("高血压不能吃什么?") |
| k2 = self.mgr._generate_key("高血压不能吃什么?") |
| assert k1 == k2 |
|
|
| def test_key_unique(self): |
| """不同问题 → 不同 Key""" |
| k1 = self.mgr._generate_key("高血压不能吃什么?") |
| k2 = self.mgr._generate_key("糖尿病怎么治疗?") |
| assert k1 != k2 |
|
|
| def test_key_has_prefix(self): |
| """Key 应带 'llm:cache:' 前缀""" |
| k = self.mgr._generate_key("test") |
| assert k.startswith("llm:cache:") |
|
|
| def test_key_is_md5(self): |
| """Key 后缀应为 MD5 哈希""" |
| q = "测试问题" |
| k = self.mgr._generate_key(q) |
| expected_hash = hashlib.md5(q.encode('utf-8')).hexdigest() |
| assert k == f"llm:cache:{expected_hash}" |
|
|
| |
|
|
| def test_set_then_get(self): |
| """写入后读取, 值一致""" |
| self.mgr.set_answer("Q1", "A1") |
| assert self.mgr.get_answer("Q1") == "A1" |
|
|
| def test_cache_miss_returns_none(self): |
| """未写入的 Key 返回 None""" |
| assert self.mgr.get_answer("不存在的问题") is None |
|
|
| |
|
|
| def test_empty_marker_returns_none(self): |
| """<EMPTY> 占位符 → get_answer 返回 None (不穿透到 LLM)""" |
| key = self.mgr._generate_key("空结果问题") |
| self.fake.setex(key, 60, "<EMPTY>") |
| assert self.mgr.get_answer("空结果问题") is None |
|
|
| def test_get_or_compute_writes_empty_on_null(self): |
| """LLM 返回空 → 写入 <EMPTY> 防穿透""" |
| self.mgr.get_or_compute("空问题", lambda: "") |
| key = self.mgr._generate_key("空问题") |
| assert self.fake.get(key) == "<EMPTY>" |
|
|
| |
|
|
| def test_random_expiry_jitter(self): |
| """多次写入同一过期时间, 实际 TTL 应有随机抖动""" |
| ttls = set() |
| for i in range(30): |
| self.mgr.set_answer(f"Q_{i}", f"A_{i}", expire_time=3600) |
| k = self.mgr._generate_key(f"Q_{i}") |
| ttls.add(self.fake._expiry.get(k)) |
| assert len(ttls) > 1, "过期时间应存在随机抖动, 防止集体失效" |
|
|
| |
|
|
| def test_lock_acquire_success(self): |
| """正常获取锁""" |
| token = self.mgr.acquire_lock("my_lock", acquire_timeout=1) |
| assert token is not None |
|
|
| def test_lock_mutual_exclusion(self): |
| """已持有锁时, 二次获取应超时失败""" |
| t1 = self.mgr.acquire_lock("excl", acquire_timeout=0.1) |
| t2 = self.mgr.acquire_lock("excl", acquire_timeout=0.1) |
| assert t1 is not None |
| assert t2 is None, "互斥: 不应同时获取两把锁" |
|
|
| def test_lock_release(self): |
| """释放锁后, Key 被删除""" |
| token = self.mgr.acquire_lock("rel_lock") |
| assert self.mgr.release_lock("rel_lock", token) is True |
| assert self.fake.get("lock:rel_lock") is None |
|
|
| def test_lock_wrong_token_rejected(self): |
| """用错误 token 释放锁应失败""" |
| self.mgr.acquire_lock("sec_lock") |
| assert self.mgr.release_lock("sec_lock", "wrong-uuid") is False |
|
|
| |
|
|
| def test_cache_hit_skips_compute(self): |
| """缓存命中 → 不调用 compute_func""" |
| self.mgr.set_answer("cached_q", "cached_a") |
| called = False |
|
|
| def spy(): |
| nonlocal called; called = True; return "new" |
|
|
| result = self.mgr.get_or_compute("cached_q", spy) |
| assert result == "cached_a" |
| assert called is False |
|
|
| def test_cache_miss_calls_compute(self): |
| """缓存未命中 → 调用 compute_func 并缓存""" |
| result = self.mgr.get_or_compute("new_q", lambda: "LLM答案") |
| assert result == "LLM答案" |
| assert self.mgr.get_answer("new_q") == "LLM答案" |
|
|
| def test_double_check_prevents_redundant_compute(self): |
| """Double Check: 获取锁后再次检查, 避免重复调用 LLM""" |
| call_count = 0 |
| original_get = self.mgr.get_answer |
|
|
| def patched_get(q): |
| nonlocal call_count; call_count += 1 |
| if call_count == 1: |
| return None |
| return "其他线程写入" |
|
|
| self.mgr.get_answer = patched_get |
|
|
| def should_not_call(): |
| raise AssertionError("Double Check 成功时不应调 LLM") |
|
|
| result = self.mgr.get_or_compute("dc_q", should_not_call) |
| assert result == "其他线程写入" |
|
|
|
|
| |
| |
| |
|
|
| class TestEmbedding: |
| """ |
| 测试 vector.py - OpenAIEmbeddings |
| 覆盖: embed_query / embed_documents / 维度一致性 / API 异常 |
| """ |
|
|
| def _make_embedder(self, mock_client): |
| """用 Mock OpenAI client 构造 embedder, 绕过真实连接""" |
| from vector import OpenAIEmbeddings |
| embedder = object.__new__(OpenAIEmbeddings) |
| embedder.client = mock_client |
| return embedder |
|
|
| def _fake_vec(self, dim=1536): |
| return [random.uniform(-1, 1) for _ in range(dim)] |
|
|
| def test_embed_query_dimension(self): |
| """单条嵌入: 返回 1536 维向量""" |
| mock = MagicMock() |
| mock.embeddings.create.return_value = FakeEmbeddingResponse(self._fake_vec()) |
| emb = self._make_embedder(mock) |
|
|
| vec = emb.embed_query("高血压症状") |
| assert isinstance(vec, list) |
| assert len(vec) == 1536 |
|
|
| def test_embed_documents_batch(self): |
| """批量嵌入: 3 条文本 → 3 个向量""" |
| mock = MagicMock() |
| mock.embeddings.create.side_effect = [ |
| FakeEmbeddingResponse(self._fake_vec()), |
| FakeEmbeddingResponse(self._fake_vec()), |
| FakeEmbeddingResponse(self._fake_vec()), |
| ] |
| emb = self._make_embedder(mock) |
|
|
| vecs = emb.embed_documents(["A", "B", "C"]) |
| assert len(vecs) == 3 |
| assert all(len(v) == 1536 for v in vecs) |
|
|
| def test_embed_query_calls_correct_model(self): |
| """验证调用时传入 model='text-embedding-3-small'""" |
| mock = MagicMock() |
| mock.embeddings.create.return_value = FakeEmbeddingResponse(self._fake_vec()) |
| emb = self._make_embedder(mock) |
|
|
| emb.embed_query("test") |
|
|
| |
| call_kwargs = mock.embeddings.create.call_args.kwargs |
| assert call_kwargs.get("model") == "text-embedding-3-small" |
|
|
| def test_embed_empty_text(self): |
| """空字符串也应返回向量 (不报错)""" |
| mock = MagicMock() |
| mock.embeddings.create.return_value = FakeEmbeddingResponse(self._fake_vec()) |
| emb = self._make_embedder(mock) |
|
|
| vec = emb.embed_query("") |
| assert isinstance(vec, list) and len(vec) == 1536 |
|
|
| def test_embed_api_error_propagates(self): |
| """API 报错时异常应向上传播""" |
| mock = MagicMock() |
| mock.embeddings.create.side_effect = Exception("Rate limit exceeded") |
| emb = self._make_embedder(mock) |
|
|
| with pytest.raises(Exception, match="Rate limit"): |
| emb.embed_query("test") |
|
|
| def test_embed_chinese_medical_text(self): |
| """中文医学文本嵌入应正常工作""" |
| mock = MagicMock() |
| mock.embeddings.create.return_value = FakeEmbeddingResponse(self._fake_vec()) |
| emb = self._make_embedder(mock) |
|
|
| vec = emb.embed_query("宫腔异形的治疗方案有哪些?") |
| assert len(vec) == 1536 |
|
|
|
|
| |
| |
| |
|
|
| class TestMilvusRetrieval: |
| """ |
| 测试 Milvus similarity_search + format_docs |
| 覆盖: 正常召回 / 空结果 / 重排序参数 / 异常降级 |
| """ |
|
|
| def test_format_docs_normal(self): |
| """3 篇文档 → 用双换行拼接""" |
| docs = [ |
| FakeDocument(page_content="高血压是常见疾病"), |
| FakeDocument(page_content="建议低盐饮食"), |
| FakeDocument(page_content="定期测量血压"), |
| ] |
| result = "\n\n".join(d.page_content for d in docs) |
| assert result.count("\n\n") == 2 |
| assert "低盐饮食" in result |
|
|
| def test_format_docs_empty(self): |
| """空列表 → 空字符串""" |
| assert "\n\n".join(d.page_content for d in []) == "" |
|
|
| def test_format_docs_single(self): |
| """单篇文档 → 无分隔符""" |
| result = "\n\n".join(d.page_content for d in [FakeDocument(page_content="唯一")]) |
| assert result == "唯一" |
|
|
| def test_similarity_search_returns_topk(self): |
| """Milvus 应返回 k=10 的 top-k 结果""" |
| mock_vs = MagicMock() |
| mock_vs.similarity_search.return_value = [ |
| FakeDocument(page_content=f"doc_{i}") for i in range(10) |
| ] |
| results = mock_vs.similarity_search("query", k=10, ranker_type="rrf", ranker_params={"k": 100}) |
| assert len(results) == 10 |
|
|
| def test_similarity_search_rrf_params_passed(self): |
| """验证 RRF 重排序参数被正确传递""" |
| mock_vs = MagicMock() |
| mock_vs.similarity_search.return_value = [] |
|
|
| mock_vs.similarity_search("q", k=10, ranker_type="rrf", ranker_params={"k": 100}) |
|
|
| call_kwargs = mock_vs.similarity_search.call_args.kwargs |
| assert call_kwargs["ranker_type"] == "rrf" |
| assert call_kwargs["ranker_params"] == {"k": 100} |
|
|
| def test_similarity_search_empty(self): |
| """无匹配时 → context 为空""" |
| mock_vs = MagicMock() |
| mock_vs.similarity_search.return_value = [] |
|
|
| results = mock_vs.similarity_search("xyz无关查询") |
| context = "\n\n".join(d.page_content for d in results) if results else "" |
| assert context == "" |
|
|
| def test_similarity_search_exception(self): |
| """Milvus 服务异常 → 应抛出异常 (agent 层决定降级策略)""" |
| mock_vs = MagicMock() |
| mock_vs.similarity_search.side_effect = ConnectionError("Milvus timeout") |
|
|
| with pytest.raises(ConnectionError): |
| mock_vs.similarity_search("test") |
|
|
|
|
| |
| |
| |
|
|
| class TestPDFRetrieval: |
| """ |
| 测试 parent_retriever.invoke() |
| 覆盖: 正常召回 / 空结果 / None / 长文档 |
| """ |
|
|
| def test_retriever_returns_document(self): |
| """正常检索 → 返回至少 1 篇文档""" |
| mock_ret = MagicMock() |
| mock_ret.invoke.return_value = [ |
| FakeDocument(page_content="根据《高血压防治指南》第三章...") |
| ] |
| results = mock_ret.invoke("高血压分级标准") |
| assert len(results) >= 1 |
| assert "高血压" in results[0].page_content |
|
|
| def test_retriever_empty_list(self): |
| """无匹配 → pdf_res 为空""" |
| mock_ret = MagicMock() |
| mock_ret.invoke.return_value = [] |
|
|
| results = mock_ret.invoke("xyz") |
| pdf_res = results[0].page_content if results else "" |
| assert pdf_res == "" |
|
|
| def test_retriever_none_safe(self): |
| """返回 None → 不报错, pdf_res 为空""" |
| mock_ret = MagicMock() |
| mock_ret.invoke.return_value = None |
|
|
| results = mock_ret.invoke("test") |
| pdf_res = "" |
| if results is not None and len(results) >= 1: |
| pdf_res = results[0].page_content |
| assert pdf_res == "" |
|
|
| def test_retriever_long_document(self): |
| """长文档应完整返回""" |
| long = "医学文献内容。" * 500 |
| mock_ret = MagicMock() |
| mock_ret.invoke.return_value = [FakeDocument(page_content=long)] |
|
|
| r = mock_ret.invoke("长文档") |
| assert len(r[0].page_content) == len(long) |
|
|
| def test_retriever_multiple_results_takes_first(self): |
| """agent4.py 只取 results[0], 验证此行为""" |
| mock_ret = MagicMock() |
| mock_ret.invoke.return_value = [ |
| FakeDocument(page_content="最相关"), |
| FakeDocument(page_content="第二篇"), |
| ] |
| results = mock_ret.invoke("test") |
| pdf_res = results[0].page_content if results else "" |
| assert pdf_res == "最相关" |
|
|
|
|
| |
| |
| |
|
|
| class TestNeo4jCypherPipeline: |
| """ |
| 测试 agent4.py 中 Neo4j 三阶段流程: |
| Stage 1: POST /generate → Cypher + confidence + validated |
| Stage 2: POST /validate → is_valid |
| Stage 3: session.run(cypher) → 结果提取 |
| """ |
|
|
| |
|
|
| def test_high_confidence_valid_executes(self): |
| """0.95 + validated=True → 执行""" |
| d = {"cypher_query": "MATCH (d:Disease) RETURN d", "confidence": 0.95, "validated": True} |
| assert (d["cypher_query"] is not None and float(d["confidence"]) >= 0.9 and d["validated"]) is True |
|
|
| def test_low_confidence_skips(self): |
| """0.5 < 0.9 → 不执行""" |
| d = {"cypher_query": "MATCH", "confidence": 0.5, "validated": True} |
| assert (float(d["confidence"]) >= 0.9 and d["validated"]) is False |
|
|
| def test_invalid_skips(self): |
| """validated=False → 不执行""" |
| d = {"cypher_query": "BAD", "confidence": 0.99, "validated": False} |
| assert (float(d["confidence"]) >= 0.9 and d["validated"]) is False |
|
|
| def test_null_cypher_skips(self): |
| """cypher_query=None → 不执行""" |
| d = {"cypher_query": None, "confidence": 0.95, "validated": True} |
| assert (d["cypher_query"] is not None) is False |
|
|
| def test_boundary_089_skips(self): |
| """边界 0.89 → 不执行""" |
| assert (0.89 >= 0.9) is False |
|
|
| def test_boundary_090_executes(self): |
| """边界 0.90 → 执行""" |
| assert (0.90 >= 0.9) is True |
|
|
| |
|
|
| def test_validate_pass(self): |
| resp = MagicMock(); resp.json.return_value = {"is_valid": True} |
| assert resp.json()["is_valid"] is True |
|
|
| def test_validate_fail(self): |
| resp = MagicMock(); resp.json.return_value = {"is_valid": False} |
| assert resp.json()["is_valid"] is False |
|
|
| |
|
|
| def test_neo4j_run_success(self): |
| """正常执行 → 逗号拼接""" |
| mock_session = MagicMock() |
| mock_session.run.return_value = [("高血压",), ("糖尿病",)] |
| result = list(map(lambda x: x[0], mock_session.run("MATCH ..."))) |
| assert ','.join(result) == "高血压,糖尿病" |
|
|
| def test_neo4j_run_empty(self): |
| """空结果 → 空字符串""" |
| mock_session = MagicMock() |
| mock_session.run.return_value = [] |
| result = list(map(lambda x: x[0], mock_session.run("MATCH ..."))) |
| assert ','.join(result) == "" |
|
|
| def test_neo4j_run_exception_graceful(self): |
| """查询异常 → 降级为空""" |
| mock_session = MagicMock() |
| mock_session.run.side_effect = Exception("Connection lost") |
| neo4j_res = "" |
| try: |
| result = list(map(lambda x: x[0], mock_session.run("BAD"))) |
| neo4j_res = ','.join(result) |
| except Exception: |
| neo4j_res = "" |
| assert neo4j_res == "" |
|
|
| def test_cypher_service_down(self): |
| """Cypher API 宕机 → 降级为空""" |
| with patch('requests.post', side_effect=ConnectionError("refused")): |
| neo4j_res = "" |
| try: |
| import requests |
| requests.post("http://0.0.0.0:8101/generate", "{}") |
| except Exception: |
| neo4j_res = "" |
| assert neo4j_res == "" |
|
|
|
|
| |
| |
| |
|
|
| class TestLLMInference: |
| """ |
| 测试 agent4.py - generate_openai_answer |
| 覆盖: 正常生成 / Prompt 构建 / 空返回 / 异常 |
| """ |
|
|
| def test_generate_success(self): |
| """正常生成回复""" |
| mock = MagicMock() |
| mock.chat.completions.create.return_value = FakeChatResponse( |
| "高血压患者应避免高盐饮食, 每日钠 <6g." |
| ) |
| answer = mock.chat.completions.create( |
| model="gpt-4o-mini", |
| messages=[{"role": "user", "content": "高血压饮食"}], |
| temperature=0.7, |
| ).choices[0].message.content |
| assert "高血压" in answer and len(answer) > 10 |
|
|
| def test_prompt_structure(self): |
| """Prompt 包含: 系统角色 + <context> + <question>""" |
| query, context = "高血压不能吃什么?", "低盐饮食" |
| SYSTEM = "System: 你是一个非常得力的医学助手." |
| USER = f"<context>\n{context}\n</context>\n<question>\n{query}\n</question>" |
| full = SYSTEM + USER |
| assert "医学助手" in full and query in full and context in full |
|
|
| def test_prompt_empty_context(self): |
| """上下文为空 → Prompt 仍完整""" |
| p = "<context>\n\n</context>\n<question>\n糖尿病?\n</question>" |
| assert "<context>" in p and "糖尿病" in p |
|
|
| def test_llm_timeout(self): |
| """LLM 超时 → 异常传播""" |
| mock = MagicMock() |
| mock.chat.completions.create.side_effect = TimeoutError("timeout") |
| with pytest.raises(TimeoutError): |
| mock.chat.completions.create(model="gpt-4o-mini", messages=[]) |
|
|
| def test_llm_empty_response(self): |
| """LLM 返回空""" |
| mock = MagicMock() |
| mock.chat.completions.create.return_value = FakeChatResponse("") |
| answer = mock.chat.completions.create(model="m", messages=[]).choices[0].message.content |
| assert answer == "" |
|
|
| def test_generate_with_temperature(self): |
| """验证 temperature=0.7 被正确传递""" |
| mock = MagicMock() |
| mock.chat.completions.create.return_value = FakeChatResponse("ok") |
|
|
| mock.chat.completions.create( |
| model="gpt-4o-mini", |
| messages=[{"role": "user", "content": "test"}], |
| temperature=0.7, |
| ) |
| assert mock.chat.completions.create.call_args.kwargs["temperature"] == 0.7 |
|
|
|
|
| |
| |
| |
|
|
| class TestPDFProcessor: |
| """ |
| 测试 preprocess.py - PDFBatchProcessor |
| """ |
|
|
| def test_invalid_path_raises(self): |
| from preprocess import PDFBatchProcessor |
| proc = PDFBatchProcessor(output_dir="/tmp/test_pdf_out") |
| with pytest.raises(ValueError, match="路径不存在"): |
| proc.find_pdf_files("/nonexistent/xyz.txt") |
|
|
| def test_empty_dir(self, tmp_path): |
| from preprocess import PDFBatchProcessor |
| proc = PDFBatchProcessor(output_dir=str(tmp_path / "out")) |
| assert proc.find_pdf_files(str(tmp_path)) == [] |
|
|
| def test_finds_pdf_only(self, tmp_path): |
| """只查找 .pdf 文件""" |
| from preprocess import PDFBatchProcessor |
| (tmp_path / "a.pdf").touch() |
| (tmp_path / "b.pdf").touch() |
| (tmp_path / "c.txt").touch() |
| proc = PDFBatchProcessor(output_dir=str(tmp_path / "out")) |
| files = proc.find_pdf_files(str(tmp_path)) |
| assert len(files) == 2 |
|
|
| def test_single_pdf_file(self, tmp_path): |
| from preprocess import PDFBatchProcessor |
| pdf = tmp_path / "x.pdf"; pdf.touch() |
| proc = PDFBatchProcessor(output_dir=str(tmp_path / "out")) |
| assert proc.find_pdf_files(str(pdf)) == [pdf] |
|
|
| def test_extract_nonexistent_has_error(self, tmp_path): |
| from preprocess import PDFBatchProcessor |
| proc = PDFBatchProcessor(output_dir=str(tmp_path / "out")) |
| r = proc.extract_pdf_content(Path("/no/file.pdf")) |
| assert r["error"] is not None |
|
|
| def test_result_has_required_keys(self, tmp_path): |
| from preprocess import PDFBatchProcessor |
| proc = PDFBatchProcessor(output_dir=str(tmp_path / "out")) |
| r = proc.extract_pdf_content(Path("dummy.pdf")) |
| for k in ["file_name", "file_path", "metadata", "pages", "error"]: |
| assert k in r |
|
|
|
|
| |
| |
| |
|
|
| class TestDataPreprocessing: |
|
|
| def test_jsonl_parse(self, tmp_path): |
| f = tmp_path / "t.jsonl" |
| f.write_text( |
| json.dumps({"query": "Q1", "response": "A1"}, ensure_ascii=False) + "\n" |
| + json.dumps({"query": "Q2", "response": "A2"}, ensure_ascii=False) + "\n" |
| ) |
| docs = [] |
| with open(f) as fh: |
| for line in fh: |
| c = json.loads(line.strip()) |
| docs.append(c["query"] + "\n" + c["response"]) |
| assert len(docs) == 2 and "Q1" in docs[0] |
|
|
| def test_jsonl_empty(self, tmp_path): |
| f = tmp_path / "e.jsonl"; f.write_text("") |
| assert sum(1 for line in open(f) if line.strip()) == 0 |
|
|
| def test_jsonl_bad_line(self, tmp_path): |
| f = tmp_path / "b.jsonl" |
| f.write_text('{"query":"ok","response":"r"}\n{bad}\n') |
| ok, bad = 0, 0 |
| for line in open(f): |
| try: |
| json.loads(line); ok += 1 |
| except json.JSONDecodeError: |
| bad += 1 |
| assert ok == 1 and bad == 1 |
|
|
|
|
| |
| |
| |
|
|
| class TestRAGOrchestration: |
|
|
| def test_three_way_merge(self): |
| ctx = "M结果" + "\n" + "P结果" + "\n" + "N结果" |
| assert "M结果" in ctx and "P结果" in ctx and "N结果" in ctx |
|
|
| def test_partial_empty_merge(self): |
| ctx = "有结果" + "\n" + "" + "\n" + "" |
| assert "有结果" in ctx |
|
|
| def test_all_empty_merge(self): |
| ctx = "" + "\n" + "" + "\n" + "" |
| assert ctx.strip() == "" |
|
|
| def test_request_valid(self): |
| assert {"question": "Q"}.get("question") == "Q" |
|
|
| def test_request_missing_question(self): |
| assert {"query": "x"}.get("question") is None |
|
|
| def test_redis_caching_in_chatbot(self): |
| """chatbot 使用 redis get_or_compute: 缓存命中 → 跳过 RAG""" |
| mgr = make_redis_manager() |
| mgr.set_answer("Q", "缓存A") |
| called = False |
| def rag(): nonlocal called; called = True; return "new" |
| assert mgr.get_or_compute("Q", rag) == "缓存A" |
| assert called is False |
|
|
|
|
| |
| if __name__ == "__main__": |
| pytest.main([__file__, "-v", "--tb=short"]) |