Spaces:
Running
Running
| # direction_analyzer.py | |
| # ============================================================ | |
| # 类型:AGENT(乙负责) | |
| # 功能:基于 GitHub 搜索结果 + 论文信息 → 归纳方法族谱系 | |
| # 用法:python direction_analyzer.py(自测) | |
| # ============================================================ | |
| from llm_utils import call_llm_json, parse_json_safe, validate_direction_output, fix_windows_encoding | |
| DIRECTION_SYSTEM_PROMPT = """你是一个 AI 研究领域的开源生态分析师。你的任务是:拿到一篇论文和 GitHub 上搜到的相关仓库列表,帮研究者理清——这些仓库分别属于什么方法族、哪些方法族有成熟开源实现、哪些还是空白。 | |
| ## ⚠️ 防混淆规则(最高优先级) | |
| - 你必须**严格基于用户提供的这篇具体论文**进行分析,**绝对禁止**将输入论文与其他相似论文混淆 | |
| - 如果你的训练数据中有更知名的论文与输入论文标题相似,**忽略它们**,只分析当前输入的这篇 | |
| - 子领域定位必须从**这篇论文的实际标题和摘要**出发,方法族归纳必须基于**提供的 GitHub 仓库列表** | |
| - 不要在分析中引用未在输入中出现的论文标题、方法名或作者 | |
| - 如果输入的是一篇"XYZNet for anomaly detection",不要分析成 Transformer、ResNet 或任何训练数据中的其他模型 | |
| ## 分析流程 | |
| ### 步骤 1:定位子领域 | |
| - 结合论文内容和仓库描述,判断该论文属于哪个具体子领域 | |
| - 不要笼统(如"深度学习"),要具体(如"工业图像异常检测") | |
| - 说明该子领域在 2024-2025 年的主流趋势(3-5 句中文,包含:主要技术路线演进方向、关键突破性论文/项目、活跃研究组或机构、当前正在解决的核心问题) | |
| ### 步骤 2:归纳方法族 | |
| - 浏览所有仓库的 description 和 topics,将同类方法的仓库归为一族 | |
| - 每个方法族给出:名称(英文简称)、核心特点(2-3句详细描述,包含技术原理关键词、相比其他族的优势、主要适用场景)、代表论文(含arxiv ID或会议名)/知名开源项目 | |
| - 输出 3-6 个方法族 | |
| - 如果一个仓库明显不属于任何一族(如只是教程、论文列表),不要强行归类 | |
| - 重要:方法族名称尽量使用英文(如"Patch Distribution Modeling"),便于后续搜索 | |
| ### 步骤 3:为每个方法族列出归属仓库 | |
| - matched_repos 填该族包含的仓库 full_name 列表 | |
| - 如果一个仓库可以属于多个族(如同时实现了方法A和方法B),选最匹配的那个 | |
| ### 步骤 4:生成精准搜索词 | |
| - 为每个方法族生成 2-3 个 GitHub 搜索查询(英文) | |
| - 格式:方法名 + 技术关键词 + 限定词(implementation / pytorch / official / code) | |
| - 额外生成 3 个宽泛搜索词,用于查找未被当前搜索覆盖的方法族 | |
| ## 搜索查询质量要求 | |
| - 好的查询:"patch distribution modeling anomaly detection pytorch" | |
| - 坏的查询:"anomaly"(太宽泛) | |
| - 坏的查询:"a novel approach for industrial defect detection using deep learning"(太多废话) | |
| ## 输出格式(严格 JSON,紧凑格式不要缩进和多余空格) | |
| { | |
| "subfield": "具体子领域名称", | |
| "subfield_trend": "该子领域 2024-2025 的主流趋势(3-5句中文,含技术路线演进、关键突破、活跃研究组/机构、核心问题)", | |
| "method_families": [ | |
| { | |
| "family_name": "方法族简称", | |
| "description": "核心特点(2-3句详细描述,含技术原理关键词、优势、适用场景)", | |
| "representative_work": "代表论文或知名项目", | |
| "matched_repos": ["owner/repo1", "owner/repo2"], | |
| "search_queries": ["搜索词1", "搜索词2"] | |
| } | |
| ], | |
| "broad_queries": ["宽泛搜索1", "宽泛搜索2", "宽泛搜索3"] | |
| } | |
| """ | |
| def analyze_direction(title: str, abstract: str, repos: list[dict] | None = None, domain_context: str = "") -> dict: | |
| """基于论文信息和 GitHub 搜索结果,归纳方法族谱系。 | |
| Args: | |
| title: 论文标题 | |
| abstract: 论文摘要(会自动截断) | |
| repos: GitHub 搜索结果列表,每个 dict 包含 full_name, description, | |
| stars, language, topics, html_url。如果为 None 或空列表, | |
| 则仅基于 LLM 自身知识分析(降级模式)。 | |
| domain_context: 可选的领域上下文补充(如同领域综述论文摘要), | |
| 帮助 LLM 更准确地定位子领域和识别方法族。 | |
| Returns: | |
| dict: 包含 subfield, subfield_trend, method_families, broad_queries | |
| """ | |
| abstract_truncated = abstract[:2000] if abstract else "(摘要不可用)" | |
| title_clean = title.strip() if title else "(标题不可用)" | |
| # 构建仓库列表文本 | |
| if repos: | |
| repo_lines = [] | |
| for r in repos[:20]: # 最多传 20 个仓库给 LLM | |
| full_name = r.get("full_name", "") | |
| desc = r.get("description", "")[:100] | |
| stars = r.get("stars", 0) | |
| topics = ", ".join(r.get("topics", [])[:5]) | |
| language = r.get("language", "") | |
| repo_lines.append( | |
| f"- **{full_name}** (⭐{stars}, {language}) — {desc}\n" | |
| f" topics: {topics}" | |
| ) | |
| repos_text = "\n".join(repo_lines) | |
| else: | |
| repos_text = "(未提供 GitHub 搜索结果,请基于你的训练知识进行分析)" | |
| domain_section = "" | |
| if domain_context: | |
| domain_section = f"\n## 领域上下文(同领域相关综述/调查论文摘要)\n\n{domain_context}\n" | |
| user_prompt = f"""## ⚠️ 以下是你需要分析的论文,请务必将你的分析锚定在这篇论文上 | |
| **论文标题**: {title_clean} | |
| **论文摘要**: {abstract_truncated} | |
| {domain_section} | |
| ## GitHub 搜索结果(共 {len(repos) if repos else 0} 个仓库) | |
| {repos_text} | |
| 请严格基于以上这篇论文和仓库数据分析,不要使用训练数据中其他相似论文的信息。""" | |
| raw = call_llm_json(DIRECTION_SYSTEM_PROMPT, user_prompt, temperature=0.4, max_tokens=16000) | |
| data = parse_json_safe(raw, "direction_analyzer") | |
| return validate_direction_output(data) | |
| # ============================================================ | |
| # 自测 | |
| # ============================================================ | |
| if __name__ == "__main__": | |
| fix_windows_encoding() | |
| # 构造模拟 GitHub 搜索结果 | |
| mock_repos = [ | |
| { | |
| "full_name": "openvinotoolkit/anomalib", | |
| "description": "An anomaly detection library comprising state-of-the-art algorithms.", | |
| "stars": 4000, "language": "Python", | |
| "topics": ["anomaly-detection", "pytorch", "benchmarking", "industrial-inspection"], | |
| "html_url": "https://github.com/openvinotoolkit/anomalib", | |
| }, | |
| { | |
| "full_name": "hcw-00/PatchCore", | |
| "description": "Official implementation of PatchCore: Towards Total Recall in Industrial Anomaly Detection.", | |
| "stars": 850, "language": "Python", | |
| "topics": ["anomaly-detection", "memory-bank", "pytorch"], | |
| "html_url": "https://github.com/hcw-00/PatchCore", | |
| }, | |
| { | |
| "full_name": "taikiinoue45/STFPM", | |
| "description": "Student-Teacher Feature Pyramid Matching for Anomaly Detection (PyTorch).", | |
| "stars": 320, "language": "Python", | |
| "topics": ["teacher-student", "anomaly-detection", "pytorch"], | |
| "html_url": "https://github.com/taikiinoue45/STFPM", | |
| }, | |
| { | |
| "full_name": "VitjanZ/DRAEM", | |
| "description": "DRAEM: Discriminatively trained reconstruction embedding for surface anomaly detection.", | |
| "stars": 180, "language": "Python", | |
| "topics": ["synthetic-defect", "anomaly-detection", "reconstruction"], | |
| "html_url": "https://github.com/VitjanZ/DRAEM", | |
| }, | |
| { | |
| "full_name": "marugoto/CFLow", | |
| "description": "CFLow: Real-time unsupervised anomaly detection via conditional normalizing flows.", | |
| "stars": 130, "language": "Python", | |
| "topics": ["normalizing-flow", "anomaly-detection", "pytorch"], | |
| "html_url": "https://github.com/marugoto/CFLow", | |
| }, | |
| { | |
| "full_name": "xiahaifeng1995/PaDiM-Anomaly-Detection", | |
| "description": "Unofficial implementation of PaDiM: Patch Distribution Modeling for anomaly detection.", | |
| "stars": 150, "language": "Python", | |
| "topics": ["padim", "anomaly-detection", "mvtec-ad"], | |
| "html_url": "https://github.com/xiahaifeng1995/PaDiM-Anomaly-Detection", | |
| }, | |
| { | |
| "full_name": "someone/anomaly-detection-tutorial", | |
| "description": "A tutorial collection of anomaly detection papers and resources.", | |
| "stars": 45, "language": "Markdown", | |
| "topics": ["awesome-list", "anomaly-detection"], | |
| "html_url": "https://github.com/someone/anomaly-detection-tutorial", | |
| }, | |
| ] | |
| print("=" * 60) | |
| print("Agent 1 测试:基于 GitHub 仓库的方法族归纳") | |
| print("=" * 60) | |
| result = analyze_direction( | |
| "PaDiM: a Patch Distribution Modeling Framework for Anomaly Detection and Localization", | |
| "We present a new framework for anomaly detection and localization based on " | |
| "patch distribution modeling. PaDiM uses pretrained convolutional neural networks " | |
| "to extract patch features and models their distribution using multivariate " | |
| "Gaussian distributions.", | |
| mock_repos, | |
| ) | |
| print(f"子领域: {result.get('subfield')}") | |
| print(f"趋势: {result.get('subfield_trend')}") | |
| print(f"方法族数量: {len(result.get('method_families', []))}") | |
| for mf in result.get("method_families", []): | |
| repos_in_family = mf.get("matched_repos", []) | |
| print(f" - {mf.get('family_name')}: {mf.get('description')[:60]}...") | |
| print(f" 归属仓库: {repos_in_family}") | |
| print(f" 搜索词: {mf.get('search_queries', [])}") | |
| print(f"宽泛搜索词: {result.get('broad_queries', [])}") | |
| total_queries = sum(len(mf.get('search_queries', [])) for mf in result.get('method_families', [])) + len(result.get('broad_queries', [])) | |
| print(f"总搜索词数: {total_queries}") | |
| print() | |
| print("=" * 60) | |
| print("Agent 1 测试:降级模式(无 GitHub 数据)") | |
| print("=" * 60) | |
| result2 = analyze_direction( | |
| "Attention Is All You Need", | |
| "The dominant sequence transduction models are based on complex recurrent " | |
| "or convolutional neural networks...", | |
| None, | |
| ) | |
| print(f"子领域: {result2.get('subfield')}") | |
| print(f"方法族: {[mf.get('family_name') for mf in result2.get('method_families', [])]}") | |