| import base64 |
| import traceback |
| from io import BytesIO |
|
|
| from astrbot.api import logger |
| from astrbot.core.db.vec_db.faiss_impl import FaissVecDB |
| from astrbot.core.knowledge_base.kb_helper import KBHelper |
| from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager |
|
|
|
|
| async def generate_tsne_visualization( |
| query: str, |
| kb_names: list[str], |
| kb_manager: KnowledgeBaseManager, |
| ) -> str | None: |
| """生成 t-SNE 可视化图片 |
| |
| Args: |
| query: 查询文本 |
| kb_names: 知识库名称列表 |
| kb_manager: 知识库管理器 |
| |
| Returns: |
| 图片路径或 None |
| |
| """ |
| try: |
| import faiss |
| import matplotlib |
| import numpy as np |
|
|
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
| from sklearn.manifold import TSNE |
| except ImportError as e: |
| raise Exception( |
| "缺少必要的库以生成 t-SNE 可视化。请安装 matplotlib 和 scikit-learn: {e}", |
| ) from e |
|
|
| try: |
| |
| kb_helper: KBHelper | None = None |
| for kb_name in kb_names: |
| kb_helper = await kb_manager.get_kb_by_name(kb_name) |
| if kb_helper: |
| break |
|
|
| if not kb_helper: |
| logger.warning("未找到知识库") |
| return None |
|
|
| kb = kb_helper.kb |
| index_path = kb_helper.kb_dir / "index.faiss" |
|
|
| |
| if not index_path.exists(): |
| logger.warning(f"FAISS 索引不存在: {index_path!s}") |
| return None |
|
|
| index = faiss.read_index(str(index_path)) |
|
|
| if index.ntotal == 0: |
| logger.warning("索引为空") |
| return None |
|
|
| |
| logger.info(f"提取 {index.ntotal} 个向量用于可视化...") |
| if isinstance(index, faiss.IndexIDMap): |
| base_index = faiss.downcast_index(index.index) |
| if hasattr(base_index, "reconstruct_n"): |
| vectors = base_index.reconstruct_n(0, index.ntotal) |
| else: |
| vectors = np.zeros((index.ntotal, index.d), dtype=np.float32) |
| for i in range(index.ntotal): |
| base_index.reconstruct(i, vectors[i]) |
| elif hasattr(index, "reconstruct_n"): |
| vectors = index.reconstruct_n(0, index.ntotal) |
| else: |
| vectors = np.zeros((index.ntotal, index.d), dtype=np.float32) |
| for i in range(index.ntotal): |
| index.reconstruct(i, vectors[i]) |
|
|
| |
| vec_db: FaissVecDB = kb_helper.vec_db |
| embedding_provider = vec_db.embedding_provider |
| query_embedding = await embedding_provider.get_embedding(query) |
| query_vector = np.array([query_embedding], dtype=np.float32) |
|
|
| |
| all_vectors = np.vstack([vectors, query_vector]) |
|
|
| |
| logger.info("开始 t-SNE 降维...") |
| perplexity = min(30, all_vectors.shape[0] - 1) |
| tsne = TSNE(n_components=2, random_state=42, perplexity=perplexity) |
| vectors_2d = tsne.fit_transform(all_vectors) |
|
|
| |
| kb_vectors_2d = vectors_2d[:-1] |
| query_vector_2d = vectors_2d[-1] |
|
|
| |
| logger.info("生成可视化图表...") |
| plt.figure(figsize=(14, 10)) |
|
|
| |
| scatter = plt.scatter( |
| kb_vectors_2d[:, 0], |
| kb_vectors_2d[:, 1], |
| alpha=0.5, |
| s=40, |
| c=range(len(kb_vectors_2d)), |
| cmap="viridis", |
| label="Knowledge Base Vectors", |
| ) |
|
|
| |
| plt.scatter( |
| query_vector_2d[0], |
| query_vector_2d[1], |
| c="red", |
| s=300, |
| marker="X", |
| edgecolors="black", |
| linewidths=2, |
| label="Query", |
| zorder=5, |
| ) |
|
|
| |
| plt.annotate( |
| "Query", |
| (query_vector_2d[0], query_vector_2d[1]), |
| xytext=(10, 10), |
| textcoords="offset points", |
| fontsize=10, |
| bbox={"boxstyle": "round,pad=0.5", "fc": "yellow", "alpha": 0.7}, |
| arrowprops={"arrowstyle": "->", "connectionstyle": "arc3,rad=0"}, |
| ) |
|
|
| plt.colorbar(scatter, label="Vector Index") |
| plt.title( |
| f"t-SNE Visualization: Query in Knowledge Base\n" |
| f"({index.ntotal} vectors, {index.d} dimensions, KB: {kb.kb_name})", |
| fontsize=14, |
| pad=20, |
| ) |
| plt.xlabel("t-SNE Dimension 1", fontsize=12) |
| plt.ylabel("t-SNE Dimension 2", fontsize=12) |
| plt.grid(True, alpha=0.3) |
| plt.legend(fontsize=10, loc="upper right") |
|
|
| |
| buffer = BytesIO() |
| plt.savefig(buffer, format="png", dpi=150, bbox_inches="tight") |
| plt.close() |
| buffer.seek(0) |
| img_base64 = base64.b64encode(buffer.read()).decode("utf-8") |
| return img_base64 |
|
|
| except Exception as e: |
| logger.error(f"生成 t-SNE 可视化时出错: {e}") |
| logger.error(traceback.format_exc()) |
| return None |
|
|