File size: 506 Bytes
1b435f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import numpy as np
from sklearn.cluster import KMeans


def cluster_embeddings(
    embeddings: np.ndarray,
    n_clusters: int = 4,
    random_state: int = 42,
) -> list[int]:
    if len(embeddings) == 0:
        return []

    effective_clusters = min(n_clusters, len(embeddings))
    if effective_clusters == 1:
        return [0]

    kmeans = KMeans(
        n_clusters=effective_clusters,
        random_state=random_state,
        n_init=10,
    )
    return kmeans.fit_predict(embeddings).tolist()