| |
| """ |
| GEOLIP HYPERSPHERE MANIFOLD VISUALIZATION |
| ========================================== |
| 6-panel manifold view + 3-panel expert perspective divergence. |
| S^255 projected to S^2 via PCA. |
| """ |
|
|
| import torch |
| import torch.nn.functional as F |
| import numpy as np |
| import matplotlib |
| matplotlib.use('Agg') |
| import matplotlib.pyplot as plt |
| from mpl_toolkits.mplot3d import Axes3D |
| import math |
|
|
| DEVICE = "cpu" |
|
|
| |
| |
| |
|
|
| print("Loading soup...") |
| ckpt = torch.load("checkpoints/dual_stream_best.pt", map_location="cpu", weights_only=False) |
| sd = ckpt["state_dict"] |
| D_ANCHOR = ckpt["config"]["d_anchor"] |
| N_ANCHORS = ckpt["config"]["n_anchors"] |
| anchors = F.normalize(sd["constellation.anchors"], dim=-1) |
|
|
| EXPERTS = ["clip_l14_openai", "dinov2_b14", "siglip_b16_384"] |
| COCO_CLASSES = [ |
| "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", |
| "truck", "boat", "traffic light", "fire hydrant", "stop sign", |
| "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", |
| "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", |
| "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", |
| "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", |
| "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork", |
| "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", |
| "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", |
| "couch", "potted plant", "bed", "dining table", "toilet", "tv", |
| "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", |
| "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", |
| "scissors", "teddy bear", "hair drier", "toothbrush", |
| ] |
|
|
| print("Loading features...") |
| from datasets import load_dataset |
|
|
| ref = load_dataset("AbstractPhil/bulk-coco-features", EXPERTS[0], split="val") |
| val_ids = ref["image_id"]; N_val = len(val_ids) |
| val_id_map = {iid: i for i, iid in enumerate(val_ids)} |
| val_labels = torch.zeros(N_val, 80) |
| for i, labs in enumerate(ref["labels"]): |
| for l in labs: |
| if l < 80: val_labels[i, l] = 1.0 |
|
|
| val_raw = {} |
| for name in EXPERTS: |
| ds = load_dataset("AbstractPhil/bulk-coco-features", name, split="val") |
| feats = torch.zeros(N_val, 768) |
| for row in ds: |
| if row["image_id"] in val_id_map: |
| feats[val_id_map[row["image_id"]]] = torch.tensor(row["features"], dtype=torch.float32) |
| val_raw[name] = feats; del ds |
|
|
| def project_expert(feats, i): |
| prefix = f"projectors.{i}.proj_shared" if f"projectors.{i}.proj_shared.0.weight" in sd else f"projectors.{i}.proj" |
| W = sd[f"{prefix}.0.weight"] |
| b = sd[f"{prefix}.0.bias"] |
| lw = sd[f"{prefix}.1.weight"] |
| lb = sd[f"{prefix}.1.bias"] |
| x = feats @ W.T + b |
| mu = x.mean(-1, keepdim=True); var = x.var(-1, keepdim=True, unbiased=False) |
| x = (x - mu) / (var + 1e-5).sqrt() * lw + lb |
| return F.normalize(x, dim=-1) |
|
|
| print("Generating embeddings...") |
| with torch.no_grad(): |
| projected = [project_expert(val_raw[name], i) for i, name in enumerate(EXPERTS)] |
| fused = F.normalize(sum(projected) / 3, dim=-1) |
|
|
| |
| |
| |
|
|
| emb = fused.numpy() |
| emb_centered = emb - emb.mean(axis=0, keepdims=True) |
| U, S, Vt = np.linalg.svd(emb_centered[:5000], full_matrices=False) |
| pca3 = Vt[:3] |
|
|
| emb_3d = emb @ pca3.T |
| anchors_3d = anchors.numpy() @ pca3.T |
|
|
| var_explained = S[:3]**2 / (S**2).sum() |
| print(f"PCA 3D variance: {var_explained.sum()*100:.1f}% " |
| f"({var_explained[0]*100:.1f}%, {var_explained[1]*100:.1f}%, {var_explained[2]*100:.1f}%)") |
|
|
| def to_sphere(pts): |
| norms = np.linalg.norm(pts, axis=-1, keepdims=True) |
| return pts / (norms + 1e-8) |
|
|
| emb_s = to_sphere(emb_3d) |
| anchors_s = to_sphere(anchors_3d) |
|
|
| |
| phi = np.linspace(0, 2*np.pi, 60) |
| theta = np.linspace(0, np.pi, 30) |
| xs = np.outer(np.cos(phi), np.sin(theta)) |
| ys = np.outer(np.sin(phi), np.sin(theta)) |
| zs = np.outer(np.ones_like(phi), np.cos(theta)) |
|
|
| |
| class_freq = val_labels.sum(0).numpy() |
| primary_class = np.zeros(N_val, dtype=int) |
| for i in range(N_val): |
| present = np.where(val_labels[i].numpy() > 0)[0] |
| if len(present) > 0: |
| primary_class[i] = present[class_freq[present].argmin()] |
|
|
| cmap20 = plt.cm.tab20(np.linspace(0, 1, 20)) |
| class_colors = np.array([cmap20[primary_class[i] % 20] for i in range(N_val)]) |
|
|
|
|
| |
| |
| |
|
|
| def setup_ax(ax, title): |
| ax.set_facecolor('black') |
| ax.xaxis.pane.fill = False; ax.yaxis.pane.fill = False; ax.zaxis.pane.fill = False |
| ax.xaxis.pane.set_edgecolor('gray'); ax.yaxis.pane.set_edgecolor('gray') |
| ax.zaxis.pane.set_edgecolor('gray') |
| ax.set_xlabel('PC1', color='gray', fontsize=8) |
| ax.set_ylabel('PC2', color='gray', fontsize=8) |
| ax.set_zlabel('PC3', color='gray', fontsize=8) |
| ax.tick_params(colors='gray', labelsize=6) |
| ax.set_title(title, color='white', fontsize=11, pad=10) |
| ax.plot_wireframe(xs*0.98, ys*0.98, zs*0.98, alpha=0.03, color='white', linewidth=0.3) |
| ax.set_xlim(-1.3, 1.3); ax.set_ylim(-1.3, 1.3); ax.set_zlim(-1.3, 1.3) |
|
|
|
|
| |
| |
| |
|
|
| print("Rendering figure 1...") |
| fig = plt.figure(figsize=(24, 16), facecolor='black') |
| fig.suptitle( |
| 'GeoLIP Hypersphere Manifold β SΒ²β΅β΅ projected to SΒ²\n' |
| f'{N_ANCHORS} anchors Γ {D_ANCHOR}-d Γ 3 experts | mAP={ckpt["mAP"]:.3f} | eff_dim=76.9', |
| color='white', fontsize=16, y=0.98) |
|
|
| |
| ax1 = fig.add_subplot(231, projection='3d') |
| setup_ax(ax1, f'Full Manifold β {N_val} embeddings + {N_ANCHORS} anchors') |
| ax1.scatter(emb_s[:, 0], emb_s[:, 1], emb_s[:, 2], |
| c=class_colors, s=1, alpha=0.3) |
| ax1.scatter(anchors_s[:, 0], anchors_s[:, 1], anchors_s[:, 2], |
| c='red', s=8, alpha=0.6, marker='^') |
|
|
| |
| ax2 = fig.add_subplot(232, projection='3d') |
| setup_ax(ax2, '80 COCO Class Centroids') |
| centroids = np.zeros((80, emb.shape[1])) |
| for c in range(80): |
| mask = val_labels[:, c].numpy() > 0 |
| if mask.sum() > 0: |
| centroids[c] = emb[mask].mean(0) |
| centroids_3d = to_sphere(centroids @ pca3.T) |
| sizes = val_labels.sum(0).numpy() |
| sizes_scaled = 20 + 200 * (sizes / sizes.max()) |
| colors80 = plt.cm.hsv(np.linspace(0, 0.95, 80)) |
| ax2.scatter(centroids_3d[:, 0], centroids_3d[:, 1], centroids_3d[:, 2], |
| c=colors80, s=sizes_scaled, alpha=0.8, edgecolors='white', linewidth=0.3) |
| for c in [0, 2, 14, 15, 16, 22, 23, 56, 62]: |
| if sizes[c] > 30: |
| ax2.text(centroids_3d[c, 0]*1.15, centroids_3d[c, 1]*1.15, |
| centroids_3d[c, 2]*1.15, |
| COCO_CLASSES[c], color='white', fontsize=7, ha='center') |
|
|
| |
| ax3 = fig.add_subplot(233, projection='3d') |
| setup_ax(ax3, '50 Random β nearest anchor connections') |
| np.random.seed(42) |
| idx50 = np.random.choice(N_val, 50, replace=False) |
| emb_50 = emb_s[idx50] |
| colors_50 = class_colors[idx50] |
| with torch.no_grad(): |
| cos_50 = fused[idx50] @ anchors.T |
| nearest_50 = cos_50.argmax(-1).numpy() |
| ax3.scatter(anchors_s[:, 0], anchors_s[:, 1], anchors_s[:, 2], |
| c='red', s=4, alpha=0.2, marker='^') |
| ax3.scatter(emb_50[:, 0], emb_50[:, 1], emb_50[:, 2], |
| c=colors_50, s=40, alpha=0.9, edgecolors='white', linewidth=0.5) |
| for i in range(50): |
| a = nearest_50[i] |
| ax3.plot([emb_50[i, 0], anchors_s[a, 0]], |
| [emb_50[i, 1], anchors_s[a, 1]], |
| [emb_50[i, 2], anchors_s[a, 2]], |
| color='yellow', alpha=0.3, linewidth=0.5) |
|
|
| |
| ax4 = fig.add_subplot(234, projection='3d') |
| setup_ax(ax4, '10 Random β anchor affinity heatmap') |
| idx10 = np.random.choice(N_val, 10, replace=False) |
| emb_10 = emb_s[idx10] |
| with torch.no_grad(): |
| cos_10 = (fused[idx10] @ anchors.T).numpy() |
| mean_cos = cos_10.mean(0) |
| anchor_heat = (mean_cos - mean_cos.min()) / (mean_cos.max() - mean_cos.min() + 1e-8) |
| anchor_colors = plt.cm.hot(anchor_heat) |
| ax4.scatter(anchors_s[:, 0], anchors_s[:, 1], anchors_s[:, 2], |
| c=anchor_colors, s=10, alpha=0.6) |
| ax4.scatter(emb_10[:, 0], emb_10[:, 1], emb_10[:, 2], |
| c='cyan', s=80, alpha=1.0, edgecolors='white', linewidth=1, zorder=10) |
|
|
| |
| ax5 = fig.add_subplot(235, projection='3d') |
| single_idx = 42 |
| single_class = primary_class[single_idx] |
| setup_ax(ax5, f'Single Encoding: "{COCO_CLASSES[single_class]}" β top 5 anchors') |
| with torch.no_grad(): |
| cos_single = (fused[single_idx] @ anchors.T).numpy() |
| single_heat = (cos_single - cos_single.min()) / (cos_single.max() - cos_single.min() + 1e-8) |
| single_colors = plt.cm.plasma(single_heat) |
| single_sizes = 2 + 50 * single_heat**3 |
| ax5.scatter(anchors_s[:, 0], anchors_s[:, 1], anchors_s[:, 2], |
| c=single_colors, s=single_sizes, alpha=0.7) |
| single_pt = emb_s[single_idx] |
| ax5.scatter([single_pt[0]], [single_pt[1]], [single_pt[2]], |
| c='lime', s=150, alpha=1.0, edgecolors='white', linewidth=2, |
| zorder=10, marker='*') |
| top5 = np.argsort(cos_single)[::-1][:5] |
| for a in top5: |
| ax5.plot([single_pt[0], anchors_s[a, 0]], |
| [single_pt[1], anchors_s[a, 1]], |
| [single_pt[2], anchors_s[a, 2]], |
| color='lime', alpha=0.6, linewidth=1.5) |
|
|
| |
| ax6 = fig.add_subplot(236, projection='3d') |
| radii = np.linalg.norm(emb_3d, axis=-1) |
| setup_ax(ax6, f'PCA Projection Radii β mean={radii.mean():.4f} std={radii.std():.4f}') |
| radius_dev = radii - radii.mean() |
| dev_norm = (radius_dev - radius_dev.min()) / (radius_dev.max() - radius_dev.min() + 1e-8) |
| dev_colors = plt.cm.coolwarm(dev_norm) |
| scale = 1.0 / radii.max() |
| ax6.scatter(emb_3d[:, 0]*scale, emb_3d[:, 1]*scale, emb_3d[:, 2]*scale, |
| c=dev_colors, s=2, alpha=0.4) |
|
|
| plt.tight_layout(rect=[0, 0, 1, 0.95]) |
| plt.savefig("hypersphere_manifold.png", dpi=200, facecolor='black', |
| bbox_inches='tight', pad_inches=0.3) |
| print("Saved: hypersphere_manifold.png") |
| plt.close() |
|
|
|
|
| |
| |
| |
|
|
| print("Rendering figure 2...") |
| fig2 = plt.figure(figsize=(21, 7), facecolor='black') |
| fig2.suptitle('Expert Perspective Divergence β Same sphere, three lenses', |
| color='white', fontsize=14, y=1.02) |
|
|
| has_expert_rot = f"constellation.expert_rotations.0" in sd |
| if has_expert_rot: |
| expert_R = [sd[f"constellation.expert_rotations.{i}"] for i in range(3)] |
| expert_W = [sd[f"constellation.expert_whiteners.{i}"] for i in range(3)] |
| expert_mu = [sd[f"constellation.expert_means.{i}"] for i in range(3)] |
| else: |
| expert_R = [torch.eye(D_ANCHOR) for _ in range(3)] |
| expert_W = [torch.eye(D_ANCHOR) for _ in range(3)] |
| expert_mu = [torch.zeros(D_ANCHOR) for _ in range(3)] |
|
|
| with torch.no_grad(): |
| for i, name in enumerate(EXPERTS): |
| ax = fig2.add_subplot(1, 3, i+1, projection='3d') |
|
|
| if has_expert_rot: |
| centered = fused.float() - expert_mu[i] |
| whitened = centered @ expert_W[i] |
| rotated = F.normalize(whitened @ expert_R[i].T, dim=-1) |
| elif f"projectors.{i}.proj_native.0.weight" in sd: |
| W = sd[f"projectors.{i}.proj_native.0.weight"] |
| b = sd[f"projectors.{i}.proj_native.0.bias"] |
| lw = sd[f"projectors.{i}.proj_native.1.weight"] |
| lb = sd[f"projectors.{i}.proj_native.1.bias"] |
| x = val_raw[name] @ W.T + b |
| mu_v = x.mean(-1, keepdim=True); var_v = x.var(-1, keepdim=True, unbiased=False) |
| x = (x - mu_v) / (var_v + 1e-5).sqrt() * lw + lb |
| rotated = F.normalize(x, dim=-1) |
| else: |
| rotated = projected[i] |
|
|
| rot_np = rotated.numpy() |
| rot_c = rot_np - rot_np.mean(axis=0, keepdims=True) |
| _, S_r, Vt_r = np.linalg.svd(rot_c[:5000], full_matrices=False) |
| rot_3d = to_sphere(rot_np @ Vt_r[:3].T) |
|
|
| var_exp = S_r[:3]**2 / (S_r**2).sum() |
| setup_ax(ax, f'{name[:25]}\nPC variance: {var_exp.sum()*100:.1f}%') |
| ax.scatter(rot_3d[:, 0], rot_3d[:, 1], rot_3d[:, 2], |
| c=class_colors, s=2, alpha=0.4) |
|
|
| plt.tight_layout() |
| plt.savefig("expert_perspectives.png", dpi=200, facecolor='black', |
| bbox_inches='tight', pad_inches=0.3) |
| print("Saved: expert_perspectives.png") |
| plt.close() |
|
|
|
|
| |
| |
| |
|
|
| print("Rendering figure 3 β anchors only...") |
|
|
| |
| with torch.no_grad(): |
| cos_all = fused @ anchors.T |
| nearest_all = cos_all.argmax(dim=-1) |
| vc = torch.zeros(N_ANCHORS) |
| for n in nearest_all: |
| vc[n] += 1 |
| vc_np = vc.numpy() |
|
|
| fig3 = plt.figure(figsize=(24, 8), facecolor='black') |
| fig3.suptitle(f'Constellation β {N_ANCHORS} anchors Γ {D_ANCHOR}-d on SΒ²β΅β΅', |
| color='white', fontsize=14, y=1.02) |
|
|
| |
| ax_a1 = fig3.add_subplot(131, projection='3d') |
| setup_ax(ax_a1, f'Anchor Utilization β {int((vc_np>0).sum())}/{N_ANCHORS} active') |
| heat = np.zeros(N_ANCHORS) |
| active_mask = vc_np > 0 |
| heat[active_mask] = np.log1p(vc_np[active_mask]) |
| heat = heat / (heat.max() + 1e-8) |
| a_colors = plt.cm.inferno(heat) |
| a_sizes = 5 + 60 * heat |
| |
| dead_mask = vc_np == 0 |
| ax_a1.scatter(anchors_s[dead_mask, 0], anchors_s[dead_mask, 1], anchors_s[dead_mask, 2], |
| c='dodgerblue', s=8, alpha=0.4, marker='x', label=f'dead ({int(dead_mask.sum())})') |
| ax_a1.scatter(anchors_s[active_mask, 0], anchors_s[active_mask, 1], anchors_s[active_mask, 2], |
| c=a_colors[active_mask], s=a_sizes[active_mask], alpha=0.8) |
|
|
| |
| ax_a2 = fig3.add_subplot(132, projection='3d') |
| anchor_sim = (anchors.numpy() @ anchors.numpy().T) |
| np.fill_diagonal(anchor_sim, -1) |
| max_neighbor_cos = anchor_sim.max(axis=1) |
| nn_heat = (max_neighbor_cos - max_neighbor_cos.min()) / (max_neighbor_cos.max() - max_neighbor_cos.min() + 1e-8) |
| nn_colors = plt.cm.viridis(nn_heat) |
| setup_ax(ax_a2, f'Anchor Isolation β nearest neighbor cosine\n' |
| f'mean={max_neighbor_cos.mean():.3f} max={max_neighbor_cos.max():.3f}') |
| ax_a2.scatter(anchors_s[:, 0], anchors_s[:, 1], anchors_s[:, 2], |
| c=nn_colors, s=15, alpha=0.8) |
|
|
| |
| ax_a3 = fig3.add_subplot(133, projection='3d') |
| with torch.no_grad(): |
| expert_tri_stack = [] |
| if has_expert_rot: |
| for i in range(3): |
| centered = fused.float() - expert_mu[i] |
| whitened = centered @ expert_W[i] |
| rotated = F.normalize(whitened @ expert_R[i].T, dim=-1) |
| expert_tri_stack.append(1.0 - (rotated @ anchors.T)) |
| elif f"projectors.0.proj_native.0.weight" in sd: |
| def _pn(feats, i): |
| W = sd[f"projectors.{i}.proj_native.0.weight"] |
| b = sd[f"projectors.{i}.proj_native.0.bias"] |
| lw = sd[f"projectors.{i}.proj_native.1.weight"] |
| lb = sd[f"projectors.{i}.proj_native.1.bias"] |
| x = feats @ W.T + b |
| mu = x.mean(-1, keepdim=True); var = x.var(-1, keepdim=True, unbiased=False) |
| x = (x - mu) / (var + 1e-5).sqrt() * lw + lb |
| return F.normalize(x, dim=-1) |
| for i, name in enumerate(EXPERTS): |
| nat = _pn(val_raw[name], i) |
| expert_tri_stack.append(1.0 - (nat @ anchors.T)) |
| else: |
| for p in projected: |
| expert_tri_stack.append(1.0 - (p @ anchors.T)) |
| tri_stack = torch.stack(expert_tri_stack, dim=-1) |
| per_anchor_div = tri_stack.std(dim=-1).mean(dim=0).numpy() |
|
|
| div_heat = (per_anchor_div - per_anchor_div.min()) / (per_anchor_div.max() - per_anchor_div.min() + 1e-8) |
| div_colors = plt.cm.coolwarm(div_heat) |
| setup_ax(ax_a3, f'Expert Divergence per Anchor\n' |
| f'mean={per_anchor_div.mean():.4f} range=[{per_anchor_div.min():.4f}, {per_anchor_div.max():.4f}]') |
| ax_a3.scatter(anchors_s[:, 0], anchors_s[:, 1], anchors_s[:, 2], |
| c=div_colors, s=15, alpha=0.8) |
|
|
| |
| flat_sim = anchor_sim.copy() |
| np.fill_diagonal(flat_sim, -999) |
| for panel_ax in [ax_a1, ax_a2]: |
| for _ in range(20): |
| idx_flat = np.argmax(flat_sim) |
| i_a, j_a = np.unravel_index(idx_flat, flat_sim.shape) |
| flat_sim[i_a, j_a] = -999; flat_sim[j_a, i_a] = -999 |
| panel_ax.plot([anchors_s[i_a, 0], anchors_s[j_a, 0]], |
| [anchors_s[i_a, 1], anchors_s[j_a, 1]], |
| [anchors_s[i_a, 2], anchors_s[j_a, 2]], |
| color='white', alpha=0.15, linewidth=0.5) |
|
|
| plt.tight_layout() |
| plt.savefig("anchors_only.png", dpi=200, facecolor='black', |
| bbox_inches='tight', pad_inches=0.3) |
| print("Saved: anchors_only.png") |
| plt.close() |
|
|
|
|
| |
| |
| |
|
|
| print("Rendering figure 4 β pairwise expert diffs...") |
|
|
| with torch.no_grad(): |
| |
| |
| |
| expert_tris = [] |
|
|
| if has_expert_rot: |
| |
| for i in range(3): |
| centered = fused.float() - expert_mu[i] |
| whitened = centered @ expert_W[i] |
| rotated = F.normalize(whitened @ expert_R[i].T, dim=-1) |
| tri = 1.0 - (rotated @ anchors.T) |
| expert_tris.append(tri) |
| elif f"projectors.0.proj_native.0.weight" in sd: |
| |
| def _proj_native(feats, i): |
| W = sd[f"projectors.{i}.proj_native.0.weight"] |
| b = sd[f"projectors.{i}.proj_native.0.bias"] |
| lw = sd[f"projectors.{i}.proj_native.1.weight"] |
| lb = sd[f"projectors.{i}.proj_native.1.bias"] |
| x = feats @ W.T + b |
| mu = x.mean(-1, keepdim=True); var = x.var(-1, keepdim=True, unbiased=False) |
| x = (x - mu) / (var + 1e-5).sqrt() * lw + lb |
| return F.normalize(x, dim=-1) |
| for i, name in enumerate(EXPERTS): |
| native_emb = _proj_native(val_raw[name], i) |
| tri = 1.0 - (native_emb @ anchors.T) |
| expert_tris.append(tri) |
| else: |
| |
| for p in projected: |
| tri = 1.0 - (p @ anchors.T) |
| expert_tris.append(tri) |
|
|
| |
| diff_cd = expert_tris[0] - expert_tris[1] |
| diff_cs = expert_tris[0] - expert_tris[2] |
| diff_ds = expert_tris[1] - expert_tris[2] |
| diffs = [diff_cd, diff_cs, diff_ds] |
| diff_names = ["CLIP β DINOv2", "CLIP β SigLIP", "DINOv2 β SigLIP"] |
|
|
| abs_tri = expert_tris[0] |
|
|
| print(f"\n Pairwise diff statistics:") |
| for name, d in zip(diff_names, diffs): |
| print(f" {name:20s}: mean={d.mean():.6f} std={d.std():.6f} " |
| f"min={d.min():.6f} max={d.max():.6f}") |
| print(f" Absolute tri std: {abs_tri.std():.6f}") |
| diff_std = diffs[0].std().item() |
| abs_std = abs_tri.std().item() |
| print(f" Ratio (diff/abs): {diff_std / abs_std:.4f}" if abs_std > 1e-10 else |
| f" Ratio (diff/abs): N/A (zero abs std)") |
|
|
| |
| diff_stacked = torch.cat(diffs, dim=-1).numpy() |
| diff_centered = diff_stacked - diff_stacked.mean(axis=0, keepdims=True) |
| _, S_diff, Vt_diff = np.linalg.svd(diff_centered[:5000], full_matrices=False) |
|
|
| |
| s_sum = (S_diff**2).sum() |
| if s_sum > 1e-20: |
| diff_3d = to_sphere(diff_centered @ Vt_diff[:3].T) |
| var_diff = S_diff[:3]**2 / s_sum |
| eff_dim_diff = float(((S_diff / S_diff.sum())**2).sum()**-1) |
| else: |
| diff_3d = np.zeros((len(diff_centered), 3)) |
| var_diff = np.zeros(3) |
| eff_dim_diff = 0.0 |
| print(f"\n Diff space effective dim: {eff_dim_diff:.1f}") |
| print(f" Diff PCA 3D variance: {var_diff.sum()*100:.1f}%") |
|
|
| abs_stacked = abs_tri.numpy() |
| abs_centered = abs_stacked - abs_stacked.mean(axis=0, keepdims=True) |
| _, S_abs, Vt_abs = np.linalg.svd(abs_centered[:5000], full_matrices=False) |
| abs_eff = float(((S_abs / S_abs.sum())**2).sum()**-1) if S_abs.sum() > 1e-20 else 0.0 |
| print(f" Absolute tri effective dim: {abs_eff:.1f}") |
|
|
| full_stacked = np.concatenate([abs_stacked, diff_stacked], axis=-1) |
| full_centered = full_stacked - full_stacked.mean(axis=0, keepdims=True) |
| _, S_full, Vt_full = np.linalg.svd(full_centered[:5000], full_matrices=False) |
| full_eff = float(((S_full / S_full.sum())**2).sum()**-1) if S_full.sum() > 1e-20 else 0.0 |
| full_3d = to_sphere(full_centered @ Vt_full[:3].T) if S_full.sum() > 1e-20 else np.zeros((len(full_centered), 3)) |
| print(f" Full (abs+diffs) effective dim: {full_eff:.1f}") |
| print(f" Information gain from diffs: {full_eff - abs_eff:.1f} dimensions") |
|
|
| fig4 = plt.figure(figsize=(28, 14), facecolor='black') |
| fig4.suptitle( |
| 'Expert Pairwise Differences β Where the discriminative signal lives\n' |
| f'Diff eff_dim={eff_dim_diff:.1f} | Abs eff_dim={abs_eff:.1f} | ' |
| f'Combined eff_dim={full_eff:.1f} | Info gain: +{full_eff-abs_eff:.1f} dims', |
| color='white', fontsize=14, y=0.98) |
|
|
| |
| for col, (name, d) in enumerate(zip(diff_names, diffs)): |
| ax = fig4.add_subplot(2, 4, col+1, projection='3d') |
| d_np = d.numpy() |
|
|
| |
| diff_mag = np.linalg.norm(d_np, axis=-1) |
| mag_heat = (diff_mag - diff_mag.min()) / (diff_mag.max() - diff_mag.min() + 1e-8) |
| mag_colors = plt.cm.magma(mag_heat) |
|
|
| setup_ax(ax, f'{name}\nstd={d_np.std():.5f}') |
| ax.scatter(emb_s[:, 0], emb_s[:, 1], emb_s[:, 2], |
| c=mag_colors, s=2, alpha=0.5) |
|
|
| |
| ax_dp = fig4.add_subplot(244, projection='3d') |
| setup_ax(ax_dp, f'Diff Space PCA\neff_dim={eff_dim_diff:.1f} var={var_diff.sum()*100:.1f}%') |
| ax_dp.scatter(diff_3d[:, 0], diff_3d[:, 1], diff_3d[:, 2], |
| c=class_colors, s=2, alpha=0.4) |
|
|
| |
| |
| with torch.no_grad(): |
| per_anchor_cd = diff_cd.abs().mean(dim=0).numpy() |
| per_anchor_cs = diff_cs.abs().mean(dim=0).numpy() |
| per_anchor_ds = diff_ds.abs().mean(dim=0).numpy() |
| per_anchor_total = (per_anchor_cd + per_anchor_cs + per_anchor_ds) / 3 |
|
|
| |
| ax_a = fig4.add_subplot(245, projection='3d') |
| total_heat = (per_anchor_total - per_anchor_total.min()) / (per_anchor_total.max() - per_anchor_total.min() + 1e-8) |
| total_colors = plt.cm.hot(total_heat) |
| total_sizes = 5 + 40 * total_heat |
| setup_ax(ax_a, f'Anchor Divergence (all pairs)\n' |
| f'range=[{per_anchor_total.min():.5f}, {per_anchor_total.max():.5f}]') |
| ax_a.scatter(anchors_s[:, 0], anchors_s[:, 1], anchors_s[:, 2], |
| c=total_colors, s=total_sizes, alpha=0.8) |
|
|
| |
| ax_abs = fig4.add_subplot(246, projection='3d') |
| abs_3d = to_sphere(abs_centered @ Vt_abs[:3].T) |
| var_abs_3 = S_abs[:3]**2 / (S_abs**2).sum() |
| setup_ax(ax_abs, f'Absolute Tri PCA\neff_dim={abs_eff:.1f} var={var_abs_3.sum()*100:.1f}%') |
| ax_abs.scatter(abs_3d[:, 0], abs_3d[:, 1], abs_3d[:, 2], |
| c=class_colors, s=2, alpha=0.4) |
|
|
| |
| ax_full = fig4.add_subplot(247, projection='3d') |
| var_full_3 = S_full[:3]**2 / (S_full**2).sum() |
| setup_ax(ax_full, f'Combined (abs+diffs) PCA\neff_dim={full_eff:.1f} var={var_full_3.sum()*100:.1f}%') |
| ax_full.scatter(full_3d[:, 0], full_3d[:, 1], full_3d[:, 2], |
| c=class_colors, s=2, alpha=0.4) |
|
|
| |
| ax_hist = fig4.add_subplot(248) |
| ax_hist.set_facecolor('black') |
| for name, d, color in zip(diff_names, diffs, |
| ['#ff6b6b', '#4ecdc4', '#ffe66d']): |
| d_np = d.numpy() |
| per_image_mag = np.linalg.norm(d_np, axis=-1) |
| ax_hist.hist(per_image_mag, bins=50, alpha=0.6, color=color, |
| label=name, density=True) |
| ax_hist.set_xlabel('Diff magnitude (L2)', color='white', fontsize=9) |
| ax_hist.set_ylabel('Density', color='white', fontsize=9) |
| ax_hist.set_title('Per-image diff magnitudes', color='white', fontsize=11) |
| ax_hist.legend(fontsize=8, facecolor='black', edgecolor='gray', |
| labelcolor='white') |
| ax_hist.tick_params(colors='gray', labelsize=7) |
| ax_hist.spines['bottom'].set_color('gray'); ax_hist.spines['left'].set_color('gray') |
| ax_hist.spines['top'].set_visible(False); ax_hist.spines['right'].set_visible(False) |
|
|
| plt.tight_layout(rect=[0, 0, 1, 0.95]) |
| plt.savefig("pairwise_diffs.png", dpi=200, facecolor='black', |
| bbox_inches='tight', pad_inches=0.3) |
| print("Saved: pairwise_diffs.png") |
| plt.close() |
|
|
| print("\nDone.") |