Threat_Hunter / tools /exploit_tool.py
EricChen2005's picture
Deploy ThreatHunter - AMD MI300X + Qwen2.5-32B
c8d30bc
# tools/exploit_tool.py
# 功能:GitHub Exploit / PoC 搜尋 Tool
# Harness 支柱:Graceful Degradation(降級瀑布)+ Observability(原子化日誌)
# 擁有者:成員 C(Analyst Agent Pipeline)
#
# 使用方式:
# from tools.exploit_tool import search_exploits
#
# 架構定位:
# Analyst Agent 的「第二隻手」— 搜尋 CVE 的公開 Exploit/PoC
# 有公開 exploit = 攻擊門檻極低 = 風險指標 HIGH
import json
import os
import time
import hashlib
import logging
from datetime import datetime, timezone
import requests
logger = logging.getLogger("ThreatHunter")
# ══════════════════════════════════════════════════════════════
# 常數
# ══════════════════════════════════════════════════════════════
GITHUB_API_BASE = "https://api.github.com/search/repositories"
RESULTS_PER_PAGE = 10 # 限制結果數量,避免 context 過長
REQUEST_TIMEOUT = 20 # 秒
# Rate limit 控制
# GitHub 無認證 10 req/min,有認證 30 req/min
RATE_LIMIT_WITH_TOKEN = 2.0 # 有 Token: 30 req/min → 2s 間隔(保守)
RATE_LIMIT_WITHOUT_TOKEN = 6.0 # 無 Token: 10 req/min → 6s 間隔(保守)
MAX_RETRIES = 2
# 離線快取
CACHE_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data")
CACHE_TTL = 3600 * 24 # 24 小時過期
# 上次請求時間(模組級 rate limiter)
_last_request_time = 0.0
# ══════════════════════════════════════════════════════════════
# 輔助函式
# ══════════════════════════════════════════════════════════════
def _get_github_token() -> str:
"""取得 GitHub Token(優先從 config,備選環境變數)"""
try:
from config import GITHUB_TOKEN
if GITHUB_TOKEN:
return GITHUB_TOKEN
except ImportError:
pass
return os.getenv("GITHUB_TOKEN", "")
def _get_cache_path(cve_id: str) -> str:
"""取得離線快取檔案路徑"""
safe_name = hashlib.md5(cve_id.encode()).hexdigest()[:12]
return os.path.join(CACHE_DIR, f"exploit_cache_{cve_id}_{safe_name}.json")
def _read_cache(cve_id: str) -> dict | None:
"""讀取離線快取,過期或不存在回傳 None"""
cache_path = _get_cache_path(cve_id)
try:
if os.path.exists(cache_path):
with open(cache_path, "r", encoding="utf-8") as f:
cached = json.load(f)
cached_time = cached.get("_cached_at", 0)
if time.time() - cached_time < CACHE_TTL:
logger.info("[OK] Exploit cache hit: %s", cve_id)
return cached
else:
logger.info("[INFO] Exploit cache expired: %s", cve_id)
except (json.JSONDecodeError, IOError) as e:
logger.warning("[WARN] Exploit cache read failed: %s", e)
return None
def _write_cache(cve_id: str, data: dict) -> None:
"""寫入離線快取"""
try:
os.makedirs(CACHE_DIR, exist_ok=True)
cache_path = _get_cache_path(cve_id)
data["_cached_at"] = time.time()
with open(cache_path, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
except (IOError, PermissionError) as e:
logger.warning("[WARN] Exploit cache write failed: %s", e)
def _rate_limit() -> None:
"""Rate limiter — 確保不超過 GitHub API 限速"""
global _last_request_time
token = _get_github_token()
interval = RATE_LIMIT_WITH_TOKEN if token else RATE_LIMIT_WITHOUT_TOKEN
elapsed = time.time() - _last_request_time
if elapsed < interval:
wait = interval - elapsed
logger.info("[WAIT] GitHub rate limit: waiting %.1fs", wait)
time.sleep(wait)
_last_request_time = time.time()
def _classify_exploit_type(name: str, description: str) -> str:
"""
根據 repo 名稱和描述判斷 exploit 類型。
分類規則:
- 含 scanner/detect/check → scanner
- 含 weaponized/exploit-db/metasploit → weaponized
- 其他(poc/proof/test/demo 等)→ poc
"""
text = f"{name} {description}".lower()
# 掃描器 / 偵測工具
scanner_keywords = ["scanner", "detect", "check", "scan", "checker", "finder"]
for kw in scanner_keywords:
if kw in text:
return "scanner"
# 武器化 exploit
weaponized_keywords = ["weaponized", "exploit-db", "metasploit", "payload", "shellcode", "reverse shell"]
for kw in weaponized_keywords:
if kw in text:
return "weaponized"
# 預設為 PoC
return "poc"
def _determine_risk_indicator(exploit_count: int, api_available: bool) -> str:
"""
根據 exploit 數量和 API 可用性判定風險指標。
邏輯:
exploit_count > 0 → "HIGH"(有公開 exploit = 高風險)
exploit_count == 0 且 API 可用 → "LOW"(確認沒有公開 exploit)
API 不可用 → "UNKNOWN (API limited)"
"""
if not api_available:
return "UNKNOWN (API limited)"
if exploit_count > 0:
return "HIGH"
return "LOW"
# ══════════════════════════════════════════════════════════════
# 核心查詢邏輯
# ══════════════════════════════════════════════════════════════
def _query_github_api(cve_id: str) -> dict | None:
"""
呼叫 GitHub Search API,回傳原始 JSON response dict。
失敗回傳 None。
包含:
- Rate limiting
- 重試機制(最多 MAX_RETRIES 次)
- Timeout 處理
"""
token = _get_github_token()
headers = {
"Accept": "application/vnd.github.v3+json",
}
if token:
headers["Authorization"] = f"token {token}"
# 搜尋策略:用 CVE ID 作為關鍵字 + exploit 限定
params = {
"q": f"{cve_id} exploit",
"sort": "stars",
"order": "desc",
"per_page": RESULTS_PER_PAGE,
}
for attempt in range(1, MAX_RETRIES + 1):
_rate_limit()
try:
logger.info("[QUERY] GitHub API search: %s (attempt %d)", cve_id, attempt)
response = requests.get(
GITHUB_API_BASE,
params=params,
headers=headers,
timeout=REQUEST_TIMEOUT,
)
if response.status_code == 200:
return response.json()
if response.status_code in (403, 429):
# Rate limit 被觸發
retry_after = response.headers.get("Retry-After", "60")
logger.warning(
"[WARN] GitHub API %d (rate limited), Retry-After: %ss",
response.status_code, retry_after
)
time.sleep(min(int(retry_after), 30)) # 最多等 30 秒
continue
if response.status_code >= 500:
logger.warning("[WARN] GitHub API %d (server error)", response.status_code)
time.sleep(2)
continue
# 其他錯誤碼
logger.warning("[WARN] GitHub API returned %d: %s", response.status_code, response.text[:200])
return None
except requests.exceptions.Timeout:
logger.warning("[WARN] GitHub API timeout (%ds)", REQUEST_TIMEOUT)
continue
except requests.exceptions.ConnectionError:
logger.warning("[WARN] GitHub API connection failed (network issue)")
continue
except requests.exceptions.RequestException as e:
logger.warning("[WARN] GitHub API request error: %s", e)
return None
return None # 所有重試都失敗
def _parse_github_response(raw: dict, cve_id: str) -> dict:
"""
將 GitHub Search API 原始 response 轉換為 Tool 輸出格式。
轉換 mapping:
response.items[].full_name → repo_name
response.items[].html_url → url
response.items[].stargazers_count → stars
response.items[].language → language
response.items[].updated_at → last_updated
response.items[].description → description
"""
exploits = []
items = raw.get("items", [])
for repo in items:
repo_name = repo.get("full_name", "")
description = repo.get("description", "") or ""
exploit_type = _classify_exploit_type(repo_name, description)
# 提取最後更新日期(只取日期部分)
updated_at = repo.get("updated_at", "")
if "T" in updated_at:
updated_at = updated_at.split("T")[0]
exploits.append({
"repo_name": repo_name,
"url": repo.get("html_url", ""),
"stars": repo.get("stargazers_count", 0),
"language": repo.get("language", "") or "Unknown",
"last_updated": updated_at,
"description": description[:300], # 截斷過長描述
"type": exploit_type,
})
exploit_count = len(exploits)
risk_indicator = _determine_risk_indicator(exploit_count, api_available=True)
return {
"cve_id": cve_id,
"source": "GitHub API",
"exploit_count": exploit_count,
"exploits": exploits,
"risk_indicator": risk_indicator,
}
def _search_exploits_impl(cve_id: str) -> str:
"""
search_exploits 的核心實作(與 CrewAI @tool 解耦,方便單元測試)。
降級瀑布:
1. 查 GitHub Search API → 成功則快取
2. API 失敗(403/429 rate limit)→ 讀離線快取
3. 快取也沒有 → 回傳 exploit_count: 0, risk_indicator: "UNKNOWN"
4. 任何未預期錯誤 → 回傳安全的預設結果(絕不 crash)
"""
try:
# 清理 CVE ID(去空白、統一大寫)
cve_id = cve_id.strip().upper()
if not cve_id:
logger.warning("[WARN] Exploit Tool received empty CVE ID input")
return json.dumps({
"cve_id": "",
"source": "GitHub API (error)",
"exploit_count": 0,
"exploits": [],
"risk_indicator": "UNKNOWN",
"error": "No CVE ID provided",
}, ensure_ascii=False, indent=2)
logger.info("[QUERY] Exploit search: %s", cve_id)
# 嘗試 GitHub API 查詢
raw = _query_github_api(cve_id)
if raw is not None:
result = _parse_github_response(raw, cve_id)
# 寫入快取供離線使用
_write_cache(cve_id, result)
logger.info(
"[OK] Exploit search success: %s -> %d exploit(s), risk_indicator=%s",
cve_id, result['exploit_count'], result['risk_indicator']
)
return json.dumps(result, ensure_ascii=False, indent=2)
# API 失敗 → 嘗試讀快取
cached = _read_cache(cve_id)
if cached:
cached.pop("_cached_at", None)
cached["source"] = "GitHub API (cache)"
cached["error"] = f"GitHub API unavailable, using cached data for '{cve_id}'"
logger.info("[OK] Exploit using cache: %s", cve_id)
return json.dumps(cached, ensure_ascii=False, indent=2)
# 完全沒有資料
empty_result = {
"cve_id": cve_id,
"source": "GitHub API (unavailable)",
"exploit_count": 0,
"exploits": [],
"risk_indicator": "UNKNOWN (API limited)",
"error": f"GitHub API unavailable and no cache for '{cve_id}'",
}
logger.info("[INFO] Exploit no data for: %s", cve_id)
return json.dumps(empty_result, ensure_ascii=False, indent=2)
except Exception as e:
# 最後一道防線:任何未預期錯誤都不能讓 Agent crash
logger.error("[FAIL] Exploit Tool unexpected error: %s", e, exc_info=True)
error_result = {
"cve_id": cve_id if 'cve_id' in dir() else "",
"source": "GitHub API (error)",
"exploit_count": 0,
"exploits": [],
"risk_indicator": "UNKNOWN",
"error": f"Unexpected error: {str(e)}",
}
return json.dumps(error_result, ensure_ascii=False, indent=2)
# ══════════════════════════════════════════════════════════════
# CrewAI @tool 包裝(Agent 呼叫用)
# ══════════════════════════════════════════════════════════════
# ⚠️ 重要:使用「延遲載入」模式(LazyToolLoader)
# 原因:避免在 import 階段就觸發 CrewAI 的 tool 註冊
def _create_tool():
"""延遲建立 CrewAI Tool,僅在 Agent 實際使用時才 import"""
from crewai.tools import tool
@tool("search_exploits")
def search_exploits(cve_id: str) -> str:
"""搜尋指定 CVE 的公開 Exploit 和 PoC 程式碼(透過 GitHub API)。
輸入單一 CVE ID(如 "CVE-2021-44228"),回傳公開 exploit 的數量、連結、星數等資訊。
有公開 exploit = 攻擊門檻極低 = 風險指標 HIGH。"""
return _search_exploits_impl(cve_id)
return search_exploits
# ── 延遲載入機制(與 nvd_tool.py 相同模式)──────────────────
class _LazyToolLoader:
def __init__(self):
self._tool = None
def _load(self):
if self._tool is None:
self._tool = _create_tool()
@property
def search_exploits(self):
self._load()
return self._tool
_loader = _LazyToolLoader()
def __getattr__(name):
"""模組層級 __getattr__,支援 from tools.exploit_tool import search_exploits"""
if name == "search_exploits":
return _loader.search_exploits
raise AttributeError(f"module 'tools.exploit_tool' has no attribute {name!r}")