File size: 14,788 Bytes
c8d30bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
"""
ThreatHunter 雙層記憶學習 Tool
==============================

Layer 1: JSON 持久化(穩定保底,Day 1 起可用)
Layer 2: LlamaIndex RAG(語義搜尋,ENABLE_MEMORY_RAG=true 啟用)

遵循文件:
  - FINAL_PLAN.md §支柱 3(Feedback Loops:雙層記憶學習系統)
  - leader_plan.md §記憶學習系統
"""

import json
import logging
from pathlib import Path
from datetime import datetime, timezone
from typing import Any

from config import MEMORY_DIR, ENABLE_MEMORY_RAG, SIMILARITY_THRESHOLD
from crewai.tools import tool

logger = logging.getLogger("threathunter.memory")

# Sandbox Layer 3: Memory cache sanitization
try:
    from sandbox.memory_sanitizer import sanitize_memory_write as _sanitize_write
    _MEM_SANITIZER_OK = True
except ImportError:
    def _sanitize_write(data, agent_name=''):  # type: ignore[misc]
        return True, data, 'ok'
    _MEM_SANITIZER_OK = False

# ── 常數 ─────────────────────────────────────────────────────
VALID_AGENT_NAMES = {"scout", "analyst", "advisor", "critic", "orchestrator"}


# ── Layer 1: JSON 持久化工具函式 ─────────────────────────────
def _get_memory_path(agent_name: str) -> Path:
    """取得指定 Agent 的記憶 JSON 路徑"""
    return MEMORY_DIR / f"{agent_name}_memory.json"


def _load_json(path: Path) -> dict:
    """安全載入 JSON,檔案不存在或損壞時回傳空 dict"""
    if not path.exists():
        return {}
    try:
        content = path.read_text(encoding="utf-8")
        if not content.strip():
            return {}
        return json.loads(content)
    except (json.JSONDecodeError, OSError) as e:
        logger.warning("[WARN] Memory file read failed %s: %s, returning empty", path, e)
        return {}


def _save_json(path: Path, data: dict) -> None:
    """安全寫入 JSON(先寫臨時檔再 rename,防止寫入中斷導致損壞)"""
    path.parent.mkdir(parents=True, exist_ok=True)
    temp_path = path.with_suffix(".tmp")
    try:
        temp_path.write_text(
            json.dumps(data, ensure_ascii=False, indent=2),
            encoding="utf-8",
        )
        temp_path.replace(path)
    except OSError as e:
        logger.error("[FAIL] Memory file write failed %s: %s", path, e)
        if temp_path.exists():
            temp_path.unlink()
        raise


# ── Layer 2: LlamaIndex RAG(條件性啟用)─────────────────────
_rag_index = None
_rag_query_engine = None


def _init_rag() -> None:
    """延遲初始化 LlamaIndex RAG(只在第一次呼叫時執行)"""
    global _rag_index, _rag_query_engine

    if not ENABLE_MEMORY_RAG:
        return
    if _rag_index is not None:
        return

    try:
        from llama_index.core import VectorStoreIndex, StorageContext, Settings
        from llama_index.core import load_index_from_storage

        # ── 設定 Free Local Embedding(不需要 OpenAI API Key)──
        # 使用 HuggingFace BAAI/bge-small-en-v1.5:輕量、快速、免費
        try:
            from llama_index.embeddings.huggingface import HuggingFaceEmbedding
            Settings.embed_model = HuggingFaceEmbedding(
                model_name="BAAI/bge-small-en-v1.5"
            )
            logger.info("[OK] Embedding: HuggingFace BAAI/bge-small-en-v1.5 (local free)")
        except ImportError:
            logger.warning("[WARN] HuggingFace embedding not installed, trying OpenAI embedding")

        # 停用 LLM(RAG 記憶層只需要 embedding,不需要 LLM 生成)
        try:
            from llama_index.core.llms import MockLLM
            Settings.llm = MockLLM()
        except Exception:
            Settings.llm = None  # type: ignore

        vector_store_path = MEMORY_DIR / "vector_store"

        if (vector_store_path / "docstore.json").exists():
            storage_context = StorageContext.from_defaults(
                persist_dir=str(vector_store_path)
            )
            _rag_index = load_index_from_storage(storage_context)
            logger.info("[OK] LlamaIndex vector index loaded")
            _rag_query_engine = _rag_index.as_query_engine(similarity_top_k=3)
        else:
            _rag_index = VectorStoreIndex([])
            logger.info("[OK] LlamaIndex vector index created (empty)")
            # 先設 query_engine,再回填(_backfill 需要 _rag_index 已就緒)
            _rag_query_engine = _rag_index.as_query_engine(similarity_top_k=3)
            _backfill_from_json_history()
            # 回填後重建 query_engine(索引已有資料)
            _rag_query_engine = _rag_index.as_query_engine(similarity_top_k=3)

    except ImportError:
        logger.warning("[WARN] LlamaIndex not installed, RAG disabled")
    except Exception as e:
        logger.warning("[WARN] LlamaIndex init failed: %s", e)


