| """ |
| ================================================================ |
| 医疗 RAG Agent — Cost & Efficiency 评测 (成本与效率) |
| ================================================================ |
| 测试层级: |
| 单元测试 (test1.py): 单工具调用准确性 ✅ 67 passed |
| 集成测试 (test2.py): 多步骤工具链协作 ✅ 37 passed |
| 回归测试 (test3.py): 防退化 & 边界守护 ✅ 52 passed |
| 安全红队 (test4.py): 对抗性攻击防御 ✅ 45 passed |
| E2E完成率(test5.py): 端到端任务完成率 ✅ 60 passed |
| 成本效率 (test6.py): Cost & Efficiency ← 当前文件 |
| |
| 为什么要测成本? |
| Agent 每回答一个问题要: 1次 Embedding + 1次 Milvus + 1次 PDF + |
| 2次 Cypher API + 1次 Neo4j + 1次 LLM = 至少 7 次外部调用 |
| 在生产环境中, 这些调用直接关系 token 消耗和 API 费用 |
| |
| 测试维度: |
| 维度 1: 外部调用次数审计 (每次查询调了几次 API?) |
| 维度 2: Token 消耗估算 (Prompt + Response 共多少 token?) |
| 维度 3: 缓存节省量化 (Redis 命中省了多少调用?) |
| 维度 4: 降级场景的成本影响 (组件故障时成本变化) |
| 维度 5: 成本报告 (人类可读的费用估算) |
| |
| 运行: |
| pytest test6.py -v --tb=short -s |
| pytest test6.py -v -k "call_count" # 调用次数 |
| pytest test6.py -v -k "token" # Token 消耗 |
| pytest test6.py -v -k "cache_saving" # 缓存节省 |
| ================================================================ |
| """ |
|
|
| import sys |
| import os |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) |
|
|
| import types |
| import pytest |
| import json |
| import hashlib |
| import time |
| from unittest.mock import MagicMock, patch, call |
| from dataclasses import dataclass, field |
| from typing import Optional, List, Dict |
|
|
|
|
| |
| |
| |
|
|
| def _ensure_mock_module(name): |
| if name not in sys.modules: |
| sys.modules[name] = MagicMock() |
|
|
| for mod in [ |
| "langchain_classic", "langchain_classic.retrievers", |
| "langchain_classic.retrievers.parent_document_retriever", |
| "langchain_milvus", "langchain_text_splitters", |
| "langchain_core", "langchain_core.stores", "langchain_core.documents", |
| "langchain.embeddings", "langchain.embeddings.base", |
| "neo4j", "dotenv", "uvicorn", |
| "fastapi", "fastapi.middleware", "fastapi.middleware.cors", |
| ]: |
| _ensure_mock_module(mod) |
|
|
| class _FakeEmbeddingsBase: |
| pass |
| sys.modules["langchain.embeddings.base"].Embeddings = _FakeEmbeddingsBase |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class FakeDocument: |
| page_content: str |
| metadata: dict = field(default_factory=dict) |
|
|
| class FakeChatResponse: |
| def __init__(self, content): |
| msg = type('Msg', (), {'content': content})() |
| choice = type('Choice', (), {'message': msg})() |
| self.choices = [choice] |
|
|
| class FakeRedisClient: |
| def __init__(self): |
| self._store = {} |
| self._expiry = {} |
| def ping(self): return True |
| def get(self, key): return self._store.get(key) |
| def set(self, key, value, ex=None, nx=False): |
| if nx and key in self._store: return False |
| self._store[key] = value |
| if ex: self._expiry[key] = ex |
| return True |
| def setex(self, key, expire, value): |
| self._store[key] = value; self._expiry[key] = expire; return True |
| def delete(self, key): return 1 if self._store.pop(key, None) is not None else 0 |
| def register_script(self, script): |
| def f(keys=None, args=None): |
| if keys and args and self._store.get(keys[0]) == args[0]: |
| del self._store[keys[0]]; return 1 |
| return 0 |
| return f |
|
|
| def make_redis_manager(): |
| from new_redis import RedisClientWrapper |
| RedisClientWrapper._pool = "FAKE" |
| mgr = object.__new__(RedisClientWrapper) |
| mgr.client = FakeRedisClient() |
| mgr.unlock_script = mgr.client.register_script("") |
| return mgr |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class CostTracker: |
| """追踪单次查询的全部外部调用和资源消耗""" |
| |
| milvus_calls: int = 0 |
| pdf_calls: int = 0 |
| cypher_generate_calls: int = 0 |
| cypher_validate_calls: int = 0 |
| neo4j_session_calls: int = 0 |
| llm_calls: int = 0 |
| redis_get_calls: int = 0 |
| redis_set_calls: int = 0 |
|
|
| |
| prompt_chars: int = 0 |
| response_chars: int = 0 |
|
|
| |
| start_time: float = 0.0 |
| end_time: float = 0.0 |
|
|
| @property |
| def total_external_calls(self) -> int: |
| return (self.milvus_calls + self.pdf_calls + |
| self.cypher_generate_calls + self.cypher_validate_calls + |
| self.neo4j_session_calls + self.llm_calls) |
|
|
| @property |
| def estimated_prompt_tokens(self) -> int: |
| """粗估 prompt token 数 (中文 ≈ 1.5 字符/token)""" |
| return int(self.prompt_chars / 1.5) if self.prompt_chars else 0 |
|
|
| @property |
| def estimated_response_tokens(self) -> int: |
| return int(self.response_chars / 1.5) if self.response_chars else 0 |
|
|
| @property |
| def estimated_total_tokens(self) -> int: |
| return self.estimated_prompt_tokens + self.estimated_response_tokens |
|
|
| @property |
| def elapsed_ms(self) -> float: |
| return (self.end_time - self.start_time) * 1000 if self.end_time else 0 |
|
|
| def estimated_cost_usd(self, model="gpt-4o-mini") -> float: |
| """ |
| 估算 API 费用 (USD) |
| gpt-4o-mini: $0.15/1M input + $0.60/1M output |
| gpt-4o: $2.50/1M input + $10.00/1M output |
| text-embedding-3-small: $0.02/1M tokens |
| """ |
| pricing = { |
| "gpt-4o-mini": {"input": 0.15, "output": 0.60}, |
| "gpt-4o": {"input": 2.50, "output": 10.00}, |
| } |
| p = pricing.get(model, pricing["gpt-4o-mini"]) |
| input_cost = self.estimated_prompt_tokens * p["input"] / 1_000_000 |
| output_cost = self.estimated_response_tokens * p["output"] / 1_000_000 |
| |
| embed_cost = 50 * 0.02 / 1_000_000 |
| return input_cost + output_cost + embed_cost |
|
|
|
|
| def build_tracked_mocks(tracker: CostTracker, neo4j_fail=False): |
| """构建带调用计数的 Mock 组件""" |
|
|
| |
| milvus = MagicMock() |
| def milvus_search(*args, **kwargs): |
| tracker.milvus_calls += 1 |
| return [FakeDocument(page_content="高血压患者应控制钠摄入量不超过5克")] |
| milvus.similarity_search.side_effect = milvus_search |
|
|
| |
| pdf = MagicMock() |
| def pdf_invoke(*args, **kwargs): |
| tracker.pdf_calls += 1 |
| return [FakeDocument(page_content="《中国高血压防治指南》建议低盐低脂饮食")] |
| pdf.invoke.side_effect = pdf_invoke |
|
|
| |
| neo4j_driver = MagicMock() |
| sess = MagicMock() |
| def neo4j_run(*args, **kwargs): |
| tracker.neo4j_session_calls += 1 |
| if neo4j_fail: |
| raise Exception("Neo4j down") |
| return [("氨氯地平",), ("缬沙坦",)] |
| sess.run.side_effect = neo4j_run |
| neo4j_driver.session.return_value.__enter__ = MagicMock(return_value=sess) |
| neo4j_driver.session.return_value.__exit__ = MagicMock(return_value=False) |
|
|
| |
| req = MagicMock() |
| call_index = [0] |
| def req_post(url, *args, **kwargs): |
| if neo4j_fail: |
| raise ConnectionError("Cypher API down") |
| if "/generate" in url: |
| tracker.cypher_generate_calls += 1 |
| resp = MagicMock(); resp.status_code = 200 |
| resp.json.return_value = { |
| "cypher_query": "MATCH (d:Disease)-[:has_drug]->(m) RETURN m.name", |
| "confidence": 0.95, "validated": True, |
| } |
| return resp |
| elif "/validate" in url: |
| tracker.cypher_validate_calls += 1 |
| resp = MagicMock(); resp.status_code = 200 |
| resp.json.return_value = {"is_valid": True} |
| return resp |
| req.post.side_effect = req_post |
|
|
| |
| llm = MagicMock() |
| def llm_create(*args, **kwargs): |
| tracker.llm_calls += 1 |
| prompt = kwargs.get("messages", [{}])[0].get("content", "") |
| tracker.prompt_chars = len(prompt) |
| answer = "高血压患者应避免高盐饮食, 建议每日钠摄入不超过5克, 常用药物包括氨氯地平、缬沙坦等。" |
| tracker.response_chars = len(answer) |
| return FakeChatResponse(answer) |
| llm.chat.completions.create.side_effect = llm_create |
|
|
| return milvus, pdf, neo4j_driver, llm, req |
|
|
|
|
| def perform_rag_tracked(query, milvus, pdf, neo4j_driver, llm, requests_module): |
| """依赖注入版 perform_rag_and_llm""" |
| import json as _json |
|
|
| try: |
| results = milvus.similarity_search(query, k=10, ranker_type="rrf", ranker_params={"k": 100}) |
| context = "\n\n".join(d.page_content for d in results) if results else "" |
| except Exception: |
| context = "" |
|
|
| pdf_res = "" |
| try: |
| docs = pdf.invoke(query) |
| if docs and len(docs) >= 1: |
| pdf_res = docs[0].page_content |
| except Exception: |
| pass |
| context = context + "\n" + pdf_res |
|
|
| neo4j_res = "" |
| try: |
| resp = requests_module.post("http://0.0.0.0:8101/generate", |
| _json.dumps({"natural_language_query": query})) |
| if resp.status_code == 200: |
| d = resp.json() |
| if d["cypher_query"] and float(d["confidence"]) >= 0.9 and d["validated"]: |
| vresp = requests_module.post("http://0.0.0.0:8101/validate", |
| _json.dumps({"cypher_query": d["cypher_query"]})) |
| if vresp.status_code == 200 and vresp.json()["is_valid"]: |
| with neo4j_driver.session() as session: |
| try: |
| record = session.run(d["cypher_query"]) |
| neo4j_res = ','.join(list(map(lambda x: x[0], record))) |
| except Exception: |
| neo4j_res = "" |
| except Exception: |
| pass |
| context = context + "\n" + neo4j_res |
|
|
| SYSTEM = "System: 你是一个非常得力的医学助手, 你可以通过从数据库中检索出的信息找到问题的答案." |
| USER = f"""User: 利用介于<context>和</context>之间的信息来回答问题. |
| <context> |
| {context} |
| </context> |
| <question> |
| {query} |
| </question>""" |
|
|
| response = llm.chat.completions.create( |
| model="gpt-4o-mini", |
| messages=[{"role": "user", "content": SYSTEM + USER}], |
| temperature=0.7, |
| ) |
| return response.choices[0].message.content |
|
|
|
|
| def run_tracked_query(query="高血压不能吃什么?", neo4j_fail=False) -> CostTracker: |
| """执行一次查询并返回成本追踪数据""" |
| tracker = CostTracker() |
| milvus, pdf, neo4j, llm, req = build_tracked_mocks(tracker, neo4j_fail=neo4j_fail) |
|
|
| tracker.start_time = time.time() |
| perform_rag_tracked(query, milvus, pdf, neo4j, llm, req) |
| tracker.end_time = time.time() |
|
|
| return tracker |
|
|
|
|
| |
| |
| |
|
|
| class TestExternalCallCount: |
| """ |
| 核心问题: 回答一个问题到底调了多少次外部 API? |
| 每多一次调用 = 多一份延迟 + 多一份费用 + 多一个故障点 |
| """ |
|
|
| def test_normal_query_call_count(self): |
| """正常查询: 精确审计每个组件的调用次数""" |
| t = run_tracked_query() |
|
|
| assert t.milvus_calls == 1, f"Milvus 应调 1 次, 实际 {t.milvus_calls}" |
| assert t.pdf_calls == 1, f"PDF 应调 1 次, 实际 {t.pdf_calls}" |
| assert t.cypher_generate_calls == 1, f"Cypher /generate 应调 1 次, 实际 {t.cypher_generate_calls}" |
| assert t.cypher_validate_calls == 1, f"Cypher /validate 应调 1 次, 实际 {t.cypher_validate_calls}" |
| assert t.neo4j_session_calls == 1, f"Neo4j session.run 应调 1 次, 实际 {t.neo4j_session_calls}" |
| assert t.llm_calls == 1, f"LLM 应调 1 次, 实际 {t.llm_calls}" |
|
|
| def test_total_external_calls_is_six(self): |
| """正常查询总外部调用次数 = 6""" |
| t = run_tracked_query() |
| assert t.total_external_calls == 6, ( |
| f"总外部调用应为 6, 实际 {t.total_external_calls}" |
| f"\n Milvus={t.milvus_calls}, PDF={t.pdf_calls}," |
| f" Cypher生成={t.cypher_generate_calls}, Cypher校验={t.cypher_validate_calls}," |
| f" Neo4j={t.neo4j_session_calls}, LLM={t.llm_calls}" |
| ) |
|
|
| def test_no_duplicate_llm_calls(self): |
| """LLM 严格只调 1 次 (最贵的组件)""" |
| t = run_tracked_query() |
| assert t.llm_calls == 1, f"LLM 不应重复调用, 实际 {t.llm_calls}" |
|
|
| def test_neo4j_down_reduces_calls(self): |
| """Neo4j 宕机: 减少 3 次外部调用 (generate + validate + session)""" |
| t = run_tracked_query(neo4j_fail=True) |
|
|
| assert t.cypher_generate_calls == 0, "Cypher API 不可用时不应有 /generate 调用" |
| assert t.cypher_validate_calls == 0, "Cypher API 不可用时不应有 /validate 调用" |
| assert t.neo4j_session_calls == 0, "Cypher API 不可用时不应有 session.run" |
| assert t.total_external_calls == 3, ( |
| f"Neo4j 宕机时总调用应为 3 (Milvus+PDF+LLM), 实际 {t.total_external_calls}" |
| ) |
|
|
| def test_multiple_queries_each_has_own_calls(self): |
| """多个查询: 每个查询独立计数""" |
| trackers = [run_tracked_query(f"问题{i}") for i in range(5)] |
| for i, t in enumerate(trackers): |
| assert t.llm_calls == 1, f"查询 {i}: LLM 调用应为 1" |
| assert t.total_external_calls == 6, f"查询 {i}: 总调用应为 6" |
|
|
| def test_embedding_call_per_milvus_search(self): |
| """ |
| 每次 Milvus similarity_search 内部会调用 1 次 Embedding |
| (由 Milvus SDK 内部处理, 这里验证 Milvus 调用次数) |
| """ |
| t = run_tracked_query() |
| |
| |
| assert t.milvus_calls == 1, "每次查询应只触发 1 次 Milvus 搜索 (含 1 次 Embedding)" |
|
|
|
|
| |
| |
| |
|
|
| class TestTokenConsumption: |
| """ |
| 核心问题: 每次查询消耗多少 token? |
| token 是 LLM 计费的直接单位 |
| """ |
|
|
| def test_prompt_token_count_reasonable(self): |
| """Prompt token 数在合理范围 (50-2000)""" |
| t = run_tracked_query() |
| tokens = t.estimated_prompt_tokens |
| assert 50 <= tokens <= 2000, f"Prompt tokens {tokens} 超出合理范围 [50, 2000]" |
|
|
| def test_response_token_count_reasonable(self): |
| """Response token 数在合理范围 (5-500)""" |
| t = run_tracked_query() |
| tokens = t.estimated_response_tokens |
| assert 5 <= tokens <= 500, f"Response tokens {tokens} 超出合理范围 [5, 500]" |
|
|
| def test_total_token_count_per_query(self): |
| """单次查询总 token 数 < 3000 (gpt-4o-mini 上下文窗口远大于此)""" |
| t = run_tracked_query() |
| total = t.estimated_total_tokens |
| assert total < 3000, f"单次查询 token {total} 不应超过 3000" |
|
|
| def test_prompt_is_largest_cost_component(self): |
| """Prompt token 应占总 token 的大部分 (>60%)""" |
| t = run_tracked_query() |
| if t.estimated_total_tokens > 0: |
| prompt_ratio = t.estimated_prompt_tokens / t.estimated_total_tokens |
| assert prompt_ratio > 0.6, ( |
| f"Prompt 占比 {prompt_ratio:.1%}, 应 >60% (context 是大头)" |
| ) |
|
|
| def test_longer_query_means_more_tokens(self): |
| """更长的问题 → 更多的 prompt token""" |
| t_short = run_tracked_query("高血压") |
| t_long = run_tracked_query("请详细介绍高血压的所有相关症状以及对应的治疗方案和饮食建议") |
|
|
| |
| assert t_long.prompt_chars >= t_short.prompt_chars, ( |
| f"长问题 prompt ({t_long.prompt_chars}) 应 ≥ 短问题 ({t_short.prompt_chars})" |
| ) |
|
|
| def test_context_contributes_most_tokens(self): |
| """Context (三路召回内容) 是 prompt 中 token 最大的来源""" |
| t = run_tracked_query() |
| |
| |
| pure_template = 120 |
| context_chars = t.prompt_chars - pure_template |
| assert context_chars > 0, "Context 应为 prompt 贡献内容" |
| context_ratio = context_chars / t.prompt_chars |
| assert context_ratio > 0.3, ( |
| f"Context 占 prompt 比例 {context_ratio:.1%}, 应 >30%" |
| f"\n (Mock 数据较短; 生产环境 context 占比通常 >70%)" |
| ) |
|
|
|
|
| |
| |
| |
|
|
| class TestCacheSavings: |
| """ |
| 核心问题: Redis 缓存帮我们省了多少钱? |
| 每次缓存命中 = 省了 6 次外部调用 |
| """ |
|
|
| def test_cache_hit_saves_all_external_calls(self): |
| """缓存命中: 0 次外部调用 (省了 6 次)""" |
| mgr = make_redis_manager() |
| first_tracker = CostTracker() |
| milvus, pdf, neo4j, llm, req = build_tracked_mocks(first_tracker) |
|
|
| def first_rag(): |
| return perform_rag_tracked("高血压", milvus, pdf, neo4j, llm, req) |
|
|
| |
| mgr.get_or_compute("高血压", first_rag) |
| assert first_tracker.total_external_calls == 6 |
|
|
| |
| second_tracker = CostTracker() |
| milvus2, pdf2, neo4j2, llm2, req2 = build_tracked_mocks(second_tracker) |
| def second_rag(): |
| return perform_rag_tracked("高血压", milvus2, pdf2, neo4j2, llm2, req2) |
| mgr.get_or_compute("高血压", second_rag) |
|
|
| assert second_tracker.total_external_calls == 0, ( |
| f"缓存命中时不应有外部调用, 实际 {second_tracker.total_external_calls}" |
| ) |
|
|
| def test_cache_saves_llm_cost(self): |
| """缓存命中: 节省 LLM 调用费用""" |
| mgr = make_redis_manager() |
|
|
| first_t = CostTracker() |
| m, p, n, l, r = build_tracked_mocks(first_t) |
| mgr.get_or_compute("Q1", lambda: perform_rag_tracked("Q1", m, p, n, l, r)) |
|
|
| second_t = CostTracker() |
| m2, p2, n2, l2, r2 = build_tracked_mocks(second_t) |
| mgr.get_or_compute("Q1", lambda: perform_rag_tracked("Q1", m2, p2, n2, l2, r2)) |
|
|
| assert first_t.llm_calls == 1, "第一次应调 LLM" |
| assert second_t.llm_calls == 0, "第二次缓存命中, 不应调 LLM" |
|
|
| def test_ten_queries_same_question_only_one_rag(self): |
| """同一问题查 10 次, 只走 1 次 RAG""" |
| mgr = make_redis_manager() |
| total_llm_calls = 0 |
|
|
| for i in range(10): |
| t = CostTracker() |
| m, p, n, l, r = build_tracked_mocks(t) |
| mgr.get_or_compute("重复问题", lambda: perform_rag_tracked("重复问题", m, p, n, l, r)) |
| total_llm_calls += t.llm_calls |
|
|
| assert total_llm_calls == 1, f"10 次查询只应调 1 次 LLM, 实际 {total_llm_calls}" |
|
|
| def test_cache_saving_ratio_over_batch(self): |
| """批量查询: 50% 重复率 → 节省约 50% 的外部调用""" |
| mgr = make_redis_manager() |
| questions = ["Q1", "Q2", "Q3", "Q4", "Q5"] * 2 |
|
|
| total_external = 0 |
| for q in questions: |
| t = CostTracker() |
| m, p, n, l, r = build_tracked_mocks(t) |
| mgr.get_or_compute(q, lambda: perform_rag_tracked(q, m, p, n, l, r)) |
| total_external += t.total_external_calls |
|
|
| |
| no_cache_total = len(questions) * 6 |
| saving_ratio = 1 - (total_external / no_cache_total) |
|
|
| assert saving_ratio >= 0.4, ( |
| f"缓存节省率 {saving_ratio:.1%}, 预期 ≥40%" |
| f"\n 实际总调用: {total_external}, 无缓存总调用: {no_cache_total}" |
| ) |
|
|
| def test_cache_saving_dollar_estimate(self): |
| """估算缓存节省的美元费用""" |
| t = run_tracked_query() |
| cost_per_query = t.estimated_cost_usd() |
|
|
| |
| daily_queries = 1000 |
| hit_rate = 0.5 |
| daily_cost_no_cache = daily_queries * cost_per_query |
| daily_cost_with_cache = daily_queries * (1 - hit_rate) * cost_per_query |
| daily_savings = daily_cost_no_cache - daily_cost_with_cache |
|
|
| |
| assert daily_savings > 0, "缓存应节省费用" |
| assert daily_savings == daily_cost_no_cache * hit_rate |
|
|
|
|
| |
| |
| |
|
|
| class TestDegradedCost: |
| """ |
| 组件故障不仅影响质量, 也影响成本 |
| 部分降级 → 调用次数减少 → 费用降低 (但质量也降低) |
| """ |
|
|
| def test_neo4j_down_saves_three_calls(self): |
| """Neo4j 宕机: 节省 3 次调用 (generate + validate + session)""" |
| t_normal = run_tracked_query(neo4j_fail=False) |
| t_degraded = run_tracked_query(neo4j_fail=True) |
|
|
| saved = t_normal.total_external_calls - t_degraded.total_external_calls |
| assert saved == 3, f"Neo4j 宕机应节省 3 次调用, 实际节省 {saved}" |
|
|
| def test_degraded_cost_is_lower(self): |
| """降级时 LLM prompt 更短 (没有 Neo4j context) → token 更少""" |
| t_normal = run_tracked_query(neo4j_fail=False) |
| t_degraded = run_tracked_query(neo4j_fail=True) |
|
|
| |
| assert t_degraded.prompt_chars <= t_normal.prompt_chars, ( |
| f"降级时 prompt 应更短: 降级={t_degraded.prompt_chars}, 正常={t_normal.prompt_chars}" |
| ) |
|
|
| def test_llm_still_called_once_even_when_degraded(self): |
| """降级时 LLM 仍然只调 1 次""" |
| t = run_tracked_query(neo4j_fail=True) |
| assert t.llm_calls == 1, "降级时 LLM 仍应只调 1 次" |
|
|
| def test_cost_comparison_normal_vs_degraded(self): |
| """正常 vs 降级的成本对比""" |
| t_normal = run_tracked_query(neo4j_fail=False) |
| t_degraded = run_tracked_query(neo4j_fail=True) |
|
|
| cost_normal = t_normal.estimated_cost_usd() |
| cost_degraded = t_degraded.estimated_cost_usd() |
|
|
| |
| assert cost_degraded <= cost_normal, ( |
| f"降级费用 ${cost_degraded:.6f} 应 ≤ 正常费用 ${cost_normal:.6f}" |
| ) |
|
|
|
|
| |
| |
| |
|
|
| class TestCostEfficiencyReport: |
| """生成人类可读的成本效率报告""" |
|
|
| def test_single_query_cost_breakdown(self): |
| """单次查询成本明细""" |
| t = run_tracked_query() |
|
|
| assert t.total_external_calls > 0 |
| assert t.estimated_total_tokens > 0 |
| assert t.estimated_cost_usd() >= 0 |
|
|
| def test_batch_efficiency_metrics(self): |
| """批量查询效率指标""" |
| trackers = [run_tracked_query(f"问题{i}") for i in range(10)] |
|
|
| avg_calls = sum(t.total_external_calls for t in trackers) / len(trackers) |
| avg_tokens = sum(t.estimated_total_tokens for t in trackers) / len(trackers) |
| avg_cost = sum(t.estimated_cost_usd() for t in trackers) / len(trackers) |
|
|
| assert avg_calls == 6, f"平均调用次数应为 6, 实际 {avg_calls}" |
| assert avg_tokens > 0, "平均 token 应 > 0" |
| assert avg_cost > 0, "平均费用应 > 0" |
|
|
| def test_model_cost_comparison(self): |
| """不同模型的费用对比: gpt-4o-mini vs gpt-4o""" |
| t = run_tracked_query() |
|
|
| cost_mini = t.estimated_cost_usd("gpt-4o-mini") |
| cost_4o = t.estimated_cost_usd("gpt-4o") |
|
|
| assert cost_4o > cost_mini, "gpt-4o 应比 gpt-4o-mini 贵" |
| ratio = cost_4o / cost_mini if cost_mini > 0 else float('inf') |
| assert ratio > 5, f"gpt-4o 应比 mini 贵 5 倍以上, 实际 {ratio:.1f} 倍" |
|
|
| def test_cost_report_printout(self, capsys): |
| """打印完整成本效率报告""" |
| t = run_tracked_query("高血压不能吃什么?") |
|
|
| print("\n") |
| print("=" * 70) |
| print(" 医疗 RAG Agent — Cost & Efficiency 报告") |
| print("=" * 70) |
|
|
| print(f"\n 📋 查询: '高血压不能吃什么?'") |
|
|
| print(f"\n ── 外部调用明细 ──") |
| print(f" Milvus 向量搜索: {t.milvus_calls} 次") |
| print(f" PDF 父子检索: {t.pdf_calls} 次") |
| print(f" Cypher /generate: {t.cypher_generate_calls} 次") |
| print(f" Cypher /validate: {t.cypher_validate_calls} 次") |
| print(f" Neo4j session.run: {t.neo4j_session_calls} 次") |
| print(f" LLM 推理: {t.llm_calls} 次") |
| print(f" ────────────────────────────") |
| print(f" 总外部调用: {t.total_external_calls} 次") |
|
|
| print(f"\n ── Token 消耗 ──") |
| print(f" Prompt: ~{t.estimated_prompt_tokens} tokens ({t.prompt_chars} 字符)") |
| print(f" Response: ~{t.estimated_response_tokens} tokens ({t.response_chars} 字符)") |
| print(f" 总计: ~{t.estimated_total_tokens} tokens") |
|
|
| print(f"\n ── 费用估算 (per query) ──") |
| print(f" gpt-4o-mini: ${t.estimated_cost_usd('gpt-4o-mini'):.6f}") |
| print(f" gpt-4o: ${t.estimated_cost_usd('gpt-4o'):.6f}") |
|
|
| |
| daily = 1000 |
| monthly = daily * 30 |
| hit_rate = 0.5 |
| effective_queries = monthly * (1 - hit_rate) |
| print(f"\n ── 月度预估 (日均 {daily} 查询, 缓存命中率 {hit_rate:.0%}) ──") |
| print(f" 有效 LLM 调用: {int(effective_queries)} 次/月") |
| print(f" gpt-4o-mini 月费: ${effective_queries * t.estimated_cost_usd('gpt-4o-mini'):.2f}") |
| print(f" gpt-4o 月费: ${effective_queries * t.estimated_cost_usd('gpt-4o'):.2f}") |
| print(f" 缓存节省: {hit_rate:.0%} ({int(monthly * hit_rate)} 次 LLM 调用)") |
|
|
| |
| t_deg = run_tracked_query(neo4j_fail=True) |
| print(f"\n ── 降级场景对比 ──") |
| print(f" 正常: {t.total_external_calls} 次调用, ~{t.estimated_total_tokens} tokens, ${t.estimated_cost_usd():.6f}") |
| print(f" 降级: {t_deg.total_external_calls} 次调用, ~{t_deg.estimated_total_tokens} tokens, ${t_deg.estimated_cost_usd():.6f}") |
|
|
| print("=" * 70) |
|
|
| assert True |
|
|
|
|
| |
| if __name__ == "__main__": |
| pytest.main([__file__, "-v", "--tb=short", "-s"]) |