File size: 506 Bytes
1b435f0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 | import numpy as np
from sklearn.cluster import KMeans
def cluster_embeddings(
embeddings: np.ndarray,
n_clusters: int = 4,
random_state: int = 42,
) -> list[int]:
if len(embeddings) == 0:
return []
effective_clusters = min(n_clusters, len(embeddings))
if effective_clusters == 1:
return [0]
kmeans = KMeans(
n_clusters=effective_clusters,
random_state=random_state,
n_init=10,
)
return kmeans.fit_predict(embeddings).tolist()
|