""" Shared evaluation metrics for GAP-CLIP experiments. Provides nearest-neighbor accuracy, separation score, centroid-based accuracy, and confusion matrix generation — used across all evaluation sections. """ from __future__ import annotations from collections import defaultdict from typing import List, Optional, Tuple import matplotlib.pyplot as plt import numpy as np import seaborn as sns from sklearn.metrics import accuracy_score, classification_report, confusion_matrix from sklearn.metrics.pairwise import cosine_similarity from sklearn.preprocessing import normalize def compute_similarity_metrics( embeddings: np.ndarray, labels: List[str], max_samples: int = 5000, ) -> dict: """Compute intra/inter-class similarities and nearest-neighbor accuracy. Uses vectorized numpy operations for efficiency. Args: embeddings: Array of shape (N, D). labels: List of N class labels. max_samples: Cap for large datasets (random subsample). Returns: Dict with keys: intra_class_mean, inter_class_mean, separation_score, accuracy (NN), centroid_accuracy, intra_class_similarities, inter_class_similarities. """ if len(embeddings) > max_samples: indices = np.random.choice(len(embeddings), max_samples, replace=False) embeddings = embeddings[indices] labels = [labels[i] for i in indices] similarities = cosine_similarity(embeddings) label_array = np.array(labels) unique_labels = np.unique(label_array) label_groups = {label: np.where(label_array == label)[0] for label in unique_labels} intra_class_similarities: List[float] = [] for indices in label_groups.values(): if len(indices) > 1: sub = similarities[np.ix_(indices, indices)] triu = np.triu_indices_from(sub, k=1) intra_class_similarities.extend(sub[triu].tolist()) inter_class_similarities: List[float] = [] keys = list(label_groups.keys()) for i in range(len(keys)): for j in range(i + 1, len(keys)): inter = similarities[np.ix_(label_groups[keys[i]], label_groups[keys[j]])] inter_class_similarities.extend(inter.flatten().tolist()) nn_acc = compute_embedding_accuracy(embeddings, labels, similarities) centroid_acc = compute_centroid_accuracy(embeddings, labels) return { "intra_class_similarities": intra_class_similarities, "inter_class_similarities": inter_class_similarities, "intra_class_mean": float(np.mean(intra_class_similarities)) if intra_class_similarities else 0.0, "inter_class_mean": float(np.mean(inter_class_similarities)) if inter_class_similarities else 0.0, "separation_score": ( float(np.mean(intra_class_similarities) - np.mean(inter_class_similarities)) if intra_class_similarities and inter_class_similarities else 0.0 ), "accuracy": nn_acc, "centroid_accuracy": centroid_acc, } def compute_embedding_accuracy( embeddings: np.ndarray, labels: List[str], similarities: Optional[np.ndarray] = None, ) -> float: """Nearest-neighbor classification accuracy (leave-one-out). Args: embeddings: Array of shape (N, D). labels: List of N class labels. similarities: Pre-computed cosine similarity matrix (N, N). Computed if not provided. Returns: Fraction of samples whose nearest neighbor shares their label. """ n = len(embeddings) if n == 0: return 0.0 if similarities is None: similarities = cosine_similarity(embeddings) correct = 0 for i in range(n): sims = similarities[i].copy() sims[i] = -1.0 if labels[np.argmax(sims)] == labels[i]: correct += 1 return correct / n def compute_centroid_accuracy( embeddings: np.ndarray, labels: List[str], ) -> float: """Centroid-based (1-NN centroid) classification accuracy. Uses L2-normalized embeddings and centroids for correct cosine comparison. Args: embeddings: Array of shape (N, D). labels: List of N class labels. Returns: Fraction of samples classified correctly by nearest centroid. """ if len(embeddings) == 0: return 0.0 emb_norm = normalize(embeddings, norm="l2") unique_labels = sorted(set(labels)) centroids = {} for label in unique_labels: idx = [i for i, l in enumerate(labels) if l == label] centroids[label] = normalize([emb_norm[idx].mean(axis=0)], norm="l2")[0] centroid_labels = list(centroids.keys()) centroid_matrix = np.vstack([centroids[l] for l in centroid_labels]) sims = cosine_similarity(emb_norm, centroid_matrix) predicted = [centroid_labels[int(np.argmax(row))] for row in sims] return sum(p == t for p, t in zip(predicted, labels)) / len(labels) def predict_labels_from_embeddings( embeddings: np.ndarray, labels: List[str], ) -> List[str]: """Predict a label for each embedding using nearest centroid. Returns: List of predicted labels (same length as embeddings). """ valid_labels = [l for l in set(labels) if l is not None] if not valid_labels: return [None] * len(embeddings) emb_norm = normalize(embeddings, norm="l2") centroids = {} for label in valid_labels: mask = np.array(labels) == label if np.any(mask): centroids[label] = np.mean(emb_norm[mask], axis=0) centroid_labels = list(centroids.keys()) centroid_matrix = np.vstack([centroids[l] for l in centroid_labels]) sims = cosine_similarity(emb_norm, centroid_matrix) return [centroid_labels[int(np.argmax(row))] for row in sims] def compute_worst_group_accuracy( embeddings: np.ndarray, labels: List[str], groups: List[str], ) -> Tuple[float, str, dict]: """Compute per-group nearest-neighbor accuracy and return the worst. This is the key metric for DFR (Kirichenko et al. 2023): it measures robustness to spurious correlations by reporting the accuracy of the worst-performing (color x hierarchy) group. Args: embeddings: Array of shape (N, D). labels: List of N class labels (what we classify). groups: List of N group labels (e.g. 'red_dress', 'blue_jeans'). Returns: (worst_accuracy, worst_group_name, per_group_dict) where per_group_dict maps group -> accuracy. """ similarities = cosine_similarity(embeddings) label_array = np.array(labels) group_array = np.array(groups) unique_groups = np.unique(group_array) per_group: dict = {} for g in unique_groups: mask = group_array == g idxs = np.where(mask)[0] if len(idxs) < 2: continue correct = 0 for i in idxs: sims = similarities[i].copy() sims[i] = -1.0 nn_idx = np.argmax(sims) if label_array[nn_idx] == label_array[i]: correct += 1 per_group[g] = correct / len(idxs) if not per_group: return 0.0, "", {} worst_group = min(per_group, key=per_group.get) return per_group[worst_group], worst_group, per_group def create_confusion_matrix( true_labels: List[str], predicted_labels: List[str], title: str = "Confusion Matrix", label_type: str = "Label", ) -> Tuple[plt.Figure, float, np.ndarray]: """Create and return a seaborn confusion-matrix heatmap figure. Args: true_labels: Ground-truth labels. predicted_labels: Predicted labels. title: Plot title prefix. label_type: Axis label (e.g. "Color", "Category"). Returns: (fig, accuracy, cm_array) """ unique_labels = sorted(set(true_labels + predicted_labels)) cm = confusion_matrix(true_labels, predicted_labels, labels=unique_labels) acc = accuracy_score(true_labels, predicted_labels) fig = plt.figure(figsize=(10, 8)) sns.heatmap( cm, annot=True, fmt="d", cmap="Blues", xticklabels=unique_labels, yticklabels=unique_labels, ) plt.title(f"{title}\nAccuracy: {acc:.3f} ({acc * 100:.1f}%)") plt.ylabel(f"True {label_type}") plt.xlabel(f"Predicted {label_type}") plt.xticks(rotation=45) plt.yticks(rotation=0) plt.tight_layout() return fig, acc, cm