# test_agents.py # ============================================================ # 类型:测试工具(乙负责) # 功能:一键运行两个 Agent 的自测,验证基础功能正常 # 用法:python test_agents.py # ============================================================ import sys from llm_utils import fix_windows_encoding # 模拟 GitHub 搜索结果(供 Agent 1 测试使用) MOCK_REPOS = [ { "full_name": "openvinotoolkit/anomalib", "html_url": "https://github.com/openvinotoolkit/anomalib", "description": "An anomaly detection library comprising state-of-the-art algorithms.", "stars": 4000, "language": "Python", "topics": ["anomaly-detection", "pytorch", "benchmarking"], }, { "full_name": "hcw-00/PatchCore", "html_url": "https://github.com/hcw-00/PatchCore", "description": "Official implementation of PatchCore: Towards Total Recall in Industrial Anomaly Detection.", "stars": 850, "language": "Python", "topics": ["memory-bank", "anomaly-detection", "pytorch"], }, { "full_name": "taikiinoue45/STFPM", "html_url": "https://github.com/taikiinoue45/STFPM", "description": "Student-Teacher Feature Pyramid Matching for Anomaly Detection (PyTorch).", "stars": 320, "language": "Python", "topics": ["teacher-student", "anomaly-detection"], }, { "full_name": "VitjanZ/DRAEM", "html_url": "https://github.com/VitjanZ/DRAEM", "description": "DRAEM: Discriminatively trained reconstruction embedding for surface anomaly detection.", "stars": 180, "language": "Python", "topics": ["synthetic-defect", "anomaly-detection", "reconstruction"], }, { "full_name": "marugoto/CFLow", "html_url": "https://github.com/marugoto/CFLow", "description": "CFLow: Real-time unsupervised anomaly detection via conditional normalizing flows.", "stars": 130, "language": "Python", "topics": ["normalizing-flow", "anomaly-detection"], }, { "full_name": "someone/anomaly-tutorial", "html_url": "https://github.com/someone/anomaly-tutorial", "description": "A curated list of anomaly detection papers and resources.", "stars": 45, "language": "Markdown", "topics": ["awesome-list", "anomaly-detection"], }, ] def test_agent1(): """测试 Agent 1:direction_analyzer(新版——基于仓库数据归纳)""" from direction_analyzer import analyze_direction print("=" * 60) print("Agent 1 测试:基于 GitHub 仓库的方法族归纳") print("=" * 60) test_cases = [ { "name": "PaDiM + 真实仓库数据", "title": "PaDiM: a Patch Distribution Modeling Framework for Anomaly Detection and Localization", "abstract": ( "We present a new framework for anomaly detection and localization based on " "patch distribution modeling. PaDiM uses pretrained convolutional neural networks " "to extract patch features and models their distribution using multivariate " "Gaussian distributions. It achieves state-of-the-art results on MVTec AD and " "STC datasets for industrial defect detection." ), "repos": MOCK_REPOS, "checks": [ ("subfield 非空", lambda r: bool(r.get("subfield"))), ("有 >= 2 个方法族", lambda r: len(r.get("method_families", [])) >= 2), ("至少 1 个方法族有 matched_repos", lambda r: any( len(mf.get("matched_repos", [])) >= 1 for mf in r.get("method_families", []) )), ("有 broad_queries", lambda r: len(r.get("broad_queries", [])) >= 1), ], }, { "name": "Transformer (降级模式:无仓库数据)", "title": "Attention Is All You Need", "abstract": ( "The dominant sequence transduction models are based on complex recurrent " "or convolutional neural networks... We propose a new simple network " "architecture, the Transformer, based solely on attention mechanisms." ), "repos": None, # 降级模式 "checks": [ ("subfield 非空", lambda r: bool(r.get("subfield"))), ("有 >= 3 个方法族", lambda r: len(r.get("method_families", [])) >= 3), ("有 broad_queries", lambda r: len(r.get("broad_queries", [])) >= 1), ], }, { "name": "EfficientAD (轻量仓库数据)", "title": "EfficientAD: Accurate Visual Anomaly Detection at Millisecond-Level Latencies", "abstract": ( "We propose EfficientAD, a lightweight anomaly detection method that achieves " "millisecond-level latency. It uses a teacher-student architecture with a simple " "student network that learns to mimic the teacher's features on normal images only." ), "repos": MOCK_REPOS[:4], # 只给 4 个仓库 "checks": [ ("subfield 非空", lambda r: bool(r.get("subfield"))), ("有 subfield_trend", lambda r: len(r.get("subfield_trend", "")) >= 10), ("方法族有 search_queries", lambda r: any( len(mf.get("search_queries", [])) >= 1 for mf in r.get("method_families", []) )), ], }, ] passed = 0 failed = 0 for tc in test_cases: print(f"\n [{tc['name']}]") try: result = analyze_direction(tc["title"], tc["abstract"], tc["repos"]) subfield = result.get("subfield", "?") families = result.get("method_families", []) print(f" 子领域: {subfield}") print(f" 方法族:") for mf in families: print(f" - {mf.get('family_name', '?')}: matched_repos={mf.get('matched_repos', [])}") tc_failed = False for check_name, check_fn in tc["checks"]: if check_fn(result): print(f" ✅ {check_name}") else: print(f" ❌ {check_name}") tc_failed = True if tc_failed: failed += 1 else: passed += 1 except Exception as e: print(f" ❌ 异常: {e}") import traceback traceback.print_exc() failed += 1 return passed, failed def test_agent2(): """测试 Agent 2:repo_evaluator""" from repo_evaluator import evaluate_repo print("\n" + "=" * 60) print("Agent 2 测试:仓库评估") print("=" * 60) test_cases = [ { "name": "anomalib (预期高分)", "repo": { "full_name": "openvinotoolkit/anomalib", "html_url": "https://github.com/openvinotoolkit/anomalib", "stars": 4000, "description": "An anomaly detection library comprising state-of-the-art algorithms.", "updated_at": "2026-01-15T00:00:00Z", "language": "Python", }, "readme": ( "# Anomalib\nA library for benchmarking anomaly detection.\n\n" "## Installation\n```\npip install anomalib\n```\n\n" "## Training\n```python\nfrom anomalib.engine import Engine\nengine = Engine()\nengine.train()\n```\n\n" "## Supported Datasets\n- MVTec AD\n- BTAD\n\n" "## Benchmarking\nUse `tools/benchmark.py` for standardized evaluation." ), "deps": {"requirements.txt": "torch>=1.10\npytorch-lightning>=1.7\nopencv-python\nnumpy"}, "family": "Patch Distribution Modeling", "checks": [ ("overall_score >= 55", lambda r: r.get("overall_score", 0) >= 55), ("verdict 非 error", lambda r: r.get("verdict") in ("reproducible", "partially")), ("benchmark_readiness 非 not_ready", lambda r: r.get("benchmark_readiness") in ("ready", "partial")), ("reasoning 非空", lambda r: len(r.get("reasoning", "")) >= 10), ], }, { "name": "推理 Demo 项目 (预期低分)", "repo": { "full_name": "someone/anomaly-demo", "html_url": "https://github.com/someone/anomaly-demo", "stars": 15, "description": "A simple demo of anomaly detection.", "updated_at": "2024-03-01T00:00:00Z", "language": "Python", }, "readme": "# Anomaly Detection Demo\nJust a demo.\n\n```\npython demo.py --image path/to/image.jpg\n```\nPretrained weights: Google Drive link.", "deps": {}, "family": "", "checks": [ ("overall_score < 50", lambda r: r.get("overall_score", 0) < 50), ("benchmark_readiness 为 not_ready", lambda r: r.get("benchmark_readiness") == "not_ready"), ("有 risks 列表", lambda r: isinstance(r.get("risks"), list)), ], }, { "name": "中等项目 (部分可复现)", "repo": { "full_name": "research-lab/anomaly-pytorch", "html_url": "https://github.com/research-lab/anomaly-pytorch", "stars": 200, "description": "PyTorch implementation of anomaly detection methods.", "updated_at": "2025-08-01T00:00:00Z", "language": "Python", }, "readme": ( "# Anomaly Detection in PyTorch\n\n" "## Install\n```\npip install -r requirements.txt\n```\n\n" "## Train\n```\npython train.py --config config.yaml\n```" ), "deps": {"requirements.txt": "torch>=1.8\nnumpy\nopencv-python\nmatplotlib"}, "family": "Teacher-Student", "checks": [ ("overall_score 在 40-85 之间", lambda r: 40 <= r.get("overall_score", 0) <= 85), ("verdict 非 error", lambda r: r.get("verdict") != "error"), ], }, ] passed = 0 failed = 0 for tc in test_cases: print(f"\n [{tc['name']}]") try: result = evaluate_repo(tc["repo"], tc["readme"], tc["deps"], tc["family"]) print(f" 综合: {result.get('overall_score')}/100") print(f" 可复现性: {result.get('reproducibility_score')}/80") print(f" 适配度: {result.get('benchmark_fitness_score')}/20") print(f" 判定: {result.get('verdict')} | 就绪: {result.get('benchmark_readiness')}") tc_failed = False for check_name, check_fn in tc["checks"]: if check_fn(result): print(f" ✅ {check_name}") else: print(f" ❌ {check_name}") tc_failed = True if tc_failed: failed += 1 else: passed += 1 except Exception as e: print(f" ❌ 异常: {e}") import traceback traceback.print_exc() failed += 1 return passed, failed # ============================================================ # 主入口 # ============================================================ if __name__ == "__main__": fix_windows_encoding() print("ResearchRadar · Agent 自检套件") print("=" * 60) a1_pass, a1_fail = test_agent1() a2_pass, a2_fail = test_agent2() print("\n" + "=" * 60) print("测试汇总") print("=" * 60) print(f" Agent 1 (方法族归纳): {a1_pass} 通过, {a1_fail} 失败") print(f" Agent 2 (仓库评估): {a2_pass} 通过, {a2_fail} 失败") total_pass = a1_pass + a2_pass total_fail = a1_fail + a2_fail print(f" 总计: {total_pass} 通过, {total_fail} 失败") if total_fail > 0: print("\n⚠ 部分测试未通过,请检查对应 Agent 的输出。") sys.exit(1) else: print("\n✅ 所有测试通过!") sys.exit(0)