drewli20200316 commited on
Commit
2d0bb6a
·
verified ·
1 Parent(s): c9cdcf4

Add test/test6.py

Browse files
Files changed (1) hide show
  1. test/test6.py +681 -0
test/test6.py ADDED
@@ -0,0 +1,681 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ================================================================
3
+ 医疗 RAG Agent — Cost & Efficiency 评测 (成本与效率)
4
+ ================================================================
5
+ 测试层级:
6
+ 单元测试 (test1.py): 单工具调用准确性 ✅ 67 passed
7
+ 集成测试 (test2.py): 多步骤工具链协作 ✅ 37 passed
8
+ 回归测试 (test3.py): 防退化 & 边界守护 ✅ 52 passed
9
+ 安全红队 (test4.py): 对抗性攻击防御 ✅ 45 passed
10
+ E2E完成率(test5.py): 端到端任务完成率 ✅ 60 passed
11
+ 成本效率 (test6.py): Cost & Efficiency ← 当前文件
12
+
13
+ 为什么要测成本?
14
+ Agent 每回答一个问题要: 1次 Embedding + 1次 Milvus + 1次 PDF +
15
+ 2次 Cypher API + 1次 Neo4j + 1次 LLM = 至少 7 次外部调用
16
+ 在生产环境中, 这些调用直接关系 token 消耗和 API 费用
17
+
18
+ 测试维度:
19
+ 维度 1: 外部调用次数审计 (每次查询调了几次 API?)
20
+ 维度 2: Token 消耗估算 (Prompt + Response 共多少 token?)
21
+ 维度 3: 缓存节省量化 (Redis 命中省了多少调用?)
22
+ 维度 4: 降级场景的成本影响 (组件故障时成本变化)
23
+ 维度 5: 成本报告 (人类可读的费用估算)
24
+
25
+ 运行:
26
+ pytest test6.py -v --tb=short -s
27
+ pytest test6.py -v -k "call_count" # 调用次数
28
+ pytest test6.py -v -k "token" # Token 消耗
29
+ pytest test6.py -v -k "cache_saving" # 缓存节省
30
+ ================================================================
31
+ """
32
+
33
+ import sys
34
+ import os
35
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
36
+
37
+ import types
38
+ import pytest
39
+ import json
40
+ import hashlib
41
+ import time
42
+ from unittest.mock import MagicMock, patch, call
43
+ from dataclasses import dataclass, field
44
+ from typing import Optional, List, Dict
45
+
46
+
47
+ # ================================================================
48
+ # 前置: Mock 缺失依赖
49
+ # ================================================================
50
+
51
+ def _ensure_mock_module(name):
52
+ if name not in sys.modules:
53
+ sys.modules[name] = MagicMock()
54
+
55
+ for mod in [
56
+ "langchain_classic", "langchain_classic.retrievers",
57
+ "langchain_classic.retrievers.parent_document_retriever",
58
+ "langchain_milvus", "langchain_text_splitters",
59
+ "langchain_core", "langchain_core.stores", "langchain_core.documents",
60
+ "langchain.embeddings", "langchain.embeddings.base",
61
+ "neo4j", "dotenv", "uvicorn",
62
+ "fastapi", "fastapi.middleware", "fastapi.middleware.cors",
63
+ ]:
64
+ _ensure_mock_module(mod)
65
+
66
+ class _FakeEmbeddingsBase:
67
+ pass
68
+ sys.modules["langchain.embeddings.base"].Embeddings = _FakeEmbeddingsBase
69
+
70
+
71
+ # ================================================================
72
+ # 基础设施
73
+ # ================================================================
74
+
75
+ @dataclass
76
+ class FakeDocument:
77
+ page_content: str
78
+ metadata: dict = field(default_factory=dict)
79
+
80
+ class FakeChatResponse:
81
+ def __init__(self, content):
82
+ msg = type('Msg', (), {'content': content})()
83
+ choice = type('Choice', (), {'message': msg})()
84
+ self.choices = [choice]
85
+
86
+ class FakeRedisClient:
87
+ def __init__(self):
88
+ self._store = {}
89
+ self._expiry = {}
90
+ def ping(self): return True
91
+ def get(self, key): return self._store.get(key)
92
+ def set(self, key, value, ex=None, nx=False):
93
+ if nx and key in self._store: return False
94
+ self._store[key] = value
95
+ if ex: self._expiry[key] = ex
96
+ return True
97
+ def setex(self, key, expire, value):
98
+ self._store[key] = value; self._expiry[key] = expire; return True
99
+ def delete(self, key): return 1 if self._store.pop(key, None) is not None else 0
100
+ def register_script(self, script):
101
+ def f(keys=None, args=None):
102
+ if keys and args and self._store.get(keys[0]) == args[0]:
103
+ del self._store[keys[0]]; return 1
104
+ return 0
105
+ return f
106
+
107
+ def make_redis_manager():
108
+ from new_redis import RedisClientWrapper
109
+ RedisClientWrapper._pool = "FAKE"
110
+ mgr = object.__new__(RedisClientWrapper)
111
+ mgr.client = FakeRedisClient()
112
+ mgr.unlock_script = mgr.client.register_script("")
113
+ return mgr
114
+
115
+
116
+ # ================================================================
117
+ # 成本追踪器: 记录所有外部调用和资源消耗
118
+ # ================================================================
119
+
120
+ @dataclass
121
+ class CostTracker:
122
+ """追踪单次查询的全部外部调用和资源消耗"""
123
+ # 调用次数
124
+ milvus_calls: int = 0
125
+ pdf_calls: int = 0
126
+ cypher_generate_calls: int = 0
127
+ cypher_validate_calls: int = 0
128
+ neo4j_session_calls: int = 0
129
+ llm_calls: int = 0
130
+ redis_get_calls: int = 0
131
+ redis_set_calls: int = 0
132
+
133
+ # Token 估算 (中文约 1 token ≈ 1.5 字符)
134
+ prompt_chars: int = 0
135
+ response_chars: int = 0
136
+
137
+ # 时间
138
+ start_time: float = 0.0
139
+ end_time: float = 0.0
140
+
141
+ @property
142
+ def total_external_calls(self) -> int:
143
+ return (self.milvus_calls + self.pdf_calls +
144
+ self.cypher_generate_calls + self.cypher_validate_calls +
145
+ self.neo4j_session_calls + self.llm_calls)
146
+
147
+ @property
148
+ def estimated_prompt_tokens(self) -> int:
149
+ """粗估 prompt token 数 (中文 ≈ 1.5 字符/token)"""
150
+ return int(self.prompt_chars / 1.5) if self.prompt_chars else 0
151
+
152
+ @property
153
+ def estimated_response_tokens(self) -> int:
154
+ return int(self.response_chars / 1.5) if self.response_chars else 0
155
+
156
+ @property
157
+ def estimated_total_tokens(self) -> int:
158
+ return self.estimated_prompt_tokens + self.estimated_response_tokens
159
+
160
+ @property
161
+ def elapsed_ms(self) -> float:
162
+ return (self.end_time - self.start_time) * 1000 if self.end_time else 0
163
+
164
+ def estimated_cost_usd(self, model="gpt-4o-mini") -> float:
165
+ """
166
+ 估算 API 费用 (USD)
167
+ gpt-4o-mini: $0.15/1M input + $0.60/1M output
168
+ gpt-4o: $2.50/1M input + $10.00/1M output
169
+ text-embedding-3-small: $0.02/1M tokens
170
+ """
171
+ pricing = {
172
+ "gpt-4o-mini": {"input": 0.15, "output": 0.60},
173
+ "gpt-4o": {"input": 2.50, "output": 10.00},
174
+ }
175
+ p = pricing.get(model, pricing["gpt-4o-mini"])
176
+ input_cost = self.estimated_prompt_tokens * p["input"] / 1_000_000
177
+ output_cost = self.estimated_response_tokens * p["output"] / 1_000_000
178
+ # Embedding 调用 (1次/查询)
179
+ embed_cost = 50 * 0.02 / 1_000_000 # ~50 tokens per query
180
+ return input_cost + output_cost + embed_cost
181
+
182
+
183
+ def build_tracked_mocks(tracker: CostTracker, neo4j_fail=False):
184
+ """构建带调用计数的 Mock 组件"""
185
+
186
+ # Milvus
187
+ milvus = MagicMock()
188
+ def milvus_search(*args, **kwargs):
189
+ tracker.milvus_calls += 1
190
+ return [FakeDocument(page_content="高血压患者应控制钠摄入量不超过5克")]
191
+ milvus.similarity_search.side_effect = milvus_search
192
+
193
+ # PDF
194
+ pdf = MagicMock()
195
+ def pdf_invoke(*args, **kwargs):
196
+ tracker.pdf_calls += 1
197
+ return [FakeDocument(page_content="《中国高血压防治指南》建议低盐低脂饮食")]
198
+ pdf.invoke.side_effect = pdf_invoke
199
+
200
+ # Neo4j Driver
201
+ neo4j_driver = MagicMock()
202
+ sess = MagicMock()
203
+ def neo4j_run(*args, **kwargs):
204
+ tracker.neo4j_session_calls += 1
205
+ if neo4j_fail:
206
+ raise Exception("Neo4j down")
207
+ return [("氨氯地平",), ("缬沙坦",)]
208
+ sess.run.side_effect = neo4j_run
209
+ neo4j_driver.session.return_value.__enter__ = MagicMock(return_value=sess)
210
+ neo4j_driver.session.return_value.__exit__ = MagicMock(return_value=False)
211
+
212
+ # Cypher API (requests)
213
+ req = MagicMock()
214
+ call_index = [0]
215
+ def req_post(url, *args, **kwargs):
216
+ if neo4j_fail:
217
+ raise ConnectionError("Cypher API down")
218
+ if "/generate" in url:
219
+ tracker.cypher_generate_calls += 1
220
+ resp = MagicMock(); resp.status_code = 200
221
+ resp.json.return_value = {
222
+ "cypher_query": "MATCH (d:Disease)-[:has_drug]->(m) RETURN m.name",
223
+ "confidence": 0.95, "validated": True,
224
+ }
225
+ return resp
226
+ elif "/validate" in url:
227
+ tracker.cypher_validate_calls += 1
228
+ resp = MagicMock(); resp.status_code = 200
229
+ resp.json.return_value = {"is_valid": True}
230
+ return resp
231
+ req.post.side_effect = req_post
232
+
233
+ # LLM
234
+ llm = MagicMock()
235
+ def llm_create(*args, **kwargs):
236
+ tracker.llm_calls += 1
237
+ prompt = kwargs.get("messages", [{}])[0].get("content", "")
238
+ tracker.prompt_chars = len(prompt)
239
+ answer = "高血压患者应避免高盐饮食, 建议每日钠摄入不超过5克, 常用药物包括氨氯地平、缬沙坦等。"
240
+ tracker.response_chars = len(answer)
241
+ return FakeChatResponse(answer)
242
+ llm.chat.completions.create.side_effect = llm_create
243
+
244
+ return milvus, pdf, neo4j_driver, llm, req
245
+
246
+
247
+ def perform_rag_tracked(query, milvus, pdf, neo4j_driver, llm, requests_module):
248
+ """依赖注入版 perform_rag_and_llm"""
249
+ import json as _json
250
+
251
+ try:
252
+ results = milvus.similarity_search(query, k=10, ranker_type="rrf", ranker_params={"k": 100})
253
+ context = "\n\n".join(d.page_content for d in results) if results else ""
254
+ except Exception:
255
+ context = ""
256
+
257
+ pdf_res = ""
258
+ try:
259
+ docs = pdf.invoke(query)
260
+ if docs and len(docs) >= 1:
261
+ pdf_res = docs[0].page_content
262
+ except Exception:
263
+ pass
264
+ context = context + "\n" + pdf_res
265
+
266
+ neo4j_res = ""
267
+ try:
268
+ resp = requests_module.post("http://0.0.0.0:8101/generate",
269
+ _json.dumps({"natural_language_query": query}))
270
+ if resp.status_code == 200:
271
+ d = resp.json()
272
+ if d["cypher_query"] and float(d["confidence"]) >= 0.9 and d["validated"]:
273
+ vresp = requests_module.post("http://0.0.0.0:8101/validate",
274
+ _json.dumps({"cypher_query": d["cypher_query"]}))
275
+ if vresp.status_code == 200 and vresp.json()["is_valid"]:
276
+ with neo4j_driver.session() as session:
277
+ try:
278
+ record = session.run(d["cypher_query"])
279
+ neo4j_res = ','.join(list(map(lambda x: x[0], record)))
280
+ except Exception:
281
+ neo4j_res = ""
282
+ except Exception:
283
+ pass
284
+ context = context + "\n" + neo4j_res
285
+
286
+ SYSTEM = "System: 你是一个非常得力的医学助手, 你可以通过从数据库中检索出的信息找到问题的答案."
287
+ USER = f"""User: 利用介于<context>和</context>之间的信息来回答问题.
288
+ <context>
289
+ {context}
290
+ </context>
291
+ <question>
292
+ {query}
293
+ </question>"""
294
+
295
+ response = llm.chat.completions.create(
296
+ model="gpt-4o-mini",
297
+ messages=[{"role": "user", "content": SYSTEM + USER}],
298
+ temperature=0.7,
299
+ )
300
+ return response.choices[0].message.content
301
+
302
+
303
+ def run_tracked_query(query="高血压不能吃什么?", neo4j_fail=False) -> CostTracker:
304
+ """执行一次查询并返回成本追踪数据"""
305
+ tracker = CostTracker()
306
+ milvus, pdf, neo4j, llm, req = build_tracked_mocks(tracker, neo4j_fail=neo4j_fail)
307
+
308
+ tracker.start_time = time.time()
309
+ perform_rag_tracked(query, milvus, pdf, neo4j, llm, req)
310
+ tracker.end_time = time.time()
311
+
312
+ return tracker
313
+
314
+
315
+ # ================================================================
316
+ # 维度 1: 外部调用次数审计
317
+ # ================================================================
318
+
319
+ class TestExternalCallCount:
320
+ """
321
+ 核心问题: 回答一个问题到底调了多少次外部 API?
322
+ 每多一次调用 = 多一份延迟 + 多一份费用 + 多一个故障点
323
+ """
324
+
325
+ def test_normal_query_call_count(self):
326
+ """正常查询: 精确审计每个组件的调用次数"""
327
+ t = run_tracked_query()
328
+
329
+ assert t.milvus_calls == 1, f"Milvus 应调 1 次, 实际 {t.milvus_calls}"
330
+ assert t.pdf_calls == 1, f"PDF 应调 1 次, 实际 {t.pdf_calls}"
331
+ assert t.cypher_generate_calls == 1, f"Cypher /generate 应调 1 次, 实际 {t.cypher_generate_calls}"
332
+ assert t.cypher_validate_calls == 1, f"Cypher /validate 应调 1 次, 实际 {t.cypher_validate_calls}"
333
+ assert t.neo4j_session_calls == 1, f"Neo4j session.run 应调 1 次, 实际 {t.neo4j_session_calls}"
334
+ assert t.llm_calls == 1, f"LLM 应调 1 次, 实际 {t.llm_calls}"
335
+
336
+ def test_total_external_calls_is_six(self):
337
+ """正常查询总外部调用次数 = 6"""
338
+ t = run_tracked_query()
339
+ assert t.total_external_calls == 6, (
340
+ f"总外部调用应为 6, 实际 {t.total_external_calls}"
341
+ f"\n Milvus={t.milvus_calls}, PDF={t.pdf_calls},"
342
+ f" Cypher生成={t.cypher_generate_calls}, Cypher校验={t.cypher_validate_calls},"
343
+ f" Neo4j={t.neo4j_session_calls}, LLM={t.llm_calls}"
344
+ )
345
+
346
+ def test_no_duplicate_llm_calls(self):
347
+ """LLM 严格只调 1 次 (最贵的组件)"""
348
+ t = run_tracked_query()
349
+ assert t.llm_calls == 1, f"LLM 不应重复调用, 实际 {t.llm_calls}"
350
+
351
+ def test_neo4j_down_reduces_calls(self):
352
+ """Neo4j 宕机: 减少 3 次外部调用 (generate + validate + session)"""
353
+ t = run_tracked_query(neo4j_fail=True)
354
+
355
+ assert t.cypher_generate_calls == 0, "Cypher API 不可用时不应有 /generate 调用"
356
+ assert t.cypher_validate_calls == 0, "Cypher API 不可用时不应有 /validate 调用"
357
+ assert t.neo4j_session_calls == 0, "Cypher API 不可用时不应有 session.run"
358
+ assert t.total_external_calls == 3, (
359
+ f"Neo4j 宕机时总调用应为 3 (Milvus+PDF+LLM), 实际 {t.total_external_calls}"
360
+ )
361
+
362
+ def test_multiple_queries_each_has_own_calls(self):
363
+ """多个查询: 每个查询独立计数"""
364
+ trackers = [run_tracked_query(f"问题{i}") for i in range(5)]
365
+ for i, t in enumerate(trackers):
366
+ assert t.llm_calls == 1, f"查询 {i}: LLM 调用应为 1"
367
+ assert t.total_external_calls == 6, f"查询 {i}: 总调用应为 6"
368
+
369
+ def test_embedding_call_per_milvus_search(self):
370
+ """
371
+ 每次 Milvus similarity_search 内部会调用 1 次 Embedding
372
+ (由 Milvus SDK 内部处理, 这里验证 Milvus 调用次数)
373
+ """
374
+ t = run_tracked_query()
375
+ # Milvus 的 similarity_search 内部封装了 Embedding 调用
376
+ # 1 次 Milvus search = 1 次 Embedding (隐含)
377
+ assert t.milvus_calls == 1, "每次查询应只触发 1 次 Milvus 搜索 (含 1 次 Embedding)"
378
+
379
+
380
+ # ================================================================
381
+ # 维度 2: Token 消耗估算
382
+ # ================================================================
383
+
384
+ class TestTokenConsumption:
385
+ """
386
+ 核心问题: 每次查询消耗多少 token?
387
+ token 是 LLM 计费的直接单位
388
+ """
389
+
390
+ def test_prompt_token_count_reasonable(self):
391
+ """Prompt token 数在合理范围 (50-2000)"""
392
+ t = run_tracked_query()
393
+ tokens = t.estimated_prompt_tokens
394
+ assert 50 <= tokens <= 2000, f"Prompt tokens {tokens} 超出合理范围 [50, 2000]"
395
+
396
+ def test_response_token_count_reasonable(self):
397
+ """Response token 数在合理范围 (5-500)"""
398
+ t = run_tracked_query()
399
+ tokens = t.estimated_response_tokens
400
+ assert 5 <= tokens <= 500, f"Response tokens {tokens} 超出合理范围 [5, 500]"
401
+
402
+ def test_total_token_count_per_query(self):
403
+ """单次查询总 token 数 < 3000 (gpt-4o-mini 上下文窗口远大于此)"""
404
+ t = run_tracked_query()
405
+ total = t.estimated_total_tokens
406
+ assert total < 3000, f"单次查询 token {total} 不应超过 3000"
407
+
408
+ def test_prompt_is_largest_cost_component(self):
409
+ """Prompt token 应占总 token 的大部分 (>60%)"""
410
+ t = run_tracked_query()
411
+ if t.estimated_total_tokens > 0:
412
+ prompt_ratio = t.estimated_prompt_tokens / t.estimated_total_tokens
413
+ assert prompt_ratio > 0.6, (
414
+ f"Prompt 占比 {prompt_ratio:.1%}, 应 >60% (context 是大头)"
415
+ )
416
+
417
+ def test_longer_query_means_more_tokens(self):
418
+ """更长的问题 → 更多的 prompt token"""
419
+ t_short = run_tracked_query("高血压")
420
+ t_long = run_tracked_query("请详细介绍高血压的所有相关症状以及对应的治疗方案和饮食建议")
421
+
422
+ # 问题更长, prompt 应更大 (因为 query 出现在 <question> 中)
423
+ assert t_long.prompt_chars >= t_short.prompt_chars, (
424
+ f"长问题 prompt ({t_long.prompt_chars}) 应 ≥ 短问题 ({t_short.prompt_chars})"
425
+ )
426
+
427
+ def test_context_contributes_most_tokens(self):
428
+ """Context (三路召回内容) 是 prompt 中 token 最大的来源"""
429
+ t = run_tracked_query()
430
+ # 验证 prompt 中包含了 context 内容 (通过 prompt 长度 > 纯模板)
431
+ # 纯模板 (System + User + 标签) ≈ 120 字符
432
+ pure_template = 120
433
+ context_chars = t.prompt_chars - pure_template
434
+ assert context_chars > 0, "Context 应为 prompt 贡献内容"
435
+ context_ratio = context_chars / t.prompt_chars
436
+ assert context_ratio > 0.3, (
437
+ f"Context 占 prompt 比例 {context_ratio:.1%}, 应 >30%"
438
+ f"\n (Mock 数据较短; 生产环境 context 占比通常 >70%)"
439
+ )
440
+
441
+
442
+ # ================================================================
443
+ # 维度 3: 缓存节省量化
444
+ # ================================================================
445
+
446
+ class TestCacheSavings:
447
+ """
448
+ 核心问题: Redis 缓存帮我们省了多少钱?
449
+ 每次缓存命中 = 省了 6 次外部调用
450
+ """
451
+
452
+ def test_cache_hit_saves_all_external_calls(self):
453
+ """缓存命中: 0 次外部调用 (省了 6 次)"""
454
+ mgr = make_redis_manager()
455
+ first_tracker = CostTracker()
456
+ milvus, pdf, neo4j, llm, req = build_tracked_mocks(first_tracker)
457
+
458
+ def first_rag():
459
+ return perform_rag_tracked("高血压", milvus, pdf, neo4j, llm, req)
460
+
461
+ # 第一次: Miss, 走 RAG
462
+ mgr.get_or_compute("高血压", first_rag)
463
+ assert first_tracker.total_external_calls == 6
464
+
465
+ # 第二次: Hit, 不走 RAG
466
+ second_tracker = CostTracker()
467
+ milvus2, pdf2, neo4j2, llm2, req2 = build_tracked_mocks(second_tracker)
468
+ def second_rag():
469
+ return perform_rag_tracked("高血压", milvus2, pdf2, neo4j2, llm2, req2)
470
+ mgr.get_or_compute("高血压", second_rag)
471
+
472
+ assert second_tracker.total_external_calls == 0, (
473
+ f"缓存命中时不应有外部调用, 实际 {second_tracker.total_external_calls}"
474
+ )
475
+
476
+ def test_cache_saves_llm_cost(self):
477
+ """缓存命中: 节省 LLM 调用费用"""
478
+ mgr = make_redis_manager()
479
+
480
+ first_t = CostTracker()
481
+ m, p, n, l, r = build_tracked_mocks(first_t)
482
+ mgr.get_or_compute("Q1", lambda: perform_rag_tracked("Q1", m, p, n, l, r))
483
+
484
+ second_t = CostTracker()
485
+ m2, p2, n2, l2, r2 = build_tracked_mocks(second_t)
486
+ mgr.get_or_compute("Q1", lambda: perform_rag_tracked("Q1", m2, p2, n2, l2, r2))
487
+
488
+ assert first_t.llm_calls == 1, "第一次应调 LLM"
489
+ assert second_t.llm_calls == 0, "第二次缓存命中, 不应调 LLM"
490
+
491
+ def test_ten_queries_same_question_only_one_rag(self):
492
+ """同一问题查 10 次, 只走 1 次 RAG"""
493
+ mgr = make_redis_manager()
494
+ total_llm_calls = 0
495
+
496
+ for i in range(10):
497
+ t = CostTracker()
498
+ m, p, n, l, r = build_tracked_mocks(t)
499
+ mgr.get_or_compute("重复问题", lambda: perform_rag_tracked("重复问题", m, p, n, l, r))
500
+ total_llm_calls += t.llm_calls
501
+
502
+ assert total_llm_calls == 1, f"10 次查询只应调 1 次 LLM, 实际 {total_llm_calls}"
503
+
504
+ def test_cache_saving_ratio_over_batch(self):
505
+ """批量查询: 50% 重复率 → 节省约 50% 的外部调用"""
506
+ mgr = make_redis_manager()
507
+ questions = ["Q1", "Q2", "Q3", "Q4", "Q5"] * 2 # 10 次查询, 5 个不同问题
508
+
509
+ total_external = 0
510
+ for q in questions:
511
+ t = CostTracker()
512
+ m, p, n, l, r = build_tracked_mocks(t)
513
+ mgr.get_or_compute(q, lambda: perform_rag_tracked(q, m, p, n, l, r))
514
+ total_external += t.total_external_calls
515
+
516
+ # 5 个唯一问题 × 6 次调用 = 30 次; 5 个重复 × 0 次 = 0; 总计 30
517
+ no_cache_total = len(questions) * 6 # 60 (如果没缓存)
518
+ saving_ratio = 1 - (total_external / no_cache_total)
519
+
520
+ assert saving_ratio >= 0.4, (
521
+ f"缓存节省率 {saving_ratio:.1%}, 预期 ≥40%"
522
+ f"\n 实际总调用: {total_external}, 无缓存总调用: {no_cache_total}"
523
+ )
524
+
525
+ def test_cache_saving_dollar_estimate(self):
526
+ """估算缓存节省的美元费用"""
527
+ t = run_tracked_query()
528
+ cost_per_query = t.estimated_cost_usd()
529
+
530
+ # 假设每天 1000 次查询, 50% 缓存命中率
531
+ daily_queries = 1000
532
+ hit_rate = 0.5
533
+ daily_cost_no_cache = daily_queries * cost_per_query
534
+ daily_cost_with_cache = daily_queries * (1 - hit_rate) * cost_per_query
535
+ daily_savings = daily_cost_no_cache - daily_cost_with_cache
536
+
537
+ # 只验证计算逻辑正确
538
+ assert daily_savings > 0, "缓存应节省费用"
539
+ assert daily_savings == daily_cost_no_cache * hit_rate
540
+
541
+
542
+ # ================================================================
543
+ # 维度 4: 降级场景的成本影响
544
+ # ================================================================
545
+
546
+ class TestDegradedCost:
547
+ """
548
+ 组件故障不仅影响质量, 也影响成本
549
+ 部分降级 → 调用次数减少 → 费用降低 (但质量也降低)
550
+ """
551
+
552
+ def test_neo4j_down_saves_three_calls(self):
553
+ """Neo4j 宕机: 节省 3 次调用 (generate + validate + session)"""
554
+ t_normal = run_tracked_query(neo4j_fail=False)
555
+ t_degraded = run_tracked_query(neo4j_fail=True)
556
+
557
+ saved = t_normal.total_external_calls - t_degraded.total_external_calls
558
+ assert saved == 3, f"Neo4j 宕机应节省 3 次调用, 实际节省 {saved}"
559
+
560
+ def test_degraded_cost_is_lower(self):
561
+ """降级时 LLM prompt 更短 (没有 Neo4j context) → token 更少"""
562
+ t_normal = run_tracked_query(neo4j_fail=False)
563
+ t_degraded = run_tracked_query(neo4j_fail=True)
564
+
565
+ # Neo4j 结果不在 context 中, prompt 更短
566
+ assert t_degraded.prompt_chars <= t_normal.prompt_chars, (
567
+ f"降级时 prompt 应更短: 降级={t_degraded.prompt_chars}, 正常={t_normal.prompt_chars}"
568
+ )
569
+
570
+ def test_llm_still_called_once_even_when_degraded(self):
571
+ """降级时 LLM 仍然只调 1 次"""
572
+ t = run_tracked_query(neo4j_fail=True)
573
+ assert t.llm_calls == 1, "降级时 LLM 仍应只调 1 次"
574
+
575
+ def test_cost_comparison_normal_vs_degraded(self):
576
+ """正常 vs 降级的成本对比"""
577
+ t_normal = run_tracked_query(neo4j_fail=False)
578
+ t_degraded = run_tracked_query(neo4j_fail=True)
579
+
580
+ cost_normal = t_normal.estimated_cost_usd()
581
+ cost_degraded = t_degraded.estimated_cost_usd()
582
+
583
+ # 降级成本应 ≤ 正常成本 (少了 context)
584
+ assert cost_degraded <= cost_normal, (
585
+ f"降级费用 ${cost_degraded:.6f} 应 ≤ 正常费用 ${cost_normal:.6f}"
586
+ )
587
+
588
+
589
+ # ================================================================
590
+ # 维度 5: 成本效率报告
591
+ # ================================================================
592
+
593
+ class TestCostEfficiencyReport:
594
+ """生成人类可读的成本效率报告"""
595
+
596
+ def test_single_query_cost_breakdown(self):
597
+ """单次查询成本明细"""
598
+ t = run_tracked_query()
599
+
600
+ assert t.total_external_calls > 0
601
+ assert t.estimated_total_tokens > 0
602
+ assert t.estimated_cost_usd() >= 0
603
+
604
+ def test_batch_efficiency_metrics(self):
605
+ """批量查询效率指标"""
606
+ trackers = [run_tracked_query(f"问题{i}") for i in range(10)]
607
+
608
+ avg_calls = sum(t.total_external_calls for t in trackers) / len(trackers)
609
+ avg_tokens = sum(t.estimated_total_tokens for t in trackers) / len(trackers)
610
+ avg_cost = sum(t.estimated_cost_usd() for t in trackers) / len(trackers)
611
+
612
+ assert avg_calls == 6, f"平均调用次数应为 6, 实际 {avg_calls}"
613
+ assert avg_tokens > 0, "平均 token 应 > 0"
614
+ assert avg_cost > 0, "平均费用应 > 0"
615
+
616
+ def test_model_cost_comparison(self):
617
+ """不同模型的费用对比: gpt-4o-mini vs gpt-4o"""
618
+ t = run_tracked_query()
619
+
620
+ cost_mini = t.estimated_cost_usd("gpt-4o-mini")
621
+ cost_4o = t.estimated_cost_usd("gpt-4o")
622
+
623
+ assert cost_4o > cost_mini, "gpt-4o 应比 gpt-4o-mini 贵"
624
+ ratio = cost_4o / cost_mini if cost_mini > 0 else float('inf')
625
+ assert ratio > 5, f"gpt-4o 应比 mini 贵 5 倍以上, 实际 {ratio:.1f} 倍"
626
+
627
+ def test_cost_report_printout(self, capsys):
628
+ """打印完整成本效率报告"""
629
+ t = run_tracked_query("高血压不能吃什么?")
630
+
631
+ print("\n")
632
+ print("=" * 70)
633
+ print(" 医疗 RAG Agent — Cost & Efficiency 报告")
634
+ print("=" * 70)
635
+
636
+ print(f"\n 📋 查询: '高血压不能吃什么?'")
637
+
638
+ print(f"\n ── 外部调用明细 ──")
639
+ print(f" Milvus 向量搜索: {t.milvus_calls} 次")
640
+ print(f" PDF 父子检索: {t.pdf_calls} 次")
641
+ print(f" Cypher /generate: {t.cypher_generate_calls} 次")
642
+ print(f" Cypher /validate: {t.cypher_validate_calls} 次")
643
+ print(f" Neo4j session.run: {t.neo4j_session_calls} 次")
644
+ print(f" LLM 推理: {t.llm_calls} 次")
645
+ print(f" ────────────────────────────")
646
+ print(f" 总外部调用: {t.total_external_calls} 次")
647
+
648
+ print(f"\n ── Token 消耗 ──")
649
+ print(f" Prompt: ~{t.estimated_prompt_tokens} tokens ({t.prompt_chars} 字符)")
650
+ print(f" Response: ~{t.estimated_response_tokens} tokens ({t.response_chars} 字符)")
651
+ print(f" 总计: ~{t.estimated_total_tokens} tokens")
652
+
653
+ print(f"\n ── 费用估算 (per query) ──")
654
+ print(f" gpt-4o-mini: ${t.estimated_cost_usd('gpt-4o-mini'):.6f}")
655
+ print(f" gpt-4o: ${t.estimated_cost_usd('gpt-4o'):.6f}")
656
+
657
+ # 月度预估
658
+ daily = 1000
659
+ monthly = daily * 30
660
+ hit_rate = 0.5
661
+ effective_queries = monthly * (1 - hit_rate)
662
+ print(f"\n ── 月度预估 (日均 {daily} 查询, 缓存命中率 {hit_rate:.0%}) ──")
663
+ print(f" 有效 LLM 调用: {int(effective_queries)} 次/月")
664
+ print(f" gpt-4o-mini 月费: ${effective_queries * t.estimated_cost_usd('gpt-4o-mini'):.2f}")
665
+ print(f" gpt-4o 月费: ${effective_queries * t.estimated_cost_usd('gpt-4o'):.2f}")
666
+ print(f" 缓存节省: {hit_rate:.0%} ({int(monthly * hit_rate)} 次 LLM 调用)")
667
+
668
+ # 降级对比
669
+ t_deg = run_tracked_query(neo4j_fail=True)
670
+ print(f"\n ── 降级场景对比 ──")
671
+ print(f" 正常: {t.total_external_calls} 次调用, ~{t.estimated_total_tokens} tokens, ${t.estimated_cost_usd():.6f}")
672
+ print(f" 降级: {t_deg.total_external_calls} 次调用, ~{t_deg.estimated_total_tokens} tokens, ${t_deg.estimated_cost_usd():.6f}")
673
+
674
+ print("=" * 70)
675
+
676
+ assert True # 报告打印成功即通过
677
+
678
+
679
+ # ================================================================
680
+ if __name__ == "__main__":
681
+ pytest.main([__file__, "-v", "--tb=short", "-s"])