Spaces:
Sleeping
Sleeping
Fix OOM: batch sampling in micro-batches of 4
Browse files
sample.py
CHANGED
|
@@ -168,8 +168,18 @@ def main():
|
|
| 168 |
key = jax.random.PRNGKey(args.seed)
|
| 169 |
print(f"Sampling {args.num_samples} configurations "
|
| 170 |
f"(compiling on first call) …")
|
| 171 |
-
|
| 172 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
|
| 174 |
grids = tokens_to_grids(tokens, lattice_size) # (N, L, L), values {-1, +1}
|
| 175 |
print(f" shape: {grids.shape} dtype: {grids.dtype}")
|
|
|
|
| 168 |
key = jax.random.PRNGKey(args.seed)
|
| 169 |
print(f"Sampling {args.num_samples} configurations "
|
| 170 |
f"(compiling on first call) …")
|
| 171 |
+
|
| 172 |
+
# Sample in fixed-size micro-batches to avoid OOM on GPU.
|
| 173 |
+
# Changing _MICRO_BATCH triggers recompilation, so keep it constant.
|
| 174 |
+
_MICRO_BATCH = 4
|
| 175 |
+
all_tokens = []
|
| 176 |
+
n_calls = -(-args.num_samples // _MICRO_BATCH) # ceiling division
|
| 177 |
+
for i in range(n_calls):
|
| 178 |
+
key, subkey = jax.random.split(key)
|
| 179 |
+
batch = sample_batch(model, _MICRO_BATCH, subkey)
|
| 180 |
+
all_tokens.append(np.asarray(batch))
|
| 181 |
+
print(f" batch {i+1}/{n_calls} done")
|
| 182 |
+
tokens = np.concatenate(all_tokens)[:args.num_samples]
|
| 183 |
|
| 184 |
grids = tokens_to_grids(tokens, lattice_size) # (N, L, L), values {-1, +1}
|
| 185 |
print(f" shape: {grids.shape} dtype: {grids.dtype}")
|