def _backfill_from_json_history() -> None:
    """
    首次啟用 RAG 時,將所有已有的 *_memory.json 批次回填進 LlamaIndex。
    只在 vector_store 新建(非載入)時執行一次,之後透過 persist 保存。

    設計原則(Harness Engineering — Feedback Loops 支柱):
    - 確保歷史掃描記錄不因 RAG 冷啟動而遺失語義感知
    - 失敗不阻塞(Graceful Degradation)
    """
    if _rag_index is None:
        return

    total_inserted = 0
    for agent_name in VALID_AGENT_NAMES:
        json_path = _get_memory_path(agent_name)
        if not json_path.exists():
            continue

        data = _load_json(json_path)
        if not data:
            continue

        # 回填 latest(最新掃描)
        _rag_insert(agent_name, data)
        total_inserted += 1

        # 回填 history[] 陣列中的每一筆歷史掃描
        for hist_scan in data.get("history", []):
            if isinstance(hist_scan, dict) and hist_scan:
                _rag_insert(agent_name, hist_scan)
                total_inserted += 1

    if total_inserted > 0:
        logger.info("[OK] RAG history backfill done: %d scan records vectorized", total_inserted)
    else:
        logger.info("[INFO] RAG backfill: no historical JSON memory to backfill (first scan)")


def _extract_package_names(tech_stack: str) -> set[str]:
    """
    從技術棧字串中提取套件名稱(小寫、去版本號)。

    範例:
      'Django 4.2, Redis 7.0' -> {'django', 'redis'}
      'Spring Boot 3.1 和 Node.js 18' -> {'spring', 'boot', 'node.js'}
    """
    if not tech_stack:
        return set()
    names = set()
    for part in tech_stack.replace(",", " ").split():
        clean = part.strip().lower()
        # 跳過版本號(純數字或 x.y.z 格式)
        if clean and not clean.replace(".", "").replace("-", "").isdigit():
            names.add(clean)
    return names


def _rag_insert(agent_name: str, data: dict) -> None:
    """將資料插入 LlamaIndex 向量索引(雙寫的 Layer 2),含 tech_stack 元資料"""
    if not ENABLE_MEMORY_RAG or _rag_index is None:
        return
    try:
        from llama_index.core import Document

        # 提取 tech_stack:可能在 data 的不同欄位中
        tech_stack = (
            data.get("tech_stack", "")
            or data.get("tech_stack_input", "")
            or ""
        )
        # 如果 tech_stack 是 list,轉成字串
        if isinstance(tech_stack, list):
            tech_stack = ", ".join(str(t) for t in tech_stack)

        doc = Document(
            text=json.dumps(data, ensure_ascii=False),
            metadata={
                "agent": agent_name,
                "timestamp": data.get("timestamp", ""),
                "scan_id": data.get("scan_id", ""),
                "tech_stack": str(tech_stack),
            },
        )
        _rag_index.insert(doc)
        vector_store_path = MEMORY_DIR / "vector_store"
        _rag_index.storage_context.persist(persist_dir=str(vector_store_path))
        logger.info("[OK] RAG index updated: %s (tech_stack=%s)", agent_name, tech_stack[:50])
    except Exception as e:
        logger.warning("[WARN] RAG write failed (JSON layer unaffected): %s", e)


def _rag_search(query: str, tech_stack: str | None = None) -> str:
    """
    語義搜尋(帶安全閥 + 技術棧相關性過濾)。

    Args:
        query: 搜尋查詢
        tech_stack: 當前掃描的技術棧(用於過濾不相關歷史)
    """
    if not ENABLE_MEMORY_RAG:
        return "RAG disabled (ENABLE_MEMORY_RAG=false)"

    _init_rag()
    if _rag_index is None:
        return "RAG index unavailable"

    try:
        doc_count = len(_rag_index.docstore.docs) if hasattr(_rag_index, "docstore") else 0
        if doc_count == 0:
            return "No history available (vector index empty)"

        response = _rag_query_engine.query(query)

        # 相關性門檻過濾
        if hasattr(response, "source_nodes") and response.source_nodes:
            scores = [n.score for n in response.source_nodes if n.score is not None]
            max_score = max(scores) if scores else 0
            if max_score < SIMILARITY_THRESHOLD:
                return (
                    f"No relevant history found"
                    f" (max_similarity {max_score:.2f} < threshold {SIMILARITY_THRESHOLD})"
                )

            # 技術棧相關性過濾:只保留與當前掃描套件有交集的歷史
            if tech_stack:
                current_packages = _extract_package_names(tech_stack)
                if current_packages:
                    filtered_nodes = []
                    for node in response.source_nodes:
                        node_tech = node.metadata.get("tech_stack", "")
                        if not node_tech:
                            # 無 tech_stack 元資料的舊記錄,保守保留
                            filtered_nodes.append(node)
                            continue
                        node_packages = _extract_package_names(node_tech)
                        # 有套件名稱交集才保留
                        if current_packages & node_packages:
                            filtered_nodes.append(node)
                        else:
                            logger.info(
                                "[FILTER] Excluded history: %s (no overlap with %s)",
                                node_tech[:50], tech_stack[:50],
                            )

                    if not filtered_nodes:
                        return (
                            f"No relevant history for current tech stack"
                            f" (filtered {len(response.source_nodes)} results, 0 matched)"
                        )
                    response.source_nodes = filtered_nodes

        return str(response)

    except Exception as e:
        logger.warning("[WARN] RAG search failed: %s", e)
        return f"RAG search failed: {e}"


