#!/usr/bin/env python3 """ GEOLIP HYPERSPHERE MANIFOLD VISUALIZATION ========================================== 6-panel manifold view + 3-panel expert perspective divergence. S^255 projected to S^2 via PCA. """ import torch import torch.nn.functional as F import numpy as np import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D import math DEVICE = "cpu" # ══════════════════════════════════════════════════════════════════ # LOAD + EMBED # ══════════════════════════════════════════════════════════════════ print("Loading soup...") ckpt = torch.load("checkpoints/dual_stream_best.pt", map_location="cpu", weights_only=False) sd = ckpt["state_dict"] D_ANCHOR = ckpt["config"]["d_anchor"] N_ANCHORS = ckpt["config"]["n_anchors"] anchors = F.normalize(sd["constellation.anchors"], dim=-1) EXPERTS = ["clip_l14_openai", "dinov2_b14", "siglip_b16_384"] COCO_CLASSES = [ "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush", ] print("Loading features...") from datasets import load_dataset ref = load_dataset("AbstractPhil/bulk-coco-features", EXPERTS[0], split="val") val_ids = ref["image_id"]; N_val = len(val_ids) val_id_map = {iid: i for i, iid in enumerate(val_ids)} val_labels = torch.zeros(N_val, 80) for i, labs in enumerate(ref["labels"]): for l in labs: if l < 80: val_labels[i, l] = 1.0 val_raw = {} for name in EXPERTS: ds = load_dataset("AbstractPhil/bulk-coco-features", name, split="val") feats = torch.zeros(N_val, 768) for row in ds: if row["image_id"] in val_id_map: feats[val_id_map[row["image_id"]]] = torch.tensor(row["features"], dtype=torch.float32) val_raw[name] = feats; del ds def project_expert(feats, i): prefix = f"projectors.{i}.proj_shared" if f"projectors.{i}.proj_shared.0.weight" in sd else f"projectors.{i}.proj" W = sd[f"{prefix}.0.weight"] b = sd[f"{prefix}.0.bias"] lw = sd[f"{prefix}.1.weight"] lb = sd[f"{prefix}.1.bias"] x = feats @ W.T + b mu = x.mean(-1, keepdim=True); var = x.var(-1, keepdim=True, unbiased=False) x = (x - mu) / (var + 1e-5).sqrt() * lw + lb return F.normalize(x, dim=-1) print("Generating embeddings...") with torch.no_grad(): projected = [project_expert(val_raw[name], i) for i, name in enumerate(EXPERTS)] fused = F.normalize(sum(projected) / 3, dim=-1) # ══════════════════════════════════════════════════════════════════ # PCA → 3D # ══════════════════════════════════════════════════════════════════ emb = fused.numpy() emb_centered = emb - emb.mean(axis=0, keepdims=True) U, S, Vt = np.linalg.svd(emb_centered[:5000], full_matrices=False) pca3 = Vt[:3] emb_3d = emb @ pca3.T anchors_3d = anchors.numpy() @ pca3.T var_explained = S[:3]**2 / (S**2).sum() print(f"PCA 3D variance: {var_explained.sum()*100:.1f}% " f"({var_explained[0]*100:.1f}%, {var_explained[1]*100:.1f}%, {var_explained[2]*100:.1f}%)") def to_sphere(pts): norms = np.linalg.norm(pts, axis=-1, keepdims=True) return pts / (norms + 1e-8) emb_s = to_sphere(emb_3d) anchors_s = to_sphere(anchors_3d) # Reference sphere wireframe phi = np.linspace(0, 2*np.pi, 60) theta = np.linspace(0, np.pi, 30) xs = np.outer(np.cos(phi), np.sin(theta)) ys = np.outer(np.sin(phi), np.sin(theta)) zs = np.outer(np.ones_like(phi), np.cos(theta)) # Primary class per image (most specific) class_freq = val_labels.sum(0).numpy() primary_class = np.zeros(N_val, dtype=int) for i in range(N_val): present = np.where(val_labels[i].numpy() > 0)[0] if len(present) > 0: primary_class[i] = present[class_freq[present].argmin()] cmap20 = plt.cm.tab20(np.linspace(0, 1, 20)) class_colors = np.array([cmap20[primary_class[i] % 20] for i in range(N_val)]) # ══════════════════════════════════════════════════════════════════ # HELPER # ══════════════════════════════════════════════════════════════════ def setup_ax(ax, title): ax.set_facecolor('black') ax.xaxis.pane.fill = False; ax.yaxis.pane.fill = False; ax.zaxis.pane.fill = False ax.xaxis.pane.set_edgecolor('gray'); ax.yaxis.pane.set_edgecolor('gray') ax.zaxis.pane.set_edgecolor('gray') ax.set_xlabel('PC1', color='gray', fontsize=8) ax.set_ylabel('PC2', color='gray', fontsize=8) ax.set_zlabel('PC3', color='gray', fontsize=8) ax.tick_params(colors='gray', labelsize=6) ax.set_title(title, color='white', fontsize=11, pad=10) ax.plot_wireframe(xs*0.98, ys*0.98, zs*0.98, alpha=0.03, color='white', linewidth=0.3) ax.set_xlim(-1.3, 1.3); ax.set_ylim(-1.3, 1.3); ax.set_zlim(-1.3, 1.3) # ══════════════════════════════════════════════════════════════════ # FIGURE 1: 6-PANEL MANIFOLD VIEW # ══════════════════════════════════════════════════════════════════ print("Rendering figure 1...") fig = plt.figure(figsize=(24, 16), facecolor='black') fig.suptitle( 'GeoLIP Hypersphere Manifold — S²⁵⁵ projected to S²\n' f'{N_ANCHORS} anchors × {D_ANCHOR}-d × 3 experts | mAP={ckpt["mAP"]:.3f} | eff_dim=76.9', color='white', fontsize=16, y=0.98) # Panel 1: Full manifold ax1 = fig.add_subplot(231, projection='3d') setup_ax(ax1, f'Full Manifold — {N_val} embeddings + {N_ANCHORS} anchors') ax1.scatter(emb_s[:, 0], emb_s[:, 1], emb_s[:, 2], c=class_colors, s=1, alpha=0.3) ax1.scatter(anchors_s[:, 0], anchors_s[:, 1], anchors_s[:, 2], c='red', s=8, alpha=0.6, marker='^') # Panel 2: Class centroids ax2 = fig.add_subplot(232, projection='3d') setup_ax(ax2, '80 COCO Class Centroids') centroids = np.zeros((80, emb.shape[1])) for c in range(80): mask = val_labels[:, c].numpy() > 0 if mask.sum() > 0: centroids[c] = emb[mask].mean(0) centroids_3d = to_sphere(centroids @ pca3.T) sizes = val_labels.sum(0).numpy() sizes_scaled = 20 + 200 * (sizes / sizes.max()) colors80 = plt.cm.hsv(np.linspace(0, 0.95, 80)) ax2.scatter(centroids_3d[:, 0], centroids_3d[:, 1], centroids_3d[:, 2], c=colors80, s=sizes_scaled, alpha=0.8, edgecolors='white', linewidth=0.3) for c in [0, 2, 14, 15, 16, 22, 23, 56, 62]: if sizes[c] > 30: ax2.text(centroids_3d[c, 0]*1.15, centroids_3d[c, 1]*1.15, centroids_3d[c, 2]*1.15, COCO_CLASSES[c], color='white', fontsize=7, ha='center') # Panel 3: 50 random with anchor connections ax3 = fig.add_subplot(233, projection='3d') setup_ax(ax3, '50 Random — nearest anchor connections') np.random.seed(42) idx50 = np.random.choice(N_val, 50, replace=False) emb_50 = emb_s[idx50] colors_50 = class_colors[idx50] with torch.no_grad(): cos_50 = fused[idx50] @ anchors.T nearest_50 = cos_50.argmax(-1).numpy() ax3.scatter(anchors_s[:, 0], anchors_s[:, 1], anchors_s[:, 2], c='red', s=4, alpha=0.2, marker='^') ax3.scatter(emb_50[:, 0], emb_50[:, 1], emb_50[:, 2], c=colors_50, s=40, alpha=0.9, edgecolors='white', linewidth=0.5) for i in range(50): a = nearest_50[i] ax3.plot([emb_50[i, 0], anchors_s[a, 0]], [emb_50[i, 1], anchors_s[a, 1]], [emb_50[i, 2], anchors_s[a, 2]], color='yellow', alpha=0.3, linewidth=0.5) # Panel 4: 10 random — triangulation heatmap ax4 = fig.add_subplot(234, projection='3d') setup_ax(ax4, '10 Random — anchor affinity heatmap') idx10 = np.random.choice(N_val, 10, replace=False) emb_10 = emb_s[idx10] with torch.no_grad(): cos_10 = (fused[idx10] @ anchors.T).numpy() mean_cos = cos_10.mean(0) anchor_heat = (mean_cos - mean_cos.min()) / (mean_cos.max() - mean_cos.min() + 1e-8) anchor_colors = plt.cm.hot(anchor_heat) ax4.scatter(anchors_s[:, 0], anchors_s[:, 1], anchors_s[:, 2], c=anchor_colors, s=10, alpha=0.6) ax4.scatter(emb_10[:, 0], emb_10[:, 1], emb_10[:, 2], c='cyan', s=80, alpha=1.0, edgecolors='white', linewidth=1, zorder=10) # Panel 5: Single encoding ax5 = fig.add_subplot(235, projection='3d') single_idx = 42 single_class = primary_class[single_idx] setup_ax(ax5, f'Single Encoding: "{COCO_CLASSES[single_class]}" — top 5 anchors') with torch.no_grad(): cos_single = (fused[single_idx] @ anchors.T).numpy() single_heat = (cos_single - cos_single.min()) / (cos_single.max() - cos_single.min() + 1e-8) single_colors = plt.cm.plasma(single_heat) single_sizes = 2 + 50 * single_heat**3 ax5.scatter(anchors_s[:, 0], anchors_s[:, 1], anchors_s[:, 2], c=single_colors, s=single_sizes, alpha=0.7) single_pt = emb_s[single_idx] ax5.scatter([single_pt[0]], [single_pt[1]], [single_pt[2]], c='lime', s=150, alpha=1.0, edgecolors='white', linewidth=2, zorder=10, marker='*') top5 = np.argsort(cos_single)[::-1][:5] for a in top5: ax5.plot([single_pt[0], anchors_s[a, 0]], [single_pt[1], anchors_s[a, 1]], [single_pt[2], anchors_s[a, 2]], color='lime', alpha=0.6, linewidth=1.5) # Panel 6: Radial deviation ax6 = fig.add_subplot(236, projection='3d') radii = np.linalg.norm(emb_3d, axis=-1) setup_ax(ax6, f'PCA Projection Radii — mean={radii.mean():.4f} std={radii.std():.4f}') radius_dev = radii - radii.mean() dev_norm = (radius_dev - radius_dev.min()) / (radius_dev.max() - radius_dev.min() + 1e-8) dev_colors = plt.cm.coolwarm(dev_norm) scale = 1.0 / radii.max() ax6.scatter(emb_3d[:, 0]*scale, emb_3d[:, 1]*scale, emb_3d[:, 2]*scale, c=dev_colors, s=2, alpha=0.4) plt.tight_layout(rect=[0, 0, 1, 0.95]) plt.savefig("hypersphere_manifold.png", dpi=200, facecolor='black', bbox_inches='tight', pad_inches=0.3) print("Saved: hypersphere_manifold.png") plt.close() # ══════════════════════════════════════════════════════════════════ # FIGURE 2: EXPERT PERSPECTIVES # ══════════════════════════════════════════════════════════════════ print("Rendering figure 2...") fig2 = plt.figure(figsize=(21, 7), facecolor='black') fig2.suptitle('Expert Perspective Divergence — Same sphere, three lenses', color='white', fontsize=14, y=1.02) has_expert_rot = f"constellation.expert_rotations.0" in sd if has_expert_rot: expert_R = [sd[f"constellation.expert_rotations.{i}"] for i in range(3)] expert_W = [sd[f"constellation.expert_whiteners.{i}"] for i in range(3)] expert_mu = [sd[f"constellation.expert_means.{i}"] for i in range(3)] else: expert_R = [torch.eye(D_ANCHOR) for _ in range(3)] expert_W = [torch.eye(D_ANCHOR) for _ in range(3)] expert_mu = [torch.zeros(D_ANCHOR) for _ in range(3)] with torch.no_grad(): for i, name in enumerate(EXPERTS): ax = fig2.add_subplot(1, 3, i+1, projection='3d') if has_expert_rot: centered = fused.float() - expert_mu[i] whitened = centered @ expert_W[i] rotated = F.normalize(whitened @ expert_R[i].T, dim=-1) elif f"projectors.{i}.proj_native.0.weight" in sd: W = sd[f"projectors.{i}.proj_native.0.weight"] b = sd[f"projectors.{i}.proj_native.0.bias"] lw = sd[f"projectors.{i}.proj_native.1.weight"] lb = sd[f"projectors.{i}.proj_native.1.bias"] x = val_raw[name] @ W.T + b mu_v = x.mean(-1, keepdim=True); var_v = x.var(-1, keepdim=True, unbiased=False) x = (x - mu_v) / (var_v + 1e-5).sqrt() * lw + lb rotated = F.normalize(x, dim=-1) else: rotated = projected[i] rot_np = rotated.numpy() rot_c = rot_np - rot_np.mean(axis=0, keepdims=True) _, S_r, Vt_r = np.linalg.svd(rot_c[:5000], full_matrices=False) rot_3d = to_sphere(rot_np @ Vt_r[:3].T) var_exp = S_r[:3]**2 / (S_r**2).sum() setup_ax(ax, f'{name[:25]}\nPC variance: {var_exp.sum()*100:.1f}%') ax.scatter(rot_3d[:, 0], rot_3d[:, 1], rot_3d[:, 2], c=class_colors, s=2, alpha=0.4) plt.tight_layout() plt.savefig("expert_perspectives.png", dpi=200, facecolor='black', bbox_inches='tight', pad_inches=0.3) print("Saved: expert_perspectives.png") plt.close() # ══════════════════════════════════════════════════════════════════ # FIGURE 3: ANCHORS ONLY # ══════════════════════════════════════════════════════════════════ print("Rendering figure 3 — anchors only...") # Anchor visit counts for coloring with torch.no_grad(): cos_all = fused @ anchors.T nearest_all = cos_all.argmax(dim=-1) vc = torch.zeros(N_ANCHORS) for n in nearest_all: vc[n] += 1 vc_np = vc.numpy() fig3 = plt.figure(figsize=(24, 8), facecolor='black') fig3.suptitle(f'Constellation — {N_ANCHORS} anchors × {D_ANCHOR}-d on S²⁵⁵', color='white', fontsize=14, y=1.02) # Panel 1: Anchors colored by visit count ax_a1 = fig3.add_subplot(131, projection='3d') setup_ax(ax_a1, f'Anchor Utilization — {int((vc_np>0).sum())}/{N_ANCHORS} active') heat = np.zeros(N_ANCHORS) active_mask = vc_np > 0 heat[active_mask] = np.log1p(vc_np[active_mask]) heat = heat / (heat.max() + 1e-8) a_colors = plt.cm.inferno(heat) a_sizes = 5 + 60 * heat # Dead anchors in blue dead_mask = vc_np == 0 ax_a1.scatter(anchors_s[dead_mask, 0], anchors_s[dead_mask, 1], anchors_s[dead_mask, 2], c='dodgerblue', s=8, alpha=0.4, marker='x', label=f'dead ({int(dead_mask.sum())})') ax_a1.scatter(anchors_s[active_mask, 0], anchors_s[active_mask, 1], anchors_s[active_mask, 2], c=a_colors[active_mask], s=a_sizes[active_mask], alpha=0.8) # Panel 2: Anchors colored by nearest neighbor distance ax_a2 = fig3.add_subplot(132, projection='3d') anchor_sim = (anchors.numpy() @ anchors.numpy().T) np.fill_diagonal(anchor_sim, -1) max_neighbor_cos = anchor_sim.max(axis=1) nn_heat = (max_neighbor_cos - max_neighbor_cos.min()) / (max_neighbor_cos.max() - max_neighbor_cos.min() + 1e-8) nn_colors = plt.cm.viridis(nn_heat) setup_ax(ax_a2, f'Anchor Isolation — nearest neighbor cosine\n' f'mean={max_neighbor_cos.mean():.3f} max={max_neighbor_cos.max():.3f}') ax_a2.scatter(anchors_s[:, 0], anchors_s[:, 1], anchors_s[:, 2], c=nn_colors, s=15, alpha=0.8) # Panel 3: Anchors colored by expert divergence at that anchor ax_a3 = fig3.add_subplot(133, projection='3d') with torch.no_grad(): expert_tri_stack = [] if has_expert_rot: for i in range(3): centered = fused.float() - expert_mu[i] whitened = centered @ expert_W[i] rotated = F.normalize(whitened @ expert_R[i].T, dim=-1) expert_tri_stack.append(1.0 - (rotated @ anchors.T)) elif f"projectors.0.proj_native.0.weight" in sd: def _pn(feats, i): W = sd[f"projectors.{i}.proj_native.0.weight"] b = sd[f"projectors.{i}.proj_native.0.bias"] lw = sd[f"projectors.{i}.proj_native.1.weight"] lb = sd[f"projectors.{i}.proj_native.1.bias"] x = feats @ W.T + b mu = x.mean(-1, keepdim=True); var = x.var(-1, keepdim=True, unbiased=False) x = (x - mu) / (var + 1e-5).sqrt() * lw + lb return F.normalize(x, dim=-1) for i, name in enumerate(EXPERTS): nat = _pn(val_raw[name], i) expert_tri_stack.append(1.0 - (nat @ anchors.T)) else: for p in projected: expert_tri_stack.append(1.0 - (p @ anchors.T)) tri_stack = torch.stack(expert_tri_stack, dim=-1) per_anchor_div = tri_stack.std(dim=-1).mean(dim=0).numpy() div_heat = (per_anchor_div - per_anchor_div.min()) / (per_anchor_div.max() - per_anchor_div.min() + 1e-8) div_colors = plt.cm.coolwarm(div_heat) setup_ax(ax_a3, f'Expert Divergence per Anchor\n' f'mean={per_anchor_div.mean():.4f} range=[{per_anchor_div.min():.4f}, {per_anchor_div.max():.4f}]') ax_a3.scatter(anchors_s[:, 0], anchors_s[:, 1], anchors_s[:, 2], c=div_colors, s=15, alpha=0.8) # Add connections between closest anchor pairs (top 20) flat_sim = anchor_sim.copy() np.fill_diagonal(flat_sim, -999) for panel_ax in [ax_a1, ax_a2]: for _ in range(20): idx_flat = np.argmax(flat_sim) i_a, j_a = np.unravel_index(idx_flat, flat_sim.shape) flat_sim[i_a, j_a] = -999; flat_sim[j_a, i_a] = -999 panel_ax.plot([anchors_s[i_a, 0], anchors_s[j_a, 0]], [anchors_s[i_a, 1], anchors_s[j_a, 1]], [anchors_s[i_a, 2], anchors_s[j_a, 2]], color='white', alpha=0.15, linewidth=0.5) plt.tight_layout() plt.savefig("anchors_only.png", dpi=200, facecolor='black', bbox_inches='tight', pad_inches=0.3) print("Saved: anchors_only.png") plt.close() # ══════════════════════════════════════════════════════════════════ # FIGURE 4: PAIRWISE EXPERT DIFFERENCES # ══════════════════════════════════════════════════════════════════ print("Rendering figure 4 — pairwise expert diffs...") with torch.no_grad(): # Compute per-expert triangulations # For dual-stream: use native projectors (the actual expert perspectives) # For fused constellation: use expert rotations expert_tris = [] if has_expert_rot: # Fused constellation: rotate through R/W/mu for i in range(3): centered = fused.float() - expert_mu[i] whitened = centered @ expert_W[i] rotated = F.normalize(whitened @ expert_R[i].T, dim=-1) tri = 1.0 - (rotated @ anchors.T) expert_tris.append(tri) elif f"projectors.0.proj_native.0.weight" in sd: # Dual-stream: use native projector embeddings def _proj_native(feats, i): W = sd[f"projectors.{i}.proj_native.0.weight"] b = sd[f"projectors.{i}.proj_native.0.bias"] lw = sd[f"projectors.{i}.proj_native.1.weight"] lb = sd[f"projectors.{i}.proj_native.1.bias"] x = feats @ W.T + b mu = x.mean(-1, keepdim=True); var = x.var(-1, keepdim=True, unbiased=False) x = (x - mu) / (var + 1e-5).sqrt() * lw + lb return F.normalize(x, dim=-1) for i, name in enumerate(EXPERTS): native_emb = _proj_native(val_raw[name], i) tri = 1.0 - (native_emb @ anchors.T) expert_tris.append(tri) else: # Fallback: use shared projections (will be near-identical) for p in projected: tri = 1.0 - (p @ anchors.T) expert_tris.append(tri) # Pairwise diffs diff_cd = expert_tris[0] - expert_tris[1] diff_cs = expert_tris[0] - expert_tris[2] diff_ds = expert_tris[1] - expert_tris[2] diffs = [diff_cd, diff_cs, diff_ds] diff_names = ["CLIP − DINOv2", "CLIP − SigLIP", "DINOv2 − SigLIP"] abs_tri = expert_tris[0] print(f"\n Pairwise diff statistics:") for name, d in zip(diff_names, diffs): print(f" {name:20s}: mean={d.mean():.6f} std={d.std():.6f} " f"min={d.min():.6f} max={d.max():.6f}") print(f" Absolute tri std: {abs_tri.std():.6f}") diff_std = diffs[0].std().item() abs_std = abs_tri.std().item() print(f" Ratio (diff/abs): {diff_std / abs_std:.4f}" if abs_std > 1e-10 else f" Ratio (diff/abs): N/A (zero abs std)") # PCA of the diff space diff_stacked = torch.cat(diffs, dim=-1).numpy() diff_centered = diff_stacked - diff_stacked.mean(axis=0, keepdims=True) _, S_diff, Vt_diff = np.linalg.svd(diff_centered[:5000], full_matrices=False) # Guard against zero SVDs s_sum = (S_diff**2).sum() if s_sum > 1e-20: diff_3d = to_sphere(diff_centered @ Vt_diff[:3].T) var_diff = S_diff[:3]**2 / s_sum eff_dim_diff = float(((S_diff / S_diff.sum())**2).sum()**-1) else: diff_3d = np.zeros((len(diff_centered), 3)) var_diff = np.zeros(3) eff_dim_diff = 0.0 print(f"\n Diff space effective dim: {eff_dim_diff:.1f}") print(f" Diff PCA 3D variance: {var_diff.sum()*100:.1f}%") abs_stacked = abs_tri.numpy() abs_centered = abs_stacked - abs_stacked.mean(axis=0, keepdims=True) _, S_abs, Vt_abs = np.linalg.svd(abs_centered[:5000], full_matrices=False) abs_eff = float(((S_abs / S_abs.sum())**2).sum()**-1) if S_abs.sum() > 1e-20 else 0.0 print(f" Absolute tri effective dim: {abs_eff:.1f}") full_stacked = np.concatenate([abs_stacked, diff_stacked], axis=-1) full_centered = full_stacked - full_stacked.mean(axis=0, keepdims=True) _, S_full, Vt_full = np.linalg.svd(full_centered[:5000], full_matrices=False) full_eff = float(((S_full / S_full.sum())**2).sum()**-1) if S_full.sum() > 1e-20 else 0.0 full_3d = to_sphere(full_centered @ Vt_full[:3].T) if S_full.sum() > 1e-20 else np.zeros((len(full_centered), 3)) print(f" Full (abs+diffs) effective dim: {full_eff:.1f}") print(f" Information gain from diffs: {full_eff - abs_eff:.1f} dimensions") fig4 = plt.figure(figsize=(28, 14), facecolor='black') fig4.suptitle( 'Expert Pairwise Differences — Where the discriminative signal lives\n' f'Diff eff_dim={eff_dim_diff:.1f} | Abs eff_dim={abs_eff:.1f} | ' f'Combined eff_dim={full_eff:.1f} | Info gain: +{full_eff-abs_eff:.1f} dims', color='white', fontsize=14, y=0.98) # Row 1: Three pairwise diff distributions on sphere for col, (name, d) in enumerate(zip(diff_names, diffs)): ax = fig4.add_subplot(2, 4, col+1, projection='3d') d_np = d.numpy() # Per-image: magnitude of diff vector diff_mag = np.linalg.norm(d_np, axis=-1) mag_heat = (diff_mag - diff_mag.min()) / (diff_mag.max() - diff_mag.min() + 1e-8) mag_colors = plt.cm.magma(mag_heat) setup_ax(ax, f'{name}\nstd={d_np.std():.5f}') ax.scatter(emb_s[:, 0], emb_s[:, 1], emb_s[:, 2], c=mag_colors, s=2, alpha=0.5) # Panel 4: Diff space PCA ax_dp = fig4.add_subplot(244, projection='3d') setup_ax(ax_dp, f'Diff Space PCA\neff_dim={eff_dim_diff:.1f} var={var_diff.sum()*100:.1f}%') ax_dp.scatter(diff_3d[:, 0], diff_3d[:, 1], diff_3d[:, 2], c=class_colors, s=2, alpha=0.4) # Row 2: Per-anchor diff analysis # Per-anchor mean absolute diff (where do experts disagree most?) with torch.no_grad(): per_anchor_cd = diff_cd.abs().mean(dim=0).numpy() per_anchor_cs = diff_cs.abs().mean(dim=0).numpy() per_anchor_ds = diff_ds.abs().mean(dim=0).numpy() per_anchor_total = (per_anchor_cd + per_anchor_cs + per_anchor_ds) / 3 # Panel 5: Anchor-level divergence map (total) ax_a = fig4.add_subplot(245, projection='3d') total_heat = (per_anchor_total - per_anchor_total.min()) / (per_anchor_total.max() - per_anchor_total.min() + 1e-8) total_colors = plt.cm.hot(total_heat) total_sizes = 5 + 40 * total_heat setup_ax(ax_a, f'Anchor Divergence (all pairs)\n' f'range=[{per_anchor_total.min():.5f}, {per_anchor_total.max():.5f}]') ax_a.scatter(anchors_s[:, 0], anchors_s[:, 1], anchors_s[:, 2], c=total_colors, s=total_sizes, alpha=0.8) # Panel 6: Abs tri PCA vs diff PCA side by side ax_abs = fig4.add_subplot(246, projection='3d') abs_3d = to_sphere(abs_centered @ Vt_abs[:3].T) var_abs_3 = S_abs[:3]**2 / (S_abs**2).sum() setup_ax(ax_abs, f'Absolute Tri PCA\neff_dim={abs_eff:.1f} var={var_abs_3.sum()*100:.1f}%') ax_abs.scatter(abs_3d[:, 0], abs_3d[:, 1], abs_3d[:, 2], c=class_colors, s=2, alpha=0.4) # Panel 7: Combined PCA ax_full = fig4.add_subplot(247, projection='3d') var_full_3 = S_full[:3]**2 / (S_full**2).sum() setup_ax(ax_full, f'Combined (abs+diffs) PCA\neff_dim={full_eff:.1f} var={var_full_3.sum()*100:.1f}%') ax_full.scatter(full_3d[:, 0], full_3d[:, 1], full_3d[:, 2], c=class_colors, s=2, alpha=0.4) # Panel 8: Histogram of diff magnitudes ax_hist = fig4.add_subplot(248) ax_hist.set_facecolor('black') for name, d, color in zip(diff_names, diffs, ['#ff6b6b', '#4ecdc4', '#ffe66d']): d_np = d.numpy() per_image_mag = np.linalg.norm(d_np, axis=-1) ax_hist.hist(per_image_mag, bins=50, alpha=0.6, color=color, label=name, density=True) ax_hist.set_xlabel('Diff magnitude (L2)', color='white', fontsize=9) ax_hist.set_ylabel('Density', color='white', fontsize=9) ax_hist.set_title('Per-image diff magnitudes', color='white', fontsize=11) ax_hist.legend(fontsize=8, facecolor='black', edgecolor='gray', labelcolor='white') ax_hist.tick_params(colors='gray', labelsize=7) ax_hist.spines['bottom'].set_color('gray'); ax_hist.spines['left'].set_color('gray') ax_hist.spines['top'].set_visible(False); ax_hist.spines['right'].set_visible(False) plt.tight_layout(rect=[0, 0, 1, 0.95]) plt.savefig("pairwise_diffs.png", dpi=200, facecolor='black', bbox_inches='tight', pad_inches=0.3) print("Saved: pairwise_diffs.png") plt.close() print("\nDone.")