| """ |
| Shared evaluation metrics for GAP-CLIP experiments. |
| |
| Provides nearest-neighbor accuracy, separation score, centroid-based accuracy, |
| and confusion matrix generation — used across all evaluation sections. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from collections import defaultdict |
| from typing import List, Optional, Tuple |
|
|
| import matplotlib.pyplot as plt |
| import numpy as np |
| import seaborn as sns |
| from sklearn.metrics import accuracy_score, classification_report, confusion_matrix |
| from sklearn.metrics.pairwise import cosine_similarity |
| from sklearn.preprocessing import normalize |
|
|
|
|
| def compute_similarity_metrics( |
| embeddings: np.ndarray, |
| labels: List[str], |
| max_samples: int = 5000, |
| ) -> dict: |
| """Compute intra/inter-class similarities and nearest-neighbor accuracy. |
| |
| Uses vectorized numpy operations for efficiency. |
| |
| Args: |
| embeddings: Array of shape (N, D). |
| labels: List of N class labels. |
| max_samples: Cap for large datasets (random subsample). |
| |
| Returns: |
| Dict with keys: intra_class_mean, inter_class_mean, separation_score, |
| accuracy (NN), centroid_accuracy, intra_class_similarities, |
| inter_class_similarities. |
| """ |
| if len(embeddings) > max_samples: |
| indices = np.random.choice(len(embeddings), max_samples, replace=False) |
| embeddings = embeddings[indices] |
| labels = [labels[i] for i in indices] |
|
|
| similarities = cosine_similarity(embeddings) |
|
|
| label_array = np.array(labels) |
| unique_labels = np.unique(label_array) |
| label_groups = {label: np.where(label_array == label)[0] for label in unique_labels} |
|
|
| intra_class_similarities: List[float] = [] |
| for indices in label_groups.values(): |
| if len(indices) > 1: |
| sub = similarities[np.ix_(indices, indices)] |
| triu = np.triu_indices_from(sub, k=1) |
| intra_class_similarities.extend(sub[triu].tolist()) |
|
|
| inter_class_similarities: List[float] = [] |
| keys = list(label_groups.keys()) |
| for i in range(len(keys)): |
| for j in range(i + 1, len(keys)): |
| inter = similarities[np.ix_(label_groups[keys[i]], label_groups[keys[j]])] |
| inter_class_similarities.extend(inter.flatten().tolist()) |
|
|
| nn_acc = compute_embedding_accuracy(embeddings, labels, similarities) |
| centroid_acc = compute_centroid_accuracy(embeddings, labels) |
|
|
| return { |
| "intra_class_similarities": intra_class_similarities, |
| "inter_class_similarities": inter_class_similarities, |
| "intra_class_mean": float(np.mean(intra_class_similarities)) if intra_class_similarities else 0.0, |
| "inter_class_mean": float(np.mean(inter_class_similarities)) if inter_class_similarities else 0.0, |
| "separation_score": ( |
| float(np.mean(intra_class_similarities) - np.mean(inter_class_similarities)) |
| if intra_class_similarities and inter_class_similarities |
| else 0.0 |
| ), |
| "accuracy": nn_acc, |
| "centroid_accuracy": centroid_acc, |
| } |
|
|
|
|
| def compute_embedding_accuracy( |
| embeddings: np.ndarray, |
| labels: List[str], |
| similarities: Optional[np.ndarray] = None, |
| ) -> float: |
| """Nearest-neighbor classification accuracy (leave-one-out). |
| |
| Args: |
| embeddings: Array of shape (N, D). |
| labels: List of N class labels. |
| similarities: Pre-computed cosine similarity matrix (N, N). Computed |
| if not provided. |
| |
| Returns: |
| Fraction of samples whose nearest neighbor shares their label. |
| """ |
| n = len(embeddings) |
| if n == 0: |
| return 0.0 |
| if similarities is None: |
| similarities = cosine_similarity(embeddings) |
|
|
| correct = 0 |
| for i in range(n): |
| sims = similarities[i].copy() |
| sims[i] = -1.0 |
| if labels[np.argmax(sims)] == labels[i]: |
| correct += 1 |
| return correct / n |
|
|
|
|
| def compute_centroid_accuracy( |
| embeddings: np.ndarray, |
| labels: List[str], |
| ) -> float: |
| """Centroid-based (1-NN centroid) classification accuracy. |
| |
| Uses L2-normalized embeddings and centroids for correct cosine comparison. |
| |
| Args: |
| embeddings: Array of shape (N, D). |
| labels: List of N class labels. |
| |
| Returns: |
| Fraction of samples classified correctly by nearest centroid. |
| """ |
| if len(embeddings) == 0: |
| return 0.0 |
|
|
| emb_norm = normalize(embeddings, norm="l2") |
| unique_labels = sorted(set(labels)) |
| centroids = {} |
| for label in unique_labels: |
| idx = [i for i, l in enumerate(labels) if l == label] |
| centroids[label] = normalize([emb_norm[idx].mean(axis=0)], norm="l2")[0] |
|
|
| centroid_labels = list(centroids.keys()) |
| centroid_matrix = np.vstack([centroids[l] for l in centroid_labels]) |
| sims = cosine_similarity(emb_norm, centroid_matrix) |
| predicted = [centroid_labels[int(np.argmax(row))] for row in sims] |
| return sum(p == t for p, t in zip(predicted, labels)) / len(labels) |
|
|
|
|
| def predict_labels_from_embeddings( |
| embeddings: np.ndarray, |
| labels: List[str], |
| ) -> List[str]: |
| """Predict a label for each embedding using nearest centroid. |
| |
| Returns: |
| List of predicted labels (same length as embeddings). |
| """ |
| valid_labels = [l for l in set(labels) if l is not None] |
| if not valid_labels: |
| return [None] * len(embeddings) |
|
|
| emb_norm = normalize(embeddings, norm="l2") |
| centroids = {} |
| for label in valid_labels: |
| mask = np.array(labels) == label |
| if np.any(mask): |
| centroids[label] = np.mean(emb_norm[mask], axis=0) |
|
|
| centroid_labels = list(centroids.keys()) |
| centroid_matrix = np.vstack([centroids[l] for l in centroid_labels]) |
| sims = cosine_similarity(emb_norm, centroid_matrix) |
| return [centroid_labels[int(np.argmax(row))] for row in sims] |
|
|
|
|
| def compute_worst_group_accuracy( |
| embeddings: np.ndarray, |
| labels: List[str], |
| groups: List[str], |
| ) -> Tuple[float, str, dict]: |
| """Compute per-group nearest-neighbor accuracy and return the worst. |
| |
| This is the key metric for DFR (Kirichenko et al. 2023): it measures |
| robustness to spurious correlations by reporting the accuracy of the |
| worst-performing (color x hierarchy) group. |
| |
| Args: |
| embeddings: Array of shape (N, D). |
| labels: List of N class labels (what we classify). |
| groups: List of N group labels (e.g. 'red_dress', 'blue_jeans'). |
| |
| Returns: |
| (worst_accuracy, worst_group_name, per_group_dict) |
| where per_group_dict maps group -> accuracy. |
| """ |
| similarities = cosine_similarity(embeddings) |
| label_array = np.array(labels) |
| group_array = np.array(groups) |
| unique_groups = np.unique(group_array) |
|
|
| per_group: dict = {} |
| for g in unique_groups: |
| mask = group_array == g |
| idxs = np.where(mask)[0] |
| if len(idxs) < 2: |
| continue |
| correct = 0 |
| for i in idxs: |
| sims = similarities[i].copy() |
| sims[i] = -1.0 |
| nn_idx = np.argmax(sims) |
| if label_array[nn_idx] == label_array[i]: |
| correct += 1 |
| per_group[g] = correct / len(idxs) |
|
|
| if not per_group: |
| return 0.0, "", {} |
|
|
| worst_group = min(per_group, key=per_group.get) |
| return per_group[worst_group], worst_group, per_group |
|
|
|
|
| def create_confusion_matrix( |
| true_labels: List[str], |
| predicted_labels: List[str], |
| title: str = "Confusion Matrix", |
| label_type: str = "Label", |
| ) -> Tuple[plt.Figure, float, np.ndarray]: |
| """Create and return a seaborn confusion-matrix heatmap figure. |
| |
| Args: |
| true_labels: Ground-truth labels. |
| predicted_labels: Predicted labels. |
| title: Plot title prefix. |
| label_type: Axis label (e.g. "Color", "Category"). |
| |
| Returns: |
| (fig, accuracy, cm_array) |
| """ |
| unique_labels = sorted(set(true_labels + predicted_labels)) |
| cm = confusion_matrix(true_labels, predicted_labels, labels=unique_labels) |
| acc = accuracy_score(true_labels, predicted_labels) |
|
|
| fig = plt.figure(figsize=(10, 8)) |
| sns.heatmap( |
| cm, |
| annot=True, |
| fmt="d", |
| cmap="Blues", |
| xticklabels=unique_labels, |
| yticklabels=unique_labels, |
| ) |
| plt.title(f"{title}\nAccuracy: {acc:.3f} ({acc * 100:.1f}%)") |
| plt.ylabel(f"True {label_type}") |
| plt.xlabel(f"Predicted {label_type}") |
| plt.xticks(rotation=45) |
| plt.yticks(rotation=0) |
| plt.tight_layout() |
| return fig, acc, cm |
|
|