ResearchRadar / direction_analyzer.py
ZZZyx3587's picture
Upload direction_analyzer.py with huggingface_hub
5855e32 verified
# direction_analyzer.py
# ============================================================
# 类型:AGENT(乙负责)
# 功能:基于 GitHub 搜索结果 + 论文信息 → 归纳方法族谱系
# 用法:python direction_analyzer.py(自测)
# ============================================================
from llm_utils import call_llm_json, parse_json_safe, validate_direction_output, fix_windows_encoding
DIRECTION_SYSTEM_PROMPT = """你是一个 AI 研究领域的开源生态分析师。你的任务是:拿到一篇论文和 GitHub 上搜到的相关仓库列表,帮研究者理清——这些仓库分别属于什么方法族、哪些方法族有成熟开源实现、哪些还是空白。
## ⚠️ 防混淆规则(最高优先级)
- 你必须**严格基于用户提供的这篇具体论文**进行分析,**绝对禁止**将输入论文与其他相似论文混淆
- 如果你的训练数据中有更知名的论文与输入论文标题相似,**忽略它们**,只分析当前输入的这篇
- 子领域定位必须从**这篇论文的实际标题和摘要**出发,方法族归纳必须基于**提供的 GitHub 仓库列表**
- 不要在分析中引用未在输入中出现的论文标题、方法名或作者
- 如果输入的是一篇"XYZNet for anomaly detection",不要分析成 Transformer、ResNet 或任何训练数据中的其他模型
## 分析流程
### 步骤 1:定位子领域
- 结合论文内容和仓库描述,判断该论文属于哪个具体子领域
- 不要笼统(如"深度学习"),要具体(如"工业图像异常检测")
- 说明该子领域在 2024-2025 年的主流趋势(3-5 句中文,包含:主要技术路线演进方向、关键突破性论文/项目、活跃研究组或机构、当前正在解决的核心问题)
### 步骤 2:归纳方法族
- 浏览所有仓库的 description 和 topics,将同类方法的仓库归为一族
- 每个方法族给出:名称(英文简称)、核心特点(2-3句详细描述,包含技术原理关键词、相比其他族的优势、主要适用场景)、代表论文(含arxiv ID或会议名)/知名开源项目
- 输出 3-6 个方法族
- 如果一个仓库明显不属于任何一族(如只是教程、论文列表),不要强行归类
- 重要:方法族名称尽量使用英文(如"Patch Distribution Modeling"),便于后续搜索
### 步骤 3:为每个方法族列出归属仓库
- matched_repos 填该族包含的仓库 full_name 列表
- 如果一个仓库可以属于多个族(如同时实现了方法A和方法B),选最匹配的那个
### 步骤 4:生成精准搜索词
- 为每个方法族生成 2-3 个 GitHub 搜索查询(英文)
- 格式:方法名 + 技术关键词 + 限定词(implementation / pytorch / official / code)
- 额外生成 3 个宽泛搜索词,用于查找未被当前搜索覆盖的方法族
## 搜索查询质量要求
- 好的查询:"patch distribution modeling anomaly detection pytorch"
- 坏的查询:"anomaly"(太宽泛)
- 坏的查询:"a novel approach for industrial defect detection using deep learning"(太多废话)
## 输出格式(严格 JSON,紧凑格式不要缩进和多余空格)
{
"subfield": "具体子领域名称",
"subfield_trend": "该子领域 2024-2025 的主流趋势(3-5句中文,含技术路线演进、关键突破、活跃研究组/机构、核心问题)",
"method_families": [
{
"family_name": "方法族简称",
"description": "核心特点(2-3句详细描述,含技术原理关键词、优势、适用场景)",
"representative_work": "代表论文或知名项目",
"matched_repos": ["owner/repo1", "owner/repo2"],
"search_queries": ["搜索词1", "搜索词2"]
}
],
"broad_queries": ["宽泛搜索1", "宽泛搜索2", "宽泛搜索3"]
}
"""
def analyze_direction(title: str, abstract: str, repos: list[dict] | None = None, domain_context: str = "") -> dict:
"""基于论文信息和 GitHub 搜索结果,归纳方法族谱系。
Args:
title: 论文标题
abstract: 论文摘要(会自动截断)
repos: GitHub 搜索结果列表,每个 dict 包含 full_name, description,
stars, language, topics, html_url。如果为 None 或空列表,
则仅基于 LLM 自身知识分析(降级模式)。
domain_context: 可选的领域上下文补充(如同领域综述论文摘要),
帮助 LLM 更准确地定位子领域和识别方法族。
Returns:
dict: 包含 subfield, subfield_trend, method_families, broad_queries
"""
abstract_truncated = abstract[:2000] if abstract else "(摘要不可用)"
title_clean = title.strip() if title else "(标题不可用)"
# 构建仓库列表文本
if repos:
repo_lines = []
for r in repos[:20]: # 最多传 20 个仓库给 LLM
full_name = r.get("full_name", "")
desc = r.get("description", "")[:100]
stars = r.get("stars", 0)
topics = ", ".join(r.get("topics", [])[:5])
language = r.get("language", "")
repo_lines.append(
f"- **{full_name}** (⭐{stars}, {language}) — {desc}\n"
f" topics: {topics}"
)
repos_text = "\n".join(repo_lines)
else:
repos_text = "(未提供 GitHub 搜索结果,请基于你的训练知识进行分析)"
domain_section = ""
if domain_context:
domain_section = f"\n## 领域上下文(同领域相关综述/调查论文摘要)\n\n{domain_context}\n"
user_prompt = f"""## ⚠️ 以下是你需要分析的论文,请务必将你的分析锚定在这篇论文上
**论文标题**: {title_clean}
**论文摘要**: {abstract_truncated}
{domain_section}
## GitHub 搜索结果(共 {len(repos) if repos else 0} 个仓库)
{repos_text}
请严格基于以上这篇论文和仓库数据分析,不要使用训练数据中其他相似论文的信息。"""
raw = call_llm_json(DIRECTION_SYSTEM_PROMPT, user_prompt, temperature=0.4, max_tokens=16000)
data = parse_json_safe(raw, "direction_analyzer")
return validate_direction_output(data)
# ============================================================
# 自测
# ============================================================
if __name__ == "__main__":
fix_windows_encoding()
# 构造模拟 GitHub 搜索结果
mock_repos = [
{
"full_name": "openvinotoolkit/anomalib",
"description": "An anomaly detection library comprising state-of-the-art algorithms.",
"stars": 4000, "language": "Python",
"topics": ["anomaly-detection", "pytorch", "benchmarking", "industrial-inspection"],
"html_url": "https://github.com/openvinotoolkit/anomalib",
},
{
"full_name": "hcw-00/PatchCore",
"description": "Official implementation of PatchCore: Towards Total Recall in Industrial Anomaly Detection.",
"stars": 850, "language": "Python",
"topics": ["anomaly-detection", "memory-bank", "pytorch"],
"html_url": "https://github.com/hcw-00/PatchCore",
},
{
"full_name": "taikiinoue45/STFPM",
"description": "Student-Teacher Feature Pyramid Matching for Anomaly Detection (PyTorch).",
"stars": 320, "language": "Python",
"topics": ["teacher-student", "anomaly-detection", "pytorch"],
"html_url": "https://github.com/taikiinoue45/STFPM",
},
{
"full_name": "VitjanZ/DRAEM",
"description": "DRAEM: Discriminatively trained reconstruction embedding for surface anomaly detection.",
"stars": 180, "language": "Python",
"topics": ["synthetic-defect", "anomaly-detection", "reconstruction"],
"html_url": "https://github.com/VitjanZ/DRAEM",
},
{
"full_name": "marugoto/CFLow",
"description": "CFLow: Real-time unsupervised anomaly detection via conditional normalizing flows.",
"stars": 130, "language": "Python",
"topics": ["normalizing-flow", "anomaly-detection", "pytorch"],
"html_url": "https://github.com/marugoto/CFLow",
},
{
"full_name": "xiahaifeng1995/PaDiM-Anomaly-Detection",
"description": "Unofficial implementation of PaDiM: Patch Distribution Modeling for anomaly detection.",
"stars": 150, "language": "Python",
"topics": ["padim", "anomaly-detection", "mvtec-ad"],
"html_url": "https://github.com/xiahaifeng1995/PaDiM-Anomaly-Detection",
},
{
"full_name": "someone/anomaly-detection-tutorial",
"description": "A tutorial collection of anomaly detection papers and resources.",
"stars": 45, "language": "Markdown",
"topics": ["awesome-list", "anomaly-detection"],
"html_url": "https://github.com/someone/anomaly-detection-tutorial",
},
]
print("=" * 60)
print("Agent 1 测试:基于 GitHub 仓库的方法族归纳")
print("=" * 60)
result = analyze_direction(
"PaDiM: a Patch Distribution Modeling Framework for Anomaly Detection and Localization",
"We present a new framework for anomaly detection and localization based on "
"patch distribution modeling. PaDiM uses pretrained convolutional neural networks "
"to extract patch features and models their distribution using multivariate "
"Gaussian distributions.",
mock_repos,
)
print(f"子领域: {result.get('subfield')}")
print(f"趋势: {result.get('subfield_trend')}")
print(f"方法族数量: {len(result.get('method_families', []))}")
for mf in result.get("method_families", []):
repos_in_family = mf.get("matched_repos", [])
print(f" - {mf.get('family_name')}: {mf.get('description')[:60]}...")
print(f" 归属仓库: {repos_in_family}")
print(f" 搜索词: {mf.get('search_queries', [])}")
print(f"宽泛搜索词: {result.get('broad_queries', [])}")
total_queries = sum(len(mf.get('search_queries', [])) for mf in result.get('method_families', [])) + len(result.get('broad_queries', []))
print(f"总搜索词数: {total_queries}")
print()
print("=" * 60)
print("Agent 1 测试:降级模式(无 GitHub 数据)")
print("=" * 60)
result2 = analyze_direction(
"Attention Is All You Need",
"The dominant sequence transduction models are based on complex recurrent "
"or convolutional neural networks...",
None,
)
print(f"子领域: {result2.get('subfield')}")
print(f"方法族: {[mf.get('family_name') for mf in result2.get('method_families', [])]}")