agentV2 / test /test7.py
drewli20200316's picture
Add test/test7.py
a927ee5 verified
"""
================================================================
医疗 RAG Agent — Observability & Tracing 验证 (可观测性)
================================================================
测试层级:
单元测试 (test1.py): 单工具调用准确性 ✅ 67 passed
集成测试 (test2.py): 多步骤工具链协作 ✅ 37 passed
回归测试 (test3.py): 防退化 & 边界守护 ✅ 52 passed
安全红队 (test4.py): 对抗性攻击防御 ✅ 45 passed
E2E完成率(test5.py): 端到端任务完成率 ✅ 60 passed
成本效率 (test6.py): Cost & Efficiency ✅ 25 passed
可观测性 (test7.py): Observability & Tracing ← 当前文件
为什么要测可观测性?
Agent 的决策链路是: 用户输入 → Redis → Milvus → PDF → Neo4j → LLM → 响应
出了问题, 你需要在几分钟内定位 "哪一步出错了, 输入是什么, 输出是什么"
如果没有 Tracing, 调试一个 Agent bug 可能需要几小时甚至几天
测试维度:
维度 1: 决策链路完整性 (每一步 thought→action→observation 都被记录)
维度 2: 错误定位能力 (任意一步失败时, 能精确定位到是哪一步)
维度 3: 输入输出追溯 (每步的 input/output 可追溯)
维度 4: 时间线追踪 (每步的开始/结束时间, 耗时分析)
维度 5: 全链路 Trace 报告 (人类可读的调试日志)
运行:
pytest test7.py -v --tb=short -s
pytest test7.py -v -k "trace_complete" # 链路完整性
pytest test7.py -v -k "error_locate" # 错误定位
pytest test7.py -v -k "timeline" # 时间线
================================================================
"""
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
import types
import pytest
import json
import hashlib
import time
import uuid
from enum import Enum
from unittest.mock import MagicMock
from dataclasses import dataclass, field
from typing import Optional, List, Dict, Any
# ================================================================
# 前置: Mock 缺失依赖
# ================================================================
def _ensure_mock_module(name):
if name not in sys.modules:
sys.modules[name] = MagicMock()
for mod in [
"langchain_classic", "langchain_classic.retrievers",
"langchain_classic.retrievers.parent_document_retriever",
"langchain_milvus", "langchain_text_splitters",
"langchain_core", "langchain_core.stores", "langchain_core.documents",
"langchain.embeddings", "langchain.embeddings.base",
"neo4j", "dotenv", "uvicorn",
"fastapi", "fastapi.middleware", "fastapi.middleware.cors",
]:
_ensure_mock_module(mod)
class _FakeEmbeddingsBase:
pass
sys.modules["langchain.embeddings.base"].Embeddings = _FakeEmbeddingsBase
# ================================================================
# 基础设施
# ================================================================
@dataclass
class FakeDocument:
page_content: str
metadata: dict = field(default_factory=dict)
class FakeChatResponse:
def __init__(self, content):
msg = type('Msg', (), {'content': content})()
choice = type('Choice', (), {'message': msg})()
self.choices = [choice]
class FakeRedisClient:
def __init__(self):
self._store = {}
self._expiry = {}
def ping(self): return True
def get(self, key): return self._store.get(key)
def set(self, key, value, ex=None, nx=False):
if nx and key in self._store: return False
self._store[key] = value
if ex: self._expiry[key] = ex
return True
def setex(self, key, expire, value):
self._store[key] = value; self._expiry[key] = expire; return True
def delete(self, key): return 1 if self._store.pop(key, None) is not None else 0
def register_script(self, script):
def f(keys=None, args=None):
if keys and args and self._store.get(keys[0]) == args[0]:
del self._store[keys[0]]; return 1
return 0
return f
def make_redis_manager():
from new_redis import RedisClientWrapper
RedisClientWrapper._pool = "FAKE"
mgr = object.__new__(RedisClientWrapper)
mgr.client = FakeRedisClient()
mgr.unlock_script = mgr.client.register_script("")
return mgr
# ================================================================
# Tracing 框架: 记录 Agent 每一步的 Thought → Action → Observation
# ================================================================
class StepStatus(Enum):
SUCCESS = "success"
FAILED = "failed"
SKIPPED = "skipped"
@dataclass
class TraceStep:
"""Agent 决策链路中的单步记录"""
step_name: str # 步骤名: "milvus_search", "pdf_retrieval", ...
step_index: int # 第几步 (从 1 开始)
status: StepStatus # 成功/失败/跳过
input_data: Any # 这一步的输入
output_data: Any # 这一步的输出
error: Optional[str] # 如果失败, 错误信息
start_time: float # 开始时间戳
end_time: float # 结束时间戳
metadata: Dict = field(default_factory=dict) # 额外信息
@property
def duration_ms(self) -> float:
return (self.end_time - self.start_time) * 1000
@dataclass
class TraceRecord:
"""一次完整查询的全链路 Trace"""
trace_id: str # 唯一追踪 ID
query: str # 用户原始问题
steps: List[TraceStep] = field(default_factory=list)
total_start: float = 0.0
total_end: float = 0.0
final_answer: str = ""
@property
def total_duration_ms(self) -> float:
return (self.total_end - self.total_start) * 1000
@property
def success_steps(self) -> List[TraceStep]:
return [s for s in self.steps if s.status == StepStatus.SUCCESS]
@property
def failed_steps(self) -> List[TraceStep]:
return [s for s in self.steps if s.status == StepStatus.FAILED]
@property
def step_names(self) -> List[str]:
return [s.step_name for s in self.steps]
def get_step(self, name: str) -> Optional[TraceStep]:
return next((s for s in self.steps if s.step_name == name), None)
def perform_rag_with_tracing(
query: str,
milvus, pdf, neo4j_driver, llm, requests_module,
) -> TraceRecord:
"""
带完整 Tracing 的 RAG 执行
每一步都记录: input → output → 状态 → 耗时
"""
import json as _json
trace = TraceRecord(
trace_id=str(uuid.uuid4()),
query=query,
total_start=time.time(),
)
step_idx = 0
# ---- Step 1: Redis 缓存检查 ----
step_idx += 1
s1_start = time.time()
cache_key = hashlib.md5(query.encode('utf-8')).hexdigest()
trace.steps.append(TraceStep(
step_name="redis_check",
step_index=step_idx,
status=StepStatus.SUCCESS,
input_data={"query": query, "cache_key": f"llm:cache:{cache_key}"},
output_data={"hit": False},
error=None,
start_time=s1_start,
end_time=time.time(),
metadata={"action": "cache_lookup"},
))
# ---- Step 2: Milvus 向量召回 ----
step_idx += 1
s2_start = time.time()
context = ""
try:
results = milvus.similarity_search(query, k=10, ranker_type="rrf", ranker_params={"k": 100})
context = "\n\n".join(d.page_content for d in results) if results else ""
trace.steps.append(TraceStep(
step_name="milvus_search",
step_index=step_idx,
status=StepStatus.SUCCESS,
input_data={"query": query, "k": 10, "ranker": "rrf"},
output_data={"doc_count": len(results) if results else 0, "context_chars": len(context)},
error=None,
start_time=s2_start,
end_time=time.time(),
metadata={"ranker_params": {"k": 100}},
))
except Exception as e:
trace.steps.append(TraceStep(
step_name="milvus_search",
step_index=step_idx,
status=StepStatus.FAILED,
input_data={"query": query, "k": 10},
output_data=None,
error=str(e),
start_time=s2_start,
end_time=time.time(),
))
# ---- Step 3: PDF 父子文档检索 ----
step_idx += 1
s3_start = time.time()
pdf_res = ""
try:
docs = pdf.invoke(query)
if docs and len(docs) >= 1:
pdf_res = docs[0].page_content
trace.steps.append(TraceStep(
step_name="pdf_retrieval",
step_index=step_idx,
status=StepStatus.SUCCESS,
input_data={"query": query},
output_data={"doc_count": len(docs) if docs else 0, "content_chars": len(pdf_res)},
error=None,
start_time=s3_start,
end_time=time.time(),
))
except Exception as e:
trace.steps.append(TraceStep(
step_name="pdf_retrieval",
step_index=step_idx,
status=StepStatus.FAILED,
input_data={"query": query},
output_data=None,
error=str(e),
start_time=s3_start,
end_time=time.time(),
))
context = context + "\n" + pdf_res
# ---- Step 4: Cypher 生成 ----
step_idx += 1
s4_start = time.time()
neo4j_res = ""
cypher_query = None
try:
resp = requests_module.post("http://0.0.0.0:8101/generate",
_json.dumps({"natural_language_query": query}))
if resp.status_code == 200:
d = resp.json()
cypher_query = d.get("cypher_query")
confidence = d.get("confidence", 0)
validated = d.get("validated", False)
trace.steps.append(TraceStep(
step_name="cypher_generate",
step_index=step_idx,
status=StepStatus.SUCCESS,
input_data={"query": query},
output_data={"cypher": cypher_query, "confidence": confidence, "validated": validated},
error=None,
start_time=s4_start,
end_time=time.time(),
))
# ---- Step 5: Cypher 校验 + Neo4j 执行 ----
if cypher_query and float(confidence) >= 0.9 and validated:
step_idx += 1
s5_start = time.time()
try:
vresp = requests_module.post("http://0.0.0.0:8101/validate",
_json.dumps({"cypher_query": cypher_query}))
if vresp.status_code == 200 and vresp.json()["is_valid"]:
with neo4j_driver.session() as session:
record = session.run(cypher_query)
result = list(map(lambda x: x[0], record))
neo4j_res = ','.join(result)
trace.steps.append(TraceStep(
step_name="neo4j_execute",
step_index=step_idx,
status=StepStatus.SUCCESS,
input_data={"cypher": cypher_query},
output_data={"result_count": len(result), "results": result},
error=None,
start_time=s5_start,
end_time=time.time(),
))
else:
trace.steps.append(TraceStep(
step_name="neo4j_execute",
step_index=step_idx,
status=StepStatus.SKIPPED,
input_data={"cypher": cypher_query},
output_data={"reason": "validation_failed"},
error=None,
start_time=s5_start,
end_time=time.time(),
))
except Exception as e:
trace.steps.append(TraceStep(
step_name="neo4j_execute",
step_index=step_idx,
status=StepStatus.FAILED,
input_data={"cypher": cypher_query},
output_data=None,
error=str(e),
start_time=s5_start,
end_time=time.time(),
))
else:
step_idx += 1
trace.steps.append(TraceStep(
step_name="neo4j_execute",
step_index=step_idx,
status=StepStatus.SKIPPED,
input_data={"cypher": cypher_query, "confidence": confidence},
output_data={"reason": "low_confidence_or_invalid"},
error=None,
start_time=time.time(),
end_time=time.time(),
))
else:
trace.steps.append(TraceStep(
step_name="cypher_generate",
step_index=step_idx,
status=StepStatus.FAILED,
input_data={"query": query},
output_data={"status_code": resp.status_code},
error=f"HTTP {resp.status_code}",
start_time=s4_start,
end_time=time.time(),
))
except Exception as e:
trace.steps.append(TraceStep(
step_name="cypher_generate",
step_index=step_idx,
status=StepStatus.FAILED,
input_data={"query": query},
output_data=None,
error=str(e),
start_time=s4_start,
end_time=time.time(),
))
context = context + "\n" + neo4j_res
# ---- Step 6: LLM 推理 ----
step_idx += 1
s6_start = time.time()
SYSTEM = "System: 你是一个非常得力的医学助手."
USER = f"User: <context>{context}</context><question>{query}</question>"
full_prompt = SYSTEM + USER
try:
response = llm.chat.completions.create(
model="gpt-4o-mini",
messages=[{"role": "user", "content": full_prompt}],
temperature=0.7,
)
answer = response.choices[0].message.content
trace.steps.append(TraceStep(
step_name="llm_inference",
step_index=step_idx,
status=StepStatus.SUCCESS,
input_data={"prompt_chars": len(full_prompt), "model": "gpt-4o-mini", "temperature": 0.7},
output_data={"answer_chars": len(answer), "answer_preview": answer[:80]},
error=None,
start_time=s6_start,
end_time=time.time(),
))
trace.final_answer = answer
except Exception as e:
trace.steps.append(TraceStep(
step_name="llm_inference",
step_index=step_idx,
status=StepStatus.FAILED,
input_data={"prompt_chars": len(full_prompt)},
output_data=None,
error=str(e),
start_time=s6_start,
end_time=time.time(),
))
trace.total_end = time.time()
return trace
# ================================================================
# Mock 工厂
# ================================================================
def make_mocks(milvus_fail=False, pdf_fail=False, neo4j_fail=False, llm_fail=False,
neo4j_low_confidence=False):
"""构建可配置故障的 Mock 组件"""
# Milvus
milvus = MagicMock()
if milvus_fail:
milvus.similarity_search.side_effect = ConnectionError("Milvus timeout")
else:
milvus.similarity_search.return_value = [
FakeDocument(page_content="高血压需控制钠摄入不超过5g")
]
# PDF
pdf = MagicMock()
if pdf_fail:
pdf.invoke.side_effect = Exception("PDF index corrupted")
else:
pdf.invoke.return_value = [FakeDocument(page_content="《高血压防治指南》建议低盐低脂")]
# Neo4j
neo4j_driver = MagicMock()
sess = MagicMock()
if neo4j_fail:
sess.run.side_effect = Exception("Neo4j ServiceUnavailable")
else:
sess.run.return_value = [("氨氯地平",), ("缬沙坦",)]
neo4j_driver.session.return_value.__enter__ = MagicMock(return_value=sess)
neo4j_driver.session.return_value.__exit__ = MagicMock(return_value=False)
# Cypher API
req = MagicMock()
if neo4j_fail:
req.post.side_effect = ConnectionError("Cypher API down")
else:
gen = MagicMock(); gen.status_code = 200
conf = 0.5 if neo4j_low_confidence else 0.95
gen.json.return_value = {
"cypher_query": "MATCH (d:Disease)-[:has_drug]->(m) RETURN m.name",
"confidence": conf, "validated": True,
}
val = MagicMock(); val.status_code = 200
val.json.return_value = {"is_valid": True}
req.post.side_effect = [gen, val]
# LLM
llm = MagicMock()
if llm_fail:
llm.chat.completions.create.side_effect = TimeoutError("LLM timeout")
else:
llm.chat.completions.create.return_value = FakeChatResponse(
"高血压患者应限制盐分摄入, 常用药物包括氨氯地平等。"
)
return milvus, pdf, neo4j_driver, llm, req
def run_traced(query="高血压不能吃什么?", **kwargs) -> TraceRecord:
milvus, pdf, neo4j, llm, req = make_mocks(**kwargs)
return perform_rag_with_tracing(query, milvus, pdf, neo4j, llm, req)
# ================================================================
# 维度 1: 决策链路完整性
# ================================================================
class TestTraceCompleteness:
"""每一步 Thought → Action → Observation 都被记录"""
def test_happy_path_has_all_six_steps(self):
"""正常链路: 6 步全部记录"""
trace = run_traced()
expected = ["redis_check", "milvus_search", "pdf_retrieval",
"cypher_generate", "neo4j_execute", "llm_inference"]
assert trace.step_names == expected, (
f"链路步骤: {trace.step_names}, 期望: {expected}"
)
def test_every_step_has_required_fields(self):
"""每步都有 step_name / status / input / output / 时间"""
trace = run_traced()
for step in trace.steps:
assert step.step_name, f"Step {step.step_index}: 缺少 step_name"
assert step.status is not None, f"Step {step.step_name}: 缺少 status"
assert step.input_data is not None, f"Step {step.step_name}: 缺少 input_data"
assert step.start_time > 0, f"Step {step.step_name}: 缺少 start_time"
assert step.end_time >= step.start_time, f"Step {step.step_name}: end < start"
def test_trace_has_id_and_query(self):
"""Trace 有唯一 ID 和原始查询"""
trace = run_traced("糖尿病饮食")
assert trace.trace_id, "应有 trace_id"
assert len(trace.trace_id) == 36, "trace_id 应为 UUID 格式"
assert trace.query == "糖尿病饮食", "应记录原始查询"
def test_trace_has_final_answer(self):
"""Trace 记录最终回答"""
trace = run_traced()
assert len(trace.final_answer) > 0, "应记录最终回答"
def test_step_indices_are_sequential(self):
"""步骤序号连续递增"""
trace = run_traced()
indices = [s.step_index for s in trace.steps]
for i in range(len(indices) - 1):
assert indices[i+1] > indices[i], f"步骤序号不连续: {indices}"
def test_all_success_in_happy_path(self):
"""正常链路: 所有步骤都是 SUCCESS"""
trace = run_traced()
for step in trace.steps:
assert step.status == StepStatus.SUCCESS, (
f"Step '{step.step_name}' 应为 SUCCESS, 实际 {step.status.value}"
)
# ================================================================
# 维度 2: 错误定位能力
# ================================================================
class TestErrorLocalization:
"""任意一步失败时, Trace 能精确定位到是哪一步"""
def test_milvus_failure_located(self):
"""Milvus 故障 → Trace 精确标记 milvus_search 为 FAILED"""
trace = run_traced(milvus_fail=True)
milvus_step = trace.get_step("milvus_search")
assert milvus_step is not None, "应有 milvus_search 步骤"
assert milvus_step.status == StepStatus.FAILED
assert "timeout" in milvus_step.error.lower() or "milvus" in milvus_step.error.lower()
def test_pdf_failure_located(self):
"""PDF 故障 → 精确标记 pdf_retrieval 为 FAILED"""
trace = run_traced(pdf_fail=True)
pdf_step = trace.get_step("pdf_retrieval")
assert pdf_step is not None
assert pdf_step.status == StepStatus.FAILED
assert pdf_step.error is not None
def test_neo4j_failure_located(self):
"""Neo4j 故障 → 精确标记 cypher_generate 为 FAILED"""
trace = run_traced(neo4j_fail=True)
cypher_step = trace.get_step("cypher_generate")
assert cypher_step is not None
assert cypher_step.status == StepStatus.FAILED
assert cypher_step.error is not None
def test_llm_failure_located(self):
"""LLM 超时 → 精确标记 llm_inference 为 FAILED"""
trace = run_traced(llm_fail=True)
llm_step = trace.get_step("llm_inference")
assert llm_step is not None
assert llm_step.status == StepStatus.FAILED
assert "timeout" in llm_step.error.lower()
def test_only_failed_step_is_marked(self):
"""只有故障步骤被标记 FAILED, 其他步骤正常"""
trace = run_traced(pdf_fail=True)
for step in trace.steps:
if step.step_name == "pdf_retrieval":
assert step.status == StepStatus.FAILED
elif step.step_name in ["redis_check", "milvus_search"]:
assert step.status == StepStatus.SUCCESS, (
f"PDF 故障不应影响 {step.step_name}"
)
def test_failed_steps_count(self):
"""Milvus + PDF 同时故障 → 恰好 2 个 FAILED"""
trace = run_traced(milvus_fail=True, pdf_fail=True)
assert len(trace.failed_steps) == 2, (
f"应有 2 个失败步骤, 实际 {len(trace.failed_steps)}: "
f"{[s.step_name for s in trace.failed_steps]}"
)
def test_low_confidence_neo4j_skipped(self):
"""低置信度 → neo4j_execute 标记为 SKIPPED (非 FAILED)"""
trace = run_traced(neo4j_low_confidence=True)
neo4j_step = trace.get_step("neo4j_execute")
assert neo4j_step is not None
assert neo4j_step.status == StepStatus.SKIPPED, (
f"低置信度应 SKIPPED, 实际 {neo4j_step.status.value}"
)
# ================================================================
# 维度 3: 输入输出追溯
# ================================================================
class TestInputOutputTracing:
"""每步的 input/output 可追溯, 出问题时能看到 '传了什么进去, 出了什么来'"""
def test_redis_step_records_cache_key(self):
"""Redis 步骤记录 cache_key"""
trace = run_traced("高血压")
redis_step = trace.get_step("redis_check")
assert "cache_key" in redis_step.input_data
assert redis_step.input_data["cache_key"].startswith("llm:cache:")
def test_milvus_step_records_query_and_params(self):
"""Milvus 步骤记录查询参数 (k, ranker)"""
trace = run_traced()
step = trace.get_step("milvus_search")
assert step.input_data["k"] == 10
assert step.input_data["ranker"] == "rrf"
assert step.output_data["doc_count"] >= 0
def test_cypher_step_records_generated_cypher(self):
"""Cypher 步骤记录生成的 Cypher 语句和置信度"""
trace = run_traced()
step = trace.get_step("cypher_generate")
assert step.output_data["cypher"] is not None
assert step.output_data["confidence"] >= 0.9
assert step.output_data["validated"] is True
def test_neo4j_step_records_results(self):
"""Neo4j 步骤记录查询结果"""
trace = run_traced()
step = trace.get_step("neo4j_execute")
assert step.output_data["result_count"] == 2
assert "氨氯地平" in step.output_data["results"]
def test_llm_step_records_prompt_size(self):
"""LLM 步骤记录 prompt 大小和模型参数"""
trace = run_traced()
step = trace.get_step("llm_inference")
assert step.input_data["prompt_chars"] > 0
assert step.input_data["model"] == "gpt-4o-mini"
assert step.input_data["temperature"] == 0.7
def test_llm_step_records_answer_preview(self):
"""LLM 步骤记录回答预览"""
trace = run_traced()
step = trace.get_step("llm_inference")
assert step.output_data["answer_chars"] > 0
assert len(step.output_data["answer_preview"]) > 0
def test_failed_step_records_error_message(self):
"""失败步骤记录具体错误信息"""
trace = run_traced(milvus_fail=True)
step = trace.get_step("milvus_search")
assert step.error is not None
assert len(step.error) > 0
assert step.output_data is None, "失败步骤 output 应为 None"
# ================================================================
# 维度 4: 时间线追踪
# ================================================================
class TestTimelineTracking:
"""每步的耗时可追踪, 用于发现性能瓶颈"""
def test_every_step_has_positive_duration(self):
"""每步耗时 ≥ 0"""
trace = run_traced()
for step in trace.steps:
assert step.duration_ms >= 0, (
f"Step '{step.step_name}' 耗时为负: {step.duration_ms}ms"
)
def test_total_duration_covers_all_steps(self):
"""总耗时 ≥ 所有步骤耗时之和 (因为有框架开销)"""
trace = run_traced()
steps_total = sum(s.duration_ms for s in trace.steps)
assert trace.total_duration_ms >= 0
def test_steps_are_chronologically_ordered(self):
"""步骤按时间顺序排列"""
trace = run_traced()
for i in range(len(trace.steps) - 1):
assert trace.steps[i].start_time <= trace.steps[i+1].start_time, (
f"Step '{trace.steps[i].step_name}' 开始时间 > 下一步"
)
def test_trace_start_before_first_step(self):
"""Trace 总开始时间 ≤ 第一步开始时间"""
trace = run_traced()
assert trace.total_start <= trace.steps[0].start_time
def test_trace_end_after_last_step(self):
"""Trace 总结束时间 ≥ 最后一步结束时间"""
trace = run_traced()
assert trace.total_end >= trace.steps[-1].end_time
# ================================================================
# 维度 5: 全链路 Trace 报告
# ================================================================
class TestTraceReport:
"""生成人类可读的 Trace 报告, 用于调试和展示"""
def test_normal_trace_report(self, capsys):
"""正常链路的 Trace 报告"""
trace = run_traced("高血压不能吃什么?")
_print_trace_report(trace)
assert len(trace.failed_steps) == 0
def test_degraded_trace_report(self, capsys):
"""降级链路的 Trace 报告 (PDF + Neo4j 故障)"""
trace = run_traced("高血压不能吃什么?", pdf_fail=True, neo4j_fail=True)
_print_trace_report(trace)
assert len(trace.failed_steps) >= 1
def test_full_failure_trace_report(self, capsys):
"""全故障链路的 Trace 报告"""
trace = run_traced("高血压", milvus_fail=True, pdf_fail=True, neo4j_fail=True)
_print_trace_report(trace)
assert len(trace.failed_steps) >= 3
def test_trace_report_identifies_bottleneck(self):
"""Trace 能识别最慢的步骤 (性能瓶颈)"""
trace = run_traced()
if trace.steps:
slowest = max(trace.steps, key=lambda s: s.duration_ms)
assert slowest.step_name is not None
# 在 Mock 环境中耗时极短, 但结构正确
def test_multiple_traces_comparable(self):
"""多个 Trace 可以对比 (不同 trace_id)"""
t1 = run_traced("问题A")
t2 = run_traced("问题B")
assert t1.trace_id != t2.trace_id, "不同查询应有不同 trace_id"
def test_trace_summary_printout(self, capsys):
"""打印完整的 Observability 总结报告"""
traces = [
("正常链路", run_traced("高血压饮食")),
("PDF故障", run_traced("高血压饮食", pdf_fail=True)),
("Neo4j故障", run_traced("高血压饮食", neo4j_fail=True)),
("全部故障", run_traced("高血压饮食", milvus_fail=True, pdf_fail=True, neo4j_fail=True)),
]
print("\n")
print("=" * 70)
print(" 医疗 RAG Agent — Observability & Tracing 报告")
print("=" * 70)
for label, trace in traces:
status_icon = "✅" if len(trace.failed_steps) == 0 else "⚠️"
print(f"\n {status_icon} [{label}] trace_id={trace.trace_id[:8]}...")
print(f" 查询: '{trace.query}'")
print(f" 总耗时: {trace.total_duration_ms:.2f}ms | 步骤数: {len(trace.steps)}")
for step in trace.steps:
icon = {"success": "✅", "failed": "❌", "skipped": "⏭️"}[step.status.value]
err_info = f" | Error: {step.error[:40]}" if step.error else ""
print(f" {icon} [{step.step_index}] {step.step_name:<20s} "
f"{step.duration_ms:>6.2f}ms{err_info}")
if trace.final_answer:
print(f" 回答: {trace.final_answer[:50]}...")
print(f"\n{'─' * 70}")
print(f" 场景覆盖: {len(traces)} 种 | 总步骤数: {sum(len(t.steps) for _, t in traces)}")
print("=" * 70)
assert True
def _print_trace_report(trace: TraceRecord):
"""打印单条 Trace 的详细报告"""
print(f"\n ── Trace: {trace.trace_id[:8]}... ──")
print(f" 查询: '{trace.query}'")
print(f" 总耗时: {trace.total_duration_ms:.2f}ms")
for step in trace.steps:
icon = {"success": "✅", "failed": "❌", "skipped": "⏭️"}[step.status.value]
print(f" {icon} [{step.step_index}] {step.step_name}: {step.duration_ms:.2f}ms")
if step.error:
print(f" ❗ Error: {step.error}")
if trace.final_answer:
print(f" 回答: {trace.final_answer[:60]}...")
# ================================================================
if __name__ == "__main__":
pytest.main([__file__, "-v", "--tb=short", "-s"])