| """ |
| CV at D=32 with absurd vocabulary sizes. |
| Does V matter at scale? We say no. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import math |
| import time |
|
|
|
|
| def cayley_menger_vol2(points): |
| B, N, D = points.shape |
| gram = torch.bmm(points, points.transpose(1, 2)) |
| norms = torch.diagonal(gram, dim1=1, dim2=2) |
| d2 = F.relu(norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram) |
| cm = torch.zeros(B, N + 1, N + 1, device=points.device, dtype=points.dtype) |
| cm[:, 0, 1:] = 1.0 |
| cm[:, 1:, 0] = 1.0 |
| cm[:, 1:, 1:] = d2 |
| k = N - 1 |
| sign = (-1.0) ** (k + 1) |
| fact = math.factorial(k) |
| return sign * torch.linalg.det(cm.float()).to(points.dtype) / ((2 ** k) * (fact ** 2)) |
|
|
|
|
| def cv_metric(weight, n_samples=500, n_points=5, pool_size=512): |
| V, D = weight.shape |
| pool = min(V, pool_size) |
| indices = torch.stack([ |
| torch.randperm(pool, device=weight.device)[:n_points] |
| for _ in range(n_samples) |
| ]) |
| pts = weight[:pool][indices] |
| vol2 = cayley_menger_vol2(pts) |
| valid = vol2 > 1e-20 |
| if valid.sum() < 10: |
| return None |
| vols = vol2[valid].sqrt() |
| return (vols.std() / (vols.mean() + 1e-8)).item() |
|
|
|
|
| if __name__ == "__main__": |
| D = 32 |
| vocabs = [32, 512, 8192, 65536, 131072, 500000, 1000000, 4000000, 13000000] |
|
|
| print(f"D={D} fixed. CV across vocab sizes.") |
| print(f"Pool capped at 512 for fair comparison.") |
| print("=" * 60) |
|
|
| for V in vocabs: |
| t0 = time.time() |
| |
| weight = torch.randn(V, D) |
| cv = cv_metric(weight, n_samples=500) |
| elapsed = time.time() - t0 |
| mem_mb = V * D * 4 / 1e6 |
| print(f" V={V:>10,} D={D} CV={cv:.4f} {elapsed:.1f}s {mem_mb:.0f}MB") |
|
|
| |
| print() |
| print("=" * 60) |
| print("Now uncapped pool (sample from ALL embeddings):") |
| print("=" * 60) |
|
|
| for V in [512, 8192, 65536, 500000]: |
| weight = torch.randn(V, D) |
| cv = cv_metric(weight, n_samples=500, pool_size=V) |
| print(f" V={V:>10,} D={D} CV={cv:.4f} pool={V}") |