geolip-hypersphere-experiments / hypersphere_convergence_analysis.py
AbstractPhil's picture
Create hypersphere_convergence_analysis.py (#1)
5945b22
#!/usr/bin/env python3
"""
GEOLIP HYPERSPHERE MANIFOLD VISUALIZATION
==========================================
6-panel manifold view + 3-panel expert perspective divergence.
S^255 projected to S^2 via PCA.
"""
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import math
DEVICE = "cpu"
# ══════════════════════════════════════════════════════════════════
# LOAD + EMBED
# ══════════════════════════════════════════════════════════════════
print("Loading soup...")
ckpt = torch.load("checkpoints/dual_stream_best.pt", map_location="cpu", weights_only=False)
sd = ckpt["state_dict"]
D_ANCHOR = ckpt["config"]["d_anchor"]
N_ANCHORS = ckpt["config"]["n_anchors"]
anchors = F.normalize(sd["constellation.anchors"], dim=-1)
EXPERTS = ["clip_l14_openai", "dinov2_b14", "siglip_b16_384"]
COCO_CLASSES = [
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train",
"truck", "boat", "traffic light", "fire hydrant", "stop sign",
"parking meter", "bench", "bird", "cat", "dog", "horse", "sheep",
"cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella",
"handbag", "tie", "suitcase", "frisbee", "skis", "snowboard",
"sports ball", "kite", "baseball bat", "baseball glove", "skateboard",
"surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork",
"knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange",
"broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair",
"couch", "potted plant", "bed", "dining table", "toilet", "tv",
"laptop", "mouse", "remote", "keyboard", "cell phone", "microwave",
"oven", "toaster", "sink", "refrigerator", "book", "clock", "vase",
"scissors", "teddy bear", "hair drier", "toothbrush",
]
print("Loading features...")
from datasets import load_dataset
ref = load_dataset("AbstractPhil/bulk-coco-features", EXPERTS[0], split="val")
val_ids = ref["image_id"]; N_val = len(val_ids)
val_id_map = {iid: i for i, iid in enumerate(val_ids)}
val_labels = torch.zeros(N_val, 80)
for i, labs in enumerate(ref["labels"]):
for l in labs:
if l < 80: val_labels[i, l] = 1.0
val_raw = {}
for name in EXPERTS:
ds = load_dataset("AbstractPhil/bulk-coco-features", name, split="val")
feats = torch.zeros(N_val, 768)
for row in ds:
if row["image_id"] in val_id_map:
feats[val_id_map[row["image_id"]]] = torch.tensor(row["features"], dtype=torch.float32)
val_raw[name] = feats; del ds
def project_expert(feats, i):
prefix = f"projectors.{i}.proj_shared" if f"projectors.{i}.proj_shared.0.weight" in sd else f"projectors.{i}.proj"
W = sd[f"{prefix}.0.weight"]
b = sd[f"{prefix}.0.bias"]
lw = sd[f"{prefix}.1.weight"]
lb = sd[f"{prefix}.1.bias"]
x = feats @ W.T + b
mu = x.mean(-1, keepdim=True); var = x.var(-1, keepdim=True, unbiased=False)
x = (x - mu) / (var + 1e-5).sqrt() * lw + lb
return F.normalize(x, dim=-1)
print("Generating embeddings...")
with torch.no_grad():
projected = [project_expert(val_raw[name], i) for i, name in enumerate(EXPERTS)]
fused = F.normalize(sum(projected) / 3, dim=-1)
# ══════════════════════════════════════════════════════════════════
# PCA β†’ 3D
# ══════════════════════════════════════════════════════════════════
emb = fused.numpy()
emb_centered = emb - emb.mean(axis=0, keepdims=True)
U, S, Vt = np.linalg.svd(emb_centered[:5000], full_matrices=False)
pca3 = Vt[:3]
emb_3d = emb @ pca3.T
anchors_3d = anchors.numpy() @ pca3.T
var_explained = S[:3]**2 / (S**2).sum()
print(f"PCA 3D variance: {var_explained.sum()*100:.1f}% "
f"({var_explained[0]*100:.1f}%, {var_explained[1]*100:.1f}%, {var_explained[2]*100:.1f}%)")
def to_sphere(pts):
norms = np.linalg.norm(pts, axis=-1, keepdims=True)
return pts / (norms + 1e-8)
emb_s = to_sphere(emb_3d)
anchors_s = to_sphere(anchors_3d)
# Reference sphere wireframe
phi = np.linspace(0, 2*np.pi, 60)
theta = np.linspace(0, np.pi, 30)
xs = np.outer(np.cos(phi), np.sin(theta))
ys = np.outer(np.sin(phi), np.sin(theta))
zs = np.outer(np.ones_like(phi), np.cos(theta))
# Primary class per image (most specific)
class_freq = val_labels.sum(0).numpy()
primary_class = np.zeros(N_val, dtype=int)
for i in range(N_val):
present = np.where(val_labels[i].numpy() > 0)[0]
if len(present) > 0:
primary_class[i] = present[class_freq[present].argmin()]
cmap20 = plt.cm.tab20(np.linspace(0, 1, 20))
class_colors = np.array([cmap20[primary_class[i] % 20] for i in range(N_val)])
# ══════════════════════════════════════════════════════════════════
# HELPER
# ══════════════════════════════════════════════════════════════════
def setup_ax(ax, title):
ax.set_facecolor('black')
ax.xaxis.pane.fill = False; ax.yaxis.pane.fill = False; ax.zaxis.pane.fill = False
ax.xaxis.pane.set_edgecolor('gray'); ax.yaxis.pane.set_edgecolor('gray')
ax.zaxis.pane.set_edgecolor('gray')
ax.set_xlabel('PC1', color='gray', fontsize=8)
ax.set_ylabel('PC2', color='gray', fontsize=8)
ax.set_zlabel('PC3', color='gray', fontsize=8)
ax.tick_params(colors='gray', labelsize=6)
ax.set_title(title, color='white', fontsize=11, pad=10)
ax.plot_wireframe(xs*0.98, ys*0.98, zs*0.98, alpha=0.03, color='white', linewidth=0.3)
ax.set_xlim(-1.3, 1.3); ax.set_ylim(-1.3, 1.3); ax.set_zlim(-1.3, 1.3)
# ══════════════════════════════════════════════════════════════════
# FIGURE 1: 6-PANEL MANIFOLD VIEW
# ══════════════════════════════════════════════════════════════════
print("Rendering figure 1...")
fig = plt.figure(figsize=(24, 16), facecolor='black')
fig.suptitle(
'GeoLIP Hypersphere Manifold β€” S²⁡⁡ projected to SΒ²\n'
f'{N_ANCHORS} anchors Γ— {D_ANCHOR}-d Γ— 3 experts | mAP={ckpt["mAP"]:.3f} | eff_dim=76.9',
color='white', fontsize=16, y=0.98)
# Panel 1: Full manifold
ax1 = fig.add_subplot(231, projection='3d')
setup_ax(ax1, f'Full Manifold β€” {N_val} embeddings + {N_ANCHORS} anchors')
ax1.scatter(emb_s[:, 0], emb_s[:, 1], emb_s[:, 2],
c=class_colors, s=1, alpha=0.3)
ax1.scatter(anchors_s[:, 0], anchors_s[:, 1], anchors_s[:, 2],
c='red', s=8, alpha=0.6, marker='^')
# Panel 2: Class centroids
ax2 = fig.add_subplot(232, projection='3d')
setup_ax(ax2, '80 COCO Class Centroids')
centroids = np.zeros((80, emb.shape[1]))
for c in range(80):
mask = val_labels[:, c].numpy() > 0
if mask.sum() > 0:
centroids[c] = emb[mask].mean(0)
centroids_3d = to_sphere(centroids @ pca3.T)
sizes = val_labels.sum(0).numpy()
sizes_scaled = 20 + 200 * (sizes / sizes.max())
colors80 = plt.cm.hsv(np.linspace(0, 0.95, 80))
ax2.scatter(centroids_3d[:, 0], centroids_3d[:, 1], centroids_3d[:, 2],
c=colors80, s=sizes_scaled, alpha=0.8, edgecolors='white', linewidth=0.3)
for c in [0, 2, 14, 15, 16, 22, 23, 56, 62]:
if sizes[c] > 30:
ax2.text(centroids_3d[c, 0]*1.15, centroids_3d[c, 1]*1.15,
centroids_3d[c, 2]*1.15,
COCO_CLASSES[c], color='white', fontsize=7, ha='center')
# Panel 3: 50 random with anchor connections
ax3 = fig.add_subplot(233, projection='3d')
setup_ax(ax3, '50 Random β€” nearest anchor connections')
np.random.seed(42)
idx50 = np.random.choice(N_val, 50, replace=False)
emb_50 = emb_s[idx50]
colors_50 = class_colors[idx50]
with torch.no_grad():
cos_50 = fused[idx50] @ anchors.T
nearest_50 = cos_50.argmax(-1).numpy()
ax3.scatter(anchors_s[:, 0], anchors_s[:, 1], anchors_s[:, 2],
c='red', s=4, alpha=0.2, marker='^')
ax3.scatter(emb_50[:, 0], emb_50[:, 1], emb_50[:, 2],
c=colors_50, s=40, alpha=0.9, edgecolors='white', linewidth=0.5)
for i in range(50):
a = nearest_50[i]
ax3.plot([emb_50[i, 0], anchors_s[a, 0]],
[emb_50[i, 1], anchors_s[a, 1]],
[emb_50[i, 2], anchors_s[a, 2]],
color='yellow', alpha=0.3, linewidth=0.5)
# Panel 4: 10 random β€” triangulation heatmap
ax4 = fig.add_subplot(234, projection='3d')
setup_ax(ax4, '10 Random β€” anchor affinity heatmap')
idx10 = np.random.choice(N_val, 10, replace=False)
emb_10 = emb_s[idx10]
with torch.no_grad():
cos_10 = (fused[idx10] @ anchors.T).numpy()
mean_cos = cos_10.mean(0)
anchor_heat = (mean_cos - mean_cos.min()) / (mean_cos.max() - mean_cos.min() + 1e-8)
anchor_colors = plt.cm.hot(anchor_heat)
ax4.scatter(anchors_s[:, 0], anchors_s[:, 1], anchors_s[:, 2],
c=anchor_colors, s=10, alpha=0.6)
ax4.scatter(emb_10[:, 0], emb_10[:, 1], emb_10[:, 2],
c='cyan', s=80, alpha=1.0, edgecolors='white', linewidth=1, zorder=10)
# Panel 5: Single encoding
ax5 = fig.add_subplot(235, projection='3d')
single_idx = 42
single_class = primary_class[single_idx]
setup_ax(ax5, f'Single Encoding: "{COCO_CLASSES[single_class]}" β€” top 5 anchors')
with torch.no_grad():
cos_single = (fused[single_idx] @ anchors.T).numpy()
single_heat = (cos_single - cos_single.min()) / (cos_single.max() - cos_single.min() + 1e-8)
single_colors = plt.cm.plasma(single_heat)
single_sizes = 2 + 50 * single_heat**3
ax5.scatter(anchors_s[:, 0], anchors_s[:, 1], anchors_s[:, 2],
c=single_colors, s=single_sizes, alpha=0.7)
single_pt = emb_s[single_idx]
ax5.scatter([single_pt[0]], [single_pt[1]], [single_pt[2]],
c='lime', s=150, alpha=1.0, edgecolors='white', linewidth=2,
zorder=10, marker='*')
top5 = np.argsort(cos_single)[::-1][:5]
for a in top5:
ax5.plot([single_pt[0], anchors_s[a, 0]],
[single_pt[1], anchors_s[a, 1]],
[single_pt[2], anchors_s[a, 2]],
color='lime', alpha=0.6, linewidth=1.5)
# Panel 6: Radial deviation
ax6 = fig.add_subplot(236, projection='3d')
radii = np.linalg.norm(emb_3d, axis=-1)
setup_ax(ax6, f'PCA Projection Radii β€” mean={radii.mean():.4f} std={radii.std():.4f}')
radius_dev = radii - radii.mean()
dev_norm = (radius_dev - radius_dev.min()) / (radius_dev.max() - radius_dev.min() + 1e-8)
dev_colors = plt.cm.coolwarm(dev_norm)
scale = 1.0 / radii.max()
ax6.scatter(emb_3d[:, 0]*scale, emb_3d[:, 1]*scale, emb_3d[:, 2]*scale,
c=dev_colors, s=2, alpha=0.4)
plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.savefig("hypersphere_manifold.png", dpi=200, facecolor='black',
bbox_inches='tight', pad_inches=0.3)
print("Saved: hypersphere_manifold.png")
plt.close()
# ══════════════════════════════════════════════════════════════════
# FIGURE 2: EXPERT PERSPECTIVES
# ══════════════════════════════════════════════════════════════════
print("Rendering figure 2...")
fig2 = plt.figure(figsize=(21, 7), facecolor='black')
fig2.suptitle('Expert Perspective Divergence β€” Same sphere, three lenses',
color='white', fontsize=14, y=1.02)
has_expert_rot = f"constellation.expert_rotations.0" in sd
if has_expert_rot:
expert_R = [sd[f"constellation.expert_rotations.{i}"] for i in range(3)]
expert_W = [sd[f"constellation.expert_whiteners.{i}"] for i in range(3)]
expert_mu = [sd[f"constellation.expert_means.{i}"] for i in range(3)]
else:
expert_R = [torch.eye(D_ANCHOR) for _ in range(3)]
expert_W = [torch.eye(D_ANCHOR) for _ in range(3)]
expert_mu = [torch.zeros(D_ANCHOR) for _ in range(3)]
with torch.no_grad():
for i, name in enumerate(EXPERTS):
ax = fig2.add_subplot(1, 3, i+1, projection='3d')
if has_expert_rot:
centered = fused.float() - expert_mu[i]
whitened = centered @ expert_W[i]
rotated = F.normalize(whitened @ expert_R[i].T, dim=-1)
elif f"projectors.{i}.proj_native.0.weight" in sd:
W = sd[f"projectors.{i}.proj_native.0.weight"]
b = sd[f"projectors.{i}.proj_native.0.bias"]
lw = sd[f"projectors.{i}.proj_native.1.weight"]
lb = sd[f"projectors.{i}.proj_native.1.bias"]
x = val_raw[name] @ W.T + b
mu_v = x.mean(-1, keepdim=True); var_v = x.var(-1, keepdim=True, unbiased=False)
x = (x - mu_v) / (var_v + 1e-5).sqrt() * lw + lb
rotated = F.normalize(x, dim=-1)
else:
rotated = projected[i]
rot_np = rotated.numpy()
rot_c = rot_np - rot_np.mean(axis=0, keepdims=True)
_, S_r, Vt_r = np.linalg.svd(rot_c[:5000], full_matrices=False)
rot_3d = to_sphere(rot_np @ Vt_r[:3].T)
var_exp = S_r[:3]**2 / (S_r**2).sum()
setup_ax(ax, f'{name[:25]}\nPC variance: {var_exp.sum()*100:.1f}%')
ax.scatter(rot_3d[:, 0], rot_3d[:, 1], rot_3d[:, 2],
c=class_colors, s=2, alpha=0.4)
plt.tight_layout()
plt.savefig("expert_perspectives.png", dpi=200, facecolor='black',
bbox_inches='tight', pad_inches=0.3)
print("Saved: expert_perspectives.png")
plt.close()
# ══════════════════════════════════════════════════════════════════
# FIGURE 3: ANCHORS ONLY
# ══════════════════════════════════════════════════════════════════
print("Rendering figure 3 β€” anchors only...")
# Anchor visit counts for coloring
with torch.no_grad():
cos_all = fused @ anchors.T
nearest_all = cos_all.argmax(dim=-1)
vc = torch.zeros(N_ANCHORS)
for n in nearest_all:
vc[n] += 1
vc_np = vc.numpy()
fig3 = plt.figure(figsize=(24, 8), facecolor='black')
fig3.suptitle(f'Constellation β€” {N_ANCHORS} anchors Γ— {D_ANCHOR}-d on S²⁡⁡',
color='white', fontsize=14, y=1.02)
# Panel 1: Anchors colored by visit count
ax_a1 = fig3.add_subplot(131, projection='3d')
setup_ax(ax_a1, f'Anchor Utilization β€” {int((vc_np>0).sum())}/{N_ANCHORS} active')
heat = np.zeros(N_ANCHORS)
active_mask = vc_np > 0
heat[active_mask] = np.log1p(vc_np[active_mask])
heat = heat / (heat.max() + 1e-8)
a_colors = plt.cm.inferno(heat)
a_sizes = 5 + 60 * heat
# Dead anchors in blue
dead_mask = vc_np == 0
ax_a1.scatter(anchors_s[dead_mask, 0], anchors_s[dead_mask, 1], anchors_s[dead_mask, 2],
c='dodgerblue', s=8, alpha=0.4, marker='x', label=f'dead ({int(dead_mask.sum())})')
ax_a1.scatter(anchors_s[active_mask, 0], anchors_s[active_mask, 1], anchors_s[active_mask, 2],
c=a_colors[active_mask], s=a_sizes[active_mask], alpha=0.8)
# Panel 2: Anchors colored by nearest neighbor distance
ax_a2 = fig3.add_subplot(132, projection='3d')
anchor_sim = (anchors.numpy() @ anchors.numpy().T)
np.fill_diagonal(anchor_sim, -1)
max_neighbor_cos = anchor_sim.max(axis=1)
nn_heat = (max_neighbor_cos - max_neighbor_cos.min()) / (max_neighbor_cos.max() - max_neighbor_cos.min() + 1e-8)
nn_colors = plt.cm.viridis(nn_heat)
setup_ax(ax_a2, f'Anchor Isolation β€” nearest neighbor cosine\n'
f'mean={max_neighbor_cos.mean():.3f} max={max_neighbor_cos.max():.3f}')
ax_a2.scatter(anchors_s[:, 0], anchors_s[:, 1], anchors_s[:, 2],
c=nn_colors, s=15, alpha=0.8)
# Panel 3: Anchors colored by expert divergence at that anchor
ax_a3 = fig3.add_subplot(133, projection='3d')
with torch.no_grad():
expert_tri_stack = []
if has_expert_rot:
for i in range(3):
centered = fused.float() - expert_mu[i]
whitened = centered @ expert_W[i]
rotated = F.normalize(whitened @ expert_R[i].T, dim=-1)
expert_tri_stack.append(1.0 - (rotated @ anchors.T))
elif f"projectors.0.proj_native.0.weight" in sd:
def _pn(feats, i):
W = sd[f"projectors.{i}.proj_native.0.weight"]
b = sd[f"projectors.{i}.proj_native.0.bias"]
lw = sd[f"projectors.{i}.proj_native.1.weight"]
lb = sd[f"projectors.{i}.proj_native.1.bias"]
x = feats @ W.T + b
mu = x.mean(-1, keepdim=True); var = x.var(-1, keepdim=True, unbiased=False)
x = (x - mu) / (var + 1e-5).sqrt() * lw + lb
return F.normalize(x, dim=-1)
for i, name in enumerate(EXPERTS):
nat = _pn(val_raw[name], i)
expert_tri_stack.append(1.0 - (nat @ anchors.T))
else:
for p in projected:
expert_tri_stack.append(1.0 - (p @ anchors.T))
tri_stack = torch.stack(expert_tri_stack, dim=-1)
per_anchor_div = tri_stack.std(dim=-1).mean(dim=0).numpy()
div_heat = (per_anchor_div - per_anchor_div.min()) / (per_anchor_div.max() - per_anchor_div.min() + 1e-8)
div_colors = plt.cm.coolwarm(div_heat)
setup_ax(ax_a3, f'Expert Divergence per Anchor\n'
f'mean={per_anchor_div.mean():.4f} range=[{per_anchor_div.min():.4f}, {per_anchor_div.max():.4f}]')
ax_a3.scatter(anchors_s[:, 0], anchors_s[:, 1], anchors_s[:, 2],
c=div_colors, s=15, alpha=0.8)
# Add connections between closest anchor pairs (top 20)
flat_sim = anchor_sim.copy()
np.fill_diagonal(flat_sim, -999)
for panel_ax in [ax_a1, ax_a2]:
for _ in range(20):
idx_flat = np.argmax(flat_sim)
i_a, j_a = np.unravel_index(idx_flat, flat_sim.shape)
flat_sim[i_a, j_a] = -999; flat_sim[j_a, i_a] = -999
panel_ax.plot([anchors_s[i_a, 0], anchors_s[j_a, 0]],
[anchors_s[i_a, 1], anchors_s[j_a, 1]],
[anchors_s[i_a, 2], anchors_s[j_a, 2]],
color='white', alpha=0.15, linewidth=0.5)
plt.tight_layout()
plt.savefig("anchors_only.png", dpi=200, facecolor='black',
bbox_inches='tight', pad_inches=0.3)
print("Saved: anchors_only.png")
plt.close()
# ══════════════════════════════════════════════════════════════════
# FIGURE 4: PAIRWISE EXPERT DIFFERENCES
# ══════════════════════════════════════════════════════════════════
print("Rendering figure 4 β€” pairwise expert diffs...")
with torch.no_grad():
# Compute per-expert triangulations
# For dual-stream: use native projectors (the actual expert perspectives)
# For fused constellation: use expert rotations
expert_tris = []
if has_expert_rot:
# Fused constellation: rotate through R/W/mu
for i in range(3):
centered = fused.float() - expert_mu[i]
whitened = centered @ expert_W[i]
rotated = F.normalize(whitened @ expert_R[i].T, dim=-1)
tri = 1.0 - (rotated @ anchors.T)
expert_tris.append(tri)
elif f"projectors.0.proj_native.0.weight" in sd:
# Dual-stream: use native projector embeddings
def _proj_native(feats, i):
W = sd[f"projectors.{i}.proj_native.0.weight"]
b = sd[f"projectors.{i}.proj_native.0.bias"]
lw = sd[f"projectors.{i}.proj_native.1.weight"]
lb = sd[f"projectors.{i}.proj_native.1.bias"]
x = feats @ W.T + b
mu = x.mean(-1, keepdim=True); var = x.var(-1, keepdim=True, unbiased=False)
x = (x - mu) / (var + 1e-5).sqrt() * lw + lb
return F.normalize(x, dim=-1)
for i, name in enumerate(EXPERTS):
native_emb = _proj_native(val_raw[name], i)
tri = 1.0 - (native_emb @ anchors.T)
expert_tris.append(tri)
else:
# Fallback: use shared projections (will be near-identical)
for p in projected:
tri = 1.0 - (p @ anchors.T)
expert_tris.append(tri)
# Pairwise diffs
diff_cd = expert_tris[0] - expert_tris[1]
diff_cs = expert_tris[0] - expert_tris[2]
diff_ds = expert_tris[1] - expert_tris[2]
diffs = [diff_cd, diff_cs, diff_ds]
diff_names = ["CLIP βˆ’ DINOv2", "CLIP βˆ’ SigLIP", "DINOv2 βˆ’ SigLIP"]
abs_tri = expert_tris[0]
print(f"\n Pairwise diff statistics:")
for name, d in zip(diff_names, diffs):
print(f" {name:20s}: mean={d.mean():.6f} std={d.std():.6f} "
f"min={d.min():.6f} max={d.max():.6f}")
print(f" Absolute tri std: {abs_tri.std():.6f}")
diff_std = diffs[0].std().item()
abs_std = abs_tri.std().item()
print(f" Ratio (diff/abs): {diff_std / abs_std:.4f}" if abs_std > 1e-10 else
f" Ratio (diff/abs): N/A (zero abs std)")
# PCA of the diff space
diff_stacked = torch.cat(diffs, dim=-1).numpy()
diff_centered = diff_stacked - diff_stacked.mean(axis=0, keepdims=True)
_, S_diff, Vt_diff = np.linalg.svd(diff_centered[:5000], full_matrices=False)
# Guard against zero SVDs
s_sum = (S_diff**2).sum()
if s_sum > 1e-20:
diff_3d = to_sphere(diff_centered @ Vt_diff[:3].T)
var_diff = S_diff[:3]**2 / s_sum
eff_dim_diff = float(((S_diff / S_diff.sum())**2).sum()**-1)
else:
diff_3d = np.zeros((len(diff_centered), 3))
var_diff = np.zeros(3)
eff_dim_diff = 0.0
print(f"\n Diff space effective dim: {eff_dim_diff:.1f}")
print(f" Diff PCA 3D variance: {var_diff.sum()*100:.1f}%")
abs_stacked = abs_tri.numpy()
abs_centered = abs_stacked - abs_stacked.mean(axis=0, keepdims=True)
_, S_abs, Vt_abs = np.linalg.svd(abs_centered[:5000], full_matrices=False)
abs_eff = float(((S_abs / S_abs.sum())**2).sum()**-1) if S_abs.sum() > 1e-20 else 0.0
print(f" Absolute tri effective dim: {abs_eff:.1f}")
full_stacked = np.concatenate([abs_stacked, diff_stacked], axis=-1)
full_centered = full_stacked - full_stacked.mean(axis=0, keepdims=True)
_, S_full, Vt_full = np.linalg.svd(full_centered[:5000], full_matrices=False)
full_eff = float(((S_full / S_full.sum())**2).sum()**-1) if S_full.sum() > 1e-20 else 0.0
full_3d = to_sphere(full_centered @ Vt_full[:3].T) if S_full.sum() > 1e-20 else np.zeros((len(full_centered), 3))
print(f" Full (abs+diffs) effective dim: {full_eff:.1f}")
print(f" Information gain from diffs: {full_eff - abs_eff:.1f} dimensions")
fig4 = plt.figure(figsize=(28, 14), facecolor='black')
fig4.suptitle(
'Expert Pairwise Differences β€” Where the discriminative signal lives\n'
f'Diff eff_dim={eff_dim_diff:.1f} | Abs eff_dim={abs_eff:.1f} | '
f'Combined eff_dim={full_eff:.1f} | Info gain: +{full_eff-abs_eff:.1f} dims',
color='white', fontsize=14, y=0.98)
# Row 1: Three pairwise diff distributions on sphere
for col, (name, d) in enumerate(zip(diff_names, diffs)):
ax = fig4.add_subplot(2, 4, col+1, projection='3d')
d_np = d.numpy()
# Per-image: magnitude of diff vector
diff_mag = np.linalg.norm(d_np, axis=-1)
mag_heat = (diff_mag - diff_mag.min()) / (diff_mag.max() - diff_mag.min() + 1e-8)
mag_colors = plt.cm.magma(mag_heat)
setup_ax(ax, f'{name}\nstd={d_np.std():.5f}')
ax.scatter(emb_s[:, 0], emb_s[:, 1], emb_s[:, 2],
c=mag_colors, s=2, alpha=0.5)
# Panel 4: Diff space PCA
ax_dp = fig4.add_subplot(244, projection='3d')
setup_ax(ax_dp, f'Diff Space PCA\neff_dim={eff_dim_diff:.1f} var={var_diff.sum()*100:.1f}%')
ax_dp.scatter(diff_3d[:, 0], diff_3d[:, 1], diff_3d[:, 2],
c=class_colors, s=2, alpha=0.4)
# Row 2: Per-anchor diff analysis
# Per-anchor mean absolute diff (where do experts disagree most?)
with torch.no_grad():
per_anchor_cd = diff_cd.abs().mean(dim=0).numpy()
per_anchor_cs = diff_cs.abs().mean(dim=0).numpy()
per_anchor_ds = diff_ds.abs().mean(dim=0).numpy()
per_anchor_total = (per_anchor_cd + per_anchor_cs + per_anchor_ds) / 3
# Panel 5: Anchor-level divergence map (total)
ax_a = fig4.add_subplot(245, projection='3d')
total_heat = (per_anchor_total - per_anchor_total.min()) / (per_anchor_total.max() - per_anchor_total.min() + 1e-8)
total_colors = plt.cm.hot(total_heat)
total_sizes = 5 + 40 * total_heat
setup_ax(ax_a, f'Anchor Divergence (all pairs)\n'
f'range=[{per_anchor_total.min():.5f}, {per_anchor_total.max():.5f}]')
ax_a.scatter(anchors_s[:, 0], anchors_s[:, 1], anchors_s[:, 2],
c=total_colors, s=total_sizes, alpha=0.8)
# Panel 6: Abs tri PCA vs diff PCA side by side
ax_abs = fig4.add_subplot(246, projection='3d')
abs_3d = to_sphere(abs_centered @ Vt_abs[:3].T)
var_abs_3 = S_abs[:3]**2 / (S_abs**2).sum()
setup_ax(ax_abs, f'Absolute Tri PCA\neff_dim={abs_eff:.1f} var={var_abs_3.sum()*100:.1f}%')
ax_abs.scatter(abs_3d[:, 0], abs_3d[:, 1], abs_3d[:, 2],
c=class_colors, s=2, alpha=0.4)
# Panel 7: Combined PCA
ax_full = fig4.add_subplot(247, projection='3d')
var_full_3 = S_full[:3]**2 / (S_full**2).sum()
setup_ax(ax_full, f'Combined (abs+diffs) PCA\neff_dim={full_eff:.1f} var={var_full_3.sum()*100:.1f}%')
ax_full.scatter(full_3d[:, 0], full_3d[:, 1], full_3d[:, 2],
c=class_colors, s=2, alpha=0.4)
# Panel 8: Histogram of diff magnitudes
ax_hist = fig4.add_subplot(248)
ax_hist.set_facecolor('black')
for name, d, color in zip(diff_names, diffs,
['#ff6b6b', '#4ecdc4', '#ffe66d']):
d_np = d.numpy()
per_image_mag = np.linalg.norm(d_np, axis=-1)
ax_hist.hist(per_image_mag, bins=50, alpha=0.6, color=color,
label=name, density=True)
ax_hist.set_xlabel('Diff magnitude (L2)', color='white', fontsize=9)
ax_hist.set_ylabel('Density', color='white', fontsize=9)
ax_hist.set_title('Per-image diff magnitudes', color='white', fontsize=11)
ax_hist.legend(fontsize=8, facecolor='black', edgecolor='gray',
labelcolor='white')
ax_hist.tick_params(colors='gray', labelsize=7)
ax_hist.spines['bottom'].set_color('gray'); ax_hist.spines['left'].set_color('gray')
ax_hist.spines['top'].set_visible(False); ax_hist.spines['right'].set_visible(False)
plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.savefig("pairwise_diffs.png", dpi=200, facecolor='black',
bbox_inches='tight', pad_inches=0.3)
print("Saved: pairwise_diffs.png")
plt.close()
print("\nDone.")