| import numpy as np |
| from chromadb import Client, Settings |
| from sklearn.decomposition import PCA |
| import joblib |
| import os |
| from datetime import datetime |
| import warnings |
| import cupy as cp |
| from cuml.cluster import KMeans as cuKMeans |
| from tqdm import tqdm |
| |
| warnings.filterwarnings('ignore', category=FutureWarning) |
| warnings.filterwarnings('ignore', category=UserWarning) |
|
|
| class TopicClusterer: |
| def __init__(self, chroma_uri: str = "./Data/database"): |
| """初始化聚类器 |
| |
| Args: |
| chroma_uri: ChromaDB数据库路径 |
| """ |
| self.chroma_uri = chroma_uri |
| self.client = Client(Settings( |
| persist_directory=chroma_uri, |
| anonymized_telemetry=False, |
| is_persistent=True |
| )) |
| |
| self.vector_dim = 768 |
| |
| |
| try: |
| self.collection = self.client.get_collection("healthcare_qa") |
| except Exception as e: |
| print(f"集合不存在") |
| |
| self.embeddings = None |
| self.reduced_embeddings = None |
| self.labels = None |
| self.document_ids = None |
| |
| def load_embeddings(self) -> np.ndarray: |
| """从数据库加载embeddings""" |
| |
| embeddings_cache_file = '/home/dyvm6xra/dyvm6xrauser11/workspace/projects/HKU/Chatbot/Data/Embeddings/embeddings_703df19c43bd6565563071b97e7172ce.npy' |
| |
| |
| if os.path.exists(embeddings_cache_file) and 0: |
| print("发现缓存的embeddings,正在加载...") |
| try: |
| self.embeddings = np.load(embeddings_cache_file) |
| self.document_ids = [str(i) for i in range(len(self.embeddings))] |
| print(f"从缓存加载完成,数据形状: {self.embeddings.shape}") |
| return self.embeddings |
| except Exception as e: |
| print(f"加载缓存失败: {e},将从数据库重新加载") |
| else: |
| print("正在加载embeddings...") |
| print(self.collection.count()) |
| result = self.collection.get(include=["embeddings"]) |
| self.embeddings = np.array(result["embeddings"]) |
| self.document_ids = result["ids"] |
| |
| print(f"加载完成,数据形状: {self.embeddings.shape}") |
| return self.embeddings |
| |
| def reduce_dimensions(self, n_components: int = 2) -> np.ndarray: |
| """使用PCA进行降维 |
| |
| Args: |
| n_components: 降维后的维度 |
| """ |
| if self.embeddings is None: |
| self.load_embeddings() |
| |
| print("使用PCA进行降维...") |
| |
| |
| reducer = PCA( |
| n_components=n_components, |
| random_state=42, |
| svd_solver='randomized' |
| ) |
| self.reduced_embeddings = reducer.fit_transform(self.embeddings) |
| cumulative_variance = np.cumsum(reducer.explained_variance_ratio_) |
| print(f"PCA累积解释方差比: {cumulative_variance[-1]:.4f}") |
| |
| print(f"降维完成,降维后形状: {self.reduced_embeddings.shape}") |
| |
| |
| cache_dir = os.path.dirname(os.path.dirname(self.chroma_uri)) + '/Embeddings' |
| os.makedirs(cache_dir, exist_ok=True) |
| cache_file = os.path.join(cache_dir, f'pca_reduced_{n_components}d.npy') |
| np.save(cache_file, self.reduced_embeddings) |
| print(f"降维结果已缓存到: {cache_file}") |
| |
| return self.reduced_embeddings |
| |
| def cluster_kmeans(self, n_clusters: int = 4) -> np.ndarray: |
| """使用KMeans进行聚类 |
| |
| Args: |
| n_clusters: 聚类数 |
| """ |
| print("使用GPU加速的KMeans进行聚类...") |
| |
| |
| if self.reduced_embeddings is None: |
| self.reduce_dimensions() |
| |
| |
| data_gpu = cp.array(self.reduced_embeddings) |
| |
| |
| kmeans = cuKMeans( |
| n_clusters=n_clusters, |
| random_state=42, |
| n_init=10, |
| max_iter=300, |
| verbose=1 |
| ) |
| kmeans.fit(data_gpu) |
| self.labels = cp.asnumpy(kmeans.labels_) |
| |
| |
| unique_labels = np.unique(self.labels) |
| n_clusters = len(unique_labels) |
| |
| print(f"发现 {n_clusters} 个聚类") |
| for label in unique_labels: |
| count = np.sum(self.labels == label) |
| percentage = count / len(self.labels) * 100 |
| print(f"簇 {label}: {count} 样本 ({percentage:.2f}%)") |
| |
| return self.labels |
| |
| def update_database(self) -> None: |
| """将聚类结果写回数据库""" |
| if self.labels is None or self.document_ids is None: |
| raise ValueError("请先进行聚类") |
| |
| print("正在更新数据库...") |
| |
| |
| label_strings = [f"cluster_{label}" for label in self.labels] |
| |
| |
| batch_size = 500 |
| total_docs = len(self.document_ids) |
| |
| for i in tqdm(range(0, total_docs, batch_size), desc="批量更新数据库"): |
| batch_end = min(i + batch_size, total_docs) |
| batch_ids = self.document_ids[i:batch_end] |
| batch_labels = label_strings[i:batch_end] |
| |
| |
| continue |
|
|
| |
| |
| |
| |
|
|
| print("数据库更新完成") |
|
|
| def main(): |
| |
| clusterer = TopicClusterer() |
| |
| |
| clusterer.load_embeddings() |
| |
| |
| clusterer.reduce_dimensions(n_components=2) |
| |
| |
| clusterer.cluster_kmeans(n_clusters=4) |
| |
| |
| clusterer.update_database() |
|
|
| if __name__ == "__main__": |
| main() |
|
|