Threat_Hunter / tools /nvd_tool.py
EricChen2005's picture
Deploy ThreatHunter - AMD MI300X + Qwen2.5-32B
c8d30bc
# tools/nvd_tool.py
# ๅŠŸ่ƒฝ๏ผšNVD (National Vulnerability Database) ๆผๆดžๆŸฅ่ฉข Tool
# Harness ๆ”ฏๆŸฑ๏ผšGraceful Degradation๏ผˆไบ”ๅฑค้™็ดš็€‘ๅธƒ๏ผ‰+ Observability๏ผˆๅŽŸๅญๅŒ–ๆ—ฅ่ชŒ๏ผ‰
# ๆ“ๆœ‰่€…๏ผšๆˆๅ“ก B๏ผˆScout Agent Pipeline๏ผ‰
#
# ไฝฟ็”จๆ–นๅผ๏ผš
# from tools.nvd_tool import search_nvd
#
# ๆžถๆง‹ๅฎšไฝ๏ผš
# Scout Agent ็š„ใ€Œๆ‰‹ใ€โ€” ่ฒ ่ฒฌๆŸฅ่ฉข NVD API ๅ–ๅพ— CVE ๆธ…ๅ–ฎ
# Agent ้€้Ž ReAct ่ฟดๅœˆ่‡ชๅ‹•ๆฑบๅฎšไฝ•ๆ™‚ๅ‘ผๅซๆญค Tool
import json
import os
import time
import hashlib
import logging
from datetime import datetime, timezone
import requests
logger = logging.getLogger("ThreatHunter")
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
# ๅธธๆ•ธ
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
NVD_API_BASE = "https://services.nvd.nist.gov/rest/json/cves/2.0"
RESULTS_PER_PAGE = 10 # Agent ่ผธๅ…ฅ็š„ context ๆœ‰้™๏ผŒๅคชๅคš CVE ๆœƒๅฐŽ่‡ด LLM ๅฟฝ็•ฅๅทฅๅ…ท่ผธๅ‡บ
REQUEST_TIMEOUT = 30 # ็ง’
# Rate limit ๆŽงๅˆถ
RATE_LIMIT_WITH_KEY = 0.6 # ๆœ‰ API Key: 50 req / 30s โ†’ 0.6s ้–“้š”
RATE_LIMIT_WITHOUT_KEY = 6.0 # ็„ก API Key: 5 req / 30s โ†’ 6s ้–“้š”
MAX_RETRIES = 2
# ้›ข็ทšๅฟซๅ–
CACHE_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data")
CACHE_TTL = 3600 * 24 # 24 ๅฐๆ™‚้ŽๆœŸ
# ๅฅ—ไปถๅ็จฑๅฐๆ‡‰่กจ
PACKAGE_MAP_PATH = os.path.join(CACHE_DIR, "package_map.json")
# ไธŠๆฌก่ซ‹ๆฑ‚ๆ™‚้–“๏ผˆๆจก็ต„็ดš rate limiter๏ผ‰
_last_request_time = 0.0
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
# ่ผ”ๅŠฉๅ‡ฝๅผ
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
def _load_package_map() -> dict:
"""่ผ‰ๅ…ฅๅฅ—ไปถๅ็จฑๅฐๆ‡‰่กจ๏ผŒๅคฑๆ•—ๅ›žๅ‚ณ็ฉบ dict"""
try:
if os.path.exists(PACKAGE_MAP_PATH):
with open(PACKAGE_MAP_PATH, "r", encoding="utf-8") as f:
return json.load(f)
except (json.JSONDecodeError, IOError) as e:
logger.warning("[WARN] package_map.json load failed: %s", e)
return {}
def _normalize_package_name(raw_name: str) -> list[str]:
"""
ๅฐ‡ไฝฟ็”จ่€…่ผธๅ…ฅ็š„ๅฅ—ไปถๅ็จฑๆญฃ่ฆๅŒ–๏ผŒๅ›žๅ‚ณๅฏ่ƒฝ็š„ๆŸฅ่ฉขๅ็จฑๅˆ—่กจใ€‚
็ฌฌไธ€ๅ€‹ๆ˜ฏๆœ€ๅฏ่ƒฝ็š„๏ผŒๅพŒ็บŒๆ˜ฏๅˆฅๅๅ‚™้ธใ€‚
ไพ‹ๅฆ‚๏ผš
"postgres" โ†’ ["postgresql", "postgres"]
"django" โ†’ ["django"]
"""
name = raw_name.strip().lower()
# ๅŽปๆމ็‰ˆๆœฌ่™Ÿ๏ผˆๅฆ‚ "django 4.2" โ†’ "django"๏ผ‰
name = name.split()[0] if " " in name else name
pkg_map = _load_package_map()
candidates = []
if name in pkg_map:
mapped = pkg_map[name]
candidates.append(mapped)
if mapped != name:
candidates.append(name)
else:
candidates.append(name)
# ๅๆŸฅ๏ผš็œ‹ๆœ‰ๆฒ’ๆœ‰ๅˆฅๅๆŒ‡ๅ‘่‡ชๅทฑ
for alias, target in pkg_map.items():
if target == name and alias not in candidates:
candidates.append(alias)
return candidates
def _get_cache_path(package_name: str) -> str:
"""ๅ–ๅพ—้›ข็ทšๅฟซๅ–ๆช”ๆกˆ่ทฏๅพ‘"""
safe_name = hashlib.md5(package_name.encode()).hexdigest()[:12]
return os.path.join(CACHE_DIR, f"nvd_cache_{package_name}_{safe_name}.json")
def _read_cache(package_name: str, allow_stale: bool = False) -> dict | None:
"""่ฎ€ๅ–้›ข็ทšๅฟซๅ–๏ผŒ้ŽๆœŸๆˆ–ไธๅญ˜ๅœจๅ›žๅ‚ณ Noneใ€‚"""
cache_path = _get_cache_path(package_name)
try:
if os.path.exists(cache_path):
with open(cache_path, "r", encoding="utf-8") as f:
cached = json.load(f)
cached_time = cached.get("_cached_at", 0)
cache_age = time.time() - cached_time
if cache_age < CACHE_TTL:
logger.info("[OK] NVD cache hit: %s", package_name)
return cached
if allow_stale:
logger.warning("[WARN] NVD stale cache fallback: %s (age=%.0fs)", package_name, cache_age)
return cached
logger.info("[INFO] NVD cache expired: %s", package_name)
except (json.JSONDecodeError, IOError) as e:
logger.warning("[WARN] NVD cache read failed: %s", e)
return None
def _write_cache(package_name: str, data: dict) -> None:
"""ๅฏซๅ…ฅ้›ข็ทšๅฟซๅ–"""
try:
os.makedirs(CACHE_DIR, exist_ok=True)
cache_path = _get_cache_path(package_name)
data["_cached_at"] = time.time()
with open(cache_path, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
except (IOError, PermissionError) as e:
logger.warning("[WARN] NVD cache write failed: %s", e)
def _rate_limit() -> None:
"""Rate limiter โ€” ็ขบไฟไธ่ถ…้Ž NVD API ้™้€Ÿ"""
global _last_request_time
api_key = os.getenv("NVD_API_KEY", "")
interval = RATE_LIMIT_WITH_KEY if api_key else RATE_LIMIT_WITHOUT_KEY
elapsed = time.time() - _last_request_time
if elapsed < interval:
wait = interval - elapsed
logger.info("[WAIT] NVD rate limit: waiting %.1fs", wait)
time.sleep(wait)
_last_request_time = time.time()
def _extract_cvss(metrics: dict) -> tuple[float, str]:
"""
ๅพž NVD metrics ไธญๆๅ– CVSS ๅˆ†ๆ•ธๅ’Œๅšด้‡ๅบฆใ€‚
ๅ„ชๅ…ˆ v3.1 โ†’ v3.0 โ†’ v2 โ†’ ้ ่จญๅ€ผใ€‚
Returns:
(cvss_score, severity)
"""
# ๅ˜—่ฉฆ CVSS v3.1
v31 = metrics.get("cvssMetricV31", [])
if v31:
data = v31[0].get("cvssData", {})
score = data.get("baseScore", 0.0)
severity = data.get("baseSeverity", "")
if score and severity:
return float(score), severity.upper()
# ๅ˜—่ฉฆ CVSS v3.0
v30 = metrics.get("cvssMetricV30", [])
if v30:
data = v30[0].get("cvssData", {})
score = data.get("baseScore", 0.0)
severity = data.get("baseSeverity", "")
if score and severity:
return float(score), severity.upper()
# ๅ˜—่ฉฆ CVSS v2๏ผˆๅ‚™็”จ๏ผ‰
v2 = metrics.get("cvssMetricV2", [])
if v2:
data = v2[0].get("cvssData", {})
score = data.get("baseScore", 0.0)
if score:
severity = _cvss_to_severity(float(score))
return float(score), severity
return 0.0, "LOW"
def _cvss_to_severity(score: float) -> str:
"""CVSS ๅˆ†ๆ•ธ โ†’ ๅšด้‡ๅบฆ็ญ‰็ดš่ฝ‰ๆ›๏ผˆๅƒ…ๅœจ API ๆœชๆไพ› severity ๆ™‚ไฝฟ็”จ๏ผ‰"""
if score >= 9.0:
return "CRITICAL"
elif score >= 7.0:
return "HIGH"
elif score >= 4.0:
return "MEDIUM"
else:
return "LOW"
def _extract_affected_versions(configurations: list) -> str:
"""ๅ˜—่ฉฆๅพž NVD configurations ๆๅ–ๅ—ๅฝฑ้Ÿฟ็‰ˆๆœฌ็ฏ„ๅœ"""
versions = []
try:
for config in configurations:
for node in config.get("nodes", []):
for cpe_match in node.get("cpeMatch", []):
if cpe_match.get("vulnerable", False):
cpe = cpe_match.get("criteria", "")
version_start = cpe_match.get("versionStartIncluding", "")
version_end = cpe_match.get("versionEndExcluding", "")
version_end_incl = cpe_match.get("versionEndIncluding", "")
if version_end:
versions.append(f"< {version_end}")
elif version_end_incl:
versions.append(f"<= {version_end_incl}")
elif version_start:
versions.append(f">= {version_start}")
elif cpe:
# ๅพž CPE URI ๆๅ–็‰ˆๆœฌ
parts = cpe.split(":")
if len(parts) > 5 and parts[5] not in ("*", "-"):
versions.append(parts[5])
except (KeyError, IndexError, TypeError):
pass
return ", ".join(versions[:3]) if versions else ""
def _extract_description(descriptions: list) -> str:
"""ๆๅ–่‹ฑๆ–‡ๆ่ฟฐ๏ผŒๅ„ชๅ…ˆ en๏ผŒfallback ๅˆฐ็ฌฌไธ€ๅ€‹"""
for desc in descriptions:
if desc.get("lang", "") == "en":
return desc.get("value", "")
if descriptions:
return descriptions[0].get("value", "")
return ""
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
# ๆ ธๅฟƒๆŸฅ่ฉข้‚่ผฏ
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
def _query_nvd_api(keyword: str) -> dict | None:
"""
ๅ‘ผๅซ NVD API๏ผŒไปฅ keywordSearch ๅ…จๆ–‡ๆœๅฐ‹ใ€‚
ๅคฑๆ•—ๅ›žๅ‚ณ Noneใ€‚
"""
api_key = os.getenv("NVD_API_KEY", "")
headers = {"apiKey": api_key} if api_key else {}
params = {
"keywordSearch": keyword,
"resultsPerPage": RESULTS_PER_PAGE,
}
for attempt in range(1, MAX_RETRIES + 1):
_rate_limit()
try:
logger.info("[QUERY] NVD keywordSearch: %s (attempt %d)", keyword, attempt)
response = requests.get(NVD_API_BASE, params=params, headers=headers, timeout=REQUEST_TIMEOUT)
if response.status_code == 200:
return response.json()
if response.status_code == 403:
logger.warning("[WARN] NVD API 403 (rate limited), retrying...")
time.sleep(RATE_LIMIT_WITHOUT_KEY * 2)
continue
if response.status_code >= 500:
logger.warning("[WARN] NVD API %d (server error)", response.status_code)
time.sleep(2)
continue
logger.warning("[WARN] NVD API returned %d: %s", response.status_code, response.text[:200])
return None
except requests.exceptions.Timeout:
logger.warning("[WARN] NVD API timeout (%ds)", REQUEST_TIMEOUT)
continue
except requests.exceptions.ConnectionError:
logger.warning("[WARN] NVD API connection failed (network issue)")
continue
except requests.exceptions.RequestException as e:
logger.warning("[WARN] NVD API request error: %s", e)
return None
return None
def _query_nvd_api_cpe(cpe_name: str) -> dict | None:
"""
ๅ‘ผๅซ NVD API๏ผŒไปฅ cpeName ็ฒพ็ขบๆœๅฐ‹ใ€‚
ๆฏ” keywordSearch ็ฒพ็ขบ โ€” ๅชๅ›žๅ‚ณๅ—ๅฝฑ้Ÿฟ CPE ๆฏ”ๅฐๆˆๅŠŸ็š„ CVE๏ผŒ
้ฟๅ…่ชžๆณ•้—œ้ตๅญ—๏ผˆevalใ€html ็ญ‰๏ผ‰ๆฑกๆŸ“็ตๆžœใ€‚
ๅคฑๆ•—ๅ›žๅ‚ณ Noneใ€‚
"""
api_key = os.getenv("NVD_API_KEY", "")
headers = {"apiKey": api_key} if api_key else {}
params = {
"cpeName": cpe_name,
"resultsPerPage": RESULTS_PER_PAGE,
}
for attempt in range(1, MAX_RETRIES + 1):
_rate_limit()
try:
logger.info("[QUERY] NVD cpeName: %s (attempt %d)", cpe_name, attempt)
response = requests.get(NVD_API_BASE, params=params, headers=headers, timeout=REQUEST_TIMEOUT)
if response.status_code == 200:
return response.json()
if response.status_code == 403:
logger.warning("[WARN] NVD API 403 (rate limited), retrying...")
time.sleep(RATE_LIMIT_WITHOUT_KEY * 2)
continue
if response.status_code >= 500:
time.sleep(2)
continue
logger.warning("[WARN] NVD cpeName returned %d", response.status_code)
return None
except requests.exceptions.Timeout:
continue
except requests.exceptions.ConnectionError:
continue
except requests.exceptions.RequestException as e:
logger.warning("[WARN] NVD cpe request error: %s", e)
return None
return None
def _extract_cpe_vendors(configurations: list) -> list[str]:
"""
ๅพž NVD configurations ๆๅ–ๅ—ๅฝฑ้Ÿฟ CPE ็š„ vendor:product ็ต„ๅˆใ€‚
ไพ› Analyst CPE ็›ธ้—œๆ€ง้Žๆฟพไฝฟ็”จใ€‚
ๅ›žๅ‚ณๆ ผๅผๅฆ‚๏ผš["nodejs:node.js", "expressjs:express"]
"""
vendors = []
try:
for config in configurations:
for node in config.get("nodes", []):
for cpe_match in node.get("cpeMatch", []):
if cpe_match.get("vulnerable", False):
cpe = cpe_match.get("criteria", "")
parts = cpe.split(":")
# cpe:2.3:a:vendor:product:version:...
if len(parts) >= 5:
vendor_product = f"{parts[3]}:{parts[4]}"
if vendor_product not in vendors:
vendors.append(vendor_product)
except (KeyError, IndexError, TypeError):
pass
return vendors[:10]
def _parse_nvd_response(raw: dict, package_name: str) -> dict:
"""
ๅฐ‡ NVD API ๅŽŸๅง‹ response ่ฝ‰ๆ›็‚บ Tool ่ผธๅ‡บๆ ผๅผใ€‚
v3.8: ่ผธๅ‡บ cpe_vendors ไพ› Analyst ๅš็›ธ้—œๆ€ง้ฉ—่ญ‰ใ€‚
"""
vulnerabilities = []
raw_vulns = raw.get("vulnerabilities", [])
for item in raw_vulns:
cve = item.get("cve", {})
cve_id = cve.get("id", "")
# ่ทณ้Ž้žๆจ™ๆบ– CVE ID
if not cve_id.startswith("CVE-"):
continue
description = _extract_description(cve.get("descriptions", []))
metrics = cve.get("metrics", {})
cvss_score, severity = _extract_cvss(metrics)
published = cve.get("published", "")
configurations = cve.get("configurations", [])
affected_versions = _extract_affected_versions(configurations)
cpe_vendors = _extract_cpe_vendors(configurations) # v3.8: ไพ›็›ธ้—œๆ€ง้ฉ—่ญ‰
vulnerabilities.append({
"cve_id": cve_id,
"cvss_score": cvss_score,
"severity": severity,
"description": description[:500],
"published": published,
"affected_versions": affected_versions,
"cpe_vendors": cpe_vendors, # v3.8: Analyst ็”จๆ–ผ CPE ็›ธ้—œๆ€ง้Žๆฟพ
})
# ๆŒ‰ CVSS ๅˆ†ๆ•ธ้™ๅบๆŽ’ๅˆ—๏ผˆๆœ€ๅฑ้šช็š„ๅœจๆœ€ๅ‰้ข๏ผ‰
vulnerabilities.sort(key=lambda v: v["cvss_score"], reverse=True)
return {
"package": package_name,
"source": "NVD",
"count": len(vulnerabilities),
"vulnerabilities": vulnerabilities,
}
# CPE ๅ็จฑๆŽจๆ–ทๅฐๆ‡‰่กจ๏ผˆๅฅ—ไปถๅ โ†’ NVD CPE vendor:product๏ผ‰
# ๆœชๅ‘ฝไธญ็š„ๅฅ—ไปถ fallback ๅˆฐ keywordSearch
PACKAGE_CPE_MAP: dict[str, str] = {
# Node.js ็”Ÿๆ…‹
"express": "cpe:2.3:a:expressjs:express:*:*:*:*:*:*:*:*",
"node": "cpe:2.3:a:nodejs:node.js:*:*:*:*:*:*:*:*",
"nodejs": "cpe:2.3:a:nodejs:node.js:*:*:*:*:*:*:*:*",
"lodash": "cpe:2.3:a:lodash:lodash:*:*:*:*:*:*:*:*",
"axios": "cpe:2.3:a:axios:axios:*:*:*:*:*:node.js:*:*",
"webpack": "cpe:2.3:a:webpack:webpack:*:*:*:*:*:node.js:*:*",
"moment": "cpe:2.3:a:momentjs:moment.js:*:*:*:*:*:node.js:*:*",
"next": "cpe:2.3:a:vercel:next.js:*:*:*:*:*:node.js:*:*",
"nextjs": "cpe:2.3:a:vercel:next.js:*:*:*:*:*:node.js:*:*",
"react": "cpe:2.3:a:facebook:react:*:*:*:*:*:node.js:*:*",
"vue": "cpe:2.3:a:vuejs:vue.js:*:*:*:*:*:node.js:*:*",
"angular": "cpe:2.3:a:google:angular.js:*:*:*:*:*:node.js:*:*",
# Python ็”Ÿๆ…‹
"django": "cpe:2.3:a:djangoproject:django:*:*:*:*:*:*:*:*",
"flask": "cpe:2.3:a:palletsprojects:flask:*:*:*:*:*:*:*:*",
"requests": "cpe:2.3:a:python-requests:requests:*:*:*:*:*:*:*:*",
"pillow": "cpe:2.3:a:python:pillow:*:*:*:*:*:*:*:*",
"pyyaml": "cpe:2.3:a:pyyaml:pyyaml:*:*:*:*:*:*:*:*",
"cryptography": "cpe:2.3:a:cryptography.io:cryptography:*:*:*:*:*:python:*:*",
"jinja2": "cpe:2.3:a:palletsprojects:jinja:*:*:*:*:*:python:*:*",
"werkzeug": "cpe:2.3:a:palletsprojects:werkzeug:*:*:*:*:*:python:*:*",
"sqlalchemy": "cpe:2.3:a:sqlalchemy:sqlalchemy:*:*:*:*:*:*:*:*",
# Java ็”Ÿๆ…‹
"log4j": "cpe:2.3:a:apache:log4j:*:*:*:*:*:*:*:*",
"spring": "cpe:2.3:a:pivotal_software:spring_framework:*:*:*:*:*:*:*:*",
"struts": "cpe:2.3:a:apache:struts:*:*:*:*:*:*:*:*",
# Go ็”Ÿๆ…‹
"go": "cpe:2.3:a:golang:go:*:*:*:*:*:*:*:*",
# DB
"redis": "cpe:2.3:a:redis:redis:*:*:*:*:*:*:*:*",
"postgresql": "cpe:2.3:a:postgresql:postgresql:*:*:*:*:*:*:*:*",
"postgres": "cpe:2.3:a:postgresql:postgresql:*:*:*:*:*:*:*:*",
"mysql": "cpe:2.3:a:mysql:mysql:*:*:*:*:*:*:*:*",
"mongodb": "cpe:2.3:a:mongodb:mongodb:*:*:*:*:*:*:*:*",
"nginx": "cpe:2.3:a:nginx:nginx:*:*:*:*:*:*:*:*",
"openssl": "cpe:2.3:a:openssl:openssl:*:*:*:*:*:*:*:*",
}
def _search_nvd_impl(package_name: str) -> str:
"""
search_nvd ๆ ธๅฟƒๅฏฆไฝœ๏ผˆv3.8๏ผ‰ใ€‚
ๆœๅฐ‹็ญ–็•ฅๅ„ชๅ…ˆ้ †ๅบ๏ผš
1. ๅฟซๅ–ๅ‘ฝไธญ โ†’ ็›ดๆŽฅๅ›žๅ‚ณ๏ผˆCache-First๏ผ‰
2. CPE ็ฒพ็ขบๆœๅฐ‹๏ผˆPACKAGE_CPE_MAP ๅ‘ฝไธญๆ™‚๏ผ‰โ†’ ๅชๅ›žๅ‚ณ็œŸๆญฃๅฝฑ้Ÿฟ่ฉฒๅฅ—ไปถ็š„ CVE
3. Keyword ๅ…จๆ–‡ๆœๅฐ‹๏ผˆCPE ๆœชๅ‘ฝไธญ fallback๏ผ‰
4. ้›ข็ทšๅฟซๅ– fallback
5. ๅ›žๅ‚ณ็ฉบ็ตๆžœ๏ผˆ็ต•ไธ crash๏ผ‰
"""
try:
candidates = _normalize_package_name(package_name)
logger.info("[QUERY] NVD package: %s -> candidates: %s", package_name, candidates)
# โ”€โ”€ 1. Cache-First โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
for keyword in candidates:
cached = _read_cache(keyword)
if cached:
cached.pop("_cached_at", None)
cached["fallback_used"] = False
logger.info("[OK] NVD cache hit: %s -> %d CVEs",
keyword, len(cached.get("vulnerabilities", [])))
return json.dumps(cached, ensure_ascii=False, indent=2)
# โ”€โ”€ 2. CPE ็ฒพ็ขบๆœๅฐ‹๏ผˆ้˜ฒๆญข่ชžๆณ•้—œ้ตๅญ—ๆฑกๆŸ“ NVD ็ตๆžœ๏ผ‰โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
primary = candidates[0]
cpe_name = PACKAGE_CPE_MAP.get(primary)
if cpe_name:
raw = _query_nvd_api_cpe(cpe_name)
if raw is not None:
result = _parse_nvd_response(raw, package_name)
result["search_mode"] = "cpe"
if result["count"] > 0:
_write_cache(primary, result)
logger.info("[OK] NVD CPE query: %s -> %d CVEs", package_name, result["count"])
return json.dumps(result, ensure_ascii=False, indent=2)
logger.info("[INFO] NVD CPE no results for: %s", primary)
# โ”€โ”€ 3. Keyword ๆœๅฐ‹๏ผˆfallback๏ผŒๅƒ…ๅฐๅฅ—ไปถๅๆœฌ่บซ โ€” ้ž็จ‹ๅผ็ขผ้—œ้ตๅญ—๏ผ‰โ”€โ”€
for keyword in candidates:
raw = _query_nvd_api(keyword)
if raw is not None:
result = _parse_nvd_response(raw, package_name)
result["search_mode"] = "keyword"
if result["count"] > 0:
_write_cache(keyword, result)
logger.info("[OK] NVD keyword query: %s -> %d CVEs", package_name, result["count"])
return json.dumps(result, ensure_ascii=False, indent=2)
logger.info("[INFO] NVD keyword no results for: %s, trying next alias", keyword)
continue
cached = _read_cache(keyword, allow_stale=True)
if cached:
cached.pop("_cached_at", None)
cached["fallback_used"] = True
cached["cache_stale"] = True
cached["error"] = f"NVD API unavailable, using cached data for '{keyword}'"
return json.dumps(cached, ensure_ascii=False, indent=2)
# โ”€โ”€ 4. ๅ…จ้ƒจๆŸฅไธๅˆฐ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
empty_result = {
"package": package_name,
"source": "NVD",
"count": 0,
"vulnerabilities": [],
"search_mode": "none",
"error": f"No vulnerabilities found for '{package_name}' (tried: {candidates})",
"fallback_used": False,
}
logger.info("[INFO] NVD no results for: %s", package_name)
return json.dumps(empty_result, ensure_ascii=False, indent=2)
except Exception as e:
logger.error("[FAIL] NVD Tool unexpected error: %s", e, exc_info=True)
return json.dumps({
"package": package_name, "source": "NVD", "count": 0,
"vulnerabilities": [], "error": f"Unexpected error: {str(e)}",
"fallback_used": False,
}, ensure_ascii=False, indent=2)
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
# CrewAI @tool ๅŒ…่ฃ๏ผˆAgent ๅ‘ผๅซ็”จ๏ผ‰
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
def _create_tool():
"""ๅปถ้ฒๅปบ็ซ‹ CrewAI Tool๏ผŒๅƒ…ๅœจ Agent ๅฏฆ้š›ไฝฟ็”จๆ™‚ๆ‰ import"""
from crewai.tools import tool
@tool("search_nvd")
def search_nvd(package_name: str) -> str:
"""ๆŸฅ่ฉข NVD (National Vulnerability Database) ไธญๆŒ‡ๅฎšๅฅ—ไปถ็š„ๅทฒ็Ÿฅๆผๆดžใ€‚
่ผธๅ…ฅๅฅ—ไปถๅ็จฑ๏ผˆๅฆ‚ djangoใ€redisใ€postgresql๏ผ‰๏ผŒๅ›žๅ‚ณ่ฉฒๅฅ—ไปถ็š„ CVE ๆผๆดžๆธ…ๅ–ฎ๏ผˆJSON ๆ ผๅผ๏ผ‰ใ€‚
ๅŒ…ๅซ CVE ็ทจ่™Ÿใ€CVSS ๅˆ†ๆ•ธใ€ๅšด้‡ๅบฆใ€ๆ่ฟฐใ€ๅ—ๅฝฑ้Ÿฟ็‰ˆๆœฌ็ญ‰่ณ‡่จŠใ€‚
่‹ฅ API ไธๅฏ็”จๆœƒ่‡ชๅ‹•ไฝฟ็”จ้›ข็ทšๅฟซๅ–ใ€‚"""
return _search_nvd_impl(package_name)
return search_nvd
# โ”€โ”€ ๅปถ้ฒ่ผ‰ๅ…ฅๆฉŸๅˆถ๏ผˆ่ˆ‡ memory_tool.py ็›ธๅŒๆจกๅผ๏ผ‰โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
class _LazyToolLoader:
def __init__(self):
self._tool = None
def _load(self):
if self._tool is None:
self._tool = _create_tool()
@property
def search_nvd(self):
self._load()
return self._tool
_loader = _LazyToolLoader()
def __getattr__(name):
"""ๆจก็ต„ๅฑค็ดš __getattr__๏ผŒๆ”ฏๆด from tools.nvd_tool import search_nvd"""
if name == "search_nvd":
return _loader.search_nvd
raise AttributeError(f"module 'tools.nvd_tool' has no attribute {name!r}")