""" CM CV Full Sweep — 8×8 grid from 8 to 2048 """ import torch import torch.nn as nn import torch.nn.functional as F import math import json import time def cayley_menger_vol2(points): B, N, D = points.shape gram = torch.bmm(points, points.transpose(1, 2)) norms = torch.diagonal(gram, dim1=1, dim2=2) d2 = norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram d2 = F.relu(d2) cm = torch.zeros(B, N + 1, N + 1, device=points.device, dtype=points.dtype) cm[:, 0, 1:] = 1.0 cm[:, 1:, 0] = 1.0 cm[:, 1:, 1:] = d2 k = N - 1 sign = (-1.0) ** (k + 1) fact = math.factorial(k) return sign * torch.linalg.det(cm.float()).to(points.dtype) / ((2 ** k) * (fact ** 2)) def cv_metric(weight, n_samples=300, n_points=5): V, D = weight.shape if V < n_points: return None pool = min(V, 512) indices = torch.stack([ torch.randperm(pool, device=weight.device)[:n_points] for _ in range(n_samples) ]) pts = weight[:pool][indices] vol2 = cayley_menger_vol2(pts) valid = vol2 > 1e-20 if valid.sum() < 10: return None vols = vol2[valid].sqrt() return (vols.std() / (vols.mean() + 1e-8)).item() if __name__ == "__main__": STEP = 8 LOW, HIGH = 8, 2048 BAND_LO, BAND_HI = 0.13, 0.30 dims = list(range(LOW, HIGH + 1, STEP)) vocabs = list(range(LOW, HIGH + 1, STEP)) total = len(dims) * len(vocabs) print(f"CM CV Sweep: {len(vocabs)} vocabs × {len(dims)} dims = {total} configs") print(f"Band: {BAND_LO} < CV < {BAND_HI}") print("=" * 70) all_results = [] band_results = [] t0 = time.time() for i, vocab in enumerate(vocabs): for dim in dims: emb = nn.Embedding(vocab, dim) with torch.no_grad(): cv = cv_metric(emb.weight) if cv is not None: in_band = BAND_LO < cv < BAND_HI entry = {"V": vocab, "D": dim, "CV": round(cv, 4), "in_band": in_band} all_results.append(entry) if in_band: band_results.append(entry) else: all_results.append({"V": vocab, "D": dim, "CV": None, "in_band": False}) elapsed = time.time() - t0 pct = (i + 1) / len(vocabs) * 100 if (i + 1) % 16 == 0 or i == 0 or i == len(vocabs) - 1: print(f"V={vocab:4d} done | {pct:.0f}% | {elapsed:.0f}s") # ── Save JSON ── output = { "sweep": {"step": STEP, "low": LOW, "high": HIGH}, "band": {"lo": BAND_LO, "hi": BAND_HI}, "band_results": sorted(band_results, key=lambda x: x["CV"]), "all_results": all_results, } with open("cm_cv_sweep_8x8.json", "w") as f: json.dump(output, f, indent=2) # ── Summary by D ── print() print("=" * 70) print(f"BAND VALID ({BAND_LO} < CV < {BAND_HI}): {len(band_results)} / {total}") print("=" * 70) by_dim = {} for r in band_results: d = r["D"] if d not in by_dim: by_dim[d] = [] by_dim[d].append(r) for d in sorted(by_dim.keys()): entries = by_dim[d] v_range = f"V={min(e['V'] for e in entries)}-{max(e['V'] for e in entries)}" cv_range = f"CV={min(e['CV'] for e in entries):.4f}-{max(e['CV'] for e in entries):.4f}" print(f" D={d:4d}: {len(entries):3d} configs {v_range:20s} {cv_range}") # ── Band boundaries ── print() print("=" * 70) print("Band boundaries (CV at each D, averaged across all V)") print("=" * 70) by_dim_all = {} for r in all_results: if r["CV"] is not None: d = r["D"] if d not in by_dim_all: by_dim_all[d] = [] by_dim_all[d].append(r["CV"]) for d in sorted(by_dim_all.keys()): cvs = by_dim_all[d] avg = sum(cvs) / len(cvs) mn, mx = min(cvs), max(cvs) marker = " <-- IN BAND" if BAND_LO < avg < BAND_HI else "" if d <= 256: print(f" D={d:4d}: avg={avg:.4f} min={mn:.4f} max={mx:.4f}{marker}") # ── Ratios ── print() print("=" * 70) print("Unique V/D ratios for band-valid configs:") print("=" * 70) ratios = sorted(set(round(r["V"] / r["D"], 2) for r in band_results)) # Show range print(f" Count: {len(ratios)}") print(f" Min ratio: {ratios[0]}") print(f" Max ratio: {ratios[-1]}") print() print(f"Results saved to cm_cv_sweep_8x8.json") print(f"Total time: {time.time() - t0:.1f}s")