| """ |
| CM CV Full Sweep β 8Γ8 grid from 8 to 2048 |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import math |
| import json |
| import time |
|
|
|
|
| def cayley_menger_vol2(points): |
| B, N, D = points.shape |
| gram = torch.bmm(points, points.transpose(1, 2)) |
| norms = torch.diagonal(gram, dim1=1, dim2=2) |
| d2 = norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram |
| d2 = F.relu(d2) |
| cm = torch.zeros(B, N + 1, N + 1, device=points.device, dtype=points.dtype) |
| cm[:, 0, 1:] = 1.0 |
| cm[:, 1:, 0] = 1.0 |
| cm[:, 1:, 1:] = d2 |
| k = N - 1 |
| sign = (-1.0) ** (k + 1) |
| fact = math.factorial(k) |
| return sign * torch.linalg.det(cm.float()).to(points.dtype) / ((2 ** k) * (fact ** 2)) |
|
|
|
|
| def cv_metric(weight, n_samples=300, n_points=5): |
| V, D = weight.shape |
| if V < n_points: |
| return None |
| pool = min(V, 512) |
| indices = torch.stack([ |
| torch.randperm(pool, device=weight.device)[:n_points] |
| for _ in range(n_samples) |
| ]) |
| pts = weight[:pool][indices] |
| vol2 = cayley_menger_vol2(pts) |
| valid = vol2 > 1e-20 |
| if valid.sum() < 10: |
| return None |
| vols = vol2[valid].sqrt() |
| return (vols.std() / (vols.mean() + 1e-8)).item() |
|
|
|
|
| if __name__ == "__main__": |
| STEP = 8 |
| LOW, HIGH = 8, 2048 |
| BAND_LO, BAND_HI = 0.13, 0.30 |
|
|
| dims = list(range(LOW, HIGH + 1, STEP)) |
| vocabs = list(range(LOW, HIGH + 1, STEP)) |
|
|
| total = len(dims) * len(vocabs) |
| print(f"CM CV Sweep: {len(vocabs)} vocabs Γ {len(dims)} dims = {total} configs") |
| print(f"Band: {BAND_LO} < CV < {BAND_HI}") |
| print("=" * 70) |
|
|
| all_results = [] |
| band_results = [] |
| t0 = time.time() |
|
|
| for i, vocab in enumerate(vocabs): |
| for dim in dims: |
| emb = nn.Embedding(vocab, dim) |
| with torch.no_grad(): |
| cv = cv_metric(emb.weight) |
| if cv is not None: |
| in_band = BAND_LO < cv < BAND_HI |
| entry = {"V": vocab, "D": dim, "CV": round(cv, 4), "in_band": in_band} |
| all_results.append(entry) |
| if in_band: |
| band_results.append(entry) |
| else: |
| all_results.append({"V": vocab, "D": dim, "CV": None, "in_band": False}) |
|
|
| elapsed = time.time() - t0 |
| pct = (i + 1) / len(vocabs) * 100 |
| if (i + 1) % 16 == 0 or i == 0 or i == len(vocabs) - 1: |
| print(f"V={vocab:4d} done | {pct:.0f}% | {elapsed:.0f}s") |
|
|
| |
| output = { |
| "sweep": {"step": STEP, "low": LOW, "high": HIGH}, |
| "band": {"lo": BAND_LO, "hi": BAND_HI}, |
| "band_results": sorted(band_results, key=lambda x: x["CV"]), |
| "all_results": all_results, |
| } |
|
|
| with open("cm_cv_sweep_8x8.json", "w") as f: |
| json.dump(output, f, indent=2) |
|
|
| |
| print() |
| print("=" * 70) |
| print(f"BAND VALID ({BAND_LO} < CV < {BAND_HI}): {len(band_results)} / {total}") |
| print("=" * 70) |
|
|
| by_dim = {} |
| for r in band_results: |
| d = r["D"] |
| if d not in by_dim: |
| by_dim[d] = [] |
| by_dim[d].append(r) |
|
|
| for d in sorted(by_dim.keys()): |
| entries = by_dim[d] |
| v_range = f"V={min(e['V'] for e in entries)}-{max(e['V'] for e in entries)}" |
| cv_range = f"CV={min(e['CV'] for e in entries):.4f}-{max(e['CV'] for e in entries):.4f}" |
| print(f" D={d:4d}: {len(entries):3d} configs {v_range:20s} {cv_range}") |
|
|
| |
| print() |
| print("=" * 70) |
| print("Band boundaries (CV at each D, averaged across all V)") |
| print("=" * 70) |
| by_dim_all = {} |
| for r in all_results: |
| if r["CV"] is not None: |
| d = r["D"] |
| if d not in by_dim_all: |
| by_dim_all[d] = [] |
| by_dim_all[d].append(r["CV"]) |
|
|
| for d in sorted(by_dim_all.keys()): |
| cvs = by_dim_all[d] |
| avg = sum(cvs) / len(cvs) |
| mn, mx = min(cvs), max(cvs) |
| marker = " <-- IN BAND" if BAND_LO < avg < BAND_HI else "" |
| if d <= 256: |
| print(f" D={d:4d}: avg={avg:.4f} min={mn:.4f} max={mx:.4f}{marker}") |
|
|
| |
| print() |
| print("=" * 70) |
| print("Unique V/D ratios for band-valid configs:") |
| print("=" * 70) |
| ratios = sorted(set(round(r["V"] / r["D"], 2) for r in band_results)) |
| |
| print(f" Count: {len(ratios)}") |
| print(f" Min ratio: {ratios[0]}") |
| print(f" Max ratio: {ratios[-1]}") |
|
|
| print() |
| print(f"Results saved to cm_cv_sweep_8x8.json") |
| print(f"Total time: {time.time() - t0:.1f}s") |