File size: 506 Bytes
09f4a33 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 | from __future__ import annotations
import numpy as np
from sklearn.cluster import KMeans
def cluster_embeddings(
embeddings: np.ndarray,
n_clusters: int = 4,
random_state: int = 42,
) -> list[int]:
if len(embeddings) == 0:
return []
effective_k = min(n_clusters, len(embeddings))
if effective_k == 1:
return [0] * len(embeddings)
kmeans = KMeans(n_clusters=effective_k, random_state=random_state, n_init=10)
return kmeans.fit_predict(embeddings).tolist()
|