bertran-yorro commited on
Commit
acfb50c
·
verified ·
1 Parent(s): 3b17899

Fix OOM: batch sampling in micro-batches of 4

Browse files
Files changed (1) hide show
  1. sample.py +12 -2
sample.py CHANGED
@@ -168,8 +168,18 @@ def main():
168
  key = jax.random.PRNGKey(args.seed)
169
  print(f"Sampling {args.num_samples} configurations "
170
  f"(compiling on first call) …")
171
- tokens = sample_batch(model, args.num_samples, key)
172
- tokens = np.asarray(tokens)
 
 
 
 
 
 
 
 
 
 
173
 
174
  grids = tokens_to_grids(tokens, lattice_size) # (N, L, L), values {-1, +1}
175
  print(f" shape: {grids.shape} dtype: {grids.dtype}")
 
168
  key = jax.random.PRNGKey(args.seed)
169
  print(f"Sampling {args.num_samples} configurations "
170
  f"(compiling on first call) …")
171
+
172
+ # Sample in fixed-size micro-batches to avoid OOM on GPU.
173
+ # Changing _MICRO_BATCH triggers recompilation, so keep it constant.
174
+ _MICRO_BATCH = 4
175
+ all_tokens = []
176
+ n_calls = -(-args.num_samples // _MICRO_BATCH) # ceiling division
177
+ for i in range(n_calls):
178
+ key, subkey = jax.random.split(key)
179
+ batch = sample_batch(model, _MICRO_BATCH, subkey)
180
+ all_tokens.append(np.asarray(batch))
181
+ print(f" batch {i+1}/{n_calls} done")
182
+ tokens = np.concatenate(all_tokens)[:args.num_samples]
183
 
184
  grids = tokens_to_grids(tokens, lattice_size) # (N, L, L), values {-1, +1}
185
  print(f" shape: {grids.shape} dtype: {grids.dtype}")