# ── CrewAI Tool 定義 ─────────────────────────────────────────
@tool("read_memory")
def read_memory(agent_name: str) -> str:
    """
    讀取指定 Agent 的歷史記憶(JSON Layer 1:穩定保底)。
    0 份歷史回傳空 JSON,Agent 可據此判斷是否為第一次掃描。

    Args:
        agent_name: Agent 名稱(scout / analyst / advisor / critic / orchestrator)

    Returns:
        JSON 字串格式的歷史記憶
    """
    agent_name = agent_name.strip().lower()
    if agent_name not in VALID_AGENT_NAMES:
        logger.warning("[WARN] Invalid agent name: %s", agent_name)
        return json.dumps({}, ensure_ascii=False)

    data = _load_json(_get_memory_path(agent_name))

    if not data:
        logger.info("[INFO] %s has no history (first scan)", agent_name)
    else:
        logger.info("[OK] %s memory loaded (scan_id: %s)", agent_name, data.get('scan_id', 'N/A'))

    return json.dumps(data, ensure_ascii=False, indent=2)


@tool("write_memory")
def write_memory(agent_name: str, data: str) -> str:
    """
    寫入 Agent 記憶(雙寫:JSON + LlamaIndex)。自動添加 timestamp。

    Args:
        agent_name: Agent 名稱(scout / analyst / advisor / critic / orchestrator)
        data: JSON 字串格式的記憶資料

    Returns:
        寫入結果訊息
    """
    agent_name = agent_name.strip().lower()
    if agent_name not in VALID_AGENT_NAMES:
        return f"[FAIL] Invalid agent name: {agent_name} (allowed: {VALID_AGENT_NAMES})"

    try:
        memory_data = json.loads(data) if isinstance(data, str) else data
    except json.JSONDecodeError as e:
        return f"[FAIL] JSON format error: {e}"

    # Sandbox Layer 3: poison filter before write
    is_safe, clean_data, reason = _sanitize_write(memory_data, agent_name)
    if not is_safe:
        logger.warning('[MEMORY][SANDBOX] Write BLOCKED: %s', reason)
        return '[BLOCKED] Memory write rejected by Sandbox: ' + reason
    memory_data = clean_data

    memory_data["timestamp"] = datetime.now(timezone.utc).isoformat()

    # Layer 1: JSON — 累積 history[] 陣列
    try:
        existing = _load_json(_get_memory_path(agent_name))
        history = existing.get("history", [])

        # 若已有舊的 latest,推入 history(最多保留 50 筆,防止無限增長)
        if existing and "scan_id" in existing:
            old_entry = {k: v for k, v in existing.items() if k != "history"}
            history.append(old_entry)
            if len(history) > 50:
                history = history[-50:]  # 保留最新 50 筆

        memory_data["history"] = history
        _save_json(_get_memory_path(agent_name), memory_data)
        logger.info("[OK] %s memory saved to JSON (Layer 1 | history=%d records)", agent_name, len(history))
    except Exception as e:
        return f"[FAIL] JSON write failed: {e}"
    # Layer 2: LlamaIndex(雙寫)
    _rag_insert(agent_name, memory_data)

    return f"[OK] {agent_name} memory saved (timestamp: {memory_data['timestamp']})"


@tool("history_search")
def history_search(query: str, tech_stack: str = "") -> str:
    """
    語義搜尋歷史安全報告(帶技術棧過濾)。
    帶安全閥:索引為空 / 分數太低 / 技術棧不匹配 / RAG 未啟用 → 回傳提示。

    Args:
        query: 搜尋查詢(例如:"Django SSRF 歷史")
        tech_stack: 當前掃描的技術棧(例如:"Django 4.2, Redis 7.0"),
                    用於過濾不相關的歷史記錄

    Returns:
        搜尋結果或安全提示
    """
    return _rag_search(query, tech_stack if tech_stack else None)