Spaces:
Running
Running
File size: 13,470 Bytes
10b540d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 | # supervisor.py
# ============================================================
# 类型:审核层(乙负责)
# 功能:审核 Agent 1 和 Agent 2 的输出质量,检测偷懒/错误/来源不可靠等问题
# 用法:被 run.py 在 Agent 1 和 Agent 2 完成后自动调用
# ============================================================
from llm_utils import call_llm_json, parse_json_safe, fix_windows_encoding
SUPERVISOR_SYSTEM_PROMPT = """你是 ResearchRadar 的质量审核员。你的任务是审核一份研究报告的各个部分,判断它们是否达到了可发布的标准。
## 审核维度
### 1. 方向分析审核 (Agent 1 输出)
- 子领域是否足够具体?("深度学习"→不合格,"工业图像异常检测"→合格)
- 趋势分析是否包含具体的技术路线、关键突破或活跃研究组?(纯套话→不合格)
- 方法族是否完整覆盖了仓库列表?(有明显遗漏→不合格)
- 方法族描述是否包含技术原理?
### 2. 仓库评估审核 (Agent 2 输出)
- reasoning 是否足够详细?(按维度逐条分析→合格,一句话概括→不合格)
- risks 是否具体?(引用文件名→合格,"无依赖文件"笼统→不合格)
- 评分是否自洽?(README 完善但 env_score=2 等矛盾→不合格)
- suggested_use 是否可操作?
### 3. 来源可靠性审核
- GitHub 仓库 Stars < 5 且超过 2 年未更新 → 低可靠性
- 仓库缺少 README 或依赖文件 → 低可靠性
- 方法族没有任何仓库归属 → 标记为"研究空白"
## 输出格式(严格 JSON)
{
"overall_score": 85,
"direction_audit": {
"passed": true,
"issues": [],
"subfield_score": 8,
"trend_score": 7,
"family_score": 9
},
"evaluation_audit": {
"passed": true,
"issues": [],
"lazy_repos": [],
"avg_reasoning_length": 350
},
"source_audit": {
"passed": true,
"issues": [],
"unreliable_repos": []
},
"actions": []
}"""
def audit_direction(title: str, abstract: str, direction: dict) -> dict:
"""审核 Agent 1 的方向分析输出。
Returns:
dict: {passed, issues, subfield_score, trend_score, family_score}
"""
# ===== 规则检查 =====
issues = []
subfield_score = 10
trend_score = 10
family_score = 10
subfield = direction.get("subfield", "")
trend = direction.get("subfield_trend", "")
families = direction.get("method_families", [])
# 检查子领域具体性
if not subfield or subfield == "未知":
issues.append("子领域为空")
subfield_score = 0
elif len(subfield) < 6 or subfield.lower() in ("deep learning", "machine learning", "ai", "computer vision", "nlp"):
issues.append(f"子领域过于笼统: {subfield}")
subfield_score = 3
# 检查趋势分析深度
sentences = [s.strip() for s in trend.replace("。", ".").split(".") if s.strip()]
if len(sentences) < 2:
issues.append(f"趋势分析过于简短,仅 {len(sentences)} 句")
trend_score = 3
elif len(trend) < 80:
issues.append(f"趋势分析不足 80 字符")
trend_score = 4
# 检查方法族
if not families:
issues.append("未识别出任何方法族")
family_score = 0
else:
# 检查是否每个方法族都有描述
for mf in families:
desc = mf.get("description", "")
if len(desc) < 15:
issues.append(f"方法族 '{mf.get('family_name', '?')}' 描述过于简短")
family_score = min(family_score, 5)
if not mf.get("matched_repos"):
issues.append(f"方法族 '{mf.get('family_name', '?')}' 无归属仓库")
passed = len(issues) == 0 or (subfield_score + trend_score + family_score >= 20)
return {
"passed": passed,
"issues": issues,
"subfield_score": subfield_score,
"trend_score": trend_score,
"family_score": family_score,
}
def audit_evaluations(repos: list[dict]) -> dict:
"""审核 Agent 2 的仓库评估输出。
检测偷懒行为:reasoning 过短、risks 不足、评分不合理。
"""
issues = []
lazy_repos = []
reasoning_lengths = []
for r in repos:
ev = r.get("evaluation", {})
full_name = r.get("full_name", "?")
reasoning = ev.get("reasoning", "")
risks = ev.get("risks", [])
reasoning_lengths.append(len(reasoning))
is_lazy = False
repo_issues = []
# 检查 reasoning 长度
if len(reasoning) < 80:
repo_issues.append(f"reasoning 仅 {len(reasoning)} 字符")
is_lazy = True
elif len(reasoning) < 200:
repo_issues.append(f"reasoning 偏短 ({len(reasoning)} 字符)")
# 检查 risks
if not risks or len(risks) < 2:
repo_issues.append(f"risks 仅 {len(risks)} 个")
is_lazy = True
# 检查评分自洽
env_score = ev.get("env_score", 0)
doc_score = ev.get("doc_score", 0)
if env_score >= 10 and doc_score <= 2:
repo_issues.append("env_score 高但 doc_score 低,可能矛盾")
overall = ev.get("overall_score", 0)
stars = r.get("stars", 0)
if overall >= 80 and stars < 10:
repo_issues.append(f"高评分 ({overall}) 但仅 {stars} Stars,值得怀疑")
if is_lazy:
lazy_repos.append(full_name)
if repo_issues:
issues.append(f"[{full_name}] " + "; ".join(repo_issues))
avg_len = int(sum(reasoning_lengths) / max(len(reasoning_lengths), 1))
passed = len(lazy_repos) == 0 and len(issues) <= 1
return {
"passed": passed,
"issues": issues,
"lazy_repos": lazy_repos,
"avg_reasoning_length": avg_len,
}
def audit_sources(repos: list[dict]) -> dict:
"""审核信息来源可靠性。"""
issues = []
unreliable = []
for r in repos:
full_name = r.get("full_name", "?")
stars = r.get("stars", 0)
updated = r.get("updated_at", "")
readme = r.get("readme", "")
deps = r.get("dependencies", {})
is_unreliable = False
# 高 Star 仓库跳过基础检查
if stars >= 100:
continue
if stars < 5:
issues.append(f"[{full_name}] 仅 {stars} Stars,低影响力")
is_unreliable = True
if updated and updated < "2024-01-01":
issues.append(f"[{full_name}] 最后更新 {updated[:10]},超过 2 年未维护")
is_unreliable = True
if not readme or len(str(readme)) < 100:
issues.append(f"[{full_name}] README 缺失或过短")
is_unreliable = True
if not deps:
issues.append(f"[{full_name}] 无依赖文件")
is_unreliable = True
if is_unreliable:
unreliable.append(full_name)
passed = len(unreliable) <= len(repos) // 3 # 允许 1/3 的仓库质量不高
return {
"passed": passed,
"issues": issues,
"unreliable_repos": unreliable,
}
def supervise(title: str, abstract: str, direction: dict, repos: list[dict]) -> dict:
"""主编排函数:审核全部输出,生成质量报告。
Returns:
dict: {
overall_score, direction_audit, evaluation_audit,
source_audit, actions, summary
}
"""
d_audit = audit_direction(title, abstract, direction)
e_audit = audit_evaluations(repos)
s_audit = audit_sources(repos)
# 综合评分
d_weight = 0.4
e_weight = 0.3
s_weight = 0.3
d_avg = (d_audit["subfield_score"] + d_audit["trend_score"] + d_audit["family_score"]) / 3
e_avg = 10 - min(10, len(e_audit["lazy_repos"]) * 3 + len(e_audit["issues"]))
s_avg = 10 - min(10, len(s_audit["unreliable_repos"]) * 2)
overall = int(d_weight * d_avg * 10 + e_weight * e_avg * 10 + s_weight * s_avg * 10)
# 生成改进建议
actions = []
if d_audit["issues"]:
actions.append(f"方向分析存在问题: {'; '.join(d_audit['issues'][:3])}。建议调整 Agent 1 温度参数重试。")
if e_audit["lazy_repos"]:
actions.append(f"以下仓库的评估疑似偷懒: {', '.join(e_audit['lazy_repos'][:3])}。建议重跑 Agent 2。")
if s_audit["unreliable_repos"]:
actions.append(f"以下仓库来源可靠性低: {', '.join(s_audit['unreliable_repos'][:3])}。考虑降低其权重。")
# 生成人类可读摘要
all_issues = d_audit["issues"] + e_audit["issues"] + s_audit["issues"]
if not all_issues:
summary = "✅ 所有审核通过,报告质量良好。"
elif len(all_issues) <= 2:
summary = f"⚠️ 发现 {len(all_issues)} 个小问题,不影响整体质量。"
else:
summary = f"🔴 发现 {len(all_issues)} 个问题,建议关注改进建议。"
return {
"overall_score": overall,
"direction_audit": d_audit,
"evaluation_audit": e_audit,
"source_audit": s_audit,
"actions": actions,
"summary": summary,
}
# ============================================================
# 自测
# ============================================================
if __name__ == "__main__":
fix_windows_encoding()
# 模拟数据
mock_direction = {
"subfield": "工业图像异常检测",
"subfield_trend": "2024-2025年该领域主流趋势包括:1) 从基于重建的方法转向基于嵌入的方法,如PatchCore、PaDiM等利用预训练CNN提取特征;2) 多模态方法的兴起,如AnomalyGPT和WinCLIP结合视觉-语言模型;3) 从单类检测向多类统一检测发展。活跃研究组包括AWS、Intel OpenVINO团队、MVTec等。",
"method_families": [
{
"family_name": "Patch Distribution Modeling",
"description": "利用预训练CNN提取图像块级特征,建模多元高斯分布,通过马氏距离计算异常分数。优势在于无需训练、推理速度快,适用于工业部署场景。",
"representative_work": "PaDiM (ICPR 2021)",
"matched_repos": ["openvinotoolkit/anomalib", "xiahaifeng1995/PaDiM-Anomaly-Detection"],
"search_queries": ["padim anomaly detection pytorch"],
},
{
"family_name": "Memory Bank",
"description": "构建正常样本的特征记忆库,测试时通过最近邻检索判断异常。优势是可解释性强,但内存开销大。",
"representative_work": "PatchCore (CVPR 2022)",
"matched_repos": [],
"search_queries": ["patchcore anomaly detection pytorch"],
},
],
"broad_queries": ["anomaly detection pytorch benchmark", "industrial defect detection deep learning"],
}
mock_repos = [
{
"full_name": "openvinotoolkit/anomalib",
"stars": 4000, "updated_at": "2026-01-15",
"readme": "# Anomalib\nA library for anomaly detection...",
"dependencies": {"requirements.txt": "torch>=1.10"},
"evaluation": {
"reasoning": "【环境配置】提供 requirements.txt 含 torch>=1.10...【文档】README 含 pip install 步骤...【代码】提供 Engine 类封装训练流程...【社区】4000 Stars...", # noqa
"risks": ["部分依赖版本号使用>=范围", "部分模型预训练权重需要单独下载", "仅支持图像检测"],
"overall_score": 93, "env_score": 14, "doc_score": 18, "code_score": 18,
"community_score": 10, "dep_score": 15, "benchmark_score": 18,
"verdict": "reproducible", "benchmark_readiness": "ready",
"suggested_use": "可直接 pip install anomalib 安装,使用 tools/benchmark.py 评估",
},
},
{
"full_name": "someone/tiny-demo",
"stars": 3, "updated_at": "2023-01-01",
"readme": "",
"dependencies": {},
"evaluation": {
"reasoning": "还行",
"risks": [],
"overall_score": 80, "env_score": 15, "doc_score": 15,
"benchmark_score": 20,
"suggested_use": "可以用来跑对比实验",
},
},
]
print("=" * 60)
print("Supervisor Agent 自测")
print("=" * 60)
result = supervise("PaDiM: Patch Distribution Modeling", "anomaly detection", mock_direction, mock_repos)
print(f"综合评分: {result['overall_score']}/100")
print(f"摘要: {result['summary']}")
print(f"\n方向审核: {'✅' if result['direction_audit']['passed'] else '❌'}")
for issue in result["direction_audit"]["issues"]:
print(f" - {issue}")
print(f"\n评估审核: {'✅' if result['evaluation_audit']['passed'] else '❌'}")
print(f" 疑似偷懒: {result['evaluation_audit']['lazy_repos']}")
print(f" 平均 reasoning 长度: {result['evaluation_audit']['avg_reasoning_length']} 字符")
print(f"\n来源审核: {'✅' if result['source_audit']['passed'] else '❌'}")
print(f" 不可靠来源: {result['source_audit']['unreliable_repos']}")
print(f"\n改进建议:")
for action in result["actions"]:
print(f" - {action}")
|