File size: 5,332 Bytes
8ede856 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 | import base64
import traceback
from io import BytesIO
from astrbot.api import logger
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
from astrbot.core.knowledge_base.kb_helper import KBHelper
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
async def generate_tsne_visualization(
query: str,
kb_names: list[str],
kb_manager: KnowledgeBaseManager,
) -> str | None:
"""生成 t-SNE 可视化图片
Args:
query: 查询文本
kb_names: 知识库名称列表
kb_manager: 知识库管理器
Returns:
图片路径或 None
"""
try:
import faiss
import matplotlib
import numpy as np
matplotlib.use("Agg") # 使用非交互式后端
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
except ImportError as e:
raise Exception(
"缺少必要的库以生成 t-SNE 可视化。请安装 matplotlib 和 scikit-learn: {e}",
) from e
try:
# 获取第一个知识库的向量数据
kb_helper: KBHelper | None = None
for kb_name in kb_names:
kb_helper = await kb_manager.get_kb_by_name(kb_name)
if kb_helper:
break
if not kb_helper:
logger.warning("未找到知识库")
return None
kb = kb_helper.kb
index_path = kb_helper.kb_dir / "index.faiss"
# 读取 FAISS 索引
if not index_path.exists():
logger.warning(f"FAISS 索引不存在: {index_path!s}")
return None
index = faiss.read_index(str(index_path))
if index.ntotal == 0:
logger.warning("索引为空")
return None
# 提取所有向量
logger.info(f"提取 {index.ntotal} 个向量用于可视化...")
if isinstance(index, faiss.IndexIDMap):
base_index = faiss.downcast_index(index.index)
if hasattr(base_index, "reconstruct_n"):
vectors = base_index.reconstruct_n(0, index.ntotal)
else:
vectors = np.zeros((index.ntotal, index.d), dtype=np.float32)
for i in range(index.ntotal):
base_index.reconstruct(i, vectors[i])
elif hasattr(index, "reconstruct_n"):
vectors = index.reconstruct_n(0, index.ntotal)
else:
vectors = np.zeros((index.ntotal, index.d), dtype=np.float32)
for i in range(index.ntotal):
index.reconstruct(i, vectors[i])
# 获取查询向量
vec_db: FaissVecDB = kb_helper.vec_db # type: ignore
embedding_provider = vec_db.embedding_provider
query_embedding = await embedding_provider.get_embedding(query)
query_vector = np.array([query_embedding], dtype=np.float32)
# 合并所有向量和查询向量
all_vectors = np.vstack([vectors, query_vector])
# t-SNE 降维
logger.info("开始 t-SNE 降维...")
perplexity = min(30, all_vectors.shape[0] - 1)
tsne = TSNE(n_components=2, random_state=42, perplexity=perplexity)
vectors_2d = tsne.fit_transform(all_vectors)
# 分离知识库向量和查询向量
kb_vectors_2d = vectors_2d[:-1]
query_vector_2d = vectors_2d[-1]
# 可视化
logger.info("生成可视化图表...")
plt.figure(figsize=(14, 10))
# 绘制知识库向量
scatter = plt.scatter(
kb_vectors_2d[:, 0],
kb_vectors_2d[:, 1],
alpha=0.5,
s=40,
c=range(len(kb_vectors_2d)),
cmap="viridis",
label="Knowledge Base Vectors",
)
# 绘制查询向量(红色 X)
plt.scatter(
query_vector_2d[0],
query_vector_2d[1],
c="red",
s=300,
marker="X",
edgecolors="black",
linewidths=2,
label="Query",
zorder=5,
)
# 添加查询文本标注
plt.annotate(
"Query",
(query_vector_2d[0], query_vector_2d[1]),
xytext=(10, 10),
textcoords="offset points",
fontsize=10,
bbox={"boxstyle": "round,pad=0.5", "fc": "yellow", "alpha": 0.7},
arrowprops={"arrowstyle": "->", "connectionstyle": "arc3,rad=0"},
)
plt.colorbar(scatter, label="Vector Index")
plt.title(
f"t-SNE Visualization: Query in Knowledge Base\n"
f"({index.ntotal} vectors, {index.d} dimensions, KB: {kb.kb_name})",
fontsize=14,
pad=20,
)
plt.xlabel("t-SNE Dimension 1", fontsize=12)
plt.ylabel("t-SNE Dimension 2", fontsize=12)
plt.grid(True, alpha=0.3)
plt.legend(fontsize=10, loc="upper right")
# base64 编码图片返回
buffer = BytesIO()
plt.savefig(buffer, format="png", dpi=150, bbox_inches="tight")
plt.close()
buffer.seek(0)
img_base64 = base64.b64encode(buffer.read()).decode("utf-8")
return img_base64
except Exception as e:
logger.error(f"生成 t-SNE 可视化时出错: {e}")
logger.error(traceback.format_exc())
return None
|