File size: 506 Bytes
09f4a33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
from __future__ import annotations

import numpy as np
from sklearn.cluster import KMeans


def cluster_embeddings(
    embeddings: np.ndarray,
    n_clusters: int = 4,
    random_state: int = 42,
) -> list[int]:
    if len(embeddings) == 0:
        return []
    effective_k = min(n_clusters, len(embeddings))
    if effective_k == 1:
        return [0] * len(embeddings)
    kmeans = KMeans(n_clusters=effective_k, random_state=random_state, n_init=10)
    return kmeans.fit_predict(embeddings).tolist()