# direction_analyzer.py # ============================================================ # 类型:AGENT(乙负责) # 功能:基于 GitHub 搜索结果 + 论文信息 → 归纳方法族谱系 # 用法:python direction_analyzer.py(自测) # ============================================================ from llm_utils import call_llm_json, parse_json_safe, validate_direction_output, fix_windows_encoding DIRECTION_SYSTEM_PROMPT = """你是一个 AI 研究领域的开源生态分析师。你的任务是:拿到一篇论文和 GitHub 上搜到的相关仓库列表,帮研究者理清——这些仓库分别属于什么方法族、哪些方法族有成熟开源实现、哪些还是空白。 ## ⚠️ 防混淆规则(最高优先级) - 你必须**严格基于用户提供的这篇具体论文**进行分析,**绝对禁止**将输入论文与其他相似论文混淆 - 如果你的训练数据中有更知名的论文与输入论文标题相似,**忽略它们**,只分析当前输入的这篇 - 子领域定位必须从**这篇论文的实际标题和摘要**出发,方法族归纳必须基于**提供的 GitHub 仓库列表** - 不要在分析中引用未在输入中出现的论文标题、方法名或作者 - 如果输入的是一篇"XYZNet for anomaly detection",不要分析成 Transformer、ResNet 或任何训练数据中的其他模型 ## 分析流程 ### 步骤 1:定位子领域 - 结合论文内容和仓库描述,判断该论文属于哪个具体子领域 - 不要笼统(如"深度学习"),要具体(如"工业图像异常检测") - 说明该子领域在 2024-2025 年的主流趋势(3-5 句中文,包含:主要技术路线演进方向、关键突破性论文/项目、活跃研究组或机构、当前正在解决的核心问题) ### 步骤 2:归纳方法族 - 浏览所有仓库的 description 和 topics,将同类方法的仓库归为一族 - 每个方法族给出:名称(英文简称)、核心特点(2-3句详细描述,包含技术原理关键词、相比其他族的优势、主要适用场景)、代表论文(含arxiv ID或会议名)/知名开源项目 - 输出 3-6 个方法族 - 如果一个仓库明显不属于任何一族(如只是教程、论文列表),不要强行归类 - 重要:方法族名称尽量使用英文(如"Patch Distribution Modeling"),便于后续搜索 ### 步骤 3:为每个方法族列出归属仓库 - matched_repos 填该族包含的仓库 full_name 列表 - 如果一个仓库可以属于多个族(如同时实现了方法A和方法B),选最匹配的那个 ### 步骤 4:生成精准搜索词 - 为每个方法族生成 2-3 个 GitHub 搜索查询(英文) - 格式:方法名 + 技术关键词 + 限定词(implementation / pytorch / official / code) - 额外生成 3 个宽泛搜索词,用于查找未被当前搜索覆盖的方法族 ## 搜索查询质量要求 - 好的查询:"patch distribution modeling anomaly detection pytorch" - 坏的查询:"anomaly"(太宽泛) - 坏的查询:"a novel approach for industrial defect detection using deep learning"(太多废话) ## 输出格式(严格 JSON,紧凑格式不要缩进和多余空格) { "subfield": "具体子领域名称", "subfield_trend": "该子领域 2024-2025 的主流趋势(3-5句中文,含技术路线演进、关键突破、活跃研究组/机构、核心问题)", "method_families": [ { "family_name": "方法族简称", "description": "核心特点(2-3句详细描述,含技术原理关键词、优势、适用场景)", "representative_work": "代表论文或知名项目", "matched_repos": ["owner/repo1", "owner/repo2"], "search_queries": ["搜索词1", "搜索词2"] } ], "broad_queries": ["宽泛搜索1", "宽泛搜索2", "宽泛搜索3"] } """ def analyze_direction(title: str, abstract: str, repos: list[dict] | None = None, domain_context: str = "") -> dict: """基于论文信息和 GitHub 搜索结果,归纳方法族谱系。 Args: title: 论文标题 abstract: 论文摘要(会自动截断) repos: GitHub 搜索结果列表,每个 dict 包含 full_name, description, stars, language, topics, html_url。如果为 None 或空列表, 则仅基于 LLM 自身知识分析(降级模式)。 domain_context: 可选的领域上下文补充(如同领域综述论文摘要), 帮助 LLM 更准确地定位子领域和识别方法族。 Returns: dict: 包含 subfield, subfield_trend, method_families, broad_queries """ abstract_truncated = abstract[:2000] if abstract else "(摘要不可用)" title_clean = title.strip() if title else "(标题不可用)" # 构建仓库列表文本 if repos: repo_lines = [] for r in repos[:20]: # 最多传 20 个仓库给 LLM full_name = r.get("full_name", "") desc = r.get("description", "")[:100] stars = r.get("stars", 0) topics = ", ".join(r.get("topics", [])[:5]) language = r.get("language", "") repo_lines.append( f"- **{full_name}** (⭐{stars}, {language}) — {desc}\n" f" topics: {topics}" ) repos_text = "\n".join(repo_lines) else: repos_text = "(未提供 GitHub 搜索结果,请基于你的训练知识进行分析)" domain_section = "" if domain_context: domain_section = f"\n## 领域上下文(同领域相关综述/调查论文摘要)\n\n{domain_context}\n" user_prompt = f"""## ⚠️ 以下是你需要分析的论文,请务必将你的分析锚定在这篇论文上 **论文标题**: {title_clean} **论文摘要**: {abstract_truncated} {domain_section} ## GitHub 搜索结果(共 {len(repos) if repos else 0} 个仓库) {repos_text} 请严格基于以上这篇论文和仓库数据分析,不要使用训练数据中其他相似论文的信息。""" raw = call_llm_json(DIRECTION_SYSTEM_PROMPT, user_prompt, temperature=0.4, max_tokens=16000) data = parse_json_safe(raw, "direction_analyzer") return validate_direction_output(data) # ============================================================ # 自测 # ============================================================ if __name__ == "__main__": fix_windows_encoding() # 构造模拟 GitHub 搜索结果 mock_repos = [ { "full_name": "openvinotoolkit/anomalib", "description": "An anomaly detection library comprising state-of-the-art algorithms.", "stars": 4000, "language": "Python", "topics": ["anomaly-detection", "pytorch", "benchmarking", "industrial-inspection"], "html_url": "https://github.com/openvinotoolkit/anomalib", }, { "full_name": "hcw-00/PatchCore", "description": "Official implementation of PatchCore: Towards Total Recall in Industrial Anomaly Detection.", "stars": 850, "language": "Python", "topics": ["anomaly-detection", "memory-bank", "pytorch"], "html_url": "https://github.com/hcw-00/PatchCore", }, { "full_name": "taikiinoue45/STFPM", "description": "Student-Teacher Feature Pyramid Matching for Anomaly Detection (PyTorch).", "stars": 320, "language": "Python", "topics": ["teacher-student", "anomaly-detection", "pytorch"], "html_url": "https://github.com/taikiinoue45/STFPM", }, { "full_name": "VitjanZ/DRAEM", "description": "DRAEM: Discriminatively trained reconstruction embedding for surface anomaly detection.", "stars": 180, "language": "Python", "topics": ["synthetic-defect", "anomaly-detection", "reconstruction"], "html_url": "https://github.com/VitjanZ/DRAEM", }, { "full_name": "marugoto/CFLow", "description": "CFLow: Real-time unsupervised anomaly detection via conditional normalizing flows.", "stars": 130, "language": "Python", "topics": ["normalizing-flow", "anomaly-detection", "pytorch"], "html_url": "https://github.com/marugoto/CFLow", }, { "full_name": "xiahaifeng1995/PaDiM-Anomaly-Detection", "description": "Unofficial implementation of PaDiM: Patch Distribution Modeling for anomaly detection.", "stars": 150, "language": "Python", "topics": ["padim", "anomaly-detection", "mvtec-ad"], "html_url": "https://github.com/xiahaifeng1995/PaDiM-Anomaly-Detection", }, { "full_name": "someone/anomaly-detection-tutorial", "description": "A tutorial collection of anomaly detection papers and resources.", "stars": 45, "language": "Markdown", "topics": ["awesome-list", "anomaly-detection"], "html_url": "https://github.com/someone/anomaly-detection-tutorial", }, ] print("=" * 60) print("Agent 1 测试:基于 GitHub 仓库的方法族归纳") print("=" * 60) result = analyze_direction( "PaDiM: a Patch Distribution Modeling Framework for Anomaly Detection and Localization", "We present a new framework for anomaly detection and localization based on " "patch distribution modeling. PaDiM uses pretrained convolutional neural networks " "to extract patch features and models their distribution using multivariate " "Gaussian distributions.", mock_repos, ) print(f"子领域: {result.get('subfield')}") print(f"趋势: {result.get('subfield_trend')}") print(f"方法族数量: {len(result.get('method_families', []))}") for mf in result.get("method_families", []): repos_in_family = mf.get("matched_repos", []) print(f" - {mf.get('family_name')}: {mf.get('description')[:60]}...") print(f" 归属仓库: {repos_in_family}") print(f" 搜索词: {mf.get('search_queries', [])}") print(f"宽泛搜索词: {result.get('broad_queries', [])}") total_queries = sum(len(mf.get('search_queries', [])) for mf in result.get('method_families', [])) + len(result.get('broad_queries', [])) print(f"总搜索词数: {total_queries}") print() print("=" * 60) print("Agent 1 测试:降级模式(无 GitHub 数据)") print("=" * 60) result2 = analyze_direction( "Attention Is All You Need", "The dominant sequence transduction models are based on complex recurrent " "or convolutional neural networks...", None, ) print(f"子领域: {result2.get('subfield')}") print(f"方法族: {[mf.get('family_name') for mf in result2.get('method_families', [])]}")