File size: 5,332 Bytes
8ede856
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import base64
import traceback
from io import BytesIO

from astrbot.api import logger
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
from astrbot.core.knowledge_base.kb_helper import KBHelper
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager


async def generate_tsne_visualization(
    query: str,
    kb_names: list[str],
    kb_manager: KnowledgeBaseManager,
) -> str | None:
    """生成 t-SNE 可视化图片

    Args:
        query: 查询文本
        kb_names: 知识库名称列表
        kb_manager: 知识库管理器

    Returns:
        图片路径或 None

    """
    try:
        import faiss
        import matplotlib
        import numpy as np

        matplotlib.use("Agg")  # 使用非交互式后端
        import matplotlib.pyplot as plt
        from sklearn.manifold import TSNE
    except ImportError as e:
        raise Exception(
            "缺少必要的库以生成 t-SNE 可视化。请安装 matplotlib 和 scikit-learn: {e}",
        ) from e

    try:
        # 获取第一个知识库的向量数据
        kb_helper: KBHelper | None = None
        for kb_name in kb_names:
            kb_helper = await kb_manager.get_kb_by_name(kb_name)
            if kb_helper:
                break

        if not kb_helper:
            logger.warning("未找到知识库")
            return None

        kb = kb_helper.kb
        index_path = kb_helper.kb_dir / "index.faiss"

        # 读取 FAISS 索引
        if not index_path.exists():
            logger.warning(f"FAISS 索引不存在: {index_path!s}")
            return None

        index = faiss.read_index(str(index_path))

        if index.ntotal == 0:
            logger.warning("索引为空")
            return None

        # 提取所有向量
        logger.info(f"提取 {index.ntotal} 个向量用于可视化...")
        if isinstance(index, faiss.IndexIDMap):
            base_index = faiss.downcast_index(index.index)
            if hasattr(base_index, "reconstruct_n"):
                vectors = base_index.reconstruct_n(0, index.ntotal)
            else:
                vectors = np.zeros((index.ntotal, index.d), dtype=np.float32)
                for i in range(index.ntotal):
                    base_index.reconstruct(i, vectors[i])
        elif hasattr(index, "reconstruct_n"):
            vectors = index.reconstruct_n(0, index.ntotal)
        else:
            vectors = np.zeros((index.ntotal, index.d), dtype=np.float32)
            for i in range(index.ntotal):
                index.reconstruct(i, vectors[i])

        # 获取查询向量
        vec_db: FaissVecDB = kb_helper.vec_db  # type: ignore
        embedding_provider = vec_db.embedding_provider
        query_embedding = await embedding_provider.get_embedding(query)
        query_vector = np.array([query_embedding], dtype=np.float32)

        # 合并所有向量和查询向量
        all_vectors = np.vstack([vectors, query_vector])

        # t-SNE 降维
        logger.info("开始 t-SNE 降维...")
        perplexity = min(30, all_vectors.shape[0] - 1)
        tsne = TSNE(n_components=2, random_state=42, perplexity=perplexity)
        vectors_2d = tsne.fit_transform(all_vectors)

        # 分离知识库向量和查询向量
        kb_vectors_2d = vectors_2d[:-1]
        query_vector_2d = vectors_2d[-1]

        # 可视化
        logger.info("生成可视化图表...")
        plt.figure(figsize=(14, 10))

        # 绘制知识库向量
        scatter = plt.scatter(
            kb_vectors_2d[:, 0],
            kb_vectors_2d[:, 1],
            alpha=0.5,
            s=40,
            c=range(len(kb_vectors_2d)),
            cmap="viridis",
            label="Knowledge Base Vectors",
        )

        # 绘制查询向量(红色 X)
        plt.scatter(
            query_vector_2d[0],
            query_vector_2d[1],
            c="red",
            s=300,
            marker="X",
            edgecolors="black",
            linewidths=2,
            label="Query",
            zorder=5,
        )

        # 添加查询文本标注
        plt.annotate(
            "Query",
            (query_vector_2d[0], query_vector_2d[1]),
            xytext=(10, 10),
            textcoords="offset points",
            fontsize=10,
            bbox={"boxstyle": "round,pad=0.5", "fc": "yellow", "alpha": 0.7},
            arrowprops={"arrowstyle": "->", "connectionstyle": "arc3,rad=0"},
        )

        plt.colorbar(scatter, label="Vector Index")
        plt.title(
            f"t-SNE Visualization: Query in Knowledge Base\n"
            f"({index.ntotal} vectors, {index.d} dimensions, KB: {kb.kb_name})",
            fontsize=14,
            pad=20,
        )
        plt.xlabel("t-SNE Dimension 1", fontsize=12)
        plt.ylabel("t-SNE Dimension 2", fontsize=12)
        plt.grid(True, alpha=0.3)
        plt.legend(fontsize=10, loc="upper right")

        # base64 编码图片返回
        buffer = BytesIO()
        plt.savefig(buffer, format="png", dpi=150, bbox_inches="tight")
        plt.close()
        buffer.seek(0)
        img_base64 = base64.b64encode(buffer.read()).decode("utf-8")
        return img_base64

    except Exception as e:
        logger.error(f"生成 t-SNE 可视化时出错: {e}")
        logger.error(traceback.format_exc())
        return None