Spaces:
Running
Running
| # test_agents.py | |
| # ============================================================ | |
| # 类型:测试工具(乙负责) | |
| # 功能:一键运行两个 Agent 的自测,验证基础功能正常 | |
| # 用法:python test_agents.py | |
| # ============================================================ | |
| import sys | |
| from llm_utils import fix_windows_encoding | |
| # 模拟 GitHub 搜索结果(供 Agent 1 测试使用) | |
| MOCK_REPOS = [ | |
| { | |
| "full_name": "openvinotoolkit/anomalib", | |
| "html_url": "https://github.com/openvinotoolkit/anomalib", | |
| "description": "An anomaly detection library comprising state-of-the-art algorithms.", | |
| "stars": 4000, "language": "Python", | |
| "topics": ["anomaly-detection", "pytorch", "benchmarking"], | |
| }, | |
| { | |
| "full_name": "hcw-00/PatchCore", | |
| "html_url": "https://github.com/hcw-00/PatchCore", | |
| "description": "Official implementation of PatchCore: Towards Total Recall in Industrial Anomaly Detection.", | |
| "stars": 850, "language": "Python", | |
| "topics": ["memory-bank", "anomaly-detection", "pytorch"], | |
| }, | |
| { | |
| "full_name": "taikiinoue45/STFPM", | |
| "html_url": "https://github.com/taikiinoue45/STFPM", | |
| "description": "Student-Teacher Feature Pyramid Matching for Anomaly Detection (PyTorch).", | |
| "stars": 320, "language": "Python", | |
| "topics": ["teacher-student", "anomaly-detection"], | |
| }, | |
| { | |
| "full_name": "VitjanZ/DRAEM", | |
| "html_url": "https://github.com/VitjanZ/DRAEM", | |
| "description": "DRAEM: Discriminatively trained reconstruction embedding for surface anomaly detection.", | |
| "stars": 180, "language": "Python", | |
| "topics": ["synthetic-defect", "anomaly-detection", "reconstruction"], | |
| }, | |
| { | |
| "full_name": "marugoto/CFLow", | |
| "html_url": "https://github.com/marugoto/CFLow", | |
| "description": "CFLow: Real-time unsupervised anomaly detection via conditional normalizing flows.", | |
| "stars": 130, "language": "Python", | |
| "topics": ["normalizing-flow", "anomaly-detection"], | |
| }, | |
| { | |
| "full_name": "someone/anomaly-tutorial", | |
| "html_url": "https://github.com/someone/anomaly-tutorial", | |
| "description": "A curated list of anomaly detection papers and resources.", | |
| "stars": 45, "language": "Markdown", | |
| "topics": ["awesome-list", "anomaly-detection"], | |
| }, | |
| ] | |
| def test_agent1(): | |
| """测试 Agent 1:direction_analyzer(新版——基于仓库数据归纳)""" | |
| from direction_analyzer import analyze_direction | |
| print("=" * 60) | |
| print("Agent 1 测试:基于 GitHub 仓库的方法族归纳") | |
| print("=" * 60) | |
| test_cases = [ | |
| { | |
| "name": "PaDiM + 真实仓库数据", | |
| "title": "PaDiM: a Patch Distribution Modeling Framework for Anomaly Detection and Localization", | |
| "abstract": ( | |
| "We present a new framework for anomaly detection and localization based on " | |
| "patch distribution modeling. PaDiM uses pretrained convolutional neural networks " | |
| "to extract patch features and models their distribution using multivariate " | |
| "Gaussian distributions. It achieves state-of-the-art results on MVTec AD and " | |
| "STC datasets for industrial defect detection." | |
| ), | |
| "repos": MOCK_REPOS, | |
| "checks": [ | |
| ("subfield 非空", lambda r: bool(r.get("subfield"))), | |
| ("有 >= 2 个方法族", lambda r: len(r.get("method_families", [])) >= 2), | |
| ("至少 1 个方法族有 matched_repos", lambda r: any( | |
| len(mf.get("matched_repos", [])) >= 1 | |
| for mf in r.get("method_families", []) | |
| )), | |
| ("有 broad_queries", lambda r: len(r.get("broad_queries", [])) >= 1), | |
| ], | |
| }, | |
| { | |
| "name": "Transformer (降级模式:无仓库数据)", | |
| "title": "Attention Is All You Need", | |
| "abstract": ( | |
| "The dominant sequence transduction models are based on complex recurrent " | |
| "or convolutional neural networks... We propose a new simple network " | |
| "architecture, the Transformer, based solely on attention mechanisms." | |
| ), | |
| "repos": None, # 降级模式 | |
| "checks": [ | |
| ("subfield 非空", lambda r: bool(r.get("subfield"))), | |
| ("有 >= 3 个方法族", lambda r: len(r.get("method_families", [])) >= 3), | |
| ("有 broad_queries", lambda r: len(r.get("broad_queries", [])) >= 1), | |
| ], | |
| }, | |
| { | |
| "name": "EfficientAD (轻量仓库数据)", | |
| "title": "EfficientAD: Accurate Visual Anomaly Detection at Millisecond-Level Latencies", | |
| "abstract": ( | |
| "We propose EfficientAD, a lightweight anomaly detection method that achieves " | |
| "millisecond-level latency. It uses a teacher-student architecture with a simple " | |
| "student network that learns to mimic the teacher's features on normal images only." | |
| ), | |
| "repos": MOCK_REPOS[:4], # 只给 4 个仓库 | |
| "checks": [ | |
| ("subfield 非空", lambda r: bool(r.get("subfield"))), | |
| ("有 subfield_trend", lambda r: len(r.get("subfield_trend", "")) >= 10), | |
| ("方法族有 search_queries", lambda r: any( | |
| len(mf.get("search_queries", [])) >= 1 | |
| for mf in r.get("method_families", []) | |
| )), | |
| ], | |
| }, | |
| ] | |
| passed = 0 | |
| failed = 0 | |
| for tc in test_cases: | |
| print(f"\n [{tc['name']}]") | |
| try: | |
| result = analyze_direction(tc["title"], tc["abstract"], tc["repos"]) | |
| subfield = result.get("subfield", "?") | |
| families = result.get("method_families", []) | |
| print(f" 子领域: {subfield}") | |
| print(f" 方法族:") | |
| for mf in families: | |
| print(f" - {mf.get('family_name', '?')}: matched_repos={mf.get('matched_repos', [])}") | |
| tc_failed = False | |
| for check_name, check_fn in tc["checks"]: | |
| if check_fn(result): | |
| print(f" ✅ {check_name}") | |
| else: | |
| print(f" ❌ {check_name}") | |
| tc_failed = True | |
| if tc_failed: | |
| failed += 1 | |
| else: | |
| passed += 1 | |
| except Exception as e: | |
| print(f" ❌ 异常: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| failed += 1 | |
| return passed, failed | |
| def test_agent2(): | |
| """测试 Agent 2:repo_evaluator""" | |
| from repo_evaluator import evaluate_repo | |
| print("\n" + "=" * 60) | |
| print("Agent 2 测试:仓库评估") | |
| print("=" * 60) | |
| test_cases = [ | |
| { | |
| "name": "anomalib (预期高分)", | |
| "repo": { | |
| "full_name": "openvinotoolkit/anomalib", | |
| "html_url": "https://github.com/openvinotoolkit/anomalib", | |
| "stars": 4000, | |
| "description": "An anomaly detection library comprising state-of-the-art algorithms.", | |
| "updated_at": "2026-01-15T00:00:00Z", | |
| "language": "Python", | |
| }, | |
| "readme": ( | |
| "# Anomalib\nA library for benchmarking anomaly detection.\n\n" | |
| "## Installation\n```\npip install anomalib\n```\n\n" | |
| "## Training\n```python\nfrom anomalib.engine import Engine\nengine = Engine()\nengine.train()\n```\n\n" | |
| "## Supported Datasets\n- MVTec AD\n- BTAD\n\n" | |
| "## Benchmarking\nUse `tools/benchmark.py` for standardized evaluation." | |
| ), | |
| "deps": {"requirements.txt": "torch>=1.10\npytorch-lightning>=1.7\nopencv-python\nnumpy"}, | |
| "family": "Patch Distribution Modeling", | |
| "checks": [ | |
| ("overall_score >= 55", lambda r: r.get("overall_score", 0) >= 55), | |
| ("verdict 非 error", lambda r: r.get("verdict") in ("reproducible", "partially")), | |
| ("benchmark_readiness 非 not_ready", lambda r: r.get("benchmark_readiness") in ("ready", "partial")), | |
| ("reasoning 非空", lambda r: len(r.get("reasoning", "")) >= 10), | |
| ], | |
| }, | |
| { | |
| "name": "推理 Demo 项目 (预期低分)", | |
| "repo": { | |
| "full_name": "someone/anomaly-demo", | |
| "html_url": "https://github.com/someone/anomaly-demo", | |
| "stars": 15, | |
| "description": "A simple demo of anomaly detection.", | |
| "updated_at": "2024-03-01T00:00:00Z", | |
| "language": "Python", | |
| }, | |
| "readme": "# Anomaly Detection Demo\nJust a demo.\n\n```\npython demo.py --image path/to/image.jpg\n```\nPretrained weights: Google Drive link.", | |
| "deps": {}, | |
| "family": "", | |
| "checks": [ | |
| ("overall_score < 50", lambda r: r.get("overall_score", 0) < 50), | |
| ("benchmark_readiness 为 not_ready", lambda r: r.get("benchmark_readiness") == "not_ready"), | |
| ("有 risks 列表", lambda r: isinstance(r.get("risks"), list)), | |
| ], | |
| }, | |
| { | |
| "name": "中等项目 (部分可复现)", | |
| "repo": { | |
| "full_name": "research-lab/anomaly-pytorch", | |
| "html_url": "https://github.com/research-lab/anomaly-pytorch", | |
| "stars": 200, | |
| "description": "PyTorch implementation of anomaly detection methods.", | |
| "updated_at": "2025-08-01T00:00:00Z", | |
| "language": "Python", | |
| }, | |
| "readme": ( | |
| "# Anomaly Detection in PyTorch\n\n" | |
| "## Install\n```\npip install -r requirements.txt\n```\n\n" | |
| "## Train\n```\npython train.py --config config.yaml\n```" | |
| ), | |
| "deps": {"requirements.txt": "torch>=1.8\nnumpy\nopencv-python\nmatplotlib"}, | |
| "family": "Teacher-Student", | |
| "checks": [ | |
| ("overall_score 在 40-85 之间", lambda r: 40 <= r.get("overall_score", 0) <= 85), | |
| ("verdict 非 error", lambda r: r.get("verdict") != "error"), | |
| ], | |
| }, | |
| ] | |
| passed = 0 | |
| failed = 0 | |
| for tc in test_cases: | |
| print(f"\n [{tc['name']}]") | |
| try: | |
| result = evaluate_repo(tc["repo"], tc["readme"], tc["deps"], tc["family"]) | |
| print(f" 综合: {result.get('overall_score')}/100") | |
| print(f" 可复现性: {result.get('reproducibility_score')}/80") | |
| print(f" 适配度: {result.get('benchmark_fitness_score')}/20") | |
| print(f" 判定: {result.get('verdict')} | 就绪: {result.get('benchmark_readiness')}") | |
| tc_failed = False | |
| for check_name, check_fn in tc["checks"]: | |
| if check_fn(result): | |
| print(f" ✅ {check_name}") | |
| else: | |
| print(f" ❌ {check_name}") | |
| tc_failed = True | |
| if tc_failed: | |
| failed += 1 | |
| else: | |
| passed += 1 | |
| except Exception as e: | |
| print(f" ❌ 异常: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| failed += 1 | |
| return passed, failed | |
| # ============================================================ | |
| # 主入口 | |
| # ============================================================ | |
| if __name__ == "__main__": | |
| fix_windows_encoding() | |
| print("ResearchRadar · Agent 自检套件") | |
| print("=" * 60) | |
| a1_pass, a1_fail = test_agent1() | |
| a2_pass, a2_fail = test_agent2() | |
| print("\n" + "=" * 60) | |
| print("测试汇总") | |
| print("=" * 60) | |
| print(f" Agent 1 (方法族归纳): {a1_pass} 通过, {a1_fail} 失败") | |
| print(f" Agent 2 (仓库评估): {a2_pass} 通过, {a2_fail} 失败") | |
| total_pass = a1_pass + a2_pass | |
| total_fail = a1_fail + a2_fail | |
| print(f" 总计: {total_pass} 通过, {total_fail} 失败") | |
| if total_fail > 0: | |
| print("\n⚠ 部分测试未通过,请检查对应 Agent 的输出。") | |
| sys.exit(1) | |
| else: | |
| print("\n✅ 所有测试通过!") | |
| sys.exit(0) | |