Threat_Hunter / tools /memory_tool.py
EricChen2005's picture
Deploy ThreatHunter - AMD MI300X + Qwen2.5-32B
c8d30bc
"""
ThreatHunter 雙層記憶學習 Tool
==============================
Layer 1: JSON 持久化(穩定保底,Day 1 起可用)
Layer 2: LlamaIndex RAG(語義搜尋,ENABLE_MEMORY_RAG=true 啟用)
遵循文件:
- FINAL_PLAN.md §支柱 3(Feedback Loops:雙層記憶學習系統)
- leader_plan.md §記憶學習系統
"""
import json
import logging
from pathlib import Path
from datetime import datetime, timezone
from typing import Any
from config import MEMORY_DIR, ENABLE_MEMORY_RAG, SIMILARITY_THRESHOLD
from crewai.tools import tool
logger = logging.getLogger("threathunter.memory")
# Sandbox Layer 3: Memory cache sanitization
try:
from sandbox.memory_sanitizer import sanitize_memory_write as _sanitize_write
_MEM_SANITIZER_OK = True
except ImportError:
def _sanitize_write(data, agent_name=''): # type: ignore[misc]
return True, data, 'ok'
_MEM_SANITIZER_OK = False
# ── 常數 ─────────────────────────────────────────────────────
VALID_AGENT_NAMES = {"scout", "analyst", "advisor", "critic", "orchestrator"}
# ── Layer 1: JSON 持久化工具函式 ─────────────────────────────
def _get_memory_path(agent_name: str) -> Path:
"""取得指定 Agent 的記憶 JSON 路徑"""
return MEMORY_DIR / f"{agent_name}_memory.json"
def _load_json(path: Path) -> dict:
"""安全載入 JSON,檔案不存在或損壞時回傳空 dict"""
if not path.exists():
return {}
try:
content = path.read_text(encoding="utf-8")
if not content.strip():
return {}
return json.loads(content)
except (json.JSONDecodeError, OSError) as e:
logger.warning("[WARN] Memory file read failed %s: %s, returning empty", path, e)
return {}
def _save_json(path: Path, data: dict) -> None:
"""安全寫入 JSON(先寫臨時檔再 rename,防止寫入中斷導致損壞)"""
path.parent.mkdir(parents=True, exist_ok=True)
temp_path = path.with_suffix(".tmp")
try:
temp_path.write_text(
json.dumps(data, ensure_ascii=False, indent=2),
encoding="utf-8",
)
temp_path.replace(path)
except OSError as e:
logger.error("[FAIL] Memory file write failed %s: %s", path, e)
if temp_path.exists():
temp_path.unlink()
raise
# ── Layer 2: LlamaIndex RAG(條件性啟用)─────────────────────
_rag_index = None
_rag_query_engine = None
def _init_rag() -> None:
"""延遲初始化 LlamaIndex RAG(只在第一次呼叫時執行)"""
global _rag_index, _rag_query_engine
if not ENABLE_MEMORY_RAG:
return
if _rag_index is not None:
return
try:
from llama_index.core import VectorStoreIndex, StorageContext, Settings
from llama_index.core import load_index_from_storage
# ── 設定 Free Local Embedding(不需要 OpenAI API Key)──
# 使用 HuggingFace BAAI/bge-small-en-v1.5:輕量、快速、免費
try:
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
Settings.embed_model = HuggingFaceEmbedding(
model_name="BAAI/bge-small-en-v1.5"
)
logger.info("[OK] Embedding: HuggingFace BAAI/bge-small-en-v1.5 (local free)")
except ImportError:
logger.warning("[WARN] HuggingFace embedding not installed, trying OpenAI embedding")
# 停用 LLM(RAG 記憶層只需要 embedding,不需要 LLM 生成)
try:
from llama_index.core.llms import MockLLM
Settings.llm = MockLLM()
except Exception:
Settings.llm = None # type: ignore
vector_store_path = MEMORY_DIR / "vector_store"
if (vector_store_path / "docstore.json").exists():
storage_context = StorageContext.from_defaults(
persist_dir=str(vector_store_path)
)
_rag_index = load_index_from_storage(storage_context)
logger.info("[OK] LlamaIndex vector index loaded")
_rag_query_engine = _rag_index.as_query_engine(similarity_top_k=3)
else:
_rag_index = VectorStoreIndex([])
logger.info("[OK] LlamaIndex vector index created (empty)")
# 先設 query_engine,再回填(_backfill 需要 _rag_index 已就緒)
_rag_query_engine = _rag_index.as_query_engine(similarity_top_k=3)
_backfill_from_json_history()
# 回填後重建 query_engine(索引已有資料)
_rag_query_engine = _rag_index.as_query_engine(similarity_top_k=3)
except ImportError:
logger.warning("[WARN] LlamaIndex not installed, RAG disabled")
except Exception as e:
logger.warning("[WARN] LlamaIndex init failed: %s", e)
def _backfill_from_json_history() -> None:
"""
首次啟用 RAG 時,將所有已有的 *_memory.json 批次回填進 LlamaIndex。
只在 vector_store 新建(非載入)時執行一次,之後透過 persist 保存。
設計原則(Harness Engineering — Feedback Loops 支柱):
- 確保歷史掃描記錄不因 RAG 冷啟動而遺失語義感知
- 失敗不阻塞(Graceful Degradation)
"""
if _rag_index is None:
return
total_inserted = 0
for agent_name in VALID_AGENT_NAMES:
json_path = _get_memory_path(agent_name)
if not json_path.exists():
continue
data = _load_json(json_path)
if not data:
continue
# 回填 latest(最新掃描)
_rag_insert(agent_name, data)
total_inserted += 1
# 回填 history[] 陣列中的每一筆歷史掃描
for hist_scan in data.get("history", []):
if isinstance(hist_scan, dict) and hist_scan:
_rag_insert(agent_name, hist_scan)
total_inserted += 1
if total_inserted > 0:
logger.info("[OK] RAG history backfill done: %d scan records vectorized", total_inserted)
else:
logger.info("[INFO] RAG backfill: no historical JSON memory to backfill (first scan)")
def _extract_package_names(tech_stack: str) -> set[str]:
"""
從技術棧字串中提取套件名稱(小寫、去版本號)。
範例:
'Django 4.2, Redis 7.0' -> {'django', 'redis'}
'Spring Boot 3.1 和 Node.js 18' -> {'spring', 'boot', 'node.js'}
"""
if not tech_stack:
return set()
names = set()
for part in tech_stack.replace(",", " ").split():
clean = part.strip().lower()
# 跳過版本號(純數字或 x.y.z 格式)
if clean and not clean.replace(".", "").replace("-", "").isdigit():
names.add(clean)
return names
def _rag_insert(agent_name: str, data: dict) -> None:
"""將資料插入 LlamaIndex 向量索引(雙寫的 Layer 2),含 tech_stack 元資料"""
if not ENABLE_MEMORY_RAG or _rag_index is None:
return
try:
from llama_index.core import Document
# 提取 tech_stack:可能在 data 的不同欄位中
tech_stack = (
data.get("tech_stack", "")
or data.get("tech_stack_input", "")
or ""
)
# 如果 tech_stack 是 list,轉成字串
if isinstance(tech_stack, list):
tech_stack = ", ".join(str(t) for t in tech_stack)
doc = Document(
text=json.dumps(data, ensure_ascii=False),
metadata={
"agent": agent_name,
"timestamp": data.get("timestamp", ""),
"scan_id": data.get("scan_id", ""),
"tech_stack": str(tech_stack),
},
)
_rag_index.insert(doc)
vector_store_path = MEMORY_DIR / "vector_store"
_rag_index.storage_context.persist(persist_dir=str(vector_store_path))
logger.info("[OK] RAG index updated: %s (tech_stack=%s)", agent_name, tech_stack[:50])
except Exception as e:
logger.warning("[WARN] RAG write failed (JSON layer unaffected): %s", e)
def _rag_search(query: str, tech_stack: str | None = None) -> str:
"""
語義搜尋(帶安全閥 + 技術棧相關性過濾)。
Args:
query: 搜尋查詢
tech_stack: 當前掃描的技術棧(用於過濾不相關歷史)
"""
if not ENABLE_MEMORY_RAG:
return "RAG disabled (ENABLE_MEMORY_RAG=false)"
_init_rag()
if _rag_index is None:
return "RAG index unavailable"
try:
doc_count = len(_rag_index.docstore.docs) if hasattr(_rag_index, "docstore") else 0
if doc_count == 0:
return "No history available (vector index empty)"
response = _rag_query_engine.query(query)
# 相關性門檻過濾
if hasattr(response, "source_nodes") and response.source_nodes:
scores = [n.score for n in response.source_nodes if n.score is not None]
max_score = max(scores) if scores else 0
if max_score < SIMILARITY_THRESHOLD:
return (
f"No relevant history found"
f" (max_similarity {max_score:.2f} < threshold {SIMILARITY_THRESHOLD})"
)
# 技術棧相關性過濾:只保留與當前掃描套件有交集的歷史
if tech_stack:
current_packages = _extract_package_names(tech_stack)
if current_packages:
filtered_nodes = []
for node in response.source_nodes:
node_tech = node.metadata.get("tech_stack", "")
if not node_tech:
# 無 tech_stack 元資料的舊記錄,保守保留
filtered_nodes.append(node)
continue
node_packages = _extract_package_names(node_tech)
# 有套件名稱交集才保留
if current_packages & node_packages:
filtered_nodes.append(node)
else:
logger.info(
"[FILTER] Excluded history: %s (no overlap with %s)",
node_tech[:50], tech_stack[:50],
)
if not filtered_nodes:
return (
f"No relevant history for current tech stack"
f" (filtered {len(response.source_nodes)} results, 0 matched)"
)
response.source_nodes = filtered_nodes
return str(response)
except Exception as e:
logger.warning("[WARN] RAG search failed: %s", e)
return f"RAG search failed: {e}"
# ── CrewAI Tool 定義 ─────────────────────────────────────────
@tool("read_memory")
def read_memory(agent_name: str) -> str:
"""
讀取指定 Agent 的歷史記憶(JSON Layer 1:穩定保底)。
0 份歷史回傳空 JSON,Agent 可據此判斷是否為第一次掃描。
Args:
agent_name: Agent 名稱(scout / analyst / advisor / critic / orchestrator)
Returns:
JSON 字串格式的歷史記憶
"""
agent_name = agent_name.strip().lower()
if agent_name not in VALID_AGENT_NAMES:
logger.warning("[WARN] Invalid agent name: %s", agent_name)
return json.dumps({}, ensure_ascii=False)
data = _load_json(_get_memory_path(agent_name))
if not data:
logger.info("[INFO] %s has no history (first scan)", agent_name)
else:
logger.info("[OK] %s memory loaded (scan_id: %s)", agent_name, data.get('scan_id', 'N/A'))
return json.dumps(data, ensure_ascii=False, indent=2)
@tool("write_memory")
def write_memory(agent_name: str, data: str) -> str:
"""
寫入 Agent 記憶(雙寫:JSON + LlamaIndex)。自動添加 timestamp。
Args:
agent_name: Agent 名稱(scout / analyst / advisor / critic / orchestrator)
data: JSON 字串格式的記憶資料
Returns:
寫入結果訊息
"""
agent_name = agent_name.strip().lower()
if agent_name not in VALID_AGENT_NAMES:
return f"[FAIL] Invalid agent name: {agent_name} (allowed: {VALID_AGENT_NAMES})"
try:
memory_data = json.loads(data) if isinstance(data, str) else data
except json.JSONDecodeError as e:
return f"[FAIL] JSON format error: {e}"
# Sandbox Layer 3: poison filter before write
is_safe, clean_data, reason = _sanitize_write(memory_data, agent_name)
if not is_safe:
logger.warning('[MEMORY][SANDBOX] Write BLOCKED: %s', reason)
return '[BLOCKED] Memory write rejected by Sandbox: ' + reason
memory_data = clean_data
memory_data["timestamp"] = datetime.now(timezone.utc).isoformat()
# Layer 1: JSON — 累積 history[] 陣列
try:
existing = _load_json(_get_memory_path(agent_name))
history = existing.get("history", [])
# 若已有舊的 latest,推入 history(最多保留 50 筆,防止無限增長)
if existing and "scan_id" in existing:
old_entry = {k: v for k, v in existing.items() if k != "history"}
history.append(old_entry)
if len(history) > 50:
history = history[-50:] # 保留最新 50 筆
memory_data["history"] = history
_save_json(_get_memory_path(agent_name), memory_data)
logger.info("[OK] %s memory saved to JSON (Layer 1 | history=%d records)", agent_name, len(history))
except Exception as e:
return f"[FAIL] JSON write failed: {e}"
# Layer 2: LlamaIndex(雙寫)
_rag_insert(agent_name, memory_data)
return f"[OK] {agent_name} memory saved (timestamp: {memory_data['timestamp']})"
@tool("history_search")
def history_search(query: str, tech_stack: str = "") -> str:
"""
語義搜尋歷史安全報告(帶技術棧過濾)。
帶安全閥:索引為空 / 分數太低 / 技術棧不匹配 / RAG 未啟用 → 回傳提示。
Args:
query: 搜尋查詢(例如:"Django SSRF 歷史")
tech_stack: 當前掃描的技術棧(例如:"Django 4.2, Redis 7.0"),
用於過濾不相關的歷史記錄
Returns:
搜尋結果或安全提示
"""
return _rag_search(query, tech_stack if tech_stack else None)