gap-clip / evaluation /utils /metrics.py
Leacb4's picture
Upload evaluation/utils/metrics.py with huggingface_hub
b3bf2b7 verified
"""
Shared evaluation metrics for GAP-CLIP experiments.
Provides nearest-neighbor accuracy, separation score, centroid-based accuracy,
and confusion matrix generation — used across all evaluation sections.
"""
from __future__ import annotations
from collections import defaultdict
from typing import List, Optional, Tuple
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.preprocessing import normalize
def compute_similarity_metrics(
embeddings: np.ndarray,
labels: List[str],
max_samples: int = 5000,
) -> dict:
"""Compute intra/inter-class similarities and nearest-neighbor accuracy.
Uses vectorized numpy operations for efficiency.
Args:
embeddings: Array of shape (N, D).
labels: List of N class labels.
max_samples: Cap for large datasets (random subsample).
Returns:
Dict with keys: intra_class_mean, inter_class_mean, separation_score,
accuracy (NN), centroid_accuracy, intra_class_similarities,
inter_class_similarities.
"""
if len(embeddings) > max_samples:
indices = np.random.choice(len(embeddings), max_samples, replace=False)
embeddings = embeddings[indices]
labels = [labels[i] for i in indices]
similarities = cosine_similarity(embeddings)
label_array = np.array(labels)
unique_labels = np.unique(label_array)
label_groups = {label: np.where(label_array == label)[0] for label in unique_labels}
intra_class_similarities: List[float] = []
for indices in label_groups.values():
if len(indices) > 1:
sub = similarities[np.ix_(indices, indices)]
triu = np.triu_indices_from(sub, k=1)
intra_class_similarities.extend(sub[triu].tolist())
inter_class_similarities: List[float] = []
keys = list(label_groups.keys())
for i in range(len(keys)):
for j in range(i + 1, len(keys)):
inter = similarities[np.ix_(label_groups[keys[i]], label_groups[keys[j]])]
inter_class_similarities.extend(inter.flatten().tolist())
nn_acc = compute_embedding_accuracy(embeddings, labels, similarities)
centroid_acc = compute_centroid_accuracy(embeddings, labels)
return {
"intra_class_similarities": intra_class_similarities,
"inter_class_similarities": inter_class_similarities,
"intra_class_mean": float(np.mean(intra_class_similarities)) if intra_class_similarities else 0.0,
"inter_class_mean": float(np.mean(inter_class_similarities)) if inter_class_similarities else 0.0,
"separation_score": (
float(np.mean(intra_class_similarities) - np.mean(inter_class_similarities))
if intra_class_similarities and inter_class_similarities
else 0.0
),
"accuracy": nn_acc,
"centroid_accuracy": centroid_acc,
}
def compute_embedding_accuracy(
embeddings: np.ndarray,
labels: List[str],
similarities: Optional[np.ndarray] = None,
) -> float:
"""Nearest-neighbor classification accuracy (leave-one-out).
Args:
embeddings: Array of shape (N, D).
labels: List of N class labels.
similarities: Pre-computed cosine similarity matrix (N, N). Computed
if not provided.
Returns:
Fraction of samples whose nearest neighbor shares their label.
"""
n = len(embeddings)
if n == 0:
return 0.0
if similarities is None:
similarities = cosine_similarity(embeddings)
correct = 0
for i in range(n):
sims = similarities[i].copy()
sims[i] = -1.0
if labels[np.argmax(sims)] == labels[i]:
correct += 1
return correct / n
def compute_centroid_accuracy(
embeddings: np.ndarray,
labels: List[str],
) -> float:
"""Centroid-based (1-NN centroid) classification accuracy.
Uses L2-normalized embeddings and centroids for correct cosine comparison.
Args:
embeddings: Array of shape (N, D).
labels: List of N class labels.
Returns:
Fraction of samples classified correctly by nearest centroid.
"""
if len(embeddings) == 0:
return 0.0
emb_norm = normalize(embeddings, norm="l2")
unique_labels = sorted(set(labels))
centroids = {}
for label in unique_labels:
idx = [i for i, l in enumerate(labels) if l == label]
centroids[label] = normalize([emb_norm[idx].mean(axis=0)], norm="l2")[0]
centroid_labels = list(centroids.keys())
centroid_matrix = np.vstack([centroids[l] for l in centroid_labels])
sims = cosine_similarity(emb_norm, centroid_matrix)
predicted = [centroid_labels[int(np.argmax(row))] for row in sims]
return sum(p == t for p, t in zip(predicted, labels)) / len(labels)
def predict_labels_from_embeddings(
embeddings: np.ndarray,
labels: List[str],
) -> List[str]:
"""Predict a label for each embedding using nearest centroid.
Returns:
List of predicted labels (same length as embeddings).
"""
valid_labels = [l for l in set(labels) if l is not None]
if not valid_labels:
return [None] * len(embeddings)
emb_norm = normalize(embeddings, norm="l2")
centroids = {}
for label in valid_labels:
mask = np.array(labels) == label
if np.any(mask):
centroids[label] = np.mean(emb_norm[mask], axis=0)
centroid_labels = list(centroids.keys())
centroid_matrix = np.vstack([centroids[l] for l in centroid_labels])
sims = cosine_similarity(emb_norm, centroid_matrix)
return [centroid_labels[int(np.argmax(row))] for row in sims]
def compute_worst_group_accuracy(
embeddings: np.ndarray,
labels: List[str],
groups: List[str],
) -> Tuple[float, str, dict]:
"""Compute per-group nearest-neighbor accuracy and return the worst.
This is the key metric for DFR (Kirichenko et al. 2023): it measures
robustness to spurious correlations by reporting the accuracy of the
worst-performing (color x hierarchy) group.
Args:
embeddings: Array of shape (N, D).
labels: List of N class labels (what we classify).
groups: List of N group labels (e.g. 'red_dress', 'blue_jeans').
Returns:
(worst_accuracy, worst_group_name, per_group_dict)
where per_group_dict maps group -> accuracy.
"""
similarities = cosine_similarity(embeddings)
label_array = np.array(labels)
group_array = np.array(groups)
unique_groups = np.unique(group_array)
per_group: dict = {}
for g in unique_groups:
mask = group_array == g
idxs = np.where(mask)[0]
if len(idxs) < 2:
continue
correct = 0
for i in idxs:
sims = similarities[i].copy()
sims[i] = -1.0
nn_idx = np.argmax(sims)
if label_array[nn_idx] == label_array[i]:
correct += 1
per_group[g] = correct / len(idxs)
if not per_group:
return 0.0, "", {}
worst_group = min(per_group, key=per_group.get)
return per_group[worst_group], worst_group, per_group
def create_confusion_matrix(
true_labels: List[str],
predicted_labels: List[str],
title: str = "Confusion Matrix",
label_type: str = "Label",
) -> Tuple[plt.Figure, float, np.ndarray]:
"""Create and return a seaborn confusion-matrix heatmap figure.
Args:
true_labels: Ground-truth labels.
predicted_labels: Predicted labels.
title: Plot title prefix.
label_type: Axis label (e.g. "Color", "Category").
Returns:
(fig, accuracy, cm_array)
"""
unique_labels = sorted(set(true_labels + predicted_labels))
cm = confusion_matrix(true_labels, predicted_labels, labels=unique_labels)
acc = accuracy_score(true_labels, predicted_labels)
fig = plt.figure(figsize=(10, 8))
sns.heatmap(
cm,
annot=True,
fmt="d",
cmap="Blues",
xticklabels=unique_labels,
yticklabels=unique_labels,
)
plt.title(f"{title}\nAccuracy: {acc:.3f} ({acc * 100:.1f}%)")
plt.ylabel(f"True {label_type}")
plt.xlabel(f"Predicted {label_type}")
plt.xticks(rotation=45)
plt.yticks(rotation=0)
plt.tight_layout()
return fig, acc, cm