AbstractPhil commited on
Commit
9c5e7b8
Β·
verified Β·
1 Parent(s): 6fb2607

Create cv_sweep_8x8_scaling_65k_runs.py

Browse files
Files changed (1) hide show
  1. cv_sweep_8x8_scaling_65k_runs.py +147 -0
cv_sweep_8x8_scaling_65k_runs.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CM CV Full Sweep β€” 8Γ—8 grid from 8 to 2048
3
+ """
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import math
9
+ import json
10
+ import time
11
+
12
+
13
+ def cayley_menger_vol2(points):
14
+ B, N, D = points.shape
15
+ gram = torch.bmm(points, points.transpose(1, 2))
16
+ norms = torch.diagonal(gram, dim1=1, dim2=2)
17
+ d2 = norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram
18
+ d2 = F.relu(d2)
19
+ cm = torch.zeros(B, N + 1, N + 1, device=points.device, dtype=points.dtype)
20
+ cm[:, 0, 1:] = 1.0
21
+ cm[:, 1:, 0] = 1.0
22
+ cm[:, 1:, 1:] = d2
23
+ k = N - 1
24
+ sign = (-1.0) ** (k + 1)
25
+ fact = math.factorial(k)
26
+ return sign * torch.linalg.det(cm.float()).to(points.dtype) / ((2 ** k) * (fact ** 2))
27
+
28
+
29
+ def cv_metric(weight, n_samples=300, n_points=5):
30
+ V, D = weight.shape
31
+ if V < n_points:
32
+ return None
33
+ pool = min(V, 512)
34
+ indices = torch.stack([
35
+ torch.randperm(pool, device=weight.device)[:n_points]
36
+ for _ in range(n_samples)
37
+ ])
38
+ pts = weight[:pool][indices]
39
+ vol2 = cayley_menger_vol2(pts)
40
+ valid = vol2 > 1e-20
41
+ if valid.sum() < 10:
42
+ return None
43
+ vols = vol2[valid].sqrt()
44
+ return (vols.std() / (vols.mean() + 1e-8)).item()
45
+
46
+
47
+ if __name__ == "__main__":
48
+ STEP = 8
49
+ LOW, HIGH = 8, 2048
50
+ BAND_LO, BAND_HI = 0.13, 0.30
51
+
52
+ dims = list(range(LOW, HIGH + 1, STEP))
53
+ vocabs = list(range(LOW, HIGH + 1, STEP))
54
+
55
+ total = len(dims) * len(vocabs)
56
+ print(f"CM CV Sweep: {len(vocabs)} vocabs Γ— {len(dims)} dims = {total} configs")
57
+ print(f"Band: {BAND_LO} < CV < {BAND_HI}")
58
+ print("=" * 70)
59
+
60
+ all_results = []
61
+ band_results = []
62
+ t0 = time.time()
63
+
64
+ for i, vocab in enumerate(vocabs):
65
+ for dim in dims:
66
+ emb = nn.Embedding(vocab, dim)
67
+ with torch.no_grad():
68
+ cv = cv_metric(emb.weight)
69
+ if cv is not None:
70
+ in_band = BAND_LO < cv < BAND_HI
71
+ entry = {"V": vocab, "D": dim, "CV": round(cv, 4), "in_band": in_band}
72
+ all_results.append(entry)
73
+ if in_band:
74
+ band_results.append(entry)
75
+ else:
76
+ all_results.append({"V": vocab, "D": dim, "CV": None, "in_band": False})
77
+
78
+ elapsed = time.time() - t0
79
+ pct = (i + 1) / len(vocabs) * 100
80
+ if (i + 1) % 16 == 0 or i == 0 or i == len(vocabs) - 1:
81
+ print(f"V={vocab:4d} done | {pct:.0f}% | {elapsed:.0f}s")
82
+
83
+ # ── Save JSON ──
84
+ output = {
85
+ "sweep": {"step": STEP, "low": LOW, "high": HIGH},
86
+ "band": {"lo": BAND_LO, "hi": BAND_HI},
87
+ "band_results": sorted(band_results, key=lambda x: x["CV"]),
88
+ "all_results": all_results,
89
+ }
90
+
91
+ with open("cm_cv_sweep_8x8.json", "w") as f:
92
+ json.dump(output, f, indent=2)
93
+
94
+ # ── Summary by D ──
95
+ print()
96
+ print("=" * 70)
97
+ print(f"BAND VALID ({BAND_LO} < CV < {BAND_HI}): {len(band_results)} / {total}")
98
+ print("=" * 70)
99
+
100
+ by_dim = {}
101
+ for r in band_results:
102
+ d = r["D"]
103
+ if d not in by_dim:
104
+ by_dim[d] = []
105
+ by_dim[d].append(r)
106
+
107
+ for d in sorted(by_dim.keys()):
108
+ entries = by_dim[d]
109
+ v_range = f"V={min(e['V'] for e in entries)}-{max(e['V'] for e in entries)}"
110
+ cv_range = f"CV={min(e['CV'] for e in entries):.4f}-{max(e['CV'] for e in entries):.4f}"
111
+ print(f" D={d:4d}: {len(entries):3d} configs {v_range:20s} {cv_range}")
112
+
113
+ # ── Band boundaries ──
114
+ print()
115
+ print("=" * 70)
116
+ print("Band boundaries (CV at each D, averaged across all V)")
117
+ print("=" * 70)
118
+ by_dim_all = {}
119
+ for r in all_results:
120
+ if r["CV"] is not None:
121
+ d = r["D"]
122
+ if d not in by_dim_all:
123
+ by_dim_all[d] = []
124
+ by_dim_all[d].append(r["CV"])
125
+
126
+ for d in sorted(by_dim_all.keys()):
127
+ cvs = by_dim_all[d]
128
+ avg = sum(cvs) / len(cvs)
129
+ mn, mx = min(cvs), max(cvs)
130
+ marker = " <-- IN BAND" if BAND_LO < avg < BAND_HI else ""
131
+ if d <= 256:
132
+ print(f" D={d:4d}: avg={avg:.4f} min={mn:.4f} max={mx:.4f}{marker}")
133
+
134
+ # ── Ratios ──
135
+ print()
136
+ print("=" * 70)
137
+ print("Unique V/D ratios for band-valid configs:")
138
+ print("=" * 70)
139
+ ratios = sorted(set(round(r["V"] / r["D"], 2) for r in band_results))
140
+ # Show range
141
+ print(f" Count: {len(ratios)}")
142
+ print(f" Min ratio: {ratios[0]}")
143
+ print(f" Max ratio: {ratios[-1]}")
144
+
145
+ print()
146
+ print(f"Results saved to cm_cv_sweep_8x8.json")
147
+ print(f"Total time: {time.time() - t0:.1f}s")