PULSE-code / experiments /analysis /generate_action_labels.py
velvet-pine-22's picture
Upload folder using huggingface_hub
b4b2877 verified
#!/usr/bin/env python3
"""
Generate action labels by clustering task descriptions using text embeddings.
No manual rules — uses sentence-transformers + K-Means clustering.
"""
import os
import json
import glob
import argparse
import numpy as np
from collections import Counter
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
ANNOTATION_DIR = "${PULSE_ROOT}"
def collect_tasks():
"""Collect all task descriptions from all annotation files."""
tasks = []
for path in sorted(glob.glob(os.path.join(ANNOTATION_DIR, 'v*/s*.json'))):
with open(path) as f:
data = json.load(f)
for seg in data.get('segments', []):
tasks.append(seg['task'])
return tasks
def embed_texts(texts):
"""Encode texts using sentence-transformers (multilingual model)."""
try:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2')
embeddings = model.encode(texts, show_progress_bar=True, batch_size=128)
print(f"Encoded {len(texts)} texts with sentence-transformers, dim={embeddings.shape[1]}")
return embeddings
except Exception as e:
print(f"sentence-transformers failed ({e}), falling back to TF-IDF")
from sklearn.feature_extraction.text import TfidfVectorizer
vec = TfidfVectorizer(analyzer='char', ngram_range=(1, 3), max_features=3000)
X = vec.fit_transform(texts).toarray()
print(f"Encoded {len(texts)} texts with TF-IDF char n-grams, dim={X.shape[1]}")
return X
def cluster_tasks(tasks, k_range=(10, 30)):
unique_tasks = sorted(set(tasks))
print(f"Total segments: {len(tasks)}, Unique task texts: {len(unique_tasks)}")
X = embed_texts(unique_tasks)
# Find optimal K via silhouette score
best_k, best_score = k_range[0], -1
scores = {}
for k in range(k_range[0], k_range[1] + 1):
km = KMeans(n_clusters=k, random_state=42, n_init=10)
labels = km.fit_predict(X)
score = silhouette_score(X, labels, sample_size=min(2000, len(unique_tasks)))
scores[k] = score
if score > best_score:
best_score = score
best_k = k
print(f" K={k}: silhouette={score:.4f}" + (" *" if k == best_k else ""))
print(f"\nBest K={best_k} (silhouette={best_score:.4f})")
# Final clustering
km = KMeans(n_clusters=best_k, random_state=42, n_init=10)
labels = km.fit_predict(X)
task_to_cluster = {task: int(labels[i]) for i, task in enumerate(unique_tasks)}
# Representative task per cluster (closest to centroid)
cluster_representatives = {}
cluster_members = {}
for cid in range(best_k):
member_idx = [i for i, l in enumerate(labels) if l == cid]
members = [unique_tasks[i] for i in member_idx]
cluster_members[cid] = members
centroid = km.cluster_centers_[cid]
dists = np.linalg.norm(X[member_idx] - centroid, axis=1)
closest = member_idx[np.argmin(dists)]
cluster_representatives[cid] = unique_tasks[closest]
return task_to_cluster, cluster_representatives, cluster_members, best_k, scores
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--output_dir', type=str,
default='${PULSE_ROOT}/results/pred')
parser.add_argument('--k_min', type=int, default=10)
parser.add_argument('--k_max', type=int, default=30)
args = parser.parse_args()
os.makedirs(args.output_dir, exist_ok=True)
tasks = collect_tasks()
task_to_cluster, representatives, members, K, scores = cluster_tasks(
tasks, k_range=(args.k_min, args.k_max)
)
# Print summary
segment_counts = Counter(task_to_cluster[t] for t in tasks)
print(f"\n{'='*60}")
print(f"Clusters (K={K}):")
for cid in range(K):
rep = representatives[cid]
n_unique = len(members[cid])
n_segs = segment_counts.get(cid, 0)
examples = [m for m in members[cid] if m != rep][:3]
print(f"\n [{cid:2d}] ({n_segs:4d} segs, {n_unique:3d} unique) \"{rep}\"")
for ex in examples:
print(f" - {ex}")
# Save
output = {
'num_classes': K,
'task_to_cluster': task_to_cluster,
'cluster_representatives': {str(k): v for k, v in representatives.items()},
'cluster_sizes_unique': {str(k): len(v) for k, v in members.items()},
'cluster_sizes_segments': {str(k): v for k, v in segment_counts.items()},
'silhouette_scores': {str(k): v for k, v in scores.items()},
}
out_path = os.path.join(args.output_dir, 'action_labels.json')
with open(out_path, 'w') as f:
json.dump(output, f, indent=2, ensure_ascii=False)
print(f"\nSaved to {out_path}")
if __name__ == '__main__':
main()