drewli20200316 commited on
Commit
a927ee5
·
verified ·
1 Parent(s): 2d0bb6a

Add test/test7.py

Browse files
Files changed (1) hide show
  1. test/test7.py +808 -0
test/test7.py ADDED
@@ -0,0 +1,808 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ================================================================
3
+ 医疗 RAG Agent — Observability & Tracing 验证 (可观测性)
4
+ ================================================================
5
+ 测试层级:
6
+ 单元测试 (test1.py): 单工具调用准确性 ✅ 67 passed
7
+ 集成测试 (test2.py): 多步骤工具链协作 ✅ 37 passed
8
+ 回归测试 (test3.py): 防退化 & 边界守护 ✅ 52 passed
9
+ 安全红队 (test4.py): 对抗性攻击防御 ✅ 45 passed
10
+ E2E完成率(test5.py): 端到端任务完成率 ✅ 60 passed
11
+ 成本效率 (test6.py): Cost & Efficiency ✅ 25 passed
12
+ 可观测性 (test7.py): Observability & Tracing ← 当前文件
13
+
14
+ 为什么要测可观测性?
15
+ Agent 的决策链路是: 用户输入 → Redis → Milvus → PDF → Neo4j → LLM → 响应
16
+ 出了问题, 你需要在几分钟内定位 "哪一步出错了, 输入是什么, 输出是什么"
17
+ 如果没有 Tracing, 调试一个 Agent bug 可能需要几小时甚至几天
18
+
19
+ 测试维度:
20
+ 维度 1: 决策链路完整性 (每一步 thought→action→observation 都被记录)
21
+ 维度 2: 错误定位能力 (任意一步失败时, 能精确定位到是哪一步)
22
+ 维度 3: 输入输出追溯 (每步的 input/output 可追溯)
23
+ 维度 4: 时间线追踪 (每步的开始/结束时间, 耗时分析)
24
+ 维度 5: 全链路 Trace 报告 (人类可读的调试日志)
25
+
26
+ 运行:
27
+ pytest test7.py -v --tb=short -s
28
+ pytest test7.py -v -k "trace_complete" # 链路完整性
29
+ pytest test7.py -v -k "error_locate" # 错误定位
30
+ pytest test7.py -v -k "timeline" # 时间线
31
+ ================================================================
32
+ """
33
+
34
+ import sys
35
+ import os
36
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
37
+
38
+ import types
39
+ import pytest
40
+ import json
41
+ import hashlib
42
+ import time
43
+ import uuid
44
+ from enum import Enum
45
+ from unittest.mock import MagicMock
46
+ from dataclasses import dataclass, field
47
+ from typing import Optional, List, Dict, Any
48
+
49
+
50
+ # ================================================================
51
+ # 前置: Mock 缺失依赖
52
+ # ================================================================
53
+
54
+ def _ensure_mock_module(name):
55
+ if name not in sys.modules:
56
+ sys.modules[name] = MagicMock()
57
+
58
+ for mod in [
59
+ "langchain_classic", "langchain_classic.retrievers",
60
+ "langchain_classic.retrievers.parent_document_retriever",
61
+ "langchain_milvus", "langchain_text_splitters",
62
+ "langchain_core", "langchain_core.stores", "langchain_core.documents",
63
+ "langchain.embeddings", "langchain.embeddings.base",
64
+ "neo4j", "dotenv", "uvicorn",
65
+ "fastapi", "fastapi.middleware", "fastapi.middleware.cors",
66
+ ]:
67
+ _ensure_mock_module(mod)
68
+
69
+ class _FakeEmbeddingsBase:
70
+ pass
71
+ sys.modules["langchain.embeddings.base"].Embeddings = _FakeEmbeddingsBase
72
+
73
+
74
+ # ================================================================
75
+ # 基础设施
76
+ # ================================================================
77
+
78
+ @dataclass
79
+ class FakeDocument:
80
+ page_content: str
81
+ metadata: dict = field(default_factory=dict)
82
+
83
+ class FakeChatResponse:
84
+ def __init__(self, content):
85
+ msg = type('Msg', (), {'content': content})()
86
+ choice = type('Choice', (), {'message': msg})()
87
+ self.choices = [choice]
88
+
89
+ class FakeRedisClient:
90
+ def __init__(self):
91
+ self._store = {}
92
+ self._expiry = {}
93
+ def ping(self): return True
94
+ def get(self, key): return self._store.get(key)
95
+ def set(self, key, value, ex=None, nx=False):
96
+ if nx and key in self._store: return False
97
+ self._store[key] = value
98
+ if ex: self._expiry[key] = ex
99
+ return True
100
+ def setex(self, key, expire, value):
101
+ self._store[key] = value; self._expiry[key] = expire; return True
102
+ def delete(self, key): return 1 if self._store.pop(key, None) is not None else 0
103
+ def register_script(self, script):
104
+ def f(keys=None, args=None):
105
+ if keys and args and self._store.get(keys[0]) == args[0]:
106
+ del self._store[keys[0]]; return 1
107
+ return 0
108
+ return f
109
+
110
+ def make_redis_manager():
111
+ from new_redis import RedisClientWrapper
112
+ RedisClientWrapper._pool = "FAKE"
113
+ mgr = object.__new__(RedisClientWrapper)
114
+ mgr.client = FakeRedisClient()
115
+ mgr.unlock_script = mgr.client.register_script("")
116
+ return mgr
117
+
118
+
119
+ # ================================================================
120
+ # Tracing 框架: 记录 Agent 每一步的 Thought → Action → Observation
121
+ # ================================================================
122
+
123
+ class StepStatus(Enum):
124
+ SUCCESS = "success"
125
+ FAILED = "failed"
126
+ SKIPPED = "skipped"
127
+
128
+
129
+ @dataclass
130
+ class TraceStep:
131
+ """Agent 决策链路中的单步记录"""
132
+ step_name: str # 步骤名: "milvus_search", "pdf_retrieval", ...
133
+ step_index: int # 第几步 (从 1 开始)
134
+ status: StepStatus # 成功/失败/跳过
135
+ input_data: Any # 这一步的输入
136
+ output_data: Any # 这一步的输出
137
+ error: Optional[str] # 如果失败, 错误信息
138
+ start_time: float # 开始时间戳
139
+ end_time: float # 结束时间戳
140
+ metadata: Dict = field(default_factory=dict) # 额外信息
141
+
142
+ @property
143
+ def duration_ms(self) -> float:
144
+ return (self.end_time - self.start_time) * 1000
145
+
146
+
147
+ @dataclass
148
+ class TraceRecord:
149
+ """一次完整查询的全链路 Trace"""
150
+ trace_id: str # 唯一追踪 ID
151
+ query: str # 用户原始问题
152
+ steps: List[TraceStep] = field(default_factory=list)
153
+ total_start: float = 0.0
154
+ total_end: float = 0.0
155
+ final_answer: str = ""
156
+
157
+ @property
158
+ def total_duration_ms(self) -> float:
159
+ return (self.total_end - self.total_start) * 1000
160
+
161
+ @property
162
+ def success_steps(self) -> List[TraceStep]:
163
+ return [s for s in self.steps if s.status == StepStatus.SUCCESS]
164
+
165
+ @property
166
+ def failed_steps(self) -> List[TraceStep]:
167
+ return [s for s in self.steps if s.status == StepStatus.FAILED]
168
+
169
+ @property
170
+ def step_names(self) -> List[str]:
171
+ return [s.step_name for s in self.steps]
172
+
173
+ def get_step(self, name: str) -> Optional[TraceStep]:
174
+ return next((s for s in self.steps if s.step_name == name), None)
175
+
176
+
177
+ def perform_rag_with_tracing(
178
+ query: str,
179
+ milvus, pdf, neo4j_driver, llm, requests_module,
180
+ ) -> TraceRecord:
181
+ """
182
+ 带完整 Tracing 的 RAG 执行
183
+ 每一步都记录: input → output → 状态 → 耗时
184
+ """
185
+ import json as _json
186
+
187
+ trace = TraceRecord(
188
+ trace_id=str(uuid.uuid4()),
189
+ query=query,
190
+ total_start=time.time(),
191
+ )
192
+ step_idx = 0
193
+
194
+ # ---- Step 1: Redis 缓存检查 ----
195
+ step_idx += 1
196
+ s1_start = time.time()
197
+ cache_key = hashlib.md5(query.encode('utf-8')).hexdigest()
198
+ trace.steps.append(TraceStep(
199
+ step_name="redis_check",
200
+ step_index=step_idx,
201
+ status=StepStatus.SUCCESS,
202
+ input_data={"query": query, "cache_key": f"llm:cache:{cache_key}"},
203
+ output_data={"hit": False},
204
+ error=None,
205
+ start_time=s1_start,
206
+ end_time=time.time(),
207
+ metadata={"action": "cache_lookup"},
208
+ ))
209
+
210
+ # ---- Step 2: Milvus 向量召回 ----
211
+ step_idx += 1
212
+ s2_start = time.time()
213
+ context = ""
214
+ try:
215
+ results = milvus.similarity_search(query, k=10, ranker_type="rrf", ranker_params={"k": 100})
216
+ context = "\n\n".join(d.page_content for d in results) if results else ""
217
+ trace.steps.append(TraceStep(
218
+ step_name="milvus_search",
219
+ step_index=step_idx,
220
+ status=StepStatus.SUCCESS,
221
+ input_data={"query": query, "k": 10, "ranker": "rrf"},
222
+ output_data={"doc_count": len(results) if results else 0, "context_chars": len(context)},
223
+ error=None,
224
+ start_time=s2_start,
225
+ end_time=time.time(),
226
+ metadata={"ranker_params": {"k": 100}},
227
+ ))
228
+ except Exception as e:
229
+ trace.steps.append(TraceStep(
230
+ step_name="milvus_search",
231
+ step_index=step_idx,
232
+ status=StepStatus.FAILED,
233
+ input_data={"query": query, "k": 10},
234
+ output_data=None,
235
+ error=str(e),
236
+ start_time=s2_start,
237
+ end_time=time.time(),
238
+ ))
239
+
240
+ # ---- Step 3: PDF 父子文档检索 ----
241
+ step_idx += 1
242
+ s3_start = time.time()
243
+ pdf_res = ""
244
+ try:
245
+ docs = pdf.invoke(query)
246
+ if docs and len(docs) >= 1:
247
+ pdf_res = docs[0].page_content
248
+ trace.steps.append(TraceStep(
249
+ step_name="pdf_retrieval",
250
+ step_index=step_idx,
251
+ status=StepStatus.SUCCESS,
252
+ input_data={"query": query},
253
+ output_data={"doc_count": len(docs) if docs else 0, "content_chars": len(pdf_res)},
254
+ error=None,
255
+ start_time=s3_start,
256
+ end_time=time.time(),
257
+ ))
258
+ except Exception as e:
259
+ trace.steps.append(TraceStep(
260
+ step_name="pdf_retrieval",
261
+ step_index=step_idx,
262
+ status=StepStatus.FAILED,
263
+ input_data={"query": query},
264
+ output_data=None,
265
+ error=str(e),
266
+ start_time=s3_start,
267
+ end_time=time.time(),
268
+ ))
269
+ context = context + "\n" + pdf_res
270
+
271
+ # ---- Step 4: Cypher 生成 ----
272
+ step_idx += 1
273
+ s4_start = time.time()
274
+ neo4j_res = ""
275
+ cypher_query = None
276
+ try:
277
+ resp = requests_module.post("http://0.0.0.0:8101/generate",
278
+ _json.dumps({"natural_language_query": query}))
279
+ if resp.status_code == 200:
280
+ d = resp.json()
281
+ cypher_query = d.get("cypher_query")
282
+ confidence = d.get("confidence", 0)
283
+ validated = d.get("validated", False)
284
+ trace.steps.append(TraceStep(
285
+ step_name="cypher_generate",
286
+ step_index=step_idx,
287
+ status=StepStatus.SUCCESS,
288
+ input_data={"query": query},
289
+ output_data={"cypher": cypher_query, "confidence": confidence, "validated": validated},
290
+ error=None,
291
+ start_time=s4_start,
292
+ end_time=time.time(),
293
+ ))
294
+
295
+ # ---- Step 5: Cypher 校验 + Neo4j 执行 ----
296
+ if cypher_query and float(confidence) >= 0.9 and validated:
297
+ step_idx += 1
298
+ s5_start = time.time()
299
+ try:
300
+ vresp = requests_module.post("http://0.0.0.0:8101/validate",
301
+ _json.dumps({"cypher_query": cypher_query}))
302
+ if vresp.status_code == 200 and vresp.json()["is_valid"]:
303
+ with neo4j_driver.session() as session:
304
+ record = session.run(cypher_query)
305
+ result = list(map(lambda x: x[0], record))
306
+ neo4j_res = ','.join(result)
307
+ trace.steps.append(TraceStep(
308
+ step_name="neo4j_execute",
309
+ step_index=step_idx,
310
+ status=StepStatus.SUCCESS,
311
+ input_data={"cypher": cypher_query},
312
+ output_data={"result_count": len(result), "results": result},
313
+ error=None,
314
+ start_time=s5_start,
315
+ end_time=time.time(),
316
+ ))
317
+ else:
318
+ trace.steps.append(TraceStep(
319
+ step_name="neo4j_execute",
320
+ step_index=step_idx,
321
+ status=StepStatus.SKIPPED,
322
+ input_data={"cypher": cypher_query},
323
+ output_data={"reason": "validation_failed"},
324
+ error=None,
325
+ start_time=s5_start,
326
+ end_time=time.time(),
327
+ ))
328
+ except Exception as e:
329
+ trace.steps.append(TraceStep(
330
+ step_name="neo4j_execute",
331
+ step_index=step_idx,
332
+ status=StepStatus.FAILED,
333
+ input_data={"cypher": cypher_query},
334
+ output_data=None,
335
+ error=str(e),
336
+ start_time=s5_start,
337
+ end_time=time.time(),
338
+ ))
339
+ else:
340
+ step_idx += 1
341
+ trace.steps.append(TraceStep(
342
+ step_name="neo4j_execute",
343
+ step_index=step_idx,
344
+ status=StepStatus.SKIPPED,
345
+ input_data={"cypher": cypher_query, "confidence": confidence},
346
+ output_data={"reason": "low_confidence_or_invalid"},
347
+ error=None,
348
+ start_time=time.time(),
349
+ end_time=time.time(),
350
+ ))
351
+ else:
352
+ trace.steps.append(TraceStep(
353
+ step_name="cypher_generate",
354
+ step_index=step_idx,
355
+ status=StepStatus.FAILED,
356
+ input_data={"query": query},
357
+ output_data={"status_code": resp.status_code},
358
+ error=f"HTTP {resp.status_code}",
359
+ start_time=s4_start,
360
+ end_time=time.time(),
361
+ ))
362
+ except Exception as e:
363
+ trace.steps.append(TraceStep(
364
+ step_name="cypher_generate",
365
+ step_index=step_idx,
366
+ status=StepStatus.FAILED,
367
+ input_data={"query": query},
368
+ output_data=None,
369
+ error=str(e),
370
+ start_time=s4_start,
371
+ end_time=time.time(),
372
+ ))
373
+
374
+ context = context + "\n" + neo4j_res
375
+
376
+ # ---- Step 6: LLM 推理 ----
377
+ step_idx += 1
378
+ s6_start = time.time()
379
+ SYSTEM = "System: 你是一个非常得力的医学助手."
380
+ USER = f"User: <context>{context}</context><question>{query}</question>"
381
+ full_prompt = SYSTEM + USER
382
+ try:
383
+ response = llm.chat.completions.create(
384
+ model="gpt-4o-mini",
385
+ messages=[{"role": "user", "content": full_prompt}],
386
+ temperature=0.7,
387
+ )
388
+ answer = response.choices[0].message.content
389
+ trace.steps.append(TraceStep(
390
+ step_name="llm_inference",
391
+ step_index=step_idx,
392
+ status=StepStatus.SUCCESS,
393
+ input_data={"prompt_chars": len(full_prompt), "model": "gpt-4o-mini", "temperature": 0.7},
394
+ output_data={"answer_chars": len(answer), "answer_preview": answer[:80]},
395
+ error=None,
396
+ start_time=s6_start,
397
+ end_time=time.time(),
398
+ ))
399
+ trace.final_answer = answer
400
+ except Exception as e:
401
+ trace.steps.append(TraceStep(
402
+ step_name="llm_inference",
403
+ step_index=step_idx,
404
+ status=StepStatus.FAILED,
405
+ input_data={"prompt_chars": len(full_prompt)},
406
+ output_data=None,
407
+ error=str(e),
408
+ start_time=s6_start,
409
+ end_time=time.time(),
410
+ ))
411
+
412
+ trace.total_end = time.time()
413
+ return trace
414
+
415
+
416
+ # ================================================================
417
+ # Mock 工厂
418
+ # ================================================================
419
+
420
+ def make_mocks(milvus_fail=False, pdf_fail=False, neo4j_fail=False, llm_fail=False,
421
+ neo4j_low_confidence=False):
422
+ """构建可配置故障的 Mock 组件"""
423
+ # Milvus
424
+ milvus = MagicMock()
425
+ if milvus_fail:
426
+ milvus.similarity_search.side_effect = ConnectionError("Milvus timeout")
427
+ else:
428
+ milvus.similarity_search.return_value = [
429
+ FakeDocument(page_content="高血压需控制钠摄入不超过5g")
430
+ ]
431
+
432
+ # PDF
433
+ pdf = MagicMock()
434
+ if pdf_fail:
435
+ pdf.invoke.side_effect = Exception("PDF index corrupted")
436
+ else:
437
+ pdf.invoke.return_value = [FakeDocument(page_content="《高血压防治指南》建议低盐低脂")]
438
+
439
+ # Neo4j
440
+ neo4j_driver = MagicMock()
441
+ sess = MagicMock()
442
+ if neo4j_fail:
443
+ sess.run.side_effect = Exception("Neo4j ServiceUnavailable")
444
+ else:
445
+ sess.run.return_value = [("氨氯地平",), ("缬沙坦",)]
446
+ neo4j_driver.session.return_value.__enter__ = MagicMock(return_value=sess)
447
+ neo4j_driver.session.return_value.__exit__ = MagicMock(return_value=False)
448
+
449
+ # Cypher API
450
+ req = MagicMock()
451
+ if neo4j_fail:
452
+ req.post.side_effect = ConnectionError("Cypher API down")
453
+ else:
454
+ gen = MagicMock(); gen.status_code = 200
455
+ conf = 0.5 if neo4j_low_confidence else 0.95
456
+ gen.json.return_value = {
457
+ "cypher_query": "MATCH (d:Disease)-[:has_drug]->(m) RETURN m.name",
458
+ "confidence": conf, "validated": True,
459
+ }
460
+ val = MagicMock(); val.status_code = 200
461
+ val.json.return_value = {"is_valid": True}
462
+ req.post.side_effect = [gen, val]
463
+
464
+ # LLM
465
+ llm = MagicMock()
466
+ if llm_fail:
467
+ llm.chat.completions.create.side_effect = TimeoutError("LLM timeout")
468
+ else:
469
+ llm.chat.completions.create.return_value = FakeChatResponse(
470
+ "高血压患者应限制盐分摄入, 常用药物包括氨氯地平等。"
471
+ )
472
+
473
+ return milvus, pdf, neo4j_driver, llm, req
474
+
475
+
476
+ def run_traced(query="高血压不能吃什么?", **kwargs) -> TraceRecord:
477
+ milvus, pdf, neo4j, llm, req = make_mocks(**kwargs)
478
+ return perform_rag_with_tracing(query, milvus, pdf, neo4j, llm, req)
479
+
480
+
481
+ # ================================================================
482
+ # 维度 1: 决策链路完整性
483
+ # ================================================================
484
+
485
+ class TestTraceCompleteness:
486
+ """每一步 Thought → Action → Observation 都被记录"""
487
+
488
+ def test_happy_path_has_all_six_steps(self):
489
+ """正常链路: 6 步全部记录"""
490
+ trace = run_traced()
491
+ expected = ["redis_check", "milvus_search", "pdf_retrieval",
492
+ "cypher_generate", "neo4j_execute", "llm_inference"]
493
+ assert trace.step_names == expected, (
494
+ f"链路步骤: {trace.step_names}, 期望: {expected}"
495
+ )
496
+
497
+ def test_every_step_has_required_fields(self):
498
+ """每步都有 step_name / status / input / output / 时间"""
499
+ trace = run_traced()
500
+ for step in trace.steps:
501
+ assert step.step_name, f"Step {step.step_index}: 缺少 step_name"
502
+ assert step.status is not None, f"Step {step.step_name}: 缺少 status"
503
+ assert step.input_data is not None, f"Step {step.step_name}: 缺少 input_data"
504
+ assert step.start_time > 0, f"Step {step.step_name}: 缺少 start_time"
505
+ assert step.end_time >= step.start_time, f"Step {step.step_name}: end < start"
506
+
507
+ def test_trace_has_id_and_query(self):
508
+ """Trace 有唯一 ID 和原始查询"""
509
+ trace = run_traced("糖尿病饮食")
510
+ assert trace.trace_id, "应有 trace_id"
511
+ assert len(trace.trace_id) == 36, "trace_id 应为 UUID 格式"
512
+ assert trace.query == "糖尿病饮食", "应记录原始查询"
513
+
514
+ def test_trace_has_final_answer(self):
515
+ """Trace 记录最终回答"""
516
+ trace = run_traced()
517
+ assert len(trace.final_answer) > 0, "应记录最终回答"
518
+
519
+ def test_step_indices_are_sequential(self):
520
+ """步骤序号连续递增"""
521
+ trace = run_traced()
522
+ indices = [s.step_index for s in trace.steps]
523
+ for i in range(len(indices) - 1):
524
+ assert indices[i+1] > indices[i], f"步骤序号不连续: {indices}"
525
+
526
+ def test_all_success_in_happy_path(self):
527
+ """正常链路: 所有步骤都是 SUCCESS"""
528
+ trace = run_traced()
529
+ for step in trace.steps:
530
+ assert step.status == StepStatus.SUCCESS, (
531
+ f"Step '{step.step_name}' 应为 SUCCESS, 实际 {step.status.value}"
532
+ )
533
+
534
+
535
+ # ================================================================
536
+ # 维度 2: 错误定位能力
537
+ # ================================================================
538
+
539
+ class TestErrorLocalization:
540
+ """任意一步失败时, Trace 能精确定位到是哪一步"""
541
+
542
+ def test_milvus_failure_located(self):
543
+ """Milvus 故障 → Trace 精确标记 milvus_search 为 FAILED"""
544
+ trace = run_traced(milvus_fail=True)
545
+ milvus_step = trace.get_step("milvus_search")
546
+
547
+ assert milvus_step is not None, "应有 milvus_search 步骤"
548
+ assert milvus_step.status == StepStatus.FAILED
549
+ assert "timeout" in milvus_step.error.lower() or "milvus" in milvus_step.error.lower()
550
+
551
+ def test_pdf_failure_located(self):
552
+ """PDF 故障 → 精确标记 pdf_retrieval 为 FAILED"""
553
+ trace = run_traced(pdf_fail=True)
554
+ pdf_step = trace.get_step("pdf_retrieval")
555
+
556
+ assert pdf_step is not None
557
+ assert pdf_step.status == StepStatus.FAILED
558
+ assert pdf_step.error is not None
559
+
560
+ def test_neo4j_failure_located(self):
561
+ """Neo4j 故障 → 精确标记 cypher_generate 为 FAILED"""
562
+ trace = run_traced(neo4j_fail=True)
563
+ cypher_step = trace.get_step("cypher_generate")
564
+
565
+ assert cypher_step is not None
566
+ assert cypher_step.status == StepStatus.FAILED
567
+ assert cypher_step.error is not None
568
+
569
+ def test_llm_failure_located(self):
570
+ """LLM 超时 → 精确标记 llm_inference 为 FAILED"""
571
+ trace = run_traced(llm_fail=True)
572
+ llm_step = trace.get_step("llm_inference")
573
+
574
+ assert llm_step is not None
575
+ assert llm_step.status == StepStatus.FAILED
576
+ assert "timeout" in llm_step.error.lower()
577
+
578
+ def test_only_failed_step_is_marked(self):
579
+ """只有故障步骤被标记 FAILED, 其他步骤正常"""
580
+ trace = run_traced(pdf_fail=True)
581
+
582
+ for step in trace.steps:
583
+ if step.step_name == "pdf_retrieval":
584
+ assert step.status == StepStatus.FAILED
585
+ elif step.step_name in ["redis_check", "milvus_search"]:
586
+ assert step.status == StepStatus.SUCCESS, (
587
+ f"PDF 故障不应影响 {step.step_name}"
588
+ )
589
+
590
+ def test_failed_steps_count(self):
591
+ """Milvus + PDF 同时故障 → 恰好 2 个 FAILED"""
592
+ trace = run_traced(milvus_fail=True, pdf_fail=True)
593
+ assert len(trace.failed_steps) == 2, (
594
+ f"应有 2 个失败步骤, 实际 {len(trace.failed_steps)}: "
595
+ f"{[s.step_name for s in trace.failed_steps]}"
596
+ )
597
+
598
+ def test_low_confidence_neo4j_skipped(self):
599
+ """低置信度 → neo4j_execute 标记为 SKIPPED (非 FAILED)"""
600
+ trace = run_traced(neo4j_low_confidence=True)
601
+ neo4j_step = trace.get_step("neo4j_execute")
602
+
603
+ assert neo4j_step is not None
604
+ assert neo4j_step.status == StepStatus.SKIPPED, (
605
+ f"低置信度应 SKIPPED, 实际 {neo4j_step.status.value}"
606
+ )
607
+
608
+
609
+ # ================================================================
610
+ # 维度 3: 输入输出追溯
611
+ # ================================================================
612
+
613
+ class TestInputOutputTracing:
614
+ """每步的 input/output 可追溯, 出问题时能看到 '传了什么进去, 出了什么来'"""
615
+
616
+ def test_redis_step_records_cache_key(self):
617
+ """Redis 步骤记录 cache_key"""
618
+ trace = run_traced("高血压")
619
+ redis_step = trace.get_step("redis_check")
620
+
621
+ assert "cache_key" in redis_step.input_data
622
+ assert redis_step.input_data["cache_key"].startswith("llm:cache:")
623
+
624
+ def test_milvus_step_records_query_and_params(self):
625
+ """Milvus 步骤记录查询参数 (k, ranker)"""
626
+ trace = run_traced()
627
+ step = trace.get_step("milvus_search")
628
+
629
+ assert step.input_data["k"] == 10
630
+ assert step.input_data["ranker"] == "rrf"
631
+ assert step.output_data["doc_count"] >= 0
632
+
633
+ def test_cypher_step_records_generated_cypher(self):
634
+ """Cypher 步骤记录生成的 Cypher 语句和置信度"""
635
+ trace = run_traced()
636
+ step = trace.get_step("cypher_generate")
637
+
638
+ assert step.output_data["cypher"] is not None
639
+ assert step.output_data["confidence"] >= 0.9
640
+ assert step.output_data["validated"] is True
641
+
642
+ def test_neo4j_step_records_results(self):
643
+ """Neo4j 步骤记录查询结果"""
644
+ trace = run_traced()
645
+ step = trace.get_step("neo4j_execute")
646
+
647
+ assert step.output_data["result_count"] == 2
648
+ assert "氨氯地平" in step.output_data["results"]
649
+
650
+ def test_llm_step_records_prompt_size(self):
651
+ """LLM 步骤记录 prompt 大小和模型参数"""
652
+ trace = run_traced()
653
+ step = trace.get_step("llm_inference")
654
+
655
+ assert step.input_data["prompt_chars"] > 0
656
+ assert step.input_data["model"] == "gpt-4o-mini"
657
+ assert step.input_data["temperature"] == 0.7
658
+
659
+ def test_llm_step_records_answer_preview(self):
660
+ """LLM 步骤记录回答预览"""
661
+ trace = run_traced()
662
+ step = trace.get_step("llm_inference")
663
+
664
+ assert step.output_data["answer_chars"] > 0
665
+ assert len(step.output_data["answer_preview"]) > 0
666
+
667
+ def test_failed_step_records_error_message(self):
668
+ """失败步骤记录具体错误信息"""
669
+ trace = run_traced(milvus_fail=True)
670
+ step = trace.get_step("milvus_search")
671
+
672
+ assert step.error is not None
673
+ assert len(step.error) > 0
674
+ assert step.output_data is None, "失败步骤 output 应为 None"
675
+
676
+
677
+ # ================================================================
678
+ # 维度 4: 时间线追踪
679
+ # ================================================================
680
+
681
+ class TestTimelineTracking:
682
+ """每步的耗时可追踪, 用于发现性能瓶颈"""
683
+
684
+ def test_every_step_has_positive_duration(self):
685
+ """每步耗时 ≥ 0"""
686
+ trace = run_traced()
687
+ for step in trace.steps:
688
+ assert step.duration_ms >= 0, (
689
+ f"Step '{step.step_name}' 耗时为负: {step.duration_ms}ms"
690
+ )
691
+
692
+ def test_total_duration_covers_all_steps(self):
693
+ """总耗时 ≥ 所有步骤耗时之和 (因为有框架开销)"""
694
+ trace = run_traced()
695
+ steps_total = sum(s.duration_ms for s in trace.steps)
696
+ assert trace.total_duration_ms >= 0
697
+
698
+ def test_steps_are_chronologically_ordered(self):
699
+ """步骤按时间顺序排列"""
700
+ trace = run_traced()
701
+ for i in range(len(trace.steps) - 1):
702
+ assert trace.steps[i].start_time <= trace.steps[i+1].start_time, (
703
+ f"Step '{trace.steps[i].step_name}' 开始时间 > 下一步"
704
+ )
705
+
706
+ def test_trace_start_before_first_step(self):
707
+ """Trace 总开始时间 ≤ 第一步开始时间"""
708
+ trace = run_traced()
709
+ assert trace.total_start <= trace.steps[0].start_time
710
+
711
+ def test_trace_end_after_last_step(self):
712
+ """Trace 总结束时间 ≥ 最后一步结束时间"""
713
+ trace = run_traced()
714
+ assert trace.total_end >= trace.steps[-1].end_time
715
+
716
+
717
+ # ================================================================
718
+ # 维度 5: 全链路 Trace 报告
719
+ # ================================================================
720
+
721
+ class TestTraceReport:
722
+ """生成人类可读的 Trace 报告, 用于调试和展示"""
723
+
724
+ def test_normal_trace_report(self, capsys):
725
+ """正常链路的 Trace 报告"""
726
+ trace = run_traced("高血压不能吃什么?")
727
+ _print_trace_report(trace)
728
+ assert len(trace.failed_steps) == 0
729
+
730
+ def test_degraded_trace_report(self, capsys):
731
+ """降级链路的 Trace 报告 (PDF + Neo4j 故障)"""
732
+ trace = run_traced("高血压不能吃什么?", pdf_fail=True, neo4j_fail=True)
733
+ _print_trace_report(trace)
734
+ assert len(trace.failed_steps) >= 1
735
+
736
+ def test_full_failure_trace_report(self, capsys):
737
+ """全故障链路的 Trace 报告"""
738
+ trace = run_traced("高血压", milvus_fail=True, pdf_fail=True, neo4j_fail=True)
739
+ _print_trace_report(trace)
740
+ assert len(trace.failed_steps) >= 3
741
+
742
+ def test_trace_report_identifies_bottleneck(self):
743
+ """Trace 能识别最慢的步骤 (性能瓶颈)"""
744
+ trace = run_traced()
745
+ if trace.steps:
746
+ slowest = max(trace.steps, key=lambda s: s.duration_ms)
747
+ assert slowest.step_name is not None
748
+ # 在 Mock 环境中耗时极短, 但结构正确
749
+
750
+ def test_multiple_traces_comparable(self):
751
+ """多个 Trace 可以对比 (不同 trace_id)"""
752
+ t1 = run_traced("问题A")
753
+ t2 = run_traced("问题B")
754
+ assert t1.trace_id != t2.trace_id, "不同查询应有不同 trace_id"
755
+
756
+ def test_trace_summary_printout(self, capsys):
757
+ """打印完整的 Observability 总结报告"""
758
+ traces = [
759
+ ("正常链路", run_traced("高血压饮食")),
760
+ ("PDF故障", run_traced("高血压饮食", pdf_fail=True)),
761
+ ("Neo4j故障", run_traced("高血压饮食", neo4j_fail=True)),
762
+ ("全部故障", run_traced("高血压饮食", milvus_fail=True, pdf_fail=True, neo4j_fail=True)),
763
+ ]
764
+
765
+ print("\n")
766
+ print("=" * 70)
767
+ print(" 医疗 RAG Agent — Observability & Tracing 报告")
768
+ print("=" * 70)
769
+
770
+ for label, trace in traces:
771
+ status_icon = "✅" if len(trace.failed_steps) == 0 else "⚠️"
772
+ print(f"\n {status_icon} [{label}] trace_id={trace.trace_id[:8]}...")
773
+ print(f" 查询: '{trace.query}'")
774
+ print(f" 总耗时: {trace.total_duration_ms:.2f}ms | 步骤数: {len(trace.steps)}")
775
+
776
+ for step in trace.steps:
777
+ icon = {"success": "✅", "failed": "❌", "skipped": "⏭️"}[step.status.value]
778
+ err_info = f" | Error: {step.error[:40]}" if step.error else ""
779
+ print(f" {icon} [{step.step_index}] {step.step_name:<20s} "
780
+ f"{step.duration_ms:>6.2f}ms{err_info}")
781
+
782
+ if trace.final_answer:
783
+ print(f" 回答: {trace.final_answer[:50]}...")
784
+
785
+ print(f"\n{'─' * 70}")
786
+ print(f" 场景覆盖: {len(traces)} 种 | 总步骤数: {sum(len(t.steps) for _, t in traces)}")
787
+ print("=" * 70)
788
+
789
+ assert True
790
+
791
+
792
+ def _print_trace_report(trace: TraceRecord):
793
+ """打印单条 Trace 的详细报告"""
794
+ print(f"\n ── Trace: {trace.trace_id[:8]}... ──")
795
+ print(f" 查询: '{trace.query}'")
796
+ print(f" 总耗时: {trace.total_duration_ms:.2f}ms")
797
+ for step in trace.steps:
798
+ icon = {"success": "✅", "failed": "❌", "skipped": "⏭️"}[step.status.value]
799
+ print(f" {icon} [{step.step_index}] {step.step_name}: {step.duration_ms:.2f}ms")
800
+ if step.error:
801
+ print(f" ❗ Error: {step.error}")
802
+ if trace.final_answer:
803
+ print(f" 回答: {trace.final_answer[:60]}...")
804
+
805
+
806
+ # ================================================================
807
+ if __name__ == "__main__":
808
+ pytest.main([__file__, "-v", "--tb=short", "-s"])