ResearchRadar / supervisor.py
ZZZyx3587's picture
Upload supervisor.py with huggingface_hub
10b540d verified
# supervisor.py
# ============================================================
# 类型:审核层(乙负责)
# 功能:审核 Agent 1 和 Agent 2 的输出质量,检测偷懒/错误/来源不可靠等问题
# 用法:被 run.py 在 Agent 1 和 Agent 2 完成后自动调用
# ============================================================
from llm_utils import call_llm_json, parse_json_safe, fix_windows_encoding
SUPERVISOR_SYSTEM_PROMPT = """你是 ResearchRadar 的质量审核员。你的任务是审核一份研究报告的各个部分,判断它们是否达到了可发布的标准。
## 审核维度
### 1. 方向分析审核 (Agent 1 输出)
- 子领域是否足够具体?("深度学习"→不合格,"工业图像异常检测"→合格)
- 趋势分析是否包含具体的技术路线、关键突破或活跃研究组?(纯套话→不合格)
- 方法族是否完整覆盖了仓库列表?(有明显遗漏→不合格)
- 方法族描述是否包含技术原理?
### 2. 仓库评估审核 (Agent 2 输出)
- reasoning 是否足够详细?(按维度逐条分析→合格,一句话概括→不合格)
- risks 是否具体?(引用文件名→合格,"无依赖文件"笼统→不合格)
- 评分是否自洽?(README 完善但 env_score=2 等矛盾→不合格)
- suggested_use 是否可操作?
### 3. 来源可靠性审核
- GitHub 仓库 Stars < 5 且超过 2 年未更新 → 低可靠性
- 仓库缺少 README 或依赖文件 → 低可靠性
- 方法族没有任何仓库归属 → 标记为"研究空白"
## 输出格式(严格 JSON)
{
"overall_score": 85,
"direction_audit": {
"passed": true,
"issues": [],
"subfield_score": 8,
"trend_score": 7,
"family_score": 9
},
"evaluation_audit": {
"passed": true,
"issues": [],
"lazy_repos": [],
"avg_reasoning_length": 350
},
"source_audit": {
"passed": true,
"issues": [],
"unreliable_repos": []
},
"actions": []
}"""
def audit_direction(title: str, abstract: str, direction: dict) -> dict:
"""审核 Agent 1 的方向分析输出。
Returns:
dict: {passed, issues, subfield_score, trend_score, family_score}
"""
# ===== 规则检查 =====
issues = []
subfield_score = 10
trend_score = 10
family_score = 10
subfield = direction.get("subfield", "")
trend = direction.get("subfield_trend", "")
families = direction.get("method_families", [])
# 检查子领域具体性
if not subfield or subfield == "未知":
issues.append("子领域为空")
subfield_score = 0
elif len(subfield) < 6 or subfield.lower() in ("deep learning", "machine learning", "ai", "computer vision", "nlp"):
issues.append(f"子领域过于笼统: {subfield}")
subfield_score = 3
# 检查趋势分析深度
sentences = [s.strip() for s in trend.replace("。", ".").split(".") if s.strip()]
if len(sentences) < 2:
issues.append(f"趋势分析过于简短,仅 {len(sentences)} 句")
trend_score = 3
elif len(trend) < 80:
issues.append(f"趋势分析不足 80 字符")
trend_score = 4
# 检查方法族
if not families:
issues.append("未识别出任何方法族")
family_score = 0
else:
# 检查是否每个方法族都有描述
for mf in families:
desc = mf.get("description", "")
if len(desc) < 15:
issues.append(f"方法族 '{mf.get('family_name', '?')}' 描述过于简短")
family_score = min(family_score, 5)
if not mf.get("matched_repos"):
issues.append(f"方法族 '{mf.get('family_name', '?')}' 无归属仓库")
passed = len(issues) == 0 or (subfield_score + trend_score + family_score >= 20)
return {
"passed": passed,
"issues": issues,
"subfield_score": subfield_score,
"trend_score": trend_score,
"family_score": family_score,
}
def audit_evaluations(repos: list[dict]) -> dict:
"""审核 Agent 2 的仓库评估输出。
检测偷懒行为:reasoning 过短、risks 不足、评分不合理。
"""
issues = []
lazy_repos = []
reasoning_lengths = []
for r in repos:
ev = r.get("evaluation", {})
full_name = r.get("full_name", "?")
reasoning = ev.get("reasoning", "")
risks = ev.get("risks", [])
reasoning_lengths.append(len(reasoning))
is_lazy = False
repo_issues = []
# 检查 reasoning 长度
if len(reasoning) < 80:
repo_issues.append(f"reasoning 仅 {len(reasoning)} 字符")
is_lazy = True
elif len(reasoning) < 200:
repo_issues.append(f"reasoning 偏短 ({len(reasoning)} 字符)")
# 检查 risks
if not risks or len(risks) < 2:
repo_issues.append(f"risks 仅 {len(risks)} 个")
is_lazy = True
# 检查评分自洽
env_score = ev.get("env_score", 0)
doc_score = ev.get("doc_score", 0)
if env_score >= 10 and doc_score <= 2:
repo_issues.append("env_score 高但 doc_score 低,可能矛盾")
overall = ev.get("overall_score", 0)
stars = r.get("stars", 0)
if overall >= 80 and stars < 10:
repo_issues.append(f"高评分 ({overall}) 但仅 {stars} Stars,值得怀疑")
if is_lazy:
lazy_repos.append(full_name)
if repo_issues:
issues.append(f"[{full_name}] " + "; ".join(repo_issues))
avg_len = int(sum(reasoning_lengths) / max(len(reasoning_lengths), 1))
passed = len(lazy_repos) == 0 and len(issues) <= 1
return {
"passed": passed,
"issues": issues,
"lazy_repos": lazy_repos,
"avg_reasoning_length": avg_len,
}
def audit_sources(repos: list[dict]) -> dict:
"""审核信息来源可靠性。"""
issues = []
unreliable = []
for r in repos:
full_name = r.get("full_name", "?")
stars = r.get("stars", 0)
updated = r.get("updated_at", "")
readme = r.get("readme", "")
deps = r.get("dependencies", {})
is_unreliable = False
# 高 Star 仓库跳过基础检查
if stars >= 100:
continue
if stars < 5:
issues.append(f"[{full_name}] 仅 {stars} Stars,低影响力")
is_unreliable = True
if updated and updated < "2024-01-01":
issues.append(f"[{full_name}] 最后更新 {updated[:10]},超过 2 年未维护")
is_unreliable = True
if not readme or len(str(readme)) < 100:
issues.append(f"[{full_name}] README 缺失或过短")
is_unreliable = True
if not deps:
issues.append(f"[{full_name}] 无依赖文件")
is_unreliable = True
if is_unreliable:
unreliable.append(full_name)
passed = len(unreliable) <= len(repos) // 3 # 允许 1/3 的仓库质量不高
return {
"passed": passed,
"issues": issues,
"unreliable_repos": unreliable,
}
def supervise(title: str, abstract: str, direction: dict, repos: list[dict]) -> dict:
"""主编排函数:审核全部输出,生成质量报告。
Returns:
dict: {
overall_score, direction_audit, evaluation_audit,
source_audit, actions, summary
}
"""
d_audit = audit_direction(title, abstract, direction)
e_audit = audit_evaluations(repos)
s_audit = audit_sources(repos)
# 综合评分
d_weight = 0.4
e_weight = 0.3
s_weight = 0.3
d_avg = (d_audit["subfield_score"] + d_audit["trend_score"] + d_audit["family_score"]) / 3
e_avg = 10 - min(10, len(e_audit["lazy_repos"]) * 3 + len(e_audit["issues"]))
s_avg = 10 - min(10, len(s_audit["unreliable_repos"]) * 2)
overall = int(d_weight * d_avg * 10 + e_weight * e_avg * 10 + s_weight * s_avg * 10)
# 生成改进建议
actions = []
if d_audit["issues"]:
actions.append(f"方向分析存在问题: {'; '.join(d_audit['issues'][:3])}。建议调整 Agent 1 温度参数重试。")
if e_audit["lazy_repos"]:
actions.append(f"以下仓库的评估疑似偷懒: {', '.join(e_audit['lazy_repos'][:3])}。建议重跑 Agent 2。")
if s_audit["unreliable_repos"]:
actions.append(f"以下仓库来源可靠性低: {', '.join(s_audit['unreliable_repos'][:3])}。考虑降低其权重。")
# 生成人类可读摘要
all_issues = d_audit["issues"] + e_audit["issues"] + s_audit["issues"]
if not all_issues:
summary = "✅ 所有审核通过,报告质量良好。"
elif len(all_issues) <= 2:
summary = f"⚠️ 发现 {len(all_issues)} 个小问题,不影响整体质量。"
else:
summary = f"🔴 发现 {len(all_issues)} 个问题,建议关注改进建议。"
return {
"overall_score": overall,
"direction_audit": d_audit,
"evaluation_audit": e_audit,
"source_audit": s_audit,
"actions": actions,
"summary": summary,
}
# ============================================================
# 自测
# ============================================================
if __name__ == "__main__":
fix_windows_encoding()
# 模拟数据
mock_direction = {
"subfield": "工业图像异常检测",
"subfield_trend": "2024-2025年该领域主流趋势包括:1) 从基于重建的方法转向基于嵌入的方法,如PatchCore、PaDiM等利用预训练CNN提取特征;2) 多模态方法的兴起,如AnomalyGPT和WinCLIP结合视觉-语言模型;3) 从单类检测向多类统一检测发展。活跃研究组包括AWS、Intel OpenVINO团队、MVTec等。",
"method_families": [
{
"family_name": "Patch Distribution Modeling",
"description": "利用预训练CNN提取图像块级特征,建模多元高斯分布,通过马氏距离计算异常分数。优势在于无需训练、推理速度快,适用于工业部署场景。",
"representative_work": "PaDiM (ICPR 2021)",
"matched_repos": ["openvinotoolkit/anomalib", "xiahaifeng1995/PaDiM-Anomaly-Detection"],
"search_queries": ["padim anomaly detection pytorch"],
},
{
"family_name": "Memory Bank",
"description": "构建正常样本的特征记忆库,测试时通过最近邻检索判断异常。优势是可解释性强,但内存开销大。",
"representative_work": "PatchCore (CVPR 2022)",
"matched_repos": [],
"search_queries": ["patchcore anomaly detection pytorch"],
},
],
"broad_queries": ["anomaly detection pytorch benchmark", "industrial defect detection deep learning"],
}
mock_repos = [
{
"full_name": "openvinotoolkit/anomalib",
"stars": 4000, "updated_at": "2026-01-15",
"readme": "# Anomalib\nA library for anomaly detection...",
"dependencies": {"requirements.txt": "torch>=1.10"},
"evaluation": {
"reasoning": "【环境配置】提供 requirements.txt 含 torch>=1.10...【文档】README 含 pip install 步骤...【代码】提供 Engine 类封装训练流程...【社区】4000 Stars...", # noqa
"risks": ["部分依赖版本号使用>=范围", "部分模型预训练权重需要单独下载", "仅支持图像检测"],
"overall_score": 93, "env_score": 14, "doc_score": 18, "code_score": 18,
"community_score": 10, "dep_score": 15, "benchmark_score": 18,
"verdict": "reproducible", "benchmark_readiness": "ready",
"suggested_use": "可直接 pip install anomalib 安装,使用 tools/benchmark.py 评估",
},
},
{
"full_name": "someone/tiny-demo",
"stars": 3, "updated_at": "2023-01-01",
"readme": "",
"dependencies": {},
"evaluation": {
"reasoning": "还行",
"risks": [],
"overall_score": 80, "env_score": 15, "doc_score": 15,
"benchmark_score": 20,
"suggested_use": "可以用来跑对比实验",
},
},
]
print("=" * 60)
print("Supervisor Agent 自测")
print("=" * 60)
result = supervise("PaDiM: Patch Distribution Modeling", "anomaly detection", mock_direction, mock_repos)
print(f"综合评分: {result['overall_score']}/100")
print(f"摘要: {result['summary']}")
print(f"\n方向审核: {'✅' if result['direction_audit']['passed'] else '❌'}")
for issue in result["direction_audit"]["issues"]:
print(f" - {issue}")
print(f"\n评估审核: {'✅' if result['evaluation_audit']['passed'] else '❌'}")
print(f" 疑似偷懒: {result['evaluation_audit']['lazy_repos']}")
print(f" 平均 reasoning 长度: {result['evaluation_audit']['avg_reasoning_length']} 字符")
print(f"\n来源审核: {'✅' if result['source_audit']['passed'] else '❌'}")
print(f" 不可靠来源: {result['source_audit']['unreliable_repos']}")
print(f"\n改进建议:")
for action in result["actions"]:
print(f" - {action}")