agentV2 / test /test1.py
drewli20200316's picture
Add test/test1.py
dae033a verified
"""
================================================================
医疗 RAG Agent 单元测试 — 单工具调用准确性
================================================================
测试对象: 项目中的 6 个核心组件 ("工具")
工具1: Redis 缓存管理器 (new_redis.py - RedisClientWrapper)
工具2: OpenAI Embedding (vector.py - OpenAIEmbeddings)
工具3: Milvus 向量检索 (agent4.py - similarity_search + format_docs)
工具4: PDF 父子文档检索 (agent4.py - parent_retriever)
工具5: Neo4j 图数据库查询 (agent4.py - Cypher 生成 → 校验 → 执行)
工具6: OpenAI LLM 推理 (agent4.py - generate_openai_answer)
额外覆盖:
工具7: PDF 批处理器 (preprocess.py - PDFBatchProcessor)
工具8: 数据预处理 (vector.py - prepare_document)
工具9: 端到端 RAG 流程编排 (agent4.py - perform_rag_and_llm 逻辑)
测试原则:
✅ 每个组件独立测试, 用 Mock/Patch 隔离外部依赖
✅ 正常路径 + 异常路径 + 边界条件
✅ 不需要真实的 Redis / Milvus / Neo4j / OpenAI 连接
✅ 用 sys.modules 拦截无法安装的第三方包
运行:
pytest test_agent_unit.py -v --tb=short
pytest test_agent_unit.py -v -k "Redis" # 只跑 Redis
pytest test_agent_unit.py -v -k "Embedding" # 只跑 Embedding
pytest test_agent_unit.py -v -k "Neo4j" # 只跑 Neo4j
================================================================
"""
import sys
import os
# 关键: 将项目根目录 (test/ 的上级) 加入 Python 搜索路径
# 这样 test/ 子目录中的测试文件才能找到 new_redis.py, vector.py, preprocess.py 等模块
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
import types
import pytest
import json
import hashlib
import time
import uuid
import random
from unittest.mock import MagicMock, patch, PropertyMock
from pathlib import Path
from dataclasses import dataclass, field
from typing import Optional, List
# ================================================================
# 前置: 用 sys.modules 拦截无法安装的第三方依赖
# 这样 `from vector import X` 不会因缺少 langchain_classic 而崩溃
# ================================================================
def _ensure_mock_module(name):
"""如果模块不存在, 注入一个 MagicMock 占位"""
if name not in sys.modules:
sys.modules[name] = MagicMock()
# 拦截所有可能缺失的依赖
_MOCK_MODULES = [
"langchain_classic",
"langchain_classic.retrievers",
"langchain_classic.retrievers.parent_document_retriever",
"langchain_milvus",
"langchain_text_splitters",
"langchain_core",
"langchain_core.stores",
"langchain_core.documents",
"langchain.embeddings",
"langchain.embeddings.base",
"neo4j",
"dotenv",
"uvicorn",
"fastapi",
"fastapi.middleware",
"fastapi.middleware.cors",
]
for mod in _MOCK_MODULES:
_ensure_mock_module(mod)
# 关键修复: langchain Embeddings 基类必须是真正的 class, 否则继承会失败
class _FakeEmbeddingsBase:
"""占位基类, 让 OpenAIEmbeddings 能正常继承"""
pass
sys.modules["langchain.embeddings.base"].Embeddings = _FakeEmbeddingsBase
# ================================================================
# 测试辅助: 模拟对象和数据工厂
# ================================================================
class FakeRedisClient:
"""内存字典模拟 Redis, 完整实现单元测试所需的全部方法"""
def __init__(self):
self._store = {}
self._expiry = {}
def ping(self):
return True
def get(self, key):
return self._store.get(key, None)
def set(self, key, value, ex=None, nx=False):
if nx and key in self._store:
return False
self._store[key] = value
if ex:
self._expiry[key] = ex
return True
def setex(self, key, expire, value):
self._store[key] = value
self._expiry[key] = expire
return True
def delete(self, key):
return 1 if self._store.pop(key, None) is not None else 0
def hset(self, name, key, value):
self._store.setdefault(name, {})[key] = value
def hget(self, name, key):
return self._store.get(name, {}).get(key, None)
def expire(self, key, seconds):
self._expiry[key] = seconds
def register_script(self, script):
"""模拟 Lua 脚本: 原子 CAS 删除"""
def fake_script(keys=None, args=None):
if keys and args and self._store.get(keys[0]) == args[0]:
del self._store[keys[0]]
return 1
return 0
return fake_script
@dataclass
class FakeDocument:
"""模拟 LangChain Document"""
page_content: str
metadata: dict = field(default_factory=dict)
class FakeEmbeddingResponse:
"""模拟 OpenAI Embedding API 响应"""
def __init__(self, embedding):
obj = type('EmbObj', (), {'embedding': embedding})()
self.data = [obj]
class FakeChatResponse:
"""模拟 OpenAI Chat Completion 响应"""
def __init__(self, content):
msg = type('Msg', (), {'content': content})()
choice = type('Choice', (), {'message': msg})()
self.choices = [choice]
# ================================================================
# 辅助: 创建被测 RedisClientWrapper 实例 (注入假 Redis)
# ================================================================
def make_redis_manager():
"""构造一个使用内存假 Redis 的 RedisClientWrapper"""
from new_redis import RedisClientWrapper
RedisClientWrapper._pool = "FAKE" # 跳过连接池创建
mgr = object.__new__(RedisClientWrapper) # 跳过 __init__
mgr.client = FakeRedisClient()
mgr.unlock_script = mgr.client.register_script("")
return mgr
# ================================================================
# 工具 1: Redis 缓存管理器
# ================================================================
class TestRedisManager:
"""
测试 new_redis.py - RedisClientWrapper
覆盖: 缓存读写 / 防穿透 / 防雪崩 / 防击穿(分布式锁) / get_or_compute
"""
def setup_method(self):
self.mgr = make_redis_manager()
self.fake = self.mgr.client # 直接访问底层假 Redis
# ---- 1.1 Key 生成 ----
def test_key_deterministic(self):
"""相同问题 → 相同 Key"""
k1 = self.mgr._generate_key("高血压不能吃什么?")
k2 = self.mgr._generate_key("高血压不能吃什么?")
assert k1 == k2
def test_key_unique(self):
"""不同问题 → 不同 Key"""
k1 = self.mgr._generate_key("高血压不能吃什么?")
k2 = self.mgr._generate_key("糖尿病怎么治疗?")
assert k1 != k2
def test_key_has_prefix(self):
"""Key 应带 'llm:cache:' 前缀"""
k = self.mgr._generate_key("test")
assert k.startswith("llm:cache:")
def test_key_is_md5(self):
"""Key 后缀应为 MD5 哈希"""
q = "测试问题"
k = self.mgr._generate_key(q)
expected_hash = hashlib.md5(q.encode('utf-8')).hexdigest()
assert k == f"llm:cache:{expected_hash}"
# ---- 1.2 基础读写 ----
def test_set_then_get(self):
"""写入后读取, 值一致"""
self.mgr.set_answer("Q1", "A1")
assert self.mgr.get_answer("Q1") == "A1"
def test_cache_miss_returns_none(self):
"""未写入的 Key 返回 None"""
assert self.mgr.get_answer("不存在的问题") is None
# ---- 1.3 防缓存穿透 ----
def test_empty_marker_returns_none(self):
"""<EMPTY> 占位符 → get_answer 返回 None (不穿透到 LLM)"""
key = self.mgr._generate_key("空结果问题")
self.fake.setex(key, 60, "<EMPTY>")
assert self.mgr.get_answer("空结果问题") is None
def test_get_or_compute_writes_empty_on_null(self):
"""LLM 返回空 → 写入 <EMPTY> 防穿透"""
self.mgr.get_or_compute("空问题", lambda: "")
key = self.mgr._generate_key("空问题")
assert self.fake.get(key) == "<EMPTY>"
# ---- 1.4 防缓存雪崩 ----
def test_random_expiry_jitter(self):
"""多次写入同一过期时间, 实际 TTL 应有随机抖动"""
ttls = set()
for i in range(30):
self.mgr.set_answer(f"Q_{i}", f"A_{i}", expire_time=3600)
k = self.mgr._generate_key(f"Q_{i}")
ttls.add(self.fake._expiry.get(k))
assert len(ttls) > 1, "过期时间应存在随机抖动, 防止集体失效"
# ---- 1.5 分布式锁 (防击穿) ----
def test_lock_acquire_success(self):
"""正常获取锁"""
token = self.mgr.acquire_lock("my_lock", acquire_timeout=1)
assert token is not None
def test_lock_mutual_exclusion(self):
"""已持有锁时, 二次获取应超时失败"""
t1 = self.mgr.acquire_lock("excl", acquire_timeout=0.1)
t2 = self.mgr.acquire_lock("excl", acquire_timeout=0.1)
assert t1 is not None
assert t2 is None, "互斥: 不应同时获取两把锁"
def test_lock_release(self):
"""释放锁后, Key 被删除"""
token = self.mgr.acquire_lock("rel_lock")
assert self.mgr.release_lock("rel_lock", token) is True
assert self.fake.get("lock:rel_lock") is None
def test_lock_wrong_token_rejected(self):
"""用错误 token 释放锁应失败"""
self.mgr.acquire_lock("sec_lock")
assert self.mgr.release_lock("sec_lock", "wrong-uuid") is False
# ---- 1.6 get_or_compute 完整流程 ----
def test_cache_hit_skips_compute(self):
"""缓存命中 → 不调用 compute_func"""
self.mgr.set_answer("cached_q", "cached_a")
called = False
def spy():
nonlocal called; called = True; return "new"
result = self.mgr.get_or_compute("cached_q", spy)
assert result == "cached_a"
assert called is False
def test_cache_miss_calls_compute(self):
"""缓存未命中 → 调用 compute_func 并缓存"""
result = self.mgr.get_or_compute("new_q", lambda: "LLM答案")
assert result == "LLM答案"
assert self.mgr.get_answer("new_q") == "LLM答案"
def test_double_check_prevents_redundant_compute(self):
"""Double Check: 获取锁后再次检查, 避免重复调用 LLM"""
call_count = 0
original_get = self.mgr.get_answer
def patched_get(q):
nonlocal call_count; call_count += 1
if call_count == 1:
return None # 第一次: miss
return "其他线程写入" # 第二次 (Double Check): hit
self.mgr.get_answer = patched_get
def should_not_call():
raise AssertionError("Double Check 成功时不应调 LLM")
result = self.mgr.get_or_compute("dc_q", should_not_call)
assert result == "其他线程写入"
# ================================================================
# 工具 2: OpenAI Embedding 模型
# ================================================================
class TestEmbedding:
"""
测试 vector.py - OpenAIEmbeddings
覆盖: embed_query / embed_documents / 维度一致性 / API 异常
"""
def _make_embedder(self, mock_client):
"""用 Mock OpenAI client 构造 embedder, 绕过真实连接"""
from vector import OpenAIEmbeddings
embedder = object.__new__(OpenAIEmbeddings)
embedder.client = mock_client
return embedder
def _fake_vec(self, dim=1536):
return [random.uniform(-1, 1) for _ in range(dim)]
def test_embed_query_dimension(self):
"""单条嵌入: 返回 1536 维向量"""
mock = MagicMock()
mock.embeddings.create.return_value = FakeEmbeddingResponse(self._fake_vec())
emb = self._make_embedder(mock)
vec = emb.embed_query("高血压症状")
assert isinstance(vec, list)
assert len(vec) == 1536
def test_embed_documents_batch(self):
"""批量嵌入: 3 条文本 → 3 个向量"""
mock = MagicMock()
mock.embeddings.create.side_effect = [
FakeEmbeddingResponse(self._fake_vec()),
FakeEmbeddingResponse(self._fake_vec()),
FakeEmbeddingResponse(self._fake_vec()),
]
emb = self._make_embedder(mock)
vecs = emb.embed_documents(["A", "B", "C"])
assert len(vecs) == 3
assert all(len(v) == 1536 for v in vecs)
def test_embed_query_calls_correct_model(self):
"""验证调用时传入 model='text-embedding-3-small'"""
mock = MagicMock()
mock.embeddings.create.return_value = FakeEmbeddingResponse(self._fake_vec())
emb = self._make_embedder(mock)
emb.embed_query("test")
# 检查 create() 被调用时的参数
call_kwargs = mock.embeddings.create.call_args.kwargs
assert call_kwargs.get("model") == "text-embedding-3-small"
def test_embed_empty_text(self):
"""空字符串也应返回向量 (不报错)"""
mock = MagicMock()
mock.embeddings.create.return_value = FakeEmbeddingResponse(self._fake_vec())
emb = self._make_embedder(mock)
vec = emb.embed_query("")
assert isinstance(vec, list) and len(vec) == 1536
def test_embed_api_error_propagates(self):
"""API 报错时异常应向上传播"""
mock = MagicMock()
mock.embeddings.create.side_effect = Exception("Rate limit exceeded")
emb = self._make_embedder(mock)
with pytest.raises(Exception, match="Rate limit"):
emb.embed_query("test")
def test_embed_chinese_medical_text(self):
"""中文医学文本嵌入应正常工作"""
mock = MagicMock()
mock.embeddings.create.return_value = FakeEmbeddingResponse(self._fake_vec())
emb = self._make_embedder(mock)
vec = emb.embed_query("宫腔异形的治疗方案有哪些?")
assert len(vec) == 1536
# ================================================================
# 工具 3: Milvus 向量检索
# ================================================================
class TestMilvusRetrieval:
"""
测试 Milvus similarity_search + format_docs
覆盖: 正常召回 / 空结果 / 重排序参数 / 异常降级
"""
def test_format_docs_normal(self):
"""3 篇文档 → 用双换行拼接"""
docs = [
FakeDocument(page_content="高血压是常见疾病"),
FakeDocument(page_content="建议低盐饮食"),
FakeDocument(page_content="定期测量血压"),
]
result = "\n\n".join(d.page_content for d in docs)
assert result.count("\n\n") == 2
assert "低盐饮食" in result
def test_format_docs_empty(self):
"""空列表 → 空字符串"""
assert "\n\n".join(d.page_content for d in []) == ""
def test_format_docs_single(self):
"""单篇文档 → 无分隔符"""
result = "\n\n".join(d.page_content for d in [FakeDocument(page_content="唯一")])
assert result == "唯一"
def test_similarity_search_returns_topk(self):
"""Milvus 应返回 k=10 的 top-k 结果"""
mock_vs = MagicMock()
mock_vs.similarity_search.return_value = [
FakeDocument(page_content=f"doc_{i}") for i in range(10)
]
results = mock_vs.similarity_search("query", k=10, ranker_type="rrf", ranker_params={"k": 100})
assert len(results) == 10
def test_similarity_search_rrf_params_passed(self):
"""验证 RRF 重排序参数被正确传递"""
mock_vs = MagicMock()
mock_vs.similarity_search.return_value = []
mock_vs.similarity_search("q", k=10, ranker_type="rrf", ranker_params={"k": 100})
call_kwargs = mock_vs.similarity_search.call_args.kwargs
assert call_kwargs["ranker_type"] == "rrf"
assert call_kwargs["ranker_params"] == {"k": 100}
def test_similarity_search_empty(self):
"""无匹配时 → context 为空"""
mock_vs = MagicMock()
mock_vs.similarity_search.return_value = []
results = mock_vs.similarity_search("xyz无关查询")
context = "\n\n".join(d.page_content for d in results) if results else ""
assert context == ""
def test_similarity_search_exception(self):
"""Milvus 服务异常 → 应抛出异常 (agent 层决定降级策略)"""
mock_vs = MagicMock()
mock_vs.similarity_search.side_effect = ConnectionError("Milvus timeout")
with pytest.raises(ConnectionError):
mock_vs.similarity_search("test")
# ================================================================
# 工具 4: PDF 父子文档检索
# ================================================================
class TestPDFRetrieval:
"""
测试 parent_retriever.invoke()
覆盖: 正常召回 / 空结果 / None / 长文档
"""
def test_retriever_returns_document(self):
"""正常检索 → 返回至少 1 篇文档"""
mock_ret = MagicMock()
mock_ret.invoke.return_value = [
FakeDocument(page_content="根据《高血压防治指南》第三章...")
]
results = mock_ret.invoke("高血压分级标准")
assert len(results) >= 1
assert "高血压" in results[0].page_content
def test_retriever_empty_list(self):
"""无匹配 → pdf_res 为空"""
mock_ret = MagicMock()
mock_ret.invoke.return_value = []
results = mock_ret.invoke("xyz")
pdf_res = results[0].page_content if results else ""
assert pdf_res == ""
def test_retriever_none_safe(self):
"""返回 None → 不报错, pdf_res 为空"""
mock_ret = MagicMock()
mock_ret.invoke.return_value = None
results = mock_ret.invoke("test")
pdf_res = ""
if results is not None and len(results) >= 1:
pdf_res = results[0].page_content
assert pdf_res == ""
def test_retriever_long_document(self):
"""长文档应完整返回"""
long = "医学文献内容。" * 500
mock_ret = MagicMock()
mock_ret.invoke.return_value = [FakeDocument(page_content=long)]
r = mock_ret.invoke("长文档")
assert len(r[0].page_content) == len(long)
def test_retriever_multiple_results_takes_first(self):
"""agent4.py 只取 results[0], 验证此行为"""
mock_ret = MagicMock()
mock_ret.invoke.return_value = [
FakeDocument(page_content="最相关"),
FakeDocument(page_content="第二篇"),
]
results = mock_ret.invoke("test")
pdf_res = results[0].page_content if results else ""
assert pdf_res == "最相关"
# ================================================================
# 工具 5: Neo4j 图数据库查询 (Cypher 生成 → 校验 → 执行)
# ================================================================
class TestNeo4jCypherPipeline:
"""
测试 agent4.py 中 Neo4j 三阶段流程:
Stage 1: POST /generate → Cypher + confidence + validated
Stage 2: POST /validate → is_valid
Stage 3: session.run(cypher) → 结果提取
"""
# ---- Stage 1: Cypher 生成决策逻辑 ----
def test_high_confidence_valid_executes(self):
"""0.95 + validated=True → 执行"""
d = {"cypher_query": "MATCH (d:Disease) RETURN d", "confidence": 0.95, "validated": True}
assert (d["cypher_query"] is not None and float(d["confidence"]) >= 0.9 and d["validated"]) is True
def test_low_confidence_skips(self):
"""0.5 < 0.9 → 不执行"""
d = {"cypher_query": "MATCH", "confidence": 0.5, "validated": True}
assert (float(d["confidence"]) >= 0.9 and d["validated"]) is False
def test_invalid_skips(self):
"""validated=False → 不执行"""
d = {"cypher_query": "BAD", "confidence": 0.99, "validated": False}
assert (float(d["confidence"]) >= 0.9 and d["validated"]) is False
def test_null_cypher_skips(self):
"""cypher_query=None → 不执行"""
d = {"cypher_query": None, "confidence": 0.95, "validated": True}
assert (d["cypher_query"] is not None) is False
def test_boundary_089_skips(self):
"""边界 0.89 → 不执行"""
assert (0.89 >= 0.9) is False
def test_boundary_090_executes(self):
"""边界 0.90 → 执行"""
assert (0.90 >= 0.9) is True
# ---- Stage 2: Cypher 校验 ----
def test_validate_pass(self):
resp = MagicMock(); resp.json.return_value = {"is_valid": True}
assert resp.json()["is_valid"] is True
def test_validate_fail(self):
resp = MagicMock(); resp.json.return_value = {"is_valid": False}
assert resp.json()["is_valid"] is False
# ---- Stage 3: Cypher 执行 ----
def test_neo4j_run_success(self):
"""正常执行 → 逗号拼接"""
mock_session = MagicMock()
mock_session.run.return_value = [("高血压",), ("糖尿病",)]
result = list(map(lambda x: x[0], mock_session.run("MATCH ...")))
assert ','.join(result) == "高血压,糖尿病"
def test_neo4j_run_empty(self):
"""空结果 → 空字符串"""
mock_session = MagicMock()
mock_session.run.return_value = []
result = list(map(lambda x: x[0], mock_session.run("MATCH ...")))
assert ','.join(result) == ""
def test_neo4j_run_exception_graceful(self):
"""查询异常 → 降级为空"""
mock_session = MagicMock()
mock_session.run.side_effect = Exception("Connection lost")
neo4j_res = ""
try:
result = list(map(lambda x: x[0], mock_session.run("BAD")))
neo4j_res = ','.join(result)
except Exception:
neo4j_res = ""
assert neo4j_res == ""
def test_cypher_service_down(self):
"""Cypher API 宕机 → 降级为空"""
with patch('requests.post', side_effect=ConnectionError("refused")):
neo4j_res = ""
try:
import requests
requests.post("http://0.0.0.0:8101/generate", "{}")
except Exception:
neo4j_res = ""
assert neo4j_res == ""
# ================================================================
# 工具 6: OpenAI LLM 推理
# ================================================================
class TestLLMInference:
"""
测试 agent4.py - generate_openai_answer
覆盖: 正常生成 / Prompt 构建 / 空返回 / 异常
"""
def test_generate_success(self):
"""正常生成回复"""
mock = MagicMock()
mock.chat.completions.create.return_value = FakeChatResponse(
"高血压患者应避免高盐饮食, 每日钠 <6g."
)
answer = mock.chat.completions.create(
model="gpt-4o-mini",
messages=[{"role": "user", "content": "高血压饮食"}],
temperature=0.7,
).choices[0].message.content
assert "高血压" in answer and len(answer) > 10
def test_prompt_structure(self):
"""Prompt 包含: 系统角色 + <context> + <question>"""
query, context = "高血压不能吃什么?", "低盐饮食"
SYSTEM = "System: 你是一个非常得力的医学助手."
USER = f"<context>\n{context}\n</context>\n<question>\n{query}\n</question>"
full = SYSTEM + USER
assert "医学助手" in full and query in full and context in full
def test_prompt_empty_context(self):
"""上下文为空 → Prompt 仍完整"""
p = "<context>\n\n</context>\n<question>\n糖尿病?\n</question>"
assert "<context>" in p and "糖尿病" in p
def test_llm_timeout(self):
"""LLM 超时 → 异常传播"""
mock = MagicMock()
mock.chat.completions.create.side_effect = TimeoutError("timeout")
with pytest.raises(TimeoutError):
mock.chat.completions.create(model="gpt-4o-mini", messages=[])
def test_llm_empty_response(self):
"""LLM 返回空"""
mock = MagicMock()
mock.chat.completions.create.return_value = FakeChatResponse("")
answer = mock.chat.completions.create(model="m", messages=[]).choices[0].message.content
assert answer == ""
def test_generate_with_temperature(self):
"""验证 temperature=0.7 被正确传递"""
mock = MagicMock()
mock.chat.completions.create.return_value = FakeChatResponse("ok")
mock.chat.completions.create(
model="gpt-4o-mini",
messages=[{"role": "user", "content": "test"}],
temperature=0.7,
)
assert mock.chat.completions.create.call_args.kwargs["temperature"] == 0.7
# ================================================================
# 工具 7: PDF 批处理器 (preprocess.py)
# ================================================================
class TestPDFProcessor:
"""
测试 preprocess.py - PDFBatchProcessor
"""
def test_invalid_path_raises(self):
from preprocess import PDFBatchProcessor
proc = PDFBatchProcessor(output_dir="/tmp/test_pdf_out")
with pytest.raises(ValueError, match="路径不存在"):
proc.find_pdf_files("/nonexistent/xyz.txt")
def test_empty_dir(self, tmp_path):
from preprocess import PDFBatchProcessor
proc = PDFBatchProcessor(output_dir=str(tmp_path / "out"))
assert proc.find_pdf_files(str(tmp_path)) == []
def test_finds_pdf_only(self, tmp_path):
"""只查找 .pdf 文件"""
from preprocess import PDFBatchProcessor
(tmp_path / "a.pdf").touch()
(tmp_path / "b.pdf").touch()
(tmp_path / "c.txt").touch()
proc = PDFBatchProcessor(output_dir=str(tmp_path / "out"))
files = proc.find_pdf_files(str(tmp_path))
assert len(files) == 2
def test_single_pdf_file(self, tmp_path):
from preprocess import PDFBatchProcessor
pdf = tmp_path / "x.pdf"; pdf.touch()
proc = PDFBatchProcessor(output_dir=str(tmp_path / "out"))
assert proc.find_pdf_files(str(pdf)) == [pdf]
def test_extract_nonexistent_has_error(self, tmp_path):
from preprocess import PDFBatchProcessor
proc = PDFBatchProcessor(output_dir=str(tmp_path / "out"))
r = proc.extract_pdf_content(Path("/no/file.pdf"))
assert r["error"] is not None
def test_result_has_required_keys(self, tmp_path):
from preprocess import PDFBatchProcessor
proc = PDFBatchProcessor(output_dir=str(tmp_path / "out"))
r = proc.extract_pdf_content(Path("dummy.pdf"))
for k in ["file_name", "file_path", "metadata", "pages", "error"]:
assert k in r
# ================================================================
# 工具 8: 数据预处理 (JSONL → Document)
# ================================================================
class TestDataPreprocessing:
def test_jsonl_parse(self, tmp_path):
f = tmp_path / "t.jsonl"
f.write_text(
json.dumps({"query": "Q1", "response": "A1"}, ensure_ascii=False) + "\n"
+ json.dumps({"query": "Q2", "response": "A2"}, ensure_ascii=False) + "\n"
)
docs = []
with open(f) as fh:
for line in fh:
c = json.loads(line.strip())
docs.append(c["query"] + "\n" + c["response"])
assert len(docs) == 2 and "Q1" in docs[0]
def test_jsonl_empty(self, tmp_path):
f = tmp_path / "e.jsonl"; f.write_text("")
assert sum(1 for line in open(f) if line.strip()) == 0
def test_jsonl_bad_line(self, tmp_path):
f = tmp_path / "b.jsonl"
f.write_text('{"query":"ok","response":"r"}\n{bad}\n')
ok, bad = 0, 0
for line in open(f):
try:
json.loads(line); ok += 1
except json.JSONDecodeError:
bad += 1
assert ok == 1 and bad == 1
# ================================================================
# 工具 9: 端到端 RAG 编排逻辑
# ================================================================
class TestRAGOrchestration:
def test_three_way_merge(self):
ctx = "M结果" + "\n" + "P结果" + "\n" + "N结果"
assert "M结果" in ctx and "P结果" in ctx and "N结果" in ctx
def test_partial_empty_merge(self):
ctx = "有结果" + "\n" + "" + "\n" + ""
assert "有结果" in ctx
def test_all_empty_merge(self):
ctx = "" + "\n" + "" + "\n" + ""
assert ctx.strip() == ""
def test_request_valid(self):
assert {"question": "Q"}.get("question") == "Q"
def test_request_missing_question(self):
assert {"query": "x"}.get("question") is None
def test_redis_caching_in_chatbot(self):
"""chatbot 使用 redis get_or_compute: 缓存命中 → 跳过 RAG"""
mgr = make_redis_manager()
mgr.set_answer("Q", "缓存A")
called = False
def rag(): nonlocal called; called = True; return "new"
assert mgr.get_or_compute("Q", rag) == "缓存A"
assert called is False
# ================================================================
if __name__ == "__main__":
pytest.main([__file__, "-v", "--tb=short"])