ResearchRadar / test_agents.py
ZZZyx3587's picture
Upload folder using huggingface_hub
03c63b9 verified
# test_agents.py
# ============================================================
# 类型:测试工具(乙负责)
# 功能:一键运行两个 Agent 的自测,验证基础功能正常
# 用法:python test_agents.py
# ============================================================
import sys
from llm_utils import fix_windows_encoding
# 模拟 GitHub 搜索结果(供 Agent 1 测试使用)
MOCK_REPOS = [
{
"full_name": "openvinotoolkit/anomalib",
"html_url": "https://github.com/openvinotoolkit/anomalib",
"description": "An anomaly detection library comprising state-of-the-art algorithms.",
"stars": 4000, "language": "Python",
"topics": ["anomaly-detection", "pytorch", "benchmarking"],
},
{
"full_name": "hcw-00/PatchCore",
"html_url": "https://github.com/hcw-00/PatchCore",
"description": "Official implementation of PatchCore: Towards Total Recall in Industrial Anomaly Detection.",
"stars": 850, "language": "Python",
"topics": ["memory-bank", "anomaly-detection", "pytorch"],
},
{
"full_name": "taikiinoue45/STFPM",
"html_url": "https://github.com/taikiinoue45/STFPM",
"description": "Student-Teacher Feature Pyramid Matching for Anomaly Detection (PyTorch).",
"stars": 320, "language": "Python",
"topics": ["teacher-student", "anomaly-detection"],
},
{
"full_name": "VitjanZ/DRAEM",
"html_url": "https://github.com/VitjanZ/DRAEM",
"description": "DRAEM: Discriminatively trained reconstruction embedding for surface anomaly detection.",
"stars": 180, "language": "Python",
"topics": ["synthetic-defect", "anomaly-detection", "reconstruction"],
},
{
"full_name": "marugoto/CFLow",
"html_url": "https://github.com/marugoto/CFLow",
"description": "CFLow: Real-time unsupervised anomaly detection via conditional normalizing flows.",
"stars": 130, "language": "Python",
"topics": ["normalizing-flow", "anomaly-detection"],
},
{
"full_name": "someone/anomaly-tutorial",
"html_url": "https://github.com/someone/anomaly-tutorial",
"description": "A curated list of anomaly detection papers and resources.",
"stars": 45, "language": "Markdown",
"topics": ["awesome-list", "anomaly-detection"],
},
]
def test_agent1():
"""测试 Agent 1:direction_analyzer(新版——基于仓库数据归纳)"""
from direction_analyzer import analyze_direction
print("=" * 60)
print("Agent 1 测试:基于 GitHub 仓库的方法族归纳")
print("=" * 60)
test_cases = [
{
"name": "PaDiM + 真实仓库数据",
"title": "PaDiM: a Patch Distribution Modeling Framework for Anomaly Detection and Localization",
"abstract": (
"We present a new framework for anomaly detection and localization based on "
"patch distribution modeling. PaDiM uses pretrained convolutional neural networks "
"to extract patch features and models their distribution using multivariate "
"Gaussian distributions. It achieves state-of-the-art results on MVTec AD and "
"STC datasets for industrial defect detection."
),
"repos": MOCK_REPOS,
"checks": [
("subfield 非空", lambda r: bool(r.get("subfield"))),
("有 >= 2 个方法族", lambda r: len(r.get("method_families", [])) >= 2),
("至少 1 个方法族有 matched_repos", lambda r: any(
len(mf.get("matched_repos", [])) >= 1
for mf in r.get("method_families", [])
)),
("有 broad_queries", lambda r: len(r.get("broad_queries", [])) >= 1),
],
},
{
"name": "Transformer (降级模式:无仓库数据)",
"title": "Attention Is All You Need",
"abstract": (
"The dominant sequence transduction models are based on complex recurrent "
"or convolutional neural networks... We propose a new simple network "
"architecture, the Transformer, based solely on attention mechanisms."
),
"repos": None, # 降级模式
"checks": [
("subfield 非空", lambda r: bool(r.get("subfield"))),
("有 >= 3 个方法族", lambda r: len(r.get("method_families", [])) >= 3),
("有 broad_queries", lambda r: len(r.get("broad_queries", [])) >= 1),
],
},
{
"name": "EfficientAD (轻量仓库数据)",
"title": "EfficientAD: Accurate Visual Anomaly Detection at Millisecond-Level Latencies",
"abstract": (
"We propose EfficientAD, a lightweight anomaly detection method that achieves "
"millisecond-level latency. It uses a teacher-student architecture with a simple "
"student network that learns to mimic the teacher's features on normal images only."
),
"repos": MOCK_REPOS[:4], # 只给 4 个仓库
"checks": [
("subfield 非空", lambda r: bool(r.get("subfield"))),
("有 subfield_trend", lambda r: len(r.get("subfield_trend", "")) >= 10),
("方法族有 search_queries", lambda r: any(
len(mf.get("search_queries", [])) >= 1
for mf in r.get("method_families", [])
)),
],
},
]
passed = 0
failed = 0
for tc in test_cases:
print(f"\n [{tc['name']}]")
try:
result = analyze_direction(tc["title"], tc["abstract"], tc["repos"])
subfield = result.get("subfield", "?")
families = result.get("method_families", [])
print(f" 子领域: {subfield}")
print(f" 方法族:")
for mf in families:
print(f" - {mf.get('family_name', '?')}: matched_repos={mf.get('matched_repos', [])}")
tc_failed = False
for check_name, check_fn in tc["checks"]:
if check_fn(result):
print(f" ✅ {check_name}")
else:
print(f" ❌ {check_name}")
tc_failed = True
if tc_failed:
failed += 1
else:
passed += 1
except Exception as e:
print(f" ❌ 异常: {e}")
import traceback
traceback.print_exc()
failed += 1
return passed, failed
def test_agent2():
"""测试 Agent 2:repo_evaluator"""
from repo_evaluator import evaluate_repo
print("\n" + "=" * 60)
print("Agent 2 测试:仓库评估")
print("=" * 60)
test_cases = [
{
"name": "anomalib (预期高分)",
"repo": {
"full_name": "openvinotoolkit/anomalib",
"html_url": "https://github.com/openvinotoolkit/anomalib",
"stars": 4000,
"description": "An anomaly detection library comprising state-of-the-art algorithms.",
"updated_at": "2026-01-15T00:00:00Z",
"language": "Python",
},
"readme": (
"# Anomalib\nA library for benchmarking anomaly detection.\n\n"
"## Installation\n```\npip install anomalib\n```\n\n"
"## Training\n```python\nfrom anomalib.engine import Engine\nengine = Engine()\nengine.train()\n```\n\n"
"## Supported Datasets\n- MVTec AD\n- BTAD\n\n"
"## Benchmarking\nUse `tools/benchmark.py` for standardized evaluation."
),
"deps": {"requirements.txt": "torch>=1.10\npytorch-lightning>=1.7\nopencv-python\nnumpy"},
"family": "Patch Distribution Modeling",
"checks": [
("overall_score >= 55", lambda r: r.get("overall_score", 0) >= 55),
("verdict 非 error", lambda r: r.get("verdict") in ("reproducible", "partially")),
("benchmark_readiness 非 not_ready", lambda r: r.get("benchmark_readiness") in ("ready", "partial")),
("reasoning 非空", lambda r: len(r.get("reasoning", "")) >= 10),
],
},
{
"name": "推理 Demo 项目 (预期低分)",
"repo": {
"full_name": "someone/anomaly-demo",
"html_url": "https://github.com/someone/anomaly-demo",
"stars": 15,
"description": "A simple demo of anomaly detection.",
"updated_at": "2024-03-01T00:00:00Z",
"language": "Python",
},
"readme": "# Anomaly Detection Demo\nJust a demo.\n\n```\npython demo.py --image path/to/image.jpg\n```\nPretrained weights: Google Drive link.",
"deps": {},
"family": "",
"checks": [
("overall_score < 50", lambda r: r.get("overall_score", 0) < 50),
("benchmark_readiness 为 not_ready", lambda r: r.get("benchmark_readiness") == "not_ready"),
("有 risks 列表", lambda r: isinstance(r.get("risks"), list)),
],
},
{
"name": "中等项目 (部分可复现)",
"repo": {
"full_name": "research-lab/anomaly-pytorch",
"html_url": "https://github.com/research-lab/anomaly-pytorch",
"stars": 200,
"description": "PyTorch implementation of anomaly detection methods.",
"updated_at": "2025-08-01T00:00:00Z",
"language": "Python",
},
"readme": (
"# Anomaly Detection in PyTorch\n\n"
"## Install\n```\npip install -r requirements.txt\n```\n\n"
"## Train\n```\npython train.py --config config.yaml\n```"
),
"deps": {"requirements.txt": "torch>=1.8\nnumpy\nopencv-python\nmatplotlib"},
"family": "Teacher-Student",
"checks": [
("overall_score 在 40-85 之间", lambda r: 40 <= r.get("overall_score", 0) <= 85),
("verdict 非 error", lambda r: r.get("verdict") != "error"),
],
},
]
passed = 0
failed = 0
for tc in test_cases:
print(f"\n [{tc['name']}]")
try:
result = evaluate_repo(tc["repo"], tc["readme"], tc["deps"], tc["family"])
print(f" 综合: {result.get('overall_score')}/100")
print(f" 可复现性: {result.get('reproducibility_score')}/80")
print(f" 适配度: {result.get('benchmark_fitness_score')}/20")
print(f" 判定: {result.get('verdict')} | 就绪: {result.get('benchmark_readiness')}")
tc_failed = False
for check_name, check_fn in tc["checks"]:
if check_fn(result):
print(f" ✅ {check_name}")
else:
print(f" ❌ {check_name}")
tc_failed = True
if tc_failed:
failed += 1
else:
passed += 1
except Exception as e:
print(f" ❌ 异常: {e}")
import traceback
traceback.print_exc()
failed += 1
return passed, failed
# ============================================================
# 主入口
# ============================================================
if __name__ == "__main__":
fix_windows_encoding()
print("ResearchRadar · Agent 自检套件")
print("=" * 60)
a1_pass, a1_fail = test_agent1()
a2_pass, a2_fail = test_agent2()
print("\n" + "=" * 60)
print("测试汇总")
print("=" * 60)
print(f" Agent 1 (方法族归纳): {a1_pass} 通过, {a1_fail} 失败")
print(f" Agent 2 (仓库评估): {a2_pass} 通过, {a2_fail} 失败")
total_pass = a1_pass + a2_pass
total_fail = a1_fail + a2_fail
print(f" 总计: {total_pass} 通过, {total_fail} 失败")
if total_fail > 0:
print("\n⚠ 部分测试未通过,请检查对应 Agent 的输出。")
sys.exit(1)
else:
print("\n✅ 所有测试通过!")
sys.exit(0)