geolip-deep-embedding-analysis / big_vocabulary_tests.py
AbstractPhil's picture
Create big_vocabulary_tests.py
185899c verified
"""
CV at D=32 with absurd vocabulary sizes.
Does V matter at scale? We say no.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import time
def cayley_menger_vol2(points):
B, N, D = points.shape
gram = torch.bmm(points, points.transpose(1, 2))
norms = torch.diagonal(gram, dim1=1, dim2=2)
d2 = F.relu(norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram)
cm = torch.zeros(B, N + 1, N + 1, device=points.device, dtype=points.dtype)
cm[:, 0, 1:] = 1.0
cm[:, 1:, 0] = 1.0
cm[:, 1:, 1:] = d2
k = N - 1
sign = (-1.0) ** (k + 1)
fact = math.factorial(k)
return sign * torch.linalg.det(cm.float()).to(points.dtype) / ((2 ** k) * (fact ** 2))
def cv_metric(weight, n_samples=500, n_points=5, pool_size=512):
V, D = weight.shape
pool = min(V, pool_size)
indices = torch.stack([
torch.randperm(pool, device=weight.device)[:n_points]
for _ in range(n_samples)
])
pts = weight[:pool][indices]
vol2 = cayley_menger_vol2(pts)
valid = vol2 > 1e-20
if valid.sum() < 10:
return None
vols = vol2[valid].sqrt()
return (vols.std() / (vols.mean() + 1e-8)).item()
if __name__ == "__main__":
D = 32
vocabs = [32, 512, 8192, 65536, 131072, 500000, 1000000, 4000000, 13000000]
print(f"D={D} fixed. CV across vocab sizes.")
print(f"Pool capped at 512 for fair comparison.")
print("=" * 60)
for V in vocabs:
t0 = time.time()
# Use raw tensor instead of nn.Embedding for huge sizes
weight = torch.randn(V, D)
cv = cv_metric(weight, n_samples=500)
elapsed = time.time() - t0
mem_mb = V * D * 4 / 1e6
print(f" V={V:>10,} D={D} CV={cv:.4f} {elapsed:.1f}s {mem_mb:.0f}MB")
# Also uncap the pool for the big ones
print()
print("=" * 60)
print("Now uncapped pool (sample from ALL embeddings):")
print("=" * 60)
for V in [512, 8192, 65536, 500000]:
weight = torch.randn(V, D)
cv = cv_metric(weight, n_samples=500, pool_size=V)
print(f" V={V:>10,} D={D} CV={cv:.4f} pool={V}")