geolip-deep-embedding-analysis / cv_sweep_8x8_scaling_65k_runs.py
AbstractPhil's picture
Create cv_sweep_8x8_scaling_65k_runs.py
9c5e7b8 verified
"""
CM CV Full Sweep β€” 8Γ—8 grid from 8 to 2048
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import json
import time
def cayley_menger_vol2(points):
B, N, D = points.shape
gram = torch.bmm(points, points.transpose(1, 2))
norms = torch.diagonal(gram, dim1=1, dim2=2)
d2 = norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram
d2 = F.relu(d2)
cm = torch.zeros(B, N + 1, N + 1, device=points.device, dtype=points.dtype)
cm[:, 0, 1:] = 1.0
cm[:, 1:, 0] = 1.0
cm[:, 1:, 1:] = d2
k = N - 1
sign = (-1.0) ** (k + 1)
fact = math.factorial(k)
return sign * torch.linalg.det(cm.float()).to(points.dtype) / ((2 ** k) * (fact ** 2))
def cv_metric(weight, n_samples=300, n_points=5):
V, D = weight.shape
if V < n_points:
return None
pool = min(V, 512)
indices = torch.stack([
torch.randperm(pool, device=weight.device)[:n_points]
for _ in range(n_samples)
])
pts = weight[:pool][indices]
vol2 = cayley_menger_vol2(pts)
valid = vol2 > 1e-20
if valid.sum() < 10:
return None
vols = vol2[valid].sqrt()
return (vols.std() / (vols.mean() + 1e-8)).item()
if __name__ == "__main__":
STEP = 8
LOW, HIGH = 8, 2048
BAND_LO, BAND_HI = 0.13, 0.30
dims = list(range(LOW, HIGH + 1, STEP))
vocabs = list(range(LOW, HIGH + 1, STEP))
total = len(dims) * len(vocabs)
print(f"CM CV Sweep: {len(vocabs)} vocabs Γ— {len(dims)} dims = {total} configs")
print(f"Band: {BAND_LO} < CV < {BAND_HI}")
print("=" * 70)
all_results = []
band_results = []
t0 = time.time()
for i, vocab in enumerate(vocabs):
for dim in dims:
emb = nn.Embedding(vocab, dim)
with torch.no_grad():
cv = cv_metric(emb.weight)
if cv is not None:
in_band = BAND_LO < cv < BAND_HI
entry = {"V": vocab, "D": dim, "CV": round(cv, 4), "in_band": in_band}
all_results.append(entry)
if in_band:
band_results.append(entry)
else:
all_results.append({"V": vocab, "D": dim, "CV": None, "in_band": False})
elapsed = time.time() - t0
pct = (i + 1) / len(vocabs) * 100
if (i + 1) % 16 == 0 or i == 0 or i == len(vocabs) - 1:
print(f"V={vocab:4d} done | {pct:.0f}% | {elapsed:.0f}s")
# ── Save JSON ──
output = {
"sweep": {"step": STEP, "low": LOW, "high": HIGH},
"band": {"lo": BAND_LO, "hi": BAND_HI},
"band_results": sorted(band_results, key=lambda x: x["CV"]),
"all_results": all_results,
}
with open("cm_cv_sweep_8x8.json", "w") as f:
json.dump(output, f, indent=2)
# ── Summary by D ──
print()
print("=" * 70)
print(f"BAND VALID ({BAND_LO} < CV < {BAND_HI}): {len(band_results)} / {total}")
print("=" * 70)
by_dim = {}
for r in band_results:
d = r["D"]
if d not in by_dim:
by_dim[d] = []
by_dim[d].append(r)
for d in sorted(by_dim.keys()):
entries = by_dim[d]
v_range = f"V={min(e['V'] for e in entries)}-{max(e['V'] for e in entries)}"
cv_range = f"CV={min(e['CV'] for e in entries):.4f}-{max(e['CV'] for e in entries):.4f}"
print(f" D={d:4d}: {len(entries):3d} configs {v_range:20s} {cv_range}")
# ── Band boundaries ──
print()
print("=" * 70)
print("Band boundaries (CV at each D, averaged across all V)")
print("=" * 70)
by_dim_all = {}
for r in all_results:
if r["CV"] is not None:
d = r["D"]
if d not in by_dim_all:
by_dim_all[d] = []
by_dim_all[d].append(r["CV"])
for d in sorted(by_dim_all.keys()):
cvs = by_dim_all[d]
avg = sum(cvs) / len(cvs)
mn, mx = min(cvs), max(cvs)
marker = " <-- IN BAND" if BAND_LO < avg < BAND_HI else ""
if d <= 256:
print(f" D={d:4d}: avg={avg:.4f} min={mn:.4f} max={mx:.4f}{marker}")
# ── Ratios ──
print()
print("=" * 70)
print("Unique V/D ratios for band-valid configs:")
print("=" * 70)
ratios = sorted(set(round(r["V"] / r["D"], 2) for r in band_results))
# Show range
print(f" Count: {len(ratios)}")
print(f" Min ratio: {ratios[0]}")
print(f" Max ratio: {ratios[-1]}")
print()
print(f"Results saved to cm_cv_sweep_8x8.json")
print(f"Total time: {time.time() - t0:.1f}s")