ising-transformer / sample.py
bertran-yorro's picture
Fix OOM: batch sampling in micro-batches of 4
acfb50c verified
#!/usr/bin/env python
# /// script
# dependencies = [
# "jax[cuda12]",
# "equinox",
# "matplotlib",
# "jaxtyping",
# ]
# ///
"""Sample spin configurations from a trained Ising Generator checkpoint.
Usage:
python sample.py --checkpoint model.eqx [--num-samples 16]
[--output samples.npy] [--plot] [--seed 0]
How autoregressive sampling works
----------------------------------
The model is trained with a causal (lower-triangular) attention mask, so at
position t the output logits[t] are a function of spins s_0 … s_t only.
We exploit this to sample the full sequence one spin at a time:
1. Sample s_0 uniformly (the model has no BOS token).
2. For t = 0, 1, …, L²-2:
a. Run the full forward pass on the current token buffer.
Spins at positions > t are still placeholder zeros, but causal
masking prevents the network from attending to them.
b. Draw s_{t+1} ~ Categorical(softmax(logits[t])).
c. Write s_{t+1} into the buffer.
This is O(L⁴) in compute (L² steps × L² attention), which is 1 B ops for a
32×32 lattice. `jax.lax.scan` compiles the loop body once so subsequent
calls are fast.
"""
import argparse
import functools
from pathlib import Path
import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
from jaxtyping import Array, Int
from model import Generator, gen_config, snake_order
# ---------------------------------------------------------------------------
# Checkpoint I/O
# ---------------------------------------------------------------------------
def load_checkpoint(
path: Path,
config: dict = gen_config,
key: jax.Array | None = None,
) -> Generator:
"""Deserialise a Generator from *path*.
A fresh model is initialised with *config* (weights are immediately
overwritten), so *key* only needs to be reproducible across calls if you
care about the random seed used for the skeleton — in practice any key
works.
"""
if key is None:
key = jax.random.PRNGKey(0)
skeleton = Generator(config=config, key=key)
return eqx.tree_deserialise_leaves(path, skeleton)
# ---------------------------------------------------------------------------
# Sampling
# ---------------------------------------------------------------------------
def sample_sequence(
model: Generator,
key: jax.random.PRNGKey,
) -> Int[Array, " seq_len"]:
"""Autoregressively sample one spin sequence in snake order.
This function is JAX-traceable and safe to use inside ``jax.vmap`` or
``jax.lax.scan``. JIT-compile it (or wrap in ``sample_batch``) for best
performance — the first call will take longer due to compilation.
"""
lattice_size = model.encoder.embedder_block.lattice_size
seq_len = lattice_size * lattice_size
def step(carry, t):
tokens, step_key = carry
step_key, sample_key = jax.random.split(step_key)
logits = model({"token_ids": tokens}, enable_dropout=False, key=None)
# logits[t] → distribution over s_{t+1}
next_token = jax.random.categorical(sample_key, logits[t])
tokens = tokens.at[t + 1].set(next_token)
return (tokens, step_key), None
key, first_key = jax.random.split(key)
first_token = jax.random.randint(first_key, shape=(), minval=0, maxval=2)
tokens = jnp.zeros(seq_len, dtype=jnp.int32).at[0].set(first_token)
(tokens, _), _ = jax.lax.scan(step, (tokens, key), jnp.arange(seq_len - 1))
return tokens
@eqx.filter_jit
def sample_batch(
model: Generator,
num_samples: int,
key: jax.random.PRNGKey,
) -> Int[Array, "num_samples seq_len"]:
"""Sample *num_samples* configurations in parallel (vmapped + JIT'd).
The first call triggers compilation; subsequent calls with the same
``num_samples`` reuse the compiled code.
"""
keys = jax.random.split(key, num_samples)
return jax.vmap(sample_sequence, in_axes=(None, 0))(model, keys)
# ---------------------------------------------------------------------------
# Grid conversion
# ---------------------------------------------------------------------------
def tokens_to_grid(
tokens: np.ndarray | Array,
lattice_size: int,
) -> np.ndarray:
"""Convert a snake-ordered token sequence {0, 1} → a {-1, +1} L×L grid."""
tokens = np.asarray(tokens)
rows, cols = snake_order(lattice_size)
grid = np.empty((lattice_size, lattice_size), dtype=np.int8)
grid[rows, cols] = (tokens * 2 - 1).astype(np.int8)
return grid
def tokens_to_grids(
tokens: np.ndarray | Array,
lattice_size: int,
) -> np.ndarray:
"""Batch version of ``tokens_to_grid``. Input shape: (N, L²)."""
tokens = np.asarray(tokens)
return np.stack([tokens_to_grid(tokens[i], lattice_size) for i in range(len(tokens))])
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def parse_args():
p = argparse.ArgumentParser(description="Sample from a trained Ising Generator.")
p.add_argument("--checkpoint", type=Path, required=True,
help="Path to the .eqx checkpoint file.")
p.add_argument("--num-samples", type=int, default=16)
p.add_argument("--output", type=Path, default=None,
help="Save sampled {-1,+1} grids as a .npy file (N, L, L).")
p.add_argument("--plot", action="store_true",
help="Display a grid of sampled configurations with matplotlib.")
p.add_argument("--seed", type=int, default=0)
return p.parse_args()
def main():
args = parse_args()
print(f"Loading checkpoint from {args.checkpoint} …")
model = load_checkpoint(args.checkpoint)
lattice_size = model.encoder.embedder_block.lattice_size
print(f" lattice_size={lattice_size}, seq_len={lattice_size**2}")
key = jax.random.PRNGKey(args.seed)
print(f"Sampling {args.num_samples} configurations "
f"(compiling on first call) …")
# Sample in fixed-size micro-batches to avoid OOM on GPU.
# Changing _MICRO_BATCH triggers recompilation, so keep it constant.
_MICRO_BATCH = 4
all_tokens = []
n_calls = -(-args.num_samples // _MICRO_BATCH) # ceiling division
for i in range(n_calls):
key, subkey = jax.random.split(key)
batch = sample_batch(model, _MICRO_BATCH, subkey)
all_tokens.append(np.asarray(batch))
print(f" batch {i+1}/{n_calls} done")
tokens = np.concatenate(all_tokens)[:args.num_samples]
grids = tokens_to_grids(tokens, lattice_size) # (N, L, L), values {-1, +1}
print(f" shape: {grids.shape} dtype: {grids.dtype}")
print(f" mean magnetization : {grids.mean():.4f}")
print(f" mean |magnetization|: {np.abs(grids.mean(axis=(1, 2))).mean():.4f}")
if args.output is not None:
np.save(args.output, grids)
print(f"Saved → {args.output}")
if args.plot:
try:
import matplotlib.pyplot as plt
except ImportError:
print("matplotlib not available; skipping plot (pip install matplotlib).")
return
cols = min(8, args.num_samples)
rows = (args.num_samples + cols - 1) // cols
fig, axes = plt.subplots(rows, cols, figsize=(cols * 1.5, rows * 1.5))
axes = np.array(axes).reshape(-1)
for i, ax in enumerate(axes):
if i < len(grids):
ax.imshow(grids[i], cmap="gray", vmin=-1, vmax=1, interpolation="nearest")
ax.axis("off")
fig.suptitle(
f"Sampled Ising configurations "
f"(L={lattice_size}, n={args.num_samples})",
fontsize=10,
)
plt.tight_layout()
plt.show()
if __name__ == "__main__":
main()